├── Assets ├── LSTM-master │ ├── data │ │ └── .Rhistory │ ├── .gitattributes │ ├── output_15_0.png │ ├── output_15_1.png │ ├── output_17_0.png │ ├── output_3_1.png │ ├── LICENSE │ ├── .gitignore │ └── README.md ├── Images │ └── 研究区域.png └── Ref │ ├── sp_ncdf.pro │ ├── gldas_tws_eg.py │ └── sp_ncdf_lunwen.pro ├── .gitignore ├── History └── IDL.rar ├── requirements.txt ├── utils ├── __pycache__ │ ├── models.cpython-38.pyc │ ├── utils.cpython-38.pyc │ └── __init__.cpython-38.pyc ├── __init__.py ├── utils.py └── models.py ├── .idea ├── .gitignore ├── misc.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml └── VEG.iml ├── Core ├── dead_code.py ├── process_Rs.py ├── check_datasets.py ├── process_gldas.py ├── Plot │ ├── density_box_plot.py │ ├── line_bar_distribution_plot.py │ └── plot.ipynb ├── uniform_datasets.py ├── feature_engineering.py ├── model_train_dynamic.py ├── model_train.py └── process_modis.py └── README.md /Assets/LSTM-master/data/.Rhistory: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Assets/LSTM-master/.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | Assets/LSTM-master/data/prcp.tif 3 | Assets/LSTM-master/data/prcp.tif.aux.xml 4 | VEG.zip 5 | -------------------------------------------------------------------------------- /History/IDL.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/History/IDL.rar -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/requirements.txt -------------------------------------------------------------------------------- /Assets/Images/研究区域.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/Images/研究区域.png -------------------------------------------------------------------------------- /Assets/LSTM-master/output_15_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/LSTM-master/output_15_0.png -------------------------------------------------------------------------------- /Assets/LSTM-master/output_15_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/LSTM-master/output_15_1.png -------------------------------------------------------------------------------- /Assets/LSTM-master/output_17_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/LSTM-master/output_17_0.png -------------------------------------------------------------------------------- /Assets/LSTM-master/output_3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/LSTM-master/output_3_1.png -------------------------------------------------------------------------------- /utils/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/utils/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/1/14 19:06 3 | # @Email : chaoqiezi.one@qq.com 4 | 5 | """ 6 | This script is used to ... 7 | """ 8 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/VEG.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 12 | -------------------------------------------------------------------------------- /Core/dead_code.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2023/12/30 20:35 3 | # @Email : chaoqiezi.one@qq.com 4 | 5 | """ 6 | This script is used to ... 7 | """ 8 | import os 9 | 10 | # 指定要更改的目录 11 | path = r'H:\Datasets\Objects\Veg\LST_Max' 12 | 13 | # 遍历目录下的所有文件 14 | for filename in os.listdir(path): 15 | # 检查文件名是否包含"max" 16 | if "Max" in filename: 17 | # 创建新的文件名,将"max"替换为"MAX" 18 | new_filename = filename.replace("Max", "MAX") 19 | 20 | # 获取文件的原始路径和新路径 21 | old_path = os.path.join(path, filename) 22 | new_path = os.path.join(path, new_filename) 23 | 24 | # 重命名文件 25 | os.rename(old_path, new_path) 26 | -------------------------------------------------------------------------------- /Assets/LSTM-master/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Longhao Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Assets/Ref/sp_ncdf.pro: -------------------------------------------------------------------------------- 1 | ;首先我们需要一个读取函数,这里我直接复制过来讲解 2 | function era5_readarr,file_name,sds_name ;这里只需要给函数传递文件名,和数据集名字 3 | file_id=ncdf_open(file_name) ;打开文件 4 | data_id=ncdf_varid(file_id,sds_name);读取数据集-我们这里是2m温度 5 | ncdf_varget,file_id,data_id,data;获取2m温度 存储为data 6 | ncdf_attget,file_id,data_id,'scale_factor',sc_data ;获取数据需要的预处理乘法因子,每个数据需要先乘以这个因子 7 | ncdf_attget,file_id,data_id,'add_offset',ao_data;相应的获取加法因子,每个数据需要乘以上面因子之后再加 8 | ncdf_close,file_id 9 | data=float(data)*sc_data+ao_data ;先乘后加 10 | return,data ;得到的最后数据返回 11 | end 12 | 13 | function era5_readarr11,file_name,sds_name ;这里只需要给函数传递文件名,和数据集名字 14 | file_id=ncdf_open(file_name) ;打开文件 15 | data_id=ncdf_varid(file_id,sds_name);读取数据集-我们这里是2m温度 16 | ncdf_varget,file_id,data_id,data;获取2m温度 存储为data 17 | ;ncdf_attget,file_id,data_id,'scale_factor',sc_data ;获取数据需要的预处理乘法因子,每个数据需要先乘以这个因子 18 | ;ncdf_attget,file_id,data_id,'add_offset',ao_data;相应的获取加法因子,每个数据需要乘以上面因子之后再加 19 | ncdf_close,file_id 20 | data=float(data);*sc_data+ao_data ;先乘后加 21 | return,data ;得到的最后数据返回 22 | end 23 | ;接着我们对应我们的数据处理 24 | pro sp_ncdf 25 | ;输入文件名字直接复制过来也可以直接拖过来。 26 | path = 'F:\new_lunwen\Global GLDAS\' 27 | out_path = 'F:\new_lunwen\jg\Tveg\' 28 | file_dir = file_test(out_path,/directory) 29 | if file_dir eq 0 then begin 30 | file_mkdir,out_path 31 | endif 32 | ;接着我们读取数据-t2m 33 | file_list = file_search(path,'*.nc4',count = file_n) 34 | for file_i=0,file_n-1 do begin 35 | 36 | data_ciwc = era5_readarr11(file_list[file_i],'Tveg_tavg') 37 | data_ciwc = (data_ciwc gt -9999 and data_ciwc lt 9999)*data_ciwc 38 | 39 | lon = era5_readarr11(file_list[file_i],'lon') 40 | lon_min = min(lon) 41 | lat = era5_readarr11(file_list[file_i],'lat') 42 | lat_max = max(lat) 43 | data_ciwc = rotate(data_ciwc,7) 44 | 45 | res = 0.25 46 | geo_info={$ 47 | MODELPIXELSCALETAG:[res,res,0.0],$ ;还是直接复制过来这是地理信息直接复制即可,这一行需要加入经纬度分辨率,都是0.25所以不用改 48 | MODELTIEPOINTTAG:[0.0,0.0,0.0,lon_min,lat_max,0.0],$ 49 | GTMODELTYPEGEOKEY:2,$ 50 | GTRASTERTYPEGEOKEY:1,$ 51 | GEOGRAPHICTYPEGEOKEY:4326,$ 52 | GEOGCITATIONGEOKEY:'GCS_WGS_1984',$ 53 | GEOGANGULARUNITSGEOKEY:9102} 54 | ;写成tif 55 | write_tiff,out_path+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc,/float,geotiff=geo_info ;输出路径out_path,文件名字2021——t2m.tif 56 | ;完成读取 57 | print,'down!!' 58 | 59 | endfor 60 | end -------------------------------------------------------------------------------- /Core/process_Rs.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/5/9 19:25 3 | # @FileName : process_Rs.py 4 | # @Email : chaoqiezi.one@qq.com 5 | 6 | """ 7 | This script is used to 处理RS地表太阳辐射并提取经纬度数据集,通过裁剪、掩膜、重采样等处理输出为tiff文件 8 | """ 9 | 10 | import os.path 11 | import matplotlib.pyplot as plt 12 | import netCDF4 as nc 13 | import numpy as np 14 | from osgeo import gdal, osr 15 | import pandas as pd 16 | 17 | # 准备 18 | Rs_path = r'H:\Datasets\Objects\Veg\GWRHXG_Rs1.nc' 19 | out_Rs_dir = r'E:\FeaturesTargets\uniform\Rs' 20 | out_dir = r'E:\FeaturesTargets\uniform' 21 | mask_path = r'E:\Basic\Region\sw5f\sw5_mask.shp' 22 | out_res = 0.1 # 度(°) 23 | if not os.path.exists(out_Rs_dir): os.makedirs(out_Rs_dir) 24 | if not os.path.exists(out_dir): os.makedirs(out_dir) 25 | 26 | # 读取 27 | with nc.Dataset(Rs_path) as f: 28 | lon, lat = f['longitude'][:].filled(-9999), f['latitude'][:].filled(-9999) 29 | Rs = f['Rs'][:].filled(-9999) 30 | years, months = f['year'][:].filled(np.nan), f['month'][:].filled(np.nan) 31 | for ix, (year, month) in enumerate(zip(years, months)): 32 | cur_Rs = Rs[ix, :, :] # 当前时间点的Rs地表太阳辐射 33 | lon_min, lon_max, lat_min, lat_max = lon.min(), lon.max(), lat.min(), lat.max() 34 | lon_res = (lon_max - lon_min) / len(lon) 35 | lat_res = (lat_max - lat_min) / len(lat) 36 | geo_transform = [lon_min, lon_res, 0, lat_max, 0, -lat_res] 37 | 38 | # 输出 39 | out_file_name = 'Rs_{:4.0f}{:02.0f}.tiff'.format(year, month) 40 | out_path = os.path.join(out_Rs_dir, out_file_name) 41 | mem_driver = gdal.GetDriverByName('MEM') 42 | mem_ds = mem_driver.Create('', len(lon), len(lat), 1, gdal.GDT_Float32) 43 | srs = osr.SpatialReference() 44 | srs.ImportFromEPSG(4326) 45 | mem_ds.SetProjection(srs.ExportToWkt()) 46 | mem_ds.SetGeoTransform(geo_transform) 47 | mem_ds.GetRasterBand(1).WriteArray(cur_Rs) 48 | mem_ds.GetRasterBand(1).SetNoDataValue(-9999) # 设置无效值 49 | out_ds = gdal.Warp(out_path, mem_ds, cropToCutline=True, cutlineDSName=mask_path, xRes=out_res, yRes=out_res, 50 | resampleAlg=gdal.GRA_Cubic, srcNodata=-9999, dstNodata=-9999) 51 | mem_ds.FlushCache() 52 | print('processing: {}'.format(out_file_name)) 53 | 54 | masked_geo_transform = out_ds.GetGeoTransform() 55 | rows, cols = out_ds.RasterYSize, out_ds.RasterXSize 56 | lat = np.array([masked_geo_transform[3] + _ix * masked_geo_transform[-1] + masked_geo_transform[-1] / 2 for _ix in range(rows)]) 57 | lon = np.array([masked_geo_transform[0] + _ix * masked_geo_transform[1] + masked_geo_transform[1] / 2 for _ix in range(cols)]) 58 | lon_2d, lat_2d = np.meshgrid(lon, lat) 59 | driver = gdal.GetDriverByName('GTiff') 60 | lon_ds = driver.Create(os.path.join(out_dir, 'Lon.tiff'), len(lon), len(lat), 1, gdal.GDT_Float32) 61 | lat_ds = driver.Create(os.path.join(out_dir, 'Lat.tiff'), len(lon), len(lat), 1, gdal.GDT_Float32) 62 | srs.ImportFromEPSG(4326) 63 | lon_ds.SetProjection(srs.ExportToWkt()) 64 | lon_ds.SetGeoTransform(masked_geo_transform) 65 | lon_ds.GetRasterBand(1).WriteArray(lon_2d) 66 | lon_ds.GetRasterBand(1).SetNoDataValue(-9999) # 设置无效值 67 | lat_ds.SetProjection(srs.ExportToWkt()) 68 | lat_ds.SetGeoTransform(masked_geo_transform) 69 | lat_ds.GetRasterBand(1).WriteArray(lat_2d) 70 | lat_ds.GetRasterBand(1).SetNoDataValue(-9999) # 设置无效值 71 | lon_ds.FlushCache() 72 | lat_ds.FlushCache() 73 | 74 | print('Done!') -------------------------------------------------------------------------------- /Assets/LSTM-master/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | -------------------------------------------------------------------------------- /Core/check_datasets.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2023/12/7 15:07 3 | # @FileName : check_datasets.py 4 | # @Email : chaoqiezi.one@qq.com 5 | 6 | """ 7 | This script is used to 用于检查数据完整性, 包括MCD12Q1、MOD11A2、MOD13A2 8 | -·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·- 9 | 拓展: MYD\MOD\MCD 10 | MOD标识Terra卫星 11 | MYD标识Aqua卫星 12 | MCD标识Terra和Aqua卫星的结合 13 | -·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·- 14 | 拓展: MCD12Q1\MOD11A2\MOD13A2 15 | MCD12Q1为土地利用数据 16 | MOD11A2为地表温度数据 17 | MOD13A2为植被指数数据(包括NDVI\EVI) 18 | """ 19 | 20 | import os.path 21 | import glob 22 | from datetime import datetime, timedelta 23 | 24 | # 准备 25 | in_dir = r'F:\Cy_modis' 26 | searching_ds_wildcard = ['MCD12Q1', 'MOD11A2', 'MOD13A2'] 27 | 28 | # 检查MCD12Q1数据集 29 | error_txt = os.path.join(in_dir, 'MCD12Q1_check_error.txt') 30 | ds_name_wildcard = 'MCD12Q1*' 31 | region_wildcard = ['h26v05', 'h26v06', 'h27v05', 'h27v06'] 32 | with open(error_txt, 'w+') as f: 33 | for year in range(2001, 2021): 34 | for region in region_wildcard: 35 | cur_ds_name_wildcard = ds_name_wildcard + 'A{}*'.format(year) + region + '*.hdf' 36 | ds_path_wildcard = os.path.join(in_dir, '**', cur_ds_name_wildcard) 37 | hdf_paths = glob.glob(ds_path_wildcard, recursive=True) 38 | if len(hdf_paths) != 1: 39 | f.write('{}: 文件数目(为: {})不正常\n'.format(cur_ds_name_wildcard, len(hdf_paths))) 40 | if not f.read(): 41 | f.write('MCD12Q1数据集文件数正常') 42 | 43 | # 检查MOD11A2数据集 44 | error_txt = os.path.join(in_dir, 'MOD11A2_check_error.txt') 45 | ds_name_wildcard = 'MOD11A2*' 46 | region_wildcard = ['h26v05', 'h26v06', 'h27v05', 'h27v06'] 47 | start_date = datetime(2000, 1, 1) + timedelta(days=48) 48 | end_date = datetime(2022, 1, 1) + timedelta(days=296) 49 | with open(error_txt, 'w+') as f: 50 | cur_date = start_date 51 | while cur_date <= end_date: 52 | cur_date_str = cur_date.strftime('%Y%j') 53 | for region in region_wildcard: 54 | cur_ds_name_wildcard = ds_name_wildcard + 'A{}*'.format(cur_date_str) + region + '*.hdf' 55 | ds_path_wildcard = os.path.join(in_dir, '**', cur_ds_name_wildcard) 56 | hdf_paths = glob.glob(ds_path_wildcard, recursive=True) 57 | if len(hdf_paths) != 1: 58 | f.write('{}: 文件数目(为: {})不正常\n'.format(cur_ds_name_wildcard, len(hdf_paths))) 59 | if (cur_date + timedelta(days=8)).year != cur_date.year: 60 | cur_date = datetime(cur_date.year + 1, 1, 1) 61 | else: 62 | cur_date += timedelta(days=8) 63 | if not f.read(): 64 | f.write('MOD11A2数据集文件数正常') 65 | 66 | # 检查MOD13A2数据集 67 | error_txt = os.path.join(in_dir, 'MOD13A2_check_error.txt') 68 | ds_name_wildcard = 'MOD13A2*' 69 | region_wildcard = ['h26v05', 'h26v06', 'h27v05', 'h27v06'] 70 | start_date = datetime(2000, 1, 1) + timedelta(days=48) 71 | end_date = datetime(2020, 1, 1) + timedelta(days=352) 72 | with open(error_txt, 'w+') as f: 73 | cur_date = start_date 74 | while cur_date <= end_date: 75 | cur_date_str = cur_date.strftime('%Y%j') 76 | for region in region_wildcard: 77 | cur_ds_name_wildcard = ds_name_wildcard + 'A{}*'.format(cur_date_str) + region + '*.hdf' 78 | ds_path_wildcard = os.path.join(in_dir, '**', cur_ds_name_wildcard) 79 | hdf_paths = glob.glob(ds_path_wildcard, recursive=True) 80 | if len(hdf_paths) != 1: 81 | f.write('{}: 文件数目(为: {})不正常\n'.format(cur_ds_name_wildcard, len(hdf_paths))) 82 | if (cur_date + timedelta(days=16)).year != cur_date.year: 83 | cur_date = datetime(cur_date.year + 1, 1, 1) 84 | else: 85 | cur_date += timedelta(days=16) 86 | if not f.read(): 87 | f.write('MOD13A2数据集文件数正常') 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /Assets/Ref/gldas_tws_eg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from netCDF4 import Dataset 3 | import numpy as np 4 | 5 | #=========================================================================== 6 | 7 | # import time 8 | # from datetime import datetime, timedelta 9 | # from netCDF4 import num2date, date2num 10 | # nc_file = 'F:/My_Postdoctor/GLDAS/GLDAS_NOAH025_M.A200001.021.nc4.SUB.nc4' 11 | 12 | month2seconds = 30*24*3600 # seconds in a month 13 | month2threehours = 30*8 # numbers of 3 hours in a month 14 | filelist = [] # create a file list for GLDAS data 15 | TWSCs = [] # create TWSCs (changes of Terrestrial Water Storage) 16 | PRCPs = [] 17 | ETs = [] 18 | Qss = [] 19 | Qsbs = [] 20 | # read GLDAS files of *.nc4 (0.25 x 0.25),and put them into filelist 21 | for (dirname, dirs, files) in os.walk('F:/My_Postdoctor/GlobalGLDAS/'): 22 | for filename in files: 23 | if filename.endswith('.nc4'): 24 | filelist.append(os.path.join(dirname,filename)) 25 | 26 | filelist = np.sort(filelist) # order in time 27 | num_files = len(filelist) # 201 files of *.nc4 from 2002.4-2018.12 28 | 29 | # read each file of *.nc4 30 | n=1000 31 | for files in filelist: 32 | data = Dataset(files, 'r',format='netCDF4') 33 | 34 | lons_gldas = data.variables['lon'][:] #len: 190 35 | lats_gldas = data.variables['lat'][:] #len: 103 36 | 37 | precipitation_flux = data.variables['Rainf_f_tavg'][:] 38 | water_evaporation_flux = data.variables['Evap_tavg'][:] 39 | surface_runoff_amount = data.variables['Qs_acc'][:] 40 | subsurface_runoff_amount = data.variables['Qsb_acc'][:] 41 | 42 | data.close() # close files of GLDAS 43 | 44 | # get monthly P, ET, Surface runoff, underground runoff 45 | precipitation_flux = precipitation_flux[0] # from (1,600,1440) to (600,1440) to reduce dim 46 | water_evaporation_flux = water_evaporation_flux[0] 47 | surface_runoff_amount = surface_runoff_amount[0] 48 | subsurface_runoff_amount = subsurface_runoff_amount[0] 49 | 50 | # calculate change of (k-1)-month TWSC in term of k-month (k>=1) TWSC 51 | TWSC = (precipitation_flux - water_evaporation_flux) * month2seconds - (surface_runoff_amount + subsurface_runoff_amount) * month2threehours 52 | PRCP = precipitation_flux * month2seconds 53 | ET = water_evaporation_flux * month2seconds 54 | Qs = surface_runoff_amount * month2threehours 55 | Qsb = subsurface_runoff_amount * month2threehours 56 | 57 | # #save 58 | # lons_new = lons_gldas[900:1072] 59 | # lats_new = lats_gldas[372:468] 60 | # PRCP = precipitation_flux[372:468, 900:1072]*month2seconds 61 | # ET = water_evaporation_flux[372:468, 900:1072]*month2seconds 62 | # Qs = surface_runoff_amount[372:468, 900:1072]* month2threehours 63 | # Qsb = subsurface_runoff_amount[372:468, 900:1072]* month2threehours 64 | # TWSC = TWSC[372:468, 900:1072] 65 | # n+= 1 66 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/PRCP_' + str(n), PRCP) 67 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/ET_' + str(n), ET) 68 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/Qs_' + str(n), Qs ) 69 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/Qsb_' + str(n), Qsb) 70 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/lon_'+ str(n), lons_new) 71 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/lat_'+ str(n), lats_new) 72 | 73 | 74 | # GLDAS 月文件中的海洋区域和南极洲地区均无有效测量值(默认填充为 −9999.0),这里将填充值重设为零。 75 | TWSCs.append(TWSC.filled(0)) 76 | 77 | 78 | TWSCs = np.array(TWSCs) 79 | # 首先计算第 k(k>=1) 个月相对于第 0 个月的陆地水储量变化,即每个月的陆地水储量变化,然后计算陆地水储量变化的月平均 80 | TWSCs_acc = np.cumsum(TWSCs, axis=0) 81 | TWSCs_acc_average = np.average(TWSCs_acc, axis=0) 82 | # 对每个月的陆地是储量变化进行去平均化,and get the final TWSCs 83 | TWSCs = TWSCs_acc - TWSCs_acc_average 84 | 85 | #save ewt 86 | # n=1000 87 | # for da in TWSCs: 88 | # n+=1 89 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/TWSC_'+str(n),da) 90 | 91 | 92 | -------------------------------------------------------------------------------- /Assets/Ref/sp_ncdf_lunwen.pro: -------------------------------------------------------------------------------- 1 | ;首先我们需要一个读取函数,这里我直接复制过来讲解 2 | function era5_readarr,file_name,sds_name ;这里只需要给函数传递文件名,和数据集名字 3 | file_id=ncdf_open(file_name) ;打开文件 4 | data_id=ncdf_varid(file_id,sds_name);读取数据集-我们这里是2m温度 5 | ncdf_varget,file_id,data_id,data;获取2m温度 存储为data 6 | ncdf_attget,file_id,data_id,'scale_factor',sc_data ;获取数据需要的预处理乘法因子,每个数据需要先乘以这个因子 7 | ncdf_attget,file_id,data_id,'add_offset',ao_data;相应的获取加法因子,每个数据需要乘以上面因子之后再加 8 | ncdf_close,file_id 9 | data=float(data)*sc_data+ao_data ;先乘后加 10 | return,data ;得到的最后数据返回 11 | end 12 | 13 | function era5_readarr11,file_name,sds_name ;这里只需要给函数传递文件名,和数据集名字 14 | file_id=ncdf_open(file_name) ;打开文件 15 | data_id=ncdf_varid(file_id,sds_name);读取数据集-我们这里是2m温度 16 | ncdf_varget,file_id,data_id,data;获取2m温度 存储为data 17 | ;ncdf_attget,file_id,data_id,'scale_factor',sc_data ;获取数据需要的预处理乘法因子,每个数据需要先乘以这个因子 18 | ;ncdf_attget,file_id,data_id,'add_offset',ao_data;相应的获取加法因子,每个数据需要乘以上面因子之后再加 19 | ncdf_close,file_id 20 | data=float(data);*sc_data+ao_data ;先乘后加 21 | return,data ;得到的最后数据返回 22 | end 23 | ;接着我们对应我们的数据处理 24 | pro sp_ncdf_lunwen 25 | ;输入文件名字直接复制过来也可以直接拖过来。 26 | path = 'F:\new_lunwen\Global GLDAS\' 27 | 28 | out_path = 'F:\new_lunwen\Rainf_f_tavg\' 29 | out_path1 = 'F:\new_lunwen\Evap_tavg_tavg\' 30 | out_path2 = 'F:\new_lunwen\Qs_acc\' 31 | out_path3 = 'F:\new_lunwen\Qsb_acc\' 32 | if ~file_test(out_path1, /directory) then file_mkdir,out_path1 33 | if ~file_test(out_path2, /directory) then file_mkdir,out_path2 34 | if ~file_test(out_path3, /directory) then file_mkdir,out_path3 35 | file_dir = file_test(out_path,/directory) 36 | if file_dir eq 0 then begin 37 | file_mkdir,out_path 38 | endif 39 | ;接着我们读取数据-t2m 40 | file_list = file_search(path,'*.nc4',count = file_n) 41 | for file_i=0,file_n-1 do begin 42 | 43 | year = fix(strmid(file_basename(file_list[file_i]),17,4)) 44 | month = fix(strmid(file_basename(file_list[file_i]),21,2)) 45 | data_ciwc = era5_readarr11(file_list[file_i],'Rainf_f_tavg') 46 | data_ciwc1 = era5_readarr11(file_list[file_i],'Evap_tavg') 47 | data_ciwc2 = era5_readarr11(file_list[file_i],'Qs_acc') 48 | data_ciwc3 = era5_readarr11(file_list[file_i],'Qsb_acc') 49 | 50 | 51 | 52 | data_ciwc = (data_ciwc gt -9999 and data_ciwc lt 9999)*data_ciwc 53 | data_ciwc1 = (data_ciwc1 gt -9999 and data_ciwc1 lt 9999)*data_ciwc1 54 | data_ciwc2 = (data_ciwc2 gt -9999 and data_ciwc2 lt 9999)*data_ciwc2 55 | data_ciwc3 = (data_ciwc3 gt -9999 and data_ciwc3 lt 9999)*data_ciwc3 56 | 57 | lon = era5_readarr11(file_list[file_i],'lon') 58 | lon_min = min(lon) 59 | lat = era5_readarr11(file_list[file_i],'lat') 60 | lat_max = max(lat) 61 | data_ciwc = rotate(data_ciwc,7) 62 | data_ciwc1 = rotate(data_ciwc1,7) 63 | data_ciwc2 = rotate(data_ciwc2,7) 64 | data_ciwc3 = rotate(data_ciwc3,7) 65 | if month eq 1 or month eq 3 or month eq 5 or month eq 7 or month eq 8 or month eq 10 or month eq 12 then begin 66 | data_ciwc = data_ciwc * 31 * 24 * 3600 67 | data_ciwc1 = data_ciwc1 * 31 * 24 * 3600 68 | data_ciwc2 = data_ciwc2 * 31 * 8 69 | data_ciwc3 = data_ciwc3 * 31 * 8 70 | 71 | endif 72 | if month eq 4 or month eq 6 or month eq 9 or month eq 11 then begin 73 | data_ciwc = data_ciwc * 30 * 24 * 3600 74 | data_ciwc1 = data_ciwc1 * 30 * 24 * 3600 75 | data_ciwc2 = data_ciwc2 * 30 * 8 76 | data_ciwc3 = data_ciwc3 * 30 * 8 77 | 78 | endif 79 | m = year mod 4 80 | if month eq 2 and m eq 0 then begin 81 | data_ciwc = data_ciwc * 29 * 24 * 3600 82 | data_ciwc1 = data_ciwc1 * 29 * 24 * 3600 83 | data_ciwc2 = data_ciwc2 * 29 * 8 84 | data_ciwc3 = data_ciwc3 * 29 * 8 85 | 86 | endif 87 | if month eq 2 and m ne 0 then begin 88 | data_ciwc = data_ciwc * 28 * 24 * 3600 89 | data_ciwc1 = data_ciwc1 * 28 * 24 * 3600 90 | data_ciwc2 = data_ciwc2 * 28 * 8 91 | data_ciwc3 = data_ciwc3 * 28 * 8 92 | 93 | endif 94 | 95 | res = 0.25 96 | geo_info={$ 97 | MODELPIXELSCALETAG:[res,res,0.0],$ ;还是直接复制过来这是地理信息直接复制即可,这一行需要加入经纬度分辨率,都是0.25所以不用改 98 | MODELTIEPOINTTAG:[0.0,0.0,0.0,lon_min,lat_max,0.0],$ ;这里需要提供最小经度和最大纬度,在第4、5个位置 99 | GTMODELTYPEGEOKEY:2,$ 100 | GTRASTERTYPEGEOKEY:1,$ 101 | GEOGRAPHICTYPEGEOKEY:4326,$ 102 | GEOGCITATIONGEOKEY:'GCS_WGS_1984',$ 103 | GEOGANGULARUNITSGEOKEY:9102} 104 | ;写成tif 105 | write_tiff,out_path+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc,/float,geotiff=geo_info ;输出路径out_path,文件名字2021——t2m.tif 106 | write_tiff,out_path1+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc1,/float,geotiff=geo_info 107 | write_tiff,out_path2+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc2,/float,geotiff=geo_info 108 | write_tiff,out_path3+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc3,/float,geotiff=geo_info 109 | ;完成读取 110 | print,'down!!' 111 | 112 | endfor 113 | end -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2023/12/14 6:33 3 | # @Email : chaoqiezi.one@qq.com 4 | 5 | """ 6 | This script is used to 存放常用工具 7 | """ 8 | 9 | import h5py 10 | import torch 11 | from torch.utils.data import Dataset 12 | 13 | torch.manual_seed(42) # 固定种子 14 | 15 | 16 | class H5DynamicDatasetDecoder(Dataset): 17 | """ 18 | 对存储动态特征项和目标的HDF5文件进行加载和解析, 用于后续的数据集的加载训练 19 | """ 20 | 21 | def __init__(self, file_path, shuffle_feature_ix=None, dynamic=True): 22 | self.file_path = file_path 23 | self.shuffle_feature_ix = shuffle_feature_ix 24 | self.dynamic = dynamic 25 | 26 | # 获取数据集样本数 27 | with h5py.File(file_path, mode='r') as h5: 28 | self.length = h5['dynamic_features'].shape[1] 29 | self.targets = h5['targets'][:] # (12, 138488) 30 | self.dynamic_features = h5['dynamic_features'][:] # (12, 138488, 6) 31 | 32 | if self.shuffle_feature_ix is not None: 33 | shuffled_indices = torch.randperm(self.length) 34 | if self.dynamic: 35 | # 乱序索引 36 | self.dynamic_features[:, :, self.shuffle_feature_ix] = \ 37 | self.dynamic_features[:, shuffled_indices, self.shuffle_feature_ix] 38 | 39 | def __len__(self): 40 | """ 41 | 返回数据集的总样本数 42 | :return: 43 | """ 44 | 45 | return self.length 46 | 47 | def __getitem__(self, index): 48 | """ 49 | 依据索引索引返回一个样本 50 | :param index: 51 | :return: 52 | """ 53 | 54 | dynamic_feature = self.dynamic_features[:, index, :] 55 | target = self.targets[:, index] 56 | return torch.tensor(dynamic_feature, dtype=torch.float32), torch.tensor(target, dtype=torch.float32) 57 | 58 | 59 | class H5DatasetDecoder(Dataset): 60 | """ 61 | 对存储特征项和目标项的HDF5文件进行解析,用于后续的数据集加载训练 62 | """ 63 | 64 | def __init__(self, file_path, shuffle_feature_ix=None, dynamic=True): 65 | self.file_path = file_path 66 | self.shuffle_feature_ix = shuffle_feature_ix 67 | self.dynamic = dynamic 68 | 69 | # 获取数据集样本数 70 | with h5py.File(file_path, mode='r') as h5: 71 | self.length = h5['static_features1'].shape[0] 72 | self.targets = h5['targets'][:] # (12, 138488) 73 | self.dynamic_features = h5['dynamic_features'][:] # (12, 138488, 6) 74 | self.static_features1 = h5['static_features1'][:] # (138488,) 75 | self.static_features2 = h5['static_features2'][:] # (138488,) 76 | self.static_features3 = h5['static_features3'][:] # (138488,) 77 | self.static_features4 = h5['static_features4'][:] # (138488,) 78 | 79 | if self.shuffle_feature_ix is not None: 80 | shuffled_indices = torch.randperm(self.length) 81 | if self.dynamic: 82 | # 乱序索引 83 | self.dynamic_features[:, :, self.shuffle_feature_ix] = \ 84 | self.dynamic_features[:, shuffled_indices, self.shuffle_feature_ix] 85 | elif self.shuffle_feature_ix == 0: # 静态的 86 | self.static_features1 = self.static_features1[shuffled_indices] 87 | elif self.shuffle_feature_ix == 1: 88 | self.static_features2 = self.static_features2[shuffled_indices] 89 | elif self.shuffle_feature_ix == 2: 90 | self.static_features3 = self.static_features3[shuffled_indices] 91 | elif self.shuffle_feature_ix == 3: 92 | self.static_features4 = self.static_features4[shuffled_indices] 93 | 94 | def __len__(self): 95 | """ 96 | 返回数据集的总样本数 97 | :return: 98 | """ 99 | 100 | return self.length 101 | 102 | def __getitem__(self, index): 103 | """ 104 | 依据索引索引返回一个样本 105 | :param index: 106 | :return: 107 | """ 108 | 109 | dynamic_feature = self.dynamic_features[:, index, :] 110 | static_features1 = self.static_features1[index] 111 | static_features2 = self.static_features2[index] 112 | static_features3 = self.static_features3[index] 113 | static_features4 = self.static_features4[index] 114 | target = self.targets[:, index] 115 | 116 | static_feature = (static_features1, static_features2, static_features3, static_features4) 117 | return torch.tensor(dynamic_feature, dtype=torch.float32), \ 118 | torch.tensor(static_feature, dtype=torch.float32), torch.tensor(target, dtype=torch.float32) 119 | 120 | 121 | def cal_r2(outputs, targets): 122 | """ 123 | 计算R2决定系数 124 | :param outputs: 125 | :param targets: 126 | :return: 127 | """ 128 | mean_predictions = torch.mean(outputs, dim=0, keepdim=True) 129 | mean_targets = torch.mean(targets, dim=0, keepdim=True) 130 | predictions_centered = outputs - mean_predictions 131 | targets_centered = targets - mean_targets 132 | corr = torch.sum(predictions_centered * targets_centered, dim=0) / \ 133 | (torch.sqrt(torch.sum(predictions_centered ** 2, dim=0)) * torch.sqrt( 134 | torch.sum(targets_centered ** 2, dim=0))) 135 | 136 | return torch.mean(corr) 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2023-12-17更新处理过程记录 2 | 处理过程包含两部分代码: 3 | 4 | - check_datasets.py 5 | - aver_cal.py 6 | 7 | ## 01 检查数据集完整性 8 | 9 | `check_datasets.py`用于检查数据集的完整性, 包括`MCD12Q1`为土地利用数据、 `MOD11A2`为地表温度数据、 `MOD13A2`为植被指数数据。 10 | 11 | 1. MCD12Q1数据集(土地利用|每年)检查至2001年~2021年: 12 | 13 | 正常 14 | 15 | 2. MOD11A2数据集(地表温度|8天周期)检查至2000年第48日~2022年第296日: 16 | 17 | MOD11A2*A2001169*h26v05*.hdf: 文件数目(为: 0)不正常 18 | MOD11A2*A2001169*h26v06*.hdf: 文件数目(为: 0)不正常 19 | MOD11A2*A2001169*h27v05*.hdf: 文件数目(为: 0)不正常 20 | MOD11A2*A2001169*h27v06*.hdf: 文件数目(为: 0)不正常 21 | MOD11A2*A2001177*h26v05*.hdf: 文件数目(为: 0)不正常 22 | MOD11A2*A2001177*h26v06*.hdf: 文件数目(为: 0)不正常 23 | MOD11A2*A2001177*h27v05*.hdf: 文件数目(为: 0)不正常 24 | MOD11A2*A2001177*h27v06*.hdf: 文件数目(为: 0)不正常 25 | MOD11A2*A2010121*h26v05*.hdf: 文件数目(为: 0)不正常(已下载) 26 | MOD11A2*A2010121*h26v06*.hdf: 文件数目(为: 0)不正常(已下载) 27 | MOD11A2*A2010121*h27v05*.hdf: 文件数目(为: 0)不正常(已下载) 28 | MOD11A2*A2010121*h27v06*.hdf: 文件数目(为: 0)不正常(已下载) 29 | MOD11A2*A2022289*h26v05*.hdf: 文件数目(为: 0)不正常 30 | MOD11A2*A2022289*h26v06*.hdf: 文件数目(为: 0)不正常 31 | MOD11A2*A2022289*h27v05*.hdf: 文件数目(为: 0)不正常 32 | MOD11A2*A2022289*h27v06*.hdf: 文件数目(为: 0)不正常 33 | 其余未标注`已下载`的数据集未官网缺失,目前未解决。 34 | 35 | 3. MOD13A2数据集(NDVI|16天周期)检查至2000年第48天~2020年第352天: 36 | 37 | 正常 38 | 39 | ## 02 对三大数据集进行全流程预处理 40 | 41 | `aver_cal.py`主要对三个数据集(土地利用数据集、NDVI数据集、地表温度数据集)进行镶嵌、重投影并最终输出为GeoTIFF文件。 42 | 43 | 1. MCD12Q1(土地利用)数据集具体包括读取LC_Type1数据集(IGBP分类标准)、无效值去除(无效值设定为255)、镶嵌(Last模式, 年尺度)、 44 | 重投影、(sinu ==> WGS84, 重采样为最近邻<因为土地利用数据类型为整型>), 输出分辨率为0.045°(500m), 无效值为255. 45 | 2. MOD11A2(地表温度)数据集具体包括读取LST_Day_1km、无效值去除(无效值设定为-65535)、单位换算(最终单位为摄氏度)、 46 | 镶嵌(MAX模式, 月尺度)、重投影(sinu ==> WGS84, 重采样为三次卷积), 输出分辨率为0.009°(1000m), 无效值为-65535. 47 | 3. MOD13A2(NDVI)数据集具体包括读取1 km 16 days NDVI、无效值去除(无效值设定为-65535)、单位换算、镶嵌(MAX模式, 月尺度)、 48 | 重投影(sinu ==> WGS84, 重采样为三次卷积), 输出分辨率为0.009°(1000m), 无效值为-65535. 49 | 4. MOD13A2(EVI)数据集具体包括读取1 km 16 days EVI、无效值去除(无效值设定为-65535)、单位换算、镶嵌(MAX模式, 月尺度)、 50 | 重投影(sinu ==> WGS84, 重采样为三次卷积), 输出分辨率为0.009°(1000m), 无效值为-65535. 51 | 52 | # 2024年1月18日更新处理过程记录 53 | 处理主要包括: 54 | 55 | - 文件名更改(将aver_cal.py更改为process_modis.py) 56 | - 编写process_gldas.py代码 57 | 58 | ## process_gldas.py 59 | 主要包括对gldas数据集(nc文件)中的`Rain_f_tavg`(降水通量), `Evap_tavg`(蒸散发通量), `Qs_acc`(表面径流量), `Qsb_acc`(地下径流量) 60 | 进行月累加值的计算分别的都月降水量、月蒸散发量、月表面径流量、月地下径流量,依据`TWSC` = 降水量 - 蒸散发量 - (表面径流量 + 地下径流量)。 61 | 需要注意,此时栅格矩阵的范围为-180~180, -60~90.另外进行了无效值的去除(设置为nan)、南北极颠倒、重采样以及重采样后范围偏移的限定。 62 | 最后输出为tiff文件,WGS84坐标系. 63 | 64 | # 2024年1月19日更新处理过程记录 65 | 66 | 处理包括: 67 | - uniform_datasets.py 68 | - 修复process_gldas.py Bug 69 | 70 | ## uniform_datasets.py 71 | 主要是进行各个数据集的统一,统一包括空间范围的限定,研究区域范围如下: 72 | ![研究区域](Assets/Images/研究区域.png) 73 | 具体是进行掩膜和裁剪至掩膜形状、以及重采样0.1° 74 | 处理的数据集包括: 75 | - Landuse 76 | - LST 77 | - NDVI 78 | - 降水(PRCP) 79 | - 蒸散发量(ET) 80 | - 地表径流量(Qs) 81 | - 地下径流量(Qsb) 82 | - TWSC 83 | 84 | ## 修复process_gldas.py Bug 85 | 86 | 主要是在gldas数据集的处理中,将所有缺失值均赋值为np.nan,虽然可以正确写入无效值。但是在uniform_datasets.py 87 | 的处理中存在无法正确识别无效值np.nan的情况,因此出现了非研究区域的像元值为0.0而非无效值nan。 88 | 因此在gldas数据集的处理中将所有无效值设置为-65535,这一操作则与前期处理NDVI、LST、Landuse的操作一致,并没有 89 | 直接赋值np.nan,这也就是为什么在uniform_datasets.py中这个三个数据集没有发生意外情况的原因。 90 | 91 | # 2024年1月22日更新处理过程记录 92 | 93 | 处理包括: 94 | - feature_engineering.py 95 | - model_train.py 96 | - utils.models.py 97 | - 安装环境requirements.txt 98 | 99 | --- 100 | ## feature_engineering.py 101 | 主要是进行各个特征项和目标项数据集的模型预输入的处理,方便后续的模型的数据加载和训练, 102 | 涉及的数据集包括: 103 | Landuse: 2001 - 2020 104 | LST: 200002 - 202210 105 | NDVI: 200002 - 202010 106 | ET: 200204 - 202309 107 | PRCP: 200204 - 202309 108 | Qs: 200204 - 202309 109 | Qsb: 200204 - 202309 110 | TWSC: 200204 - 202309 111 | dem: single 112 | 由于数据集的某些特殊例如landuse的时间分辨率(年)与其它数据集(月)不同,dem所有训练样本都固定不随时间变化,仅体现地理 113 | 位置上的差异,因此对于landuse和dem进行单独的存储(后续将根据需求进行改写或许, 可能不太方便数据加载或者说数据加载出来 114 | 到可用于模型训练还需要进行一定的数据变换,这需要一定的时间和成本) 115 | 116 | 处理仅仅将各个数据集整理如下HDF5文件格式: 117 | - group(2002) 118 | - features1 119 | 120 | _存储月分辨率的数据集(shape为(行数 * 列数<即样本数>, 时间步, 特征数)), 特征项具体为: LST、PRCP、ET、Qs、Qsb、TWSC_ 121 | - features2 122 | 123 | _存储年分辨率的数据集(shape为(行数 * 列数)), 特征项具体为: Landuse_ 124 | - targets 125 | 126 | _存储月分辨率的数据集(shape为(行数 * 列数, 时间步)), 目标项具体为: NDVI_ 127 | - group(2003) 128 | - features1 129 | - features2 130 | - targets 131 | 132 | - ······ 133 | - dem 134 | 135 | _存储DEM的数据集, shape为(行数 * 列数)_ 136 | 137 | --- 138 | ## model_train.py 139 | 140 | 主要是模型的构建的训练以及评估(处于优化中, 依据后续任务进行框架的完善) 141 | 模型这里CNN-LSTM模型, cnn为一维卷积且在时间维度上卷积, lstm为常规模型. 142 | 目前完成某年部分样本的训练、评估。 143 | 这里暂未完全定型,只算草稿版本,后续将完善。 144 | 145 | --- 146 | ## utils.models.py 147 | 148 | 这里存储定义的模型,目前定义的编码解码lstm模型实际效果不如前面的cnn-lstm模型,暂时搁置。 149 | 150 | --- 151 | ## requirements.txt 152 | 153 | 考虑的后续项目的维护和迁移,这里增加环境配置的相关信息,配置代码: 154 | 155 | ```shell 156 | pip install -r I:\PyProJect\Veg\VEG\requirements.txt 157 | ``` 158 | 159 | # 2024年02月29日处理记录 160 | 161 | ## 增加NDVI、LST的MEAN、MIN处理 162 | 163 | 修改process_modis.py相关参数,进行NDVI、LST的MEAN、MIN计算 164 | 165 | ## uniform_datasets.py, feature_engineering.py调整 166 | 167 | # 2024年05月09日处理 168 | 169 | 新增process_Rs.py文件, 用于处理Rs地表太阳辐射数据, 主要是各个月份的影像nc转tiff,顺便做了一下掩膜、裁剪和重采样 170 | 并输出了lon和lat数据集为tiff文件, 方便后续作为变量输入到模型中 171 | 172 | # 2024/5/11处理 173 | 完善feature_engineering.py文件, 新增关于Rs(dynamic)、lon(static)、lat(static)的特征输入 -------------------------------------------------------------------------------- /Core/process_gldas.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/1/17 12:41 3 | # @FileName : process_gldas.py 4 | # @Email : chaoqiezi.one@qq.com 5 | 6 | """ 7 | This script is used to 预处理global gldas数据集 8 | 9 | 说明: 10 | 为确保简洁性和便捷性, 今后读取HDF5文件和NC文件均使用xarray模块而非h5py和NetCDF4模块 11 | 数据集介绍: 12 | TWSC = 降水量(PRCP) - 蒸散发量(ET) - 径流量(即表面径流量Qs + 地下径流量Qsb) ==> 给定时间间隔内, 例如月 13 | 在gldas数据集中: 14 | Rainf_f_tavg表示降水通量,即单位时间单位面积上的降水量(本数据集单位为kg/m2/s) 15 | Evap_tavg表示蒸散发通量,即单位时间单位面积上的水蒸发量(本数据集单位为kg/m2/s) 16 | Qs_acc表示表面径流量,即一定时间内通过地表流动进入河流、湖泊和水库的水量(本数据集单位为kg/m2) 17 | Qsb_acc表示地下径流量,即一定时间内通过土壤层流动的水量,最终进入河流的水量,最终进入河流的水量(本数据集单位为kg/m2) 18 | TWSC计算了由降水和蒸发引起的净水量变化,再减去地表和地下径流,其评估给定时间段内区域水资源变化的重要指标 19 | 20 | 存疑: 21 | 01 对于Qs和Qsb的计算, 由于数据集单位未包含/s, 是否已经是月累加值? --2024/01/18(已解决) 22 | ==> 由gldas_tws_eg.py知是: numbers of 3 hours in a month, 23 | 另外nc文件全局属性也提及: 24 | :tavg_definision: = "past 3-hour average"; 25 | :acc_definision: = "past 3-hour accumulation"; 26 | 27 | """ 28 | 29 | import os.path 30 | from glob import glob 31 | from calendar import monthrange 32 | from datetime import datetime 33 | 34 | import numpy as np 35 | import xarray as xr 36 | from osgeo import gdal, osr 37 | 38 | # 准备 39 | in_dir = r'E:\Global GLDAS' # 检索该文件夹及迭代其所有子文件夹满足要求的文件 40 | out_dir = r'E:\FeaturesTargets\non_uniform' 41 | target_names = ['Rainf_f_tavg', 'Evap_tavg', 'Qs_acc', 'Qsb_acc'] 42 | out_names = ['PRCP', 'ET', 'Qs', 'Qsb', 'TWSC'] 43 | out_res = 0.1 # default: 0.25°, base on default res of gldas 44 | no_data_value = -65535.0 # 缺失值或者无效值的设置 45 | # 预准备 46 | [os.makedirs(os.path.join(out_dir, _name)) for _name in out_names if not os.path.exists(os.path.join(out_dir, _name))] 47 | 48 | # 检索和循环 49 | nc_paths = glob(os.path.join(in_dir, '**', 'GLDAS_NOAH025_M*.nc4'), recursive=True) 50 | for nc_path in nc_paths: 51 | # 获取当前月天数 52 | cur_time = datetime.strptime(nc_path.split('.')[1], 'A%Y%m') # eg. 200204 53 | _, cur_month_days = monthrange(cur_time.year, cur_time.month) 54 | 55 | ds = xr.open_dataset(nc_path) 56 | # 读取经纬度数据集和地理参数 57 | lon = ds['lon'].values # (1440, ) 58 | lat = ds['lat'].values # (600, ) 59 | lon_res = ds.attrs['DX'] 60 | lat_res = ds.attrs['DY'] 61 | lon_min = min(lon) - lon_res / 2.0 62 | lon_max = max(lon) + lon_res / 2.0 63 | lat_min = min(lat) - lat_res / 2.0 64 | lat_max = max(lat) + lat_res / 2.0 65 | """ 66 | 注意: 经纬度数据集中的所有值均指代对应地理位置的像元的中心处的经纬度, 因此经纬度范围需要往外扩充0.5个分辨率 67 | """ 68 | geo_transform = [lon_min, lon_res, 0, lat_max, 0, -lat_res] # gdal要求样式 69 | srs = osr.SpatialReference() 70 | srs.ImportFromEPSG(4326) # WGS84 71 | 72 | fluxs = {} 73 | # 获取Rain_f_tavg, Evap_tavg, Qs_acc, Qsb_acc四个数据集 74 | for target_name, out_name in zip(target_names, out_names): # 仅循环前四次 75 | # 计算月累加值 76 | flux = ds[target_name].values 77 | vmin = ds[target_name].attrs['vmin'] 78 | vmax = ds[target_name].attrs['vmax'] 79 | flux[(flux < vmin) | (flux > vmax)] = np.nan # 将不在规定范围内的值设置为nan 80 | flux = np.squeeze(flux) # 去掉多余维度 81 | flux = np.flipud(flux) # 南北极颠倒(使之正常: 北极在上) 82 | if target_name.endswith('acc'): # :acc_definision: = "past 3-hour accumulation"; 83 | flux *= cur_month_days * 8 84 | elif target_name.endswith('tavg'): # :tavg_definision: = "past 3-hour average"; 85 | flux *= cur_month_days * 24 * 3600 86 | fluxs[out_name] = flux 87 | 88 | fluxs['TWSC'] = fluxs['PRCP'] - fluxs['ET'] - (fluxs['Qs'] + fluxs['Qsb']) # 计算TWSC 89 | for out_name, flux in fluxs.items(): 90 | # 输出路径 91 | cur_out_name = 'GLDAS_{}_{:04}{:02}.tiff'.format(out_name, cur_time.year, cur_time.month) 92 | cur_out_path = os.path.join(out_dir, out_name, cur_out_name) 93 | 94 | driver = gdal.GetDriverByName('MEM') # 在内存/TIFF中创建 95 | temp_img = driver.Create('', flux.shape[1], flux.shape[0], 1, gdal.GDT_Float32) 96 | temp_img.SetProjection(srs.ExportToWkt()) # 设置坐标系 97 | temp_img.SetGeoTransform(geo_transform) # 设置仿射参数 98 | flux = np.nan_to_num(flux, nan=no_data_value) 99 | temp_img.GetRasterBand(1).WriteArray(flux) # 写入数据集 100 | temp_img.GetRasterBand(1).SetNoDataValue(no_data_value) # 设置无效值 101 | resample_img = gdal.Warp(cur_out_path, temp_img, xRes=out_res, yRes=out_res, resampleAlg=gdal.GRA_Cubic) # 重采样 102 | # 去除由于重采样造成的数据集不符合实际意义例如降水为负值等情况 103 | vmin = np.nanmin(flux) 104 | vmax = np.nanmax(flux) 105 | flux = resample_img.GetRasterBand(1).ReadAsArray() 106 | resample_img_srs = resample_img.GetProjection() 107 | resample_img_transform = resample_img.GetGeoTransform() 108 | temp_img, resample_img = None, None # 释放资源 109 | flux[flux < vmin] = vmin 110 | flux[flux > vmax] = vmax 111 | driver = gdal.GetDriverByName('GTiff') 112 | final_img = driver.Create(cur_out_path, flux.shape[1], flux.shape[0], 1, gdal.GDT_Float32) 113 | final_img.SetProjection(resample_img_srs) 114 | final_img.SetGeoTransform(resample_img_transform) 115 | final_img.GetRasterBand(1).WriteArray(flux) 116 | final_img.GetRasterBand(1).SetNoDataValue(no_data_value) 117 | final_img.FlushCache() 118 | temp_img, final_img = None, None 119 | 120 | print('当前处理: {}-{}'.format(out_name, cur_time.strftime('%Y%m'))) 121 | 122 | ds.close() # 关闭当前nc文件,释放资源 123 | print('处理完成') -------------------------------------------------------------------------------- /Core/Plot/density_box_plot.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/3/11 18:58 3 | # @Email : chaoqiezi.one@qq.com 4 | 5 | """ 6 | This script is used to 是用来绘图滴,主要是箱线图和核密度散点图 7 | """ 8 | 9 | import os 10 | import numpy as np 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | from scipy.stats import gaussian_kde 14 | import seaborn as sns 15 | from osgeo import gdal 16 | from matplotlib.colors import LinearSegmentedColormap 17 | 18 | # 准备 19 | in_path = r'H:\Datasets\Objects\Veg\Plot\cor_by_st.csv' 20 | dem_path = r'H:\Datasets\Objects\Veg\DEM\dem_1km.tif' 21 | out_dir =r'H:\Datasets\Objects\Veg\Plot' 22 | sns.set_style('darkgrid') # 设置风格 23 | plt.rcParams['font.sans-serif'] = ['Times New Roman'] 24 | plt.rcParams['axes.unicode_minus'] = False # 允许负号正常显示 25 | 26 | # 加载数据 27 | df = pd.read_csv(in_path) 28 | dem = gdal.Open(dem_path) 29 | dem_raster = dem.GetRasterBand(1).ReadAsArray() # 获取dem栅格矩阵 30 | dem_nodata_value = dem.GetRasterBand(1).GetNoDataValue() # 获取无效值 31 | lon_ul, lon_res, _, lat_ul, _, lat_res_negative = dem.GetGeoTransform() # [左上角经度, 经度分辨率, 旋转角度, 左上角纬度, 旋转角度, -纬度分辨率] 32 | lat_res = -lat_res_negative 33 | # 删除TWSC列, 将TWSC_SH列标签换为TWSC 34 | df.drop(['TWSC', 'TWSC_1', 'TWSC_2', 'TWSC_3'], axis=1, inplace=True) 35 | df.rename(columns={'TWSC_SH': 'TWSC', 'TWSC_SH_1': 'TWSC_1', 'TWSC_SH_2': 'TWSC_2', 'TWSC_SH_3': 'TWSC_3'}, inplace=True) 36 | iter_columns_name = df.columns[4:] 37 | # 色带 38 | colors = ['#ff0000', '#ff6f00', '#fbb700', '#cdff00', '#a1ff6e', '#52ffc7', '#00ffff', '#15acff', '#4261ff', '#3100fe'] 39 | colors.reverse() 40 | cm = LinearSegmentedColormap.from_list('common', colors, 100) 41 | 42 | # 添加DEM列 43 | cols = np.floor((df['Lon'] - lon_ul) / lon_res).astype(int) 44 | rows = np.floor((lat_ul - df['Lat']) / lat_res).astype(int) 45 | df['DEM'] = dem_raster[rows, cols] 46 | df[df['DEM'] == dem_nodata_value] = np.nan 47 | # 绘制散点核密度图 48 | for column_name in iter_columns_name: 49 | plt.figure(dpi=200) 50 | cur_ds = df[['DEM', column_name]].dropna(how='any') 51 | cur_ds['Density'] = gaussian_kde(cur_ds[column_name])(cur_ds[column_name]) 52 | 53 | scatter = plt.scatter(x='DEM', y=column_name, c='Density', cmap=cm, linewidth=0, data=cur_ds, s=20) 54 | clb = plt.colorbar(scatter) 55 | clb.ax.set_title('Density', fontsize=8) # 为色带添加标题 56 | # sns.kdeplot(x='DEM', y=column_name, fill=True, data=cur_ds, alpha=0.6) 57 | sns.kdeplot(x='DEM', y=column_name, fill=False, color='gray', data=cur_ds, alpha=0.6) 58 | title_name = 'Scatter kernel density map of $R^2$ \n between NDVI and {} under DEM'.format(column_name) 59 | plt.title(title_name, fontsize=16) 60 | plt.xlabel('DEM(m)', fontsize=14) 61 | plt.ylabel('$R^2$ between NDVI and {}'.format(column_name), fontsize=14) 62 | plt.xticks(fontsize=12) 63 | plt.yticks(fontsize=12) 64 | # 设置XY轴起始值 65 | plt.xlim(left=0) 66 | plt.ylim(bottom=0) 67 | plt.savefig(os.path.join(out_dir, 'R2_{}.png'.format(column_name)), dpi=200) 68 | # plt.show() 69 | print('处理: {}'.format(column_name)) 70 | # 绘制箱线图 71 | meanprops = {"marker":"o", "markerfacecolor":"white", "markeredgecolor":"black", "markersize":"10"} 72 | fig, axs = plt.subplots(4, 1, figsize=(13, 18), dpi=432) 73 | axs = axs.flatten() 74 | fig.suptitle('Box plot of NDVI and correlation coefficients of each variable', fontsize=30, va='top') 75 | for ix, ax in enumerate(axs): 76 | # print(iter_columns_name[(ix * 9):((ix + 1) * 9)]) 77 | # ax.figure(figsize=(26, 9), dpi=321) 78 | df_melt = pd.melt(df, value_vars=iter_columns_name[(ix * 8):((ix + 1) * 8)]).dropna(how='any') 79 | sns.boxplot(data=df_melt, x='variable', y='value', palette=cm(np.linspace(0, 1, 9)), ax=ax, linewidth=3, 80 | showmeans=True, meanprops=meanprops) 81 | ax.set_xlabel('', fontsize=25) 82 | ax.set_ylabel('$R^2$', fontsize=25) 83 | ax.tick_params(axis='x', labelsize=18) # x轴标签旋转90度 84 | ax.tick_params(axis='y', labelsize=18) 85 | ax.grid(True) 86 | plt.tight_layout(pad=2) 87 | fig.savefig(os.path.join(out_dir, 'Box_R2.png')) 88 | # plt.show() 89 | 90 | 91 | # 用于看 92 | fig, axs = plt.subplots(4, 1, figsize=(13, 18), dpi=432) 93 | axs = axs.flatten() 94 | fig.suptitle('Box plot of NDVI and correlation coefficients of each variable', fontsize=30, va='top') 95 | 96 | for ix, ax in enumerate(axs): 97 | columns_slice = iter_columns_name[(ix * 9):((ix + 1) * 9)] 98 | df_melt = pd.melt(df, value_vars=columns_slice).dropna(how='any') 99 | # sns.boxplot(data=df_melt, x='variable', y='value', palette=cm(np.linspace(0, 1, 9)), ax=ax, linewidth=3, 100 | # showmeans=True, meanprops=meanprops) 101 | sns.boxplot(data=df_melt, x='variable', y='value', palette=cm(np.linspace(0, 1, 9)), ax=ax, linewidth=3) 102 | 103 | # 循环每个变量,计算最大值、最小值和平均值,然后在图上标注 104 | for i, column in enumerate(columns_slice): 105 | subset = df[column].dropna() 106 | max_val = subset.max() 107 | min_val = subset.min() 108 | mean_val = subset.mean() 109 | 110 | # 标注最大值、最小值和平均值 111 | ax.text(i, max_val, f'{max_val:.2f}', ha='center', va='bottom', fontsize=16, rotation=45) 112 | ax.text(i, min_val, f'{min_val:.2f}', ha='center', va='top', fontsize=16, rotation=45) 113 | ax.text(i, mean_val, f'{mean_val:.2f}', ha='center', va='center', fontsize=16, color='white', rotation=45) 114 | 115 | ax.set_xlabel('', fontsize=25) 116 | ax.set_ylabel('$R^2$', fontsize=25) 117 | ax.tick_params(axis='x', labelsize=18) 118 | ax.tick_params(axis='y', labelsize=18) 119 | ax.grid(True) 120 | 121 | plt.tight_layout(pad=2) 122 | fig.savefig(os.path.join(out_dir, 'Box_R2_quick.png')) 123 | # plt.show() -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/1/15 16:01 3 | # @Email : chaoqiezi.one@qq.com 4 | 5 | """ 6 | This script is used to 定义模型""" 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torchsummary import summary 12 | 13 | 14 | class Encoder(nn.Module): 15 | def __init__(self, 16 | input_size=6, 17 | embedding_size=128, 18 | hidden_size=256, 19 | lstm_layers=3, 20 | dropout=0.5): 21 | super().__init__() 22 | self.fc = nn.Linear(input_size, embedding_size) 23 | self.rnn = nn.LSTM(embedding_size, hidden_size, lstm_layers, batch_first=True, dropout=dropout) 24 | self.dropout = nn.Dropout(dropout) 25 | 26 | def forward(self, x): 27 | embedded = self.dropout(F.relu(self.fc(x))) 28 | output, (hidden, cell) = self.rnn(embedded) 29 | 30 | return hidden, cell 31 | 32 | 33 | class Decoder(nn.Module): 34 | def __init__(self, 35 | input_size=6, 36 | embedding_size=128, 37 | hidden_size=256, 38 | output_size=1, 39 | lstm_layers=3, 40 | dropout=0.5): 41 | super().__init__() 42 | self.embedding = nn.Linear(input_size, embedding_size) 43 | self.rnn = nn.LSTM(embedding_size, hidden_size, lstm_layers, dropout=dropout) 44 | self.fc = nn.Linear(hidden_size, output_size) 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | def forward(self, x, hidden, cell): 48 | # x = x.unsqueeze(0) # add dimension, original shape=(batch_size, feature_size), ==> (1, batch_size, feature) 49 | 50 | embedded = self.dropout(F.relu(self.embedding(x))) 51 | output, (hidden, cell) = self.rnn(embedded, (hidden, cell)) 52 | prediction = self.fc(output[-1]).squeeze() 53 | 54 | return prediction, hidden, cell 55 | 56 | 57 | class Seq2Seq(nn.Module): 58 | def __init__(self, encoder, decoder, device): 59 | super().__init__() 60 | self.encoder = encoder 61 | self.decoder = decoder 62 | self.device = device 63 | 64 | def forward(self, x, teacher_forcing_ratio=0.5): 65 | # time_step = y.shape[1] 66 | time_step = 12 67 | hidden, cell = self.encoder(x) 68 | 69 | outputs = torch.zeros((x.shape[0], time_step)).to(self.device) 70 | # decoder_input = x[:, -1, :] 71 | 72 | for time_ix in range(time_step): 73 | decoder_input = x[:, time_ix:time_ix+1, :] 74 | decoder_input = torch.transpose(decoder_input, 0, 1) 75 | output, hidden, cell = self.decoder(decoder_input, hidden, cell) 76 | outputs[:, time_ix] = output 77 | 78 | # teacher_forcing = random.random() < teacher_forcing_ratio 79 | 80 | # decoder_input = y[:, time_ix] if teacher_forcing else output 81 | 82 | return outputs 83 | 84 | # 创建编码解码的lstm模型 85 | encoder = Encoder(6, 128, 256) 86 | decoder = Decoder(6, 128, 256) 87 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 88 | model = Seq2Seq(encoder, decoder, device) 89 | summary(encoder, (12, 6)) 90 | summary(model, (12, 6)) 91 | 92 | 93 | class LSTMModel(nn.Module): 94 | def __init__(self, input_size, hidden_size, num_layers, output_size): 95 | """ 96 | :param input_size: 输入动态特征项数 97 | :param hidden_size: LSTM的隐藏层大小 98 | :param num_layers: LSTM层数 99 | :param output_size: 输出时间序列长度(default: 12, 12个月份) 100 | """ 101 | super().__init__() 102 | self.causal_conv1d = nn.Conv1d(input_size, 128, 5) 103 | self.fc1 = nn.Linear(4, 128) 104 | self.rnn = nn.LSTM(128, hidden_size, num_layers, batch_first=True) 105 | self.fc2 = nn.Linear(384, output_size) 106 | 107 | def forward(self, dynamic_x, static_x): 108 | # 因果卷积 109 | conv1d_out = (self.causal_conv1d(F.pad(torch.transpose(dynamic_x, 1, 2), (2, 0)))) 110 | # conv1d_out = self.causal_conv1d(F.pad(torch.transpose(dynamic_x, 1, 2), (2, 0))) 111 | # conv1d_out = self.causal_conv1d(torch.transpose(dynamic_x, 1, 2)) 112 | # LSTM层 113 | lstm_out, _ = self.rnn(torch.transpose(conv1d_out, 1, 2)) 114 | # 只使用最后一个时间步的输出 115 | lstm_out = lstm_out[:, -1, :] # (-1, 256) 116 | static_out = self.fc1(static_x) # (-1, 2) ==> (-1, 128) 117 | # static_out = self.fc1(static_x) # (-1, 2) ==> (-1, 128) 2024/5/11: 静态特征由2变为4, 新增Lon、Lat 118 | merged_out = torch.cat([lstm_out, static_out], dim=1) # (-1, 256 + 128) 119 | # 全连接层 120 | out = self.fc2(merged_out) # (-1, 12) 121 | 122 | return out 123 | 124 | 125 | class LSTMModelDynamic(nn.Module): 126 | def __init__(self, input_size, hidden_size, num_layers, output_size): 127 | """ 128 | :param input_size: 输入动态特征项数 129 | :param hidden_size: LSTM的隐藏层大小 130 | :param num_layers: LSTM层数 131 | :param output_size: 输出时间序列长度(default: 12, 12个月份) 132 | """ 133 | super().__init__() 134 | self.causal_conv1d = nn.Conv1d(input_size, 128, 5) 135 | # self.fc1 = nn.Linear(4, 128) 136 | self.rnn = nn.LSTM(128, hidden_size, num_layers, batch_first=True) 137 | self.fc2 = nn.Linear(hidden_size, output_size) 138 | 139 | def forward(self, dynamic_x): 140 | # 因果卷积 141 | conv1d_out = (self.causal_conv1d(F.pad(torch.transpose(dynamic_x, 1, 2), (2, 0)))) 142 | # conv1d_out = self.causal_conv1d(F.pad(torch.transpose(dynamic_x, 1, 2), (2, 0))) 143 | # conv1d_out = self.causal_conv1d(torch.transpose(dynamic_x, 1, 2)) 144 | # LSTM层 145 | lstm_out, _ = self.rnn(torch.transpose(conv1d_out, 1, 2)) 146 | # 只使用最后一个时间步的输出 147 | lstm_out = lstm_out[:, -1, :] # (-1, 256) 148 | # static_out = self.fc1(static_x) # (-1, 2) ==> (-1, 128) 149 | # static_out = self.fc1(static_x) # (-1, 2) ==> (-1, 128) 2024/5/11: 静态特征由2变为4, 新增Lon、Lat 150 | # merged_out = torch.cat([lstm_out, static_out], dim=1) # (-1, 256 + 128) 151 | # 全连接层 152 | out = self.fc2(lstm_out) # (-1, 12) 153 | 154 | return out -------------------------------------------------------------------------------- /Core/uniform_datasets.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/1/3 16:51 3 | # @Email : chaoqiezi.one@qq.com 4 | 5 | """ 6 | This script is used to 对各个数据集进行统一,例如空间范围() 7 | 8 | 主要包括: 对modis(土地利用、ndvi、地表温度)、geo(DEM等)、gldas数据集进行重采样, 范围限定(裁剪至掩膜形状) 9 | """ 10 | 11 | import os.path 12 | from glob import glob 13 | from concurrent.futures import ThreadPoolExecutor # 线程池 14 | 15 | from osgeo import gdal 16 | 17 | # 准备 18 | in_dir = r'E:\FeaturesTargets\non_uniform' 19 | out_dir = r'E:\FeaturesTargets\uniform' 20 | shp_path = r'E:\Basic\Region\sw5f\sw5_mask.shp' 21 | dem_path = r'E:\GEO\cndem01.tif' 22 | out_res = 0.1 23 | 24 | 25 | def resample_clip_mask(in_dir: str, out_dir: str, shp_path: str, wildcard: str, out_res: float = 0.1, 26 | resampleAlg=gdal.GRA_Cubic): 27 | """ 28 | 该函数用于对指定文件夹内的影像进行批量重采样和裁剪、掩膜 29 | :param in_dir: 待处理文件所在文件夹目录 30 | :param out_dir: 输出文件的文件夹目录 31 | :param shp_path: 掩膜裁剪的shp文件 32 | :param wildcard: 检索输入文件夹内指定文件的通配符 33 | :param out_res: 输出分辨率 34 | :param resampleAlg: 重采样方法 35 | :return: None 36 | """ 37 | 38 | if not os.path.exists(out_dir): os.makedirs(out_dir) 39 | 40 | target_paths = glob(os.path.join(in_dir, wildcard)) 41 | for target_path in target_paths: 42 | out_path = os.path.join(out_dir, os.path.basename(target_path)) 43 | 44 | img = gdal.Warp( 45 | out_path, # 输出位置 46 | target_path, # 源文件位置 47 | cutlineDSName=shp_path, # 掩膜裁剪所需文件 48 | cropToCutline=True, # 裁剪至掩膜形状 49 | xRes=out_res, # X方向分辨率 50 | yRes=out_res, # Y方向分辨率 51 | resampleAlg=resampleAlg # 重采样方法 52 | ) 53 | img = None 54 | 55 | print('目前已处理: {}'.format(os.path.splitext(os.path.basename(target_path))[0])) 56 | 57 | 58 | # # 处理土地利用数据集 59 | # in_landuse_dir = os.path.join(in_dir, 'Landuse') 60 | # out_landuse_dir = os.path.join(out_dir, 'Landuse') 61 | # resample_clip_mask(in_landuse_dir, out_landuse_dir, shp_path, 'Landuse*.tiff', resampleAlg=gdal.GRA_NearestNeighbour) 62 | # # 处理地表温度数据集 63 | # in_lst_dir = os.path.join(in_dir, 'LST') 64 | # out_lst_dir = os.path.join(out_dir, 'LST') 65 | # resample_clip_mask(in_lst_dir, out_lst_dir, shp_path, 'LST*.tiff') 66 | # # 处理NDVI数据集 67 | # in_ndvi_dir = os.path.join(in_dir, 'NDVI') 68 | # out_ndvi_dir = os.path.join(out_dir, 'NDVI') 69 | # resample_clip_mask(in_ndvi_dir, out_ndvi_dir, shp_path, 'NDVI*.tiff') 70 | # # 处理ET(蒸散发量)数据集 71 | # in_et_dir = os.path.join(in_dir, 'ET') 72 | # out_et_dir = os.path.join(out_dir, 'ET') 73 | # resample_clip_mask(in_et_dir, out_et_dir, shp_path, 'GLDAS_ET*.tiff') 74 | # # 处理降水数据集 75 | # in_prcp_dir = os.path.join(in_dir, 'PRCP') 76 | # out_prcp_dir = os.path.join(out_dir, 'PRCP') 77 | # resample_clip_mask(in_prcp_dir, out_prcp_dir, shp_path, 'GLDAS_PRCP*.tiff') 78 | # # 处理Qs(表面径流量)数据集 79 | # in_qs_dir = os.path.join(in_dir, 'Qs') 80 | # out_qs_dir = os.path.join(out_dir, 'Qs') 81 | # resample_clip_mask(in_qs_dir, out_qs_dir, shp_path, 'GLDAS_Qs*.tiff') 82 | # # 处理Qsb(地下径流量)数据集 83 | # in_qsb_dir = os.path.join(in_dir, 'Qsb') 84 | # out_qsb_dir = os.path.join(out_dir, 'Qsb') 85 | # resample_clip_mask(in_qsb_dir, out_qsb_dir, shp_path, 'GLDAS_Qsb*.tiff') 86 | # # 处理TWSC数据集 87 | # in_twsc_dir = os.path.join(in_dir, 'TWSC') 88 | # out_twsc_dir = os.path.join(out_dir, 'TWSC') 89 | # resample_clip_mask(in_twsc_dir, out_twsc_dir, shp_path, 'GLDAS_TWSC*.tiff') 90 | # 处理DEM数据集 91 | # out_dem_path = os.path.join(out_dir, 'dem.tiff') 92 | # img = gdal.Warp( 93 | # out_dem_path, 94 | # dem_path, 95 | # cutlineDSName=shp_path, 96 | # cropToCutline=True, 97 | # xRes=out_res, 98 | # yRes=out_res, 99 | # resampleAlg=gdal.GRA_Cubic 100 | # ) 101 | # img = None 102 | 103 | # 并行处理(加快处理速度) 104 | datasets_param = { 105 | 'Landuse': 'Landuse*.tiff', 106 | 'LST_MEAN': 'LST_MEAN*.tiff', 107 | 'LST_MAX': 'LST_MAX*.tiff', 108 | 'LST_MIN': 'LST_MIN*.tiff', 109 | 'NDVI_MEAN': 'NDVI_MEAN*.tiff', 110 | 'NDVI_MAX': 'NDVI_MAX*.tiff', 111 | 'NDVI_MIN': 'NDVI_MIN*.tiff', 112 | 'ET': 'GLDAS_ET*.tiff', 113 | 'PRCP': 'GLDAS_PRCP*.tiff', 114 | 'Qs': 'GLDAS_Qs*.tiff', 115 | 'Qsb': 'GLDAS_Qsb*.tiff', 116 | 'TWSC': 'GLDAS_TWSC*.tiff', 117 | 118 | } 119 | 120 | if __name__ == '__main__': 121 | with ThreadPoolExecutor() as executor: 122 | futures = [] 123 | for dataset_name, wildcard in datasets_param.items(): 124 | in_dataset_dir = os.path.join(in_dir, dataset_name) 125 | out_dataset_dir = os.path.join(out_dir, dataset_name) 126 | resampleAlg = gdal.GRA_NearestNeighbour if dataset_name == 'Landuse' else gdal.GRA_Cubic 127 | futures.append(executor.submit(resample_clip_mask, in_dataset_dir, out_dataset_dir, shp_path, 128 | wildcard, resampleAlg=resampleAlg)) 129 | # 处理DEM 130 | out_dem_path = os.path.join(out_dir, 'dem.tiff') 131 | futures.append(executor.submit(gdal.Warp, out_dem_path, dem_path, cutlineDSName=shp_path, 132 | cropToCutline=True, xRes=out_res, yRes=out_res, resampleAlg=gdal.GRA_Cubic)) 133 | # 等待所有数据集处理完成 134 | for future in futures: 135 | future.result() 136 | 137 | # 处理DEM数据集 138 | """ 139 | 下述代码比较冗余, 简化为resample_clip_mask函数 140 | ---------------------------------------------------------------------- 141 | # 处理地表温度数据 142 | lst_paths = glob(os.path.join(lst_dir, 'LST*.tiff')) 143 | out_lst_dir = os.path.join(out_dir, lst_dir.split('\\')[-1]) 144 | if not os.path.exists(out_lst_dir): os.makedirs(out_lst_dir) 145 | for lst_path in lst_paths: 146 | out_path = os.path.join(out_lst_dir, os.path.basename(lst_path)) 147 | 148 | # 重采样、掩膜和裁剪 149 | gdal.Warp( 150 | out_path, 151 | lst_path, 152 | xRes=out_res, 153 | yRes=out_res, 154 | cutlineDSName=shp_path, # 设置掩膜 shp文件 155 | cropToCutline=True, # 裁剪至掩膜形状 156 | resampleAlg=gdal.GRA_Cubic # 重采样方法: 三次卷积 157 | ) 158 | print('目前已处理: {}'.format(os.path.splitext(os.path.basename(lst_path))[0])) 159 | 160 | # 处理ndvi数据集 161 | ndvi_paths = glob(os.path.join(ndvi_dir, 'NDVI*.tiff')) 162 | out_ndvi_dir = os.path.join(out_dir, ndvi_dir.split('\\')[-1]) 163 | if not os.path.exists(out_ndvi_dir): os.makedirs(out_ndvi_dir) 164 | for ndvi_path in ndvi_paths: 165 | out_path = os.path.join(out_ndvi_dir, os.path.basename(ndvi_path)) 166 | out_path = os.path.join(out_ndvi_dir, 'NDVI_temp.tiff') 167 | gdal.Warp( 168 | out_path, 169 | ndvi_path, 170 | cutlineDSName=shp_path, # 设置掩膜 shp文件 171 | cropToCutline=True, # 是否裁剪至掩膜形状 172 | xRes=out_res, 173 | yRes=out_res, 174 | resampleAlg=gdal.GRA_Cubic # 重采样方法: 三次卷积 175 | ) 176 | """ 177 | -------------------------------------------------------------------------------- /Core/feature_engineering.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/1/19 3:12 3 | # @Email : chaoqiezi.one@qq.com 4 | 5 | """ 6 | This script is used to 包括数据集的整合以支持输入到模型中训练,以及特征工程 7 | 8 | 各个数据集的时间范围: 9 | 10 | Landuse: 2001 - 2020 11 | LST(MEAN/MIN/MAX): 200002 - 202210 12 | NDVI(MEAN/MIN/MAX): 200002 - 202010 13 | ET: 200204 - 202309 14 | PRCP: 200204 - 202309 15 | Qs: 200204 - 202309 16 | Qsb: 200204 - 202309 17 | TWSC: 200204 - 202309 18 | dem: single 19 | 20 | 输出的nc文件的数据格式: 21 | - group(year) 22 | - features1 -> (None, time_step, features_count) , eg. (184, 139, 12 or other, 6) 23 | 7: LST, PRCP, ET, Qs, Qsb, TWSC 24 | - features2 -> (None, ), Landuse, (184 * 139) 25 | - targets-> (Noner, time_step), NDVI, (184 * 139, 12) 26 | - features3 -> dem 27 | 28 | 2024/5/11 新增关于Rs地表太阳辐射和经纬度数据集的添加 29 | 由于Rs时间范围为1983-2017年6月份, 因此此处公共部分的使用日期缩短到2016年12月份. 30 | 31 | 2024/5/11 仅使用LST,PRCP,ET,RS四个变量进行特征构建 32 | """ 33 | 34 | from datetime import datetime 35 | import os 36 | import re 37 | from glob import glob 38 | 39 | import netCDF4 as nc 40 | import numpy as np 41 | from osgeo import gdal 42 | import h5py 43 | import torch 44 | from sklearn.preprocessing import MinMaxScaler, StandardScaler, scale 45 | 46 | 47 | def read_img(img_path): 48 | """ 49 | 读取栅格文件的波段数据集 50 | :param img_path: 待读取栅格文件的路径 51 | :return: 波段数据集 52 | """ 53 | img = gdal.Open(img_path) 54 | band = np.float32(img.GetRasterBand(1).ReadAsArray()) 55 | no_data_value = img.GetRasterBand(1).GetNoDataValue() 56 | band[band == no_data_value] = np.nan 57 | 58 | return band 59 | 60 | 61 | # 准备 62 | in_dir = r'E:\FeaturesTargets\uniform' 63 | h5_path = r'E:\FeaturesTargets\features_targets.h5' 64 | dem_path = r'E:\FeaturesTargets\uniform\dem.tiff' 65 | slope_path = r'E:\FeaturesTargets\uniform\slope.tif' 66 | lon_path = r'E:\FeaturesTargets\uniform\Lat.tiff' 67 | lat_path = r'E:\FeaturesTargets\uniform\Lat.tiff' 68 | start_date = datetime(2003, 1, 1) 69 | end_date = datetime(2016, 12, 1) 70 | features1_params = { 71 | 'LST_MAX': 'LST_MAX_', 72 | # 'LST_MIN': 'LST_MIN_', 73 | # 'LST_MEAN': 'LST_MEAN_', 74 | 'PRCP': 'GLDAS_PRCP_', 75 | 'ET': 'GLDAS_ET_', 76 | # 'Qs': 'GLDAS_Qs_', 77 | # 'Qsb': 'GLDAS_Qsb_', 78 | # 'TWSC': 'GLDAS_TWSC_', 79 | 'Rs': 'Rs_' 80 | } 81 | rows = 132 82 | cols = 193 83 | features1_size = len(features1_params) 84 | 85 | # 特征处理和写入 86 | h5 = h5py.File(h5_path, mode='w') 87 | for year in range(start_date.year, end_date.year + 1): 88 | start_month = start_date.month if year == start_date.year else 1 89 | end_month = end_date.month if year == end_date.year else 12 90 | 91 | features1 = [] # 存储动态特征 92 | targets = [] 93 | cur_group = h5.create_group(str(year)) 94 | for month in range(start_month, end_month + 1): 95 | # 当前月份特征项的读取 96 | cur_features = np.empty((rows, cols, features1_size)) 97 | for ix, (parent_folder_name, feature_wildcard) in enumerate(features1_params.items()): 98 | cur_in_dir = os.path.join(in_dir, parent_folder_name) 99 | pattern = re.compile(feature_wildcard + r'{:04}_?{:02}\.tiff'.format(year, month)) 100 | feature_paths = [_path for _path in os.listdir(cur_in_dir) if pattern.match(_path)] 101 | if len(feature_paths) != 1: 102 | raise NameError('文件名错误, 文件不存在或者指定文件存在多个') 103 | feature_path = os.path.join(cur_in_dir, feature_paths[0]) 104 | cur_features[:, :, ix] = read_img(feature_path) 105 | features1.append(cur_features.reshape(-1, features1_size)) 106 | # 当前月份目标项的读取 107 | ndvi_paths = glob(os.path.join(in_dir, 'NDVI_MAX', 'NDVI_MAX_{:04}_{:02}.tiff'.format(year, month))) 108 | if len(ndvi_paths) != 1: 109 | raise NameError('文件名错误, 文件不存在或者指定文件存在多个') 110 | ndvi_path = ndvi_paths[0] 111 | cur_ndvi = read_img(ndvi_path) 112 | targets.append(cur_ndvi.reshape(-1)) 113 | features1 = np.array(features1) 114 | targets = np.array(targets) 115 | 116 | """这里不使用土地利用数据,改用slope数据""" 117 | # landuse_paths = glob(os.path.join(in_dir, 'Landuse', 'Landuse_{}.tiff'.format(year))) 118 | # if len(landuse_paths) != 1: 119 | # raise NameError('文件名错误, 文件不存在或者指定文件存在多个') 120 | # landuse_path = landuse_paths[0] 121 | # features2 = read_img(landuse_path).reshape(-1) 122 | 123 | cur_group['features1'] = features1 124 | # cur_group['features2'] = features2 125 | cur_group['targets'] = targets 126 | print('目前已处理: {}'.format(year)) 127 | 128 | h5['dem'] = read_img(dem_path).reshape(-1) 129 | h5['slope'] = read_img(slope_path).reshape(-1) # 添加slope数据作为特征项 130 | h5['lon'] = read_img(lon_path).reshape(-1) 131 | h5['lat'] = read_img(lat_path).reshape(-1) 132 | if np.isnan(h5['lon']).any() or np.isnan(h5['lat']).any(): 133 | raise RuntimeWarning("Lon/Lat 存在无效值!") 134 | h5.flush() 135 | h5.close() 136 | h5 = None 137 | 138 | # 进一步处理,混合所有年份的数据(无需分组) 139 | with h5py.File(h5_path, mode='a') as h5: 140 | year_dem = h5['dem'] 141 | year_slope = h5['slope'] 142 | year_lon = h5['lon'] 143 | year_lat = h5['lat'] 144 | for year in range(start_date.year, end_date.year + 1): 145 | year_features1 = h5[r'{}/features1'.format(year)] # 这里导致的重大错误: year_features1 = h5[r'2003/features1'] 146 | # year_features2 = h5[r'2003/features2'] 147 | year_targets = h5[r'{}/targets'.format(year)] # Here too 148 | 149 | mask = np.all(~np.isnan(year_features1), axis=(0, 2)) & \ 150 | ~np.isnan(year_slope) & \ 151 | np.all(~np.isnan(year_targets), axis=0) & \ 152 | ~np.isnan(year_dem) 153 | h5['{}/mask'.format(year)] = mask 154 | if year == 2003: 155 | features1 = year_features1[:, mask, :] 156 | slope = year_slope[mask] 157 | targets = year_targets[:, mask] 158 | dem = year_dem[mask] 159 | lon = year_lon[mask] 160 | lat = year_lat[mask] 161 | else: 162 | features1 = np.concatenate((features1, year_features1[:, mask, :]), axis=1) 163 | slope = np.concatenate((slope, year_slope[mask]), axis=0) 164 | targets = np.concatenate((targets, year_targets[:, mask]), axis=1) 165 | dem = np.concatenate((dem, year_dem[mask]), axis=0) 166 | lon = np.concatenate((lon, year_lon[mask]), axis=0) 167 | lat = np.concatenate((dem, year_lat[mask]), axis=0) 168 | 169 | # 归一化 170 | scaler = StandardScaler() 171 | for month in range(12): 172 | features1[month, :, :] = scaler.fit_transform(features1[month, :, :]) 173 | dem = scaler.fit_transform(dem.reshape(-1, 1)).ravel() 174 | slope = scaler.fit_transform(slope.reshape(-1, 1)).ravel() 175 | lon = scaler.fit_transform(lon.reshape(-1, 1)).ravel() 176 | lat = scaler.fit_transform(lat.reshape(-1, 1)).ravel() 177 | 178 | sample_size = dem.shape[0] 179 | train_amount = int(sample_size * 0.8) 180 | eval_amount = sample_size - train_amount 181 | # 创建数据集并存储训练数据 182 | with h5py.File(r'E:\FeaturesTargets\train.h5', mode='w') as h5: 183 | h5.create_dataset('dynamic_features', data=features1[:, :train_amount, :]) 184 | h5.create_dataset('static_features1', data=slope[:train_amount]) # 静态变量 185 | h5.create_dataset('static_features2', data=dem[:train_amount]) # 静态变量 186 | h5.create_dataset('static_features3', data=lon[:train_amount]) # 静态变量 187 | h5.create_dataset('static_features4', data=lat[:train_amount]) # 静态变量 188 | h5.create_dataset('targets', data=targets[:, :train_amount]) 189 | with h5py.File(r'E:\FeaturesTargets\eval.h5', mode='w') as h5: 190 | # # # 创建数据集并存储评估数据 191 | h5.create_dataset('dynamic_features', data=features1[:, train_amount:, :]) 192 | h5.create_dataset('static_features1', data=slope[train_amount:]) # 静态变量 193 | h5.create_dataset('static_features2', data=dem[train_amount:]) # 静态变量 194 | h5.create_dataset('static_features3', data=lon[train_amount:]) # 静态变量 195 | h5.create_dataset('static_features4', data=lat[train_amount:]) # 静态变量 196 | h5.create_dataset('targets', data=targets[:, train_amount:]) 197 | -------------------------------------------------------------------------------- /Core/Plot/line_bar_distribution_plot.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/3/29 9:47 3 | # @FileName : line_bar_distribution_plot.py 4 | # @Email : chaoqiezi.one@qq.com 5 | 6 | """ 7 | This script is used to 绘制折线图、柱状图、插值分布图 8 | 9 | EWTC: 包括人类活动导致的(地表水地下使用加工运输到别的地方等), 自然变化的(蒸腾蒸发降水等)引起的储水量变化 10 | TWSC: 指单独自然变化导致的储水量变化 11 | AWC: (EWTC - TWSC即可得到)人类活动导致引发的储水量变化 12 | """ 13 | 14 | import glob 15 | import os.path 16 | import pandas as pd 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | import seaborn as sns 20 | from osgeo import gdal 21 | 22 | # 准备 23 | ewtc_path = r'H:\Datasets\Objects\Veg\LXB_plot\Data\sw_EWTC_T.csv' 24 | twsc_path = r'H:\Datasets\Objects\Veg\LXB_plot\Data\sw_TWSC_SH_T.csv' 25 | ndvi_path = r'H:\Datasets\Objects\Veg\LXB_plot\Data\sw_NDVI_T.csv' 26 | in_img_dir= r'E:\FeaturesTargets\uniform' 27 | out_dir = r'H:\Datasets\Objects\Veg\LXB_plot' 28 | sns.set_style('darkgrid') 29 | plt.rcParams['font.sans-serif'] = ['Times New Roman'] # 新罗马字体 30 | 31 | 32 | # 绘制 33 | ewtc = pd.read_csv(ewtc_path) 34 | twsc = pd.read_csv(twsc_path) 35 | ndvi = pd.read_csv(ndvi_path) 36 | awc = pd.merge(ewtc, twsc, left_on='ProvinceNa', right_on='ProvinceNa', suffixes=('_ewtc', '_twsc')) 37 | awc['AWC'] = awc['EWTC'] - awc['TWSC_SH'] 38 | 39 | # # 年月均值(ewtc twsc ndvi) 40 | # var_str = ['EWTC', 'TWSC_SH', 'AWC'] 41 | # for ix, var in enumerate([ewtc, twsc, awc]): 42 | # var_name = var_str[ix] 43 | # if var_name == 'AWC': 44 | # var_monthly = var[['Year_ewtc', 'Month_ewtc', var_name]].groupby(['Year_ewtc', 'Month_ewtc']).mean() 45 | # var_monthly.to_csv(os.path.join(out_dir, '{}.csv'.format(var_name))) 46 | # else: 47 | # var_monthly = var[['Year', 'Month', var_name]].groupby(['Year', 'Month']).mean() 48 | # var_monthly['Date'] = var_monthly.reset_index().apply(lambda x: '{:04.0f}/{:02.0f}'.format(x.iloc[0], x.iloc[1]), axis=1).values 49 | # var_monthly.to_csv(os.path.join(out_dir, '{}.csv'.format(var_name))) 50 | # # 月均值 51 | # var_str = ['EWTC', 'TWSC_SH', 'AWC'] 52 | # for ix, var in enumerate([ewtc, twsc, awc]): 53 | # var_name = var_str[ix] 54 | # if var_name == 'AWC': 55 | # var_monthly = var[['Month_ewtc', var_name]].groupby(['Month_ewtc']).mean() 56 | # # var_monthly.to_csv(os.path.join(out_dir, '{}_monthly.csv'.format(var_name))) 57 | # else: 58 | # var_monthly = var[['Month', var_name]].groupby(['Month']).mean() 59 | # var_monthly['Date'] = var_monthly.index 60 | # # var_monthly['Date'] = var_monthly.reset_index().apply(lambda x: '{:04.0f}/{:02.0f}'.format(x.iloc[0], x.iloc[1]), axis=1).values 61 | # var_monthly.to_csv(os.path.join(out_dir, '{}_monthly.csv'.format(var_name)), index=False) 62 | # 年均值 63 | var_str = ['EWTC', 'TWSC_SH', 'AWC'] 64 | for ix, var in enumerate([ewtc, twsc, awc]): 65 | var_name = var_str[ix] 66 | if var_name == 'AWC': 67 | var_yearly = var[['Year_ewtc', var_name]].groupby(['Year_ewtc']).mean() 68 | # var_monthly.to_csv(os.path.join(out_dir, '{}_monthly.csv'.format(var_name))) 69 | else: 70 | var_yearly = var[['Year', var_name]].groupby(['Year']).mean() 71 | var_yearly['Date'] = var_yearly.index 72 | # var_monthly['Date'] = var_monthly.reset_index().apply(lambda x: '{:04.0f}/{:02.0f}'.format(x.iloc[0], x.iloc[1]), axis=1).values 73 | var_yearly.to_csv(os.path.join(out_dir, '{}_yearly.csv'.format(var_name)), index=False) 74 | 75 | # PRCP, Qs, Qsb, ET年均值的计算 76 | prcp_month_path = r'H:\Datasets\Objects\Veg\LXB_plot\PRCP\PRCP.xlsx' 77 | et_month_path = r'H:\Datasets\Objects\Veg\LXB_plot\ET\ET.xlsx' 78 | qs_month_path = r'H:\Datasets\Objects\Veg\LXB_plot\Qs\Qs.xlsx' 79 | qsb_month_path = r'H:\Datasets\Objects\Veg\LXB_plot\Qsb\Qsb.xlsx' 80 | prcp = pd.read_excel(prcp_month_path) 81 | et = pd.read_excel(et_month_path) 82 | qs = pd.read_excel(qs_month_path) 83 | qsb = pd.read_excel(qsb_month_path) 84 | prcp['Year'] = prcp.date.apply(lambda x: x.year) 85 | prcp_yearly = prcp[['Year', 'PRCP']].groupby(['Year']).mean().reset_index(drop=False) 86 | prcp_yearly.to_excel(os.path.join('H:\Datasets\Objects\Veg\LXB_plot\PRCP', 'prcp_yearly.xlsx'), index=False) 87 | 88 | et['Year'] = et.date.apply(lambda x: x.year) 89 | et_yearly = et[['Year', 'ET']].groupby(['Year']).mean().reset_index(drop=False) 90 | et_yearly.to_excel(os.path.join('H:\Datasets\Objects\Veg\LXB_plot\ET', 'et_yearly.xlsx'), index=False) 91 | 92 | qs['Year'] = qs.date.apply(lambda x: x.year) 93 | qs_yearly = qs[['Year', 'Qs']].groupby(['Year']).mean().reset_index(drop=False) 94 | qs_yearly.to_excel(os.path.join('H:\Datasets\Objects\Veg\LXB_plot\Qs', 'Qs_yearly.xlsx'), index=False) 95 | 96 | qsb['Year'] = qsb.date.apply(lambda x: x.year) 97 | qsb_yearly = qsb[['Year', 'Qsb']].groupby(['Year']).mean().reset_index(drop=False) 98 | qsb_yearly.to_excel(os.path.join('H:\Datasets\Objects\Veg\LXB_plot\Qsb', 'Qsb_yearly.xlsx'), index=False) 99 | 100 | 101 | 102 | # 绘制NDVI年变化折线图和月变化柱状图 103 | ndvi_monthly = ndvi[['Month', 'NDVI']].groupby('Month').mean() 104 | ndvi_yearly = ndvi[['Year', 'NDVI']].groupby('Year').mean() 105 | ndvi_yearly['Year'] = ndvi_yearly.index 106 | ndvi_monthly['Month'] = ['{:02}'.format(_x) for _x in ndvi_monthly.index] 107 | ndvi_monthly.to_csv(os.path.join(out_dir, 'ndvi_monthly.csv'), index=False) 108 | ndvi_yearly.to_csv(os.path.join(out_dir, 'ndvi_yearly.csv'), index=False) 109 | # # 绘制折线图 110 | # plt.figure(figsize=(13, 9), dpi=222) 111 | # sns.lineplot(data=ndvi_yearly, x='Year', y='NDVI', linestyle='-', color='#1f77b4', linewidth=7, legend=True) 112 | # plt.scatter(ndvi_yearly['Year'], ndvi_yearly['NDVI'], s=100, facecolors='none', edgecolors='#bcbd22', linewidths=5, zorder=5) 113 | # plt.xlabel('Year', size=26) # 设置x轴标签 114 | # plt.ylabel('NDVI', size=26) # 设置y轴标签 115 | # plt.xticks(ndvi_yearly['Year'], rotation=45, fontsize=18) 116 | # plt.yticks(fontsize=22) 117 | # plt.savefig(os.path.join(out_dir, 'ndvi_line_yearly.png')) 118 | # plt.show() 119 | # # 绘制柱状图 120 | # plt.figure(figsize=(13, 9), dpi=222) 121 | # sns.barplot(data=ndvi_monthly, x='Month', y='NDVI', linestyle='-', color='#1f77b4') 122 | # plt.xlabel('Month', size=26) # 设置x轴标签 123 | # plt.ylabel('NDVI', size=26) # 设置y轴标签 124 | # x_labels = ndvi_monthly['Month'].apply(lambda x: '{:02}'.format(x)) 125 | # plt.xticks(ticks=range(len(x_labels)), labels=x_labels, fontsize=18) 126 | # plt.yticks(fontsize=22) 127 | # plt.savefig(os.path.join(out_dir, 'ndvi_line_monthly.png')) 128 | # plt.show() 129 | 130 | # 提取降水、蒸散、地表和地下径流 131 | station = ndvi.drop_duplicates(['Lon', 'Lat'])[['Lon', 'Lat']] 132 | var_names = ['PRCP', 'ET', 'Qs', 'Qsb'] 133 | for var_name in var_names: 134 | var = [] 135 | cur_dir = os.path.join(in_img_dir, var_name) 136 | var_paths = glob.glob(os.path.join(cur_dir, 'GLDAS_{}*.tiff'.format(var_name))) 137 | for var_path in var_paths: 138 | ds = gdal.Open(var_path) 139 | lon_min, lon_res, _, lat_max, _, lat_res_negative = ds.GetGeoTransform() 140 | ds_band = np.float32(ds.GetRasterBand(1).ReadAsArray()) 141 | nodata_value = ds.GetRasterBand(1).GetNoDataValue() 142 | ds_band[ds_band == nodata_value] = np.nan 143 | station['row'] = np.floor((lat_max - station['Lat']) / (-lat_res_negative)).astype(int) 144 | station['col'] = np.floor((station['Lon'] - lon_min) / lon_res).astype(int) 145 | station[var_name] = ds_band[station['row'], station['col']] 146 | station['date'] = os.path.basename(var_path).split('_')[2][:6] 147 | var.append(station.copy()) 148 | var = pd.concat(var, ignore_index=True) 149 | var['date'] = var['date'].apply(lambda x: x[:4] + '/' + x[4:]) 150 | var = var[['date', var_name]].groupby(['date']).mean() 151 | out_path = os.path.join(out_dir, '{}.csv'.format(var_name)) 152 | var.to_csv(out_path) 153 | 154 | # TWSC/EWTC/AWC均值计算 155 | ewtc_by_station = ewtc[['Lon', 'Lat', 'EWTC']].groupby(['Lon', 'Lat']).mean().reset_index() 156 | twsc_by_station = twsc[['Lon', 'Lat', 'TWSC_SH']].groupby(['Lon', 'Lat']).mean().reset_index() 157 | awc_by_station = awc[['Lat_ewtc', 'Lon_ewtc', 'AWC']].groupby(['Lat_ewtc', 'Lon_ewtc']).mean().reset_index() 158 | ndvi_by_station = ndvi[['Lon', 'Lat', 'NDVI']].groupby(['Lon', 'Lat']).mean().reset_index() 159 | for var in [ewtc_by_station, twsc_by_station, awc_by_station, ndvi_by_station]: 160 | out_path = os.path.join(out_dir, 'distribution_{}.csv'.format(var.columns[-1])) 161 | var.to_csv(out_path, index=False) -------------------------------------------------------------------------------- /Core/model_train_dynamic.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/1/3 16:54 3 | # @Email : chaoqiezi.one@qq.com 4 | 5 | """ 6 | This script is used to 构建lstm模型并训练 7 | 8 | 2024/5/11 增加Rs、Lon、Lat输入到模型中训练 9 | 10 | 2024/5/11 仅使用动态特征中的LST, PRCP, ET, RS进行训练 11 | """ 12 | 13 | import random 14 | import glob 15 | import os.path 16 | import numpy as np 17 | import pandas as pd 18 | import torch 19 | from torchsummary import summary 20 | from torch.utils.data import DataLoader, random_split 21 | from VEG.utils.utils import cal_r2, H5DynamicDatasetDecoder 22 | from VEG.utils.models import LSTMModel, LSTMModelDynamic 23 | from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error 24 | 25 | 26 | def set_seed(seed=42): 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | os.environ['PYTHONHASHSEED'] = str(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) # 如果使用多GPU 33 | torch.backends.cudnn.deterministic = True 34 | torch.backends.cudnn.benchmark = False 35 | 36 | 37 | set_seed(42) 38 | 39 | # 准备 40 | train_path = r'E:\FeaturesTargets\train.h5' 41 | eval_path = r'E:\FeaturesTargets\eval.h5' 42 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 43 | out_model_dir = r'E:\Models' 44 | dynamic_features_name = [ 45 | 'LST_MAX', 46 | 'PRCP', 47 | 'ET', 48 | # 'Qs', 49 | # 'Qsb', 50 | # 'TWSC', 51 | 'Rs' 52 | ] 53 | static_feature_name = [ 54 | 'Slope', 55 | 'DEM', 56 | 'Lon', 57 | 'Lat' 58 | ] 59 | # 创建LSTM模型实例并移至GPU 60 | # model = LSTMModel(4, 256, 4, 12).to('cuda' if torch.cuda.is_available() else 'cpu') 61 | # summary(model, input_data=[(12, 7), (4,)]) 62 | model = LSTMModelDynamic(4, 256, 4, 12).to('cuda' if torch.cuda.is_available() else 'cpu') 63 | summary(model, input_data=[(12, 4)]) 64 | batch_size = 256 65 | 66 | # generator = torch.Generator().manual_seed(42) # 指定随机种子 67 | # train_dataset, eval_dataset, sample_dataset = random_split(dataset, (0.8, 0.195, 0.005), generator=generator) 68 | # train_dataset, eval_dataset = random_split(dataset, (0.8, 0.2), generator=generator) 69 | # 创建数据加载器 70 | train_dataset = H5DynamicDatasetDecoder(train_path) # 创建自定义数据集实例 71 | eval_dataset = H5DynamicDatasetDecoder(eval_path) 72 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 73 | eval_data_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True) 74 | # 训练参数 75 | criterion = torch.nn.MSELoss() 76 | optimizer = torch.optim.Adam(model.parameters(), lr=0.002) # 初始学习率设置为0.001 77 | epochs_num = 30 78 | model.train() # 切换为训练模式 79 | 80 | 81 | def model_train(data_loader, epochs_num: int = 25,save_path: str = None, device='cuda'): 82 | # 创建新的模型实例 83 | model = LSTMModelDynamic(4, 256, 4, 12).to(device) 84 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) # 初始学习率设置为0.001 85 | epochs_loss = [] 86 | for epoch in range(epochs_num): 87 | train_loss = [] 88 | for dynamic_inputs, targets in data_loader: 89 | dynamic_inputs, targets = dynamic_inputs.to(device), targets.to(device) 90 | 91 | """正常""" 92 | # 前向传播 93 | outputs = model(dynamic_inputs) 94 | # 计算损失 95 | loss = criterion(outputs, targets) 96 | # 反向传播和优化 97 | loss.backward() 98 | optimizer.step() 99 | # scheduler.step() # 更新学习率 100 | 101 | optimizer.zero_grad() # 清除梯度 102 | train_loss.append(loss.item()) 103 | print(f'Epoch {epoch + 1}/{epochs_num}, Loss: {np.mean(train_loss)}') 104 | epochs_loss.append(np.mean(train_loss)) 105 | 106 | if save_path: 107 | torch.save(model.state_dict(), save_path) 108 | 109 | return epochs_loss 110 | 111 | 112 | def model_eval_whole(model_path: str, data_loader, device='cuda'): 113 | # 加载模型 114 | model = LSTMModelDynamic(4, 256, 4, 12).to(device) 115 | model.load_state_dict(torch.load(model_path)) 116 | 117 | # 评估 118 | model.eval() # 评估模式 119 | all_outputs = [] 120 | all_targets = [] 121 | with torch.no_grad(): 122 | for dynamic_inputs, targets in data_loader: 123 | dynamic_inputs, targets = dynamic_inputs.to(device), targets.to(device) 124 | outputs = model(dynamic_inputs) 125 | all_outputs.append(outputs.cpu()) # outputs/targets: (batch_size, time_steps) 126 | all_targets.append(targets.cpu()) 127 | 128 | all_outputs = np.concatenate(all_outputs, axis=0) 129 | all_targets = np.concatenate(all_targets, axis=0) 130 | 131 | # mse_per_step = [] 132 | # mae_per_step = [] 133 | # r2_per_step = [] 134 | # rmse_per_step = [] 135 | # for time_step in range(12): 136 | # mse_step = mean_squared_error(all_targets[:, time_step], all_outputs[:, time_step]) 137 | # mae_step = mean_absolute_error(all_targets[:, time_step], all_outputs[:, time_step]) 138 | # r2_step = r2_score(all_targets[:, time_step], all_outputs[:, time_step]) 139 | # rmse_step = np.sqrt(mse_step) 140 | # 141 | # mse_per_step.append(mse_step) 142 | # mae_per_step.append(mae_step) 143 | # r2_per_step.append(r2_step) 144 | # rmse_per_step.append(rmse_step) 145 | 146 | # mse = np.mean(mse_per_step) 147 | # mae = np.mean(mae_per_step) 148 | # r2 = np.mean(r2_per_step) 149 | # rmse = np.mean(rmse_per_step) 150 | 151 | # 不区分月份求取指标(视为整体) 152 | mse_step = mean_squared_error(all_targets.reshape(-1), all_outputs.reshape(-1)) 153 | mae_step = mean_absolute_error(all_targets.reshape(-1), all_outputs.reshape(-1)) 154 | r2_step = r2_score(all_targets.reshape(-1), all_outputs.reshape(-1)) 155 | rmse_step = np.sqrt(mse_step) 156 | return mse_step, mae_step, r2_step, rmse_step 157 | 158 | # return mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets 159 | 160 | 161 | 162 | def model_eval(model_path: str, data_loader, device='cuda'): 163 | # 加载模型 164 | model = LSTMModelDynamic(4, 256, 4, 12).to(device) 165 | model.load_state_dict(torch.load(model_path)) 166 | 167 | # 评估 168 | model.eval() # 评估模式 169 | all_outputs = [] 170 | all_targets = [] 171 | with torch.no_grad(): 172 | for dynamic_inputs, targets in data_loader: 173 | dynamic_inputs, targets = dynamic_inputs.to(device), targets.to(device) 174 | outputs = model(dynamic_inputs) 175 | all_outputs.append(outputs.cpu()) # outputs/targets: (batch_size, time_steps) 176 | all_targets.append(targets.cpu()) 177 | 178 | all_outputs = np.concatenate(all_outputs, axis=0) 179 | all_targets = np.concatenate(all_targets, axis=0) 180 | 181 | mse_per_step = [] 182 | mae_per_step = [] 183 | r2_per_step = [] 184 | rmse_per_step = [] 185 | for time_step in range(12): 186 | mse_step = mean_squared_error(all_targets[:, time_step], all_outputs[:, time_step]) 187 | mae_step = mean_absolute_error(all_targets[:, time_step], all_outputs[:, time_step]) 188 | r2_step = r2_score(all_targets[:, time_step], all_outputs[:, time_step]) 189 | rmse_step = np.sqrt(mse_step) 190 | 191 | mse_per_step.append(mse_step) 192 | mae_per_step.append(mae_step) 193 | r2_per_step.append(r2_step) 194 | rmse_per_step.append(rmse_step) 195 | 196 | return mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets 197 | 198 | if __name__ == '__main__': 199 | df = pd.DataFrame() 200 | # 常规训练 201 | df['normal_epochs_loss'] = model_train(train_data_loader, save_path=os.path.join(out_model_dir, 202 | 'normal_model_dynamic.pth')) 203 | print('>>> 常规训练结束') 204 | # 特征重要性训练 205 | # 动态特征 206 | for feature_ix in range(4): 207 | train_dataset = H5DynamicDatasetDecoder(train_path, shuffle_feature_ix=feature_ix, dynamic=True) # 创建自定义数据集实例 208 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 209 | 210 | cur_feature_name = dynamic_features_name[feature_ix] 211 | save_path = os.path.join(out_model_dir, cur_feature_name + '_model_dynamic.pth') 212 | df[cur_feature_name + '_epochs_loss'] = \ 213 | model_train(train_data_loader, save_path=save_path) 214 | print('>>> {}乱序排列 训练结束'.format(cur_feature_name)) 215 | # # 静态特征 216 | # for feature_ix in range(4): 217 | # train_dataset = H5DatasetDecoder(train_path, shuffle_feature_ix=feature_ix, dynamic=False) # 创建自定义数据集实例 218 | # train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 219 | # 220 | # cur_feature_name = static_feature_name[feature_ix] 221 | # save_path = os.path.join(out_model_dir, cur_feature_name + '_model.pth') 222 | # df[cur_feature_name + '_epochs_loss'] = \ 223 | # model_train(train_data_loader, save_path=save_path) 224 | # print('>>> {}乱序排列 训练结束'.format(cur_feature_name)) 225 | df.to_excel(r'E:\Models\training_eval_results\training_loss_dynamic.xlsx') 226 | 227 | # 评估 228 | indicator_whole = pd.DataFrame() 229 | indicator = pd.DataFrame() 230 | model_paths = glob.glob(os.path.join(out_model_dir, '*_model_dynamic.pth')) 231 | for model_path in model_paths: 232 | cur_model_name = os.path.basename(model_path).split('_model')[0] 233 | mse_step, mae_step, r2_step, rmse_step = model_eval_whole(model_path, eval_data_loader) 234 | indicator_whole[cur_model_name + '_evaluate_mse'] = [mse_step] 235 | indicator_whole[cur_model_name + '_evaluate_mae'] = [mae_step] 236 | indicator_whole[cur_model_name + '_evaluate_r2'] = [r2_step] 237 | indicator_whole[cur_model_name + '_evaluate_rmse'] = [rmse_step] 238 | 239 | mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets = model_eval(model_path, eval_data_loader) 240 | 241 | all_outputs_targets = np.concatenate((all_outputs, all_targets), axis=1) 242 | columns = [*['outputs_{:02}'.format(month) for month in range(1, 13)], *['targets_{:02}'.format(month) for month in range(1, 13)]] 243 | outputs_targets = pd.DataFrame(all_outputs_targets, columns=columns) 244 | indicator[cur_model_name + '_evaluate_mse'] = mse_per_step 245 | indicator[cur_model_name + '_evaluate_mae'] = mae_per_step 246 | indicator[cur_model_name + '_evaluate_r2'] = r2_per_step 247 | indicator[cur_model_name + '_evaluate_rmse'] = rmse_per_step 248 | outputs_targets.to_excel(r'E:\Models\training_eval_results\{}_outputs_targets_dynamic.xlsx'.format(cur_model_name)) 249 | print('>>> {} 重要性评估完毕'.format(cur_model_name)) 250 | indicator.loc['均值指标'] = np.mean(indicator, axis=0) 251 | indicator.to_excel(r'E:\Models\training_eval_results\eval_indicators_dynamic.xlsx') 252 | indicator_whole.to_excel(r'E:\Models\training_eval_results\eval_indicators_整体_dynamic.xlsx') 253 | # model.eval() 254 | # eval_loss = [] 255 | # with torch.no_grad(): 256 | # for dynamic_inputs, static_inputs, targets in data_loader: 257 | # dynamic_inputs = dynamic_inputs.to('cuda' if torch.cuda.is_available() else 'cpu') 258 | # static_inputs = static_inputs.to('cuda' if torch.cuda.is_available() else 'cpu') 259 | # targets = targets.to('cuda' if torch.cuda.is_available() else 'cpu') 260 | # # 前向传播 261 | # outputs = model(dynamic_inputs, static_inputs) 262 | # # 计算损失 263 | # loss = criterion(outputs, targets) 264 | # r2 = cal_r2(outputs, targets) 265 | # print('预测项:', outputs) 266 | # print('目标项:', targets) 267 | # print(f'MSE Loss: {loss.item()}') 268 | # break 269 | # eval_loss.append(loss.item()) 270 | # print(f'Loss: {np.mean(eval_loss)}') 271 | # print(f'R2:', r2) 272 | 273 | 274 | 275 | # # 取 276 | # with h5py.File(r'E:\FeaturesTargets\features_targets.h5', 'r') as h5: 277 | # features = np.transpose(h5['2003/features1'][:], (1, 0, 2)) # shape=(样本数, 时间步, 特征项) 278 | # targets = np.transpose(h5['2003/targets'][:], (1, 0)) # shape=(样本数, 时间步) 279 | # static_features = np.column_stack((h5['2003/features2'][:], h5['dem'][:])) 280 | # mask1 = ~np.any(np.isnan(features), axis=(1, 2)) 281 | # mask2 = ~np.any(np.isnan(targets), axis=(1,)) 282 | # mask3 = ~np.any(np.isnan(static_features), axis=(1, )) 283 | # mask = (mask1 & mask2 & mask3) 284 | # features = features[mask, :, :] 285 | # targets = targets[mask, :] 286 | # static_features = static_features[mask, :] 287 | # print(features.shape) 288 | # print(targets.shape) 289 | # for ix in range(6): 290 | # feature = features[:, :, ix] 291 | # features[:, :, ix] = (feature - feature.mean()) / feature.std() 292 | # if ix <= 1: 293 | # feature = static_features[:, ix] 294 | # static_features[:, ix] = (feature - feature.mean()) / feature.std() 295 | # 296 | # features_tensor = torch.tensor(features, dtype=torch.float32) 297 | # targets_tensor = torch.tensor(targets, dtype=torch.float32) 298 | # static_features_tensor = torch.tensor(static_features, dtype=torch.float32) 299 | # 300 | # # 创建包含动态特征、静态特征和目标的数据集 301 | # dataset = TensorDataset(features_tensor, static_features_tensor, targets_tensor) 302 | # train_dataset, eval_dataset = random_split(dataset, [8000, 10238 - 8000]) -------------------------------------------------------------------------------- /Core/Plot/plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os.path\n", 12 | "import pandas as pd\n", 13 | "import numpy as np\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import seaborn as sns\n", 16 | "\n", 17 | "# 准备\n", 18 | "ewtc_path = r'H:\\Datasets\\Objects\\Veg\\LXB_plot\\Data\\sw_EWTC_T.csv'\n", 19 | "twsc_path = r'H:\\Datasets\\Objects\\Veg\\LXB_plot\\Data\\sw_TWSC_SH_T.csv'\n", 20 | "ndvi_path = r'H:\\Datasets\\Objects\\Veg\\LXB_plot\\Data\\sw_NDVI_T.csv'\n", 21 | "out_dir = r'H:\\Datasets\\Objects\\Veg\\LXB_plot'\n", 22 | "sns.set_style('darkgrid')\n", 23 | "plt.rcParams['font.sans-serif'] = ['Times New Roman'] # 新罗马字体\n", 24 | "\n", 25 | "\n", 26 | "# 绘制\n", 27 | "ewtc = pd.read_csv(ewtc_path)\n", 28 | "twsc = pd.read_csv(twsc_path)\n", 29 | "ndvi = pd.read_csv(ndvi_path)\n", 30 | "awc = pd.merge(ewtc, twsc, left_on='ProvinceNa', right_on='ProvinceNa', suffixes=('_ewtc', '_twsc'))\n", 31 | "awc['AWC'] = awc['EWTC'] - awc['TWSC_SH']\n", 32 | "\n", 33 | "# 绘制NDVI年变化折线图和月变化柱状图\n", 34 | "ndvi_monthly = ndvi[['Month', 'NDVI']].groupby('Month').mean()\n", 35 | "ndvi_yearly = ndvi[['Year', 'NDVI']].groupby('Year').mean()\n", 36 | "ndvi_yearly['Year'] = ndvi_yearly.index\n", 37 | "ndvi_monthly['Month'] = ndvi_monthly.index" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "outputs": [ 44 | { 45 | "ename": "NoMatchingVersions", 46 | "evalue": "No matches for version='5.16.3' among ['4.0.2', '4.8.1', '4.17.0'].\nOften this can be fixed by updating altair_viewer:\n pip install -U altair_viewer", 47 | "output_type": "error", 48 | "traceback": [ 49 | "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", 50 | "\u001B[1;31mNoMatchingVersions\u001B[0m Traceback (most recent call last)", 51 | "Cell \u001B[1;32mIn[2], line 8\u001B[0m\n\u001B[0;32m 2\u001B[0m chart \u001B[38;5;241m=\u001B[39m alt\u001B[38;5;241m.\u001B[39mChart(ndvi_monthly)\u001B[38;5;241m.\u001B[39mmark_bar()\u001B[38;5;241m.\u001B[39mencode(\n\u001B[0;32m 3\u001B[0m x\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mMonth\u001B[39m\u001B[38;5;124m'\u001B[39m,\n\u001B[0;32m 4\u001B[0m y\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mNDVI\u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[0;32m 5\u001B[0m )\n\u001B[0;32m 7\u001B[0m \u001B[38;5;66;03m# 显示图表\u001B[39;00m\n\u001B[1;32m----> 8\u001B[0m \u001B[43mchart\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mshow\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", 52 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair\\vegalite\\v5\\api.py:2691\u001B[0m, in \u001B[0;36mTopLevelMixin.show\u001B[1;34m(self, embed_opt, open_browser)\u001B[0m\n\u001B[0;32m 2686\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mImportError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n\u001B[0;32m 2687\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m 2688\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mshow\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m method requires the altair_viewer package. \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 2689\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mSee http://github.com/altair-viz/altair_viewer\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 2690\u001B[0m ) \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01merr\u001B[39;00m\n\u001B[1;32m-> 2691\u001B[0m \u001B[43maltair_viewer\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mshow\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43membed_opt\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43membed_opt\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mopen_browser\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mopen_browser\u001B[49m\u001B[43m)\u001B[49m\n", 53 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_viewer.py:355\u001B[0m, in \u001B[0;36mChartViewer.show\u001B[1;34m(self, chart, embed_opt, open_browser)\u001B[0m\n\u001B[0;32m 328\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mshow\u001B[39m(\n\u001B[0;32m 329\u001B[0m \u001B[38;5;28mself\u001B[39m,\n\u001B[0;32m 330\u001B[0m chart: Union[\u001B[38;5;28mdict\u001B[39m, alt\u001B[38;5;241m.\u001B[39mTopLevelMixin],\n\u001B[0;32m 331\u001B[0m embed_opt: Optional[\u001B[38;5;28mdict\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[0;32m 332\u001B[0m open_browser: Optional[\u001B[38;5;28mbool\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[0;32m 333\u001B[0m ) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 334\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"Show chart and prompt to pause execution.\u001B[39;00m\n\u001B[0;32m 335\u001B[0m \n\u001B[0;32m 336\u001B[0m \u001B[38;5;124;03m Use this to show a chart within a stand-alone script, to prevent the Python process\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 353\u001B[0m \u001B[38;5;124;03m render : Jupyter renderer for chart.\u001B[39;00m\n\u001B[0;32m 354\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[1;32m--> 355\u001B[0m msg \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdisplay\u001B[49m\u001B[43m(\u001B[49m\u001B[43mchart\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43membed_opt\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43membed_opt\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mopen_browser\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mopen_browser\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 356\u001B[0m \u001B[38;5;28mprint\u001B[39m(msg)\n\u001B[0;32m 357\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_provider \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n", 54 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_viewer.py:266\u001B[0m, in \u001B[0;36mChartViewer.display\u001B[1;34m(self, chart, inline, embed_opt, open_browser)\u001B[0m\n\u001B[0;32m 264\u001B[0m chart \u001B[38;5;241m=\u001B[39m chart\u001B[38;5;241m.\u001B[39mto_dict()\n\u001B[0;32m 265\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(chart, \u001B[38;5;28mdict\u001B[39m)\n\u001B[1;32m--> 266\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_initialize\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 267\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_stream \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 268\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mRuntimeError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mInternal: _stream is not defined.\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n", 55 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_viewer.py:183\u001B[0m, in \u001B[0;36mChartViewer._initialize\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 180\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_use_bundled_js:\n\u001B[0;32m 181\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m package \u001B[38;5;129;01min\u001B[39;00m [\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mvega\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mvega-lite\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mvega-embed\u001B[39m\u001B[38;5;124m\"\u001B[39m]:\n\u001B[0;32m 182\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_resources[package] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_provider\u001B[38;5;241m.\u001B[39mcreate(\n\u001B[1;32m--> 183\u001B[0m content\u001B[38;5;241m=\u001B[39m\u001B[43mget_bundled_script\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 184\u001B[0m \u001B[43m \u001B[49m\u001B[43mpackage\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_versions\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mget\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpackage\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 185\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m,\n\u001B[0;32m 186\u001B[0m route\u001B[38;5;241m=\u001B[39m\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mscripts/\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpackage\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.js\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 187\u001B[0m )\n\u001B[0;32m 189\u001B[0m favicon \u001B[38;5;241m=\u001B[39m pkgutil\u001B[38;5;241m.\u001B[39mget_data(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124maltair_viewer\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mstatic/favicon.ico\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 190\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m favicon \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n", 56 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_scripts.py:40\u001B[0m, in \u001B[0;36mget_bundled_script\u001B[1;34m(package, version)\u001B[0m\n\u001B[0;32m 36\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m package \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m listing:\n\u001B[0;32m 37\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m 38\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpackage \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpackage\u001B[38;5;132;01m!r}\u001B[39;00m\u001B[38;5;124m not recognized. Available: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mlist\u001B[39m(listing)\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 39\u001B[0m )\n\u001B[1;32m---> 40\u001B[0m version_str \u001B[38;5;241m=\u001B[39m \u001B[43mfind_version\u001B[49m\u001B[43m(\u001B[49m\u001B[43mversion\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlisting\u001B[49m\u001B[43m[\u001B[49m\u001B[43mpackage\u001B[49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 41\u001B[0m path \u001B[38;5;241m=\u001B[39m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mscripts/\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpackage\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m-\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mversion_str\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.js\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 42\u001B[0m content \u001B[38;5;241m=\u001B[39m pkgutil\u001B[38;5;241m.\u001B[39mget_data(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124maltair_viewer\u001B[39m\u001B[38;5;124m\"\u001B[39m, path)\n", 57 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_utils.py:212\u001B[0m, in \u001B[0;36mfind_version\u001B[1;34m(version, candidates, strict_micro)\u001B[0m\n\u001B[0;32m 210\u001B[0m matches \u001B[38;5;241m=\u001B[39m [c \u001B[38;5;28;01mfor\u001B[39;00m c \u001B[38;5;129;01min\u001B[39;00m cand \u001B[38;5;28;01mif\u001B[39;00m v\u001B[38;5;241m.\u001B[39mmatches(c)]\n\u001B[0;32m 211\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m matches:\n\u001B[1;32m--> 212\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m NoMatchingVersions(\n\u001B[0;32m 213\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mNo matches for version=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mversion\u001B[38;5;132;01m!r}\u001B[39;00m\u001B[38;5;124m among \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mcandidates\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 214\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mOften this can be fixed by updating altair_viewer:\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 215\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m pip install -U altair_viewer\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 216\u001B[0m )\n\u001B[0;32m 217\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mstr\u001B[39m(matches[\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m])\n", 58 | "\u001B[1;31mNoMatchingVersions\u001B[0m: No matches for version='5.16.3' among ['4.0.2', '4.8.1', '4.17.0'].\nOften this can be fixed by updating altair_viewer:\n pip install -U altair_viewer" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "import altair as alt\n", 64 | "chart = alt.Chart(ndvi_monthly).mark_bar().encode(\n", 65 | " x='Month',\n", 66 | " y='NDVI'\n", 67 | ")\n", 68 | "\n", 69 | "# 显示图表\n", 70 | "chart.show()" 71 | ], 72 | "metadata": { 73 | "collapsed": false 74 | } 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "outputs": [], 80 | "source": [], 81 | "metadata": { 82 | "collapsed": false 83 | } 84 | } 85 | ], 86 | "metadata": { 87 | "kernelspec": { 88 | "display_name": "Python 3", 89 | "language": "python", 90 | "name": "python3" 91 | }, 92 | "language_info": { 93 | "codemirror_mode": { 94 | "name": "ipython", 95 | "version": 2 96 | }, 97 | "file_extension": ".py", 98 | "mimetype": "text/x-python", 99 | "name": "python", 100 | "nbconvert_exporter": "python", 101 | "pygments_lexer": "ipython2", 102 | "version": "2.7.6" 103 | } 104 | }, 105 | "nbformat": 4, 106 | "nbformat_minor": 0 107 | } 108 | -------------------------------------------------------------------------------- /Core/model_train.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2024/1/3 16:54 3 | # @Email : chaoqiezi.one@qq.com 4 | 5 | """ 6 | This script is used to 构建lstm模型并训练 7 | 8 | 2024/5/11 增加Rs、Lon、Lat输入到模型中训练 9 | 10 | 2024/5/11 仅使用动态特征中的LST, PRCP, ET, RS进行训练 11 | """ 12 | 13 | import random 14 | import glob 15 | import os.path 16 | import numpy as np 17 | import pandas as pd 18 | import torch 19 | from torchsummary import summary 20 | from torch.utils.data import DataLoader, random_split 21 | from VEG.utils.utils import H5DatasetDecoder, cal_r2 22 | from VEG.utils.models import LSTMModel, LSTMModelDynamic 23 | from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error 24 | 25 | 26 | def set_seed(seed=42): 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | os.environ['PYTHONHASHSEED'] = str(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) # 如果使用多GPU 33 | torch.backends.cudnn.deterministic = True 34 | torch.backends.cudnn.benchmark = False 35 | 36 | 37 | set_seed(42) 38 | 39 | # 准备 40 | train_path = r'E:\FeaturesTargets\train.h5' 41 | eval_path = r'E:\FeaturesTargets\eval.h5' 42 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 43 | out_model_dir = r'E:\Models' 44 | dynamic_features_name = [ 45 | 'LST_MAX', 46 | 'PRCP', 47 | 'ET', 48 | # 'Qs', 49 | # 'Qsb', 50 | # 'TWSC', 51 | 'Rs' 52 | ] 53 | static_feature_name = [ 54 | 'Slope', 55 | 'DEM', 56 | 'Lon', 57 | 'Lat' 58 | ] 59 | # 创建LSTM模型实例并移至GPU 60 | model = LSTMModel(4, 256, 4, 12).to('cuda' if torch.cuda.is_available() else 'cpu') 61 | summary(model, input_data=[(12, 4), (4,)]) 62 | # model = LSTMModelDynamic(4, 256, 4, 12).to('cuda' if torch.cuda.is_available() else 'cpu') 63 | # summary(model, input_data=[(12, 4)]) 64 | batch_size = 256 65 | 66 | # generator = torch.Generator().manual_seed(42) # 指定随机种子 67 | # train_dataset, eval_dataset, sample_dataset = random_split(dataset, (0.8, 0.195, 0.005), generator=generator) 68 | # train_dataset, eval_dataset = random_split(dataset, (0.8, 0.2), generator=generator) 69 | # 创建数据加载器 70 | train_dataset = H5DatasetDecoder(train_path) # 创建自定义数据集实例 71 | eval_dataset = H5DatasetDecoder(eval_path) 72 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 73 | eval_data_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True) 74 | # 训练参数 75 | criterion = torch.nn.MSELoss() 76 | optimizer = torch.optim.Adam(model.parameters(), lr=0.002) # 初始学习率设置为0.001 77 | epochs_num = 30 78 | model.train() # 切换为训练模式 79 | 80 | 81 | def model_train(data_loader, epochs_num: int = 25, save_path: str = None, device='cuda'): 82 | # 创建新的模型实例 83 | model = LSTMModel(4, 256, 4, 12).to(device) 84 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) # 初始学习率设置为0.001 85 | epochs_loss = [] 86 | for epoch in range(epochs_num): 87 | train_loss = [] 88 | for dynamic_inputs, static_inputs, targets in data_loader: 89 | # if feature_ix is not None: 90 | # if dynamic: 91 | # batch_size, _, _ = dynamic_inputs.shape 92 | # shuffled_indices = torch.randperm(batch_size) 93 | # # dynamic_inputs[:, :, feature_ix] = torch.tensor(np.random.permutation(dynamic_inputs[:, :, feature_ix])) 94 | # dynamic_inputs[:, :, feature_ix] = torch.tensor(dynamic_inputs[shuffled_indices, :, feature_ix]) 95 | # else: 96 | # batch_size, _ = static_inputs.shape 97 | # shuffled_indices = torch.randperm(batch_size) 98 | # # static_inputs[:, feature_ix] = torch.tensor(np.random.permutation(static_inputs[shuffled_indices, feature_ix])) 99 | # static_inputs[:, feature_ix] = torch.tensor(static_inputs[shuffled_indices, feature_ix]) 100 | dynamic_inputs, static_inputs, targets = dynamic_inputs.to(device), static_inputs.to(device), targets.to( 101 | device) 102 | 103 | """正常""" 104 | # 前向传播 105 | outputs = model(dynamic_inputs, static_inputs) 106 | # 计算损失 107 | loss = criterion(outputs, targets) 108 | # 反向传播和优化 109 | loss.backward() 110 | optimizer.step() 111 | # scheduler.step() # 更新学习率 112 | 113 | optimizer.zero_grad() # 清除梯度 114 | train_loss.append(loss.item()) 115 | print(f'Epoch {epoch + 1}/{epochs_num}, Loss: {np.mean(train_loss)}') 116 | epochs_loss.append(np.mean(train_loss)) 117 | 118 | if save_path: 119 | torch.save(model.state_dict(), save_path) 120 | 121 | return epochs_loss 122 | 123 | 124 | def model_eval_whole(model_path: str, data_loader, device='cuda'): 125 | # 加载模型 126 | model = LSTMModel(4, 256, 4, 12).to(device) 127 | model.load_state_dict(torch.load(model_path)) 128 | 129 | # 评估 130 | model.eval() # 评估模式 131 | all_outputs = [] 132 | all_targets = [] 133 | with torch.no_grad(): 134 | for dynamic_inputs, static_inputs, targets in data_loader: 135 | dynamic_inputs, static_inputs, targets = dynamic_inputs.to(device), static_inputs.to(device), targets.to( 136 | device) 137 | outputs = model(dynamic_inputs, static_inputs) 138 | all_outputs.append(outputs.cpu()) # outputs/targets: (batch_size, time_steps) 139 | all_targets.append(targets.cpu()) 140 | 141 | all_outputs = np.concatenate(all_outputs, axis=0) 142 | all_targets = np.concatenate(all_targets, axis=0) 143 | 144 | # mse_per_step = [] 145 | # mae_per_step = [] 146 | # r2_per_step = [] 147 | # rmse_per_step = [] 148 | # for time_step in range(12): 149 | # mse_step = mean_squared_error(all_targets[:, time_step], all_outputs[:, time_step]) 150 | # mae_step = mean_absolute_error(all_targets[:, time_step], all_outputs[:, time_step]) 151 | # r2_step = r2_score(all_targets[:, time_step], all_outputs[:, time_step]) 152 | # rmse_step = np.sqrt(mse_step) 153 | # 154 | # mse_per_step.append(mse_step) 155 | # mae_per_step.append(mae_step) 156 | # r2_per_step.append(r2_step) 157 | # rmse_per_step.append(rmse_step) 158 | 159 | # mse = np.mean(mse_per_step) 160 | # mae = np.mean(mae_per_step) 161 | # r2 = np.mean(r2_per_step) 162 | # rmse = np.mean(rmse_per_step) 163 | 164 | # 不区分月份求取指标(视为整体) 165 | mse_step = mean_squared_error(all_targets.reshape(-1), all_outputs.reshape(-1)) 166 | mae_step = mean_absolute_error(all_targets.reshape(-1), all_outputs.reshape(-1)) 167 | r2_step = r2_score(all_targets.reshape(-1), all_outputs.reshape(-1)) 168 | rmse_step = np.sqrt(mse_step) 169 | return mse_step, mae_step, r2_step, rmse_step 170 | 171 | # return mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets 172 | 173 | 174 | 175 | def model_eval(model_path: str, data_loader, device='cuda'): 176 | # 加载模型 177 | model = LSTMModel(4, 256, 4, 12).to(device) 178 | model.load_state_dict(torch.load(model_path)) 179 | 180 | # 评估 181 | model.eval() # 评估模式 182 | all_outputs = [] 183 | all_targets = [] 184 | with torch.no_grad(): 185 | for dynamic_inputs, static_inputs, targets in data_loader: 186 | dynamic_inputs, static_inputs, targets = dynamic_inputs.to(device), static_inputs.to(device), targets.to( 187 | device) 188 | outputs = model(dynamic_inputs, static_inputs) 189 | all_outputs.append(outputs.cpu()) # outputs/targets: (batch_size, time_steps) 190 | all_targets.append(targets.cpu()) 191 | 192 | all_outputs = np.concatenate(all_outputs, axis=0) 193 | all_targets = np.concatenate(all_targets, axis=0) 194 | 195 | mse_per_step = [] 196 | mae_per_step = [] 197 | r2_per_step = [] 198 | rmse_per_step = [] 199 | for time_step in range(12): 200 | mse_step = mean_squared_error(all_targets[:, time_step], all_outputs[:, time_step]) 201 | mae_step = mean_absolute_error(all_targets[:, time_step], all_outputs[:, time_step]) 202 | r2_step = r2_score(all_targets[:, time_step], all_outputs[:, time_step]) 203 | rmse_step = np.sqrt(mse_step) 204 | 205 | mse_per_step.append(mse_step) 206 | mae_per_step.append(mae_step) 207 | r2_per_step.append(r2_step) 208 | rmse_per_step.append(rmse_step) 209 | 210 | return mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets 211 | 212 | if __name__ == '__main__': 213 | df = pd.DataFrame() 214 | # 常规训练 215 | df['normal_epochs_loss'] = model_train(train_data_loader, save_path=os.path.join(out_model_dir, 'normal_model_V07.pth')) 216 | print('>>> 常规训练结束') 217 | # 特征重要性训练 218 | # 动态特征 219 | for feature_ix in range(4): 220 | train_dataset = H5DatasetDecoder(train_path, shuffle_feature_ix=feature_ix, dynamic=True) # 创建自定义数据集实例 221 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 222 | 223 | cur_feature_name = dynamic_features_name[feature_ix] 224 | save_path = os.path.join(out_model_dir, cur_feature_name + '_model_V07.pth') 225 | df[cur_feature_name + '_epochs_loss'] = \ 226 | model_train(train_data_loader, save_path=save_path) 227 | print('>>> {}乱序排列 训练结束'.format(cur_feature_name)) 228 | # 静态特征 229 | for feature_ix in range(4): 230 | train_dataset = H5DatasetDecoder(train_path, shuffle_feature_ix=feature_ix, dynamic=False) # 创建自定义数据集实例 231 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 232 | 233 | cur_feature_name = static_feature_name[feature_ix] 234 | save_path = os.path.join(out_model_dir, cur_feature_name + '_model_V07.pth') 235 | df[cur_feature_name + '_epochs_loss'] = \ 236 | model_train(train_data_loader, save_path=save_path) 237 | print('>>> {}乱序排列 训练结束'.format(cur_feature_name)) 238 | df.to_excel(r'E:\Models\training_eval_results\training_loss_V07.xlsx') 239 | 240 | # 评估 241 | indicator_whole = pd.DataFrame() 242 | indicator = pd.DataFrame() 243 | model_paths = glob.glob(os.path.join(out_model_dir, '*_model_V07.pth')) 244 | for model_path in model_paths: 245 | cur_model_name = os.path.basename(model_path).rsplit('_model')[0] 246 | mse_step, mae_step, r2_step, rmse_step = model_eval_whole(model_path, eval_data_loader) 247 | indicator_whole[cur_model_name + '_evaluate_mse'] = [mse_step] 248 | indicator_whole[cur_model_name + '_evaluate_mae'] = [mae_step] 249 | indicator_whole[cur_model_name + '_evaluate_r2'] = [r2_step] 250 | indicator_whole[cur_model_name + '_evaluate_rmse'] = [rmse_step] 251 | 252 | mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets = model_eval(model_path, eval_data_loader) 253 | 254 | all_outputs_targets = np.concatenate((all_outputs, all_targets), axis=1) 255 | columns = [*['outputs_{:02}'.format(month) for month in range(1, 13)], *['targets_{:02}'.format(month) for month in range(1, 13)]] 256 | outputs_targets = pd.DataFrame(all_outputs_targets, columns=columns) 257 | indicator[cur_model_name + '_evaluate_mse'] = mse_per_step 258 | indicator[cur_model_name + '_evaluate_mae'] = mae_per_step 259 | indicator[cur_model_name + '_evaluate_r2'] = r2_per_step 260 | indicator[cur_model_name + '_evaluate_rmse'] = rmse_per_step 261 | outputs_targets.to_excel(r'E:\Models\training_eval_results\{}_outputs_targets_V07.xlsx'.format(cur_model_name)) 262 | print('>>> {} 重要性评估完毕'.format(cur_model_name)) 263 | indicator.loc['均值指标'] = np.mean(indicator, axis=0) 264 | indicator.to_excel(r'E:\Models\training_eval_results\eval_indicators_V07.xlsx') 265 | indicator_whole.to_excel(r'E:\Models\training_eval_results\eval_indicators_整体_V07.xlsx') 266 | # model.eval() 267 | # eval_loss = [] 268 | # with torch.no_grad(): 269 | # for dynamic_inputs, static_inputs, targets in data_loader: 270 | # dynamic_inputs = dynamic_inputs.to('cuda' if torch.cuda.is_available() else 'cpu') 271 | # static_inputs = static_inputs.to('cuda' if torch.cuda.is_available() else 'cpu') 272 | # targets = targets.to('cuda' if torch.cuda.is_available() else 'cpu') 273 | # # 前向传播 274 | # outputs = model(dynamic_inputs, static_inputs) 275 | # # 计算损失 276 | # loss = criterion(outputs, targets) 277 | # r2 = cal_r2(outputs, targets) 278 | # print('预测项:', outputs) 279 | # print('目标项:', targets) 280 | # print(f'MSE Loss: {loss.item()}') 281 | # break 282 | # eval_loss.append(loss.item()) 283 | # print(f'Loss: {np.mean(eval_loss)}') 284 | # print(f'R2:', r2) 285 | 286 | 287 | 288 | # # 取 289 | # with h5py.File(r'E:\FeaturesTargets\features_targets.h5', 'r') as h5: 290 | # features = np.transpose(h5['2003/features1'][:], (1, 0, 2)) # shape=(样本数, 时间步, 特征项) 291 | # targets = np.transpose(h5['2003/targets'][:], (1, 0)) # shape=(样本数, 时间步) 292 | # static_features = np.column_stack((h5['2003/features2'][:], h5['dem'][:])) 293 | # mask1 = ~np.any(np.isnan(features), axis=(1, 2)) 294 | # mask2 = ~np.any(np.isnan(targets), axis=(1,)) 295 | # mask3 = ~np.any(np.isnan(static_features), axis=(1, )) 296 | # mask = (mask1 & mask2 & mask3) 297 | # features = features[mask, :, :] 298 | # targets = targets[mask, :] 299 | # static_features = static_features[mask, :] 300 | # print(features.shape) 301 | # print(targets.shape) 302 | # for ix in range(6): 303 | # feature = features[:, :, ix] 304 | # features[:, :, ix] = (feature - feature.mean()) / feature.std() 305 | # if ix <= 1: 306 | # feature = static_features[:, ix] 307 | # static_features[:, ix] = (feature - feature.mean()) / feature.std() 308 | # 309 | # features_tensor = torch.tensor(features, dtype=torch.float32) 310 | # targets_tensor = torch.tensor(targets, dtype=torch.float32) 311 | # static_features_tensor = torch.tensor(static_features, dtype=torch.float32) 312 | # 313 | # # 创建包含动态特征、静态特征和目标的数据集 314 | # dataset = TensorDataset(features_tensor, static_features_tensor, targets_tensor) 315 | # train_dataset, eval_dataset = random_split(dataset, [8000, 10238 - 8000]) -------------------------------------------------------------------------------- /Core/process_modis.py: -------------------------------------------------------------------------------- 1 | # @Author : ChaoQiezi 2 | # @Time : 2023/12/14 6:31 3 | # @FileName : process_modis.py 4 | # @Email : chaoqiezi.one@qq.com 5 | 6 | """ 7 | This script is used to 对MODIS GRID产品(hdf4文件)进行批量镶嵌和重投影并输出为GeoTIFF文件 8 | 9 | <说明> 10 | # pyhdf模块相关 11 | 对于读取HDF4文件的pyhdf模块需要依据python版本安装指定的whl文件才可正常运行, 12 | 下载wheel文件见: https://www.lfd.uci.edu/~gohlke/pythonlibs/ 13 | 安装: cmd ==> where python ==> 跳转指定python路径 ==> cd Scripts ==> pip install wheel文件的绝对路径 14 | 15 | # 数据集 16 | MCD12Q1为土地利用数据 17 | MOD11A2为地表温度数据 18 | MOD13A2为植被指数数据(包括NDVI\EVI) 19 | 20 | # 相关链接 21 | CSDN博客: https://blog.csdn.net/m0_63001937/article/details/134995867 22 | 微信博文: https://mp.weixin.qq.com/s/6oeUEdazz8FL1pRnQQFhMA 23 | 24 | """ 25 | 26 | import os 27 | import re 28 | import time 29 | from glob import glob 30 | from typing import Union 31 | from datetime import datetime 32 | from math import ceil, floor 33 | from threading import Lock 34 | from concurrent.futures import ThreadPoolExecutor # 线程池 35 | 36 | import numpy as np 37 | from pyhdf.SD import SD 38 | from osgeo import gdal, osr 39 | from scipy import stats 40 | 41 | 42 | def img_mosaic(mosaic_paths: list, mosaic_ds_name: str, return_all: bool = True, img_nodata: Union[int, float] = -1, 43 | img_type: Union[np.int32, np.float32, None] = None, unit_conversion: bool = False, 44 | scale_factor_op: str = 'multiply', mosaic_mode: str = 'last'): 45 | """ 46 | 该函数用于对列表中的所有HDF4文件进行镶嵌 47 | :param mosaic_mode: 镶嵌模式, 默认是Last(即如果有存在像元重叠, mosaic_paths中靠后影像的像元将覆盖其), 48 | 可选: last, mean, max, min 49 | :param scale_factor_op: 比例因子的运算符, 默认是乘以(可选: multiply, divide), 该参数尽在unit_conversion为True时生效 50 | :param unit_conversion: 是否进行单位换算 51 | :param mosaic_ds_name: 待镶嵌的数据集名称 52 | :param mosaic_paths: 多个HDF4文件路径组成的字符串列表 53 | :param return_all: 是否一同返回仿射变换、镶嵌数据集的坐标系等参数 54 | :return: 默认返回镶嵌好的数据集 55 | :param img_type: 待镶嵌影像的数据类型 56 | :param img_nodata: 影像中的无效值设置 57 | 58 | 镶嵌策略是last模式, 59 | """ 60 | 61 | # 获取镶嵌范围 62 | x_mins, x_maxs, y_mins, y_maxs = [], [], [], [] 63 | for mosaic_path in mosaic_paths: 64 | hdf = SD(mosaic_path) # 默认只读 65 | # 获取元数据 66 | metadata = hdf.__getattr__('StructMetadata.0') 67 | # 获取角点信息 68 | ul_pt = [float(x) for x in re.findall(r'UpperLeftPointMtrs=\((.*)\)', metadata)[0].split(',')] 69 | lr_pt = [float(x) for x in re.findall(r'LowerRightMtrs=\((.*)\)', metadata)[0].split(',')] 70 | x_mins.append(ul_pt[0]) 71 | x_maxs.append(lr_pt[0]) 72 | y_mins.append(lr_pt[1]) 73 | y_maxs.append(ul_pt[1]) 74 | else: 75 | # 计算分辨率 76 | col = int(re.findall(r'XDim=(.*?)\n', metadata)[0]) 77 | row = int(re.findall(r'YDim=(.*?)\n', metadata)[0]) 78 | x_res = (lr_pt[0] - ul_pt[0]) / col 79 | y_res = (ul_pt[1] - lr_pt[1]) / row 80 | # 如果img_type没有指定, 那么数据类型默认为与输入相同 81 | if img_type is None: 82 | img_type = hdf.select(mosaic_ds_name)[:].dtype 83 | # 获取数据集的坐标系参数并转化为proj4字符串格式 84 | projection_param = [float(_param) for _param in re.findall(r'ProjParams=\((.*?)\)', metadata)[0].split(',')] 85 | mosaic_img_proj4 = "+proj={} +R={:0.4f} +lon_0={:0.4f} +lat_0={:0.4f} +x_0={:0.4f} " \ 86 | "+y_0={:0.4f} ".format('sinu', projection_param[0], projection_param[4], projection_param[5], 87 | projection_param[6], projection_param[7]) 88 | # 关闭文件, 释放资源 89 | hdf.end() 90 | x_min, x_max, y_min, y_max = min(x_mins), max(x_maxs), min(y_mins), max(y_maxs) 91 | 92 | # 镶嵌 93 | col = ceil((x_max - x_min) / x_res) 94 | row = ceil((y_max - y_min) / y_res) 95 | mosaic_imgs = [] # 用于存储各个影像 96 | for ix, mosaic_path in enumerate(mosaic_paths): 97 | mosaic_img = np.full((row, col), img_nodata, dtype=img_type) # 初始化 98 | hdf = SD(mosaic_path) 99 | target_ds = hdf.select(mosaic_ds_name) 100 | # 读取数据集和预处理 101 | target = target_ds.get().astype(img_type) 102 | valid_range = target_ds.attributes()['valid_range'] 103 | target[(target < valid_range[0]) | (target > valid_range[1])] = img_nodata # 限定有效范围 104 | if unit_conversion: # 进行单位换算 105 | scale_factor = target_ds.attributes()['scale_factor'] 106 | add_offset = target_ds.attributes()['add_offset'] 107 | # 判断比例因子的运算符 108 | if scale_factor_op == 'multiply': 109 | target[target != img_nodata] = target[target != img_nodata] * scale_factor + add_offset 110 | elif scale_factor_op == 'divide': 111 | target[target != img_nodata] = target[target != img_nodata] / scale_factor + add_offset 112 | # 计算当前镶嵌范围 113 | start_row = floor((y_max - (y_maxs[ix] - x_res / 2)) / y_res) 114 | start_col = floor(((x_mins[ix] + x_res / 2) - x_min) / x_res) 115 | end_row = start_row + target.shape[0] 116 | end_col = start_col + target.shape[1] 117 | mosaic_img[start_row:end_row, start_col:end_col] = target 118 | mosaic_imgs.append(mosaic_img) 119 | 120 | # 释放资源 121 | target_ds.endaccess() 122 | hdf.end() 123 | 124 | # 判断镶嵌模式 125 | if mosaic_mode == 'last': 126 | mosaic_img = mosaic_imgs[0].copy() 127 | for img in mosaic_imgs: 128 | mosaic_img[img != img_nodata] = img[img != img_nodata] 129 | elif mosaic_mode == 'mean': 130 | mosaic_imgs = np.asarray(mosaic_imgs) # mosaic_img.shape = (mosaic_num, rows, cols) 131 | mask = mosaic_imgs == img_nodata 132 | mosaic_img = np.ma.array(mosaic_imgs, mask=mask).mean(axis=0).filled(img_nodata) 133 | elif mosaic_mode == 'max': 134 | mosaic_imgs = np.asarray(mosaic_imgs) # mosaic_img.shape = (mosaic_num, rows, cols) 135 | mask = mosaic_imgs == img_nodata 136 | mosaic_img = np.ma.array(mosaic_imgs, mask=mask).max(axis=0).filled(img_nodata) 137 | elif mosaic_mode == 'min': 138 | mosaic_imgs = np.asarray(mosaic_imgs) # mosaic_img.shape = (mosaic_num, rows, cols) 139 | mask = mosaic_imgs == img_nodata 140 | mosaic_img = np.ma.array(mosaic_imgs, mask=mask).min(axis=0).filled(img_nodata) 141 | else: 142 | raise ValueError('不支持的镶嵌模式: {}'.format(mosaic_mode)) 143 | 144 | if return_all: 145 | return mosaic_img, [x_min, x_res, 0, y_max, 0, -y_res], mosaic_img_proj4 146 | 147 | return mosaic_img 148 | 149 | 150 | def img_warp(src_img: np.ndarray, out_path: str, transform: list, src_proj4: str, out_res: float, 151 | nodata: Union[int, float] = None, resample: str = 'nearest') -> None: 152 | """ 153 | 该函数用于对正弦投影下的栅格矩阵进行重投影(GLT校正), 得到WGS84坐标系下的栅格矩阵并输出为TIFF文件 154 | :param src_img: 待重投影的栅格矩阵 155 | :param out_path: 输出路径 156 | :param transform: 仿射变换参数([x_min, x_res, 0, y_max, 0, -y_res], 旋转参数为0是常规选项) 157 | :param out_res: 输出的分辨率(栅格方形) 158 | :param nodata: 设置为NoData的数值 159 | :param out_type: 输出的数据类型 160 | :param resample: 重采样方法(默认是最近邻, ['nearest', 'bilinear', 'cubic']) 161 | :param src_proj4: 表达源数据集(src_img)的坐标系参数(以proj4字符串形式) 162 | :return: None 163 | """ 164 | 165 | # 输出数据类型 166 | if np.issubdtype(src_img.dtype, np.integer): 167 | out_type = gdal.GDT_Int32 168 | elif np.issubdtype(src_img.dtype, np.floating): 169 | out_type = gdal.GDT_Float32 170 | else: 171 | raise ValueError("当前待校正数组类型为不支持的数据类型") 172 | resamples = {'nearest': gdal.GRA_NearestNeighbour, 'bilinear': gdal.GRA_Bilinear, 'cubic': gdal.GRA_Cubic} 173 | # 原始数据集创建(正弦投影) 174 | driver = gdal.GetDriverByName('MEM') # 在内存中临时创建 175 | src_ds = driver.Create("", src_img.shape[1], src_img.shape[0], 1, out_type) # 注意: 先传列数再传行数, 1表示单波段 176 | srs = osr.SpatialReference() 177 | srs.ImportFromProj4(src_proj4) 178 | """ 179 | 对于src_proj4, 依据元数据StructMetadata.0知: 180 | Projection=GCTP_SNSOID; ProjParams=(6371007.181000,0,0,0,0,0,0,0,0,0,0,0,0) 181 | 或数据集属性(MODIS_Grid_8Day_1km_LST/Data_Fields/Projection)知: 182 | :grid_mapping_name = "sinusoidal"; 183 | :longitude_of_central_meridian = 0.0; // double 184 | :earth_radius = 6371007.181; // double 185 | """ 186 | src_ds.SetProjection(srs.ExportToWkt()) # 设置投影信息 187 | src_ds.SetGeoTransform(transform) # 设置仿射参数 188 | src_ds.GetRasterBand(1).WriteArray(src_img) # 写入数据 189 | src_ds.GetRasterBand(1).SetNoDataValue(nodata) 190 | # 重投影信息(WGS84) 191 | dst_srs = osr.SpatialReference() 192 | dst_srs.ImportFromEPSG(4326) 193 | # 重投影 194 | dst_ds = gdal.Warp(out_path, src_ds, dstSRS=dst_srs, xRes=out_res, yRes=out_res, dstNodata=nodata, 195 | outputType=out_type, multithread=True, format='GTiff', resampleAlg=resamples[resample]) 196 | if dst_ds: # 释放缓存和资源 197 | dst_ds.FlushCache() 198 | src_ds, dst_ds = None, None 199 | 200 | 201 | def ydays2ym(file_path: str) -> str: 202 | """ 203 | 获取路径中的年积日并转化为年月日 204 | :param file_path: 文件路径 205 | :return: 返回表达年月日的字符串 206 | """ 207 | 208 | file_name = os.path.basename(file_path) 209 | ydays = file_name[9:16] 210 | date = datetime.strptime(ydays, "%Y%j") 211 | 212 | return date.strftime("%Y_%m") 213 | 214 | 215 | # 闭包 216 | def process_task(union_id, process_paths, ds_name, out_dir, description, nodata, out_res, resamlpe='nearest', 217 | temperature=False, img_type=np.float32, unit_conversion=True, scale_factor_op='multiply', 218 | mosaic_mode='last'): 219 | print_lock = Lock() # 线程锁 220 | 221 | # 处理 222 | def process_id(id: any = None): 223 | start_time = time.time() 224 | cur_mosaic_ixs = [_ix for _ix, _id in enumerate(union_id) if _id == id] 225 | # 镶嵌 226 | mosaic_paths = [process_paths[_ix] for _ix in cur_mosaic_ixs] 227 | mosaic_img, transform, mosaic_img_proj4 = img_mosaic(mosaic_paths, ds_name, img_nodata=nodata, 228 | img_type=img_type, unit_conversion=unit_conversion, 229 | scale_factor_op=scale_factor_op, mosaic_mode=mosaic_mode) 230 | if temperature: # 若设置temperature, 则说明当前处理数据集为地表温度, 需要开尔文 ==> 摄氏度 231 | mosaic_img[mosaic_img != nodata] -= 273.15 232 | # 重投影 233 | reproj_path = os.path.join(out_dir, description + '_' + id + '.tiff') 234 | img_warp(mosaic_img, reproj_path, transform, mosaic_img_proj4, out_res, nodata, resample=resamlpe) 235 | end_time = time.time() 236 | 237 | with print_lock: # 避免打印混乱 238 | print("{}-{} 处理完毕: {:0.2f}s".format(description, id, end_time - start_time)) 239 | 240 | return process_id 241 | 242 | 243 | # 准备 244 | in_dir = 'F:\DATA\Cy_modis' # F:\Cy_modis\MCD12Q1_2001_2020、F:\Cy_modis\MOD11A2_2000_2022、F:\Cy_modis\MOD13A2_2001_2020 245 | out_dir = 'H:\Datasets\Objects\Veg' 246 | landuse_name = 'LC_Type1' # Land Cover Type 1: Annual International Geosphere-Biosphere Programme (IGBP) classification 247 | lst_name = 'LST_Day_1km' 248 | ndvi_name = '1 km 16 days NDVI' # 注意panoply上显示为: 1_km_16_days_NDVI, 实际上是做了显示上的优化, 原始名称为当前 249 | evi_name = '1 km 16 days EVI' # 注意panoply上显示为: 1_km_16_days_NDVI, 实际上是做了显示上的优化, 原始名称为当前 250 | out_landuse_res = 0.0045 # 500m 251 | out_lst_res = 0.009 # 1000m 252 | out_ndvi_res = 0.009 253 | out_evi_res = 0.009 254 | # 预准备 255 | out_landuse_dir = os.path.join(out_dir, 'Landuse') 256 | out_lst_dir = os.path.join(out_dir, 'LST_MIN') 257 | out_ndvi_dir = os.path.join(out_dir, 'NDVI_MIN') 258 | out_evi_dir = os.path.join(out_dir, 'evi') 259 | _ = [os.makedirs(_dir, exist_ok=True) for _dir in [out_landuse_dir, out_lst_dir, out_ndvi_dir, out_evi_dir]] 260 | 261 | # # 对MCD12Q1数据集(土地利用数据集)进行镶嵌和重投影(GLT校正) 262 | # landuse_paths = glob(os.path.join(in_dir, '**', 'MCD12Q1*.hdf'), recursive=True) # 迭代 263 | # union_id = [os.path.basename(_path)[9:13] for _path in landuse_paths] # 基于年份进行合并镶嵌的字段(年份-此处) 264 | # unique_id = set(union_id) # unique_id = np.unique(np.asarray(union_id)) # 不使用set是为保证原始顺序 265 | # # 多线程处理 266 | # with ThreadPoolExecutor() as executer: 267 | # start_time = time.time() 268 | # process_id = process_task(union_id, landuse_paths, landuse_name, out_landuse_dir, 'Landuse', 255, out_landuse_res, 269 | # img_type=np.int32, unit_conversion=False) 270 | # executer.map(process_id, unique_id) 271 | # end_time = time.time() 272 | # print('MCD12Q1(土地利用数据集)预处理完毕: {:0.2f}s '.format(end_time - start_time)) 273 | # # 常规处理 274 | # for id in unique_id: 275 | # start_time = time.time() 276 | # cur_mosaic_ixs = [_ix for _ix, _id in enumerate(union_id) if _id == id] 277 | # # 镶嵌 278 | # mosaic_paths = [landuse_paths[_ix] for _ix in cur_mosaic_ixs] 279 | # mosaic_img, transform, mosaic_img_proj4 = img_mosaic(mosaic_paths, landuse_name, img_nodata=255, img_type=np.int32) 280 | # # 重投影 281 | # reproj_path = os.path.join(out_landuse_dir, 'landuse_' + id + '.tiff') 282 | # img_warp(mosaic_img, reproj_path, transform, mosaic_img_proj4, out_landuse_res, 255, resample='nearest') 283 | # 284 | # # 打印输出 285 | # end_time = time.time() 286 | # print("Landuse-{} 处理完毕: {:0.2f}s".format(id, end_time - start_time)) 287 | 288 | # 对MOD12A2数据集(地表温度数据集)进行镶嵌和重投影(GLT校正) 289 | lst_paths = glob(os.path.join(in_dir, '**', 'MOD11A2*.hdf'), recursive=True) 290 | union_id = [ydays2ym(_path) for _path in lst_paths] 291 | unique_id = set(union_id) 292 | # 多线程处理 293 | with ThreadPoolExecutor() as executer: 294 | start_time = time.time() 295 | process_id = process_task(union_id, lst_paths, lst_name, out_lst_dir, 'LST_MIN', -65535, out_lst_res, resamlpe='cubic', 296 | temperature=True, unit_conversion=True, mosaic_mode='min') 297 | executer.map(process_id, unique_id) 298 | end_time = time.time() 299 | print('MOD11A2(地表温度数据集)预处理完毕: {:0.2f}s'.format(end_time - start_time)) 300 | # # 常规处理 301 | # for id in unique_id: 302 | # start_time = time.time() 303 | # cur_mosaic_ixs = [_ix for _ix, _id in enumerate(union_id) if _id == id] 304 | # # 镶嵌 305 | # mosaic_paths = [lst_paths[_ix] for _ix in cur_mosaic_ixs] 306 | # mosaic_img, transform, mosaic_img_proj4 = img_mosaic(mosaic_paths, lst_name, img_nodata=-65535, 307 | # img_type=np.float32, unit_conversion=True) 308 | # # 开尔文 ==> 摄氏度 309 | # mosaic_img -= 273.15 310 | # # 重投影 311 | # reproj_path = os.path.join(out_lst_dir, 'lst_' + id + '.tiff') 312 | # img_warp(mosaic_img, reproj_path, transform, mosaic_img_proj4, out_lst_res, -65535, resample='cubic') 313 | # 314 | # # 打印输出 315 | # end_time = time.time() 316 | # print("LST-{} 处理完毕: {:0.2f}s".format(id, end_time - start_time)) 317 | 318 | # 对MOD13A2数据集(NDVI数据集)进行镶嵌和重投影(GLT校正) 319 | ndvi_paths = glob(os.path.join(in_dir, '**', 'MOD13A2*.hdf'), recursive=True) 320 | union_id = [ydays2ym(_path) for _path in ndvi_paths] 321 | unique_id = np.unique(np.asarray(union_id)) 322 | # 多线程处理 323 | with ThreadPoolExecutor() as executer: 324 | start_time = time.time() 325 | process_id = process_task(union_id, ndvi_paths, ndvi_name, out_ndvi_dir, 'NDVI_MIN', -65535, out_ndvi_res, 326 | resamlpe='cubic', unit_conversion=True, scale_factor_op='divide', mosaic_mode='min') 327 | executer.map(process_id, unique_id) 328 | # end_time = time.time() 329 | # print('MCD13A2(NDVI数据集)预处理完毕: {:0.2f}s'.format(end_time - start_time)) 330 | # 常规处理 331 | # for id in unique_id: 332 | # start_time = time.time() 333 | # cur_mosaic_ixs = [_ix for _ix, _id in enumerate(union_id) if _id == id] 334 | # # 镶嵌 335 | # mosaic_paths = [ndvi_paths[_ix] for _ix in cur_mosaic_ixs] 336 | # mosaic_img, transform, mosaic_img_proj4 = img_mosaic(mosaic_paths, ndvi_name, img_nodata=-65535, img_type=np.float32, 337 | # unit_conversion=True, scale_factor_op='divide') 338 | # # 重投影 339 | # reproj_path = os.path.join(out_ndvi_dir, 'ndvi_' + id + '.tiff') 340 | # img_warp(mosaic_img, reproj_path, transform, mosaic_img_proj4, out_ndvi_res, -65535, resample='cubic') 341 | # 342 | # # 打印输出 343 | # end_time = time.time() 344 | # print("NDVI-{} 处理完毕: {:0.2f}s".format(id, end_time - start_time)) 345 | 346 | 347 | # 对MOD13A2数据集(EVI数据集)进行镶嵌和重投影(GLT校正) 348 | evi_paths = glob(os.path.join(in_dir, '**', 'MOD13A2*.hdf'), recursive=True) 349 | union_id = [ydays2ym(_path) for _path in evi_paths] 350 | unique_id = np.unique(np.asarray(union_id)) 351 | # 多线程处理 352 | with ThreadPoolExecutor() as executer: 353 | start_time = time.time() 354 | process_id = process_task(union_id, evi_paths, evi_name, out_evi_dir, 'EVI', -65535, out_evi_res, 355 | resamlpe='cubic', unit_conversion=True, scale_factor_op='divide', mosaic_mode='max') 356 | executer.map(process_id, unique_id) 357 | end_time = time.time() 358 | print('MOD13A2(EVI数据集)预处理完毕: {:0.2f}s '.format(end_time - start_time)) -------------------------------------------------------------------------------- /Assets/LSTM-master/README.md: -------------------------------------------------------------------------------- 1 | ```python 2 | import torch 3 | import torch as t 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torchnet import meter 8 | import xarray as xr 9 | import rioxarray as rxr 10 | ``` 11 | 12 | 13 | ```python 14 | torch.cuda.is_available() 15 | ``` 16 | 17 | 18 | 19 | 20 | True 21 | 22 | 23 | 24 | 25 | ```python 26 | precipitation_data = rxr.open_rasterio('data/prcp.tif').values 27 | 28 | # 将数据转换为 PyTorch 张量 29 | precipitation_data = torch.tensor(precipitation_data, dtype=torch.float32) 30 | 31 | precipitation_mean = torch.mean(precipitation_data, 0) 32 | precipitation_std = torch.std(precipitation_data, 0) 33 | precipitation = (precipitation_data - precipitation_mean) / precipitation_std 34 | 35 | precipitation_re = precipitation.reshape(183,-1).transpose(0,1) 36 | ``` 37 | 38 | 39 | ```python 40 | from utils import plot 41 | import matplotlib.pyplot as plt 42 | import cartopy.crs as ccrs 43 | import numpy as np 44 | import xarray as xr 45 | file_name='data/train/prcp.tif' 46 | ds=xr.open_dataset(file_name) 47 | data = ds['band_data'][7] 48 | 49 | fig = plt.figure() 50 | proj = ccrs.Robinson() #ccrs.Robinson()ccrs.Mollweide()Mollweide() 51 | ax = fig.add_subplot(111, projection=proj) 52 | levels = np.linspace(0, 0.25, num=9) 53 | plot.one_map_flat(data, ax, levels=levels, cmap="RdBu", mask_ocean=False, add_coastlines=True, add_land=True, colorbar=True, plotfunc="pcolormesh") 54 | ``` 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | ![png](output_3_1.png) 66 | 67 | 68 | 69 | 70 | ```python 71 | # 创建二维矩阵 72 | import random 73 | matrix = torch.mean(torch.stack([torch.mean(precipitation_re, 1)], 1), 1).flatten() 74 | # 将矩阵中值为NaN的元素置为0 75 | matrix[torch.isnan(matrix)] = 0 76 | 77 | # 获取所有不为NaN的元素的索引 78 | non_negative_indices = torch.nonzero(matrix) 79 | precipitation_re = precipitation_re[non_negative_indices.flatten(), :] 80 | ``` 81 | 82 | 83 | ```python 84 | class Config(object): 85 | t0 = 155 #155 86 | t1 = 12 87 | t = t0 + t1 88 | train_num = 8000 #8 89 | validation_num = 1000 #1 90 | test_num = 1000 #1 91 | in_channels = 1 92 | batch_size = 500 #500 NSE 0.75 93 | lr = .0005 # learning rate 94 | epochs = 100 95 | ``` 96 | 97 | 98 | ```python 99 | import torch 100 | import matplotlib.pyplot as plt 101 | import numpy as np 102 | from torch.utils.data import Dataset 103 | 104 | class time_series_decoder_paper(Dataset): 105 | """synthetic time series dataset from section 5.1""" 106 | 107 | def __init__(self,t0=120,N=4500,dx=None,dy=None,transform=None): 108 | """ 109 | Args: 110 | t0: previous t0 data points to predict from 111 | N: number of data points 112 | transform: any transformations to be applied to time series 113 | """ 114 | self.t0 = t0 115 | self.N = N 116 | self.dx = dx 117 | self.dy = dy 118 | self.transform = None 119 | 120 | 121 | # time points 122 | #self.x = torch.cat(N*[torch.arange(0,t0+24).type(torch.float).unsqueeze(0)]) 123 | self.x = dx 124 | self.fx = dy 125 | # self.fx = torch.cat([A1.unsqueeze(1)*torch.sin(np.pi*self.x[0,0:12]/6)+72 , 126 | # A2.unsqueeze(1)*torch.sin(np.pi*self.x[0,12:24]/6)+72 , 127 | # A3.unsqueeze(1)*torch.sin(np.pi*self.x[0,24:t0]/6)+72, 128 | # A4.unsqueeze(1)*torch.sin(np.pi*self.x[0,t0:t0+24]/12)+72],1) 129 | 130 | # add noise 131 | # self.fx = self.fx + torch.randn(self.fx.shape) 132 | 133 | self.masks = self._generate_square_subsequent_mask(t0) 134 | 135 | 136 | # print out shapes to confirm desired output 137 | print("x: ",self.x.shape, 138 | "fx: ",self.fx.shape) 139 | 140 | def __len__(self): 141 | return len(self.fx) 142 | 143 | def __getitem__(self,idx): 144 | if torch.is_tensor(idx): 145 | idx = idx.tolist() 146 | 147 | 148 | sample = (self.x[idx,:,:], #self.x[idx,:] 149 | self.fx[idx,:], 150 | self.masks) 151 | 152 | if self.transform: 153 | sample=self.transform(sample) 154 | 155 | return sample 156 | 157 | def _generate_square_subsequent_mask(self,t0): 158 | mask = torch.zeros(Config.t,Config.t) 159 | for i in range(0,Config.t0): 160 | mask[i,Config.t0:] = 1 161 | for i in range(Config.t0,Config.t): 162 | mask[i,i+1:] = 1 163 | mask = mask.float().masked_fill(mask == 1, float('-inf'))#.masked_fill(mask == 1, float(0.0)) 164 | return mask 165 | ``` 166 | 167 | 168 | ```python 169 | class TransformerTimeSeries(torch.nn.Module): 170 | """ 171 | Time Series application of transformers based on paper 172 | 173 | causal_convolution_layer parameters: 174 | in_channels: the number of features per time point 175 | out_channels: the number of features outputted per time point 176 | kernel_size: k is the width of the 1-D sliding kernel 177 | 178 | nn.Transformer parameters: 179 | d_model: the size of the embedding vector (input) 180 | 181 | PositionalEncoding parameters: 182 | d_model: the size of the embedding vector (positional vector) 183 | dropout: the dropout to be used on the sum of positional+embedding vector 184 | 185 | """ 186 | def __init__(self): 187 | super(TransformerTimeSeries,self).__init__() 188 | self.input_embedding = context_embedding(Config.in_channels+1,256,5) 189 | self.positional_embedding = torch.nn.Embedding(512,256) 190 | 191 | 192 | self.decode_layer = torch.nn.TransformerEncoderLayer(d_model=256,nhead=8) 193 | self.transformer_decoder = torch.nn.TransformerEncoder(self.decode_layer, num_layers=3) 194 | 195 | self.fc1 = torch.nn.Linear(256,1) 196 | 197 | def forward(self,x,y,attention_masks): 198 | 199 | # concatenate observed points and time covariate 200 | # (B*feature_size*n_time_points) 201 | #re z = torch.cat((y.unsqueeze(1),x.unsqueeze(1)),1) 202 | z = torch.cat((y,x),1) 203 | # input_embedding returns shape (Batch size,embedding size,sequence len) -> need (sequence len,Batch size,embedding_size) 204 | #re z_embedding = self.input_embedding(z).permute(2,0,1) 205 | z_embedding = self.input_embedding(z).unsqueeze(1).permute(3, 1, 0, 2) 206 | # get my positional embeddings (Batch size, sequence_len, embedding_size) -> need (sequence len,Batch size,embedding_size) 207 | x1 = x.type(torch.long) 208 | x1[x1 < 0] = 0 209 | positional_embeddings = self.positional_embedding(x1).permute(2, 1, 0, 3) 210 | #re #positional_embeddings = self.positional_embedding(x.type(torch.long)).permute(1,0,2) 211 | 212 | input_embedding = z_embedding+positional_embeddings 213 | input_embedding1 = torch.mean(input_embedding, 1) 214 | transformer_embedding = self.transformer_decoder(input_embedding1,attention_masks) 215 | 216 | output = self.fc1(transformer_embedding.permute(1,0,2)) 217 | 218 | return output 219 | import torch 220 | import numpy as np 221 | import matplotlib.pyplot as plt 222 | import torch.nn.functional as F 223 | 224 | class CausalConv1d(torch.nn.Conv1d): 225 | def __init__(self, 226 | in_channels, 227 | out_channels, 228 | kernel_size, 229 | stride=1, 230 | dilation=1, 231 | groups=1, 232 | bias=True): 233 | 234 | super(CausalConv1d, self).__init__( 235 | in_channels, 236 | out_channels, 237 | kernel_size=kernel_size, 238 | stride=stride, 239 | padding=0, 240 | dilation=dilation, 241 | groups=groups, 242 | bias=bias) 243 | 244 | self.__padding = (kernel_size - 1) * dilation 245 | 246 | def forward(self, input): 247 | return super(CausalConv1d, self).forward(F.pad(input, (self.__padding, 0))) 248 | 249 | 250 | class context_embedding(torch.nn.Module): 251 | def __init__(self,in_channels=Config.in_channels,embedding_size=256,k=5): 252 | super(context_embedding,self).__init__() 253 | self.causal_convolution = CausalConv1d(in_channels,embedding_size,kernel_size=k) 254 | 255 | def forward(self,x): 256 | x = self.causal_convolution(x) 257 | return torch.tanh(x) 258 | ``` 259 | 260 | 261 | ```python 262 | class LSTM_Time_Series(torch.nn.Module): 263 | def __init__(self,input_size=2,embedding_size=256,kernel_width=9,hidden_size=512): 264 | super(LSTM_Time_Series,self).__init__() 265 | 266 | self.input_embedding = context_embedding(input_size,embedding_size,kernel_width) 267 | 268 | self.lstm = torch.nn.LSTM(embedding_size,hidden_size,batch_first=True) 269 | 270 | self.fc1 = torch.nn.Linear(hidden_size,1) 271 | 272 | def forward(self,x,y): 273 | """ 274 | x: the time covariate 275 | y: the observed target 276 | """ 277 | # concatenate observed points and time covariate 278 | # (B,input size + covariate size,sequence length) 279 | # z = torch.cat((y.unsqueeze(1),x.unsqueeze(1)),1) 280 | z_obs = torch.cat((y.unsqueeze(1),x.unsqueeze(1)),1) 281 | if isLSTM: 282 | z_obs = torch.cat((y, x),1) 283 | # input_embedding returns shape (B,embedding size,sequence length) 284 | z_obs_embedding = self.input_embedding(z_obs) 285 | 286 | # permute axes (B,sequence length, embedding size) 287 | z_obs_embedding = self.input_embedding(z_obs).permute(0,2,1) 288 | 289 | # all hidden states from lstm 290 | # (B,sequence length,num_directions * hidden size) 291 | lstm_out,_ = self.lstm(z_obs_embedding) 292 | 293 | # input to nn.Linear: (N,*,Hin) 294 | # output (N,*,Hout) 295 | return self.fc1(lstm_out) 296 | ``` 297 | 298 | 299 | ```python 300 | from torch.utils.data import DataLoader 301 | import random 302 | random.seed(0) 303 | 304 | random_indices = random.sample(range(non_negative_indices.shape[0]), Config.train_num) 305 | random_indices1 = random.sample(range(non_negative_indices.shape[0]), Config.validation_num) 306 | random_indices2 = random.sample(range(non_negative_indices.shape[0]), Config.test_num) 307 | dx = torch.stack([torch.cat(Config.train_num*[torch.arange(0,Config.t).type(torch.float).unsqueeze(0)]).cuda()], 1) 308 | dx1 = torch.stack([torch.cat(Config.validation_num*[torch.arange(0,Config.t).type(torch.float).unsqueeze(0)]).cuda()], 1) 309 | dx2 = torch.stack([torch.cat(Config.test_num*[torch.arange(0,Config.t).type(torch.float).unsqueeze(0)]).cuda()], 1) 310 | train_dataset = time_series_decoder_paper(t0=Config.t0,N=Config.train_num,dx=dx ,dy=precipitation_re[np.array([random_indices]).flatten(),0:Config.t].unsqueeze(1)) 311 | validation_dataset = time_series_decoder_paper(t0=Config.t0,N=Config.validation_num,dx=dx1,dy=precipitation_re[np.array([random_indices1]).flatten(),0:Config.t].unsqueeze(1)) 312 | test_dataset = time_series_decoder_paper(t0=Config.t0,N=Config.test_num,dx=dx2,dy=precipitation_re[np.array([random_indices2]).flatten(),0:Config.t].unsqueeze(1)) 313 | ``` 314 | 315 | x: torch.Size([8000, 1, 167]) fx: torch.Size([8000, 1, 167]) 316 | x: torch.Size([1000, 1, 167]) fx: torch.Size([1000, 1, 167]) 317 | x: torch.Size([1000, 1, 167]) fx: torch.Size([1000, 1, 167]) 318 | 319 | 320 | 321 | ```python 322 | criterion = torch.nn.MSELoss() 323 | train_dl = DataLoader(train_dataset,batch_size=Config.batch_size,shuffle=True, generator=torch.Generator(device='cpu')) 324 | validation_dl = DataLoader(validation_dataset,batch_size=Config.batch_size, generator=torch.Generator(device='cpu')) 325 | test_dl = DataLoader(test_dataset,batch_size=Config.batch_size, generator=torch.Generator(device='cpu')) 326 | ``` 327 | 328 | 329 | ```python 330 | criterion_LSTM = torch.nn.MSELoss() 331 | ``` 332 | 333 | 334 | ```python 335 | LSTM = LSTM_Time_Series().cuda() 336 | ``` 337 | 338 | 339 | ```python 340 | def Dp(y_pred,y_true,q): 341 | return max([q*(y_pred-y_true),(q-1)*(y_pred-y_true)]) 342 | 343 | def Rp_num_den(y_preds,y_trues,q): 344 | numerator = np.sum([Dp(y_pred,y_true,q) for y_pred,y_true in zip(y_preds,y_trues)]) 345 | denominator = np.sum([np.abs(y_true) for y_true in y_trues]) 346 | return numerator,denominator 347 | def train_epoch(LSTM,train_dl,t0=Config.t0): 348 | LSTM.train() 349 | train_loss = 0 350 | n = 0 351 | for step,(x,y,_) in enumerate(train_dl): 352 | x = x.cuda() 353 | y = y.cuda() 354 | 355 | optimizer.zero_grad() 356 | output = LSTM(x,y) 357 | 358 | loss = criterion(output.squeeze()[:,(Config.t0-1):(Config.t0+Config.t1-1)],y.cuda()[:,0,Config.t0:]) 359 | loss.backward() 360 | optimizer.step() 361 | 362 | train_loss += (loss.detach().cpu().item() * x.shape[0]) 363 | n += x.shape[0] 364 | return train_loss/n 365 | def eval_epoch(LSTM,validation_dl,t0=Config.t0): 366 | LSTM.eval() 367 | eval_loss = 0 368 | n = 0 369 | with torch.no_grad(): 370 | for step,(x,y,_) in enumerate(train_dl): 371 | x = x.cuda() 372 | y = y.cuda() 373 | 374 | output = LSTM(x,y) 375 | loss = criterion(output.squeeze()[:,(Config.t0-1):(Config.t0+Config.t1-1)],y.cuda()[:,0,Config.t0:]) 376 | 377 | eval_loss += (loss.detach().cpu().item() * x.shape[0]) 378 | n += x.shape[0] 379 | 380 | return eval_loss/n 381 | def test_epoch(LSTM,test_dl,t0=Config.t0): 382 | with torch.no_grad(): 383 | predictions = [] 384 | observations = [] 385 | 386 | LSTM.eval() 387 | for step,(x,y,_) in enumerate(train_dl): 388 | x = x.cuda() 389 | y = y.cuda() 390 | 391 | output = LSTM(x,y) 392 | 393 | for p,o in zip(output.squeeze()[:,(Config.t0-1):(Config.t0+Config.t1-1)].cpu().numpy().tolist(),y.cuda()[:,0,Config.t0:].cpu().numpy().tolist()): 394 | 395 | predictions.append(p) 396 | observations.append(o) 397 | 398 | num = 0 399 | den = 0 400 | for y_preds,y_trues in zip(predictions,observations): 401 | num_i,den_i = Rp_num_den(y_preds,y_trues,.5) 402 | num+=num_i 403 | den+=den_i 404 | Rp = (2*num)/den 405 | 406 | return Rp 407 | ``` 408 | 409 | 410 | ```python 411 | train_epoch_loss = [] 412 | eval_epoch_loss = [] 413 | Rp_best = 30 414 | isLSTM = True 415 | optimizer = torch.optim.Adam(LSTM.parameters(), lr=Config.lr) 416 | 417 | for e,epoch in enumerate(range(Config.epochs)): 418 | train_loss = [] 419 | eval_loss = [] 420 | 421 | l_train = train_epoch(LSTM,train_dl) 422 | train_loss.append(l_train) 423 | 424 | l_eval = eval_epoch(LSTM,validation_dl) 425 | eval_loss.append(l_eval) 426 | 427 | Rp = test_epoch(LSTM,test_dl) 428 | 429 | if Rp_best > Rp: 430 | Rp_best = Rp 431 | 432 | with torch.no_grad(): 433 | print("Epoch {}: Train loss={} \t Eval loss = {} \t Rp={}".format(e,np.mean(train_loss),np.mean(eval_loss),Rp)) 434 | 435 | train_epoch_loss.append(np.mean(train_loss)) 436 | eval_epoch_loss.append(np.mean(eval_loss)) 437 | ``` 438 | 439 | Epoch 0: Train loss=1.169178232550621 Eval loss = 1.0225972533226013 Rp=0.9754564549482763 440 | Epoch 1: Train loss=1.0208212696015835 Eval loss = 1.0133976340293884 Rp=0.992397172460293 441 | Epoch 2: Train loss=1.0135348699986935 Eval loss = 1.0102826319634914 Rp=1.0010193148701694 442 | Epoch 3: Train loss=1.0075771994888783 Eval loss = 1.003737311810255 Rp=0.9874602462009079 443 | Epoch 4: Train loss=0.9986509047448635 Eval loss = 0.991193663328886 Rp=0.9851476850727866 444 | Epoch 5: Train loss=0.9815127141773701 Eval loss = 0.9676672779023647 Rp=0.9776791701859356 445 | Epoch 6: Train loss=0.9377330988645554 Eval loss = 0.8851083293557167 Rp=0.9021427715861259 446 | Epoch 7: Train loss=0.8124373629689217 Eval loss = 0.776776347309351 Rp=0.8164312307396333 447 | Epoch 8: Train loss=0.7808051072061062 Eval loss = 0.7724409475922585 Rp=0.8044285279275872 448 | Epoch 9: Train loss=0.7723440378904343 Eval loss = 0.7691570967435837 Rp=0.8276665115435272 449 | Epoch 10: Train loss=0.7680666074156761 Eval loss = 0.7604397684335709 Rp=0.8172812960763582 450 | Epoch 11: Train loss=0.7637499608099461 Eval loss = 0.7642757333815098 Rp=0.7958800149846623 451 | Epoch 12: Train loss=0.7604391016066074 Eval loss = 0.7545832060277462 Rp=0.7959258052455023 452 | Epoch 13: Train loss=0.7542793937027454 Eval loss = 0.758263424038887 Rp=0.8029712542553213 453 | Epoch 14: Train loss=0.7513296827673912 Eval loss = 0.74464987590909 Rp=0.7928307957450957 454 | Epoch 15: Train loss=0.7609197050333023 Eval loss = 0.7561161443591118 Rp=0.7961394733618599 455 | Epoch 16: Train loss=0.7611901015043259 Eval loss = 0.754481915384531 Rp=0.8042048513261087 456 | Epoch 17: Train loss=0.7494660168886185 Eval loss = 0.7436127960681915 Rp=0.808727372545216 457 | Epoch 18: Train loss=0.7624928876757622 Eval loss = 0.7601931132376194 Rp=0.8278642644450237 458 | Epoch 19: Train loss=0.7445684559643269 Eval loss = 0.7404011972248554 Rp=0.7988691222148593 459 | Epoch 20: Train loss=0.7364756055176258 Eval loss = 0.7331099547445774 Rp=0.7900263164250347 460 | Epoch 21: Train loss=0.7366516776382923 Eval loss = 0.7335694879293442 Rp=0.7858690993104261 461 | Epoch 22: Train loss=0.7310461960732937 Eval loss = 0.7295852825045586 Rp=0.7947732982354974 462 | Epoch 23: Train loss=0.7325067669153214 Eval loss = 0.7321035303175449 Rp=0.7994415259947144 463 | Epoch 24: Train loss=0.7324810847640038 Eval loss = 0.7215580977499485 Rp=0.7793204264401973 464 | Epoch 25: Train loss=0.7343184538185596 Eval loss = 0.7716133445501328 Rp=0.8319286623223218 465 | Epoch 26: Train loss=0.7366975098848343 Eval loss = 0.7249130606651306 Rp=0.77453832290699 466 | Epoch 27: Train loss=0.7278863601386547 Eval loss = 0.720306035131216 Rp=0.7780187017935781 467 | Epoch 28: Train loss=0.7243384085595608 Eval loss = 0.715414222329855 Rp=0.7660441378673354 468 | Epoch 29: Train loss=0.7391963303089142 Eval loss = 0.8104664944112301 Rp=0.848384596737658 469 | Epoch 30: Train loss=0.7501446716487408 Eval loss = 0.7330859526991844 Rp=0.7958694240958736 470 | Epoch 31: Train loss=0.7319861426949501 Eval loss = 0.7288344763219357 Rp=0.7764215311271584 471 | Epoch 32: Train loss=0.7289896085858345 Eval loss = 0.7186561860144138 Rp=0.7719541082002876 472 | Epoch 33: Train loss=0.7208635434508324 Eval loss = 0.7130853533744812 Rp=0.7693201078171342 473 | Epoch 34: Train loss=0.7188350297510624 Eval loss = 0.7184220626950264 Rp=0.7790605390063141 474 | Epoch 35: Train loss=0.7278616651892662 Eval loss = 0.7226458676159382 Rp=0.7934061207417414 475 | Epoch 36: Train loss=0.7257222011685371 Eval loss = 0.7461872175335884 Rp=0.8043810986938649 476 | Epoch 37: Train loss=0.722360011190176 Eval loss = 0.7184372805058956 Rp=0.7733680838057659 477 | Epoch 38: Train loss=0.7406770437955856 Eval loss = 0.7328226044774055 Rp=0.7792520400962948 478 | Epoch 39: Train loss=0.7231648564338684 Eval loss = 0.7170744873583317 Rp=0.7745043319879358 479 | Epoch 40: Train loss=0.7179877758026123 Eval loss = 0.7121099643409252 Rp=0.7665604319856079 480 | Epoch 41: Train loss=0.7204379811882973 Eval loss = 0.7137661874294281 Rp=0.7697876723151911 481 | Epoch 42: Train loss=0.7189657613635063 Eval loss = 0.7145831622183323 Rp=0.7683013184528533 482 | Epoch 43: Train loss=0.7236194014549255 Eval loss = 0.7213072367012501 Rp=0.7794286898969346 483 | Epoch 44: Train loss=0.7113666497170925 Eval loss = 0.7064446061849594 Rp=0.7730508504910318 484 | Epoch 45: Train loss=0.7169638313353062 Eval loss = 0.7072659730911255 Rp=0.7818726340368028 485 | Epoch 46: Train loss=0.7137239314615726 Eval loss = 0.7547547481954098 Rp=0.8419246011202801 486 | Epoch 47: Train loss=0.7260657027363777 Eval loss = 0.7107443884015083 Rp=0.7655332919045247 487 | Epoch 48: Train loss=0.7079816907644272 Eval loss = 0.7071417346596718 Rp=0.773454930517809 488 | Epoch 49: Train loss=0.7094167955219746 Eval loss = 0.7058326154947281 Rp=0.7727462602735332 489 | Epoch 50: Train loss=0.7380903214216232 Eval loss = 0.7227592132985592 Rp=0.769177665270142 490 | Epoch 51: Train loss=0.7130068391561508 Eval loss = 0.7105456776916981 Rp=0.7592860561025371 491 | Epoch 52: Train loss=0.7084374688565731 Eval loss = 0.7031594552099705 Rp=0.7660899650703171 492 | Epoch 53: Train loss=0.7042888924479485 Eval loss = 0.7040572166442871 Rp=0.7721988413128251 493 | Epoch 54: Train loss=0.7063969075679779 Eval loss = 0.6986251175403595 Rp=0.7695131487850577 494 | Epoch 55: Train loss=0.7053375691175461 Eval loss = 0.7212032824754715 Rp=0.7735872477697779 495 | Epoch 56: Train loss=0.7035926096141338 Eval loss = 0.7097134478390217 Rp=0.784710594937848 496 | Epoch 57: Train loss=0.735156461596489 Eval loss = 0.8654714487493038 Rp=0.911015246823642 497 | Epoch 58: Train loss=0.8107807412743568 Eval loss = 0.7581384815275669 Rp=0.8157151369086734 498 | Epoch 59: Train loss=0.7544550113379955 Eval loss = 0.7499602921307087 Rp=0.7949385155087615 499 | Epoch 60: Train loss=0.746234655380249 Eval loss = 0.7389007620513439 Rp=0.7862202786182774 500 | Epoch 61: Train loss=0.7351461201906204 Eval loss = 0.7316578552126884 Rp=0.7699839364702437 501 | Epoch 62: Train loss=0.7276912368834019 Eval loss = 0.7233806289732456 Rp=0.7752237105789992 502 | Epoch 63: Train loss=0.7192910276353359 Eval loss = 0.7333457358181477 Rp=0.7724518923905164 503 | Epoch 64: Train loss=0.7260234951972961 Eval loss = 0.7129928097128868 Rp=0.7660325583116143 504 | Epoch 65: Train loss=0.7277122773230076 Eval loss = 0.7464463748037815 Rp=0.7795688062338593 505 | Epoch 66: Train loss=0.7259447425603867 Eval loss = 0.7074765078723431 Rp=0.7646423879285662 506 | Epoch 67: Train loss=0.7121074683964252 Eval loss = 0.7206148467957973 Rp=0.7670712925578708 507 | Epoch 68: Train loss=0.7051395028829575 Eval loss = 0.7368365041911602 Rp=0.8174581520496608 508 | Epoch 69: Train loss=0.7655579410493374 Eval loss = 0.7538384310901165 Rp=0.7762011659936345 509 | Epoch 70: Train loss=0.7304071560502052 Eval loss = 0.7248818911612034 Rp=0.791938228314532 510 | Epoch 71: Train loss=0.7145950980484486 Eval loss = 0.7085471898317337 Rp=0.771219627305778 511 | Epoch 72: Train loss=0.705636128783226 Eval loss = 0.7026742734014988 Rp=0.7582436097165333 512 | Epoch 73: Train loss=0.7039311081171036 Eval loss = 0.701056282967329 Rp=0.7621456124101622 513 | Epoch 74: Train loss=0.7022229805588722 Eval loss = 0.7022544406354427 Rp=0.7572908772835294 514 | Epoch 75: Train loss=0.7077537477016449 Eval loss = 0.7068974897265434 Rp=0.7689672806983037 515 | Epoch 76: Train loss=0.70463952049613 Eval loss = 0.7016653120517731 Rp=0.7620379826248179 516 | Epoch 77: Train loss=0.6936824433505535 Eval loss = 0.6882451064884663 Rp=0.7536649577052977 517 | Epoch 78: Train loss=0.7085927426815033 Eval loss = 0.7006802186369896 Rp=0.7573836578425687 518 | Epoch 79: Train loss=0.6964434124529362 Eval loss = 0.6970510184764862 Rp=0.7566034081453576 519 | Epoch 80: Train loss=0.7041287049651146 Eval loss = 0.7041221261024475 Rp=0.7517282641309433 520 | Epoch 81: Train loss=0.7040938474237919 Eval loss = 0.694716889411211 Rp=0.7518976134462692 521 | Epoch 82: Train loss=0.6934744939208031 Eval loss = 0.6876317113637924 Rp=0.7490530317847554 522 | Epoch 83: Train loss=0.6924876533448696 Eval loss = 0.7114526480436325 Rp=0.7712103875056127 523 | Epoch 84: Train loss=0.70367331802845 Eval loss = 0.6974217854440212 Rp=0.7550885913055979 524 | Epoch 85: Train loss=0.7047922983765602 Eval loss = 0.6882399097084999 Rp=0.7489761214479039 525 | Epoch 86: Train loss=0.6913500241935253 Eval loss = 0.6827207766473293 Rp=0.744314955284494 526 | Epoch 87: Train loss=0.6916158571839333 Eval loss = 0.6919064372777939 Rp=0.7509365031084853 527 | Epoch 88: Train loss=0.6971654705703259 Eval loss = 0.6914562620222569 Rp=0.7441346646542655 528 | Epoch 89: Train loss=0.6963543370366096 Eval loss = 0.691758755594492 Rp=0.7572971481054963 529 | Epoch 90: Train loss=0.6917684748768806 Eval loss = 0.6856491975486279 Rp=0.7448740185296933 530 | Epoch 91: Train loss=0.6941814571619034 Eval loss = 0.6949076354503632 Rp=0.7525349231139533 531 | Epoch 92: Train loss=0.69602020829916 Eval loss = 0.7147903628647327 Rp=0.7746622771469637 532 | Epoch 93: Train loss=0.6934036538004875 Eval loss = 0.689013235270977 Rp=0.7674142639047323 533 | Epoch 94: Train loss=0.6828178651630878 Eval loss = 0.6839329451322556 Rp=0.749017388184114 534 | Epoch 95: Train loss=0.6820085123181343 Eval loss = 0.679760005325079 Rp=0.7496224481291669 535 | Epoch 96: Train loss=0.6940162815153599 Eval loss = 0.6897397376596928 Rp=0.7505375270808133 536 | Epoch 97: Train loss=0.6879976131021976 Eval loss = 0.7038531377911568 Rp=0.7740689357717414 537 | Epoch 98: Train loss=0.6902556456625462 Eval loss = 0.6736902967095375 Rp=0.7441236415340688 538 | Epoch 99: Train loss=0.6990400142967701 Eval loss = 0.7069363221526146 Rp=0.7699857686313047 539 | 540 | 541 | 542 | ```python 543 | import os 544 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 545 | n_plots = 5 546 | 547 | t0=120 548 | with torch.no_grad(): 549 | LSTM.eval() 550 | for step,(x,y,_) in enumerate(test_dl): 551 | x = x.cuda() 552 | y = y.cuda() 553 | 554 | output = LSTM(x,y) 555 | 556 | 557 | if step > n_plots: 558 | break 559 | 560 | with torch.no_grad(): 561 | plt.figure(figsize=(10,10)) 562 | plt.plot(x[1, 0].cpu().detach().squeeze().numpy(),y[1].cpu().detach().squeeze().numpy(),'g--',linewidth=3) 563 | plt.plot(x[1, 0, Config.t0:].cpu().detach().squeeze().numpy(),output[1,(Config.t0-1):(Config.t0+Config.t1-1),0].cpu().detach().squeeze().numpy(),'b--',linewidth=3) 564 | 565 | plt.xlabel("x",fontsize=20) 566 | plt.legend(["$[0,t_0+24)_{obs}$","$[t_0,t_0+24)_{predicted}$"]) 567 | plt.show() 568 | ``` 569 | 570 | 571 | 572 | ![png](output_15_0.png) 573 | 574 | 575 | 576 | 577 | 578 | ![png](output_15_1.png) 579 | 580 | 581 | 582 | 583 | ```python 584 | matrix = torch.empty(0).cuda() 585 | obsmat = torch.empty(0).cuda() 586 | 587 | with torch.no_grad(): 588 | LSTM.eval() 589 | predictions = [] 590 | observations = [] 591 | for step,(x,y,attention_masks) in enumerate(test_dl): 592 | # if step == 8: 593 | # break 594 | output = LSTM(x.cuda(),y.cuda()) 595 | matrix = torch.cat((matrix, output.cuda())) 596 | obsmat = torch.cat((obsmat, y.cuda())) 597 | 598 | pre = matrix.cpu().detach().numpy() 599 | obs = obsmat.cpu().detach().numpy() 600 | # libraries 601 | import matplotlib.pyplot as plt 602 | import numpy as np 603 | import pandas as pd 604 | 605 | # data 606 | df = pd.DataFrame({ 607 | 'obs': obs[:, 0, Config.t0:Config.t].flatten(), 608 | 'pre': pre[:, Config.t0:Config.t, 0].flatten() 609 | }) 610 | df 611 | ``` 612 | 613 | 614 | 615 | 616 |
617 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 |
obspre
01.1955941.645856
11.6497690.247605
2-0.608017-0.466307
3-0.471923-0.499275
4-1.097827-0.693097
.........
119953.6494560.124768
11996-0.162433-0.615067
11997-0.589042-0.838524
11998-0.578971-0.815273
11999-0.744391-0.398436
696 |

