├── LICENSE ├── README.md ├── config.py ├── experiment_configs ├── t2m_GlobalStandardScaling_Monthly.json ├── t2m_GlobalStandardScaling_NotMonthly.json ├── t2m_GlobalStandardScaling_NotMonthly_longer.json ├── t2m_LocalStandardScaling_Monthly.json ├── t2m_LocalStandardScaling_NotMonthly.json ├── test_single_image_t2m_GlobalStandardScaling_NotMonthly.json └── test_t2m_GlobalStandardScaling_NotMonthly.json ├── inference.py ├── inference_on_single_image.py ├── model ├── __init__.py ├── base_model.py ├── ema.py ├── model.py ├── modules │ ├── diffusion.py │ └── unet.py └── networks.py ├── report.pdf ├── requirements.txt ├── results └── reverse_diffusion_steps.jpg ├── train.py ├── utils.py └── weatherbench_data ├── __init__.py ├── datasets.py ├── datastorage.py ├── fileconverter.py ├── transforms.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Davit Papikyan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Probabilistic Downscaling of Climate Variables 2 | 3 | Project with Colloquium (MA8114) at TUM: Probabilistic Downscaling of Climate Variables Using Denoising Diffusion Probabilistic Models 4 | 5 | Supervisor: Prof. Dr. Rüdiger Westermann (Chair of Computer Graphics and Visualization)\ 6 | Advisor: Kevin Höhlein (Chair of Computer Graphics and Visualization) 7 | 8 | --- 9 | 10 | Downscaling combines methods that are used to infer high-resolution information from 11 | low-resolution climate variables. We approach this problem as an image super-resolution 12 | task and employ Denoising Diffusion Probabilistic Model to generate finer-scale variables 13 | conditioned on coarse-scale information. Experiments are conducted on WeatherBench dataset 14 | by analysing temperature at 2 m height above the surface variable. See the final report [here](https://github.com/davitpapikyan/Probabilistic-Downscaling-of-Climate-Variables/blob/main/report.pdf). 15 | 16 | ![](results/reverse_diffusion_steps.jpg?raw=true) 17 | 18 | --- 19 | 20 | ## References 21 | 22 | - Liangwei Jiang (2021) Image-Super-Resolution-via-Iterative-Refinement [[Source code](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement#readme)] 23 | - Song et al. (2021) Score-Based Generative Modeling through Stochastic Differential Equations [[Source code](https://github.com/yang-song/score_sde_pytorch)] 24 | - Stephan Rasp, Peter D. Dueben, Sebastian Scher, Jonathan A. Weyn, Soukayna Mouatadid, and Nils Thuerey, 2020. WeatherBench: A benchmark dataset for data-driven weather forecasting. arXiv: [WeatherBench: A benchmark dataset for data-driven weather forecasting 25 | ](https://arxiv.org/abs/2002.00469) 26 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """Defines configuration parameters for the whole model and dataset. 2 | """ 3 | import argparse 4 | import json 5 | import os 6 | from collections import OrderedDict 7 | from datetime import datetime 8 | 9 | 10 | def get_current_datetime() -> str: 11 | """Converts the current datetime to string. 12 | 13 | Returns: 14 | String version of current datetime of the form: %y%m%d_%H%M%S. 15 | """ 16 | return datetime.now().strftime("%y%m%d_%H%M%S") 17 | 18 | 19 | def mkdirs(paths) -> None: 20 | """Creates directories represented by paths argument. 21 | 22 | Args: 23 | paths: Either list of paths or a single path. 24 | """ 25 | if isinstance(paths, str): 26 | os.makedirs(paths, exist_ok=True) 27 | else: 28 | for path in paths: 29 | os.makedirs(path, exist_ok=True) 30 | 31 | 32 | class Config: 33 | """Configuration class. 34 | 35 | Attributes: 36 | args: Command line aarguments. 37 | root: Configuration json file. 38 | gpu_ids: A list of GPU IDs. 39 | params: A dictionary containing configuration parameters stored in a json file. 40 | name: Name of the experiment. 41 | phase: Either train or val. 42 | distributed: Whether the computation will be distributed among multiple GPUs or not. 43 | log: Path to logs. 44 | tb_logger: Tensorboard logging directory. 45 | results: Validation results directory. 46 | checkpoint: Model checkpoints directory. 47 | resume_state: The path to load the network. 48 | dataset_name: The name of dataset. 49 | dataroot: The path to dataset. 50 | batch_size: Batch size. 51 | num_workers: The number of processes for multi-process data loading. 52 | use_shuffle: Either to shuffle the training data or not. 53 | train_min_date: Minimum date starting from which to read the data for training. 54 | train_max_date: Maximum date until which to read the date for training. 55 | val_min_date: Minimum date starting from which to read the data for validation. 56 | val_max_date: Maximum date until which to read the date for validation. 57 | train_subset_min_date: Minimum date starting from which to read the data for model evaluation on train subset. 58 | train_subset_max_date: Maximum date starting until which to read the data for model evaluation on train subset. 59 | variables: A list of WeatherBench variables. 60 | finetune_norm: Whetehr to fine-tune or train from scratch. 61 | in_channel: The number of channels of input tensor of U-Net. 62 | out_channel: The number of channels of output tensor of U-Net. 63 | inner_channel: Timestep embedding dimension. 64 | norm_groups: The number of groups for group normalization. 65 | channel_multiplier: A tuple specifying the scaling factors of channels. 66 | attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer. 67 | res_blocks: The number of residual blocks. 68 | dropout: Dropout probability. 69 | init_method: NN weight initialization method. One of normal, kaiming or orthogonal inisializations. 70 | train_schedule: Defines the type of beta schedule for training. 71 | train_n_timestep: Number of diffusion timesteps for training. 72 | train_linear_start: Minimum value of the linear schedule for training. 73 | train_linear_end: Maximum value of the linear schedule for training. 74 | val_schedule: Defines the type of beta schedule for validation. 75 | val_n_timestep: Number of diffusion timesteps for validation. 76 | val_linear_start: Minimum value of the linear schedule for validation. 77 | val_linear_end: Maximum value of the linear schedule for validation. 78 | test_schedule: Defines the type of beta schedule for inference. 79 | test_n_timestep: Number of diffusion timesteps for inference. 80 | test_linear_start: Minimum value of the linear schedule for inference. 81 | test_linear_end: Maximum value of the linear schedule for inference. 82 | conditional: Whether to condition on INTERPOLATED image or not. 83 | diffusion_loss: Either 'l1' or 'l2'. 84 | n_iter: Number of iterations to train. 85 | val_freq: Validation frequency. 86 | save_checkpoint_freq: Model checkpoint frequency. 87 | print_freq: The frequency of displaying training information. 88 | n_val_vis: Number of data points to visualize. 89 | val_vis_freq: Validation data points visualization frequency. 90 | sample_size: Numer of SR images to generate to calculate metrics. 91 | optimizer_type: The name of optimization algorithm. Supported values are 'adam', 'adamw'. 92 | amsgrad: Whether to use the AMSGrad variant of optimizer. 93 | lr: The learning rate. 94 | experiments_root: The path to experiment. 95 | tranform_monthly: Whether to apply transformation monthly or on the whole dataset. 96 | height: U-Net input tensor height value. 97 | """ 98 | 99 | def __init__(self, args: argparse.Namespace): 100 | self.args = args 101 | self.root = self.args.config 102 | self.gpu_ids = self.args.gpu_ids 103 | self.params = {} 104 | self.experiments_root = None 105 | self.__parse_configs() 106 | self.name = self.params["name"] 107 | self.phase = self.params["phase"] 108 | self.gpu_ids = self.params["gpu_ids"] 109 | self.distributed = self.params["distributed"] 110 | self.log = self.params["path"]["log"] 111 | self.tb_logger = self.params["path"]["tb_logger"] 112 | self.results = self.params["path"]["results"] 113 | self.checkpoint = self.params["path"]["checkpoint"] 114 | self.resume_state = self.params["path"]["resume_state"] 115 | self.dataset_name = self.params["data"]["name"] 116 | self.dataroot = self.params["data"]["dataroot"] 117 | self.batch_size = self.params["data"]["batch_size"] 118 | self.num_workers = self.params["data"]["num_workers"] 119 | self.use_shuffle = self.params["data"]["use_shuffle"] 120 | self.train_min_date = self.params["data"]["train_min_date"] 121 | self.train_max_date = self.params["data"]["train_max_date"] 122 | self.train_subset_min_date = self.params["data"]["train_subset_min_date"] 123 | self.train_subset_max_date = self.params["data"]["train_subset_max_date"] 124 | self.tranform_monthly = self.params["data"]["apply_tranform_monthly"] 125 | self.transformation = self.params["data"]["transformation"] 126 | self.val_min_date = self.params["data"]["val_min_date"] 127 | self.val_max_date = self.params["data"]["val_max_date"] 128 | self.variables = self.params["data"]["variables"] 129 | self.height = self.params["data"]["height"] 130 | self.finetune_norm = self.params["model"]["finetune_norm"] 131 | self.in_channel = self.params["model"]["unet"]["in_channel"] 132 | self.out_channel = self.params["model"]["unet"]["out_channel"] 133 | self.inner_channel = self.params["model"]["unet"]["inner_channel"] 134 | self.norm_groups = self.params["model"]["unet"]["norm_groups"] 135 | self.channel_multiplier = self.params["model"]["unet"]["channel_multiplier"] 136 | self.attn_res = self.params["model"]["unet"]["attn_res"] 137 | self.res_blocks = self.params["model"]["unet"]["res_blocks"] 138 | self.dropout = self.params["model"]["unet"]["dropout"] 139 | self.init_method = self.params["model"]["unet"]["init_method"] 140 | self.train_schedule = self.params["model"]["beta_schedule"]["train"]["schedule"] 141 | self.train_n_timestep = self.params["model"]["beta_schedule"]["train"]["n_timestep"] 142 | self.train_linear_start = self.params["model"]["beta_schedule"]["train"]["linear_start"] 143 | self.train_linear_end = self.params["model"]["beta_schedule"]["train"]["linear_end"] 144 | self.val_schedule = self.params["model"]["beta_schedule"]["val"]["schedule"] 145 | self.val_n_timestep = self.params["model"]["beta_schedule"]["val"]["n_timestep"] 146 | self.val_linear_start = self.params["model"]["beta_schedule"]["val"]["linear_start"] 147 | self.val_linear_end = self.params["model"]["beta_schedule"]["val"]["linear_end"] 148 | self.test_schedule = self.params["model"]["beta_schedule"]["test"]["schedule"] 149 | self.test_n_timestep = self.params["model"]["beta_schedule"]["test"]["n_timestep"] 150 | self.test_linear_start = self.params["model"]["beta_schedule"]["test"]["linear_start"] 151 | self.test_linear_end = self.params["model"]["beta_schedule"]["test"]["linear_end"] 152 | self.conditional = self.params["model"]["diffusion"]["conditional"] 153 | self.diffusion_loss = self.params["model"]["diffusion"]["loss"] 154 | self.n_iter = self.params["training"]["epoch_n_iter"] 155 | self.val_freq = self.params["training"]["val_freq"] 156 | self.save_checkpoint_freq = self.params["training"]["save_checkpoint_freq"] 157 | self.print_freq = self.params["training"]["print_freq"] 158 | self.n_val_vis = self.params["training"]["n_val_vis"] 159 | self.val_vis_freq = self.params["training"]["val_vis_freq"] 160 | self.sample_size = self.params["training"]["sample_size"] 161 | self.optimizer_type = self.params["training"]["optimizer"]["type"] 162 | self.amsgrad = self.params["training"]["optimizer"]["amsgrad"] 163 | self.lr = self.params["training"]["optimizer"]["lr"] 164 | 165 | def __parse_configs(self): 166 | """Reads configureation json file and stores in params attribute.""" 167 | json_str = "" 168 | with open(self.root, "r") as f: 169 | for line in f: 170 | json_str = f"{json_str}{line.split('//')[0]}\n" 171 | 172 | self.params = json.loads(json_str, object_pairs_hook=OrderedDict) 173 | 174 | if not self.params["path"]["resume_state"]: 175 | self.experiments_root = os.path.join("experiments", f"{self.params['name']}_{get_current_datetime()}") 176 | else: 177 | self.experiments_root = "/".join(self.params["path"]["resume_state"].split("/")[:-2]) 178 | 179 | for key, path in self.params["path"].items(): 180 | if not key.startswith("resume"): 181 | self.params["path"][key] = os.path.join(self.experiments_root, path) 182 | mkdirs(self.params["path"][key]) 183 | 184 | if self.gpu_ids: 185 | self.params["gpu_ids"] = [int(gpu_id) for gpu_id in self.gpu_ids.split(",")] 186 | gpu_list = self.gpu_ids 187 | else: 188 | gpu_list = ",".join(str(x) for x in self.params["gpu_ids"]) 189 | 190 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list 191 | self.params["distributed"] = True if len(gpu_list) > 1 else False 192 | 193 | def __getattr__(self, item): 194 | """Returns None when attribute doesn't exist. 195 | 196 | Args: 197 | item: Attribute to retrieve. 198 | 199 | Returns: 200 | None 201 | """ 202 | return None 203 | 204 | def get_hyperparameters_as_dict(self): 205 | """Returns dictionary containg parsed configuration json file. 206 | """ 207 | return self.params 208 | -------------------------------------------------------------------------------- /experiment_configs/t2m_GlobalStandardScaling_Monthly.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t2m_GlobalSS_Monthly", 3 | "phase": "train", 4 | "gpu_ids": [0], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state": null 11 | }, 12 | "data": { 13 | "name": "WeatherBench", 14 | "dataroot": "/mnt/data/papikyan/WeatherBench/numpy", 15 | "batch_size": 16, 16 | "num_workers": 4, 17 | "use_shuffle": true, 18 | "train_min_date": "1979-01-01-00", 19 | "train_max_date": "2016-01-01-00", 20 | "train_subset_min_date": "2014-01-01-00", 21 | "train_subset_max_date": "2016-01-01-00", 22 | "transformation": "GlobalStandardScaling", 23 | "apply_tranform_monthly": true, 24 | "val_min_date": "2016-01-01-00", 25 | "val_max_date": "2018-01-01-00", 26 | "variables": ["t2m"], 27 | "height": 128 28 | }, 29 | "model": { 30 | "finetune_norm": false, 31 | "unet": { 32 | "in_channel": 2, // This should be equal to the number of variables * 2. Used in only networks.py 121 line. 33 | "out_channel": 1, // This should be equal to the number of variables. 34 | "inner_channel": 64, 35 | "norm_groups": 32, // 16 36 | "channel_multiplier": [1, 2, 4, 8], 37 | "attn_res": [16], // Possible values are 128, 64, 32, 16 and depends on channel_multipliers. 38 | "res_blocks": 1, 39 | "dropout": 0.7, 40 | "init_method": "kaiming" 41 | }, 42 | "beta_schedule": { 43 | "train": { 44 | "schedule": "cosine", 45 | "n_timestep": 2000, 46 | "linear_start": 1e-6, 47 | "linear_end": 1e-2 48 | }, 49 | "val": { 50 | "schedule": "cosine", 51 | "n_timestep": 100, 52 | "linear_start": 1e-6, 53 | "linear_end": 1e-2 54 | }, 55 | "test": { 56 | "schedule": "cosine", 57 | "n_timestep": 1000, 58 | "linear_start": 1e-6, 59 | "linear_end": 1e-2 60 | } 61 | }, 62 | "diffusion": { 63 | "conditional": true, 64 | "loss": "l2" 65 | } 66 | }, 67 | "training": { 68 | "epoch_n_iter": 20000, 69 | "val_freq": 2000, 70 | "save_checkpoint_freq": 2000, 71 | "print_freq": 100, 72 | "n_val_vis": 1, 73 | "val_vis_freq": 500, 74 | "sample_size": 5, 75 | "optimizer": { 76 | "type": "adam", // Possible types are ['adam', 'adamw'] 77 | "amsgrad": false, 78 | "lr": 5e-5 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /experiment_configs/t2m_GlobalStandardScaling_NotMonthly.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t2m_GlobalSS_NotMonthly_lr_1e06_dropout_04", 3 | "phase": "train", 4 | "gpu_ids": [0], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state": null 11 | }, 12 | "data": { 13 | "name": "WeatherBench", 14 | "dataroot": "/mnt/data/papikyan/WeatherBench/numpy", 15 | "batch_size": 16, 16 | "num_workers": 4, 17 | "use_shuffle": true, 18 | "train_min_date": "1979-01-01-00", 19 | "train_max_date": "2016-01-01-00", 20 | "train_subset_min_date": "2014-01-01-00", 21 | "train_subset_max_date": "2016-01-01-00", 22 | "transformation": "GlobalStandardScaling", 23 | "apply_tranform_monthly": false, 24 | "val_min_date": "2016-01-01-00", 25 | "val_max_date": "2018-01-01-00", 26 | "variables": ["t2m"], 27 | "height": 128 28 | }, 29 | "model": { 30 | "finetune_norm": false, 31 | "unet": { 32 | "in_channel": 2, // This should be equal to the number of variables * 2. Used in only networks.py 121 line. 33 | "out_channel": 1, // This should be equal to the number of variables. 34 | "inner_channel": 32, 35 | "norm_groups": 32, // 16 36 | "channel_multiplier": [1, 2, 4], 37 | "attn_res": [16, 32], // Possible values are 128, 64, 32, 16 and depends on channel_multipliers. 38 | "res_blocks": 1, 39 | "dropout": 0.3, 40 | "init_method": "kaiming" 41 | }, 42 | "beta_schedule": { 43 | "train": { 44 | "schedule": "cosine", 45 | "n_timestep": 2000, 46 | "linear_start": 1e-6, 47 | "linear_end": 1e-2 48 | }, 49 | "val": { 50 | "schedule": "cosine", 51 | "n_timestep": 2000, // 80 52 | "linear_start": 1e-6, 53 | "linear_end": 1e-2 54 | }, 55 | "test": { 56 | "schedule": "cosine", 57 | "n_timestep": 2000, 58 | "linear_start": 1e-6, 59 | "linear_end": 1e-2 60 | } 61 | }, 62 | "diffusion": { 63 | "conditional": true, 64 | "loss": "l2" 65 | } 66 | }, 67 | "training": { 68 | "epoch_n_iter": 1000000, // 100000 69 | "val_freq": 1000000, // 20000 70 | "save_checkpoint_freq": 1000000, // 20000 71 | "print_freq": 200, 72 | "n_val_vis": 1, 73 | "val_vis_freq": 500, 74 | "sample_size": 5, 75 | "optimizer": { 76 | "type": "adam", // Possible types are ['adam', 'adamw'] 77 | "amsgrad": false, 78 | "lr": 1e-6 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /experiment_configs/t2m_GlobalStandardScaling_NotMonthly_longer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t2m_GlobalSS_NotMonthly_Longer", 3 | "phase": "train", 4 | "gpu_ids": [0], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state": null 11 | }, 12 | "data": { 13 | "name": "WeatherBench", 14 | "dataroot": "/mnt/data/papikyan/WeatherBench/numpy", 15 | "batch_size": 16, 16 | "num_workers": 4, 17 | "use_shuffle": true, 18 | "train_min_date": "1979-01-01-00", 19 | "train_max_date": "2016-01-01-00", 20 | "train_subset_min_date": "2014-01-01-00", 21 | "train_subset_max_date": "2016-01-01-00", 22 | "transformation": "GlobalStandardScaling", 23 | "apply_tranform_monthly": false, 24 | "val_min_date": "2016-01-01-00", 25 | "val_max_date": "2018-01-01-00", 26 | "variables": ["t2m"], 27 | "height": 128 28 | }, 29 | "model": { 30 | "finetune_norm": false, 31 | "unet": { 32 | "in_channel": 2, // This should be equal to the number of variables * 2. Used in only networks.py 121 line. 33 | "out_channel": 1, // This should be equal to the number of variables. 34 | "inner_channel": 64, 35 | "norm_groups": 32, // 16 36 | "channel_multiplier": [1, 2, 4, 8], 37 | "attn_res": [16], // Possible values are 128, 64, 32, 16 and depends on channel_multipliers. 38 | "res_blocks": 1, 39 | "dropout": 0.7, 40 | "init_method": "kaiming" 41 | }, 42 | "beta_schedule": { 43 | "train": { 44 | "schedule": "cosine", 45 | "n_timestep": 2000, 46 | "linear_start": 1e-6, 47 | "linear_end": 1e-2 48 | }, 49 | "val": { 50 | "schedule": "cosine", 51 | "n_timestep": 100, 52 | "linear_start": 1e-6, 53 | "linear_end": 1e-2 54 | }, 55 | "test": { 56 | "schedule": "cosine", 57 | "n_timestep": 2000, 58 | "linear_start": 1e-6, 59 | "linear_end": 1e-2 60 | } 61 | }, 62 | "diffusion": { 63 | "conditional": true, 64 | "loss": "l2" 65 | } 66 | }, 67 | "training": { 68 | "epoch_n_iter": 200000, 69 | "val_freq": 40000, 70 | "save_checkpoint_freq": 40000, 71 | "print_freq": 200, 72 | "n_val_vis": 1, 73 | "val_vis_freq": 500, 74 | "sample_size": 10, 75 | "optimizer": { 76 | "type": "adam", // Possible types are ['adam', 'adamw'] 77 | "amsgrad": false, 78 | "lr": 5e-5 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /experiment_configs/t2m_LocalStandardScaling_Monthly.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t2m_LocalSS_Monthly", 3 | "phase": "train", 4 | "gpu_ids": [0], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state": null 11 | }, 12 | "data": { 13 | "name": "WeatherBench", 14 | "dataroot": "/mnt/data/papikyan/WeatherBench/numpy", 15 | "batch_size": 16, 16 | "num_workers": 4, 17 | "use_shuffle": true, 18 | "train_min_date": "1979-01-01-00", 19 | "train_max_date": "2016-01-01-00", 20 | "train_subset_min_date": "2014-01-01-00", 21 | "train_subset_max_date": "2016-01-01-00", 22 | "transformation": "LocalStandardScaling", 23 | "apply_tranform_monthly": true, 24 | "val_min_date": "2016-01-01-00", 25 | "val_max_date": "2018-01-01-00", 26 | "variables": ["t2m"], 27 | "height": 128 28 | }, 29 | "model": { 30 | "finetune_norm": false, 31 | "unet": { 32 | "in_channel": 2, // This should be equal to the number of variables * 2. Used in only networks.py 121 line. 33 | "out_channel": 1, // This should be equal to the number of variables. 34 | "inner_channel": 64, 35 | "norm_groups": 32, // 16 36 | "channel_multiplier": [1, 2, 4, 8], 37 | "attn_res": [16], // Possible values are 128, 64, 32, 16 and depends on channel_multipliers. 38 | "res_blocks": 1, 39 | "dropout": 0.7, 40 | "init_method": "kaiming" 41 | }, 42 | "beta_schedule": { 43 | "train": { 44 | "schedule": "cosine", 45 | "n_timestep": 2000, 46 | "linear_start": 1e-6, 47 | "linear_end": 1e-2 48 | }, 49 | "val": { 50 | "schedule": "cosine", 51 | "n_timestep": 100, 52 | "linear_start": 1e-6, 53 | "linear_end": 1e-2 54 | }, 55 | "test": { 56 | "schedule": "cosine", 57 | "n_timestep": 1000, 58 | "linear_start": 1e-6, 59 | "linear_end": 1e-2 60 | } 61 | }, 62 | "diffusion": { 63 | "conditional": true, 64 | "loss": "l2" 65 | } 66 | }, 67 | "training": { 68 | "epoch_n_iter": 20000, 69 | "val_freq": 2000, 70 | "save_checkpoint_freq": 2000, 71 | "print_freq": 100, 72 | "n_val_vis": 1, 73 | "val_vis_freq": 500, 74 | "sample_size": 5, 75 | "optimizer": { 76 | "type": "adam", // Possible types are ['adam', 'adamw'] 77 | "amsgrad": false, 78 | "lr": 5e-5 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /experiment_configs/t2m_LocalStandardScaling_NotMonthly.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t2m_LocalSS_NotMonthly", 3 | "phase": "train", 4 | "gpu_ids": [0], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state": null 11 | }, 12 | "data": { 13 | "name": "WeatherBench", 14 | "dataroot": "/mnt/data/papikyan/WeatherBench/numpy", 15 | "batch_size": 16, 16 | "num_workers": 4, 17 | "use_shuffle": true, 18 | "train_min_date": "1979-01-01-00", 19 | "train_max_date": "2016-01-01-00", 20 | "train_subset_min_date": "2014-01-01-00", 21 | "train_subset_max_date": "2016-01-01-00", 22 | "transformation": "LocalStandardScaling", 23 | "apply_tranform_monthly": false, 24 | "val_min_date": "2016-01-01-00", 25 | "val_max_date": "2018-01-01-00", 26 | "variables": ["t2m"], 27 | "height": 128 28 | }, 29 | "model": { 30 | "finetune_norm": false, 31 | "unet": { 32 | "in_channel": 2, // This should be equal to the number of variables * 2. Used in only networks.py 121 line. 33 | "out_channel": 1, // This should be equal to the number of variables. 34 | "inner_channel": 64, 35 | "norm_groups": 32, // 16 36 | "channel_multiplier": [1, 2, 4, 8], 37 | "attn_res": [16], // Possible values are 128, 64, 32, 16 and depends on channel_multipliers. 38 | "res_blocks": 1, 39 | "dropout": 0.7, 40 | "init_method": "kaiming" 41 | }, 42 | "beta_schedule": { 43 | "train": { 44 | "schedule": "cosine", 45 | "n_timestep": 2000, 46 | "linear_start": 1e-6, 47 | "linear_end": 1e-2 48 | }, 49 | "val": { 50 | "schedule": "cosine", 51 | "n_timestep": 100, 52 | "linear_start": 1e-6, 53 | "linear_end": 1e-2 54 | }, 55 | "test": { 56 | "schedule": "cosine", 57 | "n_timestep": 1000, 58 | "linear_start": 1e-6, 59 | "linear_end": 1e-2 60 | } 61 | }, 62 | "diffusion": { 63 | "conditional": true, 64 | "loss": "l2" 65 | } 66 | }, 67 | "training": { 68 | "epoch_n_iter": 20000, 69 | "val_freq": 2000, 70 | "save_checkpoint_freq": 2000, 71 | "print_freq": 100, 72 | "n_val_vis": 1, 73 | "val_vis_freq": 500, 74 | "sample_size": 5, 75 | "optimizer": { 76 | "type": "adam", // Possible types are ['adam', 'adamw'] 77 | "amsgrad": false, 78 | "lr": 5e-5 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /experiment_configs/test_single_image_t2m_GlobalStandardScaling_NotMonthly.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t2m_GlobalSS_NotMonthly_dropout_03_const_lr", 3 | "phase": "val", 4 | "gpu_ids": [0], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state": "/home/papikyan/guided_research/development/experiments/experiments_wbd/experiments/t2m_GlobalSS_NotMonthly_220216_103355/checkpoint/I60000_E3" 11 | }, 12 | "data": { 13 | "name": "WeatherBench", 14 | "dataroot": "/mnt/data/papikyan/WeatherBench/numpy", 15 | "batch_size": 16, 16 | "num_workers": 4, 17 | "use_shuffle": true, 18 | "train_min_date": "1970-01-01-00", 19 | "train_max_date": "2000-01-01-00", 20 | "train_subset_min_date": "2014-01-01-00", 21 | "train_subset_max_date": "2016-01-01-00", 22 | "transformation": "GlobalStandardScaling", 23 | "apply_tranform_monthly": false, 24 | "val_min_date": "2016-01-01-00", 25 | "val_max_date": "2018-01-01-00", 26 | "variables": ["t2m"], 27 | "height": 128 28 | }, 29 | "model": { 30 | "finetune_norm": false, 31 | "unet": { 32 | "in_channel": 2, // This should be equal to the number of variables * 2. Used in only networks.py 121 line. 33 | "out_channel": 1, // This should be equal to the number of variables. 34 | "inner_channel": 32, 35 | "norm_groups": 32, // 16 36 | "channel_multiplier": [1, 2, 4], 37 | "attn_res": [16, 32], // Possible values are 128, 64, 32, 16 and depends on channel_multipliers. 38 | "res_blocks": 1, 39 | "dropout": 0.3, 40 | "init_method": "kaiming" 41 | }, 42 | "beta_schedule": { 43 | "train": { 44 | "schedule": "cosine", 45 | "n_timestep": 2000, 46 | "linear_start": 1e-6, 47 | "linear_end": 1e-2 48 | }, 49 | "val": { 50 | "schedule": "cosine", 51 | "n_timestep": 2000, 52 | "linear_start": 1e-6, 53 | "linear_end": 1e-2 54 | }, 55 | "test": { 56 | "schedule": "cosine", 57 | "n_timestep": 2000, 58 | "linear_start": 1e-6, 59 | "linear_end": 1e-2 60 | } 61 | }, 62 | "diffusion": { 63 | "conditional": true, 64 | "loss": "l2" 65 | } 66 | }, 67 | "training": { 68 | "epoch_n_iter": 20000, 69 | "val_freq": 2000, 70 | "save_checkpoint_freq": 2000, 71 | "print_freq": 100, 72 | "n_val_vis": 1, 73 | "val_vis_freq": 14680, // 16000 74 | "sample_size": 500, 75 | "optimizer": { 76 | "type": "adam", // Possible types are ['adam', 'adamw'] 77 | "amsgrad": false, 78 | "lr": 5e-5 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /experiment_configs/test_t2m_GlobalStandardScaling_NotMonthly.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t2m_GlobalSS_NotMonthly", 3 | "phase": "val", 4 | "gpu_ids": [0], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state": "/home/papikyan/guided_research/development/experiments/experiments_wbd/experiments/t2m_GlobalSS_NotMonthly_220216_103355/checkpoint/I60000_E3" 11 | }, 12 | "data": { 13 | "name": "WeatherBench", 14 | "dataroot": "/mnt/data/papikyan/WeatherBench/numpy", 15 | "batch_size": 16, 16 | "num_workers": 4, 17 | "use_shuffle": true, 18 | "train_min_date": "1970-01-01-00", 19 | "train_max_date": "2000-01-01-00", 20 | "train_subset_min_date": "2014-01-01-00", 21 | "train_subset_max_date": "2016-01-01-00", 22 | "transformation": "GlobalStandardScaling", 23 | "apply_tranform_monthly": false, 24 | "val_min_date": "2016-01-01-00", 25 | "val_max_date": "2018-01-01-00", 26 | "variables": ["t2m"], 27 | "height": 128 28 | }, 29 | "model": { 30 | "finetune_norm": false, 31 | "unet": { 32 | "in_channel": 2, // This should be equal to the number of variables * 2. Used in only networks.py 121 line. 33 | "out_channel": 1, // This should be equal to the number of variables. 34 | "inner_channel": 32, 35 | "norm_groups": 32, // 16 36 | "channel_multiplier": [1, 2, 4], 37 | "attn_res": [16, 32], // Possible values are 128, 64, 32, 16 and depends on channel_multipliers. 38 | "res_blocks": 1, 39 | "dropout": 0.3, 40 | "init_method": "kaiming" 41 | }, 42 | "beta_schedule": { 43 | "train": { 44 | "schedule": "cosine", 45 | "n_timestep": 2000, 46 | "linear_start": 1e-6, 47 | "linear_end": 1e-2 48 | }, 49 | "val": { 50 | "schedule": "cosine", 51 | "n_timestep": 2000, 52 | "linear_start": 1e-6, 53 | "linear_end": 1e-2 54 | }, 55 | "test": { 56 | "schedule": "cosine", 57 | "n_timestep": 2000, 58 | "linear_start": 1e-6, 59 | "linear_end": 1e-2 60 | } 61 | }, 62 | "diffusion": { 63 | "conditional": true, 64 | "loss": "l2" 65 | } 66 | }, 67 | "training": { 68 | "epoch_n_iter": 20000, 69 | "val_freq": 2000, 70 | "save_checkpoint_freq": 2000, 71 | "print_freq": 100, 72 | "n_val_vis": 1, 73 | "val_vis_freq": 100000000, // 500 74 | "sample_size": 2, // 5 75 | "optimizer": { 76 | "type": "adam", // Possible types are ['adam', 'adamw'] 77 | "amsgrad": false, 78 | "lr": 5e-5 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """The inference script for DDPM model. 2 | """ 3 | import argparse 4 | import logging 5 | import os 6 | import pickle 7 | import warnings 8 | from collections import OrderedDict, defaultdict 9 | 10 | import torch 11 | from torch.nn.functional import mse_loss, l1_loss 12 | from torch.utils.data import DataLoader 13 | 14 | import model 15 | from config import Config, get_current_datetime 16 | from utils import dict2str, setup_logger, construct_and_save_wbd_plots, \ 17 | construct_mask, reverse_transform_candidates, set_seeds 18 | from weatherbench_data import collate_wb_batch 19 | from weatherbench_data.utils import reverse_transform, reverse_transform_tensor, load_object, prepare_test_data 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | if __name__ == "__main__": 25 | set_seeds() # For reproducability. 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("-c", "--config", type=str, help="JSON file for configuration") 29 | parser.add_argument("-p", "--phase", type=str, choices=["train", "val"], 30 | help="Run either training or validation(inference).", default="train") 31 | parser.add_argument("-gpu", "--gpu_ids", type=str, default=None) 32 | args = parser.parse_args() 33 | configs = Config(args) 34 | 35 | torch.backends.cudnn.enabled = True 36 | torch.backends.cudnn.benchmark = True 37 | 38 | test_root = f"{configs.experiments_root}/test_{get_current_datetime()}" 39 | os.makedirs(test_root, exist_ok=True) 40 | setup_logger("test", test_root, "test", screen=True) 41 | val_logger = logging.getLogger("test") 42 | val_logger.info(dict2str(configs.get_hyperparameters_as_dict())) 43 | 44 | # Preparing testing data. 45 | transformations = load_object(configs.experiments_root, "transformations") 46 | metadata = load_object(configs.experiments_root, "metadata") 47 | if transformations and metadata: 48 | val_logger.info("Transformations and Metadata are successfuly loaded.") 49 | 50 | val_dataset = prepare_test_data(variables=configs.variables, val_min_date=configs.val_min_date, 51 | val_max_date=configs.val_max_date, dataroot=configs.dataroot, 52 | transformations=transformations) 53 | 54 | val_logger.info(f"Dataset [{val_dataset.__class__.__name__} - Testing] is created.") 55 | val_logger.info(f"""Created {val_dataset.__class__.__name__} dataset of length {len(val_dataset)}, 56 | containing data from {val_dataset.min_date} until {val_dataset.max_date}""") 57 | val_logger.info(f"Group structure: {val_dataset.get_data_names()}") 58 | val_logger.info(f"Channel count: {val_dataset.get_channel_count()}\n") 59 | 60 | val_loader = DataLoader(val_dataset, batch_size=32, 61 | collate_fn=collate_wb_batch, 62 | pin_memory=True, drop_last=True, 63 | num_workers=configs.num_workers) 64 | val_logger.info("Testing dataset is ready.") 65 | 66 | # Defining the model. 67 | diffusion = model.create_model(in_channel=configs.in_channel, out_channel=configs.out_channel, 68 | norm_groups=configs.norm_groups, inner_channel=configs.inner_channel, 69 | channel_multiplier=configs.channel_multiplier, attn_res=configs.attn_res, 70 | res_blocks=configs.res_blocks, dropout=configs.dropout, 71 | diffusion_loss=configs.diffusion_loss, conditional=configs.conditional, 72 | gpu_ids=configs.gpu_ids, distributed=configs.distributed, 73 | init_method=configs.init_method, train_schedule=configs.train_schedule, 74 | train_n_timestep=configs.train_n_timestep, 75 | train_linear_start=configs.train_linear_start, 76 | train_linear_end=configs.train_linear_end, 77 | val_schedule=configs.val_schedule, val_n_timestep=configs.val_n_timestep, 78 | val_linear_start=configs.val_linear_start, val_linear_end=configs.val_linear_end, 79 | finetune_norm=configs.finetune_norm, optimizer=None, amsgrad=configs.amsgrad, 80 | learning_rate=configs.lr, checkpoint=configs.checkpoint, 81 | resume_state=configs.resume_state, phase=configs.phase, height=configs.height) 82 | val_logger.info("Model initialization is finished.") 83 | 84 | current_step, current_epoch = diffusion.begin_step, diffusion.begin_epoch 85 | val_logger.info(f"Testing the model at epoch: {current_epoch}, iter: {current_step}.") 86 | 87 | diffusion.set_new_noise_schedule(schedule=configs.test_schedule, 88 | n_timestep=configs.test_n_timestep, 89 | linear_start=configs.test_linear_start, 90 | linear_end=configs.test_linear_end) 91 | accumulated_statistics = OrderedDict() 92 | 93 | # Creating placeholder for storing validation metrics Mean Squared Error, Root MSE, Mean Residual. 94 | val_metrics = OrderedDict({"MSE": 0.0, "RMSE": 0.0, "MAE": 0.0, "MR": 0.0, 95 | "Mean_Temp_MSE": 0.0, "Std_Temp_MSE": 0.0, 96 | "Mean_Temp_MAE": 0.0, "Std_Temp_MAE": 0.0}) 97 | idx = 0 98 | 99 | result_path = f"{test_root}/results" 100 | os.makedirs(result_path, exist_ok=True) 101 | 102 | # A dictionary for storing a list of mean temperatures for each month. 103 | month2mean_temperature = defaultdict(list) 104 | 105 | with torch.no_grad(): 106 | for val_data in val_loader: 107 | idx += 1 108 | diffusion.feed_data(val_data) 109 | diffusion.test(continuous=False) # Continues=False to return only the last timesteps's outcome. 110 | 111 | # Computing metrics on vlaidation data. 112 | visuals = diffusion.get_current_visuals() 113 | 114 | inv_visuals = reverse_transform(visuals, transformations, 115 | configs.variables, diffusion.get_months(), 116 | configs.tranform_monthly) 117 | 118 | # Computing MSE and RMSE on original data. 119 | mse_value = mse_loss(inv_visuals["HR"], inv_visuals["SR"]) 120 | val_metrics["MSE"] += mse_value 121 | val_metrics["RMSE"] += torch.sqrt(mse_value) 122 | val_metrics["MAE"] += l1_loss(inv_visuals["HR"], inv_visuals["SR"]) 123 | 124 | # How well model estimates the mean and standard deviation of temperature? 125 | mean_temp = inv_visuals["HR"].mean(axis=[1, 2, 3]) 126 | mean_temp_pred = inv_visuals["SR"].mean(axis=[1, 2, 3]) 127 | std_temp = inv_visuals["HR"].std(axis=[1, 2, 3]) 128 | std_temp_pred = inv_visuals["SR"].std(axis=[1, 2, 3]) 129 | val_metrics["Mean_Temp_MSE"] += mse_loss(mean_temp, mean_temp_pred) 130 | val_metrics["Std_Temp_MSE"] += mse_loss(std_temp, std_temp_pred) 131 | val_metrics["Mean_Temp_MAE"] += l1_loss(mean_temp, mean_temp_pred) 132 | val_metrics["Std_Temp_MAE"] += l1_loss(std_temp, std_temp_pred) 133 | 134 | for m, t in zip(diffusion.get_months(), mean_temp_pred): 135 | month2mean_temperature[int(m)].append(t) 136 | 137 | # Computing residuals for visualization. 138 | residuals = inv_visuals["SR"] - inv_visuals["HR"] 139 | val_metrics["MR"] += residuals.mean() 140 | 141 | if idx % configs.val_vis_freq == 0: 142 | path = f"{result_path}/{current_epoch}_{current_step}_{idx}" 143 | val_logger.info(f"[{idx//configs.val_vis_freq}] Visualizing and storing some examples.") 144 | 145 | sr_candidates = diffusion.generate_multiple_candidates(n=configs.sample_size) 146 | reverse_transform_candidates(sr_candidates, reverse_transform_tensor, 147 | transformations, configs.variables, 148 | "hr", diffusion.get_months(), 149 | configs.tranform_monthly) 150 | mean_candidate = sr_candidates.mean(dim=0) # [B, C, H, W] 151 | std_candidate = sr_candidates.std(dim=0) # [B, C, H, W] 152 | bias = mean_candidate - visuals["HR"] 153 | mean_bias_over_pixels = bias.mean() # Scalar. 154 | std_bias_over_pixels = bias.std() # Scalar. 155 | 156 | # Computing min and max measures to set a fixed colorbar for all visualizations. 157 | vmin = min(inv_visuals["HR"][:configs.n_val_vis].min(), 158 | inv_visuals["SR"][:configs.n_val_vis].min(), 159 | inv_visuals["LR"][:configs.n_val_vis].min(), 160 | inv_visuals["INTERPOLATED"][:configs.n_val_vis].min(), 161 | mean_candidate[:configs.n_val_vis].min()) 162 | vmax = max(inv_visuals["HR"][:configs.n_val_vis].max(), 163 | inv_visuals["SR"][:configs.n_val_vis].max(), 164 | inv_visuals["LR"][:configs.n_val_vis].max(), 165 | inv_visuals["INTERPOLATED"][:configs.n_val_vis].max(), 166 | mean_candidate[:configs.n_val_vis].max()) 167 | 168 | # Choosing the first n_val_vis number of samples to visualize. 169 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 170 | data=inv_visuals["HR"][:configs.n_val_vis], 171 | path=f"{path}_hr.png", vmin=vmin, vmax=vmax) 172 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 173 | data=inv_visuals["SR"][:configs.n_val_vis], 174 | path=f"{path}_sr.png", vmin=vmin, vmax=vmax) 175 | construct_and_save_wbd_plots(latitude=metadata.lr_lat, longitude=metadata.lr_lon, 176 | data=inv_visuals["LR"][:configs.n_val_vis], 177 | path=f"{path}_lr.png", vmin=vmin, vmax=vmax) 178 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 179 | data=inv_visuals["INTERPOLATED"][:configs.n_val_vis], 180 | path=f"{path}_interpolated.png", vmin=vmin, vmax=vmax) 181 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 182 | data=construct_mask(residuals[:configs.n_val_vis]), 183 | path=f"{path}_residual.png", vmin=-1, vmax=1, 184 | costline_color="red", cmap="binary", 185 | label="Signum(SR - HR)") 186 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 187 | data=mean_candidate[:configs.n_val_vis], 188 | path=f"{path}_mean_sr.png", vmin=vmin, vmax=vmax) 189 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 190 | data=std_candidate[:configs.n_val_vis], 191 | path=f"{path}_std_sr.png", vmin=0, cmap="Greens") 192 | 193 | # Validation is finished. 194 | val_metrics["MSE"] /= idx 195 | val_metrics["RMSE"] /= idx 196 | val_metrics["MR"] /= idx 197 | val_metrics["MAE"] /= idx 198 | val_metrics["Mean_Temp_MSE"] /= idx 199 | val_metrics["Std_Temp_MSE"] /= idx 200 | val_metrics["Mean_Temp_MAE"] /= idx 201 | val_metrics["Std_Temp_MAE"] /= idx 202 | 203 | message = f"Epoch: {current_epoch:5} | Iteration: {current_step:8}" 204 | for metric, value in val_metrics.items(): 205 | message = f"{message} | {metric:s}: {value:.5f}" 206 | val_logger.info(message) 207 | 208 | with open(f"{test_root}/month2mean_temperature.pickle", 'wb') as handle: 209 | pickle.dump(month2mean_temperature, handle, protocol=pickle.HIGHEST_PROTOCOL) 210 | 211 | val_logger.info("End of testing.") 212 | -------------------------------------------------------------------------------- /inference_on_single_image.py: -------------------------------------------------------------------------------- 1 | """The inference script for DDPM model for a single image. 2 | 3 | This script is not a part of the pipeline of the project. It 4 | was used to generate plots and statistics for a single 5 | data sample case. 6 | """ 7 | import argparse 8 | import logging 9 | import os 10 | import warnings 11 | from collections import OrderedDict 12 | 13 | import numpy as np 14 | import torch 15 | from torch.nn.functional import mse_loss, l1_loss 16 | from torch.utils.data import DataLoader 17 | 18 | import model 19 | from config import Config, get_current_datetime 20 | from utils import dict2str, setup_logger, construct_and_save_wbd_plots, \ 21 | construct_mask, reverse_transform_candidates, set_seeds 22 | from weatherbench_data import collate_wb_batch 23 | from weatherbench_data.utils import reverse_transform, reverse_transform_tensor, load_object, prepare_test_data 24 | 25 | warnings.filterwarnings("ignore") 26 | 27 | 28 | if __name__ == "__main__": 29 | set_seeds() # For reproducability. 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("-c", "--config", type=str, help="JSON file for configuration") 33 | parser.add_argument("-p", "--phase", type=str, choices=["train", "val"], 34 | help="Run either training or validation(inference).", default="train") 35 | parser.add_argument("-gpu", "--gpu_ids", type=str, default=None) 36 | args = parser.parse_args() 37 | configs = Config(args) 38 | 39 | torch.backends.cudnn.enabled = True 40 | torch.backends.cudnn.benchmark = True 41 | 42 | test_root = f"{configs.experiments_root}/test_on_single_image_{get_current_datetime()}" 43 | os.makedirs(test_root, exist_ok=True) 44 | setup_logger("test", test_root, "test", screen=True) 45 | val_logger = logging.getLogger("test") 46 | val_logger.info(dict2str(configs.get_hyperparameters_as_dict())) 47 | 48 | # Preparing testing data. 49 | transformations = load_object(configs.experiments_root, "transformations") 50 | metadata = load_object(configs.experiments_root, "metadata") 51 | if transformations and metadata: 52 | val_logger.info("Transformations and Metadata are successfuly loaded.") 53 | 54 | val_dataset = prepare_test_data(variables=configs.variables, val_min_date=configs.val_min_date, 55 | val_max_date=configs.val_max_date, dataroot=configs.dataroot, 56 | transformations=transformations) 57 | 58 | val_logger.info(f"Dataset [{val_dataset.__class__.__name__} - Testing] is created.") 59 | val_logger.info(f"""Created {val_dataset.__class__.__name__} dataset of length {len(val_dataset)}, 60 | containing data from {val_dataset.min_date} until {val_dataset.max_date}""") 61 | val_logger.info(f"Group structure: {val_dataset.get_data_names()}") 62 | val_logger.info(f"Channel count: {val_dataset.get_channel_count()}\n") 63 | 64 | # The batch size is 1. 65 | val_loader = DataLoader(val_dataset, 66 | collate_fn=collate_wb_batch, 67 | pin_memory=True, drop_last=True, 68 | num_workers=configs.num_workers) 69 | val_logger.info("Testing dataset is ready.") 70 | 71 | # Defining the model. 72 | diffusion = model.create_model(in_channel=configs.in_channel, out_channel=configs.out_channel, 73 | norm_groups=configs.norm_groups, inner_channel=configs.inner_channel, 74 | channel_multiplier=configs.channel_multiplier, attn_res=configs.attn_res, 75 | res_blocks=configs.res_blocks, dropout=configs.dropout, 76 | diffusion_loss=configs.diffusion_loss, conditional=configs.conditional, 77 | gpu_ids=configs.gpu_ids, distributed=configs.distributed, 78 | init_method=configs.init_method, train_schedule=configs.train_schedule, 79 | train_n_timestep=configs.train_n_timestep, 80 | train_linear_start=configs.train_linear_start, 81 | train_linear_end=configs.train_linear_end, 82 | val_schedule=configs.val_schedule, val_n_timestep=configs.val_n_timestep, 83 | val_linear_start=configs.val_linear_start, val_linear_end=configs.val_linear_end, 84 | finetune_norm=configs.finetune_norm, optimizer=None, amsgrad=configs.amsgrad, 85 | learning_rate=configs.lr, checkpoint=configs.checkpoint, 86 | resume_state=configs.resume_state, phase=configs.phase, height=configs.height) 87 | val_logger.info("Model initialization is finished.") 88 | 89 | current_step, current_epoch = diffusion.begin_step, diffusion.begin_epoch 90 | val_logger.info(f"Testing the model at epoch: {current_epoch}, iter: {current_step}.") 91 | 92 | diffusion.set_new_noise_schedule(schedule=configs.test_schedule, 93 | n_timestep=configs.test_n_timestep, 94 | linear_start=configs.test_linear_start, 95 | linear_end=configs.test_linear_end) 96 | accumulated_statistics = OrderedDict() 97 | 98 | # Creating placeholder for storing validation metrics Mean Squared Error, Root MSE, Mean Residual. 99 | val_metrics = OrderedDict({"MSE": 0.0, "RMSE": 0.0, "MAE": 0.0, "MR": 0.0, 100 | "mean_bias_over_pixels": 0.0, "std_bias_over_pixels": 0.0}) 101 | idx = 0 102 | 103 | result_path = f"{test_root}/results" 104 | os.makedirs(result_path, exist_ok=True) 105 | 106 | with torch.no_grad(): 107 | for val_data in val_loader: 108 | idx += 1 109 | 110 | # Works only for one image. 111 | if idx % configs.val_vis_freq == 0: 112 | diffusion.feed_data(val_data) 113 | val_logger.info("Starting to generate SR images.") 114 | diffusion.test(continuous=True) # Continues=False to return only the last timesteps's outcome. 115 | val_logger.info("Finished generating SR images.") 116 | 117 | # Computing metrics on vlaidation data.. 118 | visuals = diffusion.get_current_visuals() 119 | # When continuous is True, visuals["SR"] has [T, C, H, W] dimension 120 | # where T is the number of diffusion timesteps. 121 | 122 | inv_visuals = reverse_transform(visuals, transformations, 123 | configs.variables, diffusion.get_months(), 124 | configs.tranform_monthly) 125 | 126 | # Computing MSE and RMSE on original data. 127 | mse_value = mse_loss(inv_visuals["HR"], inv_visuals["SR"][-1]) 128 | val_metrics["MSE"] += mse_value 129 | val_metrics["RMSE"] += torch.sqrt(mse_value) 130 | val_metrics["MAE"] += l1_loss(inv_visuals["HR"], inv_visuals["SR"][-1]) 131 | 132 | # Computing residuals for visualization. 133 | residuals = inv_visuals["SR"][-1] - inv_visuals["HR"] 134 | val_metrics["MR"] += residuals.mean() 135 | 136 | path = f"{result_path}/{idx}/" 137 | os.makedirs(path, exist_ok=True) 138 | path = f"{path}{idx}" 139 | 140 | val_logger.info("Started generating multiple SR candidates.") 141 | sr_candidates = diffusion.generate_multiple_candidates(n=configs.sample_size) 142 | reverse_transform_candidates(sr_candidates, reverse_transform_tensor, 143 | transformations, configs.variables, 144 | "hr", diffusion.get_months(), 145 | configs.tranform_monthly) 146 | val_logger.info("Finished generating multiple SR candidates.") 147 | 148 | mean_candidate = sr_candidates.mean(dim=0) 149 | std_candidate = sr_candidates.std(dim=0) 150 | bias = mean_candidate - inv_visuals["HR"] 151 | mean_bias_over_pixels = bias.mean() 152 | std_bias_over_pixels = bias.std() 153 | val_metrics["mean_bias_over_pixels"] += mean_bias_over_pixels 154 | val_metrics["std_bias_over_pixels"] += std_bias_over_pixels 155 | 156 | # Computing min and max measures to set a fixed colorbar for all visualizations. 157 | vmin = min(inv_visuals["HR"].min(), 158 | inv_visuals["LR"].min(), 159 | inv_visuals["INTERPOLATED"].min(), 160 | mean_candidate.min()) 161 | vmax = max(inv_visuals["HR"].max(), 162 | inv_visuals["LR"].max(), 163 | inv_visuals["INTERPOLATED"].max(), 164 | mean_candidate.max()) 165 | vmin, vmax = np.floor(vmin), np.ceil(vmax) 166 | 167 | val_logger.info(f"[{idx // configs.val_vis_freq}] Visualizing and storing some examples.") 168 | # Choosing the first n_val_vis number of samples to visualize. 169 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 170 | data=inv_visuals["HR"], 171 | path=f"{path}_hr.png", vmin=vmin, vmax=vmax) 172 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 173 | data=inv_visuals["SR"], 174 | path=f"{path}_sr.png", vmin=vmin, vmax=vmax) 175 | construct_and_save_wbd_plots(latitude=metadata.lr_lat, longitude=metadata.lr_lon, 176 | data=inv_visuals["LR"], 177 | path=f"{path}_lr.png", vmin=vmin, vmax=vmax) 178 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 179 | data=inv_visuals["INTERPOLATED"], 180 | path=f"{path}_interpolated.png", vmin=vmin, vmax=vmax) 181 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 182 | data=construct_mask(residuals), 183 | path=f"{path}_residual.png", vmin=-1, vmax=1, 184 | costline_color="red", cmap="binary", 185 | label="Signum(SR - HR)") 186 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 187 | data=mean_candidate, 188 | path=f"{path}_mean_sr.png", vmin=vmin, vmax=vmax) 189 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 190 | data=std_candidate, 191 | path=f"{path}_std_sr.png", vmin=0.0, cmap="Greens") 192 | 193 | normalizing_constant = idx // configs.val_vis_freq 194 | val_metrics["MSE"] /= normalizing_constant 195 | val_metrics["RMSE"] /= normalizing_constant 196 | val_metrics["MR"] /= normalizing_constant 197 | val_metrics["MAE"] /= normalizing_constant 198 | val_metrics["mean_bias_over_pixels"] /= normalizing_constant 199 | val_metrics["std_bias_over_pixels"] /= normalizing_constant 200 | 201 | message = f"Epoch: {current_epoch:5} | Iteration: {current_step:8}" 202 | for metric, value in val_metrics.items(): 203 | message = f"{message} | {metric:s}: {value:.5f}" 204 | val_logger.info(message) 205 | 206 | torch.save(inv_visuals["LR"], f"{path}_LR.pt") 207 | torch.save(inv_visuals["HR"], f"{path}_HR.pt") 208 | torch.save(inv_visuals["INTERPOLATED"], f"{path}_INTERPOLATED.pt") 209 | torch.save(inv_visuals["SR"], f"{path}_SR.pt") 210 | torch.save(mean_candidate, f"{path}_mean_sr.pt") 211 | torch.save(std_candidate, f"{path}_std.pt") 212 | val_logger.info("End of testing.") 213 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for creating end-to-end network for 2 | Single Image Super-Resolution task with DDPM. 3 | 4 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 5 | """ 6 | import logging 7 | 8 | from .model import DDPM 9 | 10 | logger = logging.getLogger(name="base") 11 | 12 | 13 | def create_model(in_channel, out_channel, norm_groups, inner_channel, 14 | channel_multiplier, attn_res, res_blocks, dropout, 15 | diffusion_loss, conditional, gpu_ids, distributed, init_method, 16 | train_schedule, train_n_timestep, train_linear_start, train_linear_end, 17 | val_schedule, val_n_timestep, val_linear_start, val_linear_end, 18 | finetune_norm, optimizer, amsgrad, learning_rate, checkpoint, resume_state, 19 | phase, height): 20 | """Creates DDPM model. 21 | 22 | Args: 23 | in_channel: The number of channels of input tensor of U-Net. 24 | out_channel: The number of channels of output tensor of U-Net. 25 | norm_groups: The number of groups for group normalization. 26 | inner_channel: Timestep embedding dimension. 27 | channel_multiplier: A tuple specifying the scaling factors of channels. 28 | attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer. 29 | res_blocks: The number of residual blocks. 30 | dropout: Dropout probability. 31 | diffusion_loss: Either l1 or l2. 32 | conditional: Whether to condition on INTERPOLATED image or not. 33 | gpu_ids: IDs of gpus. 34 | distributed: Whether the computation will be distributed among multiple GPUs or not. 35 | init_method: NN weight initialization method. One of normal, kaiming or orthogonal inisializations. 36 | train_schedule: Defines the type of beta schedule for training. 37 | train_n_timestep: Number of diffusion timesteps for training. 38 | train_linear_start: Minimum value of the linear schedule for training. 39 | train_linear_end: Maximum value of the linear schedule for training. 40 | val_schedule: Defines the type of beta schedule for validation. 41 | val_n_timestep: Number of diffusion timesteps for validation. 42 | val_linear_start: Minimum value of the linear schedule for validation. 43 | val_linear_end: Maximum value of the linear schedule for validation. 44 | finetune_norm: Whetehr to fine-tune or train from scratch. 45 | optimizer: The optimization algorithm. 46 | amsgrad: Whether to use the AMSGrad variant of optimizer. 47 | learning_rate: The learning rate. 48 | checkpoint: Path to the checkpoint file. 49 | resume_state: The path to load the network. 50 | phase: Either train or val. 51 | height: U-Net input tensor height value. 52 | 53 | Returns: 54 | Returns DDPM model. 55 | """ 56 | diffusion_model = DDPM(in_channel=in_channel, out_channel=out_channel, norm_groups=norm_groups, 57 | inner_channel=inner_channel, channel_multiplier=channel_multiplier, 58 | attn_res=attn_res, res_blocks=res_blocks, dropout=dropout, 59 | diffusion_loss=diffusion_loss, conditional=conditional, 60 | gpu_ids=gpu_ids, distributed=distributed, init_method=init_method, 61 | train_schedule=train_schedule, train_n_timestep=train_n_timestep, 62 | train_linear_start=train_linear_start, train_linear_end=train_linear_end, 63 | val_schedule=val_schedule, val_n_timestep=val_n_timestep, 64 | val_linear_start=val_linear_start, val_linear_end=val_linear_end, 65 | finetune_norm=finetune_norm, optimizer=optimizer, amsgrad=amsgrad, 66 | learning_rate=learning_rate, checkpoint=checkpoint, 67 | resume_state=resume_state, phase=phase, height=height) 68 | logger.info("Model [{:s}] is created.".format(diffusion_model.__class__.__name__)) 69 | return diffusion_model 70 | -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | """Defines a base class for DDPM model. 2 | 3 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 4 | """ 5 | import typing 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class BaseModel: 12 | """A skeleton for DDPM models. 13 | 14 | Attributes: 15 | gpu_ids: IDs of gpus. 16 | """ 17 | 18 | def __init__(self, gpu_ids): 19 | self.gpu_ids = gpu_ids 20 | self.device = torch.device("cuda" if torch.cuda.is_available() and self.gpu_ids else "cpu") 21 | self.begin_step, self.begin_epoch = 0, 0 22 | 23 | def feed_data(self, data) -> None: 24 | """Provides model with data. 25 | 26 | Args: 27 | data: A batch of data. 28 | """ 29 | pass 30 | 31 | def optimize_parameters(self) -> None: 32 | """Computes loss and performs GD step on learnable parameters. 33 | """ 34 | pass 35 | 36 | def get_current_visuals(self) -> dict: 37 | """Returns reconstructed data points. 38 | """ 39 | pass 40 | 41 | def print_network(self) -> None: 42 | """Prints the network architecture. 43 | """ 44 | pass 45 | 46 | def set_device(self, x): 47 | """Sets values of x onto device specified by an attribute of the same name. 48 | 49 | Args: 50 | x: Value storage. 51 | 52 | Returns: 53 | x set on self.device. 54 | """ 55 | if isinstance(x, dict): 56 | x = {key: (item.to(self.device) if item.numel() else item) for key, item in x.items()} 57 | elif isinstance(x, list): 58 | x = [item.to(self.device) if item else item for item in x] 59 | else: 60 | x = x.to(self.device) 61 | return x 62 | 63 | @staticmethod 64 | def get_network_description(network: nn.Module) -> typing.Tuple[str, int]: 65 | """Get the network name and parameters. 66 | 67 | Args: 68 | network: The neural netowrk. 69 | 70 | Returns: 71 | Name of the network and the number of parameters. 72 | """ 73 | if isinstance(network, nn.DataParallel): 74 | network = network.module 75 | n_params = sum(map(lambda x: x.numel(), network.parameters())) 76 | return str(network), n_params 77 | -------------------------------------------------------------------------------- /model/ema.py: -------------------------------------------------------------------------------- 1 | """Defines Exponential Moving Average class for 2 | model parameters. 3 | 4 | The work is based on https://github.com/ermongroup/ddim/blob/main/models/ema.py. 5 | """ 6 | 7 | import torch.nn as nn 8 | 9 | 10 | class EMA(object): 11 | """An Exponential Moving Average class. 12 | 13 | Attributes: 14 | mu: IDs of gpus. 15 | shadow: The storage for parameter values. 16 | """ 17 | 18 | def __init__(self, mu=0.999): 19 | self.mu = mu 20 | self.shadow = {} 21 | 22 | def register(self, module): 23 | """Registers network parameters. 24 | 25 | Args: 26 | module: A parameter module, typically a neural network. 27 | """ 28 | if isinstance(module, nn.DataParallel): 29 | module = module.module 30 | for name, param in module.named_parameters(): 31 | if param.requires_grad: 32 | self.shadow[name] = param.data.clone() 33 | 34 | def update(self, module): 35 | """Updates parameters with a decay rate mu and stores in a storage. 36 | 37 | Args: 38 | module: A parameter module, typically a neural network. 39 | """ 40 | if isinstance(module, nn.DataParallel): 41 | module = module.module 42 | for name, param in module.named_parameters(): 43 | if param.requires_grad: 44 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 45 | 46 | def ema(self, module): 47 | """Updates network parameters from the storage. 48 | 49 | Args: 50 | module: A parameter module, typically a neural network. 51 | """ 52 | if isinstance(module, nn.DataParallel): 53 | module = module.module 54 | for name, param in module.named_parameters(): 55 | if param.requires_grad: 56 | param.data.copy_(self.shadow[name].data) 57 | 58 | def ema_copy(self, module): 59 | """Updates network parameters from the storage and returns a copy of it. 60 | 61 | Args: 62 | module: A parameter module, typically a neural network. 63 | 64 | Returns: 65 | A copy of network parameters. 66 | """ 67 | if isinstance(module, nn.DataParallel): 68 | inner_module = module.module 69 | module_copy = type(inner_module)( 70 | inner_module.config).to(inner_module.config.device) 71 | module_copy.load_state_dict(inner_module.state_dict()) 72 | module_copy = nn.DataParallel(module_copy) 73 | else: 74 | module_copy = type(module)(module.config).to(module.config.device) 75 | module_copy.load_state_dict(module.state_dict()) 76 | 77 | self.ema(module_copy) 78 | return module_copy 79 | 80 | def state_dict(self): 81 | """Returns current state of model parameters. 82 | 83 | Returns: 84 | Current state of model parameters stored in a local storage. 85 | """ 86 | return self.shadow 87 | 88 | def load_state_dict(self, state_dict): 89 | """Update local storage of parameters. 90 | 91 | Args: 92 | state_dict: A state of network parameters for updating local storage. 93 | """ 94 | self.shadow = state_dict 95 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | """Denoising Diffusion Probabilistic Model. 2 | 3 | Combines U-Net network with Denoising Diffusion Model and 4 | creates single image super-resolution solver architecture. 5 | 6 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 7 | """ 8 | import logging 9 | import os 10 | import typing 11 | from collections import OrderedDict 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.optim.lr_scheduler import MultiStepLR 16 | 17 | from .base_model import BaseModel 18 | from .ema import EMA 19 | from .networks import define_network 20 | 21 | logger = logging.getLogger("base") 22 | 23 | 24 | class DDPM(BaseModel): 25 | """Denoising Diffusion Probabilistic Model. 26 | 27 | Attributes: 28 | in_channel: The number of channels of input tensor of U-Net. 29 | out_channel: The number of channels of output tensor of U-Net. 30 | norm_groups: The number of groups for group normalization. 31 | inner_channel: Timestep embedding dimension. 32 | channel_multiplier: A tuple specifying the scaling factors of channels. 33 | attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer. 34 | res_blocks: The number of residual blocks. 35 | dropout: Dropout probability. 36 | diffusion_loss: Either l1 or l2. 37 | conditional: Whether to condition on INTERPOLATED image or not. 38 | gpu_ids: IDs of gpus. 39 | distributed: Whether the computation will be distributed among multiple GPUs or not. 40 | init_method: NN weight initialization method. One of normal, kaiming or orthogonal inisializations. 41 | train_schedule: Defines the type of beta schedule for training. 42 | train_n_timestep: Number of diffusion timesteps for training. 43 | train_linear_start: Minimum value of the linear schedule for training. 44 | train_linear_end: Maximum value of the linear schedule for training. 45 | val_schedule: Defines the type of beta schedule for validation. 46 | val_n_timestep: Number of diffusion timesteps for validation. 47 | val_linear_start: Minimum value of the linear schedule for validation. 48 | val_linear_end: Maximum value of the linear schedule for validation. 49 | finetune_norm: Whetehr to fine-tune or train from scratch. 50 | optimizer: The optimization algorithm. 51 | amsgrad: Whether to use the AMSGrad variant of optimizer. 52 | learning_rate: The learning rate. 53 | checkpoint: Path to the checkpoint file. 54 | resume_state: The path to load the network. 55 | phase: Either train or val. 56 | height: U-Net input tensor height value. 57 | """ 58 | 59 | def __init__(self, in_channel, out_channel, norm_groups, inner_channel, 60 | channel_multiplier, attn_res, res_blocks, dropout, 61 | diffusion_loss, conditional, gpu_ids, distributed, init_method, 62 | train_schedule, train_n_timestep, train_linear_start, train_linear_end, 63 | val_schedule, val_n_timestep, val_linear_start, val_linear_end, 64 | finetune_norm, optimizer, amsgrad, learning_rate, checkpoint, resume_state, 65 | phase, height): 66 | 67 | super(DDPM, self).__init__(gpu_ids) 68 | noise_predictor = define_network(in_channel, out_channel, norm_groups, inner_channel, 69 | channel_multiplier, attn_res, res_blocks, dropout, 70 | diffusion_loss, conditional, gpu_ids, distributed, 71 | init_method, height) 72 | self.SR_net = self.set_device(noise_predictor) 73 | self.loss_type = diffusion_loss 74 | self.data, self.SR = None, None 75 | self.checkpoint = checkpoint 76 | self.resume_state = resume_state 77 | self.finetune_norm = finetune_norm 78 | self.phase = phase 79 | self.set_loss() 80 | self.months = [] # A list of months of curent data given by feed_data. 81 | 82 | if self.phase == "train": 83 | self.set_new_noise_schedule(schedule=train_schedule, n_timestep=train_n_timestep, 84 | linear_start=train_linear_start, linear_end=train_linear_end) 85 | else: 86 | self.set_new_noise_schedule(schedule=val_schedule, n_timestep=val_n_timestep, 87 | linear_start=val_linear_start, linear_end=val_linear_end) 88 | 89 | if self.phase == "train": 90 | self.SR_net.train() 91 | if self.finetune_norm: 92 | optim_params = [] 93 | for k, v in self.SR_net.named_parameters(): 94 | v.requires_grad = False 95 | if k.find("transformer") >= 0: 96 | v.requires_grad = True 97 | v.data.zero_() 98 | optim_params.append(v) 99 | logger.info(f"Params [{k:s}] initialized to 0 and will be fine-tuned.") 100 | else: 101 | optim_params = list(self.SR_net.parameters()) 102 | 103 | self.optimizer = optimizer(optim_params, lr=learning_rate, amsgrad=amsgrad) 104 | 105 | # Learning rate schedulers. 106 | # self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=40000, eta_min=1e-6) 107 | self.scheduler = MultiStepLR(self.optimizer, milestones=[40000], gamma=0.5) 108 | 109 | self.ema = EMA(mu=0.9999) 110 | self.ema.register(self.SR_net) 111 | 112 | self.log_dict = OrderedDict() 113 | 114 | self.load_network() 115 | self.print_network() 116 | 117 | def feed_data(self, data: tuple) -> None: 118 | """Stores data for feeding into the model and month indices for each tensor in a batch. 119 | 120 | Args: 121 | data: A tuple containing dictionary with the following keys: 122 | HR: a batch of high-resolution images [B, C, H, W], 123 | LR: a batch of low-resolution images [B, C, H, W], 124 | INTERPOLATED: a batch of upsampled (via interpolation) images [B, C, H, W] 125 | and list of corresponding months of samples in a batch. 126 | """ 127 | self.data, self.months = self.set_device(data[0]), data[1] 128 | 129 | def optimize_parameters(self) -> None: 130 | """Computes loss and performs GD step on learnable parameters. 131 | """ 132 | self.optimizer.zero_grad() 133 | loss = self.SR_net(self.data) 134 | loss = loss.sum() / self.data["HR"].numel() 135 | loss.backward() 136 | self.optimizer.step() 137 | self.ema.update(self.SR_net) # Exponential Moving Average step of parameters. 138 | self.log_dict[self.loss_type] = loss.item() # Setting the log. 139 | 140 | def lr_scheduler_step(self): 141 | """Learning rate scheduler step. 142 | """ 143 | # self.scheduler.step() 144 | 145 | def get_lr(self) -> float: 146 | """Fetches current learning rate. 147 | 148 | Returns: 149 | Current learning rate value. 150 | """ 151 | return self.optimizer.param_groups[0]['lr'] 152 | 153 | def get_named_parameters(self) -> dict: 154 | """Fetched U-Net's parameters. 155 | 156 | Returns: 157 | U-Net's parameters with their names. 158 | """ 159 | return self.SR_net.named_parameters() 160 | 161 | def test(self, continuous: bool = False) -> None: 162 | """Constructs the super-resolution image and assiggns to SR attribute. 163 | 164 | Args: 165 | continuous: Either to return all the SR images for each denoising timestep or not. 166 | """ 167 | self.SR_net.eval() 168 | with torch.no_grad(): 169 | if isinstance(self.SR_net, nn.DataParallel): 170 | self.SR = self.SR_net.module.super_resolution(self.data["INTERPOLATED"], continuous) 171 | else: 172 | self.SR = self.SR_net.super_resolution(self.data["INTERPOLATED"], continuous) 173 | self.SR = self.SR.unsqueeze(0) if len(self.SR.size()) == 3 else self.SR 174 | 175 | self.SR_net.train() 176 | 177 | def generate_multiple_candidates(self, n: int = 10) -> torch.tensor: 178 | """Generates n super-resolution tesnors. 179 | 180 | Args: 181 | n: The number of candidates. 182 | 183 | Returns: 184 | n super-resolution tensors of shape [n, B, C, H, W] corresponding 185 | to data fed into the model. 186 | """ 187 | self.SR_net.eval() 188 | batch_size, c, h, w = self.data["INTERPOLATED"].size() 189 | sr_candidates = torch.empty(size=(n, batch_size, c, h, w)) 190 | with torch.no_grad(): 191 | for i in range(n): 192 | if isinstance(self.SR_net, nn.DataParallel): 193 | x_sr = self.SR_net.module.super_resolution(self.data["INTERPOLATED"], False).detach().float().cpu() 194 | else: 195 | x_sr = self.SR_net.super_resolution(self.data["INTERPOLATED"], False).detach().float().cpu() 196 | sr_candidates[i] = x_sr.unsqueeze(0) if len(x_sr.size()) == 3 else x_sr 197 | 198 | self.SR_net.train() 199 | return sr_candidates 200 | 201 | def set_loss(self) -> None: 202 | """Sets loss to a device. 203 | """ 204 | if isinstance(self.SR_net, nn.DataParallel): 205 | self.SR_net.module.set_loss(self.device) 206 | else: 207 | self.SR_net.set_loss(self.device) 208 | 209 | def set_new_noise_schedule(self, schedule, n_timestep, linear_start, linear_end) -> None: 210 | """Creates new noise scheduler. 211 | 212 | Args: 213 | schedule: Defines the type of beta schedule. 214 | n_timestep: Number of diffusion timesteps. 215 | linear_start: Minimum value of the linear schedule. 216 | linear_end: Maximum value of the linear schedule. 217 | """ 218 | if isinstance(self.SR_net, nn.DataParallel): 219 | self.SR_net.module.set_new_noise_schedule(schedule, n_timestep, linear_start, linear_end, self.device) 220 | else: 221 | self.SR_net.set_new_noise_schedule(schedule, n_timestep, linear_start, linear_end, self.device) 222 | 223 | def get_current_log(self) -> OrderedDict: 224 | """Returns the logs. 225 | 226 | Returns: 227 | log_dict: Current logs of the model. 228 | """ 229 | return self.log_dict 230 | 231 | def get_months(self) -> list: 232 | """Returns the list of month indices corresponding to batch of samples 233 | fed into the model with feed_data. 234 | 235 | Returns: 236 | months: Current list of months. 237 | """ 238 | return self.months 239 | 240 | def get_current_visuals(self, need_LR: bool = True, only_rec: bool = False) -> typing.OrderedDict: 241 | """Returns only reconstructed super-resolution image if only_rec is True (with "SAM" key), 242 | otherwise returns super-resolution image (with "SR" key), interpolated LR image 243 | (with "interpolated" key), HR image (with "HR" key), LR image (with "LR" key). 244 | 245 | Args: 246 | need_LR: Whether to return LR image or not. 247 | only_rec: Whether to return only reconstructed super-resolution image or not. 248 | 249 | Returns: 250 | Dict containing desired images. 251 | """ 252 | out_dict = OrderedDict() 253 | if only_rec: 254 | out_dict["SR"] = self.SR.detach().float().cpu() 255 | else: 256 | out_dict["SR"] = self.SR.detach().float().cpu() 257 | out_dict["INTERPOLATED"] = self.data["INTERPOLATED"].detach().float().cpu() 258 | out_dict["HR"] = self.data["HR"].detach().float().cpu() 259 | if need_LR and "LR" in self.data: 260 | out_dict["LR"] = self.data["LR"].detach().float().cpu() 261 | return out_dict 262 | 263 | def print_network(self) -> None: 264 | """Prints the network architecture. 265 | """ 266 | s, n = self.get_network_description(self.SR_net) 267 | if isinstance(self.SR_net, nn.DataParallel): 268 | net_struc_str = "{} - {}".format(self.SR_net.__class__.__name__, self.SR_net.module.__class__.__name__) 269 | else: 270 | net_struc_str = "{}".format(self.SR_net.__class__.__name__) 271 | 272 | logger.info(f"U-Net structure: {net_struc_str}, with parameters: {n:,d}") 273 | logger.info(f"Architecture:\n{s}\n") 274 | 275 | def save_network(self, epoch: int, iter_step: int) -> None: 276 | """Saves the network checkpoint. 277 | 278 | Args: 279 | epoch: How many epochs has the model been trained. 280 | iter_step: How many iteration steps has the model been trained. 281 | """ 282 | gen_path = os.path.join(self.checkpoint, f"I{iter_step}_E{epoch}_gen.pth") 283 | opt_path = os.path.join(self.checkpoint, f"I{iter_step}_E{epoch}_opt.pth") 284 | 285 | network = self.SR_net.module if isinstance(self.SR_net, nn.DataParallel) else self.SR_net 286 | 287 | state_dict = network.state_dict() 288 | for key, param in state_dict.items(): 289 | state_dict[key] = param.cpu() 290 | torch.save(state_dict, gen_path) 291 | 292 | opt_state = {"epoch": epoch, "iter": iter_step, 293 | "scheduler": self.scheduler.state_dict(), 294 | "optimizer": self.optimizer.state_dict()} 295 | torch.save(opt_state, opt_path) 296 | logger.info("Saved model in [{:s}] ...".format(gen_path)) 297 | 298 | def load_network(self) -> None: 299 | """Loads the netowrk parameters. 300 | """ 301 | if self.resume_state is not None: 302 | logger.info(f"Loading pretrained model for G [{self.resume_state:s}] ...") 303 | gen_path, opt_path = f"{self.resume_state}_gen.pth", f"{self.resume_state}_opt.pth" 304 | 305 | network = self.SR_net.module if isinstance(self.SR_net, nn.DataParallel) else self.SR_net 306 | network.load_state_dict(torch.load(gen_path), strict=(not self.finetune_norm)) 307 | 308 | if self.phase == "train": 309 | opt = torch.load(opt_path) 310 | self.optimizer.load_state_dict(opt["optimizer"]) 311 | self.scheduler.load_state_dict(opt["scheduler"]) 312 | self.begin_step = opt["iter"] 313 | self.begin_epoch = opt["epoch"] 314 | -------------------------------------------------------------------------------- /model/modules/diffusion.py: -------------------------------------------------------------------------------- 1 | """Gaussian denoising model. 2 | 3 | Model gets an image from data and adds noise step by step. Then the 4 | model is trained to predict that noise at each step. Later, it 5 | can be used to denoise images. 6 | 7 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 8 | """ 9 | import math 10 | from functools import partial 11 | from typing import Union, Tuple 12 | 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | 17 | 18 | def _warmup_beta(linear_start: float, linear_end: float, 19 | n_timestep: int, warmup_frac: float) -> np.ndarray: 20 | """Computes linear beta schedule using warmup fraction. 21 | 22 | Args: 23 | linear_start: Minimum value of the schedule. 24 | linear_end: Maximum value of the schedule. 25 | n_timestep: Number of diffusion timesteps. 26 | warmup_frac: The portion of timesteps that a scheduler requires to go from start to end. 27 | Returns: 28 | Beta values for each timestamp starting from 1 to n_timestep. 29 | """ 30 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 31 | warmup_time = int(n_timestep * warmup_frac) 32 | betas[:warmup_time] = np.linspace(linear_start, linear_end, warmup_time, dtype=np.float64) 33 | return betas 34 | 35 | 36 | def make_beta_schedule(schedule: str, n_timestep: int, linear_start: float = 1e-4, 37 | linear_end: float = 2e-2, cosine_s: float = 8e-3) -> \ 38 | Union[np.ndarray, torch.Tensor]: 39 | """Defines Gaussian noise variance beta schedule that is gradually added 40 | to the data during the diffusion process. 41 | 42 | Args: 43 | schedule: Defines the type of beta schedule. Possible types are const, 44 | linear, warmup10, warmup50, quad, jsd and cosine. 45 | n_timestep: Number of diffusion timesteps. 46 | linear_start: Minimum value of the linear schedule. 47 | linear_end: Maximum value of the linear schedule. 48 | cosine_s: An offset to prevent beta to be smaller at timestep 0. 49 | 50 | Returns: 51 | Beta values for each timestep starting from 1 to n_timestep. 52 | """ 53 | if schedule == "const": # Constant beta schedule. 54 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 55 | elif schedule == "linear": # Linear beta schedule. 56 | betas = np.linspace(linear_start, linear_end, 57 | n_timestep, dtype=np.float64) 58 | elif schedule == "warmup10": # Linear beta schedule with warmup fraction of 0.10. 59 | betas = _warmup_beta(linear_start, linear_end, 60 | n_timestep, 0.1) 61 | elif schedule == "warmup50": # Linear beta schedule with warmup fraction of 0.50. 62 | betas = _warmup_beta(linear_start, linear_end, 63 | n_timestep, 0.5) 64 | elif schedule == "quad": # Quadratic beta schedule. 65 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 66 | n_timestep, dtype=np.float64) ** 2 67 | elif schedule == "jsd": # Multiplicative inverse beta schedule: 1/T, 1/(T-1), 1/(T-2), ..., 1. 68 | betas = 1. / np.linspace(n_timestep, 69 | 1, n_timestep, dtype=np.float64) 70 | elif schedule == "cosine": # Cosine beta schedule [formula 17, arxiv:2102.09672]. 71 | timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s) 72 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 73 | alphas = torch.cos(alphas).pow(2) 74 | alphas = alphas / alphas[0] 75 | betas = 1 - alphas[1:] / alphas[:-1] 76 | betas = betas.clamp(max=0.999) 77 | else: 78 | raise NotImplementedError(schedule) 79 | return betas 80 | 81 | 82 | class GaussianDiffusion(nn.Module): 83 | """Gaussian Diffusion Probabilistic model. 84 | 85 | Attributes: 86 | denoise_net: U-Net. 87 | loss_type: Loss function, either l1 or l2. 88 | conditional: Whether to condition on smth or not (typically model is conditioned on INTERPOLATED image). 89 | """ 90 | def __init__(self, denoise_net: nn.Module, 91 | loss_type: str = "l2", conditional: bool = True): 92 | super().__init__() 93 | self.denoise_net = denoise_net 94 | self.loss_type = loss_type 95 | self.conditional = conditional 96 | self.loss_func = None 97 | self.sqrt_alphas_cumprod_prev = None 98 | self.num_timesteps = None 99 | 100 | def set_loss(self, device: torch.device): 101 | """Sets a loss function. 102 | 103 | Args: 104 | device: A torch.device object. 105 | """ 106 | if self.loss_type == "l1": 107 | self.loss_func = nn.L1Loss(reduction="sum").to(device) # L1 loss. 108 | elif self.loss_type == "l2": 109 | self.loss_func = nn.MSELoss(reduction="sum").to(device) # Squared L2 loss. 110 | else: 111 | raise NotImplementedError("Specify loss_type attribute to be either \'l1\' or \'l2\'.") 112 | 113 | def set_new_noise_schedule(self, schedule, n_timestep, linear_start, linear_end, device): 114 | """Sets a new beta schedule. 115 | 116 | Args: 117 | schedule: Defines the type of beta schedule. Possible types are const, linear, warmup10, warmup50, quad, 118 | jsd and cosine. 119 | n_timestep: Number of diffusion timesteps. 120 | linear_start: Minimum value of the linear schedule. 121 | linear_end: Maximum value of the linear schedule. 122 | device: A torch.device object. 123 | """ 124 | # Defining a partial fundtion that converts data into type of float32 and moves it onto the specified device. 125 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 126 | 127 | betas = make_beta_schedule(schedule=schedule, n_timestep=n_timestep, 128 | linear_start=linear_start, linear_end=linear_end) 129 | 130 | betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas 131 | alphas = 1. - betas 132 | alphas_cumprod = np.cumprod(alphas, axis=0) 133 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 134 | self.sqrt_alphas_cumprod_prev = np.sqrt(np.append(1., alphas_cumprod)) 135 | 136 | timesteps, = betas.shape 137 | self.num_timesteps = int(timesteps) 138 | 139 | # Storing parameters into state dict of model. 140 | self.register_buffer("betas", to_torch(betas)) 141 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 142 | self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) 143 | 144 | # Calculating constants for reverse conditional posterior distribution q(x_{t-1} | x_t, x_0). 145 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) 146 | self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1. - alphas_cumprod))) 147 | self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1. - alphas_cumprod))) 148 | self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1. / alphas_cumprod))) 149 | self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1. / alphas_cumprod - 1))) 150 | 151 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # Formula 7, arXiv:2006.11239. 152 | self.register_buffer("posterior_variance", to_torch(posterior_variance)) 153 | 154 | # Clipping the minimum log value of posterior variance to be 1e-20 as posterior variance is 0 at timestep 0. 155 | self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 156 | 157 | # Calculating the coefficients of the mean q(x_{t-1} | x_t, x_0) [formula 7, arXiv:2006.11239]. 158 | self.register_buffer("posterior_mean_coef1", 159 | to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 160 | self.register_buffer("posterior_mean_coef2", 161 | to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 162 | 163 | def predict_start_from_noise(self, x_t: torch.Tensor, t: int, noise: torch.Tensor) -> torch.Tensor: 164 | """Calculates x_0 from x_t and Gaussian standard noise by applying reparametrization 165 | trick to the formula 4 [arXiv:2006.11239]. 166 | 167 | Args: 168 | x_t: Data point of size [B, C, H, W] after t diffusion steps. 169 | t: The diffusion timestep. 170 | noise: Gaussian Standard noise of size [B, C, H, W]. 171 | 172 | Returns: 173 | Starting data point x_0 of size [B, C, H, W]. 174 | """ 175 | return self.sqrt_recip_alphas_cumprod[t] * x_t - self.sqrt_recipm1_alphas_cumprod[t] * noise 176 | 177 | def q_posterior(self, x_start: torch.Tensor, x_t: torch.Tensor, t: int) -> Tuple[torch.Tensor, torch.Tensor]: 178 | """Computes mean and log variance of q(x_{t-1} | x_t, x_0) using formula 7 [arXiv:2006.11239]. 179 | 180 | Args: 181 | x_start: Starting data point of size [B, C, H, W]. 182 | x_t: Data point of size [B, C, H, W] after t diffusion steps. 183 | t: The diffusion timestep. 184 | 185 | Returns: 186 | Mean and log variance of reverse conditional posterior distribution. 187 | posterior_mean: Size of [B, C, H, W] 188 | posterior_log_variance_clipped: Scalar. 189 | """ 190 | posterior_mean = self.posterior_mean_coef1[t] * x_start + self.posterior_mean_coef2[t] * x_t 191 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t] 192 | return posterior_mean, posterior_log_variance_clipped 193 | 194 | def p_mean_variance(self, x: torch.Tensor, t: int, clip_denoised: bool, 195 | condition_x: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: 196 | """Computes mean and log variance of q(x_{t-1} | x_t, x_0) from arbitrary noise point x at timestep t. 197 | 198 | Args: 199 | x: Noisy data point at timestep t of size [B, C, H, W]. 200 | t: The diffusion timestep. 201 | clip_denoised: Either to clip or not starting data point. 202 | condition_x: The conditioned point x of size [B, C, H, W], typically upscaled LR image. 203 | 204 | Returns: 205 | Mean and log variance of reverse conditional posterior distribution. 206 | model_mean: Size of [B, C, H, W] 207 | posterior_log_variance: Scalar. 208 | """ 209 | batch_size = x.shape[0] 210 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(x.device) 211 | 212 | if condition_x is not None: 213 | x_recon = self.predict_start_from_noise( 214 | x, t=t, noise=self.denoise_net(torch.cat([condition_x, x], dim=1), noise_level)) 215 | else: 216 | x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_net(x, noise_level)) 217 | 218 | if clip_denoised: 219 | x_recon.clamp_(-1., 1.) 220 | 221 | model_mean, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 222 | return model_mean, posterior_log_variance 223 | 224 | @torch.no_grad() 225 | def p_sample(self, x: torch.tensor, t: int, clip_denoised: bool = True, 226 | condition_x: torch.Tensor = None) -> torch.Tensor: 227 | """Defines single sampling step, i.e. sample from p(x{t-1} | x_t). 228 | 229 | Args: 230 | x: Noisy data point at timestep t of size [B, C, H, W]. 231 | t: The diffusion timestep. 232 | clip_denoised: Either to clip or not starting data point. 233 | condition_x: The conditioned point x of size [B, C, H, W]. Typically upscaled LR image. 234 | 235 | Returns: 236 | Sampled denoised data point at timestep t-1 of size [B, C, H, W]. 237 | """ 238 | model_mean, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised, 239 | condition_x=condition_x) 240 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x) 241 | return model_mean + noise * (0.5 * model_log_variance).exp() 242 | 243 | @torch.no_grad() 244 | def p_sample_loop(self, x_in: torch.Tensor, continuous: bool = False) -> torch.Tensor: 245 | """Implements the sampling algorithms [algorithm 2, arXiv:2006.11239]. 246 | 247 | Args: 248 | x_in: Input noisy data point of size [B, C, H, W]. 249 | continuous: Either to return all the SR images for each denoising timestep or not. 250 | 251 | Returns: 252 | Sampled denoised data point of size [C, H, W]. 253 | """ 254 | sample_inter = 10 # Frequency of keeping denoised images during reverse 255 | # diffusion process. 256 | batch_size = x_in.size(0) 257 | if not self.conditional: 258 | shape = x_in 259 | img = torch.randn(shape, device=self.betas.device) 260 | ret_img = img 261 | 262 | for i in reversed(range(0, self.num_timesteps)): # self.num_timesteps-1, self.num_timesteps-2, ..., 0 263 | img = self.p_sample(img, i) 264 | if i % sample_inter == 0: 265 | ret_img = torch.cat([ret_img, img], dim=0) 266 | else: 267 | x = x_in 268 | shape = x.shape 269 | img = torch.randn(shape, device=self.betas.device) # 1st step of the algorithm. 270 | ret_img = img 271 | for t in reversed(range(0, self.num_timesteps)): 272 | # By specifying condition_x argument to be input image x, U-Net input 273 | # is constructed by concatenating upsampled LR image with the noisy 274 | # high resolution reconstructed image at current step t. 275 | img = self.p_sample(img, t, condition_x=x) # 3rd and 4th steps. 276 | if t % sample_inter == 0: 277 | ret_img = torch.cat([ret_img, img], dim=0) 278 | 279 | if continuous: 280 | return ret_img 281 | else: 282 | return ret_img[-batch_size:] 283 | 284 | @torch.no_grad() 285 | def super_resolution(self, x_in: torch.Tensor, continuous: bool = False) -> torch.Tensor: 286 | """Denoises the given input data x_in. 287 | 288 | Args: 289 | x_in: A noisy data point of size [B, C, H, W]. Typically upscaled LR image. 290 | continuous: Either to return all the SR images for each denoising timestep or not. 291 | 292 | Returns: 293 | Denoised data point of size [B, C, H, W]. 294 | """ 295 | return self.p_sample_loop(x_in, continuous) 296 | 297 | @staticmethod 298 | def q_sample(x_start: torch.Tensor, continuous_sqrt_alpha_cumprod: torch.Tensor, 299 | noise: torch.Tensor = None) -> torch.Tensor: 300 | """Sampling from q(x_t | x_0) [formula 4, arXiv:2006.11239]. 301 | 302 | Args: 303 | x_start: Starting data point x_0 of size [B, C, H, W]. Often HR image. 304 | continuous_sqrt_alpha_cumprod: Square root of the product of alphas of size [B, 1, 1, 1]. 305 | noise: Gaussian standard noise of the same size as x_start. 306 | 307 | Returns: 308 | Sampled noisy point of size [B, C, H, W]. 309 | """ 310 | if noise is None: 311 | noise = torch.randn_like(x_start) 312 | return continuous_sqrt_alpha_cumprod * x_start + (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise 313 | 314 | def p_losses(self, x_in: dict, noise: torch.Tensor = None) -> torch.Tensor: 315 | """Computes loss function. 316 | 317 | Args: 318 | x_in: A dictionary containing the following keys: 319 | HR: a batch of high-resolution images [B, C, H, W]. 320 | SR: a batch of upsampled (via interpolation) images [B, C, H, W]. 321 | Index: indices of samples of a batch in the dataset [B]. 322 | noise: Gaussian Standard noise of size [B, C, H, W]. 323 | 324 | Returns: 325 | Loss function value. 326 | """ 327 | x_start = x_in["HR"] 328 | b = x_start.shape[0] # Dimension of s_start is (B, C, H, W). 329 | 330 | # Using piecewise Uniform distribution to sample gammas. 331 | # See definition of gamma in formula 3 of paper [arXiv:2104.07636] and section 2.4 for 332 | # its sampling strategy p(gamma). 333 | # continuous_sqrt_alpha_cumprod is equal to square root of gamma. 334 | t = np.random.randint(1, self.num_timesteps + 1) # Randomly sampling a diffusion timestep. 335 | continuous_sqrt_alpha_cumprod = torch.FloatTensor(np.random.uniform(self.sqrt_alphas_cumprod_prev[t-1], 336 | self.sqrt_alphas_cumprod_prev[t], 337 | size=b) 338 | ).to(x_start.device) 339 | continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(b, -1) 340 | 341 | if noise is None: 342 | noise = torch.randn_like(x_start) 343 | 344 | # Diffuion process: HR image is corrupted to get the noisy image. 345 | x_noisy = self.q_sample(x_start=x_start, 346 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), 347 | noise=noise) 348 | 349 | # U-Net predicts the Gaussian noise used to corrupt the HR image in the diffusion process. 350 | if not self.conditional: 351 | noise_reconstructed = self.denoise_net(x_noisy, continuous_sqrt_alpha_cumprod) 352 | else: 353 | # Conditioning on interpolated LR image called INTERPOLATED. 354 | noise_reconstructed = self.denoise_net(torch.cat([x_in["INTERPOLATED"], x_noisy], dim=1), 355 | continuous_sqrt_alpha_cumprod) 356 | 357 | loss = self.loss_func(noise, noise_reconstructed) # Penalizing x_recon to predict Gaussian Standard noise. 358 | return loss 359 | 360 | def forward(self, x: dict, *args, **kwargs) -> torch.Tensor: 361 | """Forward pass. 362 | 363 | Args: 364 | x: A dictionary containing the following keys: 365 | HR: a batch of high-resolution images [B, C, H, W], 366 | SR: a batch of upsampled (via interpolation) images [B, C, H, W], 367 | Index: indices of samples of a batch in the dataset [B]. 368 | 369 | Returns: 370 | Loss function value. 371 | """ 372 | return self.p_losses(x, *args, **kwargs) 373 | -------------------------------------------------------------------------------- /model/modules/unet.py: -------------------------------------------------------------------------------- 1 | """U-Net model for Denoising Diffusion Probabilistic Model. 2 | 3 | This implementation contains a number of modifications to 4 | original U-Net (residual blocks, multi-head attention) 5 | and also adds diffusion timestep embeddings t. 6 | 7 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 8 | """ 9 | import math 10 | 11 | import torch 12 | from torch import nn 13 | 14 | 15 | class PositionalEncoding(nn.Module): 16 | """Sinusoidal Positional Encoding component. 17 | 18 | Attributes: 19 | dim: Embedding dimension. 20 | """ 21 | 22 | def __init__(self, dim): 23 | super().__init__() 24 | self.dim = dim 25 | 26 | def forward(self, noise_level): 27 | """Computes the sinusoidal positional encodings. 28 | 29 | Args: 30 | noise_level: An array of size [B, 1] representing the difusion timesteps. 31 | 32 | Returns: 33 | Positional encodings of size [B, 1, D]. 34 | """ 35 | half_dim = self.dim // 2 36 | step = torch.arange(half_dim, dtype=noise_level.dtype, device=noise_level.device) / half_dim 37 | encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 38 | encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) 39 | return encoding 40 | 41 | 42 | class FeatureWiseAffine(nn.Module): 43 | """Transformes timestep embeddings and injects it into input tensor. 44 | 45 | Attributes: 46 | in_channels: Input tensor channels. 47 | out_channels: Output tensor channels. 48 | use_affine_level: Whether to apply an affine transformation on input or add a noise. 49 | """ 50 | 51 | def __init__(self, in_channels: int, out_channels: int, use_affine_level: bool = False): 52 | super().__init__() 53 | self.use_affine_level = use_affine_level 54 | self.noise_func = nn.Linear(in_channels, out_channels * (1+self.use_affine_level)) 55 | 56 | def forward(self, x, time_emb): 57 | """Forward pass. 58 | 59 | Args: 60 | x: Input tensor of size [B, D, H, W]. 61 | time_emb: Timestep embeddings of size [B, 1, D] where D is the dimension of embedding. 62 | 63 | Returns: 64 | Transformed tensor of size [B, D, H, W]. 65 | """ 66 | batch_size = x.shape[0] 67 | if self.use_affine_level: 68 | gamma, beta = self.noise_func(time_emb).view(batch_size, -1, 1, 1).chunk(2, dim=1) 69 | # The size of gamma and beta is (batch_size, out_channels, 1, 1). 70 | x = (1 + gamma) * x + beta 71 | else: 72 | x = x + self.noise_func(time_emb).view(batch_size, -1, 1, 1) 73 | return x 74 | 75 | 76 | class Upsample(nn.Module): 77 | """Scales the feature map by a factor of 2, i.e. upscale the feature map. 78 | 79 | Attributes: 80 | dim: Input/output tensor channels. 81 | """ 82 | 83 | def __init__(self, dim: int): 84 | super().__init__() 85 | self.up = nn.Upsample(scale_factor=2, mode="bicubic") 86 | self.conv = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=1) 87 | 88 | def forward(self, x): 89 | """Upscales the spatial dimensions of the input tensor two times. 90 | 91 | Args: 92 | x: Input tensor of size [B, 8*D, H, W]. 93 | 94 | Returns: 95 | Upscaled tensor of size [B, 8*D, 2*H, 2*W]. 96 | """ 97 | return self.conv(self.up(x)) 98 | 99 | 100 | class Downsample(nn.Module): 101 | """Scale the feature map by a factor of 1/2, i.e. downscale the feature map. 102 | 103 | Attributes: 104 | dim: Input/output tensor channels. 105 | """ 106 | 107 | def __init__(self, dim: int): 108 | super().__init__() 109 | self.conv = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, stride=2, padding=1) 110 | 111 | def forward(self, x): 112 | """Downscales the spatial dimensions of the input tensor two times. 113 | 114 | Args: 115 | x: Input tensor of size [B, D, H, W]. 116 | 117 | Returns: 118 | Downscaled tensor of size [B, D, H/2, W/2]. 119 | """ 120 | return self.conv(x) 121 | 122 | 123 | class Block(nn.Module): 124 | """A building component of Residual block. 125 | 126 | Attributes: 127 | dim: Input tensor channels. 128 | dim_out: Output tensor channels. 129 | groups: Number of groups to separate the channels into. 130 | dropout: Dropout probability. 131 | """ 132 | 133 | def __init__(self, dim: int, dim_out: int, groups: int = 32, dropout: float = 0): 134 | super().__init__() 135 | self.block = nn.Sequential(nn.GroupNorm(num_groups=groups, num_channels=dim), 136 | nn.SiLU(), 137 | nn.Dropout2d(dropout) if dropout != 0 else nn.Identity(), 138 | nn.Conv2d(in_channels=dim, out_channels=dim_out, kernel_size=3, padding=1)) 139 | 140 | def forward(self, x): 141 | """Applies block transformations on input tensor. 142 | 143 | Args: 144 | x: Input tensor of size [B, D, H, W]. 145 | 146 | Returns: 147 | Transformed tensor of size [B, D, H, W]. 148 | """ 149 | return self.block(x) 150 | 151 | 152 | class ResnetBlock(nn.Module): 153 | """Residual block. 154 | 155 | Attributes: 156 | dim: Input tensor channels. 157 | dim_out: Output tensor channels. 158 | noise_level_emb_dim: Timestep embedding dimension. 159 | dropout: Dropout probability. 160 | use_affine_level: Whether to apply an affine transformation on input or add a noise. 161 | norm_groups: The number of groups for group normalization. 162 | """ 163 | 164 | def __init__(self, dim: int, dim_out: int, noise_level_emb_dim: int = None, dropout: float = 0, 165 | use_affine_level: bool = False, norm_groups: int = 32): 166 | super().__init__() 167 | self.noise_func = FeatureWiseAffine(in_channels=noise_level_emb_dim, out_channels=dim_out, 168 | use_affine_level=use_affine_level) 169 | self.block1 = Block(dim, dim_out, groups=norm_groups) 170 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 171 | self.res_conv = nn.Conv2d(in_channels=dim, out_channels=dim_out, kernel_size=1) \ 172 | if dim != dim_out else nn.Identity() 173 | 174 | def forward(self, x, time_emb): 175 | """Applied Residual block on input tensors. 176 | 177 | Args: 178 | x: Input tensor of size [B, D, H, W]. 179 | time_emb: Timestep embeddings of size [B, 1, D] where D is the dimension of embedding. 180 | 181 | Returns: 182 | Transformed tensor of size [B, D, H, W]. 183 | """ 184 | h = self.block1(x) 185 | h = self.noise_func(h, time_emb) 186 | h = self.block2(h) 187 | return h + self.res_conv(x) 188 | 189 | 190 | class SelfAttention(nn.Module): 191 | """Multi-head attention. 192 | 193 | Attributes: 194 | in_channel: Input tensor channels. 195 | n_head: The number of heads in multi-head attention. 196 | norm_groups: The number of groups for group normalization. 197 | """ 198 | 199 | def __init__(self, in_channel: int, n_head: int = 1, norm_groups: int = 32): 200 | super().__init__() 201 | 202 | self.n_head = n_head 203 | self.norm = nn.GroupNorm(norm_groups, in_channel) 204 | self.qkv = nn.Conv2d(in_channels=in_channel, out_channels=3*in_channel, kernel_size=1, bias=False) 205 | self.out = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=1) 206 | 207 | def forward(self, x): 208 | """Applies self-attention to input tensor. 209 | 210 | Args: 211 | x: Input tensor of size [B, 8*D, H, W]. 212 | 213 | Returns: 214 | Transformed tensor of size [B, 8*D, H, W]. 215 | """ 216 | batch_size, channel, height, width = x.shape 217 | head_dim = channel // self.n_head 218 | 219 | norm = self.norm(x) 220 | qkv = self.qkv(norm).view(batch_size, self.n_head, head_dim * 3, height, width) 221 | query, key, value = qkv.chunk(3, dim=2) 222 | 223 | attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel) 224 | attn = attn.view(batch_size, self.n_head, height, width, -1) 225 | attn = torch.softmax(attn, -1) 226 | attn = attn.view(batch_size, self.n_head, height, width, height, width) 227 | 228 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() 229 | out = self.out(out.view(batch_size, channel, height, width)) 230 | 231 | return out + x 232 | 233 | 234 | class ResnetBlocWithAttn(nn.Module): 235 | """ResnetBlock combined with sefl-attention layer. 236 | 237 | Attributes: 238 | dim: Input tensor channels. 239 | dim_out: Output tensor channels. 240 | noise_level_emb_dim: Timestep embedding dimension. 241 | norm_groups: The number of groups for group normalization. 242 | dropout: Dropout probability. 243 | with_attn: Whether to add self-attention layer or not. 244 | """ 245 | 246 | def __init__(self, dim: int, dim_out: int, *, noise_level_emb_dim: int = None, 247 | norm_groups: int = 32, dropout: float = 0, with_attn: bool = True): 248 | super().__init__() 249 | self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) 250 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) if with_attn else nn.Identity() 251 | 252 | def forward(self, x, time_emb): 253 | """Forward pass. 254 | 255 | Args: 256 | x: Input tensor of size [B, D, H, W]. 257 | time_emb: Timestep embeddings of size [B, 1, D] where D is the dimension of embedding. 258 | 259 | Returns: 260 | Transformed tensor of size [B, D, H, W]. 261 | """ 262 | x = self.res_block(x, time_emb) 263 | x = self.attn(x) 264 | return x 265 | 266 | 267 | class UNet(nn.Module): 268 | """Defines U-Net network. 269 | 270 | Attributes: 271 | in_channel: Input tensor channels. 272 | out_channel: Output tensor channels. 273 | inner_channel: Timestep embedding dimension. 274 | norm_groups: The number of groups for group normalization. 275 | channel_mults: A tuple specifying the scaling factors of channels. 276 | attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer. 277 | res_blocks: The number of residual blocks. 278 | dropout: Dropout probability. 279 | with_noise_level_emb: Whether to apply timestep encodings or not. 280 | height: Height of input tensor. 281 | """ 282 | 283 | def __init__(self, in_channel: int, out_channel: int, inner_channel: int, 284 | norm_groups: int, channel_mults: tuple, attn_res: tuple, 285 | res_blocks: int, dropout: float, with_noise_level_emb: bool = True, height: int = 128): 286 | super().__init__() 287 | 288 | if with_noise_level_emb: 289 | noise_level_channel = inner_channel 290 | 291 | # Time embedding layer that returns 292 | self.time_embedding = nn.Sequential(PositionalEncoding(inner_channel), 293 | nn.Linear(inner_channel, 4*inner_channel), 294 | nn.SiLU(), 295 | nn.Linear(4*inner_channel, inner_channel)) 296 | else: 297 | noise_level_channel, self.time_embedding = None, None 298 | 299 | num_mults = len(channel_mults) 300 | pre_channel = inner_channel 301 | feat_channels = [pre_channel] 302 | current_height = height 303 | downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)] 304 | 305 | for ind in range(num_mults): # For each channel growing factor. 306 | is_last = (ind == num_mults - 1) 307 | 308 | use_attn = current_height in attn_res 309 | channel_mult = inner_channel * channel_mults[ind] 310 | 311 | for _ in range(res_blocks): # Add res_blocks number of ResnetBlocWithAttn layer. 312 | downs.append(ResnetBlocWithAttn(pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, 313 | norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 314 | feat_channels.append(channel_mult) 315 | pre_channel = channel_mult 316 | 317 | # If the newly added ResnetBlocWithAttn layer to downs list is not the last one, 318 | # then add a Downsampling layer. 319 | if not is_last: 320 | downs.append(Downsample(pre_channel)) 321 | feat_channels.append(pre_channel) 322 | current_height //= 2 323 | 324 | self.downs = nn.ModuleList(downs) 325 | self.mid = nn.ModuleList([ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, 326 | norm_groups=norm_groups, dropout=dropout), 327 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, 328 | norm_groups=norm_groups, dropout=dropout, with_attn=False)]) 329 | 330 | ups = [] 331 | for ind in reversed(range(num_mults)): # For each channel growing factor (in decreasing order). 332 | is_last = (ind < 1) 333 | use_attn = (current_height in attn_res) 334 | channel_mult = inner_channel * channel_mults[ind] 335 | 336 | for _ in range(res_blocks+1): # Add res_blocks+1 number of ResnetBlocWithAttn layer. 337 | ups.append(ResnetBlocWithAttn(pre_channel+feat_channels.pop(), channel_mult, 338 | noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 339 | dropout=dropout, with_attn=use_attn)) 340 | pre_channel = channel_mult 341 | 342 | # If the newly added ResnetBlocWithAttn layer to ups list is not the last one, 343 | # then add an Upsample layer. 344 | if not is_last: 345 | ups.append(Upsample(pre_channel)) 346 | current_height *= 2 347 | 348 | self.ups = nn.ModuleList(ups) 349 | 350 | # Final convolution layer to transform the spatial dimensions to the desired shapes. 351 | self.final_conv = Block(pre_channel, out_channel if out_channel else in_channel, groups=norm_groups) 352 | 353 | def forward(self, x, time): 354 | """Forward pass. 355 | 356 | Args: 357 | x: Input tensor of size: [B, C, H, W], for WeatherBench C=2. 358 | time: Diffusion timesteps of size: [B, 1]. 359 | 360 | Returns: 361 | Estimation of Gaussian noise. 362 | """ 363 | t = self.time_embedding(time) if self.time_embedding else None # [B, 1, D] 364 | feats = [] 365 | 366 | for layer in self.downs: 367 | x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) 368 | feats.append(x) 369 | 370 | for layer in self.mid: 371 | x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) 372 | 373 | for layer in self.ups: 374 | x = layer(torch.cat((x, feats.pop()), dim=1), t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) 375 | 376 | return self.final_conv(x) 377 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | """Declares network weight initialization functions and a function 2 | to define final single image super-resolution solver architecture. 3 | 4 | Implements neural netowrk weight initialization methods such as 5 | normal, kaiming and orthogonal. Defines a function that 6 | creates a returns a network to train on single image 7 | super-resolution task. 8 | 9 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 10 | """ 11 | import functools 12 | import logging 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import init 17 | 18 | from .modules.diffusion import GaussianDiffusion 19 | from .modules.unet import UNet 20 | 21 | logger = logging.getLogger("base") 22 | 23 | 24 | def weights_init_normal(model: nn.Module, std: float = 0.02) -> None: 25 | """Initializes model weights from Gaussian distribution. 26 | 27 | Args: 28 | model: The network. 29 | std: Standard deviation of Gaussian distrbiution. 30 | """ 31 | classname = model.__class__.__name__ 32 | if classname.find("Conv") != -1: 33 | init.normal_(model.weight.data, 0.0, std) 34 | if model.bias is not None: 35 | model.bias.data.zero_() 36 | elif classname.find("Linear") != -1: 37 | init.normal_(model.weight.data, 0.0, std) 38 | if model.bias is not None: 39 | model.bias.data.zero_() 40 | elif classname.find("BatchNorm2d") != -1: 41 | init.normal_(model.weight.data, 1.0, std) 42 | init.constant_(model.bias.data, 0.0) 43 | 44 | 45 | def weights_init_kaiming(model: nn.Module, scale: float = 1) -> None: 46 | """He initialization of model weights. 47 | 48 | Args: 49 | model: The network. 50 | scale: Scaling factor of weights. 51 | """ 52 | classname = model.__class__.__name__ 53 | if classname.find("Conv2d") != -1: 54 | init.kaiming_normal_(model.weight.data) 55 | model.weight.data *= scale 56 | if model.bias is not None: 57 | model.bias.data.zero_() 58 | elif classname.find("Linear") != -1: 59 | init.kaiming_normal_(model.weight.data) 60 | model.weight.data *= scale 61 | if model.bias is not None: 62 | model.bias.data.zero_() 63 | elif classname.find("BatchNorm2d") != -1: 64 | init.constant_(model.weight.data, 1.0) 65 | init.constant_(model.bias.data, 0.0) 66 | 67 | 68 | def weights_init_orthogonal(model: nn.Module) -> None: 69 | """Fills the model weights to be orthogonal matrices. 70 | 71 | Args: 72 | model: The network. 73 | """ 74 | classname = model.__class__.__name__ 75 | if classname.find("Conv") != -1: 76 | init.orthogonal_(model.weight.data) 77 | if model.bias is not None: 78 | model.bias.data.zero_() 79 | elif classname.find("Linear") != -1: 80 | init.orthogonal_(model.weight.data) 81 | if model.bias is not None: 82 | model.bias.data.zero_() 83 | elif classname.find("BatchNorm2d") != -1: 84 | init.constant_(model.weight.data, 1.0) 85 | init.constant_(model.bias.data, 0.0) 86 | 87 | 88 | def init_weights(net: nn.Module, init_type: str = "kaiming", scale: float = 1, std: float = 0.02) -> None: 89 | """Initializes network weights. 90 | 91 | Args: 92 | net: The neural network. 93 | init_type: One of "normal", "kaiming" or "orthogonal". 94 | scale: Scaling factor of weights used in kaiming initialization. 95 | std: Standard deviation of Gaussian distrbiution used in normal initialization. 96 | """ 97 | logger.info("Initialization method [{:s}]".format(init_type)) 98 | if init_type == "normal": 99 | weights_init_normal_ = functools.partial(weights_init_normal, std=std) 100 | net.apply(weights_init_normal_) 101 | elif init_type == "kaiming": 102 | weights_init_kaiming_ = functools.partial( 103 | weights_init_kaiming, scale=scale) 104 | net.apply(weights_init_kaiming_) 105 | elif init_type == "orthogonal": 106 | net.apply(weights_init_orthogonal) 107 | else: 108 | raise NotImplementedError("Initialization method [{:s}] not implemented".format(init_type)) 109 | 110 | 111 | def define_network(in_channel, out_channel, norm_groups, inner_channel, 112 | channel_multiplier, attn_res, res_blocks, dropout, 113 | diffusion_loss, conditional, gpu_ids, distributed, init_method, height) -> nn.Module: 114 | """Defines Gaussian Diffusion model for single image super-resolution task. 115 | 116 | Args: 117 | in_channel: The number of channels of input tensor of U-Net. 118 | out_channel: The number of channels of output tensor of U-Net. 119 | norm_groups: The number of groups for group normalization. 120 | inner_channel: Timestep embedding dimension. 121 | channel_multiplier: A tuple specifying the scaling factors of channels. 122 | attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer. 123 | res_blocks: The number of residual blocks. 124 | dropout: Dropout probability. 125 | diffusion_loss: Either l1 or l2. 126 | conditional: Whether to condition on INTERPOLATED image or not. 127 | gpu_ids: IDs of gpus. 128 | distributed: Whether the computation will be distributed among multiple GPUs or not. 129 | init_method: NN weight initialization method. One of normal, kaiming or orthogonal inisializations. 130 | height: U-Net input tensor height value. 131 | 132 | Returns: 133 | A Gaussian Diffusion model. 134 | """ 135 | 136 | network = UNet(in_channel=in_channel, 137 | out_channel=out_channel, 138 | norm_groups=norm_groups if norm_groups else 32, 139 | inner_channel=inner_channel, 140 | channel_mults=channel_multiplier, 141 | attn_res=attn_res, 142 | res_blocks=res_blocks, 143 | dropout=dropout, 144 | height=height) 145 | 146 | model = GaussianDiffusion(network, loss_type=diffusion_loss, conditional=conditional) 147 | init_weights(model, init_type=init_method) 148 | 149 | if gpu_ids and distributed: 150 | assert torch.cuda.is_available() 151 | model = nn.DataParallel(model) 152 | 153 | return model 154 | -------------------------------------------------------------------------------- /report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davitpapikyan/Probabilistic-Downscaling-of-Climate-Variables/2f916c8fa990779fae21c4465a95e08bce317599/report.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.4 2 | torch==1.10.0 3 | tensorboardX==2.4.1 4 | tensorboard==2.7.0 5 | aim==3.3.3 6 | Shapely==1.8.0 7 | pyproj==3.3.0 8 | cartopy==0.19.0.post1 9 | matplotlib==3.5.0 10 | python-dateutil==2.8.2 11 | xarray==0.20.2 -------------------------------------------------------------------------------- /results/reverse_diffusion_steps.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davitpapikyan/Probabilistic-Downscaling-of-Climate-Variables/2f916c8fa990779fae21c4465a95e08bce317599/results/reverse_diffusion_steps.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """The training script for DDPM model. 2 | """ 3 | import argparse 4 | import logging 5 | import os 6 | import pickle 7 | import warnings 8 | from collections import OrderedDict, defaultdict 9 | 10 | import numpy as np 11 | import torch 12 | from aim import Run, num_utils 13 | from tensorboardX import SummaryWriter 14 | from torch.nn.functional import mse_loss, l1_loss 15 | from torch.utils.data import DataLoader 16 | 17 | import model 18 | from config import Config 19 | from utils import dict2str, setup_logger, construct_and_save_wbd_plots, \ 20 | accumulate_statistics, get_transformation, \ 21 | get_optimizer, construct_mask, reverse_transform_candidates, set_seeds 22 | from weatherbench_data import collate_wb_batch, create_datasets, create_dataloaders 23 | from weatherbench_data.utils import reverse_transform, reverse_transform_tensor, prepare_test_data 24 | 25 | warnings.filterwarnings("ignore") 26 | 27 | 28 | if __name__ == "__main__": 29 | set_seeds() # For reproducability. 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("-c", "--config", type=str, help="JSON file for configuration") 33 | parser.add_argument("-p", "--phase", type=str, choices=["train", "val"], 34 | help="Run either training or validation(inference).", default="train") 35 | parser.add_argument("-gpu", "--gpu_ids", type=str, default=None) 36 | args = parser.parse_args() 37 | configs = Config(args) 38 | 39 | torch.backends.cudnn.enabled = True 40 | torch.backends.cudnn.benchmark = True 41 | 42 | setup_logger(None, configs.log, "train", screen=True) 43 | setup_logger("val", configs.log, "val") 44 | logger = logging.getLogger("base") 45 | val_logger = logging.getLogger("val") 46 | logger.info(dict2str(configs.get_hyperparameters_as_dict())) 47 | tb_logger = SummaryWriter(log_dir=configs.tb_logger) 48 | 49 | aim_logger = Run(run_hash=configs.name, repo='./experiments/aim/', experiment=configs.name) 50 | aim_logger["hparams"] = {"train_min_date": configs.train_min_date, "train_max_date": configs.train_max_date, 51 | "val_min_date": configs.val_min_date, "val_max_date": configs.val_max_date, 52 | "variables": configs.variables, "transformation": configs.transformation, 53 | "tranform_monthly": configs.tranform_monthly, "batch_size": configs.batch_size, 54 | "norm_groups": configs.norm_groups, "dropout": configs.dropout, 55 | "diffusion_loss": configs.diffusion_loss, "init_method": configs.init_method, 56 | "train_schedule": configs.train_schedule, "val_schedule": configs.val_schedule, 57 | "optimizer": configs.optimizer_type, "learning_rate": configs.lr} 58 | 59 | transformation = get_transformation(configs.transformation) 60 | train_data, val_data, metadata, transformations = create_datasets(dataroot=configs.dataroot, 61 | name=configs.name, 62 | train_min_date=configs.train_min_date, 63 | train_max_date=configs.train_max_date, 64 | val_min_date=configs.val_min_date, 65 | val_max_date=configs.val_max_date, 66 | variables=configs.variables, 67 | transformation=transformation, 68 | storage_root=configs.experiments_root, 69 | apply_tranform_monthly=configs.tranform_monthly) 70 | logger.info(f"Train size: {len(train_data)}, Val size: {len(val_data)}.") 71 | train_loader, val_loader = create_dataloaders(train_data, val_data, batch_size=configs.batch_size, 72 | use_shuffle=configs.use_shuffle, num_workers=configs.num_workers) 73 | logger.info("Training and Validation dataloaders are ready.") 74 | 75 | # Defining the model. 76 | optimizer = get_optimizer(configs.optimizer_type) 77 | diffusion = model.create_model(in_channel=configs.in_channel, out_channel=configs.out_channel, 78 | norm_groups=configs.norm_groups, inner_channel=configs.inner_channel, 79 | channel_multiplier=configs.channel_multiplier, attn_res=configs.attn_res, 80 | res_blocks=configs.res_blocks, dropout=configs.dropout, 81 | diffusion_loss=configs.diffusion_loss, conditional=configs.conditional, 82 | gpu_ids=configs.gpu_ids, distributed=configs.distributed, 83 | init_method=configs.init_method, train_schedule=configs.train_schedule, 84 | train_n_timestep=configs.train_n_timestep, 85 | train_linear_start=configs.train_linear_start, 86 | train_linear_end=configs.train_linear_end, 87 | val_schedule=configs.val_schedule, val_n_timestep=configs.val_n_timestep, 88 | val_linear_start=configs.val_linear_start, val_linear_end=configs.val_linear_end, 89 | finetune_norm=configs.finetune_norm, optimizer=optimizer, amsgrad=configs.amsgrad, 90 | learning_rate=configs.lr, checkpoint=configs.checkpoint, 91 | resume_state=configs.resume_state, phase=configs.phase, height=configs.height) 92 | logger.info("Model initialization is finished.") 93 | 94 | current_step, current_epoch = diffusion.begin_step, diffusion.begin_epoch 95 | if configs.resume_state: 96 | logger.info(f"Resuming training from epoch: {current_epoch}, iter: {current_step}.") 97 | 98 | logger.info("Starting the training.") 99 | diffusion.set_new_noise_schedule(schedule=configs.train_schedule, n_timestep=configs.train_n_timestep, 100 | linear_start=configs.train_linear_start, linear_end=configs.train_linear_end) 101 | 102 | accumulated_statistics = OrderedDict() 103 | 104 | # Creating placeholder for storing validation metrics Mean Squared Error, Root MSE, Mean Residual. 105 | val_metrics = OrderedDict({"MSE": 0.0, "RMSE": 0.0, "MAE": 0.0, "MR": 0.0}) 106 | 107 | # Training. 108 | while current_step < configs.n_iter: 109 | current_epoch += 1 110 | 111 | for train_data in train_loader: 112 | current_step += 1 113 | 114 | if current_step > configs.n_iter: 115 | break 116 | 117 | # Training. 118 | diffusion.feed_data(train_data) 119 | diffusion.optimize_parameters() 120 | # diffusion.lr_scheduler_step() # For lr scheduler updates per iteration. 121 | accumulate_statistics(diffusion.get_current_log(), accumulated_statistics) 122 | 123 | # Logging the training information. 124 | if current_step % configs.print_freq == 0: 125 | message = f"Epoch: {current_epoch:5} | Iteration: {current_step:8}" 126 | 127 | for metric, values in accumulated_statistics.items(): 128 | mean_value = np.mean(values) 129 | message = f"{message} | {metric:s}: {mean_value:.5f}" 130 | tb_logger.add_scalar(f"{metric}/train", mean_value, current_step) 131 | aim_logger.track(num_utils.convert_to_py_number(mean_value), name=metric, step=current_step, 132 | epoch=current_epoch, context={"subset": "train"}) 133 | 134 | logger.info(message) 135 | # tb_logger.add_scalar(f"learning_rate", diffusion.get_lr(), current_step) 136 | 137 | # Visualizing distributions of parameters. 138 | for name, param in diffusion.get_named_parameters(): 139 | tb_logger.add_histogram(name, param.clone().cpu().data.numpy(), current_step) 140 | 141 | accumulated_statistics = OrderedDict() 142 | 143 | # Validation. 144 | if current_step % configs.val_freq == 0: 145 | logger.info("Starting validation.") 146 | idx = 0 147 | result_path = f"{configs.results}/{current_epoch}" 148 | os.makedirs(result_path, exist_ok=True) 149 | diffusion.set_new_noise_schedule(schedule=configs.val_schedule, 150 | n_timestep=configs.val_n_timestep, 151 | linear_start=configs.val_linear_start, 152 | linear_end=configs.val_linear_end) 153 | 154 | # A dictionary for storing a list of mean temperatures for each month. 155 | month2mean_temperature = defaultdict(list) 156 | 157 | for val_data in val_loader: 158 | idx += 1 159 | diffusion.feed_data(val_data) 160 | diffusion.test(continuous=False) # Continues=False to return only the last timesteps's outcome. 161 | 162 | # Computing metrics on vlaidation data. 163 | visuals = diffusion.get_current_visuals() 164 | 165 | inv_visuals = reverse_transform(visuals, transformations, 166 | configs.variables, diffusion.get_months(), 167 | configs.tranform_monthly) 168 | 169 | # Computing MSE and RMSE on original data. 170 | mse_value = mse_loss(inv_visuals["HR"], inv_visuals["SR"]) 171 | val_metrics["MSE"] += mse_value 172 | val_metrics["RMSE"] += torch.sqrt(mse_value) 173 | val_metrics["MAE"] += l1_loss(inv_visuals["HR"], inv_visuals["SR"]) 174 | 175 | mean_temp_pred = inv_visuals["SR"].mean(axis=[1, 2, 3]) 176 | for m, t in zip(diffusion.get_months(), mean_temp_pred): 177 | month2mean_temperature[int(m)].append(t) 178 | 179 | # Computing residuals for visualization. 180 | residuals = inv_visuals["SR"] - inv_visuals["HR"] 181 | val_metrics["MR"] += residuals.mean() 182 | 183 | if idx % configs.val_vis_freq == 0: 184 | path = f"{result_path}/{current_epoch}_{current_step}_{idx}" 185 | logger.info(f"[{idx//configs.val_vis_freq}] Visualizing and storing some examples.") 186 | 187 | sr_candidates = diffusion.generate_multiple_candidates(n=configs.sample_size) 188 | reverse_transform_candidates(sr_candidates, reverse_transform_tensor, 189 | transformations, configs.variables, 190 | "hr", diffusion.get_months(), 191 | configs.tranform_monthly) 192 | mean_candidate = sr_candidates.mean(dim=0) # [B, C, H, W] 193 | std_candidate = sr_candidates.std(dim=0) # [B, C, H, W] 194 | bias = mean_candidate - inv_visuals["HR"] 195 | mean_bias_over_pixels = bias.mean() # Scalar. 196 | std_bias_over_pixels = bias.std() # Scalar. 197 | 198 | # Computing min and max measures to set a fixed colorbar for all visualizations. 199 | vmin = min(inv_visuals["HR"][:configs.n_val_vis].min(), 200 | inv_visuals["SR"][:configs.n_val_vis].min(), 201 | inv_visuals["LR"][:configs.n_val_vis].min(), 202 | inv_visuals["INTERPOLATED"][:configs.n_val_vis].min(), 203 | mean_candidate[:configs.n_val_vis].min()) 204 | vmax = max(inv_visuals["HR"][:configs.n_val_vis].max(), 205 | inv_visuals["SR"][:configs.n_val_vis].max(), 206 | inv_visuals["LR"][:configs.n_val_vis].max(), 207 | inv_visuals["INTERPOLATED"][:configs.n_val_vis].max(), 208 | mean_candidate[:configs.n_val_vis].max()) 209 | 210 | # Choosing the first n_val_vis number of samples to visualize. 211 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 212 | data=inv_visuals["HR"][:configs.n_val_vis], 213 | path=f"{path}_hr.png", vmin=vmin, vmax=vmax) 214 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 215 | data=inv_visuals["SR"][:configs.n_val_vis], 216 | path=f"{path}_sr.png", vmin=vmin, vmax=vmax) 217 | construct_and_save_wbd_plots(latitude=metadata.lr_lat, longitude=metadata.lr_lon, 218 | data=inv_visuals["LR"][:configs.n_val_vis], 219 | path=f"{path}_lr.png", vmin=vmin, vmax=vmax) 220 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 221 | data=inv_visuals["INTERPOLATED"][:configs.n_val_vis], 222 | path=f"{path}_interpolated.png", vmin=vmin, vmax=vmax) 223 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 224 | data=construct_mask(residuals[:configs.n_val_vis]), 225 | path=f"{path}_residual.png", vmin=-1, vmax=1, 226 | costline_color="red", cmap="binary", 227 | label="Signum(SR - HR)") 228 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 229 | data=mean_candidate[:configs.n_val_vis], 230 | path=f"{path}_mean_sr.png", vmin=vmin, vmax=vmax) 231 | construct_and_save_wbd_plots(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 232 | data=std_candidate[:configs.n_val_vis], 233 | path=f"{path}_std_sr.png", vmin=0.0, cmap="Greens") 234 | 235 | # tb_logger.add_scalar(f"mean_bias_over_pixels/val", mean_bias_over_pixels, current_step) 236 | # tb_logger.add_scalar(f"std_bias_over_pixels/val", std_bias_over_pixels, current_step) 237 | 238 | aim_logger.track(num_utils.convert_to_py_number(mean_bias_over_pixels), 239 | name="mean_bias_over_pixels", step=current_step, 240 | epoch=current_epoch, context={"subset": "val"}) 241 | aim_logger.track(num_utils.convert_to_py_number(std_bias_over_pixels), 242 | name="std_bias_over_pixels", step=current_step, 243 | epoch=current_epoch, context={"subset": "val"}) 244 | 245 | # tb_figure = construct_tb_visualization(latitude=metadata.hr_lat, longitude=metadata.hr_lon, 246 | # data=(inv_visuals["INTERPOLATED"][-1].squeeze(), 247 | # inv_visuals["SR"][-1].squeeze(), 248 | # inv_visuals["HR"][-1].squeeze())) 249 | # tb_logger.add_figure(tag=f"Iter_{current_epoch}_{current_step}", 250 | # figure=tb_figure, global_step=idx) 251 | 252 | # Validation is finished. 253 | val_metrics["MSE"] /= idx 254 | val_metrics["RMSE"] /= idx 255 | val_metrics["MR"] /= idx 256 | val_metrics["MAE"] /= idx 257 | 258 | diffusion.set_new_noise_schedule(schedule=configs.train_schedule, 259 | n_timestep=configs.train_n_timestep, 260 | linear_start=configs.train_linear_start, 261 | linear_end=configs.train_linear_end) 262 | 263 | message = f"Epoch: {current_epoch:5} | Iteration: {current_step:8}" 264 | for metric, value in val_metrics.items(): 265 | message = f"{message} | {metric:s}: {value:.5f}" 266 | # tb_logger.add_scalar(f"{metric}/val", value, current_step) 267 | aim_logger.track(num_utils.convert_to_py_number(value), name=metric, step=current_step, 268 | epoch=current_epoch, context={"subset": "val"}) 269 | val_logger.info(message) 270 | 271 | val_metrics = val_metrics.fromkeys(val_metrics, 0.0) # Sets all metrics to zero. 272 | 273 | if current_step % configs.save_checkpoint_freq == 0: 274 | logger.info("Saving models and training states.") 275 | diffusion.save_network(current_epoch, current_step) 276 | 277 | # Learning rate scheduler step per iteration. 278 | # diffusion.lr_scheduler_step() # For lr scheduler updates per epoch. 279 | 280 | tb_logger.close() 281 | aim_logger.close() 282 | logger.info("End of training.") 283 | 284 | logger.info("Starting final evaluation on training set.") 285 | train_subset = prepare_test_data(variables=configs.variables, val_min_date=configs.train_subset_min_date, 286 | val_max_date=configs.train_subset_max_date, dataroot=configs.dataroot, 287 | transformations=transformations) 288 | train_loader = DataLoader(train_subset, batch_size=32, 289 | collate_fn=collate_wb_batch, 290 | pin_memory=True, num_workers=2) 291 | 292 | diffusion.set_new_noise_schedule(schedule=configs.val_schedule, 293 | n_timestep=configs.val_n_timestep, 294 | linear_start=configs.val_linear_start, 295 | linear_end=configs.val_linear_end) 296 | with torch.no_grad(): 297 | idx = 0 298 | train_metrics = OrderedDict({"MSE": 0.0, "RMSE": 0.0, "MAE": 0.0, "MR": 0.0}) 299 | for train_data in train_loader: 300 | idx += 1 301 | diffusion.feed_data(train_data) 302 | diffusion.test(continuous=False) 303 | visuals = diffusion.get_current_visuals() 304 | inv_visuals = reverse_transform(visuals, transformations, 305 | configs.variables, diffusion.get_months(), 306 | configs.tranform_monthly) 307 | mse_value = mse_loss(inv_visuals["HR"], inv_visuals["SR"]) 308 | train_metrics["MSE"] += mse_value 309 | train_metrics["RMSE"] += torch.sqrt(mse_value) 310 | train_metrics["MR"] += (inv_visuals["SR"] - inv_visuals["HR"]).mean() 311 | train_metrics["MAE"] += l1_loss(inv_visuals["HR"], inv_visuals["SR"]) 312 | 313 | train_metrics["MSE"] /= idx 314 | train_metrics["RMSE"] /= idx 315 | train_metrics["MR"] /= idx 316 | train_metrics["MAE"] /= idx 317 | 318 | message = f"Final evaluation on train set" 319 | for metric, value in train_metrics.items(): 320 | message = f"{message} | {metric:s}: {value:.5f}" 321 | logger.info(message) 322 | 323 | with open(f"{result_path}/month2mean_temperature.pickle", 'wb') as handle: 324 | pickle.dump(month2mean_temperature, handle, protocol=pickle.HIGHEST_PROTOCOL) 325 | 326 | logger.info("End of evaluation.") 327 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Defines auxiliary functions for fixing the seeds, setting 2 | a logger and visualizing WeatherBench data.""" 3 | import logging 4 | import os 5 | import random 6 | 7 | import cartopy.crs as ccrs 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | from cartopy.mpl.geoaxes import GeoAxes 13 | from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter 14 | from cartopy.util import add_cyclic_point 15 | from matplotlib.figure import Figure 16 | from mpl_toolkits.axes_grid1 import AxesGrid 17 | 18 | from weatherbench_data.transforms import Transform 19 | 20 | # Tensorboard visualization titles. 21 | TITLES = ("Upsampled with interpolation", 22 | "Super-resolution reconstruction", 23 | "High-resolution original") 24 | 25 | 26 | def set_seeds(seed: int = 0): 27 | """Sets random seeds of Python, NumPy and PyTorch. 28 | 29 | Args: 30 | seed: Seed value. 31 | """ 32 | random.seed(seed) 33 | os.environ["PYTHONHASHSEED"] = str(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed(seed) 37 | torch.backends.cudnn.deterministic = True 38 | 39 | 40 | def dict2str(dict_obj: dict, indent_l: int = 4) -> str: 41 | """Converts dictionary to string for printing out. 42 | 43 | Args: 44 | dict_obj: Dictionary or OrderedDict. 45 | indent_l: Left indentation level. 46 | 47 | Returns: 48 | Returns string version of opt. 49 | """ 50 | msg = "" 51 | for k, v in dict_obj.items(): 52 | if isinstance(v, dict): 53 | msg = f"{msg}{' '*(indent_l*2)}{k}:[\n{dict2str(v, indent_l+1)}{' '*(indent_l*2)}]\n" 54 | else: 55 | msg = f"{msg}{' '*(indent_l*2)}{k}: {v}\n" 56 | return msg 57 | 58 | 59 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): 60 | """Sets up the logger. 61 | 62 | Args: 63 | logger_name: The logger name. 64 | root: The directory of logger. 65 | phase: Either train or val. 66 | level: The level of logging. 67 | screen: If True then write logging records to a stream. 68 | """ 69 | logger = logging.getLogger(logger_name) 70 | formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") 71 | log_file = os.path.join(root, "{}.log".format(phase)) 72 | fh = logging.FileHandler(log_file, mode="w") 73 | fh.setFormatter(formatter) 74 | logger.setLevel(level) 75 | logger.addHandler(fh) 76 | if screen: 77 | sh = logging.StreamHandler() 78 | sh.setFormatter(formatter) 79 | logger.addHandler(sh) 80 | 81 | 82 | def construct_and_save_wbd_plot(latitude: np.array, longitude: np.array, single_variable: torch.tensor, 83 | path: str, title: str = None, label: str = None, dpi: int = 200, 84 | figsize: tuple = (11, 8.5), cmap: str = "coolwarm", vmin=None, 85 | vmax=None, costline_color="black"): 86 | """Creates and saves WeatherBench data visualization for a single variable. 87 | 88 | Args: 89 | latitude: An array of latitudes. 90 | longitude: An array of longitudes. 91 | single_variable: A tensor to visualize. 92 | path: Path of a directory to save visualization. 93 | title: Title of the figure. 94 | label: Label of the colorbar. 95 | dpi: Resolution of the figure. 96 | figsize: Tuple of (width, height) in inches. 97 | cmap: A matplotlib colormap. 98 | vmin: Minimum value for colormap. 99 | vmax: Maximum value for colormap. 100 | costline_color: Matplotlib color. 101 | """ 102 | single_variable, longitude = add_cyclic_point(single_variable, coord=np.array(longitude)) 103 | plt.figure(dpi=dpi, figsize=figsize) 104 | projection = ccrs.PlateCarree() 105 | ax = plt.axes(projection=projection) 106 | 107 | if cmap == "binary": 108 | # For mask visualization. 109 | p = plt.contourf(longitude, latitude, single_variable, 60, transform=projection, 110 | cmap=(matplotlib.colors.ListedColormap(["white", "gray", "black"]) 111 | .with_extremes(over="0.25", under="0.75")), 112 | vmin=-1, vmax=1) 113 | boundaries, ticks = [-1, -0.33, 0.33, 1], [-1, 0, 1] 114 | elif cmap == "coolwarm": 115 | # For temperature visualization. 116 | p = plt.contourf(longitude, latitude, single_variable, 60, transform=projection, cmap=cmap, 117 | levels=np.linspace(vmin, vmax, max(int(np.abs(vmax-vmin))//2, 3))) 118 | boundaries, ticks = None, np.round(np.linspace(vmin, vmax, 7), 2) 119 | 120 | elif cmap == "Greens": 121 | # For visualization of standard deviation. 122 | p = plt.contourf(longitude, latitude, single_variable, 60, transform=projection, cmap=cmap, 123 | extend="max") 124 | boundaries, ticks = None, np.linspace(single_variable.min(), single_variable.max(), 5) 125 | 126 | ax.set_xticks(np.linspace(-180, 180, 5), crs=projection) 127 | ax.set_yticks(np.linspace(-90, 90, 5), crs=projection) 128 | lon_formatter = LongitudeFormatter(zero_direction_label=True) 129 | lat_formatter = LatitudeFormatter() 130 | ax.xaxis.set_major_formatter(lon_formatter) 131 | ax.yaxis.set_major_formatter(lat_formatter) 132 | ax.coastlines(color=costline_color) 133 | 134 | plt.colorbar(p, pad=0.06, label=label, orientation="horizontal", shrink=0.75, 135 | boundaries=boundaries, ticks=ticks) 136 | 137 | plt.title(title) 138 | plt.savefig(path, bbox_inches="tight") 139 | plt.close("all") 140 | 141 | 142 | def add_batch_index(path: str, index: int): 143 | """Adds the number of batch gotten from data loader to path. 144 | 145 | Args: 146 | path: The path to which the function needs to add batch index. 147 | index: The batch index. 148 | 149 | Returns: 150 | The path with the index appended to the filename. 151 | """ 152 | try: 153 | filename, extension = path.split(".") 154 | except ValueError: 155 | splitted_parts = path.split(".") 156 | filename, extension = ".".join(splitted_parts[:-1]), splitted_parts[-1] 157 | return f"{filename}_{index}.{extension}" 158 | 159 | 160 | def construct_and_save_wbd_plots(latitude: np.array, longitude: np.array, data: torch.tensor, 161 | path: str, title: str = None, label: str = None, 162 | dpi: int = 200, figsize: tuple = (11, 8.5), cmap: str = "coolwarm", 163 | vmin=None, vmax=None, costline_color="black"): 164 | """Creates and saves WeatherBench data visualization. 165 | 166 | Args: 167 | latitude: An array of latitudes. 168 | longitude: An array of longitudes. 169 | data: A batch of variables to visualize. 170 | path: Path of a directory to save visualization. 171 | title: Title of the figure. 172 | label: Label of the colorbar. 173 | dpi: Resolution of the figure. 174 | figsize: Tuple of (width, height) in inches. 175 | cmap: A matplotlib colormap. 176 | vmin: Minimum value for colormap. 177 | vmax: Maximum value for colormap. 178 | costline_color: Matplotlib color. 179 | """ 180 | if len(data.shape) > 2: 181 | data = data.squeeze() 182 | 183 | if len(data.shape) > 2: 184 | for batch_index in range(data.shape[0]): 185 | path_for_sample = add_batch_index(path, batch_index) 186 | construct_and_save_wbd_plot(latitude, longitude, data[batch_index], path_for_sample, 187 | title, label, dpi, figsize, cmap, vmin, vmax, costline_color) 188 | else: 189 | construct_and_save_wbd_plot(latitude, longitude, data, path, title, label, dpi, figsize, cmap, 190 | vmin, vmax, costline_color) 191 | 192 | 193 | def construct_tb_visualization(latitude: np.array, longitude: np.array, data: tuple, label=None, 194 | dpi: int = 300, figsize: tuple = (22, 6), cmap: str = "coolwarm") -> Figure: 195 | """Construct tensorboard visualization figure. 196 | 197 | Args: 198 | latitude: An array of latitudes. 199 | longitude: An array of longitudes. 200 | data: A batch of variables to visualize. 201 | label: Label of the colorbar. 202 | dpi: Resolution of the figure. 203 | figsize: Tuple of (width, height) in inches. 204 | cmap: A matplotlib colormap. 205 | 206 | Returns: 207 | Matplotlib Figure. 208 | """ 209 | max_value = max((tensor.max() for tensor in data)) 210 | min_value = min((tensor.min() for tensor in data)) 211 | projection = ccrs.PlateCarree() 212 | axes_class = (GeoAxes, dict(map_projection=projection)) 213 | fig = plt.figure(figsize=figsize, dpi=dpi) 214 | axgr = AxesGrid(fig, 111, axes_class=axes_class, nrows_ncols=(1, 3), axes_pad=0.95, cbar_location="bottom", 215 | cbar_mode="single", cbar_pad=0.01, cbar_size="2%", label_mode='') 216 | lon_formatter = LongitudeFormatter(zero_direction_label=True) 217 | lat_formatter = LatitudeFormatter() 218 | 219 | for i, ax in enumerate(axgr): 220 | single_variable, lon = add_cyclic_point(data[i], coord=np.array(longitude)) 221 | ax.set_title(TITLES[i]) 222 | ax.gridlines(draw_labels=True, xformatter=lon_formatter, yformatter=lat_formatter, 223 | xlocs=np.linspace(-180, 180, 5), ylocs=np.linspace(-90, 90, 5)) 224 | p = ax.contourf(lon, latitude, single_variable, transform=projection, cmap=cmap, 225 | vmin=min_value, vmax=max_value) 226 | ax.coastlines() 227 | 228 | axgr.cbar_axes[0].colorbar(p, pad=0.01, label=label, shrink=0.95) 229 | fig.tight_layout() 230 | plt.close("all") 231 | return fig 232 | 233 | 234 | def accumulate_statistics(new_info: dict, storage: dict): 235 | """Accumulates statistics provided with new_info into storage. 236 | 237 | Args: 238 | new_info: A dictionary containing new information. 239 | storage: A dictionary where to accumulate new information. 240 | """ 241 | for key, value in new_info.items(): 242 | if key in storage: 243 | storage[key].append(value) 244 | else: 245 | storage[key] = [value] 246 | 247 | 248 | def get_transformation(name: str) -> Transform: 249 | """Return data transformation class corresponding to name. 250 | 251 | Args: 252 | name: The name of transformation. 253 | 254 | Returns: 255 | A data transformer. 256 | """ 257 | if name == "LocalStandardScaling": 258 | from weatherbench_data.transforms import LocalStandardScaling as Transformation 259 | elif name == "GlobalStandardScaling": 260 | from weatherbench_data.transforms import GlobalStandardScaling as Transformation 261 | return Transformation 262 | 263 | 264 | def get_optimizer(name: str) -> Transform: 265 | """Return optimization algorithm class corresponding to name. 266 | 267 | Args: 268 | name: The name of optimizer. 269 | 270 | Returns: 271 | A torch optimizer. 272 | """ 273 | if name == "adam": 274 | from torch.optim import Adam as Optimizer 275 | elif name == "adamw": 276 | from torch.optim import AdamW as Optimizer 277 | return Optimizer 278 | 279 | 280 | def construct_mask(x: torch.tensor) -> torch.tensor: 281 | """Constructs signum(x) tensor with tolerance around 0 specified 282 | by torch.isclose function. 283 | 284 | Args: 285 | x: The input tensor. 286 | 287 | Returns: 288 | Signum(x) with slight tolerance around 0. 289 | """ 290 | values = torch.ones_like(x) 291 | zero_mask = torch.isclose(x, torch.zeros_like(x)) 292 | neg_mask = x < 0 293 | values[neg_mask] = -1 294 | values[zero_mask] = 0 295 | return values 296 | 297 | 298 | def reverse_transform_candidates(candidates: torch.tensor, reverse_transform: Transform, 299 | transformations: dict, variables: list, data_type: str, 300 | months: list, tranform_monthly: bool): 301 | """Reverse transforms. 302 | 303 | Args: 304 | candidates: A tensor of shape [n, B, C, H, W]. 305 | reverse_transform: A reverse transformation. 306 | transformations: A dictionary of transformations. 307 | variables: Weatherbench data variables. 308 | data_type: Either 'lr' or 'hr'. 309 | months: A list of months for each batch sample of length (B, ). 310 | tranform_monthly: Either to apply transformation month-wise or not. 311 | 312 | Returns: 313 | Reversed transformed candidates. 314 | """ 315 | for i in range(candidates.shape[0]): 316 | candidates[i] = reverse_transform(candidates[i], transformations, variables, 317 | data_type, months, tranform_monthly) 318 | -------------------------------------------------------------------------------- /weatherbench_data/__init__.py: -------------------------------------------------------------------------------- 1 | """Defines dataset and dataloader creation functionalities. 2 | """ 3 | import logging 4 | 5 | import torch 6 | from torch.nn.functional import interpolate 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | from .datasets import WeatherBenchData, TimeVariateData 10 | from .datastorage import WeatherBenchNPYStorage 11 | from .fileconverter import DATETIME_FORMAT 12 | from .transforms import Transform 13 | from .utils import log_dataset_info, prepare_datasets 14 | 15 | 16 | def create_datasets(dataroot: str, name: str, train_min_date: str, train_max_date: str, 17 | val_min_date: str, val_max_date: str, variables: list, transformation: Transform, 18 | storage_root: str, apply_tranform_monthly: bool = True): 19 | """Creates transformed datasets. 20 | 21 | Args: 22 | dataroot: Path to the dataset. 23 | name: The name of the dataset. 24 | train_min_date: Minimum date starting from which to read the data for training. 25 | train_max_date: Maximum date until which to read the date for training. 26 | val_min_date: Minimum date starting from which to read the data for validation. 27 | val_max_date: Maximum date until which to read the date for validation. 28 | variables: variables: A list of WeatherBench variables. 29 | transformation: A transformation to fit. 30 | storage_root: A path to save metadata and fitted transformations. 31 | apply_tranform_monthly: Whether to apply transformation monthly or on the whole dataset. 32 | 33 | Returns: 34 | Training and validation datasets (already transformed), metadata containing longitude and 35 | latitude information for LR/HR data and monthly fitted transformations for eahc variable. 36 | """ 37 | train_dataset, val_dataset, metadata, transformations = prepare_datasets(variables, train_min_date, 38 | train_max_date, val_min_date, 39 | val_max_date, dataroot, 40 | transformation, storage_root, 41 | apply_tranform_monthly) 42 | logger = logging.getLogger("base") 43 | log_dataset_info(train_dataset, f"Train {name}", logger) 44 | log_dataset_info(val_dataset, f"Validation {name}", logger) 45 | logger.info("Finished.\n") 46 | return train_dataset, val_dataset, metadata, transformations 47 | 48 | 49 | def collate_wb_batch(samples: list): 50 | """Processes a list of samples to form a batch. 51 | 52 | Args: 53 | samples: A list of data points. 54 | 55 | Returns: 56 | A dictionary of the following items: 57 | LR – a low-resolution tensor, 58 | HR – a high-resolution tensor, 59 | INTERPOLATED – an upsampled low-resolution tensor with bicubic interpolation 60 | and a list of month indices corresponding to each sample. 61 | """ 62 | lr_tensors, hr_tensors, months = [], [], [] 63 | for lr, hr in samples: 64 | lr_tensors.append(torch.cat([variable[0] for variable in lr], dim=1)) 65 | hr_tensors.append(torch.cat([variable[0] for variable in hr], dim=1)) 66 | months.append(lr[0][2]) 67 | fake_tensors = [interpolate(tensor, scale_factor=4, mode="bicubic") for tensor in lr_tensors] 68 | return {"LR": torch.cat(lr_tensors), 69 | "HR": torch.cat(hr_tensors), 70 | "INTERPOLATED": torch.cat(fake_tensors)}, months 71 | 72 | 73 | def create_dataloaders(train_dataset: Dataset, val_dataset: Dataset, batch_size: int, 74 | use_shuffle: bool = True, num_workers: int = None): 75 | """Creates train/val dataloaders. 76 | 77 | Args: 78 | train_dataset: The training dataset. 79 | val_dataset: The validation dataset. 80 | batch_size: The size of a batch. 81 | use_shuffle: Either shuffle the training data or not. 82 | num_workers: The number of processes for multi-process data loading. 83 | 84 | Returns: 85 | Training and validations dataloaders. 86 | """ 87 | train_loader = DataLoader(train_dataset, 88 | batch_size=batch_size, 89 | collate_fn=collate_wb_batch, 90 | shuffle=use_shuffle, 91 | pin_memory=True, 92 | drop_last=True, 93 | num_workers=num_workers) 94 | 95 | validation_loader = DataLoader(val_dataset, 96 | batch_size=32, 97 | collate_fn=collate_wb_batch, 98 | pin_memory=True, 99 | drop_last=True, 100 | num_workers=num_workers) 101 | 102 | return train_loader, validation_loader 103 | -------------------------------------------------------------------------------- /weatherbench_data/datasets.py: -------------------------------------------------------------------------------- 1 | """Defines core classes to read WeatherBench data. 2 | """ 3 | 4 | from collections import OrderedDict 5 | from datetime import datetime 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | from .datastorage import WeatherBenchNPYStorage 12 | from .fileconverter import TEMPORAL_RESOLUTION, DATETIME_FORMAT 13 | 14 | 15 | def _parse_date_input(date_input, datetime_format=None): 16 | if date_input is None: 17 | return None 18 | input_type = type(date_input) 19 | if input_type == np.datetime64: 20 | return date_input 21 | elif input_type == datetime: 22 | return np.datetime64(date_input) 23 | elif input_type == str: 24 | if datetime_format is None: 25 | datetime_format = DATETIME_FORMAT 26 | try: 27 | date = datetime.strptime(date_input, datetime_format) 28 | except Exception: 29 | raise Exception( 30 | "[ERROR] Encountered invalid date string input (input: {}, datetime format: {}).".format( 31 | date_input, datetime_format 32 | ) 33 | ) 34 | return np.datetime64(date) 35 | else: 36 | raise Exception("[ERROR] Encountered invalid date input.") 37 | 38 | 39 | def _verify_date_bounds(min_date, max_date): 40 | assert (isinstance(min_date, np.datetime64) or min_date is None) and (isinstance(max_date, np.datetime64) or max_date is None), \ 41 | "[ERROR] Date bounds must be given as numpy.datetime64 objects." 42 | if min_date is not None: 43 | assert (min_date - np.datetime64("2020-01-01T00")) % TEMPORAL_RESOLUTION == np.timedelta64(0, "ms"), \ 44 | "[ERROR] Date bounds must be consistent with the temporal resolution of the data set ({}).".format( 45 | TEMPORAL_RESOLUTION 46 | ) 47 | if max_date is not None: 48 | assert (max_date - np.datetime64("2020-01-01T00")) % TEMPORAL_RESOLUTION == np.timedelta64(0, "ms"), \ 49 | "[ERROR] Date bounds must be consistent with the temporal resolution of the data set ({}).".format( 50 | TEMPORAL_RESOLUTION 51 | ) 52 | if min_date is not None and max_date is not None: 53 | assert max_date > min_date, "[ERROR] Lower date bound ({}) must be earlier than upper ({}).".format(min_date, max_date) 54 | 55 | 56 | class DefaultIdentityMapping(dict): 57 | def __missing__(self, key): 58 | return lambda x: x 59 | 60 | 61 | class TimeVariateData(Dataset): 62 | 63 | def __init__(self, source, name=None, lead_time=None, delays=None, min_date=None, 64 | max_date=None, transform: dict = None): 65 | assert isinstance(source, WeatherBenchNPYStorage) 66 | assert source.is_time_variate() 67 | self.name = name if name is not None else source.name 68 | if name is not None: 69 | assert isinstance(name, str) 70 | self.source = source 71 | self._lead_time = TEMPORAL_RESOLUTION * lead_time if lead_time is not None else None 72 | if delays is not None: 73 | assert isinstance(delays, list), "[ERROR] Delay parameter must be given as list." 74 | for d in delays: 75 | assert isinstance(d, int), "[ERROR] Delay parameter must be given as list of ints." 76 | if 0 not in delays: 77 | delays = [0] + delays 78 | delays = np.array(delays) 79 | assert len(delays) == len(np.unique(delays)), "[ERROR] Delays must be unique." 80 | self._delays = TEMPORAL_RESOLUTION * delays 81 | else: 82 | self._delays = None 83 | self.min_date = None 84 | self.max_date = None 85 | self._sample_index = None 86 | self.set_date_range(min_date, max_date) 87 | self._fitting_mode = False 88 | self._transform = transform if transform else DefaultIdentityMapping() 89 | 90 | def set_date_range(self, min_date=None, max_date=None, datetime_format=None): 91 | min_date = _parse_date_input(min_date, datetime_format) 92 | max_date = _parse_date_input(max_date, datetime_format) 93 | _verify_date_bounds(min_date, max_date) 94 | source_time_stamps = self.source.get_valid_time_stamps() 95 | source_min_date = np.min(source_time_stamps) 96 | source_max_date = np.max(source_time_stamps) + TEMPORAL_RESOLUTION 97 | admissible_min_date = source_min_date 98 | admissible_max_date = source_max_date 99 | if self._lead_time is not None: 100 | admissible_min_date = admissible_min_date - self._lead_time 101 | admissible_max_date = admissible_max_date - self._lead_time 102 | if self._delays is not None: 103 | admissible_min_date = admissible_min_date - np.min(self._delays) 104 | admissible_max_date = admissible_max_date - np.max(self._delays) 105 | if min_date is None: 106 | self.min_date = admissible_min_date 107 | else: 108 | assert min_date >= admissible_min_date, \ 109 | "[ERROR] Requested minimum date ({}) is beyond the range of admissible dates ([{}] – [{}]).".format( 110 | min_date, admissible_min_date, admissible_max_date 111 | ) 112 | self.min_date = min_date 113 | if max_date is None: 114 | self.max_date = admissible_max_date 115 | else: 116 | assert max_date <= admissible_max_date, \ 117 | "[ERROR] Requested maximum date ({}) is beyond the range of admissible dates ([{}] – [{}]).".format( 118 | max_date, admissible_min_date, admissible_max_date 119 | ) 120 | self.max_date = max_date 121 | self._build_sample_index() 122 | return self 123 | 124 | def _build_sample_index(self): 125 | valid_samples = np.arange(self.min_date, self.max_date, TEMPORAL_RESOLUTION) 126 | self._sample_index = {i: time_stamp for i, time_stamp in enumerate(valid_samples)} 127 | 128 | def set_transform(self, transform: dict): 129 | self._transform = transform 130 | 131 | def __getitem__(self, item): 132 | time_stamp = self._sample_index[item] 133 | month = int(time_stamp.astype("datetime64[M]").astype(int) % 12 + 1) 134 | 135 | if month not in self._transform: 136 | month = 0 137 | 138 | if self._lead_time is not None: 139 | time_stamp = time_stamp + self._lead_time 140 | if self._fitting_mode or self._delays is None: 141 | return self._transform[month](self.source[time_stamp]), self.name, month 142 | else: 143 | return tuple((self._transform[month](self.source[delayed_time]), self.name, month) 144 | for delayed_time in (time_stamp + self._delays)) 145 | 146 | def __len__(self): 147 | return len(self._sample_index) 148 | 149 | def get_channel_count(self): 150 | source_channels = self.source.get_channel_count() 151 | if self._delays is not None: 152 | return len(self._delays) * source_channels 153 | else: 154 | return source_channels 155 | 156 | @staticmethod 157 | def _generate_batches(data, data_length: int, chunk_size: int = 50000): 158 | for start in range(0, data_length, chunk_size): 159 | yield [next(data) for _ in range(start, min(data_length, start+chunk_size))] 160 | 161 | def get_batch(self, indices): 162 | data = (self.__getitem__(i) for i in indices) 163 | for chunk_of_data in self._generate_batches(data, len(indices)): 164 | if self._delays is not None: 165 | yield tuple(torch.cat(d[0], dim=0) for d in chunk_of_data) 166 | else: 167 | yield torch.cat([tup[0] for tup in chunk_of_data], dim=0) 168 | 169 | def enable_fitting_mode(self): 170 | return self.set_fitting_mode(True) 171 | 172 | def disable_fitting_mode(self): 173 | return self.set_fitting_mode(False) 174 | 175 | def set_fitting_mode(self, mode): 176 | assert isinstance(mode, bool) 177 | self._fitting_mode = mode 178 | return self 179 | 180 | def get_fitting_mode(self): 181 | return self._fitting_mode 182 | 183 | @staticmethod 184 | def is_time_variate(): 185 | return True 186 | 187 | def summarize(self): 188 | return { 189 | "data_type": "TimeVariateData", 190 | "path": self.source.path, 191 | "date_range": [ 192 | self._numpy_date_to_datetime(self.min_date).strftime(DATETIME_FORMAT), 193 | self._numpy_date_to_datetime(self.max_date).strftime(DATETIME_FORMAT) 194 | ], 195 | "lead_time": self._lead_time, 196 | "delays": self._delays, 197 | } 198 | 199 | @staticmethod 200 | def _numpy_date_to_datetime(time_stamp): 201 | total_seconds = (time_stamp - np.datetime64("1970-01-01T00:00:00Z")) / np.timedelta64(1, "s") 202 | return datetime.utcfromtimestamp(total_seconds) 203 | 204 | def get_valid_time_stamps(self): 205 | return sorted(self._sample_index.values()) 206 | 207 | 208 | class ConstantData(Dataset): 209 | def __init__(self, source, name=None, min_date=None, max_date=None, datetime_format=None): 210 | assert isinstance(source, WeatherBenchNPYStorage) 211 | assert not source.is_time_variate() 212 | if name is not None: 213 | assert isinstance(name, str) 214 | self.name = name if name is not None else source.name 215 | self.source = source 216 | min_date = _parse_date_input(min_date, datetime_format) 217 | max_date = _parse_date_input(max_date, datetime_format) 218 | _verify_date_bounds(min_date, max_date) 219 | self.min_date = min_date 220 | self.max_date = max_date 221 | self._num_samples = None 222 | self.set_date_range(min_date, max_date) 223 | self._fitting_mode = False 224 | 225 | def __getitem__(self, item): 226 | if item < self._num_samples: 227 | return self.source[item] 228 | else: 229 | raise Exception("[ERROR] Requested item ({}) is beyond the range of valid item numbers ([0, {}]).".format(item, self._num_samples)) 230 | 231 | def __len__(self): 232 | return self._num_samples 233 | 234 | def set_date_range(self, min_date=None, max_date=None, datetime_format=None): 235 | min_date = _parse_date_input(min_date, datetime_format) 236 | max_date = _parse_date_input(max_date, datetime_format) 237 | _verify_date_bounds(min_date, max_date) 238 | self.min_date = min_date 239 | self.max_date = max_date 240 | if min_date is None or max_date is None: 241 | self._num_samples = 1 242 | else: 243 | self._num_samples = int((max_date - min_date) / TEMPORAL_RESOLUTION) 244 | return self 245 | 246 | def get_channel_count(self): 247 | return self.source.get_channel_count() 248 | 249 | def enable_fitting_mode(self): 250 | return self.set_fitting_mode(True) 251 | 252 | def disable_fitting_mode(self): 253 | return self.set_fitting_mode(False) 254 | 255 | def set_fitting_mode(self, mode): 256 | assert isinstance(mode, bool) 257 | self._fitting_mode = mode 258 | return self 259 | 260 | def get_fitting_mode(self): 261 | return self._fitting_mode 262 | 263 | @staticmethod 264 | def is_time_variate(): 265 | return False 266 | 267 | def summarize(self): 268 | return { 269 | "data_type": "ConstantData", 270 | "path": self.source.path 271 | } 272 | 273 | 274 | class WeatherBenchData(Dataset): 275 | def __init__(self, min_date=None, max_date=None, datetime_format=None): 276 | min_date = _parse_date_input(min_date, datetime_format) 277 | max_date = _parse_date_input(max_date, datetime_format) 278 | _verify_date_bounds(min_date, max_date) 279 | self.min_date = min_date 280 | self.max_date = max_date 281 | self.data_groups = OrderedDict({}) 282 | 283 | def add_data_group(self, group_key, datasets, _except_on_changing_date_bounds=False): 284 | self._verify_data_group_inputs(group_key, datasets) 285 | min_dates = [dataset.min_date for dataset in datasets if dataset.min_date is not None] 286 | if len(min_dates) > 0: 287 | common_min_date = np.max(min_dates) 288 | else: 289 | common_min_date = None 290 | if _except_on_changing_date_bounds: 291 | assert common_min_date == self.min_date, "[ERROR] Encountered missing time stamps." 292 | else: 293 | if (common_min_date is not None) and (self.min_date is None or common_min_date > self.min_date): 294 | self.min_date = common_min_date 295 | max_dates = [dataset.max_date for dataset in datasets if dataset.max_date is not None] 296 | if len(max_dates) > 0: 297 | common_max_date = np.min(max_dates) 298 | else: 299 | common_max_date = None 300 | if _except_on_changing_date_bounds: 301 | assert common_max_date == self.max_date, "[ERROR] Encountered missing time stamps." 302 | else: 303 | if (common_max_date is not None) and (self.max_date is None or common_max_date < self.max_date): 304 | self.max_date = common_max_date 305 | self.data_groups.update({group_key: {dataset.name: dataset for dataset in datasets}}) 306 | self._update_date_bounds() 307 | return self 308 | 309 | def _verify_data_group_inputs(self, group_key, datasets): 310 | assert isinstance(group_key, str), "[ERROR] Group keys must be of type string." 311 | assert group_key not in self.data_groups, "[ERROR] Group keys must be unique. Key <{}> is already existing.".format( 312 | group_key) 313 | if not isinstance(datasets, list): 314 | datasets = [datasets] 315 | for dataset in datasets: 316 | assert isinstance(dataset, (ConstantData, TimeVariateData)) 317 | "[ERROR] Datasets must be given as TimeVariateData or ConstantData objects or a list thereof." 318 | data_names = [dataset.name for dataset in datasets] 319 | assert len(data_names) == len( 320 | np.unique(data_names)), "[ERROR] Dataset names must be unique within a common parameter group." 321 | 322 | def remove_data_group(self, group_key): 323 | if group_key in self.data_groups: 324 | del self.data_groups[group_key] 325 | return self 326 | 327 | def _update_date_bounds(self): 328 | for group in self.data_groups.values(): 329 | for dataset in group.values(): 330 | dataset.set_date_range(self.min_date, self.max_date) 331 | 332 | def set_date_range(self, min_date=None, max_date=None, datetime_format=None): 333 | min_date = _parse_date_input(min_date, datetime_format) 334 | max_date = _parse_date_input(max_date, datetime_format) 335 | _verify_date_bounds(min_date, max_date) 336 | self.min_date = min_date 337 | self.max_date = max_date 338 | self._update_date_bounds() 339 | return self 340 | 341 | def __len__(self): 342 | if self.min_date is None or self.max_date is None: 343 | return 1 if len(self.data_groups) > 0 else 0 344 | else: 345 | return int((self.max_date - self.min_date) / TEMPORAL_RESOLUTION) if len(self.data_groups) else 0 346 | 347 | def __getitem__(self, item): 348 | # dataset[item][1] is the integer indicating month index (from 1 for Januray to 12 for December). 349 | # dataset[item][2] name of the variable. 350 | # dataset[item][0] tensor data. 351 | return tuple(tuple(dataset[item] for dataset in group.values()) for group in self.data_groups.values()) 352 | 353 | def get_data_names(self): 354 | return { 355 | group_key: tuple(dataset.name for dataset in group.values()) 356 | for group_key, group in self.data_groups.items() 357 | } 358 | 359 | def get_named_item(self, item): 360 | return { 361 | group_key: {dataset.name: dataset[item] for dataset in group.values()} 362 | for group_key, group in self.data_groups.items() 363 | } 364 | 365 | def get_channel_count(self, group_key=None): 366 | if group_key is None: 367 | return {gk: self.get_channel_count(group_key=gk) for gk in self.data_groups} 368 | elif group_key in self.data_groups: 369 | return np.sum([dataset.get_channel_count() for dataset in self.data_groups[group_key].values()]) 370 | else: 371 | raise Exception("[ERROR] Dataset does not contain a data group named <{}>.".format(group_key)) 372 | 373 | def get_valid_time_stamps(self): 374 | return np.arange(self.min_date, self.max_date, TEMPORAL_RESOLUTION) 375 | -------------------------------------------------------------------------------- /weatherbench_data/datastorage.py: -------------------------------------------------------------------------------- 1 | """Defines WeatherBenchNPYStorage class to read data in npy format. 2 | 3 | # TODO: Test the script. 4 | """ 5 | import json 6 | import os 7 | from datetime import datetime 8 | from itertools import chain 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from .fileconverter import DATETIME_FORMAT, TEMPORAL_RESOLUTION, \ 14 | DIRECTORY_NAME_META_DATA, DIRECTORY_NAME_SAMPLE_DATA, FILE_NAME_META_DATA 15 | 16 | 17 | class WeatherBenchNPYStorage(object): 18 | 19 | def __init__(self, path, domain_dimension=2): 20 | self._verify_path(path) 21 | self.path = os.path.abspath(path) 22 | self.domain_dimension=domain_dimension 23 | self.meta_data = None 24 | self._load_meta_data() 25 | assert len(self.meta_data["dims"]) >= domain_dimension 26 | self.name = self.meta_data["name"] 27 | self._is_time_variate = self.meta_data["time_variate"] 28 | self._samples = None 29 | self._read_sample_directory() 30 | 31 | @staticmethod 32 | def _verify_path(path): 33 | assert os.path.isdir(path), "[ERROR] <{}> is not a valid directory path.".format(path) 34 | contents = os.listdir(path) 35 | assert len(contents) == 2 and os.path.isdir(os.path.join(path, DIRECTORY_NAME_META_DATA)) and os.path.isdir(os.path.join(path, DIRECTORY_NAME_SAMPLE_DATA)),\ 36 | "[ERROR] <{}> does not follow the expected folder structure of a WeatherBench parameter directory.".format(path) 37 | 38 | def _load_meta_data(self): 39 | # load meta data file 40 | with open(os.path.join(self.path, DIRECTORY_NAME_META_DATA, FILE_NAME_META_DATA + ".json"), "r") as f: 41 | self.meta_data = json.load(f) 42 | coordinates = self.meta_data["coords"] 43 | # convert coordinate value lists to numpy arrays 44 | for c in coordinates: 45 | c.update({"values": np.array(c["values"])}) 46 | 47 | def _read_sample_directory(self): 48 | sample_directory = os.path.join(self.path, DIRECTORY_NAME_SAMPLE_DATA) 49 | if self._is_time_variate: 50 | sample_time_stamps = self._build_sample_index(sample_directory) 51 | self._verify_data_completeness(sample_time_stamps) 52 | else: 53 | self._load_constant_data(sample_directory) 54 | 55 | def _build_sample_index(self, sample_directory): 56 | sub_directories = [ 57 | d for d in sorted(os.listdir(sample_directory)) 58 | if os.path.isdir(os.path.join(sample_directory, d)) 59 | ] 60 | samples = [] 61 | time_stamps = [] 62 | for sub_directory in sub_directories: 63 | contents_s = [] 64 | contents_t = [] 65 | for f in sorted(os.listdir(os.path.join(sample_directory, sub_directory))): 66 | if self._matches_sample_file_convention(f): 67 | contents_s.append(os.path.join(sample_directory, sub_directory, f)) 68 | contents_t.append(self._file_name_to_datetime(f)) 69 | samples.append(contents_s) 70 | time_stamps.append(contents_t) 71 | samples = np.array(list(chain.from_iterable(samples))) 72 | time_stamps = np.array(list(chain.from_iterable(time_stamps))) 73 | sorting_index = np.argsort(time_stamps) 74 | self._samples = (time_stamps[0], samples[sorting_index]) 75 | return time_stamps[sorting_index] 76 | 77 | @staticmethod 78 | def _verify_data_completeness(sample_time_stamps): 79 | # verify that the data covers a comprehensive range of time stamps 80 | min_date = sample_time_stamps[0] 81 | max_date = sample_time_stamps[-1] 82 | assert len(sample_time_stamps) == int((max_date - min_date) / TEMPORAL_RESOLUTION) + 1, \ 83 | "[ERROR] encountered missing data values." 84 | assert np.all(np.diff(sample_time_stamps) == TEMPORAL_RESOLUTION) 85 | 86 | def _matches_sample_file_convention(self, f): 87 | if not f.endswith(".npy"): 88 | return False 89 | f_split = f.split(".") 90 | if len(f_split) > 2: 91 | return False 92 | try: 93 | date = self._file_name_to_datetime(f) 94 | except: 95 | return False 96 | return True 97 | 98 | @staticmethod 99 | def _file_name_to_datetime(f): 100 | return np.datetime64(datetime.strptime(f.split(".")[0], DATETIME_FORMAT)) 101 | 102 | def _load_constant_data(self, sample_directory): 103 | data = torch.tensor(np.load(os.path.join(sample_directory, "constant.npy"))) 104 | self._samples = self._to_pytorch_standard_shape(data) 105 | 106 | def _to_pytorch_standard_shape(self, data): 107 | dim = len(data.shape) 108 | domain_dim = self.domain_dimension 109 | # care for channel dimensions 110 | if dim == domain_dim: 111 | data = data.unsqueeze(dim=0) 112 | elif dim > domain_dim + 1: 113 | data = torch.flatten(data, start_dim=0, end_dim=-(domain_dim + 1)) 114 | # add batch (time) dimension 115 | return data.unsqueeze(dim=0) 116 | 117 | def __len__(self): 118 | if self._is_time_variate: 119 | return len(self._samples[1]) 120 | else: 121 | return 1 122 | 123 | def __getitem__(self, item): 124 | if self._is_time_variate: 125 | idx = int((item - self._samples[0]) / TEMPORAL_RESOLUTION) 126 | data = torch.tensor(np.load(self._samples[1][idx])) 127 | return self._to_pytorch_standard_shape(data) 128 | else: 129 | return self._samples 130 | 131 | def get_valid_time_stamps(self): 132 | if self._is_time_variate: 133 | min_date = self._samples[0] 134 | return np.arange(min_date, min_date + len(self._samples[1]) * TEMPORAL_RESOLUTION, TEMPORAL_RESOLUTION) 135 | else: 136 | return None 137 | 138 | def is_time_variate(self): 139 | return self._is_time_variate 140 | 141 | def get_channel_count(self): 142 | count = 1 143 | for axis_length in self.meta_data["shape"][0:-self.domain_dimension]: 144 | count = count * axis_length 145 | return int(count) 146 | -------------------------------------------------------------------------------- /weatherbench_data/fileconverter.py: -------------------------------------------------------------------------------- 1 | """Defines NetCDF to NumPy convertor for WeatherBench data. 2 | """ 3 | import json 4 | import os 5 | import shutil 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import xarray as xr 10 | 11 | DATETIME_FORMAT = "%Y-%m-%d-%H" 12 | TEMPORAL_RESOLUTION = np.timedelta64(1, "h") 13 | DIRECTORY_NAME_META_DATA = "meta" 14 | FILE_NAME_META_DATA = "metadata" 15 | FILE_NAME_CONSTANT_DATA = "constant" 16 | DIRECTORY_NAME_SAMPLE_DATA = "samples" 17 | 18 | 19 | class NumpyEncoder(json.JSONEncoder): 20 | def default(self, obj): 21 | if isinstance(obj, np.integer): 22 | return int(obj) 23 | elif isinstance(obj, np.floating): 24 | return float(obj) 25 | elif isinstance(obj, np.ndarray): 26 | return obj.tolist() 27 | else: 28 | return super(NumpyEncoder, self).default(obj) 29 | 30 | 31 | class NetCDFNumpyConverter(object): 32 | 33 | def __init__(self, netcdf_extension=".nc", numpy_extension=".npy", datetime_format=DATETIME_FORMAT): 34 | self.NETCDF_EXTENSION = netcdf_extension 35 | self.NUMPY_EXTENSION = numpy_extension 36 | self.DATETIME_FORMAT = datetime_format 37 | self.source_directory = None 38 | self.data = None 39 | 40 | def set_source_directory(self, source_directory): 41 | self.source_directory = os.path.abspath(source_directory) 42 | return self 43 | 44 | def read_source_directory(self, source_directory=None, chunks=None, parallel=True): 45 | if self.source_directory is None and source_directory is None: 46 | raise Exception("[ERROR] Source directory must be set or given before it can be read.") 47 | elif source_directory is not None: 48 | self.set_source_directory(source_directory) 49 | data = xr.open_mfdataset( 50 | os.path.join(self.source_directory, "*" + self.NETCDF_EXTENSION), parallel=parallel, chunks=chunks 51 | ) 52 | if chunks is None and "time" in data.dims: 53 | data = xr.open_mfdataset( 54 | os.path.join(self.source_directory, "*" + self.NETCDF_EXTENSION), parallel=parallel, 55 | chunks={"time": 12} 56 | ) 57 | self.data = data 58 | return self 59 | 60 | def convert_to_pytorch_samples(self, target_directory, enable_batch_processing=False, batch_size=None, rename_vars=None, overwrite_previous=False): 61 | if self.source_directory is None: 62 | raise Exception("[ERROR] Source directory must be set and read pefore running the conversion.") 63 | if self.data is None: 64 | raise Exception("[ERROR] Source directory must be read pefore running the conversion.") 65 | print("[INFO] Converting NetCDF dataset at <{}> to PyTorch sample files.".format(self.source_directory)) 66 | if enable_batch_processing: 67 | assert batch_size is not None, "[ERROR] If batch-processing is enabled, a batch size must be given" 68 | else: 69 | batch_size = 0 70 | if rename_vars is None: 71 | rename_vars = {} # use rename_vars for renaming the variable folders upon conversion 72 | else: 73 | assert isinstance(rename_vars, dict) 74 | data_vars = self.data.data_vars 75 | if len(data_vars) == 0: 76 | print("[INFO] Selected data set did not contain any variables. No further actions required.") 77 | return 78 | target_directory = os.path.abspath(target_directory) 79 | if not os.path.isdir(target_directory): 80 | os.makedirs(target_directory) 81 | print("[INFO] Created target directory at <{}>.".format(target_directory)) 82 | for var_key in data_vars: 83 | print("[INFO] Processing data variable <{}>.".format(var_key)) 84 | data_var = data_vars[var_key] 85 | meta_folder, samples_folder = self._create_new_var_directory(target_directory, var_key, rename_vars, overwrite_previous) 86 | self._convert_meta_data(data_var, meta_folder) 87 | self._convert_sample_data(data_var, samples_folder, batch_size) 88 | 89 | def _create_new_var_directory(self, target_directory, var_key, rename_vars, overwrite_previous): 90 | directory_name = var_key if var_key not in rename_vars else rename_vars[var_key] 91 | var_directory = os.path.join(target_directory, directory_name) 92 | if os.path.isdir(var_directory): 93 | if len(os.listdir(var_directory)) > 0 and not overwrite_previous: 94 | raise Exception("[ERROR] Tried to create variable directory at <{}> but directory existed and was found to be not empty.") 95 | else: 96 | print("[INFO] Removing previously existing variable directory.") 97 | shutil.rmtree(var_directory, ignore_errors=True) 98 | os.makedirs(var_directory) 99 | print("[INFO] Created new variable directory at <{}>.".format(var_directory)) 100 | sub_directories = [] 101 | for folder_name in ["meta", "samples"]: 102 | sub_dir = os.path.join(var_directory, folder_name) 103 | os.makedirs(sub_dir) 104 | sub_directories.append(sub_dir) 105 | return tuple(sub_directories) 106 | 107 | def _convert_meta_data(self, data_var, meta_folder): 108 | print("[INFO] Reading meta data.") 109 | meta_data = {} 110 | meta_data.update({"name": data_var.name}) 111 | meta_data.update({"time_variate": "time" in list(data_var.dims)}) 112 | meta_data.update({"dims": [dim_name for dim_name in data_var.dims if dim_name != "time"]}) 113 | meta_data.update({"shape": [dim_length for dim_name, dim_length in zip(data_var.dims, data_var.data.shape) if dim_name != "time"]}) 114 | meta_data.update({"coords": []}) 115 | data_coords = self.data.coords 116 | for coord_key in data_coords: 117 | if coord_key != "time": 118 | axis = data_coords[coord_key] 119 | meta_data["coords"].append({ 120 | "name": axis.name, 121 | "values": axis.values.tolist(), 122 | "dims": list(axis.dims) 123 | }) 124 | meta_data.update({"attrs": {**self.data.attrs, **data_var.attrs}}) 125 | meta_data_file = os.path.join(meta_folder, FILE_NAME_META_DATA + ".json") 126 | with open(meta_data_file, "w") as f: 127 | json.dump(meta_data, f) 128 | print("[INFO] Stored meta data in <{}>.".format(meta_data_file)) 129 | 130 | def _convert_sample_data(self, data_var, samples_folder, batch_size): 131 | if "time" in data_var.dims: 132 | self._convert_temporal_samples(data_var, samples_folder, batch_size) 133 | else: 134 | self._convert_constant(data_var, samples_folder) 135 | 136 | def _convert_temporal_samples(self, data_var, samples_folder, batch_size): 137 | print("[INFO] Converting temporal samples.") 138 | time_stamps = data_var["time"].values 139 | time_axis = tuple(data_var.dims).index("time") 140 | assert len(time_stamps) == len(np.unique(time_stamps)), "[ERROR] Encountered data variable with non-unique time stamps." 141 | batches = np.array_split(time_stamps, np.ceil(len(time_stamps) / batch_size)) 142 | current_year = None 143 | storage_folder = None 144 | for sample_batch in batches: 145 | batch_data = np.array_split(data_var.sel(time=sample_batch).values, len(sample_batch), axis=time_axis) 146 | for time_stamp, data in zip(sample_batch, batch_data): 147 | time_stamp = self._numpy_date_to_datetime(time_stamp) 148 | if time_stamp.year != current_year: 149 | current_year = time_stamp.year 150 | storage_folder = os.path.join(samples_folder, "{}".format(current_year)) 151 | if not os.path.isdir(storage_folder): 152 | os.makedirs(storage_folder) 153 | np.save( 154 | os.path.join(storage_folder, self._file_name_from_time_stamp(time_stamp)), 155 | np.squeeze(data, axis=time_axis) 156 | ) 157 | 158 | def _numpy_date_to_datetime(self, time_stamp): 159 | total_seconds = (time_stamp - np.datetime64("1970-01-01T00:00:00Z")) / np.timedelta64(1, "s") 160 | return datetime.utcfromtimestamp(total_seconds) 161 | 162 | def _file_name_from_time_stamp(self, time_stamp): 163 | return time_stamp.strftime(self.DATETIME_FORMAT) + self.NUMPY_EXTENSION 164 | 165 | def _convert_constant(self, data_var, samples_folder): 166 | data = data_var.values 167 | np.save( 168 | os.path.join(samples_folder, FILE_NAME_CONSTANT_DATA + self.NUMPY_EXTENSION), 169 | data 170 | ) 171 | -------------------------------------------------------------------------------- /weatherbench_data/transforms.py: -------------------------------------------------------------------------------- 1 | """Defines various transformations for WeatherBench data. 2 | """ 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Transform(nn.Module): 9 | def __init__(self, requires_fit, exclude_at_evaluation=False): 10 | super(Transform, self).__init__() 11 | self.requires_fit = requires_fit 12 | self.exclude_at_evaluation = exclude_at_evaluation 13 | 14 | def transform(self, data): 15 | raise NotImplementedError() 16 | 17 | def out_channels(self, in_channels): 18 | return in_channels 19 | 20 | def forward(self, data): 21 | return self.transform(data) 22 | 23 | def is_data_adaptive(self): 24 | return self.requires_fit 25 | 26 | def summarize(self): 27 | return {"transform_type": self.__class__.__name__} 28 | 29 | 30 | class ReversibleTransform(Transform): 31 | def __init__(self, requires_fit, exclude_at_evaluation=False): 32 | super(ReversibleTransform, self).__init__( 33 | requires_fit=requires_fit, exclude_at_evaluation=exclude_at_evaluation 34 | ) 35 | 36 | def transform(self, data): 37 | raise NotImplementedError() 38 | 39 | def revert(self, data): 40 | raise NotImplementedError() 41 | 42 | 43 | class AdaptiveReversibleTransform(ReversibleTransform): 44 | def __init__(self, exclude_at_evaluation=False): 45 | super(AdaptiveReversibleTransform, self).__init__( 46 | requires_fit=True, exclude_at_evaluation=exclude_at_evaluation 47 | ) 48 | self._data_source = None 49 | 50 | def fit(self, dataset, batch_size=None, previous_transforms=None, disable_fitting_mode=False): 51 | if self._data_source is not None: 52 | raise Exception("[ERROR] Fit should only be called once on adaptive transform objects.") 53 | if previous_transforms is not None: 54 | assert isinstance(previous_transforms, list) 55 | for t in previous_transforms: 56 | assert isinstance(t, Transform) 57 | if not dataset.is_time_variate(): 58 | self._fit_to_batch(dataset, [0], previous_transforms) 59 | else: 60 | in_fitting_mode = dataset.get_fitting_mode() 61 | if in_fitting_mode != disable_fitting_mode: 62 | dataset.set_fitting_mode(disable_fitting_mode) 63 | if batch_size is None: 64 | self._fit_to_batch(dataset, np.arange(len(dataset)), previous_transforms) 65 | else: 66 | assert isinstance(batch_size, int) 67 | idx = np.arange(len(dataset)) 68 | batches = np.array_split(idx, np.ceil(len(idx) / batch_size)) 69 | for idx_batch in batches: 70 | self._fit_to_batch(dataset, idx_batch, previous_transforms) 71 | dataset.set_fitting_mode(in_fitting_mode) 72 | 73 | self._fill_data_source(dataset, previous_transforms) 74 | return self 75 | 76 | def _fit_to_batch(self, dataset, batch, previous_transforms): 77 | for data in dataset.get_batch(batch): 78 | if previous_transforms is not None: 79 | for t in previous_transforms: 80 | data = t.transform(data) 81 | self._update_parameters(data) 82 | 83 | def _fill_data_source(self, dataset, previous_transforms): 84 | self._data_source = dataset.summarize() 85 | if previous_transforms is not None: 86 | self._data_source.update({ 87 | "previous_transforms": [ 88 | t.summarize() for t in reversed(previous_transforms) 89 | ] 90 | }) 91 | 92 | def clear_data_source(self): 93 | self._data_source = None 94 | 95 | def _update_parameters(self, data): 96 | raise NotImplementedError() 97 | 98 | def transform(self, data): 99 | raise NotImplementedError() 100 | 101 | def revert(self, data): 102 | raise NotImplementedError() 103 | 104 | def summarize(self): 105 | summary = super(AdaptiveReversibleTransform, self).summarize() 106 | summary.update({"data_source": self._data_source}) 107 | return summary 108 | 109 | 110 | class StandardScaling(AdaptiveReversibleTransform): 111 | def __init__(self, unbiased=True): 112 | super(StandardScaling, self).__init__(exclude_at_evaluation=False) 113 | self._count = 0 114 | self._bias_correction = int(unbiased) 115 | self.register_buffer("_mean", None) 116 | self.register_buffer("_squared_differences", None) 117 | 118 | def _std(self): 119 | return torch.sqrt(self._squared_differences / (self._count - self._bias_correction)) 120 | 121 | def transform(self, data): 122 | return (data - self._mean) / self._std() 123 | 124 | def revert(self, data): 125 | return (self._std() * data) + self._mean 126 | 127 | def _update_parameters(self, data): 128 | data_stats = self._compute_stats(data) 129 | if self._mean is None: 130 | self._count, self._mean, self._squared_differences = data_stats 131 | return self 132 | return self._update_stats(*data_stats) 133 | 134 | def _compute_stats(self, data): 135 | raise NotImplementedError() 136 | 137 | def _update_stats(self, data_count, data_mean, data_squared_differences): 138 | new_count = self._count + data_count 139 | self._squared_differences += data_squared_differences 140 | self._squared_differences += (data_mean - self._mean)**2 * ((data_count * self._count) / new_count) 141 | self._mean = ((self._count * self._mean) + (data_count * data_mean)) / new_count 142 | self._count = new_count 143 | return self 144 | 145 | 146 | class LocalStandardScaling(StandardScaling): 147 | 148 | def _compute_stats(self, data): 149 | data_count = data.shape[0] 150 | data_mean = torch.mean(data, dim=0, keepdim=True) 151 | return data_count, data_mean, torch.sum(torch.square(data - data_mean), dim=0, keepdim=True) 152 | 153 | 154 | class LatitudeStandardScaling(StandardScaling): 155 | 156 | def _compute_stats(self, data): 157 | shape = data.shape 158 | data_count = shape[0] * shape[3] 159 | data_mean = torch.mean(data, dim=(0, 3), keepdim=True) 160 | return data_count, data_mean, torch.sum(torch.square(data - data_mean), dim=(0, 3), keepdim=True) 161 | 162 | 163 | class GlobalStandardScaling(StandardScaling): 164 | 165 | def _compute_stats(self, data): 166 | shape = data.shape 167 | data_count = shape[0] * shape[2] * shape[3] 168 | data_mean = torch.mean(data, dim=(0, 2, 3), keepdim=True) 169 | return data_count, data_mean, torch.sum(torch.square(data - data_mean), dim=(0, 2, 3), keepdim=True) 170 | 171 | 172 | class AngularTransform(ReversibleTransform): 173 | def __init__(self, mode="deg", clamp=True): 174 | super(AngularTransform, self).__init__(requires_fit=False, exclude_at_evaluation=False) 175 | assert mode in ["deg", "rad"], "[ERROR] Mode of angular transform must be \"deg\" or \"rad\"." 176 | self._deg = (mode == "deg") 177 | self._clamp = clamp 178 | 179 | def transform(self, data): 180 | output = data 181 | if self._deg: 182 | output = torch.deg2rad(output) 183 | return self._transform(output) 184 | 185 | @staticmethod 186 | def _transform(data): 187 | raise NotImplementedError() 188 | 189 | def revert(self, data): 190 | output = data 191 | if self._clamp: 192 | output = torch.clamp(output, min=-1, max=1) 193 | output = self._revert(output) 194 | if self._deg: 195 | output = torch.rad2deg(output) 196 | return output 197 | 198 | @staticmethod 199 | def _revert(data): 200 | raise NotImplementedError() 201 | 202 | 203 | class Cosine(AngularTransform): 204 | 205 | @staticmethod 206 | def _transform(data): 207 | return torch.cos(data) 208 | 209 | @staticmethod 210 | def _revert(data): 211 | return torch.acos(data) 212 | 213 | 214 | class Sine(AngularTransform): 215 | 216 | @staticmethod 217 | def _transform(data): 218 | return torch.sin(data) 219 | 220 | @staticmethod 221 | def _revert(data): 222 | return torch.asin(data) 223 | 224 | 225 | class PolarCoordinates(AngularTransform): 226 | 227 | @staticmethod 228 | def _transform(data): 229 | return torch.cat([torch.cos(data), torch.sin(data)], dim=1) 230 | 231 | @staticmethod 232 | def _revert(data): 233 | data = torch.chunk(data, 2, dim=1) 234 | return torch.angle(data[0] + 1j * data[1]) 235 | 236 | def out_channels(self, in_channels): 237 | return 2 * in_channels 238 | 239 | 240 | class GlobalRandomOffset(Transform): 241 | def __init__(self, minimum=0, maximum=1): 242 | assert maximum > minimum 243 | super(GlobalRandomOffset, self).__init__(requires_fit=False, exclude_at_evaluation=True) 244 | self._min = minimum 245 | self._max = maximum 246 | 247 | def transform(self, data): 248 | offset = torch.rand(data.shape[0], 1, 1, 1, device=data.device, dtype=data.dtype) 249 | return data + ((self._max - self.min) * offset + self._min) 250 | -------------------------------------------------------------------------------- /weatherbench_data/utils.py: -------------------------------------------------------------------------------- 1 | """Defines auxiliary functionalities for data transformation.""" 2 | import os 3 | import pickle 4 | from collections import OrderedDict 5 | from datetime import datetime 6 | from types import SimpleNamespace 7 | 8 | import torch 9 | from dateutil.relativedelta import relativedelta 10 | from torch.utils.data import Dataset, ConcatDataset 11 | 12 | from .datasets import WeatherBenchData, TimeVariateData 13 | from .datastorage import WeatherBenchNPYStorage 14 | from .fileconverter import DATETIME_FORMAT 15 | from .transforms import Transform 16 | 17 | ONE_MONTH = relativedelta(months=1) 18 | 19 | 20 | def reverse_transform_variable(transformations: dict, variable: str, data_type: str, months: list, 21 | tensor: torch.tensor, apply_tranform_monthly: bool) -> torch.tensor: 22 | """Inverse transforms a single variable. 23 | 24 | Args: 25 | transformations: A dictionary of monthly fitted LR/HR transformations for each variable used for training. 26 | variable: Variable name from WeatherBench dataset. 27 | data_type: Either lr or hr. 28 | months: A list of month indices. 29 | tensor: Tensor data of shape (batch_size, 1, H, W). 30 | apply_tranform_monthly: Whether to apply transformation monthly or on the whole dataset. 31 | 32 | Returns: 33 | Returns inverse transformed tensor by preserving the dimensionality. 34 | """ 35 | batch_size = tensor.shape[0] 36 | if apply_tranform_monthly: 37 | return torch.cat([transformations[variable][data_type][months[idx]].revert(tensor[idx]) 38 | for idx in range(batch_size)]) 39 | else: 40 | return torch.cat([transformations[variable][data_type][0].revert(tensor[idx]) 41 | for idx in range(batch_size)]) 42 | 43 | 44 | def reverse_transform_tensor(tensor: torch.tensor, transformations: dict, 45 | variables: list, data_type: str, months: list, 46 | apply_tranform_monthly: bool) -> torch.tensor: 47 | """Inverse transforms tensor. 48 | 49 | Args: 50 | tensor: Tensor data of shape (batch_size, number of variables, H, W). 51 | transformations: A dictionary of monthly fitted LR/HR transformations for each variable used for training. 52 | variables: A list of WeatherBench variables. 53 | data_type: Either lr or hr. 54 | months: A list of month indices. 55 | apply_tranform_monthly: Whether to apply transformation monthly or on the whole dataset. 56 | 57 | Returns: 58 | Inverse transformed single data point. 59 | """ 60 | reverse_transformed_tesnors = [] 61 | for index, variable in enumerate(variables): 62 | tensor_of_variable = tensor[:, index].unsqueeze(1) 63 | reverse_transformed_tesnors.append(reverse_transform_variable(transformations, variable, 64 | data_type, months, 65 | tensor_of_variable, 66 | apply_tranform_monthly)) 67 | return torch.cat(reverse_transformed_tesnors, dim=1) 68 | 69 | 70 | def reverse_transform(data: dict, transformations: dict, 71 | variables: list, months: list, apply_tranform_monthly: bool) -> dict: 72 | """Inverse transforms data stored in a dictionary. 73 | 74 | Args: 75 | data: Dictionary of data points. 76 | transformations: A dictionary of monthly fitted LR/HR transformations for each variable used for training. 77 | variables: A list of WeatherBench variables. 78 | months: A list of month indices. 79 | apply_tranform_monthly: Whether to apply transformation monthly or on the whole dataset. 80 | 81 | Returns: 82 | Inverse transformed data. 83 | """ 84 | reverse_transformed_batch = OrderedDict({}) 85 | for key, tensor in data.items(): 86 | if key == "LR": 87 | reverse_transformed_batch[key] = reverse_transform_tensor(tensor, transformations, variables, 88 | "lr", months, apply_tranform_monthly) 89 | else: 90 | reverse_transformed_batch[key] = reverse_transform_tensor(tensor, transformations, variables, 91 | "hr", months, apply_tranform_monthly) 92 | return reverse_transformed_batch 93 | 94 | 95 | def get_start_of_next_month(datetime_object: datetime) -> datetime: 96 | """Computes the start of the next month of the datetime_object. 97 | 98 | Args: 99 | datetime_object: A datetime object representing a particular datetime. 100 | 101 | Returns: 102 | The start date of the next month of datetime_object. 103 | """ 104 | return (datetime_object + ONE_MONTH).replace(day=1) 105 | 106 | 107 | def get_str_date(datetime_object: datetime) -> str: 108 | """Converts datetime object into string. 109 | 110 | Args: 111 | datetime_object: A datetime object representing a particular datetime. 112 | 113 | Returns: 114 | Returns datetime_object converted into string according to DATETIME_FORMAT. 115 | """ 116 | return datetime.strftime(datetime_object, DATETIME_FORMAT) 117 | 118 | 119 | def read_variable(dataroot: str, data_type: str, variable: str, min_date: str, max_date: str) -> TimeVariateData: 120 | """Reads a single variable data. 121 | 122 | Args: 123 | dataroot: Path to the dataset. 124 | data_type: Either lr or hr. 125 | variable: Variable name from WeatherBench dataset. 126 | min_date: Minimum date starting from which to read the data. 127 | max_date: Maximum date until which to read the date. 128 | 129 | Returns: 130 | TimeVariateData of variable. 131 | """ 132 | return TimeVariateData(WeatherBenchNPYStorage(os.path.join(dataroot, data_type, variable)), 133 | name=f"{variable}_{data_type}{min_date}", 134 | lead_time=0, min_date=min_date, max_date=max_date) 135 | 136 | 137 | def add_monthly_data(storage: dict, new_data: TimeVariateData, month: int) -> None: 138 | """Adds new_data to storage with a key month. 139 | 140 | Args: 141 | storage: A dictionary to add monthly data. Keys are indices of months. 142 | new_data: Data to add. 143 | month: To which month the data belongs. 144 | """ 145 | storage[month] = ConcatDataset([storage[month], new_data]) if month in storage else new_data 146 | 147 | 148 | def create_global_dataset(min_date: str, max_date: str, dataroot: str, data_type: str, variable: str) -> dict: 149 | """Reads data entirely and constructs a dictionary mapping 0 to the dataset. 150 | 151 | Args: 152 | min_date: Minimum date starting from which to read the data. 153 | max_date: Maximum date until which to read the date. 154 | dataroot: Path to the dataset. 155 | data_type: Either lr or hr. 156 | variable: Variable name from WeatherBench dataset. 157 | 158 | Returns: 159 | Dictionary mapping 0 to the dataset. 160 | """ 161 | return {0: read_variable(dataroot=dataroot, data_type=data_type, 162 | variable=variable, min_date=min_date, 163 | max_date=max_date)} 164 | 165 | 166 | def create_monthly_datasets(min_date: str, max_date: str, dataroot: str, data_type: str, variable: str) -> dict: 167 | """Reads data month by month and concatenates datasets of the same month. Constructs 168 | a dictionary mapping each month index to its corresponding dataset. 169 | 170 | Args: 171 | min_date: Minimum date starting from which to read the data. 172 | max_date: Maximum date until which to read the date. 173 | dataroot: Path to the dataset. 174 | data_type: Either lr or hr. 175 | variable: Variable name from WeatherBench dataset. 176 | 177 | Returns: 178 | Month to data mapping. 179 | """ 180 | month2data = {} 181 | max_date_datetime = datetime.strptime(max_date, DATETIME_FORMAT) 182 | start = datetime.strptime(min_date, DATETIME_FORMAT) 183 | start_of_next_month = start + ONE_MONTH 184 | 185 | while start_of_next_month < max_date_datetime: 186 | current_month = start.month 187 | data = read_variable(dataroot=dataroot, data_type=data_type, 188 | variable=variable, 189 | min_date=get_str_date(start), 190 | max_date=get_str_date(start_of_next_month)) 191 | add_monthly_data(month2data, data, current_month) 192 | start = start_of_next_month 193 | start_of_next_month = get_start_of_next_month(start_of_next_month) 194 | 195 | data = read_variable(dataroot=dataroot, data_type=data_type, 196 | variable=variable, 197 | min_date=get_str_date(start), 198 | max_date=get_str_date(max_date_datetime)) 199 | add_monthly_data(month2data, data, start.month) 200 | 201 | if not all(month in month2data for month in range(1, 13)): 202 | month2data[0] = ConcatDataset([data for data in month2data.values()]) 203 | 204 | return month2data 205 | 206 | 207 | def unpack_datasets(datasets) -> list: 208 | """Unpacks a concatenated datasets and creates alist of those datasets. 209 | 210 | Args: 211 | datasets: Either ConcatDataset object or TimeVariateData. 212 | 213 | Returns: 214 | A list of TimeVariateData datasets. If datasets is TimeVariateData, 215 | object, returns that object NOT in a list. 216 | """ 217 | return [unpack_datasets(dataset) for dataset in datasets.datasets] \ 218 | if isinstance(datasets, ConcatDataset) else datasets 219 | 220 | 221 | def flatten(list_of_lists): 222 | """Flattens a nested-list structure. 223 | 224 | Args: 225 | list_of_lists: A list of nested lists. 226 | 227 | Returns: 228 | Flattened 1-dimensional list. 229 | """ 230 | if not isinstance(list_of_lists, list): 231 | return [list_of_lists] 232 | if len(list_of_lists) == 0: 233 | return list_of_lists 234 | if isinstance(list_of_lists[0], list): 235 | return flatten(list_of_lists[0]) + flatten(list_of_lists[1:]) 236 | return list_of_lists[:1] + flatten(list_of_lists[1:]) 237 | 238 | 239 | def fit_monthly_transformations(datasets: list, transformation: Transform) -> Transform: 240 | """Fits a transformation to a list of datasets. 241 | 242 | Args: 243 | datasets: A list of datasets corresponding to the same month. 244 | transformation: A transformation to fit. 245 | 246 | Returns: 247 | A fitted transformation. 248 | """ 249 | transform = transformation() 250 | for data in flatten(unpack_datasets(datasets)): 251 | transform.fit(data) 252 | transform.clear_data_source() 253 | return transform 254 | 255 | 256 | def store_monthly_transformations(data: dict, transformation: Transform) -> dict: 257 | """Creates a month to transformation mapping. 258 | 259 | Args: 260 | data: A dictionary of datasets of each month. Keys are indices of months. 261 | transformation: A transformation to fit. 262 | 263 | Returns: 264 | A dictionary containing fitted transformation for each monthly data. 265 | """ 266 | return {month: fit_monthly_transformations(datasets, transformation) for month, datasets in data.items()} 267 | 268 | 269 | def fit_and_return_transformations(min_date: str, max_date: str, dataroot: str, data_type: str, 270 | variable: str, transformation: Transform, 271 | apply_tranform_monthly: bool = True) -> dict: 272 | """Creates monthly transformations. 273 | 274 | Args: 275 | min_date: Minimum date starting from which to read the data. 276 | max_date: Maximum date until which to read the date. 277 | dataroot: Path to the dataset. 278 | data_type: Either lr or hr. 279 | variable: Variable name from WeatherBench dataset. 280 | transformation: A transformation to fit. 281 | apply_tranform_monthly: Whether to apply transformation monthly or on the whole dataset. 282 | 283 | Returns: 284 | A dictionary mapping each month to its fitted transformaition. 285 | """ 286 | 287 | if apply_tranform_monthly: 288 | data = create_monthly_datasets(min_date, max_date, dataroot, data_type, variable) 289 | else: 290 | data = create_global_dataset(min_date, max_date, dataroot, data_type, variable) 291 | return store_monthly_transformations(data, transformation) 292 | 293 | 294 | def save_object(obj, path: str, filename: str) -> None: 295 | """Saves python object with pickle. 296 | 297 | Args: 298 | obj: Object to save. 299 | path: A directory where to save. 300 | filename: The name of a file in which to write. 301 | """ 302 | if not filename.endswith(".pkl"): 303 | filename = f"{filename}.pkl" 304 | 305 | with open(os.path.join(path, filename), "wb") as file: 306 | pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL) 307 | 308 | 309 | def load_object(path: str, filename: str): 310 | """Loads python object. 311 | 312 | Args: 313 | path: A directory where an object is saved. 314 | filename: The name of a file in which an object is saved. 315 | 316 | Returns: 317 | Loaded object if path and filename are correct, otherwise None. 318 | """ 319 | try: 320 | if not filename.endswith(".pkl"): 321 | filename = f"{filename}.pkl" 322 | 323 | with open(os.path.join(path, filename), "rb") as file: 324 | return pickle.load(file) 325 | 326 | except FileNotFoundError: 327 | return None 328 | 329 | 330 | def prepare_datasets(variables: list, train_min_date: str, train_max_date: str, val_min_date: str, val_max_date: str, 331 | dataroot: str, transformation: Transform, storage_root: str = None, 332 | apply_tranform_monthly: bool = True): 333 | """Reads datasets and fits trasformation. 334 | 335 | Args: 336 | variables: A list of WeatherBench variables. 337 | train_min_date: Minimum date starting from which to read the data for training. 338 | train_max_date: Maximum date until which to read the date for training. 339 | val_min_date: Minimum date starting from which to read the data for validation. 340 | val_max_date: Maximum date until which to read the date for validation. 341 | dataroot: Path to the dataset. 342 | transformation: A transformation to fit. 343 | storage_root: A path to save metadata and fitted transformations. 344 | apply_tranform_monthly: Whether to apply transformation monthly or on the whole dataset. 345 | 346 | Returns: 347 | Training and validation datasets, metadata and fitted transformations. 348 | """ 349 | train_datasets, val_datasets = {"lr": [], "hr": []}, {"lr": [], "hr": []} 350 | transformations, metadata = {}, {} 351 | 352 | for idx, variable in enumerate(variables): 353 | 354 | transformations[variable] = {} 355 | for data_type in ("lr", "hr"): 356 | month2transform = fit_and_return_transformations(train_min_date, train_max_date, dataroot, 357 | data_type, variable, transformation, 358 | apply_tranform_monthly) 359 | wbd_storage = WeatherBenchNPYStorage(os.path.join(dataroot, data_type, variable)) 360 | train_data = TimeVariateData(wbd_storage, name=f"train_{data_type}_{variable}", 361 | lead_time=0, min_date=train_min_date, 362 | max_date=train_max_date, transform=month2transform) 363 | # Updates metadata information for only first variable, other variables should have 364 | # the same latitudes and longitudes. 365 | if idx == 0: 366 | metadata.update({f"{data_type}_{dimension['name']}": dimension["values"] 367 | for dimension in wbd_storage.meta_data["coords"]}) 368 | transformations[variable][f"{data_type}"] = month2transform 369 | train_datasets[data_type].append(train_data) 370 | val_data = TimeVariateData(WeatherBenchNPYStorage(os.path.join(dataroot, data_type, variable)), 371 | name=f"{data_type}_{variable}", lead_time=0, 372 | min_date=val_min_date, max_date=val_max_date, 373 | transform=month2transform) 374 | val_datasets[data_type].append(val_data) 375 | 376 | train_dataset = WeatherBenchData(min_date=train_min_date, max_date=train_max_date) 377 | train_dataset.add_data_group("lr", train_datasets["lr"]) 378 | train_dataset.add_data_group("hr", train_datasets["hr"]) 379 | 380 | val_dataset = WeatherBenchData(min_date=val_min_date, max_date=val_max_date) 381 | val_dataset.add_data_group("lr", val_datasets["lr"]) 382 | val_dataset.add_data_group("hr", val_datasets["hr"]) 383 | 384 | metadata = SimpleNamespace(**metadata) 385 | 386 | if storage_root: 387 | save_object(metadata, storage_root, "metadata") 388 | save_object(transformations, storage_root, "transformations") 389 | 390 | return train_dataset, val_dataset, metadata, transformations 391 | 392 | 393 | def prepare_test_data(variables: list, val_min_date: str, val_max_date: str, 394 | dataroot: str, transformations: dict): 395 | """Creates testing data with already fitted transformations. 396 | 397 | Args: 398 | variables: A list of WeatherBench variables. 399 | val_min_date: Minimum date starting from which to read the data for validation. 400 | val_max_date: Maximum date until which to read the date for validation. 401 | dataroot: Path to the dataset. 402 | transformations: A dict of month to transformation mappings. 403 | 404 | Returns: 405 | Test data. 406 | """ 407 | val_datasets = {"lr": [], "hr": []} 408 | 409 | for idx, variable in enumerate(variables): 410 | for data_type in ("lr", "hr"): 411 | val_data = TimeVariateData(WeatherBenchNPYStorage(os.path.join(dataroot, data_type, variable)), 412 | name=f"{data_type}_{variable}", lead_time=0, 413 | min_date=val_min_date, max_date=val_max_date, 414 | transform=transformations[variable][data_type]) 415 | val_datasets[data_type].append(val_data) 416 | 417 | val_dataset = WeatherBenchData(min_date=val_min_date, max_date=val_max_date) 418 | val_dataset.add_data_group("lr", val_datasets["lr"]) 419 | val_dataset.add_data_group("hr", val_datasets["hr"]) 420 | 421 | return val_dataset 422 | 423 | 424 | def log_dataset_info(dataset: Dataset, dataset_name: str, logger) -> None: 425 | """Logs dataset information. 426 | 427 | Args: 428 | dataset: A pytorch dataset. 429 | dataset_name: The name of dataset. 430 | logger: Logging object. 431 | """ 432 | logger.info(f"Dataset [{dataset.__class__.__name__} - {dataset_name}] is created.") 433 | logger.info(f"""Created {dataset.__class__.__name__} dataset of length {len(dataset)}, containing data 434 | from {dataset.min_date} until {dataset.max_date}""") 435 | logger.info(f"Group structure: {dataset.get_data_names()}") 436 | logger.info(f"Channel count: {dataset.get_channel_count()}\n") 437 | --------------------------------------------------------------------------------