├── .gitattributes ├── .gitignore ├── Model_chpt ├── diffusion.pt └── unet.pt ├── README.md ├── data └── ERA5_const_sfc_variables.nc ├── download_ERA5 ├── ERA5_download_my_dates_sfc.sh ├── ERA5_download_project.sh ├── ERA5_download_single_date.sh ├── preprocessing_concat_year.py └── preprocessing_subsample.py ├── environment.yml ├── example.png ├── examples ├── inference.ipynb └── train_minimal.ipynb ├── inference ├── CRPS.py ├── compute_spectrum.py ├── plot_error_metrics.py ├── plot_spectrum.py ├── plot_timestep_examples.py ├── plot_timestep_std.py ├── save_test_preds.py └── save_test_truth.py └── src ├── DatasetUS.py ├── Inference.py ├── Network.py ├── TrainDiffusion.py └── TrainUnet.py /.gitattributes: -------------------------------------------------------------------------------- 1 | .pt filter=lfs diff=lfs merge=lfs -text 2 | *.pt filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Models* 2 | results* 3 | runs* 4 | .idea/ 5 | *.pyc 6 | *.DS_Store 7 | *.ipynb_checkpoints* 8 | 9 | # Data 10 | data/samples_*.nc 11 | data/e5.*.nc 12 | 13 | 14 | -------------------------------------------------------------------------------- /Model_chpt/diffusion.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dc0f95b826f6a024fa90376d5dbe28dbb92b4eb541032c45a9ac29caea2dd4c1 3 | size 389695738 4 | -------------------------------------------------------------------------------- /Model_chpt/unet.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eab9192c6a76b4b01de25b9dfc3d25327c7c6a5d6c28ce6dd8e4278394e247ce 3 | size 389657415 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Generative diffusion-based downscaling for climate 2 | ### Robbie A. Watt & Laura A. Mansfield 3 | 4 | ![plot](./example.png) 5 | 6 | This repo contains code to go alongside Watt & Mansfield (2024) preprint. In this preprint, we apply a diffusion based model and a Unet to a downscaling problem with climate data. The diffusion model is based on the implementation by T. Karras et al. () and the code is addapted from . 7 | 8 | 9 | ## File structure 10 | * src: contains code used to train model 11 | * inference: contains inference and plotting scripts 12 | * Model_chpt: contains model checkpoints 13 | * download_ERA5: contains scripts for downloading ERA5 data and processing into netcdf files. 14 | 15 | ## Usage 16 | ### Download ERA5 data 17 | The script `download_ERA5/ERA5_download_my_dates_sfc.sh` downloads the variables (temperature at 2m and zonal and meridional winds at 100 hPa) for all years of ERA5 up to 2022 and saves files into a directory named `data/`. You may need to edit data directories. Note that we subsample in time to reduce the size of the dataset (see file `preprocessing_subsample.py`). Data is concatenated into yearly samples and saved as `samples_{year}.nc`. 18 | 19 | We also use variables that are constant in time for the land sea mask and the topography. These are currently stored in `data/ERA5_const_sfc_variables.nc` or can be manually downloaded from ERA5 Copernicus store ( https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-single-levels?tab=form) by checking `geopotential` (z) and `land-sea mask` (lsm) (found under `Other`) and saving these to `netcdf`. 20 | 21 | ### Dependencies 22 | python>=3.9, torch, tensorboard, xarray, netcdf4, cartopy, matplotlib, scipy, numpy 23 | 24 | ### Training 25 | To train either the diffusion or unet models from scratch, simply run the `src/TrainDiffusion.py` or `src/TrainUnet.py` script from the project root directory. 26 | 27 | ### Inference 28 | After training, the inference scripts can be run in the following order: 29 | 1. `save_test_truth.py`: this script simply processes the true test data to save it into one file for easier comparison to other variables 30 | 2. `save_test_preds.py`: this script runs through all test data and saves the output into one file. You need to run this for each model. `modelname=UNet` for the standard UNet, `modelname=LinearInterpolation` for linear interpolation of coarse resolution variables onto the high resolution grid (i.e., the inputs to the model) and `modelname=Diffusion` for the diffusion model. When running the Diffusion model, we generate many possible samples in a loop, each seeded with a different random number, currently we loop over `rngs=range(0, 30)`. 31 | 32 | After running the above scripts, you should have files saved as `output/{modelname}/samples_2018-2023.nc` (or for diffusion, these are saved as `output/Diffusion/samples_{i}_2018-2023.nc` where `i` indexes the different generated samples). 33 | 34 | Plotting scripts: 35 | * `plot_timestep_examples.py` plots maps of methods for each timestep (used for Fig. 1). 36 | * `plot_error_metrics.py` plots maps of error metrics across full test dataset (Fig. 2) and prints the mean across the domain. 37 | * `plot_spectrum.py` plots the power spectrum for all methods (Fig. 3) 38 | 39 | 40 | ## Citation 41 | ``` 42 | @misc{watt2024generative, 43 | title={Generative Diffusion-based Downscaling for Climate}, 44 | author={Robbie A. Watt and Laura A. Mansfield}, 45 | year={2024}, 46 | eprint={2404.17752}, 47 | archivePrefix={arXiv}, 48 | primaryClass={physics.ao-ph} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /data/ERA5_const_sfc_variables.nc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robbiewatt1/ClimateDiffuse/7b8ee181fa3baeacf08a2bf9dc275989048b4030/data/ERA5_const_sfc_variables.nc -------------------------------------------------------------------------------- /download_ERA5/ERA5_download_my_dates_sfc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # We will download every few days so they aren't too correlated 3 | # over N years so that this turns out to be 1Tb? 4 | dir="../data/" 5 | opts="-N -c -P ${dir}" 6 | 7 | n_days=1 8 | 9 | year_start=1953 10 | month_start=1 11 | 12 | year_end=2022 13 | month_end=12 14 | 15 | 16 | for year in $(seq ${year_start} 1 ${year_end}); do 17 | for m in $(seq ${month_start} 1 ${month_end}); do 18 | if [ "${m}" = 4 ] || [ "${m}" = 6 ] || [ "${m}" = 9 ] ||[ "${m}" = 11 ] ; then 19 | last_day=30 20 | elif [ "${m}" = 2 ]; then 21 | if [ "$(((${year}) % 4))" = 0 ]; then 22 | last_day=29 23 | else last_day=28 24 | fi 25 | else last_day=31 26 | fi 27 | month=$(printf "%02d" ${m}) 28 | echo "Getting surface files for $year $month" 29 | source ERA5_download_project.sh ${year} ${month} ${last_day} 30 | echo "Subsample in time" 31 | python preprocessing_subsample.py --year ${year} --month ${month} --last_day ${last_day} --remove_files 32 | echo "Done for ${year} ${month}" 33 | done 34 | echo "Concatenate all months for ${year}" 35 | python preprocessing_concat_year.py --year ${year} --remove_files 36 | echo "Done for ${year}" 37 | done 38 | echo DONE 39 | 40 | -------------------------------------------------------------------------------- /download_ERA5/ERA5_download_project.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -f 2 | # 3 | # Original c-shell script converted to bash by Laura Mansfield 4 | # c-shell script to download selected files from rda.ucar.edu using Wget 5 | # NOTE: if you want to run under a different shell, make sure you change 6 | # the 'set' commands according to your shell's syntax 7 | # after you save the file, don't forget to make it executable 8 | # i.e. - "chmod 755 " 9 | # 10 | # Experienced Wget Users: add additional command-line flags here 11 | # Use the -r (--recursive) option with care 12 | dir="../data/" 13 | opts="-N -c -P ${dir}" 14 | # 15 | cert_opt="" 16 | # If you get a certificate verification error (version 1.10 or higher), 17 | # uncomment the following line: 18 | #set cert_opt = "--no-check-certificate" 19 | # 20 | # get year, month, day 21 | year=${1?Error: year?} 22 | month=${2?Error: month?} 23 | last_day=${3?Error: last day of month?} 24 | 25 | # download the file(s) 26 | 27 | # temperature at 2 m 28 | wget $cert_opt $opts https://data.rda.ucar.edu/ds633.0/e5.oper.an.sfc/${year}${month}/e5.oper.an.sfc.128_167_2t.ll025sc.${year}${month}0100_${year}${month}${last_day}23.nc 29 | 30 | # u-component of wind at 10 m 31 | wget $cert_opt $opts https://data.rda.ucar.edu/ds633.0/e5.oper.an.sfc/${year}${month}/e5.oper.an.sfc.128_165_10u.ll025sc.${year}${month}0100_${year}${month}${last_day}23.nc 32 | 33 | # v-component of wind at 10 m 34 | wget $cert_opt $opts https://data.rda.ucar.edu/ds633.0/e5.oper.an.sfc/${year}${month}/e5.oper.an.sfc.128_166_10v.ll025sc.${year}${month}0100_${year}${month}${last_day}23.nc 35 | 36 | 37 | -------------------------------------------------------------------------------- /download_ERA5/ERA5_download_single_date.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # We will download every few days so they aren't too correlated 3 | # over N years so that this turns out to be 1Tb? 4 | dir="../data/" 5 | opts="-N -c -P ${dir}" 6 | 7 | year=1987 8 | m=2 9 | 10 | 11 | 12 | if [ "${m}" = 4 ] || [ "${m}" = 6 ] || [ "${m}" = 9 ] ||[ "${m}" = 11 ] ; then 13 | last_day=30 14 | elif [ "${m}" = 2 ]; then 15 | if [ "$(((${year}) % 4))" = 0 ]; then 16 | last_day=29 17 | else last_day=28 18 | fi 19 | else last_day=31 20 | fi 21 | 22 | month=$(printf "%02d" ${m}) 23 | echo "Getting surface files for $year $month" 24 | source ERA5_download_project.sh ${year} ${month} ${last_day} 25 | echo "Subsample in time" 26 | python preprocessing_subsample.py --year ${year} --month ${month} --last_day ${last_day} --remove_files 27 | echo "Done for ${year} ${month}" 28 | 29 | 30 | echo "Concatenate all months for ${year}" 31 | python preprocessing_concat_year.py --year ${year} --remove_files 32 | echo DONE 33 | 34 | -------------------------------------------------------------------------------- /download_ERA5/preprocessing_concat_year.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import xarray as xr 4 | import matplotlib.pyplot as plt 5 | import random 6 | import argparse 7 | 8 | # Get arguments from argparser 9 | parser = argparse.ArgumentParser() 10 | ## Arguments: year 11 | parser.add_argument('--year', metavar='year', type=int) 12 | parser.add_argument('--remove_files', metavar='remove_files', action=argparse.BooleanOptionalAction) 13 | 14 | datadir = "../data/" 15 | 16 | ## Provide year and month as input to this file using args 17 | args = parser.parse_args() 18 | year = args.year 19 | remove_files = args.remove_files 20 | 21 | # Open first month 22 | month="01" 23 | filename = f"samples_{year}{month}.nc" 24 | path_to_file = f"{datadir}{filename}" 25 | ds = xr.open_dataset(path_to_file, engine="netcdf4") 26 | 27 | 28 | for m in range(2,13): 29 | month = f"{m:02d}" 30 | filename = f"samples_{year}{month}.nc" 31 | path_to_file = f"{datadir}{filename}" 32 | ds2 = xr.open_dataset(path_to_file, engine="netcdf4") 33 | 34 | # Concatenate along time axis 35 | ds = xr.concat((ds, ds2), dim="time") 36 | 37 | # Save 38 | save_file = f"samples_{year}.nc" 39 | ds.to_netcdf(f"{datadir}{save_file}") 40 | 41 | if remove_files: 42 | print("Removing intermediate files") 43 | for m in range(1,13): 44 | month = f"{m:02d}" 45 | filename = f"samples_{year}{month}.nc" 46 | path_to_file = f"{datadir}{filename}" 47 | os.remove(f"{datadir}{filename}") 48 | 49 | -------------------------------------------------------------------------------- /download_ERA5/preprocessing_subsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import xarray as xr 4 | import matplotlib.pyplot as plt 5 | import random 6 | import argparse 7 | 8 | datadir = "../data/" 9 | 10 | # Get arguments from argparser 11 | parser = argparse.ArgumentParser() 12 | ## Arguments: year 13 | parser.add_argument('--year', metavar='year', type=int) 14 | parser.add_argument('--month', metavar='month', type=int) 15 | parser.add_argument('--last_day', metavar='last_day', type=int) 16 | parser.add_argument('--remove_files', metavar='remove_files', action=argparse.BooleanOptionalAction) 17 | 18 | ## Provide year and month as input to this file using args 19 | args = parser.parse_args() 20 | year = args.year 21 | month = args.month 22 | last_day = args.last_day 23 | remove_files = args.remove_files 24 | 25 | first_day=1 26 | 27 | # Variable names of surface files 28 | varnames= {"VAR_2T":"128_167_2t", 29 | "VAR_10U":"128_165_10u", 30 | "VAR_10V":"128_166_10v"} 31 | 32 | var_keys = list(varnames.keys()) 33 | var_key = var_keys[0] 34 | 35 | # Set random seed for reproducibility, but different for each year/month 36 | seed = year*12 + month 37 | print(seed) 38 | random.seed(seed) 39 | 40 | ## First variable for setting up 41 | varname = varnames[var_key] 42 | filename = f"e5.oper.an.sfc.{varname}.ll025sc.{year}{month:02d}{first_day:02d}00_{year}{month:02d}{last_day}23.nc" 43 | 44 | # Open file 45 | path_to_file = f"{datadir}{filename}" 46 | ds = xr.open_dataset(path_to_file, engine="netcdf4") 47 | 48 | # Select time inds randomly 49 | time_inds = np.arange(len(ds.time), dtype=int) 50 | random.shuffle(time_inds) 51 | ## Select 30 time inds from this month 52 | time_inds = time_inds[0:30] 53 | 54 | # Pre-processed dataset 55 | ds_proc = ds.isel(time=time_inds) 56 | 57 | ## Open next vars and add them to the dataset. 58 | for var_key in var_keys[1:]: 59 | varname = varnames[var_key] 60 | filename = f"e5.oper.an.sfc.{varname}.ll025sc.{year}{month:02d}{first_day:02d}00_{year}{month:02d}{last_day}23.nc" 61 | 62 | # Open file 63 | path_to_file = f"{datadir}{filename}" 64 | ds = xr.open_dataset(path_to_file, engine="netcdf4") 65 | 66 | # Pre-processed dataset and add to existing 67 | ds_proc2 = ds.isel(time=time_inds) 68 | ds_proc = xr.merge((ds_proc, ds_proc2)) 69 | 70 | 71 | save_file = f"samples_{year}{month:02d}.nc" 72 | ds_proc.to_netcdf(f"{datadir}{save_file}") 73 | 74 | if remove_files: 75 | print("Removing intermediate files") 76 | for var_key in var_keys: 77 | varname = varnames[var_key] 78 | filename = f"e5.oper.an.sfc.{varname}.ll025sc.{year}{month:02d}{first_day:02d}00_{year}{month:02d}{last_day}23.nc" 79 | os.remove(f"{datadir}{filename}") 80 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ClimateDiffuse 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - python=3.10 8 | - pytorch 9 | - torchvision 10 | - pytorch-cuda=11.8 11 | - xarray 12 | - tensorboard 13 | - netcdf4 14 | - cartopy 15 | - matplotlib 16 | - scipy 17 | - jupyter 18 | - tqdm 19 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robbiewatt1/ClimateDiffuse/7b8ee181fa3baeacf08a2bf9dc275989048b4030/example.png -------------------------------------------------------------------------------- /inference/CRPS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def crps(y_true, y_pred, sample_weight=None): 4 | """Calculate Continuous Ranked Probability Score 5 | Data based on size (N_samples, N_features) 6 | Args: 7 | * y_true : np.array (N_samples, N_features) ground truth 8 | * y_pred : np.array (N_ensemble, N_samples, N_features) predictions from N_ensemble members 9 | * sample_weight : np.array (N_samples) weighting for samples e.g., area weighting 10 | Returns: 11 | * CRPS : np.array (N_features) 12 | """ 13 | num_samples = y_pred.shape[0] 14 | y_pred = np.sort(y_pred, axis=0) 15 | diff = y_pred[1:] - y_pred[:-1] 16 | weight = np.arange(1, num_samples) * np.arange(num_samples - 1, 0, -1) 17 | weight = np.expand_dims(weight, (-2,-1)) 18 | y_true = np.expand_dims(y_true, 0) 19 | absolute_error = np.mean(np.abs(y_pred - y_true), axis=(0)) 20 | per_obs_crps = absolute_error - np.sum(diff * weight, axis=0) / num_samples**2 21 | return np.average(per_obs_crps, axis=0, weights=sample_weight) -------------------------------------------------------------------------------- /inference/compute_spectrum.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import scipy.stats as stats 4 | 5 | def compute_spectrum2d(data): 6 | if data.ndim == 2: 7 | data = data[np.newaxis, ...] 8 | 9 | N, N_y, N_x = data.shape 10 | if N_x == 2 * N_y: 11 | data1 = data[:, :, :N_y] 12 | data2 = data[:, :, N_y:] 13 | data = np.concatenate((data1, data2), axis=0) 14 | N, N_y, N_x = data.shape 15 | 16 | # Take FFT and take amplitude 17 | fourier_image = np.fft.fftn(data, axes=(1,2)) 18 | fourier_amplitudes = np.abs(fourier_image)**2 19 | 20 | # Get kx and ky 21 | kfreq_x = np.fft.fftfreq(N_x) * N_x 22 | kfreq_y = np.fft.fftfreq(N_y) * N_y 23 | 24 | # Combine into one wavenumber for both directions 25 | kfreq2D = np.meshgrid(kfreq_x, kfreq_y) 26 | knrm = np.sqrt(kfreq2D[0]**2 + kfreq2D[1]**2) 27 | knrm = np.repeat(knrm[np.newaxis, ...], repeats=N, axis=0) 28 | 29 | # Flatten arrays 30 | knrm = knrm.flatten() 31 | fourier_amplitudes = fourier_amplitudes.flatten() 32 | 33 | # Get k-bins and mean amplitude within each bin 34 | kbins = np.arange(0.5, N_x//2, 1) 35 | Abins, _, _ = stats.binned_statistic(knrm, fourier_amplitudes, 36 | statistic="mean", 37 | bins=kbins) 38 | 39 | # Multiply by volume of bin 40 | Abins *= np.pi * (kbins[1:] ** 2 - kbins[:-1] ** 2) 41 | 42 | # Get center of k-bin for plotting 43 | kvals = 0.5 * (kbins[1:] + kbins[:-1]) 44 | 45 | return (kvals, Abins) -------------------------------------------------------------------------------- /inference/plot_error_metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import xarray as xr 4 | import matplotlib.pyplot as plt 5 | import cartopy 6 | import cartopy.crs as ccrs 7 | from CRPS import crps 8 | 9 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes 10 | 11 | from compute_spectrum import compute_spectrum2d 12 | 13 | # Models 14 | models = ["Truth", "Diffusion", "UNet", "LinearInterpolation"] 15 | 16 | # Variables - defines three separate subplots 17 | varnames = ["VAR_2T", "VAR_10U", "VAR_10V"] 18 | vmax = [1, 1, 1] 19 | 20 | # For diffusion 21 | rngs = range(0, 30) 22 | n_ens = len(rngs) 23 | 24 | plot_dir = "../output/plots/" 25 | year_start = 2018 26 | year_end = 2023 27 | 28 | 29 | data = {} 30 | 31 | # Loop over data 32 | for m, model in enumerate(models): 33 | print(model) 34 | # Loop over ensembles for diffusion (probabilistic) 35 | if model == "Diffusion": 36 | ntime, nlat, nlon = data["Truth"]["VAR_2T"].shape 37 | diffusion = np.zeros((3, n_ens, ntime, nlat, nlon)) 38 | 39 | # Loop over data 40 | for r, rng in enumerate(rngs): 41 | filename = f"../output/{model}/samples{rng}_{year_start}-{year_end}.nc" 42 | ds = xr.open_dataset(filename, engine="netcdf4") 43 | 44 | for i, varname in enumerate(varnames): 45 | diffusion[i, r] = ds[varname].to_numpy() 46 | 47 | data["DiffusionEns"] = {} 48 | data["DiffusionMean"] = {} 49 | for i, varname in enumerate(varnames): 50 | data["DiffusionEns"][varname] = diffusion[i] 51 | data["DiffusionMean"][varname] = diffusion[i].mean(axis=0) 52 | else: 53 | filename = f"../output/{model}/samples_{year_start}-{year_end}.nc" 54 | ds = xr.open_dataset(filename, engine="netcdf4") 55 | 56 | data[model] = {} 57 | for i, varname in enumerate(varnames): 58 | data[model][varname] = ds[varname].to_numpy() 59 | 60 | 61 | lon = ds.longitude 62 | lat = ds.latitude 63 | nlat, nlon = len(lat), len(lon) 64 | 65 | # Get areas 66 | area_weights = np.cos(np.deg2rad(lat.to_numpy())) 67 | area_weights = np.repeat(area_weights[:, None], nlon, axis=1) 68 | 69 | 70 | ## Error metrics: 71 | MAE_diffusion = np.zeros((3, 128, 256)) 72 | MAE_UNet = np.zeros((3, 128, 256)) 73 | MAE_linearinterp = np.zeros((3, 128, 256)) 74 | 75 | RMSE_diffusion = np.zeros((3, 128, 256)) 76 | RMSE_UNet = np.zeros((3, 128, 256)) 77 | RMSE_linearinterp = np.zeros((3, 128, 256)) 78 | 79 | CRPS_diffusion = np.zeros((3, 128, 256)) 80 | 81 | print("Calc errors") 82 | for i, varname in enumerate(varnames): 83 | print(varname) 84 | MAE_diffusion[i] = np.mean(np.abs( data["DiffusionMean"][varname] - data["Truth"][varname] ), axis=0) 85 | MAE_UNet[i] = np.mean(np.abs( data["UNet"][varname] - data["Truth"][varname] ), axis=0) 86 | MAE_linearinterp[i] = np.mean(np.abs( data["LinearInterpolation"][varname] - data["Truth"][varname] ), axis=0) 87 | RMSE_diffusion[i] = np.sqrt(np.mean( ( data["DiffusionMean"][varname] - data["Truth"][varname] )**2, axis=0)) 88 | RMSE_UNet[i] = np.sqrt(np.mean(( data["UNet"][varname] - data["Truth"][varname] )**2, axis=0)) 89 | RMSE_linearinterp[i] = np.sqrt(np.mean(( data["LinearInterpolation"][varname] - data["Truth"][varname] )**2, 90 | axis=0)) 91 | 92 | diffusion_i = data["DiffusionEns"][varname] 93 | diffusion_i = diffusion_i.reshape((n_ens, ntime, nlat*nlon)) # flatten x-y axis 94 | CRPS_diffusion_flat = crps(data["Truth"][varname].reshape((ntime, nlat*nlon)), diffusion_i ) 95 | CRPS_diffusion[i] = CRPS_diffusion_flat.reshape((nlat, nlon)) 96 | 97 | print(MAE_diffusion.mean(), CRPS_diffusion.mean()) 98 | 99 | # Set up plots 100 | print("Plot") 101 | # plot MAE for temp, u, v 102 | plot_varnames = ["Temperature", "Zonal wind", "Meridional wind"] 103 | plot_var_labels = ["K", "m/s", "m/s"] 104 | 105 | plt.clf() 106 | plt.rcParams.update({'font.size': 18}) 107 | 108 | fig, axs = plt.subplots(3, 3, figsize=(16, 9), 109 | subplot_kw={'projection': ccrs.PlateCarree()}, 110 | gridspec_kw={'wspace': 0.1, 111 | 'hspace': 0.08}) 112 | 113 | for i in range(3): 114 | ax = axs[0, i] 115 | plt.sca(ax) 116 | ax.coastlines() 117 | ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none') 118 | pcm = plt.pcolormesh(lon, lat, MAE_UNet[i], 119 | vmin=0, vmax=vmax[i], 120 | cmap="YlOrRd", 121 | shading='nearest') 122 | plt.title(plot_varnames[i]) 123 | if i==0: 124 | plt.text(lon[0]-2, lat[len(lat)//2], f"U-Net MAE", transform=ccrs.PlateCarree(), 125 | rotation='vertical', ha='right', va='center', zorder=10) 126 | 127 | ax = axs[1, i] 128 | plt.sca(ax) 129 | ax.coastlines() 130 | ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none') 131 | pcm = plt.pcolormesh(lon, lat, MAE_diffusion[i], 132 | vmin=0, vmax=vmax[i], 133 | cmap = "YlOrRd", 134 | shading='nearest') 135 | if i==0: 136 | plt.text(lon[0]-2, lat[len(lat)//2], f"Diffusion MAE", transform=ccrs.PlateCarree(), 137 | rotation='vertical', ha='right', va='center', zorder=10) 138 | 139 | ax = axs[2, i] 140 | plt.sca(ax) 141 | ax.coastlines() 142 | ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none') 143 | pcm = plt.pcolormesh(lon, lat, CRPS_diffusion[i], 144 | vmin=0, vmax=vmax[i], 145 | cmap="YlOrRd", 146 | shading='nearest') 147 | if i==0: 148 | plt.text(lon[0]-2, lat[len(lat)//2], f"Diffusion CRPS", transform=ccrs.PlateCarree(), 149 | rotation='vertical', ha='right', va='center', 150 | zorder=10) 151 | 152 | cax = axs[2, i].inset_axes([0., -0.25, 1, 0.1]) 153 | plt.colorbar(pcm, cax=cax, orientation="horizontal", label=plot_var_labels[i]) 154 | 155 | # add labels 156 | axs_flat = axs.flatten() 157 | labels = ["a)", "b)", "c)", 158 | "d)", "e)", "f)", 159 | "g)", "h)", " i)"] 160 | for i in range(len(axs_flat)): 161 | plt.text(x=-0.07, y=1.03, s=labels[i], 162 | fontsize=16, transform=axs_flat[i].transAxes) 163 | 164 | plt.tight_layout() 165 | save_filename = f"{plot_dir}/error_maps.png" 166 | plt.savefig(save_filename, bbox_inches="tight") 167 | print(f"Saved as {save_filename}") 168 | 169 | # Area weighted means 170 | print(MAE_diffusion.shape,area_weights.shape) 171 | # Repeat area weights 172 | area_weights = np.repeat(area_weights[None, :, :], 3, axis=0) 173 | print(MAE_diffusion.shape,area_weights.shape) 174 | 175 | MAE_diffusion = np.average(MAE_diffusion, axis=(1,2), weights=area_weights) 176 | MAE_UNet = np.average(MAE_UNet, axis=(1,2), weights=area_weights) 177 | MAE_linearinterp = np.average(MAE_linearinterp, axis=(1,2), weights=area_weights) 178 | 179 | RMSE_diffusion = np.sqrt(np.average(RMSE_diffusion**2, axis=(1,2), weights=area_weights)) 180 | RMSE_UNet = np.sqrt(np.average(RMSE_UNet**2, axis=(1,2), weights=area_weights)) 181 | RMSE_linearinterp = np.sqrt(np.average(RMSE_linearinterp**2, axis=(1,2), weights=area_weights)) 182 | 183 | CRPS_diffusion = np.average(CRPS_diffusion, axis=(1,2), weights=area_weights) 184 | 185 | print(f" Diffusion mean abs: {MAE_diffusion}") 186 | print(f" UNet mean abs: {MAE_UNet}") 187 | print(f" Linear Interp mean abs: {MAE_linearinterp}") 188 | 189 | print(f" Diffusion RMSE : {RMSE_diffusion}") 190 | print(f" UNet RMSE : {RMSE_UNet}") 191 | print(f" Linear Interp RMSE : {RMSE_linearinterp}") 192 | 193 | print(f" Diffusion CRPS: {CRPS_diffusion}") 194 | 195 | -------------------------------------------------------------------------------- /inference/plot_spectrum.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import xarray as xr 4 | import matplotlib.pyplot as plt 5 | 6 | from compute_spectrum import compute_spectrum2d 7 | 8 | plt.rcParams.update({'font.size': 18}) 9 | 10 | 11 | # Models - lines shown in different colors 12 | models = ["Truth", "Diffusion", "UNet", "LinearInterpolation"] 13 | colors = ["black", "red", "blue", "yellowgreen"] 14 | 15 | plot_dir = "../output/plots/" 16 | year_start = 2018 17 | year_end = 2023 18 | 19 | # For diffusion 20 | rngs = range(0, 30) 21 | n_ens = len(rngs) 22 | 23 | # Variables - defines three separate subplots 24 | varnames = ["VAR_2T", "VAR_10U", "VAR_10V"] 25 | plot_varnames = ["Temperature", "Zonal wind", "Meridional wind"] 26 | 27 | # Set up plot 28 | plt.clf() 29 | fig, axs = plt.subplots(1, 3, figsize=(9, 3), 30 | sharex=True, sharey=True) 31 | 32 | spectrum_all = {} 33 | 34 | # Loop over data 35 | for m, model in enumerate(models): 36 | print(model) 37 | if model != "Diffusion": 38 | spectrum_all[model] = {} 39 | 40 | filename = f"../output/{model}/samples_{year_start}-{year_end}.nc" 41 | ds = xr.open_dataset(filename, engine="netcdf4") 42 | 43 | for i, varname in enumerate(varnames): 44 | data = ds[varname].to_numpy() 45 | 46 | # Compute spectrum 47 | kvals, Abins = compute_spectrum2d(data) 48 | spectrum_all[model][varname] = Abins 49 | 50 | plt.sca(axs[i]) 51 | axs[i].loglog(kvals, Abins, 52 | color=colors[m], 53 | label=models[m], 54 | alpha=0.5) 55 | plt.title(plot_varnames[i]) 56 | plt.xlabel("$k$") 57 | plt.ylabel("$P(k)$") 58 | elif model == "Diffusion": 59 | 60 | spectrum_all["Diffusion_mean"] = {} 61 | spectrum_all["Diffusion_std"] = {} 62 | 63 | # Loop over data 64 | Abins_diffusion = np.zeros((3, n_ens, len(kvals))) 65 | for r, rng in enumerate(rngs): 66 | filename = f"../output/{model}/samples{rng}_{year_start}-{year_end}.nc" 67 | ds = xr.open_dataset(filename, engine="netcdf4") 68 | 69 | for i, varname in enumerate(varnames): 70 | data = ds[varname].to_numpy() 71 | kvals, Abins = compute_spectrum2d(data) 72 | Abins_diffusion[i, r] = Abins 73 | print("Calculated for all ensembles in Diffusion") 74 | 75 | # Plot spectrum 76 | for i, varname in enumerate(varnames): 77 | spectrum_all["Diffusion_mean"][varname] = Abins_diffusion[i].mean(axis=0) 78 | spectrum_all["Diffusion_std"][varname] = Abins_diffusion[i].std(axis=0) 79 | 80 | Abins_mean = spectrum_all["Diffusion_mean"][varname] 81 | Abins_std = spectrum_all["Diffusion_std"][varname] 82 | 83 | print(Abins_mean) 84 | print(Abins_std) 85 | 86 | plt.sca(axs[i]) 87 | axs[i].loglog(kvals, Abins_mean, 88 | color=colors[m], 89 | label=models[m], 90 | alpha=0.5) 91 | axs[i].fill_between(kvals, 92 | Abins_mean - Abins_std, 93 | Abins_mean + Abins_std, 94 | color= colors[m], alpha=0.3) 95 | 96 | 97 | #plt.legend() 98 | fig_filename = f"{plot_dir}/spectrum.png" 99 | plt.tight_layout() 100 | plt.savefig(fig_filename, bbox_inches="tight") 101 | print(f"Saved as {fig_filename}") 102 | 103 | # Plot the differences 104 | # Set up plot 105 | plt.clf() 106 | fig, axs = plt.subplots(2, 3, figsize=(16, 8), 107 | sharex=True, sharey="row") 108 | for m, model in enumerate(models): 109 | print(model, colors[m]) 110 | if model != "Diffusion": 111 | for i, varname in enumerate(varnames): 112 | # Already computed spectrum and saved into dictionary 113 | Abins = spectrum_all[model][varname] 114 | 115 | plt.sca(axs[0, i]) 116 | axs[0, i].loglog(kvals, Abins, 117 | color=colors[m], 118 | label=models[m], 119 | alpha=0.5) 120 | plt.title(plot_varnames[i]) 121 | plt.ylabel("$P(k)$") 122 | 123 | if model != "Truth": 124 | diff = np.abs(spectrum_all["Truth"][varname] - Abins) 125 | plt.sca(axs[1, i]) 126 | axs[1, i].loglog(kvals, diff, 127 | color=colors[m], 128 | label=models[m], 129 | alpha=0.5) 130 | plt.ylabel("$P(k)$") 131 | plt.xlabel("$k$") 132 | 133 | elif model == "Diffusion": 134 | # Plot spectrum 135 | for i, varname in enumerate(varnames): 136 | Abins_mean = spectrum_all["Diffusion_mean"][varname] 137 | diff = np.abs(spectrum_all["Truth"][varname] - Abins_mean) 138 | Abins_std = spectrum_all["Diffusion_std"][varname] 139 | 140 | plt.sca(axs[0, i]) 141 | axs[0, i].loglog(kvals, Abins_mean, 142 | color=colors[m], 143 | label=models[m], 144 | alpha=0.5) 145 | 146 | plt.sca(axs[1, i]) 147 | axs[1, i].loglog(kvals, diff, 148 | color=colors[m], 149 | label=models[m], 150 | alpha=0.5) 151 | #axs[1, i].fill_between(kvals, 152 | # diff - Abins_std, 153 | # diff + Abins_std, 154 | # color= colors[m], alpha=0.3) 155 | # add labels 156 | axs_flat = axs.flatten() 157 | labels = ["a)", "b)", "c)", 158 | "d)", "e)", "f)"] 159 | for i in range(len(axs_flat)): 160 | plt.text(x=-0.15, y=1.02, s=labels[i], 161 | fontsize=16, transform=axs_flat[i].transAxes) 162 | 163 | # add legend at the bottom 164 | leg_ax = axs[0, 1] 165 | # Put a legend below current axis 166 | leg_ax.legend(loc='lower center', 167 | bbox_to_anchor=(0.5, -1.7), 168 | ncol=4) 169 | 170 | fig_filename = f"{plot_dir}/spectrum_incl_differences.png" 171 | plt.tight_layout() 172 | plt.savefig(fig_filename, bbox_inches="tight") 173 | print(f"Saved as {fig_filename}") 174 | 175 | 176 | -------------------------------------------------------------------------------- /inference/plot_timestep_examples.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import xarray as xr 4 | import matplotlib.pyplot as plt 5 | import cartopy 6 | import cartopy.crs as ccrs 7 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes 8 | 9 | from compute_spectrum import compute_spectrum2d 10 | 11 | # Models 12 | models = ["Truth", "Diffusion", "UNet", "LinearInterpolation"] 13 | 14 | year_start = 2018 15 | year_end = 2023 16 | 17 | # Variables - defines three separate subplots 18 | varnames = ["VAR_2T", "VAR_10U", "VAR_10V"] 19 | vmin = [250, -10, -10] 20 | vmax = [300, 10, 10] 21 | vmax_stds = [3, 1, 1] 22 | cmaps = ["rainbow", "BrBG_r", "BrBG_r"] 23 | 24 | model = "Diffusion" 25 | rngs = range(0, 30) 26 | n_samples = len(rngs) 27 | 28 | plot_dir = "../output/plots/diffusion_pred/" 29 | # First, get the truth 30 | filename = f"../output/Truth/samples_{year_start}-{year_end}.nc" 31 | ds = xr.open_dataset(filename, engine="netcdf4") 32 | truth = xr.concat([ds[varname] for varname in varnames], dim="var") 33 | lat = ds.latitude 34 | lon = ds.longitude 35 | time = ds.time 36 | 37 | ntime, nlat, nlon = len(time), len(lat), len(lon) 38 | print(truth.shape) 39 | # Get coarse version / Linear interp 40 | filename = f"../output/LinearInterpolation/samples_{year_start}-{year_end}.nc" 41 | ds = xr.open_dataset(filename, engine="netcdf4") 42 | coarse = xr.concat([ds[varname] for varname in varnames], dim="var") 43 | 44 | # Get UNet 45 | filename = f"../output/UNet/samples_{year_start}-{year_end}.nc" 46 | ds = xr.open_dataset(filename, engine="netcdf4") 47 | unet = xr.concat([ds[varname] for varname in varnames], dim="var") 48 | 49 | # Get Diffusion 50 | rng=0 51 | filename = f"../output/Diffusion/samples{rng}_{year_start}-{year_end}.nc" 52 | ds = xr.open_dataset(filename, engine="netcdf4") 53 | diffusion = xr.concat([ds[varname] for varname in varnames], dim="var") 54 | 55 | # Plot 56 | plot_varnames = ["Temperature", "Zonal wind", "Meridional wind"] 57 | plot_var_labels = ["K", "m/s", "m/s"] 58 | plt.rcParams.update({'font.size': 18}) 59 | 60 | # Plot all 30 preds for a selection of timesteps 61 | for t, timestep in enumerate(time[::60]): 62 | plt.clf() 63 | fig, axs = plt.subplots(4,3, figsize=(16, 10.2), 64 | subplot_kw={'projection': ccrs.PlateCarree()}, 65 | gridspec_kw={'wspace': 0.1, 66 | 'hspace': 0.1}) 67 | for i, varname in enumerate(varnames): 68 | # Plot truth for first plot 69 | ax = axs[0, i] 70 | plt.sca(ax) 71 | ax.coastlines() 72 | ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none') 73 | pcm = plt.pcolormesh(lon, lat, coarse[i, t], 74 | vmin=vmin[i], vmax=vmax[i], 75 | shading='nearest', 76 | cmap=cmaps[i]) 77 | plt.title(f"{plot_varnames[i]}") 78 | if i == 0: 79 | plt.text(lon[0]-2, lat[len(lat) // 2], f"Coarse", transform=ccrs.PlateCarree(), 80 | rotation='vertical', ha='right', va='center', zorder=10) 81 | #plt.colorbar(pcm, orientation="horizontal", label=f"{varname}") 82 | 83 | ax = axs[1, i] 84 | plt.sca(ax) 85 | ax.coastlines() 86 | ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none') 87 | pcm = plt.pcolormesh(lon, lat, truth[i, t], 88 | vmin=vmin[i], vmax=vmax[i], 89 | shading='nearest', 90 | cmap=cmaps[i]) 91 | if i == 0: 92 | plt.text(lon[0]-2, lat[len(lat) // 2], f"Truth", transform=ccrs.PlateCarree(), 93 | rotation='vertical', ha='right', va='center', zorder=10) 94 | #plt.title(f"Truth {varname}") 95 | #plt.colorbar(pcm, orientation="horizontal", label=f"{varname}") 96 | 97 | ax = axs[2, i] 98 | plt.sca(ax) 99 | ax.coastlines() 100 | ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none') 101 | pcm = plt.pcolormesh(lon, lat, unet[i, t], 102 | vmin=vmin[i], vmax=vmax[i], 103 | shading='nearest', 104 | cmap=cmaps[i]) 105 | if i == 0: 106 | plt.text(lon[0]-2, lat[len(lat) // 2], f"U-Net", transform=ccrs.PlateCarree(), 107 | rotation='vertical', ha='right', va='center', zorder=10) 108 | #plt.title(f"UNet {varname}") 109 | #plt.colorbar(pcm, orientation="horizontal", label=f"{varname}") 110 | 111 | 112 | ax = axs[3, i] 113 | plt.sca(ax) 114 | ax.coastlines() 115 | ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none') 116 | pcm = plt.pcolormesh(lon, lat, diffusion[i, t], 117 | vmin=vmin[i], vmax=vmax[i], 118 | shading='nearest', 119 | cmap=cmaps[i]) 120 | if i == 0: 121 | plt.text(lon[0]-2, lat[len(lat) // 2], f"Diffusion", transform=ccrs.PlateCarree(), 122 | rotation='vertical', ha='right', va='center', zorder=10) 123 | cax = axs[3, i].inset_axes([0., -0.25, 1, 0.1]) 124 | plt.colorbar(pcm, cax = cax, orientation="horizontal", label=f"{plot_var_labels[i]}") 125 | 126 | # add labels 127 | axs_flat = axs.flatten() 128 | labels = ["a)", "b)", "c)", 129 | "d)", "e)", "f)", 130 | "g)", "h)", " i)", 131 | "j)", "k)", "l)"] 132 | 133 | for i in range(len(axs_flat)): 134 | plt.text(x=-0.08, y=1.02, s=labels[i], 135 | fontsize=16, transform=axs_flat[i].transAxes) 136 | 137 | plt.suptitle(f"Time: {timestep.values}") 138 | plt.tight_layout() 139 | save_filename = f"{plot_dir}/compare_all_{timestep.values}.png" 140 | plt.savefig(save_filename, bbox_inches="tight") 141 | plt.close() 142 | print(f"Saved as {save_filename}") 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /inference/plot_timestep_std.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import xarray as xr 4 | import matplotlib.pyplot as plt 5 | import cartopy 6 | import cartopy.crs as ccrs 7 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes 8 | 9 | from compute_spectrum import compute_spectrum2d 10 | 11 | # Models 12 | models = ["Truth", "Diffusion", "UNet", "LinearInterpolation"] 13 | 14 | year_start = 2018 15 | year_end = 2023 16 | 17 | # Variables - defines three separate subplots 18 | varnames = ["VAR_2T", "VAR_10U", "VAR_10V"] 19 | vmin = [250, -10, -10] 20 | vmax = [300, 10, 10] 21 | vmax_stds = [1, 1, 1] 22 | cmaps = ["rainbow", "BrBG_r", "BrBG_r"] 23 | 24 | model = "Diffusion" 25 | rngs = range(0, 30) 26 | n_samples = len(rngs) 27 | 28 | plot_dir = "../output/plots/diffusion_pred/" 29 | # First, get the truth 30 | filename = f"../output/Truth/samples_{year_start}-{year_end}.nc" 31 | ds = xr.open_dataset(filename, engine="netcdf4") 32 | truth = xr.concat([ds[varname] for varname in varnames], dim="var") 33 | lat = ds.latitude 34 | lon = ds.longitude 35 | time = ds.time 36 | 37 | ntime, nlat, nlon = len(time), len(lat), len(lon) 38 | print(truth.shape) 39 | # Get coarse version / Linear interp 40 | filename = f"../output/LinearInterpolation/samples_{year_start}-{year_end}.nc" 41 | ds = xr.open_dataset(filename, engine="netcdf4") 42 | coarse = xr.concat([ds[varname] for varname in varnames], dim="var") 43 | 44 | # Get UNet 45 | filename = f"../output/UNet/samples_{year_start}-{year_end}.nc" 46 | ds = xr.open_dataset(filename, engine="netcdf4") 47 | unet = xr.concat([ds[varname] for varname in varnames], dim="var") 48 | 49 | # Get Diffusion 50 | # Get UNet 51 | rng=0 52 | filename = f"../output/Diffusion/samples{rng}_{year_start}-{year_end}.nc" 53 | ds = xr.open_dataset(filename, engine="netcdf4") 54 | diffusion = xr.concat([ds[varname] for varname in varnames], dim="var") 55 | 56 | for rng in range(20): 57 | filename = f"../output/Diffusion/samples{rng}_{year_start}-{year_end}.nc" 58 | ds = xr.open_dataset(filename, engine="netcdf4") 59 | diffusion = xr.concat([ds[varname] for varname in varnames], dim="var") 60 | if rng==0: 61 | diffusion_all = diffusion 62 | else: 63 | diffusion_all = xr.concat((diffusion_all, diffusion), dim="ens") 64 | diffusion_mean = diffusion_all.mean(dim="ens") 65 | diffusion_std = diffusion_all.std(dim="ens") 66 | 67 | 68 | # Plot 69 | plot_varnames = ["Temperature", "Zonal wind", "Meridional wind"] 70 | plot_var_labels = ["K", "m/s", "m/s"] 71 | plt.rcParams.update({'font.size': 18}) 72 | 73 | # Plot all 30 preds for a selection of timesteps 74 | for t, timestep in enumerate(time[::60]): 75 | plt.clf() 76 | fig, axs = plt.subplots(1,3, figsize=(16, 2.4), # 16, 10.2 77 | subplot_kw={'projection': ccrs.PlateCarree()}, 78 | gridspec_kw={'wspace': 0.1, 79 | 'hspace': 0.1}) 80 | for i, varname in enumerate(varnames): 81 | # Plot truth for first plot 82 | ax = axs[i] 83 | plt.sca(ax) 84 | ax.coastlines() 85 | ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none') 86 | pcm = plt.pcolormesh(lon, lat, diffusion_std[i, t], 87 | vmin=0., vmax=vmax_stds[i], 88 | shading='nearest', 89 | cmap="YlOrRd",) 90 | if i==0: 91 | plt.text(lon[0]-2, lat[len(lat) // 2], f"Diffusion Std", transform=ccrs.PlateCarree(), 92 | rotation='vertical', ha='right', va='center', zorder=10) 93 | 94 | cax = axs[i].inset_axes([0., -0.25, 1, 0.1]) 95 | plt.colorbar(pcm, cax = cax, orientation="horizontal", label=f"{plot_var_labels[i]}") 96 | 97 | # add labels 98 | axs_flat = axs.flatten() 99 | labels = ["m)", "n)", "o)"] 100 | 101 | for i in range(len(axs_flat)): 102 | plt.text(x=-0.08, y=1.02, s=labels[i], 103 | fontsize=16, transform=axs_flat[i].transAxes) 104 | 105 | plt.tight_layout() 106 | save_filename = f"{plot_dir}/diffusion_std_{timestep.values}.png" 107 | plt.savefig(save_filename, bbox_inches="tight") 108 | plt.close() 109 | print(f"Saved as {save_filename}") 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /inference/save_test_preds.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import xarray as xr 3 | import torch 4 | from torch.utils.data.dataloader import DataLoader 5 | 6 | sys.path.append('../src') 7 | from Inference import * 8 | 9 | # Test years: 2018-2022 10 | year_start = 2018 11 | year_end = 2023 12 | 13 | # Choose diffusion or U-Net to run 14 | modelname = "UNet" # "Diffusion" or "UNet" or "LinearInterpolation" 15 | 16 | # Device 17 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | print(device) 19 | 20 | # Dirs 21 | data_dir="/home/Everyone/ERA5/data/" 22 | model_dir="/home/Everyone/Model_Inference/" 23 | save_dir = f"../output/{modelname}/" 24 | 25 | # Load model, sample function 26 | if modelname == "Diffusion": 27 | model = EDMPrecond((256, 128), 8, 28 | 3).to(device) 29 | model.load_state_dict(torch.load(f"{model_dir}/Model_chpt/diffusion.pt")) 30 | sample_function = sample_model_EDS 31 | num_steps = 100 32 | 33 | rngs = range(0, 30) 34 | elif modelname == "UNet": 35 | sample_function = lambda test_batch, model, device, dataset_test, num_steps: sample_unet(test_batch, model, device, dataset_test) 36 | # Load model 37 | model = UNet((256, 128), 5, 3, label_dim=2, use_diffuse=False).to(device) 38 | model.load_state_dict(torch.load(f"{model_dir}/Model_chpt/unet2.pt")) 39 | num_steps = None 40 | rngs = [""] 41 | 42 | elif modelname == "LinearInterpolation": 43 | def sample_function(test_batch, model, device, dataset_test, num_steps=None): 44 | coarse, fine = test_batch["coarse"], test_batch["fine"] 45 | return coarse, fine, coarse 46 | model = None 47 | num_steps = None 48 | rngs = [""] 49 | else: 50 | raise Exception(f"Choose modelname either Diffusion or UNet. You chose {modelname}") 51 | 52 | 53 | print(f"Running model {modelname} with sample function {sample_function}.") 54 | 55 | # Load dataset 56 | dataset_test = UpscaleDataset(data_dir, 57 | year_start=year_start, 58 | year_end=year_end, 59 | constant_variables=["lsm", "z"]) 60 | 61 | nlat = dataset_test.nlat 62 | nlon = dataset_test.nlon 63 | ntime = dataset_test.ntime 64 | print(ntime, nlat, nlon) 65 | 66 | 67 | 68 | for rng in rngs: 69 | if modelname == "Diffusion": 70 | np.random.seed(seed=rng) 71 | # Set up dataloader. Make sure shuffle=False so we go through timesteps in order. 72 | BATCH_SIZE = 32 73 | dataloader = DataLoader(dataset_test, 74 | batch_size=BATCH_SIZE, 75 | shuffle=False) 76 | 77 | 78 | # Set up xarray for saving: we will base this on the truth samples saved so all arrays have same 79 | # format, with same dimensions for time, lat and lon 80 | truth_filename = f"../output/Truth/samples_{year_start}-{year_end}.nc" 81 | truth_ds = xr.open_dataset(truth_filename, engine="netcdf4") 82 | print(truth_ds) 83 | 84 | # Create new arrays for saving predictions 85 | var_2T = xr.zeros_like(truth_ds.VAR_2T) 86 | var_10U = xr.zeros_like(truth_ds.VAR_10U) 87 | var_10V = xr.zeros_like(truth_ds.VAR_10V) 88 | 89 | t = 0 # time index 90 | for test_batch in dataloader: 91 | # Run model 92 | coarse, fine, predicted = sample_function(test_batch, model, 93 | device, dataset_test, 94 | num_steps=num_steps) 95 | all_pred_variables = predicted.detach().numpy() 96 | 97 | 98 | # This batch (may not be exactly = BATCH_SIZE for all batches) 99 | n_batch = all_pred_variables.shape[0] 100 | # Fill xarrays with predictions from t to t_end=t+n_batch 101 | t_end = t + n_batch 102 | var_2T[t:t_end] = all_pred_variables[:, 0] 103 | var_10U[t:t_end] = all_pred_variables[:, 1] 104 | var_10V[t:t_end] = all_pred_variables[:, 2] 105 | 106 | # Reset time index t for next iteration 107 | t = t_end 108 | 109 | 110 | # Create dataset 111 | ds_US = var_2T.to_dataset(name="VAR_2T") 112 | ds_US["VAR_10U"] = var_10U 113 | ds_US["VAR_10V"] = var_10V 114 | 115 | print(ds_US) 116 | 117 | 118 | # Save to netcdf4 119 | save_filename = f"{save_dir}/samples{rng}_{year_start}-{year_end}.nc" 120 | ds_US.to_netcdf(save_filename) 121 | print(f"Saved as {save_filename}") 122 | 123 | -------------------------------------------------------------------------------- /inference/save_test_truth.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import matplotlib.pyplot as plt 3 | import cartopy 4 | import torch 5 | from torch.utils.data.dataloader import DataLoader 6 | 7 | sys.path.append('../src/') 8 | from DatasetUS import * 9 | from Inference import * 10 | 11 | ## Saves US dataset over test years for quicker analysis 12 | ## Test years: 2018-2022 13 | year_start = 2021 14 | year_end = 2021 15 | 16 | ## Dirs 17 | data_dir="/home/Everyone/ERA5/data/" 18 | model_dir="/home/Everyone/Model_Inference/" 19 | plot_dir="../plots/" 20 | save_dir="../output/Truth/" 21 | 22 | # Code identical to dataset, but we keep things in xarray format 23 | filenames = [f"samples_{year}.nc" for year in range(year_start, year_end)] 24 | 25 | filename0 = filenames[0] 26 | path_to_file = data_dir + filename0 27 | ds = xr.open_dataset(path_to_file, engine="netcdf4") 28 | 29 | 30 | varnames = ["temp", "u-comp wind", "v-comp wind"] 31 | n_var = len(varnames) 32 | 33 | # Select domain with size 256 x 128 (W x H) 34 | ds_US = ds.sel(latitude=slice(54.5, 22.6), # latitude is ordered N to S 35 | longitude=slice(233.6, 297.5)) # longitude ordered E to W 36 | lon = ds_US.longitude # len 256 37 | lat = ds_US.latitude # len 128 38 | 39 | # Concatenate other files 40 | for filename in filenames[1:]: 41 | path_to_file = data_dir + filename 42 | ds = xr.open_dataset(path_to_file, engine="netcdf4") 43 | # Select domain 44 | ds_US = xr.concat((ds_US, 45 | ds.sel(latitude=slice(54.5, 22.6), # latitude is ordered N to S 46 | longitude=slice(233.6, 297.5))), # longitude ordered E to W 47 | dim="time") 48 | 49 | 50 | print(ds_US) 51 | 52 | ## Save to netcdf4 53 | save_filename = f"{save_dir}/samples_{year_start}-{year_end}.nc" 54 | ds_US.to_netcdf(save_filename) 55 | print(f"Saved as {save_filename}") 56 | -------------------------------------------------------------------------------- /src/DatasetUS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | import numpy as np 5 | import xarray as xr 6 | import matplotlib.pyplot as plt 7 | import cartopy.crs as ccrs 8 | 9 | 10 | class UpscaleDataset(torch.utils.data.Dataset): 11 | """ 12 | Dataset class of images with a low resolution and a high resolution counterpart 13 | over the US continent. 14 | """ 15 | 16 | def __init__(self, data_dir, 17 | in_shape=(16, 32), out_shape=(128, 256), 18 | year_start=1950, year_end=2001, 19 | normalize_rawdata_mean=torch.Tensor([2.8504e+02, 4.4536e-01, -1.1892e-01]), 20 | normalize_rawdata_std=torch.Tensor([12.7438, 3.4649, 3.742]), 21 | normalize_residual_mean=torch.Tensor([-9.4627e-05, -1.3833e-03, -1.5548e-03]), 22 | normalize_residual_std=torch.Tensor([1.6042, 1.0221, 1.0384]), 23 | constant_variables=None, 24 | constant_variables_filename="ERA5_const_sfc_variables.nc" 25 | ): 26 | """ 27 | :param data_dir: path to the dataset directory 28 | :param in_shape: shape of the low resolution images 29 | :param out_shape: shape of the high resolution images 30 | :param year_start: starting year of file named samples_{year_start}.nc 31 | :param year_end: ending year of file named samples_{year_end}.nc 32 | :param normalize_mean: channel-wise mean values estimated over all samples 33 | for normalizing file 34 | :param normalize_std: channel-wise standard deviation values estimated 35 | over all samples for normalizing file 36 | """ 37 | 38 | print("Opening files") 39 | self.filenames = [f"samples_{year}.nc" for year in range(year_start, year_end)] 40 | 41 | # Open first file for saving dimension info 42 | filename0 = self.filenames[0] 43 | path_to_file = data_dir + filename0 44 | ds = xr.open_dataset(path_to_file, engine="netcdf4") 45 | 46 | # Dimensions: lon, lat (global domain) 47 | self.lon_glob = ds.longitude 48 | self.lat_glob = ds.latitude 49 | self.varnames = ["temp", "u-comp wind", "v-comp wind"] 50 | self.n_var = len(self.varnames) 51 | 52 | 53 | # Select domain with size 256 x 128 (W x H) 54 | ds_US = ds.sel(latitude=slice(54.5, 22.6), # latitude is ordered N to S 55 | longitude=slice(233.6, 297.5)) # longitude ordered E to W 56 | self.lon = ds_US.longitude # len 256 57 | self.lat = ds_US.latitude # len 128 58 | self.nlon = self.W = len(self.lon) # Width 59 | self.nlat = self.H = len(self.lat) # Height 60 | 61 | # Concatenate other files 62 | for filename in self.filenames[1:]: 63 | path_to_file = data_dir + filename 64 | ds = xr.open_dataset(path_to_file, engine="netcdf4") 65 | # Select domain with size 256 x 128 (W x H) 66 | ds_US = xr.concat((ds_US, 67 | ds.sel(latitude=slice(54.5, 22.6), # latitude is ordered N to S 68 | longitude=slice(233.6, 297.5))), # longitude ordered E to W 69 | dim="time") 70 | 71 | print("All files accessed. Creating tensors") 72 | self.ntime = len(ds_US.time) 73 | 74 | # Convert xarray dataarrays into torch Tensor (loads into memory) 75 | t = torch.from_numpy(ds_US.VAR_2T.to_numpy()).float() 76 | u = torch.from_numpy(ds_US.VAR_10U.to_numpy()).float() 77 | v = torch.from_numpy(ds_US.VAR_10V.to_numpy()).float() 78 | 79 | # Stack into (ntime, 3, 128, 256), creating the fine resolution image. 80 | fine = torch.stack((t, u, v), dim=1) 81 | 82 | # Transforms 83 | # Coarsen 84 | coarsen_transform = torchvision.transforms.Resize(in_shape, 85 | interpolation=torchvision.transforms.InterpolationMode.BILINEAR, 86 | antialias=True) 87 | interp_transform = torchvision.transforms.Resize(out_shape, 88 | interpolation=torchvision.transforms.InterpolationMode.BILINEAR, 89 | antialias=True) 90 | # Coarsen fine into coarse image, but interp to keep it on same grid 91 | # This will be our input into NN 92 | coarse = interp_transform(coarsen_transform(fine)) 93 | # Calculate residual = fine - coarse. this will be our target 94 | residual = fine - coarse 95 | 96 | # Save unnormalized coarse and fine images for plotting 97 | self.coarse = coarse 98 | self.fine = fine 99 | 100 | # Normalize : use raw data means for coarse image 101 | normalize_rawdata_transform = torchvision.transforms.Normalize(normalize_rawdata_mean, normalize_rawdata_std) 102 | coarse_norm = normalize_rawdata_transform(coarse) 103 | 104 | # use residual means for the difference between them 105 | normalize_residual_transform = torchvision.transforms.Normalize(normalize_residual_mean, normalize_residual_std) 106 | residual_norm = normalize_residual_transform(residual) 107 | 108 | print(normalize_residual_std.shape) 109 | self.inverse_normalize_residual = lambda residual_norm: ((residual_norm * 110 | normalize_residual_std[:, np.newaxis, np.newaxis]) + 111 | normalize_residual_mean[:, np.newaxis, np.newaxis]) 112 | 113 | # Save 114 | self.targets = residual_norm # targets = normalized residual 115 | self.inputs = coarse_norm # inputs = normalized coarse 116 | 117 | # Define limits for plotting (plus/minus 2 sigma 118 | self.vmin = normalize_rawdata_mean - 2 * normalize_rawdata_std 119 | self.vmax = normalize_rawdata_mean + 2 * normalize_rawdata_std 120 | 121 | print(self.vmin, self.vmax) 122 | 123 | 124 | 125 | # Additional channels for constant variables 126 | self.constant_variables = constant_variables 127 | if constant_variables is not None: 128 | print("Opening constant variables file (e.g. land-sea mask, topography)") 129 | # Open file 130 | ds_const = xr.open_dataset(data_dir + constant_variables_filename, 131 | engine="netcdf4") 132 | ds_const = ds_const.sel(latitude=slice(54.5, 22.6), # latitude is ordered N to S 133 | longitude=slice(233.6, 297.5)) 134 | 135 | # Get torch tensors and concatenate 136 | self.const_var = torch.zeros((self.ntime, 137 | len(constant_variables), 138 | self.nlat, 139 | self.nlon), 140 | dtype=torch.float) 141 | 142 | for i, const_varname in enumerate(constant_variables): 143 | const_var = ds_const[const_varname] 144 | # normalize? 145 | if const_varname != "lsm": 146 | print(f"Normalize {const_varname}") 147 | weighted_var = const_var.weighted(np.cos(np.radians(ds_const.latitude))) 148 | mean_var = weighted_var.mean() # 2270.3596 149 | std_var = weighted_var.std() # 6149.4727 150 | print(f"Mean:{mean_var}, Std{std_var}") 151 | const_var = (const_var - mean_var) / std_var 152 | self.const_var[i] = torch.from_numpy(const_var.to_numpy()).float() 153 | self.inputs = torch.concatenate((self.inputs, self.const_var), dim=1) 154 | 155 | # Dimensions from orig to coarse 156 | lat_coarse_inds = np.arange(0, len(self.lat), 8, dtype=int) 157 | lon_coarse_inds = np.arange(0, len(self.lon), 8, dtype=int) 158 | 159 | self.lon_coarse = self.lon.isel(longitude=lon_coarse_inds) # len 32 160 | self.lat_coarse = self.lat.isel(latitude=lat_coarse_inds) # len 16 161 | 162 | # Time embeddings 163 | self.time = ds_US.time.dt # in datetime format 164 | self.year = self.time.year 165 | self.month = self.time.month 166 | self.day = self.time.day 167 | self.hour = self.time.hour 168 | # day of year (1 to 360) 169 | self.doy = ((self.month - 1.) * 30 + (self.day - 1.)) 170 | 171 | # Normalize and convert to numpy (load into mem) 172 | self.year_norm = (self.year.to_numpy() - 1940.)/100 173 | self.doy_norm = self.doy.to_numpy()/360. 174 | self.hour_norm = self.hour.to_numpy()/24. 175 | 176 | # Torch arrays and float 177 | self.year_norm = torch.from_numpy(self.year_norm).float() 178 | self.doy_norm = torch.from_numpy(self.doy_norm).float() 179 | self.hour_norm = torch.from_numpy(self.hour_norm).float() 180 | 181 | print("Dataset initialized.") 182 | 183 | def __len__(self): 184 | """ 185 | :return: length of the dataset 186 | """ 187 | return self.inputs.shape[0] 188 | 189 | def __getitem__(self, index): 190 | """ 191 | :param index: index of the dataset 192 | :return: input data and time data 193 | """ 194 | return {"inputs": self.inputs[index], 195 | "targets": self.targets[index], 196 | "fine": self.fine[index], 197 | "coarse": self.coarse[index], 198 | "year": self.year_norm[index], 199 | "doy": self.doy_norm[index], 200 | "hour": self.hour_norm[index]} 201 | 202 | def residual_to_fine_image(self, residual, coarse_image): 203 | return coarse_image + self.inverse_normalize_residual(residual) 204 | 205 | def plot_fine(self, image_fine, ax, vmin=-2, vmax=2): 206 | plt.sca(ax) 207 | ax.coastlines() 208 | plt.pcolormesh(self.lon, self.lat, image_fine, 209 | vmin=vmin, vmax=vmax, shading='nearest') 210 | 211 | def plot_all_channels(self, X, Y): 212 | """Plots T, u, V for single image (no batch dimension)""" 213 | fig, axs = plt.subplots(self.n_var, 2, figsize=(8, 2 * self.n_var), 214 | subplot_kw={'projection': ccrs.PlateCarree()}) 215 | for i in range(self.n_var): 216 | self.plot_fine(X[i], axs[i, 0]) 217 | plt.title(self.varnames[i] + " coarse-res") 218 | self.plot_fine(Y[i], axs[i, 1]) 219 | plt.title(self.varnames[i] + " fine-res") 220 | 221 | plt.tight_layout() 222 | return fig, axs 223 | 224 | def plot_batch(self, coarse_image, fine_image, fine_image_pred, N=3): 225 | """Plots u,v,T for N samples out of batch, separate 226 | column for coarse, predicted fine and truth fine""" 227 | fig, axs = plt.subplots(self.n_var * N, 3, figsize=(8, N * 5), 228 | subplot_kw={'projection': ccrs.PlateCarree()}) 229 | 230 | for j in range(N): 231 | # Plot batch 232 | for i in range(self.n_var): 233 | # Plot channel 234 | self.plot_fine(coarse_image[j, i], axs[(j * N) + i, 0], 235 | vmin=self.vmin[i], vmax=self.vmax[i]) 236 | self.plot_fine(fine_image_pred[j, i], axs[(j * N) + i, 1], 237 | vmin=self.vmin[i], vmax=self.vmax[i]) 238 | self.plot_fine(fine_image[j, i], axs[(j * N) + i, 2], 239 | vmin=self.vmin[i], vmax=self.vmax[i]) 240 | 241 | plt.tight_layout() 242 | return fig, axs 243 | -------------------------------------------------------------------------------- /src/Inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Network import UNet, EDMPrecond 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from DatasetUS import UpscaleDataset 6 | 7 | 8 | @torch.no_grad() 9 | def sample_unet(input_batch, model, device, dataset): 10 | 11 | images_input = input_batch["inputs"].to(device) 12 | coarse, fine = input_batch["coarse"], input_batch["fine"] 13 | condition_params = torch.stack( 14 | (input_batch["doy"].to(device), 15 | input_batch["hour"].to(device)), dim=1) 16 | residual = model(images_input, class_labels=condition_params) 17 | predicted = dataset.residual_to_fine_image(residual.detach().cpu(), coarse) 18 | return coarse, fine, predicted 19 | 20 | 21 | @torch.no_grad() 22 | def sample_model_EDS(input_batch, model, device, dataset, num_steps=40, 23 | sigma_min=0.002, sigma_max=80, rho=7, S_churn=40, 24 | S_min=0, S_max=float('inf'), S_noise=1): 25 | 26 | images_input = input_batch["inputs"].to(device) 27 | coarse, fine = input_batch["coarse"], input_batch["fine"] 28 | condition_params = torch.stack( 29 | (input_batch["doy"].to(device), 30 | input_batch["hour"].to(device)), dim=1) 31 | 32 | sigma_min = max(sigma_min, model.sigma_min) 33 | sigma_max = min(sigma_max, model.sigma_max) 34 | 35 | init_noise = torch.randn((images_input.shape[0], 3, images_input.shape[2], 36 | images_input.shape[3]), 37 | dtype=torch.float64, device=device) 38 | 39 | # Time step discretization. 40 | step_indices = torch.arange(num_steps, dtype=torch.float64, 41 | device=init_noise.device) 42 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) 43 | * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 44 | t_steps = torch.cat([model.round_sigma(t_steps), 45 | torch.zeros_like(t_steps[:1])]) # t_N = 0 46 | 47 | # Main sampling loop. 48 | x_next = init_noise.to(torch.float64) * t_steps[0] 49 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 50 | x_cur = x_next 51 | 52 | # Increase noise temporarily. 53 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 54 | t_hat = model.round_sigma(t_cur + gamma * t_cur) 55 | x_hat = (x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * 56 | torch.randn_like(x_cur)) 57 | 58 | # Euler step. 59 | denoised = model(x_hat, t_hat, images_input, condition_params).to( 60 | torch.float64) 61 | d_cur = (x_hat - denoised) / t_hat 62 | x_next = x_hat + (t_next - t_hat) * d_cur 63 | 64 | # Apply 2nd order correction. 65 | if i < num_steps - 1: 66 | denoised = model(x_next, t_next, images_input, 67 | condition_params).to(torch.float64) 68 | d_prime = (x_next - denoised) / t_next 69 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 70 | 71 | predicted = dataset.residual_to_fine_image( 72 | x_next.detach().cpu(), coarse) 73 | 74 | 75 | return coarse, fine, predicted 76 | 77 | 78 | if __name__ == "__main__": 79 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 80 | diff_model = EDMPrecond((256, 128), 8, 81 | 3).to(device) 82 | diff_model.load_state_dict(torch.load("./Model_chpt/diffusion.pt")) 83 | 84 | unet_model = UNet((256, 128), 5, 3, 85 | label_dim=2, use_diffuse=False).to(device) 86 | unet_model.load_state_dict(torch.load("./Model_chpt/unet.pt")) 87 | 88 | datadir = "/home/Everyone/ERA5/data/" 89 | dataset_test = UpscaleDataset(datadir, year_start=2017, year_end=2022, 90 | constant_variables=["lsm", "z"]) 91 | 92 | # Try diffusion model 93 | coarse, fine, predicted = sample_model_EDS(dataset_test[0:4], diff_model, 94 | device, dataset_test) 95 | fig, ax = plt.subplots(1, 3, figsize=(12, 4)) 96 | ax[0].pcolormesh(coarse[0, 0]) 97 | ax[0].set_title("Coarse") 98 | ax[1].pcolormesh(fine[0, 0]) 99 | ax[1].set_title("Fine") 100 | ax[2].pcolormesh(predicted[0, 0]) 101 | ax[2].set_title("Predicted") 102 | 103 | # Try unet model 104 | coarse, fine, predicted = sample_unet(dataset_test[0:4], unet_model, 105 | device, dataset_test) 106 | fig, ax = plt.subplots(1, 3, figsize=(12, 4)) 107 | ax[0].pcolormesh(coarse[0, 0]) 108 | ax[0].set_title("Coarse") 109 | ax[1].pcolormesh(fine[0, 0]) 110 | ax[1].set_title("Fine") 111 | ax[2].pcolormesh(predicted[0, 0]) 112 | ax[2].set_title("Predicted") 113 | 114 | plt.show() -------------------------------------------------------------------------------- /src/Network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # This file was modified by Robbie Watt (2024) for the purpose of downscaling 9 | # climate data 10 | 11 | """Model architectures and preconditioning schemes used in the paper 12 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 13 | 14 | import numpy as np 15 | import torch 16 | from torch.nn.functional import silu 17 | 18 | #---------------------------------------------------------------------------- 19 | # Unified routine for initializing weights and biases. 20 | 21 | def weight_init(shape, mode, fan_in, fan_out): 22 | if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) 23 | if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) 24 | if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) 25 | if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape) 26 | raise ValueError(f'Invalid init mode "{mode}"') 27 | 28 | #---------------------------------------------------------------------------- 29 | # Fully-connected layer. 30 | 31 | class Linear(torch.nn.Module): 32 | def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0): 33 | super().__init__() 34 | self.in_features = in_features 35 | self.out_features = out_features 36 | init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) 37 | self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight) 38 | self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None 39 | 40 | def forward(self, x): 41 | x = x @ self.weight.to(x.dtype).t() 42 | if self.bias is not None: 43 | x = x.add_(self.bias.to(x.dtype)) 44 | return x 45 | 46 | #---------------------------------------------------------------------------- 47 | # Convolutional layer with optional up/downsampling. 48 | 49 | class Conv2d(torch.nn.Module): 50 | def __init__(self, 51 | in_channels, out_channels, kernel, bias=True, up=False, down=False, 52 | resample_filter=[1,1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0, 53 | ): 54 | assert not (up and down) 55 | super().__init__() 56 | self.in_channels = in_channels 57 | self.out_channels = out_channels 58 | self.up = up 59 | self.down = down 60 | self.fused_resample = fused_resample 61 | init_kwargs = dict(mode=init_mode, fan_in=in_channels*kernel*kernel, fan_out=out_channels*kernel*kernel) 62 | self.weight = torch.nn.Parameter(weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight) if kernel else None 63 | self.bias = torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None 64 | f = torch.as_tensor(resample_filter, dtype=torch.float32) 65 | f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() 66 | self.register_buffer('resample_filter', f if up or down else None) 67 | 68 | def forward(self, x): 69 | w = self.weight.to(x.dtype) if self.weight is not None else None 70 | b = self.bias.to(x.dtype) if self.bias is not None else None 71 | f = self.resample_filter.to(x.dtype) if self.resample_filter is not None else None 72 | w_pad = w.shape[-1] // 2 if w is not None else 0 73 | f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 74 | 75 | if self.fused_resample and self.up and w is not None: 76 | x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0)) 77 | x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) 78 | elif self.fused_resample and self.down and w is not None: 79 | x = torch.nn.functional.conv2d(x, w, padding=w_pad+f_pad) 80 | x = torch.nn.functional.conv2d(x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2) 81 | else: 82 | if self.up: 83 | x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad) 84 | if self.down: 85 | x = torch.nn.functional.conv2d(x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad) 86 | if w is not None: 87 | x = torch.nn.functional.conv2d(x, w, padding=w_pad) 88 | if b is not None: 89 | x = x.add_(b.reshape(1, -1, 1, 1)) 90 | return x 91 | 92 | #---------------------------------------------------------------------------- 93 | # Group normalization. 94 | 95 | class GroupNorm(torch.nn.Module): 96 | def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5): 97 | super().__init__() 98 | self.num_groups = min(num_groups, num_channels // min_channels_per_group) 99 | self.eps = eps 100 | self.weight = torch.nn.Parameter(torch.ones(num_channels)) 101 | self.bias = torch.nn.Parameter(torch.zeros(num_channels)) 102 | 103 | def forward(self, x): 104 | x = torch.nn.functional.group_norm(x, num_groups=self.num_groups, weight=self.weight.to(x.dtype), bias=self.bias.to(x.dtype), eps=self.eps) 105 | return x 106 | 107 | #---------------------------------------------------------------------------- 108 | # Attention weight computation, i.e., softmax(Q^T * K). 109 | # Performs all computation using FP32, but uses the original datatype for 110 | # inputs/outputs/gradients to conserve memory. 111 | 112 | class AttentionOp(torch.autograd.Function): 113 | @staticmethod 114 | def forward(ctx, q, k): 115 | w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype) 116 | ctx.save_for_backward(q, k, w) 117 | return w 118 | 119 | @staticmethod 120 | def backward(ctx, dw): 121 | q, k, w = ctx.saved_tensors 122 | db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32) 123 | dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1]) 124 | dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1]) 125 | return dq, dk 126 | 127 | #---------------------------------------------------------------------------- 128 | # Unified U-Net block with optional up/downsampling and self-attention. 129 | # Represents the union of all features employed by the DDPM++, NCSN++, and 130 | # ADM architectures. 131 | 132 | class UNetBlock(torch.nn.Module): 133 | def __init__(self, 134 | in_channels, out_channels, emb_channels, up=False, down=False, attention=False, 135 | num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5, 136 | resample_filter=[1,1], resample_proj=False, adaptive_scale=True, 137 | init=dict(), init_zero=dict(init_weight=0), init_attn=None, 138 | ): 139 | super().__init__() 140 | self.in_channels = in_channels 141 | self.out_channels = out_channels 142 | self.emb_channels = emb_channels 143 | self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head 144 | self.dropout = dropout 145 | self.skip_scale = skip_scale 146 | self.adaptive_scale = adaptive_scale 147 | 148 | self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) 149 | self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init) 150 | self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init) 151 | self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) 152 | self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero) 153 | 154 | self.skip = None 155 | if out_channels != in_channels or up or down: 156 | kernel = 1 if resample_proj or out_channels!= in_channels else 0 157 | self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init) 158 | 159 | if self.num_heads: 160 | self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) 161 | self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init)) 162 | self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero) 163 | 164 | def forward(self, x, emb): 165 | orig = x 166 | x = self.conv0(silu(self.norm0(x))) 167 | 168 | params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) 169 | if self.adaptive_scale: 170 | scale, shift = params.chunk(chunks=2, dim=1) 171 | x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) 172 | else: 173 | x = silu(self.norm1(x.add_(params))) 174 | 175 | x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training)) 176 | x = x.add_(self.skip(orig) if self.skip is not None else orig) 177 | x = x * self.skip_scale 178 | 179 | if self.num_heads: 180 | q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2) 181 | w = AttentionOp.apply(q, k) 182 | a = torch.einsum('nqk,nck->ncq', w, v) 183 | x = self.proj(a.reshape(*x.shape)).add_(x) 184 | x = x * self.skip_scale 185 | return x 186 | 187 | #---------------------------------------------------------------------------- 188 | # Timestep embedding used in the DDPM++ and ADM architectures. 189 | 190 | class PositionalEmbedding(torch.nn.Module): 191 | def __init__(self, num_channels, max_positions=10000, endpoint=False): 192 | super().__init__() 193 | self.num_channels = num_channels 194 | self.max_positions = max_positions 195 | self.endpoint = endpoint 196 | 197 | def forward(self, x): 198 | freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) 199 | freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) 200 | freqs = (1 / self.max_positions) ** freqs 201 | x = x.ger(freqs.to(x.dtype)) 202 | x = torch.cat([x.cos(), x.sin()], dim=1) 203 | return x 204 | 205 | #---------------------------------------------------------------------------- 206 | # Timestep embedding used in the NCSN++ architecture. 207 | 208 | class FourierEmbedding(torch.nn.Module): 209 | def __init__(self, num_channels, scale=16): 210 | super().__init__() 211 | self.register_buffer('freqs', torch.randn(num_channels // 2) * scale) 212 | 213 | def forward(self, x): 214 | x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) 215 | x = torch.cat([x.cos(), x.sin()], dim=1) 216 | return x 217 | 218 | #---------------------------------------------------------------------------- 219 | # Reimplementation of the ADM architecture from the paper 220 | # "Diffusion Models Beat GANS on Image Synthesis". Equivalent to the 221 | # original implementation by Dhariwal and Nichol, available at 222 | # https://github.com/openai/guided-diffusion 223 | 224 | class UNet(torch.nn.Module): 225 | def __init__(self, 226 | img_resolution, # Image resolution at input/output. 227 | in_channels, # Number of color channels at input. 228 | out_channels, # Number of color channels at output. 229 | label_dim = 0, # Number of class labels, 0 = unconditional. 230 | augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation. 231 | 232 | model_channels = 128, # Base multiplier for the number of 233 | # channels. 234 | channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels. 235 | channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector. 236 | num_blocks = 2, # Number of residual blocks per resolution. 237 | attn_resolutions = [32,16,8], # List of resolutions with self-attention. 238 | dropout = 0.10, # List of resolutions with self-attention. 239 | label_dropout = 0, # Dropout probability of class labels for classifier-free guidance. 240 | use_diffuse = True # Use Unet for diffusion 241 | ): 242 | super().__init__() 243 | self.label_dropout = label_dropout 244 | emb_channels = model_channels * channel_mult_emb 245 | init = dict(init_mode='kaiming_uniform', init_weight=np.sqrt(1/3), init_bias=np.sqrt(1/3)) 246 | init_zero = dict(init_mode='kaiming_uniform', init_weight=0, init_bias=0) 247 | block_kwargs = dict(emb_channels=emb_channels, channels_per_head=64, dropout=dropout, init=init, init_zero=init_zero) 248 | 249 | # Mapping. 250 | self.map_noise = PositionalEmbedding(num_channels=model_channels) if use_diffuse else None 251 | self.map_augment = Linear(in_features=augment_dim, out_features=model_channels, bias=False, **init_zero) if augment_dim else None 252 | self.map_layer0 = Linear(in_features=model_channels, out_features=emb_channels, **init) 253 | self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init) 254 | self.map_label = Linear(in_features=label_dim, out_features=emb_channels, bias=False, init_mode='kaiming_normal', init_weight=np.sqrt(label_dim)) if label_dim else None 255 | 256 | assert len(img_resolution) == 2 257 | 258 | # Encoder. 259 | self.enc = torch.nn.ModuleDict() 260 | cout = in_channels 261 | for level, mult in enumerate(channel_mult): 262 | resx = img_resolution[0] >> level 263 | resy = img_resolution[1] >> level 264 | if level == 0: 265 | cin = cout 266 | cout = model_channels * mult 267 | self.enc[f'{resx}x{resy}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init) 268 | else: 269 | self.enc[f'{resx}x{resy}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs) 270 | for idx in range(num_blocks): 271 | cin = cout 272 | cout = model_channels * mult 273 | self.enc[f'{resx}x{resy}_block{idx}'] = UNetBlock( 274 | in_channels=cin, out_channels=cout, attention=(resx in 275 | attn_resolutions), **block_kwargs) 276 | skips = [block.out_channels for block in self.enc.values()] 277 | 278 | # Decoder. 279 | self.dec = torch.nn.ModuleDict() 280 | for level, mult in reversed(list(enumerate(channel_mult))): 281 | resx = img_resolution[0] >> level 282 | resy = img_resolution[1] >> level 283 | if level == len(channel_mult) - 1: 284 | self.dec[f'{resx}x{resy}_in0'] = UNetBlock(in_channels=cout, 285 | out_channels=cout, attention=True, **block_kwargs) 286 | self.dec[f'{resx}x{resy}_in1'] = UNetBlock(in_channels=cout, 287 | out_channels=cout, **block_kwargs) 288 | else: 289 | self.dec[f'{resx}x{resy}_up'] = UNetBlock(in_channels=cout, 290 | out_channels=cout, up=True, **block_kwargs) 291 | for idx in range(num_blocks + 1): 292 | cin = cout + skips.pop() 293 | cout = model_channels * mult 294 | self.dec[f'{resx}x{resy}_block{idx}'] = UNetBlock( 295 | in_channels=cin, out_channels=cout, attention=(resx in 296 | attn_resolutions), **block_kwargs) 297 | self.out_norm = GroupNorm(num_channels=cout) 298 | self.out_conv = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero) 299 | 300 | def forward(self, x, noise_labels=None, class_labels=None, 301 | augment_labels=None): 302 | # Mapping. 303 | emb = torch.zeros([1, self.map_layer1.in_features], device=x.device) 304 | if self.map_label is not None: 305 | tmp = class_labels 306 | if self.training and self.label_dropout: 307 | tmp = tmp * (torch.rand([x.shape[0], 1], 308 | device=x.device) >= self.label_dropout).to( 309 | tmp.dtype) 310 | emb = self.map_label(tmp) 311 | if self.map_noise is not None: 312 | emb_n = self.map_noise(noise_labels) 313 | emb_n = silu(self.map_layer0(emb_n)) 314 | emb_n = self.map_layer1(emb_n) 315 | emb = emb + emb_n 316 | if self.map_augment is not None and augment_labels is not None: 317 | emb = emb + self.map_augment(augment_labels) 318 | 319 | emb = silu(emb) 320 | 321 | # Encoder. 322 | skips = [] 323 | for block in self.enc.values(): 324 | x = block(x, emb) if isinstance(block, UNetBlock) else block(x) 325 | skips.append(x) 326 | 327 | # Decoder. 328 | for block in self.dec.values(): 329 | if x.shape[1] != block.in_channels: 330 | x = torch.cat([x, skips.pop()], dim=1) 331 | x = block(x, emb) 332 | x = self.out_conv(silu(self.out_norm(x))) 333 | return x 334 | 335 | #---------------------------------------------------------------------------- 336 | # Improved preconditioning proposed in the paper "Elucidating the Design 337 | # Space of Diffusion-Based Generative Models" (EDM). 338 | 339 | class EDMPrecond(torch.nn.Module): 340 | def __init__(self, 341 | img_resolution, # Image resolution. 342 | in_channels, # Number of color channels. 343 | out_channels, # Number of color channels. 344 | label_dim = 0, # Number of class labels, 0 = unconditional. 345 | use_fp16 = False, # Execute the underlying model at FP16 precision? 346 | sigma_min = 0, # Minimum supported noise level. 347 | sigma_max = float('inf'), # Maximum supported noise level. 348 | sigma_data = 1.0, # Expected standard deviation of 349 | # the training data. 350 | model_type = 'UNet', # Class name of the underlying model. 351 | **model_kwargs, # Keyword arguments for the underlying model. 352 | ): 353 | super().__init__() 354 | self.img_resolution = img_resolution 355 | self.in_channels = in_channels 356 | self.out_channels = out_channels 357 | self.label_dim = label_dim 358 | self.use_fp16 = use_fp16 359 | self.sigma_min = sigma_min 360 | self.sigma_max = sigma_max 361 | self.sigma_data = sigma_data 362 | self.model = globals()[model_type]( 363 | img_resolution=img_resolution, in_channels=in_channels, 364 | out_channels=out_channels, label_dim=label_dim, **model_kwargs) 365 | 366 | def forward(self, x, sigma, condition_img=None, class_labels=None, 367 | force_fp32=True, **model_kwargs): 368 | if condition_img is not None: 369 | in_img = torch.cat([x, condition_img], dim=1) 370 | else: 371 | in_img = x 372 | sigma = sigma.reshape(-1, 1, 1, 1) 373 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=in_img.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) 374 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and in_img.device.type == 'cuda') else torch.float32 375 | 376 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 377 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() 378 | c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() 379 | c_noise = sigma.log() / 4 380 | 381 | F_x = self.model((c_in * in_img).to(dtype), 382 | noise_labels=c_noise.flatten(), 383 | class_labels=class_labels, **model_kwargs).to(dtype) 384 | assert F_x.dtype == dtype 385 | D_x = c_skip * x + c_out * F_x 386 | return D_x 387 | 388 | def round_sigma(self, sigma): 389 | return torch.as_tensor(sigma) 390 | 391 | #---------------------------------------------------------------------------- 392 | 393 | -------------------------------------------------------------------------------- /src/TrainDiffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import Network 3 | from tqdm import tqdm 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from DatasetUS import UpscaleDataset 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | # Loss class taken from EDS_Diffusion/loss.py 11 | class EDMLoss: 12 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=1.0): 13 | self.P_mean = P_mean 14 | self.P_std = P_std 15 | self.sigma_data = sigma_data 16 | 17 | def __call__(self, net, images, conditional_img=None, labels=None, 18 | augment_pipe=None): 19 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 20 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 21 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data)**2 22 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 23 | n = torch.randn_like(y) * sigma 24 | D_yn = net(y + n, sigma, conditional_img, labels, 25 | augment_labels=augment_labels) 26 | loss = weight * ((D_yn - y) ** 2) 27 | return loss 28 | 29 | 30 | def training_step(model, loss_fn, optimiser, data_loader, scaler, step, 31 | accum=4, writer=None, device="cuda"): 32 | """ 33 | Function for a single training step. 34 | :param model: Instance of the Unet class 35 | :param loss_fn: Loss function 36 | :param optimiser: Optimiser to use 37 | :param data_loader: Data loader 38 | :param scaler: Scaler for mixed precision training 39 | :param step: Current step 40 | :param accum: Number of steps to accumulate gradients over 41 | :param writer: Tensorboard writer 42 | :param device: Device to use 43 | :return: Loss value 44 | """ 45 | 46 | model.train() 47 | with tqdm(total=len(data_loader), dynamic_ncols=True) as tq: 48 | tq.set_description(f"Train :: Epoch: {step}") 49 | 50 | epoch_losses = [] 51 | step_loss = 0 52 | for i, batch in enumerate(data_loader): 53 | tq.update(1) 54 | 55 | image_input = batch["inputs"].to(device) 56 | image_output = batch["targets"].to(device) 57 | day = batch["doy"].to(device) 58 | hour = batch["hour"].to(device) 59 | condition_params = torch.stack((day, hour), dim=1) 60 | 61 | # forward unet 62 | with torch.cuda.amp.autocast(): 63 | loss = loss_fn(net=model, images=image_output, 64 | conditional_img=image_input, 65 | labels=condition_params) 66 | loss = torch.mean(loss) 67 | 68 | # backpropagation 69 | scaler.scale(loss).backward() 70 | step_loss += loss.item() 71 | 72 | if (i + 1) % accum == 0: 73 | scaler.step(optimiser) 74 | scaler.update() 75 | optimiser.zero_grad(set_to_none=True) 76 | 77 | if writer is not None: 78 | writer.add_scalar("Loss/train", step_loss / accum, 79 | step * len(data_loader) + i) 80 | step_loss = 0 81 | 82 | epoch_losses.append(loss.item()) 83 | tq.set_postfix_str(s=f"Loss: {loss.item():.4f}") 84 | mean_loss = sum(epoch_losses) / len(epoch_losses) 85 | tq.set_postfix_str(s=f"Loss: {mean_loss:.4f}") 86 | return mean_loss 87 | 88 | 89 | @torch.no_grad() 90 | def sample_model(model, dataloader, num_steps=40, sigma_min=0.002, 91 | sigma_max=80, rho=7, S_churn=40, S_min=0, 92 | S_max=float('inf'), S_noise=1, device="cuda"): 93 | 94 | batch = next(iter(dataloader)) 95 | images_input = batch["inputs"].to(device) 96 | coarse, fine = batch["coarse"], batch["fine"] 97 | 98 | condition_params = torch.stack( 99 | (batch["doy"].to(device), 100 | batch["hour"].to(device)), dim=1) 101 | 102 | sigma_min = max(sigma_min, model.sigma_min) 103 | sigma_max = min(sigma_max, model.sigma_max) 104 | 105 | init_noise = torch.randn((images_input.shape[0], 3, images_input.shape[2], 106 | images_input.shape[3]), 107 | dtype=torch.float64, device=device) 108 | 109 | # Time step discretization. 110 | step_indices = torch.arange(num_steps, dtype=torch.float64, 111 | device=init_noise.device) 112 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) 113 | * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 114 | t_steps = torch.cat([model.round_sigma(t_steps), 115 | torch.zeros_like(t_steps[:1])]) # t_N = 0 116 | 117 | # Main sampling loop. 118 | x_next = init_noise.to(torch.float64) * t_steps[0] 119 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 120 | x_cur = x_next 121 | 122 | # Increase noise temporarily. 123 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 124 | t_hat = model.round_sigma(t_cur + gamma * t_cur) 125 | x_hat = (x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * 126 | torch.randn_like(x_cur)) 127 | 128 | # Euler step. 129 | denoised = model(x_hat, t_hat, images_input, condition_params).to( 130 | torch.float64) 131 | d_cur = (x_hat - denoised) / t_hat 132 | x_next = x_hat + (t_next - t_hat) * d_cur 133 | 134 | # Apply 2nd order correction. 135 | if i < num_steps - 1: 136 | denoised = model(x_next, t_next, images_input, 137 | condition_params).to(torch.float64) 138 | d_prime = (x_next - denoised) / t_next 139 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 140 | 141 | predicted = dataloader.dataset.residual_to_fine_image( 142 | x_next.detach().cpu(), coarse) 143 | 144 | fig, ax = dataloader.dataset.plot_batch(coarse, fine, predicted) 145 | 146 | plt.subplots_adjust(wspace=0, hspace=0) 147 | base_error = torch.mean(torch.abs(fine - coarse)) 148 | pred_error = torch.mean(torch.abs(fine - predicted)) 149 | 150 | return (fig, ax), (base_error.item(), pred_error.item()) 151 | 152 | 153 | def main(): 154 | batch_size = 8 155 | learning_rate = 1e-4 156 | num_epochs = 10000 157 | accum = 8 158 | 159 | # Define device 160 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 161 | network = Network.EDMPrecond((256, 128), 8, 3, label_dim=2) 162 | network.to(device) 163 | 164 | # define the datasets 165 | datadir = "/home/Everyone/ERA5/data/" 166 | dataset_train = UpscaleDataset(datadir, year_start=1950, year_end=2017, 167 | constant_variables=["lsm", "z"]) 168 | 169 | dataset_test = UpscaleDataset(datadir, year_start=2017, year_end=2018, 170 | constant_variables=["lsm", "z"]) 171 | 172 | dataloader_train = torch.utils.data.DataLoader( 173 | dataset_train, batch_size=batch_size, shuffle=True, num_workers=4) 174 | dataloader_test = torch.utils.data.DataLoader( 175 | dataset_test, batch_size=batch_size, shuffle=True, num_workers=4) 176 | 177 | scaler = torch.cuda.amp.GradScaler() 178 | 179 | # define the optimiser 180 | optimiser = torch.optim.AdamW(network.parameters(), lr=learning_rate) 181 | 182 | # Define the tensorboard writer 183 | writer = SummaryWriter("./runs") 184 | 185 | # define loss function 186 | loss_fn = EDMLoss() 187 | 188 | # train the model 189 | losses = [] 190 | for step in range(0, num_epochs): 191 | epoch_loss = training_step(network, loss_fn, optimiser, 192 | dataloader_train, scaler, step, 193 | accum, writer) 194 | losses.append(epoch_loss) 195 | 196 | if (step + 0) % 5 == 0: 197 | (fig, ax), (base_error, pred_error) = sample_model( 198 | network, dataloader_test) 199 | fig.savefig(f"./results/{step}.png", dpi=300) 200 | plt.close(fig) 201 | 202 | writer.add_scalar("Error/base", base_error, step) 203 | writer.add_scalar("Error/pred", pred_error, step) 204 | 205 | # save the model 206 | if losses[-1] == min(losses): 207 | torch.save(network.state_dict(), f"./Model/{step}.pt") 208 | 209 | 210 | if __name__ == "__main__": 211 | main() 212 | 213 | -------------------------------------------------------------------------------- /src/TrainUnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from DatasetUS import UpscaleDataset 4 | import matplotlib.pyplot as plt 5 | from Network import UNet 6 | from torch.utils.tensorboard import SummaryWriter 7 | from tqdm import tqdm 8 | 9 | 10 | def train_step(model, loss_fn, data_loader, optimiser, scaler, step, accum=4, 11 | writer=None, device="cuda"): 12 | """ 13 | Function for a single training step. 14 | :param model: instance of the Unet class 15 | :param loss_fn: loss function 16 | :param data_loader: data loader 17 | :param optimiser: optimiser to use 18 | :param scaler: scaler for mixed precision training 19 | :param step: current step 20 | :param accum: number of steps to accumulate gradients over 21 | :param writer: tensorboard writer 22 | :param device: device to use 23 | :return: loss value 24 | """ 25 | 26 | model.train() 27 | 28 | with tqdm(total=len(data_loader), dynamic_ncols=True) as tq: 29 | tq.set_description(f"Train :: Epoch: {step}") 30 | 31 | epoch_losses = [] 32 | for i, batch in enumerate(data_loader): 33 | tq.update(1) 34 | 35 | image_input = batch["inputs"].to(device) 36 | image_output = batch["targets"].to(device) 37 | day = batch["doy"].to(device) 38 | hour = batch["hour"].to(device) 39 | condition_params = torch.stack((day, hour), dim=1) 40 | 41 | # forward unet 42 | with torch.cuda.amp.autocast(): 43 | model_out = model(image_input, 44 | class_labels=condition_params) 45 | loss = loss_fn(model_out, image_output) 46 | 47 | # backpropagation 48 | scaler.scale(loss).backward() 49 | 50 | if (i + 1) % accum == 0: 51 | scaler.step(optimiser) 52 | scaler.update() 53 | optimiser.zero_grad(set_to_none=True) 54 | 55 | epoch_losses.append(loss.item()) 56 | tq.set_postfix_str(s=f"Loss: {loss.item():.4f}") 57 | 58 | if writer is not None: 59 | writer.add_scalar("Loss/train", loss.item(), 60 | step * len(data_loader) + i) 61 | 62 | mean_loss = sum(epoch_losses) / len(epoch_losses) 63 | tq.set_postfix_str(s=f"Loss: {mean_loss:.4f}") 64 | 65 | return mean_loss 66 | 67 | 68 | @torch.no_grad() 69 | def sample_model(model, dataloader, device="cuda"): 70 | """ 71 | Function for sampling the model. 72 | :param model: instance of the Unet class 73 | :param dataloader: data loader 74 | """ 75 | 76 | model.eval() 77 | 78 | # Get n_images from the dataloader 79 | batch = next(iter(dataloader)) 80 | images_input = batch["inputs"].to(device) 81 | coarse, fine = batch["coarse"], batch["fine"] 82 | condition_params = torch.stack( 83 | (batch["doy"].to(device), 84 | batch["hour"].to(device)), dim=1) 85 | residual = model(images_input, class_labels=condition_params) 86 | 87 | predicted = dataloader.dataset.residual_to_fine_image( 88 | residual.detach().cpu(), coarse) 89 | 90 | fig, ax = dataloader.dataset.plot_batch(coarse, fine, predicted) 91 | 92 | plt.subplots_adjust(wspace=0, hspace=0) 93 | base_error = torch.mean(torch.abs(fine - coarse)) 94 | pred_error = torch.mean(torch.abs(fine - predicted)) 95 | 96 | return (fig, ax), (base_error.item(), pred_error.item()) 97 | 98 | 99 | def main(): 100 | batch_size = 8 101 | learning_rate = 3e-5 102 | num_epochs = 10000 103 | accum = 8 104 | 105 | # Define device 106 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 107 | 108 | # define the ml model 109 | unet_model = UNet((256, 128), 5, 3, label_dim=2, use_diffuse=False) 110 | unet_model.to(device) 111 | 112 | # define the datasets 113 | datadir = "/home/Everyone/ERA5/data/" 114 | dataset_train = UpscaleDataset(datadir, year_start=1950, year_end=2017, 115 | constant_variables=["lsm", "z"]) 116 | 117 | dataset_test = UpscaleDataset(datadir, year_start=2017, year_end=2018, 118 | constant_variables=["lsm", "z"]) 119 | 120 | dataloader_train = torch.utils.data.DataLoader( 121 | dataset_train, batch_size=batch_size, shuffle=True, num_workers=4) 122 | dataloader_test = torch.utils.data.DataLoader( 123 | dataset_test, batch_size=batch_size, shuffle=True, num_workers=4) 124 | 125 | scaler = torch.cuda.amp.GradScaler() 126 | 127 | # define the optimiser 128 | optimiser = torch.optim.AdamW(unet_model.parameters(), lr=learning_rate) 129 | 130 | # Define the tensorboard writer 131 | writer = SummaryWriter("./runs_unet") 132 | 133 | loss_fn = torch.nn.MSELoss() 134 | 135 | # train the model 136 | losses = [] 137 | for step in range(6, num_epochs): 138 | epoch_loss = train_step( 139 | unet_model, loss_fn, dataloader_train, optimiser, 140 | scaler, step, accum, writer) 141 | losses.append(epoch_loss) 142 | 143 | if (step + 0) % 5 == 0: 144 | (fig, ax), (base_error, pred_error) = sample_model( 145 | unet_model, dataloader_test) 146 | fig.savefig(f"./results_unet/{step}.png", dpi=300) 147 | plt.close(fig) 148 | 149 | writer.add_scalar("Error/base", base_error, step) 150 | writer.add_scalar("Error/pred", pred_error, step) 151 | 152 | # save the model 153 | if losses[-1] == min(losses): 154 | torch.save(unet_model.state_dict(), f"./Models_unet/{step}.pt") 155 | 156 | if __name__ == "__main__": 157 | main() 158 | --------------------------------------------------------------------------------