├── .gitignore ├── LICENSE ├── README.md ├── README_en.md ├── data_factory ├── __init__.py ├── convert_era5_hourly.py ├── datasets.py ├── download_era5.py ├── graph_tools.py └── update_scaler.py ├── img ├── precipitation_small.gif └── wind_small.gif ├── model ├── afnonet.py └── graphcast_sequential.py ├── train_fourcastnet.py ├── train_graphcast.py └── utils ├── eval.py ├── params.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | /output/ 2 | *.csv 3 | *.pt 4 | .ipynb_checkpoints/ 5 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 HFAiLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenCastKit: an open-source solutions of global data-driven high-resolution weather forecasting 2 | 3 | 简体中文 | [English](README_en.md) 4 | 5 | 本项目是由幻方AI团队复现优化,并开源的全球AI气象预报模型工具库。基于 [FourCastNet](https://arxiv.org/abs/2202.11214) 和 [GraphCast](https://arxiv.org/abs/2212.12794) 的论文,我们构建了一个新的全球AI气象预报项目——**OpenCastKit**,它能够与欧洲中期天气预报中心(ECMWF)的传统物理模型——高分辨率综合预测系统(IFS),进行直接比较。 6 | 7 | 我们将基于1979年1月到2022年12月的ERA5数据训练出来的模型参数开源到 [Hugging Face 仓库](https://huggingface.co/hf-ai/OpenCastKit)中,并上线了一个每日更新的 [HF-Earth](https://www.high-flyer.cn/hf-earth/),展示模型的预测效果。 8 | 9 | 下面是一些预测案例: 10 | 11 | ![台风路径预测与真实路径比较](./img/wind_small.gif) 12 | 13 | ![汽水浓度预测与真实情况比较](./img/precipitation_small.gif) 14 | 15 | 16 | ## 依赖 17 | 18 | - [hfai](https://doc.hfai.high-flyer.cn/index.html) >= 7.9.5 19 | - torch >=1.8 20 | 21 | 22 | ## 训练 23 | 原始数据来自欧洲中期天气预报中心(ECMWF)提供的一个公开可用的综合数据集 [ERA5](https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5) ,需要通过 `data_factory/convert_ear5_hourly.py` 脚本提出数据特征,转化为[高性能训练样本格式 FFRecord](https://www.high-flyer.cn/blog/ffrecord/) 下的样本数据。 24 | 25 | 26 | ### 训练 FourCastNet 27 | 28 | 本地运行: 29 | ```shell 30 | python train_fourcastnet.py --pretrain-epochs 100 --fintune-epochs 40 --batch-size 4 31 | ``` 32 | 33 | 也可以提交任务至幻方萤火集群,使用96张A100进行数据并行训练 34 | ```shell 35 | hfai python train_fourcastnet.py --pretrain-epochs 100 --fintune-epochs 40 --batch-size 4 -- -n 12 --name train_fourcastnet 36 | ``` 37 | 38 | ### 训练 GraphCast 39 | 40 | 本地运行: 41 | ```shell 42 | python train_graphcast.py --epochs 200 --batch-size 2 43 | ``` 44 | 45 | 也可以提交任务至幻方萤火集群,使用256张A100进行流水线并行训练 46 | ```shell 47 | hfai python train_graphcast.py --epochs 200 --batch-size 2 -- -n 32 --name train_graphcast 48 | ``` 49 | 50 | 51 | ## 引用 52 | 53 | ```bibtex 54 | @article{pathak2022fourcastnet, 55 | title={Fourcastnet: A global data-driven high-resolution weather model using adaptive fourier neural operators}, 56 | author={Pathak, Jaideep and Subramanian, Shashank and Harrington, Peter and Raja, Sanjeev and Chattopadhyay, Ashesh and Mardani, Morteza and Kurth, Thorsten and Hall, David and Li, Zongyi and Azizzadenesheli, Kamyar and others}, 57 | journal={arXiv preprint arXiv:2202.11214}, 58 | year={2022} 59 | } 60 | ``` 61 | 62 | ```bibtex 63 | @article{remi2022graphcast, 64 | title={GraphCast: Learning skillful medium-range global weather forecasting}, 65 | author={Remi Lam, Alvaro Sanchez-Gonzalez, Matthew Willson, Peter Wirnsberger, Meire Fortunato, Alexander Pritzel, Suman Ravuri, Timo Ewalds, Ferran Alet, Zach Eaton-Rosen, Weihua Hu, Alexander Merose, Stephan Hoyer, George Holland, Jacklynn Stott, Oriol Vinyals, Shakir Mohamed, Peter Battaglia}, 66 | journal={arXiv preprint arXiv:2212.12794}, 67 | year={2022} 68 | } 69 | ``` -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | # OpenCastKit: an open-source solutions of global data-driven high-resolution weather forecasting 2 | 3 | English | [简体中文](README.md) 4 | 5 | This is an open-source solutions of global data-driven high-resolution weather forecasting, implemented and improved by [High-Flyer AI](https://www.high-flyer.cn/). It can compare with the ECMWF Integrated Forecasting System (IFS). 6 | 7 | The model weights trained on the ERA5 data from 1979-01 to 2022-12 are released at [Hugging Face repository](https://huggingface.co/hf-ai/OpenCastKit). You can also have a look at [HF-Earth](https://www.high-flyer.cn/hf-earth/), a daily updated demo of weather prediction. 8 | 9 | As shown in the following cases: 10 | 11 | ![Typhoon track comparison](./img/wind_small.gif) 12 | 13 | ![Water vapour comparison](./img/precipitation_small.gif) 14 | 15 | 16 | ## Requirements 17 | 18 | - [hfai](https://doc.hfai.high-flyer.cn/index.html) >= 7.9.5 19 | - torch >=1.8 20 | 21 | 22 | ## Training 23 | The raw data is from the public dataset, [ERA5](https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5) . We can use the script `data_factory/convert_ear5_hourly.py` to fetch featuers and convert them into the high-performance sample data of [FFRecord](https://www.high-flyer.cn/blog/ffrecord/) format. 24 | 25 | ### FourCastNet training 26 | 27 | Run locally: 28 | ```shell 29 | python train_fourcastnet.py --pretrain-epochs 100 --fintune-epochs 40 --batch-size 4 30 | ``` 31 | 32 | We can conduct data-parallel training on the Yinghuo HPC: 33 | ```shell 34 | hfai python train_fourcastnet.py --pretrain-epochs 100 --fintune-epochs 40 --batch-size 4 -- -n 12 --name train_fourcastnet 35 | ``` 36 | 37 | ### GraphCast training 38 | 39 | Run locally: 40 | ```shell 41 | python train_graphcast.py --epochs 200 --batch-size 2 42 | ``` 43 | 44 | We can conduct pipeline-parallel training on the Yinghuo HPC: 45 | ```shell 46 | hfai python train_graphcast.py --epochs 200 --batch-size 2 -- -n 32 --name train_graphcast 47 | ``` 48 | 49 | 50 | ## Reference 51 | 52 | ```bibtex 53 | @article{pathak2022fourcastnet, 54 | title={Fourcastnet: A global data-driven high-resolution weather model using adaptive fourier neural operators}, 55 | author={Pathak, Jaideep and Subramanian, Shashank and Harrington, Peter and Raja, Sanjeev and Chattopadhyay, Ashesh and Mardani, Morteza and Kurth, Thorsten and Hall, David and Li, Zongyi and Azizzadenesheli, Kamyar and others}, 56 | journal={arXiv preprint arXiv:2202.11214}, 57 | year={2022} 58 | } 59 | ``` 60 | 61 | ```bibtex 62 | @article{remi2022graphcast, 63 | title={GraphCast: Learning skillful medium-range global weather forecasting}, 64 | author={Remi Lam, Alvaro Sanchez-Gonzalez, Matthew Willson, Peter Wirnsberger, Meire Fortunato, Alexander Pritzel, Suman Ravuri, Timo Ewalds, Ferran Alet, Zach Eaton-Rosen, Weihua Hu, Alexander Merose, Stephan Hoyer, George Holland, Jacklynn Stott, Oriol Vinyals, Shakir Mohamed, Peter Battaglia}, 65 | journal={arXiv preprint arXiv:2212.12794}, 66 | year={2022} 67 | } 68 | ``` -------------------------------------------------------------------------------- /data_factory/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import ERA5, StandardScaler 2 | -------------------------------------------------------------------------------- /data_factory/convert_era5_hourly.py: -------------------------------------------------------------------------------- 1 | import dask 2 | import numpy as np 3 | from datetime import datetime, timedelta 4 | from dateutil.relativedelta import relativedelta 5 | import xarray as xr 6 | import pickle 7 | from pathlib import Path 8 | 9 | from ffrecord import FileWriter 10 | from data_factory.graph_tools import fetch_time_features 11 | 12 | np.random.seed(2022) 13 | 14 | DATADIR = './output/rawdata/era5_6_hourly/' 15 | DATANAMES = ['10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature', 16 | 'geopotential@1000', 'geopotential@50', 'geopotential@500', 'geopotential@850', 17 | 'mean_sea_level_pressure', 'relative_humidity@500', 'relative_humidity@850', 18 | 'surface_pressure', 'temperature@500', 'temperature@850', 'total_column_water_vapour', 19 | 'u_component_of_wind@1000', 'u_component_of_wind@500', 'u_component_of_wind@850', 20 | 'v_component_of_wind@1000', 'v_component_of_wind@500', 'v_component_of_wind@850', 21 | 'total_precipitation'] 22 | DATAMAP = { 23 | 'geopotential': 'z', 24 | 'relative_humidity': 'r', 25 | 'temperature': 't', 26 | 'u_component_of_wind': 'u', 27 | 'v_component_of_wind': 'v' 28 | } 29 | 30 | 31 | def dataset_to_sample(raw_data, mean, std): 32 | tmpdata = (raw_data - mean) / std 33 | 34 | xt0 = tmpdata[:-2] 35 | xt1 = tmpdata[1:-1] 36 | yt = tmpdata[2:] 37 | 38 | return xt0, xt1, yt 39 | 40 | 41 | def write_dataset(x0, x1, y, out_file): 42 | n_sample = x0.shape[0] 43 | 44 | # 初始化ffrecord 45 | writer = FileWriter(out_file, n_sample) 46 | 47 | for item in zip(x0, x1, y): 48 | bytes_ = pickle.dumps(item) 49 | writer.write_one(bytes_) 50 | writer.close() 51 | 52 | 53 | def load_ndf(time_scale): 54 | datas = [] 55 | for file in DATANAMES: 56 | tmp = xr.open_mfdataset(f'{DATADIR}/{file}/*.nc', combine='by_coords').sel(time=time_scale) 57 | if '@' in file: 58 | k, v = file.split('@') 59 | tmp = tmp.rename_vars({DATAMAP[k]: f'{DATAMAP[k]}@{v}'}) 60 | datas.append(tmp) 61 | with dask.config.set(**{'array.slicing.split_large_chunks': False}): 62 | valid_data = xr.merge(datas, compat="identical", join="inner") 63 | 64 | return valid_data 65 | 66 | 67 | def fetch_dataset(cursor_time, out_dir): 68 | # load weather data 69 | 70 | step = (cursor_time.year - 1979) * 12 + (cursor_time.month - 1) + 1 71 | start = cursor_time.strftime('%Y-%m-%d %H:%M:%S') 72 | end = (cursor_time + relativedelta(months=1, hours=7)).strftime('%Y-%m-%d %H:%M:%S') 73 | print(f'Step {step} | from {start} to {end}') 74 | 75 | time_scale = slice(start, end) 76 | 77 | with open("./output/data/scaler.pkl", "rb") as f: 78 | pkl = pickle.load(f) 79 | Mean = pkl["mean"] 80 | Std = pkl["std"] 81 | 82 | valid_data = load_ndf(time_scale) 83 | 84 | # era5 features 85 | Xt0, Xt1, Yt = [], [], [] 86 | for i, name in enumerate(['u10', 'v10', 't2m', 'z@1000', 'z@50', 'z@500', 'z@850', 'msl', 'r@500', 'r@850', 'sp', 't@500', 't@850', 'tcwv', 'u@1000', 'u@500', 'u@850', 'v@1000', 'v@500', 'v@850']): 87 | raw = valid_data[name] 88 | 89 | # split sample data 90 | xt0, xt1, yt = dataset_to_sample(raw, Mean[i], Std[i]) 91 | 92 | Xt0.append(xt0) 93 | Xt1.append(xt1) 94 | Yt.append(yt) 95 | 96 | Xt0 = np.stack(Xt0, axis=-1) 97 | Xt1 = np.stack(Xt1, axis=-1) 98 | Yt = np.stack(Yt, axis=-1) 99 | 100 | # time-dependent features 101 | time_features = [] 102 | for i in range(len(valid_data['time'])): 103 | cursor_time = cursor_time + timedelta(hours=6) * i 104 | tmp_feats = fetch_time_features(cursor_time) 105 | time_features.append(tmp_feats) 106 | time_features = np.asarray(time_features) 107 | 108 | Xt0 = np.concatenate([Xt0[:, 1:], time_features[:-2]], axis=-1) 109 | Xt1 = np.concatenate([Xt1[:, 1:], time_features[1:-1]], axis=-1) 110 | Yt = np.concatenate([Yt[:, 1:], time_features[2:]], axis=-1) 111 | print(f"Xt0.shape: {Xt0.shape}, Xt1.shape: {Xt1.shape}, Yt.shape: {Yt.shape}\n") 112 | 113 | write_dataset(Xt0, Xt1, Yt, out_dir / f"{step:03d}.ffr") 114 | 115 | 116 | def dump_era5(out_dir): 117 | out_dir.mkdir(exist_ok=True, parents=True) 118 | 119 | start_time = datetime(1979, 1, 1, 0, 0) 120 | end_time = datetime(2023, 2, 1, 0, 0) 121 | 122 | cursor_time = start_time 123 | while True: 124 | if cursor_time >= end_time: 125 | break 126 | 127 | fetch_dataset(cursor_time, out_dir) 128 | cursor_time += relativedelta(months=1) 129 | 130 | 131 | if __name__ == "__main__": 132 | out_dir = Path("./output/data/train.ffr") 133 | dump_era5(out_dir) 134 | -------------------------------------------------------------------------------- /data_factory/datasets.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | from torch_geometric.data import Data 7 | from ffrecord import FileReader 8 | from ffrecord.torch import Dataset, DataLoader 9 | import data_factory.graph_tools as gg 10 | 11 | 12 | class StandardScaler: 13 | def __init__(self): 14 | self.mean = 0.0 15 | self.std = 1.0 16 | 17 | def load(self, scaler_dir): 18 | with open(scaler_dir, "rb") as f: 19 | pkl = pickle.load(f) 20 | self.mean = pkl["mean"] 21 | self.std = pkl["std"] 22 | 23 | def inverse_transform(self, data): 24 | mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean 25 | std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std 26 | return (data * std) + mean 27 | 28 | 29 | class ERA5(Dataset): 30 | 31 | def __init__(self, split: str, check_data: bool = True, modelname: str = 'fourcastnet') -> None: 32 | 33 | self.data_dir = Path("./output/data/") 34 | 35 | assert split in ["train", "val"] 36 | assert modelname in ["fourcastnet", "graphcast"] 37 | self.split = split 38 | self.modelname = modelname 39 | self.fname = str(self.data_dir / f"{split}.ffr") 40 | self.reader = FileReader(self.fname, check_data) 41 | self.scaler = StandardScaler() 42 | self.scaler.load("./output/data/scaler.pkl") 43 | 44 | if self.modelname == 'graphcast': 45 | self.constant_features = gg.fetch_constant_features() 46 | else: 47 | self.constant_features = None 48 | 49 | def __len__(self): 50 | return self.reader.n 51 | 52 | def __getitem__(self, indices): 53 | seqs_bytes = self.reader.read(indices) 54 | samples = [] 55 | for bytes_ in seqs_bytes: 56 | x0, x1, y = pickle.loads(bytes_) 57 | 58 | if self.modelname == 'fourcastnet': 59 | x0 = np.nan_to_num(x0[:, :, :-2]) 60 | x1 = np.nan_to_num(x1[:, :, :-2]) 61 | y = np.nan_to_num(y[:, :, :-2]) 62 | samples.append((x0, x1, y)) 63 | else: 64 | x = np.nan_to_num(np.reshape(np.concatenate([x0, x1, y[:, :, -2:]], axis=-1), [-1, 49])) 65 | y = np.nan_to_num(np.reshape(y[:, :, :-2], [-1, 20])) 66 | samples.append((x, y)) 67 | return samples 68 | 69 | def get_scaler(self): 70 | return self.scaler 71 | 72 | def loader(self, *args, **kwargs) -> DataLoader: 73 | return DataLoader(self, *args, **kwargs) 74 | 75 | 76 | class EarthGraph(object): 77 | def __init__(self): 78 | self.mesh_data = None 79 | self.grid2mesh_data = None 80 | self.mesh2grid_data = None 81 | 82 | def generate_graph(self): 83 | mesh_nodes = gg.fetch_mesh_nodes() 84 | 85 | mesh_6_edges, mesh_6_edges_attrs = gg.fetch_mesh_edges(6) 86 | mesh_5_edges, mesh_5_edges_attrs = gg.fetch_mesh_edges(5) 87 | mesh_4_edges, mesh_4_edges_attrs = gg.fetch_mesh_edges(4) 88 | mesh_3_edges, mesh_3_edges_attrs = gg.fetch_mesh_edges(3) 89 | mesh_2_edges, mesh_2_edges_attrs = gg.fetch_mesh_edges(2) 90 | mesh_1_edges, mesh_1_edges_attrs = gg.fetch_mesh_edges(1) 91 | mesh_0_edges, mesh_0_edges_attrs = gg.fetch_mesh_edges(0) 92 | 93 | mesh_edges = mesh_6_edges + mesh_5_edges + mesh_4_edges + mesh_3_edges + mesh_2_edges + mesh_1_edges + mesh_0_edges 94 | mesh_edges_attrs = mesh_6_edges_attrs + mesh_5_edges_attrs + mesh_4_edges_attrs + mesh_3_edges_attrs + mesh_2_edges_attrs + mesh_1_edges_attrs + mesh_0_edges_attrs 95 | 96 | self.mesh_data = Data(x=torch.tensor(mesh_nodes, dtype=torch.float), 97 | edge_index=torch.tensor(mesh_edges, dtype=torch.long).T.contiguous(), 98 | edge_attr=torch.tensor(mesh_edges_attrs, dtype=torch.float)) 99 | 100 | grid2mesh_edges, grid2mesh_edge_attrs = gg.fetch_grid2mesh_edges() 101 | self.grid2mesh_data = Data(x=None, 102 | edge_index=torch.tensor(grid2mesh_edges, dtype=torch.long).T.contiguous(), 103 | edge_attr=torch.tensor(grid2mesh_edge_attrs, dtype=torch.float)) 104 | 105 | mesh2grid_edges, mesh2grid_edge_attrs = gg.fetch_mesh2grid_edges() 106 | self.mesh2grid_data = Data(x=None, 107 | edge_index=torch.tensor(mesh2grid_edges, dtype=torch.long).T.contiguous(), 108 | edge_attr=torch.tensor(mesh2grid_edge_attrs, dtype=torch.float)) -------------------------------------------------------------------------------- /data_factory/download_era5.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['http_proxy'] = "http://vdi-proxy.high-flyer.cn:3128" 3 | os.environ['https_proxy'] = "http://vdi-proxy.high-flyer.cn:3128" 4 | 5 | from datetime import datetime, timedelta 6 | from pathlib import Path 7 | import cdsapi 8 | 9 | DATADIR = Path('./output/rawdata/') 10 | all_days = [ 11 | '01','02','03','04','05','06','07','08','09','10','11','12','13','14','15', 12 | '16','17','18','19','20','21','22','23','24','25','26','27','28','29','30','31' 13 | ] 14 | all_times = [ 15 | '00:00','06:00','12:00','18:00' 16 | ] 17 | all_variables = [ 18 | '10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature', 19 | 'geopotential@1000', 'geopotential@50', 'geopotential@500', 'geopotential@850', 20 | 'mean_sea_level_pressure', 'relative_humidity@500', 'relative_humidity@850', 21 | 'surface_pressure', 'temperature@500', 'temperature@850', 'total_column_water_vapour', 22 | 'u_component_of_wind@1000', 'u_component_of_wind@500', 'u_component_of_wind@850', 23 | 'v_component_of_wind@1000', 'v_component_of_wind@500', 'v_component_of_wind@850', 24 | 'total_precipitation' 25 | ] 26 | 27 | 28 | def download_single_file( 29 | variable, 30 | level_type, 31 | year, 32 | pressure_level='1000', 33 | month='01', 34 | day='01', 35 | output_dir=None, 36 | ): 37 | """ 38 | Download a single file from the ERA5 archive. 39 | 40 | :param variable: Name of variable in archive 41 | :param level_type: 'single' or 'pressure' 42 | :param year: Year(s) to download data 43 | :param pressure_level: Pressure levels to download. None for 'single' output type. 44 | :param month: Month(s) to download data 45 | :param day: Day(s) to download data 46 | :param time: Hour(s) to download data. Format: 'hh:mm' 47 | """ 48 | 49 | if level_type == 'pressure': 50 | var_name = f'{variable}@{pressure_level}' 51 | else: 52 | var_name = variable 53 | 54 | if type(day) is list: 55 | fn = f'era5_{var_name}_{year}_{month}_6_hourly.nc' 56 | else: 57 | fn = f'era5_{var_name}_{year}_{month}_{day}_6_hourly.nc' 58 | 59 | c = cdsapi.Client(progress=False) 60 | 61 | request_parameters = { 62 | 'product_type': 'reanalysis', 63 | 'expver': '1', 64 | 'format': 'netcdf', 65 | 'variable': variable, 66 | 'year': year, 67 | 'month': month, 68 | 'day': day, 69 | 'time': all_times, 70 | } 71 | request_parameters.update({'pressure_level': pressure_level} if level_type == 'pressure' else {}) 72 | 73 | c.retrieve( 74 | f'reanalysis-era5-{level_type}-levels', 75 | request_parameters, 76 | str(output_dir / fn) 77 | ) 78 | 79 | print(f"Saved file: {output_dir / fn}") 80 | 81 | 82 | def main( 83 | variable='geopotential', 84 | level_type='pressure', 85 | years='1979', 86 | pressure_level='1000', 87 | month='01', 88 | day='01' 89 | ): 90 | """ 91 | Command line script to download single or several files from the ERA5 archive. 92 | 93 | :param variable: Name of variable in archive 94 | :param level_type: 'single' or 'pressure' 95 | :param years: Years to download data. Each year is saved separately 96 | :param pressure_level: Pressure levels to download. None for 'single' output type. 97 | :param month: Month(s) to download data 98 | :param day: Day(s) to download data 99 | :param time: Hour(s) to download data. Format: 'hh:mm' 100 | """ 101 | # Make sure output directory exists 102 | output_dir = DATADIR / 'new_files' 103 | output_dir.mkdir(parents=True, exist_ok=True) 104 | 105 | if level_type == 'pressure': 106 | assert pressure_level is not None, 'Pressure level must be defined.' 107 | 108 | download_single_file( 109 | variable=variable, 110 | level_type=level_type, 111 | year=years, 112 | pressure_level=pressure_level, 113 | month=month, 114 | day=day, 115 | output_dir=output_dir 116 | ) 117 | 118 | 119 | def fetch_a_day(date): 120 | for var in all_variables: 121 | if '@' in var: 122 | level_type= 'pressure' 123 | var, pressure_level = var.split('@') 124 | else: 125 | level_type = 'single' 126 | pressure_level = '' 127 | 128 | main( 129 | variable=var, 130 | level_type=level_type, 131 | years=f'{date.year}', 132 | pressure_level=pressure_level, 133 | month=f'{date.month:02d}', 134 | day=f'{date.day:02d}' 135 | ) 136 | 137 | # move data 138 | for item in all_variables: 139 | os.system(f'mv {DATADIR}/new_files/era5_{item}_* {DATADIR}/era5_6_hourly/{item}/') 140 | 141 | 142 | def fetch_a_month(date): 143 | for var in all_variables: 144 | if '@' in var: 145 | level_type= 'pressure' 146 | var, pressure_level = var.split('@') 147 | else: 148 | level_type = 'single' 149 | pressure_level = '' 150 | 151 | main( 152 | variable=var, 153 | level_type=level_type, 154 | years=f'{date.year}', 155 | pressure_level=pressure_level, 156 | month=f'{date.month:02d}', 157 | day=all_days 158 | ) 159 | 160 | # move data 161 | for item in all_variables: 162 | os.system(f'mv {DATADIR}/new_files/era5_{item}_* {DATADIR}/era5_6_hourly/{item}/') 163 | 164 | 165 | if __name__ == '__main__': 166 | 167 | fetch_date = datetime.today() - timedelta(days=6) 168 | # fetch_date = datetime(2022, 11, 23) 169 | 170 | fetch_a_day(fetch_date) 171 | # fetch_a_month(fetch_date) 172 | -------------------------------------------------------------------------------- /data_factory/graph_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | 4 | 5 | def id2position(node_id, lat_len, lon_len): 6 | lat = node_id // lon_len 7 | lon = node_id % lon_len 8 | cos_lat = np.cos((0.5 - (lat + 1) / (lat_len + 1)) * np.pi) 9 | sin_lon = np.sin(lon / lon_len * np.pi) 10 | cos_lon = np.cos(lon / lon_len * np.pi) 11 | return cos_lat, sin_lon, cos_lon 12 | 13 | 14 | def fetch_mesh_nodes(): 15 | nodes = [] 16 | for i in range(128): 17 | cos_lat = np.cos((0.5 - (i + 1) / 129) * np.pi) 18 | for j in range(320): 19 | sin_lon = np.sin(j / 320 * np.pi) 20 | cos_lon = np.cos(j / 320 * np.pi) 21 | nodes.append([cos_lat, sin_lon, cos_lon]) 22 | return nodes 23 | 24 | 25 | def fetch_mesh_edges(r): 26 | assert 6 >= r >= 0 27 | 28 | step = 2 ** (6 - r) 29 | edges = [] 30 | edge_attrs = [] 31 | for i in range(0, 128, step): 32 | for j in range(0, 320, step): 33 | cur_node = id2position(i * 320 + j, 128, 320) 34 | if i - step >= 0: 35 | edges.append([(i - step) * 320 + j, i * 320 + j]) 36 | target_node = id2position((i - step) * 320 + j, 128, 320) 37 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 38 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 39 | if i + step < 128: 40 | edges.append([(i + step) * 320 + j, i * 320 + j]) 41 | target_node = id2position((i + step) * 320 + j, 128, 320) 42 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 43 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 44 | if j - step >= 0: 45 | edges.append([i * 320 + j - step, i * 320 + j]) 46 | target_node = id2position(i * 320 + j - step, 128, 320) 47 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 48 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 49 | else: 50 | edges.append([i * 320 + 320 - step, i * 320 + j]) 51 | target_node = id2position(i * 320 + 320 - step, 128, 320) 52 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 53 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 54 | if j + step < 320: 55 | edges.append([i * 320 + j + step, i * 320 + j]) 56 | target_node = id2position(i * 320 + j + step, 128, 320) 57 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 58 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 59 | else: 60 | edges.append([i * 320, i * 320 + j]) 61 | target_node = id2position(i * 320, 128, 320) 62 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 63 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 64 | 65 | return edges, edge_attrs 66 | 67 | 68 | def fetch_grid2mesh_edges(): 69 | lat_span = 720 / 128 70 | lon_span = 1440 / 320 71 | edges = [] 72 | edge_attrs = [] 73 | for i in range(720): 74 | for j in range(1440): 75 | target_mesh_i = int(i / lat_span) 76 | target_mesh_j = int(j / lon_span) 77 | edges.append([i * 1440 + j, target_mesh_i * 320 + target_mesh_j]) 78 | cur_node = id2position(i * 1440 + j, 720, 1440) 79 | target_node = id2position(target_mesh_i * 320 + target_mesh_j, 128, 320) 80 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 81 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 82 | 83 | over_mesh_i = int(i / lat_span - 0.1) 84 | if i / lat_span - 0.1 > 0 and over_mesh_i != target_mesh_i: 85 | edges.append([i * 1440 + j, over_mesh_i * 320 + target_mesh_j]) 86 | target_node = id2position(over_mesh_i * 320 + target_mesh_j, 128, 320) 87 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 88 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 89 | 90 | over_mesh_i = int(i / lat_span + 0.1) 91 | if i / lat_span + 0.1 < 128 and over_mesh_i != target_mesh_i: 92 | edges.append([i * 1440 + j, over_mesh_i * 320 + target_mesh_j]) 93 | target_node = id2position(over_mesh_i * 320 + target_mesh_j, 128, 320) 94 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 95 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 96 | 97 | over_mesh_j = int(j / lon_span - 0.1) 98 | if j / lon_span - 0.1 < 0: 99 | edges.append([i * 1440 + 1439, target_mesh_i * 320 + target_mesh_j]) 100 | target_node = id2position(target_mesh_i * 320 + target_mesh_j, 128, 320) 101 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 102 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 103 | elif over_mesh_j != target_mesh_j: 104 | edges.append([i * 1440 + j, target_mesh_i * 320 + over_mesh_j]) 105 | target_node = id2position(target_mesh_i * 320 + over_mesh_j, 128, 320) 106 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 107 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 108 | 109 | over_mesh_j = int(j / lon_span + 0.1) 110 | if j / lon_span + 0.1 > 320: 111 | edges.append([i * 1440, target_mesh_i * 320 + target_mesh_j]) 112 | target_node = id2position(target_mesh_i * 320 + target_mesh_j, 128, 320) 113 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 114 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 115 | elif over_mesh_j != target_mesh_j: 116 | edges.append([i * 1440 + j, target_mesh_i * 320 + over_mesh_j]) 117 | target_node = id2position(target_mesh_i * 320 + over_mesh_j, 128, 320) 118 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 119 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 120 | 121 | return edges, edge_attrs 122 | 123 | 124 | def fetch_mesh2grid_edges(): 125 | lat_span = 720 / 128 126 | lon_span = 1440 / 320 127 | edges = [] 128 | edge_attrs = [] 129 | for i in range(720): 130 | for j in range(1440): 131 | target_mesh_i = int(i / lat_span) 132 | target_mesh_j = int(j / lon_span) 133 | edges.append([target_mesh_i * 320 + target_mesh_j, i * 1440 + j]) 134 | target_node = id2position(i * 1440 + j, 720, 1440) 135 | cur_node = id2position(target_mesh_i * 320 + target_mesh_j, 128, 320) 136 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 137 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 138 | 139 | over_mesh_i = int(i / lat_span - 0.3) 140 | if i / lat_span - 0.3 > 0 and over_mesh_i != target_mesh_i: 141 | edges.append([over_mesh_i * 320 + target_mesh_j, i * 1440 + j]) 142 | cur_node = id2position(over_mesh_i * 320 + target_mesh_j, 128, 320) 143 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 144 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 145 | 146 | over_mesh_i = int(i / lat_span + 0.3) 147 | if i / lat_span + 0.3 < 128 and over_mesh_i != target_mesh_i: 148 | edges.append([over_mesh_i * 320 + target_mesh_j, i * 1440 + j]) 149 | cur_node = id2position(over_mesh_i * 320 + target_mesh_j, 128, 320) 150 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 151 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 152 | 153 | over_mesh_j = int(j / lon_span - 0.3) 154 | if j / lon_span - 0.3 < 0: 155 | edges.append([target_mesh_i * 320 + 319, i * 1440 + j]) 156 | cur_node = id2position(target_mesh_i * 320 + 319, 128, 320) 157 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 158 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 159 | elif over_mesh_j != target_mesh_j: 160 | edges.append([target_mesh_i * 320 + over_mesh_j, i * 1440 + j]) 161 | cur_node = id2position(target_mesh_i * 320 + over_mesh_j, 128, 320) 162 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 163 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 164 | 165 | over_mesh_j = int(j / lon_span + 0.3) 166 | if j / lon_span + 0.3 > 320: 167 | edges.append([target_mesh_i * 320, i * 1440 + j]) 168 | cur_node = id2position(target_mesh_i * 320, 128, 320) 169 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 170 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 171 | elif over_mesh_j != target_mesh_j: 172 | edges.append([target_mesh_i * 320 + over_mesh_j, i * 1440 + j]) 173 | cur_node = id2position(target_mesh_i * 320 + over_mesh_j, 128, 320) 174 | tmp_attr = [target_node[k] - cur_node[k] for k in range(3)] 175 | edge_attrs.append([np.sqrt(np.sum(np.square(tmp_attr)))] + tmp_attr) 176 | 177 | return edges, edge_attrs 178 | 179 | 180 | ## 特征抽取 181 | def fetch_time_features(cursor_time): 182 | 183 | year_hours = (datetime.date(cursor_time.year + 1, 1, 1) - datetime.date(cursor_time.year, 1, 1)).days * 24 184 | next_year_hours = (datetime.date(cursor_time.year + 2, 1, 1) - datetime.date(cursor_time.year + 1, 1, 1)).days * 24 185 | 186 | cur_hour = (cursor_time - datetime.datetime(cursor_time.year, 1, 1)) / datetime.timedelta(hours=1) 187 | time_features = [] 188 | for j in range(1440): 189 | # local time 190 | local_hour = cur_hour + j * 24 / 1440 191 | if local_hour > year_hours: 192 | tr = (local_hour - year_hours) / next_year_hours 193 | else: 194 | tr = local_hour / year_hours 195 | 196 | time_features.append([[np.sin(2 * np.pi * tr), np.cos(2 * np.pi * tr)]] * 720) 197 | 198 | return np.transpose(np.asarray(time_features), [1, 0, 2]) 199 | 200 | 201 | def fetch_constant_features(): 202 | constant_features = [] 203 | for i in range(720): 204 | tmp = [] 205 | for j in range(1440): 206 | tmp.append(id2position(i * 1440 + j, 720, 1440)) 207 | constant_features.append(tmp) 208 | return np.asarray(constant_features) -------------------------------------------------------------------------------- /data_factory/update_scaler.py: -------------------------------------------------------------------------------- 1 | import hfai_env 2 | hfai_env.set_env("weather") 3 | 4 | import xarray as xr 5 | import pickle 6 | import numpy as np 7 | 8 | DATADIR = './output/rawdata/era5_6_hourly/' 9 | DATANAMES = { 10 | '10m_u_component_of_wind': 'u10', 11 | '10m_v_component_of_wind': 'v10', 12 | '2m_temperature': 't2m', 13 | 'geopotential@1000': 'z', 14 | 'geopotential@50': 'z', 15 | 'geopotential@500': 'z', 16 | 'geopotential@850': 'z', 17 | 'mean_sea_level_pressure': 'msl', 18 | 'relative_humidity@500': 'r', 19 | 'relative_humidity@850': 'r', 20 | 'surface_pressure': 'sp', 21 | 'temperature@500': 't', 22 | 'temperature@850': 't', 23 | 'total_column_water_vapour': 'tcwv', 24 | 'u_component_of_wind@1000': 'u', 25 | 'u_component_of_wind@500': 'u', 26 | 'u_component_of_wind@850': 'u', 27 | 'v_component_of_wind@1000': 'v', 28 | 'v_component_of_wind@500': 'v', 29 | 'v_component_of_wind@850': 'v', 30 | 'total_precipitation': 'tp'} 31 | DATAVARS = ['u10', 'v10', 't2m', 'z@1000', 'z@50', 'z@500', 'z@850', 'msl', 'r@500', 'r@850', 'sp', 't@500', 't@850', 32 | 'tcwv', 'u@1000', 'u@500', 'u@850', 'v@1000', 'v@500', 'v@850', 'tp'] 33 | 34 | 35 | if __name__ == '__main__': 36 | 37 | Mean, Std = [], [] 38 | 39 | for k, v in DATANAMES.items(): 40 | 41 | raw = xr.open_mfdataset((f'{DATADIR}/{k}/*.nc'), combine='by_coords') 42 | 43 | if k == 'total_precipitation': 44 | data = raw[v].values[:, :, :, 0] * 1e9 45 | np.nan_to_num(data, copy=False) 46 | meanv = np.mean(data) / 1e9 47 | stdv = np.std(data) / 1e9 48 | else: 49 | data = raw[v].values 50 | np.nan_to_num(data, copy=False) 51 | meanv = np.mean(data) 52 | stdv = np.std(data) 53 | 54 | print(f'Var: {k} | mean: {meanv}, std: {stdv}') 55 | 56 | Mean.append(meanv) 57 | Std.append(stdv) 58 | 59 | with open("./output/data/scaler.pkl", "wb") as f: 60 | pickle.dump({ 61 | "mean": np.asarray(Mean), 62 | "std": np.asarray(Std) 63 | }, f) -------------------------------------------------------------------------------- /img/precipitation_small.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFAiLab/OpenCastKit/6e3f934533e1aeb6446058a19eb3057b1a304ad2/img/precipitation_small.gif -------------------------------------------------------------------------------- /img/wind_small.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFAiLab/OpenCastKit/6e3f934533e1aeb6446058a19eb3057b1a304ad2/img/wind_small.gif -------------------------------------------------------------------------------- /model/afnonet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 8 | import torch.fft 9 | from utils.params import get_args 10 | from torch.utils.checkpoint import checkpoint_sequential 11 | 12 | 13 | class Mlp(nn.Module): 14 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 15 | super().__init__() 16 | out_features = out_features or in_features 17 | hidden_features = hidden_features or in_features 18 | self.fc1 = nn.Linear(in_features, hidden_features) 19 | self.act = act_layer() 20 | # self.fc2 = nn.Linear(hidden_features, out_features) 21 | self.fc2 = nn.AdaptiveAvgPool1d(out_features) 22 | self.drop = nn.Dropout(drop) 23 | 24 | def forward(self, x): 25 | x = self.fc1(x) 26 | x = self.act(x) 27 | x = self.drop(x) 28 | x = self.fc2(x) 29 | x = self.drop(x) 30 | return x 31 | 32 | 33 | class AdaptiveFourierNeuralOperator(nn.Module): 34 | def __init__(self, dim, h=14, w=8): 35 | super().__init__() 36 | args = get_args() 37 | self.hidden_size = dim 38 | self.h = h 39 | self.w = w 40 | 41 | self.num_blocks = args.fno_blocks 42 | self.block_size = self.hidden_size // self.num_blocks 43 | assert self.hidden_size % self.num_blocks == 0 44 | 45 | self.scale = 0.02 46 | self.w1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size)) 47 | self.b1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) 48 | self.w2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size)) 49 | self.b2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) 50 | self.relu = nn.ReLU() 51 | 52 | if args.fno_bias: 53 | self.bias = nn.Conv1d(self.hidden_size, self.hidden_size, 1) 54 | else: 55 | self.bias = None 56 | 57 | self.softshrink = args.fno_softshrink 58 | 59 | def multiply(self, input, weights): 60 | return torch.einsum('...bd,bdk->...bk', input, weights) 61 | 62 | def forward(self, x): 63 | B, N, C = x.shape 64 | 65 | if self.bias: 66 | bias = self.bias(x.permute(0, 2, 1)).permute(0, 2, 1) 67 | else: 68 | bias = torch.zeros(x.shape, device=x.device) 69 | 70 | x = x.reshape(B, self.h, self.w, C) 71 | x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') 72 | x = x.reshape(B, x.shape[1], x.shape[2], self.num_blocks, self.block_size) 73 | 74 | x_real = F.relu(self.multiply(x.real, self.w1[0]) - self.multiply(x.imag, self.w1[1]) + self.b1[0], inplace=True) 75 | x_imag = F.relu(self.multiply(x.real, self.w1[1]) + self.multiply(x.imag, self.w1[0]) + self.b1[1], inplace=True) 76 | x_real = self.multiply(x_real, self.w2[0]) - self.multiply(x_imag, self.w2[1]) + self.b2[0] 77 | x_imag = self.multiply(x_real, self.w2[1]) + self.multiply(x_imag, self.w2[0]) + self.b2[1] 78 | 79 | x = torch.stack([x_real, x_imag], dim=-1) 80 | x = F.softshrink(x, lambd=self.softshrink) if self.softshrink else x 81 | 82 | x = torch.view_as_complex(x) 83 | x = x.reshape(B, x.shape[1], x.shape[2], self.hidden_size) 84 | x = torch.fft.irfft2(x, s=(self.h, self.w), dim=(1, 2), norm='ortho') 85 | x = x.reshape(B, N, C) 86 | 87 | return x + bias 88 | 89 | 90 | class Block(nn.Module): 91 | def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8): 92 | super().__init__() 93 | args = get_args() 94 | self.norm1 = norm_layer(dim) 95 | self.filter = AdaptiveFourierNeuralOperator(dim, h=h, w=w) 96 | 97 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 98 | self.norm2 = norm_layer(dim) 99 | mlp_hidden_dim = int(dim * mlp_ratio) 100 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 101 | 102 | self.double_skip = args.double_skip 103 | 104 | def forward(self, x): 105 | residual = x 106 | x = self.norm1(x) 107 | x = self.filter(x) 108 | 109 | if self.double_skip: 110 | x += residual 111 | residual = x 112 | 113 | x = self.norm2(x) 114 | x = self.mlp(x) 115 | x = self.drop_path(x) 116 | x += residual 117 | return x 118 | 119 | 120 | class PatchEmbed(nn.Module): 121 | def __init__(self, img_size=None, patch_size=8, in_chans=13, embed_dim=768): 122 | super().__init__() 123 | 124 | if img_size is None: 125 | raise KeyError('img is None') 126 | 127 | patch_size = to_2tuple(patch_size) 128 | 129 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 130 | self.img_size = img_size 131 | self.patch_size = patch_size 132 | self.num_patches = num_patches 133 | 134 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 135 | 136 | def forward(self, x): 137 | B, C, H, W = x.shape 138 | # FIXME look at relaxing size constraints 139 | assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 140 | x = self.proj(x).flatten(2).transpose(1, 2) 141 | return x 142 | 143 | 144 | class AFNONet(nn.Module): 145 | def __init__(self, img_size=None, patch_size=8, in_chans=20, out_chans=20, embed_dim=768, depth=12, mlp_ratio=4., 146 | uniform_drop=False, drop_rate=0., drop_path_rate=0., norm_layer=None, dropcls=0): 147 | super().__init__() 148 | 149 | if img_size is None: 150 | img_size = [720, 1440] 151 | 152 | self.embed_dim = embed_dim 153 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 154 | 155 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 156 | num_patches = self.patch_embed.num_patches 157 | 158 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 159 | self.pos_drop = nn.Dropout(p=drop_rate) 160 | 161 | self.h = img_size[0] // patch_size 162 | self.w = img_size[1] // patch_size 163 | 164 | if uniform_drop: 165 | dpr = [drop_path_rate for _ in range(depth)] # stochastic depth decay rule 166 | else: 167 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 168 | 169 | self.blocks = nn.ModuleList([Block(dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, h=self.h, w=self.w) for i in range(depth)]) 170 | self.norm = norm_layer(embed_dim) 171 | 172 | # Representation layer 173 | # self.num_features = out_chans * img_size[0] * img_size[1] 174 | # self.representation_size = self.num_features * 8 175 | # self.pre_logits = nn.Sequential(OrderedDict([ 176 | # ('fc', nn.Linear(embed_dim, self.representation_size)), 177 | # ('act', nn.Tanh()) 178 | # ])) 179 | self.pre_logits = nn.Sequential(OrderedDict([ 180 | ('conv1', nn.ConvTranspose2d(embed_dim, out_chans*16, kernel_size=(2, 2), stride=(2, 2))), 181 | ('act1', nn.Tanh()), 182 | ('conv2', nn.ConvTranspose2d(out_chans*16, out_chans*4, kernel_size=(2, 2), stride=(2, 2))), 183 | ('act2', nn.Tanh()) 184 | ])) 185 | 186 | # Generator head 187 | # self.head = nn.Linear(self.representation_size, self.num_features) 188 | self.head = nn.ConvTranspose2d(out_chans*4, out_chans, kernel_size=(2, 2), stride=(2, 2)) 189 | 190 | if dropcls > 0: 191 | print('dropout %.2f before classifier' % dropcls) 192 | self.final_dropout = nn.Dropout(p=dropcls) 193 | else: 194 | self.final_dropout = nn.Identity() 195 | 196 | trunc_normal_(self.pos_embed, std=.02) 197 | self.apply(self._init_weights) 198 | 199 | def _init_weights(self, m): 200 | if isinstance(m, nn.Linear): 201 | trunc_normal_(m.weight, std=.02) 202 | if isinstance(m, nn.Linear) and m.bias is not None: 203 | nn.init.constant_(m.bias, 0) 204 | elif isinstance(m, nn.LayerNorm): 205 | nn.init.constant_(m.bias, 0) 206 | nn.init.constant_(m.weight, 1.0) 207 | 208 | @torch.jit.ignore 209 | def no_weight_decay(self): 210 | return {'pos_embed', 'cls_token'} 211 | 212 | def forward_features(self, x): 213 | B = x.shape[0] 214 | x = self.patch_embed(x) 215 | x += self.pos_embed 216 | x = self.pos_drop(x) 217 | 218 | if not get_args().checkpoint_activations: 219 | for blk in self.blocks: 220 | x = blk(x) 221 | else: 222 | x = checkpoint_sequential(self.blocks, 4, x) 223 | 224 | x = self.norm(x).transpose(1, 2) 225 | x = torch.reshape(x, [-1, self.embed_dim, self.h, self.w]) 226 | return x 227 | 228 | def forward(self, x): 229 | x = self.forward_features(x) 230 | x = self.final_dropout(x) 231 | x = self.pre_logits(x) 232 | x = self.head(x) 233 | return x 234 | -------------------------------------------------------------------------------- /model/graphcast_sequential.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_scatter import scatter 4 | from haiscale.pipeline import SequentialModel 5 | 6 | 7 | class FeatEmbedding(torch.nn.Module): 8 | def __init__(self, args): 9 | super(FeatEmbedding, self).__init__() 10 | 11 | gdim, mdim, edim = args.grid_node_dim, args.mesh_node_dim, args.edge_dim 12 | gemb, memb, eemb = args.grid_node_embed_dim, args.mesh_node_embed_dim, args.edge_embed_dim 13 | 14 | # Embedding the input features 15 | self.grid_feat_embedding = nn.Sequential( 16 | nn.Linear(gdim, gemb, bias=True) 17 | ) 18 | self.mesh_feat_embedding = nn.Sequential( 19 | nn.Linear(mdim, memb, bias=True) 20 | ) 21 | self.mesh_edge_feat_embedding = nn.Sequential( 22 | nn.Linear(edim, eemb, bias=True) 23 | ) 24 | self.grid2mesh_edge_feat_embedding = nn.Sequential( 25 | nn.Linear(edim, eemb, bias=True) 26 | ) 27 | self.mesh2grid_edge_feat_embedding = nn.Sequential( 28 | nn.Linear(edim, eemb, bias=True) 29 | ) 30 | 31 | def forward(self, gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x): 32 | 33 | bs = gx.size(0) 34 | 35 | gx = self.grid_feat_embedding(gx) 36 | mx = self.mesh_feat_embedding(mx).repeat(bs, 1, 1) 37 | me_x = self.mesh_edge_feat_embedding(me_x).repeat(bs, 1, 1) 38 | g2me_x = self.grid2mesh_edge_feat_embedding(g2me_x).repeat(bs, 1, 1) 39 | m2ge_x = self.mesh2grid_edge_feat_embedding(m2ge_x).repeat(bs, 1, 1) 40 | 41 | return gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x 42 | 43 | 44 | class Grid2MeshEdgeUpdate(torch.nn.Module): 45 | def __init__(self, args): 46 | super(Grid2MeshEdgeUpdate, self).__init__() 47 | 48 | g2m_enum = args.grid2mesh_edge_num 49 | gemb, memb, eemb = args.grid_node_embed_dim, args.mesh_node_embed_dim, args.edge_embed_dim 50 | 51 | # Grid2Mesh GNN 52 | self.grid2mesh_edge_update = nn.Sequential( 53 | nn.Linear(gemb + memb + eemb, 512, bias=True), 54 | nn.SiLU(), 55 | nn.Linear(512, 64, bias=True), 56 | nn.SiLU(), 57 | nn.Linear(64, eemb, bias=True), 58 | nn.LayerNorm([g2m_enum, eemb]) 59 | ) 60 | 61 | def forward(self, gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x): 62 | 63 | row, col = g2me_i 64 | 65 | # edge update 66 | edge_attr_updated = torch.cat([gx[:, row], mx[:, col], g2me_x], dim=-1) 67 | edge_attr_updated = self.grid2mesh_edge_update(edge_attr_updated) 68 | 69 | # residual 70 | g2me_x = g2me_x + edge_attr_updated 71 | 72 | return gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x 73 | 74 | 75 | class Grid2MeshNodeUpdate(torch.nn.Module): 76 | def __init__(self, args): 77 | super(Grid2MeshNodeUpdate, self).__init__() 78 | 79 | gnum, mnum = args.grid_node_num, args.mesh_node_num 80 | gemb, memb, eemb = args.grid_node_embed_dim, args.mesh_node_embed_dim, args.edge_embed_dim 81 | 82 | # Grid2Mesh GNN 83 | self.grid2mesh_node_aggregate = nn.Sequential( 84 | nn.Linear(memb + eemb, 512, bias=True), 85 | nn.SiLU(), 86 | nn.Linear(512, 256, bias=True), 87 | nn.SiLU(), 88 | nn.Linear(256, memb, bias=True), 89 | nn.LayerNorm([mnum, memb]) 90 | ) 91 | self.grid2mesh_grid_update = nn.Sequential( 92 | nn.Linear(gemb, 256, bias=True), 93 | nn.SiLU(), 94 | nn.Linear(256, gemb, bias=True), 95 | nn.LayerNorm([gnum, gemb]) 96 | ) 97 | 98 | def forward(self, gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x): 99 | 100 | row, col = g2me_i 101 | 102 | # mesh node update 103 | edge_agg = scatter(g2me_x, col, dim=-2, reduce='sum') 104 | mesh_node_updated = torch.cat([mx, edge_agg], dim=-1) 105 | mesh_node_updated = self.grid2mesh_node_aggregate(mesh_node_updated) 106 | 107 | # grid node update 108 | grid_node_updated = self.grid2mesh_grid_update(gx) 109 | 110 | # residual 111 | gx = gx + grid_node_updated 112 | mx = mx + mesh_node_updated 113 | 114 | return gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x 115 | 116 | 117 | class MeshEdgeUpdate(torch.nn.Module): 118 | def __init__(self, args): 119 | super(MeshEdgeUpdate, self).__init__() 120 | 121 | m_enum = args.mesh_edge_num 122 | memb, eemb = args.mesh_node_embed_dim, args.edge_embed_dim 123 | 124 | # Multi-mesh GNN 125 | self.mesh_edge_update = nn.Sequential( 126 | nn.Linear(memb + memb + eemb, 512, bias=True), 127 | nn.SiLU(), 128 | nn.Linear(512, 64, bias=True), 129 | nn.SiLU(), 130 | nn.Linear(64, eemb, bias=True), 131 | nn.LayerNorm([m_enum, eemb]) 132 | ) 133 | 134 | def forward(self, gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x): 135 | 136 | row, col = me_i 137 | 138 | # edge update 139 | edge_attr_updated = torch.cat([mx[:, row], mx[:, col], me_x], dim=-1) 140 | edge_attr_updated = self.mesh_edge_update(edge_attr_updated) 141 | 142 | # residual 143 | me_x = me_x + edge_attr_updated 144 | 145 | return gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x 146 | 147 | 148 | class MeshNodeUpdate(torch.nn.Module): 149 | def __init__(self, args): 150 | super(MeshNodeUpdate, self).__init__() 151 | 152 | mnum = args.mesh_node_num 153 | memb, eemb = args.mesh_node_embed_dim, args.edge_embed_dim 154 | 155 | # Grid2Mesh GNN 156 | self.mesh_node_aggregate = nn.Sequential( 157 | nn.Linear(memb + eemb, 512, bias=True), 158 | nn.SiLU(), 159 | nn.Linear(512, 256, bias=True), 160 | nn.SiLU(), 161 | nn.Linear(256, memb, bias=True), 162 | nn.LayerNorm([mnum, memb]) 163 | ) 164 | 165 | def forward(self, gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x): 166 | 167 | row, col = me_i 168 | 169 | # mesh node update 170 | edge_agg = scatter(me_x, col, dim=-2, reduce='sum') 171 | mesh_node_updated = torch.cat([mx, edge_agg], dim=-1) 172 | mesh_node_updated = self.mesh_node_aggregate(mesh_node_updated) 173 | 174 | # residual 175 | mx = mx + mesh_node_updated 176 | 177 | return gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x 178 | 179 | 180 | class Mesh2GridEdgeUpdate(torch.nn.Module): 181 | def __init__(self, args): 182 | super(Mesh2GridEdgeUpdate, self).__init__() 183 | 184 | m2g_enum = args.mesh2grid_edge_num 185 | gemb, memb, eemb = args.grid_node_embed_dim, args.mesh_node_embed_dim, args.edge_embed_dim 186 | 187 | # Mesh2grid GNN 188 | self.mesh2grid_edge_update = nn.Sequential( 189 | nn.Linear(gemb + memb + eemb, 512, bias=True), 190 | nn.SiLU(), 191 | nn.Linear(512, 64, bias=True), 192 | nn.SiLU(), 193 | nn.Linear(64, eemb, bias=True), 194 | nn.LayerNorm([m2g_enum, eemb]) 195 | ) 196 | 197 | def forward(self, gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x): 198 | 199 | row, col = m2ge_i 200 | 201 | # edge update 202 | edge_attr_updated = torch.cat([mx[:, row], gx[:, col], m2ge_x], dim=-1) 203 | edge_attr_updated = self.mesh2grid_edge_update(edge_attr_updated) 204 | 205 | # residual 206 | m2ge_x = m2ge_x + edge_attr_updated 207 | 208 | return gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x 209 | 210 | 211 | class Mesh2GridNodeUpdate(torch.nn.Module): 212 | def __init__(self, args): 213 | super(Mesh2GridNodeUpdate, self).__init__() 214 | 215 | gnum = args.grid_node_num 216 | gemb, eemb = args.grid_node_embed_dim, args.edge_embed_dim 217 | 218 | # Mesh2grid GNN 219 | self.mesh2grid_node_aggregate = nn.Sequential( 220 | nn.Linear(gemb + eemb, 512, bias=True), 221 | nn.SiLU(), 222 | nn.Linear(512, 256, bias=True), 223 | nn.SiLU(), 224 | nn.Linear(256, gemb, bias=True), 225 | nn.LayerNorm([gnum, gemb]) 226 | ) 227 | 228 | def forward(self, gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x): 229 | 230 | row, col = m2ge_i 231 | 232 | # mesh node update 233 | edge_agg = scatter(m2ge_x, col, dim=-2, reduce='sum') 234 | grid_node_updated = torch.cat([gx, edge_agg], dim=-1) 235 | grid_node_updated = self.mesh2grid_node_aggregate(grid_node_updated) 236 | 237 | # residual 238 | gx = gx + grid_node_updated 239 | 240 | return gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x 241 | 242 | 243 | class PredictNet(torch.nn.Module): 244 | def __init__(self, args): 245 | super(PredictNet, self).__init__() 246 | 247 | gemb = args.grid_node_embed_dim 248 | pred_dim = args.grid_node_pred_dim 249 | 250 | # prediction 251 | self.predict_nn = nn.Sequential( 252 | nn.Linear(gemb, 256, bias=True), 253 | nn.SiLU(), 254 | nn.Linear(256, 64, bias=True), 255 | nn.SiLU(), 256 | nn.Linear(64, pred_dim, bias=True) 257 | ) 258 | 259 | def forward(self, gx, mx, me_i, me_x, g2me_i, g2me_x, m2ge_i, m2ge_x): 260 | 261 | # output 262 | gx = self.predict_nn(gx) 263 | 264 | return gx 265 | 266 | 267 | def get_graphcast_module(args): 268 | embed = FeatEmbedding(args) 269 | gnn_blocks = [ 270 | Grid2MeshEdgeUpdate(args), 271 | Grid2MeshNodeUpdate(args), 272 | MeshEdgeUpdate(args), 273 | MeshNodeUpdate(args), 274 | Mesh2GridEdgeUpdate(args), 275 | Mesh2GridNodeUpdate(args), 276 | ] 277 | head = PredictNet(args) 278 | layers = [embed] + gnn_blocks + [head] 279 | 280 | return SequentialModel(*layers) 281 | -------------------------------------------------------------------------------- /train_fourcastnet.py: -------------------------------------------------------------------------------- 1 | import hfai 2 | hfai.set_watchdog_time(21600) 3 | 4 | import os 5 | from pathlib import Path 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from torch.utils.data.distributed import DistributedSampler 10 | from functools import partial 11 | import hfai.nccl.distributed as dist 12 | from torch.nn.parallel import DistributedDataParallel 13 | import timm.optim 14 | from timm.scheduler import create_scheduler 15 | 16 | torch.backends.cuda.matmul.allow_tf32 = True 17 | torch.backends.cudnn.allow_tf32 = True 18 | 19 | from data_factory.datasets import ERA5 20 | from model.afnonet import AFNONet 21 | from utils.params import get_fourcastnet_args 22 | from utils.tools import getModelSize, load_model, save_model 23 | from utils.eval import fourcastnet_pretrain_evaluate, fourcastnet_finetune_evaluate 24 | 25 | SAVE_PATH = Path('./output/model/fourcastnet/') 26 | SAVE_PATH.mkdir(parents=True, exist_ok=True) 27 | 28 | 29 | def pretrain_one_epoch(epoch, start_step, model, criterion, data_loader, optimizer, loss_scaler, lr_scheduler, min_loss): 30 | loss_val = torch.tensor(0., device="cuda") 31 | count = torch.tensor(1e-5, device="cuda") 32 | 33 | model.train() 34 | 35 | for step, batch in enumerate(data_loader): 36 | if step < start_step: 37 | continue 38 | 39 | _, x, y = [x.half().cuda(non_blocking=True) for x in batch] 40 | x = x.transpose(3, 2).transpose(2, 1) 41 | y = y.transpose(3, 2).transpose(2, 1) 42 | 43 | with torch.cuda.amp.autocast(): 44 | out = model(x) 45 | loss = criterion(out, y) 46 | if torch.isnan(loss).int().sum() == 0: 47 | count += 1 48 | loss_val += loss 49 | 50 | loss_scaler.scale(loss).backward() 51 | loss_scaler.step(optimizer) 52 | loss_scaler.update() 53 | optimizer.zero_grad() 54 | 55 | if dist.get_rank() == 0 and hfai.client.receive_suspend_command(): 56 | save_model(model.module, epoch, step+1, optimizer, lr_scheduler, loss_scaler, min_loss, SAVE_PATH/'pretrain_latest.pt') 57 | hfai.client.go_suspend() 58 | 59 | return loss_val.item() / count.item() 60 | 61 | 62 | def finetune_one_epoch(epoch, start_step, model, criterion, data_loader, optimizer, loss_scaler, lr_scheduler, min_loss): 63 | loss_val = torch.tensor(0., device="cuda") 64 | count = torch.tensor(1e-5, device="cuda") 65 | 66 | model.train() 67 | 68 | for step, batch in enumerate(data_loader): 69 | if step < start_step: 70 | continue 71 | 72 | xt0, xt1, xt2 = [x.half().cuda(non_blocking=True) for x in batch] 73 | xt0 = xt0.transpose(3, 2).transpose(2, 1) 74 | xt1 = xt1.transpose(3, 2).transpose(2, 1) 75 | xt2 = xt2.transpose(3, 2).transpose(2, 1) 76 | 77 | with torch.cuda.amp.autocast(): 78 | out = model(xt0) 79 | loss = criterion(out, xt1) 80 | out = model(out) 81 | loss += criterion(out, xt2) 82 | if torch.isnan(loss).int().sum() == 0: 83 | count += 1 84 | loss_val += loss 85 | 86 | loss_scaler.scale(loss).backward() 87 | loss_scaler.step(optimizer) 88 | loss_scaler.update() 89 | optimizer.zero_grad() 90 | 91 | if dist.get_rank() == 0 and hfai.client.receive_suspend_command(): 92 | save_model(model.module, epoch, step + 1, optimizer, lr_scheduler, loss_scaler, min_loss, SAVE_PATH / 'finetune_latest.pt') 93 | hfai.go_suspend() 94 | 95 | return loss_val.item() / count.item() 96 | 97 | 98 | def train(local_rank, rank, args): 99 | # input size 100 | h, w = 720, 1440 101 | x_c, y_c = 24, 20 102 | 103 | model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=y_c, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6)) 104 | model = hfai.nn.to_hfai(model) 105 | if local_rank == 0: 106 | param_sum, buffer_sum, all_size = getModelSize(model) 107 | print(f"Number of Parameters: {param_sum}, Number of Buffers: {buffer_sum}, Size of Model: {all_size:.4f} MB") 108 | model = DistributedDataParallel(model.cuda(), device_ids=[local_rank]) 109 | 110 | param_groups = timm.optim.optim_factory.add_weight_decay(model, args.weight_decay) 111 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 112 | loss_scaler = torch.cuda.amp.GradScaler(enabled=True) 113 | lr_scheduler, _ = create_scheduler(args, optimizer) 114 | criterion = torch.nn.MSELoss() 115 | 116 | train_dataset = ERA5(split="train", check_data=True, modelname='fourcastnet') 117 | train_datasampler = DistributedSampler(train_dataset, shuffle=True) 118 | train_dataloader = train_dataset.loader(args.batch_size, sampler=train_datasampler, num_workers=8, pin_memory=True, drop_last=False) 119 | val_dataset = ERA5(split="val", check_data=True, modelname='fourcastnet') 120 | val_datasampler = DistributedSampler(val_dataset) 121 | val_dataloader = val_dataset.loader(args.batch_size, sampler=val_datasampler, num_workers=8, pin_memory=True, drop_last=False) 122 | 123 | # load 124 | start_epoch, start_step, min_loss = load_model(model.module, optimizer, lr_scheduler, loss_scaler, SAVE_PATH / 'pretrain_latest.pt') 125 | if local_rank == 0: 126 | print(f"Start pretrain for {args.pretrain_epochs} epochs") 127 | 128 | for epoch in range(start_epoch, args.pretrain_epochs): 129 | 130 | train_loss = pretrain_one_epoch(epoch, start_step, model, criterion, train_dataloader, optimizer, loss_scaler, lr_scheduler, min_loss) 131 | start_step = 0 132 | lr_scheduler.step(epoch) 133 | 134 | val_loss = fourcastnet_pretrain_evaluate(val_dataloader, model, criterion) 135 | 136 | if rank == 0 and local_rank == 0: 137 | print(f"Epoch {epoch} | Train loss: {train_loss:.6f}, Val loss: {val_loss:.6f}") 138 | if val_loss < min_loss: 139 | min_loss = val_loss 140 | save_model(model.module, path=SAVE_PATH / 'backbone.pt', only_model=True) 141 | save_model(model.module, epoch + 1, 0, optimizer, lr_scheduler, loss_scaler, min_loss, SAVE_PATH / 'pretrain_latest.pt') 142 | 143 | 144 | # load 145 | start_epoch, start_step, min_loss = load_model(model.module, optimizer, lr_scheduler, loss_scaler, SAVE_PATH / 'finetune_latest.pt') 146 | if local_rank == 0: 147 | print(f"Start finetune for {args.finetune_epochs} epochs") 148 | 149 | for epoch in range(start_epoch, args.finetune_epochs): 150 | 151 | train_loss = finetune_one_epoch(epoch, start_step, model, criterion, train_dataloader, optimizer, loss_scaler, lr_scheduler, min_loss) 152 | start_step = 0 153 | lr_scheduler.step(epoch) 154 | 155 | val_loss = fourcastnet_finetune_evaluate(val_dataloader, model, criterion) 156 | 157 | if rank == 0 and local_rank == 0: 158 | print(f"Epoch {epoch} | Train loss: {train_loss:.6f}, Val loss: {val_loss:.6f}") 159 | if val_loss < min_loss: 160 | min_loss = val_loss 161 | save_model(model.module, path=SAVE_PATH / 'backbone.pt', only_model=True) 162 | save_model(model.module, epoch + 1, 0, optimizer, lr_scheduler, loss_scaler, min_loss, SAVE_PATH / 'finetune_latest.pt') 163 | 164 | 165 | def main(local_rank, args): 166 | # fix the seed for reproducibility 167 | torch.manual_seed(42) 168 | np.random.seed(42) 169 | cudnn.benchmark = True 170 | 171 | # init dist 172 | ip = os.environ.get("MASTER_ADDR", "127.0.0.1") 173 | port = os.environ.get("MASTER_PORT", "54247") 174 | hosts = int(os.environ.get("WORLD_SIZE", "1")) # number of nodes 175 | rank = int(os.environ.get("RANK", "0")) # node id 176 | gpus = torch.cuda.device_count() # gpus per node 177 | 178 | dist.init_process_group(backend="nccl", init_method=f"tcp://{ip}:{port}", world_size=hosts * gpus, rank=rank * gpus + local_rank) 179 | torch.cuda.set_device(local_rank) 180 | 181 | train(local_rank, rank, args) 182 | 183 | 184 | if __name__ == '__main__': 185 | args = get_fourcastnet_args() 186 | ngpus = torch.cuda.device_count() 187 | hfai.multiprocessing.spawn(main, args=(args,), nprocs=ngpus, bind_numa=True) 188 | 189 | 190 | 191 | -------------------------------------------------------------------------------- /train_graphcast.py: -------------------------------------------------------------------------------- 1 | import hfai 2 | hfai.set_watchdog_time(21600) 3 | 4 | import os 5 | from pathlib import Path 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.backends.cudnn as cudnn 10 | import hfai.nccl.distributed as dist 11 | from haiscale.ddp import DistributedDataParallel 12 | from haiscale.pipeline import PipeDream, make_subgroups, partition 13 | from torch.utils.data.distributed import DistributedSampler 14 | import timm.optim 15 | from timm.scheduler import create_scheduler 16 | 17 | torch.backends.cuda.matmul.allow_tf32 = True 18 | torch.backends.cudnn.allow_tf32 = True 19 | 20 | from data_factory.datasets import ERA5, EarthGraph 21 | from model.graphcast_sequential import get_graphcast_module 22 | from utils.params import get_graphcast_args 23 | from utils.tools import load_model, save_model 24 | from utils.eval import graphcast_evaluate 25 | 26 | SAVE_PATH = Path('./output/graphcast/') 27 | SAVE_PATH.mkdir(parents=True, exist_ok=True) 28 | 29 | 30 | def train_one_epoch(epoch, model, criterion, data_loader, graph, optimizer, lr_scheduler, min_loss, dp_group, pp_group): 31 | is_last_pipeline_stage = (pp_group.rank() == pp_group.size() - 1) 32 | loss = torch.tensor(0., device="cuda") 33 | count = torch.tensor(0., device="cuda") 34 | model.train() 35 | 36 | input_x = [ 37 | None, 38 | graph.mesh_data.x.half().cuda(non_blocking=True), 39 | graph.mesh_data.edge_index.cuda(non_blocking=True), 40 | graph.mesh_data.edge_attr.half().cuda(non_blocking=True), 41 | graph.grid2mesh_data.edge_index.cuda(non_blocking=True), 42 | graph.grid2mesh_data.edge_attr.half().cuda(non_blocking=True), 43 | graph.mesh2grid_data.edge_index.cuda(non_blocking=True), 44 | graph.mesh2grid_data.edge_attr.half().cuda(non_blocking=True) 45 | ] 46 | 47 | for step, batch in enumerate(data_loader): 48 | 49 | x, y = [x.half().cuda(non_blocking=True) for x in batch] 50 | input_x[0] = x 51 | 52 | with torch.cuda.amp.autocast(): 53 | # out = model(input_x) 54 | step_loss, _ = model.forward_backward(*input_x, criterion=criterion, labels=(y,)) 55 | 56 | optimizer.step() 57 | optimizer.zero_grad() 58 | 59 | if is_last_pipeline_stage: 60 | loss += step_loss.sum().item() 61 | count += 1 62 | 63 | if dp_group.rank() == 0 and is_last_pipeline_stage and hfai.client.receive_suspend_command(): 64 | save_model(model.module.module, epoch, step + 1, optimizer, lr_scheduler, min_loss, SAVE_PATH / 'latest.pt') 65 | hfai.go_suspend() 66 | 67 | # all-reduce in data paralel group 68 | if dp_group.rank == 0 and is_last_pipeline_stage: 69 | dist.all_reduce(loss, group=dp_group) 70 | dist.all_reduce(count, group=dp_group) 71 | loss = loss / count 72 | 73 | # broadcast from the last stage to other pipeline stages 74 | dist.all_reduce(loss, group=pp_group) 75 | 76 | return loss.item() 77 | 78 | 79 | def train(local_rank, args): 80 | rank, world_size = dist.get_rank(), dist.get_world_size() 81 | 82 | # data parallel + pipeline parallel 83 | dp_group, pp_group = make_subgroups(pp_size=args.pp_size) 84 | dp_rank, dp_size = dp_group.rank(), dp_group.size() 85 | pp_rank, pp_size = pp_group.rank(), pp_group.size() 86 | is_last_pipeline_stage = (pp_group.rank() == pp_group.size() - 1) 87 | print(f"RANK {rank}: data parallel {dp_rank}/{dp_size}, pipeline parallel {pp_rank}/{pp_size}", flush=True) 88 | 89 | # model & criterion & optimizer 90 | model = get_graphcast_module(args) 91 | # model = hfai.nn.to_hfai(model) 92 | balance = [1, 1, 1, 1, 1, 1, 1, 1] 93 | model = partition(model, pp_group.rank(), pp_group.size(), balance=balance) 94 | 95 | model = DistributedDataParallel(model.cuda(), process_group=dp_group) 96 | model = PipeDream(model, args.chunks, process_group=pp_group) 97 | 98 | # args.lr = args.lr * args.batch_size * dist.get_world_size() / 512.0 99 | param_groups = timm.optim.optim_factory.add_weight_decay(model, args.weight_decay) 100 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 101 | lr_scheduler, _ = create_scheduler(args, optimizer) 102 | criterion = nn.MSELoss() 103 | 104 | # generate graph 105 | graph = EarthGraph() 106 | graph.generate_graph() 107 | 108 | # load grid data 109 | train_dataset = ERA5(split="train", check_data=True, modelname='graphcast') 110 | train_datasampler = DistributedSampler(train_dataset, num_replicas=dp_size, rank=dp_rank, shuffle=True) 111 | train_dataloader = train_dataset.loader(args.batch_size, sampler=train_datasampler, num_workers=8, pin_memory=True, drop_last=True) 112 | val_dataset = ERA5(split="val", check_data=True, modelname='graphcast') 113 | val_datasampler = DistributedSampler(val_dataset, num_replicas=dp_size, rank=dp_rank, shuffle=True) 114 | val_dataloader = val_dataset.loader(args.batch_size, sampler=val_datasampler, num_workers=8, pin_memory=True, drop_last=False) 115 | 116 | # load 117 | start_epoch, min_loss = load_model(model.module.module, optimizer, lr_scheduler, SAVE_PATH / 'latest.pt') 118 | if local_rank == 0: 119 | print(f"Start training for {args.epochs} epochs") 120 | 121 | for epoch in range(start_epoch, args.epochs): 122 | 123 | train_loss = train_one_epoch(epoch, model, criterion, train_dataloader, graph, optimizer, lr_scheduler, min_loss, dp_group, pp_group) 124 | lr_scheduler.step(epoch) 125 | 126 | val_loss = graphcast_evaluate(val_dataloader, graph, model, criterion, dp_group, pp_group) 127 | 128 | if is_last_pipeline_stage: 129 | print(f"Epoch {epoch} | Train loss: {train_loss:.6f}, Val loss: {val_loss:.6f}") 130 | 131 | if dp_rank == 0: 132 | save_model(model.module.module, epoch + 1, optimizer, lr_scheduler, min_loss, SAVE_PATH / 'latest.pt') 133 | if val_loss < min_loss: 134 | min_loss = val_loss 135 | save_model(model.module.module, path=SAVE_PATH / 'best.pt', only_model=True) 136 | 137 | # synchronize all processes 138 | model.module.reducer.stop() 139 | dist.barrier() 140 | 141 | 142 | def main(local_rank, args): 143 | # fix the seed for reproducibility 144 | torch.manual_seed(2023) 145 | np.random.seed(2023) 146 | cudnn.benchmark = True 147 | 148 | # init dist 149 | ip = os.environ.get("MASTER_ADDR", "127.0.0.1") 150 | port = os.environ.get("MASTER_PORT", "54247") 151 | hosts = int(os.environ.get("WORLD_SIZE", "1")) # number of nodes 152 | rank = int(os.environ.get("RANK", "0")) # node id 153 | gpus = torch.cuda.device_count() # gpus per node 154 | 155 | dist.init_process_group(backend="nccl", init_method=f"tcp://{ip}:{port}", world_size=hosts * gpus, rank=rank * gpus + local_rank) 156 | torch.cuda.set_device(local_rank) 157 | 158 | train(local_rank, args) 159 | 160 | 161 | if __name__ == '__main__': 162 | args = get_graphcast_args() 163 | ngpus = torch.cuda.device_count() 164 | hfai.multiprocessing.spawn(main, args=(args,), nprocs=ngpus, bind_numa=True) 165 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hfai.nccl.distributed as dist 3 | 4 | 5 | @torch.no_grad() 6 | def fourcastnet_pretrain_evaluate(data_loader, model, criterion): 7 | loss = torch.tensor(0., device="cuda") 8 | count = torch.tensor(1e-5, device="cuda") 9 | 10 | # switch to evaluation mode 11 | model.eval() 12 | for batch in data_loader: 13 | _, x, y = [x.half().cuda(non_blocking=True) for x in batch] 14 | x = x.transpose(3, 2).transpose(2, 1) 15 | y = y.transpose(3, 2).transpose(2, 1) 16 | 17 | with torch.cuda.amp.autocast(): 18 | out = model(x) 19 | tmp_loss = criterion(out, y) 20 | if torch.isnan(tmp_loss).int().sum() == 0: 21 | count += 1 22 | loss += tmp_loss 23 | 24 | dist.reduce(loss, 0) 25 | dist.reduce(count, 0) 26 | 27 | loss_val = 0 28 | if dist.get_rank() == 0: 29 | loss_val = loss.item() / count.item() 30 | return loss_val 31 | 32 | 33 | @torch.no_grad() 34 | def fourcastnet_finetune_evaluate(data_loader, model, criterion): 35 | loss = torch.tensor(0., device="cuda") 36 | count = torch.tensor(1e-5, device="cuda") 37 | 38 | # switch to evaluation mode 39 | model.eval() 40 | for batch in data_loader: 41 | xt0, xt1, xt2 = [x.half().cuda(non_blocking=True) for x in batch] 42 | xt0 = xt0.transpose(3, 2).transpose(2, 1) 43 | xt1 = xt1.transpose(3, 2).transpose(2, 1) 44 | xt2 = xt2.transpose(3, 2).transpose(2, 1) 45 | 46 | with torch.cuda.amp.autocast(): 47 | out = model(xt0) 48 | loss += criterion(out, xt1) 49 | out = model(out) 50 | loss += criterion(out, xt2) 51 | count += 1 52 | 53 | dist.reduce(loss, 0) 54 | dist.reduce(count, 0) 55 | 56 | loss_val = 0 57 | if dist.get_rank() == 0: 58 | loss_val = loss.item() / count.item() 59 | return loss_val 60 | 61 | 62 | @torch.no_grad() 63 | def graphcast_evaluate(data_loader, graph, model, criterion, dp_group, pp_group): 64 | is_last_pipeline_stage = (pp_group.rank() == pp_group.size() - 1) 65 | loss = torch.tensor(0., device="cuda") 66 | count = torch.tensor(0., device="cuda") 67 | 68 | input_x = [ 69 | None, 70 | graph.mesh_data.x.half().cuda(non_blocking=True), 71 | graph.mesh_data.edge_index.cuda(non_blocking=True), 72 | graph.mesh_data.edge_attr.half().cuda(non_blocking=True), 73 | graph.grid2mesh_data.edge_index.cuda(non_blocking=True), 74 | graph.grid2mesh_data.edge_attr.half().cuda(non_blocking=True), 75 | graph.mesh2grid_data.edge_index.cuda(non_blocking=True), 76 | graph.mesh2grid_data.edge_attr.half().cuda(non_blocking=True) 77 | ] 78 | 79 | # switch to evaluation mode 80 | model.eval() 81 | for batch in data_loader: 82 | x, y = [x.half().cuda(non_blocking=True) for x in batch] 83 | input_x[0] = x 84 | 85 | with torch.cuda.amp.autocast(): 86 | out = model(*input_x) 87 | 88 | if is_last_pipeline_stage: 89 | loss += criterion(out, y) 90 | count += 1 91 | 92 | # all-reduce in data paralel group 93 | if is_last_pipeline_stage: 94 | dist.all_reduce(loss, group=dp_group) 95 | dist.all_reduce(count, group=dp_group) 96 | loss = loss / count 97 | else: 98 | loss = torch.tensor(0., device="cuda") 99 | 100 | # broadcast from the last stage to other pipeline stages 101 | dist.all_reduce(loss, group=pp_group) 102 | 103 | return loss.item() 104 | -------------------------------------------------------------------------------- /utils/params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_fourcastnet_args(): 5 | parser = argparse.ArgumentParser('FourCastNet training and evaluation script', add_help=False) 6 | parser.add_argument('--batch-size', default=4, type=int) 7 | parser.add_argument('--pretrain-epochs', default=80, type=int) 8 | parser.add_argument('--fintune-epochs', default=25, type=int) 9 | 10 | # Model parameters 11 | parser.add_argument('--arch', default='deit_small', type=str, help='Name of model to train') 12 | 13 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') 14 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', help='Drop path rate (default: 0.1)') 15 | 16 | # Optimizer parameters 17 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') 18 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') 19 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') 20 | parser.add_argument('--clip-grad', type=float, default=1, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') 21 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') 22 | parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)') 23 | # Learning rate schedule parameters 24 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"') 25 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', help='learning rate (default: 5e-4)') 26 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') 27 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') 28 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') 29 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') 30 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 31 | 32 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') 33 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') 34 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 35 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') 36 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') 37 | 38 | # Augmentation parameters 39 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') 40 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy. "v0" or "original". "(default: rand-m9-mstd0.5-inc1)'), 41 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 42 | parser.add_argument('--train-interpolation', type=str, default='bicubic', help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 43 | 44 | parser.add_argument('--repeated-aug', action='store_true') 45 | parser.set_defaults(repeated_aug=False) 46 | 47 | # * Random Erase params 48 | parser.add_argument('--reprob', type=float, default=0, metavar='PCT', help='Random erase prob (default: 0.25)') 49 | parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') 50 | parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') 51 | parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') 52 | 53 | # fno parameters 54 | parser.add_argument('--fno-bias', action='store_true') 55 | parser.add_argument('--fno-blocks', type=int, default=4) 56 | parser.add_argument('--fno-softshrink', type=float, default=0.00) 57 | parser.add_argument('--double-skip', action='store_true') 58 | parser.add_argument('--tensorboard-dir', type=str, default=None) 59 | parser.add_argument('--hidden-size', type=int, default=256) 60 | parser.add_argument('--num-layers', type=int, default=12) 61 | parser.add_argument('--checkpoint-activations', action='store_true') 62 | parser.add_argument('--autoresume', action='store_true') 63 | 64 | # attention parameters 65 | parser.add_argument('--num-attention-heads', type=int, default=1) 66 | 67 | # long short parameters 68 | parser.add_argument('--ls-w', type=int, default=4) 69 | parser.add_argument('--ls-dp-rank', type=int, default=16) 70 | 71 | return parser.parse_args() 72 | 73 | 74 | def get_graphcast_args(): 75 | parser = argparse.ArgumentParser('Graphcast training and evaluation script', add_help=False) 76 | parser.add_argument('--batch-size', default=2, type=int) 77 | parser.add_argument('--epochs', default=200, type=int) 78 | 79 | # Model parameters 80 | parser.add_argument('--grid-node-num', default=720 * 1440, type=int, help='The number of grid nodes') 81 | parser.add_argument('--mesh-node-num', default=128 * 320, type=int, help='The number of mesh nodes') 82 | parser.add_argument('--mesh-edge-num', default=217170, type=int, help='The number of mesh nodes') 83 | parser.add_argument('--grid2mesh-edge-num', default=1357920, type=int, help='The number of mesh nodes') 84 | parser.add_argument('--mesh2grid-edge-num', default=2230560, type=int, help='The number of mesh nodes') 85 | parser.add_argument('--grid-node-dim', default=49, type=int, help='The input dim of grid nodes') 86 | parser.add_argument('--grid-node-pred-dim', default=20, type=int, help='The output dim of grid-node prediction') 87 | parser.add_argument('--mesh-node-dim', default=3, type=int, help='The input dim of mesh nodes') 88 | parser.add_argument('--edge-dim', default=4, type=int, help='The input dim of all edges') 89 | parser.add_argument('--grid-node-embed-dim', default=64, type=int, help='The embedding dim of grid nodes') 90 | parser.add_argument('--mesh-node-embed-dim', default=64, type=int, help='The embedding dim of mesh nodes') 91 | parser.add_argument('--edge-embed-dim', default=8, type=int, help='The embedding dim of mesh nodes') 92 | 93 | # Optimizer parameters 94 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') 95 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') 96 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') 97 | parser.add_argument('--clip-grad', type=float, default=1, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') 98 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') 99 | parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)') 100 | 101 | # Learning rate schedule parameters 102 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"') 103 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', help='learning rate (default: 5e-4)') 104 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') 105 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') 106 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') 107 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') 108 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 109 | 110 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') 111 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') 112 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 113 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') 114 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') 115 | 116 | # Pipline training parameters 117 | parser.add_argument('--pp_size', type=int, default=8, help='pipeline parallel size') 118 | parser.add_argument('--chunks', type=int, default=1, help='chunk size') 119 | 120 | return parser.parse_args() -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def getModelSize(model): 6 | param_size = 0 7 | param_sum = 0 8 | for param in model.parameters(): 9 | param_size += param.nelement() * param.element_size() 10 | param_sum += param.nelement() 11 | 12 | buffer_size = 0 13 | buffer_sum = 0 14 | for buffer in model.buffers(): 15 | buffer_size += buffer.nelement() * buffer.element_size() 16 | buffer_sum += buffer.nelement() 17 | 18 | all_size = (param_size + buffer_size) / 1024 / 1024 19 | return param_sum, buffer_sum, all_size 20 | 21 | 22 | def load_model(model, optimizer=None, lr_scheduler=None, loss_scaler=None, path=None, only_model=False): 23 | 24 | start_epoch, start_step = 0, 0 25 | min_loss = np.inf 26 | if path.exists(): 27 | ckpt = torch.load(path, map_location="cpu") 28 | 29 | if only_model: 30 | model.load_state_dict(ckpt['model']) 31 | else: 32 | model.load_state_dict(ckpt['model']) 33 | optimizer.load_state_dict(ckpt['optimizer']) 34 | lr_scheduler.load_state_dict(ckpt['lr_scheduler']) 35 | if ckpt['loss_scaler'] is not None: 36 | loss_scaler.load_state_dict(ckpt['loss_scaler']) 37 | start_epoch = ckpt["epoch"] 38 | start_step = ckpt["step"] 39 | min_loss = ckpt["min_loss"] 40 | 41 | return start_epoch, start_step, min_loss 42 | 43 | 44 | def save_model(model, epoch=0, step=0, optimizer=None, lr_scheduler=None, loss_scaler=None, min_loss=0, path=None, only_model=False): 45 | 46 | if only_model: 47 | states = { 48 | 'model': model.state_dict(), 49 | } 50 | else: 51 | states = { 52 | 'model': model.state_dict(), 53 | 'optimizer': optimizer.state_dict(), 54 | 'lr_scheduler': lr_scheduler.state_dict(), 55 | 'loss_scaler': loss_scaler.state_dict() if loss_scaler is not None else None, 56 | 'epoch': epoch, 57 | 'step': step, 58 | 'min_loss': min_loss 59 | } 60 | 61 | torch.save(states, path) 62 | 63 | --------------------------------------------------------------------------------