12000 rows × 2 columns

697 |
698 | 699 | 700 | 701 | 702 | ```python 703 | import numpy as np 704 | import pandas as pd 705 | import matplotlib.pyplot as plt 706 | from scipy import stats 707 | from matplotlib import rcParams 708 | from statistics import mean 709 | from sklearn.metrics import explained_variance_score,r2_score,median_absolute_error,mean_squared_error,mean_absolute_error 710 | from scipy.stats import pearsonr 711 | # 加载数据(PS:原始数据太多,采样10000) 712 | # 默认是读取csv/xlsx的列成DataFrame 713 | 714 | 715 | config = {"font.family":'Times New Roman',"font.size": 16,"mathtext.fontset":'stix'} 716 | #df = df.sample(5000) 717 | # 用于计算指标 718 | x = df['obs']; y = df['pre'] 719 | rcParams.update(config) 720 | BIAS = mean(x - y) 721 | MSE = mean_squared_error(x, y) 722 | RMSE = np.power(MSE, 0.5) 723 | R2 = pearsonr(x, y).statistic 724 | adjR2 = 1-((1-r2_score(x,y))*(len(x)-1))/(len(x)-Config.in_channels-1) 725 | MAE = mean_absolute_error(x, y) 726 | EV = explained_variance_score(x, y) 727 | NSE = 1 - (RMSE ** 2 / np.var(x)) 728 | # 计算散点密度 729 | xy = np.vstack([x, y]) 730 | z = stats.gaussian_kde(xy)(xy) 731 | idx = z.argsort() 732 | x, y, z = x.iloc[idx], y.iloc[idx], z[idx] 733 | 734 | # 拟合(若换MK,自行操作)最小二乘 735 | def slope(xs, ys): 736 | m = (((mean(xs) * mean(ys)) - mean(xs * ys)) / ((mean(xs) * mean(xs)) - mean(xs * xs))) 737 | b = mean(ys) - m * mean(xs) 738 | return m, b 739 | k, b = slope(x, y) 740 | regression_line = [] 741 | for a in x: 742 | regression_line.append((k * a) + b) 743 | 744 | # 绘图,可自行调整颜色等等 745 | import os 746 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 747 | 748 | fig,ax=plt.subplots(figsize=(8,6),dpi=300) 749 | scatter=ax.scatter(x, y, marker='o', c=z*100, edgecolors=None ,s=15, label='LST',cmap='Spectral_r') 750 | cbar=plt.colorbar(scatter,shrink=1,orientation='vertical',extend='both',pad=0.015,aspect=30,label='frequency') 751 | plt.plot([-30,30],[-30,30],'black',lw=1.5) # 画的1:1线,线的颜色为black,线宽为0.8 752 | plt.plot(x,regression_line,'red',lw=1.5) # 预测与实测数据之间的回归线 753 | plt.axis([-30,30,-30,30]) # 设置线的范围 754 | plt.xlabel('OBS',family = 'Times New Roman') 755 | plt.ylabel('PRE',family = 'Times New Roman') 756 | plt.xticks(fontproperties='Times New Roman') 757 | plt.yticks(fontproperties='Times New Roman') 758 | plt.text(-1.8,1.75, '$N=%.f$' % len(y), family = 'Times New Roman') # text的位置需要根据x,y的大小范围进行调整。 759 | plt.text(-1.8,1.50, '$R^2=%.3f$' % R2, family = 'Times New Roman') 760 | plt.text(-1.8,1.25, '$NSE=%.3f$' % NSE, family = 'Times New Roman') 761 | 762 | plt.text(-1.8,1, '$RMSE=%.3f$' % RMSE, family = 'Times New Roman') 763 | plt.xlim(-2,2) # 设置x坐标轴的显示范围 764 | plt.ylim(-2,2) # 设置y坐标轴的显示范围 765 | plt.show() 766 | ``` 767 | 768 | 769 | 770 | ![png](output_17_0.png) 771 | 772 | 773 | 774 | 775 | ```python 776 | 777 | ``` 778 | --------------------------------------------------------------------------------