├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── config.yml └── workflows │ ├── ci.yml │ ├── eval-model.yml │ └── publish-to-pypi.yml ├── .gitignore ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── ci └── evaluate │ └── backtest_config.yaml ├── figures ├── chronos-logo.png ├── main-figure.png └── zero_shot-agg_scaled_score.svg ├── notebooks └── deploy-chronos-bolt-to-amazon-sagemaker.ipynb ├── pyproject.toml ├── scripts ├── README.md ├── evaluation │ ├── agg-relative-score.py │ ├── configs │ │ ├── in-domain.yaml │ │ └── zero-shot.yaml │ ├── evaluate.py │ └── results │ │ ├── chronos-bolt-base-agg-rel-scores.csv │ │ ├── chronos-bolt-base-in-domain.csv │ │ ├── chronos-bolt-base-zero-shot.csv │ │ ├── chronos-bolt-mini-agg-rel-scores.csv │ │ ├── chronos-bolt-mini-in-domain.csv │ │ ├── chronos-bolt-mini-zero-shot.csv │ │ ├── chronos-bolt-small-agg-rel-scores.csv │ │ ├── chronos-bolt-small-in-domain.csv │ │ ├── chronos-bolt-small-zero-shot.csv │ │ ├── chronos-bolt-tiny-agg-rel-scores.csv │ │ ├── chronos-bolt-tiny-in-domain.csv │ │ ├── chronos-bolt-tiny-zero-shot.csv │ │ ├── chronos-t5-base-agg-rel-scores.csv │ │ ├── chronos-t5-base-in-domain.csv │ │ ├── chronos-t5-base-zero-shot.csv │ │ ├── chronos-t5-large-agg-rel-scores.csv │ │ ├── chronos-t5-large-in-domain.csv │ │ ├── chronos-t5-large-zero-shot.csv │ │ ├── chronos-t5-mini-agg-rel-scores.csv │ │ ├── chronos-t5-mini-in-domain.csv │ │ ├── chronos-t5-mini-zero-shot.csv │ │ ├── chronos-t5-small-agg-rel-scores.csv │ │ ├── chronos-t5-small-in-domain.csv │ │ ├── chronos-t5-small-zero-shot.csv │ │ ├── chronos-t5-tiny-agg-rel-scores.csv │ │ ├── chronos-t5-tiny-in-domain.csv │ │ ├── chronos-t5-tiny-zero-shot.csv │ │ ├── seasonal-naive-in-domain.csv │ │ └── seasonal-naive-zero-shot.csv ├── kernel-synth.py └── training │ ├── configs │ ├── chronos-gpt2.yaml │ ├── chronos-t5-base.yaml │ ├── chronos-t5-large.yaml │ ├── chronos-t5-mini.yaml │ ├── chronos-t5-small.yaml │ └── chronos-t5-tiny.yaml │ └── train.py ├── src └── chronos │ ├── __init__.py │ ├── base.py │ ├── chronos.py │ ├── chronos_bolt.py │ └── utils.py └── test ├── __init__.py ├── dummy-chronos-bolt-model ├── config.json └── model.safetensors ├── dummy-chronos-model ├── config.json ├── generation_config.json └── pytorch_model.bin ├── test_chronos.py ├── test_chronos_bolt.py ├── test_utils.py └── util.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us reproduce and fix a bug 4 | title: "[BUG]" 5 | labels: ['bug'] 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Bug report checklist** 11 | 12 | - [ ] I provided code that demonstrates a minimal reproducible example. 13 | - [ ] I confirmed bug exists on the latest mainline of Chronos via source install. 14 | 15 | **Describe the bug** 16 | 17 | 18 | **Expected behavior** 19 | 20 | 21 | **To reproduce** 22 | 25 | 26 | **Environment description** 27 | Operating system: 28 | Python version: 29 | CUDA version: 30 | PyTorch version: 31 | HuggingFace transformers version: 32 | HuggingFace accelerate version: 33 | 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Frequently asked questions 4 | url: https://github.com/amazon-science/chronos-forecasting/issues?q=is%3Aissue+label%3AFAQ 5 | about: Check the frequently asked questions before opening a new one 6 | - name: Discussions 7 | url: https://github.com/amazon-science/chronos-forecasting/discussions/new 8 | about: Use this to ask questions and start discussions 9 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ["main"] # Run only on main branch 6 | pull_request: 7 | branches: ["**"] # Run on any branch 8 | schedule: 9 | - cron: "0 8 * * *" # Run at 8 AM UTC 10 | 11 | jobs: 12 | type-check: 13 | strategy: 14 | max-parallel: 4 15 | fail-fast: false 16 | matrix: 17 | python-version: ["3.11"] 18 | platform: [ubuntu-latest] 19 | 20 | runs-on: ${{ matrix.platform }} 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: pip install ".[typecheck]" -f https://download.pytorch.org/whl/cpu/torch_stable.html 30 | - name: Type checks with mypy 31 | run: mypy src test 32 | 33 | test: 34 | strategy: 35 | max-parallel: 4 36 | fail-fast: false 37 | matrix: 38 | python-version: ["3.9", "3.10", "3.11", "3.12"] 39 | platform: [ubuntu-latest] 40 | 41 | runs-on: ${{ matrix.platform }} 42 | 43 | steps: 44 | - uses: actions/checkout@v4 45 | - name: Set up Python ${{ matrix.python-version }} 46 | uses: actions/setup-python@v5 47 | with: 48 | python-version: ${{ matrix.python-version }} 49 | - name: Install dependencies 50 | run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html 51 | - name: Test with pytest 52 | run: pytest 53 | -------------------------------------------------------------------------------- /.github/workflows/eval-model.yml: -------------------------------------------------------------------------------- 1 | # Evaluates Chronos-Bolt (Small) model on selected datasets 2 | name: Evaluate 3 | 4 | on: 5 | # Runs only with read privilages for the GITHUB_TOKEN 6 | pull_request: 7 | branches: ["main"] # Run on PRs to main branch 8 | types: 9 | - opened # When a PR is created 10 | - reopened # When a closed PR is reopened 11 | - synchronize # When new commits are pushed to the PR 12 | - labeled # When a label is added to the PR 13 | 14 | jobs: 15 | evaluate-and-print: 16 | if: contains(github.event.pull_request.labels.*.name, 'run-eval') # Only run if 'run-eval' label is added 17 | runs-on: ubuntu-latest 18 | env: 19 | RESULTS_CSV: "eval-ci-metrics-${{ github.event.pull_request.number }}.csv" 20 | 21 | steps: 22 | - name: Checkout Repository 23 | uses: actions/checkout@v4 24 | 25 | - name: Set up Python 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: '3.11' 29 | 30 | - name: Install Dependencies 31 | run: pip install ".[evaluation]" -f https://download.pytorch.org/whl/cpu/torch_stable.html 32 | 33 | - name: Run Eval Script 34 | run: python scripts/evaluation/evaluate.py ci/evaluate/backtest_config.yaml $RESULTS_CSV --chronos-model-id=amazon/chronos-bolt-small --device=cpu --torch-dtype=float32 35 | 36 | - name: Print CSV 37 | run: cat $RESULTS_CSV 38 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package to PyPi 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy-to-pypi: 9 | runs-on: ubuntu-latest 10 | environment: release 11 | permissions: 12 | id-token: write 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.11' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install -U pip 22 | python -m pip install setuptools wheel build 23 | - name: Build package 24 | run: | 25 | python -m build 26 | - name: Publish to PyPi 27 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # macOS stuff 163 | .DS_store -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: "Chronos: Learning the Language of Time Series" 3 | message: "If you find Chronos models useful for your research, please consider citing the associated paper." 4 | authors: 5 | - family-names: Ansari 6 | given-names: Abdul Fatir 7 | - family-names: Stella 8 | given-names: Lorenzo 9 | - family-names: Turkmen 10 | given-names: Caner 11 | - family-names: Zhang 12 | given-names: Xiyuan 13 | - family-names: Mercado 14 | given-names: Pedro 15 | - family-names: Shen 16 | given-names: Huibin 17 | - family-names: Shchur 18 | given-names: Oleksandr 19 | - family-names: Rangapuram 20 | given-names: Syama Syndar 21 | - family-names: Arango 22 | given-names: Sebastian Pineda 23 | - family-names: Kapoor 24 | given-names: Shubham 25 | - family-names: Zschiegner 26 | given-names: Jasper 27 | - family-names: Maddix 28 | given-names: Danielle C. 29 | - family-names: Mahoney 30 | given-names: Michael W. 31 | - family-names: Torkkola 32 | given-names: Kari 33 | - family-names: Wilson 34 | given-names: Andrew Gordon 35 | - family-names: Bohlke-Schneider 36 | given-names: Michael 37 | - family-names: Wang 38 | given-names: Yuyang 39 | preferred-citation: 40 | type: article 41 | authors: 42 | - family-names: Ansari 43 | given-names: Abdul Fatir 44 | - family-names: Stella 45 | given-names: Lorenzo 46 | - family-names: Turkmen 47 | given-names: Caner 48 | - family-names: Zhang 49 | given-names: Xiyuan 50 | - family-names: Mercado 51 | given-names: Pedro 52 | - family-names: Shen 53 | given-names: Huibin 54 | - family-names: Shchur 55 | given-names: Oleksandr 56 | - family-names: Rangapuram 57 | given-names: Syama Syndar 58 | - family-names: Arango 59 | given-names: Sebastian Pineda 60 | - family-names: Kapoor 61 | given-names: Shubham 62 | - family-names: Zschiegner 63 | given-names: Jasper 64 | - family-names: Maddix 65 | given-names: Danielle C. 66 | - family-names: Mahoney 67 | given-names: Michael W. 68 | - family-names: Torkkola 69 | given-names: Kari 70 | - family-names: Wilson 71 | given-names: Andrew Gordon 72 | - family-names: Bohlke-Schneider 73 | given-names: Michael 74 | - family-names: Wang 75 | given-names: Yuyang 76 | title: "Chronos: Learning the Language of Time Series" 77 | journal: "arXiv preprint arXiv:2403.07815" 78 | year: 2024 79 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 |
6 | 7 | # Chronos: Learning the Language of Time Series 8 | 9 | [![preprint](https://img.shields.io/static/v1?label=arXiv&message=2403.07815&color=B31B1B&logo=arXiv)](https://arxiv.org/abs/2403.07815) 10 | [![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Datasets-FFD21E)](https://huggingface.co/datasets/autogluon/chronos_datasets) 11 | [![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-FFD21E)](https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444) 12 | [![fev](https://img.shields.io/static/v1?label=fev&message=Benchmark&color=B31B1B&logo=github)](https://github.com/autogluon/fev) 13 | [![aws](https://img.shields.io/static/v1?label=SageMaker&message=Deploy&color=FF9900&logo=amazon-web-services)](notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb) 14 | [![faq](https://img.shields.io/badge/FAQ-Questions%3F-blue)](https://github.com/amazon-science/chronos-forecasting/issues?q=is%3Aissue+label%3AFAQ) 15 | [![License: MIT](https://img.shields.io/badge/License-Apache--2.0-green.svg)](https://opensource.org/licenses/Apache-2.0) 16 | 17 |
18 | 19 | 20 | ## 🚀 News 21 | - **14 Feb 2025**: 🚀 Chronos-Bolt is now available on Amazon SageMaker JumpStart! Check out the [tutorial notebook](notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb) to learn how to deploy Chronos endpoints for production use in 3 lines of code. 22 | - **12 Dec 2024**: 📊 We released [`fev`](https://github.com/autogluon/fev), a lightweight package for benchmarking time series forecasting models based on the [Hugging Face `datasets`](https://huggingface.co/docs/datasets/en/index) library. 23 | - **26 Nov 2024**: ⚡️ Chronos-Bolt models released [on HuggingFace](https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444). Chronos-Bolt models are more accurate (5% lower error), up to 250x faster and 20x more memory efficient than the original Chronos models of the same size! 24 | - **27 Jun 2024**: 🚀 [Released datasets](https://huggingface.co/datasets/autogluon/chronos_datasets) used in the paper and an [evaluation script](./scripts/README.md#evaluating-chronos-models) to compute the WQL and MASE scores reported in the paper. 25 | - **17 May 2024**: 🐛 Fixed an off-by-one error in bin indices in the `output_transform`. This simple fix significantly improves the overall performance of Chronos. We will update the results in the next revision on ArXiv. 26 | - **10 May 2024**: 🚀 We added the code for pretraining and fine-tuning Chronos models. You can find it in [this folder](./scripts/training). We also added [a script](./scripts/kernel-synth.py) for generating synthetic time series data from Gaussian processes (KernelSynth; see Section 4.2 in the paper for details). Check out the [usage examples](./scripts/). 27 | - **19 Apr 2024**: 🚀 Chronos is now supported on [AutoGluon-TimeSeries](https://auto.gluon.ai/stable/tutorials/timeseries/index.html), the powerful AutoML package for time series forecasting which enables model ensembles, cloud deployments, and much more. Get started with the [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html). 28 | - **08 Apr 2024**: 🧪 Experimental [MLX inference support](https://github.com/amazon-science/chronos-forecasting/tree/mlx) added. If you have an Apple Silicon Mac, you can now obtain significantly faster forecasts from Chronos compared to CPU inference. This provides an alternative way to exploit the GPU on your Apple Silicon Macs together with the "mps" support in PyTorch. 29 | - **25 Mar 2024**: 🚀 [v1.1.0 released](https://github.com/amazon-science/chronos-forecasting/releases/tag/v1.1.0) with inference optimizations and `pipeline.embed` to extract encoder embeddings from Chronos. 30 | - **13 Mar 2024**: 🚀 Chronos [paper](https://arxiv.org/abs/2403.07815) and inference code released. 31 | 32 | ## ✨ Introduction 33 | 34 | Chronos is a family of **pretrained time series forecasting models** based on language model architectures. A time series is transformed into a sequence of tokens via scaling and quantization, and a language model is trained on these tokens using the cross-entropy loss. Once trained, probabilistic forecasts are obtained by sampling multiple future trajectories given the historical context. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes. 35 | 36 | For details on Chronos models, training data and procedures, and experimental results, please refer to the paper [Chronos: Learning the Language of Time Series](https://arxiv.org/abs/2403.07815). 37 | 38 |

39 | 40 |
41 | 42 | Fig. 1: High-level depiction of Chronos. (Left) The input time series is scaled and quantized to obtain a sequence of tokens. (Center) The tokens are fed into a language model which may either be an encoder-decoder or a decoder-only model. The model is trained using the cross-entropy loss. (Right) During inference, we autoregressively sample tokens from the model and map them back to numerical values. Multiple trajectories are sampled to obtain a predictive distribution. 43 | 44 |

