├── LICENSE ├── README.md ├── blacklist.json ├── config.py ├── hypercomplex ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── helpers.cpython-36.pyc │ ├── helpers.cpython-39.pyc │ ├── hypercomplex_layers.cpython-36.pyc │ ├── hypercomplex_layers.cpython-37.pyc │ ├── hypercomplex_layers.cpython-39.pyc │ ├── hypercomplex_layers2.cpython-36.pyc │ ├── hypercomplex_ops.cpython-36.pyc │ ├── hypercomplex_ops.cpython-37.pyc │ ├── hypercomplex_ops.cpython-39.pyc │ ├── hypercomplex_utils.cpython-36.pyc │ ├── hypercomplex_utils.cpython-37.pyc │ └── hypercomplex_utils.cpython-39.pyc ├── helpers.py ├── hypercomplex_layers.py ├── hypercomplex_layers2.py ├── hypercomplex_ops.py └── hypercomplex_utils.py ├── inference.py ├── layers.py ├── logs └── ALL_real_swinencoder3d_688080 │ └── 20210630T224355 │ ├── checkpoints │ └── epoch=58-val_loss=0.029748.ckpt │ ├── hparams.yaml │ ├── metrics.csv │ └── options.csv ├── main.py ├── model_pl.py ├── models.py ├── parametrize.py ├── requirements.txt ├── splits.csv ├── swin_transformer3d.py ├── t.npy ├── test_split.json ├── train_utils.py ├── utils ├── 1. Onboarding.ipynb ├── 2. Submission_UNet.ipynb ├── 3-train-UNet-example.py ├── blacklist.json ├── context_variables.py ├── data_utils.py ├── environment.yml ├── h5shape.py ├── splits.csv ├── test_split.json └── w4c_dataloader.py └── validation_metrics.py /LICENSE: -------------------------------------------------------------------------------- 1 | Open Source License 2 | Software 3 | 4 | Copyright (c) 2021 KHALIFA UIVERSITY FOR SCIENCE & TECHNOLOGY (KU). All rights reserved. 5 | 6 | Redistribution and use of the SOFTWARE, with or without modification, are permitted for Academic and Scholarly purposes only and not intended for commercial or business purposes; 7 | provided that the following conditions are met: 8 | 9 | 1. Redistributions of SOFTWARE must retain the above copyright notice, this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials 12 | provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 15 | IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 16 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 17 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 18 | 19 | The views and conclusions contained in the software/database and documentation are those of the authors and should not be interpreted as representing official policies, either expressed or implied, 20 | of KU. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weather4cast2021-SwinEncoderDecoder (AI4EX Team) 2 | 3 | ## Table of Content 4 | * [General Info](#general-info) 5 | * [Requirements](#requirements) 6 | * [Installation](#installation) 7 | * [Usage](#usage) 8 | * [Inference](#inference) 9 | 10 | ## General Info 11 | The resipository contains the code and learned model parameters for our submision in Weather4cast2021 stage-1 competition. 12 | 13 | ## Requirements 14 | This resipository depends on the following packages availability 15 | - Pytorch Lightning 16 | - timm 17 | - torch_optimizer 18 | - pytorch_model_summary 19 | - einops 20 | 21 | ## Installation: 22 | ``` 23 | unzip folder.zip 24 | cd folder 25 | conda create --name swinencoder_env python=3.6 26 | conda activate swinencoder_env 27 | conda install pytorch=1.9.0 cudatoolkit=10.2 -c pytorch 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ## Usage 32 | - a.1) train from scratch 33 | ``` 34 | python main.py --gpus 0 --use_all_region 35 | ``` 36 | - a.2) fine tune a model from a checkpoint 37 | ``` 38 | python main.py --gpu_id 1 --use_all_region --mode train --name ALL_real_swinencoder3d_688080 --time-code 20210630T224355 --initial-epoch 58``` 39 | 40 | - b.1) evaluate an untrained model (with random weights) 41 | ``` 42 | python main.py --gpus 0 --use_all_region --mode test 43 | ``` 44 | - b.2) evaluate a trained model from a checkpoint (submitted inference) 45 | ``` 46 | python main.py --gpu_id 1 --use_all_region --mode test --name ALL_real_swinencoder3d_688080 --time-code 20210630T224355 --initial-epoch 58 47 | ``` 48 | 49 | ## Inference 50 | To generate predictions using our trained model 51 | ``` 52 | R=R1 53 | INPUT_PATH=../data 54 | WEIGHTS=logs/ALL_real_swinencoder3d_688080 55 | OUT_PATH=. 56 | python inference.py -d $INPUT_PATH -r $R -w $WEIGHTS -o $OUT_PATH -g 1 57 | ``` 58 | -------------------------------------------------------------------------------- /blacklist.json: -------------------------------------------------------------------------------- 1 | { 2 | "2019056": [ 3 | 51, 4 | 52, 5 | 53, 6 | 54, 7 | 58 8 | ], 9 | "2019071": [ 10 | 52, 11 | 53, 12 | 54, 13 | 55 14 | ], 15 | "2019095": [ 16 | 51, 17 | 95 18 | ], 19 | "2019096": [ 20 | 95 21 | ], 22 | "2019100": [ 23 | 95 24 | ], 25 | "2019102": [ 26 | 48, 27 | 41, 28 | 95 29 | ], 30 | "2019103": [ 31 | 19, 32 | 95 33 | ], 34 | "2019105": [ 35 | 95 36 | ], 37 | "2019108": [ 38 | 95 39 | ], 40 | "2019109": [ 41 | 95 42 | ], 43 | "2019111": [ 44 | 95 45 | ], 46 | "2019126": [ 47 | 63 48 | ], 49 | "2019134": [ 50 | 41 51 | ], 52 | "2019144": [ 53 | 56 54 | ], 55 | "2019151": [ 56 | 88 57 | ], 58 | "2019182": [ 59 | 59, 60 | 61 61 | ], 62 | "2019183": [ 63 | 4, 64 | 5, 65 | 6, 66 | 7, 67 | 8, 68 | 9, 69 | 10, 70 | 11, 71 | 12, 72 | 13 73 | ], 74 | "2019229": [ 75 | 29 76 | ], 77 | "2019232": [ 78 | 88, 79 | 89, 80 | 90 81 | ], 82 | "2019236": [ 83 | 63 84 | ], 85 | "2019238": [ 86 | 69, 87 | 70, 88 | 71, 89 | 73, 90 | 74 91 | ], 92 | "2019239": [ 93 | 74, 94 | 75, 95 | 78 96 | ], 97 | "2019242": [ 98 | 69 99 | ], 100 | "2019256": [ 101 | 47 102 | ], 103 | "2019263": [ 104 | 77 105 | ], 106 | "2019281": [ 107 | 42 108 | ], 109 | "2019288": [ 110 | 48 111 | ], 112 | "2019289": [ 113 | 49 114 | ], 115 | "2019315": [ 116 | 43, 117 | 44, 118 | 56, 119 | 57, 120 | 58 121 | ], 122 | "2019180": [ 123 | 79, 124 | 80 125 | ], 126 | "2019299": [ 127 | 1, 128 | 2, 129 | 3, 130 | 4, 131 | 5, 132 | 6, 133 | 7, 134 | 8, 135 | 9, 136 | 10, 137 | 11, 138 | 12, 139 | 13, 140 | 14, 141 | 15, 142 | 16, 143 | 17, 144 | 18, 145 | 19, 146 | 20 147 | ], 148 | "2019352": [ 149 | 47 150 | ], 151 | "2020029": [ 152 | 9 153 | ] 154 | } -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | def prepare_crop(regions, region_id): 5 | """ this function prepares the expected parameters to crop images per region 6 | e.g., to crop latitudes to the region of interest 7 | """ 8 | x, y = regions[region_id]['up_left'] 9 | crop = {'x_start': x, 'y_start': y, 'size': regions[region_id]['size']} 10 | return crop 11 | 12 | def n_extra_vars(string_vars): 13 | """ computes how many extra variables will be used """ 14 | if string_vars=='': 15 | len_extra = 0 16 | else: 17 | len_extra = len(string_vars.split('-')) 18 | if 'l' in string_vars: 19 | len_extra += 1 # 'l' loads both lat/lon, so 2 vars (not 1) 20 | return len_extra 21 | 22 | def get_prod_name(product): 23 | """ get the folder name for each product. Note that only the folder containing ASII 24 | have a slightly different name 25 | """ 26 | if product=='ASII': 27 | product = 'ASII-TF' 28 | return product 29 | 30 | 31 | def get_params(region_id='R1', competition='stage-1', 32 | use_static=True, use_all_variables=False, use_cloud_type=False, use_time_slot=True, 33 | data_path='D:/KU_Works/Datasets/Weather4cast/2021/', # os.path.join(os.getcwd(), '../data'), 34 | splits_path=os.path.join(os.getcwd()),#'D:/PycharmProjects/HypercomplexNetwork/Weather4cast2021', # os.path.join(os.getcwd()), 35 | static_data_path='D:/KU_Works/Datasets/Weather4cast/2021/statics', # os.path.join(os.getcwd(), '../data/static'), 36 | size=256, 37 | collapse_time=False): 38 | """ Set paths & parameters to load/transform/save data and models. 39 | 40 | Args: 41 | region_id (str, optional): Region to load data from]. Defaults to 'R1'. 42 | competition (str, optional): competition name [stage-1, ieee-bd] (Default: stage-1). 43 | use_static (bool, optional): use static variable (Default: True) 44 | use_all_variables (bool, optional): use available variables for the variable types used (Default: False) 45 | use_cloud_type (bool, optional): use cloud type variables. [Only when all variables are used] (Default: False) 46 | use_time_slot (bool, optional): use time slots (Default: True) 47 | data_path (str, optional): path to the parent folder containing folders 48 | for the core competition (*/w4c-core-stage-1) and/or 49 | transfer learning comptition (*/w4c-transfer-learning-stage-1'). 50 | Defaults to 'data'. 51 | splits_path (str, optional): Path to the folder containing the csv and json files defining 52 | the data splits. 53 | Defaults to 'utils'. 54 | static_data_path (str, optional): Path to the folder containing the static channels. 55 | Defaults to 'data/static'. 56 | size (int, optional): Size of the region. Default to 256. 57 | collapse_time (bool, optional): collapses the time dimension 58 | Returns: 59 | dict: Contains the params 60 | """ 61 | competitions = ['ieee-bd', 'stage-1'] 62 | assert competition in competitions, f"competition name [{competition}] must be in {competitions}" 63 | 64 | data_params = {} 65 | model_params = {} 66 | training_params = {} 67 | optimization_params = {} 68 | 69 | regions = {'R3': {'up_left': (935, 400), 'split': 'train', 'desc': 'South West\nEurope', 'size': size}, 70 | 'R6': {'up_left': (1270, 250), 'split': 'test', 'desc': 'Central\nEurope', 'size': size}, 71 | 'R2': {'up_left': (1550, 200), 'split': 'train', 'desc': 'Eastern\nEurope', 'size': size}, 72 | 'R1': {'up_left': (1850, 760), 'split': 'train', 'desc': 'Nile Region', 'size': size}, 73 | 'R5': {'up_left': (1300, 550), 'split': 'test', 'desc': 'South\nMediterranean', 'size': size}, 74 | 'R4': {'up_left': (1020, 670), 'split': 'test', 'desc': 'Central\nMaghreb', 'size': size}, 75 | 'R7': {'up_left': (1700, 470), 'split': 'train', 'desc': 'Bosphorus', 'size': size}, 76 | 'R8': {'up_left': (750, 670), 'split': 'train', 'desc': 'East\nMaghreb', 'size': size}, 77 | 'R9': {'up_left': (450, 760), 'split': 'test', 'desc': 'Canarian Islands', 'size': size}, 78 | 'R10': {'up_left': (250, 500), 'split': 'test', 'desc': 'Azores Islands', 'size': size}, 79 | 'R11': {'up_left': (1000, 130), 'split': 'test', 'desc': 'North West\nEurope', 'size': size} 80 | } 81 | print(f'Using data for region {region_id} | size: {size} | {regions[region_id]["desc"]}') 82 | 83 | # ------ 84 | # control variables 85 | # -------- 86 | use_cloud_type = use_cloud_type if use_all_variables else False 87 | control_params = {} 88 | control_params['use_all_variables'] = use_all_variables 89 | control_params['use_cloud_type'] = use_cloud_type 90 | control_params['use_time_slot'] = use_time_slot 91 | control_params['use_static'] = use_static 92 | 93 | # ------------ 94 | # 1. Files to load 95 | # ------------ 96 | if competition == 'stage-1': 97 | if region_id in ['R1', 'R2', 'R3']: 98 | track = 'w4c-core-stage-1' 99 | else: 100 | track = 'w4c-transfer-learning-stage-1' 101 | else: # competition is now ieee-bd 102 | if region_id in ['R1', 'R2', 'R3', 'R7', 'R8']: 103 | track = 'ieee-bd-core' 104 | else: 105 | track = 'ieee-bd-transfer-learning' 106 | 107 | data_params['data_path'] = os.path.join(data_path, track, region_id) 108 | data_params['static_paths'] = {} 109 | data_params['static_paths']['l'] = os.path.join(static_data_path, 'Navigation_of_S_NWC_CT_MSG4_Europe-VISIR_20201106T120000Z.nc') 110 | data_params['static_paths']['e'] = os.path.join(static_data_path, 'S_NWC_TOPO_MSG4_+000.0_Europe-VISIR.raw') 111 | 112 | data_params['train_splits'] = os.path.join(splits_path, 'splits.csv') 113 | data_params['test_splits'] = os.path.join(splits_path, 'test_split.json') 114 | data_params['black_list_path'] = os.path.join(splits_path, 'blacklist.json') 115 | 116 | # ------------ 117 | # 2. Data params 118 | # ------------ 119 | data_params['collapse_time'] = collapse_time 120 | data_params['extra_data'] = 'l-e' if use_static else '' # use '' to not use static features 121 | data_params['target_vars'] = ['temperature', 'crr_intensity', 'asii_turb_trop_prob', 'cma'] 122 | 123 | data_params['products'] = {'CTTH': ['temperature'], 124 | 'CRR': ['crr_intensity'], 125 | 'ASII': ['asii_turb_trop_prob'], 126 | 'CMA': ['cma']} 127 | 128 | if use_all_variables: 129 | data_params['products'] = {'CTTH': ['ctth_pres', 'ctth_alti', 'ctth_tempe', 'ctth_effectiv', 'ctth_method', 130 | 'ctth_quality', 'ishai_skt', 'ishai_quality', 'temperature'], 131 | 'CRR': ['crr', 'crr_intensity', 'crr_accum', 'crr_quality'], 132 | 'ASII': ['asii_turb_trop_prob', 'asiitf_quality'], 133 | 'CMA': ['cma_cloudsnow', 'cma', 'cma_dust', 'cma_volcanic', 'cma_smoke', 134 | 'cma_quality'], 135 | } 136 | if use_cloud_type: 137 | data_params['products']['CT'] = ['ct', 'ct_cumuliform', 'ct_multilayer', 'ct_quality'] 138 | data_params['input_vars'] = [item for sublist in [value for key, value in data_params['products'].items()] for item 139 | in sublist] 140 | 141 | data_params['weigths'] = {'temperature': .25, 142 | 'crr_intensity': .25, 143 | 'asii_turb_trop_prob': .25, 144 | 'cma': .25} # to use by the metric 145 | 146 | # data_params['depth'] = len(data_params['target_vars']) + n_extra_vars(data_params['extra_data']) + 1 # lead time is added 147 | # data_params['depth'] = len(data_params['input_vars']) + n_extra_vars( 148 | # data_params['extra_data']) + 1 # lead time is added 149 | data_params['depth'] = len(data_params['input_vars']) + n_extra_vars( 150 | data_params['extra_data']) # lead time is added 151 | if use_time_slot: 152 | data_params['depth'] += 1 153 | 154 | data_params['spatial_dim'] = (size, size) 155 | data_params['crop_static'] = prepare_crop(regions, region_id) 156 | data_params['crop_in'] = None 157 | data_params['crop_out'] = None 158 | data_params['train_region_id'] = region_id+'_mse'*1 # this is actually used by the model, not the data ?????? 159 | data_params['region_id'] = region_id 160 | data_params['len_seq_in'] = 4 # time-bins of 15 minutes 161 | data_params['bins_to_predict'] = 8*4 # hours x (time-bins per hour =4) # not used 162 | # data_params['len_seq_out'] = 1 # time-bins 163 | data_params['len_seq_out'] = 8*4 # time-bins 164 | data_params['day_bins'] = 96 165 | data_params['seq_mode'] = 'sliding_window' # not used 166 | data_params['width'] = 256 # not used 167 | data_params['height'] = 256 # not used 168 | 169 | data_params['control_params'] = control_params 170 | # preprocessing: 171 | # a. fill_value: value to replace NaNs (currently temperature is the one that has more) 172 | # b. max_value: maximum value of the variable when it's saved on disk as integer 173 | # c. scale_factor: netCDF automatically uses this value to re-scale the value 174 | # d. add_offset: netCDF automatically uses this value to shift a variable 175 | # 176 | # c. and d. together mean that once loaded, the data is in the scale [add_offset, max_value*scale_factor + add_offset] 177 | # Hence, to normalize the data between [0, 1] we must use: 178 | # data = (data-add_offset)/(max_value*scale_factor - add_offset) 179 | 180 | preprocess = {'cma': {'fill_value': 0, 'max_value': 1, 'add_offset': 0, 'scale_factor': 1}, 181 | 'temperature': {'fill_value': 0, 'max_value': 35000, 'add_offset': 130, 182 | 'scale_factor': np.float32(0.01)}, 183 | 'crr_intensity': {'fill_value': 0, 'max_value': 500, 'add_offset': 0, 184 | 'scale_factor': np.float32(0.1)}, 185 | 'asii_turb_trop_prob': {'fill_value': 0, 'max_value': 100, 'add_offset': 0, 'scale_factor': 1}} 186 | preprocess_tgt = {'cma': {'fill_value': np.nan, 'max_value': 1, 'add_offset': 0, 'scale_factor': 1}, 187 | 'temperature': {'fill_value': np.nan, 'max_value': 35000, 'add_offset': 130, 188 | 'scale_factor': np.float32(0.01)}, 189 | 'crr_intensity': {'fill_value': np.nan, 'max_value': 500, 'add_offset': 0, 190 | 'scale_factor': np.float32(0.1)}, 191 | 'asii_turb_trop_prob': {'fill_value': np.nan, 'max_value': 100, 'add_offset': 0, 192 | 'scale_factor': 1}} 193 | 194 | data_params['preprocess'] = {'source': preprocess, 'target': preprocess_tgt} 195 | 196 | 197 | # ------------ 198 | # 3. Model params 199 | # ------------ 200 | if data_params['collapse_time']: 201 | model_params['in_channels'] = data_params['depth'] * data_params['len_seq_in'] 202 | else: 203 | 204 | model_params['in_channels'] = data_params['depth'] 205 | model_params['n_classes'] = len(data_params['target_vars']) * data_params['len_seq_out'] 206 | model_params['depth'] = 5 207 | model_params['wf'] = 6 208 | model_params['padding'] = True 209 | model_params['batch_norm'] = False 210 | model_params['up_mode'] = 'upconv' 211 | 212 | # # ------------ 213 | # # 4. Training params 214 | # # ------------ 215 | # training_params['batch_size'] = 16 # 64 216 | # training_params['n_workers'] = 8 217 | 218 | params = { 219 | 'data_params': data_params, 220 | 'model_params': model_params, 221 | # 'training_params': training_params, 222 | 'optimization_params': optimization_params, 223 | } 224 | 225 | return params 226 | 227 | if __name__ == '__main__': 228 | # this is only executed when the module is run directly. 229 | print(get_params()) 230 | -------------------------------------------------------------------------------- /hypercomplex/__init__.py: -------------------------------------------------------------------------------- 1 | from .hypercomplex_layers import (ComplexConv1D, ComplexConv2D, ComplexConv3D, ComplexLinear, 2 | QuaternionConv1D, QuaternionConv2D, QuaternionConv3D, QuaternionLinear, 3 | OctonionConv1D, OctonionConv2D, OctonionConv3D, OctonionLinear, 4 | SedanionConv1D, SedanionConv2D, SedanionConv3D, SedanionLinear, 5 | ComplexTransposeConv1D, QuaternionTransposeConv1D, OctonionTransposeConv1D, SedanionTransposeConv1D, 6 | ComplexTransposeConv2D, QuaternionTransposeConv2D, OctonionTransposeConv2D, SedanionTransposeConv2D, 7 | ComplexTransposeConv3D, QuaternionTransposeConv3D, OctonionTransposeConv3D, SedanionTransposeConv3D, 8 | HyperConv1D, HyperConv2D, HyperConv3D, HyperLinear) 9 | 10 | from .hypercomplex_ops import get_c 11 | 12 | from .hypercomplex_utils import get_comp_mat, get_hmat 13 | -------------------------------------------------------------------------------- /hypercomplex/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/helpers.cpython-36.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/helpers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/helpers.cpython-39.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_layers.cpython-36.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_layers.cpython-37.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_layers.cpython-39.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_layers2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_layers2.cpython-36.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_ops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_ops.cpython-36.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_ops.cpython-37.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_ops.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_ops.cpython-39.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_utils.cpython-36.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_utils.cpython-37.pyc -------------------------------------------------------------------------------- /hypercomplex/__pycache__/hypercomplex_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/hypercomplex/__pycache__/hypercomplex_utils.cpython-39.pyc -------------------------------------------------------------------------------- /hypercomplex/helpers.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | # from torch._six import container_abcs 3 | import collections.abc as container_abcs 4 | 5 | 6 | # From PyTorch internals 7 | def _ntuple(n): 8 | def parse(x): 9 | if isinstance(x, container_abcs.Iterable): 10 | return x 11 | return tuple(repeat(x, n)) 12 | return parse 13 | 14 | 15 | to_1tuple = _ntuple(1) 16 | to_2tuple = _ntuple(2) 17 | to_3tuple = _ntuple(3) 18 | to_4tuple = _ntuple(4) 19 | to_ntuple = _ntuple -------------------------------------------------------------------------------- /hypercomplex/hypercomplex_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | """Using symbolic algorithm to get the multiplication component matrix""" 3 | 4 | 5 | def hstar(h): 6 | h_out = [h[0]] 7 | for h_ in h[1:]: 8 | h_ = '-' + h_ 9 | h_out.append(h_) 10 | return h_out 11 | 12 | 13 | def zmult(h1, h2): 14 | ha, hb = h1[:len(h1)//2], h1[len(h1)//2:] 15 | hc, hd = h2[:len(h1)//2], h2[len(h1)//2:] 16 | 17 | # (a, b) (c, d) = (ac – db*, a*d + cb) 18 | ac = [x_ + y_ for (x_, y_) in zip(ha, hc)] 19 | db = [x_ + y_ for (x_, y_) in zip(hstar(hd), hb)] 20 | 21 | da = [x_ + y_ for (x_, y_) in zip(hd, ha)] 22 | bc = [x_ + y_ for (x_, y_) in zip(hb, hstar(hc))] 23 | 24 | hm_a = [f"{x_} -{y_}" for (x_,y_) in zip(ac, db)] 25 | hm_a.extend([f"{x_} {y_}" for (x_,y_) in zip(da, bc)]) 26 | return hm_a 27 | 28 | 29 | def qmult(h1, h2): 30 | ha, hb = h1[:len(h1)//2], h1[len(h1)//2:] 31 | hc, hd = h2[:len(h1)//2], h2[len(h1)//2:] 32 | 33 | # (a, b) (c, d) = (ac – d*b, da + bc*) 34 | ac = zmult(ha, hc) 35 | db = zmult(hstar(hd), hb) 36 | 37 | da = zmult(hd, ha) 38 | bc = zmult(hb, hstar(hc)) 39 | 40 | hm_a = [f"{x_} {' '.join([f'-{t}' for t in y_.split()])}" for (x_,y_) in zip(ac, db)] 41 | hm_a.extend([f"{x_} {y_}" for (x_,y_) in zip(da, bc)]) 42 | return hm_a 43 | 44 | 45 | def omult(h1, h2): 46 | ha, hb = h1[:len(h1)//2], h1[len(h1)//2:] 47 | hc, hd = h2[:len(h1)//2], h2[len(h1)//2:] 48 | 49 | # (a, b) (c, d) = (ac – d*b, da + bc*) 50 | ac = qmult(ha, hc) 51 | db = qmult(hstar(hd), hb) 52 | 53 | da = qmult(hd, ha) 54 | bc = qmult(hb, hstar(hc)) 55 | 56 | hm_a = [f"{x_} {' '.join([f'-{t}' for t in y_.split()])}" for (x_,y_) in zip(ac, db)] 57 | hm_a.extend([f"{x_} {y_}" for (x_,y_) in zip(da, bc)]) 58 | return hm_a 59 | 60 | 61 | def smult(h1, h2): 62 | ha, hb = h1[:len(h1)//2], h1[len(h1)//2:] 63 | hc, hd = h2[:len(h1)//2], h2[len(h1)//2:] 64 | 65 | # (a, b) (c, d) = (ac – d*b, da + bc*) 66 | ac = omult(ha, hc) 67 | db = omult(hstar(hd), hb) 68 | 69 | da = omult(hd, ha) 70 | bc = omult(hb, hstar(hc)) 71 | 72 | hm_a = [f"{x_} {' '.join([f'-{t}' for t in y_.split()])}" for (x_,y_) in zip(ac, db)] 73 | hm_a.extend([f"{x_} {y_}" for (x_,y_) in zip(da, bc)]) 74 | return hm_a 75 | 76 | 77 | # used recursion to cater for hypercomplex mukti 78 | def hmult(h1, h2): 79 | assert len(h1) == len(h2) 80 | n = len(h1) 81 | ha, hb = h1[:len(h1) // 2], h1[len(h1) // 2:] 82 | hc, hd = h2[:len(h1) // 2], h2[len(h1) // 2:] 83 | if n > 2: # do recursion 84 | ac = hmult(ha, hc) 85 | db = hmult(hstar(hd), hb) 86 | 87 | da = hmult(hd, ha) 88 | bc = hmult(hb, hstar(hc)) 89 | else: # end recursion 90 | ac = [x_ + y_ for (x_, y_) in zip(ha, hc)] 91 | db = [x_ + y_ for (x_, y_) in zip(hstar(hd), hb)] 92 | 93 | da = [x_ + y_ for (x_, y_) in zip(hd, ha)] 94 | bc = [x_ + y_ for (x_, y_) in zip(hb, hstar(hc))] 95 | 96 | hm_a = [f"{x_} {' '.join([f'-{t}' for t in y_.split()])}" for (x_, y_) in zip(ac, db)] 97 | hm_a.extend([f"{x_} {y_}" for (x_, y_) in zip(da, bc)]) 98 | return hm_a 99 | 100 | 101 | def hmat(h1, h2): 102 | # if len(h1) == 2: 103 | # hm = zmult(h1, h2) 104 | # elif len(h1) == 4: 105 | # hm = qmult(h1, h2) 106 | # elif len(h1) == 8: 107 | # hm = omult(h1, h2) 108 | # elif len(h1) == 16: 109 | # hm = smult(h1, h2) 110 | hm = hmult(h1, h2) 111 | 112 | m_out = [] 113 | for h_ in hm: 114 | m_temp = [] 115 | for x_ in h2: 116 | for h1_ in h_.split(): 117 | if x_ in h1_: 118 | h2_ = h1_.replace(x_, '') 119 | c_ = h2_.count('-') 120 | if c_ % 2: 121 | h2_ = f"-{h2_.replace('-','')}" 122 | else: 123 | h2_ = f"{h2_.replace('-','')}" 124 | m_temp.append(h2_) 125 | m_out.append(m_temp) 126 | return m_out 127 | 128 | 129 | def cmat(h1, h2): 130 | # if len(h1) == 2: 131 | # hm = zmult(h1, h2) 132 | # elif len(h1) == 4: 133 | # hm = qmult(h1, h2) 134 | # elif len(h1) == 8: 135 | # hm = omult(h1, h2) 136 | # elif len(h1) == 16: 137 | # hm = smult(h1, h2) 138 | hm = hmult(h1, h2) 139 | 140 | m_out = [] 141 | for h_ in hm: 142 | m_temp = [] 143 | for x_ in h2: 144 | for h1_ in h_.split(): 145 | if x_ in h1_: 146 | h2_ = h1_.replace(x_, '') 147 | c_ = h2_.count('-') 148 | if c_ % 2: 149 | h2_ = -int(f"{h2_.replace('-','').replace('w','')}") 150 | else: 151 | h2_ = int(f"{h2_.replace('-','').replace('w','')}") 152 | m_temp.append(h2_) 153 | m_out.append(m_temp) 154 | return m_out 155 | 156 | 157 | def get_comp_mat(num_components=8): 158 | h1 = [f'w{component}' for component in range(num_components)] 159 | h2 = [f'f{component}f' for component in range(num_components)] 160 | return np.array(cmat(h1, h2)) 161 | 162 | 163 | def get_hmat(num_components=8): 164 | h1 = [f'w{component}' for component in range(num_components)] 165 | h2 = [f'f{component}f' for component in range(num_components)] 166 | return hmat(h1, h2) 167 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Author: Pedro Herruzo 2 | # Copyright 2021 Institute of Advanced Research in Artificial Intelligence (IARAI) GmbH. 3 | # IARAI licenses this file to You under the Apache License, Version 2.0 4 | # (the "License"); you may not use this file except in compliance with 5 | # the License. You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ################################ 15 | # modified: Alabi Bojesomo 16 | # Date: 17 August 2021 17 | # Additional Copyright: Khalifa University Abu Dhabi 18 | 19 | import argparse 20 | import datetime 21 | 22 | from torch.utils.data import DataLoader 23 | 24 | #from w4c_dataloader import create_dataset 25 | from utils.w4c_dataloader import create_dataset 26 | 27 | import numpy as np 28 | import pandas as pd 29 | import pathlib 30 | import sys 31 | import os 32 | import json 33 | module_dir = str(pathlib.Path(os.getcwd()).parent) 34 | sys.path.append(module_dir) 35 | 36 | import config as cf 37 | import utils.data_utils as data_utils 38 | #from benchmarks.FeaturesSysUNet import FeaturesSysUNet as Model 39 | from model_pl import Model 40 | 41 | import glob 42 | 43 | # ------------ 44 | # 1. Create output folders 45 | # ------------ 46 | def create_directory_structure(root, region, folder_name='inference'): 47 | """ 48 | Create the inference output directory structure at given root path: root/folder_name 49 | """ 50 | 51 | # create the main fo 52 | metadata_path = os.path.join(root, folder_name) 53 | out_path = os.path.join(metadata_path, region, 'test') 54 | try: 55 | # os.makedirs(r_path) 56 | os.makedirs(out_path) 57 | print(f'created path: {out_path}') 58 | 59 | except: 60 | print(f'failed to create directory structure, maybe they already exist: {out_path}') 61 | return metadata_path, out_path 62 | 63 | # ------------ 64 | # 2. Prepare metadata needed by the weather4cast dataloader: `w4c_dataloader` 65 | # ------------ 66 | def get_bin_labels(): 67 | return ['{}{}{}{}00'.format('0'*bool(i<10), i, '0'*bool(j<10), j) for i in np.arange(0, 24, 1) for j in np.arange(0, 60, 15)] 68 | 69 | def fn_time_2_timebin(): 70 | times = get_bin_labels() 71 | bins = {t_str: tbin for tbin, t_str in enumerate(times) } 72 | return bins 73 | 74 | def get_out_bins(start, end, id_date, n_bins, time_bin_labels=get_bin_labels()): 75 | """ Creates the meta-data for the time intervals to be predicted """ 76 | i = 0 77 | bins_holder = {} 78 | 79 | for idx_bin in range(start, end): 80 | 81 | if idx_bin%n_bins == 0: # jump to next day 82 | day = int(id_date[-3:]) + 1 # ToDo %365 83 | 84 | zeros_before = '0'*(3 - len(str(day))%4) 85 | id_day = zeros_before + str(day) 86 | id_date = id_date[:-3]+id_day 87 | 88 | bins_holder[i] = {'id_day': id_date[-3:], 'id_bin': idx_bin%n_bins, 'time_bin': time_bin_labels[idx_bin%n_bins], 89 | 'date': datetime.datetime.strptime(id_date[:-3]+' '+id_date[-3:], '%Y %j').strftime('%Y%m%d')} 90 | i += 1 91 | 92 | return bins_holder 93 | 94 | def create_test_csv_json(data_p, region_id, metadata_path, product='CMA', 95 | n_bins=96, n_preds=32, n_files=4, competition='stage-1'): 96 | """ Creates a metadata filling input/output time intervals for a given folder. 97 | It uses the files inside the folder of product `product` to inform the time intervals.¡ 98 | """ 99 | #if region_id in ['R1', 'R2', 'R3', 'R7', 'R8']: 100 | # track = 'w4c-core-stage-1' 101 | #else: 102 | # track = 'w4c-transfer-learning-stage-1' 103 | if competition == 'stage-1': 104 | if region_id in ['R1', 'R2', 'R3']: 105 | track = 'w4c-core-stage-1' 106 | else: 107 | track = 'w4c-transfer-learning-stage-1' 108 | else: # competition is now ieee-bd 109 | if region_id in ['R1', 'R2', 'R3', 'R7', 'R8']: 110 | track = 'ieee-bd-core' 111 | else: 112 | track = 'ieee-bd-transfer-learning' 113 | # 1. get the dates to make inference from 114 | root = os.path.join(data_p, track, region_id, 'test') 115 | dates = [name for name in os.listdir(root) if os.path.isdir(root)] 116 | dates.sort() 117 | 118 | cols = ['id_date', 'split_id', 'split', 'id_day', 'date'] 119 | date_split = [] 120 | date_timebins = {} 121 | time_2_timebin = fn_time_2_timebin() 122 | 123 | for date in dates: 124 | # get the 4 input time intervals & sort them 125 | tmp_p = os.path.join(root, date, product, '*.nc') 126 | files = glob.glob(tmp_p) 127 | files.sort() 128 | assert len(files) == n_files, f'Number of files must be {n_files}, check your input folders' 129 | 130 | # get day and time from the files 131 | bins_day = {'bins_in': {}, 'bins_out': {}} 132 | for i, f in enumerate(files): 133 | f = f.split('_')[-1].split('Z')[0].split('T') 134 | day, time = f[0], f[-1] 135 | idx_timebin = time_2_timebin[time] 136 | 137 | # data to add to the json 138 | tmp = {'id_day': date[-3:], 'id_bin': idx_timebin, 'time_bin': time, 'date': day} 139 | bins_day['bins_in'][str(i)] = tmp 140 | # print(tmp) 141 | # print(day, time) 142 | 143 | if i == 0: 144 | # data to add to the csv 145 | date_split.append([date, 2, 'test', date[-3:], day]) 146 | 147 | idx_timebin += 1 # set the next time bin (the one to start predicting) 148 | bins_day['bins_out'] = get_out_bins(idx_timebin, idx_timebin+n_preds, date, n_bins) 149 | date_timebins[date[-3:]] = bins_day 150 | df = pd.DataFrame(date_split, columns=cols) 151 | 152 | # safe the files 153 | df.to_csv(os.path.join(metadata_path, 'splits.csv')) 154 | with open(os.path.join(metadata_path, 'test_split.json'), 'w', encoding='utf-8') as f: 155 | json.dump(date_timebins, f, ensure_ascii=False, indent=4) 156 | with open(os.path.join(metadata_path, 'blacklist.json'), 'w', encoding='utf-8') as f: 157 | json.dump({}, f, ensure_ascii=False, indent=4) 158 | 159 | # ------------ 160 | # 3. load data & model 161 | # ------------ 162 | def get_data_iterator(region_id, data_path, splits_path, data_split='test', collapse_time=False, 163 | batch_size=32, shuffle=False, num_workers=0): 164 | """ Creates an iterator for data in region 'region_id' for the days in `splits_path` 165 | """ 166 | params = cf.get_params(region_id=region_id, competition='stage-1', 167 | collapse_time=collapse_time, use_static=False, 168 | use_all_variables=False,use_cloud_type=False, use_time_slot=False) 169 | 170 | #params = cf.get_params(region_id=region_id, data_path=data_path, splits_path=splits_path) 171 | #params['data_params']['collapse_time'] = collapse_time 172 | 173 | ds = create_dataset('test', params['data_params'], precision=32, populate_mask=True) 174 | #ds = create_dataset(data_split, params['data_params']) 175 | dataloader = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 176 | 177 | data_splits, test_sequences = data_utils.read_splits(params['data_params']['train_splits'], params['data_params']['test_splits']) 178 | test_dates = data_splits[data_splits.split=='test'].id_date.sort_values().values 179 | 180 | return iter(dataloader), test_dates, params 181 | 182 | def load_model_and_weights(root_to_ckp, device=None): 183 | """ Loads a model as we basically have a single model trained on all the available data 184 | """ 185 | # regions where we have trained models 186 | #\logs\ALL_real_swinencoder3d_688080\20210630T224355\checkpoints\epoch=58-val_loss=0.029748.ckpt 187 | checkpoint_path = f"{root_to_ckp}/20210630T224355/checkpoints/epoch=58-val_loss=0.029748.ckpt" 188 | model = Model.load_from_checkpoint(checkpoint_path) 189 | if device is not None: 190 | model = model.cuda(device) 191 | 192 | return model 193 | 194 | # ------------ 195 | # 4. Make predictions 196 | # ------------ 197 | def get_preds(model, batch, device=None): 198 | """ Computes the output of the model on the next iterator's batch and 199 | returns the prediction and the date of it. 200 | """ 201 | 202 | in_seq, out, metadata = batch 203 | #day_in_year = metadata['in']['day_in_year'][0][0].item() 204 | day_in_year = metadata['out']['day_in_year'][0][0].item() 205 | 206 | if device is not None: 207 | in_seq = in_seq.cuda(device=device) 208 | y_hat = model(in_seq) 209 | #y_hat = y_hat.data.cpu().numpy() 210 | 211 | y_hat = torch.reshape(y_hat, (-1, 32, 4, 256, 256)) 212 | y_hat = y_hat.data.cpu().numpy() 213 | 214 | return y_hat, day_in_year 215 | 216 | def predictions_per_day(test_dates, model, ds_iterator, device, file_path, data_params): 217 | """ Computes predictions of all dates and saves them to disk """ 218 | model = model.eval() 219 | for target_date in test_dates: 220 | print(f'generating submission for date: {target_date}...') 221 | batch = next(ds_iterator) 222 | y_hat, predicted_day = get_preds(model, batch, device) 223 | 224 | # force data to be in the valid range 225 | y_hat[y_hat>1] = 1 226 | y_hat[y_hat<0] = 0 227 | 228 | # batches are sorted by date for the dataloader, that's why they coincide 229 | assert predicted_day==target_date, f"Error, the loaded date {predicted_day} is different than the target: {target_date}" 230 | 231 | f_path = os.path.join(file_path, f'{predicted_day}.h5') 232 | y_hat = data_utils.postprocess_fn(y_hat, data_params['target_vars'], data_params['preprocess']['source']) 233 | data_utils.write_data(y_hat, f_path) 234 | 235 | print(f'--> saved in: {f_path}') 236 | 237 | def predictions_per_day_ensamble(test_dates, models, ds_iterator, device, file_path, data_params): 238 | """ Computes predictions of all dates and saves them to disk. It uses the average of predictions across all models provided 239 | models (list): list of models to be used in the ensample 240 | """ 241 | for target_date in test_dates: 242 | print(f'generating submission for date: {target_date}...') 243 | batch = next(ds_iterator) 244 | 245 | ensamble = [] 246 | for model in models: 247 | y_hat, predicted_day = get_preds(model, batch, device) 248 | 249 | # force data to be in the valid range 250 | y_hat[y_hat>1] = 1 251 | y_hat[y_hat<0] = 0 252 | 253 | # batches are sorted by date for the dataloader, that's why they coincide 254 | assert predicted_day==target_date, f"Error, the loaded date {predicted_day} is different than the target: {target_date}" 255 | 256 | ensamble.append(y_hat) 257 | 258 | ensamble = np.asarray(ensamble) 259 | y_hat = np.mean(ensamble, axis=0) 260 | 261 | f_path = os.path.join(file_path, f'{predicted_day}.h5') 262 | y_hat = data_utils.postprocess_fn(y_hat, data_params['target_vars'], data_params['preprocess']['source']) 263 | data_utils.write_data(y_hat, f_path) 264 | print(f'--> saved in: {f_path}') 265 | 266 | # ------------ 267 | # 5. Main program 268 | # ------------ 269 | def inference(data_p, region, weights, output, gpu_id): 270 | """ Computes predictions using inputs from the `test` folder in: `data_p/-w4c/region_id` 271 | This script must load all needed weigths from folder: `weights` 272 | and save predictions in folder `outputs` 273 | """ 274 | # ------------ 275 | # A. input/output preparation 276 | # ------------ 277 | # 1. create a folder to save the predictions per day 278 | metadata_path, out_path = create_directory_structure(output, region, folder_name='inference') 279 | 280 | # 2. create the csv and json needed by the class `dataset` to provide single sequences per batch 281 | # so we can save to disk single predictions per day of shape (32, 4, 256, 256) 282 | create_test_csv_json(data_p, region, metadata_path) 283 | 284 | # ------------ 285 | # B. model & data loading: 286 | # 287 | # This part of the code must load the data and models. If you used the same `dataset` class we provided 288 | # you probably only need to modify loading the models. Adapt the code so it loads the learned weights from 289 | # the folder `weights` you provided for them 290 | # ------------ 291 | ds_iterator, test_dates, params = get_data_iterator(region, data_p, metadata_path) 292 | model = load_model_and_weights(weights, device=gpu_id) 293 | 294 | # ------------ 295 | # C. Predict and save the predictions 296 | # ------------ 297 | predictions_per_day(test_dates, model, ds_iterator, gpu_id, out_path, params['data_params']) 298 | 299 | def set_parser(): 300 | """ set custom parser """ 301 | 302 | parser = argparse.ArgumentParser(description="") 303 | parser.add_argument("-d", "--data", type=str, required=True, 304 | help='path to a folder containing days to be predicted (e.g. the test folder of the test dataset)') 305 | parser.add_argument("-r", "--region", type=str, required=False, default='R1', 306 | help='Region where the data belongs.') 307 | parser.add_argument("-w", "--weights", type=str, required=True, 308 | help='path to a folder containing all required weights of the model') 309 | parser.add_argument("-o", "--output", type=str, required=True, 310 | help='path to save the outputs of the model for each day.') 311 | parser.add_argument("-g", "--gpu_id", type=int, required=False, default=1, 312 | help="specify a gpu ID. 1 as default, -1 for CPU.") 313 | 314 | return parser 315 | 316 | def main(): 317 | 318 | parser = set_parser() 319 | options = parser.parse_args() 320 | if options.gpu_id < 0: 321 | options.gpu_id = None 322 | inference(options.data, options.region, options.weights, options.output, options.gpu_id) 323 | 324 | if __name__ == "__main__": 325 | main() 326 | 327 | """ examples of usage: 328 | - inference for Region R1, using a single model 329 | R=R1 330 | INPUT_PATH=../data 331 | WEIGHTS=~/projects/weather4cast/ligh_logs_old/old 332 | OUT_PATH=. 333 | python inference.py -d $INPUT_PATH -r $R -w $WEIGHTS -o $OUT_PATH -g 1 334 | """ -------------------------------------------------------------------------------- /logs/ALL_real_swinencoder3d_688080/20210630T224355/checkpoints/epoch=58-val_loss=0.029748.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/logs/ALL_real_swinencoder3d_688080/20210630T224355/checkpoints/epoch=58-val_loss=0.029748.ckpt -------------------------------------------------------------------------------- /logs/ALL_real_swinencoder3d_688080/20210630T224355/hparams.yaml: -------------------------------------------------------------------------------- 1 | args: !!python/object:argparse.Namespace 2 | augment_data: true 3 | batch_norm: false 4 | batch_size: 1 5 | beta_1: 0.9 6 | beta_2: 0.999 7 | bins_to_predict: 32 8 | black_list_path: /home/farhanakram/PycharmProjects/HypercomplexNetwork/Weather4cast2021/code_base/blacklist.json 9 | blk_type: swinencoder3d 10 | checkpoint: '' 11 | classifier_activation: sigmoid 12 | collapse_time: false 13 | control_params: 14 | use_all_variables: false 15 | use_cloud_type: false 16 | use_static: false 17 | use_time_slot: false 18 | crop_in: null 19 | crop_out: null 20 | crop_static: 21 | size: 256 22 | x_start: 1850 23 | y_start: 760 24 | data_path: /home/farhanakram/Alabi/Datasets/Weather4cast2021/w4c-core-stage-1/R1 25 | day_bins: 96 26 | dense_type: D 27 | depth: 4 28 | dropout: 0.0 29 | epochs: 100 30 | epsilon: 1.0e-08 31 | extra_data: '' 32 | filename: ALL_real_swinencoder3d_688080 33 | gpus: 0,1 34 | growth_rate: 64 35 | height: 256 36 | hidden_activation: elu 37 | in_channels: 4 38 | inplace_activation: true 39 | input_vars: 40 | - temperature 41 | - crr_intensity 42 | - asii_turb_trop_prob 43 | - cma 44 | len_seq_in: 4 45 | len_seq_out: 32 46 | log_dir: logs 47 | lr: 0.0001 48 | manual_seed: 0 49 | memory_efficient: true 50 | mode: train 51 | model_dir: '' 52 | modify_activation: true 53 | momentum: 0.9 54 | n_classes: 128 55 | name: '' 56 | nb_layers: 4 57 | net_type: real 58 | optimizer: adam 59 | padding: true 60 | patch_size: 2 61 | populate_mask: true 62 | precision: 32 63 | preprocess: 64 | source: 65 | asii_turb_trop_prob: 66 | add_offset: 0 67 | fill_value: 0 68 | max_value: 100 69 | scale_factor: 1 70 | cma: 71 | add_offset: 0 72 | fill_value: 0 73 | max_value: 1 74 | scale_factor: 1 75 | crr_intensity: 76 | add_offset: 0 77 | fill_value: 0 78 | max_value: 500 79 | scale_factor: &id002 !!python/object/apply:numpy.core.multiarray.scalar 80 | - &id001 !!python/object/apply:numpy.dtype 81 | args: 82 | - f4 83 | - 0 84 | - 1 85 | state: !!python/tuple 86 | - 3 87 | - < 88 | - null 89 | - null 90 | - null 91 | - -1 92 | - -1 93 | - 0 94 | - !!binary | 95 | zczMPQ== 96 | temperature: 97 | add_offset: 130 98 | fill_value: 0 99 | max_value: 35000 100 | scale_factor: &id003 !!python/object/apply:numpy.core.multiarray.scalar 101 | - *id001 102 | - !!binary | 103 | CtcjPA== 104 | target: 105 | asii_turb_trop_prob: 106 | add_offset: 0 107 | fill_value: .nan 108 | max_value: 100 109 | scale_factor: 1 110 | cma: 111 | add_offset: 0 112 | fill_value: .nan 113 | max_value: 1 114 | scale_factor: 1 115 | crr_intensity: 116 | add_offset: 0 117 | fill_value: .nan 118 | max_value: 500 119 | scale_factor: &id004 !!python/object/apply:numpy.core.multiarray.scalar 120 | - *id001 121 | - !!binary | 122 | zczMPQ== 123 | temperature: 124 | add_offset: 130 125 | fill_value: .nan 126 | max_value: 35000 127 | scale_factor: &id005 !!python/object/apply:numpy.core.multiarray.scalar 128 | - *id001 129 | - !!binary | 130 | CtcjPA== 131 | products: 132 | ASII: 133 | - asii_turb_trop_prob 134 | CMA: 135 | - cma 136 | CRR: 137 | - crr_intensity 138 | CTTH: 139 | - temperature 140 | region: R1 141 | region_id: R1 142 | seq_mode: sliding_window 143 | sf: 16 144 | sf_grp: 1 145 | spatial_dim: !!python/tuple 146 | - 256 147 | - 256 148 | stages: 3 149 | static_paths: 150 | e: /home/farhanakram/Alabi/Datasets/Weather4cast2021/statics/S_NWC_TOPO_MSG4_+000.0_Europe-VISIR.raw 151 | l: /home/farhanakram/Alabi/Datasets/Weather4cast2021/statics/Navigation_of_S_NWC_CT_MSG4_Europe-VISIR_20201106T120000Z.nc 152 | target_vars: 153 | - temperature 154 | - crr_intensity 155 | - asii_turb_trop_prob 156 | - cma 157 | test_splits: /home/farhanakram/PycharmProjects/HypercomplexNetwork/Weather4cast2021/code_base/test_split.json 158 | time_code: 20210630T224355 159 | train_dims: 49794 160 | train_region_id: R1_mse 161 | train_splits: /home/farhanakram/PycharmProjects/HypercomplexNetwork/Weather4cast2021/code_base/splits.csv 162 | up_mode: upconv 163 | use_all_region: true 164 | use_all_variables: false 165 | use_cloud_type: false 166 | use_group_norm: false 167 | use_static: false 168 | use_time_slot: false 169 | versiondir: logs/ALL_real_swinencoder3d_688080/20210630T224355 170 | weight_decay: 1.0e-06 171 | weigths: 172 | asii_turb_trop_prob: 0.25 173 | cma: 0.25 174 | crr_intensity: 0.25 175 | temperature: 0.25 176 | wf: 6 177 | width: 256 178 | workers: 8 179 | bins_to_predict: 32 180 | black_list_path: /home/farhanakram/PycharmProjects/HypercomplexNetwork/Weather4cast2021/code_base/blacklist.json 181 | collapse_time: false 182 | control_params: 183 | use_all_variables: false 184 | use_cloud_type: false 185 | use_static: false 186 | use_time_slot: false 187 | crop_in: null 188 | crop_out: null 189 | crop_static: 190 | size: 256 191 | x_start: 1850 192 | y_start: 760 193 | data_path: /home/farhanakram/Alabi/Datasets/Weather4cast2021/w4c-core-stage-1/R1 194 | day_bins: 96 195 | depth: 4 196 | extra_data: '' 197 | height: 256 198 | input_vars: 199 | - temperature 200 | - crr_intensity 201 | - asii_turb_trop_prob 202 | - cma 203 | len_seq_in: 4 204 | len_seq_out: 32 205 | preprocess: 206 | source: 207 | asii_turb_trop_prob: 208 | add_offset: 0 209 | fill_value: 0 210 | max_value: 100 211 | scale_factor: 1 212 | cma: 213 | add_offset: 0 214 | fill_value: 0 215 | max_value: 1 216 | scale_factor: 1 217 | crr_intensity: 218 | add_offset: 0 219 | fill_value: 0 220 | max_value: 500 221 | scale_factor: *id002 222 | temperature: 223 | add_offset: 130 224 | fill_value: 0 225 | max_value: 35000 226 | scale_factor: *id003 227 | target: 228 | asii_turb_trop_prob: 229 | add_offset: 0 230 | fill_value: .nan 231 | max_value: 100 232 | scale_factor: 1 233 | cma: 234 | add_offset: 0 235 | fill_value: .nan 236 | max_value: 1 237 | scale_factor: 1 238 | crr_intensity: 239 | add_offset: 0 240 | fill_value: .nan 241 | max_value: 500 242 | scale_factor: *id004 243 | temperature: 244 | add_offset: 130 245 | fill_value: .nan 246 | max_value: 35000 247 | scale_factor: *id005 248 | products: 249 | ASII: 250 | - asii_turb_trop_prob 251 | CMA: 252 | - cma 253 | CRR: 254 | - crr_intensity 255 | CTTH: 256 | - temperature 257 | region_id: R1 258 | seq_mode: sliding_window 259 | spatial_dim: !!python/tuple 260 | - 256 261 | - 256 262 | static_paths: 263 | e: /home/farhanakram/Alabi/Datasets/Weather4cast2021/statics/S_NWC_TOPO_MSG4_+000.0_Europe-VISIR.raw 264 | l: /home/farhanakram/Alabi/Datasets/Weather4cast2021/statics/Navigation_of_S_NWC_CT_MSG4_Europe-VISIR_20201106T120000Z.nc 265 | target_vars: 266 | - temperature 267 | - crr_intensity 268 | - asii_turb_trop_prob 269 | - cma 270 | test_splits: /home/farhanakram/PycharmProjects/HypercomplexNetwork/Weather4cast2021/code_base/test_split.json 271 | train_region_id: R1_mse 272 | train_splits: /home/farhanakram/PycharmProjects/HypercomplexNetwork/Weather4cast2021/code_base/splits.csv 273 | use_all_region: true 274 | weigths: 275 | asii_turb_trop_prob: 0.25 276 | cma: 0.25 277 | crr_intensity: 0.25 278 | temperature: 0.25 279 | width: 256 280 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import sys 3 | import os 4 | 5 | module_dir = str(pathlib.Path(os.getcwd())) 6 | sys.path.append(module_dir) 7 | 8 | import re 9 | import argparse 10 | import warnings 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import time 15 | import pytorch_lightning as pl 16 | from torch.utils.data import DataLoader, ConcatDataset 17 | import torch 18 | 19 | from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor 20 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 21 | from pytorch_lightning.plugins import DDPPlugin 22 | from pytorch_lightning.loggers import CSVLogger 23 | 24 | import config as cf 25 | from model_pl import Model 26 | from models import HyperSwinEncoderDecoder3D 27 | 28 | from utils.w4c_dataloader import create_dataset 29 | from train_utils import model_summary 30 | 31 | 32 | def get_held_out_params(params): 33 | held_out_params = params 34 | # print(params.keys()) 35 | old_path = held_out_params['data_path'] 36 | paths = re.split('/|\\\\', old_path) 37 | paths[-2] += '-heldout' 38 | new_path = f'{os.sep}'.join(paths) 39 | held_out_params['data_path'] = new_path 40 | return held_out_params 41 | 42 | 43 | class DataModule(pl.LightningDataModule): 44 | """ Class to handle training/validation splits in a single object 45 | """ 46 | 47 | def __init__(self, params, training_params, args): 48 | super().__init__() 49 | self.params = params 50 | self.training_params = training_params 51 | self.args = args 52 | self.train_dims = None 53 | self.precision = args.precision if hasattr(args, 'precision') else 16 54 | self.populate_mask = args.populate_mask if hasattr(args, 'populate_mask') else False 55 | 56 | self.train = self.val = self.predict = self.held_out = None 57 | self.all_regions = self.core_regions = None 58 | 59 | def setup(self): 60 | if self.params['use_all_region']: 61 | train_datasets = [] 62 | val_datasets = [] 63 | predict_datasets = [] 64 | held_out_datasets = [] 65 | core_regions = ['R1', 'R2', 'R3'] 66 | all_regions = [f"R{i + 1}" for i in range(6)] 67 | if self.args.competition == 'ieee-bd': 68 | core_regions.extend(['R7', 'R8']) 69 | all_regions.extend([f"R{i + 1}" for i in range(6, 11)]) 70 | self.all_regions = all_regions 71 | self.core_regions = core_regions 72 | for region_id in core_regions: 73 | params_i = cf.get_params(region_id=region_id, competition=self.args.competition, 74 | collapse_time=self.args.collapse_time, use_static=self.args.use_static, 75 | use_all_variables=self.args.use_all_variables, 76 | use_cloud_type=self.args.use_cloud_type, use_time_slot=self.args.use_time_slot) 77 | train_datasets.append(create_dataset('training', params_i['data_params'], precision=self.precision, 78 | populate_mask=self.populate_mask)) 79 | val_datasets.append(create_dataset('validation', params_i['data_params'], precision=self.precision, 80 | populate_mask=self.populate_mask)) 81 | 82 | for region_id in all_regions: 83 | params_i = cf.get_params(region_id=region_id, competition=self.args.competition, 84 | collapse_time=self.args.collapse_time, use_static=self.args.use_static, 85 | use_all_variables=self.args.use_all_variables, 86 | use_cloud_type=self.args.use_cloud_type, use_time_slot=self.args.use_time_slot) 87 | predict_datasets.append(create_dataset('test', params_i['data_params'], precision=self.precision, 88 | populate_mask=self.populate_mask)) 89 | if self.args.held_out: # using held-out data 90 | held_out_params = get_held_out_params(params_i['data_params']) 91 | held_out_datasets.append(create_dataset('test', held_out_params, precision=self.precision, 92 | populate_mask=self.populate_mask)) 93 | 94 | self.train = ConcatDataset(train_datasets) 95 | self.val = ConcatDataset(val_datasets) 96 | self.predict = ConcatDataset(predict_datasets) 97 | if self.args.held_out: # using held-out data 98 | self.held_out = ConcatDataset(held_out_datasets) 99 | else: 100 | self.train = create_dataset('training', self.params, precision=self.precision, 101 | populate_mask=self.populate_mask) 102 | self.val = create_dataset('validation', self.params, precision=self.precision, 103 | populate_mask=self.populate_mask) 104 | self.predict = create_dataset('test', self.params, precision=self.precision, 105 | populate_mask=self.populate_mask) 106 | if self.args.held_out: # using held-out data 107 | held_out_params = get_held_out_params(self.params) 108 | self.held_out = create_dataset('test', held_out_params, precision=self.precision, 109 | populate_mask=self.populate_mask) 110 | 111 | 112 | self.train_dims = self.train.__len__() 113 | 114 | def __load_dataloader(self, dataset, shuffle=True, pin=True): 115 | dl = DataLoader(dataset, 116 | batch_size=self.training_params['batch_size'], num_workers=self.training_params['n_workers'], 117 | shuffle=shuffle, pin_memory=pin) 118 | return dl 119 | 120 | def train_dataloader(self): 121 | ds = self.train # create_dataset('training', self.params) 122 | return self.__load_dataloader(ds, shuffle=True, pin=True) 123 | 124 | def val_dataloader(self): 125 | val_loader = self.__load_dataloader(self.val, shuffle=False, pin=True) 126 | predict_loader = self.__load_dataloader(self.predict, shuffle=False, pin=True) 127 | return [val_loader, predict_loader] 128 | 129 | def test_dataloader(self): 130 | if self.args.held_out: # using held-out data 131 | predict_loader = self.__load_dataloader(self.held_out, shuffle=False, pin=True) 132 | else: 133 | predict_loader = self.__load_dataloader(self.predict, shuffle=False, pin=True) 134 | return predict_loader 135 | 136 | 137 | def print_training(params): 138 | """ print pre-training info """ 139 | 140 | print(f'Extra variables: {params["extra_data"]} | spatial_dim: {params["spatial_dim"]} ', 141 | f'| collapse_time: {params["collapse_time"]} | in channels depth: {params["depth"]} | len_seq_in: {params["len_seq_in"]}') 142 | 143 | 144 | def modify_options(options, n_params): 145 | filename = '_'.join( 146 | [f"{item}" for item in ('ALL' if options.use_all_region else options.region, options.net_type, 'swinencoder3d', 147 | int(n_params))]) 148 | options.filename = options.name or filename # to account for resuming from a previous state 149 | 150 | options.versiondir = os.path.join(options.log_dir, options.filename, options.time_code) 151 | os.makedirs(options.versiondir, exist_ok=True) 152 | readme_file = os.path.join(options.versiondir, 'options.csv') 153 | args_dict = vars(argparse.Namespace(**{'modelname': options.filename, 'num_params': n_params}, **vars(options))) 154 | args_df = pd.DataFrame([args_dict]) 155 | if os.path.exists(readme_file): 156 | args_df.to_csv(readme_file, mode='a', index=False, header=False) 157 | else: 158 | args_df.to_csv(readme_file, mode='a', index=False) 159 | 160 | return options 161 | 162 | 163 | def save_options(options, n_params): 164 | options.versiondir = os.path.join(options.log_dir, options.filename, options.time_code) 165 | os.makedirs(options.versiondir, exist_ok=True) 166 | readme_file = os.path.join(options.versiondir, 'options.csv') 167 | args_dict = vars(argparse.Namespace(**{'modelname': options.filename, 'num_params': n_params}, **vars(options))) 168 | args_df = pd.DataFrame([args_dict]) 169 | if os.path.exists(readme_file): 170 | args_df.to_csv(readme_file, mode='a', index=False, header=False) 171 | else: 172 | args_df.to_csv(readme_file, mode='a', index=False) 173 | return options 174 | 175 | 176 | def get_trainer(options): 177 | """ get the trainer, modify here it's options: 178 | - save_top_k 179 | - max_epochs 180 | """ 181 | lr_monitor = LearningRateMonitor(logging_interval='step') 182 | 183 | early_stop_callback = EarlyStopping( 184 | monitor='val_loss', # should be found in logs 185 | patience=20, 186 | strict=False, # will act as disabled if monitor not found 187 | verbose=False, 188 | mode='min' 189 | ) 190 | 191 | logger = CSVLogger(save_dir=options.log_dir, 192 | name=options.filename, 193 | version=options.time_code, 194 | ) 195 | 196 | resume_from_checkpoint = None 197 | if options.name and options.time_code: 198 | checkpoint_dir = os.path.join(options.versiondir, 'checkpoints') 199 | if options.initial_epoch == -1: 200 | checkpoint_name = 'last.ckpt' 201 | else: 202 | format_str = f"epoch={options.initial_epoch:02g}" 203 | checkpoint_names = os.listdir(checkpoint_dir) 204 | checkpoint_name = checkpoint_names[[t.startswith(format_str) for t in checkpoint_names].index(True)] 205 | resume_from_checkpoint = os.path.join(checkpoint_dir, checkpoint_name) 206 | 207 | 208 | checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3, 209 | save_last=True, verbose=False, 210 | filename='{epoch:02d}-{val_loss:.6f}') 211 | 212 | callbacks = [lr_monitor, checkpoint_callback, early_stop_callback] 213 | 214 | trainer = pl.Trainer(gpus=options.gpus, 215 | max_epochs=options.epochs, 216 | progress_bar_refresh_rate=10, 217 | deterministic=True, 218 | gradient_clip_val=1, # to clip gradient value and prevent exploding gradient 219 | gradient_clip_algorithm='value', 220 | default_root_dir=os.path.dirname(options.log_dir), 221 | callbacks=callbacks, 222 | profiler='simple', 223 | sync_batchnorm=True, 224 | num_sanity_val_steps=0, 225 | # accelerator='ddp', 226 | logger=logger, 227 | resume_from_checkpoint=resume_from_checkpoint, 228 | num_nodes=1, 229 | precision=options.precision if hasattr(options, 'precision') else 16, 230 | ) 231 | 232 | return trainer 233 | 234 | 235 | def train(region_id, mode, options=None): 236 | """ main training/evaluation method 237 | """ 238 | 239 | # some needed stuffs 240 | warnings.filterwarnings("ignore") 241 | 242 | params = cf.get_params(region_id=region_id, competition=options.competition, collapse_time=options.collapse_time, 243 | use_static=options.use_static, use_all_variables=options.use_all_variables, 244 | use_cloud_type=options.use_cloud_type, use_time_slot=options.use_time_slot) 245 | # print(params['data_params']) 246 | params['data_params']['use_all_region'] = options.use_all_region 247 | options = argparse.Namespace(**{**vars(options), **params['model_params'], **params['data_params']}) \ 248 | if options else argparse.Namespace(**params) 249 | 250 | pl.seed_everything(options.manual_seed, workers=True) 251 | torch.manual_seed(options.manual_seed) 252 | torch.cuda.manual_seed_all(options.manual_seed) 253 | 254 | training_params = {'batch_size': options.batch_size, 255 | 'n_workers': options.workers # 8 256 | } 257 | 258 | # ------------ 259 | # Data and model params 260 | # ------------ 261 | data = DataModule(params['data_params'], training_params, options) 262 | data.setup() 263 | 264 | # add other depending args 265 | options.train_dims = data.train_dims 266 | options.core_regions = data.core_regions 267 | options.all_regions = data.all_regions 268 | 269 | # let's load model for printing structure 270 | model = HyperSwinEncoderDecoder3D(options) 271 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 272 | x_all = torch.rand(1, params['data_params']['len_seq_in'], params['data_params']['depth'], 256, 256) 273 | _ = model_summary(model, x_all, print_summary=True, max_depth=1) 274 | del model, x_all 275 | 276 | # ------------ 277 | # trainer 278 | # ------------ 279 | options = modify_options(options, n_params) 280 | trainer = get_trainer(options) 281 | print_training(params['data_params']) 282 | 283 | # ------ 284 | # Model 285 | # ----- 286 | checkpoint_path = trainer.resume_from_checkpoint 287 | if checkpoint_path is not None: 288 | model = Model.load_from_checkpoint(checkpoint_path) 289 | else: 290 | model = Model(options, **params['data_params']) 291 | 292 | 293 | 294 | print(options) 295 | # ------------ 296 | # train & final validation 297 | # ------------ 298 | if mode == 'train': 299 | print("-----------------") 300 | print("-- TRAIN MODE ---") 301 | print("-----------------") 302 | trainer.fit(model, datamodule=data) 303 | # elif mode == 'val': 304 | # print("-----------------") 305 | # print("-- Validation only for metric collection---") 306 | # print("-----------------") 307 | # trainer.validate(model, datamodule=data) 308 | else: 309 | print("-----------------") 310 | print("--- TEST MODE ---") 311 | print("-----------------") 312 | trainer.test(model, datamodule=data) 313 | 314 | 315 | def set_parser(parent_parser): 316 | """ set custom parser """ 317 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 318 | parser.add_argument("-g", "--gpus", type=str, required=False, default='0', 319 | help="specify a gpu ID. 0 as default") 320 | parser.add_argument("-r", "--region", type=str, required=False, default='R1', 321 | help="region_id to load data from. R1 as default") 322 | parser.add_argument("-a", "--use_all_region", type=bool, required=False, default=True, 323 | help="use all region") 324 | parser.add_argument("-m", "--mode", type=str, required=False, default='test', 325 | help="choose mode: train (default) / test") 326 | parser.add_argument("-ho", "--held-out", type=bool, required=False, default=False, 327 | help="are we using held-out dataset for the 'test'") 328 | 329 | return parser 330 | 331 | 332 | def add_main_args(parent_parser): 333 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 334 | parser.add_argument('--competition', default='stage-1', help='competition name', choices=['stage-1', 'ieee-bd']) 335 | parser.add_argument('--collapse-time', type=bool, default=False, help='collapse time axis') 336 | parser.add_argument('--use_static', type=bool, default=False, help='use static variable (Default: True)') 337 | parser.add_argument('--use_time_slot', type=bool, default=False, help='use time slots (Default: True)') 338 | parser.add_argument('--use_cloud_type', type=bool, default=False, 339 | help='use cloud type variables. [Only when all variables are used] (Default: False)') 340 | parser.add_argument('--use_all_variables', type=bool, default=False, 341 | help='use available variables for the variable types used (Default: False)') 342 | parser.add_argument('--populate_mask', type=bool, default=True, help='use mask to work only on unmasked data') 343 | 344 | parser.add_argument('--precision', type=int, default=32, help='precision to use for training', choices=[16, 32]) 345 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train for') 346 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') 347 | 348 | parser.add_argument('--manual-seed', default=0, type=int, help='manual global seed') 349 | parser.add_argument('--log-dir', default='logs', help='base directory to save logs') 350 | parser.add_argument('--model-dir', default='', help='base directory to save logs') 351 | parser.add_argument('--name', default='ALL_real_swinencoder3d_688080', 352 | help='identifier for model if already exist') 353 | parser.add_argument('--time-code', default='20210630T224355', 354 | help='identifier for model if already exist') 355 | parser.add_argument('--initial-epoch', type=int, default=58, 356 | help='number of epochs done (-1 == last)') 357 | parser.add_argument('--memory_efficient', type=bool, default=True, help='memory_efficient') 358 | 359 | return parser 360 | 361 | 362 | def get_time_code(): 363 | time_now = [f"{'0' if len(x) < 2 else ''}{x}" for x in np.array(time.localtime(), dtype=str)][:6] 364 | if os.path.exists('t.npy'): 365 | time_before = np.load('t.npy') # .astype(np.int) 366 | if abs(int(''.join(time_before)) - int(''.join(time_now))) < 70: 367 | time_now = time_before 368 | else: 369 | np.save('t.npy', time_now) 370 | else: 371 | np.save('t.npy', time_now) 372 | time_now = ''.join(time_now[:3]) + 'T' + ''.join(time_now[3:]) 373 | return time_now 374 | 375 | 376 | def main(): 377 | parser = argparse.ArgumentParser(description="Weather4Cast Arguments") 378 | parser = set_parser(parser) 379 | parser = add_main_args(parser) 380 | parser = Model.add_model_specific_args(parser) 381 | options = parser.parse_args() 382 | 383 | options.region = options.region.upper() 384 | options.workers = 6 385 | 386 | time_code = get_time_code() 387 | options.time_code = options.time_code or time_code # to account for resuming from a previous state 388 | 389 | train(options.region, options.mode, options) 390 | 391 | 392 | if __name__ == "__main__": 393 | main() 394 | """ examples of usage: 395 | 396 | - a.1) train from scratch 397 | python main.py --gpus 0 --region R1 398 | 399 | - a.2) fine tune a model from a checkpoint 400 | python main.py --gpu_id 1 --region R1 --mode train --name ALL_real_swinencoder3d_688080 --time-code 20210630T224355 --initial-epoch 58 401 | 402 | - b.1) evaluate an untrained model (with random weights) 403 | python main.py --gpus 0 --region R1 --mode test 404 | 405 | - b.2) evaluate a trained model from a checkpoint 406 | python main.py --gpu_id 1 --region R1 --mode test --name ALL_real_swinencoder3d_688080 --time-code 20210630T224355 --initial-epoch 58 407 | 408 | """ 409 | -------------------------------------------------------------------------------- /model_pl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import pytorch_lightning as pl 6 | import torch 7 | import os 8 | 9 | import argparse 10 | from argparse import ArgumentParser 11 | from models import optimizer_dict, HyperSwinEncoderDecoder3D 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | import math 14 | import utils.data_utils as data_utils 15 | 16 | 17 | class Model(pl.LightningModule): 18 | def __init__(self, args, 19 | extra_data: str, depth: int, height: int, 20 | width: int, len_seq_in: int, len_seq_out: int, bins_to_predict: int, 21 | seq_mode: str, **kwargs): 22 | super().__init__() 23 | 24 | self.save_hyperparameters() 25 | 26 | self.args = args 27 | 28 | self.net = HyperSwinEncoderDecoder3D(args) 29 | self.extra_data = extra_data 30 | self.depth = depth 31 | self.height = height 32 | self.width = width 33 | self.len_seq_in = len_seq_in 34 | self.len_seq_out = len_seq_out 35 | self.bins_to_predict = bins_to_predict 36 | self.seq_mode = seq_mode 37 | 38 | self.loss_fn = F.mse_loss 39 | 40 | self.core_dir = '' 41 | self.transfer_dir = '' 42 | 43 | def forward(self, x, mask=None): 44 | return self.net(x) 45 | 46 | def _compute_loss(self, y_hat, y, agg=True, mask=None): 47 | if mask is not None: 48 | y_hat = y_hat.flatten()[~mask.flatten()] 49 | y = y.flatten()[~mask.flatten()] 50 | 51 | if agg: 52 | loss = self.loss_fn(y_hat, y) 53 | else: 54 | loss = self.loss_fn(y_hat, y, reduction='none') 55 | return loss 56 | 57 | @staticmethod 58 | def process_batch(batch): 59 | # in_seq, out_seq, metadata = batch 60 | return batch 61 | 62 | def training_step(self, batch, batch_idx, phase='train'): 63 | 64 | x, y, metadata = self.process_batch(batch) 65 | y_hat = self.forward(x) 66 | loss = self._compute_loss(y_hat, y, mask=metadata['out'].get('masks')) 67 | self.log(f'{phase}_loss', loss, on_epoch=True) 68 | return loss 69 | 70 | def create_inference_dirs(self): 71 | for region_id in ['R1', 'R2', 'R3']: 72 | os.makedirs(os.path.join(self.core_dir, region_id, 'test'), exist_ok=True) 73 | for region_id in ['R4', 'R5', 'R6']: 74 | os.makedirs(os.path.join(self.transfer_dir, region_id, 'test'), exist_ok=True) 75 | 76 | def on_validation_epoch_start(self): 77 | epoch_dir = os.path.join(self.args.versiondir, 'inference', f"epoch={self.current_epoch}") 78 | self.core_dir = os.path.join(epoch_dir, f'core_{self.current_epoch}') 79 | self.transfer_dir = os.path.join(epoch_dir, f'transfer_{self.current_epoch}') 80 | self.create_inference_dirs() 81 | 82 | def on_test_epoch_start(self): 83 | ckpt_name = str(os.path.basename(self.trainer.resume_from_checkpoint)) 84 | folder_name = ckpt_name.split('.')[0].split('-')[0].split('=')[-1] 85 | epoch_dir = os.path.join(self.args.versiondir, 'test', f"epoch={folder_name}") 86 | self.core_dir = os.path.join(epoch_dir, f'core_{folder_name}') 87 | self.transfer_dir = os.path.join(epoch_dir, f'transfer_{folder_name}') 88 | self.create_inference_dirs() 89 | 90 | def save_prediction(self, y_hat, metadata, batch_idx, loader_idx): 91 | y_hat = torch.reshape(y_hat, (-1, self.len_seq_out, len(self.args.target_vars), self.height, self.width)) 92 | y_hat = y_hat.data.cpu().numpy() 93 | 94 | for idx, (region_id, day_in_year) in enumerate(zip(metadata['out']['region_id'], 95 | metadata['out']['day_in_year'][0])): 96 | if region_id in ['R1', 'R2', 'R3']: # 'w4c-core-stage-1' 97 | save_path = os.path.join(self.core_dir, region_id, 'test', f"{day_in_year}.h5") 98 | else: # 'w4c-transfer-learning-stage-1' 99 | save_path = os.path.join(self.transfer_dir, region_id, 'test', f"{day_in_year}.h5") 100 | y = data_utils.postprocess_fn(y_hat[idx], self.args.target_vars, 101 | self.args.preprocess['source']) 102 | data_utils.write_data(y, save_path) 103 | 104 | def validation_step(self, batch, batch_idx, loader_idx, phase='val'): 105 | 106 | x, y, metadata = self.process_batch(batch) 107 | y_hat = self.forward(x) 108 | if loader_idx == 0: # for validation loader only 109 | loss = self._compute_loss(y_hat, y, mask=metadata['out'].get('masks')) 110 | self.log(f'{phase}_loss', loss, prog_bar=True) # , logger=True) 111 | else: # for prediction 112 | self.save_prediction(y_hat, metadata, batch_idx, loader_idx) 113 | 114 | def test_step(self, batch, batch_idx): # , phase='test'): 115 | 116 | x, y, metadata = self.process_batch(batch) 117 | y_hat = self.forward(x) 118 | self.save_prediction(y_hat, metadata, batch_idx, loader_idx=0) 119 | 120 | def configure_optimizers(self): 121 | other_args = {} 122 | print(self.args) 123 | if self.args.optimizer == 'sgd': 124 | other_args = {'lr': self.args.lr, 'momentum': self.args.momentum, 125 | 'weight_decay': self.args.weight_decay, 'nesterov': True} 126 | optimizer = optimizer_dict[self.args.optimizer](self.net.parameters(), **other_args) 127 | t_max = math.ceil(self.args.epochs * self.args.train_dims / 128 | (self.args.batch_size * len( 129 | self.args.gpus.split(',')))) # self.args.epochs * self.args.train_dims 130 | scheduler = {'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, t_max, eta_min=0), 131 | 'interval': 'step', # or 'epoch' 132 | } 133 | elif self.args.optimizer == 'adam': 134 | other_args = {'lr': self.args.lr, 'eps': self.args.epsilon, 135 | 'betas': (self.args.beta_1, self.args.beta_2), 136 | 'weight_decay': self.args.weight_decay} 137 | optimizer = optimizer_dict[self.args.optimizer](self.net.parameters(), **other_args) 138 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, cooldown=0, min_lr=1e-7) 139 | elif self.args.optimizer == 'swats': 140 | other_args = {'lr': self.args.lr, 141 | 'weight_decay': self.args.weight_decay, 142 | 'nesterov': True 143 | } 144 | optimizer = optimizer_dict[self.args.optimizer](self.net.parameters(), **other_args) 145 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, cooldown=0, min_lr=1e-7) 146 | 147 | return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} 148 | 149 | @staticmethod 150 | def add_model_specific_args(parent_parser): 151 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 152 | parser.add_argument('--net_type', default='real', help='type of network', 153 | choices=['sedenion', 'real', 'complex', 'quaternion', 'octonion']) 154 | parser.add_argument('--patch_size', type=int, default=2, help='patch size to use in swin transfer') 155 | parser.add_argument('--nb_layers', type=int, default=4, help='depth of resnet blocks (default: 1)') 156 | parser.add_argument('--sf', type=int, default=16 * 1, 157 | help='number of feature maps/embedding dimension (default: 16*1)') 158 | parser.add_argument('--stages', type=int, default=3, 159 | help='number of encoder stages (<1 means infer) (default:0)') 160 | parser.add_argument('--classifier_activation', default='sigmoid', 161 | help='hidden layer activation (default: hardtanh)') # sigmoid? 162 | parser.add_argument('--modify_activation', type=bool, default=True, 163 | help='modify the range of hardtanh activation') 164 | parser.add_argument('--inplace_activation', type=bool, default=True, help='inplace activation') 165 | parser.add_argument('--dropout', type=float, default=0.0, help='dropout probability') 166 | 167 | parser.add_argument('--optimizer', default='adam', help='optimizer to train with', 168 | choices=['sgd', 'adam', 'swats']) 169 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 170 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum term for sgd') 171 | parser.add_argument('--beta_1', default=0.9, type=float, help='beta_1 term for adam') 172 | parser.add_argument('--beta_2', default=0.999, type=float, help='beta_2 term for adam') 173 | parser.add_argument('--epsilon', default=1e-8, type=float, help='epsilon term for adam') 174 | parser.add_argument('--weight_decay', default=1e-6, type=float, 175 | help='weight decay for regularization (default: 1e-6)') 176 | 177 | return parser 178 | 179 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | ########################################################## 3 | # pytorch v1.0 4 | # Alabi Bojesomo 5 | # Khalifa University 6 | # Abu Dhabi, UAE 7 | # April 2021 8 | ########################################################## 9 | 10 | import os 11 | import sys 12 | sys.path.extend(os.getcwd()) 13 | 14 | from torch import optim 15 | import numpy as np 16 | import warnings 17 | from swin_transformer3d import SwinEncoderDecoderTransformer3D 18 | import torch_optimizer as extra_optim 19 | from functools import partial 20 | 21 | optimizer_dict = {'adadelta': optim.Adadelta, 22 | 'adagrad': optim.Adagrad, 23 | 'adam': optim.Adam, 24 | 'adamw': optim.AdamW, 25 | 'swats': extra_optim.SWATS, 26 | 'sparse_adam': optim.SparseAdam, 27 | 'adamax': optim.Adamax, 28 | 'asgd': optim.ASGD, 29 | 'sgd': optim.SGD, 30 | 'rprop': optim.Rprop, 31 | 'rmsprop': optim.RMSprop, 32 | 'lbfgs': optim.LBFGS} 33 | n_div_dict = {'sedenion': 16, 34 | 'octonion': 8, 35 | 'quaternion': 4, 36 | 'complex': 2, 37 | 'real': 1} 38 | 39 | 40 | class HyperSwinEncoderDecoder3D(SwinEncoderDecoderTransformer3D): 41 | def __init__(self, args): 42 | if hasattr(args, 'n_divs'): 43 | n_divs = args.n_divs 44 | else: 45 | n_divs = n_div_dict[args.net_type.lower()] 46 | 47 | heads_ = 8 48 | n_multiples_in = int(n_divs * np.ceil(args.len_seq_in / n_divs)) 49 | embed_dim = int(n_divs * heads_ * np.ceil(n_multiples_in / (n_divs * heads_))) 50 | if args.sf < embed_dim: 51 | warnings.warn(f"args.sf = {args.sf} < embed_dim used [{embed_dim}]") 52 | if args.sf > embed_dim: 53 | embed_dim = int(embed_dim * np.ceil(args.sf / embed_dim)) 54 | 55 | super().__init__(depths=tuple([args.nb_layers] * args.stages), 56 | num_heads=tuple([8] * args.stages), 57 | out_chans=args.len_seq_out, 58 | in_chans=args.len_seq_in, 59 | embed_dim=embed_dim, # args.sf, 60 | img_size=(args.height, args.width), 61 | in_depth=args.depth, out_depth=len(args.target_vars), 62 | n_divs=n_divs, 63 | drop_rate=args.dropout, 64 | patch_size=(1, *([args.patch_size] * 2)) 65 | ) 66 | -------------------------------------------------------------------------------- /parametrize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.container import ModuleList, ModuleDict, Module 3 | from torch.nn.parameter import Parameter 4 | from torch import Tensor 5 | from typing import Union, Optional, Iterable, Dict, Tuple 6 | from contextlib import contextmanager 7 | 8 | 9 | _cache_enabled = 0 10 | _cache: Dict[Tuple[int, str], Optional[Tensor]] = {} 11 | 12 | 13 | @contextmanager 14 | def cached(): 15 | r"""Context manager that enables the caching system within parametrizations 16 | registered with :func:`register_parametrization`. 17 | The value of the parametrized objects is computed and cached the first time 18 | they are required when this context manager is active. The cached values are 19 | discarded when leaving the context manager. 20 | This is useful when using a parametrized parameter more than once in the forward pass. 21 | An example of this is when parametrizing the recurrent kernel of an RNN or when 22 | sharing weights. 23 | The simplest way to activate the cache is by wrapping the forward pass of the neural network 24 | .. code-block:: python 25 | import torch.nn.utils.parametrize as P 26 | ... 27 | with P.cached(): 28 | output = model(inputs) 29 | in training and evaluation. One may also wrap the parts of the modules that use 30 | several times the parametrized tensors. For example, the loop of an RNN with a 31 | parametrized recurrent kernel: 32 | .. code-block:: python 33 | with P.cached(): 34 | for x in xs: 35 | out_rnn = self.rnn_cell(x, out_rnn) 36 | """ 37 | global _cache 38 | global _cache_enabled 39 | _cache_enabled += 1 40 | try: 41 | yield 42 | finally: 43 | _cache_enabled -= 1 44 | if not _cache_enabled: 45 | _cache = {} 46 | 47 | 48 | class ParametrizationList(ModuleList): 49 | r"""A sequential container that holds and manages the ``original`` parameter or buffer of 50 | a parametrized :class:`torch.nn.Module`. It is the type of 51 | ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` has been parametrized 52 | with :func:`register_parametrization`. 53 | .. note :: 54 | This class is used internally by :func:`register_parametrization`. It is documented 55 | here for completeness. It should not be instantiated by the user. 56 | Args: 57 | modules (iterable): an iterable of modules representing the parametrizations 58 | original (Parameter or Tensor): parameter or buffer that is parametrized 59 | """ 60 | original: Tensor 61 | 62 | def __init__( 63 | self, modules: Iterable[Module], original: Union[Tensor, Parameter] 64 | ) -> None: 65 | super().__init__(modules) 66 | if isinstance(original, Parameter): 67 | self.register_parameter("original", original) 68 | else: 69 | self.register_buffer("original", original) 70 | 71 | def set_original_(self, value: Tensor) -> None: 72 | r"""This method is called when assigning to a parametrized tensor. 73 | It calls the methods ``right_inverse`` (see :func:`register_parametrization`) 74 | of the parametrizations in the inverse order that they have been registered. 75 | Then, it assigns the result to ``self.original``. 76 | Args: 77 | value (Tensor): Value to which initialize the module 78 | Raises: 79 | RuntimeError: if any of the parametrizations do not implement a ``right_inverse`` method 80 | """ 81 | with torch.no_grad(): 82 | # See https://github.com/pytorch/pytorch/issues/53103 83 | for module in reversed(self): # type: ignore[call-overload] 84 | if hasattr(module, "right_inverse"): 85 | value = module.right_inverse(value) 86 | else: 87 | raise RuntimeError( 88 | "The parametrization '{}' does not implement a 'right_inverse' method. " 89 | "Assigning to a parametrized tensor is only possible when all the parametrizations " 90 | "implement a 'right_inverse' method.".format(module.__class__.__name__) 91 | ) 92 | self.original.copy_(value) 93 | 94 | def forward(self) -> Tensor: 95 | x = self.original 96 | for module in self: 97 | x = module(x) 98 | # if x.size() != self.original.size(): 99 | # raise RuntimeError( 100 | # "The parametrization may not change the size of the parametrized tensor. " 101 | # "Size of original tensor: {} " 102 | # "Size of parametrized tensor: {}".format(self.original.size(), x.size()) 103 | # ) 104 | return x 105 | 106 | 107 | def _inject_new_class(module: Module) -> None: 108 | r"""Sets up the parametrization mechanism used by parametrizations. 109 | This works by substituting the class of the module by a class 110 | that extends it to be able to inject a property 111 | Args: 112 | module (nn.Module): module into which to inject the property 113 | """ 114 | cls = module.__class__ 115 | 116 | def getstate(self): 117 | raise RuntimeError( 118 | "Serialization of parametrized modules is only " 119 | "supported through state_dict(). See:\n" 120 | "https://pytorch.org/tutorials/beginner/saving_loading_models.html" 121 | "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" 122 | ) 123 | 124 | param_cls = type( 125 | "Parametrized{}".format(cls.__name__), 126 | (cls,), 127 | { 128 | "__getstate__": getstate, 129 | }, 130 | ) 131 | 132 | module.__class__ = param_cls 133 | 134 | 135 | def _inject_property(module: Module, tensor_name: str) -> None: 136 | r"""Injects a property into module[tensor_name]. 137 | It assumes that the class in the module has already been modified from its 138 | original one using _inject_new_class and that the tensor under :attr:`tensor_name` 139 | has already been moved out 140 | Args: 141 | module (nn.Module): module into which to inject the property 142 | tensor_name (str): name of the name of the property to create 143 | """ 144 | # We check the precondition. 145 | # This should never fire if register_parametrization is correctly implemented 146 | assert not hasattr(module, tensor_name) 147 | 148 | def get_parametrized(self) -> Tensor: 149 | global _cache 150 | 151 | parametrization = self.parametrizations[tensor_name] 152 | if _cache_enabled: 153 | key = (id(module), tensor_name) 154 | tensor = _cache.get(key) 155 | if tensor is None: 156 | tensor = parametrization() 157 | _cache[key] = tensor 158 | return tensor 159 | else: 160 | # If caching is not active, this function just evaluates the parametrization 161 | return parametrization() 162 | 163 | def set_original(self, value: Tensor) -> None: 164 | self.parametrizations[tensor_name].set_original_(value) 165 | 166 | setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) 167 | 168 | 169 | def register_parametrization( 170 | module: Module, tensor_name: str, parametrization: Module 171 | ) -> Module: 172 | r"""Adds a parametrization to a tensor in a module. 173 | Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, 174 | the module will return the parametrized version ``parametrization(module.weight)``. 175 | If the original tensor requires a gradient, the backward pass will differentiate 176 | through the :attr:`parametrization`, and the optimizer will update the tensor accordingly. 177 | The first time that a module registers a parametrization, this function will add an attribute 178 | ``parametrizations`` to the module of type :class:`~ParametrizationList`. 179 | The list of parametrizations on a tensor will be accessible under 180 | ``module.parametrizations.weight``. 181 | The original tensor will be accessible under 182 | ``module.parametrizations.weight.original``. 183 | Parametrizations may be concatenated by registering several parametrizations 184 | on the same attribute. 185 | The training mode of the registered parametrizations are updated on registration 186 | if necessary to match the training mode of the host module 187 | Parametrized parameters and buffers have an inbuilt caching system that can be activated 188 | using the context manager :func:`cached`. 189 | A :attr:`parametrization` may optionally implement a method with signature 190 | .. code-block:: python 191 | def right_inverse(self, X: Tensor) -> Tensor 192 | If :attr:`parametrization` implements this method, it will be possible to assign 193 | to the parametrized tensor. This may be used to initialize the tensor, as shown in the example. 194 | In most situations, ``right_inverse`` will be a function such that 195 | ``forward(right_inverse(X)) == X`` (see 196 | `right inverse `_). 197 | Sometimes, when the parametrization is not surjective, it may be reasonable 198 | to relax this, as shown in the example below. 199 | Args: 200 | module (nn.Module): module on which to register the parametrization 201 | tensor_name (str): name of the parameter or buffer on which to register 202 | the parametrization 203 | parametrization (nn.Module): the parametrization to register 204 | Returns: 205 | Module: module 206 | Raises: 207 | ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` 208 | Examples: 209 | >>> import torch 210 | >>> import torch.nn.utils.parametrize as P 211 | >>> 212 | >>> class Symmetric(torch.nn.Module): 213 | >>> def forward(self, X): 214 | >>> return X.triu() + X.triu(1).T # Return a symmetric matrix 215 | >>> 216 | >>> def right_inverse(self, A): 217 | >>> return A.triu() 218 | >>> 219 | >>> m = torch.nn.Linear(5, 5) 220 | >>> P.register_parametrization(m, "weight", Symmetric()) 221 | >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric 222 | True 223 | >>> A = torch.rand(5, 5) 224 | >>> A = A + A.T # A is now symmetric 225 | >>> m.weight = A # Initialize the weight to be the symmetric matrix A 226 | >>> print(torch.allclose(m.weight, A)) 227 | True 228 | """ 229 | parametrization.train(module.training) 230 | if is_parametrized(module, tensor_name): 231 | # Just add the new parametrization to the parametrization list 232 | module.parametrizations[tensor_name].append(parametrization) # type: ignore[index, union-attr] 233 | elif tensor_name in module._buffers or tensor_name in module._parameters: 234 | # Set the parametrization mechanism 235 | # Fetch the original buffer or parameter 236 | original = getattr(module, tensor_name) 237 | # Delete the previous parameter or buffer 238 | delattr(module, tensor_name) 239 | # If this is the first parametrization registered on the module, 240 | # we prepare the module to inject the property 241 | if not is_parametrized(module): 242 | # Change the class 243 | _inject_new_class(module) 244 | # Inject the a ``ModuleDict`` into the instance under module.parametrizations 245 | module.parametrizations = ModuleDict() 246 | # Add a property into the class 247 | _inject_property(module, tensor_name) 248 | # Add a ParametrizationList 249 | module.parametrizations[tensor_name] = ParametrizationList( # type: ignore[assignment, index, operator] 250 | [parametrization], original 251 | ) 252 | else: 253 | raise ValueError( 254 | "Module '{}' does not have a parameter, a buffer, or a " 255 | "parametrized element with name '{}'".format(module, tensor_name) 256 | ) 257 | return module 258 | 259 | 260 | def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: 261 | r"""Returns ``True`` if module has an active parametrization. 262 | If the argument :attr:`tensor_name` is specified, returns ``True`` if 263 | ``module[tensor_name]`` is parametrized. 264 | Args: 265 | module (nn.Module): module to query 266 | name (str, optional): attribute in the module to query 267 | Default: ``None`` 268 | """ 269 | parametrizations = getattr(module, "parametrizations", None) 270 | if parametrizations is None or not isinstance(parametrizations, ModuleDict): 271 | return False 272 | if tensor_name is None: 273 | # Check that there is at least one parametrized buffer or Parameter 274 | return len(parametrizations) > 0 275 | else: 276 | return tensor_name in parametrizations 277 | 278 | 279 | def remove_parametrizations( 280 | module: Module, tensor_name: str, leave_parametrized: bool = True 281 | ) -> Module: 282 | r"""Removes the parametrizations on a tensor in a module. 283 | - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to 284 | its current output. In this case, the parametrization shall not change the ``dtype`` 285 | of the tensor. 286 | - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to 287 | the unparametrised tensor in ``module.parametrizations[tensor_name].original``. 288 | Args: 289 | module (nn.Module): module from which remove the parametrization 290 | tensor_name (str): name of the parametrization to be removed 291 | leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. 292 | Default: ``True`` 293 | Returns: 294 | Module: module 295 | Raises: 296 | ValueError: if ``module[tensor_name]`` is not parametrized 297 | ValueError: if ``leave_parametrized=True`` and the parametrization changes the size or dtype 298 | of the tensor 299 | """ 300 | 301 | if not is_parametrized(module, tensor_name): 302 | raise ValueError( 303 | "Module {} does not have a parametrization on {}".format( 304 | module, tensor_name 305 | ) 306 | ) 307 | 308 | # Fetch the original tensor 309 | original = module.parametrizations[tensor_name].original # type: ignore[index, union-attr] 310 | if leave_parametrized: 311 | with torch.no_grad(): 312 | t = getattr(module, tensor_name) 313 | # If they have the same dtype, we reuse the original tensor. 314 | # We do this so that the parameter does not to change the id() 315 | # This way the user does not need to update the optimizer 316 | if t.dtype == original.dtype: 317 | with torch.no_grad(): 318 | original.set_(t) 319 | else: 320 | raise ValueError( 321 | "The parametrization changes the dtype of the tensor from {} to {}. " 322 | "It is not supported to leave the tensor parametrized (`leave_parametrized=True`) " 323 | "in this case.".format(original.dtype, t.dtype) 324 | ) 325 | # Delete the property that manages the parametrization 326 | delattr(module.__class__, tensor_name) 327 | # Delete the ParametrizationList 328 | del module.parametrizations[tensor_name] # type: ignore[operator, union-attr] 329 | 330 | # Restore the parameter / buffer into the main class 331 | if isinstance(original, Parameter): 332 | module.register_parameter(tensor_name, original) 333 | else: 334 | module.register_buffer(tensor_name, original) 335 | 336 | # Roll back the parametrized class if no other buffer or parameter 337 | # is currently parametrized in this class 338 | if not is_parametrized(module): 339 | delattr(module, "parametrizations") 340 | # Restore class 341 | orig_cls = module.__class__.__bases__[0] 342 | module.__class__ = orig_cls 343 | return module -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | torch-optimizer 3 | einops 4 | pytorch-model-summary 5 | pytorch-lightning 6 | -------------------------------------------------------------------------------- /splits.csv: -------------------------------------------------------------------------------- 1 | ,id_date,split_id,split,id_day,date 2 | 0,2019046,0,training,046,20190215 3 | 1,2019047,2,test,047,20190216 4 | 2,2019048,2,test-next,048,20190217 5 | 3,2019049,0,training,049,20190218 6 | 4,2019050,0,training,050,20190219 7 | 5,2019051,0,training,051,20190220 8 | 6,2019052,0,training,052,20190221 9 | 7,2019053,0,training,053,20190222 10 | 8,2019054,0,training,054,20190223 11 | 9,2019055,0,training,055,20190224 12 | 10,2019056,0,training,056,20190225 13 | 11,2019057,0,training,057,20190226 14 | 12,2019058,0,training,058,20190227 15 | 13,2019059,0,training,059,20190228 16 | 14,2019060,0,training,060,20190301 17 | 15,2019061,0,training,061,20190302 18 | 16,2019062,0,training,062,20190303 19 | 17,2019063,0,training,063,20190304 20 | 18,2019064,0,training,064,20190305 21 | 19,2019065,0,training,065,20190306 22 | 20,2019066,0,training,066,20190307 23 | 21,2019067,0,training,067,20190308 24 | 22,2019068,0,training,068,20190309 25 | 23,2019069,0,training,069,20190310 26 | 24,2019070,0,training,070,20190311 27 | 25,2019071,0,training,071,20190312 28 | 26,2019072,0,training,072,20190313 29 | 27,2019073,2,test,073,20190314 30 | 28,2019074,2,test-next,074,20190315 31 | 29,2019075,0,training,075,20190316 32 | 30,2019076,1,validation,076,20190317 33 | 31,2019077,0,training,077,20190318 34 | 32,2019078,0,training,078,20190319 35 | 33,2019079,0,training,079,20190320 36 | 34,2019080,0,training,080,20190321 37 | 35,2019081,0,training,081,20190322 38 | 36,2019082,2,test,082,20190323 39 | 37,2019083,0,training,083,20190324 40 | 38,2019084,0,training,084,20190325 41 | 39,2019085,0,training,085,20190326 42 | 40,2019086,0,training,086,20190327 43 | 41,2019087,0,training,087,20190328 44 | 42,2019088,0,training,088,20190329 45 | 43,2019089,1,validation,089,20190330 46 | 44,2019090,0,training,090,20190331 47 | 45,2019091,0,training,091,20190401 48 | 46,2019092,2,test,092,20190402 49 | 47,2019093,0,training,093,20190403 50 | 48,2019094,0,training,094,20190404 51 | 49,2019095,0,training,095,20190405 52 | 50,2019096,0,training,096,20190406 53 | 51,2019097,0,training,097,20190407 54 | 52,2019098,2,test,098,20190408 55 | 53,2019099,0,training,099,20190409 56 | 54,2019100,0,training,100,20190410 57 | 55,2019101,1,validation,101,20190411 58 | 56,2019102,1,validation,102,20190412 59 | 57,2019103,0,training,103,20190413 60 | 58,2019104,2,test,104,20190414 61 | 59,2019105,2,test-next,105,20190415 62 | 60,2019106,0,training,106,20190416 63 | 61,2019107,0,training,107,20190417 64 | 62,2019108,0,training,108,20190418 65 | 63,2019109,0,training,109,20190419 66 | 64,2019110,2,test,110,20190420 67 | 65,2019111,0,training,111,20190421 68 | 66,2019112,0,training,112,20190422 69 | 67,2019113,0,training,113,20190423 70 | 68,2019114,2,test,114,20190424 71 | 69,2019115,0,training,115,20190425 72 | 70,2019116,0,training,116,20190426 73 | 71,2019117,0,training,117,20190427 74 | 72,2019118,0,training,118,20190428 75 | 73,2019119,0,training,119,20190429 76 | 74,2019120,1,validation,120,20190430 77 | 75,2019121,0,training,121,20190501 78 | 76,2019122,0,training,122,20190502 79 | 77,2019123,0,training,123,20190503 80 | 78,2019124,0,training,124,20190504 81 | 79,2019125,0,training,125,20190505 82 | 80,2019126,1,validation,126,20190506 83 | 81,2019127,0,training,127,20190507 84 | 82,2019128,0,training,128,20190508 85 | 83,2019129,0,training,129,20190509 86 | 84,2019130,0,training,130,20190510 87 | 85,2019131,0,training,131,20190511 88 | 86,2019132,0,training,132,20190512 89 | 87,2019133,0,training,133,20190513 90 | 88,2019134,0,training,134,20190514 91 | 89,2019135,0,training,135,20190515 92 | 90,2019136,0,training,136,20190516 93 | 91,2019137,0,training,137,20190517 94 | 92,2019138,0,training,138,20190518 95 | 93,2019139,1,validation,139,20190519 96 | 94,2019140,0,training,140,20190520 97 | 95,2019141,0,training,141,20190521 98 | 96,2019142,0,training,142,20190522 99 | 97,2019143,0,training,143,20190523 100 | 98,2019144,0,training,144,20190524 101 | 99,2019145,2,test,145,20190525 102 | 100,2019146,0,training,146,20190526 103 | 101,2019147,2,test,147,20190527 104 | 102,2019148,2,test-next,148,20190528 105 | 103,2019149,0,training,149,20190529 106 | 104,2019150,1,validation,150,20190530 107 | 105,2019151,0,training,151,20190531 108 | 106,2019152,0,training,152,20190601 109 | 107,2019153,0,training,153,20190602 110 | 108,2019154,0,training,154,20190603 111 | 109,2019155,0,training,155,20190604 112 | 110,2019156,0,training,156,20190605 113 | 111,2019157,0,training,157,20190606 114 | 112,2019158,0,training,158,20190607 115 | 113,2019159,0,training,159,20190608 116 | 114,2019160,0,training,160,20190609 117 | 115,2019161,0,training,161,20190610 118 | 116,2019162,0,training,162,20190611 119 | 117,2019163,1,validation,163,20190612 120 | 118,2019164,0,training,164,20190613 121 | 119,2019165,0,training,165,20190614 122 | 120,2019166,0,training,166,20190615 123 | 121,2019167,1,validation,167,20190616 124 | 122,2019168,0,training,168,20190617 125 | 123,2019169,0,training,169,20190618 126 | 124,2019170,0,training,170,20190619 127 | 125,2019171,0,training,171,20190620 128 | 126,2019172,0,training,172,20190621 129 | 127,2019173,0,training,173,20190622 130 | 128,2019174,2,test,174,20190623 131 | 129,2019175,0,training,175,20190624 132 | 130,2019176,0,training,176,20190625 133 | 131,2019177,0,training,177,20190626 134 | 132,2019178,2,test,178,20190627 135 | 133,2019179,0,training,179,20190628 136 | 134,2019180,0,training,180,20190629 137 | 135,2019181,0,training,181,20190630 138 | 136,2019182,0,training,182,20190701 139 | 137,2019183,0,training,183,20190702 140 | 138,2019184,0,training,184,20190703 141 | 139,2019185,0,training,185,20190704 142 | 140,2019186,0,training,186,20190705 143 | 141,2019187,0,training,187,20190706 144 | 142,2019188,2,test,188,20190707 145 | 143,2019189,1,validation,189,20190708 146 | 144,2019190,0,training,190,20190709 147 | 145,2019191,0,training,191,20190710 148 | 146,2019192,0,training,192,20190711 149 | 147,2019193,1,validation,193,20190712 150 | 148,2019194,1,validation,194,20190713 151 | 149,2019195,0,training,195,20190714 152 | 150,2019196,0,training,196,20190715 153 | 151,2019197,0,training,197,20190716 154 | 152,2019198,0,training,198,20190717 155 | 153,2019199,0,training,199,20190718 156 | 154,2019200,0,training,200,20190719 157 | 155,2019201,0,training,201,20190720 158 | 156,2019202,0,training,202,20190721 159 | 157,2019203,0,training,203,20190722 160 | 158,2019204,0,training,204,20190723 161 | 159,2019205,1,validation,205,20190724 162 | 160,2019206,0,training,206,20190725 163 | 161,2019207,0,training,207,20190726 164 | 162,2019208,0,training,208,20190727 165 | 163,2019209,2,test,209,20190728 166 | 164,2019210,2,test-next,210,20190729 167 | 165,2019211,0,training,211,20190730 168 | 166,2019212,0,training,212,20190731 169 | 167,2019213,1,validation,213,20190801 170 | 168,2019214,0,training,214,20190802 171 | 169,2019215,0,training,215,20190803 172 | 170,2019216,0,training,216,20190804 173 | 171,2019217,0,training,217,20190805 174 | 172,2019218,0,training,218,20190806 175 | 173,2019219,0,training,219,20190807 176 | 174,2019220,0,training,220,20190808 177 | 175,2019221,1,validation,221,20190809 178 | 176,2019222,0,training,222,20190810 179 | 177,2019223,0,training,223,20190811 180 | 178,2019224,0,training,224,20190812 181 | 179,2019225,0,training,225,20190813 182 | 180,2019226,0,training,226,20190814 183 | 181,2019227,0,training,227,20190815 184 | 182,2019228,0,training,228,20190816 185 | 183,2019229,0,training,229,20190817 186 | 184,2019230,0,training,230,20190818 187 | 185,2019231,0,training,231,20190819 188 | 186,2019232,0,training,232,20190820 189 | 187,2019233,1,validation,233,20190821 190 | 188,2019234,0,training,234,20190822 191 | 189,2019235,0,training,235,20190823 192 | 190,2019236,2,test,236,20190824 193 | 191,2019237,2,test-next,237,20190825 194 | 192,2019238,1,validation,238,20190826 195 | 193,2019239,0,training,239,20190827 196 | 194,2019240,1,validation,240,20190828 197 | 195,2019241,0,training,241,20190829 198 | 196,2019242,0,training,242,20190830 199 | 197,2019243,2,test,243,20190831 200 | 198,2019244,1,validation,244,20190901 201 | 199,2019245,0,training,245,20190902 202 | 200,2019246,2,test,246,20190903 203 | 201,2019247,2,test-next,247,20190904 204 | 202,2019248,0,training,248,20190905 205 | 203,2019249,0,training,249,20190906 206 | 204,2019250,0,training,250,20190907 207 | 205,2019251,0,training,251,20190908 208 | 206,2019252,0,training,252,20190909 209 | 207,2019253,0,training,253,20190910 210 | 208,2019254,0,training,254,20190911 211 | 209,2019255,0,training,255,20190912 212 | 210,2019256,1,validation,256,20190913 213 | 211,2019257,0,training,257,20190914 214 | 212,2019258,1,validation,258,20190915 215 | 213,2019259,1,validation,259,20190916 216 | 214,2019260,0,training,260,20190917 217 | 215,2019261,0,training,261,20190918 218 | 216,2019262,0,training,262,20190919 219 | 217,2019263,0,training,263,20190920 220 | 218,2019264,0,training,264,20190921 221 | 219,2019265,0,training,265,20190922 222 | 220,2019266,0,training,266,20190923 223 | 221,2019267,0,training,267,20190924 224 | 222,2019268,0,training,268,20190925 225 | 223,2019269,2,test,269,20190926 226 | 224,2019270,2,test-next,270,20190927 227 | 225,2019271,0,training,271,20190928 228 | 226,2019272,0,training,272,20190929 229 | 227,2019273,2,test,273,20190930 230 | 228,2019274,0,training,274,20191001 231 | 229,2019275,0,training,275,20191002 232 | 230,2019276,0,training,276,20191003 233 | 231,2019277,0,training,277,20191004 234 | 232,2019278,1,validation,278,20191005 235 | 233,2019279,1,validation,279,20191006 236 | 234,2019280,0,training,280,20191007 237 | 235,2019281,0,training,281,20191008 238 | 236,2019282,0,training,282,20191009 239 | 237,2019283,1,validation,283,20191010 240 | 238,2019284,0,training,284,20191011 241 | 239,2019285,2,test,285,20191012 242 | 240,2019286,2,test-next,286,20191013 243 | 241,2019287,1,validation,287,20191014 244 | 242,2019288,0,training,288,20191015 245 | 243,2019289,0,training,289,20191016 246 | 244,2019290,0,training,290,20191017 247 | 245,2019291,0,training,291,20191018 248 | 246,2019292,0,training,292,20191019 249 | 247,2019293,0,training,293,20191020 250 | 248,2019294,0,training,294,20191021 251 | 249,2019295,0,training,295,20191022 252 | 250,2019296,1,validation,296,20191023 253 | 251,2019297,0,training,297,20191024 254 | 252,2019298,0,training,298,20191025 255 | 253,2019299,1,validation,299,20191026 256 | 254,2019300,0,training,300,20191027 257 | 255,2019301,0,training,301,20191028 258 | 256,2019302,0,training,302,20191029 259 | 257,2019303,0,training,303,20191030 260 | 258,2019304,0,training,304,20191031 261 | 259,2019305,0,training,305,20191101 262 | 260,2019306,2,test,306,20191102 263 | 261,2019307,2,test-next,307,20191103 264 | 262,2019308,0,training,308,20191104 265 | 263,2019309,0,training,309,20191105 266 | 264,2019310,2,test,310,20191106 267 | 265,2019311,0,training,311,20191107 268 | 266,2019312,0,training,312,20191108 269 | 267,2019313,0,training,313,20191109 270 | 268,2019314,0,training,314,20191110 271 | 269,2019315,0,training,315,20191111 272 | 270,2019316,0,training,316,20191112 273 | 271,2019317,0,training,317,20191113 274 | 272,2019318,0,training,318,20191114 275 | 273,2019319,2,test,319,20191115 276 | 274,2019320,0,training,320,20191116 277 | 275,2019321,0,training,321,20191117 278 | 276,2019322,0,training,322,20191118 279 | 277,2019323,1,validation,323,20191119 280 | 278,2019324,2,test,324,20191120 281 | 279,2019325,0,training,325,20191121 282 | 280,2019326,0,training,326,20191122 283 | 281,2019327,0,training,327,20191123 284 | 282,2019328,0,training,328,20191124 285 | 283,2019329,0,training,329,20191125 286 | 284,2019330,0,training,330,20191126 287 | 285,2019331,0,training,331,20191127 288 | 286,2019332,0,training,332,20191128 289 | 287,2019333,0,training,333,20191129 290 | 288,2019334,0,training,334,20191130 291 | 289,2019335,0,training,335,20191201 292 | 290,2019336,1,validation,336,20191202 293 | 291,2019337,0,training,337,20191203 294 | 292,2019338,1,validation,338,20191204 295 | 293,2019339,0,training,339,20191205 296 | 294,2019340,0,training,340,20191206 297 | 295,2019341,0,training,341,20191207 298 | 296,2019342,0,training,342,20191208 299 | 297,2019343,0,training,343,20191209 300 | 298,2019344,0,training,344,20191210 301 | 299,2019345,0,training,345,20191211 302 | 300,2019346,0,training,346,20191212 303 | 301,2019347,0,training,347,20191213 304 | 302,2019348,0,training,348,20191214 305 | 303,2019349,2,test,349,20191215 306 | 304,2019350,0,training,350,20191216 307 | 305,2019351,0,training,351,20191217 308 | 306,2019352,0,training,352,20191218 309 | 307,2019353,2,test,353,20191219 310 | 308,2019354,0,training,354,20191220 311 | 309,2019355,0,training,355,20191221 312 | 310,2019356,0,training,356,20191222 313 | 311,2019357,0,training,357,20191223 314 | 312,2019358,0,training,358,20191224 315 | 313,2019359,1,validation,359,20191225 316 | 314,2019360,0,training,360,20191226 317 | 315,2019361,0,training,361,20191227 318 | 316,2019362,0,training,362,20191228 319 | 317,2019363,2,test,363,20191229 320 | 318,2019364,2,test-next,364,20191230 321 | 319,2019365,2,test,365,20191231 322 | 320,2020001,0,training,001,20200101 323 | 321,2020002,0,training,002,20200102 324 | 322,2020003,0,training,003,20200103 325 | 323,2020004,0,training,004,20200104 326 | 324,2020005,0,training,005,20200105 327 | 325,2020006,1,validation,006,20200106 328 | 326,2020007,0,training,007,20200107 329 | 327,2020008,0,training,008,20200108 330 | 328,2020009,2,test,009,20200109 331 | 329,2020010,0,training,010,20200110 332 | 330,2020011,2,test,011,20200111 333 | 331,2020012,0,training,012,20200112 334 | 332,2020013,0,training,013,20200113 335 | 333,2020014,0,training,014,20200114 336 | 334,2020015,0,training,015,20200115 337 | 335,2020016,0,training,016,20200116 338 | 336,2020017,0,training,017,20200117 339 | 337,2020018,0,training,018,20200118 340 | 338,2020019,0,training,019,20200119 341 | 339,2020020,2,test,020,20200120 342 | 340,2020021,0,training,021,20200121 343 | 341,2020022,0,training,022,20200122 344 | 342,2020023,2,test,023,20200123 345 | 343,2020024,0,training,024,20200124 346 | 344,2020025,0,training,025,20200125 347 | 345,2020026,0,training,026,20200126 348 | 346,2020027,0,training,027,20200127 349 | 347,2020028,0,training,028,20200128 350 | 348,2020029,0,training,029,20200129 351 | 349,2020030,2,test,030,20200130 352 | 350,2020031,0,training,031,20200131 353 | 351,2020032,0,training,032,20200201 354 | 352,2020033,0,training,033,20200202 355 | 353,2020034,2,test,034,20200203 356 | 354,2020035,2,test-next,035,20200204 357 | 355,2020036,0,training,036,20200205 358 | 356,2020037,0,training,037,20200206 359 | 357,2020038,2,test,038,20200207 360 | 358,2020039,2,test-next,039,20200208 361 | 359,2020040,0,training,040,20200209 362 | 360,2020041,0,training,041,20200210 363 | 361,2020042,2,test,042,20200211 364 | 362,2020043,0,training,043,20200212 365 | 363,2020044,0,training,044,20200213 366 | 364,2020045,0,training,045,20200214 -------------------------------------------------------------------------------- /t.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojesomo/Weather4cast2021-SwinEncoderDecoder/8c158890628be2a28d47aa082fd96eb43c6cbb4a/t.npy -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pytorch_model_summary import summary 3 | 4 | 5 | def model_summary(model, inputs, print_summary=False, max_depth=1, show_parent_layers=False): 6 | # _ = summary(model, x_in, print_summary=True) 7 | kwargs = {'max_depth': max_depth, 8 | 'show_parent_layers': show_parent_layers} 9 | sT = summary(model, inputs, show_input=True, print_summary=False, **kwargs) 10 | sF = summary(model, inputs, show_input=False, print_summary=False, **kwargs) 11 | 12 | st = sT.split('\n') 13 | sf = sF.split('\n') 14 | 15 | sf1 = re.split(r'\s{2,}', sf[1]) 16 | out_i = sf1.index('Output Shape') 17 | 18 | ss = [] 19 | i_esc = [] 20 | for i in range(0, len(st)): 21 | if len(re.split(r'\s{2,}', st[i])) == 1: 22 | ssi = st[i] 23 | if len(set(st[i])) == 1: 24 | i_esc.append(i) 25 | else: 26 | sfi = re.split(r'\s{2,}', sf[i]) 27 | sti = re.split(r'\s{2,}', st[i]) 28 | 29 | ptr = st[i].index(sti[out_i]) + len(sti[out_i]) 30 | in_1 = sf[i].index(sfi[out_i-1]) + len(sfi[out_i-1]) 31 | in_2 = sf[i].index(sfi[out_i]) + len(sfi[out_i]) 32 | ssi = st[i][:ptr] + sf[i][in_1:in_2] + st[i][ptr:] 33 | ss.append(ssi) 34 | 35 | n_str = max([len(s) for s in ss]) 36 | for i in i_esc: 37 | ss[i] = ss[i][-1] * n_str 38 | 39 | ss = '\n'.join(ss) 40 | if print_summary: 41 | print(ss) 42 | 43 | return ss 44 | -------------------------------------------------------------------------------- /utils/2. Submission_UNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Author: Pedro Herruzo\n", 10 | "# Copyright 2021 Institute of Advanced Research in Artificial Intelligence (IARAI) GmbH.\n", 11 | "# IARAI licenses this file to You under the Apache License, Version 2.0\n", 12 | "# (the \"License\"); you may not use this file except in compliance with\n", 13 | "# the License. You may obtain a copy of the License at\n", 14 | "#\n", 15 | "# http://www.apache.org/licenses/LICENSE-2.0\n", 16 | "#\n", 17 | "# Unless required by applicable law or agreed to in writing, software\n", 18 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 19 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 20 | "# See the License for the specific language governing permissions and\n", 21 | "# limitations under the License." 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Create a submission using pretrained UNet models\n", 29 | "\n", 30 | "In this notebook we will perform the following actions:\n", 31 | "* Create a valid submission for the core-competition (R1-3) using pretrained UNets per region i.e. individual models per region\n", 32 | "* Create a valid submission for the transfer-learning-competition (R4-6) using a single UNet trained on region R1\n", 33 | "* Use the ensamble of models trained in regions R1-3 to generate a valid submission transfer-learning-competition (R4-6) by averaging their predictions\n", 34 | "\n", 35 | "[Download the weights for the pre-trained models here](https://www.iarai.ac.at/weather4cast/forums/topic/weights-of-unet-baselines-trained-on-regions-r1-r2-and-r3/)\n", 36 | "\n", 37 | "Dependencies required:\n", 38 | "* torch \n", 39 | "* pytorch_lightning\n", 40 | "* numpy\n", 41 | "\n", 42 | "The model is defined in weather4cast/benchmarks:\n", 43 | "* unet.py: architecture definition \n", 44 | "* FeaturesSysUNet.py: Inherits from pytorch_lightning.LightningModule. In this notebook we only use it for the forward pass\n", 45 | "\n", 46 | "Please, refer to those files if you want to know more about the architecture. You can also read the [docs of pytorch_lightning](https://pytorch-lightning.readthedocs.io/en/latest/), but it is not necessary in order to understand how to produce predictions, which is the main topic of this notebook.\n", 47 | "\n", 48 | "In this notebook, we will use the folder paths and parameters defined in **weather4cast/config.py**. Please, set there the paths to the data before starting with the notebook (check the `Start here` section in the README.md if you want to see how).\n" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Let us fisrt define the functions that will perform the main tasks:\n", 56 | "* create the submission directory structure\n", 57 | "* load the data & the model \n", 58 | "* compute predictions per day in the test split for a given region" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 1, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "%load_ext autoreload\n", 68 | "%autoreload 2\n", 69 | "\n", 70 | "from torch.utils.data import DataLoader\n", 71 | "\n", 72 | "import pathlib\n", 73 | "import sys\n", 74 | "import os\n", 75 | "module_dir = str(pathlib.Path(os.getcwd()).parent)\n", 76 | "sys.path.append(module_dir)\n", 77 | "\n", 78 | "import data_utils\n", 79 | "import config as cf\n", 80 | "from w4c_dataloader import create_dataset\n", 81 | "from benchmarks.FeaturesSysUNet import FeaturesSysUNet as Model\n", 82 | "\n", 83 | "# ------------\n", 84 | "# 1. create folder structures for the submission\n", 85 | "# ------------\n", 86 | "def create_directory_structure(root, folder_name='submission'):\n", 87 | " \"\"\"\n", 88 | " create competition output directory structure at given root path. \n", 89 | " \"\"\"\n", 90 | " challenges = {'w4c-core-stage-1': ['R1', 'R2', 'R3'], 'w4c-transfer-learning-stage-1': ['R4', 'R5', 'R6']}\n", 91 | " \n", 92 | " for f_name, regions in challenges.items():\n", 93 | " for region in regions:\n", 94 | " r_path = os.path.join(root, folder_name, f_name, region, 'test')\n", 95 | " try:\n", 96 | " os.makedirs(r_path)\n", 97 | " print(f'created path: {r_path}')\n", 98 | " except:\n", 99 | " print(f'failed to create directory structure, maybe they already exist: {r_path}')\n", 100 | "\n", 101 | "# ------------\n", 102 | "# 2. load data & model\n", 103 | "# ------------\n", 104 | "def get_data_iterator(region_id='R1', data_split= 'test', collapse_time=True, \n", 105 | " batch_size=32, shuffle=False, num_workers=0):\n", 106 | " \"\"\" creates an iterator for data in region 'region_id' for the 'data_split' data partition \"\"\"\n", 107 | " \n", 108 | " params = cf.get_params(region_id=region_id)\n", 109 | " params['data_params']['collapse_time'] = collapse_time\n", 110 | "\n", 111 | " ds = create_dataset(data_split, params['data_params'])\n", 112 | " dataloader = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)\n", 113 | " \n", 114 | " data_splits, test_sequences = data_utils.read_splits(params['data_params']['train_splits'], params['data_params']['test_splits'])\n", 115 | " test_dates = data_splits[data_splits.split=='test'].id_date.sort_values().values\n", 116 | "\n", 117 | " return iter(dataloader), test_dates, params\n", 118 | "\n", 119 | "def load_model(Model, params, checkpoint_path='', device=None):\n", 120 | " \"\"\" loads a model from a checkpoint or from scratch if checkpoint_path='' \"\"\"\n", 121 | " \n", 122 | " if checkpoint_path == '':\n", 123 | " model = Model(params['model_params'], **params['data_params']) \n", 124 | " else:\n", 125 | " print(\"model:\", Model)\n", 126 | " print(f'-> Loading model checkpoint: {checkpoint_path}')\n", 127 | " model = Model.load_from_checkpoint(checkpoint_path)\n", 128 | " \n", 129 | " if device is not None:\n", 130 | " model = model.eval().cuda(device)\n", 131 | " \n", 132 | " return model\n", 133 | "\n", 134 | "# ------------\n", 135 | "# 3. make predictions & loop over regions\n", 136 | "# ------------\n", 137 | "def get_preds(model, batch, device=None):\n", 138 | " \"\"\" computes the output of the model on the next iterator's batch \n", 139 | " returns the prediction and the date of it\n", 140 | " \"\"\"\n", 141 | " \n", 142 | " in_seq, out, metadata = batch\n", 143 | " day_in_year = metadata['in']['day_in_year'][0][0].item()\n", 144 | " \n", 145 | " if device is not None:\n", 146 | " in_seq = in_seq.cuda(device=device)\n", 147 | " y_hat = model(in_seq)\n", 148 | " y_hat = y_hat.data.cpu().numpy() \n", 149 | " \n", 150 | " return y_hat, day_in_year\n", 151 | "\n", 152 | "def predictions_per_day(test_dates, model, ds_iterator, device, file_path, data_params):\n", 153 | " \"\"\" computes predictions of all dates and saves them to disk \"\"\"\n", 154 | " \n", 155 | " for target_date in test_dates:\n", 156 | " print(f'generating submission for date: {target_date}...')\n", 157 | " batch = next(ds_iterator)\n", 158 | " y_hat, predicted_day = get_preds(model, batch, device)\n", 159 | " \n", 160 | " # force data to be in the valid range\n", 161 | " y_hat[y_hat>1] = 1\n", 162 | " y_hat[y_hat<0] = 0\n", 163 | " \n", 164 | " # batches are sorted by date for the dataloader, that's why they coincide\n", 165 | " assert predicted_day==target_date, f\"Error, the loaded date {predicted_day} is different than the target: {target_date}\"\n", 166 | "\n", 167 | " f_path = os.path.join(file_path, f'{predicted_day}.h5')\n", 168 | " y_hat = data_utils.postprocess_fn(y_hat, data_params['target_vars'], data_params['preprocess']['source'])\n", 169 | " data_utils.write_data(y_hat, f_path)\n", 170 | " print(f'--> saved in: {f_path}')" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "## Let us now generate and save the predictions for the core-competition:\n", 178 | "\n", 179 | "Note that you need to specify the path to save your predictions (root) and the path to the weights for the already trained model, You can find these files in the following section of our competition website: https://www.iarai.ac.at/weather4cast/forums/topic/weights-of-unet-baselines-trained-on-regions-r1-r2-and-r3/\n" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "# 1. Define model's checkpoints, regions per task & gpu id to use\n", 189 | "# Attention, if you work on Windows OS, modify the following paths accordingly\n", 190 | "root_to_ckps = '~/projects/weather4cast/lightning_logs'\n", 191 | "checkpoint_paths = {'R1': f'{root_to_ckps}/version_21/checkpoints/epoch=03-val_loss_epoch=0.027697.ckpt', \n", 192 | " 'R2': f'{root_to_ckps}/version_19/checkpoints/epoch=01-val_loss_epoch=0.042129.ckpt', \n", 193 | " 'R3': f'{root_to_ckps}/version_20/checkpoints/epoch=06-val_loss_epoch=0.058147.ckpt'}\n", 194 | "challenges = {'w4c-core-stage-1': ['R1', 'R2', 'R3'], 'w4c-transfer-learning-stage-1': ['R4', 'R5', 'R6']}\n", 195 | "device = 0 # gpu id - SET THE ID OF THE GPU YOU WANT TO USE ---> Use `None` for CPU\n", 196 | "\n", 197 | "# 2. define root and name of the submission to create the folders' structure\n", 198 | "root = '/iarai/home/pedro.herruzo/projects/Weather4cast2021/utils/submission_examples'\n", 199 | "folder_name = 'UNet_submission'\n", 200 | "create_directory_structure(root, folder_name=folder_name)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "# 3. compute and save predictions for each reagion for all dates in the test split\n", 210 | "task_name = 'w4c-core-stage-1'\n", 211 | "for region in challenges[task_name]:\n", 212 | " # load data and model\n", 213 | " ds_iterator, test_dates, params = get_data_iterator(region_id=region)\n", 214 | " model = load_model(Model, params, checkpoint_path=checkpoint_paths[region], device=device)\n", 215 | "\n", 216 | " r_path = os.path.join(root, folder_name, task_name, region, 'test')\n", 217 | " predictions_per_day(test_dates, model, ds_iterator, device, r_path, params['data_params']) " 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "We have just computed a valid submission for all regions using a pretrained UNet model. \n", 225 | "\n", 226 | "Now we should submit a zip containing all regions and we are done. Please, follow the instructions in weather4cast/README.md to generate the zip file." 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "## Transfer Learning competition submission\n", 234 | "\n", 235 | "### Let us first generate the predictions using only a single model from the core-competition" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "# 3. compute and save predictions for each reagion for all dates in the test split\n", 245 | "task_name = 'w4c-core-stage-1'\n", 246 | "for region in challenges[task_name]:\n", 247 | " # load data and model\n", 248 | " ds_iterator, test_dates, params = get_data_iterator(region_id=region)\n", 249 | " model = load_model(Model, params, checkpoint_path=checkpoint_paths[region], device=device)\n", 250 | "\n", 251 | " r_path = os.path.join(root, folder_name, task_name, region, 'test')\n", 252 | " predictions_per_day(test_dates, model, ds_iterator, device, r_path, params['data_params']) " 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "### Let use an ensemble of the models learned in regions 1-3 by getting individual predictions and just averaging across them" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "# create a new folder structure (note that we will only use R4-6)\n", 269 | "folder_name = 'transfer-ensample'\n", 270 | "create_directory_structure(root, folder_name=folder_name)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 11, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "import numpy as np\n", 280 | "\n", 281 | "def predictions_per_day_ensamble(test_dates, models, ds_iterator, device, file_path, data_params):\n", 282 | " \"\"\" computes predictions of all dates and saves them to disk. It uses the average of predictions across all models provided\n", 283 | " models (list): list of models to be used in the ensample\n", 284 | " \"\"\"\n", 285 | " \n", 286 | " for target_date in test_dates:\n", 287 | " print(f'generating submission for date: {target_date}...')\n", 288 | " batch = next(ds_iterator)\n", 289 | " \n", 290 | " ensamble = []\n", 291 | " for model in models:\n", 292 | " y_hat, predicted_day = get_preds(model, batch, device)\n", 293 | "\n", 294 | " # force data to be in the valid range\n", 295 | " y_hat[y_hat>1] = 1\n", 296 | " y_hat[y_hat<0] = 0\n", 297 | "\n", 298 | " # batches are sorted by date for the dataloader, that's why they coincide\n", 299 | " assert predicted_day==target_date, f\"Error, the loaded date {predicted_day} is different than the target: {target_date}\"\n", 300 | " \n", 301 | " ensamble.append(y_hat)\n", 302 | " \n", 303 | " ensamble = np.asarray(ensamble)\n", 304 | " y_hat = np.mean(ensamble, axis=0)\n", 305 | "\n", 306 | " f_path = os.path.join(file_path, f'{predicted_day}.h5')\n", 307 | " y_hat = data_utils.postprocess_fn(y_hat, data_params['target_vars'], data_params['preprocess']['source'])\n", 308 | " data_utils.write_data(y_hat, f_path)\n", 309 | " print(f'--> saved in: {f_path}')\n", 310 | " \n", 311 | "# load all 3 models into a list\n", 312 | "models = [load_model(Model, params, checkpoint_path=checkpoint_paths[reg_id], device=device) for reg_id in challenges['w4c-core-stage-1']]\n", 313 | "\n", 314 | "# compute and save averaged predictions with the ensamble of models\n", 315 | "task_name = 'w4c-transfer-learning-stage-1'\n", 316 | "for region in challenges[task_name]:\n", 317 | " # load data\n", 318 | " ds_iterator, test_dates, params = get_data_iterator(region_id=region)\n", 319 | "\n", 320 | " # compute predictions in the ensamble of models and save them to disk\n", 321 | " r_path = os.path.join(root, folder_name, task_name, region, 'test')\n", 322 | " predictions_per_day_ensamble(test_dates, models, ds_iterator, device, r_path, params['data_params']) " 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [] 338 | } 339 | ], 340 | "metadata": { 341 | "kernelspec": { 342 | "display_name": "Python [conda env:.conda-weather2]", 343 | "language": "python", 344 | "name": "conda-env-.conda-weather2-py" 345 | }, 346 | "language_info": { 347 | "codemirror_mode": { 348 | "name": "ipython", 349 | "version": 3 350 | }, 351 | "file_extension": ".py", 352 | "mimetype": "text/x-python", 353 | "name": "python", 354 | "nbconvert_exporter": "python", 355 | "pygments_lexer": "ipython3", 356 | "version": "3.7.8" 357 | } 358 | }, 359 | "nbformat": 4, 360 | "nbformat_minor": 4 361 | } 362 | -------------------------------------------------------------------------------- /utils/3-train-UNet-example.py: -------------------------------------------------------------------------------- 1 | # Author: Pedro Herruzo 2 | # Copyright 2021 Institute of Advanced Research in Artificial Intelligence (IARAI) GmbH. 3 | # IARAI licenses this file to You under the Apache License, Version 2.0 4 | # (the "License"); you may not use this file except in compliance with 5 | # the License. You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | 16 | import pytorch_lightning as pl 17 | from pytorch_lightning.callbacks import ModelCheckpoint 18 | from torch.utils.data import DataLoader 19 | import torch 20 | 21 | from w4c_dataloader import create_dataset 22 | 23 | import pathlib 24 | import sys 25 | import os 26 | module_dir = str(pathlib.Path(os.getcwd()).parent) 27 | sys.path.append(module_dir) 28 | 29 | import config as cf 30 | from benchmarks.FeaturesSysUNet import FeaturesSysUNet as Model 31 | 32 | class DataModule(pl.LightningDataModule): 33 | """ Class to handle training/validation splits in a single object 34 | """ 35 | def __init__(self, params, training_params): 36 | super().__init__() 37 | self.params = params 38 | self.training_params = training_params 39 | 40 | def __load_dataloader(self, dataset, shuffle=True, pin=True): 41 | dl = DataLoader(dataset, 42 | batch_size=self.training_params['batch_size'], num_workers=self.training_params['n_workers'], 43 | shuffle=shuffle, pin_memory=pin) 44 | # prefetch_factor=2, 45 | # persistent_workers=False) 46 | return dl 47 | 48 | def train_dataloader(self): 49 | ds = create_dataset('training', self.params) 50 | return self.__load_dataloader(ds, shuffle=True, pin=True) 51 | 52 | def val_dataloader(self): 53 | ds = create_dataset('validation', self.params) 54 | return self.__load_dataloader(ds, shuffle=False, pin=True) 55 | 56 | def print_training(params): 57 | """ print pre-training info """ 58 | 59 | print(f'Extra variables: {params["extra_data"]} | spatial_dim: {params["spatial_dim"]} ', 60 | f'| collapse_time: {params["collapse_time"]} | in channels depth: {params["depth"]} | len_seq_in: {params["len_seq_in"]}') 61 | 62 | def load_model(Model, params, checkpoint_path=''): 63 | """ loads a model from a checkpoint or from scratch if checkpoint_path='' """ 64 | if checkpoint_path == '': 65 | print('-> model from scratch!') 66 | model = Model(params['model_params'], **params['data_params']) 67 | else: 68 | print(f'-> Loading model checkpoint: {checkpoint_path}') 69 | model = Model.load_from_checkpoint(checkpoint_path) 70 | return model 71 | 72 | def get_trainer(gpu): 73 | """ get the trainer, modify here it's options: 74 | - save_top_k 75 | - max_epochs 76 | """ 77 | checkpoint_callback = ModelCheckpoint(monitor='val_loss_epoch', save_top_k=3, 78 | filename='{epoch:02d}-{val_loss_epoch:.6f}') 79 | 80 | trainer = pl.Trainer(gpus=gpu, max_epochs=20, 81 | progress_bar_refresh_rate=10, # 80, 82 | callbacks=[checkpoint_callback], 83 | profiler='simple', 84 | sync_batchnorm=True, 85 | num_sanity_val_steps=0, 86 | accelerator='ddp', 87 | ) 88 | return trainer 89 | 90 | def do_test(trainer, model, test_data): 91 | print("-----------------") 92 | print("--- TEST MODE ---") 93 | print("-----------------") 94 | scores = trainer.test(model, test_dataloaders=test_data) 95 | 96 | def train(gpu, region_id, mode, checkpoint_path): 97 | """ main training/evaluation method 98 | """ 99 | pl.seed_everything(2021) 100 | torch.manual_seed(2021) 101 | torch.cuda.manual_seed_all(2021) 102 | 103 | # ------------ 104 | # model & data 105 | # ------------ 106 | params = cf.get_params(region_id=region_id, collapse_time=True) 107 | data = DataModule(params['data_params'], params['training_params']) 108 | 109 | model = load_model(Model, params, checkpoint_path) 110 | print(model) 111 | # Exception('Just stopping here for a while') 112 | # ------------ 113 | # trainer 114 | # ------------ 115 | trainer = get_trainer(gpu) 116 | print_training(params['data_params']) 117 | 118 | # ------------ 119 | # train & final validation 120 | # ------------ 121 | if mode == 'train': 122 | print("-----------------") 123 | print("-- TRAIN MODE ---") 124 | print("-----------------") 125 | trainer.fit(model, data) 126 | 127 | # validate 128 | do_test(trainer, model, data.val_dataloader()) 129 | 130 | def set_parser(): 131 | """ set custom parser """ 132 | 133 | parser = argparse.ArgumentParser(description="") 134 | parser.add_argument("-g", "--gpu_id", type=str, required=False, default='0,1', 135 | help="specify a gpu ID. 1 as default") 136 | parser.add_argument("-r", "--region", type=str, required=False, default='R1', 137 | help="region_id to load data from. R1 as default") 138 | parser.add_argument("-m", "--mode", type=str, required=False, default='train', 139 | help="choose mode: train (default) / val") 140 | parser.add_argument("-c", "--checkpoint", type=str, required=False, default='', 141 | help="init a model from a checkpoint path. '' as default (random weights)") 142 | 143 | return parser 144 | 145 | def main(): 146 | 147 | parser = set_parser() 148 | options = parser.parse_args() 149 | 150 | train(options.gpu_id, options.region, options.mode, options.checkpoint) 151 | 152 | if __name__ == "__main__": 153 | main() 154 | """ examples of usage: 155 | 156 | cd utils 157 | 158 | - a.1) train from scratch 159 | python 3-train-UNet-example.py --gpu_id 1 --region R1 160 | 161 | - a.2) fine tune a model from a checkpoint 162 | python 3-train-UNet-example.py --gpu_id 1 --region R1 -c '~/projects/weather4cast/lightning_logs/version_21/checkpoints/epoch=03-val_loss_epoch=0.027697.ckpt' 163 | 164 | - b.1) evaluate an untrained model (with random weights) 165 | python 3-train-UNet-example.py --gpu_id 1 --region R1 --mode val 166 | 167 | - b.2) evaluate a trained model from a checkpoint 168 | python 3-train-UNet-example.py --gpu_id 1 --region R1 --mode val -c '~/projects/weather4cast/lightning_logs/version_21/checkpoints/epoch=03-val_loss_epoch=0.027697.ckpt' 169 | """ -------------------------------------------------------------------------------- /utils/blacklist.json: -------------------------------------------------------------------------------- 1 | { 2 | "2019056": [ 3 | 51, 4 | 52, 5 | 53, 6 | 54, 7 | 58 8 | ], 9 | "2019071": [ 10 | 52, 11 | 53, 12 | 54, 13 | 55 14 | ], 15 | "2019095": [ 16 | 51, 17 | 95 18 | ], 19 | "2019096": [ 20 | 95 21 | ], 22 | "2019100": [ 23 | 95 24 | ], 25 | "2019102": [ 26 | 48, 27 | 41, 28 | 95 29 | ], 30 | "2019103": [ 31 | 19, 32 | 95 33 | ], 34 | "2019105": [ 35 | 95 36 | ], 37 | "2019108": [ 38 | 95 39 | ], 40 | "2019109": [ 41 | 95 42 | ], 43 | "2019111": [ 44 | 95 45 | ], 46 | "2019126": [ 47 | 63 48 | ], 49 | "2019134": [ 50 | 41 51 | ], 52 | "2019144": [ 53 | 56 54 | ], 55 | "2019151": [ 56 | 88 57 | ], 58 | "2019182": [ 59 | 59, 60 | 61 61 | ], 62 | "2019183": [ 63 | 4, 64 | 5, 65 | 6, 66 | 7, 67 | 8, 68 | 9, 69 | 10, 70 | 11, 71 | 12, 72 | 13 73 | ], 74 | "2019229": [ 75 | 29 76 | ], 77 | "2019232": [ 78 | 88, 79 | 89, 80 | 90 81 | ], 82 | "2019236": [ 83 | 63 84 | ], 85 | "2019238": [ 86 | 69, 87 | 70, 88 | 71, 89 | 73, 90 | 74 91 | ], 92 | "2019239": [ 93 | 74, 94 | 75, 95 | 78 96 | ], 97 | "2019242": [ 98 | 69 99 | ], 100 | "2019256": [ 101 | 47 102 | ], 103 | "2019263": [ 104 | 77 105 | ], 106 | "2019281": [ 107 | 42 108 | ], 109 | "2019288": [ 110 | 48 111 | ], 112 | "2019289": [ 113 | 49 114 | ], 115 | "2019315": [ 116 | 43, 117 | 44, 118 | 56, 119 | 57, 120 | 58 121 | ], 122 | "2019180": [ 123 | 79, 124 | 80 125 | ], 126 | "2019299": [ 127 | 1, 128 | 2, 129 | 3, 130 | 4, 131 | 5, 132 | 6, 133 | 7, 134 | 8, 135 | 9, 136 | 10, 137 | 11, 138 | 12, 139 | 13, 140 | 14, 141 | 15, 142 | 16, 143 | 17, 144 | 18, 145 | 19, 146 | 20 147 | ], 148 | "2019352": [ 149 | 47 150 | ], 151 | "2020029": [ 152 | 9 153 | ] 154 | } -------------------------------------------------------------------------------- /utils/context_variables.py: -------------------------------------------------------------------------------- 1 | # Author: Pedro Herruzo 2 | # Copyright 2021 Institute of Advanced Research in Artificial Intelligence (IARAI) GmbH. 3 | # IARAI licenses this file to You under the Apache License, Version 2.0 4 | # (the "License"); you may not use this file except in compliance with 5 | # the License. You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import xarray as xr 17 | 18 | # ---------------------------------- 19 | # preprocess - static features 20 | # ---------------------------------- 21 | def get_copies(l, n_copies): 22 | """ return the array duplicated for each sample in the sequence """ 23 | arr = np.asarray([l for i in range(n_copies)]) 24 | return arr 25 | 26 | def _norm(x, max_v, min_v): 27 | """ we assume max_v > 0 & max_v > min_v """ 28 | return (x-min_v)/(max_v-min_v) 29 | 30 | def normalize_latlon(latlons): 31 | norm_latlon = {'lat': {'max_v': 86, 'min_v': 23}, # it does not start from the equator & does not reah the pole 32 | 'lon': {'max_v': 76, 'min_v': -76}} # it does not consider full earth 33 | 34 | latlons[0] = _norm(latlons[0], **norm_latlon['lat']) 35 | latlons[1] = _norm(latlons[1], **norm_latlon['lon']) 36 | 37 | return latlons 38 | 39 | def crop_Dataset(product, x_start, y_start, size=256): 40 | """ crop a squared region size 41 | provide upper-left corner with (x_start, y_start) 42 | """ 43 | return product.isel(nx=slice(x_start, x_start+size), 44 | ny=slice(y_start, y_start+size)) 45 | 46 | def mk_crop_np(product, x_start, y_start, size=256): 47 | """ crop a squared region size^2 48 | provide upper-left corner with (x_start, y_start) 49 | """ 50 | return product[y_start:y_start+size, x_start:x_start+size] 51 | 52 | # ---------------------------------- 53 | # load extra information - static features 54 | # ---------------------------------- 55 | def get_elevation(n_copies, crop=None, path='', shape=[1019, 2200], norm=True): 56 | 57 | altitudes = np.fromfile(path, dtype=np.float32) 58 | altitudes = altitudes.reshape(shape[0], shape[1]) 59 | max_alt = altitudes.max() 60 | 61 | if crop is not None: 62 | altitudes = mk_crop_np(altitudes, **crop) 63 | 64 | if norm: 65 | # make under see level 0 66 | altitudes[altitudes<0] = 0 67 | 68 | # normalize 69 | altitudes = altitudes/max_alt 70 | 71 | return np.expand_dims(get_copies(altitudes, n_copies), axis=1), ['altitudes'] 72 | 73 | def get_lat_lon(n_copies, crop=None, path='', atts = ['latitude', 'longitude'], norm=True): 74 | 75 | latlons = xr.open_dataset(path) 76 | 77 | if crop is not None: 78 | latlons = crop_Dataset(latlons, **crop) 79 | 80 | # get only the values form the netcdf4 file 81 | latlons = [latlons[att][0].values for att in atts] 82 | 83 | if norm: 84 | latlons = normalize_latlon(latlons) 85 | 86 | return get_copies(latlons, n_copies), atts.copy() 87 | 88 | def get_static(attributes, n_copies, paths, crop=None, channel_dim=1, norm=True): 89 | 90 | statics, descriptions = [], [] 91 | funcs = {'l': get_lat_lon, 'e': get_elevation} 92 | 93 | for feature in attributes: 94 | if feature in 'le': 95 | data, channels = funcs[feature](n_copies, crop=crop, path=paths[feature], norm=norm) 96 | statics.append(data) 97 | descriptions += channels 98 | 99 | if len(statics)!=0: 100 | statics = np.concatenate(statics, axis=channel_dim) 101 | 102 | return statics, descriptions -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Author: Pedro Herruzo 2 | # Copyright 2021 Institute of Advanced Research in Artificial Intelligence (IARAI) GmbH. 3 | # IARAI licenses this file to You under the Apache License, Version 2.0 4 | # (the "License"); you may not use this file except in compliance with 5 | # the License. You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import numpy as np 17 | import netCDF4 18 | import pandas as pd 19 | import json 20 | import h5py 21 | 22 | 23 | # ---------------------------------- 24 | # Loading utils: blacklist & data splits 25 | # ---------------------------------- 26 | def get_double_idxs_w_blacklist(days, bins_to_predict=32, day_bins=96, len_seq_in=4, black_list_path='blacklist.json', 27 | verbose=False): 28 | doubles = [] 29 | with open(black_list_path) as data_file: 30 | # convert key 'dates' to integer 31 | black_list = json.load(data_file) 32 | black_list = {int(k): v for k, v in black_list.items()} 33 | 34 | # range for the input sequence 35 | in_range = day_bins - (len_seq_in + bins_to_predict - 1) 36 | 37 | # for each day 'd', AND each starting timebin index 'i' 38 | for d in days: 39 | for i in range(in_range): 40 | seq_in_black = False 41 | 42 | if d in black_list.keys(): 43 | 44 | seq_bins = [i + j for j in range(len_seq_in + bins_to_predict)] 45 | timebins_in_black = [idx for idx in seq_bins if idx in black_list[d]] 46 | 47 | if len(timebins_in_black) > 0: 48 | seq_in_black = True 49 | if verbose: 50 | print(i, seq_bins, timebins_in_black, black_list[d]) 51 | else: 52 | doubles.append((d, i)) 53 | 54 | else: 55 | doubles.append((d, i)) 56 | 57 | return doubles 58 | 59 | 60 | def get_test_doubles(days, test_sequences, bins_to_predict=32): 61 | doubles = [] 62 | 63 | for d in days: 64 | day_id = str(d)[-3:] 65 | i = test_sequences[day_id]['bins_in']['0']['id_bin'] 66 | doubles.append((d, i)) 67 | return doubles 68 | 69 | 70 | def get_triple_idxs_w_blacklist(days, bins_to_predict=32, day_bins=96, len_seq_in=4, black_list_path='blacklist.json', 71 | verbose=False): 72 | 73 | triples = [] 74 | with open(black_list_path) as data_file: 75 | # convert key 'dates' to integer 76 | black_list = json.load(data_file) 77 | black_list = {int(k): v for k, v in black_list.items()} 78 | 79 | # range for the input sequence 80 | in_range = day_bins-(len_seq_in+bins_to_predict-1) 81 | 82 | # for each day 'd', AND each starting timebin index 'i' 83 | for d in days: 84 | for i in range(in_range): 85 | seq_in_black = False 86 | 87 | if d in black_list.keys(): 88 | 89 | seq_bins = [i + j for j in range(len_seq_in + len_seq_in)] 90 | timebins_in_black = [idx for idx in seq_bins if idx in black_list[d]] 91 | 92 | if len(timebins_in_black) > 0: 93 | seq_in_black = True 94 | if verbose: 95 | print(i, seq_bins, timebins_in_black, black_list[d]) 96 | else: 97 | triples.append((d, i, o)) 98 | 99 | else: 100 | triples.append((d, i, o)) 101 | 102 | # if day is in the blacklist AND any input frame in the sequence is missing 103 | # --> we will block this sequence 104 | if d in black_list.keys(): 105 | 106 | seq_bins = [i+j for j in range(len_seq_in)] 107 | timebins_in_black = [idx for idx in seq_bins if idx in black_list[d]] 108 | 109 | if len(timebins_in_black)>0: 110 | seq_in_black = True 111 | if verbose: 112 | print(i, seq_bins, timebins_in_black, black_list[d]) 113 | 114 | # only consider sequences with input frames not in the black_list 115 | if not seq_in_black: 116 | 117 | # even if input sequence is not in the black_list 118 | # we will only consider the triplet if the output is not in the black_list 119 | for o in range(bins_to_predict): 120 | out_in_black = False 121 | if d in black_list.keys(): 122 | if i+len_seq_in+o in black_list[d]: 123 | out_in_black = True 124 | 125 | if not out_in_black: 126 | triples.append((d, i, o)) 127 | else: 128 | if verbose: 129 | print("--> output") 130 | print((d, i, i+o, black_list[d])) 131 | else: 132 | if verbose: 133 | print("input: ") 134 | print(d, i) 135 | print() 136 | 137 | return triples 138 | 139 | 140 | def get_test_triplets(days, test_sequences, bins_to_predict=32): 141 | triples = [] 142 | 143 | for d in days: 144 | day_id = str(d)[-3:] 145 | i = test_sequences[day_id]['bins_in']['0']['id_bin'] 146 | for o in range(bins_to_predict): 147 | triples.append((d, i, o)) 148 | return triples 149 | 150 | 151 | def get_time(): 152 | return ['{}{}{}{}00'.format('0'*bool(i<10), i, '0'*bool(j<10), j) for i in np.arange(0, 24, 1) 153 | for j in np.arange(0, 60, 15)] 154 | 155 | 156 | def read_splits(path_splits, path_test_split=''): 157 | """ read dates splits with pandas and test_splits with json """ 158 | df = pd.read_csv(path_splits, index_col=0) 159 | with open(path_test_split) as data_file: 160 | test_splits = json.load(data_file) 161 | return df, test_splits 162 | 163 | def get_next_day(sorted_days, current_day): 164 | """ get the next day """ 165 | pos_date = np.argwhere(sorted_days==current_day) 166 | assert len(pos_date)==1, f" Error: date {current_day} not in the list" 167 | 168 | pos_date = pos_date[0, 0] 169 | assert pos_date+1 all these files were found: {file_path}" 336 | file_path = file_path[0] 337 | ds = netCDF4.Dataset(file_path, 'r') 338 | 339 | return ds 340 | 341 | def get_file_netcdf4(product, day_in_year, time, attributes, 342 | root='', crop=None, params_process=None, 343 | ct_1hot=None, target=False, use_mask=False): 344 | """ open a *.nc file and return only the specified attributes 345 | - one_hot_ct allows to 1-hot encoding 'ct' 346 | """ 347 | ds = read_netcdf4(product, day_in_year, time, root) 348 | ds_vars, ds_masks = {}, {} 349 | # ds_vars = {} 350 | for attr in attributes: 351 | #ds.variables[attr].set_auto_scale(False) # uncomment & don't scale in preprocess_fn to see raw data 352 | v = ds.variables[attr][...] 353 | 354 | try: 355 | v_max = np.finfo(v.dtype).max 356 | except: 357 | v_max = np.iinfo(v.dtype).max 358 | 359 | if crop is not None: 360 | v = mk_crop_Dataset_netcdf4(v, **crop) 361 | 362 | if isinstance(v.mask, np.bool_): 363 | # v.mask = np.zeros(v.shape) 364 | v.mask = np.zeros(v.shape, dtype=np.bool_) 365 | ds_masks[attr] = v.mask 366 | 367 | if params_process is not None: 368 | if attr == 'ct' and ct_1hot is not None: 369 | v = preprocess_1hot_ct(v, **params_process[attr]) 370 | else: # fill_value, max_value, add_offset, scale_factor 371 | # print(v, v.dtype, v_max, attr, params_process.get(attr)) 372 | # v = preprocess_fn(v, **params_process[attr]) 373 | # TODO - check the best and right scaling factors for unlisted variables 374 | v = preprocess_fn(v, **params_process.get(attr, {'fill_value': 0, 'max_value': v_max, 375 | 'add_offset': 0, 'scale_factor': 1})) 376 | 377 | else: # return raw value with NaNs where a mask is found 378 | v = v.filled(np.nan) 379 | 380 | ds_vars[attr] = v 381 | 382 | ds.close() 383 | # return ds_vars # , ds_masks 384 | return ds_vars, ds_masks if use_mask else None 385 | 386 | def get_products_netcdf4(day_in_year, time, products, path, 387 | attrs_order=['ct', 'ctth_pres', 'crr_intensity'], 388 | crop=None, preprocess=None, ct_1hot=None, debug=False, 389 | target=False, use_mask=False): 390 | """ loads all products and attributes into a single tensor 391 | sorted by attrs_order 392 | 393 | returns 394 | - numpy tensor with shape (attributes, ny, nx) 395 | - attrs_order 396 | """ 397 | 398 | prods, masks = {}, {} 399 | # prods = {} 400 | for product, attributes_ in products.items(): 401 | attributes = [attr for attr in attributes_ if attr in attrs_order] # just added to increase loading speed 402 | prod, mask = get_file_netcdf4(product, day_in_year, time, attributes, 403 | root=path, crop=crop, params_process=preprocess, 404 | ct_1hot=ct_1hot, target=target, use_mask=use_mask) 405 | # prod = get_file_netcdf4(product, day_in_year, time, attributes, 406 | # root=path, crop=crop, params_process=preprocess, 407 | # ct_1hot=ct_1hot, target=target) 408 | prods = dict(prods, **prod) 409 | if use_mask: 410 | masks = dict(masks, **mask) 411 | 412 | if debug: 413 | for sorted_var in attrs_order: 414 | print(sorted_var, prods[sorted_var].shape) 415 | 416 | # to numpy 417 | if ct_1hot is not None: 418 | v = 'ct' 419 | prods, masks = prods[v], masks[v] 420 | # prods = prods[v] 421 | attrs_order = [ct_1hot[key] for key in sorted(ct_1hot.keys(), reverse=False)] 422 | else: 423 | prods = np.asarray([prods[sorted_var] for sorted_var in attrs_order]) 424 | if use_mask: 425 | masks = np.asarray([masks[sorted_var] for sorted_var in attrs_order]) 426 | 427 | return prods, masks, attrs_order.copy() 428 | # return prods, attrs_order.copy() 429 | 430 | def get_sequence_netcdf4(len_seq, in_start_id, day_id, products, path, target_vars, 431 | hhmmss=get_time(), crop=None, preprocess=None, ct_1hot=None, 432 | day_bins=96, sorted_dates=None, populate_mask=False): 433 | """ input doesn't need the mask """ 434 | 435 | sequence = [] 436 | seq_info = {'day_in_year': [], 'time_bins': [], 'masks': []} 437 | already_next_day = True 438 | 439 | for time_bin in range(in_start_id, in_start_id+len_seq): 440 | if time_bin>=day_bins: # this is to load next days' timebin in the test split 441 | time_bin = time_bin % day_bins 442 | if not already_next_day: 443 | print(f'Since input sequence goes from {in_start_id} to {in_start_id+len_seq}') 444 | print(f'we should need to update day {day_id}...') 445 | next_day_id = get_next_day(sorted_dates, day_id) 446 | print(f'to {next_day_id}, but files are all in the folder of the former day.\n') 447 | #already_next_day = True 448 | 449 | prod, masks, channels = get_products_netcdf4(day_id, hhmmss[time_bin], products, path, 450 | target_vars, crop, preprocess, ct_1hot, use_mask=populate_mask) 451 | # prod, channels = get_products_netcdf4(day_id, hhmmss[time_bin], products, path, 452 | # target_vars, crop, preprocess, ct_1hot) 453 | sequence.append(prod) 454 | seq_info['day_in_year'].append(day_id) 455 | seq_info['time_bins'].append(time_bin) 456 | if populate_mask: 457 | seq_info['masks'].append(masks) 458 | 459 | if populate_mask: 460 | seq_info['masks'] = np.asarray(seq_info['masks']) 461 | seq_info['channels'] = channels 462 | 463 | # to numpy 464 | sequence = np.asarray(sequence) 465 | 466 | return sequence, seq_info -------------------------------------------------------------------------------- /utils/environment.yml: -------------------------------------------------------------------------------- 1 | name: w4c 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_llvm 9 | - absl-py=0.12.0=pyhd8ed1ab_0 10 | - aiohttp=3.7.4=py39h3811e60_0 11 | - anyio=2.2.0=py39hf3d152e_0 12 | - argon2-cffi=20.1.0=py39h3811e60_2 13 | - async-timeout=3.0.1=py_1000 14 | - async_generator=1.10=py_0 15 | - attrs=20.3.0=pyhd3deb0d_0 16 | - babel=2.9.0=pyhd3deb0d_0 17 | - backcall=0.2.0=pyh9f0ad1d_0 18 | - backports=1.0=py_2 19 | - backports.functools_lru_cache=1.6.3=pyhd8ed1ab_0 20 | - blas=2.108=mkl 21 | - blas-devel=3.9.0=8_mkl 22 | - bleach=3.3.0=pyh44b312d_0 23 | - blinker=1.4=py_1 24 | - brotlipy=0.7.0=py39h3811e60_1001 25 | - bzip2=1.0.8=h7f98852_4 26 | - c-ares=1.17.1=h7f98852_1 27 | - ca-certificates=2020.12.5=ha878542_0 28 | - cached-property=1.5.2=hd8ed1ab_1 29 | - cached_property=1.5.2=pyha770c72_1 30 | - cachetools=4.2.1=pyhd8ed1ab_0 31 | - certifi=2020.12.5=py39hf3d152e_1 32 | - cffi=1.14.5=py39he32792d_0 33 | - cftime=1.4.1=py39hce5d2b2_0 34 | - chardet=4.0.0=py39hf3d152e_1 35 | - click=7.1.2=pyh9f0ad1d_0 36 | - cryptography=3.4.7=py39hbca0aa6_0 37 | - cudatoolkit=11.1.1=h6406543_8 38 | - curl=7.75.0=h979ede3_0 39 | - cycler=0.10.0=py_2 40 | - dbus=1.13.6=h48d8840_2 41 | - decorator=4.4.2=py_0 42 | - defusedxml=0.7.1=pyhd8ed1ab_0 43 | - entrypoints=0.3=pyhd8ed1ab_1003 44 | - expat=2.3.0=h9c3ff4c_0 45 | - fontconfig=2.13.1=hba837de_1004 46 | - freetype=2.10.4=h0708190_1 47 | - fsspec=0.8.7=pyhd8ed1ab_0 48 | - future=0.18.2=py39hf3d152e_3 49 | - gettext=0.19.8.1=h0b5b191_1005 50 | - glib=2.68.0=h9c3ff4c_2 51 | - glib-tools=2.68.0=h9c3ff4c_2 52 | - google-auth=1.28.0=pyh44b312d_0 53 | - google-auth-oauthlib=0.4.1=py_2 54 | - grpcio=1.36.1=py39hff7568b_0 55 | - gst-plugins-base=1.18.4=h29181c9_0 56 | - gstreamer=1.18.4=h76c114f_0 57 | - h5py=3.1.0=nompi_py39h25020de_100 58 | - hdf4=4.2.13=h10796ff_1004 59 | - hdf5=1.10.6=nompi_h6a2412b_1114 60 | - icu=68.1=h58526e2_0 61 | - idna=2.10=pyh9f0ad1d_0 62 | - importlib-metadata=3.10.0=py39hf3d152e_0 63 | - ipykernel=5.5.0=py39hef51801_1 64 | - ipython=7.22.0=py39hef51801_0 65 | - ipython_genutils=0.2.0=py_1 66 | - jedi=0.18.0=py39hf3d152e_2 67 | - jinja2=2.11.3=pyh44b312d_0 68 | - jpeg=9d=h36c2ea0_0 69 | - json5=0.9.5=pyh9f0ad1d_0 70 | - jsonschema=3.2.0=pyhd8ed1ab_3 71 | - jupyter-packaging=0.7.12=pyhd8ed1ab_0 72 | - jupyter_client=6.1.12=pyhd8ed1ab_0 73 | - jupyter_core=4.7.1=py39hf3d152e_0 74 | - jupyter_server=1.5.1=py39hf3d152e_0 75 | - jupyterlab=3.0.12=pyhd8ed1ab_0 76 | - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 77 | - jupyterlab_server=2.3.0=pyhd8ed1ab_0 78 | - kiwisolver=1.3.1=py39h1a9c180_1 79 | - krb5=1.17.2=h926e7f8_0 80 | - lcms2=2.12=hddcbb42_0 81 | - ld_impl_linux-64=2.35.1=hea4e1c9_2 82 | - libblas=3.9.0=8_mkl 83 | - libcblas=3.9.0=8_mkl 84 | - libclang=11.1.0=default_ha53f305_0 85 | - libcurl=7.75.0=hc4aaa36_0 86 | - libedit=3.1.20191231=he28a2e2_2 87 | - libev=4.33=h516909a_1 88 | - libevent=2.1.10=hcdb4288_3 89 | - libffi=3.3=h58526e2_2 90 | - libgcc-ng=9.3.0=h2828fa1_18 91 | - libgfortran-ng=9.3.0=hff62375_18 92 | - libgfortran5=9.3.0=hff62375_18 93 | - libglib=2.68.0=h3e27bee_2 94 | - libgomp=9.3.0=h2828fa1_18 95 | - libiconv=1.16=h516909a_0 96 | - liblapack=3.9.0=8_mkl 97 | - liblapacke=3.9.0=8_mkl 98 | - libllvm11=11.1.0=hf817b99_0 99 | - libnetcdf=4.7.4=nompi_h56d31a8_107 100 | - libnghttp2=1.43.0=h812cca2_0 101 | - libopenblas=0.3.12=pthreads_h4812303_1 102 | - libpng=1.6.37=h21135ba_2 103 | - libpq=13.1=hfd2b0eb_2 104 | - libprotobuf=3.15.6=h780b84a_0 105 | - libsodium=1.0.18=h36c2ea0_1 106 | - libssh2=1.9.0=ha56f1ee_6 107 | - libstdcxx-ng=9.3.0=h6de172a_18 108 | - libtiff=4.2.0=hdc55705_0 109 | - libuuid=2.32.1=h7f98852_1000 110 | - libuv=1.41.0=h7f98852_0 111 | - libwebp-base=1.2.0=h7f98852_2 112 | - libxcb=1.13=h7f98852_1003 113 | - libxkbcommon=1.0.3=he3ba5ed_0 114 | - libxml2=2.9.10=h72842e0_3 115 | - llvm-openmp=11.1.0=h4bd325d_0 116 | - lz4-c=1.9.3=h9c3ff4c_0 117 | - markdown=3.3.4=pyhd8ed1ab_0 118 | - markupsafe=1.1.1=py39h3811e60_3 119 | - matplotlib=3.3.4=py39hf3d152e_0 120 | - matplotlib-base=3.3.4=py39h2fa2bec_0 121 | - mistune=0.8.4=py39h3811e60_1003 122 | - mkl=2020.4=h726a3e6_304 123 | - mkl-devel=2020.4=ha770c72_305 124 | - mkl-include=2020.4=h726a3e6_304 125 | - multidict=5.1.0=py39h3811e60_1 126 | - mysql-common=8.0.23=ha770c72_1 127 | - mysql-libs=8.0.23=h935591d_1 128 | - nbclassic=0.2.6=pyhd8ed1ab_0 129 | - nbclient=0.5.3=pyhd8ed1ab_0 130 | - nbconvert=6.0.7=py39hf3d152e_3 131 | - nbformat=5.1.2=pyhd8ed1ab_1 132 | - ncurses=6.2=h58526e2_4 133 | - nest-asyncio=1.4.3=pyhd8ed1ab_0 134 | - netcdf4=1.5.6=nompi_py39h36800e2_100 135 | - ninja=1.10.2=h4bd325d_0 136 | - notebook=6.3.0=py39hf3d152e_0 137 | - nspr=4.30=h9c3ff4c_0 138 | - nss=3.63=hb5efdd6_0 139 | - numpy=1.20.2=py39hdbf815f_0 140 | - oauthlib=3.0.1=py_0 141 | - olefile=0.46=pyh9f0ad1d_1 142 | - openssl=1.1.1k=h7f98852_0 143 | - packaging=20.9=pyh44b312d_0 144 | - pandas=1.2.3=py39hde0f152_0 145 | - pandoc=2.12=h7f98852_0 146 | - pandocfilters=1.4.2=py_1 147 | - parso=0.8.1=pyhd8ed1ab_0 148 | - pcre=8.44=he1b5a44_0 149 | - pexpect=4.8.0=pyh9f0ad1d_2 150 | - pickleshare=0.7.5=py_1003 151 | - pillow=8.1.2=py39hf95b381_0 152 | - pip=21.0.1=pyhd8ed1ab_0 153 | - prometheus_client=0.9.0=pyhd3deb0d_0 154 | - prompt-toolkit=3.0.18=pyha770c72_0 155 | - protobuf=3.15.6=py39he80948d_0 156 | - pthread-stubs=0.4=h36c2ea0_1001 157 | - ptyprocess=0.7.0=pyhd3deb0d_0 158 | - pyasn1=0.4.8=py_0 159 | - pyasn1-modules=0.2.7=py_0 160 | - pycparser=2.20=pyh9f0ad1d_2 161 | - pygments=2.8.1=pyhd8ed1ab_0 162 | - pyjwt=2.0.1=pyhd8ed1ab_0 163 | - pyopenssl=20.0.1=pyhd8ed1ab_0 164 | - pyparsing=2.4.7=pyh9f0ad1d_0 165 | - pyqt=5.12.3=py39hf3d152e_7 166 | - pyqt-impl=5.12.3=py39h0fcd23e_7 167 | - pyqt5-sip=4.19.18=py39he80948d_7 168 | - pyqtchart=5.12=py39h0fcd23e_7 169 | - pyqtwebengine=5.12.1=py39h0fcd23e_7 170 | - pyrsistent=0.17.3=py39h3811e60_2 171 | - pysocks=1.7.1=py39hf3d152e_3 172 | - python=3.9.2=hffdb5ce_0_cpython 173 | - python-dateutil=2.8.1=py_0 174 | - python_abi=3.9=1_cp39 175 | - pytorch=1.8.1=py3.9_cuda11.1_cudnn8.0.5_0 176 | - pytorch-lightning=1.2.6=pyhd8ed1ab_0 177 | - pytz=2021.1=pyhd8ed1ab_0 178 | - pyyaml=5.4.1=py39h3811e60_0 179 | - pyzmq=22.0.3=py39h37b5a0c_1 180 | - qt=5.12.9=hda022c4_4 181 | - readline=8.0=he28a2e2_2 182 | - requests=2.25.1=pyhd3deb0d_0 183 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 184 | - rsa=4.7.2=pyh44b312d_0 185 | - send2trash=1.5.0=py_0 186 | - setuptools=49.6.0=py39hf3d152e_3 187 | - six=1.15.0=pyh9f0ad1d_0 188 | - sleef=3.5.1=h7f98852_1 189 | - sniffio=1.2.0=py39hf3d152e_1 190 | - sqlite=3.35.3=h74cdb3f_0 191 | - tensorboard=2.4.1=pyhd8ed1ab_0 192 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 193 | - terminado=0.9.4=py39hf3d152e_0 194 | - testpath=0.4.4=py_0 195 | - tk=8.6.10=h21135ba_1 196 | - torchaudio=0.8.1=py39 197 | - torchmetrics=0.2.0=pyhd8ed1ab_0 198 | - torchvision=0.2.2=py_3 199 | - tornado=6.1=py39h3811e60_1 200 | - tqdm=4.59.0=pyhd8ed1ab_0 201 | - traitlets=5.0.5=py_0 202 | - typing-extensions=3.7.4.3=0 203 | - typing_extensions=3.7.4.3=py_0 204 | - tzdata=2021a=he74cb21_0 205 | - urllib3=1.26.4=pyhd8ed1ab_0 206 | - wcwidth=0.2.5=pyh9f0ad1d_2 207 | - webencodings=0.5.1=py_1 208 | - werkzeug=1.0.1=pyh9f0ad1d_0 209 | - wheel=0.36.2=pyhd3deb0d_0 210 | - xarray=0.17.0=pyhd8ed1ab_0 211 | - xorg-libxau=1.0.9=h7f98852_0 212 | - xorg-libxdmcp=1.1.3=h7f98852_0 213 | - xz=5.2.5=h516909a_1 214 | - yaml=0.2.5=h516909a_0 215 | - yarl=1.6.3=py39h3811e60_1 216 | - zeromq=4.3.4=h9c3ff4c_0 217 | - zipp=3.4.1=pyhd8ed1ab_0 218 | - zlib=1.2.11=h516909a_1010 219 | - zstd=1.4.9=ha95c52a_0 220 | -------------------------------------------------------------------------------- /utils/h5shape.py: -------------------------------------------------------------------------------- 1 | # Author: Pedro Herruzo 2 | # Copyright 2021 Institute of Advanced Research in Artificial Intelligence (IARAI) GmbH. 3 | # IARAI licenses this file to You under the Apache License, Version 2.0 4 | # (the "License"); you may not use this file except in compliance with 5 | # the License. You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import sys, getopt 17 | import h5py 18 | 19 | 20 | def load_test_file(file_path): 21 | """ 22 | Given a file path, loads test file (in h5 format). 23 | Returns: tensor of shape (number_of_test_cases = 5, 3, 3, 496, 435) 24 | """ 25 | # load h5 file 26 | fr = h5py.File(file_path, 'r') 27 | a_group_key = list(fr.keys())[0] 28 | data = list(fr[a_group_key]) 29 | 30 | # get relevant test cases 31 | data = data[0:] 32 | data = np.stack(data,axis=0) 33 | # transpose 34 | return data 35 | 36 | def print_shape(data): 37 | print(data.shape) 38 | 39 | 40 | if __name__ == '__main__': 41 | 42 | # gather command line arguments. 43 | infile = '' 44 | try: 45 | opts, args = getopt.getopt(sys.argv[1:], "hi:", ["infile="]) 46 | except getopt.GetoptError: 47 | print('usage: h5shape -i ') 48 | sys.exit(2) 49 | for opt, arg in opts: 50 | if opt == '-h': 51 | print('usage: h5shape -i ') 52 | sys.exit() 53 | elif opt in ("-i","--infile"): 54 | infile = arg 55 | 56 | data = load_test_file(infile) 57 | print_shape(data) -------------------------------------------------------------------------------- /utils/splits.csv: -------------------------------------------------------------------------------- 1 | ,id_date,split_id,split,id_day,date 2 | 0,2019046,0,training,046,20190215 3 | 1,2019047,2,test,047,20190216 4 | 2,2019048,2,test-next,048,20190217 5 | 3,2019049,0,training,049,20190218 6 | 4,2019050,0,training,050,20190219 7 | 5,2019051,0,training,051,20190220 8 | 6,2019052,0,training,052,20190221 9 | 7,2019053,0,training,053,20190222 10 | 8,2019054,0,training,054,20190223 11 | 9,2019055,0,training,055,20190224 12 | 10,2019056,0,training,056,20190225 13 | 11,2019057,0,training,057,20190226 14 | 12,2019058,0,training,058,20190227 15 | 13,2019059,0,training,059,20190228 16 | 14,2019060,0,training,060,20190301 17 | 15,2019061,0,training,061,20190302 18 | 16,2019062,0,training,062,20190303 19 | 17,2019063,0,training,063,20190304 20 | 18,2019064,0,training,064,20190305 21 | 19,2019065,0,training,065,20190306 22 | 20,2019066,0,training,066,20190307 23 | 21,2019067,0,training,067,20190308 24 | 22,2019068,0,training,068,20190309 25 | 23,2019069,0,training,069,20190310 26 | 24,2019070,0,training,070,20190311 27 | 25,2019071,0,training,071,20190312 28 | 26,2019072,0,training,072,20190313 29 | 27,2019073,2,test,073,20190314 30 | 28,2019074,2,test-next,074,20190315 31 | 29,2019075,0,training,075,20190316 32 | 30,2019076,1,validation,076,20190317 33 | 31,2019077,0,training,077,20190318 34 | 32,2019078,0,training,078,20190319 35 | 33,2019079,0,training,079,20190320 36 | 34,2019080,0,training,080,20190321 37 | 35,2019081,0,training,081,20190322 38 | 36,2019082,2,test,082,20190323 39 | 37,2019083,0,training,083,20190324 40 | 38,2019084,0,training,084,20190325 41 | 39,2019085,0,training,085,20190326 42 | 40,2019086,0,training,086,20190327 43 | 41,2019087,0,training,087,20190328 44 | 42,2019088,0,training,088,20190329 45 | 43,2019089,1,validation,089,20190330 46 | 44,2019090,0,training,090,20190331 47 | 45,2019091,0,training,091,20190401 48 | 46,2019092,2,test,092,20190402 49 | 47,2019093,0,training,093,20190403 50 | 48,2019094,0,training,094,20190404 51 | 49,2019095,0,training,095,20190405 52 | 50,2019096,0,training,096,20190406 53 | 51,2019097,0,training,097,20190407 54 | 52,2019098,2,test,098,20190408 55 | 53,2019099,0,training,099,20190409 56 | 54,2019100,0,training,100,20190410 57 | 55,2019101,1,validation,101,20190411 58 | 56,2019102,1,validation,102,20190412 59 | 57,2019103,0,training,103,20190413 60 | 58,2019104,2,test,104,20190414 61 | 59,2019105,2,test-next,105,20190415 62 | 60,2019106,0,training,106,20190416 63 | 61,2019107,0,training,107,20190417 64 | 62,2019108,0,training,108,20190418 65 | 63,2019109,0,training,109,20190419 66 | 64,2019110,2,test,110,20190420 67 | 65,2019111,0,training,111,20190421 68 | 66,2019112,0,training,112,20190422 69 | 67,2019113,0,training,113,20190423 70 | 68,2019114,2,test,114,20190424 71 | 69,2019115,0,training,115,20190425 72 | 70,2019116,0,training,116,20190426 73 | 71,2019117,0,training,117,20190427 74 | 72,2019118,0,training,118,20190428 75 | 73,2019119,0,training,119,20190429 76 | 74,2019120,1,validation,120,20190430 77 | 75,2019121,0,training,121,20190501 78 | 76,2019122,0,training,122,20190502 79 | 77,2019123,0,training,123,20190503 80 | 78,2019124,0,training,124,20190504 81 | 79,2019125,0,training,125,20190505 82 | 80,2019126,1,validation,126,20190506 83 | 81,2019127,0,training,127,20190507 84 | 82,2019128,0,training,128,20190508 85 | 83,2019129,0,training,129,20190509 86 | 84,2019130,0,training,130,20190510 87 | 85,2019131,0,training,131,20190511 88 | 86,2019132,0,training,132,20190512 89 | 87,2019133,0,training,133,20190513 90 | 88,2019134,0,training,134,20190514 91 | 89,2019135,0,training,135,20190515 92 | 90,2019136,0,training,136,20190516 93 | 91,2019137,0,training,137,20190517 94 | 92,2019138,0,training,138,20190518 95 | 93,2019139,1,validation,139,20190519 96 | 94,2019140,0,training,140,20190520 97 | 95,2019141,0,training,141,20190521 98 | 96,2019142,0,training,142,20190522 99 | 97,2019143,0,training,143,20190523 100 | 98,2019144,0,training,144,20190524 101 | 99,2019145,2,test,145,20190525 102 | 100,2019146,0,training,146,20190526 103 | 101,2019147,2,test,147,20190527 104 | 102,2019148,2,test-next,148,20190528 105 | 103,2019149,0,training,149,20190529 106 | 104,2019150,1,validation,150,20190530 107 | 105,2019151,0,training,151,20190531 108 | 106,2019152,0,training,152,20190601 109 | 107,2019153,0,training,153,20190602 110 | 108,2019154,0,training,154,20190603 111 | 109,2019155,0,training,155,20190604 112 | 110,2019156,0,training,156,20190605 113 | 111,2019157,0,training,157,20190606 114 | 112,2019158,0,training,158,20190607 115 | 113,2019159,0,training,159,20190608 116 | 114,2019160,0,training,160,20190609 117 | 115,2019161,0,training,161,20190610 118 | 116,2019162,0,training,162,20190611 119 | 117,2019163,1,validation,163,20190612 120 | 118,2019164,0,training,164,20190613 121 | 119,2019165,0,training,165,20190614 122 | 120,2019166,0,training,166,20190615 123 | 121,2019167,1,validation,167,20190616 124 | 122,2019168,0,training,168,20190617 125 | 123,2019169,0,training,169,20190618 126 | 124,2019170,0,training,170,20190619 127 | 125,2019171,0,training,171,20190620 128 | 126,2019172,0,training,172,20190621 129 | 127,2019173,0,training,173,20190622 130 | 128,2019174,2,test,174,20190623 131 | 129,2019175,0,training,175,20190624 132 | 130,2019176,0,training,176,20190625 133 | 131,2019177,0,training,177,20190626 134 | 132,2019178,2,test,178,20190627 135 | 133,2019179,0,training,179,20190628 136 | 134,2019180,0,training,180,20190629 137 | 135,2019181,0,training,181,20190630 138 | 136,2019182,0,training,182,20190701 139 | 137,2019183,0,training,183,20190702 140 | 138,2019184,0,training,184,20190703 141 | 139,2019185,0,training,185,20190704 142 | 140,2019186,0,training,186,20190705 143 | 141,2019187,0,training,187,20190706 144 | 142,2019188,2,test,188,20190707 145 | 143,2019189,1,validation,189,20190708 146 | 144,2019190,0,training,190,20190709 147 | 145,2019191,0,training,191,20190710 148 | 146,2019192,0,training,192,20190711 149 | 147,2019193,1,validation,193,20190712 150 | 148,2019194,1,validation,194,20190713 151 | 149,2019195,0,training,195,20190714 152 | 150,2019196,0,training,196,20190715 153 | 151,2019197,0,training,197,20190716 154 | 152,2019198,0,training,198,20190717 155 | 153,2019199,0,training,199,20190718 156 | 154,2019200,0,training,200,20190719 157 | 155,2019201,0,training,201,20190720 158 | 156,2019202,0,training,202,20190721 159 | 157,2019203,0,training,203,20190722 160 | 158,2019204,0,training,204,20190723 161 | 159,2019205,1,validation,205,20190724 162 | 160,2019206,0,training,206,20190725 163 | 161,2019207,0,training,207,20190726 164 | 162,2019208,0,training,208,20190727 165 | 163,2019209,2,test,209,20190728 166 | 164,2019210,2,test-next,210,20190729 167 | 165,2019211,0,training,211,20190730 168 | 166,2019212,0,training,212,20190731 169 | 167,2019213,1,validation,213,20190801 170 | 168,2019214,0,training,214,20190802 171 | 169,2019215,0,training,215,20190803 172 | 170,2019216,0,training,216,20190804 173 | 171,2019217,0,training,217,20190805 174 | 172,2019218,0,training,218,20190806 175 | 173,2019219,0,training,219,20190807 176 | 174,2019220,0,training,220,20190808 177 | 175,2019221,1,validation,221,20190809 178 | 176,2019222,0,training,222,20190810 179 | 177,2019223,0,training,223,20190811 180 | 178,2019224,0,training,224,20190812 181 | 179,2019225,0,training,225,20190813 182 | 180,2019226,0,training,226,20190814 183 | 181,2019227,0,training,227,20190815 184 | 182,2019228,0,training,228,20190816 185 | 183,2019229,0,training,229,20190817 186 | 184,2019230,0,training,230,20190818 187 | 185,2019231,0,training,231,20190819 188 | 186,2019232,0,training,232,20190820 189 | 187,2019233,1,validation,233,20190821 190 | 188,2019234,0,training,234,20190822 191 | 189,2019235,0,training,235,20190823 192 | 190,2019236,2,test,236,20190824 193 | 191,2019237,2,test-next,237,20190825 194 | 192,2019238,1,validation,238,20190826 195 | 193,2019239,0,training,239,20190827 196 | 194,2019240,1,validation,240,20190828 197 | 195,2019241,0,training,241,20190829 198 | 196,2019242,0,training,242,20190830 199 | 197,2019243,2,test,243,20190831 200 | 198,2019244,1,validation,244,20190901 201 | 199,2019245,0,training,245,20190902 202 | 200,2019246,2,test,246,20190903 203 | 201,2019247,2,test-next,247,20190904 204 | 202,2019248,0,training,248,20190905 205 | 203,2019249,0,training,249,20190906 206 | 204,2019250,0,training,250,20190907 207 | 205,2019251,0,training,251,20190908 208 | 206,2019252,0,training,252,20190909 209 | 207,2019253,0,training,253,20190910 210 | 208,2019254,0,training,254,20190911 211 | 209,2019255,0,training,255,20190912 212 | 210,2019256,1,validation,256,20190913 213 | 211,2019257,0,training,257,20190914 214 | 212,2019258,1,validation,258,20190915 215 | 213,2019259,1,validation,259,20190916 216 | 214,2019260,0,training,260,20190917 217 | 215,2019261,0,training,261,20190918 218 | 216,2019262,0,training,262,20190919 219 | 217,2019263,0,training,263,20190920 220 | 218,2019264,0,training,264,20190921 221 | 219,2019265,0,training,265,20190922 222 | 220,2019266,0,training,266,20190923 223 | 221,2019267,0,training,267,20190924 224 | 222,2019268,0,training,268,20190925 225 | 223,2019269,2,test,269,20190926 226 | 224,2019270,2,test-next,270,20190927 227 | 225,2019271,0,training,271,20190928 228 | 226,2019272,0,training,272,20190929 229 | 227,2019273,2,test,273,20190930 230 | 228,2019274,0,training,274,20191001 231 | 229,2019275,0,training,275,20191002 232 | 230,2019276,0,training,276,20191003 233 | 231,2019277,0,training,277,20191004 234 | 232,2019278,1,validation,278,20191005 235 | 233,2019279,1,validation,279,20191006 236 | 234,2019280,0,training,280,20191007 237 | 235,2019281,0,training,281,20191008 238 | 236,2019282,0,training,282,20191009 239 | 237,2019283,1,validation,283,20191010 240 | 238,2019284,0,training,284,20191011 241 | 239,2019285,2,test,285,20191012 242 | 240,2019286,2,test-next,286,20191013 243 | 241,2019287,1,validation,287,20191014 244 | 242,2019288,0,training,288,20191015 245 | 243,2019289,0,training,289,20191016 246 | 244,2019290,0,training,290,20191017 247 | 245,2019291,0,training,291,20191018 248 | 246,2019292,0,training,292,20191019 249 | 247,2019293,0,training,293,20191020 250 | 248,2019294,0,training,294,20191021 251 | 249,2019295,0,training,295,20191022 252 | 250,2019296,1,validation,296,20191023 253 | 251,2019297,0,training,297,20191024 254 | 252,2019298,0,training,298,20191025 255 | 253,2019299,1,validation,299,20191026 256 | 254,2019300,0,training,300,20191027 257 | 255,2019301,0,training,301,20191028 258 | 256,2019302,0,training,302,20191029 259 | 257,2019303,0,training,303,20191030 260 | 258,2019304,0,training,304,20191031 261 | 259,2019305,0,training,305,20191101 262 | 260,2019306,2,test,306,20191102 263 | 261,2019307,2,test-next,307,20191103 264 | 262,2019308,0,training,308,20191104 265 | 263,2019309,0,training,309,20191105 266 | 264,2019310,2,test,310,20191106 267 | 265,2019311,0,training,311,20191107 268 | 266,2019312,0,training,312,20191108 269 | 267,2019313,0,training,313,20191109 270 | 268,2019314,0,training,314,20191110 271 | 269,2019315,0,training,315,20191111 272 | 270,2019316,0,training,316,20191112 273 | 271,2019317,0,training,317,20191113 274 | 272,2019318,0,training,318,20191114 275 | 273,2019319,2,test,319,20191115 276 | 274,2019320,0,training,320,20191116 277 | 275,2019321,0,training,321,20191117 278 | 276,2019322,0,training,322,20191118 279 | 277,2019323,1,validation,323,20191119 280 | 278,2019324,2,test,324,20191120 281 | 279,2019325,0,training,325,20191121 282 | 280,2019326,0,training,326,20191122 283 | 281,2019327,0,training,327,20191123 284 | 282,2019328,0,training,328,20191124 285 | 283,2019329,0,training,329,20191125 286 | 284,2019330,0,training,330,20191126 287 | 285,2019331,0,training,331,20191127 288 | 286,2019332,0,training,332,20191128 289 | 287,2019333,0,training,333,20191129 290 | 288,2019334,0,training,334,20191130 291 | 289,2019335,0,training,335,20191201 292 | 290,2019336,1,validation,336,20191202 293 | 291,2019337,0,training,337,20191203 294 | 292,2019338,1,validation,338,20191204 295 | 293,2019339,0,training,339,20191205 296 | 294,2019340,0,training,340,20191206 297 | 295,2019341,0,training,341,20191207 298 | 296,2019342,0,training,342,20191208 299 | 297,2019343,0,training,343,20191209 300 | 298,2019344,0,training,344,20191210 301 | 299,2019345,0,training,345,20191211 302 | 300,2019346,0,training,346,20191212 303 | 301,2019347,0,training,347,20191213 304 | 302,2019348,0,training,348,20191214 305 | 303,2019349,2,test,349,20191215 306 | 304,2019350,0,training,350,20191216 307 | 305,2019351,0,training,351,20191217 308 | 306,2019352,0,training,352,20191218 309 | 307,2019353,2,test,353,20191219 310 | 308,2019354,0,training,354,20191220 311 | 309,2019355,0,training,355,20191221 312 | 310,2019356,0,training,356,20191222 313 | 311,2019357,0,training,357,20191223 314 | 312,2019358,0,training,358,20191224 315 | 313,2019359,1,validation,359,20191225 316 | 314,2019360,0,training,360,20191226 317 | 315,2019361,0,training,361,20191227 318 | 316,2019362,0,training,362,20191228 319 | 317,2019363,2,test,363,20191229 320 | 318,2019364,2,test-next,364,20191230 321 | 319,2019365,2,test,365,20191231 322 | 320,2020001,0,training,001,20200101 323 | 321,2020002,0,training,002,20200102 324 | 322,2020003,0,training,003,20200103 325 | 323,2020004,0,training,004,20200104 326 | 324,2020005,0,training,005,20200105 327 | 325,2020006,1,validation,006,20200106 328 | 326,2020007,0,training,007,20200107 329 | 327,2020008,0,training,008,20200108 330 | 328,2020009,2,test,009,20200109 331 | 329,2020010,0,training,010,20200110 332 | 330,2020011,2,test,011,20200111 333 | 331,2020012,0,training,012,20200112 334 | 332,2020013,0,training,013,20200113 335 | 333,2020014,0,training,014,20200114 336 | 334,2020015,0,training,015,20200115 337 | 335,2020016,0,training,016,20200116 338 | 336,2020017,0,training,017,20200117 339 | 337,2020018,0,training,018,20200118 340 | 338,2020019,0,training,019,20200119 341 | 339,2020020,2,test,020,20200120 342 | 340,2020021,0,training,021,20200121 343 | 341,2020022,0,training,022,20200122 344 | 342,2020023,2,test,023,20200123 345 | 343,2020024,0,training,024,20200124 346 | 344,2020025,0,training,025,20200125 347 | 345,2020026,0,training,026,20200126 348 | 346,2020027,0,training,027,20200127 349 | 347,2020028,0,training,028,20200128 350 | 348,2020029,0,training,029,20200129 351 | 349,2020030,2,test,030,20200130 352 | 350,2020031,0,training,031,20200131 353 | 351,2020032,0,training,032,20200201 354 | 352,2020033,0,training,033,20200202 355 | 353,2020034,2,test,034,20200203 356 | 354,2020035,2,test-next,035,20200204 357 | 355,2020036,0,training,036,20200205 358 | 356,2020037,0,training,037,20200206 359 | 357,2020038,2,test,038,20200207 360 | 358,2020039,2,test-next,039,20200208 361 | 359,2020040,0,training,040,20200209 362 | 360,2020041,0,training,041,20200210 363 | 361,2020042,2,test,042,20200211 364 | 362,2020043,0,training,043,20200212 365 | 363,2020044,0,training,044,20200213 366 | 364,2020045,0,training,045,20200214 -------------------------------------------------------------------------------- /utils/w4c_dataloader.py: -------------------------------------------------------------------------------- 1 | # Author: Pedro Herruzo 2 | # Copyright 2021 Institute of Advanced Research in Artificial Intelligence (IARAI) GmbH. 3 | # IARAI licenses this file to You under the Apache License, Version 2.0 4 | # (the "License"); you may not use this file except in compliance with 5 | # the License. You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from torch.utils.data import Dataset 17 | import utils.data_utils as data_utils 18 | from utils.context_variables import get_static 19 | import os 20 | 21 | class NWCSAF(Dataset): 22 | 23 | def __init__(self, data_split, products, input_vars, target_vars, 24 | spatial_dim, collapse_time=True, 25 | len_seq_in=4, len_seq_out=32, bins_to_predict=32, day_bins=96, 26 | region_id=None, preprocess=None, 27 | crop_in=None, crop_out=None, 28 | extra_data='', crop_static=None, static_paths=None, 29 | data_path='', control_params=None, 30 | train_splits='splits.csv', 31 | test_splits='test_split.json', 32 | black_list_path='blacklist.json', precision=16, populate_mask=False, **kwargs): 33 | self.precision = {16: np.float16, 32: np.float32}[precision] 34 | self.channel_dim = 1 # specifies the dimension to concat multiple channels/variables 35 | 36 | # data dimensions 37 | self.spatial_dim = spatial_dim 38 | self.collapse_time = collapse_time 39 | self.len_seq_in = len_seq_in 40 | self.len_seq_out = len_seq_out 41 | self.bins_to_predict = bins_to_predict 42 | self.day_bins = day_bins 43 | self.day_strings = ['{}{}{}{}00'.format('0'*bool(i<10), i, '0'*bool(j<10), j) for i in np.arange(0, 24, 1) for j in np.arange(0, 60, 15)] 44 | 45 | # type of data & processing variables 46 | self.products = products 47 | self.input_vars = input_vars 48 | self.target_vars = target_vars 49 | self.region_id = region_id 50 | self.preprocess = preprocess 51 | self.populate_mask = populate_mask 52 | self.crop_in, self.crop_out = crop_in, crop_out 53 | self.control_params = control_params 54 | 55 | # load extra variables if any 56 | self.extra_data, self.static_tensor, self.static_desc = [], [], [] 57 | if extra_data != '': 58 | self.extra_data = extra_data.split('-') 59 | self.static_tensor, self.static_desc = get_static(self.extra_data, self.len_seq_in, static_paths, 60 | crop=crop_static, channel_dim=self.channel_dim) 61 | 62 | # data splits to load (training/validation/test) 63 | self.data_path = data_path + f'/{data_split}' 64 | self.data_split = data_split 65 | self.day_paths, self.test_splits = data_utils.read_splits(train_splits, test_splits) 66 | 67 | 68 | # prepare all elements to load - batch idx will use the object 'self.idx' 69 | if self.data_split != 'test': 70 | self.day_paths = self.day_paths[self.day_paths.split == self.data_split].reset_index() 71 | # self.idxs = data_utils.get_triple_idxs_w_blacklist(self.day_paths['id_date'].values, self.bins_to_predict, 72 | # self.day_bins, self.len_seq_in, 73 | # black_list_path=black_list_path) 74 | self.idxs = data_utils.get_double_idxs_w_blacklist(self.day_paths['id_date'].values, self.bins_to_predict, 75 | self.day_bins, self.len_seq_in, 76 | black_list_path=black_list_path) 77 | else: 78 | test_dates = self.day_paths[self.day_paths.split==self.data_split].reset_index() 79 | # self.idxs = data_utils.get_test_triplets(test_dates['id_date'].sort_values().values, 80 | # self.test_splits, 81 | # self.bins_to_predict) 82 | self.idxs = data_utils.get_test_doubles(test_dates['id_date'].sort_values().values, 83 | self.test_splits, 84 | self.bins_to_predict) 85 | self.day_paths = self.day_paths[self.day_paths.split.isin(['test', 'test-next'])].reset_index() 86 | 87 | 88 | 89 | def __len__(self): 90 | """ total number of samples (sequences of in:4-out:1 in our case) to train """ 91 | return len(self.idxs) 92 | 93 | def load_in_seq(self, day_id, in_start_id): # , lead_time): 94 | """ load the input sequence """ 95 | 96 | # 1. load nwcsaf products & metadata 97 | in_seq, in_info = data_utils.get_sequence_netcdf4(self.len_seq_in, in_start_id, day_id, 98 | self.products, self.data_path, self.input_vars, 99 | # self.target_vars, 100 | crop=self.crop_in, preprocess=self.preprocess['source'], 101 | day_bins=self.day_bins, 102 | sorted_dates=self.day_paths.id_date.sort_values().values, 103 | populate_mask=self.populate_mask) 104 | 105 | # 2. Load extra features 106 | if len(self.static_tensor) != 0: # 2.1 static features 107 | in_seq = np.concatenate((in_seq, self.static_tensor), axis=self.channel_dim) 108 | in_info['channels'] += self.static_desc 109 | 110 | # # 3. Load lead time to predict and normalize it 111 | # data = np.ones(shape=(self.len_seq_in, 1, self.spatial_dim[0], self.spatial_dim[1])) 112 | # data[...] = (lead_time+1)/self.bins_to_predict 113 | # in_seq = np.concatenate((in_seq, data), axis=self.channel_dim) 114 | # in_info['channels'] += ['lead_time'] 115 | 116 | # 3. Load time_slot 117 | if self.control_params['use_time_slot']: 118 | data = np.stack([np.ones(shape=(1, *self.spatial_dim)) * ((i + in_start_id) % self.day_bins) for i in 119 | range(self.len_seq_in)]) / self.day_bins 120 | in_seq = np.concatenate((in_seq, data), axis=self.channel_dim) 121 | in_info['channels'] += ['time_slot'] 122 | 123 | 124 | 125 | # 4. Collapse time if needed and set the appropriate data type for learning 126 | if self.collapse_time: 127 | in_seq = data_utils.time_2_channels(in_seq, *self.spatial_dim) 128 | 129 | in_seq = in_seq.astype(self.precision) # np.float16) # np.float32) 130 | 131 | return in_seq, in_info 132 | 133 | def load_in_out(self, day_id, in_start_id): #, lead_time): 134 | """ load input/output data """ 135 | 136 | # load input sequence 137 | in_seq, in_info = self.load_in_seq(day_id, in_start_id) # , lead_time) 138 | 139 | # load ground truth 140 | if self.data_split != 'test': 141 | target_time = in_start_id + self.len_seq_in # + lead_time 142 | # out, masks, channels = data_utils.get_products_netcdf4(day_id, self.day_strings[target_time], 143 | # self.products, self.data_path, self.target_vars, 144 | # self.crop_out, self.preprocess['source']) 145 | 146 | out, out_info = data_utils.get_sequence_netcdf4(self.len_seq_out, target_time, day_id, 147 | self.products, self.data_path, self.target_vars, 148 | crop=self.crop_out, preprocess=self.preprocess['source'], 149 | day_bins=self.day_bins, 150 | sorted_dates=self.day_paths.id_date.sort_values().values, 151 | populate_mask=self.populate_mask) 152 | if self.collapse_time: 153 | out = data_utils.time_2_channels(out, *self.spatial_dim) 154 | 155 | metadata = {'in': in_info, 156 | 'out': {'day_in_year': [day_id], # 'lead_time': [lead_time], 157 | 'time_bins': [target_time], 'region_id': self.region_id}} 158 | # 'masks': out_info['masks']}} 159 | if self.populate_mask: 160 | metadata['out']['masks'] = out_info['masks'] 161 | else: 162 | out = np.asarray([]) # we don't have the ground truth for the test split 163 | metadata = {'in': in_info, 164 | 'out': {'day_in_year': [day_id], # 'lead_time': [lead_time] 165 | 'region_id': self.region_id}} 166 | out = out.astype(self.precision) # np.float16) # np.float32) 167 | 168 | return in_seq, out, metadata 169 | 170 | def __getitem__(self, idx): 171 | """ load 1 sequence (1 sample) """ 172 | # day_id, in_start_id, lead_time = self.idxs[idx] 173 | # return self.load_in_out(day_id, in_start_id, lead_time) 174 | day_id, in_start_id = self.idxs[idx] 175 | if 'heldout' in self.data_path: 176 | day_id += 1000 177 | return self.load_in_out(day_id, in_start_id) 178 | 179 | def get_date(self, id_day): 180 | """ get date from day_in_year id """ 181 | return str(self.day_paths[self.day_paths.id_date==id_day]['date'].values[0]) 182 | 183 | def geti(self, idx=0): 184 | """ this function allows you to get 1 sample for debugging 185 | Note that the batch dimension is missing, so it is added 186 | 187 | example: 188 | ds = create_dataset(data_split, params) 189 | in_seq, out, metadata = ds.geti(0) 190 | """ 191 | in_seq, out, metadata = self.__getitem__(idx) 192 | in_seq = np.expand_dims(in_seq, axis=0) 193 | out = np.expand_dims(out, axis=0) 194 | metadata = np.expand_dims(metadata, axis=0) 195 | 196 | return in_seq, out, metadata 197 | 198 | def create_dataset(data_split, params, precision=16, populate_mask=False): 199 | return NWCSAF(data_split, precision=precision, populate_mask=populate_mask, **params) -------------------------------------------------------------------------------- /validation_metrics.py: -------------------------------------------------------------------------------- 1 | # Author: Pedro Herruzo 2 | # Copyright 2021 Institute of Advanced Research in Artificial Intelligence (IARAI) GmbH. 3 | # IARAI licenses this file to You under the Apache License, Version 2.0 4 | # (the "License"); you may not use this file except in compliance with 5 | # the License. You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | class LeadTimeEval(): 18 | """ This class helps to evaluate how a model performs across the prediction time horizon. 19 | It will save the metrics per time lead predicted and create a plot with all of them. 20 | """ 21 | 22 | def __init__(self, len_seq_in=4, bins_to_predict=32, n_channels=3): 23 | 24 | self.len_seq = len_seq_in 25 | self.n_bins = bins_to_predict 26 | self.n_channels = n_channels 27 | self.errors = {} 28 | 29 | self.index = ['day_in_year', 'in_start_id', 'channel'] 30 | self.cols = self.index + [j for j in range(self.n_bins)] 31 | 32 | def get_numpy(self, x): 33 | return x.detach().cpu().numpy() 34 | 35 | def update_errors(self, err, metadata): 36 | """ Updates errors per channel for a particular 'date', 'starting time bin' and the 'lead time' predicted 37 | 38 | Note: 39 | - err.shape = (batch_size, channels_size) 40 | - metadata has to contain (day_in_year, lead_time, time_bins) 41 | """ 42 | 43 | days = self.get_numpy(metadata['out']['day_in_year'][0]) 44 | lead_times = self.get_numpy(metadata['out']['lead_time'][0]) 45 | target_times = self.get_numpy(metadata['out']['time_bins'][0]) 46 | 47 | j = 0 48 | for d, lead_t, tgt_t, e in zip(days, lead_times, target_times, err): 49 | start_t = tgt_t - self.len_seq - lead_t 50 | #str_print = f'{j}- day={d}, lead_t={lead_t}, tgt_t={tgt_t}, start_t={start_t}, err={e}' 51 | #print(str_print) 52 | 53 | if d not in self.errors.keys(): 54 | self.errors[d] = {} 55 | 56 | if start_t not in self.errors[d].keys(): 57 | self.errors[d][start_t] = {} 58 | 59 | if lead_t not in self.errors[d][start_t].keys(): 60 | self.errors[d][start_t][lead_t] = e # e.shape = channels_size 61 | else: 62 | print(f"Error, this lead_time={lead_t} was already updated in day={d} start_t={start_t}") 63 | 64 | j += 1 65 | 66 | def __update_channel_errors(self, errors, row_channels): 67 | """ Updates the errors per channel 68 | 69 | Args: 70 | errors (dict): errros of each variable 71 | row_channels (list): placeholder to save the errors per variable 72 | 73 | Returns: 74 | list: filled placeholder to save the errors per variable 75 | """ 76 | for id_chn in range(self.n_channels): 77 | row_channels[id_chn].append(errors[id_chn]) 78 | return row_channels 79 | 80 | def __get_lead_time_array(self, data, n_bins): 81 | """ builds a list of lists containing all predictions per 'date', 'starting time bin', 'variable', and 'lead time' 82 | """ 83 | rows = [] 84 | 85 | for id_date in data.keys(): 86 | for id_start in data[id_date].keys(): 87 | row_channels = [] 88 | for id_chn in range(self.n_channels): 89 | row_channels.append([id_date, id_start, id_chn]) 90 | 91 | for j in range(n_bins): 92 | # if the lead time is informed 93 | if j in data[id_date][id_start].keys(): 94 | errors = data[id_date][id_start][j] 95 | 96 | row_channels = self.__update_channel_errors(errors, row_channels) 97 | else: 98 | row_channels = self.__update_channel_errors([np.NaN]*self.n_channels, row_channels) 99 | 100 | for row in row_channels: 101 | rows.append(row) 102 | 103 | return rows 104 | 105 | def get_lead_time_errors_df(self): 106 | """ Builds a spreadsheet containing all 32 predictions for all 'dates', 'starting time bins', and 'varriables' 107 | 108 | Returns: 109 | pandas.DataFrame: each rows contains the 32 predictions for a particular 'date' and 'starting time bin' 110 | """ 111 | import pandas as pd 112 | 113 | rows = self.__get_lead_time_array(self.errors, self.n_bins) 114 | 115 | df = pd.DataFrame(rows, columns=self.cols) 116 | df = df.set_index(self.index).sort_index() 117 | 118 | return df 119 | 120 | def get_lead_time_metrics(self, root, title, region='', y_label='mse', x_label='lead times'): 121 | """ creates a plot for the prediction horizon 122 | 123 | Args: 124 | root (str): path to save the plots 125 | title (str): title of the plot 126 | region (str, optional): Region where the errors belong to. Defaults to ''. 127 | 128 | Returns: 129 | list: errors, standard deviations 130 | """ 131 | import matplotlib.pyplot as plt 132 | fname = f'{root}/lead_times_mse_{region}.csv' 133 | fname_fig = f'{root}/lead_times_mse_fig_{region}' 134 | 135 | df = self.get_lead_time_errors_df() 136 | df.to_csv(fname, encoding='utf-8') 137 | print("saved errors to disk:", fname) 138 | 139 | errs, std = df.mean(), df.std() 140 | 141 | fig = plt.figure(figsize=(20,10)) 142 | plt.errorbar(np.arange(len(errs)), errs, std, fmt='ok', lw=3) 143 | plt.ylabel(y_label) 144 | plt.xlabel(x_label) 145 | plt.xticks(np.arange(self.n_bins), np.arange(self.n_bins)) 146 | plt.title(title) 147 | 148 | fig.savefig(fname_fig) 149 | plt.show() 150 | plt.close(fig) 151 | 152 | return list(errs), list(std) 153 | 154 | 155 | --------------------------------------------------------------------------------