├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── analysis
└── wandb_api.py
├── climart
├── META_INFO.json
├── __init__.py
├── data_loading
│ ├── __init__.py
│ ├── constants.py
│ ├── data_variables.py
│ └── h5_dataset.py
├── data_transform
│ ├── __init__.py
│ ├── normalization.py
│ └── transforms.py
├── datamodules
│ ├── __init__.py
│ └── pl_climart_datamodule.py
├── interface.py
├── models
│ ├── CNNs
│ │ ├── CNN.py
│ │ └── __init__.py
│ ├── GraphNet
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── graph_network.py
│ │ └── graph_network_block.py
│ ├── MLP.py
│ ├── __init__.py
│ ├── base_model.py
│ ├── baseline.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── additional_layers.py
│ │ └── mlp.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── callbacks.py
│ ├── config_utils.py
│ ├── evaluation.py
│ ├── naming.py
│ ├── optimization.py
│ ├── plotting.py
│ ├── postprocessing.py
│ ├── utils.py
│ └── wandb_callbacks.py
├── configs
├── callbacks
│ ├── default.yaml
│ ├── none.yaml
│ └── wandb.yaml
├── experiment
│ ├── example.yaml
│ └── reproduce_paper2021_cnn.yaml
├── hparams_search
│ └── mnist_optuna.yaml
├── input_transform
│ ├── flatten.yaml
│ ├── graphnet_level_nodes.yaml
│ ├── none.yaml
│ └── repeat_global_vars.yaml
├── local
│ └── .gitkeep
├── logger
│ ├── comet.yaml
│ ├── csv.yaml
│ ├── many_loggers.yaml
│ ├── mlflow.yaml
│ ├── neptune.yaml
│ ├── none.yaml
│ ├── tensorboard.yaml
│ └── wandb.yaml
├── main_config.yaml
├── mode
│ ├── debug.yaml
│ ├── default.yaml
│ └── exp.yaml
├── model
│ ├── cnn.yaml
│ ├── graphnet.yaml
│ └── mlp.yaml
├── optimizer
│ ├── adam.yaml
│ ├── adamw.yaml
│ └── sgd.yaml
├── trainer
│ ├── ddp.yaml
│ ├── debug.yaml
│ └── default.yaml
└── transform
│ └── default.yaml
├── download_climart.sh
├── download_data_subset.sh
├── env.yml
├── images
└── variable_table.png
├── notebooks
└── 2022-06-06-get-predictions-pl.ipynb
├── run.py
├── setup.cfg
├── setup.py
└── tests
├── test_utils.py
└── test_variables.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | ### VisualStudioCode
131 | .vscode/*
132 | !.vscode/settings.json
133 | !.vscode/tasks.json
134 | !.vscode/launch.json
135 | !.vscode/extensions.json
136 | *.code-workspace
137 | **/.vscode
138 |
139 | # JetBrains
140 | .idea/
141 |
142 | # Lightning-Hydra-Template
143 | configs/local/default.yaml
144 | data/
145 | logs/
146 | wandb/
147 | .env
148 | .autoenv
149 |
150 | out/*
151 | outputs/*
152 | *.out
153 | *.txt
154 | *__pycache__*
155 | .idea
156 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v3.4.0
4 | hooks:
5 | # list of supported hooks: https://pre-commit.com/hooks.html
6 | - id: trailing-whitespace
7 | - id: end-of-file-fixer
8 | - id: check-yaml
9 | - id: check-added-large-files
10 | - id: debug-statements
11 | - id: detect-private-key
12 |
13 | # python code formatting
14 | - repo: https://github.com/psf/black
15 | rev: 20.8b1
16 | hooks:
17 | - id: black
18 | args: [--line-length, "99"]
19 |
20 | # python import sorting
21 | - repo: https://github.com/PyCQA/isort
22 | rev: 5.8.0
23 | hooks:
24 | - id: isort
25 | args: ["--profile", "black", "--filter-files"]
26 |
27 | # yaml formatting
28 | - repo: https://github.com/pre-commit/mirrors-prettier
29 | rev: v2.3.0
30 | hooks:
31 | - id: prettier
32 | types: [yaml]
33 |
34 | # python code analysis
35 | - repo: https://github.com/PyCQA/flake8
36 | rev: 3.9.2
37 | hooks:
38 | - id: flake8
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ***ClimART*** - A Benchmark Dataset for Emulating Atmospheric Radiative Transfer in Weather and Climate Models
2 |
3 |
4 |
5 |
6 | ![CC BY 4.0][cc-by-image]
7 |
8 | [cc-by-image]: https://i.creativecommons.org/l/by/4.0/88x31.png
9 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg
10 |
11 | ## Official PyTorch Implementation
12 |
13 | ### Using deep learning to optimise radiative transfer calculations.
14 |
15 | Our NeurIPS 2021 Datasets Track paper: https://arxiv.org/abs/2111.14671
16 |
17 | Abstract: *Numerical simulations of Earth's weather and climate require substantial amounts of computation. This has led to a growing interest in replacing subroutines that explicitly compute physical processes with approximate machine learning (ML) methods that are fast at inference time. Within weather and climate models, atmospheric radiative transfer (RT) calculations are especially expensive. This has made them a popular target for neural network-based emulators. However, prior work is hard to compare due to the lack of a comprehensive dataset and standardized best practices for ML benchmarking. To fill this gap, we build a large dataset, ClimART, with more than **10 million** samples from present, pre-industrial, and future climate conditions, based on the Canadian Earth System Model.
18 | ClimART poses several methodological challenges for the ML community, such as multiple out-of-distribution test sets, underlying domain physics, and a trade-off between accuracy and inference speed. We also present several novel baselines that indicate shortcomings of datasets and network architectures used in prior work.*
19 |
20 | **Contact:** Venkatesh Ramesh [(venka97 at gmail)](mailto:venka97@gmail.com) or Salva Rühling Cachay [(salvaruehling at gmail)](mailto:salvaruehling@gmail.com).
21 |
22 | ## Overview:
23 |
24 | * ``climart/``: Package with the main code, baselines and ML training logic.
25 | * ``analysis/``: Scripts to create visualization of the results (requires logging).
26 | * ``configs/``: Yaml configuration files for Hydra that define in a modular way (hyper-)parameters.
27 |
28 | ## Getting Started
29 |
30 | Requirements
31 |
32 |
33 | - Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
34 | - NVIDIA GPUs with at least 8 GB of memory and system with 12 GB RAM (More RAM is required if training with --load_train_into_mem option which allows for faster training). We have done all testing and development using NVIDIA V100 GPUs.
35 | - 64-bit Python >=3.7 and PyTorch >=1.8.1. See https://pytorch.org/ for PyTorch install instructions.
36 | - Python libraries mentioned in ``env.yml`` file, see Getting Started (Need to have miniconda/conda installed).
37 |
38 |
39 |
40 |
41 | Downloading the ClimART Dataset
42 |
43 | By default, only a subset of CLimART is downloaded.
44 | To download the train/val/test years you want, please change the loop in ``data_download.sh.`` appropriately.
45 | To download the whole ClimART dataset, you can simply run
46 |
47 | sudo bash download_climart.sh
48 |
49 |
50 |
51 | **Note:** If you have issues with downloading the data please let us know to help you.
52 |
53 | conda env create -f env.yml # create new environment will all dependencies
54 | conda activate climart # activate the environment called 'climart'
55 | sudo bash download_data_subset.sh # download the dataset (or a subset of it, see above)
56 | python run.py trainer.gpus=0 datamodule.train_years="2000" # train a MLP emulator on 2000
57 |
58 | ## Data Structure
59 |
60 | To avoid storage redundancy, we store one single input array for both pristine- and clear-sky conditions. The dimensions of ClimART’s input arrays are:
61 |
62 | - layers: (N, 49, D-lay)
63 | - levels: (N, 50, 4)
64 | - globals: (N, 82)
65 |
66 |
67 | where N is the data dimension (i.e. the number of examples of a specific year, or, during training, of a batch),
68 | 49 and 50 are the number of layers and levels in a column respectively. Dlay, 4, 82 is the number of features/channels for layers, levels, globals respectively.
69 |
70 | For pristine-sky Dlay = 14, while for clear-sky Dlay = 45, since it contains extra aerosol related variables. The array for pristine-sky conditions can be easily accessed by slicing the first 14 features out of the stored array, e.g.:
71 | ``` pristine_array = layers_array[:, :, : 14] ```. This is automatically done for you when you set the atmospheric
72 | condition type via ```datamodule.exp_type=pristine``` or ```datamodule.exp_type=clear_sky```.
73 |
74 |
75 | ## Baselines
76 |
77 | To reproduce our paper results (for seed = 7), you may choose any of our pre-defined configs in the
78 | [configs/model](configs/model) folder and train it as follows
79 |
80 | ```
81 | # You can replace mlp with "graphnet", "gcn", or "cnn" to run a different ML model
82 | # To train on the CPU, choose trainer.gpus=0
83 | # Specify the directory where the CLimART data is saved with datamodule.data_dir=""
84 | # Test on the OOD subsets by setting arg datamodule.{test_ood_historic, test_ood_1991, test_ood_future}=True
85 | python run.py seed=7 model=mlp trainer.gpus=1
86 | ```
87 |
88 | To reproduce the exact CNN model used in the paper, you can use the following command:
89 | ```
90 | python run.py experiment=reproduce_paper2021_cnn seed=7 # feel free to run for more/other seeds
91 | ```
92 | Note: You can also take a look at
93 | [this WandB report](https://wandb.ai/salv47/ClimART-public-runs/reports/ClimART-paper-CNN-runs--VmlldzozMDUyOTUy)
94 | which shows the results of three runs of the CNN model from the paper.
95 |
96 | ### Inference
97 | Check out [this notebook](notebooks/2022-06-06-get-predictions-pl.ipynb) for simple code on how to extract the predictions
98 | for each target variable from a trained model (for arbitrary years of the ClimART dataset).
99 |
100 | ## Tips
101 |
102 |
103 | Reproducibility & Data Generation code
104 |
105 | To best reproduce our baselines and experiments and/or look into how the ClimART dataset was created/designed,
106 | have a look at our `research_code` branch. It operates on pure PyTorch and has a less clean interface/code
107 | than our main branch -- if you have any questions, let us know!
108 |
109 |
110 |
111 | Testing on OOD data subsets
112 |
113 | By default tests run on the main test dataset only (2007-14), to test on the
114 | historic, future or anomaly test subsets you need to pass/change the arg
115 | datamodule.test_ood_historic=True
(and/or test_ood_future=True
, test_ood_1991=True
),
116 | besides downloading those data files, e.g. via the download_climart.sh
script.
117 |
118 |
119 |
120 |
121 | Overriding nested Hydra config groups
122 |
123 | Nested config groups need to be overridden with a different notation - not with a dot, since it would be interpreted as a string otherwise.
124 | For example, if you want to change the optimizer in the model you want to train, you should run:
125 | python run.py model=graphnet optimizer@model.optimizer=SGD
126 |
127 |
128 |
129 |
130 | Local configurations
131 |
132 | You can easily use a local config file (that,e.g., overrides data paths, working dir etc.), by putting such a yaml config
133 | in the configs/local subdirectory (Hydra searches for & uses by default the file configs/local/default.yaml, if it exists)
134 |
135 |
136 |
137 | Wandb
138 |
139 | If you use Wandb, make sure to select the "Group first prefix" option in the panel settings of the web app.
140 | This will make it easier to browse through the logged metrics.
141 |
142 |
143 |
144 | Credits & Resources
145 |
146 | The following template was extremely useful for getting started with the PL+Hydra implementation:
147 | [ashleve/lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template)
148 |
149 |
150 |
151 |
152 | ## License:
153 | This work is made available under [Attribution 4.0 International (CC BY 4.0)](https://creativecommons.org/licenses/by/4.0/legalcode) license. ![CC BY 4.0][cc-by-shield]
154 |
155 | ## Development
156 |
157 | This repository is currently under active development and you may encounter bugs with some functionality.
158 | Any feedback, extensions & suggestions are welcome!
159 |
160 |
161 | ## Citation
162 | If you find ClimART or this repository helpful, feel free to cite our publication:
163 |
164 | @inproceedings{cachay2021climart,
165 | title={{ClimART}: A Benchmark Dataset for Emulating Atmospheric Radiative Transfer in Weather and Climate Models},
166 | author={Salva R{\"u}hling Cachay and Venkatesh Ramesh and Jason N. S. Cole and Howard Barker and David Rolnick},
167 | booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
168 | year={2021},
169 | url={https://arxiv.org/abs/2111.14671}
170 | }
--------------------------------------------------------------------------------
/analysis/wandb_api.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Callable, List
2 |
3 | import wandb
4 | import pandas as pd
5 |
6 | DF_MAPPING = Callable[[pd.DataFrame], pd.DataFrame]
7 |
8 | exp_to_wandb_project = {
9 | 'pristine': "ClimART",
10 | 'clear_sky': "ClimART"
11 | }
12 |
13 |
14 | # Pre-filters
15 | def has_finished(run) -> bool:
16 | return run.state == "finished"
17 |
18 |
19 | def has_final_metric(run) -> bool:
20 | return 'test/hrsc/rmse' in run.summary.keys()
21 |
22 |
23 | def has_max_metric_value(metric: str = 'test/hrsc/rmse', max_metric_value: float = 1.0) -> Callable:
24 | return lambda run: run.summary[metric] <= max_metric_value
25 |
26 |
27 | def has_tags(tags: Union[str, List[str]]) -> Callable:
28 | if isinstance(tags, str):
29 | tags = [tags]
30 | return lambda run: all([tag in run.tags for tag in tags])
31 |
32 |
33 | def hasnt_tags(tags: Union[str, List[str]]) -> Callable:
34 | if isinstance(tags, str):
35 | tags = [tags]
36 | return lambda run: all([tag not in run.tags for tag in tags])
37 |
38 |
39 | def has_hyperparam_values(**kwargs) -> Callable:
40 | return lambda run: all(hasattr(run.config, hyperparam) and value == run.config[hyperparam]
41 | for hyperparam, value in kwargs.items())
42 |
43 |
44 | def larger_than(**kwargs) -> Callable:
45 | return lambda run: all(hasattr(run.config, hyperparam) and value > run.config[hyperparam]
46 | for hyperparam, value in kwargs.items())
47 |
48 |
49 | def lower_than(**kwargs) -> Callable:
50 | return lambda run: all(hasattr(run.config, hyperparam) and value < run.config[hyperparam]
51 | for hyperparam, value in kwargs.items())
52 |
53 |
54 | def df_larger_than(**kwargs) -> DF_MAPPING:
55 | def f(df) -> pd.DataFrame:
56 | for k, v in kwargs.items():
57 | df = df.loc[getattr(df, k) > v]
58 | return df
59 |
60 | return f
61 |
62 |
63 | def df_lower_than(**kwargs) -> DF_MAPPING:
64 | def f(df) -> pd.DataFrame:
65 | for k, v in kwargs.items():
66 | df = df.loc[getattr(df, k) < v]
67 | return df
68 |
69 | return f
70 |
71 |
72 | def is_model_type(model: str) -> Callable:
73 | return lambda run: model.lower() in run.config['model'].lower()
74 |
75 |
76 | str_to_run_pre_filter = {
77 | 'has_finished': has_finished,
78 | 'has_final_metric': has_final_metric
79 | }
80 |
81 |
82 | # Post-filters
83 | def topk_runs(k: int = 5,
84 | metric: str = 'test/hrsc/rmse',
85 | lower_is_better: bool = True) -> DF_MAPPING:
86 | if lower_is_better:
87 | return lambda df: df.nsmallest(k, metric)
88 | else:
89 | return lambda df: df.nlargest(k, metric)
90 |
91 |
92 | def topk_run_of_each_model_type(k: int = 1,
93 | metric: str = 'test/hrsc/rmse',
94 | lower_is_better: bool = True) -> DF_MAPPING:
95 | topk_filter = topk_runs(k, metric, lower_is_better)
96 |
97 | def topk_runs_per_model(df: pd.DataFrame) -> pd.DataFrame:
98 | models = df.model.unique()
99 | dfs = []
100 | for model in models:
101 | dfs += [topk_filter(df[df.model == model])]
102 | return pd.concat(dfs)
103 |
104 | return topk_runs_per_model
105 |
106 |
107 | def flatten_column_dicts(df: pd.DataFrame) -> pd.DataFrame:
108 | types = df.dtypes
109 | df = pd.concat([df.drop(['preprocessing_dict'], axis=1), df['preprocessing_dict'].apply(pd.Series)], axis=1)
110 | df = pd.concat([df.drop(['spatial_dim'], axis=1), df['spatial_dim'].apply(pd.Series)], axis=1)
111 | df = pd.concat([df.drop(['input_dim'], axis=1), df['input_dim'].apply(pd.Series)], axis=1)
112 | df = pd.concat([df.drop(['target_variable'], axis=1), df['target_variable'].apply(pd.Series)], axis=1)
113 | df = pd.concat([df.drop(['target_type'], axis=1), df['target_type'].apply(pd.Series)], axis=1)
114 | df = pd.concat([df.drop(['hidden_dims'], axis=1), df['hidden_dims'].apply(tuple)], axis=1)
115 | if 'channels_list' in df.columns:
116 | df = df.drop('channels_list', axis=1)
117 | # df = pd.concat([df.drop(['channels_list'], axis=1), df['channels_list'].apply(tuple)], axis=1)
118 |
119 | # df['channels_list'] = df['channels_list'].apply(frozenset)
120 | for col, dtype in types.items():
121 | if dtype == dict and dtype != object:
122 | df = pd.concat([df.drop([col], axis=1), df[col].apply(pd.Series)], axis=1)
123 |
124 | return df
125 |
126 |
127 | def non_unique_cols_dropper(df: pd.DataFrame) -> pd.DataFrame:
128 | nunique = df.nunique()
129 | cols_to_drop = nunique[nunique == 1].index
130 | df = df.drop(cols_to_drop, axis=1)
131 | return df
132 |
133 |
134 | def groupby(df: pd.DataFrame, group_by='seed', metric='Test/MAE'):
135 | grouped_df = df.groupby(group_by)
136 | stats = grouped_df[[metric, 'name']].mean()
137 | stats['std'] = grouped_df[[metric, 'name']].std()
138 | return stats
139 |
140 |
141 | str_to_run_post_filter = {
142 | **{
143 | f"top{k}": topk_runs(k=k)
144 | for k in range(1, 21)
145 | },
146 | 'best_per_model': topk_run_of_each_model_type(k=1),
147 | **{
148 | f'top{k}_per_model': topk_run_of_each_model_type(k=k)
149 | for k in range(1, 6)
150 | },
151 | 'unique_columns': non_unique_cols_dropper,
152 | 'flatten_dicts': flatten_column_dicts
153 | }
154 |
155 |
156 | def get_runs_df(
157 | get_metrics: bool = True,
158 | run_pre_filters: Union[str, List[Union[Callable, str]]] = 'has_finished',
159 | run_post_filters: Union[str, List[Union[DF_MAPPING, str]]] = None,
160 | exp_type: str = 'pristine', verbose: bool = False
161 | ) -> pd.DataFrame:
162 | if run_pre_filters is None:
163 | run_pre_filters = []
164 | elif not isinstance(run_pre_filters, list):
165 | run_pre_filters: List[Union[Callable, str]] = [run_pre_filters]
166 | run_pre_filters = [(f if callable(f) else str_to_run_pre_filter[f.lower()]) for f in run_pre_filters]
167 | if run_post_filters is None:
168 | run_post_filters = []
169 | elif not isinstance(run_post_filters, list):
170 | run_post_filters: List[Union[Callable, str]] = [run_post_filters]
171 | run_post_filters = [(f if callable(f) else str_to_run_post_filter[f.lower()]) for f in run_post_filters]
172 |
173 | api = wandb.Api()
174 | # Project is specified by
175 | runs = api.runs(f"ecc-mila7/{exp_to_wandb_project[exp_type]}")
176 | summary_list = []
177 | config_list = []
178 | group_list = []
179 | name_list = []
180 | tag_list = []
181 | id_list = []
182 | for run in runs:
183 | # run.summary are the output key/values like accuracy.
184 | # We call ._json_dict to omit large files
185 | if 'model' not in run.config.keys():
186 | if verbose:
187 | print(f"Run {run.config['wandb_name'] if 'wandb_name' in run.config else run} filtered out, I.")
188 | continue
189 |
190 | def filter_out():
191 | for filtr in run_pre_filters:
192 | if not filtr(run):
193 | if verbose:
194 | print(f"Run {run.config['wandb_name']} filtered out, by {filtr.__qualname__}.")
195 | return False
196 | return True
197 |
198 | b = filter_out()
199 | if not b:
200 | continue
201 |
202 | id_list.append(str(run.id))
203 | tag_list.append(str(run.tags))
204 | if get_metrics:
205 | summary_list.append(run.summary._json_dict)
206 | # run.config is the input metrics.
207 | config_list.append(run.config)
208 |
209 | # run.name is the name of the run.
210 | name_list.append(run.name)
211 | group_list.append(run.group)
212 |
213 | summary_df = pd.DataFrame.from_records(summary_list)
214 | config_df = pd.DataFrame.from_records(config_list)
215 | name_df = pd.DataFrame({'name': name_list, 'id': id_list, 'tags': tag_list})
216 | group_df = pd.DataFrame({'group': group_list})
217 | all_df = pd.concat([name_df, config_df, summary_df, group_df], axis=1)
218 |
219 | cols = [c for c in all_df.columns if not c.startswith('gradients/') and c != 'graph_0']
220 | all_df = all_df[cols]
221 | if all_df.empty:
222 | raise ValueError('Empty DF!')
223 | for post_filter in run_post_filters:
224 | all_df = post_filter(all_df)
225 | return all_df
226 |
--------------------------------------------------------------------------------
/climart/META_INFO.json:
--------------------------------------------------------------------------------
1 | {
2 | "feature_by_var": {
3 | "globals": {
4 | "cszrow": {
5 | "start": 0,
6 | "end": 1
7 | },
8 | "gtrow": {
9 | "start": 1,
10 | "end": 2
11 | },
12 | "pressg": {
13 | "start": 2,
14 | "end": 3
15 | },
16 | "oztop": {
17 | "start": 3,
18 | "end": 4
19 | },
20 | "emisrow": {
21 | "start": 4,
22 | "end": 5
23 | },
24 | "salbrol": {
25 | "start": 5,
26 | "end": 9
27 | },
28 | "csalrol": {
29 | "start": 9,
30 | "end": 13
31 | },
32 | "emisrot": {
33 | "start": 13,
34 | "end": 19
35 | },
36 | "gtrot": {
37 | "start": 19,
38 | "end": 25
39 | },
40 | "farerot": {
41 | "start": 25,
42 | "end": 31
43 | },
44 | "salbrot": {
45 | "start": 31,
46 | "end": 55
47 | },
48 | "csalrot": {
49 | "start": 55,
50 | "end": 79
51 | },
52 | "x_cord": {
53 | "start": 79,
54 | "end": 80
55 | },
56 | "y_cord": {
57 | "start": 80,
58 | "end": 81
59 | },
60 | "z_cord": {
61 | "start": 81,
62 | "end": 82
63 | }
64 | },
65 | "levels": {
66 | "shtj": {
67 | "start": 0,
68 | "end": 1
69 | },
70 | "tfrow": {
71 | "start": 1,
72 | "end": 2
73 | },
74 | "level_pressure": {
75 | "start": 2,
76 | "end": 3
77 | },
78 | "height": {
79 | "start": 3,
80 | "end": 4
81 | }
82 | },
83 | "layers": {
84 | "shj": {
85 | "start": 0,
86 | "end": 1
87 | },
88 | "tlayer": {
89 | "start": 1,
90 | "end": 2
91 | },
92 | "layer_pressure": {
93 | "start": 2,
94 | "end": 3
95 | },
96 | "ozphs": {
97 | "start": 3,
98 | "end": 4
99 | },
100 | "qc": {
101 | "start": 4,
102 | "end": 5
103 | },
104 | "dz": {
105 | "start": 5,
106 | "end": 6
107 | },
108 | "dshj": {
109 | "start": 6,
110 | "end": 7
111 | },
112 | "co2rox": {
113 | "start": 7,
114 | "end": 8
115 | },
116 | "ch4rox": {
117 | "start": 8,
118 | "end": 9
119 | },
120 | "n2orox": {
121 | "start": 9,
122 | "end": 10
123 | },
124 | "f11rox": {
125 | "start": 10,
126 | "end": 11
127 | },
128 | "f12rox": {
129 | "start": 11,
130 | "end": 12
131 | },
132 | "layer_thickness": {
133 | "start": 12,
134 | "end": 13
135 | },
136 | "temp_diff": {
137 | "start": 13,
138 | "end": 14
139 | },
140 | "rhc": {
141 | "start": 14,
142 | "end": 15
143 | },
144 | "aerin": {
145 | "start": 15,
146 | "end": 24
147 | },
148 | "sw_ext_sa": {
149 | "start": 24,
150 | "end": 28
151 | },
152 | "sw_ssa_sa": {
153 | "start": 28,
154 | "end": 32
155 | },
156 | "sw_g_sa": {
157 | "start": 32,
158 | "end": 36
159 | },
160 | "lw_abs_sa": {
161 | "start": 36,
162 | "end": 45
163 | }
164 | },
165 | "outputs_pristine": {
166 | "rldc": {
167 | "start": 0,
168 | "end": 50
169 | },
170 | "rluc": {
171 | "start": 0,
172 | "end": 50
173 | },
174 | "rsdc": {
175 | "start": 0,
176 | "end": 50
177 | },
178 | "rsuc": {
179 | "start": 0,
180 | "end": 50
181 | },
182 | "hrlc": {
183 | "start": 0,
184 | "end": 49
185 | },
186 | "hrsc": {
187 | "start": 0,
188 | "end": 49
189 | }
190 | },
191 | "outputs_clear_sky": {
192 | "rldc": {
193 | "start": 0,
194 | "end": 50
195 | },
196 | "rluc": {
197 | "start": 0,
198 | "end": 50
199 | },
200 | "rsdc": {
201 | "start": 0,
202 | "end": 50
203 | },
204 | "rsuc": {
205 | "start": 0,
206 | "end": 50
207 | },
208 | "hrlc": {
209 | "start": 0,
210 | "end": 49
211 | },
212 | "hrsc": {
213 | "start": 0,
214 | "end": 49
215 | }
216 | },
217 | "outputs_all_sky": {
218 | "rld": {
219 | "start": 0,
220 | "end": 50
221 | },
222 | "rlu": {
223 | "start": 0,
224 | "end": 50
225 | },
226 | "rsd": {
227 | "start": 0,
228 | "end": 50
229 | },
230 | "rsu": {
231 | "start": 0,
232 | "end": 50
233 | },
234 | "hrl": {
235 | "start": 0,
236 | "end": 49
237 | },
238 | "hrs": {
239 | "start": 0,
240 | "end": 49
241 | }
242 | }
243 | },
244 | "input_dims": {
245 | "pristine": {
246 | "globals": 82,
247 | "layers": 14,
248 | "levels": 4
249 | },
250 | "clear_sky": {
251 | "globals": 82,
252 | "layers": 45,
253 | "levels": 4
254 | }
255 | },
256 | "variables": {
257 | "cszrow": {
258 | "data_type": "globals",
259 | "num_features": 1,
260 | "only_clear_sky": false
261 | },
262 | "shj": {
263 | "data_type": "layers",
264 | "num_features": 1,
265 | "only_clear_sky": false
266 | },
267 | "shtj": {
268 | "data_type": "levels",
269 | "num_features": 1,
270 | "only_clear_sky": false
271 | },
272 | "gtrow": {
273 | "data_type": "globals",
274 | "num_features": 1,
275 | "only_clear_sky": false
276 | },
277 | "tlayer": {
278 | "data_type": "layers",
279 | "num_features": 1,
280 | "only_clear_sky": false
281 | },
282 | "tfrow": {
283 | "data_type": "levels",
284 | "num_features": 1,
285 | "only_clear_sky": false
286 | },
287 | "pressg": {
288 | "data_type": "globals",
289 | "num_features": 1,
290 | "only_clear_sky": false
291 | },
292 | "layer_pressure": {
293 | "data_type": "layers",
294 | "num_features": 1,
295 | "only_clear_sky": false
296 | },
297 | "level_pressure": {
298 | "data_type": "levels",
299 | "num_features": 1,
300 | "only_clear_sky": false
301 | },
302 | "oztop": {
303 | "data_type": "globals",
304 | "num_features": 1,
305 | "only_clear_sky": false
306 | },
307 | "ozphs": {
308 | "data_type": "layers",
309 | "num_features": 1,
310 | "only_clear_sky": false
311 | },
312 | "qc": {
313 | "data_type": "layers",
314 | "num_features": 1,
315 | "only_clear_sky": false
316 | },
317 | "dz": {
318 | "data_type": "layers",
319 | "num_features": 1,
320 | "only_clear_sky": false
321 | },
322 | "dshj": {
323 | "data_type": "layers",
324 | "num_features": 1,
325 | "only_clear_sky": false
326 | },
327 | "co2rox": {
328 | "data_type": "layers",
329 | "num_features": 1,
330 | "only_clear_sky": false
331 | },
332 | "ch4rox": {
333 | "data_type": "layers",
334 | "num_features": 1,
335 | "only_clear_sky": false
336 | },
337 | "n2orox": {
338 | "data_type": "layers",
339 | "num_features": 1,
340 | "only_clear_sky": false
341 | },
342 | "f11rox": {
343 | "data_type": "layers",
344 | "num_features": 1,
345 | "only_clear_sky": false
346 | },
347 | "f12rox": {
348 | "data_type": "layers",
349 | "num_features": 1,
350 | "only_clear_sky": false
351 | },
352 | "f113rox": true,
353 | "f114rox": true,
354 | "emisrow": {
355 | "data_type": "globals",
356 | "num_features": 1,
357 | "only_clear_sky": false
358 | },
359 | "salbrol": {
360 | "data_type": "globals",
361 | "num_features": 4,
362 | "only_clear_sky": false
363 | },
364 | "csalrol": {
365 | "data_type": "globals",
366 | "num_features": 4,
367 | "only_clear_sky": false
368 | },
369 | "emisrot": {
370 | "data_type": "globals",
371 | "num_features": 6,
372 | "only_clear_sky": false
373 | },
374 | "gtrot": {
375 | "data_type": "globals",
376 | "num_features": 6,
377 | "only_clear_sky": false
378 | },
379 | "farerot": {
380 | "data_type": "globals",
381 | "num_features": 6,
382 | "only_clear_sky": false
383 | },
384 | "salbrot": {
385 | "data_type": "globals",
386 | "num_features": 24,
387 | "only_clear_sky": false
388 | },
389 | "csalrot": {
390 | "data_type": "globals",
391 | "num_features": 24,
392 | "only_clear_sky": false
393 | },
394 | "layer_thickness": {
395 | "data_type": "layers",
396 | "num_features": 1,
397 | "only_clear_sky": false
398 | },
399 | "temp_diff": {
400 | "data_type": "layers",
401 | "num_features": 1,
402 | "only_clear_sky": false
403 | },
404 | "height": {
405 | "data_type": "levels",
406 | "num_features": 1,
407 | "only_clear_sky": false
408 | },
409 | "x_cord": {
410 | "data_type": "globals",
411 | "num_features": 1,
412 | "only_clear_sky": false
413 | },
414 | "y_cord": {
415 | "data_type": "globals",
416 | "num_features": 1,
417 | "only_clear_sky": false
418 | },
419 | "z_cord": {
420 | "data_type": "globals",
421 | "num_features": 1,
422 | "only_clear_sky": false
423 | },
424 | "rhc": {
425 | "data_type": "layers",
426 | "num_features": 1,
427 | "only_clear_sky": true
428 | },
429 | "aerin": {
430 | "data_type": "layers",
431 | "num_features": 9,
432 | "only_clear_sky": true
433 | },
434 | "sw_ext_sa": {
435 | "data_type": "layers",
436 | "num_features": 4,
437 | "only_clear_sky": true
438 | },
439 | "sw_ssa_sa": {
440 | "data_type": "layers",
441 | "num_features": 4,
442 | "only_clear_sky": true
443 | },
444 | "sw_g_sa": {
445 | "data_type": "layers",
446 | "num_features": 4,
447 | "only_clear_sky": true
448 | },
449 | "lw_abs_sa": {
450 | "data_type": "layers",
451 | "num_features": 9,
452 | "only_clear_sky": true
453 | },
454 | "rldc": {
455 | "data_type": "outputs",
456 | "num_features": 50,
457 | "only_clear_sky": true
458 | },
459 | "rluc": {
460 | "data_type": "outputs",
461 | "num_features": 50,
462 | "only_clear_sky": true
463 | },
464 | "rsdc": {
465 | "data_type": "outputs",
466 | "num_features": 50,
467 | "only_clear_sky": true
468 | },
469 | "rsuc": {
470 | "data_type": "outputs",
471 | "num_features": 50,
472 | "only_clear_sky": true
473 | },
474 | "hrlc": {
475 | "data_type": "outputs",
476 | "num_features": 49,
477 | "only_clear_sky": true
478 | },
479 | "hrsc": {
480 | "data_type": "outputs",
481 | "num_features": 49,
482 | "only_clear_sky": true
483 | },
484 | "rld": {
485 | "data_type": "outputs",
486 | "num_features": 50,
487 | "only_clear_sky": false
488 | },
489 | "rlu": {
490 | "data_type": "outputs",
491 | "num_features": 50,
492 | "only_clear_sky": false
493 | },
494 | "rsd": {
495 | "data_type": "outputs",
496 | "num_features": 50,
497 | "only_clear_sky": false
498 | },
499 | "rsu": {
500 | "data_type": "outputs",
501 | "num_features": 50,
502 | "only_clear_sky": false
503 | },
504 | "hrl": {
505 | "data_type": "outputs",
506 | "num_features": 49,
507 | "only_clear_sky": false
508 | },
509 | "hrs": {
510 | "data_type": "outputs",
511 | "num_features": 49,
512 | "only_clear_sky": false
513 | }
514 | }
515 | }
--------------------------------------------------------------------------------
/climart/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/__init__.py
--------------------------------------------------------------------------------
/climart/data_loading/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/data_loading/__init__.py
--------------------------------------------------------------------------------
/climart/data_loading/constants.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import numpy as np
5 | from typing import Optional, Dict
6 |
7 | import xarray
8 | from hydra.utils import get_original_cwd, to_absolute_path
9 |
10 | GLOBALS = "globals"
11 | LAYERS = "layers"
12 | LEVELS = "levels"
13 | INPUTS = 'inputs'
14 | PRISTINE = 'pristine'
15 | CLEAR_SKY = 'clear_sky'
16 | OUTPUTS_PRISTINE = f"outputs_{PRISTINE}"
17 | OUTPUTS_CLEARSKY = f"outputs_{CLEAR_SKY}"
18 | SHORTWAVE = 'shortwave'
19 | LONGWAVE = 'longwave'
20 | HEATING_RATES = 'heating_rates'
21 | FLUXES = 'fluxes'
22 | TOA_FLUXES = 'toa_fluxes'
23 | SURFACE_FLUXES = 'surface_fluxes'
24 |
25 | INPUT_TYPES = [GLOBALS, LEVELS, LAYERS]
26 | DATA_TYPES = [GLOBALS, LAYERS, LEVELS, OUTPUTS_PRISTINE, OUTPUTS_CLEARSKY]
27 | EXP_TYPES = [PRISTINE, CLEAR_SKY]
28 |
29 | SPATIAL_DATA_TYPES = [LAYERS, LEVELS, OUTPUTS_PRISTINE, OUTPUTS_CLEARSKY]
30 | DATA_TYPE_DIMS = {GLOBALS: 82, LEVELS: 4, LAYERS: 45, OUTPUTS_CLEARSKY: 1, OUTPUTS_PRISTINE: 1}
31 |
32 | TRAIN_YEARS = list(range(1979, 1991)) + list(range(1994, 2005))
33 | VAL_YEARS = [2005, 2006]
34 | TEST_YEARS = list(range(2007, 2015))
35 | OOD_PRESENT_YEARS = [1991]
36 | OOD_FUTURE_YEARS = [2097, 2098, 2099]
37 | OOD_HISTORIC_YEARS = [1850, 1851, 1852]
38 | ALL_YEARS = TRAIN_YEARS + VAL_YEARS + TEST_YEARS + OOD_PRESENT_YEARS + OOD_FUTURE_YEARS + OOD_HISTORIC_YEARS
39 |
40 | # DATA_DIR should be an absolute path, since Hydra changes the working dir... to_absolute_path() solves this.
41 | DATA_DIR = to_absolute_path("ClimART_DATA/")
42 |
43 |
44 | def get_data_subdirs(data_dir: str) -> Dict[str, str]:
45 | d = {INPUTS: os.path.join(data_dir, INPUTS),
46 | OUTPUTS_PRISTINE: os.path.join(data_dir, OUTPUTS_PRISTINE),
47 | OUTPUTS_CLEARSKY: os.path.join(data_dir, OUTPUTS_CLEARSKY)
48 | }
49 | return d
50 |
51 |
52 | def get_metadata(data_dir: str = None):
53 | if data_dir is None:
54 | data_dir = DATA_DIR
55 | path = os.path.join(data_dir, 'META_INFO.json')
56 |
57 | if not os.path.isfile(path):
58 | if os.path.isfile(to_absolute_path(path)):
59 | path = to_absolute_path(path)
60 | else:
61 | err_msg = f' Not able to recover meta information from {path}, as it is not a file!'
62 | raise ValueError(err_msg)
63 | with open(path, 'r') as fp:
64 | meta_info = json.load(fp)
65 | return meta_info
66 |
67 |
68 | def get_statistics(data_dir: str = None):
69 | if data_dir is None:
70 | data_dir = DATA_DIR
71 | path = os.path.join(data_dir, 'statistics.npz')
72 | if not os.path.isfile(path):
73 | if os.path.isfile(to_absolute_path(path)):
74 | path = to_absolute_path(path)
75 | else:
76 | err_msg = f' Not able to recover statistics file from {path}'
77 | raise ValueError(err_msg)
78 | statistics = np.load(path)
79 | return statistics
80 |
81 |
82 | def get_coordinates(data_dir: str = None):
83 | if data_dir is None:
84 | data_dir = DATA_DIR
85 | path = os.path.join(data_dir, 'areacella_fx_CanESM5.nc')
86 | if not os.path.isfile(path):
87 | err_msg = f' Not able to recover coordinates/latitudes/longitudes from {path}'
88 | raise ValueError(err_msg)
89 | return xarray.open_dataset(path)
90 |
91 |
92 | def get_data_dims(exp_type: str) -> (Dict[str, int], Dict[str, int]):
93 | spatial_dim = {GLOBALS: 0, LEVELS: 50, LAYERS: 49}
94 | in_dim = {GLOBALS: 82, LEVELS: 4, LAYERS: 14 if exp_type.lower() == PRISTINE else 45}
95 | out_dim = 100
96 | return {'spatial_dim': spatial_dim, 'input_dim': in_dim, 'output_dim': out_dim}
97 |
98 |
99 | def get_flux_mean():
100 | import numpy
101 | return numpy.array([
102 | 296.68795572, 295.59927749, 295.1482046, 294.64596736, 294.11837758,
103 | 293.58230163, 293.04465391, 292.50202651, 291.93796111, 291.40537197,
104 | 290.73892157, 290.04539078, 289.39613274, 288.70448136, 288.01931166,
105 | 287.30607635, 286.62727984, 285.93099547, 285.23425452, 284.56900363,
106 | 283.87613833, 283.12472721, 282.30417958, 281.35523027, 280.2217935,
107 | 278.83756356, 277.10443911, 274.98218953, 272.46624459, 269.56133595,
108 | 266.19161612, 262.41764801, 258.2832474, 254.0212223, 250.09734285,
109 | 246.61271747, 243.56633538, 240.91876951, 238.62519398, 236.65407597,
110 | 234.96650554, 233.53089125, 232.3124631, 231.20520435, 230.09577872,
111 | 228.99557943, 227.9148505, 226.85982104, 226.01209087, 225.32944844,
112 |
113 | 59.0339687, 59.04093876, 59.03547336, 59.02984351, 59.0244958,
114 | 59.02004596, 59.01671905, 59.0139414, 59.01162252, 59.00892333,
115 | 59.00462615, 58.99430663, 58.97929168, 58.94997202, 58.90542686,
116 | 58.83393261, 58.73876127, 58.60566996, 58.42574877, 58.2087681,
117 | 57.94035763, 57.6132036, 57.23794982, 56.81170941, 56.32479326,
118 | 55.77942549, 55.16748318, 54.49360675, 53.77000922, 52.99446721,
119 | 52.16208255, 51.29397739, 50.39621889, 49.53103161, 48.77409974,
120 | 48.12633895, 47.58119014, 47.12617189, 46.74971375, 46.44106945,
121 | 46.19017345, 45.98759973, 45.82464207, 45.68331371, 45.54797404,
122 | 45.41889672, 45.29627039, 45.18034025, 45.09040981, 45.01957376
123 | ], dtype=numpy.float64)
124 |
--------------------------------------------------------------------------------
/climart/data_loading/data_variables.py:
--------------------------------------------------------------------------------
1 | _ALL_INPUT_VARS = [
2 | 'shtj',
3 | 'tfrow',
4 | 'shj',
5 | 'dshj',
6 | 'dz',
7 | 'tlayer',
8 | 'ozphs',
9 | 'qc',
10 | 'co2rox',
11 | 'ch4rox',
12 | 'n2orox',
13 | 'f11rox',
14 | 'f12rox',
15 | 'ccld',
16 | 'rhc',
17 | 'anu',
18 | 'eta',
19 | 'aerin',
20 | 'sw_ext_sa',
21 | 'sw_ssa_sa',
22 | 'sw_g_sa',
23 | 'lw_abs_sa',
24 | 'pressg',
25 | 'gtrow',
26 | 'oztop',
27 | 'cszrow',
28 | 'vtaurow',
29 | 'troprow',
30 | 'emisrow',
31 | 'cldtrol',
32 | 'ncldy',
33 | 'salbrol',
34 | 'csalrol',
35 | 'emisrot',
36 | 'gtrot',
37 | 'farerot',
38 | 'salbrot',
39 | 'csalrot',
40 | 'rel_sub',
41 | 'rei_sub',
42 | 'clw_sub',
43 | 'cic_sub',
44 | 'layer_pressure',
45 | 'level_pressure',
46 | 'layer_thickness',
47 | 'x_cord',
48 | 'y_cord',
49 | 'z_cord',
50 | 'temp_diff',
51 | 'height'
52 | ]
53 |
54 | # -------------------------------------- OUTPUTS/TARGET VARIABLES
55 | LW_HEATING_RATE = 'hrlc'
56 | SW_HEATING_RATE = 'hrsc'
57 |
58 | OUT_HEATING_RATE_CLOUDS = [
59 | 'hrl', # heating rate (long-wave)
60 | 'hrs' # heating rate (short-wave)
61 | ]
62 |
63 | OUT_HEATING_RATE_NOCLOUDS = [
64 | LW_HEATING_RATE, # heating rate (long-wave)
65 | SW_HEATING_RATE # heating rate (short-wave)
66 | ]
67 |
68 | OUT_SHORTWAVE_CLOUDS = [
69 | 'rsd', # solar flux down
70 | 'rsu', # solar flux up
71 | ]
72 |
73 | OUT_SHORTWAVE_NOCLOUDS = [
74 | 'rsdc', # solar flux down
75 | 'rsuc', # solar flux up
76 | ]
77 |
78 | OUT_LONGWAVE_CLOUDS = [
79 | "rld", # thermal flux down
80 | 'rlu', # thermal flux up
81 | ]
82 |
83 | OUT_LONGWAVE_NOCLOUDS = [
84 | "rldc", # thermal flux down
85 | 'rluc', # thermal flux up
86 | ]
87 |
88 | _ALL_OUTPUT_VARS = OUT_SHORTWAVE_NOCLOUDS + OUT_LONGWAVE_NOCLOUDS + OUT_HEATING_RATE_NOCLOUDS
89 | _ALL_VARS = _ALL_INPUT_VARS + _ALL_OUTPUT_VARS
90 |
91 | INPUT_VARS_CLOUDS = [
92 | 'ccld', # Cloud amount profile
93 | 'anu', # Cloud water content horizontal variability parameter
94 | 'eta', # Fraction black carbon in liquid cloud droplets
95 | 'vtaurow', # Vertically integrated optical thickness at 550 nm for stratospheric aerosols
96 | 'troprow', # Layer index of the tropopause
97 | 'cldtrol', # Total vertically projected cloud fraction
98 | 'ncldy', # Number of cloudy subcolumns in CanAM grid
99 | 'rel_sub', # Liquid cloud effective radius for subcolumns in CanAM grid
100 | 'rei_sub', # Ice cloud effective radius for subcolumns in CanAM grid
101 | 'clw_sub', # Liquid cloud water path for subcolumns in CanAM grid
102 | 'cic_sub'
103 | ]
104 |
105 | INPUT_VARS_AEROSOLS = [
106 | 'rhc', # Relative humidity
107 | 'aerin', # Relative humidity
108 | 'sw_ext_sa', # Cloud water content horizontal variability parameter
109 | 'sw_ssa_sa', # solar flux up
110 | 'sw_g_sa', # heating rate (long-wave?)
111 | 'lw_abs_sa' # heating rate (short-wave?)
112 | ]
113 |
114 |
115 | # -----------------------------------------------------
116 |
117 | def no_clouds_exp(name):
118 | return name in ['pristine', 'clear-sky', 'clearsky', 'clear_sky', 'aerosols']
119 |
120 |
121 | def get_flux_output_variables(target_type: str):
122 | if target_type.lower() == "shortwave":
123 | return OUT_SHORTWAVE_NOCLOUDS
124 | if target_type.lower() == "longwave":
125 | return OUT_LONGWAVE_NOCLOUDS
126 | if target_type.lower() == "shortwave+longwave":
127 | return OUT_SHORTWAVE_NOCLOUDS + OUT_LONGWAVE_NOCLOUDS
128 | raise ValueError(f" Unexpected arg {target_type} for target_type")
129 |
130 |
131 | EXP_TYPES = ['pristine', 'clear_sky']
132 |
--------------------------------------------------------------------------------
/climart/data_transform/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/data_transform/__init__.py
--------------------------------------------------------------------------------
/climart/data_transform/normalization.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from abc import ABC
3 | from typing import Optional, Union, Dict, Iterable, Sequence, List, Callable
4 | from omegaconf import DictConfig
5 | import numpy as np
6 | import torch
7 | from torch import Tensor
8 | from climart.data_loading.constants import LEVELS, LAYERS, GLOBALS, get_metadata, get_statistics
9 | from climart.data_loading import constants
10 | from climart.data_loading.data_variables import get_flux_output_variables
11 | from climart.utils.utils import get_logger, get_target_variable_names, get_identity_callable, identity
12 |
13 | NP_ARRAY_MAPPING = Callable[[np.ndarray], np.ndarray]
14 |
15 | log = get_logger(__name__)
16 |
17 |
18 | def in_degree_normalization(adj: np.ndarray):
19 | """
20 | For Graph networks
21 | :param adj: A N x N adjacency matrix
22 | :return: In-degree normalized matrix (Row-normalized matrix)
23 | """
24 | rowsum = np.array(adj.sum(1))
25 | r_inv = np.power(rowsum, -1).flatten()
26 | r_inv[np.isinf(r_inv)] = 0.0
27 | r_mat_inv = np.diag(r_inv)
28 | mx = r_mat_inv @ adj
29 | return mx
30 |
31 |
32 |
33 | class NormalizationMethod(ABC):
34 | def __init__(self, *args, **kwargs):
35 | pass
36 |
37 | def normalize(self, data: np.ndarray, axis=0, *args, **kwargs):
38 | return data
39 |
40 | def inverse_normalize(self, normalized_data: np.ndarray):
41 | return normalized_data
42 |
43 | def stored_values(self):
44 | return dict()
45 |
46 | def __copy__(self):
47 | return type(self)(**self.stored_values())
48 |
49 | def copy(self):
50 | return self.__copy__()
51 |
52 | def change_input_type(self, new_type):
53 | for attribute, value in self.__dict__.items():
54 | if new_type in [torch.Tensor, torch.TensorType]:
55 | if isinstance(value, np.ndarray):
56 | setattr(self, attribute, torch.from_numpy(value))
57 | elif new_type == np.ndarray:
58 | if torch.is_tensor(value):
59 | setattr(self, attribute, value.numpy().cpu())
60 | else:
61 | setattr(self, attribute, new_type(value))
62 |
63 | def apply_torch_func(self, fn):
64 | """
65 | Function to be called to apply a torch function to all tensors of this class, e.g. apply .to(), .cuda(), ...,
66 | Just call this function from within the model's nn.Module._apply()
67 | """
68 | for attribute, value in self.__dict__.items():
69 | if torch.is_tensor(value):
70 | setattr(self, attribute, fn(value))
71 |
72 |
73 | class Z_Normalizer(NormalizationMethod):
74 | def __init__(self, mean, std, **kwargs):
75 | super().__init__(**kwargs)
76 | self.mean = mean
77 | self.std = std
78 |
79 | def normalize(self, data, axis=None, *args, **kwargs):
80 | return self(data)
81 |
82 | def inverse_normalize(self, normalized_data):
83 | data = normalized_data * self.std + self.mean
84 | return data
85 |
86 | def stored_values(self):
87 | return {'mean': self.mean, 'std': self.std}
88 |
89 | def __call__(self, data):
90 | return (data - self.mean) / self.std
91 |
92 |
93 | class MinMax_Normalizer(NormalizationMethod):
94 | def __init__(self, min=None, max_minus_min=None, max=None, **kwargs):
95 | super().__init__(**kwargs)
96 | self.min = min
97 | if min:
98 | assert max_minus_min or max
99 | self.max_minus_min = max_minus_min or max - min
100 |
101 | def normalize(self, data, axis=None, *args, **kwargs):
102 | # self.min = np.min(data, axis=axis)
103 | # self.max_minus_min = (np.max(data, axis=axis) - self.min)
104 | return self(data)
105 |
106 | def inverse_normalize(self, normalized_data):
107 | shapes = normalized_data.shape
108 | if len(shapes) >= 2:
109 | normalized_data = normalized_data.reshape(normalized_data.shape[0], -1)
110 | data = normalized_data * self.max_minus_min + self.min
111 | if len(shapes) >= 2:
112 | data = data.reshape(shapes)
113 | return data
114 |
115 | def stored_values(self):
116 | return {'min': self.min, 'max_minus_min': self.max_minus_min}
117 |
118 | def __call__(self, data):
119 | return (data - self.min) / self.max_minus_min
120 |
121 |
122 | class LogNormalizer(NormalizationMethod):
123 | def normalize(self, data, *args, **kwargs):
124 | normalized_data = self(data)
125 | return normalized_data
126 |
127 | def inverse_normalize(self, normalized_data):
128 | data = np.exp(normalized_data)
129 | return data
130 |
131 | def __call__(self, data: np.ndarray, *args, **kwargs):
132 | return np.log(data)
133 |
134 |
135 | class LogZ_Normalizer(NormalizationMethod):
136 | def __init__(self, mean=None, std=None, **kwargs):
137 | super().__init__(**kwargs)
138 | self.z_normalizer = Z_Normalizer(mean, std)
139 |
140 | def normalize(self, data, *args, **kwargs):
141 | normalized_data = np.log(data + 1e-5)
142 | normalized_data = self.z_normalizer.normalize(normalized_data)
143 | return normalized_data
144 |
145 | def inverse_normalize(self, normalized_data):
146 | data = self.z_normalizer.inverse_normalize(normalized_data)
147 | data = np.exp(data) - 1e-5
148 | return data
149 |
150 | def stored_values(self):
151 | return self.z_normalizer.stored_values()
152 |
153 | def change_input_type(self, new_type):
154 | self.z_normalizer.change_input_type(new_type)
155 |
156 | def apply_torch_func(self, fn):
157 | self.z_normalizer.apply_torch_func(fn)
158 |
159 | def __call__(self, data, *args, **kwargs):
160 | normalized_data = np.log(data + 1e-5)
161 | return self.z_normalizer(normalized_data)
162 |
163 |
164 | class MinMax_LogNormalizer(NormalizationMethod):
165 | def __init__(self, min=None, max_minus_min=None, **kwargs):
166 | super().__init__(**kwargs)
167 | self.min_max_normalizer = MinMax_Normalizer(min, max_minus_min)
168 |
169 | def normalize(self, data, *args, **kwargs):
170 | normalized_data = self.min_max_normalizer.normalize(data)
171 | normalized_data = np.log(normalized_data)
172 | return normalized_data
173 |
174 | def inverse_normalize(self, normalized_data):
175 | data = np.exp(normalized_data)
176 | data = self.min_max_normalizer.inverse_normalize(data)
177 | return data
178 |
179 | def stored_values(self):
180 | return self.min_max_normalizer.stored_values()
181 |
182 | def change_input_type(self, new_type):
183 | self.min_max_normalizer.change_input_type(new_type)
184 |
185 | def apply_torch_func(self, fn):
186 | self.min_max_normalizer.apply_torch_func(fn)
187 |
188 |
189 | def get_normalizer(normalizer='z', *args, **kwargs) -> NormalizationMethod:
190 | normalizer = normalizer.lower().strip().replace('-', '_').replace('&', '+')
191 | supported_normalizers = ['z',
192 | 'min_max',
193 | 'min_max+log', 'min_max_log',
194 | 'log_z',
195 | 'log',
196 | 'none']
197 | assert normalizer in supported_normalizers, f"Unsupported Normalization {normalizer} not in {str(supported_normalizers)}"
198 | if normalizer == 'z':
199 | return Z_Normalizer(*args, **kwargs)
200 | elif normalizer == 'min_max':
201 | return MinMax_Normalizer(*args, **kwargs)
202 | elif normalizer in ['min_max+log', 'min_max_log']:
203 | return MinMax_LogNormalizer(*args, **kwargs)
204 | elif normalizer in ['logz', 'log_z']:
205 | return LogZ_Normalizer(*args, **kwargs)
206 | elif normalizer == 'log':
207 | return LogNormalizer(*args, **kwargs)
208 | else:
209 | return NormalizationMethod(*args, **kwargs) # like no normalizer
210 |
211 |
212 | class Normalizer:
213 | def __init__(
214 | self,
215 | datamodule_config: DictConfig,
216 | input_normalization: Optional[str] = None,
217 | output_normalization: Optional[str] = None,
218 | spatial_normalization_in: bool = False,
219 | spatial_normalization_out: bool = False,
220 | log_scaling: Union[bool, List[str]] = False,
221 | data_dir: Optional[str] = None,
222 | verbose: bool = True
223 | ):
224 | """
225 | input_normalization (str): "z" for z-scaling (zero mean and unit standard deviation)
226 | """
227 | if not verbose:
228 | log.setLevel(logging.WARNING)
229 |
230 | if data_dir is None:
231 | data_dir = datamodule_config.get("data_dir") or constants.DATA_DIR
232 | exp_type = datamodule_config.get("exp_type")
233 | target_type = datamodule_config.get("target_type")
234 | target_variable = datamodule_config.get("target_variable")
235 |
236 | self._layer_mask = 45 if exp_type == constants.CLEAR_SKY else 14
237 | self._recover_meta_info(data_dir)
238 | self._input_normalizer: Dict[str, NormalizationMethod] = dict()
239 | self._output_normalizer: Optional[Dict[str, NormalizationMethod]] = None
240 |
241 | self._target_variables = get_target_variable_names(target_type, target_variable)
242 | if input_normalization is not None:
243 | norma_type = '_spatial' if spatial_normalization_in else ''
244 | info_msg = f" Applying {norma_type.lstrip('_')} {input_normalization} normalization to input data," \
245 | f" based on pre-computed stats."
246 | log.info(info_msg)
247 |
248 | precomputed_stats = get_statistics(data_dir)
249 | precomputed_stats = {k: precomputed_stats[k] for k in precomputed_stats.keys() if
250 | (('spatial' in k and spatial_normalization_in) or ('spatial' not in k))}
251 | if isinstance(log_scaling, list) or log_scaling:
252 | log.info(' Log scaling pressure and height variables! (no other normalization is applied to them)')
253 | post_log_vals = dict(pressg=(11.473797, 0.10938317),
254 | layer_pressure=(9.29207, 2.097411),
255 | dz=(6.5363674, 1.044927),
256 | layer_thickness=(6.953938313568889, 1.3751644503732554),
257 | level_pressure=(9.252319, 2.1721559))
258 | vars_to_log_scale = ['pressg', 'layer_pressure', 'dz', 'layer_thickness', 'level_pressure']
259 | self._layer_log_mask = torch.tensor([2, 5, 12])
260 | for var in vars_to_log_scale:
261 | dtype = self._variables[var]['data_type']
262 | s, e = self.feature_by_var[dtype][var]['start'], self.feature_by_var[dtype][var]['end']
263 | # precomputed_stats[f'{dtype}{prefix}_mean'][..., s:e] = 0
264 | # precomputed_stats[f'{dtype}{prefix}_std'][..., s:e] = 1
265 | precomputed_stats[f'{dtype}{norma_type}_mean'][..., s:e] = post_log_vals[var][0]
266 | precomputed_stats[f'{dtype}{norma_type}_std'][..., s:e] = post_log_vals[var][1]
267 |
268 | def log_scaler(X: Dict[str, Tensor]) -> Dict[str, Tensor]:
269 | # layer_log_mask = torch.tensor([2, 5, 12])
270 | X[GLOBALS][2] = torch.log(X[GLOBALS][2])
271 | X[LEVELS][..., 2] = torch.log(X[LEVELS][..., 2])
272 | X[LAYERS][..., self._layer_log_mask] = torch.log(X[LAYERS][..., self._layer_log_mask])
273 | return X
274 |
275 | self._log_scaler_func = log_scaler
276 | else:
277 | self._log_scaler_func = identity
278 |
279 | for data_type in [GLOBALS, LEVELS, LAYERS]:
280 | if input_normalization is not None:
281 | normer_kwargs = dict(
282 | mean=precomputed_stats[data_type + f'{norma_type}_mean'],
283 | std=precomputed_stats[data_type + f'{norma_type}_std'],
284 | min=precomputed_stats[data_type + f'{norma_type}_min'],
285 | max=precomputed_stats[data_type + f'{norma_type}_max'],
286 | )
287 | if data_type == LAYERS:
288 | for k, v in normer_kwargs.items():
289 | normer_kwargs[k] = v[..., :self._layer_mask]
290 | normalizer = get_normalizer(
291 | input_normalization,
292 | **normer_kwargs
293 | )
294 | else:
295 | normalizer = identity
296 | self._input_normalizer[data_type] = normalizer
297 | if output_normalization is not None:
298 | self._output_normalizer = dict()
299 | px = "spatial_" if spatial_normalization_out else ""
300 | precomputed_stats = get_statistics(data_dir)
301 | for tv in get_flux_output_variables(target_type):
302 | normer_kwargs = dict(
303 | mean=precomputed_stats[f"outputs_{exp_type}_{tv}_{px}mean"],
304 | std=precomputed_stats[f"outputs_{exp_type}_{tv}_{px}std"],
305 | min=precomputed_stats[f"outputs_{exp_type}_{tv}_{px}min"],
306 | max=precomputed_stats[f"outputs_{exp_type}_{tv}_{px}max"],
307 | )
308 | self._output_normalizer[tv] = get_normalizer(**normer_kwargs, normalizer=output_normalization)
309 |
310 | def _recover_meta_info(self, data_dir: str):
311 | meta_info = get_metadata(data_dir)
312 | self._variables = meta_info['variables']
313 | self._vars_used_or_not = list(self._variables.keys())
314 | self._feature_by_var = meta_info['feature_by_var']
315 |
316 | @property
317 | def feature_by_var(self):
318 | return self._feature_by_var
319 |
320 | def get_normalizer(self, data_type: str) -> Union[NP_ARRAY_MAPPING, NormalizationMethod]:
321 | return self._input_normalizer[data_type]
322 |
323 | def get_normalizers(self) -> Dict[str, Union[NP_ARRAY_MAPPING, NormalizationMethod]]:
324 | return {
325 | data_type: self.get_normalizer(data_type)
326 | for data_type in constants.INPUT_TYPES
327 | }
328 |
329 | def set_normalizer(self, data_type: str, new_normalizer: Optional[NP_ARRAY_MAPPING]):
330 | if new_normalizer is None:
331 | new_normalizer = identity
332 | if data_type in constants.INPUT_TYPES:
333 | self._input_normalizer[data_type] = new_normalizer
334 | else:
335 | log.info(f" Setting output normalizer, after calling set_normalizer with data_type={data_type}")
336 | self._output_normalizer = new_normalizer
337 |
338 | def set_input_normalizers(self, new_normalizer: Optional[NP_ARRAY_MAPPING]):
339 | for data_type in constants.INPUT_TYPES:
340 | self.set_normalizer(data_type, new_normalizer)
341 |
342 | @property
343 | def output_normalizer(self):
344 | return self._output_normalizer
345 |
346 | def normalize(self, X: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
347 | for input_type, rawX in X.items():
348 | X[input_type] = self._input_normalizer[input_type](rawX)
349 |
350 | X = self._log_scaler_func(X)
351 | return X
352 |
353 | def __call__(self, X: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
354 | return self.normalize(X)
--------------------------------------------------------------------------------
/climart/datamodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/datamodules/__init__.py
--------------------------------------------------------------------------------
/climart/datamodules/pl_climart_datamodule.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional, List, Callable
3 |
4 | from pytorch_lightning import LightningDataModule
5 | from pytorch_lightning.utilities.types import EVAL_DATALOADERS
6 | from torch.utils.data import DataLoader
7 |
8 | from climart.data_loading.constants import TRAIN_YEARS, TEST_YEARS, OOD_PRESENT_YEARS, OOD_HISTORIC_YEARS, \
9 | OOD_FUTURE_YEARS, VAL_YEARS
10 | from climart.data_loading.h5_dataset import ClimART_HdF5_Dataset
11 | from climart.data_transform.normalization import Normalizer
12 | from climart.data_transform.transforms import AbstractTransform
13 | from climart.utils.utils import year_string_to_list
14 |
15 | log = logging.getLogger(__name__)
16 |
17 |
18 | class ClimartDataModule(LightningDataModule):
19 | """
20 | ----------------------------------------------------------------------------------------------------------
21 | A DataModule implements 5 key methods:
22 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
23 | - setup (things to do on every accelerator in distributed mode)
24 | - train_dataloader (the training dataloader)
25 | - val_dataloader (the validation dataloader(s))
26 | - test_dataloader (the test dataloader(s))
27 |
28 | This allows you to share a full dataset without explaining how to download,
29 | split, transform and process the data
30 |
31 | Read the docs:
32 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
33 | """
34 |
35 | def __init__(
36 | self,
37 | exp_type: str,
38 | target_type: str,
39 | target_variable: str,
40 | data_dir: Optional[str] = None,
41 | train_years: str = "1999-2000",
42 | validation_years: str = "2005",
43 | predict_years: str = "2014",
44 | input_transform: Optional[AbstractTransform] = None,
45 | normalizer: Optional[Normalizer] = None,
46 | batch_size: int = 64,
47 | eval_batch_size: int = 512,
48 | num_workers: int = 0,
49 | pin_memory: bool = True,
50 | load_train_into_mem: bool = False,
51 | load_test_into_mem: bool = False,
52 | load_valid_into_mem: bool = True,
53 | test_main_dataset: bool = True,
54 | test_ood_1991: bool = True,
55 | test_ood_historic: bool = True,
56 | test_ood_future: bool = True,
57 | verbose: bool = True
58 | ):
59 | """
60 | Args:
61 | exp_type (str): 'pristine' or 'clear-sky'
62 | target_type (str): 'longwave" or 'shortwave'
63 | target_variable (str): 'fluxes' or 'heating-rate'
64 | data_dir (str or None): If str: A path to the data folder, if None: constants.DATA_DIR will be used.
65 | batch_size (int): Batch size for the training dataloader
66 | eval_batch_size (int): Batch size for the test and validation dataloader's
67 | num_workers (int): Dataloader arg for higher efficiency
68 | pin_memory (bool): Dataloader arg for higher efficiency
69 | test_main_dataset (bool): Whether to test and compute metrics on main test dataset (2007-14). Default: True
70 | test_ood_1991 (bool): Whether to test and compute metrics on OOD/anomaly test year 1991. Default: True
71 | test_ood_historic (bool): Whether to test and compute metrics on historic test years 1850-52. Default: True
72 | test_ood_future (bool): Whether to test and compute metrics on future test years 2097-99. Default: True
73 | seed (int): Used to seed the validation-test set split, such that the split will always be the same.
74 | """
75 | super().__init__()
76 | # The following makes all args available as, e.g., self.hparams.batch_size
77 | self.save_hyperparameters(ignore=["input_transform", "normalizer"])
78 | self.input_transform = input_transform # self.hparams.input_transform
79 | self.normalizer = normalizer
80 |
81 | self._data_train: Optional[ClimART_HdF5_Dataset] = None
82 | self._data_val: Optional[ClimART_HdF5_Dataset] = None
83 | self._data_test: Optional[List[ClimART_HdF5_Dataset]] = None
84 | self._data_predict: Optional[List[ClimART_HdF5_Dataset]] = None
85 | self._test_set_names: Optional[List[str]] = None
86 |
87 | def prepare_data(self):
88 | """Download data if needed. This method is called only from a single GPU.
89 | Do not use it to assign state (self.x = y)."""
90 | pass
91 |
92 | def setup(self, stage: Optional[str] = None):
93 | """Load data. Set internal variables: self._data_train, self._data_val, self._data_test."""
94 | dataset_kwargs = dict(
95 | data_dir=self.hparams.data_dir,
96 | exp_type=self.hparams.exp_type,
97 | target_type=self.hparams.target_type,
98 | target_variable=self.hparams.target_variable,
99 | verbose=self.hparams.verbose,
100 | input_transform=self.input_transform,
101 | normalizer=self.normalizer,
102 | )
103 |
104 | # Training set:
105 | if stage == "fit" or stage is None:
106 | # Get & check list of training/validation years
107 | train_years = year_string_to_list(self.hparams.train_years)
108 | assert all([y in TRAIN_YEARS for y in train_years]), f"All years in --train_years must be in {TRAIN_YEARS}!"
109 |
110 | self._data_train = ClimART_HdF5_Dataset(years=train_years, name='Train',
111 | load_h5_into_mem=self.hparams.load_train_into_mem,
112 | **dataset_kwargs)
113 | # Validation set
114 | if stage in ['fit', 'validate', None] and self.hparams.validation_years is not None:
115 | val_yrs = year_string_to_list(self.hparams.validation_years)
116 | assert all([y in VAL_YEARS for y in val_yrs]), f'All years in --validation_years must be in {VAL_YEARS}!'
117 | self._data_val = ClimART_HdF5_Dataset(years=val_yrs, name='Val',
118 | load_h5_into_mem=self.hparams.load_valid_into_mem,
119 | **dataset_kwargs)
120 | # Test sets:
121 | # - Main Present-day Test Set(s):
122 | # To compute metrics for each test year -> use a separate dataloader for each of the test years (2007-14).
123 | if stage == "test" or stage is None:
124 | dataset_kwargs["load_h5_into_mem"] = self.hparams.load_test_into_mem
125 | if self.hparams.test_main_dataset:
126 | test_sets = [
127 | ClimART_HdF5_Dataset(years=[test_year], name=f'Test_{test_year}', **dataset_kwargs)
128 | for test_year in TEST_YEARS
129 | ]
130 | else:
131 | test_sets = []
132 | log.info(" Main test dataset (2007-14) will not be tested on in this run.")
133 | # - OOD Test Sets:
134 | ood_test_sets = []
135 | if self.hparams.test_ood_1991:
136 | # 1991 OOD test year accounts for Mt. Pinatubo eruption: especially challenging for clear-sky conditions
137 | ood_test_sets += [ClimART_HdF5_Dataset(years=OOD_PRESENT_YEARS, name='OOD Test', **dataset_kwargs)]
138 | if self.hparams.test_ood_historic:
139 | ood_test_sets += [
140 | ClimART_HdF5_Dataset(years=OOD_HISTORIC_YEARS, name='Historic Test', **dataset_kwargs)]
141 | if self.hparams.test_ood_future:
142 | ood_test_sets += [ClimART_HdF5_Dataset(years=OOD_FUTURE_YEARS, name='Future Test', **dataset_kwargs)]
143 |
144 | self._data_test = test_sets + ood_test_sets
145 |
146 | # Prediction set:
147 | if stage == "predict" and self.hparams.predict_years is not None:
148 | dataset_kwargs["load_h5_into_mem"] = self.hparams.load_test_into_mem
149 | predict_years = year_string_to_list(self.hparams.predict_years)
150 | self._data_predict = [
151 | ClimART_HdF5_Dataset(years=[pred_year], name=f'Predict_{pred_year}', **dataset_kwargs)
152 | for pred_year in predict_years
153 | ]
154 |
155 | @property
156 | def test_set_names(self) -> List[str]:
157 | if self._test_set_names is None:
158 | test_names = []
159 | if self.hparams.test_main_dataset:
160 | test_names += [f'{test_year}' for test_year in TEST_YEARS]
161 | if self.hparams.test_ood_1991:
162 | test_names += ['OOD']
163 | if self.hparams.test_ood_historic:
164 | test_names += ['historic']
165 | if self.hparams.test_ood_future:
166 | test_names += ['future']
167 | self._test_set_names = test_names
168 | return self._test_set_names
169 |
170 | @property
171 | def predict_years(self) -> List[int]:
172 | return year_string_to_list(self.hparams.predict_years)
173 |
174 | @predict_years.setter
175 | def predict_years(self, predict_years: str):
176 | self.hparams.predict_years = predict_years
177 |
178 | def on_before_batch_transfer(self, batch, dataloader_idx):
179 | return batch
180 |
181 | def on_after_batch_transfer(self, batch, dataloader_idx):
182 | return batch
183 |
184 | def _shared_dataloader_kwargs(self) -> dict:
185 | shared_kwargs = dict(num_workers=int(self.hparams.num_workers), pin_memory=self.hparams.pin_memory)
186 | return shared_kwargs
187 |
188 | def _shared_eval_dataloader_kwargs(self) -> dict:
189 | return dict(**self._shared_dataloader_kwargs(), batch_size=self.hparams.eval_batch_size, shuffle=False)
190 |
191 | def train_dataloader(self):
192 | return DataLoader(
193 | dataset=self._data_train,
194 | batch_size=self.hparams.batch_size,
195 | shuffle=True,
196 | **self._shared_dataloader_kwargs(),
197 | )
198 |
199 | def val_dataloader(self):
200 | return DataLoader(
201 | dataset=self._data_val,
202 | **self._shared_eval_dataloader_kwargs()
203 | ) if self._data_val is not None else None
204 |
205 | def test_dataloader(self) -> List[DataLoader]:
206 | return [DataLoader(
207 | dataset=data_test_subset,
208 | **self._shared_eval_dataloader_kwargs()
209 | ) for data_test_subset in self._data_test]
210 |
211 | def predict_dataloader(self) -> EVAL_DATALOADERS:
212 | return [DataLoader(
213 | dataset=data_test_subset,
214 | **self._shared_eval_dataloader_kwargs()
215 | ) for data_test_subset in self._data_predict]
--------------------------------------------------------------------------------
/climart/interface.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 |
4 | import hydra
5 | import torch
6 | from omegaconf import DictConfig
7 |
8 | from climart.data_transform.normalization import Normalizer
9 | from climart.datamodules.pl_climart_datamodule import ClimartDataModule
10 | from climart.models.base_model import BaseModel
11 |
12 | """
13 | In this file you can find helper functions to avoid model/data loading and reloading boilerplate code
14 | """
15 |
16 |
17 | def get_model(config: DictConfig, **kwargs) -> BaseModel:
18 | """
19 | Args:
20 | config (DictConfig): A OmegaConf config (e.g. produced by hydra yaml config file parsing)
21 | Returns:
22 | The model that you can directly use to train with pytorch-lightning
23 | """
24 | if config.get('normalizer'):
25 | # This can be a bit redundant with get_datamodule (normalizer is instantiated twice), but it is better to be
26 | # sure that the output_normalizer is used by the model in cases where pytorch-lightning is not used.
27 | # By default if you use pytorch-lightning, the correct output_normalizer is passed to the model before training,
28 | # even without the below
29 | normalizer: Normalizer = hydra.utils.instantiate(
30 | config.normalizer, _recursive_=False,
31 | datamodule_config=config.datamodule,
32 | )
33 | kwargs['output_normalizer'] = normalizer.output_normalizer
34 | model: BaseModel = hydra.utils.instantiate(
35 | config.model, _recursive_=False,
36 | datamodule_config=config.datamodule,
37 | **kwargs
38 | )
39 | return model
40 |
41 |
42 | def get_datamodule(config: DictConfig) -> ClimartDataModule:
43 | """
44 | Args:
45 | config (DictConfig): A OmegaConf config (e.g. produced by hydra yaml config file parsing)
46 | Returns:
47 | A datamodule that you can directly use to train pytorch-lightning models
48 | """
49 | # First we instantiate our normalization preprocesser, then our datamodule, and finally the model
50 | normalizer: Normalizer = hydra.utils.instantiate(
51 | config.normalizer,
52 | datamodule_config=config.datamodule,
53 | _recursive_=False
54 | )
55 | data_module: ClimartDataModule = hydra.utils.instantiate(
56 | config.datamodule,
57 | input_transform=config.model.get("input_transform"),
58 | normalizer=normalizer
59 | )
60 | return data_module
61 |
62 |
63 | def get_model_and_data(config: DictConfig) -> (BaseModel, ClimartDataModule):
64 | """
65 | Args:
66 | config (DictConfig): A OmegaConf config (e.g. produced by hydra yaml config file parsing)
67 | Returns:
68 | A tuple of (model, datamodule), that you can directly use to train with pytorch-lightning
69 | (e.g., checkpointing the best model w.r.t. a small validation set with the ModelCheckpoint callback),
70 | with:
71 | trainer.fit(model=model, datamodule=datamodule)
72 | """
73 | data_module = get_datamodule(config)
74 | model: BaseModel = hydra.utils.instantiate(
75 | config.model, _recursive_=False,
76 | datamodule_config=config.datamodule,
77 | output_normalizer=data_module.normalizer.output_normalizer,
78 | )
79 | return model, data_module
80 |
81 |
82 | def reload_model_from_config_and_ckpt(config: DictConfig, model_path: str, load_datamodule: bool = False):
83 | model, data_module = get_model_and_data(config)
84 | # Reload model
85 | model_state = torch.load(model_path)['state_dict']
86 | model.load_state_dict(model_state)
87 | if load_datamodule:
88 | return model, data_module
89 | return model
90 |
91 |
92 |
--------------------------------------------------------------------------------
/climart/models/CNNs/CNN.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence, Dict, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 | from einops import rearrange, repeat
9 |
10 | from climart.models.base_model import BaseModel
11 | from climart.utils.utils import get_activation_function, get_normalization_layer
12 | from climart.models.modules.additional_layers import Multiscale_Module, GAP, SE_Block
13 |
14 |
15 | class CNN_Net(BaseModel):
16 | def __init__(self,
17 | hidden_dims: Sequence[int],
18 | dilation: int = 1,
19 | net_normalization: str = 'none',
20 | kernels: Sequence[int] = (20, 10, 5),
21 | strides: Sequence[int] = (2, 1, 1), # 221
22 | gap: bool = False,
23 | se_block: bool = False,
24 | activation_function: str = 'relu',
25 | dropout: float = 0.0,
26 | *args, **kwargs):
27 | super().__init__(*args, **kwargs)
28 | self.save_hyperparameters()
29 | self.output_dim = self.raw_output_dim
30 | self.channel_list = list(hidden_dims)
31 | input_dim = self.input_transform.output_dim
32 | self.channel_list = [input_dim] + self.channel_list
33 |
34 | self.linear_in_shape = 10
35 | self.use_linear = gap
36 | self.ratio = 16
37 | self.kernel_list = list(kernels)
38 | self.stride_list = list(strides)
39 | self.global_average = GAP()
40 |
41 | feat_cnn_modules = []
42 | for i in range(len(self.channel_list) - 1):
43 | out = self.channel_list[i + 1]
44 | feat_cnn_modules += [nn.Conv1d(in_channels=self.channel_list[i],
45 | out_channels=out, kernel_size=self.kernel_list[i],
46 | stride=self.stride_list[i],
47 | bias=self.hparams.net_normalization != 'batch_norm',
48 | dilation=self.hparams.dilation)]
49 | if se_block:
50 | feat_cnn_modules.append(SE_Block(out, self.ratio))
51 | if self.hparams.net_normalization != 'none':
52 | feat_cnn_modules += [get_normalization_layer(self.hparams.net_normalization, out)]
53 | feat_cnn_modules += [get_activation_function(activation_function, functional=False)]
54 | # TODO: Need to add adaptive pooling with arguments
55 | feat_cnn_modules += [nn.Dropout(dropout)]
56 |
57 | self.feat_cnn = nn.Sequential(*feat_cnn_modules)
58 |
59 | # input_dim = [self.channel_list[0], self.linear_in_shape] # TODO: Need to pass input shape as argument
60 | # linear_input_shape = functools.reduce(operator.mul, list(self.feat_cnn(torch.rand(1, *input_dim)).shape))
61 | # print(linear_input_shape)
62 | linear_layers = []
63 | if not self.use_linear:
64 | linear_layers.append(nn.Linear(int(self.channel_list[-1] / 100) * 400, 256, bias=True))
65 | linear_layers.append(get_activation_function(activation_function, functional=False))
66 | linear_layers.append(nn.Dropout(dropout))
67 | linear_layers.append(nn.Linear(256, self.output_dim, bias=True))
68 | self.ll = nn.Sequential(*linear_layers)
69 |
70 | def forward(self, X: Union[Tensor, Dict[str, Tensor]]) -> Tensor:
71 | """
72 | input:
73 | Dict with key-values {GLOBALS: x_glob, LEVELS: x_lev, LAYERS: x_lay},
74 | where x_*** are the corresponding features.
75 | """
76 | X = self.feat_cnn(X)
77 |
78 | if not self.use_linear:
79 | X = rearrange(X, 'b f c -> b (f c)')
80 | X = self.ll(X)
81 | else:
82 | X = self.global_average(X)
83 |
84 | return X.squeeze(2)
85 |
86 |
87 | class CNN_Multiscale(BaseModel):
88 | def __init__(self,
89 | hidden_dims: Sequence[int],
90 | out_dim: int,
91 | dilation: int = 1,
92 | gap: bool = False,
93 | se_block: bool = False,
94 | use_act: bool = False,
95 | net_normalization: str = 'none',
96 | activation_function: str = 'relu',
97 | dropout: float = 0.0,
98 | *args, **kwargs):
99 | # super().__init__(channels_list, out_dim, column_handler, projection, net_normalization,
100 | # gap, se_block, activation_function, dropout, *args, **kwargs)
101 | super().__init__(*args, **kwargs)
102 | self.save_hyperparameters()
103 | self.channels_per_layer = 200
104 | self.linear_in_shape = 10
105 | self.multiscale_in_shape = 10
106 | self.ratio = 16
107 | self.kernel_list = [6, 4, 4]
108 | self.stride_list = [2, 1, 1]
109 | self.stride = 1
110 | self.use_linear = gap
111 | self.out_size = out_dim
112 | self.channel_list = hidden_dims
113 |
114 | feat_cnn_modules = []
115 | for i in range(len(self.channel_list) - 1):
116 | out = self.channel_list[i + 1]
117 | feat_cnn_modules.append(nn.Conv1d(in_channels=self.channel_list[i],
118 | out_channels=out, kernel_size=self.kernel_list[i],
119 | stride=self.stride_list[i],
120 | bias=self.hparams.net_normalization != 'batch_norm'))
121 | if se_block:
122 | feat_cnn_modules.append(SE_Block(out, self.ratio))
123 | if self.hparams.net_normalization != 'none':
124 | feat_cnn_modules += [get_normalization_layer(self.hparams.net_normalization, self.channel_list[i + 1])]
125 | feat_cnn_modules.append(get_activation_function(activation_function, functional=False))
126 | # TODO: Need to add adaptive pooling with arguments
127 | feat_cnn_modules.append(nn.Dropout(dropout))
128 |
129 | self.feat_cnn = nn.Sequential(*feat_cnn_modules)
130 | kwargs = {'in_channels': self.channel_list[-1], 'channels_per_layer': self.channels_per_layer,
131 | 'out_shape': self.linear_in_shape, 'dil_rate': self.hparams.dilation, 'use_act': use_act}
132 | self.pyramid = Multiscale_Module(**kwargs)
133 |
134 | input_dim = [self.channel_list[0], self.linear_in_shape]
135 | # TODO: Need to pass input shape as argument
136 | # linear_input_shape = functools.reduce(operator.mul, list(self.feat_cnn(torch.rand(1, *input_dim)).shape))
137 | linear_layers = []
138 | # linear_layers.append(nn.Linear(int(self.channel_list[-1]/100)*1000, 300, bias=True))
139 | linear_layers.append(nn.Linear(2800, 256, bias=True))
140 | linear_layers.append(get_activation_function(activation_function, functional=False))
141 | linear_layers.append(nn.Linear(256, self.out_size, bias=True))
142 | self.ll = nn.Sequential(*linear_layers)
143 |
144 | def forward(self, X: Union[Tensor, Dict[str, Tensor]]) -> Tensor:
145 | """
146 | input:
147 | Dict with key-values {GLOBALS: x_glob, LEVELS: x_lev, LAYERS: x_lay},
148 | where x_*** are the corresponding features.
149 | """
150 | if isinstance(X, dict):
151 | X_levels = X['levels']
152 |
153 | X_layers = rearrange(F.pad(rearrange(X['layers'], 'b c f -> () b c f'), (0, 0, 1, 0), \
154 | mode='reflect'), '() b c f -> b c f')
155 | X_global = repeat(X['globals'], 'b f -> b c f', c=X_levels.shape[1])
156 |
157 | X = torch.cat((X_levels, X_layers, X_global), -1)
158 | X = rearrange(X, 'b c f -> b f c')
159 |
160 | X = self.feat_cnn(X)
161 | X = rearrange(X, 'b f c -> b (f c)')
162 | X = self.ll(X)
163 |
164 | return X.squeeze(1)
165 |
--------------------------------------------------------------------------------
/climart/models/CNNs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/models/CNNs/__init__.py
--------------------------------------------------------------------------------
/climart/models/GraphNet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/models/GraphNet/__init__.py
--------------------------------------------------------------------------------
/climart/models/GraphNet/constants.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Dict
3 |
4 | from torch import Tensor
5 |
6 |
7 | GLOBALS = "globals"
8 | NODES = "nodes"
9 | EDGES = "edges"
10 | GRAPH_COMPONENTS = [GLOBALS, NODES, EDGES]
11 |
12 | GraphComponentToTensor = Dict[str, Tensor]
13 |
14 |
15 | class AggregationTypes(Enum):
16 | AGG_E_TO_N = "edge_to_node_aggregation"
17 | AGG_E_TO_U = "edge_to_global_aggregation"
18 | AGG_N_TO_U = "node_to_global_aggregation"
19 | AGGREGATORS = [AGG_E_TO_N, AGG_E_TO_U, AGG_N_TO_U]
20 |
21 |
22 | N_NODES = "n_nodes"
23 | N_EDGES = "n_edges"
24 | SPATIAL_DIM = 1
25 |
--------------------------------------------------------------------------------
/climart/models/GraphNet/graph_network.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Sequence, Optional, Union
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 |
7 | from climart.data_transform.transforms import AbstractGraphTransform
8 | from climart.models.GraphNet.constants import AggregationTypes, NODES, EDGES, GLOBALS, GRAPH_COMPONENTS, SPATIAL_DIM
9 | from climart.models.GraphNet.graph_network_block import GraphNetBlock
10 | from climart.models.base_model import BaseModel
11 | from climart.models.modules.mlp import MLP
12 |
13 |
14 | class GraphNetwork(BaseModel):
15 | def __init__(self,
16 | # input_dim: Dict[str, int],
17 | hidden_dims: Sequence[int],
18 | input_transform: AbstractGraphTransform,
19 | readout_which_output: Optional[str] = NODES,
20 | update_mlp_n_layers: int = 1,
21 | aggregator_funcs: Union[str, Dict[AggregationTypes, int]] = 'sum',
22 | net_normalization: str = 'layer_norm',
23 | residual: Union[bool, Dict[str, bool]] = True,
24 | activation_function: str = 'Gelu',
25 | output_activation_function: Optional[str] = None,
26 | output_net_normalization: bool = True,
27 | dropout: float = 0.0,
28 | *args, **kwargs):
29 | """
30 | Args:
31 | readout_which_output: Which graph part to return (default: edges),
32 | can be {EDGES, NODES, GLOBALS, 'graph', None}
33 | If None or 'graph', the whole graph is returned.
34 | """
35 | super().__init__(input_transform=input_transform, *args, **kwargs)
36 | self.save_hyperparameters(ignore="verbose_mlp")
37 | assert len(self.hparams.hidden_dims) >= 1
38 | assert update_mlp_n_layers >= 1
39 | self.input_transform: AbstractGraphTransform = self.input_transform
40 | in_dim = self.input_transform.output_dim
41 |
42 | senders, receivers = self.input_transform.get_edge_idxs()
43 | gn_layers = []
44 | dims = [in_dim] + list(hidden_dims)
45 | for i in range(1, len(dims)):
46 | out_activation_function = output_activation_function if i == len(dims) - 1 else activation_function
47 | out_net_norm = output_net_normalization if i == len(dims) - 1 else True
48 | gn_layers += [
49 | GraphNetBlock(
50 | in_dims=in_dim,
51 | out_dims=dims[i],
52 | senders=senders,
53 | receivers=receivers,
54 | n_layers=update_mlp_n_layers,
55 | residual=residual,
56 | net_norm=net_normalization,
57 | activation=activation_function,
58 | dropout=dropout,
59 | output_normalization=out_net_norm,
60 | output_activation_function=out_activation_function,
61 | aggregator_funcs=aggregator_funcs,
62 | )]
63 | in_dim = dims[i]
64 |
65 | self.layers: nn.ModuleList[GraphNetBlock] = nn.ModuleList(gn_layers)
66 | self.output_type = readout_which_output
67 | if self.output_type not in [NODES, EDGES, GLOBALS, 'graph', None]:
68 | raise ValueError("Unsupported argument for GraphNetwork `output_type`", readout_which_output)
69 | if hasattr(self.input_transform, "n_edges"):
70 | err_msg = f"GN inferred {self.n_edges} edges, but input_tranform refers to {self.input_transform.n_edges}"
71 | assert self.n_edges == self.input_transform.n_edges, err_msg
72 |
73 | @property
74 | def n_edges(self):
75 | return self.layers[0].n_edges
76 |
77 | def update_graph_structure(self, senders: Sequence[int], receivers: Sequence[int]) -> None:
78 | for layer in self.layers:
79 | layer.update_graph_structure(senders, receivers)
80 |
81 | def forward(self, input: Dict[str, Tensor]):
82 | """
83 | input:
84 | Dict with key-values {GLOBALS: x_glob, LEVELS: x_lev, LAYERS: x_lay},
85 | where x_*** are the corresponding features.
86 | """
87 | graph_new = self.update_graph(input)
88 | if self.output_type is not None and self.output_type != 'graph':
89 | graph_component = graph_new[self.output_type]
90 | return graph_component.reshape(graph_component.shape[0], -1)
91 | else:
92 | return graph_new
93 |
94 | def update_graph(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
95 | graph_net_input = input
96 | graph_new = self.layers[0](graph_net_input)
97 | for graph_net_block in self.layers[1:]:
98 | graph_new = graph_net_block(graph_new)
99 |
100 | return graph_new
101 |
102 |
103 | class GN_withReadout(GraphNetwork):
104 | def __init__(self,
105 | # input_dim,
106 | # out_dim: int,
107 | readout_which_output=NODES,
108 | graph_pooling: str = 'mean',
109 | *args, **kwargs):
110 | super().__init__(output_activation_function=kwargs[
111 | 'activation_function'] if 'activation_function' in kwargs else 'gelu',
112 | *args, **kwargs)
113 | assert readout_which_output in GRAPH_COMPONENTS
114 | self.readout_which_output = readout_which_output
115 | self.graph_pooling = graph_pooling.lower()
116 |
117 | self.mlp_input_dim = self.hparams.hidden_dims[-1]
118 | if readout_which_output in [NODES, EDGES] and self.graph_pooling in ['sum+mean', 'mean+sum', 'mean&sum',
119 | 'sum&mean']:
120 | self.mlp_input_dim = self.mlp_input_dim * 2
121 |
122 | self.readout_mlp = MLP(
123 | input_dim=self.mlp_input_dim,
124 | hidden_dims=[int((self.mlp_input_dim + self.raw_output_dim) / 2)],
125 | output_dim=self.raw_output_dim,
126 | dropout=kwargs['dropout'] if 'dropout' in kwargs else 0.0,
127 | activation_function=self.hparams.activation_function,
128 | net_normalization='layer_norm' if 'inst' in self.hparams.net_normalization else self.hparams.net_normalization,
129 | output_normalization=False, output_activation_function=None, out_layer_bias_init=self.out_layer_bias_init,
130 | name='GraphNet_Readout_MLP'
131 | )
132 |
133 | def forward(self, input: Dict[str, torch.Tensor]):
134 | final_graph = self.update_graph(input)
135 | output_to_use = final_graph[self.readout_which_output]
136 |
137 | # Graph pooling, e.g. take the mean over all node embeddings (dimension=1)
138 | if self.readout_which_output == GLOBALS:
139 | g_emb = output_to_use
140 | else:
141 | if self.graph_pooling == 'sum':
142 | g_emb = torch.sum(output_to_use, dim=SPATIAL_DIM)
143 | elif self.graph_pooling == 'mean':
144 | g_emb = torch.mean(output_to_use, dim=SPATIAL_DIM)
145 | elif self.graph_pooling == 'max':
146 | g_emb, _ = torch.max(output_to_use, dim=SPATIAL_DIM) # returns (values, indices)
147 | elif self.graph_pooling in ['sum+mean', 'mean+sum', 'mean&sum', 'sum&mean']:
148 | xmean = torch.mean(output_to_use, dim=SPATIAL_DIM)
149 | xsum = torch.sum(output_to_use, dim=SPATIAL_DIM) # (batch-size, out-dim)
150 | g_emb = torch.cat((xmean, xsum), dim=SPATIAL_DIM) # (batch-size 2*out-dim)
151 | else:
152 | raise ValueError('Unsupported readout operation', self.graph_pooling)
153 |
154 | # After graph pooling: (batch-size, out-dim)
155 | # torch.Size([64, 100, 2])
156 | # torch.Size([64, 100])
157 | # After reshape: torch.Size([64, 200])
158 | out = self.readout_mlp(g_emb)
159 | return out
160 |
161 |
--------------------------------------------------------------------------------
/climart/models/GraphNet/graph_network_block.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from functools import partial
3 | from typing import Dict, Union, Optional, Sequence
4 | from einops import rearrange, repeat
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch_scatter
10 | from climart.models.modules.mlp import MLP
11 | import climart.models.GraphNet.constants as gn_constants
12 | from climart.models.GraphNet.constants import GraphComponentToTensor, AggregationTypes, NODES, EDGES, GLOBALS
13 |
14 |
15 | class GraphNetBlock(nn.Module):
16 | def __init__(self,
17 | in_dims: Union[int, Dict[str, int]],
18 | out_dims: Union[int, Dict[str, int]],
19 | senders: Sequence[int],
20 | receivers: Sequence[int],
21 | n_layers: int = 1,
22 | use_edge_features: bool = True,
23 | use_global_features: bool = True,
24 | residual: Union[bool, Dict[str, bool]] = False,
25 | net_norm: str = 'none',
26 | activation: str = 'relu',
27 | dropout: float = 0,
28 | output_normalization: bool = True,
29 | output_activation_function: Optional[str] = None,
30 | aggregator_funcs: Union[str, Dict[AggregationTypes, int]] = 'sum',
31 | ):
32 | super().__init__()
33 | if isinstance(in_dims, int):
34 | in_dims: dict = {c: in_dims for c in gn_constants.GRAPH_COMPONENTS}
35 | if isinstance(out_dims, int):
36 | out_dims: dict = {c: out_dims for c in gn_constants.GRAPH_COMPONENTS}
37 |
38 | self.components = gn_constants.GRAPH_COMPONENTS
39 | self.use_edge_features = use_edge_features
40 | self.use_global_features = use_global_features
41 |
42 | if not use_edge_features:
43 | in_dims[EDGES] = out_dims[EDGES] = 0
44 | self.components.remove(EDGES)
45 | if not use_global_features:
46 | in_dims[GLOBALS] = out_dims[GLOBALS] = 0
47 | self.components.remove(GLOBALS)
48 |
49 | n_feats_e = in_dims[EDGES]
50 | n_feats_n = in_dims[NODES]
51 | n_feats_u = in_dims[GLOBALS]
52 | self._n_edges = None
53 | self.update_graph_structure(senders, receivers)
54 | self.residual = {c: residual for c in self.components}
55 |
56 | in_dims = {
57 | EDGES: 2 * n_feats_n + n_feats_e + n_feats_u,
58 | NODES: n_feats_n + out_dims[EDGES] + n_feats_u,
59 | GLOBALS: out_dims[NODES] + out_dims[EDGES] + n_feats_u
60 | }
61 |
62 | update_funcs = OrderedDict() # nn.ModuleDict()
63 | for component in self.components:
64 | c_in_dim = in_dims[component]
65 | out_dim = out_dims[component]
66 | if c_in_dim != out_dim:
67 | self.residual[component] = False
68 |
69 | in_dim = in_dims[component]
70 | hdim = int((in_dim + out_dim) / 2)
71 |
72 | update_funcs[component] = MLP(
73 | input_dim=in_dim,
74 | hidden_dims=[hdim for _ in range(n_layers)],
75 | output_dim=out_dim,
76 | activation_function=activation,
77 | net_normalization=net_norm,
78 | dropout=dropout,
79 | output_normalization=output_normalization,
80 | output_activation_function=output_activation_function,
81 | out_layer_bias_init=None,
82 | name=f'GN_{component}_update_MLP',
83 | )
84 | self.update_funcs = nn.ModuleDict(update_funcs)
85 |
86 | self.aggregator_funcs = OrderedDict() # nn.ModuleDict()
87 | for aggregation in AggregationTypes:
88 | agg_func = aggregator_funcs[aggregation] if isinstance(aggregator_funcs, dict) else aggregator_funcs
89 | agg_func = agg_func.lower()
90 | agg_dim = gn_constants.SPATIAL_DIM
91 | if agg_func == 'sum':
92 | if aggregation == AggregationTypes.AGG_E_TO_N:
93 | # agg func returns a (batch-size, #nodes, #edge-feats) tensor
94 | agg_func = partial(torch_scatter.scatter_sum, dim=agg_dim)
95 | else:
96 | agg_func = partial(torch.sum, dim=agg_dim)
97 | elif agg_func == 'mean':
98 | if aggregation == AggregationTypes.AGG_E_TO_N:
99 | agg_func = partial(torch_scatter.scatter_mean, dim=agg_dim)
100 | else:
101 | agg_func = partial(torch.mean, dim=agg_dim)
102 | elif agg_func == 'max':
103 | if aggregation == AggregationTypes.AGG_E_TO_N:
104 | def max_scatter_partial(x, index):
105 | return torch_scatter.scatter_max(x, dim=agg_dim, index=index)[0]
106 | agg_func = max_scatter_partial
107 | else:
108 | agg_func = lambda x: torch.max(x, dim=agg_dim)[0] # returns (values, indices)[0]
109 | # elif agg_func in ['sum+mean', 'mean+sum', 'mean&sum', 'sum&mean']:
110 | # xmean = torch.mean(final_embs, dim=1)
111 | # xsum = torch.sum(final_embs, dim=1) # (batch-size, out-dim)
112 | # g_emb = torch.cat((xmean, xsum), dim=1) # (batch-size 2*out-dim)
113 | else:
114 | raise ValueError('Unsupported aggregation operation')
115 | self.aggregator_funcs[aggregation] = agg_func
116 |
117 | @property
118 | def n_edges(self):
119 | return self._n_edges
120 |
121 | def update_graph_structure(self, senders: Sequence[int], receivers: Sequence[int]) -> None:
122 | if not torch.is_tensor(senders):
123 | senders = torch.from_numpy(senders) if isinstance(senders, np.ndarray) else torch.tensor(senders)
124 | if not torch.is_tensor(receivers):
125 | receivers = torch.from_numpy(receivers) if isinstance(receivers, np.ndarray) else torch.tensor(receivers)
126 | self.register_buffer('_senders', senders.long()) # so that they are moved to correct device
127 | self.register_buffer('_receivers', receivers.long())
128 | assert len(self._receivers) == len(self._senders), "Sender and receiver must both have #edges indices"
129 | self._n_edges = len(self._senders)
130 |
131 | # def _aggregate_edges_for_node(self, mode, edge_feats_old: Tensor) -> Tensor:
132 | # """ Return a (batch-size, #nodes, #edge-feats) tensor, K, where K_bij = Agg({e_j | receiver_j = i}) """
133 | # indices = self._senders if mode == gn_constants.SENDERS else self._receivers
134 | # batch_size = edge_feats_old.shape[0]
135 |
136 | def forward(self, graph: GraphComponentToTensor) -> GraphComponentToTensor:
137 | if self.use_edge_features:
138 | edge_feats_old = graph[EDGES] # edge_feats_old have shape (b, #edges, #edge-feats)
139 | if edge_feats_old.shape[gn_constants.SPATIAL_DIM] != self._n_edges:
140 | raise ValueError(f"Edge features imply {edge_feats_old.shape[gn_constants.SPATIAL_DIM]} edges, "
141 | f"while sender and receiver lists imply {self._n_edges} edges.")
142 | node_feats_old = graph[NODES] # node_feats_old have shape (b, #nodes, #node-feats)
143 | n_nodes = node_feats_old.shape[gn_constants.SPATIAL_DIM]
144 | if self.use_global_features:
145 | global_feats_old = graph[GLOBALS] # global_feats_old have shape (b, #global-feats)
146 | batch_size, n_glob_feats = global_feats_old.shape
147 |
148 | out = {c: None for c in self.components}
149 |
150 | # ----------------------- Update edges
151 | # self.senders and self.receivers are a sequence of indices with #edges elements
152 | sender_feats = node_feats_old.index_select(index=self._senders, dim=gn_constants.SPATIAL_DIM)
153 | receiver_feats = node_feats_old.index_select(index=self._receivers, dim=gn_constants.SPATIAL_DIM)
154 |
155 | # print('E:', edge_feats_old.shape, 'V:', node_feats_old.shape, 'U:', global_feats_old.shape)
156 | # print('Senders', sender_feats.shape, 'Recvs', receiver_feats.shape, global_feats_unsqueezed.shape)
157 | mlp_input_e = torch.cat([
158 | edge_feats_old, # (b, #edges, #edge-feats)
159 | sender_feats, # (b, #edges, #node-feats)
160 | receiver_feats, # (b, #edges, #node-feats)
161 | repeat(global_feats_old, 'b g -> b e g', e=self._n_edges) # (1, self.n_edges, 1) (b, 1, #global-feats)
162 | ], dim=-1)
163 |
164 | mlp_input_e = rearrange(mlp_input_e, 'b e d1 -> (b e) d1')
165 |
166 | out[EDGES] = self.update_funcs[EDGES](mlp_input_e)
167 | out[EDGES] = rearrange(out[EDGES], '(b e) d2 -> b e d2', b=batch_size, e=self._n_edges)
168 | if self.residual[EDGES]:
169 | out[EDGES] += edge_feats_old
170 | # ----------------------- Update nodes
171 | aggregated_edge_feats_for_node = self.aggregator_funcs[AggregationTypes.AGG_E_TO_N](
172 | out[EDGES], index=self._receivers
173 | )
174 |
175 | mlp_input_n = torch.cat([
176 | aggregated_edge_feats_for_node, # (b, #nodes, #edge-feats)
177 | node_feats_old, # (b, #nodes, #node-feats)
178 | repeat(global_feats_old, 'b g -> b n g', n=n_nodes) # (b, #nodes, #global-feats)
179 | ], dim=-1)
180 | mlp_input_n = rearrange(mlp_input_n, 'b n d1 -> (b n) d1')
181 |
182 | # print('Agg E:', aggregated_edge_feats_for_node.shape, 'V:', node_feats_old.shape, mlp_input_n.shape)
183 | out[NODES] = self.update_funcs[NODES](mlp_input_n)
184 | out[NODES] = rearrange(out[NODES], '(b n) d2 -> b n d2', b=batch_size, n=n_nodes)
185 | if self.residual[NODES]:
186 | out[NODES] += node_feats_old
187 | # ----------------------- Update global features
188 | aggregated_edge_feats_for_global = self.aggregator_funcs[AggregationTypes.AGG_E_TO_U](out[EDGES])
189 | aggregated_node_feats_for_global = self.aggregator_funcs[AggregationTypes.AGG_N_TO_U](out[NODES])
190 | # print('Agg EU:', aggregated_edge_feats_for_global.shape, 'VU:', aggregated_node_feats_for_global.shape)
191 |
192 | mlp_input_u = torch.cat([
193 | aggregated_edge_feats_for_global, # (b, #edge-feats)
194 | aggregated_node_feats_for_global, # (b, #node-feats)
195 | global_feats_old # (b, #global-feats)
196 | ], dim=-1)
197 | out[GLOBALS] = self.update_funcs[GLOBALS](mlp_input_u)
198 | if self.residual[GLOBALS]:
199 | out[GLOBALS] += global_feats_old
200 |
201 | # for c in self.components:
202 | # if self.residual[c]:
203 | # out[c] += graph[c] # residual connection
204 |
205 | return out
206 |
207 |
208 |
209 | if __name__ == '__main__':
210 | nedges = 50
211 | nnodes = 49
212 | nfe, nfn, nfu = 2, 22, 82
213 | b = 64
214 | E = torch.randn((b, nedges, nfe))
215 | V = torch.randn((b, nnodes, nfn))
216 | U = torch.randn((b, nfu))
217 |
218 | senders = torch.arange(nedges) % nnodes
219 | receivers = (torch.arange(nedges) + 1) % nnodes
220 |
221 | gnl = GraphNetBlock(
222 | in_dims={NODES: nfn, EDGES: nfe, GLOBALS: nfu}, out_dims=128, senders=senders, receivers=receivers
223 | )
224 |
225 | X_in = {NODES: V, EDGES: E, GLOBALS: U}
226 | x = gnl(X_in)
227 | print([k + str(xx.shape) for k, xx in x.items()])
228 |
--------------------------------------------------------------------------------
/climart/models/MLP.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence, Optional, Dict, Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from omegaconf import DictConfig
7 | from torch import Tensor
8 |
9 | from climart.models.base_model import BaseModel
10 | from climart.models.modules.mlp import MLP
11 |
12 |
13 | class ClimartMLP(BaseModel):
14 | def __init__(self,
15 | hidden_dims: Sequence[int],
16 | datamodule_config: DictConfig = None,
17 | net_normalization: Optional[str] = None,
18 | activation_function: str = 'relu',
19 | dropout: float = 0.0,
20 | residual: bool = False,
21 | output_normalization: bool = False,
22 | output_activation_function: Optional[Union[str, bool]] = None,
23 | *args, **kwargs):
24 | """
25 | Args:
26 | input_dim must either be an int, i.e. the expected 1D input tensor dim, or a dict s.t.
27 | input_dim and spatial_dim have the same keys to compute the flattened input shape.
28 | output_activation_function (str, bool, optional): By default no output activation function is used (None).
29 | If a string is passed, is must be the name of the desired output activation (e.g. 'softmax')
30 | If True, the same activation function is used as defined by the arg `activation_function`.
31 | """
32 | super().__init__(datamodule_config=datamodule_config, *args, **kwargs)
33 | self.save_hyperparameters()
34 |
35 | if isinstance(self.raw_input_dim, dict):
36 | assert all([k in self.raw_spatial_dim.keys() for k in self.raw_input_dim.keys()])
37 | self.input_dim = sum([self.raw_input_dim[k] * max(1, self.raw_spatial_dim[k]) for k in self.raw_input_dim.keys()]) # flattened
38 | self.log_text.info(f' Inferred a flattened input dim = {self.input_dim}')
39 | else:
40 | self.input_dim = self.raw_input_dim
41 |
42 | self.output_dim = self.raw_output_dim
43 |
44 | self.mlp = MLP(
45 | self.input_dim, hidden_dims, self.output_dim,
46 | net_normalization=net_normalization, activation_function=activation_function, dropout=dropout,
47 | residual=residual, output_normalization=output_normalization,
48 | output_activation_function=output_activation_function, out_layer_bias_init=self.out_layer_bias_init
49 | )
50 |
51 | def forward(self, X: Tensor) -> Tensor:
52 | return self.mlp(X)
53 |
--------------------------------------------------------------------------------
/climart/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/models/__init__.py
--------------------------------------------------------------------------------
/climart/models/baseline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class LW_abs(nn.Module):
7 | """
8 | Inputs: 18 (Temp, Pres, and Gasesous concentrations) as inputs
9 | Outputs: 256 (g-points LW absorption cross section)
10 |
11 | """
12 | def __init__(self, input_size):
13 | super(LW_abs, self).__init__()
14 |
15 | self.input_size = input_size
16 | self.fc1 = nn.Linear(self.input_size, 58) #Not sure about the input shape, depends on nlevels
17 | self.fc2 = nn.Linear(58, 58)
18 | self.fc3 = nn.Linear(58, 256)
19 |
20 |
21 | def forward(self, x):
22 |
23 | x = F.softsign(self.fc1(x))
24 | x = F.softsign(self.fc2(x))
25 | x = self.fc3(x)
26 | return x
27 |
28 |
29 | class LW_ems(nn.Module):
30 | """
31 | Inputs: 18 (Temp, Pres, and Gasesous concentrations) as inputs
32 | Outputs: 256 (g-points LW emission)
33 |
34 | """
35 | def __init__(self, input_size):
36 | super(LW_ems, self).__init__()
37 |
38 | self.input_size = input_size
39 | self.fc1 = nn.Linear(self.input_size, 16)
40 | self.fc2 = nn.Linear(16, 16)
41 | self.fc3 = nn.Linear(16, 256)
42 |
43 |
44 | def forward(self, x):
45 |
46 | x = F.softsign(self.fc1(x))
47 | x = F.softsign(self.fc2(x))
48 | x = self.fc3(x)
49 | return x
50 |
51 |
52 | class SW_abs(nn.Module):
53 | """
54 | Inputs: 7 (Temp, Pres, and Gasesous concentrations) as inputs, SW takes lesser gases than LW
55 | Outputs: 224 (g-points SW absorption cross section)
56 |
57 | """
58 | def __init__(self, input_size):
59 | super(SW_abs, self).__init__()
60 |
61 | self.input_size = input_size
62 | self.fc1 = nn.Linear(self.input_size, 48)
63 | self.fc2 = nn.Linear(48, 48)
64 | self.fc3 = nn.Linear(48, 224)
65 |
66 |
67 | def forward(self, x):
68 |
69 | x = F.softsign(self.fc1(x))
70 | x = F.softsign(self.fc2(x))
71 | x = self.fc3(x)
72 | return x
73 |
74 |
75 |
76 | class SW_rcs(nn.Module):
77 | """
78 | Inputs: 7 (Temp, Pres, and Gasesous concentrations) as inputs, SW takes lesser gases than LW
79 | Outputs: 224 (g-points SW rayleigh cross section)
80 |
81 | """
82 | def __init__(self, input_size):
83 | super(SW_rcs, self).__init__()
84 |
85 | self.input_size = input_size
86 | self.fc1 = nn.Linear(self.input_size, 16)
87 | self.fc2 = nn.Linear(16, 16)
88 | self.fc3 = nn.Linear(16, 224)
89 |
90 |
91 | def forward(self, x):
92 |
93 | x = F.softsign(self.fc1(x))
94 | x = F.softsign(self.fc2(x))
95 | x = self.fc3(x)
96 | return x
97 |
98 |
99 | if __name__ == "__main__":
100 | net = SW_rcs(180)
101 | print(net)
102 |
103 |
104 | input = torch.randn(180)
105 | out = net(input)
106 | print(out)
107 | # params = list(net.parameters())
108 | # print(params)
109 |
--------------------------------------------------------------------------------
/climart/models/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/models/modules/__init__.py
--------------------------------------------------------------------------------
/climart/models/modules/additional_layers.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Dict, Optional, Union, Callable, Any
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from torch import Tensor
9 | from climart.models.modules.mlp import MLP
10 |
11 |
12 | class Multiscale_Module(nn.Module):
13 |
14 | def __init__(self, in_channels=None, channels_per_layer=None, out_shape=None,
15 | dil_rate=1, use_act=False, *args, **kwargs):
16 | super().__init__()
17 | self.out_shape = out_shape
18 | self.use_act = use_act
19 |
20 | self.multi_3 = nn.Conv1d(in_channels=in_channels, out_channels=channels_per_layer,
21 | kernel_size=5, stride=1, dilation=dil_rate)
22 | self.multi_5 = nn.Conv1d(in_channels=in_channels, out_channels=channels_per_layer,
23 | kernel_size=6, stride=1, dilation=dil_rate)
24 | self.multi_7 = nn.Conv1d(in_channels=in_channels, out_channels=channels_per_layer,
25 | kernel_size=9, stride=1, dilation=dil_rate)
26 | self.after_concat = nn.Conv1d(in_channels=int(channels_per_layer * 3),
27 | out_channels=int(channels_per_layer / 2), kernel_size=1, stride=1)
28 | self.gap = GAP()
29 |
30 | def forward(self, x):
31 | x_3 = self.multi_3(x)
32 | x_5 = self.multi_5(x)
33 | x_7 = self.multi_7(x)
34 | x_3 = F.adaptive_max_pool1d(x_3, self.out_shape)
35 | x_5 = F.adaptive_max_pool1d(x_5, self.out_shape)
36 | x_7 = F.adaptive_max_pool1d(x_7, self.out_shape)
37 | x_concat = torch.cat((x_3, x_5, x_7), 1)
38 | x_concat = self.after_concat(x_concat)
39 |
40 | if self.use_act:
41 | return torch.sigmoid(self.gap(x)) * x_concat
42 | else:
43 | return x_concat
44 |
45 |
46 | class GAP():
47 | def __init__(self):
48 | pass
49 |
50 | def __call__(self, x):
51 | x = F.adaptive_avg_pool1d(x, 1)
52 | return x
53 |
54 |
55 | class SE_Block(nn.Module):
56 | def __init__(self, c, r=16):
57 | super().__init__()
58 | self.squeeze = GAP()
59 | self.excitation = nn.Sequential(
60 | nn.Linear(c, c // r, bias=False),
61 | nn.ReLU(inplace=True),
62 | nn.Linear(c // r, c, bias=False),
63 | nn.Sigmoid()
64 | )
65 |
66 | def forward(self, x):
67 | b, c, f = x.shape
68 | y = self.squeeze(x).view(b, c)
69 | y = self.excitation(y).view(b, c, 1)
70 | return x * y.expand_as(x)
71 |
72 |
73 | class FeatureProjector(nn.Module):
74 | def __init__(
75 | self,
76 | input_name_to_feature_dim: Dict[str, int],
77 | projection_dim: int = 128,
78 | projector_n_layers: int = 1,
79 | projector_activation_func: str = 'Gelu',
80 | projector_net_normalization: str = 'none',
81 | projections_aggregation: Optional[Union[Callable[[Dict[str, Tensor]], Any], str]] = None,
82 | output_normalization: bool = True,
83 | output_activation_function: bool = False
84 | ):
85 | super().__init__()
86 | self.projection_dim = projection_dim
87 | self.input_name_to_feature_dim = input_name_to_feature_dim
88 | self.projections_aggregation = projections_aggregation
89 |
90 | input_name_to_projector = OrderedDict()
91 | for input_name, feature_dim in input_name_to_feature_dim.items():
92 | projector_hidden_dim = int((feature_dim + projection_dim) / 2)
93 | projector = MLP(
94 | input_dim=feature_dim,
95 | hidden_dims=[projector_hidden_dim for _ in range(projector_n_layers)],
96 | output_dim=projection_dim,
97 | activation_function=projector_activation_func,
98 | net_normalization=projector_net_normalization,
99 | dropout=0,
100 | output_normalization=output_normalization,
101 | output_activation_function=output_activation_function,
102 | name=f"{input_name}_MLP_projector"
103 | )
104 | input_name_to_projector[input_name] = projector
105 |
106 | self.input_name_to_projector = nn.ModuleDict(input_name_to_projector)
107 |
108 | def forward(self,
109 | inputs: Dict[str, Tensor]
110 | ) -> Union[Dict[str, Tensor], Any]:
111 | name_to_projection = dict()
112 |
113 | for name, projector in self.input_name_to_projector.items():
114 | # Project (batch-size, ..arbitrary_dim(s).., in_feat_dim)
115 | # to (batch-size, ..arbitrary_dim(s).., projection_dim)
116 | in_feat_dim = self.input_name_to_feature_dim[name]
117 | input_tensor = inputs[name]
118 | shape_out = list(input_tensor.shape)
119 | shape_out[-1] = self.projection_dim
120 |
121 | projector_input = input_tensor.reshape((-1, in_feat_dim))
122 | # projector_input has shape (batch-size * #spatial-dims, features)
123 | flattened_projection = projector(projector_input)
124 | name_to_projection[name] = flattened_projection.reshape(shape_out)
125 |
126 | if self.projections_aggregation is None:
127 | return name_to_projection # Dict[str, Tensor]
128 | else:
129 | return self.projections_aggregation(name_to_projection) # Any
130 |
131 |
132 | class PredictorHeads(nn.Module):
133 | """
134 | Module to predict (with one or more MLPs) desired output based on a 1D hidden representation.
135 | Can be used to:
136 | - readout a hidden vector to a desired output dimensionality
137 | - predict multiple variables with separate MLP heads (e.g. one for rsuc and rsdc each)
138 | """
139 |
140 | def __init__(
141 | self,
142 | input_dim: int,
143 | var_name_to_output_dim: Dict[str, int],
144 | separate_heads: bool = True,
145 | n_layers: int = 1,
146 | activation_func: str = 'Gelu',
147 | net_normalization: str = 'none',
148 | ):
149 | super().__init__()
150 | self.input_dim = input_dim
151 | self.separate_heads = separate_heads
152 |
153 | predictor_heads = OrderedDict()
154 | mlp_shared_params = {
155 | 'input_dim': input_dim,
156 | 'activation_function': activation_func,
157 | 'net_normalization': net_normalization,
158 | 'output_normalization': False,
159 | 'dropout': 0
160 | }
161 | if self.separate_heads:
162 | self.output_name_to_feature_dim = var_name_to_output_dim
163 |
164 | for output_name, var_out_dim in var_name_to_output_dim.items():
165 | predictor_hidden_dim = int((input_dim + var_out_dim) / 2)
166 | predictor = MLP(
167 | hidden_dims=[predictor_hidden_dim for _ in range(n_layers)],
168 | output_dim=var_out_dim,
169 | **mlp_shared_params
170 | )
171 | predictor_heads[output_name] = predictor
172 | else:
173 | joint_out_dim = sum([out_dim for _, out_dim in var_name_to_output_dim.items()])
174 | self.output_name_to_feature_dim = {'joint_output': joint_out_dim}
175 |
176 | predictor_hidden_dim = int((input_dim + joint_out_dim) / 2)
177 | predictor = MLP(
178 | hidden_dims=[predictor_hidden_dim for _ in range(n_layers)],
179 | output_dim=joint_out_dim,
180 | **mlp_shared_params
181 | )
182 | predictor_heads['joint_output'] = predictor
183 |
184 | self.predictor_heads = nn.ModuleDict(predictor_heads)
185 |
186 | def forward(self,
187 | hidden_input: Tensor, # (batch-size, hidden-dim) 1D tensor
188 | as_dict: bool = False,
189 | ) -> Union[Dict[str, Tensor], Tensor]:
190 |
191 | name_to_prediction = OrderedDict()
192 | for name, predictor in self.predictor_heads.items():
193 | name_to_prediction[name] = predictor(hidden_input)
194 |
195 | if self.separate_heads:
196 | if as_dict:
197 | return name_to_prediction
198 | else:
199 | joint_output = torch.cat(list(name_to_prediction.values()), dim=-1)
200 | return joint_output
201 | else:
202 | return name_to_prediction if as_dict else name_to_prediction['joint_output']
203 |
--------------------------------------------------------------------------------
/climart/models/modules/mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence, Optional, Dict, Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from torch import Tensor
7 | from climart.utils.utils import get_activation_function, get_normalization_layer, get_logger
8 |
9 | log = get_logger(__name__)
10 |
11 |
12 | class MLP(nn.Module):
13 | def __init__(self,
14 | input_dim: int,
15 | hidden_dims: Sequence[int],
16 | output_dim: int,
17 | net_normalization: Optional[str] = None,
18 | activation_function: str = 'gelu',
19 | dropout: float = 0.0,
20 | residual: bool = False,
21 | output_normalization: bool = False,
22 | output_activation_function: Optional[Union[str, bool]] = None,
23 | out_layer_bias_init: Tensor = None,
24 | name: str = ""
25 | ):
26 | """
27 | Args:
28 | input_dim (int): the expected 1D input tensor dim
29 | output_activation_function (str, bool, optional): By default no output activation function is used (None).
30 | If a string is passed, is must be the name of the desired output activation (e.g. 'softmax')
31 | If True, the same activation function is used as defined by the arg `activation_function`.
32 | """
33 |
34 | super().__init__()
35 | self.name = name
36 | hidden_layers = []
37 | dims = [input_dim] + list(hidden_dims)
38 | for i in range(1, len(dims)):
39 | hidden_layers += [MLP_Block(
40 | in_dim=dims[i - 1],
41 | out_dim=dims[i],
42 | net_norm=net_normalization.lower() if isinstance(net_normalization, str) else 'none',
43 | activation_function=activation_function,
44 | dropout=dropout,
45 | residual=residual
46 | )]
47 | self.hidden_layers = nn.ModuleList(hidden_layers)
48 |
49 | out_weight = nn.Linear(dims[-1], output_dim, bias=True)
50 | if out_layer_bias_init is not None:
51 | log.info(' Pre-initializing the MLP final/output layer bias.')
52 | out_weight.bias.data = out_layer_bias_init
53 | out_layer = [out_weight]
54 | if output_normalization and net_normalization != 'none':
55 | out_layer += [get_normalization_layer(net_normalization, output_dim)]
56 | if output_activation_function is not None and output_activation_function:
57 | if isinstance(output_activation_function, bool):
58 | output_activation_function = activation_function
59 |
60 | out_layer += [get_activation_function(output_activation_function, functional=False)]
61 | self.out_layer = nn.Sequential(*out_layer)
62 |
63 | def forward(self, X: Tensor) -> Tensor:
64 | for layer in self.hidden_layers:
65 | X = layer(X)
66 |
67 | Y = self.out_layer(X)
68 | return Y.squeeze(1)
69 |
70 |
71 | class MLP_Block(nn.Module):
72 | def __init__(self,
73 | in_dim: int,
74 | out_dim: int,
75 | net_norm: str = 'none',
76 | activation_function: str = 'Gelu',
77 | dropout: float = 0.0,
78 | residual: bool = False
79 | ):
80 | super().__init__()
81 | layer = [nn.Linear(in_dim, out_dim, bias=net_norm != 'batch_norm')]
82 | if net_norm != 'none':
83 | layer += [get_normalization_layer(net_norm, out_dim)]
84 | layer += [get_activation_function(activation_function, functional=False)]
85 | if dropout > 0:
86 | layer += [nn.Dropout(dropout)]
87 | self.layer = nn.Sequential(*layer)
88 | self.residual = residual
89 | if in_dim != out_dim:
90 | self.residual = False
91 | elif residual:
92 | print('MLP block with residual!')
93 |
94 | def forward(self, X: Tensor) -> Tensor:
95 | X_out = self.layer(X)
96 | if self.residual:
97 | X_out += X
98 | return X_out
99 |
--------------------------------------------------------------------------------
/climart/train.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | from hydra.utils import instantiate as hydra_instantiate
3 | from omegaconf import DictConfig
4 |
5 | import pytorch_lightning as pl
6 | from pytorch_lightning import seed_everything
7 |
8 | import climart.utils.config_utils as cfg_utils
9 | from climart.interface import get_model_and_data
10 | from climart.utils.utils import get_logger
11 |
12 |
13 | def run_model(config: DictConfig):
14 | seed_everything(config.seed, workers=True)
15 | log = get_logger(__name__)
16 | cfg_utils.extras(config)
17 |
18 | if config.get("print_config"):
19 | cfg_utils.print_config(config, fields='all')
20 |
21 | emulator_model, data_module = get_model_and_data(config)
22 |
23 | # Init Lightning callbacks and loggers
24 | callbacks = cfg_utils.get_all_instantiable_hydra_modules(config, 'callbacks')
25 | loggers = cfg_utils.get_all_instantiable_hydra_modules(config, 'logger')
26 |
27 | # Init Lightning trainer
28 | trainer: pl.Trainer = hydra_instantiate(
29 | config.trainer, callbacks=callbacks, logger=loggers, # , deterministic=True
30 | )
31 |
32 | # Send some parameters from config to all lightning loggers
33 | log.info("Logging hyperparameters to the PyTorch Lightning loggers.")
34 | cfg_utils.log_hyperparameters(config=config, model=emulator_model, data_module=data_module, trainer=trainer,
35 | callbacks=callbacks)
36 |
37 | trainer.fit(model=emulator_model, datamodule=data_module)
38 |
39 | cfg_utils.save_hydra_config_to_wandb(config)
40 |
41 | # Testing:
42 | if config.get("test_after_training"):
43 | trainer.test(datamodule=data_module, ckpt_path='best')
44 |
45 | if config.get('logger') and config.logger.get("wandb"):
46 | wandb.finish()
47 |
48 | # log.info("Reloading model from checkpoint based on best validation stat.")
49 | # final_model = emulator_model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path,
50 | # datamodule_config=config.datamodule, output_normalizer=data_module.normalizer.output_normalizer)
51 | # return final_model
52 |
--------------------------------------------------------------------------------
/climart/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/climart/utils/__init__.py
--------------------------------------------------------------------------------
/climart/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict, List, Union, Sequence
2 | import numpy as np
3 | import torch
4 |
5 |
6 | class TestingScheduleCallback:
7 | def __init__(self,
8 | start_epoch: int = 1,
9 | test_at_most_every_n_epochs: int = 10,
10 | test_at_least_every_n_epochs: int = 20,
11 | test_on_new_best_validation: bool = True,
12 | ignore_first_n_epochs: int = 5,
13 | ):
14 | self.test_at_most_every_n_epochs = test_at_most_every_n_epochs
15 | self.test_at_least_every_n_epochs = test_at_least_every_n_epochs
16 | self.test_on_new_best_validation = test_on_new_best_validation
17 | self.untested_epochs = 1
18 | self.cur_epoch = start_epoch
19 | self.ignore_first_n_epochs = ignore_first_n_epochs
20 |
21 | def __call__(self, is_new_best_val_model: bool = False) -> bool:
22 | do_test = False
23 | if self.cur_epoch <= self.ignore_first_n_epochs:
24 | do_test = False
25 | elif self.untested_epochs >= self.test_at_least_every_n_epochs:
26 | do_test = True
27 | elif self.test_on_new_best_validation and is_new_best_val_model:
28 | if self.untested_epochs < self.test_at_most_every_n_epochs:
29 | do_test = False
30 | else:
31 | do_test = True
32 |
33 | self.cur_epoch += 1
34 | if do_test:
35 | self.untested_epochs = 1
36 | else:
37 | self.untested_epochs += 1
38 | return do_test
39 |
40 |
41 | class PredictionPostProcessCallback:
42 | def __init__(self,
43 | variables: List[str],
44 | sizes: Union[int, Sequence[int]]
45 | ):
46 | self.variable_to_channel = dict()
47 | cur = 0
48 | sizes = [sizes for _ in range(len(variables))] if isinstance(sizes, int) else sizes
49 | for var, size in zip(variables, sizes):
50 | self.variable_to_channel[var] = {'start': cur, 'end': cur + size}
51 | cur += size
52 |
53 | def split_vector_by_variable(self,
54 | vector: Union[np.ndarray, torch.Tensor]
55 | ) -> Dict[str, Union[np.ndarray, torch.Tensor]]:
56 | if isinstance(vector, dict):
57 | return vector
58 | splitted_vector = dict()
59 | for var_name, var_channel_limits in self.variable_to_channel.items():
60 | splitted_vector[var_name] = vector[..., var_channel_limits['start']:var_channel_limits['end']]
61 | return splitted_vector
62 |
63 | def __call__(self, vector, *args, **kwargs):
64 | return self.split_vector_by_variable(vector)
65 |
--------------------------------------------------------------------------------
/climart/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import warnings
4 | from typing import Union, Sequence, List
5 |
6 | import omegaconf
7 | import pytorch_lightning as pl
8 | import wandb
9 | from omegaconf import DictConfig, OmegaConf, open_dict
10 |
11 | from climart.utils.naming import get_group_name, get_detailed_name
12 | from climart.utils.utils import no_op, get_logger
13 |
14 | log = get_logger(__name__)
15 |
16 | def print_config(
17 | config,
18 | fields: Union[str, Sequence[str]] = (
19 | "datamodule",
20 | "model",
21 | "trainer",
22 | # "callbacks",
23 | # "logger",
24 | "seed",
25 | ),
26 | resolve: bool = True,
27 | ) -> None:
28 | """Prints content of DictConfig using Rich library and its tree structure.
29 |
30 | Credits go to: https://github.com/ashleve/lightning-hydra-template
31 |
32 | Args:
33 | config (ConfigDict): Configuration
34 | fields (Sequence[str], optional): Determines which main fields from config will
35 | be printed and in what order.
36 | resolve (bool, optional): Whether to resolve reference fields of DictConfig.
37 | """
38 | import importlib
39 | if not importlib.util.find_spec("rich") or not importlib.util.find_spec("omegaconf"):
40 | # no pretty printing
41 | return
42 | import rich.syntax
43 | import rich.tree
44 |
45 | style = "dim"
46 | tree = rich.tree.Tree(":gear: CONFIG", style=style, guide_style=style)
47 | if isinstance(fields, str):
48 | if fields.lower() == 'all':
49 | fields = config.keys()
50 | else:
51 | fields = [fields]
52 |
53 | for field in fields:
54 | branch = tree.add(field, style=style, guide_style=style)
55 |
56 | config_section = config.get(field)
57 | branch_content = str(config_section)
58 | if isinstance(config_section, DictConfig):
59 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
60 |
61 | branch.add(rich.syntax.Syntax(branch_content, "yaml"))
62 |
63 | rich.print(tree)
64 |
65 |
66 | def extras(config: DictConfig) -> None:
67 | """A couple of optional utilities, controlled by main config file:
68 | - disabling warnings
69 | - easier access to debug mode
70 | - forcing debug friendly configuration
71 | - forcing multi-gpu friendly configuration
72 |
73 | Credits go to: https://github.com/ashleve/lightning-hydra-template
74 |
75 | Modifies DictConfig in place.
76 | """
77 | log = get_logger()
78 |
79 | # Create working dir if it does not exist yet
80 | if config.get('work_dir'):
81 | os.makedirs(name=config.get("work_dir"), exist_ok=True)
82 |
83 | # disable python warnings if
84 | if config.get("ignore_warnings"):
85 | log.info("Disabling python warnings! ")
86 | warnings.filterwarnings("ignore")
87 |
88 | # set if
89 | if config.get("debug"):
90 | log.info("Running in debug mode! ")
91 | config.trainer.fast_dev_run = True
92 |
93 | # force debugger friendly configuration if
94 | if config.trainer.get("fast_dev_run"):
95 | log.info("Forcing debugger friendly configuration! ")
96 | # Debuggers don't like GPUs or multiprocessing
97 | if config.trainer.get("gpus"):
98 | config.trainer.gpus = 0
99 | if config.datamodule.get("pin_memory"):
100 | config.datamodule.pin_memory = False
101 | if config.datamodule.get("num_workers"):
102 | config.datamodule.num_workers = 0
103 |
104 | # force multi-gpu friendly configuration if
105 | accelerator = config.trainer.get("accelerator")
106 | if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]:
107 | log.info(f"Forcing ddp friendly configuration! ")
108 | if config.datamodule.get("num_workers"):
109 | config.datamodule.num_workers = 0
110 | if config.datamodule.get("pin_memory"):
111 | config.datamodule.pin_memory = False
112 |
113 | USE_WANDB = "logger" in config.keys() and config.logger.get("wandb")
114 | if USE_WANDB:
115 | if not config.logger.wandb.get('id'): # no wandb id has been assigned yet
116 | wandb_id = wandb.util.generate_id()
117 | config.logger.wandb.id = wandb_id
118 | else:
119 | log.info(f" This experiment config already has a wandb run ID = {config.logger.wandb.id}")
120 | if not config.logger.wandb.get('group'): # no wandb group has been assigned yet
121 | group_name = get_group_name(config)
122 | config.logger.wandb.group = group_name if len(group_name) < 128 else group_name[:128]
123 | config.logger.wandb.name = get_detailed_name(config) + '_' + time.strftime('%Hh%Mm_on_%b_%d') + '_' + config.logger.wandb.id
124 |
125 | check_config_values(config)
126 | if USE_WANDB:
127 | wandb_kwargs = {
128 | k: config.logger.wandb.get(k) for k in ['id', 'project', 'entity', 'name', 'group',
129 | 'tags', 'notes', 'reinit', 'mode', 'resume']
130 | }
131 | wandb_kwargs['dir'] = config.logger.wandb.get('save_dir')
132 | wandb.init(**wandb_kwargs)
133 | log.info(f"Wandb kwargs: {wandb_kwargs}")
134 | save_hydra_config_to_wandb(config)
135 |
136 |
137 | def check_config_values(config: DictConfig):
138 | exp_type = config.datamodule.exp_type.lower()
139 | config.datamodule.exp_type = exp_type
140 | if exp_type not in ["clear_sky", "pristine"]:
141 | raise ValueError(f"Arg `exp_type` should be one of clear_sky or pristine, but got {exp_type}")
142 |
143 | if "net_normalization" in config.model.keys():
144 | if config.model.net_normalization is None:
145 | config.model.net_normalization = "none"
146 | config.model.net_normalization = config.model.net_normalization.lower()
147 |
148 | if config.logger.get("wandb"):
149 | if 'callbacks' in config and config.callbacks.get('model_checkpoint'):
150 | id_mdl = config.logger.wandb.get('id')
151 | d = config.callbacks.model_checkpoint.dirpath
152 | if id_mdl is not None:
153 | with open_dict(config):
154 | new_dir = os.path.join(d, id_mdl)
155 | config.callbacks.model_checkpoint.dirpath = new_dir
156 | os.makedirs(new_dir, exist_ok=True)
157 | log.info(f" Model checkpoints will be saved in: {new_dir}")
158 | else:
159 | if config.save_config_to_wandb:
160 | log.warning(" `save_config_to_wandb`=True but no wandb logger was found.. config will not be saved!")
161 | config.save_config_to_wandb = False
162 |
163 |
164 | def get_all_instantiable_hydra_modules(config, module_name: str):
165 | from hydra.utils import instantiate as hydra_instantiate
166 | modules = []
167 | if module_name in config:
168 | for _, module_config in config[module_name].items():
169 | if module_config is not None and "_target_" in module_config:
170 | modules.append(hydra_instantiate(module_config))
171 | return modules
172 |
173 |
174 | def log_hyperparameters(
175 | config,
176 | model: pl.LightningModule,
177 | data_module: pl.LightningDataModule,
178 | trainer: pl.Trainer,
179 | callbacks: List[pl.Callback],
180 | ) -> None:
181 | """This method controls which parameters from Hydra config are saved by Lightning loggers.
182 | Credits go to: https://github.com/ashleve/lightning-hydra-template
183 |
184 | Additionally saves:
185 | - number of {total, trainable, non-trainable} model parameters
186 | """
187 |
188 | def copy_and_ignore_keys(dictionary, *keys_to_ignore):
189 | new_dict = dict()
190 | for k in dictionary.keys():
191 | if k not in keys_to_ignore:
192 | new_dict[k] = dictionary[k]
193 | return new_dict
194 |
195 | params = dict()
196 | if 'seed' in config:
197 | params['seed'] = config['seed']
198 | if 'model' in config:
199 | params['model'] = config['model']
200 |
201 | # Remove redundant keys or those that are not important to know after training -- feel free to edit this!
202 | params["datamodule"] = copy_and_ignore_keys(config["datamodule"], 'pin_memory', 'num_workers')
203 | params['model'] = copy_and_ignore_keys(config['model'], 'optimizer', 'scheduler')
204 | params['normalizer'] = config['normalizer']
205 | params["trainer"] = copy_and_ignore_keys(config["trainer"])
206 | # encoder, optims, and scheduler as separate top-level key
207 | params['optim'] = config['model']['optimizer']
208 | params['scheduler'] = config['model']['scheduler'] if 'scheduler' in config['model'] else None
209 |
210 | if "callbacks" in config:
211 | if 'model_checkpoint' in config['callbacks']:
212 | params["model_checkpoint"] = copy_and_ignore_keys(
213 | config["callbacks"]['model_checkpoint'], 'save_top_k'
214 | )
215 |
216 | # save number of model parameters
217 | params["model/params_total"] = sum(p.numel() for p in model.parameters())
218 | params["model/params_trainable"] = sum(
219 | p.numel() for p in model.parameters() if p.requires_grad
220 | )
221 | params["model/params_not_trainable"] = sum(
222 | p.numel() for p in model.parameters() if not p.requires_grad
223 | )
224 | params['dirs/work_dir'] = config.get('work_dir')
225 | params['dirs/ckpt_dir'] = config.get('ckpt_dir')
226 | params['dirs/wandb_save_dir'] = config.logger.wandb.save_dir if (
227 | config.get('logger') and config.logger.get('wandb')) else None
228 |
229 | # send hparams to all loggers
230 | trainer.logger.log_hyperparams(params)
231 |
232 | # disable logging any more hyperparameters for all loggers
233 | # this is just a trick to prevent trainer from logging hparams of model,
234 | # since we already did that above
235 | trainer.logger.log_hyperparams = no_op
236 |
237 |
238 | def save_hydra_config_to_wandb(config: DictConfig):
239 | if config.get('save_config_to_wandb'):
240 | log.info(f"Hydra config will be saved to WandB as hydra_config.yaml and in wandb run_dir: {wandb.run.dir}")
241 | # files in wandb.run.dir folder get directly uploaded to wandb
242 | with open(os.path.join(wandb.run.dir, "hydra_config.yaml"), "w") as fp:
243 | OmegaConf.save(config, f=fp.name, resolve=True)
244 | wandb.save(os.path.join(wandb.run.dir, "hydra_config.yaml"))
245 |
246 |
247 | def get_config_from_hydra_compose_overrides(overrides: list) -> DictConfig:
248 | import hydra
249 | from hydra.core.global_hydra import GlobalHydra
250 | overrides = list(set(overrides))
251 | if '-m' in overrides:
252 | overrides.remove('-m') # if multiruns flags are mistakenly in overrides
253 | hydra.initialize(config_path="../../configs")
254 | try:
255 | config = hydra.compose(config_name="main_config", overrides=overrides)
256 | finally:
257 | GlobalHydra.instance().clear() # always clean up global hydra
258 | return config
259 |
260 |
261 | def get_model_from_hydra_compose_overrides(overrides: list):
262 | from climart.interface import get_model
263 | cfg = get_config_from_hydra_compose_overrides(overrides)
264 | return get_model(cfg)
265 |
--------------------------------------------------------------------------------
/climart/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import numpy as np
4 | # from sklearn.metrics import mean_squared_error, mean_absolute_error
5 | from climart.data_loading.constants import get_statistics
6 | from climart.utils import utils
7 |
8 | log = utils.get_logger(__name__)
9 |
10 |
11 | def evaluate_preds(Ytrue: np.ndarray, preds: np.ndarray):
12 | MSE = np.mean((preds - Ytrue) ** 2) # mean_squared_error(preds, Ytrue)
13 | RMSE = np.sqrt(MSE)
14 | # MAE = mean_absolute_error(preds, Ytrue)
15 | MBE = np.mean(preds - Ytrue)
16 | stats = {'mbe': MBE,
17 | # 'mse': MSE,
18 | # 'mae': MAE,
19 | "rmse": RMSE}
20 |
21 | return stats
22 |
23 |
24 | def evaluate_per_target_variable(Ytrue: dict,
25 | preds: dict,
26 | data_split: str = None) -> Dict[str, float]:
27 | stats = dict()
28 | if not isinstance(Ytrue, dict):
29 | log.warning(f" Expected a dictionary var_name->Tensor/nd_array, but got {type(Ytrue)} for Ytrue!")
30 | return stats
31 |
32 | for var_name in Ytrue.keys():
33 | # var_name stands for 'rsuc', 'hrsc', etc., i.e. shortwave upwelling flux, shortwave heating rate, etc.
34 | var_stats = evaluate_preds(Ytrue[var_name], preds[var_name])
35 | # pre-append the variable's name to its specific performance on the returned metrics dict
36 | for metric_name, metric_stat in var_stats.items():
37 | stats[f"{data_split}/{var_name}/{metric_name}"] = metric_stat
38 |
39 | num_height_levels = Ytrue[var_name].shape[1]
40 | for lvl in range(0, num_height_levels):
41 | stats_lvl = evaluate_preds(Ytrue[var_name][:, lvl], preds[var_name][:, lvl])
42 | for metric_name, metric_stat in stats_lvl.items():
43 | stats[f"levelwise/{data_split}/{var_name}_level{lvl}/{metric_name}"] = metric_stat
44 |
45 | if lvl == num_height_levels - 1:
46 | for metric_name, metric_stat in stats_lvl.items():
47 | stats[f"{data_split}/{var_name}_surface/{metric_name}"] = metric_stat
48 | if lvl == 0:
49 | for metric_name, metric_stat in stats_lvl.items():
50 | stats[f"{data_split}/{var_name}_toa/{metric_name}"] = metric_stat
51 |
52 | return stats
53 |
--------------------------------------------------------------------------------
/climart/utils/naming.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List
2 | from omegaconf import DictConfig
3 |
4 |
5 | def _shared_prefix(config: DictConfig, init_prefix: str = "") -> str:
6 | s = init_prefix if isinstance(init_prefix, str) else ""
7 | if 'clear' in config.datamodule.get('exp_type'):
8 | s += '_CS'
9 | s += f"_{config.datamodule.get('train_years')}train" + f"_{config.datamodule.get('validation_years')}val"
10 | if config.normalizer.get('input_normalization'):
11 | s += f"_{config.normalizer.get('input_normalization').upper()}"
12 | if config.normalizer.get('output_normalization'):
13 | s += f"_{config.normalizer.get('output_normalization').upper()}"
14 | return s.lstrip('_')
15 |
16 |
17 | def get_name_for_hydra_config_class(config: DictConfig) -> str:
18 | if 'name' in config and config.get('name') is not None:
19 | return config.get('name')
20 | elif '_target_' in config:
21 | return config._target_.split('.')[-1]
22 | return "$"
23 |
24 |
25 | def get_detailed_name(config) -> str:
26 | """ This is a prefix for naming the runs for a more agreeable logging."""
27 | s = config.get("name", '')
28 | s = _shared_prefix(config, init_prefix=s) + '_'
29 | if config.model.get('dropout') > 0:
30 | s += f"{config.model.get('dropout')}dout_"
31 |
32 | s += config.model.get('activation_function') + '_'
33 | s += get_name_for_hydra_config_class(config.model.optimizer) + '_'
34 | s += get_name_for_hydra_config_class(config.model.scheduler) + '_'
35 |
36 | s += f"{config.datamodule.get('batch_size')}bs_"
37 | s += f"{config.model.optimizer.get('lr')}lr_"
38 | if config.model.optimizer.get('weight_decay') > 0:
39 | s += f"{config.model.optimizer.get('weight_decay')}wd_"
40 |
41 | hdims = config.model.get('hidden_dims')
42 | if all([h == hdims[0] for h in hdims]):
43 | hdims = f"{hdims[0]}x{len(hdims)}"
44 | else:
45 | hdims = str(hdims)
46 | s += f"{hdims}h" # &{net_params['out_dim']}oDim"
47 | # if not params['shuffle']:
48 | # s += 'noShuffle_'
49 | s += f"{config.get('seed')}seed"
50 |
51 | return s.replace('None', '')
52 |
53 |
54 | def get_model_name(name: str) -> str:
55 | if 'CNN' in name:
56 | return 'CNN'
57 | elif 'MLP' in name:
58 | return 'MLP'
59 | elif 'GCN' in name:
60 | return 'GCN'
61 | elif 'GN' in name:
62 | return 'GraphNet'
63 | else:
64 | raise ValueError(name)
65 |
66 |
67 | def get_group_name(config) -> str:
68 | s = get_name_for_hydra_config_class(config.model)
69 | s = s.lower().replace('net', '').replace('_', '').replace("climart", "").replace("with", "+").upper()
70 | s = _shared_prefix(config, init_prefix=s)
71 |
72 | if config.normalizer.get('spatial_normalization_in') and config.normalizer.get('spatial_normalization_out'):
73 | s += '+spatialNormed'
74 | elif config.normalizer.get('spatial_normalization_in'):
75 | s += '+spatialInNormed'
76 | elif config.normalizer.get('spatial_normalization_out'):
77 | s += '+spatialOutNormed'
78 |
79 | return s
80 |
81 |
82 | def stem_word(word: str) -> str:
83 | return word.lower().strip().replace('-', '').replace('&', '').replace('+', '').replace('_', '')
84 |
85 |
86 | def get_exp_ID(exp_type: str, target_types: Union[str, List[str]], target_variables: Union[str, List[str]]):
87 | s = f"{exp_type.upper()} conditions, with {' '.join(target_types)} x {' '.join(target_variables)} targets"
88 | return s
89 |
--------------------------------------------------------------------------------
/climart/utils/optimization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def get_loss(name, reduction='mean'):
5 | name = name.lower().strip().replace('-', '_')
6 | if name in ['l1', 'mae', "mean_absolute_error"]:
7 | loss = nn.L1Loss(reduction=reduction)
8 | elif name in ['l2', 'mse', "mean_squared_error"]:
9 | loss = nn.MSELoss(reduction=reduction)
10 | elif name in ['smoothl1', 'smooth']:
11 | loss = nn.SmoothL1Loss(reduction=reduction)
12 | else:
13 | raise ValueError(f'Unknown loss function {name}')
14 | return loss
15 |
16 |
17 | def get_trainable_params(model):
18 | trainable_params = []
19 | for name, param in model.named_parameters():
20 | if param.requires_grad:
21 | trainable_params.append(param)
22 | return trainable_params
23 |
--------------------------------------------------------------------------------
/climart/utils/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from matplotlib.colors import TwoSlopeNorm
3 | import numpy as np
4 | import einops
5 | import xarray as xr
6 |
7 | from climart.data_loading.constants import get_coordinates
8 |
9 |
10 | def set_labels_and_ticks(ax,
11 | title: str = "",
12 | xlabel: str = "", ylabel: str = "",
13 | xlabel_fontsize: int = 10, ylabel_fontsize: int = 14,
14 | xlim=None, ylim=None,
15 | xticks=None, yticks=None,
16 | title_fontsize: int = None,
17 | xticks_fontsize: int = None, yticks_fontsize: int = None,
18 | xtick_labels=None, ytick_labels=None,
19 | logscale_y: bool = False,
20 | show: bool = True,
21 | grid: bool = True,
22 | legend: bool = True, legend_loc='best', legend_prop=10,
23 | full_screen: bool = False,
24 | tight_layout: bool = True,
25 | save_to: str = None
26 | ):
27 | ax.set_title(title, fontsize=title_fontsize)
28 | ax.set_xlabel(xlabel, fontsize=xlabel_fontsize)
29 | ax.set_ylabel(ylabel, fontsize=ylabel_fontsize)
30 |
31 | ax.set_xlim(xlim)
32 | ax.set_ylim(ylim)
33 | if xticks is not None:
34 | ax.set_xticks(xticks)
35 | if xtick_labels is not None:
36 | ax.set_xticklabels(xtick_labels)
37 | if xticks_fontsize:
38 | for tick in ax.xaxis.get_major_ticks():
39 | tick.label.set_fontsize(xticks_fontsize)
40 | # tick.label.set_rotation('vertical')
41 |
42 | if logscale_y:
43 | ax.set_yscale('log')
44 | if yticks is not None:
45 | ax.set_yticks(yticks)
46 | if ytick_labels is not None:
47 | ax.set_yticklabels(ytick_labels)
48 | if yticks_fontsize:
49 | for tick in ax.yaxis.get_major_ticks():
50 | tick.label.set_fontsize(yticks_fontsize)
51 |
52 | if grid:
53 | ax.grid()
54 | if legend:
55 | ax.legend(loc=legend_loc, prop={'size': legend_prop}) #if full_screen else ax.legend(loc=legend_loc)
56 |
57 | if tight_layout:
58 | plt.tight_layout()
59 |
60 | if save_to is not None:
61 | if full_screen:
62 | mng = plt.get_current_fig_manager()
63 | mng.full_screen_toggle()
64 |
65 | plt.savefig(save_to, bbox_inches='tight')
66 | if full_screen:
67 | mng.full_screen_toggle()
68 |
69 | if show:
70 | plt.show()
71 |
72 |
73 | class RollingCmaps:
74 | def __init__(self,
75 | unique_keys: list,
76 | pos_cmaps: list = None,
77 | max_key_occurence: int = 5):
78 | if pos_cmaps is None:
79 | pos_cmaps = ['Greens', 'Oranges', 'Blues', 'Greys', 'Purples']
80 | pos_cmaps = [plt.get_cmap(cmap) for cmap in pos_cmaps]
81 | self.cmaps = {key: pos_cmaps[i] for i, key in enumerate(unique_keys)}
82 | self.pos_per_cmap = {key: 0.75 for key in unique_keys} # lower makes lines too white
83 | self.max_key_occurence = max_key_occurence
84 |
85 | def __getitem__(self, key):
86 | color = self.cmaps[key](self.pos_per_cmap[key] / self.max_key_occurence) # [self.pos_per_cmap[key]]
87 | self.pos_per_cmap[key] += 1
88 | return color
89 |
90 |
91 | class RollingLineFormats:
92 | def __init__(self,
93 | unique_keys: list,
94 | pos_markers: list = None,
95 | cmap = None,
96 | linewidth: float = 4
97 | ):
98 | print(unique_keys)
99 | if pos_markers is None:
100 | pos_markers = ['-', '--', ':', '-', '-.']
101 | if cmap is None:
102 | cmap = plt.get_cmap('viridis')
103 | cs = ['#1f77b4', '#ff7f0e', '#2ca02c', '#9467bd', '#8c564b',
104 | '#e377c2', '#7f7f7f', '#d62728', '#bcbd22', '#17becf']
105 | # cs = plt.rcParams['axes.prop_cycle'].by_key()['color']
106 |
107 | self.pos_markers = pos_markers
108 | # self.cmaps = {key: cmap(i/len(unique_keys)) for i, key in enumerate(unique_keys)}
109 | self.cmaps = {key: cs[i] for i, key in enumerate(unique_keys)}
110 | self.pos_per_key = {key: 0 for key in unique_keys} # lower makes lines too white
111 | self.lws = {key: linewidth for key in unique_keys}
112 |
113 | def __getitem__(self, key):
114 | cur_i = self.pos_per_key[key]
115 | lw = self.lws[key]
116 | line_format = self.pos_markers[cur_i] # [self.pos_per_cmap[key]]
117 | self.pos_per_key[key] += 1
118 | self.lws[key] = max(1, lw - 1)
119 | return line_format, dict(c=self.cmaps[key], linewidth=lw)
120 |
121 |
122 | def plot_groups(xaxis_key, metric='Test/MAE', ax=None, show: bool = True, **kwargs):
123 | if not ax:
124 | fig, ax = plt.subplots() # 1
125 |
126 | for key, group in kwargs.items():
127 | group.plot(xaxis_key, metric, yerr='std', label=key, ax=ax)
128 |
129 | set_labels_and_ticks(
130 | ax, xlabel='Used training points', ylabel=metric, show=show
131 | )
132 |
133 |
134 | def height_errors(Ytrue: np.ndarray, preds: np.ndarray, height_ticks=None,
135 | xlabel='', ylabel='height', fill_between=True, show=True):
136 | """
137 | Plot MAE and MBE as a function of the height/pressure
138 | :param Ytrue:
139 | :param preds:
140 | :param height_ticks: must have same shape as Ytrue.shape[1]
141 | :param show:
142 | :return:
143 | """
144 | n_samples, n_levels = Ytrue.shape
145 | diff = Ytrue - preds
146 | abs_diff = np.abs(diff)
147 | levelwise_MBE = np.mean(diff, axis=0)
148 | levelwise_MAE = np.mean(abs_diff, axis=0)
149 |
150 | levelwise_MBE_std = np.std(diff, axis=0)
151 | levelwise_MAE_std = np.std(abs_diff, axis=0)
152 |
153 | # Plotting
154 | plotting_kwargs = {'yticks': height_ticks, 'ylabel': ylabel, 'show': show, "fill_between": fill_between}
155 | yaxis = np.arange(n_levels)
156 | figMBE = height_plot(yaxis, levelwise_MBE, levelwise_MBE_std, xlabel=xlabel + ' MBE', **plotting_kwargs)
157 | figMAE = height_plot(yaxis, levelwise_MAE, levelwise_MAE_std, xlabel=xlabel + ' MAE', **plotting_kwargs)
158 |
159 | if show:
160 | plt.show()
161 | return figMAE, figMBE
162 |
163 |
164 | def height_plot(yaxis, line, std, yticks=None, ylabel=None, xlabel=None, show=False, fill_between=True):
165 | fig, ax = plt.subplots(1)
166 | if "mbe" in xlabel.lower():
167 | # to better see the bias
168 | ax.plot(np.zeros(yaxis.shape), yaxis, '--', color='grey')
169 |
170 | p = ax.plot(line, yaxis, '-', linewidth=3)
171 | if fill_between:
172 | ax.fill_betweenx(yaxis, line - std, line + std, alpha=0.2)
173 | else:
174 | ax.plot(line - std, yaxis, '--', color=p[0].get_color(), linewidth=1.5)
175 | ax.plot(line + std, yaxis, '--', color=p[0].get_color(), linewidth=1.5)
176 |
177 | xlim = [0, ax.get_xlim()[1]] if 'mae' in xlabel.lower() or 'rmse' in xlabel.lower() else None
178 | set_labels_and_ticks(ax=ax, xlabel=xlabel, xlim=xlim,
179 | yticks=yaxis, ytick_labels=yticks,
180 | ylabel=ylabel, show=show)
181 |
182 | return fig
183 |
184 |
185 | def level_errors(Y_true, Y_preds, epoch):
186 | errors = np.mean((Y_true - Y_preds), axis=0)
187 | colours = ['red' if x < 0 else 'green' for x in errors]
188 | index = np.arange(0, len(colours), 1)
189 |
190 | # Draw plot
191 | lev_fig = plt.figure(figsize=(14, 14), dpi=80)
192 | plt.hlines(y=index, xmin=0, xmax=errors)
193 | for x, y, tex in zip(errors, index, errors):
194 | t = plt.text(x, y, round(tex, 2), horizontalalignment='right' if x < 0 else 'left',
195 | verticalalignment='center', fontdict={'color': 'red' if x < 0 else 'green', 'size': 10})
196 |
197 | # Styling
198 | plt.yticks(index, ['Level: ' + str(z) for z in index], fontsize=12)
199 | plt.title(f'Average Level-wise error for epoch: {epoch}', fontdict={'size': 20})
200 | plt.grid(linestyle='--', alpha=0.5)
201 | plt.xlim(-5, 5)
202 |
203 | return lev_fig
204 |
205 |
206 | def profile_errors(Y_true, Y_preds, plot_profiles=200, var_name=None, data_dir: str = None,
207 | error_type='mean', plot_type='scatter', set_seed=False, title=""):
208 | coords_data = get_coordinates(data_dir)
209 | lat = list(coords_data.get_index('lat'))
210 | lon = list(coords_data.get_index('lon'))
211 |
212 | total_profiles, n_levels = Y_true.shape
213 |
214 | if set_seed: # To get the same profiles everytime
215 | np.random.seed(7)
216 |
217 | errors = np.abs(Y_true - Y_preds)
218 | # print(errors.shape, Y_true.shape, total_profiles / 8192)
219 |
220 | if plot_type.lower() == 'scatter':
221 | latitude = []
222 | longitude = []
223 |
224 | for i in lat:
225 | for j in lon:
226 | latitude.append(i)
227 | longitude.append(j)
228 |
229 | lat_var = np.array(latitude)
230 | lon_var = np.array(longitude)
231 |
232 | n_times = int(total_profiles / 8192)
233 | indices = np.arange(0, total_profiles)
234 | indices_train = np.random.choice(total_profiles, total_profiles - plot_profiles, replace=False)
235 | indices_rest = np.setxor1d(indices_train, indices, assume_unique=True)
236 |
237 | lat_var = np.mean(np.vstack([np.expand_dims(lat_var, 1)] * n_times), axis=1)
238 | lon_var = np.mean(np.vstack([np.expand_dims(lon_var, 1)] * n_times), axis=1)
239 | lon_plot = lon_var[indices_rest]
240 | lat_plot = lat_var[indices_rest]
241 | errors_lev = errors[indices_rest]
242 | errors_lev = einops.rearrange(np.array(errors_lev), 'p l -> l p') # p = profile dim
243 | print(errors.shape, Y_true.shape) # (81920, 50) (81920, 50)
244 | else:
245 | errors_lev = errors.reshape(n_levels, 8192, -1) # level x spatial_dim x snapshot_dim
246 | errors_lev = np.mean(errors_lev, axis=2) # mean over all snapshots
247 | errors_lev = errors_lev.reshape((n_levels, len(lat), len(lon))) # reshape back to spatial grid
248 | lon_plot, lat_plot = np.meshgrid(lon, lat)
249 |
250 | if error_type.lower() == 'toa':
251 | err = errors_lev[0]
252 | elif error_type.lower() == 'surface':
253 | err = errors_lev[-1]
254 | elif error_type.lower() == 'mean':
255 | err = np.mean(errors_lev, axis=0)
256 |
257 | pp = profile_plot(lon_plot, lat_plot, err, var_name, plot_type=plot_type, title=title)
258 |
259 | return pp
260 |
261 |
262 | def profile_plot(lon_plot, lat_plot, errors, var_name=None, plot_type='scatter', dpi=70, title=""):
263 | """
264 |
265 | :param lon_plot:
266 | :param lat_plot:
267 | :param errors: Note that if plot_type is 2D, i.e is in ['heatmap', 'contour'], it has to have shape lat x lon
268 | :param var_name:
269 | :param plot_type:
270 | :param dpi:
271 | :return:
272 | """
273 | import cartopy.crs as ccrs
274 | import cartopy.feature as cfeature
275 | fig = plt.figure(figsize=(12, 8))
276 | plt.rcParams['figure.dpi'] = dpi
277 | plot_type = plot_type.lower()
278 |
279 | ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson())
280 | ax.add_feature(cfeature.COASTLINE)
281 | ax.gridlines()
282 | # ax.stock_img()
283 | ax.set_global()
284 |
285 | # jet = plt.cm.get_cmap('RdBu_r')
286 | jet = plt.cm.get_cmap('plasma')
287 |
288 | # nightshade
289 | # current_time = datetime.now() Can work well if we can get the UTC time for the data
290 | # ax.add_feature(Nightshade(current_time, alpha=0.3))
291 |
292 | # Circle params
293 | fs_text = 10
294 | padd = -0.18
295 | stroffset = -0.2
296 | circlesize = 100
297 | lw = 2.2
298 |
299 | if plot_type == 'contour':
300 | sc = ax.contourf(lon_plot, lat_plot, errors, transform=ccrs.PlateCarree(), alpha=0.85, cmap="Reds",
301 | levels=100)
302 | elif plot_type == 'heatmap':
303 | sc = ax.pcolormesh(lon_plot, lat_plot, errors, cmap="Reds", transform=ccrs.PlateCarree())
304 | else:
305 | sc = ax.scatter(x=lon_plot, y=lat_plot, s=circlesize,
306 | c=errors, norm=TwoSlopeNorm(5, vmin=0, vmax=10),
307 | alpha=0.8, cmap=jet, linewidths=lw,
308 | transform=ccrs.PlateCarree())
309 |
310 | ax.set_title(f'{title}{plot_type.upper()}-{var_name.upper()} error', fontsize=fs_text)
311 | # Colour Bar
312 | cbar = plt.colorbar(sc, ax=ax, aspect=30, pad=0.01, shrink=0.4, orientation='vertical')
313 | cbar.ax.set_ylabel('W m$^{-2}$', rotation=270, labelpad=10)
314 | return fig
315 |
316 |
317 | def prediction_hist(preds: dict, TOA=False, surface=False,
318 | title="", show=True, figsize=(16, 12), axes=None,
319 | label="", **kwargs):
320 | n_vars = len(preds.keys())
321 | n_cols = 3 if TOA and surface else 2 if (TOA or surface) else 1
322 |
323 | surface_ax = 1
324 | TOA_ax = 2 if surface else 1
325 |
326 | if axes is None:
327 | fig, axs = plt.subplots(n_vars, n_cols, figsize=figsize)
328 | fig.suptitle("Prediction magnitudes" if title == "" else title)
329 | axs[0, 0].set_title('Mean')
330 |
331 | if surface:
332 | axs[0, surface_ax].set_title('Surface')
333 | if TOA:
334 | axs[0, TOA_ax].set_title('TOA')
335 | else:
336 | axs = axes
337 |
338 | def set_bar_colors(patches, upto=5):
339 | return
340 | jet = plt.get_cmap('jet', len(patches))
341 | for i in range(len(patches)):
342 | if i > upto:
343 | return
344 | patches[i].set_facecolor(jet(i * 10))
345 |
346 | for (var_name, var_preds), ax_row in zip(preds.items(), axs):
347 | # n_samples, n_levels = var_preds.shape
348 | N, bins, patches = ax_row[0].hist(np.mean(var_preds, axis=1), label=label, **kwargs)
349 | set_bar_colors(patches)
350 | ax_row[0].set_ylabel(f"{var_name.upper()}", fontsize=20)
351 | if surface:
352 | N, bins, patches = ax_row[surface_ax].hist(var_preds[:, -1], label=label, **kwargs)
353 | set_bar_colors(patches)
354 | if TOA:
355 | N, bins, patches = ax_row[TOA_ax].hist(var_preds[:, 0], label=label, **kwargs)
356 | set_bar_colors(patches)
357 |
358 | axs[0, 0].legend()
359 | if show:
360 | plt.show()
361 |
362 | return axs
363 |
364 |
365 | def prediction_bars(preds: dict, bins, TOA=False, surface=False,
366 | title="", show=True, figsize=(16, 12), axes=None,
367 | label="", **kwargs):
368 | n_vars = len(preds.keys())
369 | n_cols = 3 if TOA and surface else 2 if (TOA or surface) else 1
370 |
371 | surface_ax = 1
372 | TOA_ax = 2 if surface else 1
373 |
374 | if axes is None:
375 | fig, axs = plt.subplots(n_vars, n_cols, figsize=figsize)
376 | fig.suptitle("Prediction magnitudes" if title == "" else title)
377 | axs[0, 0].set_title('Mean')
378 |
379 | if surface:
380 | axs[0, surface_ax].set_title('Surface')
381 | if TOA:
382 | axs[0, TOA_ax].set_title('TOA')
383 | else:
384 | axs = axes
385 |
386 | for i, ((var_name, var_preds), ax_row) in enumerate(zip(preds.items(), axs)):
387 |
388 | if False: # i == 1:
389 | kwargs['tick_label'] = ['{} - {}'.format(bins[i], bins[i + 1]) for i, j in enumerate(hist)]
390 |
391 | hist, bin_edges = np.histogram(np.mean(var_preds, axis=1), bins)
392 | ax_row[0].bar(range(len(hist)), hist, width=1, align='center', label=label, **kwargs)
393 | ax_row[0].set_ylabel(f"{var_name.upper()}", fontsize=20)
394 | if surface:
395 | hist, bin_edges = np.histogram(var_preds[:, -1], bins)
396 | ax_row[surface_ax].bar(range(len(hist)), hist, width=1, align='center', label=label, **kwargs)
397 | if TOA:
398 | hist, bin_edges = np.histogram(var_preds[:, 0], bins)
399 | ax_row[TOA_ax].bar(range(len(hist)), hist, width=1, align='center', label=label, **kwargs)
400 |
401 | axs[0, 0].legend()
402 | if show:
403 | plt.show()
404 |
405 | return axs
406 |
--------------------------------------------------------------------------------
/climart/utils/postprocessing.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import wandb
4 | import os
5 |
6 | from functools import partial
7 | import torch
8 | import xarray as xr
9 | import numpy as np
10 | from climart.data_loading.constants import LEVELS, LAYERS, GLOBALS, get_data_dims, get_metadata, get_coordinates
11 |
12 | def get_lat_lon(data_dir: str = None):
13 | coords_data = get_coordinates(data_dir)
14 | lat = list(coords_data.get_index('lat'))
15 | lon = list(coords_data.get_index('lon'))
16 |
17 | latitude = []
18 | longitude = []
19 | for i in lat:
20 | for j in lon:
21 | latitude.append(i)
22 | longitude.append(j)
23 | lat_var = np.array(latitude)
24 | lon_var = np.array(longitude)
25 | return {'latitude': lat, 'longitude': lon, 'latitude_flattened': lat_var, 'longitude_flattened': lon_var}
26 |
27 |
28 | # %%
29 |
30 | def save_preds_to_netcdf(preds,
31 | targets,
32 | post_fix: str = '',
33 | save_path=None,
34 | exp_type='pristine',
35 | data_dir: str = None,
36 | model=None,
37 | **kwargs):
38 | lat_lon = get_lat_lon(data_dir)
39 | lat, lon = lat_lon['latitude'], lat_lon['longitude']
40 | spatial_dim, _ = get_data_dims(exp_type)
41 | n_levels = spatial_dim[LEVELS]
42 | n_layers = spatial_dim[LAYERS]
43 | shape = ['snapshot', 'latitude', 'longitude', 'level']
44 | shape_lay = ['snapshot', 'latitude', 'longitude', 'layer']
45 | shape_glob = ['snapshot', 'latitude', 'longitude']
46 |
47 | meta_info = get_metadata(data_dir)
48 |
49 | data_vars = dict()
50 | for k, v in preds.items():
51 | data_vars[f"{k}_preds"] = (shape, v.reshape((-1, len(lat), len(lon), n_levels)))
52 | for k, v in targets.items():
53 | data_vars[f"{k}_targets"] = (shape, v.reshape((-1, len(lat), len(lon), n_levels)))
54 |
55 | data_vars["pressure"] = (shape, kwargs['pressure'].reshape((-1, len(lat), len(lon), n_levels)))
56 | data_vars["layer_pressure"] = (shape_lay, kwargs['layer_pressure'].reshape((-1, len(lat), len(lon), n_layers)))
57 | data_vars["cszrow"] = (shape_glob, kwargs['cszrow'].reshape((-1, len(lat), len(lon))))
58 |
59 | xr_dset = xr.Dataset(
60 | data_vars=data_vars,
61 | coords=dict(
62 | longitude=lon,
63 | latitude=lat,
64 | level=list(range(n_levels))[::-1],
65 | layer=list(range(n_layers))[::-1],
66 | ),
67 | attrs=dict(description="ML emulated RT outputs."),
68 | )
69 | if save_path is not None:
70 | if not save_path.endswith('.nc'):
71 | save_path += '.nc'
72 | save_path.replace('.nc', post_fix + '.nc')
73 |
74 | elif model is not None:
75 | save_path = f'./example_{exp_type}_preds_{model}_{post_fix}.nc'
76 | else:
77 | print("Not saving to NC!")
78 | return xr_dset
79 | if not os.path.isfile(save_path):
80 | xr_dset.to_netcdf(save_path)
81 | print('saved to\n', save_path)
82 | return xr_dset
83 |
84 |
85 | def restore_run(run_id,
86 | run_path: str,
87 | api=None,
88 | ):
89 | if api is None:
90 | api = wandb.Api()
91 | run_path = f"{run_path}/{run_id}"
92 | run = api.run(run_path)
93 | return run
94 |
95 |
96 | def restore_ckpt_from_wandb_run(run, entity: str = 'ecc-mila7', run_path=None, load: bool = False, **kwargs):
97 | run_id = run.id
98 | ckpt = [f for f in run.files() if f"{run_id}.pkl" in str(f)]
99 | ckpt = str(ckpt[0].name)
100 | ckpt = wandb.restore(ckpt, run_path=f"{entity}/ClimART/{run_id}")
101 | ckpt_fname = ckpt.name
102 | if load:
103 | return torch.load(ckpt_fname, **kwargs)
104 | return ckpt_fname
105 |
106 |
107 | def restore_ckpt_from_wandb(run_id,
108 | run_path: str,
109 | api=None,
110 | load: bool = False,
111 | **kwargs):
112 | run = restore_run(run_id, run_path, api)
113 | return restore_ckpt_from_wandb(run, load=load, **kwargs)
114 |
--------------------------------------------------------------------------------
/climart/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Salva Rühling Cachay
3 | """
4 | import functools
5 | import logging
6 | import math
7 | import os
8 |
9 | from types import SimpleNamespace
10 | from typing import Union, Sequence, List, Dict, Optional, Callable
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from omegaconf import DictConfig, open_dict, OmegaConf
17 | from torch import Tensor
18 | from pytorch_lightning.utilities import rank_zero_only
19 | from climart.data_loading import constants, data_variables
20 |
21 |
22 | def no_op(*args, **kwargs):
23 | pass
24 |
25 |
26 | def get_identity_callable(*args, **kwargs) -> Callable:
27 | return identity
28 |
29 |
30 | def get_activation_function(name: str, functional: bool = False, num: int = 1):
31 | name = name.lower().strip()
32 |
33 | def get_functional(s: str) -> Optional[Callable]:
34 | return {"softmax": F.softmax, "relu": F.relu, "tanh": torch.tanh, "sigmoid": torch.sigmoid,
35 | "identity": nn.Identity(),
36 | None: None, 'swish': F.silu, 'silu': F.silu, 'elu': F.elu, 'gelu': F.gelu, 'prelu': nn.PReLU(),
37 | }[s]
38 |
39 | def get_nn(s: str) -> Optional[Callable]:
40 | return {"softmax": nn.Softmax(dim=1), "relu": nn.ReLU(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(),
41 | "identity": nn.Identity(), 'silu': nn.SiLU(), 'elu': nn.ELU(), 'prelu': nn.PReLU(),
42 | 'swish': nn.SiLU(), 'gelu': nn.GELU(),
43 | }[s]
44 |
45 | if num == 1:
46 | return get_functional(name) if functional else get_nn(name)
47 | else:
48 | return [get_nn(name) for _ in range(num)]
49 |
50 |
51 | def get_normalization_layer(name, dims, num_groups=None, *args, **kwargs):
52 | if not isinstance(name, str) or name.lower() == 'none':
53 | return None
54 | elif 'batch' in name:
55 | return nn.BatchNorm1d(num_features=dims, *args, **kwargs)
56 | elif 'layer' in name:
57 | return nn.LayerNorm(dims, *args, **kwargs)
58 | elif 'inst' in name:
59 | return nn.InstanceNorm1d(num_features=dims, *args, **kwargs)
60 | elif 'group' in name:
61 | if num_groups is None:
62 | num_groups = int(dims / 10)
63 | return nn.GroupNorm(num_groups=num_groups, num_channels=dims)
64 | else:
65 | raise ValueError("Unknown normalization name", name)
66 |
67 |
68 | def identity(X):
69 | return X
70 |
71 |
72 | def to_dict(obj: Optional[Union[dict, SimpleNamespace]]):
73 | if obj is None:
74 | return dict()
75 | elif isinstance(obj, dict):
76 | return obj
77 | else:
78 | return vars(obj)
79 |
80 |
81 | def to_DictConfig(obj: Optional[Union[List, Dict]]):
82 | if isinstance(obj, DictConfig):
83 | return obj
84 |
85 | if isinstance(obj, list):
86 | try:
87 | dict_config = OmegaConf.from_dotlist(obj)
88 | except ValueError as e:
89 | dict_config = OmegaConf.create(obj)
90 |
91 | elif isinstance(obj, dict):
92 | dict_config = OmegaConf.create(obj)
93 |
94 | else:
95 | dict_config = OmegaConf.create() # empty
96 |
97 | return dict_config
98 |
99 |
100 | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
101 | """Initializes multi-GPU-friendly python logger."""
102 | logger = logging.getLogger(name)
103 | logger.setLevel(level)
104 |
105 | # this ensures all logging levels get marked with the rank zero decorator
106 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup
107 | for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
108 | setattr(logger, level, rank_zero_only(getattr(logger, level)))
109 |
110 | return logger
111 |
112 |
113 | #####
114 | # The following two functions extend setattr and getattr to support chained objects, e.g. rsetattr(cfg, optim.lr, 1e-4)
115 | # From https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-subobjects-chained-properties
116 | def rsetattr(obj, attr, val):
117 | pre, _, post = attr.rpartition('.')
118 | return setattr(rgetattr(obj, pre) if pre else obj, post, val)
119 |
120 |
121 | def rgetattr(obj, attr, *args):
122 | def _getattr(obj, attr):
123 | return getattr(obj, attr, *args)
124 |
125 | return functools.reduce(_getattr, [obj] + attr.split('.'))
126 |
127 |
128 | def adj_to_edge_indices(adj: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]:
129 | """
130 | Args:
131 | adj: a (N, N) adjacency matrix, where N is the number of nodes
132 | Returns:
133 | A (2, E) array, edge_idxs, where E is the number of edges,
134 | and edge_idxs[0], edge_idxs[1] are the source & destination nodes, respectively.
135 | """
136 | edge_tuples = torch.nonzero(adj, as_tuple=True) if torch.is_tensor(adj) else np.nonzero(adj)
137 | edge_src = edge_tuples[0].unsqueeze(0) if torch.is_tensor(adj) else np.expand_dims(edge_tuples[0], axis=0)
138 | edge_dest = edge_tuples[1].unsqueeze(0) if torch.is_tensor(adj) else np.expand_dims(edge_tuples[1], axis=0)
139 | if torch.is_tensor(adj):
140 | edge_idxs = torch.cat((edge_src, edge_dest), dim=0)
141 | else:
142 | edge_idxs = np.concatenate((edge_src, edge_dest), axis=0)
143 | return edge_idxs
144 |
145 |
146 | def normalize_adjacency_matrix_torch(adj: Tensor, improved: bool = True, add_self_loops: bool = False):
147 | if add_self_loops:
148 | fill_value = 2. if improved else 1.
149 | adj = adj.fill_diagonal_(fill_value)
150 | deg: Tensor = torch.sum(adj, dim=1)
151 | deg_inv_sqrt: Tensor = deg.pow_(-0.5)
152 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
153 | adj_t = torch.mul(adj, deg_inv_sqrt.view(-1, 1))
154 | adj_t = torch.mul(adj_t, deg_inv_sqrt.view(1, -1))
155 | return adj_t
156 |
157 |
158 | def normalize_adjacency_matrix(adj: np.ndarray, improved: bool = True, add_self_loops: bool = True):
159 | if add_self_loops:
160 | fill_value = 2. if improved else 1.
161 | np.fill_diagonal(adj, fill_value)
162 | deg = np.sum(adj, axis=1)
163 | deg_inv_sqrt = np.power(deg, -0.5)
164 | deg_inv_sqrt[np.isinf(deg_inv_sqrt)] = 0.
165 |
166 | deg_inv_sqrt_matrix = np.diag(deg_inv_sqrt)
167 | adj_normed = deg_inv_sqrt_matrix @ adj @ deg_inv_sqrt_matrix
168 | return adj_normed
169 |
170 |
171 | def set_gpu(gpu_id):
172 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
173 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
174 |
175 |
176 | def set_seed(seed, device='cuda'):
177 | import random, torch
178 | # setting seeds
179 | random.seed(seed)
180 | np.random.seed(seed)
181 | torch.manual_seed(seed)
182 | if device != 'cpu':
183 | torch.backends.cudnn.deterministic = True
184 | torch.cuda.manual_seed(seed)
185 | torch.cuda.manual_seed_all(seed)
186 |
187 |
188 | def year_string_to_list(year_string: str) -> List[int]:
189 | """
190 | Args:
191 | year_string (str): must only contain {digits, '-', '+'}.
192 | Examples:
193 | '1988-90' will return [1988, 1989, 1990]
194 | '1988-1990+2001-2004' will return [1988, 1989, 1990, 2001, 2002, 2003, 2004]
195 | """
196 | if not isinstance(year_string, str):
197 | year_string = str(year_string)
198 |
199 | def year_string_to_full_year(year_string: str):
200 | if len(year_string) == 4:
201 | return int(year_string)
202 | assert len(year_string) == 2, f'Year {year_string} had an unexpected length.'
203 | if int(year_string[0]) < 3:
204 | return int('20' + year_string)
205 | else:
206 | return int('19' + year_string)
207 |
208 | def update_years(year_list: List[int], year_start, year_end):
209 | if not isinstance(year_start, int):
210 | year_start = year_string_to_full_year(year_start)
211 | if year_end == '':
212 | year_end = year_start
213 | else:
214 | year_end = year_string_to_full_year(year_end)
215 | year_list += list(range(year_start, year_end + 1))
216 | return year_list, '', ''
217 |
218 | years = []
219 | cur_year_start = cur_year_end = ''
220 | for char in year_string:
221 | if char == '-':
222 | cur_year_start = year_string_to_full_year(cur_year_start)
223 | elif char == '+':
224 | years, cur_year_start, cur_year_end = update_years(years, cur_year_start, cur_year_end)
225 | else:
226 | if isinstance(cur_year_start, int):
227 | cur_year_end += char
228 | else:
229 | cur_year_start += char
230 | years, _, _ = update_years(years, cur_year_start, cur_year_end)
231 | return years
232 |
233 |
234 | def target_var_id_mapping(x, y):
235 | k = 'l' if y.lower().replace('_', '').replace('-', '') == 'longwave' else 's'
236 | x = x.lower().replace('-', '_').replace('rates', 'rate').replace('_fluxes', '').replace('_flux', '')
237 | if x == 'heating_rate':
238 | return f'hr{k}c'
239 | elif x == 'upwelling':
240 | return f"r{k}uc"
241 | elif x == 'downwelling':
242 | return f"r{k}dc"
243 | else:
244 | raise ValueError(f"Combination {x} {y} not understood!")
245 |
246 |
247 | def get_target_types(target_type: Union[str, List[str]]) -> List[str]:
248 | if isinstance(target_type, list):
249 | assert all([t in [constants.SHORTWAVE, constants.LONGWAVE] for t in target_type])
250 | return target_type
251 | target_type2 = target_type.lower().replace('&', '+').replace('-', '')
252 | if target_type2 in ['sw+lw', 'lw+sw', 'shortwave+longwave', 'longwave+shortwave']:
253 | return [constants.SHORTWAVE, constants.LONGWAVE]
254 | elif target_type2 in ['sw', 'shortwave']:
255 | return [constants.SHORTWAVE]
256 | elif target_type2 in ['lw', 'longwave']:
257 | return [constants.LONGWAVE]
258 | else:
259 | raise ValueError(f"Target type `{target_type}` must be one of shortwave, longwave or shortwave+longwave")
260 |
261 |
262 | def get_target_variable_names(target_types: Union[str, List[str]],
263 | target_variable: Union[str, List[str]]) -> List[str]:
264 | out_vars = data_variables.OUT_SHORTWAVE_NOCLOUDS + data_variables.OUT_LONGWAVE_NOCLOUDS \
265 | + data_variables.OUT_HEATING_RATE_NOCLOUDS
266 | if isinstance(target_variable, list):
267 | if len(target_variable) == 1:
268 | target_variable = target_variable[0]
269 | else:
270 | err_msg = f"Each target var must be in {out_vars}, but got {target_variable}"
271 | assert all([t.lower() in out_vars for t in target_variable]), err_msg
272 | return target_variable
273 |
274 | target_types = get_target_types(target_types)
275 | target_variable2 = target_variable.lower().replace('&', '+').replace('-', '').replace('_', '')
276 | target_variable2 = target_variable2.replace('fluxes', 'flux').replace('heatingrate', 'hr')
277 | target_vars: List[str] = []
278 | if constants.LONGWAVE in target_types:
279 | target_vars += data_variables.OUT_LONGWAVE_NOCLOUDS + [data_variables.LW_HEATING_RATE]
280 | if constants.SHORTWAVE in target_types:
281 | target_vars += data_variables.OUT_SHORTWAVE_NOCLOUDS + [data_variables.SW_HEATING_RATE]
282 | assert len(target_vars) > 0, f"{target_types}, {target_variable} resulted in zero target_vars!"
283 |
284 | if len(target_vars) == 0:
285 | raise ValueError(f"Target var `{target_variable2}` must be one of fluxes, heating_rate.")
286 | return target_vars
287 |
288 |
289 | def get_target_variable(target_variable: Union[str, List[str]]) -> List[str]:
290 | if isinstance(target_variable, list):
291 | if len(target_variable) == 1 and 'flux' in target_variable[0]:
292 | target_variable = target_variable[0]
293 | else:
294 | return target_variable
295 | target_variable2 = target_variable.lower().replace('&', '+').replace('-', '').replace('_', '')
296 | target_variable2 = target_variable2.replace('fluxes', 'flux').replace('heatingrate', 'hr')
297 | target_vars: List[str] = []
298 | if target_variable2 == 'hr':
299 | return [constants.SURFACE_FLUXES, constants.TOA_FLUXES, constants.HEATING_RATES]
300 | else:
301 | if 'flux' in target_variable2:
302 | target_vars += [constants.FLUXES]
303 | if 'hr' in target_variable2:
304 | target_vars += [constants.HEATING_RATES]
305 |
306 | if len(target_vars) == 0:
307 | raise ValueError(f"Target var `{target_variable2}` must be one of fluxes, heating_rate.")
308 | return target_vars
309 |
310 |
311 | def pressure_from_level_array(levels_array):
312 | PRESSURE_IDX = 2
313 | return levels_array[..., PRESSURE_IDX]
314 |
315 |
316 | def layer_thickness_from_layer_array(layers_array):
317 | THICKNESS_IDX = 12
318 | return layers_array[..., THICKNESS_IDX]
319 |
320 |
321 | def fluxes_to_heating_rates(upwelling_flux: Union[np.ndarray, Tensor],
322 | downwelling_flux: Union[np.ndarray, Tensor],
323 | pressure: Union[np.ndarray, Tensor],
324 | c: float = 9.761357e-03
325 | ) -> Union[np.ndarray, Tensor]:
326 | """
327 | N - the batch/data dimension size
328 | L - the number of levels (= number of layers + 1)
329 |
330 | Args:
331 | upwelling_flux: a (N, L) array
332 | downwelling_flux: a (N, L) array
333 | pressure: a (N, L) array representing the levels pressure
334 | or a (N, L, D-lev) array containing *all* level variables (including pressure)
335 |
336 | Returns:
337 | A (N, L-1) array representing the heating rates at each of the L-1 layers
338 | """
339 | err_msg = f"Upwelling/Downwelling arg should have same shape, but have {upwelling_flux.shape}, {downwelling_flux.shape}"
340 | assert upwelling_flux.shape == downwelling_flux.shape, err_msg
341 | if len(pressure.shape) <= 2:
342 | err_msg = f"pressure arg has not the expected shape (N, #levels), but has shape {pressure.shape}"
343 | assert downwelling_flux.shape == pressure.shape, err_msg
344 | else:
345 | err_msg = "pressure argument is not the expected levels array of shape (N, #levels, #level-vars)"
346 | assert len(pressure.shape) == 3, err_msg
347 | pressure = pressure_from_level_array(pressure)
348 |
349 | c = 9.761357e-03 # 9.76 * 1e-5
350 | # c_p = 1004.98322108, g = 9.81, c = g/c_p
351 | # 3D radiative effects paper uses 8.91/1004 = 0.00977091633
352 | net_level_flux = upwelling_flux - downwelling_flux
353 | net_layer_flux = net_level_flux[:, 1:] - net_level_flux[:, :-1]
354 | pressure_diff = pressure[:, 1:] - pressure[:, :-1]
355 | heating_rate = c * net_layer_flux / pressure_diff
356 |
357 | assert tuple(heating_rate.shape) == (pressure.shape[0], pressure.shape[1] - 1)
358 |
359 | return heating_rate
360 |
--------------------------------------------------------------------------------
/climart/utils/wandb_callbacks.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from pathlib import Path
3 | import wandb
4 | import pytorch_lightning as pl
5 | from pytorch_lightning import Callback, Trainer
6 | from pytorch_lightning.utilities import rank_zero_only
7 | from pytorch_lightning.loggers import LoggerCollection, WandbLogger
8 | from climart.utils.utils import get_logger
9 |
10 | log = get_logger(__name__)
11 |
12 |
13 | def get_wandb_logger(trainer: Trainer) -> WandbLogger:
14 | """Safely get Weights&Biases logger from Trainer."""
15 |
16 | if trainer.fast_dev_run:
17 | raise Exception(
18 | "Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode."
19 | )
20 |
21 | if isinstance(trainer.logger, WandbLogger):
22 | return trainer.logger
23 |
24 | if isinstance(trainer.logger, LoggerCollection):
25 | for logger in trainer.logger:
26 | if isinstance(logger, WandbLogger):
27 | return logger
28 |
29 | raise Exception(
30 | "You are using wandb related callback, but WandbLogger was not found for some reason..."
31 | )
32 |
33 |
34 | class WatchModel(Callback):
35 | """Make wandb watch model at the beginning of the run."""
36 |
37 | def __init__(self, log: str = "gradients", log_freq: int = 100):
38 | self.log = log
39 | self.log_freq = log_freq
40 |
41 | @rank_zero_only
42 | def on_train_start(self, trainer, pl_module):
43 | logger: WandbLogger = get_wandb_logger(trainer=trainer)
44 | try:
45 | logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq, log_graph=True)
46 | except TypeError as e:
47 | log.info(
48 | f" Pytorch-lightning/Wandb version seems to be too old to support 'log_graph' arg in wandb.watch(.)"
49 | f" Wandb version={wandb.__version__}")
50 | wandb.watch(models=trainer.model, log=self.log, log_freq=self.log_freq) # , log_graph=True)
51 |
52 |
53 | class SummarizeBestValMetric(Callback):
54 | """Make wandb log in run.summary the best achieved monitored val_metric as opposed to the last"""
55 |
56 | @rank_zero_only
57 | def on_train_start(self, trainer, pl_module):
58 | logger: WandbLogger = get_wandb_logger(trainer=trainer)
59 | experiment = logger.experiment
60 | experiment.define_metric(trainer.model.hparams.monitor, summary=trainer.model.hparams.mode)
61 |
62 |
63 | class UploadCodeAsArtifact(Callback):
64 | """Upload all code files to wandb as an artifact, at the beginning of the run."""
65 |
66 | def __init__(self, code_dir: str, use_git: bool = True):
67 | """
68 | Args:
69 | code_dir: the code directory
70 | use_git: if using git, then upload all files that are not ignored by git.
71 | if not using git, then upload all '*.py' file
72 | """
73 | self.code_dir = code_dir
74 | self.use_git = use_git
75 |
76 | @rank_zero_only
77 | def on_train_start(self, trainer, pl_module):
78 | logger = get_wandb_logger(trainer=trainer)
79 | experiment = logger.experiment
80 |
81 | code = wandb.Artifact("project-source", type="code")
82 |
83 | if self.use_git:
84 | # get .git folder path
85 | git_dir_path = Path(
86 | subprocess.check_output(["git", "rev-parse", "--git-dir"]).strip().decode("utf8")
87 | ).resolve()
88 |
89 | for path in Path(self.code_dir).resolve().rglob("*"):
90 |
91 | # don't upload files ignored by git
92 | # https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/
93 | command = ["git", "check-ignore", "-q", str(path)]
94 | not_ignored = subprocess.run(command).returncode == 1
95 |
96 | # don't upload files from .git folder
97 | not_git = not str(path).startswith(str(git_dir_path))
98 |
99 | if path.is_file() and not_git and not_ignored:
100 | code.add_file(str(path), name=str(path.relative_to(self.code_dir)))
101 |
102 | else:
103 | for path in Path(self.code_dir).resolve().rglob("*.py"):
104 | code.add_file(str(path), name=str(path.relative_to(self.code_dir)))
105 |
106 | experiment.log_artifact(code)
107 |
108 |
109 | class UploadBestCheckpointAsFile(Callback):
110 | """Upload checkpoints to wandb as a file, at the end of run."""
111 | @rank_zero_only
112 | def on_train_start(self, trainer, pl_module):
113 | if not hasattr(trainer, 'checkpoint_callback'):
114 | log.warning("pl.Trainer has no checkpoint_callback/ModelCheckpoint() callback even though you use"
115 | " UploadBestCheckpointAsFile - This callback will be ignored!")
116 |
117 | @rank_zero_only
118 | def on_exception(self, trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException) -> None:
119 | self.on_train_end(trainer, pl_module)
120 |
121 | @rank_zero_only
122 | def on_train_end(self, trainer, pl_module):
123 | if not hasattr(trainer, 'checkpoint_callback'):
124 | return
125 | logger = get_wandb_logger(trainer=trainer)
126 | path = trainer.checkpoint_callback.best_model_path
127 | if path is not None:
128 | log.info(f"Best checkpoint path will be saved to wandb from path: {path}")
129 | logger.experiment.log({'best_model_filepath': path})
130 | logger.experiment.save(path)
131 |
132 |
133 | class UploadCheckpointsAsArtifact(Callback):
134 | """Upload checkpoints to wandb as an artifact, at the end of run."""
135 |
136 | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = True):
137 | self.ckpt_dir = ckpt_dir
138 | self.upload_best_only = upload_best_only
139 |
140 | @rank_zero_only
141 | def on_exception(self, trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException) -> None:
142 | self.on_train_end(trainer, pl_module)
143 |
144 | @rank_zero_only
145 | def on_train_end(self, trainer, pl_module):
146 | logger = get_wandb_logger(trainer=trainer)
147 |
148 | ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
149 |
150 | if self.upload_best_only:
151 | ckpts.add_file(trainer.checkpoint_callback.best_model_path)
152 | else:
153 | for path in Path(self.ckpt_dir).rglob("*.ckpt"):
154 | ckpts.add_file(str(path))
155 |
156 | logger.experiment.log_artifact(ckpts)
157 |
--------------------------------------------------------------------------------
/configs/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | model_checkpoint:
2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
3 | monitor: ${val_metric} # name of the logged metric which determines when model is improving
4 | mode: "min" # "max" means higher metric value is better, can be also "min"
5 | save_top_k: 1 # save k best models (determined by above metric)
6 | save_last: True # additionally always save model from last epoch
7 | verbose: False
8 | dirpath: ${ckpt_dir}
9 | filename: "epoch{epoch:03d}_seed${seed}"
10 | auto_insert_metric_name: False
11 |
12 | early_stopping:
13 | _target_: pytorch_lightning.callbacks.EarlyStopping
14 | monitor: ${val_metric} # name of the logged metric which determines when model is improving
15 | mode: "min" # "max" means higher metric value is better, can be also "min"
16 | patience: 100 # how many validation epochs of not improving until training stops
17 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
18 |
19 | learning_rate_logging:
20 | _target_: pytorch_lightning.callbacks.LearningRateMonitor
21 |
22 | model_summary:
23 | _target_: pytorch_lightning.callbacks.RichModelSummary
24 | max_depth: 1
25 |
26 | rich_progress_bar:
27 | _target_: pytorch_lightning.callbacks.RichProgressBar
28 |
--------------------------------------------------------------------------------
/configs/callbacks/none.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/configs/callbacks/none.yaml
--------------------------------------------------------------------------------
/configs/callbacks/wandb.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | watch_model:
5 | _target_: climart.utils.wandb_callbacks.WatchModel
6 | log: "all"
7 | log_freq: 100
8 |
9 | summarize_best_val_metric:
10 | _target_: climart.utils.wandb_callbacks.SummarizeBestValMetric
11 |
12 | upload_best_ckpt_as_file:
13 | _target_: climart.utils.wandb_callbacks.UploadBestCheckpointAsFile
14 |
--------------------------------------------------------------------------------
/configs/experiment/example.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example
5 |
6 | defaults:
7 | - override /mode: exp.yaml
8 | - override /trainer: default.yaml
9 | - override /model: mnist.yaml
10 | - override /callbacks: default.yaml
11 | - override /logger: null
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | # name of the run determines folder name in logs
17 | # can also be accessed by loggers
18 | name: "example"
19 |
20 | seed: 12345
21 |
22 | trainer:
23 | min_epochs: 1
24 | max_epochs: 10
25 | gradient_clip_val: 5
26 |
27 | model:
28 | lin1_size: 128
29 | lin2_size: 256
30 | lin3_size: 64
31 | lr: 0.002
32 |
33 | datamodule:
34 | batch_size: 64
35 | train_val_test_split: [55_000, 5_000, 10_000]
36 |
37 | logger:
38 | csv:
39 | name: csv/${name}
40 | wandb:
41 | tags: ["mnist", "simple_dense_net"]
42 |
--------------------------------------------------------------------------------
/configs/experiment/reproduce_paper2021_cnn.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=reproduce_paper2021_cnn
5 |
6 | defaults:
7 | - override /mode: exp.yaml
8 | - override /trainer: default.yaml
9 | - override /model: cnn.yaml
10 | - override /callbacks: wandb.yaml # feel free to change this to default.yaml or any other callback
11 | - override /logger: wandb.yaml # feel free to change this to your favorite logger
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 |
15 | name: "ClimART-21-CNN"
16 |
17 | seed: 7
18 |
19 | trainer:
20 | max_epochs: 100
21 | gradient_clip_val: 1.0
22 |
23 | model:
24 | hidden_dims: [200, 400, 100]
25 | strides: [2, 1, 1]
26 | gap: True
27 | activation_function: "gelu"
28 | net_normalization: "none"
29 |
30 | datamodule:
31 | exp_type: "pristine"
32 | batch_size: 128
33 | train_years: "1990+1999+2003"
34 | validation_years: "2005"
35 | target_type: "shortwave"
36 | # target_variable: "fluxes+heating_rate"
37 |
38 | normalizer:
39 | input_normalization: "z"
40 | output_normalization: null
41 |
42 | logger:
43 | wandb:
44 | tags: ["reproduce-climart-2021", "cnn", "reproduce-cnn"]
45 |
--------------------------------------------------------------------------------
/configs/hparams_search/mnist_optuna.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # example hyperparameter optimization of some experiment with Optuna:
4 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30
5 |
6 | defaults:
7 | - override /hydra/sweeper: optuna
8 |
9 | # choose metric which will be optimized by Optuna
10 | # make sure this is the correct name of some metric logged in lightning module!
11 | optimized_metric: "val/acc_best"
12 |
13 | hydra:
14 | # here we define Optuna hyperparameter search
15 | # it optimizes for value returned from function with @hydra.main decorator
16 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper
17 | sweeper:
18 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
19 | storage: null
20 | study_name: null
21 | n_jobs: 1
22 |
23 | # 'minimize' or 'maximize' the objective
24 | direction: maximize
25 |
26 | # number of experiments that will be executed
27 | n_trials: 20
28 |
29 | # choose Optuna hyperparameter sampler
30 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html
31 | sampler:
32 | _target_: optuna.samplers.TPESampler
33 | seed: 12345
34 | consider_prior: true
35 | prior_weight: 1.0
36 | consider_magic_clip: true
37 | consider_endpoints: false
38 | n_startup_trials: 10
39 | n_ei_candidates: 24
40 | multivariate: false
41 | warn_independent_sampling: true
42 |
43 | # define range of hyperparameters
44 | search_space:
45 | datamodule.batch_size:
46 | type: categorical
47 | choices: [32, 64, 128]
48 | model.lr:
49 | type: float
50 | low: 0.0001
51 | high: 0.2
52 | model.lin1_size:
53 | type: categorical
54 | choices: [32, 64, 128, 256, 512]
55 | model.lin2_size:
56 | type: categorical
57 | choices: [32, 64, 128, 256, 512]
58 | model.lin3_size:
59 | type: categorical
60 | choices: [32, 64, 128, 256, 512]
61 |
--------------------------------------------------------------------------------
/configs/input_transform/flatten.yaml:
--------------------------------------------------------------------------------
1 | _target_: climart.data_transform.transforms.FlattenTransform
2 | exp_type: ${datamodule.exp_type}
--------------------------------------------------------------------------------
/configs/input_transform/graphnet_level_nodes.yaml:
--------------------------------------------------------------------------------
1 | _target_: climart.data_transform.transforms.LayerEdgesAndLevelNodesGraph
2 | exp_type: ${datamodule.exp_type}
--------------------------------------------------------------------------------
/configs/input_transform/none.yaml:
--------------------------------------------------------------------------------
1 | _target_: climart.data_transforms.transforms.IdentityTranform
2 | exp_type: ${datamodule.exp_type}
--------------------------------------------------------------------------------
/configs/input_transform/repeat_global_vars.yaml:
--------------------------------------------------------------------------------
1 | _target_: climart.data_transform.transforms.RepeatGlobalsTransform
2 | exp_type: ${datamodule.exp_type}
--------------------------------------------------------------------------------
/configs/local/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/configs/local/.gitkeep
--------------------------------------------------------------------------------
/configs/logger/comet.yaml:
--------------------------------------------------------------------------------
1 | # https://www.comet.ml
2 |
3 | comet:
4 | _target_: pytorch_lightning.loggers.comet.CometLogger
5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6 | project_name: "template-tests"
7 | experiment_name: ${name}
8 |
--------------------------------------------------------------------------------
/configs/logger/csv.yaml:
--------------------------------------------------------------------------------
1 | # csv logger built in lightning
2 |
3 | csv:
4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
5 | save_dir: "."
6 | name: "csv/"
7 | version: ${name}
8 | prefix: ""
9 |
--------------------------------------------------------------------------------
/configs/logger/many_loggers.yaml:
--------------------------------------------------------------------------------
1 | # train with many loggers at once
2 |
3 | defaults:
4 | # - comet.yaml
5 | - csv.yaml
6 | # - mlflow.yaml
7 | # - neptune.yaml
8 | - tensorboard.yaml
9 | - wandb.yaml
10 |
--------------------------------------------------------------------------------
/configs/logger/mlflow.yaml:
--------------------------------------------------------------------------------
1 | # https://mlflow.org
2 |
3 | mlflow:
4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
5 | experiment_name: ${name}
6 | tracking_uri: null
7 | tags: null
8 | save_dir: ./mlruns
9 | prefix: ""
10 | artifact_location: null
11 |
--------------------------------------------------------------------------------
/configs/logger/neptune.yaml:
--------------------------------------------------------------------------------
1 | # https://neptune.ai
2 |
3 | neptune:
4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger
5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
6 | project_name: your_name/template-tests
7 | close_after_fit: True
8 | offline_mode: False
9 | experiment_name: ${name}
10 | experiment_id: null
11 | prefix: ""
12 |
--------------------------------------------------------------------------------
/configs/logger/none.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/configs/logger/none.yaml
--------------------------------------------------------------------------------
/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | # https://www.tensorflow.org/tensorboard/
2 |
3 | tensorboard:
4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
5 | save_dir: "tensorboard/"
6 | name: null
7 | version: ${name}
8 | log_graph: False
9 | default_hp_metric: True
10 | prefix: ""
11 |
--------------------------------------------------------------------------------
/configs/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | # https://wandb.ai
2 |
3 | wandb:
4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
5 | # entity: "some-name" # optionally set to name of your wandb team
6 | name: ${name}
7 | tags: []
8 | notes: "..."
9 | project: "ClimART"
10 | group: ""
11 | resume: "allow"
12 | reinit: True
13 | mode: online # disabled # disabled for no wandb logging
14 | save_dir: ${work_dir}/
15 | offline: False # set True to store all logs only locally
16 | id: null # pass correct id to resume experiment!
17 | log_model: False
18 | prefix: ""
19 | job_type: "train"
20 |
--------------------------------------------------------------------------------
/configs/main_config.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default training configuration
4 | defaults:
5 | - _self_
6 | - trainer: default.yaml
7 | - model: mlp.yaml
8 |
9 | - callbacks: default.yaml # or use wandb.yaml for wandb suppport
10 | - logger: wandb # set logger here or use command line (e.g. `python run.py logger=wandb`)
11 |
12 | # modes are special collections of config options for different purposes, e.g. debugging
13 | - mode: default.yaml
14 |
15 | # experiment configs allow for version control of specific configurations
16 | # for example, use them to store best hyperparameters for each combination of model and datamodule
17 | - experiment: null
18 |
19 | # config for hyperparameter optimization
20 | - hparams_search: null
21 |
22 | # optional local config for machine/user specific settings
23 | - optional local: default.yaml
24 |
25 | # enable color logging
26 | #- override hydra/hydra_logging: colorlog
27 | # - override hydra/job_logging: colorlog
28 | # default optimizer is AdamW
29 | - override optimizer@model.optimizer: adamw.yaml
30 |
31 |
32 | datamodule:
33 | _target_: climart.datamodules.pl_climart_datamodule.ClimartDataModule
34 | exp_type: "pristine"
35 | target_type: "shortwave"
36 | target_variable: "fluxes+heating_rate"
37 | train_years: "2000"
38 | validation_years: "2005"
39 | batch_size: 128
40 | eval_batch_size: 512
41 | num_workers: 0
42 | pin_memory: True
43 | load_train_into_mem: True
44 | load_valid_into_mem: True
45 | load_test_into_mem: False
46 | test_main_dataset: True
47 | test_ood_1991: False
48 | test_ood_historic: False
49 | test_ood_future: False
50 | verbose: ${verbose}
51 | # path to folder with data (optional, can also override constants.DATA_DIR to point to correct dir)
52 | # Make sure that it is an absolute path! hydra.runtime.cwd points to the original working dir.
53 | data_dir: "${hydra:runtime.cwd}/ClimART_DATA" # null
54 |
55 | normalizer:
56 | _target_: climart.data_transform.normalization.Normalizer
57 | input_normalization: "z"
58 | output_normalization: "z"
59 | spatial_normalization_in: False
60 | spatial_normalization_out: False
61 | log_scaling: False
62 | data_dir: ${datamodule.data_dir}
63 | verbose: ${verbose}
64 | # path to original working directory
65 | # hydra hijacks working directory by changing it to the new log directory
66 | # so it's useful to have this path as a special variable
67 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
68 | work_dir: ${hydra:runtime.cwd} # {oc.env:ENV_VAR} allows to get environment variable ENV_VAR
69 |
70 | val_metric: "val/${target_var_id:heating_rate, ${datamodule.target_type}}/rmse"
71 |
72 | # path to checkpoints
73 | ckpt_dir: ${work_dir}/checkpoints/
74 |
75 | # path for logging
76 | log_dir: ${work_dir}/logs/
77 |
78 | # pretty print config at the start of the run using Rich library
79 | print_config: True
80 |
81 | # disable python warnings if they annoy you
82 | ignore_warnings: True
83 |
84 | # evaluate on test set, using best model weights achieved during training
85 | # lightning chooses best weights based on metric specified in checkpoint callback
86 | test_after_training: True
87 |
88 | # Upload config file to wandb cloud?
89 | save_config_to_wandb: True
90 |
91 | # Verbose?
92 | verbose: True
93 |
94 | # seed for random number generators in pytorch, numpy and python.random
95 | seed: 11
96 |
97 | # name of the run, should be used along with experiment mode
98 | name: null
99 |
--------------------------------------------------------------------------------
/configs/mode/debug.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # debug mode sets name of the logging folder to 'logs/debug/...'
4 | # enables trainer debug options
5 | # also sets level od command line logger to DEBUG
6 | # example usage:
7 | # `python run.py mode=debug`
8 |
9 | defaults:
10 | - override /trainer: debug.yaml
11 | - override /model: mlp.yaml
12 | - override /logger: none.yaml
13 | - override /callbacks: none.yaml
14 |
15 | debug_mode: True
16 |
17 | hydra:
18 | # sets level of all command line loggers to 'DEBUG'
19 | verbose: True
20 |
21 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
22 | # use this to set level of only chosen command line loggers to 'DEBUG'
23 | # verbose: [src.train, src.utils]
24 |
25 | run:
26 | dir: ${log_dir}/debug/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
27 | sweep:
28 | dir: ${log_dir}/debug/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S}
29 | subdir: ${hydra.job.num}
30 |
31 | # disable rich config printing, since it will be already printed by hydra when `verbose: True`
32 | print_config: False
33 |
--------------------------------------------------------------------------------
/configs/mode/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # default running mode
4 |
5 | default_mode: True
6 |
7 | hydra:
8 | # default output paths for all file logs
9 | run:
10 | dir: ${log_dir}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
11 | sweep:
12 | dir: ${log_dir}/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S}
13 | subdir: ${hydra.job.num}
14 |
--------------------------------------------------------------------------------
/configs/mode/exp.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # experiment mode sets name of the logging folder to the experiment name
4 | # can also be used to name the run in the logger
5 | # example usage:
6 | # `python run.py mode=exp name=some_name`
7 |
8 | experiment_mode: True
9 |
10 | name: ???
11 |
12 | hydra:
13 | run:
14 | dir: ${log_dir}/experiments/${name}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
15 | sweep:
16 | dir: ${log_dir}/experiments/${name}/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S}
17 | subdir: ${hydra.job.num}
18 |
--------------------------------------------------------------------------------
/configs/model/cnn.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - /input_transform: repeat_global_vars.yaml
3 | - /optimizer: adamw.yaml
4 |
5 | _target_: climart.models.CNNs.CNN.CNN_Net
6 |
7 | hidden_dims: [256, 256, 256]
8 | kernels: [20, 10, 5]
9 | strides: [2, 1, 1]
10 | net_normalization: null
11 | activation_function: "Gelu"
12 | dropout: 0.0
13 | gap: True
14 |
15 |
16 | downwelling_loss_contribution: 0.5
17 | upwelling_loss_contribution: 0.5
18 | heating_rate_loss_contribution: 0
19 |
20 | monitor: ${val_metric}
21 | scheduler:
22 | _target_: torch.optim.lr_scheduler.ExponentialLR
23 | gamma: 0.98
--------------------------------------------------------------------------------
/configs/model/graphnet.yaml:
--------------------------------------------------------------------------------
1 | _target_: climart.models.GraphNet.graph_network.GN_withReadout
2 |
3 | defaults:
4 | - /input_transform: graphnet_level_nodes.yaml
5 | - /optimizer: adamw.yaml
6 |
7 | hidden_dims: [128, 128, 128]
8 | net_normalization: null
9 | activation_function: "Gelu"
10 | dropout: 0.0
11 | residual: True
12 |
13 | update_mlp_n_layers: 1
14 | aggregator_funcs: "mean"
15 | graph_pooling: "mean"
16 | readout_which_output: "nodes"
17 |
18 | downwelling_loss_contribution: 0.5
19 | upwelling_loss_contribution: 0.5
20 | heating_rate_loss_contribution: 0
21 |
22 | monitor: ${val_metric}
23 | scheduler:
24 | _target_: torch.optim.lr_scheduler.ExponentialLR
25 | gamma: 0.98
26 |
--------------------------------------------------------------------------------
/configs/model/mlp.yaml:
--------------------------------------------------------------------------------
1 | _target_: climart.models.MLP.ClimartMLP
2 |
3 | defaults:
4 | - /input_transform: flatten.yaml
5 | - /optimizer: adamw.yaml
6 |
7 | hidden_dims: [512, 256, 256]
8 | net_normalization: "layer_norm"
9 | activation_function: "Gelu"
10 | dropout: 0.0
11 | residual: False
12 |
13 | downwelling_loss_contribution: 0.5
14 | upwelling_loss_contribution: 0.5
15 | heating_rate_loss_contribution: 0
16 |
17 | monitor: ${val_metric}
18 | scheduler:
19 | _target_: torch.optim.lr_scheduler.ExponentialLR
20 | gamma: 0.98
21 |
--------------------------------------------------------------------------------
/configs/optimizer/adam.yaml:
--------------------------------------------------------------------------------
1 | name: "adam"
2 | lr: 2e-4
3 | weight_decay: 1e-6
4 | eps: 1e-8
--------------------------------------------------------------------------------
/configs/optimizer/adamw.yaml:
--------------------------------------------------------------------------------
1 | name: "adamw"
2 | lr: 2e-4
3 | weight_decay: 1e-6
4 | eps: 1e-8
5 |
--------------------------------------------------------------------------------
/configs/optimizer/sgd.yaml:
--------------------------------------------------------------------------------
1 | name: "sgd"
2 | lr: 5e-4
3 | weight_decay: 0.05
4 | momentum: 0.9
--------------------------------------------------------------------------------
/configs/trainer/ddp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | gpus: 4
5 | strategy: ddp
6 |
--------------------------------------------------------------------------------
/configs/trainer/debug.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | gpus: 1
5 |
6 | min_epochs: 1
7 | max_epochs: 2
8 |
9 | # prints
10 | profiler: null
11 |
12 | # debugs
13 | fast_dev_run: true
14 | overfit_batches: 0
15 | limit_train_batches: 1.0
16 | limit_val_batches: 1.0
17 | limit_test_batches: 1.0
18 | track_grad_norm: -1
19 | detect_anomaly: true
20 |
--------------------------------------------------------------------------------
/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | gpus: 1
4 |
5 | min_epochs: 1
6 | max_epochs: 100
7 |
8 | gradient_clip_val: 1.0
9 |
10 | resume_from_checkpoint: null
11 |
12 | # number of validation steps to execute at the beginning of the training
13 | num_sanity_val_steps: 0
14 |
--------------------------------------------------------------------------------
/configs/transform/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: climart.data_transforms.transforms.IdentityTranform
2 | exp_type: ${datamodule.exp_type}
--------------------------------------------------------------------------------
/download_climart.sh:
--------------------------------------------------------------------------------
1 | # Change the directory where the data will be downloaded below
2 | data_dir="ClimART_DATA"
3 | mkdir -p ${data_dir}/inputs
4 | mkdir -p ${data_dir}/outputs_clear_sky
5 | mkdir -p ${data_dir}/outputs_pristine
6 |
7 | # Uncomment all lines to download all data (which will take time though :)).
8 |
9 | echo "Downloading metadata & statistics..."
10 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/META_INFO.json --output ${data_dir}/META_INFO.json
11 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/statistics.npz --output ${data_dir}/statistics.npz
12 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/areacella_fx_CanESM5.nc --output ${data_dir}/areacella_fx_CanESM5.npz
13 | echo "Done."
14 |
15 | echo "Downloading input files..."
16 | for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
17 | for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
18 | for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
19 | for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
20 | echo "Done."
21 |
22 | echo "Downloading clear-sky targets..."
23 | for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done
24 | for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done
25 | for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done
26 | for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done
27 | echo "Done."
28 |
29 | echo "Downloading pristine-sky targets..."
30 | for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
31 | for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
32 | for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
33 | for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
34 |
35 | echo "Done. Finished downloading ClimART :)"
36 |
--------------------------------------------------------------------------------
/download_data_subset.sh:
--------------------------------------------------------------------------------
1 | # Change the directory where the data will be downloaded below
2 | data_dir="ClimART_DATA"
3 | mkdir -p ${data_dir}/inputs
4 | mkdir -p ${data_dir}/outputs_clear_sky
5 | mkdir -p ${data_dir}/outputs_pristine
6 |
7 | # Uncomment all lines to download all data (which will take time though :)).
8 |
9 | echo "Downloading metadata & statistics..."
10 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/META_INFO.json --output ${data_dir}/META_INFO.json
11 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/statistics.npz --output ${data_dir}/statistics.npz
12 | curl https://object-arbutus.cloud.computecanada.ca/rt-public/areacella_fx_CanESM5.nc --output ${data_dir}/areacella_fx_CanESM5.npz
13 | echo "Done."
14 |
15 | echo "Downloading input files..."
16 | for x in 2000 2005 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
17 | for x in {2007..2014} ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
18 | #for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
19 | #for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
20 | #for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
21 | #for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/inputs/$x.h5 --output ${data_dir}/inputs/$x.h5; done
22 | echo "Done."
23 |
24 | echo "Downloading clear-sky targets..."
25 | #for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done
26 | #for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done
27 | #for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done
28 | #for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_clear_sky/$x.h5 --output ${data_dir}/outputs_clear_sky/$x.h5; done
29 | echo "Done."
30 |
31 | echo "Downloading pristine-sky targets..."
32 | for x in 2000 2005 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
33 | for x in {2007..2014} ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
34 | #for x in {1979..1991};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
35 | #for x in {1994..2014};do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
36 | #for x in 1850 1851 1852 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
37 | #for x in 2097 2098 2099 ;do curl https://object-arbutus.cloud.computecanada.ca/rt-public/outputs_pristine/$x.h5 --output ${data_dir}/outputs_pristine/$x.h5; done
38 |
39 | echo "Done. Finished downloading the data."
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: climart
2 | channels:
3 | - rusty1s
4 | - pytorch
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - cartopy
9 | - cftime
10 | - cudatoolkit
11 | - curl
12 | - cycler
13 | - dask
14 | - dask-core
15 | - einops>=0.3.0
16 | - geos>=3.9.1
17 | - glib>=2.66.3
18 | - h5py>=3.3.0
19 | - hdf5>=1.10.6
20 | - kiwisolver
21 | - matplotlib>=3.4.2
22 | - mkl
23 | - netcdf4>=1.5.7
24 | - ninja>=1.10.2
25 | - numpy>=1.21.1
26 | - pandas>=1.3.1
27 | - pip>=21.2.3
28 | - pytest
29 | - python>=3.8.0
30 | - python-dateutil>=2.8.2
31 | - pytorch=1.9.0
32 | - pytorch-lightning>=1.5.8
33 | - pytorch-scatter>=2.0.8
34 | - pyyaml>=5.4.1
35 | - qt>=5.12.9
36 | - scipy>=1.7.1
37 | - seaborn
38 | - setuptools=59.5.0
39 | - six>=1.16.0
40 | - tornado
41 | - xarray>=0.19.0
42 | - yaml
43 | - zlib
44 | - pip:
45 | - hydra-core
46 | - hydra_colorlog
47 | - networkx
48 | - pre-commit
49 | - python-dotenv
50 | - rich
51 | - timm
52 | - torchmetrics
53 | - tqdm>=4.62.0
54 | - wandb>=0.12.9
55 |
56 |
--------------------------------------------------------------------------------
/images/variable_table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RolnickLab/climart/6d12fa478de5db5209842e0369a2a24377079edd/images/variable_table.png
--------------------------------------------------------------------------------
/notebooks/2022-06-06-get-predictions-pl.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 9,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "The autoreload extension is already loaded. To reload it, use:\n",
13 | " %reload_ext autoreload\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "%load_ext autoreload\n",
19 | "%autoreload 2"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 3,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import os\n",
29 | "# Make sure we're in the right directory\n",
30 | "if os.path.basename(os.getcwd()) == \"notebooks\":\n",
31 | " os.chdir(\"..\")"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {},
37 | "source": [
38 | "## Inference with ClimART data and models"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 8,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "import sys\n",
48 | "import numpy as np\n",
49 | "import pytorch_lightning as pl\n",
50 | "from climart.utils.config_utils import get_config_from_hydra_compose_overrides\n",
51 | "from climart.interface import get_model_and_data"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 18,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "num_workers=2"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 19,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "os.environ['WANDB_SILENT']=\"true\"\n",
70 | "np.set_printoptions(suppress=True, threshold=sys.maxsize)"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 22,
76 | "metadata": {},
77 | "outputs": [
78 | {
79 | "name": "stderr",
80 | "output_type": "stream",
81 | "text": [
82 | "Multiprocessing is handled by SLURM.\n",
83 | "GPU available: True, used: True\n",
84 | "TPU available: False, using: 0 TPU cores\n",
85 | "IPU available: False, using: 0 IPUs\n",
86 | "HPU available: False, using: 0 HPUs\n"
87 | ]
88 | }
89 | ],
90 | "source": [
91 | "exp_type='pristine'\n",
92 | "predict_years = \"2012-14\"\n",
93 | "overrides = [f'model=mlp',\n",
94 | " f'datamodule.exp_type={exp_type}',\n",
95 | " f'datamodule.num_workers={num_workers}', \n",
96 | " f'++datamodule.predict_years={predict_years}']\n",
97 | "cfg = get_config_from_hydra_compose_overrides(overrides)\n",
98 | "model, dm = get_model_and_data(cfg)\n",
99 | "trainer = pl.Trainer(gpus=-1)"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 55,
105 | "metadata": {},
106 | "outputs": [
107 | {
108 | "name": "stderr",
109 | "output_type": "stream",
110 | "text": [
111 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
112 | ]
113 | },
114 | {
115 | "data": {
116 | "application/vnd.jupyter.widget-view+json": {
117 | "model_id": "90e0261b55f04cccaf82bcb90e3494f2",
118 | "version_major": 2,
119 | "version_minor": 0
120 | },
121 | "text/plain": [
122 | "Predicting: 0it [00:00, ?it/s]"
123 | ]
124 | },
125 | "metadata": {},
126 | "output_type": "display_data"
127 | }
128 | ],
129 | "source": [
130 | "results1 = trainer.predict(model=model,datamodule=dm, return_predictions=True)"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 62,
136 | "metadata": {},
137 | "outputs": [
138 | {
139 | "data": {
140 | "text/plain": [
141 | "(122880, 49)"
142 | ]
143 | },
144 | "execution_count": 62,
145 | "metadata": {},
146 | "output_type": "execute_result"
147 | }
148 | ],
149 | "source": [
150 | "results1 = model.aggregate_predictions(results1)\n",
151 | "sw_hr_preds_2014 = results1[2014]['preds']['hrsc']\n",
152 | "sw_hr_preds_2014.shape"
153 | ]
154 | },
155 | {
156 | "cell_type": "markdown",
157 | "metadata": {},
158 | "source": [
159 | "#### Alternatively to command line/config-directed changing of the predict_years, you can also do this programmatically to an existing datamodule as follows:"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 79,
165 | "metadata": {},
166 | "outputs": [
167 | {
168 | "name": "stderr",
169 | "output_type": "stream",
170 | "text": [
171 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
172 | ]
173 | },
174 | {
175 | "data": {
176 | "application/vnd.jupyter.widget-view+json": {
177 | "model_id": "0c78d34a8ed94ced997186cdffb4d0e5",
178 | "version_major": 2,
179 | "version_minor": 0
180 | },
181 | "text/plain": [
182 | "Predicting: 0it [00:00, ?it/s]"
183 | ]
184 | },
185 | "metadata": {},
186 | "output_type": "display_data"
187 | }
188 | ],
189 | "source": [
190 | "dm.predict_years = \"2010\"\n",
191 | "results2 = trainer.predict(model, dm)"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": 80,
197 | "metadata": {},
198 | "outputs": [
199 | {
200 | "data": {
201 | "text/plain": [
202 | "torch.Size([512, 50])"
203 | ]
204 | },
205 | "execution_count": 80,
206 | "metadata": {},
207 | "output_type": "execute_result"
208 | }
209 | ],
210 | "source": [
211 | "results2[0]['preds']['rsuc'].shape"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 81,
217 | "metadata": {},
218 | "outputs": [
219 | {
220 | "data": {
221 | "text/plain": [
222 | "(122880, 49)"
223 | ]
224 | },
225 | "execution_count": 81,
226 | "metadata": {},
227 | "output_type": "execute_result"
228 | }
229 | ],
230 | "source": [
231 | "results2 = model.aggregate_predictions(results2)\n",
232 | "sw_hr_preds_2010 = results2[2010]['preds']['hrsc']\n",
233 | "sw_hr_preds_2010.shape"
234 | ]
235 | }
236 | ],
237 | "metadata": {
238 | "kernelspec": {
239 | "display_name": "Climart-GraphNet",
240 | "language": "python",
241 | "name": "climart_gn"
242 | },
243 | "language_info": {
244 | "codemirror_mode": {
245 | "name": "ipython",
246 | "version": 3
247 | },
248 | "file_extension": ".py",
249 | "mimetype": "text/x-python",
250 | "name": "python",
251 | "nbconvert_exporter": "python",
252 | "pygments_lexer": "ipython3",
253 | "version": "3.9.12"
254 | }
255 | },
256 | "nbformat": 4,
257 | "nbformat_minor": 4
258 | }
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import dotenv
2 | import hydra
3 | from omegaconf import DictConfig, OmegaConf
4 |
5 | # load environment variables from `.env` file if it exists
6 | # recursively searches for `.env` in all folders starting from work dir
7 | from climart.utils.utils import target_var_id_mapping
8 |
9 | dotenv.load_dotenv(override=True)
10 | OmegaConf.register_new_resolver("target_var_id", target_var_id_mapping)
11 |
12 |
13 | @hydra.main(config_path="configs/", config_name="main_config.yaml", version_base=None)
14 | def main(config: DictConfig):
15 | from climart.train import run_model
16 | return run_model(config)
17 |
18 |
19 | if __name__ == "__main__":
20 | main()
21 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | line_length = 99
3 | profile = black
4 | filter_files = True
5 |
6 |
7 | [flake8]
8 | max_line_length = 99
9 | show_source = True
10 | format = pylint
11 | ignore =
12 | F401 # Module imported but unused
13 | W504 # Line break occurred after a binary operator
14 | F841 # Local variable name is assigned to but never used
15 | F403 # from module import *
16 | E501 # Line too long
17 | exclude =
18 | .git
19 | __pycache__
20 | data/*
21 | tests/*
22 | notebooks/*
23 | logs/*
24 |
25 |
26 | [tool:pytest]
27 | python_files = tests/*
28 | log_cli = True
29 | markers =
30 | slow
31 | addopts =
32 | --durations=0
33 | --strict-markers
34 | --doctest-modules
35 | filterwarnings =
36 | ignore::DeprecationWarning
37 | ignore::UserWarning
38 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 | keywords = ["radiative transfer", "emulation", "deep learning", "climart", "dataset", "pytorch", "mila", "eccc"]
6 |
7 | setup(
8 | name='climart',
9 | version='0.1.0',
10 | packages=find_packages(),
11 | url='https://github.com/RolnickLab/climart',
12 | license='CC BY 4.0',
13 | author='Salva Rühling Cachay, Venkatesh Ramesh',
14 | author_email='',
15 | keywords=keywords,
16 | description='A comprehensive, large-scale dataset for'
17 | ' benchmarking neural network emulators of the'
18 | ' radiation component in climate and weather models.',
19 | long_description=long_description,
20 | long_description_content_type="text/markdown",
21 | python_requires=">=3.7.0",
22 | install_requires=[
23 | "torch>=1.7.1",
24 | "scikit-learn",
25 | ],
26 | classifiers=[
27 | "Intended Audience :: Developers",
28 | "Intended Audience :: Science/Research",
29 | "Operating System :: OS Independent",
30 | "Programming Language :: Python :: 3",
31 | "Programming Language :: Python :: 3.7",
32 | "Programming Language :: Python :: 3.8",
33 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
34 | ],
35 | )
36 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def test_year_string_to_list_parsing():
4 | from rtml.utils.utils import year_string_to_list
5 | inputs = [
6 | ('1990', [1990]),
7 | ('1990-1991', [1990, 1991]),
8 | ('1990-91', [1990, 1991]),
9 | ('90-91', [1990, 1991]),
10 | ('1989-1991', [1989, 1990, 1991]),
11 | ('1990-1991+2003-2005', [1990, 1991, 2003, 2004, 2005]),
12 | ('1990+1999+2005-06', [1990, 1999, 2005, 2006]),
13 | ('1990+1999', [1990, 1999]),
14 | ]
15 | for i, (string, expected) in enumerate(inputs):
16 | actual = year_string_to_list(string)
17 | err_msg = f"Input {i+1}: Expected {expected}, but {actual} was returned."
18 | assert all([a == b for a, b in zip(actual, expected)]), err_msg
19 |
--------------------------------------------------------------------------------
/tests/test_variables.py:
--------------------------------------------------------------------------------
1 | from climart.data_wrangling.data_variables import INPUT_VARS_CLOUDS, INPUT_VARS_AEROSOLS, _ALL_INPUT_VARS
2 |
3 |
4 | def exp_type_subset_vars_test():
5 | for k in INPUT_VARS_CLOUDS:
6 | assert k in _ALL_INPUT_VARS, f"Cloud var {k} was expected to be in _ALL_INPUT_VARS."
7 | for k in INPUT_VARS_AEROSOLS:
8 | assert k in _ALL_INPUT_VARS, f"Aerosol var {k} was expected to be in _ALL_INPUT_VARS."
9 |
--------------------------------------------------------------------------------