├── .gitignore ├── README.md ├── configs ├── global_forecast_weatherode.yaml └── regional_forecast_weatherode.yaml ├── environment.yml ├── pyproject.toml ├── scripts ├── global │ ├── train12h.sh │ ├── train18h.sh │ ├── train24h.sh │ └── train6h.sh └── regional │ ├── Australia │ ├── train12h.sh │ ├── train18h.sh │ ├── train24h.sh │ └── train6h.sh │ ├── NorthAmerica │ ├── train12h.sh │ ├── train18h.sh │ ├── train24h.sh │ └── train6h.sh │ └── SouthAmerica │ ├── train12h.sh │ ├── train18h.sh │ ├── train24h.sh │ └── train6h.sh └── src ├── data_preprocessing ├── nc2np_equally_era5.py ├── regrid.py └── regrid_climatebench.py └── weatherode ├── __init__.py ├── blocks.py ├── c3d.py ├── cnn_dit.py ├── dit.py ├── global_forecast ├── __init__.py ├── datamodule.py ├── module.py └── train.py ├── ode.py ├── ode_utils.py ├── pretrain ├── __init__.py ├── datamodule.py ├── dataset.py ├── module.py └── train.py ├── regional_forecast ├── __init__.py ├── datamodule.py ├── module.py ├── ode.py └── train.py └── utils ├── data_utils.py ├── lr_scheduler.py ├── metrics.py └── pos_embed.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode setting 132 | .vscode 133 | 134 | # experiments 135 | exps 136 | 137 | # snakemake logs 138 | .snakemake 139 | 140 | # MacOS 141 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WeatherODE 2 | 3 | Peiyuan Liu, Tian Zhou, Liang Sun, Rong Jin, "Mitigating Time Discretization Challenges with WeatherODE: A Sandwich Physics-Driven Neural ODE for Weather Forecasting". [[paper](https://arxiv.org/abs/2410.06560)] 4 | 5 | ## Overview 6 | 7 | WeatherODE is a comprehensive framework designed for global and regional weather forecasting based on the ERA5 dataset. The package includes preprocessing scripts, model training pipelines, and evaluation tools tailored for different forecasting horizons (6h, 12h, 18h, and 24h). It supports global and regional forecasting for various areas including Australia, North America, and South America. 8 | 9 | ## Installation 10 | 11 | ### Setting Up the Environment 12 | 13 | To get started, create and activate a conda environment using the provided configuration: 14 | 15 | ```bash 16 | conda env create -f environment.yml 17 | conda activate weatherode 18 | ``` 19 | 20 | ### Installing WeatherODE 21 | 22 | Install the WeatherODE package in editable mode: 23 | 24 | ```bash 25 | pip install -e . 26 | ``` 27 | 28 | ## Data Preparation 29 | 30 | ### Download ERA5 Data 31 | 32 | Download the ERA5 reanalysis dataset from the [WeatherBench](https://dataserv.ub.tum.de/index.php/s/m1524895). Organize the data directory as follows: 33 | 34 | ``` 35 | 5.625deg 36 | ├── 10m_u_component_of_wind 37 | ├── 10m_v_component_of_wind 38 | ├── 2m_temperature 39 | ├── constants.nc 40 | ├── geopotential 41 | ├── relative_humidity 42 | ├── specific_humidity 43 | ├── temperature 44 | ├── toa_incident_solar_radiation 45 | ├── total_precipitation 46 | ├── u_component_of_wind 47 | └── v_component_of_wind 48 | ``` 49 | 50 | ### Preprocessing 51 | 52 | Convert the raw NetCDF files into smaller, more manageable NumPy files and compute essential statistical measures. Execute the following script: 53 | 54 | ```bash 55 | python src/data_preprocessing/nc2np_equally_era5.py \ 56 | --root_dir /mnt/data/5.625deg \ 57 | --save_dir /mnt/data/5.625deg_npz \ 58 | --start_train_year 1979 --start_val_year 2016 \ 59 | --start_test_year 2017 --end_year 2019 --num_shards 8 60 | ``` 61 | 62 | The preprocessed data directory will have the following structure: 63 | 64 | ``` 65 | 5.625deg_npz 66 | ├── train 67 | ├── val 68 | ├── test 69 | ├── normalize_mean.npz 70 | ├── normalize_std.npz 71 | ├── lat.npy 72 | └── lon.npy 73 | ``` 74 | 75 | ## Training 76 | 77 | ### Global Forecasting 78 | 79 | To train a global forecasting model with a 6-hour prediction horizon, use the following command: 80 | 81 | ```bash 82 | bash ./scripts/global/train_6h.sh 83 | ``` 84 | 85 | Scripts for 12-hour, 18-hour, and 24-hour forecast models are also available. 86 | 87 | ### Regional Forecasting 88 | 89 | For regional forecasting in Australia with a 6-hour prediction horizon, use the following command: 90 | 91 | ```bash 92 | bash ./scripts/regional/Australia/train_6h.sh 93 | ``` 94 | 95 | Scripts are also provided for 12-hour, 18-hour, and 24-hour forecasts. Additional regions include North America and South America. 96 | 97 | 98 | ## Acknowledgements 99 | 100 | We acknowledge the use of the ERA5 reanalysis data provided by the European Centre for Medium-Range Weather Forecasts (ECMWF) and the WeatherBench dataset for benchmarking. 101 | 102 | ## Citation 103 | If you find this repo useful, please cite our paper. 104 | ``` 105 | @article{liu2024mitigating, 106 | title={Mitigating Time Discretization Challenges with WeatherODE: A Sandwich Physics-Driven Neural ODE for Weather Forecasting}, 107 | author={Liu, Peiyuan and Zhou, Tian and Sun, Liang and Jin, Rong}, 108 | journal={arXiv preprint arXiv:2410.06560}, 109 | year={2024} 110 | } 111 | ``` -------------------------------------------------------------------------------- /configs/global_forecast_weatherode.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | 3 | # ---------------------------- TRAINER ------------------------------------------- 4 | trainer: 5 | default_root_dir: ${oc.env:OUTPUT_DIR,./202408_36hours} 6 | 7 | precision: 16 8 | 9 | gpus: null 10 | num_nodes: 1 11 | accelerator: gpu 12 | strategy: ddp 13 | 14 | min_epochs: 1 15 | max_epochs: 100 16 | enable_progress_bar: true 17 | 18 | sync_batchnorm: True 19 | enable_checkpointing: True 20 | resume_from_checkpoint: null 21 | 22 | # debugging 23 | fast_dev_run: false 24 | 25 | logger: 26 | class_path: pytorch_lightning.loggers.wandb.WandbLogger 27 | init_args: 28 | # name: "${model.name}_${oc.env:SLURM_JOB_ID}" # name of the run (normally generated by wandb) 29 | name: "predict_72h_multi_gt_with_noise_ViT_ode_linear_term_72steps_skip_all_nan_0.001_all_multi_val_ode_lr=1e-4_else_5e-4_3DNoise_with_v_x" 30 | # name: "predict_6h_multi_gt_finetune_noise_1e-4_2DNoise" 31 | save_dir: ${trainer.default_root_dir}/logs 32 | offline: False 33 | id: null # pass correct id to resume experiment! 34 | anonymous: null # enable anonymous logging 35 | project: "era5" 36 | log_model: False # upload lightning ckpts 37 | prefix: "" # a string to put at the beginning of metric keys 38 | entity: "WeatherODE" # set to name of your wandb team 39 | group: "WeatherODE_exp" 40 | tags: ['era5', 'forecast', 'multi_gt'] 41 | job_type: "" 42 | 43 | # logger: 44 | # class_path: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 45 | # init_args: 46 | # save_dir: ${trainer.default_root_dir}/logs 47 | # name: null 48 | # version: null 49 | # log_graph: False 50 | # default_hp_metric: True 51 | # prefix: "" 52 | 53 | callbacks: 54 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 55 | init_args: 56 | logging_interval: "step" 57 | 58 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 59 | init_args: 60 | dirpath: "${trainer.default_root_dir}/checkpoints" 61 | monitor: "val/w_rmse" # name of the logged metric which determines when model is improving 62 | mode: "min" # "max" means higher metric value is better, can be also "min" 63 | save_top_k: 1 # save k best models (determined by above metric) 64 | save_last: True # additionaly always save model from last epoch 65 | verbose: False 66 | filename: "epoch_{epoch:03d}" 67 | auto_insert_metric_name: False 68 | 69 | - class_path: pytorch_lightning.callbacks.EarlyStopping 70 | init_args: 71 | monitor: "val/w_rmse" # name of the logged metric which determines when model is improving 72 | mode: "min" # "max" means higher metric value is better, can be also "min" 73 | patience: 2 # how many validation epochs of not improving until training stops 74 | min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement 75 | 76 | - class_path: pytorch_lightning.callbacks.RichModelSummary 77 | init_args: 78 | max_depth: -1 79 | 80 | - class_path: pytorch_lightning.callbacks.RichProgressBar 81 | 82 | # ---------------------------- MODEL ------------------------------------------- 83 | model: 84 | lr: 5e-3 85 | ode_lr: 5e-6 86 | beta_1: 0.9 87 | beta_2: 0.99 88 | weight_decay: 1e-5 89 | warmup_epochs: 20000 90 | max_epochs: 100000 91 | warmup_start_lr: 1e-8 92 | eta_min: 1e-8 93 | pretrained_path: "" 94 | gradient_clip_val: 0.1 95 | gradient_clip_algorithm: "value" 96 | train_noise_only: False 97 | 98 | net: 99 | class_path: weatherode.ode.WeatherODE 100 | init_args: 101 | default_vars: [ 102 | "land_sea_mask", 103 | "orography", 104 | "lattitude", 105 | "2m_temperature", 106 | "10m_u_component_of_wind", 107 | "10m_v_component_of_wind", 108 | "geopotential_50", 109 | "geopotential_250", 110 | "geopotential_500", 111 | "geopotential_600", 112 | "geopotential_700", 113 | "geopotential_850", 114 | "geopotential_925", 115 | "u_component_of_wind_50", 116 | "u_component_of_wind_250", 117 | "u_component_of_wind_500", 118 | "u_component_of_wind_600", 119 | "u_component_of_wind_700", 120 | "u_component_of_wind_850", 121 | "u_component_of_wind_925", 122 | "v_component_of_wind_50", 123 | "v_component_of_wind_250", 124 | "v_component_of_wind_500", 125 | "v_component_of_wind_600", 126 | "v_component_of_wind_700", 127 | "v_component_of_wind_850", 128 | "v_component_of_wind_925", 129 | "temperature_50", 130 | "temperature_250", 131 | "temperature_500", 132 | "temperature_600", 133 | "temperature_700", 134 | "temperature_850", 135 | "temperature_925", 136 | "relative_humidity_50", 137 | "relative_humidity_250", 138 | "relative_humidity_500", 139 | "relative_humidity_600", 140 | "relative_humidity_700", 141 | "relative_humidity_850", 142 | "relative_humidity_925", 143 | "specific_humidity_50", 144 | "specific_humidity_250", 145 | "specific_humidity_500", 146 | "specific_humidity_600", 147 | "specific_humidity_700", 148 | "specific_humidity_850", 149 | "specific_humidity_925", 150 | ] 151 | img_size: [32, 64] 152 | layers: [5, 5, 3, 2] 153 | hidden: [512, 128, 64] 154 | depth: 4 155 | method: "euler" 156 | drop_rate: 0.1 157 | time_steps: 36 158 | time_interval: 0.001 159 | rtol: 1e-9 160 | atol: 1e-11 161 | predict_list: [1,2,3,4,5,6] 162 | gradient_loss: False 163 | err_type: DiT 164 | err_with_v: False 165 | err_with_x: False 166 | 167 | # ---------------------------- DATA ------------------------------------------- 168 | data: 169 | root_dir: /mnt/workgroup/5.625deg_npz 170 | variables: [ 171 | "land_sea_mask", 172 | "orography", 173 | "lattitude", 174 | "2m_temperature", 175 | "10m_u_component_of_wind", 176 | "10m_v_component_of_wind", 177 | "geopotential_50", 178 | "geopotential_250", 179 | "geopotential_500", 180 | "geopotential_600", 181 | "geopotential_700", 182 | "geopotential_850", 183 | "geopotential_925", 184 | "u_component_of_wind_50", 185 | "u_component_of_wind_250", 186 | "u_component_of_wind_500", 187 | "u_component_of_wind_600", 188 | "u_component_of_wind_700", 189 | "u_component_of_wind_850", 190 | "u_component_of_wind_925", 191 | "v_component_of_wind_50", 192 | "v_component_of_wind_250", 193 | "v_component_of_wind_500", 194 | "v_component_of_wind_600", 195 | "v_component_of_wind_700", 196 | "v_component_of_wind_850", 197 | "v_component_of_wind_925", 198 | "temperature_50", 199 | "temperature_250", 200 | "temperature_500", 201 | "temperature_600", 202 | "temperature_700", 203 | "temperature_850", 204 | "temperature_925", 205 | "relative_humidity_50", 206 | "relative_humidity_250", 207 | "relative_humidity_500", 208 | "relative_humidity_600", 209 | "relative_humidity_700", 210 | "relative_humidity_850", 211 | "relative_humidity_925", 212 | "specific_humidity_50", 213 | "specific_humidity_250", 214 | "specific_humidity_500", 215 | "specific_humidity_600", 216 | "specific_humidity_700", 217 | "specific_humidity_850", 218 | "specific_humidity_925", 219 | ] 220 | out_variables: [ 221 | "2m_temperature", 222 | "10m_u_component_of_wind", 223 | "10m_v_component_of_wind", 224 | "geopotential_50", 225 | "geopotential_250", 226 | "geopotential_500", 227 | "geopotential_600", 228 | "geopotential_700", 229 | "geopotential_850", 230 | "geopotential_925", 231 | "u_component_of_wind_50", 232 | "u_component_of_wind_250", 233 | "u_component_of_wind_500", 234 | "u_component_of_wind_600", 235 | "u_component_of_wind_700", 236 | "u_component_of_wind_850", 237 | "u_component_of_wind_925", 238 | "v_component_of_wind_50", 239 | "v_component_of_wind_250", 240 | "v_component_of_wind_500", 241 | "v_component_of_wind_600", 242 | "v_component_of_wind_700", 243 | "v_component_of_wind_850", 244 | "v_component_of_wind_925", 245 | "temperature_50", 246 | "temperature_250", 247 | "temperature_500", 248 | "temperature_600", 249 | "temperature_700", 250 | "temperature_850", 251 | "temperature_925", 252 | "relative_humidity_50", 253 | "relative_humidity_250", 254 | "relative_humidity_500", 255 | "relative_humidity_600", 256 | "relative_humidity_700", 257 | "relative_humidity_850", 258 | "relative_humidity_925", 259 | "specific_humidity_50", 260 | "specific_humidity_250", 261 | "specific_humidity_500", 262 | "specific_humidity_600", 263 | "specific_humidity_700", 264 | "specific_humidity_850", 265 | "specific_humidity_925", 266 | ] 267 | predict_range: 6 268 | hrs_each_step: 1 269 | buffer_size: 10000 270 | batch_size: 128 271 | num_workers: 1 272 | pin_memory: False 273 | -------------------------------------------------------------------------------- /configs/regional_forecast_weatherode.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | 3 | # ---------------------------- TRAINER ------------------------------------------- 4 | trainer: 5 | default_root_dir: ${oc.env:OUTPUT_DIR,/home/t-tungnguyen/WeatherODE/exps/regional_forecast_weatherode} 6 | 7 | precision: 16 8 | 9 | gpus: null 10 | num_nodes: 1 11 | accelerator: gpu 12 | strategy: ddp 13 | 14 | min_epochs: 1 15 | max_epochs: 100 16 | enable_progress_bar: true 17 | 18 | sync_batchnorm: True 19 | enable_checkpointing: True 20 | resume_from_checkpoint: null 21 | 22 | # debugging 23 | fast_dev_run: false 24 | 25 | logger: 26 | class_path: pytorch_lightning.loggers.wandb.WandbLogger 27 | init_args: 28 | # name: "${model.name}_${oc.env:SLURM_JOB_ID}" # name of the run (normally generated by wandb) 29 | name: "predict_6h_NorthAmerican" 30 | # name: "predict_6h_multi_gt_finetune_noise_1e-4_2DNoise" 31 | save_dir: ${trainer.default_root_dir}/logs 32 | offline: False 33 | id: null # pass correct id to resume experiment! 34 | anonymous: null # enable anonymous logging 35 | project: "era5" 36 | log_model: False # upload lightning ckpts 37 | prefix: "" # a string to put at the beginning of metric keys 38 | entity: "WeatherODE" # set to name of your wandb team 39 | group: "WeatherODE_exp" 40 | tags: ['era5', 'forecast', 'multi_gt', 'regional'] 41 | job_type: "" 42 | 43 | # logger: 44 | # class_path: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 45 | # init_args: 46 | # save_dir: ${trainer.default_root_dir}/logs 47 | # name: null 48 | # version: null 49 | # log_graph: False 50 | # default_hp_metric: True 51 | # prefix: "" 52 | 53 | callbacks: 54 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 55 | init_args: 56 | logging_interval: "step" 57 | 58 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 59 | init_args: 60 | dirpath: "${trainer.default_root_dir}/checkpoints" 61 | monitor: "val/w_rmse" # name of the logged metric which determines when model is improving 62 | mode: "min" # "max" means higher metric value is better, can be also "min" 63 | save_top_k: 1 # save k best models (determined by above metric) 64 | save_last: True # additionaly always save model from last epoch 65 | verbose: False 66 | filename: "epoch_{epoch:03d}" 67 | auto_insert_metric_name: False 68 | 69 | - class_path: pytorch_lightning.callbacks.EarlyStopping 70 | init_args: 71 | monitor: "val/w_rmse" # name of the logged metric which determines when model is improving 72 | mode: "min" # "max" means higher metric value is better, can be also "min" 73 | patience: 5 # how many validation epochs of not improving until training stops 74 | min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement 75 | 76 | - class_path: pytorch_lightning.callbacks.RichModelSummary 77 | init_args: 78 | max_depth: -1 79 | 80 | - class_path: pytorch_lightning.callbacks.RichProgressBar 81 | 82 | # ---------------------------- MODEL ------------------------------------------- 83 | model: 84 | lr: 5e-4 85 | beta_1: 0.9 86 | beta_2: 0.99 87 | weight_decay: 1e-5 88 | warmup_epochs: 10000 89 | max_epochs: 100000 90 | warmup_start_lr: 1e-8 91 | eta_min: 1e-8 92 | pretrained_path: "" 93 | 94 | net: 95 | class_path: weatherode.regional_forecast.ode.RegionalWeatherODE 96 | init_args: 97 | default_vars: [ 98 | "land_sea_mask", 99 | "orography", 100 | "lattitude", 101 | "2m_temperature", 102 | "10m_u_component_of_wind", 103 | "10m_v_component_of_wind", 104 | "geopotential_50", 105 | "geopotential_250", 106 | "geopotential_500", 107 | "geopotential_600", 108 | "geopotential_700", 109 | "geopotential_850", 110 | "geopotential_925", 111 | "u_component_of_wind_50", 112 | "u_component_of_wind_250", 113 | "u_component_of_wind_500", 114 | "u_component_of_wind_600", 115 | "u_component_of_wind_700", 116 | "u_component_of_wind_850", 117 | "u_component_of_wind_925", 118 | "v_component_of_wind_50", 119 | "v_component_of_wind_250", 120 | "v_component_of_wind_500", 121 | "v_component_of_wind_600", 122 | "v_component_of_wind_700", 123 | "v_component_of_wind_850", 124 | "v_component_of_wind_925", 125 | "temperature_50", 126 | "temperature_250", 127 | "temperature_500", 128 | "temperature_600", 129 | "temperature_700", 130 | "temperature_850", 131 | "temperature_925", 132 | "relative_humidity_50", 133 | "relative_humidity_250", 134 | "relative_humidity_500", 135 | "relative_humidity_600", 136 | "relative_humidity_700", 137 | "relative_humidity_850", 138 | "relative_humidity_925", 139 | "specific_humidity_50", 140 | "specific_humidity_250", 141 | "specific_humidity_500", 142 | "specific_humidity_600", 143 | "specific_humidity_700", 144 | "specific_humidity_850", 145 | "specific_humidity_925", 146 | ] 147 | img_size: [32, 64] 148 | layers: [5, 5, 3, 2] 149 | hidden: [512, 128, 64] 150 | depth: 4 151 | method: "euler" 152 | drop_rate: 0.1 153 | time_steps: 36 154 | time_interval: 0.001 155 | rtol: 1e-9 156 | atol: 1e-11 157 | predict_list: [1,2,3,4,5,6] 158 | gradient_loss: False 159 | err_type: DiT 160 | err_with_v: False 161 | 162 | # ---------------------------- DATA ------------------------------------------- 163 | data: 164 | root_dir: /datadrive/datasets/5.625deg_equally_np/ 165 | variables: [ 166 | "land_sea_mask", 167 | "orography", 168 | "lattitude", 169 | "2m_temperature", 170 | "10m_u_component_of_wind", 171 | "10m_v_component_of_wind", 172 | "geopotential_50", 173 | "geopotential_250", 174 | "geopotential_500", 175 | "geopotential_600", 176 | "geopotential_700", 177 | "geopotential_850", 178 | "geopotential_925", 179 | "u_component_of_wind_50", 180 | "u_component_of_wind_250", 181 | "u_component_of_wind_500", 182 | "u_component_of_wind_600", 183 | "u_component_of_wind_700", 184 | "u_component_of_wind_850", 185 | "u_component_of_wind_925", 186 | "v_component_of_wind_50", 187 | "v_component_of_wind_250", 188 | "v_component_of_wind_500", 189 | "v_component_of_wind_600", 190 | "v_component_of_wind_700", 191 | "v_component_of_wind_850", 192 | "v_component_of_wind_925", 193 | "temperature_50", 194 | "temperature_250", 195 | "temperature_500", 196 | "temperature_600", 197 | "temperature_700", 198 | "temperature_850", 199 | "temperature_925", 200 | "relative_humidity_50", 201 | "relative_humidity_250", 202 | "relative_humidity_500", 203 | "relative_humidity_600", 204 | "relative_humidity_700", 205 | "relative_humidity_850", 206 | "relative_humidity_925", 207 | "specific_humidity_50", 208 | "specific_humidity_250", 209 | "specific_humidity_500", 210 | "specific_humidity_600", 211 | "specific_humidity_700", 212 | "specific_humidity_850", 213 | "specific_humidity_925", 214 | ] 215 | out_variables: ["geopotential_500", "temperature_850", "2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind"] 216 | region: "NorthAmerica" 217 | predict_range: 72 218 | hrs_each_step: 1 219 | buffer_size: 10000 220 | batch_size: 128 221 | num_workers: 1 222 | pin_memory: False 223 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: weatherode 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_kmp_llvm 9 | - appdirs=1.4.4=pyh9f0ad1d_0 10 | - asciitree=0.3.3=py_2 11 | - blas=1.0=mkl 12 | - bokeh=2.4.3=pyhd8ed1ab_3 13 | - bottleneck=1.3.6=py38h7e4f40d_0 14 | - brotlipy=0.7.0=py38h27cfd23_1003 15 | - bzip2=1.0.8=h7b6447c_0 16 | - c-ares=1.18.1=h7f98852_0 17 | - ca-certificates=2024.3.11=h06a4308_0 18 | - certifi=2024.6.2=py38h06a4308_0 19 | - cf_xarray=0.7.9=pyhd8ed1ab_0 20 | - cffi=1.15.1=py38h5eee18b_3 21 | - cftime=1.6.2=py38h26c90d9_1 22 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 23 | - click=8.1.3=unix_pyhd8ed1ab_2 24 | - cloudpickle=2.2.1=pyhd8ed1ab_0 25 | - cryptography=38.0.4=py38h9ce1e76_0 26 | - cudatoolkit=11.3.1=h2bc3f7f_2 27 | - curl=7.87.0=h6312ad2_0 28 | - cytoolz=0.12.0=py38h0a891b7_1 29 | - dask=2023.1.1=pyhd8ed1ab_0 30 | - dask-core=2023.1.1=pyhd8ed1ab_0 31 | - distributed=2023.1.1=pyhd8ed1ab_0 32 | - entrypoints=0.4=pyhd8ed1ab_0 33 | - esmf=8.4.0=nompi_hdb2cfa9_3 34 | - esmpy=8.4.0=nompi_py38h2b78397_1 35 | - fasteners=0.17.3=pyhd8ed1ab_0 36 | - ffmpeg=4.3=hf484d3e_0 37 | - fftw=3.3.10=nompi_hf0379b8_106 38 | - flit-core=3.6.0=pyhd3eb1b0_0 39 | - freetype=2.12.1=h4a9f257_0 40 | - geos=3.11.1=h27087fc_0 41 | - giflib=5.2.1=h5eee18b_1 42 | - gmp=6.2.1=h295c915_3 43 | - gnutls=3.6.15=he1e5248_0 44 | - hdf4=4.2.15=h9772cbc_5 45 | - hdf5=1.12.2=nompi_h2386368_101 46 | - heapdict=1.0.1=py_0 47 | - idna=3.4=py38h06a4308_0 48 | - importlib-metadata=6.0.0=pyha770c72_0 49 | - intel-openmp=2021.4.0=h06a4308_3561 50 | - jinja2=3.1.2=pyhd8ed1ab_1 51 | - jpeg=9e=h7f8727e_0 52 | - keyutils=1.6.1=h166bdaf_0 53 | - krb5=1.20.1=hf9c8cef_0 54 | - lame=3.100=h7b6447c_0 55 | - lcms2=2.12=h3be6417_0 56 | - ld_impl_linux-64=2.38=h1181459_1 57 | - lerc=3.0=h295c915_0 58 | - libaec=1.0.6=hcb278e6_1 59 | - libblas=3.9.0=12_linux64_mkl 60 | - libcblas=3.9.0=12_linux64_mkl 61 | - libcurl=7.87.0=h6312ad2_0 62 | - libdeflate=1.8=h7f8727e_5 63 | - libedit=3.1.20191231=he28a2e2_2 64 | - libev=4.33=h516909a_1 65 | - libffi=3.4.2=h6a678d5_6 66 | - libgcc-ng=12.2.0=h65d4601_19 67 | - libgfortran-ng=12.2.0=h69a702a_19 68 | - libgfortran5=12.2.0=h337968e_19 69 | - libiconv=1.16=h7f8727e_2 70 | - libidn2=2.3.2=h7f8727e_0 71 | - liblapack=3.9.0=12_linux64_mkl 72 | - libllvm11=11.1.0=he0ac6c6_5 73 | - libnetcdf=4.8.1=nompi_h21705cb_104 74 | - libnghttp2=1.51.0=hdcd2b5c_0 75 | - libpng=1.6.37=hbc83047_0 76 | - libssh2=1.10.0=haa6b8db_3 77 | - libstdcxx-ng=12.2.0=h46fd767_19 78 | - libtasn1=4.16.0=h27cfd23_0 79 | - libtiff=4.5.0=h6a678d5_1 80 | - libunistring=0.9.10=h27cfd23_0 81 | - libwebp=1.2.4=h11a3e52_0 82 | - libwebp-base=1.2.4=h5eee18b_0 83 | - libzip=1.9.2=hc869a4a_1 84 | - libzlib=1.2.13=h166bdaf_4 85 | - llvm-openmp=15.0.7=h0cdce71_0 86 | - llvmlite=0.39.1=py38h38d86a4_1 87 | - locket=1.0.0=pyhd8ed1ab_0 88 | - lz4=4.2.0=py38hd012fdc_0 89 | - lz4-c=1.9.4=h6a678d5_0 90 | - markupsafe=2.1.2=py38h1de0b5d_0 91 | - mkl=2021.4.0=h06a4308_640 92 | - mkl-service=2.4.0=py38h7f8727e_0 93 | - mkl_fft=1.3.1=py38hd3c417c_0 94 | - mkl_random=1.2.2=py38h51133e4_0 95 | - msgpack-python=1.0.4=py38h43d8883_1 96 | - ncurses=6.4=h6a678d5_0 97 | - netcdf-fortran=4.6.0=nompi_he1eeb6f_102 98 | - netcdf4=1.6.2=nompi_py38h2250339_100 99 | - nettle=3.7.3=hbbd107a_1 100 | - numba=0.56.4=py38h9a4aae9_0 101 | - numcodecs=0.11.0=py38h8dc9893_1 102 | - numpy=1.23.5=py38h14f4228_0 103 | - numpy-base=1.23.5=py38h31eccc5_0 104 | - openh264=2.1.1=h4ff587b_0 105 | - openssl=1.1.1w=h7f8727e_0 106 | - packaging=23.0=pyhd8ed1ab_0 107 | - partd=1.3.0=pyhd8ed1ab_0 108 | - pillow=9.3.0=py38h6a678d5_2 109 | - pip=22.3.1=py38h06a4308_0 110 | - pooch=1.6.0=pyhd8ed1ab_0 111 | - portalocker=2.3.0=py38h06a4308_0 112 | - psutil=5.9.4=py38h0a891b7_0 113 | - pycparser=2.21=pyhd3eb1b0_0 114 | - pyopenssl=22.0.0=pyhd3eb1b0_0 115 | - pysocks=1.7.1=py38h06a4308_0 116 | - python=3.8.16=h7a1cb2a_2 117 | - python-dateutil=2.8.2=pyhd8ed1ab_0 118 | - python_abi=3.8=2_cp38 119 | - pytorch-mutex=1.0=cuda 120 | - pytz=2022.7.1=pyhd8ed1ab_0 121 | - pyyaml=6.0=py38h0a891b7_5 122 | - readline=8.2=h5eee18b_0 123 | - requests=2.28.1=py38h06a4308_0 124 | - setuptools=65.6.3=py38h06a4308_0 125 | - shapely=1.8.5=py38hafd38ec_2 126 | - six=1.16.0=pyhd3eb1b0_1 127 | - sortedcontainers=2.4.0=pyhd8ed1ab_0 128 | - sparse=0.13.0=pyhd8ed1ab_0 129 | - sqlite=3.40.1=h5082296_0 130 | - tblib=1.7.0=pyhd8ed1ab_0 131 | - tk=8.6.12=h1ccaba5_0 132 | - toolz=0.12.0=pyhd8ed1ab_0 133 | - torchaudio=0.12.1=py38_cu113 134 | - tornado=6.2=py38h0a891b7_1 135 | - urllib3=1.26.14=py38h06a4308_0 136 | - wheel=0.37.1=pyhd3eb1b0_0 137 | - xarray=2023.1.0=pyhd8ed1ab_0 138 | - xesmf=0.7.0=pyhd8ed1ab_0 139 | - xz=5.2.10=h5eee18b_1 140 | - yaml=0.2.5=h7f98852_2 141 | - zarr=2.13.6=pyhd8ed1ab_0 142 | - zict=2.2.0=pyhd8ed1ab_0 143 | - zipp=3.12.1=pyhd8ed1ab_0 144 | - zlib=1.2.13=h166bdaf_4 145 | - zstd=1.5.2=ha4553b6_0 146 | - pip: 147 | - absl-py==2.1.0 148 | - aiohappyeyeballs==2.4.0 149 | - aiohttp==3.10.5 150 | - aiosignal==1.3.1 151 | - antlr4-python3-runtime==4.9.3 152 | - async-timeout==4.0.3 153 | - attrs==24.2.0 154 | - basemap==1.4.1 155 | - basemap-data==1.3.2 156 | - cachetools==5.5.0 157 | - contourpy==1.1.1 158 | - cycler==0.12.1 159 | - docker-pycreds==0.4.0 160 | - docstring-parser==0.16 161 | - filelock==3.16.1 162 | - fire==0.6.0 163 | - fonttools==4.53.1 164 | - frozenlist==1.4.1 165 | - fsspec==2024.9.0 166 | - gitdb==4.0.11 167 | - gitpython==3.1.43 168 | - google-auth==2.34.0 169 | - google-auth-oauthlib==1.0.0 170 | - grpcio==1.66.1 171 | - huggingface-hub==0.25.0 172 | - importlib-resources==6.4.5 173 | - jsonargparse==4.32.1 174 | - kiwisolver==1.4.7 175 | - lightning-lite==1.8.0 176 | - lightning-utilities==0.3.0 177 | - markdown==3.7 178 | - markdown-it-py==3.0.0 179 | - matplotlib==3.7.5 180 | - mdurl==0.1.2 181 | - multidict==6.1.0 182 | - nvidia-cublas-cu11==11.10.3.66 183 | - nvidia-cuda-nvrtc-cu11==11.7.99 184 | - nvidia-cuda-runtime-cu11==11.7.99 185 | - nvidia-cudnn-cu11==8.5.0.96 186 | - oauthlib==3.2.2 187 | - omegaconf==2.3.0 188 | - pandas==2.0.3 189 | - platformdirs==4.3.6 190 | - protobuf==5.28.1 191 | - pyasn1==0.6.1 192 | - pyasn1-modules==0.4.1 193 | - pygments==2.18.0 194 | - pyparsing==3.1.4 195 | - pyproj==3.5.0 196 | - pyshp==2.3.1 197 | - pytorch-lightning==1.8.0 198 | - requests-oauthlib==2.0.0 199 | - rich==13.8.1 200 | - rsa==4.9 201 | - safetensors==0.4.5 202 | - scipy==1.9.1 203 | - sentry-sdk==2.14.0 204 | - setproctitle==1.3.3 205 | - smmap==5.0.1 206 | - tensorboard==2.14.0 207 | - tensorboard-data-server==0.7.2 208 | - termcolor==2.4.0 209 | - timm==0.6.12 210 | - torch==1.12.1+cu116 211 | - torchdata==0.4.1 212 | - torchdiffeq==0.2.4 213 | - torchmetrics==0.11.4 214 | - torchvision==0.13.1+cu116 215 | - tqdm==4.66.5 216 | - typeshed-client==2.7.0 217 | - typing-extensions==4.12.2 218 | - tzdata==2024.1 219 | - wandb==0.18.1 220 | - werkzeug==3.0.4 221 | - yarl==1.11.1 222 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "weatherode" 7 | version = "0.0.1" 8 | authors =[ 9 | {name="Tung Nguyen", email="tungnd@g.ucla.edu"}, 10 | {name="Jayesh K. Gupta", email="mail@rejuvyesh.com"} 11 | ] 12 | description = "" 13 | readme = "README.md" 14 | requires-python = ">=3.8" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | ] 19 | dependencies = [ 20 | 21 | ] 22 | 23 | [tool.setuptools.packages.find] 24 | where = ["src"] 25 | -------------------------------------------------------------------------------- /scripts/global/train12h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./Global_12h 3 | 4 | python src/weatherode/global_forecast/train.py --config configs/global_forecast_weatherode.yaml \ 5 | --trainer.strategy=ddp \ 6 | --trainer.devices=4 \ 7 | --trainer.max_epochs=50 \ 8 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 9 | --data.predict_range=12 \ 10 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 11 | --model.lr=5e-4 --model.ode_lr=1e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 12 | --model.weight_decay=1e-5 \ 13 | --model.net.init_args.time_steps=12 \ 14 | --model.net.depth=4 \ 15 | --model.net.err_type=3D \ 16 | --model.net.predict_list=[6,12] \ 17 | --model.net.gradient_loss=False \ 18 | --model.net.err_with_x=True \ 19 | --model.net.err_with_v=True \ 20 | --data.batch_size=8 \ 21 | --model.warmup_epochs=20000 \ 22 | --model.max_epochs=80000 \ 23 | -------------------------------------------------------------------------------- /scripts/global/train18h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./Global_18h 3 | 4 | python src/weatherode/global_forecast/train.py --config configs/global_forecast_weatherode.yaml \ 5 | --trainer.strategy=ddp \ 6 | --trainer.devices=4 \ 7 | --trainer.max_epochs=50 \ 8 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 9 | --data.predict_range=18 \ 10 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 11 | --model.lr=5e-4 --model.ode_lr=1e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 12 | --model.weight_decay=1.2e-5 \ 13 | --model.net.init_args.time_steps=18 \ 14 | --model.net.depth=4 \ 15 | --model.net.err_type=3D \ 16 | --model.net.predict_list=[6,12,18] \ 17 | --model.net.gradient_loss=False \ 18 | --model.net.err_with_x=True \ 19 | --model.net.err_with_v=True \ 20 | --data.batch_size=8 \ 21 | --model.warmup_epochs=20000 \ 22 | --model.max_epochs=100000 \ 23 | -------------------------------------------------------------------------------- /scripts/global/train24h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./Global_24h 3 | 4 | python src/weatherode/global_forecast/train.py --config configs/global_forecast_weatherode.yaml \ 5 | --trainer.strategy=ddp \ 6 | --trainer.devices=4 \ 7 | --trainer.max_epochs=50 \ 8 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 9 | --data.predict_range=24 \ 10 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 11 | --model.lr=5e-4 --model.ode_lr=1e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 12 | --model.weight_decay=1.2e-5 \ 13 | --model.net.init_args.time_steps=24 \ 14 | --model.net.depth=4 \ 15 | --model.net.err_type=3D \ 16 | --model.net.predict_list=[6,12,18,24] \ 17 | --model.net.gradient_loss=False \ 18 | --model.net.err_with_x=True \ 19 | --model.net.err_with_v=True \ 20 | --data.batch_size=8 \ 21 | --model.warmup_epochs=20000 \ 22 | --model.max_epochs=100000 \ 23 | -------------------------------------------------------------------------------- /scripts/global/train6h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./Global_6h 3 | 4 | python src/weatherode/global_forecast/train.py --config configs/global_forecast_weatherode.yaml \ 5 | --trainer.strategy=ddp \ 6 | --trainer.devices=4 \ 7 | --trainer.max_epochs=50 \ 8 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 9 | --data.predict_range=6 \ 10 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 11 | --model.lr=5e-4 --model.ode_lr=1e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 12 | --model.weight_decay=1e-5 \ 13 | --model.net.init_args.time_steps=6 \ 14 | --model.net.depth=4 \ 15 | --model.net.err_type=3D \ 16 | --model.net.predict_list=[6] \ 17 | --model.net.gradient_loss=False \ 18 | --model.net.err_with_x=True \ 19 | --model.net.err_with_v=True \ 20 | --data.batch_size=8 \ 21 | --model.warmup_epochs=20000 \ 22 | --model.max_epochs=80000 \ 23 | -------------------------------------------------------------------------------- /scripts/regional/Australia/train12h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./Australia_12h 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_12h_Australia \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=12 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=12 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[10,14] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6,12] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=True \ 24 | --model.net.err_with_v=True \ 25 | --data.batch_size=8 \ 26 | --data.region=Australia \ 27 | --model.warmup_epochs=20000 \ 28 | --model.max_epochs=80000 \ 29 | -------------------------------------------------------------------------------- /scripts/regional/Australia/train18h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./Australia_18h 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_18h_Australia \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=18 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=18 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[10,14] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6,12,18] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=False \ 24 | --model.net.err_with_v=False \ 25 | --data.batch_size=8 \ 26 | --data.region=Australia \ 27 | --model.warmup_epochs=10000 \ 28 | --model.max_epochs=40000 \ 29 | -------------------------------------------------------------------------------- /scripts/regional/Australia/train24h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./Australia_24h 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_24h_Australia \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=24 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=24 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[10,14] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6,12,18,24] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=False \ 24 | --model.net.err_with_v=False \ 25 | --data.batch_size=8 \ 26 | --data.region=Australia \ 27 | --model.warmup_epochs=10000 \ 28 | --model.max_epochs=40000 \ 29 | -------------------------------------------------------------------------------- /scripts/regional/Australia/train6h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./Australia_6h 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_6h_Australia \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=6 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=6 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[10,14] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=True \ 24 | --model.net.err_with_v=True \ 25 | --data.batch_size=8 \ 26 | --data.region=Australia \ 27 | --model.warmup_epochs=20000 \ 28 | --model.max_epochs=80000 \ 29 | -------------------------------------------------------------------------------- /scripts/regional/NorthAmerica/train12h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./NorthAmerican_12h 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_12h_NorthAmerican \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=12 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=12 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[8,14] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6,12] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=True \ 24 | --model.net.err_with_v=True \ 25 | --data.batch_size=8 \ 26 | --model.warmup_epochs=20000 \ 27 | --model.max_epochs=80000 \ 28 | -------------------------------------------------------------------------------- /scripts/regional/NorthAmerica/train18h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./NorthAmerican_18h 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_18h_NorthAmerican \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=18 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=18 \ 17 | --model.net.img_size=[8,14] \ 18 | --model.net.depth=4 \ 19 | --model.net.err_type=3D \ 20 | --model.net.predict_list=[6,12,18] \ 21 | --model.net.gradient_loss=False \ 22 | --model.net.err_with_x=True \ 23 | --model.net.err_with_v=True \ 24 | --data.batch_size=8 \ 25 | --model.warmup_epochs=20000 \ 26 | --model.max_epochs=80000 \ 27 | -------------------------------------------------------------------------------- /scripts/regional/NorthAmerica/train24h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./NorthAmerican_24h 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_24h_NorthAmerican \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=24 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=24 \ 17 | --model.net.img_size=[8,14] \ 18 | --model.net.depth=4 \ 19 | --model.net.err_type=3D \ 20 | --model.net.predict_list=[6,12,18,24] \ 21 | --model.net.gradient_loss=False \ 22 | --model.net.err_with_x=True \ 23 | --model.net.err_with_v=True \ 24 | --data.batch_size=8 \ 25 | --model.warmup_epochs=20000 \ 26 | --model.max_epochs=80000 \ 27 | -------------------------------------------------------------------------------- /scripts/regional/NorthAmerica/train6h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./NorthAmerican_6h 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_6h_NorthAmerican \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=6 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=6 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[8,14] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=True \ 24 | --model.net.err_with_v=True \ 25 | --data.batch_size=8 \ 26 | --model.warmup_epochs=20000 \ 27 | --model.max_epochs=80000 \ 28 | -------------------------------------------------------------------------------- /scripts/regional/SouthAmerica/train12h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./SouthAmerican_12h 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_12h_SouthAmerican \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=12 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=12 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[14,10] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6,12] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=True \ 24 | --model.net.err_with_v=True \ 25 | --data.batch_size=8 \ 26 | --data.region=SouthAmerica \ 27 | --model.warmup_epochs=20000 \ 28 | --model.max_epochs=80000 29 | -------------------------------------------------------------------------------- /scripts/regional/SouthAmerica/train18h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./SouthAmerican_18h 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_18h_SouthAmerican \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=18 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=18 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[14,10] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6,12,18] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=False \ 24 | --model.net.err_with_v=False \ 25 | --data.batch_size=8 \ 26 | --data.region=SouthAmerica \ 27 | --model.warmup_epochs=10000 \ 28 | --model.max_epochs=40000 29 | -------------------------------------------------------------------------------- /scripts/regional/SouthAmerica/train24h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./SouthAmerican_24h 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_24h_SouthAmerican \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=24 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=24 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[14,10] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6,12,18,24] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=False \ 24 | --model.net.err_with_v=False \ 25 | --data.batch_size=8 \ 26 | --data.region=SouthAmerica \ 27 | --model.warmup_epochs=10000 \ 28 | --model.max_epochs=40000 \ 29 | -------------------------------------------------------------------------------- /scripts/regional/SouthAmerica/train6h.sh: -------------------------------------------------------------------------------- 1 | export NCCL_ALGO=Tree 2 | export OUTPUT_DIR=./SouthAmerican_6h 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | python src/weatherode/regional_forecast/train.py --config configs/regional_forecast_weatherode.yaml \ 6 | --trainer.strategy=ddp \ 7 | --trainer.devices=4 \ 8 | --trainer.max_epochs=50 \ 9 | --trainer.logger.name=predict_6h_SouthAmerican \ 10 | --data.root_dir=/jupyter/weather_prediction/weather_prediction/5.625deg_npz \ 11 | --data.predict_range=6 \ 12 | --model.pretrained_path='' --data.out_variables=["geopotential_500","temperature_850",'2m_temperature','10m_u_component_of_wind','10m_v_component_of_wind'] \ 13 | --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ 14 | --model.weight_decay=1e-5 \ 15 | --model.ode_lr=1e-4 \ 16 | --model.net.time_steps=6 \ 17 | --model.net.patch_size=2 \ 18 | --model.net.img_size=[14,10] \ 19 | --model.net.depth=4 \ 20 | --model.net.err_type=3D \ 21 | --model.net.predict_list=[6] \ 22 | --model.net.gradient_loss=False \ 23 | --model.net.err_with_x=True \ 24 | --model.net.err_with_v=True \ 25 | --data.batch_size=8 \ 26 | --data.region=SouthAmerica \ 27 | --model.warmup_epochs=20000 \ 28 | --model.max_epochs=80000 -------------------------------------------------------------------------------- /src/data_preprocessing/nc2np_equally_era5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import glob 5 | import os 6 | 7 | import click 8 | import numpy as np 9 | import xarray as xr 10 | from tqdm import tqdm 11 | 12 | from weatherode.utils.data_utils import DEFAULT_PRESSURE_LEVELS, NAME_TO_VAR 13 | 14 | HOURS_PER_YEAR = 8760 # 365-day year 15 | 16 | 17 | def nc2np(path, variables, years, save_dir, partition, num_shards_per_year): 18 | os.makedirs(os.path.join(save_dir, partition), exist_ok=True) 19 | 20 | if partition == "train": 21 | normalize_mean = {} 22 | normalize_std = {} 23 | climatology = {} 24 | 25 | constants = xr.open_mfdataset(os.path.join(path, "constants.nc"), combine="by_coords", parallel=True) 26 | constant_fields = ["land_sea_mask", "orography", "lattitude"] 27 | constant_values = {} 28 | for f in constant_fields: 29 | constant_values[f] = np.expand_dims(constants[NAME_TO_VAR[f]].to_numpy(), axis=(0, 1)).repeat( 30 | HOURS_PER_YEAR, axis=0 31 | ) 32 | if partition == "train": 33 | normalize_mean[f] = constant_values[f].mean(axis=(0, 2, 3)) 34 | normalize_std[f] = constant_values[f].std(axis=(0, 2, 3)) 35 | 36 | for year in tqdm(years): 37 | np_vars = {} 38 | 39 | # constant variables 40 | for f in constant_fields: 41 | np_vars[f] = constant_values[f] 42 | 43 | # non-constant fields 44 | for var in variables: 45 | ps = glob.glob(os.path.join(path, var, f"*{year}*.nc")) 46 | ds = xr.open_mfdataset(ps, combine="by_coords", parallel=True) # dataset for a single variable 47 | code = NAME_TO_VAR[var] 48 | 49 | if len(ds[code].shape) == 3: # surface level variables 50 | ds[code] = ds[code].expand_dims("val", axis=1) 51 | # remove the last 24 hours if this year has 366 days 52 | np_vars[var] = ds[code].to_numpy()[:HOURS_PER_YEAR] 53 | 54 | if partition == "train": # compute mean and std of each var in each year 55 | var_mean_yearly = np_vars[var].mean(axis=(0, 2, 3)) 56 | var_std_yearly = np_vars[var].std(axis=(0, 2, 3)) 57 | if var not in normalize_mean: 58 | normalize_mean[var] = [var_mean_yearly] 59 | normalize_std[var] = [var_std_yearly] 60 | else: 61 | normalize_mean[var].append(var_mean_yearly) 62 | normalize_std[var].append(var_std_yearly) 63 | 64 | clim_yearly = np_vars[var].mean(axis=0) 65 | if var not in climatology: 66 | climatology[var] = [clim_yearly] 67 | else: 68 | climatology[var].append(clim_yearly) 69 | 70 | else: # multiple-level variables, only use a subset 71 | assert len(ds[code].shape) == 4 72 | all_levels = ds["level"][:].to_numpy() 73 | all_levels = np.intersect1d(all_levels, DEFAULT_PRESSURE_LEVELS) 74 | for level in all_levels: 75 | ds_level = ds.sel(level=[level]) 76 | level = int(level) 77 | # remove the last 24 hours if this year has 366 days 78 | np_vars[f"{var}_{level}"] = ds_level[code].to_numpy()[:HOURS_PER_YEAR] 79 | 80 | if partition == "train": # compute mean and std of each var in each year 81 | var_mean_yearly = np_vars[f"{var}_{level}"].mean(axis=(0, 2, 3)) 82 | var_std_yearly = np_vars[f"{var}_{level}"].std(axis=(0, 2, 3)) 83 | if var not in normalize_mean: 84 | normalize_mean[f"{var}_{level}"] = [var_mean_yearly] 85 | normalize_std[f"{var}_{level}"] = [var_std_yearly] 86 | else: 87 | normalize_mean[f"{var}_{level}"].append(var_mean_yearly) 88 | normalize_std[f"{var}_{level}"].append(var_std_yearly) 89 | 90 | clim_yearly = np_vars[f"{var}_{level}"].mean(axis=0) 91 | if f"{var}_{level}" not in climatology: 92 | climatology[f"{var}_{level}"] = [clim_yearly] 93 | else: 94 | climatology[f"{var}_{level}"].append(clim_yearly) 95 | 96 | assert HOURS_PER_YEAR % num_shards_per_year == 0 97 | num_hrs_per_shard = HOURS_PER_YEAR // num_shards_per_year 98 | for shard_id in range(num_shards_per_year): 99 | start_id = shard_id * num_hrs_per_shard 100 | end_id = start_id + num_hrs_per_shard 101 | sharded_data = {k: np_vars[k][start_id:end_id] for k in np_vars.keys()} 102 | np.savez( 103 | os.path.join(save_dir, partition, f"{year}_{shard_id}.npz"), 104 | **sharded_data, 105 | ) 106 | 107 | if partition == "train": 108 | for var in normalize_mean.keys(): 109 | if var not in constant_fields: 110 | normalize_mean[var] = np.stack(normalize_mean[var], axis=0) 111 | normalize_std[var] = np.stack(normalize_std[var], axis=0) 112 | 113 | for var in normalize_mean.keys(): # aggregate over the years 114 | if var not in constant_fields: 115 | mean, std = normalize_mean[var], normalize_std[var] 116 | # var(X) = E[var(X|Y)] + var(E[X|Y]) 117 | variance = (std**2).mean(axis=0) + (mean**2).mean(axis=0) - mean.mean(axis=0) ** 2 118 | std = np.sqrt(variance) 119 | # E[X] = E[E[X|Y]] 120 | mean = mean.mean(axis=0) 121 | normalize_mean[var] = mean 122 | normalize_std[var] = std 123 | 124 | np.savez(os.path.join(save_dir, "normalize_mean.npz"), **normalize_mean) 125 | np.savez(os.path.join(save_dir, "normalize_std.npz"), **normalize_std) 126 | 127 | for var in climatology.keys(): 128 | climatology[var] = np.stack(climatology[var], axis=0) 129 | climatology = {k: np.mean(v, axis=0) for k, v in climatology.items()} 130 | np.savez( 131 | os.path.join(save_dir, partition, "climatology.npz"), 132 | **climatology, 133 | ) 134 | 135 | 136 | @click.command() 137 | @click.option("--root_dir", type=click.Path(exists=True)) 138 | @click.option("--save_dir", type=str) 139 | @click.option( 140 | "--variables", 141 | "-v", 142 | type=click.STRING, 143 | multiple=True, 144 | default=[ 145 | "2m_temperature", 146 | "10m_u_component_of_wind", 147 | "10m_v_component_of_wind", 148 | "toa_incident_solar_radiation", 149 | "total_precipitation", 150 | "geopotential", 151 | "u_component_of_wind", 152 | "v_component_of_wind", 153 | "temperature", 154 | "relative_humidity", 155 | "specific_humidity", 156 | ], 157 | ) 158 | @click.option("--start_train_year", type=int, default=1979) 159 | @click.option("--start_val_year", type=int, default=2016) 160 | @click.option("--start_test_year", type=int, default=2017) 161 | @click.option("--end_year", type=int, default=2019) 162 | @click.option("--num_shards", type=int, default=8) 163 | def main( 164 | root_dir, 165 | save_dir, 166 | variables, 167 | start_train_year, 168 | start_val_year, 169 | start_test_year, 170 | end_year, 171 | num_shards, 172 | ): 173 | assert start_val_year > start_train_year and start_test_year > start_val_year and end_year > start_test_year 174 | train_years = range(start_train_year, start_val_year) 175 | val_years = range(start_val_year, start_test_year) 176 | test_years = range(start_test_year, end_year) 177 | 178 | os.makedirs(save_dir, exist_ok=True) 179 | 180 | nc2np(root_dir, variables, train_years, save_dir, "train", num_shards) 181 | nc2np(root_dir, variables, val_years, save_dir, "val", num_shards) 182 | nc2np(root_dir, variables, test_years, save_dir, "test", num_shards) 183 | 184 | # save lat and lon data 185 | ps = glob.glob(os.path.join(root_dir, variables[0], f"*{train_years[0]}*.nc")) 186 | x = xr.open_mfdataset(ps[0], parallel=True) 187 | lat = x["lat"].to_numpy() 188 | lon = x["lon"].to_numpy() 189 | np.save(os.path.join(save_dir, "lat.npy"), lat) 190 | np.save(os.path.join(save_dir, "lon.npy"), lon) 191 | 192 | 193 | if __name__ == "__main__": 194 | main() 195 | -------------------------------------------------------------------------------- /src/data_preprocessing/regrid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import xarray as xr 3 | import numpy as np 4 | import xesmf as xe 5 | from glob import glob 6 | import os 7 | 8 | def regrid( 9 | ds_in, 10 | ddeg_out, 11 | method='bilinear', 12 | reuse_weights=True, 13 | cmip=False, 14 | rename=None 15 | ): 16 | """ 17 | Regrid horizontally. 18 | :param ds_in: Input xarray dataset 19 | :param ddeg_out: Output resolution 20 | :param method: Regridding method 21 | :param reuse_weights: Reuse weights for regridding 22 | :return: ds_out: Regridded dataset 23 | """ 24 | # import pdb; pdb.set_trace() 25 | # Rename to ESMF compatible coordinates 26 | if 'latitude' in ds_in.coords: 27 | ds_in = ds_in.rename({'latitude': 'lat', 'longitude': 'lon'}) 28 | if cmip: 29 | ds_in = ds_in.drop(('lat_bnds', 'lon_bnds')) 30 | if hasattr(ds_in, 'plev_bnds'): 31 | ds_in = ds_in.drop(('plev_bnds')) 32 | if hasattr(ds_in, 'time_bnds'): 33 | ds_in = ds_in.drop(('time_bnds')) 34 | if rename is not None: 35 | ds_in = ds_in.rename({rename[0]: rename[1]}) 36 | 37 | # Create output grid 38 | grid_out = xr.Dataset( 39 | { 40 | 'lat': (['lat'], np.arange(-90+ddeg_out/2, 90, ddeg_out)), 41 | 'lon': (['lon'], np.arange(0, 360, ddeg_out)), 42 | } 43 | ) 44 | 45 | # Create regridder 46 | regridder = xe.Regridder( 47 | ds_in, grid_out, method, periodic=True, reuse_weights=reuse_weights 48 | ) 49 | 50 | ds_out = regridder(ds_in, keep_attrs=True).astype('float32') 51 | 52 | # # Set attributes since they get lost during regridding 53 | # for var in ds_out: 54 | # ds_out[var].attrs = ds_in[var].attrs 55 | # ds_out.attrs.update(ds_in.attrs) 56 | 57 | if rename is not None: 58 | if rename[0] == 'zg': 59 | ds_out['z'] *= 9.807 60 | if rename[0] == 'rsdt': 61 | ds_out['tisr'] *= 60*60 62 | ds_out = ds_out.isel(time=slice(1, None, 12)) 63 | ds_out = ds_out.assign_coords({'time': ds_out.time + np.timedelta64(90, 'm')}) 64 | 65 | # # Regrid dataset 66 | # ds_out = regridder(ds_in) 67 | return ds_out 68 | 69 | 70 | def main( 71 | input_fns, 72 | output_dir, 73 | ddeg_out, 74 | method='bilinear', 75 | reuse_weights=True, 76 | custom_fn=None, 77 | file_ending='nc', 78 | cmip=False, 79 | rename=None 80 | ): 81 | """ 82 | :param input_fns: Input files. Can use *. If more than one, loop over them 83 | :param output_dir: Output directory 84 | :param ddeg_out: Output resolution 85 | :param method: Regridding method 86 | :param reuse_weights: Reuse weights for regridding 87 | :param custom_fn: If not None, use custom file name. Otherwise infer from parameters. 88 | :param file_ending: Default = nc 89 | """ 90 | 91 | # Make sure output directory exists 92 | os.makedirs(output_dir, exist_ok=True) 93 | # Get files for starred expressions 94 | if '*' in input_fns[0]: 95 | input_fns = sorted(glob(input_fns[0])) 96 | # Loop over input files 97 | for fn in input_fns: 98 | print(f'Regridding file: {fn}') 99 | ds_in = xr.open_dataset(fn) 100 | ds_out = regrid(ds_in, ddeg_out, method, reuse_weights, cmip, rename) 101 | fn_out = ( 102 | custom_fn or 103 | '_'.join(fn.split('/')[-1][:-3].split('_')[:-1]) + '_' + str(ddeg_out) + 'deg.' + file_ending 104 | ) 105 | print(f"Saving file: {output_dir + '/' + fn_out}") 106 | ds_out.to_netcdf(output_dir + '/' + fn_out) 107 | ds_in.close(); ds_out.close() 108 | 109 | if __name__ == '__main__': 110 | 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument( 113 | '--input_fns', 114 | type=str, 115 | nargs='+', 116 | help="Input files (full path). Can use *. If more than one, loop over them", 117 | required=True 118 | ) 119 | parser.add_argument( 120 | '--output_dir', 121 | type=str, 122 | help="Output directory", 123 | required=True 124 | ) 125 | parser.add_argument( 126 | '--ddeg_out', 127 | type=float, 128 | help="Output resolution", 129 | required=True 130 | ) 131 | parser.add_argument( 132 | '--reuse_weights', 133 | type=int, 134 | help="Reuse weights for regridding. 0 or 1 (default)", 135 | # default=1, 136 | default=0 137 | ) 138 | parser.add_argument( 139 | '--custom_fn', 140 | type=str, 141 | help="If not None, use custom file name. Otherwise infer from parameters.", 142 | default=None 143 | ) 144 | parser.add_argument( 145 | '--file_ending', 146 | type=str, 147 | help="File ending. Default = nc", 148 | default='nc' 149 | ) 150 | parser.add_argument( 151 | '--cmip', 152 | type=int, 153 | help="Is CMIP data. 0 or 1 (default)", 154 | default=0 155 | ) 156 | parser.add_argument( 157 | '--rename', 158 | type=str, 159 | nargs='+', 160 | help="Rename var in dataset", 161 | default=None 162 | ) 163 | args = parser.parse_args() 164 | 165 | main( 166 | input_fns=args.input_fns, 167 | output_dir=args.output_dir, 168 | ddeg_out=args.ddeg_out, 169 | reuse_weights=args.reuse_weights, 170 | custom_fn=args.custom_fn, 171 | file_ending=args.file_ending, 172 | cmip=args.cmip, 173 | rename=args.rename 174 | ) -------------------------------------------------------------------------------- /src/data_preprocessing/regrid_climatebench.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | import click 5 | import xarray as xr 6 | import numpy as np 7 | import xesmf as xe 8 | 9 | def regrid( 10 | ds_in, 11 | ddeg_out, 12 | method='bilinear', 13 | reuse_weights=True, 14 | cmip=False, 15 | rename=None 16 | ): 17 | """ 18 | Regrid horizontally. 19 | :param ds_in: Input xarray dataset 20 | :param ddeg_out: Output resolution 21 | :param method: Regridding method 22 | :param reuse_weights: Reuse weights for regridding 23 | :return: ds_out: Regridded dataset 24 | """ 25 | # import pdb; pdb.set_trace() 26 | # Rename to ESMF compatible coordinates 27 | if 'latitude' in ds_in.coords: 28 | ds_in = ds_in.rename({'latitude': 'lat', 'longitude': 'lon'}) 29 | if cmip: 30 | ds_in = ds_in.drop(('lat_bnds', 'lon_bnds')) 31 | if hasattr(ds_in, 'plev_bnds'): 32 | ds_in = ds_in.drop(('plev_bnds')) 33 | if hasattr(ds_in, 'time_bnds'): 34 | ds_in = ds_in.drop(('time_bnds')) 35 | if rename is not None: 36 | ds_in = ds_in.rename({rename[0]: rename[1]}) 37 | 38 | # Create output grid 39 | grid_out = xr.Dataset( 40 | { 41 | 'lat': (['lat'], np.arange(-90+ddeg_out/2, 90, ddeg_out)), 42 | 'lon': (['lon'], np.arange(0, 360, ddeg_out)), 43 | } 44 | ) 45 | 46 | # Create regridder 47 | regridder = xe.Regridder( 48 | ds_in, grid_out, method, periodic=True, reuse_weights=reuse_weights 49 | ) 50 | 51 | # Hack to speed up regridding of large files 52 | ds_out = regridder(ds_in, keep_attrs=True).astype('float32') 53 | 54 | if rename is not None: 55 | if rename[0] == 'zg': 56 | ds_out['z'] *= 9.807 57 | if rename[0] == 'rsdt': 58 | ds_out['tisr'] *= 60*60 59 | ds_out = ds_out.isel(time=slice(1, None, 12)) 60 | ds_out = ds_out.assign_coords({'time': ds_out.time + np.timedelta64(90, 'm')}) 61 | 62 | # # Regrid dataset 63 | # ds_out = regridder(ds_in) 64 | return ds_out 65 | 66 | @click.command() 67 | @click.argument("path", type=click.Path(exists=True)) 68 | @click.option("--save_path", type=str) 69 | @click.option("--ddeg_out", type=float, default=5.625) 70 | def main( 71 | path, 72 | save_path, 73 | ddeg_out 74 | ): 75 | if not os.path.exists(save_path): 76 | os.makedirs(save_path, exist_ok=True) 77 | 78 | list_simu = ['hist-GHG.nc', 'hist-aer.nc', 'historical.nc', 'ssp126.nc', 'ssp370.nc', 'ssp585.nc', 'ssp245.nc'] 79 | ps = glob(os.path.join(path, f"*.nc")) 80 | ps_ = [] 81 | for p in ps: 82 | for simu in list_simu: 83 | if simu in p: 84 | ps_.append(p) 85 | ps = ps_ 86 | 87 | constant_vars = ['CO2', 'CH4'] 88 | for p in ps: 89 | x = xr.open_dataset(p) 90 | if 'input' in p: 91 | for v in constant_vars: 92 | x[v] = x[v].expand_dims(dim={'latitude': 96, 'longitude': 144}, axis=(1,2)) 93 | x_regridded = regrid(x, ddeg_out, reuse_weights=False) 94 | x_regridded.to_netcdf(os.path.join(save_path, os.path.basename(p))) 95 | 96 | if __name__ == "__main__": 97 | main() -------------------------------------------------------------------------------- /src/weatherode/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/weatherode/c3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | 5 | class Conv2Plus1D(nn.Sequential): 6 | def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None: 7 | super().__init__( 8 | nn.Conv3d( 9 | in_planes, 10 | midplanes, 11 | kernel_size=(1, 3, 3), 12 | stride=(1, stride, stride), 13 | padding=(0, padding, padding), 14 | bias=False, 15 | ), 16 | nn.BatchNorm3d(midplanes), 17 | nn.ReLU(inplace=True), 18 | nn.Conv3d( 19 | midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False 20 | ), 21 | nn.BatchNorm3d(out_planes), 22 | nn.ReLU(inplace=True), 23 | ) 24 | 25 | class ResidualBlock2Plus1D(nn.Module): 26 | def __init__(self, in_channels, out_channels, norm=False, n_groups=1): 27 | super().__init__() 28 | mid_channels = (in_channels * out_channels * 3 * 3 * 3) // (in_channels * 3 * 3 + 3 * out_channels) 29 | self.conv1 = Conv2Plus1D(in_channels, out_channels, mid_channels) 30 | self.conv2 = Conv2Plus1D(out_channels, out_channels, mid_channels) 31 | self.bn1 = nn.BatchNorm3d(out_channels) 32 | self.bn2 = nn.BatchNorm3d(out_channels) 33 | self.activation = nn.LeakyReLU(0.3) 34 | # self.attention = nn.MultiheadAttention(out_channels, num_heads) 35 | 36 | self.nan = nn.Identity() 37 | self.drop = nn.Dropout(0.1) 38 | 39 | self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() 40 | 41 | self.norm1 = nn.GroupNorm(n_groups, out_channels) if norm else nn.Identity() 42 | self.norm2 = nn.GroupNorm(n_groups, out_channels) if norm else nn.Identity() 43 | 44 | def forward(self, x: torch.Tensor): 45 | # First convolution layer 46 | 47 | test = self.conv2(self.norm2(self.conv1(self.norm1(x.permute(1, 2, 0, 3, 4))))) 48 | if self._check_for_nan(test, "2Plus1D Conv"): 49 | h = self.activation(self.nan(self.conv1(self.norm1(x.permute(1, 2, 0, 3, 4))))) 50 | # Second convolution layer 51 | h = self.activation(self.nan(self.conv2(self.norm2(h)))) 52 | h = self.drop(h) 53 | else: 54 | h = self.activation(self.bn1(self.conv1(self.norm1(x.permute(1, 2, 0, 3, 4))))) 55 | # Second convolution layer 56 | h = self.activation(self.bn2(self.conv2(self.norm2(h)))) 57 | h = self.drop(h) 58 | 59 | if torch.isnan(self.bn1.running_mean).any() or torch.isnan(self.bn2.running_mean).any(): 60 | print("NAN!!!!!!!!!\n\n\n") 61 | breakpoint() 62 | 63 | # Add the shortcut connection and return 64 | return (h + self.shortcut(x.permute(1, 2, 0, 3, 4))).permute(2, 0, 1, 3, 4) 65 | 66 | def _check_for_nan(self, tensor: torch.Tensor, step: str) -> bool: 67 | has_nan = torch.isnan(tensor).any().float() 68 | 69 | if dist.is_initialized(): 70 | # 如果在分布式训练中,使用全局归约操作 71 | dist.all_reduce(has_nan, op=dist.ReduceOp.SUM) 72 | 73 | if has_nan > 0: 74 | if dist.is_initialized(): 75 | rank = dist.get_rank() 76 | print(f"NaN detected on GPU {rank} at step: {step}") 77 | dist.barrier() # 同步进程,让所有进程停留在相同的状态 78 | else: 79 | print(f"NaN detected in single-GPU training at step: {step}") 80 | return True 81 | return False 82 | 83 | 84 | class ClimateResNet2Plus1D(nn.Module): 85 | def __init__(self, num_channels, layers, hidden_size): 86 | super().__init__() 87 | cnn_layers = [] 88 | 89 | self.residual_block_class = ResidualBlock2Plus1D 90 | self.inplanes = num_channels 91 | 92 | for idx in range(len(layers)): 93 | in_channels = num_channels if idx == 0 else hidden_size[idx - 1] 94 | out_channels = hidden_size[idx] 95 | cnn_layers.append( 96 | self.create_layer( 97 | self.residual_block_class, in_channels, out_channels, layers[idx] 98 | ) 99 | ) 100 | 101 | self.cnn_layer_modules = nn.ModuleList(cnn_layers) 102 | 103 | def create_layer(self, block, in_channels, out_channels, reps): 104 | layers = [] 105 | layers.append(block(in_channels, out_channels)) 106 | for i in range(1, reps): 107 | layers.append(block(out_channels, out_channels)) 108 | 109 | return nn.Sequential(*layers) 110 | 111 | def forward(self, data): 112 | output = data.float() 113 | 114 | for layer in self.cnn_layer_modules: 115 | output = layer(output) 116 | 117 | return output 118 | -------------------------------------------------------------------------------- /src/weatherode/cnn_dit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | def modulate(x, shift, scale): 7 | return x * (1 + scale.unsqueeze(2).unsqueeze(3)) + shift.unsqueeze(2).unsqueeze(3) 8 | 9 | 10 | class TimestepEmbedder(nn.Module): 11 | """ 12 | Embeds scalar timesteps into vector representations. 13 | """ 14 | def __init__(self, hidden_size, frequency_embedding_size=256): 15 | super().__init__() 16 | self.mlp = nn.Sequential( 17 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 18 | nn.SiLU(), 19 | nn.Linear(hidden_size, hidden_size, bias=True), 20 | ) 21 | self.frequency_embedding_size = frequency_embedding_size 22 | 23 | @staticmethod 24 | def timestep_embedding(t, dim, max_period=10000): 25 | half = dim // 2 26 | freqs = torch.exp( 27 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 28 | ).to(device=t.device) 29 | args = t[:, None].float() * freqs[None] 30 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 31 | if dim % 2: 32 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 33 | return embedding 34 | 35 | def forward(self, t): 36 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 37 | t_emb = self.mlp(t_freq) 38 | return t_emb 39 | 40 | 41 | class ResidualBlock(nn.Module): 42 | def __init__( 43 | self, 44 | in_channels: int, 45 | out_channels: int, 46 | t_embed_dim: int, 47 | activation: str = "gelu", 48 | norm: bool = False, 49 | n_groups: int = 1, 50 | ): 51 | super().__init__() 52 | self.activation = nn.LeakyReLU(0.3) 53 | self.conv1 = nn.Conv2d( 54 | in_channels, out_channels, kernel_size=(3, 3), padding="same" 55 | ) 56 | self.conv2 = nn.Conv2d( 57 | out_channels, out_channels, kernel_size=(3, 3), padding="same" 58 | ) 59 | self.conv3 = nn.Conv2d( 60 | out_channels, out_channels, kernel_size=(3, 3), padding="same" 61 | ) 62 | 63 | self.bn1 = nn.BatchNorm2d(out_channels) 64 | self.bn2 = nn.BatchNorm2d(out_channels) 65 | self.bn3 = nn.BatchNorm2d(out_channels) 66 | 67 | self.drop = nn.Dropout(0.1) 68 | 69 | self.shortcut = ( 70 | nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)) 71 | if in_channels != out_channels 72 | else nn.Identity() 73 | ) 74 | 75 | self.norm1 = nn.GroupNorm(n_groups, out_channels) if norm else nn.Identity() 76 | self.norm2 = nn.GroupNorm(n_groups, out_channels) if norm else nn.Identity() 77 | 78 | self.adaLN_modulation = nn.Sequential( 79 | nn.SiLU(), 80 | nn.Linear(t_embed_dim, 6 * out_channels, bias=True) 81 | ) 82 | 83 | def forward(self, x: torch.Tensor, t_emb: torch.Tensor): 84 | shift1, scale1, gate1, shift2, scale2, gate2 = self.adaLN_modulation(t_emb).chunk(6, dim=1) 85 | 86 | h = self.activation(self.bn1(self.conv1(self.norm1(x)))) 87 | h = modulate(h, shift1, scale1) 88 | # First convolution layer 89 | h = self.activation(self.bn1(self.conv2(h))) 90 | h = h * gate1.unsqueeze(2).unsqueeze(3) 91 | 92 | # Second convolution layer 93 | h = modulate(self.norm2(h), shift2, scale2) 94 | h = self.activation(self.bn3(self.conv3(h))) 95 | h = h * gate2.unsqueeze(2).unsqueeze(3) 96 | 97 | h = self.drop(h) 98 | # Add the shortcut connection and return 99 | return h + self.shortcut(x) 100 | 101 | 102 | class ClimateResNet2DTime(nn.Module): 103 | def __init__(self, num_channels, layers, hidden_size, t_embed_dim=256): 104 | super().__init__() 105 | cnn_layers = [] 106 | 107 | self.residual_block_class = ResidualBlock 108 | self.inplanes = num_channels 109 | 110 | for idx in range(len(layers)): 111 | in_channels = num_channels if idx == 0 else hidden_size[idx - 1] 112 | out_channels = hidden_size[idx] 113 | cnn_layers.append( 114 | self.create_layer( 115 | self.residual_block_class, in_channels, out_channels, layers[idx], t_embed_dim 116 | ) 117 | ) 118 | 119 | self.cnn_layer_modules = nn.ModuleList(cnn_layers) 120 | self.t_embedder = TimestepEmbedder(hidden_size[0], t_embed_dim) 121 | 122 | def create_layer(self, block, in_channels, out_channels, reps, t_embed_dim): 123 | layers = [] 124 | layers.append(block(in_channels, out_channels, t_embed_dim=t_embed_dim)) 125 | for i in range(1, reps): 126 | layers.append(block(out_channels, out_channels, t_embed_dim=t_embed_dim)) 127 | 128 | return nn.Sequential(*layers) 129 | 130 | def forward(self, data, t): 131 | output = data.float() 132 | t_emb = self.t_embedder(t) 133 | 134 | for layer in self.cnn_layer_modules: 135 | for block in layer: 136 | output = block(output, t_emb) 137 | 138 | return output 139 | -------------------------------------------------------------------------------- /src/weatherode/dit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 17 | 18 | 19 | def modulate(x, shift, scale): 20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 21 | 22 | 23 | ################################################################################# 24 | # Embedding Layers for Timesteps and Class Labels # 25 | ################################################################################# 26 | 27 | class TimestepEmbedder(nn.Module): 28 | """ 29 | Embeds scalar timesteps into vector representations. 30 | """ 31 | def __init__(self, hidden_size, frequency_embedding_size=256): 32 | super().__init__() 33 | self.mlp = nn.Sequential( 34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 35 | nn.SiLU(), 36 | nn.Linear(hidden_size, hidden_size, bias=True), 37 | ) 38 | self.frequency_embedding_size = frequency_embedding_size 39 | 40 | @staticmethod 41 | def timestep_embedding(t, dim, max_period=24): 42 | """ 43 | Create sinusoidal timestep embeddings. 44 | :param t: a 1-D Tensor of N indices, one per batch element. 45 | These may be fractional. 46 | :param dim: the dimension of the output. 47 | :param max_period: controls the minimum frequency of the embeddings. 48 | :return: an (N, D) Tensor of positional embeddings. 49 | """ 50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 51 | half = dim // 2 52 | freqs = torch.exp( 53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 54 | ).to(device=t.device) 55 | args = t[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 59 | return embedding 60 | 61 | def forward(self, t): 62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 63 | t_emb = self.mlp(t_freq) 64 | return t_emb 65 | 66 | 67 | class LabelEmbedder(nn.Module): 68 | """ 69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 70 | """ 71 | def __init__(self, num_classes, hidden_size, dropout_prob): 72 | super().__init__() 73 | use_cfg_embedding = dropout_prob > 0 74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 75 | self.num_classes = num_classes 76 | self.dropout_prob = dropout_prob 77 | 78 | def token_drop(self, labels, force_drop_ids=None): 79 | """ 80 | Drops labels to enable classifier-free guidance. 81 | """ 82 | if force_drop_ids is None: 83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 84 | else: 85 | drop_ids = force_drop_ids == 1 86 | labels = torch.where(drop_ids, self.num_classes, labels) 87 | return labels 88 | 89 | def forward(self, labels, train, force_drop_ids=None): 90 | use_dropout = self.dropout_prob > 0 91 | if (train and use_dropout) or (force_drop_ids is not None): 92 | labels = self.token_drop(labels, force_drop_ids) 93 | embeddings = self.embedding_table(labels) 94 | return embeddings 95 | 96 | 97 | ################################################################################# 98 | # Core DiT Model # 99 | ################################################################################# 100 | 101 | class DiTBlock(nn.Module): 102 | """ 103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 104 | """ 105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 106 | super().__init__() 107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1, proj_drop=0.1, **block_kwargs) 109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 110 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 111 | approx_gelu = lambda: nn.GELU(approximate="tanh") 112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0.1) 113 | self.adaLN_modulation = nn.Sequential( 114 | nn.SiLU(), 115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 116 | ) 117 | 118 | def forward(self, x, c): 119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 122 | return x 123 | 124 | 125 | class FinalLayer(nn.Module): 126 | """ 127 | The final layer of DiT. 128 | """ 129 | def __init__(self, hidden_size, patch_size, out_channels): 130 | super().__init__() 131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 133 | self.adaLN_modulation = nn.Sequential( 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 136 | ) 137 | 138 | def forward(self, x, c): 139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 140 | x = modulate(self.norm_final(x), shift, scale) 141 | x = self.linear(x) 142 | return x 143 | 144 | 145 | class DiT(nn.Module): 146 | """ 147 | Diffusion model with a Transformer backbone. 148 | """ 149 | def __init__( 150 | self, 151 | input_size=32, 152 | patch_size=2, 153 | in_channels=4, 154 | out_channels=None, 155 | hidden_size=768, 156 | depth=28, 157 | num_heads=12, 158 | mlp_ratio=4.0, 159 | class_dropout_prob=0.1, 160 | num_classes=1000, 161 | learn_sigma=True, 162 | ): 163 | super().__init__() 164 | self.input_size = input_size 165 | self.learn_sigma = learn_sigma 166 | self.in_channels = in_channels 167 | self.out_channels = in_channels if out_channels is None else out_channels # in_channels * 2 if learn_sigma else in_channels 168 | self.patch_size = patch_size 169 | self.num_heads = num_heads 170 | 171 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 172 | self.t_embedder = TimestepEmbedder(hidden_size) 173 | # self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 174 | num_patches = self.x_embedder.num_patches 175 | # Will use fixed sin-cos embedding: 176 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 177 | 178 | self.drop = nn.Dropout(0.1) 179 | 180 | self.blocks = nn.ModuleList([ 181 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 182 | ]) 183 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 184 | self.initialize_weights() 185 | 186 | def initialize_weights(self): 187 | # Initialize transformer layers: 188 | def _basic_init(module): 189 | if isinstance(module, nn.Linear): 190 | torch.nn.init.xavier_uniform_(module.weight) 191 | if module.bias is not None: 192 | nn.init.constant_(module.bias, 0) 193 | self.apply(_basic_init) 194 | 195 | # Initialize (and freeze) pos_embed by sin-cos embedding: 196 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.input_size[0] // self.patch_size, self.input_size[1] // self.patch_size)) 197 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 198 | 199 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 200 | w = self.x_embedder.proj.weight.data 201 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 202 | nn.init.constant_(self.x_embedder.proj.bias, 0) 203 | 204 | # Initialize label embedding table: 205 | # nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 206 | 207 | # Initialize timestep embedding MLP: 208 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 209 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 210 | 211 | # Zero-out adaLN modulation layers in DiT blocks: 212 | for block in self.blocks: 213 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 214 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 215 | 216 | # Zero-out output layers: 217 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 218 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 219 | nn.init.constant_(self.final_layer.linear.weight, 0) 220 | nn.init.constant_(self.final_layer.linear.bias, 0) 221 | 222 | def unpatchify(self, x): 223 | """ 224 | x: (N, T, patch_size**2 * C) 225 | imgs: (N, H, W, C) 226 | """ 227 | c = self.out_channels 228 | p = self.x_embedder.patch_size[0] 229 | # breakpoint() 230 | # h = w = int(x.shape[1] ** 0.5) 231 | h, w = self.input_size[0] // p, self.input_size[1] // p 232 | assert h * w == x.shape[1] 233 | 234 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 235 | x = torch.einsum('nhwpqc->nchpwq', x) 236 | imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) 237 | return imgs 238 | 239 | def forward(self, x, t): 240 | """ 241 | Forward pass of DiT. 242 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 243 | t: (N,) tensor of diffusion timesteps 244 | y: (N,) tensor of class labels 245 | """ 246 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 247 | 248 | x = self.drop(x) 249 | 250 | t = self.t_embedder(t) # (N, D) 251 | # y = self.y_embedder(y, self.training) # (N, D) 252 | c = t # (N, D) 253 | for block in self.blocks: 254 | x = block(x, c) # (N, T, D) 255 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 256 | x = self.unpatchify(x) # (N, out_channels, H, W) 257 | return x 258 | 259 | def forward_with_cfg(self, x, t, y, cfg_scale): 260 | """ 261 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 262 | """ 263 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 264 | half = x[: len(x) // 2] 265 | combined = torch.cat([half, half], dim=0) 266 | model_out = self.forward(combined, t, y) 267 | # For exact reproducibility reasons, we apply classifier-free guidance on only 268 | # three channels by default. The standard approach to cfg applies it to all channels. 269 | # This can be done by uncommenting the following line and commenting-out the line following that. 270 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 271 | eps, rest = model_out[:, :3], model_out[:, 3:] 272 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 273 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 274 | eps = torch.cat([half_eps, half_eps], dim=0) 275 | return torch.cat([eps, rest], dim=1) 276 | 277 | 278 | ################################################################################# 279 | # Sine/Cosine Positional Embedding Functions # 280 | ################################################################################# 281 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 282 | 283 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 284 | """ 285 | grid_size: int of the grid height and width 286 | return: 287 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 288 | """ 289 | if not isinstance(grid_size, tuple): 290 | grid_size = (grid_size, grid_size) 291 | 292 | grid_h = np.arange(grid_size[0], dtype=np.float32) 293 | grid_w = np.arange(grid_size[1], dtype=np.float32) 294 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 295 | grid = np.stack(grid, axis=0) 296 | 297 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) 298 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 299 | if cls_token and extra_tokens > 0: 300 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 301 | return pos_embed 302 | 303 | 304 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 305 | assert embed_dim % 2 == 0 306 | 307 | # use half of dimensions to encode grid_h 308 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 309 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 310 | 311 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 312 | return emb 313 | 314 | 315 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 316 | """ 317 | embed_dim: output dimension for each position 318 | pos: a list of positions to be encoded: size (M,) 319 | out: (M, D) 320 | """ 321 | assert embed_dim % 2 == 0 322 | omega = np.arange(embed_dim // 2, dtype=np.float64) 323 | omega /= embed_dim / 2. 324 | omega = 1. / 10000**omega # (D/2,) 325 | 326 | pos = pos.reshape(-1) # (M,) 327 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 328 | 329 | emb_sin = np.sin(out) # (M, D/2) 330 | emb_cos = np.cos(out) # (M, D/2) 331 | 332 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 333 | return emb 334 | 335 | 336 | if __name__=='__main__': 337 | model = DiT(input_size=(32, 64), depth=12, hidden_size=384, patch_size=2, num_heads=6) 338 | 339 | input = torch.randn(4, 4, 32, 64) 340 | 341 | model(input, torch.randn(4), abs(torch.randn(4)).long()) -------------------------------------------------------------------------------- /src/weatherode/global_forecast/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/weatherode/global_forecast/datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | from typing import Optional 6 | 7 | import numpy as np 8 | import torch 9 | import torchdata.datapipes as dp 10 | from pytorch_lightning import LightningDataModule 11 | from torch.utils.data import DataLoader, IterableDataset 12 | from torchvision.transforms import transforms 13 | 14 | from weatherode.pretrain.datamodule import collate_fn 15 | from weatherode.pretrain.dataset import ( 16 | Forecast, 17 | IndividualForecastDataIter, 18 | NpyReader, 19 | ShuffleIterableDataset, 20 | ) 21 | 22 | 23 | class GlobalForecastDataModule(LightningDataModule): 24 | """DataModule for global forecast data. 25 | 26 | Args: 27 | root_dir (str): Root directory for sharded data. 28 | variables (list): List of input variables. 29 | buffer_size (int): Buffer size for shuffling. 30 | out_variables (list, optional): List of output variables. 31 | predict_range (int, optional): Predict range. 32 | hrs_each_step (int, optional): Hours each step. 33 | batch_size (int, optional): Batch size. 34 | num_workers (int, optional): Number of workers. 35 | pin_memory (bool, optional): Whether to pin memory. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | root_dir, 41 | variables, 42 | buffer_size, 43 | out_variables=None, 44 | predict_range: int = 6, 45 | hrs_each_step: int = 1, 46 | batch_size: int = 64, 47 | num_workers: int = 0, 48 | pin_memory: bool = False, 49 | ): 50 | super().__init__() 51 | if num_workers > 1: 52 | raise NotImplementedError( 53 | "num_workers > 1 is not supported yet. Performance will likely degrage too with larger num_workers." 54 | ) 55 | 56 | # this line allows to access init params with 'self.hparams' attribute 57 | self.save_hyperparameters(logger=False) 58 | 59 | if isinstance(out_variables, str): 60 | out_variables = [out_variables] 61 | self.hparams.out_variables = out_variables 62 | 63 | self.lister_train = list(dp.iter.FileLister(os.path.join(root_dir, "train"))) 64 | self.lister_val = list(dp.iter.FileLister(os.path.join(root_dir, "val"))) 65 | self.lister_test = list(dp.iter.FileLister(os.path.join(root_dir, "test"))) 66 | 67 | self.transforms = self.get_normalize() 68 | self.output_transforms = self.get_normalize(out_variables) 69 | 70 | self.val_clim = self.get_climatology("val", out_variables) 71 | self.test_clim = self.get_climatology("test", out_variables) 72 | 73 | self.data_train: Optional[IterableDataset] = None 74 | self.data_val: Optional[IterableDataset] = None 75 | self.data_test: Optional[IterableDataset] = None 76 | 77 | def get_normalize(self, variables=None): 78 | if variables is None: 79 | variables = self.hparams.variables 80 | normalize_mean = dict(np.load(os.path.join(self.hparams.root_dir, "normalize_mean.npz"))) 81 | mean = [] 82 | for var in variables: 83 | if var != "total_precipitation": 84 | mean.append(normalize_mean[var]) 85 | else: 86 | mean.append(np.array([0.0])) 87 | normalize_mean = np.concatenate(mean) 88 | normalize_std = dict(np.load(os.path.join(self.hparams.root_dir, "normalize_std.npz"))) 89 | normalize_std = np.concatenate([normalize_std[var] for var in variables]) 90 | return transforms.Normalize(normalize_mean, normalize_std) 91 | 92 | def get_lat_lon(self): 93 | lat = np.load(os.path.join(self.hparams.root_dir, "lat.npy")) 94 | lon = np.load(os.path.join(self.hparams.root_dir, "lon.npy")) 95 | self.lat = lat 96 | self.lon = lon 97 | return lat, lon 98 | 99 | def get_climatology(self, partition="val", variables=None): 100 | path = os.path.join(self.hparams.root_dir, partition, "climatology.npz") 101 | clim_dict = np.load(path) 102 | if variables is None: 103 | variables = self.hparams.variables 104 | clim = np.concatenate([clim_dict[var] for var in variables]) 105 | clim = torch.from_numpy(clim) 106 | return clim 107 | 108 | def setup(self, stage: Optional[str] = None): 109 | # load datasets only if they're not loaded already 110 | if not self.data_train and not self.data_val and not self.data_test: 111 | self.data_train = ShuffleIterableDataset( 112 | IndividualForecastDataIter( 113 | Forecast( 114 | NpyReader( 115 | file_list=self.lister_train, 116 | start_idx=0, 117 | end_idx=1, 118 | variables=self.hparams.variables, 119 | out_variables=self.hparams.out_variables, 120 | shuffle=True, 121 | multi_dataset_training=False, 122 | ), 123 | max_predict_range=self.hparams.predict_range, 124 | random_lead_time=False, 125 | hrs_each_step=self.hparams.hrs_each_step, 126 | ), 127 | transforms=self.transforms, 128 | output_transforms=self.output_transforms, 129 | 130 | ), 131 | buffer_size=self.hparams.buffer_size, 132 | ) 133 | 134 | self.data_val = IndividualForecastDataIter( 135 | Forecast( 136 | NpyReader( 137 | file_list=self.lister_val, 138 | start_idx=0, 139 | end_idx=1, 140 | variables=self.hparams.variables, 141 | out_variables=self.hparams.out_variables, 142 | shuffle=False, 143 | multi_dataset_training=False, 144 | ), 145 | max_predict_range=self.hparams.predict_range, 146 | random_lead_time=False, 147 | hrs_each_step=self.hparams.hrs_each_step, 148 | ), 149 | transforms=self.transforms, 150 | output_transforms=self.output_transforms, 151 | ) 152 | 153 | self.data_test = IndividualForecastDataIter( 154 | Forecast( 155 | NpyReader( 156 | file_list=self.lister_test, 157 | start_idx=0, 158 | end_idx=1, 159 | variables=self.hparams.variables, 160 | out_variables=self.hparams.out_variables, 161 | shuffle=False, 162 | multi_dataset_training=False, 163 | ), 164 | max_predict_range=self.hparams.predict_range, 165 | random_lead_time=False, 166 | hrs_each_step=self.hparams.hrs_each_step, 167 | ), 168 | transforms=self.transforms, 169 | output_transforms=self.output_transforms, 170 | ) 171 | 172 | def train_dataloader(self): 173 | return DataLoader( 174 | self.data_train, 175 | batch_size=self.hparams.batch_size, 176 | drop_last=False, 177 | num_workers=self.hparams.num_workers, 178 | pin_memory=self.hparams.pin_memory, 179 | collate_fn=collate_fn, 180 | ) 181 | 182 | def val_dataloader(self): 183 | return DataLoader( 184 | self.data_val, 185 | batch_size=self.hparams.batch_size, 186 | shuffle=False, 187 | drop_last=False, 188 | num_workers=self.hparams.num_workers, 189 | pin_memory=self.hparams.pin_memory, 190 | collate_fn=collate_fn, 191 | ) 192 | 193 | def test_dataloader(self): 194 | return DataLoader( 195 | self.data_test, 196 | batch_size=self.hparams.batch_size, 197 | shuffle=False, 198 | drop_last=False, 199 | num_workers=self.hparams.num_workers, 200 | pin_memory=self.hparams.pin_memory, 201 | collate_fn=collate_fn, 202 | ) 203 | -------------------------------------------------------------------------------- /src/weatherode/global_forecast/module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # credits: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py 5 | from typing import Any 6 | 7 | import wandb 8 | import torch 9 | from pytorch_lightning import LightningModule 10 | from torchvision.transforms import transforms 11 | 12 | from weatherode.ode import WeatherODE 13 | from weatherode.utils.lr_scheduler import LinearWarmupCosineAnnealingLR 14 | from weatherode.utils.metrics import ( 15 | lat_weighted_acc, 16 | lat_weighted_mse, 17 | lat_weighted_mse_val, 18 | lat_weighted_rmse, 19 | lat_weighted_mse_velocity_guess 20 | ) 21 | from weatherode.utils.pos_embed import interpolate_pos_embed 22 | 23 | from tqdm import tqdm 24 | 25 | import matplotlib.pyplot as plt 26 | from mpl_toolkits.basemap import Basemap 27 | import numpy as np 28 | 29 | class GlobalForecastModule(LightningModule): 30 | """Lightning module for global forecasting with the WeatherODE model. 31 | 32 | Args: 33 | net (WeatherODE): WeatherODE model. 34 | pretrained_path (str, optional): Path to pre-trained checkpoint. 35 | lr (float, optional): Learning rate. 36 | beta_1 (float, optional): Beta 1 for AdamW. 37 | beta_2 (float, optional): Beta 2 for AdamW. 38 | weight_decay (float, optional): Weight decay for AdamW. 39 | warmup_epochs (int, optional): Number of warmup epochs. 40 | max_epochs (int, optional): Number of total epochs. 41 | warmup_start_lr (float, optional): Starting learning rate for warmup. 42 | eta_min (float, optional): Minimum learning rate. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | net: WeatherODE, 48 | pretrained_path: str = "", 49 | lr: float = 5e-4, 50 | ode_lr: float = 5e-5, 51 | beta_1: float = 0.9, 52 | beta_2: float = 0.99, 53 | weight_decay: float = 1e-5, 54 | warmup_epochs: int = 10000, 55 | max_epochs: int = 200000, 56 | warmup_start_lr: float = 1e-8, 57 | eta_min: float = 1e-8, 58 | gradient_clip_val: float = 0.5, 59 | gradient_clip_algorithm: str = "value", 60 | train_noise_only: bool = False, 61 | ): 62 | super().__init__() 63 | self.save_hyperparameters(logger=False, ignore=["net"]) 64 | self.net = net 65 | self.skip_optimization = False 66 | if len(pretrained_path) > 0: 67 | self.load_pretrained_weights(pretrained_path) 68 | 69 | def load_pretrained_weights(self, pretrained_path): 70 | if pretrained_path.startswith("http"): 71 | checkpoint = torch.hub.load_state_dict_from_url(pretrained_path) 72 | else: 73 | checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu")) 74 | print("Loading pre-trained checkpoint from: %s" % pretrained_path) 75 | checkpoint_model = checkpoint["state_dict"] 76 | # interpolate positional embedding 77 | # interpolate_pos_embed(self.net, checkpoint_model, new_size=self.net.img_size) 78 | 79 | state_dict = self.state_dict() 80 | # if self.net.parallel_patch_embed: 81 | # if "token_embeds.proj_weights" not in checkpoint_model.keys(): 82 | # raise ValueError( 83 | # "Pretrained checkpoint does not have token_embeds.proj_weights for parallel processing. Please convert the checkpoints first or disable parallel patch_embed tokenization." 84 | # ) 85 | 86 | # checkpoint_keys = list(checkpoint_model.keys()) 87 | for k in list(checkpoint_model.keys()): 88 | if "channel" in k: 89 | checkpoint_model[k.replace("channel", "var")] = checkpoint_model[k] 90 | del checkpoint_model[k] 91 | for k in list(checkpoint_model.keys()): 92 | if k not in state_dict.keys() or checkpoint_model[k].shape != state_dict[k].shape: 93 | print(f"Removing key {k} from pretrained checkpoint") 94 | del checkpoint_model[k] 95 | 96 | # load pre-trained model 97 | msg = self.load_state_dict(checkpoint_model, strict=False) 98 | print(msg) 99 | 100 | def set_denormalization(self, mean, std): 101 | self.denormalization = transforms.Normalize(mean, std) 102 | 103 | def set_lat_lon(self, lat, lon): 104 | self.lat = lat 105 | self.lon = lon 106 | 107 | def set_pred_range(self, r): 108 | self.pred_range = r 109 | 110 | def set_val_clim(self, clim): 111 | self.val_clim = clim 112 | 113 | def set_test_clim(self, clim): 114 | self.test_clim = clim 115 | 116 | def training_step(self, batch: Any, batch_idx: int): 117 | x, y, predict_ranges, variables, out_variables = batch 118 | 119 | # init_vx, init_vy = optimize_vel(x, prev_x, self.kernel) 120 | 121 | loss_dict, _ = self.net.forward(x, y, predict_ranges, variables, out_variables, [lat_weighted_mse_velocity_guess], lat=self.lat, lon=self.lon, epoch=self.current_epoch) 122 | loss_dict = loss_dict[0] 123 | 124 | # check nan 125 | has_nan = False 126 | for var, loss_value in loss_dict.items(): 127 | if torch.isnan(loss_value).any(): 128 | has_nan = True 129 | break 130 | 131 | # sum nan 132 | has_nan_tensor = torch.tensor(float(has_nan), device=self.device) 133 | torch.distributed.all_reduce(has_nan_tensor, op=torch.distributed.ReduceOp.SUM) 134 | 135 | if has_nan_tensor.item() > 0: 136 | self.log("train/has_nan", True, on_step=True, on_epoch=False, prog_bar=True) 137 | self.skip_optimization = True 138 | return loss_dict["loss"] #torch.tensor(1., device=self.device, requires_grad=True).to(torch.float32) 139 | 140 | self.log("train/has_nan", False, on_step=True, on_epoch=False, prog_bar=True) 141 | self.skip_optimization = False 142 | 143 | for var in loss_dict.keys(): 144 | self.log( 145 | "train/" + var, 146 | loss_dict[var], 147 | on_step=True, 148 | on_epoch=False, 149 | prog_bar=True, 150 | ) 151 | loss = loss_dict["loss"] 152 | 153 | return loss 154 | 155 | def plot_weather_maps(self, preds, out_variables, batch_idx, lat, lon, type="gt", test=False): 156 | """ 157 | Plots weather maps for given predictions and logs them to the experiment logger. 158 | 159 | Args: 160 | preds (torch.Tensor): The predictions tensor of shape [batch_size, 5, 32, 64]. 161 | out_variables (list): List of variable names for the channels. 162 | batch_idx (int): The batch index. 163 | logger: The experiment logger. 164 | lat (np.array): Latitude values. 165 | lon (np.array): Longitude values. 166 | """ 167 | prefix = "test_images" if test else "val_images" 168 | 169 | batch_size, num_vars, height, width = preds.shape 170 | for var_idx in range(num_vars): 171 | fig, ax = plt.subplots() 172 | m = Basemap(projection='cyl', resolution='c', ax=ax, 173 | llcrnrlat=lat.min(), urcrnrlat=lat.max(), 174 | llcrnrlon=lon.min(), urcrnrlon=lon.max()) 175 | m.drawcoastlines() 176 | m.drawcountries() 177 | m.drawmapboundary() 178 | m.drawparallels(np.arange(-90., 91., 30.), labels=[1, 0, 0, 0]) 179 | m.drawmeridians(np.arange(-180., 181., 60.), labels=[0, 0, 0, 1]) 180 | 181 | data = preds[0, var_idx].cpu().detach().numpy() 182 | 183 | lon_grid, lat_grid = np.meshgrid(lon, lat) 184 | xi, yi = m(lon_grid, lat_grid) 185 | 186 | # Interpolate the data if needed (you can adjust the method and resolution) 187 | data = np.interp(data, (data.min(), data.max()), (0, 1)) 188 | 189 | # Plot the data 190 | cs = m.pcolormesh(xi, yi, data, cmap='RdBu') 191 | fig.colorbar(cs, ax=ax, orientation='vertical', label=out_variables[var_idx]) 192 | 193 | ax.set_title(f"{type}_{out_variables[var_idx]}") 194 | 195 | # Log the figure to the experiment logger 196 | try: 197 | self.logger.experiment.log({f"{prefix}/{out_variables[var_idx]}_{type}_{batch_idx}": wandb.Image(fig)}, step=self.global_step) 198 | except: 199 | self.logger.experiment.add_figure(f"{prefix}/{out_variables[var_idx]}_{type}_{batch_idx}", fig, global_step=self.global_step) 200 | plt.close(fig) 201 | 202 | def validation_step(self, batch: Any, batch_idx: int): 203 | x, y, predict_ranges, variables, out_variables = batch 204 | 205 | # init_vx, init_vy = optimize_vel(x, prev_x, self.kernel) 206 | 207 | if self.pred_range < 24: 208 | log_postfix = f"{self.pred_range}_hours" 209 | else: 210 | days = int(self.pred_range / 24) 211 | log_postfix = f"{days}_days" 212 | 213 | all_loss_dicts, preds, ode_preds, noise_preds = self.net.evaluate( 214 | x, 215 | y, 216 | predict_ranges, 217 | variables, 218 | out_variables, 219 | transform=self.denormalization, 220 | metrics=[lat_weighted_mse_val, lat_weighted_rmse, lat_weighted_acc], 221 | lat=self.lat, 222 | lon=self.lon, 223 | clim=self.val_clim, 224 | log_postfix=log_postfix, 225 | ) 226 | 227 | if batch_idx % 50 == 0: 228 | self.plot_weather_maps(y[:, -1], out_variables, batch_idx, self.lat, self.lon, type="gt") 229 | self.plot_weather_maps(ode_preds, out_variables, batch_idx, self.lat, self.lon, type="ode") 230 | self.plot_weather_maps(noise_preds, out_variables, batch_idx, self.lat, self.lon, type="noise") 231 | 232 | loss_dict = {} 233 | for d in all_loss_dicts: 234 | for k in d.keys(): 235 | loss_dict[k] = d[k] 236 | 237 | for var in loss_dict.keys(): 238 | self.log( 239 | "val/" + var, 240 | loss_dict[var], 241 | on_step=False, 242 | on_epoch=True, 243 | prog_bar=False, 244 | sync_dist=True, 245 | ) 246 | return loss_dict 247 | 248 | def test_step(self, batch: Any, batch_idx: int): 249 | x, y, predict_ranges, variables, out_variables = batch 250 | 251 | # init_vx, init_vy = optimize_vel(x, prev_x, self.kernel) 252 | 253 | if self.pred_range < 24: 254 | log_postfix = f"{self.pred_range}_hours" 255 | else: 256 | days = int(self.pred_range / 24) 257 | log_postfix = f"{days}_days" 258 | 259 | all_loss_dicts, preds, ode_preds, noise_preds = self.net.evaluate( 260 | x, 261 | y, 262 | predict_ranges, 263 | variables, 264 | out_variables, 265 | transform=self.denormalization, 266 | metrics=[lat_weighted_mse_val, lat_weighted_rmse, lat_weighted_acc], 267 | lat=self.lat, 268 | lon=self.lon, 269 | clim=self.test_clim, 270 | log_postfix=log_postfix, 271 | ) 272 | 273 | if batch_idx % 50 == 0: 274 | self.plot_weather_maps(y[:, -1], out_variables, batch_idx, self.lat, self.lon, type="gt", test=True) 275 | self.plot_weather_maps(ode_preds, out_variables, batch_idx, self.lat, self.lon, type="ode", test=True) 276 | self.plot_weather_maps(noise_preds, out_variables, batch_idx, self.lat, self.lon, type="noise", test=True) 277 | 278 | loss_dict = {} 279 | for d in all_loss_dicts: 280 | for k in d.keys(): 281 | loss_dict[k] = d[k] 282 | 283 | for var in loss_dict.keys(): 284 | self.log( 285 | "test/" + var, 286 | loss_dict[var], 287 | on_step=False, 288 | on_epoch=True, 289 | prog_bar=False, 290 | sync_dist=True, 291 | ) 292 | return loss_dict 293 | 294 | def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): 295 | self.clip_gradients(optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm) 296 | 297 | def configure_optimizers(self): 298 | decay = [] 299 | no_decay = [] 300 | 301 | ode = [] 302 | noise_net = [] 303 | for name, m in self.named_parameters(): 304 | if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name: 305 | no_decay.append(m) 306 | else: 307 | if "net.model" in name: 308 | ode.append(m) 309 | else: 310 | decay.append(m) 311 | 312 | if self.hparams.train_noise_only: 313 | if "net.noise_model" in name: 314 | noise_net.append(m) 315 | else: 316 | m.requires_grad = False 317 | 318 | if self.hparams.train_noise_only: 319 | optimizer = torch.optim.AdamW( 320 | [ 321 | { 322 | "params": noise_net, 323 | "lr": self.hparams.lr, 324 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 325 | "weight_decay": self.hparams.weight_decay, 326 | } 327 | ] 328 | ) 329 | else: 330 | optimizer = torch.optim.AdamW( 331 | [ 332 | { 333 | "params": ode, 334 | "lr": self.hparams.ode_lr, 335 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 336 | "weight_decay": self.hparams.weight_decay, 337 | }, 338 | { 339 | "params": decay, 340 | "lr": self.hparams.lr, 341 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 342 | "weight_decay": self.hparams.weight_decay, 343 | }, 344 | { 345 | "params": no_decay, 346 | "lr": self.hparams.lr, 347 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 348 | "weight_decay": 0, 349 | }, 350 | ] 351 | ) 352 | 353 | lr_scheduler = LinearWarmupCosineAnnealingLR( 354 | optimizer, 355 | self.hparams.warmup_epochs, 356 | self.hparams.max_epochs, 357 | self.hparams.warmup_start_lr, 358 | self.hparams.eta_min, 359 | ) 360 | scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1} 361 | 362 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 363 | 364 | 365 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): 366 | # check loss 367 | 368 | if self.skip_optimization: 369 | optimizer.zero_grad() 370 | optimizer_closure() 371 | return 372 | 373 | optimizer.step(closure=optimizer_closure) -------------------------------------------------------------------------------- /src/weatherode/global_forecast/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | os.environ['NCCL_ALGO'] = 'Tree' 6 | from weatherode.global_forecast.datamodule import GlobalForecastDataModule 7 | from weatherode.global_forecast.module import GlobalForecastModule 8 | from pytorch_lightning.cli import LightningCLI 9 | 10 | 11 | def main(): 12 | # Initialize Lightning with the model and data modules, and instruct it to parse the config yml 13 | cli = LightningCLI( 14 | model_class=GlobalForecastModule, 15 | datamodule_class=GlobalForecastDataModule, 16 | seed_everything_default=42, 17 | save_config_overwrite=True, 18 | run=False, 19 | auto_registry=True, 20 | parser_kwargs={"parser_mode": "omegaconf", "error_handler": None}, 21 | ) 22 | os.makedirs(cli.trainer.default_root_dir, exist_ok=True) 23 | 24 | normalization = cli.datamodule.output_transforms 25 | mean_norm, std_norm = normalization.mean, normalization.std 26 | mean_denorm, std_denorm = -mean_norm / std_norm, 1 / std_norm 27 | cli.model.set_denormalization(mean_denorm, std_denorm) 28 | cli.model.set_lat_lon(*cli.datamodule.get_lat_lon()) 29 | # cli.model.set_gauss_kernel() 30 | cli.model.set_pred_range(cli.datamodule.hparams.predict_range) 31 | cli.model.set_val_clim(cli.datamodule.val_clim) 32 | cli.model.set_test_clim(cli.datamodule.test_clim) 33 | 34 | cli.trainer.gradient_clip_val = cli.model.hparams.gradient_clip_val 35 | cli.trainer.gradient_clip_algorithm = cli.model.hparams.gradient_clip_algorithm 36 | 37 | # fit() runs the training 38 | cli.trainer.fit(cli.model, datamodule=cli.datamodule) 39 | 40 | # test the trained model 41 | cli.trainer.test(cli.model, datamodule=cli.datamodule, ckpt_path="best") 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /src/weatherode/ode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from functools import lru_cache 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | from weatherode.utils.pos_embed import ( 11 | get_1d_sincos_pos_embed_from_grid, 12 | get_2d_sincos_pos_embed, 13 | ) 14 | 15 | from weatherode.ode_utils import ClimateResNet2D, ClimateResNet3D, SelfAttentionConv, ViT, SparseLinear 16 | from weatherode.dit import DiT 17 | from weatherode.cnn_dit import ClimateResNet2DTime 18 | from weatherode.c3d import ClimateResNet2Plus1D 19 | from torchdiffeq import odeint 20 | import torch.distributed as dist 21 | 22 | class WeatherODE(nn.Module): 23 | """Implements the WeatherODE model as described in the paper, 24 | https://arxiv.org/abs/2301.10343 25 | 26 | Args: 27 | default_vars (list): list of default variables to be used for training 28 | img_size (list): image size of the input data 29 | patch_size (int): patch size of the input data 30 | embed_dim (int): embedding dimension 31 | depth (int): number of transformer layers 32 | decoder_depth (int): number of decoder layers 33 | num_heads (int): number of attention heads 34 | mlp_ratio (float): ratio of mlp hidden dimension to embedding dimension 35 | drop_path (float): stochastic depth rate 36 | drop_rate (float): dropout rate 37 | parallel_patch_embed (bool): whether to use parallel patch embedding 38 | """ 39 | 40 | def __init__( 41 | self, 42 | default_vars, 43 | method, 44 | img_size=[32, 64], 45 | patch_size=2, 46 | layers=[5, 5, 3, 2], # [5, 3, 2], 47 | hidden=[512, 128, 64], #[256, 64], 48 | depth=4, 49 | use_err=True, 50 | err_type="2D", 51 | err_with_x=False, 52 | err_with_v=False, 53 | err_with_std=False, 54 | drop_rate=0.1, 55 | time_steps=12, 56 | time_interval=0.001, 57 | rtol=1e-9, 58 | atol=1e-11, 59 | predict_list=[6], 60 | gradient_loss=False 61 | ): 62 | super().__init__() 63 | 64 | self.default_vars = default_vars 65 | self.method = method 66 | self.patch_size = patch_size 67 | self.time_steps = time_steps 68 | self.time_interval = time_interval 69 | self.rtol = rtol 70 | self.atol = atol 71 | 72 | self.layers = layers 73 | self.hidden = hidden + [2 * len(self.default_vars)] 74 | 75 | self.use_err = use_err 76 | self.drop_rate = drop_rate 77 | self.predict_list = predict_list 78 | self.gradient_loss = gradient_loss 79 | self.err_with_x = err_with_x 80 | self.err_with_v = err_with_v 81 | self.err_type = err_type 82 | 83 | self.v_net = ClimateResNet2D(3 * len(self.default_vars), self.layers, self.hidden) 84 | # self.v_net = ViT(3 * len(self.default_vars), 2 * len(self.default_vars)) 85 | 86 | # t(1), day_t(2), sea_t(2), x(5), nabla_x(10), v(10), lat_lon(2), pos(6), pos_time(24) 87 | input_channels = 37 + len(self.default_vars) * 5 88 | 89 | self.model = ViT(input_channels, 2 * len(self.default_vars), depth=depth, patch_size=patch_size, img_size=img_size) 90 | # self.model = ClimateResNet2D(input_channels, self.layers, self.hidden) 91 | 92 | self.linear_model = SparseLinear(input_channels, 2 * len(self.default_vars)) 93 | # ClimateResNet2D(input_channels, self.layers, self.hidden) 94 | 95 | # x(5), nabla_x(10), v(10), lat_lon(2), pos(6) 96 | if self.use_err: 97 | noise_input_channels = 8 + len(self.default_vars) 98 | noise_input_channels = noise_input_channels + 2 * len(self.default_vars) if self.err_with_v else noise_input_channels 99 | noise_input_channels = noise_input_channels + len(self.default_vars) if self.err_with_x else noise_input_channels 100 | 101 | noise_hidden = hidden + [len(self.default_vars)] 102 | 103 | if err_type == "vit": 104 | self.noise_model = ViT(noise_input_channels, len(self.default_vars), patch_size=patch_size, img_size=img_size) 105 | elif err_type == "2D": 106 | self.noise_model = ClimateResNet2D(noise_input_channels, self.layers, noise_hidden) 107 | elif err_type == "3D": 108 | self.noise_model = ClimateResNet3D(noise_input_channels, self.layers, noise_hidden) 109 | elif err_type == "2+1D": 110 | self.noise_model = ClimateResNet2Plus1D(noise_input_channels, self.layers, noise_hidden) 111 | elif err_type == "DiT": 112 | self.noise_model = DiT(input_size=tuple(img_size), in_channels=noise_input_channels, out_channels=len(self.default_vars), depth=self.layers) 113 | elif err_type == "2DTime": 114 | self.noise_model = ClimateResNet2DTime(noise_input_channels, self.layers, noise_hidden, t_embed_dim=hidden[0]) 115 | 116 | self.var_map = self.create_var_map() 117 | 118 | def create_var_map(self): 119 | var_map = {} 120 | idx = 0 121 | for var in self.default_vars: 122 | var_map[var] = idx 123 | idx += 1 124 | return var_map 125 | 126 | def get_var_ids(self, vars, device): 127 | ids = np.array([self.var_map[var] for var in vars]) 128 | return torch.from_numpy(ids).to(device) 129 | 130 | def pde(self, t, x): 131 | vx = x[:, len(self.default_vars) : 2 * len(self.default_vars)] 132 | vy = x[:, 2 * len(self.default_vars) : 3 * len(self.default_vars)] 133 | 134 | v = torch.cat([vx, vy], 1) 135 | 136 | new_lat_lon = x[:, 3 * len(self.default_vars): 3 * len(self.default_vars) + 2] 137 | 138 | pos_feats = x[:, 3 * len(self.default_vars) + 2:] 139 | 140 | x = x[:, : len(self.default_vars)] 141 | 142 | x_grad_x = torch.gradient(x, dim=3)[0] 143 | x_grad_y = torch.gradient(x, dim=2)[0] 144 | 145 | nabla_x = torch.cat([x_grad_x, x_grad_y], 1) 146 | 147 | t_emb = ((t * (1 / self.time_interval)) % 24).view(1, 1, 1, 1).expand(x.shape[0], 1, x.shape[2], x.shape[3]) 148 | 149 | sin_t_emb = torch.sin(torch.pi * t_emb / 12 - torch.pi / 2) 150 | cos_t_emb = torch.cos(torch.pi * t_emb / 12 - torch.pi / 2) 151 | 152 | sin_seas_emb = torch.sin(torch.pi * t_emb/ (12 * 365) - torch.pi / 2) 153 | cos_seas_emb = torch.cos(torch.pi * t_emb / (12 * 365) - torch.pi / 2) 154 | 155 | day_emb = torch.cat([sin_t_emb, cos_t_emb], 1) 156 | seas_emb = torch.cat([sin_seas_emb, cos_seas_emb], 1) 157 | 158 | t_cyc_emb = torch.cat([day_emb, seas_emb], 1) 159 | 160 | t_cyc_emb_expanded = t_cyc_emb.unsqueeze(2) 161 | pos_feats_expanded = pos_feats.unsqueeze(1) 162 | pos_time_ft = (t_cyc_emb_expanded * pos_feats_expanded).view(t_cyc_emb.shape[0], -1, t_cyc_emb.shape[2], t_cyc_emb.shape[3]) 163 | 164 | comb_rep = torch.cat([t_emb / 24, day_emb, seas_emb, nabla_x, v, x, new_lat_lon, pos_feats, pos_time_ft], 1) 165 | 166 | dv = self.model(comb_rep) 167 | 168 | dv += self.linear_model(comb_rep.reshape(comb_rep.shape[0], comb_rep.shape[2], comb_rep.shape[3], -1)).reshape(*dv.shape) 169 | 170 | adv1 = vx * x_grad_x + vy * x_grad_y 171 | adv2 = x * (torch.gradient(vx, dim=3)[0] + torch.gradient(vy, dim=2)[0]) 172 | 173 | x = adv1 + adv2 174 | 175 | return torch.cat([x, dv, new_lat_lon, pos_feats], 1) 176 | 177 | 178 | def forward(self, x, y, predict_range, variables, out_variables, metric, lat, lon, vis_noise=False, epoch=0): 179 | 180 | v_net_input = torch.cat([x, torch.gradient(x, dim=3)[0], torch.gradient(x, dim=2)[0]], 1) 181 | 182 | v_output = self.v_net(v_net_input) 183 | vx, vy = v_output[:, :x.shape[1]], v_output[:, x.shape[1]:] 184 | 185 | new_lat = torch.tensor(lat).float().expand(x.shape[3], x.shape[2]).T.to(x.device).expand(x.shape[0], 1, x.shape[2], x.shape[3]) * torch.pi / 180 186 | new_lon = torch.tensor(lon).float().expand(x.shape[2], x.shape[3]).to(x.device).expand(x.shape[0], 1, x.shape[2], x.shape[3]) * torch.pi / 180 187 | 188 | new_lat_lon = torch.cat([new_lat, new_lon], 1) 189 | 190 | cos_lat_map, sin_lat_map = torch.cos(new_lat), torch.sin(new_lat) 191 | cos_lon_map, sin_lon_map = torch.cos(new_lon), torch.sin(new_lon) 192 | 193 | pos_feats = torch.cat([cos_lat_map, cos_lon_map, sin_lat_map, sin_lon_map, sin_lat_map * cos_lon_map, sin_lat_map * sin_lon_map], 1) 194 | 195 | ode_x = torch.cat([x, vx, vy, new_lat_lon, pos_feats], 1) 196 | 197 | new_time_steps = torch.linspace(int(predict_range[0] / self.time_steps), int(predict_range[0]), self.time_steps).float().to(x.device) * self.time_interval 198 | 199 | final_result = odeint(self.pde, ode_x, new_time_steps, method=self.method, rtol=self.rtol, atol=self.atol) 200 | 201 | preds = final_result[:, :, :len(self.default_vars)] 202 | 203 | out_ids = self.get_var_ids(tuple(out_variables), preds.device) 204 | y_ = y.permute(1,0,2,3,4)[torch.linspace(int(predict_range[0] / self.time_steps), int(predict_range[0]), self.time_steps).long() - 1] 205 | 206 | if self._check_for_nan(preds, "ode"): 207 | if metric is None: 208 | loss = None 209 | if vis_noise: 210 | return preds[:, :, out_ids], preds[:, :, out_ids], preds[:, :, out_ids] 211 | else: 212 | preds = preds[:, :, out_ids] 213 | loss = [m(preds, preds, preds, y_, out_variables, lat) for m in metric] 214 | return loss, preds 215 | 216 | if self.use_err: 217 | noise_x = torch.cat([preds, new_lat_lon.expand(preds.shape[0], *new_lat_lon.shape), pos_feats.expand(preds.shape[0], *pos_feats.shape)], 2) 218 | if self.err_with_x: 219 | noise_x = torch.cat([noise_x, x.expand(preds.shape[0], *x.shape)], 2) 220 | if self.err_with_v: 221 | noise_x = torch.cat([noise_x, v_output.expand(preds.shape[0], *v_output.shape)], 2) 222 | 223 | if self.err_type == "2D": 224 | noise_x = noise_x.reshape(-1, *noise_x.shape[2:]) 225 | noise_output = self.noise_model(noise_x) 226 | noise_output = noise_output.view(preds.shape[0], -1, *noise_output.shape[1:]) 227 | elif self.err_type == "2DTime": 228 | noise_x = noise_x.reshape(-1, *noise_x.shape[2:]) 229 | 230 | time_embedding = torch.repeat_interleave(torch.linspace(int(predict_range[0] / self.time_steps), int(predict_range[0]), self.time_steps, device=preds.device), preds.shape[1]) 231 | 232 | noise_output = self.noise_model(noise_x, time_embedding) 233 | 234 | noise_output = noise_output.view(preds.shape[0], -1, *noise_output.shape[1:]) 235 | elif self.err_type == "DiT": 236 | noise_x = noise_x.reshape(-1, *noise_x.shape[2:]) 237 | 238 | time_embedding = torch.repeat_interleave(torch.linspace(int(predict_range[0] / self.time_steps), int(predict_range[0]), self.time_steps, device=preds.device), preds.shape[1]) 239 | 240 | noise_output = self.noise_model(noise_x, time_embedding) 241 | noise_output = noise_output.view(preds.shape[0], -1, *noise_output.shape[1:]) 242 | elif self.err_type == "3D": 243 | noise_output = self.noise_model(noise_x) 244 | elif self.err_type == "2+1D": 245 | noise_output = self.noise_model(noise_x) 246 | elif self.err_type == "vit": 247 | noise_x = noise_x.reshape(-1, *noise_x.shape[2:]) 248 | noise_output = self.noise_model(noise_x) 249 | noise_output = noise_output.view(preds.shape[0], -1, *noise_output.shape[1:]) 250 | 251 | if torch.isnan(noise_output).any(): 252 | print("noise nan \n") 253 | 254 | if self._check_for_nan(noise_output, "noise net"): 255 | if metric is None: 256 | loss = None 257 | if vis_noise: 258 | return preds[:, :, out_ids], preds[:, :, out_ids], preds[:, :, out_ids] 259 | else: 260 | loss = [m(preds[:, :, out_ids], preds[:, :, out_ids], preds[:, :, out_ids], y_, out_variables, lat) for m in metric] 261 | return loss, preds 262 | 263 | final_preds = preds + noise_output[:, :, :len(self.default_vars)] 264 | else: 265 | final_preds = preds.clone() 266 | noise_output = preds.clone() 267 | 268 | final_preds = final_preds[:, :, out_ids] 269 | 270 | if metric is None: 271 | # preds = preds[-1] 272 | loss = None 273 | if vis_noise: 274 | return final_preds, preds[:, :, out_ids], noise_output[:, :, :len(self.default_vars)][:, :, out_ids] 275 | else: 276 | loss = [m(final_preds, preds[:, :, out_ids], noise_output[:, :, :len(self.default_vars)][:, :, out_ids], y_, out_variables, lat, gradient_loss=self.gradient_loss, epoch=epoch) for m in metric] 277 | 278 | return loss, final_preds 279 | 280 | def evaluate(self, x, y, predict_range, variables, out_variables, transform, metrics, lat, lon, clim, log_postfix): 281 | preds, ode_preds, noise_preds = self.forward(x, y, predict_range, variables, out_variables, metric=None, lat=lat, lon=lon, vis_noise=True) 282 | 283 | ratio = int(predict_range.mean()) // preds.shape[0] 284 | 285 | loss_dict = [] 286 | 287 | for pred_range in self.predict_list: 288 | if pred_range < 24: 289 | log_postfix = f"{pred_range}_hours" 290 | else: 291 | days = pred_range // 24 292 | if pred_range > days * 24: 293 | log_postfix = f"{days}_days_{pred_range - days * 24}_hours" 294 | else: 295 | log_postfix = f"{days}_days" 296 | 297 | steps = pred_range // ratio 298 | 299 | dic_list = [m(preds[steps - 1], y.permute(1,0,2,3,4)[pred_range - 1], transform, out_variables, lat, clim, log_postfix) for m in metrics] 300 | 301 | if pred_range != int(predict_range.mean()): 302 | for dic in dic_list: 303 | dic.pop('w_rmse', None) 304 | 305 | loss_dict += dic_list 306 | 307 | return loss_dict, preds[-1], ode_preds[-1], noise_preds[-1] 308 | 309 | def _check_for_nan(self, tensor: torch.Tensor, step: str) -> bool: 310 | has_nan = torch.isnan(tensor).any().float() 311 | 312 | if dist.is_initialized(): 313 | dist.all_reduce(has_nan, op=dist.ReduceOp.SUM) 314 | 315 | if has_nan > 0: 316 | if dist.is_initialized(): 317 | rank = dist.get_rank() 318 | print(f"NaN detected on GPU {rank} at step: {step}") 319 | dist.barrier() 320 | else: 321 | print(f"NaN detected in single-GPU training at step: {step}") 322 | return True 323 | return False 324 | -------------------------------------------------------------------------------- /src/weatherode/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/weatherode/pretrain/datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | from typing import Dict, Optional 6 | 7 | import numpy as np 8 | import torch 9 | import torchdata.datapipes as dp 10 | from pytorch_lightning import LightningDataModule 11 | from torch.utils.data import DataLoader 12 | from torchvision.transforms import transforms 13 | 14 | from weatherode.pretrain.dataset import ( 15 | Forecast, 16 | IndividualForecastDataIter, 17 | NpyReader, 18 | ShuffleIterableDataset, 19 | ) 20 | 21 | 22 | def collate_fn(batch): 23 | inp = torch.stack([batch[i][0] for i in range(len(batch))]) 24 | out = torch.stack([batch[i][1] for i in range(len(batch))]) 25 | predict_range = torch.stack([batch[i][2] for i in range(len(batch))]) 26 | variables = batch[0][3] 27 | out_variables = batch[0][4] 28 | return ( 29 | inp, 30 | out, 31 | predict_range, 32 | [v for v in variables], 33 | [v for v in out_variables], 34 | ) 35 | 36 | 37 | class MultiSourceDataModule(LightningDataModule): 38 | """DataModule for multi-source data. 39 | 40 | Args: 41 | dict_root_dirs (Dict): Dictionary of root directories for each source. 42 | dict_start_idx (Dict): Dictionary of start indices ratio (between 0.0 and 1.0) for each source. 43 | dict_end_idx (Dict): Dictionary of end indices ratio (between 0.0 and 1.0) for each source. 44 | dict_buffer_sizes (Dict): Dictionary of buffer sizes for each source. 45 | dict_in_variables (Dict): Dictionary of input variables for each source. 46 | dict_out_variables (Dict): Dictionary of output variables for each source. 47 | dict_max_predict_ranges (Dict, optional): Dictionary of maximum predict ranges for each source. 48 | dict_random_lead_time (Dict, optional): Dictionary of whether to use random lead time for each source. 49 | dict_hrs_each_step (Dict, optional): Dictionary of hours each step for each source. 50 | batch_size (int, optional): Batch size. 51 | num_workers (int, optional): Number of workers. 52 | pin_memory (bool, optional): Whether to pin memory. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | dict_root_dirs: Dict, 58 | dict_start_idx: Dict, 59 | dict_end_idx: Dict, 60 | dict_buffer_sizes: Dict, 61 | dict_in_variables: Dict, 62 | dict_out_variables: Dict, 63 | dict_max_predict_ranges: Dict = {"mpi-esm": 28}, 64 | dict_random_lead_time: Dict = {"mpi-esm": True}, 65 | dict_hrs_each_step: Dict = {"mpi-esm": 6}, 66 | batch_size: int = 64, 67 | num_workers: int = 0, 68 | pin_memory: bool = False, 69 | ): 70 | super().__init__() 71 | if num_workers > 1: 72 | raise NotImplementedError( 73 | "num_workers > 1 is not supported yet. Performance will likely degrage too with larger num_workers." 74 | ) 75 | # this line allows to access init params with 'self.hparams' attribute 76 | self.save_hyperparameters(logger=False) 77 | 78 | out_variables = {} 79 | for k, list_out in dict_out_variables.items(): 80 | if list_out is not None: 81 | out_variables[k] = list_out 82 | else: 83 | out_variables[k] = dict_in_variables[k] 84 | self.hparams.dict_out_variables = out_variables 85 | 86 | self.dict_lister_trains = { 87 | k: list(dp.iter.FileLister(os.path.join(root_dir, "train"))) for k, root_dir in dict_root_dirs.items() 88 | } 89 | self.train_dataset_args = { 90 | k: { 91 | "max_predict_range": dict_max_predict_ranges[k], 92 | "random_lead_time": dict_random_lead_time[k], 93 | "hrs_each_step": dict_hrs_each_step[k], 94 | } 95 | for k in dict_root_dirs.keys() 96 | } 97 | 98 | self.transforms = self.get_normalize() 99 | self.output_transforms = self.get_normalize(self.hparams.dict_out_variables) 100 | 101 | self.dict_data_train: Optional[Dict] = None 102 | 103 | def get_normalize(self, dict_variables: Optional[Dict] = None): 104 | if dict_variables is None: 105 | dict_variables = self.hparams.dict_in_variables 106 | dict_transforms = {} 107 | for k in dict_variables.keys(): 108 | root_dir = self.hparams.dict_root_dirs[k] 109 | variables = dict_variables[k] 110 | normalize_mean = dict(np.load(os.path.join(root_dir, "normalize_mean.npz"))) 111 | mean = [] 112 | for var in variables: 113 | if var != "total_precipitation": 114 | mean.append(normalize_mean[var]) 115 | else: 116 | mean.append(np.array([0.0])) 117 | normalize_mean = np.concatenate(mean) 118 | normalize_std = dict(np.load(os.path.join(root_dir, "normalize_std.npz"))) 119 | normalize_std = np.concatenate([normalize_std[var] for var in variables]) 120 | dict_transforms[k] = transforms.Normalize(normalize_mean, normalize_std) 121 | return dict_transforms 122 | 123 | def get_lat_lon(self): 124 | # assume different data sources have the same lat and lon coverage 125 | lat = np.load(os.path.join(list(self.hparams.dict_root_dirs.values())[0], "lat.npy")) 126 | lon = np.load(os.path.join(list(self.hparams.dict_root_dirs.values())[0], "lon.npy")) 127 | return lat, lon 128 | 129 | def setup(self, stage: Optional[str] = None): 130 | # load datasets only if they're not loaded already 131 | if not self.dict_data_train: 132 | dict_data_train = {} 133 | for k in self.dict_lister_trains.keys(): 134 | lister_train = self.dict_lister_trains[k] 135 | start_idx = self.hparams.dict_start_idx[k] 136 | end_idx = self.hparams.dict_end_idx[k] 137 | variables = self.hparams.dict_in_variables[k] 138 | out_variables = self.hparams.dict_out_variables[k] 139 | max_predict_range = self.hparams.dict_max_predict_ranges[k] 140 | random_lead_time = self.hparams.dict_random_lead_time[k] 141 | hrs_each_step = self.hparams.dict_hrs_each_step[k] 142 | transforms = self.transforms[k] 143 | output_transforms = self.output_transforms[k] 144 | buffer_size = self.hparams.dict_buffer_sizes[k] 145 | dict_data_train[k] = ShuffleIterableDataset( 146 | IndividualForecastDataIter( 147 | Forecast( 148 | NpyReader( 149 | lister_train, 150 | start_idx=start_idx, 151 | end_idx=end_idx, 152 | variables=variables, 153 | out_variables=out_variables, 154 | shuffle=True, 155 | multi_dataset_training=True, 156 | ), 157 | max_predict_range=max_predict_range, 158 | random_lead_time=random_lead_time, 159 | hrs_each_step=hrs_each_step, 160 | ), 161 | transforms, 162 | output_transforms, 163 | ), 164 | buffer_size, 165 | ) 166 | self.dict_data_train = dict_data_train 167 | 168 | def train_dataloader(self): 169 | if not torch.distributed.is_initialized(): 170 | raise NotImplementedError("Only support distributed training") 171 | else: 172 | node_rank = int(os.environ["NODE_RANK"]) 173 | # assert that number of datasets is the same as number of nodes 174 | num_nodes = os.environ.get("NODES", None) 175 | if num_nodes is not None: 176 | num_nodes = int(num_nodes) 177 | assert num_nodes == len(self.dict_data_train.keys()) 178 | 179 | for idx, k in enumerate(self.dict_data_train.keys()): 180 | if idx == node_rank: 181 | data_train = self.dict_data_train[k] 182 | break 183 | 184 | # This assumes that the number of datapoints are going to be the same for all datasets 185 | return DataLoader( 186 | data_train, 187 | batch_size=self.hparams.batch_size, 188 | drop_last=True, 189 | num_workers=self.hparams.num_workers, 190 | pin_memory=self.hparams.pin_memory, 191 | collate_fn=collate_fn, 192 | ) 193 | -------------------------------------------------------------------------------- /src/weatherode/pretrain/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import math 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import IterableDataset 11 | 12 | class NpyReader(IterableDataset): 13 | def __init__( 14 | self, 15 | file_list, 16 | start_idx, 17 | end_idx, 18 | variables, 19 | out_variables, 20 | shuffle: bool = False, 21 | multi_dataset_training=False, 22 | ) -> None: 23 | super().__init__() 24 | start_idx = int(start_idx * len(file_list)) 25 | end_idx = int(end_idx * len(file_list)) 26 | file_list = file_list[start_idx:end_idx] 27 | self.file_list = [f for f in file_list if "climatology" not in f] 28 | self.variables = variables 29 | self.out_variables = out_variables if out_variables is not None else variables 30 | self.shuffle = shuffle 31 | self.multi_dataset_training = multi_dataset_training 32 | 33 | def __iter__(self): 34 | if self.shuffle: 35 | random.shuffle(self.file_list) 36 | worker_info = torch.utils.data.get_worker_info() 37 | if worker_info is None: 38 | iter_start = 0 39 | iter_end = len(self.file_list) 40 | else: 41 | if not torch.distributed.is_initialized(): 42 | rank = 0 43 | world_size = 1 44 | else: 45 | rank = torch.distributed.get_rank() 46 | world_size = torch.distributed.get_world_size() 47 | num_workers_per_ddp = worker_info.num_workers 48 | if self.multi_dataset_training: 49 | num_nodes = 1 50 | # num_nodes = int(os.environ.get("NODES", None)) 51 | num_gpus_per_node = int(world_size / num_nodes) 52 | num_shards = num_workers_per_ddp * num_gpus_per_node 53 | rank = rank % num_gpus_per_node 54 | else: 55 | num_shards = num_workers_per_ddp * world_size 56 | per_worker = int(math.floor(len(self.file_list) / float(num_shards))) 57 | worker_id = rank * num_workers_per_ddp + worker_info.id 58 | iter_start = worker_id * per_worker 59 | iter_end = iter_start + per_worker 60 | 61 | for idx in range(iter_start, iter_end): 62 | path = self.file_list[idx] 63 | data = np.load(path) 64 | yield {k: data[k] for k in self.variables}, self.variables, self.out_variables 65 | 66 | 67 | class Forecast(IterableDataset): 68 | def __init__( 69 | self, dataset: NpyReader, max_predict_range: int = 6, random_lead_time: bool = False, hrs_each_step: int = 1 70 | ) -> None: 71 | super().__init__() 72 | self.dataset = dataset 73 | 74 | self.max_predict_range = max_predict_range 75 | self.random_lead_time = random_lead_time 76 | self.hrs_each_step = hrs_each_step 77 | 78 | def __iter__(self): 79 | for data, variables, out_variables in self.dataset: 80 | # [1095, 5, 32, 64] 81 | x = np.concatenate([data[k].astype(np.float32) for k in data.keys()], axis=1) 82 | x = torch.from_numpy(x) 83 | 84 | # [1095, 5, 32, 64] 85 | y = np.concatenate([data[k].astype(np.float32) for k in out_variables], axis=1) 86 | y = torch.from_numpy(y) 87 | 88 | inputs = x[: -self.max_predict_range] # N, C, H, W 89 | 90 | # t-2, t-1 and t 91 | # prev_inputs = torch.stack([inputs[i:i+3] for i in range(inputs.size(0) - 2)], dim=0) 92 | 93 | # 2 for torchcubicspline 94 | # inputs = inputs[2:] 95 | 96 | if self.random_lead_time: 97 | predict_ranges = torch.randint(low=1, high=self.max_predict_range, size=(inputs.shape[0],)) 98 | else: 99 | predict_ranges = torch.ones(inputs.shape[0]).to(torch.long) * self.max_predict_range 100 | # lead_times = self.hrs_each_step * predict_ranges / 100 101 | # lead_times = lead_times.to(inputs.dtype) 102 | 103 | base_index = torch.arange(inputs.shape[0]).unsqueeze(0) 104 | 105 | range_index = torch.arange(1, self.max_predict_range + 1).unsqueeze(1) 106 | 107 | # 2 for torchcubicspline 108 | final_index = base_index + range_index 109 | 110 | outputs = y[final_index].permute(1, 0, 2, 3, 4) 111 | 112 | yield inputs, outputs, predict_ranges.to(inputs.dtype), variables, out_variables 113 | 114 | 115 | class IndividualForecastDataIter(IterableDataset): 116 | def __init__(self, dataset, transforms: torch.nn.Module, output_transforms: torch.nn.Module, region_info = None): 117 | super().__init__() 118 | self.dataset = dataset 119 | self.transforms = transforms 120 | self.output_transforms = output_transforms 121 | self.region_info = region_info 122 | 123 | # self.lat = lat 124 | # self.lon = lon 125 | # self.set_gauss_kernel() 126 | # breakpoint() 127 | # for (inp, prev_inp, out, predict_ranges, variables, out_variables) in self.dataset: 128 | # optimize_vel(inp, prev_inp, self.kernel) 129 | 130 | def __iter__(self): 131 | for (inp, out, predict_ranges, variables, out_variables) in self.dataset: 132 | assert inp.shape[0] == out.shape[0] 133 | for i in range(inp.shape[0]): 134 | if self.region_info is not None: 135 | yield self.transforms(inp[i]), self.output_transforms(out[i]), predict_ranges[i], variables, out_variables, self.region_info 136 | else: 137 | yield self.transforms(inp[i]), self.output_transforms(out[i]), predict_ranges[i], variables, out_variables 138 | 139 | 140 | class ShuffleIterableDataset(IterableDataset): 141 | def __init__(self, dataset, buffer_size: int) -> None: 142 | super().__init__() 143 | assert buffer_size > 0 144 | self.dataset = dataset 145 | self.buffer_size = buffer_size 146 | 147 | def __iter__(self): 148 | buf = [] 149 | for x in self.dataset: 150 | if len(buf) == self.buffer_size: 151 | idx = random.randint(0, self.buffer_size - 1) 152 | yield buf[idx] 153 | buf[idx] = x 154 | else: 155 | buf.append(x) 156 | random.shuffle(buf) 157 | while buf: 158 | yield buf.pop() 159 | -------------------------------------------------------------------------------- /src/weatherode/pretrain/module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Any 5 | 6 | import torch 7 | from pytorch_lightning import LightningModule 8 | 9 | from weatherode.ode import WeatherODE 10 | from weatherode.utils.lr_scheduler import LinearWarmupCosineAnnealingLR 11 | from weatherode.utils.metrics import lat_weighted_mse 12 | 13 | 14 | class PretrainModule(LightningModule): 15 | """Lightning module for pretraining the WeatherODE model. 16 | 17 | Args: 18 | net (WeatherODE): WeatherODE model. 19 | lr (float, optional): Learning rate. 20 | beta_1 (float, optional): Beta 1 for AdamW. 21 | beta_2 (float, optional): Beta 2 for AdamW. 22 | weight_decay (float, optional): Weight decay for AdamW. 23 | warmup_steps (int, optional): Number of warmup steps. 24 | max_steps (int, optional): Number of total steps. 25 | warmup_start_lr (float, optional): Starting learning rate for warmup. 26 | eta_min (float, optional): Minimum learning rate. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | net: WeatherODE, 32 | lr: float = 5e-4, 33 | beta_1: float = 0.9, 34 | beta_2: float = 0.95, 35 | weight_decay: float = 1e-5, 36 | warmup_steps: int = 10000, 37 | max_steps: int = 200000, 38 | warmup_start_lr: float = 1e-8, 39 | eta_min: float = 1e-8, 40 | ): 41 | super().__init__() 42 | self.save_hyperparameters(logger=False, ignore=["net"]) 43 | self.net = net 44 | 45 | def set_lat_lon(self, lat, lon): 46 | self.lat = lat 47 | self.lon = lon 48 | 49 | def training_step(self, batch: Any, batch_idx: int): 50 | x, y, lead_times, variables, out_variables = batch 51 | 52 | loss_dict, _ = self.net.forward(x, y, lead_times, variables, out_variables, [lat_weighted_mse], lat=self.lat) 53 | loss_dict = loss_dict[0] 54 | for var in loss_dict.keys(): 55 | self.log( 56 | "train/" + var, 57 | loss_dict[var], 58 | on_step=True, 59 | on_epoch=False, 60 | prog_bar=True, 61 | ) 62 | loss = loss_dict["loss"] 63 | 64 | return loss 65 | 66 | def configure_optimizers(self): 67 | decay = [] 68 | no_decay = [] 69 | for name, m in self.named_parameters(): 70 | if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name: 71 | no_decay.append(m) 72 | else: 73 | decay.append(m) 74 | 75 | optimizer = torch.optim.AdamW( 76 | [ 77 | { 78 | "params": decay, 79 | "lr": self.hparams.lr, 80 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 81 | "weight_decay": self.hparams.weight_decay, 82 | }, 83 | { 84 | "params": no_decay, 85 | "lr": self.hparams.lr, 86 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 87 | "weight_decay": 0, 88 | }, 89 | ] 90 | ) 91 | 92 | lr_scheduler = LinearWarmupCosineAnnealingLR( 93 | optimizer, 94 | self.hparams.warmup_steps, 95 | self.hparams.max_steps, 96 | self.hparams.warmup_start_lr, 97 | self.hparams.eta_min, 98 | ) 99 | scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1} 100 | 101 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 102 | -------------------------------------------------------------------------------- /src/weatherode/pretrain/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | 6 | from pytorch_lightning.cli import LightningCLI 7 | 8 | from weatherode.pretrain.datamodule import MultiSourceDataModule 9 | from weatherode.pretrain.module import PretrainModule 10 | 11 | 12 | def main(): 13 | # Initialize Lightning with the model and data modules, and instruct it to parse the config yml 14 | cli = LightningCLI( 15 | model_class=PretrainModule, 16 | datamodule_class=MultiSourceDataModule, 17 | seed_everything_default=42, 18 | save_config_overwrite=True, 19 | run=False, 20 | auto_registry=True, 21 | parser_kwargs={"parser_mode": "omegaconf", "error_handler": None}, 22 | ) 23 | os.makedirs(cli.trainer.default_root_dir, exist_ok=True) 24 | 25 | cli.model.set_lat_lon(*cli.datamodule.get_lat_lon()) 26 | 27 | # fit() runs the training 28 | cli.trainer.fit(cli.model, datamodule=cli.datamodule) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /src/weatherode/regional_forecast/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/weatherode/regional_forecast/datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | from typing import Optional 6 | 7 | import numpy as np 8 | import torch 9 | import torchdata.datapipes as dp 10 | from pytorch_lightning import LightningDataModule 11 | from torch.utils.data import DataLoader, IterableDataset 12 | from torchvision.transforms import transforms 13 | 14 | from weatherode.pretrain.dataset import ( 15 | Forecast, 16 | IndividualForecastDataIter, 17 | NpyReader, 18 | ShuffleIterableDataset, 19 | ) 20 | from weatherode.utils.data_utils import get_region_info 21 | 22 | 23 | def collate_fn_regional(batch): 24 | inp = torch.stack([batch[i][0] for i in range(len(batch))]) 25 | out = torch.stack([batch[i][1] for i in range(len(batch))]) 26 | lead_times = torch.stack([batch[i][2] for i in range(len(batch))]) 27 | variables = batch[0][3] 28 | out_variables = batch[0][4] 29 | region_info = batch[0][5] 30 | return ( 31 | inp, 32 | out, 33 | lead_times, 34 | [v for v in variables], 35 | [v for v in out_variables], 36 | region_info, 37 | ) 38 | 39 | 40 | class RegionalForecastDataModule(LightningDataModule): 41 | """DataModule for regional forecast data. 42 | 43 | Args: 44 | root_dir (str): Root directory for sharded data. 45 | variables (list): List of input variables. 46 | buffer_size (int): Buffer size for shuffling. 47 | out_variables (list, optional): List of output variables. 48 | region (str, optional): The name of the region to finetune WeatherODE on. 49 | predict_range (int, optional): Predict range. 50 | hrs_each_step (int, optional): Hours each step. 51 | batch_size (int, optional): Batch size. 52 | num_workers (int, optional): Number of workers. 53 | pin_memory (bool, optional): Whether to pin memory. 54 | """ 55 | 56 | def __init__( 57 | self, 58 | root_dir, 59 | variables, 60 | buffer_size, 61 | out_variables=None, 62 | region: str = 'NorthAmerica', 63 | predict_range: int = 6, 64 | hrs_each_step: int = 1, 65 | batch_size: int = 64, 66 | num_workers: int = 0, 67 | pin_memory: bool = False, 68 | ): 69 | super().__init__() 70 | 71 | # this line allows to access init params with 'self.hparams' attribute 72 | self.save_hyperparameters(logger=False) 73 | 74 | if isinstance(out_variables, str): 75 | out_variables = [out_variables] 76 | self.hparams.out_variables = out_variables 77 | 78 | self.lister_train = list(dp.iter.FileLister(os.path.join(root_dir, "train"))) 79 | self.lister_val = list(dp.iter.FileLister(os.path.join(root_dir, "val"))) 80 | self.lister_test = list(dp.iter.FileLister(os.path.join(root_dir, "test"))) 81 | 82 | self.transforms = self.get_normalize() 83 | self.output_transforms = self.get_normalize(out_variables) 84 | 85 | self.val_clim = self.get_climatology("val", out_variables) 86 | self.test_clim = self.get_climatology("test", out_variables) 87 | 88 | self.data_train: Optional[IterableDataset] = None 89 | self.data_val: Optional[IterableDataset] = None 90 | self.data_test: Optional[IterableDataset] = None 91 | 92 | def get_normalize(self, variables=None): 93 | if variables is None: 94 | variables = self.hparams.variables 95 | normalize_mean = dict(np.load(os.path.join(self.hparams.root_dir, "normalize_mean.npz"))) 96 | mean = [] 97 | for var in variables: 98 | if var != "total_precipitation": 99 | mean.append(normalize_mean[var]) 100 | else: 101 | mean.append(np.array([0.0])) 102 | normalize_mean = np.concatenate(mean) 103 | normalize_std = dict(np.load(os.path.join(self.hparams.root_dir, "normalize_std.npz"))) 104 | normalize_std = np.concatenate([normalize_std[var] for var in variables]) 105 | return transforms.Normalize(normalize_mean, normalize_std) 106 | 107 | def get_lat_lon(self): 108 | lat = np.load(os.path.join(self.hparams.root_dir, "lat.npy")) 109 | lon = np.load(os.path.join(self.hparams.root_dir, "lon.npy")) 110 | return lat, lon 111 | 112 | def get_climatology(self, partition="val", variables=None): 113 | path = os.path.join(self.hparams.root_dir, partition, "climatology.npz") 114 | clim_dict = np.load(path) 115 | if variables is None: 116 | variables = self.hparams.variables 117 | clim = np.concatenate([clim_dict[var] for var in variables]) 118 | clim = torch.from_numpy(clim) 119 | return clim 120 | 121 | def set_patch_size(self, p): 122 | self.patch_size = p 123 | 124 | def setup(self, stage: Optional[str] = None): 125 | lat, lon = self.get_lat_lon() 126 | region_info = get_region_info(self.hparams.region, lat, lon, self.patch_size) 127 | # load datasets only if they're not loaded already 128 | if not self.data_train and not self.data_val and not self.data_test: 129 | self.data_train = ShuffleIterableDataset( 130 | IndividualForecastDataIter( 131 | Forecast( 132 | NpyReader( 133 | file_list=self.lister_train, 134 | start_idx=0, 135 | end_idx=1, 136 | variables=self.hparams.variables, 137 | out_variables=self.hparams.out_variables, 138 | shuffle=True, 139 | multi_dataset_training=False, 140 | ), 141 | max_predict_range=self.hparams.predict_range, 142 | random_lead_time=False, 143 | hrs_each_step=self.hparams.hrs_each_step, 144 | ), 145 | transforms=self.transforms, 146 | output_transforms=self.output_transforms, 147 | region_info=region_info 148 | ), 149 | buffer_size=self.hparams.buffer_size, 150 | ) 151 | 152 | self.data_val = IndividualForecastDataIter( 153 | Forecast( 154 | NpyReader( 155 | file_list=self.lister_val, 156 | start_idx=0, 157 | end_idx=1, 158 | variables=self.hparams.variables, 159 | out_variables=self.hparams.out_variables, 160 | shuffle=False, 161 | multi_dataset_training=False, 162 | ), 163 | max_predict_range=self.hparams.predict_range, 164 | random_lead_time=False, 165 | hrs_each_step=self.hparams.hrs_each_step, 166 | ), 167 | transforms=self.transforms, 168 | output_transforms=self.output_transforms, 169 | region_info=region_info 170 | ) 171 | 172 | self.data_test = IndividualForecastDataIter( 173 | Forecast( 174 | NpyReader( 175 | file_list=self.lister_test, 176 | start_idx=0, 177 | end_idx=1, 178 | variables=self.hparams.variables, 179 | out_variables=self.hparams.out_variables, 180 | shuffle=False, 181 | multi_dataset_training=False, 182 | ), 183 | max_predict_range=self.hparams.predict_range, 184 | random_lead_time=False, 185 | hrs_each_step=self.hparams.hrs_each_step, 186 | ), 187 | transforms=self.transforms, 188 | output_transforms=self.output_transforms, 189 | region_info=region_info 190 | ) 191 | 192 | def train_dataloader(self): 193 | return DataLoader( 194 | self.data_train, 195 | batch_size=self.hparams.batch_size, 196 | drop_last=False, 197 | num_workers=self.hparams.num_workers, 198 | pin_memory=self.hparams.pin_memory, 199 | collate_fn=collate_fn_regional, 200 | ) 201 | 202 | def val_dataloader(self): 203 | return DataLoader( 204 | self.data_val, 205 | batch_size=self.hparams.batch_size, 206 | shuffle=False, 207 | drop_last=False, 208 | num_workers=self.hparams.num_workers, 209 | pin_memory=self.hparams.pin_memory, 210 | collate_fn=collate_fn_regional, 211 | ) 212 | 213 | def test_dataloader(self): 214 | return DataLoader( 215 | self.data_test, 216 | batch_size=self.hparams.batch_size, 217 | shuffle=False, 218 | drop_last=False, 219 | num_workers=self.hparams.num_workers, 220 | pin_memory=self.hparams.pin_memory, 221 | collate_fn=collate_fn_regional, 222 | ) 223 | -------------------------------------------------------------------------------- /src/weatherode/regional_forecast/module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # credits: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py 5 | from typing import Any 6 | 7 | import torch 8 | from pytorch_lightning import LightningModule 9 | from torchvision.transforms import transforms 10 | 11 | from weatherode.regional_forecast.ode import RegionalWeatherODE 12 | from weatherode.utils.lr_scheduler import LinearWarmupCosineAnnealingLR 13 | from weatherode.utils.metrics import ( 14 | lat_weighted_acc, 15 | lat_weighted_mse, 16 | lat_weighted_mse_val, 17 | lat_weighted_rmse, 18 | lat_weighted_mse_velocity_guess 19 | ) 20 | from weatherode.utils.pos_embed import interpolate_pos_embed 21 | 22 | from tqdm import tqdm 23 | 24 | import matplotlib.pyplot as plt 25 | from mpl_toolkits.basemap import Basemap 26 | import numpy as np 27 | 28 | 29 | class RegionalForecastModule(LightningModule): 30 | """Lightning module for regional forecasting with the WeatherODE model. 31 | 32 | Args: 33 | net (WeatherODE): WeatherODE model. 34 | pretrained_path (str, optional): Path to pre-trained checkpoint. 35 | lr (float, optional): Learning rate. 36 | beta_1 (float, optional): Beta 1 for AdamW. 37 | beta_2 (float, optional): Beta 2 for AdamW. 38 | weight_decay (float, optional): Weight decay for AdamW. 39 | warmup_epochs (int, optional): Number of warmup epochs. 40 | max_epochs (int, optional): Number of total epochs. 41 | warmup_start_lr (float, optional): Starting learning rate for warmup. 42 | eta_min (float, optional): Minimum learning rate. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | net: RegionalWeatherODE, 48 | pretrained_path: str = "", 49 | lr: float = 5e-4, 50 | ode_lr: float = 5e-5, 51 | beta_1: float = 0.9, 52 | beta_2: float = 0.99, 53 | weight_decay: float = 1e-5, 54 | warmup_epochs: int = 10000, 55 | max_epochs: int = 200000, 56 | warmup_start_lr: float = 1e-8, 57 | eta_min: float = 1e-8, 58 | gradient_clip_val: float = 0.5, 59 | gradient_clip_algorithm: str = "value", 60 | train_noise_only: bool = False, 61 | ): 62 | super().__init__() 63 | self.save_hyperparameters(logger=False, ignore=["net"]) 64 | self.net = net 65 | self.skip_optimization = False 66 | if len(pretrained_path) > 0: 67 | self.load_pretrained_weights(pretrained_path) 68 | 69 | def load_pretrained_weights(self, pretrained_path): 70 | if pretrained_path.startswith("http"): 71 | checkpoint = torch.hub.load_state_dict_from_url(pretrained_path) 72 | else: 73 | checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu")) 74 | 75 | print("Loading pre-trained checkpoint from: %s" % pretrained_path) 76 | checkpoint_model = checkpoint["state_dict"] 77 | # interpolate positional embedding 78 | # interpolate_pos_embed(self.net, checkpoint_model, new_size=self.net.img_size) 79 | 80 | state_dict = self.state_dict() 81 | # if self.net.parallel_patch_embed: 82 | # if "token_embeds.proj_weights" not in checkpoint_model.keys(): 83 | # raise ValueError( 84 | # "Pretrained checkpoint does not have token_embeds.proj_weights for parallel processing. Please convert the checkpoints first or disable parallel patch_embed tokenization." 85 | # ) 86 | 87 | # checkpoint_keys = list(checkpoint_model.keys()) 88 | for k in list(checkpoint_model.keys()): 89 | if "channel" in k: 90 | checkpoint_model[k.replace("channel", "var")] = checkpoint_model[k] 91 | del checkpoint_model[k] 92 | for k in list(checkpoint_model.keys()): 93 | if k not in state_dict.keys() or checkpoint_model[k].shape != state_dict[k].shape: 94 | print(f"Removing key {k} from pretrained checkpoint") 95 | del checkpoint_model[k] 96 | 97 | # load pre-trained model 98 | msg = self.load_state_dict(checkpoint_model, strict=False) 99 | print(msg) 100 | 101 | def set_denormalization(self, mean, std): 102 | self.denormalization = transforms.Normalize(mean, std) 103 | 104 | def set_lat_lon(self, lat, lon): 105 | self.lat = lat 106 | self.lon = lon 107 | 108 | def set_pred_range(self, r): 109 | self.pred_range = r 110 | 111 | def set_val_clim(self, clim): 112 | self.val_clim = clim 113 | 114 | def set_test_clim(self, clim): 115 | self.test_clim = clim 116 | 117 | def get_patch_size(self): 118 | return self.net.patch_size 119 | 120 | def training_step(self, batch: Any, batch_idx: int): 121 | x, y, predict_ranges, variables, out_variables, region_info = batch 122 | 123 | min_h, max_h = region_info['min_h'], region_info['max_h'] 124 | min_w, max_w = region_info['min_w'], region_info['max_w'] 125 | x = x[:, :, min_h:max_h+1, min_w:max_w+1] 126 | y = y[:, :, :, min_h:max_h+1, min_w:max_w+1] 127 | lat = self.lat[min_h:max_h+1] 128 | lon = self.lon[min_w:max_w+1] 129 | 130 | loss_dict, _ = self.net.forward( 131 | x, y, predict_ranges, variables, out_variables, [lat_weighted_mse_velocity_guess], lat=lat, lon=lon, epoch=self.current_epoch 132 | ) 133 | loss_dict = loss_dict[0] 134 | 135 | # check nan 136 | has_nan = False 137 | for var, loss_value in loss_dict.items(): 138 | if torch.isnan(loss_value).any(): 139 | has_nan = True 140 | break 141 | 142 | # sum nan 143 | has_nan_tensor = torch.tensor(float(has_nan), device=self.device) 144 | torch.distributed.all_reduce(has_nan_tensor, op=torch.distributed.ReduceOp.SUM) 145 | 146 | if has_nan_tensor.item() > 0: 147 | self.log("train/has_nan", True, on_step=True, on_epoch=False, prog_bar=True) 148 | self.skip_optimization = True 149 | return torch.tensor(0., device=self.device, requires_grad=True) 150 | 151 | self.log("train/has_nan", False, on_step=True, on_epoch=False, prog_bar=True) 152 | self.skip_optimization = False 153 | 154 | for var in loss_dict.keys(): 155 | self.log( 156 | "train/" + var, 157 | loss_dict[var], 158 | on_step=True, 159 | on_epoch=False, 160 | prog_bar=True, 161 | ) 162 | loss = loss_dict["loss"] 163 | 164 | return loss 165 | 166 | 167 | def plot_weather_maps(self, preds, out_variables, batch_idx, lat, lon, type="gt", test=False): 168 | """ 169 | Plots weather maps for given predictions and logs them to the experiment logger. 170 | 171 | Args: 172 | preds (torch.Tensor): The predictions tensor of shape [batch_size, 5, 32, 64]. 173 | out_variables (list): List of variable names for the channels. 174 | batch_idx (int): The batch index. 175 | logger: The experiment logger. 176 | lat (np.array): Latitude values. 177 | lon (np.array): Longitude values. 178 | """ 179 | prefix = "test_images" if test else "val_images" 180 | 181 | batch_size, num_vars, height, width = preds.shape 182 | for var_idx in range(num_vars): 183 | fig, ax = plt.subplots() 184 | m = Basemap(projection='cyl', resolution='c', ax=ax, 185 | llcrnrlat=lat.min(), urcrnrlat=lat.max(), 186 | llcrnrlon=lon.min(), urcrnrlon=lon.max()) 187 | m.drawcoastlines() 188 | m.drawcountries() 189 | m.drawmapboundary() 190 | m.drawparallels(np.arange(-90., 91., 30.), labels=[1, 0, 0, 0]) 191 | m.drawmeridians(np.arange(-180., 181., 60.), labels=[0, 0, 0, 1]) 192 | 193 | data = preds[0, var_idx].cpu().detach().numpy() 194 | 195 | lon_grid, lat_grid = np.meshgrid(lon, lat) 196 | xi, yi = m(lon_grid, lat_grid) 197 | 198 | # Interpolate the data if needed (you can adjust the method and resolution) 199 | data = np.interp(data, (data.min(), data.max()), (0, 1)) 200 | 201 | # Plot the data 202 | cs = m.pcolormesh(xi, yi, data, cmap='RdBu') 203 | fig.colorbar(cs, ax=ax, orientation='vertical', label=out_variables[var_idx]) 204 | 205 | ax.set_title(f"{type}_{out_variables[var_idx]}") 206 | 207 | # Log the figure to the experiment logger 208 | try: 209 | self.logger.experiment.log({f"{prefix}/{out_variables[var_idx]}_{type}_{batch_idx}": wandb.Image(fig)}, step=self.global_step) 210 | except: 211 | self.logger.experiment.add_figure(f"{prefix}/{out_variables[var_idx]}_{type}_{batch_idx}", fig, global_step=self.global_step) 212 | plt.close(fig) 213 | 214 | def validation_step(self, batch: Any, batch_idx: int): 215 | x, y, predict_ranges, variables, out_variables, region_info = batch 216 | 217 | if self.pred_range < 24: 218 | log_postfix = f"{self.pred_range}_hours" 219 | else: 220 | days = int(self.pred_range / 24) 221 | log_postfix = f"{days}_days" 222 | 223 | all_loss_dicts, preds, ode_preds, noise_preds = self.net.evaluate( 224 | x, 225 | y, 226 | predict_ranges, 227 | variables, 228 | out_variables, 229 | transform=self.denormalization, 230 | metrics=[lat_weighted_mse_val, lat_weighted_rmse, lat_weighted_acc], 231 | lat=self.lat, 232 | lon=self.lon, 233 | clim=self.val_clim, 234 | log_postfix=log_postfix, 235 | region_info=region_info, 236 | ) 237 | 238 | loss_dict = {} 239 | for d in all_loss_dicts: 240 | for k in d.keys(): 241 | loss_dict[k] = d[k] 242 | 243 | for var in loss_dict.keys(): 244 | self.log( 245 | "val/" + var, 246 | loss_dict[var], 247 | on_step=False, 248 | on_epoch=True, 249 | prog_bar=False, 250 | sync_dist=True, 251 | ) 252 | return loss_dict 253 | 254 | def test_step(self, batch: Any, batch_idx: int): 255 | x, y, predict_ranges, variables, out_variables, region_info = batch 256 | 257 | if self.pred_range < 24: 258 | log_postfix = f"{self.pred_range}_hours" 259 | else: 260 | days = int(self.pred_range / 24) 261 | log_postfix = f"{days}_days" 262 | 263 | all_loss_dicts, preds, ode_preds, noise_preds = self.net.evaluate( 264 | x, 265 | y, 266 | predict_ranges, 267 | variables, 268 | out_variables, 269 | transform=self.denormalization, 270 | metrics=[lat_weighted_mse_val, lat_weighted_rmse, lat_weighted_acc], 271 | lat=self.lat, 272 | lon=self.lon, 273 | clim=self.test_clim, 274 | log_postfix=log_postfix, 275 | region_info=region_info, 276 | ) 277 | 278 | loss_dict = {} 279 | for d in all_loss_dicts: 280 | for k in d.keys(): 281 | loss_dict[k] = d[k] 282 | 283 | for var in loss_dict.keys(): 284 | self.log( 285 | "test/" + var, 286 | loss_dict[var], 287 | on_step=False, 288 | on_epoch=True, 289 | prog_bar=False, 290 | sync_dist=True, 291 | ) 292 | return loss_dict 293 | 294 | def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): 295 | self.clip_gradients(optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm) 296 | 297 | def configure_optimizers(self): 298 | decay = [] 299 | no_decay = [] 300 | 301 | ode = [] 302 | noise_net = [] 303 | for name, m in self.named_parameters(): 304 | if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name: 305 | no_decay.append(m) 306 | else: 307 | if "net.model" in name: 308 | ode.append(m) 309 | else: 310 | decay.append(m) 311 | 312 | if self.hparams.train_noise_only: 313 | if "net.noise_model" in name: 314 | noise_net.append(m) 315 | else: 316 | m.requires_grad = False 317 | 318 | if self.hparams.train_noise_only: 319 | optimizer = torch.optim.AdamW( 320 | [ 321 | { 322 | "params": noise_net, 323 | "lr": self.hparams.lr, 324 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 325 | "weight_decay": self.hparams.weight_decay, 326 | } 327 | ] 328 | ) 329 | else: 330 | optimizer = torch.optim.AdamW( 331 | [ 332 | { 333 | "params": ode, 334 | "lr": self.hparams.ode_lr, 335 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 336 | "weight_decay": self.hparams.weight_decay, 337 | }, 338 | { 339 | "params": decay, 340 | "lr": self.hparams.lr, 341 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 342 | "weight_decay": self.hparams.weight_decay, 343 | }, 344 | { 345 | "params": no_decay, 346 | "lr": self.hparams.lr, 347 | "betas": (self.hparams.beta_1, self.hparams.beta_2), 348 | "weight_decay": 0, 349 | }, 350 | ] 351 | ) 352 | 353 | lr_scheduler = LinearWarmupCosineAnnealingLR( 354 | optimizer, 355 | self.hparams.warmup_epochs, 356 | self.hparams.max_epochs, 357 | self.hparams.warmup_start_lr, 358 | self.hparams.eta_min, 359 | ) 360 | scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1} 361 | 362 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 363 | 364 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): 365 | # check loss 366 | if self.skip_optimization: 367 | optimizer_closure() 368 | optimizer.zero_grad() 369 | return 370 | optimizer.step(closure=optimizer_closure) 371 | 372 | -------------------------------------------------------------------------------- /src/weatherode/regional_forecast/ode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from weatherode.ode import WeatherODE 3 | from torchdiffeq import odeint 4 | import torch.distributed as dist 5 | 6 | class RegionalWeatherODE(WeatherODE): 7 | def __init__( 8 | self, 9 | default_vars, 10 | method, 11 | img_size=[32, 64], 12 | patch_size=2, 13 | layers=[5, 5, 3, 2], 14 | hidden=[512, 128, 64], 15 | depth=4, 16 | use_err=True, 17 | err_type="2D", 18 | err_with_x=False, 19 | err_with_v=False, 20 | err_with_std=False, 21 | drop_rate=0.1, 22 | time_steps=12, 23 | time_interval=0.001, 24 | rtol=1e-9, 25 | atol=1e-11, 26 | predict_list=[6], 27 | gradient_loss=False 28 | ): 29 | super().__init__(default_vars, method, img_size, patch_size, layers, hidden, depth, use_err, err_type, err_with_x, err_with_v, err_with_std, drop_rate, time_steps, time_interval, rtol, atol, predict_list, gradient_loss) 30 | 31 | def forward(self, x, y, predict_range, variables, out_variables, metric, lat, lon, vis_noise=False, epoch=0): 32 | 33 | v_net_input = torch.cat([x, torch.gradient(x, dim=3)[0], torch.gradient(x, dim=2)[0]], 1) 34 | 35 | v_output = self.v_net(v_net_input) 36 | vx, vy = v_output[:, :x.shape[1]], v_output[:, x.shape[1]:] 37 | 38 | new_lat = torch.tensor(lat).float().expand(x.shape[3], x.shape[2]).T.to(x.device).expand(x.shape[0], 1, x.shape[2], x.shape[3]) * torch.pi / 180 39 | new_lon = torch.tensor(lon).float().expand(x.shape[2], x.shape[3]).to(x.device).expand(x.shape[0], 1, x.shape[2], x.shape[3]) * torch.pi / 180 40 | 41 | new_lat_lon = torch.cat([new_lat, new_lon], 1) 42 | 43 | cos_lat_map, sin_lat_map = torch.cos(new_lat), torch.sin(new_lat) 44 | cos_lon_map, sin_lon_map = torch.cos(new_lon), torch.sin(new_lon) 45 | 46 | pos_feats = torch.cat([cos_lat_map, cos_lon_map, sin_lat_map, sin_lon_map, sin_lat_map * cos_lon_map, sin_lat_map * sin_lon_map], 1) 47 | 48 | ode_x = torch.cat([x, vx, vy, new_lat_lon, pos_feats], 1) 49 | 50 | new_time_steps = torch.linspace(int(predict_range[0] / self.time_steps), int(predict_range[0]), self.time_steps).float().to(x.device) * self.time_interval 51 | 52 | final_result = odeint(self.pde, ode_x, new_time_steps, method=self.method, rtol=self.rtol, atol=self.atol) 53 | 54 | preds = final_result[:, :, :len(self.default_vars)] 55 | 56 | out_ids = self.get_var_ids(tuple(out_variables), preds.device) 57 | y_ = y.permute(1,0,2,3,4)[torch.linspace(int(predict_range[0] / self.time_steps), int(predict_range[0]), self.time_steps).long() - 1] 58 | 59 | if self._check_for_nan(preds, "ode"): 60 | if metric is None: 61 | loss = None 62 | if vis_noise: 63 | return x.expand(self.time_steps, *x.shape)[:, :, out_ids], x.expand(self.time_steps, *x.shape)[:, :, out_ids], x.expand(self.time_steps, *x.shape)[:, :, out_ids] 64 | else: 65 | preds = preds[:, :, out_ids] 66 | loss = [m(preds, preds, preds, y_, out_variables, lat) for m in metric] 67 | return loss, preds 68 | 69 | if self.use_err: 70 | noise_x = torch.cat([preds, new_lat_lon.expand(preds.shape[0], *new_lat_lon.shape), pos_feats.expand(preds.shape[0], *pos_feats.shape)], 2) 71 | if self.err_with_x: 72 | noise_x = torch.cat([noise_x, x.expand(preds.shape[0], *x.shape)], 2) 73 | if self.err_with_v: 74 | noise_x = torch.cat([noise_x, v_output.expand(preds.shape[0], *v_output.shape)], 2) 75 | 76 | if self.err_type == "2D": 77 | noise_x = noise_x.reshape(-1, *noise_x.shape[2:]) 78 | noise_output = self.noise_model(noise_x) 79 | noise_output = noise_output.view(preds.shape[0], -1, *noise_output.shape[1:]) 80 | elif self.err_type == "2DTime": 81 | noise_x = noise_x.reshape(-1, *noise_x.shape[2:]) 82 | 83 | time_embedding = torch.repeat_interleave(torch.linspace(int(predict_range[0] / self.time_steps), int(predict_range[0]), self.time_steps, device=preds.device), preds.shape[1]) 84 | 85 | noise_output = self.noise_model(noise_x, time_embedding) 86 | 87 | noise_output = noise_output.view(preds.shape[0], -1, *noise_output.shape[1:]) 88 | elif self.err_type == "DiT": 89 | noise_x = noise_x.reshape(-1, *noise_x.shape[2:]) 90 | 91 | time_embedding = torch.repeat_interleave(torch.linspace(int(predict_range[0] / self.time_steps), int(predict_range[0]), self.time_steps, device=preds.device), preds.shape[1]) 92 | 93 | noise_output = self.noise_model(noise_x, time_embedding) 94 | noise_output = noise_output.view(preds.shape[0], -1, *noise_output.shape[1:]) 95 | elif self.err_type == "3D": 96 | noise_output = self.noise_model(noise_x) 97 | elif self.err_type == "2+1D": 98 | noise_output = self.noise_model(noise_x) 99 | 100 | if torch.isnan(noise_output).any(): 101 | print("noise nan \n") 102 | 103 | if self._check_for_nan(noise_output, "noise net"): 104 | if metric is None: 105 | loss = None 106 | if vis_noise: 107 | return preds[:, :, out_ids], preds[:, :, out_ids], preds[:, :, out_ids] 108 | else: 109 | loss = [m(preds[:, :, out_ids], preds[:, :, out_ids], preds[:, :, out_ids], y_, out_variables, lat) for m in metric] 110 | return loss, preds 111 | 112 | final_preds = preds + noise_output[:, :, :len(self.default_vars)] 113 | 114 | final_preds = final_preds[:, :, out_ids] 115 | 116 | if metric is None: 117 | # preds = preds[-1] 118 | loss = None 119 | if vis_noise: 120 | return final_preds, preds[:, :, out_ids], noise_output[:, :, :len(self.default_vars)][:, :, out_ids] 121 | else: 122 | loss = [m(final_preds, preds[:, :, out_ids], noise_output[:, :, :len(self.default_vars)][:, :, out_ids], y_, out_variables, lat, gradient_loss=self.gradient_loss, epoch=epoch) for m in metric] 123 | 124 | return loss, final_preds 125 | 126 | def evaluate(self, x, y, predict_range, variables, out_variables, transform, metrics, lat, lon, clim, log_postfix, region_info): 127 | min_h, max_h = region_info['min_h'], region_info['max_h'] 128 | min_w, max_w = region_info['min_w'], region_info['max_w'] 129 | x = x[:, :, min_h:max_h+1, min_w:max_w+1] 130 | y = y[:, :, :, min_h:max_h+1, min_w:max_w+1] 131 | lat = lat[min_h:max_h+1] 132 | lon = lon[min_w:max_w+1] 133 | clim = clim[:, min_h:max_h+1, min_w:max_w+1] 134 | 135 | preds, ode_preds, noise_preds = self.forward(x, y, predict_range, variables, out_variables, metric=None, lat=lat, lon=lon, vis_noise=True) 136 | 137 | ratio = int(predict_range.mean()) // preds.shape[0] 138 | 139 | loss_dict = [] 140 | 141 | for pred_range in self.predict_list: 142 | if pred_range < 24: 143 | log_postfix = f"{pred_range}_hours" 144 | else: 145 | days = pred_range // 24 146 | if pred_range > days * 24: 147 | log_postfix = f"{days}_days_{pred_range - days * 24}_hours" 148 | else: 149 | log_postfix = f"{days}_days" 150 | 151 | steps = pred_range // ratio 152 | 153 | dic_list = [m(preds[steps - 1], y.permute(1,0,2,3,4)[pred_range - 1], transform, out_variables, lat, clim, log_postfix) for m in metrics] 154 | 155 | if pred_range != int(predict_range.mean()): 156 | for dic in dic_list: 157 | dic.pop('w_rmse', None) 158 | 159 | loss_dict += dic_list 160 | 161 | return loss_dict, preds[-1], ode_preds[-1], noise_preds[-1] 162 | -------------------------------------------------------------------------------- /src/weatherode/regional_forecast/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | 6 | from weatherode.regional_forecast.datamodule import RegionalForecastDataModule 7 | from weatherode.regional_forecast.module import RegionalForecastModule 8 | from pytorch_lightning.cli import LightningCLI 9 | 10 | 11 | def main(): 12 | # Initialize Lightning with the model and data modules, and instruct it to parse the config yml 13 | cli = LightningCLI( 14 | model_class=RegionalForecastModule, 15 | datamodule_class=RegionalForecastDataModule, 16 | seed_everything_default=42, 17 | save_config_overwrite=True, 18 | run=False, 19 | auto_registry=True, 20 | parser_kwargs={"parser_mode": "omegaconf", "error_handler": None}, 21 | ) 22 | os.makedirs(cli.trainer.default_root_dir, exist_ok=True) 23 | 24 | cli.datamodule.set_patch_size(cli.model.get_patch_size()) 25 | 26 | normalization = cli.datamodule.output_transforms 27 | mean_norm, std_norm = normalization.mean, normalization.std 28 | mean_denorm, std_denorm = -mean_norm / std_norm, 1 / std_norm 29 | cli.model.set_denormalization(mean_denorm, std_denorm) 30 | cli.model.set_lat_lon(*cli.datamodule.get_lat_lon()) 31 | cli.model.set_pred_range(cli.datamodule.hparams.predict_range) 32 | cli.model.set_val_clim(cli.datamodule.val_clim) 33 | cli.model.set_test_clim(cli.datamodule.test_clim) 34 | 35 | # fit() runs the training 36 | cli.trainer.fit(cli.model, datamodule=cli.datamodule) 37 | 38 | # test the trained model 39 | cli.trainer.test(cli.model, datamodule=cli.datamodule, ckpt_path="best") 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /src/weatherode/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | 6 | NAME_TO_VAR = { 7 | "2m_temperature": "t2m", 8 | "10m_u_component_of_wind": "u10", 9 | "10m_v_component_of_wind": "v10", 10 | "mean_sea_level_pressure": "msl", 11 | "surface_pressure": "sp", 12 | "toa_incident_solar_radiation": "tisr", 13 | "total_precipitation": "tp", 14 | "land_sea_mask": "lsm", 15 | "orography": "orography", 16 | "lattitude": "lat2d", 17 | "geopotential": "z", 18 | "u_component_of_wind": "u", 19 | "v_component_of_wind": "v", 20 | "temperature": "t", 21 | "relative_humidity": "r", 22 | "specific_humidity": "q", 23 | } 24 | 25 | VAR_TO_NAME = {v: k for k, v in NAME_TO_VAR.items()} 26 | 27 | SINGLE_LEVEL_VARS = [ 28 | "2m_temperature", 29 | "10m_u_component_of_wind", 30 | "10m_v_component_of_wind", 31 | "mean_sea_level_pressure", 32 | "surface_pressure", 33 | "toa_incident_solar_radiation", 34 | "total_precipitation", 35 | "land_sea_mask", 36 | "orography", 37 | "lattitude", 38 | ] 39 | PRESSURE_LEVEL_VARS = [ 40 | "geopotential", 41 | "u_component_of_wind", 42 | "v_component_of_wind", 43 | "temperature", 44 | "relative_humidity", 45 | "specific_humidity", 46 | ] 47 | DEFAULT_PRESSURE_LEVELS = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] 48 | 49 | NAME_LEVEL_TO_VAR_LEVEL = {} 50 | 51 | for var in SINGLE_LEVEL_VARS: 52 | NAME_LEVEL_TO_VAR_LEVEL[var] = NAME_TO_VAR[var] 53 | 54 | for var in PRESSURE_LEVEL_VARS: 55 | for l in DEFAULT_PRESSURE_LEVELS: 56 | NAME_LEVEL_TO_VAR_LEVEL[var + "_" + str(l)] = NAME_TO_VAR[var] + "_" + str(l) 57 | 58 | VAR_LEVEL_TO_NAME_LEVEL = {v: k for k, v in NAME_LEVEL_TO_VAR_LEVEL.items()} 59 | 60 | BOUNDARIES = { 61 | 'NorthAmerica': { # 8x14 62 | 'lat_range': (15, 65), 63 | 'lon_range': (220, 300) 64 | }, 65 | 'SouthAmerica': { # 14x10 66 | 'lat_range': (-55, 20), 67 | 'lon_range': (270, 330) 68 | }, 69 | 'Europe': { # 6x8 70 | 'lat_range': (30, 65), 71 | 'lon_range': (0, 40) 72 | }, 73 | 'SouthAsia': { # 10, 14 74 | 'lat_range': (-15, 45), 75 | 'lon_range': (25, 110) 76 | }, 77 | 'EastAsia': { # 10, 12 78 | 'lat_range': (5, 65), 79 | 'lon_range': (70, 150) 80 | }, 81 | 'Australia': { # 10x14 82 | 'lat_range': (-50, 10), 83 | 'lon_range': (100, 180) 84 | }, 85 | 'Global': { # 32, 64 86 | 'lat_range': (-90, 90), 87 | 'lon_range': (0, 360) 88 | } 89 | } 90 | 91 | def get_region_info(region, lat, lon, patch_size): 92 | region = BOUNDARIES[region] 93 | lat_range = region['lat_range'] 94 | lon_range = region['lon_range'] 95 | lat = lat[::-1] # -90 to 90 from south (bottom) to north (top) 96 | h, w = len(lat), len(lon) 97 | lat_matrix = np.expand_dims(lat, axis=1).repeat(w, axis=1) 98 | lon_matrix = np.expand_dims(lon, axis=0).repeat(h, axis=0) 99 | valid_cells = (lat_matrix >= lat_range[0]) & (lat_matrix <= lat_range[1]) & (lon_matrix >= lon_range[0]) & (lon_matrix <= lon_range[1]) 100 | h_ids, w_ids = np.nonzero(valid_cells) 101 | h_from, h_to = h_ids[0], h_ids[-1] 102 | w_from, w_to = w_ids[0], w_ids[-1] 103 | patch_idx = -1 104 | p = patch_size 105 | valid_patch_ids = [] 106 | min_h, max_h = 1e5, -1e5 107 | min_w, max_w = 1e5, -1e5 108 | for i in range(0, h, p): 109 | for j in range(0, w, p): 110 | patch_idx += 1 111 | if (i >= h_from) & (i + p - 1 <= h_to) & (j >= w_from) & (j + p - 1 <= w_to): 112 | valid_patch_ids.append(patch_idx) 113 | min_h = min(min_h, i) 114 | max_h = max(max_h, i + p - 1) 115 | min_w = min(min_w, j) 116 | max_w = max(max_w, j + p - 1) 117 | return { 118 | 'patch_ids': valid_patch_ids, 119 | 'min_h': min_h, 120 | 'max_h': max_h, 121 | 'min_w': min_w, 122 | 'max_w': max_w 123 | } -------------------------------------------------------------------------------- /src/weatherode/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import math 5 | import warnings 6 | from typing import List 7 | 8 | from torch.optim import Optimizer 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | 11 | 12 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 13 | """Sets the learning rate of each parameter group to follow a linear warmup schedule between 14 | warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and 15 | eta_min.""" 16 | 17 | def __init__( 18 | self, 19 | optimizer: Optimizer, 20 | warmup_epochs: int, 21 | max_epochs: int, 22 | warmup_start_lr: float = 0.0, 23 | eta_min: float = 0.0, 24 | last_epoch: int = -1, 25 | ) -> None: 26 | """ 27 | Args: 28 | optimizer (Optimizer): Wrapped optimizer. 29 | warmup_epochs (int): Maximum number of iterations for linear warmup 30 | max_epochs (int): Maximum number of iterations 31 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 32 | eta_min (float): Minimum learning rate. Default: 0. 33 | last_epoch (int): The index of last epoch. Default: -1. 34 | """ 35 | self.warmup_epochs = warmup_epochs 36 | self.max_epochs = max_epochs 37 | self.warmup_start_lr = warmup_start_lr 38 | self.eta_min = eta_min 39 | 40 | super().__init__(optimizer, last_epoch) 41 | 42 | def get_lr(self) -> List[float]: 43 | """Compute learning rate using chainable form of the scheduler.""" 44 | if not self._get_lr_called_within_step: 45 | warnings.warn( 46 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", 47 | UserWarning, 48 | ) 49 | 50 | if self.last_epoch == self.warmup_epochs: 51 | return self.base_lrs 52 | if self.last_epoch == 0: 53 | return [self.warmup_start_lr] * len(self.base_lrs) 54 | if self.last_epoch < self.warmup_epochs: 55 | return [ 56 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 58 | ] 59 | if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 60 | return [ 61 | group["lr"] 62 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 63 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 64 | ] 65 | 66 | return [ 67 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 68 | / ( 69 | 1 70 | + math.cos( 71 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) 72 | ) 73 | ) 74 | * (group["lr"] - self.eta_min) 75 | + self.eta_min 76 | for group in self.optimizer.param_groups 77 | ] 78 | 79 | def _get_closed_form_lr(self) -> List[float]: 80 | """Called when epoch is passed as a param to the `step` function of the scheduler.""" 81 | if self.last_epoch < self.warmup_epochs: 82 | return [ 83 | self.warmup_start_lr 84 | + self.last_epoch * (base_lr - self.warmup_start_lr) / max(1, self.warmup_epochs - 1) 85 | for base_lr in self.base_lrs 86 | ] 87 | 88 | return [ 89 | self.eta_min 90 | + 0.5 91 | * (base_lr - self.eta_min) 92 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 93 | for base_lr in self.base_lrs 94 | ] 95 | -------------------------------------------------------------------------------- /src/weatherode/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from scipy import stats 8 | 9 | 10 | def gradient(x): 11 | left = x 12 | right = F.pad(x, [0, 1, 0, 0])[:, :, :, 1:] 13 | top = x 14 | bottom = F.pad(x, [0, 0, 0, 1])[:, :, 1:, :] 15 | dx = right - left 16 | dy = bottom - top 17 | dx[:, :, :, -1] = 0 18 | dy[:, :, -1, :] = 0 19 | return dx, dy 20 | 21 | def compute_gradient_loss(pred, target): 22 | pred_dx, pred_dy = gradient(pred) 23 | target_dx, target_dy = gradient(target) 24 | 25 | return torch.mean(torch.abs(pred_dx - target_dx) + torch.abs(pred_dy - target_dy)) 26 | 27 | 28 | def mse(pred, y, vars, lat=None, mask=None): 29 | """Mean squared error 30 | 31 | Args: 32 | pred: [B, L, V*p*p] 33 | y: [B, V, H, W] 34 | vars: list of variable names 35 | """ 36 | 37 | loss = (pred - y) ** 2 38 | 39 | loss_dict = {} 40 | 41 | with torch.no_grad(): 42 | for i, var in enumerate(vars): 43 | if mask is not None: 44 | loss_dict[var] = (loss[:, i] * mask).sum() / mask.sum() 45 | else: 46 | loss_dict[var] = loss[:, i].mean() 47 | 48 | if mask is not None: 49 | loss_dict["loss"] = (loss.mean(dim=1) * mask).sum() / mask.sum() 50 | else: 51 | loss_dict["loss"] = loss.mean(dim=1).mean() 52 | 53 | return loss_dict 54 | 55 | 56 | def lat_weighted_mse(pred, y, vars, lat, mask=None): 57 | """Latitude weighted mean squared error 58 | 59 | Allows to weight the loss by the cosine of the latitude to account for gridding differences at equator vs. poles. 60 | 61 | Args: 62 | y: [B, V, H, W] 63 | pred: [B, V, H, W] 64 | vars: list of variable names 65 | lat: H 66 | """ 67 | 68 | pred = pred.to(torch.float32) 69 | 70 | error = (pred - y) ** 2 # [N, C, H, W] 71 | 72 | # lattitude weights 73 | w_lat = np.cos(np.deg2rad(lat)) 74 | w_lat = w_lat / w_lat.mean() # (H, ) 75 | w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=error.dtype, device=error.device) # (1, H, 1) 76 | 77 | loss_dict = {} 78 | with torch.no_grad(): 79 | for i, var in enumerate(vars): 80 | if mask is not None: 81 | loss_dict[var] = (error[:, i] * w_lat * mask).sum() / mask.sum() 82 | else: 83 | loss_dict[var] = (error[:, i] * w_lat).mean() 84 | 85 | if mask is not None: 86 | loss_dict["loss"] = ((error * w_lat.unsqueeze(1)).mean(dim=1) * mask).sum() / mask.sum() 87 | else: 88 | loss_dict["loss"] = (error * w_lat.unsqueeze(1)).mean(dim=1).mean() 89 | 90 | return loss_dict 91 | 92 | 93 | def lat_weighted_mse_val(pred, y, transform, vars, lat, clim, log_postfix): 94 | """Latitude weighted mean squared error 95 | Args: 96 | y: [B, V, H, W] 97 | pred: [B, V, H, W] 98 | vars: list of variable names 99 | lat: H 100 | """ 101 | pred = pred.to(torch.float32) 102 | 103 | error = (pred - y) ** 2 # [B, V, H, W] 104 | 105 | # lattitude weights 106 | w_lat = np.cos(np.deg2rad(lat)) 107 | w_lat = w_lat / w_lat.mean() # (H, ) 108 | w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=error.dtype, device=error.device) # (1, H, 1) 109 | 110 | loss_dict = {} 111 | with torch.no_grad(): 112 | for i, var in enumerate(vars): 113 | loss_dict[f"w_mse_{var}_{log_postfix}"] = (error[:, i] * w_lat).mean() 114 | 115 | loss_dict["w_mse"] = np.mean([loss_dict[k].cpu() for k in loss_dict.keys()]) 116 | 117 | return loss_dict 118 | 119 | 120 | def lat_weighted_rmse(pred, y, transform, vars, lat, clim, log_postfix): 121 | """Latitude weighted root mean squared error 122 | 123 | Args: 124 | y: [B, V, H, W] 125 | pred: [B, V, H, W] 126 | vars: list of variable names 127 | lat: H 128 | """ 129 | 130 | pred = transform(pred.to(torch.float32)) 131 | y = transform(y) 132 | 133 | error = (pred - y) ** 2 # [B, V, H, W] 134 | 135 | # lattitude weights 136 | w_lat = np.cos(np.deg2rad(lat)) 137 | w_lat = w_lat / w_lat.mean() # (H, ) 138 | w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=error.dtype, device=error.device) 139 | 140 | loss_dict = {} 141 | with torch.no_grad(): 142 | for i, var in enumerate(vars): 143 | loss_dict[f"w_rmse_{var}_{log_postfix}"] = torch.mean( 144 | torch.sqrt(torch.mean(error[:, i] * w_lat, dim=(-2, -1))) 145 | ) 146 | 147 | loss_dict["w_rmse"] = np.mean([loss_dict[k].cpu() for k in loss_dict.keys()]) 148 | 149 | return loss_dict 150 | 151 | 152 | def lat_weighted_acc(pred, y, transform, vars, lat, clim, log_postfix): 153 | """ 154 | y: [B, V, H, W] 155 | pred: [B V, H, W] 156 | vars: list of variable names 157 | lat: H 158 | """ 159 | 160 | pred = transform(pred.to(torch.float32)) 161 | y = transform(y) 162 | 163 | # lattitude weights 164 | w_lat = np.cos(np.deg2rad(lat)) 165 | w_lat = w_lat / w_lat.mean() # (H, ) 166 | w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=pred.dtype, device=pred.device) # [1, H, 1] 167 | 168 | # clim = torch.mean(y, dim=(0, 1), keepdim=True) 169 | clim = clim.to(device=y.device).unsqueeze(0) 170 | pred = pred - clim 171 | y = y - clim 172 | loss_dict = {} 173 | 174 | with torch.no_grad(): 175 | for i, var in enumerate(vars): 176 | pred_prime = pred[:, i] - torch.mean(pred[:, i]) 177 | y_prime = y[:, i] - torch.mean(y[:, i]) 178 | loss_dict[f"acc_{var}_{log_postfix}"] = torch.sum(w_lat * pred_prime * y_prime) / torch.sqrt( 179 | torch.sum(w_lat * pred_prime**2) * torch.sum(w_lat * y_prime**2) 180 | ) 181 | 182 | loss_dict["acc"] = np.mean([loss_dict[k].cpu() for k in loss_dict.keys()]) 183 | 184 | return loss_dict 185 | 186 | 187 | def lat_weighted_nrmses(pred, y, transform, vars, lat, clim, log_postfix): 188 | """ 189 | y: [B, V, H, W] 190 | pred: [B V, H, W] 191 | vars: list of variable names 192 | lat: H 193 | """ 194 | 195 | pred = transform(pred.to(torch.float32)) 196 | y = transform(y) 197 | y_normalization = clim 198 | 199 | # lattitude weights 200 | w_lat = np.cos(np.deg2rad(lat)) 201 | w_lat = w_lat / w_lat.mean() # (H, ) 202 | w_lat = torch.from_numpy(w_lat).unsqueeze(-1).to(dtype=y.dtype, device=y.device) # (H, 1) 203 | 204 | loss_dict = {} 205 | with torch.no_grad(): 206 | for i, var in enumerate(vars): 207 | pred_ = pred[:, i] # B, H, W 208 | y_ = y[:, i] # B, H, W 209 | error = (torch.mean(pred_, dim=0) - torch.mean(y_, dim=0)) ** 2 # H, W 210 | error = torch.mean(error * w_lat) 211 | loss_dict[f"w_nrmses_{var}"] = torch.sqrt(error) / y_normalization 212 | 213 | return loss_dict 214 | 215 | 216 | def lat_weighted_nrmseg(pred, y, transform, vars, lat, clim, log_postfix): 217 | """ 218 | y: [B, V, H, W] 219 | pred: [B V, H, W] 220 | vars: list of variable names 221 | lat: H 222 | """ 223 | 224 | pred = transform(pred) 225 | y = transform(y) 226 | y_normalization = clim 227 | 228 | # lattitude weights 229 | w_lat = np.cos(np.deg2rad(lat)) 230 | w_lat = w_lat / w_lat.mean() # (H, ) 231 | w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=y.dtype, device=y.device) # (1, H, 1) 232 | 233 | loss_dict = {} 234 | with torch.no_grad(): 235 | for i, var in enumerate(vars): 236 | pred_ = pred[:, i] # B, H, W 237 | pred_ = torch.mean(pred_ * w_lat, dim=(-2, -1)) # B 238 | y_ = y[:, i] # B, H, W 239 | y_ = torch.mean(y_ * w_lat, dim=(-2, -1)) # B 240 | error = torch.mean((pred_ - y_) ** 2) 241 | loss_dict[f"w_nrmseg_{var}"] = torch.sqrt(error) / y_normalization 242 | 243 | return loss_dict 244 | 245 | 246 | def lat_weighted_nrmse(pred, y, transform, vars, lat, clim, log_postfix): 247 | """ 248 | y: [B, V, H, W] 249 | pred: [B V, H, W] 250 | vars: list of variable names 251 | lat: H 252 | """ 253 | 254 | nrmses = lat_weighted_nrmses(pred, y, transform, vars, lat, clim, log_postfix) 255 | nrmseg = lat_weighted_nrmseg(pred, y, transform, vars, lat, clim, log_postfix) 256 | loss_dict = {} 257 | for var in vars: 258 | loss_dict[f"w_nrmses_{var}"] = nrmses[f"w_nrmses_{var}"] 259 | loss_dict[f"w_nrmseg_{var}"] = nrmseg[f"w_nrmseg_{var}"] 260 | loss_dict[f"w_nrmse_{var}"] = nrmses[f"w_nrmses_{var}"] + 5 * nrmseg[f"w_nrmseg_{var}"] 261 | return loss_dict 262 | 263 | 264 | def remove_nans(pred: torch.Tensor, gt: torch.Tensor): 265 | # pred and gt are two flattened arrays 266 | pred_nan_ids = torch.isnan(pred) | torch.isinf(pred) 267 | pred = pred[~pred_nan_ids] 268 | gt = gt[~pred_nan_ids] 269 | 270 | gt_nan_ids = torch.isnan(gt) | torch.isinf(gt) 271 | pred = pred[~gt_nan_ids] 272 | gt = gt[~gt_nan_ids] 273 | 274 | return pred, gt 275 | 276 | 277 | def pearson(pred, y, transform, vars, lat, log_steps, log_days, clim): 278 | """ 279 | y: [N, T, 3, H, W] 280 | pred: [N, T, 3, H, W] 281 | vars: list of variable names 282 | lat: H 283 | """ 284 | 285 | pred = transform(pred) 286 | y = transform(y) 287 | 288 | loss_dict = {} 289 | with torch.no_grad(): 290 | for i, var in enumerate(vars): 291 | for day, step in zip(log_days, log_steps): 292 | pred_, y_ = pred[:, step - 1, i].flatten(), y[:, step - 1, i].flatten() 293 | pred_, y_ = remove_nans(pred_, y_) 294 | loss_dict[f"pearsonr_{var}_day_{day}"] = stats.pearsonr(pred_.cpu().numpy(), y_.cpu().numpy())[0] 295 | 296 | loss_dict["pearsonr"] = np.mean([loss_dict[k] for k in loss_dict.keys()]) 297 | 298 | return loss_dict 299 | 300 | 301 | def lat_weighted_mean_bias(pred, y, transform, vars, lat, log_steps, log_days, clim): 302 | """ 303 | y: [N, T, 3, H, W] 304 | pred: [N, T, 3, H, W] 305 | vars: list of variable names 306 | lat: H 307 | """ 308 | 309 | pred = transform(pred) 310 | y = transform(y) 311 | 312 | # lattitude weights 313 | w_lat = np.cos(np.deg2rad(lat)) 314 | w_lat = w_lat / w_lat.mean() # (H, ) 315 | w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=pred.dtype, device=pred.device) # [1, H, 1] 316 | 317 | loss_dict = {} 318 | with torch.no_grad(): 319 | for i, var in enumerate(vars): 320 | for day, step in zip(log_days, log_steps): 321 | pred_, y_ = pred[:, step - 1, i].flatten(), y[:, step - 1, i].flatten() 322 | pred_, y_ = remove_nans(pred_, y_) 323 | loss_dict[f"mean_bias_{var}_day_{day}"] = pred_.mean() - y_.mean() 324 | 325 | # pred_mean = torch.mean(w_lat * pred[:, step - 1, i]) 326 | # y_mean = torch.mean(w_lat * y[:, step - 1, i]) 327 | # loss_dict[f"mean_bias_{var}_day_{day}"] = y_mean - pred_mean 328 | 329 | loss_dict["mean_bias"] = np.mean([loss_dict[k].cpu() for k in loss_dict.keys()]) 330 | 331 | return loss_dict 332 | 333 | 334 | # def lat_weighted_mse_velocity_guess(mean, std, y, guess_delta_x, delta_x, vars, lat, mask=None): 335 | # normal_lkl = torch.distributions.normal.Normal(mean, 1e-3 + std) 336 | 337 | # lkl = -normal_lkl.log_prob(y) 338 | 339 | # loss_val = lkl.mean() + 0.0001 * (std ** 2).sum() + 0.0001 * ((guess_delta_x - delta_x) ** 2).sum() 340 | 341 | # return {"loss": loss_val} 342 | 343 | 344 | def lat_weighted_mse_velocity_guess(pred, ode_pred, noise_pred, y, vars, lat, mask=None, gradient_loss=False, epoch=0): 345 | """Latitude weighted mean squared error 346 | 347 | Allows to weight the loss by the cosine of the latitude to account for gridding differences at equator vs. poles. 348 | 349 | Args: 350 | y: [B, V, H, W] 351 | pred: [B, V, H, W] 352 | vars: list of variable names 353 | lat: H 354 | """ 355 | 356 | if gradient_loss: 357 | if epoch < 2: 358 | error = (ode_pred - y) ** 2 359 | else: 360 | error = (noise_pred - (y - ode_pred.detach())) ** 2 361 | else: 362 | error = (pred - y) ** 2 363 | 364 | # ============================= time dimension weights ============================= 365 | # t = error.size(0) 366 | # indices = torch.arange(t, dtype=torch.float32) 367 | # weights = torch.sigmoid((indices - (t-1) / 2) / ((t-1) / 8)).to(error.device).view(t, 1, 1, 1, 1) 368 | # error *= weights 369 | # =================================================================================== 370 | 371 | error = error.reshape(-1, error.shape[2], error.shape[3], error.shape[4]) 372 | 373 | # lattitude weights 374 | w_lat = np.cos(np.deg2rad(lat)) 375 | w_lat = w_lat / w_lat.mean() # (H, ) 376 | w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=error.dtype, device=error.device) # (1, H, 1) 377 | 378 | loss_dict = {} 379 | with torch.no_grad(): 380 | for i, var in enumerate(vars): 381 | if mask is not None: 382 | loss_dict[var] = (error[:, i] * w_lat * mask).sum() / mask.sum() 383 | else: 384 | loss_dict[var] = (error[:, i] * w_lat).mean() 385 | 386 | # loss_dict["velocity_guess"] = ((guess_delta_x - delta_x) ** 2).mean() 387 | 388 | # loss_dict["variable_loss"] = (error * w_lat.unsqueeze(1)).mean(dim=1).mean() 389 | 390 | if mask is not None: 391 | loss_dict["loss"] = ((error * w_lat.unsqueeze(1)).mean(dim=1) * mask).sum() / mask.sum() # + ((guess_delta_x - delta_x) ** 2).mean() 392 | else: 393 | loss_dict["loss"] = (error * w_lat.unsqueeze(1)).mean(dim=1).mean() # + ((guess_delta_x - delta_x) ** 2).mean() 394 | 395 | return loss_dict 396 | 397 | -------------------------------------------------------------------------------- /src/weatherode/utils/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | 11 | import numpy as np 12 | import torch 13 | 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 19 | # MoCo v3: https://github.com/facebookresearch/moco-v3 20 | # -------------------------------------------------------- 21 | def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False): 22 | """ 23 | grid_size: int of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 26 | """ 27 | grid_h = np.arange(grid_size_h, dtype=np.float32) 28 | grid_w = np.arange(grid_size_w, dtype=np.float32) 29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 30 | grid = np.stack(grid, axis=0) 31 | 32 | grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) 33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 34 | if cls_token: 35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 36 | return pos_embed 37 | 38 | 39 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 40 | assert embed_dim % 2 == 0 41 | 42 | # use half of dimensions to encode grid_h 43 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 44 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 45 | 46 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 47 | return emb 48 | 49 | 50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 51 | """ 52 | embed_dim: output dimension for each position 53 | pos: a list of positions to be encoded: size (M,) 54 | out: (M, D) 55 | """ 56 | assert embed_dim % 2 == 0 57 | omega = np.arange(embed_dim // 2, dtype=np.float) 58 | omega /= embed_dim / 2.0 59 | omega = 1.0 / 10000**omega # (D/2,) 60 | 61 | pos = pos.reshape(-1) # (M,) 62 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 63 | 64 | emb_sin = np.sin(out) # (M, D/2) 65 | emb_cos = np.cos(out) # (M, D/2) 66 | 67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 68 | return emb 69 | 70 | 71 | # -------------------------------------------------------- 72 | # Interpolate position embeddings for high-resolution 73 | # References: 74 | # DeiT: https://github.com/facebookresearch/deit 75 | # -------------------------------------------------------- 76 | def interpolate_pos_embed(model, checkpoint_model, new_size=(64, 128)): 77 | if "net.pos_embed" in checkpoint_model: 78 | pos_embed_checkpoint = checkpoint_model["net.pos_embed"] 79 | embedding_size = pos_embed_checkpoint.shape[-1] 80 | orig_num_patches = pos_embed_checkpoint.shape[-2] 81 | patch_size = model.patch_size 82 | w_h_ratio = 2 83 | orig_h = int((orig_num_patches // w_h_ratio) ** 0.5) 84 | orig_w = w_h_ratio * orig_h 85 | orig_size = (orig_h, orig_w) 86 | new_size = (new_size[0] // patch_size, new_size[1] // patch_size) 87 | # print (orig_size) 88 | # print (new_size) 89 | if orig_size[0] != new_size[0]: 90 | print("Interpolate PEs from %dx%d to %dx%d" % (orig_size[0], orig_size[1], new_size[0], new_size[1])) 91 | pos_tokens = pos_embed_checkpoint.reshape(-1, orig_size[0], orig_size[1], embedding_size).permute( 92 | 0, 3, 1, 2 93 | ) 94 | new_pos_tokens = torch.nn.functional.interpolate( 95 | pos_tokens, size=(new_size[0], new_size[1]), mode="bicubic", align_corners=False 96 | ) 97 | new_pos_tokens = new_pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 98 | checkpoint_model["net.pos_embed"] = new_pos_tokens 99 | 100 | 101 | def interpolate_channel_embed(checkpoint_model, new_len): 102 | if "net.channel_embed" in checkpoint_model: 103 | channel_embed_checkpoint = checkpoint_model["net.channel_embed"] 104 | old_len = channel_embed_checkpoint.shape[1] 105 | if new_len <= old_len: 106 | checkpoint_model["net.channel_embed"] = channel_embed_checkpoint[:, :new_len] 107 | --------------------------------------------------------------------------------