├── requirements_cpu.txt ├── requirements_gpu.txt ├── constant_masks ├── land_mask.npy ├── soil_type.npy └── topography.npy ├── inference_cpu.py ├── inference_gpu.py ├── inference_iterative.py ├── README.md └── pseudocode.py /requirements_cpu.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | onnx==1.13.1 3 | onnxruntime==1.14.0 -------------------------------------------------------------------------------- /requirements_gpu.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | onnx==1.12.0 3 | onnxruntime-gpu==1.14.0 -------------------------------------------------------------------------------- /constant_masks/land_mask.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/198808xc/Pangu-Weather/HEAD/constant_masks/land_mask.npy -------------------------------------------------------------------------------- /constant_masks/soil_type.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/198808xc/Pangu-Weather/HEAD/constant_masks/soil_type.npy -------------------------------------------------------------------------------- /constant_masks/topography.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/198808xc/Pangu-Weather/HEAD/constant_masks/topography.npy -------------------------------------------------------------------------------- /inference_cpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import onnx 4 | import onnxruntime as ort 5 | 6 | 7 | # The directory of your input and output data 8 | input_data_dir = 'input_data' 9 | output_data_dir = 'output_data' 10 | model_24 = onnx.load('pangu_weather_24.onnx') 11 | 12 | # Set the behavier of onnxruntime 13 | options = ort.SessionOptions() 14 | options.enable_cpu_mem_arena=False 15 | options.enable_mem_pattern = False 16 | options.enable_mem_reuse = False 17 | # Increase the number for faster inference and more memory consumption 18 | options.intra_op_num_threads = 1 19 | 20 | # Set the behavier of cuda provider 21 | cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',} 22 | 23 | # Initialize onnxruntime session for Pangu-Weather Models 24 | ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider']) 25 | 26 | # Load the upper-air numpy arrays 27 | input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32) 28 | # Load the surface numpy arrays 29 | input_surface = np.load(os.path.join(input_data_dir, 'input_surface.npy')).astype(np.float32) 30 | 31 | # Run the inference session 32 | output, output_surface = ort_session_24.run(None, {'input':input, 'input_surface':input_surface}) 33 | 34 | # Save the results 35 | np.save(os.path.join(output_data_dir, 'output_upper'), output) 36 | np.save(os.path.join(output_data_dir, 'output_surface'), output_surface) -------------------------------------------------------------------------------- /inference_gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import onnx 4 | import onnxruntime as ort 5 | 6 | 7 | # The directory of your input and output data 8 | input_data_dir = 'input_data' 9 | output_data_dir = 'output_data' 10 | model_24 = onnx.load('pangu_weather_24.onnx') 11 | 12 | # Set the behavier of onnxruntime 13 | options = ort.SessionOptions() 14 | options.enable_cpu_mem_arena=False 15 | options.enable_mem_pattern = False 16 | options.enable_mem_reuse = False 17 | # Increase the number for faster inference and more memory consumption 18 | options.intra_op_num_threads = 1 19 | 20 | # Set the behavier of cuda provider 21 | cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',} 22 | 23 | # Initialize onnxruntime session for Pangu-Weather Models 24 | ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) 25 | 26 | # Load the upper-air numpy arrays 27 | input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32) 28 | # Load the surface numpy arrays 29 | input_surface = np.load(os.path.join(input_data_dir, 'input_surface.npy')).astype(np.float32) 30 | 31 | # Run the inference session 32 | output, output_surface = ort_session_24.run(None, {'input':input, 'input_surface':input_surface}) 33 | # Save the results 34 | np.save(os.path.join(output_data_dir, 'output_upper'), output) 35 | np.save(os.path.join(output_data_dir, 'output_surface'), output_surface) -------------------------------------------------------------------------------- /inference_iterative.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import onnx 4 | import onnxruntime as ort 5 | 6 | 7 | # The directory of your input and output data 8 | input_data_dir = 'input_data' 9 | output_data_dir = 'output_data' 10 | model_24 = onnx.load('pangu_weather_24.onnx') 11 | model_6 = onnx.load('pangu_weather_6.onnx') 12 | 13 | # Set the behavier of onnxruntime 14 | options = ort.SessionOptions() 15 | options.enable_cpu_mem_arena=False 16 | options.enable_mem_pattern = False 17 | options.enable_mem_reuse = False 18 | # Increase the number for faster inference and more memory consumption 19 | options.intra_op_num_threads = 1 20 | 21 | # Set the behavier of cuda provider 22 | cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',} 23 | 24 | # Initialize onnxruntime session for Pangu-Weather Models 25 | ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) 26 | ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) 27 | 28 | # Load the upper-air numpy arrays 29 | input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32) 30 | # Load the surface numpy arrays 31 | input_surface = np.load(os.path.join(input_data_dir, 'input_surface.npy')).astype(np.float32) 32 | 33 | # Run the inference session 34 | input_24, input_surface_24 = input, input_surface 35 | for i in range(28): 36 | if (i+1) % 4 == 0: 37 | output, output_surface = ort_session_24.run(None, {'input':input_24, 'input_surface':input_surface_24}) 38 | input_24, input_surface_24 = output, output_surface 39 | else: 40 | output, output_surface = ort_session_6.run(None, {'input':input, 'input_surface':input_surface}) 41 | input, input_surface = output, output_surface 42 | # Your can save the results here -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pangu-Weather 2 | 3 | This is the official repository for the Pangu-Weather papers. 4 | 5 | [Accurate medium-range global weather forecasting with 3D neural networks](https://www.nature.com/articles/s41586-023-06185-3), Nature, Volume 619, Pages 533–538, 2023. 6 | 7 | [Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast](https://arxiv.org/abs/2211.02556), arXiv preprint: 2211.02556, 2022. 8 | 9 | *by Kaifeng Bi, Lingxi Xie, Hengheng Zhang, Xin Chen, Xiaotao Gu and Qi Tian* 10 | 11 | **Note: the arXiv version offers more technical details, and the Nature paper contains some new figures.** 12 | 13 | Resources including pseudocode, pre-trained models, and inference code are released here. 14 | 15 | The slides used in a series of recent talks are attached here. [Baidu Netdisk](https://pan.baidu.com/s/1eYkmf0QehEYdR1dk12idNA?pwd=zjj4), extraction code: zjj4 16 | 17 | 中文版PPT请参见链接: [百度网盘](https://pan.baidu.com/s/1oINuxuBstFEbfcouX6qHMw?pwd=7nzb), 提取码: 7nzb 18 | 19 | ## News and Updates 20 | 21 | * [Jul 31 2023] We released the details of training the lite version of Pangu-Weather. 22 | * [Jul 19 2023] ECMWF released an official [technical report](https://arxiv.org/abs/2307.10128) for "the rise of data-driven weather forecasting". Pangu-Weather was mentioned and tested thoroughly in the paper. We thank ECMWF for testing our models in real-world scenarios. 23 | * [Jul 17 2023] Pangu-Weather was online as part of ECMWF's operational suite! Everyone can see 10-day global weather forecasting **without running code**. ECMWF has made use of the models released at this repository! [Please search the ECMWF charts website with the query of "PANGU".](https://charts.ecmwf.int/?query=pangu) 24 | * [Jul 05 2023] Pangu-Weather was published on [Nature](https://www.nature.com/articles/s41586-023-06185-3). It was made **Open Access**! We recommend the researchers to cite our Nature paper in the future. 25 | * [Jun 27 2023] Pangu-Weather was presented at [PASC 2023](https://pasc23.pasc-conference.org/program/schedule/). 26 | * [Jun 12 2023] Pangu-Weather was presented at [VALSE 2023](http://valser.org/2023/#/workshopde?id=15). 27 | * [May 27 2023] Pangu-Weather was presented at [the WMO Early Warning for All (EW4ALL) conference](https://community.wmo.int/en/news/exploring-possibilities-artificial-intelligence-areas-water-weather-and-climate). 28 | * [May 12 2023] ECMWF released a [repository](https://github.com/ecmwf-lab/ai-models-panguweather), offering a toolkit for running Pangu-Weather. We thank ECMWF for the efforts in easing everyone to test Pangu-Weather. 29 | * [May 09 2023] Pangu-Weather was accepted by Nature! 30 | 31 | ## Installation 32 | 33 | The downloaded files shall be organized as the following hierarchy: 34 | 35 | ```plain 36 | ├── root 37 | │ ├── input_data 38 | │ │ ├── input_surface.npy 39 | │ │ ├── input_upper.npy 40 | │ ├── output_data 41 | │ ├── pangu_weather_1.onnx 42 | │ ├── pangu_weather_3.onnx 43 | │ ├── pangu_weather_6.onnx 44 | │ ├── pangu_weather_24.onnx 45 | │ ├── inference_cpu.py 46 | │ ├── inference_gpu.py 47 | │ ├── inference_iterative.py 48 | ``` 49 | 50 | If you use a CPU environment, please run: 51 | ``` 52 | pip install -r requirements_cpu.txt 53 | ``` 54 | 55 | If you use a GPU environment, please first confirm that the cuda version is 11.6 and the cudnn version is the 8.2.4 for Linux and 8.5.0.96 for Windows (please see [this page](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) for details). Then, please run: 56 | ``` 57 | pip install -r requirements_gpu.txt 58 | ``` 59 | 60 | ## Global weather forecasting (inference) using the trained models 61 | 62 | #### Downloading trained models 63 | 64 | Please download the four pre-trained models (~1.1GB each) from Google drive or Baidu netdisk: 65 | 66 | The 1-hour model (pangu_weather_1.onnx): [Google drive](https://drive.google.com/file/d/1fg5jkiN_5dHzKb-5H9Aw4MOmfILmeY-S/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1M7SAigVsCSH8hpw6DE8TDQ?pwd=ie0h) 67 | 68 | The 3-hour model (pangu_weather_3.onnx): [Google drive](https://drive.google.com/file/d/1EdoLlAXqE9iZLt9Ej9i-JW9LTJ9Jtewt/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/197fZsoiCqZYzKwM7tyRrfg?pwd=gmcl) 69 | 70 | The 6-hour model (pangu_weather_6.onnx): [Google drive](https://drive.google.com/file/d/1a4XTktkZa5GCtjQxDJb_fNaqTAUiEJu4/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1q7IB7tNjqIwoGC7KVMPn4w?pwd=vxq3) 71 | 72 | The 24-hour model (pangu_weather_24.onnx): [Google drive](https://drive.google.com/file/d/1lweQlxcn9fG0zKNW8ne1Khr9ehRTI6HP/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/179q2gkz2BrsOR6g3yfTVQg?pwd=eajy) 73 | 74 | These models are stored using the ONNX format, and thus can be used via different languages such as Python, C++, C#, Java, etc. 75 | 76 | #### Input data preparation using Python 77 | 78 | Please prepare the input data using [numpy](https://numpy.org/). There are two files that shall be put under the `input_data` folder, namely, `input_surface.npy` and `input_upper.npy`. 79 | 80 | `input_surface.npy` stores the input surface variables. It is a numpy array shaped (4,721,1440) where the first dimension represents the 4 surface variables (MSLP, U10, V10, T2M **in the exact order**). 81 | 82 | `input_upper.npy` stores the upper-air variables. It is a numpy array shaped (5,13,721,1440) where the first dimension represents the 5 surface variables (Z, Q, T, U and V **in the exact order**), and the second dimension represents the 13 pressure levels (1000hPa, 925hPa, 850hPa, 700hPa, 600hPa, 500hPa, 400hPa, 300hPa, 250hPa, 200hPa, 150hPa, 100hPa and 50hPa **in the exact order**). 83 | 84 | In both cases, the dimensions of 721 and 1440 represent the size along the latitude and longitude, where the numerical range is [90,-90] degree and [0,359.75] degree, respectively, and the spacing is 0.25 degrees. For each 721x1440 slice, the data format is exactly the same as the `.nc` file download from the ERA5 official website. 85 | 86 | Note that the numpy arrays should be in single precision (`.astype(np.float32)`), not in double precision. 87 | 88 | We support ERA5 initial fields and ECMWF initial fields (e.g., the initial fields of the HRES forecast), where the latter often leads to a slight accuracy drop (mainly for T2M because the two fields are quite different in temperature). A `.nc` file of ERA5 can be transformed into a `.npy` file using the netCDF4 package, and a `.grib` file of the ECMWF initial fields can be transformed into a `.npy` file using the pygrib package. Note that Z represents geopotential, not geopotential height, so a factor of 9.80665 should be multiplied if the raw data contains the geopotential height. 89 | 90 | We temporarily do not support other kinds of initial fields due to the possibly dramatic differences in the fields when Z<0. 91 | 92 | We provide an example of transferred input files, `input_surface.npy` and `input_upper.npy`, which correspond to the ERA5 initial fields of at 12:00UTC, 2018/09/27. Please download them from Google drive or Baidu netdisk: 93 | 94 | `input_surface.npy`: [Google drive](https://drive.google.com/file/d/1pj8QEVNpC1FyJfUabDpV4oU3NpSe0BkD/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1i4o5i8guAqmOus6PWncAlA?pwd=4z9s) 95 | 96 | `input_upper.npy`: [Google drive](https://drive.google.com/file/d/1--7xEBJt79E3oixizr8oFmK_haDE77SS/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1mS8X5MqEdbVfF2u2Us62FQ?pwd=sgx6) 97 | 98 | #### Inference 99 | 100 | After the above steps are finished, please check `inference_cpu.py` for an example of making a 24-hour weather forecast on CPU with the 24-hour model, and `inference_gpu.py` for the GPU version. 101 | 102 | For example, running the following command, one can get the 24-hour forecast in the `output_data` folder: 103 | ``` 104 | python inference_cpu.py # python inference_gpu.py for gpu environment 105 | ``` 106 | 107 | Also, `inference_iterative.py` shows an example to generate per-6-hour forecast within a week. 108 | 109 | ## Pseudocode and how to use 110 | 111 | `pseudocode.py` contains the pseudocode that elaborates our main algorithm. It is written in Python and can be implemented using any deep learning library, e.g. PyTorch and TensorFlow. 112 | 113 | Note that one needs to download about 60TB of ERA5 data and prepare for computational resource of 3000 GPU-days (in V100) to train each model. 114 | 115 | ## Training a lite version 116 | 117 | Recently, we found that Pangu-Weather can be trained efficiently using only 1% of data and GPU computation. We call the version Pangu-Weather-lite. Note that the lite models cannot rival the full models, but the lite version offers opportunities for researchers with limited resource to explore the AI methods for weather forecasting. 118 | 119 | Here are the key implementation details. 120 | 121 | * Data. We reduced the training data into 11 years (2007-2017) and only used the 00UTC time point (the full version used all 24 time points throughout the day). Also, only 00UTC data is used in the testing phase. The total amount of downloaded data shall be less than 1TB. 122 | * Model. We adjusted the down-sampling rate in the first stage from 2x4x4 to 2x8x8. 123 | * Training epochs. One can remain using 100 epochs or reduce the number to 50 (half); note that the cosine annealing schedule is adjusted accordingly. 124 | * Model set. We only trained one model (lead time is 24 hours), which means that the lite version can only perform daily weather forecasting. 125 | 126 | Here are the results. 127 | 128 | | Method | RMSE, Z500 | RMSE, T850 | RMSE, T2M | RMSE, U10 | Years | Down-sampling | Epochs | GPU x days | 129 | | ------------------- | ---------------------- | -------------------- | -------------------- | -------------------- | ----- | ------------- | -- | ---------- | 130 | | Operational IFS | 152.8 (3d), 333.7 (5d) | 1.37 (3d), 2.06 (5d) | 1.34 (3d), 1.75 (5d) | 1.94 (3d), 2.90 (5d) | -- | -- | -- | -- | 131 | | Pangu-Weather | 134.5 (3d), 296.7 (5d) | 1.14 (3d), 1.79 (5d) | 1.05 (3d), 1.53 (5d) | 1.61 (3d), 2.53 (5d) | 39 | 2 x 4 x 4 | 100 | 192 x 16 | 132 | | Pangu-Weather-Lite1 | 163.1 (3d), 338.2 (5d) | 1.29 (3d), 1.96 (5d) | 1.16 (3d), 1.64 (5d) | 1.80 (3d), 2.74 (5d) | 11 | 2 x 8 x 8 | 100 | 8 x 6 | 133 | | Pangu-Weather-Lite2 | 177.9 (3d), 357.5 (5d) | 1.36 (3d), 2.05 (5d) | 1.24 (3d), 1.71 (5d) | 1.90 (3d), 2.84 (5d) | 11 | 2 x 8 x 8 | 50 | 8 x 3 | 134 | 135 | One can observe that the lite version can surpass operational IFS (*when tested only at 00UTC time points*) in T850 (850hPa temperature), T2M (2m temperature) and U10 (u-component of 10m wind speed), while requiring less than 1% of computational costs compared to the full version. 136 | 137 | Please note that the lite version was only trained and tested in 00UTC data. This means that its performance on other time points is not guaranteed. Since whether variables are closely correlated to time-in-day, it is difficult to directly use the lite version for daily whether forecasting. Again, the lite version is to ease the researchers to explore the property of AI models. 138 | 139 | ## License 140 | 141 | Pangu-Weather was released by Huawei Cloud. 142 | 143 | The trained parameters of Pangu-Weather were made available under the terms of the BY-NC-SA 4.0 license. You can find details [here](https://creativecommons.org/licenses/by-nc-sa/4.0/). 144 | 145 | **The commercial use of these models is forbidden.** 146 | 147 | Also, please note that all models were trained using the ERA5 dataset provided by ECMWF. Please do follow [their policy](https://apps.ecmwf.int/datasets/licences/copernicus/). 148 | 149 | ## References 150 | 151 | If you use the resource in your research, please cite our paper: 152 | 153 | ``` 154 | @article{bi2023accurate, 155 | title={Accurate medium-range global weather forecasting with 3D neural networks}, 156 | author={Bi, Kaifeng and Xie, Lingxi and Zhang, Hengheng and Chen, Xin and Gu, Xiaotao and Tian, Qi}, 157 | journal={Nature}, 158 | volume={619}, 159 | number={7970}, 160 | pages={533--538}, 161 | year={2023}, 162 | publisher={Nature Publishing Group} 163 | } 164 | ``` 165 | 166 | We also offer the bibliography of the arXiv preprint version for your information. 167 | 168 | ``` 169 | @article{bi2022pangu, 170 | title={Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast}, 171 | author={Bi, Kaifeng and Xie, Lingxi and Zhang, Hengheng and Chen, Xin and Gu, Xiaotao and Tian, Qi}, 172 | journal={arXiv preprint arXiv:2211.02556}, 173 | year={2022} 174 | } 175 | ``` 176 | -------------------------------------------------------------------------------- /pseudocode.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Pseudocode of Pangu-Weather 3 | ''' 4 | # The pseudocode can be implemented using deep learning libraries, e.g., Pytorch and Tensorflow or other high-level APIs 5 | 6 | # Basic operations used in our model, namely Linear, Conv3d, Conv2d, ConvTranspose3d and ConvTranspose2d 7 | # Linear: Linear transformation, available in all deep learning libraries 8 | # Conv3d and Con2d: Convolution with 2 or 3 dimensions, available in all deep learning libraries 9 | # ConvTranspose3d, ConvTranspose2d: transposed convolution with 2 or 3 dimensions, see Pytorch API or Tensorflow API 10 | from Your_AI_Library import Linear, Conv3d, Conv2d, ConvTranspose3d, ConvTranspose2d 11 | 12 | # Functions in the networks, namely GeLU, DropOut, DropPath, LayerNorm, and SoftMax 13 | # GeLU: the GeLU activation function, see Pytorch API or Tensorflow API 14 | # DropOut: the dropout function, available in all deep learning libraries 15 | # DropPath: the DropPath function, see the implementation of vision-transformer, see timm pakage of Pytorch 16 | # A possible implementation of DropPath: from timm.models.layers import DropPath 17 | # LayerNorm: the layer normalization function, see Pytorch API or Tensorflow API 18 | # Softmax: softmax function, see Pytorch API or Tensorflow API 19 | from Your_AI_Library import GeLU, DropOut, DropPath, LayerNorm, SoftMax 20 | 21 | # Common functions for roll, pad, and crop, depends on the data structure of your software environment 22 | from Your_AI_Library import roll3D, pad3D, pad2D, Crop3D, Crop2D 23 | 24 | # Common functions for reshaping and changing the order of dimensions 25 | # reshape: change the shape of the data with the order unchanged, see Pytorch API or Tensorflow API 26 | # TransposeDimensions: change the order of the dimensions, see Pytorch API or Tensorflow API 27 | from Your_AI_Library import reshape, TransposeDimensions 28 | 29 | # Common functions for creating new tensors 30 | # ConstructTensor: create a new tensor with an arbitrary shape 31 | # TruncatedNormalInit: Initialize the tensor with Truncate Normalization distribution 32 | # RangeTensor: create a new tensor like range(a, b) 33 | from Your_AI_Library import ConstructTensor, TruncatedNormalInit, RangeTensor 34 | 35 | # Common operations for the data, you may design it or simply use deep learning APIs default operations 36 | # LinearSpace: a tensor version of numpy.linspace 37 | # MeshGrid: a tensor version of numpy.meshgrid 38 | # Stack: a tensor version of numpy.stack 39 | # Flatten: a tensor version of numpy.ndarray.flatten 40 | # TensorSum: a tensor version of numpy.sum 41 | # TensorAbs: a tensor version of numpy.abs 42 | # Concatenate: a tensor version of numpy.concatenate 43 | from Your_AI_Library import LinearSpace, MeshGrid, Stack, Flatten, TensorSum, TensorAbs, Concatenate 44 | 45 | # Common functions for training models 46 | # LoadModel and SaveModel: Load and save the model, some APIs may require further adaptation to hardwares 47 | # Backward: Gradient backward to calculate the gratitude of each parameters 48 | # UpdateModelParametersWithAdam: Use Adam to update parameters, e.g., torch.optim.Adam 49 | from Your_AI_Library import LoadModel, Backward, UpdateModelParametersWithAdam, SaveModel 50 | 51 | # Custom functions to read your data from the disc 52 | # LoadData: Load the ERA5 data 53 | # LoadConstantMask: Load constant masks, e.g., soil type 54 | # LoadStatic: Load mean and std of the ERA5 training data, every fields such as T850 is treated as an image and calculate the mean and std 55 | from Your_Data_Code import LoadData, LoadConstantMask, LoadStatic 56 | 57 | 58 | def Inference(input, input_surface, forecast_range): 59 | '''Inference code, describing the algorithm of inference using models with different lead times. 60 | PanguModel24, PanguModel6, PanguModel3 and PanguModel1 share the same training algorithm but differ in lead times. 61 | Args: 62 | input: input tensor, need to be normalized to N(0, 1) in practice 63 | input_surface: target tensor, need to be normalized to N(0, 1) in practice 64 | forecast_range: iteration numbers when roll out the forecast model 65 | ''' 66 | 67 | # Load 4 pre-trained models with different lead times 68 | PanguModel24 = LoadModel(ModelPath24) 69 | PanguModel6 = LoadModel(ModelPath6) 70 | PanguModel3 = LoadModel(ModelPath3) 71 | PanguModel1 = LoadModel(ModelPath1) 72 | 73 | # Load mean and std of the weather data 74 | weather_mean, weather_std, weather_surface_mean, weather_surface_std = LoadStatic() 75 | 76 | # Store initial input for different models 77 | input_24, input_surface_24 = input, input_surface 78 | input_6, input_surface_6 = input, input_surface 79 | input_3, input_surface_3 = input, input_surface 80 | 81 | # Using a list to store output 82 | output_list = [] 83 | 84 | # Note: the following code is implemented for fast inference of [1,forecast_range]-hour forecasts -- if only one lead time is requested, the inference can be much faster. 85 | for i in range(forecast_range): 86 | # switch to the 24-hour model if the forecast time is 24 hours, 48 hours, ..., 24*N hours 87 | if (i+1) % 24 == 0: 88 | # Switch the input back to the stored input 89 | input, input_surface = input_24, input_surface_24 90 | 91 | # Call the model pretrained for 24 hours forecast 92 | output, output_surface = PanguModel24(input, input_surface) 93 | 94 | # Restore from uniformed output 95 | output = output * weather_std + weather_mean 96 | output_surface = output_surface * weather_surface_std + weather_surface_mean 97 | 98 | # Stored the output for next round forecast 99 | input_24, input_surface_24 = output, output_surface 100 | input_6, input_surface_6 = output, output_surface 101 | input_3, input_surface_3 = output, output_surface 102 | 103 | # switch to the 6-hour model if the forecast time is 30 hours, 36 hours, ..., 24*N + 6/12/18 hours 104 | elif (i+1) % 6 == 0: 105 | # Switch the input back to the stored input 106 | input, input_surface = input_6, input_surface_6 107 | 108 | # Call the model pretrained for 6 hours forecast 109 | output, output_surface = PanguModel6(input, input_surface) 110 | 111 | # Restore from uniformed output 112 | output = output * weather_std + weather_mean 113 | output_surface = output_surface * weather_surface_std + weather_surface_mean 114 | 115 | # Stored the output for next round forecast 116 | input_6, input_surface_6 = output, output_surface 117 | input_3, input_surface_3 = output, output_surface 118 | 119 | # switch to the 3-hour model if the forecast time is 3 hours, 9 hours, ..., 6*N + 3 hours 120 | elif (i+1) % 3 ==0: 121 | # Switch the input back to the stored input 122 | input, input_surface = input_3, input_surface_3 123 | 124 | # Call the model pretrained for 3 hours forecast 125 | output, output_surface = PanguModel3(input, input_surface) 126 | 127 | # Restore from uniformed output 128 | output = output * weather_std + weather_mean 129 | output_surface = output_surface * weather_surface_std + weather_surface_mean 130 | 131 | # Stored the output for next round forecast 132 | input_3, input_surface_3 = output, output_surface 133 | 134 | # switch to the 1-hour model 135 | else: 136 | # Call the model pretrained for 1 hours forecast 137 | output, output_surface = PanguModel1(input, input_surface) 138 | 139 | # Restore from uniformed output 140 | output = output * weather_std + weather_mean 141 | output_surface = output_surface * weather_surface_std + weather_surface_mean 142 | 143 | # Stored the output for next round forecast 144 | input, input_surface = output, output_surface 145 | 146 | # Save the output 147 | output_list.append((output, output_surface)) 148 | return output_list 149 | 150 | 151 | def Train(): 152 | '''Training code''' 153 | # Initialize the model, for some APIs some adaptation is needed to fit hardwares 154 | model = PanguModel() 155 | 156 | # Train single Pangu-Weather model 157 | epochs = 100 158 | for i in range(epochs): 159 | # For each epoch, we iterate from 1979 to 2017 160 | # dataset_length is the length of your training data, e.g., the sample between 1979 and 2017 161 | for step in range(dataset_length): 162 | # Load weather data at time t as the input; load weather data at time t+1/3/6/24 as the output 163 | # Note the data need to be randomly shuffled 164 | # Note the input and target need to be normalized, see Inference() for details 165 | input, input_surface, target, target_surface = LoadData(step) 166 | 167 | # Call the model and get the output 168 | output, output_surface = model(input, input_surface) 169 | 170 | # We use the MAE loss to train the model 171 | # The weight of surface loss is 0.25 172 | # Different weight can be applied for differen fields if needed 173 | loss = TensorAbs(output-target) + TensorAbs(output_surface-target_surface) * 0.25 174 | 175 | # Call the backward algorithm and calculate the gratitude of parameters 176 | Backward(loss) 177 | 178 | # Update model parameters with Adam optimizer 179 | # The learning rate is 5e-4 as in the paper, while the weight decay is 3e-6 180 | # A example solution is using torch.optim.adam 181 | UpdateModelParametersWithAdam() 182 | 183 | # Save the model at the end of the training stage 184 | SaveModel() 185 | 186 | class PanguModel: 187 | def __init__(self): 188 | # Drop path rate is linearly increased as the depth increases 189 | drop_path_list = LinearSpace(0, 0.2, 8) 190 | 191 | # Patch embedding 192 | self._input_layer = PatchEmbedding((2, 4, 4), 192) 193 | 194 | # Four basic layers 195 | self.layer1 = EarthSpecificLayer(2, 192, drop_list[:2], 6) 196 | self.layer2 = EarthSpecificLayer(6, 384, drop_list[6:], 12) 197 | self.layer3 = EarthSpecificLayer(6, 384, drop_list[6:], 12) 198 | self.layer4 = EarthSpecificLayer(2, 192, drop_list[:2], 6) 199 | 200 | # Upsample and downsample 201 | self.upsample = UpSample(384, 192) 202 | self.downsample = DownSample(192) 203 | 204 | # Patch Recovery 205 | self._output_layer = PatchRecovery(384) 206 | 207 | def forward(self, input, input_surface): 208 | '''Backbone architecture''' 209 | # Embed the input fields into patches 210 | x = self._input_layer(input, input_surface) 211 | 212 | # Encoder, composed of two layers 213 | # Layer 1, shape (8, 360, 181, C), C = 192 as in the original paper 214 | x = self.layer1(x, 8, 360, 181) 215 | 216 | # Store the tensor for skip-connection 217 | skip = x 218 | 219 | # Downsample from (8, 360, 181) to (8, 180, 91) 220 | x = self.downsample(x, 8, 360, 181) 221 | 222 | # Layer 2, shape (8, 180, 91, 2C), C = 192 as in the original paper 223 | x = self.layer2(x, 8, 180, 91) 224 | 225 | # Decoder, composed of two layers 226 | # Layer 3, shape (8, 180, 91, 2C), C = 192 as in the original paper 227 | x = self.layer3(x, 8, 180, 91) 228 | 229 | # Upsample from (8, 180, 91) to (8, 360, 181) 230 | x = self.upsample(x) 231 | 232 | # Layer 4, shape (8, 360, 181, 2C), C = 192 as in the original paper 233 | x = self.layer4(x, 8, 360, 181) 234 | 235 | # Skip connect, in last dimension(C from 192 to 384) 236 | x = Concatenate(skip, x) 237 | 238 | # Recover the output fields from patches 239 | output, output_surface = self._output_layer(x) 240 | return output, output_surface 241 | 242 | class PatchEmbedding: 243 | def __init__(self, patch_size, dim): 244 | '''Patch embedding operation''' 245 | # Here we use convolution to partition data into cubes 246 | self.conv = Conv3d(input_dims=5, output_dims=dim, kernel_size=patch_size, stride=patch_size) 247 | self.conv_surface = Conv2d(input_dims=7, output_dims=dim, kernel_size=patch_size[1:], stride=patch_size[1:]) 248 | 249 | # Load constant masks from the disc 250 | self.land_mask, self.soil_type, self.topography = LoadConstantMask() 251 | 252 | def forward(self, input, input_surface): 253 | # Zero-pad the input 254 | input = Pad3D(input) 255 | input_surface = Pad2D(input_surface) 256 | 257 | # Apply a linear projection for patch_size[0]*patch_size[1]*patch_size[2] patches, patch_size = (2, 4, 4) as in the original paper 258 | input = self.conv(input) 259 | 260 | # Add three constant fields to the surface fields 261 | input_surface = Concatenate(input_surface, self.land_mask, self.soil_type, self.topography) 262 | 263 | # Apply a linear projection for patch_size[1]*patch_size[2] patches 264 | input_surface = self.conv_surface(input_surface) 265 | 266 | # Concatenate the input in the pressure level, i.e., in Z dimension 267 | x = Concatenate(input, input_surface) 268 | 269 | # Reshape x for calculation of linear projections 270 | x = TransposeDimensions(x, (0, 2, 3, 4, 1)) 271 | x = reshape(x, target_shape=(x.shape[0], 8*360*181, x.shape[-1])) 272 | return x 273 | 274 | class PatchRecovery: 275 | def __init__(self, dim): 276 | '''Patch recovery operation''' 277 | # Hear we use two transposed convolutions to recover data 278 | self.conv = ConvTranspose3d(input_dims=dim, output_dims=5, kernel_size=patch_size, stride=patch_size) 279 | self.conv_surface = ConvTranspose2d(input_dims=dim, output_dims=4, kernel_size=patch_size[1:], stride=patch_size[1:]) 280 | 281 | def forward(self, x, Z, H, W): 282 | # The inverse operation of the patch embedding operation, patch_size = (2, 4, 4) as in the original paper 283 | # Reshape x back to three dimensions 284 | x = TransposeDimensions(x, (0, 2, 1)) 285 | x = reshape(x, target_shape=(x.shape[0], x.shape[1], Z, H, W)) 286 | 287 | # Call the transposed convolution 288 | output = self.conv(x[:, :, 1:, :, :]) 289 | output_surface = self.conv_surface(x[:, :, 0, :, :]) 290 | 291 | # Crop the output to remove zero-paddings 292 | output = Crop3D(output) 293 | output_surface = Crop2D(output_surface) 294 | return output, output_surface 295 | 296 | class DownSample: 297 | def __init__(self, dim): 298 | '''Down-sampling operation''' 299 | # A linear function and a layer normalization 300 | self.linear = Linear(4*dim, 2*dim, bias=Fasle) 301 | self.norm = LayerNorm(4*dim) 302 | 303 | def forward(self, x, Z, H, W): 304 | # Reshape x to three dimensions for downsampling 305 | x = reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[-1])) 306 | 307 | # Padding the input to facilitate downsampling 308 | x = Pad3D(x) 309 | 310 | # Reorganize x to reduce the resolution: simply change the order and downsample from (8, 360, 182) to (8, 180, 91) 311 | Z, H, W = x.shape 312 | # Reshape x to facilitate downsampling 313 | x = reshape(x, target_shape=(x.shape[0], Z, H//2, 2, W//2, 2, x.shape[-1])) 314 | # Change the order of x 315 | x = TransposeDimensions(x, (0,1,2,4,3,5,6)) 316 | # Reshape to get a tensor of resolution (8, 180, 91) 317 | x = reshape(x, target_shape=(x.shape[0], Z*(H//2)*(W//2), 4 * x.shape[-1])) 318 | 319 | # Call the layer normalization 320 | x = self.norm(x) 321 | 322 | # Decrease the channels of the data to reduce computation cost 323 | x = self.linear(x) 324 | return x 325 | 326 | class UpSample: 327 | def __init__(self, input_dim, output_dim): 328 | '''Up-sampling operation''' 329 | # Linear layers without bias to increase channels of the data 330 | self.linear1 = Linear(input_dim, output_dim*4, bias=False) 331 | 332 | # Linear layers without bias to mix the data up 333 | self.linear2 = Linear(output_dim, output_dim, bias=False) 334 | 335 | # Normalization 336 | self.norm = LayerNorm(output_dim) 337 | 338 | def forward(self, x): 339 | # Call the linear functions to increase channels of the data 340 | x = self.linear1(x) 341 | 342 | # Reorganize x to increase the resolution: simply change the order and upsample from (8, 180, 91) to (8, 360, 182) 343 | # Reshape x to facilitate upsampling. 344 | x = reshape(x, target_shape=(x.shape[0], 8, 180, 91, 2, 2, x.shape[-1]//4)) 345 | # Change the order of x 346 | x = TransposeDimensions(x, (0,1,2,4,3,5,6)) 347 | # Reshape to get Tensor with a resolution of (8, 360, 182) 348 | x = reshape(x, target_shape=(x.shape[0], 8, 360, 182, x.shape[-1])) 349 | 350 | # Crop the output to the input shape of the network 351 | x = Crop3D(x) 352 | 353 | # Reshape x back 354 | x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[-1])) 355 | 356 | # Call the layer normalization 357 | x = self.norm(x) 358 | 359 | # Mixup normalized tensors 360 | x = self.linear2(x) 361 | return x 362 | 363 | class EarthSpecificLayer: 364 | def __init__(self, depth, dim, drop_path_ratio_list, heads): 365 | '''Basic layer of our network, contains 2 or 6 blocks''' 366 | self.depth = depth 367 | self.blocks = [] 368 | 369 | # Construct basic blocks 370 | for i in range(depth): 371 | self.blocks.append(EarthSpecificBlock(dim, drop_path_ratio_list[i], heads)) 372 | 373 | def forward(self, x, Z, H, W): 374 | for i in range(self.depth): 375 | # Roll the input every two blocks 376 | if i % 2 == 0: 377 | self.blocks[i](x, Z, H, W, roll=False) 378 | else: 379 | self.blocks[i](x, Z, H, W, roll=True) 380 | return x 381 | 382 | class EarthSpecificBlock: 383 | def __init__(self, dim, drop_path_ratio, heads): 384 | ''' 385 | 3D transformer block with Earth-Specific bias and window attention, 386 | see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention. 387 | The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias. 388 | ''' 389 | # Define the window size of the neural network 390 | self.window_size = (2, 6, 12) 391 | 392 | # Initialize serveral operations 393 | self.drop_path = DropPath(drop_rate=drop_path_ratio) 394 | self.norm1 = LayerNorm(dim) 395 | self.norm2 = LayerNorm(dim) 396 | self.linear = MLP(dim, 0) 397 | self.attention = EarthAttention3D(dim, heads, 0, self.window_size) 398 | 399 | def forward(self, x, Z, H, W, roll): 400 | # Save the shortcut for skip-connection 401 | shortcut = x 402 | 403 | # Reshape input to three dimensions to calculate window attention 404 | reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[2])) 405 | 406 | # Zero-pad input if needed 407 | x = pad3D(x) 408 | 409 | # Store the shape of the input for restoration 410 | ori_shape = x.shape 411 | 412 | if roll: 413 | # Roll x for half of the window for 3 dimensions 414 | x = roll3D(x, shift=[self.window_size[0]//2, self.window_size[1]//2, self.window_size[2]//2]) 415 | # Generate mask of attention masks 416 | # If two pixels are not adjacent, then mask the attention between them 417 | # Your can set the matrix element to -1000 when it is not adjacent, then add it to the attention 418 | mask = gen_mask(x) 419 | else: 420 | # e.g., zero matrix when you add mask to attention 421 | mask = no_mask 422 | 423 | # Reorganize data to calculate window attention 424 | x_window = reshape(x, target_shape=(x.shape[0], Z//window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], x.shape[-1])) 425 | x_window = TransposeDimensions(x_window, (0, 1, 3, 5, 2, 4, 6, 7)) 426 | 427 | # Get data stacked in 3D cubes, which will further be used to calculated attention among each cube 428 | x_window = reshape(x_window, target_shape=(-1, window_size[0]* window_size[1]*window_size[2], x.shape[-1])) 429 | 430 | # Apply 3D window attention with Earth-Specific bias 431 | x_window = self.attention(x, mask) 432 | 433 | # Reorganize data to original shapes 434 | x = reshape(x_window, target_shape=((-1, Z // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], x_window.shape[-1]))) 435 | x = TransposeDimensions(x, (0, 1, 4, 2, 5, 3, 6, 7)) 436 | 437 | # Reshape the tensor back to its original shape 438 | x = reshape(x_window, target_shape=ori_shape) 439 | 440 | if roll: 441 | # Roll x back for half of the window 442 | x = roll3D(x, shift=[-self.window_size[0]//2, -self.window_size[1]//2, -self.window_size[2]//2]) 443 | 444 | # Crop the zero-padding 445 | x = Crop3D(x) 446 | 447 | # Reshape the tensor back to the input shape 448 | x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[4])) 449 | 450 | # Main calculation stages 451 | x = shortcut + self.drop_path(self.norm1(x)) 452 | x = x + self.drop_path(self.norm2(self.linear(x))) 453 | return x 454 | 455 | class EarthAttention3D: 456 | def __init__(self, dim, heads, dropout_rate, window_size): 457 | ''' 458 | 3D window attention with the Earth-Specific bias, 459 | see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention. 460 | ''' 461 | # Initialize several operations 462 | self.linear1 = Linear(dim, dim=3, bias=True) 463 | self.linear2 = Linear(dim, dim) 464 | self.softmax = SoftMax(dim=-1) 465 | self.dropout = DropOut(dropout_rate) 466 | 467 | # Store several attributes 468 | self.head_number = heads 469 | self.dim = dim 470 | self.scale = (dim//heads)**-0.5 471 | self.window_size = window_size 472 | 473 | # input_shape is current shape of the self.forward function 474 | # You can run your code to record it, modify the code and rerun it 475 | # Record the number of different window types 476 | self.type_of_windows = (input_shape[0]//window_size[0])*(input_shape[1]//window_size[1]) 477 | 478 | # For each type of window, we will construct a set of parameters according to the paper 479 | self.earth_specific_bias = ConstructTensor(shape=((2 * window_size[2] - 1) * window_size[1] * window_size[1] * window_size[0] * window_size[0], self.type_of_windows, heads)) 480 | 481 | # Making these tensors to be learnable parameters 482 | self.earth_specific_bias = Parameters(self.earth_specific_bias) 483 | 484 | # Initialize the tensors using Truncated normal distribution 485 | TruncatedNormalInit(self.earth_specific_bias, std=0.02) 486 | 487 | # Construct position index to reuse self.earth_specific_bias 488 | self.position_index = self._construct_index() 489 | 490 | def _construct_index(self): 491 | ''' This function construct the position index to reuse symmetrical parameters of the position bias''' 492 | # Index in the pressure level of query matrix 493 | coords_zi = RangeTensor(self.window_size[0]) 494 | # Index in the pressure level of key matrix 495 | coords_zj = -RangeTensor(self.window_size[0])*self.window_size[0] 496 | 497 | # Index in the latitude of query matrix 498 | coords_hi = RangeTensor(self.window_size[1]) 499 | # Index in the latitude of key matrix 500 | coords_hj = -RangeTensor(self.window_size[1])*self.window_size[1] 501 | 502 | # Index in the longitude of the key-value pair 503 | coords_w = RangeTensor(self.window_size[2]) 504 | 505 | # Change the order of the index to calculate the index in total 506 | coords_1 = Stack(MeshGrid([coords_zi, coords_hi, coords_w])) 507 | coords_2 = Stack(MeshGrid([coords_zj, coords_hj, coords_w])) 508 | coords_flatten_1 = Flatten(coords_1, start_dimension=1) 509 | coords_flatten_2 = Flatten(coords_2, start_dimension=1) 510 | coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] 511 | coords = TransposeDimensions(coords, (1, 2, 0)) 512 | 513 | # Shift the index for each dimension to start from 0 514 | coords[:, :, 2] += self.window_size[2] - 1 515 | coords[:, :, 1] *= 2 * self.window_size[2] - 1 516 | coords[:, :, 0] *= (2 * self.window_size[2] - 1)*self.window_size[1]*self.window_size[1] 517 | 518 | # Sum up the indexes in three dimensions 519 | self.position_index = TensorSum(coords, dim=-1) 520 | 521 | # Flatten the position index to facilitate further indexing 522 | self.position_index = Flatten(self.position_index) 523 | 524 | def forward(self, x, mask): 525 | # Linear layer to create query, key and value 526 | x = self.linear1(x) 527 | 528 | # Record the original shape of the input 529 | original_shape = x.shape 530 | 531 | # reshape the data to calculate multi-head attention 532 | qkv = reshape(x, target_shape=(x.shape[0], x.shape[1], 3, self.head_number, self.dim // self.head_number)) 533 | query, key, value = TransposeDimensions(qkv, (2, 0, 3, 1, 4)) 534 | 535 | # Scale the attention 536 | query = query * self.scale 537 | 538 | # Calculated the attention, a learnable bias is added to fix the nonuniformity of the grid. 539 | attention = query @ key.T # @ denotes matrix multiplication 540 | 541 | # self.earth_specific_bias is a set of neural network parameters to optimize. 542 | EarthSpecificBias = self.earth_specific_bias[self.position_index] 543 | 544 | # Reshape the learnable bias to the same shape as the attention matrix 545 | EarthSpecificBias = reshape(EarthSpecificBias, target_shape=(self.window_size[0]*self.window_size[1]*self.window_size[2], self.window_size[0]*self.window_size[1]*self.window_size[2], self.type_of_windows, self.head_number)) 546 | EarthSpecificBias = TransposeDimensions(EarthSpecificBias, (2, 3, 0, 1)) 547 | EarthSpecificBias = reshape(EarthSpecificBias, target_shape = [1]+EarthSpecificBias.shape) 548 | 549 | # Add the Earth-Specific bias to the attention matrix 550 | attention = attention + EarthSpecificBias 551 | 552 | # Mask the attention between non-adjacent pixels, e.g., simply add -100 to the masked element. 553 | attention = self.mask_attention(attention, mask) 554 | attention = self.softmax(attention) 555 | attention = self.dropout(attention) 556 | 557 | # Calculated the tensor after spatial mixing. 558 | x = attention @ value.T # @ denote matrix multiplication 559 | 560 | # Reshape tensor to the original shape 561 | x = TransposeDimensions(x, (0, 2, 1)) 562 | x = reshape(x, target_shape = original_shape) 563 | 564 | # Linear layer to post-process operated tensor 565 | x = self.linear2(x) 566 | x = self.dropout(x) 567 | return x 568 | 569 | class Mlp: 570 | def __init__(self, dim, dropout_rate): 571 | '''MLP layers, same as most vision transformer architectures.''' 572 | self.linear1 = Linear(dim, dim * 4) 573 | self.linear2 = Linear(dim * 4, dim) 574 | self.activation = GeLU() 575 | self.drop = DropOut(drop_rate=dropout_rate) 576 | 577 | def forward(self, x): 578 | x = self.linear(x) 579 | x = self.activation(x) 580 | x = self.drop(x) 581 | x = self.linear(x) 582 | x = self.drop(x) 583 | return x 584 | 585 | def PerlinNoise(): 586 | '''Generate random Perlin noise: we follow https://github.com/pvigier/perlin-numpy/ to calculate the perlin noise.''' 587 | # Define number of noise 588 | octaves = 3 589 | # Define the scaling factor of noise 590 | noise_scale = 0.2 591 | # Define the number of periods of noise along the axis 592 | period_number = 12 593 | # The size of an input slice 594 | H, W = 721, 1440 595 | # Scaling factor between two octaves 596 | persistence = 0.5 597 | # see https://github.com/pvigier/perlin-numpy/ for the implementation of GenerateFractalNoise (e.g., from perlin_numpy import generate_fractal_noise_3d) 598 | perlin_noise = noise_scale*GenerateFractalNoise((H, W), (period_number, period_number), octaves, persistence) 599 | return perlin_noise 600 | --------------------------------------------------------------------------------