├── README.md ├── arnold ├── evaluation ├── __init__.py ├── learning_curve.py ├── calibrate_model.py ├── predict_tf.py ├── prediction_plot.py ├── performance_diagrams.py └── generate_performance_stats.py ├── utils ├── __init__.py ├── change_model_number.py ├── plotting_utils.py ├── change_model_properties.py ├── timestep_front_count.py ├── misc.py └── settings.py ├── .gitignore ├── plots.py ├── custom_activations.py ├── requirements.txt ├── convert_front_gml_to_xml.py ├── LICENSE ├── custom_metrics.py ├── custom_losses.py ├── download_grib_files.py ├── convert_front_xml_to_netcdf.py ├── create_era5_netcdf.py ├── convert_grib_to_netcdf.py └── convert_netcdf_to_tf.py /README.md: -------------------------------------------------------------------------------- 1 | # fronts -------------------------------------------------------------------------------- /arnold: -------------------------------------------------------------------------------- 1 | test 2 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/change_model_number.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script that changes the number of a model and its data files. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.6.26 6 | """ 7 | import os 8 | from glob import glob 9 | import argparse 10 | import pandas as pd 11 | import pickle 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--model_dir', type=str, help='Directory for the models.') 17 | parser.add_argument('--model_numbers', type=int, nargs=2, help='The original and new model numbers.') 18 | args = vars(parser.parse_args()) 19 | 20 | assert not os.path.isdir('%s/model_%d' % (args['model_dir'], args['model_numbers'][1])) # make sure the new model number is not already assigned to a model 21 | 22 | ### Change the model number in the model properties dictionary ### 23 | model_properties_file = '%s/model_%d/model_%d_properties.pkl' % (args['model_dir'], args['model_numbers'][0], args['model_numbers'][0]) 24 | model_properties = pd.read_pickle(model_properties_file) 25 | model_properties['model_number'] = args['model_numbers'][1] 26 | 27 | with open(model_properties_file, 'wb') as f: 28 | pickle.dump(model_properties, f) 29 | 30 | os.rename('%s/model_%d' % (args['model_dir'], args['model_numbers'][0]), '%s/model_%d' % (args['model_dir'], args['model_numbers'][1])) # rename the model number directory 31 | files_to_rename = list(sorted(glob('%s/model_%d/**/*' % (args['model_dir'], args['model_numbers'][1]), recursive=True))) # files within the subdirectories to rename 32 | 33 | print("Renaming %d files" % len(files_to_rename)) 34 | for file in files_to_rename: 35 | os.rename(file, file.replace(str(args['model_numbers'][0]), str(args['model_numbers'][1]))) 36 | 37 | print("Successfully changed model number: %d -------> %d" % (args['model_numbers'][0], args['model_numbers'][1])) 38 | -------------------------------------------------------------------------------- /utils/plotting_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting tools. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.9.15 6 | """ 7 | 8 | import cartopy.feature as cfeature 9 | import cartopy.crs as ccrs 10 | import matplotlib.pyplot as plt 11 | import matplotlib as mpl 12 | import numpy as np 13 | 14 | 15 | def plot_background(extent, ax=None, linewidth=0.5): 16 | """ 17 | Returns new background for the plot. 18 | 19 | Parameters 20 | ---------- 21 | extent: Iterable object with 4 integers 22 | Iterable containing the extent/boundaries of the plot in the format of [min lon, max lon, min lat, max lat] expressed 23 | in degrees. 24 | ax: matplotlib.axes.Axes instance or None 25 | Axis on which the background will be plotted. 26 | linewidth: float or int 27 | Thickness of coastlines and the borders of states and countries. 28 | 29 | Returns 30 | ------- 31 | ax: matplotlib.axes.Axes instance 32 | New plot background. 33 | """ 34 | if ax is None: 35 | crs = ccrs.LambertConformal(central_longitude=250) 36 | ax = plt.axes(projection=crs) 37 | else: 38 | ax.add_feature(cfeature.COASTLINE.with_scale('50m'), linewidth=linewidth) 39 | ax.add_feature(cfeature.BORDERS, linewidth=linewidth) 40 | ax.add_feature(cfeature.STATES, linewidth=linewidth) 41 | ax.set_extent(extent, crs=ccrs.PlateCarree()) 42 | return ax 43 | 44 | 45 | def truncated_colormap(cmap, minval=0.0, maxval=1.0, n=100): 46 | """ 47 | Get an instance of a truncated matplotlib.colors.Colormap object. 48 | 49 | Parameters 50 | ---------- 51 | cmap: str 52 | Matplotlib colormap to truncate. 53 | minval: float 54 | Starting point of the colormap, represented by a float of 0 <= minval < 1. 55 | maxval: float 56 | End point of the colormap, represented by a float of 0 < maxval <= 1. 57 | n: int 58 | Number of colors for the colormap. 59 | 60 | Returns 61 | ------- 62 | new_cmap: matplotlib.colors.Colormap instance 63 | Truncated colormap. 64 | """ 65 | cmap = plt.get_cmap(cmap) 66 | new_cmap = mpl.colors.LinearSegmentedColormap.from_list( 67 | 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval), 68 | cmap(np.linspace(minval, maxval, n))) 69 | return new_cmap 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # For PyCharm 132 | .idea 133 | -------------------------------------------------------------------------------- /utils/change_model_properties.py: -------------------------------------------------------------------------------- 1 | """ 2 | Changes values of dictionary keys in a model_properties.pkl file. 3 | This script is mainly used to address bugs in train_model.py, where the dictionaries are created. 4 | 5 | Author: Andrew Justin (andrewjustinwx@gmail.com) 6 | Script version: 2023.8.12 7 | """ 8 | import argparse 9 | import pandas as pd 10 | import pickle 11 | from utils.misc import string_arg_to_dict 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--model_dir', type=str, help='Directory for the model.') 17 | parser.add_argument('--model_number', type=int, help='Model number.') 18 | parser.add_argument('--changes', type=str, 19 | help="Changes to make to the model properties dictionary. See utils.misc.string_arg_to_dict for more information.") 20 | parser.add_argument('--permission_override', action='store_true', 21 | help="WARNING: Read the description for this argument CAREFULLY. This is a boolean flag that overrides permission " 22 | "errors when attempting to modify critical model information. Changing properties that raise a PermissionError " 23 | "can render a model unusable with this module. ALWAYS create a backup of the model_*_properties.pkl file if " 24 | "you plan to modify critical model information.") 25 | 26 | args = vars(parser.parse_args()) 27 | 28 | model_properties_file = '%s/model_%d/model_%d_properties.pkl' % (args['model_dir'], args['model_number'], args['model_number']) 29 | model_properties = pd.read_pickle(model_properties_file) 30 | 31 | changes = string_arg_to_dict(args['changes']) 32 | 33 | critical_args = ['dataset_properties', 'normalization_parameters', 'training_years', 'validation_years', 'test_years', 'model_number'] 34 | critical_args_passed = list([arg for arg in critical_args if arg in changes]) 35 | 36 | if len(critical_args_passed) > 0: 37 | if not args['permission_override']: 38 | raise PermissionError( 39 | f"The following critical model properties were attempted to be modified: --{', --'.join(critical_args_passed)}. " 40 | "Changing these properties can render the model properties file to be incompatible with other scripts. " 41 | "If you would like to modify these properties, pass the --permission_override flag. ALWAYS CREATE A BACKUP " 42 | "model_*_properties.pkl file before proceeding.") 43 | 44 | for arg in changes: 45 | model_properties[arg] = changes[arg] 46 | 47 | # Rewrite the human-readable model properties text file 48 | with open(model_properties_file.replace('.pkl', '.txt'), 'w') as f: 49 | for key in model_properties.keys(): 50 | f.write(f"{key}: {model_properties[key]}\n") 51 | 52 | # Save the model properties dictionary with the new changes. 53 | with open(model_properties_file, 'wb') as f: 54 | pickle.dump(model_properties, f) 55 | -------------------------------------------------------------------------------- /utils/timestep_front_count.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tool for counting the number of fronts in each timestep so that tensorflow datasets can be quickly generated. This script 3 | effectively prevents empty timesteps from being analyzed by 'convert_netcdf_to_tf.py', saving potentially large amounts 4 | of time when generating tensorflow datasets. 5 | 6 | A dictionary containing front counts for timesteps across a given domain will be saved to a pickle file in a directory for 7 | tensorflow datasets. 8 | 9 | Author: Andrew Justin (andrewjustinwx@gmail.com) 10 | Script version: 2023.6.13 11 | """ 12 | import os 13 | import sys 14 | import csv 15 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) # this line allows us to import scripts outside the current directory 16 | import file_manager as fm 17 | from utils.settings import DEFAULT_DOMAIN_INDICES 18 | import argparse 19 | import numpy as np 20 | import xarray as xr 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--fronts_netcdf_indir', type=str, required=True, 26 | help="Input directory for the netCDF files containing frontal boundary data.") 27 | parser.add_argument('--tf_outdir', type=str, required=True, 28 | help="Output directory for future tensorflow datasets. This is where the pickle file containing frontal counts will " 29 | "also be saved.") 30 | parser.add_argument('--domain', type=str, default='conus', help='Domain from which to pull the images.') 31 | 32 | args = vars(parser.parse_args()) 33 | 34 | front_files_obj = fm.DataFileLoader(args['fronts_netcdf_indir'], data_file_type='fronts-netcdf') 35 | front_files = front_files_obj.front_files 36 | 37 | isel_kwargs = {'longitude': slice(DEFAULT_DOMAIN_INDICES[args['domain']][0], DEFAULT_DOMAIN_INDICES[args['domain']][1]), 38 | 'latitude': slice(DEFAULT_DOMAIN_INDICES[args['domain']][2], DEFAULT_DOMAIN_INDICES[args['domain']][3])} 39 | 40 | fieldnames = ['File', 'CF', 'CF-F', 'CF-D', 'WF', 'WF-F', 'WF-D', 'SF', 'SF-F', 'SF-D', 'OF', 'OF-F', 'OF-D', 'INST', 41 | 'TROF', 'TT', 'DL'] 42 | 43 | front_count_csv_file = '%s/timestep_front_counts_%s.csv' % (args['tf_outdir'], args['domain']) 44 | 45 | with open('%s/timestep_front_counts.csv' % args['tf_outdir'], 'w', newline='') as f: 46 | csvwriter = csv.writer(f) 47 | csvwriter.writerow(fieldnames) 48 | 49 | for file_no, front_file in enumerate(front_files): 50 | print(front_file, end='\r') 51 | front_dataset = xr.open_dataset(front_file, engine='netcdf4').isel(**isel_kwargs).expand_dims('time', axis=0).astype('float16') 52 | front_bins = np.bincount(front_dataset['identifier'].values.astype('int64').flatten(), minlength=17)[1:] # counts for each front type ('no front' type removed) 53 | 54 | row = [os.path.basename(front_file), *front_bins] 55 | 56 | csvwriter.writerow(row) 57 | -------------------------------------------------------------------------------- /plots.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for generating various plots. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.7.24 6 | 7 | TODO: Add functions for plotting ERA5, GDAS, and GFS data. 8 | """ 9 | 10 | import argparse 11 | import xarray as xr 12 | import matplotlib.pyplot as plt 13 | import matplotlib as mpl 14 | from utils import settings 15 | from utils.data_utils import reformat_fronts, expand_fronts 16 | from utils.plotting_utils import plot_background 17 | import numpy as np 18 | import cartopy.crs as ccrs 19 | 20 | 21 | def plot_fronts(netcdf_indir, plot_outdir, timestep, front_types, domain, extent=(-180, 180, -90, 90)): 22 | 23 | year, month, day, hour = timestep 24 | 25 | fronts_ds = xr.open_dataset('%s/%d%02d/FrontObjects_%d%02d%02d%02d_%s.nc' % (netcdf_indir, year, month, year, month, day, hour, domain)) 26 | 27 | if front_types is not None: 28 | fronts_ds = reformat_fronts(fronts_ds, front_types) 29 | labels = fronts_ds.attrs['labels'] 30 | 31 | fronts_ds = expand_fronts(fronts_ds, iterations=1) 32 | fronts_ds = xr.where(fronts_ds == 0, np.nan, fronts_ds) 33 | 34 | front_colors_by_type = [settings.DEFAULT_FRONT_COLORS[label] for label in labels] 35 | front_names_by_type = [settings.DEFAULT_FRONT_NAMES[label] for label in labels] 36 | cmap_front = mpl.colors.ListedColormap(front_colors_by_type, name='from_list', N=len(front_colors_by_type)) 37 | norm_front = mpl.colors.Normalize(vmin=1, vmax=len(front_colors_by_type) + 1) 38 | 39 | fig, ax = plt.subplots(1, 1, figsize=(16, 8), subplot_kw={'projection': ccrs.PlateCarree(central_longitude=np.mean(extent[:2]))}) 40 | plot_background(extent, ax=ax, linewidth=0.25) 41 | 42 | cbar_front = plt.colorbar(mpl.cm.ScalarMappable(norm=norm_front, cmap=cmap_front), ax=ax, alpha=0.75, shrink=0.8, pad=0.02) 43 | cbar_front.set_ticks(np.arange(1, len(front_colors_by_type) + 1) + 0.5) 44 | cbar_front.set_ticklabels(front_names_by_type) 45 | cbar_front.set_label('Front Type') 46 | 47 | fronts_ds['identifier'].plot(ax=ax, x='longitude', y='latitude', cmap=cmap_front, norm=norm_front, transform=ccrs.PlateCarree(), 48 | add_colorbar=False) 49 | ax.gridlines(alpha=0.5) 50 | 51 | plt.tight_layout() 52 | plt.savefig(f"%s/fronts_%d%02d%02d%02d_{domain}.png" % (plot_outdir, year, month, day, hour), dpi=300, bbox_inches='tight') 53 | plt.close() 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--timestep', type=int, nargs=4, help='Year, month, day, and hour of the data.') 59 | parser.add_argument('--netcdf_indir', type=str, help='Directory for the netcdf files.') 60 | parser.add_argument('--plot_outdir', type=str, help='Directory for the plots.') 61 | parser.add_argument('--front_types', type=str, nargs='+', help='Directory for the netcdf files.') 62 | parser.add_argument('--domain', type=str, default='full', help="Domain for which the fronts will be plotted.") 63 | parser.add_argument('--extent', type=float, nargs=4, default=[-180., 180., -90., 90.], help="Extent of the plot [min lon, max lon, min lat, max lat]") 64 | args = vars(parser.parse_args()) 65 | 66 | plot_fronts(args['netcdf_indir'], args['plot_outdir'], args['timestep'], args['front_types'], args['domain'], args['extent']) 67 | -------------------------------------------------------------------------------- /custom_activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom activation functions: 3 | - Gaussian 4 | - GCU (Growing Cosine Unit) 5 | - SmeLU (Smooth ReLU) 6 | - Snake 7 | 8 | Author: Andrew Justin (andrewjustinwx@gmail.com) 9 | Script version: 2023.3.3 10 | """ 11 | from tensorflow.keras.layers import Layer 12 | import tensorflow as tf 13 | 14 | 15 | class Gaussian(Layer): 16 | """ 17 | Gaussian function activation layer. 18 | """ 19 | def __init__(self, name=None): 20 | super(Gaussian, self).__init__(name=name) 21 | 22 | def build(self, input_shape): 23 | """ Build the Gaussian layer """ 24 | 25 | def call(self, inputs): 26 | """ Call the Gaussian activation function """ 27 | inputs = tf.cast(inputs, 'float32') 28 | square_tensor = tf.constant(2.0, shape=inputs.shape[1:]) 29 | y = tf.math.exp(tf.math.negative(tf.math.pow(inputs, square_tensor))) 30 | 31 | return y 32 | 33 | 34 | class GCU(Layer): 35 | """ 36 | Growing Cosine Unit (GCU) activation layer. 37 | """ 38 | def __init__(self, name=None): 39 | super(GCU, self).__init__(name=name) 40 | 41 | def build(self, input_shape): 42 | """ Build the GCU layer """ 43 | 44 | def call(self, inputs): 45 | """ Call the GCU activation function """ 46 | inputs = tf.cast(inputs, 'float32') 47 | y = tf.multiply(inputs, tf.math.cos(inputs)) 48 | 49 | return y 50 | 51 | 52 | class SmeLU(Layer): 53 | """ 54 | SmeLU (Smooth ReLU) activation function layer for deep learning models. 55 | 56 | References 57 | ---------- 58 | https://arxiv.org/pdf/2202.06499.pdf 59 | """ 60 | def __init__(self, name=None): 61 | super(SmeLU, self).__init__(name=name) 62 | 63 | def build(self, input_shape): 64 | """ Build the SmeLU layer """ 65 | self.beta = self.add_weight(name='beta', dtype='float32', shape=input_shape[1:]) # Learnable parameter (see Eq. 7 in the linked paper above) 66 | 67 | def call(self, inputs): 68 | """ Call the SmeLU activation function """ 69 | inputs = tf.cast(inputs, 'float32') 70 | y = tf.where(inputs <= -self.beta, 0.0, # Condition 1 71 | tf.where(tf.abs(inputs) <= self.beta, tf.math.divide(tf.math.pow(inputs + self.beta, 2.0), tf.math.multiply(4.0, self.beta)), # Condition 2 72 | inputs)) # Condition 3 (if x >= beta) 73 | 74 | return y 75 | 76 | 77 | class Snake(Layer): 78 | """ 79 | Snake activation function layer for deep learning models. 80 | 81 | References 82 | ---------- 83 | https://arxiv.org/pdf/2006.08195.pdf 84 | """ 85 | def __init__(self, name=None): 86 | super(Snake, self).__init__(name=name) 87 | 88 | def build(self, input_shape): 89 | """ Build the Snake layer """ 90 | self.alpha = self.add_weight(name='alpha', dtype='float32', shape=input_shape[1:]) # Learnable parameter (see Eq. 3 in the linked paper above) 91 | self.square_tensor = tf.constant(2.0, shape=input_shape[1:]) 92 | 93 | def call(self, inputs): 94 | """ Call the Snake activation function """ 95 | inputs = tf.cast(inputs, 'float32') 96 | y = inputs + tf.multiply(tf.divide(tf.constant(1.0, shape=inputs.shape[1:]), self.alpha), tf.math.pow(tf.math.sin(tf.multiply(self.alpha, inputs)), self.square_tensor)) 97 | 98 | return y 99 | 100 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiobotocore==2.13.0 3 | aiohttp==3.9.5 4 | aioitertools==0.11.0 5 | aiosignal==1.3.1 6 | archspec==0.2.3 7 | asttokens==2.4.1 8 | astunparse==1.6.3 9 | async-timeout==4.0.3 10 | attrs==23.2.0 11 | bokeh==3.4.1 12 | boltons==24.0.0 13 | botocore==1.34.106 14 | Brotli==1.1.0 15 | cached-property==1.5.2 16 | cachetools==5.3.3 17 | Cartopy==0.23.0 18 | certifi==2024.7.4 19 | cffi==1.16.0 20 | cfgrib==0.9.11.0 21 | cftime==1.6.3 22 | charset-normalizer==3.3.2 23 | click==8.1.7 24 | cloudpickle==3.0.0 25 | colorama==0.4.6 26 | conda==24.5.0 27 | conda-libmamba-solver==24.1.0 28 | conda-package-handling==2.2.0 29 | conda_package_streaming==0.9.0 30 | contourpy==1.2.1 31 | cycler==0.12.1 32 | cytoolz==0.12.3 33 | dask==2024.5.0 34 | dask-expr==1.1.0 35 | decorator==5.1.1 36 | distributed==2024.5.0 37 | distro==1.9.0 38 | docker-pycreds==0.4.0 39 | eccodes==1.7.0 40 | exceptiongroup==1.2.0 41 | executing==2.0.1 42 | findlibs==0.0.5 43 | flatbuffers==24.3.25 44 | fonttools==4.51.0 45 | frozendict==2.4.4 46 | frozenlist==1.4.1 47 | fsspec==2024.6.0 48 | gast==0.4.0 49 | gitdb==4.0.11 50 | GitPython==3.1.43 51 | google-auth==2.29.0 52 | google-auth-oauthlib==0.4.6 53 | google-pasta==0.2.0 54 | graphviz==0.20.3 55 | grpcio==1.63.0 56 | h5netcdf==1.3.0 57 | h5py==3.11.0 58 | idna==3.7 59 | imagecodecs==2024.1.1 60 | imageio==2.34.1 61 | importlib_metadata==7.1.0 62 | ipython==8.24.0 63 | jedi==0.19.1 64 | Jinja2==3.1.4 65 | jmespath==1.0.1 66 | joblib==1.4.2 67 | jsonpatch==1.33 68 | jsonpointer==2.4 69 | keras==2.10.0 70 | Keras-Preprocessing==1.1.2 71 | kiwisolver==1.4.5 72 | lazy_loader==0.4 73 | libclang==18.1.1 74 | libmambapy==1.5.8 75 | locket==1.0.0 76 | lz4==4.3.3 77 | mamba==1.5.8 78 | Markdown==3.6 79 | MarkupSafe==2.1.5 80 | matplotlib==3.8.4 81 | matplotlib-inline==0.1.7 82 | menuinst==2.0.2 83 | mplcursors==0.5.3 84 | msgpack==1.0.8 85 | multidict==6.0.5 86 | munkres==1.1.4 87 | netCDF4==1.5.8 88 | networkx==3.3 89 | numpy==1.26.4 90 | oauthlib==3.2.2 91 | opt-einsum==3.3.0 92 | packaging==24.0 93 | pandas==2.2.2 94 | parso==0.8.4 95 | partd==1.4.2 96 | pickleshare==0.7.5 97 | pillow==10.3.0 98 | pip==24.0 99 | platformdirs==4.2.1 100 | pluggy==1.5.0 101 | ply==3.11 102 | prompt-toolkit==3.0.42 103 | protobuf==3.19.6 104 | psutil==5.9.8 105 | pure-eval==0.2.2 106 | pyarrow==15.0.2 107 | pyarrow-hotfix==0.6 108 | pyasn1==0.6.0 109 | pyasn1_modules==0.4.0 110 | pycosat==0.6.6 111 | pycparser==2.22 112 | pydot==2.0.0 113 | Pygments==2.18.0 114 | pyparsing==3.1.2 115 | pyproj==3.6.1 116 | PyQt5==5.15.9 117 | PyQt5-sip==12.12.2 118 | pyshp==2.3.1 119 | PySocks==1.7.1 120 | python-dateutil==2.9.0 121 | pytz==2024.1 122 | PyWavelets==1.4.1 123 | PyYAML==6.0.1 124 | requests==2.31.0 125 | requests-oauthlib==2.0.0 126 | rsa==4.9 127 | ruamel.yaml==0.18.6 128 | ruamel.yaml.clib==0.2.8 129 | s3fs==2024.6.0 130 | scikit-image==0.22.0 131 | scikit-learn==1.4.2 132 | scipy==1.13.0 133 | sentry-sdk==2.5.1 134 | setproctitle==1.3.3 135 | setuptools==69.5.1 136 | shapely==2.0.4 137 | sip==6.7.12 138 | six==1.16.0 139 | smmap==5.0.1 140 | sortedcontainers==2.4.0 141 | stack-data==0.6.2 142 | tblib==3.0.0 143 | tensorboard==2.10.1 144 | tensorboard-data-server==0.6.1 145 | tensorboard-plugin-wit==1.8.1 146 | tensorflow==2.10.0 147 | tensorflow-estimator==2.10.0 148 | tensorflow-intel==2.11.0 149 | tensorflow-io-gcs-filesystem==0.31.0 150 | termcolor==2.4.0 151 | threadpoolctl==3.5.0 152 | tifffile==2024.5.10 153 | toml==0.10.2 154 | tomli==2.0.1 155 | toolz==0.12.1 156 | tornado==6.4 157 | tqdm==4.66.4 158 | traitlets==5.14.3 159 | truststore==0.8.0 160 | typing_extensions==4.11.0 161 | tzdata==2024.1 162 | unicodedata2==15.1.0 163 | urllib3==2.2.1 164 | wandb==0.17.1 165 | wcwidth==0.2.13 166 | Werkzeug==3.0.3 167 | wget==3.2 168 | wheel==0.43.0 169 | win-inet-pton==1.1.0 170 | wrapt==1.16.0 171 | xarray==2024.5.0 172 | xyzservices==2024.4.0 173 | yarl==1.9.4 174 | zict==3.0.0 175 | zipp==3.17.0 176 | zstandard==0.19.0 177 | -------------------------------------------------------------------------------- /evaluation/learning_curve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot the learning curve for a model. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.6.12 6 | """ 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pandas as pd 11 | 12 | 13 | if __name__ == '__main__': 14 | """ 15 | All arguments listed in the examples are listed via argparse in alphabetical order below this comment block. 16 | """ 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--model_dir', type=str, required=True, help='Directory for the models.') 19 | parser.add_argument('--model_number', type=int, required=True, help='Model number.') 20 | 21 | args = vars(parser.parse_args()) 22 | 23 | with open("%s/model_%d/model_%d_history.csv" % (args['model_dir'], args['model_number'], args['model_number']), 'rb') as f: 24 | history = pd.read_csv(f) 25 | 26 | model_properties = pd.read_pickle(f"{args['model_dir']}/model_{args['model_number']}/model_{args['model_number']}_properties.pkl") 27 | 28 | # Model properties 29 | try: 30 | loss = model_properties['loss'] 31 | except KeyError: 32 | loss = model_properties['loss_string'] 33 | 34 | try: 35 | metric_string = model_properties['metric'] 36 | except KeyError: 37 | metric_string = model_properties['metric_string'] 38 | 39 | if model_properties['deep_supervision']: 40 | train_metric = history['sup1_Softmax_%s' % metric_string] 41 | val_metric = history['val_sup1_Softmax_%s' % metric_string] 42 | else: 43 | train_metric = history[metric_string] 44 | val_metric = history['val_%s' % metric_string] 45 | 46 | if 'fss' in loss.lower(): 47 | loss_title = 'Fractions Skill Score (loss)' 48 | elif 'bss' in loss.lower(): 49 | loss_title = 'Brier Skill Score (loss)' 50 | elif 'csi' in loss.lower(): 51 | loss_title = 'Categorical Cross-Entropy' 52 | else: 53 | loss_title = None 54 | 55 | if 'fss' in metric_string: 56 | metric_title = 'Fractions Skill Score' 57 | elif 'bss' in metric_string: 58 | metric_title = 'Brier Skill Score' 59 | elif 'csi' in metric_string: 60 | metric_title = 'Critical Success Index' 61 | else: 62 | metric_title = None 63 | 64 | min_val_loss_epoch = np.where(history['val_loss'] == np.min(history['val_loss']))[0][0] + 1 65 | 66 | num_epochs = len(history['val_loss']) 67 | 68 | fig, axs = plt.subplots(1, 2, figsize=(12, 6), dpi=300) 69 | axarr = axs.flatten() 70 | 71 | annotate_kwargs = dict(color='black', va='center', xycoords='axes fraction', fontsize=11) 72 | axarr[0].annotate('Epoch %d' % min_val_loss_epoch, xy=(0, -0.2), fontweight='bold', **annotate_kwargs) 73 | axarr[0].annotate('Training/Validation loss: %.4e, %.4e' % (history['loss'][min_val_loss_epoch - 1], history['val_loss'][min_val_loss_epoch - 1]), xy=(0, -0.25), **annotate_kwargs) 74 | axarr[0].annotate('Training/Validation metric: %.4f, %.4f' % (train_metric[min_val_loss_epoch - 1], val_metric[min_val_loss_epoch - 1]), xy=(0, -0.3), **annotate_kwargs) 75 | 76 | axarr[0].set_title(loss_title) 77 | axarr[0].plot(np.arange(1, num_epochs + 1), history['loss'], color='blue', label='Training loss') 78 | axarr[0].plot(np.arange(1, num_epochs + 1), history['val_loss'], color='red', label='Validation loss') 79 | axarr[0].set_xlim(xmin=0, xmax=num_epochs + 1) 80 | axarr[0].set_xlabel('Epochs') 81 | axarr[0].legend(loc='best') 82 | axarr[0].grid() 83 | axarr[0].set_yscale('log') # Turns y-axis into a logarithmic scale. Useful if loss functions appear as very sharp curves. 84 | 85 | axarr[1].set_title(metric_title) 86 | axarr[1].plot(np.arange(1, num_epochs + 1), train_metric, color='blue', label='Training') 87 | axarr[1].plot(np.arange(1, num_epochs + 1), val_metric, color='red', label='Validation') 88 | axarr[1].set_xlim(xmin=0, xmax=num_epochs + 1) 89 | axarr[1].set_ylim(ymin=0) 90 | axarr[1].set_xlabel('Epochs') 91 | axarr[1].legend(loc='best') 92 | axarr[1].grid() 93 | 94 | plt.tight_layout() 95 | plt.savefig("%s/model_%d/model_%d_learning_curve.png" % (args['model_dir'], args['model_number'], args['model_number']), bbox_inches='tight') 96 | plt.close() 97 | -------------------------------------------------------------------------------- /convert_front_gml_to_xml.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert GML files containing IBM/TWC fronts into XML files. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.9.2 6 | """ 7 | 8 | import argparse 9 | from lxml import etree as ET 10 | from glob import glob 11 | import os 12 | import numpy as np 13 | 14 | XML_FRONT_TYPE = {'Cold Front': 'COLD_FRONT', 'Dissipating Cold Front': 'COLD_FRONT_DISS', 15 | 'Warm Front': 'WARM_FRONT', 'Stationary Front': 'STATIONARY_FRONT', 16 | 'Occluded Front': 'OCCLUDED_FRONT', 'Dissipating Occluded Front': 'OCCLUDED_FRONT_DISS', 17 | 'Dry Line': 'DRY_LINE', 'Trough': 'TROF', 'Squall Line': 'INSTABILITY'} 18 | 19 | XML_FRONT_COLORS = {'Cold Front': dict(red="0", green="0", blue="255"), 'Dissipating Cold Front': dict(red="0", green="0", blue="255"), 20 | 'Warm Front': dict(red="255", green="0", blue="0"), 'Dissipating Warm Front': dict(red="255", green="0", blue="0"), 21 | 'Occluded Front': dict(red="145", green="44", blue="238"), 'Dissipating Occluded Front': dict(red="145", green="44", blue="238"), 22 | 'Dry Line': dict(red="255", green="130", blue="71"), 'Trough': dict(red="255", green="130", blue="71"), 23 | 'Squall Line': dict(red="255", green="0", blue="0")} 24 | 25 | LINE_KWARGS = dict(pgenCategory="Front", lineWidth="4", sizeScale=" 1.0", smoothFactor="2", closed="false", filled="false", 26 | fillPattern="SOLID") 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--gml_indir', type=str, required=True, help="Input directory for IBM/TWC front GML files.") 32 | parser.add_argument('--xml_outdir', type=str, required=True, help="Output directory for front XML files.") 33 | parser.add_argument('--date', type=int, nargs=3, required=True, help="Date for the data to be read in. (year, month, day)") 34 | args = vars(parser.parse_args()) 35 | 36 | year, month, day = args['date'] 37 | 38 | gml_files = sorted(glob('%s/%d%02d%02d/*/*%d%02d%02d*.gml' % (args['gml_indir'], year, month, day, year, month, day))) 39 | 40 | for gml_file in gml_files: 41 | 42 | valid_time_str = os.path.basename(gml_file).split('.')[2] 43 | valid_time_str = valid_time_str[:4] + '-' + valid_time_str[4:6] + '-' + valid_time_str[6:8] + 'T' + valid_time_str[9:11] 44 | valid_time = np.datetime64(valid_time_str, 'ns') 45 | 46 | init_time_str = os.path.basename(gml_file).split('.')[3] 47 | if init_time_str != 'NIL': # an init time of 'NIL' is used to indicate forecast hour 0 (i.e. valid time is same as init time) 48 | init_time_str = init_time_str[:4] + '-' + init_time_str[4:6] + '-' + init_time_str[6:8] + 'T' + init_time_str[9:11] 49 | init_time = np.datetime64(init_time_str, 'ns') 50 | else: 51 | init_time_str = valid_time_str 52 | init_time = valid_time 53 | 54 | forecast_hour = int((valid_time - init_time) / np.timedelta64(1, 'h')) 55 | 56 | root_xml = ET.Element("Product", name="IBM_global_fronts", init_time=init_time_str, valid_time=valid_time_str, forecast_hour=str(forecast_hour)) 57 | tree = ET.parse(gml_file, parser=ET.XMLPullParser(encoding='utf-8')) 58 | root_gml = tree.getroot() 59 | 60 | Layer = ET.SubElement(root_xml, "Layer", name="Default", onOff="true", monoColor="false", filled="false") 61 | ET.SubElement(Layer, "Color", red="255", green="255", blue="0", alpha="255") 62 | DrawableElement = ET.SubElement(Layer, "DrawableElement") 63 | 64 | front_elements = [element[0] for element in root_gml if element[0].tag == 'FRONT'] 65 | 66 | for element in front_elements: 67 | front_type = [subelement.text for subelement in element if subelement.tag == 'FRONT_TYPE'][0] 68 | coords = [subelement for subelement in element if 'lineString' in subelement.tag][0][0][0].text 69 | 70 | Line = ET.SubElement(DrawableElement, "Line", pgenType=XML_FRONT_TYPE[front_type], **LINE_KWARGS) 71 | if front_type == 'Stationary Front': 72 | ET.SubElement(Line, "Color", red="255", green="0", blue="0", alpha="255") 73 | ET.SubElement(Line, "Color", red="0", green="0", blue="255", alpha="255") 74 | else: 75 | ET.SubElement(Line, "Color", **XML_FRONT_COLORS[front_type], alpha="255") 76 | 77 | coords = coords.replace('\n', '').split(' ') # generate coordinate strings 78 | coords = list(coord_pair.split(',') for coord_pair in coords) # generate coordinate pairs from the strings 79 | 80 | for coord_pair in coords: 81 | ET.SubElement(Line, "Point", Lat="%.6f" % float(coord_pair[1]), Lon="%.6f" % float(coord_pair[0])) 82 | 83 | save_path_file = "%s/IBM_fronts_%sf%03d.xml" % (args['xml_outdir'], init_time_str.replace('-', '').replace('T', ''), forecast_hour) 84 | 85 | print(save_path_file) 86 | 87 | ET.indent(root_xml) 88 | mydata = ET.tostring(root_xml) 89 | xmlFile = open(save_path_file, "wb") 90 | xmlFile.write(mydata) 91 | xmlFile.close() 92 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous tools. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.7.7.D1 6 | """ 7 | 8 | 9 | def string_arg_to_dict(arg_str: str): 10 | """ 11 | Function that converts a string argument into a dictionary. Dictionaries cannot be passed through a command line, so 12 | this function takes a special string argument and converts it to a dictionary so arguments within a function can be 13 | explicitly called. 14 | 15 | arg_str: str 16 | arg_dict: dict 17 | """ 18 | 19 | arg_str = arg_str.replace(' ', '') # Remove all spaces from the string. 20 | arg_dict = {} # Dictionary that will contain the arguments and their respective values 21 | 22 | # Iterate through all the arguments within the string 23 | while True: 24 | 25 | equals_index = arg_str.find('=') # Index representing where an equals sign is located, marking the end of the argument name 26 | 27 | ################################# Attempt to see if a tuple or list was passed ################################# 28 | open_parenthesis_index = arg_str.find('(') 29 | close_parenthesis_index = arg_str.find(')') 30 | open_bracket_index = arg_str.find('[') 31 | close_bracket_index = arg_str.find(']') 32 | 33 | if open_parenthesis_index == close_parenthesis_index and open_bracket_index == close_bracket_index: # These will only be equal when there are no parentheses/brackets in the argument string (i.e. there is no tuple/list) 34 | comma_index = arg_str.find(',') # Index representing where the first comma is located within the 'arg' string, essentially representing the end of a argument 35 | elif open_parenthesis_index == -1 and close_parenthesis_index != -1: 36 | raise TypeError("An open parenthesis appears to be missing. Check the argument string.") 37 | elif open_parenthesis_index != -1 and close_parenthesis_index == -1: 38 | raise TypeError("A closed parenthesis appears to be missing. Check the argument string.") 39 | elif open_bracket_index == -1 and close_bracket_index != -1: 40 | raise TypeError("An open bracket appears to be missing. Check the argument string.") 41 | elif open_bracket_index != -1 and close_bracket_index == -1: 42 | raise TypeError("A closed bracket appears to be missing. Check the argument string.") 43 | elif open_parenthesis_index != close_parenthesis_index: 44 | comma_index = close_parenthesis_index + 1 45 | else: 46 | comma_index = close_bracket_index + 1 47 | 48 | current_arg_name = arg_str[:equals_index] 49 | 50 | if comma_index == -1: # When the final argument is being added to the dictionary, this index will become -1 51 | current_arg_value = arg_str[equals_index + 1:] 52 | else: 53 | current_arg_value = arg_str[equals_index + 1:comma_index] 54 | 55 | ######################################## Convert the argument to a tuple ####################################### 56 | if open_parenthesis_index != close_parenthesis_index: 57 | 58 | arg_dict[current_arg_name] = current_arg_value.replace('(', '').replace(')', '').split(',') 59 | 60 | if '.' in arg_dict[current_arg_name]: # If the tuple appears to contain a float 61 | arg_dict[current_arg_name] = tuple([float(val) for val in arg_dict[current_arg_name]]) 62 | else: 63 | arg_dict[current_arg_name] = tuple([int(val) for val in arg_dict[current_arg_name]]) 64 | ################################################################################################################ 65 | 66 | ######################################## Convert the argument to a list ######################################## 67 | elif open_bracket_index != close_bracket_index: 68 | 69 | arg_dict[current_arg_name] = current_arg_value.replace('[', '').replace(']', '').split(',') 70 | 71 | list_values = [] 72 | for val in arg_dict[current_arg_name]: 73 | if '.' in val: 74 | list_values.append(float(val)) 75 | else: 76 | try: 77 | list_values.append(int(val)) 78 | except ValueError: 79 | list_values.append(val) 80 | 81 | arg_dict[current_arg_name] = list_values 82 | ################################################################################################################ 83 | 84 | else: 85 | 86 | if '.' in current_arg_value: 87 | arg_dict[current_arg_name] = float(current_arg_value) 88 | else: 89 | try: 90 | arg_dict[current_arg_name] = int(current_arg_value) 91 | except ValueError: 92 | if current_arg_value == 'True': 93 | arg_dict[current_arg_name] = True 94 | elif current_arg_value == 'False': 95 | arg_dict[current_arg_name] = False 96 | else: 97 | arg_dict[current_arg_name] = current_arg_value.replace("'", '') 98 | 99 | arg_str = arg_str[comma_index + 1:] # After the current argument has been added to the dictionary, remove it from the argument string 100 | 101 | if comma_index == -1 or len(arg_str) == 0: 102 | break 103 | 104 | return arg_dict 105 | -------------------------------------------------------------------------------- /utils/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Default settings 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.7.24 6 | """ 7 | DEFAULT_DOMAIN_EXTENTS = {'global': [0, 359.75, -89.75, 90], 8 | 'full': [130, 369.75, 0.25, 80], 9 | 'conus': [228, 299.75, 25, 56.75]} # default values for extents of domains [start lon, end lon, start lat, end lat] 10 | DEFAULT_DOMAIN_INDICES = {'global': [0, 1440, 0, 720], 11 | 'full': [0, 960, 0, 320], 12 | 'conus': [392, 680, 93, 221]} # indices corresponding to default extents of domains [start lon, end lon, start lat, end lat] 13 | DEFAULT_DOMAIN_IMAGES = {'global': [17, 9], 14 | 'full': [8, 3], 15 | 'conus': [3, 1]} # default values for the number of images to use when making predictions [lon, lat] 16 | 17 | # colors for plotted ground truth fronts 18 | DEFAULT_FRONT_COLORS = {'CF': 'blue', 'WF': 'red', 'SF': 'limegreen', 'OF': 'darkviolet', 'CF-F': 'darkblue', 'WF-F': 'darkred', 19 | 'SF-F': 'darkgreen', 'OF-F': 'darkmagenta', 'CF-D': 'lightskyblue', 'WF-D': 'lightcoral', 'SF-D': 'lightgreen', 20 | 'OF-D': 'violet', 'INST': 'gold', 'TROF': 'goldenrod', 'TT': 'orange', 'DL': 'chocolate', 21 | 'MERGED-CF': 'blue', 'MERGED-WF': 'red', 'MERGED-SF': 'limegreen', 'MERGED-OF': 'darkviolet', 'MERGED-F': 'gray', 22 | 'MERGED-T': 'brown', 'F_BIN': 'tab:red', 'MERGED-F_BIN': 'tab:red'} 23 | 24 | # colormaps of probability contours for front predictions 25 | DEFAULT_CONTOUR_CMAPS = {'CF': 'Blues', 'WF': 'Reds', 'SF': 'Greens', 'OF': 'Purples', 'CF-F': 'Blues', 'WF-F': 'Reds', 26 | 'SF-F': 'Greens', 'OF-F': 'Purples', 'CF-D': 'Blues', 'WF-D': 'Reds', 'SF-D': 'Greens', 'OF-D': 'Purples', 27 | 'INST': 'YlOrBr', 'TROF': 'YlOrRd', 'TT': 'Oranges', 'DL': 'copper_r', 'MERGED-CF': 'Blues', 28 | 'MERGED-WF': 'Reds', 'MERGED-SF': 'Greens', 'MERGED-OF': 'Purples', 'MERGED-F': 'Greys', 'MERGED-T': 'YlOrBr', 29 | 'F_BIN': 'Greys', 'MERGED-F_BIN': 'Greys'} 30 | 31 | # names of front types 32 | DEFAULT_FRONT_NAMES = {'CF': 'Cold front', 'WF': 'Warm front', 'SF': 'Stationary front', 'OF': 'Occluded front', 'CF-F': 'Cold front (forming)', 33 | 'WF-F': 'Warm front (forming)', 'SF-F': 'Stationary front (forming)', 'OF-F': 'Occluded front (forming)', 34 | 'CF-D': 'Cold front (dying)', 'WF-D': 'Warm front (dying)', 'SF-D': 'Stationary front (dying)', 'OF-D': 'Occluded front (dying)', 35 | 'INST': 'Outflow boundary', 'TROF': 'Trough', 'TT': 'Tropical trough', 'DL': 'Dryline', 36 | 'MERGED-CF': 'Cold front (any)', 'MERGED-WF': 'Warm front (any)', 'MERGED-SF': 'Stationary front (any)', 'MERGED-OF': 'Occluded front (any)', 37 | 'MERGED-F': 'CF, WF, SF, OF (any)', 'MERGED-T': 'Trough (any)', 'F_BIN': 'Binary front', 'MERGED-F_BIN': 'Binary front (any)'} 38 | 39 | """ 40 | TIMESTEP_PREDICT_SIZE is the number of timesteps for which predictions will be processed at the same time. In other words, 41 | if this parameter is 10, then up to 10 maps will be generated at the same time for 10 timesteps. 42 | 43 | Typically, raising this parameter will result in faster predictions, but the memory requirements increase as well. The size 44 | of the domain for which the predictions are being generated greatly affects the limits of this parameter. 45 | 46 | NOTES: 47 | - Setting the values at a lower threshold may result in slower predictions but will not have negative effects on hardware. 48 | - Increasing the parameters above the default values is STRONGLY discouraged. Greatly exceeding the allocated RAM 49 | will force your operating system to resort to virtual memory usage, which can cause major slowdowns, OS crashes, and GPU failure. 50 | 51 | In the case of any hardware failures (including OOM errors from the GPU) or major slowdowns, reduce the parameter(s) for the 52 | domain(s) until your system becomes stable. 53 | """ 54 | TIMESTEP_PREDICT_SIZE = {'conus': 128, 'full': 64, 'global': 16} # All values here are adjusted for 16 GB of system RAM and 10 GB of GPU VRAM 55 | 56 | """ 57 | GPU_PREDICT_BATCH_SIZE is the number of images that the GPU will process at one time when generating predictions. 58 | If predictions are being generated on the same GPU that the model was trained on, then this value should be equal to or greater than 59 | the batch size used when training the model. 60 | 61 | NOTES: 62 | - This value should ideally be a value of 2^n, where n is any integer. Using values not equal to 2^n may have negative effects 63 | on performance. 64 | - Decreasing this parameter will result in overall slower performance, but can help prevent OOM errors on the GPU. 65 | - Increasing this parameter can speed up predictions on high-performance GPUs, but a value too large can cause OOM errors 66 | and GPU failure. 67 | """ 68 | GPU_PREDICT_BATCH_SIZE = 8 69 | 70 | """ 71 | MAX_FILE_CHUNK_SIZE is the maximum number of ERA5, GDAS, and/or GFS netCDF files that will be loaded into an xarray dataset at one 72 | time. Loading too many files / too much data into one xarray dataset can take a very long time and may lead to segmentation errors. 73 | If segmentation errors are occurring, consider lowering this parameter until the error disappears. 74 | """ 75 | MAX_FILE_CHUNK_SIZE = 2500 76 | 77 | """ 78 | MAX_TRAIN_BUFFER_SIZE is the maximum number of elements within the training dataset that can be shuffled at one time. Tensorflow 79 | does not efficiently use RAM during shuffling on Windows machines and can lead to system crashes, so the buffer size should be 80 | relatively small. It is important to monitor RAM usage if you are training a model on Windows. Linux is able to shuffle much 81 | larger datasets than Windows, but crashes can still occur if the maximum buffer size is too large. 82 | """ 83 | MAX_TRAIN_BUFFER_SIZE = 5000 84 | -------------------------------------------------------------------------------- /evaluation/calibrate_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calibrate a trained model. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.6.24.D1 6 | """ 7 | import argparse 8 | import pandas as pd 9 | import os 10 | import sys 11 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) # this line allows us to import scripts outside the current directory 12 | from utils.settings import DEFAULT_FRONT_NAMES 13 | import matplotlib.pyplot as plt 14 | import pickle 15 | import xarray as xr 16 | import numpy as np 17 | from sklearn.isotonic import IsotonicRegression 18 | from sklearn.metrics import r2_score 19 | 20 | 21 | if __name__ == '__main__': 22 | """ 23 | All arguments listed in the examples are listed via argparse in alphabetical order below this comment block. 24 | """ 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--dataset', type=str, help="Dataset for which to make predictions if prediction_method is 'random' or 'all'. Options are:" 27 | "'training', 'validation', 'test'") 28 | parser.add_argument('--domain', type=str, help='Domain of the data.') 29 | parser.add_argument('--model_dir', type=str, help='Directory for the models.') 30 | parser.add_argument('--model_number', type=int, help='Model number.') 31 | parser.add_argument('--data_source', type=str, default='era5', help='Data source for variables') 32 | 33 | args = vars(parser.parse_args()) 34 | 35 | model_properties = pd.read_pickle('%s/model_%d/model_%d_properties.pkl' % (args['model_dir'], args['model_number'], args['model_number'])) 36 | 37 | ### front_types argument is being moved into the dataset_properties dictionary within model_properties ### 38 | try: 39 | front_types = model_properties['front_types'] 40 | except KeyError: 41 | front_types = model_properties['dataset_properties']['front_types'] 42 | 43 | if type(front_types) == str: 44 | front_types = [front_types, ] 45 | 46 | try: 47 | _ = model_properties['calibration_models'] # Check to see if the model has already been calibrated before 48 | except KeyError: 49 | model_properties['calibration_models'] = dict() 50 | 51 | model_properties['calibration_models'][args['domain']] = dict() 52 | 53 | stats_ds = xr.open_dataset('%s/model_%d/statistics/model_%d_statistics_%s_%s.nc' % (args['model_dir'], args['model_number'], args['model_number'], args['domain'], args['dataset'])) 54 | 55 | axis_ticks = np.arange(0.1, 1.1, 0.1) 56 | 57 | for front_label in front_types: 58 | 59 | model_properties['calibration_models'][args['domain']][front_label] = dict() 60 | 61 | true_positives = stats_ds[f'tp_temporal_{front_label}'].values 62 | false_positives = stats_ds[f'fp_temporal_{front_label}'].values 63 | 64 | thresholds = stats_ds['threshold'].values 65 | 66 | ### Sum the true positives along the 'time' axis ### 67 | true_positives_sum = np.sum(true_positives, axis=0) 68 | false_positives_sum = np.sum(false_positives, axis=0) 69 | 70 | ### Find the number of true positives and false positives in each probability bin ### 71 | true_positives_diff = np.abs(np.diff(true_positives_sum)) 72 | false_positives_diff = np.abs(np.diff(false_positives_sum)) 73 | observed_relative_frequency = np.divide(true_positives_diff, true_positives_diff + false_positives_diff) 74 | 75 | boundary_colors = ['red', 'purple', 'brown', 'darkorange', 'darkgreen'] 76 | 77 | calibrated_probabilities = [] 78 | 79 | fig, axs = plt.subplots(1, 2, figsize=(14, 6)) 80 | axs[0].plot(thresholds, thresholds, color='black', linestyle='--', linewidth=0.5, label='Perfect Reliability') 81 | 82 | for boundary, color in enumerate(boundary_colors): 83 | 84 | ####################### Test different calibration methods to see which performs best ###################### 85 | 86 | x = [threshold for threshold, frequency in zip(thresholds[1:], observed_relative_frequency[boundary]) if not np.isnan(frequency)] 87 | y = [frequency for threshold, frequency in zip(thresholds[1:], observed_relative_frequency[boundary]) if not np.isnan(frequency)] 88 | 89 | ### Isotonic Regression ### 90 | ir = IsotonicRegression(out_of_bounds='clip') 91 | ir.fit_transform(x, y) 92 | calibrated_probabilities.append(ir.predict(x)) 93 | r_squared = r2_score(y, calibrated_probabilities[boundary]) 94 | 95 | axs[0].plot(x, y, color=color, linewidth=1, label='%d km' % ((boundary + 1) * 50)) 96 | axs[1].plot(x, calibrated_probabilities[boundary], color=color, linestyle='--', linewidth=1, label=r'%d km ($R^2$ = %.3f)' % ((boundary + 1) * 50, r_squared)) 97 | model_properties['calibration_models'][args['domain']][front_label]['%d km' % ((boundary + 1) * 50)] = ir 98 | 99 | for ax in axs: 100 | 101 | axs[0].set_xlabel("Forecast Probability (uncalibrated)") 102 | ax.set_xticks(axis_ticks) 103 | ax.set_yticks(axis_ticks) 104 | ax.set_xlim(0, 1) 105 | ax.set_ylim(0, 1) 106 | ax.grid() 107 | ax.legend() 108 | 109 | axs[0].set_title('Reliability Diagram') 110 | axs[1].set_title('Calibration (isotonic regression)') 111 | axs[0].set_ylabel("Observed Relative Frequency") 112 | axs[1].set_ylabel("Forecast Probability (calibrated)") 113 | 114 | with open('%s/model_%d/model_%d_properties.pkl' % (args['model_dir'], args['model_number'], args['model_number']), 'wb') as f: 115 | pickle.dump(model_properties, f) 116 | 117 | plt.suptitle(f"Model {args['model_number']} reliability/calibration: {DEFAULT_FRONT_NAMES[front_label]}") 118 | plt.savefig(f'%s/model_%d/model_%d_calibration_%s_%s.png' % (args['model_dir'], args['model_number'], args['model_number'], args['domain'], front_label), 119 | bbox_inches='tight', dpi=300) 120 | plt.close() 121 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /custom_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom metrics for U-Net models. 3 | - Brier Skill Score (BSS) 4 | - Critical Success Index (CSI) 5 | - Fractions Skill Score (FSS) 6 | 7 | Author: Andrew Justin (andrewjustinwx@gmail.com) 8 | Script version: 2023.9.21 9 | """ 10 | import tensorflow as tf 11 | 12 | 13 | def brier_skill_score(class_weights: list[int | float] = None): 14 | """ 15 | Brier skill score (BSS). 16 | 17 | class_weights: list of values or None 18 | List of weights to apply to each class. The length must be equal to the number of classes in y_pred and y_true. 19 | """ 20 | 21 | @tf.function 22 | def bss(y_true, y_pred): 23 | """ 24 | y_true: tf.Tensor 25 | One-hot encoded tensor containing labels. 26 | y_pred: tf.Tensor 27 | Tensor containing model predictions. 28 | """ 29 | 30 | squared_errors = tf.math.square(tf.subtract(y_true, y_pred)) 31 | 32 | if class_weights is not None: 33 | relative_class_weights = tf.cast(class_weights / tf.math.reduce_sum(class_weights), tf.float32) 34 | squared_errors *= relative_class_weights 35 | 36 | return 1 - tf.math.reduce_sum(squared_errors) / tf.size(squared_errors) 37 | 38 | return bss 39 | 40 | 41 | def critical_success_index(threshold: float = None, class_weights: list[int | float] = None): 42 | """ 43 | Critical success index (CSI). 44 | 45 | y_true: tf.Tensor 46 | One-hot encoded tensor containing labels. 47 | y_pred: tf.Tensor 48 | Tensor containing model predictions. 49 | threshold: float or None 50 | Optional probability threshold that binarizes y_pred. Values in y_pred greater than or equal to the threshold are 51 | set to 1, and 0 otherwise. 52 | If the threshold is set, it must be greater than 0 and less than 1. 53 | class_weights: list of values or None 54 | List of weights to apply to each class. The length must be equal to the number of classes in y_pred and y_true. 55 | """ 56 | 57 | @tf.function 58 | def csi(y_true, y_pred): 59 | """ 60 | y_true: tf.Tensor 61 | One-hot encoded tensor containing labels. 62 | y_pred: tf.Tensor 63 | Tensor containing model predictions. 64 | """ 65 | 66 | if threshold is not None: 67 | y_pred = tf.where(y_pred >= threshold, 1.0, 0.0) 68 | 69 | y_pred_neg = 1 - y_pred 70 | y_true_neg = 1 - y_true 71 | 72 | sum_over_axes = tf.range(tf.rank(y_pred) - 1) # Indices for axes to sum over. Excludes the final (class) dimension. 73 | 74 | true_positives = tf.math.reduce_sum(y_pred * y_true, axis=sum_over_axes) 75 | false_negatives = tf.math.reduce_sum(y_pred_neg * y_true, axis=sum_over_axes) 76 | false_positives = tf.math.reduce_sum(y_pred * y_true_neg, axis=sum_over_axes) 77 | 78 | if class_weights is not None: 79 | relative_class_weights = tf.cast(class_weights / tf.math.reduce_sum(class_weights), tf.float32) 80 | csi = tf.math.reduce_sum(tf.math.divide_no_nan(true_positives, true_positives + false_positives + false_negatives) * relative_class_weights) 81 | else: 82 | csi = tf.math.divide(tf.math.reduce_sum(true_positives), tf.math.reduce_sum(true_positives) + tf.math.reduce_sum(false_negatives) + tf.math.reduce_sum(false_positives)) 83 | 84 | return csi 85 | 86 | return csi 87 | 88 | 89 | def fractions_skill_score( 90 | num_dims: int, 91 | mask_size: int = 3, 92 | c: float = 1.0, 93 | cutoff: float = 0.5, 94 | want_hard_discretization: bool = False, 95 | class_weights: list[int | float] = None): 96 | """ 97 | Fractions skill score loss function. Visit https://github.com/CIRA-ML/custom_loss_functions for documentation. 98 | 99 | Parameters 100 | ---------- 101 | num_dims: int 102 | Number of dimensions for the mask. 103 | mask_size: int or tuple 104 | Size of the mask/pool in the AveragePooling layers. 105 | c: int or float 106 | C parameter in the sigmoid function. This will only be used if 'want_hard_discretization' is False. 107 | cutoff: float 108 | If 'want_hard_discretization' is True, y_true and y_pred will be discretized to only have binary values (0/1) 109 | want_hard_discretization: bool 110 | If True, y_true and y_pred will be discretized to only have binary values (0/1). 111 | If False, y_true and y_pred will be discretized using a sigmoid function. 112 | class_weights: list of values or None 113 | List of weights to apply to each class. The length must be equal to the number of classes in y_pred and y_true. 114 | 115 | Returns 116 | ------- 117 | fractions_skill_score: float 118 | Fractions skill score. 119 | """ 120 | 121 | pool_kwargs = {'pool_size': (mask_size, ) * num_dims, 122 | 'strides': (1, ) * num_dims, 123 | 'padding': 'valid'} 124 | 125 | if num_dims == 2: 126 | pool1 = tf.keras.layers.AveragePooling2D(**pool_kwargs) 127 | pool2 = tf.keras.layers.AveragePooling2D(**pool_kwargs) 128 | else: 129 | pool1 = tf.keras.layers.AveragePooling3D(**pool_kwargs) 130 | pool2 = tf.keras.layers.AveragePooling3D(**pool_kwargs) 131 | 132 | @tf.function 133 | def fss(y_true, y_pred): 134 | """ 135 | y_true: tf.Tensor 136 | One-hot encoded tensor containing labels. 137 | y_pred: tf.Tensor 138 | Tensor containing model predictions. 139 | """ 140 | 141 | if want_hard_discretization: 142 | y_true_binary = tf.where(y_true > cutoff, 1.0, 0.0) 143 | y_pred_binary = tf.where(y_pred > cutoff, 1.0, 0.0) 144 | else: 145 | y_true_binary = tf.math.sigmoid(c * (y_true - cutoff)) 146 | y_pred_binary = tf.math.sigmoid(c * (y_pred - cutoff)) 147 | 148 | y_true_density = pool1(y_true_binary) 149 | n_density_pixels = tf.cast((tf.shape(y_true_density)[1] * tf.shape(y_true_density)[2]), tf.float32) 150 | 151 | y_pred_density = pool2(y_pred_binary) 152 | 153 | if class_weights is None: 154 | MSE_n = tf.keras.metrics.mean_squared_error(y_true_density, y_pred_density) 155 | else: 156 | relative_class_weights = tf.cast(class_weights / tf.math.reduce_sum(class_weights), tf.float32) 157 | MSE_n = tf.reduce_mean(tf.math.square(y_true_density - y_pred_density) * relative_class_weights, axis=-1) 158 | 159 | O_n_squared_image = tf.keras.layers.Multiply()([y_true_density, y_true_density]) 160 | O_n_squared_vector = tf.keras.layers.Flatten()(O_n_squared_image) 161 | O_n_squared_sum = tf.reduce_sum(O_n_squared_vector) 162 | 163 | M_n_squared_image = tf.keras.layers.Multiply()([y_pred_density, y_pred_density]) 164 | M_n_squared_vector = tf.keras.layers.Flatten()(M_n_squared_image) 165 | M_n_squared_sum = tf.reduce_sum(M_n_squared_vector) 166 | 167 | MSE_n_ref = (O_n_squared_sum + M_n_squared_sum) / n_density_pixels 168 | 169 | my_epsilon = tf.keras.backend.epsilon() # this is 10^(-7) 170 | 171 | if want_hard_discretization: 172 | if MSE_n_ref == 0: 173 | return 1 - MSE_n 174 | else: 175 | return 1 - (MSE_n / MSE_n_ref) 176 | else: 177 | return 1 - (MSE_n / (MSE_n_ref + my_epsilon)) 178 | 179 | return fss 180 | -------------------------------------------------------------------------------- /custom_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom loss functions for U-Net models. 3 | - Brier Skill Score (BSS) 4 | - Critical Success Index (CSI) 5 | - Fractions Skill Score (FSS) 6 | 7 | Author: Andrew Justin (andrewjustinwx@gmail.com) 8 | Script version: 2023.5.20.D1 9 | """ 10 | import tensorflow as tf 11 | 12 | 13 | def brier_skill_score(class_weights: list = None): 14 | """ 15 | Brier skill score (BSS) loss function. 16 | 17 | class_weights: list of values or None 18 | List of weights to apply to each class. The length must be equal to the number of classes in y_pred and y_true. 19 | """ 20 | 21 | @tf.function 22 | def bss_loss(y_true, y_pred): 23 | """ 24 | y_true: tf.Tensor 25 | One-hot encoded tensor containing labels. 26 | y_pred: tf.Tensor 27 | Tensor containing model predictions. 28 | """ 29 | 30 | losses = tf.math.square(tf.subtract(y_true, y_pred)) 31 | 32 | if class_weights is not None: 33 | relative_class_weights = tf.cast(class_weights / tf.math.reduce_sum(class_weights), tf.float32) 34 | losses *= relative_class_weights 35 | 36 | brier_score_loss = tf.math.reduce_sum(losses) / tf.size(losses) 37 | return brier_score_loss 38 | 39 | return bss_loss 40 | 41 | 42 | def critical_success_index(threshold: float = None, 43 | class_weights: list[int | float] = None): 44 | """ 45 | Critical Success Index (CSI) loss function. 46 | 47 | y_true: tf.Tensor 48 | One-hot encoded tensor containing labels. 49 | y_pred: tf.Tensor 50 | Tensor containing model predictions. 51 | threshold: float or None 52 | Optional probability threshold that binarizes y_pred. Values in y_pred greater than or equal to the threshold are 53 | set to 1, and 0 otherwise. 54 | If the threshold is set, it must be greater than 0 and less than 1. 55 | class_weights: list of values or None 56 | List of weights to apply to each class. The length must be equal to the number of classes in y_pred and y_true. 57 | """ 58 | 59 | @tf.function 60 | def csi_loss(y_true, y_pred): 61 | """ 62 | y_true: tf.Tensor 63 | One-hot encoded tensor containing labels. 64 | y_pred: tf.Tensor 65 | Tensor containing model predictions. 66 | """ 67 | 68 | if threshold is not None: 69 | y_pred = tf.where(y_pred >= threshold, 1, 0) 70 | 71 | y_pred_neg = 1 - y_pred 72 | y_true_neg = 1 - y_true 73 | 74 | sum_over_axes = tf.range(tf.rank(y_pred) - 1) # Indices for axes to sum over. Excludes the final (class) dimension. 75 | 76 | true_positives = tf.math.reduce_sum(y_pred * y_true, axis=sum_over_axes) 77 | false_negatives = tf.math.reduce_sum(y_pred_neg * y_true, axis=sum_over_axes) 78 | false_positives = tf.math.reduce_sum(y_pred * y_true_neg, axis=sum_over_axes) 79 | 80 | if class_weights is not None: 81 | relative_class_weights = tf.cast(class_weights / tf.math.reduce_sum(class_weights), tf.float32) 82 | csi = tf.math.reduce_sum(tf.math.divide_no_nan(true_positives, true_positives + false_positives + false_negatives) * relative_class_weights) 83 | else: 84 | csi = tf.math.divide(tf.math.reduce_sum(true_positives), tf.math.reduce_sum(true_positives) + tf.math.reduce_sum(false_negatives) + tf.math.reduce_sum(false_positives)) 85 | 86 | return 1 - csi 87 | 88 | return csi_loss 89 | 90 | 91 | def fractions_skill_score( 92 | num_dims: int, 93 | mask_size: int = 3, 94 | c: float = 1.0, 95 | cutoff: float = 0.5, 96 | want_hard_discretization: bool = False, 97 | class_weights: list[int | float] = None): 98 | """ 99 | Fractions skill score loss function. Visit https://github.com/CIRA-ML/custom_loss_functions for documentation. 100 | 101 | Parameters 102 | ---------- 103 | num_dims: int 104 | Number of dimensions for the mask. 105 | mask_size: int or tuple 106 | Size of the mask/pool in the AveragePooling layers. 107 | c: int or float 108 | C parameter in the sigmoid function. This will only be used if 'want_hard_discretization' is False. 109 | cutoff: float 110 | If 'want_hard_discretization' is True, y_true and y_pred will be discretized to only have binary values (0/1) 111 | want_hard_discretization: bool 112 | If True, y_true and y_pred will be discretized to only have binary values (0/1). 113 | If False, y_true and y_pred will be discretized using a sigmoid function. 114 | class_weights: list of values or None 115 | List of weights to apply to each class. The length must be equal to the number of classes in y_pred and y_true. 116 | 117 | Returns 118 | ------- 119 | fractions_skill_score: float 120 | Fractions skill score. 121 | """ 122 | 123 | pool_kwargs = {'pool_size': (mask_size, ) * num_dims, 124 | 'strides': (1, ) * num_dims, 125 | 'padding': 'valid'} 126 | 127 | if num_dims == 2: 128 | pool1 = tf.keras.layers.AveragePooling2D(**pool_kwargs) 129 | pool2 = tf.keras.layers.AveragePooling2D(**pool_kwargs) 130 | else: 131 | pool1 = tf.keras.layers.AveragePooling3D(**pool_kwargs) 132 | pool2 = tf.keras.layers.AveragePooling3D(**pool_kwargs) 133 | 134 | @tf.function 135 | def fss_loss(y_true, y_pred): 136 | """ 137 | y_true: tf.Tensor 138 | One-hot encoded tensor containing labels. 139 | y_pred: tf.Tensor 140 | Tensor containing model predictions. 141 | """ 142 | 143 | if want_hard_discretization: 144 | y_true_binary = tf.where(y_true > cutoff, 1.0, 0.0) 145 | y_pred_binary = tf.where(y_pred > cutoff, 1.0, 0.0) 146 | else: 147 | y_true_binary = tf.math.sigmoid(c * (y_true - cutoff)) 148 | y_pred_binary = tf.math.sigmoid(c * (y_pred - cutoff)) 149 | 150 | y_true_density = pool1(y_true_binary) 151 | n_density_pixels = tf.cast((tf.shape(y_true_density)[1] * tf.shape(y_true_density)[2]), tf.float32) 152 | 153 | y_pred_density = pool2(y_pred_binary) 154 | 155 | if class_weights is None: 156 | MSE_n = tf.keras.metrics.mean_squared_error(y_true_density, y_pred_density) 157 | else: 158 | relative_class_weights = tf.cast(class_weights / tf.math.reduce_sum(class_weights), tf.float32) 159 | MSE_n = tf.reduce_mean(tf.math.square(y_true_density - y_pred_density) * relative_class_weights, axis=-1) 160 | 161 | O_n_squared_image = tf.keras.layers.Multiply()([y_true_density, y_true_density]) 162 | O_n_squared_vector = tf.keras.layers.Flatten()(O_n_squared_image) 163 | O_n_squared_sum = tf.reduce_sum(O_n_squared_vector) 164 | 165 | M_n_squared_image = tf.keras.layers.Multiply()([y_pred_density, y_pred_density]) 166 | M_n_squared_vector = tf.keras.layers.Flatten()(M_n_squared_image) 167 | M_n_squared_sum = tf.reduce_sum(M_n_squared_vector) 168 | 169 | MSE_n_ref = (O_n_squared_sum + M_n_squared_sum) / n_density_pixels 170 | 171 | my_epsilon = tf.keras.backend.epsilon() # this is 10^(-7) 172 | 173 | if want_hard_discretization: 174 | if MSE_n_ref == 0: 175 | return MSE_n 176 | else: 177 | return MSE_n / MSE_n_ref 178 | else: 179 | return MSE_n / (MSE_n_ref + my_epsilon) 180 | 181 | return fss_loss 182 | -------------------------------------------------------------------------------- /download_grib_files.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download grib files for GDAS and/or GFS data. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.8.23 6 | """ 7 | 8 | import argparse 9 | import os 10 | import pandas as pd 11 | import requests 12 | import urllib.error 13 | import wget 14 | import sys 15 | import datetime 16 | 17 | 18 | def bar_progress(current, total, width=None): 19 | progress_message = "Downloading %s: %d%% [%d/%d] MB " % (local_filename, current / total * 100, current / 1e6, total / 1e6) 20 | sys.stdout.write("\r" + progress_message) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--grib_outdir', type=str, required=True, help='Output directory for GDAS grib files downloaded from NCEP.') 26 | parser.add_argument('--model', type=str, required=True, help="NWP model to use as the data source.") 27 | parser.add_argument('--init_time', type=str, help="Initialization time of the model. Format: YYYY-MM-DD-HH.") 28 | parser.add_argument('--range', type=str, nargs=3, 29 | help="Download model data between a range of dates. Three arguments must be passed, with the first two arguments " 30 | "marking the bounds of the date range in the format YYYY-MM-DD-HH. The third argument is the frequency (e.g. 6H), " 31 | "which has the same formatting as the 'freq' keyword argument in pandas.date_range().") 32 | parser.add_argument('--forecast_hours', type=int, nargs="+", required=True, help="List of forecast hours to download for the given day.") 33 | parser.add_argument('--verbose', action='store_true', help="Include a progress bar for download progress.") 34 | 35 | args = vars(parser.parse_args()) 36 | 37 | args['model'] = args['model'].lower() 38 | 39 | # If --verbose is passed, include a progress bar to show the download progress 40 | bar = bar_progress if args['verbose'] else None 41 | 42 | if args['init_time'] is not None and args['range'] is not None: 43 | raise ValueError("Only one of the following arguments can be passed: --init_time, --range") 44 | elif args['init_time'] is None and args['range'] is None: 45 | raise ValueError("One of the following arguments must be passed: --init_time, --range") 46 | 47 | init_times = pd.date_range(args['init_time'], args['init_time']) if args['init_time'] is not None else pd.date_range(*args['range'][:2], freq=args['range'][-1]) 48 | 49 | files = [] # complete urls for the files to pull from AWS 50 | local_filenames = [] # filenames for the local files after downloading 51 | 52 | for init_time in init_times: 53 | if args['model'] == 'gdas': 54 | if datetime.datetime(init_time.year, init_time.month, init_time.day, init_time.hour) < datetime.datetime(2015, 6, 23, 0): 55 | raise ConnectionAbortedError("Cannot download GDAS data prior to June 23, 2015.") 56 | elif datetime.datetime(init_time.year, init_time.month, init_time.day, init_time.hour) < datetime.datetime(2017, 7, 20, 0): 57 | [files.append(f'https://noaa-gfs-bdp-pds.s3.amazonaws.com/gdas.%d%02d%02d/%02d/gdas1.t%02dz.pgrb2.0p25.f%03d' % (init_time.year, init_time.month, init_time.day, init_time.hour, init_time.hour, forecast_hour)) 58 | for forecast_hour in args['forecast_hours']] 59 | elif init_time.year < 2021: 60 | [files.append(f'https://noaa-gfs-bdp-pds.s3.amazonaws.com/gdas.%d%02d%02d/%02d/gdas.t%02dz.pgrb2.0p25.f%03d' % (init_time.year, init_time.month, init_time.day, init_time.hour, init_time.hour, forecast_hour)) 61 | for forecast_hour in args['forecast_hours']] 62 | else: 63 | [files.append(f"https://noaa-gfs-bdp-pds.s3.amazonaws.com/gdas.%d%02d%02d/%02d/atmos/gdas.t%02dz.pgrb2.0p25.f%03d" % (init_time.year, init_time.month, init_time.day, init_time.hour, init_time.hour, forecast_hour)) 64 | for forecast_hour in args['forecast_hours']] 65 | elif args['model'] == 'gfs': 66 | if datetime.datetime(init_time.year, init_time.month, init_time.day, init_time.hour) < datetime.datetime(2021, 2, 26, 0): 67 | raise ConnectionAbortedError("Cannot download GFS data prior to February 26, 2021.") 68 | elif datetime.datetime(init_time.year, init_time.month, init_time.day, init_time.hour) < datetime.datetime(2021, 3, 22, 0): 69 | [files.append(f"https://noaa-gfs-bdp-pds.s3.amazonaws.com/gfs.%d%02d%02d/%02d/gfs.t%02dz.pgrb2.0p25.f%03d" % (init_time.year, init_time.month, init_time.day, init_time.hour, init_time.hour, forecast_hour)) 70 | for forecast_hour in args['forecast_hours']] 71 | else: 72 | [files.append(f"https://noaa-gfs-bdp-pds.s3.amazonaws.com/gfs.%d%02d%02d/%02d/atmos/gfs.t%02dz.pgrb2.0p25.f%03d" % (init_time.year, init_time.month, init_time.day, init_time.hour, init_time.hour, forecast_hour)) 73 | for forecast_hour in args['forecast_hours']] 74 | elif args['model'] == 'hrrr': 75 | [files.append(f"https://noaa-hrrr-bdp-pds.s3.amazonaws.com/hrrr.%d%02d%02d/conus/hrrr.t%02dz.wrfprsf%02d.grib2" % (init_time.year, init_time.month, init_time.day, init_time.hour, forecast_hour)) 76 | for forecast_hour in args['forecast_hours']] 77 | elif args['model'] == 'rap': 78 | [files.append(f"https://noaa-rap-pds.s3.amazonaws.com/rap.%d%02d%02d/rap.t%02dz.wrfprsf%02d.grib2" % (init_time.year, init_time.month, init_time.day, init_time.hour, forecast_hour)) 79 | for forecast_hour in args['forecast_hours']] 80 | elif 'namnest' in args['model']: 81 | nest = args['model'].split('_')[-1] 82 | [files.append(f"https://nomads.ncep.noaa.gov/pub/data/nccf/com/nam/prod/nam.%d%02d%02d/nam.t%02dz.%snest.hiresf%02d.tm00.grib2" % (init_time.year, init_time.month, init_time.day, init_time.hour, nest, forecast_hour)) 83 | for forecast_hour in args['forecast_hours']] 84 | elif args['model'] == 'nam_12km': 85 | for forecast_hour in args['forecast_hours']: 86 | if forecast_hour in [0, 1, 2, 3, 6]: 87 | folder = 'analysis' # use the analysis folder as it contains more accurate data 88 | else: 89 | folder = 'forecast' # forecast hours other than 0, 1, 2, 3, 6 do not have analysis data 90 | files.append(f"https://www.ncei.noaa.gov/data/north-american-mesoscale-model/access/%s/%d%02d/%d%02d%02d/nam_218_%d%02d%02d_%02d00_%03d.grb2" % 91 | (folder, init_time.year, init_time.month, init_time.year, init_time.month, init_time.day, init_time.year, init_time.month, init_time.day, init_time.hour, forecast_hour)) 92 | [local_filenames.append("%s_%d%02d%02d%02d_f%03d.grib" % (args['model'], init_time.year, init_time.month, init_time.day, init_time.hour, forecast_hour)) for forecast_hour in args['forecast_hours']] 93 | 94 | for file, local_filename in zip(files, local_filenames): 95 | 96 | init_time = local_filename.split('_')[1] if 'nam' not in args['model'] else local_filename.split('_')[2] 97 | init_time = pd.to_datetime(f'{init_time[:4]}-{init_time[4:6]}-{init_time[6:8]}-{init_time[8:10]}') 98 | 99 | monthly_directory = '%s/%d%02d' % (args['grib_outdir'], init_time.year, init_time.month) # Directory for the grib files for the given days 100 | 101 | ### If the directory does not exist, check to see if the file link is valid. If the file link is NOT valid, then the directory will not be created since it will be empty. ### 102 | if not os.path.isdir(monthly_directory): 103 | if requests.head(file).status_code == requests.codes.ok or requests.head(file.replace('/atmos', '')).status_code == requests.codes.ok: 104 | os.mkdir(monthly_directory) 105 | 106 | full_file_path = f'{monthly_directory}/{local_filename}' 107 | 108 | if not os.path.isfile(full_file_path): 109 | try: 110 | wget.download(file, out=full_file_path, bar=bar) 111 | except urllib.error.HTTPError: 112 | print(f"Error downloading {file}") 113 | else: 114 | print(f"{full_file_path} already exists, skipping file....") 115 | -------------------------------------------------------------------------------- /evaluation/predict_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | **** EXPERIMENTAL SCRIPT TO REPLACE 'predict.py' IN THE NEAR FUTURE **** 3 | 4 | Generate predictions using a model with tensorflow datasets. 5 | 6 | Author: Andrew Justin (andrewjustinwx@gmail.com) 7 | Script version: 2023.7.24.D1 8 | """ 9 | import argparse 10 | import sys 11 | import os 12 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) # this line allows us to import scripts outside the current directory 13 | import file_manager as fm 14 | import numpy as np 15 | import pandas as pd 16 | from utils.settings import * 17 | import xarray as xr 18 | import tensorflow as tf 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset', type=str, help="Dataset for which to make predictions. Options are: 'training', 'validation', 'test'") 24 | parser.add_argument('--year_and_month', type=int, nargs=2, help="Year and month for which to make predictions.") 25 | parser.add_argument('--model_dir', type=str, help='Directory for the models.') 26 | parser.add_argument('--model_number', type=int, help='Model number.') 27 | parser.add_argument('--tf_indir', type=str, help='Directory for the tensorflow dataset that will be used when generating predictions.') 28 | parser.add_argument('--data_source', type=str, default='era5', help='Data source for variables') 29 | parser.add_argument('--gpu_device', type=int, nargs='+', help='GPU device numbers.') 30 | parser.add_argument('--memory_growth', action='store_true', help='Use memory growth on the GPU') 31 | parser.add_argument('--overwrite', action='store_true', help="Overwrite any existing prediction files.") 32 | 33 | args = vars(parser.parse_args()) 34 | 35 | model_properties = pd.read_pickle('%s/model_%d/model_%d_properties.pkl' % (args['model_dir'], args['model_number'], args['model_number'])) 36 | dataset_properties = pd.read_pickle('%s/dataset_properties.pkl' % args['tf_indir']) 37 | 38 | domain = dataset_properties['domain'] 39 | 40 | if domain == 'conus': 41 | hour_interval = 3 42 | else: 43 | hour_interval = 6 44 | 45 | # Some older models do not have the 'dataset_properties' dictionary 46 | try: 47 | front_types = model_properties['dataset_properties']['front_types'] 48 | num_dims = model_properties['dataset_properties']['num_dims'] 49 | except KeyError: 50 | front_types = model_properties['front_types'] 51 | if args['model_number'] in [6846496, 7236500, 7507525]: 52 | num_dims = (3, 3) 53 | 54 | if args['dataset'] is not None and args['year_and_month'] is not None: 55 | raise ValueError("--dataset and --year_and_month cannot be passed together.") 56 | elif args['dataset'] is None and args['year_and_month'] is None: 57 | raise ValueError("At least one of [--dataset, --year_and_month] must be passed.") 58 | elif args['year_and_month'] is not None: 59 | years, months = [args['year_and_month'][0]], [args['year_and_month'][1]] 60 | else: 61 | years, months = model_properties['%s_years' % args['dataset']], range(1, 13) 62 | 63 | ### Make sure that the dataset has the same attributes as the model ### 64 | if model_properties['normalization_parameters'] != dataset_properties['normalization_parameters']: 65 | raise ValueError("Cannot evaluate model with the selected dataset. Reason: normalization parameters do not match") 66 | if model_properties['dataset_properties']['front_types'] != dataset_properties['front_types']: 67 | raise ValueError("Cannot evaluate model with the selected dataset. Reason: front types do not match " 68 | f"(model: {model_properties['dataset_properties']['front_types']}, dataset: {dataset_properties['front_types']})") 69 | if model_properties['dataset_properties']['variables'] != dataset_properties['variables']: 70 | raise ValueError("Cannot evaluate model with the selected dataset. Reason: variables do not match " 71 | f"(model: {model_properties['dataset_properties']['variables']}, dataset: {dataset_properties['variables']})") 72 | if model_properties['dataset_properties']['pressure_levels'] != dataset_properties['pressure_levels']: 73 | raise ValueError("Cannot evaluate model with the selected dataset. Reason: pressure levels do not match " 74 | f"(model: {model_properties['dataset_properties']['pressure_levels']}, dataset: {dataset_properties['pressure_levels']})") 75 | 76 | gpus = tf.config.list_physical_devices(device_type='GPU') # Find available GPUs 77 | if len(gpus) > 0: 78 | 79 | print("Number of GPUs available: %d" % len(gpus)) 80 | 81 | # Only make the selected GPU(s) visible to TensorFlow 82 | if args['gpu_device'] is not None: 83 | tf.config.set_visible_devices(devices=[gpus[gpu] for gpu in args['gpu_device']], device_type='GPU') 84 | gpus = tf.config.get_visible_devices(device_type='GPU') # List of selected GPUs 85 | print("Using %d GPU(s):" % len(gpus), gpus) 86 | 87 | # Allow for memory growth on the GPU. This will only use the GPU memory that is required rather than allocating all the GPU's memory. 88 | if args['memory_growth']: 89 | tf.config.experimental.set_memory_growth(device=[gpu for gpu in gpus][0], enable=True) 90 | 91 | else: 92 | print('WARNING: No GPUs found, all computations will be performed on CPUs.') 93 | tf.config.set_visible_devices([], 'GPU') 94 | 95 | # The axis that the predicts will be concatenated on depends on the shape of the output, which is determined by deep supervision 96 | if model_properties['deep_supervision']: 97 | concat_axis = 1 98 | else: 99 | concat_axis = 0 100 | 101 | tf_ds_obj = fm.DataFileLoader(args['tf_indir'], data_file_type='%s-tensorflow' % args['data_source']) 102 | 103 | lons = np.arange(DEFAULT_DOMAIN_EXTENTS[domain][0], DEFAULT_DOMAIN_EXTENTS[domain][1] + 0.25, 0.25) 104 | lats = np.arange(DEFAULT_DOMAIN_EXTENTS[domain][2], DEFAULT_DOMAIN_EXTENTS[domain][3] + 0.25, 0.25)[::-1] 105 | 106 | model = fm.load_model(args['model_number'], args['model_dir']) 107 | 108 | for year in years: 109 | 110 | tf_ds_obj.test_years = [year, ] 111 | files_for_year = tf_ds_obj.data_files_test 112 | 113 | for month in months: 114 | 115 | prediction_dataset_path = '%s/model_%d/probabilities/model_%d_pred_%s_%d%02d.nc' % (args['model_dir'], args['model_number'], args['model_number'], domain, year, month) 116 | if os.path.isfile(prediction_dataset_path) and not args['overwrite']: 117 | print("WARNING: %s exists, pass the --overwrite argument to overwrite existing data." % prediction_dataset_path) 118 | continue 119 | 120 | input_file = [file for file in files_for_year if '_%d%02d' % (year, month) in file][0] 121 | tf_ds = tf.data.Dataset.load(input_file) 122 | time_array = np.arange(np.datetime64(f"{input_file[-9:-5]}-{input_file[-5:-3]}"), 123 | np.datetime64(f"{input_file[-9:-5]}-{input_file[-5:-3]}") + np.timedelta64(1, "M"), 124 | np.timedelta64(hour_interval, "h")) 125 | 126 | assert len(tf_ds) == len(time_array) # make sure tensorflow dataset has all timesteps 127 | 128 | tf_ds = tf_ds.batch(GPU_PREDICT_BATCH_SIZE) 129 | prediction = np.array(model.predict(tf_ds)).astype(np.float16) 130 | 131 | if model_properties['deep_supervision']: 132 | prediction = prediction[0, ...] # select the top output of the model, since it is the only one we care about 133 | 134 | if num_dims[1] == 3: 135 | # Take the maxmimum probability for each front type over the vertical dimension (pressure levels) 136 | prediction = np.amax(prediction, axis=3) # shape: (time, longitude, latitude, front type) 137 | 138 | prediction = prediction[..., 1:] # remove the 'no front' type from the array 139 | prediction = np.transpose(prediction, (0, 2, 1, 3)) # shape: (time, latitude, longitude, front type) 140 | 141 | xr.Dataset(data_vars={front_type: (('time', 'latitude', 'longitude'), prediction[:, :, :, front_type_no]) 142 | for front_type_no, front_type in enumerate(front_types)}, 143 | coords={'time': time_array, 'longitude': lons, 'latitude': lats}).astype('float32').\ 144 | to_netcdf(path=prediction_dataset_path, mode='w', engine='netcdf4') 145 | 146 | del prediction # Delete the prediction variable so it can be recreated for the next year 147 | -------------------------------------------------------------------------------- /convert_front_xml_to_netcdf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert front XML files to netCDF files. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.9.21 6 | """ 7 | 8 | import argparse 9 | import glob 10 | import numpy as np 11 | import os 12 | from utils import data_utils 13 | import xarray as xr 14 | import xml.etree.ElementTree as ET 15 | 16 | 17 | pgenType_identifiers = {'COLD_FRONT': 1, 'WARM_FRONT': 2, 'STATIONARY_FRONT': 3, 'OCCLUDED_FRONT': 4, 'COLD_FRONT_FORM': 5, 18 | 'WARM_FRONT_FORM': 6, 'STATIONARY_FRONT_FORM': 7, 'OCCLUDED_FRONT_FORM': 8, 'COLD_FRONT_DISS': 9, 19 | 'WARM_FRONT_DISS': 10, 'STATIONARY_FRONT_DISS': 11, 'OCCLUDED_FRONT_DISS': 12, 'INSTABILITY': 13, 20 | 'TROF': 14, 'TROPICAL_TROF': 15, 'DRY_LINE': 16} 21 | 22 | """ 23 | conus: 132 W to 60.25 W, 57 N to 26.25 N 24 | full: 130 E pointing eastward to 10 E, 80 N to 0.25 N 25 | global: 179.75 W to 180 E, 90 N to 89.75 N 26 | """ 27 | domain_coords = {'conus': {'lons': np.arange(-132, -60, 0.25), 'lats': np.arange(57, 25, -0.25)}, 28 | 'full': {'lons': np.append(np.arange(-179.75, 10, 0.25), np.arange(130, 180.25, 0.25)), 'lats': np.arange(80, 0, -0.25)}, 29 | 'global': {'lons': np.arange(-179.75, 180.25, 0.25), 'lats': np.arange(90, -90, -0.25)}} 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--xml_indir', type=str, required=True, help="Input directory for front XML files.") 35 | parser.add_argument('--netcdf_outdir', type=str, required=True, help="Output directory for front netCDF files.") 36 | parser.add_argument('--date', type=int, nargs=3, required=True, help="Date for the data to be read in. (year, month, day)") 37 | parser.add_argument('--distance', type=float, default=1., help="Interpolation distance in kilometers.") 38 | parser.add_argument('--domain', type=str, default='full', help="Domain for which to generate fronts.") 39 | 40 | args = vars(parser.parse_args()) 41 | 42 | year, month, day = args['date'] 43 | 44 | if args['domain'] == 'global': 45 | files = sorted(glob.glob("%s/IBM*_%04d%02d%02d*f*.xml" % (args['xml_indir'], year, month, day))) 46 | else: 47 | files = sorted(glob.glob("%s/pres*_%04d%02d%02d*f000.xml" % (args['xml_indir'], year, month, day))) 48 | 49 | domain_from_model = args['domain'] not in ['conus', 'full', 'global'] 50 | 51 | if domain_from_model: 52 | 53 | transform_args = {'hrrr': dict(std_parallels=(38.5, 38.5), lon_ref=262.5, lat_ref=38.5), 54 | 'nam_12km': dict(std_parallels=(25, 25), lon_ref=265, lat_ref=40), 55 | 'namnest_conus': dict(std_parallels=(38.5, 38.5), lon_ref=262.5, lat_ref=38.5)} 56 | 57 | if args['domain'] == 'hrrr': 58 | model_dataset = xr.open_dataset('hrrr_2023040100_f000.grib', backend_kwargs=dict(filter_by_keys={'typeOfLevel': 'isobaricInhPa'})) 59 | elif args['domain'] == 'nam_12km': 60 | model_dataset = xr.open_dataset('nam_12km_2021032300_f006.grib', backend_kwargs=dict(filter_by_keys={'typeOfLevel': 'isobaricInhPa'})) 61 | elif args['domain'] == 'rap': 62 | model_dataset = xr.open_dataset('rap_2021032300_f006.grib', backend_kwargs=dict(filter_by_keys={'typeOfLevel': 'isobaricInhPa'})) 63 | elif args['domain'] == 'namnest_conus': 64 | model_dataset = xr.open_dataset('namnest_conus_2023090800_f000.grib', backend_kwargs=dict(filter_by_keys={'typeOfLevel': 'isobaricInhPa'})) 65 | 66 | gridded_lons = model_dataset['longitude'].values.astype('float32') 67 | gridded_lats = model_dataset['latitude'].values.astype('float32') 68 | 69 | model_x_transform, model_y_transform = data_utils.lambert_conformal_to_cartesian(gridded_lons, gridded_lats, **transform_args[args['domain']]) 70 | gridded_x = model_x_transform[0, :] 71 | gridded_y = model_y_transform[:, 0] 72 | 73 | identifier = np.zeros(np.shape(gridded_lons)).astype('float32') 74 | 75 | else: 76 | 77 | gridded_lons = domain_coords[args['domain']]['lons'].astype('float32') 78 | gridded_lats = domain_coords[args['domain']]['lats'].astype('float32') 79 | identifier = np.zeros([len(gridded_lons), len(gridded_lats)]).astype('float32') 80 | 81 | for filename in files[-1:]: 82 | 83 | tree = ET.parse(filename, parser=ET.XMLParser(encoding='utf-8')) 84 | root = tree.getroot() 85 | date = os.path.basename(filename).split('_')[-1].split('.')[0].split('f')[0] # YYYYMMDDhh 86 | forecast_hour = int(filename.split('f')[-1].split('.')[0]) 87 | 88 | hour = date[-2:] 89 | 90 | ### Iterate through the individual fronts ### 91 | for line in root.iter('Line'): 92 | 93 | type_of_front = line.get("pgenType") # front type 94 | 95 | print(type_of_front) 96 | 97 | lons, lats = zip(*[[float(point.get("Lon")), float(point.get("Lat"))] for point in line.iter('Point')]) 98 | lons, lats = np.array(lons), np.array(lats) 99 | 100 | # If the front crosses the dateline or the 180th meridian, its coordinates must be modified for proper interpolation 101 | front_needs_modification = np.max(np.abs(np.diff(lons))) > 180 102 | 103 | if front_needs_modification or domain_from_model: 104 | lons = np.where(lons < 0, lons + 360, lons) # convert coordinates to a 360 degree system 105 | 106 | xs, ys = data_utils.haversine(lons, lats) # x/y coordinates in kilometers 107 | xy_linestring = data_utils.geometric(xs, ys) # convert coordinates to a LineString object 108 | x_new, y_new = data_utils.redistribute_vertices(xy_linestring, args['distance']).xy # interpolate x/y coordinates 109 | x_new, y_new = np.array(x_new), np.array(y_new) 110 | lon_new, lat_new = data_utils.reverse_haversine(x_new, y_new) # convert interpolated x/y coordinates to lat/lon 111 | 112 | date_and_time = np.datetime64('%04d-%02d-%02dT%02d' % (year, month, day, int(hour)), 'ns') 113 | 114 | expand_dims_args = {'time': np.atleast_1d(date_and_time)} 115 | 116 | if args['domain'] == 'global': 117 | filename_netcdf = "FrontObjects_%s_f%03d_%s.nc" % (date, forecast_hour, args['domain']) 118 | expand_dims_args['forecast_hour'] = np.atleast_1d(forecast_hour) 119 | else: 120 | filename_netcdf = "FrontObjects_%s_%s.nc" % (date, args['domain']) 121 | 122 | if domain_from_model: 123 | x_new *= 1000; y_new *= 1000 # convert front's points to meters 124 | x_transform, y_transform = data_utils.lambert_conformal_to_cartesian(lon_new, lat_new, **transform_args[args['domain']]) 125 | 126 | gridded_indices = np.dstack((np.digitize(y_transform, gridded_y), np.digitize(x_transform, gridded_x)))[0] # translate coordinate indices to grid 127 | gridded_indices_unique = np.unique(gridded_indices, axis=0) # remove duplicate coordinate indices 128 | 129 | # Remove points outside the domain 130 | gridded_indices_unique = gridded_indices_unique[np.where(gridded_indices_unique[:, 0] != len(gridded_y))] 131 | gridded_indices_unique = gridded_indices_unique[np.where(gridded_indices_unique[:, 1] != len(gridded_x))] 132 | 133 | identifier[gridded_indices_unique[:, 0], gridded_indices_unique[:, 1]] = pgenType_identifiers[type_of_front] # assign labels to the gridded points based on the front type 134 | 135 | fronts_ds = xr.Dataset({"identifier": (('y', 'x'), identifier)}, 136 | coords={"longitude": (('y', 'x'), gridded_lons), "latitude": (('y', 'x'), gridded_lats)}).expand_dims(**expand_dims_args) 137 | 138 | else: 139 | 140 | if front_needs_modification: 141 | lon_new = np.where(lon_new > 180, lon_new - 360, lon_new) # convert new longitudes to standard -180 to 180 range 142 | 143 | gridded_indices = np.dstack((np.digitize(lon_new, gridded_lons), np.digitize(lat_new, gridded_lats)))[0] # translate coordinate indices to grid 144 | gridded_indices_unique = np.unique(gridded_indices, axis=0) # remove duplicate coordinate indices 145 | 146 | # Remove points outside the domain 147 | gridded_indices_unique = gridded_indices_unique[np.where(gridded_indices_unique[:, 0] != len(gridded_lons))] 148 | gridded_indices_unique = gridded_indices_unique[np.where(gridded_indices_unique[:, 1] != len(gridded_lats))] 149 | 150 | identifier[gridded_indices_unique[:, 0], gridded_indices_unique[:, 1]] = pgenType_identifiers[type_of_front] # assign labels to the gridded points based on the front type 151 | 152 | fronts_ds = xr.Dataset({"identifier": (('longitude', 'latitude'), identifier)}, 153 | coords={"longitude": gridded_lons, "latitude": gridded_lats}).expand_dims(**expand_dims_args) 154 | 155 | if not os.path.isdir("%s/%d%02d" % (args['netcdf_outdir'], year, month)): 156 | os.mkdir("%s/%d%02d" % (args['netcdf_outdir'], year, month)) 157 | fronts_ds.to_netcdf(path="%s/%d%02d/%s" % (args['netcdf_outdir'], year, month, filename_netcdf), engine='netcdf4', mode='w') 158 | -------------------------------------------------------------------------------- /evaluation/prediction_plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot model predictions. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.9.10 6 | """ 7 | import itertools 8 | import argparse 9 | import pandas as pd 10 | import cartopy.crs as ccrs 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import xarray as xr 14 | from matplotlib import cm, colors # Here we explicitly import the cm and color modules to suppress a PyCharm bug 15 | import os 16 | import sys 17 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) # this line allows us to import scripts outside the current directory 18 | from utils import data_utils, settings 19 | from utils.plotting_utils import plot_background 20 | from skimage.morphology import skeletonize 21 | 22 | 23 | if __name__ == '__main__': 24 | """ 25 | All arguments listed in the examples are listed via argparse in alphabetical order below this comment block. 26 | """ 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--datetime', type=int, nargs=4, help='Date and time of the data. Pass 4 ints in the following order: year, month, day, hour') 29 | parser.add_argument('--domain', type=str, required=True, help='Domain of the data.') 30 | parser.add_argument('--domain_images', type=int, nargs=2, help='Number of images for each dimension the final stitched map for predictions: lon, lat') 31 | parser.add_argument('--forecast_hour', type=int, help='Forecast hour for the GDAS data') 32 | parser.add_argument('--model_dir', type=str, required=True, help='Directory for the models.') 33 | parser.add_argument('--model_number', type=int, required=True, help='Model number.') 34 | parser.add_argument('--fronts_netcdf_indir', type=str, help='Main directory for the netcdf files containing frontal objects.') 35 | parser.add_argument('--data_source', type=str, default='era5', help="Source of the variable data (ERA5, GDAS, etc.)") 36 | parser.add_argument('--prob_mask', type=float, nargs=2, default=[0.1, 0.1], 37 | help="Probability mask and the step/interval for the probability contours. Probabilities smaller than the mask will not be plotted.") 38 | parser.add_argument('--calibration', type=int, 39 | help="Neighborhood calibration distance in kilometers. Possible neighborhoods are 50, 100, 150, 200, and 250 km.") 40 | parser.add_argument('--deterministic', action='store_true', help="Plot deterministic splines.") 41 | parser.add_argument('--targets', action='store_true', help="Plot ground truth targets/labels.") 42 | parser.add_argument('--contours', action='store_true', help="Plot probability contours.") 43 | 44 | args = vars(parser.parse_args()) 45 | 46 | if args['deterministic'] and args['targets']: 47 | raise TypeError("Cannot plot deterministic splines and ground truth targets at the same time. Only one of --deterministic, --targets may be passed") 48 | 49 | if args['domain_images'] is None: 50 | args['domain_images'] = [1, 1] 51 | 52 | DEFAULT_COLORBAR_POSITION = {'conus': 0.74, 'full': 0.84, 'global': 0.74} 53 | cbar_position = DEFAULT_COLORBAR_POSITION['conus'] 54 | 55 | model_properties = pd.read_pickle(f"{args['model_dir']}/model_{args['model_number']}/model_{args['model_number']}_properties.pkl") 56 | 57 | args['data_source'] = args['data_source'].lower() 58 | 59 | extent = settings.DEFAULT_DOMAIN_EXTENTS[args['domain']] 60 | 61 | year, month, day, hour = args['datetime'][0], args['datetime'][1], args['datetime'][2], args['datetime'][3] 62 | 63 | ### Attempt to pull predictions from a yearly netcdf file generated with tensorflow datasets, otherwise try to pull a single netcdf file ### 64 | try: 65 | probs_file = f"{args['model_dir']}/model_{args['model_number']}/probabilities/model_{args['model_number']}_pred_{args['domain']}_{year}%02d.nc" % month 66 | fronts_file = '%s/%d%02d/FrontObjects_%d%02d%02d%02d_full.nc' % (args['fronts_netcdf_indir'], year, month, year, month, day, hour) 67 | plot_filename = '%s/model_%d/maps/model_%d_%d%02d%02d%02d_%s.png' % (args['model_dir'], args['model_number'], args['model_number'], year, month, day, hour, args['domain']) 68 | probs_ds = xr.open_mfdataset(probs_file).sel(time=['%d-%02d-%02dT%02d' % (year, month, day, hour), ]) 69 | except OSError: 70 | subdir_base = '%s_%dx%d' % (args['domain'], args['domain_images'][0], args['domain_images'][1]) 71 | probs_dir = f"{args['model_dir']}/model_{args['model_number']}/probabilities/{subdir_base}" 72 | 73 | if args['forecast_hour'] is not None: 74 | timestep = np.datetime64('%d-%02d-%02dT%02d' % (year, month, day, hour)).astype(object) 75 | forecast_timestep = timestep if args['forecast_hour'] == 0 else timestep + np.timedelta64(args['forecast_hour'], 'h').astype(object) 76 | new_year, new_month, new_day, new_hour = forecast_timestep.year, forecast_timestep.month, forecast_timestep.day, forecast_timestep.hour - (forecast_timestep.hour % 3) 77 | fronts_file = '%s/%s%s/FrontObjects_%s%s%s%02d_full.nc' % (args['fronts_netcdf_indir'], new_year, new_month, new_year, new_month, new_day, new_hour) 78 | filename_base = f'model_%d_{year}%02d%02d%02d_%s_%s_f%03d_%dx%d' % (args['model_number'], month, day, hour, args['domain'], args['data_source'], args['forecast_hour'], args['domain_images'][0], args['domain_images'][1]) 79 | else: 80 | fronts_file = '%s/%d%02d/FrontObjects_%d%02d%02d%02d_full.nc' % (args['fronts_netcdf_indir'], year, month, year, month, day, hour) 81 | filename_base = f'model_%d_{year}%02d%02d%02d_%s_%dx%d' % (args['model_number'], month, day, hour, args['domain'], args['domain_images'][0], args['domain_images'][1]) 82 | args['data_source'] = 'era5' 83 | 84 | plot_filename = '%s/model_%d/maps/%s/%s-same.png' % (args['model_dir'], args['model_number'], subdir_base, filename_base) 85 | probs_file = f'{probs_dir}/{filename_base}_probabilities.nc' 86 | probs_ds = xr.open_mfdataset(probs_file) 87 | 88 | try: 89 | front_types = model_properties['dataset_properties']['front_types'] 90 | except KeyError: 91 | front_types = model_properties['front_types'] 92 | 93 | labels = front_types 94 | fronts_found = False 95 | 96 | if args['targets']: 97 | right_title = 'Splines: NOAA fronts' 98 | try: 99 | fronts = xr.open_dataset(fronts_file).sel(longitude=slice(extent[0], extent[1]), latitude=slice(extent[3], extent[2])) 100 | fronts = data_utils.reformat_fronts(fronts, front_types=front_types) 101 | labels = fronts.attrs['labels'] 102 | fronts = xr.where(fronts == 0, float('NaN'), fronts) 103 | fronts_found = True 104 | except FileNotFoundError: 105 | print("No ground truth fronts found") 106 | 107 | if type(front_types) == str: 108 | front_types = [front_types, ] 109 | 110 | mask, prob_int = args['prob_mask'][0], args['prob_mask'][1] # Probability mask, contour interval for probabilities 111 | vmax, cbar_tick_adjust, cbar_label_adjust, n_colors = 1, prob_int, 10, 11 112 | levels = np.around(np.arange(0, 1 + prob_int, prob_int), 2) 113 | cbar_ticks = np.around(np.arange(mask, 1 + prob_int, prob_int), 2) 114 | 115 | contour_maps_by_type = [settings.DEFAULT_CONTOUR_CMAPS[label] for label in labels] 116 | front_colors_by_type = [settings.DEFAULT_FRONT_COLORS[label] for label in labels] 117 | front_names_by_type = [settings.DEFAULT_FRONT_NAMES[label] for label in labels] 118 | 119 | cmap_front = colors.ListedColormap(front_colors_by_type, name='from_list', N=len(front_colors_by_type)) 120 | norm_front = colors.Normalize(vmin=1, vmax=len(front_colors_by_type) + 1) 121 | 122 | probs_ds = probs_ds.isel(time=0) if args['data_source'] == 'era5' else probs_ds.isel(time=0, forecast_hour=0) 123 | probs_ds = probs_ds.transpose('latitude', 'longitude') 124 | 125 | for key in list(probs_ds.keys()): 126 | 127 | if args['deterministic']: 128 | spline_threshold = model_properties['front_obj_thresholds'][args['domain']][key]['100'] 129 | probs_ds[f'{key}_obj'] = (('latitude', 'longitude'), skeletonize(xr.where(probs_ds[key] > spline_threshold, 1, 0).values.copy(order='C'))) 130 | 131 | if args['calibration'] is not None: 132 | try: 133 | ir_model = model_properties['calibration_models'][args['domain']][key]['%d km' % args['calibration']] 134 | except KeyError: 135 | ir_model = model_properties['calibration_models']['conus'][key]['%d km' % args['calibration']] 136 | original_shape = np.shape(probs_ds[key].values) 137 | probs_ds[key].values = ir_model.predict(probs_ds[key].values.flatten()).reshape(original_shape) 138 | cbar_label = 'Probability (calibrated - %d km)' % args['calibration'] 139 | else: 140 | cbar_label = 'Probability (uncalibrated)' 141 | 142 | if len(front_types) > 1: 143 | all_possible_front_combinations = itertools.permutations(front_types, r=2) 144 | for combination in all_possible_front_combinations: 145 | probs_ds[combination[0]].values = np.where(probs_ds[combination[0]].values > probs_ds[combination[1]].values - 0.02, probs_ds[combination[0]].values, 0) 146 | 147 | probs_ds = xr.where(probs_ds > mask, probs_ds, float("NaN")) 148 | if args['data_source'] != 'era5': 149 | valid_time = timestep + np.timedelta64(args['forecast_hour'], 'h').astype(object) 150 | data_title = f"Run: {args['data_source'].upper()} {year}-%02d-%02d-%02dz F%03d \nPredictions valid: %d-%02d-%02d-%02dz" % (month, day, hour, args['forecast_hour'], valid_time.year, valid_time.month, valid_time.day, valid_time.hour) 151 | else: 152 | data_title = 'Data: ERA5 reanalysis %d-%02d-%02d-%02dz\n' \ 153 | 'Predictions valid: %d-%02d-%02d-%02dz' % (year, month, day, hour, year, month, day, hour) 154 | 155 | fig, ax = plt.subplots(1, 1, figsize=(22, 8), subplot_kw={'projection': ccrs.PlateCarree(central_longitude=np.mean(extent[:2]))}) 156 | plot_background(extent, ax=ax, linewidth=0.5) 157 | # ax.gridlines(draw_labels=True, zorder=0) 158 | 159 | cbar_front_labels = [] 160 | cbar_front_ticks = [] 161 | 162 | for front_no, front_key, front_name, front_label, cmap in zip(range(1, len(front_names_by_type) + 1), list(probs_ds.keys()), front_names_by_type, front_types, contour_maps_by_type): 163 | 164 | if args['contours']: 165 | cmap_probs, norm_probs = cm.get_cmap(cmap, n_colors), colors.Normalize(vmin=0, vmax=vmax) 166 | probs_ds[front_key].plot.contourf(ax=ax, x='longitude', y='latitude', norm=norm_probs, levels=levels, cmap=cmap_probs, transform=ccrs.PlateCarree(), alpha=0.75, add_colorbar=False) 167 | 168 | cbar_ax = fig.add_axes([cbar_position + (front_no * 0.015), 0.24, 0.015, 0.64]) 169 | cbar = plt.colorbar(cm.ScalarMappable(norm=norm_probs, cmap=cmap_probs), cax=cbar_ax, boundaries=levels[1:], alpha=0.75) 170 | cbar.set_ticklabels([]) 171 | 172 | if args['deterministic']: 173 | right_title = 'Splines: Deterministic first-guess fronts' 174 | cmap_deterministic = colors.ListedColormap(['None', front_colors_by_type[front_no - 1]], name='from_list', N=2) 175 | norm_deterministic = colors.Normalize(vmin=0, vmax=1) 176 | probs_ds[f'{front_key}_obj'].plot(ax=ax, x='longitude', y='latitude', cmap=cmap_deterministic, norm=norm_deterministic, 177 | transform=ccrs.PlateCarree(), alpha=0.9, add_colorbar=False) 178 | 179 | if fronts_found: 180 | fronts['identifier'].plot(ax=ax, x='longitude', y='latitude', cmap=cmap_front, norm=norm_front, transform=ccrs.PlateCarree(), add_colorbar=False) 181 | 182 | cbar_front_labels.append(front_name) 183 | cbar_front_ticks.append(front_no + 0.5) 184 | 185 | if args['contours']: 186 | cbar.set_label(cbar_label, rotation=90) 187 | cbar.set_ticks(cbar_ticks) 188 | cbar.set_ticklabels(cbar_ticks) 189 | 190 | cbar_front = plt.colorbar(cm.ScalarMappable(norm=norm_front, cmap=cmap_front), ax=ax, alpha=0.75, orientation='horizontal', shrink=0.5, pad=0.02) 191 | cbar_front.set_ticks(cbar_front_ticks) 192 | cbar_front.set_ticklabels(cbar_front_labels) 193 | cbar_front.set_label(r'$\bf{Front}$ $\bf{type}$') 194 | 195 | if fronts_found or args['deterministic']: 196 | ax.set_title(right_title, loc='right') 197 | 198 | ax.set_title('') 199 | ax.set_title(data_title, loc='left') 200 | 201 | plt.savefig(plot_filename, bbox_inches='tight', dpi=300) 202 | plt.close() 203 | -------------------------------------------------------------------------------- /create_era5_netcdf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create ERA5 netCDF datasets. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.6.7 6 | """ 7 | 8 | import argparse 9 | import numpy as np 10 | import os 11 | from utils import variables 12 | import xarray as xr 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--netcdf_era5_indir', type=str, required=True, help="Input directory for the global ERA5 netCDF files.") 18 | parser.add_argument('--netcdf_outdir', type=str, required=True, help="Output directory for front netCDF files.") 19 | parser.add_argument('--date', type=int, nargs=3, required=True, help="Date for the data to be read in. (year, month, day)") 20 | 21 | args = vars(parser.parse_args()) 22 | 23 | year, month, day = args['date'][0], args['date'][1], args['date'][2] 24 | 25 | era5_T_sfc_file = 'ERA5Global_%d_3hrly_2mT.nc' % year 26 | era5_Td_sfc_file = 'ERA5Global_%d_3hrly_2mTd.nc' % year 27 | era5_sp_file = 'ERA5Global_%d_3hrly_sp.nc' % year 28 | era5_u_sfc_file = 'ERA5Global_%d_3hrly_U10m.nc' % year 29 | era5_v_sfc_file = 'ERA5Global_%d_3hrly_V10m.nc' % year 30 | 31 | timestring = "%d-%02d-%02d" % (year, month, day) 32 | 33 | lons = np.append(np.arange(130, 360, 0.25), np.arange(0, 10.25, 0.25)) 34 | lats = np.arange(0, 80.25, 0.25)[::-1] 35 | lons360 = np.arange(130, 370.25, 0.25) 36 | 37 | T_sfc_full_day = xr.open_mfdataset("%s/Surface/%s" % (args['netcdf_era5_indir'], era5_T_sfc_file), chunks={'latitude': 721, 'longitude': 1440, 'time': 4}).sel(time=('%s' % timestring), longitude=lons, latitude=lats) 38 | Td_sfc_full_day = xr.open_mfdataset("%s/Surface/%s" % (args['netcdf_era5_indir'], era5_Td_sfc_file), chunks={'latitude': 721, 'longitude': 1440, 'time': 4}).sel(time=('%s' % timestring), longitude=lons, latitude=lats) 39 | sp_full_day = xr.open_mfdataset("%s/Surface/%s" % (args['netcdf_era5_indir'], era5_sp_file), chunks={'latitude': 721, 'longitude': 1440, 'time': 4}).sel(time=('%s' % timestring), longitude=lons, latitude=lats) 40 | u_sfc_full_day = xr.open_mfdataset("%s/Surface/%s" % (args['netcdf_era5_indir'], era5_u_sfc_file), chunks={'latitude': 721, 'longitude': 1440, 'time': 4}).sel(time=('%s' % timestring), longitude=lons, latitude=lats) 41 | v_sfc_full_day = xr.open_mfdataset("%s/Surface/%s" % (args['netcdf_era5_indir'], era5_v_sfc_file), chunks={'latitude': 721, 'longitude': 1440, 'time': 4}).sel(time=('%s' % timestring), longitude=lons, latitude=lats) 42 | 43 | PL_data = xr.open_mfdataset( 44 | paths=('%s/Pressure_Level/ERA5Global_PL_%s_3hrly_Q.nc' % (args['netcdf_era5_indir'], year), 45 | '%s/Pressure_Level/ERA5Global_PL_%s_3hrly_T.nc' % (args['netcdf_era5_indir'], year), 46 | '%s/Pressure_Level/ERA5Global_PL_%s_3hrly_U.nc' % (args['netcdf_era5_indir'], year), 47 | '%s/Pressure_Level/ERA5Global_PL_%s_3hrly_V.nc' % (args['netcdf_era5_indir'], year), 48 | '%s/Pressure_Level/ERA5Global_PL_%s_3hrly_Z.nc' % (args['netcdf_era5_indir'], year)), 49 | chunks={'latitude': 721, 'longitude': 1440, 'time': 4}).sel(time=('%s' % timestring), longitude=lons, latitude=lats) 50 | 51 | if not os.path.isdir('%s/%d%02d' % (args['netcdf_outdir'], year, month)): 52 | os.mkdir('%s/%d%02d' % (args['netcdf_outdir'], year, month)) 53 | 54 | for hour in range(0, 24, 3): 55 | 56 | print(f"saving ERA5 data for {year}-%02d-%02d-%02dz" % (month, day, hour)) 57 | 58 | timestep = '%d-%02d-%02dT%02d:00:00' % (year, month, day, hour) 59 | 60 | PL_850 = PL_data.sel(level=850, time=timestep) 61 | PL_900 = PL_data.sel(level=900, time=timestep) 62 | PL_950 = PL_data.sel(level=950, time=timestep) 63 | PL_1000 = PL_data.sel(level=1000, time=timestep) 64 | 65 | T_sfc = T_sfc_full_day.sel(time=timestep)['t2m'].values 66 | Td_sfc = Td_sfc_full_day.sel(time=timestep)['d2m'].values 67 | sp = sp_full_day.sel(time=timestep)['sp'].values 68 | u_sfc = u_sfc_full_day.sel(time=timestep)['u10'].values 69 | v_sfc = v_sfc_full_day.sel(time=timestep)['v10'].values 70 | 71 | theta_sfc = variables.potential_temperature(T_sfc, sp) # Potential temperature 72 | theta_e_sfc = variables.equivalent_potential_temperature(T_sfc, Td_sfc, sp) # Equivalent potential temperature 73 | theta_v_sfc = variables.virtual_temperature_from_dewpoint(T_sfc, Td_sfc, sp) # Virtual potential temperature 74 | theta_w_sfc = variables.wet_bulb_potential_temperature(T_sfc, Td_sfc, sp) # Wet-bulb potential temperature 75 | r_sfc = variables.mixing_ratio_from_dewpoint(Td_sfc, sp) # Mixing ratio 76 | q_sfc = variables.specific_humidity_from_dewpoint(Td_sfc, sp) # Specific humidity 77 | RH_sfc = variables.relative_humidity(T_sfc, Td_sfc) # Relative humidity 78 | Tv_sfc = variables.virtual_temperature_from_dewpoint(T_sfc, Td_sfc, sp) # Virtual temperature 79 | Tw_sfc = variables.wet_bulb_temperature(T_sfc, Td_sfc) # Wet-bulb temperature 80 | 81 | q_850 = PL_850['q'].values 82 | q_900 = PL_900['q'].values 83 | q_950 = PL_950['q'].values 84 | q_1000 = PL_1000['q'].values 85 | T_850 = PL_850['t'].values 86 | T_900 = PL_900['t'].values 87 | T_950 = PL_950['t'].values 88 | T_1000 = PL_1000['t'].values 89 | u_850 = PL_850['u'].values 90 | u_900 = PL_900['u'].values 91 | u_950 = PL_950['u'].values 92 | u_1000 = PL_1000['u'].values 93 | v_850 = PL_850['v'].values 94 | v_900 = PL_900['v'].values 95 | v_950 = PL_950['v'].values 96 | v_1000 = PL_1000['v'].values 97 | z_850 = PL_850['z'].values 98 | z_900 = PL_900['z'].values 99 | z_950 = PL_950['z'].values 100 | z_1000 = PL_1000['z'].values 101 | 102 | Td_850 = variables.dewpoint_from_specific_humidity(85000, T_850, q_850) 103 | Td_900 = variables.dewpoint_from_specific_humidity(90000, T_900, q_900) 104 | Td_950 = variables.dewpoint_from_specific_humidity(95000, T_950, q_950) 105 | Td_1000 = variables.dewpoint_from_specific_humidity(100000, T_1000, q_1000) 106 | r_850 = variables.mixing_ratio_from_dewpoint(Td_850, 85000) 107 | r_900 = variables.mixing_ratio_from_dewpoint(Td_900, 90000) 108 | r_950 = variables.mixing_ratio_from_dewpoint(Td_950, 95000) 109 | r_1000 = variables.mixing_ratio_from_dewpoint(Td_1000, 100000) 110 | RH_850 = variables.relative_humidity(T_850, Td_850) 111 | RH_900 = variables.relative_humidity(T_900, Td_900) 112 | RH_950 = variables.relative_humidity(T_950, Td_950) 113 | RH_1000 = variables.relative_humidity(T_1000, Td_1000) 114 | theta_850 = variables.potential_temperature(T_850, 85000) 115 | theta_900 = variables.potential_temperature(T_900, 90000) 116 | theta_950 = variables.potential_temperature(T_950, 95000) 117 | theta_1000 = variables.potential_temperature(T_1000, 100000) 118 | theta_e_850 = variables.equivalent_potential_temperature(T_850, Td_850, 85000) 119 | theta_e_900 = variables.equivalent_potential_temperature(T_900, Td_900, 90000) 120 | theta_e_950 = variables.equivalent_potential_temperature(T_950, Td_950, 95000) 121 | theta_e_1000 = variables.equivalent_potential_temperature(T_1000, Td_1000, 100000) 122 | theta_v_850 = variables.virtual_temperature_from_dewpoint(T_850, Td_850, 85000) 123 | theta_v_900 = variables.virtual_temperature_from_dewpoint(T_900, Td_900, 90000) 124 | theta_v_950 = variables.virtual_temperature_from_dewpoint(T_950, Td_950, 95000) 125 | theta_v_1000 = variables.virtual_temperature_from_dewpoint(T_1000, Td_1000, 100000) 126 | theta_w_850 = variables.wet_bulb_potential_temperature(T_850, Td_850, 85000) 127 | theta_w_900 = variables.wet_bulb_potential_temperature(T_900, Td_900, 90000) 128 | theta_w_950 = variables.wet_bulb_potential_temperature(T_950, Td_950, 95000) 129 | theta_w_1000 = variables.wet_bulb_potential_temperature(T_1000, Td_1000, 100000) 130 | Tv_850 = variables.virtual_temperature_from_dewpoint(T_850, Td_850, 85000) 131 | Tv_900 = variables.virtual_temperature_from_dewpoint(T_900, Td_900, 90000) 132 | Tv_950 = variables.virtual_temperature_from_dewpoint(T_950, Td_950, 95000) 133 | Tv_1000 = variables.virtual_temperature_from_dewpoint(T_1000, Td_1000, 100000) 134 | Tw_850 = variables.wet_bulb_temperature(T_850, Td_850) 135 | Tw_900 = variables.wet_bulb_temperature(T_900, Td_900) 136 | Tw_950 = variables.wet_bulb_temperature(T_950, Td_950) 137 | Tw_1000 = variables.wet_bulb_temperature(T_1000, Td_1000) 138 | 139 | pressure_levels = ['surface', 1000, 950, 900, 850] 140 | 141 | T = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 142 | Td = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 143 | Tv = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 144 | Tw = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 145 | theta = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 146 | theta_e = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 147 | theta_v = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 148 | theta_w = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 149 | RH = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 150 | r = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 151 | q = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 152 | u = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 153 | v = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 154 | sp_z = np.empty(shape=(len(pressure_levels), len(lats), len(lons360))) 155 | 156 | T[0, :, :], T[1, :, :], T[2, :, :], T[3, :, :], T[4, :, :] = T_sfc, T_1000, T_950, T_900, T_850 157 | Td[0, :, :], Td[1, :, :], Td[2, :, :], Td[3, :, :], Td[4, :, :] = Td_sfc, Td_1000, Td_950, Td_900, Td_850 158 | Tv[0, :, :], Tv[1, :, :], Tv[2, :, :], Tv[3, :, :], Tv[4, :, :] = Tv_sfc, Tv_1000, Tv_950, Tv_900, Tv_850 159 | Tw[0, :, :], Tw[1, :, :], Tw[2, :, :], Tw[3, :, :], Tw[4, :, :] = Tw_sfc, Tw_1000, Tw_950, Tw_900, Tw_850 160 | theta[0, :, :], theta[1, :, :], theta[2, :, :], theta[3, :, :], theta[4, :, :] = theta_sfc, theta_1000, theta_950, theta_900, theta_850 161 | theta_e[0, :, :], theta_e[1, :, :], theta_e[2, :, :], theta_e[3, :, :], theta_e[4, :, :] = theta_e_sfc, theta_e_1000, theta_e_950, theta_e_900, theta_e_850 162 | theta_v[0, :, :], theta_v[1, :, :], theta_v[2, :, :], theta_v[3, :, :], theta_v[4, :, :] = theta_v_sfc, theta_v_1000, theta_v_950, theta_v_900, theta_v_850 163 | theta_w[0, :, :], theta_w[1, :, :], theta_w[2, :, :], theta_w[3, :, :], theta_w[4, :, :] = theta_w_sfc, theta_w_1000, theta_w_950, theta_w_900, theta_w_850 164 | RH[0, :, :], RH[1, :, :], RH[2, :, :], RH[3, :, :], RH[4, :, :] = RH_sfc, RH_1000, RH_950, RH_900, RH_850 165 | r[0, :, :], r[1, :, :], r[2, :, :], r[3, :, :], r[4, :, :] = r_sfc, r_1000, r_950, r_900, r_850 166 | q[0, :, :], q[1, :, :], q[2, :, :], q[3, :, :], q[4, :, :] = q_sfc, q_1000, q_950, q_900, q_850 167 | u[0, :, :], u[1, :, :], u[2, :, :], u[3, :, :], u[4, :, :] = u_sfc, u_1000, u_950, u_900, u_850 168 | v[0, :, :], v[1, :, :], v[2, :, :], v[3, :, :], v[4, :, :] = v_sfc, v_1000, v_950, v_900, v_850 169 | sp_z[0, :, :], sp_z[1, :, :], sp_z[2, :, :], sp_z[3, :, :], sp_z[4, :, :] = sp/100, z_1000/98.0665, z_950/98.0665, z_900/98.0665, z_850/98.0665 170 | 171 | full_era5_dataset = xr.Dataset(data_vars=dict(T=(('pressure_level', 'latitude', 'longitude'), T), 172 | Td=(('pressure_level', 'latitude', 'longitude'), Td), 173 | Tv=(('pressure_level', 'latitude', 'longitude'), Tv), 174 | Tw=(('pressure_level', 'latitude', 'longitude'), Tw), 175 | theta=(('pressure_level', 'latitude', 'longitude'), theta), 176 | theta_e=(('pressure_level', 'latitude', 'longitude'), theta_e), 177 | theta_v=(('pressure_level', 'latitude', 'longitude'), theta_v), 178 | theta_w=(('pressure_level', 'latitude', 'longitude'), theta_w), 179 | RH=(('pressure_level', 'latitude', 'longitude'), RH), 180 | r=(('pressure_level', 'latitude', 'longitude'), r * 1000), 181 | q=(('pressure_level', 'latitude', 'longitude'), q * 1000), 182 | u=(('pressure_level', 'latitude', 'longitude'), u), 183 | v=(('pressure_level', 'latitude', 'longitude'), v), 184 | sp_z=(('pressure_level', 'latitude', 'longitude'), sp_z)), 185 | coords=dict(pressure_level=pressure_levels, latitude=lats, longitude=lons360)).astype('float32') 186 | 187 | full_era5_dataset = full_era5_dataset.expand_dims({'time': np.atleast_1d(timestep)}) 188 | 189 | full_era5_dataset.to_netcdf(path='%s/%d%02d/era5_%d%02d%02d%02d_full.nc' % (args['netcdf_outdir'], year, month, year, month, day, hour), mode='w', engine='netcdf4') 190 | -------------------------------------------------------------------------------- /evaluation/performance_diagrams.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot performance diagrams for a model. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.9.18 6 | """ 7 | import argparse 8 | import cartopy.crs as ccrs 9 | from matplotlib import colors 10 | from matplotlib.font_manager import FontProperties 11 | import matplotlib.pyplot as plt 12 | from matplotlib.ticker import FixedLocator 13 | import numpy as np 14 | import pandas as pd 15 | import pickle 16 | import xarray as xr 17 | import os 18 | import sys 19 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) # this line allows us to import scripts outside the current directory 20 | from utils import settings, plotting_utils 21 | 22 | 23 | if __name__ == '__main__': 24 | """ 25 | All arguments listed in the examples are listed via argparse in alphabetical order below this comment block. 26 | """ 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--confidence_level', type=int, default=95, help="Confidence interval. Options are: 90, 95, 99.") 29 | parser.add_argument('--dataset', type=str, help="'training', 'validation', or 'test'") 30 | parser.add_argument('--data_source', type=str, default='era5', help="Source of the variable data (ERA5, GDAS, etc.)") 31 | parser.add_argument('--domain_images', type=int, nargs=2, help='Number of images for each dimension the final stitched map for predictions: lon, lat') 32 | parser.add_argument('--domain', type=str, required=True, help='Domain of the data.') 33 | parser.add_argument('--forecast_hour', type=int, help='Forecast hour for the GDAS or GFS data.') 34 | parser.add_argument('--map_neighborhood', type=int, default=250, 35 | help="Neighborhood for the CSI map in kilometers. Options are: 50, 100, 150, 200, 250") 36 | parser.add_argument('--model_dir', type=str, required=True, help='Directory for the models.') 37 | parser.add_argument('--model_number', type=int, required=True, help='Model number.') 38 | 39 | args = vars(parser.parse_args()) 40 | 41 | model_properties_filepath = f"{args['model_dir']}/model_{args['model_number']}/model_{args['model_number']}_properties.pkl" 42 | model_properties = pd.read_pickle(model_properties_filepath) 43 | 44 | # Some older models do not have the 'dataset_properties' dictionary 45 | try: 46 | front_types = model_properties['dataset_properties']['front_types'] 47 | except KeyError: 48 | front_types = model_properties['front_types'] 49 | 50 | domain_extent_indices = settings.DEFAULT_DOMAIN_INDICES[args['domain']] 51 | 52 | stats_ds = xr.open_dataset('%s/model_%d/statistics/model_%d_statistics_%s_%s.nc' % (args['model_dir'], args['model_number'], args['model_number'], args['domain'], args['dataset'])) 53 | 54 | if type(front_types) == str: 55 | front_types = [front_types, ] 56 | 57 | # Probability threshold where CSI is maximized for each front type and domain 58 | max_csi_thresholds = dict() 59 | 60 | if args['domain'] not in list(max_csi_thresholds.keys()): 61 | max_csi_thresholds[args['domain']] = dict() 62 | 63 | for front_no, front_label in enumerate(front_types): 64 | 65 | if front_label not in list(max_csi_thresholds[args['domain']].keys()): 66 | max_csi_thresholds[args['domain']][front_label] = dict() 67 | 68 | ################################ CSI and reliability diagrams (panels a and b) ################################# 69 | true_positives_temporal = stats_ds[f'tp_temporal_{front_label}'].values 70 | false_positives_temporal = stats_ds[f'fp_temporal_{front_label}'].values 71 | false_negatives_temporal = stats_ds[f'fn_temporal_{front_label}'].values 72 | spatial_csi_ds = (stats_ds[f'tp_spatial_{front_label}'] / (stats_ds[f'tp_spatial_{front_label}'] + stats_ds[f'fp_spatial_{front_label}'] + stats_ds[f'fn_spatial_{front_label}'])).max('threshold') 73 | thresholds = stats_ds['threshold'].values 74 | 75 | if args['confidence_level'] != 90: 76 | CI_low, CI_high = (100 - args['confidence_level']) / 2, 50 + (args['confidence_level'] / 2) 77 | CI_low, CI_high = '%.1f' % CI_low, '%.1f' % CI_high 78 | else: 79 | CI_low, CI_high = 5, 95 80 | 81 | # Confidence intervals for POD and SR 82 | CI_POD = np.stack((stats_ds[f"POD_{CI_low}_{front_label}"].values, stats_ds[f"POD_{CI_high}_{front_label}"].values), axis=0) 83 | CI_SR = np.stack((stats_ds[f"SR_{CI_low}_{front_label}"].values, stats_ds[f"SR_{CI_high}_{front_label}"].values), axis=0) 84 | CI_CSI = np.stack((CI_SR ** -1 + CI_POD ** -1 - 1.) ** -1, axis=0) 85 | CI_FB = np.stack(CI_POD * (CI_SR ** -1), axis=0) 86 | 87 | # Remove the zeros 88 | try: 89 | polygon_stop_index = np.min(np.where(CI_POD == 0)[2]) 90 | except IndexError: 91 | polygon_stop_index = 100 92 | 93 | ### Statistics with shape (boundary, threshold) after taking the sum along the time axis (axis=0) ### 94 | true_positives_temporal_sum = np.sum(true_positives_temporal, axis=0) 95 | false_positives_temporal_sum = np.sum(false_positives_temporal, axis=0) 96 | false_negatives_temporal_sum = np.sum(false_negatives_temporal, axis=0) 97 | 98 | ### Find the number of true positives and false positives in each probability bin ### 99 | true_positives_diff = np.abs(np.diff(true_positives_temporal_sum)) 100 | false_positives_diff = np.abs(np.diff(false_positives_temporal_sum)) 101 | observed_relative_frequency = np.divide(true_positives_diff, true_positives_diff + false_positives_diff) 102 | 103 | pod = np.divide(true_positives_temporal_sum, true_positives_temporal_sum + false_negatives_temporal_sum) # Probability of detection 104 | sr = np.divide(true_positives_temporal_sum, true_positives_temporal_sum + false_positives_temporal_sum) # Success ratio 105 | 106 | fig, axs = plt.subplots(1, 2, figsize=(15, 6)) 107 | axarr = axs.flatten() 108 | 109 | sr_matrix, pod_matrix = np.meshgrid(np.linspace(0, 1, 101), np.linspace(0, 1, 101)) 110 | csi_matrix = 1 / ((1/sr_matrix) + (1/pod_matrix) - 1) # CSI coordinates 111 | fb_matrix = pod_matrix * (sr_matrix ** -1) # Frequency Bias coordinates 112 | CSI_LEVELS = np.linspace(0, 1, 11) # CSI contour levels 113 | FB_LEVELS = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 2, 3] # Frequency Bias levels 114 | cmap = 'Blues' # Colormap for the CSI contours 115 | axis_ticks = np.arange(0, 1.01, 0.1) 116 | axis_ticklabels = np.arange(0, 100.1, 10).astype(int) 117 | 118 | cs = axarr[0].contour(sr_matrix, pod_matrix, fb_matrix, FB_LEVELS, colors='black', linewidths=0.5, linestyles='--') # Plot FB levels 119 | axarr[0].clabel(cs, FB_LEVELS, fontsize=8) 120 | 121 | csi_contour = axarr[0].contourf(sr_matrix, pod_matrix, csi_matrix, CSI_LEVELS, cmap=cmap) # Plot CSI contours in 0.1 increments 122 | cbar = fig.colorbar(csi_contour, ax=axarr[0], pad=0.02, label='Critical Success Index (CSI)') 123 | cbar.set_ticks(axis_ticks) 124 | 125 | axarr[1].plot(thresholds, thresholds, color='black', linestyle='--', linewidth=0.5, label='Perfect Reliability') 126 | 127 | cell_text = [] # List of strings that will be used in the table near the bottom of this function 128 | 129 | ### CSI and reliability lines for each boundary ### 130 | boundary_colors = ['red', 'purple', 'brown', 'darkorange', 'darkgreen'] 131 | max_CSI_scores_by_boundary = np.zeros(shape=(5,)) 132 | for boundary, color in enumerate(boundary_colors): 133 | csi = np.power((1/sr[boundary]) + (1/pod[boundary]) - 1, -1) 134 | max_CSI_scores_by_boundary[boundary] = np.nanmax(csi) 135 | max_CSI_index = np.where(csi == max_CSI_scores_by_boundary[boundary])[0] 136 | max_CSI_threshold = thresholds[max_CSI_index][0] # Probability threshold where CSI is maximized 137 | max_csi_thresholds[args['domain']][front_label]['%s' % int((boundary + 1) * 50)] = np.round(max_CSI_threshold, 2) 138 | max_CSI_pod = pod[boundary][max_CSI_index][0] # POD where CSI is maximized 139 | max_CSI_sr = sr[boundary][max_CSI_index][0] # SR where CSI is maximized 140 | max_CSI_fb = max_CSI_pod / max_CSI_sr # Frequency bias 141 | 142 | cell_text.append([r'$\bf{%.2f}$' % max_CSI_threshold, 143 | r'$\bf{%.3f}$' % max_CSI_scores_by_boundary[boundary] + r'$^{%.3f}_{%.3f}$' % (CI_CSI[1, boundary, max_CSI_index][0], CI_CSI[0, boundary, max_CSI_index][0]), 144 | r'$\bf{%.1f}$' % (max_CSI_pod * 100) + r'$^{%.1f}_{%.1f}$' % (CI_POD[1, boundary, max_CSI_index][0] * 100, CI_POD[0, boundary, max_CSI_index][0] * 100), 145 | r'$\bf{%.1f}$' % (max_CSI_sr * 100) + r'$^{%.1f}_{%.1f}$' % (CI_SR[1, boundary, max_CSI_index][0] * 100, CI_SR[0, boundary, max_CSI_index][0] * 100), 146 | r'$\bf{%.1f}$' % ((1 - max_CSI_sr) * 100) + r'$^{%.1f}_{%.1f}$' % ((1 - CI_SR[1, boundary, max_CSI_index][0]) * 100, (1 - CI_SR[0, boundary, max_CSI_index][0]) * 100), 147 | r'$\bf{%.3f}$' % max_CSI_fb + r'$^{%.3f}_{%.3f}$' % (CI_FB[1, boundary, max_CSI_index][0], CI_FB[0, boundary, max_CSI_index][0])]) 148 | 149 | # Plot CSI lines 150 | axarr[0].plot(max_CSI_sr, max_CSI_pod, color=color, marker='*', markersize=10) 151 | axarr[0].plot(sr[boundary], pod[boundary], color=color, linewidth=1) 152 | 153 | # Plot reliability curve 154 | axarr[1].plot(thresholds[1:], observed_relative_frequency[boundary], color=color, linewidth=1) 155 | 156 | # Confidence interval 157 | xs = np.concatenate([CI_SR[0, boundary, :polygon_stop_index], CI_SR[1, boundary, :polygon_stop_index][::-1]]) 158 | ys = np.concatenate([CI_POD[0, boundary, :polygon_stop_index], CI_POD[1, boundary, :polygon_stop_index][::-1]]) 159 | axarr[0].fill(xs, ys, alpha=0.3, color=color) # Shade the confidence interval 160 | 161 | axarr[0].set_xticklabels(axis_ticklabels[::-1]) # False alarm rate on x-axis means values are reversed 162 | axarr[0].set_xlabel("False Alarm Rate (FAR; %)") 163 | axarr[0].set_ylabel("Probability of Detection (POD; %)") 164 | axarr[0].set_title(r'$\bf{a)}$ $\bf{CSI}$ $\bf{diagram}$ [confidence level = %d%%]' % args['confidence_level']) 165 | 166 | axarr[1].set_xticklabels(axis_ticklabels) 167 | axarr[1].set_xlabel("Forecast Probability (uncalibrated; %)") 168 | axarr[1].set_ylabel("Observed Relative Frequency (%)") 169 | axarr[1].set_title(r'$\bf{b)}$ $\bf{Reliability}$ $\bf{diagram}$') 170 | 171 | for ax in axarr: 172 | ax.set_xticks(axis_ticks) 173 | ax.set_yticks(axis_ticks) 174 | ax.set_yticklabels(axis_ticklabels) 175 | ax.grid(color='black', alpha=0.1) 176 | ax.set_xlim(0, 1) 177 | ax.set_ylim(0, 1) 178 | ################################################################################################################ 179 | 180 | ############################################# Data table (panel c) ############################################# 181 | columns = ['Threshold*', 'CSI', 'POD %', 'SR %', 'FAR %', 'FB'] # Column names 182 | rows = ['50 km', '100 km', '150 km', '200 km', '250 km'] # Row names 183 | 184 | table_axis = plt.axes([0.063, -0.06, 0.4, 0.2]) 185 | table_axis.set_title(r'$\bf{c)}$ $\bf{Data}$ $\bf{table}$ [confidence level = %d%%]' % args['confidence_level'], x=0.5, y=0.135, pad=-4) 186 | table_axis.axis('off') 187 | table_axis.text(0.16, -2.7, '* probability threshold where CSI is maximized') # Add disclaimer for probability threshold column 188 | stats_table = table_axis.table(cellText=cell_text, rowLabels=rows, rowColours=boundary_colors, colLabels=columns, cellLoc='center') 189 | stats_table.scale(1, 3) # Make the table larger 190 | 191 | ### Shade the cells and make the cell text larger ### 192 | for cell in stats_table._cells: 193 | stats_table._cells[cell].set_alpha(.7) 194 | stats_table._cells[cell].set_text_props(fontproperties=FontProperties(size='xx-large', stretch='expanded')) 195 | ################################################################################################################ 196 | 197 | ########################################## Spatial CSI map (panel d) ########################################### 198 | # Colorbar keyword arguments 199 | cbar_kwargs = {'label': 'CSI', 'pad': 0} 200 | 201 | # Adjust the spatial CSI plot based on the domain 202 | if args['domain'] == 'conus': 203 | spatial_axis_extent = [0.52, -0.582, 0.512, 0.544] 204 | cbar_kwargs['shrink'] = 0.919 205 | spatial_plot_xlabels = [-140, -105, -70] 206 | spatial_plot_ylabels = [30, 40, 50] 207 | else: 208 | spatial_axis_extent = [0.538, -0.6, 0.48, 0.577] 209 | cbar_kwargs['shrink'] = 0.862 210 | spatial_plot_xlabels = [-150, -120, -90, -60, -30, 0, 120, 150, 180] 211 | spatial_plot_ylabels = [0, 20, 40, 60, 80] 212 | 213 | right_labels = False # Disable latitude labels on the right side of the subplot 214 | top_labels = False # Disable longitude labels on top of the subplot 215 | left_labels = True # Latitude labels on the left side of the subplot 216 | bottom_labels = True # Longitude labels on the bottom of the subplot 217 | 218 | ## Set up the spatial CSI plot ### 219 | csi_cmap = plotting_utils.truncated_colormap('gnuplot2', maxval=0.9, n=10) 220 | extent = settings.DEFAULT_DOMAIN_EXTENTS[args['domain']] 221 | spatial_axis = plt.axes(spatial_axis_extent, projection=ccrs.Miller(central_longitude=250)) 222 | spatial_axis_title_text = r'$\bf{d)}$ $\bf{%d}$ $\bf{km}$ $\bf{CSI}$ $\bf{map}$' % args['map_neighborhood'] 223 | plotting_utils.plot_background(extent=extent, ax=spatial_axis) 224 | norm_probs = colors.Normalize(vmin=0.1, vmax=1) 225 | spatial_csi_ds = xr.where(spatial_csi_ds >= 0.1, spatial_csi_ds, float("NaN")) 226 | spatial_csi_ds.sel(boundary=args['map_neighborhood']).plot(ax=spatial_axis, x='longitude', y='latitude', norm=norm_probs, 227 | cmap=csi_cmap, transform=ccrs.PlateCarree(), alpha=0.6, cbar_kwargs=cbar_kwargs) 228 | spatial_axis.set_title(spatial_axis_title_text) 229 | gl = spatial_axis.gridlines(draw_labels=True, zorder=0, dms=True, x_inline=False, y_inline=False) 230 | gl.right_labels = right_labels 231 | gl.top_labels = top_labels 232 | gl.left_labels = left_labels 233 | gl.bottom_labels = bottom_labels 234 | gl.xlocator = FixedLocator(spatial_plot_xlabels) 235 | gl.ylocator = FixedLocator(spatial_plot_ylabels) 236 | gl.xlabel_style = {'size': 7} 237 | gl.ylabel_style = {'size': 8} 238 | ################################################################################################################ 239 | 240 | if args['domain'] == 'conus': 241 | domain_text = args['domain'].upper() 242 | else: 243 | domain_text = args['domain'] 244 | plt.suptitle(f'Model %d: %ss over %s domain' % (args['model_number'], settings.DEFAULT_FRONT_NAMES[front_label], domain_text), fontsize=20) # Create and plot the main title 245 | 246 | filename = f"%s/model_%d/performance_%s_%s_%s_{args['data_source']}.png" % (args['model_dir'], args['model_number'], front_label, args['dataset'], args['domain']) 247 | if args['data_source'] != 'era5': 248 | filename = filename.replace('.png', '_f%03d.png' % args['forecast_hour']) # Add forecast hour to the end of the filename 249 | 250 | plt.tight_layout() 251 | plt.savefig(filename, bbox_inches='tight', dpi=500) 252 | plt.close() 253 | 254 | # Thresholds for creating deterministic splines with different front types and neighborhoods 255 | model_properties['front_obj_thresholds'] = max_csi_thresholds 256 | 257 | with open(model_properties_filepath, 'wb') as f: 258 | pickle.dump(model_properties, f) 259 | -------------------------------------------------------------------------------- /convert_grib_to_netcdf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert GDAS and/or GFS grib files to netCDF files. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.7.24 6 | """ 7 | 8 | import argparse 9 | import time 10 | import xarray as xr 11 | from utils import variables 12 | import glob 13 | import numpy as np 14 | import os 15 | import tensorflow as tf 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--grib_indir', type=str, required=True, help="Input directory for GDAS grib files.") 21 | parser.add_argument('--model', required=True, type=str, help="GDAS or GFS") 22 | parser.add_argument('--netcdf_outdir', type=str, required=True, help="Output directory for the netCDF files.") 23 | parser.add_argument('--init_time', type=int, nargs=4, required=True, help="Date and time for the data to be read in. (year, month, day, hour)") 24 | parser.add_argument('--overwrite_grib', action='store_true', help="Overwrite the split grib files if they exist.") 25 | parser.add_argument('--delete_original_grib', action='store_true', help="Delete the original grib files after they are split.") 26 | parser.add_argument('--delete_split_grib', action='store_true', help="Delete the split grib files after they have been opened.") 27 | parser.add_argument('--gpu', action='store_true', 28 | help="Use a GPU to perform calculations of additional variables. This can provide enormous speedups when generating " 29 | "very large amounts of data.") 30 | 31 | args = vars(parser.parse_args()) 32 | 33 | gpus = tf.config.list_physical_devices(device_type='GPU') 34 | if len(gpus) > 0 and args['gpu']: 35 | print("Using GPU for variable derivations") 36 | tf.config.set_visible_devices(devices=gpus[0], device_type='GPU') 37 | gpus = tf.config.get_visible_devices(device_type='GPU') 38 | tf.config.experimental.set_memory_growth(device=gpus[0], enable=True) 39 | else: 40 | print("Using CPUs for variable derivations") 41 | tf.config.set_visible_devices([], 'GPU') 42 | 43 | args['model'] = args['model'].lower() 44 | year, month, day, hour = args['init_time'] 45 | 46 | resolution = 0.25 47 | 48 | keys_to_extract = ['gh', 'mslet', 'r', 'sp', 't', 'u', 'v'] 49 | 50 | pressure_level_file_indices = [0, 2, 4, 5, 6] 51 | surface_data_file_indices = [2, 4, 5, 6] 52 | raw_pressure_data_file_index = 3 53 | mslp_data_file_index = 1 54 | 55 | # all lon/lat values in degrees 56 | start_lon, end_lon = 0, 360 # western boundary, eastern boundary 57 | start_lat, end_lat = 90, -90 # northern boundary, southern boundary 58 | unified_longitude_indices = np.arange(0, 360 / resolution) 59 | unified_latitude_indices = np.arange(0, 180 / resolution + 1).astype(int) 60 | lon_coords_360 = np.arange(start_lon, end_lon + resolution, resolution) 61 | 62 | domain_indices_isel = {'longitude': unified_longitude_indices, 63 | 'latitude': unified_latitude_indices} 64 | 65 | chunk_sizes = {'latitude': 721, 'longitude': 1440} 66 | 67 | dataset_dimensions = ('forecast_hour', 'pressure_level', 'latitude', 'longitude') 68 | 69 | grib_filename_format = f"%s/%d%02d/%s_%d%02d%02d%02d_f*.grib" % (args['grib_indir'], year, month, args['model'], year, month, day, hour) 70 | individual_variable_filename_format = f"%s/%d%02d/%s_*_%d%02d%02d%02d.grib" % (args['grib_indir'], year, month, args['model'], year, month, day, hour) 71 | 72 | ### Split grib files into one file per variable ### 73 | grib_files = list(glob.glob(grib_filename_format)) 74 | grib_files = [file for file in grib_files if 'idx' not in file] 75 | 76 | for key in keys_to_extract: 77 | output_file = f"%s/%d%02d/%s_%s_%d%02d%02d%02d.grib" % (args['grib_indir'], year, month, args['model'], key, year, month, day, hour) 78 | if (os.path.isfile(output_file) and args['overwrite_grib']) or not os.path.isfile(output_file): 79 | os.system(f'grib_copy -w shortName={key} {" ".join(grib_files)} {output_file}') 80 | 81 | if args['delete_original_grib']: 82 | [os.remove(file) for file in grib_files] 83 | 84 | time.sleep(5) # Pause the code for 5 seconds to ensure that all contents of the individual files are preserved 85 | 86 | # grib files by variable 87 | grib_files = sorted(glob.glob(individual_variable_filename_format)) 88 | 89 | pressure_level_files = [grib_files[index] for index in pressure_level_file_indices] 90 | surface_data_files = [grib_files[index] for index in surface_data_file_indices] 91 | 92 | raw_pressure_data_file = grib_files[raw_pressure_data_file_index] 93 | if 'mslp_data_file_index' in locals(): 94 | mslp_data_file = grib_files[mslp_data_file_index] 95 | mslp_data = xr.open_dataset(mslp_data_file, engine='cfgrib', backend_kwargs={'filter_by_keys': {'typeOfLevel': 'meanSea'}}, chunks=chunk_sizes).drop_vars(['step']) 96 | 97 | pressure_levels = [1000, 950, 900, 850, 700, 500] 98 | 99 | # Open the datasets 100 | pressure_level_data = xr.open_mfdataset(pressure_level_files, engine='cfgrib', backend_kwargs={'filter_by_keys': {'typeOfLevel': 'isobaricInhPa'}}, chunks=chunk_sizes, combine='nested').sel(isobaricInhPa=pressure_levels).drop_vars(['step']) 101 | surface_data = xr.open_mfdataset(surface_data_files, engine='cfgrib', backend_kwargs={'filter_by_keys': {'typeOfLevel': 'sigma'}}, chunks=chunk_sizes).drop_vars(['step']) 102 | raw_pressure_data = xr.open_dataset(raw_pressure_data_file, engine='cfgrib', backend_kwargs={'filter_by_keys': {'typeOfLevel': 'surface', 'stepType': 'instant'}}, chunks=chunk_sizes).drop_vars(['step']) 103 | 104 | # Calculate the forecast hours using the surface_data dataset 105 | try: 106 | run_time = surface_data['time'].values.astype('int64') 107 | except KeyError: 108 | run_time = surface_data['run_time'].values.astype('int64') 109 | 110 | valid_time = surface_data['valid_time'].values.astype('int64') 111 | forecast_hours = np.array((valid_time - int(run_time)) / 3.6e12, dtype='int32') 112 | 113 | try: 114 | num_forecast_hours = len(forecast_hours) 115 | except TypeError: 116 | num_forecast_hours = 1 117 | forecast_hours = [forecast_hours, ] 118 | 119 | if args['model'] in ['gdas', 'gfs']: 120 | mslp = mslp_data['mslet'].values # mean sea level pressure (eta model reduction) 121 | mslp_z = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 122 | mslp_z[:, 0, :, :] = mslp / 100 # convert to hectopascals 123 | 124 | P = np.empty(shape=(num_forecast_hours, len(pressure_levels), chunk_sizes['latitude'], chunk_sizes['longitude']), dtype=np.float32) # create 3D array of pressure levels to match the shape of variable arrays 125 | for pressure_level_index, pressure_level in enumerate(pressure_levels): 126 | P[:, pressure_level_index, :, :] = pressure_level * 100 127 | 128 | print("Retrieving downloaded variables") 129 | ### Pressure level variables provided in the grib files ### 130 | T_pl = pressure_level_data['t'].values 131 | RH_pl = pressure_level_data['r'].values / 100 132 | u_pl = pressure_level_data['u'].values 133 | v_pl = pressure_level_data['v'].values 134 | z = pressure_level_data['gh'].values / 10 # Convert to dam 135 | if 'mslp_data_file_index' in locals(): 136 | mslp_z[:, 1:, :, :] = z 137 | 138 | ### Surface variables provided in the grib files ### 139 | sp = raw_pressure_data['sp'].values 140 | T_sigma = surface_data['t'].values 141 | RH_sigma = surface_data['r'].values / 100 142 | u_sigma = surface_data['u'].values 143 | v_sigma = surface_data['v'].values 144 | surface_data_latitudes = pressure_level_data['latitude'].values 145 | 146 | if len(gpus) > 0: 147 | T_pl = tf.convert_to_tensor(T_pl) 148 | RH_pl = tf.convert_to_tensor(RH_pl) 149 | P = tf.convert_to_tensor(P) 150 | sp = tf.convert_to_tensor(sp) 151 | T_sigma = tf.convert_to_tensor(T_sigma) 152 | RH_sigma = tf.convert_to_tensor(RH_sigma) 153 | 154 | print("Deriving additional variables") 155 | vap_pres_pl = RH_pl * variables.vapor_pressure(T_pl) 156 | Td_pl = variables.dewpoint_from_vapor_pressure(vap_pres_pl) 157 | Tv_pl = variables.virtual_temperature_from_dewpoint(T_pl, Td_pl, P) 158 | Tw_pl = variables.wet_bulb_temperature(T_pl, Td_pl) 159 | r_pl = variables.mixing_ratio_from_dewpoint(Td_pl, P) * 1000 # Convert to g/kg 160 | q_pl = variables.specific_humidity_from_dewpoint(Td_pl, P) * 1000 # Convert to g/kg 161 | theta_pl = variables.potential_temperature(T_pl, P) 162 | theta_e_pl = variables.equivalent_potential_temperature(T_pl, Td_pl, P) 163 | theta_v_pl = variables.virtual_potential_temperature(T_pl, Td_pl, P) 164 | theta_w_pl = variables.wet_bulb_potential_temperature(T_pl, Td_pl, P) 165 | 166 | # Create arrays of coordinates for the surface data 167 | vap_pres_sigma = RH_sigma * variables.vapor_pressure(T_sigma) 168 | Td_sigma = variables.dewpoint_from_vapor_pressure(vap_pres_sigma) 169 | Tv_sigma = variables.virtual_temperature_from_dewpoint(T_sigma, Td_sigma, sp) 170 | Tw_sigma = variables.wet_bulb_temperature(T_sigma, Td_sigma) 171 | r_sigma = variables.mixing_ratio_from_dewpoint(Td_sigma, sp) * 1000 # Convert to g/kg 172 | q_sigma = variables.specific_humidity_from_dewpoint(Td_sigma, sp) * 1000 # Convert to g/kg 173 | theta_sigma = variables.potential_temperature(T_sigma, sp) 174 | theta_e_sigma = variables.equivalent_potential_temperature(T_sigma, Td_sigma, sp) 175 | theta_v_sigma = variables.virtual_potential_temperature(T_sigma, Td_sigma, sp) 176 | theta_w_sigma = variables.wet_bulb_potential_temperature(T_sigma, Td_sigma, sp) 177 | 178 | T = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 179 | Td = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 180 | Tv = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 181 | Tw = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 182 | theta = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 183 | theta_e = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 184 | theta_v = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 185 | theta_w = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 186 | RH = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 187 | r = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 188 | q = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 189 | u = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 190 | v = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 191 | sp_z = np.empty(shape=(num_forecast_hours, len(pressure_levels) + 1, chunk_sizes['latitude'], chunk_sizes['longitude'])) 192 | 193 | sp /= 100 # pascals (Pa) --> hectopascals (hPa) 194 | if len(gpus) > 0: 195 | T[:, 0, :, :] = T_sigma.numpy() 196 | T[:, 1:, :, :] = T_pl.numpy() 197 | Td[:, 0, :, :] = Td_sigma.numpy() 198 | Td[:, 1:, :, :] = Td_pl.numpy() 199 | Tv[:, 0, :, :] = Tv_sigma.numpy() 200 | Tv[:, 1:, :, :] = Tv_pl.numpy() 201 | Tw[:, 0, :, :] = Tw_sigma.numpy() 202 | Tw[:, 1:, :, :] = Tw_pl.numpy() 203 | theta[:, 0, :, :] = theta_sigma.numpy() 204 | theta[:, 1:, :, :] = theta_pl.numpy() 205 | theta_e[:, 0, :, :] = theta_e_sigma.numpy() 206 | theta_e[:, 1:, :, :] = theta_e_pl.numpy() 207 | theta_v[:, 0, :, :] = theta_v_sigma.numpy() 208 | theta_v[:, 1:, :, :] = theta_v_pl.numpy() 209 | theta_w[:, 0, :, :] = theta_w_sigma.numpy() 210 | theta_w[:, 1:, :, :] = theta_w_pl.numpy() 211 | RH[:, 0, :, :] = RH_sigma.numpy() 212 | RH[:, 1:, :, :] = RH_pl.numpy() 213 | r[:, 0, :, :] = r_sigma.numpy() 214 | r[:, 1:, :, :] = r_pl.numpy() 215 | q[:, 0, :, :] = q_sigma.numpy() 216 | q[:, 1:, :, :] = q_pl.numpy() 217 | sp_z[:, 0, :, :] = sp.numpy() 218 | else: 219 | T[:, 0, :, :] = T_sigma 220 | T[:, 1:, :, :] = T_pl 221 | Td[:, 0, :, :] = Td_sigma 222 | Td[:, 1:, :, :] = Td_pl 223 | Tv[:, 0, :, :] = Tv_sigma 224 | Tv[:, 1:, :, :] = Tv_pl 225 | Tw[:, 0, :, :] = Tw_sigma 226 | Tw[:, 1:, :, :] = Tw_pl 227 | theta[:, 0, :, :] = theta_sigma 228 | theta[:, 1:, :, :] = theta_pl 229 | theta_e[:, 0, :, :] = theta_e_sigma 230 | theta_e[:, 1:, :, :] = theta_e_pl 231 | theta_v[:, 0, :, :] = theta_v_sigma 232 | theta_v[:, 1:, :, :] = theta_v_pl 233 | theta_w[:, 0, :, :] = theta_w_sigma 234 | theta_w[:, 1:, :, :] = theta_w_pl 235 | RH[:, 0, :, :] = RH_sigma 236 | RH[:, 1:, :, :] = RH_pl 237 | r[:, 0, :, :] = r_sigma 238 | r[:, 1:, :, :] = r_pl 239 | q[:, 0, :, :] = q_sigma 240 | q[:, 1:, :, :] = q_pl 241 | sp_z[:, 0, :, :] = sp 242 | 243 | u[:, 0, :, :] = u_sigma 244 | u[:, 1:, :, :] = u_pl 245 | v[:, 0, :, :] = v_sigma 246 | v[:, 1:, :, :] = v_pl 247 | sp_z[:, 1:, :, :] = z 248 | 249 | pressure_levels = ['surface', '1000', '950', '900', '850', '700', '500'] 250 | 251 | print("Building final dataset") 252 | 253 | full_dataset_coordinates = dict(forecast_hour=forecast_hours, pressure_level=pressure_levels) 254 | full_dataset_variables = dict(T=(dataset_dimensions, T), 255 | Td=(dataset_dimensions, Td), 256 | Tv=(dataset_dimensions, Tv), 257 | Tw=(dataset_dimensions, Tw), 258 | theta=(dataset_dimensions, theta), 259 | theta_e=(dataset_dimensions, theta_e), 260 | theta_v=(dataset_dimensions, theta_v), 261 | theta_w=(dataset_dimensions, theta_w), 262 | RH=(dataset_dimensions, RH), 263 | r=(dataset_dimensions, r), 264 | q=(dataset_dimensions, q), 265 | u=(dataset_dimensions, u), 266 | v=(dataset_dimensions, v), 267 | sp_z=(dataset_dimensions, sp_z)) 268 | 269 | if 'mslp_data_file_index' in locals(): 270 | full_dataset_variables['mslp_z'] = (('forecast_hour', 'pressure_level', 'latitude', 'longitude'), mslp_z) 271 | 272 | full_dataset_coordinates['latitude'] = pressure_level_data['latitude'] 273 | full_dataset_coordinates['longitude'] = pressure_level_data['longitude'] 274 | 275 | full_grib_dataset = xr.Dataset(data_vars=full_dataset_variables, 276 | coords=full_dataset_coordinates).astype('float32') 277 | 278 | full_grib_dataset = full_grib_dataset.expand_dims({'time': np.atleast_1d(pressure_level_data['time'].values)}) 279 | 280 | monthly_dir = '%s/%d%02d' % (args['netcdf_outdir'], year, month) 281 | 282 | if not os.path.isdir(monthly_dir): 283 | os.mkdir(monthly_dir) 284 | 285 | for fcst_hr_index, forecast_hour in enumerate(forecast_hours): 286 | full_grib_dataset.isel(forecast_hour=np.atleast_1d(fcst_hr_index)).to_netcdf(path=f"%s/{args['model'].lower()}_%d%02d%02d%02d_f%03d_global.nc" % (monthly_dir, year, month, day, hour, forecast_hour), mode='w', engine='netcdf4') 287 | 288 | if args['delete_split_grib']: 289 | grib_files = sorted(glob.glob(individual_variable_filename_format + "*")) 290 | [os.remove(file) for file in grib_files] 291 | -------------------------------------------------------------------------------- /evaluation/generate_performance_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate performance statistics for a model. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.9.2.D1 6 | """ 7 | import argparse 8 | import glob 9 | import numpy as np 10 | import pandas as pd 11 | import random 12 | import tensorflow as tf 13 | import xarray as xr 14 | import os 15 | import sys 16 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) # this line allows us to import scripts outside the current directory 17 | import file_manager as fm 18 | from utils import data_utils 19 | from utils.settings import DEFAULT_DOMAIN_EXTENTS 20 | 21 | 22 | def combine_statistics_for_dataset(): 23 | 24 | statistics_files = [] 25 | 26 | for year in years: 27 | statistics_files += list(sorted(glob.glob('%s/model_%d/statistics/model_%d_statistics_%s_%d*.nc' % 28 | (args['model_dir'], args['model_number'], args['model_number'], args['domain'], year)))) 29 | 30 | datasets_by_front_type = [] 31 | 32 | for front_no, front_type in enumerate(front_types): 33 | 34 | ### Temporal and spatial datasets need to be loaded separately because of differing dimensions (xarray bugs) ### 35 | dataset_performance_ds_temporal = xr.open_dataset(statistics_files[0], chunks={'time': 16})[['%s_temporal_%s' % (stat, front_type) for stat in ['tp', 'fp', 'tn', 'fn']]] 36 | dataset_performance_ds_spatial = xr.open_dataset(statistics_files[0], chunks={'time': 16})[['%s_spatial_%s' % (stat, front_type) for stat in ['tp', 'fp', 'tn', 'fn']]] 37 | for stats_file in statistics_files[1:]: 38 | dataset_performance_ds_spatial += xr.open_dataset(stats_file, chunks={'time': 16})[['%s_spatial_%s' % (stat, front_type) for stat in ['tp', 'fp', 'tn', 'fn']]] 39 | dataset_performance_ds_temporal = xr.merge([dataset_performance_ds_temporal, xr.open_dataset(stats_file, chunks={'time': 16})[['%s_temporal_%s' % (stat, front_type) for stat in ['tp', 'fp', 'tn', 'fn']]]]) 40 | dataset_performance_ds = xr.merge([dataset_performance_ds_spatial, dataset_performance_ds_temporal]) # Combine spatial and temporal data into one dataset 41 | 42 | tp_array_temporal = dataset_performance_ds['tp_temporal_%s' % front_type].values 43 | fp_array_temporal = dataset_performance_ds['fp_temporal_%s' % front_type].values 44 | fn_array_temporal = dataset_performance_ds['fn_temporal_%s' % front_type].values 45 | 46 | time_array = dataset_performance_ds['time'].values 47 | 48 | ### Bootstrap the temporal statistics to find confidence intervals ### 49 | POD_array = np.zeros([num_front_types, args['num_iterations'], 5, 100]) # probability of detection = TP / (TP + FN) 50 | SR_array = np.zeros([num_front_types, args['num_iterations'], 5, 100]) # success ratio = 1 - False Alarm Ratio = TP / (TP + FP) 51 | 52 | # 3 confidence intervals: 90, 95, and 99% 53 | CI_lower_POD = np.zeros([num_front_types, 3, 5, 100]) 54 | CI_lower_SR = np.zeros([num_front_types, 3, 5, 100]) 55 | CI_upper_POD = np.zeros([num_front_types, 3, 5, 100]) 56 | CI_upper_SR = np.zeros([num_front_types, 3, 5, 100]) 57 | 58 | num_timesteps = len(time_array) 59 | selectable_indices = range(num_timesteps) 60 | 61 | for iteration in range(args['num_iterations']): 62 | print(f"Iteration {iteration}/{args['num_iterations']}", end='\r') 63 | indices = random.choices(selectable_indices, k=num_timesteps) # Select a sample equal to the total number of timesteps 64 | 65 | POD_array[front_no, iteration, :, :] = np.divide(np.sum(tp_array_temporal[indices, :, :], axis=0), 66 | np.add(np.sum(tp_array_temporal[indices, :, :], axis=0), 67 | np.sum(fn_array_temporal[indices, :, :], axis=0))) 68 | SR_array[front_no, iteration, :, :] = np.divide(np.sum(tp_array_temporal[indices, :, :], axis=0), 69 | np.add(np.sum(tp_array_temporal[indices, :, :], axis=0), 70 | np.sum(fp_array_temporal[indices, :, :], axis=0))) 71 | print(f"Iteration {args['num_iterations']}/{args['num_iterations']}") 72 | 73 | ## Turn NaNs to zeros 74 | POD_array = np.nan_to_num(POD_array) 75 | SR_array = np.nan_to_num(SR_array) 76 | 77 | # Calculate confidence intervals at each probability bin 78 | for percent in np.arange(0, 100): 79 | CI_lower_POD[front_no, 0, :, percent] = np.percentile(POD_array[front_no, :, :, percent], q=5, axis=0) # lower bound for 90% confidence interval 80 | CI_lower_POD[front_no, 1, :, percent] = np.percentile(POD_array[front_no, :, :, percent], q=2.5, axis=0) # lower bound for 95% confidence interval 81 | CI_lower_POD[front_no, 2, :, percent] = np.percentile(POD_array[front_no, :, :, percent], q=0.5, axis=0) # lower bound for 99% confidence interval 82 | CI_upper_POD[front_no, 0, :, percent] = np.percentile(POD_array[front_no, :, :, percent], q=95, axis=0) # upper bound for 90% confidence interval 83 | CI_upper_POD[front_no, 1, :, percent] = np.percentile(POD_array[front_no, :, :, percent], q=97.5, axis=0) # upper bound for 95% confidence interval 84 | CI_upper_POD[front_no, 2, :, percent] = np.percentile(POD_array[front_no, :, :, percent], q=99.5, axis=0) # upper bound for 99% confidence interval 85 | 86 | CI_lower_SR[front_no, 0, :, percent] = np.percentile(SR_array[front_no, :, :, percent], q=5, axis=0) # lower bound for 90% confidence interval 87 | CI_lower_SR[front_no, 1, :, percent] = np.percentile(SR_array[front_no, :, :, percent], q=2.5, axis=0) # lower bound for 95% confidence interval 88 | CI_lower_SR[front_no, 2, :, percent] = np.percentile(SR_array[front_no, :, :, percent], q=0.5, axis=0) # lower bound for 99% confidence interval 89 | CI_upper_SR[front_no, 0, :, percent] = np.percentile(SR_array[front_no, :, :, percent], q=95, axis=0) # upper bound for 90% confidence interval 90 | CI_upper_SR[front_no, 1, :, percent] = np.percentile(SR_array[front_no, :, :, percent], q=97.5, axis=0) # upper bound for 95% confidence interval 91 | CI_upper_SR[front_no, 2, :, percent] = np.percentile(SR_array[front_no, :, :, percent], q=99.5, axis=0) # upper bound for 99% confidence interval 92 | 93 | dataset_performance_ds["POD_0.5_%s" % front_type] = (('boundary', 'threshold'), CI_lower_POD[front_no, 2, :, :]) 94 | dataset_performance_ds["POD_2.5_%s" % front_type] = (('boundary', 'threshold'), CI_lower_POD[front_no, 1, :, :]) 95 | dataset_performance_ds["POD_5_%s" % front_type] = (('boundary', 'threshold'), CI_lower_POD[front_no, 0, :, :]) 96 | dataset_performance_ds["POD_99.5_%s" % front_type] = (('boundary', 'threshold'), CI_upper_POD[front_no, 2, :, :]) 97 | dataset_performance_ds["POD_97.5_%s" % front_type] = (('boundary', 'threshold'), CI_upper_POD[front_no, 1, :, :]) 98 | dataset_performance_ds["POD_95_%s" % front_type] = (('boundary', 'threshold'), CI_upper_POD[front_no, 0, :, :]) 99 | dataset_performance_ds["SR_0.5_%s" % front_type] = (('boundary', 'threshold'), CI_lower_SR[front_no, 2, :, :]) 100 | dataset_performance_ds["SR_2.5_%s" % front_type] = (('boundary', 'threshold'), CI_lower_SR[front_no, 1, :, :]) 101 | dataset_performance_ds["SR_5_%s" % front_type] = (('boundary', 'threshold'), CI_lower_SR[front_no, 0, :, :]) 102 | dataset_performance_ds["SR_99.5_%s" % front_type] = (('boundary', 'threshold'), CI_upper_SR[front_no, 2, :, :]) 103 | dataset_performance_ds["SR_97.5_%s" % front_type] = (('boundary', 'threshold'), CI_upper_SR[front_no, 1, :, :]) 104 | dataset_performance_ds["SR_95_%s" % front_type] = (('boundary', 'threshold'), CI_upper_SR[front_no, 0, :, :]) 105 | 106 | datasets_by_front_type.append(dataset_performance_ds) 107 | 108 | final_performance_ds = xr.merge(datasets_by_front_type) 109 | final_performance_ds.to_netcdf(path='%s/model_%d/statistics/model_%d_statistics_%s_%s.nc' % (args['model_dir'], args['model_number'], args['model_number'], args['domain'], args['dataset']), mode='w', engine='netcdf4') 110 | 111 | 112 | if __name__ == '__main__': 113 | """ 114 | All arguments listed in the examples are listed via argparse in alphabetical order below this comment block. 115 | """ 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--dataset', type=str, help="Dataset for which to make predictions. Options are: 'training', 'validation', 'test'") 118 | parser.add_argument('--year_and_month', type=int, nargs=2, help="Year and month for which to make predictions.") 119 | parser.add_argument('--combine', action='store_true', help="Combine calculated statistics for a dataset.") 120 | parser.add_argument('--domain', type=str, help='Domain of the data.') 121 | parser.add_argument('--forecast_hour', type=int, help='Forecast hour for the GDAS data') 122 | parser.add_argument('--gpu_device', type=int, nargs='+', help='GPU device number.') 123 | parser.add_argument('--memory_growth', action='store_true', help='Use memory growth on the GPU') 124 | parser.add_argument('--model_dir', type=str, required=True, help='Directory for the models.') 125 | parser.add_argument('--model_number', type=int, required=True, help='Model number.') 126 | parser.add_argument('--num_iterations', type=int, default=10000, help='Number of iterations to perform when bootstrapping the data.') 127 | parser.add_argument('--fronts_netcdf_indir', type=str, help='Main directory for the netcdf files containing frontal objects.') 128 | parser.add_argument('--data_source', type=str, default='era5', help='Data source for variables') 129 | parser.add_argument('--overwrite', action='store_true', help="Overwrite any existing statistics files.") 130 | 131 | args = vars(parser.parse_args()) 132 | 133 | model_properties = pd.read_pickle('%s/model_%d/model_%d_properties.pkl' % (args['model_dir'], args['model_number'], args['model_number'])) 134 | domain = args['domain'] 135 | 136 | # Some older models do not have the 'dataset_properties' dictionary 137 | try: 138 | front_types = model_properties['dataset_properties']['front_types'] 139 | num_dims = model_properties['dataset_properties']['num_dims'] 140 | except KeyError: 141 | front_types = model_properties['front_types'] 142 | if args['model_number'] in [6846496, 7236500, 7507525]: 143 | num_dims = (3, 3) 144 | 145 | num_front_types = model_properties['classes'] - 1 146 | 147 | if args['dataset'] is not None and args['year_and_month'] is not None: 148 | raise ValueError("--dataset and --year_and_month cannot be passed together.") 149 | elif args['dataset'] is None and args['year_and_month'] is None: 150 | raise ValueError("At least one of [--dataset, --year_and_month] must be passed.") 151 | elif args['year_and_month'] is not None: 152 | years, months = [args['year_and_month'][0]], [args['year_and_month'][1]] 153 | else: 154 | years, months = model_properties['%s_years' % args['dataset']], range(1, 13) 155 | 156 | if args['dataset'] is not None and args['combine']: 157 | combine_statistics_for_dataset() 158 | exit() 159 | 160 | if args['gpu_device'] is not None: 161 | gpus = tf.config.list_physical_devices(device_type='GPU') 162 | tf.config.set_visible_devices(devices=[gpus[gpu] for gpu in args['gpu_device']], device_type='GPU') 163 | 164 | # Allow for memory growth on the GPU. This will only use the GPU memory that is required rather than allocating all the GPU's memory. 165 | if args['memory_growth']: 166 | tf.config.experimental.set_memory_growth(device=[gpus[gpu] for gpu in args['gpu_device']][0], enable=True) 167 | 168 | for year in years: 169 | 170 | era5_files_obj = fm.DataFileLoader(args['fronts_netcdf_indir'], data_file_type='fronts-netcdf') 171 | era5_files_obj.test_years = [year, ] # does not matter which year attribute we set the years to 172 | front_files = era5_files_obj.front_files_test 173 | 174 | for month in months: 175 | 176 | front_files_month = [file for file in front_files if '_%d%02d' % (year, month) in file] 177 | 178 | if args['domain'] == 'full': 179 | print("full") 180 | for front_file in front_files_month: 181 | if any(['%02d_full.nc' % hour in front_file for hour in np.arange(3, 27, 6)]): 182 | front_files_month.pop(front_files_month.index(front_file)) 183 | 184 | prediction_file = f'%s/model_%d/probabilities/model_%d_pred_%s_%d%02d.nc' % \ 185 | (args['model_dir'], args['model_number'], args['model_number'], args['domain'], year, month) 186 | 187 | stats_dataset_path = '%s/model_%d/statistics/model_%d_statistics_%s_%d%02d.nc' % (args['model_dir'], args['model_number'], args['model_number'], args['domain'], year, month) 188 | if os.path.isfile(stats_dataset_path) and not args['overwrite']: 189 | print("WARNING: %s exists, pass the --overwrite argument to overwrite existing data." % stats_dataset_path) 190 | continue 191 | 192 | probs_ds = xr.open_dataset(prediction_file) 193 | lons = probs_ds['longitude'].values 194 | lats = probs_ds['latitude'].values 195 | 196 | fronts_ds = xr.open_mfdataset(front_files_month, combine='nested', concat_dim='time')\ 197 | .sel(longitude=slice(DEFAULT_DOMAIN_EXTENTS[args['domain']][0], DEFAULT_DOMAIN_EXTENTS[args['domain']][1]), 198 | latitude=slice(DEFAULT_DOMAIN_EXTENTS[args['domain']][3], DEFAULT_DOMAIN_EXTENTS[args['domain']][2])) 199 | 200 | fronts_ds_month = data_utils.reformat_fronts(fronts_ds.sel(time='%d-%02d' % (year, month)), front_types) 201 | 202 | time_array = probs_ds['time'].values 203 | num_timesteps = len(time_array) 204 | 205 | tp_array_spatial = np.zeros(shape=[num_front_types, len(lats), len(lons), 5, 100]).astype('int64') 206 | fp_array_spatial = np.zeros(shape=[num_front_types, len(lats), len(lons), 5, 100]).astype('int64') 207 | tn_array_spatial = np.zeros(shape=[num_front_types, len(lats), len(lons), 5, 100]).astype('int64') 208 | fn_array_spatial = np.zeros(shape=[num_front_types, len(lats), len(lons), 5, 100]).astype('int64') 209 | 210 | tp_array_temporal = np.zeros(shape=[num_front_types, num_timesteps, 5, 100]).astype('int64') 211 | fp_array_temporal = np.zeros(shape=[num_front_types, num_timesteps, 5, 100]).astype('int64') 212 | tn_array_temporal = np.zeros(shape=[num_front_types, num_timesteps, 5, 100]).astype('int64') 213 | fn_array_temporal = np.zeros(shape=[num_front_types, num_timesteps, 5, 100]).astype('int64') 214 | 215 | thresholds = np.linspace(0.01, 1, 100) # Probability thresholds for calculating performance statistics 216 | boundaries = np.array([50, 100, 150, 200, 250]) # Boundaries for checking whether a front is present (kilometers) 217 | 218 | bool_tn_fn_dss = dict({front: tf.convert_to_tensor(xr.where(fronts_ds_month == front_no + 1, 1, 0)['identifier'].values) for front_no, front in enumerate(front_types)}) 219 | bool_tp_fp_dss = dict({front: None for front in front_types}) 220 | probs_dss = dict({front: tf.convert_to_tensor(probs_ds[front].values) for front in front_types}) 221 | 222 | performance_ds = xr.Dataset(coords={'time': time_array, 'longitude': lons, 'latitude': lats, 'boundary': boundaries, 'threshold': thresholds}) 223 | 224 | for front_no, front_type in enumerate(front_types): 225 | fronts_ds_month = data_utils.reformat_fronts(fronts_ds.sel(time='%d-%02d' % (year, month)), front_types) 226 | print("%d-%02d: %s (TN/FN)" % (year, month, front_type)) 227 | ### Calculate true/false negatives ### 228 | for i in range(100): 229 | """ 230 | True negative ==> model correctly predicts the lack of a front at a given point 231 | False negative ==> model does not predict a front, but a front exists 232 | 233 | The numbers of true negatives and false negatives are the same for all neighborhoods and are calculated WITHOUT expanding the fronts. 234 | If we were to calculate the negatives separately for each neighborhood, the number of misses would be artificially inflated, lowering the 235 | final CSI scores and making the neighborhood method effectively useless. 236 | """ 237 | tn = tf.where((probs_dss[front_type] < thresholds[i]) & (bool_tn_fn_dss[front_type] == 0), 1, 0) 238 | fn = tf.where((probs_dss[front_type] < thresholds[i]) & (bool_tn_fn_dss[front_type] == 1), 1, 0) 239 | 240 | tn_array_spatial[front_no, :, :, :, i] = tf.tile(tf.expand_dims(tf.reduce_sum(tn, axis=0), axis=-1), (1, 1, 5)) 241 | fn_array_spatial[front_no, :, :, :, i] = tf.tile(tf.expand_dims(tf.reduce_sum(fn, axis=0), axis=-1), (1, 1, 5)) 242 | tn_array_temporal[front_no, :, :, i] = tf.tile(tf.expand_dims(tf.reduce_sum(tn, axis=(1, 2)), axis=-1), (1, 5)) 243 | fn_array_temporal[front_no, :, :, i] = tf.tile(tf.expand_dims(tf.reduce_sum(fn, axis=(1, 2)), axis=-1), (1, 5)) 244 | 245 | ### Calculate true/false positives ### 246 | for boundary in range(5): 247 | fronts_ds_month = data_utils.expand_fronts(fronts_ds_month, iterations=2) # Expand fronts 248 | bool_tp_fp_dss[front_type] = tf.convert_to_tensor(xr.where(fronts_ds_month == front_no + 1, 1, 0)['identifier'].values) # 1 = cold front, 0 = not a cold front 249 | print("%d-%02d: %s (%d km)" % (year, month, front_type, (boundary + 1) * 50)) 250 | for i in range(100): 251 | """ 252 | True positive ==> model correctly identifies a front 253 | False positive ==> model predicts a front, but no front is present within the given neighborhood 254 | """ 255 | tp = tf.where((probs_dss[front_type] > thresholds[i]) & (bool_tp_fp_dss[front_type] == 1), 1, 0) 256 | fp = tf.where((probs_dss[front_type] > thresholds[i]) & (bool_tp_fp_dss[front_type] == 0), 1, 0) 257 | 258 | tp_array_spatial[front_no, :, :, boundary, i] = tf.reduce_sum(tp, axis=0) 259 | fp_array_spatial[front_no, :, :, boundary, i] = tf.reduce_sum(fp, axis=0) 260 | tp_array_temporal[front_no, :, boundary, i] = tf.reduce_sum(tp, axis=(1, 2)) 261 | fp_array_temporal[front_no, :, boundary, i] = tf.reduce_sum(fp, axis=(1, 2)) 262 | 263 | performance_ds["tp_spatial_%s" % front_type] = (('latitude', 'longitude', 'boundary', 'threshold'), tp_array_spatial[front_no]) 264 | performance_ds["fp_spatial_%s" % front_type] = (('latitude', 'longitude', 'boundary', 'threshold'), fp_array_spatial[front_no]) 265 | performance_ds["tn_spatial_%s" % front_type] = (('latitude', 'longitude', 'boundary', 'threshold'), tn_array_spatial[front_no]) 266 | performance_ds["fn_spatial_%s" % front_type] = (('latitude', 'longitude', 'boundary', 'threshold'), fn_array_spatial[front_no]) 267 | performance_ds["tp_temporal_%s" % front_type] = (('time', 'boundary', 'threshold'), tp_array_temporal[front_no]) 268 | performance_ds["fp_temporal_%s" % front_type] = (('time', 'boundary', 'threshold'), fp_array_temporal[front_no]) 269 | performance_ds["tn_temporal_%s" % front_type] = (('time', 'boundary', 'threshold'), tn_array_temporal[front_no]) 270 | performance_ds["fn_temporal_%s" % front_type] = (('time', 'boundary', 'threshold'), fn_array_temporal[front_no]) 271 | 272 | performance_ds.to_netcdf(path=stats_dataset_path, mode='w', engine='netcdf4') 273 | -------------------------------------------------------------------------------- /convert_netcdf_to_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert netCDF files containing variable and frontal boundary data into tensorflow datasets for model training. 3 | 4 | Author: Andrew Justin (andrewjustinwx@gmail.com) 5 | Script version: 2023.8.13 6 | """ 7 | import argparse 8 | import itertools 9 | import numpy as np 10 | import os 11 | import pandas as pd 12 | import pickle 13 | import tensorflow as tf 14 | import file_manager as fm 15 | from utils import data_utils, settings 16 | import xarray as xr 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--year_and_month', type=int, nargs=2, required=True, 22 | help="Year and month for the netcdf data to be converted to tensorflow datasets.") 23 | parser.add_argument('--variable_data_source', type=str, default='era5', help="Data source or model containing the variable data.") 24 | parser.add_argument('--variables_netcdf_indir', type=str, required=True, 25 | help="Input directory for the netCDF files containing variable data.") 26 | parser.add_argument('--fronts_netcdf_indir', type=str, required=True, 27 | help="Input directory for the netCDF files containing frontal boundary data.") 28 | parser.add_argument('--tf_outdir', type=str, required=True, 29 | help="Output directory for the generated tensorflow datasets.") 30 | parser.add_argument('--front_types', type=str, nargs='+', required=True, 31 | help="Code(s) for the front types that will be generated in the tensorflow datasets. Refer to documentation in 'utils.data_utils.reformat_fronts' " 32 | "for more information on these codes.") 33 | parser.add_argument('--variables', type=str, nargs='+', help='Variables to select') 34 | parser.add_argument('--pressure_levels', type=str, nargs='+', help='Variables pressure levels to select') 35 | parser.add_argument('--num_dims', type=int, nargs=2, default=[3, 3], help='Number of dimensions in the variables and front object images, repsectively.') 36 | parser.add_argument('--domain', type=str, default='conus', help='Domain from which to pull the images.') 37 | parser.add_argument('--override_extent', type=float, nargs=4, 38 | help='Override the default domain extent by selecting a custom extent. [min lon, max lon, min lat, max lat]') 39 | parser.add_argument('--evaluation_dataset', action='store_true', 40 | help=''' 41 | Boolean flag that determines if the dataset being generated will be used for evaluating a model. 42 | If this flag is True, all of the following keyword arguments will be set and any values provided to 'netcdf_to_tf' 43 | by the user will be overriden: 44 | * num_dims = (_, 2) <=== NOTE: The first value of this tuple will NOT be overriden. 45 | * images = (1, 1) 46 | * image_size will be set to the size of the domain. 47 | * keep_fraction will have no effect 48 | * shuffle_timesteps = False 49 | * shuffle_images = False 50 | * noise_fraction = 0.0 51 | * rotate_chance = 0.0 52 | * flip_chance_lon = 0.0 53 | * flip_chance_lat = 0.0 54 | ''') 55 | parser.add_argument('--images', type=int, nargs=2, default=[9, 1], 56 | help='Number of variables/front images along the longitude and latitude dimensions to generate for each timestep. The product of the 2 integers ' 57 | 'will be the total number of images generated per timestep.') 58 | parser.add_argument('--image_size', type=int, nargs=2, default=[128, 128], help='Size of the longitude and latitude dimensions of the images.') 59 | parser.add_argument('--shuffle_timesteps', action='store_true', 60 | help='Shuffle the timesteps when generating the dataset. This is particularly useful when generating very large ' 61 | 'datasets that cannot be shuffled on the fly during training.') 62 | parser.add_argument('--shuffle_images', action='store_true', 63 | help='Shuffle the order of the images in each timestep. This does NOT shuffle the entire dataset for the provided ' 64 | 'month, but rather only the images in each respective timestep. This is particularly useful when generating ' 65 | 'very large datasets that cannot be shuffled on the fly during training.') 66 | parser.add_argument('--add_previous_fronts', type=str, nargs='+', 67 | help='Optional front types from previous timesteps to include as predictors. If the dataset is over conus, the fronts ' 68 | 'will be pulled from the last 3-hour timestep. If the dataset is over the full domain, the fronts will be pulled ' 69 | 'from the last 6-hour timestep.') 70 | parser.add_argument('--front_dilation', type=int, default=0, help='Number of pixels to expand the fronts by in all directions.') 71 | parser.add_argument('--keep_fraction', type=float, default=0.0, 72 | help='The fraction of timesteps WITHOUT all necessary front types that will be retained in the dataset. Can be any float 0 <= x <= 1.') 73 | parser.add_argument('--noise_fraction', type=float, default=0.0, 74 | help='The fraction of pixels in each image that will contain noise. Can be any float 0 <= x < 1.') 75 | parser.add_argument('--rotate_chance', type=float, default=0.0, 76 | help='The probability that the current image will be rotated (in any direction, up to 270 degrees). Can be any float 0 <= x <= 1.') 77 | parser.add_argument('--flip_chance_lon', type=float, default=0.0, 78 | help='The probability that the current image will have its longitude dimension reversed. Can be any float 0 <= x <= 1.') 79 | parser.add_argument('--flip_chance_lat', type=float, default=0.0, 80 | help='The probability that the current image will have its latitude dimension reversed. Can be any float 0 <= x <= 1.') 81 | parser.add_argument('--overwrite', action='store_true', help='Overwrite the contents of any existing variables and fronts data.') 82 | parser.add_argument('--verbose', action='store_true', help='Print out the progress of the dataset generation.') 83 | parser.add_argument('--gpu_device', type=int, nargs='+', help='GPU device numbers.') 84 | parser.add_argument('--memory_growth', action='store_true', help='Use memory growth on the GPU(s).') 85 | 86 | args = vars(parser.parse_args()) 87 | 88 | if args['gpu_device'] is not None: 89 | gpus = tf.config.list_physical_devices(device_type='GPU') 90 | tf.config.set_visible_devices(devices=[gpus[gpu] for gpu in args['gpu_device']], device_type='GPU') 91 | 92 | # Allow for memory growth on the GPU. This will only use the GPU memory that is required rather than allocating all of the GPU's memory. 93 | if args['memory_growth']: 94 | tf.config.experimental.set_memory_growth(device=[gpus[gpu] for gpu in args['gpu_device']][0], enable=True) 95 | 96 | year, month = args['year_and_month'][0], args['year_and_month'][1] 97 | 98 | tf_dataset_folder_variables = f'%s/%s_%d%02d_tf' % (args['tf_outdir'], args['variable_data_source'], year, month) 99 | tf_dataset_folder_fronts = f"%s/fronts_%d%02d_tf" % (args['tf_outdir'], year, month) 100 | 101 | if os.path.isdir(tf_dataset_folder_variables) or os.path.isdir(tf_dataset_folder_fronts): 102 | if args['overwrite']: 103 | print("WARNING: Tensorflow dataset(s) already exist for the provided year and month and will be overwritten.") 104 | else: 105 | raise FileExistsError("Tensorflow dataset(s) already exist for the provided year and month. If you would like to " 106 | "overwrite the existing datasets, pass the --overwrite flag into the command line.") 107 | 108 | if not os.path.isdir(args['tf_outdir']): 109 | try: 110 | os.mkdir(args['tf_outdir']) 111 | except FileExistsError: # When running in parallel, sometimes multiple instances will try to create this directory at once, resulting in a FileExistsError 112 | pass 113 | 114 | dataset_props_file = '%s/dataset_properties.pkl' % args['tf_outdir'] 115 | 116 | if not os.path.isfile(dataset_props_file): 117 | """ 118 | Save critical dataset information to a pickle file so it can be referenced later when generating data for other months. 119 | """ 120 | 121 | if args['evaluation_dataset']: 122 | """ 123 | Override all keyword arguments so the dataset will be prepared for model evaluation. 124 | """ 125 | print("WARNING: This dataset will be used for model evaluation, so the following arguments will be set and " 126 | "any provided values for these arguments will be overriden:") 127 | args['num_dims'] = tuple(args['num_dims']) 128 | args['images'] = (1, 1) 129 | 130 | if args['override_extent'] is None: 131 | args['image_size'] = (settings.DEFAULT_DOMAIN_INDICES[args['domain']][1] - settings.DEFAULT_DOMAIN_INDICES[args['domain']][0], 132 | settings.DEFAULT_DOMAIN_INDICES[args['domain']][3] - settings.DEFAULT_DOMAIN_INDICES[args['domain']][2]) 133 | else: 134 | args['image_size'] = (int((args['override_extent'][1] - args['override_extent'][0]) / 0.25 + 1), 135 | int((args['override_extent'][3] - args['override_extent'][2]) / 0.25 + 1)) 136 | 137 | args['shuffle_timesteps'] = False 138 | args['shuffle_images'] = False 139 | args['noise_fraction'] = 0.0 140 | args['rotate_chance'] = 0.0 141 | args['flip_chance_lon'] = 0.0 142 | args['flip_chance_lat'] = 0.0 143 | 144 | print(f"images = {args['images']}\n" 145 | f"image_size = {args['image_size']}\n" 146 | f"shuffle_timesteps = False\n" 147 | f"shuffle_images = False\n" 148 | f"noise_fraction = 0.0\n" 149 | f"rotate_chance = 0.0\n" 150 | f"flip_chance_lon = 0.0\n" 151 | f"flip_chance_lat = 0.0\n") 152 | 153 | dataset_props = dict({}) 154 | dataset_props['normalization_parameters'] = data_utils.normalization_parameters 155 | for key in sorted(['front_types', 'variables', 'pressure_levels', 'num_dims', 'images', 'image_size', 'front_dilation', 156 | 'noise_fraction', 'rotate_chance', 'flip_chance_lon', 'flip_chance_lat', 'shuffle_images', 'shuffle_timesteps', 157 | 'domain', 'evaluation_dataset', 'add_previous_fronts', 'keep_fraction', 'override_extent']): 158 | dataset_props[key] = args[key] 159 | 160 | with open(dataset_props_file, 'wb') as f: 161 | pickle.dump(dataset_props, f) 162 | 163 | with open('%s/dataset_properties.txt' % args['tf_outdir'], 'w') as f: 164 | for key in sorted(dataset_props.keys()): 165 | f.write(f"{key}: {dataset_props[key]}\n") 166 | 167 | else: 168 | 169 | print("WARNING: Dataset properties file was found in %s. The following settings will be used from the file." % args['tf_outdir']) 170 | dataset_props = pd.read_pickle(dataset_props_file) 171 | 172 | for key in sorted(['front_types', 'variables', 'pressure_levels', 'num_dims', 'images', 'image_size', 'front_dilation', 173 | 'noise_fraction', 'rotate_chance', 'flip_chance_lon', 'flip_chance_lat', 'shuffle_images', 'shuffle_timesteps', 174 | 'domain', 'evaluation_dataset', 'add_previous_fronts', 'keep_fraction']): 175 | args[key] = dataset_props[key] 176 | print(f"%s: {args[key]}" % key) 177 | 178 | all_variables = ['T', 'Td', 'sp_z', 'u', 'v', 'theta_w', 'r', 'RH', 'Tv', 'Tw', 'theta_e', 'q', 'theta', 'theta_v'] 179 | all_pressure_levels = ['surface', '1000', '950', '900', '850'] if args['variable_data_source'] == 'era5' else ['surface', '1000', '950', '900', '850', '700', '500'] 180 | 181 | synoptic_only = True if args['domain'] == 'full' else False 182 | 183 | file_loader = fm.DataFileLoader(args['variables_netcdf_indir'], '%s-netcdf' % args['variable_data_source'], synoptic_only) 184 | file_loader.pair_with_fronts(args['fronts_netcdf_indir']) 185 | 186 | variables_netcdf_files = file_loader.data_files 187 | fronts_netcdf_files = file_loader.front_files 188 | 189 | print(stop) 190 | 191 | ### Grab front files from previous timesteps so previous fronts can be used as predictors ### 192 | if args['add_previous_fronts'] is not None: 193 | files_to_remove = [] # variables and front files that will be removed from the dataset 194 | previous_fronts_netcdf_files = [] 195 | for file in fronts_netcdf_files: 196 | current_timestep = np.datetime64(f'{file[-18:-14]}-{file[-14:-12]}-{file[-12:-10]}T{file[-10:-8]}') 197 | previous_timestep = (current_timestep - np.timedelta64(3, "h")).astype(object) 198 | prev_year, prev_month, prev_day, prev_hour = previous_timestep.year, previous_timestep.month, previous_timestep.day, previous_timestep.hour 199 | previous_fronts_file = '%s/%d%02d/FrontObjects_%d%02d%02d%02d_full.nc' % (args['fronts_netcdf_indir'], prev_year, prev_month, prev_year, prev_month, prev_day, prev_hour) 200 | if os.path.isfile(previous_fronts_file): 201 | previous_fronts_netcdf_files.append(previous_fronts_file) # Add the previous fronts to the dataset 202 | else: 203 | files_to_remove.append(file) 204 | 205 | ### Remove files from the dataset if previous fronts are not available ### 206 | if len(files_to_remove) > 0: 207 | for file in files_to_remove: 208 | index_to_pop = fronts_netcdf_files.index(file) 209 | variables_netcdf_files.pop(index_to_pop), fronts_netcdf_files.pop(index_to_pop) 210 | 211 | if args['shuffle_timesteps']: 212 | zipped_list = list(zip(variables_netcdf_files, fronts_netcdf_files)) 213 | np.random.shuffle(zipped_list) 214 | variables_netcdf_files, fronts_netcdf_files = zip(*zipped_list) 215 | 216 | # assert that the dates of the files match 217 | files_match_flag = all(os.path.basename(variables_file).split('_')[1] == os.path.basename(fronts_file).split('_')[1] for variables_file, fronts_file in zip(variables_netcdf_files, fronts_netcdf_files)) 218 | 219 | if args['override_extent'] is None: 220 | isel_kwargs = {'longitude': slice(settings.DEFAULT_DOMAIN_INDICES[args['domain']][0], settings.DEFAULT_DOMAIN_INDICES[args['domain']][1]), 221 | 'latitude': slice(settings.DEFAULT_DOMAIN_INDICES[args['domain']][2], settings.DEFAULT_DOMAIN_INDICES[args['domain']][3])} 222 | domain_size = (int(settings.DEFAULT_DOMAIN_INDICES[args['domain']][1] - settings.DEFAULT_DOMAIN_INDICES[args['domain']][0]), 223 | int(settings.DEFAULT_DOMAIN_INDICES[args['domain']][3] - settings.DEFAULT_DOMAIN_INDICES[args['domain']][2])) 224 | else: 225 | isel_kwargs = {'longitude': slice(int((args['override_extent'][0] - settings.DEFAULT_DOMAIN_EXTENTS[args['domain']][0]) // 0.25), 226 | int((args['override_extent'][1] - settings.DEFAULT_DOMAIN_EXTENTS[args['domain']][0]) // 0.25) + 1), 227 | 'latitude': slice(int((settings.DEFAULT_DOMAIN_EXTENTS[args['domain']][3] - args['override_extent'][3]) // 0.25), 228 | int((settings.DEFAULT_DOMAIN_EXTENTS[args['domain']][3] - args['override_extent'][2]) // 0.25) + 1)} 229 | domain_size = (int((args['override_extent'][1] - args['override_extent'][0]) // 0.25), 230 | int((args['override_extent'][3] - args['override_extent'][2]) // 0.25)) 231 | 232 | if not files_match_flag: 233 | raise OSError("%s/fronts files do not match") 234 | 235 | variables_to_use = all_variables if args['variables'] is None else args['variables'] 236 | args['pressure_levels'] = all_pressure_levels if args['pressure_levels'] is None else [lvl for lvl in all_pressure_levels if lvl in args['pressure_levels']] 237 | 238 | num_timesteps = len(variables_netcdf_files) 239 | timesteps_kept = 0 240 | timesteps_discarded = 0 241 | 242 | for timestep_no in range(num_timesteps): 243 | 244 | front_dataset = xr.open_dataset(fronts_netcdf_files[timestep_no], engine='netcdf4').isel(**isel_kwargs).astype('float16') 245 | 246 | ### Reformat the fronts in the current timestep ### 247 | if args['front_types'] is not None: 248 | front_dataset = data_utils.reformat_fronts(front_dataset, args['front_types']) 249 | num_front_types = front_dataset.attrs['num_types'] + 1 250 | else: 251 | num_front_types = 17 252 | 253 | if args['front_dilation'] > 0: 254 | front_dataset = data_utils.expand_fronts(front_dataset, iterations=args['front_dilation']) # expand the front labels 255 | 256 | keep_timestep = np.random.random() <= args['keep_fraction'] # boolean flag for keeping timesteps without all front types 257 | 258 | front_dataset = front_dataset.isel(time=0).to_array().transpose('longitude', 'latitude', 'variable') 259 | front_bins = np.bincount(front_dataset.values.astype('int64').flatten(), minlength=num_front_types) # counts for each front type 260 | all_fronts_present = all([front_count > 0 for front_count in front_bins]) > 0 # boolean flag that says if all front types are present in the current timestep 261 | 262 | if all_fronts_present or keep_timestep or args['evaluation_dataset']: 263 | 264 | if args['variable_data_source'] != 'era5': 265 | isel_kwargs['forecast_hour'] = 0 266 | 267 | variables_dataset = xr.open_dataset(variables_netcdf_files[timestep_no], engine='netcdf4')[variables_to_use].isel(**isel_kwargs).sel(pressure_level=args['pressure_levels']).transpose('time', 'longitude', 'latitude', 'pressure_level').astype('float16') 268 | variables_dataset = data_utils.normalize_variables(variables_dataset).isel(time=0).transpose('longitude', 'latitude', 'pressure_level').astype('float16') 269 | 270 | ### Reformat the fronts from the previous timestep ### 271 | if args['add_previous_fronts'] is not None: 272 | previous_front_dataset = xr.open_dataset(previous_fronts_netcdf_files[timestep_no], engine='netcdf4').isel(**isel_kwargs).astype('float16') 273 | previous_front_dataset = data_utils.reformat_fronts(previous_front_dataset, args['add_previous_fronts']) 274 | 275 | if args['front_dilation'] > 0: 276 | previous_front_dataset = data_utils.expand_fronts(previous_front_dataset, iterations=args['front_dilation']) 277 | 278 | previous_front_dataset = previous_front_dataset.transpose('longitude', 'latitude') 279 | 280 | previous_fronts = np.zeros([len(previous_front_dataset['longitude'].values), 281 | len(previous_front_dataset['latitude'].values), 282 | len(args['pressure_levels'])], dtype=np.float16) 283 | 284 | for front_type_no, previous_front_type in enumerate(args['add_previous_fronts']): 285 | previous_fronts[..., 0] = np.where(previous_front_dataset['identifier'].values == front_type_no + 1, 1, 0) # Place previous front labels at the surface level 286 | variables_dataset[previous_front_type] = (('longitude', 'latitude', 'pressure_level'), previous_fronts) # Add previous fronts to the predictor dataset 287 | 288 | variables_dataset = variables_dataset.to_array().transpose('longitude', 'latitude', 'pressure_level', 'variable') 289 | 290 | if args['override_extent'] is None: 291 | if args['images'][0] > 1 and domain_size[0] > args['image_size'][0] + args['images'][0]: 292 | start_indices_lon = np.linspace(0, settings.DEFAULT_DOMAIN_INDICES[args['domain']][1] - settings.DEFAULT_DOMAIN_INDICES[args['domain']][0] - args['image_size'][0], 293 | args['images'][0]).astype(int) 294 | else: 295 | start_indices_lon = np.zeros((args['images'][0], ), dtype=int) 296 | 297 | if args['images'][1] > 1 and domain_size[1] > args['image_size'][1] + args['images'][1]: 298 | start_indices_lat = np.linspace(0, settings.DEFAULT_DOMAIN_INDICES[args['domain']][3] - settings.DEFAULT_DOMAIN_INDICES[args['domain']][2] - args['image_size'][1], 299 | args['images'][1]).astype(int) 300 | else: 301 | start_indices_lat = np.zeros((args['images'][1], ), dtype=int) 302 | 303 | else: 304 | if args['images'][0] > 1 and domain_size[0] > args['image_size'][0] + args['images'][0]: 305 | start_indices_lon = np.linspace(0, domain_size[0] - args['image_size'][0], args['images'][0]).astype(int) 306 | else: 307 | start_indices_lon = np.zeros((args['images'][0], ), dtype=int) 308 | 309 | if args['images'][1] > 1 and domain_size[1] > args['image_size'][1] + args['images'][1]: 310 | start_indices_lat = np.linspace(0, domain_size[1] - args['image_size'][1], args['images'][1]).astype(int) 311 | else: 312 | start_indices_lat = np.zeros((args['images'][1], ), dtype=int) 313 | 314 | image_order = list(itertools.product(start_indices_lon, start_indices_lat)) # Every possible combination of longitude and latitude starting points 315 | 316 | if args['shuffle_images']: 317 | np.random.shuffle(image_order) 318 | 319 | for image_start_indices in image_order: 320 | 321 | start_index_lon = image_start_indices[0] 322 | end_index_lon = start_index_lon + args['image_size'][0] 323 | start_index_lat = image_start_indices[1] 324 | end_index_lat = start_index_lat + args['image_size'][1] 325 | 326 | # boolean flags for rotating and flipping images 327 | rotate_image = np.random.random() <= args['rotate_chance'] 328 | flip_lon = np.random.random() <= args['flip_chance_lon'] 329 | flip_lat = np.random.random() <= args['flip_chance_lat'] 330 | 331 | if rotate_image: 332 | rotation_direction = np.random.randint(0, 2) # 0 = clockwise, 1 = counter-clockwise 333 | num_rotations = np.random.randint(1, 4) # n * 90 degrees 334 | 335 | variables_tensor = tf.convert_to_tensor(variables_dataset[start_index_lon:end_index_lon, start_index_lat:end_index_lat, :, :], dtype=tf.float16) 336 | if flip_lon: 337 | variables_tensor = tf.reverse(variables_tensor, axis=[0]) # Reverse values along the longitude dimension 338 | if flip_lat: 339 | variables_tensor = tf.reverse(variables_tensor, axis=[1]) # Reverse values along the latitude dimension 340 | if rotate_image: 341 | for rotation in range(num_rotations): 342 | variables_tensor = tf.reverse(tf.transpose(variables_tensor, perm=[1, 0, 2, 3]), axis=[rotation_direction]) # Rotate image 90 degrees 343 | 344 | if args['noise_fraction'] > 0: 345 | ### Add noise to image ### 346 | random_values = tf.random.uniform(shape=variables_tensor.shape) 347 | variables_tensor = tf.where(random_values < args['noise_fraction'] / 2, 0.0, variables_tensor) # add 0s to image 348 | variables_tensor = tf.where(random_values > 1.0 - (args['noise_fraction'] / 2), 1.0, variables_tensor) # add 1s to image 349 | 350 | if args['num_dims'][0] == 2: 351 | variables_tensor_shape_3d = variables_tensor.shape 352 | # Combine pressure level and variables dimensions, making the images 2D (excluding the final dimension) 353 | variables_tensor = tf.reshape(variables_tensor, [variables_tensor_shape_3d[0], variables_tensor_shape_3d[1], variables_tensor_shape_3d[2] * variables_tensor_shape_3d[3]]) 354 | 355 | variables_tensor_for_timestep = tf.data.Dataset.from_tensors(variables_tensor) 356 | if 'variables_tensors_for_month' not in locals(): 357 | variables_tensors_for_month = variables_tensor_for_timestep 358 | else: 359 | variables_tensors_for_month = variables_tensors_for_month.concatenate(variables_tensor_for_timestep) 360 | 361 | front_tensor = tf.convert_to_tensor(front_dataset[start_index_lon:end_index_lon, start_index_lat:end_index_lat, :], dtype=tf.int32) 362 | 363 | if flip_lon: 364 | front_tensor = tf.reverse(front_tensor, axis=[0]) # Reverse values along the longitude dimension 365 | if flip_lat: 366 | front_tensor = tf.reverse(front_tensor, axis=[1]) # Reverse values along the latitude dimension 367 | if rotate_image: 368 | for rotation in range(num_rotations): 369 | front_tensor = tf.reverse(tf.transpose(front_tensor, perm=[1, 0, 2]), axis=[rotation_direction]) # Rotate image 90 degrees 370 | 371 | if args['num_dims'][1] == 3: 372 | # Make the front object images 3D, with the size of the 3rd dimension equal to the number of pressure levels 373 | front_tensor = tf.tile(front_tensor, (1, 1, len(args['pressure_levels']))) 374 | else: 375 | front_tensor = front_tensor[:, :, 0] 376 | 377 | front_tensor = tf.cast(tf.one_hot(front_tensor, num_front_types), tf.float16) # One-hot encode the labels 378 | front_tensor_for_timestep = tf.data.Dataset.from_tensors(front_tensor) 379 | if 'front_tensors_for_month' not in locals(): 380 | front_tensors_for_month = front_tensor_for_timestep 381 | else: 382 | front_tensors_for_month = front_tensors_for_month.concatenate(front_tensor_for_timestep) 383 | 384 | timesteps_kept += 1 385 | else: 386 | timesteps_discarded += 1 387 | 388 | if args['verbose']: 389 | print("Timesteps complete: %d/%d (Retained/discarded: %d/%d)" % (timesteps_kept + timesteps_discarded, num_timesteps, timesteps_kept, timesteps_discarded), end='\r') 390 | 391 | print("Timesteps complete: %d/%d (Retained/discarded: %d/%d)" % (timesteps_kept + timesteps_discarded, num_timesteps, timesteps_kept, timesteps_discarded)) 392 | 393 | if args['overwrite']: 394 | if os.path.isdir(tf_dataset_folder_variables): 395 | os.rmdir(tf_dataset_folder_variables) 396 | if os.path.isdir(tf_dataset_folder_fronts): 397 | os.rmdir(tf_dataset_folder_fronts) 398 | 399 | try: 400 | tf.data.Dataset.save(variables_tensors_for_month, path=tf_dataset_folder_variables) 401 | tf.data.Dataset.save(front_tensors_for_month, path=tf_dataset_folder_fronts) 402 | print("Tensorflow datasets for %d-%02d saved to %s." % (year, month, args['tf_outdir'])) 403 | except NameError: 404 | print("No images could be retained with the provided arguments.") 405 | --------------------------------------------------------------------------------