├── .github └── workflows │ ├── dependabot.yml │ └── unittest.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── datasets └── nbody │ └── mnist.npz ├── figures ├── cuboid_illustration.gif ├── data │ ├── earthnet │ │ ├── vis_earthnet_ndvi_train0.png │ │ └── vis_earthnet_rgb_train0.png │ ├── nbody │ │ ├── chaos_nbody_dis01.png │ │ ├── chaos_nbody_dis02_seed0.png │ │ ├── chaos_seed9 │ │ │ ├── mm_dis01.png │ │ │ ├── mm_dis01_seq0.npy │ │ │ ├── mm_dis01_seq1.npy │ │ │ ├── nbody_r0_dis02.png │ │ │ ├── nbody_r0_dis02_seq0.npy │ │ │ └── nbody_r0_dis02_seq1.npy │ │ ├── mnist_chaos_dis01.png │ │ ├── vis_chaos_seed9.pdf │ │ └── vis_chaos_seed9.png │ └── sevir │ │ ├── sevir_example.png │ │ └── sevir_example_len7.pdf └── teaser.png ├── scripts ├── baselines │ └── persistence │ │ └── earthnet │ │ ├── README.md │ │ ├── cfg.yaml │ │ └── test_persistence_earthnet.py ├── cuboid_transformer │ ├── earthnet_w_meso │ │ ├── README.md │ │ ├── cfg.yaml │ │ ├── earthformer_earthnet_v1.yaml │ │ ├── inference_tutorial_earthformer_earthnet2021.ipynb │ │ ├── inference_tutorial_earthformer_earthnet2021x.ipynb │ │ ├── predict_cuboid_en21x.py │ │ └── train_cuboid_earthnet.py │ ├── enso │ │ ├── README.md │ │ ├── cfg.yaml │ │ ├── earthformer_enso_v1.yaml │ │ └── train_cuboid_enso.py │ ├── moving_mnist │ │ ├── README.md │ │ ├── cfg.yaml │ │ └── train_cuboid_mnist.py │ ├── nbody │ │ ├── README.md │ │ ├── cfg.yaml │ │ └── train_cuboid_nbody.py │ └── sevir │ │ ├── README.md │ │ ├── cfg_sevir.yaml │ │ ├── cfg_sevirlr.yaml │ │ ├── earthformer_sevir_v1.yaml │ │ └── train_cuboid_sevir.py └── datasets │ ├── nbody │ ├── README.md │ ├── cfg.yaml │ ├── download_nbody_paper.py │ └── generate_nbody_dataset.py │ └── sevir │ └── download_sevir.py ├── setup.py ├── src └── earthformer │ ├── __init__.py │ ├── baselines │ ├── __init__.py │ └── persistence.py │ ├── config.py │ ├── cuboid_transformer │ ├── __init__.py │ ├── cuboid_transformer.py │ ├── cuboid_transformer_patterns.py │ ├── cuboid_transformer_unet_dec.py │ └── utils.py │ ├── datasets │ ├── __init__.py │ ├── augmentation.py │ ├── earthnet │ │ ├── __init__.py │ │ ├── earthnet21x_dataloader.py │ │ ├── earthnet_dataloader.py │ │ ├── earthnet_scores.py │ │ ├── earthnet_toolkit │ │ │ ├── __init__.py │ │ │ ├── coords.py │ │ │ ├── coords_dict.py │ │ │ ├── download.py │ │ │ ├── download_links.py │ │ │ ├── parallel_score.py │ │ │ └── plot_cube.py │ │ └── visualization.py │ ├── enso │ │ ├── __init__.py │ │ └── enso_dataloader.py │ ├── moving_mnist │ │ ├── __init__.py │ │ └── moving_mnist.py │ ├── nbody │ │ ├── __init__.py │ │ ├── nbody_mnist.py │ │ └── nbody_mnist_torch_wrap.py │ └── sevir │ │ ├── __init__.py │ │ ├── sevir_dataloader.py │ │ └── sevir_torch_wrap.py │ ├── metrics │ ├── __init__.py │ ├── enso.py │ ├── sevir.py │ ├── skill_scores.py │ └── torchmetrics_wo_compute.py │ ├── utils │ ├── __init__.py │ ├── apex_ddp.py │ ├── checkpoint.py │ ├── layout.py │ ├── optim.py │ ├── registry.py │ └── utils.py │ └── visualization │ ├── __init__.py │ ├── nbody.py │ ├── sevir │ ├── __init__.py │ ├── dataset_statistics.py │ ├── sevir_cmap.py │ └── sevir_vis_seq.py │ └── utils.py └── tests ├── README.md ├── test_cuboid.py └── unittests └── test_pretrained_checkpoints.py /.github/workflows/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yml: -------------------------------------------------------------------------------- 1 | name: Unit-Test 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build-linux: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | max-parallel: 5 10 | 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python 3.9 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: '3.9' 17 | - name: Add conda to system path 18 | run: | 19 | # $CONDA is an environment variable pointing to the root of the miniconda directory 20 | echo $CONDA/bin >> $GITHUB_PATH 21 | - name: Install dependencies 22 | run: | 23 | pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu 24 | pip install pytorch_lightning==1.6.4 25 | pip install xarray netcdf4 opencv-python 26 | pip install -U -e . --no-build-isolation 27 | - name: Test with pytest 28 | run: | 29 | pip install pytest 30 | cd tests/unittests && pytest . 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | .idea/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # MacOS 133 | .DS_Store 134 | 135 | # experiment results 136 | /experiments/ 137 | /scripts/cuboid_transformer/moving_mnist/experiments/ 138 | /scripts/cuboid_transformer/nbody/experiments/ 139 | /scripts/cuboid_transformer/sevir/experiments/ 140 | /scripts/cuboid_transformer/enso/experiments/ 141 | /scripts/cuboid_transformer/earthnet_w_meso/experiments/ 142 | 143 | # local documents 144 | /docs/ 145 | *.pptx 146 | 147 | # datasets 148 | /datasets/moving_mnist/ 149 | /datasets/sevir/ 150 | /datasets/sevir_lr/ 151 | /datasets/earthnet2021/ 152 | /datasets/icar_enso_2021/ 153 | /tests/unittests/test_pretrained_checkpoints_data/ 154 | 155 | # pretrained checkpoints 156 | *.pt 157 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Earthformer 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/earthformer-exploring-space-time-transformers/weather-forecasting-on-sevir)](https://paperswithcode.com/sota/weather-forecasting-on-sevir?p=earthformer-exploring-space-time-transformers) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/earthformer-exploring-space-time-transformers/earth-surface-forecasting-on-earthnet2021-iid)](https://paperswithcode.com/sota/earth-surface-forecasting-on-earthnet2021-iid?p=earthformer-exploring-space-time-transformers) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/earthformer-exploring-space-time-transformers/earth-surface-forecasting-on-earthnet2021-ood)](https://paperswithcode.com/sota/earth-surface-forecasting-on-earthnet2021-ood?p=earthformer-exploring-space-time-transformers) 6 | 7 | By [Zhihan Gao](https://scholar.google.com/citations?user=P6ACUAUAAAAJ&hl=en), [Xingjian Shi](https://github.com/sxjscience), [Hao Wang](http://www.wanghao.in/), [Yi Zhu](https://bryanyzhu.github.io/), [Yuyang Wang](https://scholar.google.com/citations?user=IKUm624AAAAJ&hl=en), [Mu Li](https://github.com/mli), [Dit-Yan Yeung](https://scholar.google.com/citations?user=nEsOOx8AAAAJ&hl=en). 8 | 9 | This repo is the official implementation of ["Earthformer: Exploring Space-Time Transformers for Earth System Forecasting"](https://www.amazon.science/publications/earthformer-exploring-space-time-transformers-for-earth-system-forecasting) that will appear in NeurIPS 2022. 10 | 11 | Check our [poster](https://earthformer.s3.amazonaws.com/docs/Earthformer_poster_NeurIPS22.pdf). 12 | 13 | ## Tutorials 14 | 15 | - [Inference Tutorial of Earthformer on EarthNet2021](./scripts/cuboid_transformer/earthnet_w_meso/inference_tutorial_earthformer_earthnet2021.ipynb). [![Open In Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/amazon-science/earth-forecasting-transformer/blob/main/scripts/cuboid_transformer/earthnet_w_meso/inference_tutorial_earthformer_earthnet2021.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/amazon-science/earth-forecasting-transformer/blob/main/scripts/cuboid_transformer/earthnet_w_meso/inference_tutorial_earthformer_earthnet2021.ipynb) 16 | 17 | ## Introduction 18 | 19 | Conventionally, Earth system (e.g., weather and climate) forecasting relies on numerical simulation with complex physical models and are hence both 20 | expensive in computation and demanding on domain expertise. With the explosive growth of the spatiotemporal Earth observation data in the past decade, 21 | data-driven models that apply Deep Learning (DL) are demonstrating impressive potential for various Earth system forecasting tasks. 22 | The Transformer as an emerging DL architecture, despite its broad success in other domains, has limited adoption in this area. 23 | In this paper, we propose **Earthformer**, a space-time Transformer for Earth system forecasting. 24 | Earthformer is based on a generic, flexible and efficient space-time attention block, named **Cuboid Attention**. 25 | The idea is to decompose the data into cuboids and apply cuboid-level self-attention in parallel. These cuboids are further connected with a collection of global vectors. 26 | 27 | Earthformer achieves strong results in synthetic datasets like MovingMNIST and N-body MNIST dataset, and also outperforms non-Transformer models (like ConvLSTM, CNN-U-Net) in SEVIR (precipitation nowcasting) and ICAR-ENSO2021 (El Nino/Southern Oscillation forecasting). 28 | 29 | 30 | ![teaser](figures/teaser.png) 31 | 32 | 33 | ### Cuboid Attention Illustration 34 | 35 | ![cuboid_attention_illustration](figures/cuboid_illustration.gif) 36 | 37 | ## Installation 38 | We recommend managing the environment through Anaconda. 39 | 40 | First, find out where CUDA is installed on your machine. It is usually under `/usr/local/cuda` or `/opt/cuda`. 41 | 42 | Next, check which version of CUDA you have installed on your machine: 43 | 44 | ```bash 45 | nvcc --version 46 | ``` 47 | 48 | Then, create a new conda environment: 49 | 50 | ```bash 51 | conda create -n earthformer python=3.9 52 | conda activate earthformer 53 | ``` 54 | 55 | Lastly, install dependencies. For example, if you have CUDA 11.6 installed under `/usr/local/cuda`, run: 56 | 57 | ```bash 58 | python3 -m pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 -f https://download.pytorch.org/whl/torch_stable.html 59 | python3 -m pip install pytorch_lightning==1.6.4 60 | python3 -m pip install xarray netcdf4 opencv-python earthnet==0.3.9 61 | cd ROOT_DIR/earth-forecasting-transformer 62 | python3 -m pip install -U -e . --no-build-isolation 63 | 64 | # Install Apex 65 | CUDA_HOME=/usr/local/cuda python3 -m pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" pytorch-extension git+https://github.com/NVIDIA/apex.git 66 | ``` 67 | 68 | If you have CUDA 11.7 installed under `/opt/cuda`, run: 69 | 70 | ```bash 71 | python3 -m pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html 72 | python3 -m pip install pytorch_lightning==1.6.4 73 | python3 -m pip install xarray netcdf4 opencv-python earthnet==0.3.9 74 | cd ROOT_DIR/earth-forecasting-transformer 75 | python3 -m pip install -U -e . --no-build-isolation 76 | 77 | # Install Apex 78 | CUDA_HOME=/opt/cuda python3 -m pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" pytorch-extension git+https://github.com/NVIDIA/apex.git 79 | ``` 80 | 81 | ## Dataset 82 | ### MovingMNIST 83 | We follow [*Unsupervised Learning of Video Representations using LSTMs (ICML2015)*](http://www.cs.toronto.edu/~nitish/unsup_video.pdf) to use [MovingMNIST](https://github.com/mansimov/unsupervised-videos) that contains 10,000 sequences each of length 20 showing 2 digits moving in a $64\times 64$ frame. 84 | 85 | Our [MovingMNIST DataModule](src/earthformer/datasets/moving_mnist/moving_mnist.py) automatically downloads it to [datasets/moving_mnist](./datasets/moving_mnist). 86 | 87 | ### N-body MNIST 88 | The underlying dynamics in the N-body MNIST dataset is governed by the Newton's law of universal gravitation: 89 | 90 | $\frac{d^2\boldsymbol{x}\_{i}}{dt^2} = - \sum\_{j\neq i}\frac{G m\_j (\boldsymbol{x}\_{i}-\boldsymbol{x}\_{j})}{(\|\boldsymbol{x}\_i-\boldsymbol{x}\_j\|+d\_{\text{soft}})^r}$ 91 | 92 | where $\boldsymbol{x}\_{i}$ is the spatial coordinates of the $i$-th digit, $G$ is the gravitational constant, $m\_j$ is the mass of the $j$-th digit, $r$ is a constant representing the power scale in the gravitational law, $d\_{\text{soft}}$ is a small softening distance that ensures numerical stability. 93 | 94 | The N-body MNIST dataset we used in the paper can be downloaded from https://earthformer.s3.amazonaws.com/nbody/nbody_paper.zip . 95 | 96 | In addition, you can also use the following script for downloading / extracting the data: 97 | ```bash 98 | cd ROOT_DIR/earth-forecasting-transformer 99 | python ./scripts/datasets/nbody/download_nbody_paper.py 100 | ``` 101 | 102 | Alternatively, run the following commands to generate N-body MNIST dataset. 103 | ```bash 104 | cd ROOT_DIR/earth-forecasting-transformer 105 | python ./scripts/datasets/nbody/generate_nbody_dataset.py --cfg ./scripts/datasets/nbody/cfg.yaml 106 | ``` 107 | 108 | ### SEVIR 109 | [Storm EVent ImageRy (SEVIR) dataset](https://sevir.mit.edu/) is a spatiotemporally aligned dataset containing over 10,000 weather events. 110 | We adopt NEXRAD Vertically Integrated Liquid (VIL) mosaics in SEVIR for benchmarking precipitation nowcasting, i.e., to predict the future VIL up to 60 minutes given 65 minutes context VIL. 111 | The resolution is thus $13\times 384\times 384\rightarrow 12\times 384\times 384$. 112 | 113 | To download SEVIR dataset from AWS S3, run: 114 | ```bash 115 | cd ROOT_DIR/earth-forecasting-transformer 116 | python ./scripts/datasets/sevir/download_sevir.py --dataset sevir 117 | ``` 118 | 119 | A visualization example of SEVIR VIL sequence: 120 | ![Example_SEVIR_VIL_sequence](./figures/data/sevir/sevir_example.png) 121 | 122 | ### ICAR-ENSO 123 | ICAR-ENSO consists of historical climate observation and stimulation data provided by Institute for Climate and Application Research (ICAR). 124 | We forecast the SST anomalies up to 14 steps (2 steps more than one year for calculating three-month-moving-average), 125 | given a context of 12 steps (one year) of SST anomalies observations. 126 | 127 | To download the dataset, you need to follow the instructions on the [official website](https://tianchi.aliyun.com/dataset/dataDetail?dataId=98942). 128 | You can download a zip-file named `enso_round1_train_20210201.zip`. Put it under `./datasets/` and extract the zip file with the following command: 129 | 130 | ```bash 131 | unzip datasets/enso_round1_train_20210201.zip -d datasets/icar_enso_2021 132 | ``` 133 | 134 | ### EarthNet2021 135 | 136 | You may follow the [official instructions](https://www.earthnet.tech/en21/ds-download/) for downloading [EarthNet2021 dataset](https://www.earthnet.tech/en21/ch-task/). 137 | We recommend download it via the [earthnet_toolket](https://github.com/earthnet2021/earthnet-toolkit). 138 | ```python 139 | import earthnet as en 140 | en.download(dataset="earthnet2021", splits="all", save_directory="./datasets/earthnet2021") 141 | ``` 142 | Alternatively, you may download [EarthNet2021x dataset](https://www.earthnet.tech/en21x/download/), which is the same as [EarthNet2021 dataset](https://www.earthnet.tech/en21/ch-task/) except for the file format (`.npz` for EarthNet2021 and `.nc` for EarthNet2021x). 143 | ```python 144 | import earthnet as en 145 | en.download(dataset="earthnet2021x", splits="all", save_directory="./datasets/earthnet2021x") 146 | ``` 147 | 148 | It requires 455G disk space in total. 149 | 150 | ## Earthformer Training 151 | Find detailed instructions in the corresponding training script folder 152 | - [N-body MNIST](./scripts/cuboid_transformer/nbody/README.md) 153 | - [SEVIR&SEVIR-LR](./scripts/cuboid_transformer/sevir/README.md) 154 | - [ENSO](./scripts/cuboid_transformer/enso/README.md) 155 | - [EarthNet2021](./scripts/cuboid_transformer/earthnet_w_meso/README.md) 156 | 157 | ## Training Script and Pretrained Models 158 | 159 | Find detailed instructions in how to train the models or running inference with our pretrained models in the corresponding script folder. 160 | 161 | | Dataset | Script Folder | Pretrained Weights | Config | 162 | |---------------|----------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------| 163 | | SEVIR | [scripts](./scripts/cuboid_transformer/sevir) | [link](https://earthformer.s3.amazonaws.com/pretrained_checkpoints/earthformer_sevir.pt) | [config](./scripts/cuboid_transformer/sevir/earthformer_sevir_v1.yaml) | 164 | | ICAR-ENSO | [scripts](./scripts/cuboid_transformer/enso) | [link](https://earthformer.s3.amazonaws.com/pretrained_checkpoints/earthformer_icarenso2021.pt) | [config](./scripts/cuboid_transformer/enso/earthformer_enso_v1.yaml) | 165 | | EarthNet2021 | [scripts](./scripts/cuboid_transformer/earthnet_w_meso) | [link](https://earthformer.s3.amazonaws.com/pretrained_checkpoints/earthformer_earthnet2021.pt) | [config](./scripts/cuboid_transformer/earthnet_w_meso/earthformer_earthnet_v1.yaml) | 166 | | N-body MNIST | [scripts](./scripts/cuboid_transformer/nbody) | - | - | 167 | 168 | ## Citing Earthformer 169 | 170 | ``` 171 | @inproceedings{gao2022earthformer, 172 | title={Earthformer: Exploring Space-Time Transformers for Earth System Forecasting}, 173 | author={Gao, Zhihan and Shi, Xingjian and Wang, Hao and Zhu, Yi and Wang, Yuyang and Li, Mu and Yeung, Dit-Yan}, 174 | booktitle={NeurIPS}, 175 | year={2022} 176 | } 177 | ``` 178 | 179 | ## Security 180 | 181 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 182 | 183 | 184 | ## Credits 185 | Third-party libraries: 186 | - [PyTorch](https://pytorch.org/) 187 | - [PyTorch Lightning](https://www.pytorchlightning.ai/) 188 | - [Apex](https://github.com/NVIDIA/apex) 189 | - [OpenCV](https://opencv.org/) 190 | - [TensorBoard](https://www.tensorflow.org/tensorboard) 191 | - [OmegaConf](https://github.com/omry/omegaconf) 192 | - [YACS](https://github.com/rbgirshick/yacs) 193 | - [Pillow](https://python-pillow.org/) 194 | - [scikit-learn](https://scikit-learn.org/stable/) 195 | 196 | ## License 197 | 198 | This project is licensed under the Apache-2.0 License. 199 | 200 | -------------------------------------------------------------------------------- /datasets/nbody/mnist.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/datasets/nbody/mnist.npz -------------------------------------------------------------------------------- /figures/cuboid_illustration.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/cuboid_illustration.gif -------------------------------------------------------------------------------- /figures/data/earthnet/vis_earthnet_ndvi_train0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/earthnet/vis_earthnet_ndvi_train0.png -------------------------------------------------------------------------------- /figures/data/earthnet/vis_earthnet_rgb_train0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/earthnet/vis_earthnet_rgb_train0.png -------------------------------------------------------------------------------- /figures/data/nbody/chaos_nbody_dis01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/chaos_nbody_dis01.png -------------------------------------------------------------------------------- /figures/data/nbody/chaos_nbody_dis02_seed0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/chaos_nbody_dis02_seed0.png -------------------------------------------------------------------------------- /figures/data/nbody/chaos_seed9/mm_dis01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/chaos_seed9/mm_dis01.png -------------------------------------------------------------------------------- /figures/data/nbody/chaos_seed9/mm_dis01_seq0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/chaos_seed9/mm_dis01_seq0.npy -------------------------------------------------------------------------------- /figures/data/nbody/chaos_seed9/mm_dis01_seq1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/chaos_seed9/mm_dis01_seq1.npy -------------------------------------------------------------------------------- /figures/data/nbody/chaos_seed9/nbody_r0_dis02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/chaos_seed9/nbody_r0_dis02.png -------------------------------------------------------------------------------- /figures/data/nbody/chaos_seed9/nbody_r0_dis02_seq0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/chaos_seed9/nbody_r0_dis02_seq0.npy -------------------------------------------------------------------------------- /figures/data/nbody/chaos_seed9/nbody_r0_dis02_seq1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/chaos_seed9/nbody_r0_dis02_seq1.npy -------------------------------------------------------------------------------- /figures/data/nbody/mnist_chaos_dis01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/mnist_chaos_dis01.png -------------------------------------------------------------------------------- /figures/data/nbody/vis_chaos_seed9.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/vis_chaos_seed9.pdf -------------------------------------------------------------------------------- /figures/data/nbody/vis_chaos_seed9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/nbody/vis_chaos_seed9.png -------------------------------------------------------------------------------- /figures/data/sevir/sevir_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/sevir/sevir_example.png -------------------------------------------------------------------------------- /figures/data/sevir/sevir_example_len7.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/data/sevir/sevir_example_len7.pdf -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/figures/teaser.png -------------------------------------------------------------------------------- /scripts/baselines/persistence/earthnet/README.md: -------------------------------------------------------------------------------- 1 | # Test Persistence on EarthNet2021 2 | Run the following command to test Persistence on EarthNet2021 dataset. 3 | Change the configurations in [corresponding cfg.yaml](./cfg.yaml) 4 | ```bash 5 | cd ROOT_DIR/earth-forecasting-transformer 6 | MASTER_ADDR=localhost MASTER_PORT=10001 python ./scripts/baselines/persistence/earthnet/test_persistence_earthnet.py --gpus 2 --cfg ./scripts/baselines/persistence/earthnet/cfg.yaml --save tmp_earthnet_persistence 7 | ``` 8 | Run the tensorboard command to upload experiment records 9 | ```bash 10 | cd ROOT_DIR/earth-forecasting-transformer 11 | tensorboard dev upload --logdir ./experiments/tmp_earthnet/lightning_logs --name 'tmp_earthnet' 12 | -------------------------------------------------------------------------------- /scripts/baselines/persistence/earthnet/cfg.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | return_mode: "minimal" 3 | test_subset_name: ["iid", "ood"] 4 | in_len: 10 5 | out_len: 20 6 | layout: "THWC" 7 | static_layout: null 8 | val_ratio: 0.1 9 | train_val_split_seed: null 10 | highresstatic_expand_t: false 11 | mesostatic_expand_t: false 12 | meso_crop: null 13 | fp16: false 14 | optim: 15 | micro_batch_size: 4 16 | seed: 0 17 | logging: 18 | logging_prefix: "Persistence_EarthNet" 19 | use_wandb: false 20 | vis: 21 | train_example_data_idx_list: [0, ] 22 | val_example_data_idx_list: [0, ] 23 | test_example_data_idx_list: [0, ] 24 | eval_example_only: false 25 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/earthnet_w_meso/README.md: -------------------------------------------------------------------------------- 1 | # Earthformer Training on EarthNet2021 with auxiliary meso scale data 2 | Run the following command to train Earthformer on EarthNet2021 dataset. 3 | Change the configurations in [cfg.yaml](./cfg.yaml) 4 | ```bash 5 | MASTER_ADDR=localhost MASTER_PORT=10001 python train_cuboid_earthnet.py --gpus 2 --cfg cfg.yaml --ckpt_name last.ckpt --save tmp_earthnet_w_meso 6 | ``` 7 | 8 | Or run the following command to directly load pretrained checkpoint for test. 9 | ```bash 10 | MASTER_ADDR=localhost MASTER_PORT=10001 python train_cuboid_earthnet.py --gpus 2 --pretrained --save tmp_earthnet_w_meso 11 | ``` 12 | Run the tensorboard command to upload experiment records 13 | ```bash 14 | tensorboard dev upload --logdir ./experiments/tmp_earthnet_w_meso/lightning_logs --name 'tmp_earthnet_w_meso' 15 | ``` 16 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/earthnet_w_meso/cfg.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | return_mode: "default" 3 | data_aug_mode: null 4 | data_aug_cfg: 5 | mixup_alpha: 0.2 6 | test_subset_name: ["iid", "ood"] 7 | in_len: 10 8 | out_len: 20 9 | layout: "THWC" 10 | static_layout: "CHW" 11 | val_ratio: 0.1 12 | train_val_split_seed: null 13 | highresstatic_expand_t: false 14 | mesostatic_expand_t: false 15 | meso_crop: null 16 | fp16: false 17 | layout: 18 | in_len: 10 19 | out_len: 20 20 | layout: "NTHWC" 21 | img_height: 128 22 | img_width: 128 23 | optim: 24 | total_batch_size: 32 25 | micro_batch_size: 2 26 | seed: 0 27 | method: "adamw" 28 | lr: 0.0005 29 | wd: 1.0e-05 30 | gradient_clip_val: 1.0 31 | max_epochs: 200 32 | # scheduler 33 | lr_scheduler_mode: "cosine" 34 | min_lr_ratio: 1.0e-3 35 | warmup_min_lr_ratio: 0.0 36 | warmup_percentage: 0.2 37 | # early stopping 38 | early_stop: true 39 | early_stop_mode: "min" 40 | early_stop_patience: 5 41 | save_top_k: 5 42 | logging: 43 | logging_prefix: "Cuboid_EarthNet" 44 | monitor_lr: true 45 | monitor_device: false 46 | track_grad_norm: -1 47 | use_wandb: false 48 | trainer: 49 | check_val_every_n_epoch: 5 50 | log_step_ratio: 0.001 51 | precision: 32 52 | vis: 53 | train_example_data_idx_list: [0, ] 54 | val_example_data_idx_list: [0, ] 55 | test_example_data_idx_list: [0, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400] 56 | eval_example_only: false 57 | model: 58 | input_shape: [10, 128, 128, 4] 59 | target_shape: [20, 128, 128, 4] 60 | base_units: 256 61 | # block_units: null 62 | scale_alpha: 1.0 63 | 64 | enc_depth: [1, 1] 65 | dec_depth: [1, 1] 66 | enc_use_inter_ffn: true 67 | dec_use_inter_ffn: true 68 | dec_hierarchical_pos_embed: false 69 | 70 | downsample: 2 71 | downsample_type: "patch_merge" 72 | upsample_type: "upsample" 73 | 74 | num_global_vectors: 8 75 | use_dec_self_global: false 76 | dec_self_update_global: true 77 | use_dec_cross_global: false 78 | use_global_vector_ffn: false 79 | use_global_self_attn: true 80 | separate_global_qkv: true 81 | global_dim_ratio: 1 82 | 83 | self_pattern: "axial" 84 | cross_self_pattern: "axial" 85 | cross_pattern: "cross_1x1" 86 | dec_cross_last_n_frames: null 87 | 88 | attn_drop: 0.1 89 | proj_drop: 0.1 90 | ffn_drop: 0.1 91 | num_heads: 4 92 | 93 | ffn_activation: "gelu" 94 | gated_ffn: false 95 | norm_layer: "layer_norm" 96 | padding_type: "zeros" 97 | pos_embed_type: "t+hw" 98 | use_relative_pos: true 99 | self_attn_use_final_proj: true 100 | 101 | checkpoint_level: 0 102 | 103 | initial_downsample_type: "stack_conv" 104 | initial_downsample_activation: "leaky" 105 | initial_downsample_stack_conv_num_layers: 2 106 | initial_downsample_stack_conv_dim_list: [64, 256] 107 | initial_downsample_stack_conv_downscale_list: [2, 2] 108 | initial_downsample_stack_conv_num_conv_list: [2, 2] 109 | 110 | attn_linear_init_mode: "0" 111 | ffn_linear_init_mode: "0" 112 | conv_init_mode: "0" 113 | down_up_linear_init_mode: "0" 114 | norm_init_mode: "0" 115 | # different from CuboidTransformerModel, no arg `dec_use_first_self_attn=False` 116 | auxiliary_channels: 7 117 | unet_dec_cross_mode: "both" 118 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/earthnet_w_meso/earthformer_earthnet_v1.yaml: -------------------------------------------------------------------------------- 1 | # Configurations for pretrained checkpoint. Please do not modify it. 2 | dataset: 3 | return_mode: "default" 4 | data_aug_mode: null 5 | data_aug_cfg: 6 | mixup_alpha: 0.2 7 | test_subset_name: ["iid", "ood"] 8 | in_len: 10 9 | out_len: 20 10 | layout: "THWC" 11 | static_layout: "CHW" 12 | val_ratio: 0.1 13 | train_val_split_seed: null 14 | highresstatic_expand_t: false 15 | mesostatic_expand_t: false 16 | meso_crop: null 17 | fp16: false 18 | layout: 19 | in_len: 10 20 | out_len: 20 21 | layout: "NTHWC" 22 | img_height: 128 23 | img_width: 128 24 | optim: 25 | total_batch_size: 32 26 | micro_batch_size: 2 27 | seed: 0 28 | method: "adamw" 29 | lr: 0.0005 30 | wd: 1.0e-05 31 | gradient_clip_val: 1.0 32 | max_epochs: 200 33 | # scheduler 34 | lr_scheduler_mode: "cosine" 35 | min_lr_ratio: 1.0e-3 36 | warmup_min_lr_ratio: 0.0 37 | warmup_percentage: 0.2 38 | # early stopping 39 | early_stop: true 40 | early_stop_mode: "min" 41 | early_stop_patience: 5 42 | save_top_k: 5 43 | logging: 44 | logging_prefix: "Cuboid_EarthNet" 45 | monitor_lr: true 46 | monitor_device: false 47 | track_grad_norm: -1 48 | use_wandb: false 49 | trainer: 50 | check_val_every_n_epoch: 5 51 | log_step_ratio: 0.001 52 | precision: 32 53 | vis: 54 | train_example_data_idx_list: [0, ] 55 | val_example_data_idx_list: [0, ] 56 | test_example_data_idx_list: [0, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400] 57 | eval_example_only: false 58 | model: 59 | input_shape: [10, 128, 128, 4] 60 | target_shape: [20, 128, 128, 4] 61 | base_units: 256 62 | # block_units: null 63 | scale_alpha: 1.0 64 | 65 | enc_depth: [1, 1] 66 | dec_depth: [1, 1] 67 | enc_use_inter_ffn: true 68 | dec_use_inter_ffn: true 69 | dec_hierarchical_pos_embed: false 70 | 71 | downsample: 2 72 | downsample_type: "patch_merge" 73 | upsample_type: "upsample" 74 | 75 | num_global_vectors: 8 76 | use_dec_self_global: false 77 | dec_self_update_global: true 78 | use_dec_cross_global: false 79 | use_global_vector_ffn: false 80 | use_global_self_attn: true 81 | separate_global_qkv: true 82 | global_dim_ratio: 1 83 | 84 | self_pattern: "axial" 85 | cross_self_pattern: "axial" 86 | cross_pattern: "cross_1x1" 87 | dec_cross_last_n_frames: null 88 | 89 | attn_drop: 0.1 90 | proj_drop: 0.1 91 | ffn_drop: 0.1 92 | num_heads: 4 93 | 94 | ffn_activation: "gelu" 95 | gated_ffn: false 96 | norm_layer: "layer_norm" 97 | padding_type: "zeros" 98 | pos_embed_type: "t+hw" 99 | use_relative_pos: true 100 | self_attn_use_final_proj: true 101 | 102 | checkpoint_level: 0 103 | 104 | initial_downsample_type: "stack_conv" 105 | initial_downsample_activation: "leaky" 106 | initial_downsample_stack_conv_num_layers: 2 107 | initial_downsample_stack_conv_dim_list: [64, 256] 108 | initial_downsample_stack_conv_downscale_list: [2, 2] 109 | initial_downsample_stack_conv_num_conv_list: [2, 2] 110 | 111 | attn_linear_init_mode: "0" 112 | ffn_linear_init_mode: "0" 113 | conv_init_mode: "0" 114 | down_up_linear_init_mode: "0" 115 | norm_init_mode: "0" 116 | # different from CuboidTransformerModel, no arg `dec_use_first_self_attn=False` 117 | auxiliary_channels: 7 118 | unet_dec_cross_mode: "both" 119 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/earthnet_w_meso/predict_cuboid_en21x.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import xarray as xr 4 | 5 | import torch 6 | import os 7 | from train_cuboid_earthnet import CuboidEarthNet2021PLModule 8 | from earthformer.utils.utils import download 9 | from omegaconf import OmegaConf 10 | from earthformer.datasets.earthnet.earthnet21x_dataloader import EarthNet2021xLightningDataModule, EarthNet2021xTestDataset 11 | from pytorch_lightning import Trainer, seed_everything, loggers as pl_loggers 12 | from earthformer.datasets.earthnet.visualization import vis_earthnet_seq 13 | import numpy as np 14 | from pathlib import Path 15 | from tqdm import tqdm 16 | 17 | 18 | def process_sample(sample): 19 | # High-resolution Earth surface data. The channels are [blue, green, red, nir, cloud] 20 | highresdynamic = sample['highresdynamic'] 21 | highresstatic = sample['highresstatic'] 22 | 23 | # The meso-scale data. The channels are ["precipitation", "pressure", "temp mean", "temp min", "temp max"] 24 | mesodynamic = sample['mesodynamic'] 25 | mesostatic = sample['mesostatic'] 26 | 27 | highresdynamic = np.nan_to_num(highresdynamic, nan=0.0, posinf=1.0, neginf=0.0) 28 | highresdynamic = np.clip(highresdynamic, a_min=0.0, a_max=1.0) 29 | mesodynamic = np.nan_to_num(mesodynamic, nan=0.0) 30 | highresstatic = np.nan_to_num(highresstatic, nan=0.0) 31 | mesostatic = np.nan_to_num(mesostatic, nan=0.0) 32 | return highresdynamic, highresstatic, mesodynamic, mesostatic 33 | 34 | def main(): 35 | print("Loading Config") 36 | pred_dir = Path("./experiments/preds_en21x/") 37 | 38 | config_file = "./earthformer_earthnet_v1.yaml" 39 | config = OmegaConf.load(open(config_file, "r")) 40 | in_len = config.layout.in_len 41 | out_len = config.layout.out_len 42 | 43 | seed = config.optim.seed 44 | dataset_cfg = OmegaConf.to_object(config.dataset) 45 | 46 | seed_everything(seed, workers=True) 47 | 48 | micro_batch_size = 1 49 | print("Loading Dataset") 50 | earthnet_iid_testset = EarthNet2021xTestDataset(subset_name="iid", 51 | data_dir="/Net/Groups/BGI/work_1/scratch/s3/earthnet/earthnet2021x/iid/", 52 | layout=config.dataset.layout, 53 | static_layout=config.dataset.static_layout, 54 | highresstatic_expand_t=config.dataset.highresstatic_expand_t, 55 | mesostatic_expand_t=config.dataset.mesostatic_expand_t, 56 | meso_crop=None, 57 | fp16=False) 58 | 59 | save_dir = "./experiments" 60 | print("Loading Model") 61 | pl_module = CuboidEarthNet2021PLModule( 62 | total_num_steps=None, 63 | save_dir="./experiments", 64 | oc_file=config_file 65 | ) 66 | 67 | pretrained_checkpoint_url = "https://earthformer.s3.amazonaws.com/pretrained_checkpoints/earthformer_earthnet2021.pt" 68 | local_checkpoint_path = os.path.join(save_dir, "earthformer_earthnet2021.pt") 69 | download(url=pretrained_checkpoint_url, path=local_checkpoint_path) 70 | 71 | state_dict = torch.load(local_checkpoint_path, map_location=torch.device("cpu")) 72 | pl_module.torch_nn_module.load_state_dict(state_dict=state_dict) 73 | 74 | pl_module.torch_nn_module.cuda() 75 | pl_module.torch_nn_module.eval() 76 | print("Starting Predictions") 77 | for idx in tqdm(range(len(earthnet_iid_testset))): 78 | highresdynamic, highresstatic, mesodynamic, mesostatic = process_sample(earthnet_iid_testset[idx]) 79 | 80 | 81 | with torch.no_grad(): 82 | pred_seq, loss, in_seq, target_seq, mask = pl_module({"highresdynamic": torch.tensor(np.expand_dims(highresdynamic, axis=0)).cuda(), 83 | "highresstatic": torch.tensor(np.expand_dims(highresstatic, axis=0)).cuda(), 84 | "mesodynamic": torch.tensor(np.expand_dims(mesodynamic, axis=0)).cuda(), 85 | "mesostatic": torch.tensor(np.expand_dims(mesostatic, axis=0)).cuda()}) 86 | pred_seq_np = pred_seq.detach().cpu().numpy() 87 | 88 | targ_path = Path(earthnet_iid_testset.nc_path_list[idx]) 89 | 90 | targ_cube = xr.open_dataset(targ_path) 91 | 92 | lat = targ_cube.lat 93 | lon = targ_cube.lon 94 | 95 | blue_pred = pred_seq_np[0,:,:,:,0] 96 | green_pred = pred_seq_np[0,:,:,:,1] 97 | red_pred = pred_seq_np[0,:,:,:,2] 98 | nir_pred = pred_seq_np[0,:,:,:,3] 99 | ndvi_pred = ((nir_pred - red_pred) / (nir_pred + red_pred + 1e-8)) 100 | 101 | pred_cube = xr.Dataset({"ndvi_pred": xr.DataArray(data = ndvi_pred, coords = {"time": targ_cube.time.isel(time = slice(4,None,5)).isel(time = slice(in_len, in_len + out_len)), "lat": lat, "lon": lon}, dims = ["time","lat", "lon"])}) 102 | 103 | 104 | pred_path = pred_dir/targ_path.parent.stem/targ_path.name 105 | pred_path.parent.mkdir(parents = True, exist_ok = True) 106 | if not pred_path.is_file(): 107 | pred_cube.to_netcdf(pred_path, encoding={"ndvi_pred":{"dtype": "float32"}}) 108 | 109 | if __name__ == "__main__": 110 | main() -------------------------------------------------------------------------------- /scripts/cuboid_transformer/enso/README.md: -------------------------------------------------------------------------------- 1 | # Earthformer Training on ICAR-ENSO dataset 2 | Run the following command to train Earthformer on ICAR-ENSO dataset. 3 | Change the configurations in [cfg.yaml](./cfg.yaml) 4 | ```bash 5 | MASTER_ADDR=localhost MASTER_PORT=10001 python train_cuboid_enso.py --gpus 2 --cfg cfg.yaml --ckpt_name last.ckpt --save tmp_enso 6 | ``` 7 | Or run the following command to directly load pretrained checkpoint for test. 8 | ```bash 9 | MASTER_ADDR=localhost MASTER_PORT=10001 python train_cuboid_enso.py --gpus 2 --pretrained --save tmp_enso 10 | ``` 11 | Run the tensorboard command to upload experiment records 12 | ```bash 13 | tensorboard dev upload --logdir ./experiments/tmp_enso/lightning_logs --name 'tmp_enso' 14 | ``` 15 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/enso/cfg.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | in_len: 12 3 | out_len: 14 4 | nino_window_t: 3 5 | in_stride: 1 6 | out_stride: 1 7 | train_samples_gap: 1 8 | eval_samples_gap: 1 9 | normalize_sst: true 10 | layout: 11 | in_len: 12 12 | out_len: 14 13 | layout: "NTHWC" 14 | optim: 15 | total_batch_size: 64 16 | micro_batch_size: 16 17 | seed: 0 18 | method: "adamw" 19 | lr: 0.0001 20 | wd: 1.0e-05 21 | gradient_clip_val: 1.0 22 | max_epochs: 100 23 | # scheduler 24 | lr_scheduler_mode: "cosine" 25 | min_lr_ratio: 1.0e-3 26 | warmup_min_lr_ratio: 0.0 27 | warmup_percentage: 0.2 28 | # early stopping 29 | early_stop: true 30 | early_stop_mode: "min" 31 | early_stop_patience: 5 32 | save_top_k: 5 33 | logging: 34 | logging_prefix: "Cuboid_ENSO" 35 | monitor_lr: true 36 | monitor_device: false 37 | track_grad_norm: -1 38 | use_wandb: false 39 | trainer: 40 | check_val_every_n_epoch: 5 41 | log_step_ratio: 0.001 42 | precision: 32 43 | vis: 44 | train_example_data_idx_list: [0, ] 45 | val_example_data_idx_list: [0, ] 46 | test_example_data_idx_list: [0, ] 47 | eval_example_only: false 48 | model: 49 | input_shape: [12, 24, 48, 1] 50 | target_shape: [14, 24, 48, 1] 51 | base_units: 64 52 | # block_units: null 53 | scale_alpha: 1.0 54 | 55 | enc_depth: [1, 1] 56 | dec_depth: [1, 1] 57 | enc_use_inter_ffn: true 58 | dec_use_inter_ffn: true 59 | dec_hierarchical_pos_embed: false 60 | 61 | downsample: 2 62 | downsample_type: "patch_merge" 63 | upsample_type: "upsample" 64 | 65 | num_global_vectors: 0 66 | use_dec_self_global: false 67 | dec_self_update_global: true 68 | use_dec_cross_global: false 69 | use_global_vector_ffn: false 70 | use_global_self_attn: false 71 | separate_global_qkv: false 72 | global_dim_ratio: 1 73 | 74 | self_pattern: "axial" 75 | cross_self_pattern: "axial" 76 | cross_pattern: "cross_1x1" 77 | dec_cross_last_n_frames: null 78 | 79 | attn_drop: 0.1 80 | proj_drop: 0.1 81 | ffn_drop: 0.1 82 | num_heads: 4 83 | 84 | ffn_activation: "gelu" 85 | gated_ffn: false 86 | norm_layer: "layer_norm" 87 | padding_type: "zeros" 88 | pos_embed_type: "t+h+w" 89 | use_relative_pos: true 90 | self_attn_use_final_proj: true 91 | dec_use_first_self_attn: false 92 | 93 | z_init_method: "zeros" 94 | initial_downsample_type: "conv" 95 | initial_downsample_activation: "leaky" 96 | initial_downsample_scale: [1, 1, 2] 97 | initial_downsample_conv_layers: 2 98 | final_upsample_conv_layers: 1 99 | checkpoint_level: 0 100 | 101 | attn_linear_init_mode: "0" 102 | ffn_linear_init_mode: "0" 103 | conv_init_mode: "0" 104 | down_up_linear_init_mode: "0" 105 | norm_init_mode: "0" 106 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/enso/earthformer_enso_v1.yaml: -------------------------------------------------------------------------------- 1 | # Configurations for pretrained checkpoint. Please do not modify it. 2 | dataset: 3 | in_len: 12 4 | out_len: 14 5 | nino_window_t: 3 6 | in_stride: 1 7 | out_stride: 1 8 | train_samples_gap: 1 9 | eval_samples_gap: 1 10 | normalize_sst: true 11 | layout: 12 | in_len: 12 13 | out_len: 14 14 | layout: "NTHWC" 15 | optim: 16 | total_batch_size: 64 17 | micro_batch_size: 16 18 | seed: 0 19 | method: "adamw" 20 | lr: 0.0001 21 | wd: 1.0e-05 22 | gradient_clip_val: 1.0 23 | max_epochs: 100 24 | # scheduler 25 | lr_scheduler_mode: "cosine" 26 | min_lr_ratio: 1.0e-3 27 | warmup_min_lr_ratio: 0.0 28 | warmup_percentage: 0.2 29 | # early stopping 30 | early_stop: true 31 | early_stop_mode: "min" 32 | early_stop_patience: 5 33 | save_top_k: 5 34 | logging: 35 | logging_prefix: "Cuboid_ENSO" 36 | monitor_lr: true 37 | monitor_device: false 38 | track_grad_norm: -1 39 | use_wandb: false 40 | trainer: 41 | check_val_every_n_epoch: 5 42 | log_step_ratio: 0.001 43 | precision: 32 44 | vis: 45 | train_example_data_idx_list: [0, ] 46 | val_example_data_idx_list: [0, ] 47 | test_example_data_idx_list: [0, ] 48 | eval_example_only: false 49 | model: 50 | input_shape: [12, 24, 48, 1] 51 | target_shape: [14, 24, 48, 1] 52 | base_units: 64 53 | # block_units: null 54 | scale_alpha: 1.0 55 | 56 | enc_depth: [1, 1] 57 | dec_depth: [1, 1] 58 | enc_use_inter_ffn: true 59 | dec_use_inter_ffn: true 60 | dec_hierarchical_pos_embed: false 61 | 62 | downsample: 2 63 | downsample_type: "patch_merge" 64 | upsample_type: "upsample" 65 | 66 | num_global_vectors: 0 67 | use_dec_self_global: false 68 | dec_self_update_global: true 69 | use_dec_cross_global: false 70 | use_global_vector_ffn: false 71 | use_global_self_attn: false 72 | separate_global_qkv: false 73 | global_dim_ratio: 1 74 | 75 | self_pattern: "axial" 76 | cross_self_pattern: "axial" 77 | cross_pattern: "cross_1x1" 78 | dec_cross_last_n_frames: null 79 | 80 | attn_drop: 0.1 81 | proj_drop: 0.1 82 | ffn_drop: 0.1 83 | num_heads: 4 84 | 85 | ffn_activation: "gelu" 86 | gated_ffn: false 87 | norm_layer: "layer_norm" 88 | padding_type: "zeros" 89 | pos_embed_type: "t+h+w" 90 | use_relative_pos: true 91 | self_attn_use_final_proj: true 92 | dec_use_first_self_attn: false 93 | 94 | z_init_method: "zeros" 95 | initial_downsample_type: "conv" 96 | initial_downsample_activation: "leaky" 97 | initial_downsample_scale: [1, 1, 2] 98 | initial_downsample_conv_layers: 2 99 | final_upsample_conv_layers: 1 100 | checkpoint_level: 0 101 | 102 | attn_linear_init_mode: "0" 103 | ffn_linear_init_mode: "0" 104 | conv_init_mode: "0" 105 | down_up_linear_init_mode: "0" 106 | norm_init_mode: "0" 107 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/moving_mnist/README.md: -------------------------------------------------------------------------------- 1 | # Earthformer Training on Moving MNIST 2 | Run the following command to train Earthformer on Moving MNIST dataset. 3 | Change the configurations in [cfg.yaml](./cfg.yaml) 4 | ```bash 5 | MASTER_ADDR=localhost MASTER_PORT=10001 python train_cuboid_mnist.py --gpus 2 --cfg cfg.yaml --ckpt_name last.ckpt --save tmp_mnist 6 | ``` 7 | Run the tensorboard command to upload experiment records 8 | ```bash 9 | tensorboard dev upload --logdir ./experiments/tmp_mnist/lightning_logs --name 'tmp_mnist' 10 | ``` 11 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/moving_mnist/cfg.yaml: -------------------------------------------------------------------------------- 1 | layout: 2 | in_len: 10 3 | out_len: 10 4 | layout: "NTHWC" 5 | data_seq_len: 20 6 | optim: 7 | total_batch_size: 32 8 | micro_batch_size: 2 9 | seed: 0 10 | method: "adamw" 11 | lr: 0.001 12 | wd: 1.0e-05 13 | gradient_clip_val: 1.0 14 | max_epochs: 100 15 | # scheduler 16 | lr_scheduler_mode: "cosine" 17 | min_lr_ratio: 1.0e-3 18 | warmup_min_lr_ratio: 0.0 19 | warmup_percentage: 0.2 20 | # early stopping 21 | early_stop: true 22 | early_stop_mode: "min" 23 | early_stop_patience: 20 24 | save_top_k: 1 25 | logging: 26 | logging_prefix: "Cuboid_MovingMNIST" 27 | monitor_lr: true 28 | monitor_device: false 29 | track_grad_norm: -1 30 | use_wandb: false 31 | trainer: 32 | check_val_every_n_epoch: 1 33 | log_step_ratio: 0.001 34 | precision: 32 35 | vis: 36 | train_example_data_idx_list: [0, ] 37 | val_example_data_idx_list: [0, ] 38 | test_example_data_idx_list: [0, ] 39 | eval_example_only: false 40 | model: 41 | input_shape: [10, 64, 64, 1] 42 | target_shape: [10, 64, 64, 1] 43 | base_units: 64 44 | # block_units: null 45 | scale_alpha: 1.0 46 | 47 | enc_depth: [4, 4] 48 | dec_depth: [4, 4] 49 | enc_use_inter_ffn: true 50 | dec_use_inter_ffn: true 51 | dec_hierarchical_pos_embed: false 52 | 53 | downsample: 2 54 | downsample_type: "patch_merge" 55 | upsample_type: "upsample" 56 | 57 | num_global_vectors: 0 58 | use_dec_self_global: false 59 | dec_self_update_global: true 60 | use_dec_cross_global: false 61 | use_global_vector_ffn: false 62 | use_global_self_attn: false 63 | separate_global_qkv: false 64 | global_dim_ratio: 1 65 | 66 | self_pattern: "axial" 67 | cross_self_pattern: "axial" 68 | cross_pattern: "cross_1x1" 69 | dec_cross_last_n_frames: null 70 | 71 | attn_drop: 0.1 72 | proj_drop: 0.1 73 | ffn_drop: 0.1 74 | num_heads: 4 75 | 76 | ffn_activation: "gelu" 77 | gated_ffn: false 78 | norm_layer: "layer_norm" 79 | padding_type: "zeros" 80 | pos_embed_type: "t+hw" 81 | use_relative_pos: true 82 | self_attn_use_final_proj: true 83 | dec_use_first_self_attn: false 84 | 85 | z_init_method: "zeros" 86 | initial_downsample_type: "conv" 87 | initial_downsample_activation: "leaky" 88 | initial_downsample_scale: 2 89 | initial_downsample_conv_layers: 2 90 | final_upsample_conv_layers: 1 91 | checkpoint_level: 0 92 | 93 | attn_linear_init_mode: "0" 94 | ffn_linear_init_mode: "0" 95 | conv_init_mode: "0" 96 | down_up_linear_init_mode: "0" 97 | norm_init_mode: "0" 98 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/nbody/README.md: -------------------------------------------------------------------------------- 1 | # Earthformer Training on N-body MNIST 2 | Run the following command to train Earthformer on N-body MNIST dataset. 3 | Change the configurations in [cfg.yaml](./cfg.yaml) 4 | ```bash 5 | MASTER_ADDR=localhost MASTER_PORT=10001 python train_cuboid_nbody.py --gpus 2 --cfg cfg.yaml --ckpt_name last.ckpt --save tmp_nbody 6 | ``` 7 | Run the tensorboard command to upload experiment records 8 | ```bash 9 | tensorboard dev upload --logdir ./experiments/tmp_nbody/lightning_logs --name 'tmp_nbody' 10 | ``` 11 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/nbody/cfg.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset_name: "nbody_digits3_len20_size64_r0_train20k" 3 | num_train_samples: 20000 4 | num_val_samples: 1000 5 | num_test_samples: 1000 6 | digit_num: null 7 | img_size: 64 8 | raw_img_size: 128 9 | seq_len: 20 10 | raw_seq_len_multiplier: 5 11 | distractor_num: null 12 | distractor_size: 5 13 | max_velocity_scale: 2.0 14 | initial_velocity_range: [0.0, 2.0] 15 | random_acceleration_range: [0.0, 0.0] 16 | scale_variation_range: [1.0, 1.0] 17 | rotation_angle_range: [-0, 0] 18 | illumination_factor_range: [1.0, 1.0] 19 | period: 5 20 | global_rotation_prob: 0.5 21 | index_range: [0, 40000] 22 | mnist_data_path: null 23 | # N-Body params 24 | nbody_acc_mode: "r0" 25 | nbody_G: 0.05 26 | nbody_softening_distance: 10.0 27 | nbody_mass: null 28 | layout: 29 | in_len: 10 30 | out_len: 10 31 | layout: "NTHWC" 32 | data_seq_len: 20 33 | optim: 34 | total_batch_size: 32 35 | micro_batch_size: 2 36 | seed: 0 37 | method: "adamw" 38 | lr: 0.001 39 | wd: 1.0e-05 40 | gradient_clip_val: 1.0 41 | max_epochs: 100 42 | # scheduler 43 | lr_scheduler_mode: "cosine" 44 | min_lr_ratio: 1.0e-3 45 | warmup_min_lr_ratio: 0.0 46 | warmup_percentage: 0.2 47 | # early stopping 48 | early_stop: true 49 | early_stop_mode: "min" 50 | early_stop_patience: 20 51 | save_top_k: 1 52 | logging: 53 | logging_prefix: "Cuboid_NBody" 54 | monitor_lr: true 55 | monitor_device: false 56 | track_grad_norm: -1 57 | use_wandb: false 58 | trainer: 59 | check_val_every_n_epoch: 1 60 | log_step_ratio: 0.001 61 | precision: 32 62 | vis: 63 | train_example_data_idx_list: [0, ] 64 | val_example_data_idx_list: [0, ] 65 | test_example_data_idx_list: [0, ] 66 | eval_example_only: false 67 | model: 68 | input_shape: [10, 64, 64, 1] 69 | target_shape: [10, 64, 64, 1] 70 | base_units: 64 71 | # block_units: null 72 | scale_alpha: 1.0 73 | 74 | enc_depth: [4, 4] 75 | dec_depth: [4, 4] 76 | enc_use_inter_ffn: true 77 | dec_use_inter_ffn: true 78 | dec_hierarchical_pos_embed: false 79 | 80 | downsample: 2 81 | downsample_type: "patch_merge" 82 | upsample_type: "upsample" 83 | 84 | num_global_vectors: 0 85 | use_dec_self_global: false 86 | dec_self_update_global: true 87 | use_dec_cross_global: false 88 | use_global_vector_ffn: false 89 | use_global_self_attn: false 90 | separate_global_qkv: false 91 | global_dim_ratio: 1 92 | 93 | self_pattern: "axial" 94 | cross_self_pattern: "axial" 95 | cross_pattern: "cross_1x1" 96 | dec_cross_last_n_frames: null 97 | 98 | attn_drop: 0.1 99 | proj_drop: 0.1 100 | ffn_drop: 0.1 101 | num_heads: 4 102 | 103 | ffn_activation: "gelu" 104 | gated_ffn: false 105 | norm_layer: "layer_norm" 106 | padding_type: "zeros" 107 | pos_embed_type: "t+hw" 108 | use_relative_pos: true 109 | self_attn_use_final_proj: true 110 | dec_use_first_self_attn: false 111 | 112 | z_init_method: "zeros" 113 | initial_downsample_type: "conv" 114 | initial_downsample_activation: "leaky" 115 | initial_downsample_scale: 2 116 | initial_downsample_conv_layers: 2 117 | final_upsample_conv_layers: 1 118 | checkpoint_level: 0 119 | 120 | attn_linear_init_mode: "0" 121 | ffn_linear_init_mode: "0" 122 | conv_init_mode: "0" 123 | down_up_linear_init_mode: "0" 124 | norm_init_mode: "0" 125 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/sevir/README.md: -------------------------------------------------------------------------------- 1 | # Earthformer Training on SEVIR&SEVIR-LR 2 | ## SEVIR 3 | Run the following command to train Earthformer on SEVIR dataset. 4 | Change the configurations in [cfg_sevir.yaml](./cfg_sevir.yaml) 5 | ```bash 6 | MASTER_ADDR=localhost MASTER_PORT=10001 python train_cuboid_sevir.py --gpus 2 --cfg cfg_sevir.yaml --ckpt_name last.ckpt --save tmp_sevir 7 | ``` 8 | Or run the following command to directly load pretrained checkpoint for test. 9 | ```bash 10 | MASTER_ADDR=localhost MASTER_PORT=10001 python train_cuboid_sevir.py --gpus 2 --pretrained --save tmp_sevir 11 | ``` 12 | Run the tensorboard command to upload experiment records 13 | ```bash 14 | tensorboard dev upload --logdir ./experiments/tmp_sevir/lightning_logs --name 'tmp_sevir' 15 | ``` 16 | ## SEVIR-LR 17 | Run the following command to train Earthformer on SEVIR-LR dataset. 18 | Change the configurations in [cfg_sevirlr.yaml](./cfg_sevirlr.yaml) 19 | ```bash 20 | MASTER_ADDR=localhost MASTER_PORT=10001 python train_cuboid_sevir.py --gpus 2 --cfg cfg_sevirlr.yaml --ckpt_name last.ckpt --save tmp_sevirlr 21 | ``` 22 | Run the tensorboard command to upload experiment records 23 | ```bash 24 | tensorboard dev upload --logdir ./experiments/tmp_sevirlr/lightning_logs --name 'tmp_sevirlr' 25 | ``` 26 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/sevir/cfg_sevir.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset_name: "sevir" 3 | img_height: 384 4 | img_width: 384 5 | in_len: 13 6 | out_len: 12 7 | seq_len: 25 8 | plot_stride: 2 9 | interval_real_time: 5 10 | sample_mode: "sequent" 11 | stride: 12 12 | layout: "NTHWC" 13 | start_date: null 14 | train_val_split_date: [2019, 1, 1] 15 | train_test_split_date: [2019, 6, 1] 16 | end_date: null 17 | metrics_mode: "0" 18 | metrics_list: ['csi', 'pod', 'sucr', 'bias'] 19 | threshold_list: [16, 74, 133, 160, 181, 219] 20 | layout: 21 | in_len: 13 22 | out_len: 12 23 | layout: "NTHWC" 24 | optim: 25 | total_batch_size: 32 26 | micro_batch_size: 2 27 | seed: 0 28 | method: "adamw" 29 | lr: 0.001 30 | wd: 0.0 31 | gradient_clip_val: 1.0 32 | max_epochs: 100 33 | # scheduler 34 | lr_scheduler_mode: "cosine" 35 | min_lr_ratio: 1.0e-3 36 | warmup_min_lr_ratio: 0.0 37 | warmup_percentage: 0.2 38 | # early stopping 39 | early_stop: true 40 | early_stop_mode: "min" 41 | early_stop_patience: 20 42 | save_top_k: 1 43 | logging: 44 | logging_prefix: "Cuboid_SEVIR" 45 | monitor_lr: true 46 | monitor_device: false 47 | track_grad_norm: -1 48 | use_wandb: false 49 | trainer: 50 | check_val_every_n_epoch: 1 51 | log_step_ratio: 0.001 52 | precision: 32 53 | vis: 54 | train_example_data_idx_list: [0, ] 55 | val_example_data_idx_list: [80, ] 56 | test_example_data_idx_list: [0, 80, 160, 240, 320, 400] 57 | eval_example_only: false 58 | plot_stride: 2 59 | model: 60 | input_shape: [13, 384, 384, 1] 61 | target_shape: [12, 384, 384, 1] 62 | base_units: 128 63 | # block_units: null 64 | scale_alpha: 1.0 65 | 66 | enc_depth: [1, 1] 67 | dec_depth: [1, 1] 68 | enc_use_inter_ffn: true 69 | dec_use_inter_ffn: true 70 | dec_hierarchical_pos_embed: false 71 | 72 | downsample: 2 73 | downsample_type: "patch_merge" 74 | upsample_type: "upsample" 75 | 76 | num_global_vectors: 8 77 | use_dec_self_global: false 78 | dec_self_update_global: true 79 | use_dec_cross_global: false 80 | use_global_vector_ffn: false 81 | use_global_self_attn: true 82 | separate_global_qkv: true 83 | global_dim_ratio: 1 84 | 85 | self_pattern: "axial" 86 | cross_self_pattern: "axial" 87 | cross_pattern: "cross_1x1" 88 | dec_cross_last_n_frames: null 89 | 90 | attn_drop: 0.1 91 | proj_drop: 0.1 92 | ffn_drop: 0.1 93 | num_heads: 4 94 | 95 | ffn_activation: "gelu" 96 | gated_ffn: false 97 | norm_layer: "layer_norm" 98 | padding_type: "zeros" 99 | pos_embed_type: "t+h+w" 100 | use_relative_pos: true 101 | self_attn_use_final_proj: true 102 | dec_use_first_self_attn: false 103 | 104 | z_init_method: "zeros" 105 | checkpoint_level: 0 106 | 107 | initial_downsample_type: "stack_conv" 108 | initial_downsample_activation: "leaky" 109 | initial_downsample_stack_conv_num_layers: 3 110 | initial_downsample_stack_conv_dim_list: [16, 64, 128] 111 | initial_downsample_stack_conv_downscale_list: [3, 2, 2] 112 | initial_downsample_stack_conv_num_conv_list: [2, 2, 2] 113 | 114 | attn_linear_init_mode: "0" 115 | ffn_linear_init_mode: "0" 116 | conv_init_mode: "0" 117 | down_up_linear_init_mode: "0" 118 | norm_init_mode: "0" 119 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/sevir/cfg_sevirlr.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset_name: "sevir_lr" 3 | img_height: 128 4 | img_width: 128 5 | in_len: 7 6 | out_len: 6 7 | seq_len: 13 8 | plot_stride: 1 9 | interval_real_time: 10 10 | sample_mode: "sequent" 11 | stride: 6 12 | layout: "NTHWC" 13 | start_date: null 14 | train_val_split_date: [2019, 1, 1] 15 | train_test_split_date: [2019, 6, 1] 16 | end_date: null 17 | metrics_mode: "0" 18 | metrics_list: ['csi', 'pod', 'sucr', 'bias'] 19 | threshold_list: [16, 74, 133, 160, 181, 219] 20 | layout: 21 | in_len: 7 22 | out_len: 6 23 | layout: "NTHWC" 24 | optim: 25 | total_batch_size: 32 26 | micro_batch_size: 2 27 | seed: 0 28 | method: "adamw" 29 | lr: 0.001 30 | wd: 0.0 31 | gradient_clip_val: 1.0 32 | max_epochs: 100 33 | # scheduler 34 | lr_scheduler_mode: "cosine" 35 | min_lr_ratio: 1.0e-3 36 | warmup_min_lr_ratio: 0.0 37 | warmup_percentage: 0.2 38 | # early stopping 39 | early_stop: true 40 | early_stop_mode: "min" 41 | early_stop_patience: 20 42 | save_top_k: 1 43 | logging: 44 | logging_prefix: "Cuboid_SEVIR_LR" 45 | monitor_lr: true 46 | monitor_device: false 47 | track_grad_norm: -1 48 | use_wandb: false 49 | trainer: 50 | check_val_every_n_epoch: 1 51 | log_step_ratio: 0.001 52 | precision: 32 53 | vis: 54 | train_example_data_idx_list: [0, ] 55 | val_example_data_idx_list: [80, ] 56 | test_example_data_idx_list: [0, 80, 160, 240, 320, 400] 57 | eval_example_only: false 58 | plot_stride: 2 59 | model: 60 | input_shape: [7, 128, 128, 1] 61 | target_shape: [6, 128, 128, 1] 62 | base_units: 64 63 | # block_units: null 64 | scale_alpha: 1.0 65 | 66 | enc_depth: [1, 1] 67 | dec_depth: [1, 1] 68 | enc_use_inter_ffn: true 69 | dec_use_inter_ffn: true 70 | dec_hierarchical_pos_embed: false 71 | 72 | downsample: 2 73 | downsample_type: "patch_merge" 74 | upsample_type: "upsample" 75 | 76 | num_global_vectors: 0 77 | use_dec_self_global: false 78 | dec_self_update_global: true 79 | use_dec_cross_global: false 80 | use_global_vector_ffn: false 81 | use_global_self_attn: false 82 | separate_global_qkv: false 83 | global_dim_ratio: 1 84 | 85 | self_pattern: "axial" 86 | cross_self_pattern: "axial" 87 | cross_pattern: "cross_1x1" 88 | dec_cross_last_n_frames: null 89 | 90 | attn_drop: 0.1 91 | proj_drop: 0.1 92 | ffn_drop: 0.1 93 | num_heads: 4 94 | 95 | ffn_activation: "gelu" 96 | gated_ffn: false 97 | norm_layer: "layer_norm" 98 | padding_type: "zeros" 99 | pos_embed_type: "t+h+w" 100 | use_relative_pos: true 101 | self_attn_use_final_proj: true 102 | dec_use_first_self_attn: false 103 | 104 | z_init_method: "zeros" 105 | checkpoint_level: 0 106 | 107 | initial_downsample_type: "stack_conv" 108 | initial_downsample_activation: "leaky" 109 | initial_downsample_stack_conv_num_layers: 3 110 | initial_downsample_stack_conv_dim_list: [4, 16, 64] 111 | initial_downsample_stack_conv_downscale_list: [1, 2, 2] 112 | initial_downsample_stack_conv_num_conv_list: [2, 2, 2] 113 | 114 | attn_linear_init_mode: "0" 115 | ffn_linear_init_mode: "0" 116 | conv_init_mode: "0" 117 | down_up_linear_init_mode: "0" 118 | norm_init_mode: "0" 119 | -------------------------------------------------------------------------------- /scripts/cuboid_transformer/sevir/earthformer_sevir_v1.yaml: -------------------------------------------------------------------------------- 1 | # Configurations for pretrained checkpoint. Please do not modify it. 2 | dataset: 3 | dataset_name: "sevir" 4 | img_height: 384 5 | img_width: 384 6 | in_len: 13 7 | out_len: 12 8 | seq_len: 25 9 | plot_stride: 2 10 | interval_real_time: 5 11 | sample_mode: "sequent" 12 | stride: 12 13 | layout: "NTHWC" 14 | start_date: null 15 | train_val_split_date: [2019, 1, 1] 16 | train_test_split_date: [2019, 6, 1] 17 | end_date: null 18 | metrics_mode: "0" 19 | metrics_list: ['csi', 'pod', 'sucr', 'bias'] 20 | threshold_list: [16, 74, 133, 160, 181, 219] 21 | layout: 22 | in_len: 13 23 | out_len: 12 24 | layout: "NTHWC" 25 | optim: 26 | total_batch_size: 32 27 | micro_batch_size: 2 28 | seed: 0 29 | method: "adamw" 30 | lr: 0.001 31 | wd: 0.0 32 | gradient_clip_val: 1.0 33 | max_epochs: 100 34 | # scheduler 35 | lr_scheduler_mode: "cosine" 36 | min_lr_ratio: 1.0e-3 37 | warmup_min_lr_ratio: 0.0 38 | warmup_percentage: 0.2 39 | # early stopping 40 | early_stop: true 41 | early_stop_mode: "min" 42 | early_stop_patience: 20 43 | save_top_k: 1 44 | logging: 45 | logging_prefix: "Cuboid_SEVIR" 46 | monitor_lr: true 47 | monitor_device: false 48 | track_grad_norm: -1 49 | use_wandb: false 50 | trainer: 51 | check_val_every_n_epoch: 1 52 | log_step_ratio: 0.001 53 | precision: 32 54 | vis: 55 | train_example_data_idx_list: [0, ] 56 | val_example_data_idx_list: [80, ] 57 | test_example_data_idx_list: [0, 80, 160, 240, 320, 400] 58 | eval_example_only: false 59 | plot_stride: 2 60 | model: 61 | input_shape: [13, 384, 384, 1] 62 | target_shape: [12, 384, 384, 1] 63 | base_units: 128 64 | # block_units: null 65 | scale_alpha: 1.0 66 | 67 | enc_depth: [1, 1] 68 | dec_depth: [1, 1] 69 | enc_use_inter_ffn: true 70 | dec_use_inter_ffn: true 71 | dec_hierarchical_pos_embed: false 72 | 73 | downsample: 2 74 | downsample_type: "patch_merge" 75 | upsample_type: "upsample" 76 | 77 | num_global_vectors: 8 78 | use_dec_self_global: false 79 | dec_self_update_global: true 80 | use_dec_cross_global: false 81 | use_global_vector_ffn: false 82 | use_global_self_attn: true 83 | separate_global_qkv: true 84 | global_dim_ratio: 1 85 | 86 | self_pattern: "axial" 87 | cross_self_pattern: "axial" 88 | cross_pattern: "cross_1x1" 89 | dec_cross_last_n_frames: null 90 | 91 | attn_drop: 0.1 92 | proj_drop: 0.1 93 | ffn_drop: 0.1 94 | num_heads: 4 95 | 96 | ffn_activation: "gelu" 97 | gated_ffn: false 98 | norm_layer: "layer_norm" 99 | padding_type: "zeros" 100 | pos_embed_type: "t+h+w" 101 | use_relative_pos: true 102 | self_attn_use_final_proj: true 103 | dec_use_first_self_attn: false 104 | 105 | z_init_method: "zeros" 106 | checkpoint_level: 0 107 | 108 | initial_downsample_type: "stack_conv" 109 | initial_downsample_activation: "leaky" 110 | initial_downsample_stack_conv_num_layers: 3 111 | initial_downsample_stack_conv_dim_list: [16, 64, 128] 112 | initial_downsample_stack_conv_downscale_list: [3, 2, 2] 113 | initial_downsample_stack_conv_num_conv_list: [2, 2, 2] 114 | 115 | attn_linear_init_mode: "0" 116 | ffn_linear_init_mode: "0" 117 | conv_init_mode: "0" 118 | down_up_linear_init_mode: "0" 119 | norm_init_mode: "0" 120 | -------------------------------------------------------------------------------- /scripts/datasets/nbody/README.md: -------------------------------------------------------------------------------- 1 | # N-body MNIST 2 | The underlying dynamics in the N-body MNIST dataset is governed by the Newton's law of universal gravitation: 3 | 4 | $\frac{d^2\boldsymbol{x}\_{i}}{dt^2} = - \sum\_{j\neq i}\frac{G m\_j (\boldsymbol{x}\_{i}-\boldsymbol{x}\_{j})}{(\|\boldsymbol{x}\_i-\boldsymbol{x}\_j\|+d\_{\text{soft}})^r}$ 5 | 6 | where $\boldsymbol{x}\_{i}$ is the spatial coordinates of the $i$-th digit, $G$ is the gravitational constant, $m\_j$ is the mass of the $j$-th digit, $r$ is a constant representing the power scale in the gravitational law, $d\_{\text{soft}}$ is a small softening distance that ensures numerical stability. 7 | 8 | To download the N-body MNIST dataset used in our paper from AWS S3 run: 9 | ```bash 10 | cd ROOT_DIR/earth-forecasting-transformer 11 | python ./scripts/datasets/nbody/download_nbody_paper.py 12 | ``` 13 | 14 | Alternatively, run the following commands to generate N-body MNIST dataset. 15 | ```bash 16 | cd ROOT_DIR/earth-forecasting-transformer 17 | python ./scripts/datasets/nbody/generate_nbody_dataset.py --cfg ./scripts/datasets/nbody/cfg.yaml 18 | ``` 19 | -------------------------------------------------------------------------------- /scripts/datasets/nbody/cfg.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset_name: "nbody_digits3_len20_size64_r0_train20k" 3 | num_train_samples: 20000 4 | num_val_samples: 1000 5 | num_test_samples: 1000 6 | digit_num: null 7 | img_size: 64 8 | raw_img_size: 128 9 | seq_len: 20 10 | raw_seq_len_multiplier: 5 11 | distractor_num: null 12 | distractor_size: 5 13 | max_velocity_scale: 2.0 14 | initial_velocity_range: [0.0, 2.0] 15 | random_acceleration_range: [0.0, 0.0] 16 | scale_variation_range: [1.0, 1.0] 17 | rotation_angle_range: [-0, 0] 18 | illumination_factor_range: [1.0, 1.0] 19 | period: 5 20 | global_rotation_prob: 0.5 21 | index_range: [0, 40000] 22 | mnist_data_path: null 23 | # N-Body params 24 | nbody_acc_mode: "r0" 25 | nbody_G: 0.05 26 | nbody_softening_distance: 10.0 27 | nbody_mass: null 28 | -------------------------------------------------------------------------------- /scripts/datasets/nbody/download_nbody_paper.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | 4 | 5 | nbody_paper_name = "nbody_digits3_len20_size64_r0_train20k" 6 | nbody_paper_zip_name = "nbody_paper.zip" 7 | 8 | def s3_download_nbody_paper(save_dir=None, exist_ok=False): 9 | if save_dir is None: 10 | from earthformer.config import cfg 11 | save_dir = os.path.join(cfg.datasets_dir, "nbody") 12 | if os.path.exists(os.path.join(save_dir, nbody_paper_name)) and not exist_ok: 13 | warnings.warn(f"N-body dataset {os.path.join(save_dir, nbody_paper_name)} already exists!") 14 | else: 15 | os.makedirs(save_dir, exist_ok=True) 16 | os.system(f"aws s3 cp --no-sign-request s3://earthformer/nbody/{nbody_paper_zip_name} " 17 | f"{save_dir}") 18 | os.system(f"unzip {os.path.join(save_dir, nbody_paper_zip_name)} " 19 | f"-d {save_dir}") 20 | 21 | if __name__ == "__main__": 22 | s3_download_nbody_paper(exist_ok=False) 23 | -------------------------------------------------------------------------------- /scripts/datasets/nbody/generate_nbody_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from shutil import copyfile 4 | from omegaconf import OmegaConf 5 | import numpy as np 6 | from earthformer.datasets.nbody.nbody_mnist_torch_wrap import NBodyMovingMNISTLightningDataModule 7 | from earthformer.visualization.utils import save_gif 8 | 9 | 10 | def duplicate_single_seq_dataset(npz_path, num_copies=64, save_path=None): 11 | r""" 12 | Duplicate a single N-body sequence `num_copies` times as a toy dataset for debugging. 13 | """ 14 | if save_path is None: 15 | from pathlib import Path 16 | save_dir = os.path.dirname(npz_path) 17 | save_name = Path(npz_path).stem + f"_copy{num_copies}.npz" 18 | save_path = os.path.join(save_dir, save_name) 19 | f = np.load(npz_path) 20 | saved_data_dict = dict(f) 21 | new_data_dict = {} 22 | for key, val in saved_data_dict.items(): 23 | dup_val = val[0] 24 | dup_val = np.repeat(dup_val[np.newaxis], 25 | repeats=num_copies, 26 | axis=0) 27 | new_data_dict[key] = dup_val 28 | np.savez_compressed(file=save_path, **new_data_dict) 29 | 30 | def generate_nbody_dataset(save_dir=None, oc_file_path=None): 31 | if oc_file_path is None: 32 | oc_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "cfg.yaml") 33 | dataset_oc = OmegaConf.to_object(OmegaConf.load(open(oc_file_path, "r")).dataset) 34 | if save_dir is None: 35 | from earthformer.config import cfg 36 | save_dir = os.path.join(cfg.datasets_dir, "nbody") 37 | data_dir = os.path.join(save_dir, dataset_oc["dataset_name"]) 38 | if os.path.exists(data_dir): 39 | raise ValueError(f"data_dir {data_dir} already exists!") 40 | else: 41 | os.makedirs(data_dir) 42 | copyfile(oc_file_path, os.path.join(data_dir, "nbody_dataset_cfg.yaml")) 43 | dm = NBodyMovingMNISTLightningDataModule( 44 | data_dir=data_dir, 45 | force_regenerate=False, 46 | num_train_samples=dataset_oc["num_train_samples"], 47 | num_val_samples=dataset_oc["num_val_samples"], 48 | num_test_samples=dataset_oc["num_test_samples"], 49 | digit_num=dataset_oc["digit_num"], 50 | img_size=dataset_oc["img_size"], 51 | raw_img_size=dataset_oc["raw_img_size"], 52 | seq_len=dataset_oc["seq_len"], 53 | raw_seq_len_multiplier=dataset_oc["raw_seq_len_multiplier"], 54 | distractor_num=dataset_oc["distractor_num"], 55 | distractor_size=dataset_oc["distractor_size"], 56 | max_velocity_scale=dataset_oc["max_velocity_scale"], 57 | initial_velocity_range=dataset_oc["initial_velocity_range"], 58 | random_acceleration_range=dataset_oc["random_acceleration_range"], 59 | scale_variation_range=dataset_oc["scale_variation_range"], 60 | rotation_angle_range=dataset_oc["rotation_angle_range"], 61 | illumination_factor_range=dataset_oc["illumination_factor_range"], 62 | period=dataset_oc["period"], 63 | global_rotation_prob=dataset_oc["global_rotation_prob"], 64 | index_range=dataset_oc["index_range"], 65 | mnist_data_path=dataset_oc["mnist_data_path"], 66 | # N-Body params 67 | nbody_acc_mode=dataset_oc["nbody_acc_mode"], 68 | nbody_G=dataset_oc["nbody_G"], 69 | nbody_softening_distance=dataset_oc["nbody_softening_distance"], 70 | nbody_mass=dataset_oc["nbody_mass"], 71 | # datamodule_only 72 | batch_size=1, 73 | num_workers=8, ) 74 | dm.prepare_data() 75 | dm.setup() 76 | seq = dm.nbody_train[0] 77 | save_gif(single_seq=seq[:, :, :, 0].numpy().astype(np.float32), 78 | fname=os.path.join(data_dir, "nbody_example.gif")) 79 | 80 | def get_parser(): 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--cfg', default=None, type=str) 83 | return parser 84 | 85 | def main(): 86 | parser = get_parser() 87 | args = parser.parse_args() 88 | generate_nbody_dataset(oc_file_path=args.cfg) 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /scripts/datasets/sevir/download_sevir.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from earthformer.datasets.sevir.sevir_torch_wrap import download_SEVIR 3 | 4 | 5 | def get_parser(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--dataset", default="sevir", type=str) 8 | parser.add_argument("--save", default=None, type=str) 9 | return parser 10 | 11 | 12 | def main(): 13 | parser = get_parser() 14 | args = parser.parse_args() 15 | if args.dataset == "sevir": 16 | download_SEVIR(save_dir=args.save) 17 | else: 18 | raise ValueError(f"Wrong dataset name {args.dataset}! Must be one of ('sevir', 'sevir_lr').") 19 | 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import io 3 | import os 4 | import re 5 | from datetime import datetime 6 | from setuptools import find_packages, setup 7 | 8 | 9 | def read(*names, **kwargs): 10 | with io.open(os.path.join(os.path.dirname(__file__), *names), 11 | encoding=kwargs.get("encoding", "utf8")) as fp: 12 | return fp.read() 13 | 14 | 15 | def find_version(*file_paths): 16 | version_file = read(*file_paths) 17 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 18 | if version_match: 19 | return version_match.group(1) 20 | raise RuntimeError("Unable to find version string.") 21 | 22 | 23 | VERSION = find_version('src', 'earthformer', '__init__.py') 24 | 25 | if VERSION.endswith('dev'): 26 | VERSION = VERSION + datetime.today().strftime('%Y%m%d') 27 | 28 | requirements = [ 29 | 'absl-py', 30 | 'boto3', 31 | 'javalang>=0.13.0', 32 | 'h5py>=2.10.0', 33 | 'yacs>=0.1.8', 34 | 'protobuf', 35 | 'unidiff', 36 | 'scipy', 37 | 'tqdm', 38 | 'regex', 39 | 'requests', 40 | 'jsonlines', 41 | 'contextvars', 42 | 'pyarrow>=3', 43 | 'transformers>=4.3.0', 44 | 'tensorboard', 45 | 'pandas', 46 | 'contextvars;python_version<"3.7"', # Contextvars for python <= 3.6 47 | 'dataclasses;python_version<"3.7"', # Dataclass for python <= 3.6 48 | 'pickle5;python_version<"3.8"', # pickle protocol 5 for python <= 3.8 49 | 'graphviz', 50 | 'networkx', 51 | 'fairscale>=0.3.0', 52 | 'fvcore>=0.1.5', 53 | 'pympler', 54 | 'einops>=0.3.0', 55 | 'timm', 56 | 'omegaconf', 57 | 'matplotlib', 58 | 'awscli', 59 | 'boto3', 60 | 'botocore', 61 | ] 62 | 63 | setup( 64 | # Metadata 65 | name='earthformer', 66 | version=VERSION, 67 | python_requires='>=3.6', 68 | description='Earthformer: Exploring Space-Time Transformers for Earth System Forecasting', 69 | long_description_content_type='text/markdown', 70 | license='Apache-2.0', 71 | 72 | # Package info 73 | packages=find_packages(where="src", exclude=( 74 | 'tests', 75 | 'scripts', 76 | )), 77 | package_dir={"": "src"}, 78 | zip_safe=True, 79 | include_package_data=True, 80 | install_requires=requirements, 81 | ) 82 | -------------------------------------------------------------------------------- /src/earthformer/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1.dev' 2 | -------------------------------------------------------------------------------- /src/earthformer/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/baselines/__init__.py -------------------------------------------------------------------------------- /src/earthformer/baselines/persistence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Persistence(nn.Module): 7 | 8 | def __init__(self, layout="NTHWC"): 9 | super(Persistence, self).__init__() 10 | self.layout = layout 11 | self.t_axis = self.layout.find("T") 12 | self.in_seq_slice = [slice(None, None), ] * len(self.layout) 13 | self.in_seq_slice[self.t_axis] = slice(-1, None) 14 | 15 | def forward(self, in_seq, out_seq): 16 | out_len = out_seq.shape[self.t_axis] 17 | output = in_seq[self.in_seq_slice] 18 | output = torch.repeat_interleave(output, 19 | repeats=out_len, 20 | dim=self.t_axis) 21 | loss = F.mse_loss(output, out_seq) 22 | return output, loss 23 | -------------------------------------------------------------------------------- /src/earthformer/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import OmegaConf 3 | 4 | _CURR_DIR = os.path.realpath(os.path.dirname(os.path.realpath(__file__))) 5 | 6 | 7 | cfg = OmegaConf.create() 8 | cfg.root_dir = os.path.abspath(os.path.realpath(os.path.join(_CURR_DIR, "..", ".."))) 9 | cfg.datasets_dir = os.path.join(cfg.root_dir, "datasets") # default directory for loading datasets 10 | cfg.pretrained_checkpoints_dir = os.path.join(cfg.root_dir, "pretrained_checkpoints") # default directory for saving and loading pretrained checkpoints 11 | cfg.exps_dir = os.path.join(cfg.root_dir, "experiments") # default directory for saving experiment results 12 | os.makedirs(cfg.exps_dir, exist_ok=True) 13 | -------------------------------------------------------------------------------- /src/earthformer/cuboid_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/cuboid_transformer/__init__.py -------------------------------------------------------------------------------- /src/earthformer/cuboid_transformer/cuboid_transformer_patterns.py: -------------------------------------------------------------------------------- 1 | """Patterns for cuboid self-attention / cross attention""" 2 | 3 | import functools 4 | from ..utils.registry import Registry 5 | 6 | CuboidSelfAttentionPatterns = Registry('CuboidSelfAttentionPattern') 7 | CuboidCrossAttentionPatterns = Registry('CuboidCrossAttentionPatterns') 8 | 9 | # basic patterns 10 | 11 | def full_attention(input_shape): 12 | T, H, W, _ = input_shape 13 | cuboid_size = [(T, H, W)] 14 | strategy = [('l', 'l', 'l')] 15 | shift_size = [(0, 0, 0)] 16 | return cuboid_size, strategy, shift_size 17 | 18 | def self_axial(input_shape): 19 | """Axial attention proposed in https://arxiv.org/abs/1912.12180 20 | 21 | Parameters 22 | ---------- 23 | input_shape 24 | T, H, W 25 | 26 | Returns 27 | ------- 28 | cuboid_size 29 | strategy 30 | shift_size 31 | """ 32 | T, H, W, _ = input_shape 33 | cuboid_size = [(T, 1, 1), (1, H, 1), (1, 1, W)] 34 | strategy = [('l', 'l', 'l'), ('l', 'l', 'l'), ('l', 'l', 'l')] 35 | shift_size = [(0, 0, 0), (0, 0, 0), (0, 0, 0)] 36 | return cuboid_size, strategy, shift_size 37 | 38 | def self_video_swin(input_shape, P=2, M=4): 39 | """Adopt the strategy in Video SwinTransformer https://arxiv.org/pdf/2106.13230.pdf""" 40 | T, H, W, _ = input_shape 41 | P = min(P, T) 42 | M = min(M, H, W) 43 | cuboid_size = [(P, M, M), (P, M, M)] 44 | strategy = [('l', 'l', 'l'), ('l', 'l', 'l')] 45 | shift_size = [(0, 0, 0), (P // 2, M // 2, M // 2)] 46 | 47 | return cuboid_size, strategy, shift_size 48 | 49 | def self_divided_space_time(input_shape): 50 | T, H, W, _ = input_shape 51 | cuboid_size = [(T, 1, 1), (1, H, W)] 52 | strategy = [('l', 'l', 'l'), ('l', 'l', 'l')] 53 | shift_size = [(0, 0, 0), (0, 0, 0)] 54 | return cuboid_size, strategy, shift_size 55 | 56 | # basic patterns 57 | CuboidSelfAttentionPatterns.register('full', full_attention) 58 | CuboidSelfAttentionPatterns.register('axial', self_axial) 59 | CuboidSelfAttentionPatterns.register('video_swin', self_video_swin) 60 | CuboidSelfAttentionPatterns.register('divided_st', self_divided_space_time) 61 | # video_swin_PxM 62 | for p in [1, 2, 4, 8, 10]: 63 | for m in [1, 2, 4, 8, 16, 32]: 64 | CuboidSelfAttentionPatterns.register( 65 | f'video_swin_{p}x{m}', 66 | functools.partial(self_video_swin, 67 | P=p, M=m)) 68 | 69 | # our proposals 70 | def self_spatial_lg_v1(input_shape, M=4): 71 | T, H, W, _ = input_shape 72 | 73 | if H <= M and W <= M: 74 | cuboid_size = [(T, 1, 1), (1, H, W)] 75 | strategy = [('l', 'l', 'l'), ('l', 'l', 'l')] 76 | shift_size = [(0, 0, 0), (0, 0, 0)] 77 | else: 78 | cuboid_size = [(T, 1, 1), (1, M, M), (1, M, M)] 79 | strategy = [('l', 'l', 'l'), ('l', 'l', 'l'), ('d', 'd', 'd')] 80 | shift_size = [(0, 0, 0), (0, 0, 0), (0, 0, 0)] 81 | return cuboid_size, strategy, shift_size 82 | 83 | 84 | # Following are our proposed new patterns based on the CuboidSelfAttention design. 85 | CuboidSelfAttentionPatterns.register('spatial_lg_v1', self_spatial_lg_v1) 86 | # spatial_lg 87 | for m in [1, 2, 4, 8, 16, 32]: 88 | CuboidSelfAttentionPatterns.register( 89 | f'spatial_lg_{m}', 90 | functools.partial(self_spatial_lg_v1, 91 | M=m)) 92 | 93 | def self_axial_space_dilate_K(input_shape, K=2): 94 | T, H, W, _ = input_shape 95 | K = min(K, H, W) 96 | cuboid_size = [(T, 1, 1), 97 | (1, H // K, 1), (1, H // K, 1), 98 | (1, 1, W // K), (1, 1, W // K)] 99 | strategy = [('l', 'l', 'l'), 100 | ('d', 'd', 'd'), ('l', 'l', 'l'), 101 | ('d', 'd', 'd'), ('l', 'l', 'l'),] 102 | shift_size = [(0, 0, 0), 103 | (0, 0, 0), (0, 0, 0), 104 | (0, 0, 0), (0, 0, 0)] 105 | return cuboid_size, strategy, shift_size 106 | for k in [2, 4, 8]: 107 | CuboidSelfAttentionPatterns.register( 108 | f'axial_space_dilate_{k}', 109 | functools.partial(self_axial_space_dilate_K, 110 | K=k)) 111 | 112 | 113 | def cross_KxK(mem_shape, K): 114 | """ 115 | 116 | Parameters 117 | ---------- 118 | mem_shape 119 | K 120 | 121 | Returns 122 | ------- 123 | cuboid_hw 124 | shift_hw 125 | strategy 126 | n_temporal 127 | """ 128 | T_mem, H, W, _ = mem_shape 129 | K = min(K, H, W) 130 | cuboid_hw = [(K, K)] 131 | shift_hw = [(0, 0)] 132 | strategy = [('l', 'l', 'l')] 133 | n_temporal = [1] 134 | return cuboid_hw, shift_hw, strategy, n_temporal 135 | 136 | def cross_KxK_lg(mem_shape, K): 137 | """ 138 | 139 | Parameters 140 | ---------- 141 | mem_shape 142 | K 143 | 144 | Returns 145 | ------- 146 | cuboid_hw 147 | shift_hw 148 | strategy 149 | n_temporal 150 | """ 151 | T_mem, H, W, _ = mem_shape 152 | K = min(K, H, W) 153 | cuboid_hw = [(K, K), (K, K)] 154 | shift_hw = [(0, 0), (0, 0)] 155 | strategy = [('l', 'l', 'l'), ('d', 'd', 'd')] 156 | n_temporal = [1, 1] 157 | return cuboid_hw, shift_hw, strategy, n_temporal 158 | 159 | def cross_KxK_heter(mem_shape, K): 160 | """ 161 | 162 | Parameters 163 | ---------- 164 | mem_shape 165 | K 166 | 167 | Returns 168 | ------- 169 | cuboid_hw 170 | shift_hw 171 | strategy 172 | n_temporal 173 | """ 174 | T_mem, H, W, _ = mem_shape 175 | K = min(K, H, W) 176 | cuboid_hw = [(K, K), (K, K), (K, K)] 177 | shift_hw = [(0, 0), (0, 0), (K // 2, K // 2)] 178 | strategy = [('l', 'l', 'l'), ('d', 'd', 'd'), ('l', 'l', 'l')] 179 | n_temporal = [1, 1, 1] 180 | return cuboid_hw, shift_hw, strategy, n_temporal 181 | 182 | # # Our proposed CuboidCrossAttention patterns. 183 | for k in [1, 2, 4, 8]: 184 | CuboidCrossAttentionPatterns.register(f'cross_{k}x{k}', functools.partial(cross_KxK, K=k)) 185 | CuboidCrossAttentionPatterns.register(f'cross_{k}x{k}_lg', functools.partial(cross_KxK_lg, K=k)) 186 | CuboidCrossAttentionPatterns.register(f'cross_{k}x{k}_heter', functools.partial(cross_KxK_heter, K=k)) 187 | -------------------------------------------------------------------------------- /src/earthformer/cuboid_transformer/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | def round_to(dat, c): 8 | return dat + (dat - dat % c) % c 9 | 10 | def get_activation(act, inplace=False, **kwargs): 11 | """ 12 | 13 | Parameters 14 | ---------- 15 | act 16 | Name of the activation 17 | inplace 18 | Whether to perform inplace activation 19 | 20 | Returns 21 | ------- 22 | activation_layer 23 | The activation 24 | """ 25 | if act is None: 26 | return lambda x: x 27 | if isinstance(act, str): 28 | if act == 'leaky': 29 | negative_slope = kwargs.get("negative_slope", 0.1) 30 | return nn.LeakyReLU(negative_slope, inplace=inplace) 31 | elif act == 'identity': 32 | return nn.Identity() 33 | elif act == 'elu': 34 | return nn.ELU(inplace=inplace) 35 | elif act == 'gelu': 36 | return nn.GELU() 37 | elif act == 'relu': 38 | return nn.ReLU() 39 | elif act == 'sigmoid': 40 | return nn.Sigmoid() 41 | elif act == 'tanh': 42 | return nn.Tanh() 43 | elif act == 'softrelu' or act == 'softplus': 44 | return nn.Softplus() 45 | elif act == 'softsign': 46 | return nn.Softsign() 47 | else: 48 | raise NotImplementedError('act="{}" is not supported. ' 49 | 'Try to include it if you can find that in ' 50 | 'https://pytorch.org/docs/stable/nn.html'.format(act)) 51 | else: 52 | return act 53 | 54 | class RMSNorm(nn.Module): 55 | def __init__(self, d, p=-1., eps=1e-8, bias=False): 56 | """Root Mean Square Layer Normalization proposed in "[NeurIPS2019] Root Mean Square Layer Normalization" 57 | 58 | Parameters 59 | ---------- 60 | d 61 | model size 62 | p 63 | partial RMSNorm, valid value [0, 1], default -1.0 (disabled) 64 | eps 65 | epsilon value, default 1e-8 66 | bias 67 | whether use bias term for RMSNorm, disabled by 68 | default because RMSNorm doesn't enforce re-centering invariance. 69 | """ 70 | super(RMSNorm, self).__init__() 71 | 72 | self.eps = eps 73 | self.d = d 74 | self.p = p 75 | self.bias = bias 76 | 77 | self.scale = nn.Parameter(torch.ones(d)) 78 | self.register_parameter("scale", self.scale) 79 | 80 | if self.bias: 81 | self.offset = nn.Parameter(torch.zeros(d)) 82 | self.register_parameter("offset", self.offset) 83 | 84 | def forward(self, x): 85 | if self.p < 0. or self.p > 1.: 86 | norm_x = x.norm(2, dim=-1, keepdim=True) 87 | d_x = self.d 88 | else: 89 | partial_size = int(self.d * self.p) 90 | partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) 91 | 92 | norm_x = partial_x.norm(2, dim=-1, keepdim=True) 93 | d_x = partial_size 94 | 95 | rms_x = norm_x * d_x ** (-1. / 2) 96 | x_normed = x / (rms_x + self.eps) 97 | 98 | if self.bias: 99 | return self.scale * x_normed + self.offset 100 | 101 | return self.scale * x_normed 102 | 103 | def get_norm_layer(normalization: str = 'layer_norm', 104 | axis: int = -1, 105 | epsilon: float = 1e-5, 106 | in_channels: int = 0, **kwargs): 107 | """Get the normalization layer based on the provided type 108 | 109 | Parameters 110 | ---------- 111 | normalization 112 | The type of the layer normalization from ['layer_norm'] 113 | axis 114 | The axis to normalize the 115 | epsilon 116 | The epsilon of the normalization layer 117 | in_channels 118 | Input channel 119 | 120 | Returns 121 | ------- 122 | norm_layer 123 | The layer normalization layer 124 | """ 125 | if isinstance(normalization, str): 126 | if normalization == 'layer_norm': 127 | assert in_channels > 0 128 | assert axis == -1 129 | norm_layer = nn.LayerNorm(normalized_shape=in_channels, eps=epsilon, **kwargs) 130 | elif normalization == 'rms_norm': 131 | assert axis == -1 132 | norm_layer = RMSNorm(d=in_channels, eps=epsilon, **kwargs) 133 | else: 134 | raise NotImplementedError('normalization={} is not supported'.format(normalization)) 135 | return norm_layer 136 | elif normalization is None: 137 | return nn.Identity() 138 | else: 139 | raise NotImplementedError('The type of normalization must be str') 140 | 141 | 142 | def _generalize_padding(x, pad_t, pad_h, pad_w, padding_type, t_pad_left=False): 143 | """ 144 | 145 | Parameters 146 | ---------- 147 | x 148 | Shape (B, T, H, W, C) 149 | pad_t 150 | pad_h 151 | pad_w 152 | padding_type 153 | t_pad_left 154 | 155 | Returns 156 | ------- 157 | out 158 | The result after padding the x. Shape will be (B, T + pad_t, H + pad_h, W + pad_w, C) 159 | """ 160 | if pad_t == 0 and pad_h == 0 and pad_w == 0: 161 | return x 162 | 163 | assert padding_type in ['zeros', 'ignore', 'nearest'] 164 | B, T, H, W, C = x.shape 165 | 166 | if padding_type == 'nearest': 167 | return F.interpolate(x.permute(0, 4, 1, 2, 3), size=(T + pad_t, H + pad_h, W + pad_w)).permute(0, 2, 3, 4, 1) 168 | else: 169 | if t_pad_left: 170 | return F.pad(x, (0, 0, 0, pad_w, 0, pad_h, pad_t, 0)) 171 | else: 172 | return F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t)) 173 | 174 | 175 | def _generalize_unpadding(x, pad_t, pad_h, pad_w, padding_type): 176 | assert padding_type in['zeros', 'ignore', 'nearest'] 177 | B, T, H, W, C = x.shape 178 | if pad_t == 0 and pad_h == 0 and pad_w == 0: 179 | return x 180 | 181 | if padding_type == 'nearest': 182 | return F.interpolate(x.permute(0, 4, 1, 2, 3), size=(T - pad_t, H - pad_h, W - pad_w)).permute(0, 2, 3, 4, 1) 183 | else: 184 | return x[:, :(T - pad_t), :(H - pad_h), :(W - pad_w), :].contiguous() 185 | 186 | def apply_initialization(m, 187 | linear_mode="0", 188 | conv_mode="0", 189 | norm_mode="0", 190 | embed_mode="0"): 191 | if isinstance(m, nn.Linear): 192 | 193 | if linear_mode in ("0", ): 194 | nn.init.kaiming_normal_(m.weight, 195 | mode='fan_in', nonlinearity="linear") 196 | elif linear_mode in ("1", ): 197 | nn.init.kaiming_normal_(m.weight, 198 | a=0.1, 199 | mode='fan_out', 200 | nonlinearity="leaky_relu") 201 | else: 202 | raise NotImplementedError 203 | if hasattr(m, 'bias') and m.bias is not None: 204 | nn.init.zeros_(m.bias) 205 | elif isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): 206 | if conv_mode in ("0", ): 207 | nn.init.kaiming_normal_(m.weight, 208 | a=0.1, 209 | mode='fan_out', 210 | nonlinearity="leaky_relu") 211 | else: 212 | raise NotImplementedError 213 | if hasattr(m, 'bias') and m.bias is not None: 214 | nn.init.zeros_(m.bias) 215 | elif isinstance(m, nn.LayerNorm): 216 | if norm_mode in ("0", ): 217 | if m.elementwise_affine: 218 | nn.init.ones_(m.weight) 219 | nn.init.zeros_(m.bias) 220 | else: 221 | raise NotImplementedError 222 | elif isinstance(m, nn.GroupNorm): 223 | if norm_mode in ("0", ): 224 | if m.affine: 225 | nn.init.ones_(m.weight) 226 | nn.init.zeros_(m.bias) 227 | else: 228 | raise NotImplementedError 229 | # # pos_embed already initialized when created 230 | elif isinstance(m, nn.Embedding): 231 | if embed_mode in ("0", ): 232 | nn.init.trunc_normal_(m.weight.data, std=0.02) 233 | else: 234 | raise NotImplementedError 235 | else: 236 | pass 237 | -------------------------------------------------------------------------------- /src/earthformer/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/datasets/__init__.py -------------------------------------------------------------------------------- /src/earthformer/datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | import random 3 | import torchvision.transforms.functional as TF 4 | 5 | 6 | class TransformsFixRotation: 7 | r""" 8 | Rotate by one of the given angles. 9 | 10 | Example: `rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])` 11 | """ 12 | 13 | def __init__(self, angles): 14 | if not isinstance(angles, Sequence): 15 | angles = [angles, ] 16 | self.angles = angles 17 | 18 | def __call__(self, x): 19 | angle = random.choice(self.angles) 20 | return TF.rotate(x, angle) 21 | -------------------------------------------------------------------------------- /src/earthformer/datasets/earthnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/datasets/earthnet/__init__.py -------------------------------------------------------------------------------- /src/earthformer/datasets/earthnet/earthnet_scores.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import hmean 3 | import torch 4 | from torchmetrics import Metric 5 | from einops import rearrange 6 | from .earthnet_toolkit.parallel_score import CubeCalculator as EN_CubeCalculator 7 | from ...metrics.torchmetrics_wo_compute import MetricsUpdateWithoutCompute 8 | 9 | 10 | class EarthNet2021Score(Metric): 11 | 12 | default_layout = "NHWCT" 13 | default_channel_axis = 3 14 | channels = 4 15 | 16 | def __init__(self, 17 | layout: str = "NTHWC", 18 | eps: float = 1e-4, 19 | dist_sync_on_step: bool = False, ): 20 | super(EarthNet2021Score, self).__init__(dist_sync_on_step=dist_sync_on_step) 21 | self.layout = layout 22 | self.eps = eps 23 | 24 | self.add_state("MAD", 25 | default=torch.tensor(0.0), 26 | dist_reduce_fx="sum") 27 | self.add_state("OLS", 28 | default=torch.tensor(0.0), 29 | dist_reduce_fx="sum") 30 | self.add_state("EMD", 31 | default=torch.tensor(0.0), 32 | dist_reduce_fx="sum") 33 | self.add_state("SSIM", 34 | default=torch.tensor(0.0), 35 | dist_reduce_fx="sum") 36 | # does not count if NaN 37 | self.add_state("num_MAD", 38 | default=torch.tensor(0), 39 | dist_reduce_fx="sum") 40 | self.add_state("num_OLS", 41 | default=torch.tensor(0), 42 | dist_reduce_fx="sum") 43 | self.add_state("num_EMD", 44 | default=torch.tensor(0), 45 | dist_reduce_fx="sum") 46 | self.add_state("num_SSIM", 47 | default=torch.tensor(0), 48 | dist_reduce_fx="sum") 49 | 50 | @property 51 | def einops_default_layout(self): 52 | if not hasattr(self, "_einops_default_layout"): 53 | self._einops_default_layout = " ".join(self.default_layout) 54 | return self._einops_default_layout 55 | 56 | @property 57 | def einops_layout(self): 58 | if not hasattr(self, "_einops_layout"): 59 | self._einops_layout = " ".join(self.layout) 60 | return self._einops_layout 61 | 62 | def update(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor = None): 63 | r""" 64 | 65 | Parameters 66 | ---------- 67 | pred, target: torch.Tensor 68 | With the first dim as batch dim, and 4 channels (RGB and infrared) 69 | mask: torch.Tensor 70 | With the first dim as batch dim, and 1 channel 71 | """ 72 | pred_np = rearrange(pred.detach(), f"{self.einops_layout} -> {self.einops_default_layout}").cpu().numpy() 73 | target_np = rearrange(target.detach(), f"{self.einops_layout} -> {self.einops_default_layout}").cpu().numpy() 74 | # layout = "NHWCT" 75 | if mask is None: 76 | mask_np = np.ones_like(target_np) 77 | else: 78 | mask_np = torch.repeat_interleave( 79 | rearrange(1 - mask.detach(), f"{self.einops_layout} -> {self.einops_default_layout}"), 80 | repeats=self.channels, dim=self.default_channel_axis).cpu().numpy() 81 | for preds, targs, masks in zip(pred_np, target_np, mask_np): 82 | # Code is adapted from `load_file()` in ./earthnet_toolkit/parallel_score.py 83 | preds[preds < 0] = 0 84 | preds[preds > 1] = 1 85 | 86 | targs[np.isnan(targs)] = 0 87 | targs[targs > 1] = 1 88 | targs[targs < 0] = 0 89 | 90 | ndvi_preds = ((preds[:, :, 3, :] - preds[:, :, 2, :]) / (preds[:, :, 3, :] + preds[:, :, 2, :] + 1e-6))[:, 91 | :, np.newaxis, :] 92 | ndvi_targs = ((targs[:, :, 3, :] - targs[:, :, 2, :]) / (targs[:, :, 3, :] + targs[:, :, 2, :] + 1e-6))[:, 93 | :, np.newaxis, :] 94 | ndvi_masks = masks[:, :, 0, :][:, :, np.newaxis, :] 95 | # Code is adapted from `get_scores()` in ./earthnet_toolkit/parallel_score.py 96 | debug_info = {} 97 | mad, debug_info["MAD"] = EN_CubeCalculator.MAD(preds, targs, masks) 98 | ols, debug_info["OLS"] = EN_CubeCalculator.OLS(ndvi_preds, ndvi_targs, ndvi_masks) 99 | emd, debug_info["EMD"] = EN_CubeCalculator.EMD(ndvi_preds, ndvi_targs, ndvi_masks) 100 | ssim, debug_info["SSIM"] = EN_CubeCalculator.SSIM(preds, targs, masks) 101 | # does not count if NaN 102 | if mad is not None and not np.isnan(mad): 103 | self.MAD += mad 104 | self.num_MAD += 1 105 | if ols is not None and not np.isnan(ols): 106 | self.OLS += ols 107 | self.num_OLS += 1 108 | if emd is not None and not np.isnan(emd): 109 | self.EMD += emd 110 | self.num_EMD += 1 111 | if ssim is not None and not np.isnan(ssim): 112 | self.SSIM += ssim 113 | self.num_SSIM += 1 114 | 115 | def compute(self): 116 | MAD_mean = (self.MAD / (self.num_MAD + self.eps)).cpu().item() 117 | OLS_mean = (self.OLS / (self.num_OLS + self.eps)).cpu().item() 118 | EMD_mean = (self.EMD / (self.num_EMD + self.eps)).cpu().item() 119 | SSIM_mean = (self.SSIM / (self.num_SSIM + self.eps)).cpu().item() 120 | ENS = hmean([MAD_mean, OLS_mean, EMD_mean, SSIM_mean]) 121 | return { 122 | "MAD": MAD_mean, 123 | "OLS": OLS_mean, 124 | "EMD": EMD_mean, 125 | "SSIM":SSIM_mean, 126 | "EarthNetScore": ENS, 127 | } 128 | 129 | class EarthNet2021ScoreUpdateWithoutCompute(MetricsUpdateWithoutCompute): 130 | 131 | default_layout = "NHWCT" 132 | default_channel_axis = 3 133 | channels = 4 134 | 135 | def __init__(self, 136 | layout: str = "NTHWC", 137 | eps: float = 1e-4, 138 | dist_sync_on_step: bool = False, ): 139 | super(EarthNet2021ScoreUpdateWithoutCompute, self).__init__(dist_sync_on_step=dist_sync_on_step) 140 | self.layout = layout 141 | self.eps = eps 142 | 143 | self.add_state("MAD", 144 | default=torch.tensor(0.0), 145 | dist_reduce_fx="sum") 146 | self.add_state("OLS", 147 | default=torch.tensor(0.0), 148 | dist_reduce_fx="sum") 149 | self.add_state("EMD", 150 | default=torch.tensor(0.0), 151 | dist_reduce_fx="sum") 152 | self.add_state("SSIM", 153 | default=torch.tensor(0.0), 154 | dist_reduce_fx="sum") 155 | # does not count if NaN 156 | self.add_state("num_MAD", 157 | default=torch.tensor(0), 158 | dist_reduce_fx="sum") 159 | self.add_state("num_OLS", 160 | default=torch.tensor(0), 161 | dist_reduce_fx="sum") 162 | self.add_state("num_EMD", 163 | default=torch.tensor(0), 164 | dist_reduce_fx="sum") 165 | self.add_state("num_SSIM", 166 | default=torch.tensor(0), 167 | dist_reduce_fx="sum") 168 | 169 | @property 170 | def einops_default_layout(self): 171 | if not hasattr(self, "_einops_default_layout"): 172 | self._einops_default_layout = " ".join(self.default_layout) 173 | return self._einops_default_layout 174 | 175 | @property 176 | def einops_layout(self): 177 | if not hasattr(self, "_einops_layout"): 178 | self._einops_layout = " ".join(self.layout) 179 | return self._einops_layout 180 | 181 | def update(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor = None): 182 | r""" 183 | 184 | Parameters 185 | ---------- 186 | pred, target: torch.Tensor 187 | With the first dim as batch dim, and 4 channels (RGB and infrared) 188 | mask: torch.Tensor 189 | With the first dim as batch dim, and 1 channel 190 | """ 191 | pred_np = rearrange(pred.detach(), f"{self.einops_layout} -> {self.einops_default_layout}").cpu().numpy() 192 | target_np = rearrange(target.detach(), f"{self.einops_layout} -> {self.einops_default_layout}").cpu().numpy() 193 | # layout = "NHWCT" 194 | if mask is None: 195 | mask_np = np.ones_like(target_np) 196 | else: 197 | mask_np = torch.repeat_interleave( 198 | rearrange(1 - mask.detach(), f"{self.einops_layout} -> {self.einops_default_layout}"), 199 | repeats=self.channels, dim=self.default_channel_axis).cpu().numpy() 200 | for preds, targs, masks in zip(pred_np, target_np, mask_np): 201 | # Code is adapted from `load_file()` in ./earthnet_toolkit/parallel_score.py 202 | preds[preds < 0] = 0 203 | preds[preds > 1] = 1 204 | 205 | targs[np.isnan(targs)] = 0 206 | targs[targs > 1] = 1 207 | targs[targs < 0] = 0 208 | 209 | ndvi_preds = ((preds[:, :, 3, :] - preds[:, :, 2, :]) / (preds[:, :, 3, :] + preds[:, :, 2, :] + 1e-6))[:, 210 | :, np.newaxis, :] 211 | ndvi_targs = ((targs[:, :, 3, :] - targs[:, :, 2, :]) / (targs[:, :, 3, :] + targs[:, :, 2, :] + 1e-6))[:, 212 | :, np.newaxis, :] 213 | ndvi_masks = masks[:, :, 0, :][:, :, np.newaxis, :] 214 | # Code is adapted from `get_scores()` in ./earthnet_toolkit/parallel_score.py 215 | debug_info = {} 216 | mad, debug_info["MAD"] = EN_CubeCalculator.MAD(preds, targs, masks) 217 | ols, debug_info["OLS"] = EN_CubeCalculator.OLS(ndvi_preds, ndvi_targs, ndvi_masks) 218 | emd, debug_info["EMD"] = EN_CubeCalculator.EMD(ndvi_preds, ndvi_targs, ndvi_masks) 219 | ssim, debug_info["SSIM"] = EN_CubeCalculator.SSIM(preds, targs, masks) 220 | # does not count if NaN 221 | if mad is not None and not np.isnan(mad): 222 | self.MAD += mad 223 | self.num_MAD += 1 224 | if ols is not None and not np.isnan(ols): 225 | self.OLS += ols 226 | self.num_OLS += 1 227 | if emd is not None and not np.isnan(emd): 228 | self.EMD += emd 229 | self.num_EMD += 1 230 | if ssim is not None and not np.isnan(ssim): 231 | self.SSIM += ssim 232 | self.num_SSIM += 1 233 | 234 | def compute(self): 235 | MAD_mean = (self.MAD / (self.num_MAD + self.eps)).cpu().item() 236 | OLS_mean = (self.OLS / (self.num_OLS + self.eps)).cpu().item() 237 | EMD_mean = (self.EMD / (self.num_EMD + self.eps)).cpu().item() 238 | SSIM_mean = (self.SSIM / (self.num_SSIM + self.eps)).cpu().item() 239 | ENS = hmean([MAD_mean, OLS_mean, EMD_mean, SSIM_mean]) 240 | return { 241 | "MAD": MAD_mean, 242 | "OLS": OLS_mean, 243 | "EMD": EMD_mean, 244 | "SSIM":SSIM_mean, 245 | "EarthNetScore": ENS, 246 | } 247 | -------------------------------------------------------------------------------- /src/earthformer/datasets/earthnet/earthnet_toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/datasets/earthnet/earthnet_toolkit/__init__.py -------------------------------------------------------------------------------- /src/earthformer/datasets/earthnet/earthnet_toolkit/coords.py: -------------------------------------------------------------------------------- 1 | """Code is adapted from https://github.com/earthnet2021/earthnet-toolkit.""" 2 | import warnings 3 | import os 4 | import json 5 | from pyproj import Transformer 6 | from .coords_dict import COORDS 7 | 8 | 9 | EXTREME_TILES = ["32UMC", "32UNC", "32UPC", "32UQC"] 10 | 11 | def get_coords_from_cube(cubename: str, return_meso: bool = False, ignore_warning = False): 12 | """ 13 | 14 | Get the coordinates for a Cube in Lon-Lat-Grid. 15 | 16 | Args: 17 | cubename (str): cubename (has format tile_startyear_startmonth_startday_endyear_endmonth_endday_hrxmin_hrxmax_hrymin_hrymax_mesoxmin_mesoxmax_mesoymin_mesoymax.npz) 18 | return_meso (bool, optional): If True returns also the coordinates for the Meso-scale variables in the cube. Defaults to False. 19 | 20 | Returns: 21 | tuple: Min-Lon, Min-Lat, Max-Lon, Max-Lat or Min-Lon-HR, Min-Lat-HR, Max-Lon-HR, Max-Lat-HR, Min-Lon-Meso, Min-Lat-Meso, Max-Lon-Meso, Max-Lat-Meso 22 | """ 23 | 24 | if not ignore_warning: 25 | warnings.warn('Getting coordinates to a cube is experimental. The resulting coordinates on Lon-Lat-Grid will never be pixel perfect. Under certain circumstances, the whole bounding box might shifted by up to 0.02° in either direction. Use with caution. EarthNet2021 does not provide geo-referenced data.') 26 | 27 | cubetile,_, _,hr_x_min, hr_x_max, hr_y_min, hr_y_max, meso_x_min, meso_x_max, meso_y_min, meso_y_max = os.path.splitext(cubename)[0].split("_") 28 | 29 | tile = COORDS[cubetile] 30 | 31 | transformer = Transformer.from_crs(tile["EPSG"], 4326, always_xy = True) 32 | 33 | tile_x_min, tile_y_max = transformer.transform(tile["MinLon"],tile["MaxLat"], direction = "INVERSE") 34 | 35 | if cubetile in EXTREME_TILES: 36 | hr_x_min = int(hr_x_min) + 57 37 | hr_x_max = int(hr_x_max) + 57 38 | 39 | cube_x_min = tile_x_min + 20 * float(hr_y_min) 40 | cube_x_max = tile_x_min + 20 * float(hr_y_max) 41 | cube_y_min = tile_y_max - 20 * float(hr_x_min) 42 | cube_y_max = tile_y_max - 20 * float(hr_x_max) 43 | 44 | cube_lon_min, cube_lat_min = transformer.transform(cube_x_min, cube_y_max) 45 | cube_lon_max, cube_lat_max = transformer.transform(cube_x_max, cube_y_min) 46 | 47 | if return_meso: 48 | meso_x_min = tile_x_min + 20 * float(meso_x_min) 49 | meso_x_max = tile_x_min + 20 * float(meso_x_max) 50 | meso_y_min = tile_y_max - 20 * float(meso_y_min) 51 | meso_y_max = tile_y_max - 20 * float(meso_y_max) 52 | 53 | meso_lon_min, meso_lat_min = transformer.transform(meso_x_min, meso_y_max) 54 | meso_lon_max, meso_lat_max = transformer.transform(meso_x_max, meso_y_min) 55 | 56 | return cube_lon_min, cube_lat_min, cube_lon_max, cube_lat_max, meso_lon_min, meso_lat_min, meso_lon_max, meso_lat_max 57 | 58 | else: 59 | return cube_lon_min, cube_lat_min, cube_lon_max, cube_lat_max 60 | 61 | def get_coords_from_tile(tilename: str): 62 | """ 63 | Get the Coordinates for a Tile in Lon-lat-grid. 64 | 65 | Args: 66 | tilename (str): 5 Letter MGRS tile 67 | 68 | Returns: 69 | tuple: Min-Lon, Min-Lat, Max-Lon, Max-Lat 70 | """ 71 | tile = COORDS[tilename] 72 | 73 | return tile["MinLon"], tile["MinLat"], tile["MaxLon"], tile["MaxLat"] 74 | 75 | if __name__ == "__main__": 76 | import fire 77 | fire.Fire() 78 | -------------------------------------------------------------------------------- /src/earthformer/datasets/earthnet/earthnet_toolkit/download.py: -------------------------------------------------------------------------------- 1 | """Code is adapted from https://github.com/earthnet2021/earthnet-toolkit.""" 2 | from typing import Sequence, Union 3 | import hashlib 4 | import pickle 5 | import os 6 | import urllib.request 7 | from tqdm import tqdm 8 | import tarfile 9 | from .download_links import DOWNLOAD_LINKS 10 | 11 | 12 | class DownloadProgressBar(tqdm): 13 | def update_to(self, b=1, bsize=1, tsize=None): 14 | if tsize is not None: 15 | self.total = tsize 16 | self.update(b * bsize - self.n) 17 | 18 | def get_sha_of_file(file: str, buf_size: int = 100*1024*1024): 19 | sha = hashlib.sha256() 20 | with open(file, 'rb') as f: 21 | while True: 22 | data = f.read(buf_size) 23 | if not data: 24 | break 25 | sha.update(data) 26 | return sha.hexdigest() 27 | 28 | class Downloader(): 29 | """Downloader Class for EarthNet2021 30 | """ 31 | __URL__ = DOWNLOAD_LINKS 32 | 33 | def __init__(self, data_dir: str): 34 | """Initialize Downloader Class 35 | 36 | Args: 37 | data_dir (str): The directory where the data shall be saved in, we recommend data/dataset/ 38 | """ 39 | self.data_dir = data_dir 40 | os.makedirs(data_dir, exist_ok = True) 41 | 42 | @classmethod 43 | def get(cls, data_dir: str, splits: Union[str,Sequence[str]], overwrite: bool = False, delete: bool = True): 44 | """Download the EarthNet2021 Dataset 45 | 46 | Before downloading, ensure that you have enough free disk space. We recommend 1 TB. 47 | 48 | Specify the directory data_dir, where it should be saved. Then choose, which of the splits you want to download. 49 | All available splits: ["train","iid","ood","extreme","seasonal"] 50 | You can either give "all" to splits or a List of splits, for example ["train","iid"]. 51 | 52 | Args: 53 | data_dir (str): The directory where the data shall be saved in, we recommend data/dataset/ 54 | splits (Sequence[str]): Either "all" or a subset of ["train","iid","ood","extreme","seasonal"]. This determines the splits that are downloaded. 55 | overwrite (bool, optional): If True, overwrites an existing gzipped tarball by downloading it again. Defaults to False. 56 | delete (bool, optional): If True, deletes the downloaded tarball after unpacking it. Defaults to True. 57 | """ 58 | self = cls(data_dir) 59 | print(splits) 60 | if isinstance(splits, str): 61 | if splits == "all": 62 | splits = ["train","iid","ood","extreme","seasonal"] 63 | if isinstance(splits, str): 64 | splits_set = {splits} 65 | splits = (splits,) 66 | elif isinstance(splits, list) and len(splits) == 1: 67 | splits_set = {splits[0]} 68 | else: 69 | splits_set = set(splits) 70 | 71 | assert(splits_set.issubset(set(["train","iid","ood","extreme","seasonal"]))) 72 | 73 | progress_file = os.path.join(self.data_dir, ".PROGRESS") 74 | try: 75 | with open(progress_file, "rb") as fp: 76 | progress_list = pickle.load(fp) 77 | print("Resuming Download.") 78 | except: 79 | progress_list = [] 80 | 81 | for split in splits: 82 | print(f"Downloading split {split}") 83 | 84 | for filename, dl_url, sha in tqdm(self.__URL__[split]): 85 | if filename in progress_list and not overwrite: 86 | print(f"{filename} already downloaded") 87 | continue 88 | tmp_path = os.path.join(self.data_dir, filename) 89 | print("Downloading...") 90 | with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=filename) as t: 91 | urllib.request.urlretrieve(dl_url, filename = tmp_path, reporthook=t.update_to) 92 | print("Downloaded!") 93 | print("Asserting SHA256 Hash.") 94 | assert(sha == get_sha_of_file(tmp_path)) 95 | print("SHA256 Hash is correct!") 96 | print("Extracting tarball...") 97 | with tarfile.open(tmp_path, 'r:gz') as tar: 98 | members = tar.getmembers() 99 | for member in tqdm(iterable=members, total=len(members)): 100 | tar.extract(member=member,path=self.data_dir) 101 | print("Extracted!") 102 | if delete: 103 | print("Deleting tarball...") 104 | os.remove(tmp_path) 105 | 106 | progress_list.append(filename) 107 | with open(progress_file, "wb") as fp: 108 | pickle.dump(progress_list, fp) 109 | 110 | if __name__ == "__main__": 111 | import fire 112 | fire.Fire(Downloader.get) 113 | -------------------------------------------------------------------------------- /src/earthformer/datasets/earthnet/visualization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from math import ceil 3 | import matplotlib.pyplot as plt 4 | from matplotlib.font_manager import FontProperties 5 | import numpy as np 6 | import matplotlib.colors as clr 7 | import matplotlib.cm as cm 8 | from mpl_toolkits.axes_grid1 import make_axes_locatable 9 | import copy 10 | import pandas as pd 11 | from pathlib import Path 12 | from .earthnet_toolkit.plot_cube import colorize, gallery 13 | from .earthnet_dataloader import einops_change_layout 14 | 15 | 16 | def vis_earthnet_seq( 17 | context_np=None, 18 | target_np=None, 19 | pred_np=None, 20 | batch_idx=0, 21 | ncols: int = 10, 22 | layout: str = "NHWCT", 23 | variable: str = "rgb", 24 | vegetation_mask=None, 25 | cloud_mask=True, 26 | save_path=None, 27 | dpi=300, 28 | figsize=None, 29 | font_size=10, 30 | y_label_rotation=0, 31 | y_label_offset=(-0.06, 0.4)): 32 | r""" 33 | Visualize the `batch_idx`-th seq in a batch of EarthNet data sequence 34 | 35 | Parameters 36 | ---------- 37 | context_np: np.ndarray 38 | default_layout = "NHWCT" 39 | target_np: np.ndarray 40 | default_layout = "NHWCT" 41 | pred_np: np.ndarray 42 | default_layout = "NHWCT" 43 | ncols: int 44 | Number of columns in plot 45 | layout: str 46 | The layout of np.ndarray 47 | variable: str 48 | One of "rgb", "ndvi", "rr","pp","tg","tn","tx". Defaults to "rgb". 49 | vegetation_mask: np.ndarray 50 | If given uses this as red mask over non-vegetation. S2GLC data. Defaults to None. 51 | cloud_mask: bool 52 | If True tries to use the last channel from the cubes sat imgs as blue cloud mask, 1 where no clouds, 0 where there are clouds. Defaults to True. 53 | save_path: str 54 | If given, saves PNG to this path. Defaults to None. 55 | 56 | Returns 57 | ------- 58 | fig: plt.Figure 59 | """ 60 | fontproperties = FontProperties() 61 | fontproperties.set_family('serif') 62 | # font.set_name('Times New Roman') 63 | fontproperties.set_size(font_size) 64 | # font.set_weight("bold") 65 | 66 | default_layout = "NTHWC" 67 | data_np_list = [] 68 | label_list = [] 69 | if context_np is not None: 70 | context_np = einops_change_layout( 71 | data=context_np, 72 | einops_in_layout=" ".join(layout), 73 | einops_out_layout=" ".join(default_layout))[batch_idx, ...] 74 | data_np_list.append(context_np) 75 | label_list.append("context") 76 | if target_np is not None: 77 | target_np = einops_change_layout( 78 | data=target_np, 79 | einops_in_layout=" ".join(layout), 80 | einops_out_layout=" ".join(default_layout))[batch_idx, ...] 81 | data_np_list.append(target_np) 82 | label_list.append("target") 83 | if pred_np is not None: 84 | pred_np = einops_change_layout( 85 | data=pred_np, 86 | einops_in_layout=" ".join(layout), 87 | einops_out_layout=" ".join(default_layout))[batch_idx, ...] 88 | data_np_list.append(pred_np) 89 | label_list.append("pred") 90 | 91 | fig, axes = plt.subplots( 92 | nrows=len(data_np_list), 93 | figsize=figsize, dpi=dpi, 94 | constrained_layout=True) 95 | for data, label, ax in zip(data_np_list, label_list, axes): 96 | if variable == "rgb": 97 | targ = np.stack([data[:, :, :, 2], data[:, :, :, 1], data[:, :, :, 0]], axis=-1) 98 | targ[targ < 0] = 0 99 | targ[targ > 0.5] = 0.5 100 | targ = 2 * targ 101 | if data.shape[-1] > 4 and cloud_mask: 102 | mask = data[:, :, :, -1] 103 | zeros = np.zeros_like(targ) 104 | zeros[:, :, :, 2] = 0.1 105 | targ = np.where(np.stack([mask] * 3, -1).astype(np.uint8) | np.isnan(targ).astype(np.uint8), zeros, targ) 106 | else: 107 | targ[np.isnan(targ)] = 0 108 | 109 | elif variable == "ndvi": 110 | if data.shape[-1] == 1: 111 | targ = data[:, :, :, 0] 112 | else: 113 | targ = (data[:, :, :, 3] - data[:, :, :, 2]) / (data[:, :, :, 2] + data[:, :, :, 3] + 1e-6) 114 | if data.shape[-1] > 4 and cloud_mask: 115 | cld_mask = 1 - data[:, :, :, -1] 116 | else: 117 | cld_mask = None 118 | 119 | if vegetation_mask is not None: 120 | if isinstance(vegetation_mask, str) or isinstance(vegetation_mask, Path): 121 | vegetation_mask = np.load(vegetation_mask) 122 | if isinstance(vegetation_mask, np.lib.npyio.NpzFile): 123 | vegetation_mask = vegetation_mask["landcover"] 124 | vegetation_mask = vegetation_mask.reshape(hw, hw) 125 | lc_mask = 1 - (vegetation_mask > 63) & (vegetation_mask < 105) 126 | lc_mask = np.repeat(lc_mask[np.newaxis, :, :], targ.shape[0], axis=0) 127 | else: 128 | lc_mask = None 129 | targ = colorize(targ, colormap="ndvi", mask_red=lc_mask, mask_blue=cld_mask) 130 | 131 | elif variable == "rr": 132 | targ = data[:, :, :, 0] 133 | targ = colorize(targ, colormap='Blues', mask_red=np.isnan(targ)) 134 | elif variable == "pp": 135 | targ = data[:, :, :, 1] 136 | targ = colorize(targ, colormap='rainbow', mask_red=np.isnan(targ)) 137 | elif variable in ["tg", "tn", "tx"]: 138 | targ = data[:, :, :, 2 if variable == "tg" else 3 if variable == "tn" else 4] 139 | targ = colorize(targ, colormap='coolwarm', mask_red=np.isnan(targ)) 140 | else: 141 | raise ValueError(f"Invalid variable {variable}!") 142 | 143 | grid = gallery(targ, ncols=ncols) 144 | ax.set_ylabel(ylabel=label, fontproperties=fontproperties, rotation=y_label_rotation) 145 | ax.yaxis.set_label_coords(y_label_offset[0], y_label_offset[1]) 146 | ax.xaxis.set_ticks([]) 147 | ax.yaxis.set_ticks([]) 148 | ax.imshow(grid) 149 | 150 | if variable != "rgb": 151 | colormap = \ 152 | {"ndvi": "ndvi", "rr": "Blues", "pp": "rainbow", "tg": "coolwarm", "tn": "coolwarm", "tx": "coolwarm"}[variable] 153 | cmap = clr.LinearSegmentedColormap.from_list('ndvi', 154 | ["#cbbe9a", "#fffde4", "#bccea5", "#66985b", "#2e6a32", "#123f1e", 155 | "#0e371a", "#01140f", "#000d0a"], 156 | N=256) if colormap == "ndvi" else copy.copy(plt.get_cmap(colormap)) 157 | # divider = make_axes_locatable(plt.gca()) 158 | # cax = divider.append_axes("right", size="5%", pad=0.1) 159 | vmin, vmax = \ 160 | {"ndvi": (0, 1), "rr": (0, 50), "pp": (900, 1100), "tg": (-50, 50), "tn": (-50, 50), "tx": (-50, 50)}[variable] 161 | colorbar_label = { 162 | "ndvi": "NDVI", "rr": "Precipitation in mm/d", "pp": "Sea-level pressure in hPa", 163 | "tg": "Mean temperature in °C", "tn": "Minimum Temperature in °C", "tx": "Maximum Temperature in °C" 164 | }[variable] 165 | fig.colorbar(mappable=cm.ScalarMappable(norm=clr.Normalize(vmin=vmin, vmax=vmax), cmap=cmap), 166 | # cax=cax, 167 | label=colorbar_label, 168 | ax=axes, shrink=0.9, location="right") 169 | 170 | if save_path is not None: 171 | save_path = Path(save_path) 172 | save_path.parents[0].mkdir(parents=True, exist_ok=True) 173 | plt.savefig(save_path, dpi=300, bbox_inches='tight', transparent=True) 174 | 175 | return fig 176 | -------------------------------------------------------------------------------- /src/earthformer/datasets/enso/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/datasets/enso/__init__.py -------------------------------------------------------------------------------- /src/earthformer/datasets/moving_mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/datasets/moving_mnist/__init__.py -------------------------------------------------------------------------------- /src/earthformer/datasets/nbody/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/datasets/nbody/__init__.py -------------------------------------------------------------------------------- /src/earthformer/datasets/sevir/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/datasets/sevir/__init__.py -------------------------------------------------------------------------------- /src/earthformer/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/metrics/__init__.py -------------------------------------------------------------------------------- /src/earthformer/metrics/enso.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Union 2 | import numpy as np 3 | import torch 4 | from torchmetrics import Metric 5 | from ..datasets.enso.enso_dataloader import NINO_WINDOW_T, scale_back_sst 6 | 7 | 8 | def compute_enso_score( 9 | y_pred, y_true, 10 | acc_weight: Optional[Union[str, np.ndarray, torch.Tensor]] = None): 11 | r""" 12 | 13 | Parameters 14 | ---------- 15 | y_pred: torch.Tensor 16 | y_true: torch.Tensor 17 | acc_weight: Optional[Union[str, np.ndarray, torch.Tensor]] 18 | None: not used 19 | default: use default acc_weight specified at https://tianchi.aliyun.com/competition/entrance/531871/information 20 | np.ndarray: custom weights 21 | 22 | Returns 23 | ------- 24 | acc 25 | rmse 26 | """ 27 | pred = y_pred - y_pred.mean(dim=0, keepdim=True) # (N, 24) 28 | true = y_true - y_true.mean(dim=0, keepdim=True) # (N, 24) 29 | cor = (pred * true).sum(dim=0) / (torch.sqrt(torch.sum(pred**2, dim=0) * torch.sum(true**2, dim=0)) + 1e-6) 30 | 31 | if acc_weight is None: 32 | acc = cor.sum() 33 | else: 34 | nino_out_len = y_true.shape[-1] 35 | if acc_weight == "default": 36 | acc_weight = torch.tensor([1.5] * 4 + [2] * 7 + [3] * 7 + [4] * (nino_out_len - 18))[:nino_out_len] \ 37 | * torch.log(torch.arange(nino_out_len) + 1) 38 | elif isinstance(acc_weight, np.ndarray): 39 | acc_weight = torch.from_numpy(acc_weight[:nino_out_len]) 40 | elif isinstance(acc_weight, torch.Tensor): 41 | acc_weight = acc_weight[:nino_out_len] 42 | else: 43 | raise ValueError(f"Invalid acc_weight {acc_weight}!") 44 | acc_weight = acc_weight.to(y_pred) 45 | acc = (acc_weight * cor).sum() 46 | rmse = torch.mean((y_pred - y_true)**2, dim=0).sqrt().sum() 47 | return acc, rmse 48 | 49 | def sst_to_nino(sst: torch.Tensor, 50 | normalize_sst: bool = True, 51 | detach: bool = True): 52 | r""" 53 | 54 | Parameters 55 | ---------- 56 | sst: torch.Tensor 57 | Shape = (N, T, H, W) 58 | 59 | Returns 60 | ------- 61 | nino_index: torch.Tensor 62 | Shape = (N, T-NINO_WINDOW_T+1) 63 | """ 64 | if detach: 65 | nino_index = sst.detach() 66 | else: 67 | nino_index = sst 68 | if normalize_sst: 69 | nino_index = scale_back_sst(nino_index) 70 | nino_index = nino_index[:, :, 10:13, 19:30].mean(dim=[2, 3]) # (N, 26) 71 | nino_index = nino_index.unfold(dimension=1, size=NINO_WINDOW_T, step=1).mean(dim=2) # (N, 24) 72 | return nino_index 73 | 74 | class ENSOScore(Metric): 75 | 76 | def __init__(self, 77 | layout="NTHW", 78 | out_len=26, 79 | normalize_sst=True): 80 | super(ENSOScore, self).__init__() 81 | assert layout in ["NTHW", "NTHWC"], f"layout {layout} not supported" 82 | self.layout = layout 83 | self.normalize_sst = normalize_sst 84 | self.out_len = out_len 85 | self.nino_out_len = out_len - NINO_WINDOW_T + 1 86 | self.nino_weight = torch.from_numpy(np.array([1.5]*4 + [2]*7 + [3]*7 + [4]*(self.nino_out_len-18)) 87 | * np.log(np.arange(self.nino_out_len)+1)) 88 | 89 | self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") 90 | self.add_state("num_pixels", default=torch.tensor(0), dist_reduce_fx="sum") 91 | 92 | self.add_state("nino_preds", default=[], dist_reduce_fx="cat") 93 | self.add_state("nino_target", default=[], dist_reduce_fx="cat") 94 | # self.add_state("nino_preds", default=torch.zeros(0, 24), dist_reduce_fx="cat") 95 | # self.add_state("nino_target", default=torch.zeros(0, 24), dist_reduce_fx="cat") 96 | 97 | def update(self, 98 | preds: torch.Tensor, target: torch.Tensor, 99 | nino_preds:torch.Tensor = None, nino_target: torch.Tensor = None) -> None: 100 | r""" 101 | Parameters 102 | ---------- 103 | preds 104 | Shape = (N, T, H, W) if self.layout == "NTHW" 105 | or (N, T, H, W, 1) if self.layout == "NTHWC" 106 | target 107 | Shape = (N, T, H, W) if self.layout == "NTHW" 108 | or (N, T, H, W, 1) if self.layout == "NTHWC" 109 | nino_preds 110 | Shape = (N, T-NINO_WINDOW_T+1) 111 | nino_target 112 | Shape = (N, T-NINO_WINDOW_T+1) 113 | Returns 114 | ------- 115 | mse 116 | """ 117 | if self.layout == "NTHWC": 118 | preds = preds[..., 0] 119 | target = target[..., 0] 120 | diff = preds - target 121 | sum_squared_error = torch.sum(diff * diff) 122 | num_pixels = target.numel() 123 | self.sum_squared_error += sum_squared_error 124 | self.num_pixels += num_pixels 125 | 126 | if nino_preds is None: 127 | nino_preds = sst_to_nino(sst=preds, 128 | normalize_sst=self.normalize_sst) 129 | if nino_target is None: 130 | nino_target = sst_to_nino(sst=target, 131 | normalize_sst=self.normalize_sst) 132 | nino_preds_list = [ele for ele in nino_preds] 133 | nino_target_list = [ele for ele in nino_target] 134 | self.nino_preds.extend(nino_preds_list) 135 | self.nino_target.extend(nino_target_list) 136 | # self.nino_preds = torch.cat((self.nino_preds, nino_preds)) 137 | # self.nino_target = torch.cat((self.nino_target, nino_target)) 138 | 139 | def compute(self) -> Tuple[float]: 140 | mse = self.sum_squared_error / self.num_pixels 141 | # print(f"self.nino_preds.shape = {self.nino_preds.shape}") 142 | y_pred=torch.stack(self.nino_preds, dim=0) 143 | y_true=torch.stack(self.nino_target, dim=0) 144 | # y_pred=self.nino_preds 145 | # y_true=self.nino_target 146 | acc, nino_rmse = compute_enso_score( 147 | y_pred=y_pred, y_true=y_true, 148 | acc_weight=self.nino_weight) 149 | return acc.cpu().item(), mse.cpu().item(), 150 | -------------------------------------------------------------------------------- /src/earthformer/metrics/sevir.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | import numpy as np 3 | import torch 4 | from torchmetrics import Metric 5 | from .skill_scores import _threshold 6 | from ..datasets.sevir.sevir_dataloader import SEVIRDataLoader 7 | 8 | 9 | class SEVIRSkillScore(Metric): 10 | r""" 11 | The calculation of skill scores in SEVIR challenge is slightly different: 12 | `mCSI = sum(mCSI_t) / T` 13 | See https://github.com/MIT-AI-Accelerator/sevir_challenges/blob/dev/radar_nowcasting/RadarNowcastBenchmarks.ipynb for more details. 14 | """ 15 | full_state_update: bool = True 16 | 17 | def __init__(self, 18 | layout: str = "NHWT", 19 | mode: str = "0", 20 | seq_len: Optional[int] = None, 21 | preprocess_type: str = "sevir", 22 | threshold_list: Sequence[int] = (16, 74, 133, 160, 181, 219), 23 | metrics_list: Sequence[str] = ("csi", "bias", "sucr", "pod"), 24 | eps: float = 1e-4, 25 | dist_sync_on_step: bool = False, 26 | ): 27 | r""" 28 | Parameters 29 | ---------- 30 | seq_len 31 | layout 32 | mode: str 33 | Should be in ("0", "1", "2") 34 | "0": 35 | cumulates hits/misses/fas of all test pixels 36 | score_avg takes average over all thresholds 37 | return 38 | score_thresh shape = (1, ) 39 | score_avg shape = (1, ) 40 | "1": 41 | cumulates hits/misses/fas of each step 42 | score_avg takes average over all thresholds while keeps the seq_len dim 43 | return 44 | score_thresh shape = (seq_len, ) 45 | score_avg shape = (seq_len, ) 46 | "2": 47 | cumulates hits/misses/fas of each step 48 | score_avg takes average over all thresholds, then takes average over the seq_len dim 49 | return 50 | score_thresh shape = (1, ) 51 | score_avg shape = (1, ) 52 | preprocess_type 53 | threshold_list 54 | dist_sync_on_step 55 | """ 56 | super().__init__(dist_sync_on_step=dist_sync_on_step) 57 | self.layout = layout 58 | self.preprocess_type = preprocess_type 59 | self.threshold_list = threshold_list 60 | self.metrics_list = metrics_list 61 | self.eps = eps 62 | self.mode = mode 63 | self.seq_len = seq_len 64 | if mode in ("0", ): 65 | self.keep_seq_len_dim = False 66 | state_shape = (len(self.threshold_list), ) 67 | elif mode in ("1", "2"): 68 | self.keep_seq_len_dim = True 69 | assert isinstance(self.seq_len, int), "seq_len must be provided when we need to keep seq_len dim." 70 | state_shape = (len(self.threshold_list), self.seq_len) 71 | 72 | else: 73 | raise NotImplementedError(f"mode {mode} not supported!") 74 | 75 | self.add_state("hits", 76 | default=torch.zeros(state_shape), 77 | dist_reduce_fx="sum") 78 | self.add_state("misses", 79 | default=torch.zeros(state_shape), 80 | dist_reduce_fx="sum") 81 | self.add_state("fas", 82 | default=torch.zeros(state_shape), 83 | dist_reduce_fx="sum") 84 | 85 | @property 86 | def hits_misses_fas_reduce_dims(self): 87 | if not hasattr(self, "_hits_misses_fas_reduce_dims"): 88 | seq_dim = self.layout.find('T') 89 | self._hits_misses_fas_reduce_dims = list(range(len(self.layout))) 90 | if self.keep_seq_len_dim: 91 | self._hits_misses_fas_reduce_dims.pop(seq_dim) 92 | return self._hits_misses_fas_reduce_dims 93 | 94 | @staticmethod 95 | def pod(hits, misses, fas, eps): 96 | return hits / (hits + misses + eps) 97 | 98 | @staticmethod 99 | def sucr(hits, misses, fas, eps): 100 | return hits / (hits + fas + eps) 101 | 102 | @staticmethod 103 | def csi(hits, misses, fas, eps): 104 | return hits / (hits + misses + fas + eps) 105 | 106 | @staticmethod 107 | def bias(hits, misses, fas, eps): 108 | bias = (hits + fas) / (hits + misses + eps) 109 | logbias = torch.pow(bias / torch.log(torch.tensor(2.0)), 2.0) 110 | return logbias 111 | 112 | def calc_seq_hits_misses_fas(self, pred, target, threshold): 113 | """ 114 | Parameters 115 | ---------- 116 | pred, target: torch.Tensor 117 | threshold: int 118 | 119 | Returns 120 | ------- 121 | hits, misses, fas: torch.Tensor 122 | each has shape (seq_len, ) 123 | """ 124 | with torch.no_grad(): 125 | t, p = _threshold(target, pred, threshold) 126 | hits = torch.sum(t * p, dim=self.hits_misses_fas_reduce_dims).int() 127 | misses = torch.sum(t * (1 - p), dim=self.hits_misses_fas_reduce_dims).int() 128 | fas = torch.sum((1 - t) * p, dim=self.hits_misses_fas_reduce_dims).int() 129 | return hits, misses, fas 130 | 131 | def preprocess(self, pred, target): 132 | if self.preprocess_type == "sevir": 133 | pred = SEVIRDataLoader.process_data_dict_back( 134 | data_dict={'vil': pred.detach().float()})['vil'] 135 | target = SEVIRDataLoader.process_data_dict_back( 136 | data_dict={'vil': target.detach().float()})['vil'] 137 | else: 138 | raise NotImplementedError 139 | return pred, target 140 | 141 | def update(self, pred: torch.Tensor, target: torch.Tensor): 142 | pred, target = self.preprocess(pred, target) 143 | for i, threshold in enumerate(self.threshold_list): 144 | hits, misses, fas = self.calc_seq_hits_misses_fas(pred, target, threshold) 145 | self.hits[i] += hits 146 | self.misses[i] += misses 147 | self.fas[i] += fas 148 | 149 | def compute(self): 150 | metrics_dict = { 151 | 'pod': self.pod, 152 | 'csi': self.csi, 153 | 'sucr': self.sucr, 154 | 'bias': self.bias} 155 | ret = {} 156 | for threshold in self.threshold_list: 157 | ret[threshold] = {} 158 | ret["avg"] = {} 159 | for metrics in self.metrics_list: 160 | if self.keep_seq_len_dim: 161 | score_avg = np.zeros((self.seq_len, )) 162 | else: 163 | score_avg = 0 164 | # shape = (len(threshold_list), seq_len) if self.keep_seq_len_dim, 165 | # else shape = (len(threshold_list),) 166 | scores = metrics_dict[metrics](self.hits, self.misses, self.fas, self.eps) 167 | scores = scores.detach().cpu().numpy() 168 | for i, threshold in enumerate(self.threshold_list): 169 | if self.keep_seq_len_dim: 170 | score = scores[i] # shape = (seq_len, ) 171 | else: 172 | score = scores[i].item() # shape = (1, ) 173 | if self.mode in ("0", "1"): 174 | ret[threshold][metrics] = score 175 | elif self.mode in ("2", ): 176 | ret[threshold][metrics] = np.mean(score).item() 177 | else: 178 | raise NotImplementedError 179 | score_avg += score 180 | score_avg /= len(self.threshold_list) 181 | if self.mode in ("0", "1"): 182 | ret["avg"][metrics] = score_avg 183 | elif self.mode in ("2",): 184 | ret["avg"][metrics] = np.mean(score_avg).item() 185 | else: 186 | raise NotImplementedError 187 | return ret 188 | -------------------------------------------------------------------------------- /src/earthformer/metrics/skill_scores.py: -------------------------------------------------------------------------------- 1 | """Code is adapted from https://github.com/MIT-AI-Accelerator/neurips-2020-sevir. Their license is MIT License.""" 2 | 3 | import numpy as np 4 | import torch 5 | from torchmetrics import Metric 6 | from torch import nn 7 | from torch.nn import init 8 | import torch.nn.functional as F 9 | 10 | 11 | def _threshold(target, pred ,T): 12 | """ 13 | Returns binary tensors t,p the same shape as target & pred. t = 1 wherever 14 | target > t. p =1 wherever pred > t. p and t are set to 0 wherever EITHER 15 | t or p are nan. 16 | This is useful for counts that don't involve correct rejections. 17 | 18 | Parameters 19 | ---------- 20 | target 21 | torch.Tensor 22 | pred 23 | torch.Tensor 24 | T 25 | numeric_type: threshold 26 | Returns 27 | ------- 28 | t 29 | p 30 | """ 31 | t = (target >= T).float() 32 | p = (pred >= T).float() 33 | is_nan = torch.logical_or(torch.isnan(target), 34 | torch.isnan(pred)) 35 | t[is_nan] = 0 36 | p[is_nan] = 0 37 | return t, p 38 | 39 | def _calc_hits_misses_fas(t, p): 40 | hits = torch.sum(t * p) 41 | misses = torch.sum(t * (1 - p)) 42 | fas = torch.sum((1 - t) * p) 43 | return hits, misses, fas 44 | 45 | def _pod(target, pred ,T, eps=1e-6): 46 | """ 47 | Single channel version of probability_of_detection 48 | """ 49 | t, p = _threshold(target, pred ,T) 50 | hits, misses, fas = _calc_hits_misses_fas(t, p) 51 | # return (hits + eps) / (hits + misses + eps) 52 | return hits / (hits + misses + eps) 53 | 54 | 55 | def _sucr(target, pred, T, eps=1e-6): 56 | """ 57 | Single channel version of success_rate 58 | """ 59 | t, p = _threshold(target, pred, T) 60 | hits, misses, fas = _calc_hits_misses_fas(t, p) 61 | # return (hits + eps) / (hits + fas + eps) 62 | return hits / (hits + fas + eps) 63 | 64 | def _csi(target, pred, T, eps=1e-6): 65 | """ 66 | Single channel version of csi 67 | """ 68 | t, p = _threshold(target, pred, T) 69 | hits, misses, fas = _calc_hits_misses_fas(t, p) 70 | # return (hits + eps) / (hits + misses + fas + eps) 71 | return hits / (hits + misses + fas + eps) 72 | 73 | def _bias(target, pred, T, eps=1e-6): 74 | """ 75 | Single channel version of csi 76 | """ 77 | t, p = _threshold(target, pred, T) 78 | hits, misses, fas = _calc_hits_misses_fas(t, p) 79 | # return (hits + fas + eps) / (hits + misses + eps) 80 | return (hits + fas) / (hits + misses + eps) 81 | -------------------------------------------------------------------------------- /src/earthformer/metrics/torchmetrics_wo_compute.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | from torchmetrics import Metric 4 | from torchmetrics.utilities.exceptions import TorchMetricsUserError 5 | 6 | 7 | class MetricsUpdateWithoutCompute(Metric): 8 | r""" 9 | Delete `batch_val = self.compute()` in `forward()` to reduce unnecessary computation 10 | """ 11 | 12 | def __init__(self, **kwargs: Any): 13 | super(MetricsUpdateWithoutCompute, self).__init__(**kwargs) 14 | 15 | @torch.jit.unused 16 | def forward(self, *args: Any, **kwargs: Any): 17 | """``forward`` serves the dual purpose of both computing the metric on the current batch of inputs but also 18 | add the batch statistics to the overall accumululating metric state. 19 | 20 | Input arguments are the exact same as corresponding ``update`` method. The returned output is the exact same as 21 | the output of ``compute``. 22 | """ 23 | # check if states are already synced 24 | if self._is_synced: 25 | raise TorchMetricsUserError( 26 | "The Metric shouldn't be synced when performing ``forward``. " 27 | "HINT: Did you forget to call ``unsync`` ?." 28 | ) 29 | 30 | if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step: 31 | self._forward_full_state_update_without_compute(*args, **kwargs) 32 | else: 33 | self._forward_reduce_state_update_without_compute(*args, **kwargs) 34 | 35 | def _forward_full_state_update_without_compute(self, *args: Any, **kwargs: Any): 36 | """forward computation using two calls to `update` to calculate the metric value on the current batch and 37 | accumulate global state. 38 | 39 | Doing this secures that metrics that need access to the full metric state during `update` works as expected. 40 | """ 41 | # global accumulation 42 | self.update(*args, **kwargs) 43 | _update_count = self._update_count 44 | 45 | self._to_sync = self.dist_sync_on_step # type: ignore 46 | # skip restore cache operation from compute as cache is stored below. 47 | self._should_unsync = False 48 | # skip computing on cpu for the batch 49 | _temp_compute_on_cpu = self.compute_on_cpu 50 | self.compute_on_cpu = False 51 | 52 | # save context before switch 53 | cache = {attr: getattr(self, attr) for attr in self._defaults} 54 | 55 | # call reset, update, compute, on single batch 56 | self._enable_grad = True # allow grads for batch computation 57 | self.reset() 58 | self.update(*args, **kwargs) 59 | 60 | # restore context 61 | for attr, val in cache.items(): 62 | setattr(self, attr, val) 63 | self._update_count = _update_count 64 | 65 | # restore context 66 | self._is_synced = False 67 | self._should_unsync = True 68 | self._to_sync = self.sync_on_compute 69 | self._computed = None 70 | self._enable_grad = False 71 | self.compute_on_cpu = _temp_compute_on_cpu 72 | 73 | def _forward_reduce_state_update_without_compute(self, *args: Any, **kwargs: Any): 74 | """forward computation using single call to `update` to calculate the metric value on the current batch and 75 | accumulate global state. 76 | 77 | This can be done when the global metric state is a sinple reduction of batch states. 78 | """ 79 | # store global state and reset to default 80 | global_state = {attr: getattr(self, attr) for attr in self._defaults.keys()} 81 | _update_count = self._update_count 82 | self.reset() 83 | 84 | # local synchronization settings 85 | self._to_sync = self.dist_sync_on_step 86 | self._should_unsync = False 87 | _temp_compute_on_cpu = self.compute_on_cpu 88 | self.compute_on_cpu = False 89 | self._enable_grad = True # allow grads for batch computation 90 | 91 | # calculate batch state and compute batch value 92 | self.update(*args, **kwargs) 93 | 94 | # reduce batch and global state 95 | self._update_count = _update_count + 1 96 | with torch.no_grad(): 97 | self._reduce_states(global_state) 98 | 99 | # restore context 100 | self._is_synced = False 101 | self._should_unsync = True 102 | self._to_sync = self.sync_on_compute 103 | self._computed = None 104 | self._enable_grad = False 105 | self.compute_on_cpu = _temp_compute_on_cpu 106 | -------------------------------------------------------------------------------- /src/earthformer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/utils/__init__.py -------------------------------------------------------------------------------- /src/earthformer/utils/apex_ddp.py: -------------------------------------------------------------------------------- 1 | # Find the original code and discussion at https://github.com/PyTorchLightning/pytorch-lightning/discussions/10922 2 | # We will need to use the AMP implementation from apex because https://discuss.pytorch.org/t/using-torch-utils-checkpoint-checkpoint-with-dataparallel/78452 3 | 4 | from apex.parallel import DistributedDataParallel 5 | from pytorch_lightning.strategies.ddp import DDPStrategy 6 | from pytorch_lightning.overrides.base import ( 7 | _LightningModuleWrapperBase, 8 | _LightningPrecisionModuleWrapperBase, 9 | ) 10 | 11 | def unwrap_lightning_module(wrapped_model): 12 | model = wrapped_model 13 | if isinstance(model, DistributedDataParallel): 14 | model = unwrap_lightning_module(model.module) 15 | if isinstance( 16 | model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase) 17 | ): 18 | model = unwrap_lightning_module(model.module) 19 | return model 20 | 21 | 22 | class ApexDDPStrategy(DDPStrategy): 23 | def _setup_model(self, model): 24 | return DistributedDataParallel(model, delay_allreduce=False) 25 | 26 | @property 27 | def lightning_module(self): 28 | return unwrap_lightning_module(self._model) 29 | 30 | 31 | if __name__ == "__main__": 32 | # Correct usage of apex DDP, which can avoid error caused by using `torch.utils.checkpoint` 33 | # when using `strategy="ddp"` in pl. 34 | import pytorch_lightning as pl 35 | trainer = pl.Trainer( 36 | strategy=ApexDDPStrategy(find_unused_parameters=False, delay_allreduce=True), # "ddp", 37 | ) 38 | -------------------------------------------------------------------------------- /src/earthformer/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Callable, Dict, Optional, Union, List 4 | import torch 5 | from pytorch_lightning.utilities.cloud_io import load as pl_load 6 | 7 | 8 | def average_checkpoints(checkpoint_paths: List[str] = None): 9 | r""" Code is adapted from https://github.com/awslabs/autogluon/blob/a818545e047f2bcda13569568a7fd611abdfc876/multimodal/src/autogluon/multimodal/utils/checkpoint.py#L13 10 | 11 | Average a list of checkpoints' state_dicts. 12 | Reference: https://github.com/rwightman/pytorch-image-models/blob/master/avg_checkpoints.py 13 | Parameters 14 | ---------- 15 | checkpoint_paths 16 | A list of model checkpoint paths. 17 | Returns 18 | ------- 19 | The averaged state_dict. 20 | """ 21 | if len(checkpoint_paths) > 1: 22 | avg_state_dict = {} 23 | avg_counts = {} 24 | for per_path in checkpoint_paths: 25 | state_dict = torch.load(per_path, map_location=torch.device("cpu"))["state_dict"] 26 | for k, v in state_dict.items(): 27 | if k not in avg_state_dict: 28 | avg_state_dict[k] = v.clone().to(dtype=torch.float64) 29 | avg_counts[k] = 1 30 | else: 31 | avg_state_dict[k] += v.to(dtype=torch.float64) 32 | avg_counts[k] += 1 33 | del state_dict 34 | 35 | for k, v in avg_state_dict.items(): 36 | v.div_(avg_counts[k]) 37 | 38 | # convert to float32. 39 | float32_info = torch.finfo(torch.float32) 40 | for k in avg_state_dict: 41 | avg_state_dict[k].clamp_(float32_info.min, float32_info.max).to(dtype=torch.float32) 42 | else: 43 | avg_state_dict = torch.load(checkpoint_paths[0], map_location=torch.device("cpu"))["state_dict"] 44 | 45 | return avg_state_dict 46 | 47 | def average_pl_checkpoints(pl_checkpoint_paths: List[str] = None, delete_prefix_len: int = len("")): 48 | r""" Code is adapted from https://github.com/awslabs/autogluon/blob/a818545e047f2bcda13569568a7fd611abdfc876/multimodal/src/autogluon/multimodal/utils/checkpoint.py#L13 49 | 50 | Average a list of checkpoints' state_dicts. 51 | Reference: https://github.com/rwightman/pytorch-image-models/blob/master/avg_checkpoints.py 52 | Parameters 53 | ---------- 54 | checkpoint_paths 55 | A list of model checkpoint paths. 56 | Returns 57 | ------- 58 | The averaged state_dict. 59 | """ 60 | if len(pl_checkpoint_paths) > 1: 61 | avg_state_dict = {} 62 | avg_counts = {} 63 | for per_path in pl_checkpoint_paths: 64 | state_dict = pl_ckpt_to_pytorch_state_dict(per_path, 65 | map_location=torch.device("cpu"), 66 | delete_prefix_len=delete_prefix_len) 67 | for k, v in state_dict.items(): 68 | if k not in avg_state_dict: 69 | avg_state_dict[k] = v.clone().to(dtype=torch.float64) 70 | avg_counts[k] = 1 71 | else: 72 | avg_state_dict[k] += v.to(dtype=torch.float64) 73 | avg_counts[k] += 1 74 | del state_dict 75 | 76 | for k, v in avg_state_dict.items(): 77 | v.div_(avg_counts[k]) 78 | 79 | # convert to float32. 80 | float32_info = torch.finfo(torch.float32) 81 | for k in avg_state_dict: 82 | avg_state_dict[k].clamp_(float32_info.min, float32_info.max).to(dtype=torch.float32) 83 | else: 84 | avg_state_dict = pl_ckpt_to_pytorch_state_dict(pl_checkpoint_paths[0], 85 | map_location=torch.device("cpu"), 86 | delete_prefix_len=delete_prefix_len) 87 | 88 | return avg_state_dict 89 | 90 | def pl_ckpt_to_pytorch_state_dict( 91 | checkpoint_path: str, 92 | map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, 93 | delete_prefix_len: int = len("")): 94 | r""" 95 | Parameters 96 | ---------- 97 | checkpoint_path: str 98 | map_location 99 | A function, torch.device, string or a dict specifying how to remap storage locations. 100 | The same as the arg `map_location` in `torch.load()`. 101 | delete_prefix_len: int 102 | Delete the first several characters in the keys of state_dict. 103 | 104 | Returns 105 | ------- 106 | pytorch_state_dict: OrderedDict 107 | """ 108 | if map_location is not None: 109 | checkpoint = pl_load(checkpoint_path, map_location=map_location) 110 | else: 111 | checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) 112 | pl_ckpt_state_dict = checkpoint["state_dict"] 113 | pytorch_state_dict = {key[delete_prefix_len:]: val 114 | for key, val in pl_ckpt_state_dict.items()} 115 | return pytorch_state_dict 116 | 117 | 118 | def s3_download_pretrained_ckpt(ckpt_name, save_dir=None, exist_ok=False): 119 | if save_dir is None: 120 | from ..config import cfg 121 | save_dir = cfg.pretrained_checkpoints_dir 122 | if os.path.exists(os.path.join(save_dir, ckpt_name)) and not exist_ok: 123 | warnings.warn(f"Checkpoint file {os.path.join(save_dir, ckpt_name)} already exists!") 124 | else: 125 | os.makedirs(save_dir, exist_ok=True) 126 | os.system(f"aws s3 cp --no-sign-request s3://earthformer/pretrained_checkpoints/{ckpt_name} " 127 | f"{save_dir}") 128 | -------------------------------------------------------------------------------- /src/earthformer/utils/layout.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def layout_to_in_out_slice(layout, in_len, out_len=None): 7 | t_axis = layout.find("T") 8 | num_axes = len(layout) 9 | in_slice = [slice(None, None), ] * num_axes 10 | out_slice = deepcopy(in_slice) 11 | in_slice[t_axis] = slice(None, in_len) 12 | if out_len is None: 13 | out_slice[t_axis] = slice(in_len, None) 14 | else: 15 | out_slice[t_axis] = slice(in_len, in_len + out_len) 16 | return in_slice, out_slice 17 | 18 | def change_layout_np(data, 19 | in_layout='NHWT', out_layout='NHWT', 20 | ret_contiguous=False): 21 | # first convert to 'NHWT' 22 | if in_layout == 'NHWT': 23 | pass 24 | elif in_layout == 'NTHW': 25 | data = np.transpose(data, 26 | axes=(0, 2, 3, 1)) 27 | elif in_layout == 'NWHT': 28 | data = np.transpose(data, 29 | axes=(0, 2, 1, 3)) 30 | elif in_layout == 'NTCHW': 31 | data = data[:, :, 0, :, :] 32 | data = np.transpose(data, 33 | axes=(0, 2, 3, 1)) 34 | elif in_layout == 'NTHWC': 35 | data = data[:, :, :, :, 0] 36 | data = np.transpose(data, 37 | axes=(0, 2, 3, 1)) 38 | elif in_layout == 'NTWHC': 39 | data = data[:, :, :, :, 0] 40 | data = np.transpose(data, 41 | axes=(0, 3, 2, 1)) 42 | elif in_layout == 'TNHW': 43 | data = np.transpose(data, 44 | axes=(1, 2, 3, 0)) 45 | elif in_layout == 'TNCHW': 46 | data = data[:, :, 0, :, :] 47 | data = np.transpose(data, 48 | axes=(1, 2, 3, 0)) 49 | else: 50 | raise NotImplementedError 51 | 52 | if out_layout == 'NHWT': 53 | pass 54 | elif out_layout == 'NTHW': 55 | data = np.transpose(data, 56 | axes=(0, 3, 1, 2)) 57 | elif out_layout == 'NWHT': 58 | data = np.transpose(data, 59 | axes=(0, 2, 1, 3)) 60 | elif out_layout == 'NTCHW': 61 | data = np.transpose(data, 62 | axes=(0, 3, 1, 2)) 63 | data = np.expand_dims(data, axis=2) 64 | elif out_layout == 'NTHWC': 65 | data = np.transpose(data, 66 | axes=(0, 3, 1, 2)) 67 | data = np.expand_dims(data, axis=-1) 68 | elif out_layout == 'NTWHC': 69 | data = np.transpose(data, 70 | axes=(0, 3, 2, 1)) 71 | data = np.expand_dims(data, axis=-1) 72 | elif out_layout == 'TNHW': 73 | data = np.transpose(data, 74 | axes=(3, 0, 1, 2)) 75 | elif out_layout == 'TNCHW': 76 | data = np.transpose(data, 77 | axes=(3, 0, 1, 2)) 78 | data = np.expand_dims(data, axis=2) 79 | else: 80 | raise NotImplementedError 81 | if ret_contiguous: 82 | data = data.ascontiguousarray() 83 | return data 84 | 85 | def change_layout_torch(data, 86 | in_layout='NHWT', out_layout='NHWT', 87 | ret_contiguous=False): 88 | # first convert to 'NHWT' 89 | if in_layout == 'NHWT': 90 | pass 91 | elif in_layout == 'NTHW': 92 | data = data.permute(0, 2, 3, 1) 93 | elif in_layout == 'NTCHW': 94 | data = data[:, :, 0, :, :] 95 | data = data.permute(0, 2, 3, 1) 96 | elif in_layout == 'NTHWC': 97 | data = data[:, :, :, :, 0] 98 | data = data.permute(0, 2, 3, 1) 99 | elif in_layout == 'TNHW': 100 | data = data.permute(1, 2, 3, 0) 101 | elif in_layout == 'TNCHW': 102 | data = data[:, :, 0, :, :] 103 | data = data.permute(1, 2, 3, 0) 104 | else: 105 | raise NotImplementedError 106 | 107 | if out_layout == 'NHWT': 108 | pass 109 | elif out_layout == 'NTHW': 110 | data = data.permute(0, 3, 1, 2) 111 | elif out_layout == 'NTCHW': 112 | data = data.permute(0, 3, 1, 2) 113 | data = torch.unsqueeze(data, dim=2) 114 | elif out_layout == 'NTHWC': 115 | data = data.permute(0, 3, 1, 2) 116 | data = torch.unsqueeze(data, dim=-1) 117 | elif out_layout == 'TNHW': 118 | data = data.permute(3, 0, 1, 2) 119 | elif out_layout == 'TNCHW': 120 | data = data.permute(3, 0, 1, 2) 121 | data = torch.unsqueeze(data, dim=2) 122 | else: 123 | raise NotImplementedError 124 | if ret_contiguous: 125 | data = data.contiguous() 126 | return data 127 | -------------------------------------------------------------------------------- /src/earthformer/utils/optim.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | import torch 3 | 4 | 5 | if version.parse(torch.__version__) >= version.parse('1.11.0'): 6 | # Starting from torch>=1.11.0, the attribute `optimizer` is set in SequentialLR. See https://github.com/pytorch/pytorch/pull/67406 7 | from torch.optim.lr_scheduler import SequentialLR 8 | else: 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | from bisect import bisect_right 11 | 12 | class SequentialLR(_LRScheduler): 13 | """Receives the list of schedulers that is expected to be called sequentially during 14 | optimization process and milestone points that provides exact intervals to reflect 15 | which scheduler is supposed to be called at a given epoch. 16 | 17 | Args: 18 | schedulers (list): List of chained schedulers. 19 | milestones (list): List of integers that reflects milestone points. 20 | 21 | Example: 22 | >>> # Assuming optimizer uses lr = 1. for all groups 23 | >>> # lr = 0.1 if epoch == 0 24 | >>> # lr = 0.1 if epoch == 1 25 | >>> # lr = 0.9 if epoch == 2 26 | >>> # lr = 0.81 if epoch == 3 27 | >>> # lr = 0.729 if epoch == 4 28 | >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) 29 | >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) 30 | >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) 31 | >>> for epoch in range(100): 32 | >>> train(...) 33 | >>> validate(...) 34 | >>> scheduler.step() 35 | """ 36 | 37 | def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): 38 | for scheduler_idx in range(1, len(schedulers)): 39 | if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): 40 | raise ValueError( 41 | "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " 42 | "got schedulers at index {} and {} to be different".format(0, scheduler_idx) 43 | ) 44 | if (len(milestones) != len(schedulers) - 1): 45 | raise ValueError( 46 | "Sequential Schedulers expects number of schedulers provided to be one more " 47 | "than the number of milestone points, but got number of schedulers {} and the " 48 | "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) 49 | ) 50 | self.optimizer = optimizer 51 | self._schedulers = schedulers 52 | self._milestones = milestones 53 | self.last_epoch = last_epoch + 1 54 | 55 | def step(self): 56 | self.last_epoch += 1 57 | idx = bisect_right(self._milestones, self.last_epoch) 58 | if idx > 0 and self._milestones[idx - 1] == self.last_epoch: 59 | self._schedulers[idx].step(0) 60 | else: 61 | self._schedulers[idx].step() 62 | 63 | def state_dict(self): 64 | """Returns the state of the scheduler as a :class:`dict`. 65 | 66 | It contains an entry for every variable in self.__dict__ which 67 | is not the optimizer. 68 | The wrapped scheduler states will also be saved. 69 | """ 70 | state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} 71 | state_dict['_schedulers'] = [None] * len(self._schedulers) 72 | 73 | for idx, s in enumerate(self._schedulers): 74 | state_dict['_schedulers'][idx] = s.state_dict() 75 | 76 | return state_dict 77 | 78 | def load_state_dict(self, state_dict): 79 | """Loads the schedulers state. 80 | 81 | Args: 82 | state_dict (dict): scheduler state. Should be an object returned 83 | from a call to :meth:`state_dict`. 84 | """ 85 | _schedulers = state_dict.pop('_schedulers') 86 | self.__dict__.update(state_dict) 87 | # Restore state_dict keys in order to prevent side effects 88 | # https://github.com/pytorch/pytorch/issues/32756 89 | state_dict['_schedulers'] = _schedulers 90 | 91 | for idx, s in enumerate(_schedulers): 92 | self._schedulers[idx].load_state_dict(s) 93 | 94 | def warmup_lambda(warmup_steps, min_lr_ratio=0.1): 95 | def ret_lambda(epoch): 96 | if epoch <= warmup_steps: 97 | return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps 98 | else: 99 | return 1.0 100 | return ret_lambda 101 | -------------------------------------------------------------------------------- /src/earthformer/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Licensed to the GluonNLP team under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | """Create a registry.""" 18 | 19 | from typing import Optional, List 20 | import json 21 | from json import JSONDecodeError 22 | 23 | 24 | class Registry: 25 | """Create the registry that will map name to object. This facilitates the users to create 26 | custom registry. 27 | 28 | Parameters 29 | ---------- 30 | name 31 | The name of the registry 32 | 33 | Examples 34 | -------- 35 | 36 | >>> from earthformer.utils.registry import Registry 37 | >>> # Create a registry 38 | >>> MODEL_REGISTRY = Registry('MODEL') 39 | >>> 40 | >>> # To register a class/function with decorator 41 | >>> @MODEL_REGISTRY.register() 42 | ... class MyModel: 43 | ... pass 44 | >>> @MODEL_REGISTRY.register() 45 | ... def my_model(): 46 | ... return 47 | >>> 48 | >>> # To register a class object with decorator and provide nickname: 49 | >>> @MODEL_REGISTRY.register('test_class') 50 | ... class MyModelWithNickName: 51 | ... pass 52 | >>> @MODEL_REGISTRY.register('test_function') 53 | ... def my_model_with_nick_name(): 54 | ... return 55 | >>> 56 | >>> # To register a class/function object by function call 57 | ... class MyModel2: 58 | ... pass 59 | >>> MODEL_REGISTRY.register(MyModel2) 60 | >>> # To register with a given name 61 | >>> MODEL_REGISTRY.register('my_model2', MyModel2) 62 | >>> # To list all the registered objects: 63 | >>> MODEL_REGISTRY.list_keys() 64 | 65 | ['MyModel', 'my_model', 'test_class', 'test_function', 'MyModel2', 'my_model2'] 66 | 67 | >>> # To get the registered object/class 68 | >>> MODEL_REGISTRY.get('test_class') 69 | 70 | __main__.MyModelWithNickName 71 | 72 | """ 73 | 74 | def __init__(self, name: str) -> None: 75 | self._name: str = name 76 | self._obj_map: dict[str, object] = dict() 77 | 78 | def _do_register(self, name: str, obj: object) -> None: 79 | assert ( 80 | name not in self._obj_map 81 | ), "An object named '{}' was already registered in '{}' registry!".format( 82 | name, self._name 83 | ) 84 | self._obj_map[name] = obj 85 | 86 | def register(self, *args): 87 | """ 88 | Register the given object under either the nickname or `obj.__name__`. It can be used as 89 | either a decorator or not. See docstring of this class for usage. 90 | """ 91 | if len(args) == 2: 92 | # Register an object with nick name by function call 93 | nickname, obj = args 94 | self._do_register(nickname, obj) 95 | elif len(args) == 1: 96 | if isinstance(args[0], str): 97 | # Register an object with nick name by decorator 98 | nickname = args[0] 99 | def deco(func_or_class: object) -> object: 100 | self._do_register(nickname, func_or_class) 101 | return func_or_class 102 | return deco 103 | else: 104 | # Register an object by function call 105 | self._do_register(args[0].__name__, args[0]) 106 | elif len(args) == 0: 107 | # Register an object by decorator 108 | def deco(func_or_class: object) -> object: 109 | self._do_register(func_or_class.__name__, func_or_class) 110 | return func_or_class 111 | return deco 112 | else: 113 | raise ValueError('Do not support the usage!') 114 | 115 | def get(self, name: str) -> object: 116 | ret = self._obj_map.get(name) 117 | if ret is None: 118 | raise KeyError( 119 | "No object named '{}' found in '{}' registry!".format( 120 | name, self._name 121 | ) 122 | ) 123 | return ret 124 | 125 | def list_keys(self) -> List: 126 | return list(self._obj_map.keys()) 127 | 128 | def __repr__(self) -> str: 129 | s = '{name}(keys={keys})'.format(name=self._name, 130 | keys=self.list_keys()) 131 | return s 132 | 133 | def create(self, name: str, *args, **kwargs) -> object: 134 | """Create the class object with the given args and kwargs 135 | 136 | Parameters 137 | ---------- 138 | name 139 | The name in the registry 140 | args 141 | kwargs 142 | 143 | Returns 144 | ------- 145 | ret 146 | The created object 147 | """ 148 | obj = self.get(name) 149 | try: 150 | return obj(*args, **kwargs) 151 | except Exception as exp: 152 | print('Cannot create name="{}" --> {} with the provided arguments!\n' 153 | ' args={},\n' 154 | ' kwargs={},\n' 155 | .format(name, obj, args, kwargs)) 156 | raise exp 157 | 158 | def create_with_json(self, name: str, json_str: str): 159 | """ 160 | 161 | Parameters 162 | ---------- 163 | name 164 | json_str 165 | 166 | Returns 167 | ------- 168 | 169 | """ 170 | try: 171 | args = json.loads(json_str) 172 | except JSONDecodeError: 173 | raise ValueError('Unable to decode the json string: json_str="{}"' 174 | .format(json_str)) 175 | if isinstance(args, (list, tuple)): 176 | return self.create(name, *args) 177 | elif isinstance(args, dict): 178 | return self.create(name, **args) 179 | else: 180 | raise NotImplementedError('The format of json string is not supported! We only support ' 181 | 'list/dict. json_str="{}".' 182 | .format(json_str)) 183 | -------------------------------------------------------------------------------- /src/earthformer/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | from typing import Optional 4 | import tqdm 5 | import hashlib 6 | import sys 7 | import functools 8 | import uuid 9 | import requests 10 | 11 | 12 | if not sys.platform.startswith('win32'): 13 | # refer to https://github.com/untitaker/python-atomicwrites 14 | def replace_file(src, dst): 15 | """Implement atomic os.replace with linux and OSX. 16 | Parameters 17 | ---------- 18 | src : source file path 19 | dst : destination file path 20 | """ 21 | try: 22 | os.rename(src, dst) 23 | except OSError: 24 | try: 25 | os.remove(src) 26 | except OSError: 27 | pass 28 | finally: 29 | raise OSError( 30 | 'Moving downloaded temp file - {}, to {} failed. \ 31 | Please retry the download.'.format(src, dst)) 32 | else: 33 | import ctypes 34 | 35 | _MOVEFILE_REPLACE_EXISTING = 0x1 36 | # Setting this value guarantees that a move performed as a copy 37 | # and delete operation is flushed to disk before the function returns. 38 | # The flush occurs at the end of the copy operation. 39 | _MOVEFILE_WRITE_THROUGH = 0x8 40 | _windows_default_flags = _MOVEFILE_WRITE_THROUGH 41 | 42 | def _str_to_unicode(x): 43 | """Handle text decoding. Internal use only""" 44 | if not isinstance(x, str): 45 | return x.decode(sys.getfilesystemencoding()) 46 | return x 47 | 48 | def _handle_errors(rv, src): 49 | """Handle WinError. Internal use only""" 50 | if not rv: 51 | msg = ctypes.FormatError(ctypes.GetLastError()) 52 | # if the MoveFileExW fails(e.g. fail to acquire file lock), removes the tempfile 53 | try: 54 | os.remove(src) 55 | except OSError: 56 | pass 57 | finally: 58 | raise OSError(msg) 59 | 60 | def replace_file(src, dst): 61 | """Implement atomic os.replace with windows. 62 | refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw 63 | The function fails when one of the process(copy, flush, delete) fails. 64 | Parameters 65 | ---------- 66 | src : source file path 67 | dst : destination file path 68 | """ 69 | _handle_errors(ctypes.windll.kernel32.MoveFileExW( 70 | _str_to_unicode(src), _str_to_unicode(dst), 71 | _windows_default_flags | _MOVEFILE_REPLACE_EXISTING 72 | ), src) 73 | 74 | 75 | def path_splitall(path): 76 | allparts = [] 77 | while 1: 78 | parts = os.path.split(path) 79 | if parts[0] == path: # sentinel for absolute paths 80 | allparts.insert(0, parts[0]) 81 | break 82 | elif parts[1] == path: # sentinel for relative paths 83 | allparts.insert(0, parts[1]) 84 | break 85 | else: 86 | path = parts[0] 87 | allparts.insert(0, parts[1]) 88 | return allparts 89 | 90 | def get_parameter_names(model, forbidden_layer_types): 91 | r""" 92 | Returns the names of the model parameters that are not inside a forbidden layer. 93 | 94 | Borrowed from https://github.com/huggingface/transformers/blob/623b4f7c63f60cce917677ee704d6c93ee960b4b/src/transformers/trainer_pt_utils.py#L996 95 | """ 96 | result = [] 97 | for name, child in model.named_children(): 98 | result += [ 99 | f"{name}.{n}" 100 | for n in get_parameter_names(child, forbidden_layer_types) 101 | if not isinstance(child, tuple(forbidden_layer_types)) 102 | ] 103 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child. 104 | result += list(model._parameters.keys()) 105 | return result 106 | 107 | 108 | def sha1sum(filename): 109 | """Calculate the sha1sum of a file 110 | Parameters 111 | ---------- 112 | filename 113 | Name of the file 114 | Returns 115 | ------- 116 | ret 117 | The sha1sum 118 | """ 119 | with open(filename, mode='rb') as f: 120 | d = hashlib.sha1() 121 | for buf in iter(functools.partial(f.read, 1024*100), b''): 122 | d.update(buf) 123 | return d.hexdigest() 124 | 125 | 126 | def download(url: str, 127 | path: Optional[str] = None, 128 | overwrite: Optional[bool] = False, 129 | sha1_hash: Optional[str] = None, 130 | retries: Optional[int] = 5, 131 | verify_ssl: Optional[bool] = True) -> str: 132 | """Download a given URL 133 | 134 | Parameters 135 | ---------- 136 | url 137 | URL to download 138 | path 139 | Destination path to store downloaded file. By default stores to the 140 | current directory with same name as in url. 141 | overwrite 142 | Whether to overwrite destination file if already exists. 143 | sha1_hash 144 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 145 | but doesn't match. 146 | retries 147 | The number of times to attempt the download in case of failure or non 200 return codes 148 | verify_ssl 149 | Verify SSL certificates. 150 | Returns 151 | ------- 152 | fname 153 | The file path of the downloaded file. 154 | """ 155 | if path is None: 156 | fname = url.split('/')[-1] 157 | # Empty filenames are invalid 158 | assert fname, 'Can\'t construct file-name from this URL. ' \ 159 | 'Please set the `path` option manually.' 160 | else: 161 | path = os.path.expanduser(path) 162 | if os.path.isdir(path): 163 | fname = os.path.join(path, url.split('/')[-1]) 164 | else: 165 | fname = path 166 | 167 | if not verify_ssl: 168 | warnings.warn('Unverified HTTPS request is being made (verify_ssl=False). ' 169 | 'Adding certificate verification is strongly advised.') 170 | 171 | assert retries >= 0, f"Number of retries should be at least 0, currently it's {retries}" 172 | if overwrite or not os.path.exists(fname) or (sha1_hash and not sha1sum(fname) == sha1_hash): 173 | is_s3 = url.startswith('s3://') 174 | if is_s3: 175 | import boto3 176 | session = boto3.Session() 177 | s3 = session.resource('s3') 178 | if session.get_credentials() is None: 179 | from botocore.handlers import disable_signing 180 | s3.meta.client.meta.events.register('choose-signer.s3.*', disable_signing) 181 | components = url[len('s3://'):].split('/') 182 | if len(components) < 2: 183 | raise ValueError('Invalid S3 url. Received url={}'.format(url)) 184 | s3_bucket_name = components[0] 185 | s3_key = '/'.join(components[1:]) 186 | 187 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 188 | if not os.path.exists(dirname): 189 | os.makedirs(dirname, exist_ok=True) 190 | while retries + 1 > 0: 191 | # Disable pyling too broad Exception 192 | # pylint: disable=W0703 193 | try: 194 | print('Downloading {} from {}...'.format(fname, url)) 195 | if is_s3: 196 | response = s3.meta.client.head_object(Bucket=s3_bucket_name, Key=s3_key) 197 | total_size = int(response.get('ContentLength', 0)) 198 | random_uuid = str(uuid.uuid4()) 199 | tmp_path = '{}.{}'.format(fname, random_uuid) 200 | if tqdm is not None: 201 | def hook(t_obj): 202 | def inner(bytes_amount): 203 | t_obj.update(bytes_amount) 204 | 205 | return inner 206 | 207 | with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) as t: 208 | s3.meta.client.download_file(s3_bucket_name, s3_key, tmp_path, 209 | Callback=hook(t)) 210 | else: 211 | s3.meta.client.download_file(s3_bucket_name, s3_key, tmp_path) 212 | else: 213 | r = requests.get(url, stream=True, verify=verify_ssl) 214 | if r.status_code != 200: 215 | raise RuntimeError('Failed downloading url {}'.format(url)) 216 | # create uuid for temporary files 217 | random_uuid = str(uuid.uuid4()) 218 | total_size = int(r.headers.get('content-length', 0)) 219 | chunk_size = 1024 220 | if tqdm is not None: 221 | t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) 222 | with open('{}.{}'.format(fname, random_uuid), 'wb') as f: 223 | for chunk in r.iter_content(chunk_size=chunk_size): 224 | if chunk: # filter out keep-alive new chunks 225 | if tqdm is not None: 226 | t.update(len(chunk)) 227 | f.write(chunk) 228 | if tqdm is not None: 229 | t.close() 230 | # if the target file exists(created by other processes) 231 | # and have the same hash with target file 232 | # delete the temporary file 233 | if not os.path.exists(fname) or (sha1_hash and not sha1sum(fname) == sha1_hash): 234 | # atomic operation in the same file system 235 | replace_file('{}.{}'.format(fname, random_uuid), fname) 236 | else: 237 | try: 238 | os.remove('{}.{}'.format(fname, random_uuid)) 239 | except OSError: 240 | pass 241 | finally: 242 | warnings.warn( 243 | 'File {} exists in file system so the downloaded file is deleted'.format(fname)) 244 | if sha1_hash and not sha1sum(fname) == sha1_hash: 245 | raise UserWarning( 246 | 'File {} is downloaded but the content hash does not match.' 247 | ' The repo may be outdated or download may be incomplete. ' 248 | 'If the "repo_url" is overridden, consider switching to ' 249 | 'the default repo.'.format(fname)) 250 | break 251 | except Exception as e: 252 | retries -= 1 253 | if retries <= 0: 254 | raise e 255 | 256 | print('download failed due to {}, retrying, {} attempt{} left' 257 | .format(repr(e), retries, 's' if retries > 1 else '')) 258 | 259 | return fname 260 | -------------------------------------------------------------------------------- /src/earthformer/visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/visualization/__init__.py -------------------------------------------------------------------------------- /src/earthformer/visualization/nbody.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | from ..utils.layout import change_layout_np 5 | 6 | 7 | def save_example_vis_results( 8 | save_dir, save_prefix, 9 | in_seq, target_seq, 10 | pred_seq, label, 11 | layout='NHWT', idx=0, 12 | plot_stride=1, fs=10, norm="none"): 13 | r""" 14 | Parameters 15 | ---------- 16 | in_seq: np.array 17 | float value 0-1 18 | target_seq: np.array 19 | float value 0-1 20 | pred_seq: np.array or List[np.array] 21 | float value 0-1 22 | """ 23 | in_seq = change_layout_np(in_seq, in_layout=layout).astype(np.float32) 24 | target_seq = change_layout_np(target_seq, in_layout=layout).astype(np.float32) 25 | if isinstance(pred_seq, list): 26 | pred_seq_list = [change_layout_np(ele, in_layout=layout).astype(np.float32) 27 | for ele in pred_seq] 28 | assert isinstance(label, list) and len(label) == len(pred_seq) 29 | else: 30 | pred_seq_list = [change_layout_np(pred_seq, in_layout=layout).astype(np.float32), ] 31 | label_list = [label, ] 32 | fig_path = os.path.join(save_dir, f'{save_prefix}.png') 33 | if norm == "none": 34 | norm = {'scale': 1.0, 35 | 'shift': 0.0} 36 | elif norm == "to255": 37 | norm = {'scale': 255, 38 | 'shift': 0} 39 | else: 40 | raise NotImplementedError 41 | in_len = in_seq.shape[-1] 42 | out_len = target_seq.shape[-1] 43 | max_len = max(in_len, out_len) 44 | ncols = (max_len - 1) // plot_stride + 1 45 | fig, ax = plt.subplots(nrows=2 + len(pred_seq_list), 46 | ncols=ncols, 47 | figsize=(24, 8)) 48 | 49 | ax[0][0].set_ylabel('Inputs\n', fontsize=fs) 50 | for i in range(0, max_len, plot_stride): 51 | if i < in_len: 52 | xt = in_seq[idx, :, :, i] * norm['scale'] + norm['shift'] 53 | ax[0][i // plot_stride].imshow(xt, cmap='gray') 54 | else: 55 | ax[0][i // plot_stride].axis('off') 56 | 57 | ax[1][0].set_ylabel('Target\n', fontsize=fs) 58 | for i in range(0, max_len, plot_stride): 59 | if i < out_len: 60 | xt = target_seq[idx, :, :, i] * norm['scale'] + norm['shift'] 61 | ax[1][i // plot_stride].imshow(xt, cmap='gray') 62 | else: 63 | ax[1][i // plot_stride].axis('off') 64 | 65 | y_preds = [pred_seq[idx:idx + 1] * norm['scale'] + norm['shift'] 66 | for pred_seq in pred_seq_list] 67 | 68 | # Plot model predictions 69 | for k in range(len(pred_seq_list)): 70 | for i in range(0, max_len, plot_stride): 71 | if i < out_len: 72 | ax[2 + k][i // plot_stride].imshow(y_preds[k][0, :, :, i], cmap='gray') 73 | else: 74 | ax[2 + k][i // plot_stride].axis('off') 75 | 76 | ax[2 + k][0].set_ylabel(label_list[k], fontsize=fs) 77 | 78 | for i in range(0, max_len, plot_stride): 79 | if i < out_len: 80 | ax[-1][i // plot_stride].set_title(f"step {int(i + plot_stride)}", y=-0.25, fontsize=fs) 81 | 82 | for j in range(len(ax)): 83 | for i in range(len(ax[j])): 84 | ax[j][i].xaxis.set_ticks([]) 85 | ax[j][i].yaxis.set_ticks([]) 86 | 87 | plt.subplots_adjust(hspace=0.05, wspace=0.05) 88 | plt.savefig(fig_path) 89 | plt.close(fig) 90 | -------------------------------------------------------------------------------- /src/earthformer/visualization/sevir/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/earth-forecasting-transformer/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/visualization/sevir/__init__.py -------------------------------------------------------------------------------- /src/earthformer/visualization/sevir/dataset_statistics.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | from earthformer.datasets.sevir.sevir_dataloader import SEVIR_CATALOG, SEVIR_DATA_DIR, SEVIR_RAW_SEQ_LEN, \ 7 | SEVIR_LR_CATALOG, SEVIR_LR_DATA_DIR, SEVIR_LR_RAW_SEQ_LEN, SEVIRDataLoader 8 | 9 | 10 | def report_SEVIR_statistics( 11 | dataset="sevir", 12 | sanity_check=True, 13 | hist_save_path="tmp_sevir_hist.png"): 14 | r""" 15 | Report important statistics of SEVIR dataset, including: 16 | - The distribution of pixel values (from 0 to 255). 17 | 18 | Refer to https://discuss.pytorch.org/t/plot-a-histogram-for-multiple-images-full-dataset/67600 19 | """ 20 | if sanity_check: 21 | start_date = datetime.datetime(2019, 5, 27) 22 | train_val_split_date = datetime.datetime(2019, 5, 29) 23 | train_test_split_date = datetime.datetime(2019, 6, 1) 24 | end_date = datetime.datetime(2019, 6, 3) 25 | else: 26 | # total number of event = 14926 for the whole training set 27 | # val percentage is 20% in original SEVIR paper 28 | # this split results in 11906 / 14926 = 79.77% 29 | # if use (2019, 2, 1), the ratio becomes 12489 / 14926 = 83.67% 30 | start_date = None 31 | train_val_split_date = datetime.datetime(2019, 1, 1) 32 | # train_val_split_date = datetime.datetime(2019, 2, 1) 33 | train_test_split_date = datetime.datetime(2019, 6, 1) 34 | end_date = None 35 | if dataset == "sevir": 36 | catalog_path = SEVIR_CATALOG 37 | data_dir = SEVIR_DATA_DIR 38 | raw_seq_len = SEVIR_RAW_SEQ_LEN 39 | elif dataset == "sevir_lr": 40 | catalog_path = SEVIR_LR_CATALOG 41 | data_dir = SEVIR_LR_DATA_DIR 42 | raw_seq_len = SEVIR_LR_RAW_SEQ_LEN 43 | else: 44 | raise ValueError(f"Invalid dataset: {dataset}") 45 | 46 | batch_size = 1 47 | data_types = ["vil", ] 48 | layout = "NHWT" 49 | seq_len = raw_seq_len 50 | stride = seq_len 51 | sample_mode = "sequent" 52 | 53 | train_dataloader = SEVIRDataLoader( 54 | sevir_catalog=catalog_path, sevir_data_dir=data_dir, 55 | raw_seq_len=raw_seq_len, split_mode="uneven", 56 | data_types=data_types, seq_len=seq_len, stride=stride, 57 | sample_mode=sample_mode, batch_size=batch_size, layout=layout, 58 | num_shard=1, rank=0, 59 | start_date=start_date, end_date=train_val_split_date) 60 | # val_dataloader = SEVIRDataLoader( 61 | # sevir_catalog=catalog_path, sevir_data_dir=data_dir, 62 | # raw_seq_len=raw_seq_len, split_mode="uneven", 63 | # data_types=data_types, seq_len=seq_len, stride=stride, 64 | # sample_mode=sample_mode, batch_size=batch_size, layout=layout, 65 | # num_shard=1, rank=0, 66 | # start_date=train_val_split_date, end_date=train_test_split_date) 67 | # test_dataloader = SEVIRDataLoader( 68 | # sevir_catalog=catalog_path, sevir_data_dir=data_dir, 69 | # raw_seq_len=raw_seq_len, split_mode="uneven", 70 | # data_types=data_types, seq_len=seq_len, stride=stride, 71 | # sample_mode=sample_mode, batch_size=batch_size, layout=layout, 72 | # num_shard=1, rank=0, 73 | # start_date=train_test_split_date, end_date=end_date) 74 | 75 | data_loader = train_dataloader 76 | # data_loader = val_dataloader 77 | # data_loader = test_dataloader 78 | 79 | num_bins = 256 80 | count = np.zeros(num_bins) 81 | for data_idx, data in enumerate(data_loader): 82 | data_seq = data['vil'] 83 | hist = np.histogram(data_seq, bins=num_bins, range=[0, 255]) 84 | count += hist[0] 85 | bins = hist[1] 86 | fig = plt.figure() 87 | plt.bar(bins[:-1], count, color='b', alpha=0.5) 88 | plt.savefig(hist_save_path) -------------------------------------------------------------------------------- /src/earthformer/visualization/sevir/sevir_vis_seq.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | from matplotlib.colors import ListedColormap 6 | from matplotlib.patches import Patch 7 | from ...utils.layout import change_layout_np 8 | from .sevir_cmap import get_cmap, VIL_COLORS, VIL_LEVELS 9 | 10 | 11 | HMF_COLORS = np.array([ 12 | [82, 82, 82], 13 | [252, 141, 89], 14 | [255, 255, 191], 15 | [145, 191, 219] 16 | ]) / 255 17 | 18 | THRESHOLDS = (0, 16, 74, 133, 160, 181, 219, 255) 19 | 20 | def plot_hit_miss_fa(ax, y_true, y_pred, thres): 21 | mask = np.zeros_like(y_true) 22 | mask[np.logical_and(y_true >= thres, y_pred >= thres)] = 4 23 | mask[np.logical_and(y_true >= thres, y_pred < thres)] = 3 24 | mask[np.logical_and(y_true < thres, y_pred >= thres)] = 2 25 | mask[np.logical_and(y_true < thres, y_pred < thres)] = 1 26 | cmap = ListedColormap(HMF_COLORS) 27 | ax.imshow(mask, cmap=cmap) 28 | 29 | def plot_hit_miss_fa_all_thresholds(ax, y_true, y_pred, **unused_kwargs): 30 | fig = np.zeros(y_true.shape) 31 | y_true_idx = np.searchsorted(THRESHOLDS, y_true) 32 | y_pred_idx = np.searchsorted(THRESHOLDS, y_pred) 33 | fig[y_true_idx == y_pred_idx] = 4 34 | fig[y_true_idx > y_pred_idx] = 3 35 | fig[y_true_idx < y_pred_idx] = 2 36 | # do not count results in these not challenging areas. 37 | fig[np.logical_and(y_true < THRESHOLDS[1], y_pred < THRESHOLDS[1])] = 1 38 | cmap = ListedColormap(HMF_COLORS) 39 | ax.imshow(fig, cmap=cmap) 40 | 41 | def visualize_result( 42 | in_seq: np.array, target_seq: np.array, 43 | pred_seq_list: List[np.array], label_list: List[str], 44 | interval_real_time: float = 10.0, idx=0, norm=None, plot_stride=2, 45 | figsize=(24, 8), fs=10, 46 | vis_thresh=THRESHOLDS[2], vis_hits_misses_fas=True): 47 | """ 48 | Parameters 49 | ---------- 50 | model_list: list of nn.Module 51 | layout_list: list of str 52 | in_seq: np.array 53 | target_seq: np.array 54 | interval_real_time: float 55 | The minutes of each plot interval 56 | """ 57 | if norm is None: 58 | norm = {'scale': 255, 59 | 'shift': 0} 60 | cmap_dict = lambda s: {'cmap': get_cmap(s, encoded=True)[0], 61 | 'norm': get_cmap(s, encoded=True)[1], 62 | 'vmin': get_cmap(s, encoded=True)[2], 63 | 'vmax': get_cmap(s, encoded=True)[3]} 64 | in_len = in_seq.shape[-1] 65 | out_len = target_seq.shape[-1] 66 | max_len = max(in_len, out_len) 67 | ncols = (max_len - 1) // plot_stride + 1 68 | if vis_hits_misses_fas: 69 | fig, ax = plt.subplots(nrows=2 + 3 * len(pred_seq_list), 70 | ncols=ncols, 71 | figsize=figsize) 72 | else: 73 | fig, ax = plt.subplots(nrows=2 + len(pred_seq_list), 74 | ncols=ncols, 75 | figsize=figsize) 76 | 77 | ax[0][0].set_ylabel('Inputs', fontsize=fs) 78 | for i in range(0, max_len, plot_stride): 79 | if i < in_len: 80 | xt = in_seq[idx, :, :, i] * norm['scale'] + norm['shift'] 81 | ax[0][i // plot_stride].imshow(xt, **cmap_dict('vil')) 82 | else: 83 | ax[0][i // plot_stride].axis('off') 84 | 85 | ax[1][0].set_ylabel('Target', fontsize=fs) 86 | for i in range(0, max_len, plot_stride): 87 | if i < out_len: 88 | xt = target_seq[idx, :, :, i] * norm['scale'] + norm['shift'] 89 | ax[1][i // plot_stride].imshow(xt, **cmap_dict('vil')) 90 | # ax[1][i // plot_stride].set_title(f'{5*(i+plot_stride)} Minutes') 91 | else: 92 | ax[1][i // plot_stride].axis('off') 93 | 94 | target_seq = target_seq[idx:idx + 1] * norm['scale'] + norm['shift'] 95 | y_preds = [pred_seq[idx:idx + 1] * norm['scale'] + norm['shift'] 96 | for pred_seq in pred_seq_list] 97 | 98 | # Plot model predictions 99 | if vis_hits_misses_fas: 100 | for k in range(len(pred_seq_list)): 101 | for i in range(0, max_len, plot_stride): 102 | if i < out_len: 103 | ax[2 + 3 * k][i // plot_stride].imshow(y_preds[k][0, :, :, i], **cmap_dict('vil')) 104 | plot_hit_miss_fa(ax[2 + 1 + 3 * k][i // plot_stride], target_seq[0, :, :, i], y_preds[k][0, :, :, i], 105 | vis_thresh) 106 | plot_hit_miss_fa_all_thresholds(ax[2 + 2 + 3 * k][i // plot_stride], target_seq[0, :, :, i], 107 | y_preds[k][0, :, :, i]) 108 | else: 109 | ax[2 + 3 * k][i // plot_stride].axis('off') 110 | ax[2 + 1 + 3 * k][i // plot_stride].axis('off') 111 | ax[2 + 2 + 3 * k][i // plot_stride].axis('off') 112 | 113 | ax[2 + 3 * k][0].set_ylabel(label_list[k] + '\nPrediction', fontsize=fs) 114 | ax[2 + 1 + 3 * k][0].set_ylabel(label_list[k] + f'\nScores\nThresh={vis_thresh}', fontsize=fs) 115 | ax[2 + 2 + 3 * k][0].set_ylabel(label_list[k] + '\nScores\nAll Thresh', fontsize=fs) 116 | else: 117 | for k in range(len(pred_seq_list)): 118 | for i in range(0, max_len, plot_stride): 119 | if i < out_len: 120 | ax[2 + k][i // plot_stride].imshow(y_preds[k][0, :, :, i], **cmap_dict('vil')) 121 | else: 122 | ax[2 + k][i // plot_stride].axis('off') 123 | 124 | ax[2 + k][0].set_ylabel(label_list[k] + '\nPrediction', fontsize=fs) 125 | 126 | for i in range(0, max_len, plot_stride): 127 | if i < out_len: 128 | ax[-1][i // plot_stride].set_title(f'{int(interval_real_time * (i + plot_stride))} Minutes', y=-0.25) 129 | 130 | for j in range(len(ax)): 131 | for i in range(len(ax[j])): 132 | ax[j][i].xaxis.set_ticks([]) 133 | ax[j][i].yaxis.set_ticks([]) 134 | 135 | # Legend of thresholds 136 | num_thresh_legend = len(VIL_LEVELS) - 1 137 | legend_elements = [Patch(facecolor=VIL_COLORS[i], 138 | label=f'{int(VIL_LEVELS[i - 1])}-{int(VIL_LEVELS[i])}') 139 | for i in range(1, num_thresh_legend + 1)] 140 | ax[0][0].legend(handles=legend_elements, loc='center left', 141 | bbox_to_anchor=(-1.2, -0.), 142 | borderaxespad=0, frameon=False, fontsize='10') 143 | if vis_hits_misses_fas: 144 | # Legend of Hit, Miss and False Alarm 145 | legend_elements = [Patch(facecolor=HMF_COLORS[3], edgecolor='k', label='Hit'), 146 | Patch(facecolor=HMF_COLORS[2], edgecolor='k', label='Miss'), 147 | Patch(facecolor=HMF_COLORS[1], edgecolor='k', label='False Alarm')] 148 | # ax[-1][0].legend(handles=legend_elements, loc='lower right', 149 | # bbox_to_anchor=(6., -.6), 150 | # ncol=5, borderaxespad=0, frameon=False, fontsize='16') 151 | ax[3][0].legend(handles=legend_elements, loc='center left', 152 | bbox_to_anchor=(-2.2, -0.), 153 | borderaxespad=0, frameon=False, fontsize='16') 154 | 155 | plt.subplots_adjust(hspace=0.05, wspace=0.05) 156 | return fig, ax 157 | 158 | def save_example_vis_results( 159 | save_dir, save_prefix, in_seq, target_seq, pred_seq, label, 160 | layout='NHWT', interval_real_time: float = 10.0, idx=0, 161 | plot_stride=2, fs=10, norm=None): 162 | """ 163 | Parameters 164 | ---------- 165 | in_seq: np.array 166 | float value 0-1 167 | target_seq: np.array 168 | float value 0-1 169 | pred_seq: np.array 170 | float value 0-1 171 | interval_real_time: float 172 | The minutes of each plot interval 173 | """ 174 | in_seq = change_layout_np(in_seq, in_layout=layout).astype(np.float32) 175 | target_seq = change_layout_np(target_seq, in_layout=layout).astype(np.float32) 176 | pred_seq = change_layout_np(pred_seq, in_layout=layout).astype(np.float32) 177 | fig_path = os.path.join(save_dir, f'{save_prefix}.png') 178 | fig, ax = visualize_result( 179 | in_seq=in_seq, target_seq=target_seq, pred_seq_list=[pred_seq,], 180 | label_list=[label, ], interval_real_time=interval_real_time, idx=idx, 181 | plot_stride=plot_stride, fs=fs, norm=norm) 182 | plt.savefig(fig_path) 183 | plt.close(fig) 184 | -------------------------------------------------------------------------------- /src/earthformer/visualization/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def save_gif(single_seq, fname): 6 | """Save a single gif consisting of image sequence in single_seq to fname.""" 7 | img_seq = [Image.fromarray(img.astype(np.float32) * 255, 'F').convert("L") for img in single_seq] 8 | img = img_seq[0] 9 | img.save(fname, save_all=True, append_images=img_seq[1:]) 10 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Run Unittests 2 | 3 | Use the following command to run unittest 4 | 5 | ```bash 6 | python3 -m pytest . 7 | ``` 8 | -------------------------------------------------------------------------------- /tests/test_cuboid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from omegaconf import OmegaConf 4 | from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel 5 | import os 6 | 7 | 8 | oc_file = "./scripts/cuboid_transformer/sevir/earthformer_sevir_v1.yaml" 9 | model_cfg = OmegaConf.to_object(OmegaConf.load(open(oc_file, "r")).model) 10 | num_blocks = len(model_cfg["enc_depth"]) 11 | if isinstance(model_cfg["self_pattern"], str): 12 | enc_attn_patterns = [model_cfg["self_pattern"]] * num_blocks 13 | else: 14 | enc_attn_patterns = OmegaConf.to_container(model_cfg["self_pattern"]) 15 | if isinstance(model_cfg["cross_self_pattern"], str): 16 | dec_self_attn_patterns = [model_cfg["cross_self_pattern"]] * num_blocks 17 | else: 18 | dec_self_attn_patterns = OmegaConf.to_container(model_cfg["cross_self_pattern"]) 19 | if isinstance(model_cfg["cross_pattern"], str): 20 | dec_cross_attn_patterns = [model_cfg["cross_pattern"]] * num_blocks 21 | else: 22 | dec_cross_attn_patterns = OmegaConf.to_container(model_cfg["cross_pattern"]) 23 | 24 | model = CuboidTransformerModel( 25 | input_shape=model_cfg["input_shape"], 26 | target_shape=model_cfg["target_shape"], 27 | base_units=model_cfg["base_units"], 28 | # block_units=model_cfg["block_units"], 29 | scale_alpha=model_cfg["scale_alpha"], 30 | enc_depth=model_cfg["enc_depth"], 31 | dec_depth=model_cfg["dec_depth"], 32 | enc_use_inter_ffn=model_cfg["enc_use_inter_ffn"], 33 | dec_use_inter_ffn=model_cfg["dec_use_inter_ffn"], 34 | dec_hierarchical_pos_embed=model_cfg["dec_hierarchical_pos_embed"], 35 | downsample=model_cfg["downsample"], 36 | downsample_type=model_cfg["downsample_type"], 37 | enc_attn_patterns=enc_attn_patterns, 38 | dec_self_attn_patterns=dec_self_attn_patterns, 39 | dec_cross_attn_patterns=dec_cross_attn_patterns, 40 | dec_cross_last_n_frames=model_cfg["dec_cross_last_n_frames"], 41 | dec_use_first_self_attn=model_cfg["dec_use_first_self_attn"], 42 | num_heads=model_cfg["num_heads"], 43 | attn_drop=model_cfg["attn_drop"], 44 | proj_drop=model_cfg["proj_drop"], 45 | ffn_drop=model_cfg["ffn_drop"], 46 | upsample_type=model_cfg["upsample_type"], 47 | ffn_activation=model_cfg["ffn_activation"], 48 | gated_ffn=model_cfg["gated_ffn"], 49 | norm_layer=model_cfg["norm_layer"], 50 | # global vectors 51 | num_global_vectors=model_cfg["num_global_vectors"], 52 | use_dec_self_global=model_cfg["use_dec_self_global"], 53 | dec_self_update_global=model_cfg["dec_self_update_global"], 54 | use_dec_cross_global=model_cfg["use_dec_cross_global"], 55 | use_global_vector_ffn=model_cfg["use_global_vector_ffn"], 56 | use_global_self_attn=model_cfg["use_global_self_attn"], 57 | separate_global_qkv=model_cfg["separate_global_qkv"], 58 | global_dim_ratio=model_cfg["global_dim_ratio"], 59 | # initial_downsample 60 | initial_downsample_type=model_cfg["initial_downsample_type"], 61 | initial_downsample_activation=model_cfg["initial_downsample_activation"], 62 | # initial_downsample_type=="stack_conv" 63 | initial_downsample_stack_conv_num_layers=model_cfg["initial_downsample_stack_conv_num_layers"], 64 | initial_downsample_stack_conv_dim_list=model_cfg["initial_downsample_stack_conv_dim_list"], 65 | initial_downsample_stack_conv_downscale_list=model_cfg["initial_downsample_stack_conv_downscale_list"], 66 | initial_downsample_stack_conv_num_conv_list=model_cfg["initial_downsample_stack_conv_num_conv_list"], 67 | # misc 68 | padding_type=model_cfg["padding_type"], 69 | z_init_method=model_cfg["z_init_method"], 70 | checkpoint_level=model_cfg["checkpoint_level"], 71 | pos_embed_type=model_cfg["pos_embed_type"], 72 | use_relative_pos=model_cfg["use_relative_pos"], 73 | self_attn_use_final_proj=model_cfg["self_attn_use_final_proj"], 74 | # initialization 75 | attn_linear_init_mode=model_cfg["attn_linear_init_mode"], 76 | ffn_linear_init_mode=model_cfg["ffn_linear_init_mode"], 77 | conv_init_mode=model_cfg["conv_init_mode"], 78 | down_up_linear_init_mode=model_cfg["down_up_linear_init_mode"], 79 | norm_init_mode=model_cfg["norm_init_mode"], 80 | ) 81 | 82 | input_shape = model_cfg["input_shape"] 83 | target_shape = model_cfg["target_shape"] 84 | batch_size = 2 85 | data = torch.rand((batch_size, *input_shape)) 86 | out = model(data) 87 | print(out.shape) 88 | -------------------------------------------------------------------------------- /tests/unittests/test_pretrained_checkpoints.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | from omegaconf import OmegaConf 4 | import pytest 5 | import numpy as np 6 | import torch 7 | from torch.nn import functional as F 8 | import torchmetrics 9 | from einops import rearrange 10 | from earthformer.config import cfg 11 | from earthformer.utils.checkpoint import s3_download_pretrained_ckpt 12 | from earthformer.utils.layout import layout_to_in_out_slice 13 | from earthformer.utils.utils import download 14 | from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel 15 | from earthformer.cuboid_transformer.cuboid_transformer_unet_dec import CuboidTransformerAuxModel 16 | 17 | 18 | NUM_TEST_ITER = 16 # max = 32 since saved `unittest_data.pt` only contains the first 0 to 31 data entries. 19 | test_data_dir = os.path.join(cfg.root_dir, "tests", "unittests", "test_pretrained_checkpoints_data") 20 | 21 | def s3_download_unittest_data(data_name): 22 | test_data_path = os.path.join(test_data_dir, data_name) 23 | if not os.path.exists(test_data_path): 24 | os.makedirs(test_data_dir, exist_ok=True) 25 | download(url=f"s3://earthformer/unittests/{data_name}", path=test_data_path) 26 | 27 | 28 | def config_cuboid_transformer(cfg, model_type="CuboidTransformerModel"): 29 | model_cfg = OmegaConf.to_object(cfg.model) 30 | num_blocks = len(model_cfg["enc_depth"]) 31 | if isinstance(model_cfg["self_pattern"], str): 32 | enc_attn_patterns = [model_cfg.pop("self_pattern")] * num_blocks 33 | else: 34 | enc_attn_patterns = OmegaConf.to_container(model_cfg.pop("self_pattern")) 35 | model_cfg["enc_attn_patterns"] = enc_attn_patterns 36 | if isinstance(model_cfg["cross_self_pattern"], str): 37 | dec_self_attn_patterns = [model_cfg.pop("cross_self_pattern")] * num_blocks 38 | else: 39 | dec_self_attn_patterns = OmegaConf.to_container(model_cfg.pop("cross_self_pattern")) 40 | model_cfg["dec_self_attn_patterns"] = dec_self_attn_patterns 41 | if isinstance(model_cfg["cross_pattern"], str): 42 | dec_cross_attn_patterns = [model_cfg.pop("cross_pattern")] * num_blocks 43 | else: 44 | dec_cross_attn_patterns = OmegaConf.to_container(model_cfg.pop("cross_pattern")) 45 | model_cfg["dec_cross_attn_patterns"] = dec_cross_attn_patterns 46 | if model_type == "CuboidTransformerModel": 47 | model = CuboidTransformerModel(**model_cfg) 48 | elif model_type == "CuboidTransformerAuxModel": 49 | model = CuboidTransformerAuxModel(**model_cfg) 50 | else: 51 | raise ValueError(f"Invalid model_type {model_type}. Must be 'CuboidTransformerModel' or ''.") 52 | return model 53 | 54 | def test_sevir(): 55 | pretrained_ckpt_name = "earthformer_sevir.pt" 56 | test_data_name = "unittest_sevir_data_bs1_idx0to31.pt" 57 | s3_download_unittest_data(data_name=test_data_name) 58 | test_data_path = os.path.join(test_data_dir, test_data_name) 59 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 60 | # Load pretrained model 61 | pretrained_cfg_path = os.path.join(cfg.root_dir, "scripts", "cuboid_transformer", "sevir", "earthformer_sevir_v1.yaml") 62 | pretrained_cfg = OmegaConf.load(open(pretrained_cfg_path, "r")) 63 | model = config_cuboid_transformer( 64 | cfg=pretrained_cfg, 65 | model_type="CuboidTransformerModel").to(device) 66 | model.eval() 67 | if not os.path.exists(os.path.join(cfg.pretrained_checkpoints_dir, pretrained_ckpt_name)): 68 | s3_download_pretrained_ckpt(ckpt_name=pretrained_ckpt_name, 69 | save_dir=cfg.pretrained_checkpoints_dir, 70 | exist_ok=False) 71 | state_dict = torch.load(os.path.join(cfg.pretrained_checkpoints_dir, pretrained_ckpt_name), 72 | map_location=device) 73 | missing_keys, unexpected_keys = model.load_state_dict(state_dict=state_dict, strict=False) 74 | assert len(missing_keys) == 0, f"missing_keys {missing_keys} when loading pretrained state_dict." 75 | assert len(unexpected_keys) == 0, f"missing_keys {unexpected_keys} when loading pretrained state_dict." 76 | # Test on SEVIR test 77 | layout_cfg = pretrained_cfg.layout 78 | in_slice, out_slice = layout_to_in_out_slice(layout=layout_cfg.layout, 79 | in_len=layout_cfg.in_len, 80 | out_len=layout_cfg.out_len) 81 | test_mse_metrics = torchmetrics.MeanSquaredError().to(device) 82 | test_mae_metrics = torchmetrics.MeanAbsoluteError().to(device) 83 | test_data = torch.load(test_data_path) 84 | counter = 0 85 | with torch.no_grad(): 86 | for batch in test_data: 87 | data_seq = batch['vil'].contiguous().to(device) 88 | x = data_seq[in_slice] 89 | y = data_seq[out_slice] 90 | y_hat = model(x) 91 | test_mse_metrics(y_hat, y) 92 | test_mae_metrics(y_hat, y) 93 | counter += 1 94 | if counter >= NUM_TEST_ITER: 95 | break 96 | test_mse = test_mse_metrics.compute() 97 | test_mae = test_mae_metrics.compute() 98 | assert test_mse < 1E-2 99 | assert test_mae < 5E-2 100 | 101 | def test_enso(): 102 | pretrained_ckpt_name = "earthformer_icarenso2021.pt" 103 | test_data_name = "unittest_icarenso2021_data_bs1_idx0to31.pt" 104 | s3_download_unittest_data(data_name=test_data_name) 105 | test_data_path = os.path.join(test_data_dir, test_data_name) 106 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 107 | # Load pretrained model 108 | pretrained_cfg_path = os.path.join(cfg.root_dir, "scripts", "cuboid_transformer", "enso", "earthformer_enso_v1.yaml") 109 | pretrained_cfg = OmegaConf.load(open(pretrained_cfg_path, "r")) 110 | model = config_cuboid_transformer( 111 | cfg=pretrained_cfg, 112 | model_type="CuboidTransformerModel").to(device) 113 | model.eval() 114 | 115 | if not os.path.exists(os.path.join(cfg.pretrained_checkpoints_dir, pretrained_ckpt_name)): 116 | s3_download_pretrained_ckpt(ckpt_name=pretrained_ckpt_name, 117 | save_dir=cfg.pretrained_checkpoints_dir, 118 | exist_ok=False) 119 | state_dict = torch.load(os.path.join(cfg.pretrained_checkpoints_dir, pretrained_ckpt_name), 120 | map_location=device) 121 | missing_keys, unexpected_keys = model.load_state_dict(state_dict=state_dict, strict=False) 122 | assert len(missing_keys) == 0, f"missing_keys {missing_keys} when loading pretrained state_dict." 123 | assert len(unexpected_keys) == 0, f"missing_keys {unexpected_keys} when loading pretrained state_dict." 124 | # Test on ENSO test 125 | layout_cfg = pretrained_cfg.layout 126 | in_slice, out_slice = layout_to_in_out_slice(layout=layout_cfg.layout, 127 | in_len=layout_cfg.in_len, 128 | out_len=layout_cfg.out_len) 129 | test_mse_metrics = torchmetrics.MeanSquaredError().to(device) 130 | test_mae_metrics = torchmetrics.MeanAbsoluteError().to(device) 131 | test_data = torch.load(test_data_path) 132 | counter = 0 133 | with torch.no_grad(): 134 | for batch in test_data: 135 | sst_seq, nino_target = batch 136 | data_seq = sst_seq.float().unsqueeze(-1).to(device) 137 | x = data_seq[in_slice] 138 | y = data_seq[out_slice] 139 | y_hat = model(x) 140 | test_mse_metrics(y_hat, y) 141 | test_mae_metrics(y_hat, y) 142 | counter += 1 143 | if counter >= NUM_TEST_ITER: 144 | break 145 | test_mse = test_mse_metrics.compute() 146 | test_mae = test_mae_metrics.compute() 147 | assert test_mse < 5E-4 148 | assert test_mae < 2E-2 149 | 150 | def test_earthnet(): 151 | data_channels = 4 152 | pretrained_ckpt_name = "earthformer_earthnet2021.pt" 153 | test_data_name = "unittest_earthnet2021_data_bs1_idx0to31.pt" 154 | s3_download_unittest_data(data_name=test_data_name) 155 | test_data_path = os.path.join(test_data_dir, test_data_name) 156 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 157 | # Load pretrained model 158 | pretrained_cfg_path = os.path.join(cfg.root_dir, "scripts", "cuboid_transformer", "earthnet_w_meso", "earthformer_earthnet_v1.yaml") 159 | pretrained_cfg = OmegaConf.load(open(pretrained_cfg_path, "r")) 160 | model = config_cuboid_transformer( 161 | cfg=pretrained_cfg, 162 | model_type="CuboidTransformerAuxModel").to(device) 163 | model.eval() 164 | 165 | if not os.path.exists(os.path.join(cfg.pretrained_checkpoints_dir, pretrained_ckpt_name)): 166 | s3_download_pretrained_ckpt(ckpt_name=pretrained_ckpt_name, 167 | save_dir=cfg.pretrained_checkpoints_dir, 168 | exist_ok=False) 169 | state_dict = torch.load(os.path.join(cfg.pretrained_checkpoints_dir, pretrained_ckpt_name), 170 | map_location=device) 171 | missing_keys, unexpected_keys = model.load_state_dict(state_dict=state_dict, strict=False) 172 | assert len(missing_keys) == 0, f"missing_keys {missing_keys} when loading pretrained state_dict." 173 | assert len(unexpected_keys) == 0, f"missing_keys {unexpected_keys} when loading pretrained state_dict." 174 | # Test on EarthNet2021 test 175 | layout_cfg = pretrained_cfg.layout 176 | in_slice, out_slice = layout_to_in_out_slice(layout=layout_cfg.layout, 177 | in_len=layout_cfg.in_len, 178 | out_len=layout_cfg.out_len) 179 | test_mse_metrics = torchmetrics.MeanSquaredError().to(device) 180 | test_mae_metrics = torchmetrics.MeanAbsoluteError().to(device) 181 | test_data = torch.load(test_data_path) 182 | counter = 0 183 | with torch.no_grad(): 184 | for batch in test_data: 185 | highresdynamic = batch["highresdynamic"].to(device) 186 | seq = highresdynamic[..., :data_channels] 187 | print(seq.shape) 188 | # mask from dataloader: 1 for mask and 0 for non-masked 189 | mask = highresdynamic[..., data_channels: data_channels + 1][out_slice] 190 | in_seq = seq[in_slice] 191 | target_seq = seq[out_slice] 192 | # process aux data 193 | highresstatic = batch["highresstatic"].to(device) # (b c h w) 194 | mesodynamic = batch["mesodynamic"].to(device) # (b t h w c) 195 | mesostatic = batch["mesostatic"].to(device) # (b c h w) 196 | mesodynamic_interp = rearrange(mesodynamic, 197 | "b t h w c -> b c t h w") 198 | mesodynamic_interp = F.interpolate(mesodynamic_interp, 199 | size=(layout_cfg.in_len + layout_cfg.out_len, 200 | layout_cfg.img_height, 201 | layout_cfg.img_width), 202 | mode="nearest") 203 | highresstatic_interp = rearrange(highresstatic, 204 | "b c h w -> b c 1 h w") 205 | highresstatic_interp = F.interpolate(highresstatic_interp, 206 | size=(layout_cfg.in_len + layout_cfg.out_len, 207 | layout_cfg.img_height, 208 | layout_cfg.img_width), 209 | mode="nearest") 210 | mesostatic_interp = rearrange(mesostatic, 211 | "b c h w -> b c 1 h w") 212 | mesostatic_interp = F.interpolate(mesostatic_interp, 213 | size=(layout_cfg.in_len + layout_cfg.out_len, 214 | layout_cfg.img_height, 215 | layout_cfg.img_width), 216 | mode="nearest") 217 | aux_data = torch.cat((highresstatic_interp, mesodynamic_interp, mesostatic_interp), 218 | dim=1) 219 | aux_data = rearrange(aux_data, 220 | "b c t h w -> b t h w c") 221 | pred_seq = model(in_seq, aux_data[in_slice], aux_data[out_slice]) 222 | test_mse_metrics(pred_seq * (1 - mask), target_seq * (1 - mask)) 223 | test_mae_metrics(pred_seq * (1 - mask), target_seq * (1 - mask)) 224 | counter += 1 225 | if counter >= NUM_TEST_ITER: 226 | break 227 | test_mse = test_mse_metrics.compute() 228 | test_mae = test_mae_metrics.compute() 229 | assert test_mse < 5E-4 230 | assert test_mae < 1E-2 231 | --------------------------------------------------------------------------------