├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── WD ├── WeatherBenchDataset.py ├── dataset_old.py ├── datasets.py ├── io.py ├── plotting.py ├── regridding.py └── utils.py ├── config ├── data.yaml ├── ds_format │ └── zarr.yaml ├── experiment │ ├── diffusion.yaml │ ├── diffusion_2csteps.yaml │ ├── diffusion_CosineAnnealing.yaml │ ├── diffusion_MSE_Loss.yaml │ ├── diffusion_MSE_Loss_more_patient_deeper.yaml │ ├── diffusion_deeper.yaml │ ├── diffusion_fourcast.yaml │ ├── diffusion_more_patient_deeper.yaml │ ├── diffusion_t2m_3day_highres.yaml │ ├── diffusion_t2m_5day_highres.yaml │ ├── diffusion_t_850_1day.yaml │ ├── diffusion_t_850_2day.yaml │ ├── diffusion_t_850_3day.yaml │ ├── diffusion_t_850_3day_highres.yaml │ ├── diffusion_t_850_4day.yaml │ ├── diffusion_t_850_5day.yaml │ ├── diffusion_t_850_5day_highres.yaml │ ├── diffusion_wider.yaml │ ├── diffusion_z_500_1day.yaml │ ├── diffusion_z_500_2day.yaml │ ├── diffusion_z_500_3day.yaml │ ├── diffusion_z_500_3day_highres.yaml │ ├── diffusion_z_500_4day.yaml │ ├── diffusion_z_500_5day.yaml │ ├── diffusion_z_500_5day_highres.yaml │ ├── fourcastnet.yaml │ ├── iterative_diffusion_reduced_set.yaml │ ├── iterative_diffusion_t_850.yaml │ ├── iterative_diffusion_t_850_highres.yaml │ ├── iterative_diffusion_z_500.yaml │ ├── iterative_diffusion_z_500_highres.yaml │ ├── iterative_diffusion_z_500_t_850.yaml │ ├── unet.yaml │ ├── unet_highres.yaml │ ├── unet_highres_t2m.yaml │ └── unet_highres_t2m_3day.yaml ├── inference.yaml ├── paths │ └── default_paths.yaml ├── template │ ├── iterative_rasp_thuerey.yaml │ ├── iterative_reduced_set.yaml │ ├── iterative_t_850.yaml │ ├── iterative_t_850_highres.yaml │ ├── iterative_z_500.yaml │ ├── iterative_z_500_highres.yaml │ ├── iterative_z_500_t_850.yaml │ ├── rasp_thuerey_highres_t2m_3day.yaml │ ├── rasp_thuerey_highres_t2m_5day.yaml │ ├── rasp_thuerey_highres_t_850_3day.yaml │ ├── rasp_thuerey_highres_t_850_5day.yaml │ ├── rasp_thuerey_highres_z_500_3day.yaml │ ├── rasp_thuerey_highres_z_500_5day.yaml │ ├── rasp_thuerey_t_850_1day.yaml │ ├── rasp_thuerey_t_850_2day.yaml │ ├── rasp_thuerey_t_850_3day.yaml │ ├── rasp_thuerey_t_850_4day.yaml │ ├── rasp_thuerey_t_850_5day.yaml │ ├── rasp_thuerey_z_500_1day.yaml │ ├── rasp_thuerey_z_500_2day.yaml │ ├── rasp_thuerey_z_500_3day.yaml │ ├── rasp_thuerey_z_500_3day_2csteps.yaml │ ├── rasp_thuerey_z_500_4day.yaml │ └── rasp_thuerey_z_500_5day.yaml └── train.yaml ├── env_data.yml ├── env_eval.yml ├── env_model.yml ├── images ├── chronologic_timesteps.jpg ├── ensemble_condition.jpg ├── ensemble_predictions.jpg ├── ensemble_stats.jpg ├── ensemble_std.jpg ├── heatwave_predictions_step_0.png ├── heatwave_predictions_step_1.png ├── heatwave_predictions_step_10.png ├── heatwave_predictions_step_2.png ├── heatwave_predictions_step_20.png ├── heatwave_predictions_step_32.png ├── heatwave_predictions_step_5.png ├── heatwave_true_anomaly.png ├── performance_leadtime_version_0.jpg ├── performance_leadtime_version_1.jpg ├── performance_leadtime_version_2.jpg ├── performance_leadtime_version_3.jpg ├── performance_leadtime_version_4.jpg ├── predictions.jpg ├── spectra.png ├── spectra_no_entries.png ├── t_850_lowres.gif ├── timeseries.jpg └── z_500_lowres.gif ├── nb_ensemble_eval.ipynb ├── nb_heatwave.ipynb ├── nb_performance_over_lead_time.ipynb ├── nb_results.ipynb ├── nb_spectral_analysis.ipynb ├── nb_test_predictions.ipynb ├── pyproject.toml ├── s10_write_predictions_vae.py ├── s11_train_LFD.py ├── s12_write_predictions_LFD.py ├── s13_write_predictions_iterative.py ├── s14_very_long_iterative_run.py ├── s1_write_dataset.py ├── s2_train_conditional_pixel_diffusion.py ├── s3_write_predictions_conditional_pixel_diffusion.py ├── s4_train_val_test.py ├── s5_train_FourCastNet.py ├── s6_write_predictions_FourCastNet.py ├── s7_train_unet.py ├── s8_write_predictions_unet.py ├── s9_train_vae.py ├── submit_script_10_inference_vae.sh ├── submit_script_11_train_LFD.sh ├── submit_script_12_inference_LFD.sh ├── submit_script_13_inference_iterative.sh ├── submit_script_14_inference_iterative_very_long.sh ├── submit_script_1_dataset_creation.sh ├── submit_script_2_run_model.sh ├── submit_script_3_inference.sh ├── submit_script_4_eval_epoch.sh ├── submit_script_5_train_FourCastNet.sh ├── submit_script_6_inference_fourcastnet.sh ├── submit_script_7_train_unet.sh ├── submit_script_8_inference_UNet.sh └── submit_script_9_train_vae.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Editors 2 | .vscode/ 3 | .idea/ 4 | 5 | # Vagrant 6 | .vagrant/ 7 | 8 | # Mac/OSX 9 | .DS_Store 10 | 11 | # Windows 12 | Thumbs.db 13 | 14 | # Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # ML files 128 | 129 | lightning_logs 130 | lightning_logs/ 131 | 132 | 133 | # log files 134 | *.err 135 | *.out 136 | 137 | # pycache 138 | 139 | **/__pycache__/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "benchmark"] 2 | path = benchmark 3 | url = https://github.com/melioristic/benchmark 4 | [submodule "dm_zoo"] 5 | path = dm_zoo 6 | url = https://github.com/melioristic/dm_zoo 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.3.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/flake8 7 | rev: 6.0.0 8 | hooks: 9 | - id: flake8 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WeatherDiff 2 | 3 | ![image](https://github.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/blob/main/images/z_500_lowres.gif) 4 | 5 | ## Description 6 | This Code4Earth challenge explores the potential of Diffusion Models for weather prediction, more specificially we test it on the [WeatherBench](https://github.com/pangeo-data/WeatherBench) benchmark data set. 7 | 8 | This repository contains functions to benchmark the diffusion models developed in [diffusion-models-for-weather-prediction](https://github.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction). It builds on existing code from [WeatherBench](https://github.com/pangeo-data/WeatherBench). 9 | 10 | ## Roadmap 11 | This repository is part of a [ECMWF Code4Earth Project](https://github.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction), which takes place between May 1 2023 and September 20 2023. 12 | 13 | ## Installation 14 | 15 | ### Repository: 16 | - The repository is formatted with black formatter and also uses pre-commit 17 | - make sure that pre-commit package is installed or `pip install pre-commit` 18 | - to set up the git hook scripts `pre-commit install`. 19 | 20 | - The main repository has two submodules that can be installed as follows: 21 | 22 | Clone the main repository. 23 | Clone the `` | Make sure you have access to them. Then: 24 | 25 | 1. `git submodule init` 26 | 2. `git submodule update` 27 | 28 | ### Data download: 29 | - Our code requires the WeatherBench to be downloaded as described in [this repository](https://github.com/pangeo-data/WeatherBench/tree/master). We tested the 5.625° and 2.8125° resolutions. 30 | 31 | ### Setup: 32 | - Setting up conda environments. We create 3 environments, the requirements of each of them are contained in a .yml file. Run `conda env create -f ` to create each environment. 33 | - `env_data.yml` creates an environment `WD_data` that is used to preprocess the data 34 | - `env_model.yml` creates an environment `WD_model` that is used to train and make prediction with machine learning models. 35 | - `env_eval.yml` creates an environment `WD_eval` with packages required to analyse and plot results. 36 | - The workflow requires paths being set for a few different directories. These paths are specified in the `config/paths/` directory and make the following choices: 37 | - `dir_WeatherBench`: Directory the weatherBench dataset was downloaded to. 38 | - `dir_PreprocessedDatasets`: Preprocessed datasets get stored here 39 | - `dir_SavedModels`: Checkpoints and tensorboard logs are stored here 40 | - `dir_HydraConfigs`: When running jobs, the selected configuration files are logged here. 41 | - `dir_ModelOutput`: Predictions with the ML models get saved here. 42 | 43 | ### Workflow: 44 | The workflow to train and predict with the diffusion models is as follows: 45 | - Dataset creation: Creating a preprocessed dataset from the raw WeatherBench dataset. This can be obtained with `s1_write_dataset.py` and `submit_script_1_dataset_creation.sh` (if submitting jobscripts is required) 46 | - configurations for the dataset creation process and other parameter choices in the process are managed with [hydra](https://hydra.cc). The name of a configuration ("template") has to be selected when running the script, e.g. `python s1_write_data.py +template=`. The corresponding file `.yaml` should be contained in the `config/template` directory. 47 | - preprocessed datasets get saved as [zarr](https://zarr.readthedocs.io/en/stable/) files in the `dir_PreprocessedDirectories/` directory. 48 | - Training a model: Select the appropriate script (e.g. `s2_train_pixel_diffusion.py`). Configuration choices are made in the `config/train.yaml` file, and experiment specific choices (model architecture, dataset, ...) are listed in the files in the `/config/experiment` directory. A experiment name has to be given, analogously the dataset creation. A model can for example be trained by `python s2_train_pixel_diffusion.py +experiment=`. The selected configuration, including the experiment get logged to `dir_HydraConfigs`. 49 | - The training progress can be monitored with tensorbaord. 50 | - Once the training is finished, predictions can be written with the trained models. Selecting an appropriate script (e.g. `s3_write_predictions_conditional_pixel_diffusion.py`), predictions can be made as follows `python s3_write_predictions_conditional_pixel_diffusion.py +data.template= +experiment= +model_name= n_ensemble_members=`. Here `` and `` are the choices made when creating the employed dataset and training the model. By default, `` should be the time that the model run was started. To find this information, have a look at the logged configurations for training in `dir_HydraConfigs/training`. As the name suggests, `` determines how many ensemble predictions should be produces simultaneously. The predictions and ground truth get rescaled and saved in `.nc` files in `dir_ModelOutput`. They can be opened with [xarray](https://docs.xarray.dev/en/stable), and contain data of the following dimensionality: `(ensemble_member, init_time, lead_time, lat, lon)`. `init_time` is the "starting/initialization" time of the forecast, and `lead_time` specifies how far one wants to predict into the future. 51 | 52 | ## Contributing 53 | Script on guidelines for contributions will be added in the future. 54 | 55 | ## Authors and acknowledgment 56 | Participants: 57 | - [Mohit Anand](https://github.com/melioristic) 58 | - [Jonathan Wider](https://github.com/jonathanwider) 59 | 60 | Mentors: 61 | - [Jesper Dramsch](https://github.com/JesperDramsch) 62 | - [Florian Pinault](https://github.com/floriankrb) 63 | 64 | ## License 65 | This project is licensed under the [Apache 2.0 License](https://github.com/melioristic/benchmark/blob/main/LICENSE). The submodules contain code from external sources and are subject to the licenses included in these submodules. 66 | 67 | ## Project status 68 | Code4Earth project finished. 69 | 70 | 71 | -------------------------------------------------------------------------------- /WD/WeatherBenchDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | 4 | # import kornia 5 | import kornia.augmentation as KA 6 | import torch 7 | import xarray as xr 8 | 9 | 10 | class TestDataset(Dataset): 11 | """Dataset returning images in a folder.""" 12 | 13 | def __init__(self, root_dir, transforms=None): 14 | self.root_dir = root_dir 15 | self.transforms = transforms 16 | 17 | z500 = xr.open_mfdataset( 18 | os.path.join(root_dir, "geopotential_500/*.nc"), 19 | combine="by_coords", 20 | chunks=100, 21 | ).z 22 | 23 | zmin = z500.min().compute() 24 | zmax = z500.max().compute() 25 | self.normalized_z500 = (z500 - zmin) / (zmax - zmin) 26 | 27 | # set up transforms 28 | if self.transforms is not None: 29 | self.input_T = KA.container.AugmentationSequential( 30 | *self.transforms, data_keys=["input"], same_on_batch=False 31 | ) 32 | 33 | def __len__(self): 34 | return len(self.normalized_z500.time) 35 | 36 | def __getitem__(self, idx): 37 | if torch.is_tensor(idx): 38 | idx = idx.tolist() 39 | 40 | image = self.normalized_z500.isel(time=idx).values[None, ...] 41 | 42 | if self.transforms is not None: 43 | image = self.input_T(image)[0] 44 | 45 | return torch.tensor(image) 46 | 47 | 48 | def single_torch_file_from_dataset(root_dir): 49 | z500 = xr.open_mfdataset( 50 | os.path.join(root_dir, "geopotential_500/*.nc"), 51 | combine="by_coords", 52 | ).z 53 | zmin = z500.min() 54 | zmax = z500.max() 55 | normalized_z500 = torch.tensor(((z500 - zmin) / (zmax - zmin)).values)[ 56 | :, None, ... 57 | ] 58 | torch.save( 59 | normalized_z500, 60 | os.path.join(root_dir, "complete_dataset.pt"), 61 | ) 62 | 63 | 64 | class SingleDataset(Dataset): 65 | """Dataset returning images in a folder.""" 66 | 67 | def __init__(self, root_dir, transforms=None): 68 | self.root_dir = root_dir 69 | self.transforms = transforms 70 | 71 | self.data = torch.load(os.path.join(root_dir, "complete_dataset.pt")) 72 | 73 | def __len__(self): 74 | return len(self.data) 75 | 76 | def __getitem__(self, idx): 77 | if torch.is_tensor(idx): 78 | idx = idx.tolist() 79 | 80 | image = self.data[idx] 81 | 82 | if self.transforms is not None: 83 | image = self.input_T(image)[0] 84 | 85 | return torch.tensor(image) 86 | -------------------------------------------------------------------------------- /WD/dataset_old.py: -------------------------------------------------------------------------------- 1 | import xarray as xr 2 | from datetime import datetime 3 | import os 4 | from typing import Union, Tuple, Dict, List 5 | import torch 6 | from torch.utils.data import Dataset 7 | from WD.utils import ( 8 | inverse_transform_precipitation, 9 | generate_uid, 10 | ) 11 | 12 | from WD.io import write_config, load_config 13 | 14 | 15 | def write_conditional_datasets(config_path: str) -> None: 16 | """Save a preprocessed version of the WeatherBench dataset into a single file. 17 | 18 | Args: 19 | root_dir (str): Directory in which the WeatherBench Dataset is stored. 20 | train_limits (Tuple[datetime, datetime]): Start and end date of the training set. 21 | test_limits (Tuple[datetime, datetime]): Start and end date of the test set. 22 | conditioning_variables (Dict[str, str]): Variables we want to use for conditioning. Dict containing the filename as keys and the variables as values. 23 | output_variables (Dict[str, str]): Variables we want the model to forecast. Dict containing the filename as keys and the variables as values. 24 | conditioning_timesteps (List[int]): List of timesteps that should be used as conditioning information. If None, no conditioning will be applied. 25 | lead_time (Union[int,None]): The lead time at which we want to produce the prediction. In units of delta_t. 26 | validation_limits (Union[Tuple[datetime, datetime], None], optional): Start and end date of the training set. Can be None, if no validation set is used.. Defaults to None. 27 | spatial_resolution (str, optional): The spatial resolution of the dataset we want to load. Defaults to "5.625deg". 28 | delta_t (int, optional): Interval between consecutive timesteps in hours. Defaults to 6. 29 | out_dir (Union[None, str], optional): Directory to save the datasets in, if None use the same as the input directory. Defaults to None. 30 | out_filename (Union[None, str], optional): Name to save the dataset as, if not provided use a default name. Defaults to None. 31 | """ # noqa: E501 32 | 33 | # root_dir: str, 34 | # train_limits: Tuple[datetime, datetime], 35 | # test_limits: Tuple[datetime, datetime], 36 | # conditioning_variables: List[int], 37 | # output_variables: List[int], 38 | # conditioning_timesteps: List[int], 39 | # lead_time: Union[int, None], 40 | # validation_limits: Union[Tuple[datetime, datetime], None] = None, 41 | # spatial_resolution: str = "5.625deg", 42 | # delta_t: int = 6, 43 | # out_dir: Union[None, str] = None, 44 | # out_filename: Union[None, str] = None, 45 | 46 | # Read the config file 47 | 48 | config = load_config(config_path) 49 | 50 | from_train = config.exp_data.train.time_start 51 | to_train = config.exp_data.train.time_end 52 | 53 | from_val = config.exp_data.val.time_start 54 | to_val = config.exp_data.val.time_end 55 | 56 | from_test = config.exp_data.test.time_start 57 | to_test = config.exp_data.test.time_end 58 | 59 | root_dir = config.file_structure.dir_WeatherBench 60 | out_dir = config.file_structure.dir_pytorch_data 61 | train_limits = (from_train, to_train) 62 | validation_limits = (from_val, to_val) 63 | test_limits = (from_test, to_test) 64 | conditioning_variables = config.data_specs.conditioning_vars 65 | output_variables = config.data_specs.output_vars 66 | conditioning_timesteps = config.data_specs.conditioning_time_step 67 | lead_time = config.data_specs.lead_t 68 | spatial_resolution = config.data_specs.spatial_resolution 69 | delta_t = config.data_specs.delta_t 70 | constant_vars = config.data_specs.constant_vars 71 | 72 | config.ds_id = generate_uid() 73 | out_filename = f"{config.ds_id}" 74 | 75 | # load all files: 76 | output_datasets = [] 77 | conditioning_datasets = [] 78 | 79 | print("Open files.") 80 | # output files: 81 | for foldername, var_config in output_variables.items(): 82 | path = os.path.join( 83 | root_dir, 84 | foldername, 85 | "*_{}.nc".format(spatial_resolution), 86 | ) 87 | ds = xr.open_mfdataset(path) 88 | 89 | assert len(ds.keys()) == 1 90 | varname = list(ds.keys())[0] 91 | 92 | if varname == "tp": 93 | ds = ds.rolling(time=6).sum() # take 6 hour average 94 | ds["tp"] = inverse_transform_precipitation(ds) 95 | 96 | # extract desired pressure levels: 97 | if var_config.levels is not None: 98 | ds = ds.sel({"level": var_config.levels}) 99 | 100 | grouped = ds.groupby("level") 101 | group_indices = grouped.groups 102 | datasets = [] 103 | for ( 104 | group_name, 105 | group_index, 106 | ) in group_indices.items(): 107 | group_data = ds.isel(level=group_index) 108 | renamed_vars = {} 109 | for ( 110 | var_name, 111 | var_data, 112 | ) in group_data.data_vars.items(): 113 | new_var_name = f"{var_name}_{group_name}" 114 | renamed_vars[new_var_name] = var_data 115 | group_ds = xr.Dataset(renamed_vars).drop_vars("level") 116 | datasets.append(group_ds) 117 | output_datasets.extend(datasets) 118 | else: 119 | if "level" in ds.var(): 120 | ds = ds.drop_vars("level") 121 | output_datasets.append(ds) 122 | output_dataset = xr.merge(output_datasets) 123 | 124 | # conditioning files: 125 | for ( 126 | foldername, 127 | var_config, 128 | ) in conditioning_variables.items(): 129 | path = os.path.join( 130 | root_dir, 131 | foldername, 132 | "*_{}.nc".format(spatial_resolution), 133 | ) 134 | print(foldername, path) 135 | ds = xr.open_mfdataset(path) 136 | 137 | assert len(ds.keys()) == 1 138 | varname = list(ds.keys())[0] 139 | 140 | if varname == "tp": 141 | ds = ds.rolling(time=6).sum() # take 6 hour average 142 | ds["tp"] = inverse_transform_precipitation(ds) 143 | 144 | # extract desired pressure levels: 145 | if var_config.levels is not None: 146 | ds = ds.sel({"level": var_config.levels}) 147 | grouped = ds.groupby("level") 148 | group_indices = grouped.groups 149 | datasets = [] 150 | for ( 151 | group_name, 152 | group_index, 153 | ) in group_indices.items(): 154 | group_data = ds.isel(level=group_index) 155 | renamed_vars = {} 156 | for ( 157 | var_name, 158 | var_data, 159 | ) in group_data.data_vars.items(): 160 | new_var_name = f"{var_name}_{group_name}" 161 | renamed_vars[new_var_name] = var_data 162 | group_ds = xr.Dataset(renamed_vars).drop_vars("level") 163 | datasets.append(group_ds) 164 | conditioning_datasets.extend(datasets) 165 | else: 166 | conditioning_datasets.append(ds) 167 | 168 | # append constant fields: 169 | ds_constants = xr.open_dataset( 170 | os.path.join( 171 | root_dir, 172 | "constants", 173 | "constants_{}.nc".format(spatial_resolution), 174 | ) 175 | ) # "/data/compoundx/WeatherBench/constants/constants_5.625deg.nc" 176 | 177 | if constant_vars is not None: 178 | for cv in constant_vars: 179 | ds = ds_constants[ 180 | [ 181 | cv, 182 | ] 183 | ] 184 | conditioning_datasets.append(ds) 185 | conditioning_dataset = xr.merge(conditioning_datasets) 186 | 187 | print( 188 | "Number of conditioning variables:", 189 | len(list(conditioning_dataset.keys())), 190 | ) 191 | 192 | # pre-processing: 193 | 194 | # filter to temporal resolution delta_t 195 | output_dataset = output_dataset.resample( 196 | time="{}H".format(delta_t) 197 | ).nearest() 198 | conditioning_dataset = conditioning_dataset.resample( 199 | time="{}H".format(delta_t) 200 | ).nearest() 201 | 202 | # calculate training set maxima and minima - will need these to 203 | # rescale the data to [0,1] range. 204 | print("Compute train set minima and maxima.") 205 | 206 | # use these to rescale the datasets. 207 | print("Rescale datasets") 208 | conditioning_dataset = rescale_dataset(conditioning_dataset, train_limits) 209 | output_dataset = rescale_dataset(output_dataset, train_limits) 210 | 211 | print("Split into train, test, validation sets.") 212 | 213 | # create datasets for train, test and validation 214 | train_targets, _ = prepare_datasets( 215 | output_dataset.sel({"time": slice(*train_limits)}), 216 | lead_time=lead_time, 217 | conditioning_timesteps=conditioning_timesteps, 218 | ) 219 | _, train_inputs = prepare_datasets( 220 | conditioning_dataset.sel({"time": slice(*train_limits)}), 221 | lead_time=lead_time, 222 | conditioning_timesteps=conditioning_timesteps, 223 | ) 224 | 225 | test_targets, _ = prepare_datasets( 226 | output_dataset.sel({"time": slice(*test_limits)}), 227 | lead_time=lead_time, 228 | conditioning_timesteps=conditioning_timesteps, 229 | ) 230 | _, test_inputs = prepare_datasets( 231 | conditioning_dataset.sel({"time": slice(*test_limits)}), 232 | lead_time=lead_time, 233 | conditioning_timesteps=conditioning_timesteps, 234 | ) 235 | 236 | assert bool(train_targets.to_array().notnull().all().any()), ( 237 | "train_targets data set contains missing values," 238 | " possibly because of the precipitation" 239 | " computation." 240 | ) 241 | assert bool(train_inputs.to_array().notnull().all().any()), ( 242 | "train_inputs data set contains missing values," 243 | " possibly because of the precipitation" 244 | " computation." 245 | ) 246 | assert bool(test_targets.to_array().notnull().all().any()), ( 247 | "test_targets data set contains missing values," 248 | " possibly because of the precipitation" 249 | " computation." 250 | ) 251 | assert bool(test_inputs.to_array().notnull().all().any()), ( 252 | "test_inputs data set contains missing values," 253 | " possibly because of the precipitation" 254 | " computation." 255 | ) 256 | 257 | if validation_limits is not None: 258 | validation_targets, _ = prepare_datasets( 259 | output_dataset.sel({"time": slice(*validation_limits)}), 260 | lead_time=lead_time, 261 | conditioning_timesteps=conditioning_timesteps, 262 | ) 263 | _, validation_inputs = prepare_datasets( 264 | conditioning_dataset.sel({"time": slice(*validation_limits)}), 265 | lead_time=lead_time, 266 | conditioning_timesteps=conditioning_timesteps, 267 | ) 268 | assert bool(validation_targets.to_array().notnull().all().any()), ( 269 | "validation_targets data set contains missing" 270 | " values, possibly because of the precipitation" 271 | " computation." 272 | ) 273 | assert bool(validation_inputs.to_array().notnull().all().any()), ( 274 | "validation_inputs data set contains missing" 275 | " values, possibly because of the precipitation" 276 | " computation." 277 | ) 278 | 279 | # write the files: 280 | if out_filename is None: 281 | out_filename = "ds" 282 | 283 | print("write output") 284 | torch.save( 285 | { 286 | "inputs": torch.tensor( 287 | xr.Dataset.to_array(train_inputs).values 288 | ).transpose(1, 0), 289 | "targets": torch.tensor( 290 | xr.Dataset.to_array(train_targets).values 291 | ).transpose(1, 0), 292 | }, 293 | os.path.join(out_dir, "{}_train.pt".format(out_filename)), 294 | ) 295 | torch.save( 296 | { 297 | "inputs": torch.tensor( 298 | xr.Dataset.to_array(test_inputs).values 299 | ).transpose(1, 0), 300 | "targets": torch.tensor( 301 | xr.Dataset.to_array(test_targets).values 302 | ).transpose(1, 0), 303 | }, 304 | os.path.join(out_dir, "{}_test.pt".format(out_filename)), 305 | ) 306 | # if we want a validation set, create one: 307 | if validation_limits is not None: 308 | torch.save( 309 | { 310 | "inputs": torch.tensor( 311 | xr.Dataset.to_array(validation_inputs).values 312 | ).transpose(1, 0), 313 | "targets": torch.tensor( 314 | xr.Dataset.to_array(validation_targets).values 315 | ).transpose(1, 0), 316 | }, 317 | os.path.join(out_dir, "{}_val.pt".format(out_filename)), 318 | ) 319 | 320 | write_config(config) 321 | 322 | 323 | def prepare_datasets( 324 | ds: xr.DataArray, 325 | lead_time: int, 326 | conditioning_timesteps: List[int], 327 | ) -> Tuple[xr.Dataset, xr.Dataset]: 328 | """Given a dataset, a lead time and conditioning timesteps, which we want to use for conditioning the prediction, 329 | return the one dataset that contains all valid target data and 330 | one dataset that contains the combined conditioning information for the target data. 331 | 332 | Args: 333 | ds (xr.DataArray): A dataset we want to work with - ideally already restricted to train / test / validation set. 334 | lead_time (int): Lead time at which we want to make predictions, in units of delta_t. 335 | conditioning_timesteps (List[int]): Timesteps we want to use in the conditioning, in units of delta_t, e.g. 0 for current time step. 336 | 337 | Returns: 338 | Tuple[xr.Dataset, xr.Dataset]: Target and Conditioning datasets. 339 | """ # noqa: E501 340 | 341 | ds_target = ds.isel( 342 | { 343 | "time": slice( 344 | -(min(conditioning_timesteps)) + lead_time, 345 | None, 346 | ) 347 | } 348 | ) 349 | ds_target["time"] = ds.isel( 350 | {"time": slice(-(min(conditioning_timesteps)), -lead_time)} 351 | )["time"] 352 | 353 | dsts_condition = [] 354 | 355 | conditioning_vars_constant = [ 356 | k for k in ds.keys() if "time" not in ds[k].coords 357 | ] 358 | conditioning_vars_nonconstant = [ 359 | k for k in ds.keys() if "time" in ds[k].coords 360 | ] 361 | 362 | ds_conditional_nconst = ds[conditioning_vars_nonconstant] 363 | ds_conditional_const = ds[conditioning_vars_constant] 364 | 365 | for t_c in conditioning_timesteps: 366 | ds_c = ds_conditional_nconst.isel( 367 | { 368 | "time": slice( 369 | t_c - (min(conditioning_timesteps)), 370 | -lead_time + t_c, 371 | ) 372 | } 373 | ) 374 | ds_c["time"] = ds_conditional_nconst.isel( 375 | { 376 | "time": slice( 377 | -(min(conditioning_timesteps)), 378 | -lead_time, 379 | ) 380 | } 381 | )["time"] 382 | keys = ds_c.keys() 383 | values = ["{}_{}".format(k, t_c) for k in keys] 384 | renaming_dict = dict(zip(keys, values)) 385 | ds_c = xr.Dataset.rename_vars(ds_c, name_dict=renaming_dict) 386 | dsts_condition.append(ds_c) 387 | 388 | # add time dimension to constant fields: 389 | ds_conditional_const = ds_conditional_const.expand_dims( 390 | time=ds.isel( 391 | { 392 | "time": slice( 393 | -(min(conditioning_timesteps)), 394 | -lead_time, 395 | ) 396 | } 397 | )["time"] 398 | ) 399 | 400 | ds_condition = xr.merge(dsts_condition) 401 | ds_condition = xr.merge([ds_condition, ds_conditional_const]) 402 | 403 | return ds_target, ds_condition 404 | 405 | 406 | def write_datasets( 407 | root_dir: str, 408 | train_limits: Tuple[datetime, datetime], 409 | test_limits: Tuple[datetime, datetime], 410 | output_variables: Dict[str, str], 411 | validation_limits: Union[Tuple[datetime, datetime], None] = None, 412 | spatial_resolution: str = "5.625deg", 413 | delta_t: int = 6, 414 | out_dir: Union[None, str] = None, 415 | out_filename: Union[None, str] = None, 416 | ) -> None: 417 | """Save a preprocessed version of the WeatherBench dataset into a single file. 418 | 419 | Args: 420 | root_dir (str): Directory in which the WeatherBench Dataset is stored. 421 | train_limits (Tuple[datetime, datetime]): Start and end date of the training set. 422 | test_limits (Tuple[datetime, datetime]): Start and end date of the test set. 423 | output_variables (Dict[str, str]): Variables we want the model to forecast. Dict containing the filename as keys and the variables as values. 424 | validation_limits (Union[Tuple[datetime, datetime], None], optional): Start and end date of the training set. Can be None, if no validation set is used.. Defaults to None. 425 | spatial_resolution (str, optional): The spatial resolution of the dataset we want to load. Defaults to "5.625deg". 426 | delta_t (int, optional): Interval between consecutive timesteps in hours. Defaults to 6. 427 | out_dir (Union[None, str], optional): Directory to save the datasets in, if None use the same as the input directory. Defaults to None. 428 | out_filename (Union[None, str], optional): Name to save the dataset as, if not provided use a default name. Defaults to None. 429 | """ # noqa: E501 430 | 431 | # load all files: 432 | output_datasets = [] 433 | 434 | # output files: 435 | for foldername, varname in output_variables.items(): 436 | path = os.path.join( 437 | root_dir, 438 | foldername, 439 | "*_{}.nc".format(spatial_resolution), 440 | ) 441 | ds = xr.open_mfdataset(path) 442 | if varname == "tp": 443 | ds = ds.rolling(time=6).sum() # take 6 hour average 444 | output_datasets.append(ds) 445 | 446 | output_dataset = xr.merge(output_datasets) 447 | 448 | # pre-processing: 449 | 450 | # filter to temporal resolution delta_t 451 | output_dataset = output_dataset.resample( 452 | time="{}H".format(delta_t) 453 | ).nearest() 454 | 455 | # calculate training set maxima and minima - 456 | # will need these to rescale the data to [0,1] range. 457 | print("Compute train set minima and maxima.") 458 | train_output_set_max = ( 459 | output_dataset.sel({"time": slice(*train_limits)}).max().compute() 460 | ) 461 | train_output_set_min = ( 462 | output_dataset.sel({"time": slice(*train_limits)}).min().compute() 463 | ) 464 | 465 | # use these to rescale the datasets. 466 | print("rescale datasets") 467 | output_dataset = (output_dataset - train_output_set_min) / ( 468 | train_output_set_max - train_output_set_min 469 | ) 470 | 471 | # create datasets for train, test and validation 472 | train_targets = output_dataset.sel({"time": slice(*train_limits)}) 473 | assert bool( 474 | train_targets.to_array().notnull().all().any() 475 | ), ( # assert that there are no missing values in the training set: 476 | "Training data set contains missing values," 477 | " possibly because of the precipitation" 478 | " computation." 479 | ) 480 | 481 | test_targets = output_dataset.sel({"time": slice(*test_limits)}) 482 | assert bool( 483 | test_targets.to_array().notnull().all().any() 484 | ), ( # assert that there are no missing values in the test set: 485 | "Training data set contains missing values," 486 | " possibly because of the precipitation" 487 | " computation." 488 | ) 489 | 490 | if validation_limits is not None: 491 | validation_targets = output_dataset.sel( 492 | {"time": slice(*validation_limits)} 493 | ) 494 | assert bool( 495 | validation_targets.to_array().notnull().all().any() 496 | ), ( # assert that there are no missing values in the val set. 497 | "Validation data set contains missing values," 498 | " possibly because of the precipitation" 499 | " computation." 500 | ) 501 | 502 | # write the files: 503 | if out_filename is None: 504 | out_filename = "ds" 505 | 506 | print("write output") 507 | torch.save( 508 | torch.tensor(xr.Dataset.to_array(train_targets).values).transpose( 509 | 1, 0 510 | ), 511 | os.path.join(out_dir, "{}_train.pt".format(out_filename)), 512 | ) 513 | torch.save( 514 | torch.tensor(xr.Dataset.to_array(test_targets).values).transpose(1, 0), 515 | os.path.join(out_dir, "{}_test.pt".format(out_filename)), 516 | ) 517 | # if we want a validation set, create one: 518 | if validation_limits is not None: 519 | torch.save( 520 | torch.tensor( 521 | xr.Dataset.to_array(validation_targets).values 522 | ).transpose(1, 0), 523 | os.path.join(out_dir, "{}_val.pt".format(out_filename)), 524 | ) 525 | 526 | 527 | class Conditional_Dataset(Dataset): 528 | """Dataset when using past steps as conditioning information 529 | and predicting into the future.""" 530 | 531 | def __init__(self, path): 532 | self.path = path 533 | data = torch.load(self.path) 534 | 535 | self.inputs = data["inputs"] 536 | self.targets = data["targets"] 537 | 538 | def __len__(self): 539 | return len(self.inputs) 540 | 541 | def __getitem__(self, idx): 542 | if torch.is_tensor(idx): 543 | idx = idx.tolist() 544 | 545 | input = self.inputs[idx] 546 | target = self.targets[idx] 547 | 548 | return input, target 549 | 550 | 551 | class Unconditional_Dataset(Dataset): 552 | """Dataset for unconditional image generation.""" 553 | 554 | def __init__(self, path): 555 | self.path = path 556 | self.data = torch.load(self.path) 557 | 558 | def __len__(self): 559 | return len(self.data) 560 | 561 | def __getitem__(self, idx): 562 | if torch.is_tensor(idx): 563 | idx = idx.tolist() 564 | 565 | sample = self.data[idx] 566 | 567 | return sample 568 | 569 | 570 | def rescale_dataset(dataset, limits): 571 | print("Compute minima and maxima.") 572 | set_max = dataset.sel({"time": slice(*limits)}).max().compute() 573 | set_min = dataset.sel({"time": slice(*limits)}).min().compute() 574 | 575 | dataset = (dataset - set_min) / (set_max - set_min) 576 | 577 | return dataset 578 | -------------------------------------------------------------------------------- /WD/io.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from omegaconf import DictConfig 4 | 5 | import xarray as xr 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import torch 10 | from datetime import datetime, timedelta 11 | 12 | from WD.utils import ( 13 | get_git_revision_hash, 14 | inverse_transform_precipitation, 15 | ) 16 | import os 17 | import yaml 18 | # from munch import Munch 19 | 20 | 21 | def create_xr(data: np.array, var: List, data_description: str): 22 | start_dt = datetime(2000, 1, 1) 23 | lon_deg = 360 / data.shape[3] 24 | lat_deg = 180 / data.shape[2] 25 | 26 | assert lon_deg == lat_deg 27 | deg = lon_deg 28 | lon = np.arange(0, 360, deg) + deg / 2 29 | lat = np.arange(-90, 90, deg) + deg / 2 30 | time_series = [start_dt + timedelta(days=i) for i in range(data.shape[0])] 31 | ds = xr.Dataset( 32 | data_vars={ 33 | var[i]: ( 34 | ["time", "lat", "lon"], 35 | data[:, i, :, :], 36 | ) 37 | for i in range(data.shape[1]) 38 | }, 39 | coords=dict( 40 | lon=(["lon"], lon), 41 | lat=(["lat"], lat), 42 | time=time_series, 43 | reference_time=start_dt, 44 | ), 45 | attrs=dict(description=data_description), 46 | ) 47 | 48 | return ds 49 | 50 | 51 | def write_config(config): 52 | if type(config) == dict: 53 | config = Munch(config) 54 | 55 | dm_zoo, weather_diff = get_git_revision_hash() 56 | 57 | config.git_rev_parse = ( 58 | Munch() 59 | ) # need to initialize this subthing first, otherwise get errors below. 60 | config.git_rev_parse.dm_zoo = dm_zoo 61 | config.git_rev_parse.WeatherDiff = weather_diff 62 | 63 | fname = config.ds_id 64 | if "model_id" in config: 65 | fname = fname + "_" + config.model_id 66 | print("Writing model configuration file.") 67 | else: 68 | print("Writing dataset configuration file.") 69 | 70 | path = f"/data/compoundx/WeatherDiff/config_file/{fname}.yml" 71 | 72 | config_dict = config.toDict() 73 | 74 | with open(path, "w") as f: 75 | yaml.dump(config_dict, f, default_flow_style=False) 76 | 77 | os.chmod(path, 0o444) 78 | 79 | print(f"File {fname}.yml written (locked in read-only mode).") 80 | 81 | 82 | def load_config(config): 83 | if type(config) == str: 84 | with open(config) as f: 85 | config = yaml.safe_load(f) 86 | config = Munch.fromDict(config) 87 | if "data_specs" in config: 88 | config.n_generated_channels = n_generated_channels(config) 89 | config.n_condition_channels = n_condition_channels(config) 90 | 91 | return config 92 | 93 | 94 | def n_generated_channels(config): 95 | ov = config.data_specs.output_vars 96 | n_level = 0 97 | for k, v in ov.items(): 98 | if v is None: 99 | n_level += 1 100 | else: 101 | n_level = ( 102 | n_level + len(v["level"]) 103 | if v["level"] is not None 104 | else n_level + 1 105 | ) 106 | return n_level 107 | 108 | 109 | def n_condition_channels(config): 110 | n_level = 0 111 | for ( 112 | k, 113 | v, 114 | ) in config.data_specs.conditioning_vars.items(): 115 | if v is None: 116 | n_level += 1 117 | else: 118 | n_level = ( 119 | n_level + len(v["level"]) 120 | if v["level"] is not None 121 | else n_level + 1 122 | ) 123 | n_level = n_level * len(config.data_specs.conditioning_time_step) 124 | n_level = ( 125 | n_level + len(config.data_specs.constants) 126 | if config.data_specs.constants is not None 127 | else n_level 128 | ) 129 | return n_level 130 | 131 | 132 | def undo_scaling( 133 | dataset: xr.Dataset, dataset_min_max: xr.Dataset 134 | ) -> xr.Dataset: 135 | res = xr.Dataset() 136 | for varname in dataset.var(): 137 | dmin = dataset_min_max[varname + "_min"] 138 | dmax = dataset_min_max[varname + "_max"] 139 | 140 | res[varname] = dataset[varname] * (dmax - dmin) + dmin 141 | 142 | if varname == "tp": 143 | res[varname] = inverse_transform_precipitation( 144 | dataset[ 145 | [ 146 | varname, 147 | ] 148 | ] 149 | ) 150 | 151 | return res 152 | 153 | 154 | def create_xr_output_variables( 155 | data: torch.tensor, 156 | zarr_path: str, 157 | config: DictConfig, 158 | min_max_file_path: str, 159 | ) -> None: 160 | """Create an xarray dataset with dimensions [ensemble_member, init_time, lead_time, lat, lon] from a data tensor with shape (n_ensemble_members, n_init_times, n_variables, n_lat, n_lon) 161 | 162 | Args: 163 | data (torch.tensor): Data to be rescaled and read into an xarray dataset. 164 | dates (str): Path to the zarr file where the dataset is saved. Is required because we need to load the time axis from there. 165 | config_file_path (str): Path to the used configuration file 166 | min_max_file_path (str): Path to the netcdf4 file in which training set maxima and minima are stored. 167 | """ # noqa: E501 168 | # loading config information: 169 | 170 | spatial_resolution = config.template.data_specs.spatial_resolution 171 | root_dir = config.paths.dir_WeatherBench 172 | lead_time = config.template.data_specs.lead_time 173 | max_conditioning_time_steps = max(abs(np.array(config.template.data_specs.conditioning_time_step))) 174 | # load time: 175 | dates = xr.open_zarr(zarr_path).time.rename({"time":"init_time"}).isel({"init_time": slice(max_conditioning_time_steps, -lead_time)}) 176 | 177 | # create dataset and set up coordinates: 178 | ds = xr.Dataset() 179 | 180 | assert os.path.isfile( 181 | os.path.join( 182 | root_dir, 183 | spatial_resolution, 184 | "constants.nc", 185 | ) 186 | ), ( 187 | "The file {} is required to extract the coordinates, but doesn't" 188 | " exist.".format( 189 | os.path.join( 190 | root_dir, 191 | spatial_resolution, 192 | "constants.nc", 193 | ) 194 | ) 195 | ) 196 | coords = xr.open_dataset( 197 | os.path.join( 198 | root_dir, 199 | spatial_resolution, 200 | "constants.nc", 201 | ) 202 | ).coords 203 | ds.coords.update(coords) 204 | 205 | if data.ndim == 5: # (ensemble_member, bs, channels, lat, lon) 206 | ds = ds.expand_dims({"lead_time": 1}).assign_coords( 207 | {"lead_time": [lead_time]} 208 | ) 209 | elif data.ndim == 6: # (ensemble_member, bs, len_traj, channels, lat, lon) 210 | ds = ds.expand_dims({"lead_time": 1}).assign_coords( 211 | {"lead_time": [(i+1)*lead_time for i in range(data.shape[2])]} 212 | ) 213 | else: 214 | raise ValueError("Invalid number of dimensions of input data.") 215 | 216 | ds = ds.expand_dims({"ensemble_member": data.shape[0]}).assign_coords( 217 | {"ensemble_member": np.arange(data.shape[0])} 218 | ) 219 | ds = ds.expand_dims(init_time=dates) 220 | 221 | # get list of variables: 222 | assert os.path.isfile(min_max_file_path), ( 223 | "The file {} is required to extract minima and" 224 | " maxima, but doesn't exist.".format(min_max_file_path) 225 | ) 226 | ds_min_max = xr.open_dataset(min_max_file_path) 227 | # get list of variables, hopefully in the same order as the channels: 228 | var_names = [ 229 | name.replace("_max", "") 230 | for name in list(ds_min_max.var()) 231 | if "_max" in name 232 | ] 233 | 234 | if data.ndim == 5: # (ensemble_member, bs, channels, lat, lon) 235 | for i in range(data.shape[-3]): 236 | ds[var_names[i]] = xr.DataArray( 237 | data[..., i : i + 1, :, :], 238 | dims=("ensemble_member", "init_time", "lead_time", "lat", "lon"), 239 | coords={ 240 | "ensemble_member": ds.ensemble_member, 241 | "lat": ds.lat, 242 | "lon": ds.lon, 243 | "lead_time": ds.lead_time, 244 | "init_time": ds.init_time, 245 | }, 246 | ) 247 | elif data.ndim == 6: # (ensemble_member, bs, len_traj, channels, lat, lon) 248 | for i in range(data.shape[-3]): 249 | ds[var_names[i]] = xr.DataArray( 250 | data[..., i, :, :], 251 | dims=("ensemble_member", "init_time", "lead_time", "lat", "lon"), 252 | coords={ 253 | "ensemble_member": ds.ensemble_member, 254 | "lat": ds.lat, 255 | "lon": ds.lon, 256 | "lead_time": ds.lead_time, 257 | "init_time": ds.init_time, 258 | }, 259 | ) 260 | else: 261 | raise ValueError("Invalid number of dimensions of input data.") 262 | return undo_scaling(ds, ds_min_max) 263 | -------------------------------------------------------------------------------- /WD/plotting.py: -------------------------------------------------------------------------------- 1 | import cartopy.crs as ccrs 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | 5 | import numpy as np 6 | from cartopy.util import add_cyclic_point 7 | 8 | import xarray as xr 9 | from typing import Dict 10 | 11 | 12 | def plot_map( 13 | ax: plt.Axes, 14 | data: xr.Dataset, 15 | plotting_config: Dict, 16 | title: str = "", 17 | ) -> None: 18 | """Plot an xarray dataset as a map. The dataset is assumed to have dimensions lat and 19 | lon and trivial other dimensions and only contain a single variable. The function works 20 | by modifying an existing Axes object. 21 | 22 | Args: 23 | ax (plt.Axes): A matplotlib Axes to plot in. 24 | data (xr.Dataset): Data to plot, should have non-trivial dimensions lat and lon only 25 | plotting_config (Dict): Some configurations regarding plotting, infer details from code below. 26 | title (str, optional): Title for this specific panel of the plot. Defaults to "". 27 | """ # noqa: E501 28 | 29 | p_data = data[list(data.keys())[0]] 30 | 31 | lat = data.lat.values 32 | lon = data.lon.values 33 | 34 | ax.set_global() 35 | # remove white line 36 | field, lon_plot = add_cyclic_point(p_data, coord=lon) 37 | lo, la = np.meshgrid(lon_plot, lat) 38 | mesh = ax.pcolormesh( 39 | lo, 40 | la, 41 | field, 42 | transform=ccrs.PlateCarree(), 43 | cmap=plotting_config["CMAP"], 44 | norm=plotting_config["NORM"], 45 | rasterized=plotting_config["RASTERIZED"], 46 | ) 47 | 48 | if plotting_config["SHOW_COLORBAR"]: 49 | cbar = plt.colorbar( 50 | matplotlib.cm.ScalarMappable( 51 | cmap=plotting_config["CMAP"], 52 | norm=plotting_config["NORM"], 53 | ), 54 | spacing="proportional", 55 | orientation=plotting_config["CBAR_ORIENTATION"], 56 | extend=plotting_config["CBAR_EXTEND"], 57 | ax=ax, 58 | ) 59 | if plotting_config["SHOW_COLORBAR_LABEL"]: 60 | cbar.set_label(plotting_config["CBAR_LABEL"]) 61 | 62 | ax.coastlines() 63 | if title != "": 64 | ax.set_title( 65 | title, 66 | fontsize=plotting_config["TITLE_FONTSIZE"], 67 | ) 68 | return mesh 69 | 70 | 71 | def add_label_to_axes(ax, label, fontsize=None): 72 | if fontsize is None: 73 | ax.text( 74 | 0.01, 75 | 0.99, 76 | label, 77 | ha="left", 78 | va="top", 79 | transform=ax.transAxes, 80 | ) 81 | else: 82 | ax.text( 83 | 0.01, 84 | 0.99, 85 | label, 86 | ha="left", 87 | va="top", 88 | transform=ax.transAxes, 89 | fontsize=fontsize, 90 | ) 91 | -------------------------------------------------------------------------------- /WD/regridding.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/pangeo-data/WeatherBench/blob/master/src/regrid.py 2 | 3 | import xarray as xr 4 | import xesmf as xe 5 | 6 | def regrid( 7 | ds_in, 8 | ds_res, 9 | method='bilinear', 10 | reuse_weights=True 11 | ): 12 | """ 13 | Regrid horizontally. 14 | :param ds_in: Input xarray dataset 15 | :param ds_res: Output xarray dataset used to extract the output resolution 16 | :param method: Regridding method 17 | :param reuse_weights: Reuse weights for regridding 18 | :return: ds_out: Regridded dataset 19 | """ 20 | # Rename to ESMF compatible coordinates 21 | if 'latitude' in ds_in.coords: 22 | ds_in = ds_in.rename({'latitude': 'lat', 'longitude': 'lon'}) 23 | 24 | if 'latitude' in ds_res.coords: 25 | ds_res = ds_res.rename({'latitude': 'lat', 'longitude': 'lon'}) 26 | 27 | # Create regridder 28 | regridder = xe.Regridder( 29 | ds_in, ds_res, method, periodic=True, reuse_weights=reuse_weights, 30 | ) 31 | 32 | """ 33 | # Hack to speed up regridding of large files 34 | ds_list = [] 35 | chunk_size = 500 36 | n_chunks = len(ds_in.time) // chunk_size + 1 37 | for i in range(n_chunks): 38 | ds_small = ds_in.isel(time=slice(i*chunk_size, (i+1)*chunk_size)) 39 | ds_list.append(regridder(ds_small).astype('float32')) 40 | ds_out = xr.concat(ds_list, dim='time') 41 | """ 42 | 43 | ds_out = regridder(ds_in).astype('float32') 44 | 45 | # Set attributes since they get lost during regridding 46 | for var in ds_out: 47 | ds_out[var].attrs = ds_in[var].attrs 48 | ds_out.attrs.update(ds_in.attrs) 49 | 50 | # # Regrid dataset 51 | # ds_out = regridder(ds_in) 52 | return ds_out 53 | 54 | 55 | def regrid_to_res(ds_in:xr.Dataset, out_res: str, weatherbench_path:str="/data/compoundx/WeatherBench", *args, **kwargs)-> xr.Dataset: 56 | """Regrid to a given resolution contained in the WeatherBench dataset. 57 | 58 | Args: 59 | ds_in (xr.Dataset): Dataset to be interpolated. 60 | out_res (str): Target resolution. Must be contained in WeatherBench, and corresponding files must be downloaded. 61 | weatherbench_path (str): Path under which the WeatherBench directory is stored. 62 | Returns: 63 | xr.Dataset: Interpolated dataset. 64 | """ 65 | 66 | assert "deg" in out_res, "Resolution specification must be of the type 1.2345deg" 67 | 68 | ds_res = xr.open_dataset("{}/{}/constants.nc".format(weatherbench_path, out_res)) 69 | 70 | return regrid(ds_in, ds_res, *args, **kwargs) 71 | -------------------------------------------------------------------------------- /WD/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import subprocess 3 | from pathlib import Path 4 | import xarray as xr 5 | import numpy as np 6 | import uuid 7 | # import h5py 8 | 9 | from torch import tensor 10 | 11 | 12 | def check_devices(): 13 | if torch.cuda.is_available() or 1: 14 | print(f"Number of Devices {torch.cuda.device_count()}") 15 | for i in range(torch.cuda.device_count()): 16 | print(f"Device {i}: {torch.cuda.get_device_name(i)}") 17 | else: 18 | print("Cuda is not available") 19 | 20 | 21 | def get_git_revision_hash() -> str: 22 | grv_dm_zoo = ( 23 | subprocess.check_output(["git", "rev-parse", "HEAD:dm_zoo"]) 24 | .decode("ascii") 25 | .strip() 26 | ) 27 | grv_weather_diff = ( 28 | subprocess.check_output(["git", "rev-parse", "HEAD"]) 29 | .decode("ascii") 30 | .strip() 31 | ) 32 | 33 | return [grv_dm_zoo, grv_weather_diff] 34 | 35 | 36 | def generate_uid(): 37 | return uuid.uuid4().hex.upper()[0:6] 38 | 39 | 40 | def transformation_function(x, eps=0.001): 41 | return np.log1p(x / eps) 42 | 43 | 44 | def inverse_transformation_function(y, eps=0.001): 45 | return eps * np.expm1(y) 46 | 47 | 48 | def transform_precipitation(pr: xr.Dataset) -> xr.Dataset: 49 | """Apply a transformation to the precipitation values to make training easier. 50 | For now, follow Rasp & Thuerey 51 | 52 | Args: 53 | pr (xr.Dataset): The precipitation array on the original scale 54 | 55 | Returns: 56 | xr.Dataset: A rescaled version of the precipitation array 57 | """ # noqa: E501 58 | return xr.apply_ufunc( 59 | transformation_function, 60 | pr["tp"], 61 | dask="parallelized", 62 | ) 63 | 64 | 65 | def inverse_transform_precipitation( 66 | pr_transform: xr.Dataset, 67 | ) -> xr.Dataset: 68 | """Undo the precipitation transformation. 69 | 70 | Args: 71 | pr_transform (xr.Dataset): The rescaled precipitation array 72 | 73 | Returns: 74 | xr.Dataset: The precipitation, rescaled to the original resolution 75 | """ 76 | return xr.apply_ufunc( 77 | inverse_transformation_function, 78 | pr_transform["tp"], 79 | dask="parallelized", 80 | ) 81 | 82 | 83 | def create_dir(path): 84 | Path(path).mkdir(parents=True, exist_ok=True) 85 | 86 | 87 | def n_generated_channels(config): 88 | n_level = 0 89 | for k, v in config.data_specs.output_vars.items(): 90 | n_level = ( 91 | n_level + len(v.levels) if v.levels is not None else n_level + 1 92 | ) 93 | 94 | return n_level 95 | 96 | 97 | def n_condition_channels(config): 98 | n_level = 0 99 | for k, v in config.data_specs.conditioning_vars.items(): 100 | n_level = ( 101 | n_level + len(v.levels) if v.levels is not None else n_level + 1 102 | ) 103 | n_level = n_level * len(config.data_specs.conditioning_time_step) 104 | n_level = ( 105 | n_level + len(config.data_specs.constant_vars) 106 | if config.data_specs.constant_vars is not None 107 | else n_level 108 | ) 109 | return n_level 110 | 111 | """ 112 | def norm_area(deg=5.625): 113 | with h5py.File("/data/compoundx/WeatherBench/gridarea.h5", "r") as f: 114 | return f[f"norm_gridarea_{int(360/deg)}x{int(180/deg)}"][:] 115 | """ 116 | """ 117 | class AreaWeightedMSELoss: 118 | def __init__(self, spatial_res): 119 | deg = float(spatial_res[:-3]) 120 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 121 | self.weight = torch.tensor(norm_area(deg=deg), device=device) 122 | 123 | def loss_fn(self, input: tensor, target: tensor): 124 | return torch.sum(self.weight * (input - target) ** 2) 125 | """ 126 | 127 | class WeightedMSELoss: 128 | def __init__(self, weights): 129 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 130 | self.weights = torch.tensor(weights, device=device) 131 | 132 | def loss_fn(self, input: tensor, target: tensor): 133 | return (self.weights * (input - target) ** 2).mean() 134 | 135 | class AreaWeightedMSELoss(WeightedMSELoss): 136 | def __init__(self, lat, lon): 137 | super().__init__(weights=comp_area_weights_simple(lat, lon)) 138 | 139 | 140 | 141 | def comp_area_weights_simple(lat: np.ndarray, lon: np.ndarray) -> np.ndarray: 142 | """An easier way to calculate the (already normalized) area weights. 143 | 144 | Args: 145 | lat (np.ndarray): Array of latitudes of grid center points 146 | lon (np.ndarray): Array of lontigutes of grid center points 147 | 148 | Returns: 149 | np.ndarray: 2d array of relative area sizes. 150 | """ 151 | area_weights = np.cos(lat * (2 * np.pi) / 360) 152 | area_weights = area_weights.reshape(-1, 1).repeat(lon.shape[0],axis=-1) 153 | area_weights = (lat.shape[0]*lon.shape[0] / np.sum(area_weights)) * area_weights 154 | return area_weights -------------------------------------------------------------------------------- /config/data.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ds_format: zarr 3 | - paths: default_paths 4 | - _self_ 5 | 6 | template_name: ${hydra.runtime.choices.template} 7 | 8 | hydra: 9 | job: 10 | chdir: False 11 | run: 12 | dir: ${paths.dir_HydraConfigs}/data/${template_name} 13 | 14 | -------------------------------------------------------------------------------- /config/ds_format/zarr.yaml: -------------------------------------------------------------------------------- 1 | max_chunksize: 0.5 # maximal size per chunk in GB -------------------------------------------------------------------------------- /config/experiment/diffusion.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_2csteps.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day_2csteps 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_CosineAnnealing.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: CosineAnnealingLR 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_MSE_Loss.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_MSE_Loss_more_patient_deeper.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_deeper.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_fourcast.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | model: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | lr_scheduler_name: ReduceLROnPlateau 20 | batch_size: 64 21 | learning_rate: 0.00001 22 | num_workers: 1 23 | 24 | diffusion: 25 | num_diffusion_steps_inference: 200 26 | sampler_name: DDPM 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 34 | 35 | forecast: 36 | AFNO: 37 | patch_size: 2 38 | depth: 16 39 | embed_dim: 256 40 | num_blocks: 16 41 | drop_path_rate: 0. 42 | drop_rate: 0. 43 | mlp_ratio: 4. 44 | sparsity_threshold: 0.01 45 | hard_thresholding_fraction: 1. 46 | final_act: null -------------------------------------------------------------------------------- /config/experiment/diffusion_more_patient_deeper.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_t2m_3day_highres.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_highres_t2m_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.00003 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_t2m_5day_highres.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_highres_t2m_5day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.00003 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_t_850_1day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_t_850_1day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_t_850_2day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_t_850_2day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_t_850_3day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_t_850_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_t_850_3day_highres.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_highres_t_850_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.00003 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_t_850_4day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_t_850_4day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_t_850_5day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_t_850_5day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_t_850_5day_highres.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_highres_t_850_5day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.00003 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_wider.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 256 -------------------------------------------------------------------------------- /config/experiment/diffusion_z_500_1day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_1day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_z_500_2day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_2day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_z_500_3day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_z_500_3day_highres.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_highres_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.00003 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_z_500_4day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_4day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_z_500_5day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_5day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/diffusion_z_500_5day_highres.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_highres_z_500_5day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 10 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 40 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: Constant 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.00003 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8,16] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/fourcastnet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | setup: 10 | loss_fn_name: AreaWeighted_MSE_Loss 11 | 12 | fourcastnet: 13 | lr_scheduler_name: Constant 14 | batch_size: 64 15 | learning_rate: 0.0005 16 | num_workers: 1 17 | afno_net: 18 | patch_size: 2 19 | depth: 8 20 | embed_dim: 256 21 | num_blocks: 8 22 | drop_path_rate: 0. 23 | drop_rate: 0. 24 | mlp_ratio: 4. 25 | sparsity_threshold: 0.01 26 | hard_thresholding_fraction: 1. 27 | 28 | training: 29 | max_steps: 5000000 30 | ema_decay: 0.9999 31 | limit_val_batches: 10 32 | accelerator: "cuda" 33 | devices: -1 34 | patience: 10 -------------------------------------------------------------------------------- /config/experiment/iterative_diffusion_reduced_set.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: iterative_reduced_set 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 20 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: CosineAnnealingLR 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/iterative_diffusion_t_850.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: iterative_t_850 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 20 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: CosineAnnealingLR 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/iterative_diffusion_t_850_highres.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: iterative_t_850_highres 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 20 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: CosineAnnealingLR 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/iterative_diffusion_z_500.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: iterative_z_500 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 20 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: CosineAnnealingLR 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/iterative_diffusion_z_500_highres.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: iterative_z_500_highres 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 20 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: CosineAnnealingLR 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/iterative_diffusion_z_500_t_850.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: iterative_z_500_t_850 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 20 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 10 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | sampler_name: DDPM 20 | 21 | pixel_diffusion: 22 | lr_scheduler_name: CosineAnnealingLR 23 | num_diffusion_steps_inference: 200 24 | batch_size: 64 25 | learning_rate: 0.0001 26 | num_workers: 1 27 | denoising_diffusion_process: 28 | unet_type: UnetConvNextBlock 29 | noise_schedule: linear 30 | use_cyclical_padding: True 31 | num_diffusion_steps: 1000 32 | dims_mults: [1,2,4,8] 33 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/unet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 1.0 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 50 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | 20 | unet_regression: 21 | lr_scheduler_name: Constant 22 | batch_size: 64 23 | learning_rate: 0.0001 24 | num_workers: 1 25 | direct_unet_prediction: 26 | unet_type: UnetConvNextBlock 27 | use_cyclical_padding: True 28 | dims_mults: [1,2,4,8] 29 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/unet_highres.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_highres_z_500_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 1.0 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 50 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | 20 | unet_regression: 21 | lr_scheduler_name: Constant 22 | batch_size: 64 23 | learning_rate: 0.0001 24 | num_workers: 1 25 | direct_unet_prediction: 26 | unet_type: UnetConvNextBlock 27 | use_cyclical_padding: True 28 | dims_mults: [1,2,4,8,16] 29 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/unet_highres_t2m.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_highres_t2m_5day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 1.0 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 50 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | 20 | unet_regression: 21 | lr_scheduler_name: Constant 22 | batch_size: 64 23 | learning_rate: 0.0001 24 | num_workers: 1 25 | direct_unet_prediction: 26 | unet_type: UnetConvNextBlock 27 | use_cyclical_padding: True 28 | dims_mults: [1,2,4,8,16] 29 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/experiment/unet_highres_t2m_3day.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | template: rasp_thuerey_highres_t2m_3day 3 | ds_format: zarr 4 | train_shuffle_chunks: True 5 | train_shuffle_in_chunks: True 6 | val_shuffle_chunks: False 7 | val_shuffle_in_chunks: False 8 | 9 | training: 10 | max_steps: 5000000 11 | ema_decay: 0.9999 12 | limit_val_batches: 1.0 13 | accelerator: "cuda" 14 | devices: 1 15 | patience: 50 16 | 17 | setup: 18 | loss_fn_name: AreaWeighted_MSE_Loss 19 | 20 | unet_regression: 21 | lr_scheduler_name: Constant 22 | batch_size: 64 23 | learning_rate: 0.0001 24 | num_workers: 1 25 | direct_unet_prediction: 26 | unet_type: UnetConvNextBlock 27 | use_cyclical_padding: True 28 | dims_mults: [1,2,4,8,16] 29 | num_channels_base: 64 -------------------------------------------------------------------------------- /config/inference.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - paths: default_paths 3 | - _self_ 4 | 5 | batchsize: 128 6 | shuffle_chunks: False 7 | shuffle_in_chunks: False 8 | sampler: null 9 | loss_fn: null 10 | 11 | hydra: 12 | job: 13 | chdir: False 14 | run: 15 | dir: ${paths.dir_HydraConfigs}/inference/${data.template}/${hydra.runtime.choices.experiment}/${model_name}/${now:%Y-%m-%d_%H-%M-%S} -------------------------------------------------------------------------------- /config/paths/default_paths.yaml: -------------------------------------------------------------------------------- 1 | dir_WeatherBench: "/data/compoundx/WeatherBench/" # directory the weatherBench dataset was downloaded to 2 | dir_PreprocessedDatasets: "/data/compoundx/WeatherDiff/model_input/" # preprocessed datasets 3 | dir_SavedModels: "/data/compoundx/WeatherDiff/saved_model/" # checkpoints, tensorboard logs etc. 4 | dir_HydraConfigs: "/data/compoundx/WeatherDiff/config_logs/" # when running jobs, the configurations get logged here 5 | dir_ModelOutput: "/data/compoundx/WeatherDiff/model_output/" # predictions with the ML models -------------------------------------------------------------------------------- /config/template/iterative_rasp_thuerey.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 1 # 3 days (3 times 4 times 6 hours) 30 | output_vars: # same as conditioning variables 31 | # fields with pressure levels 32 | temperature: 33 | level: [50, 250, 500, 600, 700, 850, 925] 34 | geopotential: 35 | level: [50, 250, 500, 600, 700, 850, 925] 36 | u_component_of_wind: 37 | level: [50, 250, 500, 600, 700, 850, 925] 38 | v_component_of_wind: 39 | level: [50, 250, 500, 600, 700, 850, 925] 40 | specific_humidity: 41 | level: [50, 250, 500, 600, 700, 850, 925] 42 | # 2d fields: 43 | 2m_temperature: 44 | level: 45 | total_precipitation: 46 | level: 47 | toa_incident_solar_radiation: 48 | level: 49 | spatial_resolution: 5.625deg 50 | exp_data: 51 | test: 52 | start: 2017-01-01 00:00:00 53 | end: 2018-12-31 00:00:00 54 | train: 55 | start: 1979-01-02 00:00:00 56 | end: 2015-12-31 00:00:00 57 | val: 58 | start: 2016-01-01 00:00:00 59 | end: 2016-12-31 00:00:00 60 | 61 | -------------------------------------------------------------------------------- /config/template/iterative_reduced_set.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 925] 8 | geopotential: 9 | level: [50, 250, 500, 925] 10 | specific_humidity: 11 | level: [50, 250, 500, 925] 12 | # 2d fields: 13 | 2m_temperature: 14 | level: 15 | # constant input fields 16 | constants: 17 | - orography 18 | - lat2d 19 | - lsm 20 | delta_t: 6 21 | lead_time: 1 # 3 days (3 times 4 times 6 hours) 22 | output_vars: # same as conditioning variables 23 | # fields with pressure levels 24 | temperature: 25 | level: [50, 250, 500, 925] 26 | geopotential: 27 | level: [50, 250, 500, 925] 28 | specific_humidity: 29 | level: [50, 250, 500, 925] 30 | # 2d fields: 31 | 2m_temperature: 32 | level: 33 | spatial_resolution: 5.625deg 34 | exp_data: 35 | test: 36 | start: 2017-01-01 00:00:00 37 | end: 2018-12-31 00:00:00 38 | train: 39 | start: 1979-01-02 00:00:00 40 | end: 2015-12-31 00:00:00 41 | val: 42 | start: 2016-01-01 00:00:00 43 | end: 2016-12-31 00:00:00 44 | 45 | -------------------------------------------------------------------------------- /config/template/iterative_t_850.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [850] 8 | # constant input fields 9 | constants: 10 | - orography 11 | - lat2d 12 | - lsm 13 | delta_t: 6 14 | lead_time: 1 # 3 days (3 times 4 times 6 hours) 15 | output_vars: # same as conditioning variables 16 | # fields with pressure levels 17 | temperature: 18 | level: [850] 19 | spatial_resolution: 5.625deg 20 | exp_data: 21 | test: 22 | start: 2017-01-01 00:00:00 23 | end: 2018-12-31 00:00:00 24 | train: 25 | start: 1979-01-02 00:00:00 26 | end: 2015-12-31 00:00:00 27 | val: 28 | start: 2016-01-01 00:00:00 29 | end: 2016-12-31 00:00:00 30 | 31 | -------------------------------------------------------------------------------- /config/template/iterative_t_850_highres.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [850] 8 | # constant input fields 9 | constants: 10 | - orography 11 | - lat2d 12 | - lsm 13 | delta_t: 6 14 | lead_time: 1 # 3 days (3 times 4 times 6 hours) 15 | output_vars: # same as conditioning variables 16 | # fields with pressure levels 17 | temperature: 18 | level: [850] 19 | spatial_resolution: 2.8125deg 20 | exp_data: 21 | test: 22 | start: 2017-01-01 00:00:00 23 | end: 2018-12-31 00:00:00 24 | train: 25 | start: 1979-01-02 00:00:00 26 | end: 2015-12-31 00:00:00 27 | val: 28 | start: 2016-01-01 00:00:00 29 | end: 2016-12-31 00:00:00 30 | 31 | -------------------------------------------------------------------------------- /config/template/iterative_z_500.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | geopotential: 7 | level: [500] 8 | # constant input fields 9 | constants: 10 | - orography 11 | - lat2d 12 | - lsm 13 | delta_t: 6 14 | lead_time: 1 # 3 days (3 times 4 times 6 hours) 15 | output_vars: # same as conditioning variables 16 | # fields with pressure levels 17 | geopotential: 18 | level: [500] 19 | spatial_resolution: 5.625deg 20 | exp_data: 21 | test: 22 | start: 2017-01-01 00:00:00 23 | end: 2018-12-31 00:00:00 24 | train: 25 | start: 1979-01-02 00:00:00 26 | end: 2015-12-31 00:00:00 27 | val: 28 | start: 2016-01-01 00:00:00 29 | end: 2016-12-31 00:00:00 30 | 31 | -------------------------------------------------------------------------------- /config/template/iterative_z_500_highres.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | geopotential: 7 | level: [500] 8 | # constant input fields 9 | constants: 10 | - orography 11 | - lat2d 12 | - lsm 13 | delta_t: 6 14 | lead_time: 1 # 3 days (3 times 4 times 6 hours) 15 | output_vars: # same as conditioning variables 16 | # fields with pressure levels 17 | geopotential: 18 | level: [500] 19 | spatial_resolution: 2.8125deg 20 | exp_data: 21 | test: 22 | start: 2017-01-01 00:00:00 23 | end: 2018-12-31 00:00:00 24 | train: 25 | start: 1979-01-02 00:00:00 26 | end: 2015-12-31 00:00:00 27 | val: 28 | start: 2016-01-01 00:00:00 29 | end: 2016-12-31 00:00:00 30 | 31 | -------------------------------------------------------------------------------- /config/template/iterative_z_500_t_850.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | geopotential: 7 | level: [500] 8 | temperature: 9 | level: [850] 10 | # constant input fields 11 | constants: 12 | - orography 13 | - lat2d 14 | - lsm 15 | delta_t: 6 16 | lead_time: 1 # 3 days (3 times 4 times 6 hours) 17 | output_vars: # same as conditioning variables 18 | # fields with pressure levels 19 | geopotential: 20 | level: [500] 21 | temperature: 22 | level: [850] 23 | spatial_resolution: 5.625deg 24 | exp_data: 25 | test: 26 | start: 2017-01-01 00:00:00 27 | end: 2018-12-31 00:00:00 28 | train: 29 | start: 1979-01-02 00:00:00 30 | end: 2015-12-31 00:00:00 31 | val: 32 | start: 2016-01-01 00:00:00 33 | end: 2016-12-31 00:00:00 34 | 35 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_highres_t2m_3day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 12 # 3 days (3 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | 2m_temperature: 33 | level: 34 | spatial_resolution: 2.8125deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_highres_t2m_5day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 20 # 3 days (5 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | 2m_temperature: 33 | level: 34 | spatial_resolution: 2.8125deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_highres_t_850_3day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 12 # 3 days (3 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | temperature: 33 | level: [850,] 34 | spatial_resolution: 2.8125deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_highres_t_850_5day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 20 # 5 days (5 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | temperature: 33 | level: [850,] 34 | spatial_resolution: 2.8125deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_highres_z_500_3day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 12 # 3 days (3 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | geopotential: 33 | level: [500,] 34 | spatial_resolution: 2.8125deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_highres_z_500_5day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 20 # 5 days (5 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | geopotential: 33 | level: [500,] 34 | spatial_resolution: 2.8125deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_t_850_1day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 4 # 1 days (1 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | temperature: 33 | level: [850,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_t_850_2day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 8 # 2 days (2 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | temperature: 33 | level: [850,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_t_850_3day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 12 # 3 days (3 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | temperature: 33 | level: [850,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_t_850_4day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 16 # 4 days (4 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | temperature: 33 | level: [850,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_t_850_5day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 20 # 5 days (5 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | temperature: 33 | level: [850,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_z_500_1day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 4 # 1 days (1 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | geopotential: 33 | level: [500,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_z_500_2day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 8 # 2 days (2 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | geopotential: 33 | level: [500,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_z_500_3day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 12 # 3 days (3 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | geopotential: 33 | level: [500,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_z_500_3day_2csteps.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 12 # 3 days (3 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | geopotential: 33 | level: [500,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_z_500_4day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 16 # 4 days (4 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | geopotential: 33 | level: [500,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/template/rasp_thuerey_z_500_5day.yaml: -------------------------------------------------------------------------------- 1 | data_specs: 2 | conditioning_time_step: [0, -1, -2] 3 | max_chunksize: 0.5 # maximal size per chunk in GB 4 | conditioning_vars: 5 | # fields with pressure levels 6 | temperature: 7 | level: [50, 250, 500, 600, 700, 850, 925] 8 | geopotential: 9 | level: [50, 250, 500, 600, 700, 850, 925] 10 | u_component_of_wind: 11 | level: [50, 250, 500, 600, 700, 850, 925] 12 | v_component_of_wind: 13 | level: [50, 250, 500, 600, 700, 850, 925] 14 | specific_humidity: 15 | level: [50, 250, 500, 600, 700, 850, 925] 16 | # 2d fields: 17 | 2m_temperature: 18 | level: 19 | total_precipitation: 20 | level: 21 | toa_incident_solar_radiation: 22 | level: 23 | # constant input fields 24 | constants: 25 | - orography 26 | - lat2d 27 | - lsm 28 | delta_t: 6 29 | lead_time: 20 # 5 days (5 times 4 times 6 hours) 30 | output_vars: 31 | # 2d fields: 32 | geopotential: 33 | level: [500,] 34 | spatial_resolution: 5.625deg 35 | exp_data: 36 | test: 37 | start: 2017-01-01 00:00:00 38 | end: 2018-12-31 00:00:00 39 | train: 40 | start: 1979-01-02 00:00:00 41 | end: 2015-12-31 00:00:00 42 | val: 43 | start: 2016-01-01 00:00:00 44 | end: 2016-12-31 00:00:00 45 | 46 | -------------------------------------------------------------------------------- /config/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - paths: default_paths 3 | - _self_ 4 | 5 | hydra: 6 | job: 7 | chdir: False 8 | run: 9 | dir: ${paths.dir_HydraConfigs}/training/${experiment.data.template}/${hydra.runtime.choices.experiment}/${now:%Y-%m-%d_%H-%M-%S} -------------------------------------------------------------------------------- /env_data.yml: -------------------------------------------------------------------------------- 1 | name: WD_data # for creating datasets. 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - ipykernel # jupyter notebooks 7 | - xarray # data processing 8 | - netcdf4 9 | - dask 10 | - zarr 11 | - pytorch # to be able to write to .pt files if wanted 12 | - omegaconf # config 13 | - hydra-core 14 | - matplotlib # for simple diagnostic plotting stuff -------------------------------------------------------------------------------- /env_eval.yml: -------------------------------------------------------------------------------- 1 | name: WD_eval # to be used for evaluation purposes: ie. plotting, computing metrics, ... 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - numpy=<1.24 7 | - esmpy=8.0.1 8 | - xesmf=0.3.0 9 | - dask 10 | - pyshtools 11 | - ipykernel 12 | - netcdf4 13 | - matplotlib 14 | - tensorboard 15 | - cartopy 16 | - scipy 17 | - xskillscore 18 | - zarr 19 | - hydra-core 20 | - omegaconf 21 | -------------------------------------------------------------------------------- /env_model.yml: -------------------------------------------------------------------------------- 1 | name: WD_model # for training or predicting with a model 2 | channels: 3 | - nvidia 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - python=3.11 8 | - pytorch-cuda=11.7 9 | - torchaudio 10 | - pytorch # =2.0.1=py3.11_cuda11.7_cudnn8.5.0_0 11 | - torchvision 12 | - cudatoolkit=11.7.0 13 | - cudnn 14 | - xarray 15 | # - numpy 16 | - ipykernel 17 | - pytorch-lightning 18 | # - python-wget 19 | # - kornia 20 | - einops 21 | - netcdf4 22 | - dask 23 | - tensorboard 24 | - matplotlib 25 | # - cartopy 26 | # - munch 27 | # - pytest 28 | # - scipy 29 | # - xskillscore 30 | - zarr 31 | - timm # for running fourcastnet. 32 | - omegaconf 33 | - hydra-core 34 | # - xesmf 35 | -------------------------------------------------------------------------------- /images/chronologic_timesteps.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/chronologic_timesteps.jpg -------------------------------------------------------------------------------- /images/ensemble_condition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/ensemble_condition.jpg -------------------------------------------------------------------------------- /images/ensemble_predictions.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/ensemble_predictions.jpg -------------------------------------------------------------------------------- /images/ensemble_stats.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/ensemble_stats.jpg -------------------------------------------------------------------------------- /images/ensemble_std.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/ensemble_std.jpg -------------------------------------------------------------------------------- /images/heatwave_predictions_step_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/heatwave_predictions_step_0.png -------------------------------------------------------------------------------- /images/heatwave_predictions_step_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/heatwave_predictions_step_1.png -------------------------------------------------------------------------------- /images/heatwave_predictions_step_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/heatwave_predictions_step_10.png -------------------------------------------------------------------------------- /images/heatwave_predictions_step_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/heatwave_predictions_step_2.png -------------------------------------------------------------------------------- /images/heatwave_predictions_step_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/heatwave_predictions_step_20.png -------------------------------------------------------------------------------- /images/heatwave_predictions_step_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/heatwave_predictions_step_32.png -------------------------------------------------------------------------------- /images/heatwave_predictions_step_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/heatwave_predictions_step_5.png -------------------------------------------------------------------------------- /images/heatwave_true_anomaly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/heatwave_true_anomaly.png -------------------------------------------------------------------------------- /images/performance_leadtime_version_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/performance_leadtime_version_0.jpg -------------------------------------------------------------------------------- /images/performance_leadtime_version_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/performance_leadtime_version_1.jpg -------------------------------------------------------------------------------- /images/performance_leadtime_version_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/performance_leadtime_version_2.jpg -------------------------------------------------------------------------------- /images/performance_leadtime_version_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/performance_leadtime_version_3.jpg -------------------------------------------------------------------------------- /images/performance_leadtime_version_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/performance_leadtime_version_4.jpg -------------------------------------------------------------------------------- /images/predictions.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/predictions.jpg -------------------------------------------------------------------------------- /images/spectra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/spectra.png -------------------------------------------------------------------------------- /images/spectra_no_entries.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/spectra_no_entries.png -------------------------------------------------------------------------------- /images/t_850_lowres.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/t_850_lowres.gif -------------------------------------------------------------------------------- /images/timeseries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/timeseries.jpg -------------------------------------------------------------------------------- /images/z_500_lowres.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/diffusion-models-for-weather-prediction/936cb79e7abe5978ccb95e9c571c203db8b10629/images/z_500_lowres.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | experimental-string-processing = true 3 | line-length=79 4 | py36 = true 5 | include = '\.pyi?$' 6 | exclude = ''' 7 | /( 8 | \.git 9 | | \.hg 10 | | \.mypy_cache 11 | | \.tox 12 | | \.venv 13 | | _build 14 | | buck-out 15 | | build 16 | | dist 17 | 18 | # The following are specific to Black, you probably don't want those. 19 | | blib2to3 20 | | tests/data 21 | )/ 22 | ''' -------------------------------------------------------------------------------- /s10_write_predictions_vae.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import lightning as L 7 | 8 | import hydra 9 | from omegaconf import DictConfig, OmegaConf 10 | from WD.utils import create_dir 11 | 12 | from dm_zoo.latent.vae.vae_lightning_module import VAE 13 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 14 | from WD.io import create_xr_output_variables 15 | 16 | import numpy as np 17 | 18 | @hydra.main(version_base=None, config_path="./config", config_name="inference") 19 | def vae_inference(config: DictConfig) -> None: 20 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 21 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 22 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 23 | 24 | model_name = config.model_name # we have to pass this to the bash file every time! (should contain a string). 25 | experiment_name = hydra_cfg['runtime']['choices']['experiment'] 26 | 27 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.data.template}/.hydra/config.yaml") 28 | ml_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/training/{config.data.template}/{experiment_name}/{config.model_name}/.hydra/config.yaml") 29 | 30 | model_output_dir = config.paths.dir_ModelOutput 31 | 32 | model_load_dir = Path(f"{config.paths.dir_SavedModels}/{config.data.template}/{experiment_name}/{config.model_name}/lightning_logs/version_0/checkpoints/") 33 | 34 | test_ds_path = f"{config.paths.dir_PreprocessedData}{config.data.template}_test.zarr" 35 | 36 | ds = Conditional_Dataset_Zarr_Iterable(test_ds_path, ds_config.template, shuffle_chunks=config.shuffle_chunks, 37 | shuffle_in_chunks=config.shuffle_in_chunks) 38 | 39 | conditioning_channels = ds.array_inputs.shape[1] * len(ds.conditioning_timesteps) + ds.array_constants.shape[0] 40 | generated_channels = ds.array_targets.shape[1] 41 | img_size = ds.array_targets.shape[-2:] 42 | 43 | print(ml_config) 44 | 45 | if ml_config.experiment.vae.type == "input": 46 | n_channel = conditioning_channels 47 | elif ml_config.experiment.vae.type == "output": 48 | n_channel = generated_channels 49 | else: 50 | raise AssertionError 51 | 52 | in_shape = (n_channel, img_size[0], img_size[1]) 53 | 54 | model_ckpt = [x for x in model_load_dir.iterdir()][0] 55 | 56 | restored_model = VAE.load_from_checkpoint(model_ckpt, map_location="cpu", 57 | inp_shape = in_shape, 58 | dim=ml_config.experiment.vae.dim, 59 | channel_mult = ml_config.experiment.vae.channel_mult, 60 | batch_size = ml_config.experiment.vae.batch_size, 61 | lr = ml_config.experiment.vae.lr, 62 | lr_scheduler_name=ml_config.experiment.vae.lr_scheduler_name, 63 | num_workers = ml_config.experiment.vae.num_workers, 64 | beta = ml_config.experiment.vae.beta, 65 | data_type = ml_config.experiment.vae.type) 66 | 67 | dl = DataLoader(ds, batch_size=ml_config.experiment.vae.batch_size) 68 | 69 | out = [] 70 | for i, data in enumerate(dl): 71 | 72 | r, _, x, _ = restored_model(data) 73 | 74 | if i==0: 75 | print(r.shape, x.shape) 76 | print(f"Input reduction factor: {np.round(np.prod(r.shape[1:])/np.prod(x.shape[1:]), decimals=2)}") 77 | 78 | out.append(r) 79 | 80 | out = torch.cat(out, dim=0).unsqueeze(dim=0) # to keep compatible with the version that uses ensemble members 81 | 82 | model_output_dir = os.path.join(model_output_dir, config.data.template, experiment_name, model_name, dir_name) 83 | create_dir(model_output_dir) 84 | 85 | # need the view to create axis for 86 | # different ensemble members (although only 1 here). 87 | 88 | targets = torch.tensor(ds.data.targets.data[ds.start+ds.lead_time:ds.stop+ds.lead_time], dtype=torch.float).unsqueeze(dim=0) 89 | 90 | 91 | gen_xr = create_xr_output_variables( 92 | out, 93 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 94 | config=ds_config, 95 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 96 | ) 97 | 98 | 99 | target_xr = create_xr_output_variables( 100 | targets, 101 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 102 | config=ds_config, 103 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 104 | ) 105 | 106 | gen_dir = os.path.join(model_output_dir, "gen.nc") 107 | gen_xr.to_netcdf(gen_dir) 108 | print(f"Generated data written at: {gen_dir}") 109 | 110 | target_dir = os.path.join(model_output_dir, "target.nc") 111 | target_xr.to_netcdf(target_dir) 112 | print(f"Target data written at: {target_dir}") 113 | 114 | if __name__ == '__main__': 115 | vae_inference() 116 | 117 | -------------------------------------------------------------------------------- /s11_train_LFD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lightning as L 3 | from lightning.pytorch import loggers as pl_loggers 4 | from lightning.pytorch.callbacks import LearningRateMonitor 5 | from lightning.pytorch.callbacks import ( 6 | EarlyStopping, 7 | ) 8 | import os 9 | 10 | import hydra 11 | from omegaconf import DictConfig, OmegaConf, open_dict 12 | 13 | from dm_zoo.dff.EMA import EMA 14 | from dm_zoo.diffusion.LFD_lightning import LatentForecastDiffusion 15 | 16 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 17 | from WD.utils import check_devices, create_dir, AreaWeightedMSELoss 18 | 19 | @hydra.main(version_base=None, config_path="./config", config_name="train") 20 | 21 | def train_LFD(config: DictConfig) -> None: 22 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 23 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 24 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 25 | exp_name = hydra_cfg['runtime']['choices']['experiment'] 26 | 27 | print(f"The torch version being used is {torch.__version__}") 28 | check_devices() 29 | 30 | # load config 31 | print(f"Loading dataset {config.experiment.data.template}") 32 | # ds_config_path = os.path.join(conf.base_path, f"{conf.template}.yml") 33 | # ds_config = load_config(ds_config_path) 34 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.experiment.data.template}/.hydra/config.yaml") 35 | 36 | train_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_train.zarr" 37 | train_ds = Conditional_Dataset_Zarr_Iterable(train_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.train_shuffle_chunks, 38 | shuffle_in_chunks=config.experiment.data.train_shuffle_in_chunks) 39 | 40 | val_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_val.zarr" 41 | val_ds = Conditional_Dataset_Zarr_Iterable(val_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.val_shuffle_chunks, shuffle_in_chunks=config.experiment.data.val_shuffle_in_chunks) 42 | 43 | # select loss_fn: 44 | if config.experiment.model.loss_fn_name == "MSE_Loss": 45 | loss_fn = torch.nn.functional.mse_loss 46 | elif config.experiment.model.loss_fn_name == "AreaWeighted_MSE_Loss": 47 | lat_grid = train_ds.data.targets.lat[:] 48 | lon_grid = train_ds.data.targets.lon[:] 49 | loss_fn = AreaWeightedMSELoss(lat_grid, lon_grid).loss_fn 50 | else: 51 | raise NotImplementedError("Invalid loss function.") 52 | 53 | print(config.experiment.model.diffusion.sampler_name) 54 | if config.experiment.model.diffusion.sampler_name == "DDPM": # this is the default case 55 | sampler = None 56 | else: 57 | raise NotImplementedError("This sampler has not been implemented.") 58 | 59 | model_dir = f"{config.paths.dir_SavedModels}/{config.experiment.data.template}/{exp_name}/{dir_name}/" 60 | create_dir(model_dir) 61 | 62 | # set up logger: 63 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=model_dir) 64 | 65 | # set up diffusion model: 66 | 67 | conditioning_channels = train_ds.array_inputs.shape[1] * len(train_ds.conditioning_timesteps) + train_ds.array_constants.shape[0] 68 | generated_channels = train_ds.array_targets.shape[1] 69 | print("generated channels: {} conditioning channels: {}".format(generated_channels, conditioning_channels)) 70 | 71 | image_size = train_ds.array_inputs.shape[-2:] 72 | 73 | with open_dict(config): 74 | config.experiment.model.image_size = image_size 75 | config.experiment.model.generated_channels = generated_channels 76 | config.experiment.model.conditioning_channels = conditioning_channels 77 | 78 | 79 | model= LatentForecastDiffusion(config.experiment.model, 80 | train_dataset=train_ds, 81 | valid_dataset=val_ds, 82 | loss_fn = loss_fn, 83 | sampler = sampler) 84 | 85 | lr_monitor = LearningRateMonitor(logging_interval="step") 86 | 87 | early_stopping = EarlyStopping( 88 | monitor="val_reconstruction_loss", mode="min", patience=config.experiment.training.patience 89 | ) 90 | 91 | trainer = L.Trainer( 92 | max_steps=config.experiment.training.max_steps, 93 | limit_val_batches=config.experiment.training.limit_val_batches, 94 | accelerator=config.experiment.training.accelerator, 95 | devices=config.experiment.training.devices, 96 | callbacks=[EMA(config.experiment.training.ema_decay), lr_monitor, early_stopping], 97 | logger=tb_logger, 98 | gradient_clip_val=0.5 # Gradient clip value for exploding gradient 99 | ) 100 | 101 | trainer.fit(model) 102 | 103 | if __name__ == '__main__': 104 | train_LFD() -------------------------------------------------------------------------------- /s12_write_predictions_LFD.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import lightning as L 7 | 8 | import hydra 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from WD.utils import create_dir 11 | 12 | from dm_zoo.diffusion.LFD_lightning import LatentForecastDiffusion 13 | 14 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 15 | from WD.io import create_xr_output_variables 16 | 17 | from WD.utils import AreaWeightedMSELoss 18 | 19 | 20 | @hydra.main(version_base=None, config_path="./config", config_name="inference") 21 | def LFD_inference(config: DictConfig) -> None: 22 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 23 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 24 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 25 | 26 | model_name = config.model_name # we have to pass this to the bash file every time! (should contain a string). 27 | nens = config.n_ensemble_members 28 | experiment_name = hydra_cfg['runtime']['choices']['experiment'] 29 | 30 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.data.template}/.hydra/config.yaml") 31 | ml_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/training/{config.data.template}/{experiment_name}/{config.model_name}/.hydra/config.yaml") 32 | 33 | model_output_dir = config.paths.dir_ModelOutput 34 | 35 | model_load_dir = Path(f"{config.paths.dir_SavedModels}/{config.data.template}/{experiment_name}/{config.model_name}/lightning_logs/version_0/checkpoints/") 36 | 37 | test_ds_path = f"{config.paths.dir_PreprocessedDatasets}{config.data.template}_test.zarr" 38 | 39 | ds = Conditional_Dataset_Zarr_Iterable(test_ds_path, ds_config.template, shuffle_chunks=config.shuffle_chunks, 40 | shuffle_in_chunks=config.shuffle_in_chunks) 41 | 42 | conditioning_channels = ds.array_inputs.shape[1] * len(ds.conditioning_timesteps) + ds.array_constants.shape[0] 43 | generated_channels = ds.array_targets.shape[1] 44 | img_size = ds.array_targets.shape[-2:] 45 | 46 | print(ml_config) 47 | 48 | model_ckpt = [x for x in model_load_dir.iterdir()][0] 49 | if ml_config.experiment.model.loss_fn_name == "MSE_Loss": 50 | loss_fn = torch.nn.functional.mse_loss 51 | elif ml_config.experiment.model.loss_fn_name == "AreaWeighted_MSE_Loss": 52 | lat_grid = ds.data.targets.lat[:] 53 | lon_grid = ds.data.targets.lon[:] 54 | loss_fn = AreaWeightedMSELoss(lat_grid, lon_grid).loss_fn 55 | else: 56 | raise NotImplementedError("Invalid loss function.") 57 | 58 | if ml_config.experiment.model.diffusion.sampler_name == "DDPM": # this is the default case 59 | sampler = None 60 | else: 61 | raise NotImplementedError("This sampler has not been implemented.") 62 | 63 | with open_dict(ml_config): 64 | ml_config.experiment.model.image_size = img_size 65 | ml_config.experiment.model.generated_channels = generated_channels 66 | ml_config.experiment.model.conditioning_channels = conditioning_channels 67 | 68 | restored_model = LatentForecastDiffusion.load_from_checkpoint(model_ckpt, map_location="cpu", 69 | model_config = ml_config.experiment.model, 70 | loss_fn = loss_fn, 71 | sampler = sampler) 72 | 73 | dl = DataLoader(ds, batch_size=ml_config.experiment.model.batch_size) 74 | trainer = L.Trainer() 75 | 76 | out = [] 77 | for i in range(nens): 78 | pred = trainer.predict(restored_model, dl) 79 | pred = torch.cat(pred, dim=0).unsqueeze(dim=0) 80 | out.append(pred) 81 | 82 | out = torch.cat(out, dim=0) # to keep compatible with the version that uses ensemble members 83 | 84 | print(out.shape) 85 | model_output_dir = os.path.join(model_output_dir, config.data.template, experiment_name, model_name, dir_name) 86 | create_dir(model_output_dir) 87 | 88 | # need the view to create axis for 89 | # different ensemble members (although only 1 here). 90 | 91 | targets = torch.tensor(ds.data.targets.data[ds.start+ds.lead_time:ds.stop+ds.lead_time], dtype=torch.float).unsqueeze(dim=0) 92 | 93 | 94 | gen_xr = create_xr_output_variables( 95 | out, 96 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 97 | config=ds_config, 98 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 99 | ) 100 | 101 | 102 | target_xr = create_xr_output_variables( 103 | targets, 104 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 105 | config=ds_config, 106 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 107 | ) 108 | 109 | gen_dir = os.path.join(model_output_dir, "gen.nc") 110 | gen_xr.to_netcdf(gen_dir) 111 | print(f"Generated data written at: {gen_dir}") 112 | 113 | target_dir = os.path.join(model_output_dir, "target.nc") 114 | target_xr.to_netcdf(target_dir) 115 | print(f"Target data written at: {target_dir}") 116 | 117 | if __name__ == '__main__': 118 | LFD_inference() 119 | 120 | -------------------------------------------------------------------------------- /s13_write_predictions_iterative.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from dm_zoo.dff.PixelDiffusion import ( 11 | PixelDiffusionConditional, 12 | ) 13 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 14 | from WD.utils import create_dir 15 | from WD.io import create_xr_output_variables 16 | import lightning as L 17 | 18 | 19 | @hydra.main(version_base=None, config_path="./config", config_name="inference") 20 | def main(config: DictConfig) -> None: 21 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 22 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 23 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 24 | 25 | model_name = config.model_name # we have to pass this to the bash file every time! (should contain a string of the date the run was started). 26 | experiment_name = hydra_cfg['runtime']['choices']['experiment'] 27 | nens = config.n_ensemble_members # we have to pass this to the bash file every time! 28 | 29 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.data.template}/.hydra/config.yaml") 30 | ml_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/training/{config.data.template}/{experiment_name}/{config.model_name}/.hydra/config.yaml") 31 | 32 | model_output_dir = config.paths.dir_ModelOutput 33 | 34 | model_load_dir = Path(f"{config.paths.dir_SavedModels}/{config.data.template}/{experiment_name}/{config.model_name}/lightning_logs/version_0/checkpoints/") 35 | 36 | test_ds_path = f"{config.paths.dir_PreprocessedDatasets}{config.data.template}_test.zarr" 37 | 38 | 39 | assert config.shuffle_in_chunks is False, "no shuffling allowed for iterative predictions" 40 | assert config.shuffle_chunks is False, "no shuffling allowed for iterative predictions" 41 | 42 | ds = Conditional_Dataset_Zarr_Iterable(test_ds_path, ds_config.template, shuffle_chunks=config.shuffle_chunks, 43 | shuffle_in_chunks=config.shuffle_in_chunks) 44 | 45 | model_ckpt = [x for x in model_load_dir.iterdir()][0] 46 | 47 | conditioning_channels = ds.array_inputs.shape[1] * len(ds.conditioning_timesteps) + ds.array_constants.shape[0] 48 | generated_channels = ds.array_targets.shape[1] 49 | 50 | restored_model = PixelDiffusionConditional.load_from_checkpoint( 51 | model_ckpt, 52 | config=ml_config.experiment.pixel_diffusion, 53 | conditioning_channels=conditioning_channels, 54 | generated_channels=generated_channels, 55 | loss_fn=config.loss_fn, 56 | sampler=config.sampler, 57 | ) 58 | 59 | dl = DataLoader(ds, batch_size=config.batchsize) 60 | 61 | n_steps = config.n_steps 62 | 63 | constants = torch.tensor(ds.array_constants[:], dtype=torch.float).to(restored_model.device) 64 | 65 | out = [] 66 | for i in range(nens): # loop over ensemble members 67 | ts = [] 68 | for b in dl: # loop over batches in test set 69 | input = b 70 | trajectories = torch.zeros(size=(b[1].shape[0], n_steps, *b[1].shape[1:])) 71 | for step in range(n_steps): 72 | restored_model.eval() 73 | with torch.no_grad(): 74 | res = restored_model.forward(input) # is this a list of tensors or a tensor? 75 | trajectories[:,step,...] = res 76 | input = [torch.concatenate([res, constants.unsqueeze(0).expand(res.size(0), *constants.size())], dim=1), None] # we don't need the true target here 77 | ts.append(trajectories) 78 | out.append(torch.cat(ts, dim=0)) 79 | out = torch.stack(out, dim=0) 80 | 81 | 82 | # get the targets: 83 | targets = torch.tensor(ds.data.targets.data[ds.start+ds.lead_time:ds.stop+ds.lead_time]) 84 | l = len(targets) 85 | 86 | indices = torch.stack([torch.arange(i, i+n_steps) for i in range(l-n_steps+1)], dim=0) 87 | targets = targets[indices,:,:,:].unsqueeze(dim=0) 88 | # fill targets with infty values until we reach the shape we need 89 | targets = torch.cat([targets, torch.ones((targets.shape[0], n_steps-1, *targets.shape[2:]))*torch.inf], dim=1) 90 | 91 | print(out.shape, targets.shape) 92 | 93 | model_output_dir = os.path.join(model_output_dir, config.data.template, experiment_name, model_name, dir_name) 94 | create_dir(model_output_dir) 95 | 96 | gen_xr = create_xr_output_variables( 97 | out, 98 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 99 | config=ds_config, 100 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 101 | ) 102 | 103 | target_xr = create_xr_output_variables( 104 | targets, 105 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 106 | config=ds_config, 107 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 108 | ) 109 | 110 | gen_dir = os.path.join(model_output_dir, "gen.nc") 111 | gen_xr.to_netcdf(gen_dir) 112 | print(f"Generated data written at: {gen_dir}") 113 | 114 | target_dir = os.path.join(model_output_dir, "target.nc") 115 | target_xr.to_netcdf(target_dir) 116 | print(f"Target data written at: {target_dir}") 117 | 118 | if __name__ == '__main__': 119 | main() 120 | 121 | -------------------------------------------------------------------------------- /s14_very_long_iterative_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from dm_zoo.dff.PixelDiffusion import ( 11 | PixelDiffusionConditional, 12 | ) 13 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 14 | from WD.utils import create_dir 15 | from WD.io import create_xr_output_variables 16 | import lightning as L 17 | 18 | 19 | @hydra.main(version_base=None, config_path="./config", config_name="inference") 20 | def main(config: DictConfig) -> None: 21 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 22 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 23 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 24 | 25 | model_name = config.model_name # we have to pass this to the bash file every time! (should contain a string of the date the run was started). 26 | nens = config.n_ensemble_members # we have to pass this to the bash file every time! 27 | experiment_name = hydra_cfg['runtime']['choices']['experiment'] 28 | 29 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.data.template}/.hydra/config.yaml") 30 | ml_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/training/{config.data.template}/{experiment_name}/{config.model_name}/.hydra/config.yaml") 31 | 32 | model_output_dir = config.paths.dir_ModelOutput 33 | 34 | model_load_dir = Path(f"{config.paths.dir_SavedModels}/{config.data.template}/{experiment_name}/{config.model_name}/lightning_logs/version_0/checkpoints/") 35 | 36 | test_ds_path = f"{config.paths.dir_PreprocessedDatasets}{config.data.template}_test.zarr" 37 | 38 | 39 | assert config.shuffle_in_chunks is False, "no shuffling allowed for iterative predictions" 40 | assert config.shuffle_chunks is False, "no shuffling allowed for iterative predictions" 41 | 42 | ds = Conditional_Dataset_Zarr_Iterable(test_ds_path, ds_config.template, shuffle_chunks=config.shuffle_chunks, 43 | shuffle_in_chunks=config.shuffle_in_chunks) 44 | 45 | model_ckpt = [x for x in model_load_dir.iterdir()][0] 46 | 47 | conditioning_channels = ds.array_inputs.shape[1] * len(ds.conditioning_timesteps) + ds.array_constants.shape[0] 48 | generated_channels = ds.array_targets.shape[1] 49 | 50 | restored_model = PixelDiffusionConditional.load_from_checkpoint( 51 | model_ckpt, 52 | config=ml_config.experiment.pixel_diffusion, 53 | conditioning_channels=conditioning_channels, 54 | generated_channels=generated_channels, 55 | loss_fn=config.loss_fn, 56 | sampler=config.sampler, 57 | ) 58 | 59 | dl = DataLoader(ds, batch_size=1) 60 | 61 | n_steps = len(ds.data.targets.data[ds.start+ds.lead_time:ds.stop+ds.lead_time]) 62 | 63 | constants = torch.tensor(ds.array_constants[:], dtype=torch.float).to(restored_model.device) 64 | 65 | out = [] 66 | input = next(iter(dl)) 67 | for i in range(nens): # loop over ensemble members 68 | trajectories = torch.zeros(size=(input[1].shape[0], n_steps, *input[1].shape[1:])) 69 | ts = [] 70 | for step in range(n_steps): 71 | print(step) 72 | restored_model.eval() 73 | with torch.no_grad(): 74 | res = restored_model.forward(input) # is this a list of tensors or a tensor? 75 | trajectories[:,step,...] = res 76 | input = [torch.concatenate([res, constants.unsqueeze(0).expand(res.size(0), *constants.size())], dim=1), None] # we don't need the true target here 77 | ts.append(trajectories) 78 | out.append(torch.cat(ts, dim=0)) 79 | out = torch.stack(out, dim=0) 80 | 81 | print(out.shape) 82 | 83 | model_output_dir = os.path.join(model_output_dir, config.data.template, experiment_name, model_name, dir_name) 84 | create_dir(model_output_dir) 85 | 86 | gen_xr = create_xr_output_variables( 87 | out, 88 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 89 | config=ds_config, 90 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 91 | ) 92 | 93 | gen_dir = os.path.join(model_output_dir, "gen.nc") 94 | gen_xr.to_netcdf(gen_dir) 95 | print(f"Generated data written at: {gen_dir}") 96 | 97 | if __name__ == '__main__': 98 | main() 99 | 100 | -------------------------------------------------------------------------------- /s1_write_dataset.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | from WD.datasets import write_conditional_datasets 5 | 6 | @hydra.main(version_base=None, config_path="./config", config_name="data") 7 | def main(conf: DictConfig) -> None: 8 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 9 | template_name = hydra_cfg['runtime']['choices']['template'] 10 | write_conditional_datasets(conf, template_name) 11 | 12 | if __name__ == '__main__': 13 | main() 14 | -------------------------------------------------------------------------------- /s2_train_conditional_pixel_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | import lightning as L 7 | from lightning.pytorch import loggers as pl_loggers 8 | 9 | from dm_zoo.dff.EMA import EMA 10 | from dm_zoo.dff.PixelDiffusion import ( 11 | PixelDiffusionConditional, 12 | ) 13 | 14 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 15 | import torch 16 | 17 | import os 18 | 19 | 20 | from WD.utils import check_devices, create_dir, AreaWeightedMSELoss 21 | from lightning.pytorch.callbacks import LearningRateMonitor 22 | from lightning.pytorch.callbacks import ( 23 | EarlyStopping, 24 | ) 25 | 26 | 27 | @hydra.main(version_base=None, config_path="./config", config_name="train") 28 | def main(config: DictConfig) -> None: 29 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 30 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 31 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 32 | exp_name = hydra_cfg['runtime']['choices']['experiment'] 33 | 34 | print(f"The torch version being used is {torch.__version__}") 35 | check_devices() 36 | 37 | # load config 38 | print(f"Loading dataset {config.experiment.data.template}") 39 | # ds_config_path = os.path.join(conf.base_path, f"{conf.template}.yml") 40 | # ds_config = load_config(ds_config_path) 41 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.experiment.data.template}/.hydra/config.yaml") 42 | 43 | # set up datasets: 44 | 45 | train_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_train.zarr" 46 | train_ds = Conditional_Dataset_Zarr_Iterable(train_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.train_shuffle_chunks, 47 | shuffle_in_chunks=config.experiment.data.train_shuffle_in_chunks) 48 | 49 | val_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_val.zarr" 50 | val_ds = Conditional_Dataset_Zarr_Iterable(val_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.val_shuffle_chunks, shuffle_in_chunks=config.experiment.data.val_shuffle_in_chunks) 51 | 52 | # select loss_fn: 53 | if config.experiment.setup.loss_fn_name == "MSE_Loss": 54 | loss_fn = torch.nn.functional.mse_loss 55 | elif config.experiment.setup.loss_fn_name == "AreaWeighted_MSE_Loss": 56 | lat_grid = train_ds.data.targets.lat[:] 57 | lon_grid = train_ds.data.targets.lon[:] 58 | loss_fn = AreaWeightedMSELoss(lat_grid, lon_grid).loss_fn 59 | else: 60 | raise NotImplementedError("Invalid loss function.") 61 | 62 | if config.experiment.setup.sampler_name == "DDPM": # this is the default case 63 | sampler = None 64 | else: 65 | raise NotImplementedError("This sampler has not been implemented.") 66 | 67 | # create unique model id and create directory to save model in: 68 | model_dir = f"{config.paths.dir_SavedModels}/{config.experiment.data.template}/{exp_name}/{dir_name}/" 69 | create_dir(model_dir) 70 | 71 | # set up logger: 72 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=model_dir) 73 | 74 | # set up diffusion model: 75 | 76 | conditioning_channels = train_ds.array_inputs.shape[1] * len(train_ds.conditioning_timesteps) + train_ds.array_constants.shape[0] 77 | generated_channels = train_ds.array_targets.shape[1] 78 | print("generated channels: {} conditioning channels: {}".format(generated_channels, conditioning_channels)) 79 | 80 | model = PixelDiffusionConditional( 81 | config.experiment.pixel_diffusion, 82 | generated_channels=generated_channels, 83 | conditioning_channels=conditioning_channels, 84 | loss_fn=loss_fn, 85 | sampler=sampler, 86 | train_dataset=train_ds, 87 | valid_dataset=val_ds 88 | ) 89 | 90 | lr_monitor = LearningRateMonitor(logging_interval="step") 91 | 92 | early_stopping = EarlyStopping( 93 | monitor="val_loss_new", mode="min", patience=config.experiment.training.patience 94 | ) 95 | 96 | trainer = L.Trainer( 97 | max_steps=config.experiment.training.max_steps, 98 | limit_val_batches=config.experiment.training.limit_val_batches, 99 | accelerator=config.experiment.training.accelerator, 100 | devices=config.experiment.training.devices, 101 | callbacks=[EMA(config.experiment.training.ema_decay), lr_monitor, early_stopping], 102 | logger=tb_logger 103 | ) 104 | 105 | trainer.fit(model) 106 | 107 | if __name__ == '__main__': 108 | main() -------------------------------------------------------------------------------- /s3_write_predictions_conditional_pixel_diffusion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import hydra 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | 11 | from dm_zoo.dff.PixelDiffusion import ( 12 | PixelDiffusionConditional, 13 | ) 14 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 15 | from WD.utils import create_dir 16 | from WD.io import create_xr_output_variables 17 | import lightning as L 18 | 19 | 20 | 21 | @hydra.main(version_base=None, config_path="./config", config_name="inference") 22 | def main(config: DictConfig) -> None: 23 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 24 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 25 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 26 | 27 | experiment_name = hydra_cfg['runtime']['choices']['experiment'] 28 | model_name = config.model_name # we have to pass this to the bash file every time! (should contain a string). 29 | nens = config.n_ensemble_members # we have to pass this to the bash file every time! 30 | 31 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.data.template}/.hydra/config.yaml") 32 | ml_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/training/{config.data.template}/{experiment_name}/{config.model_name}/.hydra/config.yaml") 33 | 34 | model_output_dir = config.paths.dir_ModelOutput 35 | 36 | model_load_dir = Path(f"{config.paths.dir_SavedModels}/{config.data.template}/{experiment_name}/{config.model_name}/lightning_logs/version_0/checkpoints/") 37 | 38 | test_ds_path = f"{config.paths.dir_PreprocessedDatasets}{config.data.template}_test.zarr" 39 | 40 | ds = Conditional_Dataset_Zarr_Iterable(test_ds_path, ds_config.template, shuffle_chunks=config.shuffle_chunks, 41 | shuffle_in_chunks=config.shuffle_in_chunks) 42 | 43 | model_ckpt = [x for x in model_load_dir.iterdir()][0] 44 | 45 | conditioning_channels = ds.array_inputs.shape[1] * len(ds.conditioning_timesteps) + ds.array_constants.shape[0] 46 | generated_channels = ds.array_targets.shape[1] 47 | 48 | restored_model = PixelDiffusionConditional.load_from_checkpoint( 49 | model_ckpt, 50 | config=ml_config.experiment.pixel_diffusion, 51 | conditioning_channels=conditioning_channels, 52 | generated_channels=generated_channels, 53 | loss_fn=config.loss_fn, 54 | sampler=config.sampler, 55 | ) 56 | 57 | dl = DataLoader(ds, batch_size=config.batchsize) 58 | trainer = L.Trainer() 59 | 60 | out = [] 61 | for i in range(nens): 62 | out.extend(trainer.predict(restored_model, dl)) 63 | 64 | out = torch.cat(out, dim=0) 65 | out = out.view(nens, -1, *out.shape[1:]) 66 | 67 | model_output_dir = os.path.join(model_output_dir, config.data.template, experiment_name, model_name, dir_name) 68 | create_dir(model_output_dir) 69 | 70 | # need the view to create axis for 71 | # different ensemble members (although only 1 here). 72 | 73 | targets = torch.tensor(ds.data.targets.data[ds.start+ds.lead_time:ds.stop+ds.lead_time], dtype=torch.float).unsqueeze(dim=0) 74 | 75 | gen_xr = create_xr_output_variables( 76 | out, 77 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 78 | config=ds_config, 79 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 80 | ) 81 | 82 | target_xr = create_xr_output_variables( 83 | targets, 84 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 85 | config=ds_config, 86 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 87 | ) 88 | 89 | gen_dir = os.path.join(model_output_dir, "gen.nc") 90 | gen_xr.to_netcdf(gen_dir) 91 | print(f"Generated data written at: {gen_dir}") 92 | 93 | target_dir = os.path.join(model_output_dir, "target.nc") 94 | target_xr.to_netcdf(target_dir) 95 | print(f"Target data written at: {target_dir}") 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | 101 | -------------------------------------------------------------------------------- /s4_train_val_test.py: -------------------------------------------------------------------------------- 1 | # inbuilt packages 2 | import argparse 3 | from pathlib import Path 4 | from time import time 5 | 6 | # Standard packages 7 | import numpy as np 8 | import torch 9 | import lightning as L 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | 13 | # Custom packages 14 | from dm_zoo.dff.PixelDiffusion import ( 15 | PixelDiffusionConditional, 16 | ) 17 | from WD.datasets import Conditional_Dataset, custom_collate 18 | from WD.utils import create_dir 19 | from WD.io import load_config, create_xr_output_variables 20 | 21 | parser = argparse.ArgumentParser( 22 | prog="Evalulate Model", 23 | description="Evaluate Model based on dataset id and model id", 24 | epilog="Arg parser for vanilla conditional diffusion model", 25 | ) 26 | 27 | parser.add_argument( 28 | "-did", 29 | "--dataset_id", 30 | type=str, 31 | help="path under which the selected config file is stored.", 32 | ) 33 | 34 | 35 | parser.add_argument( 36 | "-mid", 37 | "--model_id", 38 | type=str, 39 | help="path under which the selected config file is stored.", 40 | ) 41 | 42 | parser.add_argument( 43 | "-nens", 44 | "--n_ensemble_members", 45 | type=int, 46 | help="the number of ensemble members to be produced.", 47 | ) 48 | 49 | args = parser.parse_args() 50 | 51 | ds_id = args.dataset_id 52 | run_id = args.model_id 53 | nens = args.n_ensemble_members 54 | 55 | 56 | B = 1024 57 | num_copies = nens 58 | 59 | start_time = time() 60 | 61 | 62 | def write_dataset( 63 | restored_model, ds, B, num_copies, model_config, epoch, model_output_dir 64 | ): 65 | dl = DataLoader( 66 | ds, 67 | batch_size=B, 68 | shuffle=False, 69 | collate_fn=lambda x: custom_collate(x, num_copies=num_copies), 70 | ) 71 | trainer = pl.Trainer() 72 | out = trainer.predict(restored_model, dl) 73 | 74 | out = torch.cat(out, dim=0) 75 | out = out.view(-1, num_copies, *out.shape[1:]).transpose(0, 1) 76 | 77 | model_output_dir = model_output_dir / model_config.ds_id / str(epoch) 78 | create_dir(model_output_dir) 79 | 80 | targets = ds[:][1].view(1, *ds[:][1].shape) 81 | # need the view to create axis for different 82 | # ensemble members (although only 1 here). 83 | dates = ds[:][2] 84 | 85 | gen_xr = create_xr_output_variables( 86 | out, 87 | dates=dates, 88 | config_file_path=( 89 | "/data/compoundx/WeatherDiff/config_file/{}.yml".format(ds_id) 90 | ), 91 | min_max_file_path=( 92 | "/data/compoundx/WeatherDiff/model_input/{}_output_min_max.nc" 93 | .format(ds_id) 94 | ), 95 | ) 96 | 97 | target_xr = create_xr_output_variables( 98 | targets, 99 | dates=dates, 100 | config_file_path=( 101 | "/data/compoundx/WeatherDiff/config_file/{}.yml".format(ds_id) 102 | ), 103 | min_max_file_path=( 104 | "/data/compoundx/WeatherDiff/model_input/{}_output_min_max.nc" 105 | .format(ds_id) 106 | ), 107 | ) 108 | 109 | return F.mse_loss( 110 | torch.tensor(gen_xr["z_500"].values), 111 | torch.tensor(target_xr["z_500"].values), 112 | ) 113 | 114 | 115 | model_config_path = "/data/compoundx/WeatherDiff/config_file/{}_{}.yml".format( 116 | ds_id, run_id 117 | ) 118 | model_output_dir = Path("/data/compoundx/WeatherDiff/model_output/") 119 | 120 | print(model_config_path) 121 | model_config = load_config(model_config_path) 122 | 123 | model_load_dir = ( 124 | Path(model_config.file_structure.dir_saved_model) 125 | / "lightning_logs/version_0/checkpoints/" 126 | ) 127 | 128 | mse_loss_test = [] 129 | mse_loss_val = [] 130 | epoch_list = [] 131 | for epoch in range(0, 120, 5): 132 | model_ckpt = [ 133 | x for x in model_load_dir.iterdir() if f"epoch={epoch}-" in str(x) 134 | ][0] 135 | 136 | restored_model = PixelDiffusionConditional.load_from_checkpoint( 137 | model_ckpt, 138 | generated_channels=model_config.model_hparam["generated_channels"], 139 | condition_channels=model_config.model_hparam["condition_channels"], 140 | cylindrical_padding=True, 141 | ) 142 | print(epoch) 143 | ds_test = Conditional_Dataset( 144 | "/data/compoundx/WeatherDiff/model_input/{}_test.pt".format(ds_id), 145 | "/data/compoundx/WeatherDiff/config_file/{}.yml".format(ds_id), 146 | ) 147 | 148 | loss = write_dataset( 149 | restored_model, 150 | ds_test, 151 | B, 152 | num_copies, 153 | model_config, 154 | epoch, 155 | model_output_dir, 156 | ) 157 | mse_loss_test.append(loss) 158 | print(loss) 159 | ds_val = Conditional_Dataset( 160 | "/data/compoundx/WeatherDiff/model_input/{}_val.pt".format(ds_id), 161 | "/data/compoundx/WeatherDiff/config_file/{}.yml".format(ds_id), 162 | ) 163 | 164 | loss = write_dataset( 165 | restored_model, 166 | ds_val, 167 | B, 168 | num_copies, 169 | model_config, 170 | epoch, 171 | model_output_dir, 172 | ) 173 | mse_loss_val.append(loss) 174 | epoch_list.append(epoch) 175 | print(loss) 176 | 177 | mse_loss = np.vstack([epoch_list, mse_loss_test, mse_loss_val]) 178 | np.savetxt("test_val_loss.txt", np.array(mse_loss)) 179 | 180 | print(f"Total time taken is {np.round(time()-start_time, 2)} seconds") 181 | 182 | # ds_train = Conditional_Dataset( 183 | # "/data/compoundx/WeatherDiff/model_input/{}_train.pt".format(ds_id), 184 | # "/data/compoundx/WeatherDiff/config_file/{}.yml".format(ds_id), 185 | # ) 186 | 187 | 188 | # ds_val = Conditional_Dataset( 189 | # "/data/compoundx/WeatherDiff/model_input/{}_val.pt".format(ds_id), 190 | # "/data/compoundx/WeatherDiff/config_file/{}.yml".format(ds_id), 191 | # ) 192 | 193 | 194 | # model_config.file_structure.dir_model_output = str(model_output_dir) 195 | 196 | 197 | # write_config(model_config) 198 | # Write config is possible deletes and rewrites 199 | -------------------------------------------------------------------------------- /s5_train_FourCastNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | import lightning as L 7 | from pytorch_lightning import loggers as pl_loggers 8 | 9 | from dm_zoo.dff.EMA import EMA 10 | 11 | 12 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 13 | import torch 14 | 15 | import os 16 | 17 | from dm_zoo.fourcast.train_FourCastNet import FourCastNetDirect 18 | 19 | from WD.utils import check_devices, create_dir, AreaWeightedMSELoss 20 | 21 | from lightning.pytorch.callbacks import LearningRateMonitor 22 | from lightning.pytorch.callbacks import ( 23 | EarlyStopping, 24 | ) 25 | 26 | 27 | @hydra.main(version_base=None, config_path="./config", config_name="train") 28 | def main(config: DictConfig) -> None: 29 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 30 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 31 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 32 | exp_name = hydra_cfg['runtime']['choices']['experiment'] 33 | 34 | print(f"The torch version being used is {torch.__version__}") 35 | check_devices() 36 | 37 | # load config 38 | print(f"Loading dataset {config.experiment.data.template}") 39 | # ds_config_path = os.path.join(conf.base_path, f"{conf.template}.yml") 40 | # ds_config = load_config(ds_config_path) 41 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.experiment.data.template}/.hydra/config.yaml") 42 | 43 | # set up datasets: 44 | 45 | train_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_train.zarr" 46 | train_ds = Conditional_Dataset_Zarr_Iterable(train_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.train_shuffle_chunks, 47 | shuffle_in_chunks=config.experiment.data.train_shuffle_in_chunks) 48 | 49 | val_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_val.zarr" 50 | val_ds = Conditional_Dataset_Zarr_Iterable(val_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.val_shuffle_chunks, shuffle_in_chunks=config.experiment.data.val_shuffle_in_chunks) 51 | 52 | # select loss_fn: 53 | if config.experiment.setup.loss_fn_name == "MSE_Loss": 54 | loss_fn = torch.nn.functional.mse_loss 55 | elif config.experiment.setup.loss_fn_name == "AreaWeighted_MSE_Loss": 56 | lat_grid = train_ds.data.targets.lat[:] 57 | lon_grid = train_ds.data.targets.lon[:] 58 | loss_fn = AreaWeightedMSELoss(lat_grid, lon_grid).loss_fn 59 | else: 60 | raise NotImplementedError("Invalid loss function.") 61 | 62 | # create unique model id and create directory to save model in: 63 | model_dir = f"{config.paths.dir_SavedModels}/{config.experiment.data.template}/{exp_name}/{dir_name}/" 64 | create_dir(model_dir) 65 | 66 | # set up logger: 67 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=model_dir) 68 | 69 | # set up diffusion model: 70 | 71 | conditioning_channels = train_ds.array_inputs.shape[1] * len(train_ds.conditioning_timesteps) + train_ds.array_constants.shape[0] 72 | generated_channels = train_ds.array_targets.shape[1] 73 | print("generated channels: {} conditioning channels: {}".format(generated_channels, conditioning_channels)) 74 | 75 | image_size = train_ds.array_inputs.shape[-2:] 76 | 77 | model = FourCastNetDirect( 78 | config.experiment.fourcastnet, 79 | img_size=image_size, 80 | out_channels=generated_channels, 81 | in_channels=conditioning_channels, 82 | loss_fn=loss_fn, 83 | train_dataset=train_ds, 84 | valid_dataset=val_ds 85 | ) 86 | 87 | lr_monitor = LearningRateMonitor(logging_interval="step") 88 | 89 | early_stopping = EarlyStopping( 90 | monitor="val_loss", mode="min", patience=config.experiment.training.patience 91 | ) 92 | 93 | trainer = pl.Trainer( 94 | max_steps=config.experiment.training.max_steps, 95 | limit_val_batches=config.experiment.training.limit_val_batches, 96 | accelerator=config.experiment.training.accelerator, 97 | devices=config.experiment.training.devices, 98 | callbacks=[EMA(config.experiment.training.ema_decay), lr_monitor, early_stopping], 99 | logger=tb_logger 100 | ) 101 | 102 | trainer.fit(model) 103 | 104 | if __name__ == '__main__': 105 | main() -------------------------------------------------------------------------------- /s6_write_predictions_FourCastNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import lightning as L 7 | 8 | import hydra 9 | from omegaconf import DictConfig, OmegaConf 10 | 11 | from dm_zoo.fourcast.train_FourCastNet import FourCastNetDirect 12 | 13 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 14 | from WD.utils import create_dir 15 | from WD.io import create_xr_output_variables 16 | # from WD.io import load_config, write_config # noqa F401 17 | 18 | 19 | 20 | @hydra.main(version_base=None, config_path="./config", config_name="inference") 21 | def main(config: DictConfig) -> None: 22 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 23 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 24 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 25 | 26 | model_name = config.model_name # we have to pass this to the bash file every time! (should contain a string). 27 | experiment_name = hydra_cfg['runtime']['choices']['experiment'] 28 | 29 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.data.template}/.hydra/config.yaml") 30 | ml_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/training/{config.data.template}/{experiment_name}/{config.model_name}/.hydra/config.yaml") 31 | 32 | model_output_dir = config.paths.dir_ModelOutput 33 | 34 | model_load_dir = Path(f"{config.paths.dir_SavedModels}/{config.data.template}/{experiment_name}/{config.model_name}/lightning_logs/version_0/checkpoints/") 35 | 36 | test_ds_path = f"{config.paths.dir_PreprocessedDatasets}{config.data.template}_test.zarr" 37 | 38 | ds = Conditional_Dataset_Zarr_Iterable(test_ds_path, ds_config.template, shuffle_chunks=config.shuffle_chunks, 39 | shuffle_in_chunks=config.shuffle_in_chunks) 40 | 41 | model_ckpt = [x for x in model_load_dir.iterdir()][0] 42 | 43 | conditioning_channels = ds.array_inputs.shape[1] * len(ds.conditioning_timesteps) + ds.array_constants.shape[0] 44 | generated_channels = ds.array_targets.shape[1] 45 | img_size = ds.array_targets.shape[-2:] 46 | 47 | restored_model = FourCastNetDirect.load_from_checkpoint( 48 | model_ckpt, 49 | config=ml_config.experiment.fourcastnet, 50 | img_size=img_size, 51 | in_channels=conditioning_channels, 52 | out_channels=generated_channels, 53 | loss_fn=config.loss_fn 54 | ) 55 | 56 | dl = DataLoader(ds, batch_size=config.batchsize) 57 | trainer = L.Trainer() 58 | 59 | out = trainer.predict(restored_model, dl) 60 | out = torch.cat(out, dim=0).unsqueeze(dim=0) # to keep compatible with the version that uses ensemble members 61 | print(out.shape) 62 | 63 | model_output_dir = os.path.join(model_output_dir, config.data.template, experiment_name, model_name, dir_name) 64 | create_dir(model_output_dir) 65 | 66 | # need the view to create axis for 67 | # different ensemble members (although only 1 here). 68 | 69 | targets = torch.tensor(ds.data.targets.data[ds.start+ds.lead_time:ds.stop+ds.lead_time], dtype=torch.float).unsqueeze(dim=0) 70 | 71 | print(targets.shape) 72 | gen_xr = create_xr_output_variables( 73 | out, 74 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 75 | config=ds_config, 76 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 77 | ) 78 | 79 | target_xr = create_xr_output_variables( 80 | targets, 81 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 82 | config=ds_config, 83 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 84 | ) 85 | 86 | gen_dir = os.path.join(model_output_dir, "gen.nc") 87 | gen_xr.to_netcdf(gen_dir) 88 | print(f"Generated data written at: {gen_dir}") 89 | 90 | target_dir = os.path.join(model_output_dir, "target.nc") 91 | target_xr.to_netcdf(target_dir) 92 | print(f"Target data written at: {target_dir}") 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | 98 | -------------------------------------------------------------------------------- /s7_train_unet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import lightning as L 3 | from lightning.pytorch import loggers as pl_loggers 4 | 5 | import hydra 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | import os 9 | 10 | from dm_zoo.dff.EMA import EMA 11 | from dm_zoo.dff.UNetRegression import ( 12 | UNetRegression, 13 | ) 14 | from WD.datasets import Conditional_Dataset_Zarr_Iterable, Conditional_Dataset 15 | import torch 16 | from WD.utils import check_devices, create_dir, generate_uid, AreaWeightedMSELoss 17 | from WD.io import write_config, load_config 18 | from lightning.pytorch.callbacks import LearningRateMonitor 19 | from lightning.pytorch.callbacks import ( 20 | EarlyStopping, 21 | ) 22 | 23 | 24 | @hydra.main(version_base=None, config_path="./config", config_name="train") 25 | def main(config: DictConfig) -> None: 26 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 27 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 28 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 29 | exp_name = hydra_cfg['runtime']['choices']['experiment'] 30 | 31 | print(f"The torch version being used is {torch.__version__}") 32 | check_devices() 33 | 34 | # load config 35 | print(f"Loading dataset {config.experiment.data.template}") 36 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.experiment.data.template}/.hydra/config.yaml") 37 | 38 | # set up datasets: 39 | 40 | train_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_train.zarr" 41 | train_ds = Conditional_Dataset_Zarr_Iterable(train_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.train_shuffle_chunks, 42 | shuffle_in_chunks=config.experiment.data.train_shuffle_in_chunks) 43 | 44 | val_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_val.zarr" 45 | val_ds = Conditional_Dataset_Zarr_Iterable(val_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.val_shuffle_chunks, shuffle_in_chunks=config.experiment.data.val_shuffle_in_chunks) 46 | 47 | # select loss_fn: 48 | if config.experiment.setup.loss_fn_name == "MSE_Loss": 49 | loss_fn = torch.nn.functional.mse_loss 50 | elif config.experiment.setup.loss_fn_name == "AreaWeighted_MSE_Loss": 51 | lat_grid = train_ds.data.targets.lat[:] 52 | lon_grid = train_ds.data.targets.lon[:] 53 | loss_fn = AreaWeightedMSELoss(lat_grid, lon_grid).loss_fn 54 | else: 55 | raise NotImplementedError("Invalid loss function.") 56 | 57 | # create unique model id and create directory to save model in: 58 | model_dir = f"{config.paths.dir_SavedModels}/{config.experiment.data.template}/{exp_name}/{dir_name}/" 59 | create_dir(model_dir) 60 | 61 | # set up logger: 62 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=model_dir) 63 | 64 | # set up diffusion model: 65 | conditioning_channels = train_ds.array_inputs.shape[1] * len(train_ds.conditioning_timesteps) + train_ds.array_constants.shape[0] 66 | generated_channels = train_ds.array_targets.shape[1] 67 | print("generated channels: {} conditioning channels: {}".format(generated_channels, conditioning_channels)) 68 | 69 | model = UNetRegression( 70 | config.experiment.unet_regression, 71 | train_dataset=train_ds, 72 | valid_dataset=val_ds, 73 | generated_channels=generated_channels, 74 | condition_channels=conditioning_channels, 75 | loss_fn=loss_fn, 76 | ) 77 | 78 | lr_monitor = LearningRateMonitor(logging_interval="step") 79 | 80 | early_stopping = EarlyStopping( 81 | monitor="val_loss", mode="min", patience=config.experiment.training.patience 82 | ) 83 | 84 | trainer = L.Trainer( 85 | max_steps=config.experiment.training.max_steps, 86 | limit_val_batches=config.experiment.training.limit_val_batches, 87 | accelerator=config.experiment.training.accelerator, 88 | devices=config.experiment.training.devices, 89 | callbacks=[EMA(config.experiment.training.ema_decay), lr_monitor, early_stopping], 90 | logger=tb_logger 91 | ) 92 | 93 | trainer.fit(model) 94 | 95 | 96 | if __name__ == '__main__': 97 | main() -------------------------------------------------------------------------------- /s8_write_predictions_unet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from dm_zoo.dff.UNetRegression import ( 11 | UNetRegression, 12 | ) 13 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 14 | from WD.utils import create_dir 15 | from WD.io import create_xr_output_variables 16 | import lightning as L 17 | 18 | 19 | @hydra.main(version_base=None, config_path="./config", config_name="inference") 20 | def main(config: DictConfig) -> None: 21 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 22 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 23 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 24 | 25 | experiment_name = hydra_cfg['runtime']['choices']['experiment'] 26 | model_name = config.model_name # we have to pass this to the bash file every time! (should contain a string). 27 | 28 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.data.template}/.hydra/config.yaml") 29 | ml_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/training/{config.data.template}/{experiment_name}/{config.model_name}/.hydra/config.yaml") 30 | 31 | model_output_dir = config.paths.dir_ModelOutput 32 | 33 | model_load_dir = Path(f"{config.paths.dir_SavedModels}/{config.data.template}/{experiment_name}/{config.model_name}/lightning_logs/version_0/checkpoints/") 34 | 35 | test_ds_path = f"{config.paths.dir_PreprocessedDatasets}{config.data.template}_test.zarr" 36 | 37 | ds = Conditional_Dataset_Zarr_Iterable(test_ds_path, ds_config.template, shuffle_chunks=config.shuffle_chunks, 38 | shuffle_in_chunks=config.shuffle_in_chunks) 39 | 40 | model_ckpt = [x for x in model_load_dir.iterdir()][0] 41 | 42 | conditioning_channels = ds.array_inputs.shape[1] * len(ds.conditioning_timesteps) + ds.array_constants.shape[0] 43 | generated_channels = ds.array_targets.shape[1] 44 | 45 | restored_model = UNetRegression.load_from_checkpoint( 46 | model_ckpt, 47 | config=ml_config.experiment.unet_regression, 48 | generated_channels=generated_channels, 49 | condition_channels=conditioning_channels, 50 | loss_fn=config.loss_fn, 51 | ) 52 | 53 | dl = DataLoader(ds, batch_size=config.batchsize) 54 | trainer = L.Trainer() 55 | 56 | out = trainer.predict(restored_model, dl) 57 | 58 | out = torch.cat(out, dim=0) 59 | out = out.view(1, *out.shape) 60 | print(out.shape) 61 | 62 | model_output_dir = os.path.join(model_output_dir, config.data.template, experiment_name, model_name, dir_name) 63 | create_dir(model_output_dir) 64 | 65 | targets = torch.tensor(ds.data.targets.data[ds.start+ds.lead_time:ds.stop+ds.lead_time], dtype=torch.float).unsqueeze(dim=0) 66 | 67 | gen_xr = create_xr_output_variables( 68 | out, 69 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 70 | config=ds_config, 71 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 72 | ) 73 | 74 | target_xr = create_xr_output_variables( 75 | targets, 76 | zarr_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_test.zarr/targets", 77 | config=ds_config, 78 | min_max_file_path=f"{config.paths.dir_PreprocessedDatasets}/{config.data.template}_output_min_max.nc" 79 | ) 80 | 81 | gen_dir = os.path.join(model_output_dir, "gen.nc") 82 | gen_xr.to_netcdf(gen_dir) 83 | print(f"Generated data written at: {gen_dir}") 84 | 85 | target_dir = os.path.join(model_output_dir, "target.nc") 86 | target_xr.to_netcdf(target_dir) 87 | print(f"Target data written at: {target_dir}") 88 | 89 | 90 | if __name__ == '__main__': 91 | main() -------------------------------------------------------------------------------- /s9_train_vae.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from lightning.pytorch import loggers as pl_loggers 3 | from dm_zoo.dff.EMA import EMA 4 | 5 | import os 6 | import hydra 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | from WD.io import load_config 10 | from WD.datasets import Conditional_Dataset_Zarr_Iterable 11 | from WD.utils import create_dir, generate_uid, check_devices 12 | from dm_zoo.latent.vae.vae_lightning_module import VAE 13 | import torch 14 | from lightning.pytorch.callbacks import LearningRateMonitor 15 | from lightning.pytorch.callbacks import ( 16 | EarlyStopping, 17 | ) 18 | 19 | @hydra.main(version_base=None, config_path="./config", config_name="train") 20 | def main(config: DictConfig) -> None: 21 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 22 | exp_name = hydra_cfg['runtime']['choices']['experiment'] 23 | dir_name = hydra_cfg['runtime']['output_dir'] # the directory the hydra log is written to. 24 | dir_name = os.path.basename(os.path.normpath(dir_name)) # we only need the last part 25 | 26 | ds_config = OmegaConf.load(f"{config.paths.dir_HydraConfigs}/data/{config.experiment.data.template}/.hydra/config.yaml") 27 | 28 | print(f"The torch version being used is {torch.__version__}") 29 | check_devices() 30 | 31 | # load config 32 | print(f"Loading dataset {config.experiment.data.template}") 33 | 34 | train_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_train.zarr" 35 | train_ds = Conditional_Dataset_Zarr_Iterable(train_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.train_shuffle_chunks, 36 | shuffle_in_chunks=config.experiment.data.train_shuffle_in_chunks) 37 | 38 | val_ds_path = config.paths.dir_PreprocessedDatasets + f"{config.experiment.data.template}_val.zarr" 39 | val_ds = Conditional_Dataset_Zarr_Iterable(val_ds_path, ds_config.template, shuffle_chunks=config.experiment.data.val_shuffle_chunks, shuffle_in_chunks=config.experiment.data.val_shuffle_in_chunks) 40 | 41 | if config.experiment.vae.type == "input": 42 | n_channels = train_ds.array_inputs.shape[1] * len(train_ds.conditioning_timesteps) + train_ds.array_constants.shape[0] 43 | else: 44 | n_channels = train_ds.array_targets.shape[1] 45 | in_shape = (n_channels, *train_ds.array_targets.shape[:-2]) 46 | 47 | model = VAE(inp_shape = in_shape, train_dataset=train_ds, valid_dataset=val_ds, 48 | dim=config.experiment.vae.dim, 49 | channel_mult = config.experiment.vae.channel_mult, 50 | batch_size = config.experiment.vae.batch_size, 51 | lr = config.experiment.vae.lr, 52 | lr_scheduler_name=config.experiment.vae.lr_scheduler_name, 53 | num_workers = config.experiment.vae.num_workers, 54 | beta = config.experiment.vae.beta, 55 | data_type = config.experiment.vae.type) 56 | 57 | model_dir = f"{config.paths.dir_SavedModels}/{config.experiment.data.template}/{exp_name}/{dir_name}/" 58 | create_dir(model_dir) 59 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=model_dir) 60 | 61 | lr_monitor = LearningRateMonitor(logging_interval="step") 62 | 63 | trainer = L.Trainer( 64 | max_steps=config.experiment.training.max_steps, 65 | limit_val_batches=config.experiment.training.limit_val_batches, 66 | accelerator=config.experiment.training.accelerator, 67 | devices=config.experiment.training.devices, 68 | callbacks=[EMA(config.experiment.training.ema_decay), lr_monitor], #, early_stopping], 69 | logger=tb_logger 70 | ) 71 | 72 | trainer.fit(model) 73 | 74 | 75 | if __name__ == '__main__': 76 | main() -------------------------------------------------------------------------------- /submit_script_10_inference_vae.sh: -------------------------------------------------------------------------------- 1 | python s10_write_predictions_vae.py +data.template=geopotential_500_highres +experiment=vae_geopotential_highres_v1 +model_name=2023-08-23_00-43-26 -------------------------------------------------------------------------------- /submit_script_11_train_LFD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=LFD 4 | #SBATCH --time=3-10:35:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=64G 7 | #SBATCH --cpus-per-task=1 8 | # output files 9 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 10 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 11 | 12 | # read in command line arguments: 13 | helpFunction() 14 | { 15 | echo "" 16 | echo "Usage: $0 -e experiment_name" 17 | echo -e "\t-e The name of the experiment template to be used." 18 | exit 1 # Exit script after printing help 19 | } 20 | 21 | while getopts "e:" opt 22 | do 23 | case "$opt" in 24 | e ) experiment_name="$OPTARG" ;; 25 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 26 | esac 27 | done 28 | 29 | # Print helpFunction in case parameters are empty 30 | if [ -z "$experiment_name" ] 31 | then 32 | echo "Some or all of the parameters are empty."; 33 | helpFunction 34 | fi 35 | # stop reading command line arguments 36 | 37 | module load Anaconda3/2020.07 38 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 39 | 40 | conda activate TORCH311 41 | 42 | python s11_train_LFD.py +experiment=$experiment_name -------------------------------------------------------------------------------- /submit_script_12_inference_LFD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=EvalLFD 4 | #SBATCH --time=0-06:45:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=16G 7 | # output files 8 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 9 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 10 | 11 | # begin reading command line arguments 12 | helpFunction() 13 | { 14 | echo "" 15 | echo "Usage: $0 -t DatasetTemplateName -e ExperimentName -m modelName -n NEnsembleMembers" 16 | echo -t "\t-t The name of the dataset template that should be used." 17 | echo -e "\t-e The name of the experiment conducted on the dataset." 18 | echo -e "\t-m The name of the model the predictions should be created with." 19 | echo -e "\t-n The number of ensemble members to be created." 20 | exit 1 # Exit script after printing help 21 | } 22 | 23 | while getopts "t:e:m:n:" opt 24 | do 25 | case "$opt" in 26 | t ) TemplateName="$OPTARG" ;; 27 | e ) ExperimentName="$OPTARG" ;; 28 | m ) ModelID="$OPTARG" ;; 29 | n ) EnsembleMembers="$OPTARG" ;; 30 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 31 | esac 32 | done 33 | 34 | # Print helpFunction in case parameters are empty 35 | if [ -z "$TemplateName" ] || [ -z "$ExperimentName" ] || [ -z "$ModelID" ] || [ -z "$EnsembleMembers" ] 36 | then 37 | echo "Some or all of the parameters are empty."; 38 | helpFunction 39 | fi 40 | # stop reading command line arguments 41 | 42 | module load Anaconda3/2020.07 43 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 44 | 45 | conda activate TORCH311 46 | 47 | python s12_write_predictions_LFD.py +data.template=$TemplateName +experiment=$ExperimentName +model_name=$ModelID +n_ensemble_members=$EnsembleMembers 48 | -------------------------------------------------------------------------------- /submit_script_13_inference_iterative.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=EvalCondIter 4 | #SBATCH --time=1-07:45:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=35G 7 | # output files 8 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 9 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 10 | 11 | 12 | # begin reading command line arguments 13 | helpFunction() 14 | { 15 | echo "" 16 | echo "Usage: $0 -t DatasetTemplateName -e ExperimentName -m modelName -n NEnsembleMembers -s NSteps" 17 | echo -t "\t-t The name of the dataset template that should be used." 18 | echo -e "\t-e The name of the experiment conducted on the dataset." 19 | echo -e "\t-m The name of the model the predictions should be created with." 20 | echo -e "\t-n The number of ensemble members to be created." 21 | echo -e "\t-s The number of steps in the trajectories created." 22 | exit 1 # Exit script after printing help 23 | } 24 | 25 | while getopts "t:e:m:n:s:" opt 26 | do 27 | case "$opt" in 28 | t ) TemplateName="$OPTARG" ;; 29 | e ) ExperimentName="$OPTARG" ;; 30 | m ) ModelID="$OPTARG" ;; 31 | n ) EnsembleMembers="$OPTARG" ;; 32 | s ) Steps="$OPTARG" ;; 33 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 34 | esac 35 | done 36 | 37 | # Print helpFunction in case parameters are empty 38 | if [ -z "$TemplateName" ] || [ -z "$ExperimentName" ] || [ -z "$ModelID" ] || [ -z "$EnsembleMembers" ] || [ -z "$Steps" ] 39 | then 40 | echo "Some or all of the parameters are empty."; 41 | helpFunction 42 | fi 43 | # stop reading command line arguments 44 | 45 | module load Anaconda3/2020.07 46 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 47 | 48 | conda activate WD_model 49 | 50 | python s13_write_predictions_iterative.py +data.template=$TemplateName +experiment=$ExperimentName +model_name=$ModelID +n_ensemble_members=$EnsembleMembers +n_steps=$Steps 51 | -------------------------------------------------------------------------------- /submit_script_14_inference_iterative_very_long.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=EvalCondIterLong 4 | #SBATCH --time=0-20:00:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=45G 7 | # output files 8 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 9 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 10 | 11 | # begin reading command line arguments 12 | helpFunction() 13 | { 14 | echo "" 15 | echo "Usage: $0 -t DatasetTemplateName -e ExperimentName -m modelName -n NEnsembleMembers" 16 | echo -t "\t-t The name of the dataset template that should be used." 17 | echo -e "\t-e The name of the experiment conducted on the dataset." 18 | echo -e "\t-m The name of the model the predictions should be created with." 19 | echo -e "\t-n The number of ensemble members to be created." 20 | exit 1 # Exit script after printing help 21 | } 22 | 23 | while getopts "t:e:m:n:" opt 24 | do 25 | case "$opt" in 26 | t ) TemplateName="$OPTARG" ;; 27 | e ) ExperimentName="$OPTARG" ;; 28 | m ) ModelID="$OPTARG" ;; 29 | n ) EnsembleMembers="$OPTARG" ;; 30 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 31 | esac 32 | done 33 | 34 | # Print helpFunction in case parameters are empty 35 | if [ -z "$TemplateName" ] || [ -z "$ExperimentName" ] || [ -z "$ModelID" ] || [ -z "$EnsembleMembers" ] 36 | then 37 | echo "Some or all of the parameters are empty."; 38 | helpFunction 39 | fi 40 | # stop reading command line arguments 41 | 42 | module load Anaconda3/2020.07 43 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 44 | 45 | conda activate WD_model 46 | 47 | python s14_very_long_iterative_run.py +data.template=$TemplateName +experiment=$ExperimentName +model_name=$ModelID +n_ensemble_members=$EnsembleMembers 48 | -------------------------------------------------------------------------------- /submit_script_1_dataset_creation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=CreateData 4 | #SBATCH --time=0-04:00:00 5 | #SBATCH --mem-per-cpu=16G 6 | 7 | # output files 8 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 9 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 10 | 11 | # read in command line arguments: 12 | helpFunction() 13 | { 14 | echo "" 15 | echo "Usage: $0 -t Template" 16 | echo -e "\t-t The name of the template to be used." 17 | exit 1 # Exit script after printing help 18 | } 19 | 20 | while getopts "t:" opt 21 | do 22 | case "$opt" in 23 | t ) TemplateName="$OPTARG" ;; 24 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 25 | esac 26 | done 27 | 28 | # Print helpFunction in case parameters are empty 29 | if [ -z "$TemplateName" ] 30 | then 31 | echo "Some or all of the parameters are empty."; 32 | helpFunction 33 | fi 34 | # stop reading command line arguments 35 | 36 | module load Anaconda3/2020.07 37 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 38 | 39 | conda activate WD_data 40 | 41 | python s1_write_dataset.py +template=$TemplateName -------------------------------------------------------------------------------- /submit_script_2_run_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=RunCond 4 | #SBATCH --time=1-2:45:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=25G 7 | #SBATCH --cpus-per-task=1 8 | # output files 9 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 10 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 11 | 12 | # read in command line arguments: 13 | helpFunction() 14 | { 15 | echo "" 16 | echo "Usage: $0 -e experiment_name" 17 | echo -e "\t-e The name of the experiment template to be used." 18 | exit 1 # Exit script after printing help 19 | } 20 | 21 | while getopts "e:" opt 22 | do 23 | case "$opt" in 24 | e ) experiment_name="$OPTARG" ;; 25 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 26 | esac 27 | done 28 | 29 | # Print helpFunction in case parameters are empty 30 | if [ -z "$experiment_name" ] 31 | then 32 | echo "Some or all of the parameters are empty."; 33 | helpFunction 34 | fi 35 | # stop reading command line arguments 36 | 37 | module load Anaconda3/2023.03 38 | 39 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 40 | conda activate TORCH21 41 | 42 | python s2_train_conditional_pixel_diffusion.py +experiment=$experiment_name 43 | -------------------------------------------------------------------------------- /submit_script_3_inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=EvalCond 4 | #SBATCH --time=0-03:45:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=16G 7 | # output files 8 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 9 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 10 | 11 | # begin reading command line arguments 12 | helpFunction() 13 | { 14 | echo "" 15 | echo "Usage: $0 -t DatasetTemplateName -e ExperimentName -m modelName -n NEnsembleMembers" 16 | echo -t "\t-m The name of the dataset template that should be used." 17 | echo -e "\t-e The name of the experiment conducted on the dataset." 18 | echo -e "\t-m The name of the model the predictions should be created with." 19 | echo -e "\t-n The number of ensemble members to be created." 20 | exit 1 # Exit script after printing help 21 | } 22 | 23 | while getopts "t:e:m:n:" opt 24 | do 25 | case "$opt" in 26 | t ) TemplateName="$OPTARG" ;; 27 | e ) ExperimentName="$OPTARG" ;; 28 | m ) ModelID="$OPTARG" ;; 29 | n ) EnsembleMembers="$OPTARG" ;; 30 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 31 | esac 32 | done 33 | 34 | # Print helpFunction in case parameters are empty 35 | if [ -z "$TemplateName" ] || [ -z "$ExperimentName" ] || [ -z "$ModelID" ] || [ -z "$EnsembleMembers" ] 36 | then 37 | echo "Some or all of the parameters are empty."; 38 | helpFunction 39 | fi 40 | # stop reading command line arguments 41 | 42 | module load Anaconda3/2020.07 43 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 44 | 45 | conda activate TORCH311 46 | 47 | python s3_write_predictions_conditional_pixel_diffusion.py +data.template=$TemplateName +experiment=$ExperimentName +model_name=$ModelID +n_ensemble_members=$EnsembleMembers -------------------------------------------------------------------------------- /submit_script_4_eval_epoch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=EvalEpoch 4 | #SBATCH --time=0-03:45:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=16G 7 | # output files 8 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 9 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 10 | 11 | # begin reading command line arguments 12 | helpFunction() 13 | { 14 | echo "" 15 | echo "Usage: $0 -d DatasetID -m ModelID -e EnsembleMembers" 16 | echo -e "\t-d The ID of the dataset the model was trained on." 17 | echo -e "\t-m The ID of the model the predictions were created with." 18 | echo -e "\t-e The number of ensemble members to be created." 19 | exit 1 # Exit script after printing help 20 | } 21 | 22 | while getopts "d:m:e:" opt 23 | do 24 | case "$opt" in 25 | d ) DatasetID="$OPTARG" ;; 26 | m ) ModelID="$OPTARG" ;; 27 | e ) EnsembleMembers="$OPTARG" ;; 28 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 29 | esac 30 | done 31 | 32 | # Print helpFunction in case parameters are empty 33 | if [ -z "$DatasetID" ] || [ -z "$ModelID" ] || [ -z "$EnsembleMembers" ] 34 | then 35 | echo "Some or all of the parameters are empty."; 36 | helpFunction 37 | fi 38 | # stop reading command line arguments 39 | 40 | module load Anaconda3/2020.07 41 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 42 | 43 | conda activate TORCH311 44 | 45 | python s4_train_val_test.py -did $DatasetID -mid $ModelID -nens $EnsembleMembers 46 | -------------------------------------------------------------------------------- /submit_script_5_train_FourCastNet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=RunFour 4 | #SBATCH --time=1-2:45:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=20G 7 | #SBATCH --cpus-per-task=2 8 | # output files 9 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 10 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 11 | 12 | # read in command line arguments: 13 | helpFunction() 14 | { 15 | echo "" 16 | echo "Usage: $0 -e experiment_name" 17 | echo -e "\t-e The name of the experiment template to be used." 18 | exit 1 # Exit script after printing help 19 | } 20 | 21 | while getopts "e:" opt 22 | do 23 | case "$opt" in 24 | e ) experiment_name="$OPTARG" ;; 25 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 26 | esac 27 | done 28 | 29 | # Print helpFunction in case parameters are empty 30 | if [ -z "$experiment_name" ] 31 | then 32 | echo "Some or all of the parameters are empty."; 33 | helpFunction 34 | fi 35 | # stop reading command line arguments 36 | 37 | module load Anaconda3/2020.07 38 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 39 | 40 | conda activate WD_model 41 | 42 | python s5_train_FourCastNet.py +experiment=$experiment_name 43 | 44 | -------------------------------------------------------------------------------- /submit_script_6_inference_fourcastnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=EvalFour 4 | #SBATCH --time=0-06:45:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=16G 7 | # output files 8 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 9 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 10 | 11 | # begin reading command line arguments 12 | helpFunction() 13 | { 14 | echo "" 15 | echo "Usage: $0 -t DatasetTemplateName -e ExperimentName -m modelName" 16 | echo -t "\t-m The name of the dataset template that should be used." 17 | echo -e "\t-e The name of the experiment conducted on the dataset." 18 | echo -m "\t-m The name of the model the predictions should be created with." 19 | exit 1 # Exit script after printing help 20 | } 21 | 22 | while getopts "t:e:m:" opt 23 | do 24 | case "$opt" in 25 | t ) TemplateName="$OPTARG" ;; 26 | e ) ExperimentName="$OPTARG" ;; 27 | m ) ModelID="$OPTARG" ;; 28 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 29 | esac 30 | done 31 | 32 | # Print helpFunction in case parameters are empty 33 | if [ -z "$TemplateName" ] || [ -z "$ExperimentName" ] || [ -z "$ModelID" ] 34 | then 35 | echo "Some or all of the parameters are empty."; 36 | helpFunction 37 | fi 38 | # stop reading command line arguments 39 | 40 | module load Anaconda3/2020.07 41 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 42 | 43 | conda activate WD_model 44 | 45 | python s6_write_predictions_FourCastNet.py +data.template=$TemplateName +experiment=$ExperimentName +model_name=$ModelID 46 | -------------------------------------------------------------------------------- /submit_script_7_train_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=TrainUnet 4 | #SBATCH --time=0-23:45:00 5 | #SBATCH -G 1 # nvidia-a100:1 6 | #SBATCH --mem-per-cpu=16G 7 | #SBATCH --cpus-per-task=4 8 | # output files 9 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 10 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 11 | 12 | helpFunction() 13 | { 14 | echo "" 15 | echo "Usage: $0 -e experiment_name" 16 | echo -e "\t-e The name of the experiment template to be used." 17 | exit 1 # Exit script after printing help 18 | } 19 | 20 | while getopts "e:" opt 21 | do 22 | case "$opt" in 23 | e ) experiment_name="$OPTARG" ;; 24 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 25 | esac 26 | done 27 | 28 | # Print helpFunction in case parameters are empty 29 | if [ -z "$experiment_name" ] 30 | then 31 | echo "Some or all of the parameters are empty."; 32 | helpFunction 33 | fi 34 | # stop reading command line arguments 35 | 36 | module load Anaconda3/2020.07 37 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 38 | 39 | conda activate WD_model 40 | 41 | python s7_train_unet.py +experiment=$experiment_name 42 | 43 | -------------------------------------------------------------------------------- /submit_script_8_inference_UNet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=EvalUNet 4 | #SBATCH --time=0-03:45:00 5 | #SBATCH -G 1 # nvidia-a100:1 6 | #SBATCH --mem-per-cpu=16G 7 | # output files 8 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 9 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 10 | 11 | # begin reading command line arguments 12 | helpFunction() 13 | { 14 | echo "" 15 | echo "Usage: $0 -t DatasetTemplateName -e ExperimentName -m modelName" 16 | echo -t "\t-m The name of the dataset template that should be used." 17 | echo -e "\t-e The name of the experiment conducted on the dataset." 18 | echo -e "\t-m The name of the model the predictions should be created with." 19 | exit 1 # Exit script after printing help 20 | } 21 | 22 | while getopts "t:e:m:" opt 23 | do 24 | case "$opt" in 25 | t ) TemplateName="$OPTARG" ;; 26 | e ) ExperimentName="$OPTARG" ;; 27 | m ) ModelID="$OPTARG" ;; 28 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 29 | esac 30 | done 31 | 32 | # Print helpFunction in case parameters are empty 33 | if [ -z "$TemplateName" ] || [ -z "$ExperimentName" ] || [ -z "$ModelID" ] 34 | then 35 | echo "Some or all of the parameters are empty."; 36 | helpFunction 37 | fi 38 | # stop reading command line arguments 39 | 40 | module load Anaconda3/2020.07 41 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 42 | 43 | conda activate WD_model 44 | 45 | python s8_write_predictions_unet.py +data.template=$TemplateName +experiment=$ExperimentName +model_name=$ModelID 46 | 47 | -------------------------------------------------------------------------------- /submit_script_9_train_vae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=VAE 4 | #SBATCH --time=1-6:05:00 5 | #SBATCH -G nvidia-a100:1 6 | #SBATCH --mem-per-cpu=64G 7 | #SBATCH --cpus-per-task=1 8 | # output files 9 | #SBATCH -o /data/compoundx/WeatherDiff/job_log/%x-%u-%j.out 10 | #SBATCH -e /data/compoundx/WeatherDiff/job_log/%x-%u-%j.err 11 | 12 | # read in command line arguments: 13 | helpFunction() 14 | { 15 | echo "" 16 | echo "Usage: $0 -e experiment_name" 17 | echo -e "\t-e The name of the experiment template to be used." 18 | exit 1 # Exit script after printing help 19 | } 20 | 21 | while getopts "e:" opt 22 | do 23 | case "$opt" in 24 | e ) experiment_name="$OPTARG" ;; 25 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 26 | esac 27 | done 28 | 29 | # Print helpFunction in case parameters are empty 30 | if [ -z "$experiment_name" ] 31 | then 32 | echo "Some or all of the parameters are empty."; 33 | helpFunction 34 | fi 35 | # stop reading command line arguments 36 | 37 | module load Anaconda3/2020.07 38 | source $EBROOTANACONDA3/etc/profile.d/conda.sh 39 | 40 | conda activate TORCH311 41 | 42 | python s9_train_vae.py +experiment=$experiment_name 43 | 44 | --------------------------------------------------------------------------------