├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── analysis └── wandb_api.py ├── climart ├── META_INFO.json ├── __init__.py ├── data_loading │ ├── __init__.py │ ├── constants.py │ ├── data_variables.py │ └── h5_dataset.py ├── data_transform │ ├── __init__.py │ ├── normalization.py │ └── transforms.py ├── datamodules │ ├── __init__.py │ └── pl_climart_datamodule.py ├── interface.py ├── models │ ├── CNNs │ │ ├── CNN.py │ │ └── __init__.py │ ├── GraphNet │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── graph_network.py │ │ └── graph_network_block.py │ ├── MLP.py │ ├── __init__.py │ ├── base_model.py │ ├── baseline.py │ └── modules │ │ ├── __init__.py │ │ ├── additional_layers.py │ │ └── mlp.py ├── train.py └── utils │ ├── __init__.py │ ├── callbacks.py │ ├── config_utils.py │ ├── evaluation.py │ ├── naming.py │ ├── optimization.py │ ├── plotting.py │ ├── postprocessing.py │ ├── utils.py │ └── wandb_callbacks.py ├── configs ├── callbacks │ ├── default.yaml │ ├── none.yaml │ └── wandb.yaml ├── experiment │ ├── example.yaml │ └── reproduce_paper2021_cnn.yaml ├── hparams_search │ └── mnist_optuna.yaml ├── input_transform │ ├── flatten.yaml │ ├── graphnet_level_nodes.yaml │ ├── none.yaml │ └── repeat_global_vars.yaml ├── local │ └── .gitkeep ├── logger │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── none.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── main_config.yaml ├── mode │ ├── debug.yaml │ ├── default.yaml │ └── exp.yaml ├── model │ ├── cnn.yaml │ ├── graphnet.yaml │ └── mlp.yaml ├── optimizer │ ├── adam.yaml │ ├── adamw.yaml │ └── sgd.yaml ├── trainer │ ├── ddp.yaml │ ├── debug.yaml │ └── default.yaml └── transform │ └── default.yaml ├── download_climart.sh ├── download_data_subset.sh ├── env.yml ├── images └── variable_table.png ├── notebooks └── 2022-06-06-get-predictions-pl.ipynb ├── run.py ├── setup.cfg ├── setup.py └── tests ├── test_utils.py └── test_variables.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Lightning-Hydra-Template 143 | configs/local/default.yaml 144 | data/ 145 | logs/ 146 | wandb/ 147 | .env 148 | .autoenv 149 | 150 | out/* 151 | outputs/* 152 | *.out 153 | *.txt 154 | *__pycache__* 155 | .idea 156 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.4.0 4 | hooks: 5 | # list of supported hooks: https://pre-commit.com/hooks.html 6 | - id: trailing-whitespace 7 | - id: end-of-file-fixer 8 | - id: check-yaml 9 | - id: check-added-large-files 10 | - id: debug-statements 11 | - id: detect-private-key 12 | 13 | # python code formatting 14 | - repo: https://github.com/psf/black 15 | rev: 20.8b1 16 | hooks: 17 | - id: black 18 | args: [--line-length, "99"] 19 | 20 | # python import sorting 21 | - repo: https://github.com/PyCQA/isort 22 | rev: 5.8.0 23 | hooks: 24 | - id: isort 25 | args: ["--profile", "black", "--filter-files"] 26 | 27 | # yaml formatting 28 | - repo: https://github.com/pre-commit/mirrors-prettier 29 | rev: v2.3.0 30 | hooks: 31 | - id: prettier 32 | types: [yaml] 33 | 34 | # python code analysis 35 | - repo: https://github.com/PyCQA/flake8 36 | rev: 3.9.2 37 | hooks: 38 | - id: flake8 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ***ClimART*** - A Benchmark Dataset for Emulating Atmospheric Radiative Transfer in Weather and Climate Models 2 | Python 3 | PyTorch 4 | Lightning 5 | Config: hydra 6 | ![CC BY 4.0][cc-by-image] 7 | 8 | [cc-by-image]: https://i.creativecommons.org/l/by/4.0/88x31.png 9 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg 10 | 11 | ## Official PyTorch Implementation 12 | 13 | ### Using deep learning to optimise radiative transfer calculations. 14 | 15 | Our NeurIPS 2021 Datasets Track paper: https://arxiv.org/abs/2111.14671 16 | 17 | Abstract: *Numerical simulations of Earth's weather and climate require substantial amounts of computation. This has led to a growing interest in replacing subroutines that explicitly compute physical processes with approximate machine learning (ML) methods that are fast at inference time. Within weather and climate models, atmospheric radiative transfer (RT) calculations are especially expensive. This has made them a popular target for neural network-based emulators. However, prior work is hard to compare due to the lack of a comprehensive dataset and standardized best practices for ML benchmarking. To fill this gap, we build a large dataset, ClimART, with more than **10 million** samples from present, pre-industrial, and future climate conditions, based on the Canadian Earth System Model. 18 | ClimART poses several methodological challenges for the ML community, such as multiple out-of-distribution test sets, underlying domain physics, and a trade-off between accuracy and inference speed. We also present several novel baselines that indicate shortcomings of datasets and network architectures used in prior work.* 19 | 20 | **Contact:** Venkatesh Ramesh [(venka97 at gmail)](mailto:venka97@gmail.com) or Salva Rühling Cachay [(salvaruehling at gmail)](mailto:salvaruehling@gmail.com).
21 | 22 | ## Overview: 23 | 24 | * ``climart/``: Package with the main code, baselines and ML training logic. 25 | * ``analysis/``: Scripts to create visualization of the results (requires logging). 26 | * ``configs/``: Yaml configuration files for Hydra that define in a modular way (hyper-)parameters. 27 | 28 | ## Getting Started 29 |

30 |

Requirements 31 |

32 |

38 |
39 | 40 |

41 |

Downloading the ClimART Dataset 42 |

43 | By default, only a subset of CLimART is downloaded. 44 | To download the train/val/test years you want, please change the loop in ``data_download.sh.`` appropriately. 45 | To download the whole ClimART dataset, you can simply run 46 | 47 | sudo bash download_climart.sh 48 |

49 |
50 | 51 | **Note:** If you have issues with downloading the data please let us know to help you. 52 | 53 | conda env create -f env.yml # create new environment will all dependencies 54 | conda activate climart # activate the environment called 'climart' 55 | sudo bash download_data_subset.sh # download the dataset (or a subset of it, see above) 56 | python run.py trainer.gpus=0 datamodule.train_years="2000" # train a MLP emulator on 2000 57 | 58 | ## Data Structure 59 | 60 | To avoid storage redundancy, we store one single input array for both pristine- and clear-sky conditions. The dimensions of ClimART’s input arrays are: 61 | 66 | 67 | where N is the data dimension (i.e. the number of examples of a specific year, or, during training, of a batch), 68 | 49 and 50 are the number of layers and levels in a column respectively. Dlay, 4, 82 is the number of features/channels for layers, levels, globals respectively. 69 | 70 | For pristine-sky Dlay = 14, while for clear-sky Dlay = 45, since it contains extra aerosol related variables. The array for pristine-sky conditions can be easily accessed by slicing the first 14 features out of the stored array, e.g.: 71 | ``` pristine_array = layers_array[:, :, : 14] ```. This is automatically done for you when you set the atmospheric 72 | condition type via ```datamodule.exp_type=pristine``` or ```datamodule.exp_type=clear_sky```. 73 | 74 | 75 | ## Baselines 76 | 77 | To reproduce our paper results (for seed = 7), you may choose any of our pre-defined configs in the 78 | [configs/model](configs/model) folder and train it as follows 79 | 80 | ``` 81 | # You can replace mlp with "graphnet", "gcn", or "cnn" to run a different ML model 82 | # To train on the CPU, choose trainer.gpus=0 83 | # Specify the directory where the CLimART data is saved with datamodule.data_dir="" 84 | # Test on the OOD subsets by setting arg datamodule.{test_ood_historic, test_ood_1991, test_ood_future}=True 85 | python run.py seed=7 model=mlp trainer.gpus=1 86 | ``` 87 | 88 | To reproduce the exact CNN model used in the paper, you can use the following command: 89 | ``` 90 | python run.py experiment=reproduce_paper2021_cnn seed=7 # feel free to run for more/other seeds 91 | ``` 92 | Note: You can also take a look at 93 | [this WandB report](https://wandb.ai/salv47/ClimART-public-runs/reports/ClimART-paper-CNN-runs--VmlldzozMDUyOTUy) 94 | which shows the results of three runs of the CNN model from the paper. 95 | 96 | ### Inference 97 | Check out [this notebook](notebooks/2022-06-06-get-predictions-pl.ipynb) for simple code on how to extract the predictions 98 | for each target variable from a trained model (for arbitrary years of the ClimART dataset). 99 | 100 | ## Tips 101 | 102 |

103 |

Reproducibility & Data Generation code 104 |

105 | To best reproduce our baselines and experiments and/or look into how the ClimART dataset was created/designed, 106 | have a look at our `research_code` branch. It operates on pure PyTorch and has a less clean interface/code 107 | than our main branch -- if you have any questions, let us know! 108 |

109 | 110 |

111 |

Testing on OOD data subsets 112 |

113 | By default tests run on the main test dataset only (2007-14), to test on the 114 | historic, future or anomaly test subsets you need to pass/change the arg 115 | datamodule.test_ood_historic=True (and/or test_ood_future=True, test_ood_1991=True), 116 | besides downloading those data files, e.g. via the download_climart.sh script. 117 | 118 |

119 | 120 |

121 |

Overriding nested Hydra config groups 122 |

123 | Nested config groups need to be overridden with a different notation - not with a dot, since it would be interpreted as a string otherwise. 124 | For example, if you want to change the optimizer in the model you want to train, you should run: 125 | python run.py model=graphnet optimizer@model.optimizer=SGD 126 |
127 |

128 | 129 |

130 |

Local configurations 131 |

132 | You can easily use a local config file (that,e.g., overrides data paths, working dir etc.), by putting such a yaml config 133 | in the configs/local subdirectory (Hydra searches for & uses by default the file configs/local/default.yaml, if it exists) 134 |

135 | 136 |

137 |

Wandb 138 |

139 | If you use Wandb, make sure to select the "Group first prefix" option in the panel settings of the web app. 140 | This will make it easier to browse through the logged metrics. 141 |

142 | 143 |

144 |

Credits & Resources 145 |

146 | The following template was extremely useful for getting started with the PL+Hydra implementation: 147 | [ashleve/lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template) 148 |

149 | 150 | 151 | 152 | ## License: 153 | This work is made available under [Attribution 4.0 International (CC BY 4.0)](https://creativecommons.org/licenses/by/4.0/legalcode) license. ![CC BY 4.0][cc-by-shield] 154 | 155 | ## Development 156 | 157 | This repository is currently under active development and you may encounter bugs with some functionality. 158 | Any feedback, extensions & suggestions are welcome! 159 | 160 | 161 | ## Citation 162 | If you find ClimART or this repository helpful, feel free to cite our publication: 163 | 164 | @inproceedings{cachay2021climart, 165 | title={{ClimART}: A Benchmark Dataset for Emulating Atmospheric Radiative Transfer in Weather and Climate Models}, 166 | author={Salva R{\"u}hling Cachay and Venkatesh Ramesh and Jason N. S. Cole and Howard Barker and David Rolnick}, 167 | booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track}, 168 | year={2021}, 169 | url={https://arxiv.org/abs/2111.14671} 170 | } -------------------------------------------------------------------------------- /analysis/wandb_api.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Callable, List 2 | 3 | import wandb 4 | import pandas as pd 5 | 6 | DF_MAPPING = Callable[[pd.DataFrame], pd.DataFrame] 7 | 8 | exp_to_wandb_project = { 9 | 'pristine': "ClimART", 10 | 'clear_sky': "ClimART" 11 | } 12 | 13 | 14 | # Pre-filters 15 | def has_finished(run) -> bool: 16 | return run.state == "finished" 17 | 18 | 19 | def has_final_metric(run) -> bool: 20 | return 'test/hrsc/rmse' in run.summary.keys() 21 | 22 | 23 | def has_max_metric_value(metric: str = 'test/hrsc/rmse', max_metric_value: float = 1.0) -> Callable: 24 | return lambda run: run.summary[metric] <= max_metric_value 25 | 26 | 27 | def has_tags(tags: Union[str, List[str]]) -> Callable: 28 | if isinstance(tags, str): 29 | tags = [tags] 30 | return lambda run: all([tag in run.tags for tag in tags]) 31 | 32 | 33 | def hasnt_tags(tags: Union[str, List[str]]) -> Callable: 34 | if isinstance(tags, str): 35 | tags = [tags] 36 | return lambda run: all([tag not in run.tags for tag in tags]) 37 | 38 | 39 | def has_hyperparam_values(**kwargs) -> Callable: 40 | return lambda run: all(hasattr(run.config, hyperparam) and value == run.config[hyperparam] 41 | for hyperparam, value in kwargs.items()) 42 | 43 | 44 | def larger_than(**kwargs) -> Callable: 45 | return lambda run: all(hasattr(run.config, hyperparam) and value > run.config[hyperparam] 46 | for hyperparam, value in kwargs.items()) 47 | 48 | 49 | def lower_than(**kwargs) -> Callable: 50 | return lambda run: all(hasattr(run.config, hyperparam) and value < run.config[hyperparam] 51 | for hyperparam, value in kwargs.items()) 52 | 53 | 54 | def df_larger_than(**kwargs) -> DF_MAPPING: 55 | def f(df) -> pd.DataFrame: 56 | for k, v in kwargs.items(): 57 | df = df.loc[getattr(df, k) > v] 58 | return df 59 | 60 | return f 61 | 62 | 63 | def df_lower_than(**kwargs) -> DF_MAPPING: 64 | def f(df) -> pd.DataFrame: 65 | for k, v in kwargs.items(): 66 | df = df.loc[getattr(df, k) < v] 67 | return df 68 | 69 | return f 70 | 71 | 72 | def is_model_type(model: str) -> Callable: 73 | return lambda run: model.lower() in run.config['model'].lower() 74 | 75 | 76 | str_to_run_pre_filter = { 77 | 'has_finished': has_finished, 78 | 'has_final_metric': has_final_metric 79 | } 80 | 81 | 82 | # Post-filters 83 | def topk_runs(k: int = 5, 84 | metric: str = 'test/hrsc/rmse', 85 | lower_is_better: bool = True) -> DF_MAPPING: 86 | if lower_is_better: 87 | return lambda df: df.nsmallest(k, metric) 88 | else: 89 | return lambda df: df.nlargest(k, metric) 90 | 91 | 92 | def topk_run_of_each_model_type(k: int = 1, 93 | metric: str = 'test/hrsc/rmse', 94 | lower_is_better: bool = True) -> DF_MAPPING: 95 | topk_filter = topk_runs(k, metric, lower_is_better) 96 | 97 | def topk_runs_per_model(df: pd.DataFrame) -> pd.DataFrame: 98 | models = df.model.unique() 99 | dfs = [] 100 | for model in models: 101 | dfs += [topk_filter(df[df.model == model])] 102 | return pd.concat(dfs) 103 | 104 | return topk_runs_per_model 105 | 106 | 107 | def flatten_column_dicts(df: pd.DataFrame) -> pd.DataFrame: 108 | types = df.dtypes 109 | df = pd.concat([df.drop(['preprocessing_dict'], axis=1), df['preprocessing_dict'].apply(pd.Series)], axis=1) 110 | df = pd.concat([df.drop(['spatial_dim'], axis=1), df['spatial_dim'].apply(pd.Series)], axis=1) 111 | df = pd.concat([df.drop(['input_dim'], axis=1), df['input_dim'].apply(pd.Series)], axis=1) 112 | df = pd.concat([df.drop(['target_variable'], axis=1), df['target_variable'].apply(pd.Series)], axis=1) 113 | df = pd.concat([df.drop(['target_type'], axis=1), df['target_type'].apply(pd.Series)], axis=1) 114 | df = pd.concat([df.drop(['hidden_dims'], axis=1), df['hidden_dims'].apply(tuple)], axis=1) 115 | if 'channels_list' in df.columns: 116 | df = df.drop('channels_list', axis=1) 117 | # df = pd.concat([df.drop(['channels_list'], axis=1), df['channels_list'].apply(tuple)], axis=1) 118 | 119 | # df['channels_list'] = df['channels_list'].apply(frozenset) 120 | for col, dtype in types.items(): 121 | if dtype == dict and dtype != object: 122 | df = pd.concat([df.drop([col], axis=1), df[col].apply(pd.Series)], axis=1) 123 | 124 | return df 125 | 126 | 127 | def non_unique_cols_dropper(df: pd.DataFrame) -> pd.DataFrame: 128 | nunique = df.nunique() 129 | cols_to_drop = nunique[nunique == 1].index 130 | df = df.drop(cols_to_drop, axis=1) 131 | return df 132 | 133 | 134 | def groupby(df: pd.DataFrame, group_by='seed', metric='Test/MAE'): 135 | grouped_df = df.groupby(group_by) 136 | stats = grouped_df[[metric, 'name']].mean() 137 | stats['std'] = grouped_df[[metric, 'name']].std() 138 | return stats 139 | 140 | 141 | str_to_run_post_filter = { 142 | **{ 143 | f"top{k}": topk_runs(k=k) 144 | for k in range(1, 21) 145 | }, 146 | 'best_per_model': topk_run_of_each_model_type(k=1), 147 | **{ 148 | f'top{k}_per_model': topk_run_of_each_model_type(k=k) 149 | for k in range(1, 6) 150 | }, 151 | 'unique_columns': non_unique_cols_dropper, 152 | 'flatten_dicts': flatten_column_dicts 153 | } 154 | 155 | 156 | def get_runs_df( 157 | get_metrics: bool = True, 158 | run_pre_filters: Union[str, List[Union[Callable, str]]] = 'has_finished', 159 | run_post_filters: Union[str, List[Union[DF_MAPPING, str]]] = None, 160 | exp_type: str = 'pristine', verbose: bool = False 161 | ) -> pd.DataFrame: 162 | if run_pre_filters is None: 163 | run_pre_filters = [] 164 | elif not isinstance(run_pre_filters, list): 165 | run_pre_filters: List[Union[Callable, str]] = [run_pre_filters] 166 | run_pre_filters = [(f if callable(f) else str_to_run_pre_filter[f.lower()]) for f in run_pre_filters] 167 | if run_post_filters is None: 168 | run_post_filters = [] 169 | elif not isinstance(run_post_filters, list): 170 | run_post_filters: List[Union[Callable, str]] = [run_post_filters] 171 | run_post_filters = [(f if callable(f) else str_to_run_post_filter[f.lower()]) for f in run_post_filters] 172 | 173 | api = wandb.Api() 174 | # Project is specified by 175 | runs = api.runs(f"ecc-mila7/{exp_to_wandb_project[exp_type]}") 176 | summary_list = [] 177 | config_list = [] 178 | group_list = [] 179 | name_list = [] 180 | tag_list = [] 181 | id_list = [] 182 | for run in runs: 183 | # run.summary are the output key/values like accuracy. 184 | # We call ._json_dict to omit large files 185 | if 'model' not in run.config.keys(): 186 | if verbose: 187 | print(f"Run {run.config['wandb_name'] if 'wandb_name' in run.config else run} filtered out, I.") 188 | continue 189 | 190 | def filter_out(): 191 | for filtr in run_pre_filters: 192 | if not filtr(run): 193 | if verbose: 194 | print(f"Run {run.config['wandb_name']} filtered out, by {filtr.__qualname__}.") 195 | return False 196 | return True 197 | 198 | b = filter_out() 199 | if not b: 200 | continue 201 | 202 | id_list.append(str(run.id)) 203 | tag_list.append(str(run.tags)) 204 | if get_metrics: 205 | summary_list.append(run.summary._json_dict) 206 | # run.config is the input metrics. 207 | config_list.append(run.config) 208 | 209 | # run.name is the name of the run. 210 | name_list.append(run.name) 211 | group_list.append(run.group) 212 | 213 | summary_df = pd.DataFrame.from_records(summary_list) 214 | config_df = pd.DataFrame.from_records(config_list) 215 | name_df = pd.DataFrame({'name': name_list, 'id': id_list, 'tags': tag_list}) 216 | group_df = pd.DataFrame({'group': group_list}) 217 | all_df = pd.concat([name_df, config_df, summary_df, group_df], axis=1) 218 | 219 | cols = [c for c in all_df.columns if not c.startswith('gradients/') and c != 'graph_0'] 220 | all_df = all_df[cols] 221 | if all_df.empty: 222 | raise ValueError('Empty DF!') 223 | for post_filter in run_post_filters: 224 | all_df = post_filter(all_df) 225 | return all_df 226 | -------------------------------------------------------------------------------- /climart/META_INFO.json: -------------------------------------------------------------------------------- 1 | { 2 | "feature_by_var": { 3 | "globals": { 4 | "cszrow": { 5 | "start": 0, 6 | "end": 1 7 | }, 8 | "gtrow": { 9 | "start": 1, 10 | "end": 2 11 | }, 12 | "pressg": { 13 | "start": 2, 14 | "end": 3 15 | }, 16 | "oztop": { 17 | "start": 3, 18 | "end": 4 19 | }, 20 | "emisrow": { 21 | "start": 4, 22 | "end": 5 23 | }, 24 | "salbrol": { 25 | "start": 5, 26 | "end": 9 27 | }, 28 | "csalrol": { 29 | "start": 9, 30 | "end": 13 31 | }, 32 | "emisrot": { 33 | "start": 13, 34 | "end": 19 35 | }, 36 | "gtrot": { 37 | "start": 19, 38 | "end": 25 39 | }, 40 | "farerot": { 41 | "start": 25, 42 | "end": 31 43 | }, 44 | "salbrot": { 45 | "start": 31, 46 | "end": 55 47 | }, 48 | "csalrot": { 49 | "start": 55, 50 | "end": 79 51 | }, 52 | "x_cord": { 53 | "start": 79, 54 | "end": 80 55 | }, 56 | "y_cord": { 57 | "start": 80, 58 | "end": 81 59 | }, 60 | "z_cord": { 61 | "start": 81, 62 | "end": 82 63 | } 64 | }, 65 | "levels": { 66 | "shtj": { 67 | "start": 0, 68 | "end": 1 69 | }, 70 | "tfrow": { 71 | "start": 1, 72 | "end": 2 73 | }, 74 | "level_pressure": { 75 | "start": 2, 76 | "end": 3 77 | }, 78 | "height": { 79 | "start": 3, 80 | "end": 4 81 | } 82 | }, 83 | "layers": { 84 | "shj": { 85 | "start": 0, 86 | "end": 1 87 | }, 88 | "tlayer": { 89 | "start": 1, 90 | "end": 2 91 | }, 92 | "layer_pressure": { 93 | "start": 2, 94 | "end": 3 95 | }, 96 | "ozphs": { 97 | "start": 3, 98 | "end": 4 99 | }, 100 | "qc": { 101 | "start": 4, 102 | "end": 5 103 | }, 104 | "dz": { 105 | "start": 5, 106 | "end": 6 107 | }, 108 | "dshj": { 109 | "start": 6, 110 | "end": 7 111 | }, 112 | "co2rox": { 113 | "start": 7, 114 | "end": 8 115 | }, 116 | "ch4rox": { 117 | "start": 8, 118 | "end": 9 119 | }, 120 | "n2orox": { 121 | "start": 9, 122 | "end": 10 123 | }, 124 | "f11rox": { 125 | "start": 10, 126 | "end": 11 127 | }, 128 | "f12rox": { 129 | "start": 11, 130 | "end": 12 131 | }, 132 | "layer_thickness": { 133 | "start": 12, 134 | "end": 13 135 | }, 136 | "temp_diff": { 137 | "start": 13, 138 | "end": 14 139 | }, 140 | "rhc": { 141 | "start": 14, 142 | "end": 15 143 | }, 144 | "aerin": { 145 | "start": 15, 146 | "end": 24 147 | }, 148 | "sw_ext_sa": { 149 | "start": 24, 150 | "end": 28 151 | }, 152 | "sw_ssa_sa": { 153 | "start": 28, 154 | "end": 32 155 | }, 156 | "sw_g_sa": { 157 | "start": 32, 158 | "end": 36 159 | }, 160 | "lw_abs_sa": { 161 | "start": 36, 162 | "end": 45 163 | } 164 | }, 165 | "outputs_pristine": { 166 | "rldc": { 167 | "start": 0, 168 | "end": 50 169 | }, 170 | "rluc": { 171 | "start": 0, 172 | "end": 50 173 | }, 174 | "rsdc": { 175 | "start": 0, 176 | "end": 50 177 | }, 178 | "rsuc": { 179 | "start": 0, 180 | "end": 50 181 | }, 182 | "hrlc": { 183 | "start": 0, 184 | "end": 49 185 | }, 186 | "hrsc": { 187 | "start": 0, 188 | "end": 49 189 | } 190 | }, 191 | "outputs_clear_sky": { 192 | "rldc": { 193 | "start": 0, 194 | "end": 50 195 | }, 196 | "rluc": { 197 | "start": 0, 198 | "end": 50 199 | }, 200 | "rsdc": { 201 | "start": 0, 202 | "end": 50 203 | }, 204 | "rsuc": { 205 | "start": 0, 206 | "end": 50 207 | }, 208 | "hrlc": { 209 | "start": 0, 210 | "end": 49 211 | }, 212 | "hrsc": { 213 | "start": 0, 214 | "end": 49 215 | } 216 | }, 217 | "outputs_all_sky": { 218 | "rld": { 219 | "start": 0, 220 | "end": 50 221 | }, 222 | "rlu": { 223 | "start": 0, 224 | "end": 50 225 | }, 226 | "rsd": { 227 | "start": 0, 228 | "end": 50 229 | }, 230 | "rsu": { 231 | "start": 0, 232 | "end": 50 233 | }, 234 | "hrl": { 235 | "start": 0, 236 | "end": 49 237 | }, 238 | "hrs": { 239 | "start": 0, 240 | "end": 49 241 | } 242 | } 243 | }, 244 | "input_dims": { 245 | "pristine": { 246 | "globals": 82, 247 | "layers": 14, 248 | "levels": 4 249 | }, 250 | "clear_sky": { 251 | "globals": 82, 252 | "layers": 45, 253 | "levels": 4 254 | } 255 | }, 256 | "variables": { 257 | "cszrow": { 258 | "data_type": "globals", 259 | "num_features": 1, 260 | "only_clear_sky": false 261 | }, 262 | "shj": { 263 | "data_type": "layers", 264 | "num_features": 1, 265 | "only_clear_sky": false 266 | }, 267 | "shtj": { 268 | "data_type": "levels", 269 | "num_features": 1, 270 | "only_clear_sky": false 271 | }, 272 | "gtrow": { 273 | "data_type": "globals", 274 | "num_features": 1, 275 | "only_clear_sky": false 276 | }, 277 | "tlayer": { 278 | "data_type": "layers", 279 | "num_features": 1, 280 | "only_clear_sky": false 281 | }, 282 | "tfrow": { 283 | "data_type": "levels", 284 | "num_features": 1, 285 | "only_clear_sky": false 286 | }, 287 | "pressg": { 288 | "data_type": "globals", 289 | "num_features": 1, 290 | "only_clear_sky": false 291 | }, 292 | "layer_pressure": { 293 | "data_type": "layers", 294 | "num_features": 1, 295 | "only_clear_sky": false 296 | }, 297 | "level_pressure": { 298 | "data_type": "levels", 299 | "num_features": 1, 300 | "only_clear_sky": false 301 | }, 302 | "oztop": { 303 | "data_type": "globals", 304 | "num_features": 1, 305 | "only_clear_sky": false 306 | }, 307 | "ozphs": { 308 | "data_type": "layers", 309 | "num_features": 1, 310 | "only_clear_sky": false 311 | }, 312 | "qc": { 313 | "data_type": "layers", 314 | "num_features": 1, 315 | "only_clear_sky": false 316 | }, 317 | "dz": { 318 | "data_type": "layers", 319 | "num_features": 1, 320 | "only_clear_sky": false 321 | }, 322 | "dshj": { 323 | "data_type": "layers", 324 | "num_features": 1, 325 | "only_clear_sky": false 326 | }, 327 | "co2rox": { 328 | "data_type": "layers", 329 | "num_features": 1, 330 | "only_clear_sky": false 331 | }, 332 | "ch4rox": { 333 | "data_type": "layers", 334 | "num_features": 1, 335 | "only_clear_sky": false 336 | }, 337 | "n2orox": { 338 | "data_type": "layers", 339 | "num_features": 1, 340 | "only_clear_sky": false 341 | }, 342 | "f11rox": { 343 | "data_type": "layers", 344 | "num_features": 1, 345 | "only_clear_sky": false 346 | }, 347 | "f12rox": { 348 | "data_type": "layers", 349 | "num_features": 1, 350 | "only_clear_sky": false 351 | }, 352 | "f113rox": true, 353 | "f114rox": true, 354 | "emisrow": { 355 | "data_type": "globals", 356 | "num_features": 1, 357 | "only_clear_sky": false 358 | }, 359 | "salbrol": { 360 | "data_type": "globals", 361 | "num_features": 4, 362 | "only_clear_sky": false 363 | }, 364 | "csalrol": { 365 | "data_type": "globals", 366 | "num_features": 4, 367 | "only_clear_sky": false 368 | }, 369 | "emisrot": { 370 | "data_type": "globals", 371 | "num_features": 6, 372 | "only_clear_sky": false 373 | }, 374 | "gtrot": { 375 | "data_type": "globals", 376 | "num_features": 6, 377 | "only_clear_sky": false 378 | }, 379 | "farerot": { 380 | "data_type": "globals", 381 | "num_features": 6, 382 | "only_clear_sky": false 383 | }, 384 | "salbrot": { 385 | "data_type": "globals", 386 | "num_features": 24, 387 | "only_clear_sky": false 388 | }, 389 | "csalrot": { 390 | "data_type": "globals", 391 | "num_features": 24, 392 | "only_clear_sky": false 393 | }, 394 | "layer_thickness": { 395 | "data_type": "layers", 396 | "num_features": 1, 397 | "only_clear_sky": false 398 | }, 399 | "temp_diff": { 400 | "data_type": "layers", 401 | "num_features": 1, 402 | "only_clear_sky": false 403 | }, 404 | "height": { 405 | "data_type": "levels", 406 | "num_features": 1, 407 | "only_clear_sky": false 408 | }, 409 | "x_cord": { 410 | "data_type": "globals", 411 | "num_features": 1, 412 | "only_clear_sky": false 413 | }, 414 | "y_cord": { 415 | "data_type": "globals", 416 | "num_features": 1, 417 | "only_clear_sky": false 418 | }, 419 | "z_cord": { 420 | "data_type": "globals", 421 | "num_features": 1, 422 | "only_clear_sky": false 423 | }, 424 | "rhc": { 425 | "data_type": "layers", 426 | "num_features": 1, 427 | "only_clear_sky": true 428 | }, 429 | "aerin": { 430 | "data_type": "layers", 431 | "num_features": 9, 432 | "only_clear_sky": true 433 | }, 434 | "sw_ext_sa": { 435 | "data_type": "layers", 436 | "num_features": 4, 437 | "only_clear_sky": true 438 | }, 439 | "sw_ssa_sa": { 440 | "data_type": "layers", 441 | "num_features": 4, 442 | "only_clear_sky": true 443 | }, 444 | "sw_g_sa": { 445 | "data_type": "layers", 446 | "num_features": 4, 447 | "only_clear_sky": true 448 | }, 449 | "lw_abs_sa": { 450 | "data_type": "layers", 451 | "num_features": 9, 452 | "only_clear_sky": true 453 | }, 454 | "rldc": { 455 | "data_type": "outputs", 456 | "num_features": 50, 457 | "only_clear_sky": true 458 | }, 459 | "rluc": { 460 | "data_type": "outputs", 461 | "num_features": 50, 462 | "only_clear_sky": true 463 | }, 464 | "rsdc": { 465 | "data_type": "outputs", 466 | "num_features": 50, 467 | "only_clear_sky": true 468 | }, 469 | "rsuc": { 470 | "data_type": "outputs", 471 | "num_features": 50, 472 | "only_clear_sky": true 473 | }, 474 | "hrlc": { 475 | "data_type": "outputs", 476 | "num_features": 49, 477 | "only_clear_sky": true 478 | }, 479 | "hrsc": { 480 | "data_type": "outputs", 481 | "num_features": 49, 482 | "only_clear_sky": true 483 | }, 484 | "rld": { 485 | "data_type": "outputs", 486 | "num_features": 50, 487 | "only_clear_sky": false 488 | }, 489 | "rlu": { 490 | "data_type": "outputs", 491 | "num_features": 50, 492 | "only_clear_sky": false 493 | }, 494 | "rsd": { 495 | "data_type": "outputs", 496 | "num_features": 50, 497 | "only_clear_sky": false 498 | }, 499 | "rsu": { 500 | "data_type": "outputs", 501 | "num_features": 50, 502 | "only_clear_sky": false 503 | }, 504 | "hrl": { 505 | "data_type": "outputs", 506 | "num_features": 49, 507 | "only_clear_sky": false 508 | }, 509 | "hrs": { 510 | "data_type": "outputs", 511 | "num_features": 49, 512 | "only_clear_sky": false 513 | } 514 | } 515 | } -------------------------------------------------------------------------------- /climart/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/__init__.py -------------------------------------------------------------------------------- /climart/data_loading/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/data_loading/__init__.py -------------------------------------------------------------------------------- /climart/data_loading/constants.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import numpy as np 5 | from typing import Optional, Dict 6 | 7 | import xarray 8 | from hydra.utils import get_original_cwd, to_absolute_path 9 | 10 | GLOBALS = "globals" 11 | LAYERS = "layers" 12 | LEVELS = "levels" 13 | INPUTS = 'inputs' 14 | PRISTINE = 'pristine' 15 | CLEAR_SKY = 'clear_sky' 16 | OUTPUTS_PRISTINE = f"outputs_{PRISTINE}" 17 | OUTPUTS_CLEARSKY = f"outputs_{CLEAR_SKY}" 18 | SHORTWAVE = 'shortwave' 19 | LONGWAVE = 'longwave' 20 | HEATING_RATES = 'heating_rates' 21 | FLUXES = 'fluxes' 22 | TOA_FLUXES = 'toa_fluxes' 23 | SURFACE_FLUXES = 'surface_fluxes' 24 | 25 | INPUT_TYPES = [GLOBALS, LEVELS, LAYERS] 26 | DATA_TYPES = [GLOBALS, LAYERS, LEVELS, OUTPUTS_PRISTINE, OUTPUTS_CLEARSKY] 27 | EXP_TYPES = [PRISTINE, CLEAR_SKY] 28 | 29 | SPATIAL_DATA_TYPES = [LAYERS, LEVELS, OUTPUTS_PRISTINE, OUTPUTS_CLEARSKY] 30 | DATA_TYPE_DIMS = {GLOBALS: 82, LEVELS: 4, LAYERS: 45, OUTPUTS_CLEARSKY: 1, OUTPUTS_PRISTINE: 1} 31 | 32 | TRAIN_YEARS = list(range(1979, 1991)) + list(range(1994, 2005)) 33 | VAL_YEARS = [2005, 2006] 34 | TEST_YEARS = list(range(2007, 2015)) 35 | OOD_PRESENT_YEARS = [1991] 36 | OOD_FUTURE_YEARS = [2097, 2098, 2099] 37 | OOD_HISTORIC_YEARS = [1850, 1851, 1852] 38 | ALL_YEARS = TRAIN_YEARS + VAL_YEARS + TEST_YEARS + OOD_PRESENT_YEARS + OOD_FUTURE_YEARS + OOD_HISTORIC_YEARS 39 | 40 | # DATA_DIR should be an absolute path, since Hydra changes the working dir... to_absolute_path() solves this. 41 | DATA_DIR = to_absolute_path("ClimART_DATA/") 42 | 43 | 44 | def get_data_subdirs(data_dir: str) -> Dict[str, str]: 45 | d = {INPUTS: os.path.join(data_dir, INPUTS), 46 | OUTPUTS_PRISTINE: os.path.join(data_dir, OUTPUTS_PRISTINE), 47 | OUTPUTS_CLEARSKY: os.path.join(data_dir, OUTPUTS_CLEARSKY) 48 | } 49 | return d 50 | 51 | 52 | def get_metadata(data_dir: str = None): 53 | if data_dir is None: 54 | data_dir = DATA_DIR 55 | path = os.path.join(data_dir, 'META_INFO.json') 56 | 57 | if not os.path.isfile(path): 58 | if os.path.isfile(to_absolute_path(path)): 59 | path = to_absolute_path(path) 60 | else: 61 | err_msg = f' Not able to recover meta information from {path}, as it is not a file!' 62 | raise ValueError(err_msg) 63 | with open(path, 'r') as fp: 64 | meta_info = json.load(fp) 65 | return meta_info 66 | 67 | 68 | def get_statistics(data_dir: str = None): 69 | if data_dir is None: 70 | data_dir = DATA_DIR 71 | path = os.path.join(data_dir, 'statistics.npz') 72 | if not os.path.isfile(path): 73 | if os.path.isfile(to_absolute_path(path)): 74 | path = to_absolute_path(path) 75 | else: 76 | err_msg = f' Not able to recover statistics file from {path}' 77 | raise ValueError(err_msg) 78 | statistics = np.load(path) 79 | return statistics 80 | 81 | 82 | def get_coordinates(data_dir: str = None): 83 | if data_dir is None: 84 | data_dir = DATA_DIR 85 | path = os.path.join(data_dir, 'areacella_fx_CanESM5.nc') 86 | if not os.path.isfile(path): 87 | err_msg = f' Not able to recover coordinates/latitudes/longitudes from {path}' 88 | raise ValueError(err_msg) 89 | return xarray.open_dataset(path) 90 | 91 | 92 | def get_data_dims(exp_type: str) -> (Dict[str, int], Dict[str, int]): 93 | spatial_dim = {GLOBALS: 0, LEVELS: 50, LAYERS: 49} 94 | in_dim = {GLOBALS: 82, LEVELS: 4, LAYERS: 14 if exp_type.lower() == PRISTINE else 45} 95 | out_dim = 100 96 | return {'spatial_dim': spatial_dim, 'input_dim': in_dim, 'output_dim': out_dim} 97 | 98 | 99 | def get_flux_mean(): 100 | import numpy 101 | return numpy.array([ 102 | 296.68795572, 295.59927749, 295.1482046, 294.64596736, 294.11837758, 103 | 293.58230163, 293.04465391, 292.50202651, 291.93796111, 291.40537197, 104 | 290.73892157, 290.04539078, 289.39613274, 288.70448136, 288.01931166, 105 | 287.30607635, 286.62727984, 285.93099547, 285.23425452, 284.56900363, 106 | 283.87613833, 283.12472721, 282.30417958, 281.35523027, 280.2217935, 107 | 278.83756356, 277.10443911, 274.98218953, 272.46624459, 269.56133595, 108 | 266.19161612, 262.41764801, 258.2832474, 254.0212223, 250.09734285, 109 | 246.61271747, 243.56633538, 240.91876951, 238.62519398, 236.65407597, 110 | 234.96650554, 233.53089125, 232.3124631, 231.20520435, 230.09577872, 111 | 228.99557943, 227.9148505, 226.85982104, 226.01209087, 225.32944844, 112 | 113 | 59.0339687, 59.04093876, 59.03547336, 59.02984351, 59.0244958, 114 | 59.02004596, 59.01671905, 59.0139414, 59.01162252, 59.00892333, 115 | 59.00462615, 58.99430663, 58.97929168, 58.94997202, 58.90542686, 116 | 58.83393261, 58.73876127, 58.60566996, 58.42574877, 58.2087681, 117 | 57.94035763, 57.6132036, 57.23794982, 56.81170941, 56.32479326, 118 | 55.77942549, 55.16748318, 54.49360675, 53.77000922, 52.99446721, 119 | 52.16208255, 51.29397739, 50.39621889, 49.53103161, 48.77409974, 120 | 48.12633895, 47.58119014, 47.12617189, 46.74971375, 46.44106945, 121 | 46.19017345, 45.98759973, 45.82464207, 45.68331371, 45.54797404, 122 | 45.41889672, 45.29627039, 45.18034025, 45.09040981, 45.01957376 123 | ], dtype=numpy.float64) 124 | -------------------------------------------------------------------------------- /climart/data_loading/data_variables.py: -------------------------------------------------------------------------------- 1 | _ALL_INPUT_VARS = [ 2 | 'shtj', 3 | 'tfrow', 4 | 'shj', 5 | 'dshj', 6 | 'dz', 7 | 'tlayer', 8 | 'ozphs', 9 | 'qc', 10 | 'co2rox', 11 | 'ch4rox', 12 | 'n2orox', 13 | 'f11rox', 14 | 'f12rox', 15 | 'ccld', 16 | 'rhc', 17 | 'anu', 18 | 'eta', 19 | 'aerin', 20 | 'sw_ext_sa', 21 | 'sw_ssa_sa', 22 | 'sw_g_sa', 23 | 'lw_abs_sa', 24 | 'pressg', 25 | 'gtrow', 26 | 'oztop', 27 | 'cszrow', 28 | 'vtaurow', 29 | 'troprow', 30 | 'emisrow', 31 | 'cldtrol', 32 | 'ncldy', 33 | 'salbrol', 34 | 'csalrol', 35 | 'emisrot', 36 | 'gtrot', 37 | 'farerot', 38 | 'salbrot', 39 | 'csalrot', 40 | 'rel_sub', 41 | 'rei_sub', 42 | 'clw_sub', 43 | 'cic_sub', 44 | 'layer_pressure', 45 | 'level_pressure', 46 | 'layer_thickness', 47 | 'x_cord', 48 | 'y_cord', 49 | 'z_cord', 50 | 'temp_diff', 51 | 'height' 52 | ] 53 | 54 | # -------------------------------------- OUTPUTS/TARGET VARIABLES 55 | LW_HEATING_RATE = 'hrlc' 56 | SW_HEATING_RATE = 'hrsc' 57 | 58 | OUT_HEATING_RATE_CLOUDS = [ 59 | 'hrl', # heating rate (long-wave) 60 | 'hrs' # heating rate (short-wave) 61 | ] 62 | 63 | OUT_HEATING_RATE_NOCLOUDS = [ 64 | LW_HEATING_RATE, # heating rate (long-wave) 65 | SW_HEATING_RATE # heating rate (short-wave) 66 | ] 67 | 68 | OUT_SHORTWAVE_CLOUDS = [ 69 | 'rsd', # solar flux down 70 | 'rsu', # solar flux up 71 | ] 72 | 73 | OUT_SHORTWAVE_NOCLOUDS = [ 74 | 'rsdc', # solar flux down 75 | 'rsuc', # solar flux up 76 | ] 77 | 78 | OUT_LONGWAVE_CLOUDS = [ 79 | "rld", # thermal flux down 80 | 'rlu', # thermal flux up 81 | ] 82 | 83 | OUT_LONGWAVE_NOCLOUDS = [ 84 | "rldc", # thermal flux down 85 | 'rluc', # thermal flux up 86 | ] 87 | 88 | _ALL_OUTPUT_VARS = OUT_SHORTWAVE_NOCLOUDS + OUT_LONGWAVE_NOCLOUDS + OUT_HEATING_RATE_NOCLOUDS 89 | _ALL_VARS = _ALL_INPUT_VARS + _ALL_OUTPUT_VARS 90 | 91 | INPUT_VARS_CLOUDS = [ 92 | 'ccld', # Cloud amount profile 93 | 'anu', # Cloud water content horizontal variability parameter 94 | 'eta', # Fraction black carbon in liquid cloud droplets 95 | 'vtaurow', # Vertically integrated optical thickness at 550 nm for stratospheric aerosols 96 | 'troprow', # Layer index of the tropopause 97 | 'cldtrol', # Total vertically projected cloud fraction 98 | 'ncldy', # Number of cloudy subcolumns in CanAM grid 99 | 'rel_sub', # Liquid cloud effective radius for subcolumns in CanAM grid 100 | 'rei_sub', # Ice cloud effective radius for subcolumns in CanAM grid 101 | 'clw_sub', # Liquid cloud water path for subcolumns in CanAM grid 102 | 'cic_sub' 103 | ] 104 | 105 | INPUT_VARS_AEROSOLS = [ 106 | 'rhc', # Relative humidity 107 | 'aerin', # Relative humidity 108 | 'sw_ext_sa', # Cloud water content horizontal variability parameter 109 | 'sw_ssa_sa', # solar flux up 110 | 'sw_g_sa', # heating rate (long-wave?) 111 | 'lw_abs_sa' # heating rate (short-wave?) 112 | ] 113 | 114 | 115 | # ----------------------------------------------------- 116 | 117 | def no_clouds_exp(name): 118 | return name in ['pristine', 'clear-sky', 'clearsky', 'clear_sky', 'aerosols'] 119 | 120 | 121 | def get_flux_output_variables(target_type: str): 122 | if target_type.lower() == "shortwave": 123 | return OUT_SHORTWAVE_NOCLOUDS 124 | if target_type.lower() == "longwave": 125 | return OUT_LONGWAVE_NOCLOUDS 126 | if target_type.lower() == "shortwave+longwave": 127 | return OUT_SHORTWAVE_NOCLOUDS + OUT_LONGWAVE_NOCLOUDS 128 | raise ValueError(f" Unexpected arg {target_type} for target_type") 129 | 130 | 131 | EXP_TYPES = ['pristine', 'clear_sky'] 132 | -------------------------------------------------------------------------------- /climart/data_transform/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/data_transform/__init__.py -------------------------------------------------------------------------------- /climart/data_transform/normalization.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC 3 | from typing import Optional, Union, Dict, Iterable, Sequence, List, Callable 4 | from omegaconf import DictConfig 5 | import numpy as np 6 | import torch 7 | from torch import Tensor 8 | from climart.data_loading.constants import LEVELS, LAYERS, GLOBALS, get_metadata, get_statistics 9 | from climart.data_loading import constants 10 | from climart.data_loading.data_variables import get_flux_output_variables 11 | from climart.utils.utils import get_logger, get_target_variable_names, get_identity_callable, identity 12 | 13 | NP_ARRAY_MAPPING = Callable[[np.ndarray], np.ndarray] 14 | 15 | log = get_logger(__name__) 16 | 17 | 18 | def in_degree_normalization(adj: np.ndarray): 19 | """ 20 | For Graph networks 21 | :param adj: A N x N adjacency matrix 22 | :return: In-degree normalized matrix (Row-normalized matrix) 23 | """ 24 | rowsum = np.array(adj.sum(1)) 25 | r_inv = np.power(rowsum, -1).flatten() 26 | r_inv[np.isinf(r_inv)] = 0.0 27 | r_mat_inv = np.diag(r_inv) 28 | mx = r_mat_inv @ adj 29 | return mx 30 | 31 | 32 | 33 | class NormalizationMethod(ABC): 34 | def __init__(self, *args, **kwargs): 35 | pass 36 | 37 | def normalize(self, data: np.ndarray, axis=0, *args, **kwargs): 38 | return data 39 | 40 | def inverse_normalize(self, normalized_data: np.ndarray): 41 | return normalized_data 42 | 43 | def stored_values(self): 44 | return dict() 45 | 46 | def __copy__(self): 47 | return type(self)(**self.stored_values()) 48 | 49 | def copy(self): 50 | return self.__copy__() 51 | 52 | def change_input_type(self, new_type): 53 | for attribute, value in self.__dict__.items(): 54 | if new_type in [torch.Tensor, torch.TensorType]: 55 | if isinstance(value, np.ndarray): 56 | setattr(self, attribute, torch.from_numpy(value)) 57 | elif new_type == np.ndarray: 58 | if torch.is_tensor(value): 59 | setattr(self, attribute, value.numpy().cpu()) 60 | else: 61 | setattr(self, attribute, new_type(value)) 62 | 63 | def apply_torch_func(self, fn): 64 | """ 65 | Function to be called to apply a torch function to all tensors of this class, e.g. apply .to(), .cuda(), ..., 66 | Just call this function from within the model's nn.Module._apply() 67 | """ 68 | for attribute, value in self.__dict__.items(): 69 | if torch.is_tensor(value): 70 | setattr(self, attribute, fn(value)) 71 | 72 | 73 | class Z_Normalizer(NormalizationMethod): 74 | def __init__(self, mean, std, **kwargs): 75 | super().__init__(**kwargs) 76 | self.mean = mean 77 | self.std = std 78 | 79 | def normalize(self, data, axis=None, *args, **kwargs): 80 | return self(data) 81 | 82 | def inverse_normalize(self, normalized_data): 83 | data = normalized_data * self.std + self.mean 84 | return data 85 | 86 | def stored_values(self): 87 | return {'mean': self.mean, 'std': self.std} 88 | 89 | def __call__(self, data): 90 | return (data - self.mean) / self.std 91 | 92 | 93 | class MinMax_Normalizer(NormalizationMethod): 94 | def __init__(self, min=None, max_minus_min=None, max=None, **kwargs): 95 | super().__init__(**kwargs) 96 | self.min = min 97 | if min: 98 | assert max_minus_min or max 99 | self.max_minus_min = max_minus_min or max - min 100 | 101 | def normalize(self, data, axis=None, *args, **kwargs): 102 | # self.min = np.min(data, axis=axis) 103 | # self.max_minus_min = (np.max(data, axis=axis) - self.min) 104 | return self(data) 105 | 106 | def inverse_normalize(self, normalized_data): 107 | shapes = normalized_data.shape 108 | if len(shapes) >= 2: 109 | normalized_data = normalized_data.reshape(normalized_data.shape[0], -1) 110 | data = normalized_data * self.max_minus_min + self.min 111 | if len(shapes) >= 2: 112 | data = data.reshape(shapes) 113 | return data 114 | 115 | def stored_values(self): 116 | return {'min': self.min, 'max_minus_min': self.max_minus_min} 117 | 118 | def __call__(self, data): 119 | return (data - self.min) / self.max_minus_min 120 | 121 | 122 | class LogNormalizer(NormalizationMethod): 123 | def normalize(self, data, *args, **kwargs): 124 | normalized_data = self(data) 125 | return normalized_data 126 | 127 | def inverse_normalize(self, normalized_data): 128 | data = np.exp(normalized_data) 129 | return data 130 | 131 | def __call__(self, data: np.ndarray, *args, **kwargs): 132 | return np.log(data) 133 | 134 | 135 | class LogZ_Normalizer(NormalizationMethod): 136 | def __init__(self, mean=None, std=None, **kwargs): 137 | super().__init__(**kwargs) 138 | self.z_normalizer = Z_Normalizer(mean, std) 139 | 140 | def normalize(self, data, *args, **kwargs): 141 | normalized_data = np.log(data + 1e-5) 142 | normalized_data = self.z_normalizer.normalize(normalized_data) 143 | return normalized_data 144 | 145 | def inverse_normalize(self, normalized_data): 146 | data = self.z_normalizer.inverse_normalize(normalized_data) 147 | data = np.exp(data) - 1e-5 148 | return data 149 | 150 | def stored_values(self): 151 | return self.z_normalizer.stored_values() 152 | 153 | def change_input_type(self, new_type): 154 | self.z_normalizer.change_input_type(new_type) 155 | 156 | def apply_torch_func(self, fn): 157 | self.z_normalizer.apply_torch_func(fn) 158 | 159 | def __call__(self, data, *args, **kwargs): 160 | normalized_data = np.log(data + 1e-5) 161 | return self.z_normalizer(normalized_data) 162 | 163 | 164 | class MinMax_LogNormalizer(NormalizationMethod): 165 | def __init__(self, min=None, max_minus_min=None, **kwargs): 166 | super().__init__(**kwargs) 167 | self.min_max_normalizer = MinMax_Normalizer(min, max_minus_min) 168 | 169 | def normalize(self, data, *args, **kwargs): 170 | normalized_data = self.min_max_normalizer.normalize(data) 171 | normalized_data = np.log(normalized_data) 172 | return normalized_data 173 | 174 | def inverse_normalize(self, normalized_data): 175 | data = np.exp(normalized_data) 176 | data = self.min_max_normalizer.inverse_normalize(data) 177 | return data 178 | 179 | def stored_values(self): 180 | return self.min_max_normalizer.stored_values() 181 | 182 | def change_input_type(self, new_type): 183 | self.min_max_normalizer.change_input_type(new_type) 184 | 185 | def apply_torch_func(self, fn): 186 | self.min_max_normalizer.apply_torch_func(fn) 187 | 188 | 189 | def get_normalizer(normalizer='z', *args, **kwargs) -> NormalizationMethod: 190 | normalizer = normalizer.lower().strip().replace('-', '_').replace('&', '+') 191 | supported_normalizers = ['z', 192 | 'min_max', 193 | 'min_max+log', 'min_max_log', 194 | 'log_z', 195 | 'log', 196 | 'none'] 197 | assert normalizer in supported_normalizers, f"Unsupported Normalization {normalizer} not in {str(supported_normalizers)}" 198 | if normalizer == 'z': 199 | return Z_Normalizer(*args, **kwargs) 200 | elif normalizer == 'min_max': 201 | return MinMax_Normalizer(*args, **kwargs) 202 | elif normalizer in ['min_max+log', 'min_max_log']: 203 | return MinMax_LogNormalizer(*args, **kwargs) 204 | elif normalizer in ['logz', 'log_z']: 205 | return LogZ_Normalizer(*args, **kwargs) 206 | elif normalizer == 'log': 207 | return LogNormalizer(*args, **kwargs) 208 | else: 209 | return NormalizationMethod(*args, **kwargs) # like no normalizer 210 | 211 | 212 | class Normalizer: 213 | def __init__( 214 | self, 215 | datamodule_config: DictConfig, 216 | input_normalization: Optional[str] = None, 217 | output_normalization: Optional[str] = None, 218 | spatial_normalization_in: bool = False, 219 | spatial_normalization_out: bool = False, 220 | log_scaling: Union[bool, List[str]] = False, 221 | data_dir: Optional[str] = None, 222 | verbose: bool = True 223 | ): 224 | """ 225 | input_normalization (str): "z" for z-scaling (zero mean and unit standard deviation) 226 | """ 227 | if not verbose: 228 | log.setLevel(logging.WARNING) 229 | 230 | if data_dir is None: 231 | data_dir = datamodule_config.get("data_dir") or constants.DATA_DIR 232 | exp_type = datamodule_config.get("exp_type") 233 | target_type = datamodule_config.get("target_type") 234 | target_variable = datamodule_config.get("target_variable") 235 | 236 | self._layer_mask = 45 if exp_type == constants.CLEAR_SKY else 14 237 | self._recover_meta_info(data_dir) 238 | self._input_normalizer: Dict[str, NormalizationMethod] = dict() 239 | self._output_normalizer: Optional[Dict[str, NormalizationMethod]] = None 240 | 241 | self._target_variables = get_target_variable_names(target_type, target_variable) 242 | if input_normalization is not None: 243 | norma_type = '_spatial' if spatial_normalization_in else '' 244 | info_msg = f" Applying {norma_type.lstrip('_')} {input_normalization} normalization to input data," \ 245 | f" based on pre-computed stats." 246 | log.info(info_msg) 247 | 248 | precomputed_stats = get_statistics(data_dir) 249 | precomputed_stats = {k: precomputed_stats[k] for k in precomputed_stats.keys() if 250 | (('spatial' in k and spatial_normalization_in) or ('spatial' not in k))} 251 | if isinstance(log_scaling, list) or log_scaling: 252 | log.info(' Log scaling pressure and height variables! (no other normalization is applied to them)') 253 | post_log_vals = dict(pressg=(11.473797, 0.10938317), 254 | layer_pressure=(9.29207, 2.097411), 255 | dz=(6.5363674, 1.044927), 256 | layer_thickness=(6.953938313568889, 1.3751644503732554), 257 | level_pressure=(9.252319, 2.1721559)) 258 | vars_to_log_scale = ['pressg', 'layer_pressure', 'dz', 'layer_thickness', 'level_pressure'] 259 | self._layer_log_mask = torch.tensor([2, 5, 12]) 260 | for var in vars_to_log_scale: 261 | dtype = self._variables[var]['data_type'] 262 | s, e = self.feature_by_var[dtype][var]['start'], self.feature_by_var[dtype][var]['end'] 263 | # precomputed_stats[f'{dtype}{prefix}_mean'][..., s:e] = 0 264 | # precomputed_stats[f'{dtype}{prefix}_std'][..., s:e] = 1 265 | precomputed_stats[f'{dtype}{norma_type}_mean'][..., s:e] = post_log_vals[var][0] 266 | precomputed_stats[f'{dtype}{norma_type}_std'][..., s:e] = post_log_vals[var][1] 267 | 268 | def log_scaler(X: Dict[str, Tensor]) -> Dict[str, Tensor]: 269 | # layer_log_mask = torch.tensor([2, 5, 12]) 270 | X[GLOBALS][2] = torch.log(X[GLOBALS][2]) 271 | X[LEVELS][..., 2] = torch.log(X[LEVELS][..., 2]) 272 | X[LAYERS][..., self._layer_log_mask] = torch.log(X[LAYERS][..., self._layer_log_mask]) 273 | return X 274 | 275 | self._log_scaler_func = log_scaler 276 | else: 277 | self._log_scaler_func = identity 278 | 279 | for data_type in [GLOBALS, LEVELS, LAYERS]: 280 | if input_normalization is not None: 281 | normer_kwargs = dict( 282 | mean=precomputed_stats[data_type + f'{norma_type}_mean'], 283 | std=precomputed_stats[data_type + f'{norma_type}_std'], 284 | min=precomputed_stats[data_type + f'{norma_type}_min'], 285 | max=precomputed_stats[data_type + f'{norma_type}_max'], 286 | ) 287 | if data_type == LAYERS: 288 | for k, v in normer_kwargs.items(): 289 | normer_kwargs[k] = v[..., :self._layer_mask] 290 | normalizer = get_normalizer( 291 | input_normalization, 292 | **normer_kwargs 293 | ) 294 | else: 295 | normalizer = identity 296 | self._input_normalizer[data_type] = normalizer 297 | if output_normalization is not None: 298 | self._output_normalizer = dict() 299 | px = "spatial_" if spatial_normalization_out else "" 300 | precomputed_stats = get_statistics(data_dir) 301 | for tv in get_flux_output_variables(target_type): 302 | normer_kwargs = dict( 303 | mean=precomputed_stats[f"outputs_{exp_type}_{tv}_{px}mean"], 304 | std=precomputed_stats[f"outputs_{exp_type}_{tv}_{px}std"], 305 | min=precomputed_stats[f"outputs_{exp_type}_{tv}_{px}min"], 306 | max=precomputed_stats[f"outputs_{exp_type}_{tv}_{px}max"], 307 | ) 308 | self._output_normalizer[tv] = get_normalizer(**normer_kwargs, normalizer=output_normalization) 309 | 310 | def _recover_meta_info(self, data_dir: str): 311 | meta_info = get_metadata(data_dir) 312 | self._variables = meta_info['variables'] 313 | self._vars_used_or_not = list(self._variables.keys()) 314 | self._feature_by_var = meta_info['feature_by_var'] 315 | 316 | @property 317 | def feature_by_var(self): 318 | return self._feature_by_var 319 | 320 | def get_normalizer(self, data_type: str) -> Union[NP_ARRAY_MAPPING, NormalizationMethod]: 321 | return self._input_normalizer[data_type] 322 | 323 | def get_normalizers(self) -> Dict[str, Union[NP_ARRAY_MAPPING, NormalizationMethod]]: 324 | return { 325 | data_type: self.get_normalizer(data_type) 326 | for data_type in constants.INPUT_TYPES 327 | } 328 | 329 | def set_normalizer(self, data_type: str, new_normalizer: Optional[NP_ARRAY_MAPPING]): 330 | if new_normalizer is None: 331 | new_normalizer = identity 332 | if data_type in constants.INPUT_TYPES: 333 | self._input_normalizer[data_type] = new_normalizer 334 | else: 335 | log.info(f" Setting output normalizer, after calling set_normalizer with data_type={data_type}") 336 | self._output_normalizer = new_normalizer 337 | 338 | def set_input_normalizers(self, new_normalizer: Optional[NP_ARRAY_MAPPING]): 339 | for data_type in constants.INPUT_TYPES: 340 | self.set_normalizer(data_type, new_normalizer) 341 | 342 | @property 343 | def output_normalizer(self): 344 | return self._output_normalizer 345 | 346 | def normalize(self, X: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 347 | for input_type, rawX in X.items(): 348 | X[input_type] = self._input_normalizer[input_type](rawX) 349 | 350 | X = self._log_scaler_func(X) 351 | return X 352 | 353 | def __call__(self, X: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 354 | return self.normalize(X) -------------------------------------------------------------------------------- /climart/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/datamodules/__init__.py -------------------------------------------------------------------------------- /climart/datamodules/pl_climart_datamodule.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, List, Callable 3 | 4 | from pytorch_lightning import LightningDataModule 5 | from pytorch_lightning.utilities.types import EVAL_DATALOADERS 6 | from torch.utils.data import DataLoader 7 | 8 | from climart.data_loading.constants import TRAIN_YEARS, TEST_YEARS, OOD_PRESENT_YEARS, OOD_HISTORIC_YEARS, \ 9 | OOD_FUTURE_YEARS, VAL_YEARS 10 | from climart.data_loading.h5_dataset import ClimART_HdF5_Dataset 11 | from climart.data_transform.normalization import Normalizer 12 | from climart.data_transform.transforms import AbstractTransform 13 | from climart.utils.utils import year_string_to_list 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | 18 | class ClimartDataModule(LightningDataModule): 19 | """ 20 | ---------------------------------------------------------------------------------------------------------- 21 | A DataModule implements 5 key methods: 22 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) 23 | - setup (things to do on every accelerator in distributed mode) 24 | - train_dataloader (the training dataloader) 25 | - val_dataloader (the validation dataloader(s)) 26 | - test_dataloader (the test dataloader(s)) 27 | 28 | This allows you to share a full dataset without explaining how to download, 29 | split, transform and process the data 30 | 31 | Read the docs: 32 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html 33 | """ 34 | 35 | def __init__( 36 | self, 37 | exp_type: str, 38 | target_type: str, 39 | target_variable: str, 40 | data_dir: Optional[str] = None, 41 | train_years: str = "1999-2000", 42 | validation_years: str = "2005", 43 | predict_years: str = "2014", 44 | input_transform: Optional[AbstractTransform] = None, 45 | normalizer: Optional[Normalizer] = None, 46 | batch_size: int = 64, 47 | eval_batch_size: int = 512, 48 | num_workers: int = 0, 49 | pin_memory: bool = True, 50 | load_train_into_mem: bool = False, 51 | load_test_into_mem: bool = False, 52 | load_valid_into_mem: bool = True, 53 | test_main_dataset: bool = True, 54 | test_ood_1991: bool = True, 55 | test_ood_historic: bool = True, 56 | test_ood_future: bool = True, 57 | verbose: bool = True 58 | ): 59 | """ 60 | Args: 61 | exp_type (str): 'pristine' or 'clear-sky' 62 | target_type (str): 'longwave" or 'shortwave' 63 | target_variable (str): 'fluxes' or 'heating-rate' 64 | data_dir (str or None): If str: A path to the data folder, if None: constants.DATA_DIR will be used. 65 | batch_size (int): Batch size for the training dataloader 66 | eval_batch_size (int): Batch size for the test and validation dataloader's 67 | num_workers (int): Dataloader arg for higher efficiency 68 | pin_memory (bool): Dataloader arg for higher efficiency 69 | test_main_dataset (bool): Whether to test and compute metrics on main test dataset (2007-14). Default: True 70 | test_ood_1991 (bool): Whether to test and compute metrics on OOD/anomaly test year 1991. Default: True 71 | test_ood_historic (bool): Whether to test and compute metrics on historic test years 1850-52. Default: True 72 | test_ood_future (bool): Whether to test and compute metrics on future test years 2097-99. Default: True 73 | seed (int): Used to seed the validation-test set split, such that the split will always be the same. 74 | """ 75 | super().__init__() 76 | # The following makes all args available as, e.g., self.hparams.batch_size 77 | self.save_hyperparameters(ignore=["input_transform", "normalizer"]) 78 | self.input_transform = input_transform # self.hparams.input_transform 79 | self.normalizer = normalizer 80 | 81 | self._data_train: Optional[ClimART_HdF5_Dataset] = None 82 | self._data_val: Optional[ClimART_HdF5_Dataset] = None 83 | self._data_test: Optional[List[ClimART_HdF5_Dataset]] = None 84 | self._data_predict: Optional[List[ClimART_HdF5_Dataset]] = None 85 | self._test_set_names: Optional[List[str]] = None 86 | 87 | def prepare_data(self): 88 | """Download data if needed. This method is called only from a single GPU. 89 | Do not use it to assign state (self.x = y).""" 90 | pass 91 | 92 | def setup(self, stage: Optional[str] = None): 93 | """Load data. Set internal variables: self._data_train, self._data_val, self._data_test.""" 94 | dataset_kwargs = dict( 95 | data_dir=self.hparams.data_dir, 96 | exp_type=self.hparams.exp_type, 97 | target_type=self.hparams.target_type, 98 | target_variable=self.hparams.target_variable, 99 | verbose=self.hparams.verbose, 100 | input_transform=self.input_transform, 101 | normalizer=self.normalizer, 102 | ) 103 | 104 | # Training set: 105 | if stage == "fit" or stage is None: 106 | # Get & check list of training/validation years 107 | train_years = year_string_to_list(self.hparams.train_years) 108 | assert all([y in TRAIN_YEARS for y in train_years]), f"All years in --train_years must be in {TRAIN_YEARS}!" 109 | 110 | self._data_train = ClimART_HdF5_Dataset(years=train_years, name='Train', 111 | load_h5_into_mem=self.hparams.load_train_into_mem, 112 | **dataset_kwargs) 113 | # Validation set 114 | if stage in ['fit', 'validate', None] and self.hparams.validation_years is not None: 115 | val_yrs = year_string_to_list(self.hparams.validation_years) 116 | assert all([y in VAL_YEARS for y in val_yrs]), f'All years in --validation_years must be in {VAL_YEARS}!' 117 | self._data_val = ClimART_HdF5_Dataset(years=val_yrs, name='Val', 118 | load_h5_into_mem=self.hparams.load_valid_into_mem, 119 | **dataset_kwargs) 120 | # Test sets: 121 | # - Main Present-day Test Set(s): 122 | # To compute metrics for each test year -> use a separate dataloader for each of the test years (2007-14). 123 | if stage == "test" or stage is None: 124 | dataset_kwargs["load_h5_into_mem"] = self.hparams.load_test_into_mem 125 | if self.hparams.test_main_dataset: 126 | test_sets = [ 127 | ClimART_HdF5_Dataset(years=[test_year], name=f'Test_{test_year}', **dataset_kwargs) 128 | for test_year in TEST_YEARS 129 | ] 130 | else: 131 | test_sets = [] 132 | log.info(" Main test dataset (2007-14) will not be tested on in this run.") 133 | # - OOD Test Sets: 134 | ood_test_sets = [] 135 | if self.hparams.test_ood_1991: 136 | # 1991 OOD test year accounts for Mt. Pinatubo eruption: especially challenging for clear-sky conditions 137 | ood_test_sets += [ClimART_HdF5_Dataset(years=OOD_PRESENT_YEARS, name='OOD Test', **dataset_kwargs)] 138 | if self.hparams.test_ood_historic: 139 | ood_test_sets += [ 140 | ClimART_HdF5_Dataset(years=OOD_HISTORIC_YEARS, name='Historic Test', **dataset_kwargs)] 141 | if self.hparams.test_ood_future: 142 | ood_test_sets += [ClimART_HdF5_Dataset(years=OOD_FUTURE_YEARS, name='Future Test', **dataset_kwargs)] 143 | 144 | self._data_test = test_sets + ood_test_sets 145 | 146 | # Prediction set: 147 | if stage == "predict" and self.hparams.predict_years is not None: 148 | dataset_kwargs["load_h5_into_mem"] = self.hparams.load_test_into_mem 149 | predict_years = year_string_to_list(self.hparams.predict_years) 150 | self._data_predict = [ 151 | ClimART_HdF5_Dataset(years=[pred_year], name=f'Predict_{pred_year}', **dataset_kwargs) 152 | for pred_year in predict_years 153 | ] 154 | 155 | @property 156 | def test_set_names(self) -> List[str]: 157 | if self._test_set_names is None: 158 | test_names = [] 159 | if self.hparams.test_main_dataset: 160 | test_names += [f'{test_year}' for test_year in TEST_YEARS] 161 | if self.hparams.test_ood_1991: 162 | test_names += ['OOD'] 163 | if self.hparams.test_ood_historic: 164 | test_names += ['historic'] 165 | if self.hparams.test_ood_future: 166 | test_names += ['future'] 167 | self._test_set_names = test_names 168 | return self._test_set_names 169 | 170 | @property 171 | def predict_years(self) -> List[int]: 172 | return year_string_to_list(self.hparams.predict_years) 173 | 174 | @predict_years.setter 175 | def predict_years(self, predict_years: str): 176 | self.hparams.predict_years = predict_years 177 | 178 | def on_before_batch_transfer(self, batch, dataloader_idx): 179 | return batch 180 | 181 | def on_after_batch_transfer(self, batch, dataloader_idx): 182 | return batch 183 | 184 | def _shared_dataloader_kwargs(self) -> dict: 185 | shared_kwargs = dict(num_workers=int(self.hparams.num_workers), pin_memory=self.hparams.pin_memory) 186 | return shared_kwargs 187 | 188 | def _shared_eval_dataloader_kwargs(self) -> dict: 189 | return dict(**self._shared_dataloader_kwargs(), batch_size=self.hparams.eval_batch_size, shuffle=False) 190 | 191 | def train_dataloader(self): 192 | return DataLoader( 193 | dataset=self._data_train, 194 | batch_size=self.hparams.batch_size, 195 | shuffle=True, 196 | **self._shared_dataloader_kwargs(), 197 | ) 198 | 199 | def val_dataloader(self): 200 | return DataLoader( 201 | dataset=self._data_val, 202 | **self._shared_eval_dataloader_kwargs() 203 | ) if self._data_val is not None else None 204 | 205 | def test_dataloader(self) -> List[DataLoader]: 206 | return [DataLoader( 207 | dataset=data_test_subset, 208 | **self._shared_eval_dataloader_kwargs() 209 | ) for data_test_subset in self._data_test] 210 | 211 | def predict_dataloader(self) -> EVAL_DATALOADERS: 212 | return [DataLoader( 213 | dataset=data_test_subset, 214 | **self._shared_eval_dataloader_kwargs() 215 | ) for data_test_subset in self._data_predict] -------------------------------------------------------------------------------- /climart/interface.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import hydra 5 | import torch 6 | from omegaconf import DictConfig 7 | 8 | from climart.data_transform.normalization import Normalizer 9 | from climart.datamodules.pl_climart_datamodule import ClimartDataModule 10 | from climart.models.base_model import BaseModel 11 | 12 | """ 13 | In this file you can find helper functions to avoid model/data loading and reloading boilerplate code 14 | """ 15 | 16 | 17 | def get_model(config: DictConfig, **kwargs) -> BaseModel: 18 | """ 19 | Args: 20 | config (DictConfig): A OmegaConf config (e.g. produced by hydra yaml config file parsing) 21 | Returns: 22 | The model that you can directly use to train with pytorch-lightning 23 | """ 24 | if config.get('normalizer'): 25 | # This can be a bit redundant with get_datamodule (normalizer is instantiated twice), but it is better to be 26 | # sure that the output_normalizer is used by the model in cases where pytorch-lightning is not used. 27 | # By default if you use pytorch-lightning, the correct output_normalizer is passed to the model before training, 28 | # even without the below 29 | normalizer: Normalizer = hydra.utils.instantiate( 30 | config.normalizer, _recursive_=False, 31 | datamodule_config=config.datamodule, 32 | ) 33 | kwargs['output_normalizer'] = normalizer.output_normalizer 34 | model: BaseModel = hydra.utils.instantiate( 35 | config.model, _recursive_=False, 36 | datamodule_config=config.datamodule, 37 | **kwargs 38 | ) 39 | return model 40 | 41 | 42 | def get_datamodule(config: DictConfig) -> ClimartDataModule: 43 | """ 44 | Args: 45 | config (DictConfig): A OmegaConf config (e.g. produced by hydra yaml config file parsing) 46 | Returns: 47 | A datamodule that you can directly use to train pytorch-lightning models 48 | """ 49 | # First we instantiate our normalization preprocesser, then our datamodule, and finally the model 50 | normalizer: Normalizer = hydra.utils.instantiate( 51 | config.normalizer, 52 | datamodule_config=config.datamodule, 53 | _recursive_=False 54 | ) 55 | data_module: ClimartDataModule = hydra.utils.instantiate( 56 | config.datamodule, 57 | input_transform=config.model.get("input_transform"), 58 | normalizer=normalizer 59 | ) 60 | return data_module 61 | 62 | 63 | def get_model_and_data(config: DictConfig) -> (BaseModel, ClimartDataModule): 64 | """ 65 | Args: 66 | config (DictConfig): A OmegaConf config (e.g. produced by hydra yaml config file parsing) 67 | Returns: 68 | A tuple of (model, datamodule), that you can directly use to train with pytorch-lightning 69 | (e.g., checkpointing the best model w.r.t. a small validation set with the ModelCheckpoint callback), 70 | with: 71 | trainer.fit(model=model, datamodule=datamodule) 72 | """ 73 | data_module = get_datamodule(config) 74 | model: BaseModel = hydra.utils.instantiate( 75 | config.model, _recursive_=False, 76 | datamodule_config=config.datamodule, 77 | output_normalizer=data_module.normalizer.output_normalizer, 78 | ) 79 | return model, data_module 80 | 81 | 82 | def reload_model_from_config_and_ckpt(config: DictConfig, model_path: str, load_datamodule: bool = False): 83 | model, data_module = get_model_and_data(config) 84 | # Reload model 85 | model_state = torch.load(model_path)['state_dict'] 86 | model.load_state_dict(model_state) 87 | if load_datamodule: 88 | return model, data_module 89 | return model 90 | 91 | 92 | -------------------------------------------------------------------------------- /climart/models/CNNs/CNN.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from einops import rearrange, repeat 9 | 10 | from climart.models.base_model import BaseModel 11 | from climart.utils.utils import get_activation_function, get_normalization_layer 12 | from climart.models.modules.additional_layers import Multiscale_Module, GAP, SE_Block 13 | 14 | 15 | class CNN_Net(BaseModel): 16 | def __init__(self, 17 | hidden_dims: Sequence[int], 18 | dilation: int = 1, 19 | net_normalization: str = 'none', 20 | kernels: Sequence[int] = (20, 10, 5), 21 | strides: Sequence[int] = (2, 1, 1), # 221 22 | gap: bool = False, 23 | se_block: bool = False, 24 | activation_function: str = 'relu', 25 | dropout: float = 0.0, 26 | *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | self.save_hyperparameters() 29 | self.output_dim = self.raw_output_dim 30 | self.channel_list = list(hidden_dims) 31 | input_dim = self.input_transform.output_dim 32 | self.channel_list = [input_dim] + self.channel_list 33 | 34 | self.linear_in_shape = 10 35 | self.use_linear = gap 36 | self.ratio = 16 37 | self.kernel_list = list(kernels) 38 | self.stride_list = list(strides) 39 | self.global_average = GAP() 40 | 41 | feat_cnn_modules = [] 42 | for i in range(len(self.channel_list) - 1): 43 | out = self.channel_list[i + 1] 44 | feat_cnn_modules += [nn.Conv1d(in_channels=self.channel_list[i], 45 | out_channels=out, kernel_size=self.kernel_list[i], 46 | stride=self.stride_list[i], 47 | bias=self.hparams.net_normalization != 'batch_norm', 48 | dilation=self.hparams.dilation)] 49 | if se_block: 50 | feat_cnn_modules.append(SE_Block(out, self.ratio)) 51 | if self.hparams.net_normalization != 'none': 52 | feat_cnn_modules += [get_normalization_layer(self.hparams.net_normalization, out)] 53 | feat_cnn_modules += [get_activation_function(activation_function, functional=False)] 54 | # TODO: Need to add adaptive pooling with arguments 55 | feat_cnn_modules += [nn.Dropout(dropout)] 56 | 57 | self.feat_cnn = nn.Sequential(*feat_cnn_modules) 58 | 59 | # input_dim = [self.channel_list[0], self.linear_in_shape] # TODO: Need to pass input shape as argument 60 | # linear_input_shape = functools.reduce(operator.mul, list(self.feat_cnn(torch.rand(1, *input_dim)).shape)) 61 | # print(linear_input_shape) 62 | linear_layers = [] 63 | if not self.use_linear: 64 | linear_layers.append(nn.Linear(int(self.channel_list[-1] / 100) * 400, 256, bias=True)) 65 | linear_layers.append(get_activation_function(activation_function, functional=False)) 66 | linear_layers.append(nn.Dropout(dropout)) 67 | linear_layers.append(nn.Linear(256, self.output_dim, bias=True)) 68 | self.ll = nn.Sequential(*linear_layers) 69 | 70 | def forward(self, X: Union[Tensor, Dict[str, Tensor]]) -> Tensor: 71 | """ 72 | input: 73 | Dict with key-values {GLOBALS: x_glob, LEVELS: x_lev, LAYERS: x_lay}, 74 | where x_*** are the corresponding features. 75 | """ 76 | X = self.feat_cnn(X) 77 | 78 | if not self.use_linear: 79 | X = rearrange(X, 'b f c -> b (f c)') 80 | X = self.ll(X) 81 | else: 82 | X = self.global_average(X) 83 | 84 | return X.squeeze(2) 85 | 86 | 87 | class CNN_Multiscale(BaseModel): 88 | def __init__(self, 89 | hidden_dims: Sequence[int], 90 | out_dim: int, 91 | dilation: int = 1, 92 | gap: bool = False, 93 | se_block: bool = False, 94 | use_act: bool = False, 95 | net_normalization: str = 'none', 96 | activation_function: str = 'relu', 97 | dropout: float = 0.0, 98 | *args, **kwargs): 99 | # super().__init__(channels_list, out_dim, column_handler, projection, net_normalization, 100 | # gap, se_block, activation_function, dropout, *args, **kwargs) 101 | super().__init__(*args, **kwargs) 102 | self.save_hyperparameters() 103 | self.channels_per_layer = 200 104 | self.linear_in_shape = 10 105 | self.multiscale_in_shape = 10 106 | self.ratio = 16 107 | self.kernel_list = [6, 4, 4] 108 | self.stride_list = [2, 1, 1] 109 | self.stride = 1 110 | self.use_linear = gap 111 | self.out_size = out_dim 112 | self.channel_list = hidden_dims 113 | 114 | feat_cnn_modules = [] 115 | for i in range(len(self.channel_list) - 1): 116 | out = self.channel_list[i + 1] 117 | feat_cnn_modules.append(nn.Conv1d(in_channels=self.channel_list[i], 118 | out_channels=out, kernel_size=self.kernel_list[i], 119 | stride=self.stride_list[i], 120 | bias=self.hparams.net_normalization != 'batch_norm')) 121 | if se_block: 122 | feat_cnn_modules.append(SE_Block(out, self.ratio)) 123 | if self.hparams.net_normalization != 'none': 124 | feat_cnn_modules += [get_normalization_layer(self.hparams.net_normalization, self.channel_list[i + 1])] 125 | feat_cnn_modules.append(get_activation_function(activation_function, functional=False)) 126 | # TODO: Need to add adaptive pooling with arguments 127 | feat_cnn_modules.append(nn.Dropout(dropout)) 128 | 129 | self.feat_cnn = nn.Sequential(*feat_cnn_modules) 130 | kwargs = {'in_channels': self.channel_list[-1], 'channels_per_layer': self.channels_per_layer, 131 | 'out_shape': self.linear_in_shape, 'dil_rate': self.hparams.dilation, 'use_act': use_act} 132 | self.pyramid = Multiscale_Module(**kwargs) 133 | 134 | input_dim = [self.channel_list[0], self.linear_in_shape] 135 | # TODO: Need to pass input shape as argument 136 | # linear_input_shape = functools.reduce(operator.mul, list(self.feat_cnn(torch.rand(1, *input_dim)).shape)) 137 | linear_layers = [] 138 | # linear_layers.append(nn.Linear(int(self.channel_list[-1]/100)*1000, 300, bias=True)) 139 | linear_layers.append(nn.Linear(2800, 256, bias=True)) 140 | linear_layers.append(get_activation_function(activation_function, functional=False)) 141 | linear_layers.append(nn.Linear(256, self.out_size, bias=True)) 142 | self.ll = nn.Sequential(*linear_layers) 143 | 144 | def forward(self, X: Union[Tensor, Dict[str, Tensor]]) -> Tensor: 145 | """ 146 | input: 147 | Dict with key-values {GLOBALS: x_glob, LEVELS: x_lev, LAYERS: x_lay}, 148 | where x_*** are the corresponding features. 149 | """ 150 | if isinstance(X, dict): 151 | X_levels = X['levels'] 152 | 153 | X_layers = rearrange(F.pad(rearrange(X['layers'], 'b c f -> () b c f'), (0, 0, 1, 0), \ 154 | mode='reflect'), '() b c f -> b c f') 155 | X_global = repeat(X['globals'], 'b f -> b c f', c=X_levels.shape[1]) 156 | 157 | X = torch.cat((X_levels, X_layers, X_global), -1) 158 | X = rearrange(X, 'b c f -> b f c') 159 | 160 | X = self.feat_cnn(X) 161 | X = rearrange(X, 'b f c -> b (f c)') 162 | X = self.ll(X) 163 | 164 | return X.squeeze(1) 165 | -------------------------------------------------------------------------------- /climart/models/CNNs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/models/CNNs/__init__.py -------------------------------------------------------------------------------- /climart/models/GraphNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/models/GraphNet/__init__.py -------------------------------------------------------------------------------- /climart/models/GraphNet/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict 3 | 4 | from torch import Tensor 5 | 6 | 7 | GLOBALS = "globals" 8 | NODES = "nodes" 9 | EDGES = "edges" 10 | GRAPH_COMPONENTS = [GLOBALS, NODES, EDGES] 11 | 12 | GraphComponentToTensor = Dict[str, Tensor] 13 | 14 | 15 | class AggregationTypes(Enum): 16 | AGG_E_TO_N = "edge_to_node_aggregation" 17 | AGG_E_TO_U = "edge_to_global_aggregation" 18 | AGG_N_TO_U = "node_to_global_aggregation" 19 | AGGREGATORS = [AGG_E_TO_N, AGG_E_TO_U, AGG_N_TO_U] 20 | 21 | 22 | N_NODES = "n_nodes" 23 | N_EDGES = "n_edges" 24 | SPATIAL_DIM = 1 25 | -------------------------------------------------------------------------------- /climart/models/GraphNet/graph_network.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence, Optional, Union 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | from climart.data_transform.transforms import AbstractGraphTransform 8 | from climart.models.GraphNet.constants import AggregationTypes, NODES, EDGES, GLOBALS, GRAPH_COMPONENTS, SPATIAL_DIM 9 | from climart.models.GraphNet.graph_network_block import GraphNetBlock 10 | from climart.models.base_model import BaseModel 11 | from climart.models.modules.mlp import MLP 12 | 13 | 14 | class GraphNetwork(BaseModel): 15 | def __init__(self, 16 | # input_dim: Dict[str, int], 17 | hidden_dims: Sequence[int], 18 | input_transform: AbstractGraphTransform, 19 | readout_which_output: Optional[str] = NODES, 20 | update_mlp_n_layers: int = 1, 21 | aggregator_funcs: Union[str, Dict[AggregationTypes, int]] = 'sum', 22 | net_normalization: str = 'layer_norm', 23 | residual: Union[bool, Dict[str, bool]] = True, 24 | activation_function: str = 'Gelu', 25 | output_activation_function: Optional[str] = None, 26 | output_net_normalization: bool = True, 27 | dropout: float = 0.0, 28 | *args, **kwargs): 29 | """ 30 | Args: 31 | readout_which_output: Which graph part to return (default: edges), 32 | can be {EDGES, NODES, GLOBALS, 'graph', None} 33 | If None or 'graph', the whole graph is returned. 34 | """ 35 | super().__init__(input_transform=input_transform, *args, **kwargs) 36 | self.save_hyperparameters(ignore="verbose_mlp") 37 | assert len(self.hparams.hidden_dims) >= 1 38 | assert update_mlp_n_layers >= 1 39 | self.input_transform: AbstractGraphTransform = self.input_transform 40 | in_dim = self.input_transform.output_dim 41 | 42 | senders, receivers = self.input_transform.get_edge_idxs() 43 | gn_layers = [] 44 | dims = [in_dim] + list(hidden_dims) 45 | for i in range(1, len(dims)): 46 | out_activation_function = output_activation_function if i == len(dims) - 1 else activation_function 47 | out_net_norm = output_net_normalization if i == len(dims) - 1 else True 48 | gn_layers += [ 49 | GraphNetBlock( 50 | in_dims=in_dim, 51 | out_dims=dims[i], 52 | senders=senders, 53 | receivers=receivers, 54 | n_layers=update_mlp_n_layers, 55 | residual=residual, 56 | net_norm=net_normalization, 57 | activation=activation_function, 58 | dropout=dropout, 59 | output_normalization=out_net_norm, 60 | output_activation_function=out_activation_function, 61 | aggregator_funcs=aggregator_funcs, 62 | )] 63 | in_dim = dims[i] 64 | 65 | self.layers: nn.ModuleList[GraphNetBlock] = nn.ModuleList(gn_layers) 66 | self.output_type = readout_which_output 67 | if self.output_type not in [NODES, EDGES, GLOBALS, 'graph', None]: 68 | raise ValueError("Unsupported argument for GraphNetwork `output_type`", readout_which_output) 69 | if hasattr(self.input_transform, "n_edges"): 70 | err_msg = f"GN inferred {self.n_edges} edges, but input_tranform refers to {self.input_transform.n_edges}" 71 | assert self.n_edges == self.input_transform.n_edges, err_msg 72 | 73 | @property 74 | def n_edges(self): 75 | return self.layers[0].n_edges 76 | 77 | def update_graph_structure(self, senders: Sequence[int], receivers: Sequence[int]) -> None: 78 | for layer in self.layers: 79 | layer.update_graph_structure(senders, receivers) 80 | 81 | def forward(self, input: Dict[str, Tensor]): 82 | """ 83 | input: 84 | Dict with key-values {GLOBALS: x_glob, LEVELS: x_lev, LAYERS: x_lay}, 85 | where x_*** are the corresponding features. 86 | """ 87 | graph_new = self.update_graph(input) 88 | if self.output_type is not None and self.output_type != 'graph': 89 | graph_component = graph_new[self.output_type] 90 | return graph_component.reshape(graph_component.shape[0], -1) 91 | else: 92 | return graph_new 93 | 94 | def update_graph(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: 95 | graph_net_input = input 96 | graph_new = self.layers[0](graph_net_input) 97 | for graph_net_block in self.layers[1:]: 98 | graph_new = graph_net_block(graph_new) 99 | 100 | return graph_new 101 | 102 | 103 | class GN_withReadout(GraphNetwork): 104 | def __init__(self, 105 | # input_dim, 106 | # out_dim: int, 107 | readout_which_output=NODES, 108 | graph_pooling: str = 'mean', 109 | *args, **kwargs): 110 | super().__init__(output_activation_function=kwargs[ 111 | 'activation_function'] if 'activation_function' in kwargs else 'gelu', 112 | *args, **kwargs) 113 | assert readout_which_output in GRAPH_COMPONENTS 114 | self.readout_which_output = readout_which_output 115 | self.graph_pooling = graph_pooling.lower() 116 | 117 | self.mlp_input_dim = self.hparams.hidden_dims[-1] 118 | if readout_which_output in [NODES, EDGES] and self.graph_pooling in ['sum+mean', 'mean+sum', 'mean&sum', 119 | 'sum&mean']: 120 | self.mlp_input_dim = self.mlp_input_dim * 2 121 | 122 | self.readout_mlp = MLP( 123 | input_dim=self.mlp_input_dim, 124 | hidden_dims=[int((self.mlp_input_dim + self.raw_output_dim) / 2)], 125 | output_dim=self.raw_output_dim, 126 | dropout=kwargs['dropout'] if 'dropout' in kwargs else 0.0, 127 | activation_function=self.hparams.activation_function, 128 | net_normalization='layer_norm' if 'inst' in self.hparams.net_normalization else self.hparams.net_normalization, 129 | output_normalization=False, output_activation_function=None, out_layer_bias_init=self.out_layer_bias_init, 130 | name='GraphNet_Readout_MLP' 131 | ) 132 | 133 | def forward(self, input: Dict[str, torch.Tensor]): 134 | final_graph = self.update_graph(input) 135 | output_to_use = final_graph[self.readout_which_output] 136 | 137 | # Graph pooling, e.g. take the mean over all node embeddings (dimension=1) 138 | if self.readout_which_output == GLOBALS: 139 | g_emb = output_to_use 140 | else: 141 | if self.graph_pooling == 'sum': 142 | g_emb = torch.sum(output_to_use, dim=SPATIAL_DIM) 143 | elif self.graph_pooling == 'mean': 144 | g_emb = torch.mean(output_to_use, dim=SPATIAL_DIM) 145 | elif self.graph_pooling == 'max': 146 | g_emb, _ = torch.max(output_to_use, dim=SPATIAL_DIM) # returns (values, indices) 147 | elif self.graph_pooling in ['sum+mean', 'mean+sum', 'mean&sum', 'sum&mean']: 148 | xmean = torch.mean(output_to_use, dim=SPATIAL_DIM) 149 | xsum = torch.sum(output_to_use, dim=SPATIAL_DIM) # (batch-size, out-dim) 150 | g_emb = torch.cat((xmean, xsum), dim=SPATIAL_DIM) # (batch-size 2*out-dim) 151 | else: 152 | raise ValueError('Unsupported readout operation', self.graph_pooling) 153 | 154 | # After graph pooling: (batch-size, out-dim) 155 | # torch.Size([64, 100, 2]) 156 | # torch.Size([64, 100]) 157 | # After reshape: torch.Size([64, 200]) 158 | out = self.readout_mlp(g_emb) 159 | return out 160 | 161 | -------------------------------------------------------------------------------- /climart/models/GraphNet/graph_network_block.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | from typing import Dict, Union, Optional, Sequence 4 | from einops import rearrange, repeat 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch_scatter 10 | from climart.models.modules.mlp import MLP 11 | import climart.models.GraphNet.constants as gn_constants 12 | from climart.models.GraphNet.constants import GraphComponentToTensor, AggregationTypes, NODES, EDGES, GLOBALS 13 | 14 | 15 | class GraphNetBlock(nn.Module): 16 | def __init__(self, 17 | in_dims: Union[int, Dict[str, int]], 18 | out_dims: Union[int, Dict[str, int]], 19 | senders: Sequence[int], 20 | receivers: Sequence[int], 21 | n_layers: int = 1, 22 | use_edge_features: bool = True, 23 | use_global_features: bool = True, 24 | residual: Union[bool, Dict[str, bool]] = False, 25 | net_norm: str = 'none', 26 | activation: str = 'relu', 27 | dropout: float = 0, 28 | output_normalization: bool = True, 29 | output_activation_function: Optional[str] = None, 30 | aggregator_funcs: Union[str, Dict[AggregationTypes, int]] = 'sum', 31 | ): 32 | super().__init__() 33 | if isinstance(in_dims, int): 34 | in_dims: dict = {c: in_dims for c in gn_constants.GRAPH_COMPONENTS} 35 | if isinstance(out_dims, int): 36 | out_dims: dict = {c: out_dims for c in gn_constants.GRAPH_COMPONENTS} 37 | 38 | self.components = gn_constants.GRAPH_COMPONENTS 39 | self.use_edge_features = use_edge_features 40 | self.use_global_features = use_global_features 41 | 42 | if not use_edge_features: 43 | in_dims[EDGES] = out_dims[EDGES] = 0 44 | self.components.remove(EDGES) 45 | if not use_global_features: 46 | in_dims[GLOBALS] = out_dims[GLOBALS] = 0 47 | self.components.remove(GLOBALS) 48 | 49 | n_feats_e = in_dims[EDGES] 50 | n_feats_n = in_dims[NODES] 51 | n_feats_u = in_dims[GLOBALS] 52 | self._n_edges = None 53 | self.update_graph_structure(senders, receivers) 54 | self.residual = {c: residual for c in self.components} 55 | 56 | in_dims = { 57 | EDGES: 2 * n_feats_n + n_feats_e + n_feats_u, 58 | NODES: n_feats_n + out_dims[EDGES] + n_feats_u, 59 | GLOBALS: out_dims[NODES] + out_dims[EDGES] + n_feats_u 60 | } 61 | 62 | update_funcs = OrderedDict() # nn.ModuleDict() 63 | for component in self.components: 64 | c_in_dim = in_dims[component] 65 | out_dim = out_dims[component] 66 | if c_in_dim != out_dim: 67 | self.residual[component] = False 68 | 69 | in_dim = in_dims[component] 70 | hdim = int((in_dim + out_dim) / 2) 71 | 72 | update_funcs[component] = MLP( 73 | input_dim=in_dim, 74 | hidden_dims=[hdim for _ in range(n_layers)], 75 | output_dim=out_dim, 76 | activation_function=activation, 77 | net_normalization=net_norm, 78 | dropout=dropout, 79 | output_normalization=output_normalization, 80 | output_activation_function=output_activation_function, 81 | out_layer_bias_init=None, 82 | name=f'GN_{component}_update_MLP', 83 | ) 84 | self.update_funcs = nn.ModuleDict(update_funcs) 85 | 86 | self.aggregator_funcs = OrderedDict() # nn.ModuleDict() 87 | for aggregation in AggregationTypes: 88 | agg_func = aggregator_funcs[aggregation] if isinstance(aggregator_funcs, dict) else aggregator_funcs 89 | agg_func = agg_func.lower() 90 | agg_dim = gn_constants.SPATIAL_DIM 91 | if agg_func == 'sum': 92 | if aggregation == AggregationTypes.AGG_E_TO_N: 93 | # agg func returns a (batch-size, #nodes, #edge-feats) tensor 94 | agg_func = partial(torch_scatter.scatter_sum, dim=agg_dim) 95 | else: 96 | agg_func = partial(torch.sum, dim=agg_dim) 97 | elif agg_func == 'mean': 98 | if aggregation == AggregationTypes.AGG_E_TO_N: 99 | agg_func = partial(torch_scatter.scatter_mean, dim=agg_dim) 100 | else: 101 | agg_func = partial(torch.mean, dim=agg_dim) 102 | elif agg_func == 'max': 103 | if aggregation == AggregationTypes.AGG_E_TO_N: 104 | def max_scatter_partial(x, index): 105 | return torch_scatter.scatter_max(x, dim=agg_dim, index=index)[0] 106 | agg_func = max_scatter_partial 107 | else: 108 | agg_func = lambda x: torch.max(x, dim=agg_dim)[0] # returns (values, indices)[0] 109 | # elif agg_func in ['sum+mean', 'mean+sum', 'mean&sum', 'sum&mean']: 110 | # xmean = torch.mean(final_embs, dim=1) 111 | # xsum = torch.sum(final_embs, dim=1) # (batch-size, out-dim) 112 | # g_emb = torch.cat((xmean, xsum), dim=1) # (batch-size 2*out-dim) 113 | else: 114 | raise ValueError('Unsupported aggregation operation') 115 | self.aggregator_funcs[aggregation] = agg_func 116 | 117 | @property 118 | def n_edges(self): 119 | return self._n_edges 120 | 121 | def update_graph_structure(self, senders: Sequence[int], receivers: Sequence[int]) -> None: 122 | if not torch.is_tensor(senders): 123 | senders = torch.from_numpy(senders) if isinstance(senders, np.ndarray) else torch.tensor(senders) 124 | if not torch.is_tensor(receivers): 125 | receivers = torch.from_numpy(receivers) if isinstance(receivers, np.ndarray) else torch.tensor(receivers) 126 | self.register_buffer('_senders', senders.long()) # so that they are moved to correct device 127 | self.register_buffer('_receivers', receivers.long()) 128 | assert len(self._receivers) == len(self._senders), "Sender and receiver must both have #edges indices" 129 | self._n_edges = len(self._senders) 130 | 131 | # def _aggregate_edges_for_node(self, mode, edge_feats_old: Tensor) -> Tensor: 132 | # """ Return a (batch-size, #nodes, #edge-feats) tensor, K, where K_bij = Agg({e_j | receiver_j = i}) """ 133 | # indices = self._senders if mode == gn_constants.SENDERS else self._receivers 134 | # batch_size = edge_feats_old.shape[0] 135 | 136 | def forward(self, graph: GraphComponentToTensor) -> GraphComponentToTensor: 137 | if self.use_edge_features: 138 | edge_feats_old = graph[EDGES] # edge_feats_old have shape (b, #edges, #edge-feats) 139 | if edge_feats_old.shape[gn_constants.SPATIAL_DIM] != self._n_edges: 140 | raise ValueError(f"Edge features imply {edge_feats_old.shape[gn_constants.SPATIAL_DIM]} edges, " 141 | f"while sender and receiver lists imply {self._n_edges} edges.") 142 | node_feats_old = graph[NODES] # node_feats_old have shape (b, #nodes, #node-feats) 143 | n_nodes = node_feats_old.shape[gn_constants.SPATIAL_DIM] 144 | if self.use_global_features: 145 | global_feats_old = graph[GLOBALS] # global_feats_old have shape (b, #global-feats) 146 | batch_size, n_glob_feats = global_feats_old.shape 147 | 148 | out = {c: None for c in self.components} 149 | 150 | # ----------------------- Update edges 151 | # self.senders and self.receivers are a sequence of indices with #edges elements 152 | sender_feats = node_feats_old.index_select(index=self._senders, dim=gn_constants.SPATIAL_DIM) 153 | receiver_feats = node_feats_old.index_select(index=self._receivers, dim=gn_constants.SPATIAL_DIM) 154 | 155 | # print('E:', edge_feats_old.shape, 'V:', node_feats_old.shape, 'U:', global_feats_old.shape) 156 | # print('Senders', sender_feats.shape, 'Recvs', receiver_feats.shape, global_feats_unsqueezed.shape) 157 | mlp_input_e = torch.cat([ 158 | edge_feats_old, # (b, #edges, #edge-feats) 159 | sender_feats, # (b, #edges, #node-feats) 160 | receiver_feats, # (b, #edges, #node-feats) 161 | repeat(global_feats_old, 'b g -> b e g', e=self._n_edges) # (1, self.n_edges, 1) (b, 1, #global-feats) 162 | ], dim=-1) 163 | 164 | mlp_input_e = rearrange(mlp_input_e, 'b e d1 -> (b e) d1') 165 | 166 | out[EDGES] = self.update_funcs[EDGES](mlp_input_e) 167 | out[EDGES] = rearrange(out[EDGES], '(b e) d2 -> b e d2', b=batch_size, e=self._n_edges) 168 | if self.residual[EDGES]: 169 | out[EDGES] += edge_feats_old 170 | # ----------------------- Update nodes 171 | aggregated_edge_feats_for_node = self.aggregator_funcs[AggregationTypes.AGG_E_TO_N]( 172 | out[EDGES], index=self._receivers 173 | ) 174 | 175 | mlp_input_n = torch.cat([ 176 | aggregated_edge_feats_for_node, # (b, #nodes, #edge-feats) 177 | node_feats_old, # (b, #nodes, #node-feats) 178 | repeat(global_feats_old, 'b g -> b n g', n=n_nodes) # (b, #nodes, #global-feats) 179 | ], dim=-1) 180 | mlp_input_n = rearrange(mlp_input_n, 'b n d1 -> (b n) d1') 181 | 182 | # print('Agg E:', aggregated_edge_feats_for_node.shape, 'V:', node_feats_old.shape, mlp_input_n.shape) 183 | out[NODES] = self.update_funcs[NODES](mlp_input_n) 184 | out[NODES] = rearrange(out[NODES], '(b n) d2 -> b n d2', b=batch_size, n=n_nodes) 185 | if self.residual[NODES]: 186 | out[NODES] += node_feats_old 187 | # ----------------------- Update global features 188 | aggregated_edge_feats_for_global = self.aggregator_funcs[AggregationTypes.AGG_E_TO_U](out[EDGES]) 189 | aggregated_node_feats_for_global = self.aggregator_funcs[AggregationTypes.AGG_N_TO_U](out[NODES]) 190 | # print('Agg EU:', aggregated_edge_feats_for_global.shape, 'VU:', aggregated_node_feats_for_global.shape) 191 | 192 | mlp_input_u = torch.cat([ 193 | aggregated_edge_feats_for_global, # (b, #edge-feats) 194 | aggregated_node_feats_for_global, # (b, #node-feats) 195 | global_feats_old # (b, #global-feats) 196 | ], dim=-1) 197 | out[GLOBALS] = self.update_funcs[GLOBALS](mlp_input_u) 198 | if self.residual[GLOBALS]: 199 | out[GLOBALS] += global_feats_old 200 | 201 | # for c in self.components: 202 | # if self.residual[c]: 203 | # out[c] += graph[c] # residual connection 204 | 205 | return out 206 | 207 | 208 | 209 | if __name__ == '__main__': 210 | nedges = 50 211 | nnodes = 49 212 | nfe, nfn, nfu = 2, 22, 82 213 | b = 64 214 | E = torch.randn((b, nedges, nfe)) 215 | V = torch.randn((b, nnodes, nfn)) 216 | U = torch.randn((b, nfu)) 217 | 218 | senders = torch.arange(nedges) % nnodes 219 | receivers = (torch.arange(nedges) + 1) % nnodes 220 | 221 | gnl = GraphNetBlock( 222 | in_dims={NODES: nfn, EDGES: nfe, GLOBALS: nfu}, out_dims=128, senders=senders, receivers=receivers 223 | ) 224 | 225 | X_in = {NODES: V, EDGES: E, GLOBALS: U} 226 | x = gnl(X_in) 227 | print([k + str(xx.shape) for k, xx in x.items()]) 228 | -------------------------------------------------------------------------------- /climart/models/MLP.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Optional, Dict, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from omegaconf import DictConfig 7 | from torch import Tensor 8 | 9 | from climart.models.base_model import BaseModel 10 | from climart.models.modules.mlp import MLP 11 | 12 | 13 | class ClimartMLP(BaseModel): 14 | def __init__(self, 15 | hidden_dims: Sequence[int], 16 | datamodule_config: DictConfig = None, 17 | net_normalization: Optional[str] = None, 18 | activation_function: str = 'relu', 19 | dropout: float = 0.0, 20 | residual: bool = False, 21 | output_normalization: bool = False, 22 | output_activation_function: Optional[Union[str, bool]] = None, 23 | *args, **kwargs): 24 | """ 25 | Args: 26 | input_dim must either be an int, i.e. the expected 1D input tensor dim, or a dict s.t. 27 | input_dim and spatial_dim have the same keys to compute the flattened input shape. 28 | output_activation_function (str, bool, optional): By default no output activation function is used (None). 29 | If a string is passed, is must be the name of the desired output activation (e.g. 'softmax') 30 | If True, the same activation function is used as defined by the arg `activation_function`. 31 | """ 32 | super().__init__(datamodule_config=datamodule_config, *args, **kwargs) 33 | self.save_hyperparameters() 34 | 35 | if isinstance(self.raw_input_dim, dict): 36 | assert all([k in self.raw_spatial_dim.keys() for k in self.raw_input_dim.keys()]) 37 | self.input_dim = sum([self.raw_input_dim[k] * max(1, self.raw_spatial_dim[k]) for k in self.raw_input_dim.keys()]) # flattened 38 | self.log_text.info(f' Inferred a flattened input dim = {self.input_dim}') 39 | else: 40 | self.input_dim = self.raw_input_dim 41 | 42 | self.output_dim = self.raw_output_dim 43 | 44 | self.mlp = MLP( 45 | self.input_dim, hidden_dims, self.output_dim, 46 | net_normalization=net_normalization, activation_function=activation_function, dropout=dropout, 47 | residual=residual, output_normalization=output_normalization, 48 | output_activation_function=output_activation_function, out_layer_bias_init=self.out_layer_bias_init 49 | ) 50 | 51 | def forward(self, X: Tensor) -> Tensor: 52 | return self.mlp(X) 53 | -------------------------------------------------------------------------------- /climart/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/models/__init__.py -------------------------------------------------------------------------------- /climart/models/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LW_abs(nn.Module): 7 | """ 8 | Inputs: 18 (Temp, Pres, and Gasesous concentrations) as inputs 9 | Outputs: 256 (g-points LW absorption cross section) 10 | 11 | """ 12 | def __init__(self, input_size): 13 | super(LW_abs, self).__init__() 14 | 15 | self.input_size = input_size 16 | self.fc1 = nn.Linear(self.input_size, 58) #Not sure about the input shape, depends on nlevels 17 | self.fc2 = nn.Linear(58, 58) 18 | self.fc3 = nn.Linear(58, 256) 19 | 20 | 21 | def forward(self, x): 22 | 23 | x = F.softsign(self.fc1(x)) 24 | x = F.softsign(self.fc2(x)) 25 | x = self.fc3(x) 26 | return x 27 | 28 | 29 | class LW_ems(nn.Module): 30 | """ 31 | Inputs: 18 (Temp, Pres, and Gasesous concentrations) as inputs 32 | Outputs: 256 (g-points LW emission) 33 | 34 | """ 35 | def __init__(self, input_size): 36 | super(LW_ems, self).__init__() 37 | 38 | self.input_size = input_size 39 | self.fc1 = nn.Linear(self.input_size, 16) 40 | self.fc2 = nn.Linear(16, 16) 41 | self.fc3 = nn.Linear(16, 256) 42 | 43 | 44 | def forward(self, x): 45 | 46 | x = F.softsign(self.fc1(x)) 47 | x = F.softsign(self.fc2(x)) 48 | x = self.fc3(x) 49 | return x 50 | 51 | 52 | class SW_abs(nn.Module): 53 | """ 54 | Inputs: 7 (Temp, Pres, and Gasesous concentrations) as inputs, SW takes lesser gases than LW 55 | Outputs: 224 (g-points SW absorption cross section) 56 | 57 | """ 58 | def __init__(self, input_size): 59 | super(SW_abs, self).__init__() 60 | 61 | self.input_size = input_size 62 | self.fc1 = nn.Linear(self.input_size, 48) 63 | self.fc2 = nn.Linear(48, 48) 64 | self.fc3 = nn.Linear(48, 224) 65 | 66 | 67 | def forward(self, x): 68 | 69 | x = F.softsign(self.fc1(x)) 70 | x = F.softsign(self.fc2(x)) 71 | x = self.fc3(x) 72 | return x 73 | 74 | 75 | 76 | class SW_rcs(nn.Module): 77 | """ 78 | Inputs: 7 (Temp, Pres, and Gasesous concentrations) as inputs, SW takes lesser gases than LW 79 | Outputs: 224 (g-points SW rayleigh cross section) 80 | 81 | """ 82 | def __init__(self, input_size): 83 | super(SW_rcs, self).__init__() 84 | 85 | self.input_size = input_size 86 | self.fc1 = nn.Linear(self.input_size, 16) 87 | self.fc2 = nn.Linear(16, 16) 88 | self.fc3 = nn.Linear(16, 224) 89 | 90 | 91 | def forward(self, x): 92 | 93 | x = F.softsign(self.fc1(x)) 94 | x = F.softsign(self.fc2(x)) 95 | x = self.fc3(x) 96 | return x 97 | 98 | 99 | if __name__ == "__main__": 100 | net = SW_rcs(180) 101 | print(net) 102 | 103 | 104 | input = torch.randn(180) 105 | out = net(input) 106 | print(out) 107 | # params = list(net.parameters()) 108 | # print(params) 109 | -------------------------------------------------------------------------------- /climart/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/models/modules/__init__.py -------------------------------------------------------------------------------- /climart/models/modules/additional_layers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict, Optional, Union, Callable, Any 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch import Tensor 9 | from climart.models.modules.mlp import MLP 10 | 11 | 12 | class Multiscale_Module(nn.Module): 13 | 14 | def __init__(self, in_channels=None, channels_per_layer=None, out_shape=None, 15 | dil_rate=1, use_act=False, *args, **kwargs): 16 | super().__init__() 17 | self.out_shape = out_shape 18 | self.use_act = use_act 19 | 20 | self.multi_3 = nn.Conv1d(in_channels=in_channels, out_channels=channels_per_layer, 21 | kernel_size=5, stride=1, dilation=dil_rate) 22 | self.multi_5 = nn.Conv1d(in_channels=in_channels, out_channels=channels_per_layer, 23 | kernel_size=6, stride=1, dilation=dil_rate) 24 | self.multi_7 = nn.Conv1d(in_channels=in_channels, out_channels=channels_per_layer, 25 | kernel_size=9, stride=1, dilation=dil_rate) 26 | self.after_concat = nn.Conv1d(in_channels=int(channels_per_layer * 3), 27 | out_channels=int(channels_per_layer / 2), kernel_size=1, stride=1) 28 | self.gap = GAP() 29 | 30 | def forward(self, x): 31 | x_3 = self.multi_3(x) 32 | x_5 = self.multi_5(x) 33 | x_7 = self.multi_7(x) 34 | x_3 = F.adaptive_max_pool1d(x_3, self.out_shape) 35 | x_5 = F.adaptive_max_pool1d(x_5, self.out_shape) 36 | x_7 = F.adaptive_max_pool1d(x_7, self.out_shape) 37 | x_concat = torch.cat((x_3, x_5, x_7), 1) 38 | x_concat = self.after_concat(x_concat) 39 | 40 | if self.use_act: 41 | return torch.sigmoid(self.gap(x)) * x_concat 42 | else: 43 | return x_concat 44 | 45 | 46 | class GAP(): 47 | def __init__(self): 48 | pass 49 | 50 | def __call__(self, x): 51 | x = F.adaptive_avg_pool1d(x, 1) 52 | return x 53 | 54 | 55 | class SE_Block(nn.Module): 56 | def __init__(self, c, r=16): 57 | super().__init__() 58 | self.squeeze = GAP() 59 | self.excitation = nn.Sequential( 60 | nn.Linear(c, c // r, bias=False), 61 | nn.ReLU(inplace=True), 62 | nn.Linear(c // r, c, bias=False), 63 | nn.Sigmoid() 64 | ) 65 | 66 | def forward(self, x): 67 | b, c, f = x.shape 68 | y = self.squeeze(x).view(b, c) 69 | y = self.excitation(y).view(b, c, 1) 70 | return x * y.expand_as(x) 71 | 72 | 73 | class FeatureProjector(nn.Module): 74 | def __init__( 75 | self, 76 | input_name_to_feature_dim: Dict[str, int], 77 | projection_dim: int = 128, 78 | projector_n_layers: int = 1, 79 | projector_activation_func: str = 'Gelu', 80 | projector_net_normalization: str = 'none', 81 | projections_aggregation: Optional[Union[Callable[[Dict[str, Tensor]], Any], str]] = None, 82 | output_normalization: bool = True, 83 | output_activation_function: bool = False 84 | ): 85 | super().__init__() 86 | self.projection_dim = projection_dim 87 | self.input_name_to_feature_dim = input_name_to_feature_dim 88 | self.projections_aggregation = projections_aggregation 89 | 90 | input_name_to_projector = OrderedDict() 91 | for input_name, feature_dim in input_name_to_feature_dim.items(): 92 | projector_hidden_dim = int((feature_dim + projection_dim) / 2) 93 | projector = MLP( 94 | input_dim=feature_dim, 95 | hidden_dims=[projector_hidden_dim for _ in range(projector_n_layers)], 96 | output_dim=projection_dim, 97 | activation_function=projector_activation_func, 98 | net_normalization=projector_net_normalization, 99 | dropout=0, 100 | output_normalization=output_normalization, 101 | output_activation_function=output_activation_function, 102 | name=f"{input_name}_MLP_projector" 103 | ) 104 | input_name_to_projector[input_name] = projector 105 | 106 | self.input_name_to_projector = nn.ModuleDict(input_name_to_projector) 107 | 108 | def forward(self, 109 | inputs: Dict[str, Tensor] 110 | ) -> Union[Dict[str, Tensor], Any]: 111 | name_to_projection = dict() 112 | 113 | for name, projector in self.input_name_to_projector.items(): 114 | # Project (batch-size, ..arbitrary_dim(s).., in_feat_dim) 115 | # to (batch-size, ..arbitrary_dim(s).., projection_dim) 116 | in_feat_dim = self.input_name_to_feature_dim[name] 117 | input_tensor = inputs[name] 118 | shape_out = list(input_tensor.shape) 119 | shape_out[-1] = self.projection_dim 120 | 121 | projector_input = input_tensor.reshape((-1, in_feat_dim)) 122 | # projector_input has shape (batch-size * #spatial-dims, features) 123 | flattened_projection = projector(projector_input) 124 | name_to_projection[name] = flattened_projection.reshape(shape_out) 125 | 126 | if self.projections_aggregation is None: 127 | return name_to_projection # Dict[str, Tensor] 128 | else: 129 | return self.projections_aggregation(name_to_projection) # Any 130 | 131 | 132 | class PredictorHeads(nn.Module): 133 | """ 134 | Module to predict (with one or more MLPs) desired output based on a 1D hidden representation. 135 | Can be used to: 136 | - readout a hidden vector to a desired output dimensionality 137 | - predict multiple variables with separate MLP heads (e.g. one for rsuc and rsdc each) 138 | """ 139 | 140 | def __init__( 141 | self, 142 | input_dim: int, 143 | var_name_to_output_dim: Dict[str, int], 144 | separate_heads: bool = True, 145 | n_layers: int = 1, 146 | activation_func: str = 'Gelu', 147 | net_normalization: str = 'none', 148 | ): 149 | super().__init__() 150 | self.input_dim = input_dim 151 | self.separate_heads = separate_heads 152 | 153 | predictor_heads = OrderedDict() 154 | mlp_shared_params = { 155 | 'input_dim': input_dim, 156 | 'activation_function': activation_func, 157 | 'net_normalization': net_normalization, 158 | 'output_normalization': False, 159 | 'dropout': 0 160 | } 161 | if self.separate_heads: 162 | self.output_name_to_feature_dim = var_name_to_output_dim 163 | 164 | for output_name, var_out_dim in var_name_to_output_dim.items(): 165 | predictor_hidden_dim = int((input_dim + var_out_dim) / 2) 166 | predictor = MLP( 167 | hidden_dims=[predictor_hidden_dim for _ in range(n_layers)], 168 | output_dim=var_out_dim, 169 | **mlp_shared_params 170 | ) 171 | predictor_heads[output_name] = predictor 172 | else: 173 | joint_out_dim = sum([out_dim for _, out_dim in var_name_to_output_dim.items()]) 174 | self.output_name_to_feature_dim = {'joint_output': joint_out_dim} 175 | 176 | predictor_hidden_dim = int((input_dim + joint_out_dim) / 2) 177 | predictor = MLP( 178 | hidden_dims=[predictor_hidden_dim for _ in range(n_layers)], 179 | output_dim=joint_out_dim, 180 | **mlp_shared_params 181 | ) 182 | predictor_heads['joint_output'] = predictor 183 | 184 | self.predictor_heads = nn.ModuleDict(predictor_heads) 185 | 186 | def forward(self, 187 | hidden_input: Tensor, # (batch-size, hidden-dim) 1D tensor 188 | as_dict: bool = False, 189 | ) -> Union[Dict[str, Tensor], Tensor]: 190 | 191 | name_to_prediction = OrderedDict() 192 | for name, predictor in self.predictor_heads.items(): 193 | name_to_prediction[name] = predictor(hidden_input) 194 | 195 | if self.separate_heads: 196 | if as_dict: 197 | return name_to_prediction 198 | else: 199 | joint_output = torch.cat(list(name_to_prediction.values()), dim=-1) 200 | return joint_output 201 | else: 202 | return name_to_prediction if as_dict else name_to_prediction['joint_output'] 203 | -------------------------------------------------------------------------------- /climart/models/modules/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Optional, Dict, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | from climart.utils.utils import get_activation_function, get_normalization_layer, get_logger 8 | 9 | log = get_logger(__name__) 10 | 11 | 12 | class MLP(nn.Module): 13 | def __init__(self, 14 | input_dim: int, 15 | hidden_dims: Sequence[int], 16 | output_dim: int, 17 | net_normalization: Optional[str] = None, 18 | activation_function: str = 'gelu', 19 | dropout: float = 0.0, 20 | residual: bool = False, 21 | output_normalization: bool = False, 22 | output_activation_function: Optional[Union[str, bool]] = None, 23 | out_layer_bias_init: Tensor = None, 24 | name: str = "" 25 | ): 26 | """ 27 | Args: 28 | input_dim (int): the expected 1D input tensor dim 29 | output_activation_function (str, bool, optional): By default no output activation function is used (None). 30 | If a string is passed, is must be the name of the desired output activation (e.g. 'softmax') 31 | If True, the same activation function is used as defined by the arg `activation_function`. 32 | """ 33 | 34 | super().__init__() 35 | self.name = name 36 | hidden_layers = [] 37 | dims = [input_dim] + list(hidden_dims) 38 | for i in range(1, len(dims)): 39 | hidden_layers += [MLP_Block( 40 | in_dim=dims[i - 1], 41 | out_dim=dims[i], 42 | net_norm=net_normalization.lower() if isinstance(net_normalization, str) else 'none', 43 | activation_function=activation_function, 44 | dropout=dropout, 45 | residual=residual 46 | )] 47 | self.hidden_layers = nn.ModuleList(hidden_layers) 48 | 49 | out_weight = nn.Linear(dims[-1], output_dim, bias=True) 50 | if out_layer_bias_init is not None: 51 | log.info(' Pre-initializing the MLP final/output layer bias.') 52 | out_weight.bias.data = out_layer_bias_init 53 | out_layer = [out_weight] 54 | if output_normalization and net_normalization != 'none': 55 | out_layer += [get_normalization_layer(net_normalization, output_dim)] 56 | if output_activation_function is not None and output_activation_function: 57 | if isinstance(output_activation_function, bool): 58 | output_activation_function = activation_function 59 | 60 | out_layer += [get_activation_function(output_activation_function, functional=False)] 61 | self.out_layer = nn.Sequential(*out_layer) 62 | 63 | def forward(self, X: Tensor) -> Tensor: 64 | for layer in self.hidden_layers: 65 | X = layer(X) 66 | 67 | Y = self.out_layer(X) 68 | return Y.squeeze(1) 69 | 70 | 71 | class MLP_Block(nn.Module): 72 | def __init__(self, 73 | in_dim: int, 74 | out_dim: int, 75 | net_norm: str = 'none', 76 | activation_function: str = 'Gelu', 77 | dropout: float = 0.0, 78 | residual: bool = False 79 | ): 80 | super().__init__() 81 | layer = [nn.Linear(in_dim, out_dim, bias=net_norm != 'batch_norm')] 82 | if net_norm != 'none': 83 | layer += [get_normalization_layer(net_norm, out_dim)] 84 | layer += [get_activation_function(activation_function, functional=False)] 85 | if dropout > 0: 86 | layer += [nn.Dropout(dropout)] 87 | self.layer = nn.Sequential(*layer) 88 | self.residual = residual 89 | if in_dim != out_dim: 90 | self.residual = False 91 | elif residual: 92 | print('MLP block with residual!') 93 | 94 | def forward(self, X: Tensor) -> Tensor: 95 | X_out = self.layer(X) 96 | if self.residual: 97 | X_out += X 98 | return X_out 99 | -------------------------------------------------------------------------------- /climart/train.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from hydra.utils import instantiate as hydra_instantiate 3 | from omegaconf import DictConfig 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning import seed_everything 7 | 8 | import climart.utils.config_utils as cfg_utils 9 | from climart.interface import get_model_and_data 10 | from climart.utils.utils import get_logger 11 | 12 | 13 | def run_model(config: DictConfig): 14 | seed_everything(config.seed, workers=True) 15 | log = get_logger(__name__) 16 | cfg_utils.extras(config) 17 | 18 | if config.get("print_config"): 19 | cfg_utils.print_config(config, fields='all') 20 | 21 | emulator_model, data_module = get_model_and_data(config) 22 | 23 | # Init Lightning callbacks and loggers 24 | callbacks = cfg_utils.get_all_instantiable_hydra_modules(config, 'callbacks') 25 | loggers = cfg_utils.get_all_instantiable_hydra_modules(config, 'logger') 26 | 27 | # Init Lightning trainer 28 | trainer: pl.Trainer = hydra_instantiate( 29 | config.trainer, callbacks=callbacks, logger=loggers, # , deterministic=True 30 | ) 31 | 32 | # Send some parameters from config to all lightning loggers 33 | log.info("Logging hyperparameters to the PyTorch Lightning loggers.") 34 | cfg_utils.log_hyperparameters(config=config, model=emulator_model, data_module=data_module, trainer=trainer, 35 | callbacks=callbacks) 36 | 37 | trainer.fit(model=emulator_model, datamodule=data_module) 38 | 39 | cfg_utils.save_hydra_config_to_wandb(config) 40 | 41 | # Testing: 42 | if config.get("test_after_training"): 43 | trainer.test(datamodule=data_module, ckpt_path='best') 44 | 45 | if config.get('logger') and config.logger.get("wandb"): 46 | wandb.finish() 47 | 48 | # log.info("Reloading model from checkpoint based on best validation stat.") 49 | # final_model = emulator_model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path, 50 | # datamodule_config=config.datamodule, output_normalizer=data_module.normalizer.output_normalizer) 51 | # return final_model 52 | -------------------------------------------------------------------------------- /climart/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/utils/__init__.py -------------------------------------------------------------------------------- /climart/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, List, Union, Sequence 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class TestingScheduleCallback: 7 | def __init__(self, 8 | start_epoch: int = 1, 9 | test_at_most_every_n_epochs: int = 10, 10 | test_at_least_every_n_epochs: int = 20, 11 | test_on_new_best_validation: bool = True, 12 | ignore_first_n_epochs: int = 5, 13 | ): 14 | self.test_at_most_every_n_epochs = test_at_most_every_n_epochs 15 | self.test_at_least_every_n_epochs = test_at_least_every_n_epochs 16 | self.test_on_new_best_validation = test_on_new_best_validation 17 | self.untested_epochs = 1 18 | self.cur_epoch = start_epoch 19 | self.ignore_first_n_epochs = ignore_first_n_epochs 20 | 21 | def __call__(self, is_new_best_val_model: bool = False) -> bool: 22 | do_test = False 23 | if self.cur_epoch <= self.ignore_first_n_epochs: 24 | do_test = False 25 | elif self.untested_epochs >= self.test_at_least_every_n_epochs: 26 | do_test = True 27 | elif self.test_on_new_best_validation and is_new_best_val_model: 28 | if self.untested_epochs < self.test_at_most_every_n_epochs: 29 | do_test = False 30 | else: 31 | do_test = True 32 | 33 | self.cur_epoch += 1 34 | if do_test: 35 | self.untested_epochs = 1 36 | else: 37 | self.untested_epochs += 1 38 | return do_test 39 | 40 | 41 | class PredictionPostProcessCallback: 42 | def __init__(self, 43 | variables: List[str], 44 | sizes: Union[int, Sequence[int]] 45 | ): 46 | self.variable_to_channel = dict() 47 | cur = 0 48 | sizes = [sizes for _ in range(len(variables))] if isinstance(sizes, int) else sizes 49 | for var, size in zip(variables, sizes): 50 | self.variable_to_channel[var] = {'start': cur, 'end': cur + size} 51 | cur += size 52 | 53 | def split_vector_by_variable(self, 54 | vector: Union[np.ndarray, torch.Tensor] 55 | ) -> Dict[str, Union[np.ndarray, torch.Tensor]]: 56 | if isinstance(vector, dict): 57 | return vector 58 | splitted_vector = dict() 59 | for var_name, var_channel_limits in self.variable_to_channel.items(): 60 | splitted_vector[var_name] = vector[..., var_channel_limits['start']:var_channel_limits['end']] 61 | return splitted_vector 62 | 63 | def __call__(self, vector, *args, **kwargs): 64 | return self.split_vector_by_variable(vector) 65 | -------------------------------------------------------------------------------- /climart/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import warnings 4 | from typing import Union, Sequence, List 5 | 6 | import omegaconf 7 | import pytorch_lightning as pl 8 | import wandb 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | 11 | from climart.utils.naming import get_group_name, get_detailed_name 12 | from climart.utils.utils import no_op, get_logger 13 | 14 | log = get_logger(__name__) 15 | 16 | def print_config( 17 | config, 18 | fields: Union[str, Sequence[str]] = ( 19 | "datamodule", 20 | "model", 21 | "trainer", 22 | # "callbacks", 23 | # "logger", 24 | "seed", 25 | ), 26 | resolve: bool = True, 27 | ) -> None: 28 | """Prints content of DictConfig using Rich library and its tree structure. 29 | 30 | Credits go to: https://github.com/ashleve/lightning-hydra-template 31 | 32 | Args: 33 | config (ConfigDict): Configuration 34 | fields (Sequence[str], optional): Determines which main fields from config will 35 | be printed and in what order. 36 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 37 | """ 38 | import importlib 39 | if not importlib.util.find_spec("rich") or not importlib.util.find_spec("omegaconf"): 40 | # no pretty printing 41 | return 42 | import rich.syntax 43 | import rich.tree 44 | 45 | style = "dim" 46 | tree = rich.tree.Tree(":gear: CONFIG", style=style, guide_style=style) 47 | if isinstance(fields, str): 48 | if fields.lower() == 'all': 49 | fields = config.keys() 50 | else: 51 | fields = [fields] 52 | 53 | for field in fields: 54 | branch = tree.add(field, style=style, guide_style=style) 55 | 56 | config_section = config.get(field) 57 | branch_content = str(config_section) 58 | if isinstance(config_section, DictConfig): 59 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 60 | 61 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 62 | 63 | rich.print(tree) 64 | 65 | 66 | def extras(config: DictConfig) -> None: 67 | """A couple of optional utilities, controlled by main config file: 68 | - disabling warnings 69 | - easier access to debug mode 70 | - forcing debug friendly configuration 71 | - forcing multi-gpu friendly configuration 72 | 73 | Credits go to: https://github.com/ashleve/lightning-hydra-template 74 | 75 | Modifies DictConfig in place. 76 | """ 77 | log = get_logger() 78 | 79 | # Create working dir if it does not exist yet 80 | if config.get('work_dir'): 81 | os.makedirs(name=config.get("work_dir"), exist_ok=True) 82 | 83 | # disable python warnings if 84 | if config.get("ignore_warnings"): 85 | log.info("Disabling python warnings! ") 86 | warnings.filterwarnings("ignore") 87 | 88 | # set if 89 | if config.get("debug"): 90 | log.info("Running in debug mode! ") 91 | config.trainer.fast_dev_run = True 92 | 93 | # force debugger friendly configuration if 94 | if config.trainer.get("fast_dev_run"): 95 | log.info("Forcing debugger friendly configuration! ") 96 | # Debuggers don't like GPUs or multiprocessing 97 | if config.trainer.get("gpus"): 98 | config.trainer.gpus = 0 99 | if config.datamodule.get("pin_memory"): 100 | config.datamodule.pin_memory = False 101 | if config.datamodule.get("num_workers"): 102 | config.datamodule.num_workers = 0 103 | 104 | # force multi-gpu friendly configuration if 105 | accelerator = config.trainer.get("accelerator") 106 | if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]: 107 | log.info(f"Forcing ddp friendly configuration! ") 108 | if config.datamodule.get("num_workers"): 109 | config.datamodule.num_workers = 0 110 | if config.datamodule.get("pin_memory"): 111 | config.datamodule.pin_memory = False 112 | 113 | USE_WANDB = "logger" in config.keys() and config.logger.get("wandb") 114 | if USE_WANDB: 115 | if not config.logger.wandb.get('id'): # no wandb id has been assigned yet 116 | wandb_id = wandb.util.generate_id() 117 | config.logger.wandb.id = wandb_id 118 | else: 119 | log.info(f" This experiment config already has a wandb run ID = {config.logger.wandb.id}") 120 | if not config.logger.wandb.get('group'): # no wandb group has been assigned yet 121 | group_name = get_group_name(config) 122 | config.logger.wandb.group = group_name if len(group_name) < 128 else group_name[:128] 123 | config.logger.wandb.name = get_detailed_name(config) + '_' + time.strftime('%Hh%Mm_on_%b_%d') + '_' + config.logger.wandb.id 124 | 125 | check_config_values(config) 126 | if USE_WANDB: 127 | wandb_kwargs = { 128 | k: config.logger.wandb.get(k) for k in ['id', 'project', 'entity', 'name', 'group', 129 | 'tags', 'notes', 'reinit', 'mode', 'resume'] 130 | } 131 | wandb_kwargs['dir'] = config.logger.wandb.get('save_dir') 132 | wandb.init(**wandb_kwargs) 133 | log.info(f"Wandb kwargs: {wandb_kwargs}") 134 | save_hydra_config_to_wandb(config) 135 | 136 | 137 | def check_config_values(config: DictConfig): 138 | exp_type = config.datamodule.exp_type.lower() 139 | config.datamodule.exp_type = exp_type 140 | if exp_type not in ["clear_sky", "pristine"]: 141 | raise ValueError(f"Arg `exp_type` should be one of clear_sky or pristine, but got {exp_type}") 142 | 143 | if "net_normalization" in config.model.keys(): 144 | if config.model.net_normalization is None: 145 | config.model.net_normalization = "none" 146 | config.model.net_normalization = config.model.net_normalization.lower() 147 | 148 | if config.logger.get("wandb"): 149 | if 'callbacks' in config and config.callbacks.get('model_checkpoint'): 150 | id_mdl = config.logger.wandb.get('id') 151 | d = config.callbacks.model_checkpoint.dirpath 152 | if id_mdl is not None: 153 | with open_dict(config): 154 | new_dir = os.path.join(d, id_mdl) 155 | config.callbacks.model_checkpoint.dirpath = new_dir 156 | os.makedirs(new_dir, exist_ok=True) 157 | log.info(f" Model checkpoints will be saved in: {new_dir}") 158 | else: 159 | if config.save_config_to_wandb: 160 | log.warning(" `save_config_to_wandb`=True but no wandb logger was found.. config will not be saved!") 161 | config.save_config_to_wandb = False 162 | 163 | 164 | def get_all_instantiable_hydra_modules(config, module_name: str): 165 | from hydra.utils import instantiate as hydra_instantiate 166 | modules = [] 167 | if module_name in config: 168 | for _, module_config in config[module_name].items(): 169 | if module_config is not None and "_target_" in module_config: 170 | modules.append(hydra_instantiate(module_config)) 171 | return modules 172 | 173 | 174 | def log_hyperparameters( 175 | config, 176 | model: pl.LightningModule, 177 | data_module: pl.LightningDataModule, 178 | trainer: pl.Trainer, 179 | callbacks: List[pl.Callback], 180 | ) -> None: 181 | """This method controls which parameters from Hydra config are saved by Lightning loggers. 182 | Credits go to: https://github.com/ashleve/lightning-hydra-template 183 | 184 | Additionally saves: 185 | - number of {total, trainable, non-trainable} model parameters 186 | """ 187 | 188 | def copy_and_ignore_keys(dictionary, *keys_to_ignore): 189 | new_dict = dict() 190 | for k in dictionary.keys(): 191 | if k not in keys_to_ignore: 192 | new_dict[k] = dictionary[k] 193 | return new_dict 194 | 195 | params = dict() 196 | if 'seed' in config: 197 | params['seed'] = config['seed'] 198 | if 'model' in config: 199 | params['model'] = config['model'] 200 | 201 | # Remove redundant keys or those that are not important to know after training -- feel free to edit this! 202 | params["datamodule"] = copy_and_ignore_keys(config["datamodule"], 'pin_memory', 'num_workers') 203 | params['model'] = copy_and_ignore_keys(config['model'], 'optimizer', 'scheduler') 204 | params['normalizer'] = config['normalizer'] 205 | params["trainer"] = copy_and_ignore_keys(config["trainer"]) 206 | # encoder, optims, and scheduler as separate top-level key 207 | params['optim'] = config['model']['optimizer'] 208 | params['scheduler'] = config['model']['scheduler'] if 'scheduler' in config['model'] else None 209 | 210 | if "callbacks" in config: 211 | if 'model_checkpoint' in config['callbacks']: 212 | params["model_checkpoint"] = copy_and_ignore_keys( 213 | config["callbacks"]['model_checkpoint'], 'save_top_k' 214 | ) 215 | 216 | # save number of model parameters 217 | params["model/params_total"] = sum(p.numel() for p in model.parameters()) 218 | params["model/params_trainable"] = sum( 219 | p.numel() for p in model.parameters() if p.requires_grad 220 | ) 221 | params["model/params_not_trainable"] = sum( 222 | p.numel() for p in model.parameters() if not p.requires_grad 223 | ) 224 | params['dirs/work_dir'] = config.get('work_dir') 225 | params['dirs/ckpt_dir'] = config.get('ckpt_dir') 226 | params['dirs/wandb_save_dir'] = config.logger.wandb.save_dir if ( 227 | config.get('logger') and config.logger.get('wandb')) else None 228 | 229 | # send hparams to all loggers 230 | trainer.logger.log_hyperparams(params) 231 | 232 | # disable logging any more hyperparameters for all loggers 233 | # this is just a trick to prevent trainer from logging hparams of model, 234 | # since we already did that above 235 | trainer.logger.log_hyperparams = no_op 236 | 237 | 238 | def save_hydra_config_to_wandb(config: DictConfig): 239 | if config.get('save_config_to_wandb'): 240 | log.info(f"Hydra config will be saved to WandB as hydra_config.yaml and in wandb run_dir: {wandb.run.dir}") 241 | # files in wandb.run.dir folder get directly uploaded to wandb 242 | with open(os.path.join(wandb.run.dir, "hydra_config.yaml"), "w") as fp: 243 | OmegaConf.save(config, f=fp.name, resolve=True) 244 | wandb.save(os.path.join(wandb.run.dir, "hydra_config.yaml")) 245 | 246 | 247 | def get_config_from_hydra_compose_overrides(overrides: list) -> DictConfig: 248 | import hydra 249 | from hydra.core.global_hydra import GlobalHydra 250 | overrides = list(set(overrides)) 251 | if '-m' in overrides: 252 | overrides.remove('-m') # if multiruns flags are mistakenly in overrides 253 | hydra.initialize(config_path="../../configs") 254 | try: 255 | config = hydra.compose(config_name="main_config", overrides=overrides) 256 | finally: 257 | GlobalHydra.instance().clear() # always clean up global hydra 258 | return config 259 | 260 | 261 | def get_model_from_hydra_compose_overrides(overrides: list): 262 | from climart.interface import get_model 263 | cfg = get_config_from_hydra_compose_overrides(overrides) 264 | return get_model(cfg) 265 | -------------------------------------------------------------------------------- /climart/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | # from sklearn.metrics import mean_squared_error, mean_absolute_error 5 | from climart.data_loading.constants import get_statistics 6 | from climart.utils import utils 7 | 8 | log = utils.get_logger(__name__) 9 | 10 | 11 | def evaluate_preds(Ytrue: np.ndarray, preds: np.ndarray): 12 | MSE = np.mean((preds - Ytrue) ** 2) # mean_squared_error(preds, Ytrue) 13 | RMSE = np.sqrt(MSE) 14 | # MAE = mean_absolute_error(preds, Ytrue) 15 | MBE = np.mean(preds - Ytrue) 16 | stats = {'mbe': MBE, 17 | # 'mse': MSE, 18 | # 'mae': MAE, 19 | "rmse": RMSE} 20 | 21 | return stats 22 | 23 | 24 | def evaluate_per_target_variable(Ytrue: dict, 25 | preds: dict, 26 | data_split: str = None) -> Dict[str, float]: 27 | stats = dict() 28 | if not isinstance(Ytrue, dict): 29 | log.warning(f" Expected a dictionary var_name->Tensor/nd_array, but got {type(Ytrue)} for Ytrue!") 30 | return stats 31 | 32 | for var_name in Ytrue.keys(): 33 | # var_name stands for 'rsuc', 'hrsc', etc., i.e. shortwave upwelling flux, shortwave heating rate, etc. 34 | var_stats = evaluate_preds(Ytrue[var_name], preds[var_name]) 35 | # pre-append the variable's name to its specific performance on the returned metrics dict 36 | for metric_name, metric_stat in var_stats.items(): 37 | stats[f"{data_split}/{var_name}/{metric_name}"] = metric_stat 38 | 39 | num_height_levels = Ytrue[var_name].shape[1] 40 | for lvl in range(0, num_height_levels): 41 | stats_lvl = evaluate_preds(Ytrue[var_name][:, lvl], preds[var_name][:, lvl]) 42 | for metric_name, metric_stat in stats_lvl.items(): 43 | stats[f"levelwise/{data_split}/{var_name}_level{lvl}/{metric_name}"] = metric_stat 44 | 45 | if lvl == num_height_levels - 1: 46 | for metric_name, metric_stat in stats_lvl.items(): 47 | stats[f"{data_split}/{var_name}_surface/{metric_name}"] = metric_stat 48 | if lvl == 0: 49 | for metric_name, metric_stat in stats_lvl.items(): 50 | stats[f"{data_split}/{var_name}_toa/{metric_name}"] = metric_stat 51 | 52 | return stats 53 | -------------------------------------------------------------------------------- /climart/utils/naming.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | from omegaconf import DictConfig 3 | 4 | 5 | def _shared_prefix(config: DictConfig, init_prefix: str = "") -> str: 6 | s = init_prefix if isinstance(init_prefix, str) else "" 7 | if 'clear' in config.datamodule.get('exp_type'): 8 | s += '_CS' 9 | s += f"_{config.datamodule.get('train_years')}train" + f"_{config.datamodule.get('validation_years')}val" 10 | if config.normalizer.get('input_normalization'): 11 | s += f"_{config.normalizer.get('input_normalization').upper()}" 12 | if config.normalizer.get('output_normalization'): 13 | s += f"_{config.normalizer.get('output_normalization').upper()}" 14 | return s.lstrip('_') 15 | 16 | 17 | def get_name_for_hydra_config_class(config: DictConfig) -> str: 18 | if 'name' in config and config.get('name') is not None: 19 | return config.get('name') 20 | elif '_target_' in config: 21 | return config._target_.split('.')[-1] 22 | return "$" 23 | 24 | 25 | def get_detailed_name(config) -> str: 26 | """ This is a prefix for naming the runs for a more agreeable logging.""" 27 | s = config.get("name", '') 28 | s = _shared_prefix(config, init_prefix=s) + '_' 29 | if config.model.get('dropout') > 0: 30 | s += f"{config.model.get('dropout')}dout_" 31 | 32 | s += config.model.get('activation_function') + '_' 33 | s += get_name_for_hydra_config_class(config.model.optimizer) + '_' 34 | s += get_name_for_hydra_config_class(config.model.scheduler) + '_' 35 | 36 | s += f"{config.datamodule.get('batch_size')}bs_" 37 | s += f"{config.model.optimizer.get('lr')}lr_" 38 | if config.model.optimizer.get('weight_decay') > 0: 39 | s += f"{config.model.optimizer.get('weight_decay')}wd_" 40 | 41 | hdims = config.model.get('hidden_dims') 42 | if all([h == hdims[0] for h in hdims]): 43 | hdims = f"{hdims[0]}x{len(hdims)}" 44 | else: 45 | hdims = str(hdims) 46 | s += f"{hdims}h" # &{net_params['out_dim']}oDim" 47 | # if not params['shuffle']: 48 | # s += 'noShuffle_' 49 | s += f"{config.get('seed')}seed" 50 | 51 | return s.replace('None', '') 52 | 53 | 54 | def get_model_name(name: str) -> str: 55 | if 'CNN' in name: 56 | return 'CNN' 57 | elif 'MLP' in name: 58 | return 'MLP' 59 | elif 'GCN' in name: 60 | return 'GCN' 61 | elif 'GN' in name: 62 | return 'GraphNet' 63 | else: 64 | raise ValueError(name) 65 | 66 | 67 | def get_group_name(config) -> str: 68 | s = get_name_for_hydra_config_class(config.model) 69 | s = s.lower().replace('net', '').replace('_', '').replace("climart", "").replace("with", "+").upper() 70 | s = _shared_prefix(config, init_prefix=s) 71 | 72 | if config.normalizer.get('spatial_normalization_in') and config.normalizer.get('spatial_normalization_out'): 73 | s += '+spatialNormed' 74 | elif config.normalizer.get('spatial_normalization_in'): 75 | s += '+spatialInNormed' 76 | elif config.normalizer.get('spatial_normalization_out'): 77 | s += '+spatialOutNormed' 78 | 79 | return s 80 | 81 | 82 | def stem_word(word: str) -> str: 83 | return word.lower().strip().replace('-', '').replace('&', '').replace('+', '').replace('_', '') 84 | 85 | 86 | def get_exp_ID(exp_type: str, target_types: Union[str, List[str]], target_variables: Union[str, List[str]]): 87 | s = f"{exp_type.upper()} conditions, with {' '.join(target_types)} x {' '.join(target_variables)} targets" 88 | return s 89 | -------------------------------------------------------------------------------- /climart/utils/optimization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def get_loss(name, reduction='mean'): 5 | name = name.lower().strip().replace('-', '_') 6 | if name in ['l1', 'mae', "mean_absolute_error"]: 7 | loss = nn.L1Loss(reduction=reduction) 8 | elif name in ['l2', 'mse', "mean_squared_error"]: 9 | loss = nn.MSELoss(reduction=reduction) 10 | elif name in ['smoothl1', 'smooth']: 11 | loss = nn.SmoothL1Loss(reduction=reduction) 12 | else: 13 | raise ValueError(f'Unknown loss function {name}') 14 | return loss 15 | 16 | 17 | def get_trainable_params(model): 18 | trainable_params = [] 19 | for name, param in model.named_parameters(): 20 | if param.requires_grad: 21 | trainable_params.append(param) 22 | return trainable_params 23 | -------------------------------------------------------------------------------- /climart/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.colors import TwoSlopeNorm 3 | import numpy as np 4 | import einops 5 | import xarray as xr 6 | 7 | from climart.data_loading.constants import get_coordinates 8 | 9 | 10 | def set_labels_and_ticks(ax, 11 | title: str = "", 12 | xlabel: str = "", ylabel: str = "", 13 | xlabel_fontsize: int = 10, ylabel_fontsize: int = 14, 14 | xlim=None, ylim=None, 15 | xticks=None, yticks=None, 16 | title_fontsize: int = None, 17 | xticks_fontsize: int = None, yticks_fontsize: int = None, 18 | xtick_labels=None, ytick_labels=None, 19 | logscale_y: bool = False, 20 | show: bool = True, 21 | grid: bool = True, 22 | legend: bool = True, legend_loc='best', legend_prop=10, 23 | full_screen: bool = False, 24 | tight_layout: bool = True, 25 | save_to: str = None 26 | ): 27 | ax.set_title(title, fontsize=title_fontsize) 28 | ax.set_xlabel(xlabel, fontsize=xlabel_fontsize) 29 | ax.set_ylabel(ylabel, fontsize=ylabel_fontsize) 30 | 31 | ax.set_xlim(xlim) 32 | ax.set_ylim(ylim) 33 | if xticks is not None: 34 | ax.set_xticks(xticks) 35 | if xtick_labels is not None: 36 | ax.set_xticklabels(xtick_labels) 37 | if xticks_fontsize: 38 | for tick in ax.xaxis.get_major_ticks(): 39 | tick.label.set_fontsize(xticks_fontsize) 40 | # tick.label.set_rotation('vertical') 41 | 42 | if logscale_y: 43 | ax.set_yscale('log') 44 | if yticks is not None: 45 | ax.set_yticks(yticks) 46 | if ytick_labels is not None: 47 | ax.set_yticklabels(ytick_labels) 48 | if yticks_fontsize: 49 | for tick in ax.yaxis.get_major_ticks(): 50 | tick.label.set_fontsize(yticks_fontsize) 51 | 52 | if grid: 53 | ax.grid() 54 | if legend: 55 | ax.legend(loc=legend_loc, prop={'size': legend_prop}) #if full_screen else ax.legend(loc=legend_loc) 56 | 57 | if tight_layout: 58 | plt.tight_layout() 59 | 60 | if save_to is not None: 61 | if full_screen: 62 | mng = plt.get_current_fig_manager() 63 | mng.full_screen_toggle() 64 | 65 | plt.savefig(save_to, bbox_inches='tight') 66 | if full_screen: 67 | mng.full_screen_toggle() 68 | 69 | if show: 70 | plt.show() 71 | 72 | 73 | class RollingCmaps: 74 | def __init__(self, 75 | unique_keys: list, 76 | pos_cmaps: list = None, 77 | max_key_occurence: int = 5): 78 | if pos_cmaps is None: 79 | pos_cmaps = ['Greens', 'Oranges', 'Blues', 'Greys', 'Purples'] 80 | pos_cmaps = [plt.get_cmap(cmap) for cmap in pos_cmaps] 81 | self.cmaps = {key: pos_cmaps[i] for i, key in enumerate(unique_keys)} 82 | self.pos_per_cmap = {key: 0.75 for key in unique_keys} # lower makes lines too white 83 | self.max_key_occurence = max_key_occurence 84 | 85 | def __getitem__(self, key): 86 | color = self.cmaps[key](self.pos_per_cmap[key] / self.max_key_occurence) # [self.pos_per_cmap[key]] 87 | self.pos_per_cmap[key] += 1 88 | return color 89 | 90 | 91 | class RollingLineFormats: 92 | def __init__(self, 93 | unique_keys: list, 94 | pos_markers: list = None, 95 | cmap = None, 96 | linewidth: float = 4 97 | ): 98 | print(unique_keys) 99 | if pos_markers is None: 100 | pos_markers = ['-', '--', ':', '-', '-.'] 101 | if cmap is None: 102 | cmap = plt.get_cmap('viridis') 103 | cs = ['#1f77b4', '#ff7f0e', '#2ca02c', '#9467bd', '#8c564b', 104 | '#e377c2', '#7f7f7f', '#d62728', '#bcbd22', '#17becf'] 105 | # cs = plt.rcParams['axes.prop_cycle'].by_key()['color'] 106 | 107 | self.pos_markers = pos_markers 108 | # self.cmaps = {key: cmap(i/len(unique_keys)) for i, key in enumerate(unique_keys)} 109 | self.cmaps = {key: cs[i] for i, key in enumerate(unique_keys)} 110 | self.pos_per_key = {key: 0 for key in unique_keys} # lower makes lines too white 111 | self.lws = {key: linewidth for key in unique_keys} 112 | 113 | def __getitem__(self, key): 114 | cur_i = self.pos_per_key[key] 115 | lw = self.lws[key] 116 | line_format = self.pos_markers[cur_i] # [self.pos_per_cmap[key]] 117 | self.pos_per_key[key] += 1 118 | self.lws[key] = max(1, lw - 1) 119 | return line_format, dict(c=self.cmaps[key], linewidth=lw) 120 | 121 | 122 | def plot_groups(xaxis_key, metric='Test/MAE', ax=None, show: bool = True, **kwargs): 123 | if not ax: 124 | fig, ax = plt.subplots() # 1 125 | 126 | for key, group in kwargs.items(): 127 | group.plot(xaxis_key, metric, yerr='std', label=key, ax=ax) 128 | 129 | set_labels_and_ticks( 130 | ax, xlabel='Used training points', ylabel=metric, show=show 131 | ) 132 | 133 | 134 | def height_errors(Ytrue: np.ndarray, preds: np.ndarray, height_ticks=None, 135 | xlabel='', ylabel='height', fill_between=True, show=True): 136 | """ 137 | Plot MAE and MBE as a function of the height/pressure 138 | :param Ytrue: 139 | :param preds: 140 | :param height_ticks: must have same shape as Ytrue.shape[1] 141 | :param show: 142 | :return: 143 | """ 144 | n_samples, n_levels = Ytrue.shape 145 | diff = Ytrue - preds 146 | abs_diff = np.abs(diff) 147 | levelwise_MBE = np.mean(diff, axis=0) 148 | levelwise_MAE = np.mean(abs_diff, axis=0) 149 | 150 | levelwise_MBE_std = np.std(diff, axis=0) 151 | levelwise_MAE_std = np.std(abs_diff, axis=0) 152 | 153 | # Plotting 154 | plotting_kwargs = {'yticks': height_ticks, 'ylabel': ylabel, 'show': show, "fill_between": fill_between} 155 | yaxis = np.arange(n_levels) 156 | figMBE = height_plot(yaxis, levelwise_MBE, levelwise_MBE_std, xlabel=xlabel + ' MBE', **plotting_kwargs) 157 | figMAE = height_plot(yaxis, levelwise_MAE, levelwise_MAE_std, xlabel=xlabel + ' MAE', **plotting_kwargs) 158 | 159 | if show: 160 | plt.show() 161 | return figMAE, figMBE 162 | 163 | 164 | def height_plot(yaxis, line, std, yticks=None, ylabel=None, xlabel=None, show=False, fill_between=True): 165 | fig, ax = plt.subplots(1) 166 | if "mbe" in xlabel.lower(): 167 | # to better see the bias 168 | ax.plot(np.zeros(yaxis.shape), yaxis, '--', color='grey') 169 | 170 | p = ax.plot(line, yaxis, '-', linewidth=3) 171 | if fill_between: 172 | ax.fill_betweenx(yaxis, line - std, line + std, alpha=0.2) 173 | else: 174 | ax.plot(line - std, yaxis, '--', color=p[0].get_color(), linewidth=1.5) 175 | ax.plot(line + std, yaxis, '--', color=p[0].get_color(), linewidth=1.5) 176 | 177 | xlim = [0, ax.get_xlim()[1]] if 'mae' in xlabel.lower() or 'rmse' in xlabel.lower() else None 178 | set_labels_and_ticks(ax=ax, xlabel=xlabel, xlim=xlim, 179 | yticks=yaxis, ytick_labels=yticks, 180 | ylabel=ylabel, show=show) 181 | 182 | return fig 183 | 184 | 185 | def level_errors(Y_true, Y_preds, epoch): 186 | errors = np.mean((Y_true - Y_preds), axis=0) 187 | colours = ['red' if x < 0 else 'green' for x in errors] 188 | index = np.arange(0, len(colours), 1) 189 | 190 | # Draw plot 191 | lev_fig = plt.figure(figsize=(14, 14), dpi=80) 192 | plt.hlines(y=index, xmin=0, xmax=errors) 193 | for x, y, tex in zip(errors, index, errors): 194 | t = plt.text(x, y, round(tex, 2), horizontalalignment='right' if x < 0 else 'left', 195 | verticalalignment='center', fontdict={'color': 'red' if x < 0 else 'green', 'size': 10}) 196 | 197 | # Styling 198 | plt.yticks(index, ['Level: ' + str(z) for z in index], fontsize=12) 199 | plt.title(f'Average Level-wise error for epoch: {epoch}', fontdict={'size': 20}) 200 | plt.grid(linestyle='--', alpha=0.5) 201 | plt.xlim(-5, 5) 202 | 203 | return lev_fig 204 | 205 | 206 | def profile_errors(Y_true, Y_preds, plot_profiles=200, var_name=None, data_dir: str = None, 207 | error_type='mean', plot_type='scatter', set_seed=False, title=""): 208 | coords_data = get_coordinates(data_dir) 209 | lat = list(coords_data.get_index('lat')) 210 | lon = list(coords_data.get_index('lon')) 211 | 212 | total_profiles, n_levels = Y_true.shape 213 | 214 | if set_seed: # To get the same profiles everytime 215 | np.random.seed(7) 216 | 217 | errors = np.abs(Y_true - Y_preds) 218 | # print(errors.shape, Y_true.shape, total_profiles / 8192) 219 | 220 | if plot_type.lower() == 'scatter': 221 | latitude = [] 222 | longitude = [] 223 | 224 | for i in lat: 225 | for j in lon: 226 | latitude.append(i) 227 | longitude.append(j) 228 | 229 | lat_var = np.array(latitude) 230 | lon_var = np.array(longitude) 231 | 232 | n_times = int(total_profiles / 8192) 233 | indices = np.arange(0, total_profiles) 234 | indices_train = np.random.choice(total_profiles, total_profiles - plot_profiles, replace=False) 235 | indices_rest = np.setxor1d(indices_train, indices, assume_unique=True) 236 | 237 | lat_var = np.mean(np.vstack([np.expand_dims(lat_var, 1)] * n_times), axis=1) 238 | lon_var = np.mean(np.vstack([np.expand_dims(lon_var, 1)] * n_times), axis=1) 239 | lon_plot = lon_var[indices_rest] 240 | lat_plot = lat_var[indices_rest] 241 | errors_lev = errors[indices_rest] 242 | errors_lev = einops.rearrange(np.array(errors_lev), 'p l -> l p') # p = profile dim 243 | print(errors.shape, Y_true.shape) # (81920, 50) (81920, 50) 244 | else: 245 | errors_lev = errors.reshape(n_levels, 8192, -1) # level x spatial_dim x snapshot_dim 246 | errors_lev = np.mean(errors_lev, axis=2) # mean over all snapshots 247 | errors_lev = errors_lev.reshape((n_levels, len(lat), len(lon))) # reshape back to spatial grid 248 | lon_plot, lat_plot = np.meshgrid(lon, lat) 249 | 250 | if error_type.lower() == 'toa': 251 | err = errors_lev[0] 252 | elif error_type.lower() == 'surface': 253 | err = errors_lev[-1] 254 | elif error_type.lower() == 'mean': 255 | err = np.mean(errors_lev, axis=0) 256 | 257 | pp = profile_plot(lon_plot, lat_plot, err, var_name, plot_type=plot_type, title=title) 258 | 259 | return pp 260 | 261 | 262 | def profile_plot(lon_plot, lat_plot, errors, var_name=None, plot_type='scatter', dpi=70, title=""): 263 | """ 264 | 265 | :param lon_plot: 266 | :param lat_plot: 267 | :param errors: Note that if plot_type is 2D, i.e is in ['heatmap', 'contour'], it has to have shape lat x lon 268 | :param var_name: 269 | :param plot_type: 270 | :param dpi: 271 | :return: 272 | """ 273 | import cartopy.crs as ccrs 274 | import cartopy.feature as cfeature 275 | fig = plt.figure(figsize=(12, 8)) 276 | plt.rcParams['figure.dpi'] = dpi 277 | plot_type = plot_type.lower() 278 | 279 | ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson()) 280 | ax.add_feature(cfeature.COASTLINE) 281 | ax.gridlines() 282 | # ax.stock_img() 283 | ax.set_global() 284 | 285 | # jet = plt.cm.get_cmap('RdBu_r') 286 | jet = plt.cm.get_cmap('plasma') 287 | 288 | # nightshade 289 | # current_time = datetime.now() Can work well if we can get the UTC time for the data 290 | # ax.add_feature(Nightshade(current_time, alpha=0.3)) 291 | 292 | # Circle params 293 | fs_text = 10 294 | padd = -0.18 295 | stroffset = -0.2 296 | circlesize = 100 297 | lw = 2.2 298 | 299 | if plot_type == 'contour': 300 | sc = ax.contourf(lon_plot, lat_plot, errors, transform=ccrs.PlateCarree(), alpha=0.85, cmap="Reds", 301 | levels=100) 302 | elif plot_type == 'heatmap': 303 | sc = ax.pcolormesh(lon_plot, lat_plot, errors, cmap="Reds", transform=ccrs.PlateCarree()) 304 | else: 305 | sc = ax.scatter(x=lon_plot, y=lat_plot, s=circlesize, 306 | c=errors, norm=TwoSlopeNorm(5, vmin=0, vmax=10), 307 | alpha=0.8, cmap=jet, linewidths=lw, 308 | transform=ccrs.PlateCarree()) 309 | 310 | ax.set_title(f'{title}{plot_type.upper()}-{var_name.upper()} error', fontsize=fs_text) 311 | # Colour Bar 312 | cbar = plt.colorbar(sc, ax=ax, aspect=30, pad=0.01, shrink=0.4, orientation='vertical') 313 | cbar.ax.set_ylabel('W m$^{-2}$', rotation=270, labelpad=10) 314 | return fig 315 | 316 | 317 | def prediction_hist(preds: dict, TOA=False, surface=False, 318 | title="", show=True, figsize=(16, 12), axes=None, 319 | label="", **kwargs): 320 | n_vars = len(preds.keys()) 321 | n_cols = 3 if TOA and surface else 2 if (TOA or surface) else 1 322 | 323 | surface_ax = 1 324 | TOA_ax = 2 if surface else 1 325 | 326 | if axes is None: 327 | fig, axs = plt.subplots(n_vars, n_cols, figsize=figsize) 328 | fig.suptitle("Prediction magnitudes" if title == "" else title) 329 | axs[0, 0].set_title('Mean') 330 | 331 | if surface: 332 | axs[0, surface_ax].set_title('Surface') 333 | if TOA: 334 | axs[0, TOA_ax].set_title('TOA') 335 | else: 336 | axs = axes 337 | 338 | def set_bar_colors(patches, upto=5): 339 | return 340 | jet = plt.get_cmap('jet', len(patches)) 341 | for i in range(len(patches)): 342 | if i > upto: 343 | return 344 | patches[i].set_facecolor(jet(i * 10)) 345 | 346 | for (var_name, var_preds), ax_row in zip(preds.items(), axs): 347 | # n_samples, n_levels = var_preds.shape 348 | N, bins, patches = ax_row[0].hist(np.mean(var_preds, axis=1), label=label, **kwargs) 349 | set_bar_colors(patches) 350 | ax_row[0].set_ylabel(f"{var_name.upper()}", fontsize=20) 351 | if surface: 352 | N, bins, patches = ax_row[surface_ax].hist(var_preds[:, -1], label=label, **kwargs) 353 | set_bar_colors(patches) 354 | if TOA: 355 | N, bins, patches = ax_row[TOA_ax].hist(var_preds[:, 0], label=label, **kwargs) 356 | set_bar_colors(patches) 357 | 358 | axs[0, 0].legend() 359 | if show: 360 | plt.show() 361 | 362 | return axs 363 | 364 | 365 | def prediction_bars(preds: dict, bins, TOA=False, surface=False, 366 | title="", show=True, figsize=(16, 12), axes=None, 367 | label="", **kwargs): 368 | n_vars = len(preds.keys()) 369 | n_cols = 3 if TOA and surface else 2 if (TOA or surface) else 1 370 | 371 | surface_ax = 1 372 | TOA_ax = 2 if surface else 1 373 | 374 | if axes is None: 375 | fig, axs = plt.subplots(n_vars, n_cols, figsize=figsize) 376 | fig.suptitle("Prediction magnitudes" if title == "" else title) 377 | axs[0, 0].set_title('Mean') 378 | 379 | if surface: 380 | axs[0, surface_ax].set_title('Surface') 381 | if TOA: 382 | axs[0, TOA_ax].set_title('TOA') 383 | else: 384 | axs = axes 385 | 386 | for i, ((var_name, var_preds), ax_row) in enumerate(zip(preds.items(), axs)): 387 | 388 | if False: # i == 1: 389 | kwargs['tick_label'] = ['{} - {}'.format(bins[i], bins[i + 1]) for i, j in enumerate(hist)] 390 | 391 | hist, bin_edges = np.histogram(np.mean(var_preds, axis=1), bins) 392 | ax_row[0].bar(range(len(hist)), hist, width=1, align='center', label=label, **kwargs) 393 | ax_row[0].set_ylabel(f"{var_name.upper()}", fontsize=20) 394 | if surface: 395 | hist, bin_edges = np.histogram(var_preds[:, -1], bins) 396 | ax_row[surface_ax].bar(range(len(hist)), hist, width=1, align='center', label=label, **kwargs) 397 | if TOA: 398 | hist, bin_edges = np.histogram(var_preds[:, 0], bins) 399 | ax_row[TOA_ax].bar(range(len(hist)), hist, width=1, align='center', label=label, **kwargs) 400 | 401 | axs[0, 0].legend() 402 | if show: 403 | plt.show() 404 | 405 | return axs 406 | -------------------------------------------------------------------------------- /climart/utils/postprocessing.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import wandb 4 | import os 5 | 6 | from functools import partial 7 | import torch 8 | import xarray as xr 9 | import numpy as np 10 | from climart.data_loading.constants import LEVELS, LAYERS, GLOBALS, get_data_dims, get_metadata, get_coordinates 11 | 12 | def get_lat_lon(data_dir: str = None): 13 | coords_data = get_coordinates(data_dir) 14 | lat = list(coords_data.get_index('lat')) 15 | lon = list(coords_data.get_index('lon')) 16 | 17 | latitude = [] 18 | longitude = [] 19 | for i in lat: 20 | for j in lon: 21 | latitude.append(i) 22 | longitude.append(j) 23 | lat_var = np.array(latitude) 24 | lon_var = np.array(longitude) 25 | return {'latitude': lat, 'longitude': lon, 'latitude_flattened': lat_var, 'longitude_flattened': lon_var} 26 | 27 | 28 | # %% 29 | 30 | def save_preds_to_netcdf(preds, 31 | targets, 32 | post_fix: str = '', 33 | save_path=None, 34 | exp_type='pristine', 35 | data_dir: str = None, 36 | model=None, 37 | **kwargs): 38 | lat_lon = get_lat_lon(data_dir) 39 | lat, lon = lat_lon['latitude'], lat_lon['longitude'] 40 | spatial_dim, _ = get_data_dims(exp_type) 41 | n_levels = spatial_dim[LEVELS] 42 | n_layers = spatial_dim[LAYERS] 43 | shape = ['snapshot', 'latitude', 'longitude', 'level'] 44 | shape_lay = ['snapshot', 'latitude', 'longitude', 'layer'] 45 | shape_glob = ['snapshot', 'latitude', 'longitude'] 46 | 47 | meta_info = get_metadata(data_dir) 48 | 49 | data_vars = dict() 50 | for k, v in preds.items(): 51 | data_vars[f"{k}_preds"] = (shape, v.reshape((-1, len(lat), len(lon), n_levels))) 52 | for k, v in targets.items(): 53 | data_vars[f"{k}_targets"] = (shape, v.reshape((-1, len(lat), len(lon), n_levels))) 54 | 55 | data_vars["pressure"] = (shape, kwargs['pressure'].reshape((-1, len(lat), len(lon), n_levels))) 56 | data_vars["layer_pressure"] = (shape_lay, kwargs['layer_pressure'].reshape((-1, len(lat), len(lon), n_layers))) 57 | data_vars["cszrow"] = (shape_glob, kwargs['cszrow'].reshape((-1, len(lat), len(lon)))) 58 | 59 | xr_dset = xr.Dataset( 60 | data_vars=data_vars, 61 | coords=dict( 62 | longitude=lon, 63 | latitude=lat, 64 | level=list(range(n_levels))[::-1], 65 | layer=list(range(n_layers))[::-1], 66 | ), 67 | attrs=dict(description="ML emulated RT outputs."), 68 | ) 69 | if save_path is not None: 70 | if not save_path.endswith('.nc'): 71 | save_path += '.nc' 72 | save_path.replace('.nc', post_fix + '.nc') 73 | 74 | elif model is not None: 75 | save_path = f'./example_{exp_type}_preds_{model}_{post_fix}.nc' 76 | else: 77 | print("Not saving to NC!") 78 | return xr_dset 79 | if not os.path.isfile(save_path): 80 | xr_dset.to_netcdf(save_path) 81 | print('saved to\n', save_path) 82 | return xr_dset 83 | 84 | 85 | def restore_run(run_id, 86 | run_path: str, 87 | api=None, 88 | ): 89 | if api is None: 90 | api = wandb.Api() 91 | run_path = f"{run_path}/{run_id}" 92 | run = api.run(run_path) 93 | return run 94 | 95 | 96 | def restore_ckpt_from_wandb_run(run, entity: str = 'ecc-mila7', run_path=None, load: bool = False, **kwargs): 97 | run_id = run.id 98 | ckpt = [f for f in run.files() if f"{run_id}.pkl" in str(f)] 99 | ckpt = str(ckpt[0].name) 100 | ckpt = wandb.restore(ckpt, run_path=f"{entity}/ClimART/{run_id}") 101 | ckpt_fname = ckpt.name 102 | if load: 103 | return torch.load(ckpt_fname, **kwargs) 104 | return ckpt_fname 105 | 106 | 107 | def restore_ckpt_from_wandb(run_id, 108 | run_path: str, 109 | api=None, 110 | load: bool = False, 111 | **kwargs): 112 | run = restore_run(run_id, run_path, api) 113 | return restore_ckpt_from_wandb(run, load=load, **kwargs) 114 | -------------------------------------------------------------------------------- /climart/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Salva Rühling Cachay 3 | """ 4 | import functools 5 | import logging 6 | import math 7 | import os 8 | 9 | from types import SimpleNamespace 10 | from typing import Union, Sequence, List, Dict, Optional, Callable 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from omegaconf import DictConfig, open_dict, OmegaConf 17 | from torch import Tensor 18 | from pytorch_lightning.utilities import rank_zero_only 19 | from climart.data_loading import constants, data_variables 20 | 21 | 22 | def no_op(*args, **kwargs): 23 | pass 24 | 25 | 26 | def get_identity_callable(*args, **kwargs) -> Callable: 27 | return identity 28 | 29 | 30 | def get_activation_function(name: str, functional: bool = False, num: int = 1): 31 | name = name.lower().strip() 32 | 33 | def get_functional(s: str) -> Optional[Callable]: 34 | return {"softmax": F.softmax, "relu": F.relu, "tanh": torch.tanh, "sigmoid": torch.sigmoid, 35 | "identity": nn.Identity(), 36 | None: None, 'swish': F.silu, 'silu': F.silu, 'elu': F.elu, 'gelu': F.gelu, 'prelu': nn.PReLU(), 37 | }[s] 38 | 39 | def get_nn(s: str) -> Optional[Callable]: 40 | return {"softmax": nn.Softmax(dim=1), "relu": nn.ReLU(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), 41 | "identity": nn.Identity(), 'silu': nn.SiLU(), 'elu': nn.ELU(), 'prelu': nn.PReLU(), 42 | 'swish': nn.SiLU(), 'gelu': nn.GELU(), 43 | }[s] 44 | 45 | if num == 1: 46 | return get_functional(name) if functional else get_nn(name) 47 | else: 48 | return [get_nn(name) for _ in range(num)] 49 | 50 | 51 | def get_normalization_layer(name, dims, num_groups=None, *args, **kwargs): 52 | if not isinstance(name, str) or name.lower() == 'none': 53 | return None 54 | elif 'batch' in name: 55 | return nn.BatchNorm1d(num_features=dims, *args, **kwargs) 56 | elif 'layer' in name: 57 | return nn.LayerNorm(dims, *args, **kwargs) 58 | elif 'inst' in name: 59 | return nn.InstanceNorm1d(num_features=dims, *args, **kwargs) 60 | elif 'group' in name: 61 | if num_groups is None: 62 | num_groups = int(dims / 10) 63 | return nn.GroupNorm(num_groups=num_groups, num_channels=dims) 64 | else: 65 | raise ValueError("Unknown normalization name", name) 66 | 67 | 68 | def identity(X): 69 | return X 70 | 71 | 72 | def to_dict(obj: Optional[Union[dict, SimpleNamespace]]): 73 | if obj is None: 74 | return dict() 75 | elif isinstance(obj, dict): 76 | return obj 77 | else: 78 | return vars(obj) 79 | 80 | 81 | def to_DictConfig(obj: Optional[Union[List, Dict]]): 82 | if isinstance(obj, DictConfig): 83 | return obj 84 | 85 | if isinstance(obj, list): 86 | try: 87 | dict_config = OmegaConf.from_dotlist(obj) 88 | except ValueError as e: 89 | dict_config = OmegaConf.create(obj) 90 | 91 | elif isinstance(obj, dict): 92 | dict_config = OmegaConf.create(obj) 93 | 94 | else: 95 | dict_config = OmegaConf.create() # empty 96 | 97 | return dict_config 98 | 99 | 100 | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: 101 | """Initializes multi-GPU-friendly python logger.""" 102 | logger = logging.getLogger(name) 103 | logger.setLevel(level) 104 | 105 | # this ensures all logging levels get marked with the rank zero decorator 106 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 107 | for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): 108 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 109 | 110 | return logger 111 | 112 | 113 | ##### 114 | # The following two functions extend setattr and getattr to support chained objects, e.g. rsetattr(cfg, optim.lr, 1e-4) 115 | # From https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-subobjects-chained-properties 116 | def rsetattr(obj, attr, val): 117 | pre, _, post = attr.rpartition('.') 118 | return setattr(rgetattr(obj, pre) if pre else obj, post, val) 119 | 120 | 121 | def rgetattr(obj, attr, *args): 122 | def _getattr(obj, attr): 123 | return getattr(obj, attr, *args) 124 | 125 | return functools.reduce(_getattr, [obj] + attr.split('.')) 126 | 127 | 128 | def adj_to_edge_indices(adj: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: 129 | """ 130 | Args: 131 | adj: a (N, N) adjacency matrix, where N is the number of nodes 132 | Returns: 133 | A (2, E) array, edge_idxs, where E is the number of edges, 134 | and edge_idxs[0], edge_idxs[1] are the source & destination nodes, respectively. 135 | """ 136 | edge_tuples = torch.nonzero(adj, as_tuple=True) if torch.is_tensor(adj) else np.nonzero(adj) 137 | edge_src = edge_tuples[0].unsqueeze(0) if torch.is_tensor(adj) else np.expand_dims(edge_tuples[0], axis=0) 138 | edge_dest = edge_tuples[1].unsqueeze(0) if torch.is_tensor(adj) else np.expand_dims(edge_tuples[1], axis=0) 139 | if torch.is_tensor(adj): 140 | edge_idxs = torch.cat((edge_src, edge_dest), dim=0) 141 | else: 142 | edge_idxs = np.concatenate((edge_src, edge_dest), axis=0) 143 | return edge_idxs 144 | 145 | 146 | def normalize_adjacency_matrix_torch(adj: Tensor, improved: bool = True, add_self_loops: bool = False): 147 | if add_self_loops: 148 | fill_value = 2. if improved else 1. 149 | adj = adj.fill_diagonal_(fill_value) 150 | deg: Tensor = torch.sum(adj, dim=1) 151 | deg_inv_sqrt: Tensor = deg.pow_(-0.5) 152 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) 153 | adj_t = torch.mul(adj, deg_inv_sqrt.view(-1, 1)) 154 | adj_t = torch.mul(adj_t, deg_inv_sqrt.view(1, -1)) 155 | return adj_t 156 | 157 | 158 | def normalize_adjacency_matrix(adj: np.ndarray, improved: bool = True, add_self_loops: bool = True): 159 | if add_self_loops: 160 | fill_value = 2. if improved else 1. 161 | np.fill_diagonal(adj, fill_value) 162 | deg = np.sum(adj, axis=1) 163 | deg_inv_sqrt = np.power(deg, -0.5) 164 | deg_inv_sqrt[np.isinf(deg_inv_sqrt)] = 0. 165 | 166 | deg_inv_sqrt_matrix = np.diag(deg_inv_sqrt) 167 | adj_normed = deg_inv_sqrt_matrix @ adj @ deg_inv_sqrt_matrix 168 | return adj_normed 169 | 170 | 171 | def set_gpu(gpu_id): 172 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 173 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 174 | 175 | 176 | def set_seed(seed, device='cuda'): 177 | import random, torch 178 | # setting seeds 179 | random.seed(seed) 180 | np.random.seed(seed) 181 | torch.manual_seed(seed) 182 | if device != 'cpu': 183 | torch.backends.cudnn.deterministic = True 184 | torch.cuda.manual_seed(seed) 185 | torch.cuda.manual_seed_all(seed) 186 | 187 | 188 | def year_string_to_list(year_string: str) -> List[int]: 189 | """ 190 | Args: 191 | year_string (str): must only contain {digits, '-', '+'}. 192 | Examples: 193 | '1988-90' will return [1988, 1989, 1990] 194 | '1988-1990+2001-2004' will return [1988, 1989, 1990, 2001, 2002, 2003, 2004] 195 | """ 196 | if not isinstance(year_string, str): 197 | year_string = str(year_string) 198 | 199 | def year_string_to_full_year(year_string: str): 200 | if len(year_string) == 4: 201 | return int(year_string) 202 | assert len(year_string) == 2, f'Year {year_string} had an unexpected length.' 203 | if int(year_string[0]) < 3: 204 | return int('20' + year_string) 205 | else: 206 | return int('19' + year_string) 207 | 208 | def update_years(year_list: List[int], year_start, year_end): 209 | if not isinstance(year_start, int): 210 | year_start = year_string_to_full_year(year_start) 211 | if year_end == '': 212 | year_end = year_start 213 | else: 214 | year_end = year_string_to_full_year(year_end) 215 | year_list += list(range(year_start, year_end + 1)) 216 | return year_list, '', '' 217 | 218 | years = [] 219 | cur_year_start = cur_year_end = '' 220 | for char in year_string: 221 | if char == '-': 222 | cur_year_start = year_string_to_full_year(cur_year_start) 223 | elif char == '+': 224 | years, cur_year_start, cur_year_end = update_years(years, cur_year_start, cur_year_end) 225 | else: 226 | if isinstance(cur_year_start, int): 227 | cur_year_end += char 228 | else: 229 | cur_year_start += char 230 | years, _, _ = update_years(years, cur_year_start, cur_year_end) 231 | return years 232 | 233 | 234 | def target_var_id_mapping(x, y): 235 | k = 'l' if y.lower().replace('_', '').replace('-', '') == 'longwave' else 's' 236 | x = x.lower().replace('-', '_').replace('rates', 'rate').replace('_fluxes', '').replace('_flux', '') 237 | if x == 'heating_rate': 238 | return f'hr{k}c' 239 | elif x == 'upwelling': 240 | return f"r{k}uc" 241 | elif x == 'downwelling': 242 | return f"r{k}dc" 243 | else: 244 | raise ValueError(f"Combination {x} {y} not understood!") 245 | 246 | 247 | def get_target_types(target_type: Union[str, List[str]]) -> List[str]: 248 | if isinstance(target_type, list): 249 | assert all([t in [constants.SHORTWAVE, constants.LONGWAVE] for t in target_type]) 250 | return target_type 251 | target_type2 = target_type.lower().replace('&', '+').replace('-', '') 252 | if target_type2 in ['sw+lw', 'lw+sw', 'shortwave+longwave', 'longwave+shortwave']: 253 | return [constants.SHORTWAVE, constants.LONGWAVE] 254 | elif target_type2 in ['sw', 'shortwave']: 255 | return [constants.SHORTWAVE] 256 | elif target_type2 in ['lw', 'longwave']: 257 | return [constants.LONGWAVE] 258 | else: 259 | raise ValueError(f"Target type `{target_type}` must be one of shortwave, longwave or shortwave+longwave") 260 | 261 | 262 | def get_target_variable_names(target_types: Union[str, List[str]], 263 | target_variable: Union[str, List[str]]) -> List[str]: 264 | out_vars = data_variables.OUT_SHORTWAVE_NOCLOUDS + data_variables.OUT_LONGWAVE_NOCLOUDS \ 265 | + data_variables.OUT_HEATING_RATE_NOCLOUDS 266 | if isinstance(target_variable, list): 267 | if len(target_variable) == 1: 268 | target_variable = target_variable[0] 269 | else: 270 | err_msg = f"Each target var must be in {out_vars}, but got {target_variable}" 271 | assert all([t.lower() in out_vars for t in target_variable]), err_msg 272 | return target_variable 273 | 274 | target_types = get_target_types(target_types) 275 | target_variable2 = target_variable.lower().replace('&', '+').replace('-', '').replace('_', '') 276 | target_variable2 = target_variable2.replace('fluxes', 'flux').replace('heatingrate', 'hr') 277 | target_vars: List[str] = [] 278 | if constants.LONGWAVE in target_types: 279 | target_vars += data_variables.OUT_LONGWAVE_NOCLOUDS + [data_variables.LW_HEATING_RATE] 280 | if constants.SHORTWAVE in target_types: 281 | target_vars += data_variables.OUT_SHORTWAVE_NOCLOUDS + [data_variables.SW_HEATING_RATE] 282 | assert len(target_vars) > 0, f"{target_types}, {target_variable} resulted in zero target_vars!" 283 | 284 | if len(target_vars) == 0: 285 | raise ValueError(f"Target var `{target_variable2}` must be one of fluxes, heating_rate.") 286 | return target_vars 287 | 288 | 289 | def get_target_variable(target_variable: Union[str, List[str]]) -> List[str]: 290 | if isinstance(target_variable, list): 291 | if len(target_variable) == 1 and 'flux' in target_variable[0]: 292 | target_variable = target_variable[0] 293 | else: 294 | return target_variable 295 | target_variable2 = target_variable.lower().replace('&', '+').replace('-', '').replace('_', '') 296 | target_variable2 = target_variable2.replace('fluxes', 'flux').replace('heatingrate', 'hr') 297 | target_vars: List[str] = [] 298 | if target_variable2 == 'hr': 299 | return [constants.SURFACE_FLUXES, constants.TOA_FLUXES, constants.HEATING_RATES] 300 | else: 301 | if 'flux' in target_variable2: 302 | target_vars += [constants.FLUXES] 303 | if 'hr' in target_variable2: 304 | target_vars += [constants.HEATING_RATES] 305 | 306 | if len(target_vars) == 0: 307 | raise ValueError(f"Target var `{target_variable2}` must be one of fluxes, heating_rate.") 308 | return target_vars 309 | 310 | 311 | def pressure_from_level_array(levels_array): 312 | PRESSURE_IDX = 2 313 | return levels_array[..., PRESSURE_IDX] 314 | 315 | 316 | def layer_thickness_from_layer_array(layers_array): 317 | THICKNESS_IDX = 12 318 | return layers_array[..., THICKNESS_IDX] 319 | 320 | 321 | def fluxes_to_heating_rates(upwelling_flux: Union[np.ndarray, Tensor], 322 | downwelling_flux: Union[np.ndarray, Tensor], 323 | pressure: Union[np.ndarray, Tensor], 324 | c: float = 9.761357e-03 325 | ) -> Union[np.ndarray, Tensor]: 326 | """ 327 | N - the batch/data dimension size 328 | L - the number of levels (= number of layers + 1) 329 | 330 | Args: 331 | upwelling_flux: a (N, L) array 332 | downwelling_flux: a (N, L) array 333 | pressure: a (N, L) array representing the levels pressure 334 | or a (N, L, D-lev) array containing *all* level variables (including pressure) 335 | 336 | Returns: 337 | A (N, L-1) array representing the heating rates at each of the L-1 layers 338 | """ 339 | err_msg = f"Upwelling/Downwelling arg should have same shape, but have {upwelling_flux.shape}, {downwelling_flux.shape}" 340 | assert upwelling_flux.shape == downwelling_flux.shape, err_msg 341 | if len(pressure.shape) <= 2: 342 | err_msg = f"pressure arg has not the expected shape (N, #levels), but has shape {pressure.shape}" 343 | assert downwelling_flux.shape == pressure.shape, err_msg 344 | else: 345 | err_msg = "pressure argument is not the expected levels array of shape (N, #levels, #level-vars)" 346 | assert len(pressure.shape) == 3, err_msg 347 | pressure = pressure_from_level_array(pressure) 348 | 349 | c = 9.761357e-03 # 9.76 * 1e-5 350 | # c_p = 1004.98322108, g = 9.81, c = g/c_p 351 | # 3D radiative effects paper uses 8.91/1004 = 0.00977091633 352 | net_level_flux = upwelling_flux - downwelling_flux 353 | net_layer_flux = net_level_flux[:, 1:] - net_level_flux[:, :-1] 354 | pressure_diff = pressure[:, 1:] - pressure[:, :-1] 355 | heating_rate = c * net_layer_flux / pressure_diff 356 | 357 | assert tuple(heating_rate.shape) == (pressure.shape[0], pressure.shape[1] - 1) 358 | 359 | return heating_rate 360 | -------------------------------------------------------------------------------- /climart/utils/wandb_callbacks.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | import wandb 4 | import pytorch_lightning as pl 5 | from pytorch_lightning import Callback, Trainer 6 | from pytorch_lightning.utilities import rank_zero_only 7 | from pytorch_lightning.loggers import LoggerCollection, WandbLogger 8 | from climart.utils.utils import get_logger 9 | 10 | log = get_logger(__name__) 11 | 12 | 13 | def get_wandb_logger(trainer: Trainer) -> WandbLogger: 14 | """Safely get Weights&Biases logger from Trainer.""" 15 | 16 | if trainer.fast_dev_run: 17 | raise Exception( 18 | "Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode." 19 | ) 20 | 21 | if isinstance(trainer.logger, WandbLogger): 22 | return trainer.logger 23 | 24 | if isinstance(trainer.logger, LoggerCollection): 25 | for logger in trainer.logger: 26 | if isinstance(logger, WandbLogger): 27 | return logger 28 | 29 | raise Exception( 30 | "You are using wandb related callback, but WandbLogger was not found for some reason..." 31 | ) 32 | 33 | 34 | class WatchModel(Callback): 35 | """Make wandb watch model at the beginning of the run.""" 36 | 37 | def __init__(self, log: str = "gradients", log_freq: int = 100): 38 | self.log = log 39 | self.log_freq = log_freq 40 | 41 | @rank_zero_only 42 | def on_train_start(self, trainer, pl_module): 43 | logger: WandbLogger = get_wandb_logger(trainer=trainer) 44 | try: 45 | logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq, log_graph=True) 46 | except TypeError as e: 47 | log.info( 48 | f" Pytorch-lightning/Wandb version seems to be too old to support 'log_graph' arg in wandb.watch(.)" 49 | f" Wandb version={wandb.__version__}") 50 | wandb.watch(models=trainer.model, log=self.log, log_freq=self.log_freq) # , log_graph=True) 51 | 52 | 53 | class SummarizeBestValMetric(Callback): 54 | """Make wandb log in run.summary the best achieved monitored val_metric as opposed to the last""" 55 | 56 | @rank_zero_only 57 | def on_train_start(self, trainer, pl_module): 58 | logger: WandbLogger = get_wandb_logger(trainer=trainer) 59 | experiment = logger.experiment 60 | experiment.define_metric(trainer.model.hparams.monitor, summary=trainer.model.hparams.mode) 61 | 62 | 63 | class UploadCodeAsArtifact(Callback): 64 | """Upload all code files to wandb as an artifact, at the beginning of the run.""" 65 | 66 | def __init__(self, code_dir: str, use_git: bool = True): 67 | """ 68 | Args: 69 | code_dir: the code directory 70 | use_git: if using git, then upload all files that are not ignored by git. 71 | if not using git, then upload all '*.py' file 72 | """ 73 | self.code_dir = code_dir 74 | self.use_git = use_git 75 | 76 | @rank_zero_only 77 | def on_train_start(self, trainer, pl_module): 78 | logger = get_wandb_logger(trainer=trainer) 79 | experiment = logger.experiment 80 | 81 | code = wandb.Artifact("project-source", type="code") 82 | 83 | if self.use_git: 84 | # get .git folder path 85 | git_dir_path = Path( 86 | subprocess.check_output(["git", "rev-parse", "--git-dir"]).strip().decode("utf8") 87 | ).resolve() 88 | 89 | for path in Path(self.code_dir).resolve().rglob("*"): 90 | 91 | # don't upload files ignored by git 92 | # https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/ 93 | command = ["git", "check-ignore", "-q", str(path)] 94 | not_ignored = subprocess.run(command).returncode == 1 95 | 96 | # don't upload files from .git folder 97 | not_git = not str(path).startswith(str(git_dir_path)) 98 | 99 | if path.is_file() and not_git and not_ignored: 100 | code.add_file(str(path), name=str(path.relative_to(self.code_dir))) 101 | 102 | else: 103 | for path in Path(self.code_dir).resolve().rglob("*.py"): 104 | code.add_file(str(path), name=str(path.relative_to(self.code_dir))) 105 | 106 | experiment.log_artifact(code) 107 | 108 | 109 | class UploadBestCheckpointAsFile(Callback): 110 | """Upload checkpoints to wandb as a file, at the end of run.""" 111 | @rank_zero_only 112 | def on_train_start(self, trainer, pl_module): 113 | if not hasattr(trainer, 'checkpoint_callback'): 114 | log.warning("pl.Trainer has no checkpoint_callback/ModelCheckpoint() callback even though you use" 115 | " UploadBestCheckpointAsFile - This callback will be ignored!") 116 | 117 | @rank_zero_only 118 | def on_exception(self, trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException) -> None: 119 | self.on_train_end(trainer, pl_module) 120 | 121 | @rank_zero_only 122 | def on_train_end(self, trainer, pl_module): 123 | if not hasattr(trainer, 'checkpoint_callback'): 124 | return 125 | logger = get_wandb_logger(trainer=trainer) 126 | path = trainer.checkpoint_callback.best_model_path 127 | if path is not None: 128 | log.info(f"Best checkpoint path will be saved to wandb from path: {path}") 129 | logger.experiment.log({'best_model_filepath': path}) 130 | logger.experiment.save(path) 131 | 132 | 133 | class UploadCheckpointsAsArtifact(Callback): 134 | """Upload checkpoints to wandb as an artifact, at the end of run.""" 135 | 136 | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = True): 137 | self.ckpt_dir = ckpt_dir 138 | self.upload_best_only = upload_best_only 139 | 140 | @rank_zero_only 141 | def on_exception(self, trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException) -> None: 142 | self.on_train_end(trainer, pl_module) 143 | 144 | @rank_zero_only 145 | def on_train_end(self, trainer, pl_module): 146 | logger = get_wandb_logger(trainer=trainer) 147 | 148 | ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") 149 | 150 | if self.upload_best_only: 151 | ckpts.add_file(trainer.checkpoint_callback.best_model_path) 152 | else: 153 | for path in Path(self.ckpt_dir).rglob("*.ckpt"): 154 | ckpts.add_file(str(path)) 155 | 156 | logger.experiment.log_artifact(ckpts) 157 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: ${val_metric} # name of the logged metric which determines when model is improving 4 | mode: "min" # "max" means higher metric value is better, can be also "min" 5 | save_top_k: 1 # save k best models (determined by above metric) 6 | save_last: True # additionally always save model from last epoch 7 | verbose: False 8 | dirpath: ${ckpt_dir} 9 | filename: "epoch{epoch:03d}_seed${seed}" 10 | auto_insert_metric_name: False 11 | 12 | early_stopping: 13 | _target_: pytorch_lightning.callbacks.EarlyStopping 14 | monitor: ${val_metric} # name of the logged metric which determines when model is improving 15 | mode: "min" # "max" means higher metric value is better, can be also "min" 16 | patience: 100 # how many validation epochs of not improving until training stops 17 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 18 | 19 | learning_rate_logging: 20 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 21 | 22 | model_summary: 23 | _target_: pytorch_lightning.callbacks.RichModelSummary 24 | max_depth: 1 25 | 26 | rich_progress_bar: 27 | _target_: pytorch_lightning.callbacks.RichProgressBar 28 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | watch_model: 5 | _target_: climart.utils.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | summarize_best_val_metric: 10 | _target_: climart.utils.wandb_callbacks.SummarizeBestValMetric 11 | 12 | upload_best_ckpt_as_file: 13 | _target_: climart.utils.wandb_callbacks.UploadBestCheckpointAsFile 14 | -------------------------------------------------------------------------------- /configs/experiment/example.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example 5 | 6 | defaults: 7 | - override /mode: exp.yaml 8 | - override /trainer: default.yaml 9 | - override /model: mnist.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: null 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | # name of the run determines folder name in logs 17 | # can also be accessed by loggers 18 | name: "example" 19 | 20 | seed: 12345 21 | 22 | trainer: 23 | min_epochs: 1 24 | max_epochs: 10 25 | gradient_clip_val: 5 26 | 27 | model: 28 | lin1_size: 128 29 | lin2_size: 256 30 | lin3_size: 64 31 | lr: 0.002 32 | 33 | datamodule: 34 | batch_size: 64 35 | train_val_test_split: [55_000, 5_000, 10_000] 36 | 37 | logger: 38 | csv: 39 | name: csv/${name} 40 | wandb: 41 | tags: ["mnist", "simple_dense_net"] 42 | -------------------------------------------------------------------------------- /configs/experiment/reproduce_paper2021_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=reproduce_paper2021_cnn 5 | 6 | defaults: 7 | - override /mode: exp.yaml 8 | - override /trainer: default.yaml 9 | - override /model: cnn.yaml 10 | - override /callbacks: wandb.yaml # feel free to change this to default.yaml or any other callback 11 | - override /logger: wandb.yaml # feel free to change this to your favorite logger 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | 15 | name: "ClimART-21-CNN" 16 | 17 | seed: 7 18 | 19 | trainer: 20 | max_epochs: 100 21 | gradient_clip_val: 1.0 22 | 23 | model: 24 | hidden_dims: [200, 400, 100] 25 | strides: [2, 1, 1] 26 | gap: True 27 | activation_function: "gelu" 28 | net_normalization: "none" 29 | 30 | datamodule: 31 | exp_type: "pristine" 32 | batch_size: 128 33 | train_years: "1990+1999+2003" 34 | validation_years: "2005" 35 | target_type: "shortwave" 36 | # target_variable: "fluxes+heating_rate" 37 | 38 | normalizer: 39 | input_normalization: "z" 40 | output_normalization: null 41 | 42 | logger: 43 | wandb: 44 | tags: ["reproduce-climart-2021", "cnn", "reproduce-cnn"] 45 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | hydra: 14 | # here we define Optuna hyperparameter search 15 | # it optimizes for value returned from function with @hydra.main decorator 16 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper 17 | sweeper: 18 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 19 | storage: null 20 | study_name: null 21 | n_jobs: 1 22 | 23 | # 'minimize' or 'maximize' the objective 24 | direction: maximize 25 | 26 | # number of experiments that will be executed 27 | n_trials: 20 28 | 29 | # choose Optuna hyperparameter sampler 30 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html 31 | sampler: 32 | _target_: optuna.samplers.TPESampler 33 | seed: 12345 34 | consider_prior: true 35 | prior_weight: 1.0 36 | consider_magic_clip: true 37 | consider_endpoints: false 38 | n_startup_trials: 10 39 | n_ei_candidates: 24 40 | multivariate: false 41 | warn_independent_sampling: true 42 | 43 | # define range of hyperparameters 44 | search_space: 45 | datamodule.batch_size: 46 | type: categorical 47 | choices: [32, 64, 128] 48 | model.lr: 49 | type: float 50 | low: 0.0001 51 | high: 0.2 52 | model.lin1_size: 53 | type: categorical 54 | choices: [32, 64, 128, 256, 512] 55 | model.lin2_size: 56 | type: categorical 57 | choices: [32, 64, 128, 256, 512] 58 | model.lin3_size: 59 | type: categorical 60 | choices: [32, 64, 128, 256, 512] 61 | -------------------------------------------------------------------------------- /configs/input_transform/flatten.yaml: -------------------------------------------------------------------------------- 1 | _target_: climart.data_transform.transforms.FlattenTransform 2 | exp_type: ${datamodule.exp_type} -------------------------------------------------------------------------------- /configs/input_transform/graphnet_level_nodes.yaml: -------------------------------------------------------------------------------- 1 | _target_: climart.data_transform.transforms.LayerEdgesAndLevelNodesGraph 2 | exp_type: ${datamodule.exp_type} -------------------------------------------------------------------------------- /configs/input_transform/none.yaml: -------------------------------------------------------------------------------- 1 | _target_: climart.data_transforms.transforms.IdentityTranform 2 | exp_type: ${datamodule.exp_type} -------------------------------------------------------------------------------- /configs/input_transform/repeat_global_vars.yaml: -------------------------------------------------------------------------------- 1 | _target_: climart.data_transform.transforms.RepeatGlobalsTransform 2 | exp_type: ${datamodule.exp_type} -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | project_name: "template-tests" 7 | experiment_name: ${name} 8 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | version: ${name} 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | experiment_name: ${name} 6 | tracking_uri: null 7 | tags: null 8 | save_dir: ./mlruns 9 | prefix: "" 10 | artifact_location: null 11 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project_name: your_name/template-tests 7 | close_after_fit: True 8 | offline_mode: False 9 | experiment_name: ${name} 10 | experiment_id: null 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /configs/logger/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/configs/logger/none.yaml -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: null 7 | version: ${name} 8 | log_graph: False 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | # entity: "some-name" # optionally set to name of your wandb team 6 | name: ${name} 7 | tags: [] 8 | notes: "..." 9 | project: "ClimART" 10 | group: "" 11 | resume: "allow" 12 | reinit: True 13 | mode: online # disabled # disabled for no wandb logging 14 | save_dir: ${work_dir}/ 15 | offline: False # set True to store all logs only locally 16 | id: null # pass correct id to resume experiment! 17 | log_model: False 18 | prefix: "" 19 | job_type: "train" 20 | -------------------------------------------------------------------------------- /configs/main_config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - trainer: default.yaml 7 | - model: mlp.yaml 8 | 9 | - callbacks: default.yaml # or use wandb.yaml for wandb suppport 10 | - logger: wandb # set logger here or use command line (e.g. `python run.py logger=wandb`) 11 | 12 | # modes are special collections of config options for different purposes, e.g. debugging 13 | - mode: default.yaml 14 | 15 | # experiment configs allow for version control of specific configurations 16 | # for example, use them to store best hyperparameters for each combination of model and datamodule 17 | - experiment: null 18 | 19 | # config for hyperparameter optimization 20 | - hparams_search: null 21 | 22 | # optional local config for machine/user specific settings 23 | - optional local: default.yaml 24 | 25 | # enable color logging 26 | #- override hydra/hydra_logging: colorlog 27 | # - override hydra/job_logging: colorlog 28 | # default optimizer is AdamW 29 | - override optimizer@model.optimizer: adamw.yaml 30 | 31 | 32 | datamodule: 33 | _target_: climart.datamodules.pl_climart_datamodule.ClimartDataModule 34 | exp_type: "pristine" 35 | target_type: "shortwave" 36 | target_variable: "fluxes+heating_rate" 37 | train_years: "2000" 38 | validation_years: "2005" 39 | batch_size: 128 40 | eval_batch_size: 512 41 | num_workers: 0 42 | pin_memory: True 43 | load_train_into_mem: True 44 | load_valid_into_mem: True 45 | load_test_into_mem: False 46 | test_main_dataset: True 47 | test_ood_1991: False 48 | test_ood_historic: False 49 | test_ood_future: False 50 | verbose: ${verbose} 51 | # path to folder with data (optional, can also override constants.DATA_DIR to point to correct dir) 52 | # Make sure that it is an absolute path! hydra.runtime.cwd points to the original working dir. 53 | data_dir: "${hydra:runtime.cwd}/ClimART_DATA" # null 54 | 55 | normalizer: 56 | _target_: climart.data_transform.normalization.Normalizer 57 | input_normalization: "z" 58 | output_normalization: "z" 59 | spatial_normalization_in: False 60 | spatial_normalization_out: False 61 | log_scaling: False 62 | data_dir: ${datamodule.data_dir} 63 | verbose: ${verbose} 64 | # path to original working directory 65 | # hydra hijacks working directory by changing it to the new log directory 66 | # so it's useful to have this path as a special variable 67 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 68 | work_dir: ${hydra:runtime.cwd} # {oc.env:ENV_VAR} allows to get environment variable ENV_VAR 69 | 70 | val_metric: "val/${target_var_id:heating_rate, ${datamodule.target_type}}/rmse" 71 | 72 | # path to checkpoints 73 | ckpt_dir: ${work_dir}/checkpoints/ 74 | 75 | # path for logging 76 | log_dir: ${work_dir}/logs/ 77 | 78 | # pretty print config at the start of the run using Rich library 79 | print_config: True 80 | 81 | # disable python warnings if they annoy you 82 | ignore_warnings: True 83 | 84 | # evaluate on test set, using best model weights achieved during training 85 | # lightning chooses best weights based on metric specified in checkpoint callback 86 | test_after_training: True 87 | 88 | # Upload config file to wandb cloud? 89 | save_config_to_wandb: True 90 | 91 | # Verbose? 92 | verbose: True 93 | 94 | # seed for random number generators in pytorch, numpy and python.random 95 | seed: 11 96 | 97 | # name of the run, should be used along with experiment mode 98 | name: null 99 | -------------------------------------------------------------------------------- /configs/mode/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # debug mode sets name of the logging folder to 'logs/debug/...' 4 | # enables trainer debug options 5 | # also sets level od command line logger to DEBUG 6 | # example usage: 7 | # `python run.py mode=debug` 8 | 9 | defaults: 10 | - override /trainer: debug.yaml 11 | - override /model: mlp.yaml 12 | - override /logger: none.yaml 13 | - override /callbacks: none.yaml 14 | 15 | debug_mode: True 16 | 17 | hydra: 18 | # sets level of all command line loggers to 'DEBUG' 19 | verbose: True 20 | 21 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 22 | # use this to set level of only chosen command line loggers to 'DEBUG' 23 | # verbose: [src.train, src.utils] 24 | 25 | run: 26 | dir: ${log_dir}/debug/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 27 | sweep: 28 | dir: ${log_dir}/debug/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S} 29 | subdir: ${hydra.job.num} 30 | 31 | # disable rich config printing, since it will be already printed by hydra when `verbose: True` 32 | print_config: False 33 | -------------------------------------------------------------------------------- /configs/mode/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default running mode 4 | 5 | default_mode: True 6 | 7 | hydra: 8 | # default output paths for all file logs 9 | run: 10 | dir: ${log_dir}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 11 | sweep: 12 | dir: ${log_dir}/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /configs/mode/exp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # experiment mode sets name of the logging folder to the experiment name 4 | # can also be used to name the run in the logger 5 | # example usage: 6 | # `python run.py mode=exp name=some_name` 7 | 8 | experiment_mode: True 9 | 10 | name: ??? 11 | 12 | hydra: 13 | run: 14 | dir: ${log_dir}/experiments/${name}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 15 | sweep: 16 | dir: ${log_dir}/experiments/${name}/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S} 17 | subdir: ${hydra.job.num} 18 | -------------------------------------------------------------------------------- /configs/model/cnn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /input_transform: repeat_global_vars.yaml 3 | - /optimizer: adamw.yaml 4 | 5 | _target_: climart.models.CNNs.CNN.CNN_Net 6 | 7 | hidden_dims: [256, 256, 256] 8 | kernels: [20, 10, 5] 9 | strides: [2, 1, 1] 10 | net_normalization: null 11 | activation_function: "Gelu" 12 | dropout: 0.0 13 | gap: True 14 | 15 | 16 | downwelling_loss_contribution: 0.5 17 | upwelling_loss_contribution: 0.5 18 | heating_rate_loss_contribution: 0 19 | 20 | monitor: ${val_metric} 21 | scheduler: 22 | _target_: torch.optim.lr_scheduler.ExponentialLR 23 | gamma: 0.98 -------------------------------------------------------------------------------- /configs/model/graphnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: climart.models.GraphNet.graph_network.GN_withReadout 2 | 3 | defaults: 4 | - /input_transform: graphnet_level_nodes.yaml 5 | - /optimizer: adamw.yaml 6 | 7 | hidden_dims: [128, 128, 128] 8 | net_normalization: null 9 | activation_function: "Gelu" 10 | dropout: 0.0 11 | residual: True 12 | 13 | update_mlp_n_layers: 1 14 | aggregator_funcs: "mean" 15 | graph_pooling: "mean" 16 | readout_which_output: "nodes" 17 | 18 | downwelling_loss_contribution: 0.5 19 | upwelling_loss_contribution: 0.5 20 | heating_rate_loss_contribution: 0 21 | 22 | monitor: ${val_metric} 23 | scheduler: 24 | _target_: torch.optim.lr_scheduler.ExponentialLR 25 | gamma: 0.98 26 | -------------------------------------------------------------------------------- /configs/model/mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: climart.models.MLP.ClimartMLP 2 | 3 | defaults: 4 | - /input_transform: flatten.yaml 5 | - /optimizer: adamw.yaml 6 | 7 | hidden_dims: [512, 256, 256] 8 | net_normalization: "layer_norm" 9 | activation_function: "Gelu" 10 | dropout: 0.0 11 | residual: False 12 | 13 | downwelling_loss_contribution: 0.5 14 | upwelling_loss_contribution: 0.5 15 | heating_rate_loss_contribution: 0 16 | 17 | monitor: ${val_metric} 18 | scheduler: 19 | _target_: torch.optim.lr_scheduler.ExponentialLR 20 | gamma: 0.98 21 | -------------------------------------------------------------------------------- /configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | name: "adam" 2 | lr: 2e-4 3 | weight_decay: 1e-6 4 | eps: 1e-8 -------------------------------------------------------------------------------- /configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | name: "adamw" 2 | lr: 2e-4 3 | weight_decay: 1e-6 4 | eps: 1e-8 5 | -------------------------------------------------------------------------------- /configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | name: "sgd" 2 | lr: 5e-4 3 | weight_decay: 0.05 4 | momentum: 0.9 -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | gpus: 4 5 | strategy: ddp 6 | -------------------------------------------------------------------------------- /configs/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | gpus: 1 5 | 6 | min_epochs: 1 7 | max_epochs: 2 8 | 9 | # prints 10 | profiler: null 11 | 12 | # debugs 13 | fast_dev_run: true 14 | overfit_batches: 0 15 | limit_train_batches: 1.0 16 | limit_val_batches: 1.0 17 | limit_test_batches: 1.0 18 | track_grad_norm: -1 19 | detect_anomaly: true 20 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | gpus: 1 4 | 5 | min_epochs: 1 6 | max_epochs: 100 7 | 8 | gradient_clip_val: 1.0 9 | 10 | resume_from_checkpoint: null 11 | 12 | # number of validation steps to execute at the beginning of the training 13 | num_sanity_val_steps: 0 14 | -------------------------------------------------------------------------------- /configs/transform/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: climart.data_transforms.transforms.IdentityTranform 2 | exp_type: ${datamodule.exp_type} -------------------------------------------------------------------------------- /download_climart.sh: -------------------------------------------------------------------------------- 1 | # Change the directory where the data will be downloaded below 2 | data_dir="ClimART_DATA" 3 | mkdir -p ${data_dir}/inputs 4 | mkdir -p ${data_dir}/outputs_clear_sky 5 | mkdir -p ${data_dir}/outputs_pristine 6 | 7 | # Uncomment all lines to download all data (which will take time though :)). 8 | 9 | echo "Downloading metadata & statistics..." 10 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/META_INFO.json --output ${data_dir}/META_INFO.json 11 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/statistics.npz --output ${data_dir}/statistics.npz 12 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/areacella_fx_CanESM5.nc --output ${data_dir}/areacella_fx_CanESM5.npz 13 | echo "Done." 14 | 15 | echo "Downloading input files..." 16 | for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 17 | for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 18 | for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 19 | for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 20 | echo "Done." 21 | 22 | echo "Downloading clear-sky targets..." 23 | for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done 24 | for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done 25 | for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done 26 | for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done 27 | echo "Done." 28 | 29 | echo "Downloading pristine-sky targets..." 30 | for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 31 | for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 32 | for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 33 | for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 34 | 35 | echo "Done. Finished downloading ClimART :)" 36 | -------------------------------------------------------------------------------- /download_data_subset.sh: -------------------------------------------------------------------------------- 1 | # Change the directory where the data will be downloaded below 2 | data_dir="ClimART_DATA" 3 | mkdir -p ${data_dir}/inputs 4 | mkdir -p ${data_dir}/outputs_clear_sky 5 | mkdir -p ${data_dir}/outputs_pristine 6 | 7 | # Uncomment all lines to download all data (which will take time though :)). 8 | 9 | echo "Downloading metadata & statistics..." 10 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/META_INFO.json --output ${data_dir}/META_INFO.json 11 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/statistics.npz --output ${data_dir}/statistics.npz 12 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/areacella_fx_CanESM5.nc --output ${data_dir}/areacella_fx_CanESM5.npz 13 | echo "Done." 14 | 15 | echo "Downloading input files..." 16 | for x in 2000 2005 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 17 | for x in {2007..2014} ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 18 | #for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 19 | #for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 20 | #for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 21 | #for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done 22 | echo "Done." 23 | 24 | echo "Downloading clear-sky targets..." 25 | #for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done 26 | #for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done 27 | #for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done 28 | #for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done 29 | echo "Done." 30 | 31 | echo "Downloading pristine-sky targets..." 32 | for x in 2000 2005 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 33 | for x in {2007..2014} ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 34 | #for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 35 | #for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 36 | #for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 37 | #for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done 38 | 39 | echo "Done. Finished downloading the data." -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: climart 2 | channels: 3 | - rusty1s 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - cartopy 9 | - cftime 10 | - cudatoolkit 11 | - curl 12 | - cycler 13 | - dask 14 | - dask-core 15 | - einops>=0.3.0 16 | - geos>=3.9.1 17 | - glib>=2.66.3 18 | - h5py>=3.3.0 19 | - hdf5>=1.10.6 20 | - kiwisolver 21 | - matplotlib>=3.4.2 22 | - mkl 23 | - netcdf4>=1.5.7 24 | - ninja>=1.10.2 25 | - numpy>=1.21.1 26 | - pandas>=1.3.1 27 | - pip>=21.2.3 28 | - pytest 29 | - python>=3.8.0 30 | - python-dateutil>=2.8.2 31 | - pytorch=1.9.0 32 | - pytorch-lightning>=1.5.8 33 | - pytorch-scatter>=2.0.8 34 | - pyyaml>=5.4.1 35 | - qt>=5.12.9 36 | - scipy>=1.7.1 37 | - seaborn 38 | - setuptools=59.5.0 39 | - six>=1.16.0 40 | - tornado 41 | - xarray>=0.19.0 42 | - yaml 43 | - zlib 44 | - pip: 45 | - hydra-core 46 | - hydra_colorlog 47 | - networkx 48 | - pre-commit 49 | - python-dotenv 50 | - rich 51 | - timm 52 | - torchmetrics 53 | - tqdm>=4.62.0 54 | - wandb>=0.12.9 55 | 56 | -------------------------------------------------------------------------------- /images/variable_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/images/variable_table.png -------------------------------------------------------------------------------- /notebooks/2022-06-06-get-predictions-pl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import os\n", 29 | "# Make sure we're in the right directory\n", 30 | "if os.path.basename(os.getcwd()) == \"notebooks\":\n", 31 | " os.chdir(\"..\")" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "## Inference with ClimART data and models" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 8, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import sys\n", 48 | "import numpy as np\n", 49 | "import pytorch_lightning as pl\n", 50 | "from climart.utils.config_utils import get_config_from_hydra_compose_overrides\n", 51 | "from climart.interface import get_model_and_data" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 18, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "num_workers=2" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 19, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "os.environ['WANDB_SILENT']=\"true\"\n", 70 | "np.set_printoptions(suppress=True, threshold=sys.maxsize)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 22, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stderr", 80 | "output_type": "stream", 81 | "text": [ 82 | "Multiprocessing is handled by SLURM.\n", 83 | "GPU available: True, used: True\n", 84 | "TPU available: False, using: 0 TPU cores\n", 85 | "IPU available: False, using: 0 IPUs\n", 86 | "HPU available: False, using: 0 HPUs\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "exp_type='pristine'\n", 92 | "predict_years = \"2012-14\"\n", 93 | "overrides = [f'model=mlp',\n", 94 | " f'datamodule.exp_type={exp_type}',\n", 95 | " f'datamodule.num_workers={num_workers}', \n", 96 | " f'++datamodule.predict_years={predict_years}']\n", 97 | "cfg = get_config_from_hydra_compose_overrides(overrides)\n", 98 | "model, dm = get_model_and_data(cfg)\n", 99 | "trainer = pl.Trainer(gpus=-1)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 55, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stderr", 109 | "output_type": "stream", 110 | "text": [ 111 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" 112 | ] 113 | }, 114 | { 115 | "data": { 116 | "application/vnd.jupyter.widget-view+json": { 117 | "model_id": "90e0261b55f04cccaf82bcb90e3494f2", 118 | "version_major": 2, 119 | "version_minor": 0 120 | }, 121 | "text/plain": [ 122 | "Predicting: 0it [00:00, ?it/s]" 123 | ] 124 | }, 125 | "metadata": {}, 126 | "output_type": "display_data" 127 | } 128 | ], 129 | "source": [ 130 | "results1 = trainer.predict(model=model,datamodule=dm, return_predictions=True)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 62, 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "data": { 140 | "text/plain": [ 141 | "(122880, 49)" 142 | ] 143 | }, 144 | "execution_count": 62, 145 | "metadata": {}, 146 | "output_type": "execute_result" 147 | } 148 | ], 149 | "source": [ 150 | "results1 = model.aggregate_predictions(results1)\n", 151 | "sw_hr_preds_2014 = results1[2014]['preds']['hrsc']\n", 152 | "sw_hr_preds_2014.shape" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "#### Alternatively to command line/config-directed changing of the predict_years, you can also do this programmatically to an existing datamodule as follows:" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 79, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stderr", 169 | "output_type": "stream", 170 | "text": [ 171 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" 172 | ] 173 | }, 174 | { 175 | "data": { 176 | "application/vnd.jupyter.widget-view+json": { 177 | "model_id": "0c78d34a8ed94ced997186cdffb4d0e5", 178 | "version_major": 2, 179 | "version_minor": 0 180 | }, 181 | "text/plain": [ 182 | "Predicting: 0it [00:00, ?it/s]" 183 | ] 184 | }, 185 | "metadata": {}, 186 | "output_type": "display_data" 187 | } 188 | ], 189 | "source": [ 190 | "dm.predict_years = \"2010\"\n", 191 | "results2 = trainer.predict(model, dm)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 80, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "text/plain": [ 202 | "torch.Size([512, 50])" 203 | ] 204 | }, 205 | "execution_count": 80, 206 | "metadata": {}, 207 | "output_type": "execute_result" 208 | } 209 | ], 210 | "source": [ 211 | "results2[0]['preds']['rsuc'].shape" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 81, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "data": { 221 | "text/plain": [ 222 | "(122880, 49)" 223 | ] 224 | }, 225 | "execution_count": 81, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "results2 = model.aggregate_predictions(results2)\n", 232 | "sw_hr_preds_2010 = results2[2010]['preds']['hrsc']\n", 233 | "sw_hr_preds_2010.shape" 234 | ] 235 | } 236 | ], 237 | "metadata": { 238 | "kernelspec": { 239 | "display_name": "Climart-GraphNet", 240 | "language": "python", 241 | "name": "climart_gn" 242 | }, 243 | "language_info": { 244 | "codemirror_mode": { 245 | "name": "ipython", 246 | "version": 3 247 | }, 248 | "file_extension": ".py", 249 | "mimetype": "text/x-python", 250 | "name": "python", 251 | "nbconvert_exporter": "python", 252 | "pygments_lexer": "ipython3", 253 | "version": "3.9.12" 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 4 258 | } -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import hydra 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | # load environment variables from `.env` file if it exists 6 | # recursively searches for `.env` in all folders starting from work dir 7 | from climart.utils.utils import target_var_id_mapping 8 | 9 | dotenv.load_dotenv(override=True) 10 | OmegaConf.register_new_resolver("target_var_id", target_var_id_mapping) 11 | 12 | 13 | @hydra.main(config_path="configs/", config_name="main_config.yaml", version_base=None) 14 | def main(config: DictConfig): 15 | from climart.train import run_model 16 | return run_model(config) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length = 99 3 | profile = black 4 | filter_files = True 5 | 6 | 7 | [flake8] 8 | max_line_length = 99 9 | show_source = True 10 | format = pylint 11 | ignore = 12 | F401 # Module imported but unused 13 | W504 # Line break occurred after a binary operator 14 | F841 # Local variable name is assigned to but never used 15 | F403 # from module import * 16 | E501 # Line too long 17 | exclude = 18 | .git 19 | __pycache__ 20 | data/* 21 | tests/* 22 | notebooks/* 23 | logs/* 24 | 25 | 26 | [tool:pytest] 27 | python_files = tests/* 28 | log_cli = True 29 | markers = 30 | slow 31 | addopts = 32 | --durations=0 33 | --strict-markers 34 | --doctest-modules 35 | filterwarnings = 36 | ignore::DeprecationWarning 37 | ignore::UserWarning 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | keywords = ["radiative transfer", "emulation", "deep learning", "climart", "dataset", "pytorch", "mila", "eccc"] 6 | 7 | setup( 8 | name='climart', 9 | version='0.1.0', 10 | packages=find_packages(), 11 | url='https://github.com/RolnickLab/climart', 12 | license='CC BY 4.0', 13 | author='Salva Rühling Cachay, Venkatesh Ramesh', 14 | author_email='', 15 | keywords=keywords, 16 | description='A comprehensive, large-scale dataset for' 17 | ' benchmarking neural network emulators of the' 18 | ' radiation component in climate and weather models.', 19 | long_description=long_description, 20 | long_description_content_type="text/markdown", 21 | python_requires=">=3.7.0", 22 | install_requires=[ 23 | "torch>=1.7.1", 24 | "scikit-learn", 25 | ], 26 | classifiers=[ 27 | "Intended Audience :: Developers", 28 | "Intended Audience :: Science/Research", 29 | "Operating System :: OS Independent", 30 | "Programming Language :: Python :: 3", 31 | "Programming Language :: Python :: 3.7", 32 | "Programming Language :: Python :: 3.8", 33 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def test_year_string_to_list_parsing(): 4 | from rtml.utils.utils import year_string_to_list 5 | inputs = [ 6 | ('1990', [1990]), 7 | ('1990-1991', [1990, 1991]), 8 | ('1990-91', [1990, 1991]), 9 | ('90-91', [1990, 1991]), 10 | ('1989-1991', [1989, 1990, 1991]), 11 | ('1990-1991+2003-2005', [1990, 1991, 2003, 2004, 2005]), 12 | ('1990+1999+2005-06', [1990, 1999, 2005, 2006]), 13 | ('1990+1999', [1990, 1999]), 14 | ] 15 | for i, (string, expected) in enumerate(inputs): 16 | actual = year_string_to_list(string) 17 | err_msg = f"Input {i+1}: Expected {expected}, but {actual} was returned." 18 | assert all([a == b for a, b in zip(actual, expected)]), err_msg 19 | -------------------------------------------------------------------------------- /tests/test_variables.py: -------------------------------------------------------------------------------- 1 | from climart.data_wrangling.data_variables import INPUT_VARS_CLOUDS, INPUT_VARS_AEROSOLS, _ALL_INPUT_VARS 2 | 3 | 4 | def exp_type_subset_vars_test(): 5 | for k in INPUT_VARS_CLOUDS: 6 | assert k in _ALL_INPUT_VARS, f"Cloud var {k} was expected to be in _ALL_INPUT_VARS." 7 | for k in INPUT_VARS_AEROSOLS: 8 | assert k in _ALL_INPUT_VARS, f"Aerosol var {k} was expected to be in _ALL_INPUT_VARS." 9 | --------------------------------------------------------------------------------