├── assets
├── frame.png
└── framework.png
├── data
├── stats.csv
└── README.md
├── utils
├── logging_utils.py
├── YParams.py
└── data_loader_multifiles.py
├── LICENSE
├── config
└── experiment.yaml
├── SECURITY.md
├── README.md
├── .gitignore
├── inference.py
├── train.py
└── models
└── encdec.py
/assets/frame.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/ADAF/main/assets/frame.png
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/ADAF/main/assets/framework.png
--------------------------------------------------------------------------------
/data/stats.csv:
--------------------------------------------------------------------------------
1 | ,variable,max,min
2 | 0,hrrr_q,0.025,0.0
3 | 1,hrrr_sp,1050,600
4 | 2,hrrr_u_10,25,-25
5 | 3,hrrr_v_10,25,-25
6 | 4,hrrr_t,50,-40
7 | 5,rtma_t,50,-40
8 | 6,rtma_q,0.025,0.0
9 | 7,rtma_u10,25,-25
10 | 8,rtma_v10,25,-25
11 | 9,rtma_sp,1050,600
12 | 10,z,3190,-65
13 | 11,sta_t,50,-40
14 | 12,sta_q,0.025,0.0
15 | 13,sta_u10,25,-25
16 | 14,sta_v10,25,-25
17 | 15,sta_p,1200,600
18 |
--------------------------------------------------------------------------------
/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | _format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
5 |
6 |
7 | def config_logger(log_level=logging.INFO):
8 | logging.basicConfig(format=_format, level=log_level)
9 |
10 |
11 | def log_to_file(
12 | logger_name=None,
13 | log_level=logging.INFO,
14 | log_filename="tensorflow.log"
15 | ):
16 | if not os.path.exists(os.path.dirname(log_filename)):
17 | os.makedirs(os.path.dirname(log_filename))
18 |
19 | if logger_name is not None:
20 | log = logging.getLogger(logger_name)
21 | else:
22 | log = logging.getLogger()
23 |
24 | fh = logging.FileHandler(log_filename)
25 | fh.setLevel(log_level)
26 | fh.setFormatter(logging.Formatter(_format))
27 | log.addHandler(fh)
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | Variables included in a sample file:
2 |
3 | | Variable | Decription | Dimension |
4 | | ----------- | ----------- | ----------- |
5 | | z | Topography, normalized | [lat, lon] |
6 | | rtma_t | T2M from RTMA, normalized | [lat, lon] |
7 | | rtma_q | Q from RTMA, normalized | [lat, lon] |
8 | | rtma_u10 | U10 from RTMA, normalized | [lat, lon] |
9 | | rtma_v10 | V10 from RTMA, normalized | [lat, lon] |
10 | | sta_t | T2M from station's observation, 0 means non-station, normalized | [obs_time_window, lat, lon] |
11 | | sta_q | Q from station's observation, 0 means non-station, normalized | [obs_time_window, lat, lon] |
12 | | sta_u10 | U10 from station's observation, 0 means non-station, normalized | [obs_time_window, lat, lon] |
13 | | sta_v10 | V10 from station's observation, 0 means non-station, normalized | [obs_time_window, lat, lon] |
14 | | CMI02 | ABI Band 2: visible (red), normalized | [obs_time_window, lat, lon] |
15 | | CMI07 | ABI Band 7: shortwave infrared, normalized | [obs_time_window, lat, lon] |
16 | | CMI07 | ABI Band 10: low-level water vapor, normalized | [obs_time_window, lat, lon] |
17 | | CMI14 | ABI Bands 14: longwave infrared, normalized | [obs_time_window, lat, lon] |
18 | | hrrr_t | T2M from HRRR 1-hour forecast | [lat, lon] |
19 | | hrrr_q | Q from HRRR 1-hour forecast | [lat, lon] |
20 | | hrrr_u_10 | U10 from HRRR 1-hour forecast | [lat, lon] |
21 | | hrrr_v_10 | V10 from HRRR 1-hour forecast | [lat, lon] |
22 |
--------------------------------------------------------------------------------
/utils/YParams.py:
--------------------------------------------------------------------------------
1 | # import os
2 | # import sys
3 | import logging
4 | from ruamel.yaml import YAML
5 |
6 |
7 | class YParams:
8 | """Yaml file parser"""
9 |
10 | def __init__(self, yaml_filename, config_name, print_params=False):
11 | self._yaml_filename = yaml_filename
12 | self._config_name = config_name
13 | self.params = {}
14 |
15 | # if print_params:
16 | # print(os.system('hostname'))
17 | # print("Configuration:", yaml_filename)
18 |
19 | with open(yaml_filename, "rb") as _file:
20 | yaml = YAML().load(_file)
21 | for key, val in yaml[config_name].items():
22 | if print_params:
23 | print(key, val)
24 | if val == "None":
25 | val = None
26 |
27 | self.params[key] = val
28 | self.__setattr__(key, val)
29 |
30 | def __getitem__(self, key):
31 | return self.params[key]
32 |
33 | def __setitem__(self, key, val):
34 | self.params[key] = val
35 | self.__setattr__(key, val)
36 |
37 | def __contains__(self, key):
38 | return key in self.params
39 |
40 | def update_params(self, config):
41 | for key, val in config.items():
42 | self.params[key] = val
43 | self.__setattr__(key, val)
44 |
45 | def log(self):
46 | logging.info("------------------ Configuration ------------------")
47 | logging.info("Configuration file: " + str(self._yaml_filename))
48 | logging.info("Configuration name: " + str(self._config_name))
49 | for key, val in self.params.items():
50 | logging.info(str(key) + ": " + str(val))
51 | logging.info("---------------------------------------------------")
52 |
--------------------------------------------------------------------------------
/config/experiment.yaml:
--------------------------------------------------------------------------------
1 | ### base config ###
2 | # -*- coding: utf-8 -*-
3 | full_field: &FULL_FIELD
4 | lr: 1E-3 # 1e-4 cause loss nan for VA
5 | max_epochs: 1200
6 | valid_frequency: 5
7 |
8 | optimizer_type: "AdamW" # Adam, FusedAdam, SWA
9 |
10 | scheduler: "ReduceLROnPlateau" # ReduceLROnPlateau, MultiplicativeLR
11 | lr_reduce_factor: 0.65
12 |
13 | num_data_workers: 8 # 0
14 | # gridtype: 'sinusoidal' # options 'sinusoidal' or 'linear'
15 | enable_nhwc: !!bool False
16 |
17 | # directory path to store training checkpoints and other output
18 | exp_dir: "./exp"
19 |
20 | # directory path to store dataset for train, valid, and test
21 | data_path: "./data/"
22 | train_data_path: "./data/train"
23 | valid_data_path: "./data/valid"
24 | test_data_path: "./data/test"
25 |
26 | # normalization
27 | norm_type: "variable_wise_ignore_extreme" # options: channel_wise, variable_wise, variable_wise_ignore_extreme
28 | normalization: "minmax_ignore_extreme" # options: minmax, zscore, minmax_ignore_extreme, scale
29 |
30 | add_noise: False
31 |
32 | N_in_channels: 21
33 | N_out_channels: 5
34 |
35 | bg_ensemble_num: 1 # 3
36 | obs_time_window: 3 # if 3, use observation at analysis time
37 |
38 | inp_hrrr_vars: ['hrrr_q', 'hrrr_t', 'hrrr_u_10', 'hrrr_v_10'] # 'hrrr_sp'
39 | inp_satelite_vars: ['CMI02', 'CMI07', 'CMI14', 'CMI10']
40 | inp_obs_vars: ["sta_q", "sta_t", "sta_u10", "sta_v10"] # "sta_p"
41 | hold_out_obs: True
42 | field_tar_vars: ["rtma_q", "rtma_t", "rtma_u10", "rtma_v10"] # "rtma_sp"
43 | target_vars: ["q", "t", "u10", "v10"] # "sp"
44 |
45 | learn_residual: True
46 |
47 | input_time_feature: False # use inp_ausiliary_vars if True
48 | inp_auxiliary_vars: ["hour"]
49 |
50 | stack_channel_by_var: False
51 |
52 | save_model_freq: 5
53 | log_to_screen: !!bool True
54 | log_to_wandb: !!bool True
55 | save_checkpoint: !!bool True
56 |
57 | EncDec: &EncDec
58 | <<: *FULL_FIELD
59 | nettype: "EncDec"
60 | lr: 2E-3
61 | upscale: 1
62 | in_chans: 29 # 33
63 | out_chans: 4
64 | img_size_x: 1280
65 | img_size_y: 512
66 | window_size: 4
67 | patch_size: 4 # 1 # need be divisible by img_size
68 | num_feat: 64
69 | drop_rate: 0.1
70 | drop_path_rate: 0.1
71 | attn_drop_rate: 0.1
72 | ape: False
73 | patch_norm: True
74 | use_checkpoint: False
75 | resi_connection: "1conv"
76 | qkv_bias: True
77 | qk_scale: None
78 | img_range: 1.
79 | depths: [3] # [3]
80 | embed_dim: 64 # need be divisible by num_heads
81 | num_heads: [4] # [16]
82 | mlp_ratio: 2 # 2
83 | upsampler: "pixelshuffle"
84 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ADAF
2 |
3 | This repository contains the code used for "ADAF: An Artificial Intelligence Data Assimilation Framework for Weather Forecasting"
4 |
5 | ## Abstract
6 | The forecasting skill of numerical weather prediction (NWP) models critically depends on the accurate initial conditions, also known as analysis, provided by data assimilation (DA).
7 | Traditional DA methods often face a trade-off between computational cost and accuracy due to complex linear algebra computations and the high dimensionality of the model, especially in nonlinear systems. Moreover, processing massive data in real-time requires substantial computational resources. To address this, we introduce an artificial intelligence-based data assimilation framework (ADAF) to generate high-quality kilometer-scale analysis. This study is the pioneering work using real-world observations from varied locations and multiple sources to verify the AI method's efficacy in DA, including sparse surface weather observations and satellite imagery. We implemented ADAF for four near-surface variables in the Contiguous United States (CONUS). The results demonstrate that ADAF outperforms the High Resolution Rapid Refresh Data Assimilation System (HRRRDAS) in accuracy by 16\% to 33\%, and is able to reconstruct extreme events, such as the wind field of tropical cyclones. Sensitivity experiments reveal that ADAF can generate high-quality analysis even with low-accuracy backgrounds and extremely sparse surface observations. ADAF can assimilate massive observations within a three-hour window at low computational cost, taking about two seconds on an AMD MI200 graphics processing unit (GPU). ADAF has been shown to be efficient and effective in real-world DA, underscoring its potential role in operational weather forecasting.
8 |
9 | 
10 |
11 |
12 | ## Data
13 | - Pre-processed data
14 |
15 | [Link for Pre-processed Data - Zenodo Download Link](https://zenodo.org/records/14020879)
16 |
17 | The pre-proccesd data consists of input-target pairs. The inputs include surface weather observations within a 3-hour window, GOES-16 satellite imagery within a 3-hour window, HRRR forecast, and topography. The target is a combination of RTMA and surface weather observations. The table below summarizes the input and target datasets utilized in this study. All data were regularized to grids of size 512 $\times$ 1280 with a spatial resolution of 0.05 $\times$ 0.05 $^\circ$.
18 |
19 |
20 | |
21 | Dataset |
22 | Source |
23 | Time window |
24 | Variables/Bands |
25 |
26 |
27 | | Input |
28 | Surface weather observations |
29 | WeatherReal-Synoptic (Jin et al., 2024) |
30 | 3 hours |
31 | Q, T2M, U10, V10 |
32 |
33 |
34 | | Satellite imagery |
35 | GOES-16 (Tan et al., 2019) |
36 | 3 hours |
37 | 0.64, 3.9, 7.3, 11.2 $\mu m$ |
38 |
39 |
40 | | Background |
41 | HRRR forecast (Dowell et al., 2022) |
42 | N/A |
43 | Q, T2M, U10, V10 |
44 |
45 |
46 | | Topography |
47 | ERA5 (Hersbach et al., 2019) |
48 | N/A |
49 | Geopotential |
50 |
51 |
52 | | Target |
53 | Analysis |
54 | RTMA (Pondeca et al., 2011) |
55 | N/A |
56 | Q, T2M, U10, V10 |
57 |
58 |
59 | | Surface weather observations |
60 | WeatherReal-Synoptic (Jin et al., 2024) |
61 | N/A |
62 | Q, T2M, U10, V10 |
63 |
64 |
65 |
66 |
67 | - Pre-computed normalization statistics
68 |
69 | [Link for pre-computed normalization statistics- Zenodo Download Link](https://zenodo.org/records/14020879).
70 |
71 | If you are utilizing the pre-trained model weights that we provided, it is crucial that you utilize of the given statistics as these were used during model training. The learned model weights complement the normalizing statistics exactly.
72 |
73 | The data directory of pre-processed data and pre-computed normalization statistics is organized as follows:
74 | ```
75 | data
76 | │ README.md
77 | └───test
78 | │ │ 2022-10-01_00.nc
79 | │ │ 2022-10-02_06.nc
80 | │ │ 2022-10-03_12.nc
81 | │ │ ...
82 | │ │ 2023-10-31_00.nc
83 | └───stats.csv
84 | ```
85 |
86 | - Trained model weights
87 |
88 | [Link for trained model weights - Zenodo Download Link](https://zenodo.org/records/14020879)
89 |
90 | ```
91 | model_weights/
92 | │ best_ckpt.tar
93 | ```
94 |
95 | ## Train
96 |
97 | Training configurations can be set up in config/experiment.yaml. Notice the following paths need to be set by the user.
98 |
99 | ```
100 | exp_dir # directory path to store training checkpoints and other output
101 | train_data_path # directory path to store dataset for train
102 | valid_data_path # directory path to store dataset for valid
103 | test_data_path # directory path to store dataset for test
104 | ```
105 |
106 | An example launch script for distributed data parallel training is provided:
107 |
108 |
109 | ```shell
110 | run_num=$(date "+%Y%m%d-%H%M%S")
111 | resume=False
112 |
113 | exp_dir='./exp/'
114 | wandb_group='ADAF'
115 |
116 | net_config='EncDec'
117 | export CUDA_VISIBLE_DEVICES='4,5,6,7'
118 |
119 | nohup python -m torch.distributed.launch \
120 | --master_port=26500 \
121 | --standalone \
122 | --nproc_per_node=4 \
123 | --nnodes=1 \
124 | 02_train.py \
125 | --hold_out_obs_ratio=0.5 \
126 | --lr=0.0002 \
127 | --lr_reduce_factor=0.8 \
128 | --target='analysis_obs' \
129 | --max_epochs=2000 \
130 | --exp_dir='./exp/' \
131 | --yaml_config='./config/experiment.yaml' \
132 | --net_config='EncDec' \
133 | --resume=False \
134 | --run_num=${run_num} \
135 | --batch_size=16 \
136 | --wandb_group='ADAF' \
137 | --device='GPU' \
138 | --wandb_api_key='your wandb api key' \
139 | > logs/train_${net_config}_${run_num}.log 2>&1 &
140 |
141 | ```
142 |
143 | ## Inference
144 | In order to run ADAF in inference mode you will need to have the following files on hand.
145 |
146 | 1. The path to the test sample file. (./data/test/)
147 |
148 | 2. The inference script (inference.py)
149 |
150 | 3. The model weights hosted at Trained Model Weights (./model_weights/model_trained.ckpt)
151 |
152 | 4. The pre-computed normalization statistics (./data/stats.csv)
153 |
154 | 5. The configuration file (./config/experiment.yaml)
155 |
156 | Once you have all the file listed above you should be ready to go.
157 |
158 | An example launch script for inference is provided.
159 | ```shell
160 | export CUDA_VISIBLE_DEVICES='0'
161 |
162 | nohup python -u inference.py \
163 | --seed=0 \
164 | --exp_dir='./exp/' \ # directory to save prediction
165 | --test_data_path='./data/test' \ # path to test data
166 | --net_config='EncDec' \ # network configuration
167 | --hold_out_obs_ratio=0.3 \ # the ratio of surface observations to be fed into the model
168 | > inference.log 2>&1 &
169 |
170 | ```
171 |
172 |
173 |
174 | ## References
175 |
176 | ```
177 | 1. Jin, W. et al. WeatherReal: A Benchmark Based on In-Situ Observations for Evaluating Weather Models. (2024).
178 | 2. Dowell, D. et al. The High-Resolution Rapid Refresh (HRRR): An Hourly Updating Convection-Allowing Forecast Model. Part I: Motivation and System Description. Weather and Forecasting 37, (2022).
179 | 3. Tan, B., Dellomo, J., Wolfe, R. & Reth, A. GOES-16 and GOES-17 ABI INR assessment. in Earth Observing Systems XXIV vol. 11127 290–301 (SPIE, 2019).
180 | 4. Hersbach, H. et al. ERA5 monthly averaged data on single levels from 1979 to present. Copernicus Climate Change Service (C3S) Climate Data Store (CDS) 10, 252–266 (2019).
181 | 5. Pondeca, M. S. F. V. D. et al. The Real-Time Mesoscale Analysis at NOAA’s National Centers for Environmental Prediction: Current Status and Development. Weather and Forecasting 26, 593–612 (2011).
182 | ```
183 |
184 | If you find this work useful, cite it using:
185 |
186 | ```
187 | @article{xiang2024ADAF,
188 | title={ADAF: An Artificial Intelligence Data Assimilation Framework for Weather Forecasting},
189 | author={Yanfei Xiang and Weixin Jin and Haiyu Dong and Mingliang Bai and Zuliang Fang and Pengcheng Zhao and Hongyu Sun and Kit Thambiratnam and Qi Zhang and Xiaomeng Huang},
190 | year={2024},
191 | journal={arXiv preprint arXiv:2411.16807},
192 | url={https://arxiv.org/abs/2411.16807},
193 | }
194 | ```
195 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.rsuser
8 | *.suo
9 | *.user
10 | *.userosscache
11 | *.sln.docstates
12 |
13 | # User-specific files (MonoDevelop/Xamarin Studio)
14 | *.userprefs
15 |
16 | # Mono auto generated files
17 | mono_crash.*
18 |
19 | # Build results
20 | [Dd]ebug/
21 | [Dd]ebugPublic/
22 | [Rr]elease/
23 | [Rr]eleases/
24 | x64/
25 | x86/
26 | [Ww][Ii][Nn]32/
27 | [Aa][Rr][Mm]/
28 | [Aa][Rr][Mm]64/
29 | bld/
30 | [Bb]in/
31 | [Oo]bj/
32 | [Ll]og/
33 | [Ll]ogs/
34 |
35 | # Visual Studio 2015/2017 cache/options directory
36 | .vs/
37 | # Uncomment if you have tasks that create the project's static files in wwwroot
38 | #wwwroot/
39 |
40 | # Visual Studio 2017 auto generated files
41 | Generated\ Files/
42 |
43 | # MSTest test Results
44 | [Tt]est[Rr]esult*/
45 | [Bb]uild[Ll]og.*
46 |
47 | # NUnit
48 | *.VisualState.xml
49 | TestResult.xml
50 | nunit-*.xml
51 |
52 | # Build Results of an ATL Project
53 | [Dd]ebugPS/
54 | [Rr]eleasePS/
55 | dlldata.c
56 |
57 | # Benchmark Results
58 | BenchmarkDotNet.Artifacts/
59 |
60 | # .NET Core
61 | project.lock.json
62 | project.fragment.lock.json
63 | artifacts/
64 |
65 | # ASP.NET Scaffolding
66 | ScaffoldingReadMe.txt
67 |
68 | # StyleCop
69 | StyleCopReport.xml
70 |
71 | # Files built by Visual Studio
72 | *_i.c
73 | *_p.c
74 | *_h.h
75 | *.ilk
76 | *.meta
77 | *.obj
78 | *.iobj
79 | *.pch
80 | *.pdb
81 | *.ipdb
82 | *.pgc
83 | *.pgd
84 | *.rsp
85 | # but not Directory.Build.rsp, as it configures directory-level build defaults
86 | !Directory.Build.rsp
87 | *.sbr
88 | *.tlb
89 | *.tli
90 | *.tlh
91 | *.tmp
92 | *.tmp_proj
93 | *_wpftmp.csproj
94 | *.log
95 | *.tlog
96 | *.vspscc
97 | *.vssscc
98 | .builds
99 | *.pidb
100 | *.svclog
101 | *.scc
102 |
103 | # Chutzpah Test files
104 | _Chutzpah*
105 |
106 | # Visual C++ cache files
107 | ipch/
108 | *.aps
109 | *.ncb
110 | *.opendb
111 | *.opensdf
112 | *.sdf
113 | *.cachefile
114 | *.VC.db
115 | *.VC.VC.opendb
116 |
117 | # Visual Studio profiler
118 | *.psess
119 | *.vsp
120 | *.vspx
121 | *.sap
122 |
123 | # Visual Studio Trace Files
124 | *.e2e
125 |
126 | # TFS 2012 Local Workspace
127 | $tf/
128 |
129 | # Guidance Automation Toolkit
130 | *.gpState
131 |
132 | # ReSharper is a .NET coding add-in
133 | _ReSharper*/
134 | *.[Rr]e[Ss]harper
135 | *.DotSettings.user
136 |
137 | # TeamCity is a build add-in
138 | _TeamCity*
139 |
140 | # DotCover is a Code Coverage Tool
141 | *.dotCover
142 |
143 | # AxoCover is a Code Coverage Tool
144 | .axoCover/*
145 | !.axoCover/settings.json
146 |
147 | # Coverlet is a free, cross platform Code Coverage Tool
148 | coverage*.json
149 | coverage*.xml
150 | coverage*.info
151 |
152 | # Visual Studio code coverage results
153 | *.coverage
154 | *.coveragexml
155 |
156 | # NCrunch
157 | _NCrunch_*
158 | .*crunch*.local.xml
159 | nCrunchTemp_*
160 |
161 | # MightyMoose
162 | *.mm.*
163 | AutoTest.Net/
164 |
165 | # Web workbench (sass)
166 | .sass-cache/
167 |
168 | # Installshield output folder
169 | [Ee]xpress/
170 |
171 | # DocProject is a documentation generator add-in
172 | DocProject/buildhelp/
173 | DocProject/Help/*.HxT
174 | DocProject/Help/*.HxC
175 | DocProject/Help/*.hhc
176 | DocProject/Help/*.hhk
177 | DocProject/Help/*.hhp
178 | DocProject/Help/Html2
179 | DocProject/Help/html
180 |
181 | # Click-Once directory
182 | publish/
183 |
184 | # Publish Web Output
185 | *.[Pp]ublish.xml
186 | *.azurePubxml
187 | # Note: Comment the next line if you want to checkin your web deploy settings,
188 | # but database connection strings (with potential passwords) will be unencrypted
189 | *.pubxml
190 | *.publishproj
191 |
192 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
193 | # checkin your Azure Web App publish settings, but sensitive information contained
194 | # in these scripts will be unencrypted
195 | PublishScripts/
196 |
197 | # NuGet Packages
198 | *.nupkg
199 | # NuGet Symbol Packages
200 | *.snupkg
201 | # The packages folder can be ignored because of Package Restore
202 | **/[Pp]ackages/*
203 | # except build/, which is used as an MSBuild target.
204 | !**/[Pp]ackages/build/
205 | # Uncomment if necessary however generally it will be regenerated when needed
206 | #!**/[Pp]ackages/repositories.config
207 | # NuGet v3's project.json files produces more ignorable files
208 | *.nuget.props
209 | *.nuget.targets
210 |
211 | # Microsoft Azure Build Output
212 | csx/
213 | *.build.csdef
214 |
215 | # Microsoft Azure Emulator
216 | ecf/
217 | rcf/
218 |
219 | # Windows Store app package directories and files
220 | AppPackages/
221 | BundleArtifacts/
222 | Package.StoreAssociation.xml
223 | _pkginfo.txt
224 | *.appx
225 | *.appxbundle
226 | *.appxupload
227 |
228 | # Visual Studio cache files
229 | # files ending in .cache can be ignored
230 | *.[Cc]ache
231 | # but keep track of directories ending in .cache
232 | !?*.[Cc]ache/
233 |
234 | # Others
235 | ClientBin/
236 | ~$*
237 | *~
238 | *.dbmdl
239 | *.dbproj.schemaview
240 | *.jfm
241 | *.pfx
242 | *.publishsettings
243 | orleans.codegen.cs
244 |
245 | # Including strong name files can present a security risk
246 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
247 | #*.snk
248 |
249 | # Since there are multiple workflows, uncomment next line to ignore bower_components
250 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
251 | #bower_components/
252 |
253 | # RIA/Silverlight projects
254 | Generated_Code/
255 |
256 | # Backup & report files from converting an old project file
257 | # to a newer Visual Studio version. Backup files are not needed,
258 | # because we have git ;-)
259 | _UpgradeReport_Files/
260 | Backup*/
261 | UpgradeLog*.XML
262 | UpgradeLog*.htm
263 | ServiceFabricBackup/
264 | *.rptproj.bak
265 |
266 | # SQL Server files
267 | *.mdf
268 | *.ldf
269 | *.ndf
270 |
271 | # Business Intelligence projects
272 | *.rdl.data
273 | *.bim.layout
274 | *.bim_*.settings
275 | *.rptproj.rsuser
276 | *- [Bb]ackup.rdl
277 | *- [Bb]ackup ([0-9]).rdl
278 | *- [Bb]ackup ([0-9][0-9]).rdl
279 |
280 | # Microsoft Fakes
281 | FakesAssemblies/
282 |
283 | # GhostDoc plugin setting file
284 | *.GhostDoc.xml
285 |
286 | # Node.js Tools for Visual Studio
287 | .ntvs_analysis.dat
288 | node_modules/
289 |
290 | # Visual Studio 6 build log
291 | *.plg
292 |
293 | # Visual Studio 6 workspace options file
294 | *.opt
295 |
296 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
297 | *.vbw
298 |
299 | # Visual Studio 6 auto-generated project file (contains which files were open etc.)
300 | *.vbp
301 |
302 | # Visual Studio 6 workspace and project file (working project files containing files to include in project)
303 | *.dsw
304 | *.dsp
305 |
306 | # Visual Studio 6 technical files
307 | *.ncb
308 | *.aps
309 |
310 | # Visual Studio LightSwitch build output
311 | **/*.HTMLClient/GeneratedArtifacts
312 | **/*.DesktopClient/GeneratedArtifacts
313 | **/*.DesktopClient/ModelManifest.xml
314 | **/*.Server/GeneratedArtifacts
315 | **/*.Server/ModelManifest.xml
316 | _Pvt_Extensions
317 |
318 | # Paket dependency manager
319 | .paket/paket.exe
320 | paket-files/
321 |
322 | # FAKE - F# Make
323 | .fake/
324 |
325 | # CodeRush personal settings
326 | .cr/personal
327 |
328 | # Python Tools for Visual Studio (PTVS)
329 | __pycache__/
330 | *.pyc
331 |
332 | # Cake - Uncomment if you are using it
333 | # tools/**
334 | # !tools/packages.config
335 |
336 | # Tabs Studio
337 | *.tss
338 |
339 | # Telerik's JustMock configuration file
340 | *.jmconfig
341 |
342 | # BizTalk build output
343 | *.btp.cs
344 | *.btm.cs
345 | *.odx.cs
346 | *.xsd.cs
347 |
348 | # OpenCover UI analysis results
349 | OpenCover/
350 |
351 | # Azure Stream Analytics local run output
352 | ASALocalRun/
353 |
354 | # MSBuild Binary and Structured Log
355 | *.binlog
356 |
357 | # NVidia Nsight GPU debugger configuration file
358 | *.nvuser
359 |
360 | # MFractors (Xamarin productivity tool) working folder
361 | .mfractor/
362 |
363 | # Local History for Visual Studio
364 | .localhistory/
365 |
366 | # Visual Studio History (VSHistory) files
367 | .vshistory/
368 |
369 | # BeatPulse healthcheck temp database
370 | healthchecksdb
371 |
372 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
373 | MigrationBackup/
374 |
375 | # Ionide (cross platform F# VS Code tools) working folder
376 | .ionide/
377 |
378 | # Fody - auto-generated XML schema
379 | FodyWeavers.xsd
380 |
381 | # VS Code files for those working on multiple tools
382 | .vscode/*
383 | !.vscode/settings.json
384 | !.vscode/tasks.json
385 | !.vscode/launch.json
386 | !.vscode/extensions.json
387 | *.code-workspace
388 |
389 | # Local History for Visual Studio Code
390 | .history/
391 |
392 | # Windows Installer files from build outputs
393 | *.cab
394 | *.msi
395 | *.msix
396 | *.msm
397 | *.msp
398 |
399 | # JetBrains Rider
400 | *.sln.iml
401 |
--------------------------------------------------------------------------------
/utils/data_loader_multifiles.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import logging
5 | import numpy as np
6 | import pandas as pd
7 | import xarray as xr
8 | from icecream import ic
9 | from torch.utils.data import dataloader, dataset
10 | from torch.utils.data.distributed import distributedsampler
11 |
12 |
13 | def get_data_loader(params, files_pattern, distributed, train):
14 | dataset = getdataset(params, files_pattern, train)
15 |
16 | if distributed:
17 | sampler = distributedsampler(dataset, shuffle=train)
18 | else:
19 | none
20 |
21 | dataloader = dataloader(
22 | dataset,
23 | batch_size=int(params.batch_size),
24 | num_workers=params.num_data_workers,
25 | shuffle=false, # (sampler is none),
26 | sampler=sampler if train else none,
27 | drop_last=true,
28 | pin_memory=torch.cuda.is_available(),
29 | )
30 |
31 | if train:
32 | return dataloader, dataset, sampler
33 | else:
34 | return dataloader, dataset
35 |
36 |
37 | class getdataset(dataset):
38 | def __init__(self, params, location, train):
39 | self.params = params
40 | self.train = train
41 | self.location = location
42 | self.n_in_channels = params.n_in_channels
43 | self.n_out_channels = params.n_out_channels
44 | # self.add_noise = params.add_noise if train else false
45 | self._get_files_stats()
46 |
47 | def _get_files_stats(self):
48 | self.files_paths = glob.glob(self.location + "/*.nc")
49 | self.files_paths.sort()
50 | self.n_samples_total = len(self.files_paths)
51 |
52 | logging.info("getting file stats from {}".format(self.files_paths[0]))
53 | ds = xr.open_dataset(self.files_paths[0], engine="netcdf4")
54 |
55 | # original image shape (before padding)
56 | self.org_img_shape_x = ds["hrrr_t"].shape[0]
57 | self.org_img_shape_y = ds["hrrr_t"].shape[1]
58 |
59 | self.files = [none for _ in range(self.n_samples_total)]
60 |
61 | logging.info("number of samples: {}".format(self.n_samples_total))
62 | logging.info(
63 | "found data at path {}. number of examples: {}. \
64 | original image shape: {} x {} x {}".format(
65 | self.location,
66 | self.n_samples_total,
67 | self.org_img_shape_x,
68 | self.org_img_shape_y,
69 | self.n_in_channels,
70 | )
71 | )
72 |
73 | def _open_file(self, hour_idx):
74 | _file = xr.open_dataset(self.files_paths[hour_idx], engine="netcdf4")
75 | self.files[hour_idx] = _file
76 |
77 | def _min_max_norm_ignore_extreme_fill_nan(self, data, vmin, vmax):
78 | # ic(vmin.shape, vmax.shape, data.shape)
79 | if data.ndim == 4:
80 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis, np.newaxis]
81 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis, np.newaxis]
82 | vmax = np.repeat(vmax, data.shape[1], axis=1)
83 | vmax = np.repeat(vmax, data.shape[2], axis=2)
84 | vmax = np.repeat(vmax, data.shape[3], axis=3)
85 | vmin = np.repeat(vmin, data.shape[1], axis=1)
86 | vmin = np.repeat(vmin, data.shape[2], axis=2)
87 | vmin = np.repeat(vmin, data.shape[3], axis=3)
88 |
89 | elif data.ndim == 3:
90 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis]
91 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis]
92 | vmax = np.repeat(vmax, data.shape[1], axis=1)
93 | vmax = np.repeat(vmax, data.shape[2], axis=2)
94 | vmin = np.repeat(vmin, data.shape[1], axis=1)
95 | vmin = np.repeat(vmin, data.shape[2], axis=2)
96 |
97 | data -= vmin
98 | data *= 2.0 / (vmax - vmin)
99 | data -= 1.0
100 |
101 | data = np.where(data > 1, 1, data)
102 | data = np.where(data < -1, -1, data)
103 |
104 | data = np.nan_to_num(data, nan=0)
105 |
106 | return data
107 |
108 | def __len__(self):
109 | return self.n_samples_total
110 |
111 | def __getitem__(self, hour_idx):
112 | if self.files[hour_idx] is none:
113 | self._open_file(hour_idx)
114 |
115 | # %% get statistic value for normalization
116 | stats_file = os.path.join(
117 | self.params.data_path, f"stats_{self.params.norm_type}.csv"
118 | )
119 | stats = pd.read_csv(stats_file, index_col=0)
120 |
121 | for vi, var in enumerate(self.params.inp_hrrr_vars):
122 | if vi == 0:
123 | inp_hrrr_stats = stats[stats["variable"].isin([var])]
124 | else:
125 | inp_hrrr_stats = pd.concat(
126 | [inp_hrrr_stats, stats[stats["variable"].isin([var])]]
127 | )
128 |
129 | # %% read data in file
130 | if len(self.params.inp_hrrr_vars) != 0:
131 | inp_hrrr = np.array(
132 | self.files[hour_idx][self.params.inp_hrrr_vars].to_array()
133 | )[:, : self.params.img_size_y, : self.params.img_size_x]
134 | inp_hrrr = np.squeeze(inp_hrrr)
135 | # ic(inp_hrrr.shape)
136 |
137 | field_mask = inp_hrrr.copy()
138 | field_mask[field_mask != 0] = 1 # set 1 where out of range
139 |
140 | # normalization
141 | inp_hrrr = self._min_max_norm_ignore_extreme_fill_nan(
142 | inp_hrrr, inp_hrrr_stats["min"], inp_hrrr_stats["max"]
143 | )
144 |
145 | if len(self.params.inp_obs_vars) != 0:
146 | obs = np.array(
147 | self.files[hour_idx][self.params.inp_obs_vars].to_array())[
148 | :,
149 | -self.params.obs_time_window:,
150 | : self.params.img_size_y,
151 | : self.params.img_size_x,
152 | ]
153 |
154 | # inp_obs not includes nan, 0 means un-observaed location
155 | # use all observation as target
156 | obs_tar = obs[:, -1]
157 |
158 | # quality control
159 | obs_tar[(obs_tar <= -1) | (obs_tar >= 1)] = 0
160 |
161 | # this is for label, which is a combination of obs and analysis
162 | obs_tar_mask = obs_tar.copy()
163 | # 1 means observed, 0 means un-observed
164 | obs_tar_mask[obs_tar_mask != 0] = 1
165 | # print(f'obs_tar: {obs_tar.shape}')
166 |
167 | # print(f'inp_obs: {inp_obs.shape}')
168 | if self.params.hold_out_obs:
169 | # [lat, lon]
170 | # obs_mask = np.array(self.files[hour_idx]["obs_mask"])
171 | # # print(f'obs_mask: {obs_mask.shape}')
172 | # inp_obs = inp_obs * (1 - obs_mask)
173 |
174 | if self.params.obs_mask_seed != 0:
175 | np.random.seed(self.params.obs_mask_seed)
176 | logging.info(
177 | f"using random seed {self.params.obs_mask_seed}")
178 |
179 | lat_num = obs[0, 0].shape[0]
180 | lon_num = obs[0, 0].shape[1]
181 |
182 | # [lat, lon] -> [lat * lon]
183 | obs_tw_begin = obs[0, 0].reshape(-1)
184 | obs_index = np.where(~np.isnan(obs_tw_begin))[
185 | 0
186 | ] # find station's indices
187 |
188 | obs_num = len(obs_index)
189 | hold_out_num = int(obs_num * self.params.hold_out_obs_ratio)
190 | ic(obs_num, hold_out_num)
191 |
192 | np.random.shuffle(obs_index) # generate mask randomly
193 | hold_out_obs_index = obs_index[:hold_out_num]
194 | # input_obs_index = obs_index[hold_out_num:]
195 | # ic(len(hold_out_obs_index), hold_out_obs_index)
196 | # ic(len(input_obs_index), input_obs_index)
197 |
198 | # mask (lat, lon), hold_out obs=1, input obs = 0
199 | obs_mask = np.zeros(obs_tw_begin.shape)
200 | obs_mask[hold_out_obs_index] = 1
201 | obs_mask = obs_mask.reshape([lat_num, lon_num])
202 |
203 | # observation for input
204 | inp_obs = obs * (1 - obs_mask)
205 | # observation excluding the input
206 | # hold_out_obs = obs * obs_mask
207 |
208 | # ic(inp_obs.shape)
209 | inp_obs = inp_obs.reshape(
210 | (-1, self.params.img_size_y, self.params.img_size_x)
211 | )
212 |
213 | if len(self.params.inp_satelite_vars) != 0:
214 | inp_sate = np.array(
215 | self.files[hour_idx][self.params.inp_satelite_vars].to_array()
216 | )[
217 | :,
218 | -self.params.obs_time_window:,
219 | : self.params.img_size_y,
220 | : self.params.img_size_x,
221 | ]
222 |
223 | lon = np.array(self.files[hour_idx].coords["lon"].values)[
224 | : self.params.img_size_x
225 | ]
226 | lat = np.array(self.files[hour_idx].coords["lat"].values)[
227 | : self.params.img_size_y
228 | ]
229 | topo = np.array(self.files[hour_idx][["z"]].to_array())[
230 | :, : self.params.img_size_y, : self.params.img_size_x
231 | ]
232 | field_tar = np.array(
233 | self.files[hour_idx][self.params.field_tar_vars].to_array()
234 | )[:, : self.params.img_size_y, : self.params.img_size_x]
235 |
236 | # a combination of observation and target analysis field
237 | # use observed value to replace the analysis value
238 | # at observed locations
239 | field_obs_tar = field_tar.copy()
240 | # use 0 replace the value at observed location
241 | field_obs_tar[obs_tar_mask == 1] = 0
242 | field_obs_tar += obs_tar
243 |
244 | # norm(field_tar) - norm(inp_hrrr)
245 | # norm(obs_tar) - norm(inp_hrrr)
246 | # norm(field_obs_tar) - norm(inp_hrrr)
247 | if self.params.learn_residual:
248 | field_tar = field_tar - inp_hrrr
249 | obs_tar = obs_tar - inp_hrrr
250 | field_obs_tar = field_obs_tar - inp_hrrr
251 |
252 | inp = np.concatenate((inp_hrrr, inp_obs, topo), axis=0)
253 |
254 | if len(self.params.inp_satelite_vars) != 0:
255 | inp_sate = inp_sate.reshape(
256 | (-1, self.params.img_size_y, self.params.img_size_x)
257 | )
258 | inp = np.concatenate((inp, inp_sate))
259 |
260 | return (
261 | inp,
262 | field_tar,
263 | obs_tar,
264 | field_obs_tar,
265 | inp_hrrr,
266 | lat,
267 | lon,
268 | field_mask,
269 | obs_tar_mask,
270 | )
271 | else:
272 | return (
273 | inp,
274 | field_tar,
275 | obs_tar,
276 | field_obs_tar,
277 | inp_hrrr,
278 | lat,
279 | lon,
280 | field_mask,
281 | obs_tar_mask,
282 | )
283 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import time
4 | import torch
5 | import logging
6 | import datetime
7 | import argparse
8 | import numpy as np
9 | import pandas as pd
10 | import xarray as xr
11 |
12 | from math import sqrt
13 | from icecream import ic
14 | from str2bool import str2bool
15 | from collections import OrderedDict
16 | from sklearn.metrics import mean_squared_error
17 |
18 | from utils.logging_utils import config_logger, log_to_file
19 | from utils.YParams import YParams
20 | from utils.read_txt import read_lines_from_file
21 |
22 | config_logger()
23 |
24 |
25 | def gaussian_perturb(x, level=0.01, device=0):
26 | noise = level * torch.randn(x.shape).to(device, dtype=torch.float)
27 | return x + noise
28 |
29 |
30 | def enable_dropout(model):
31 | """Function to enable the dropout layers during test-time"""
32 | for m in model.modules():
33 | if m.__class__.__name__.startswith("Dropout"):
34 | m.train()
35 |
36 |
37 | def load_model(model, params, checkpoint_file):
38 | model.zero_grad()
39 | checkpoint_fname = checkpoint_file
40 | checkpoint = torch.load(checkpoint_fname)
41 | try:
42 | new_state_dict = OrderedDict()
43 | for key, val in checkpoint["model_state"].items():
44 | name = key[7:]
45 | if name != "ged":
46 | new_state_dict[name] = val
47 | model.load_state_dict(new_state_dict)
48 | except ValueError:
49 | model.load_state_dict(checkpoint["model_state"])
50 | model.eval()
51 | return model
52 |
53 |
54 | def setup(params):
55 |
56 | # device init
57 | if torch.cuda.is_available():
58 | device = torch.cuda.current_device()
59 | else:
60 | device = "cpu"
61 |
62 | if params.nettype == "EncDec":
63 | from models.encdec import EncDec as model
64 | else:
65 | raise Exception("not implemented")
66 |
67 | checkpoint_file = params["best_checkpoint_path"]
68 | logging.info("Loading model checkpoint from {}".format(checkpoint_file))
69 | model = model(params).to(device)
70 | model = load_model(model, params, checkpoint_file)
71 | model = model.to(device)
72 |
73 | files_paths = glob.glob(params.test_data_path + "/*.nc")
74 | files_paths.sort()
75 |
76 | return files_paths, inference_times, model
77 |
78 |
79 | def min_max_norm(data, vmin, vmax):
80 | if data.ndim == 4:
81 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis, np.newaxis]
82 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis, np.newaxis]
83 | vmax = np.repeat(vmax, data.shape[1], axis=1)
84 | vmax = np.repeat(vmax, data.shape[2], axis=2)
85 | vmax = np.repeat(vmax, data.shape[3], axis=3)
86 | vmin = np.repeat(vmin, data.shape[1], axis=1)
87 | vmin = np.repeat(vmin, data.shape[2], axis=2)
88 | vmin = np.repeat(vmin, data.shape[3], axis=3)
89 |
90 | elif data.ndim == 3:
91 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis]
92 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis]
93 | vmax = np.repeat(vmax, data.shape[1], axis=1)
94 | vmax = np.repeat(vmax, data.shape[2], axis=2)
95 | vmin = np.repeat(vmin, data.shape[1], axis=1)
96 | vmin = np.repeat(vmin, data.shape[2], axis=2)
97 |
98 | data = (data - vmin) / (vmax - vmin)
99 | return data
100 |
101 |
102 | def min_max_norm_ignore_extreme_fill_nan(data, vmin, vmax):
103 | if data.ndim == 4:
104 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis, np.newaxis]
105 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis, np.newaxis]
106 | vmax = np.repeat(vmax, data.shape[1], axis=1)
107 | vmax = np.repeat(vmax, data.shape[2], axis=2)
108 | vmax = np.repeat(vmax, data.shape[3], axis=3)
109 | vmin = np.repeat(vmin, data.shape[1], axis=1)
110 | vmin = np.repeat(vmin, data.shape[2], axis=2)
111 | vmin = np.repeat(vmin, data.shape[3], axis=3)
112 | elif data.ndim == 3:
113 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis]
114 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis]
115 | vmax = np.repeat(vmax, data.shape[1], axis=1)
116 | vmax = np.repeat(vmax, data.shape[2], axis=2)
117 | vmin = np.repeat(vmin, data.shape[1], axis=1)
118 | vmin = np.repeat(vmin, data.shape[2], axis=2)
119 |
120 | data -= vmin
121 | data *= 2.0 / (vmax - vmin)
122 | data -= 1.0
123 |
124 | data = np.where(data > 1, 1, data)
125 | data = np.where(data < -1, -1, data)
126 | data = np.nan_to_num(data, nan=0)
127 |
128 | return data
129 |
130 | def reverse_norm(params, data, variable_names):
131 |
132 | stats_file = os.path.join(
133 | params.data_path,
134 | f"stats_{params.norm_type}.csv")
135 | stats = pd.read_csv(stats_file, index_col=0)
136 |
137 | # for vi, var in enumerate(params.field_tar_vars):
138 | for vi, var in enumerate(variable_names):
139 | if vi == 0:
140 | field_tar_stats = stats[stats["variable"].isin([var])]
141 | else:
142 | field_tar_stats = pd.concat(
143 | [field_tar_stats, stats[stats["variable"].isin([var])]]
144 | )
145 | ic(field_tar_stats)
146 |
147 | if params.normalization == "minmax_ignore_extreme":
148 | vmin = field_tar_stats["min"]
149 | vmax = field_tar_stats["max"]
150 |
151 | if len(data.shape) == 4:
152 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis, np.newaxis]
153 | vmin = np.repeat(vmin, data.shape[1], axis=1)
154 | vmin = np.repeat(vmin, data.shape[2], axis=2)
155 | vmin = np.repeat(vmin, data.shape[3], axis=3)
156 | vmin = np.squeeze(vmin)
157 |
158 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis, np.newaxis]
159 | vmax = np.repeat(vmax, data.shape[1], axis=1)
160 | vmax = np.repeat(vmax, data.shape[2], axis=2)
161 | vmax = np.repeat(vmax, data.shape[3], axis=3)
162 | vmax = np.squeeze(vmax)
163 |
164 | if len(data.shape) == 3:
165 | vmin = np.array(vmin)[:, np.newaxis, np.newaxis]
166 | vmin = np.repeat(vmin, data.shape[1], axis=1)
167 | vmin = np.repeat(vmin, data.shape[2], axis=2)
168 | vmin = np.squeeze(vmin)
169 |
170 | vmax = np.array(vmax)[:, np.newaxis, np.newaxis]
171 | vmax = np.repeat(vmax, data.shape[1], axis=1)
172 | vmax = np.repeat(vmax, data.shape[2], axis=2)
173 | vmax = np.squeeze(vmax)
174 |
175 | if len(data.shape) == 2:
176 | vmin = np.array(vmin)[np.newaxis, np.newaxis]
177 | vmin = np.repeat(vmin, data.shape[0], axis=0)
178 | vmin = np.repeat(vmin, data.shape[1], axis=1)
179 | vmin = np.squeeze(vmin)
180 |
181 | vmax = np.array(vmax)[np.newaxis, np.newaxis]
182 | vmax = np.repeat(vmax, data.shape[0], axis=0)
183 | vmax = np.repeat(vmax, data.shape[1], axis=1)
184 | vmax = np.squeeze(vmax)
185 |
186 | data = (data + 1) * (vmax - vmin) / 2 + vmin
187 |
188 | else:
189 | raise Exception("not implemented")
190 |
191 | return data
192 |
193 |
194 | def inference(
195 | params,
196 | target_variable,
197 | test_data_file_paths,
198 | inference_times,
199 | hold_out_obs_ratio,
200 | model,
201 | ):
202 | if torch.cuda.is_available():
203 | device = torch.cuda.current_device()
204 | else:
205 | device = "cpu"
206 |
207 | out_dir = os.path.join(
208 | params["experiment_dir"],
209 | f"inference_ensemble_{params.ensemble_num}_hold_{hold_out_obs_ratio}",
210 | )
211 | os.makedirs(out_dir, exist_ok=True)
212 |
213 | with torch.no_grad():
214 | for f, analysis_time_str in zip(test_data_file_paths, inference_times):
215 | analysis_time = datetime.datetime.strptime(
216 | analysis_time_str, "%Y-%m-%d_%H")
217 | logging.info("-----------------------------------------")
218 | logging.info(f"Analysis time: {analysis_time_str}")
219 | logging.info(f"Reading {f}")
220 |
221 | out_file = os.path.join(
222 | out_dir, analysis_time_str + ".nc")
223 |
224 | if not os.path.exists(f):
225 | logging.info(f"{f} not exists, skip!")
226 | continue
227 |
228 | data = read_sample_file_and_norm_input(
229 | params, f, hold_out_obs_ratio)
230 | (
231 | inp, # normed
232 | inp_sate_norm, # normed
233 | inp_hrrr_norm, # normed
234 | field_target, # normed
235 | hold_out_obs, # normed
236 | inp_obs_for_eval, # normed
237 | bg_hrrr, # un-norm
238 | mask,
239 | lat,
240 | lon,
241 | ) = data
242 |
243 | field_target = reverse_norm(
244 | params, field_target, params.field_tar_vars)
245 | # mask the region out of range, after reverse normalization
246 | field_target = np.where(
247 | mask, field_target, np.nan
248 | ) # fill data with nan where mask is True.
249 | bg_hrrr = np.where(
250 | mask, bg_hrrr, np.nan
251 | ) # fill data with nan where mask is True.
252 |
253 | # Reverse normalization
254 | hold_out_obs = np.where(
255 | hold_out_obs == 0, np.nan, hold_out_obs
256 | ) # fill 0 with nan before normalization reversing
257 | hold_out_obs = reverse_norm(
258 | params, hold_out_obs, params.inp_obs_vars)
259 |
260 | inp_obs_for_eval = np.where(
261 | inp_obs_for_eval == 0, np.nan, inp_obs_for_eval
262 | ) # fill 0 with nan before normalization reversing
263 | inp_obs_for_eval = reverse_norm(
264 | params, inp_obs_for_eval, params.inp_obs_vars
265 | )
266 |
267 | # Unit convert: g/kg -> kg/kg
268 | q_idx = target_variable.index("q")
269 | field_target[q_idx] = field_target[q_idx] * 1000
270 | hold_out_obs[q_idx] = hold_out_obs[q_idx] * 1000
271 | inp_obs_for_eval[q_idx] = inp_obs_for_eval[q_idx] * 1000
272 | bg_hrrr[q_idx] = bg_hrrr[q_idx] * 1000
273 |
274 | inp = torch.tensor(inp[np.newaxis, :, :, :]).to(
275 | device, dtype=torch.float)
276 | inp_sate_norm = torch.tensor(
277 | inp_sate_norm[np.newaxis, :, :, :, :]).to(
278 | device, dtype=torch.float
279 | )
280 |
281 | gen_ensembles = []
282 | for i in range(params.ensemble_num):
283 | model.eval()
284 | enable_dropout(model)
285 |
286 | start = time.time()
287 |
288 | if params.nettype == "EncDec":
289 | inp_sate_norm = torch.reshape(
290 | inp_sate_norm,
291 | (1, -1, params.img_size_y, params.img_size_x)
292 | )
293 | inp = torch.concat((inp, inp_sate_norm), 1)
294 | gen = model(inp)
295 | else:
296 | raise Exception("not implemented")
297 |
298 | print(f"inference time: {time.time() - start}")
299 |
300 | if params.learn_residual:
301 | gen = np.squeeze(
302 | gen.detach().cpu().numpy()) + inp_hrrr_norm
303 |
304 | # reverse normalization
305 | gen = reverse_norm(params, gen, params.field_tar_vars)
306 |
307 | # for specific humidity, g/kg -> kg/kg
308 | gen[q_idx] = gen[q_idx] * 1000
309 |
310 | # mask the region out of range with 0,
311 | # after reverse normalization
312 | gen = np.where(mask, gen, np.nan)
313 |
314 | gen_ensembles.append(gen)
315 |
316 | gen_ensembles = np.array(gen_ensembles)
317 |
318 | for vi, tar_var in enumerate(target_variable):
319 | logging.info(f"{tar_var}:")
320 |
321 | # %% compare with hold-out observation
322 | obs_hold_obs_var = hold_out_obs[vi][
323 | ~np.isnan(hold_out_obs[vi])]
324 | if len(obs_hold_obs_var) == 0:
325 | print("No hold out obs, continue!")
326 | continue
327 | bg_hrrr_hold_obs = bg_hrrr[vi][
328 | ~np.isnan(hold_out_obs[vi])]
329 | ai_gen_hold_obs = gen_ensembles[0, vi][
330 | ~np.isnan(hold_out_obs[vi])]
331 | obs_hold_obs_var = np.nan_to_num(
332 | obs_hold_obs_var, nan=0)
333 | bg_hrrr_hold_obs = np.nan_to_num(
334 | bg_hrrr_hold_obs, nan=0)
335 | ai_gen_hold_obs = np.nan_to_num(
336 | ai_gen_hold_obs, nan=0)
337 | rmse_ai_hold_obs = round(
338 | sqrt(mean_squared_error(
339 | obs_hold_obs_var, ai_gen_hold_obs)), 3
340 | )
341 | rmse_bg_hold_obs = round(
342 | sqrt(mean_squared_error(
343 | obs_hold_obs_var, bg_hrrr_hold_obs)), 3
344 | )
345 | logging.info(f"rmse_ai_hold_obs={rmse_ai_hold_obs}")
346 | logging.info(f"rmse_bg_hold_obs={rmse_bg_hold_obs}")
347 |
348 | # %% compare with input observation
349 | inp_obs_for_eval_var = inp_obs_for_eval[vi][
350 | ~np.isnan(inp_obs_for_eval[vi])
351 | ]
352 | bg_hrrr_inp_obs = bg_hrrr[vi][
353 | ~np.isnan(inp_obs_for_eval[vi])]
354 | ai_gen_inp_obs = gen_ensembles[0, vi][
355 | ~np.isnan(inp_obs_for_eval[vi])]
356 | inp_obs_for_eval_var = np.nan_to_num(
357 | inp_obs_for_eval_var, nan=0)
358 | bg_hrrr_inp_obs = np.nan_to_num(
359 | bg_hrrr_inp_obs, nan=0)
360 | ai_gen_inp_obs = np.nan_to_num(
361 | ai_gen_inp_obs, nan=0)
362 | rmse_ai_inp_obs = round(
363 | sqrt(mean_squared_error(
364 | inp_obs_for_eval_var, ai_gen_inp_obs)), 3
365 | )
366 | rmse_bg_inp_obs = round(
367 | sqrt(mean_squared_error(
368 | inp_obs_for_eval_var, bg_hrrr_inp_obs)), 3
369 | )
370 | logging.info(f"rmse_ai_inp_obs={rmse_ai_inp_obs}")
371 | logging.info(f"rmse_bg_inp_obs={rmse_bg_inp_obs}")
372 |
373 | rmse_ai_field = round(
374 | sqrt(
375 | mean_squared_error(
376 | np.nan_to_num(gen_ensembles[0, vi], nan=0),
377 | np.nan_to_num(field_target[vi], nan=0),
378 | )
379 | ),
380 | 3,
381 | )
382 | logging.info(f"rmse_ai_field={rmse_ai_field}")
383 |
384 | rmse_bg_field = round(
385 | sqrt(
386 | mean_squared_error(
387 | np.nan_to_num(bg_hrrr[vi], nan=0),
388 | np.nan_to_num(field_target[vi], nan=0),
389 | )
390 | ),
391 | 3,
392 | )
393 | logging.info(f"rmse_bg_field={rmse_bg_field}")
394 |
395 | logging.info(
396 | "AI generation :"
397 | + f"{round(np.nanmin(gen_ensembles[0,vi]), 3)}"
398 | + f"~ {round(np.nanmax(gen_ensembles[0,vi]), 3)}"
399 | )
400 | logging.info(
401 | "Background (hrrr):"
402 | + f"{round(np.nanmin(bg_hrrr[vi]), 3)}"
403 | + f"~ {round(np.nanmax(bg_hrrr[vi]), 3)}"
404 | )
405 | logging.info(
406 | "hold out obs :"
407 | + f"{round(np.nanmin(hold_out_obs[vi]), 3)}"
408 | + f"~ {round(np.nanmax(hold_out_obs[vi]), 3)}"
409 | )
410 | logging.info(
411 | "field_target :"
412 | + f"{round(np.nanmin(field_target[vi]), 3)}"
413 | + f"~ {round(np.nanmax(field_target[vi]), 3)}"
414 | )
415 |
416 | variable_names = [s.split("_")[1] for s in params.field_tar_vars]
417 | ic(variable_names)
418 | save_output(
419 | save_file_path=out_file,
420 | analysis_time=analysis_time,
421 | variable_names=variable_names,
422 | AI_gen_ensembles=gen_ensembles, # un-normed
423 | bg_hrrr=bg_hrrr, # un-normed
424 | field_target=field_target, # un-normed
425 | inp_obs_for_eval=inp_obs_for_eval, # un-normed
426 | hold_out_obs=hold_out_obs, # un-normed
427 | lon=lon,
428 | lat=lat,
429 | mask=mask,
430 | )
431 |
432 |
433 | def save_output(
434 | save_file_path: str,
435 | variable_names: list,
436 | AI_gen_ensembles: np.array,
437 | bg_hrrr: np.array,
438 | field_target: np.array,
439 | hold_out_obs: np.array,
440 | inp_obs_for_eval: np.array,
441 | lon: np.array,
442 | lat: np.array,
443 | mask: np.array,
444 | analysis_time: str,
445 | ):
446 |
447 | ic(
448 | lon.shape,
449 | lat.shape,
450 | AI_gen_ensembles.shape,
451 | bg_hrrr.shape,
452 | field_target.shape,
453 | hold_out_obs.shape,
454 | )
455 | ds = xr.Dataset(
456 | {
457 | f"ai_gen_{variable_names[0]}": (
458 | ("ensemble_num", "lat", "lon"),
459 | AI_gen_ensembles[:, 0],
460 | ),
461 | f"ai_gen_{variable_names[1]}": (
462 | ("ensemble_num", "lat", "lon"),
463 | AI_gen_ensembles[:, 1],
464 | ),
465 | f"ai_gen_{variable_names[2]}": (
466 | ("ensemble_num", "lat", "lon"),
467 | AI_gen_ensembles[:, 2],
468 | ),
469 | f"ai_gen_{variable_names[3]}": (
470 | ("ensemble_num", "lat", "lon"),
471 | AI_gen_ensembles[:, 3],
472 | ),
473 | # f"ai_gen_{variable_names[4]}": (
474 | # ("ensemble_num", "lat", "lon"),
475 | # AI_gen_ensembles[:, 4],
476 | # ),
477 | f"hold_out_obs_{variable_names[0]}": (
478 | ("lat", "lon"), hold_out_obs[0]),
479 | f"hold_out_obs_{variable_names[1]}": (
480 | ("lat", "lon"), hold_out_obs[1]),
481 | f"hold_out_obs_{variable_names[2]}": (
482 | ("lat", "lon"), hold_out_obs[2]),
483 | f"hold_out_obs_{variable_names[3]}": (
484 | ("lat", "lon"), hold_out_obs[3]),
485 | # f"hold_out_obs_{variable_names[4]}": (
486 | # ("lat", "lon"), hold_out_obs[4]),
487 | f"inp_obs_for_eval_{variable_names[0]}": (
488 | ("lat", "lon"),
489 | inp_obs_for_eval[0],
490 | ),
491 | f"inp_obs_for_eval_{variable_names[1]}": (
492 | ("lat", "lon"),
493 | inp_obs_for_eval[1],
494 | ),
495 | f"inp_obs_for_eval_{variable_names[2]}": (
496 | ("lat", "lon"),
497 | inp_obs_for_eval[2],
498 | ),
499 | f"inp_obs_for_eval_{variable_names[3]}": (
500 | ("lat", "lon"),
501 | inp_obs_for_eval[3],
502 | ),
503 | f"rtma_{variable_names[0]}": (
504 | ("lat", "lon"), field_target[0]),
505 | f"rtma_{variable_names[1]}": (
506 | ("lat", "lon"), field_target[1]),
507 | f"rtma_{variable_names[2]}": (
508 | ("lat", "lon"), field_target[2]),
509 | f"rtma_{variable_names[3]}": (
510 | ("lat", "lon"), field_target[3]),
511 | # f"rtma_{variable_names[4]}": (
512 | # ("lat", "lon"), field_target[4]),
513 | f"bg_hrrr_{variable_names[0]}": (
514 | ("lat", "lon"), bg_hrrr[0]),
515 | f"bg_hrrr_{variable_names[1]}": (
516 | ("lat", "lon"), bg_hrrr[1]),
517 | f"bg_hrrr_{variable_names[2]}": (
518 | ("lat", "lon"), bg_hrrr[2]),
519 | f"bg_hrrr_{variable_names[3]}": (
520 | ("lat", "lon"), bg_hrrr[3]),
521 | # f"bg_hrrr_{variable_names[4]}": (
522 | # ("lat", "lon"), bg_hrrr[4]),
523 | "variable_names": (
524 | ("variable_num"), variable_names),
525 | "mask": (
526 | ("lat", "lon"), mask[0]),
527 | },
528 | coords={
529 | "lat": lat,
530 | "lon": lon,
531 | "time": analysis_time,
532 | "time_window": np.arange(0, params.obs_time_window),
533 | "ensemble_num": np.arange(0, params.ensemble_num),
534 | },
535 | )
536 | logging.info(f"Saving result to {save_file_path}")
537 | ds.to_netcdf(save_file_path)
538 | ds.close()
539 |
540 |
541 | def read_sample_file_and_norm_input(
542 | params,
543 | file_path,
544 | hold_out_obs_ratio=0.2
545 | ):
546 |
547 | # %% get statistic
548 | stats_file = os.path.join(
549 | params.data_path, f"stats.csv")
550 |
551 | stats = pd.read_csv(stats_file, index_col=0)
552 |
553 | for vi, var in enumerate(params.inp_hrrr_vars):
554 | if vi == 0:
555 | inp_hrrr_stats = stats[stats["variable"].isin([var])]
556 | else:
557 | inp_hrrr_stats = pd.concat(
558 | [inp_hrrr_stats, stats[stats["variable"].isin([var])]]
559 | )
560 |
561 | for vi, var in enumerate(params.inp_obs_vars):
562 | if vi == 0:
563 | inp_obs_stats = stats[stats["variable"].isin([var])]
564 | else:
565 | inp_obs_stats = pd.concat(
566 | [inp_obs_stats, stats[stats["variable"].isin([var])]]
567 | )
568 |
569 | # %% get sample
570 | ds = xr.open_dataset(file_path, engine="netcdf4")
571 |
572 | lat = np.array(ds.coords["lat"].values)[: params.img_size_y]
573 | lon = np.array(ds.coords["lon"].values)[: params.img_size_x]
574 |
575 | # background
576 | inp_hrrr = np.array(ds[params.inp_hrrr_vars].to_array())[
577 | :, : params.img_size_y, : params.img_size_x
578 | ]
579 | inp_hrrr = np.squeeze(inp_hrrr)
580 | mask = inp_hrrr.copy()
581 | mask[mask != 0] = 1 # set 1 where out of range
582 | mask = mask.astype(bool) # True: out of range
583 |
584 | # baseline for evaluation
585 | bg_hrrr = inp_hrrr.copy()
586 |
587 | # normalization
588 | inp_hrrr = min_max_norm_ignore_extreme_fill_nan(
589 | inp_hrrr, inp_hrrr_stats["min"], inp_hrrr_stats["max"]
590 | )
591 |
592 | # topography (normed)
593 | topo = np.array(ds[["z"]].to_array())[
594 | :, :params.img_size_y, :params.img_size_x]
595 |
596 | # satellite
597 | inp_sate = np.array [params.inp_satelite_vars].to_array())[
598 | :, -params.obs_time_window:, :params.img_size_y, :params.img_size_x
599 | ]
600 |
601 | # Observation (normed)
602 | obs = np.array(ds[params.inp_obs_vars].to_array())[
603 | :, -params.obs_time_window:, :params.img_size_y, :params.img_size_x
604 | ]
605 | # quality control
606 | obs[(obs <= -1) | (obs >= 1)] = 0
607 |
608 | if params.hold_out_obs:
609 |
610 | if params.seed != 0:
611 | np.random.seed(params.seed)
612 | logging.info(f"Using random seed {params.seed}")
613 |
614 | lat_num = obs[0, 0].shape[0]
615 | lon_num = obs[0, 0].shape[1]
616 |
617 | # [lat, lon] -> [lat * lon]
618 | obs_tw_begin = obs[0, 0].reshape(-1)
619 | # find station's indices
620 | obs_index = np.where(~np.isnan(obs_tw_begin))[0]
621 |
622 | obs_num = len(obs_index)
623 | hold_out_num = int(obs_num * hold_out_obs_ratio)
624 | ic(obs_num, hold_out_num)
625 |
626 | # generate mask randomly
627 | np.random.shuffle(obs_index)
628 | hold_out_obs_index = obs_index[:hold_out_num]
629 | input_obs_index = obs_index[hold_out_num:]
630 | ic(len(hold_out_obs_index), hold_out_obs_index)
631 | ic(len(input_obs_index), input_obs_index)
632 |
633 | # Mask (lat, lon), hold_out obs=1, input obs = 0
634 | obs_mask = np.zeros(obs_tw_begin.shape)
635 | obs_mask[hold_out_obs_index] = 1
636 | obs_mask = obs_mask.reshape([lat_num, lon_num])
637 |
638 | inp_obs = obs * (1 - obs_mask) # observation for input
639 | hold_out_obs = obs * obs_mask # observation excluding the input
640 |
641 | inp_obs_for_eval = inp_obs[:, -1] # -1: analysis time
642 | print(f"inp_obs_for_eval: {inp_obs_for_eval.shape}")
643 |
644 | inp_obs = inp_obs.reshape((-1, params.img_size_y, params.img_size_x))
645 |
646 | # %% target (normed)
647 | field_target = np.array(ds[params.field_tar_vars].to_array())[
648 | :, : params.img_size_y, : params.img_size_x
649 | ]
650 | ic(field_target.shape)
651 |
652 | inp = np.concatenate((inp_hrrr, inp_obs, topo), axis=0)
653 |
654 | return (
655 | inp, # normed
656 | inp_sate, # normed
657 | inp_hrrr, # normed
658 | field_target, # normed
659 | hold_out_obs[:, -1], # normed, -1: analaysis time
660 | inp_obs_for_eval, # normed
661 | bg_hrrr, # un-normed
662 | mask, # out_of_range mask
663 | lat,
664 | lon,
665 | )
666 |
667 |
668 | if __name__ == "__main__":
669 | parser = argparse.ArgumentParser()
670 | parser.add_argument("--seed", default=0, type=int)
671 | parser.add_argument("--exp_dir", default="", type=str)
672 | parser.add_argument("--test_data_path", default="", type=str)
673 | parser.add_argument("--net_config", default="EncDec", type=str)
674 | parser.add_argument("--hold_out_obs_ratio", type=float, default=0.2)
675 |
676 | args = parser.parse_args()
677 |
678 | config_path = os.path.join(args.exp_dir, "config.yaml")
679 |
680 | params = YParams(config_path, args.net_config)
681 | params["resuming"] = False
682 | params["seed"] = args.seed
683 | params["experiment_dir"] = args.exp_dir
684 | params["test_data_path"] = args.test_data_path
685 | params["best_checkpoint_path"] = os.path.join(
686 | params["experiment_dir"], "training_checkpoints", "best_ckpt.tar"
687 | )
688 |
689 | # set up logging
690 | log_to_file(
691 | logger_name=None,
692 | log_filename=os.path.join(params["experiment_dir"], "inference.log"),
693 | )
694 | params.log()
695 |
696 | # get data files and model
697 | test_data_file_paths, inference_times, model = setup(params)
698 |
699 | target_variable = [var.split("_")[1] for var in params.field_tar_vars]
700 |
701 | inference(
702 | params,
703 | target_variable,
704 | test_data_file_paths,
705 | inference_times,
706 | args.hold_out_obs_ratio,
707 | model,
708 | )
709 |
710 | logging.info("Done")
711 |
712 |
713 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import wandb
4 | import random
5 | import datetime
6 | import argparse
7 | import numpy as np
8 |
9 | from str2bool import str2bool
10 | from icecream import ic
11 | from shutil import copyfile
12 | from apex import optimizers
13 | from collections import OrderedDict
14 |
15 | import torch
16 | import torch.cuda.amp as amp
17 | import torch.distributed as dist
18 | from torch.nn import functional as F
19 | from torch.nn.parallel import DistributedDataParallel
20 |
21 | from ruamel.yaml import YAML
22 | from ruamel.yaml.comments import CommentedMap as ruamelDict
23 |
24 | from utils.data_loader_multifiles import get_data_loader
25 | from utils.logging_utils import log_to_file
26 | from utils.YParams import YParams
27 |
28 |
29 | class Trainer:
30 | def count_parameters(self):
31 | count_params = 0
32 | for p in self.model.parameters():
33 | if p.requires_grad:
34 | count_params += p.numel()
35 |
36 | def set_device(self):
37 | if torch.cuda.is_available():
38 | self.device = torch.cuda.current_device()
39 | else:
40 | self.device = "cpu"
41 |
42 | def __init__(self, params, world_rank):
43 | self.params = params
44 | self.world_rank = world_rank
45 | self.set_device()
46 |
47 | # %% init wandb
48 | if params.log_to_wandb:
49 | wandb.init(
50 | config=params,
51 | name=params.name,
52 | group=params.group,
53 | project=params.project,
54 | entity=params.entity,
55 | settings={"_service_wait": 600, "init_timeout": 600},
56 | )
57 |
58 | # %% init gpu
59 | local_rank = int(os.environ["LOCAL_RANK"])
60 | torch.cuda.set_device(local_rank)
61 | self.device = torch.device("cuda", local_rank)
62 | print("device: %s" % self.device)
63 |
64 | # %% model init
65 | if params.nettype == "EncDec":
66 | from models.encdec import EncDec as model
67 | elif params.nettype == "EncDec_two_encoder":
68 | from models.encdec import EncDec_two_encoder as model
69 | else:
70 | raise Exception("not implemented")
71 | self.model = model(params).to(self.device)
72 | # self.model = model(params).to(local_rank) # for torchrun
73 |
74 | # %% Load data
75 | print("rank %d, begin data loader init" % world_rank)
76 | (
77 | self.train_data_loader,
78 | self.train_dataset,
79 | self.train_sampler,
80 | ) = get_data_loader(
81 | params,
82 | params.train_data_path,
83 | dist.is_initialized(),
84 | train=True,
85 | )
86 | (
87 | self.valid_data_loader,
88 | self.valid_dataset,
89 | self.valid_sampler,
90 | ) = get_data_loader(
91 | params,
92 | params.valid_data_path,
93 | dist.is_initialized(),
94 | train=True,
95 | )
96 |
97 | # %% optimizer
98 | if params.optimizer_type == "FusedAdam":
99 | self.optimizer = optimizers.FusedAdam(
100 | self.model.parameters(), lr=params.lr)
101 | elif params.optimizer_type == "Adam":
102 | self.optimizer = torch.optim.Adam(
103 | self.model.parameters(), lr=params.lr)
104 | elif params.optimizer_type == "AdamW":
105 | self.optimizer = torch.optim.AdamW(
106 | self.model.parameters(), lr=params.lr)
107 | else:
108 | raise Exception("not implemented")
109 |
110 | if params.enable_amp:
111 | self.gscaler = amp.GradScaler()
112 |
113 | # %% DDP
114 | if dist.is_initialized():
115 | ic(local_rank)
116 | self.model = DistributedDataParallel(
117 | self.model,
118 | device_ids=[params.local_rank],
119 | output_device=[params.local_rank],
120 | find_unused_parameters=True,
121 | )
122 | self.iters = 0
123 | self.startEpoch = 0
124 | self.plot = False
125 | self.plot_img_path = None
126 |
127 | # %% Dynamical Learning rate
128 | if params.scheduler == "ReduceLROnPlateau":
129 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
130 | self.optimizer,
131 | factor=params.lr_reduce_factor,
132 | patience=20,
133 | mode="min",
134 | )
135 | elif params.scheduler == "CosineAnnealingLR":
136 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
137 | self.optimizer,
138 | T_max=params.max_epochs,
139 | last_epoch=self.startEpoch - 1,
140 | )
141 | else:
142 | self.scheduler = None
143 |
144 | # %% Resume train
145 | if params.resuming:
146 | print(f"Loading checkpoint from {params.best_checkpoint_path}")
147 | self.restore_checkpoint(params.best_checkpoint_path)
148 |
149 | self.epoch = self.startEpoch
150 |
151 | if params.log_to_screen:
152 | print(
153 | f"Number of trainable model parameters: \
154 | {self.count_parameters()}"
155 | )
156 |
157 | if params.log_to_wandb:
158 | wandb.watch(self.model)
159 |
160 | def train(self):
161 | if self.params.log_to_screen:
162 | print("Starting Training Loop...")
163 |
164 | # best_valid_obs_loss = 1.0e6
165 | best_train_loss = 1.0e6
166 |
167 | for epoch in range(self.startEpoch, self.params.max_epochs):
168 | if dist.is_initialized():
169 | # different batch on each GPU
170 | self.train_sampler.set_epoch(epoch)
171 | self.valid_sampler.set_epoch(epoch)
172 | start = time.time()
173 |
174 | # train one epoch
175 | tr_time, data_time, step_time, train_logs = self.train_one_epoch()
176 | self.plot = False
177 | self.plot_img_path = None
178 | current_lr = self.optimizer.param_groups[0]["lr"]
179 |
180 | if self.params.log_to_screen:
181 | print(f"Epoch: {epoch + 1}")
182 | print(f"train data time={data_time}")
183 | print(f"train per step time={step_time}")
184 | print(f"train loss: {train_logs['loss_field']}")
185 | print(f"learning rate: {current_lr}")
186 |
187 | # valid one epoch
188 | if (epoch != 0) and (epoch % self.params.valid_frequency == 0):
189 | valid_time, valid_logs = self.validate_one_epoch()
190 |
191 | if self.params.log_to_screen:
192 | print(f"Epoch: {epoch + 1}")
193 | print(f"Valid time={valid_time}")
194 | print(f"Valid loss={valid_logs['valid_loss_field']}")
195 |
196 | # LR scheduler
197 | if self.params.scheduler == "ReduceLROnPlateau":
198 | self.scheduler.step(valid_logs["valid_loss_field"])
199 |
200 | if self.params.log_to_wandb:
201 | wandb.log({"lr": current_lr})
202 |
203 | # save model
204 | if (
205 | self.world_rank == 0
206 | and epoch % self.params.save_model_freq == 0
207 | and self.params.save_checkpoint
208 | ):
209 | self.save_checkpoint(self.params.checkpoint_path)
210 |
211 | if self.world_rank == 0 and self.params.save_checkpoint:
212 | if train_logs["loss_field"] <= best_train_loss:
213 | print(
214 | "Loss improved from {} to {}".format(
215 | best_train_loss, train_logs["loss_field"]
216 | )
217 | )
218 | best_train_loss = train_logs["loss_field"]
219 |
220 | start = time.time()
221 | self.save_checkpoint(self.params.best_checkpoint_path)
222 | print(f"save model time: {time.time() - start}")
223 |
224 | def loss_function(
225 | self,
226 | pre_field,
227 | tar_field,
228 | tar_obs,
229 | tar_field_obs,
230 | field_mask=None,
231 | obs_tar_mask=None,
232 | mask_out_of_range=True,
233 | ):
234 | """
235 | pre_field: model's output
236 | tar_field: label, after normalization
237 | """
238 |
239 | if mask_out_of_range:
240 | pre_field = torch.masked_fill(
241 | input=pre_field, mask=~field_mask, value=0
242 | ) # fill input with 0 where field_mask is True.
243 | tar_field = torch.masked_fill(
244 | input=tar_field, mask=~field_mask, value=0
245 | ) # fill input with 0 where field_mask is True.
246 | tar_field_obs = torch.masked_fill(
247 | input=tar_field_obs, mask=~field_mask, value=0
248 | ) # fill input with 0 where field_mask is True.
249 |
250 | # type 1 loss
251 | loss_field = F.mse_loss(
252 | pre_field, tar_field)
253 | loss_field_channel_wise = F.mse_loss(
254 | pre_field, tar_field, reduction="none")
255 | loss_field_channel_wise = torch.mean(
256 | loss_field_channel_wise, dim=(0, 2, 3))
257 |
258 | # type 2 loss
259 | loss_field_obs = F.mse_loss(
260 | pre_field, tar_field_obs)
261 |
262 | # type 3 loss
263 | pre_field = torch.masked_fill(
264 | input=pre_field, mask=~obs_tar_mask, value=0
265 | ) # fill input with 0 where mask is True.
266 | tar_obs = torch.masked_fill(
267 | input=tar_obs, mask=~obs_tar_mask, value=0)
268 | loss_obs = F.mse_loss(
269 | pre_field, tar_obs)
270 | loss_obs_channel_wise = F.mse_loss(
271 | pre_field, tar_obs, reduction="none")
272 | loss_obs_channel_wise = torch.mean(
273 | loss_obs_channel_wise, dim=(0, 2, 3))
274 |
275 | return {
276 | "loss_field": loss_field,
277 | "loss_field_channel_wise": loss_field_channel_wise,
278 | "loss_obs": loss_obs,
279 | "loss_obs_channel_wise": loss_obs_channel_wise,
280 | "loss_field_obs": loss_field_obs,
281 | }
282 |
283 | def train_one_epoch(self):
284 | print("Training...")
285 | self.epoch += 1
286 | if self.params.resuming:
287 | self.resumeEpoch += 1
288 | tr_time = 0
289 | data_time = 0
290 | steps_in_one_epoch = 0
291 | loss_field = 0
292 | loss_obs = 0
293 | loss_field_obs = 0
294 | loss_field_channel_wise = torch.zeros(
295 | len(self.params.target_vars), device=self.device, dtype=float
296 | )
297 | loss_obs_channel_wise = torch.zeros(
298 | len(self.params.target_vars), device=self.device, dtype=float
299 | )
300 |
301 | self.model.train()
302 | for i, data in enumerate(self.train_data_loader, 0):
303 | self.iters += 1
304 | steps_in_one_epoch += 1
305 | data_start = time.time()
306 |
307 | if self.params.nettype == "EncDec_two_encoder":
308 | (
309 | inp,
310 | inp_sate,
311 | target_field,
312 | target_obs,
313 | target_field_obs,
314 | inp_hrrr,
315 | _,
316 | _,
317 | field_mask,
318 | obs_tar_mask,
319 | ) = data
320 | if self.params.nettype == "EncDec":
321 | (
322 | inp,
323 | target_field,
324 | target_obs,
325 | target_field_obs,
326 | inp_hrrr,
327 | _,
328 | _,
329 | field_mask,
330 | obs_tar_mask,
331 | ) = data
332 |
333 | data_time += time.time() - data_start
334 | tr_start = time.time()
335 |
336 | self.model.zero_grad()
337 | with amp.autocast(self.params.enable_amp):
338 | inp = inp.to(self.device, dtype=torch.float)
339 | inp_hrrr = inp_hrrr.to(self.device, dtype=torch.float)
340 | target_field = target_field.to(self.device, dtype=torch.float)
341 | target_obs = target_obs.to(
342 | self.device, dtype=torch.float)
343 | target_field_obs = target_field_obs.to(
344 | self.device, dtype=torch.float)
345 | field_mask = torch.as_tensor(
346 | field_mask, dtype=torch.bool, device=self.device
347 | )
348 | obs_tar_mask = torch.as_tensor(
349 | obs_tar_mask, dtype=torch.bool, device=self.device
350 | )
351 |
352 | if self.params.nettype == "EncDec":
353 | gen = self.model(inp)
354 | if self.params.nettype == "EncDec_two_encoder":
355 | inp_sate = inp_sate.to(self.device, dtype=torch.float)
356 | gen = self.model(inp, inp_sate)
357 | gen.to(self.device, dtype=torch.float)
358 |
359 | loss = self.loss_function(
360 | pre_field=gen,
361 | tar_field=target_field,
362 | tar_obs=target_obs,
363 | tar_field_obs=target_field_obs,
364 | field_mask=field_mask,
365 | obs_tar_mask=obs_tar_mask,
366 | )
367 |
368 | loss_field += loss["loss_field"]
369 | loss_obs += loss["loss_obs"]
370 | loss_field_obs += loss["loss_field_obs"]
371 | loss_field_channel_wise += loss["loss_field_channel_wise"]
372 | loss_obs_channel_wise += loss["loss_obs_channel_wise"]
373 |
374 | self.optimizer.zero_grad()
375 | if self.params.target == "obs":
376 | # target: sparse observations
377 | if self.params.enable_amp:
378 | self.gscaler.scale(loss["loss_obs"]).backward()
379 | self.gscaler.step(self.optimizer)
380 | else:
381 | loss["loss_obs"].backward()
382 | self.optimizer.step()
383 | if self.params.target == "analysis":
384 | # target: grided fields
385 | if self.params.enable_amp:
386 | self.gscaler.scale(loss["loss_field"]).backward()
387 | self.gscaler.step(self.optimizer)
388 | else:
389 | loss["loss_field"].backward()
390 | self.optimizer.step()
391 | if self.params.target == "analysis_obs":
392 | # target: grided fields + sparse observations
393 | if self.params.enable_amp:
394 | self.gscaler.scale(loss["loss_field_obs"]).backward()
395 | self.gscaler.step(self.optimizer)
396 | else:
397 | loss["loss_field_obs"].backward()
398 | self.optimizer.step()
399 |
400 | if self.params.enable_amp:
401 | self.gscaler.update()
402 |
403 | tr_time += time.time() - tr_start
404 |
405 | logs = {
406 | "loss_field": loss_field / steps_in_one_epoch,
407 | "loss_obs": loss_obs / steps_in_one_epoch,
408 | "loss_field_obs": loss_field_obs / steps_in_one_epoch,
409 | }
410 | for i_, var_ in enumerate(self.params.target_vars):
411 | tmp_var_1 = loss_obs_channel_wise[i_] / steps_in_one_epoch
412 | tmp_var_2 = loss_field_channel_wise[i_] / steps_in_one_epoch
413 | logs[f"loss_obs_{var_}"] = tmp_var_1
414 | logs[f"loss_field_{var_}"] = tmp_var_2
415 |
416 | if dist.is_initialized():
417 | for key in sorted(logs.keys()):
418 | dist.all_reduce(logs[key].detach())
419 | logs[key] = float(logs[key] / dist.get_world_size())
420 |
421 | if self.params.log_to_wandb:
422 | wandb.log(logs, step=self.epoch)
423 |
424 | # time of one step in epoch
425 | step_time = tr_time / steps_in_one_epoch
426 |
427 | return tr_time, data_time, step_time, logs
428 |
429 | def validate_one_epoch(self):
430 | print("validating...")
431 | self.model.eval()
432 |
433 | valid_buff = torch.zeros((4), dtype=torch.float32, device=self.device)
434 | valid_loss_field = valid_buff[0].view(-1)
435 | valid_loss_obs = valid_buff[1].view(-1)
436 | valid_loss_field_obs = valid_buff[2].view(-1)
437 | valid_steps = valid_buff[3].view(-1)
438 |
439 | valid_start = time.time()
440 | with torch.no_grad():
441 | for i, data in enumerate(self.valid_data_loader, 0):
442 | self.plot = False
443 | self.plot_img_path = False
444 |
445 | if self.params.nettype == "EncDec_two_encoder":
446 | (
447 | inp,
448 | inp_sate,
449 | target_field,
450 | target_obs,
451 | target_field_obs,
452 | inp_hrrr,
453 | _,
454 | _,
455 | field_mask,
456 | obs_tar_mask,
457 | ) = data
458 | if self.params.nettype == "EncDec":
459 | (
460 | inp,
461 | target_field,
462 | target_obs,
463 | target_field_obs,
464 | inp_hrrr,
465 | _,
466 | _,
467 | field_mask,
468 | obs_tar_mask,
469 | ) = data
470 |
471 | inp = inp.to(
472 | self.device, dtype=torch.float)
473 | inp_hrrr = inp_hrrr.to(
474 | self.device, dtype=torch.float)
475 | target_field = target_field.to(
476 | self.device, dtype=torch.float)
477 | target_obs = target_obs.to(
478 | self.device, dtype=torch.float)
479 | target_field_obs = target_field_obs.to(
480 | self.device, dtype=torch.float)
481 | field_mask = field_mask.to(
482 | self.device, dtype=torch.bool)
483 | obs_tar_mask = obs_tar_mask.to(
484 | self.device, dtype=torch.bool)
485 |
486 | if self.params.nettype == "EncDec":
487 | gen = self.model(inp)
488 | if self.params.nettype == "EncDec_two_encoder":
489 | inp_sate = inp_sate.to(
490 | self.device, dtype=torch.float)
491 | gen = self.model(inp, inp_sate)
492 | gen.to(self.device, dtype=torch.float)
493 |
494 | loss = self.loss_function(
495 | pre_field=gen,
496 | tar_field=target_field,
497 | tar_obs=target_obs,
498 | tar_field_obs=target_field_obs,
499 | field_mask=field_mask,
500 | obs_tar_mask=obs_tar_mask,
501 | )
502 |
503 | valid_steps += 1.0
504 | valid_loss_field += loss["loss_field"]
505 | valid_loss_obs += loss["loss_obs"]
506 | valid_loss_field_obs += loss["loss_field_obs"]
507 |
508 | if dist.is_initialized():
509 | dist.all_reduce(valid_buff)
510 |
511 | # divide by number of steps
512 | valid_buff[0:3] = valid_buff[0:3] / valid_buff[3]
513 | valid_buff_cpu = valid_buff.detach().cpu().numpy()
514 | logs = {
515 | "valid_loss_field": valid_buff_cpu[0],
516 | "valid_loss_obs": valid_buff_cpu[1],
517 | "valid_loss_field_obs": valid_buff_cpu[2],
518 | }
519 |
520 | valid_time = time.time() - valid_start
521 |
522 | if self.params.log_to_wandb:
523 | wandb.log(logs, step=self.epoch)
524 |
525 | return valid_time, logs
526 |
527 | def load_model(self, model_path):
528 | if self.params.log_to_screen:
529 | print("Loading the model weights from {}".format(model_path))
530 |
531 | checkpoint = torch.load(
532 | model_path, map_location="cuda:{}".format(self.params.local_rank)
533 | )
534 |
535 | if dist.is_initialized():
536 | self.model.load_state_dict(checkpoint["model_state"])
537 | else:
538 | new_model_state = OrderedDict()
539 | if "model_state" in checkpoint:
540 | model_key = "model_state"
541 | else:
542 | model_key = "state_dict"
543 |
544 | for key in checkpoint[model_key].keys():
545 | if "module." in key:
546 | # model was stored using ddp which prepends module
547 | name = str(key[7:])
548 | new_model_state[name] = checkpoint[model_key][key]
549 | else:
550 | new_model_state[key] = checkpoint[model_key][key]
551 | self.model.load_state_dict(new_model_state)
552 | self.model.eval()
553 |
554 | def save_checkpoint(self, checkpoint_path, model=None):
555 | """We intentionally require a checkpoint_dir to be passed
556 | in order to allow Ray Tune to use this function"""
557 |
558 | if not model:
559 | model = self.model
560 |
561 | print("Saving model to {}".format(checkpoint_path))
562 | torch.save(
563 | {
564 | "iters": self.iters,
565 | "epoch": self.epoch,
566 | "model_state": model.state_dict(),
567 | "optimizer_state_dict": self.optimizer.state_dict(),
568 | },
569 | checkpoint_path,
570 | )
571 |
572 | def restore_checkpoint(self, checkpoint_path):
573 | checkpoint = torch.load(
574 | checkpoint_path,
575 | map_location="cuda:{}".format(self.params.local_rank)
576 | )
577 | try:
578 | self.model.load_state_dict(checkpoint["model_state"])
579 | except ValueError:
580 | new_state_dict = OrderedDict()
581 | for key, val in checkpoint["model_state"].items():
582 | name = key[7:]
583 | new_state_dict[name] = val
584 | self.model.load_state_dict(new_state_dict)
585 | self.iters = checkpoint["iters"]
586 | self.startEpoch = checkpoint["epoch"]
587 | self.resumeEpoch = 0
588 | if self.params.resuming:
589 | # restore checkpoint is used for finetuning as well as resuming.
590 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
591 | # uses config specified lr.
592 | for g in self.optimizer.param_groups:
593 | g["lr"] = self.params.lr
594 |
595 |
596 | def set_random_seed(seed):
597 | random.seed(seed)
598 | np.random.seed(seed)
599 | torch.manual_seed(seed)
600 | torch.cuda.manual_seed_all(seed)
601 |
602 |
603 | if __name__ == "__main__":
604 | parser = argparse.ArgumentParser()
605 | parser.add_argument(
606 | "--yaml_config",
607 | default="./config/experiment.yaml",
608 | type=str,
609 | )
610 | parser.add_argument("--exp_dir", default="./exp_us_t2m", type=str)
611 | parser.add_argument("--run_num", default="00", type=str)
612 | parser.add_argument("--resume", default=False, type=str2bool)
613 | parser.add_argument("--device", default="GPU", type=str)
614 | parser.add_argument("--seed", default=42, type=int)
615 | parser.add_argument("--max_epochs", default=1200, type=int)
616 | parser.add_argument("--lr", default=0.001, type=float)
617 | parser.add_argument("--lr_reduce_factor", default=0.9, type=float)
618 | parser.add_argument("--target", default="obs", type=str)
619 | parser.add_argument("--hold_out_obs_ratio", default=0.1, type=float)
620 | parser.add_argument("--obs_mask_seed", default=1, type=int)
621 | parser.add_argument("--wandb_api_key", type=str)
622 | parser.add_argument("--batch_size", default=8, type=int)
623 | parser.add_argument("--wandb_group", default="us_t2m", type=str)
624 | parser.add_argument("--net_config", default="VAE-AFNO", type=str)
625 | parser.add_argument("--enable_amp", action="store_true")
626 | parser.add_argument("--epsilon_factor", default=0, type=float)
627 | parser.add_argument("--local-rank", default=-1, type=int)
628 | args = parser.parse_args()
629 |
630 | os.environ["WANDB_API_KEY"] = args.wandb_api_key
631 | os.environ["WANDB_MODE"] = "online"
632 |
633 | if args.resume:
634 | params = YParams(
635 | os.path.join(
636 | args.exp_dir,
637 | args.net_config,
638 | args.run_num,
639 | "config.yaml"),
640 | args.net_config,
641 | False,
642 | )
643 | else:
644 | params = YParams(
645 | os.path.abspath(args.yaml_config),
646 | args.net_config,
647 | False)
648 |
649 | params["target"] = args.target
650 | params["hold_out_obs_ratio"] = args.hold_out_obs_ratio
651 | params["obs_mask_seed"] = args.obs_mask_seed
652 | params["lr_reduce_factor"] = args.lr_reduce_factor
653 | params["max_epochs"] = args.max_epochs
654 | params["world_size"] = 1
655 | params["lr"] = args.lr
656 |
657 | if "WORLD_SIZE" in os.environ:
658 | params["world_size"] = int(os.environ["WORLD_SIZE"])
659 | print("world_size :", params["world_size"])
660 |
661 | if args.device == "GPU":
662 | print("Initialize distributed process group...")
663 | torch.distributed.init_process_group(
664 | backend="nccl",
665 | timeout=datetime.timedelta(seconds=5400)
666 | )
667 | local_rank = int(os.environ["LOCAL_RANK"])
668 | torch.cuda.set_device(local_rank)
669 |
670 | # device = torch.device('cuda', args.local_rank)
671 | params["local_rank"] = local_rank
672 | torch.backends.cudnn.benchmark = True
673 |
674 | world_rank = dist.get_rank() # get current process's ID
675 | print(f"world_rank: {world_rank}")
676 |
677 | set_random_seed(args.seed)
678 | params["nettype"] = args.net_config
679 | params["global_batch_size"] = args.batch_size
680 | params["batch_size"] = int(
681 | args.batch_size // params["world_size"]
682 | ) # batch size must be divisible by the number of gpu's
683 | # Automatic Mixed Precision Training
684 | params["enable_amp"] = args.enable_amp
685 |
686 | # Set up directory
687 | expDir = os.path.join(
688 | args.exp_dir,
689 | args.net_config,
690 | str(args.run_num))
691 |
692 | # start training
693 | if (not args.resume) and (
694 | (world_rank == 0 and args.device == "GPU") or args.device == "CPU"
695 | ):
696 | os.makedirs(expDir, exist_ok=True)
697 | os.makedirs(
698 | os.path.join(expDir, "training_checkpoints"),
699 | exist_ok=True)
700 | copyfile(
701 | os.path.abspath(args.yaml_config),
702 | os.path.join(expDir, "config.yaml"))
703 |
704 | params["experiment_dir"] = os.path.abspath(expDir)
705 | params["checkpoint_path"] = os.path.join(
706 | expDir, "training_checkpoints", "ckpt.tar")
707 | params["best_checkpoint_path"] = os.path.join(
708 | expDir, "training_checkpoints", "best_ckpt.tar")
709 |
710 | # Do not comment this line out please:
711 | args.resuming = True if os.path.isfile(params.checkpoint_path) else False
712 | params["resuming"] = args.resuming
713 |
714 | # experiment name
715 | params["name"] = str(args.run_num)
716 |
717 | # wandb setting
718 | params["entity"] = "your entity" # team name
719 | params["project"] = "your project" # project name
720 | params["group"] = args.wandb_group + "_" + args.net_config
721 |
722 | # if world_rank == 0:
723 | log_to_file(
724 | logger_name=None,
725 | log_filename=os.path.join(expDir, "train.log"))
726 | params.log()
727 |
728 | params["log_to_wandb"] = (world_rank == 0) and params["log_to_wandb"]
729 | params["log_to_screen"] = (world_rank == 0) and params["log_to_screen"]
730 |
731 | if world_rank == 0:
732 | hparams = ruamelDict()
733 | yaml = YAML()
734 | for key, value in params.params.items():
735 | hparams[str(key)] = str(value)
736 | with open(os.path.join(expDir, "hyperparams.yaml"), "w") as hpfile:
737 | yaml.dump(hparams, hpfile)
738 |
739 | trainer = Trainer(params, world_rank)
740 | trainer.train()
741 | print("DONE ---- rank %d" % world_rank)
742 |
--------------------------------------------------------------------------------
/models/encdec.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.checkpoint as checkpoint
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 |
8 |
9 | class Mlp(nn.Module):
10 | def __init__(
11 | self,
12 | in_features,
13 | hidden_features=None,
14 | out_features=None,
15 | act_layer=nn.GELU,
16 | drop=0.0,
17 | ):
18 | super().__init__()
19 | out_features = out_features or in_features
20 | hidden_features = hidden_features or in_features
21 | self.fc1 = nn.Linear(in_features, hidden_features)
22 | self.act = act_layer()
23 | self.fc2 = nn.Linear(hidden_features, out_features)
24 | self.drop = nn.Dropout(drop)
25 |
26 | def forward(self, x):
27 | x = self.fc1(x)
28 | x = self.act(x)
29 | x = self.drop(x)
30 | x = self.fc2(x)
31 | x = self.drop(x)
32 | return x
33 |
34 |
35 | def window_partition(x, window_size):
36 | """
37 | Args:
38 | x: (B, H, W, C)
39 | window_size (int): window size
40 |
41 | Returns:
42 | windows: (num_windows*B, window_size, window_size, C)
43 | """
44 | B, H, W, C = x.shape
45 | # ic(x.shape)
46 | x = x.view(
47 | B, H // window_size, window_size,
48 | W // window_size, window_size, C)
49 | windows = (
50 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(
51 | -1, window_size, window_size, C)
52 | )
53 | return windows
54 |
55 |
56 | def window_reverse(windows, window_size, H, W):
57 | """
58 | Args:
59 | windows: (num_windows*B, window_size, window_size, C)
60 | window_size (int): Window size
61 | H (int): Height of image
62 | W (int): Width of image
63 |
64 | Returns:
65 | x: (B, H, W, C)
66 | """
67 | B = int(windows.shape[0] / (H * W / window_size / window_size))
68 | x = windows.view(
69 | B, H // window_size,
70 | W // window_size, window_size,
71 | window_size, -1
72 | )
73 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
74 | return x
75 |
76 |
77 | class WindowAttention(nn.Module):
78 | r"""Window based multi-head self attention (W-MSA) module
79 | with relative position bias.
80 | It supports both of shifted and non-shifted window.
81 |
82 | Args:
83 | dim (int): Number of input channels.
84 | window_size (tuple[int]):
85 | The height and width of the window.
86 | num_heads (int):
87 | Number of attention heads.
88 | qkv_bias (bool, optional):
89 | If True, add a learnable bias to query, key, value.
90 | Default: True
91 | qk_scale (float | None, optional):
92 | Override default qk scale of head_dim ** -0.5 if set
93 | attn_drop (float, optional):
94 | Dropout ratio of attention weight. Default: 0.0
95 | proj_drop (float, optional):
96 | Dropout ratio of output. Default: 0.0
97 | """
98 |
99 | def __init__(
100 | self,
101 | dim,
102 | window_size,
103 | num_heads,
104 | qkv_bias=True,
105 | qk_scale=None,
106 | attn_drop=0.0,
107 | proj_drop=0.0,
108 | ):
109 |
110 | super().__init__()
111 | self.dim = dim
112 | self.window_size = window_size # Wh, Ww
113 | self.num_heads = num_heads
114 | head_dim = dim // num_heads
115 | self.scale = qk_scale or head_dim**-0.5
116 |
117 | # define a parameter table of relative position bias
118 | self.relative_position_bias_table = nn.Parameter(
119 | torch.zeros(
120 | (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
121 | ) # 2*Wh-1 * 2*Ww-1, nH
122 |
123 | # get pair-wise relative position index
124 | # for each token inside the window
125 | coords_h = torch.arange(self.window_size[0])
126 | coords_w = torch.arange(self.window_size[1])
127 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
128 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
129 | relative_coords = (
130 | coords_flatten[:, :, None] - coords_flatten[:, None, :]
131 | ) # 2, Wh*Ww, Wh*Ww
132 | relative_coords = relative_coords.permute(
133 | 1, 2, 0
134 | ).contiguous() # Wh*Ww, Wh*Ww, 2
135 | # shift to start from 0
136 | relative_coords[:, :, 0] += self.window_size[0] - 1
137 | relative_coords[:, :, 1] += self.window_size[1] - 1
138 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
139 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
140 | self.register_buffer(
141 | "relative_position_index", relative_position_index)
142 |
143 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
144 | self.attn_drop = nn.Dropout(attn_drop)
145 | self.proj = nn.Linear(dim, dim)
146 |
147 | self.proj_drop = nn.Dropout(proj_drop)
148 |
149 | trunc_normal_(self.relative_position_bias_table, std=0.02)
150 | self.softmax = nn.Softmax(dim=-1)
151 |
152 | def forward(self, x, mask=None):
153 | """
154 | Args:
155 | x: input features with shape of (num_windows*B, N, C)
156 | mask: (0/-inf) mask with shape of
157 | (num_windows, Wh*Ww, Wh*Ww) or None
158 | """
159 | B_, N, C = x.shape
160 | qkv = (
161 | self.qkv(x)
162 | .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
163 | .permute(2, 0, 3, 1, 4)
164 | )
165 | q, k, v = (
166 | qkv[0],
167 | qkv[1],
168 | qkv[2],
169 | ) # make torchscript happy (cannot use tensor as tuple)
170 |
171 | q = q * self.scale
172 | attn = q @ k.transpose(-2, -1)
173 |
174 | relative_position_bias = self.relative_position_bias_table[
175 | self.relative_position_index.view(-1)
176 | ].view(
177 | self.window_size[0] * self.window_size[1],
178 | self.window_size[0] * self.window_size[1],
179 | -1,
180 | ) # Wh*Ww,Wh*Ww,nH
181 | relative_position_bias = relative_position_bias.permute(
182 | 2, 0, 1
183 | ).contiguous() # nH, Wh*Ww, Wh*Ww
184 | attn = attn + relative_position_bias.unsqueeze(0)
185 |
186 | if mask is not None:
187 | nW = mask.shape[0]
188 | attn = attn.view(
189 | B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
190 | 1
191 | ).unsqueeze(0)
192 | attn = attn.view(-1, self.num_heads, N, N)
193 | attn = self.softmax(attn)
194 | else:
195 | attn = self.softmax(attn)
196 |
197 | attn = self.attn_drop(attn)
198 |
199 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
200 | x = self.proj(x)
201 | x = self.proj_drop(x)
202 | return x
203 |
204 | def extra_repr(self) -> str:
205 | return f"dim={self.dim}, \
206 | window_size={self.window_size}, \
207 | num_heads={self.num_heads}"
208 |
209 | def flops(self, N):
210 | # calculate flops for 1 window with token length of N
211 | flops = 0
212 | # qkv = self.qkv(x)
213 | flops += N * self.dim * 3 * self.dim
214 | # attn = (q @ k.transpose(-2, -1))
215 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
216 | # x = (attn @ v)
217 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
218 | # x = self.proj(x)
219 | flops += N * self.dim * self.dim
220 | return flops
221 |
222 |
223 | class SwinTransformerBlock(nn.Module):
224 | r"""Swin Transformer Block.
225 |
226 | Args:
227 | dim (int): Number of input channels.
228 | input_resolution (tuple[int]):
229 | Input resulotion.
230 | num_heads (int):
231 | Number of attention heads.
232 | window_size (int):
233 | Window size.
234 | shift_size (int):
235 | Shift size for SW-MSA.
236 | mlp_ratio (float):
237 | Ratio of mlp hidden dim to embedding dim.
238 | qkv_bias (bool, optional):
239 | If True, add a learnable bias to query, key, value.
240 | Default: True
241 | qk_scale (float | None, optional):
242 | Override default qk scale of head_dim ** -0.5 if set.
243 | drop (float, optional):
244 | Dropout rate.
245 | Default: 0.0
246 | attn_drop (float, optional):
247 | Attention dropout rate.
248 | Default: 0.0
249 | drop_path (float, optional):
250 | Stochastic depth rate.
251 | Default: 0.0
252 | act_layer (nn.Module, optional):
253 | Activation layer.
254 | Default: nn.GELU
255 | norm_layer (nn.Module, optional):
256 | Normalization layer.
257 | Default: nn.LayerNorm
258 | """
259 |
260 | def __init__(
261 | self,
262 | dim,
263 | input_resolution,
264 | num_heads,
265 | window_size=7,
266 | shift_size=0,
267 | mlp_ratio=4.0,
268 | qkv_bias=True,
269 | qk_scale=None,
270 | drop=0.0,
271 | attn_drop=0.0,
272 | drop_path=0.0,
273 | act_layer=nn.GELU,
274 | norm_layer=nn.LayerNorm,
275 | ):
276 | super().__init__()
277 | self.dim = dim
278 | self.input_resolution = input_resolution
279 | self.num_heads = num_heads
280 | self.window_size = window_size
281 | self.shift_size = shift_size
282 | self.mlp_ratio = mlp_ratio
283 | if min(self.input_resolution) <= self.window_size:
284 | # if window size is larger than input resolution,
285 | # we don't partition windows
286 | self.shift_size = 0
287 | self.window_size = min(self.input_resolution)
288 | assert (
289 | 0 <= self.shift_size < self.window_size
290 | ), "shift_size must in 0-window_size"
291 |
292 | self.norm1 = norm_layer(dim)
293 | self.attn = WindowAttention(
294 | dim,
295 | window_size=to_2tuple(self.window_size),
296 | num_heads=num_heads,
297 | qkv_bias=qkv_bias,
298 | qk_scale=qk_scale,
299 | attn_drop=attn_drop,
300 | proj_drop=drop,
301 | )
302 |
303 | self.drop_path = DropPath(
304 | drop_path) if drop_path > 0.0 else nn.Identity()
305 | self.norm2 = norm_layer(dim)
306 | mlp_hidden_dim = int(dim * mlp_ratio)
307 | self.mlp = Mlp(
308 | in_features=dim,
309 | hidden_features=mlp_hidden_dim,
310 | act_layer=act_layer,
311 | drop=drop,
312 | )
313 |
314 | if self.shift_size > 0:
315 | attn_mask = self.calculate_mask(self.input_resolution)
316 | else:
317 | attn_mask = None
318 |
319 | self.register_buffer("attn_mask", attn_mask)
320 |
321 | def calculate_mask(self, x_size):
322 | # calculate attention mask for SW-MSA
323 | H, W = x_size
324 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
325 | h_slices = (
326 | slice(0, -self.window_size),
327 | slice(-self.window_size, -self.shift_size),
328 | slice(-self.shift_size, None),
329 | )
330 | w_slices = (
331 | slice(0, -self.window_size),
332 | slice(-self.window_size, -self.shift_size),
333 | slice(-self.shift_size, None),
334 | )
335 | cnt = 0
336 | for h in h_slices:
337 | for w in w_slices:
338 | img_mask[:, h, w, :] = cnt
339 | cnt += 1
340 |
341 | mask_windows = window_partition(
342 | img_mask, self.window_size
343 | ) # nW, window_size, window_size, 1
344 | mask_windows = mask_windows.view(
345 | -1, self.window_size * self.window_size)
346 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
347 | attn_mask = attn_mask.masked_fill(
348 | attn_mask != 0, float(-100.0)).masked_fill(
349 | attn_mask == 0, float(0.0)
350 | )
351 |
352 | return attn_mask
353 |
354 | def forward(self, x, x_size):
355 | H, W = x_size
356 | B, L, C = x.shape
357 | # assert L == H * W, "input feature has wrong size"
358 |
359 | shortcut = x
360 | x = self.norm1(x)
361 | x = x.view(B, H, W, C)
362 |
363 | # cyclic shift
364 | if self.shift_size > 0:
365 | shifted_x = torch.roll(
366 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
367 | )
368 | else:
369 | shifted_x = x
370 |
371 | # partition windows
372 | x_windows = window_partition(
373 | shifted_x, self.window_size
374 | ) # nW*B, window_size, window_size, C
375 | x_windows = x_windows.view(
376 | -1, self.window_size * self.window_size, C
377 | ) # nW*B, window_size*window_size, C
378 |
379 | # W-MSA/SW-MSA (to be compatible for testing
380 | # on images whose shapes are the multiple of window size
381 | if self.input_resolution == x_size:
382 | attn_windows = self.attn(
383 | x_windows, mask=self.attn_mask
384 | ) # nW*B, window_size*window_size, C
385 | else:
386 | attn_windows = self.attn(
387 | x_windows, mask=self.calculate_mask(x_size).to(x.device)
388 | )
389 |
390 | # merge windows
391 | attn_windows = attn_windows.view(
392 | -1, self.window_size, self.window_size, C)
393 | shifted_x = window_reverse(
394 | attn_windows, self.window_size, H, W) # B H' W' C
395 |
396 | # reverse cyclic shift
397 | if self.shift_size > 0:
398 | x = torch.roll(
399 | shifted_x,
400 | shifts=(self.shift_size, self.shift_size),
401 | dims=(1, 2)
402 | )
403 | else:
404 | x = shifted_x
405 | x = x.view(B, H * W, C)
406 |
407 | # FFN
408 | x = shortcut + self.drop_path(x)
409 | x = x + self.drop_path(self.mlp(self.norm2(x)))
410 |
411 | return x
412 |
413 | def extra_repr(self) -> str:
414 | return (
415 | f"dim={self.dim}, \
416 | input_resolution={self.input_resolution}, \
417 | num_heads={self.num_heads}, "
418 | f"window_size={self.window_size}, \
419 | shift_size={self.shift_size}, \
420 | mlp_ratio={self.mlp_ratio}"
421 | )
422 |
423 | def flops(self):
424 | flops = 0
425 | H, W = self.input_resolution
426 | # norm1
427 | flops += self.dim * H * W
428 | # W-MSA/SW-MSA
429 | nW = H * W / self.window_size / self.window_size
430 | flops += nW * self.attn.flops(self.window_size * self.window_size)
431 | # mlp
432 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
433 | # norm2
434 | flops += self.dim * H * W
435 | return flops
436 |
437 |
438 | class PatchMerging(nn.Module):
439 | r"""Patch Merging Layer.
440 |
441 | Args:
442 | input_resolution (tuple[int]):
443 | Resolution of input feature.
444 | dim (int):
445 | Number of input channels.
446 | norm_layer (nn.Module, optional):
447 | Normalization layer.
448 | Default: nn.LayerNorm
449 | """
450 |
451 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
452 | super().__init__()
453 | self.input_resolution = input_resolution
454 | self.dim = dim
455 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
456 | self.norm = norm_layer(4 * dim)
457 |
458 | def forward(self, x):
459 | """
460 | x: B, H*W, C
461 | """
462 | H, W = self.input_resolution
463 | B, L, C = x.shape
464 | assert L == H * W, "input feature has wrong size"
465 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
466 |
467 | x = x.view(B, H, W, C)
468 |
469 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
470 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
471 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
472 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
473 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
474 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
475 |
476 | x = self.norm(x)
477 | x = self.reduction(x)
478 |
479 | return x
480 |
481 | def extra_repr(self) -> str:
482 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
483 |
484 | def flops(self):
485 | H, W = self.input_resolution
486 | flops = H * W * self.dim
487 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
488 | return flops
489 |
490 |
491 | class BasicLayer(nn.Module):
492 | """A basic Swin Transformer layer for one stage.
493 |
494 | Args:
495 | dim (int): Number of input channels.
496 | input_resolution (tuple[int]): Input resolution.
497 | depth (int): Number of blocks.
498 | num_heads (int):
499 | Number of attention heads.
500 | window_size (int):
501 | Local window size.
502 | mlp_ratio (float):
503 | Ratio of mlp hidden dim to embedding dim.
504 | qkv_bias (bool, optional):
505 | If True, add a learnable bias to query, key, value.
506 | Default: True
507 | qk_scale (float | None, optional):
508 | Override default qk scale of head_dim ** -0.5 if set.
509 | drop (float, optional):
510 | Dropout rate. Default: 0.0
511 | attn_drop (float, optional):
512 | Attention dropout rate. Default: 0.0
513 | drop_path (float | tuple[float], optional):
514 | Stochastic depth rate. Default: 0.0
515 | norm_layer (nn.Module, optional):
516 | Normalization layer. Default: nn.LayerNorm
517 | downsample (nn.Module | None, optional):
518 | Downsample layer at the end of the layer.
519 | Default: None
520 | use_checkpoint (bool):
521 | Whether to use checkpointing to save memory.
522 | Default: False.
523 | """
524 |
525 | def __init__(
526 | self,
527 | dim,
528 | input_resolution,
529 | depth,
530 | num_heads,
531 | window_size,
532 | mlp_ratio=4.0,
533 | qkv_bias=True,
534 | qk_scale=None,
535 | drop=0.0,
536 | attn_drop=0.0,
537 | drop_path=0.0,
538 | norm_layer=nn.LayerNorm,
539 | downsample=None,
540 | use_checkpoint=False,
541 | ):
542 |
543 | super().__init__()
544 | self.dim = dim
545 | self.input_resolution = input_resolution
546 | self.depth = depth
547 | self.use_checkpoint = use_checkpoint
548 |
549 | # build blocks
550 | self.blocks = nn.ModuleList(
551 | [
552 | SwinTransformerBlock(
553 | dim=dim,
554 | input_resolution=input_resolution,
555 | num_heads=num_heads,
556 | window_size=window_size,
557 | shift_size=0 if (i % 2 == 0) else window_size // 2,
558 | mlp_ratio=mlp_ratio,
559 | qkv_bias=qkv_bias,
560 | qk_scale=qk_scale,
561 | drop=drop,
562 | attn_drop=attn_drop,
563 | drop_path=(
564 | drop_path[i] if isinstance(
565 | drop_path, list) else drop_path
566 | ),
567 | norm_layer=norm_layer,
568 | )
569 | for i in range(depth)
570 | ]
571 | )
572 |
573 | # patch merging layer
574 | if downsample is not None:
575 | self.downsample = downsample(
576 | input_resolution, dim=dim, norm_layer=norm_layer
577 | )
578 | else:
579 | self.downsample = None
580 |
581 | def forward(self, x, x_size):
582 | for blk in self.blocks:
583 | if self.use_checkpoint:
584 | x = checkpoint.checkpoint(blk, x, x_size)
585 | else:
586 | x = blk(x, x_size)
587 | if self.downsample is not None:
588 | x = self.downsample(x)
589 | return x
590 |
591 | def extra_repr(self) -> str:
592 | return f"dim={self.dim}, \
593 | input_resolution={self.input_resolution}, \
594 | depth={self.depth}"
595 |
596 | def flops(self):
597 | flops = 0
598 | for blk in self.blocks:
599 | flops += blk.flops()
600 | if self.downsample is not None:
601 | flops += self.downsample.flops()
602 | return flops
603 |
604 |
605 | class RSTB(nn.Module):
606 | """Residual Swin Transformer Block (RSTB).
607 |
608 | Args:
609 | dim (int): Number of input channels.
610 | input_resolution (tuple[int]): Input resolution.
611 | depth (int):
612 | Number of blocks.
613 | num_heads (int):
614 | Number of attention heads.
615 | window_size (int):
616 | Local window size.
617 | mlp_ratio (float):
618 | Ratio of mlp hidden dim to embedding dim.
619 | qkv_bias (bool, optional):
620 | If True, add a learnable bias to query, key, value.
621 | Default: True
622 | qk_scale (float | None, optional):
623 | Override default qk scale of head_dim ** -0.5 if set.
624 | drop (float, optional):
625 | Dropout rate.
626 | Default: 0.0
627 | attn_drop (float, optional):
628 | Attention dropout rate.
629 | Default: 0.0
630 | drop_path (float | tuple[float], optional):
631 | Stochastic depth rate.
632 | Default: 0.0
633 | norm_layer (nn.Module, optional):
634 | Normalization layer.
635 | Default: nn.LayerNorm
636 | downsample (nn.Module | None, optional):
637 | Downsample layer at the end of the layer.
638 | Default: None
639 | use_checkpoint (bool):
640 | Whether to use checkpointing to save memory.
641 | Default: False.
642 | img_size:
643 | Input image size.
644 | patch_size:
645 | Patch size.
646 | resi_connection:
647 | The convolutional block before residual connection.
648 | """
649 |
650 | def __init__(
651 | self,
652 | dim,
653 | input_resolution,
654 | depth,
655 | num_heads,
656 | window_size,
657 | mlp_ratio=4.0,
658 | qkv_bias=True,
659 | qk_scale=None,
660 | drop=0.0,
661 | attn_drop=0.0,
662 | drop_path=0.0,
663 | norm_layer=nn.LayerNorm,
664 | downsample=None,
665 | use_checkpoint=False,
666 | img_size=224,
667 | patch_size=4,
668 | resi_connection="1conv",
669 | ):
670 | super(RSTB, self).__init__()
671 |
672 | self.dim = dim
673 | self.input_resolution = input_resolution
674 |
675 | self.residual_group = BasicLayer(
676 | dim=dim,
677 | input_resolution=input_resolution,
678 | depth=depth,
679 | num_heads=num_heads,
680 | window_size=window_size,
681 | mlp_ratio=mlp_ratio,
682 | qkv_bias=qkv_bias,
683 | qk_scale=qk_scale,
684 | drop=drop,
685 | attn_drop=attn_drop,
686 | drop_path=drop_path,
687 | norm_layer=norm_layer,
688 | downsample=downsample,
689 | use_checkpoint=use_checkpoint,
690 | )
691 |
692 | if resi_connection == "1conv":
693 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
694 | elif resi_connection == "3conv":
695 | # to save parameters and memory
696 | self.conv = nn.Sequential(
697 | nn.Conv2d(dim, dim // 4, 3, 1, 1),
698 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
699 | nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
700 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
701 | nn.Conv2d(dim // 4, dim, 3, 1, 1),
702 | )
703 |
704 | self.patch_embed = PatchEmbed(
705 | img_size=img_size,
706 | patch_size=patch_size,
707 | in_chans=0,
708 | embed_dim=dim,
709 | norm_layer=None,
710 | )
711 |
712 | self.patch_unembed = PatchUnEmbed(
713 | img_size=img_size,
714 | patch_size=patch_size,
715 | in_chans=0,
716 | embed_dim=dim,
717 | norm_layer=None,
718 | )
719 |
720 | def forward(self, x, x_size):
721 | return (
722 | self.patch_embed(
723 | self.conv(self.patch_unembed(
724 | self.residual_group(x, x_size),
725 | x_size))
726 | )
727 | + x
728 | )
729 |
730 | def flops(self):
731 | flops = 0
732 | flops += self.residual_group.flops()
733 | H, W = self.input_resolution
734 | flops += H * W * self.dim * self.dim * 9
735 | flops += self.patch_embed.flops()
736 | flops += self.patch_unembed.flops()
737 |
738 | return flops
739 |
740 |
741 | class PatchEmbed(nn.Module):
742 | r"""Image to Patch Embedding
743 |
744 | Args:
745 | img_size (int):
746 | Image size.
747 | Default: 224.
748 | patch_size (int):
749 | Patch token size.
750 | Default: 4.
751 | in_chans (int):
752 | Number of input image channels.
753 | Default: 3.
754 | embed_dim (int):
755 | Number of linear projection output channels.
756 | Default: 96.
757 | norm_layer (nn.Module, optional):
758 | Normalization layer. Default: None
759 | """
760 |
761 | def __init__(
762 | self, img_size=224, patch_size=4,
763 | in_chans=3, embed_dim=96, norm_layer=None
764 | ):
765 | super().__init__()
766 | img_size = to_2tuple(img_size)
767 | patch_size = to_2tuple(patch_size)
768 | patches_resolution = [
769 | img_size[0] // patch_size[0],
770 | img_size[1] // patch_size[1],
771 | ]
772 | self.img_size = img_size
773 | self.patch_size = patch_size
774 | self.patches_resolution = patches_resolution
775 | self.num_patches = patches_resolution[0] * patches_resolution[1]
776 |
777 | self.in_chans = in_chans
778 | self.embed_dim = embed_dim
779 |
780 | if norm_layer is not None:
781 | self.norm = norm_layer(embed_dim)
782 | else:
783 | self.norm = None
784 |
785 | def forward(self, x):
786 | x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
787 | if self.norm is not None:
788 | x = self.norm(x)
789 | return x
790 |
791 | def flops(self):
792 | flops = 0
793 | H, W = self.img_size
794 | if self.norm is not None:
795 | flops += H * W * self.embed_dim
796 | return flops
797 |
798 |
799 | class PatchUnEmbed(nn.Module):
800 | r"""Image to Patch Unembedding
801 |
802 | Args:
803 | img_size (int):
804 | Image size.
805 | Default: 224.
806 | patch_size (int):
807 | Patch token size.
808 | Default: 4.
809 | in_chans (int):
810 | Number of input image channels.
811 | Default: 3.
812 | embed_dim (int):
813 | Number of linear projection output channels.
814 | Default: 96.
815 | norm_layer (nn.Module, optional):
816 | Normalization layer.
817 | Default: None
818 | """
819 |
820 | def __init__(
821 | self, img_size=224,
822 | patch_size=4, in_chans=3,
823 | embed_dim=96, norm_layer=None
824 | ):
825 | super().__init__()
826 | img_size = to_2tuple(img_size)
827 | patch_size = to_2tuple(patch_size)
828 | patches_resolution = [
829 | img_size[0] // patch_size[0],
830 | img_size[1] // patch_size[1],
831 | ]
832 | self.img_size = img_size
833 | self.patch_size = patch_size
834 | self.patches_resolution = patches_resolution
835 | self.num_patches = patches_resolution[0] * patches_resolution[1]
836 |
837 | self.in_chans = in_chans
838 | self.embed_dim = embed_dim
839 |
840 | def forward(self, x, x_size):
841 | B, HW, C = x.shape
842 | # B Ph*Pw C
843 | x = x.transpose(1, 2).view(
844 | B, self.embed_dim, x_size[0], x_size[1])
845 | return x
846 |
847 | def flops(self):
848 | flops = 0
849 | return flops
850 |
851 |
852 | class Upsample(nn.Sequential):
853 | """Upsample module.
854 |
855 | Args:
856 | scale (int): Scale factor. Supported scales: 2^n and 3.
857 | num_feat (int): Channel number of intermediate features.
858 | """
859 |
860 | def __init__(self, scale, num_feat):
861 | m = []
862 | if (scale & (scale - 1)) == 0: # scale = 2^n
863 | for _ in range(int(math.log(scale, 2))):
864 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
865 | m.append(nn.PixelShuffle(2))
866 | elif scale == 3:
867 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
868 | m.append(nn.PixelShuffle(3))
869 | else:
870 | raise ValueError(
871 | f"scale {scale} is not supported. " +
872 | "Supported scales: 2^n and 3."
873 | )
874 | super(Upsample, self).__init__(*m)
875 |
876 |
877 | class UpsampleOneStep(nn.Sequential):
878 | """UpsampleOneStep module
879 | (the difference with Upsample is that it
880 | always only has 1conv + 1pixelshuffle)
881 | Used in lightweight SR to save parameters.
882 |
883 | Args:
884 | scale (int):
885 | Scale factor. Supported scales: 2^n and 3.
886 | num_feat (int):
887 | Channel number of intermediate features.
888 |
889 | """
890 |
891 | def __init__(
892 | self, scale, num_feat, num_out_ch, input_resolution=None
893 | ):
894 | self.num_feat = num_feat
895 | self.input_resolution = input_resolution
896 | m = []
897 | m.append(nn.Conv2d(
898 | num_feat, (scale**2) * num_out_ch, 3, 1, 1))
899 | m.append(nn.PixelShuffle(scale))
900 | super(UpsampleOneStep, self).__init__(*m)
901 |
902 | def flops(self):
903 | H, W = self.input_resolution
904 | flops = H * W * self.num_feat * 3 * 9
905 | return flops
906 |
907 |
908 | class EncDec(nn.Module):
909 | r"""EncDec
910 |
911 | Args:
912 | img_size (int | tuple(int)):
913 | Input image size.
914 | Default 64
915 | patch_size (int | tuple(int)):
916 | Patch size.
917 | Default: 1
918 | in_chans (int):
919 | Number of input image channels.
920 | Default: 3
921 | embed_dim (int):
922 | Patch embedding dimension.
923 | Default: 96
924 | depths (tuple(int)):
925 | Depth of each Swin Transformer layer.
926 | num_heads (tuple(int)):
927 | Number of attention heads in different layers.
928 | window_size (int):
929 | Window size.
930 | Default: 7
931 | mlp_ratio (float):
932 | Ratio of mlp hidden dim to embedding dim.
933 | Default: 4
934 | qkv_bias (bool):
935 | If True, add a learnable bias to query, key, value.
936 | Default: True
937 | qk_scale (float):
938 | Override default qk scale of head_dim ** -0.5 if set.
939 | Default: None
940 | drop_rate (float):
941 | Dropout rate. Default: 0
942 | attn_drop_rate (float):
943 | Attention dropout rate. Default: 0
944 | drop_path_rate (float):
945 | Stochastic depth rate. Default: 0.1
946 | norm_layer (nn.Module):
947 | Normalization layer.
948 | Default: nn.LayerNorm.
949 | ape (bool):
950 | If True,
951 | add absolute position embedding to the patch embedding.
952 | Default: False
953 | patch_norm (bool):
954 | If True, add normalization after patch embedding.
955 | Default: True
956 | use_checkpoint (bool):
957 | Whether to use checkpointing to save memory.
958 | Default: False
959 | upscale:
960 | Upscale factor.
961 | 2/3/4/8 for image SR,
962 | 1 for denoising and compress artifact reduction
963 | img_range:
964 | Image range. 1. or 255.
965 | resi_connection:
966 | The convolutional block before residual connection.
967 | '1conv'/'3conv'
968 | """
969 |
970 | def __init__(
971 | self,
972 | params,
973 | norm_layer=nn.LayerNorm,
974 | **kwargs,
975 | ):
976 | super(EncDec, self).__init__()
977 | self.img_range = params.img_range
978 | if params.in_chans == 3:
979 | rgb_mean = (0.4488, 0.4371, 0.4040)
980 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
981 | else:
982 | self.mean = torch.zeros(1, 1, 1, 1)
983 | self.upscale = params.upscale
984 | self.window_size = params.window_size
985 |
986 | # 1, encoder
987 | self.conv_first = nn.Conv2d(
988 | params.in_chans, params.embed_dim, 3, 1, 1)
989 |
990 | # 2, decoder
991 | self.num_layers = len(params.depths)
992 | self.embed_dim = params.embed_dim
993 | self.ape = params.ape
994 | self.patch_norm = params.patch_norm
995 | self.num_features = params.embed_dim
996 | self.mlp_ratio = params.mlp_ratio
997 |
998 | # split image into non-overlapping patches
999 | self.patch_embed = PatchEmbed(
1000 | img_size=(params.img_size_x, params.img_size_y),
1001 | patch_size=params.patch_size,
1002 | in_chans=params.embed_dim,
1003 | embed_dim=params.embed_dim,
1004 | norm_layer=norm_layer if self.patch_norm else None,
1005 | )
1006 | num_patches = self.patch_embed.num_patches
1007 | patches_resolution = self.patch_embed.patches_resolution
1008 | self.patches_resolution = patches_resolution
1009 |
1010 | # merge non-overlapping patches into image
1011 | self.patch_unembed = PatchUnEmbed(
1012 | img_size=(params.img_size_x, params.img_size_y),
1013 | patch_size=params.patch_size,
1014 | in_chans=params.embed_dim,
1015 | embed_dim=params.embed_dim,
1016 | norm_layer=norm_layer if self.patch_norm else None,
1017 | )
1018 |
1019 | # absolute position embedding
1020 | if self.ape:
1021 | self.absolute_pos_embed = nn.Parameter(
1022 | torch.zeros(1, num_patches, params.embed_dim)
1023 | )
1024 | trunc_normal_(self.absolute_pos_embed, std=0.02)
1025 |
1026 | self.pos_drop = nn.Dropout(p=params.drop_rate)
1027 |
1028 | # stochastic depth
1029 | dpr = [
1030 | x.item()
1031 | for x in torch.linspace(
1032 | 0, params.drop_path_rate, sum(params.depths))
1033 | ] # stochastic depth decay rule
1034 |
1035 | # build Residual Swin Transformer blocks (RSTB)
1036 | self.layers = nn.ModuleList()
1037 | for i_layer in range(self.num_layers):
1038 | dp_1 = sum(params.depths[:i_layer])
1039 | dp_2 = sum(params.depths[: i_layer + 1])
1040 | layer = RSTB(
1041 | dim=params.embed_dim,
1042 | input_resolution=(
1043 | patches_resolution[0], patches_resolution[1]),
1044 | depth=params.depths[i_layer],
1045 | num_heads=params.num_heads[i_layer],
1046 | window_size=params.window_size,
1047 | mlp_ratio=self.mlp_ratio,
1048 | qkv_bias=params.qkv_bias,
1049 | qk_scale=params.qk_scale,
1050 | drop=params.drop_rate,
1051 | attn_drop=params.attn_drop_rate,
1052 | drop_path=dpr[dp_1:dp_2],
1053 | norm_layer=norm_layer,
1054 | downsample=None,
1055 | use_checkpoint=params.use_checkpoint,
1056 | img_size=(params.img_size_x, params.img_size_y),
1057 | patch_size=params.patch_size,
1058 | resi_connection=params.resi_connection,
1059 | )
1060 | self.layers.append(layer)
1061 | self.norm = norm_layer(self.num_features)
1062 |
1063 | # build the last conv layer in deep feature extraction
1064 | if params.resi_connection == "1conv":
1065 | self.conv_after_body = nn.Conv2d(
1066 | params.embed_dim, params.embed_dim, 3, 1, 1
1067 | )
1068 | elif params.resi_connection == "3conv":
1069 | # to save parameters and memory
1070 | self.conv_after_body = nn.Sequential(
1071 | nn.Conv2d(
1072 | params.embed_dim, params.embed_dim // 4, 3, 1, 1),
1073 | nn.LeakyReLU(
1074 | negative_slope=0.2, inplace=True),
1075 | nn.Conv2d(
1076 | params.embed_dim // 4, params.embed_dim // 4, 1, 1, 0),
1077 | nn.LeakyReLU(
1078 | negative_slope=0.2, inplace=True),
1079 | nn.Conv2d(
1080 | params.embed_dim // 4, params.embed_dim, 3, 1, 1),
1081 | )
1082 |
1083 | # 3, reconstruction
1084 | self.conv_before_upsample = nn.Sequential(
1085 | nn.Conv2d(params.embed_dim, params.num_feat, 3, 1, 1),
1086 | nn.LeakyReLU(inplace=True),
1087 | )
1088 | self.upsample = Upsample(
1089 | params.upscale, params.num_feat)
1090 | self.conv_last = nn.Conv2d(
1091 | params.num_feat, params.out_chans, 3, 1, 1)
1092 |
1093 | self.apply(self._init_weights)
1094 |
1095 | def _init_weights(self, m):
1096 | if isinstance(m, nn.Linear):
1097 | trunc_normal_(m.weight, std=0.02)
1098 | if isinstance(m, nn.Linear) and m.bias is not None:
1099 | nn.init.constant_(m.bias, 0)
1100 | elif isinstance(m, nn.LayerNorm):
1101 | nn.init.constant_(m.bias, 0)
1102 | nn.init.constant_(m.weight, 1.0)
1103 |
1104 | @torch.jit.ignore
1105 | def no_weight_decay(self):
1106 | return {"absolute_pos_embed"}
1107 |
1108 | @torch.jit.ignore
1109 | def no_weight_decay_keywords(self):
1110 | return {"relative_position_bias_table"}
1111 |
1112 | def check_image_size(self, x):
1113 | _, _, h, w = x.size()
1114 | mod_pad_h = (
1115 | self.window_size - h % self.window_size) % self.window_size
1116 | mod_pad_w = (
1117 | self.window_size - w % self.window_size) % self.window_size
1118 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
1119 | return x
1120 |
1121 | def forward_features(self, x):
1122 | x_size = (x.shape[2], x.shape[3])
1123 | x = self.patch_embed(x)
1124 | if self.ape:
1125 | x = x + self.absolute_pos_embed
1126 | x = self.pos_drop(x)
1127 |
1128 | for layer in self.layers:
1129 | x = layer(x, x_size)
1130 |
1131 | x = self.norm(x) # B L C
1132 | x = self.patch_unembed(x, x_size)
1133 |
1134 | return x
1135 |
1136 | def forward(self, x):
1137 | H, W = x.shape[2:]
1138 | x = self.check_image_size(x)
1139 |
1140 | self.mean = self.mean.type_as(x)
1141 | x = (x - self.mean) * self.img_range
1142 |
1143 | x = self.conv_first(x)
1144 | x = self.conv_after_body(self.forward_features(x)) + x
1145 | x = self.conv_before_upsample(x)
1146 | x = self.upsample(x)
1147 | x = self.conv_last(x)
1148 |
1149 | x = x / self.img_range + self.mean
1150 |
1151 | return x[:, :, : H * self.upscale, : W * self.upscale]
1152 |
1153 | def flops(self):
1154 | flops = 0
1155 | H, W = self.patches_resolution
1156 | flops += H * W * 3 * self.embed_dim * 9
1157 | flops += self.patch_embed.flops()
1158 | for i, layer in enumerate(self.layers):
1159 | flops += layer.flops()
1160 | flops += H * W * 3 * self.embed_dim * self.embed_dim
1161 | flops += self.upsample.flops()
1162 | return flops
1163 |
1164 |
1165 | if __name__ == "__main__":
1166 |
1167 | params = {
1168 | "upscale": 1,
1169 | "in_chans": 8,
1170 | "out_chans": 4,
1171 | "img_size_x": 960,
1172 | "img_size_y": 480,
1173 | "window_size": 4,
1174 | "patch_size": 5,
1175 | "num_feat": 64,
1176 | "drop_rate": 0.1,
1177 | "drop_path_rate": 0.1,
1178 | "attn_drop_rate": 0.1,
1179 | "ape": False,
1180 | "patch_norm": True,
1181 | "use_checkpoint": False,
1182 | "resi_connection": "1conv",
1183 | "qkv_bias": True,
1184 | "qk_scale": None,
1185 | "img_range": 1.0,
1186 | "depths": [3], # [3]
1187 | "embed_dim": 64, # need be divisible by num_heads
1188 | "num_heads": [4],
1189 | "mlp_ratio": 2,
1190 | }
1191 | import argparse
1192 |
1193 | params = argparse.Namespace(**params)
1194 | model = EncDec(
1195 | params,
1196 | )
1197 |
1198 | print(model)
1199 |
1200 | x = torch.randn((1, 8, 960, 480))
1201 | x = model(x)
1202 | print(x.shape)
1203 |
--------------------------------------------------------------------------------