├── Assets
├── LSTM-master
│ ├── data
│ │ └── .Rhistory
│ ├── .gitattributes
│ ├── output_15_0.png
│ ├── output_15_1.png
│ ├── output_17_0.png
│ ├── output_3_1.png
│ ├── LICENSE
│ ├── .gitignore
│ └── README.md
├── Images
│ └── 研究区域.png
└── Ref
│ ├── sp_ncdf.pro
│ ├── gldas_tws_eg.py
│ └── sp_ncdf_lunwen.pro
├── .gitignore
├── History
└── IDL.rar
├── requirements.txt
├── utils
├── __pycache__
│ ├── models.cpython-38.pyc
│ ├── utils.cpython-38.pyc
│ └── __init__.cpython-38.pyc
├── __init__.py
├── utils.py
└── models.py
├── .idea
├── .gitignore
├── misc.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
└── VEG.iml
├── Core
├── dead_code.py
├── process_Rs.py
├── check_datasets.py
├── process_gldas.py
├── Plot
│ ├── density_box_plot.py
│ ├── line_bar_distribution_plot.py
│ └── plot.ipynb
├── uniform_datasets.py
├── feature_engineering.py
├── model_train_dynamic.py
├── model_train.py
└── process_modis.py
└── README.md
/Assets/LSTM-master/data/.Rhistory:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Assets/LSTM-master/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | Assets/LSTM-master/data/prcp.tif
3 | Assets/LSTM-master/data/prcp.tif.aux.xml
4 | VEG.zip
5 |
--------------------------------------------------------------------------------
/History/IDL.rar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/History/IDL.rar
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/requirements.txt
--------------------------------------------------------------------------------
/Assets/Images/研究区域.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/Images/研究区域.png
--------------------------------------------------------------------------------
/Assets/LSTM-master/output_15_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/LSTM-master/output_15_0.png
--------------------------------------------------------------------------------
/Assets/LSTM-master/output_15_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/LSTM-master/output_15_1.png
--------------------------------------------------------------------------------
/Assets/LSTM-master/output_17_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/LSTM-master/output_17_0.png
--------------------------------------------------------------------------------
/Assets/LSTM-master/output_3_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/Assets/LSTM-master/output_3_1.png
--------------------------------------------------------------------------------
/utils/__pycache__/models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/utils/__pycache__/models.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI/HEAD/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/1/14 19:06
3 | # @Email : chaoqiezi.one@qq.com
4 |
5 | """
6 | This script is used to ...
7 | """
8 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/VEG.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/Core/dead_code.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2023/12/30 20:35
3 | # @Email : chaoqiezi.one@qq.com
4 |
5 | """
6 | This script is used to ...
7 | """
8 | import os
9 |
10 | # 指定要更改的目录
11 | path = r'H:\Datasets\Objects\Veg\LST_Max'
12 |
13 | # 遍历目录下的所有文件
14 | for filename in os.listdir(path):
15 | # 检查文件名是否包含"max"
16 | if "Max" in filename:
17 | # 创建新的文件名,将"max"替换为"MAX"
18 | new_filename = filename.replace("Max", "MAX")
19 |
20 | # 获取文件的原始路径和新路径
21 | old_path = os.path.join(path, filename)
22 | new_path = os.path.join(path, new_filename)
23 |
24 | # 重命名文件
25 | os.rename(old_path, new_path)
26 |
--------------------------------------------------------------------------------
/Assets/LSTM-master/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Longhao Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Assets/Ref/sp_ncdf.pro:
--------------------------------------------------------------------------------
1 | ;首先我们需要一个读取函数,这里我直接复制过来讲解
2 | function era5_readarr,file_name,sds_name ;这里只需要给函数传递文件名,和数据集名字
3 | file_id=ncdf_open(file_name) ;打开文件
4 | data_id=ncdf_varid(file_id,sds_name);读取数据集-我们这里是2m温度
5 | ncdf_varget,file_id,data_id,data;获取2m温度 存储为data
6 | ncdf_attget,file_id,data_id,'scale_factor',sc_data ;获取数据需要的预处理乘法因子,每个数据需要先乘以这个因子
7 | ncdf_attget,file_id,data_id,'add_offset',ao_data;相应的获取加法因子,每个数据需要乘以上面因子之后再加
8 | ncdf_close,file_id
9 | data=float(data)*sc_data+ao_data ;先乘后加
10 | return,data ;得到的最后数据返回
11 | end
12 |
13 | function era5_readarr11,file_name,sds_name ;这里只需要给函数传递文件名,和数据集名字
14 | file_id=ncdf_open(file_name) ;打开文件
15 | data_id=ncdf_varid(file_id,sds_name);读取数据集-我们这里是2m温度
16 | ncdf_varget,file_id,data_id,data;获取2m温度 存储为data
17 | ;ncdf_attget,file_id,data_id,'scale_factor',sc_data ;获取数据需要的预处理乘法因子,每个数据需要先乘以这个因子
18 | ;ncdf_attget,file_id,data_id,'add_offset',ao_data;相应的获取加法因子,每个数据需要乘以上面因子之后再加
19 | ncdf_close,file_id
20 | data=float(data);*sc_data+ao_data ;先乘后加
21 | return,data ;得到的最后数据返回
22 | end
23 | ;接着我们对应我们的数据处理
24 | pro sp_ncdf
25 | ;输入文件名字直接复制过来也可以直接拖过来。
26 | path = 'F:\new_lunwen\Global GLDAS\'
27 | out_path = 'F:\new_lunwen\jg\Tveg\'
28 | file_dir = file_test(out_path,/directory)
29 | if file_dir eq 0 then begin
30 | file_mkdir,out_path
31 | endif
32 | ;接着我们读取数据-t2m
33 | file_list = file_search(path,'*.nc4',count = file_n)
34 | for file_i=0,file_n-1 do begin
35 |
36 | data_ciwc = era5_readarr11(file_list[file_i],'Tveg_tavg')
37 | data_ciwc = (data_ciwc gt -9999 and data_ciwc lt 9999)*data_ciwc
38 |
39 | lon = era5_readarr11(file_list[file_i],'lon')
40 | lon_min = min(lon)
41 | lat = era5_readarr11(file_list[file_i],'lat')
42 | lat_max = max(lat)
43 | data_ciwc = rotate(data_ciwc,7)
44 |
45 | res = 0.25
46 | geo_info={$
47 | MODELPIXELSCALETAG:[res,res,0.0],$ ;还是直接复制过来这是地理信息直接复制即可,这一行需要加入经纬度分辨率,都是0.25所以不用改
48 | MODELTIEPOINTTAG:[0.0,0.0,0.0,lon_min,lat_max,0.0],$
49 | GTMODELTYPEGEOKEY:2,$
50 | GTRASTERTYPEGEOKEY:1,$
51 | GEOGRAPHICTYPEGEOKEY:4326,$
52 | GEOGCITATIONGEOKEY:'GCS_WGS_1984',$
53 | GEOGANGULARUNITSGEOKEY:9102}
54 | ;写成tif
55 | write_tiff,out_path+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc,/float,geotiff=geo_info ;输出路径out_path,文件名字2021——t2m.tif
56 | ;完成读取
57 | print,'down!!'
58 |
59 | endfor
60 | end
--------------------------------------------------------------------------------
/Core/process_Rs.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/5/9 19:25
3 | # @FileName : process_Rs.py
4 | # @Email : chaoqiezi.one@qq.com
5 |
6 | """
7 | This script is used to 处理RS地表太阳辐射并提取经纬度数据集,通过裁剪、掩膜、重采样等处理输出为tiff文件
8 | """
9 |
10 | import os.path
11 | import matplotlib.pyplot as plt
12 | import netCDF4 as nc
13 | import numpy as np
14 | from osgeo import gdal, osr
15 | import pandas as pd
16 |
17 | # 准备
18 | Rs_path = r'H:\Datasets\Objects\Veg\GWRHXG_Rs1.nc'
19 | out_Rs_dir = r'E:\FeaturesTargets\uniform\Rs'
20 | out_dir = r'E:\FeaturesTargets\uniform'
21 | mask_path = r'E:\Basic\Region\sw5f\sw5_mask.shp'
22 | out_res = 0.1 # 度(°)
23 | if not os.path.exists(out_Rs_dir): os.makedirs(out_Rs_dir)
24 | if not os.path.exists(out_dir): os.makedirs(out_dir)
25 |
26 | # 读取
27 | with nc.Dataset(Rs_path) as f:
28 | lon, lat = f['longitude'][:].filled(-9999), f['latitude'][:].filled(-9999)
29 | Rs = f['Rs'][:].filled(-9999)
30 | years, months = f['year'][:].filled(np.nan), f['month'][:].filled(np.nan)
31 | for ix, (year, month) in enumerate(zip(years, months)):
32 | cur_Rs = Rs[ix, :, :] # 当前时间点的Rs地表太阳辐射
33 | lon_min, lon_max, lat_min, lat_max = lon.min(), lon.max(), lat.min(), lat.max()
34 | lon_res = (lon_max - lon_min) / len(lon)
35 | lat_res = (lat_max - lat_min) / len(lat)
36 | geo_transform = [lon_min, lon_res, 0, lat_max, 0, -lat_res]
37 |
38 | # 输出
39 | out_file_name = 'Rs_{:4.0f}{:02.0f}.tiff'.format(year, month)
40 | out_path = os.path.join(out_Rs_dir, out_file_name)
41 | mem_driver = gdal.GetDriverByName('MEM')
42 | mem_ds = mem_driver.Create('', len(lon), len(lat), 1, gdal.GDT_Float32)
43 | srs = osr.SpatialReference()
44 | srs.ImportFromEPSG(4326)
45 | mem_ds.SetProjection(srs.ExportToWkt())
46 | mem_ds.SetGeoTransform(geo_transform)
47 | mem_ds.GetRasterBand(1).WriteArray(cur_Rs)
48 | mem_ds.GetRasterBand(1).SetNoDataValue(-9999) # 设置无效值
49 | out_ds = gdal.Warp(out_path, mem_ds, cropToCutline=True, cutlineDSName=mask_path, xRes=out_res, yRes=out_res,
50 | resampleAlg=gdal.GRA_Cubic, srcNodata=-9999, dstNodata=-9999)
51 | mem_ds.FlushCache()
52 | print('processing: {}'.format(out_file_name))
53 |
54 | masked_geo_transform = out_ds.GetGeoTransform()
55 | rows, cols = out_ds.RasterYSize, out_ds.RasterXSize
56 | lat = np.array([masked_geo_transform[3] + _ix * masked_geo_transform[-1] + masked_geo_transform[-1] / 2 for _ix in range(rows)])
57 | lon = np.array([masked_geo_transform[0] + _ix * masked_geo_transform[1] + masked_geo_transform[1] / 2 for _ix in range(cols)])
58 | lon_2d, lat_2d = np.meshgrid(lon, lat)
59 | driver = gdal.GetDriverByName('GTiff')
60 | lon_ds = driver.Create(os.path.join(out_dir, 'Lon.tiff'), len(lon), len(lat), 1, gdal.GDT_Float32)
61 | lat_ds = driver.Create(os.path.join(out_dir, 'Lat.tiff'), len(lon), len(lat), 1, gdal.GDT_Float32)
62 | srs.ImportFromEPSG(4326)
63 | lon_ds.SetProjection(srs.ExportToWkt())
64 | lon_ds.SetGeoTransform(masked_geo_transform)
65 | lon_ds.GetRasterBand(1).WriteArray(lon_2d)
66 | lon_ds.GetRasterBand(1).SetNoDataValue(-9999) # 设置无效值
67 | lat_ds.SetProjection(srs.ExportToWkt())
68 | lat_ds.SetGeoTransform(masked_geo_transform)
69 | lat_ds.GetRasterBand(1).WriteArray(lat_2d)
70 | lat_ds.GetRasterBand(1).SetNoDataValue(-9999) # 设置无效值
71 | lon_ds.FlushCache()
72 | lat_ds.FlushCache()
73 |
74 | print('Done!')
--------------------------------------------------------------------------------
/Assets/LSTM-master/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
141 | # pytype static type analyzer
142 | .pytype/
143 |
144 | # Cython debug symbols
145 | cython_debug/
146 |
147 | # PyCharm
148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
150 | # and can be added to the global gitignore or merged into this file. For a more nuclear
151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
152 | #.idea/
153 |
--------------------------------------------------------------------------------
/Core/check_datasets.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2023/12/7 15:07
3 | # @FileName : check_datasets.py
4 | # @Email : chaoqiezi.one@qq.com
5 |
6 | """
7 | This script is used to 用于检查数据完整性, 包括MCD12Q1、MOD11A2、MOD13A2
8 | -·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-
9 | 拓展: MYD\MOD\MCD
10 | MOD标识Terra卫星
11 | MYD标识Aqua卫星
12 | MCD标识Terra和Aqua卫星的结合
13 | -·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-·-
14 | 拓展: MCD12Q1\MOD11A2\MOD13A2
15 | MCD12Q1为土地利用数据
16 | MOD11A2为地表温度数据
17 | MOD13A2为植被指数数据(包括NDVI\EVI)
18 | """
19 |
20 | import os.path
21 | import glob
22 | from datetime import datetime, timedelta
23 |
24 | # 准备
25 | in_dir = r'F:\Cy_modis'
26 | searching_ds_wildcard = ['MCD12Q1', 'MOD11A2', 'MOD13A2']
27 |
28 | # 检查MCD12Q1数据集
29 | error_txt = os.path.join(in_dir, 'MCD12Q1_check_error.txt')
30 | ds_name_wildcard = 'MCD12Q1*'
31 | region_wildcard = ['h26v05', 'h26v06', 'h27v05', 'h27v06']
32 | with open(error_txt, 'w+') as f:
33 | for year in range(2001, 2021):
34 | for region in region_wildcard:
35 | cur_ds_name_wildcard = ds_name_wildcard + 'A{}*'.format(year) + region + '*.hdf'
36 | ds_path_wildcard = os.path.join(in_dir, '**', cur_ds_name_wildcard)
37 | hdf_paths = glob.glob(ds_path_wildcard, recursive=True)
38 | if len(hdf_paths) != 1:
39 | f.write('{}: 文件数目(为: {})不正常\n'.format(cur_ds_name_wildcard, len(hdf_paths)))
40 | if not f.read():
41 | f.write('MCD12Q1数据集文件数正常')
42 |
43 | # 检查MOD11A2数据集
44 | error_txt = os.path.join(in_dir, 'MOD11A2_check_error.txt')
45 | ds_name_wildcard = 'MOD11A2*'
46 | region_wildcard = ['h26v05', 'h26v06', 'h27v05', 'h27v06']
47 | start_date = datetime(2000, 1, 1) + timedelta(days=48)
48 | end_date = datetime(2022, 1, 1) + timedelta(days=296)
49 | with open(error_txt, 'w+') as f:
50 | cur_date = start_date
51 | while cur_date <= end_date:
52 | cur_date_str = cur_date.strftime('%Y%j')
53 | for region in region_wildcard:
54 | cur_ds_name_wildcard = ds_name_wildcard + 'A{}*'.format(cur_date_str) + region + '*.hdf'
55 | ds_path_wildcard = os.path.join(in_dir, '**', cur_ds_name_wildcard)
56 | hdf_paths = glob.glob(ds_path_wildcard, recursive=True)
57 | if len(hdf_paths) != 1:
58 | f.write('{}: 文件数目(为: {})不正常\n'.format(cur_ds_name_wildcard, len(hdf_paths)))
59 | if (cur_date + timedelta(days=8)).year != cur_date.year:
60 | cur_date = datetime(cur_date.year + 1, 1, 1)
61 | else:
62 | cur_date += timedelta(days=8)
63 | if not f.read():
64 | f.write('MOD11A2数据集文件数正常')
65 |
66 | # 检查MOD13A2数据集
67 | error_txt = os.path.join(in_dir, 'MOD13A2_check_error.txt')
68 | ds_name_wildcard = 'MOD13A2*'
69 | region_wildcard = ['h26v05', 'h26v06', 'h27v05', 'h27v06']
70 | start_date = datetime(2000, 1, 1) + timedelta(days=48)
71 | end_date = datetime(2020, 1, 1) + timedelta(days=352)
72 | with open(error_txt, 'w+') as f:
73 | cur_date = start_date
74 | while cur_date <= end_date:
75 | cur_date_str = cur_date.strftime('%Y%j')
76 | for region in region_wildcard:
77 | cur_ds_name_wildcard = ds_name_wildcard + 'A{}*'.format(cur_date_str) + region + '*.hdf'
78 | ds_path_wildcard = os.path.join(in_dir, '**', cur_ds_name_wildcard)
79 | hdf_paths = glob.glob(ds_path_wildcard, recursive=True)
80 | if len(hdf_paths) != 1:
81 | f.write('{}: 文件数目(为: {})不正常\n'.format(cur_ds_name_wildcard, len(hdf_paths)))
82 | if (cur_date + timedelta(days=16)).year != cur_date.year:
83 | cur_date = datetime(cur_date.year + 1, 1, 1)
84 | else:
85 | cur_date += timedelta(days=16)
86 | if not f.read():
87 | f.write('MOD13A2数据集文件数正常')
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/Assets/Ref/gldas_tws_eg.py:
--------------------------------------------------------------------------------
1 | import os
2 | from netCDF4 import Dataset
3 | import numpy as np
4 |
5 | #===========================================================================
6 |
7 | # import time
8 | # from datetime import datetime, timedelta
9 | # from netCDF4 import num2date, date2num
10 | # nc_file = 'F:/My_Postdoctor/GLDAS/GLDAS_NOAH025_M.A200001.021.nc4.SUB.nc4'
11 |
12 | month2seconds = 30*24*3600 # seconds in a month
13 | month2threehours = 30*8 # numbers of 3 hours in a month
14 | filelist = [] # create a file list for GLDAS data
15 | TWSCs = [] # create TWSCs (changes of Terrestrial Water Storage)
16 | PRCPs = []
17 | ETs = []
18 | Qss = []
19 | Qsbs = []
20 | # read GLDAS files of *.nc4 (0.25 x 0.25),and put them into filelist
21 | for (dirname, dirs, files) in os.walk('F:/My_Postdoctor/GlobalGLDAS/'):
22 | for filename in files:
23 | if filename.endswith('.nc4'):
24 | filelist.append(os.path.join(dirname,filename))
25 |
26 | filelist = np.sort(filelist) # order in time
27 | num_files = len(filelist) # 201 files of *.nc4 from 2002.4-2018.12
28 |
29 | # read each file of *.nc4
30 | n=1000
31 | for files in filelist:
32 | data = Dataset(files, 'r',format='netCDF4')
33 |
34 | lons_gldas = data.variables['lon'][:] #len: 190
35 | lats_gldas = data.variables['lat'][:] #len: 103
36 |
37 | precipitation_flux = data.variables['Rainf_f_tavg'][:]
38 | water_evaporation_flux = data.variables['Evap_tavg'][:]
39 | surface_runoff_amount = data.variables['Qs_acc'][:]
40 | subsurface_runoff_amount = data.variables['Qsb_acc'][:]
41 |
42 | data.close() # close files of GLDAS
43 |
44 | # get monthly P, ET, Surface runoff, underground runoff
45 | precipitation_flux = precipitation_flux[0] # from (1,600,1440) to (600,1440) to reduce dim
46 | water_evaporation_flux = water_evaporation_flux[0]
47 | surface_runoff_amount = surface_runoff_amount[0]
48 | subsurface_runoff_amount = subsurface_runoff_amount[0]
49 |
50 | # calculate change of (k-1)-month TWSC in term of k-month (k>=1) TWSC
51 | TWSC = (precipitation_flux - water_evaporation_flux) * month2seconds - (surface_runoff_amount + subsurface_runoff_amount) * month2threehours
52 | PRCP = precipitation_flux * month2seconds
53 | ET = water_evaporation_flux * month2seconds
54 | Qs = surface_runoff_amount * month2threehours
55 | Qsb = subsurface_runoff_amount * month2threehours
56 |
57 | # #save
58 | # lons_new = lons_gldas[900:1072]
59 | # lats_new = lats_gldas[372:468]
60 | # PRCP = precipitation_flux[372:468, 900:1072]*month2seconds
61 | # ET = water_evaporation_flux[372:468, 900:1072]*month2seconds
62 | # Qs = surface_runoff_amount[372:468, 900:1072]* month2threehours
63 | # Qsb = subsurface_runoff_amount[372:468, 900:1072]* month2threehours
64 | # TWSC = TWSC[372:468, 900:1072]
65 | # n+= 1
66 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/PRCP_' + str(n), PRCP)
67 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/ET_' + str(n), ET)
68 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/Qs_' + str(n), Qs )
69 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/Qsb_' + str(n), Qsb)
70 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/lon_'+ str(n), lons_new)
71 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/lat_'+ str(n), lats_new)
72 |
73 |
74 | # GLDAS 月文件中的海洋区域和南极洲地区均无有效测量值(默认填充为 −9999.0),这里将填充值重设为零。
75 | TWSCs.append(TWSC.filled(0))
76 |
77 |
78 | TWSCs = np.array(TWSCs)
79 | # 首先计算第 k(k>=1) 个月相对于第 0 个月的陆地水储量变化,即每个月的陆地水储量变化,然后计算陆地水储量变化的月平均
80 | TWSCs_acc = np.cumsum(TWSCs, axis=0)
81 | TWSCs_acc_average = np.average(TWSCs_acc, axis=0)
82 | # 对每个月的陆地是储量变化进行去平均化,and get the final TWSCs
83 | TWSCs = TWSCs_acc - TWSCs_acc_average
84 |
85 | #save ewt
86 | # n=1000
87 | # for da in TWSCs:
88 | # n+=1
89 | # np.savetxt('F:/My_Postdoctor/TWS_project/TWS_GLDAS/GLDAS_TWSC/TWSC_'+str(n),da)
90 |
91 |
92 |
--------------------------------------------------------------------------------
/Assets/Ref/sp_ncdf_lunwen.pro:
--------------------------------------------------------------------------------
1 | ;首先我们需要一个读取函数,这里我直接复制过来讲解
2 | function era5_readarr,file_name,sds_name ;这里只需要给函数传递文件名,和数据集名字
3 | file_id=ncdf_open(file_name) ;打开文件
4 | data_id=ncdf_varid(file_id,sds_name);读取数据集-我们这里是2m温度
5 | ncdf_varget,file_id,data_id,data;获取2m温度 存储为data
6 | ncdf_attget,file_id,data_id,'scale_factor',sc_data ;获取数据需要的预处理乘法因子,每个数据需要先乘以这个因子
7 | ncdf_attget,file_id,data_id,'add_offset',ao_data;相应的获取加法因子,每个数据需要乘以上面因子之后再加
8 | ncdf_close,file_id
9 | data=float(data)*sc_data+ao_data ;先乘后加
10 | return,data ;得到的最后数据返回
11 | end
12 |
13 | function era5_readarr11,file_name,sds_name ;这里只需要给函数传递文件名,和数据集名字
14 | file_id=ncdf_open(file_name) ;打开文件
15 | data_id=ncdf_varid(file_id,sds_name);读取数据集-我们这里是2m温度
16 | ncdf_varget,file_id,data_id,data;获取2m温度 存储为data
17 | ;ncdf_attget,file_id,data_id,'scale_factor',sc_data ;获取数据需要的预处理乘法因子,每个数据需要先乘以这个因子
18 | ;ncdf_attget,file_id,data_id,'add_offset',ao_data;相应的获取加法因子,每个数据需要乘以上面因子之后再加
19 | ncdf_close,file_id
20 | data=float(data);*sc_data+ao_data ;先乘后加
21 | return,data ;得到的最后数据返回
22 | end
23 | ;接着我们对应我们的数据处理
24 | pro sp_ncdf_lunwen
25 | ;输入文件名字直接复制过来也可以直接拖过来。
26 | path = 'F:\new_lunwen\Global GLDAS\'
27 |
28 | out_path = 'F:\new_lunwen\Rainf_f_tavg\'
29 | out_path1 = 'F:\new_lunwen\Evap_tavg_tavg\'
30 | out_path2 = 'F:\new_lunwen\Qs_acc\'
31 | out_path3 = 'F:\new_lunwen\Qsb_acc\'
32 | if ~file_test(out_path1, /directory) then file_mkdir,out_path1
33 | if ~file_test(out_path2, /directory) then file_mkdir,out_path2
34 | if ~file_test(out_path3, /directory) then file_mkdir,out_path3
35 | file_dir = file_test(out_path,/directory)
36 | if file_dir eq 0 then begin
37 | file_mkdir,out_path
38 | endif
39 | ;接着我们读取数据-t2m
40 | file_list = file_search(path,'*.nc4',count = file_n)
41 | for file_i=0,file_n-1 do begin
42 |
43 | year = fix(strmid(file_basename(file_list[file_i]),17,4))
44 | month = fix(strmid(file_basename(file_list[file_i]),21,2))
45 | data_ciwc = era5_readarr11(file_list[file_i],'Rainf_f_tavg')
46 | data_ciwc1 = era5_readarr11(file_list[file_i],'Evap_tavg')
47 | data_ciwc2 = era5_readarr11(file_list[file_i],'Qs_acc')
48 | data_ciwc3 = era5_readarr11(file_list[file_i],'Qsb_acc')
49 |
50 |
51 |
52 | data_ciwc = (data_ciwc gt -9999 and data_ciwc lt 9999)*data_ciwc
53 | data_ciwc1 = (data_ciwc1 gt -9999 and data_ciwc1 lt 9999)*data_ciwc1
54 | data_ciwc2 = (data_ciwc2 gt -9999 and data_ciwc2 lt 9999)*data_ciwc2
55 | data_ciwc3 = (data_ciwc3 gt -9999 and data_ciwc3 lt 9999)*data_ciwc3
56 |
57 | lon = era5_readarr11(file_list[file_i],'lon')
58 | lon_min = min(lon)
59 | lat = era5_readarr11(file_list[file_i],'lat')
60 | lat_max = max(lat)
61 | data_ciwc = rotate(data_ciwc,7)
62 | data_ciwc1 = rotate(data_ciwc1,7)
63 | data_ciwc2 = rotate(data_ciwc2,7)
64 | data_ciwc3 = rotate(data_ciwc3,7)
65 | if month eq 1 or month eq 3 or month eq 5 or month eq 7 or month eq 8 or month eq 10 or month eq 12 then begin
66 | data_ciwc = data_ciwc * 31 * 24 * 3600
67 | data_ciwc1 = data_ciwc1 * 31 * 24 * 3600
68 | data_ciwc2 = data_ciwc2 * 31 * 8
69 | data_ciwc3 = data_ciwc3 * 31 * 8
70 |
71 | endif
72 | if month eq 4 or month eq 6 or month eq 9 or month eq 11 then begin
73 | data_ciwc = data_ciwc * 30 * 24 * 3600
74 | data_ciwc1 = data_ciwc1 * 30 * 24 * 3600
75 | data_ciwc2 = data_ciwc2 * 30 * 8
76 | data_ciwc3 = data_ciwc3 * 30 * 8
77 |
78 | endif
79 | m = year mod 4
80 | if month eq 2 and m eq 0 then begin
81 | data_ciwc = data_ciwc * 29 * 24 * 3600
82 | data_ciwc1 = data_ciwc1 * 29 * 24 * 3600
83 | data_ciwc2 = data_ciwc2 * 29 * 8
84 | data_ciwc3 = data_ciwc3 * 29 * 8
85 |
86 | endif
87 | if month eq 2 and m ne 0 then begin
88 | data_ciwc = data_ciwc * 28 * 24 * 3600
89 | data_ciwc1 = data_ciwc1 * 28 * 24 * 3600
90 | data_ciwc2 = data_ciwc2 * 28 * 8
91 | data_ciwc3 = data_ciwc3 * 28 * 8
92 |
93 | endif
94 |
95 | res = 0.25
96 | geo_info={$
97 | MODELPIXELSCALETAG:[res,res,0.0],$ ;还是直接复制过来这是地理信息直接复制即可,这一行需要加入经纬度分辨率,都是0.25所以不用改
98 | MODELTIEPOINTTAG:[0.0,0.0,0.0,lon_min,lat_max,0.0],$ ;这里需要提供最小经度和最大纬度,在第4、5个位置
99 | GTMODELTYPEGEOKEY:2,$
100 | GTRASTERTYPEGEOKEY:1,$
101 | GEOGRAPHICTYPEGEOKEY:4326,$
102 | GEOGCITATIONGEOKEY:'GCS_WGS_1984',$
103 | GEOGANGULARUNITSGEOKEY:9102}
104 | ;写成tif
105 | write_tiff,out_path+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc,/float,geotiff=geo_info ;输出路径out_path,文件名字2021——t2m.tif
106 | write_tiff,out_path1+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc1,/float,geotiff=geo_info
107 | write_tiff,out_path2+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc2,/float,geotiff=geo_info
108 | write_tiff,out_path3+file_basename(file_list[file_i],'.nc4')+'.tif',data_ciwc3,/float,geotiff=geo_info
109 | ;完成读取
110 | print,'down!!'
111 |
112 | endfor
113 | end
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2023/12/14 6:33
3 | # @Email : chaoqiezi.one@qq.com
4 |
5 | """
6 | This script is used to 存放常用工具
7 | """
8 |
9 | import h5py
10 | import torch
11 | from torch.utils.data import Dataset
12 |
13 | torch.manual_seed(42) # 固定种子
14 |
15 |
16 | class H5DynamicDatasetDecoder(Dataset):
17 | """
18 | 对存储动态特征项和目标的HDF5文件进行加载和解析, 用于后续的数据集的加载训练
19 | """
20 |
21 | def __init__(self, file_path, shuffle_feature_ix=None, dynamic=True):
22 | self.file_path = file_path
23 | self.shuffle_feature_ix = shuffle_feature_ix
24 | self.dynamic = dynamic
25 |
26 | # 获取数据集样本数
27 | with h5py.File(file_path, mode='r') as h5:
28 | self.length = h5['dynamic_features'].shape[1]
29 | self.targets = h5['targets'][:] # (12, 138488)
30 | self.dynamic_features = h5['dynamic_features'][:] # (12, 138488, 6)
31 |
32 | if self.shuffle_feature_ix is not None:
33 | shuffled_indices = torch.randperm(self.length)
34 | if self.dynamic:
35 | # 乱序索引
36 | self.dynamic_features[:, :, self.shuffle_feature_ix] = \
37 | self.dynamic_features[:, shuffled_indices, self.shuffle_feature_ix]
38 |
39 | def __len__(self):
40 | """
41 | 返回数据集的总样本数
42 | :return:
43 | """
44 |
45 | return self.length
46 |
47 | def __getitem__(self, index):
48 | """
49 | 依据索引索引返回一个样本
50 | :param index:
51 | :return:
52 | """
53 |
54 | dynamic_feature = self.dynamic_features[:, index, :]
55 | target = self.targets[:, index]
56 | return torch.tensor(dynamic_feature, dtype=torch.float32), torch.tensor(target, dtype=torch.float32)
57 |
58 |
59 | class H5DatasetDecoder(Dataset):
60 | """
61 | 对存储特征项和目标项的HDF5文件进行解析,用于后续的数据集加载训练
62 | """
63 |
64 | def __init__(self, file_path, shuffle_feature_ix=None, dynamic=True):
65 | self.file_path = file_path
66 | self.shuffle_feature_ix = shuffle_feature_ix
67 | self.dynamic = dynamic
68 |
69 | # 获取数据集样本数
70 | with h5py.File(file_path, mode='r') as h5:
71 | self.length = h5['static_features1'].shape[0]
72 | self.targets = h5['targets'][:] # (12, 138488)
73 | self.dynamic_features = h5['dynamic_features'][:] # (12, 138488, 6)
74 | self.static_features1 = h5['static_features1'][:] # (138488,)
75 | self.static_features2 = h5['static_features2'][:] # (138488,)
76 | self.static_features3 = h5['static_features3'][:] # (138488,)
77 | self.static_features4 = h5['static_features4'][:] # (138488,)
78 |
79 | if self.shuffle_feature_ix is not None:
80 | shuffled_indices = torch.randperm(self.length)
81 | if self.dynamic:
82 | # 乱序索引
83 | self.dynamic_features[:, :, self.shuffle_feature_ix] = \
84 | self.dynamic_features[:, shuffled_indices, self.shuffle_feature_ix]
85 | elif self.shuffle_feature_ix == 0: # 静态的
86 | self.static_features1 = self.static_features1[shuffled_indices]
87 | elif self.shuffle_feature_ix == 1:
88 | self.static_features2 = self.static_features2[shuffled_indices]
89 | elif self.shuffle_feature_ix == 2:
90 | self.static_features3 = self.static_features3[shuffled_indices]
91 | elif self.shuffle_feature_ix == 3:
92 | self.static_features4 = self.static_features4[shuffled_indices]
93 |
94 | def __len__(self):
95 | """
96 | 返回数据集的总样本数
97 | :return:
98 | """
99 |
100 | return self.length
101 |
102 | def __getitem__(self, index):
103 | """
104 | 依据索引索引返回一个样本
105 | :param index:
106 | :return:
107 | """
108 |
109 | dynamic_feature = self.dynamic_features[:, index, :]
110 | static_features1 = self.static_features1[index]
111 | static_features2 = self.static_features2[index]
112 | static_features3 = self.static_features3[index]
113 | static_features4 = self.static_features4[index]
114 | target = self.targets[:, index]
115 |
116 | static_feature = (static_features1, static_features2, static_features3, static_features4)
117 | return torch.tensor(dynamic_feature, dtype=torch.float32), \
118 | torch.tensor(static_feature, dtype=torch.float32), torch.tensor(target, dtype=torch.float32)
119 |
120 |
121 | def cal_r2(outputs, targets):
122 | """
123 | 计算R2决定系数
124 | :param outputs:
125 | :param targets:
126 | :return:
127 | """
128 | mean_predictions = torch.mean(outputs, dim=0, keepdim=True)
129 | mean_targets = torch.mean(targets, dim=0, keepdim=True)
130 | predictions_centered = outputs - mean_predictions
131 | targets_centered = targets - mean_targets
132 | corr = torch.sum(predictions_centered * targets_centered, dim=0) / \
133 | (torch.sqrt(torch.sum(predictions_centered ** 2, dim=0)) * torch.sqrt(
134 | torch.sum(targets_centered ** 2, dim=0)))
135 |
136 | return torch.mean(corr)
137 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 2023-12-17更新处理过程记录
2 | 处理过程包含两部分代码:
3 |
4 | - check_datasets.py
5 | - aver_cal.py
6 |
7 | ## 01 检查数据集完整性
8 |
9 | `check_datasets.py`用于检查数据集的完整性, 包括`MCD12Q1`为土地利用数据、 `MOD11A2`为地表温度数据、 `MOD13A2`为植被指数数据。
10 |
11 | 1. MCD12Q1数据集(土地利用|每年)检查至2001年~2021年:
12 |
13 | 正常
14 |
15 | 2. MOD11A2数据集(地表温度|8天周期)检查至2000年第48日~2022年第296日:
16 |
17 | MOD11A2*A2001169*h26v05*.hdf: 文件数目(为: 0)不正常
18 | MOD11A2*A2001169*h26v06*.hdf: 文件数目(为: 0)不正常
19 | MOD11A2*A2001169*h27v05*.hdf: 文件数目(为: 0)不正常
20 | MOD11A2*A2001169*h27v06*.hdf: 文件数目(为: 0)不正常
21 | MOD11A2*A2001177*h26v05*.hdf: 文件数目(为: 0)不正常
22 | MOD11A2*A2001177*h26v06*.hdf: 文件数目(为: 0)不正常
23 | MOD11A2*A2001177*h27v05*.hdf: 文件数目(为: 0)不正常
24 | MOD11A2*A2001177*h27v06*.hdf: 文件数目(为: 0)不正常
25 | MOD11A2*A2010121*h26v05*.hdf: 文件数目(为: 0)不正常(已下载)
26 | MOD11A2*A2010121*h26v06*.hdf: 文件数目(为: 0)不正常(已下载)
27 | MOD11A2*A2010121*h27v05*.hdf: 文件数目(为: 0)不正常(已下载)
28 | MOD11A2*A2010121*h27v06*.hdf: 文件数目(为: 0)不正常(已下载)
29 | MOD11A2*A2022289*h26v05*.hdf: 文件数目(为: 0)不正常
30 | MOD11A2*A2022289*h26v06*.hdf: 文件数目(为: 0)不正常
31 | MOD11A2*A2022289*h27v05*.hdf: 文件数目(为: 0)不正常
32 | MOD11A2*A2022289*h27v06*.hdf: 文件数目(为: 0)不正常
33 | 其余未标注`已下载`的数据集未官网缺失,目前未解决。
34 |
35 | 3. MOD13A2数据集(NDVI|16天周期)检查至2000年第48天~2020年第352天:
36 |
37 | 正常
38 |
39 | ## 02 对三大数据集进行全流程预处理
40 |
41 | `aver_cal.py`主要对三个数据集(土地利用数据集、NDVI数据集、地表温度数据集)进行镶嵌、重投影并最终输出为GeoTIFF文件。
42 |
43 | 1. MCD12Q1(土地利用)数据集具体包括读取LC_Type1数据集(IGBP分类标准)、无效值去除(无效值设定为255)、镶嵌(Last模式, 年尺度)、
44 | 重投影、(sinu ==> WGS84, 重采样为最近邻<因为土地利用数据类型为整型>), 输出分辨率为0.045°(500m), 无效值为255.
45 | 2. MOD11A2(地表温度)数据集具体包括读取LST_Day_1km、无效值去除(无效值设定为-65535)、单位换算(最终单位为摄氏度)、
46 | 镶嵌(MAX模式, 月尺度)、重投影(sinu ==> WGS84, 重采样为三次卷积), 输出分辨率为0.009°(1000m), 无效值为-65535.
47 | 3. MOD13A2(NDVI)数据集具体包括读取1 km 16 days NDVI、无效值去除(无效值设定为-65535)、单位换算、镶嵌(MAX模式, 月尺度)、
48 | 重投影(sinu ==> WGS84, 重采样为三次卷积), 输出分辨率为0.009°(1000m), 无效值为-65535.
49 | 4. MOD13A2(EVI)数据集具体包括读取1 km 16 days EVI、无效值去除(无效值设定为-65535)、单位换算、镶嵌(MAX模式, 月尺度)、
50 | 重投影(sinu ==> WGS84, 重采样为三次卷积), 输出分辨率为0.009°(1000m), 无效值为-65535.
51 |
52 | # 2024年1月18日更新处理过程记录
53 | 处理主要包括:
54 |
55 | - 文件名更改(将aver_cal.py更改为process_modis.py)
56 | - 编写process_gldas.py代码
57 |
58 | ## process_gldas.py
59 | 主要包括对gldas数据集(nc文件)中的`Rain_f_tavg`(降水通量), `Evap_tavg`(蒸散发通量), `Qs_acc`(表面径流量), `Qsb_acc`(地下径流量)
60 | 进行月累加值的计算分别的都月降水量、月蒸散发量、月表面径流量、月地下径流量,依据`TWSC` = 降水量 - 蒸散发量 - (表面径流量 + 地下径流量)。
61 | 需要注意,此时栅格矩阵的范围为-180~180, -60~90.另外进行了无效值的去除(设置为nan)、南北极颠倒、重采样以及重采样后范围偏移的限定。
62 | 最后输出为tiff文件,WGS84坐标系.
63 |
64 | # 2024年1月19日更新处理过程记录
65 |
66 | 处理包括:
67 | - uniform_datasets.py
68 | - 修复process_gldas.py Bug
69 |
70 | ## uniform_datasets.py
71 | 主要是进行各个数据集的统一,统一包括空间范围的限定,研究区域范围如下:
72 | 
73 | 具体是进行掩膜和裁剪至掩膜形状、以及重采样0.1°
74 | 处理的数据集包括:
75 | - Landuse
76 | - LST
77 | - NDVI
78 | - 降水(PRCP)
79 | - 蒸散发量(ET)
80 | - 地表径流量(Qs)
81 | - 地下径流量(Qsb)
82 | - TWSC
83 |
84 | ## 修复process_gldas.py Bug
85 |
86 | 主要是在gldas数据集的处理中,将所有缺失值均赋值为np.nan,虽然可以正确写入无效值。但是在uniform_datasets.py
87 | 的处理中存在无法正确识别无效值np.nan的情况,因此出现了非研究区域的像元值为0.0而非无效值nan。
88 | 因此在gldas数据集的处理中将所有无效值设置为-65535,这一操作则与前期处理NDVI、LST、Landuse的操作一致,并没有
89 | 直接赋值np.nan,这也就是为什么在uniform_datasets.py中这个三个数据集没有发生意外情况的原因。
90 |
91 | # 2024年1月22日更新处理过程记录
92 |
93 | 处理包括:
94 | - feature_engineering.py
95 | - model_train.py
96 | - utils.models.py
97 | - 安装环境requirements.txt
98 |
99 | ---
100 | ## feature_engineering.py
101 | 主要是进行各个特征项和目标项数据集的模型预输入的处理,方便后续的模型的数据加载和训练,
102 | 涉及的数据集包括:
103 | Landuse: 2001 - 2020
104 | LST: 200002 - 202210
105 | NDVI: 200002 - 202010
106 | ET: 200204 - 202309
107 | PRCP: 200204 - 202309
108 | Qs: 200204 - 202309
109 | Qsb: 200204 - 202309
110 | TWSC: 200204 - 202309
111 | dem: single
112 | 由于数据集的某些特殊例如landuse的时间分辨率(年)与其它数据集(月)不同,dem所有训练样本都固定不随时间变化,仅体现地理
113 | 位置上的差异,因此对于landuse和dem进行单独的存储(后续将根据需求进行改写或许, 可能不太方便数据加载或者说数据加载出来
114 | 到可用于模型训练还需要进行一定的数据变换,这需要一定的时间和成本)
115 |
116 | 处理仅仅将各个数据集整理如下HDF5文件格式:
117 | - group(2002)
118 | - features1
119 |
120 | _存储月分辨率的数据集(shape为(行数 * 列数<即样本数>, 时间步, 特征数)), 特征项具体为: LST、PRCP、ET、Qs、Qsb、TWSC_
121 | - features2
122 |
123 | _存储年分辨率的数据集(shape为(行数 * 列数)), 特征项具体为: Landuse_
124 | - targets
125 |
126 | _存储月分辨率的数据集(shape为(行数 * 列数, 时间步)), 目标项具体为: NDVI_
127 | - group(2003)
128 | - features1
129 | - features2
130 | - targets
131 |
132 | - ······
133 | - dem
134 |
135 | _存储DEM的数据集, shape为(行数 * 列数)_
136 |
137 | ---
138 | ## model_train.py
139 |
140 | 主要是模型的构建的训练以及评估(处于优化中, 依据后续任务进行框架的完善)
141 | 模型这里CNN-LSTM模型, cnn为一维卷积且在时间维度上卷积, lstm为常规模型.
142 | 目前完成某年部分样本的训练、评估。
143 | 这里暂未完全定型,只算草稿版本,后续将完善。
144 |
145 | ---
146 | ## utils.models.py
147 |
148 | 这里存储定义的模型,目前定义的编码解码lstm模型实际效果不如前面的cnn-lstm模型,暂时搁置。
149 |
150 | ---
151 | ## requirements.txt
152 |
153 | 考虑的后续项目的维护和迁移,这里增加环境配置的相关信息,配置代码:
154 |
155 | ```shell
156 | pip install -r I:\PyProJect\Veg\VEG\requirements.txt
157 | ```
158 |
159 | # 2024年02月29日处理记录
160 |
161 | ## 增加NDVI、LST的MEAN、MIN处理
162 |
163 | 修改process_modis.py相关参数,进行NDVI、LST的MEAN、MIN计算
164 |
165 | ## uniform_datasets.py, feature_engineering.py调整
166 |
167 | # 2024年05月09日处理
168 |
169 | 新增process_Rs.py文件, 用于处理Rs地表太阳辐射数据, 主要是各个月份的影像nc转tiff,顺便做了一下掩膜、裁剪和重采样
170 | 并输出了lon和lat数据集为tiff文件, 方便后续作为变量输入到模型中
171 |
172 | # 2024/5/11处理
173 | 完善feature_engineering.py文件, 新增关于Rs(dynamic)、lon(static)、lat(static)的特征输入
--------------------------------------------------------------------------------
/Core/process_gldas.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/1/17 12:41
3 | # @FileName : process_gldas.py
4 | # @Email : chaoqiezi.one@qq.com
5 |
6 | """
7 | This script is used to 预处理global gldas数据集
8 |
9 | 说明:
10 | 为确保简洁性和便捷性, 今后读取HDF5文件和NC文件均使用xarray模块而非h5py和NetCDF4模块
11 | 数据集介绍:
12 | TWSC = 降水量(PRCP) - 蒸散发量(ET) - 径流量(即表面径流量Qs + 地下径流量Qsb) ==> 给定时间间隔内, 例如月
13 | 在gldas数据集中:
14 | Rainf_f_tavg表示降水通量,即单位时间单位面积上的降水量(本数据集单位为kg/m2/s)
15 | Evap_tavg表示蒸散发通量,即单位时间单位面积上的水蒸发量(本数据集单位为kg/m2/s)
16 | Qs_acc表示表面径流量,即一定时间内通过地表流动进入河流、湖泊和水库的水量(本数据集单位为kg/m2)
17 | Qsb_acc表示地下径流量,即一定时间内通过土壤层流动的水量,最终进入河流的水量,最终进入河流的水量(本数据集单位为kg/m2)
18 | TWSC计算了由降水和蒸发引起的净水量变化,再减去地表和地下径流,其评估给定时间段内区域水资源变化的重要指标
19 |
20 | 存疑:
21 | 01 对于Qs和Qsb的计算, 由于数据集单位未包含/s, 是否已经是月累加值? --2024/01/18(已解决)
22 | ==> 由gldas_tws_eg.py知是: numbers of 3 hours in a month,
23 | 另外nc文件全局属性也提及:
24 | :tavg_definision: = "past 3-hour average";
25 | :acc_definision: = "past 3-hour accumulation";
26 |
27 | """
28 |
29 | import os.path
30 | from glob import glob
31 | from calendar import monthrange
32 | from datetime import datetime
33 |
34 | import numpy as np
35 | import xarray as xr
36 | from osgeo import gdal, osr
37 |
38 | # 准备
39 | in_dir = r'E:\Global GLDAS' # 检索该文件夹及迭代其所有子文件夹满足要求的文件
40 | out_dir = r'E:\FeaturesTargets\non_uniform'
41 | target_names = ['Rainf_f_tavg', 'Evap_tavg', 'Qs_acc', 'Qsb_acc']
42 | out_names = ['PRCP', 'ET', 'Qs', 'Qsb', 'TWSC']
43 | out_res = 0.1 # default: 0.25°, base on default res of gldas
44 | no_data_value = -65535.0 # 缺失值或者无效值的设置
45 | # 预准备
46 | [os.makedirs(os.path.join(out_dir, _name)) for _name in out_names if not os.path.exists(os.path.join(out_dir, _name))]
47 |
48 | # 检索和循环
49 | nc_paths = glob(os.path.join(in_dir, '**', 'GLDAS_NOAH025_M*.nc4'), recursive=True)
50 | for nc_path in nc_paths:
51 | # 获取当前月天数
52 | cur_time = datetime.strptime(nc_path.split('.')[1], 'A%Y%m') # eg. 200204
53 | _, cur_month_days = monthrange(cur_time.year, cur_time.month)
54 |
55 | ds = xr.open_dataset(nc_path)
56 | # 读取经纬度数据集和地理参数
57 | lon = ds['lon'].values # (1440, )
58 | lat = ds['lat'].values # (600, )
59 | lon_res = ds.attrs['DX']
60 | lat_res = ds.attrs['DY']
61 | lon_min = min(lon) - lon_res / 2.0
62 | lon_max = max(lon) + lon_res / 2.0
63 | lat_min = min(lat) - lat_res / 2.0
64 | lat_max = max(lat) + lat_res / 2.0
65 | """
66 | 注意: 经纬度数据集中的所有值均指代对应地理位置的像元的中心处的经纬度, 因此经纬度范围需要往外扩充0.5个分辨率
67 | """
68 | geo_transform = [lon_min, lon_res, 0, lat_max, 0, -lat_res] # gdal要求样式
69 | srs = osr.SpatialReference()
70 | srs.ImportFromEPSG(4326) # WGS84
71 |
72 | fluxs = {}
73 | # 获取Rain_f_tavg, Evap_tavg, Qs_acc, Qsb_acc四个数据集
74 | for target_name, out_name in zip(target_names, out_names): # 仅循环前四次
75 | # 计算月累加值
76 | flux = ds[target_name].values
77 | vmin = ds[target_name].attrs['vmin']
78 | vmax = ds[target_name].attrs['vmax']
79 | flux[(flux < vmin) | (flux > vmax)] = np.nan # 将不在规定范围内的值设置为nan
80 | flux = np.squeeze(flux) # 去掉多余维度
81 | flux = np.flipud(flux) # 南北极颠倒(使之正常: 北极在上)
82 | if target_name.endswith('acc'): # :acc_definision: = "past 3-hour accumulation";
83 | flux *= cur_month_days * 8
84 | elif target_name.endswith('tavg'): # :tavg_definision: = "past 3-hour average";
85 | flux *= cur_month_days * 24 * 3600
86 | fluxs[out_name] = flux
87 |
88 | fluxs['TWSC'] = fluxs['PRCP'] - fluxs['ET'] - (fluxs['Qs'] + fluxs['Qsb']) # 计算TWSC
89 | for out_name, flux in fluxs.items():
90 | # 输出路径
91 | cur_out_name = 'GLDAS_{}_{:04}{:02}.tiff'.format(out_name, cur_time.year, cur_time.month)
92 | cur_out_path = os.path.join(out_dir, out_name, cur_out_name)
93 |
94 | driver = gdal.GetDriverByName('MEM') # 在内存/TIFF中创建
95 | temp_img = driver.Create('', flux.shape[1], flux.shape[0], 1, gdal.GDT_Float32)
96 | temp_img.SetProjection(srs.ExportToWkt()) # 设置坐标系
97 | temp_img.SetGeoTransform(geo_transform) # 设置仿射参数
98 | flux = np.nan_to_num(flux, nan=no_data_value)
99 | temp_img.GetRasterBand(1).WriteArray(flux) # 写入数据集
100 | temp_img.GetRasterBand(1).SetNoDataValue(no_data_value) # 设置无效值
101 | resample_img = gdal.Warp(cur_out_path, temp_img, xRes=out_res, yRes=out_res, resampleAlg=gdal.GRA_Cubic) # 重采样
102 | # 去除由于重采样造成的数据集不符合实际意义例如降水为负值等情况
103 | vmin = np.nanmin(flux)
104 | vmax = np.nanmax(flux)
105 | flux = resample_img.GetRasterBand(1).ReadAsArray()
106 | resample_img_srs = resample_img.GetProjection()
107 | resample_img_transform = resample_img.GetGeoTransform()
108 | temp_img, resample_img = None, None # 释放资源
109 | flux[flux < vmin] = vmin
110 | flux[flux > vmax] = vmax
111 | driver = gdal.GetDriverByName('GTiff')
112 | final_img = driver.Create(cur_out_path, flux.shape[1], flux.shape[0], 1, gdal.GDT_Float32)
113 | final_img.SetProjection(resample_img_srs)
114 | final_img.SetGeoTransform(resample_img_transform)
115 | final_img.GetRasterBand(1).WriteArray(flux)
116 | final_img.GetRasterBand(1).SetNoDataValue(no_data_value)
117 | final_img.FlushCache()
118 | temp_img, final_img = None, None
119 |
120 | print('当前处理: {}-{}'.format(out_name, cur_time.strftime('%Y%m')))
121 |
122 | ds.close() # 关闭当前nc文件,释放资源
123 | print('处理完成')
--------------------------------------------------------------------------------
/Core/Plot/density_box_plot.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/3/11 18:58
3 | # @Email : chaoqiezi.one@qq.com
4 |
5 | """
6 | This script is used to 是用来绘图滴,主要是箱线图和核密度散点图
7 | """
8 |
9 | import os
10 | import numpy as np
11 | import pandas as pd
12 | import matplotlib.pyplot as plt
13 | from scipy.stats import gaussian_kde
14 | import seaborn as sns
15 | from osgeo import gdal
16 | from matplotlib.colors import LinearSegmentedColormap
17 |
18 | # 准备
19 | in_path = r'H:\Datasets\Objects\Veg\Plot\cor_by_st.csv'
20 | dem_path = r'H:\Datasets\Objects\Veg\DEM\dem_1km.tif'
21 | out_dir =r'H:\Datasets\Objects\Veg\Plot'
22 | sns.set_style('darkgrid') # 设置风格
23 | plt.rcParams['font.sans-serif'] = ['Times New Roman']
24 | plt.rcParams['axes.unicode_minus'] = False # 允许负号正常显示
25 |
26 | # 加载数据
27 | df = pd.read_csv(in_path)
28 | dem = gdal.Open(dem_path)
29 | dem_raster = dem.GetRasterBand(1).ReadAsArray() # 获取dem栅格矩阵
30 | dem_nodata_value = dem.GetRasterBand(1).GetNoDataValue() # 获取无效值
31 | lon_ul, lon_res, _, lat_ul, _, lat_res_negative = dem.GetGeoTransform() # [左上角经度, 经度分辨率, 旋转角度, 左上角纬度, 旋转角度, -纬度分辨率]
32 | lat_res = -lat_res_negative
33 | # 删除TWSC列, 将TWSC_SH列标签换为TWSC
34 | df.drop(['TWSC', 'TWSC_1', 'TWSC_2', 'TWSC_3'], axis=1, inplace=True)
35 | df.rename(columns={'TWSC_SH': 'TWSC', 'TWSC_SH_1': 'TWSC_1', 'TWSC_SH_2': 'TWSC_2', 'TWSC_SH_3': 'TWSC_3'}, inplace=True)
36 | iter_columns_name = df.columns[4:]
37 | # 色带
38 | colors = ['#ff0000', '#ff6f00', '#fbb700', '#cdff00', '#a1ff6e', '#52ffc7', '#00ffff', '#15acff', '#4261ff', '#3100fe']
39 | colors.reverse()
40 | cm = LinearSegmentedColormap.from_list('common', colors, 100)
41 |
42 | # 添加DEM列
43 | cols = np.floor((df['Lon'] - lon_ul) / lon_res).astype(int)
44 | rows = np.floor((lat_ul - df['Lat']) / lat_res).astype(int)
45 | df['DEM'] = dem_raster[rows, cols]
46 | df[df['DEM'] == dem_nodata_value] = np.nan
47 | # 绘制散点核密度图
48 | for column_name in iter_columns_name:
49 | plt.figure(dpi=200)
50 | cur_ds = df[['DEM', column_name]].dropna(how='any')
51 | cur_ds['Density'] = gaussian_kde(cur_ds[column_name])(cur_ds[column_name])
52 |
53 | scatter = plt.scatter(x='DEM', y=column_name, c='Density', cmap=cm, linewidth=0, data=cur_ds, s=20)
54 | clb = plt.colorbar(scatter)
55 | clb.ax.set_title('Density', fontsize=8) # 为色带添加标题
56 | # sns.kdeplot(x='DEM', y=column_name, fill=True, data=cur_ds, alpha=0.6)
57 | sns.kdeplot(x='DEM', y=column_name, fill=False, color='gray', data=cur_ds, alpha=0.6)
58 | title_name = 'Scatter kernel density map of $R^2$ \n between NDVI and {} under DEM'.format(column_name)
59 | plt.title(title_name, fontsize=16)
60 | plt.xlabel('DEM(m)', fontsize=14)
61 | plt.ylabel('$R^2$ between NDVI and {}'.format(column_name), fontsize=14)
62 | plt.xticks(fontsize=12)
63 | plt.yticks(fontsize=12)
64 | # 设置XY轴起始值
65 | plt.xlim(left=0)
66 | plt.ylim(bottom=0)
67 | plt.savefig(os.path.join(out_dir, 'R2_{}.png'.format(column_name)), dpi=200)
68 | # plt.show()
69 | print('处理: {}'.format(column_name))
70 | # 绘制箱线图
71 | meanprops = {"marker":"o", "markerfacecolor":"white", "markeredgecolor":"black", "markersize":"10"}
72 | fig, axs = plt.subplots(4, 1, figsize=(13, 18), dpi=432)
73 | axs = axs.flatten()
74 | fig.suptitle('Box plot of NDVI and correlation coefficients of each variable', fontsize=30, va='top')
75 | for ix, ax in enumerate(axs):
76 | # print(iter_columns_name[(ix * 9):((ix + 1) * 9)])
77 | # ax.figure(figsize=(26, 9), dpi=321)
78 | df_melt = pd.melt(df, value_vars=iter_columns_name[(ix * 8):((ix + 1) * 8)]).dropna(how='any')
79 | sns.boxplot(data=df_melt, x='variable', y='value', palette=cm(np.linspace(0, 1, 9)), ax=ax, linewidth=3,
80 | showmeans=True, meanprops=meanprops)
81 | ax.set_xlabel('', fontsize=25)
82 | ax.set_ylabel('$R^2$', fontsize=25)
83 | ax.tick_params(axis='x', labelsize=18) # x轴标签旋转90度
84 | ax.tick_params(axis='y', labelsize=18)
85 | ax.grid(True)
86 | plt.tight_layout(pad=2)
87 | fig.savefig(os.path.join(out_dir, 'Box_R2.png'))
88 | # plt.show()
89 |
90 |
91 | # 用于看
92 | fig, axs = plt.subplots(4, 1, figsize=(13, 18), dpi=432)
93 | axs = axs.flatten()
94 | fig.suptitle('Box plot of NDVI and correlation coefficients of each variable', fontsize=30, va='top')
95 |
96 | for ix, ax in enumerate(axs):
97 | columns_slice = iter_columns_name[(ix * 9):((ix + 1) * 9)]
98 | df_melt = pd.melt(df, value_vars=columns_slice).dropna(how='any')
99 | # sns.boxplot(data=df_melt, x='variable', y='value', palette=cm(np.linspace(0, 1, 9)), ax=ax, linewidth=3,
100 | # showmeans=True, meanprops=meanprops)
101 | sns.boxplot(data=df_melt, x='variable', y='value', palette=cm(np.linspace(0, 1, 9)), ax=ax, linewidth=3)
102 |
103 | # 循环每个变量,计算最大值、最小值和平均值,然后在图上标注
104 | for i, column in enumerate(columns_slice):
105 | subset = df[column].dropna()
106 | max_val = subset.max()
107 | min_val = subset.min()
108 | mean_val = subset.mean()
109 |
110 | # 标注最大值、最小值和平均值
111 | ax.text(i, max_val, f'{max_val:.2f}', ha='center', va='bottom', fontsize=16, rotation=45)
112 | ax.text(i, min_val, f'{min_val:.2f}', ha='center', va='top', fontsize=16, rotation=45)
113 | ax.text(i, mean_val, f'{mean_val:.2f}', ha='center', va='center', fontsize=16, color='white', rotation=45)
114 |
115 | ax.set_xlabel('', fontsize=25)
116 | ax.set_ylabel('$R^2$', fontsize=25)
117 | ax.tick_params(axis='x', labelsize=18)
118 | ax.tick_params(axis='y', labelsize=18)
119 | ax.grid(True)
120 |
121 | plt.tight_layout(pad=2)
122 | fig.savefig(os.path.join(out_dir, 'Box_R2_quick.png'))
123 | # plt.show()
--------------------------------------------------------------------------------
/utils/models.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/1/15 16:01
3 | # @Email : chaoqiezi.one@qq.com
4 |
5 | """
6 | This script is used to 定义模型"""
7 |
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 | from torchsummary import summary
12 |
13 |
14 | class Encoder(nn.Module):
15 | def __init__(self,
16 | input_size=6,
17 | embedding_size=128,
18 | hidden_size=256,
19 | lstm_layers=3,
20 | dropout=0.5):
21 | super().__init__()
22 | self.fc = nn.Linear(input_size, embedding_size)
23 | self.rnn = nn.LSTM(embedding_size, hidden_size, lstm_layers, batch_first=True, dropout=dropout)
24 | self.dropout = nn.Dropout(dropout)
25 |
26 | def forward(self, x):
27 | embedded = self.dropout(F.relu(self.fc(x)))
28 | output, (hidden, cell) = self.rnn(embedded)
29 |
30 | return hidden, cell
31 |
32 |
33 | class Decoder(nn.Module):
34 | def __init__(self,
35 | input_size=6,
36 | embedding_size=128,
37 | hidden_size=256,
38 | output_size=1,
39 | lstm_layers=3,
40 | dropout=0.5):
41 | super().__init__()
42 | self.embedding = nn.Linear(input_size, embedding_size)
43 | self.rnn = nn.LSTM(embedding_size, hidden_size, lstm_layers, dropout=dropout)
44 | self.fc = nn.Linear(hidden_size, output_size)
45 | self.dropout = nn.Dropout(dropout)
46 |
47 | def forward(self, x, hidden, cell):
48 | # x = x.unsqueeze(0) # add dimension, original shape=(batch_size, feature_size), ==> (1, batch_size, feature)
49 |
50 | embedded = self.dropout(F.relu(self.embedding(x)))
51 | output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
52 | prediction = self.fc(output[-1]).squeeze()
53 |
54 | return prediction, hidden, cell
55 |
56 |
57 | class Seq2Seq(nn.Module):
58 | def __init__(self, encoder, decoder, device):
59 | super().__init__()
60 | self.encoder = encoder
61 | self.decoder = decoder
62 | self.device = device
63 |
64 | def forward(self, x, teacher_forcing_ratio=0.5):
65 | # time_step = y.shape[1]
66 | time_step = 12
67 | hidden, cell = self.encoder(x)
68 |
69 | outputs = torch.zeros((x.shape[0], time_step)).to(self.device)
70 | # decoder_input = x[:, -1, :]
71 |
72 | for time_ix in range(time_step):
73 | decoder_input = x[:, time_ix:time_ix+1, :]
74 | decoder_input = torch.transpose(decoder_input, 0, 1)
75 | output, hidden, cell = self.decoder(decoder_input, hidden, cell)
76 | outputs[:, time_ix] = output
77 |
78 | # teacher_forcing = random.random() < teacher_forcing_ratio
79 |
80 | # decoder_input = y[:, time_ix] if teacher_forcing else output
81 |
82 | return outputs
83 |
84 | # 创建编码解码的lstm模型
85 | encoder = Encoder(6, 128, 256)
86 | decoder = Decoder(6, 128, 256)
87 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
88 | model = Seq2Seq(encoder, decoder, device)
89 | summary(encoder, (12, 6))
90 | summary(model, (12, 6))
91 |
92 |
93 | class LSTMModel(nn.Module):
94 | def __init__(self, input_size, hidden_size, num_layers, output_size):
95 | """
96 | :param input_size: 输入动态特征项数
97 | :param hidden_size: LSTM的隐藏层大小
98 | :param num_layers: LSTM层数
99 | :param output_size: 输出时间序列长度(default: 12, 12个月份)
100 | """
101 | super().__init__()
102 | self.causal_conv1d = nn.Conv1d(input_size, 128, 5)
103 | self.fc1 = nn.Linear(4, 128)
104 | self.rnn = nn.LSTM(128, hidden_size, num_layers, batch_first=True)
105 | self.fc2 = nn.Linear(384, output_size)
106 |
107 | def forward(self, dynamic_x, static_x):
108 | # 因果卷积
109 | conv1d_out = (self.causal_conv1d(F.pad(torch.transpose(dynamic_x, 1, 2), (2, 0))))
110 | # conv1d_out = self.causal_conv1d(F.pad(torch.transpose(dynamic_x, 1, 2), (2, 0)))
111 | # conv1d_out = self.causal_conv1d(torch.transpose(dynamic_x, 1, 2))
112 | # LSTM层
113 | lstm_out, _ = self.rnn(torch.transpose(conv1d_out, 1, 2))
114 | # 只使用最后一个时间步的输出
115 | lstm_out = lstm_out[:, -1, :] # (-1, 256)
116 | static_out = self.fc1(static_x) # (-1, 2) ==> (-1, 128)
117 | # static_out = self.fc1(static_x) # (-1, 2) ==> (-1, 128) 2024/5/11: 静态特征由2变为4, 新增Lon、Lat
118 | merged_out = torch.cat([lstm_out, static_out], dim=1) # (-1, 256 + 128)
119 | # 全连接层
120 | out = self.fc2(merged_out) # (-1, 12)
121 |
122 | return out
123 |
124 |
125 | class LSTMModelDynamic(nn.Module):
126 | def __init__(self, input_size, hidden_size, num_layers, output_size):
127 | """
128 | :param input_size: 输入动态特征项数
129 | :param hidden_size: LSTM的隐藏层大小
130 | :param num_layers: LSTM层数
131 | :param output_size: 输出时间序列长度(default: 12, 12个月份)
132 | """
133 | super().__init__()
134 | self.causal_conv1d = nn.Conv1d(input_size, 128, 5)
135 | # self.fc1 = nn.Linear(4, 128)
136 | self.rnn = nn.LSTM(128, hidden_size, num_layers, batch_first=True)
137 | self.fc2 = nn.Linear(hidden_size, output_size)
138 |
139 | def forward(self, dynamic_x):
140 | # 因果卷积
141 | conv1d_out = (self.causal_conv1d(F.pad(torch.transpose(dynamic_x, 1, 2), (2, 0))))
142 | # conv1d_out = self.causal_conv1d(F.pad(torch.transpose(dynamic_x, 1, 2), (2, 0)))
143 | # conv1d_out = self.causal_conv1d(torch.transpose(dynamic_x, 1, 2))
144 | # LSTM层
145 | lstm_out, _ = self.rnn(torch.transpose(conv1d_out, 1, 2))
146 | # 只使用最后一个时间步的输出
147 | lstm_out = lstm_out[:, -1, :] # (-1, 256)
148 | # static_out = self.fc1(static_x) # (-1, 2) ==> (-1, 128)
149 | # static_out = self.fc1(static_x) # (-1, 2) ==> (-1, 128) 2024/5/11: 静态特征由2变为4, 新增Lon、Lat
150 | # merged_out = torch.cat([lstm_out, static_out], dim=1) # (-1, 256 + 128)
151 | # 全连接层
152 | out = self.fc2(lstm_out) # (-1, 12)
153 |
154 | return out
--------------------------------------------------------------------------------
/Core/uniform_datasets.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/1/3 16:51
3 | # @Email : chaoqiezi.one@qq.com
4 |
5 | """
6 | This script is used to 对各个数据集进行统一,例如空间范围()
7 |
8 | 主要包括: 对modis(土地利用、ndvi、地表温度)、geo(DEM等)、gldas数据集进行重采样, 范围限定(裁剪至掩膜形状)
9 | """
10 |
11 | import os.path
12 | from glob import glob
13 | from concurrent.futures import ThreadPoolExecutor # 线程池
14 |
15 | from osgeo import gdal
16 |
17 | # 准备
18 | in_dir = r'E:\FeaturesTargets\non_uniform'
19 | out_dir = r'E:\FeaturesTargets\uniform'
20 | shp_path = r'E:\Basic\Region\sw5f\sw5_mask.shp'
21 | dem_path = r'E:\GEO\cndem01.tif'
22 | out_res = 0.1
23 |
24 |
25 | def resample_clip_mask(in_dir: str, out_dir: str, shp_path: str, wildcard: str, out_res: float = 0.1,
26 | resampleAlg=gdal.GRA_Cubic):
27 | """
28 | 该函数用于对指定文件夹内的影像进行批量重采样和裁剪、掩膜
29 | :param in_dir: 待处理文件所在文件夹目录
30 | :param out_dir: 输出文件的文件夹目录
31 | :param shp_path: 掩膜裁剪的shp文件
32 | :param wildcard: 检索输入文件夹内指定文件的通配符
33 | :param out_res: 输出分辨率
34 | :param resampleAlg: 重采样方法
35 | :return: None
36 | """
37 |
38 | if not os.path.exists(out_dir): os.makedirs(out_dir)
39 |
40 | target_paths = glob(os.path.join(in_dir, wildcard))
41 | for target_path in target_paths:
42 | out_path = os.path.join(out_dir, os.path.basename(target_path))
43 |
44 | img = gdal.Warp(
45 | out_path, # 输出位置
46 | target_path, # 源文件位置
47 | cutlineDSName=shp_path, # 掩膜裁剪所需文件
48 | cropToCutline=True, # 裁剪至掩膜形状
49 | xRes=out_res, # X方向分辨率
50 | yRes=out_res, # Y方向分辨率
51 | resampleAlg=resampleAlg # 重采样方法
52 | )
53 | img = None
54 |
55 | print('目前已处理: {}'.format(os.path.splitext(os.path.basename(target_path))[0]))
56 |
57 |
58 | # # 处理土地利用数据集
59 | # in_landuse_dir = os.path.join(in_dir, 'Landuse')
60 | # out_landuse_dir = os.path.join(out_dir, 'Landuse')
61 | # resample_clip_mask(in_landuse_dir, out_landuse_dir, shp_path, 'Landuse*.tiff', resampleAlg=gdal.GRA_NearestNeighbour)
62 | # # 处理地表温度数据集
63 | # in_lst_dir = os.path.join(in_dir, 'LST')
64 | # out_lst_dir = os.path.join(out_dir, 'LST')
65 | # resample_clip_mask(in_lst_dir, out_lst_dir, shp_path, 'LST*.tiff')
66 | # # 处理NDVI数据集
67 | # in_ndvi_dir = os.path.join(in_dir, 'NDVI')
68 | # out_ndvi_dir = os.path.join(out_dir, 'NDVI')
69 | # resample_clip_mask(in_ndvi_dir, out_ndvi_dir, shp_path, 'NDVI*.tiff')
70 | # # 处理ET(蒸散发量)数据集
71 | # in_et_dir = os.path.join(in_dir, 'ET')
72 | # out_et_dir = os.path.join(out_dir, 'ET')
73 | # resample_clip_mask(in_et_dir, out_et_dir, shp_path, 'GLDAS_ET*.tiff')
74 | # # 处理降水数据集
75 | # in_prcp_dir = os.path.join(in_dir, 'PRCP')
76 | # out_prcp_dir = os.path.join(out_dir, 'PRCP')
77 | # resample_clip_mask(in_prcp_dir, out_prcp_dir, shp_path, 'GLDAS_PRCP*.tiff')
78 | # # 处理Qs(表面径流量)数据集
79 | # in_qs_dir = os.path.join(in_dir, 'Qs')
80 | # out_qs_dir = os.path.join(out_dir, 'Qs')
81 | # resample_clip_mask(in_qs_dir, out_qs_dir, shp_path, 'GLDAS_Qs*.tiff')
82 | # # 处理Qsb(地下径流量)数据集
83 | # in_qsb_dir = os.path.join(in_dir, 'Qsb')
84 | # out_qsb_dir = os.path.join(out_dir, 'Qsb')
85 | # resample_clip_mask(in_qsb_dir, out_qsb_dir, shp_path, 'GLDAS_Qsb*.tiff')
86 | # # 处理TWSC数据集
87 | # in_twsc_dir = os.path.join(in_dir, 'TWSC')
88 | # out_twsc_dir = os.path.join(out_dir, 'TWSC')
89 | # resample_clip_mask(in_twsc_dir, out_twsc_dir, shp_path, 'GLDAS_TWSC*.tiff')
90 | # 处理DEM数据集
91 | # out_dem_path = os.path.join(out_dir, 'dem.tiff')
92 | # img = gdal.Warp(
93 | # out_dem_path,
94 | # dem_path,
95 | # cutlineDSName=shp_path,
96 | # cropToCutline=True,
97 | # xRes=out_res,
98 | # yRes=out_res,
99 | # resampleAlg=gdal.GRA_Cubic
100 | # )
101 | # img = None
102 |
103 | # 并行处理(加快处理速度)
104 | datasets_param = {
105 | 'Landuse': 'Landuse*.tiff',
106 | 'LST_MEAN': 'LST_MEAN*.tiff',
107 | 'LST_MAX': 'LST_MAX*.tiff',
108 | 'LST_MIN': 'LST_MIN*.tiff',
109 | 'NDVI_MEAN': 'NDVI_MEAN*.tiff',
110 | 'NDVI_MAX': 'NDVI_MAX*.tiff',
111 | 'NDVI_MIN': 'NDVI_MIN*.tiff',
112 | 'ET': 'GLDAS_ET*.tiff',
113 | 'PRCP': 'GLDAS_PRCP*.tiff',
114 | 'Qs': 'GLDAS_Qs*.tiff',
115 | 'Qsb': 'GLDAS_Qsb*.tiff',
116 | 'TWSC': 'GLDAS_TWSC*.tiff',
117 |
118 | }
119 |
120 | if __name__ == '__main__':
121 | with ThreadPoolExecutor() as executor:
122 | futures = []
123 | for dataset_name, wildcard in datasets_param.items():
124 | in_dataset_dir = os.path.join(in_dir, dataset_name)
125 | out_dataset_dir = os.path.join(out_dir, dataset_name)
126 | resampleAlg = gdal.GRA_NearestNeighbour if dataset_name == 'Landuse' else gdal.GRA_Cubic
127 | futures.append(executor.submit(resample_clip_mask, in_dataset_dir, out_dataset_dir, shp_path,
128 | wildcard, resampleAlg=resampleAlg))
129 | # 处理DEM
130 | out_dem_path = os.path.join(out_dir, 'dem.tiff')
131 | futures.append(executor.submit(gdal.Warp, out_dem_path, dem_path, cutlineDSName=shp_path,
132 | cropToCutline=True, xRes=out_res, yRes=out_res, resampleAlg=gdal.GRA_Cubic))
133 | # 等待所有数据集处理完成
134 | for future in futures:
135 | future.result()
136 |
137 | # 处理DEM数据集
138 | """
139 | 下述代码比较冗余, 简化为resample_clip_mask函数
140 | ----------------------------------------------------------------------
141 | # 处理地表温度数据
142 | lst_paths = glob(os.path.join(lst_dir, 'LST*.tiff'))
143 | out_lst_dir = os.path.join(out_dir, lst_dir.split('\\')[-1])
144 | if not os.path.exists(out_lst_dir): os.makedirs(out_lst_dir)
145 | for lst_path in lst_paths:
146 | out_path = os.path.join(out_lst_dir, os.path.basename(lst_path))
147 |
148 | # 重采样、掩膜和裁剪
149 | gdal.Warp(
150 | out_path,
151 | lst_path,
152 | xRes=out_res,
153 | yRes=out_res,
154 | cutlineDSName=shp_path, # 设置掩膜 shp文件
155 | cropToCutline=True, # 裁剪至掩膜形状
156 | resampleAlg=gdal.GRA_Cubic # 重采样方法: 三次卷积
157 | )
158 | print('目前已处理: {}'.format(os.path.splitext(os.path.basename(lst_path))[0]))
159 |
160 | # 处理ndvi数据集
161 | ndvi_paths = glob(os.path.join(ndvi_dir, 'NDVI*.tiff'))
162 | out_ndvi_dir = os.path.join(out_dir, ndvi_dir.split('\\')[-1])
163 | if not os.path.exists(out_ndvi_dir): os.makedirs(out_ndvi_dir)
164 | for ndvi_path in ndvi_paths:
165 | out_path = os.path.join(out_ndvi_dir, os.path.basename(ndvi_path))
166 | out_path = os.path.join(out_ndvi_dir, 'NDVI_temp.tiff')
167 | gdal.Warp(
168 | out_path,
169 | ndvi_path,
170 | cutlineDSName=shp_path, # 设置掩膜 shp文件
171 | cropToCutline=True, # 是否裁剪至掩膜形状
172 | xRes=out_res,
173 | yRes=out_res,
174 | resampleAlg=gdal.GRA_Cubic # 重采样方法: 三次卷积
175 | )
176 | """
177 |
--------------------------------------------------------------------------------
/Core/feature_engineering.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/1/19 3:12
3 | # @Email : chaoqiezi.one@qq.com
4 |
5 | """
6 | This script is used to 包括数据集的整合以支持输入到模型中训练,以及特征工程
7 |
8 | 各个数据集的时间范围:
9 |
10 | Landuse: 2001 - 2020
11 | LST(MEAN/MIN/MAX): 200002 - 202210
12 | NDVI(MEAN/MIN/MAX): 200002 - 202010
13 | ET: 200204 - 202309
14 | PRCP: 200204 - 202309
15 | Qs: 200204 - 202309
16 | Qsb: 200204 - 202309
17 | TWSC: 200204 - 202309
18 | dem: single
19 |
20 | 输出的nc文件的数据格式:
21 | - group(year)
22 | - features1 -> (None, time_step, features_count) , eg. (184, 139, 12 or other, 6)
23 | 7: LST, PRCP, ET, Qs, Qsb, TWSC
24 | - features2 -> (None, ), Landuse, (184 * 139)
25 | - targets-> (Noner, time_step), NDVI, (184 * 139, 12)
26 | - features3 -> dem
27 |
28 | 2024/5/11 新增关于Rs地表太阳辐射和经纬度数据集的添加
29 | 由于Rs时间范围为1983-2017年6月份, 因此此处公共部分的使用日期缩短到2016年12月份.
30 |
31 | 2024/5/11 仅使用LST,PRCP,ET,RS四个变量进行特征构建
32 | """
33 |
34 | from datetime import datetime
35 | import os
36 | import re
37 | from glob import glob
38 |
39 | import netCDF4 as nc
40 | import numpy as np
41 | from osgeo import gdal
42 | import h5py
43 | import torch
44 | from sklearn.preprocessing import MinMaxScaler, StandardScaler, scale
45 |
46 |
47 | def read_img(img_path):
48 | """
49 | 读取栅格文件的波段数据集
50 | :param img_path: 待读取栅格文件的路径
51 | :return: 波段数据集
52 | """
53 | img = gdal.Open(img_path)
54 | band = np.float32(img.GetRasterBand(1).ReadAsArray())
55 | no_data_value = img.GetRasterBand(1).GetNoDataValue()
56 | band[band == no_data_value] = np.nan
57 |
58 | return band
59 |
60 |
61 | # 准备
62 | in_dir = r'E:\FeaturesTargets\uniform'
63 | h5_path = r'E:\FeaturesTargets\features_targets.h5'
64 | dem_path = r'E:\FeaturesTargets\uniform\dem.tiff'
65 | slope_path = r'E:\FeaturesTargets\uniform\slope.tif'
66 | lon_path = r'E:\FeaturesTargets\uniform\Lat.tiff'
67 | lat_path = r'E:\FeaturesTargets\uniform\Lat.tiff'
68 | start_date = datetime(2003, 1, 1)
69 | end_date = datetime(2016, 12, 1)
70 | features1_params = {
71 | 'LST_MAX': 'LST_MAX_',
72 | # 'LST_MIN': 'LST_MIN_',
73 | # 'LST_MEAN': 'LST_MEAN_',
74 | 'PRCP': 'GLDAS_PRCP_',
75 | 'ET': 'GLDAS_ET_',
76 | # 'Qs': 'GLDAS_Qs_',
77 | # 'Qsb': 'GLDAS_Qsb_',
78 | # 'TWSC': 'GLDAS_TWSC_',
79 | 'Rs': 'Rs_'
80 | }
81 | rows = 132
82 | cols = 193
83 | features1_size = len(features1_params)
84 |
85 | # 特征处理和写入
86 | h5 = h5py.File(h5_path, mode='w')
87 | for year in range(start_date.year, end_date.year + 1):
88 | start_month = start_date.month if year == start_date.year else 1
89 | end_month = end_date.month if year == end_date.year else 12
90 |
91 | features1 = [] # 存储动态特征
92 | targets = []
93 | cur_group = h5.create_group(str(year))
94 | for month in range(start_month, end_month + 1):
95 | # 当前月份特征项的读取
96 | cur_features = np.empty((rows, cols, features1_size))
97 | for ix, (parent_folder_name, feature_wildcard) in enumerate(features1_params.items()):
98 | cur_in_dir = os.path.join(in_dir, parent_folder_name)
99 | pattern = re.compile(feature_wildcard + r'{:04}_?{:02}\.tiff'.format(year, month))
100 | feature_paths = [_path for _path in os.listdir(cur_in_dir) if pattern.match(_path)]
101 | if len(feature_paths) != 1:
102 | raise NameError('文件名错误, 文件不存在或者指定文件存在多个')
103 | feature_path = os.path.join(cur_in_dir, feature_paths[0])
104 | cur_features[:, :, ix] = read_img(feature_path)
105 | features1.append(cur_features.reshape(-1, features1_size))
106 | # 当前月份目标项的读取
107 | ndvi_paths = glob(os.path.join(in_dir, 'NDVI_MAX', 'NDVI_MAX_{:04}_{:02}.tiff'.format(year, month)))
108 | if len(ndvi_paths) != 1:
109 | raise NameError('文件名错误, 文件不存在或者指定文件存在多个')
110 | ndvi_path = ndvi_paths[0]
111 | cur_ndvi = read_img(ndvi_path)
112 | targets.append(cur_ndvi.reshape(-1))
113 | features1 = np.array(features1)
114 | targets = np.array(targets)
115 |
116 | """这里不使用土地利用数据,改用slope数据"""
117 | # landuse_paths = glob(os.path.join(in_dir, 'Landuse', 'Landuse_{}.tiff'.format(year)))
118 | # if len(landuse_paths) != 1:
119 | # raise NameError('文件名错误, 文件不存在或者指定文件存在多个')
120 | # landuse_path = landuse_paths[0]
121 | # features2 = read_img(landuse_path).reshape(-1)
122 |
123 | cur_group['features1'] = features1
124 | # cur_group['features2'] = features2
125 | cur_group['targets'] = targets
126 | print('目前已处理: {}'.format(year))
127 |
128 | h5['dem'] = read_img(dem_path).reshape(-1)
129 | h5['slope'] = read_img(slope_path).reshape(-1) # 添加slope数据作为特征项
130 | h5['lon'] = read_img(lon_path).reshape(-1)
131 | h5['lat'] = read_img(lat_path).reshape(-1)
132 | if np.isnan(h5['lon']).any() or np.isnan(h5['lat']).any():
133 | raise RuntimeWarning("Lon/Lat 存在无效值!")
134 | h5.flush()
135 | h5.close()
136 | h5 = None
137 |
138 | # 进一步处理,混合所有年份的数据(无需分组)
139 | with h5py.File(h5_path, mode='a') as h5:
140 | year_dem = h5['dem']
141 | year_slope = h5['slope']
142 | year_lon = h5['lon']
143 | year_lat = h5['lat']
144 | for year in range(start_date.year, end_date.year + 1):
145 | year_features1 = h5[r'{}/features1'.format(year)] # 这里导致的重大错误: year_features1 = h5[r'2003/features1']
146 | # year_features2 = h5[r'2003/features2']
147 | year_targets = h5[r'{}/targets'.format(year)] # Here too
148 |
149 | mask = np.all(~np.isnan(year_features1), axis=(0, 2)) & \
150 | ~np.isnan(year_slope) & \
151 | np.all(~np.isnan(year_targets), axis=0) & \
152 | ~np.isnan(year_dem)
153 | h5['{}/mask'.format(year)] = mask
154 | if year == 2003:
155 | features1 = year_features1[:, mask, :]
156 | slope = year_slope[mask]
157 | targets = year_targets[:, mask]
158 | dem = year_dem[mask]
159 | lon = year_lon[mask]
160 | lat = year_lat[mask]
161 | else:
162 | features1 = np.concatenate((features1, year_features1[:, mask, :]), axis=1)
163 | slope = np.concatenate((slope, year_slope[mask]), axis=0)
164 | targets = np.concatenate((targets, year_targets[:, mask]), axis=1)
165 | dem = np.concatenate((dem, year_dem[mask]), axis=0)
166 | lon = np.concatenate((lon, year_lon[mask]), axis=0)
167 | lat = np.concatenate((dem, year_lat[mask]), axis=0)
168 |
169 | # 归一化
170 | scaler = StandardScaler()
171 | for month in range(12):
172 | features1[month, :, :] = scaler.fit_transform(features1[month, :, :])
173 | dem = scaler.fit_transform(dem.reshape(-1, 1)).ravel()
174 | slope = scaler.fit_transform(slope.reshape(-1, 1)).ravel()
175 | lon = scaler.fit_transform(lon.reshape(-1, 1)).ravel()
176 | lat = scaler.fit_transform(lat.reshape(-1, 1)).ravel()
177 |
178 | sample_size = dem.shape[0]
179 | train_amount = int(sample_size * 0.8)
180 | eval_amount = sample_size - train_amount
181 | # 创建数据集并存储训练数据
182 | with h5py.File(r'E:\FeaturesTargets\train.h5', mode='w') as h5:
183 | h5.create_dataset('dynamic_features', data=features1[:, :train_amount, :])
184 | h5.create_dataset('static_features1', data=slope[:train_amount]) # 静态变量
185 | h5.create_dataset('static_features2', data=dem[:train_amount]) # 静态变量
186 | h5.create_dataset('static_features3', data=lon[:train_amount]) # 静态变量
187 | h5.create_dataset('static_features4', data=lat[:train_amount]) # 静态变量
188 | h5.create_dataset('targets', data=targets[:, :train_amount])
189 | with h5py.File(r'E:\FeaturesTargets\eval.h5', mode='w') as h5:
190 | # # # 创建数据集并存储评估数据
191 | h5.create_dataset('dynamic_features', data=features1[:, train_amount:, :])
192 | h5.create_dataset('static_features1', data=slope[train_amount:]) # 静态变量
193 | h5.create_dataset('static_features2', data=dem[train_amount:]) # 静态变量
194 | h5.create_dataset('static_features3', data=lon[train_amount:]) # 静态变量
195 | h5.create_dataset('static_features4', data=lat[train_amount:]) # 静态变量
196 | h5.create_dataset('targets', data=targets[:, train_amount:])
197 |
--------------------------------------------------------------------------------
/Core/Plot/line_bar_distribution_plot.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/3/29 9:47
3 | # @FileName : line_bar_distribution_plot.py
4 | # @Email : chaoqiezi.one@qq.com
5 |
6 | """
7 | This script is used to 绘制折线图、柱状图、插值分布图
8 |
9 | EWTC: 包括人类活动导致的(地表水地下使用加工运输到别的地方等), 自然变化的(蒸腾蒸发降水等)引起的储水量变化
10 | TWSC: 指单独自然变化导致的储水量变化
11 | AWC: (EWTC - TWSC即可得到)人类活动导致引发的储水量变化
12 | """
13 |
14 | import glob
15 | import os.path
16 | import pandas as pd
17 | import numpy as np
18 | import matplotlib.pyplot as plt
19 | import seaborn as sns
20 | from osgeo import gdal
21 |
22 | # 准备
23 | ewtc_path = r'H:\Datasets\Objects\Veg\LXB_plot\Data\sw_EWTC_T.csv'
24 | twsc_path = r'H:\Datasets\Objects\Veg\LXB_plot\Data\sw_TWSC_SH_T.csv'
25 | ndvi_path = r'H:\Datasets\Objects\Veg\LXB_plot\Data\sw_NDVI_T.csv'
26 | in_img_dir= r'E:\FeaturesTargets\uniform'
27 | out_dir = r'H:\Datasets\Objects\Veg\LXB_plot'
28 | sns.set_style('darkgrid')
29 | plt.rcParams['font.sans-serif'] = ['Times New Roman'] # 新罗马字体
30 |
31 |
32 | # 绘制
33 | ewtc = pd.read_csv(ewtc_path)
34 | twsc = pd.read_csv(twsc_path)
35 | ndvi = pd.read_csv(ndvi_path)
36 | awc = pd.merge(ewtc, twsc, left_on='ProvinceNa', right_on='ProvinceNa', suffixes=('_ewtc', '_twsc'))
37 | awc['AWC'] = awc['EWTC'] - awc['TWSC_SH']
38 |
39 | # # 年月均值(ewtc twsc ndvi)
40 | # var_str = ['EWTC', 'TWSC_SH', 'AWC']
41 | # for ix, var in enumerate([ewtc, twsc, awc]):
42 | # var_name = var_str[ix]
43 | # if var_name == 'AWC':
44 | # var_monthly = var[['Year_ewtc', 'Month_ewtc', var_name]].groupby(['Year_ewtc', 'Month_ewtc']).mean()
45 | # var_monthly.to_csv(os.path.join(out_dir, '{}.csv'.format(var_name)))
46 | # else:
47 | # var_monthly = var[['Year', 'Month', var_name]].groupby(['Year', 'Month']).mean()
48 | # var_monthly['Date'] = var_monthly.reset_index().apply(lambda x: '{:04.0f}/{:02.0f}'.format(x.iloc[0], x.iloc[1]), axis=1).values
49 | # var_monthly.to_csv(os.path.join(out_dir, '{}.csv'.format(var_name)))
50 | # # 月均值
51 | # var_str = ['EWTC', 'TWSC_SH', 'AWC']
52 | # for ix, var in enumerate([ewtc, twsc, awc]):
53 | # var_name = var_str[ix]
54 | # if var_name == 'AWC':
55 | # var_monthly = var[['Month_ewtc', var_name]].groupby(['Month_ewtc']).mean()
56 | # # var_monthly.to_csv(os.path.join(out_dir, '{}_monthly.csv'.format(var_name)))
57 | # else:
58 | # var_monthly = var[['Month', var_name]].groupby(['Month']).mean()
59 | # var_monthly['Date'] = var_monthly.index
60 | # # var_monthly['Date'] = var_monthly.reset_index().apply(lambda x: '{:04.0f}/{:02.0f}'.format(x.iloc[0], x.iloc[1]), axis=1).values
61 | # var_monthly.to_csv(os.path.join(out_dir, '{}_monthly.csv'.format(var_name)), index=False)
62 | # 年均值
63 | var_str = ['EWTC', 'TWSC_SH', 'AWC']
64 | for ix, var in enumerate([ewtc, twsc, awc]):
65 | var_name = var_str[ix]
66 | if var_name == 'AWC':
67 | var_yearly = var[['Year_ewtc', var_name]].groupby(['Year_ewtc']).mean()
68 | # var_monthly.to_csv(os.path.join(out_dir, '{}_monthly.csv'.format(var_name)))
69 | else:
70 | var_yearly = var[['Year', var_name]].groupby(['Year']).mean()
71 | var_yearly['Date'] = var_yearly.index
72 | # var_monthly['Date'] = var_monthly.reset_index().apply(lambda x: '{:04.0f}/{:02.0f}'.format(x.iloc[0], x.iloc[1]), axis=1).values
73 | var_yearly.to_csv(os.path.join(out_dir, '{}_yearly.csv'.format(var_name)), index=False)
74 |
75 | # PRCP, Qs, Qsb, ET年均值的计算
76 | prcp_month_path = r'H:\Datasets\Objects\Veg\LXB_plot\PRCP\PRCP.xlsx'
77 | et_month_path = r'H:\Datasets\Objects\Veg\LXB_plot\ET\ET.xlsx'
78 | qs_month_path = r'H:\Datasets\Objects\Veg\LXB_plot\Qs\Qs.xlsx'
79 | qsb_month_path = r'H:\Datasets\Objects\Veg\LXB_plot\Qsb\Qsb.xlsx'
80 | prcp = pd.read_excel(prcp_month_path)
81 | et = pd.read_excel(et_month_path)
82 | qs = pd.read_excel(qs_month_path)
83 | qsb = pd.read_excel(qsb_month_path)
84 | prcp['Year'] = prcp.date.apply(lambda x: x.year)
85 | prcp_yearly = prcp[['Year', 'PRCP']].groupby(['Year']).mean().reset_index(drop=False)
86 | prcp_yearly.to_excel(os.path.join('H:\Datasets\Objects\Veg\LXB_plot\PRCP', 'prcp_yearly.xlsx'), index=False)
87 |
88 | et['Year'] = et.date.apply(lambda x: x.year)
89 | et_yearly = et[['Year', 'ET']].groupby(['Year']).mean().reset_index(drop=False)
90 | et_yearly.to_excel(os.path.join('H:\Datasets\Objects\Veg\LXB_plot\ET', 'et_yearly.xlsx'), index=False)
91 |
92 | qs['Year'] = qs.date.apply(lambda x: x.year)
93 | qs_yearly = qs[['Year', 'Qs']].groupby(['Year']).mean().reset_index(drop=False)
94 | qs_yearly.to_excel(os.path.join('H:\Datasets\Objects\Veg\LXB_plot\Qs', 'Qs_yearly.xlsx'), index=False)
95 |
96 | qsb['Year'] = qsb.date.apply(lambda x: x.year)
97 | qsb_yearly = qsb[['Year', 'Qsb']].groupby(['Year']).mean().reset_index(drop=False)
98 | qsb_yearly.to_excel(os.path.join('H:\Datasets\Objects\Veg\LXB_plot\Qsb', 'Qsb_yearly.xlsx'), index=False)
99 |
100 |
101 |
102 | # 绘制NDVI年变化折线图和月变化柱状图
103 | ndvi_monthly = ndvi[['Month', 'NDVI']].groupby('Month').mean()
104 | ndvi_yearly = ndvi[['Year', 'NDVI']].groupby('Year').mean()
105 | ndvi_yearly['Year'] = ndvi_yearly.index
106 | ndvi_monthly['Month'] = ['{:02}'.format(_x) for _x in ndvi_monthly.index]
107 | ndvi_monthly.to_csv(os.path.join(out_dir, 'ndvi_monthly.csv'), index=False)
108 | ndvi_yearly.to_csv(os.path.join(out_dir, 'ndvi_yearly.csv'), index=False)
109 | # # 绘制折线图
110 | # plt.figure(figsize=(13, 9), dpi=222)
111 | # sns.lineplot(data=ndvi_yearly, x='Year', y='NDVI', linestyle='-', color='#1f77b4', linewidth=7, legend=True)
112 | # plt.scatter(ndvi_yearly['Year'], ndvi_yearly['NDVI'], s=100, facecolors='none', edgecolors='#bcbd22', linewidths=5, zorder=5)
113 | # plt.xlabel('Year', size=26) # 设置x轴标签
114 | # plt.ylabel('NDVI', size=26) # 设置y轴标签
115 | # plt.xticks(ndvi_yearly['Year'], rotation=45, fontsize=18)
116 | # plt.yticks(fontsize=22)
117 | # plt.savefig(os.path.join(out_dir, 'ndvi_line_yearly.png'))
118 | # plt.show()
119 | # # 绘制柱状图
120 | # plt.figure(figsize=(13, 9), dpi=222)
121 | # sns.barplot(data=ndvi_monthly, x='Month', y='NDVI', linestyle='-', color='#1f77b4')
122 | # plt.xlabel('Month', size=26) # 设置x轴标签
123 | # plt.ylabel('NDVI', size=26) # 设置y轴标签
124 | # x_labels = ndvi_monthly['Month'].apply(lambda x: '{:02}'.format(x))
125 | # plt.xticks(ticks=range(len(x_labels)), labels=x_labels, fontsize=18)
126 | # plt.yticks(fontsize=22)
127 | # plt.savefig(os.path.join(out_dir, 'ndvi_line_monthly.png'))
128 | # plt.show()
129 |
130 | # 提取降水、蒸散、地表和地下径流
131 | station = ndvi.drop_duplicates(['Lon', 'Lat'])[['Lon', 'Lat']]
132 | var_names = ['PRCP', 'ET', 'Qs', 'Qsb']
133 | for var_name in var_names:
134 | var = []
135 | cur_dir = os.path.join(in_img_dir, var_name)
136 | var_paths = glob.glob(os.path.join(cur_dir, 'GLDAS_{}*.tiff'.format(var_name)))
137 | for var_path in var_paths:
138 | ds = gdal.Open(var_path)
139 | lon_min, lon_res, _, lat_max, _, lat_res_negative = ds.GetGeoTransform()
140 | ds_band = np.float32(ds.GetRasterBand(1).ReadAsArray())
141 | nodata_value = ds.GetRasterBand(1).GetNoDataValue()
142 | ds_band[ds_band == nodata_value] = np.nan
143 | station['row'] = np.floor((lat_max - station['Lat']) / (-lat_res_negative)).astype(int)
144 | station['col'] = np.floor((station['Lon'] - lon_min) / lon_res).astype(int)
145 | station[var_name] = ds_band[station['row'], station['col']]
146 | station['date'] = os.path.basename(var_path).split('_')[2][:6]
147 | var.append(station.copy())
148 | var = pd.concat(var, ignore_index=True)
149 | var['date'] = var['date'].apply(lambda x: x[:4] + '/' + x[4:])
150 | var = var[['date', var_name]].groupby(['date']).mean()
151 | out_path = os.path.join(out_dir, '{}.csv'.format(var_name))
152 | var.to_csv(out_path)
153 |
154 | # TWSC/EWTC/AWC均值计算
155 | ewtc_by_station = ewtc[['Lon', 'Lat', 'EWTC']].groupby(['Lon', 'Lat']).mean().reset_index()
156 | twsc_by_station = twsc[['Lon', 'Lat', 'TWSC_SH']].groupby(['Lon', 'Lat']).mean().reset_index()
157 | awc_by_station = awc[['Lat_ewtc', 'Lon_ewtc', 'AWC']].groupby(['Lat_ewtc', 'Lon_ewtc']).mean().reset_index()
158 | ndvi_by_station = ndvi[['Lon', 'Lat', 'NDVI']].groupby(['Lon', 'Lat']).mean().reset_index()
159 | for var in [ewtc_by_station, twsc_by_station, awc_by_station, ndvi_by_station]:
160 | out_path = os.path.join(out_dir, 'distribution_{}.csv'.format(var.columns[-1]))
161 | var.to_csv(out_path, index=False)
--------------------------------------------------------------------------------
/Core/model_train_dynamic.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/1/3 16:54
3 | # @Email : chaoqiezi.one@qq.com
4 |
5 | """
6 | This script is used to 构建lstm模型并训练
7 |
8 | 2024/5/11 增加Rs、Lon、Lat输入到模型中训练
9 |
10 | 2024/5/11 仅使用动态特征中的LST, PRCP, ET, RS进行训练
11 | """
12 |
13 | import random
14 | import glob
15 | import os.path
16 | import numpy as np
17 | import pandas as pd
18 | import torch
19 | from torchsummary import summary
20 | from torch.utils.data import DataLoader, random_split
21 | from VEG.utils.utils import cal_r2, H5DynamicDatasetDecoder
22 | from VEG.utils.models import LSTMModel, LSTMModelDynamic
23 | from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
24 |
25 |
26 | def set_seed(seed=42):
27 | random.seed(seed)
28 | np.random.seed(seed)
29 | os.environ['PYTHONHASHSEED'] = str(seed)
30 | torch.manual_seed(seed)
31 | torch.cuda.manual_seed(seed)
32 | torch.cuda.manual_seed_all(seed) # 如果使用多GPU
33 | torch.backends.cudnn.deterministic = True
34 | torch.backends.cudnn.benchmark = False
35 |
36 |
37 | set_seed(42)
38 |
39 | # 准备
40 | train_path = r'E:\FeaturesTargets\train.h5'
41 | eval_path = r'E:\FeaturesTargets\eval.h5'
42 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
43 | out_model_dir = r'E:\Models'
44 | dynamic_features_name = [
45 | 'LST_MAX',
46 | 'PRCP',
47 | 'ET',
48 | # 'Qs',
49 | # 'Qsb',
50 | # 'TWSC',
51 | 'Rs'
52 | ]
53 | static_feature_name = [
54 | 'Slope',
55 | 'DEM',
56 | 'Lon',
57 | 'Lat'
58 | ]
59 | # 创建LSTM模型实例并移至GPU
60 | # model = LSTMModel(4, 256, 4, 12).to('cuda' if torch.cuda.is_available() else 'cpu')
61 | # summary(model, input_data=[(12, 7), (4,)])
62 | model = LSTMModelDynamic(4, 256, 4, 12).to('cuda' if torch.cuda.is_available() else 'cpu')
63 | summary(model, input_data=[(12, 4)])
64 | batch_size = 256
65 |
66 | # generator = torch.Generator().manual_seed(42) # 指定随机种子
67 | # train_dataset, eval_dataset, sample_dataset = random_split(dataset, (0.8, 0.195, 0.005), generator=generator)
68 | # train_dataset, eval_dataset = random_split(dataset, (0.8, 0.2), generator=generator)
69 | # 创建数据加载器
70 | train_dataset = H5DynamicDatasetDecoder(train_path) # 创建自定义数据集实例
71 | eval_dataset = H5DynamicDatasetDecoder(eval_path)
72 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
73 | eval_data_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
74 | # 训练参数
75 | criterion = torch.nn.MSELoss()
76 | optimizer = torch.optim.Adam(model.parameters(), lr=0.002) # 初始学习率设置为0.001
77 | epochs_num = 30
78 | model.train() # 切换为训练模式
79 |
80 |
81 | def model_train(data_loader, epochs_num: int = 25,save_path: str = None, device='cuda'):
82 | # 创建新的模型实例
83 | model = LSTMModelDynamic(4, 256, 4, 12).to(device)
84 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) # 初始学习率设置为0.001
85 | epochs_loss = []
86 | for epoch in range(epochs_num):
87 | train_loss = []
88 | for dynamic_inputs, targets in data_loader:
89 | dynamic_inputs, targets = dynamic_inputs.to(device), targets.to(device)
90 |
91 | """正常"""
92 | # 前向传播
93 | outputs = model(dynamic_inputs)
94 | # 计算损失
95 | loss = criterion(outputs, targets)
96 | # 反向传播和优化
97 | loss.backward()
98 | optimizer.step()
99 | # scheduler.step() # 更新学习率
100 |
101 | optimizer.zero_grad() # 清除梯度
102 | train_loss.append(loss.item())
103 | print(f'Epoch {epoch + 1}/{epochs_num}, Loss: {np.mean(train_loss)}')
104 | epochs_loss.append(np.mean(train_loss))
105 |
106 | if save_path:
107 | torch.save(model.state_dict(), save_path)
108 |
109 | return epochs_loss
110 |
111 |
112 | def model_eval_whole(model_path: str, data_loader, device='cuda'):
113 | # 加载模型
114 | model = LSTMModelDynamic(4, 256, 4, 12).to(device)
115 | model.load_state_dict(torch.load(model_path))
116 |
117 | # 评估
118 | model.eval() # 评估模式
119 | all_outputs = []
120 | all_targets = []
121 | with torch.no_grad():
122 | for dynamic_inputs, targets in data_loader:
123 | dynamic_inputs, targets = dynamic_inputs.to(device), targets.to(device)
124 | outputs = model(dynamic_inputs)
125 | all_outputs.append(outputs.cpu()) # outputs/targets: (batch_size, time_steps)
126 | all_targets.append(targets.cpu())
127 |
128 | all_outputs = np.concatenate(all_outputs, axis=0)
129 | all_targets = np.concatenate(all_targets, axis=0)
130 |
131 | # mse_per_step = []
132 | # mae_per_step = []
133 | # r2_per_step = []
134 | # rmse_per_step = []
135 | # for time_step in range(12):
136 | # mse_step = mean_squared_error(all_targets[:, time_step], all_outputs[:, time_step])
137 | # mae_step = mean_absolute_error(all_targets[:, time_step], all_outputs[:, time_step])
138 | # r2_step = r2_score(all_targets[:, time_step], all_outputs[:, time_step])
139 | # rmse_step = np.sqrt(mse_step)
140 | #
141 | # mse_per_step.append(mse_step)
142 | # mae_per_step.append(mae_step)
143 | # r2_per_step.append(r2_step)
144 | # rmse_per_step.append(rmse_step)
145 |
146 | # mse = np.mean(mse_per_step)
147 | # mae = np.mean(mae_per_step)
148 | # r2 = np.mean(r2_per_step)
149 | # rmse = np.mean(rmse_per_step)
150 |
151 | # 不区分月份求取指标(视为整体)
152 | mse_step = mean_squared_error(all_targets.reshape(-1), all_outputs.reshape(-1))
153 | mae_step = mean_absolute_error(all_targets.reshape(-1), all_outputs.reshape(-1))
154 | r2_step = r2_score(all_targets.reshape(-1), all_outputs.reshape(-1))
155 | rmse_step = np.sqrt(mse_step)
156 | return mse_step, mae_step, r2_step, rmse_step
157 |
158 | # return mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets
159 |
160 |
161 |
162 | def model_eval(model_path: str, data_loader, device='cuda'):
163 | # 加载模型
164 | model = LSTMModelDynamic(4, 256, 4, 12).to(device)
165 | model.load_state_dict(torch.load(model_path))
166 |
167 | # 评估
168 | model.eval() # 评估模式
169 | all_outputs = []
170 | all_targets = []
171 | with torch.no_grad():
172 | for dynamic_inputs, targets in data_loader:
173 | dynamic_inputs, targets = dynamic_inputs.to(device), targets.to(device)
174 | outputs = model(dynamic_inputs)
175 | all_outputs.append(outputs.cpu()) # outputs/targets: (batch_size, time_steps)
176 | all_targets.append(targets.cpu())
177 |
178 | all_outputs = np.concatenate(all_outputs, axis=0)
179 | all_targets = np.concatenate(all_targets, axis=0)
180 |
181 | mse_per_step = []
182 | mae_per_step = []
183 | r2_per_step = []
184 | rmse_per_step = []
185 | for time_step in range(12):
186 | mse_step = mean_squared_error(all_targets[:, time_step], all_outputs[:, time_step])
187 | mae_step = mean_absolute_error(all_targets[:, time_step], all_outputs[:, time_step])
188 | r2_step = r2_score(all_targets[:, time_step], all_outputs[:, time_step])
189 | rmse_step = np.sqrt(mse_step)
190 |
191 | mse_per_step.append(mse_step)
192 | mae_per_step.append(mae_step)
193 | r2_per_step.append(r2_step)
194 | rmse_per_step.append(rmse_step)
195 |
196 | return mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets
197 |
198 | if __name__ == '__main__':
199 | df = pd.DataFrame()
200 | # 常规训练
201 | df['normal_epochs_loss'] = model_train(train_data_loader, save_path=os.path.join(out_model_dir,
202 | 'normal_model_dynamic.pth'))
203 | print('>>> 常规训练结束')
204 | # 特征重要性训练
205 | # 动态特征
206 | for feature_ix in range(4):
207 | train_dataset = H5DynamicDatasetDecoder(train_path, shuffle_feature_ix=feature_ix, dynamic=True) # 创建自定义数据集实例
208 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
209 |
210 | cur_feature_name = dynamic_features_name[feature_ix]
211 | save_path = os.path.join(out_model_dir, cur_feature_name + '_model_dynamic.pth')
212 | df[cur_feature_name + '_epochs_loss'] = \
213 | model_train(train_data_loader, save_path=save_path)
214 | print('>>> {}乱序排列 训练结束'.format(cur_feature_name))
215 | # # 静态特征
216 | # for feature_ix in range(4):
217 | # train_dataset = H5DatasetDecoder(train_path, shuffle_feature_ix=feature_ix, dynamic=False) # 创建自定义数据集实例
218 | # train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
219 | #
220 | # cur_feature_name = static_feature_name[feature_ix]
221 | # save_path = os.path.join(out_model_dir, cur_feature_name + '_model.pth')
222 | # df[cur_feature_name + '_epochs_loss'] = \
223 | # model_train(train_data_loader, save_path=save_path)
224 | # print('>>> {}乱序排列 训练结束'.format(cur_feature_name))
225 | df.to_excel(r'E:\Models\training_eval_results\training_loss_dynamic.xlsx')
226 |
227 | # 评估
228 | indicator_whole = pd.DataFrame()
229 | indicator = pd.DataFrame()
230 | model_paths = glob.glob(os.path.join(out_model_dir, '*_model_dynamic.pth'))
231 | for model_path in model_paths:
232 | cur_model_name = os.path.basename(model_path).split('_model')[0]
233 | mse_step, mae_step, r2_step, rmse_step = model_eval_whole(model_path, eval_data_loader)
234 | indicator_whole[cur_model_name + '_evaluate_mse'] = [mse_step]
235 | indicator_whole[cur_model_name + '_evaluate_mae'] = [mae_step]
236 | indicator_whole[cur_model_name + '_evaluate_r2'] = [r2_step]
237 | indicator_whole[cur_model_name + '_evaluate_rmse'] = [rmse_step]
238 |
239 | mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets = model_eval(model_path, eval_data_loader)
240 |
241 | all_outputs_targets = np.concatenate((all_outputs, all_targets), axis=1)
242 | columns = [*['outputs_{:02}'.format(month) for month in range(1, 13)], *['targets_{:02}'.format(month) for month in range(1, 13)]]
243 | outputs_targets = pd.DataFrame(all_outputs_targets, columns=columns)
244 | indicator[cur_model_name + '_evaluate_mse'] = mse_per_step
245 | indicator[cur_model_name + '_evaluate_mae'] = mae_per_step
246 | indicator[cur_model_name + '_evaluate_r2'] = r2_per_step
247 | indicator[cur_model_name + '_evaluate_rmse'] = rmse_per_step
248 | outputs_targets.to_excel(r'E:\Models\training_eval_results\{}_outputs_targets_dynamic.xlsx'.format(cur_model_name))
249 | print('>>> {} 重要性评估完毕'.format(cur_model_name))
250 | indicator.loc['均值指标'] = np.mean(indicator, axis=0)
251 | indicator.to_excel(r'E:\Models\training_eval_results\eval_indicators_dynamic.xlsx')
252 | indicator_whole.to_excel(r'E:\Models\training_eval_results\eval_indicators_整体_dynamic.xlsx')
253 | # model.eval()
254 | # eval_loss = []
255 | # with torch.no_grad():
256 | # for dynamic_inputs, static_inputs, targets in data_loader:
257 | # dynamic_inputs = dynamic_inputs.to('cuda' if torch.cuda.is_available() else 'cpu')
258 | # static_inputs = static_inputs.to('cuda' if torch.cuda.is_available() else 'cpu')
259 | # targets = targets.to('cuda' if torch.cuda.is_available() else 'cpu')
260 | # # 前向传播
261 | # outputs = model(dynamic_inputs, static_inputs)
262 | # # 计算损失
263 | # loss = criterion(outputs, targets)
264 | # r2 = cal_r2(outputs, targets)
265 | # print('预测项:', outputs)
266 | # print('目标项:', targets)
267 | # print(f'MSE Loss: {loss.item()}')
268 | # break
269 | # eval_loss.append(loss.item())
270 | # print(f'Loss: {np.mean(eval_loss)}')
271 | # print(f'R2:', r2)
272 |
273 |
274 |
275 | # # 取
276 | # with h5py.File(r'E:\FeaturesTargets\features_targets.h5', 'r') as h5:
277 | # features = np.transpose(h5['2003/features1'][:], (1, 0, 2)) # shape=(样本数, 时间步, 特征项)
278 | # targets = np.transpose(h5['2003/targets'][:], (1, 0)) # shape=(样本数, 时间步)
279 | # static_features = np.column_stack((h5['2003/features2'][:], h5['dem'][:]))
280 | # mask1 = ~np.any(np.isnan(features), axis=(1, 2))
281 | # mask2 = ~np.any(np.isnan(targets), axis=(1,))
282 | # mask3 = ~np.any(np.isnan(static_features), axis=(1, ))
283 | # mask = (mask1 & mask2 & mask3)
284 | # features = features[mask, :, :]
285 | # targets = targets[mask, :]
286 | # static_features = static_features[mask, :]
287 | # print(features.shape)
288 | # print(targets.shape)
289 | # for ix in range(6):
290 | # feature = features[:, :, ix]
291 | # features[:, :, ix] = (feature - feature.mean()) / feature.std()
292 | # if ix <= 1:
293 | # feature = static_features[:, ix]
294 | # static_features[:, ix] = (feature - feature.mean()) / feature.std()
295 | #
296 | # features_tensor = torch.tensor(features, dtype=torch.float32)
297 | # targets_tensor = torch.tensor(targets, dtype=torch.float32)
298 | # static_features_tensor = torch.tensor(static_features, dtype=torch.float32)
299 | #
300 | # # 创建包含动态特征、静态特征和目标的数据集
301 | # dataset = TensorDataset(features_tensor, static_features_tensor, targets_tensor)
302 | # train_dataset, eval_dataset = random_split(dataset, [8000, 10238 - 8000])
--------------------------------------------------------------------------------
/Core/Plot/plot.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import os.path\n",
12 | "import pandas as pd\n",
13 | "import numpy as np\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import seaborn as sns\n",
16 | "\n",
17 | "# 准备\n",
18 | "ewtc_path = r'H:\\Datasets\\Objects\\Veg\\LXB_plot\\Data\\sw_EWTC_T.csv'\n",
19 | "twsc_path = r'H:\\Datasets\\Objects\\Veg\\LXB_plot\\Data\\sw_TWSC_SH_T.csv'\n",
20 | "ndvi_path = r'H:\\Datasets\\Objects\\Veg\\LXB_plot\\Data\\sw_NDVI_T.csv'\n",
21 | "out_dir = r'H:\\Datasets\\Objects\\Veg\\LXB_plot'\n",
22 | "sns.set_style('darkgrid')\n",
23 | "plt.rcParams['font.sans-serif'] = ['Times New Roman'] # 新罗马字体\n",
24 | "\n",
25 | "\n",
26 | "# 绘制\n",
27 | "ewtc = pd.read_csv(ewtc_path)\n",
28 | "twsc = pd.read_csv(twsc_path)\n",
29 | "ndvi = pd.read_csv(ndvi_path)\n",
30 | "awc = pd.merge(ewtc, twsc, left_on='ProvinceNa', right_on='ProvinceNa', suffixes=('_ewtc', '_twsc'))\n",
31 | "awc['AWC'] = awc['EWTC'] - awc['TWSC_SH']\n",
32 | "\n",
33 | "# 绘制NDVI年变化折线图和月变化柱状图\n",
34 | "ndvi_monthly = ndvi[['Month', 'NDVI']].groupby('Month').mean()\n",
35 | "ndvi_yearly = ndvi[['Year', 'NDVI']].groupby('Year').mean()\n",
36 | "ndvi_yearly['Year'] = ndvi_yearly.index\n",
37 | "ndvi_monthly['Month'] = ndvi_monthly.index"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "outputs": [
44 | {
45 | "ename": "NoMatchingVersions",
46 | "evalue": "No matches for version='5.16.3' among ['4.0.2', '4.8.1', '4.17.0'].\nOften this can be fixed by updating altair_viewer:\n pip install -U altair_viewer",
47 | "output_type": "error",
48 | "traceback": [
49 | "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
50 | "\u001B[1;31mNoMatchingVersions\u001B[0m Traceback (most recent call last)",
51 | "Cell \u001B[1;32mIn[2], line 8\u001B[0m\n\u001B[0;32m 2\u001B[0m chart \u001B[38;5;241m=\u001B[39m alt\u001B[38;5;241m.\u001B[39mChart(ndvi_monthly)\u001B[38;5;241m.\u001B[39mmark_bar()\u001B[38;5;241m.\u001B[39mencode(\n\u001B[0;32m 3\u001B[0m x\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mMonth\u001B[39m\u001B[38;5;124m'\u001B[39m,\n\u001B[0;32m 4\u001B[0m y\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mNDVI\u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[0;32m 5\u001B[0m )\n\u001B[0;32m 7\u001B[0m \u001B[38;5;66;03m# 显示图表\u001B[39;00m\n\u001B[1;32m----> 8\u001B[0m \u001B[43mchart\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mshow\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n",
52 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair\\vegalite\\v5\\api.py:2691\u001B[0m, in \u001B[0;36mTopLevelMixin.show\u001B[1;34m(self, embed_opt, open_browser)\u001B[0m\n\u001B[0;32m 2686\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mImportError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n\u001B[0;32m 2687\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m 2688\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mshow\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m method requires the altair_viewer package. \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 2689\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mSee http://github.com/altair-viz/altair_viewer\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 2690\u001B[0m ) \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01merr\u001B[39;00m\n\u001B[1;32m-> 2691\u001B[0m \u001B[43maltair_viewer\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mshow\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43membed_opt\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43membed_opt\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mopen_browser\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mopen_browser\u001B[49m\u001B[43m)\u001B[49m\n",
53 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_viewer.py:355\u001B[0m, in \u001B[0;36mChartViewer.show\u001B[1;34m(self, chart, embed_opt, open_browser)\u001B[0m\n\u001B[0;32m 328\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mshow\u001B[39m(\n\u001B[0;32m 329\u001B[0m \u001B[38;5;28mself\u001B[39m,\n\u001B[0;32m 330\u001B[0m chart: Union[\u001B[38;5;28mdict\u001B[39m, alt\u001B[38;5;241m.\u001B[39mTopLevelMixin],\n\u001B[0;32m 331\u001B[0m embed_opt: Optional[\u001B[38;5;28mdict\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[0;32m 332\u001B[0m open_browser: Optional[\u001B[38;5;28mbool\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[0;32m 333\u001B[0m ) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 334\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"Show chart and prompt to pause execution.\u001B[39;00m\n\u001B[0;32m 335\u001B[0m \n\u001B[0;32m 336\u001B[0m \u001B[38;5;124;03m Use this to show a chart within a stand-alone script, to prevent the Python process\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 353\u001B[0m \u001B[38;5;124;03m render : Jupyter renderer for chart.\u001B[39;00m\n\u001B[0;32m 354\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[1;32m--> 355\u001B[0m msg \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdisplay\u001B[49m\u001B[43m(\u001B[49m\u001B[43mchart\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43membed_opt\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43membed_opt\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mopen_browser\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mopen_browser\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 356\u001B[0m \u001B[38;5;28mprint\u001B[39m(msg)\n\u001B[0;32m 357\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_provider \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n",
54 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_viewer.py:266\u001B[0m, in \u001B[0;36mChartViewer.display\u001B[1;34m(self, chart, inline, embed_opt, open_browser)\u001B[0m\n\u001B[0;32m 264\u001B[0m chart \u001B[38;5;241m=\u001B[39m chart\u001B[38;5;241m.\u001B[39mto_dict()\n\u001B[0;32m 265\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(chart, \u001B[38;5;28mdict\u001B[39m)\n\u001B[1;32m--> 266\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_initialize\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 267\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_stream \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 268\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mRuntimeError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mInternal: _stream is not defined.\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n",
55 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_viewer.py:183\u001B[0m, in \u001B[0;36mChartViewer._initialize\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 180\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_use_bundled_js:\n\u001B[0;32m 181\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m package \u001B[38;5;129;01min\u001B[39;00m [\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mvega\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mvega-lite\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mvega-embed\u001B[39m\u001B[38;5;124m\"\u001B[39m]:\n\u001B[0;32m 182\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_resources[package] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_provider\u001B[38;5;241m.\u001B[39mcreate(\n\u001B[1;32m--> 183\u001B[0m content\u001B[38;5;241m=\u001B[39m\u001B[43mget_bundled_script\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 184\u001B[0m \u001B[43m \u001B[49m\u001B[43mpackage\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_versions\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mget\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpackage\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 185\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m,\n\u001B[0;32m 186\u001B[0m route\u001B[38;5;241m=\u001B[39m\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mscripts/\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpackage\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.js\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 187\u001B[0m )\n\u001B[0;32m 189\u001B[0m favicon \u001B[38;5;241m=\u001B[39m pkgutil\u001B[38;5;241m.\u001B[39mget_data(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124maltair_viewer\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mstatic/favicon.ico\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 190\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m favicon \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n",
56 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_scripts.py:40\u001B[0m, in \u001B[0;36mget_bundled_script\u001B[1;34m(package, version)\u001B[0m\n\u001B[0;32m 36\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m package \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m listing:\n\u001B[0;32m 37\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m 38\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpackage \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpackage\u001B[38;5;132;01m!r}\u001B[39;00m\u001B[38;5;124m not recognized. Available: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mlist\u001B[39m(listing)\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 39\u001B[0m )\n\u001B[1;32m---> 40\u001B[0m version_str \u001B[38;5;241m=\u001B[39m \u001B[43mfind_version\u001B[49m\u001B[43m(\u001B[49m\u001B[43mversion\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlisting\u001B[49m\u001B[43m[\u001B[49m\u001B[43mpackage\u001B[49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 41\u001B[0m path \u001B[38;5;241m=\u001B[39m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mscripts/\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpackage\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m-\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mversion_str\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.js\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 42\u001B[0m content \u001B[38;5;241m=\u001B[39m pkgutil\u001B[38;5;241m.\u001B[39mget_data(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124maltair_viewer\u001B[39m\u001B[38;5;124m\"\u001B[39m, path)\n",
57 | "File \u001B[1;32md:\\softwares\\python38\\lib\\site-packages\\altair_viewer\\_utils.py:212\u001B[0m, in \u001B[0;36mfind_version\u001B[1;34m(version, candidates, strict_micro)\u001B[0m\n\u001B[0;32m 210\u001B[0m matches \u001B[38;5;241m=\u001B[39m [c \u001B[38;5;28;01mfor\u001B[39;00m c \u001B[38;5;129;01min\u001B[39;00m cand \u001B[38;5;28;01mif\u001B[39;00m v\u001B[38;5;241m.\u001B[39mmatches(c)]\n\u001B[0;32m 211\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m matches:\n\u001B[1;32m--> 212\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m NoMatchingVersions(\n\u001B[0;32m 213\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mNo matches for version=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mversion\u001B[38;5;132;01m!r}\u001B[39;00m\u001B[38;5;124m among \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mcandidates\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 214\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mOften this can be fixed by updating altair_viewer:\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 215\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m pip install -U altair_viewer\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 216\u001B[0m )\n\u001B[0;32m 217\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mstr\u001B[39m(matches[\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m])\n",
58 | "\u001B[1;31mNoMatchingVersions\u001B[0m: No matches for version='5.16.3' among ['4.0.2', '4.8.1', '4.17.0'].\nOften this can be fixed by updating altair_viewer:\n pip install -U altair_viewer"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "import altair as alt\n",
64 | "chart = alt.Chart(ndvi_monthly).mark_bar().encode(\n",
65 | " x='Month',\n",
66 | " y='NDVI'\n",
67 | ")\n",
68 | "\n",
69 | "# 显示图表\n",
70 | "chart.show()"
71 | ],
72 | "metadata": {
73 | "collapsed": false
74 | }
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "outputs": [],
80 | "source": [],
81 | "metadata": {
82 | "collapsed": false
83 | }
84 | }
85 | ],
86 | "metadata": {
87 | "kernelspec": {
88 | "display_name": "Python 3",
89 | "language": "python",
90 | "name": "python3"
91 | },
92 | "language_info": {
93 | "codemirror_mode": {
94 | "name": "ipython",
95 | "version": 2
96 | },
97 | "file_extension": ".py",
98 | "mimetype": "text/x-python",
99 | "name": "python",
100 | "nbconvert_exporter": "python",
101 | "pygments_lexer": "ipython2",
102 | "version": "2.7.6"
103 | }
104 | },
105 | "nbformat": 4,
106 | "nbformat_minor": 0
107 | }
108 |
--------------------------------------------------------------------------------
/Core/model_train.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2024/1/3 16:54
3 | # @Email : chaoqiezi.one@qq.com
4 |
5 | """
6 | This script is used to 构建lstm模型并训练
7 |
8 | 2024/5/11 增加Rs、Lon、Lat输入到模型中训练
9 |
10 | 2024/5/11 仅使用动态特征中的LST, PRCP, ET, RS进行训练
11 | """
12 |
13 | import random
14 | import glob
15 | import os.path
16 | import numpy as np
17 | import pandas as pd
18 | import torch
19 | from torchsummary import summary
20 | from torch.utils.data import DataLoader, random_split
21 | from VEG.utils.utils import H5DatasetDecoder, cal_r2
22 | from VEG.utils.models import LSTMModel, LSTMModelDynamic
23 | from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
24 |
25 |
26 | def set_seed(seed=42):
27 | random.seed(seed)
28 | np.random.seed(seed)
29 | os.environ['PYTHONHASHSEED'] = str(seed)
30 | torch.manual_seed(seed)
31 | torch.cuda.manual_seed(seed)
32 | torch.cuda.manual_seed_all(seed) # 如果使用多GPU
33 | torch.backends.cudnn.deterministic = True
34 | torch.backends.cudnn.benchmark = False
35 |
36 |
37 | set_seed(42)
38 |
39 | # 准备
40 | train_path = r'E:\FeaturesTargets\train.h5'
41 | eval_path = r'E:\FeaturesTargets\eval.h5'
42 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
43 | out_model_dir = r'E:\Models'
44 | dynamic_features_name = [
45 | 'LST_MAX',
46 | 'PRCP',
47 | 'ET',
48 | # 'Qs',
49 | # 'Qsb',
50 | # 'TWSC',
51 | 'Rs'
52 | ]
53 | static_feature_name = [
54 | 'Slope',
55 | 'DEM',
56 | 'Lon',
57 | 'Lat'
58 | ]
59 | # 创建LSTM模型实例并移至GPU
60 | model = LSTMModel(4, 256, 4, 12).to('cuda' if torch.cuda.is_available() else 'cpu')
61 | summary(model, input_data=[(12, 4), (4,)])
62 | # model = LSTMModelDynamic(4, 256, 4, 12).to('cuda' if torch.cuda.is_available() else 'cpu')
63 | # summary(model, input_data=[(12, 4)])
64 | batch_size = 256
65 |
66 | # generator = torch.Generator().manual_seed(42) # 指定随机种子
67 | # train_dataset, eval_dataset, sample_dataset = random_split(dataset, (0.8, 0.195, 0.005), generator=generator)
68 | # train_dataset, eval_dataset = random_split(dataset, (0.8, 0.2), generator=generator)
69 | # 创建数据加载器
70 | train_dataset = H5DatasetDecoder(train_path) # 创建自定义数据集实例
71 | eval_dataset = H5DatasetDecoder(eval_path)
72 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
73 | eval_data_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
74 | # 训练参数
75 | criterion = torch.nn.MSELoss()
76 | optimizer = torch.optim.Adam(model.parameters(), lr=0.002) # 初始学习率设置为0.001
77 | epochs_num = 30
78 | model.train() # 切换为训练模式
79 |
80 |
81 | def model_train(data_loader, epochs_num: int = 25, save_path: str = None, device='cuda'):
82 | # 创建新的模型实例
83 | model = LSTMModel(4, 256, 4, 12).to(device)
84 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) # 初始学习率设置为0.001
85 | epochs_loss = []
86 | for epoch in range(epochs_num):
87 | train_loss = []
88 | for dynamic_inputs, static_inputs, targets in data_loader:
89 | # if feature_ix is not None:
90 | # if dynamic:
91 | # batch_size, _, _ = dynamic_inputs.shape
92 | # shuffled_indices = torch.randperm(batch_size)
93 | # # dynamic_inputs[:, :, feature_ix] = torch.tensor(np.random.permutation(dynamic_inputs[:, :, feature_ix]))
94 | # dynamic_inputs[:, :, feature_ix] = torch.tensor(dynamic_inputs[shuffled_indices, :, feature_ix])
95 | # else:
96 | # batch_size, _ = static_inputs.shape
97 | # shuffled_indices = torch.randperm(batch_size)
98 | # # static_inputs[:, feature_ix] = torch.tensor(np.random.permutation(static_inputs[shuffled_indices, feature_ix]))
99 | # static_inputs[:, feature_ix] = torch.tensor(static_inputs[shuffled_indices, feature_ix])
100 | dynamic_inputs, static_inputs, targets = dynamic_inputs.to(device), static_inputs.to(device), targets.to(
101 | device)
102 |
103 | """正常"""
104 | # 前向传播
105 | outputs = model(dynamic_inputs, static_inputs)
106 | # 计算损失
107 | loss = criterion(outputs, targets)
108 | # 反向传播和优化
109 | loss.backward()
110 | optimizer.step()
111 | # scheduler.step() # 更新学习率
112 |
113 | optimizer.zero_grad() # 清除梯度
114 | train_loss.append(loss.item())
115 | print(f'Epoch {epoch + 1}/{epochs_num}, Loss: {np.mean(train_loss)}')
116 | epochs_loss.append(np.mean(train_loss))
117 |
118 | if save_path:
119 | torch.save(model.state_dict(), save_path)
120 |
121 | return epochs_loss
122 |
123 |
124 | def model_eval_whole(model_path: str, data_loader, device='cuda'):
125 | # 加载模型
126 | model = LSTMModel(4, 256, 4, 12).to(device)
127 | model.load_state_dict(torch.load(model_path))
128 |
129 | # 评估
130 | model.eval() # 评估模式
131 | all_outputs = []
132 | all_targets = []
133 | with torch.no_grad():
134 | for dynamic_inputs, static_inputs, targets in data_loader:
135 | dynamic_inputs, static_inputs, targets = dynamic_inputs.to(device), static_inputs.to(device), targets.to(
136 | device)
137 | outputs = model(dynamic_inputs, static_inputs)
138 | all_outputs.append(outputs.cpu()) # outputs/targets: (batch_size, time_steps)
139 | all_targets.append(targets.cpu())
140 |
141 | all_outputs = np.concatenate(all_outputs, axis=0)
142 | all_targets = np.concatenate(all_targets, axis=0)
143 |
144 | # mse_per_step = []
145 | # mae_per_step = []
146 | # r2_per_step = []
147 | # rmse_per_step = []
148 | # for time_step in range(12):
149 | # mse_step = mean_squared_error(all_targets[:, time_step], all_outputs[:, time_step])
150 | # mae_step = mean_absolute_error(all_targets[:, time_step], all_outputs[:, time_step])
151 | # r2_step = r2_score(all_targets[:, time_step], all_outputs[:, time_step])
152 | # rmse_step = np.sqrt(mse_step)
153 | #
154 | # mse_per_step.append(mse_step)
155 | # mae_per_step.append(mae_step)
156 | # r2_per_step.append(r2_step)
157 | # rmse_per_step.append(rmse_step)
158 |
159 | # mse = np.mean(mse_per_step)
160 | # mae = np.mean(mae_per_step)
161 | # r2 = np.mean(r2_per_step)
162 | # rmse = np.mean(rmse_per_step)
163 |
164 | # 不区分月份求取指标(视为整体)
165 | mse_step = mean_squared_error(all_targets.reshape(-1), all_outputs.reshape(-1))
166 | mae_step = mean_absolute_error(all_targets.reshape(-1), all_outputs.reshape(-1))
167 | r2_step = r2_score(all_targets.reshape(-1), all_outputs.reshape(-1))
168 | rmse_step = np.sqrt(mse_step)
169 | return mse_step, mae_step, r2_step, rmse_step
170 |
171 | # return mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets
172 |
173 |
174 |
175 | def model_eval(model_path: str, data_loader, device='cuda'):
176 | # 加载模型
177 | model = LSTMModel(4, 256, 4, 12).to(device)
178 | model.load_state_dict(torch.load(model_path))
179 |
180 | # 评估
181 | model.eval() # 评估模式
182 | all_outputs = []
183 | all_targets = []
184 | with torch.no_grad():
185 | for dynamic_inputs, static_inputs, targets in data_loader:
186 | dynamic_inputs, static_inputs, targets = dynamic_inputs.to(device), static_inputs.to(device), targets.to(
187 | device)
188 | outputs = model(dynamic_inputs, static_inputs)
189 | all_outputs.append(outputs.cpu()) # outputs/targets: (batch_size, time_steps)
190 | all_targets.append(targets.cpu())
191 |
192 | all_outputs = np.concatenate(all_outputs, axis=0)
193 | all_targets = np.concatenate(all_targets, axis=0)
194 |
195 | mse_per_step = []
196 | mae_per_step = []
197 | r2_per_step = []
198 | rmse_per_step = []
199 | for time_step in range(12):
200 | mse_step = mean_squared_error(all_targets[:, time_step], all_outputs[:, time_step])
201 | mae_step = mean_absolute_error(all_targets[:, time_step], all_outputs[:, time_step])
202 | r2_step = r2_score(all_targets[:, time_step], all_outputs[:, time_step])
203 | rmse_step = np.sqrt(mse_step)
204 |
205 | mse_per_step.append(mse_step)
206 | mae_per_step.append(mae_step)
207 | r2_per_step.append(r2_step)
208 | rmse_per_step.append(rmse_step)
209 |
210 | return mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets
211 |
212 | if __name__ == '__main__':
213 | df = pd.DataFrame()
214 | # 常规训练
215 | df['normal_epochs_loss'] = model_train(train_data_loader, save_path=os.path.join(out_model_dir, 'normal_model_V07.pth'))
216 | print('>>> 常规训练结束')
217 | # 特征重要性训练
218 | # 动态特征
219 | for feature_ix in range(4):
220 | train_dataset = H5DatasetDecoder(train_path, shuffle_feature_ix=feature_ix, dynamic=True) # 创建自定义数据集实例
221 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
222 |
223 | cur_feature_name = dynamic_features_name[feature_ix]
224 | save_path = os.path.join(out_model_dir, cur_feature_name + '_model_V07.pth')
225 | df[cur_feature_name + '_epochs_loss'] = \
226 | model_train(train_data_loader, save_path=save_path)
227 | print('>>> {}乱序排列 训练结束'.format(cur_feature_name))
228 | # 静态特征
229 | for feature_ix in range(4):
230 | train_dataset = H5DatasetDecoder(train_path, shuffle_feature_ix=feature_ix, dynamic=False) # 创建自定义数据集实例
231 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
232 |
233 | cur_feature_name = static_feature_name[feature_ix]
234 | save_path = os.path.join(out_model_dir, cur_feature_name + '_model_V07.pth')
235 | df[cur_feature_name + '_epochs_loss'] = \
236 | model_train(train_data_loader, save_path=save_path)
237 | print('>>> {}乱序排列 训练结束'.format(cur_feature_name))
238 | df.to_excel(r'E:\Models\training_eval_results\training_loss_V07.xlsx')
239 |
240 | # 评估
241 | indicator_whole = pd.DataFrame()
242 | indicator = pd.DataFrame()
243 | model_paths = glob.glob(os.path.join(out_model_dir, '*_model_V07.pth'))
244 | for model_path in model_paths:
245 | cur_model_name = os.path.basename(model_path).rsplit('_model')[0]
246 | mse_step, mae_step, r2_step, rmse_step = model_eval_whole(model_path, eval_data_loader)
247 | indicator_whole[cur_model_name + '_evaluate_mse'] = [mse_step]
248 | indicator_whole[cur_model_name + '_evaluate_mae'] = [mae_step]
249 | indicator_whole[cur_model_name + '_evaluate_r2'] = [r2_step]
250 | indicator_whole[cur_model_name + '_evaluate_rmse'] = [rmse_step]
251 |
252 | mse_per_step, mae_per_step, r2_per_step, rmse_per_step, all_outputs, all_targets = model_eval(model_path, eval_data_loader)
253 |
254 | all_outputs_targets = np.concatenate((all_outputs, all_targets), axis=1)
255 | columns = [*['outputs_{:02}'.format(month) for month in range(1, 13)], *['targets_{:02}'.format(month) for month in range(1, 13)]]
256 | outputs_targets = pd.DataFrame(all_outputs_targets, columns=columns)
257 | indicator[cur_model_name + '_evaluate_mse'] = mse_per_step
258 | indicator[cur_model_name + '_evaluate_mae'] = mae_per_step
259 | indicator[cur_model_name + '_evaluate_r2'] = r2_per_step
260 | indicator[cur_model_name + '_evaluate_rmse'] = rmse_per_step
261 | outputs_targets.to_excel(r'E:\Models\training_eval_results\{}_outputs_targets_V07.xlsx'.format(cur_model_name))
262 | print('>>> {} 重要性评估完毕'.format(cur_model_name))
263 | indicator.loc['均值指标'] = np.mean(indicator, axis=0)
264 | indicator.to_excel(r'E:\Models\training_eval_results\eval_indicators_V07.xlsx')
265 | indicator_whole.to_excel(r'E:\Models\training_eval_results\eval_indicators_整体_V07.xlsx')
266 | # model.eval()
267 | # eval_loss = []
268 | # with torch.no_grad():
269 | # for dynamic_inputs, static_inputs, targets in data_loader:
270 | # dynamic_inputs = dynamic_inputs.to('cuda' if torch.cuda.is_available() else 'cpu')
271 | # static_inputs = static_inputs.to('cuda' if torch.cuda.is_available() else 'cpu')
272 | # targets = targets.to('cuda' if torch.cuda.is_available() else 'cpu')
273 | # # 前向传播
274 | # outputs = model(dynamic_inputs, static_inputs)
275 | # # 计算损失
276 | # loss = criterion(outputs, targets)
277 | # r2 = cal_r2(outputs, targets)
278 | # print('预测项:', outputs)
279 | # print('目标项:', targets)
280 | # print(f'MSE Loss: {loss.item()}')
281 | # break
282 | # eval_loss.append(loss.item())
283 | # print(f'Loss: {np.mean(eval_loss)}')
284 | # print(f'R2:', r2)
285 |
286 |
287 |
288 | # # 取
289 | # with h5py.File(r'E:\FeaturesTargets\features_targets.h5', 'r') as h5:
290 | # features = np.transpose(h5['2003/features1'][:], (1, 0, 2)) # shape=(样本数, 时间步, 特征项)
291 | # targets = np.transpose(h5['2003/targets'][:], (1, 0)) # shape=(样本数, 时间步)
292 | # static_features = np.column_stack((h5['2003/features2'][:], h5['dem'][:]))
293 | # mask1 = ~np.any(np.isnan(features), axis=(1, 2))
294 | # mask2 = ~np.any(np.isnan(targets), axis=(1,))
295 | # mask3 = ~np.any(np.isnan(static_features), axis=(1, ))
296 | # mask = (mask1 & mask2 & mask3)
297 | # features = features[mask, :, :]
298 | # targets = targets[mask, :]
299 | # static_features = static_features[mask, :]
300 | # print(features.shape)
301 | # print(targets.shape)
302 | # for ix in range(6):
303 | # feature = features[:, :, ix]
304 | # features[:, :, ix] = (feature - feature.mean()) / feature.std()
305 | # if ix <= 1:
306 | # feature = static_features[:, ix]
307 | # static_features[:, ix] = (feature - feature.mean()) / feature.std()
308 | #
309 | # features_tensor = torch.tensor(features, dtype=torch.float32)
310 | # targets_tensor = torch.tensor(targets, dtype=torch.float32)
311 | # static_features_tensor = torch.tensor(static_features, dtype=torch.float32)
312 | #
313 | # # 创建包含动态特征、静态特征和目标的数据集
314 | # dataset = TensorDataset(features_tensor, static_features_tensor, targets_tensor)
315 | # train_dataset, eval_dataset = random_split(dataset, [8000, 10238 - 8000])
--------------------------------------------------------------------------------
/Core/process_modis.py:
--------------------------------------------------------------------------------
1 | # @Author : ChaoQiezi
2 | # @Time : 2023/12/14 6:31
3 | # @FileName : process_modis.py
4 | # @Email : chaoqiezi.one@qq.com
5 |
6 | """
7 | This script is used to 对MODIS GRID产品(hdf4文件)进行批量镶嵌和重投影并输出为GeoTIFF文件
8 |
9 | <说明>
10 | # pyhdf模块相关
11 | 对于读取HDF4文件的pyhdf模块需要依据python版本安装指定的whl文件才可正常运行,
12 | 下载wheel文件见: https://www.lfd.uci.edu/~gohlke/pythonlibs/
13 | 安装: cmd ==> where python ==> 跳转指定python路径 ==> cd Scripts ==> pip install wheel文件的绝对路径
14 |
15 | # 数据集
16 | MCD12Q1为土地利用数据
17 | MOD11A2为地表温度数据
18 | MOD13A2为植被指数数据(包括NDVI\EVI)
19 |
20 | # 相关链接
21 | CSDN博客: https://blog.csdn.net/m0_63001937/article/details/134995867
22 | 微信博文: https://mp.weixin.qq.com/s/6oeUEdazz8FL1pRnQQFhMA
23 |
24 | """
25 |
26 | import os
27 | import re
28 | import time
29 | from glob import glob
30 | from typing import Union
31 | from datetime import datetime
32 | from math import ceil, floor
33 | from threading import Lock
34 | from concurrent.futures import ThreadPoolExecutor # 线程池
35 |
36 | import numpy as np
37 | from pyhdf.SD import SD
38 | from osgeo import gdal, osr
39 | from scipy import stats
40 |
41 |
42 | def img_mosaic(mosaic_paths: list, mosaic_ds_name: str, return_all: bool = True, img_nodata: Union[int, float] = -1,
43 | img_type: Union[np.int32, np.float32, None] = None, unit_conversion: bool = False,
44 | scale_factor_op: str = 'multiply', mosaic_mode: str = 'last'):
45 | """
46 | 该函数用于对列表中的所有HDF4文件进行镶嵌
47 | :param mosaic_mode: 镶嵌模式, 默认是Last(即如果有存在像元重叠, mosaic_paths中靠后影像的像元将覆盖其),
48 | 可选: last, mean, max, min
49 | :param scale_factor_op: 比例因子的运算符, 默认是乘以(可选: multiply, divide), 该参数尽在unit_conversion为True时生效
50 | :param unit_conversion: 是否进行单位换算
51 | :param mosaic_ds_name: 待镶嵌的数据集名称
52 | :param mosaic_paths: 多个HDF4文件路径组成的字符串列表
53 | :param return_all: 是否一同返回仿射变换、镶嵌数据集的坐标系等参数
54 | :return: 默认返回镶嵌好的数据集
55 | :param img_type: 待镶嵌影像的数据类型
56 | :param img_nodata: 影像中的无效值设置
57 |
58 | 镶嵌策略是last模式,
59 | """
60 |
61 | # 获取镶嵌范围
62 | x_mins, x_maxs, y_mins, y_maxs = [], [], [], []
63 | for mosaic_path in mosaic_paths:
64 | hdf = SD(mosaic_path) # 默认只读
65 | # 获取元数据
66 | metadata = hdf.__getattr__('StructMetadata.0')
67 | # 获取角点信息
68 | ul_pt = [float(x) for x in re.findall(r'UpperLeftPointMtrs=\((.*)\)', metadata)[0].split(',')]
69 | lr_pt = [float(x) for x in re.findall(r'LowerRightMtrs=\((.*)\)', metadata)[0].split(',')]
70 | x_mins.append(ul_pt[0])
71 | x_maxs.append(lr_pt[0])
72 | y_mins.append(lr_pt[1])
73 | y_maxs.append(ul_pt[1])
74 | else:
75 | # 计算分辨率
76 | col = int(re.findall(r'XDim=(.*?)\n', metadata)[0])
77 | row = int(re.findall(r'YDim=(.*?)\n', metadata)[0])
78 | x_res = (lr_pt[0] - ul_pt[0]) / col
79 | y_res = (ul_pt[1] - lr_pt[1]) / row
80 | # 如果img_type没有指定, 那么数据类型默认为与输入相同
81 | if img_type is None:
82 | img_type = hdf.select(mosaic_ds_name)[:].dtype
83 | # 获取数据集的坐标系参数并转化为proj4字符串格式
84 | projection_param = [float(_param) for _param in re.findall(r'ProjParams=\((.*?)\)', metadata)[0].split(',')]
85 | mosaic_img_proj4 = "+proj={} +R={:0.4f} +lon_0={:0.4f} +lat_0={:0.4f} +x_0={:0.4f} " \
86 | "+y_0={:0.4f} ".format('sinu', projection_param[0], projection_param[4], projection_param[5],
87 | projection_param[6], projection_param[7])
88 | # 关闭文件, 释放资源
89 | hdf.end()
90 | x_min, x_max, y_min, y_max = min(x_mins), max(x_maxs), min(y_mins), max(y_maxs)
91 |
92 | # 镶嵌
93 | col = ceil((x_max - x_min) / x_res)
94 | row = ceil((y_max - y_min) / y_res)
95 | mosaic_imgs = [] # 用于存储各个影像
96 | for ix, mosaic_path in enumerate(mosaic_paths):
97 | mosaic_img = np.full((row, col), img_nodata, dtype=img_type) # 初始化
98 | hdf = SD(mosaic_path)
99 | target_ds = hdf.select(mosaic_ds_name)
100 | # 读取数据集和预处理
101 | target = target_ds.get().astype(img_type)
102 | valid_range = target_ds.attributes()['valid_range']
103 | target[(target < valid_range[0]) | (target > valid_range[1])] = img_nodata # 限定有效范围
104 | if unit_conversion: # 进行单位换算
105 | scale_factor = target_ds.attributes()['scale_factor']
106 | add_offset = target_ds.attributes()['add_offset']
107 | # 判断比例因子的运算符
108 | if scale_factor_op == 'multiply':
109 | target[target != img_nodata] = target[target != img_nodata] * scale_factor + add_offset
110 | elif scale_factor_op == 'divide':
111 | target[target != img_nodata] = target[target != img_nodata] / scale_factor + add_offset
112 | # 计算当前镶嵌范围
113 | start_row = floor((y_max - (y_maxs[ix] - x_res / 2)) / y_res)
114 | start_col = floor(((x_mins[ix] + x_res / 2) - x_min) / x_res)
115 | end_row = start_row + target.shape[0]
116 | end_col = start_col + target.shape[1]
117 | mosaic_img[start_row:end_row, start_col:end_col] = target
118 | mosaic_imgs.append(mosaic_img)
119 |
120 | # 释放资源
121 | target_ds.endaccess()
122 | hdf.end()
123 |
124 | # 判断镶嵌模式
125 | if mosaic_mode == 'last':
126 | mosaic_img = mosaic_imgs[0].copy()
127 | for img in mosaic_imgs:
128 | mosaic_img[img != img_nodata] = img[img != img_nodata]
129 | elif mosaic_mode == 'mean':
130 | mosaic_imgs = np.asarray(mosaic_imgs) # mosaic_img.shape = (mosaic_num, rows, cols)
131 | mask = mosaic_imgs == img_nodata
132 | mosaic_img = np.ma.array(mosaic_imgs, mask=mask).mean(axis=0).filled(img_nodata)
133 | elif mosaic_mode == 'max':
134 | mosaic_imgs = np.asarray(mosaic_imgs) # mosaic_img.shape = (mosaic_num, rows, cols)
135 | mask = mosaic_imgs == img_nodata
136 | mosaic_img = np.ma.array(mosaic_imgs, mask=mask).max(axis=0).filled(img_nodata)
137 | elif mosaic_mode == 'min':
138 | mosaic_imgs = np.asarray(mosaic_imgs) # mosaic_img.shape = (mosaic_num, rows, cols)
139 | mask = mosaic_imgs == img_nodata
140 | mosaic_img = np.ma.array(mosaic_imgs, mask=mask).min(axis=0).filled(img_nodata)
141 | else:
142 | raise ValueError('不支持的镶嵌模式: {}'.format(mosaic_mode))
143 |
144 | if return_all:
145 | return mosaic_img, [x_min, x_res, 0, y_max, 0, -y_res], mosaic_img_proj4
146 |
147 | return mosaic_img
148 |
149 |
150 | def img_warp(src_img: np.ndarray, out_path: str, transform: list, src_proj4: str, out_res: float,
151 | nodata: Union[int, float] = None, resample: str = 'nearest') -> None:
152 | """
153 | 该函数用于对正弦投影下的栅格矩阵进行重投影(GLT校正), 得到WGS84坐标系下的栅格矩阵并输出为TIFF文件
154 | :param src_img: 待重投影的栅格矩阵
155 | :param out_path: 输出路径
156 | :param transform: 仿射变换参数([x_min, x_res, 0, y_max, 0, -y_res], 旋转参数为0是常规选项)
157 | :param out_res: 输出的分辨率(栅格方形)
158 | :param nodata: 设置为NoData的数值
159 | :param out_type: 输出的数据类型
160 | :param resample: 重采样方法(默认是最近邻, ['nearest', 'bilinear', 'cubic'])
161 | :param src_proj4: 表达源数据集(src_img)的坐标系参数(以proj4字符串形式)
162 | :return: None
163 | """
164 |
165 | # 输出数据类型
166 | if np.issubdtype(src_img.dtype, np.integer):
167 | out_type = gdal.GDT_Int32
168 | elif np.issubdtype(src_img.dtype, np.floating):
169 | out_type = gdal.GDT_Float32
170 | else:
171 | raise ValueError("当前待校正数组类型为不支持的数据类型")
172 | resamples = {'nearest': gdal.GRA_NearestNeighbour, 'bilinear': gdal.GRA_Bilinear, 'cubic': gdal.GRA_Cubic}
173 | # 原始数据集创建(正弦投影)
174 | driver = gdal.GetDriverByName('MEM') # 在内存中临时创建
175 | src_ds = driver.Create("", src_img.shape[1], src_img.shape[0], 1, out_type) # 注意: 先传列数再传行数, 1表示单波段
176 | srs = osr.SpatialReference()
177 | srs.ImportFromProj4(src_proj4)
178 | """
179 | 对于src_proj4, 依据元数据StructMetadata.0知:
180 | Projection=GCTP_SNSOID; ProjParams=(6371007.181000,0,0,0,0,0,0,0,0,0,0,0,0)
181 | 或数据集属性(MODIS_Grid_8Day_1km_LST/Data_Fields/Projection)知:
182 | :grid_mapping_name = "sinusoidal";
183 | :longitude_of_central_meridian = 0.0; // double
184 | :earth_radius = 6371007.181; // double
185 | """
186 | src_ds.SetProjection(srs.ExportToWkt()) # 设置投影信息
187 | src_ds.SetGeoTransform(transform) # 设置仿射参数
188 | src_ds.GetRasterBand(1).WriteArray(src_img) # 写入数据
189 | src_ds.GetRasterBand(1).SetNoDataValue(nodata)
190 | # 重投影信息(WGS84)
191 | dst_srs = osr.SpatialReference()
192 | dst_srs.ImportFromEPSG(4326)
193 | # 重投影
194 | dst_ds = gdal.Warp(out_path, src_ds, dstSRS=dst_srs, xRes=out_res, yRes=out_res, dstNodata=nodata,
195 | outputType=out_type, multithread=True, format='GTiff', resampleAlg=resamples[resample])
196 | if dst_ds: # 释放缓存和资源
197 | dst_ds.FlushCache()
198 | src_ds, dst_ds = None, None
199 |
200 |
201 | def ydays2ym(file_path: str) -> str:
202 | """
203 | 获取路径中的年积日并转化为年月日
204 | :param file_path: 文件路径
205 | :return: 返回表达年月日的字符串
206 | """
207 |
208 | file_name = os.path.basename(file_path)
209 | ydays = file_name[9:16]
210 | date = datetime.strptime(ydays, "%Y%j")
211 |
212 | return date.strftime("%Y_%m")
213 |
214 |
215 | # 闭包
216 | def process_task(union_id, process_paths, ds_name, out_dir, description, nodata, out_res, resamlpe='nearest',
217 | temperature=False, img_type=np.float32, unit_conversion=True, scale_factor_op='multiply',
218 | mosaic_mode='last'):
219 | print_lock = Lock() # 线程锁
220 |
221 | # 处理
222 | def process_id(id: any = None):
223 | start_time = time.time()
224 | cur_mosaic_ixs = [_ix for _ix, _id in enumerate(union_id) if _id == id]
225 | # 镶嵌
226 | mosaic_paths = [process_paths[_ix] for _ix in cur_mosaic_ixs]
227 | mosaic_img, transform, mosaic_img_proj4 = img_mosaic(mosaic_paths, ds_name, img_nodata=nodata,
228 | img_type=img_type, unit_conversion=unit_conversion,
229 | scale_factor_op=scale_factor_op, mosaic_mode=mosaic_mode)
230 | if temperature: # 若设置temperature, 则说明当前处理数据集为地表温度, 需要开尔文 ==> 摄氏度
231 | mosaic_img[mosaic_img != nodata] -= 273.15
232 | # 重投影
233 | reproj_path = os.path.join(out_dir, description + '_' + id + '.tiff')
234 | img_warp(mosaic_img, reproj_path, transform, mosaic_img_proj4, out_res, nodata, resample=resamlpe)
235 | end_time = time.time()
236 |
237 | with print_lock: # 避免打印混乱
238 | print("{}-{} 处理完毕: {:0.2f}s".format(description, id, end_time - start_time))
239 |
240 | return process_id
241 |
242 |
243 | # 准备
244 | in_dir = 'F:\DATA\Cy_modis' # F:\Cy_modis\MCD12Q1_2001_2020、F:\Cy_modis\MOD11A2_2000_2022、F:\Cy_modis\MOD13A2_2001_2020
245 | out_dir = 'H:\Datasets\Objects\Veg'
246 | landuse_name = 'LC_Type1' # Land Cover Type 1: Annual International Geosphere-Biosphere Programme (IGBP) classification
247 | lst_name = 'LST_Day_1km'
248 | ndvi_name = '1 km 16 days NDVI' # 注意panoply上显示为: 1_km_16_days_NDVI, 实际上是做了显示上的优化, 原始名称为当前
249 | evi_name = '1 km 16 days EVI' # 注意panoply上显示为: 1_km_16_days_NDVI, 实际上是做了显示上的优化, 原始名称为当前
250 | out_landuse_res = 0.0045 # 500m
251 | out_lst_res = 0.009 # 1000m
252 | out_ndvi_res = 0.009
253 | out_evi_res = 0.009
254 | # 预准备
255 | out_landuse_dir = os.path.join(out_dir, 'Landuse')
256 | out_lst_dir = os.path.join(out_dir, 'LST_MIN')
257 | out_ndvi_dir = os.path.join(out_dir, 'NDVI_MIN')
258 | out_evi_dir = os.path.join(out_dir, 'evi')
259 | _ = [os.makedirs(_dir, exist_ok=True) for _dir in [out_landuse_dir, out_lst_dir, out_ndvi_dir, out_evi_dir]]
260 |
261 | # # 对MCD12Q1数据集(土地利用数据集)进行镶嵌和重投影(GLT校正)
262 | # landuse_paths = glob(os.path.join(in_dir, '**', 'MCD12Q1*.hdf'), recursive=True) # 迭代
263 | # union_id = [os.path.basename(_path)[9:13] for _path in landuse_paths] # 基于年份进行合并镶嵌的字段(年份-此处)
264 | # unique_id = set(union_id) # unique_id = np.unique(np.asarray(union_id)) # 不使用set是为保证原始顺序
265 | # # 多线程处理
266 | # with ThreadPoolExecutor() as executer:
267 | # start_time = time.time()
268 | # process_id = process_task(union_id, landuse_paths, landuse_name, out_landuse_dir, 'Landuse', 255, out_landuse_res,
269 | # img_type=np.int32, unit_conversion=False)
270 | # executer.map(process_id, unique_id)
271 | # end_time = time.time()
272 | # print('MCD12Q1(土地利用数据集)预处理完毕: {:0.2f}s '.format(end_time - start_time))
273 | # # 常规处理
274 | # for id in unique_id:
275 | # start_time = time.time()
276 | # cur_mosaic_ixs = [_ix for _ix, _id in enumerate(union_id) if _id == id]
277 | # # 镶嵌
278 | # mosaic_paths = [landuse_paths[_ix] for _ix in cur_mosaic_ixs]
279 | # mosaic_img, transform, mosaic_img_proj4 = img_mosaic(mosaic_paths, landuse_name, img_nodata=255, img_type=np.int32)
280 | # # 重投影
281 | # reproj_path = os.path.join(out_landuse_dir, 'landuse_' + id + '.tiff')
282 | # img_warp(mosaic_img, reproj_path, transform, mosaic_img_proj4, out_landuse_res, 255, resample='nearest')
283 | #
284 | # # 打印输出
285 | # end_time = time.time()
286 | # print("Landuse-{} 处理完毕: {:0.2f}s".format(id, end_time - start_time))
287 |
288 | # 对MOD12A2数据集(地表温度数据集)进行镶嵌和重投影(GLT校正)
289 | lst_paths = glob(os.path.join(in_dir, '**', 'MOD11A2*.hdf'), recursive=True)
290 | union_id = [ydays2ym(_path) for _path in lst_paths]
291 | unique_id = set(union_id)
292 | # 多线程处理
293 | with ThreadPoolExecutor() as executer:
294 | start_time = time.time()
295 | process_id = process_task(union_id, lst_paths, lst_name, out_lst_dir, 'LST_MIN', -65535, out_lst_res, resamlpe='cubic',
296 | temperature=True, unit_conversion=True, mosaic_mode='min')
297 | executer.map(process_id, unique_id)
298 | end_time = time.time()
299 | print('MOD11A2(地表温度数据集)预处理完毕: {:0.2f}s'.format(end_time - start_time))
300 | # # 常规处理
301 | # for id in unique_id:
302 | # start_time = time.time()
303 | # cur_mosaic_ixs = [_ix for _ix, _id in enumerate(union_id) if _id == id]
304 | # # 镶嵌
305 | # mosaic_paths = [lst_paths[_ix] for _ix in cur_mosaic_ixs]
306 | # mosaic_img, transform, mosaic_img_proj4 = img_mosaic(mosaic_paths, lst_name, img_nodata=-65535,
307 | # img_type=np.float32, unit_conversion=True)
308 | # # 开尔文 ==> 摄氏度
309 | # mosaic_img -= 273.15
310 | # # 重投影
311 | # reproj_path = os.path.join(out_lst_dir, 'lst_' + id + '.tiff')
312 | # img_warp(mosaic_img, reproj_path, transform, mosaic_img_proj4, out_lst_res, -65535, resample='cubic')
313 | #
314 | # # 打印输出
315 | # end_time = time.time()
316 | # print("LST-{} 处理完毕: {:0.2f}s".format(id, end_time - start_time))
317 |
318 | # 对MOD13A2数据集(NDVI数据集)进行镶嵌和重投影(GLT校正)
319 | ndvi_paths = glob(os.path.join(in_dir, '**', 'MOD13A2*.hdf'), recursive=True)
320 | union_id = [ydays2ym(_path) for _path in ndvi_paths]
321 | unique_id = np.unique(np.asarray(union_id))
322 | # 多线程处理
323 | with ThreadPoolExecutor() as executer:
324 | start_time = time.time()
325 | process_id = process_task(union_id, ndvi_paths, ndvi_name, out_ndvi_dir, 'NDVI_MIN', -65535, out_ndvi_res,
326 | resamlpe='cubic', unit_conversion=True, scale_factor_op='divide', mosaic_mode='min')
327 | executer.map(process_id, unique_id)
328 | # end_time = time.time()
329 | # print('MCD13A2(NDVI数据集)预处理完毕: {:0.2f}s'.format(end_time - start_time))
330 | # 常规处理
331 | # for id in unique_id:
332 | # start_time = time.time()
333 | # cur_mosaic_ixs = [_ix for _ix, _id in enumerate(union_id) if _id == id]
334 | # # 镶嵌
335 | # mosaic_paths = [ndvi_paths[_ix] for _ix in cur_mosaic_ixs]
336 | # mosaic_img, transform, mosaic_img_proj4 = img_mosaic(mosaic_paths, ndvi_name, img_nodata=-65535, img_type=np.float32,
337 | # unit_conversion=True, scale_factor_op='divide')
338 | # # 重投影
339 | # reproj_path = os.path.join(out_ndvi_dir, 'ndvi_' + id + '.tiff')
340 | # img_warp(mosaic_img, reproj_path, transform, mosaic_img_proj4, out_ndvi_res, -65535, resample='cubic')
341 | #
342 | # # 打印输出
343 | # end_time = time.time()
344 | # print("NDVI-{} 处理完毕: {:0.2f}s".format(id, end_time - start_time))
345 |
346 |
347 | # 对MOD13A2数据集(EVI数据集)进行镶嵌和重投影(GLT校正)
348 | evi_paths = glob(os.path.join(in_dir, '**', 'MOD13A2*.hdf'), recursive=True)
349 | union_id = [ydays2ym(_path) for _path in evi_paths]
350 | unique_id = np.unique(np.asarray(union_id))
351 | # 多线程处理
352 | with ThreadPoolExecutor() as executer:
353 | start_time = time.time()
354 | process_id = process_task(union_id, evi_paths, evi_name, out_evi_dir, 'EVI', -65535, out_evi_res,
355 | resamlpe='cubic', unit_conversion=True, scale_factor_op='divide', mosaic_mode='max')
356 | executer.map(process_id, unique_id)
357 | end_time = time.time()
358 | print('MOD13A2(EVI数据集)预处理完毕: {:0.2f}s '.format(end_time - start_time))
--------------------------------------------------------------------------------
/Assets/LSTM-master/README.md:
--------------------------------------------------------------------------------
1 | ```python
2 | import torch
3 | import torch as t
4 | import numpy as np
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torchnet import meter
8 | import xarray as xr
9 | import rioxarray as rxr
10 | ```
11 |
12 |
13 | ```python
14 | torch.cuda.is_available()
15 | ```
16 |
17 |
18 |
19 |
20 | True
21 |
22 |
23 |
24 |
25 | ```python
26 | precipitation_data = rxr.open_rasterio('data/prcp.tif').values
27 |
28 | # 将数据转换为 PyTorch 张量
29 | precipitation_data = torch.tensor(precipitation_data, dtype=torch.float32)
30 |
31 | precipitation_mean = torch.mean(precipitation_data, 0)
32 | precipitation_std = torch.std(precipitation_data, 0)
33 | precipitation = (precipitation_data - precipitation_mean) / precipitation_std
34 |
35 | precipitation_re = precipitation.reshape(183,-1).transpose(0,1)
36 | ```
37 |
38 |
39 | ```python
40 | from utils import plot
41 | import matplotlib.pyplot as plt
42 | import cartopy.crs as ccrs
43 | import numpy as np
44 | import xarray as xr
45 | file_name='data/train/prcp.tif'
46 | ds=xr.open_dataset(file_name)
47 | data = ds['band_data'][7]
48 |
49 | fig = plt.figure()
50 | proj = ccrs.Robinson() #ccrs.Robinson()ccrs.Mollweide()Mollweide()
51 | ax = fig.add_subplot(111, projection=proj)
52 | levels = np.linspace(0, 0.25, num=9)
53 | plot.one_map_flat(data, ax, levels=levels, cmap="RdBu", mask_ocean=False, add_coastlines=True, add_land=True, colorbar=True, plotfunc="pcolormesh")
54 | ```
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 | 
66 |
67 |
68 |
69 |
70 | ```python
71 | # 创建二维矩阵
72 | import random
73 | matrix = torch.mean(torch.stack([torch.mean(precipitation_re, 1)], 1), 1).flatten()
74 | # 将矩阵中值为NaN的元素置为0
75 | matrix[torch.isnan(matrix)] = 0
76 |
77 | # 获取所有不为NaN的元素的索引
78 | non_negative_indices = torch.nonzero(matrix)
79 | precipitation_re = precipitation_re[non_negative_indices.flatten(), :]
80 | ```
81 |
82 |
83 | ```python
84 | class Config(object):
85 | t0 = 155 #155
86 | t1 = 12
87 | t = t0 + t1
88 | train_num = 8000 #8
89 | validation_num = 1000 #1
90 | test_num = 1000 #1
91 | in_channels = 1
92 | batch_size = 500 #500 NSE 0.75
93 | lr = .0005 # learning rate
94 | epochs = 100
95 | ```
96 |
97 |
98 | ```python
99 | import torch
100 | import matplotlib.pyplot as plt
101 | import numpy as np
102 | from torch.utils.data import Dataset
103 |
104 | class time_series_decoder_paper(Dataset):
105 | """synthetic time series dataset from section 5.1"""
106 |
107 | def __init__(self,t0=120,N=4500,dx=None,dy=None,transform=None):
108 | """
109 | Args:
110 | t0: previous t0 data points to predict from
111 | N: number of data points
112 | transform: any transformations to be applied to time series
113 | """
114 | self.t0 = t0
115 | self.N = N
116 | self.dx = dx
117 | self.dy = dy
118 | self.transform = None
119 |
120 |
121 | # time points
122 | #self.x = torch.cat(N*[torch.arange(0,t0+24).type(torch.float).unsqueeze(0)])
123 | self.x = dx
124 | self.fx = dy
125 | # self.fx = torch.cat([A1.unsqueeze(1)*torch.sin(np.pi*self.x[0,0:12]/6)+72 ,
126 | # A2.unsqueeze(1)*torch.sin(np.pi*self.x[0,12:24]/6)+72 ,
127 | # A3.unsqueeze(1)*torch.sin(np.pi*self.x[0,24:t0]/6)+72,
128 | # A4.unsqueeze(1)*torch.sin(np.pi*self.x[0,t0:t0+24]/12)+72],1)
129 |
130 | # add noise
131 | # self.fx = self.fx + torch.randn(self.fx.shape)
132 |
133 | self.masks = self._generate_square_subsequent_mask(t0)
134 |
135 |
136 | # print out shapes to confirm desired output
137 | print("x: ",self.x.shape,
138 | "fx: ",self.fx.shape)
139 |
140 | def __len__(self):
141 | return len(self.fx)
142 |
143 | def __getitem__(self,idx):
144 | if torch.is_tensor(idx):
145 | idx = idx.tolist()
146 |
147 |
148 | sample = (self.x[idx,:,:], #self.x[idx,:]
149 | self.fx[idx,:],
150 | self.masks)
151 |
152 | if self.transform:
153 | sample=self.transform(sample)
154 |
155 | return sample
156 |
157 | def _generate_square_subsequent_mask(self,t0):
158 | mask = torch.zeros(Config.t,Config.t)
159 | for i in range(0,Config.t0):
160 | mask[i,Config.t0:] = 1
161 | for i in range(Config.t0,Config.t):
162 | mask[i,i+1:] = 1
163 | mask = mask.float().masked_fill(mask == 1, float('-inf'))#.masked_fill(mask == 1, float(0.0))
164 | return mask
165 | ```
166 |
167 |
168 | ```python
169 | class TransformerTimeSeries(torch.nn.Module):
170 | """
171 | Time Series application of transformers based on paper
172 |
173 | causal_convolution_layer parameters:
174 | in_channels: the number of features per time point
175 | out_channels: the number of features outputted per time point
176 | kernel_size: k is the width of the 1-D sliding kernel
177 |
178 | nn.Transformer parameters:
179 | d_model: the size of the embedding vector (input)
180 |
181 | PositionalEncoding parameters:
182 | d_model: the size of the embedding vector (positional vector)
183 | dropout: the dropout to be used on the sum of positional+embedding vector
184 |
185 | """
186 | def __init__(self):
187 | super(TransformerTimeSeries,self).__init__()
188 | self.input_embedding = context_embedding(Config.in_channels+1,256,5)
189 | self.positional_embedding = torch.nn.Embedding(512,256)
190 |
191 |
192 | self.decode_layer = torch.nn.TransformerEncoderLayer(d_model=256,nhead=8)
193 | self.transformer_decoder = torch.nn.TransformerEncoder(self.decode_layer, num_layers=3)
194 |
195 | self.fc1 = torch.nn.Linear(256,1)
196 |
197 | def forward(self,x,y,attention_masks):
198 |
199 | # concatenate observed points and time covariate
200 | # (B*feature_size*n_time_points)
201 | #re z = torch.cat((y.unsqueeze(1),x.unsqueeze(1)),1)
202 | z = torch.cat((y,x),1)
203 | # input_embedding returns shape (Batch size,embedding size,sequence len) -> need (sequence len,Batch size,embedding_size)
204 | #re z_embedding = self.input_embedding(z).permute(2,0,1)
205 | z_embedding = self.input_embedding(z).unsqueeze(1).permute(3, 1, 0, 2)
206 | # get my positional embeddings (Batch size, sequence_len, embedding_size) -> need (sequence len,Batch size,embedding_size)
207 | x1 = x.type(torch.long)
208 | x1[x1 < 0] = 0
209 | positional_embeddings = self.positional_embedding(x1).permute(2, 1, 0, 3)
210 | #re #positional_embeddings = self.positional_embedding(x.type(torch.long)).permute(1,0,2)
211 |
212 | input_embedding = z_embedding+positional_embeddings
213 | input_embedding1 = torch.mean(input_embedding, 1)
214 | transformer_embedding = self.transformer_decoder(input_embedding1,attention_masks)
215 |
216 | output = self.fc1(transformer_embedding.permute(1,0,2))
217 |
218 | return output
219 | import torch
220 | import numpy as np
221 | import matplotlib.pyplot as plt
222 | import torch.nn.functional as F
223 |
224 | class CausalConv1d(torch.nn.Conv1d):
225 | def __init__(self,
226 | in_channels,
227 | out_channels,
228 | kernel_size,
229 | stride=1,
230 | dilation=1,
231 | groups=1,
232 | bias=True):
233 |
234 | super(CausalConv1d, self).__init__(
235 | in_channels,
236 | out_channels,
237 | kernel_size=kernel_size,
238 | stride=stride,
239 | padding=0,
240 | dilation=dilation,
241 | groups=groups,
242 | bias=bias)
243 |
244 | self.__padding = (kernel_size - 1) * dilation
245 |
246 | def forward(self, input):
247 | return super(CausalConv1d, self).forward(F.pad(input, (self.__padding, 0)))
248 |
249 |
250 | class context_embedding(torch.nn.Module):
251 | def __init__(self,in_channels=Config.in_channels,embedding_size=256,k=5):
252 | super(context_embedding,self).__init__()
253 | self.causal_convolution = CausalConv1d(in_channels,embedding_size,kernel_size=k)
254 |
255 | def forward(self,x):
256 | x = self.causal_convolution(x)
257 | return torch.tanh(x)
258 | ```
259 |
260 |
261 | ```python
262 | class LSTM_Time_Series(torch.nn.Module):
263 | def __init__(self,input_size=2,embedding_size=256,kernel_width=9,hidden_size=512):
264 | super(LSTM_Time_Series,self).__init__()
265 |
266 | self.input_embedding = context_embedding(input_size,embedding_size,kernel_width)
267 |
268 | self.lstm = torch.nn.LSTM(embedding_size,hidden_size,batch_first=True)
269 |
270 | self.fc1 = torch.nn.Linear(hidden_size,1)
271 |
272 | def forward(self,x,y):
273 | """
274 | x: the time covariate
275 | y: the observed target
276 | """
277 | # concatenate observed points and time covariate
278 | # (B,input size + covariate size,sequence length)
279 | # z = torch.cat((y.unsqueeze(1),x.unsqueeze(1)),1)
280 | z_obs = torch.cat((y.unsqueeze(1),x.unsqueeze(1)),1)
281 | if isLSTM:
282 | z_obs = torch.cat((y, x),1)
283 | # input_embedding returns shape (B,embedding size,sequence length)
284 | z_obs_embedding = self.input_embedding(z_obs)
285 |
286 | # permute axes (B,sequence length, embedding size)
287 | z_obs_embedding = self.input_embedding(z_obs).permute(0,2,1)
288 |
289 | # all hidden states from lstm
290 | # (B,sequence length,num_directions * hidden size)
291 | lstm_out,_ = self.lstm(z_obs_embedding)
292 |
293 | # input to nn.Linear: (N,*,Hin)
294 | # output (N,*,Hout)
295 | return self.fc1(lstm_out)
296 | ```
297 |
298 |
299 | ```python
300 | from torch.utils.data import DataLoader
301 | import random
302 | random.seed(0)
303 |
304 | random_indices = random.sample(range(non_negative_indices.shape[0]), Config.train_num)
305 | random_indices1 = random.sample(range(non_negative_indices.shape[0]), Config.validation_num)
306 | random_indices2 = random.sample(range(non_negative_indices.shape[0]), Config.test_num)
307 | dx = torch.stack([torch.cat(Config.train_num*[torch.arange(0,Config.t).type(torch.float).unsqueeze(0)]).cuda()], 1)
308 | dx1 = torch.stack([torch.cat(Config.validation_num*[torch.arange(0,Config.t).type(torch.float).unsqueeze(0)]).cuda()], 1)
309 | dx2 = torch.stack([torch.cat(Config.test_num*[torch.arange(0,Config.t).type(torch.float).unsqueeze(0)]).cuda()], 1)
310 | train_dataset = time_series_decoder_paper(t0=Config.t0,N=Config.train_num,dx=dx ,dy=precipitation_re[np.array([random_indices]).flatten(),0:Config.t].unsqueeze(1))
311 | validation_dataset = time_series_decoder_paper(t0=Config.t0,N=Config.validation_num,dx=dx1,dy=precipitation_re[np.array([random_indices1]).flatten(),0:Config.t].unsqueeze(1))
312 | test_dataset = time_series_decoder_paper(t0=Config.t0,N=Config.test_num,dx=dx2,dy=precipitation_re[np.array([random_indices2]).flatten(),0:Config.t].unsqueeze(1))
313 | ```
314 |
315 | x: torch.Size([8000, 1, 167]) fx: torch.Size([8000, 1, 167])
316 | x: torch.Size([1000, 1, 167]) fx: torch.Size([1000, 1, 167])
317 | x: torch.Size([1000, 1, 167]) fx: torch.Size([1000, 1, 167])
318 |
319 |
320 |
321 | ```python
322 | criterion = torch.nn.MSELoss()
323 | train_dl = DataLoader(train_dataset,batch_size=Config.batch_size,shuffle=True, generator=torch.Generator(device='cpu'))
324 | validation_dl = DataLoader(validation_dataset,batch_size=Config.batch_size, generator=torch.Generator(device='cpu'))
325 | test_dl = DataLoader(test_dataset,batch_size=Config.batch_size, generator=torch.Generator(device='cpu'))
326 | ```
327 |
328 |
329 | ```python
330 | criterion_LSTM = torch.nn.MSELoss()
331 | ```
332 |
333 |
334 | ```python
335 | LSTM = LSTM_Time_Series().cuda()
336 | ```
337 |
338 |
339 | ```python
340 | def Dp(y_pred,y_true,q):
341 | return max([q*(y_pred-y_true),(q-1)*(y_pred-y_true)])
342 |
343 | def Rp_num_den(y_preds,y_trues,q):
344 | numerator = np.sum([Dp(y_pred,y_true,q) for y_pred,y_true in zip(y_preds,y_trues)])
345 | denominator = np.sum([np.abs(y_true) for y_true in y_trues])
346 | return numerator,denominator
347 | def train_epoch(LSTM,train_dl,t0=Config.t0):
348 | LSTM.train()
349 | train_loss = 0
350 | n = 0
351 | for step,(x,y,_) in enumerate(train_dl):
352 | x = x.cuda()
353 | y = y.cuda()
354 |
355 | optimizer.zero_grad()
356 | output = LSTM(x,y)
357 |
358 | loss = criterion(output.squeeze()[:,(Config.t0-1):(Config.t0+Config.t1-1)],y.cuda()[:,0,Config.t0:])
359 | loss.backward()
360 | optimizer.step()
361 |
362 | train_loss += (loss.detach().cpu().item() * x.shape[0])
363 | n += x.shape[0]
364 | return train_loss/n
365 | def eval_epoch(LSTM,validation_dl,t0=Config.t0):
366 | LSTM.eval()
367 | eval_loss = 0
368 | n = 0
369 | with torch.no_grad():
370 | for step,(x,y,_) in enumerate(train_dl):
371 | x = x.cuda()
372 | y = y.cuda()
373 |
374 | output = LSTM(x,y)
375 | loss = criterion(output.squeeze()[:,(Config.t0-1):(Config.t0+Config.t1-1)],y.cuda()[:,0,Config.t0:])
376 |
377 | eval_loss += (loss.detach().cpu().item() * x.shape[0])
378 | n += x.shape[0]
379 |
380 | return eval_loss/n
381 | def test_epoch(LSTM,test_dl,t0=Config.t0):
382 | with torch.no_grad():
383 | predictions = []
384 | observations = []
385 |
386 | LSTM.eval()
387 | for step,(x,y,_) in enumerate(train_dl):
388 | x = x.cuda()
389 | y = y.cuda()
390 |
391 | output = LSTM(x,y)
392 |
393 | for p,o in zip(output.squeeze()[:,(Config.t0-1):(Config.t0+Config.t1-1)].cpu().numpy().tolist(),y.cuda()[:,0,Config.t0:].cpu().numpy().tolist()):
394 |
395 | predictions.append(p)
396 | observations.append(o)
397 |
398 | num = 0
399 | den = 0
400 | for y_preds,y_trues in zip(predictions,observations):
401 | num_i,den_i = Rp_num_den(y_preds,y_trues,.5)
402 | num+=num_i
403 | den+=den_i
404 | Rp = (2*num)/den
405 |
406 | return Rp
407 | ```
408 |
409 |
410 | ```python
411 | train_epoch_loss = []
412 | eval_epoch_loss = []
413 | Rp_best = 30
414 | isLSTM = True
415 | optimizer = torch.optim.Adam(LSTM.parameters(), lr=Config.lr)
416 |
417 | for e,epoch in enumerate(range(Config.epochs)):
418 | train_loss = []
419 | eval_loss = []
420 |
421 | l_train = train_epoch(LSTM,train_dl)
422 | train_loss.append(l_train)
423 |
424 | l_eval = eval_epoch(LSTM,validation_dl)
425 | eval_loss.append(l_eval)
426 |
427 | Rp = test_epoch(LSTM,test_dl)
428 |
429 | if Rp_best > Rp:
430 | Rp_best = Rp
431 |
432 | with torch.no_grad():
433 | print("Epoch {}: Train loss={} \t Eval loss = {} \t Rp={}".format(e,np.mean(train_loss),np.mean(eval_loss),Rp))
434 |
435 | train_epoch_loss.append(np.mean(train_loss))
436 | eval_epoch_loss.append(np.mean(eval_loss))
437 | ```
438 |
439 | Epoch 0: Train loss=1.169178232550621 Eval loss = 1.0225972533226013 Rp=0.9754564549482763
440 | Epoch 1: Train loss=1.0208212696015835 Eval loss = 1.0133976340293884 Rp=0.992397172460293
441 | Epoch 2: Train loss=1.0135348699986935 Eval loss = 1.0102826319634914 Rp=1.0010193148701694
442 | Epoch 3: Train loss=1.0075771994888783 Eval loss = 1.003737311810255 Rp=0.9874602462009079
443 | Epoch 4: Train loss=0.9986509047448635 Eval loss = 0.991193663328886 Rp=0.9851476850727866
444 | Epoch 5: Train loss=0.9815127141773701 Eval loss = 0.9676672779023647 Rp=0.9776791701859356
445 | Epoch 6: Train loss=0.9377330988645554 Eval loss = 0.8851083293557167 Rp=0.9021427715861259
446 | Epoch 7: Train loss=0.8124373629689217 Eval loss = 0.776776347309351 Rp=0.8164312307396333
447 | Epoch 8: Train loss=0.7808051072061062 Eval loss = 0.7724409475922585 Rp=0.8044285279275872
448 | Epoch 9: Train loss=0.7723440378904343 Eval loss = 0.7691570967435837 Rp=0.8276665115435272
449 | Epoch 10: Train loss=0.7680666074156761 Eval loss = 0.7604397684335709 Rp=0.8172812960763582
450 | Epoch 11: Train loss=0.7637499608099461 Eval loss = 0.7642757333815098 Rp=0.7958800149846623
451 | Epoch 12: Train loss=0.7604391016066074 Eval loss = 0.7545832060277462 Rp=0.7959258052455023
452 | Epoch 13: Train loss=0.7542793937027454 Eval loss = 0.758263424038887 Rp=0.8029712542553213
453 | Epoch 14: Train loss=0.7513296827673912 Eval loss = 0.74464987590909 Rp=0.7928307957450957
454 | Epoch 15: Train loss=0.7609197050333023 Eval loss = 0.7561161443591118 Rp=0.7961394733618599
455 | Epoch 16: Train loss=0.7611901015043259 Eval loss = 0.754481915384531 Rp=0.8042048513261087
456 | Epoch 17: Train loss=0.7494660168886185 Eval loss = 0.7436127960681915 Rp=0.808727372545216
457 | Epoch 18: Train loss=0.7624928876757622 Eval loss = 0.7601931132376194 Rp=0.8278642644450237
458 | Epoch 19: Train loss=0.7445684559643269 Eval loss = 0.7404011972248554 Rp=0.7988691222148593
459 | Epoch 20: Train loss=0.7364756055176258 Eval loss = 0.7331099547445774 Rp=0.7900263164250347
460 | Epoch 21: Train loss=0.7366516776382923 Eval loss = 0.7335694879293442 Rp=0.7858690993104261
461 | Epoch 22: Train loss=0.7310461960732937 Eval loss = 0.7295852825045586 Rp=0.7947732982354974
462 | Epoch 23: Train loss=0.7325067669153214 Eval loss = 0.7321035303175449 Rp=0.7994415259947144
463 | Epoch 24: Train loss=0.7324810847640038 Eval loss = 0.7215580977499485 Rp=0.7793204264401973
464 | Epoch 25: Train loss=0.7343184538185596 Eval loss = 0.7716133445501328 Rp=0.8319286623223218
465 | Epoch 26: Train loss=0.7366975098848343 Eval loss = 0.7249130606651306 Rp=0.77453832290699
466 | Epoch 27: Train loss=0.7278863601386547 Eval loss = 0.720306035131216 Rp=0.7780187017935781
467 | Epoch 28: Train loss=0.7243384085595608 Eval loss = 0.715414222329855 Rp=0.7660441378673354
468 | Epoch 29: Train loss=0.7391963303089142 Eval loss = 0.8104664944112301 Rp=0.848384596737658
469 | Epoch 30: Train loss=0.7501446716487408 Eval loss = 0.7330859526991844 Rp=0.7958694240958736
470 | Epoch 31: Train loss=0.7319861426949501 Eval loss = 0.7288344763219357 Rp=0.7764215311271584
471 | Epoch 32: Train loss=0.7289896085858345 Eval loss = 0.7186561860144138 Rp=0.7719541082002876
472 | Epoch 33: Train loss=0.7208635434508324 Eval loss = 0.7130853533744812 Rp=0.7693201078171342
473 | Epoch 34: Train loss=0.7188350297510624 Eval loss = 0.7184220626950264 Rp=0.7790605390063141
474 | Epoch 35: Train loss=0.7278616651892662 Eval loss = 0.7226458676159382 Rp=0.7934061207417414
475 | Epoch 36: Train loss=0.7257222011685371 Eval loss = 0.7461872175335884 Rp=0.8043810986938649
476 | Epoch 37: Train loss=0.722360011190176 Eval loss = 0.7184372805058956 Rp=0.7733680838057659
477 | Epoch 38: Train loss=0.7406770437955856 Eval loss = 0.7328226044774055 Rp=0.7792520400962948
478 | Epoch 39: Train loss=0.7231648564338684 Eval loss = 0.7170744873583317 Rp=0.7745043319879358
479 | Epoch 40: Train loss=0.7179877758026123 Eval loss = 0.7121099643409252 Rp=0.7665604319856079
480 | Epoch 41: Train loss=0.7204379811882973 Eval loss = 0.7137661874294281 Rp=0.7697876723151911
481 | Epoch 42: Train loss=0.7189657613635063 Eval loss = 0.7145831622183323 Rp=0.7683013184528533
482 | Epoch 43: Train loss=0.7236194014549255 Eval loss = 0.7213072367012501 Rp=0.7794286898969346
483 | Epoch 44: Train loss=0.7113666497170925 Eval loss = 0.7064446061849594 Rp=0.7730508504910318
484 | Epoch 45: Train loss=0.7169638313353062 Eval loss = 0.7072659730911255 Rp=0.7818726340368028
485 | Epoch 46: Train loss=0.7137239314615726 Eval loss = 0.7547547481954098 Rp=0.8419246011202801
486 | Epoch 47: Train loss=0.7260657027363777 Eval loss = 0.7107443884015083 Rp=0.7655332919045247
487 | Epoch 48: Train loss=0.7079816907644272 Eval loss = 0.7071417346596718 Rp=0.773454930517809
488 | Epoch 49: Train loss=0.7094167955219746 Eval loss = 0.7058326154947281 Rp=0.7727462602735332
489 | Epoch 50: Train loss=0.7380903214216232 Eval loss = 0.7227592132985592 Rp=0.769177665270142
490 | Epoch 51: Train loss=0.7130068391561508 Eval loss = 0.7105456776916981 Rp=0.7592860561025371
491 | Epoch 52: Train loss=0.7084374688565731 Eval loss = 0.7031594552099705 Rp=0.7660899650703171
492 | Epoch 53: Train loss=0.7042888924479485 Eval loss = 0.7040572166442871 Rp=0.7721988413128251
493 | Epoch 54: Train loss=0.7063969075679779 Eval loss = 0.6986251175403595 Rp=0.7695131487850577
494 | Epoch 55: Train loss=0.7053375691175461 Eval loss = 0.7212032824754715 Rp=0.7735872477697779
495 | Epoch 56: Train loss=0.7035926096141338 Eval loss = 0.7097134478390217 Rp=0.784710594937848
496 | Epoch 57: Train loss=0.735156461596489 Eval loss = 0.8654714487493038 Rp=0.911015246823642
497 | Epoch 58: Train loss=0.8107807412743568 Eval loss = 0.7581384815275669 Rp=0.8157151369086734
498 | Epoch 59: Train loss=0.7544550113379955 Eval loss = 0.7499602921307087 Rp=0.7949385155087615
499 | Epoch 60: Train loss=0.746234655380249 Eval loss = 0.7389007620513439 Rp=0.7862202786182774
500 | Epoch 61: Train loss=0.7351461201906204 Eval loss = 0.7316578552126884 Rp=0.7699839364702437
501 | Epoch 62: Train loss=0.7276912368834019 Eval loss = 0.7233806289732456 Rp=0.7752237105789992
502 | Epoch 63: Train loss=0.7192910276353359 Eval loss = 0.7333457358181477 Rp=0.7724518923905164
503 | Epoch 64: Train loss=0.7260234951972961 Eval loss = 0.7129928097128868 Rp=0.7660325583116143
504 | Epoch 65: Train loss=0.7277122773230076 Eval loss = 0.7464463748037815 Rp=0.7795688062338593
505 | Epoch 66: Train loss=0.7259447425603867 Eval loss = 0.7074765078723431 Rp=0.7646423879285662
506 | Epoch 67: Train loss=0.7121074683964252 Eval loss = 0.7206148467957973 Rp=0.7670712925578708
507 | Epoch 68: Train loss=0.7051395028829575 Eval loss = 0.7368365041911602 Rp=0.8174581520496608
508 | Epoch 69: Train loss=0.7655579410493374 Eval loss = 0.7538384310901165 Rp=0.7762011659936345
509 | Epoch 70: Train loss=0.7304071560502052 Eval loss = 0.7248818911612034 Rp=0.791938228314532
510 | Epoch 71: Train loss=0.7145950980484486 Eval loss = 0.7085471898317337 Rp=0.771219627305778
511 | Epoch 72: Train loss=0.705636128783226 Eval loss = 0.7026742734014988 Rp=0.7582436097165333
512 | Epoch 73: Train loss=0.7039311081171036 Eval loss = 0.701056282967329 Rp=0.7621456124101622
513 | Epoch 74: Train loss=0.7022229805588722 Eval loss = 0.7022544406354427 Rp=0.7572908772835294
514 | Epoch 75: Train loss=0.7077537477016449 Eval loss = 0.7068974897265434 Rp=0.7689672806983037
515 | Epoch 76: Train loss=0.70463952049613 Eval loss = 0.7016653120517731 Rp=0.7620379826248179
516 | Epoch 77: Train loss=0.6936824433505535 Eval loss = 0.6882451064884663 Rp=0.7536649577052977
517 | Epoch 78: Train loss=0.7085927426815033 Eval loss = 0.7006802186369896 Rp=0.7573836578425687
518 | Epoch 79: Train loss=0.6964434124529362 Eval loss = 0.6970510184764862 Rp=0.7566034081453576
519 | Epoch 80: Train loss=0.7041287049651146 Eval loss = 0.7041221261024475 Rp=0.7517282641309433
520 | Epoch 81: Train loss=0.7040938474237919 Eval loss = 0.694716889411211 Rp=0.7518976134462692
521 | Epoch 82: Train loss=0.6934744939208031 Eval loss = 0.6876317113637924 Rp=0.7490530317847554
522 | Epoch 83: Train loss=0.6924876533448696 Eval loss = 0.7114526480436325 Rp=0.7712103875056127
523 | Epoch 84: Train loss=0.70367331802845 Eval loss = 0.6974217854440212 Rp=0.7550885913055979
524 | Epoch 85: Train loss=0.7047922983765602 Eval loss = 0.6882399097084999 Rp=0.7489761214479039
525 | Epoch 86: Train loss=0.6913500241935253 Eval loss = 0.6827207766473293 Rp=0.744314955284494
526 | Epoch 87: Train loss=0.6916158571839333 Eval loss = 0.6919064372777939 Rp=0.7509365031084853
527 | Epoch 88: Train loss=0.6971654705703259 Eval loss = 0.6914562620222569 Rp=0.7441346646542655
528 | Epoch 89: Train loss=0.6963543370366096 Eval loss = 0.691758755594492 Rp=0.7572971481054963
529 | Epoch 90: Train loss=0.6917684748768806 Eval loss = 0.6856491975486279 Rp=0.7448740185296933
530 | Epoch 91: Train loss=0.6941814571619034 Eval loss = 0.6949076354503632 Rp=0.7525349231139533
531 | Epoch 92: Train loss=0.69602020829916 Eval loss = 0.7147903628647327 Rp=0.7746622771469637
532 | Epoch 93: Train loss=0.6934036538004875 Eval loss = 0.689013235270977 Rp=0.7674142639047323
533 | Epoch 94: Train loss=0.6828178651630878 Eval loss = 0.6839329451322556 Rp=0.749017388184114
534 | Epoch 95: Train loss=0.6820085123181343 Eval loss = 0.679760005325079 Rp=0.7496224481291669
535 | Epoch 96: Train loss=0.6940162815153599 Eval loss = 0.6897397376596928 Rp=0.7505375270808133
536 | Epoch 97: Train loss=0.6879976131021976 Eval loss = 0.7038531377911568 Rp=0.7740689357717414
537 | Epoch 98: Train loss=0.6902556456625462 Eval loss = 0.6736902967095375 Rp=0.7441236415340688
538 | Epoch 99: Train loss=0.6990400142967701 Eval loss = 0.7069363221526146 Rp=0.7699857686313047
539 |
540 |
541 |
542 | ```python
543 | import os
544 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
545 | n_plots = 5
546 |
547 | t0=120
548 | with torch.no_grad():
549 | LSTM.eval()
550 | for step,(x,y,_) in enumerate(test_dl):
551 | x = x.cuda()
552 | y = y.cuda()
553 |
554 | output = LSTM(x,y)
555 |
556 |
557 | if step > n_plots:
558 | break
559 |
560 | with torch.no_grad():
561 | plt.figure(figsize=(10,10))
562 | plt.plot(x[1, 0].cpu().detach().squeeze().numpy(),y[1].cpu().detach().squeeze().numpy(),'g--',linewidth=3)
563 | plt.plot(x[1, 0, Config.t0:].cpu().detach().squeeze().numpy(),output[1,(Config.t0-1):(Config.t0+Config.t1-1),0].cpu().detach().squeeze().numpy(),'b--',linewidth=3)
564 |
565 | plt.xlabel("x",fontsize=20)
566 | plt.legend(["$[0,t_0+24)_{obs}$","$[t_0,t_0+24)_{predicted}$"])
567 | plt.show()
568 | ```
569 |
570 |
571 |
572 | 
573 |
574 |
575 |
576 |
577 |
578 | 
579 |
580 |
581 |
582 |
583 | ```python
584 | matrix = torch.empty(0).cuda()
585 | obsmat = torch.empty(0).cuda()
586 |
587 | with torch.no_grad():
588 | LSTM.eval()
589 | predictions = []
590 | observations = []
591 | for step,(x,y,attention_masks) in enumerate(test_dl):
592 | # if step == 8:
593 | # break
594 | output = LSTM(x.cuda(),y.cuda())
595 | matrix = torch.cat((matrix, output.cuda()))
596 | obsmat = torch.cat((obsmat, y.cuda()))
597 |
598 | pre = matrix.cpu().detach().numpy()
599 | obs = obsmat.cpu().detach().numpy()
600 | # libraries
601 | import matplotlib.pyplot as plt
602 | import numpy as np
603 | import pandas as pd
604 |
605 | # data
606 | df = pd.DataFrame({
607 | 'obs': obs[:, 0, Config.t0:Config.t].flatten(),
608 | 'pre': pre[:, Config.t0:Config.t, 0].flatten()
609 | })
610 | df
611 | ```
612 |
613 |
614 |
615 |
616 |
617 |
630 |
631 |
632 |
633 | |
634 | obs |
635 | pre |
636 |
637 |
638 |
639 |
640 | | 0 |
641 | 1.195594 |
642 | 1.645856 |
643 |
644 |
645 | | 1 |
646 | 1.649769 |
647 | 0.247605 |
648 |
649 |
650 | | 2 |
651 | -0.608017 |
652 | -0.466307 |
653 |
654 |
655 | | 3 |
656 | -0.471923 |
657 | -0.499275 |
658 |
659 |
660 | | 4 |
661 | -1.097827 |
662 | -0.693097 |
663 |
664 |
665 | | ... |
666 | ... |
667 | ... |
668 |
669 |
670 | | 11995 |
671 | 3.649456 |
672 | 0.124768 |
673 |
674 |
675 | | 11996 |
676 | -0.162433 |
677 | -0.615067 |
678 |
679 |
680 | | 11997 |
681 | -0.589042 |
682 | -0.838524 |
683 |
684 |
685 | | 11998 |
686 | -0.578971 |
687 | -0.815273 |
688 |
689 |
690 | | 11999 |
691 | -0.744391 |
692 | -0.398436 |
693 |
694 |
695 |
696 |
12000 rows × 2 columns
697 |
698 |
699 |
700 |
701 |
702 | ```python
703 | import numpy as np
704 | import pandas as pd
705 | import matplotlib.pyplot as plt
706 | from scipy import stats
707 | from matplotlib import rcParams
708 | from statistics import mean
709 | from sklearn.metrics import explained_variance_score,r2_score,median_absolute_error,mean_squared_error,mean_absolute_error
710 | from scipy.stats import pearsonr
711 | # 加载数据(PS:原始数据太多,采样10000)
712 | # 默认是读取csv/xlsx的列成DataFrame
713 |
714 |
715 | config = {"font.family":'Times New Roman',"font.size": 16,"mathtext.fontset":'stix'}
716 | #df = df.sample(5000)
717 | # 用于计算指标
718 | x = df['obs']; y = df['pre']
719 | rcParams.update(config)
720 | BIAS = mean(x - y)
721 | MSE = mean_squared_error(x, y)
722 | RMSE = np.power(MSE, 0.5)
723 | R2 = pearsonr(x, y).statistic
724 | adjR2 = 1-((1-r2_score(x,y))*(len(x)-1))/(len(x)-Config.in_channels-1)
725 | MAE = mean_absolute_error(x, y)
726 | EV = explained_variance_score(x, y)
727 | NSE = 1 - (RMSE ** 2 / np.var(x))
728 | # 计算散点密度
729 | xy = np.vstack([x, y])
730 | z = stats.gaussian_kde(xy)(xy)
731 | idx = z.argsort()
732 | x, y, z = x.iloc[idx], y.iloc[idx], z[idx]
733 |
734 | # 拟合(若换MK,自行操作)最小二乘
735 | def slope(xs, ys):
736 | m = (((mean(xs) * mean(ys)) - mean(xs * ys)) / ((mean(xs) * mean(xs)) - mean(xs * xs)))
737 | b = mean(ys) - m * mean(xs)
738 | return m, b
739 | k, b = slope(x, y)
740 | regression_line = []
741 | for a in x:
742 | regression_line.append((k * a) + b)
743 |
744 | # 绘图,可自行调整颜色等等
745 | import os
746 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
747 |
748 | fig,ax=plt.subplots(figsize=(8,6),dpi=300)
749 | scatter=ax.scatter(x, y, marker='o', c=z*100, edgecolors=None ,s=15, label='LST',cmap='Spectral_r')
750 | cbar=plt.colorbar(scatter,shrink=1,orientation='vertical',extend='both',pad=0.015,aspect=30,label='frequency')
751 | plt.plot([-30,30],[-30,30],'black',lw=1.5) # 画的1:1线,线的颜色为black,线宽为0.8
752 | plt.plot(x,regression_line,'red',lw=1.5) # 预测与实测数据之间的回归线
753 | plt.axis([-30,30,-30,30]) # 设置线的范围
754 | plt.xlabel('OBS',family = 'Times New Roman')
755 | plt.ylabel('PRE',family = 'Times New Roman')
756 | plt.xticks(fontproperties='Times New Roman')
757 | plt.yticks(fontproperties='Times New Roman')
758 | plt.text(-1.8,1.75, '$N=%.f$' % len(y), family = 'Times New Roman') # text的位置需要根据x,y的大小范围进行调整。
759 | plt.text(-1.8,1.50, '$R^2=%.3f$' % R2, family = 'Times New Roman')
760 | plt.text(-1.8,1.25, '$NSE=%.3f$' % NSE, family = 'Times New Roman')
761 |
762 | plt.text(-1.8,1, '$RMSE=%.3f$' % RMSE, family = 'Times New Roman')
763 | plt.xlim(-2,2) # 设置x坐标轴的显示范围
764 | plt.ylim(-2,2) # 设置y坐标轴的显示范围
765 | plt.show()
766 | ```
767 |
768 |
769 |
770 | 
771 |
772 |
773 |
774 |
775 | ```python
776 |
777 | ```
778 |
--------------------------------------------------------------------------------