45 | 46 | ### Architecture 47 | 48 | The models in this repository are based on the [T5 architecture](https://arxiv.org/abs/1910.10683). The only difference is in the vocabulary size: Chronos-T5 models use 4096 different tokens, compared to 32128 of the original T5 models, resulting in fewer parameters. 49 | 50 |
51 | 52 | | Model | Parameters | Based on | 53 | | ---------------------------------------------------------------------- | ---------- | ---------------------------------------------------------------------- | 54 | | [**chronos-t5-tiny**](https://huggingface.co/amazon/chronos-t5-tiny) | 8M | [t5-efficient-tiny](https://huggingface.co/google/t5-efficient-tiny) | 55 | | [**chronos-t5-mini**](https://huggingface.co/amazon/chronos-t5-mini) | 20M | [t5-efficient-mini](https://huggingface.co/google/t5-efficient-mini) | 56 | | [**chronos-t5-small**](https://huggingface.co/amazon/chronos-t5-small) | 46M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) | 57 | | [**chronos-t5-base**](https://huggingface.co/amazon/chronos-t5-base) | 200M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) | 58 | | [**chronos-t5-large**](https://huggingface.co/amazon/chronos-t5-large) | 710M | [t5-efficient-large](https://huggingface.co/google/t5-efficient-large) | 59 | | [**chronos-bolt-tiny**](https://huggingface.co/amazon/chronos-bolt-tiny) | 9M | [t5-efficient-tiny](https://huggingface.co/google/t5-efficient-tiny) | 60 | | [**chronos-bolt-mini**](https://huggingface.co/amazon/chronos-bolt-mini) | 21M | [t5-efficient-mini](https://huggingface.co/google/t5-efficient-mini) | 61 | | [**chronos-bolt-small**](https://huggingface.co/amazon/chronos-bolt-small) | 48M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) | 62 | | [**chronos-bolt-base**](https://huggingface.co/amazon/chronos-bolt-base) | 205M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) | 63 | 64 |
65 | 66 | ### Zero-Shot Results 67 | 68 | The following figure showcases the remarkable **zero-shot** performance of Chronos and Chronos-Bolt models on 27 datasets against local models, task-specific models and other pretrained models. For details on the evaluation setup and other results, please refer to [the paper](https://arxiv.org/abs/2403.07815). 69 | 70 |

71 | 72 |
73 | 74 | Fig. 2: Performance of different models on Benchmark II, comprising 27 datasets not seen by Chronos and Chronos-Bolt models during training. This benchmark provides insights into the zero-shot performance of Chronos and Chronos-Bolt models against local statistical models, which fit parameters individually for each time series, task-specific models trained on each task, and pretrained models trained on a large corpus of time series. Pretrained Models (Other) indicates that some (or all) of the datasets in Benchmark II may have been in the training corpus of these models. The probabilistic (WQL) and point (MASE) forecasting metrics were normalized using the scores of the Seasonal Naive baseline and aggregated through a geometric mean to obtain the Agg. Relative WQL and MASE, respectively. 75 | 76 |

77 | 78 | ## 📈 Usage 79 | 80 | To perform inference with Chronos or Chronos-Bolt models, the easiest way is to install this package through `pip`: 81 | 82 | ```sh 83 | pip install chronos-forecasting 84 | ``` 85 | 86 | If you're interested in pretraining, fine-tuning, and other research & development, clone and install the package from source: 87 | 88 | ```sh 89 | # Clone the repository 90 | git clone https://github.com/amazon-science/chronos-forecasting.git 91 | 92 | # Install in editable mode with extra training-related dependencies 93 | cd chronos-forecasting && pip install --editable ".[training]" 94 | ``` 95 | 96 | > [!TIP] 97 | > This repository is intended for research purposes and provides a minimal interface to Chronos models. For reliable production use, we recommend the following options: 98 | > - [AutoGluon](https://auto.gluon.ai) provides effortless fine-tuning, augmenting Chronos models with exogenous information through covariate regressors, ensembling with other statistical and machine learning models. Check out the AutoGluon Chronos [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html). 99 | > - SageMaker JumpStart makes it easy to deploy Chronos inference endpoints to AWS with just a few lines of code. Check out [this tutorial](notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb) for more details. 100 | 101 | ### Forecasting 102 | 103 | A minimal example showing how to perform forecasting using Chronos and Chronos-Bolt models: 104 | 105 | ```python 106 | import pandas as pd # requires: pip install pandas 107 | import torch 108 | from chronos import BaseChronosPipeline 109 | 110 | pipeline = BaseChronosPipeline.from_pretrained( 111 | "amazon/chronos-t5-small", # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model 112 | device_map="cuda", # use "cpu" for CPU inference 113 | torch_dtype=torch.bfloat16, 114 | ) 115 | 116 | df = pd.read_csv( 117 | "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv" 118 | ) 119 | 120 | # context must be either a 1D tensor, a list of 1D tensors, 121 | # or a left-padded 2D tensor with batch as the first dimension 122 | # quantiles is an fp32 tensor with shape [batch_size, prediction_length, num_quantile_levels] 123 | # mean is an fp32 tensor with shape [batch_size, prediction_length] 124 | quantiles, mean = pipeline.predict_quantiles( 125 | context=torch.tensor(df["#Passengers"]), 126 | prediction_length=12, 127 | quantile_levels=[0.1, 0.5, 0.9], 128 | ) 129 | ``` 130 | 131 | For the original Chronos models, `pipeline.predict` can be used to draw forecast samples. More options for `predict_kwargs` in `pipeline.predict_quantiles` can be found with: 132 | 133 | ```python 134 | from chronos import ChronosPipeline, ChronosBoltPipeline 135 | 136 | print(ChronosPipeline.predict.__doc__) # for Chronos models 137 | print(ChronosBoltPipeline.predict.__doc__) # for Chronos-Bolt models 138 | ``` 139 | 140 | We can now visualize the forecast: 141 | 142 | ```python 143 | import matplotlib.pyplot as plt # requires: pip install matplotlib 144 | 145 | forecast_index = range(len(df), len(df) + 12) 146 | low, median, high = quantiles[0, :, 0], quantiles[0, :, 1], quantiles[0, :, 2] 147 | 148 | plt.figure(figsize=(8, 4)) 149 | plt.plot(df["#Passengers"], color="royalblue", label="historical data") 150 | plt.plot(forecast_index, median, color="tomato", label="median forecast") 151 | plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval") 152 | plt.legend() 153 | plt.grid() 154 | plt.show() 155 | ``` 156 | 157 | ### Extracting Encoder Embeddings 158 | 159 | A minimal example showing how to extract encoder embeddings from Chronos models: 160 | 161 | ```python 162 | import pandas as pd 163 | import torch 164 | from chronos import ChronosPipeline 165 | 166 | pipeline = ChronosPipeline.from_pretrained( 167 | "amazon/chronos-t5-small", 168 | device_map="cuda", 169 | torch_dtype=torch.bfloat16, 170 | ) 171 | 172 | df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv") 173 | 174 | # context must be either a 1D tensor, a list of 1D tensors, 175 | # or a left-padded 2D tensor with batch as the first dimension 176 | context = torch.tensor(df["#Passengers"]) 177 | embeddings, tokenizer_state = pipeline.embed(context) 178 | ``` 179 | 180 | ### Pretraining, fine-tuning and evaluation 181 | 182 | Scripts for pretraining, fine-tuning and evaluating Chronos models can be found in [this folder](./scripts/). 183 | 184 | ## :floppy_disk: Datasets 185 | 186 | Datasets used in the Chronos paper for pretraining and evaluation (both in-domain and zero-shot) are available through the HuggingFace repos: [`autogluon/chronos_datasets`](https://huggingface.co/datasets/autogluon/chronos_datasets) and [`autogluon/chronos_datasets_extra`](https://huggingface.co/datasets/autogluon/chronos_datasets_extra). Check out these repos for instructions on how to download and use the datasets. 187 | 188 | ## 🔥 Coverage 189 | 190 | - [Adapting language model architectures for time series forecasting](https://www.amazon.science/blog/adapting-language-model-architectures-for-time-series-forecasting) (Amazon Science blog post) 191 | - [Amazon AI Researchers Introduce Chronos: A New Machine Learning Framework for Pretrained Probabilistic Time Series Models](https://www.marktechpost.com/2024/03/15/amazon-ai-researchers-introduce-chronos-a-new-machine-learning-framework-for-pretrained-probabilistic-time-series-models/) (Marktechpost blog post) 192 | - [Chronos: The Rise of Foundation Models for Time Series Forecasting](https://towardsdatascience.com/chronos-the-rise-of-foundation-models-for-time-series-forecasting-aaeba62d9da3) (Towards Data Science blog post by Luís Roque and Rafael Guedes) 193 | - [Moirai: Time Series Foundation Models for Universal Forecasting](https://towardsdatascience.com/moirai-time-series-foundation-models-for-universal-forecasting-dc93f74b330f) (Towards Data Science blog post by Luís Roque and Rafael Guedes, includes comparison of Chronos with Moirai) 194 | - [Chronos: The Latest Time Series Forecasting Foundation Model by Amazon](https://towardsdatascience.com/chronos-the-latest-time-series-forecasting-foundation-model-by-amazon-2687d641705a) (Towards Data Science blog post by Marco Peixeiro) 195 | - The original article had a critical bug affecting the metric computation for Chronos. We opened a [pull request](https://github.com/marcopeix/time-series-analysis/pull/10) to fix it. 196 | - [How to Effectively Forecast Time Series with Amazon's New Time Series Forecasting Model](https://towardsdatascience.com/how-to-effectively-forecast-time-series-with-amazons-new-time-series-forecasting-model-9e04d4ccf67e) (Towards Data Science blog post by Eivind Kjosbakken) 197 | - [Chronos: Learning the Language of Time Series](https://minimizeregret.com/linked/2024/03/27/chronos-forecasting/) (Minimize Regret blog post by Tim Radtke) 198 | - [Chronos: Another Zero-Shot Time Series Forecaster LLM](https://levelup.gitconnected.com/chronos-another-zero-shot-time-series-forecaster-llm-0e80753a7ad0) (Level Up Coding blog post by Level Up Coding AI TutorMaster) 199 | - [Paper Review: Chronos: Learning the Language of Time Series](https://andlukyane.com/blog/paper-review-chronos) (Review by Andrey Lukyanenko) 200 | - [Foundation Models for Forecasting: the Future or Folly?](https://insights.radix.ai/blog/foundation-models-for-forecasting-the-future-or-folly) (Blog post by Radix) 201 | - [Learning the Language of Time Series with Chronos](https://medium.com/@ManueleCaddeo/learning-the-language-of-time-series-with-chronos-fea7d0fedde4) (Medium post by Manuele Caddeo) 202 | - [The latest advancement in Time Series Forecasting from AWS: Chronos](https://medium.com/chat-gpt-now-writes-all-my-articles/the-latest-advancement-in-time-series-forecasting-from-aws-chronos-python-code-included-0205d01248f3) (Medium post by Abish Pius) 203 | - [Decoding the Future: How Chronos Redefines Time Series Forecasting with the Art of Language](https://medium.com/@zamalbabar/decoding-the-future-how-chronos-redefines-time-series-forecasting-with-the-art-of-language-cecc2174e400) (Medium post by Zamal) 204 | - [Comparison of Chronos against the SCUM ensemble of statistical models](https://github.com/Nixtla/nixtla/tree/main/experiments/amazon-chronos) (Benchmark by Nixtla) 205 | - We opened a [pull request](https://github.com/Nixtla/nixtla/pull/281) extending the analysis to 28 datasets (200K+ time series) and showing that **zero-shot** Chronos models perform comparably to this strong ensemble of 4 statistical models while being significantly faster on average. Our complete response can be [found here](https://www.linkedin.com/pulse/extended-comparison-chronos-against-statistical-ensemble-ansari-4aste/). 206 | - [Comparison of Chronos against a variety of forecasting models](https://www.linkedin.com/feed/update/urn:li:activity:7178398371815051267/) (Benchmark by ReadyTensor) 207 | 208 | ## 📝 Citation 209 | 210 | If you find Chronos models useful for your research, please consider citing the associated [paper](https://arxiv.org/abs/2403.07815): 211 | 212 | ``` 213 | @article{ansari2024chronos, 214 | title={Chronos: Learning the Language of Time Series}, 215 | author={Ansari, Abdul Fatir and Stella, Lorenzo and Turkmen, Caner and Zhang, Xiyuan, and Mercado, Pedro and Shen, Huibin and Shchur, Oleksandr and Rangapuram, Syama Syndar and Pineda Arango, Sebastian and Kapoor, Shubham and Zschiegner, Jasper and Maddix, Danielle C. and Mahoney, Michael W. and Torkkola, Kari and Gordon Wilson, Andrew and Bohlke-Schneider, Michael and Wang, Yuyang}, 216 | journal={Transactions on Machine Learning Research}, 217 | issn={2835-8856}, 218 | year={2024}, 219 | url={https://openreview.net/forum?id=gerNCVqqtR} 220 | } 221 | ``` 222 | 223 | ## 🛡️ Security 224 | 225 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 226 | 227 | ## 📃 License 228 | 229 | This project is licensed under the Apache-2.0 License. 230 | -------------------------------------------------------------------------------- /ci/evaluate/backtest_config.yaml: -------------------------------------------------------------------------------- 1 | # From In-domain 2 | - name: taxi_30min # 30 min 3 | hf_repo: autogluon/chronos_datasets 4 | offset: -48 5 | prediction_length: 48 6 | num_rolls: 1 7 | # From Zero-shot 8 | - name: ETTh # Hourly 9 | hf_repo: autogluon/chronos_datasets_extra 10 | offset: -24 11 | prediction_length: 24 12 | num_rolls: 1 13 | - name: monash_covid_deaths # Daily 14 | hf_repo: autogluon/chronos_datasets 15 | offset: -30 16 | prediction_length: 30 17 | num_rolls: 1 18 | - name: monash_nn5_weekly # Weekly 19 | hf_repo: autogluon/chronos_datasets 20 | offset: -8 21 | prediction_length: 8 22 | num_rolls: 1 23 | - name: monash_fred_md # Monthly 24 | hf_repo: autogluon/chronos_datasets 25 | offset: -12 26 | prediction_length: 12 27 | num_rolls: 1 28 | - name: monash_m3_quarterly # Quarterly 29 | hf_repo: autogluon/chronos_datasets 30 | offset: -8 31 | prediction_length: 8 32 | num_rolls: 1 33 | - name: monash_tourism_yearly # Yearly 34 | hf_repo: autogluon/chronos_datasets 35 | offset: -4 36 | prediction_length: 4 37 | num_rolls: 1 -------------------------------------------------------------------------------- /figures/chronos-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/chronos-forecasting/6a9c8dadac04eb85befc935043e3e2cce914267f/figures/chronos-logo.png -------------------------------------------------------------------------------- /figures/main-figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/chronos-forecasting/6a9c8dadac04eb85befc935043e3e2cce914267f/figures/main-figure.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "chronos-forecasting" 3 | version = "1.5.2" 4 | authors = [ 5 | { name="Abdul Fatir Ansari", email="ansarnd@amazon.com" }, 6 | { name="Lorenzo Stella", email="stellalo@amazon.com" }, 7 | { name="Caner Turkmen", email="atturkm@amazon.com" }, 8 | ] 9 | description = "Chronos: Pretrained models for time series forecasting" 10 | readme = "README.md" 11 | license = { file = "LICENSE" } 12 | requires-python = ">=3.9" 13 | dependencies = [ 14 | "torch>=2.0,<3", # package was tested on 2.2 15 | "transformers>=4.48,<5", 16 | "accelerate>=0.32,<2", 17 | ] 18 | classifiers = [ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Operating System :: OS Independent", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | ] 24 | 25 | [build-system] 26 | requires = ["hatchling"] 27 | build-backend = "hatchling.build" 28 | 29 | [tool.hatch.build.targets.wheel] 30 | packages = ["src/chronos"] 31 | 32 | [project.optional-dependencies] 33 | test = ["pytest~=8.0", "numpy~=1.21"] 34 | typecheck = ["mypy~=1.9"] 35 | training = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer", "typer-config", "joblib", "scikit-learn", "tensorboard"] 36 | evaluation = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer"] 37 | 38 | [project.urls] 39 | Homepage = "https://github.com/amazon-science/chronos-forecasting" 40 | Issues = "https://github.com/amazon-science/chronos-forecasting/issues" 41 | Paper = "https://arxiv.org/abs/2403.07815" 42 | 43 | [tool.mypy] 44 | ignore_missing_imports = true 45 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Usage Examples 2 | 3 | ## Generating Synthetic Time Series (KernelSynth) 4 | 5 | - Install this package with with the `training` extra: 6 | ``` 7 | pip install "chronos-forecasting[training] @ git+https://github.com/amazon-science/chronos-forecasting.git" 8 | ``` 9 | - Run `kernel-synth.py`: 10 | ```sh 11 | # With defaults used in the paper (1M time series and 5 max_kernels) 12 | python kernel-synth.py 13 | 14 | # You may optionally specify num-series and max-kernels 15 | python kernel-synth.py \ 16 | --num-series \ 17 | --max-kernels 18 | ``` 19 | The generated time series will be saved in a [GluonTS](https://github.com/awslabs/gluonts)-comptabile arrow file `kernelsynth-data.arrow`. 20 | 21 | ## Pretraining (and fine-tuning) Chronos models 22 | - Install this package with with the `training` extra: 23 | ``` 24 | pip install "chronos-forecasting[training] @ git+https://github.com/amazon-science/chronos-forecasting.git" 25 | ``` 26 | - Convert your time series dataset into a GluonTS-compatible file dataset. We recommend using the arrow format. You may use the `convert_to_arrow` function from the following snippet for that. Optionally, you may use [synthetic data from KernelSynth](#generating-synthetic-time-series-kernelsynth) to follow along. 27 | ```py 28 | from pathlib import Path 29 | from typing import List, Union 30 | 31 | import numpy as np 32 | from gluonts.dataset.arrow import ArrowWriter 33 | 34 | 35 | def convert_to_arrow( 36 | path: Union[str, Path], 37 | time_series: Union[List[np.ndarray], np.ndarray], 38 | compression: str = "lz4", 39 | ): 40 | """ 41 | Store a given set of series into Arrow format at the specified path. 42 | 43 | Input data can be either a list of 1D numpy arrays, or a single 2D 44 | numpy array of shape (num_series, time_length). 45 | """ 46 | assert isinstance(time_series, list) or ( 47 | isinstance(time_series, np.ndarray) and 48 | time_series.ndim == 2 49 | ) 50 | 51 | # Set an arbitrary start time 52 | start = np.datetime64("2000-01-01 00:00", "s") 53 | 54 | dataset = [ 55 | {"start": start, "target": ts} for ts in time_series 56 | ] 57 | 58 | ArrowWriter(compression=compression).write_to_file( 59 | dataset, 60 | path=path, 61 | ) 62 | 63 | 64 | if __name__ == "__main__": 65 | # Generate 20 random time series of length 1024 66 | time_series = [np.random.randn(1024) for i in range(20)] 67 | 68 | # Convert to GluonTS arrow format 69 | convert_to_arrow("./noise-data.arrow", time_series=time_series) 70 | ``` 71 | - Modify the [training configs](training/configs) to use your data. Let's use the KernelSynth data as an example. 72 | ```yaml 73 | # List of training data files 74 | training_data_paths: 75 | - "/path/to/kernelsynth-data.arrow" 76 | # Mixing probability of each dataset file 77 | probability: 78 | - 1.0 79 | ``` 80 | You may optionally change other parameters of the config file, as required. For instance, if you're interested in fine-tuning the model from a pretrained Chronos checkpoint, you should change the `model_id`, set `random_init: false`, and (optionally) change other parameters such as `max_steps` and `learning_rate`. 81 | - Start the training (or fine-tuning) job: 82 | ```sh 83 | # On single GPU 84 | CUDA_VISIBLE_DEVICES=0 python training/train.py --config /path/to/modified/config.yaml 85 | 86 | # On multiple GPUs (example with 8 GPUs) 87 | torchrun --nproc-per-node=8 training/train.py --config /path/to/modified/config.yaml 88 | 89 | # Fine-tune `amazon/chronos-t5-small` for 1000 steps with initial learning rate of 1e-3 90 | CUDA_VISIBLE_DEVICES=0 python training/train.py --config /path/to/modified/config.yaml \ 91 | --model-id amazon/chronos-t5-small \ 92 | --no-random-init \ 93 | --max-steps 1000 \ 94 | --learning-rate 0.001 95 | ``` 96 | The output and checkpoints will be saved in `output/run-{id}/`. 97 | > [!TIP] 98 | > If the initial training step is too slow, you might want to change the `shuffle_buffer_length` and/or set `torch_compile` to `false`. 99 | 100 | > [!IMPORTANT] 101 | > When pretraining causal models (such as GPT2), the training script does [`LastValueImputation`](https://github.com/awslabs/gluonts/blob/f0f2266d520cb980f4c1ce18c28b003ad5cd2599/src/gluonts/transform/feature.py#L103) for missing values by default. If you pretrain causal models, please ensure that missing values are imputed similarly before passing the context tensor to `ChronosPipeline.predict()` for accurate results. 102 | - (Optional) Once trained, you can easily push your fine-tuned model to HuggingFace🤗 Hub. Before that, do not forget to [create an access token](https://huggingface.co/settings/tokens) with **write permissions** and put it in `~/.cache/huggingface/token`. Here's a snippet that will push a fine-tuned model to HuggingFace🤗 Hub at `/chronos-t5-small-fine-tuned`. 103 | ```py 104 | from chronos import ChronosPipeline 105 | 106 | pipeline = ChronosPipeline.from_pretrained("/path/to/fine-tuned/model/ckpt/dir/") 107 | pipeline.model.model.push_to_hub("chronos-t5-small-fine-tuned") 108 | ``` 109 | 110 | ## Evaluating Chronos models 111 | 112 | Follow these steps to compute the WQL and MASE values for the in-domain and zero-shot benchmarks in our paper. 113 | 114 | - Install this package with with the `evaluation` extra: 115 | ``` 116 | pip install "chronos-forecasting[evaluation] @ git+https://github.com/amazon-science/chronos-forecasting.git" 117 | ``` 118 | - Run the evaluation script: 119 | ```sh 120 | # In-domain evaluation 121 | # Results will be saved in: evaluation/results/chronos-t5-small-in-domain.csv 122 | python evaluation/evaluate.py evaluation/configs/in-domain.yaml evaluation/results/chronos-t5-small-in-domain.csv \ 123 | --chronos-model-id "amazon/chronos-t5-small" \ 124 | --batch-size=32 \ 125 | --device=cuda:0 \ 126 | --num-samples 20 127 | 128 | # Zero-shot evaluation 129 | # Results will be saved in: evaluation/results/chronos-t5-small-zero-shot.csv 130 | python evaluation/evaluate.py evaluation/configs/zero-shot.yaml evaluation/results/chronos-t5-small-zero-shot.csv \ 131 | --chronos-model-id "amazon/chronos-t5-small" \ 132 | --batch-size=32 \ 133 | --device=cuda:0 \ 134 | --num-samples 20 135 | ``` 136 | - Use the following snippet to compute the aggregated relative WQL and MASE scores: 137 | ```py 138 | import pandas as pd 139 | from scipy.stats import gmean # requires: pip install scipy 140 | 141 | 142 | def agg_relative_score(model_df: pd.DataFrame, baseline_df: pd.DataFrame): 143 | relative_score = model_df.drop("model", axis="columns") / baseline_df.drop( 144 | "model", axis="columns" 145 | ) 146 | return relative_score.agg(gmean) 147 | 148 | 149 | result_df = pd.read_csv("evaluation/results/chronos-t5-small-in-domain.csv").set_index("dataset") 150 | baseline_df = pd.read_csv("evaluation/results/seasonal-naive-in-domain.csv").set_index("dataset") 151 | 152 | agg_score_df = agg_relative_score(result_df, baseline_df) 153 | ``` -------------------------------------------------------------------------------- /scripts/evaluation/agg-relative-score.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import typer 3 | from scipy.stats import gmean 4 | from pathlib import Path 5 | 6 | app = typer.Typer(pretty_exceptions_enable=False) 7 | DEFAULT_RESULTS_DIR = Path(__file__).parent / "results" 8 | 9 | 10 | def agg_relative_score(model_csv: Path, baseline_csv: Path): 11 | model_df = pd.read_csv(model_csv).set_index("dataset") 12 | baseline_df = pd.read_csv(baseline_csv).set_index("dataset") 13 | relative_score = model_df.drop("model", axis="columns") / baseline_df.drop( 14 | "model", axis="columns" 15 | ) 16 | return relative_score.agg(gmean) 17 | 18 | 19 | @app.command() 20 | def main( 21 | model_name: str, 22 | baseline_name: str = "seasonal-naive", 23 | results_dir: Path = DEFAULT_RESULTS_DIR, 24 | ): 25 | """ 26 | Compute the aggregated relative score as reported in the Chronos paper. 27 | Results will be saved to {results_dir}/{model_name}-agg-rel-scores.csv 28 | 29 | Parameters 30 | ---------- 31 | model_name : str 32 | Name of the model used in the CSV files. The in-domain and zero-shot CSVs 33 | are expected to be named {model_name}-in-domain.csv and {model_name}-zero-shot.csv. 34 | results_dir : Path, optional, default = results/ 35 | Directory where results CSVs generated by evaluate.py are stored 36 | """ 37 | 38 | in_domain_agg_score_df = agg_relative_score( 39 | results_dir / f"{model_name}-in-domain.csv", 40 | results_dir / f"{baseline_name}-in-domain.csv", 41 | ) 42 | in_domain_agg_score_df.name = "value" 43 | in_domain_agg_score_df.index.name = "metric" 44 | 45 | zero_shot_agg_score_df = agg_relative_score( 46 | results_dir / f"{model_name}-zero-shot.csv", 47 | results_dir / f"{baseline_name}-zero-shot.csv", 48 | ) 49 | zero_shot_agg_score_df.name = "value" 50 | zero_shot_agg_score_df.index.name = "metric" 51 | 52 | agg_score_df = pd.concat( 53 | {"in-domain": in_domain_agg_score_df, "zero-shot": zero_shot_agg_score_df}, 54 | names=["benchmark"], 55 | ) 56 | agg_score_df.to_csv(f"{results_dir}/{model_name}-agg-rel-scores.csv") 57 | 58 | 59 | if __name__ == "__main__": 60 | app() 61 | -------------------------------------------------------------------------------- /scripts/evaluation/configs/in-domain.yaml: -------------------------------------------------------------------------------- 1 | # Backtest configs for the 15 "in-domain" datasets. 2 | # The training portion of these datasets was part of the 3 | # training corpus for Chronos models. 4 | - name: electricity_15min 5 | hf_repo: autogluon/chronos_datasets 6 | offset: -5376 7 | prediction_length: 24 8 | num_rolls: 1 9 | - name: monash_electricity_hourly 10 | hf_repo: autogluon/chronos_datasets 11 | offset: -24 12 | prediction_length: 24 13 | num_rolls: 1 14 | - name: monash_electricity_weekly 15 | hf_repo: autogluon/chronos_datasets 16 | offset: -8 17 | prediction_length: 8 18 | num_rolls: 1 19 | - name: monash_kdd_cup_2018 20 | hf_repo: autogluon/chronos_datasets 21 | offset: -48 22 | prediction_length: 48 23 | num_rolls: 1 24 | - name: m4_daily 25 | hf_repo: autogluon/chronos_datasets 26 | offset: -14 27 | prediction_length: 14 28 | num_rolls: 1 29 | - name: m4_hourly 30 | hf_repo: autogluon/chronos_datasets 31 | offset: -48 32 | prediction_length: 48 33 | num_rolls: 1 34 | - name: m4_monthly 35 | hf_repo: autogluon/chronos_datasets 36 | offset: -18 37 | prediction_length: 18 38 | num_rolls: 1 39 | - name: m4_weekly 40 | hf_repo: autogluon/chronos_datasets 41 | offset: -13 42 | prediction_length: 13 43 | num_rolls: 1 44 | - name: monash_pedestrian_counts 45 | hf_repo: autogluon/chronos_datasets 46 | offset: -48 47 | prediction_length: 48 48 | num_rolls: 1 49 | - name: taxi_30min 50 | hf_repo: autogluon/chronos_datasets 51 | offset: -48 52 | prediction_length: 48 53 | num_rolls: 1 54 | - name: uber_tlc_hourly 55 | hf_repo: autogluon/chronos_datasets 56 | offset: -24 57 | prediction_length: 24 58 | num_rolls: 1 59 | - name: uber_tlc_daily 60 | hf_repo: autogluon/chronos_datasets 61 | offset: -7 62 | prediction_length: 7 63 | num_rolls: 1 64 | - name: monash_rideshare 65 | hf_repo: autogluon/chronos_datasets 66 | offset: -24 67 | prediction_length: 24 68 | num_rolls: 1 69 | - name: monash_temperature_rain 70 | hf_repo: autogluon/chronos_datasets 71 | offset: -30 72 | prediction_length: 30 73 | num_rolls: 1 74 | - name: monash_london_smart_meters 75 | hf_repo: autogluon/chronos_datasets 76 | offset: -48 77 | prediction_length: 48 78 | num_rolls: 1 79 | -------------------------------------------------------------------------------- /scripts/evaluation/configs/zero-shot.yaml: -------------------------------------------------------------------------------- 1 | # Backtest configs for the 27 "zero-shot" datasets. 2 | # These datasets were not seen by Chronos models during training. 3 | - name: monash_traffic 4 | hf_repo: autogluon/chronos_datasets 5 | offset: -24 6 | prediction_length: 24 7 | num_rolls: 1 8 | - name: monash_australian_electricity 9 | hf_repo: autogluon/chronos_datasets 10 | offset: -48 11 | prediction_length: 48 12 | num_rolls: 1 13 | - name: ercot 14 | hf_repo: autogluon/chronos_datasets 15 | offset: -24 16 | prediction_length: 24 17 | num_rolls: 1 18 | - name: ETTm 19 | hf_repo: autogluon/chronos_datasets_extra 20 | offset: -96 21 | prediction_length: 24 22 | num_rolls: 1 23 | - name: ETTh 24 | hf_repo: autogluon/chronos_datasets_extra 25 | offset: -24 26 | prediction_length: 24 27 | num_rolls: 1 28 | - name: exchange_rate 29 | hf_repo: autogluon/chronos_datasets 30 | offset: -30 31 | prediction_length: 30 32 | num_rolls: 1 33 | - name: nn5 34 | hf_repo: autogluon/chronos_datasets 35 | offset: -56 36 | prediction_length: 56 37 | num_rolls: 1 38 | - name: monash_nn5_weekly 39 | hf_repo: autogluon/chronos_datasets 40 | offset: -8 41 | prediction_length: 8 42 | num_rolls: 1 43 | - name: monash_weather 44 | hf_repo: autogluon/chronos_datasets 45 | offset: -30 46 | prediction_length: 30 47 | num_rolls: 1 48 | - name: monash_covid_deaths 49 | hf_repo: autogluon/chronos_datasets 50 | offset: -30 51 | prediction_length: 30 52 | num_rolls: 1 53 | - name: monash_fred_md 54 | hf_repo: autogluon/chronos_datasets 55 | offset: -12 56 | prediction_length: 12 57 | num_rolls: 1 58 | - name: m4_quarterly 59 | hf_repo: autogluon/chronos_datasets 60 | offset: -8 61 | prediction_length: 8 62 | num_rolls: 1 63 | - name: m4_yearly 64 | hf_repo: autogluon/chronos_datasets 65 | offset: -6 66 | prediction_length: 6 67 | num_rolls: 1 68 | - name: dominick 69 | hf_repo: autogluon/chronos_datasets 70 | offset: -8 71 | prediction_length: 8 72 | num_rolls: 1 73 | - name: m5 74 | hf_repo: autogluon/chronos_datasets 75 | offset: -28 76 | prediction_length: 28 77 | num_rolls: 1 78 | - name: monash_tourism_monthly 79 | hf_repo: autogluon/chronos_datasets 80 | offset: -24 81 | prediction_length: 24 82 | num_rolls: 1 83 | - name: monash_tourism_quarterly 84 | hf_repo: autogluon/chronos_datasets 85 | offset: -8 86 | prediction_length: 8 87 | num_rolls: 1 88 | - name: monash_tourism_yearly 89 | hf_repo: autogluon/chronos_datasets 90 | offset: -4 91 | prediction_length: 4 92 | num_rolls: 1 93 | - name: monash_car_parts 94 | hf_repo: autogluon/chronos_datasets 95 | offset: -12 96 | prediction_length: 12 97 | num_rolls: 1 98 | - name: monash_hospital 99 | hf_repo: autogluon/chronos_datasets 100 | offset: -12 101 | prediction_length: 12 102 | num_rolls: 1 103 | - name: monash_cif_2016 104 | hf_repo: autogluon/chronos_datasets 105 | offset: -12 106 | prediction_length: 12 107 | num_rolls: 1 108 | - name: monash_m1_yearly 109 | hf_repo: autogluon/chronos_datasets 110 | offset: -6 111 | prediction_length: 6 112 | num_rolls: 1 113 | - name: monash_m1_quarterly 114 | hf_repo: autogluon/chronos_datasets 115 | offset: -8 116 | prediction_length: 8 117 | num_rolls: 1 118 | - name: monash_m1_monthly 119 | hf_repo: autogluon/chronos_datasets 120 | offset: -18 121 | prediction_length: 18 122 | num_rolls: 1 123 | - name: monash_m3_monthly 124 | hf_repo: autogluon/chronos_datasets 125 | offset: -18 126 | prediction_length: 18 127 | num_rolls: 1 128 | - name: monash_m3_yearly 129 | hf_repo: autogluon/chronos_datasets 130 | offset: -6 131 | prediction_length: 6 132 | num_rolls: 1 133 | - name: monash_m3_quarterly 134 | hf_repo: autogluon/chronos_datasets 135 | offset: -8 136 | prediction_length: 8 137 | num_rolls: 1 -------------------------------------------------------------------------------- /scripts/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Iterable, Optional 4 | 5 | import datasets 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import typer 10 | import yaml 11 | from gluonts.dataset.split import split 12 | from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss 13 | from gluonts.itertools import batcher 14 | from gluonts.model.evaluation import evaluate_forecasts 15 | from gluonts.model.forecast import QuantileForecast, SampleForecast 16 | from tqdm.auto import tqdm 17 | 18 | from chronos import ( 19 | BaseChronosPipeline, 20 | ChronosBoltPipeline, 21 | ChronosPipeline, 22 | ForecastType, 23 | ) 24 | 25 | app = typer.Typer(pretty_exceptions_enable=False) 26 | 27 | 28 | def to_gluonts_univariate(hf_dataset: datasets.Dataset): 29 | series_fields = [ 30 | col 31 | for col in hf_dataset.features 32 | if isinstance(hf_dataset.features[col], datasets.Sequence) 33 | ] 34 | series_fields.remove("timestamp") 35 | dataset_length = hf_dataset.info.splits["train"].num_examples * len(series_fields) 36 | 37 | # Assumes that all time series in the dataset have the same frequency 38 | dataset_freq = pd.DatetimeIndex(hf_dataset[0]["timestamp"]).to_period()[0].freqstr 39 | 40 | gts_dataset = [] 41 | for hf_entry in hf_dataset: 42 | for field in series_fields: 43 | gts_dataset.append( 44 | { 45 | "start": pd.Period( 46 | hf_entry["timestamp"][0], 47 | freq=dataset_freq, 48 | ), 49 | "target": hf_entry[field], 50 | } 51 | ) 52 | assert len(gts_dataset) == dataset_length 53 | 54 | return gts_dataset 55 | 56 | 57 | def load_and_split_dataset(backtest_config: dict): 58 | hf_repo = backtest_config["hf_repo"] 59 | dataset_name = backtest_config["name"] 60 | offset = backtest_config["offset"] 61 | prediction_length = backtest_config["prediction_length"] 62 | num_rolls = backtest_config["num_rolls"] 63 | 64 | # This is needed because the datasets in autogluon/chronos_datasets_extra cannot 65 | # be distribued due to license restrictions and must be generated on the fly 66 | trust_remote_code = True if hf_repo == "autogluon/chronos_datasets_extra" else False 67 | 68 | ds = datasets.load_dataset( 69 | hf_repo, dataset_name, split="train", trust_remote_code=trust_remote_code 70 | ) 71 | ds.set_format("numpy") 72 | 73 | gts_dataset = to_gluonts_univariate(ds) 74 | 75 | # Split dataset for evaluation 76 | _, test_template = split(gts_dataset, offset=offset) 77 | test_data = test_template.generate_instances(prediction_length, windows=num_rolls) 78 | 79 | return test_data 80 | 81 | 82 | def generate_forecasts( 83 | test_data_input: Iterable, 84 | pipeline: BaseChronosPipeline, 85 | prediction_length: int, 86 | batch_size: int, 87 | **predict_kwargs, 88 | ): 89 | # Generate forecasts 90 | forecast_outputs = [] 91 | for batch in tqdm(batcher(test_data_input, batch_size=batch_size)): 92 | context = [torch.tensor(entry["target"]) for entry in batch] 93 | forecast_outputs.append( 94 | pipeline.predict( 95 | context, 96 | prediction_length=prediction_length, 97 | **predict_kwargs, 98 | ).numpy() 99 | ) 100 | forecast_outputs = np.concatenate(forecast_outputs) 101 | 102 | # Convert forecast samples into gluonts Forecast objects 103 | forecasts = [] 104 | for item, ts in zip(forecast_outputs, test_data_input): 105 | forecast_start_date = ts["start"] + len(ts["target"]) 106 | 107 | if pipeline.forecast_type == ForecastType.SAMPLES: 108 | forecasts.append( 109 | SampleForecast(samples=item, start_date=forecast_start_date) 110 | ) 111 | elif pipeline.forecast_type == ForecastType.QUANTILES: 112 | forecasts.append( 113 | QuantileForecast( 114 | forecast_arrays=item, 115 | forecast_keys=list(map(str, pipeline.quantiles)), 116 | start_date=forecast_start_date, 117 | ) 118 | ) 119 | 120 | return forecasts 121 | 122 | 123 | @app.command() 124 | def main( 125 | config_path: Path, 126 | metrics_path: Path, 127 | chronos_model_id: str = "amazon/chronos-t5-small", 128 | device: str = "cuda", 129 | torch_dtype: str = "bfloat16", 130 | batch_size: int = 32, 131 | num_samples: int = 20, 132 | temperature: Optional[float] = None, 133 | top_k: Optional[int] = None, 134 | top_p: Optional[float] = None, 135 | ): 136 | """Evaluate Chronos models. 137 | 138 | Parameters 139 | ---------- 140 | config_path : Path 141 | Path to the evaluation config. See ./configs/. 142 | metrics_path : Path 143 | Path to the CSV file where metrics will be saved. 144 | chronos_model_id : str, optional, default = "amazon/chronos-t5-small" 145 | HuggingFace ID of the Chronos model or local path 146 | Available models on HuggingFace: 147 | Chronos: 148 | - amazon/chronos-t5-tiny 149 | - amazon/chronos-t5-mini 150 | - amazon/chronos-t5-small 151 | - amazon/chronos-t5-base 152 | - amazon/chronos-t5-large 153 | Chronos-Bolt: 154 | - amazon/chronos-bolt-tiny 155 | - amazon/chronos-bolt-mini 156 | - amazon/chronos-bolt-small 157 | - amazon/chronos-bolt-base 158 | device : str, optional, default = "cuda" 159 | Device on which inference will be performed 160 | torch_dtype : str, optional 161 | Model's dtype, by default "bfloat16" 162 | batch_size : int, optional, default = 32 163 | Batch size for inference. For Chronos-Bolt models, significantly larger 164 | batch sizes can be used 165 | num_samples : int, optional, default = 20 166 | Number of samples to draw when using the original Chronos models 167 | temperature : Optional[float], optional, default = 1.0 168 | Softmax temperature to used for the original Chronos models 169 | top_k : Optional[int], optional, default = 50 170 | Top-K sampling, by default None 171 | top_p : Optional[float], optional, default = 1.0 172 | Top-p sampling, by default None 173 | """ 174 | if isinstance(torch_dtype, str): 175 | torch_dtype = getattr(torch, torch_dtype) 176 | assert isinstance(torch_dtype, torch.dtype) 177 | 178 | # Load Chronos 179 | pipeline = BaseChronosPipeline.from_pretrained( 180 | chronos_model_id, 181 | device_map=device, 182 | torch_dtype=torch_dtype, 183 | ) 184 | 185 | if isinstance(pipeline, ChronosPipeline): 186 | predict_kwargs = dict( 187 | num_samples=num_samples, 188 | temperature=temperature, 189 | top_k=top_k, 190 | top_p=top_p, 191 | ) 192 | elif isinstance(pipeline, ChronosBoltPipeline): 193 | predict_kwargs = {} 194 | 195 | # Load backtest configs 196 | with open(config_path) as fp: 197 | backtest_configs = yaml.safe_load(fp) 198 | 199 | result_rows = [] 200 | for config in backtest_configs: 201 | dataset_name = config["name"] 202 | prediction_length = config["prediction_length"] 203 | 204 | logger.info(f"Loading {dataset_name}") 205 | test_data = load_and_split_dataset(backtest_config=config) 206 | 207 | logger.info( 208 | f"Generating forecasts for {dataset_name} " 209 | f"({len(test_data.input)} time series)" 210 | ) 211 | forecasts = generate_forecasts( 212 | test_data.input, 213 | pipeline=pipeline, 214 | prediction_length=prediction_length, 215 | batch_size=batch_size, 216 | **predict_kwargs, 217 | ) 218 | 219 | logger.info(f"Evaluating forecasts for {dataset_name}") 220 | metrics = ( 221 | evaluate_forecasts( 222 | forecasts, 223 | test_data=test_data, 224 | metrics=[ 225 | MASE(), 226 | MeanWeightedSumQuantileLoss(np.arange(0.1, 1.0, 0.1)), 227 | ], 228 | batch_size=5000, 229 | ) 230 | .reset_index(drop=True) 231 | .to_dict(orient="records") 232 | ) 233 | result_rows.append( 234 | {"dataset": dataset_name, "model": chronos_model_id, **metrics[0]} 235 | ) 236 | 237 | # Save results to a CSV file 238 | results_df = ( 239 | pd.DataFrame(result_rows) 240 | .rename( 241 | {"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL"}, 242 | axis="columns", 243 | ) 244 | .sort_values(by="dataset") 245 | ) 246 | results_df.to_csv(metrics_path, index=False) 247 | 248 | 249 | if __name__ == "__main__": 250 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") 251 | logger = logging.getLogger("Chronos Evaluation") 252 | logger.setLevel(logging.INFO) 253 | app() 254 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-base-agg-rel-scores.csv: -------------------------------------------------------------------------------- 1 | benchmark,metric,value 2 | in-domain,MASE,0.6800133628315155 3 | in-domain,WQL,0.5339263811489279 4 | zero-shot,MASE,0.7914551113353537 5 | zero-shot,WQL,0.6241424984163773 6 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-base-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,amazon/chronos-bolt-base,0.41069374835605243,0.0703533790998506 3 | m4_daily,amazon/chronos-bolt-base,3.205192517121196,0.02110308498174413 4 | m4_hourly,amazon/chronos-bolt-base,0.8350129849014075,0.025353803894164 5 | m4_monthly,amazon/chronos-bolt-base,0.9491758928362231,0.09382496106659234 6 | m4_weekly,amazon/chronos-bolt-base,2.0847827409162742,0.03816605075768161 7 | monash_electricity_hourly,amazon/chronos-bolt-base,1.254966217685461,0.09442192616975713 8 | monash_electricity_weekly,amazon/chronos-bolt-base,1.8391546050108039,0.06410971963960499 9 | monash_kdd_cup_2018,amazon/chronos-bolt-base,0.6405985809360102,0.2509172188706336 10 | monash_london_smart_meters,amazon/chronos-bolt-base,0.701398572604996,0.3218915088923906 11 | monash_pedestrian_counts,amazon/chronos-bolt-base,0.2646412642278343,0.18789459806066328 12 | monash_rideshare,amazon/chronos-bolt-base,0.7695376426829713,0.11637119433040358 13 | monash_temperature_rain,amazon/chronos-bolt-base,0.8983612698773724,0.6050555216496304 14 | taxi_30min,amazon/chronos-bolt-base,0.7688908266765317,0.2363178601205094 15 | uber_tlc_daily,amazon/chronos-bolt-base,0.8231767493519677,0.0926036406916842 16 | uber_tlc_hourly,amazon/chronos-bolt-base,0.6632193728217927,0.14987786887626975 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-base-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,amazon/chronos-bolt-base,0.7479154031956647,0.07062173821055001 3 | ETTm,amazon/chronos-bolt-base,0.6334357237512225,0.052261607745858835 4 | dominick,amazon/chronos-bolt-base,0.8560272479913918,0.3453573743726445 5 | ercot,amazon/chronos-bolt-base,0.6933217425507392,0.02142183038021456 6 | exchange_rate,amazon/chronos-bolt-base,1.7095176257412634,0.01200682136751536 7 | m4_quarterly,amazon/chronos-bolt-base,1.2244670010522907,0.0771066518089854 8 | m4_yearly,amazon/chronos-bolt-base,3.513752058541554,0.12142798053483984 9 | m5,amazon/chronos-bolt-base,0.9152230096463854,0.561999688057527 10 | monash_australian_electricity,amazon/chronos-bolt-base,0.7403239930185613,0.03584034231329335 11 | monash_car_parts,amazon/chronos-bolt-base,0.8550263912438314,0.9945122291263591 12 | monash_cif_2016,amazon/chronos-bolt-base,0.9988541862779904,0.016456104842296485 13 | monash_covid_deaths,amazon/chronos-bolt-base,38.901749109066415,0.047410971217640714 14 | monash_fred_md,amazon/chronos-bolt-base,0.6468787708795645,0.04185083716355386 15 | monash_hospital,amazon/chronos-bolt-base,0.6883138394434054,0.057032869931903894 16 | monash_m1_monthly,amazon/chronos-bolt-base,1.0997677446267855,0.1392311148066238 17 | monash_m1_quarterly,amazon/chronos-bolt-base,1.7737851980875563,0.1007118219350403 18 | monash_m1_yearly,amazon/chronos-bolt-base,4.404672537832342,0.1504617654430952 19 | monash_m3_monthly,amazon/chronos-bolt-base,0.8510696834878182,0.09269673913736748 20 | monash_m3_quarterly,amazon/chronos-bolt-base,1.2890908822598466,0.07615133571216029 21 | monash_m3_yearly,amazon/chronos-bolt-base,2.9067097980770082,0.12934285625258413 22 | monash_nn5_weekly,amazon/chronos-bolt-base,0.9158766337957451,0.08352114810139548 23 | monash_tourism_monthly,amazon/chronos-bolt-base,1.5283388458731357,0.09026425492612797 24 | monash_tourism_quarterly,amazon/chronos-bolt-base,1.756127005530011,0.06448060953595125 25 | monash_tourism_yearly,amazon/chronos-bolt-base,3.691545772463519,0.16548820700844424 26 | monash_traffic,amazon/chronos-bolt-base,0.7843310867739336,0.23148632068725078 27 | monash_weather,amazon/chronos-bolt-base,0.8115247139672316,0.13350830777170594 28 | nn5,amazon/chronos-bolt-base,0.5764084996361287,0.1500519584148468 29 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-mini-agg-rel-scores.csv: -------------------------------------------------------------------------------- 1 | benchmark,metric,value 2 | in-domain,MASE,0.7268373301543752 3 | in-domain,WQL,0.565140251955324 4 | zero-shot,MASE,0.8221798917822493 5 | zero-shot,WQL,0.6441645845380903 6 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-mini-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,amazon/chronos-bolt-mini,0.44185193304080733,0.0731477927531107 3 | m4_daily,amazon/chronos-bolt-mini,3.1342608828747456,0.0206872246743766 4 | m4_hourly,amazon/chronos-bolt-mini,0.9218285923038745,0.024383114886067574 5 | m4_monthly,amazon/chronos-bolt-mini,0.9628339921394529,0.09502498697494888 6 | m4_weekly,amazon/chronos-bolt-mini,2.2330452369879255,0.039393515325238534 7 | monash_electricity_hourly,amazon/chronos-bolt-mini,1.6195944363428718,0.11468972600782207 8 | monash_electricity_weekly,amazon/chronos-bolt-mini,1.866105365159433,0.06019900031840434 9 | monash_kdd_cup_2018,amazon/chronos-bolt-mini,0.74790954883436,0.3012661161484388 10 | monash_london_smart_meters,amazon/chronos-bolt-mini,0.7187830347765344,0.32984510693830227 11 | monash_pedestrian_counts,amazon/chronos-bolt-mini,0.308633944815819,0.23331301029432483 12 | monash_rideshare,amazon/chronos-bolt-mini,0.818948044410056,0.1297966960374544 13 | monash_temperature_rain,amazon/chronos-bolt-mini,0.9035244443682741,0.605031064086567 14 | taxi_30min,amazon/chronos-bolt-mini,0.812010120941363,0.25232294549917317 15 | uber_tlc_daily,amazon/chronos-bolt-mini,0.8507256206478295,0.10101757743084538 16 | uber_tlc_hourly,amazon/chronos-bolt-mini,0.6685484898085609,0.1515245941548974 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-mini-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,amazon/chronos-bolt-mini,0.8057126710113404,0.07740387596411452 3 | ETTm,amazon/chronos-bolt-mini,0.6100793941108849,0.05129333450944573 4 | dominick,amazon/chronos-bolt-mini,0.8664152477208024,0.3499696999160997 5 | ercot,amazon/chronos-bolt-mini,0.6871250728215426,0.02448804863744021 6 | exchange_rate,amazon/chronos-bolt-mini,1.3520551553333662,0.00934663373172766 7 | m4_quarterly,amazon/chronos-bolt-mini,1.2569644266281508,0.07833787023275976 8 | m4_yearly,amazon/chronos-bolt-mini,3.7611003052413796,0.12931927951165456 9 | m5,amazon/chronos-bolt-mini,0.9188876472137485,0.5661303206519673 10 | monash_australian_electricity,amazon/chronos-bolt-mini,0.8823559450287066,0.04493688824488474 11 | monash_car_parts,amazon/chronos-bolt-mini,0.8604081423647779,1.0041876404811494 12 | monash_cif_2016,amazon/chronos-bolt-mini,1.0762361363763873,0.017641893717784202 13 | monash_covid_deaths,amazon/chronos-bolt-mini,38.83915011538576,0.06098317835750057 14 | monash_fred_md,amazon/chronos-bolt-mini,0.6169859211923081,0.03256236965040934 15 | monash_hospital,amazon/chronos-bolt-mini,0.6924431064606051,0.05766349075348645 16 | monash_m1_monthly,amazon/chronos-bolt-mini,1.147893030263777,0.13270222658510553 17 | monash_m1_quarterly,amazon/chronos-bolt-mini,1.8662100001165818,0.09846363409254102 18 | monash_m1_yearly,amazon/chronos-bolt-mini,5.319154632748303,0.16167328827180308 19 | monash_m3_monthly,amazon/chronos-bolt-mini,0.8758452776118432,0.09493431248614057 20 | monash_m3_quarterly,amazon/chronos-bolt-mini,1.3555175243802005,0.07808062465932723 21 | monash_m3_yearly,amazon/chronos-bolt-mini,3.605769430055575,0.15711010456482008 22 | monash_nn5_weekly,amazon/chronos-bolt-mini,0.9347141924977239,0.08522899825844342 23 | monash_tourism_monthly,amazon/chronos-bolt-mini,1.649587479665881,0.0979648261309891 24 | monash_tourism_quarterly,amazon/chronos-bolt-mini,1.8471553663088986,0.06501077791766902 25 | monash_tourism_yearly,amazon/chronos-bolt-mini,3.9932920493826245,0.1743539122097316 26 | monash_traffic,amazon/chronos-bolt-mini,0.8355442361271347,0.24351051123330386 27 | monash_weather,amazon/chronos-bolt-mini,0.800013628350165,0.13041050756802045 28 | nn5,amazon/chronos-bolt-mini,0.611917632501032,0.1570111102680171 29 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-small-agg-rel-scores.csv: -------------------------------------------------------------------------------- 1 | benchmark,metric,value 2 | in-domain,MASE,0.7030801652116672 3 | in-domain,WQL,0.5443547623341555 4 | zero-shot,MASE,0.8192127745093378 5 | zero-shot,WQL,0.6356097843099521 6 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-small-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,amazon/chronos-bolt-small,0.44920089250026723,0.08115291306964295 3 | m4_daily,amazon/chronos-bolt-small,3.201966619014735,0.02143368277732494 4 | m4_hourly,amazon/chronos-bolt-small,0.8686298207618999,0.020368729287465817 5 | m4_monthly,amazon/chronos-bolt-small,0.9537717737278778,0.0939247807527992 6 | m4_weekly,amazon/chronos-bolt-small,2.1236755094789177,0.03785184715517262 7 | monash_electricity_hourly,amazon/chronos-bolt-small,1.3728906161330452,0.09452411472431674 8 | monash_electricity_weekly,amazon/chronos-bolt-small,1.8703239487242378,0.06648479071326366 9 | monash_kdd_cup_2018,amazon/chronos-bolt-small,0.6458631909979771,0.25148489931571666 10 | monash_london_smart_meters,amazon/chronos-bolt-small,0.7126939688565166,0.326874529903459 11 | monash_pedestrian_counts,amazon/chronos-bolt-small,0.3015070035798365,0.2285590441093863 12 | monash_rideshare,amazon/chronos-bolt-small,0.823726965741684,0.12409769473500927 13 | monash_temperature_rain,amazon/chronos-bolt-small,0.8980348827836525,0.5984819599873311 14 | taxi_30min,amazon/chronos-bolt-small,0.7597818149895785,0.2348569752311862 15 | uber_tlc_daily,amazon/chronos-bolt-small,0.8460854328036702,0.09666483354735897 16 | uber_tlc_hourly,amazon/chronos-bolt-small,0.6662547495017634,0.1524256346268063 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-small-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,amazon/chronos-bolt-small,0.792521748651108,0.07590654063011319 3 | ETTm,amazon/chronos-bolt-small,0.6209623928936988,0.05056189722606397 4 | dominick,amazon/chronos-bolt-small,0.8706134610400587,0.34811141409475416 5 | ercot,amazon/chronos-bolt-small,0.7562857616685997,0.02596064260343696 6 | exchange_rate,amazon/chronos-bolt-small,1.774835301692689,0.011363548847621512 7 | m4_quarterly,amazon/chronos-bolt-small,1.2478142413437487,0.07808795122806232 8 | m4_yearly,amazon/chronos-bolt-small,3.6925595655002574,0.12772564181388502 9 | m5,amazon/chronos-bolt-small,0.9195435643571084,0.5668430814831332 10 | monash_australian_electricity,amazon/chronos-bolt-small,0.8128424798841111,0.041509852162861564 11 | monash_car_parts,amazon/chronos-bolt-small,0.8584574663781737,1.0074689402521324 12 | monash_cif_2016,amazon/chronos-bolt-small,1.0182471909074982,0.01581964877692293 13 | monash_covid_deaths,amazon/chronos-bolt-small,36.467595559655145,0.0427382859406882 14 | monash_fred_md,amazon/chronos-bolt-small,0.6132863794635253,0.03730410577241995 15 | monash_hospital,amazon/chronos-bolt-small,0.6954489513780618,0.058119864671526154 16 | monash_m1_monthly,amazon/chronos-bolt-small,1.1277621848099244,0.1335656174632902 17 | monash_m1_quarterly,amazon/chronos-bolt-small,1.8356144904231688,0.09363028483838018 18 | monash_m1_yearly,amazon/chronos-bolt-small,5.098146069746402,0.15669928873371905 19 | monash_m3_monthly,amazon/chronos-bolt-small,0.8685125121306435,0.09396568468255145 20 | monash_m3_quarterly,amazon/chronos-bolt-small,1.3269103591066727,0.07691022995374203 21 | monash_m3_yearly,amazon/chronos-bolt-small,3.40993282700627,0.1547639821304127 22 | monash_nn5_weekly,amazon/chronos-bolt-small,0.9266513350636507,0.08452821221908001 23 | monash_tourism_monthly,amazon/chronos-bolt-small,1.6106732721197876,0.09362336754317802 24 | monash_tourism_quarterly,amazon/chronos-bolt-small,1.8357819365308639,0.06734337535269994 25 | monash_tourism_yearly,amazon/chronos-bolt-small,3.8963100495394194,0.16766064312072784 26 | monash_traffic,amazon/chronos-bolt-small,0.8598507749866499,0.25173786112983054 27 | monash_weather,amazon/chronos-bolt-small,0.8020408743877911,0.13258563963844888 28 | nn5,amazon/chronos-bolt-small,0.5833047644729239,0.15066847836762787 29 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-tiny-agg-rel-scores.csv: -------------------------------------------------------------------------------- 1 | benchmark,metric,value 2 | in-domain,MASE,0.7403252781013574 3 | in-domain,WQL,0.5733728165523524 4 | zero-shot,MASE,0.8445407343705457 5 | zero-shot,WQL,0.6678781905023173 6 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-tiny-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,amazon/chronos-bolt-tiny,0.4676384089765091,0.0861229808117837 3 | m4_daily,amazon/chronos-bolt-tiny,3.1789994761356795,0.020961883512815756 4 | m4_hourly,amazon/chronos-bolt-tiny,0.9348005698736752,0.021087527284114574 5 | m4_monthly,amazon/chronos-bolt-tiny,0.965298729632761,0.0950380483243082 6 | m4_weekly,amazon/chronos-bolt-tiny,2.261575511029903,0.04093653263178429 7 | monash_electricity_hourly,amazon/chronos-bolt-tiny,1.5739346351263623,0.10808418398945202 8 | monash_electricity_weekly,amazon/chronos-bolt-tiny,1.8628689103722829,0.05773335283584782 9 | monash_kdd_cup_2018,amazon/chronos-bolt-tiny,0.6869549985391232,0.28012801092758166 10 | monash_london_smart_meters,amazon/chronos-bolt-tiny,0.7284234905933779,0.33496438244693033 11 | monash_pedestrian_counts,amazon/chronos-bolt-tiny,0.32338947321773864,0.2530637833749087 12 | monash_rideshare,amazon/chronos-bolt-tiny,0.8562780835002918,0.1304317657933891 13 | monash_temperature_rain,amazon/chronos-bolt-tiny,0.9030707620825977,0.6064087080755548 14 | taxi_30min,amazon/chronos-bolt-tiny,0.9122159603256838,0.28002194370731626 15 | uber_tlc_daily,amazon/chronos-bolt-tiny,0.9087055420190513,0.11193388685815164 16 | uber_tlc_hourly,amazon/chronos-bolt-tiny,0.6716569179590032,0.15310845458208555 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-bolt-tiny-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,amazon/chronos-bolt-tiny,0.7941225847155844,0.07480860969990633 3 | ETTm,amazon/chronos-bolt-tiny,0.6508270995240056,0.05440068825429993 4 | dominick,amazon/chronos-bolt-tiny,0.876060127216559,0.35175949052933253 5 | ercot,amazon/chronos-bolt-tiny,0.7309134980173839,0.02468604544464515 6 | exchange_rate,amazon/chronos-bolt-tiny,1.6857262567077134,0.011477224264784112 7 | m4_quarterly,amazon/chronos-bolt-tiny,1.2605908919338378,0.0789049420017836 8 | m4_yearly,amazon/chronos-bolt-tiny,3.7118394116161757,0.1286932555969197 9 | m5,amazon/chronos-bolt-tiny,0.9195469670062033,0.5634881835998845 10 | monash_australian_electricity,amazon/chronos-bolt-tiny,0.8419304693259403,0.042040993880313904 11 | monash_car_parts,amazon/chronos-bolt-tiny,0.8625579150452282,1.0009987800801836 12 | monash_cif_2016,amazon/chronos-bolt-tiny,1.095219642027011,0.017550336784241796 13 | monash_covid_deaths,amazon/chronos-bolt-tiny,40.674057986280744,0.06723714516685976 14 | monash_fred_md,amazon/chronos-bolt-tiny,0.6127387450520702,0.04747523852271518 15 | monash_hospital,amazon/chronos-bolt-tiny,0.6980246281225624,0.05864223243167421 16 | monash_m1_monthly,amazon/chronos-bolt-tiny,1.1625495971731141,0.13142237467151166 17 | monash_m1_quarterly,amazon/chronos-bolt-tiny,1.8941765599193754,0.09972207844232561 18 | monash_m1_yearly,amazon/chronos-bolt-tiny,5.136332694531757,0.160331813128038 19 | monash_m3_monthly,amazon/chronos-bolt-tiny,0.8744553726704598,0.09435519378597752 20 | monash_m3_quarterly,amazon/chronos-bolt-tiny,1.364563776692303,0.07875066385737857 21 | monash_m3_yearly,amazon/chronos-bolt-tiny,3.3685961410254928,0.15158076519486274 22 | monash_nn5_weekly,amazon/chronos-bolt-tiny,0.9324436794013877,0.0847385189968909 23 | monash_tourism_monthly,amazon/chronos-bolt-tiny,1.7895936775088157,0.1058167042693116 24 | monash_tourism_quarterly,amazon/chronos-bolt-tiny,2.095262637810499,0.0710732570354461 25 | monash_tourism_yearly,amazon/chronos-bolt-tiny,4.042821441327848,0.172613367251472 26 | monash_traffic,amazon/chronos-bolt-tiny,0.8836032533767518,0.2574297134210491 27 | monash_weather,amazon/chronos-bolt-tiny,0.8005348255663177,0.13111355494466076 28 | nn5,amazon/chronos-bolt-tiny,0.7228248498869763,0.1816913098894226 29 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-base-agg-rel-scores.csv: -------------------------------------------------------------------------------- 1 | benchmark,metric,value 2 | in-domain,MASE,0.7007558507277635 3 | in-domain,WQL,0.5786300105297922 4 | zero-shot,MASE,0.8155209321160994 5 | zero-shot,WQL,0.6424634919486323 6 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-base-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,amazon/chronos-t5-base,0.39879754957261204,0.07738953262286181 3 | m4_daily,amazon/chronos-t5-base,3.160575865614404,0.02194256368254537 4 | m4_hourly,amazon/chronos-t5-base,0.6938747745332102,0.026354948301302205 5 | m4_monthly,amazon/chronos-t5-base,0.971951848755026,0.10355213196432872 6 | m4_weekly,amazon/chronos-t5-base,2.0143841267657945,0.03639741235815474 7 | monash_electricity_hourly,amazon/chronos-t5-base,1.5717251971297332,0.1078882125804548 8 | monash_electricity_weekly,amazon/chronos-t5-base,1.7862927210886668,0.06255982783148449 9 | monash_kdd_cup_2018,amazon/chronos-t5-base,0.6335225775496138,0.2684272353843692 10 | monash_london_smart_meters,amazon/chronos-t5-base,0.8362014889190201,0.4265549499082726 11 | monash_pedestrian_counts,amazon/chronos-t5-base,0.2817708325561419,0.20810108090665583 12 | monash_rideshare,amazon/chronos-t5-base,0.8614480533175364,0.1356591190888703 13 | monash_temperature_rain,amazon/chronos-t5-base,0.9692405156151607,0.660155448791624 14 | taxi_30min,amazon/chronos-t5-base,0.8186287575356217,0.26236060366367003 15 | uber_tlc_daily,amazon/chronos-t5-base,0.8338648311528079,0.0970875577681834 16 | uber_tlc_hourly,amazon/chronos-t5-base,0.6647193438331641,0.15436646659512512 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-base-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,amazon/chronos-t5-base,0.7653491494991778,0.08087267701042929 3 | ETTm,amazon/chronos-t5-base,0.7737006634032871,0.07008650633028274 4 | dominick,amazon/chronos-t5-base,0.8194044957573132,0.33201307438298133 5 | ercot,amazon/chronos-t5-base,0.5014399265038706,0.013589435745554596 6 | exchange_rate,amazon/chronos-t5-base,2.055616906406159,0.011066070028466317 7 | m4_quarterly,amazon/chronos-t5-base,1.2253036947743137,0.08327936201395683 8 | m4_yearly,amazon/chronos-t5-base,3.639991540990927,0.13539258375263963 9 | m5,amazon/chronos-t5-base,0.9391874615167101,0.5867234116216755 10 | monash_australian_electricity,amazon/chronos-t5-base,1.2944069383163321,0.07070604202031877 11 | monash_car_parts,amazon/chronos-t5-base,0.9071940271035218,1.077797124337994 12 | monash_cif_2016,amazon/chronos-t5-base,0.9840747802099565,0.011825556826558836 13 | monash_covid_deaths,amazon/chronos-t5-base,42.68503365359237,0.042229910495746356 14 | monash_fred_md,amazon/chronos-t5-base,0.4857773806790164,0.021204829049512715 15 | monash_hospital,amazon/chronos-t5-base,0.7053005021431749,0.05630687524507516 16 | monash_m1_monthly,amazon/chronos-t5-base,1.1153039466137842,0.12724419775326076 17 | monash_m1_quarterly,amazon/chronos-t5-base,1.746093728928804,0.1123583549291933 18 | monash_m1_yearly,amazon/chronos-t5-base,4.401291522370069,0.18541586641719554 19 | monash_m3_monthly,amazon/chronos-t5-base,0.8627172231908679,0.09640536232169555 20 | monash_m3_quarterly,amazon/chronos-t5-base,1.1696030904401578,0.07392876900131434 21 | monash_m3_yearly,amazon/chronos-t5-base,3.1298600218573775,0.1486674447940158 22 | monash_nn5_weekly,amazon/chronos-t5-base,0.9334860602210187,0.08972736821598823 23 | monash_tourism_monthly,amazon/chronos-t5-base,1.7937702435879332,0.10260220444264027 24 | monash_tourism_quarterly,amazon/chronos-t5-base,1.7791494997972261,0.06852507950474919 25 | monash_tourism_yearly,amazon/chronos-t5-base,3.8359926053603197,0.20722699382964643 26 | monash_traffic,amazon/chronos-t5-base,0.8015262383138622,0.25565153982140926 27 | monash_weather,amazon/chronos-t5-base,0.8159511190589147,0.13802320967454584 28 | nn5,amazon/chronos-t5-base,0.5927076179914024,0.1630476065585159 29 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-large-agg-rel-scores.csv: -------------------------------------------------------------------------------- 1 | benchmark,metric,value 2 | in-domain,MASE,0.6944869734691035 3 | in-domain,WQL,0.5596857927462495 4 | zero-shot,MASE,0.8213682201405101 5 | zero-shot,WQL,0.6504834081319559 6 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-large-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,amazon/chronos-t5-large,0.3866310906621673,0.07759528615667297 3 | m4_daily,amazon/chronos-t5-large,3.134560968849699,0.02158279722410466 4 | m4_hourly,amazon/chronos-t5-large,0.6975930649233378,0.02086427219957674 5 | m4_monthly,amazon/chronos-t5-large,0.9585550091429409,0.10091221432814867 6 | m4_weekly,amazon/chronos-t5-large,2.0191422600104425,0.036912838355537186 7 | monash_electricity_hourly,amazon/chronos-t5-large,1.4069912853901292,0.09642382339452431 8 | monash_electricity_weekly,amazon/chronos-t5-large,1.7501880036182798,0.05765306465830232 9 | monash_kdd_cup_2018,amazon/chronos-t5-large,0.6788042816175427,0.2853553329804835 10 | monash_london_smart_meters,amazon/chronos-t5-large,0.8290300790418726,0.4235436387853963 11 | monash_pedestrian_counts,amazon/chronos-t5-large,0.2764118100521592,0.18692234491663473 12 | monash_rideshare,amazon/chronos-t5-large,0.8758058784466208,0.140260325368757 13 | monash_temperature_rain,amazon/chronos-t5-large,0.9738403865035117,0.6604571928063249 14 | taxi_30min,amazon/chronos-t5-large,0.8245662397270109,0.2653520120326771 15 | uber_tlc_daily,amazon/chronos-t5-large,0.8044165990021739,0.09499035584302248 16 | uber_tlc_hourly,amazon/chronos-t5-large,0.6700665937164474,0.15190288476653066 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-large-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,amazon/chronos-t5-large,0.78160443631164,0.07884375667736107 3 | ETTm,amazon/chronos-t5-large,0.7325919639389967,0.06656858270921162 4 | dominick,amazon/chronos-t5-large,0.8200108271155829,0.3311575649734524 5 | ercot,amazon/chronos-t5-large,0.6050812633742764,0.01822996942395577 6 | exchange_rate,amazon/chronos-t5-large,2.3439287001928744,0.014841231672174684 7 | m4_quarterly,amazon/chronos-t5-large,1.2169666607868148,0.08235162400898562 8 | m4_yearly,amazon/chronos-t5-large,3.5524979814018947,0.1325675848907479 9 | m5,amazon/chronos-t5-large,0.9422990989146737,0.585615077637479 10 | monash_australian_electricity,amazon/chronos-t5-large,1.480849838497958,0.07973968848149568 11 | monash_car_parts,amazon/chronos-t5-large,0.901547374873302,1.0467398096496576 12 | monash_cif_2016,amazon/chronos-t5-large,0.9906388185665337,0.011966178555329998 13 | monash_covid_deaths,amazon/chronos-t5-large,44.07354193681227,0.06108999981222163 14 | monash_fred_md,amazon/chronos-t5-large,0.5184400880318044,0.01675533888399231 15 | monash_hospital,amazon/chronos-t5-large,0.7055308474630898,0.0552450850258613 16 | monash_m1_monthly,amazon/chronos-t5-large,1.0888995301234758,0.12729911122909737 17 | monash_m1_quarterly,amazon/chronos-t5-large,1.7477134564031453,0.10618253695380094 18 | monash_m1_yearly,amazon/chronos-t5-large,4.250667049416348,0.17128879333643188 19 | monash_m3_monthly,amazon/chronos-t5-large,0.8559326975903808,0.09572577431396007 20 | monash_m3_quarterly,amazon/chronos-t5-large,1.1867267751420676,0.07449254281607631 21 | monash_m3_yearly,amazon/chronos-t5-large,3.0239493021840635,0.14814710375646464 22 | monash_nn5_weekly,amazon/chronos-t5-large,0.9228721852437364,0.08948447200571868 23 | monash_tourism_monthly,amazon/chronos-t5-large,1.7304427846580348,0.09983169221760163 24 | monash_tourism_quarterly,amazon/chronos-t5-large,1.6437184365114073,0.0690906057781915 25 | monash_tourism_yearly,amazon/chronos-t5-large,3.6268503118928535,0.17732007043832695 26 | monash_traffic,amazon/chronos-t5-large,0.7985975530866148,0.25313515740581755 27 | monash_weather,amazon/chronos-t5-large,0.8187388457436171,0.1387756772600068 28 | nn5,amazon/chronos-t5-large,0.5755260854173723,0.15733693855465292 29 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-mini-agg-rel-scores.csv: -------------------------------------------------------------------------------- 1 | benchmark,metric,value 2 | in-domain,MASE,0.7249816823595568 3 | in-domain,WQL,0.5965372489622094 4 | zero-shot,MASE,0.8411995116926901 5 | zero-shot,WQL,0.6888397962259065 6 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-mini-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,amazon/chronos-t5-mini,0.4446629660227641,0.08114657599239496 3 | m4_daily,amazon/chronos-t5-mini,3.1533349226194005,0.022000507013584743 4 | m4_hourly,amazon/chronos-t5-mini,0.7616830292996938,0.024630575107847653 5 | m4_monthly,amazon/chronos-t5-mini,0.9934074425853089,0.10402168689068064 6 | m4_weekly,amazon/chronos-t5-mini,2.1407189608104416,0.04138058102434373 7 | monash_electricity_hourly,amazon/chronos-t5-mini,1.3698378948313894,0.09189698159081384 8 | monash_electricity_weekly,amazon/chronos-t5-mini,1.9238345295706893,0.07015383787479901 9 | monash_kdd_cup_2018,amazon/chronos-t5-mini,0.6027861468526459,0.25493489598663444 10 | monash_london_smart_meters,amazon/chronos-t5-mini,0.8570035850603943,0.4356582737588471 11 | monash_pedestrian_counts,amazon/chronos-t5-mini,0.30374539593979855,0.2374083216051065 12 | monash_rideshare,amazon/chronos-t5-mini,0.8157349455509949,0.12963515638823117 13 | monash_temperature_rain,amazon/chronos-t5-mini,1.010161905102516,0.6919171702485583 14 | taxi_30min,amazon/chronos-t5-mini,0.9318379552979712,0.31229508015999674 15 | uber_tlc_daily,amazon/chronos-t5-mini,0.9213437323817685,0.10475291429149586 16 | uber_tlc_hourly,amazon/chronos-t5-mini,0.6812621470377416,0.15982192635434303 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-mini-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,amazon/chronos-t5-mini,0.789678971785092,0.08068969536800001 3 | ETTm,amazon/chronos-t5-mini,0.7521219674190734,0.06791782942706617 4 | dominick,amazon/chronos-t5-mini,0.8207116999488602,0.34004499734299765 5 | ercot,amazon/chronos-t5-mini,0.5462749489237783,0.015035001020343136 6 | exchange_rate,amazon/chronos-t5-mini,2.1326718165798657,0.015073846769933199 7 | m4_quarterly,amazon/chronos-t5-mini,1.271761811062081,0.08575942238385105 8 | m4_yearly,amazon/chronos-t5-mini,3.7340853642679126,0.13938781939783162 9 | m5,amazon/chronos-t5-mini,0.9421556321929742,0.5961689098871504 10 | monash_australian_electricity,amazon/chronos-t5-mini,1.046297291920238,0.05424453772723559 11 | monash_car_parts,amazon/chronos-t5-mini,0.8913523483805221,1.0174797526818506 12 | monash_cif_2016,amazon/chronos-t5-mini,1.0674111822055679,0.016800831829085764 13 | monash_covid_deaths,amazon/chronos-t5-mini,43.69727825485175,0.08788117644141617 14 | monash_fred_md,amazon/chronos-t5-mini,0.46227452519609524,0.01871860604459728 15 | monash_hospital,amazon/chronos-t5-mini,0.7112593459108532,0.05831005112661489 16 | monash_m1_monthly,amazon/chronos-t5-mini,1.1756557848450433,0.14192178371159841 17 | monash_m1_quarterly,amazon/chronos-t5-mini,1.795009199698074,0.11760148522768847 18 | monash_m1_yearly,amazon/chronos-t5-mini,5.078889706085604,0.1882823108615221 19 | monash_m3_monthly,amazon/chronos-t5-mini,0.900404391663476,0.09935931092075681 20 | monash_m3_quarterly,amazon/chronos-t5-mini,1.2604342624229292,0.07807204797138119 21 | monash_m3_yearly,amazon/chronos-t5-mini,3.4395976709464255,0.16085249526114198 22 | monash_nn5_weekly,amazon/chronos-t5-mini,0.9459117943913629,0.09042762527674755 23 | monash_tourism_monthly,amazon/chronos-t5-mini,1.920865545569713,0.10791754513335952 24 | monash_tourism_quarterly,amazon/chronos-t5-mini,1.7957439111869486,0.07514539225156464 25 | monash_tourism_yearly,amazon/chronos-t5-mini,4.134958090482728,0.2202036957350168 26 | monash_traffic,amazon/chronos-t5-mini,0.8546792774237857,0.2668831661775284 27 | monash_weather,amazon/chronos-t5-mini,0.8607748244159247,0.15031866806333247 28 | nn5,amazon/chronos-t5-mini,0.6497211196906223,0.17352254058241523 29 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-small-agg-rel-scores.csv: -------------------------------------------------------------------------------- 1 | benchmark,metric,value 2 | in-domain,MASE,0.7296140269944743 3 | in-domain,WQL,0.6086958548874499 4 | zero-shot,MASE,0.8303721909132112 5 | zero-shot,WQL,0.6649587072099045 6 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-small-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,amazon/chronos-t5-small,0.4115559557750193,0.08085148902238105 3 | m4_daily,amazon/chronos-t5-small,3.1384304946608896,0.02129901023419818 4 | m4_hourly,amazon/chronos-t5-small,0.7300874075370588,0.024686127211237932 5 | m4_monthly,amazon/chronos-t5-small,0.9797264456494642,0.10297069145186107 6 | m4_weekly,amazon/chronos-t5-small,2.0802214537692607,0.03959222330783002 7 | monash_electricity_hourly,amazon/chronos-t5-small,1.530308399040219,0.10765947926209926 8 | monash_electricity_weekly,amazon/chronos-t5-small,1.9249616494404531,0.07593976499899265 9 | monash_kdd_cup_2018,amazon/chronos-t5-small,0.6911172359201715,0.2863722811236367 10 | monash_london_smart_meters,amazon/chronos-t5-small,0.8405756252443325,0.4300875548402115 11 | monash_pedestrian_counts,amazon/chronos-t5-small,0.30836963006151696,0.2442543970311678 12 | monash_rideshare,amazon/chronos-t5-small,0.8436277753840817,0.1363421932158997 13 | monash_temperature_rain,amazon/chronos-t5-small,1.0176003932416664,0.6847726381172435 14 | taxi_30min,amazon/chronos-t5-small,0.976277213614167,0.32770172988517626 15 | uber_tlc_daily,amazon/chronos-t5-small,0.8694727058784919,0.0994889223610958 16 | uber_tlc_hourly,amazon/chronos-t5-small,0.6738672444888639,0.1573990617753753 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-small-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,amazon/chronos-t5-small,0.8516754221042285,0.08667817580712385 3 | ETTm,amazon/chronos-t5-small,0.6825432730635727,0.06076472147001207 4 | dominick,amazon/chronos-t5-small,0.8108766032127683,0.3368104617474581 5 | ercot,amazon/chronos-t5-small,0.564879593858422,0.015547628920969682 6 | exchange_rate,amazon/chronos-t5-small,1.8143459139100264,0.014492477372711763 7 | m4_quarterly,amazon/chronos-t5-small,1.2415331521819728,0.08383826063189778 8 | m4_yearly,amazon/chronos-t5-small,3.738749650935195,0.1384514201649314 9 | m5,amazon/chronos-t5-small,0.9368713240675598,0.5896066252181699 10 | monash_australian_electricity,amazon/chronos-t5-small,1.2241146217392032,0.06951399165882449 11 | monash_car_parts,amazon/chronos-t5-small,0.8917508090523597,1.0314986717260015 12 | monash_cif_2016,amazon/chronos-t5-small,1.0187937383419037,0.014633240218233142 13 | monash_covid_deaths,amazon/chronos-t5-small,42.298997211368935,0.06339512778191682 14 | monash_fred_md,amazon/chronos-t5-small,0.4742159923922472,0.01486734736993978 15 | monash_hospital,amazon/chronos-t5-small,0.709814741753487,0.05704674270057172 16 | monash_m1_monthly,amazon/chronos-t5-small,1.1723041163998773,0.13799049510465802 17 | monash_m1_quarterly,amazon/chronos-t5-small,1.8077827825737092,0.11323432989795904 18 | monash_m1_yearly,amazon/chronos-t5-small,4.739967673537301,0.1730738338876877 19 | monash_m3_monthly,amazon/chronos-t5-small,0.8856577322724943,0.09985251429658573 20 | monash_m3_quarterly,amazon/chronos-t5-small,1.278907982396775,0.08094041554590593 21 | monash_m3_yearly,amazon/chronos-t5-small,3.382470310192457,0.157363937435307 22 | monash_nn5_weekly,amazon/chronos-t5-small,0.9277396908126303,0.08963913763368506 23 | monash_tourism_monthly,amazon/chronos-t5-small,1.9251180766131313,0.10943962474253494 24 | monash_tourism_quarterly,amazon/chronos-t5-small,1.7623454951333655,0.06862432764377493 25 | monash_tourism_yearly,amazon/chronos-t5-small,3.987690476709746,0.19960492460202509 26 | monash_traffic,amazon/chronos-t5-small,0.8204223927835267,0.2571189517024486 27 | monash_weather,amazon/chronos-t5-small,0.8550633590487968,0.1479701971025123 28 | nn5,amazon/chronos-t5-small,0.6130789183153671,0.16771392719859998 29 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-tiny-agg-rel-scores.csv: -------------------------------------------------------------------------------- 1 | benchmark,metric,value 2 | in-domain,MASE,0.7649019745781727 3 | in-domain,WQL,0.6288613368129368 4 | zero-shot,MASE,0.8704764463925718 5 | zero-shot,WQL,0.7108912052035352 6 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-tiny-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,amazon/chronos-t5-tiny,0.5091784254243783,0.08236334376190152 3 | m4_daily,amazon/chronos-t5-tiny,3.203164895930929,0.022152192084951595 4 | m4_hourly,amazon/chronos-t5-tiny,0.8171321441164723,0.027490760558343874 5 | m4_monthly,amazon/chronos-t5-tiny,1.005839207921131,0.10388015368939435 6 | m4_weekly,amazon/chronos-t5-tiny,2.2148332313370735,0.043429655561156084 7 | monash_electricity_hourly,amazon/chronos-t5-tiny,1.6190021089002615,0.10967453530956882 8 | monash_electricity_weekly,amazon/chronos-t5-tiny,2.0774597917676734,0.08159998975612164 9 | monash_kdd_cup_2018,amazon/chronos-t5-tiny,0.6730886827096076,0.2616610603634618 10 | monash_london_smart_meters,amazon/chronos-t5-tiny,0.8830447519225436,0.4499607073491794 11 | monash_pedestrian_counts,amazon/chronos-t5-tiny,0.3042105240185045,0.23387631681117071 12 | monash_rideshare,amazon/chronos-t5-tiny,0.8431350112476247,0.1378817076926394 13 | monash_temperature_rain,amazon/chronos-t5-tiny,0.9887398447367799,0.6957797286648015 14 | taxi_30min,amazon/chronos-t5-tiny,1.035544060665179,0.3450476958104713 15 | uber_tlc_daily,amazon/chronos-t5-tiny,0.93025919000775,0.1105323649942084 16 | uber_tlc_hourly,amazon/chronos-t5-tiny,0.697558054147913,0.16320255844336232 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/chronos-t5-tiny-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,amazon/chronos-t5-tiny,0.8184074113571701,0.08578203438707048 3 | ETTm,amazon/chronos-t5-tiny,0.9103621000781905,0.07975361086322658 4 | dominick,amazon/chronos-t5-tiny,0.8538295532466194,0.3597090770361857 5 | ercot,amazon/chronos-t5-tiny,0.7273437589773705,0.020843170924006626 6 | exchange_rate,amazon/chronos-t5-tiny,1.6621128608546154,0.01085145980896454 7 | m4_quarterly,amazon/chronos-t5-tiny,1.2696259955861924,0.0861404188925996 8 | m4_yearly,amazon/chronos-t5-tiny,3.5293881164900527,0.13281575565500411 9 | m5,amazon/chronos-t5-tiny,0.9394059505709506,0.5981531758388589 10 | monash_australian_electricity,amazon/chronos-t5-tiny,1.4558820561269024,0.07673567331332948 11 | monash_car_parts,amazon/chronos-t5-tiny,0.9058206654011024,1.0236307963149358 12 | monash_cif_2016,amazon/chronos-t5-tiny,1.09349564130852,0.014066593076202984 13 | monash_covid_deaths,amazon/chronos-t5-tiny,46.53079664940016,0.09201919385053775 14 | monash_fred_md,amazon/chronos-t5-tiny,0.48008374212956456,0.03219550761153211 15 | monash_hospital,amazon/chronos-t5-tiny,0.7062562198194838,0.05790409320432609 16 | monash_m1_monthly,amazon/chronos-t5-tiny,1.214892145549996,0.14723095246308077 17 | monash_m1_quarterly,amazon/chronos-t5-tiny,1.8968576926613199,0.11026972972622998 18 | monash_m1_yearly,amazon/chronos-t5-tiny,4.829453202075546,0.17286063726000958 19 | monash_m3_monthly,amazon/chronos-t5-tiny,0.9095746605884618,0.10117875324490073 20 | monash_m3_quarterly,amazon/chronos-t5-tiny,1.3234957548639883,0.08209032993637215 21 | monash_m3_yearly,amazon/chronos-t5-tiny,3.1489371074890093,0.1492445630072877 22 | monash_nn5_weekly,amazon/chronos-t5-tiny,0.9637480731663901,0.09205994784693056 23 | monash_tourism_monthly,amazon/chronos-t5-tiny,2.151677532807024,0.11356761694754255 24 | monash_tourism_quarterly,amazon/chronos-t5-tiny,1.9116538900950555,0.07191734222366106 25 | monash_tourism_yearly,amazon/chronos-t5-tiny,3.820615532600914,0.19709256337364625 26 | monash_traffic,amazon/chronos-t5-tiny,0.878709088458116,0.2632101606272236 27 | monash_weather,amazon/chronos-t5-tiny,0.8504899606521996,0.14787595319625085 28 | nn5,amazon/chronos-t5-tiny,0.7021735456568664,0.19071330483289695 29 | -------------------------------------------------------------------------------- /scripts/evaluation/results/seasonal-naive-in-domain.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | electricity_15min,seasonal-naive,0.4978697476132387,0.1169378163151378 3 | m4_daily,seasonal-naive,3.278424323759728,0.0279332664832445 4 | m4_hourly,seasonal-naive,1.1932105781333862,0.0483091941403194 5 | m4_monthly,seasonal-naive,1.2597170386001693,0.1455332906092934 6 | m4_weekly,seasonal-naive,2.777295109814942,0.0633986476090776 7 | monash_electricity_hourly,seasonal-naive,1.839634785956572,0.1468968206229902 8 | monash_electricity_weekly,seasonal-naive,3.0371656285424,0.1979332504059267 9 | monash_kdd_cup_2018,seasonal-naive,0.9943785889052376,0.5555856702439576 10 | monash_london_smart_meters,seasonal-naive,0.9661872287141056,0.5413187715028914 11 | monash_pedestrian_counts,seasonal-naive,0.3691951941442247,0.3185271550430794 12 | monash_rideshare,seasonal-naive,1.2495987545425715,0.1860080644135506 13 | monash_temperature_rain,seasonal-naive,2.243384627173123,1.4244854980220072 14 | taxi_30min,seasonal-naive,1.160268631066241,0.4711417890926274 15 | uber_tlc_daily,seasonal-naive,1.37803447078482,0.2313550175912078 16 | uber_tlc_hourly,seasonal-naive,0.930916273455971,0.298849044501192 17 | -------------------------------------------------------------------------------- /scripts/evaluation/results/seasonal-naive-zero-shot.csv: -------------------------------------------------------------------------------- 1 | dataset,model,MASE,WQL 2 | ETTh,seasonal-naive,0.9316203114697056,0.1220896585205886 3 | ETTm,seasonal-naive,1.1693053852270578,0.1413480385734046 4 | dominick,seasonal-naive,0.8706150115348875,0.4529164093744346 5 | ercot,seasonal-naive,0.7613354813741452,0.0366036447606282 6 | exchange_rate,seasonal-naive,1.7401824286954128,0.0129841406759913 7 | m4_quarterly,seasonal-naive,1.6022471766126911,0.1186484661559648 8 | m4_yearly,seasonal-naive,3.974360261259571,0.1614389663357925 9 | m5,seasonal-naive,1.399206213076729,1.0240883478068443 10 | monash_australian_electricity,seasonal-naive,1.2533189641227642,0.0836951323308387 11 | monash_car_parts,seasonal-naive,1.2014638390969912,1.5999522140809177 12 | monash_cif_2016,seasonal-naive,1.289290577415544,0.0150830409089921 13 | monash_covid_deaths,seasonal-naive,46.91239825526407,0.1330848762571827 14 | monash_fred_md,seasonal-naive,1.1008000463101226,0.1222237702571737 15 | monash_hospital,seasonal-naive,0.9205278266364826,0.0726263373268254 16 | monash_m1_monthly,seasonal-naive,1.3144614957646543,0.1914632595030148 17 | monash_m1_quarterly,seasonal-naive,2.077536550805995,0.1495022062865622 18 | monash_m1_yearly,seasonal-naive,4.894322225232431,0.2092955931101782 19 | monash_m3_monthly,seasonal-naive,1.1462045758327934,0.1485446007554992 20 | monash_m3_quarterly,seasonal-naive,1.425343793700714,0.1012520529806161 21 | monash_m3_yearly,seasonal-naive,3.1717102364409517,0.1665329650420048 22 | monash_nn5_weekly,seasonal-naive,1.0628482559107015,0.1226908962169196 23 | monash_tourism_monthly,seasonal-naive,1.630939994944413,0.1041824322151567 24 | monash_tourism_quarterly,seasonal-naive,1.6989892627474672,0.1193750169177449 25 | monash_tourism_yearly,seasonal-naive,3.5520097206480883,0.2091826587673241 26 | monash_traffic,seasonal-naive,1.0767397173107436,0.3618532196990004 27 | monash_weather,seasonal-naive,1.0038475713182748,0.2165947349654047 28 | nn5,seasonal-naive,1.2917285866431214,0.4246208074843067 29 | -------------------------------------------------------------------------------- /scripts/kernel-synth.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import functools 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import numpy as np 10 | from gluonts.dataset.arrow import ArrowWriter 11 | from joblib import Parallel, delayed 12 | from sklearn.gaussian_process import GaussianProcessRegressor 13 | from sklearn.gaussian_process.kernels import ( 14 | RBF, 15 | ConstantKernel, 16 | DotProduct, 17 | ExpSineSquared, 18 | Kernel, 19 | RationalQuadratic, 20 | WhiteKernel, 21 | ) 22 | from tqdm.auto import tqdm 23 | 24 | LENGTH = 1024 25 | KERNEL_BANK = [ 26 | ExpSineSquared(periodicity=24 / LENGTH), # H 27 | ExpSineSquared(periodicity=48 / LENGTH), # 0.5H 28 | ExpSineSquared(periodicity=96 / LENGTH), # 0.25H 29 | ExpSineSquared(periodicity=24 * 7 / LENGTH), # H 30 | ExpSineSquared(periodicity=48 * 7 / LENGTH), # 0.5H 31 | ExpSineSquared(periodicity=96 * 7 / LENGTH), # 0.25H 32 | ExpSineSquared(periodicity=7 / LENGTH), # D 33 | ExpSineSquared(periodicity=14 / LENGTH), # 0.5D 34 | ExpSineSquared(periodicity=30 / LENGTH), # D 35 | ExpSineSquared(periodicity=60 / LENGTH), # 0.5D 36 | ExpSineSquared(periodicity=365 / LENGTH), # D 37 | ExpSineSquared(periodicity=365 * 2 / LENGTH), # 0.5D 38 | ExpSineSquared(periodicity=4 / LENGTH), # W 39 | ExpSineSquared(periodicity=26 / LENGTH), # W 40 | ExpSineSquared(periodicity=52 / LENGTH), # W 41 | ExpSineSquared(periodicity=4 / LENGTH), # M 42 | ExpSineSquared(periodicity=6 / LENGTH), # M 43 | ExpSineSquared(periodicity=12 / LENGTH), # M 44 | ExpSineSquared(periodicity=4 / LENGTH), # Q 45 | ExpSineSquared(periodicity=4 * 10 / LENGTH), # Q 46 | ExpSineSquared(periodicity=10 / LENGTH), # Y 47 | DotProduct(sigma_0=0.0), 48 | DotProduct(sigma_0=1.0), 49 | DotProduct(sigma_0=10.0), 50 | RBF(length_scale=0.1), 51 | RBF(length_scale=1.0), 52 | RBF(length_scale=10.0), 53 | RationalQuadratic(alpha=0.1), 54 | RationalQuadratic(alpha=1.0), 55 | RationalQuadratic(alpha=10.0), 56 | WhiteKernel(noise_level=0.1), 57 | WhiteKernel(noise_level=1.0), 58 | ConstantKernel(), 59 | ] 60 | 61 | 62 | def random_binary_map(a: Kernel, b: Kernel): 63 | """ 64 | Applies a random binary operator (+ or *) with equal probability 65 | on kernels ``a`` and ``b``. 66 | 67 | Parameters 68 | ---------- 69 | a 70 | A GP kernel. 71 | b 72 | A GP kernel. 73 | 74 | Returns 75 | ------- 76 | The composite kernel `a + b` or `a * b`. 77 | """ 78 | binary_maps = [lambda x, y: x + y, lambda x, y: x * y] 79 | return np.random.choice(binary_maps)(a, b) 80 | 81 | 82 | def sample_from_gp_prior( 83 | kernel: Kernel, X: np.ndarray, random_seed: Optional[int] = None 84 | ): 85 | """ 86 | Draw a sample from a GP prior. 87 | 88 | Parameters 89 | ---------- 90 | kernel 91 | The GP covaraince kernel. 92 | X 93 | The input "time" points. 94 | random_seed, optional 95 | The random seed for sampling, by default None. 96 | 97 | Returns 98 | ------- 99 | A time series sampled from the GP prior. 100 | """ 101 | if X.ndim == 1: 102 | X = X[:, None] 103 | 104 | assert X.ndim == 2 105 | gpr = GaussianProcessRegressor(kernel=kernel) 106 | ts = gpr.sample_y(X, n_samples=1, random_state=random_seed) 107 | 108 | return ts 109 | 110 | 111 | def sample_from_gp_prior_efficient( 112 | kernel: Kernel, 113 | X: np.ndarray, 114 | random_seed: Optional[int] = None, 115 | method: str = "eigh", 116 | ): 117 | """ 118 | Draw a sample from a GP prior. An efficient version that allows specification 119 | of the sampling method. The default sampling method used in GaussianProcessRegressor 120 | is based on SVD which is significantly slower that alternatives such as `eigh` and 121 | `cholesky`. 122 | 123 | Parameters 124 | ---------- 125 | kernel 126 | The GP covaraince kernel. 127 | X 128 | The input "time" points. 129 | random_seed, optional 130 | The random seed for sampling, by default None. 131 | method, optional 132 | The sampling method for multivariate_normal, by default `eigh`. 133 | 134 | Returns 135 | ------- 136 | A time series sampled from the GP prior. 137 | """ 138 | if X.ndim == 1: 139 | X = X[:, None] 140 | 141 | assert X.ndim == 2 142 | 143 | cov = kernel(X) 144 | ts = np.random.default_rng(seed=random_seed).multivariate_normal( 145 | mean=np.zeros(X.shape[0]), cov=cov, method=method 146 | ) 147 | 148 | return ts 149 | 150 | 151 | def generate_time_series(max_kernels: int = 5): 152 | """Generate a synthetic time series from KernelSynth. 153 | 154 | Parameters 155 | ---------- 156 | max_kernels, optional 157 | The maximum number of base kernels to use for each time series, by default 5 158 | 159 | Returns 160 | ------- 161 | A time series generated by KernelSynth. 162 | """ 163 | while True: 164 | X = np.linspace(0, 1, LENGTH) 165 | 166 | # Randomly select upto max_kernels kernels from the KERNEL_BANK 167 | selected_kernels = np.random.choice( 168 | KERNEL_BANK, np.random.randint(1, max_kernels + 1), replace=True 169 | ) 170 | 171 | # Combine the sampled kernels using random binary operators 172 | kernel = functools.reduce(random_binary_map, selected_kernels) 173 | 174 | # Sample a time series from the GP prior 175 | try: 176 | ts = sample_from_gp_prior(kernel=kernel, X=X) 177 | except np.linalg.LinAlgError as err: 178 | print("Error caught:", err) 179 | continue 180 | 181 | # The timestamp is arbitrary 182 | return {"start": np.datetime64("2000-01-01 00:00", "s"), "target": ts.squeeze()} 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument("-N", "--num-series", type=int, default=1000_000) 188 | parser.add_argument("-J", "--max-kernels", type=int, default=5) 189 | args = parser.parse_args() 190 | path = Path(__file__).parent / "kernelsynth-data.arrow" 191 | 192 | generated_dataset = Parallel(n_jobs=-1)( 193 | delayed(generate_time_series)(max_kernels=args.max_kernels) 194 | for _ in tqdm(range(args.num_series)) 195 | ) 196 | 197 | ArrowWriter(compression="lz4").write_to_file( 198 | generated_dataset, 199 | path=path, 200 | ) 201 | -------------------------------------------------------------------------------- /scripts/training/configs/chronos-gpt2.yaml: -------------------------------------------------------------------------------- 1 | training_data_paths: 2 | - "/home/ubuntu/tsmixup-data.arrow" 3 | - "/home/ubuntu/kernelsynth-data.arrow" 4 | probability: 5 | - 0.9 6 | - 0.1 7 | context_length: 512 8 | prediction_length: 64 9 | min_past: 60 10 | max_steps: 200_000 11 | save_steps: 100_000 12 | log_steps: 500 13 | per_device_train_batch_size: 32 14 | learning_rate: 0.001 15 | optim: adamw_torch_fused 16 | num_samples: 20 17 | shuffle_buffer_length: 100_000 18 | gradient_accumulation_steps: 1 19 | model_id: openai-community/gpt2 20 | model_type: causal 21 | random_init: false 22 | tie_embeddings: false 23 | output_dir: ./output/ 24 | tf32: true 25 | torch_compile: true 26 | tokenizer_class: "MeanScaleUniformBins" 27 | tokenizer_kwargs: 28 | low_limit: -15.0 29 | high_limit: 15.0 30 | n_tokens: 4096 31 | lr_scheduler_type: linear 32 | warmup_ratio: 0.0 33 | dataloader_num_workers: 1 34 | max_missing_prop: 0.1 35 | use_eos_token: true 36 | -------------------------------------------------------------------------------- /scripts/training/configs/chronos-t5-base.yaml: -------------------------------------------------------------------------------- 1 | training_data_paths: 2 | - "/home/ubuntu/tsmixup-data.arrow" 3 | - "/home/ubuntu/kernelsynth-data.arrow" 4 | probability: 5 | - 0.9 6 | - 0.1 7 | context_length: 512 8 | prediction_length: 64 9 | min_past: 60 10 | max_steps: 200_000 11 | save_steps: 100_000 12 | log_steps: 500 13 | per_device_train_batch_size: 32 14 | learning_rate: 0.001 15 | optim: adamw_torch_fused 16 | num_samples: 20 17 | shuffle_buffer_length: 100_000 18 | gradient_accumulation_steps: 1 19 | model_id: google/t5-efficient-base 20 | model_type: seq2seq 21 | random_init: true 22 | tie_embeddings: true 23 | output_dir: ./output/ 24 | tf32: true 25 | torch_compile: true 26 | tokenizer_class: "MeanScaleUniformBins" 27 | tokenizer_kwargs: 28 | low_limit: -15.0 29 | high_limit: 15.0 30 | n_tokens: 4096 31 | lr_scheduler_type: linear 32 | warmup_ratio: 0.0 33 | dataloader_num_workers: 1 34 | max_missing_prop: 0.9 35 | use_eos_token: true 36 | -------------------------------------------------------------------------------- /scripts/training/configs/chronos-t5-large.yaml: -------------------------------------------------------------------------------- 1 | training_data_paths: 2 | - "/home/ubuntu/tsmixup-data.arrow" 3 | - "/home/ubuntu/kernelsynth-data.arrow" 4 | probability: 5 | - 0.9 6 | - 0.1 7 | context_length: 512 8 | prediction_length: 64 9 | min_past: 60 10 | max_steps: 200_000 11 | save_steps: 100_000 12 | log_steps: 500 13 | per_device_train_batch_size: 8 14 | learning_rate: 0.001 15 | optim: adamw_torch_fused 16 | num_samples: 20 17 | shuffle_buffer_length: 100_000 18 | gradient_accumulation_steps: 4 19 | model_id: google/t5-efficient-large 20 | model_type: seq2seq 21 | random_init: true 22 | tie_embeddings: true 23 | output_dir: ./output/ 24 | tf32: true 25 | torch_compile: true 26 | tokenizer_class: "MeanScaleUniformBins" 27 | tokenizer_kwargs: 28 | low_limit: -15.0 29 | high_limit: 15.0 30 | n_tokens: 4096 31 | lr_scheduler_type: linear 32 | warmup_ratio: 0.0 33 | dataloader_num_workers: 1 34 | max_missing_prop: 0.9 35 | use_eos_token: true 36 | -------------------------------------------------------------------------------- /scripts/training/configs/chronos-t5-mini.yaml: -------------------------------------------------------------------------------- 1 | training_data_paths: 2 | - "/home/ubuntu/tsmixup-data.arrow" 3 | - "/home/ubuntu/kernelsynth-data.arrow" 4 | probability: 5 | - 0.9 6 | - 0.1 7 | context_length: 512 8 | prediction_length: 64 9 | min_past: 60 10 | max_steps: 200_000 11 | save_steps: 100_000 12 | log_steps: 500 13 | per_device_train_batch_size: 32 14 | learning_rate: 0.001 15 | optim: adamw_torch_fused 16 | num_samples: 20 17 | shuffle_buffer_length: 100_000 18 | gradient_accumulation_steps: 1 19 | model_id: google/t5-efficient-mini 20 | model_type: seq2seq 21 | random_init: true 22 | tie_embeddings: true 23 | output_dir: ./output/ 24 | tf32: true 25 | torch_compile: true 26 | tokenizer_class: "MeanScaleUniformBins" 27 | tokenizer_kwargs: 28 | low_limit: -15.0 29 | high_limit: 15.0 30 | n_tokens: 4096 31 | lr_scheduler_type: linear 32 | warmup_ratio: 0.0 33 | dataloader_num_workers: 1 34 | max_missing_prop: 0.9 35 | use_eos_token: true 36 | -------------------------------------------------------------------------------- /scripts/training/configs/chronos-t5-small.yaml: -------------------------------------------------------------------------------- 1 | training_data_paths: 2 | - "/home/ubuntu/tsmixup-data.arrow" 3 | - "/home/ubuntu/kernelsynth-data.arrow" 4 | probability: 5 | - 0.9 6 | - 0.1 7 | context_length: 512 8 | prediction_length: 64 9 | min_past: 60 10 | max_steps: 200_000 11 | save_steps: 100_000 12 | log_steps: 500 13 | per_device_train_batch_size: 32 14 | learning_rate: 0.001 15 | optim: adamw_torch_fused 16 | num_samples: 20 17 | shuffle_buffer_length: 100_000 18 | gradient_accumulation_steps: 1 19 | model_id: google/t5-efficient-small 20 | model_type: seq2seq 21 | random_init: true 22 | tie_embeddings: true 23 | output_dir: ./output/ 24 | tf32: true 25 | torch_compile: true 26 | tokenizer_class: "MeanScaleUniformBins" 27 | tokenizer_kwargs: 28 | low_limit: -15.0 29 | high_limit: 15.0 30 | n_tokens: 4096 31 | lr_scheduler_type: linear 32 | warmup_ratio: 0.0 33 | dataloader_num_workers: 1 34 | max_missing_prop: 0.9 35 | use_eos_token: true 36 | -------------------------------------------------------------------------------- /scripts/training/configs/chronos-t5-tiny.yaml: -------------------------------------------------------------------------------- 1 | training_data_paths: 2 | - "/home/ubuntu/tsmixup-data.arrow" 3 | - "/home/ubuntu/kernelsynth-data.arrow" 4 | probability: 5 | - 0.9 6 | - 0.1 7 | context_length: 512 8 | prediction_length: 64 9 | min_past: 60 10 | max_steps: 200_000 11 | save_steps: 100_000 12 | log_steps: 500 13 | per_device_train_batch_size: 32 14 | learning_rate: 0.001 15 | optim: adamw_torch_fused 16 | num_samples: 20 17 | shuffle_buffer_length: 100_000 18 | gradient_accumulation_steps: 1 19 | model_id: google/t5-efficient-tiny 20 | model_type: seq2seq 21 | random_init: true 22 | tie_embeddings: true 23 | output_dir: ./output/ 24 | tf32: true 25 | torch_compile: true 26 | tokenizer_class: "MeanScaleUniformBins" 27 | tokenizer_kwargs: 28 | low_limit: -15.0 29 | high_limit: 15.0 30 | n_tokens: 4096 31 | lr_scheduler_type: linear 32 | warmup_ratio: 0.0 33 | dataloader_num_workers: 1 34 | max_missing_prop: 0.9 35 | use_eos_token: true 36 | -------------------------------------------------------------------------------- /scripts/training/train.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import ast 5 | import logging 6 | import os 7 | import re 8 | import sys 9 | import json 10 | import itertools 11 | import random 12 | from copy import deepcopy 13 | from pathlib import Path 14 | from functools import partial 15 | from typing import List, Iterator, Optional, Dict 16 | 17 | import typer 18 | from typer_config import use_yaml_config 19 | import numpy as np 20 | import torch 21 | import torch.distributed as dist 22 | from torch.utils.data import IterableDataset, get_worker_info 23 | import transformers 24 | from transformers import ( 25 | AutoModelForSeq2SeqLM, 26 | AutoModelForCausalLM, 27 | AutoConfig, 28 | T5Config, 29 | Trainer, 30 | TrainingArguments, 31 | ) 32 | import accelerate 33 | import gluonts 34 | from gluonts.dataset.common import FileDataset 35 | from gluonts.itertools import Cyclic, Map, Filter 36 | from gluonts.transform import ( 37 | FilterTransformation, 38 | TestSplitSampler, 39 | ValidationSplitSampler, 40 | InstanceSplitter, 41 | ExpectedNumInstanceSampler, 42 | MissingValueImputation, 43 | LeavesMissingValues, 44 | LastValueImputation, 45 | ) 46 | 47 | from chronos import ChronosConfig, ChronosTokenizer 48 | 49 | 50 | app = typer.Typer(pretty_exceptions_enable=False) 51 | 52 | 53 | def is_main_process() -> bool: 54 | """ 55 | Check if we're on the main process. 56 | """ 57 | if not dist.is_torchelastic_launched(): 58 | return True 59 | return int(os.environ["RANK"]) == 0 60 | 61 | 62 | def log_on_main(msg: str, logger: logging.Logger, log_level: int = logging.INFO): 63 | """ 64 | Log the given message using the given logger, if we're on the main process. 65 | """ 66 | if is_main_process(): 67 | logger.log(log_level, msg) 68 | 69 | 70 | def get_training_job_info() -> Dict: 71 | """ 72 | Returns info about this training job. 73 | """ 74 | job_info = {} 75 | 76 | # CUDA info 77 | job_info["cuda_available"] = torch.cuda.is_available() 78 | if torch.cuda.is_available(): 79 | job_info["device_count"] = torch.cuda.device_count() 80 | 81 | job_info["device_names"] = { 82 | idx: torch.cuda.get_device_name(idx) 83 | for idx in range(torch.cuda.device_count()) 84 | } 85 | job_info["mem_info"] = { 86 | idx: torch.cuda.mem_get_info(device=idx) 87 | for idx in range(torch.cuda.device_count()) 88 | } 89 | 90 | # DDP info 91 | job_info["torchelastic_launched"] = dist.is_torchelastic_launched() 92 | 93 | if dist.is_torchelastic_launched(): 94 | job_info["world_size"] = dist.get_world_size() 95 | 96 | # Versions 97 | job_info["python_version"] = sys.version.replace("\n", " ") 98 | job_info["torch_version"] = torch.__version__ 99 | job_info["numpy_version"] = np.__version__ 100 | job_info["gluonts_version"] = gluonts.__version__ 101 | job_info["transformers_version"] = transformers.__version__ 102 | job_info["accelerate_version"] = accelerate.__version__ 103 | 104 | return job_info 105 | 106 | 107 | def save_training_info(ckpt_path: Path, training_config: Dict): 108 | """ 109 | Save info about this training job in a json file for documentation. 110 | """ 111 | assert ckpt_path.is_dir() 112 | with open(ckpt_path / "training_info.json", "w") as fp: 113 | json.dump( 114 | {"training_config": training_config, "job_info": get_training_job_info()}, 115 | fp, 116 | indent=4, 117 | ) 118 | 119 | 120 | def get_next_path( 121 | base_fname: str, 122 | base_dir: Path, 123 | file_type: str = "yaml", 124 | separator: str = "-", 125 | ): 126 | """ 127 | Gets the next available path in a directory. For example, if `base_fname="results"` 128 | and `base_dir` has files ["results-0.yaml", "results-1.yaml"], this function returns 129 | "results-2.yaml". 130 | """ 131 | if file_type == "": 132 | # Directory 133 | items = filter( 134 | lambda x: x.is_dir() and re.match(f"^{base_fname}{separator}\\d+$", x.stem), 135 | base_dir.glob("*"), 136 | ) 137 | else: 138 | # File 139 | items = filter( 140 | lambda x: re.match(f"^{base_fname}{separator}\\d+$", x.stem), 141 | base_dir.glob(f"*.{file_type}"), 142 | ) 143 | run_nums = list( 144 | map(lambda x: int(x.stem.replace(base_fname + separator, "")), items) 145 | ) + [-1] 146 | 147 | next_num = max(run_nums) + 1 148 | fname = f"{base_fname}{separator}{next_num}" + ( 149 | f".{file_type}" if file_type != "" else "" 150 | ) 151 | 152 | return base_dir / fname 153 | 154 | 155 | def load_model( 156 | model_id="google/t5-efficient-tiny", 157 | model_type="seq2seq", 158 | vocab_size=4096, 159 | random_init=False, 160 | tie_embeddings=False, 161 | pad_token_id=0, 162 | eos_token_id=1, 163 | ): 164 | """ 165 | Load the specified HuggingFace model, adjusting the vocabulary 166 | size, special token IDs, and initialization options. 167 | 168 | This allows to set a model up for training on a new vocabulary 169 | of tokens. 170 | """ 171 | assert model_type in ["seq2seq", "causal"] 172 | AutoModelClass = ( 173 | AutoModelForSeq2SeqLM if model_type == "seq2seq" else AutoModelForCausalLM 174 | ) 175 | if random_init: 176 | log_on_main("Using random initialization", logger) 177 | config = AutoConfig.from_pretrained(model_id) 178 | if isinstance(config, T5Config): 179 | # The default initializer_factor (1.0) in transformers is too large 180 | config.initializer_factor = 0.05 181 | config.tie_word_embeddings = tie_embeddings 182 | model = AutoModelClass.from_config(config) 183 | else: 184 | log_on_main(f"Using pretrained initialization from {model_id}", logger) 185 | model = AutoModelClass.from_pretrained(model_id) 186 | 187 | model.resize_token_embeddings(vocab_size) 188 | 189 | model.config.pad_token_id = model.generation_config.pad_token_id = pad_token_id 190 | model.config.eos_token_id = model.generation_config.eos_token_id = eos_token_id 191 | 192 | return model 193 | 194 | 195 | def has_enough_observations( 196 | entry: dict, min_length: int = 0, max_missing_prop: float = 1.0 197 | ) -> bool: 198 | """ 199 | Check if the given entry has enough observations in the ``"target"`` attribute. 200 | 201 | Parameters 202 | ---------- 203 | entry 204 | The data entry (dictionary) to be tested. 205 | min_length 206 | The minimum length the ``"target"`` attribute must have. 207 | max_missing_prop 208 | The maximum proportion of missing data allowed in the ``"target"`` 209 | attribute. 210 | """ 211 | if ( 212 | len(entry["target"]) >= min_length 213 | and np.isnan(entry["target"]).mean() <= max_missing_prop 214 | ): 215 | return True 216 | return False 217 | 218 | 219 | class PseudoShuffledIterableDataset(IterableDataset): 220 | """ 221 | Shuffle entries from an iterable by temporarily accumulating them 222 | in an intermediate buffer. 223 | 224 | Parameters 225 | ---------- 226 | base_dataset 227 | The original iterable object, representing the dataset. 228 | shuffle_buffer_length 229 | Size of the buffer use to shuffle entries from the base dataset. 230 | """ 231 | 232 | def __init__(self, base_dataset, shuffle_buffer_length: int = 100) -> None: 233 | super().__init__() 234 | self.base_dataset = base_dataset 235 | self.shuffle_buffer_length = shuffle_buffer_length 236 | self.generator = torch.Generator() 237 | 238 | def __iter__(self): 239 | shuffle_buffer = [] 240 | 241 | for element in self.base_dataset: 242 | shuffle_buffer.append(element) 243 | if len(shuffle_buffer) >= self.shuffle_buffer_length: 244 | idx = torch.randint( 245 | len(shuffle_buffer), size=(), generator=self.generator 246 | ) 247 | yield shuffle_buffer.pop(idx) 248 | 249 | while shuffle_buffer: 250 | idx = torch.randint(len(shuffle_buffer), size=(), generator=self.generator) 251 | yield shuffle_buffer.pop(idx) 252 | 253 | 254 | class ShuffleMixin: 255 | """ 256 | Mix-in class that datasets can inherit from to get 257 | shuffling functionality. 258 | """ 259 | 260 | def shuffle(self, shuffle_buffer_length: int = 100): 261 | return PseudoShuffledIterableDataset(self, shuffle_buffer_length) 262 | 263 | 264 | class ChronosDataset(IterableDataset, ShuffleMixin): 265 | """ 266 | Dataset wrapper, using a ``ChronosTokenizer`` to turn data from a time series 267 | into a HuggingFace-compatible set of ``input_ids``, ``attention_mask`` and 268 | ``labels``. 269 | 270 | Entries from the original datasets are assumed to have a ``"start"`` attribute 271 | (of type ``pd.Period``), and a ``"target"`` attribute (of type ``np.ndarray``). 272 | 273 | Parameters 274 | ---------- 275 | datasets 276 | Datasets containing the original time series data. 277 | probabilities 278 | In training mode, data will be sampled from each of the original datasets 279 | with these probabilities. 280 | tokenizer 281 | Tokenizer to be used to turn sequences of real numbers into token IDs. 282 | context_length 283 | Samples context will be limited to this length. 284 | prediction_length 285 | Samples labels will be limited to this length. 286 | drop_prob 287 | In training mode, observations from a sample will be turned into ``np.nan``, 288 | i.e. turned into missing values, with this probability. 289 | min_past 290 | Data samples will be considered only if there's at least ``min_past``-many 291 | historical observations. 292 | mode 293 | One of ``"training"``, ``"validation"``, or ``"test"``. 294 | np_dtype 295 | Numpy float data type. 296 | """ 297 | 298 | def __init__( 299 | self, 300 | datasets: list, 301 | probabilities: List[float], 302 | tokenizer: ChronosTokenizer, 303 | context_length: int = 512, 304 | prediction_length: int = 64, 305 | drop_prob: float = 0.2, 306 | min_past: Optional[int] = None, 307 | model_type: str = "seq2seq", 308 | imputation_method: Optional[MissingValueImputation] = None, 309 | mode: str = "training", 310 | np_dtype=np.float32, 311 | ) -> None: 312 | super().__init__() 313 | 314 | assert len(probabilities) == len(datasets) 315 | assert mode in ("training", "validation", "test") 316 | assert model_type in ("seq2seq", "causal") 317 | 318 | self.datasets = datasets 319 | self.probabilities = probabilities 320 | self.tokenizer = tokenizer 321 | self.context_length = context_length 322 | self.prediction_length = prediction_length 323 | self.drop_prob = drop_prob if model_type == "seq2seq" else 0.0 324 | self.min_past = min_past or prediction_length 325 | self.model_type = model_type 326 | self.imputation_method = imputation_method or LeavesMissingValues() 327 | self.mode = mode 328 | self.np_dtype = np_dtype 329 | 330 | def preprocess_entry(self, entry: dict, mode: str) -> dict: 331 | entry = {f: entry[f] for f in ["start", "target"]} 332 | entry["target"] = np.asarray(entry["target"], dtype=self.np_dtype) 333 | assert entry["target"].ndim == 1, f"got {entry['target'].ndim=}, expected 1" 334 | 335 | if self.model_type == "causal": 336 | # Causal models do not play nice with missing values, so it is 337 | # recommended to use an imputation method, e.g., LastValueImputation 338 | entry["target"] = self.imputation_method(entry["target"]) 339 | 340 | if mode == "training" and self.drop_prob > 0: 341 | target = entry["target"].copy() 342 | drop_p = np.random.uniform(low=0.0, high=self.drop_prob) 343 | mask = np.random.choice( 344 | [True, False], size=len(target), p=[drop_p, 1 - drop_p] 345 | ) 346 | target[mask] = np.nan 347 | entry["target"] = target 348 | 349 | return entry 350 | 351 | def _create_instance_splitter(self, mode: str): 352 | assert mode in ["training", "test", "validation"] 353 | 354 | instance_sampler = { 355 | "training": ExpectedNumInstanceSampler( 356 | num_instances=1.0, 357 | min_instances=1, 358 | min_past=self.min_past, 359 | min_future=self.prediction_length, 360 | ), 361 | "test": TestSplitSampler(), 362 | "validation": ValidationSplitSampler(min_future=self.prediction_length), 363 | }[mode] 364 | 365 | return InstanceSplitter( 366 | target_field="target", 367 | is_pad_field="is_pad", 368 | start_field="start", 369 | forecast_start_field="forecast_start", 370 | instance_sampler=instance_sampler, 371 | past_length=self.context_length, 372 | future_length=self.prediction_length, 373 | dummy_value=np.nan, 374 | ) 375 | 376 | def create_training_data(self, data): 377 | data = Cyclic(data) 378 | split_transform = self._create_instance_splitter( 379 | "training" 380 | ) + FilterTransformation( 381 | condition=lambda entry: (~np.isnan(entry["past_target"])).sum() > 0 382 | ) 383 | data = split_transform.apply(data, is_train=True) 384 | return data 385 | 386 | def create_test_data(self, data): 387 | data = self._create_instance_splitter("test").apply(data, is_train=False) 388 | return data 389 | 390 | def create_validation_data(self, data): 391 | data = self._create_instance_splitter("validation").apply(data, is_train=False) 392 | return data 393 | 394 | def to_hf_format(self, entry: dict) -> dict: 395 | past_target = torch.tensor(entry["past_target"]).unsqueeze(0) 396 | input_ids, attention_mask, scale = self.tokenizer.context_input_transform( 397 | past_target 398 | ) 399 | future_target = torch.tensor(entry["future_target"]).unsqueeze(0) 400 | labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale) 401 | labels[labels_mask == 0] = -100 402 | 403 | if self.model_type == "causal": 404 | # The InstanceSplitter pads time series on the left to be equal to the 405 | # context_length. However, certain models (e.g., GPT2) with absolute 406 | # position embeddings should not be trained with left padding. 407 | # The following piece of code moves padding from left to right. 408 | 409 | assert input_ids.shape[-1] == entry["past_is_pad"].shape[0] 410 | 411 | # Find the index where padding starts 412 | pad_start_idx = np.searchsorted(1 - entry["past_is_pad"], 1) 413 | padded_input_ids, obs_input_ids = torch.tensor_split( 414 | input_ids, [pad_start_idx], dim=-1 415 | ) 416 | padded_attention_mask, obs_attention_mask = torch.tensor_split( 417 | attention_mask, [pad_start_idx], dim=-1 418 | ) 419 | 420 | # Move padding to the right 421 | input_ids = torch.cat( 422 | [ 423 | obs_input_ids, 424 | labels, 425 | padded_input_ids, 426 | ], 427 | axis=-1, 428 | ) 429 | attention_mask = torch.cat( 430 | [ 431 | obs_attention_mask, 432 | labels_mask, 433 | padded_attention_mask, 434 | ], 435 | axis=-1, 436 | ) 437 | 438 | # labels for causal models are same as the input_ids. 439 | # Internally transformers shifts the labels by one during training. 440 | labels = input_ids.clone() 441 | input_ids[~attention_mask] = self.tokenizer.config.pad_token_id 442 | labels[~attention_mask] = -100 443 | 444 | return { 445 | "input_ids": input_ids.squeeze(0), 446 | "attention_mask": attention_mask.squeeze(0), 447 | "labels": labels.squeeze(0), 448 | } 449 | 450 | def __iter__(self) -> Iterator: 451 | preprocessed_datasets = [ 452 | Map( 453 | partial(self.preprocess_entry, mode=self.mode), 454 | dataset, 455 | ) 456 | for dataset in self.datasets 457 | ] 458 | 459 | if self.mode == "training": 460 | iterables = [ 461 | self.create_training_data(dataset) for dataset in preprocessed_datasets 462 | ] 463 | elif self.mode == "test": 464 | iterables = [ 465 | self.create_test_data(dataset) for dataset in preprocessed_datasets 466 | ] 467 | else: 468 | iterables = [ 469 | self.create_validation_data(dataset) 470 | for dataset in preprocessed_datasets 471 | ] 472 | 473 | worker_info = get_worker_info() 474 | if worker_info is None: 475 | probs = list(self.probabilities) 476 | else: 477 | worker_id = worker_info.id 478 | num_workers = worker_info.num_workers 479 | iterables = list(itertools.islice(iterables, worker_id, None, num_workers)) 480 | probs = list( 481 | itertools.islice(self.probabilities, worker_id, None, num_workers) 482 | ) 483 | 484 | probs = [prob / sum(probs) for prob in probs] 485 | 486 | iterators = list(map(iter, iterables)) 487 | if self.mode == "training": 488 | while True: 489 | idx = np.random.choice(range(len(iterators)), p=probs) 490 | try: 491 | yield self.to_hf_format(next(iterators[idx])) 492 | except StopIteration: 493 | probs[idx] = 0 494 | if sum(probs) == 0: 495 | return 496 | probs = [prob / sum(probs) for prob in probs] 497 | else: 498 | for entry in itertools.chain(*iterators): 499 | yield self.to_hf_format(entry) 500 | 501 | 502 | @app.command() 503 | @use_yaml_config(param_name="config") 504 | def main( 505 | training_data_paths: str, 506 | probability: Optional[str] = None, 507 | context_length: int = 512, 508 | prediction_length: int = 64, 509 | min_past: int = 64, 510 | max_steps: int = 200_000, 511 | save_steps: int = 50_000, 512 | log_steps: int = 500, 513 | per_device_train_batch_size: int = 32, 514 | learning_rate: float = 1e-3, 515 | optim: str = "adamw_torch_fused", 516 | shuffle_buffer_length: int = 100, 517 | gradient_accumulation_steps: int = 2, 518 | model_id: str = "google/t5-efficient-tiny", 519 | model_type: str = "seq2seq", 520 | random_init: bool = False, 521 | tie_embeddings: bool = False, 522 | output_dir: str = "./output/", 523 | tf32: bool = True, 524 | torch_compile: bool = True, 525 | tokenizer_class: str = "MeanScaleUniformBins", 526 | tokenizer_kwargs: str = "{'low_limit': -15.0, 'high_limit': 15.0}", 527 | n_tokens: int = 4096, 528 | n_special_tokens: int = 2, 529 | pad_token_id: int = 0, 530 | eos_token_id: int = 1, 531 | use_eos_token: bool = True, 532 | lr_scheduler_type: str = "linear", 533 | warmup_ratio: float = 0.0, 534 | dataloader_num_workers: int = 1, 535 | max_missing_prop: float = 0.9, 536 | num_samples: int = 20, 537 | temperature: float = 1.0, 538 | top_k: int = 50, 539 | top_p: float = 1.0, 540 | seed: Optional[int] = None, 541 | ): 542 | if tf32 and not ( 543 | torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 544 | ): 545 | # TF32 floating point format is available only on NVIDIA GPUs 546 | # with compute capability 8 and above. See link for details. 547 | # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capability-8-x 548 | log_on_main( 549 | "TF32 format is only available on devices with compute capability >= 8. " 550 | "Setting tf32 to False.", 551 | logger, 552 | ) 553 | tf32 = False 554 | 555 | if seed is None: 556 | seed = random.randint(0, 2**32) 557 | 558 | log_on_main(f"Using SEED: {seed}", logger) 559 | transformers.set_seed(seed=seed) 560 | 561 | raw_training_config = deepcopy(locals()) 562 | output_dir = Path(output_dir) 563 | training_data_paths = ast.literal_eval(training_data_paths) 564 | assert isinstance(training_data_paths, list) 565 | 566 | if isinstance(probability, str): 567 | probability = ast.literal_eval(probability) 568 | elif probability is None: 569 | probability = [1.0 / len(training_data_paths)] * len(training_data_paths) 570 | assert isinstance(probability, list) 571 | 572 | assert len(training_data_paths) == len(probability) 573 | 574 | if dataloader_num_workers > len(training_data_paths): 575 | log_on_main( 576 | f"Setting the number of data loader workers to {len(training_data_paths)}, " 577 | f"instead of {dataloader_num_workers}.", 578 | logger, 579 | ) 580 | dataloader_num_workers = len(training_data_paths) 581 | 582 | if isinstance(tokenizer_kwargs, str): 583 | tokenizer_kwargs = ast.literal_eval(tokenizer_kwargs) 584 | assert isinstance(tokenizer_kwargs, dict) 585 | 586 | assert model_type in ["seq2seq", "causal"] 587 | 588 | output_dir = get_next_path("run", base_dir=output_dir, file_type="") 589 | 590 | log_on_main(f"Logging dir: {output_dir}", logger) 591 | log_on_main( 592 | f"Loading and filtering {len(training_data_paths)} datasets " 593 | f"for training: {training_data_paths}", 594 | logger, 595 | ) 596 | 597 | log_on_main( 598 | f"Mixing probabilities: {probability}", 599 | logger, 600 | ) 601 | 602 | train_datasets = [ 603 | Filter( 604 | partial( 605 | has_enough_observations, 606 | min_length=min_past + prediction_length, 607 | max_missing_prop=max_missing_prop, 608 | ), 609 | FileDataset(path=Path(data_path), freq="h"), 610 | ) 611 | for data_path in training_data_paths 612 | ] 613 | 614 | log_on_main("Initializing model", logger) 615 | 616 | model = load_model( 617 | model_id=model_id, 618 | model_type=model_type, 619 | vocab_size=n_tokens, 620 | random_init=random_init, 621 | tie_embeddings=tie_embeddings, 622 | pad_token_id=pad_token_id, 623 | eos_token_id=eos_token_id, 624 | ) 625 | 626 | chronos_config = ChronosConfig( 627 | tokenizer_class=tokenizer_class, 628 | tokenizer_kwargs=tokenizer_kwargs, 629 | n_tokens=n_tokens, 630 | n_special_tokens=n_special_tokens, 631 | pad_token_id=pad_token_id, 632 | eos_token_id=eos_token_id, 633 | use_eos_token=use_eos_token, 634 | model_type=model_type, 635 | context_length=context_length, 636 | prediction_length=prediction_length, 637 | num_samples=num_samples, 638 | temperature=temperature, 639 | top_k=top_k, 640 | top_p=top_p, 641 | ) 642 | 643 | # Add extra items to model config so that it's saved in the ckpt 644 | model.config.chronos_config = chronos_config.__dict__ 645 | 646 | shuffled_train_dataset = ChronosDataset( 647 | datasets=train_datasets, 648 | probabilities=probability, 649 | tokenizer=chronos_config.create_tokenizer(), 650 | context_length=context_length, 651 | prediction_length=prediction_length, 652 | min_past=min_past, 653 | model_type=model_type, 654 | imputation_method=LastValueImputation() if model_type == "causal" else None, 655 | mode="training", 656 | ).shuffle(shuffle_buffer_length=shuffle_buffer_length) 657 | 658 | # Define training args 659 | training_args = TrainingArguments( 660 | output_dir=str(output_dir), 661 | per_device_train_batch_size=per_device_train_batch_size, 662 | learning_rate=learning_rate, 663 | lr_scheduler_type=lr_scheduler_type, 664 | warmup_ratio=warmup_ratio, 665 | optim=optim, 666 | logging_dir=str(output_dir / "logs"), 667 | logging_strategy="steps", 668 | logging_steps=log_steps, 669 | save_strategy="steps", 670 | save_steps=save_steps, 671 | report_to=["tensorboard"], 672 | max_steps=max_steps, 673 | gradient_accumulation_steps=gradient_accumulation_steps, 674 | dataloader_num_workers=dataloader_num_workers, 675 | tf32=tf32, # remove this if not using Ampere GPUs (e.g., A100) 676 | torch_compile=torch_compile, 677 | ddp_find_unused_parameters=False, 678 | remove_unused_columns=False, 679 | ) 680 | 681 | # Create Trainer instance 682 | trainer = Trainer( 683 | model=model, 684 | args=training_args, 685 | train_dataset=shuffled_train_dataset, 686 | ) 687 | log_on_main("Training", logger) 688 | 689 | trainer.train() 690 | 691 | if is_main_process(): 692 | model.save_pretrained(output_dir / "checkpoint-final") 693 | save_training_info( 694 | output_dir / "checkpoint-final", training_config=raw_training_config 695 | ) 696 | 697 | 698 | if __name__ == "__main__": 699 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") 700 | logger = logging.getLogger(__file__) 701 | logger.setLevel(logging.INFO) 702 | app() 703 | -------------------------------------------------------------------------------- /src/chronos/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .base import BaseChronosPipeline, ForecastType 5 | from .chronos import ( 6 | ChronosConfig, 7 | ChronosModel, 8 | ChronosPipeline, 9 | ChronosTokenizer, 10 | MeanScaleUniformBins, 11 | ) 12 | from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline 13 | 14 | __all__ = [ 15 | "BaseChronosPipeline", 16 | "ForecastType", 17 | "ChronosConfig", 18 | "ChronosModel", 19 | "ChronosPipeline", 20 | "ChronosTokenizer", 21 | "MeanScaleUniformBins", 22 | "ChronosBoltConfig", 23 | "ChronosBoltPipeline", 24 | ] 25 | -------------------------------------------------------------------------------- /src/chronos/base.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # Authors: Caner Turkmen , Abdul Fatir Ansari , Lorenzo Stella 5 | # Original source: 6 | # https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/base.py 7 | 8 | from enum import Enum 9 | from pathlib import Path 10 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union 11 | 12 | import torch 13 | 14 | if TYPE_CHECKING: 15 | from transformers import PreTrainedModel 16 | 17 | from .utils import left_pad_and_stack_1D 18 | 19 | 20 | class ForecastType(Enum): 21 | SAMPLES = "samples" 22 | QUANTILES = "quantiles" 23 | 24 | 25 | class PipelineRegistry(type): 26 | REGISTRY: Dict[str, "PipelineRegistry"] = {} 27 | 28 | def __new__(cls, name, bases, attrs): 29 | """See, https://github.com/faif/python-patterns.""" 30 | new_cls = type.__new__(cls, name, bases, attrs) 31 | if name is not None: 32 | cls.REGISTRY[name] = new_cls 33 | 34 | return new_cls 35 | 36 | 37 | class BaseChronosPipeline(metaclass=PipelineRegistry): 38 | forecast_type: ForecastType 39 | dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32} 40 | 41 | def __init__(self, inner_model: "PreTrainedModel"): 42 | """ 43 | Parameters 44 | ---------- 45 | inner_model : PreTrainedModel 46 | A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration 47 | """ 48 | # for easy access to the inner HF-style model 49 | self.inner_model = inner_model 50 | 51 | def _prepare_and_validate_context( 52 | self, context: Union[torch.Tensor, List[torch.Tensor]] 53 | ): 54 | if isinstance(context, list): 55 | context = left_pad_and_stack_1D(context) 56 | assert isinstance(context, torch.Tensor) 57 | if context.ndim == 1: 58 | context = context.unsqueeze(0) 59 | assert context.ndim == 2 60 | 61 | return context 62 | 63 | def predict( 64 | self, 65 | context: Union[torch.Tensor, List[torch.Tensor]], 66 | prediction_length: Optional[int] = None, 67 | **kwargs, 68 | ): 69 | """ 70 | Get forecasts for the given time series. Predictions will be 71 | returned in fp32 on the cpu. 72 | 73 | Parameters 74 | ---------- 75 | context 76 | Input series. This is either a 1D tensor, or a list 77 | of 1D tensors, or a 2D tensor whose first dimension 78 | is batch. In the latter case, use left-padding with 79 | ``torch.nan`` to align series of different lengths. 80 | prediction_length 81 | Time steps to predict. Defaults to a model-dependent 82 | value if not given. 83 | 84 | Returns 85 | ------- 86 | forecasts 87 | Tensor containing forecasts. The layout and meaning 88 | of the forecasts values depends on ``self.forecast_type``. 89 | """ 90 | raise NotImplementedError() 91 | 92 | def predict_quantiles( 93 | self, 94 | context: Union[torch.Tensor, List[torch.Tensor]], 95 | prediction_length: Optional[int] = None, 96 | quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], 97 | **kwargs, 98 | ) -> Tuple[torch.Tensor, torch.Tensor]: 99 | """ 100 | Get quantile and mean forecasts for given time series. 101 | Predictions will be returned in fp32 on the cpu. 102 | 103 | Parameters 104 | ---------- 105 | context : Union[torch.Tensor, List[torch.Tensor]] 106 | Input series. This is either a 1D tensor, or a list 107 | of 1D tensors, or a 2D tensor whose first dimension 108 | is batch. In the latter case, use left-padding with 109 | ``torch.nan`` to align series of different lengths. 110 | prediction_length : Optional[int], optional 111 | Time steps to predict. Defaults to a model-dependent 112 | value if not given. 113 | quantile_levels : List[float], optional 114 | Quantile levels to compute, by default [0.1, 0.2, ..., 0.9] 115 | 116 | Returns 117 | ------- 118 | quantiles 119 | Tensor containing quantile forecasts. Shape 120 | (batch_size, prediction_length, num_quantiles) 121 | mean 122 | Tensor containing mean (point) forecasts. Shape 123 | (batch_size, prediction_length) 124 | """ 125 | raise NotImplementedError() 126 | 127 | @classmethod 128 | def from_pretrained( 129 | cls, 130 | pretrained_model_name_or_path: Union[str, Path], 131 | *model_args, 132 | **kwargs, 133 | ): 134 | """ 135 | Load the model, either from a local path or from the HuggingFace Hub. 136 | Supports the same arguments as ``AutoConfig`` and ``AutoModel`` 137 | from ``transformers``. 138 | """ 139 | from transformers import AutoConfig 140 | 141 | torch_dtype = kwargs.get("torch_dtype", "auto") 142 | if torch_dtype != "auto" and isinstance(torch_dtype, str): 143 | kwargs["torch_dtype"] = cls.dtypes[torch_dtype] 144 | 145 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 146 | is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr( 147 | config, "chronos_config" 148 | ) 149 | 150 | if not is_valid_config: 151 | raise ValueError("Not a Chronos config file") 152 | 153 | pipeline_class_name = getattr( 154 | config, "chronos_pipeline_class", "ChronosPipeline" 155 | ) 156 | class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name) 157 | if class_ is None: 158 | raise ValueError( 159 | f"Trying to load unknown pipeline class: {pipeline_class_name}" 160 | ) 161 | 162 | return class_.from_pretrained( # type: ignore[attr-defined] 163 | pretrained_model_name_or_path, *model_args, **kwargs 164 | ) 165 | -------------------------------------------------------------------------------- /src/chronos/chronos.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # Authors: Abdul Fatir Ansari , Lorenzo Stella , Caner Turkmen 5 | 6 | import logging 7 | from dataclasses import dataclass 8 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union 9 | 10 | import torch 11 | import torch.nn as nn 12 | from transformers import ( 13 | AutoConfig, 14 | AutoModelForCausalLM, 15 | AutoModelForSeq2SeqLM, 16 | GenerationConfig, 17 | PreTrainedModel, 18 | ) 19 | 20 | import chronos 21 | from chronos.base import BaseChronosPipeline, ForecastType 22 | from chronos.utils import left_pad_and_stack_1D 23 | 24 | logger = logging.getLogger(__file__) 25 | 26 | 27 | @dataclass 28 | class ChronosConfig: 29 | """ 30 | This class holds all the configuration parameters to be used 31 | by ``ChronosTokenizer`` and ``ChronosModel``. 32 | """ 33 | 34 | tokenizer_class: str 35 | tokenizer_kwargs: Dict[str, Any] 36 | context_length: int 37 | prediction_length: int 38 | n_tokens: int 39 | n_special_tokens: int 40 | pad_token_id: int 41 | eos_token_id: int 42 | use_eos_token: bool 43 | model_type: Literal["causal", "seq2seq"] 44 | num_samples: int 45 | temperature: float 46 | top_k: int 47 | top_p: float 48 | 49 | def __post_init__(self): 50 | assert ( 51 | self.pad_token_id < self.n_special_tokens 52 | and self.eos_token_id < self.n_special_tokens 53 | ), f"Special token id's must be smaller than {self.n_special_tokens=}" 54 | 55 | def create_tokenizer(self) -> "ChronosTokenizer": 56 | class_ = getattr(chronos, self.tokenizer_class) 57 | return class_(**self.tokenizer_kwargs, config=self) 58 | 59 | 60 | class ChronosTokenizer: 61 | """ 62 | A ``ChronosTokenizer`` definines how time series are mapped into token IDs 63 | and back. 64 | 65 | For details, see the ``input_transform`` and ``output_transform`` methods, 66 | which concrete classes must implement. 67 | """ 68 | 69 | def context_input_transform( 70 | self, 71 | context: torch.Tensor, 72 | ) -> Tuple: 73 | """ 74 | Turn a batch of time series into token IDs, attention map, and tokenizer_state. 75 | 76 | Parameters 77 | ---------- 78 | context 79 | A tensor shaped (batch_size, time_length), containing the 80 | timeseries to forecast. Use left-padding with ``torch.nan`` 81 | to align time series of different lengths. 82 | 83 | Returns 84 | ------- 85 | token_ids 86 | A tensor of integers, shaped (batch_size, time_length + 1) 87 | if ``config.use_eos_token`` and (batch_size, time_length) 88 | otherwise, containing token IDs for the input series. 89 | attention_mask 90 | A boolean tensor, same shape as ``token_ids``, indicating 91 | which input observations are not ``torch.nan`` (i.e. not 92 | missing nor padding). 93 | tokenizer_state 94 | An object that can be passed to ``label_input_transform`` 95 | and ``output_transform``. Contains the relevant information 96 | to decode output samples into real values, 97 | such as location and scale parameters. 98 | """ 99 | raise NotImplementedError() 100 | 101 | def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple: 102 | """ 103 | Turn a batch of label slices of time series into token IDs and attention map 104 | using the ``tokenizer_state`` provided by ``context_input_transform``. 105 | 106 | Parameters 107 | ---------- 108 | context 109 | A tensor shaped (batch_size, time_length), containing the 110 | timeseries to forecast. Use left-padding with ``torch.nan`` 111 | to align time series of different lengths. 112 | tokenizer_state 113 | An object returned by ``context_input_transform`` containing 114 | relevant information to preprocess data, such as location and 115 | scale. The nature of this depends on the specific tokenizer. 116 | This is used for tokenizing the label, in order to use the same 117 | scaling used to tokenize the context. 118 | 119 | Returns 120 | ------- 121 | token_ids 122 | A tensor of integers, shaped (batch_size, time_length + 1) 123 | if ``config.use_eos_token`` and (batch_size, time_length) 124 | otherwise, containing token IDs for the input series. 125 | attention_mask 126 | A boolean tensor, same shape as ``token_ids``, indicating 127 | which input observations are not ``torch.nan`` (i.e. not 128 | missing nor padding). 129 | """ 130 | raise NotImplementedError() 131 | 132 | def output_transform( 133 | self, samples: torch.Tensor, tokenizer_state: Any 134 | ) -> torch.Tensor: 135 | """ 136 | Turn a batch of sample token IDs into real values. 137 | 138 | Parameters 139 | ---------- 140 | samples 141 | A tensor of integers, shaped (batch_size, num_samples, time_length), 142 | containing token IDs of sample trajectories. 143 | tokenizer_state 144 | An object returned by ``input_transform`` containing 145 | relevant context to decode samples, such as location and scale. 146 | The nature of this depends on the specific tokenizer. 147 | 148 | Returns 149 | ------- 150 | forecasts 151 | A real tensor, shaped (batch_size, num_samples, time_length), 152 | containing forecasted sample paths. 153 | """ 154 | raise NotImplementedError() 155 | 156 | 157 | class MeanScaleUniformBins(ChronosTokenizer): 158 | def __init__( 159 | self, low_limit: float, high_limit: float, config: ChronosConfig 160 | ) -> None: 161 | self.config = config 162 | self.centers = torch.linspace( 163 | low_limit, 164 | high_limit, 165 | config.n_tokens - config.n_special_tokens - 1, 166 | ) 167 | self.boundaries = torch.concat( 168 | ( 169 | torch.tensor([-1e20], device=self.centers.device), 170 | (self.centers[1:] + self.centers[:-1]) / 2, 171 | torch.tensor([1e20], device=self.centers.device), 172 | ) 173 | ) 174 | 175 | def _input_transform( 176 | self, context: torch.Tensor, scale: Optional[torch.Tensor] = None 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | context = context.to(dtype=torch.float32) 179 | attention_mask = ~torch.isnan(context) 180 | 181 | if scale is None: 182 | scale = torch.nansum( 183 | torch.abs(context) * attention_mask, dim=-1 184 | ) / torch.nansum(attention_mask, dim=-1) 185 | scale[~(scale > 0)] = 1.0 186 | 187 | scaled_context = context / scale.unsqueeze(dim=-1) 188 | token_ids = ( 189 | torch.bucketize( 190 | input=scaled_context, 191 | boundaries=self.boundaries, 192 | # buckets are open to the right, see: 193 | # https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize 194 | right=True, 195 | ) 196 | + self.config.n_special_tokens 197 | ) 198 | 199 | token_ids.clamp_(0, self.config.n_tokens - 1) 200 | 201 | token_ids[~attention_mask] = self.config.pad_token_id 202 | 203 | return token_ids, attention_mask, scale 204 | 205 | def _append_eos_token( 206 | self, token_ids: torch.Tensor, attention_mask: torch.Tensor 207 | ) -> Tuple[torch.Tensor, torch.Tensor]: 208 | batch_size = token_ids.shape[0] 209 | eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id) 210 | token_ids = torch.concat((token_ids, eos_tokens), dim=1) 211 | eos_mask = torch.full((batch_size, 1), fill_value=True) 212 | attention_mask = torch.concat((attention_mask, eos_mask), dim=1) 213 | 214 | return token_ids, attention_mask 215 | 216 | def context_input_transform( 217 | self, context: torch.Tensor 218 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 219 | length = context.shape[-1] 220 | 221 | if length > self.config.context_length: 222 | context = context[..., -self.config.context_length :] 223 | 224 | token_ids, attention_mask, scale = self._input_transform(context=context) 225 | 226 | if self.config.use_eos_token and self.config.model_type == "seq2seq": 227 | token_ids, attention_mask = self._append_eos_token( 228 | token_ids=token_ids, attention_mask=attention_mask 229 | ) 230 | 231 | return token_ids, attention_mask, scale 232 | 233 | def label_input_transform( 234 | self, label: torch.Tensor, scale: torch.Tensor 235 | ) -> Tuple[torch.Tensor, torch.Tensor]: 236 | length = label.shape[-1] 237 | 238 | assert length == self.config.prediction_length 239 | token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale) 240 | 241 | if self.config.use_eos_token: 242 | token_ids, attention_mask = self._append_eos_token( 243 | token_ids=token_ids, attention_mask=attention_mask 244 | ) 245 | 246 | return token_ids, attention_mask 247 | 248 | def output_transform( 249 | self, samples: torch.Tensor, scale: torch.Tensor 250 | ) -> torch.Tensor: 251 | scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1) 252 | indices = torch.clamp( 253 | samples - self.config.n_special_tokens - 1, 254 | min=0, 255 | max=len(self.centers) - 1, 256 | ) 257 | return self.centers[indices] * scale_unsqueezed 258 | 259 | 260 | class ChronosModel(nn.Module): 261 | """ 262 | A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers`` 263 | and uses it to predict sample paths for time series tokens. 264 | 265 | Parameters 266 | ---------- 267 | config 268 | The configuration to use. 269 | model 270 | The pretrained model to use. 271 | """ 272 | 273 | def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None: 274 | super().__init__() 275 | self.config = config 276 | self.model = model 277 | 278 | @property 279 | def device(self): 280 | return self.model.device 281 | 282 | def encode( 283 | self, 284 | input_ids: torch.Tensor, 285 | attention_mask: torch.Tensor, 286 | ): 287 | """ 288 | Extract the encoder embedding for the given token sequences. 289 | 290 | Parameters 291 | ---------- 292 | input_ids 293 | Tensor of indices of input sequence tokens in the vocabulary 294 | with shape (batch_size, sequence_length). 295 | attention_mask 296 | A mask tensor of the same shape as input_ids to avoid attending 297 | on padding or missing tokens. 298 | 299 | Returns 300 | ------- 301 | embedding 302 | A tensor of encoder embeddings with shape 303 | (batch_size, sequence_length, d_model). 304 | """ 305 | assert ( 306 | self.config.model_type == "seq2seq" 307 | ), "Encoder embeddings are only supported for encoder-decoder models" 308 | assert hasattr(self.model, "encoder") 309 | 310 | return self.model.encoder( 311 | input_ids=input_ids, attention_mask=attention_mask 312 | ).last_hidden_state 313 | 314 | def forward( 315 | self, 316 | input_ids: torch.Tensor, 317 | attention_mask: torch.Tensor, 318 | prediction_length: Optional[int] = None, 319 | num_samples: Optional[int] = None, 320 | temperature: Optional[float] = None, 321 | top_k: Optional[int] = None, 322 | top_p: Optional[float] = None, 323 | ) -> torch.Tensor: 324 | """ 325 | Predict future sample tokens for the given token sequences. 326 | 327 | Arguments ``prediction_length``, ``num_samples``, ``temperature``, 328 | ``top_k``, ``top_p`` can be used to customize the model inference, 329 | and default to the corresponding attributes in ``self.config`` if 330 | not provided. 331 | 332 | Returns 333 | ------- 334 | samples 335 | A tensor of integers, shaped (batch_size, num_samples, time_length), 336 | containing forecasted sample paths. 337 | """ 338 | if prediction_length is None: 339 | prediction_length = self.config.prediction_length 340 | if num_samples is None: 341 | num_samples = self.config.num_samples 342 | if temperature is None: 343 | temperature = self.config.temperature 344 | if top_k is None: 345 | top_k = self.config.top_k 346 | if top_p is None: 347 | top_p = self.config.top_p 348 | 349 | assert hasattr(self.model, "generate") 350 | 351 | preds = self.model.generate( 352 | input_ids=input_ids, 353 | attention_mask=attention_mask, 354 | generation_config=GenerationConfig( 355 | min_new_tokens=prediction_length, 356 | max_new_tokens=prediction_length, 357 | do_sample=True, 358 | num_return_sequences=num_samples, 359 | eos_token_id=self.config.eos_token_id, 360 | pad_token_id=self.config.pad_token_id, 361 | temperature=temperature, 362 | top_k=top_k, 363 | top_p=top_p, 364 | ), 365 | ) 366 | 367 | if self.config.model_type == "seq2seq": 368 | preds = preds[..., 1:] # remove the decoder start token 369 | else: 370 | assert self.config.model_type == "causal" 371 | assert preds.size(-1) == input_ids.size(-1) + prediction_length 372 | preds = preds[..., -prediction_length:] 373 | 374 | return preds.reshape(input_ids.size(0), num_samples, -1) 375 | 376 | 377 | class ChronosPipeline(BaseChronosPipeline): 378 | """ 379 | A ``ChronosPipeline`` uses the given tokenizer and model to forecast 380 | input time series. 381 | 382 | Use the ``from_pretrained`` class method to load serialized models. 383 | Use the ``predict`` method to get forecasts. 384 | 385 | Parameters 386 | ---------- 387 | tokenizer 388 | The tokenizer object to use. 389 | model 390 | The model to use. 391 | """ 392 | 393 | tokenizer: ChronosTokenizer 394 | model: ChronosModel 395 | forecast_type: ForecastType = ForecastType.SAMPLES 396 | 397 | def __init__(self, tokenizer, model): 398 | super().__init__(inner_model=model.model) 399 | self.tokenizer = tokenizer 400 | self.model = model 401 | 402 | def _prepare_and_validate_context( 403 | self, context: Union[torch.Tensor, List[torch.Tensor]] 404 | ): 405 | if isinstance(context, list): 406 | context = left_pad_and_stack_1D(context) 407 | assert isinstance(context, torch.Tensor) 408 | if context.ndim == 1: 409 | context = context.unsqueeze(0) 410 | assert context.ndim == 2 411 | 412 | return context 413 | 414 | @torch.no_grad() 415 | def embed( 416 | self, context: Union[torch.Tensor, List[torch.Tensor]] 417 | ) -> Tuple[torch.Tensor, Any]: 418 | """ 419 | Get encoder embeddings for the given time series. 420 | 421 | Parameters 422 | ---------- 423 | context 424 | Input series. This is either a 1D tensor, or a list 425 | of 1D tensors, or a 2D tensor whose first dimension 426 | is batch. In the latter case, use left-padding with 427 | ``torch.nan`` to align series of different lengths. 428 | 429 | Returns 430 | ------- 431 | embeddings, tokenizer_state 432 | A tuple of two tensors: the encoder embeddings and the tokenizer_state, 433 | e.g., the scale of the time series in the case of mean scaling. 434 | The encoder embeddings are shaped (batch_size, context_length, d_model) 435 | or (batch_size, context_length + 1, d_model), where context_length 436 | is the size of the context along the time axis if a 2D tensor was provided 437 | or the length of the longest time series, if a list of 1D tensors was 438 | provided, and the extra 1 is for EOS. 439 | """ 440 | context_tensor = self._prepare_and_validate_context(context=context) 441 | token_ids, attention_mask, tokenizer_state = ( 442 | self.tokenizer.context_input_transform(context_tensor) 443 | ) 444 | embeddings = self.model.encode( 445 | input_ids=token_ids.to(self.model.device), 446 | attention_mask=attention_mask.to(self.model.device), 447 | ).cpu() 448 | return embeddings, tokenizer_state 449 | 450 | def predict( # type: ignore[override] 451 | self, 452 | context: Union[torch.Tensor, List[torch.Tensor]], 453 | prediction_length: Optional[int] = None, 454 | num_samples: Optional[int] = None, 455 | temperature: Optional[float] = None, 456 | top_k: Optional[int] = None, 457 | top_p: Optional[float] = None, 458 | limit_prediction_length: bool = False, 459 | ) -> torch.Tensor: 460 | """ 461 | Get forecasts for the given time series. 462 | 463 | Refer to the base method (``BaseChronosPipeline.predict``) 464 | for details on shared parameters. 465 | 466 | Additional parameters 467 | --------------------- 468 | num_samples 469 | Number of sample paths to predict. Defaults to what 470 | specified in ``self.model.config``. 471 | temperature 472 | Temperature to use for generating sample tokens. 473 | Defaults to what specified in ``self.model.config``. 474 | top_k 475 | Top-k parameter to use for generating sample tokens. 476 | Defaults to what specified in ``self.model.config``. 477 | top_p 478 | Top-p parameter to use for generating sample tokens. 479 | Defaults to what specified in ``self.model.config``. 480 | limit_prediction_length 481 | Force prediction length smaller or equal than the 482 | built-in prediction length from the model. False by 483 | default. When true, fail loudly if longer predictions 484 | are requested, otherwise longer predictions are allowed. 485 | 486 | Returns 487 | ------- 488 | samples 489 | Tensor of sample forecasts, of shape 490 | (batch_size, num_samples, prediction_length). 491 | """ 492 | context_tensor = self._prepare_and_validate_context(context=context) 493 | 494 | if prediction_length is None: 495 | prediction_length = self.model.config.prediction_length 496 | 497 | if prediction_length > self.model.config.prediction_length: 498 | msg = ( 499 | f"We recommend keeping prediction length <= {self.model.config.prediction_length}. " 500 | "The quality of longer predictions may degrade since the model is not optimized for it. " 501 | ) 502 | if limit_prediction_length: 503 | msg += "You can turn off this check by setting `limit_prediction_length=False`." 504 | raise ValueError(msg) 505 | logger.warning(msg) 506 | 507 | predictions = [] 508 | remaining = prediction_length 509 | 510 | while remaining > 0: 511 | token_ids, attention_mask, scale = self.tokenizer.context_input_transform( 512 | context_tensor 513 | ) 514 | samples = self.model( 515 | token_ids.to(self.model.device), 516 | attention_mask.to(self.model.device), 517 | min(remaining, self.model.config.prediction_length), 518 | num_samples, 519 | temperature, 520 | top_k, 521 | top_p, 522 | ) 523 | prediction = self.tokenizer.output_transform( 524 | samples.to(scale.device), scale 525 | ) 526 | 527 | predictions.append(prediction) 528 | remaining -= prediction.shape[-1] 529 | 530 | if remaining <= 0: 531 | break 532 | 533 | context_tensor = torch.cat( 534 | [context_tensor, prediction.median(dim=1).values], dim=-1 535 | ) 536 | 537 | return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device="cpu") 538 | 539 | def predict_quantiles( 540 | self, 541 | context: Union[torch.Tensor, List[torch.Tensor]], 542 | prediction_length: Optional[int] = None, 543 | quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], 544 | **predict_kwargs, 545 | ) -> Tuple[torch.Tensor, torch.Tensor]: 546 | """ 547 | Refer to the base method (``BaseChronosPipeline.predict_quantiles``). 548 | """ 549 | prediction_samples = ( 550 | self.predict(context, prediction_length=prediction_length, **predict_kwargs) 551 | .detach() 552 | .swapaxes(1, 2) 553 | ) 554 | mean = prediction_samples.mean(dim=-1) 555 | quantiles = torch.quantile( 556 | prediction_samples, 557 | q=torch.tensor(quantile_levels, dtype=prediction_samples.dtype), 558 | dim=-1, 559 | ).permute(1, 2, 0) 560 | 561 | return quantiles, mean 562 | 563 | @classmethod 564 | def from_pretrained(cls, *args, **kwargs): 565 | """ 566 | Load the model, either from a local path or from the HuggingFace Hub. 567 | Supports the same arguments as ``AutoConfig`` and ``AutoModel`` 568 | from ``transformers``. 569 | """ 570 | 571 | config = AutoConfig.from_pretrained(*args, **kwargs) 572 | 573 | assert hasattr(config, "chronos_config"), "Not a Chronos config file" 574 | 575 | chronos_config = ChronosConfig(**config.chronos_config) 576 | 577 | if chronos_config.model_type == "seq2seq": 578 | inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs) 579 | else: 580 | assert chronos_config.model_type == "causal" 581 | inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs) 582 | 583 | return cls( 584 | tokenizer=chronos_config.create_tokenizer(), 585 | model=ChronosModel(config=chronos_config, model=inner_model), 586 | ) 587 | -------------------------------------------------------------------------------- /src/chronos/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | from typing import List 6 | 7 | import torch 8 | 9 | 10 | def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: 11 | max_len = max(len(c) for c in tensors) 12 | padded = [] 13 | for c in tensors: 14 | assert isinstance(c, torch.Tensor) 15 | assert c.ndim == 1 16 | padding = torch.full( 17 | size=(max_len - len(c),), fill_value=torch.nan, device=c.device 18 | ) 19 | padded.append(torch.concat((padding, c), dim=-1)) 20 | return torch.stack(padded) 21 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | -------------------------------------------------------------------------------- /test/dummy-chronos-bolt-model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "ChronosBoltModelForForecasting" 4 | ], 5 | "chronos_config": { 6 | "context_length": 512, 7 | "input_patch_size": 16, 8 | "input_patch_stride": 16, 9 | "prediction_length": 64, 10 | "quantiles": [ 11 | 0.1, 12 | 0.2, 13 | 0.3, 14 | 0.4, 15 | 0.5, 16 | 0.6, 17 | 0.7, 18 | 0.8, 19 | 0.9 20 | ], 21 | "use_reg_token": true 22 | }, 23 | "chronos_pipeline_class": "ChronosBoltPipeline", 24 | "classifier_dropout": 0.0, 25 | "d_ff": 8, 26 | "d_kv": 4, 27 | "d_model": 8, 28 | "decoder_start_token_id": 0, 29 | "dense_act_fn": "relu", 30 | "dropout_rate": 0.1, 31 | "eos_token_id": 1, 32 | "feed_forward_proj": "relu", 33 | "initializer_factor": 0.05, 34 | "is_encoder_decoder": true, 35 | "is_gated_act": false, 36 | "layer_norm_epsilon": 1e-06, 37 | "model_type": "t5", 38 | "n_positions": 512, 39 | "num_decoder_layers": 4, 40 | "num_heads": 4, 41 | "num_layers": 4, 42 | "pad_token_id": 0, 43 | "reg_token_id": 1, 44 | "relative_attention_max_distance": 128, 45 | "relative_attention_num_buckets": 32, 46 | "torch_dtype": "float32", 47 | "transformers_version": "4.40.2", 48 | "use_cache": true, 49 | "vocab_size": 2 50 | } 51 | -------------------------------------------------------------------------------- /test/dummy-chronos-bolt-model/model.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/chronos-forecasting/6a9c8dadac04eb85befc935043e3e2cce914267f/test/dummy-chronos-bolt-model/model.safetensors -------------------------------------------------------------------------------- /test/dummy-chronos-model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5ForConditionalGeneration" 4 | ], 5 | "d_ff": 32, 6 | "d_kv": 16, 7 | "d_model": 64, 8 | "decoder_start_token_id": 0, 9 | "dense_act_fn": "relu", 10 | "dropout_rate": 0.1, 11 | "eos_token_id": 1, 12 | "feed_forward_proj": "relu", 13 | "initializer_factor": 0.05, 14 | "is_encoder_decoder": true, 15 | "is_gated_act": false, 16 | "layer_norm_epsilon": 1e-06, 17 | "model_type": "t5", 18 | "n_positions": 512, 19 | "num_decoder_layers": 1, 20 | "num_heads": 1, 21 | "num_layers": 1, 22 | "pad_token_id": 0, 23 | "relative_attention_max_distance": 128, 24 | "relative_attention_num_buckets": 32, 25 | "torch_dtype": "bfloat16", 26 | "transformers_version": "4.31.0", 27 | "use_cache": true, 28 | "vocab_size": 32, 29 | "chronos_config": { 30 | "tokenizer_class": "MeanScaleUniformBins", 31 | "tokenizer_kwargs": { 32 | "low_limit": -15.0, 33 | "high_limit": 15.0 34 | }, 35 | "n_tokens": 32, 36 | "n_special_tokens": 2, 37 | "pad_token_id": 0, 38 | "eos_token_id": 1, 39 | "use_eos_token": true, 40 | "model_type": "seq2seq", 41 | "context_length": 512, 42 | "prediction_length": 64, 43 | "num_samples": 20, 44 | "temperature": 1.0, 45 | "top_k": 50, 46 | "top_p": 1.0 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/dummy-chronos-model/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "decoder_start_token_id": 0, 4 | "eos_token_id": 1, 5 | "pad_token_id": 0, 6 | "transformers_version": "4.31.0" 7 | } 8 | -------------------------------------------------------------------------------- /test/dummy-chronos-model/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/chronos-forecasting/6a9c8dadac04eb85befc935043e3e2cce914267f/test/dummy-chronos-model/pytorch_model.bin -------------------------------------------------------------------------------- /test/test_chronos.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from pathlib import Path 5 | 6 | import pytest 7 | import torch 8 | 9 | from chronos import ( 10 | BaseChronosPipeline, 11 | ChronosConfig, 12 | ChronosPipeline, 13 | MeanScaleUniformBins, 14 | ) 15 | from test.util import validate_tensor 16 | 17 | 18 | def test_base_chronos_pipeline_loads_from_huggingface(): 19 | BaseChronosPipeline.from_pretrained("amazon/chronos-t5-tiny", device_map="cpu") 20 | 21 | 22 | @pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27]) 23 | @pytest.mark.parametrize("n_special_tokens", [2, 5, 13]) 24 | def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int): 25 | n_tokens = n_numerical_tokens + n_special_tokens 26 | 27 | config = ChronosConfig( 28 | tokenizer_class="MeanScaleUniformBins", 29 | tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), 30 | n_tokens=n_tokens, 31 | n_special_tokens=n_special_tokens, 32 | pad_token_id=0, 33 | eos_token_id=1, 34 | use_eos_token=True, 35 | model_type="seq2seq", 36 | context_length=512, 37 | prediction_length=64, 38 | num_samples=20, 39 | temperature=1.0, 40 | top_k=50, 41 | top_p=1.0, 42 | ) 43 | 44 | tokenizer = config.create_tokenizer() 45 | assert isinstance(tokenizer, MeanScaleUniformBins) 46 | 47 | context = tokenizer.centers.unsqueeze(0) # add batch dimension 48 | scale = torch.ones((1,)) # fix the scale to one to turn off scaling 49 | 50 | token_ids, _, _ = tokenizer._input_transform(context, scale=scale) 51 | 52 | samples = tokenizer.output_transform( 53 | token_ids.unsqueeze(1), # add sample dimension 54 | scale=scale, 55 | ) 56 | 57 | assert (samples[0, 0, :] == context).all() 58 | 59 | 60 | @pytest.mark.xfail 61 | @pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27]) 62 | @pytest.mark.parametrize("n_special_tokens", [2, 5, 13]) 63 | @pytest.mark.parametrize("use_eos_token", [False, True]) 64 | def test_tokenizer_fixed_data( 65 | n_numerical_tokens: int, n_special_tokens: int, use_eos_token: bool 66 | ): 67 | n_tokens = n_numerical_tokens + n_special_tokens 68 | context_length = 3 69 | 70 | config = ChronosConfig( 71 | tokenizer_class="MeanScaleUniformBins", 72 | tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), 73 | n_tokens=n_tokens, 74 | n_special_tokens=n_special_tokens, 75 | pad_token_id=0, 76 | eos_token_id=1, 77 | use_eos_token=use_eos_token, 78 | model_type="seq2seq", 79 | context_length=512, 80 | prediction_length=64, 81 | num_samples=20, 82 | temperature=1.0, 83 | top_k=50, 84 | top_p=1.0, 85 | ) 86 | 87 | tokenizer = config.create_tokenizer() 88 | 89 | context = torch.tensor( 90 | [ 91 | [-3.7, 3.7], 92 | [-42.0, 42.0], 93 | ] 94 | ) 95 | batch_size, _ = context.shape 96 | 97 | token_ids, attention_mask, scale = tokenizer.context_input_transform(context) 98 | 99 | assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token) 100 | assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size)) 101 | assert all(token_ids[:, 1] == torch.tensor([n_special_tokens]).repeat(batch_size)) 102 | assert all(token_ids[:, 2] == torch.tensor([n_tokens - 1]).repeat(batch_size)) 103 | 104 | if use_eos_token: 105 | assert all(token_ids[:, 3] == torch.tensor([1]).repeat(batch_size)) 106 | 107 | samples = tokenizer.output_transform( 108 | torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1), 109 | tokenizer_state=scale, 110 | ) 111 | 112 | assert (samples[:, 0, [0, -1]] == context).all() 113 | 114 | 115 | @pytest.mark.xfail 116 | @pytest.mark.parametrize("use_eos_token", [False, True]) 117 | def test_tokenizer_random_data(use_eos_token: bool): 118 | context_length = 8 119 | n_tokens = 256 120 | n_special_tokens = 2 121 | 122 | config = ChronosConfig( 123 | tokenizer_class="MeanScaleUniformBins", 124 | tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), 125 | n_tokens=n_tokens, 126 | n_special_tokens=n_special_tokens, 127 | pad_token_id=0, 128 | eos_token_id=1, 129 | use_eos_token=use_eos_token, 130 | model_type="seq2seq", 131 | context_length=context_length, 132 | prediction_length=64, 133 | num_samples=20, 134 | temperature=1.0, 135 | top_k=50, 136 | top_p=1.0, 137 | ) 138 | 139 | tokenizer = config.create_tokenizer() 140 | 141 | context = torch.tensor( 142 | [ 143 | [torch.nan, torch.nan, 1.0, 1.1, torch.nan, 2.0], 144 | [3.0, torch.nan, 3.9, 4.0, 4.1, 4.9], 145 | ] 146 | ) 147 | 148 | token_ids, attention_mask, scale = tokenizer.context_input_transform(context) 149 | 150 | assert token_ids.shape == ( 151 | *context.shape[:-1], 152 | context_length + 1 * use_eos_token, 153 | ) 154 | assert attention_mask.shape == ( 155 | *context.shape[:-1], 156 | context_length + 1 * use_eos_token, 157 | ) 158 | assert scale.shape == context.shape[:1] 159 | 160 | sample_ids = torch.randint(low=n_special_tokens, high=n_tokens, size=(2, 10, 4)) 161 | sample_ids[0, 0, 0] = n_special_tokens 162 | sample_ids[-1, -1, -1] = n_tokens - 1 163 | 164 | samples = tokenizer.output_transform(sample_ids, scale) 165 | 166 | assert samples.shape == (2, 10, 4) 167 | 168 | 169 | @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) 170 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) 171 | def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): 172 | pipeline = ChronosPipeline.from_pretrained( 173 | Path(__file__).parent / "dummy-chronos-model", 174 | device_map="cpu", 175 | torch_dtype=model_dtype, 176 | ) 177 | context = 10 * torch.rand(size=(4, 16)) + 10 178 | context = context.to(dtype=input_dtype) 179 | 180 | # input: tensor of shape (batch_size, context_length) 181 | 182 | samples = pipeline.predict(context, num_samples=12, prediction_length=3) 183 | validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32) 184 | 185 | with pytest.raises(ValueError): 186 | samples = pipeline.predict( 187 | context, num_samples=7, prediction_length=65, limit_prediction_length=True 188 | ) 189 | 190 | samples = pipeline.predict( 191 | context, num_samples=7, prediction_length=65, limit_prediction_length=False 192 | ) 193 | validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32) 194 | 195 | # input: batch_size-long list of tensors of shape (context_length,) 196 | 197 | samples = pipeline.predict(list(context), num_samples=12, prediction_length=3) 198 | validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32) 199 | 200 | with pytest.raises(ValueError): 201 | samples = pipeline.predict( 202 | list(context), 203 | num_samples=7, 204 | prediction_length=65, 205 | limit_prediction_length=True, 206 | ) 207 | 208 | samples = pipeline.predict( 209 | list(context), 210 | num_samples=7, 211 | prediction_length=65, 212 | limit_prediction_length=False, 213 | ) 214 | validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32) 215 | 216 | # input: tensor of shape (context_length,) 217 | 218 | samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3) 219 | validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32) 220 | 221 | with pytest.raises(ValueError): 222 | samples = pipeline.predict( 223 | context[0, ...], 224 | num_samples=7, 225 | prediction_length=65, 226 | limit_prediction_length=True, 227 | ) 228 | 229 | samples = pipeline.predict( 230 | context[0, ...], 231 | num_samples=7, 232 | prediction_length=65, 233 | ) 234 | validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32) 235 | 236 | 237 | @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) 238 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) 239 | @pytest.mark.parametrize("prediction_length", [3, 65]) 240 | @pytest.mark.parametrize( 241 | "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]] 242 | ) 243 | def test_pipeline_predict_quantiles( 244 | model_dtype: torch.dtype, 245 | input_dtype: torch.dtype, 246 | prediction_length: int, 247 | quantile_levels: list[int], 248 | ): 249 | pipeline = ChronosPipeline.from_pretrained( 250 | Path(__file__).parent / "dummy-chronos-model", 251 | device_map="cpu", 252 | torch_dtype=model_dtype, 253 | ) 254 | context = 10 * torch.rand(size=(4, 16)) + 10 255 | context = context.to(dtype=input_dtype) 256 | 257 | num_expected_quantiles = len(quantile_levels) 258 | # input: tensor of shape (batch_size, context_length) 259 | 260 | quantiles, mean = pipeline.predict_quantiles( 261 | context, 262 | num_samples=12, 263 | prediction_length=prediction_length, 264 | quantile_levels=quantile_levels, 265 | ) 266 | validate_tensor( 267 | quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 268 | ) 269 | validate_tensor(mean, (4, prediction_length), dtype=torch.float32) 270 | 271 | # input: batch_size-long list of tensors of shape (context_length,) 272 | 273 | quantiles, mean = pipeline.predict_quantiles( 274 | list(context), 275 | num_samples=12, 276 | prediction_length=prediction_length, 277 | quantile_levels=quantile_levels, 278 | ) 279 | validate_tensor( 280 | quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 281 | ) 282 | validate_tensor(mean, (4, prediction_length), dtype=torch.float32) 283 | 284 | # input: tensor of shape (context_length,) 285 | 286 | quantiles, mean = pipeline.predict_quantiles( 287 | context[0, ...], 288 | num_samples=12, 289 | prediction_length=prediction_length, 290 | quantile_levels=quantile_levels, 291 | ) 292 | validate_tensor( 293 | quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32 294 | ) 295 | validate_tensor(mean, (1, prediction_length), dtype=torch.float32) 296 | 297 | 298 | @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) 299 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) 300 | def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): 301 | pipeline = ChronosPipeline.from_pretrained( 302 | Path(__file__).parent / "dummy-chronos-model", 303 | device_map="cpu", 304 | torch_dtype=model_dtype, 305 | ) 306 | d_model = pipeline.model.model.config.d_model 307 | context = 10 * torch.rand(size=(4, 16)) + 10 308 | context = context.to(dtype=input_dtype) 309 | expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0) 310 | 311 | # input: tensor of shape (batch_size, context_length) 312 | 313 | embedding, scale = pipeline.embed(context) 314 | validate_tensor( 315 | embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype 316 | ) 317 | validate_tensor(scale, shape=(4,), dtype=torch.float32) 318 | 319 | # input: batch_size-long list of tensors of shape (context_length,) 320 | 321 | embedding, scale = pipeline.embed(list(context)) 322 | validate_tensor( 323 | embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype 324 | ) 325 | validate_tensor(scale, shape=(4,), dtype=torch.float32) 326 | 327 | # input: tensor of shape (context_length,) 328 | embedding, scale = pipeline.embed(context[0, ...]) 329 | validate_tensor( 330 | embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype 331 | ) 332 | validate_tensor(scale, shape=(1,), dtype=torch.float32) 333 | 334 | 335 | @pytest.mark.parametrize("n_tokens", [10, 1000, 10000]) 336 | def test_tokenizer_number_of_buckets(n_tokens): 337 | config = ChronosConfig( 338 | tokenizer_class="MeanScaleUniformBins", 339 | tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), 340 | n_tokens=n_tokens, 341 | n_special_tokens=2, 342 | pad_token_id=0, 343 | eos_token_id=1, 344 | use_eos_token=True, 345 | model_type="seq2seq", 346 | context_length=512, 347 | prediction_length=64, 348 | num_samples=20, 349 | temperature=1.0, 350 | top_k=50, 351 | top_p=1.0, 352 | ) 353 | tokenizer = config.create_tokenizer() 354 | 355 | n_numerical_tokens = config.n_tokens - config.n_special_tokens 356 | 357 | # The tokenizer has one bucket too many as a result of an early bug. In order to 358 | # keep consistent with the original trained models, this is kept as it is. However, 359 | # token ids are clipped to a maximum of `n_tokens - 1` to avoid out-of-bounds errors. 360 | assert len(tokenizer.centers) == (n_numerical_tokens - 1) 361 | assert len(tokenizer.boundaries) == n_numerical_tokens 362 | 363 | 364 | @pytest.mark.parametrize("n_tokens", [10, 1000, 10000]) 365 | def test_token_clipping(n_tokens): 366 | config = ChronosConfig( 367 | tokenizer_class="MeanScaleUniformBins", 368 | tokenizer_kwargs={"low_limit": -15, "high_limit": 15}, 369 | n_tokens=n_tokens, 370 | n_special_tokens=2, 371 | pad_token_id=0, 372 | eos_token_id=1, 373 | use_eos_token=True, 374 | model_type="seq2seq", 375 | context_length=512, 376 | prediction_length=64, 377 | num_samples=20, 378 | temperature=1.0, 379 | top_k=50, 380 | top_p=1.0, 381 | ) 382 | tokenizer = config.create_tokenizer() 383 | 384 | huge_value = 1e22 # this large value is assigned to the largest bucket 385 | token_ids, _, _ = tokenizer._input_transform( 386 | context=torch.tensor([[huge_value]]), scale=torch.tensor(([1])) 387 | ) 388 | assert token_ids[0, 0] == config.n_tokens - 1 # and it's clipped to n_tokens - 1 389 | -------------------------------------------------------------------------------- /test/test_chronos_bolt.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from pathlib import Path 5 | 6 | import pytest 7 | import torch 8 | 9 | from chronos import BaseChronosPipeline, ChronosBoltPipeline 10 | from chronos.chronos_bolt import InstanceNorm, Patch 11 | from test.util import validate_tensor 12 | 13 | 14 | def test_base_chronos_pipeline_loads_from_huggingface(): 15 | BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-tiny", device_map="cpu") 16 | 17 | 18 | @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) 19 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) 20 | def test_pipeline_predict(torch_dtype: torch.dtype, input_dtype: torch.dtype): 21 | pipeline = ChronosBoltPipeline.from_pretrained( 22 | Path(__file__).parent / "dummy-chronos-bolt-model", 23 | device_map="cpu", 24 | torch_dtype=torch_dtype, 25 | ) 26 | context = 10 * torch.rand(size=(4, 16)) + 10 27 | context = context.to(dtype=input_dtype) 28 | expected_num_quantiles = len(pipeline.quantiles) 29 | 30 | # input: tensor of shape (batch_size, context_length) 31 | 32 | quantiles = pipeline.predict(context, prediction_length=3) 33 | validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32) 34 | 35 | with pytest.raises(ValueError): 36 | quantiles = pipeline.predict( 37 | context, prediction_length=65, limit_prediction_length=True 38 | ) 39 | 40 | quantiles = pipeline.predict(context, prediction_length=65) 41 | validate_tensor(quantiles, (4, expected_num_quantiles, 65)) 42 | 43 | # input: batch_size-long list of tensors of shape (context_length,) 44 | 45 | quantiles = pipeline.predict(list(context), prediction_length=3) 46 | validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32) 47 | 48 | with pytest.raises(ValueError): 49 | quantiles = pipeline.predict( 50 | list(context), 51 | prediction_length=65, 52 | limit_prediction_length=True, 53 | ) 54 | 55 | quantiles = pipeline.predict(list(context), prediction_length=65) 56 | validate_tensor(quantiles, (4, expected_num_quantiles, 65), dtype=torch.float32) 57 | 58 | # input: tensor of shape (context_length,) 59 | 60 | quantiles = pipeline.predict(context[0, ...], prediction_length=3) 61 | validate_tensor(quantiles, (1, expected_num_quantiles, 3), dtype=torch.float32) 62 | 63 | with pytest.raises(ValueError): 64 | quantiles = pipeline.predict( 65 | context[0, ...], 66 | prediction_length=65, 67 | limit_prediction_length=True, 68 | ) 69 | 70 | quantiles = pipeline.predict( 71 | context[0, ...], 72 | prediction_length=65, 73 | ) 74 | validate_tensor(quantiles, (1, expected_num_quantiles, 65), dtype=torch.float32) 75 | 76 | 77 | @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) 78 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) 79 | @pytest.mark.parametrize("prediction_length", [3, 65]) 80 | @pytest.mark.parametrize( 81 | "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]] 82 | ) 83 | def test_pipeline_predict_quantiles( 84 | torch_dtype: torch.dtype, 85 | input_dtype: torch.dtype, 86 | prediction_length: int, 87 | quantile_levels: list[int], 88 | ): 89 | pipeline = ChronosBoltPipeline.from_pretrained( 90 | Path(__file__).parent / "dummy-chronos-bolt-model", 91 | device_map="cpu", 92 | torch_dtype=torch_dtype, 93 | ) 94 | context = 10 * torch.rand(size=(4, 16)) + 10 95 | context = context.to(dtype=input_dtype) 96 | 97 | num_expected_quantiles = len(quantile_levels) 98 | # input: tensor of shape (batch_size, context_length) 99 | 100 | quantiles, mean = pipeline.predict_quantiles( 101 | context, 102 | prediction_length=prediction_length, 103 | quantile_levels=quantile_levels, 104 | ) 105 | validate_tensor( 106 | quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 107 | ) 108 | validate_tensor(mean, (4, prediction_length), dtype=torch.float32) 109 | 110 | # input: batch_size-long list of tensors of shape (context_length,) 111 | 112 | quantiles, mean = pipeline.predict_quantiles( 113 | list(context), 114 | prediction_length=prediction_length, 115 | quantile_levels=quantile_levels, 116 | ) 117 | validate_tensor( 118 | quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 119 | ) 120 | validate_tensor(mean, (4, prediction_length), dtype=torch.float32) 121 | 122 | # input: tensor of shape (context_length,) 123 | 124 | quantiles, mean = pipeline.predict_quantiles( 125 | context[0, ...], 126 | prediction_length=prediction_length, 127 | quantile_levels=quantile_levels, 128 | ) 129 | validate_tensor( 130 | quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32 131 | ) 132 | validate_tensor(mean, (1, prediction_length), dtype=torch.float32) 133 | 134 | 135 | @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) 136 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) 137 | def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): 138 | pipeline = ChronosBoltPipeline.from_pretrained( 139 | Path(__file__).parent / "dummy-chronos-bolt-model", 140 | device_map="cpu", 141 | torch_dtype=model_dtype, 142 | ) 143 | d_model = pipeline.model.config.d_model 144 | context = 10 * torch.rand(size=(4, 16)) + 10 145 | context = context.to(dtype=input_dtype) 146 | 147 | # the patch size of dummy model is 16, so only 1 patch is created 148 | expected_embed_length = 1 + ( 149 | 1 if pipeline.model.config.chronos_config["use_reg_token"] else 0 150 | ) 151 | 152 | # input: tensor of shape (batch_size, context_length) 153 | 154 | embedding, loc_scale = pipeline.embed(context) 155 | validate_tensor( 156 | embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype 157 | ) 158 | validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32) 159 | validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32) 160 | 161 | # input: batch_size-long list of tensors of shape (context_length,) 162 | 163 | embedding, loc_scale = pipeline.embed(list(context)) 164 | validate_tensor( 165 | embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype 166 | ) 167 | validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32) 168 | validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32) 169 | 170 | # input: tensor of shape (context_length,) 171 | embedding, loc_scale = pipeline.embed(context[0, ...]) 172 | validate_tensor( 173 | embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype 174 | ) 175 | validate_tensor(loc_scale[0], shape=(1,), dtype=torch.float32) 176 | validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32) 177 | 178 | 179 | # The following tests have been taken from 180 | # https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py 181 | # Author: Caner Turkmen 182 | 183 | 184 | def test_given_even_data_patch_operator_output_is_correct(): 185 | batch_size = 17 186 | patch_len = 16 187 | 188 | patch = Patch(patch_len, patch_len) 189 | 190 | batch = ( 191 | torch.stack([torch.arange(512)] * batch_size) 192 | + torch.arange(batch_size)[:, None] 193 | ) 194 | output = patch(batch) 195 | 196 | assert output.shape == (batch_size, 512 // patch_len, patch_len) 197 | 198 | assert torch.allclose( 199 | output[:, 0], 200 | torch.stack([torch.arange(patch_len)] * batch_size) 201 | + torch.arange(batch_size)[:, None], 202 | atol=1e-5, 203 | ) 204 | assert torch.allclose( 205 | output[:, 1], 206 | torch.stack([torch.arange(patch_len, 2 * patch_len)] * batch_size) 207 | + torch.arange(batch_size)[:, None], 208 | atol=1e-5, 209 | ) 210 | assert not torch.isnan(output).any() 211 | 212 | 213 | def test_given_even_data_and_strides_patch_operator_output_is_correct(): 214 | batch_size = 17 215 | patch_len, patch_stride = 16, 8 216 | 217 | patch = Patch(patch_len, patch_stride) 218 | 219 | offset = torch.arange(batch_size)[:, None] 220 | batch = torch.stack([torch.arange(512)] * batch_size) + offset 221 | output = patch(batch) 222 | 223 | assert torch.allclose( 224 | output[:, 1], 225 | torch.stack([torch.arange(patch_stride, patch_stride + patch_len)] * batch_size) 226 | + offset, 227 | atol=1e-5, 228 | ) 229 | assert not torch.isnan(output).any() 230 | 231 | 232 | def test_given_uneven_data_patch_operator_pads_and_output_is_correct(): 233 | batch_size = 17 234 | patch_len = 16 235 | 236 | patch = Patch(patch_len, patch_len) 237 | 238 | batch = ( 239 | torch.stack([torch.arange(512 - patch_len + 1)] * batch_size) 240 | + torch.arange(batch_size)[:, None] 241 | ).float() 242 | output = patch(batch) 243 | 244 | assert output.shape == (batch_size, 512 // patch_len, patch_len) 245 | 246 | # check the first portion is padded 247 | assert torch.isnan(output[:, 0, :-1]).all() 248 | 249 | # check nowhere else is nan 250 | assert not torch.isnan(output[:, 1:]).any() 251 | 252 | 253 | def test_when_instancenorm_applied_then_standardization_correct(): 254 | inorm = InstanceNorm() 255 | 256 | input_ = torch.tensor( 257 | [ 258 | [1, 2, 3, 4, 5], 259 | [2, 3, 4, 5, 6], 260 | ] 261 | ).float() 262 | 263 | normalized, (loc, scale) = inorm(input_) 264 | 265 | assert normalized.shape == input_.shape 266 | assert torch.allclose(normalized[0], normalized[1]) 267 | assert torch.allclose(loc.squeeze(), torch.tensor([3.0, 4.0])) 268 | assert torch.allclose(scale.squeeze(), torch.tensor(1.41421)) 269 | 270 | 271 | def test_when_instancenorm_applied_and_reversed_then_nans_preserved(): 272 | inorm = InstanceNorm() 273 | 274 | input_ = torch.tensor( 275 | [ 276 | [1, torch.nan, 3, 4, 5], 277 | [2, 3, 4, 5, torch.nan], 278 | ] 279 | ).float() 280 | 281 | normalized, (loc, scale) = inorm(input_) 282 | assert torch.allclose(normalized.isnan(), input_.isnan()) 283 | 284 | output = inorm.inverse(normalized, (loc, scale)) 285 | assert torch.allclose(output, input_, equal_nan=True) 286 | 287 | 288 | def test_when_instancenorm_applied_and_reversed_then_output_correct(): 289 | inorm = InstanceNorm() 290 | 291 | input_ = torch.tensor( 292 | [ 293 | [1, 2, 3, 4, 5], 294 | [2, 3, 4, 5, 1000], 295 | ] 296 | ).float() 297 | 298 | normalized, loc_scale = inorm(input_) 299 | output = inorm.inverse(normalized, loc_scale) 300 | 301 | assert torch.allclose(output, input_) 302 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | import torch 6 | 7 | from chronos.utils import left_pad_and_stack_1D 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "tensors", 12 | [ 13 | [ 14 | torch.tensor([2.0, 3.0], dtype=dtype), 15 | torch.tensor([4.0, 5.0, 6.0], dtype=dtype), 16 | torch.tensor([7.0, 8.0, 9.0, 10.0], dtype=dtype), 17 | ] 18 | for dtype in [torch.int, torch.float16, torch.float32] 19 | ], 20 | ) 21 | def test_pad_and_stack(tensors: list): 22 | stacked_and_padded = left_pad_and_stack_1D(tensors) 23 | 24 | assert stacked_and_padded.dtype == torch.float32 25 | assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors)) 26 | 27 | ref = torch.concat(tensors).to(dtype=stacked_and_padded.dtype) 28 | 29 | assert torch.sum(torch.nan_to_num(stacked_and_padded, nan=0)) == torch.sum(ref) 30 | -------------------------------------------------------------------------------- /test/util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | 6 | def validate_tensor( 7 | a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None 8 | ) -> None: 9 | assert isinstance(a, torch.Tensor) 10 | assert a.shape == shape 11 | 12 | if dtype is not None: 13 | assert a.dtype == dtype 14 | --------------------------------------------------------------------------------