├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── config.yml
└── workflows
│ ├── ci.yml
│ ├── eval-model.yml
│ └── publish-to-pypi.yml
├── .gitignore
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── ci
└── evaluate
│ └── backtest_config.yaml
├── figures
├── chronos-logo.png
├── main-figure.png
└── zero_shot-agg_scaled_score.svg
├── notebooks
└── deploy-chronos-bolt-to-amazon-sagemaker.ipynb
├── pyproject.toml
├── scripts
├── README.md
├── evaluation
│ ├── agg-relative-score.py
│ ├── configs
│ │ ├── in-domain.yaml
│ │ └── zero-shot.yaml
│ ├── evaluate.py
│ └── results
│ │ ├── chronos-bolt-base-agg-rel-scores.csv
│ │ ├── chronos-bolt-base-in-domain.csv
│ │ ├── chronos-bolt-base-zero-shot.csv
│ │ ├── chronos-bolt-mini-agg-rel-scores.csv
│ │ ├── chronos-bolt-mini-in-domain.csv
│ │ ├── chronos-bolt-mini-zero-shot.csv
│ │ ├── chronos-bolt-small-agg-rel-scores.csv
│ │ ├── chronos-bolt-small-in-domain.csv
│ │ ├── chronos-bolt-small-zero-shot.csv
│ │ ├── chronos-bolt-tiny-agg-rel-scores.csv
│ │ ├── chronos-bolt-tiny-in-domain.csv
│ │ ├── chronos-bolt-tiny-zero-shot.csv
│ │ ├── chronos-t5-base-agg-rel-scores.csv
│ │ ├── chronos-t5-base-in-domain.csv
│ │ ├── chronos-t5-base-zero-shot.csv
│ │ ├── chronos-t5-large-agg-rel-scores.csv
│ │ ├── chronos-t5-large-in-domain.csv
│ │ ├── chronos-t5-large-zero-shot.csv
│ │ ├── chronos-t5-mini-agg-rel-scores.csv
│ │ ├── chronos-t5-mini-in-domain.csv
│ │ ├── chronos-t5-mini-zero-shot.csv
│ │ ├── chronos-t5-small-agg-rel-scores.csv
│ │ ├── chronos-t5-small-in-domain.csv
│ │ ├── chronos-t5-small-zero-shot.csv
│ │ ├── chronos-t5-tiny-agg-rel-scores.csv
│ │ ├── chronos-t5-tiny-in-domain.csv
│ │ ├── chronos-t5-tiny-zero-shot.csv
│ │ ├── seasonal-naive-in-domain.csv
│ │ └── seasonal-naive-zero-shot.csv
├── kernel-synth.py
└── training
│ ├── configs
│ ├── chronos-gpt2.yaml
│ ├── chronos-t5-base.yaml
│ ├── chronos-t5-large.yaml
│ ├── chronos-t5-mini.yaml
│ ├── chronos-t5-small.yaml
│ └── chronos-t5-tiny.yaml
│ └── train.py
├── src
└── chronos
│ ├── __init__.py
│ ├── base.py
│ ├── chronos.py
│ ├── chronos_bolt.py
│ └── utils.py
└── test
├── __init__.py
├── dummy-chronos-bolt-model
├── config.json
└── model.safetensors
├── dummy-chronos-model
├── config.json
├── generation_config.json
└── pytorch_model.bin
├── test_chronos.py
├── test_chronos_bolt.py
├── test_utils.py
└── util.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us reproduce and fix a bug
4 | title: "[BUG]"
5 | labels: ['bug']
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Bug report checklist**
11 |
12 | - [ ] I provided code that demonstrates a minimal reproducible example.
13 | - [ ] I confirmed bug exists on the latest mainline of Chronos via source install.
14 |
15 | **Describe the bug**
16 |
17 |
18 | **Expected behavior**
19 |
20 |
21 | **To reproduce**
22 |
25 |
26 | **Environment description**
27 | Operating system:
28 | Python version:
29 | CUDA version:
30 | PyTorch version:
31 | HuggingFace transformers version:
32 | HuggingFace accelerate version:
33 |
34 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Frequently asked questions
4 | url: https://github.com/amazon-science/chronos-forecasting/issues?q=is%3Aissue+label%3AFAQ
5 | about: Check the frequently asked questions before opening a new one
6 | - name: Discussions
7 | url: https://github.com/amazon-science/chronos-forecasting/discussions/new
8 | about: Use this to ask questions and start discussions
9 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: ["main"] # Run only on main branch
6 | pull_request:
7 | branches: ["**"] # Run on any branch
8 | schedule:
9 | - cron: "0 8 * * *" # Run at 8 AM UTC
10 |
11 | jobs:
12 | type-check:
13 | strategy:
14 | max-parallel: 4
15 | fail-fast: false
16 | matrix:
17 | python-version: ["3.11"]
18 | platform: [ubuntu-latest]
19 |
20 | runs-on: ${{ matrix.platform }}
21 |
22 | steps:
23 | - uses: actions/checkout@v4
24 | - name: Set up Python ${{ matrix.python-version }}
25 | uses: actions/setup-python@v5
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 | - name: Install dependencies
29 | run: pip install ".[typecheck]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
30 | - name: Type checks with mypy
31 | run: mypy src test
32 |
33 | test:
34 | strategy:
35 | max-parallel: 4
36 | fail-fast: false
37 | matrix:
38 | python-version: ["3.9", "3.10", "3.11", "3.12"]
39 | platform: [ubuntu-latest]
40 |
41 | runs-on: ${{ matrix.platform }}
42 |
43 | steps:
44 | - uses: actions/checkout@v4
45 | - name: Set up Python ${{ matrix.python-version }}
46 | uses: actions/setup-python@v5
47 | with:
48 | python-version: ${{ matrix.python-version }}
49 | - name: Install dependencies
50 | run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
51 | - name: Test with pytest
52 | run: pytest
53 |
--------------------------------------------------------------------------------
/.github/workflows/eval-model.yml:
--------------------------------------------------------------------------------
1 | # Evaluates Chronos-Bolt (Small) model on selected datasets
2 | name: Evaluate
3 |
4 | on:
5 | # Runs only with read privilages for the GITHUB_TOKEN
6 | pull_request:
7 | branches: ["main"] # Run on PRs to main branch
8 | types:
9 | - opened # When a PR is created
10 | - reopened # When a closed PR is reopened
11 | - synchronize # When new commits are pushed to the PR
12 | - labeled # When a label is added to the PR
13 |
14 | jobs:
15 | evaluate-and-print:
16 | if: contains(github.event.pull_request.labels.*.name, 'run-eval') # Only run if 'run-eval' label is added
17 | runs-on: ubuntu-latest
18 | env:
19 | RESULTS_CSV: "eval-ci-metrics-${{ github.event.pull_request.number }}.csv"
20 |
21 | steps:
22 | - name: Checkout Repository
23 | uses: actions/checkout@v4
24 |
25 | - name: Set up Python
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: '3.11'
29 |
30 | - name: Install Dependencies
31 | run: pip install ".[evaluation]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
32 |
33 | - name: Run Eval Script
34 | run: python scripts/evaluation/evaluate.py ci/evaluate/backtest_config.yaml $RESULTS_CSV --chronos-model-id=amazon/chronos-bolt-small --device=cpu --torch-dtype=float32
35 |
36 | - name: Print CSV
37 | run: cat $RESULTS_CSV
38 |
--------------------------------------------------------------------------------
/.github/workflows/publish-to-pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python Package to PyPi
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | deploy-to-pypi:
9 | runs-on: ubuntu-latest
10 | environment: release
11 | permissions:
12 | id-token: write
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Set up Python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: '3.11'
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install -U pip
22 | python -m pip install setuptools wheel build
23 | - name: Build package
24 | run: |
25 | python -m build
26 | - name: Publish to PyPi
27 | uses: pypa/gh-action-pypi-publish@release/v1
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # macOS stuff
163 | .DS_store
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | title: "Chronos: Learning the Language of Time Series"
3 | message: "If you find Chronos models useful for your research, please consider citing the associated paper."
4 | authors:
5 | - family-names: Ansari
6 | given-names: Abdul Fatir
7 | - family-names: Stella
8 | given-names: Lorenzo
9 | - family-names: Turkmen
10 | given-names: Caner
11 | - family-names: Zhang
12 | given-names: Xiyuan
13 | - family-names: Mercado
14 | given-names: Pedro
15 | - family-names: Shen
16 | given-names: Huibin
17 | - family-names: Shchur
18 | given-names: Oleksandr
19 | - family-names: Rangapuram
20 | given-names: Syama Syndar
21 | - family-names: Arango
22 | given-names: Sebastian Pineda
23 | - family-names: Kapoor
24 | given-names: Shubham
25 | - family-names: Zschiegner
26 | given-names: Jasper
27 | - family-names: Maddix
28 | given-names: Danielle C.
29 | - family-names: Mahoney
30 | given-names: Michael W.
31 | - family-names: Torkkola
32 | given-names: Kari
33 | - family-names: Wilson
34 | given-names: Andrew Gordon
35 | - family-names: Bohlke-Schneider
36 | given-names: Michael
37 | - family-names: Wang
38 | given-names: Yuyang
39 | preferred-citation:
40 | type: article
41 | authors:
42 | - family-names: Ansari
43 | given-names: Abdul Fatir
44 | - family-names: Stella
45 | given-names: Lorenzo
46 | - family-names: Turkmen
47 | given-names: Caner
48 | - family-names: Zhang
49 | given-names: Xiyuan
50 | - family-names: Mercado
51 | given-names: Pedro
52 | - family-names: Shen
53 | given-names: Huibin
54 | - family-names: Shchur
55 | given-names: Oleksandr
56 | - family-names: Rangapuram
57 | given-names: Syama Syndar
58 | - family-names: Arango
59 | given-names: Sebastian Pineda
60 | - family-names: Kapoor
61 | given-names: Shubham
62 | - family-names: Zschiegner
63 | given-names: Jasper
64 | - family-names: Maddix
65 | given-names: Danielle C.
66 | - family-names: Mahoney
67 | given-names: Michael W.
68 | - family-names: Torkkola
69 | given-names: Kari
70 | - family-names: Wilson
71 | given-names: Andrew Gordon
72 | - family-names: Bohlke-Schneider
73 | given-names: Michael
74 | - family-names: Wang
75 | given-names: Yuyang
76 | title: "Chronos: Learning the Language of Time Series"
77 | journal: "arXiv preprint arXiv:2403.07815"
78 | year: 2024
79 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 |
6 |
7 | # Chronos: Learning the Language of Time Series
8 |
9 | [](https://arxiv.org/abs/2403.07815)
10 | [](https://huggingface.co/datasets/autogluon/chronos_datasets)
11 | [](https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444)
12 | [](https://github.com/autogluon/fev)
13 | [](notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb)
14 | [](https://github.com/amazon-science/chronos-forecasting/issues?q=is%3Aissue+label%3AFAQ)
15 | [](https://opensource.org/licenses/Apache-2.0)
16 |
17 |
18 |
19 |
20 | ## 🚀 News
21 | - **14 Feb 2025**: 🚀 Chronos-Bolt is now available on Amazon SageMaker JumpStart! Check out the [tutorial notebook](notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb) to learn how to deploy Chronos endpoints for production use in 3 lines of code.
22 | - **12 Dec 2024**: 📊 We released [`fev`](https://github.com/autogluon/fev), a lightweight package for benchmarking time series forecasting models based on the [Hugging Face `datasets`](https://huggingface.co/docs/datasets/en/index) library.
23 | - **26 Nov 2024**: ⚡️ Chronos-Bolt models released [on HuggingFace](https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444). Chronos-Bolt models are more accurate (5% lower error), up to 250x faster and 20x more memory efficient than the original Chronos models of the same size!
24 | - **27 Jun 2024**: 🚀 [Released datasets](https://huggingface.co/datasets/autogluon/chronos_datasets) used in the paper and an [evaluation script](./scripts/README.md#evaluating-chronos-models) to compute the WQL and MASE scores reported in the paper.
25 | - **17 May 2024**: 🐛 Fixed an off-by-one error in bin indices in the `output_transform`. This simple fix significantly improves the overall performance of Chronos. We will update the results in the next revision on ArXiv.
26 | - **10 May 2024**: 🚀 We added the code for pretraining and fine-tuning Chronos models. You can find it in [this folder](./scripts/training). We also added [a script](./scripts/kernel-synth.py) for generating synthetic time series data from Gaussian processes (KernelSynth; see Section 4.2 in the paper for details). Check out the [usage examples](./scripts/).
27 | - **19 Apr 2024**: 🚀 Chronos is now supported on [AutoGluon-TimeSeries](https://auto.gluon.ai/stable/tutorials/timeseries/index.html), the powerful AutoML package for time series forecasting which enables model ensembles, cloud deployments, and much more. Get started with the [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
28 | - **08 Apr 2024**: 🧪 Experimental [MLX inference support](https://github.com/amazon-science/chronos-forecasting/tree/mlx) added. If you have an Apple Silicon Mac, you can now obtain significantly faster forecasts from Chronos compared to CPU inference. This provides an alternative way to exploit the GPU on your Apple Silicon Macs together with the "mps" support in PyTorch.
29 | - **25 Mar 2024**: 🚀 [v1.1.0 released](https://github.com/amazon-science/chronos-forecasting/releases/tag/v1.1.0) with inference optimizations and `pipeline.embed` to extract encoder embeddings from Chronos.
30 | - **13 Mar 2024**: 🚀 Chronos [paper](https://arxiv.org/abs/2403.07815) and inference code released.
31 |
32 | ## ✨ Introduction
33 |
34 | Chronos is a family of **pretrained time series forecasting models** based on language model architectures. A time series is transformed into a sequence of tokens via scaling and quantization, and a language model is trained on these tokens using the cross-entropy loss. Once trained, probabilistic forecasts are obtained by sampling multiple future trajectories given the historical context. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes.
35 |
36 | For details on Chronos models, training data and procedures, and experimental results, please refer to the paper [Chronos: Learning the Language of Time Series](https://arxiv.org/abs/2403.07815).
37 |
38 |
39 |
40 |
41 |
42 | Fig. 1: High-level depiction of Chronos. (Left) The input time series is scaled and quantized to obtain a sequence of tokens. (Center) The tokens are fed into a language model which may either be an encoder-decoder or a decoder-only model. The model is trained using the cross-entropy loss. (Right) During inference, we autoregressively sample tokens from the model and map them back to numerical values. Multiple trajectories are sampled to obtain a predictive distribution.
43 |
44 |
45 |
46 | ### Architecture
47 |
48 | The models in this repository are based on the [T5 architecture](https://arxiv.org/abs/1910.10683). The only difference is in the vocabulary size: Chronos-T5 models use 4096 different tokens, compared to 32128 of the original T5 models, resulting in fewer parameters.
49 |
50 |
51 |
52 | | Model | Parameters | Based on |
53 | | ---------------------------------------------------------------------- | ---------- | ---------------------------------------------------------------------- |
54 | | [**chronos-t5-tiny**](https://huggingface.co/amazon/chronos-t5-tiny) | 8M | [t5-efficient-tiny](https://huggingface.co/google/t5-efficient-tiny) |
55 | | [**chronos-t5-mini**](https://huggingface.co/amazon/chronos-t5-mini) | 20M | [t5-efficient-mini](https://huggingface.co/google/t5-efficient-mini) |
56 | | [**chronos-t5-small**](https://huggingface.co/amazon/chronos-t5-small) | 46M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) |
57 | | [**chronos-t5-base**](https://huggingface.co/amazon/chronos-t5-base) | 200M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) |
58 | | [**chronos-t5-large**](https://huggingface.co/amazon/chronos-t5-large) | 710M | [t5-efficient-large](https://huggingface.co/google/t5-efficient-large) |
59 | | [**chronos-bolt-tiny**](https://huggingface.co/amazon/chronos-bolt-tiny) | 9M | [t5-efficient-tiny](https://huggingface.co/google/t5-efficient-tiny) |
60 | | [**chronos-bolt-mini**](https://huggingface.co/amazon/chronos-bolt-mini) | 21M | [t5-efficient-mini](https://huggingface.co/google/t5-efficient-mini) |
61 | | [**chronos-bolt-small**](https://huggingface.co/amazon/chronos-bolt-small) | 48M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) |
62 | | [**chronos-bolt-base**](https://huggingface.co/amazon/chronos-bolt-base) | 205M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) |
63 |
64 |
65 |
66 | ### Zero-Shot Results
67 |
68 | The following figure showcases the remarkable **zero-shot** performance of Chronos and Chronos-Bolt models on 27 datasets against local models, task-specific models and other pretrained models. For details on the evaluation setup and other results, please refer to [the paper](https://arxiv.org/abs/2403.07815).
69 |
70 |
71 |
72 |
73 |
74 | Fig. 2: Performance of different models on Benchmark II, comprising 27 datasets not seen by Chronos and Chronos-Bolt models during training. This benchmark provides insights into the zero-shot performance of Chronos and Chronos-Bolt models against local statistical models, which fit parameters individually for each time series, task-specific models trained on each task, and pretrained models trained on a large corpus of time series. Pretrained Models (Other) indicates that some (or all) of the datasets in Benchmark II may have been in the training corpus of these models. The probabilistic (WQL) and point (MASE) forecasting metrics were normalized using the scores of the Seasonal Naive baseline and aggregated through a geometric mean to obtain the Agg. Relative WQL and MASE, respectively.
75 |
76 |
77 |
78 | ## 📈 Usage
79 |
80 | To perform inference with Chronos or Chronos-Bolt models, the easiest way is to install this package through `pip`:
81 |
82 | ```sh
83 | pip install chronos-forecasting
84 | ```
85 |
86 | If you're interested in pretraining, fine-tuning, and other research & development, clone and install the package from source:
87 |
88 | ```sh
89 | # Clone the repository
90 | git clone https://github.com/amazon-science/chronos-forecasting.git
91 |
92 | # Install in editable mode with extra training-related dependencies
93 | cd chronos-forecasting && pip install --editable ".[training]"
94 | ```
95 |
96 | > [!TIP]
97 | > This repository is intended for research purposes and provides a minimal interface to Chronos models. For reliable production use, we recommend the following options:
98 | > - [AutoGluon](https://auto.gluon.ai) provides effortless fine-tuning, augmenting Chronos models with exogenous information through covariate regressors, ensembling with other statistical and machine learning models. Check out the AutoGluon Chronos [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
99 | > - SageMaker JumpStart makes it easy to deploy Chronos inference endpoints to AWS with just a few lines of code. Check out [this tutorial](notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb) for more details.
100 |
101 | ### Forecasting
102 |
103 | A minimal example showing how to perform forecasting using Chronos and Chronos-Bolt models:
104 |
105 | ```python
106 | import pandas as pd # requires: pip install pandas
107 | import torch
108 | from chronos import BaseChronosPipeline
109 |
110 | pipeline = BaseChronosPipeline.from_pretrained(
111 | "amazon/chronos-t5-small", # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
112 | device_map="cuda", # use "cpu" for CPU inference
113 | torch_dtype=torch.bfloat16,
114 | )
115 |
116 | df = pd.read_csv(
117 | "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
118 | )
119 |
120 | # context must be either a 1D tensor, a list of 1D tensors,
121 | # or a left-padded 2D tensor with batch as the first dimension
122 | # quantiles is an fp32 tensor with shape [batch_size, prediction_length, num_quantile_levels]
123 | # mean is an fp32 tensor with shape [batch_size, prediction_length]
124 | quantiles, mean = pipeline.predict_quantiles(
125 | context=torch.tensor(df["#Passengers"]),
126 | prediction_length=12,
127 | quantile_levels=[0.1, 0.5, 0.9],
128 | )
129 | ```
130 |
131 | For the original Chronos models, `pipeline.predict` can be used to draw forecast samples. More options for `predict_kwargs` in `pipeline.predict_quantiles` can be found with:
132 |
133 | ```python
134 | from chronos import ChronosPipeline, ChronosBoltPipeline
135 |
136 | print(ChronosPipeline.predict.__doc__) # for Chronos models
137 | print(ChronosBoltPipeline.predict.__doc__) # for Chronos-Bolt models
138 | ```
139 |
140 | We can now visualize the forecast:
141 |
142 | ```python
143 | import matplotlib.pyplot as plt # requires: pip install matplotlib
144 |
145 | forecast_index = range(len(df), len(df) + 12)
146 | low, median, high = quantiles[0, :, 0], quantiles[0, :, 1], quantiles[0, :, 2]
147 |
148 | plt.figure(figsize=(8, 4))
149 | plt.plot(df["#Passengers"], color="royalblue", label="historical data")
150 | plt.plot(forecast_index, median, color="tomato", label="median forecast")
151 | plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
152 | plt.legend()
153 | plt.grid()
154 | plt.show()
155 | ```
156 |
157 | ### Extracting Encoder Embeddings
158 |
159 | A minimal example showing how to extract encoder embeddings from Chronos models:
160 |
161 | ```python
162 | import pandas as pd
163 | import torch
164 | from chronos import ChronosPipeline
165 |
166 | pipeline = ChronosPipeline.from_pretrained(
167 | "amazon/chronos-t5-small",
168 | device_map="cuda",
169 | torch_dtype=torch.bfloat16,
170 | )
171 |
172 | df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
173 |
174 | # context must be either a 1D tensor, a list of 1D tensors,
175 | # or a left-padded 2D tensor with batch as the first dimension
176 | context = torch.tensor(df["#Passengers"])
177 | embeddings, tokenizer_state = pipeline.embed(context)
178 | ```
179 |
180 | ### Pretraining, fine-tuning and evaluation
181 |
182 | Scripts for pretraining, fine-tuning and evaluating Chronos models can be found in [this folder](./scripts/).
183 |
184 | ## :floppy_disk: Datasets
185 |
186 | Datasets used in the Chronos paper for pretraining and evaluation (both in-domain and zero-shot) are available through the HuggingFace repos: [`autogluon/chronos_datasets`](https://huggingface.co/datasets/autogluon/chronos_datasets) and [`autogluon/chronos_datasets_extra`](https://huggingface.co/datasets/autogluon/chronos_datasets_extra). Check out these repos for instructions on how to download and use the datasets.
187 |
188 | ## 🔥 Coverage
189 |
190 | - [Adapting language model architectures for time series forecasting](https://www.amazon.science/blog/adapting-language-model-architectures-for-time-series-forecasting) (Amazon Science blog post)
191 | - [Amazon AI Researchers Introduce Chronos: A New Machine Learning Framework for Pretrained Probabilistic Time Series Models](https://www.marktechpost.com/2024/03/15/amazon-ai-researchers-introduce-chronos-a-new-machine-learning-framework-for-pretrained-probabilistic-time-series-models/) (Marktechpost blog post)
192 | - [Chronos: The Rise of Foundation Models for Time Series Forecasting](https://towardsdatascience.com/chronos-the-rise-of-foundation-models-for-time-series-forecasting-aaeba62d9da3) (Towards Data Science blog post by Luís Roque and Rafael Guedes)
193 | - [Moirai: Time Series Foundation Models for Universal Forecasting](https://towardsdatascience.com/moirai-time-series-foundation-models-for-universal-forecasting-dc93f74b330f) (Towards Data Science blog post by Luís Roque and Rafael Guedes, includes comparison of Chronos with Moirai)
194 | - [Chronos: The Latest Time Series Forecasting Foundation Model by Amazon](https://towardsdatascience.com/chronos-the-latest-time-series-forecasting-foundation-model-by-amazon-2687d641705a) (Towards Data Science blog post by Marco Peixeiro)
195 | - The original article had a critical bug affecting the metric computation for Chronos. We opened a [pull request](https://github.com/marcopeix/time-series-analysis/pull/10) to fix it.
196 | - [How to Effectively Forecast Time Series with Amazon's New Time Series Forecasting Model](https://towardsdatascience.com/how-to-effectively-forecast-time-series-with-amazons-new-time-series-forecasting-model-9e04d4ccf67e) (Towards Data Science blog post by Eivind Kjosbakken)
197 | - [Chronos: Learning the Language of Time Series](https://minimizeregret.com/linked/2024/03/27/chronos-forecasting/) (Minimize Regret blog post by Tim Radtke)
198 | - [Chronos: Another Zero-Shot Time Series Forecaster LLM](https://levelup.gitconnected.com/chronos-another-zero-shot-time-series-forecaster-llm-0e80753a7ad0) (Level Up Coding blog post by Level Up Coding AI TutorMaster)
199 | - [Paper Review: Chronos: Learning the Language of Time Series](https://andlukyane.com/blog/paper-review-chronos) (Review by Andrey Lukyanenko)
200 | - [Foundation Models for Forecasting: the Future or Folly?](https://insights.radix.ai/blog/foundation-models-for-forecasting-the-future-or-folly) (Blog post by Radix)
201 | - [Learning the Language of Time Series with Chronos](https://medium.com/@ManueleCaddeo/learning-the-language-of-time-series-with-chronos-fea7d0fedde4) (Medium post by Manuele Caddeo)
202 | - [The latest advancement in Time Series Forecasting from AWS: Chronos](https://medium.com/chat-gpt-now-writes-all-my-articles/the-latest-advancement-in-time-series-forecasting-from-aws-chronos-python-code-included-0205d01248f3) (Medium post by Abish Pius)
203 | - [Decoding the Future: How Chronos Redefines Time Series Forecasting with the Art of Language](https://medium.com/@zamalbabar/decoding-the-future-how-chronos-redefines-time-series-forecasting-with-the-art-of-language-cecc2174e400) (Medium post by Zamal)
204 | - [Comparison of Chronos against the SCUM ensemble of statistical models](https://github.com/Nixtla/nixtla/tree/main/experiments/amazon-chronos) (Benchmark by Nixtla)
205 | - We opened a [pull request](https://github.com/Nixtla/nixtla/pull/281) extending the analysis to 28 datasets (200K+ time series) and showing that **zero-shot** Chronos models perform comparably to this strong ensemble of 4 statistical models while being significantly faster on average. Our complete response can be [found here](https://www.linkedin.com/pulse/extended-comparison-chronos-against-statistical-ensemble-ansari-4aste/).
206 | - [Comparison of Chronos against a variety of forecasting models](https://www.linkedin.com/feed/update/urn:li:activity:7178398371815051267/) (Benchmark by ReadyTensor)
207 |
208 | ## 📝 Citation
209 |
210 | If you find Chronos models useful for your research, please consider citing the associated [paper](https://arxiv.org/abs/2403.07815):
211 |
212 | ```
213 | @article{ansari2024chronos,
214 | title={Chronos: Learning the Language of Time Series},
215 | author={Ansari, Abdul Fatir and Stella, Lorenzo and Turkmen, Caner and Zhang, Xiyuan, and Mercado, Pedro and Shen, Huibin and Shchur, Oleksandr and Rangapuram, Syama Syndar and Pineda Arango, Sebastian and Kapoor, Shubham and Zschiegner, Jasper and Maddix, Danielle C. and Mahoney, Michael W. and Torkkola, Kari and Gordon Wilson, Andrew and Bohlke-Schneider, Michael and Wang, Yuyang},
216 | journal={Transactions on Machine Learning Research},
217 | issn={2835-8856},
218 | year={2024},
219 | url={https://openreview.net/forum?id=gerNCVqqtR}
220 | }
221 | ```
222 |
223 | ## 🛡️ Security
224 |
225 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
226 |
227 | ## 📃 License
228 |
229 | This project is licensed under the Apache-2.0 License.
230 |
--------------------------------------------------------------------------------
/ci/evaluate/backtest_config.yaml:
--------------------------------------------------------------------------------
1 | # From In-domain
2 | - name: taxi_30min # 30 min
3 | hf_repo: autogluon/chronos_datasets
4 | offset: -48
5 | prediction_length: 48
6 | num_rolls: 1
7 | # From Zero-shot
8 | - name: ETTh # Hourly
9 | hf_repo: autogluon/chronos_datasets_extra
10 | offset: -24
11 | prediction_length: 24
12 | num_rolls: 1
13 | - name: monash_covid_deaths # Daily
14 | hf_repo: autogluon/chronos_datasets
15 | offset: -30
16 | prediction_length: 30
17 | num_rolls: 1
18 | - name: monash_nn5_weekly # Weekly
19 | hf_repo: autogluon/chronos_datasets
20 | offset: -8
21 | prediction_length: 8
22 | num_rolls: 1
23 | - name: monash_fred_md # Monthly
24 | hf_repo: autogluon/chronos_datasets
25 | offset: -12
26 | prediction_length: 12
27 | num_rolls: 1
28 | - name: monash_m3_quarterly # Quarterly
29 | hf_repo: autogluon/chronos_datasets
30 | offset: -8
31 | prediction_length: 8
32 | num_rolls: 1
33 | - name: monash_tourism_yearly # Yearly
34 | hf_repo: autogluon/chronos_datasets
35 | offset: -4
36 | prediction_length: 4
37 | num_rolls: 1
--------------------------------------------------------------------------------
/figures/chronos-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/chronos-forecasting/6a9c8dadac04eb85befc935043e3e2cce914267f/figures/chronos-logo.png
--------------------------------------------------------------------------------
/figures/main-figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/chronos-forecasting/6a9c8dadac04eb85befc935043e3e2cce914267f/figures/main-figure.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "chronos-forecasting"
3 | version = "1.5.2"
4 | authors = [
5 | { name="Abdul Fatir Ansari", email="ansarnd@amazon.com" },
6 | { name="Lorenzo Stella", email="stellalo@amazon.com" },
7 | { name="Caner Turkmen", email="atturkm@amazon.com" },
8 | ]
9 | description = "Chronos: Pretrained models for time series forecasting"
10 | readme = "README.md"
11 | license = { file = "LICENSE" }
12 | requires-python = ">=3.9"
13 | dependencies = [
14 | "torch>=2.0,<3", # package was tested on 2.2
15 | "transformers>=4.48,<5",
16 | "accelerate>=0.32,<2",
17 | ]
18 | classifiers = [
19 | "Programming Language :: Python :: 3",
20 | "License :: OSI Approved :: Apache Software License",
21 | "Operating System :: OS Independent",
22 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
23 | ]
24 |
25 | [build-system]
26 | requires = ["hatchling"]
27 | build-backend = "hatchling.build"
28 |
29 | [tool.hatch.build.targets.wheel]
30 | packages = ["src/chronos"]
31 |
32 | [project.optional-dependencies]
33 | test = ["pytest~=8.0", "numpy~=1.21"]
34 | typecheck = ["mypy~=1.9"]
35 | training = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer", "typer-config", "joblib", "scikit-learn", "tensorboard"]
36 | evaluation = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer"]
37 |
38 | [project.urls]
39 | Homepage = "https://github.com/amazon-science/chronos-forecasting"
40 | Issues = "https://github.com/amazon-science/chronos-forecasting/issues"
41 | Paper = "https://arxiv.org/abs/2403.07815"
42 |
43 | [tool.mypy]
44 | ignore_missing_imports = true
45 |
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # Usage Examples
2 |
3 | ## Generating Synthetic Time Series (KernelSynth)
4 |
5 | - Install this package with with the `training` extra:
6 | ```
7 | pip install "chronos-forecasting[training] @ git+https://github.com/amazon-science/chronos-forecasting.git"
8 | ```
9 | - Run `kernel-synth.py`:
10 | ```sh
11 | # With defaults used in the paper (1M time series and 5 max_kernels)
12 | python kernel-synth.py
13 |
14 | # You may optionally specify num-series and max-kernels
15 | python kernel-synth.py \
16 | --num-series \
17 | --max-kernels
18 | ```
19 | The generated time series will be saved in a [GluonTS](https://github.com/awslabs/gluonts)-comptabile arrow file `kernelsynth-data.arrow`.
20 |
21 | ## Pretraining (and fine-tuning) Chronos models
22 | - Install this package with with the `training` extra:
23 | ```
24 | pip install "chronos-forecasting[training] @ git+https://github.com/amazon-science/chronos-forecasting.git"
25 | ```
26 | - Convert your time series dataset into a GluonTS-compatible file dataset. We recommend using the arrow format. You may use the `convert_to_arrow` function from the following snippet for that. Optionally, you may use [synthetic data from KernelSynth](#generating-synthetic-time-series-kernelsynth) to follow along.
27 | ```py
28 | from pathlib import Path
29 | from typing import List, Union
30 |
31 | import numpy as np
32 | from gluonts.dataset.arrow import ArrowWriter
33 |
34 |
35 | def convert_to_arrow(
36 | path: Union[str, Path],
37 | time_series: Union[List[np.ndarray], np.ndarray],
38 | compression: str = "lz4",
39 | ):
40 | """
41 | Store a given set of series into Arrow format at the specified path.
42 |
43 | Input data can be either a list of 1D numpy arrays, or a single 2D
44 | numpy array of shape (num_series, time_length).
45 | """
46 | assert isinstance(time_series, list) or (
47 | isinstance(time_series, np.ndarray) and
48 | time_series.ndim == 2
49 | )
50 |
51 | # Set an arbitrary start time
52 | start = np.datetime64("2000-01-01 00:00", "s")
53 |
54 | dataset = [
55 | {"start": start, "target": ts} for ts in time_series
56 | ]
57 |
58 | ArrowWriter(compression=compression).write_to_file(
59 | dataset,
60 | path=path,
61 | )
62 |
63 |
64 | if __name__ == "__main__":
65 | # Generate 20 random time series of length 1024
66 | time_series = [np.random.randn(1024) for i in range(20)]
67 |
68 | # Convert to GluonTS arrow format
69 | convert_to_arrow("./noise-data.arrow", time_series=time_series)
70 | ```
71 | - Modify the [training configs](training/configs) to use your data. Let's use the KernelSynth data as an example.
72 | ```yaml
73 | # List of training data files
74 | training_data_paths:
75 | - "/path/to/kernelsynth-data.arrow"
76 | # Mixing probability of each dataset file
77 | probability:
78 | - 1.0
79 | ```
80 | You may optionally change other parameters of the config file, as required. For instance, if you're interested in fine-tuning the model from a pretrained Chronos checkpoint, you should change the `model_id`, set `random_init: false`, and (optionally) change other parameters such as `max_steps` and `learning_rate`.
81 | - Start the training (or fine-tuning) job:
82 | ```sh
83 | # On single GPU
84 | CUDA_VISIBLE_DEVICES=0 python training/train.py --config /path/to/modified/config.yaml
85 |
86 | # On multiple GPUs (example with 8 GPUs)
87 | torchrun --nproc-per-node=8 training/train.py --config /path/to/modified/config.yaml
88 |
89 | # Fine-tune `amazon/chronos-t5-small` for 1000 steps with initial learning rate of 1e-3
90 | CUDA_VISIBLE_DEVICES=0 python training/train.py --config /path/to/modified/config.yaml \
91 | --model-id amazon/chronos-t5-small \
92 | --no-random-init \
93 | --max-steps 1000 \
94 | --learning-rate 0.001
95 | ```
96 | The output and checkpoints will be saved in `output/run-{id}/`.
97 | > [!TIP]
98 | > If the initial training step is too slow, you might want to change the `shuffle_buffer_length` and/or set `torch_compile` to `false`.
99 |
100 | > [!IMPORTANT]
101 | > When pretraining causal models (such as GPT2), the training script does [`LastValueImputation`](https://github.com/awslabs/gluonts/blob/f0f2266d520cb980f4c1ce18c28b003ad5cd2599/src/gluonts/transform/feature.py#L103) for missing values by default. If you pretrain causal models, please ensure that missing values are imputed similarly before passing the context tensor to `ChronosPipeline.predict()` for accurate results.
102 | - (Optional) Once trained, you can easily push your fine-tuned model to HuggingFace🤗 Hub. Before that, do not forget to [create an access token](https://huggingface.co/settings/tokens) with **write permissions** and put it in `~/.cache/huggingface/token`. Here's a snippet that will push a fine-tuned model to HuggingFace🤗 Hub at `/chronos-t5-small-fine-tuned`.
103 | ```py
104 | from chronos import ChronosPipeline
105 |
106 | pipeline = ChronosPipeline.from_pretrained("/path/to/fine-tuned/model/ckpt/dir/")
107 | pipeline.model.model.push_to_hub("chronos-t5-small-fine-tuned")
108 | ```
109 |
110 | ## Evaluating Chronos models
111 |
112 | Follow these steps to compute the WQL and MASE values for the in-domain and zero-shot benchmarks in our paper.
113 |
114 | - Install this package with with the `evaluation` extra:
115 | ```
116 | pip install "chronos-forecasting[evaluation] @ git+https://github.com/amazon-science/chronos-forecasting.git"
117 | ```
118 | - Run the evaluation script:
119 | ```sh
120 | # In-domain evaluation
121 | # Results will be saved in: evaluation/results/chronos-t5-small-in-domain.csv
122 | python evaluation/evaluate.py evaluation/configs/in-domain.yaml evaluation/results/chronos-t5-small-in-domain.csv \
123 | --chronos-model-id "amazon/chronos-t5-small" \
124 | --batch-size=32 \
125 | --device=cuda:0 \
126 | --num-samples 20
127 |
128 | # Zero-shot evaluation
129 | # Results will be saved in: evaluation/results/chronos-t5-small-zero-shot.csv
130 | python evaluation/evaluate.py evaluation/configs/zero-shot.yaml evaluation/results/chronos-t5-small-zero-shot.csv \
131 | --chronos-model-id "amazon/chronos-t5-small" \
132 | --batch-size=32 \
133 | --device=cuda:0 \
134 | --num-samples 20
135 | ```
136 | - Use the following snippet to compute the aggregated relative WQL and MASE scores:
137 | ```py
138 | import pandas as pd
139 | from scipy.stats import gmean # requires: pip install scipy
140 |
141 |
142 | def agg_relative_score(model_df: pd.DataFrame, baseline_df: pd.DataFrame):
143 | relative_score = model_df.drop("model", axis="columns") / baseline_df.drop(
144 | "model", axis="columns"
145 | )
146 | return relative_score.agg(gmean)
147 |
148 |
149 | result_df = pd.read_csv("evaluation/results/chronos-t5-small-in-domain.csv").set_index("dataset")
150 | baseline_df = pd.read_csv("evaluation/results/seasonal-naive-in-domain.csv").set_index("dataset")
151 |
152 | agg_score_df = agg_relative_score(result_df, baseline_df)
153 | ```
--------------------------------------------------------------------------------
/scripts/evaluation/agg-relative-score.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import typer
3 | from scipy.stats import gmean
4 | from pathlib import Path
5 |
6 | app = typer.Typer(pretty_exceptions_enable=False)
7 | DEFAULT_RESULTS_DIR = Path(__file__).parent / "results"
8 |
9 |
10 | def agg_relative_score(model_csv: Path, baseline_csv: Path):
11 | model_df = pd.read_csv(model_csv).set_index("dataset")
12 | baseline_df = pd.read_csv(baseline_csv).set_index("dataset")
13 | relative_score = model_df.drop("model", axis="columns") / baseline_df.drop(
14 | "model", axis="columns"
15 | )
16 | return relative_score.agg(gmean)
17 |
18 |
19 | @app.command()
20 | def main(
21 | model_name: str,
22 | baseline_name: str = "seasonal-naive",
23 | results_dir: Path = DEFAULT_RESULTS_DIR,
24 | ):
25 | """
26 | Compute the aggregated relative score as reported in the Chronos paper.
27 | Results will be saved to {results_dir}/{model_name}-agg-rel-scores.csv
28 |
29 | Parameters
30 | ----------
31 | model_name : str
32 | Name of the model used in the CSV files. The in-domain and zero-shot CSVs
33 | are expected to be named {model_name}-in-domain.csv and {model_name}-zero-shot.csv.
34 | results_dir : Path, optional, default = results/
35 | Directory where results CSVs generated by evaluate.py are stored
36 | """
37 |
38 | in_domain_agg_score_df = agg_relative_score(
39 | results_dir / f"{model_name}-in-domain.csv",
40 | results_dir / f"{baseline_name}-in-domain.csv",
41 | )
42 | in_domain_agg_score_df.name = "value"
43 | in_domain_agg_score_df.index.name = "metric"
44 |
45 | zero_shot_agg_score_df = agg_relative_score(
46 | results_dir / f"{model_name}-zero-shot.csv",
47 | results_dir / f"{baseline_name}-zero-shot.csv",
48 | )
49 | zero_shot_agg_score_df.name = "value"
50 | zero_shot_agg_score_df.index.name = "metric"
51 |
52 | agg_score_df = pd.concat(
53 | {"in-domain": in_domain_agg_score_df, "zero-shot": zero_shot_agg_score_df},
54 | names=["benchmark"],
55 | )
56 | agg_score_df.to_csv(f"{results_dir}/{model_name}-agg-rel-scores.csv")
57 |
58 |
59 | if __name__ == "__main__":
60 | app()
61 |
--------------------------------------------------------------------------------
/scripts/evaluation/configs/in-domain.yaml:
--------------------------------------------------------------------------------
1 | # Backtest configs for the 15 "in-domain" datasets.
2 | # The training portion of these datasets was part of the
3 | # training corpus for Chronos models.
4 | - name: electricity_15min
5 | hf_repo: autogluon/chronos_datasets
6 | offset: -5376
7 | prediction_length: 24
8 | num_rolls: 1
9 | - name: monash_electricity_hourly
10 | hf_repo: autogluon/chronos_datasets
11 | offset: -24
12 | prediction_length: 24
13 | num_rolls: 1
14 | - name: monash_electricity_weekly
15 | hf_repo: autogluon/chronos_datasets
16 | offset: -8
17 | prediction_length: 8
18 | num_rolls: 1
19 | - name: monash_kdd_cup_2018
20 | hf_repo: autogluon/chronos_datasets
21 | offset: -48
22 | prediction_length: 48
23 | num_rolls: 1
24 | - name: m4_daily
25 | hf_repo: autogluon/chronos_datasets
26 | offset: -14
27 | prediction_length: 14
28 | num_rolls: 1
29 | - name: m4_hourly
30 | hf_repo: autogluon/chronos_datasets
31 | offset: -48
32 | prediction_length: 48
33 | num_rolls: 1
34 | - name: m4_monthly
35 | hf_repo: autogluon/chronos_datasets
36 | offset: -18
37 | prediction_length: 18
38 | num_rolls: 1
39 | - name: m4_weekly
40 | hf_repo: autogluon/chronos_datasets
41 | offset: -13
42 | prediction_length: 13
43 | num_rolls: 1
44 | - name: monash_pedestrian_counts
45 | hf_repo: autogluon/chronos_datasets
46 | offset: -48
47 | prediction_length: 48
48 | num_rolls: 1
49 | - name: taxi_30min
50 | hf_repo: autogluon/chronos_datasets
51 | offset: -48
52 | prediction_length: 48
53 | num_rolls: 1
54 | - name: uber_tlc_hourly
55 | hf_repo: autogluon/chronos_datasets
56 | offset: -24
57 | prediction_length: 24
58 | num_rolls: 1
59 | - name: uber_tlc_daily
60 | hf_repo: autogluon/chronos_datasets
61 | offset: -7
62 | prediction_length: 7
63 | num_rolls: 1
64 | - name: monash_rideshare
65 | hf_repo: autogluon/chronos_datasets
66 | offset: -24
67 | prediction_length: 24
68 | num_rolls: 1
69 | - name: monash_temperature_rain
70 | hf_repo: autogluon/chronos_datasets
71 | offset: -30
72 | prediction_length: 30
73 | num_rolls: 1
74 | - name: monash_london_smart_meters
75 | hf_repo: autogluon/chronos_datasets
76 | offset: -48
77 | prediction_length: 48
78 | num_rolls: 1
79 |
--------------------------------------------------------------------------------
/scripts/evaluation/configs/zero-shot.yaml:
--------------------------------------------------------------------------------
1 | # Backtest configs for the 27 "zero-shot" datasets.
2 | # These datasets were not seen by Chronos models during training.
3 | - name: monash_traffic
4 | hf_repo: autogluon/chronos_datasets
5 | offset: -24
6 | prediction_length: 24
7 | num_rolls: 1
8 | - name: monash_australian_electricity
9 | hf_repo: autogluon/chronos_datasets
10 | offset: -48
11 | prediction_length: 48
12 | num_rolls: 1
13 | - name: ercot
14 | hf_repo: autogluon/chronos_datasets
15 | offset: -24
16 | prediction_length: 24
17 | num_rolls: 1
18 | - name: ETTm
19 | hf_repo: autogluon/chronos_datasets_extra
20 | offset: -96
21 | prediction_length: 24
22 | num_rolls: 1
23 | - name: ETTh
24 | hf_repo: autogluon/chronos_datasets_extra
25 | offset: -24
26 | prediction_length: 24
27 | num_rolls: 1
28 | - name: exchange_rate
29 | hf_repo: autogluon/chronos_datasets
30 | offset: -30
31 | prediction_length: 30
32 | num_rolls: 1
33 | - name: nn5
34 | hf_repo: autogluon/chronos_datasets
35 | offset: -56
36 | prediction_length: 56
37 | num_rolls: 1
38 | - name: monash_nn5_weekly
39 | hf_repo: autogluon/chronos_datasets
40 | offset: -8
41 | prediction_length: 8
42 | num_rolls: 1
43 | - name: monash_weather
44 | hf_repo: autogluon/chronos_datasets
45 | offset: -30
46 | prediction_length: 30
47 | num_rolls: 1
48 | - name: monash_covid_deaths
49 | hf_repo: autogluon/chronos_datasets
50 | offset: -30
51 | prediction_length: 30
52 | num_rolls: 1
53 | - name: monash_fred_md
54 | hf_repo: autogluon/chronos_datasets
55 | offset: -12
56 | prediction_length: 12
57 | num_rolls: 1
58 | - name: m4_quarterly
59 | hf_repo: autogluon/chronos_datasets
60 | offset: -8
61 | prediction_length: 8
62 | num_rolls: 1
63 | - name: m4_yearly
64 | hf_repo: autogluon/chronos_datasets
65 | offset: -6
66 | prediction_length: 6
67 | num_rolls: 1
68 | - name: dominick
69 | hf_repo: autogluon/chronos_datasets
70 | offset: -8
71 | prediction_length: 8
72 | num_rolls: 1
73 | - name: m5
74 | hf_repo: autogluon/chronos_datasets
75 | offset: -28
76 | prediction_length: 28
77 | num_rolls: 1
78 | - name: monash_tourism_monthly
79 | hf_repo: autogluon/chronos_datasets
80 | offset: -24
81 | prediction_length: 24
82 | num_rolls: 1
83 | - name: monash_tourism_quarterly
84 | hf_repo: autogluon/chronos_datasets
85 | offset: -8
86 | prediction_length: 8
87 | num_rolls: 1
88 | - name: monash_tourism_yearly
89 | hf_repo: autogluon/chronos_datasets
90 | offset: -4
91 | prediction_length: 4
92 | num_rolls: 1
93 | - name: monash_car_parts
94 | hf_repo: autogluon/chronos_datasets
95 | offset: -12
96 | prediction_length: 12
97 | num_rolls: 1
98 | - name: monash_hospital
99 | hf_repo: autogluon/chronos_datasets
100 | offset: -12
101 | prediction_length: 12
102 | num_rolls: 1
103 | - name: monash_cif_2016
104 | hf_repo: autogluon/chronos_datasets
105 | offset: -12
106 | prediction_length: 12
107 | num_rolls: 1
108 | - name: monash_m1_yearly
109 | hf_repo: autogluon/chronos_datasets
110 | offset: -6
111 | prediction_length: 6
112 | num_rolls: 1
113 | - name: monash_m1_quarterly
114 | hf_repo: autogluon/chronos_datasets
115 | offset: -8
116 | prediction_length: 8
117 | num_rolls: 1
118 | - name: monash_m1_monthly
119 | hf_repo: autogluon/chronos_datasets
120 | offset: -18
121 | prediction_length: 18
122 | num_rolls: 1
123 | - name: monash_m3_monthly
124 | hf_repo: autogluon/chronos_datasets
125 | offset: -18
126 | prediction_length: 18
127 | num_rolls: 1
128 | - name: monash_m3_yearly
129 | hf_repo: autogluon/chronos_datasets
130 | offset: -6
131 | prediction_length: 6
132 | num_rolls: 1
133 | - name: monash_m3_quarterly
134 | hf_repo: autogluon/chronos_datasets
135 | offset: -8
136 | prediction_length: 8
137 | num_rolls: 1
--------------------------------------------------------------------------------
/scripts/evaluation/evaluate.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 | from typing import Iterable, Optional
4 |
5 | import datasets
6 | import numpy as np
7 | import pandas as pd
8 | import torch
9 | import typer
10 | import yaml
11 | from gluonts.dataset.split import split
12 | from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
13 | from gluonts.itertools import batcher
14 | from gluonts.model.evaluation import evaluate_forecasts
15 | from gluonts.model.forecast import QuantileForecast, SampleForecast
16 | from tqdm.auto import tqdm
17 |
18 | from chronos import (
19 | BaseChronosPipeline,
20 | ChronosBoltPipeline,
21 | ChronosPipeline,
22 | ForecastType,
23 | )
24 |
25 | app = typer.Typer(pretty_exceptions_enable=False)
26 |
27 |
28 | def to_gluonts_univariate(hf_dataset: datasets.Dataset):
29 | series_fields = [
30 | col
31 | for col in hf_dataset.features
32 | if isinstance(hf_dataset.features[col], datasets.Sequence)
33 | ]
34 | series_fields.remove("timestamp")
35 | dataset_length = hf_dataset.info.splits["train"].num_examples * len(series_fields)
36 |
37 | # Assumes that all time series in the dataset have the same frequency
38 | dataset_freq = pd.DatetimeIndex(hf_dataset[0]["timestamp"]).to_period()[0].freqstr
39 |
40 | gts_dataset = []
41 | for hf_entry in hf_dataset:
42 | for field in series_fields:
43 | gts_dataset.append(
44 | {
45 | "start": pd.Period(
46 | hf_entry["timestamp"][0],
47 | freq=dataset_freq,
48 | ),
49 | "target": hf_entry[field],
50 | }
51 | )
52 | assert len(gts_dataset) == dataset_length
53 |
54 | return gts_dataset
55 |
56 |
57 | def load_and_split_dataset(backtest_config: dict):
58 | hf_repo = backtest_config["hf_repo"]
59 | dataset_name = backtest_config["name"]
60 | offset = backtest_config["offset"]
61 | prediction_length = backtest_config["prediction_length"]
62 | num_rolls = backtest_config["num_rolls"]
63 |
64 | # This is needed because the datasets in autogluon/chronos_datasets_extra cannot
65 | # be distribued due to license restrictions and must be generated on the fly
66 | trust_remote_code = True if hf_repo == "autogluon/chronos_datasets_extra" else False
67 |
68 | ds = datasets.load_dataset(
69 | hf_repo, dataset_name, split="train", trust_remote_code=trust_remote_code
70 | )
71 | ds.set_format("numpy")
72 |
73 | gts_dataset = to_gluonts_univariate(ds)
74 |
75 | # Split dataset for evaluation
76 | _, test_template = split(gts_dataset, offset=offset)
77 | test_data = test_template.generate_instances(prediction_length, windows=num_rolls)
78 |
79 | return test_data
80 |
81 |
82 | def generate_forecasts(
83 | test_data_input: Iterable,
84 | pipeline: BaseChronosPipeline,
85 | prediction_length: int,
86 | batch_size: int,
87 | **predict_kwargs,
88 | ):
89 | # Generate forecasts
90 | forecast_outputs = []
91 | for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
92 | context = [torch.tensor(entry["target"]) for entry in batch]
93 | forecast_outputs.append(
94 | pipeline.predict(
95 | context,
96 | prediction_length=prediction_length,
97 | **predict_kwargs,
98 | ).numpy()
99 | )
100 | forecast_outputs = np.concatenate(forecast_outputs)
101 |
102 | # Convert forecast samples into gluonts Forecast objects
103 | forecasts = []
104 | for item, ts in zip(forecast_outputs, test_data_input):
105 | forecast_start_date = ts["start"] + len(ts["target"])
106 |
107 | if pipeline.forecast_type == ForecastType.SAMPLES:
108 | forecasts.append(
109 | SampleForecast(samples=item, start_date=forecast_start_date)
110 | )
111 | elif pipeline.forecast_type == ForecastType.QUANTILES:
112 | forecasts.append(
113 | QuantileForecast(
114 | forecast_arrays=item,
115 | forecast_keys=list(map(str, pipeline.quantiles)),
116 | start_date=forecast_start_date,
117 | )
118 | )
119 |
120 | return forecasts
121 |
122 |
123 | @app.command()
124 | def main(
125 | config_path: Path,
126 | metrics_path: Path,
127 | chronos_model_id: str = "amazon/chronos-t5-small",
128 | device: str = "cuda",
129 | torch_dtype: str = "bfloat16",
130 | batch_size: int = 32,
131 | num_samples: int = 20,
132 | temperature: Optional[float] = None,
133 | top_k: Optional[int] = None,
134 | top_p: Optional[float] = None,
135 | ):
136 | """Evaluate Chronos models.
137 |
138 | Parameters
139 | ----------
140 | config_path : Path
141 | Path to the evaluation config. See ./configs/.
142 | metrics_path : Path
143 | Path to the CSV file where metrics will be saved.
144 | chronos_model_id : str, optional, default = "amazon/chronos-t5-small"
145 | HuggingFace ID of the Chronos model or local path
146 | Available models on HuggingFace:
147 | Chronos:
148 | - amazon/chronos-t5-tiny
149 | - amazon/chronos-t5-mini
150 | - amazon/chronos-t5-small
151 | - amazon/chronos-t5-base
152 | - amazon/chronos-t5-large
153 | Chronos-Bolt:
154 | - amazon/chronos-bolt-tiny
155 | - amazon/chronos-bolt-mini
156 | - amazon/chronos-bolt-small
157 | - amazon/chronos-bolt-base
158 | device : str, optional, default = "cuda"
159 | Device on which inference will be performed
160 | torch_dtype : str, optional
161 | Model's dtype, by default "bfloat16"
162 | batch_size : int, optional, default = 32
163 | Batch size for inference. For Chronos-Bolt models, significantly larger
164 | batch sizes can be used
165 | num_samples : int, optional, default = 20
166 | Number of samples to draw when using the original Chronos models
167 | temperature : Optional[float], optional, default = 1.0
168 | Softmax temperature to used for the original Chronos models
169 | top_k : Optional[int], optional, default = 50
170 | Top-K sampling, by default None
171 | top_p : Optional[float], optional, default = 1.0
172 | Top-p sampling, by default None
173 | """
174 | if isinstance(torch_dtype, str):
175 | torch_dtype = getattr(torch, torch_dtype)
176 | assert isinstance(torch_dtype, torch.dtype)
177 |
178 | # Load Chronos
179 | pipeline = BaseChronosPipeline.from_pretrained(
180 | chronos_model_id,
181 | device_map=device,
182 | torch_dtype=torch_dtype,
183 | )
184 |
185 | if isinstance(pipeline, ChronosPipeline):
186 | predict_kwargs = dict(
187 | num_samples=num_samples,
188 | temperature=temperature,
189 | top_k=top_k,
190 | top_p=top_p,
191 | )
192 | elif isinstance(pipeline, ChronosBoltPipeline):
193 | predict_kwargs = {}
194 |
195 | # Load backtest configs
196 | with open(config_path) as fp:
197 | backtest_configs = yaml.safe_load(fp)
198 |
199 | result_rows = []
200 | for config in backtest_configs:
201 | dataset_name = config["name"]
202 | prediction_length = config["prediction_length"]
203 |
204 | logger.info(f"Loading {dataset_name}")
205 | test_data = load_and_split_dataset(backtest_config=config)
206 |
207 | logger.info(
208 | f"Generating forecasts for {dataset_name} "
209 | f"({len(test_data.input)} time series)"
210 | )
211 | forecasts = generate_forecasts(
212 | test_data.input,
213 | pipeline=pipeline,
214 | prediction_length=prediction_length,
215 | batch_size=batch_size,
216 | **predict_kwargs,
217 | )
218 |
219 | logger.info(f"Evaluating forecasts for {dataset_name}")
220 | metrics = (
221 | evaluate_forecasts(
222 | forecasts,
223 | test_data=test_data,
224 | metrics=[
225 | MASE(),
226 | MeanWeightedSumQuantileLoss(np.arange(0.1, 1.0, 0.1)),
227 | ],
228 | batch_size=5000,
229 | )
230 | .reset_index(drop=True)
231 | .to_dict(orient="records")
232 | )
233 | result_rows.append(
234 | {"dataset": dataset_name, "model": chronos_model_id, **metrics[0]}
235 | )
236 |
237 | # Save results to a CSV file
238 | results_df = (
239 | pd.DataFrame(result_rows)
240 | .rename(
241 | {"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL"},
242 | axis="columns",
243 | )
244 | .sort_values(by="dataset")
245 | )
246 | results_df.to_csv(metrics_path, index=False)
247 |
248 |
249 | if __name__ == "__main__":
250 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
251 | logger = logging.getLogger("Chronos Evaluation")
252 | logger.setLevel(logging.INFO)
253 | app()
254 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-base-agg-rel-scores.csv:
--------------------------------------------------------------------------------
1 | benchmark,metric,value
2 | in-domain,MASE,0.6800133628315155
3 | in-domain,WQL,0.5339263811489279
4 | zero-shot,MASE,0.7914551113353537
5 | zero-shot,WQL,0.6241424984163773
6 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-base-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,amazon/chronos-bolt-base,0.41069374835605243,0.0703533790998506
3 | m4_daily,amazon/chronos-bolt-base,3.205192517121196,0.02110308498174413
4 | m4_hourly,amazon/chronos-bolt-base,0.8350129849014075,0.025353803894164
5 | m4_monthly,amazon/chronos-bolt-base,0.9491758928362231,0.09382496106659234
6 | m4_weekly,amazon/chronos-bolt-base,2.0847827409162742,0.03816605075768161
7 | monash_electricity_hourly,amazon/chronos-bolt-base,1.254966217685461,0.09442192616975713
8 | monash_electricity_weekly,amazon/chronos-bolt-base,1.8391546050108039,0.06410971963960499
9 | monash_kdd_cup_2018,amazon/chronos-bolt-base,0.6405985809360102,0.2509172188706336
10 | monash_london_smart_meters,amazon/chronos-bolt-base,0.701398572604996,0.3218915088923906
11 | monash_pedestrian_counts,amazon/chronos-bolt-base,0.2646412642278343,0.18789459806066328
12 | monash_rideshare,amazon/chronos-bolt-base,0.7695376426829713,0.11637119433040358
13 | monash_temperature_rain,amazon/chronos-bolt-base,0.8983612698773724,0.6050555216496304
14 | taxi_30min,amazon/chronos-bolt-base,0.7688908266765317,0.2363178601205094
15 | uber_tlc_daily,amazon/chronos-bolt-base,0.8231767493519677,0.0926036406916842
16 | uber_tlc_hourly,amazon/chronos-bolt-base,0.6632193728217927,0.14987786887626975
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-base-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,amazon/chronos-bolt-base,0.7479154031956647,0.07062173821055001
3 | ETTm,amazon/chronos-bolt-base,0.6334357237512225,0.052261607745858835
4 | dominick,amazon/chronos-bolt-base,0.8560272479913918,0.3453573743726445
5 | ercot,amazon/chronos-bolt-base,0.6933217425507392,0.02142183038021456
6 | exchange_rate,amazon/chronos-bolt-base,1.7095176257412634,0.01200682136751536
7 | m4_quarterly,amazon/chronos-bolt-base,1.2244670010522907,0.0771066518089854
8 | m4_yearly,amazon/chronos-bolt-base,3.513752058541554,0.12142798053483984
9 | m5,amazon/chronos-bolt-base,0.9152230096463854,0.561999688057527
10 | monash_australian_electricity,amazon/chronos-bolt-base,0.7403239930185613,0.03584034231329335
11 | monash_car_parts,amazon/chronos-bolt-base,0.8550263912438314,0.9945122291263591
12 | monash_cif_2016,amazon/chronos-bolt-base,0.9988541862779904,0.016456104842296485
13 | monash_covid_deaths,amazon/chronos-bolt-base,38.901749109066415,0.047410971217640714
14 | monash_fred_md,amazon/chronos-bolt-base,0.6468787708795645,0.04185083716355386
15 | monash_hospital,amazon/chronos-bolt-base,0.6883138394434054,0.057032869931903894
16 | monash_m1_monthly,amazon/chronos-bolt-base,1.0997677446267855,0.1392311148066238
17 | monash_m1_quarterly,amazon/chronos-bolt-base,1.7737851980875563,0.1007118219350403
18 | monash_m1_yearly,amazon/chronos-bolt-base,4.404672537832342,0.1504617654430952
19 | monash_m3_monthly,amazon/chronos-bolt-base,0.8510696834878182,0.09269673913736748
20 | monash_m3_quarterly,amazon/chronos-bolt-base,1.2890908822598466,0.07615133571216029
21 | monash_m3_yearly,amazon/chronos-bolt-base,2.9067097980770082,0.12934285625258413
22 | monash_nn5_weekly,amazon/chronos-bolt-base,0.9158766337957451,0.08352114810139548
23 | monash_tourism_monthly,amazon/chronos-bolt-base,1.5283388458731357,0.09026425492612797
24 | monash_tourism_quarterly,amazon/chronos-bolt-base,1.756127005530011,0.06448060953595125
25 | monash_tourism_yearly,amazon/chronos-bolt-base,3.691545772463519,0.16548820700844424
26 | monash_traffic,amazon/chronos-bolt-base,0.7843310867739336,0.23148632068725078
27 | monash_weather,amazon/chronos-bolt-base,0.8115247139672316,0.13350830777170594
28 | nn5,amazon/chronos-bolt-base,0.5764084996361287,0.1500519584148468
29 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-mini-agg-rel-scores.csv:
--------------------------------------------------------------------------------
1 | benchmark,metric,value
2 | in-domain,MASE,0.7268373301543752
3 | in-domain,WQL,0.565140251955324
4 | zero-shot,MASE,0.8221798917822493
5 | zero-shot,WQL,0.6441645845380903
6 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-mini-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,amazon/chronos-bolt-mini,0.44185193304080733,0.0731477927531107
3 | m4_daily,amazon/chronos-bolt-mini,3.1342608828747456,0.0206872246743766
4 | m4_hourly,amazon/chronos-bolt-mini,0.9218285923038745,0.024383114886067574
5 | m4_monthly,amazon/chronos-bolt-mini,0.9628339921394529,0.09502498697494888
6 | m4_weekly,amazon/chronos-bolt-mini,2.2330452369879255,0.039393515325238534
7 | monash_electricity_hourly,amazon/chronos-bolt-mini,1.6195944363428718,0.11468972600782207
8 | monash_electricity_weekly,amazon/chronos-bolt-mini,1.866105365159433,0.06019900031840434
9 | monash_kdd_cup_2018,amazon/chronos-bolt-mini,0.74790954883436,0.3012661161484388
10 | monash_london_smart_meters,amazon/chronos-bolt-mini,0.7187830347765344,0.32984510693830227
11 | monash_pedestrian_counts,amazon/chronos-bolt-mini,0.308633944815819,0.23331301029432483
12 | monash_rideshare,amazon/chronos-bolt-mini,0.818948044410056,0.1297966960374544
13 | monash_temperature_rain,amazon/chronos-bolt-mini,0.9035244443682741,0.605031064086567
14 | taxi_30min,amazon/chronos-bolt-mini,0.812010120941363,0.25232294549917317
15 | uber_tlc_daily,amazon/chronos-bolt-mini,0.8507256206478295,0.10101757743084538
16 | uber_tlc_hourly,amazon/chronos-bolt-mini,0.6685484898085609,0.1515245941548974
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-mini-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,amazon/chronos-bolt-mini,0.8057126710113404,0.07740387596411452
3 | ETTm,amazon/chronos-bolt-mini,0.6100793941108849,0.05129333450944573
4 | dominick,amazon/chronos-bolt-mini,0.8664152477208024,0.3499696999160997
5 | ercot,amazon/chronos-bolt-mini,0.6871250728215426,0.02448804863744021
6 | exchange_rate,amazon/chronos-bolt-mini,1.3520551553333662,0.00934663373172766
7 | m4_quarterly,amazon/chronos-bolt-mini,1.2569644266281508,0.07833787023275976
8 | m4_yearly,amazon/chronos-bolt-mini,3.7611003052413796,0.12931927951165456
9 | m5,amazon/chronos-bolt-mini,0.9188876472137485,0.5661303206519673
10 | monash_australian_electricity,amazon/chronos-bolt-mini,0.8823559450287066,0.04493688824488474
11 | monash_car_parts,amazon/chronos-bolt-mini,0.8604081423647779,1.0041876404811494
12 | monash_cif_2016,amazon/chronos-bolt-mini,1.0762361363763873,0.017641893717784202
13 | monash_covid_deaths,amazon/chronos-bolt-mini,38.83915011538576,0.06098317835750057
14 | monash_fred_md,amazon/chronos-bolt-mini,0.6169859211923081,0.03256236965040934
15 | monash_hospital,amazon/chronos-bolt-mini,0.6924431064606051,0.05766349075348645
16 | monash_m1_monthly,amazon/chronos-bolt-mini,1.147893030263777,0.13270222658510553
17 | monash_m1_quarterly,amazon/chronos-bolt-mini,1.8662100001165818,0.09846363409254102
18 | monash_m1_yearly,amazon/chronos-bolt-mini,5.319154632748303,0.16167328827180308
19 | monash_m3_monthly,amazon/chronos-bolt-mini,0.8758452776118432,0.09493431248614057
20 | monash_m3_quarterly,amazon/chronos-bolt-mini,1.3555175243802005,0.07808062465932723
21 | monash_m3_yearly,amazon/chronos-bolt-mini,3.605769430055575,0.15711010456482008
22 | monash_nn5_weekly,amazon/chronos-bolt-mini,0.9347141924977239,0.08522899825844342
23 | monash_tourism_monthly,amazon/chronos-bolt-mini,1.649587479665881,0.0979648261309891
24 | monash_tourism_quarterly,amazon/chronos-bolt-mini,1.8471553663088986,0.06501077791766902
25 | monash_tourism_yearly,amazon/chronos-bolt-mini,3.9932920493826245,0.1743539122097316
26 | monash_traffic,amazon/chronos-bolt-mini,0.8355442361271347,0.24351051123330386
27 | monash_weather,amazon/chronos-bolt-mini,0.800013628350165,0.13041050756802045
28 | nn5,amazon/chronos-bolt-mini,0.611917632501032,0.1570111102680171
29 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-small-agg-rel-scores.csv:
--------------------------------------------------------------------------------
1 | benchmark,metric,value
2 | in-domain,MASE,0.7030801652116672
3 | in-domain,WQL,0.5443547623341555
4 | zero-shot,MASE,0.8192127745093378
5 | zero-shot,WQL,0.6356097843099521
6 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-small-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,amazon/chronos-bolt-small,0.44920089250026723,0.08115291306964295
3 | m4_daily,amazon/chronos-bolt-small,3.201966619014735,0.02143368277732494
4 | m4_hourly,amazon/chronos-bolt-small,0.8686298207618999,0.020368729287465817
5 | m4_monthly,amazon/chronos-bolt-small,0.9537717737278778,0.0939247807527992
6 | m4_weekly,amazon/chronos-bolt-small,2.1236755094789177,0.03785184715517262
7 | monash_electricity_hourly,amazon/chronos-bolt-small,1.3728906161330452,0.09452411472431674
8 | monash_electricity_weekly,amazon/chronos-bolt-small,1.8703239487242378,0.06648479071326366
9 | monash_kdd_cup_2018,amazon/chronos-bolt-small,0.6458631909979771,0.25148489931571666
10 | monash_london_smart_meters,amazon/chronos-bolt-small,0.7126939688565166,0.326874529903459
11 | monash_pedestrian_counts,amazon/chronos-bolt-small,0.3015070035798365,0.2285590441093863
12 | monash_rideshare,amazon/chronos-bolt-small,0.823726965741684,0.12409769473500927
13 | monash_temperature_rain,amazon/chronos-bolt-small,0.8980348827836525,0.5984819599873311
14 | taxi_30min,amazon/chronos-bolt-small,0.7597818149895785,0.2348569752311862
15 | uber_tlc_daily,amazon/chronos-bolt-small,0.8460854328036702,0.09666483354735897
16 | uber_tlc_hourly,amazon/chronos-bolt-small,0.6662547495017634,0.1524256346268063
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-small-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,amazon/chronos-bolt-small,0.792521748651108,0.07590654063011319
3 | ETTm,amazon/chronos-bolt-small,0.6209623928936988,0.05056189722606397
4 | dominick,amazon/chronos-bolt-small,0.8706134610400587,0.34811141409475416
5 | ercot,amazon/chronos-bolt-small,0.7562857616685997,0.02596064260343696
6 | exchange_rate,amazon/chronos-bolt-small,1.774835301692689,0.011363548847621512
7 | m4_quarterly,amazon/chronos-bolt-small,1.2478142413437487,0.07808795122806232
8 | m4_yearly,amazon/chronos-bolt-small,3.6925595655002574,0.12772564181388502
9 | m5,amazon/chronos-bolt-small,0.9195435643571084,0.5668430814831332
10 | monash_australian_electricity,amazon/chronos-bolt-small,0.8128424798841111,0.041509852162861564
11 | monash_car_parts,amazon/chronos-bolt-small,0.8584574663781737,1.0074689402521324
12 | monash_cif_2016,amazon/chronos-bolt-small,1.0182471909074982,0.01581964877692293
13 | monash_covid_deaths,amazon/chronos-bolt-small,36.467595559655145,0.0427382859406882
14 | monash_fred_md,amazon/chronos-bolt-small,0.6132863794635253,0.03730410577241995
15 | monash_hospital,amazon/chronos-bolt-small,0.6954489513780618,0.058119864671526154
16 | monash_m1_monthly,amazon/chronos-bolt-small,1.1277621848099244,0.1335656174632902
17 | monash_m1_quarterly,amazon/chronos-bolt-small,1.8356144904231688,0.09363028483838018
18 | monash_m1_yearly,amazon/chronos-bolt-small,5.098146069746402,0.15669928873371905
19 | monash_m3_monthly,amazon/chronos-bolt-small,0.8685125121306435,0.09396568468255145
20 | monash_m3_quarterly,amazon/chronos-bolt-small,1.3269103591066727,0.07691022995374203
21 | monash_m3_yearly,amazon/chronos-bolt-small,3.40993282700627,0.1547639821304127
22 | monash_nn5_weekly,amazon/chronos-bolt-small,0.9266513350636507,0.08452821221908001
23 | monash_tourism_monthly,amazon/chronos-bolt-small,1.6106732721197876,0.09362336754317802
24 | monash_tourism_quarterly,amazon/chronos-bolt-small,1.8357819365308639,0.06734337535269994
25 | monash_tourism_yearly,amazon/chronos-bolt-small,3.8963100495394194,0.16766064312072784
26 | monash_traffic,amazon/chronos-bolt-small,0.8598507749866499,0.25173786112983054
27 | monash_weather,amazon/chronos-bolt-small,0.8020408743877911,0.13258563963844888
28 | nn5,amazon/chronos-bolt-small,0.5833047644729239,0.15066847836762787
29 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-tiny-agg-rel-scores.csv:
--------------------------------------------------------------------------------
1 | benchmark,metric,value
2 | in-domain,MASE,0.7403252781013574
3 | in-domain,WQL,0.5733728165523524
4 | zero-shot,MASE,0.8445407343705457
5 | zero-shot,WQL,0.6678781905023173
6 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-tiny-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,amazon/chronos-bolt-tiny,0.4676384089765091,0.0861229808117837
3 | m4_daily,amazon/chronos-bolt-tiny,3.1789994761356795,0.020961883512815756
4 | m4_hourly,amazon/chronos-bolt-tiny,0.9348005698736752,0.021087527284114574
5 | m4_monthly,amazon/chronos-bolt-tiny,0.965298729632761,0.0950380483243082
6 | m4_weekly,amazon/chronos-bolt-tiny,2.261575511029903,0.04093653263178429
7 | monash_electricity_hourly,amazon/chronos-bolt-tiny,1.5739346351263623,0.10808418398945202
8 | monash_electricity_weekly,amazon/chronos-bolt-tiny,1.8628689103722829,0.05773335283584782
9 | monash_kdd_cup_2018,amazon/chronos-bolt-tiny,0.6869549985391232,0.28012801092758166
10 | monash_london_smart_meters,amazon/chronos-bolt-tiny,0.7284234905933779,0.33496438244693033
11 | monash_pedestrian_counts,amazon/chronos-bolt-tiny,0.32338947321773864,0.2530637833749087
12 | monash_rideshare,amazon/chronos-bolt-tiny,0.8562780835002918,0.1304317657933891
13 | monash_temperature_rain,amazon/chronos-bolt-tiny,0.9030707620825977,0.6064087080755548
14 | taxi_30min,amazon/chronos-bolt-tiny,0.9122159603256838,0.28002194370731626
15 | uber_tlc_daily,amazon/chronos-bolt-tiny,0.9087055420190513,0.11193388685815164
16 | uber_tlc_hourly,amazon/chronos-bolt-tiny,0.6716569179590032,0.15310845458208555
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-bolt-tiny-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,amazon/chronos-bolt-tiny,0.7941225847155844,0.07480860969990633
3 | ETTm,amazon/chronos-bolt-tiny,0.6508270995240056,0.05440068825429993
4 | dominick,amazon/chronos-bolt-tiny,0.876060127216559,0.35175949052933253
5 | ercot,amazon/chronos-bolt-tiny,0.7309134980173839,0.02468604544464515
6 | exchange_rate,amazon/chronos-bolt-tiny,1.6857262567077134,0.011477224264784112
7 | m4_quarterly,amazon/chronos-bolt-tiny,1.2605908919338378,0.0789049420017836
8 | m4_yearly,amazon/chronos-bolt-tiny,3.7118394116161757,0.1286932555969197
9 | m5,amazon/chronos-bolt-tiny,0.9195469670062033,0.5634881835998845
10 | monash_australian_electricity,amazon/chronos-bolt-tiny,0.8419304693259403,0.042040993880313904
11 | monash_car_parts,amazon/chronos-bolt-tiny,0.8625579150452282,1.0009987800801836
12 | monash_cif_2016,amazon/chronos-bolt-tiny,1.095219642027011,0.017550336784241796
13 | monash_covid_deaths,amazon/chronos-bolt-tiny,40.674057986280744,0.06723714516685976
14 | monash_fred_md,amazon/chronos-bolt-tiny,0.6127387450520702,0.04747523852271518
15 | monash_hospital,amazon/chronos-bolt-tiny,0.6980246281225624,0.05864223243167421
16 | monash_m1_monthly,amazon/chronos-bolt-tiny,1.1625495971731141,0.13142237467151166
17 | monash_m1_quarterly,amazon/chronos-bolt-tiny,1.8941765599193754,0.09972207844232561
18 | monash_m1_yearly,amazon/chronos-bolt-tiny,5.136332694531757,0.160331813128038
19 | monash_m3_monthly,amazon/chronos-bolt-tiny,0.8744553726704598,0.09435519378597752
20 | monash_m3_quarterly,amazon/chronos-bolt-tiny,1.364563776692303,0.07875066385737857
21 | monash_m3_yearly,amazon/chronos-bolt-tiny,3.3685961410254928,0.15158076519486274
22 | monash_nn5_weekly,amazon/chronos-bolt-tiny,0.9324436794013877,0.0847385189968909
23 | monash_tourism_monthly,amazon/chronos-bolt-tiny,1.7895936775088157,0.1058167042693116
24 | monash_tourism_quarterly,amazon/chronos-bolt-tiny,2.095262637810499,0.0710732570354461
25 | monash_tourism_yearly,amazon/chronos-bolt-tiny,4.042821441327848,0.172613367251472
26 | monash_traffic,amazon/chronos-bolt-tiny,0.8836032533767518,0.2574297134210491
27 | monash_weather,amazon/chronos-bolt-tiny,0.8005348255663177,0.13111355494466076
28 | nn5,amazon/chronos-bolt-tiny,0.7228248498869763,0.1816913098894226
29 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-base-agg-rel-scores.csv:
--------------------------------------------------------------------------------
1 | benchmark,metric,value
2 | in-domain,MASE,0.7007558507277635
3 | in-domain,WQL,0.5786300105297922
4 | zero-shot,MASE,0.8155209321160994
5 | zero-shot,WQL,0.6424634919486323
6 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-base-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,amazon/chronos-t5-base,0.39879754957261204,0.07738953262286181
3 | m4_daily,amazon/chronos-t5-base,3.160575865614404,0.02194256368254537
4 | m4_hourly,amazon/chronos-t5-base,0.6938747745332102,0.026354948301302205
5 | m4_monthly,amazon/chronos-t5-base,0.971951848755026,0.10355213196432872
6 | m4_weekly,amazon/chronos-t5-base,2.0143841267657945,0.03639741235815474
7 | monash_electricity_hourly,amazon/chronos-t5-base,1.5717251971297332,0.1078882125804548
8 | monash_electricity_weekly,amazon/chronos-t5-base,1.7862927210886668,0.06255982783148449
9 | monash_kdd_cup_2018,amazon/chronos-t5-base,0.6335225775496138,0.2684272353843692
10 | monash_london_smart_meters,amazon/chronos-t5-base,0.8362014889190201,0.4265549499082726
11 | monash_pedestrian_counts,amazon/chronos-t5-base,0.2817708325561419,0.20810108090665583
12 | monash_rideshare,amazon/chronos-t5-base,0.8614480533175364,0.1356591190888703
13 | monash_temperature_rain,amazon/chronos-t5-base,0.9692405156151607,0.660155448791624
14 | taxi_30min,amazon/chronos-t5-base,0.8186287575356217,0.26236060366367003
15 | uber_tlc_daily,amazon/chronos-t5-base,0.8338648311528079,0.0970875577681834
16 | uber_tlc_hourly,amazon/chronos-t5-base,0.6647193438331641,0.15436646659512512
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-base-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,amazon/chronos-t5-base,0.7653491494991778,0.08087267701042929
3 | ETTm,amazon/chronos-t5-base,0.7737006634032871,0.07008650633028274
4 | dominick,amazon/chronos-t5-base,0.8194044957573132,0.33201307438298133
5 | ercot,amazon/chronos-t5-base,0.5014399265038706,0.013589435745554596
6 | exchange_rate,amazon/chronos-t5-base,2.055616906406159,0.011066070028466317
7 | m4_quarterly,amazon/chronos-t5-base,1.2253036947743137,0.08327936201395683
8 | m4_yearly,amazon/chronos-t5-base,3.639991540990927,0.13539258375263963
9 | m5,amazon/chronos-t5-base,0.9391874615167101,0.5867234116216755
10 | monash_australian_electricity,amazon/chronos-t5-base,1.2944069383163321,0.07070604202031877
11 | monash_car_parts,amazon/chronos-t5-base,0.9071940271035218,1.077797124337994
12 | monash_cif_2016,amazon/chronos-t5-base,0.9840747802099565,0.011825556826558836
13 | monash_covid_deaths,amazon/chronos-t5-base,42.68503365359237,0.042229910495746356
14 | monash_fred_md,amazon/chronos-t5-base,0.4857773806790164,0.021204829049512715
15 | monash_hospital,amazon/chronos-t5-base,0.7053005021431749,0.05630687524507516
16 | monash_m1_monthly,amazon/chronos-t5-base,1.1153039466137842,0.12724419775326076
17 | monash_m1_quarterly,amazon/chronos-t5-base,1.746093728928804,0.1123583549291933
18 | monash_m1_yearly,amazon/chronos-t5-base,4.401291522370069,0.18541586641719554
19 | monash_m3_monthly,amazon/chronos-t5-base,0.8627172231908679,0.09640536232169555
20 | monash_m3_quarterly,amazon/chronos-t5-base,1.1696030904401578,0.07392876900131434
21 | monash_m3_yearly,amazon/chronos-t5-base,3.1298600218573775,0.1486674447940158
22 | monash_nn5_weekly,amazon/chronos-t5-base,0.9334860602210187,0.08972736821598823
23 | monash_tourism_monthly,amazon/chronos-t5-base,1.7937702435879332,0.10260220444264027
24 | monash_tourism_quarterly,amazon/chronos-t5-base,1.7791494997972261,0.06852507950474919
25 | monash_tourism_yearly,amazon/chronos-t5-base,3.8359926053603197,0.20722699382964643
26 | monash_traffic,amazon/chronos-t5-base,0.8015262383138622,0.25565153982140926
27 | monash_weather,amazon/chronos-t5-base,0.8159511190589147,0.13802320967454584
28 | nn5,amazon/chronos-t5-base,0.5927076179914024,0.1630476065585159
29 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-large-agg-rel-scores.csv:
--------------------------------------------------------------------------------
1 | benchmark,metric,value
2 | in-domain,MASE,0.6944869734691035
3 | in-domain,WQL,0.5596857927462495
4 | zero-shot,MASE,0.8213682201405101
5 | zero-shot,WQL,0.6504834081319559
6 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-large-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,amazon/chronos-t5-large,0.3866310906621673,0.07759528615667297
3 | m4_daily,amazon/chronos-t5-large,3.134560968849699,0.02158279722410466
4 | m4_hourly,amazon/chronos-t5-large,0.6975930649233378,0.02086427219957674
5 | m4_monthly,amazon/chronos-t5-large,0.9585550091429409,0.10091221432814867
6 | m4_weekly,amazon/chronos-t5-large,2.0191422600104425,0.036912838355537186
7 | monash_electricity_hourly,amazon/chronos-t5-large,1.4069912853901292,0.09642382339452431
8 | monash_electricity_weekly,amazon/chronos-t5-large,1.7501880036182798,0.05765306465830232
9 | monash_kdd_cup_2018,amazon/chronos-t5-large,0.6788042816175427,0.2853553329804835
10 | monash_london_smart_meters,amazon/chronos-t5-large,0.8290300790418726,0.4235436387853963
11 | monash_pedestrian_counts,amazon/chronos-t5-large,0.2764118100521592,0.18692234491663473
12 | monash_rideshare,amazon/chronos-t5-large,0.8758058784466208,0.140260325368757
13 | monash_temperature_rain,amazon/chronos-t5-large,0.9738403865035117,0.6604571928063249
14 | taxi_30min,amazon/chronos-t5-large,0.8245662397270109,0.2653520120326771
15 | uber_tlc_daily,amazon/chronos-t5-large,0.8044165990021739,0.09499035584302248
16 | uber_tlc_hourly,amazon/chronos-t5-large,0.6700665937164474,0.15190288476653066
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-large-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,amazon/chronos-t5-large,0.78160443631164,0.07884375667736107
3 | ETTm,amazon/chronos-t5-large,0.7325919639389967,0.06656858270921162
4 | dominick,amazon/chronos-t5-large,0.8200108271155829,0.3311575649734524
5 | ercot,amazon/chronos-t5-large,0.6050812633742764,0.01822996942395577
6 | exchange_rate,amazon/chronos-t5-large,2.3439287001928744,0.014841231672174684
7 | m4_quarterly,amazon/chronos-t5-large,1.2169666607868148,0.08235162400898562
8 | m4_yearly,amazon/chronos-t5-large,3.5524979814018947,0.1325675848907479
9 | m5,amazon/chronos-t5-large,0.9422990989146737,0.585615077637479
10 | monash_australian_electricity,amazon/chronos-t5-large,1.480849838497958,0.07973968848149568
11 | monash_car_parts,amazon/chronos-t5-large,0.901547374873302,1.0467398096496576
12 | monash_cif_2016,amazon/chronos-t5-large,0.9906388185665337,0.011966178555329998
13 | monash_covid_deaths,amazon/chronos-t5-large,44.07354193681227,0.06108999981222163
14 | monash_fred_md,amazon/chronos-t5-large,0.5184400880318044,0.01675533888399231
15 | monash_hospital,amazon/chronos-t5-large,0.7055308474630898,0.0552450850258613
16 | monash_m1_monthly,amazon/chronos-t5-large,1.0888995301234758,0.12729911122909737
17 | monash_m1_quarterly,amazon/chronos-t5-large,1.7477134564031453,0.10618253695380094
18 | monash_m1_yearly,amazon/chronos-t5-large,4.250667049416348,0.17128879333643188
19 | monash_m3_monthly,amazon/chronos-t5-large,0.8559326975903808,0.09572577431396007
20 | monash_m3_quarterly,amazon/chronos-t5-large,1.1867267751420676,0.07449254281607631
21 | monash_m3_yearly,amazon/chronos-t5-large,3.0239493021840635,0.14814710375646464
22 | monash_nn5_weekly,amazon/chronos-t5-large,0.9228721852437364,0.08948447200571868
23 | monash_tourism_monthly,amazon/chronos-t5-large,1.7304427846580348,0.09983169221760163
24 | monash_tourism_quarterly,amazon/chronos-t5-large,1.6437184365114073,0.0690906057781915
25 | monash_tourism_yearly,amazon/chronos-t5-large,3.6268503118928535,0.17732007043832695
26 | monash_traffic,amazon/chronos-t5-large,0.7985975530866148,0.25313515740581755
27 | monash_weather,amazon/chronos-t5-large,0.8187388457436171,0.1387756772600068
28 | nn5,amazon/chronos-t5-large,0.5755260854173723,0.15733693855465292
29 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-mini-agg-rel-scores.csv:
--------------------------------------------------------------------------------
1 | benchmark,metric,value
2 | in-domain,MASE,0.7249816823595568
3 | in-domain,WQL,0.5965372489622094
4 | zero-shot,MASE,0.8411995116926901
5 | zero-shot,WQL,0.6888397962259065
6 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-mini-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,amazon/chronos-t5-mini,0.4446629660227641,0.08114657599239496
3 | m4_daily,amazon/chronos-t5-mini,3.1533349226194005,0.022000507013584743
4 | m4_hourly,amazon/chronos-t5-mini,0.7616830292996938,0.024630575107847653
5 | m4_monthly,amazon/chronos-t5-mini,0.9934074425853089,0.10402168689068064
6 | m4_weekly,amazon/chronos-t5-mini,2.1407189608104416,0.04138058102434373
7 | monash_electricity_hourly,amazon/chronos-t5-mini,1.3698378948313894,0.09189698159081384
8 | monash_electricity_weekly,amazon/chronos-t5-mini,1.9238345295706893,0.07015383787479901
9 | monash_kdd_cup_2018,amazon/chronos-t5-mini,0.6027861468526459,0.25493489598663444
10 | monash_london_smart_meters,amazon/chronos-t5-mini,0.8570035850603943,0.4356582737588471
11 | monash_pedestrian_counts,amazon/chronos-t5-mini,0.30374539593979855,0.2374083216051065
12 | monash_rideshare,amazon/chronos-t5-mini,0.8157349455509949,0.12963515638823117
13 | monash_temperature_rain,amazon/chronos-t5-mini,1.010161905102516,0.6919171702485583
14 | taxi_30min,amazon/chronos-t5-mini,0.9318379552979712,0.31229508015999674
15 | uber_tlc_daily,amazon/chronos-t5-mini,0.9213437323817685,0.10475291429149586
16 | uber_tlc_hourly,amazon/chronos-t5-mini,0.6812621470377416,0.15982192635434303
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-mini-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,amazon/chronos-t5-mini,0.789678971785092,0.08068969536800001
3 | ETTm,amazon/chronos-t5-mini,0.7521219674190734,0.06791782942706617
4 | dominick,amazon/chronos-t5-mini,0.8207116999488602,0.34004499734299765
5 | ercot,amazon/chronos-t5-mini,0.5462749489237783,0.015035001020343136
6 | exchange_rate,amazon/chronos-t5-mini,2.1326718165798657,0.015073846769933199
7 | m4_quarterly,amazon/chronos-t5-mini,1.271761811062081,0.08575942238385105
8 | m4_yearly,amazon/chronos-t5-mini,3.7340853642679126,0.13938781939783162
9 | m5,amazon/chronos-t5-mini,0.9421556321929742,0.5961689098871504
10 | monash_australian_electricity,amazon/chronos-t5-mini,1.046297291920238,0.05424453772723559
11 | monash_car_parts,amazon/chronos-t5-mini,0.8913523483805221,1.0174797526818506
12 | monash_cif_2016,amazon/chronos-t5-mini,1.0674111822055679,0.016800831829085764
13 | monash_covid_deaths,amazon/chronos-t5-mini,43.69727825485175,0.08788117644141617
14 | monash_fred_md,amazon/chronos-t5-mini,0.46227452519609524,0.01871860604459728
15 | monash_hospital,amazon/chronos-t5-mini,0.7112593459108532,0.05831005112661489
16 | monash_m1_monthly,amazon/chronos-t5-mini,1.1756557848450433,0.14192178371159841
17 | monash_m1_quarterly,amazon/chronos-t5-mini,1.795009199698074,0.11760148522768847
18 | monash_m1_yearly,amazon/chronos-t5-mini,5.078889706085604,0.1882823108615221
19 | monash_m3_monthly,amazon/chronos-t5-mini,0.900404391663476,0.09935931092075681
20 | monash_m3_quarterly,amazon/chronos-t5-mini,1.2604342624229292,0.07807204797138119
21 | monash_m3_yearly,amazon/chronos-t5-mini,3.4395976709464255,0.16085249526114198
22 | monash_nn5_weekly,amazon/chronos-t5-mini,0.9459117943913629,0.09042762527674755
23 | monash_tourism_monthly,amazon/chronos-t5-mini,1.920865545569713,0.10791754513335952
24 | monash_tourism_quarterly,amazon/chronos-t5-mini,1.7957439111869486,0.07514539225156464
25 | monash_tourism_yearly,amazon/chronos-t5-mini,4.134958090482728,0.2202036957350168
26 | monash_traffic,amazon/chronos-t5-mini,0.8546792774237857,0.2668831661775284
27 | monash_weather,amazon/chronos-t5-mini,0.8607748244159247,0.15031866806333247
28 | nn5,amazon/chronos-t5-mini,0.6497211196906223,0.17352254058241523
29 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-small-agg-rel-scores.csv:
--------------------------------------------------------------------------------
1 | benchmark,metric,value
2 | in-domain,MASE,0.7296140269944743
3 | in-domain,WQL,0.6086958548874499
4 | zero-shot,MASE,0.8303721909132112
5 | zero-shot,WQL,0.6649587072099045
6 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-small-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,amazon/chronos-t5-small,0.4115559557750193,0.08085148902238105
3 | m4_daily,amazon/chronos-t5-small,3.1384304946608896,0.02129901023419818
4 | m4_hourly,amazon/chronos-t5-small,0.7300874075370588,0.024686127211237932
5 | m4_monthly,amazon/chronos-t5-small,0.9797264456494642,0.10297069145186107
6 | m4_weekly,amazon/chronos-t5-small,2.0802214537692607,0.03959222330783002
7 | monash_electricity_hourly,amazon/chronos-t5-small,1.530308399040219,0.10765947926209926
8 | monash_electricity_weekly,amazon/chronos-t5-small,1.9249616494404531,0.07593976499899265
9 | monash_kdd_cup_2018,amazon/chronos-t5-small,0.6911172359201715,0.2863722811236367
10 | monash_london_smart_meters,amazon/chronos-t5-small,0.8405756252443325,0.4300875548402115
11 | monash_pedestrian_counts,amazon/chronos-t5-small,0.30836963006151696,0.2442543970311678
12 | monash_rideshare,amazon/chronos-t5-small,0.8436277753840817,0.1363421932158997
13 | monash_temperature_rain,amazon/chronos-t5-small,1.0176003932416664,0.6847726381172435
14 | taxi_30min,amazon/chronos-t5-small,0.976277213614167,0.32770172988517626
15 | uber_tlc_daily,amazon/chronos-t5-small,0.8694727058784919,0.0994889223610958
16 | uber_tlc_hourly,amazon/chronos-t5-small,0.6738672444888639,0.1573990617753753
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-small-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,amazon/chronos-t5-small,0.8516754221042285,0.08667817580712385
3 | ETTm,amazon/chronos-t5-small,0.6825432730635727,0.06076472147001207
4 | dominick,amazon/chronos-t5-small,0.8108766032127683,0.3368104617474581
5 | ercot,amazon/chronos-t5-small,0.564879593858422,0.015547628920969682
6 | exchange_rate,amazon/chronos-t5-small,1.8143459139100264,0.014492477372711763
7 | m4_quarterly,amazon/chronos-t5-small,1.2415331521819728,0.08383826063189778
8 | m4_yearly,amazon/chronos-t5-small,3.738749650935195,0.1384514201649314
9 | m5,amazon/chronos-t5-small,0.9368713240675598,0.5896066252181699
10 | monash_australian_electricity,amazon/chronos-t5-small,1.2241146217392032,0.06951399165882449
11 | monash_car_parts,amazon/chronos-t5-small,0.8917508090523597,1.0314986717260015
12 | monash_cif_2016,amazon/chronos-t5-small,1.0187937383419037,0.014633240218233142
13 | monash_covid_deaths,amazon/chronos-t5-small,42.298997211368935,0.06339512778191682
14 | monash_fred_md,amazon/chronos-t5-small,0.4742159923922472,0.01486734736993978
15 | monash_hospital,amazon/chronos-t5-small,0.709814741753487,0.05704674270057172
16 | monash_m1_monthly,amazon/chronos-t5-small,1.1723041163998773,0.13799049510465802
17 | monash_m1_quarterly,amazon/chronos-t5-small,1.8077827825737092,0.11323432989795904
18 | monash_m1_yearly,amazon/chronos-t5-small,4.739967673537301,0.1730738338876877
19 | monash_m3_monthly,amazon/chronos-t5-small,0.8856577322724943,0.09985251429658573
20 | monash_m3_quarterly,amazon/chronos-t5-small,1.278907982396775,0.08094041554590593
21 | monash_m3_yearly,amazon/chronos-t5-small,3.382470310192457,0.157363937435307
22 | monash_nn5_weekly,amazon/chronos-t5-small,0.9277396908126303,0.08963913763368506
23 | monash_tourism_monthly,amazon/chronos-t5-small,1.9251180766131313,0.10943962474253494
24 | monash_tourism_quarterly,amazon/chronos-t5-small,1.7623454951333655,0.06862432764377493
25 | monash_tourism_yearly,amazon/chronos-t5-small,3.987690476709746,0.19960492460202509
26 | monash_traffic,amazon/chronos-t5-small,0.8204223927835267,0.2571189517024486
27 | monash_weather,amazon/chronos-t5-small,0.8550633590487968,0.1479701971025123
28 | nn5,amazon/chronos-t5-small,0.6130789183153671,0.16771392719859998
29 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-tiny-agg-rel-scores.csv:
--------------------------------------------------------------------------------
1 | benchmark,metric,value
2 | in-domain,MASE,0.7649019745781727
3 | in-domain,WQL,0.6288613368129368
4 | zero-shot,MASE,0.8704764463925718
5 | zero-shot,WQL,0.7108912052035352
6 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-tiny-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,amazon/chronos-t5-tiny,0.5091784254243783,0.08236334376190152
3 | m4_daily,amazon/chronos-t5-tiny,3.203164895930929,0.022152192084951595
4 | m4_hourly,amazon/chronos-t5-tiny,0.8171321441164723,0.027490760558343874
5 | m4_monthly,amazon/chronos-t5-tiny,1.005839207921131,0.10388015368939435
6 | m4_weekly,amazon/chronos-t5-tiny,2.2148332313370735,0.043429655561156084
7 | monash_electricity_hourly,amazon/chronos-t5-tiny,1.6190021089002615,0.10967453530956882
8 | monash_electricity_weekly,amazon/chronos-t5-tiny,2.0774597917676734,0.08159998975612164
9 | monash_kdd_cup_2018,amazon/chronos-t5-tiny,0.6730886827096076,0.2616610603634618
10 | monash_london_smart_meters,amazon/chronos-t5-tiny,0.8830447519225436,0.4499607073491794
11 | monash_pedestrian_counts,amazon/chronos-t5-tiny,0.3042105240185045,0.23387631681117071
12 | monash_rideshare,amazon/chronos-t5-tiny,0.8431350112476247,0.1378817076926394
13 | monash_temperature_rain,amazon/chronos-t5-tiny,0.9887398447367799,0.6957797286648015
14 | taxi_30min,amazon/chronos-t5-tiny,1.035544060665179,0.3450476958104713
15 | uber_tlc_daily,amazon/chronos-t5-tiny,0.93025919000775,0.1105323649942084
16 | uber_tlc_hourly,amazon/chronos-t5-tiny,0.697558054147913,0.16320255844336232
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/chronos-t5-tiny-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,amazon/chronos-t5-tiny,0.8184074113571701,0.08578203438707048
3 | ETTm,amazon/chronos-t5-tiny,0.9103621000781905,0.07975361086322658
4 | dominick,amazon/chronos-t5-tiny,0.8538295532466194,0.3597090770361857
5 | ercot,amazon/chronos-t5-tiny,0.7273437589773705,0.020843170924006626
6 | exchange_rate,amazon/chronos-t5-tiny,1.6621128608546154,0.01085145980896454
7 | m4_quarterly,amazon/chronos-t5-tiny,1.2696259955861924,0.0861404188925996
8 | m4_yearly,amazon/chronos-t5-tiny,3.5293881164900527,0.13281575565500411
9 | m5,amazon/chronos-t5-tiny,0.9394059505709506,0.5981531758388589
10 | monash_australian_electricity,amazon/chronos-t5-tiny,1.4558820561269024,0.07673567331332948
11 | monash_car_parts,amazon/chronos-t5-tiny,0.9058206654011024,1.0236307963149358
12 | monash_cif_2016,amazon/chronos-t5-tiny,1.09349564130852,0.014066593076202984
13 | monash_covid_deaths,amazon/chronos-t5-tiny,46.53079664940016,0.09201919385053775
14 | monash_fred_md,amazon/chronos-t5-tiny,0.48008374212956456,0.03219550761153211
15 | monash_hospital,amazon/chronos-t5-tiny,0.7062562198194838,0.05790409320432609
16 | monash_m1_monthly,amazon/chronos-t5-tiny,1.214892145549996,0.14723095246308077
17 | monash_m1_quarterly,amazon/chronos-t5-tiny,1.8968576926613199,0.11026972972622998
18 | monash_m1_yearly,amazon/chronos-t5-tiny,4.829453202075546,0.17286063726000958
19 | monash_m3_monthly,amazon/chronos-t5-tiny,0.9095746605884618,0.10117875324490073
20 | monash_m3_quarterly,amazon/chronos-t5-tiny,1.3234957548639883,0.08209032993637215
21 | monash_m3_yearly,amazon/chronos-t5-tiny,3.1489371074890093,0.1492445630072877
22 | monash_nn5_weekly,amazon/chronos-t5-tiny,0.9637480731663901,0.09205994784693056
23 | monash_tourism_monthly,amazon/chronos-t5-tiny,2.151677532807024,0.11356761694754255
24 | monash_tourism_quarterly,amazon/chronos-t5-tiny,1.9116538900950555,0.07191734222366106
25 | monash_tourism_yearly,amazon/chronos-t5-tiny,3.820615532600914,0.19709256337364625
26 | monash_traffic,amazon/chronos-t5-tiny,0.878709088458116,0.2632101606272236
27 | monash_weather,amazon/chronos-t5-tiny,0.8504899606521996,0.14787595319625085
28 | nn5,amazon/chronos-t5-tiny,0.7021735456568664,0.19071330483289695
29 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/seasonal-naive-in-domain.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | electricity_15min,seasonal-naive,0.4978697476132387,0.1169378163151378
3 | m4_daily,seasonal-naive,3.278424323759728,0.0279332664832445
4 | m4_hourly,seasonal-naive,1.1932105781333862,0.0483091941403194
5 | m4_monthly,seasonal-naive,1.2597170386001693,0.1455332906092934
6 | m4_weekly,seasonal-naive,2.777295109814942,0.0633986476090776
7 | monash_electricity_hourly,seasonal-naive,1.839634785956572,0.1468968206229902
8 | monash_electricity_weekly,seasonal-naive,3.0371656285424,0.1979332504059267
9 | monash_kdd_cup_2018,seasonal-naive,0.9943785889052376,0.5555856702439576
10 | monash_london_smart_meters,seasonal-naive,0.9661872287141056,0.5413187715028914
11 | monash_pedestrian_counts,seasonal-naive,0.3691951941442247,0.3185271550430794
12 | monash_rideshare,seasonal-naive,1.2495987545425715,0.1860080644135506
13 | monash_temperature_rain,seasonal-naive,2.243384627173123,1.4244854980220072
14 | taxi_30min,seasonal-naive,1.160268631066241,0.4711417890926274
15 | uber_tlc_daily,seasonal-naive,1.37803447078482,0.2313550175912078
16 | uber_tlc_hourly,seasonal-naive,0.930916273455971,0.298849044501192
17 |
--------------------------------------------------------------------------------
/scripts/evaluation/results/seasonal-naive-zero-shot.csv:
--------------------------------------------------------------------------------
1 | dataset,model,MASE,WQL
2 | ETTh,seasonal-naive,0.9316203114697056,0.1220896585205886
3 | ETTm,seasonal-naive,1.1693053852270578,0.1413480385734046
4 | dominick,seasonal-naive,0.8706150115348875,0.4529164093744346
5 | ercot,seasonal-naive,0.7613354813741452,0.0366036447606282
6 | exchange_rate,seasonal-naive,1.7401824286954128,0.0129841406759913
7 | m4_quarterly,seasonal-naive,1.6022471766126911,0.1186484661559648
8 | m4_yearly,seasonal-naive,3.974360261259571,0.1614389663357925
9 | m5,seasonal-naive,1.399206213076729,1.0240883478068443
10 | monash_australian_electricity,seasonal-naive,1.2533189641227642,0.0836951323308387
11 | monash_car_parts,seasonal-naive,1.2014638390969912,1.5999522140809177
12 | monash_cif_2016,seasonal-naive,1.289290577415544,0.0150830409089921
13 | monash_covid_deaths,seasonal-naive,46.91239825526407,0.1330848762571827
14 | monash_fred_md,seasonal-naive,1.1008000463101226,0.1222237702571737
15 | monash_hospital,seasonal-naive,0.9205278266364826,0.0726263373268254
16 | monash_m1_monthly,seasonal-naive,1.3144614957646543,0.1914632595030148
17 | monash_m1_quarterly,seasonal-naive,2.077536550805995,0.1495022062865622
18 | monash_m1_yearly,seasonal-naive,4.894322225232431,0.2092955931101782
19 | monash_m3_monthly,seasonal-naive,1.1462045758327934,0.1485446007554992
20 | monash_m3_quarterly,seasonal-naive,1.425343793700714,0.1012520529806161
21 | monash_m3_yearly,seasonal-naive,3.1717102364409517,0.1665329650420048
22 | monash_nn5_weekly,seasonal-naive,1.0628482559107015,0.1226908962169196
23 | monash_tourism_monthly,seasonal-naive,1.630939994944413,0.1041824322151567
24 | monash_tourism_quarterly,seasonal-naive,1.6989892627474672,0.1193750169177449
25 | monash_tourism_yearly,seasonal-naive,3.5520097206480883,0.2091826587673241
26 | monash_traffic,seasonal-naive,1.0767397173107436,0.3618532196990004
27 | monash_weather,seasonal-naive,1.0038475713182748,0.2165947349654047
28 | nn5,seasonal-naive,1.2917285866431214,0.4246208074843067
29 |
--------------------------------------------------------------------------------
/scripts/kernel-synth.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import argparse
5 | import functools
6 | from pathlib import Path
7 | from typing import Optional
8 |
9 | import numpy as np
10 | from gluonts.dataset.arrow import ArrowWriter
11 | from joblib import Parallel, delayed
12 | from sklearn.gaussian_process import GaussianProcessRegressor
13 | from sklearn.gaussian_process.kernels import (
14 | RBF,
15 | ConstantKernel,
16 | DotProduct,
17 | ExpSineSquared,
18 | Kernel,
19 | RationalQuadratic,
20 | WhiteKernel,
21 | )
22 | from tqdm.auto import tqdm
23 |
24 | LENGTH = 1024
25 | KERNEL_BANK = [
26 | ExpSineSquared(periodicity=24 / LENGTH), # H
27 | ExpSineSquared(periodicity=48 / LENGTH), # 0.5H
28 | ExpSineSquared(periodicity=96 / LENGTH), # 0.25H
29 | ExpSineSquared(periodicity=24 * 7 / LENGTH), # H
30 | ExpSineSquared(periodicity=48 * 7 / LENGTH), # 0.5H
31 | ExpSineSquared(periodicity=96 * 7 / LENGTH), # 0.25H
32 | ExpSineSquared(periodicity=7 / LENGTH), # D
33 | ExpSineSquared(periodicity=14 / LENGTH), # 0.5D
34 | ExpSineSquared(periodicity=30 / LENGTH), # D
35 | ExpSineSquared(periodicity=60 / LENGTH), # 0.5D
36 | ExpSineSquared(periodicity=365 / LENGTH), # D
37 | ExpSineSquared(periodicity=365 * 2 / LENGTH), # 0.5D
38 | ExpSineSquared(periodicity=4 / LENGTH), # W
39 | ExpSineSquared(periodicity=26 / LENGTH), # W
40 | ExpSineSquared(periodicity=52 / LENGTH), # W
41 | ExpSineSquared(periodicity=4 / LENGTH), # M
42 | ExpSineSquared(periodicity=6 / LENGTH), # M
43 | ExpSineSquared(periodicity=12 / LENGTH), # M
44 | ExpSineSquared(periodicity=4 / LENGTH), # Q
45 | ExpSineSquared(periodicity=4 * 10 / LENGTH), # Q
46 | ExpSineSquared(periodicity=10 / LENGTH), # Y
47 | DotProduct(sigma_0=0.0),
48 | DotProduct(sigma_0=1.0),
49 | DotProduct(sigma_0=10.0),
50 | RBF(length_scale=0.1),
51 | RBF(length_scale=1.0),
52 | RBF(length_scale=10.0),
53 | RationalQuadratic(alpha=0.1),
54 | RationalQuadratic(alpha=1.0),
55 | RationalQuadratic(alpha=10.0),
56 | WhiteKernel(noise_level=0.1),
57 | WhiteKernel(noise_level=1.0),
58 | ConstantKernel(),
59 | ]
60 |
61 |
62 | def random_binary_map(a: Kernel, b: Kernel):
63 | """
64 | Applies a random binary operator (+ or *) with equal probability
65 | on kernels ``a`` and ``b``.
66 |
67 | Parameters
68 | ----------
69 | a
70 | A GP kernel.
71 | b
72 | A GP kernel.
73 |
74 | Returns
75 | -------
76 | The composite kernel `a + b` or `a * b`.
77 | """
78 | binary_maps = [lambda x, y: x + y, lambda x, y: x * y]
79 | return np.random.choice(binary_maps)(a, b)
80 |
81 |
82 | def sample_from_gp_prior(
83 | kernel: Kernel, X: np.ndarray, random_seed: Optional[int] = None
84 | ):
85 | """
86 | Draw a sample from a GP prior.
87 |
88 | Parameters
89 | ----------
90 | kernel
91 | The GP covaraince kernel.
92 | X
93 | The input "time" points.
94 | random_seed, optional
95 | The random seed for sampling, by default None.
96 |
97 | Returns
98 | -------
99 | A time series sampled from the GP prior.
100 | """
101 | if X.ndim == 1:
102 | X = X[:, None]
103 |
104 | assert X.ndim == 2
105 | gpr = GaussianProcessRegressor(kernel=kernel)
106 | ts = gpr.sample_y(X, n_samples=1, random_state=random_seed)
107 |
108 | return ts
109 |
110 |
111 | def sample_from_gp_prior_efficient(
112 | kernel: Kernel,
113 | X: np.ndarray,
114 | random_seed: Optional[int] = None,
115 | method: str = "eigh",
116 | ):
117 | """
118 | Draw a sample from a GP prior. An efficient version that allows specification
119 | of the sampling method. The default sampling method used in GaussianProcessRegressor
120 | is based on SVD which is significantly slower that alternatives such as `eigh` and
121 | `cholesky`.
122 |
123 | Parameters
124 | ----------
125 | kernel
126 | The GP covaraince kernel.
127 | X
128 | The input "time" points.
129 | random_seed, optional
130 | The random seed for sampling, by default None.
131 | method, optional
132 | The sampling method for multivariate_normal, by default `eigh`.
133 |
134 | Returns
135 | -------
136 | A time series sampled from the GP prior.
137 | """
138 | if X.ndim == 1:
139 | X = X[:, None]
140 |
141 | assert X.ndim == 2
142 |
143 | cov = kernel(X)
144 | ts = np.random.default_rng(seed=random_seed).multivariate_normal(
145 | mean=np.zeros(X.shape[0]), cov=cov, method=method
146 | )
147 |
148 | return ts
149 |
150 |
151 | def generate_time_series(max_kernels: int = 5):
152 | """Generate a synthetic time series from KernelSynth.
153 |
154 | Parameters
155 | ----------
156 | max_kernels, optional
157 | The maximum number of base kernels to use for each time series, by default 5
158 |
159 | Returns
160 | -------
161 | A time series generated by KernelSynth.
162 | """
163 | while True:
164 | X = np.linspace(0, 1, LENGTH)
165 |
166 | # Randomly select upto max_kernels kernels from the KERNEL_BANK
167 | selected_kernels = np.random.choice(
168 | KERNEL_BANK, np.random.randint(1, max_kernels + 1), replace=True
169 | )
170 |
171 | # Combine the sampled kernels using random binary operators
172 | kernel = functools.reduce(random_binary_map, selected_kernels)
173 |
174 | # Sample a time series from the GP prior
175 | try:
176 | ts = sample_from_gp_prior(kernel=kernel, X=X)
177 | except np.linalg.LinAlgError as err:
178 | print("Error caught:", err)
179 | continue
180 |
181 | # The timestamp is arbitrary
182 | return {"start": np.datetime64("2000-01-01 00:00", "s"), "target": ts.squeeze()}
183 |
184 |
185 | if __name__ == "__main__":
186 | parser = argparse.ArgumentParser()
187 | parser.add_argument("-N", "--num-series", type=int, default=1000_000)
188 | parser.add_argument("-J", "--max-kernels", type=int, default=5)
189 | args = parser.parse_args()
190 | path = Path(__file__).parent / "kernelsynth-data.arrow"
191 |
192 | generated_dataset = Parallel(n_jobs=-1)(
193 | delayed(generate_time_series)(max_kernels=args.max_kernels)
194 | for _ in tqdm(range(args.num_series))
195 | )
196 |
197 | ArrowWriter(compression="lz4").write_to_file(
198 | generated_dataset,
199 | path=path,
200 | )
201 |
--------------------------------------------------------------------------------
/scripts/training/configs/chronos-gpt2.yaml:
--------------------------------------------------------------------------------
1 | training_data_paths:
2 | - "/home/ubuntu/tsmixup-data.arrow"
3 | - "/home/ubuntu/kernelsynth-data.arrow"
4 | probability:
5 | - 0.9
6 | - 0.1
7 | context_length: 512
8 | prediction_length: 64
9 | min_past: 60
10 | max_steps: 200_000
11 | save_steps: 100_000
12 | log_steps: 500
13 | per_device_train_batch_size: 32
14 | learning_rate: 0.001
15 | optim: adamw_torch_fused
16 | num_samples: 20
17 | shuffle_buffer_length: 100_000
18 | gradient_accumulation_steps: 1
19 | model_id: openai-community/gpt2
20 | model_type: causal
21 | random_init: false
22 | tie_embeddings: false
23 | output_dir: ./output/
24 | tf32: true
25 | torch_compile: true
26 | tokenizer_class: "MeanScaleUniformBins"
27 | tokenizer_kwargs:
28 | low_limit: -15.0
29 | high_limit: 15.0
30 | n_tokens: 4096
31 | lr_scheduler_type: linear
32 | warmup_ratio: 0.0
33 | dataloader_num_workers: 1
34 | max_missing_prop: 0.1
35 | use_eos_token: true
36 |
--------------------------------------------------------------------------------
/scripts/training/configs/chronos-t5-base.yaml:
--------------------------------------------------------------------------------
1 | training_data_paths:
2 | - "/home/ubuntu/tsmixup-data.arrow"
3 | - "/home/ubuntu/kernelsynth-data.arrow"
4 | probability:
5 | - 0.9
6 | - 0.1
7 | context_length: 512
8 | prediction_length: 64
9 | min_past: 60
10 | max_steps: 200_000
11 | save_steps: 100_000
12 | log_steps: 500
13 | per_device_train_batch_size: 32
14 | learning_rate: 0.001
15 | optim: adamw_torch_fused
16 | num_samples: 20
17 | shuffle_buffer_length: 100_000
18 | gradient_accumulation_steps: 1
19 | model_id: google/t5-efficient-base
20 | model_type: seq2seq
21 | random_init: true
22 | tie_embeddings: true
23 | output_dir: ./output/
24 | tf32: true
25 | torch_compile: true
26 | tokenizer_class: "MeanScaleUniformBins"
27 | tokenizer_kwargs:
28 | low_limit: -15.0
29 | high_limit: 15.0
30 | n_tokens: 4096
31 | lr_scheduler_type: linear
32 | warmup_ratio: 0.0
33 | dataloader_num_workers: 1
34 | max_missing_prop: 0.9
35 | use_eos_token: true
36 |
--------------------------------------------------------------------------------
/scripts/training/configs/chronos-t5-large.yaml:
--------------------------------------------------------------------------------
1 | training_data_paths:
2 | - "/home/ubuntu/tsmixup-data.arrow"
3 | - "/home/ubuntu/kernelsynth-data.arrow"
4 | probability:
5 | - 0.9
6 | - 0.1
7 | context_length: 512
8 | prediction_length: 64
9 | min_past: 60
10 | max_steps: 200_000
11 | save_steps: 100_000
12 | log_steps: 500
13 | per_device_train_batch_size: 8
14 | learning_rate: 0.001
15 | optim: adamw_torch_fused
16 | num_samples: 20
17 | shuffle_buffer_length: 100_000
18 | gradient_accumulation_steps: 4
19 | model_id: google/t5-efficient-large
20 | model_type: seq2seq
21 | random_init: true
22 | tie_embeddings: true
23 | output_dir: ./output/
24 | tf32: true
25 | torch_compile: true
26 | tokenizer_class: "MeanScaleUniformBins"
27 | tokenizer_kwargs:
28 | low_limit: -15.0
29 | high_limit: 15.0
30 | n_tokens: 4096
31 | lr_scheduler_type: linear
32 | warmup_ratio: 0.0
33 | dataloader_num_workers: 1
34 | max_missing_prop: 0.9
35 | use_eos_token: true
36 |
--------------------------------------------------------------------------------
/scripts/training/configs/chronos-t5-mini.yaml:
--------------------------------------------------------------------------------
1 | training_data_paths:
2 | - "/home/ubuntu/tsmixup-data.arrow"
3 | - "/home/ubuntu/kernelsynth-data.arrow"
4 | probability:
5 | - 0.9
6 | - 0.1
7 | context_length: 512
8 | prediction_length: 64
9 | min_past: 60
10 | max_steps: 200_000
11 | save_steps: 100_000
12 | log_steps: 500
13 | per_device_train_batch_size: 32
14 | learning_rate: 0.001
15 | optim: adamw_torch_fused
16 | num_samples: 20
17 | shuffle_buffer_length: 100_000
18 | gradient_accumulation_steps: 1
19 | model_id: google/t5-efficient-mini
20 | model_type: seq2seq
21 | random_init: true
22 | tie_embeddings: true
23 | output_dir: ./output/
24 | tf32: true
25 | torch_compile: true
26 | tokenizer_class: "MeanScaleUniformBins"
27 | tokenizer_kwargs:
28 | low_limit: -15.0
29 | high_limit: 15.0
30 | n_tokens: 4096
31 | lr_scheduler_type: linear
32 | warmup_ratio: 0.0
33 | dataloader_num_workers: 1
34 | max_missing_prop: 0.9
35 | use_eos_token: true
36 |
--------------------------------------------------------------------------------
/scripts/training/configs/chronos-t5-small.yaml:
--------------------------------------------------------------------------------
1 | training_data_paths:
2 | - "/home/ubuntu/tsmixup-data.arrow"
3 | - "/home/ubuntu/kernelsynth-data.arrow"
4 | probability:
5 | - 0.9
6 | - 0.1
7 | context_length: 512
8 | prediction_length: 64
9 | min_past: 60
10 | max_steps: 200_000
11 | save_steps: 100_000
12 | log_steps: 500
13 | per_device_train_batch_size: 32
14 | learning_rate: 0.001
15 | optim: adamw_torch_fused
16 | num_samples: 20
17 | shuffle_buffer_length: 100_000
18 | gradient_accumulation_steps: 1
19 | model_id: google/t5-efficient-small
20 | model_type: seq2seq
21 | random_init: true
22 | tie_embeddings: true
23 | output_dir: ./output/
24 | tf32: true
25 | torch_compile: true
26 | tokenizer_class: "MeanScaleUniformBins"
27 | tokenizer_kwargs:
28 | low_limit: -15.0
29 | high_limit: 15.0
30 | n_tokens: 4096
31 | lr_scheduler_type: linear
32 | warmup_ratio: 0.0
33 | dataloader_num_workers: 1
34 | max_missing_prop: 0.9
35 | use_eos_token: true
36 |
--------------------------------------------------------------------------------
/scripts/training/configs/chronos-t5-tiny.yaml:
--------------------------------------------------------------------------------
1 | training_data_paths:
2 | - "/home/ubuntu/tsmixup-data.arrow"
3 | - "/home/ubuntu/kernelsynth-data.arrow"
4 | probability:
5 | - 0.9
6 | - 0.1
7 | context_length: 512
8 | prediction_length: 64
9 | min_past: 60
10 | max_steps: 200_000
11 | save_steps: 100_000
12 | log_steps: 500
13 | per_device_train_batch_size: 32
14 | learning_rate: 0.001
15 | optim: adamw_torch_fused
16 | num_samples: 20
17 | shuffle_buffer_length: 100_000
18 | gradient_accumulation_steps: 1
19 | model_id: google/t5-efficient-tiny
20 | model_type: seq2seq
21 | random_init: true
22 | tie_embeddings: true
23 | output_dir: ./output/
24 | tf32: true
25 | torch_compile: true
26 | tokenizer_class: "MeanScaleUniformBins"
27 | tokenizer_kwargs:
28 | low_limit: -15.0
29 | high_limit: 15.0
30 | n_tokens: 4096
31 | lr_scheduler_type: linear
32 | warmup_ratio: 0.0
33 | dataloader_num_workers: 1
34 | max_missing_prop: 0.9
35 | use_eos_token: true
36 |
--------------------------------------------------------------------------------
/scripts/training/train.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import ast
5 | import logging
6 | import os
7 | import re
8 | import sys
9 | import json
10 | import itertools
11 | import random
12 | from copy import deepcopy
13 | from pathlib import Path
14 | from functools import partial
15 | from typing import List, Iterator, Optional, Dict
16 |
17 | import typer
18 | from typer_config import use_yaml_config
19 | import numpy as np
20 | import torch
21 | import torch.distributed as dist
22 | from torch.utils.data import IterableDataset, get_worker_info
23 | import transformers
24 | from transformers import (
25 | AutoModelForSeq2SeqLM,
26 | AutoModelForCausalLM,
27 | AutoConfig,
28 | T5Config,
29 | Trainer,
30 | TrainingArguments,
31 | )
32 | import accelerate
33 | import gluonts
34 | from gluonts.dataset.common import FileDataset
35 | from gluonts.itertools import Cyclic, Map, Filter
36 | from gluonts.transform import (
37 | FilterTransformation,
38 | TestSplitSampler,
39 | ValidationSplitSampler,
40 | InstanceSplitter,
41 | ExpectedNumInstanceSampler,
42 | MissingValueImputation,
43 | LeavesMissingValues,
44 | LastValueImputation,
45 | )
46 |
47 | from chronos import ChronosConfig, ChronosTokenizer
48 |
49 |
50 | app = typer.Typer(pretty_exceptions_enable=False)
51 |
52 |
53 | def is_main_process() -> bool:
54 | """
55 | Check if we're on the main process.
56 | """
57 | if not dist.is_torchelastic_launched():
58 | return True
59 | return int(os.environ["RANK"]) == 0
60 |
61 |
62 | def log_on_main(msg: str, logger: logging.Logger, log_level: int = logging.INFO):
63 | """
64 | Log the given message using the given logger, if we're on the main process.
65 | """
66 | if is_main_process():
67 | logger.log(log_level, msg)
68 |
69 |
70 | def get_training_job_info() -> Dict:
71 | """
72 | Returns info about this training job.
73 | """
74 | job_info = {}
75 |
76 | # CUDA info
77 | job_info["cuda_available"] = torch.cuda.is_available()
78 | if torch.cuda.is_available():
79 | job_info["device_count"] = torch.cuda.device_count()
80 |
81 | job_info["device_names"] = {
82 | idx: torch.cuda.get_device_name(idx)
83 | for idx in range(torch.cuda.device_count())
84 | }
85 | job_info["mem_info"] = {
86 | idx: torch.cuda.mem_get_info(device=idx)
87 | for idx in range(torch.cuda.device_count())
88 | }
89 |
90 | # DDP info
91 | job_info["torchelastic_launched"] = dist.is_torchelastic_launched()
92 |
93 | if dist.is_torchelastic_launched():
94 | job_info["world_size"] = dist.get_world_size()
95 |
96 | # Versions
97 | job_info["python_version"] = sys.version.replace("\n", " ")
98 | job_info["torch_version"] = torch.__version__
99 | job_info["numpy_version"] = np.__version__
100 | job_info["gluonts_version"] = gluonts.__version__
101 | job_info["transformers_version"] = transformers.__version__
102 | job_info["accelerate_version"] = accelerate.__version__
103 |
104 | return job_info
105 |
106 |
107 | def save_training_info(ckpt_path: Path, training_config: Dict):
108 | """
109 | Save info about this training job in a json file for documentation.
110 | """
111 | assert ckpt_path.is_dir()
112 | with open(ckpt_path / "training_info.json", "w") as fp:
113 | json.dump(
114 | {"training_config": training_config, "job_info": get_training_job_info()},
115 | fp,
116 | indent=4,
117 | )
118 |
119 |
120 | def get_next_path(
121 | base_fname: str,
122 | base_dir: Path,
123 | file_type: str = "yaml",
124 | separator: str = "-",
125 | ):
126 | """
127 | Gets the next available path in a directory. For example, if `base_fname="results"`
128 | and `base_dir` has files ["results-0.yaml", "results-1.yaml"], this function returns
129 | "results-2.yaml".
130 | """
131 | if file_type == "":
132 | # Directory
133 | items = filter(
134 | lambda x: x.is_dir() and re.match(f"^{base_fname}{separator}\\d+$", x.stem),
135 | base_dir.glob("*"),
136 | )
137 | else:
138 | # File
139 | items = filter(
140 | lambda x: re.match(f"^{base_fname}{separator}\\d+$", x.stem),
141 | base_dir.glob(f"*.{file_type}"),
142 | )
143 | run_nums = list(
144 | map(lambda x: int(x.stem.replace(base_fname + separator, "")), items)
145 | ) + [-1]
146 |
147 | next_num = max(run_nums) + 1
148 | fname = f"{base_fname}{separator}{next_num}" + (
149 | f".{file_type}" if file_type != "" else ""
150 | )
151 |
152 | return base_dir / fname
153 |
154 |
155 | def load_model(
156 | model_id="google/t5-efficient-tiny",
157 | model_type="seq2seq",
158 | vocab_size=4096,
159 | random_init=False,
160 | tie_embeddings=False,
161 | pad_token_id=0,
162 | eos_token_id=1,
163 | ):
164 | """
165 | Load the specified HuggingFace model, adjusting the vocabulary
166 | size, special token IDs, and initialization options.
167 |
168 | This allows to set a model up for training on a new vocabulary
169 | of tokens.
170 | """
171 | assert model_type in ["seq2seq", "causal"]
172 | AutoModelClass = (
173 | AutoModelForSeq2SeqLM if model_type == "seq2seq" else AutoModelForCausalLM
174 | )
175 | if random_init:
176 | log_on_main("Using random initialization", logger)
177 | config = AutoConfig.from_pretrained(model_id)
178 | if isinstance(config, T5Config):
179 | # The default initializer_factor (1.0) in transformers is too large
180 | config.initializer_factor = 0.05
181 | config.tie_word_embeddings = tie_embeddings
182 | model = AutoModelClass.from_config(config)
183 | else:
184 | log_on_main(f"Using pretrained initialization from {model_id}", logger)
185 | model = AutoModelClass.from_pretrained(model_id)
186 |
187 | model.resize_token_embeddings(vocab_size)
188 |
189 | model.config.pad_token_id = model.generation_config.pad_token_id = pad_token_id
190 | model.config.eos_token_id = model.generation_config.eos_token_id = eos_token_id
191 |
192 | return model
193 |
194 |
195 | def has_enough_observations(
196 | entry: dict, min_length: int = 0, max_missing_prop: float = 1.0
197 | ) -> bool:
198 | """
199 | Check if the given entry has enough observations in the ``"target"`` attribute.
200 |
201 | Parameters
202 | ----------
203 | entry
204 | The data entry (dictionary) to be tested.
205 | min_length
206 | The minimum length the ``"target"`` attribute must have.
207 | max_missing_prop
208 | The maximum proportion of missing data allowed in the ``"target"``
209 | attribute.
210 | """
211 | if (
212 | len(entry["target"]) >= min_length
213 | and np.isnan(entry["target"]).mean() <= max_missing_prop
214 | ):
215 | return True
216 | return False
217 |
218 |
219 | class PseudoShuffledIterableDataset(IterableDataset):
220 | """
221 | Shuffle entries from an iterable by temporarily accumulating them
222 | in an intermediate buffer.
223 |
224 | Parameters
225 | ----------
226 | base_dataset
227 | The original iterable object, representing the dataset.
228 | shuffle_buffer_length
229 | Size of the buffer use to shuffle entries from the base dataset.
230 | """
231 |
232 | def __init__(self, base_dataset, shuffle_buffer_length: int = 100) -> None:
233 | super().__init__()
234 | self.base_dataset = base_dataset
235 | self.shuffle_buffer_length = shuffle_buffer_length
236 | self.generator = torch.Generator()
237 |
238 | def __iter__(self):
239 | shuffle_buffer = []
240 |
241 | for element in self.base_dataset:
242 | shuffle_buffer.append(element)
243 | if len(shuffle_buffer) >= self.shuffle_buffer_length:
244 | idx = torch.randint(
245 | len(shuffle_buffer), size=(), generator=self.generator
246 | )
247 | yield shuffle_buffer.pop(idx)
248 |
249 | while shuffle_buffer:
250 | idx = torch.randint(len(shuffle_buffer), size=(), generator=self.generator)
251 | yield shuffle_buffer.pop(idx)
252 |
253 |
254 | class ShuffleMixin:
255 | """
256 | Mix-in class that datasets can inherit from to get
257 | shuffling functionality.
258 | """
259 |
260 | def shuffle(self, shuffle_buffer_length: int = 100):
261 | return PseudoShuffledIterableDataset(self, shuffle_buffer_length)
262 |
263 |
264 | class ChronosDataset(IterableDataset, ShuffleMixin):
265 | """
266 | Dataset wrapper, using a ``ChronosTokenizer`` to turn data from a time series
267 | into a HuggingFace-compatible set of ``input_ids``, ``attention_mask`` and
268 | ``labels``.
269 |
270 | Entries from the original datasets are assumed to have a ``"start"`` attribute
271 | (of type ``pd.Period``), and a ``"target"`` attribute (of type ``np.ndarray``).
272 |
273 | Parameters
274 | ----------
275 | datasets
276 | Datasets containing the original time series data.
277 | probabilities
278 | In training mode, data will be sampled from each of the original datasets
279 | with these probabilities.
280 | tokenizer
281 | Tokenizer to be used to turn sequences of real numbers into token IDs.
282 | context_length
283 | Samples context will be limited to this length.
284 | prediction_length
285 | Samples labels will be limited to this length.
286 | drop_prob
287 | In training mode, observations from a sample will be turned into ``np.nan``,
288 | i.e. turned into missing values, with this probability.
289 | min_past
290 | Data samples will be considered only if there's at least ``min_past``-many
291 | historical observations.
292 | mode
293 | One of ``"training"``, ``"validation"``, or ``"test"``.
294 | np_dtype
295 | Numpy float data type.
296 | """
297 |
298 | def __init__(
299 | self,
300 | datasets: list,
301 | probabilities: List[float],
302 | tokenizer: ChronosTokenizer,
303 | context_length: int = 512,
304 | prediction_length: int = 64,
305 | drop_prob: float = 0.2,
306 | min_past: Optional[int] = None,
307 | model_type: str = "seq2seq",
308 | imputation_method: Optional[MissingValueImputation] = None,
309 | mode: str = "training",
310 | np_dtype=np.float32,
311 | ) -> None:
312 | super().__init__()
313 |
314 | assert len(probabilities) == len(datasets)
315 | assert mode in ("training", "validation", "test")
316 | assert model_type in ("seq2seq", "causal")
317 |
318 | self.datasets = datasets
319 | self.probabilities = probabilities
320 | self.tokenizer = tokenizer
321 | self.context_length = context_length
322 | self.prediction_length = prediction_length
323 | self.drop_prob = drop_prob if model_type == "seq2seq" else 0.0
324 | self.min_past = min_past or prediction_length
325 | self.model_type = model_type
326 | self.imputation_method = imputation_method or LeavesMissingValues()
327 | self.mode = mode
328 | self.np_dtype = np_dtype
329 |
330 | def preprocess_entry(self, entry: dict, mode: str) -> dict:
331 | entry = {f: entry[f] for f in ["start", "target"]}
332 | entry["target"] = np.asarray(entry["target"], dtype=self.np_dtype)
333 | assert entry["target"].ndim == 1, f"got {entry['target'].ndim=}, expected 1"
334 |
335 | if self.model_type == "causal":
336 | # Causal models do not play nice with missing values, so it is
337 | # recommended to use an imputation method, e.g., LastValueImputation
338 | entry["target"] = self.imputation_method(entry["target"])
339 |
340 | if mode == "training" and self.drop_prob > 0:
341 | target = entry["target"].copy()
342 | drop_p = np.random.uniform(low=0.0, high=self.drop_prob)
343 | mask = np.random.choice(
344 | [True, False], size=len(target), p=[drop_p, 1 - drop_p]
345 | )
346 | target[mask] = np.nan
347 | entry["target"] = target
348 |
349 | return entry
350 |
351 | def _create_instance_splitter(self, mode: str):
352 | assert mode in ["training", "test", "validation"]
353 |
354 | instance_sampler = {
355 | "training": ExpectedNumInstanceSampler(
356 | num_instances=1.0,
357 | min_instances=1,
358 | min_past=self.min_past,
359 | min_future=self.prediction_length,
360 | ),
361 | "test": TestSplitSampler(),
362 | "validation": ValidationSplitSampler(min_future=self.prediction_length),
363 | }[mode]
364 |
365 | return InstanceSplitter(
366 | target_field="target",
367 | is_pad_field="is_pad",
368 | start_field="start",
369 | forecast_start_field="forecast_start",
370 | instance_sampler=instance_sampler,
371 | past_length=self.context_length,
372 | future_length=self.prediction_length,
373 | dummy_value=np.nan,
374 | )
375 |
376 | def create_training_data(self, data):
377 | data = Cyclic(data)
378 | split_transform = self._create_instance_splitter(
379 | "training"
380 | ) + FilterTransformation(
381 | condition=lambda entry: (~np.isnan(entry["past_target"])).sum() > 0
382 | )
383 | data = split_transform.apply(data, is_train=True)
384 | return data
385 |
386 | def create_test_data(self, data):
387 | data = self._create_instance_splitter("test").apply(data, is_train=False)
388 | return data
389 |
390 | def create_validation_data(self, data):
391 | data = self._create_instance_splitter("validation").apply(data, is_train=False)
392 | return data
393 |
394 | def to_hf_format(self, entry: dict) -> dict:
395 | past_target = torch.tensor(entry["past_target"]).unsqueeze(0)
396 | input_ids, attention_mask, scale = self.tokenizer.context_input_transform(
397 | past_target
398 | )
399 | future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
400 | labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
401 | labels[labels_mask == 0] = -100
402 |
403 | if self.model_type == "causal":
404 | # The InstanceSplitter pads time series on the left to be equal to the
405 | # context_length. However, certain models (e.g., GPT2) with absolute
406 | # position embeddings should not be trained with left padding.
407 | # The following piece of code moves padding from left to right.
408 |
409 | assert input_ids.shape[-1] == entry["past_is_pad"].shape[0]
410 |
411 | # Find the index where padding starts
412 | pad_start_idx = np.searchsorted(1 - entry["past_is_pad"], 1)
413 | padded_input_ids, obs_input_ids = torch.tensor_split(
414 | input_ids, [pad_start_idx], dim=-1
415 | )
416 | padded_attention_mask, obs_attention_mask = torch.tensor_split(
417 | attention_mask, [pad_start_idx], dim=-1
418 | )
419 |
420 | # Move padding to the right
421 | input_ids = torch.cat(
422 | [
423 | obs_input_ids,
424 | labels,
425 | padded_input_ids,
426 | ],
427 | axis=-1,
428 | )
429 | attention_mask = torch.cat(
430 | [
431 | obs_attention_mask,
432 | labels_mask,
433 | padded_attention_mask,
434 | ],
435 | axis=-1,
436 | )
437 |
438 | # labels for causal models are same as the input_ids.
439 | # Internally transformers shifts the labels by one during training.
440 | labels = input_ids.clone()
441 | input_ids[~attention_mask] = self.tokenizer.config.pad_token_id
442 | labels[~attention_mask] = -100
443 |
444 | return {
445 | "input_ids": input_ids.squeeze(0),
446 | "attention_mask": attention_mask.squeeze(0),
447 | "labels": labels.squeeze(0),
448 | }
449 |
450 | def __iter__(self) -> Iterator:
451 | preprocessed_datasets = [
452 | Map(
453 | partial(self.preprocess_entry, mode=self.mode),
454 | dataset,
455 | )
456 | for dataset in self.datasets
457 | ]
458 |
459 | if self.mode == "training":
460 | iterables = [
461 | self.create_training_data(dataset) for dataset in preprocessed_datasets
462 | ]
463 | elif self.mode == "test":
464 | iterables = [
465 | self.create_test_data(dataset) for dataset in preprocessed_datasets
466 | ]
467 | else:
468 | iterables = [
469 | self.create_validation_data(dataset)
470 | for dataset in preprocessed_datasets
471 | ]
472 |
473 | worker_info = get_worker_info()
474 | if worker_info is None:
475 | probs = list(self.probabilities)
476 | else:
477 | worker_id = worker_info.id
478 | num_workers = worker_info.num_workers
479 | iterables = list(itertools.islice(iterables, worker_id, None, num_workers))
480 | probs = list(
481 | itertools.islice(self.probabilities, worker_id, None, num_workers)
482 | )
483 |
484 | probs = [prob / sum(probs) for prob in probs]
485 |
486 | iterators = list(map(iter, iterables))
487 | if self.mode == "training":
488 | while True:
489 | idx = np.random.choice(range(len(iterators)), p=probs)
490 | try:
491 | yield self.to_hf_format(next(iterators[idx]))
492 | except StopIteration:
493 | probs[idx] = 0
494 | if sum(probs) == 0:
495 | return
496 | probs = [prob / sum(probs) for prob in probs]
497 | else:
498 | for entry in itertools.chain(*iterators):
499 | yield self.to_hf_format(entry)
500 |
501 |
502 | @app.command()
503 | @use_yaml_config(param_name="config")
504 | def main(
505 | training_data_paths: str,
506 | probability: Optional[str] = None,
507 | context_length: int = 512,
508 | prediction_length: int = 64,
509 | min_past: int = 64,
510 | max_steps: int = 200_000,
511 | save_steps: int = 50_000,
512 | log_steps: int = 500,
513 | per_device_train_batch_size: int = 32,
514 | learning_rate: float = 1e-3,
515 | optim: str = "adamw_torch_fused",
516 | shuffle_buffer_length: int = 100,
517 | gradient_accumulation_steps: int = 2,
518 | model_id: str = "google/t5-efficient-tiny",
519 | model_type: str = "seq2seq",
520 | random_init: bool = False,
521 | tie_embeddings: bool = False,
522 | output_dir: str = "./output/",
523 | tf32: bool = True,
524 | torch_compile: bool = True,
525 | tokenizer_class: str = "MeanScaleUniformBins",
526 | tokenizer_kwargs: str = "{'low_limit': -15.0, 'high_limit': 15.0}",
527 | n_tokens: int = 4096,
528 | n_special_tokens: int = 2,
529 | pad_token_id: int = 0,
530 | eos_token_id: int = 1,
531 | use_eos_token: bool = True,
532 | lr_scheduler_type: str = "linear",
533 | warmup_ratio: float = 0.0,
534 | dataloader_num_workers: int = 1,
535 | max_missing_prop: float = 0.9,
536 | num_samples: int = 20,
537 | temperature: float = 1.0,
538 | top_k: int = 50,
539 | top_p: float = 1.0,
540 | seed: Optional[int] = None,
541 | ):
542 | if tf32 and not (
543 | torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
544 | ):
545 | # TF32 floating point format is available only on NVIDIA GPUs
546 | # with compute capability 8 and above. See link for details.
547 | # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capability-8-x
548 | log_on_main(
549 | "TF32 format is only available on devices with compute capability >= 8. "
550 | "Setting tf32 to False.",
551 | logger,
552 | )
553 | tf32 = False
554 |
555 | if seed is None:
556 | seed = random.randint(0, 2**32)
557 |
558 | log_on_main(f"Using SEED: {seed}", logger)
559 | transformers.set_seed(seed=seed)
560 |
561 | raw_training_config = deepcopy(locals())
562 | output_dir = Path(output_dir)
563 | training_data_paths = ast.literal_eval(training_data_paths)
564 | assert isinstance(training_data_paths, list)
565 |
566 | if isinstance(probability, str):
567 | probability = ast.literal_eval(probability)
568 | elif probability is None:
569 | probability = [1.0 / len(training_data_paths)] * len(training_data_paths)
570 | assert isinstance(probability, list)
571 |
572 | assert len(training_data_paths) == len(probability)
573 |
574 | if dataloader_num_workers > len(training_data_paths):
575 | log_on_main(
576 | f"Setting the number of data loader workers to {len(training_data_paths)}, "
577 | f"instead of {dataloader_num_workers}.",
578 | logger,
579 | )
580 | dataloader_num_workers = len(training_data_paths)
581 |
582 | if isinstance(tokenizer_kwargs, str):
583 | tokenizer_kwargs = ast.literal_eval(tokenizer_kwargs)
584 | assert isinstance(tokenizer_kwargs, dict)
585 |
586 | assert model_type in ["seq2seq", "causal"]
587 |
588 | output_dir = get_next_path("run", base_dir=output_dir, file_type="")
589 |
590 | log_on_main(f"Logging dir: {output_dir}", logger)
591 | log_on_main(
592 | f"Loading and filtering {len(training_data_paths)} datasets "
593 | f"for training: {training_data_paths}",
594 | logger,
595 | )
596 |
597 | log_on_main(
598 | f"Mixing probabilities: {probability}",
599 | logger,
600 | )
601 |
602 | train_datasets = [
603 | Filter(
604 | partial(
605 | has_enough_observations,
606 | min_length=min_past + prediction_length,
607 | max_missing_prop=max_missing_prop,
608 | ),
609 | FileDataset(path=Path(data_path), freq="h"),
610 | )
611 | for data_path in training_data_paths
612 | ]
613 |
614 | log_on_main("Initializing model", logger)
615 |
616 | model = load_model(
617 | model_id=model_id,
618 | model_type=model_type,
619 | vocab_size=n_tokens,
620 | random_init=random_init,
621 | tie_embeddings=tie_embeddings,
622 | pad_token_id=pad_token_id,
623 | eos_token_id=eos_token_id,
624 | )
625 |
626 | chronos_config = ChronosConfig(
627 | tokenizer_class=tokenizer_class,
628 | tokenizer_kwargs=tokenizer_kwargs,
629 | n_tokens=n_tokens,
630 | n_special_tokens=n_special_tokens,
631 | pad_token_id=pad_token_id,
632 | eos_token_id=eos_token_id,
633 | use_eos_token=use_eos_token,
634 | model_type=model_type,
635 | context_length=context_length,
636 | prediction_length=prediction_length,
637 | num_samples=num_samples,
638 | temperature=temperature,
639 | top_k=top_k,
640 | top_p=top_p,
641 | )
642 |
643 | # Add extra items to model config so that it's saved in the ckpt
644 | model.config.chronos_config = chronos_config.__dict__
645 |
646 | shuffled_train_dataset = ChronosDataset(
647 | datasets=train_datasets,
648 | probabilities=probability,
649 | tokenizer=chronos_config.create_tokenizer(),
650 | context_length=context_length,
651 | prediction_length=prediction_length,
652 | min_past=min_past,
653 | model_type=model_type,
654 | imputation_method=LastValueImputation() if model_type == "causal" else None,
655 | mode="training",
656 | ).shuffle(shuffle_buffer_length=shuffle_buffer_length)
657 |
658 | # Define training args
659 | training_args = TrainingArguments(
660 | output_dir=str(output_dir),
661 | per_device_train_batch_size=per_device_train_batch_size,
662 | learning_rate=learning_rate,
663 | lr_scheduler_type=lr_scheduler_type,
664 | warmup_ratio=warmup_ratio,
665 | optim=optim,
666 | logging_dir=str(output_dir / "logs"),
667 | logging_strategy="steps",
668 | logging_steps=log_steps,
669 | save_strategy="steps",
670 | save_steps=save_steps,
671 | report_to=["tensorboard"],
672 | max_steps=max_steps,
673 | gradient_accumulation_steps=gradient_accumulation_steps,
674 | dataloader_num_workers=dataloader_num_workers,
675 | tf32=tf32, # remove this if not using Ampere GPUs (e.g., A100)
676 | torch_compile=torch_compile,
677 | ddp_find_unused_parameters=False,
678 | remove_unused_columns=False,
679 | )
680 |
681 | # Create Trainer instance
682 | trainer = Trainer(
683 | model=model,
684 | args=training_args,
685 | train_dataset=shuffled_train_dataset,
686 | )
687 | log_on_main("Training", logger)
688 |
689 | trainer.train()
690 |
691 | if is_main_process():
692 | model.save_pretrained(output_dir / "checkpoint-final")
693 | save_training_info(
694 | output_dir / "checkpoint-final", training_config=raw_training_config
695 | )
696 |
697 |
698 | if __name__ == "__main__":
699 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
700 | logger = logging.getLogger(__file__)
701 | logger.setLevel(logging.INFO)
702 | app()
703 |
--------------------------------------------------------------------------------
/src/chronos/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from .base import BaseChronosPipeline, ForecastType
5 | from .chronos import (
6 | ChronosConfig,
7 | ChronosModel,
8 | ChronosPipeline,
9 | ChronosTokenizer,
10 | MeanScaleUniformBins,
11 | )
12 | from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline
13 |
14 | __all__ = [
15 | "BaseChronosPipeline",
16 | "ForecastType",
17 | "ChronosConfig",
18 | "ChronosModel",
19 | "ChronosPipeline",
20 | "ChronosTokenizer",
21 | "MeanScaleUniformBins",
22 | "ChronosBoltConfig",
23 | "ChronosBoltPipeline",
24 | ]
25 |
--------------------------------------------------------------------------------
/src/chronos/base.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | # Authors: Caner Turkmen , Abdul Fatir Ansari , Lorenzo Stella
5 | # Original source:
6 | # https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/base.py
7 |
8 | from enum import Enum
9 | from pathlib import Path
10 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
11 |
12 | import torch
13 |
14 | if TYPE_CHECKING:
15 | from transformers import PreTrainedModel
16 |
17 | from .utils import left_pad_and_stack_1D
18 |
19 |
20 | class ForecastType(Enum):
21 | SAMPLES = "samples"
22 | QUANTILES = "quantiles"
23 |
24 |
25 | class PipelineRegistry(type):
26 | REGISTRY: Dict[str, "PipelineRegistry"] = {}
27 |
28 | def __new__(cls, name, bases, attrs):
29 | """See, https://github.com/faif/python-patterns."""
30 | new_cls = type.__new__(cls, name, bases, attrs)
31 | if name is not None:
32 | cls.REGISTRY[name] = new_cls
33 |
34 | return new_cls
35 |
36 |
37 | class BaseChronosPipeline(metaclass=PipelineRegistry):
38 | forecast_type: ForecastType
39 | dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32}
40 |
41 | def __init__(self, inner_model: "PreTrainedModel"):
42 | """
43 | Parameters
44 | ----------
45 | inner_model : PreTrainedModel
46 | A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration
47 | """
48 | # for easy access to the inner HF-style model
49 | self.inner_model = inner_model
50 |
51 | def _prepare_and_validate_context(
52 | self, context: Union[torch.Tensor, List[torch.Tensor]]
53 | ):
54 | if isinstance(context, list):
55 | context = left_pad_and_stack_1D(context)
56 | assert isinstance(context, torch.Tensor)
57 | if context.ndim == 1:
58 | context = context.unsqueeze(0)
59 | assert context.ndim == 2
60 |
61 | return context
62 |
63 | def predict(
64 | self,
65 | context: Union[torch.Tensor, List[torch.Tensor]],
66 | prediction_length: Optional[int] = None,
67 | **kwargs,
68 | ):
69 | """
70 | Get forecasts for the given time series. Predictions will be
71 | returned in fp32 on the cpu.
72 |
73 | Parameters
74 | ----------
75 | context
76 | Input series. This is either a 1D tensor, or a list
77 | of 1D tensors, or a 2D tensor whose first dimension
78 | is batch. In the latter case, use left-padding with
79 | ``torch.nan`` to align series of different lengths.
80 | prediction_length
81 | Time steps to predict. Defaults to a model-dependent
82 | value if not given.
83 |
84 | Returns
85 | -------
86 | forecasts
87 | Tensor containing forecasts. The layout and meaning
88 | of the forecasts values depends on ``self.forecast_type``.
89 | """
90 | raise NotImplementedError()
91 |
92 | def predict_quantiles(
93 | self,
94 | context: Union[torch.Tensor, List[torch.Tensor]],
95 | prediction_length: Optional[int] = None,
96 | quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
97 | **kwargs,
98 | ) -> Tuple[torch.Tensor, torch.Tensor]:
99 | """
100 | Get quantile and mean forecasts for given time series.
101 | Predictions will be returned in fp32 on the cpu.
102 |
103 | Parameters
104 | ----------
105 | context : Union[torch.Tensor, List[torch.Tensor]]
106 | Input series. This is either a 1D tensor, or a list
107 | of 1D tensors, or a 2D tensor whose first dimension
108 | is batch. In the latter case, use left-padding with
109 | ``torch.nan`` to align series of different lengths.
110 | prediction_length : Optional[int], optional
111 | Time steps to predict. Defaults to a model-dependent
112 | value if not given.
113 | quantile_levels : List[float], optional
114 | Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]
115 |
116 | Returns
117 | -------
118 | quantiles
119 | Tensor containing quantile forecasts. Shape
120 | (batch_size, prediction_length, num_quantiles)
121 | mean
122 | Tensor containing mean (point) forecasts. Shape
123 | (batch_size, prediction_length)
124 | """
125 | raise NotImplementedError()
126 |
127 | @classmethod
128 | def from_pretrained(
129 | cls,
130 | pretrained_model_name_or_path: Union[str, Path],
131 | *model_args,
132 | **kwargs,
133 | ):
134 | """
135 | Load the model, either from a local path or from the HuggingFace Hub.
136 | Supports the same arguments as ``AutoConfig`` and ``AutoModel``
137 | from ``transformers``.
138 | """
139 | from transformers import AutoConfig
140 |
141 | torch_dtype = kwargs.get("torch_dtype", "auto")
142 | if torch_dtype != "auto" and isinstance(torch_dtype, str):
143 | kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
144 |
145 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
146 | is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(
147 | config, "chronos_config"
148 | )
149 |
150 | if not is_valid_config:
151 | raise ValueError("Not a Chronos config file")
152 |
153 | pipeline_class_name = getattr(
154 | config, "chronos_pipeline_class", "ChronosPipeline"
155 | )
156 | class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name)
157 | if class_ is None:
158 | raise ValueError(
159 | f"Trying to load unknown pipeline class: {pipeline_class_name}"
160 | )
161 |
162 | return class_.from_pretrained( # type: ignore[attr-defined]
163 | pretrained_model_name_or_path, *model_args, **kwargs
164 | )
165 |
--------------------------------------------------------------------------------
/src/chronos/chronos.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | # Authors: Abdul Fatir Ansari , Lorenzo Stella , Caner Turkmen
5 |
6 | import logging
7 | from dataclasses import dataclass
8 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union
9 |
10 | import torch
11 | import torch.nn as nn
12 | from transformers import (
13 | AutoConfig,
14 | AutoModelForCausalLM,
15 | AutoModelForSeq2SeqLM,
16 | GenerationConfig,
17 | PreTrainedModel,
18 | )
19 |
20 | import chronos
21 | from chronos.base import BaseChronosPipeline, ForecastType
22 | from chronos.utils import left_pad_and_stack_1D
23 |
24 | logger = logging.getLogger(__file__)
25 |
26 |
27 | @dataclass
28 | class ChronosConfig:
29 | """
30 | This class holds all the configuration parameters to be used
31 | by ``ChronosTokenizer`` and ``ChronosModel``.
32 | """
33 |
34 | tokenizer_class: str
35 | tokenizer_kwargs: Dict[str, Any]
36 | context_length: int
37 | prediction_length: int
38 | n_tokens: int
39 | n_special_tokens: int
40 | pad_token_id: int
41 | eos_token_id: int
42 | use_eos_token: bool
43 | model_type: Literal["causal", "seq2seq"]
44 | num_samples: int
45 | temperature: float
46 | top_k: int
47 | top_p: float
48 |
49 | def __post_init__(self):
50 | assert (
51 | self.pad_token_id < self.n_special_tokens
52 | and self.eos_token_id < self.n_special_tokens
53 | ), f"Special token id's must be smaller than {self.n_special_tokens=}"
54 |
55 | def create_tokenizer(self) -> "ChronosTokenizer":
56 | class_ = getattr(chronos, self.tokenizer_class)
57 | return class_(**self.tokenizer_kwargs, config=self)
58 |
59 |
60 | class ChronosTokenizer:
61 | """
62 | A ``ChronosTokenizer`` definines how time series are mapped into token IDs
63 | and back.
64 |
65 | For details, see the ``input_transform`` and ``output_transform`` methods,
66 | which concrete classes must implement.
67 | """
68 |
69 | def context_input_transform(
70 | self,
71 | context: torch.Tensor,
72 | ) -> Tuple:
73 | """
74 | Turn a batch of time series into token IDs, attention map, and tokenizer_state.
75 |
76 | Parameters
77 | ----------
78 | context
79 | A tensor shaped (batch_size, time_length), containing the
80 | timeseries to forecast. Use left-padding with ``torch.nan``
81 | to align time series of different lengths.
82 |
83 | Returns
84 | -------
85 | token_ids
86 | A tensor of integers, shaped (batch_size, time_length + 1)
87 | if ``config.use_eos_token`` and (batch_size, time_length)
88 | otherwise, containing token IDs for the input series.
89 | attention_mask
90 | A boolean tensor, same shape as ``token_ids``, indicating
91 | which input observations are not ``torch.nan`` (i.e. not
92 | missing nor padding).
93 | tokenizer_state
94 | An object that can be passed to ``label_input_transform``
95 | and ``output_transform``. Contains the relevant information
96 | to decode output samples into real values,
97 | such as location and scale parameters.
98 | """
99 | raise NotImplementedError()
100 |
101 | def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:
102 | """
103 | Turn a batch of label slices of time series into token IDs and attention map
104 | using the ``tokenizer_state`` provided by ``context_input_transform``.
105 |
106 | Parameters
107 | ----------
108 | context
109 | A tensor shaped (batch_size, time_length), containing the
110 | timeseries to forecast. Use left-padding with ``torch.nan``
111 | to align time series of different lengths.
112 | tokenizer_state
113 | An object returned by ``context_input_transform`` containing
114 | relevant information to preprocess data, such as location and
115 | scale. The nature of this depends on the specific tokenizer.
116 | This is used for tokenizing the label, in order to use the same
117 | scaling used to tokenize the context.
118 |
119 | Returns
120 | -------
121 | token_ids
122 | A tensor of integers, shaped (batch_size, time_length + 1)
123 | if ``config.use_eos_token`` and (batch_size, time_length)
124 | otherwise, containing token IDs for the input series.
125 | attention_mask
126 | A boolean tensor, same shape as ``token_ids``, indicating
127 | which input observations are not ``torch.nan`` (i.e. not
128 | missing nor padding).
129 | """
130 | raise NotImplementedError()
131 |
132 | def output_transform(
133 | self, samples: torch.Tensor, tokenizer_state: Any
134 | ) -> torch.Tensor:
135 | """
136 | Turn a batch of sample token IDs into real values.
137 |
138 | Parameters
139 | ----------
140 | samples
141 | A tensor of integers, shaped (batch_size, num_samples, time_length),
142 | containing token IDs of sample trajectories.
143 | tokenizer_state
144 | An object returned by ``input_transform`` containing
145 | relevant context to decode samples, such as location and scale.
146 | The nature of this depends on the specific tokenizer.
147 |
148 | Returns
149 | -------
150 | forecasts
151 | A real tensor, shaped (batch_size, num_samples, time_length),
152 | containing forecasted sample paths.
153 | """
154 | raise NotImplementedError()
155 |
156 |
157 | class MeanScaleUniformBins(ChronosTokenizer):
158 | def __init__(
159 | self, low_limit: float, high_limit: float, config: ChronosConfig
160 | ) -> None:
161 | self.config = config
162 | self.centers = torch.linspace(
163 | low_limit,
164 | high_limit,
165 | config.n_tokens - config.n_special_tokens - 1,
166 | )
167 | self.boundaries = torch.concat(
168 | (
169 | torch.tensor([-1e20], device=self.centers.device),
170 | (self.centers[1:] + self.centers[:-1]) / 2,
171 | torch.tensor([1e20], device=self.centers.device),
172 | )
173 | )
174 |
175 | def _input_transform(
176 | self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178 | context = context.to(dtype=torch.float32)
179 | attention_mask = ~torch.isnan(context)
180 |
181 | if scale is None:
182 | scale = torch.nansum(
183 | torch.abs(context) * attention_mask, dim=-1
184 | ) / torch.nansum(attention_mask, dim=-1)
185 | scale[~(scale > 0)] = 1.0
186 |
187 | scaled_context = context / scale.unsqueeze(dim=-1)
188 | token_ids = (
189 | torch.bucketize(
190 | input=scaled_context,
191 | boundaries=self.boundaries,
192 | # buckets are open to the right, see:
193 | # https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize
194 | right=True,
195 | )
196 | + self.config.n_special_tokens
197 | )
198 |
199 | token_ids.clamp_(0, self.config.n_tokens - 1)
200 |
201 | token_ids[~attention_mask] = self.config.pad_token_id
202 |
203 | return token_ids, attention_mask, scale
204 |
205 | def _append_eos_token(
206 | self, token_ids: torch.Tensor, attention_mask: torch.Tensor
207 | ) -> Tuple[torch.Tensor, torch.Tensor]:
208 | batch_size = token_ids.shape[0]
209 | eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id)
210 | token_ids = torch.concat((token_ids, eos_tokens), dim=1)
211 | eos_mask = torch.full((batch_size, 1), fill_value=True)
212 | attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
213 |
214 | return token_ids, attention_mask
215 |
216 | def context_input_transform(
217 | self, context: torch.Tensor
218 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
219 | length = context.shape[-1]
220 |
221 | if length > self.config.context_length:
222 | context = context[..., -self.config.context_length :]
223 |
224 | token_ids, attention_mask, scale = self._input_transform(context=context)
225 |
226 | if self.config.use_eos_token and self.config.model_type == "seq2seq":
227 | token_ids, attention_mask = self._append_eos_token(
228 | token_ids=token_ids, attention_mask=attention_mask
229 | )
230 |
231 | return token_ids, attention_mask, scale
232 |
233 | def label_input_transform(
234 | self, label: torch.Tensor, scale: torch.Tensor
235 | ) -> Tuple[torch.Tensor, torch.Tensor]:
236 | length = label.shape[-1]
237 |
238 | assert length == self.config.prediction_length
239 | token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)
240 |
241 | if self.config.use_eos_token:
242 | token_ids, attention_mask = self._append_eos_token(
243 | token_ids=token_ids, attention_mask=attention_mask
244 | )
245 |
246 | return token_ids, attention_mask
247 |
248 | def output_transform(
249 | self, samples: torch.Tensor, scale: torch.Tensor
250 | ) -> torch.Tensor:
251 | scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
252 | indices = torch.clamp(
253 | samples - self.config.n_special_tokens - 1,
254 | min=0,
255 | max=len(self.centers) - 1,
256 | )
257 | return self.centers[indices] * scale_unsqueezed
258 |
259 |
260 | class ChronosModel(nn.Module):
261 | """
262 | A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers``
263 | and uses it to predict sample paths for time series tokens.
264 |
265 | Parameters
266 | ----------
267 | config
268 | The configuration to use.
269 | model
270 | The pretrained model to use.
271 | """
272 |
273 | def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None:
274 | super().__init__()
275 | self.config = config
276 | self.model = model
277 |
278 | @property
279 | def device(self):
280 | return self.model.device
281 |
282 | def encode(
283 | self,
284 | input_ids: torch.Tensor,
285 | attention_mask: torch.Tensor,
286 | ):
287 | """
288 | Extract the encoder embedding for the given token sequences.
289 |
290 | Parameters
291 | ----------
292 | input_ids
293 | Tensor of indices of input sequence tokens in the vocabulary
294 | with shape (batch_size, sequence_length).
295 | attention_mask
296 | A mask tensor of the same shape as input_ids to avoid attending
297 | on padding or missing tokens.
298 |
299 | Returns
300 | -------
301 | embedding
302 | A tensor of encoder embeddings with shape
303 | (batch_size, sequence_length, d_model).
304 | """
305 | assert (
306 | self.config.model_type == "seq2seq"
307 | ), "Encoder embeddings are only supported for encoder-decoder models"
308 | assert hasattr(self.model, "encoder")
309 |
310 | return self.model.encoder(
311 | input_ids=input_ids, attention_mask=attention_mask
312 | ).last_hidden_state
313 |
314 | def forward(
315 | self,
316 | input_ids: torch.Tensor,
317 | attention_mask: torch.Tensor,
318 | prediction_length: Optional[int] = None,
319 | num_samples: Optional[int] = None,
320 | temperature: Optional[float] = None,
321 | top_k: Optional[int] = None,
322 | top_p: Optional[float] = None,
323 | ) -> torch.Tensor:
324 | """
325 | Predict future sample tokens for the given token sequences.
326 |
327 | Arguments ``prediction_length``, ``num_samples``, ``temperature``,
328 | ``top_k``, ``top_p`` can be used to customize the model inference,
329 | and default to the corresponding attributes in ``self.config`` if
330 | not provided.
331 |
332 | Returns
333 | -------
334 | samples
335 | A tensor of integers, shaped (batch_size, num_samples, time_length),
336 | containing forecasted sample paths.
337 | """
338 | if prediction_length is None:
339 | prediction_length = self.config.prediction_length
340 | if num_samples is None:
341 | num_samples = self.config.num_samples
342 | if temperature is None:
343 | temperature = self.config.temperature
344 | if top_k is None:
345 | top_k = self.config.top_k
346 | if top_p is None:
347 | top_p = self.config.top_p
348 |
349 | assert hasattr(self.model, "generate")
350 |
351 | preds = self.model.generate(
352 | input_ids=input_ids,
353 | attention_mask=attention_mask,
354 | generation_config=GenerationConfig(
355 | min_new_tokens=prediction_length,
356 | max_new_tokens=prediction_length,
357 | do_sample=True,
358 | num_return_sequences=num_samples,
359 | eos_token_id=self.config.eos_token_id,
360 | pad_token_id=self.config.pad_token_id,
361 | temperature=temperature,
362 | top_k=top_k,
363 | top_p=top_p,
364 | ),
365 | )
366 |
367 | if self.config.model_type == "seq2seq":
368 | preds = preds[..., 1:] # remove the decoder start token
369 | else:
370 | assert self.config.model_type == "causal"
371 | assert preds.size(-1) == input_ids.size(-1) + prediction_length
372 | preds = preds[..., -prediction_length:]
373 |
374 | return preds.reshape(input_ids.size(0), num_samples, -1)
375 |
376 |
377 | class ChronosPipeline(BaseChronosPipeline):
378 | """
379 | A ``ChronosPipeline`` uses the given tokenizer and model to forecast
380 | input time series.
381 |
382 | Use the ``from_pretrained`` class method to load serialized models.
383 | Use the ``predict`` method to get forecasts.
384 |
385 | Parameters
386 | ----------
387 | tokenizer
388 | The tokenizer object to use.
389 | model
390 | The model to use.
391 | """
392 |
393 | tokenizer: ChronosTokenizer
394 | model: ChronosModel
395 | forecast_type: ForecastType = ForecastType.SAMPLES
396 |
397 | def __init__(self, tokenizer, model):
398 | super().__init__(inner_model=model.model)
399 | self.tokenizer = tokenizer
400 | self.model = model
401 |
402 | def _prepare_and_validate_context(
403 | self, context: Union[torch.Tensor, List[torch.Tensor]]
404 | ):
405 | if isinstance(context, list):
406 | context = left_pad_and_stack_1D(context)
407 | assert isinstance(context, torch.Tensor)
408 | if context.ndim == 1:
409 | context = context.unsqueeze(0)
410 | assert context.ndim == 2
411 |
412 | return context
413 |
414 | @torch.no_grad()
415 | def embed(
416 | self, context: Union[torch.Tensor, List[torch.Tensor]]
417 | ) -> Tuple[torch.Tensor, Any]:
418 | """
419 | Get encoder embeddings for the given time series.
420 |
421 | Parameters
422 | ----------
423 | context
424 | Input series. This is either a 1D tensor, or a list
425 | of 1D tensors, or a 2D tensor whose first dimension
426 | is batch. In the latter case, use left-padding with
427 | ``torch.nan`` to align series of different lengths.
428 |
429 | Returns
430 | -------
431 | embeddings, tokenizer_state
432 | A tuple of two tensors: the encoder embeddings and the tokenizer_state,
433 | e.g., the scale of the time series in the case of mean scaling.
434 | The encoder embeddings are shaped (batch_size, context_length, d_model)
435 | or (batch_size, context_length + 1, d_model), where context_length
436 | is the size of the context along the time axis if a 2D tensor was provided
437 | or the length of the longest time series, if a list of 1D tensors was
438 | provided, and the extra 1 is for EOS.
439 | """
440 | context_tensor = self._prepare_and_validate_context(context=context)
441 | token_ids, attention_mask, tokenizer_state = (
442 | self.tokenizer.context_input_transform(context_tensor)
443 | )
444 | embeddings = self.model.encode(
445 | input_ids=token_ids.to(self.model.device),
446 | attention_mask=attention_mask.to(self.model.device),
447 | ).cpu()
448 | return embeddings, tokenizer_state
449 |
450 | def predict( # type: ignore[override]
451 | self,
452 | context: Union[torch.Tensor, List[torch.Tensor]],
453 | prediction_length: Optional[int] = None,
454 | num_samples: Optional[int] = None,
455 | temperature: Optional[float] = None,
456 | top_k: Optional[int] = None,
457 | top_p: Optional[float] = None,
458 | limit_prediction_length: bool = False,
459 | ) -> torch.Tensor:
460 | """
461 | Get forecasts for the given time series.
462 |
463 | Refer to the base method (``BaseChronosPipeline.predict``)
464 | for details on shared parameters.
465 |
466 | Additional parameters
467 | ---------------------
468 | num_samples
469 | Number of sample paths to predict. Defaults to what
470 | specified in ``self.model.config``.
471 | temperature
472 | Temperature to use for generating sample tokens.
473 | Defaults to what specified in ``self.model.config``.
474 | top_k
475 | Top-k parameter to use for generating sample tokens.
476 | Defaults to what specified in ``self.model.config``.
477 | top_p
478 | Top-p parameter to use for generating sample tokens.
479 | Defaults to what specified in ``self.model.config``.
480 | limit_prediction_length
481 | Force prediction length smaller or equal than the
482 | built-in prediction length from the model. False by
483 | default. When true, fail loudly if longer predictions
484 | are requested, otherwise longer predictions are allowed.
485 |
486 | Returns
487 | -------
488 | samples
489 | Tensor of sample forecasts, of shape
490 | (batch_size, num_samples, prediction_length).
491 | """
492 | context_tensor = self._prepare_and_validate_context(context=context)
493 |
494 | if prediction_length is None:
495 | prediction_length = self.model.config.prediction_length
496 |
497 | if prediction_length > self.model.config.prediction_length:
498 | msg = (
499 | f"We recommend keeping prediction length <= {self.model.config.prediction_length}. "
500 | "The quality of longer predictions may degrade since the model is not optimized for it. "
501 | )
502 | if limit_prediction_length:
503 | msg += "You can turn off this check by setting `limit_prediction_length=False`."
504 | raise ValueError(msg)
505 | logger.warning(msg)
506 |
507 | predictions = []
508 | remaining = prediction_length
509 |
510 | while remaining > 0:
511 | token_ids, attention_mask, scale = self.tokenizer.context_input_transform(
512 | context_tensor
513 | )
514 | samples = self.model(
515 | token_ids.to(self.model.device),
516 | attention_mask.to(self.model.device),
517 | min(remaining, self.model.config.prediction_length),
518 | num_samples,
519 | temperature,
520 | top_k,
521 | top_p,
522 | )
523 | prediction = self.tokenizer.output_transform(
524 | samples.to(scale.device), scale
525 | )
526 |
527 | predictions.append(prediction)
528 | remaining -= prediction.shape[-1]
529 |
530 | if remaining <= 0:
531 | break
532 |
533 | context_tensor = torch.cat(
534 | [context_tensor, prediction.median(dim=1).values], dim=-1
535 | )
536 |
537 | return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device="cpu")
538 |
539 | def predict_quantiles(
540 | self,
541 | context: Union[torch.Tensor, List[torch.Tensor]],
542 | prediction_length: Optional[int] = None,
543 | quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
544 | **predict_kwargs,
545 | ) -> Tuple[torch.Tensor, torch.Tensor]:
546 | """
547 | Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
548 | """
549 | prediction_samples = (
550 | self.predict(context, prediction_length=prediction_length, **predict_kwargs)
551 | .detach()
552 | .swapaxes(1, 2)
553 | )
554 | mean = prediction_samples.mean(dim=-1)
555 | quantiles = torch.quantile(
556 | prediction_samples,
557 | q=torch.tensor(quantile_levels, dtype=prediction_samples.dtype),
558 | dim=-1,
559 | ).permute(1, 2, 0)
560 |
561 | return quantiles, mean
562 |
563 | @classmethod
564 | def from_pretrained(cls, *args, **kwargs):
565 | """
566 | Load the model, either from a local path or from the HuggingFace Hub.
567 | Supports the same arguments as ``AutoConfig`` and ``AutoModel``
568 | from ``transformers``.
569 | """
570 |
571 | config = AutoConfig.from_pretrained(*args, **kwargs)
572 |
573 | assert hasattr(config, "chronos_config"), "Not a Chronos config file"
574 |
575 | chronos_config = ChronosConfig(**config.chronos_config)
576 |
577 | if chronos_config.model_type == "seq2seq":
578 | inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
579 | else:
580 | assert chronos_config.model_type == "causal"
581 | inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
582 |
583 | return cls(
584 | tokenizer=chronos_config.create_tokenizer(),
585 | model=ChronosModel(config=chronos_config, model=inner_model),
586 | )
587 |
--------------------------------------------------------------------------------
/src/chronos/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 |
5 | from typing import List
6 |
7 | import torch
8 |
9 |
10 | def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
11 | max_len = max(len(c) for c in tensors)
12 | padded = []
13 | for c in tensors:
14 | assert isinstance(c, torch.Tensor)
15 | assert c.ndim == 1
16 | padding = torch.full(
17 | size=(max_len - len(c),), fill_value=torch.nan, device=c.device
18 | )
19 | padded.append(torch.concat((padding, c), dim=-1))
20 | return torch.stack(padded)
21 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
--------------------------------------------------------------------------------
/test/dummy-chronos-bolt-model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "ChronosBoltModelForForecasting"
4 | ],
5 | "chronos_config": {
6 | "context_length": 512,
7 | "input_patch_size": 16,
8 | "input_patch_stride": 16,
9 | "prediction_length": 64,
10 | "quantiles": [
11 | 0.1,
12 | 0.2,
13 | 0.3,
14 | 0.4,
15 | 0.5,
16 | 0.6,
17 | 0.7,
18 | 0.8,
19 | 0.9
20 | ],
21 | "use_reg_token": true
22 | },
23 | "chronos_pipeline_class": "ChronosBoltPipeline",
24 | "classifier_dropout": 0.0,
25 | "d_ff": 8,
26 | "d_kv": 4,
27 | "d_model": 8,
28 | "decoder_start_token_id": 0,
29 | "dense_act_fn": "relu",
30 | "dropout_rate": 0.1,
31 | "eos_token_id": 1,
32 | "feed_forward_proj": "relu",
33 | "initializer_factor": 0.05,
34 | "is_encoder_decoder": true,
35 | "is_gated_act": false,
36 | "layer_norm_epsilon": 1e-06,
37 | "model_type": "t5",
38 | "n_positions": 512,
39 | "num_decoder_layers": 4,
40 | "num_heads": 4,
41 | "num_layers": 4,
42 | "pad_token_id": 0,
43 | "reg_token_id": 1,
44 | "relative_attention_max_distance": 128,
45 | "relative_attention_num_buckets": 32,
46 | "torch_dtype": "float32",
47 | "transformers_version": "4.40.2",
48 | "use_cache": true,
49 | "vocab_size": 2
50 | }
51 |
--------------------------------------------------------------------------------
/test/dummy-chronos-bolt-model/model.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/chronos-forecasting/6a9c8dadac04eb85befc935043e3e2cce914267f/test/dummy-chronos-bolt-model/model.safetensors
--------------------------------------------------------------------------------
/test/dummy-chronos-model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "T5ForConditionalGeneration"
4 | ],
5 | "d_ff": 32,
6 | "d_kv": 16,
7 | "d_model": 64,
8 | "decoder_start_token_id": 0,
9 | "dense_act_fn": "relu",
10 | "dropout_rate": 0.1,
11 | "eos_token_id": 1,
12 | "feed_forward_proj": "relu",
13 | "initializer_factor": 0.05,
14 | "is_encoder_decoder": true,
15 | "is_gated_act": false,
16 | "layer_norm_epsilon": 1e-06,
17 | "model_type": "t5",
18 | "n_positions": 512,
19 | "num_decoder_layers": 1,
20 | "num_heads": 1,
21 | "num_layers": 1,
22 | "pad_token_id": 0,
23 | "relative_attention_max_distance": 128,
24 | "relative_attention_num_buckets": 32,
25 | "torch_dtype": "bfloat16",
26 | "transformers_version": "4.31.0",
27 | "use_cache": true,
28 | "vocab_size": 32,
29 | "chronos_config": {
30 | "tokenizer_class": "MeanScaleUniformBins",
31 | "tokenizer_kwargs": {
32 | "low_limit": -15.0,
33 | "high_limit": 15.0
34 | },
35 | "n_tokens": 32,
36 | "n_special_tokens": 2,
37 | "pad_token_id": 0,
38 | "eos_token_id": 1,
39 | "use_eos_token": true,
40 | "model_type": "seq2seq",
41 | "context_length": 512,
42 | "prediction_length": 64,
43 | "num_samples": 20,
44 | "temperature": 1.0,
45 | "top_k": 50,
46 | "top_p": 1.0
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/test/dummy-chronos-model/generation_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_from_model_config": true,
3 | "decoder_start_token_id": 0,
4 | "eos_token_id": 1,
5 | "pad_token_id": 0,
6 | "transformers_version": "4.31.0"
7 | }
8 |
--------------------------------------------------------------------------------
/test/dummy-chronos-model/pytorch_model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/chronos-forecasting/6a9c8dadac04eb85befc935043e3e2cce914267f/test/dummy-chronos-model/pytorch_model.bin
--------------------------------------------------------------------------------
/test/test_chronos.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from pathlib import Path
5 |
6 | import pytest
7 | import torch
8 |
9 | from chronos import (
10 | BaseChronosPipeline,
11 | ChronosConfig,
12 | ChronosPipeline,
13 | MeanScaleUniformBins,
14 | )
15 | from test.util import validate_tensor
16 |
17 |
18 | def test_base_chronos_pipeline_loads_from_huggingface():
19 | BaseChronosPipeline.from_pretrained("amazon/chronos-t5-tiny", device_map="cpu")
20 |
21 |
22 | @pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
23 | @pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
24 | def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int):
25 | n_tokens = n_numerical_tokens + n_special_tokens
26 |
27 | config = ChronosConfig(
28 | tokenizer_class="MeanScaleUniformBins",
29 | tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
30 | n_tokens=n_tokens,
31 | n_special_tokens=n_special_tokens,
32 | pad_token_id=0,
33 | eos_token_id=1,
34 | use_eos_token=True,
35 | model_type="seq2seq",
36 | context_length=512,
37 | prediction_length=64,
38 | num_samples=20,
39 | temperature=1.0,
40 | top_k=50,
41 | top_p=1.0,
42 | )
43 |
44 | tokenizer = config.create_tokenizer()
45 | assert isinstance(tokenizer, MeanScaleUniformBins)
46 |
47 | context = tokenizer.centers.unsqueeze(0) # add batch dimension
48 | scale = torch.ones((1,)) # fix the scale to one to turn off scaling
49 |
50 | token_ids, _, _ = tokenizer._input_transform(context, scale=scale)
51 |
52 | samples = tokenizer.output_transform(
53 | token_ids.unsqueeze(1), # add sample dimension
54 | scale=scale,
55 | )
56 |
57 | assert (samples[0, 0, :] == context).all()
58 |
59 |
60 | @pytest.mark.xfail
61 | @pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
62 | @pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
63 | @pytest.mark.parametrize("use_eos_token", [False, True])
64 | def test_tokenizer_fixed_data(
65 | n_numerical_tokens: int, n_special_tokens: int, use_eos_token: bool
66 | ):
67 | n_tokens = n_numerical_tokens + n_special_tokens
68 | context_length = 3
69 |
70 | config = ChronosConfig(
71 | tokenizer_class="MeanScaleUniformBins",
72 | tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
73 | n_tokens=n_tokens,
74 | n_special_tokens=n_special_tokens,
75 | pad_token_id=0,
76 | eos_token_id=1,
77 | use_eos_token=use_eos_token,
78 | model_type="seq2seq",
79 | context_length=512,
80 | prediction_length=64,
81 | num_samples=20,
82 | temperature=1.0,
83 | top_k=50,
84 | top_p=1.0,
85 | )
86 |
87 | tokenizer = config.create_tokenizer()
88 |
89 | context = torch.tensor(
90 | [
91 | [-3.7, 3.7],
92 | [-42.0, 42.0],
93 | ]
94 | )
95 | batch_size, _ = context.shape
96 |
97 | token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
98 |
99 | assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token)
100 | assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size))
101 | assert all(token_ids[:, 1] == torch.tensor([n_special_tokens]).repeat(batch_size))
102 | assert all(token_ids[:, 2] == torch.tensor([n_tokens - 1]).repeat(batch_size))
103 |
104 | if use_eos_token:
105 | assert all(token_ids[:, 3] == torch.tensor([1]).repeat(batch_size))
106 |
107 | samples = tokenizer.output_transform(
108 | torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1),
109 | tokenizer_state=scale,
110 | )
111 |
112 | assert (samples[:, 0, [0, -1]] == context).all()
113 |
114 |
115 | @pytest.mark.xfail
116 | @pytest.mark.parametrize("use_eos_token", [False, True])
117 | def test_tokenizer_random_data(use_eos_token: bool):
118 | context_length = 8
119 | n_tokens = 256
120 | n_special_tokens = 2
121 |
122 | config = ChronosConfig(
123 | tokenizer_class="MeanScaleUniformBins",
124 | tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
125 | n_tokens=n_tokens,
126 | n_special_tokens=n_special_tokens,
127 | pad_token_id=0,
128 | eos_token_id=1,
129 | use_eos_token=use_eos_token,
130 | model_type="seq2seq",
131 | context_length=context_length,
132 | prediction_length=64,
133 | num_samples=20,
134 | temperature=1.0,
135 | top_k=50,
136 | top_p=1.0,
137 | )
138 |
139 | tokenizer = config.create_tokenizer()
140 |
141 | context = torch.tensor(
142 | [
143 | [torch.nan, torch.nan, 1.0, 1.1, torch.nan, 2.0],
144 | [3.0, torch.nan, 3.9, 4.0, 4.1, 4.9],
145 | ]
146 | )
147 |
148 | token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
149 |
150 | assert token_ids.shape == (
151 | *context.shape[:-1],
152 | context_length + 1 * use_eos_token,
153 | )
154 | assert attention_mask.shape == (
155 | *context.shape[:-1],
156 | context_length + 1 * use_eos_token,
157 | )
158 | assert scale.shape == context.shape[:1]
159 |
160 | sample_ids = torch.randint(low=n_special_tokens, high=n_tokens, size=(2, 10, 4))
161 | sample_ids[0, 0, 0] = n_special_tokens
162 | sample_ids[-1, -1, -1] = n_tokens - 1
163 |
164 | samples = tokenizer.output_transform(sample_ids, scale)
165 |
166 | assert samples.shape == (2, 10, 4)
167 |
168 |
169 | @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
170 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
171 | def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
172 | pipeline = ChronosPipeline.from_pretrained(
173 | Path(__file__).parent / "dummy-chronos-model",
174 | device_map="cpu",
175 | torch_dtype=model_dtype,
176 | )
177 | context = 10 * torch.rand(size=(4, 16)) + 10
178 | context = context.to(dtype=input_dtype)
179 |
180 | # input: tensor of shape (batch_size, context_length)
181 |
182 | samples = pipeline.predict(context, num_samples=12, prediction_length=3)
183 | validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
184 |
185 | with pytest.raises(ValueError):
186 | samples = pipeline.predict(
187 | context, num_samples=7, prediction_length=65, limit_prediction_length=True
188 | )
189 |
190 | samples = pipeline.predict(
191 | context, num_samples=7, prediction_length=65, limit_prediction_length=False
192 | )
193 | validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
194 |
195 | # input: batch_size-long list of tensors of shape (context_length,)
196 |
197 | samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
198 | validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
199 |
200 | with pytest.raises(ValueError):
201 | samples = pipeline.predict(
202 | list(context),
203 | num_samples=7,
204 | prediction_length=65,
205 | limit_prediction_length=True,
206 | )
207 |
208 | samples = pipeline.predict(
209 | list(context),
210 | num_samples=7,
211 | prediction_length=65,
212 | limit_prediction_length=False,
213 | )
214 | validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
215 |
216 | # input: tensor of shape (context_length,)
217 |
218 | samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
219 | validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32)
220 |
221 | with pytest.raises(ValueError):
222 | samples = pipeline.predict(
223 | context[0, ...],
224 | num_samples=7,
225 | prediction_length=65,
226 | limit_prediction_length=True,
227 | )
228 |
229 | samples = pipeline.predict(
230 | context[0, ...],
231 | num_samples=7,
232 | prediction_length=65,
233 | )
234 | validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32)
235 |
236 |
237 | @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
238 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
239 | @pytest.mark.parametrize("prediction_length", [3, 65])
240 | @pytest.mark.parametrize(
241 | "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
242 | )
243 | def test_pipeline_predict_quantiles(
244 | model_dtype: torch.dtype,
245 | input_dtype: torch.dtype,
246 | prediction_length: int,
247 | quantile_levels: list[int],
248 | ):
249 | pipeline = ChronosPipeline.from_pretrained(
250 | Path(__file__).parent / "dummy-chronos-model",
251 | device_map="cpu",
252 | torch_dtype=model_dtype,
253 | )
254 | context = 10 * torch.rand(size=(4, 16)) + 10
255 | context = context.to(dtype=input_dtype)
256 |
257 | num_expected_quantiles = len(quantile_levels)
258 | # input: tensor of shape (batch_size, context_length)
259 |
260 | quantiles, mean = pipeline.predict_quantiles(
261 | context,
262 | num_samples=12,
263 | prediction_length=prediction_length,
264 | quantile_levels=quantile_levels,
265 | )
266 | validate_tensor(
267 | quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
268 | )
269 | validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
270 |
271 | # input: batch_size-long list of tensors of shape (context_length,)
272 |
273 | quantiles, mean = pipeline.predict_quantiles(
274 | list(context),
275 | num_samples=12,
276 | prediction_length=prediction_length,
277 | quantile_levels=quantile_levels,
278 | )
279 | validate_tensor(
280 | quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
281 | )
282 | validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
283 |
284 | # input: tensor of shape (context_length,)
285 |
286 | quantiles, mean = pipeline.predict_quantiles(
287 | context[0, ...],
288 | num_samples=12,
289 | prediction_length=prediction_length,
290 | quantile_levels=quantile_levels,
291 | )
292 | validate_tensor(
293 | quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32
294 | )
295 | validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
296 |
297 |
298 | @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
299 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
300 | def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
301 | pipeline = ChronosPipeline.from_pretrained(
302 | Path(__file__).parent / "dummy-chronos-model",
303 | device_map="cpu",
304 | torch_dtype=model_dtype,
305 | )
306 | d_model = pipeline.model.model.config.d_model
307 | context = 10 * torch.rand(size=(4, 16)) + 10
308 | context = context.to(dtype=input_dtype)
309 | expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
310 |
311 | # input: tensor of shape (batch_size, context_length)
312 |
313 | embedding, scale = pipeline.embed(context)
314 | validate_tensor(
315 | embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
316 | )
317 | validate_tensor(scale, shape=(4,), dtype=torch.float32)
318 |
319 | # input: batch_size-long list of tensors of shape (context_length,)
320 |
321 | embedding, scale = pipeline.embed(list(context))
322 | validate_tensor(
323 | embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
324 | )
325 | validate_tensor(scale, shape=(4,), dtype=torch.float32)
326 |
327 | # input: tensor of shape (context_length,)
328 | embedding, scale = pipeline.embed(context[0, ...])
329 | validate_tensor(
330 | embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
331 | )
332 | validate_tensor(scale, shape=(1,), dtype=torch.float32)
333 |
334 |
335 | @pytest.mark.parametrize("n_tokens", [10, 1000, 10000])
336 | def test_tokenizer_number_of_buckets(n_tokens):
337 | config = ChronosConfig(
338 | tokenizer_class="MeanScaleUniformBins",
339 | tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
340 | n_tokens=n_tokens,
341 | n_special_tokens=2,
342 | pad_token_id=0,
343 | eos_token_id=1,
344 | use_eos_token=True,
345 | model_type="seq2seq",
346 | context_length=512,
347 | prediction_length=64,
348 | num_samples=20,
349 | temperature=1.0,
350 | top_k=50,
351 | top_p=1.0,
352 | )
353 | tokenizer = config.create_tokenizer()
354 |
355 | n_numerical_tokens = config.n_tokens - config.n_special_tokens
356 |
357 | # The tokenizer has one bucket too many as a result of an early bug. In order to
358 | # keep consistent with the original trained models, this is kept as it is. However,
359 | # token ids are clipped to a maximum of `n_tokens - 1` to avoid out-of-bounds errors.
360 | assert len(tokenizer.centers) == (n_numerical_tokens - 1)
361 | assert len(tokenizer.boundaries) == n_numerical_tokens
362 |
363 |
364 | @pytest.mark.parametrize("n_tokens", [10, 1000, 10000])
365 | def test_token_clipping(n_tokens):
366 | config = ChronosConfig(
367 | tokenizer_class="MeanScaleUniformBins",
368 | tokenizer_kwargs={"low_limit": -15, "high_limit": 15},
369 | n_tokens=n_tokens,
370 | n_special_tokens=2,
371 | pad_token_id=0,
372 | eos_token_id=1,
373 | use_eos_token=True,
374 | model_type="seq2seq",
375 | context_length=512,
376 | prediction_length=64,
377 | num_samples=20,
378 | temperature=1.0,
379 | top_k=50,
380 | top_p=1.0,
381 | )
382 | tokenizer = config.create_tokenizer()
383 |
384 | huge_value = 1e22 # this large value is assigned to the largest bucket
385 | token_ids, _, _ = tokenizer._input_transform(
386 | context=torch.tensor([[huge_value]]), scale=torch.tensor(([1]))
387 | )
388 | assert token_ids[0, 0] == config.n_tokens - 1 # and it's clipped to n_tokens - 1
389 |
--------------------------------------------------------------------------------
/test/test_chronos_bolt.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from pathlib import Path
5 |
6 | import pytest
7 | import torch
8 |
9 | from chronos import BaseChronosPipeline, ChronosBoltPipeline
10 | from chronos.chronos_bolt import InstanceNorm, Patch
11 | from test.util import validate_tensor
12 |
13 |
14 | def test_base_chronos_pipeline_loads_from_huggingface():
15 | BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-tiny", device_map="cpu")
16 |
17 |
18 | @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
19 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
20 | def test_pipeline_predict(torch_dtype: torch.dtype, input_dtype: torch.dtype):
21 | pipeline = ChronosBoltPipeline.from_pretrained(
22 | Path(__file__).parent / "dummy-chronos-bolt-model",
23 | device_map="cpu",
24 | torch_dtype=torch_dtype,
25 | )
26 | context = 10 * torch.rand(size=(4, 16)) + 10
27 | context = context.to(dtype=input_dtype)
28 | expected_num_quantiles = len(pipeline.quantiles)
29 |
30 | # input: tensor of shape (batch_size, context_length)
31 |
32 | quantiles = pipeline.predict(context, prediction_length=3)
33 | validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32)
34 |
35 | with pytest.raises(ValueError):
36 | quantiles = pipeline.predict(
37 | context, prediction_length=65, limit_prediction_length=True
38 | )
39 |
40 | quantiles = pipeline.predict(context, prediction_length=65)
41 | validate_tensor(quantiles, (4, expected_num_quantiles, 65))
42 |
43 | # input: batch_size-long list of tensors of shape (context_length,)
44 |
45 | quantiles = pipeline.predict(list(context), prediction_length=3)
46 | validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32)
47 |
48 | with pytest.raises(ValueError):
49 | quantiles = pipeline.predict(
50 | list(context),
51 | prediction_length=65,
52 | limit_prediction_length=True,
53 | )
54 |
55 | quantiles = pipeline.predict(list(context), prediction_length=65)
56 | validate_tensor(quantiles, (4, expected_num_quantiles, 65), dtype=torch.float32)
57 |
58 | # input: tensor of shape (context_length,)
59 |
60 | quantiles = pipeline.predict(context[0, ...], prediction_length=3)
61 | validate_tensor(quantiles, (1, expected_num_quantiles, 3), dtype=torch.float32)
62 |
63 | with pytest.raises(ValueError):
64 | quantiles = pipeline.predict(
65 | context[0, ...],
66 | prediction_length=65,
67 | limit_prediction_length=True,
68 | )
69 |
70 | quantiles = pipeline.predict(
71 | context[0, ...],
72 | prediction_length=65,
73 | )
74 | validate_tensor(quantiles, (1, expected_num_quantiles, 65), dtype=torch.float32)
75 |
76 |
77 | @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
78 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
79 | @pytest.mark.parametrize("prediction_length", [3, 65])
80 | @pytest.mark.parametrize(
81 | "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
82 | )
83 | def test_pipeline_predict_quantiles(
84 | torch_dtype: torch.dtype,
85 | input_dtype: torch.dtype,
86 | prediction_length: int,
87 | quantile_levels: list[int],
88 | ):
89 | pipeline = ChronosBoltPipeline.from_pretrained(
90 | Path(__file__).parent / "dummy-chronos-bolt-model",
91 | device_map="cpu",
92 | torch_dtype=torch_dtype,
93 | )
94 | context = 10 * torch.rand(size=(4, 16)) + 10
95 | context = context.to(dtype=input_dtype)
96 |
97 | num_expected_quantiles = len(quantile_levels)
98 | # input: tensor of shape (batch_size, context_length)
99 |
100 | quantiles, mean = pipeline.predict_quantiles(
101 | context,
102 | prediction_length=prediction_length,
103 | quantile_levels=quantile_levels,
104 | )
105 | validate_tensor(
106 | quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
107 | )
108 | validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
109 |
110 | # input: batch_size-long list of tensors of shape (context_length,)
111 |
112 | quantiles, mean = pipeline.predict_quantiles(
113 | list(context),
114 | prediction_length=prediction_length,
115 | quantile_levels=quantile_levels,
116 | )
117 | validate_tensor(
118 | quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
119 | )
120 | validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
121 |
122 | # input: tensor of shape (context_length,)
123 |
124 | quantiles, mean = pipeline.predict_quantiles(
125 | context[0, ...],
126 | prediction_length=prediction_length,
127 | quantile_levels=quantile_levels,
128 | )
129 | validate_tensor(
130 | quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32
131 | )
132 | validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
133 |
134 |
135 | @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
136 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
137 | def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
138 | pipeline = ChronosBoltPipeline.from_pretrained(
139 | Path(__file__).parent / "dummy-chronos-bolt-model",
140 | device_map="cpu",
141 | torch_dtype=model_dtype,
142 | )
143 | d_model = pipeline.model.config.d_model
144 | context = 10 * torch.rand(size=(4, 16)) + 10
145 | context = context.to(dtype=input_dtype)
146 |
147 | # the patch size of dummy model is 16, so only 1 patch is created
148 | expected_embed_length = 1 + (
149 | 1 if pipeline.model.config.chronos_config["use_reg_token"] else 0
150 | )
151 |
152 | # input: tensor of shape (batch_size, context_length)
153 |
154 | embedding, loc_scale = pipeline.embed(context)
155 | validate_tensor(
156 | embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
157 | )
158 | validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
159 | validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)
160 |
161 | # input: batch_size-long list of tensors of shape (context_length,)
162 |
163 | embedding, loc_scale = pipeline.embed(list(context))
164 | validate_tensor(
165 | embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
166 | )
167 | validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
168 | validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)
169 |
170 | # input: tensor of shape (context_length,)
171 | embedding, loc_scale = pipeline.embed(context[0, ...])
172 | validate_tensor(
173 | embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
174 | )
175 | validate_tensor(loc_scale[0], shape=(1,), dtype=torch.float32)
176 | validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32)
177 |
178 |
179 | # The following tests have been taken from
180 | # https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py
181 | # Author: Caner Turkmen
182 |
183 |
184 | def test_given_even_data_patch_operator_output_is_correct():
185 | batch_size = 17
186 | patch_len = 16
187 |
188 | patch = Patch(patch_len, patch_len)
189 |
190 | batch = (
191 | torch.stack([torch.arange(512)] * batch_size)
192 | + torch.arange(batch_size)[:, None]
193 | )
194 | output = patch(batch)
195 |
196 | assert output.shape == (batch_size, 512 // patch_len, patch_len)
197 |
198 | assert torch.allclose(
199 | output[:, 0],
200 | torch.stack([torch.arange(patch_len)] * batch_size)
201 | + torch.arange(batch_size)[:, None],
202 | atol=1e-5,
203 | )
204 | assert torch.allclose(
205 | output[:, 1],
206 | torch.stack([torch.arange(patch_len, 2 * patch_len)] * batch_size)
207 | + torch.arange(batch_size)[:, None],
208 | atol=1e-5,
209 | )
210 | assert not torch.isnan(output).any()
211 |
212 |
213 | def test_given_even_data_and_strides_patch_operator_output_is_correct():
214 | batch_size = 17
215 | patch_len, patch_stride = 16, 8
216 |
217 | patch = Patch(patch_len, patch_stride)
218 |
219 | offset = torch.arange(batch_size)[:, None]
220 | batch = torch.stack([torch.arange(512)] * batch_size) + offset
221 | output = patch(batch)
222 |
223 | assert torch.allclose(
224 | output[:, 1],
225 | torch.stack([torch.arange(patch_stride, patch_stride + patch_len)] * batch_size)
226 | + offset,
227 | atol=1e-5,
228 | )
229 | assert not torch.isnan(output).any()
230 |
231 |
232 | def test_given_uneven_data_patch_operator_pads_and_output_is_correct():
233 | batch_size = 17
234 | patch_len = 16
235 |
236 | patch = Patch(patch_len, patch_len)
237 |
238 | batch = (
239 | torch.stack([torch.arange(512 - patch_len + 1)] * batch_size)
240 | + torch.arange(batch_size)[:, None]
241 | ).float()
242 | output = patch(batch)
243 |
244 | assert output.shape == (batch_size, 512 // patch_len, patch_len)
245 |
246 | # check the first portion is padded
247 | assert torch.isnan(output[:, 0, :-1]).all()
248 |
249 | # check nowhere else is nan
250 | assert not torch.isnan(output[:, 1:]).any()
251 |
252 |
253 | def test_when_instancenorm_applied_then_standardization_correct():
254 | inorm = InstanceNorm()
255 |
256 | input_ = torch.tensor(
257 | [
258 | [1, 2, 3, 4, 5],
259 | [2, 3, 4, 5, 6],
260 | ]
261 | ).float()
262 |
263 | normalized, (loc, scale) = inorm(input_)
264 |
265 | assert normalized.shape == input_.shape
266 | assert torch.allclose(normalized[0], normalized[1])
267 | assert torch.allclose(loc.squeeze(), torch.tensor([3.0, 4.0]))
268 | assert torch.allclose(scale.squeeze(), torch.tensor(1.41421))
269 |
270 |
271 | def test_when_instancenorm_applied_and_reversed_then_nans_preserved():
272 | inorm = InstanceNorm()
273 |
274 | input_ = torch.tensor(
275 | [
276 | [1, torch.nan, 3, 4, 5],
277 | [2, 3, 4, 5, torch.nan],
278 | ]
279 | ).float()
280 |
281 | normalized, (loc, scale) = inorm(input_)
282 | assert torch.allclose(normalized.isnan(), input_.isnan())
283 |
284 | output = inorm.inverse(normalized, (loc, scale))
285 | assert torch.allclose(output, input_, equal_nan=True)
286 |
287 |
288 | def test_when_instancenorm_applied_and_reversed_then_output_correct():
289 | inorm = InstanceNorm()
290 |
291 | input_ = torch.tensor(
292 | [
293 | [1, 2, 3, 4, 5],
294 | [2, 3, 4, 5, 1000],
295 | ]
296 | ).float()
297 |
298 | normalized, loc_scale = inorm(input_)
299 | output = inorm.inverse(normalized, loc_scale)
300 |
301 | assert torch.allclose(output, input_)
302 |
--------------------------------------------------------------------------------
/test/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import pytest
5 | import torch
6 |
7 | from chronos.utils import left_pad_and_stack_1D
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "tensors",
12 | [
13 | [
14 | torch.tensor([2.0, 3.0], dtype=dtype),
15 | torch.tensor([4.0, 5.0, 6.0], dtype=dtype),
16 | torch.tensor([7.0, 8.0, 9.0, 10.0], dtype=dtype),
17 | ]
18 | for dtype in [torch.int, torch.float16, torch.float32]
19 | ],
20 | )
21 | def test_pad_and_stack(tensors: list):
22 | stacked_and_padded = left_pad_and_stack_1D(tensors)
23 |
24 | assert stacked_and_padded.dtype == torch.float32
25 | assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors))
26 |
27 | ref = torch.concat(tensors).to(dtype=stacked_and_padded.dtype)
28 |
29 | assert torch.sum(torch.nan_to_num(stacked_and_padded, nan=0)) == torch.sum(ref)
30 |
--------------------------------------------------------------------------------
/test/util.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 |
5 |
6 | def validate_tensor(
7 | a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
8 | ) -> None:
9 | assert isinstance(a, torch.Tensor)
10 | assert a.shape == shape
11 |
12 | if dtype is not None:
13 | assert a.dtype == dtype
14 |
--------------------------------------------------------------------------------