├── assets ├── frame.png └── framework.png ├── data ├── stats.csv └── README.md ├── utils ├── logging_utils.py ├── YParams.py └── data_loader_multifiles.py ├── LICENSE ├── config └── experiment.yaml ├── SECURITY.md ├── README.md ├── .gitignore ├── inference.py ├── train.py └── models └── encdec.py /assets/frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ADAF/main/assets/frame.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ADAF/main/assets/framework.png -------------------------------------------------------------------------------- /data/stats.csv: -------------------------------------------------------------------------------- 1 | ,variable,max,min 2 | 0,hrrr_q,0.025,0.0 3 | 1,hrrr_sp,1050,600 4 | 2,hrrr_u_10,25,-25 5 | 3,hrrr_v_10,25,-25 6 | 4,hrrr_t,50,-40 7 | 5,rtma_t,50,-40 8 | 6,rtma_q,0.025,0.0 9 | 7,rtma_u10,25,-25 10 | 8,rtma_v10,25,-25 11 | 9,rtma_sp,1050,600 12 | 10,z,3190,-65 13 | 11,sta_t,50,-40 14 | 12,sta_q,0.025,0.0 15 | 13,sta_u10,25,-25 16 | 14,sta_v10,25,-25 17 | 15,sta_p,1200,600 18 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | _format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 5 | 6 | 7 | def config_logger(log_level=logging.INFO): 8 | logging.basicConfig(format=_format, level=log_level) 9 | 10 | 11 | def log_to_file( 12 | logger_name=None, 13 | log_level=logging.INFO, 14 | log_filename="tensorflow.log" 15 | ): 16 | if not os.path.exists(os.path.dirname(log_filename)): 17 | os.makedirs(os.path.dirname(log_filename)) 18 | 19 | if logger_name is not None: 20 | log = logging.getLogger(logger_name) 21 | else: 22 | log = logging.getLogger() 23 | 24 | fh = logging.FileHandler(log_filename) 25 | fh.setLevel(log_level) 26 | fh.setFormatter(logging.Formatter(_format)) 27 | log.addHandler(fh) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Variables included in a sample file: 2 | 3 | | Variable | Decription | Dimension | 4 | | ----------- | ----------- | ----------- | 5 | | z | Topography, normalized      | [lat, lon] | 6 | | rtma_t | T2M from RTMA, normalized   | [lat, lon] | 7 | | rtma_q | Q from RTMA, normalized     | [lat, lon] | 8 | | rtma_u10 | U10 from RTMA, normalized   | [lat, lon] | 9 | | rtma_v10 | V10 from RTMA, normalized   | [lat, lon] | 10 | | sta_t | T2M from station's observation, 0 means non-station, normalized | [obs_time_window, lat, lon] | 11 | | sta_q | Q from station's observation, 0 means non-station, normalized   | [obs_time_window, lat, lon] | 12 | | sta_u10 | U10 from station's observation, 0 means non-station, normalized | [obs_time_window, lat, lon] | 13 | | sta_v10 | V10 from station's observation, 0 means non-station, normalized | [obs_time_window, lat, lon] | 14 | | CMI02 | ABI Band 2: visible (red), normalized | [obs_time_window, lat, lon] | 15 | | CMI07 | ABI Band 7: shortwave infrared, normalized | [obs_time_window, lat, lon] | 16 | | CMI07 | ABI Band 10: low-level water vapor, normalized | [obs_time_window, lat, lon] | 17 | | CMI14 | ABI Bands 14: longwave infrared, normalized | [obs_time_window, lat, lon] | 18 | | hrrr_t | T2M from HRRR 1-hour forecast | [lat, lon]               | 19 | | hrrr_q | Q from HRRR 1-hour forecast | [lat, lon]               | 20 | | hrrr_u_10 | U10 from HRRR 1-hour forecast | [lat, lon]               | 21 | | hrrr_v_10 | V10 from HRRR 1-hour forecast | [lat, lon]               | 22 | -------------------------------------------------------------------------------- /utils/YParams.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # import sys 3 | import logging 4 | from ruamel.yaml import YAML 5 | 6 | 7 | class YParams: 8 | """Yaml file parser""" 9 | 10 | def __init__(self, yaml_filename, config_name, print_params=False): 11 | self._yaml_filename = yaml_filename 12 | self._config_name = config_name 13 | self.params = {} 14 | 15 | # if print_params: 16 | # print(os.system('hostname')) 17 | # print("Configuration:", yaml_filename) 18 | 19 | with open(yaml_filename, "rb") as _file: 20 | yaml = YAML().load(_file) 21 | for key, val in yaml[config_name].items(): 22 | if print_params: 23 | print(key, val) 24 | if val == "None": 25 | val = None 26 | 27 | self.params[key] = val 28 | self.__setattr__(key, val) 29 | 30 | def __getitem__(self, key): 31 | return self.params[key] 32 | 33 | def __setitem__(self, key, val): 34 | self.params[key] = val 35 | self.__setattr__(key, val) 36 | 37 | def __contains__(self, key): 38 | return key in self.params 39 | 40 | def update_params(self, config): 41 | for key, val in config.items(): 42 | self.params[key] = val 43 | self.__setattr__(key, val) 44 | 45 | def log(self): 46 | logging.info("------------------ Configuration ------------------") 47 | logging.info("Configuration file: " + str(self._yaml_filename)) 48 | logging.info("Configuration name: " + str(self._config_name)) 49 | for key, val in self.params.items(): 50 | logging.info(str(key) + ": " + str(val)) 51 | logging.info("---------------------------------------------------") 52 | -------------------------------------------------------------------------------- /config/experiment.yaml: -------------------------------------------------------------------------------- 1 | ### base config ### 2 | # -*- coding: utf-8 -*- 3 | full_field: &FULL_FIELD 4 | lr: 1E-3 # 1e-4 cause loss nan for VA 5 | max_epochs: 1200 6 | valid_frequency: 5 7 | 8 | optimizer_type: "AdamW" # Adam, FusedAdam, SWA 9 | 10 | scheduler: "ReduceLROnPlateau" # ReduceLROnPlateau, MultiplicativeLR 11 | lr_reduce_factor: 0.65 12 | 13 | num_data_workers: 8 # 0 14 | # gridtype: 'sinusoidal' # options 'sinusoidal' or 'linear' 15 | enable_nhwc: !!bool False 16 | 17 | # directory path to store training checkpoints and other output 18 | exp_dir: "./exp" 19 | 20 | # directory path to store dataset for train, valid, and test 21 | data_path: "./data/" 22 | train_data_path: "./data/train" 23 | valid_data_path: "./data/valid" 24 | test_data_path: "./data/test" 25 | 26 | # normalization 27 | norm_type: "variable_wise_ignore_extreme" # options: channel_wise, variable_wise, variable_wise_ignore_extreme 28 | normalization: "minmax_ignore_extreme" # options: minmax, zscore, minmax_ignore_extreme, scale 29 | 30 | add_noise: False 31 | 32 | N_in_channels: 21 33 | N_out_channels: 5 34 | 35 | bg_ensemble_num: 1 # 3 36 | obs_time_window: 3 # if 3, use observation at analysis time 37 | 38 | inp_hrrr_vars: ['hrrr_q', 'hrrr_t', 'hrrr_u_10', 'hrrr_v_10'] # 'hrrr_sp' 39 | inp_satelite_vars: ['CMI02', 'CMI07', 'CMI14', 'CMI10'] 40 | inp_obs_vars: ["sta_q", "sta_t", "sta_u10", "sta_v10"] # "sta_p" 41 | hold_out_obs: True 42 | field_tar_vars: ["rtma_q", "rtma_t", "rtma_u10", "rtma_v10"] # "rtma_sp" 43 | target_vars: ["q", "t", "u10", "v10"] # "sp" 44 | 45 | learn_residual: True 46 | 47 | input_time_feature: False # use inp_ausiliary_vars if True 48 | inp_auxiliary_vars: ["hour"] 49 | 50 | stack_channel_by_var: False 51 | 52 | save_model_freq: 5 53 | log_to_screen: !!bool True 54 | log_to_wandb: !!bool True 55 | save_checkpoint: !!bool True 56 | 57 | EncDec: &EncDec 58 | <<: *FULL_FIELD 59 | nettype: "EncDec" 60 | lr: 2E-3 61 | upscale: 1 62 | in_chans: 29 # 33 63 | out_chans: 4 64 | img_size_x: 1280 65 | img_size_y: 512 66 | window_size: 4 67 | patch_size: 4 # 1 # need be divisible by img_size 68 | num_feat: 64 69 | drop_rate: 0.1 70 | drop_path_rate: 0.1 71 | attn_drop_rate: 0.1 72 | ape: False 73 | patch_norm: True 74 | use_checkpoint: False 75 | resi_connection: "1conv" 76 | qkv_bias: True 77 | qk_scale: None 78 | img_range: 1. 79 | depths: [3] # [3] 80 | embed_dim: 64 # need be divisible by num_heads 81 | num_heads: [4] # [16] 82 | mlp_ratio: 2 # 2 83 | upsampler: "pixelshuffle" 84 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADAF 2 | 3 | This repository contains the code used for "ADAF: An Artificial Intelligence Data Assimilation Framework for Weather Forecasting" 4 | 5 | ## Abstract 6 | The forecasting skill of numerical weather prediction (NWP) models critically depends on the accurate initial conditions, also known as analysis, provided by data assimilation (DA). 7 | Traditional DA methods often face a trade-off between computational cost and accuracy due to complex linear algebra computations and the high dimensionality of the model, especially in nonlinear systems. Moreover, processing massive data in real-time requires substantial computational resources. To address this, we introduce an artificial intelligence-based data assimilation framework (ADAF) to generate high-quality kilometer-scale analysis. This study is the pioneering work using real-world observations from varied locations and multiple sources to verify the AI method's efficacy in DA, including sparse surface weather observations and satellite imagery. We implemented ADAF for four near-surface variables in the Contiguous United States (CONUS). The results demonstrate that ADAF outperforms the High Resolution Rapid Refresh Data Assimilation System (HRRRDAS) in accuracy by 16\% to 33\%, and is able to reconstruct extreme events, such as the wind field of tropical cyclones. Sensitivity experiments reveal that ADAF can generate high-quality analysis even with low-accuracy backgrounds and extremely sparse surface observations. ADAF can assimilate massive observations within a three-hour window at low computational cost, taking about two seconds on an AMD MI200 graphics processing unit (GPU). ADAF has been shown to be efficient and effective in real-world DA, underscoring its potential role in operational weather forecasting. 8 | 9 | ![Figure: Overall framework](/assets/framework.png) 10 | 11 | 12 | ## Data 13 | - Pre-processed data 14 | 15 | [Link for Pre-processed Data - Zenodo Download Link](https://zenodo.org/records/14020879) 16 | 17 | The pre-proccesd data consists of input-target pairs. The inputs include surface weather observations within a 3-hour window, GOES-16 satellite imagery within a 3-hour window, HRRR forecast, and topography. The target is a combination of RTMA and surface weather observations. The table below summarizes the input and target datasets utilized in this study. All data were regularized to grids of size 512 $\times$ 1280 with a spatial resolution of 0.05 $\times$ 0.05 $^\circ$. 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |
DatasetSourceTime windowVariables/Bands
InputSurface weather observationsWeatherReal-Synoptic (Jin et al., 2024)3 hoursQ, T2M, U10, V10
Satellite imageryGOES-16 (Tan et al., 2019)3 hours0.64, 3.9, 7.3, 11.2 $\mu m$
BackgroundHRRR forecast (Dowell et al., 2022)N/AQ, T2M, U10, V10
TopographyERA5 (Hersbach et al., 2019)N/AGeopotential
TargetAnalysisRTMA (Pondeca et al., 2011)N/AQ, T2M, U10, V10
Surface weather observationsWeatherReal-Synoptic (Jin et al., 2024)N/AQ, T2M, U10, V10
65 | 66 | 67 | - Pre-computed normalization statistics 68 | 69 | [Link for pre-computed normalization statistics- Zenodo Download Link](https://zenodo.org/records/14020879). 70 | 71 | If you are utilizing the pre-trained model weights that we provided, it is crucial that you utilize of the given statistics as these were used during model training. The learned model weights complement the normalizing statistics exactly. 72 | 73 | The data directory of pre-processed data and pre-computed normalization statistics is organized as follows: 74 | ``` 75 | data 76 | │ README.md 77 | └───test 78 | │ │ 2022-10-01_00.nc 79 | │ │ 2022-10-02_06.nc 80 | │ │ 2022-10-03_12.nc 81 | │ │ ... 82 | │ │ 2023-10-31_00.nc 83 | └───stats.csv 84 | ``` 85 | 86 | - Trained model weights 87 | 88 | [Link for trained model weights - Zenodo Download Link](https://zenodo.org/records/14020879) 89 | 90 | ``` 91 | model_weights/ 92 | │ best_ckpt.tar 93 | ``` 94 | 95 | ## Train 96 | 97 | Training configurations can be set up in config/experiment.yaml. Notice the following paths need to be set by the user. 98 | 99 | ``` 100 | exp_dir # directory path to store training checkpoints and other output 101 | train_data_path # directory path to store dataset for train 102 | valid_data_path # directory path to store dataset for valid 103 | test_data_path # directory path to store dataset for test 104 | ``` 105 | 106 | An example launch script for distributed data parallel training is provided: 107 | 108 | 109 | ```shell 110 | run_num=$(date "+%Y%m%d-%H%M%S") 111 | resume=False 112 | 113 | exp_dir='./exp/' 114 | wandb_group='ADAF' 115 | 116 | net_config='EncDec' 117 | export CUDA_VISIBLE_DEVICES='4,5,6,7' 118 | 119 | nohup python -m torch.distributed.launch \ 120 | --master_port=26500 \ 121 | --standalone \ 122 | --nproc_per_node=4 \ 123 | --nnodes=1 \ 124 | 02_train.py \ 125 | --hold_out_obs_ratio=0.5 \ 126 | --lr=0.0002 \ 127 | --lr_reduce_factor=0.8 \ 128 | --target='analysis_obs' \ 129 | --max_epochs=2000 \ 130 | --exp_dir='./exp/' \ 131 | --yaml_config='./config/experiment.yaml' \ 132 | --net_config='EncDec' \ 133 | --resume=False \ 134 | --run_num=${run_num} \ 135 | --batch_size=16 \ 136 | --wandb_group='ADAF' \ 137 | --device='GPU' \ 138 | --wandb_api_key='your wandb api key' \ 139 | > logs/train_${net_config}_${run_num}.log 2>&1 & 140 | 141 | ``` 142 | 143 | ## Inference 144 | In order to run ADAF in inference mode you will need to have the following files on hand. 145 | 146 | 1. The path to the test sample file. (./data/test/) 147 | 148 | 2. The inference script (inference.py) 149 | 150 | 3. The model weights hosted at Trained Model Weights (./model_weights/model_trained.ckpt) 151 | 152 | 4. The pre-computed normalization statistics (./data/stats.csv) 153 | 154 | 5. The configuration file (./config/experiment.yaml) 155 | 156 | Once you have all the file listed above you should be ready to go. 157 | 158 | An example launch script for inference is provided. 159 | ```shell 160 | export CUDA_VISIBLE_DEVICES='0' 161 | 162 | nohup python -u inference.py \ 163 | --seed=0 \ 164 | --exp_dir='./exp/' \ # directory to save prediction 165 | --test_data_path='./data/test' \ # path to test data 166 | --net_config='EncDec' \ # network configuration 167 | --hold_out_obs_ratio=0.3 \ # the ratio of surface observations to be fed into the model 168 | > inference.log 2>&1 & 169 | 170 | ``` 171 | 172 | 173 | 174 | ## References 175 | 176 | ``` 177 | 1. Jin, W. et al. WeatherReal: A Benchmark Based on In-Situ Observations for Evaluating Weather Models. (2024). 178 | 2. Dowell, D. et al. The High-Resolution Rapid Refresh (HRRR): An Hourly Updating Convection-Allowing Forecast Model. Part I: Motivation and System Description. Weather and Forecasting 37, (2022). 179 | 3. Tan, B., Dellomo, J., Wolfe, R. & Reth, A. GOES-16 and GOES-17 ABI INR assessment. in Earth Observing Systems XXIV vol. 11127 290–301 (SPIE, 2019). 180 | 4. Hersbach, H. et al. ERA5 monthly averaged data on single levels from 1979 to present. Copernicus Climate Change Service (C3S) Climate Data Store (CDS) 10, 252–266 (2019). 181 | 5. Pondeca, M. S. F. V. D. et al. The Real-Time Mesoscale Analysis at NOAA’s National Centers for Environmental Prediction: Current Status and Development. Weather and Forecasting 26, 593–612 (2011). 182 | ``` 183 | 184 | If you find this work useful, cite it using: 185 | 186 | ``` 187 | @article{xiang2024ADAF, 188 | title={ADAF: An Artificial Intelligence Data Assimilation Framework for Weather Forecasting}, 189 | author={Yanfei Xiang and Weixin Jin and Haiyu Dong and Mingliang Bai and Zuliang Fang and Pengcheng Zhao and Hongyu Sun and Kit Thambiratnam and Qi Zhang and Xiaomeng Huang}, 190 | year={2024}, 191 | journal={arXiv preprint arXiv:2411.16807}, 192 | url={https://arxiv.org/abs/2411.16807}, 193 | } 194 | ``` 195 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Ww][Ii][Nn]32/ 27 | [Aa][Rr][Mm]/ 28 | [Aa][Rr][Mm]64/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | [Ll]og/ 33 | [Ll]ogs/ 34 | 35 | # Visual Studio 2015/2017 cache/options directory 36 | .vs/ 37 | # Uncomment if you have tasks that create the project's static files in wwwroot 38 | #wwwroot/ 39 | 40 | # Visual Studio 2017 auto generated files 41 | Generated\ Files/ 42 | 43 | # MSTest test Results 44 | [Tt]est[Rr]esult*/ 45 | [Bb]uild[Ll]og.* 46 | 47 | # NUnit 48 | *.VisualState.xml 49 | TestResult.xml 50 | nunit-*.xml 51 | 52 | # Build Results of an ATL Project 53 | [Dd]ebugPS/ 54 | [Rr]eleasePS/ 55 | dlldata.c 56 | 57 | # Benchmark Results 58 | BenchmarkDotNet.Artifacts/ 59 | 60 | # .NET Core 61 | project.lock.json 62 | project.fragment.lock.json 63 | artifacts/ 64 | 65 | # ASP.NET Scaffolding 66 | ScaffoldingReadMe.txt 67 | 68 | # StyleCop 69 | StyleCopReport.xml 70 | 71 | # Files built by Visual Studio 72 | *_i.c 73 | *_p.c 74 | *_h.h 75 | *.ilk 76 | *.meta 77 | *.obj 78 | *.iobj 79 | *.pch 80 | *.pdb 81 | *.ipdb 82 | *.pgc 83 | *.pgd 84 | *.rsp 85 | # but not Directory.Build.rsp, as it configures directory-level build defaults 86 | !Directory.Build.rsp 87 | *.sbr 88 | *.tlb 89 | *.tli 90 | *.tlh 91 | *.tmp 92 | *.tmp_proj 93 | *_wpftmp.csproj 94 | *.log 95 | *.tlog 96 | *.vspscc 97 | *.vssscc 98 | .builds 99 | *.pidb 100 | *.svclog 101 | *.scc 102 | 103 | # Chutzpah Test files 104 | _Chutzpah* 105 | 106 | # Visual C++ cache files 107 | ipch/ 108 | *.aps 109 | *.ncb 110 | *.opendb 111 | *.opensdf 112 | *.sdf 113 | *.cachefile 114 | *.VC.db 115 | *.VC.VC.opendb 116 | 117 | # Visual Studio profiler 118 | *.psess 119 | *.vsp 120 | *.vspx 121 | *.sap 122 | 123 | # Visual Studio Trace Files 124 | *.e2e 125 | 126 | # TFS 2012 Local Workspace 127 | $tf/ 128 | 129 | # Guidance Automation Toolkit 130 | *.gpState 131 | 132 | # ReSharper is a .NET coding add-in 133 | _ReSharper*/ 134 | *.[Rr]e[Ss]harper 135 | *.DotSettings.user 136 | 137 | # TeamCity is a build add-in 138 | _TeamCity* 139 | 140 | # DotCover is a Code Coverage Tool 141 | *.dotCover 142 | 143 | # AxoCover is a Code Coverage Tool 144 | .axoCover/* 145 | !.axoCover/settings.json 146 | 147 | # Coverlet is a free, cross platform Code Coverage Tool 148 | coverage*.json 149 | coverage*.xml 150 | coverage*.info 151 | 152 | # Visual Studio code coverage results 153 | *.coverage 154 | *.coveragexml 155 | 156 | # NCrunch 157 | _NCrunch_* 158 | .*crunch*.local.xml 159 | nCrunchTemp_* 160 | 161 | # MightyMoose 162 | *.mm.* 163 | AutoTest.Net/ 164 | 165 | # Web workbench (sass) 166 | .sass-cache/ 167 | 168 | # Installshield output folder 169 | [Ee]xpress/ 170 | 171 | # DocProject is a documentation generator add-in 172 | DocProject/buildhelp/ 173 | DocProject/Help/*.HxT 174 | DocProject/Help/*.HxC 175 | DocProject/Help/*.hhc 176 | DocProject/Help/*.hhk 177 | DocProject/Help/*.hhp 178 | DocProject/Help/Html2 179 | DocProject/Help/html 180 | 181 | # Click-Once directory 182 | publish/ 183 | 184 | # Publish Web Output 185 | *.[Pp]ublish.xml 186 | *.azurePubxml 187 | # Note: Comment the next line if you want to checkin your web deploy settings, 188 | # but database connection strings (with potential passwords) will be unencrypted 189 | *.pubxml 190 | *.publishproj 191 | 192 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 193 | # checkin your Azure Web App publish settings, but sensitive information contained 194 | # in these scripts will be unencrypted 195 | PublishScripts/ 196 | 197 | # NuGet Packages 198 | *.nupkg 199 | # NuGet Symbol Packages 200 | *.snupkg 201 | # The packages folder can be ignored because of Package Restore 202 | **/[Pp]ackages/* 203 | # except build/, which is used as an MSBuild target. 204 | !**/[Pp]ackages/build/ 205 | # Uncomment if necessary however generally it will be regenerated when needed 206 | #!**/[Pp]ackages/repositories.config 207 | # NuGet v3's project.json files produces more ignorable files 208 | *.nuget.props 209 | *.nuget.targets 210 | 211 | # Microsoft Azure Build Output 212 | csx/ 213 | *.build.csdef 214 | 215 | # Microsoft Azure Emulator 216 | ecf/ 217 | rcf/ 218 | 219 | # Windows Store app package directories and files 220 | AppPackages/ 221 | BundleArtifacts/ 222 | Package.StoreAssociation.xml 223 | _pkginfo.txt 224 | *.appx 225 | *.appxbundle 226 | *.appxupload 227 | 228 | # Visual Studio cache files 229 | # files ending in .cache can be ignored 230 | *.[Cc]ache 231 | # but keep track of directories ending in .cache 232 | !?*.[Cc]ache/ 233 | 234 | # Others 235 | ClientBin/ 236 | ~$* 237 | *~ 238 | *.dbmdl 239 | *.dbproj.schemaview 240 | *.jfm 241 | *.pfx 242 | *.publishsettings 243 | orleans.codegen.cs 244 | 245 | # Including strong name files can present a security risk 246 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 247 | #*.snk 248 | 249 | # Since there are multiple workflows, uncomment next line to ignore bower_components 250 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 251 | #bower_components/ 252 | 253 | # RIA/Silverlight projects 254 | Generated_Code/ 255 | 256 | # Backup & report files from converting an old project file 257 | # to a newer Visual Studio version. Backup files are not needed, 258 | # because we have git ;-) 259 | _UpgradeReport_Files/ 260 | Backup*/ 261 | UpgradeLog*.XML 262 | UpgradeLog*.htm 263 | ServiceFabricBackup/ 264 | *.rptproj.bak 265 | 266 | # SQL Server files 267 | *.mdf 268 | *.ldf 269 | *.ndf 270 | 271 | # Business Intelligence projects 272 | *.rdl.data 273 | *.bim.layout 274 | *.bim_*.settings 275 | *.rptproj.rsuser 276 | *- [Bb]ackup.rdl 277 | *- [Bb]ackup ([0-9]).rdl 278 | *- [Bb]ackup ([0-9][0-9]).rdl 279 | 280 | # Microsoft Fakes 281 | FakesAssemblies/ 282 | 283 | # GhostDoc plugin setting file 284 | *.GhostDoc.xml 285 | 286 | # Node.js Tools for Visual Studio 287 | .ntvs_analysis.dat 288 | node_modules/ 289 | 290 | # Visual Studio 6 build log 291 | *.plg 292 | 293 | # Visual Studio 6 workspace options file 294 | *.opt 295 | 296 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 297 | *.vbw 298 | 299 | # Visual Studio 6 auto-generated project file (contains which files were open etc.) 300 | *.vbp 301 | 302 | # Visual Studio 6 workspace and project file (working project files containing files to include in project) 303 | *.dsw 304 | *.dsp 305 | 306 | # Visual Studio 6 technical files 307 | *.ncb 308 | *.aps 309 | 310 | # Visual Studio LightSwitch build output 311 | **/*.HTMLClient/GeneratedArtifacts 312 | **/*.DesktopClient/GeneratedArtifacts 313 | **/*.DesktopClient/ModelManifest.xml 314 | **/*.Server/GeneratedArtifacts 315 | **/*.Server/ModelManifest.xml 316 | _Pvt_Extensions 317 | 318 | # Paket dependency manager 319 | .paket/paket.exe 320 | paket-files/ 321 | 322 | # FAKE - F# Make 323 | .fake/ 324 | 325 | # CodeRush personal settings 326 | .cr/personal 327 | 328 | # Python Tools for Visual Studio (PTVS) 329 | __pycache__/ 330 | *.pyc 331 | 332 | # Cake - Uncomment if you are using it 333 | # tools/** 334 | # !tools/packages.config 335 | 336 | # Tabs Studio 337 | *.tss 338 | 339 | # Telerik's JustMock configuration file 340 | *.jmconfig 341 | 342 | # BizTalk build output 343 | *.btp.cs 344 | *.btm.cs 345 | *.odx.cs 346 | *.xsd.cs 347 | 348 | # OpenCover UI analysis results 349 | OpenCover/ 350 | 351 | # Azure Stream Analytics local run output 352 | ASALocalRun/ 353 | 354 | # MSBuild Binary and Structured Log 355 | *.binlog 356 | 357 | # NVidia Nsight GPU debugger configuration file 358 | *.nvuser 359 | 360 | # MFractors (Xamarin productivity tool) working folder 361 | .mfractor/ 362 | 363 | # Local History for Visual Studio 364 | .localhistory/ 365 | 366 | # Visual Studio History (VSHistory) files 367 | .vshistory/ 368 | 369 | # BeatPulse healthcheck temp database 370 | healthchecksdb 371 | 372 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 373 | MigrationBackup/ 374 | 375 | # Ionide (cross platform F# VS Code tools) working folder 376 | .ionide/ 377 | 378 | # Fody - auto-generated XML schema 379 | FodyWeavers.xsd 380 | 381 | # VS Code files for those working on multiple tools 382 | .vscode/* 383 | !.vscode/settings.json 384 | !.vscode/tasks.json 385 | !.vscode/launch.json 386 | !.vscode/extensions.json 387 | *.code-workspace 388 | 389 | # Local History for Visual Studio Code 390 | .history/ 391 | 392 | # Windows Installer files from build outputs 393 | *.cab 394 | *.msi 395 | *.msix 396 | *.msm 397 | *.msp 398 | 399 | # JetBrains Rider 400 | *.sln.iml 401 | -------------------------------------------------------------------------------- /utils/data_loader_multifiles.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import logging 5 | import numpy as np 6 | import pandas as pd 7 | import xarray as xr 8 | from icecream import ic 9 | from torch.utils.data import dataloader, dataset 10 | from torch.utils.data.distributed import distributedsampler 11 | 12 | 13 | def get_data_loader(params, files_pattern, distributed, train): 14 | dataset = getdataset(params, files_pattern, train) 15 | 16 | if distributed: 17 | sampler = distributedsampler(dataset, shuffle=train) 18 | else: 19 | none 20 | 21 | dataloader = dataloader( 22 | dataset, 23 | batch_size=int(params.batch_size), 24 | num_workers=params.num_data_workers, 25 | shuffle=false, # (sampler is none), 26 | sampler=sampler if train else none, 27 | drop_last=true, 28 | pin_memory=torch.cuda.is_available(), 29 | ) 30 | 31 | if train: 32 | return dataloader, dataset, sampler 33 | else: 34 | return dataloader, dataset 35 | 36 | 37 | class getdataset(dataset): 38 | def __init__(self, params, location, train): 39 | self.params = params 40 | self.train = train 41 | self.location = location 42 | self.n_in_channels = params.n_in_channels 43 | self.n_out_channels = params.n_out_channels 44 | # self.add_noise = params.add_noise if train else false 45 | self._get_files_stats() 46 | 47 | def _get_files_stats(self): 48 | self.files_paths = glob.glob(self.location + "/*.nc") 49 | self.files_paths.sort() 50 | self.n_samples_total = len(self.files_paths) 51 | 52 | logging.info("getting file stats from {}".format(self.files_paths[0])) 53 | ds = xr.open_dataset(self.files_paths[0], engine="netcdf4") 54 | 55 | # original image shape (before padding) 56 | self.org_img_shape_x = ds["hrrr_t"].shape[0] 57 | self.org_img_shape_y = ds["hrrr_t"].shape[1] 58 | 59 | self.files = [none for _ in range(self.n_samples_total)] 60 | 61 | logging.info("number of samples: {}".format(self.n_samples_total)) 62 | logging.info( 63 | "found data at path {}. number of examples: {}. \ 64 | original image shape: {} x {} x {}".format( 65 | self.location, 66 | self.n_samples_total, 67 | self.org_img_shape_x, 68 | self.org_img_shape_y, 69 | self.n_in_channels, 70 | ) 71 | ) 72 | 73 | def _open_file(self, hour_idx): 74 | _file = xr.open_dataset(self.files_paths[hour_idx], engine="netcdf4") 75 | self.files[hour_idx] = _file 76 | 77 | def _min_max_norm_ignore_extreme_fill_nan(self, data, vmin, vmax): 78 | # ic(vmin.shape, vmax.shape, data.shape) 79 | if data.ndim == 4: 80 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis, np.newaxis] 81 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis, np.newaxis] 82 | vmax = np.repeat(vmax, data.shape[1], axis=1) 83 | vmax = np.repeat(vmax, data.shape[2], axis=2) 84 | vmax = np.repeat(vmax, data.shape[3], axis=3) 85 | vmin = np.repeat(vmin, data.shape[1], axis=1) 86 | vmin = np.repeat(vmin, data.shape[2], axis=2) 87 | vmin = np.repeat(vmin, data.shape[3], axis=3) 88 | 89 | elif data.ndim == 3: 90 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis] 91 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis] 92 | vmax = np.repeat(vmax, data.shape[1], axis=1) 93 | vmax = np.repeat(vmax, data.shape[2], axis=2) 94 | vmin = np.repeat(vmin, data.shape[1], axis=1) 95 | vmin = np.repeat(vmin, data.shape[2], axis=2) 96 | 97 | data -= vmin 98 | data *= 2.0 / (vmax - vmin) 99 | data -= 1.0 100 | 101 | data = np.where(data > 1, 1, data) 102 | data = np.where(data < -1, -1, data) 103 | 104 | data = np.nan_to_num(data, nan=0) 105 | 106 | return data 107 | 108 | def __len__(self): 109 | return self.n_samples_total 110 | 111 | def __getitem__(self, hour_idx): 112 | if self.files[hour_idx] is none: 113 | self._open_file(hour_idx) 114 | 115 | # %% get statistic value for normalization 116 | stats_file = os.path.join( 117 | self.params.data_path, f"stats_{self.params.norm_type}.csv" 118 | ) 119 | stats = pd.read_csv(stats_file, index_col=0) 120 | 121 | for vi, var in enumerate(self.params.inp_hrrr_vars): 122 | if vi == 0: 123 | inp_hrrr_stats = stats[stats["variable"].isin([var])] 124 | else: 125 | inp_hrrr_stats = pd.concat( 126 | [inp_hrrr_stats, stats[stats["variable"].isin([var])]] 127 | ) 128 | 129 | # %% read data in file 130 | if len(self.params.inp_hrrr_vars) != 0: 131 | inp_hrrr = np.array( 132 | self.files[hour_idx][self.params.inp_hrrr_vars].to_array() 133 | )[:, : self.params.img_size_y, : self.params.img_size_x] 134 | inp_hrrr = np.squeeze(inp_hrrr) 135 | # ic(inp_hrrr.shape) 136 | 137 | field_mask = inp_hrrr.copy() 138 | field_mask[field_mask != 0] = 1 # set 1 where out of range 139 | 140 | # normalization 141 | inp_hrrr = self._min_max_norm_ignore_extreme_fill_nan( 142 | inp_hrrr, inp_hrrr_stats["min"], inp_hrrr_stats["max"] 143 | ) 144 | 145 | if len(self.params.inp_obs_vars) != 0: 146 | obs = np.array( 147 | self.files[hour_idx][self.params.inp_obs_vars].to_array())[ 148 | :, 149 | -self.params.obs_time_window:, 150 | : self.params.img_size_y, 151 | : self.params.img_size_x, 152 | ] 153 | 154 | # inp_obs not includes nan, 0 means un-observaed location 155 | # use all observation as target 156 | obs_tar = obs[:, -1] 157 | 158 | # quality control 159 | obs_tar[(obs_tar <= -1) | (obs_tar >= 1)] = 0 160 | 161 | # this is for label, which is a combination of obs and analysis 162 | obs_tar_mask = obs_tar.copy() 163 | # 1 means observed, 0 means un-observed 164 | obs_tar_mask[obs_tar_mask != 0] = 1 165 | # print(f'obs_tar: {obs_tar.shape}') 166 | 167 | # print(f'inp_obs: {inp_obs.shape}') 168 | if self.params.hold_out_obs: 169 | # [lat, lon] 170 | # obs_mask = np.array(self.files[hour_idx]["obs_mask"]) 171 | # # print(f'obs_mask: {obs_mask.shape}') 172 | # inp_obs = inp_obs * (1 - obs_mask) 173 | 174 | if self.params.obs_mask_seed != 0: 175 | np.random.seed(self.params.obs_mask_seed) 176 | logging.info( 177 | f"using random seed {self.params.obs_mask_seed}") 178 | 179 | lat_num = obs[0, 0].shape[0] 180 | lon_num = obs[0, 0].shape[1] 181 | 182 | # [lat, lon] -> [lat * lon] 183 | obs_tw_begin = obs[0, 0].reshape(-1) 184 | obs_index = np.where(~np.isnan(obs_tw_begin))[ 185 | 0 186 | ] # find station's indices 187 | 188 | obs_num = len(obs_index) 189 | hold_out_num = int(obs_num * self.params.hold_out_obs_ratio) 190 | ic(obs_num, hold_out_num) 191 | 192 | np.random.shuffle(obs_index) # generate mask randomly 193 | hold_out_obs_index = obs_index[:hold_out_num] 194 | # input_obs_index = obs_index[hold_out_num:] 195 | # ic(len(hold_out_obs_index), hold_out_obs_index) 196 | # ic(len(input_obs_index), input_obs_index) 197 | 198 | # mask (lat, lon), hold_out obs=1, input obs = 0 199 | obs_mask = np.zeros(obs_tw_begin.shape) 200 | obs_mask[hold_out_obs_index] = 1 201 | obs_mask = obs_mask.reshape([lat_num, lon_num]) 202 | 203 | # observation for input 204 | inp_obs = obs * (1 - obs_mask) 205 | # observation excluding the input 206 | # hold_out_obs = obs * obs_mask 207 | 208 | # ic(inp_obs.shape) 209 | inp_obs = inp_obs.reshape( 210 | (-1, self.params.img_size_y, self.params.img_size_x) 211 | ) 212 | 213 | if len(self.params.inp_satelite_vars) != 0: 214 | inp_sate = np.array( 215 | self.files[hour_idx][self.params.inp_satelite_vars].to_array() 216 | )[ 217 | :, 218 | -self.params.obs_time_window:, 219 | : self.params.img_size_y, 220 | : self.params.img_size_x, 221 | ] 222 | 223 | lon = np.array(self.files[hour_idx].coords["lon"].values)[ 224 | : self.params.img_size_x 225 | ] 226 | lat = np.array(self.files[hour_idx].coords["lat"].values)[ 227 | : self.params.img_size_y 228 | ] 229 | topo = np.array(self.files[hour_idx][["z"]].to_array())[ 230 | :, : self.params.img_size_y, : self.params.img_size_x 231 | ] 232 | field_tar = np.array( 233 | self.files[hour_idx][self.params.field_tar_vars].to_array() 234 | )[:, : self.params.img_size_y, : self.params.img_size_x] 235 | 236 | # a combination of observation and target analysis field 237 | # use observed value to replace the analysis value 238 | # at observed locations 239 | field_obs_tar = field_tar.copy() 240 | # use 0 replace the value at observed location 241 | field_obs_tar[obs_tar_mask == 1] = 0 242 | field_obs_tar += obs_tar 243 | 244 | # norm(field_tar) - norm(inp_hrrr) 245 | # norm(obs_tar) - norm(inp_hrrr) 246 | # norm(field_obs_tar) - norm(inp_hrrr) 247 | if self.params.learn_residual: 248 | field_tar = field_tar - inp_hrrr 249 | obs_tar = obs_tar - inp_hrrr 250 | field_obs_tar = field_obs_tar - inp_hrrr 251 | 252 | inp = np.concatenate((inp_hrrr, inp_obs, topo), axis=0) 253 | 254 | if len(self.params.inp_satelite_vars) != 0: 255 | inp_sate = inp_sate.reshape( 256 | (-1, self.params.img_size_y, self.params.img_size_x) 257 | ) 258 | inp = np.concatenate((inp, inp_sate)) 259 | 260 | return ( 261 | inp, 262 | field_tar, 263 | obs_tar, 264 | field_obs_tar, 265 | inp_hrrr, 266 | lat, 267 | lon, 268 | field_mask, 269 | obs_tar_mask, 270 | ) 271 | else: 272 | return ( 273 | inp, 274 | field_tar, 275 | obs_tar, 276 | field_obs_tar, 277 | inp_hrrr, 278 | lat, 279 | lon, 280 | field_mask, 281 | obs_tar_mask, 282 | ) 283 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import torch 5 | import logging 6 | import datetime 7 | import argparse 8 | import numpy as np 9 | import pandas as pd 10 | import xarray as xr 11 | 12 | from math import sqrt 13 | from icecream import ic 14 | from str2bool import str2bool 15 | from collections import OrderedDict 16 | from sklearn.metrics import mean_squared_error 17 | 18 | from utils.logging_utils import config_logger, log_to_file 19 | from utils.YParams import YParams 20 | from utils.read_txt import read_lines_from_file 21 | 22 | config_logger() 23 | 24 | 25 | def gaussian_perturb(x, level=0.01, device=0): 26 | noise = level * torch.randn(x.shape).to(device, dtype=torch.float) 27 | return x + noise 28 | 29 | 30 | def enable_dropout(model): 31 | """Function to enable the dropout layers during test-time""" 32 | for m in model.modules(): 33 | if m.__class__.__name__.startswith("Dropout"): 34 | m.train() 35 | 36 | 37 | def load_model(model, params, checkpoint_file): 38 | model.zero_grad() 39 | checkpoint_fname = checkpoint_file 40 | checkpoint = torch.load(checkpoint_fname) 41 | try: 42 | new_state_dict = OrderedDict() 43 | for key, val in checkpoint["model_state"].items(): 44 | name = key[7:] 45 | if name != "ged": 46 | new_state_dict[name] = val 47 | model.load_state_dict(new_state_dict) 48 | except ValueError: 49 | model.load_state_dict(checkpoint["model_state"]) 50 | model.eval() 51 | return model 52 | 53 | 54 | def setup(params): 55 | 56 | # device init 57 | if torch.cuda.is_available(): 58 | device = torch.cuda.current_device() 59 | else: 60 | device = "cpu" 61 | 62 | if params.nettype == "EncDec": 63 | from models.encdec import EncDec as model 64 | else: 65 | raise Exception("not implemented") 66 | 67 | checkpoint_file = params["best_checkpoint_path"] 68 | logging.info("Loading model checkpoint from {}".format(checkpoint_file)) 69 | model = model(params).to(device) 70 | model = load_model(model, params, checkpoint_file) 71 | model = model.to(device) 72 | 73 | files_paths = glob.glob(params.test_data_path + "/*.nc") 74 | files_paths.sort() 75 | 76 | return files_paths, inference_times, model 77 | 78 | 79 | def min_max_norm(data, vmin, vmax): 80 | if data.ndim == 4: 81 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis, np.newaxis] 82 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis, np.newaxis] 83 | vmax = np.repeat(vmax, data.shape[1], axis=1) 84 | vmax = np.repeat(vmax, data.shape[2], axis=2) 85 | vmax = np.repeat(vmax, data.shape[3], axis=3) 86 | vmin = np.repeat(vmin, data.shape[1], axis=1) 87 | vmin = np.repeat(vmin, data.shape[2], axis=2) 88 | vmin = np.repeat(vmin, data.shape[3], axis=3) 89 | 90 | elif data.ndim == 3: 91 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis] 92 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis] 93 | vmax = np.repeat(vmax, data.shape[1], axis=1) 94 | vmax = np.repeat(vmax, data.shape[2], axis=2) 95 | vmin = np.repeat(vmin, data.shape[1], axis=1) 96 | vmin = np.repeat(vmin, data.shape[2], axis=2) 97 | 98 | data = (data - vmin) / (vmax - vmin) 99 | return data 100 | 101 | 102 | def min_max_norm_ignore_extreme_fill_nan(data, vmin, vmax): 103 | if data.ndim == 4: 104 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis, np.newaxis] 105 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis, np.newaxis] 106 | vmax = np.repeat(vmax, data.shape[1], axis=1) 107 | vmax = np.repeat(vmax, data.shape[2], axis=2) 108 | vmax = np.repeat(vmax, data.shape[3], axis=3) 109 | vmin = np.repeat(vmin, data.shape[1], axis=1) 110 | vmin = np.repeat(vmin, data.shape[2], axis=2) 111 | vmin = np.repeat(vmin, data.shape[3], axis=3) 112 | elif data.ndim == 3: 113 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis] 114 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis] 115 | vmax = np.repeat(vmax, data.shape[1], axis=1) 116 | vmax = np.repeat(vmax, data.shape[2], axis=2) 117 | vmin = np.repeat(vmin, data.shape[1], axis=1) 118 | vmin = np.repeat(vmin, data.shape[2], axis=2) 119 | 120 | data -= vmin 121 | data *= 2.0 / (vmax - vmin) 122 | data -= 1.0 123 | 124 | data = np.where(data > 1, 1, data) 125 | data = np.where(data < -1, -1, data) 126 | data = np.nan_to_num(data, nan=0) 127 | 128 | return data 129 | 130 | def reverse_norm(params, data, variable_names): 131 | 132 | stats_file = os.path.join( 133 | params.data_path, 134 | f"stats_{params.norm_type}.csv") 135 | stats = pd.read_csv(stats_file, index_col=0) 136 | 137 | # for vi, var in enumerate(params.field_tar_vars): 138 | for vi, var in enumerate(variable_names): 139 | if vi == 0: 140 | field_tar_stats = stats[stats["variable"].isin([var])] 141 | else: 142 | field_tar_stats = pd.concat( 143 | [field_tar_stats, stats[stats["variable"].isin([var])]] 144 | ) 145 | ic(field_tar_stats) 146 | 147 | if params.normalization == "minmax_ignore_extreme": 148 | vmin = field_tar_stats["min"] 149 | vmax = field_tar_stats["max"] 150 | 151 | if len(data.shape) == 4: 152 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis, np.newaxis] 153 | vmin = np.repeat(vmin, data.shape[1], axis=1) 154 | vmin = np.repeat(vmin, data.shape[2], axis=2) 155 | vmin = np.repeat(vmin, data.shape[3], axis=3) 156 | vmin = np.squeeze(vmin) 157 | 158 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis, np.newaxis] 159 | vmax = np.repeat(vmax, data.shape[1], axis=1) 160 | vmax = np.repeat(vmax, data.shape[2], axis=2) 161 | vmax = np.repeat(vmax, data.shape[3], axis=3) 162 | vmax = np.squeeze(vmax) 163 | 164 | if len(data.shape) == 3: 165 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis] 166 | vmin = np.repeat(vmin, data.shape[1], axis=1) 167 | vmin = np.repeat(vmin, data.shape[2], axis=2) 168 | vmin = np.squeeze(vmin) 169 | 170 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis] 171 | vmax = np.repeat(vmax, data.shape[1], axis=1) 172 | vmax = np.repeat(vmax, data.shape[2], axis=2) 173 | vmax = np.squeeze(vmax) 174 | 175 | if len(data.shape) == 2: 176 | vmin = np.array(vmin)[np.newaxis, np.newaxis] 177 | vmin = np.repeat(vmin, data.shape[0], axis=0) 178 | vmin = np.repeat(vmin, data.shape[1], axis=1) 179 | vmin = np.squeeze(vmin) 180 | 181 | vmax = np.array(vmax)[np.newaxis, np.newaxis] 182 | vmax = np.repeat(vmax, data.shape[0], axis=0) 183 | vmax = np.repeat(vmax, data.shape[1], axis=1) 184 | vmax = np.squeeze(vmax) 185 | 186 | data = (data + 1) * (vmax - vmin) / 2 + vmin 187 | 188 | else: 189 | raise Exception("not implemented") 190 | 191 | return data 192 | 193 | 194 | def inference( 195 | params, 196 | target_variable, 197 | test_data_file_paths, 198 | inference_times, 199 | hold_out_obs_ratio, 200 | model, 201 | ): 202 | if torch.cuda.is_available(): 203 | device = torch.cuda.current_device() 204 | else: 205 | device = "cpu" 206 | 207 | out_dir = os.path.join( 208 | params["experiment_dir"], 209 | f"inference_ensemble_{params.ensemble_num}_hold_{hold_out_obs_ratio}", 210 | ) 211 | os.makedirs(out_dir, exist_ok=True) 212 | 213 | with torch.no_grad(): 214 | for f, analysis_time_str in zip(test_data_file_paths, inference_times): 215 | analysis_time = datetime.datetime.strptime( 216 | analysis_time_str, "%Y-%m-%d_%H") 217 | logging.info("-----------------------------------------") 218 | logging.info(f"Analysis time: {analysis_time_str}") 219 | logging.info(f"Reading {f}") 220 | 221 | out_file = os.path.join( 222 | out_dir, analysis_time_str + ".nc") 223 | 224 | if not os.path.exists(f): 225 | logging.info(f"{f} not exists, skip!") 226 | continue 227 | 228 | data = read_sample_file_and_norm_input( 229 | params, f, hold_out_obs_ratio) 230 | ( 231 | inp, # normed 232 | inp_sate_norm, # normed 233 | inp_hrrr_norm, # normed 234 | field_target, # normed 235 | hold_out_obs, # normed 236 | inp_obs_for_eval, # normed 237 | bg_hrrr, # un-norm 238 | mask, 239 | lat, 240 | lon, 241 | ) = data 242 | 243 | field_target = reverse_norm( 244 | params, field_target, params.field_tar_vars) 245 | # mask the region out of range, after reverse normalization 246 | field_target = np.where( 247 | mask, field_target, np.nan 248 | ) # fill data with nan where mask is True. 249 | bg_hrrr = np.where( 250 | mask, bg_hrrr, np.nan 251 | ) # fill data with nan where mask is True. 252 | 253 | # Reverse normalization 254 | hold_out_obs = np.where( 255 | hold_out_obs == 0, np.nan, hold_out_obs 256 | ) # fill 0 with nan before normalization reversing 257 | hold_out_obs = reverse_norm( 258 | params, hold_out_obs, params.inp_obs_vars) 259 | 260 | inp_obs_for_eval = np.where( 261 | inp_obs_for_eval == 0, np.nan, inp_obs_for_eval 262 | ) # fill 0 with nan before normalization reversing 263 | inp_obs_for_eval = reverse_norm( 264 | params, inp_obs_for_eval, params.inp_obs_vars 265 | ) 266 | 267 | # Unit convert: g/kg -> kg/kg 268 | q_idx = target_variable.index("q") 269 | field_target[q_idx] = field_target[q_idx] * 1000 270 | hold_out_obs[q_idx] = hold_out_obs[q_idx] * 1000 271 | inp_obs_for_eval[q_idx] = inp_obs_for_eval[q_idx] * 1000 272 | bg_hrrr[q_idx] = bg_hrrr[q_idx] * 1000 273 | 274 | inp = torch.tensor(inp[np.newaxis, :, :, :]).to( 275 | device, dtype=torch.float) 276 | inp_sate_norm = torch.tensor( 277 | inp_sate_norm[np.newaxis, :, :, :, :]).to( 278 | device, dtype=torch.float 279 | ) 280 | 281 | gen_ensembles = [] 282 | for i in range(params.ensemble_num): 283 | model.eval() 284 | enable_dropout(model) 285 | 286 | start = time.time() 287 | 288 | if params.nettype == "EncDec": 289 | inp_sate_norm = torch.reshape( 290 | inp_sate_norm, 291 | (1, -1, params.img_size_y, params.img_size_x) 292 | ) 293 | inp = torch.concat((inp, inp_sate_norm), 1) 294 | gen = model(inp) 295 | else: 296 | raise Exception("not implemented") 297 | 298 | print(f"inference time: {time.time() - start}") 299 | 300 | if params.learn_residual: 301 | gen = np.squeeze( 302 | gen.detach().cpu().numpy()) + inp_hrrr_norm 303 | 304 | # reverse normalization 305 | gen = reverse_norm(params, gen, params.field_tar_vars) 306 | 307 | # for specific humidity, g/kg -> kg/kg 308 | gen[q_idx] = gen[q_idx] * 1000 309 | 310 | # mask the region out of range with 0, 311 | # after reverse normalization 312 | gen = np.where(mask, gen, np.nan) 313 | 314 | gen_ensembles.append(gen) 315 | 316 | gen_ensembles = np.array(gen_ensembles) 317 | 318 | for vi, tar_var in enumerate(target_variable): 319 | logging.info(f"{tar_var}:") 320 | 321 | # %% compare with hold-out observation 322 | obs_hold_obs_var = hold_out_obs[vi][ 323 | ~np.isnan(hold_out_obs[vi])] 324 | if len(obs_hold_obs_var) == 0: 325 | print("No hold out obs, continue!") 326 | continue 327 | bg_hrrr_hold_obs = bg_hrrr[vi][ 328 | ~np.isnan(hold_out_obs[vi])] 329 | ai_gen_hold_obs = gen_ensembles[0, vi][ 330 | ~np.isnan(hold_out_obs[vi])] 331 | obs_hold_obs_var = np.nan_to_num( 332 | obs_hold_obs_var, nan=0) 333 | bg_hrrr_hold_obs = np.nan_to_num( 334 | bg_hrrr_hold_obs, nan=0) 335 | ai_gen_hold_obs = np.nan_to_num( 336 | ai_gen_hold_obs, nan=0) 337 | rmse_ai_hold_obs = round( 338 | sqrt(mean_squared_error( 339 | obs_hold_obs_var, ai_gen_hold_obs)), 3 340 | ) 341 | rmse_bg_hold_obs = round( 342 | sqrt(mean_squared_error( 343 | obs_hold_obs_var, bg_hrrr_hold_obs)), 3 344 | ) 345 | logging.info(f"rmse_ai_hold_obs={rmse_ai_hold_obs}") 346 | logging.info(f"rmse_bg_hold_obs={rmse_bg_hold_obs}") 347 | 348 | # %% compare with input observation 349 | inp_obs_for_eval_var = inp_obs_for_eval[vi][ 350 | ~np.isnan(inp_obs_for_eval[vi]) 351 | ] 352 | bg_hrrr_inp_obs = bg_hrrr[vi][ 353 | ~np.isnan(inp_obs_for_eval[vi])] 354 | ai_gen_inp_obs = gen_ensembles[0, vi][ 355 | ~np.isnan(inp_obs_for_eval[vi])] 356 | inp_obs_for_eval_var = np.nan_to_num( 357 | inp_obs_for_eval_var, nan=0) 358 | bg_hrrr_inp_obs = np.nan_to_num( 359 | bg_hrrr_inp_obs, nan=0) 360 | ai_gen_inp_obs = np.nan_to_num( 361 | ai_gen_inp_obs, nan=0) 362 | rmse_ai_inp_obs = round( 363 | sqrt(mean_squared_error( 364 | inp_obs_for_eval_var, ai_gen_inp_obs)), 3 365 | ) 366 | rmse_bg_inp_obs = round( 367 | sqrt(mean_squared_error( 368 | inp_obs_for_eval_var, bg_hrrr_inp_obs)), 3 369 | ) 370 | logging.info(f"rmse_ai_inp_obs={rmse_ai_inp_obs}") 371 | logging.info(f"rmse_bg_inp_obs={rmse_bg_inp_obs}") 372 | 373 | rmse_ai_field = round( 374 | sqrt( 375 | mean_squared_error( 376 | np.nan_to_num(gen_ensembles[0, vi], nan=0), 377 | np.nan_to_num(field_target[vi], nan=0), 378 | ) 379 | ), 380 | 3, 381 | ) 382 | logging.info(f"rmse_ai_field={rmse_ai_field}") 383 | 384 | rmse_bg_field = round( 385 | sqrt( 386 | mean_squared_error( 387 | np.nan_to_num(bg_hrrr[vi], nan=0), 388 | np.nan_to_num(field_target[vi], nan=0), 389 | ) 390 | ), 391 | 3, 392 | ) 393 | logging.info(f"rmse_bg_field={rmse_bg_field}") 394 | 395 | logging.info( 396 | "AI generation :" 397 | + f"{round(np.nanmin(gen_ensembles[0,vi]), 3)}" 398 | + f"~ {round(np.nanmax(gen_ensembles[0,vi]), 3)}" 399 | ) 400 | logging.info( 401 | "Background (hrrr):" 402 | + f"{round(np.nanmin(bg_hrrr[vi]), 3)}" 403 | + f"~ {round(np.nanmax(bg_hrrr[vi]), 3)}" 404 | ) 405 | logging.info( 406 | "hold out obs :" 407 | + f"{round(np.nanmin(hold_out_obs[vi]), 3)}" 408 | + f"~ {round(np.nanmax(hold_out_obs[vi]), 3)}" 409 | ) 410 | logging.info( 411 | "field_target :" 412 | + f"{round(np.nanmin(field_target[vi]), 3)}" 413 | + f"~ {round(np.nanmax(field_target[vi]), 3)}" 414 | ) 415 | 416 | variable_names = [s.split("_")[1] for s in params.field_tar_vars] 417 | ic(variable_names) 418 | save_output( 419 | save_file_path=out_file, 420 | analysis_time=analysis_time, 421 | variable_names=variable_names, 422 | AI_gen_ensembles=gen_ensembles, # un-normed 423 | bg_hrrr=bg_hrrr, # un-normed 424 | field_target=field_target, # un-normed 425 | inp_obs_for_eval=inp_obs_for_eval, # un-normed 426 | hold_out_obs=hold_out_obs, # un-normed 427 | lon=lon, 428 | lat=lat, 429 | mask=mask, 430 | ) 431 | 432 | 433 | def save_output( 434 | save_file_path: str, 435 | variable_names: list, 436 | AI_gen_ensembles: np.array, 437 | bg_hrrr: np.array, 438 | field_target: np.array, 439 | hold_out_obs: np.array, 440 | inp_obs_for_eval: np.array, 441 | lon: np.array, 442 | lat: np.array, 443 | mask: np.array, 444 | analysis_time: str, 445 | ): 446 | 447 | ic( 448 | lon.shape, 449 | lat.shape, 450 | AI_gen_ensembles.shape, 451 | bg_hrrr.shape, 452 | field_target.shape, 453 | hold_out_obs.shape, 454 | ) 455 | ds = xr.Dataset( 456 | { 457 | f"ai_gen_{variable_names[0]}": ( 458 | ("ensemble_num", "lat", "lon"), 459 | AI_gen_ensembles[:, 0], 460 | ), 461 | f"ai_gen_{variable_names[1]}": ( 462 | ("ensemble_num", "lat", "lon"), 463 | AI_gen_ensembles[:, 1], 464 | ), 465 | f"ai_gen_{variable_names[2]}": ( 466 | ("ensemble_num", "lat", "lon"), 467 | AI_gen_ensembles[:, 2], 468 | ), 469 | f"ai_gen_{variable_names[3]}": ( 470 | ("ensemble_num", "lat", "lon"), 471 | AI_gen_ensembles[:, 3], 472 | ), 473 | # f"ai_gen_{variable_names[4]}": ( 474 | # ("ensemble_num", "lat", "lon"), 475 | # AI_gen_ensembles[:, 4], 476 | # ), 477 | f"hold_out_obs_{variable_names[0]}": ( 478 | ("lat", "lon"), hold_out_obs[0]), 479 | f"hold_out_obs_{variable_names[1]}": ( 480 | ("lat", "lon"), hold_out_obs[1]), 481 | f"hold_out_obs_{variable_names[2]}": ( 482 | ("lat", "lon"), hold_out_obs[2]), 483 | f"hold_out_obs_{variable_names[3]}": ( 484 | ("lat", "lon"), hold_out_obs[3]), 485 | # f"hold_out_obs_{variable_names[4]}": ( 486 | # ("lat", "lon"), hold_out_obs[4]), 487 | f"inp_obs_for_eval_{variable_names[0]}": ( 488 | ("lat", "lon"), 489 | inp_obs_for_eval[0], 490 | ), 491 | f"inp_obs_for_eval_{variable_names[1]}": ( 492 | ("lat", "lon"), 493 | inp_obs_for_eval[1], 494 | ), 495 | f"inp_obs_for_eval_{variable_names[2]}": ( 496 | ("lat", "lon"), 497 | inp_obs_for_eval[2], 498 | ), 499 | f"inp_obs_for_eval_{variable_names[3]}": ( 500 | ("lat", "lon"), 501 | inp_obs_for_eval[3], 502 | ), 503 | f"rtma_{variable_names[0]}": ( 504 | ("lat", "lon"), field_target[0]), 505 | f"rtma_{variable_names[1]}": ( 506 | ("lat", "lon"), field_target[1]), 507 | f"rtma_{variable_names[2]}": ( 508 | ("lat", "lon"), field_target[2]), 509 | f"rtma_{variable_names[3]}": ( 510 | ("lat", "lon"), field_target[3]), 511 | # f"rtma_{variable_names[4]}": ( 512 | # ("lat", "lon"), field_target[4]), 513 | f"bg_hrrr_{variable_names[0]}": ( 514 | ("lat", "lon"), bg_hrrr[0]), 515 | f"bg_hrrr_{variable_names[1]}": ( 516 | ("lat", "lon"), bg_hrrr[1]), 517 | f"bg_hrrr_{variable_names[2]}": ( 518 | ("lat", "lon"), bg_hrrr[2]), 519 | f"bg_hrrr_{variable_names[3]}": ( 520 | ("lat", "lon"), bg_hrrr[3]), 521 | # f"bg_hrrr_{variable_names[4]}": ( 522 | # ("lat", "lon"), bg_hrrr[4]), 523 | "variable_names": ( 524 | ("variable_num"), variable_names), 525 | "mask": ( 526 | ("lat", "lon"), mask[0]), 527 | }, 528 | coords={ 529 | "lat": lat, 530 | "lon": lon, 531 | "time": analysis_time, 532 | "time_window": np.arange(0, params.obs_time_window), 533 | "ensemble_num": np.arange(0, params.ensemble_num), 534 | }, 535 | ) 536 | logging.info(f"Saving result to {save_file_path}") 537 | ds.to_netcdf(save_file_path) 538 | ds.close() 539 | 540 | 541 | def read_sample_file_and_norm_input( 542 | params, 543 | file_path, 544 | hold_out_obs_ratio=0.2 545 | ): 546 | 547 | # %% get statistic 548 | stats_file = os.path.join( 549 | params.data_path, f"stats.csv") 550 | 551 | stats = pd.read_csv(stats_file, index_col=0) 552 | 553 | for vi, var in enumerate(params.inp_hrrr_vars): 554 | if vi == 0: 555 | inp_hrrr_stats = stats[stats["variable"].isin([var])] 556 | else: 557 | inp_hrrr_stats = pd.concat( 558 | [inp_hrrr_stats, stats[stats["variable"].isin([var])]] 559 | ) 560 | 561 | for vi, var in enumerate(params.inp_obs_vars): 562 | if vi == 0: 563 | inp_obs_stats = stats[stats["variable"].isin([var])] 564 | else: 565 | inp_obs_stats = pd.concat( 566 | [inp_obs_stats, stats[stats["variable"].isin([var])]] 567 | ) 568 | 569 | # %% get sample 570 | ds = xr.open_dataset(file_path, engine="netcdf4") 571 | 572 | lat = np.array(ds.coords["lat"].values)[: params.img_size_y] 573 | lon = np.array(ds.coords["lon"].values)[: params.img_size_x] 574 | 575 | # background 576 | inp_hrrr = np.array(ds[params.inp_hrrr_vars].to_array())[ 577 | :, : params.img_size_y, : params.img_size_x 578 | ] 579 | inp_hrrr = np.squeeze(inp_hrrr) 580 | mask = inp_hrrr.copy() 581 | mask[mask != 0] = 1 # set 1 where out of range 582 | mask = mask.astype(bool) # True: out of range 583 | 584 | # baseline for evaluation 585 | bg_hrrr = inp_hrrr.copy() 586 | 587 | # normalization 588 | inp_hrrr = min_max_norm_ignore_extreme_fill_nan( 589 | inp_hrrr, inp_hrrr_stats["min"], inp_hrrr_stats["max"] 590 | ) 591 | 592 | # topography (normed) 593 | topo = np.array(ds[["z"]].to_array())[ 594 | :, :params.img_size_y, :params.img_size_x] 595 | 596 | # satellite 597 | inp_sate = np.array [params.inp_satelite_vars].to_array())[ 598 | :, -params.obs_time_window:, :params.img_size_y, :params.img_size_x 599 | ] 600 | 601 | # Observation (normed) 602 | obs = np.array(ds[params.inp_obs_vars].to_array())[ 603 | :, -params.obs_time_window:, :params.img_size_y, :params.img_size_x 604 | ] 605 | # quality control 606 | obs[(obs <= -1) | (obs >= 1)] = 0 607 | 608 | if params.hold_out_obs: 609 | 610 | if params.seed != 0: 611 | np.random.seed(params.seed) 612 | logging.info(f"Using random seed {params.seed}") 613 | 614 | lat_num = obs[0, 0].shape[0] 615 | lon_num = obs[0, 0].shape[1] 616 | 617 | # [lat, lon] -> [lat * lon] 618 | obs_tw_begin = obs[0, 0].reshape(-1) 619 | # find station's indices 620 | obs_index = np.where(~np.isnan(obs_tw_begin))[0] 621 | 622 | obs_num = len(obs_index) 623 | hold_out_num = int(obs_num * hold_out_obs_ratio) 624 | ic(obs_num, hold_out_num) 625 | 626 | # generate mask randomly 627 | np.random.shuffle(obs_index) 628 | hold_out_obs_index = obs_index[:hold_out_num] 629 | input_obs_index = obs_index[hold_out_num:] 630 | ic(len(hold_out_obs_index), hold_out_obs_index) 631 | ic(len(input_obs_index), input_obs_index) 632 | 633 | # Mask (lat, lon), hold_out obs=1, input obs = 0 634 | obs_mask = np.zeros(obs_tw_begin.shape) 635 | obs_mask[hold_out_obs_index] = 1 636 | obs_mask = obs_mask.reshape([lat_num, lon_num]) 637 | 638 | inp_obs = obs * (1 - obs_mask) # observation for input 639 | hold_out_obs = obs * obs_mask # observation excluding the input 640 | 641 | inp_obs_for_eval = inp_obs[:, -1] # -1: analysis time 642 | print(f"inp_obs_for_eval: {inp_obs_for_eval.shape}") 643 | 644 | inp_obs = inp_obs.reshape((-1, params.img_size_y, params.img_size_x)) 645 | 646 | # %% target (normed) 647 | field_target = np.array(ds[params.field_tar_vars].to_array())[ 648 | :, : params.img_size_y, : params.img_size_x 649 | ] 650 | ic(field_target.shape) 651 | 652 | inp = np.concatenate((inp_hrrr, inp_obs, topo), axis=0) 653 | 654 | return ( 655 | inp, # normed 656 | inp_sate, # normed 657 | inp_hrrr, # normed 658 | field_target, # normed 659 | hold_out_obs[:, -1], # normed, -1: analaysis time 660 | inp_obs_for_eval, # normed 661 | bg_hrrr, # un-normed 662 | mask, # out_of_range mask 663 | lat, 664 | lon, 665 | ) 666 | 667 | 668 | if __name__ == "__main__": 669 | parser = argparse.ArgumentParser() 670 | parser.add_argument("--seed", default=0, type=int) 671 | parser.add_argument("--exp_dir", default="", type=str) 672 | parser.add_argument("--test_data_path", default="", type=str) 673 | parser.add_argument("--net_config", default="EncDec", type=str) 674 | parser.add_argument("--hold_out_obs_ratio", type=float, default=0.2) 675 | 676 | args = parser.parse_args() 677 | 678 | config_path = os.path.join(args.exp_dir, "config.yaml") 679 | 680 | params = YParams(config_path, args.net_config) 681 | params["resuming"] = False 682 | params["seed"] = args.seed 683 | params["experiment_dir"] = args.exp_dir 684 | params["test_data_path"] = args.test_data_path 685 | params["best_checkpoint_path"] = os.path.join( 686 | params["experiment_dir"], "training_checkpoints", "best_ckpt.tar" 687 | ) 688 | 689 | # set up logging 690 | log_to_file( 691 | logger_name=None, 692 | log_filename=os.path.join(params["experiment_dir"], "inference.log"), 693 | ) 694 | params.log() 695 | 696 | # get data files and model 697 | test_data_file_paths, inference_times, model = setup(params) 698 | 699 | target_variable = [var.split("_")[1] for var in params.field_tar_vars] 700 | 701 | inference( 702 | params, 703 | target_variable, 704 | test_data_file_paths, 705 | inference_times, 706 | args.hold_out_obs_ratio, 707 | model, 708 | ) 709 | 710 | logging.info("Done") 711 | 712 | 713 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import wandb 4 | import random 5 | import datetime 6 | import argparse 7 | import numpy as np 8 | 9 | from str2bool import str2bool 10 | from icecream import ic 11 | from shutil import copyfile 12 | from apex import optimizers 13 | from collections import OrderedDict 14 | 15 | import torch 16 | import torch.cuda.amp as amp 17 | import torch.distributed as dist 18 | from torch.nn import functional as F 19 | from torch.nn.parallel import DistributedDataParallel 20 | 21 | from ruamel.yaml import YAML 22 | from ruamel.yaml.comments import CommentedMap as ruamelDict 23 | 24 | from utils.data_loader_multifiles import get_data_loader 25 | from utils.logging_utils import log_to_file 26 | from utils.YParams import YParams 27 | 28 | 29 | class Trainer: 30 | def count_parameters(self): 31 | count_params = 0 32 | for p in self.model.parameters(): 33 | if p.requires_grad: 34 | count_params += p.numel() 35 | 36 | def set_device(self): 37 | if torch.cuda.is_available(): 38 | self.device = torch.cuda.current_device() 39 | else: 40 | self.device = "cpu" 41 | 42 | def __init__(self, params, world_rank): 43 | self.params = params 44 | self.world_rank = world_rank 45 | self.set_device() 46 | 47 | # %% init wandb 48 | if params.log_to_wandb: 49 | wandb.init( 50 | config=params, 51 | name=params.name, 52 | group=params.group, 53 | project=params.project, 54 | entity=params.entity, 55 | settings={"_service_wait": 600, "init_timeout": 600}, 56 | ) 57 | 58 | # %% init gpu 59 | local_rank = int(os.environ["LOCAL_RANK"]) 60 | torch.cuda.set_device(local_rank) 61 | self.device = torch.device("cuda", local_rank) 62 | print("device: %s" % self.device) 63 | 64 | # %% model init 65 | if params.nettype == "EncDec": 66 | from models.encdec import EncDec as model 67 | elif params.nettype == "EncDec_two_encoder": 68 | from models.encdec import EncDec_two_encoder as model 69 | else: 70 | raise Exception("not implemented") 71 | self.model = model(params).to(self.device) 72 | # self.model = model(params).to(local_rank) # for torchrun 73 | 74 | # %% Load data 75 | print("rank %d, begin data loader init" % world_rank) 76 | ( 77 | self.train_data_loader, 78 | self.train_dataset, 79 | self.train_sampler, 80 | ) = get_data_loader( 81 | params, 82 | params.train_data_path, 83 | dist.is_initialized(), 84 | train=True, 85 | ) 86 | ( 87 | self.valid_data_loader, 88 | self.valid_dataset, 89 | self.valid_sampler, 90 | ) = get_data_loader( 91 | params, 92 | params.valid_data_path, 93 | dist.is_initialized(), 94 | train=True, 95 | ) 96 | 97 | # %% optimizer 98 | if params.optimizer_type == "FusedAdam": 99 | self.optimizer = optimizers.FusedAdam( 100 | self.model.parameters(), lr=params.lr) 101 | elif params.optimizer_type == "Adam": 102 | self.optimizer = torch.optim.Adam( 103 | self.model.parameters(), lr=params.lr) 104 | elif params.optimizer_type == "AdamW": 105 | self.optimizer = torch.optim.AdamW( 106 | self.model.parameters(), lr=params.lr) 107 | else: 108 | raise Exception("not implemented") 109 | 110 | if params.enable_amp: 111 | self.gscaler = amp.GradScaler() 112 | 113 | # %% DDP 114 | if dist.is_initialized(): 115 | ic(local_rank) 116 | self.model = DistributedDataParallel( 117 | self.model, 118 | device_ids=[params.local_rank], 119 | output_device=[params.local_rank], 120 | find_unused_parameters=True, 121 | ) 122 | self.iters = 0 123 | self.startEpoch = 0 124 | self.plot = False 125 | self.plot_img_path = None 126 | 127 | # %% Dynamical Learning rate 128 | if params.scheduler == "ReduceLROnPlateau": 129 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 130 | self.optimizer, 131 | factor=params.lr_reduce_factor, 132 | patience=20, 133 | mode="min", 134 | ) 135 | elif params.scheduler == "CosineAnnealingLR": 136 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 137 | self.optimizer, 138 | T_max=params.max_epochs, 139 | last_epoch=self.startEpoch - 1, 140 | ) 141 | else: 142 | self.scheduler = None 143 | 144 | # %% Resume train 145 | if params.resuming: 146 | print(f"Loading checkpoint from {params.best_checkpoint_path}") 147 | self.restore_checkpoint(params.best_checkpoint_path) 148 | 149 | self.epoch = self.startEpoch 150 | 151 | if params.log_to_screen: 152 | print( 153 | f"Number of trainable model parameters: \ 154 | {self.count_parameters()}" 155 | ) 156 | 157 | if params.log_to_wandb: 158 | wandb.watch(self.model) 159 | 160 | def train(self): 161 | if self.params.log_to_screen: 162 | print("Starting Training Loop...") 163 | 164 | # best_valid_obs_loss = 1.0e6 165 | best_train_loss = 1.0e6 166 | 167 | for epoch in range(self.startEpoch, self.params.max_epochs): 168 | if dist.is_initialized(): 169 | # different batch on each GPU 170 | self.train_sampler.set_epoch(epoch) 171 | self.valid_sampler.set_epoch(epoch) 172 | start = time.time() 173 | 174 | # train one epoch 175 | tr_time, data_time, step_time, train_logs = self.train_one_epoch() 176 | self.plot = False 177 | self.plot_img_path = None 178 | current_lr = self.optimizer.param_groups[0]["lr"] 179 | 180 | if self.params.log_to_screen: 181 | print(f"Epoch: {epoch + 1}") 182 | print(f"train data time={data_time}") 183 | print(f"train per step time={step_time}") 184 | print(f"train loss: {train_logs['loss_field']}") 185 | print(f"learning rate: {current_lr}") 186 | 187 | # valid one epoch 188 | if (epoch != 0) and (epoch % self.params.valid_frequency == 0): 189 | valid_time, valid_logs = self.validate_one_epoch() 190 | 191 | if self.params.log_to_screen: 192 | print(f"Epoch: {epoch + 1}") 193 | print(f"Valid time={valid_time}") 194 | print(f"Valid loss={valid_logs['valid_loss_field']}") 195 | 196 | # LR scheduler 197 | if self.params.scheduler == "ReduceLROnPlateau": 198 | self.scheduler.step(valid_logs["valid_loss_field"]) 199 | 200 | if self.params.log_to_wandb: 201 | wandb.log({"lr": current_lr}) 202 | 203 | # save model 204 | if ( 205 | self.world_rank == 0 206 | and epoch % self.params.save_model_freq == 0 207 | and self.params.save_checkpoint 208 | ): 209 | self.save_checkpoint(self.params.checkpoint_path) 210 | 211 | if self.world_rank == 0 and self.params.save_checkpoint: 212 | if train_logs["loss_field"] <= best_train_loss: 213 | print( 214 | "Loss improved from {} to {}".format( 215 | best_train_loss, train_logs["loss_field"] 216 | ) 217 | ) 218 | best_train_loss = train_logs["loss_field"] 219 | 220 | start = time.time() 221 | self.save_checkpoint(self.params.best_checkpoint_path) 222 | print(f"save model time: {time.time() - start}") 223 | 224 | def loss_function( 225 | self, 226 | pre_field, 227 | tar_field, 228 | tar_obs, 229 | tar_field_obs, 230 | field_mask=None, 231 | obs_tar_mask=None, 232 | mask_out_of_range=True, 233 | ): 234 | """ 235 | pre_field: model's output 236 | tar_field: label, after normalization 237 | """ 238 | 239 | if mask_out_of_range: 240 | pre_field = torch.masked_fill( 241 | input=pre_field, mask=~field_mask, value=0 242 | ) # fill input with 0 where field_mask is True. 243 | tar_field = torch.masked_fill( 244 | input=tar_field, mask=~field_mask, value=0 245 | ) # fill input with 0 where field_mask is True. 246 | tar_field_obs = torch.masked_fill( 247 | input=tar_field_obs, mask=~field_mask, value=0 248 | ) # fill input with 0 where field_mask is True. 249 | 250 | # type 1 loss 251 | loss_field = F.mse_loss( 252 | pre_field, tar_field) 253 | loss_field_channel_wise = F.mse_loss( 254 | pre_field, tar_field, reduction="none") 255 | loss_field_channel_wise = torch.mean( 256 | loss_field_channel_wise, dim=(0, 2, 3)) 257 | 258 | # type 2 loss 259 | loss_field_obs = F.mse_loss( 260 | pre_field, tar_field_obs) 261 | 262 | # type 3 loss 263 | pre_field = torch.masked_fill( 264 | input=pre_field, mask=~obs_tar_mask, value=0 265 | ) # fill input with 0 where mask is True. 266 | tar_obs = torch.masked_fill( 267 | input=tar_obs, mask=~obs_tar_mask, value=0) 268 | loss_obs = F.mse_loss( 269 | pre_field, tar_obs) 270 | loss_obs_channel_wise = F.mse_loss( 271 | pre_field, tar_obs, reduction="none") 272 | loss_obs_channel_wise = torch.mean( 273 | loss_obs_channel_wise, dim=(0, 2, 3)) 274 | 275 | return { 276 | "loss_field": loss_field, 277 | "loss_field_channel_wise": loss_field_channel_wise, 278 | "loss_obs": loss_obs, 279 | "loss_obs_channel_wise": loss_obs_channel_wise, 280 | "loss_field_obs": loss_field_obs, 281 | } 282 | 283 | def train_one_epoch(self): 284 | print("Training...") 285 | self.epoch += 1 286 | if self.params.resuming: 287 | self.resumeEpoch += 1 288 | tr_time = 0 289 | data_time = 0 290 | steps_in_one_epoch = 0 291 | loss_field = 0 292 | loss_obs = 0 293 | loss_field_obs = 0 294 | loss_field_channel_wise = torch.zeros( 295 | len(self.params.target_vars), device=self.device, dtype=float 296 | ) 297 | loss_obs_channel_wise = torch.zeros( 298 | len(self.params.target_vars), device=self.device, dtype=float 299 | ) 300 | 301 | self.model.train() 302 | for i, data in enumerate(self.train_data_loader, 0): 303 | self.iters += 1 304 | steps_in_one_epoch += 1 305 | data_start = time.time() 306 | 307 | if self.params.nettype == "EncDec_two_encoder": 308 | ( 309 | inp, 310 | inp_sate, 311 | target_field, 312 | target_obs, 313 | target_field_obs, 314 | inp_hrrr, 315 | _, 316 | _, 317 | field_mask, 318 | obs_tar_mask, 319 | ) = data 320 | if self.params.nettype == "EncDec": 321 | ( 322 | inp, 323 | target_field, 324 | target_obs, 325 | target_field_obs, 326 | inp_hrrr, 327 | _, 328 | _, 329 | field_mask, 330 | obs_tar_mask, 331 | ) = data 332 | 333 | data_time += time.time() - data_start 334 | tr_start = time.time() 335 | 336 | self.model.zero_grad() 337 | with amp.autocast(self.params.enable_amp): 338 | inp = inp.to(self.device, dtype=torch.float) 339 | inp_hrrr = inp_hrrr.to(self.device, dtype=torch.float) 340 | target_field = target_field.to(self.device, dtype=torch.float) 341 | target_obs = target_obs.to( 342 | self.device, dtype=torch.float) 343 | target_field_obs = target_field_obs.to( 344 | self.device, dtype=torch.float) 345 | field_mask = torch.as_tensor( 346 | field_mask, dtype=torch.bool, device=self.device 347 | ) 348 | obs_tar_mask = torch.as_tensor( 349 | obs_tar_mask, dtype=torch.bool, device=self.device 350 | ) 351 | 352 | if self.params.nettype == "EncDec": 353 | gen = self.model(inp) 354 | if self.params.nettype == "EncDec_two_encoder": 355 | inp_sate = inp_sate.to(self.device, dtype=torch.float) 356 | gen = self.model(inp, inp_sate) 357 | gen.to(self.device, dtype=torch.float) 358 | 359 | loss = self.loss_function( 360 | pre_field=gen, 361 | tar_field=target_field, 362 | tar_obs=target_obs, 363 | tar_field_obs=target_field_obs, 364 | field_mask=field_mask, 365 | obs_tar_mask=obs_tar_mask, 366 | ) 367 | 368 | loss_field += loss["loss_field"] 369 | loss_obs += loss["loss_obs"] 370 | loss_field_obs += loss["loss_field_obs"] 371 | loss_field_channel_wise += loss["loss_field_channel_wise"] 372 | loss_obs_channel_wise += loss["loss_obs_channel_wise"] 373 | 374 | self.optimizer.zero_grad() 375 | if self.params.target == "obs": 376 | # target: sparse observations 377 | if self.params.enable_amp: 378 | self.gscaler.scale(loss["loss_obs"]).backward() 379 | self.gscaler.step(self.optimizer) 380 | else: 381 | loss["loss_obs"].backward() 382 | self.optimizer.step() 383 | if self.params.target == "analysis": 384 | # target: grided fields 385 | if self.params.enable_amp: 386 | self.gscaler.scale(loss["loss_field"]).backward() 387 | self.gscaler.step(self.optimizer) 388 | else: 389 | loss["loss_field"].backward() 390 | self.optimizer.step() 391 | if self.params.target == "analysis_obs": 392 | # target: grided fields + sparse observations 393 | if self.params.enable_amp: 394 | self.gscaler.scale(loss["loss_field_obs"]).backward() 395 | self.gscaler.step(self.optimizer) 396 | else: 397 | loss["loss_field_obs"].backward() 398 | self.optimizer.step() 399 | 400 | if self.params.enable_amp: 401 | self.gscaler.update() 402 | 403 | tr_time += time.time() - tr_start 404 | 405 | logs = { 406 | "loss_field": loss_field / steps_in_one_epoch, 407 | "loss_obs": loss_obs / steps_in_one_epoch, 408 | "loss_field_obs": loss_field_obs / steps_in_one_epoch, 409 | } 410 | for i_, var_ in enumerate(self.params.target_vars): 411 | tmp_var_1 = loss_obs_channel_wise[i_] / steps_in_one_epoch 412 | tmp_var_2 = loss_field_channel_wise[i_] / steps_in_one_epoch 413 | logs[f"loss_obs_{var_}"] = tmp_var_1 414 | logs[f"loss_field_{var_}"] = tmp_var_2 415 | 416 | if dist.is_initialized(): 417 | for key in sorted(logs.keys()): 418 | dist.all_reduce(logs[key].detach()) 419 | logs[key] = float(logs[key] / dist.get_world_size()) 420 | 421 | if self.params.log_to_wandb: 422 | wandb.log(logs, step=self.epoch) 423 | 424 | # time of one step in epoch 425 | step_time = tr_time / steps_in_one_epoch 426 | 427 | return tr_time, data_time, step_time, logs 428 | 429 | def validate_one_epoch(self): 430 | print("validating...") 431 | self.model.eval() 432 | 433 | valid_buff = torch.zeros((4), dtype=torch.float32, device=self.device) 434 | valid_loss_field = valid_buff[0].view(-1) 435 | valid_loss_obs = valid_buff[1].view(-1) 436 | valid_loss_field_obs = valid_buff[2].view(-1) 437 | valid_steps = valid_buff[3].view(-1) 438 | 439 | valid_start = time.time() 440 | with torch.no_grad(): 441 | for i, data in enumerate(self.valid_data_loader, 0): 442 | self.plot = False 443 | self.plot_img_path = False 444 | 445 | if self.params.nettype == "EncDec_two_encoder": 446 | ( 447 | inp, 448 | inp_sate, 449 | target_field, 450 | target_obs, 451 | target_field_obs, 452 | inp_hrrr, 453 | _, 454 | _, 455 | field_mask, 456 | obs_tar_mask, 457 | ) = data 458 | if self.params.nettype == "EncDec": 459 | ( 460 | inp, 461 | target_field, 462 | target_obs, 463 | target_field_obs, 464 | inp_hrrr, 465 | _, 466 | _, 467 | field_mask, 468 | obs_tar_mask, 469 | ) = data 470 | 471 | inp = inp.to( 472 | self.device, dtype=torch.float) 473 | inp_hrrr = inp_hrrr.to( 474 | self.device, dtype=torch.float) 475 | target_field = target_field.to( 476 | self.device, dtype=torch.float) 477 | target_obs = target_obs.to( 478 | self.device, dtype=torch.float) 479 | target_field_obs = target_field_obs.to( 480 | self.device, dtype=torch.float) 481 | field_mask = field_mask.to( 482 | self.device, dtype=torch.bool) 483 | obs_tar_mask = obs_tar_mask.to( 484 | self.device, dtype=torch.bool) 485 | 486 | if self.params.nettype == "EncDec": 487 | gen = self.model(inp) 488 | if self.params.nettype == "EncDec_two_encoder": 489 | inp_sate = inp_sate.to( 490 | self.device, dtype=torch.float) 491 | gen = self.model(inp, inp_sate) 492 | gen.to(self.device, dtype=torch.float) 493 | 494 | loss = self.loss_function( 495 | pre_field=gen, 496 | tar_field=target_field, 497 | tar_obs=target_obs, 498 | tar_field_obs=target_field_obs, 499 | field_mask=field_mask, 500 | obs_tar_mask=obs_tar_mask, 501 | ) 502 | 503 | valid_steps += 1.0 504 | valid_loss_field += loss["loss_field"] 505 | valid_loss_obs += loss["loss_obs"] 506 | valid_loss_field_obs += loss["loss_field_obs"] 507 | 508 | if dist.is_initialized(): 509 | dist.all_reduce(valid_buff) 510 | 511 | # divide by number of steps 512 | valid_buff[0:3] = valid_buff[0:3] / valid_buff[3] 513 | valid_buff_cpu = valid_buff.detach().cpu().numpy() 514 | logs = { 515 | "valid_loss_field": valid_buff_cpu[0], 516 | "valid_loss_obs": valid_buff_cpu[1], 517 | "valid_loss_field_obs": valid_buff_cpu[2], 518 | } 519 | 520 | valid_time = time.time() - valid_start 521 | 522 | if self.params.log_to_wandb: 523 | wandb.log(logs, step=self.epoch) 524 | 525 | return valid_time, logs 526 | 527 | def load_model(self, model_path): 528 | if self.params.log_to_screen: 529 | print("Loading the model weights from {}".format(model_path)) 530 | 531 | checkpoint = torch.load( 532 | model_path, map_location="cuda:{}".format(self.params.local_rank) 533 | ) 534 | 535 | if dist.is_initialized(): 536 | self.model.load_state_dict(checkpoint["model_state"]) 537 | else: 538 | new_model_state = OrderedDict() 539 | if "model_state" in checkpoint: 540 | model_key = "model_state" 541 | else: 542 | model_key = "state_dict" 543 | 544 | for key in checkpoint[model_key].keys(): 545 | if "module." in key: 546 | # model was stored using ddp which prepends module 547 | name = str(key[7:]) 548 | new_model_state[name] = checkpoint[model_key][key] 549 | else: 550 | new_model_state[key] = checkpoint[model_key][key] 551 | self.model.load_state_dict(new_model_state) 552 | self.model.eval() 553 | 554 | def save_checkpoint(self, checkpoint_path, model=None): 555 | """We intentionally require a checkpoint_dir to be passed 556 | in order to allow Ray Tune to use this function""" 557 | 558 | if not model: 559 | model = self.model 560 | 561 | print("Saving model to {}".format(checkpoint_path)) 562 | torch.save( 563 | { 564 | "iters": self.iters, 565 | "epoch": self.epoch, 566 | "model_state": model.state_dict(), 567 | "optimizer_state_dict": self.optimizer.state_dict(), 568 | }, 569 | checkpoint_path, 570 | ) 571 | 572 | def restore_checkpoint(self, checkpoint_path): 573 | checkpoint = torch.load( 574 | checkpoint_path, 575 | map_location="cuda:{}".format(self.params.local_rank) 576 | ) 577 | try: 578 | self.model.load_state_dict(checkpoint["model_state"]) 579 | except ValueError: 580 | new_state_dict = OrderedDict() 581 | for key, val in checkpoint["model_state"].items(): 582 | name = key[7:] 583 | new_state_dict[name] = val 584 | self.model.load_state_dict(new_state_dict) 585 | self.iters = checkpoint["iters"] 586 | self.startEpoch = checkpoint["epoch"] 587 | self.resumeEpoch = 0 588 | if self.params.resuming: 589 | # restore checkpoint is used for finetuning as well as resuming. 590 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 591 | # uses config specified lr. 592 | for g in self.optimizer.param_groups: 593 | g["lr"] = self.params.lr 594 | 595 | 596 | def set_random_seed(seed): 597 | random.seed(seed) 598 | np.random.seed(seed) 599 | torch.manual_seed(seed) 600 | torch.cuda.manual_seed_all(seed) 601 | 602 | 603 | if __name__ == "__main__": 604 | parser = argparse.ArgumentParser() 605 | parser.add_argument( 606 | "--yaml_config", 607 | default="./config/experiment.yaml", 608 | type=str, 609 | ) 610 | parser.add_argument("--exp_dir", default="./exp_us_t2m", type=str) 611 | parser.add_argument("--run_num", default="00", type=str) 612 | parser.add_argument("--resume", default=False, type=str2bool) 613 | parser.add_argument("--device", default="GPU", type=str) 614 | parser.add_argument("--seed", default=42, type=int) 615 | parser.add_argument("--max_epochs", default=1200, type=int) 616 | parser.add_argument("--lr", default=0.001, type=float) 617 | parser.add_argument("--lr_reduce_factor", default=0.9, type=float) 618 | parser.add_argument("--target", default="obs", type=str) 619 | parser.add_argument("--hold_out_obs_ratio", default=0.1, type=float) 620 | parser.add_argument("--obs_mask_seed", default=1, type=int) 621 | parser.add_argument("--wandb_api_key", type=str) 622 | parser.add_argument("--batch_size", default=8, type=int) 623 | parser.add_argument("--wandb_group", default="us_t2m", type=str) 624 | parser.add_argument("--net_config", default="VAE-AFNO", type=str) 625 | parser.add_argument("--enable_amp", action="store_true") 626 | parser.add_argument("--epsilon_factor", default=0, type=float) 627 | parser.add_argument("--local-rank", default=-1, type=int) 628 | args = parser.parse_args() 629 | 630 | os.environ["WANDB_API_KEY"] = args.wandb_api_key 631 | os.environ["WANDB_MODE"] = "online" 632 | 633 | if args.resume: 634 | params = YParams( 635 | os.path.join( 636 | args.exp_dir, 637 | args.net_config, 638 | args.run_num, 639 | "config.yaml"), 640 | args.net_config, 641 | False, 642 | ) 643 | else: 644 | params = YParams( 645 | os.path.abspath(args.yaml_config), 646 | args.net_config, 647 | False) 648 | 649 | params["target"] = args.target 650 | params["hold_out_obs_ratio"] = args.hold_out_obs_ratio 651 | params["obs_mask_seed"] = args.obs_mask_seed 652 | params["lr_reduce_factor"] = args.lr_reduce_factor 653 | params["max_epochs"] = args.max_epochs 654 | params["world_size"] = 1 655 | params["lr"] = args.lr 656 | 657 | if "WORLD_SIZE" in os.environ: 658 | params["world_size"] = int(os.environ["WORLD_SIZE"]) 659 | print("world_size :", params["world_size"]) 660 | 661 | if args.device == "GPU": 662 | print("Initialize distributed process group...") 663 | torch.distributed.init_process_group( 664 | backend="nccl", 665 | timeout=datetime.timedelta(seconds=5400) 666 | ) 667 | local_rank = int(os.environ["LOCAL_RANK"]) 668 | torch.cuda.set_device(local_rank) 669 | 670 | # device = torch.device('cuda', args.local_rank) 671 | params["local_rank"] = local_rank 672 | torch.backends.cudnn.benchmark = True 673 | 674 | world_rank = dist.get_rank() # get current process's ID 675 | print(f"world_rank: {world_rank}") 676 | 677 | set_random_seed(args.seed) 678 | params["nettype"] = args.net_config 679 | params["global_batch_size"] = args.batch_size 680 | params["batch_size"] = int( 681 | args.batch_size // params["world_size"] 682 | ) # batch size must be divisible by the number of gpu's 683 | # Automatic Mixed Precision Training 684 | params["enable_amp"] = args.enable_amp 685 | 686 | # Set up directory 687 | expDir = os.path.join( 688 | args.exp_dir, 689 | args.net_config, 690 | str(args.run_num)) 691 | 692 | # start training 693 | if (not args.resume) and ( 694 | (world_rank == 0 and args.device == "GPU") or args.device == "CPU" 695 | ): 696 | os.makedirs(expDir, exist_ok=True) 697 | os.makedirs( 698 | os.path.join(expDir, "training_checkpoints"), 699 | exist_ok=True) 700 | copyfile( 701 | os.path.abspath(args.yaml_config), 702 | os.path.join(expDir, "config.yaml")) 703 | 704 | params["experiment_dir"] = os.path.abspath(expDir) 705 | params["checkpoint_path"] = os.path.join( 706 | expDir, "training_checkpoints", "ckpt.tar") 707 | params["best_checkpoint_path"] = os.path.join( 708 | expDir, "training_checkpoints", "best_ckpt.tar") 709 | 710 | # Do not comment this line out please: 711 | args.resuming = True if os.path.isfile(params.checkpoint_path) else False 712 | params["resuming"] = args.resuming 713 | 714 | # experiment name 715 | params["name"] = str(args.run_num) 716 | 717 | # wandb setting 718 | params["entity"] = "your entity" # team name 719 | params["project"] = "your project" # project name 720 | params["group"] = args.wandb_group + "_" + args.net_config 721 | 722 | # if world_rank == 0: 723 | log_to_file( 724 | logger_name=None, 725 | log_filename=os.path.join(expDir, "train.log")) 726 | params.log() 727 | 728 | params["log_to_wandb"] = (world_rank == 0) and params["log_to_wandb"] 729 | params["log_to_screen"] = (world_rank == 0) and params["log_to_screen"] 730 | 731 | if world_rank == 0: 732 | hparams = ruamelDict() 733 | yaml = YAML() 734 | for key, value in params.params.items(): 735 | hparams[str(key)] = str(value) 736 | with open(os.path.join(expDir, "hyperparams.yaml"), "w") as hpfile: 737 | yaml.dump(hparams, hpfile) 738 | 739 | trainer = Trainer(params, world_rank) 740 | trainer.train() 741 | print("DONE ---- rank %d" % world_rank) 742 | -------------------------------------------------------------------------------- /models/encdec.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as checkpoint 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | 8 | 9 | class Mlp(nn.Module): 10 | def __init__( 11 | self, 12 | in_features, 13 | hidden_features=None, 14 | out_features=None, 15 | act_layer=nn.GELU, 16 | drop=0.0, 17 | ): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.fc1 = nn.Linear(in_features, hidden_features) 22 | self.act = act_layer() 23 | self.fc2 = nn.Linear(hidden_features, out_features) 24 | self.drop = nn.Dropout(drop) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | x = self.drop(x) 30 | x = self.fc2(x) 31 | x = self.drop(x) 32 | return x 33 | 34 | 35 | def window_partition(x, window_size): 36 | """ 37 | Args: 38 | x: (B, H, W, C) 39 | window_size (int): window size 40 | 41 | Returns: 42 | windows: (num_windows*B, window_size, window_size, C) 43 | """ 44 | B, H, W, C = x.shape 45 | # ic(x.shape) 46 | x = x.view( 47 | B, H // window_size, window_size, 48 | W // window_size, window_size, C) 49 | windows = ( 50 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view( 51 | -1, window_size, window_size, C) 52 | ) 53 | return windows 54 | 55 | 56 | def window_reverse(windows, window_size, H, W): 57 | """ 58 | Args: 59 | windows: (num_windows*B, window_size, window_size, C) 60 | window_size (int): Window size 61 | H (int): Height of image 62 | W (int): Width of image 63 | 64 | Returns: 65 | x: (B, H, W, C) 66 | """ 67 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 68 | x = windows.view( 69 | B, H // window_size, 70 | W // window_size, window_size, 71 | window_size, -1 72 | ) 73 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 74 | return x 75 | 76 | 77 | class WindowAttention(nn.Module): 78 | r"""Window based multi-head self attention (W-MSA) module 79 | with relative position bias. 80 | It supports both of shifted and non-shifted window. 81 | 82 | Args: 83 | dim (int): Number of input channels. 84 | window_size (tuple[int]): 85 | The height and width of the window. 86 | num_heads (int): 87 | Number of attention heads. 88 | qkv_bias (bool, optional): 89 | If True, add a learnable bias to query, key, value. 90 | Default: True 91 | qk_scale (float | None, optional): 92 | Override default qk scale of head_dim ** -0.5 if set 93 | attn_drop (float, optional): 94 | Dropout ratio of attention weight. Default: 0.0 95 | proj_drop (float, optional): 96 | Dropout ratio of output. Default: 0.0 97 | """ 98 | 99 | def __init__( 100 | self, 101 | dim, 102 | window_size, 103 | num_heads, 104 | qkv_bias=True, 105 | qk_scale=None, 106 | attn_drop=0.0, 107 | proj_drop=0.0, 108 | ): 109 | 110 | super().__init__() 111 | self.dim = dim 112 | self.window_size = window_size # Wh, Ww 113 | self.num_heads = num_heads 114 | head_dim = dim // num_heads 115 | self.scale = qk_scale or head_dim**-0.5 116 | 117 | # define a parameter table of relative position bias 118 | self.relative_position_bias_table = nn.Parameter( 119 | torch.zeros( 120 | (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) 121 | ) # 2*Wh-1 * 2*Ww-1, nH 122 | 123 | # get pair-wise relative position index 124 | # for each token inside the window 125 | coords_h = torch.arange(self.window_size[0]) 126 | coords_w = torch.arange(self.window_size[1]) 127 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 128 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 129 | relative_coords = ( 130 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 131 | ) # 2, Wh*Ww, Wh*Ww 132 | relative_coords = relative_coords.permute( 133 | 1, 2, 0 134 | ).contiguous() # Wh*Ww, Wh*Ww, 2 135 | # shift to start from 0 136 | relative_coords[:, :, 0] += self.window_size[0] - 1 137 | relative_coords[:, :, 1] += self.window_size[1] - 1 138 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 139 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 140 | self.register_buffer( 141 | "relative_position_index", relative_position_index) 142 | 143 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 144 | self.attn_drop = nn.Dropout(attn_drop) 145 | self.proj = nn.Linear(dim, dim) 146 | 147 | self.proj_drop = nn.Dropout(proj_drop) 148 | 149 | trunc_normal_(self.relative_position_bias_table, std=0.02) 150 | self.softmax = nn.Softmax(dim=-1) 151 | 152 | def forward(self, x, mask=None): 153 | """ 154 | Args: 155 | x: input features with shape of (num_windows*B, N, C) 156 | mask: (0/-inf) mask with shape of 157 | (num_windows, Wh*Ww, Wh*Ww) or None 158 | """ 159 | B_, N, C = x.shape 160 | qkv = ( 161 | self.qkv(x) 162 | .reshape(B_, N, 3, self.num_heads, C // self.num_heads) 163 | .permute(2, 0, 3, 1, 4) 164 | ) 165 | q, k, v = ( 166 | qkv[0], 167 | qkv[1], 168 | qkv[2], 169 | ) # make torchscript happy (cannot use tensor as tuple) 170 | 171 | q = q * self.scale 172 | attn = q @ k.transpose(-2, -1) 173 | 174 | relative_position_bias = self.relative_position_bias_table[ 175 | self.relative_position_index.view(-1) 176 | ].view( 177 | self.window_size[0] * self.window_size[1], 178 | self.window_size[0] * self.window_size[1], 179 | -1, 180 | ) # Wh*Ww,Wh*Ww,nH 181 | relative_position_bias = relative_position_bias.permute( 182 | 2, 0, 1 183 | ).contiguous() # nH, Wh*Ww, Wh*Ww 184 | attn = attn + relative_position_bias.unsqueeze(0) 185 | 186 | if mask is not None: 187 | nW = mask.shape[0] 188 | attn = attn.view( 189 | B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( 190 | 1 191 | ).unsqueeze(0) 192 | attn = attn.view(-1, self.num_heads, N, N) 193 | attn = self.softmax(attn) 194 | else: 195 | attn = self.softmax(attn) 196 | 197 | attn = self.attn_drop(attn) 198 | 199 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 200 | x = self.proj(x) 201 | x = self.proj_drop(x) 202 | return x 203 | 204 | def extra_repr(self) -> str: 205 | return f"dim={self.dim}, \ 206 | window_size={self.window_size}, \ 207 | num_heads={self.num_heads}" 208 | 209 | def flops(self, N): 210 | # calculate flops for 1 window with token length of N 211 | flops = 0 212 | # qkv = self.qkv(x) 213 | flops += N * self.dim * 3 * self.dim 214 | # attn = (q @ k.transpose(-2, -1)) 215 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 216 | # x = (attn @ v) 217 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 218 | # x = self.proj(x) 219 | flops += N * self.dim * self.dim 220 | return flops 221 | 222 | 223 | class SwinTransformerBlock(nn.Module): 224 | r"""Swin Transformer Block. 225 | 226 | Args: 227 | dim (int): Number of input channels. 228 | input_resolution (tuple[int]): 229 | Input resulotion. 230 | num_heads (int): 231 | Number of attention heads. 232 | window_size (int): 233 | Window size. 234 | shift_size (int): 235 | Shift size for SW-MSA. 236 | mlp_ratio (float): 237 | Ratio of mlp hidden dim to embedding dim. 238 | qkv_bias (bool, optional): 239 | If True, add a learnable bias to query, key, value. 240 | Default: True 241 | qk_scale (float | None, optional): 242 | Override default qk scale of head_dim ** -0.5 if set. 243 | drop (float, optional): 244 | Dropout rate. 245 | Default: 0.0 246 | attn_drop (float, optional): 247 | Attention dropout rate. 248 | Default: 0.0 249 | drop_path (float, optional): 250 | Stochastic depth rate. 251 | Default: 0.0 252 | act_layer (nn.Module, optional): 253 | Activation layer. 254 | Default: nn.GELU 255 | norm_layer (nn.Module, optional): 256 | Normalization layer. 257 | Default: nn.LayerNorm 258 | """ 259 | 260 | def __init__( 261 | self, 262 | dim, 263 | input_resolution, 264 | num_heads, 265 | window_size=7, 266 | shift_size=0, 267 | mlp_ratio=4.0, 268 | qkv_bias=True, 269 | qk_scale=None, 270 | drop=0.0, 271 | attn_drop=0.0, 272 | drop_path=0.0, 273 | act_layer=nn.GELU, 274 | norm_layer=nn.LayerNorm, 275 | ): 276 | super().__init__() 277 | self.dim = dim 278 | self.input_resolution = input_resolution 279 | self.num_heads = num_heads 280 | self.window_size = window_size 281 | self.shift_size = shift_size 282 | self.mlp_ratio = mlp_ratio 283 | if min(self.input_resolution) <= self.window_size: 284 | # if window size is larger than input resolution, 285 | # we don't partition windows 286 | self.shift_size = 0 287 | self.window_size = min(self.input_resolution) 288 | assert ( 289 | 0 <= self.shift_size < self.window_size 290 | ), "shift_size must in 0-window_size" 291 | 292 | self.norm1 = norm_layer(dim) 293 | self.attn = WindowAttention( 294 | dim, 295 | window_size=to_2tuple(self.window_size), 296 | num_heads=num_heads, 297 | qkv_bias=qkv_bias, 298 | qk_scale=qk_scale, 299 | attn_drop=attn_drop, 300 | proj_drop=drop, 301 | ) 302 | 303 | self.drop_path = DropPath( 304 | drop_path) if drop_path > 0.0 else nn.Identity() 305 | self.norm2 = norm_layer(dim) 306 | mlp_hidden_dim = int(dim * mlp_ratio) 307 | self.mlp = Mlp( 308 | in_features=dim, 309 | hidden_features=mlp_hidden_dim, 310 | act_layer=act_layer, 311 | drop=drop, 312 | ) 313 | 314 | if self.shift_size > 0: 315 | attn_mask = self.calculate_mask(self.input_resolution) 316 | else: 317 | attn_mask = None 318 | 319 | self.register_buffer("attn_mask", attn_mask) 320 | 321 | def calculate_mask(self, x_size): 322 | # calculate attention mask for SW-MSA 323 | H, W = x_size 324 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 325 | h_slices = ( 326 | slice(0, -self.window_size), 327 | slice(-self.window_size, -self.shift_size), 328 | slice(-self.shift_size, None), 329 | ) 330 | w_slices = ( 331 | slice(0, -self.window_size), 332 | slice(-self.window_size, -self.shift_size), 333 | slice(-self.shift_size, None), 334 | ) 335 | cnt = 0 336 | for h in h_slices: 337 | for w in w_slices: 338 | img_mask[:, h, w, :] = cnt 339 | cnt += 1 340 | 341 | mask_windows = window_partition( 342 | img_mask, self.window_size 343 | ) # nW, window_size, window_size, 1 344 | mask_windows = mask_windows.view( 345 | -1, self.window_size * self.window_size) 346 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 347 | attn_mask = attn_mask.masked_fill( 348 | attn_mask != 0, float(-100.0)).masked_fill( 349 | attn_mask == 0, float(0.0) 350 | ) 351 | 352 | return attn_mask 353 | 354 | def forward(self, x, x_size): 355 | H, W = x_size 356 | B, L, C = x.shape 357 | # assert L == H * W, "input feature has wrong size" 358 | 359 | shortcut = x 360 | x = self.norm1(x) 361 | x = x.view(B, H, W, C) 362 | 363 | # cyclic shift 364 | if self.shift_size > 0: 365 | shifted_x = torch.roll( 366 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) 367 | ) 368 | else: 369 | shifted_x = x 370 | 371 | # partition windows 372 | x_windows = window_partition( 373 | shifted_x, self.window_size 374 | ) # nW*B, window_size, window_size, C 375 | x_windows = x_windows.view( 376 | -1, self.window_size * self.window_size, C 377 | ) # nW*B, window_size*window_size, C 378 | 379 | # W-MSA/SW-MSA (to be compatible for testing 380 | # on images whose shapes are the multiple of window size 381 | if self.input_resolution == x_size: 382 | attn_windows = self.attn( 383 | x_windows, mask=self.attn_mask 384 | ) # nW*B, window_size*window_size, C 385 | else: 386 | attn_windows = self.attn( 387 | x_windows, mask=self.calculate_mask(x_size).to(x.device) 388 | ) 389 | 390 | # merge windows 391 | attn_windows = attn_windows.view( 392 | -1, self.window_size, self.window_size, C) 393 | shifted_x = window_reverse( 394 | attn_windows, self.window_size, H, W) # B H' W' C 395 | 396 | # reverse cyclic shift 397 | if self.shift_size > 0: 398 | x = torch.roll( 399 | shifted_x, 400 | shifts=(self.shift_size, self.shift_size), 401 | dims=(1, 2) 402 | ) 403 | else: 404 | x = shifted_x 405 | x = x.view(B, H * W, C) 406 | 407 | # FFN 408 | x = shortcut + self.drop_path(x) 409 | x = x + self.drop_path(self.mlp(self.norm2(x))) 410 | 411 | return x 412 | 413 | def extra_repr(self) -> str: 414 | return ( 415 | f"dim={self.dim}, \ 416 | input_resolution={self.input_resolution}, \ 417 | num_heads={self.num_heads}, " 418 | f"window_size={self.window_size}, \ 419 | shift_size={self.shift_size}, \ 420 | mlp_ratio={self.mlp_ratio}" 421 | ) 422 | 423 | def flops(self): 424 | flops = 0 425 | H, W = self.input_resolution 426 | # norm1 427 | flops += self.dim * H * W 428 | # W-MSA/SW-MSA 429 | nW = H * W / self.window_size / self.window_size 430 | flops += nW * self.attn.flops(self.window_size * self.window_size) 431 | # mlp 432 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 433 | # norm2 434 | flops += self.dim * H * W 435 | return flops 436 | 437 | 438 | class PatchMerging(nn.Module): 439 | r"""Patch Merging Layer. 440 | 441 | Args: 442 | input_resolution (tuple[int]): 443 | Resolution of input feature. 444 | dim (int): 445 | Number of input channels. 446 | norm_layer (nn.Module, optional): 447 | Normalization layer. 448 | Default: nn.LayerNorm 449 | """ 450 | 451 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 452 | super().__init__() 453 | self.input_resolution = input_resolution 454 | self.dim = dim 455 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 456 | self.norm = norm_layer(4 * dim) 457 | 458 | def forward(self, x): 459 | """ 460 | x: B, H*W, C 461 | """ 462 | H, W = self.input_resolution 463 | B, L, C = x.shape 464 | assert L == H * W, "input feature has wrong size" 465 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 466 | 467 | x = x.view(B, H, W, C) 468 | 469 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 470 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 471 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 472 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 473 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 474 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 475 | 476 | x = self.norm(x) 477 | x = self.reduction(x) 478 | 479 | return x 480 | 481 | def extra_repr(self) -> str: 482 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 483 | 484 | def flops(self): 485 | H, W = self.input_resolution 486 | flops = H * W * self.dim 487 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 488 | return flops 489 | 490 | 491 | class BasicLayer(nn.Module): 492 | """A basic Swin Transformer layer for one stage. 493 | 494 | Args: 495 | dim (int): Number of input channels. 496 | input_resolution (tuple[int]): Input resolution. 497 | depth (int): Number of blocks. 498 | num_heads (int): 499 | Number of attention heads. 500 | window_size (int): 501 | Local window size. 502 | mlp_ratio (float): 503 | Ratio of mlp hidden dim to embedding dim. 504 | qkv_bias (bool, optional): 505 | If True, add a learnable bias to query, key, value. 506 | Default: True 507 | qk_scale (float | None, optional): 508 | Override default qk scale of head_dim ** -0.5 if set. 509 | drop (float, optional): 510 | Dropout rate. Default: 0.0 511 | attn_drop (float, optional): 512 | Attention dropout rate. Default: 0.0 513 | drop_path (float | tuple[float], optional): 514 | Stochastic depth rate. Default: 0.0 515 | norm_layer (nn.Module, optional): 516 | Normalization layer. Default: nn.LayerNorm 517 | downsample (nn.Module | None, optional): 518 | Downsample layer at the end of the layer. 519 | Default: None 520 | use_checkpoint (bool): 521 | Whether to use checkpointing to save memory. 522 | Default: False. 523 | """ 524 | 525 | def __init__( 526 | self, 527 | dim, 528 | input_resolution, 529 | depth, 530 | num_heads, 531 | window_size, 532 | mlp_ratio=4.0, 533 | qkv_bias=True, 534 | qk_scale=None, 535 | drop=0.0, 536 | attn_drop=0.0, 537 | drop_path=0.0, 538 | norm_layer=nn.LayerNorm, 539 | downsample=None, 540 | use_checkpoint=False, 541 | ): 542 | 543 | super().__init__() 544 | self.dim = dim 545 | self.input_resolution = input_resolution 546 | self.depth = depth 547 | self.use_checkpoint = use_checkpoint 548 | 549 | # build blocks 550 | self.blocks = nn.ModuleList( 551 | [ 552 | SwinTransformerBlock( 553 | dim=dim, 554 | input_resolution=input_resolution, 555 | num_heads=num_heads, 556 | window_size=window_size, 557 | shift_size=0 if (i % 2 == 0) else window_size // 2, 558 | mlp_ratio=mlp_ratio, 559 | qkv_bias=qkv_bias, 560 | qk_scale=qk_scale, 561 | drop=drop, 562 | attn_drop=attn_drop, 563 | drop_path=( 564 | drop_path[i] if isinstance( 565 | drop_path, list) else drop_path 566 | ), 567 | norm_layer=norm_layer, 568 | ) 569 | for i in range(depth) 570 | ] 571 | ) 572 | 573 | # patch merging layer 574 | if downsample is not None: 575 | self.downsample = downsample( 576 | input_resolution, dim=dim, norm_layer=norm_layer 577 | ) 578 | else: 579 | self.downsample = None 580 | 581 | def forward(self, x, x_size): 582 | for blk in self.blocks: 583 | if self.use_checkpoint: 584 | x = checkpoint.checkpoint(blk, x, x_size) 585 | else: 586 | x = blk(x, x_size) 587 | if self.downsample is not None: 588 | x = self.downsample(x) 589 | return x 590 | 591 | def extra_repr(self) -> str: 592 | return f"dim={self.dim}, \ 593 | input_resolution={self.input_resolution}, \ 594 | depth={self.depth}" 595 | 596 | def flops(self): 597 | flops = 0 598 | for blk in self.blocks: 599 | flops += blk.flops() 600 | if self.downsample is not None: 601 | flops += self.downsample.flops() 602 | return flops 603 | 604 | 605 | class RSTB(nn.Module): 606 | """Residual Swin Transformer Block (RSTB). 607 | 608 | Args: 609 | dim (int): Number of input channels. 610 | input_resolution (tuple[int]): Input resolution. 611 | depth (int): 612 | Number of blocks. 613 | num_heads (int): 614 | Number of attention heads. 615 | window_size (int): 616 | Local window size. 617 | mlp_ratio (float): 618 | Ratio of mlp hidden dim to embedding dim. 619 | qkv_bias (bool, optional): 620 | If True, add a learnable bias to query, key, value. 621 | Default: True 622 | qk_scale (float | None, optional): 623 | Override default qk scale of head_dim ** -0.5 if set. 624 | drop (float, optional): 625 | Dropout rate. 626 | Default: 0.0 627 | attn_drop (float, optional): 628 | Attention dropout rate. 629 | Default: 0.0 630 | drop_path (float | tuple[float], optional): 631 | Stochastic depth rate. 632 | Default: 0.0 633 | norm_layer (nn.Module, optional): 634 | Normalization layer. 635 | Default: nn.LayerNorm 636 | downsample (nn.Module | None, optional): 637 | Downsample layer at the end of the layer. 638 | Default: None 639 | use_checkpoint (bool): 640 | Whether to use checkpointing to save memory. 641 | Default: False. 642 | img_size: 643 | Input image size. 644 | patch_size: 645 | Patch size. 646 | resi_connection: 647 | The convolutional block before residual connection. 648 | """ 649 | 650 | def __init__( 651 | self, 652 | dim, 653 | input_resolution, 654 | depth, 655 | num_heads, 656 | window_size, 657 | mlp_ratio=4.0, 658 | qkv_bias=True, 659 | qk_scale=None, 660 | drop=0.0, 661 | attn_drop=0.0, 662 | drop_path=0.0, 663 | norm_layer=nn.LayerNorm, 664 | downsample=None, 665 | use_checkpoint=False, 666 | img_size=224, 667 | patch_size=4, 668 | resi_connection="1conv", 669 | ): 670 | super(RSTB, self).__init__() 671 | 672 | self.dim = dim 673 | self.input_resolution = input_resolution 674 | 675 | self.residual_group = BasicLayer( 676 | dim=dim, 677 | input_resolution=input_resolution, 678 | depth=depth, 679 | num_heads=num_heads, 680 | window_size=window_size, 681 | mlp_ratio=mlp_ratio, 682 | qkv_bias=qkv_bias, 683 | qk_scale=qk_scale, 684 | drop=drop, 685 | attn_drop=attn_drop, 686 | drop_path=drop_path, 687 | norm_layer=norm_layer, 688 | downsample=downsample, 689 | use_checkpoint=use_checkpoint, 690 | ) 691 | 692 | if resi_connection == "1conv": 693 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1) 694 | elif resi_connection == "3conv": 695 | # to save parameters and memory 696 | self.conv = nn.Sequential( 697 | nn.Conv2d(dim, dim // 4, 3, 1, 1), 698 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 699 | nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), 700 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 701 | nn.Conv2d(dim // 4, dim, 3, 1, 1), 702 | ) 703 | 704 | self.patch_embed = PatchEmbed( 705 | img_size=img_size, 706 | patch_size=patch_size, 707 | in_chans=0, 708 | embed_dim=dim, 709 | norm_layer=None, 710 | ) 711 | 712 | self.patch_unembed = PatchUnEmbed( 713 | img_size=img_size, 714 | patch_size=patch_size, 715 | in_chans=0, 716 | embed_dim=dim, 717 | norm_layer=None, 718 | ) 719 | 720 | def forward(self, x, x_size): 721 | return ( 722 | self.patch_embed( 723 | self.conv(self.patch_unembed( 724 | self.residual_group(x, x_size), 725 | x_size)) 726 | ) 727 | + x 728 | ) 729 | 730 | def flops(self): 731 | flops = 0 732 | flops += self.residual_group.flops() 733 | H, W = self.input_resolution 734 | flops += H * W * self.dim * self.dim * 9 735 | flops += self.patch_embed.flops() 736 | flops += self.patch_unembed.flops() 737 | 738 | return flops 739 | 740 | 741 | class PatchEmbed(nn.Module): 742 | r"""Image to Patch Embedding 743 | 744 | Args: 745 | img_size (int): 746 | Image size. 747 | Default: 224. 748 | patch_size (int): 749 | Patch token size. 750 | Default: 4. 751 | in_chans (int): 752 | Number of input image channels. 753 | Default: 3. 754 | embed_dim (int): 755 | Number of linear projection output channels. 756 | Default: 96. 757 | norm_layer (nn.Module, optional): 758 | Normalization layer. Default: None 759 | """ 760 | 761 | def __init__( 762 | self, img_size=224, patch_size=4, 763 | in_chans=3, embed_dim=96, norm_layer=None 764 | ): 765 | super().__init__() 766 | img_size = to_2tuple(img_size) 767 | patch_size = to_2tuple(patch_size) 768 | patches_resolution = [ 769 | img_size[0] // patch_size[0], 770 | img_size[1] // patch_size[1], 771 | ] 772 | self.img_size = img_size 773 | self.patch_size = patch_size 774 | self.patches_resolution = patches_resolution 775 | self.num_patches = patches_resolution[0] * patches_resolution[1] 776 | 777 | self.in_chans = in_chans 778 | self.embed_dim = embed_dim 779 | 780 | if norm_layer is not None: 781 | self.norm = norm_layer(embed_dim) 782 | else: 783 | self.norm = None 784 | 785 | def forward(self, x): 786 | x = x.flatten(2).transpose(1, 2) # B Ph*Pw C 787 | if self.norm is not None: 788 | x = self.norm(x) 789 | return x 790 | 791 | def flops(self): 792 | flops = 0 793 | H, W = self.img_size 794 | if self.norm is not None: 795 | flops += H * W * self.embed_dim 796 | return flops 797 | 798 | 799 | class PatchUnEmbed(nn.Module): 800 | r"""Image to Patch Unembedding 801 | 802 | Args: 803 | img_size (int): 804 | Image size. 805 | Default: 224. 806 | patch_size (int): 807 | Patch token size. 808 | Default: 4. 809 | in_chans (int): 810 | Number of input image channels. 811 | Default: 3. 812 | embed_dim (int): 813 | Number of linear projection output channels. 814 | Default: 96. 815 | norm_layer (nn.Module, optional): 816 | Normalization layer. 817 | Default: None 818 | """ 819 | 820 | def __init__( 821 | self, img_size=224, 822 | patch_size=4, in_chans=3, 823 | embed_dim=96, norm_layer=None 824 | ): 825 | super().__init__() 826 | img_size = to_2tuple(img_size) 827 | patch_size = to_2tuple(patch_size) 828 | patches_resolution = [ 829 | img_size[0] // patch_size[0], 830 | img_size[1] // patch_size[1], 831 | ] 832 | self.img_size = img_size 833 | self.patch_size = patch_size 834 | self.patches_resolution = patches_resolution 835 | self.num_patches = patches_resolution[0] * patches_resolution[1] 836 | 837 | self.in_chans = in_chans 838 | self.embed_dim = embed_dim 839 | 840 | def forward(self, x, x_size): 841 | B, HW, C = x.shape 842 | # B Ph*Pw C 843 | x = x.transpose(1, 2).view( 844 | B, self.embed_dim, x_size[0], x_size[1]) 845 | return x 846 | 847 | def flops(self): 848 | flops = 0 849 | return flops 850 | 851 | 852 | class Upsample(nn.Sequential): 853 | """Upsample module. 854 | 855 | Args: 856 | scale (int): Scale factor. Supported scales: 2^n and 3. 857 | num_feat (int): Channel number of intermediate features. 858 | """ 859 | 860 | def __init__(self, scale, num_feat): 861 | m = [] 862 | if (scale & (scale - 1)) == 0: # scale = 2^n 863 | for _ in range(int(math.log(scale, 2))): 864 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 865 | m.append(nn.PixelShuffle(2)) 866 | elif scale == 3: 867 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 868 | m.append(nn.PixelShuffle(3)) 869 | else: 870 | raise ValueError( 871 | f"scale {scale} is not supported. " + 872 | "Supported scales: 2^n and 3." 873 | ) 874 | super(Upsample, self).__init__(*m) 875 | 876 | 877 | class UpsampleOneStep(nn.Sequential): 878 | """UpsampleOneStep module 879 | (the difference with Upsample is that it 880 | always only has 1conv + 1pixelshuffle) 881 | Used in lightweight SR to save parameters. 882 | 883 | Args: 884 | scale (int): 885 | Scale factor. Supported scales: 2^n and 3. 886 | num_feat (int): 887 | Channel number of intermediate features. 888 | 889 | """ 890 | 891 | def __init__( 892 | self, scale, num_feat, num_out_ch, input_resolution=None 893 | ): 894 | self.num_feat = num_feat 895 | self.input_resolution = input_resolution 896 | m = [] 897 | m.append(nn.Conv2d( 898 | num_feat, (scale**2) * num_out_ch, 3, 1, 1)) 899 | m.append(nn.PixelShuffle(scale)) 900 | super(UpsampleOneStep, self).__init__(*m) 901 | 902 | def flops(self): 903 | H, W = self.input_resolution 904 | flops = H * W * self.num_feat * 3 * 9 905 | return flops 906 | 907 | 908 | class EncDec(nn.Module): 909 | r"""EncDec 910 | 911 | Args: 912 | img_size (int | tuple(int)): 913 | Input image size. 914 | Default 64 915 | patch_size (int | tuple(int)): 916 | Patch size. 917 | Default: 1 918 | in_chans (int): 919 | Number of input image channels. 920 | Default: 3 921 | embed_dim (int): 922 | Patch embedding dimension. 923 | Default: 96 924 | depths (tuple(int)): 925 | Depth of each Swin Transformer layer. 926 | num_heads (tuple(int)): 927 | Number of attention heads in different layers. 928 | window_size (int): 929 | Window size. 930 | Default: 7 931 | mlp_ratio (float): 932 | Ratio of mlp hidden dim to embedding dim. 933 | Default: 4 934 | qkv_bias (bool): 935 | If True, add a learnable bias to query, key, value. 936 | Default: True 937 | qk_scale (float): 938 | Override default qk scale of head_dim ** -0.5 if set. 939 | Default: None 940 | drop_rate (float): 941 | Dropout rate. Default: 0 942 | attn_drop_rate (float): 943 | Attention dropout rate. Default: 0 944 | drop_path_rate (float): 945 | Stochastic depth rate. Default: 0.1 946 | norm_layer (nn.Module): 947 | Normalization layer. 948 | Default: nn.LayerNorm. 949 | ape (bool): 950 | If True, 951 | add absolute position embedding to the patch embedding. 952 | Default: False 953 | patch_norm (bool): 954 | If True, add normalization after patch embedding. 955 | Default: True 956 | use_checkpoint (bool): 957 | Whether to use checkpointing to save memory. 958 | Default: False 959 | upscale: 960 | Upscale factor. 961 | 2/3/4/8 for image SR, 962 | 1 for denoising and compress artifact reduction 963 | img_range: 964 | Image range. 1. or 255. 965 | resi_connection: 966 | The convolutional block before residual connection. 967 | '1conv'/'3conv' 968 | """ 969 | 970 | def __init__( 971 | self, 972 | params, 973 | norm_layer=nn.LayerNorm, 974 | **kwargs, 975 | ): 976 | super(EncDec, self).__init__() 977 | self.img_range = params.img_range 978 | if params.in_chans == 3: 979 | rgb_mean = (0.4488, 0.4371, 0.4040) 980 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 981 | else: 982 | self.mean = torch.zeros(1, 1, 1, 1) 983 | self.upscale = params.upscale 984 | self.window_size = params.window_size 985 | 986 | # 1, encoder 987 | self.conv_first = nn.Conv2d( 988 | params.in_chans, params.embed_dim, 3, 1, 1) 989 | 990 | # 2, decoder 991 | self.num_layers = len(params.depths) 992 | self.embed_dim = params.embed_dim 993 | self.ape = params.ape 994 | self.patch_norm = params.patch_norm 995 | self.num_features = params.embed_dim 996 | self.mlp_ratio = params.mlp_ratio 997 | 998 | # split image into non-overlapping patches 999 | self.patch_embed = PatchEmbed( 1000 | img_size=(params.img_size_x, params.img_size_y), 1001 | patch_size=params.patch_size, 1002 | in_chans=params.embed_dim, 1003 | embed_dim=params.embed_dim, 1004 | norm_layer=norm_layer if self.patch_norm else None, 1005 | ) 1006 | num_patches = self.patch_embed.num_patches 1007 | patches_resolution = self.patch_embed.patches_resolution 1008 | self.patches_resolution = patches_resolution 1009 | 1010 | # merge non-overlapping patches into image 1011 | self.patch_unembed = PatchUnEmbed( 1012 | img_size=(params.img_size_x, params.img_size_y), 1013 | patch_size=params.patch_size, 1014 | in_chans=params.embed_dim, 1015 | embed_dim=params.embed_dim, 1016 | norm_layer=norm_layer if self.patch_norm else None, 1017 | ) 1018 | 1019 | # absolute position embedding 1020 | if self.ape: 1021 | self.absolute_pos_embed = nn.Parameter( 1022 | torch.zeros(1, num_patches, params.embed_dim) 1023 | ) 1024 | trunc_normal_(self.absolute_pos_embed, std=0.02) 1025 | 1026 | self.pos_drop = nn.Dropout(p=params.drop_rate) 1027 | 1028 | # stochastic depth 1029 | dpr = [ 1030 | x.item() 1031 | for x in torch.linspace( 1032 | 0, params.drop_path_rate, sum(params.depths)) 1033 | ] # stochastic depth decay rule 1034 | 1035 | # build Residual Swin Transformer blocks (RSTB) 1036 | self.layers = nn.ModuleList() 1037 | for i_layer in range(self.num_layers): 1038 | dp_1 = sum(params.depths[:i_layer]) 1039 | dp_2 = sum(params.depths[: i_layer + 1]) 1040 | layer = RSTB( 1041 | dim=params.embed_dim, 1042 | input_resolution=( 1043 | patches_resolution[0], patches_resolution[1]), 1044 | depth=params.depths[i_layer], 1045 | num_heads=params.num_heads[i_layer], 1046 | window_size=params.window_size, 1047 | mlp_ratio=self.mlp_ratio, 1048 | qkv_bias=params.qkv_bias, 1049 | qk_scale=params.qk_scale, 1050 | drop=params.drop_rate, 1051 | attn_drop=params.attn_drop_rate, 1052 | drop_path=dpr[dp_1:dp_2], 1053 | norm_layer=norm_layer, 1054 | downsample=None, 1055 | use_checkpoint=params.use_checkpoint, 1056 | img_size=(params.img_size_x, params.img_size_y), 1057 | patch_size=params.patch_size, 1058 | resi_connection=params.resi_connection, 1059 | ) 1060 | self.layers.append(layer) 1061 | self.norm = norm_layer(self.num_features) 1062 | 1063 | # build the last conv layer in deep feature extraction 1064 | if params.resi_connection == "1conv": 1065 | self.conv_after_body = nn.Conv2d( 1066 | params.embed_dim, params.embed_dim, 3, 1, 1 1067 | ) 1068 | elif params.resi_connection == "3conv": 1069 | # to save parameters and memory 1070 | self.conv_after_body = nn.Sequential( 1071 | nn.Conv2d( 1072 | params.embed_dim, params.embed_dim // 4, 3, 1, 1), 1073 | nn.LeakyReLU( 1074 | negative_slope=0.2, inplace=True), 1075 | nn.Conv2d( 1076 | params.embed_dim // 4, params.embed_dim // 4, 1, 1, 0), 1077 | nn.LeakyReLU( 1078 | negative_slope=0.2, inplace=True), 1079 | nn.Conv2d( 1080 | params.embed_dim // 4, params.embed_dim, 3, 1, 1), 1081 | ) 1082 | 1083 | # 3, reconstruction 1084 | self.conv_before_upsample = nn.Sequential( 1085 | nn.Conv2d(params.embed_dim, params.num_feat, 3, 1, 1), 1086 | nn.LeakyReLU(inplace=True), 1087 | ) 1088 | self.upsample = Upsample( 1089 | params.upscale, params.num_feat) 1090 | self.conv_last = nn.Conv2d( 1091 | params.num_feat, params.out_chans, 3, 1, 1) 1092 | 1093 | self.apply(self._init_weights) 1094 | 1095 | def _init_weights(self, m): 1096 | if isinstance(m, nn.Linear): 1097 | trunc_normal_(m.weight, std=0.02) 1098 | if isinstance(m, nn.Linear) and m.bias is not None: 1099 | nn.init.constant_(m.bias, 0) 1100 | elif isinstance(m, nn.LayerNorm): 1101 | nn.init.constant_(m.bias, 0) 1102 | nn.init.constant_(m.weight, 1.0) 1103 | 1104 | @torch.jit.ignore 1105 | def no_weight_decay(self): 1106 | return {"absolute_pos_embed"} 1107 | 1108 | @torch.jit.ignore 1109 | def no_weight_decay_keywords(self): 1110 | return {"relative_position_bias_table"} 1111 | 1112 | def check_image_size(self, x): 1113 | _, _, h, w = x.size() 1114 | mod_pad_h = ( 1115 | self.window_size - h % self.window_size) % self.window_size 1116 | mod_pad_w = ( 1117 | self.window_size - w % self.window_size) % self.window_size 1118 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") 1119 | return x 1120 | 1121 | def forward_features(self, x): 1122 | x_size = (x.shape[2], x.shape[3]) 1123 | x = self.patch_embed(x) 1124 | if self.ape: 1125 | x = x + self.absolute_pos_embed 1126 | x = self.pos_drop(x) 1127 | 1128 | for layer in self.layers: 1129 | x = layer(x, x_size) 1130 | 1131 | x = self.norm(x) # B L C 1132 | x = self.patch_unembed(x, x_size) 1133 | 1134 | return x 1135 | 1136 | def forward(self, x): 1137 | H, W = x.shape[2:] 1138 | x = self.check_image_size(x) 1139 | 1140 | self.mean = self.mean.type_as(x) 1141 | x = (x - self.mean) * self.img_range 1142 | 1143 | x = self.conv_first(x) 1144 | x = self.conv_after_body(self.forward_features(x)) + x 1145 | x = self.conv_before_upsample(x) 1146 | x = self.upsample(x) 1147 | x = self.conv_last(x) 1148 | 1149 | x = x / self.img_range + self.mean 1150 | 1151 | return x[:, :, : H * self.upscale, : W * self.upscale] 1152 | 1153 | def flops(self): 1154 | flops = 0 1155 | H, W = self.patches_resolution 1156 | flops += H * W * 3 * self.embed_dim * 9 1157 | flops += self.patch_embed.flops() 1158 | for i, layer in enumerate(self.layers): 1159 | flops += layer.flops() 1160 | flops += H * W * 3 * self.embed_dim * self.embed_dim 1161 | flops += self.upsample.flops() 1162 | return flops 1163 | 1164 | 1165 | if __name__ == "__main__": 1166 | 1167 | params = { 1168 | "upscale": 1, 1169 | "in_chans": 8, 1170 | "out_chans": 4, 1171 | "img_size_x": 960, 1172 | "img_size_y": 480, 1173 | "window_size": 4, 1174 | "patch_size": 5, 1175 | "num_feat": 64, 1176 | "drop_rate": 0.1, 1177 | "drop_path_rate": 0.1, 1178 | "attn_drop_rate": 0.1, 1179 | "ape": False, 1180 | "patch_norm": True, 1181 | "use_checkpoint": False, 1182 | "resi_connection": "1conv", 1183 | "qkv_bias": True, 1184 | "qk_scale": None, 1185 | "img_range": 1.0, 1186 | "depths": [3], # [3] 1187 | "embed_dim": 64, # need be divisible by num_heads 1188 | "num_heads": [4], 1189 | "mlp_ratio": 2, 1190 | } 1191 | import argparse 1192 | 1193 | params = argparse.Namespace(**params) 1194 | model = EncDec( 1195 | params, 1196 | ) 1197 | 1198 | print(model) 1199 | 1200 | x = torch.randn((1, 8, 960, 480)) 1201 | x = model(x) 1202 | print(x.shape) 1203 | --------------------------------------------------------------------------------