├── .gitignore ├── LICENSE ├── README.md ├── _analyze_models_nohigh.ipynb ├── _analyze_models_v1.0.ipynb ├── _analyze_outOfSample.ipynb ├── _compare_obs.ipynb ├── _compare_random_seeds.ipynb ├── _plot_temps.ipynb ├── _train_model_v4.0_loopA.ipynb ├── _train_model_v4.0_rev2.ipynb ├── _visualize_xai_v1.0.ipynb ├── custom_metrics.py ├── data_processing.py ├── experiment_settings.py ├── file_methods.py ├── initial_processing ├── _process_data_step1.ipynb ├── _process_data_step2.ipynb ├── _process_data_step2_archived.ipynb └── curl-ucar.cgd.cesm2le.atm.proc.monthly_ave.TREFHT-20220228T0728Z.sh ├── network.py ├── plots.py └── xai.py /.gitignore: -------------------------------------------------------------------------------- 1 | #Specific 2 | archive/ 3 | standby/ 4 | figures/ 5 | data/ 6 | docs/ 7 | saved_models/ 8 | saved_predictions/ 9 | from_noah/ 10 | model_diagnostics/ 11 | zip_files/ 12 | 13 | # General 14 | .ipynb_checkpoints/ 15 | *.ipynb_checkpoints/ 16 | __pycache__/ 17 | *.ipynb_checkpoints 18 | *.pyc 19 | *.pyo 20 | checkpoints/ 21 | 22 | #Apple Stuff 23 | .DS_Store 24 | .DS_* 25 | Icon 26 | 27 | PyCharm 28 | .idea/ 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Elizabeth Barnes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Detecting temperature targets 2 | *** 3 | Neural networks are trained on CMIP6 data to detect the remaining number of years until specific temperature targets are reached. 4 | 5 | ## Tensorflow Code 6 | *** 7 | This code was written in python 3.9.7, tensorflow 2.7.0, tensorflow-probability 0.15.0 and numpy 1.22.2. 8 | 9 | ### Python Environment 10 | The following python environment was used to implement this code. 11 | ``` 12 | conda create --name env-noah python=3.9 13 | conda activate env-noah 14 | pip install tensorflow==2.7.0 15 | pip install tensorflow-probability==0.15.0 16 | pip install --upgrade numpy scipy pandas statsmodels matplotlib seaborn palettable progressbar2 tabulate icecream flake8 keras-tuner sklearn jupyterlab black isort jupyterlab_code_formatter 17 | pip install -U scikit-learn 18 | pip install silence-tensorflow tqdm 19 | conda install -c conda-forge cmocean cartopy 20 | conda install -c conda-forge xarray dask netCDF4 bottleneck 21 | conda install -c conda-forge nc-time-axis 22 | ``` 23 | 24 | ## Credits 25 | *** 26 | This work is a collaborative effort between[Dr. Noah Diffenbaugh](https://earth.stanford.edu/people/noah-diffenbaugh#gs.runods) and [Dr. Elizabeth A. Barnes](https://barnes.atmos.colostate.edu). 27 | 28 | ### References 29 | [1] None. 30 | 31 | ### License 32 | This project is licensed under an MIT license. 33 | 34 | MIT © [Elizabeth A. Barnes](https://github.com/eabarnes1010) 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /_analyze_outOfSample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6da79c16-eb43-4664-a883-7a31f3af00da", 6 | "metadata": { 7 | "id": "4a650402-4774-49cb-9b72-9c8f1dd02f1d", 8 | "tags": [] 9 | }, 10 | "source": [ 11 | "# Analyze models\n", 12 | "##### authors: Elizabeth A. Barnes and Noah Diffenbaugh\n", 13 | "##### date: March 20, 2022\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "id": "7ccff821-b304-4009-8fe8-75a213b3f421", 19 | "metadata": { 20 | "tags": [] 21 | }, 22 | "source": [ 23 | "## Python stuff" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "fb968382-4186-466e-a85b-b00caa5fc9be", 30 | "metadata": { 31 | "colab": { 32 | "base_uri": "https://localhost:8080/" 33 | }, 34 | "executionInfo": { 35 | "elapsed": 17642, 36 | "status": "ok", 37 | "timestamp": 1646449680995, 38 | "user": { 39 | "displayName": "Elizabeth Barnes", 40 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiNPVVIWP6XAkP_hwu-8rAxoeeNuk2BMkX5-yuA=s64", 41 | "userId": "07585723222468022011" 42 | }, 43 | "user_tz": 420 44 | }, 45 | "id": "fb968382-4186-466e-a85b-b00caa5fc9be", 46 | "outputId": "d7964af9-2d52-4466-902d-9b85faba9a91", 47 | "tags": [] 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "import sys, imp, os, copy\n", 52 | "\n", 53 | "import xarray as xr\n", 54 | "import numpy as np\n", 55 | "import pandas as pd\n", 56 | "import matplotlib.pyplot as plt\n", 57 | "import scipy.stats as stats\n", 58 | "import tensorflow as tf\n", 59 | "import tensorflow_probability as tfp\n", 60 | "\n", 61 | "import scipy.stats as stats\n", 62 | "import seaborn as sns\n", 63 | "\n", 64 | "import experiment_settings\n", 65 | "import file_methods, plots, data_processing\n", 66 | "\n", 67 | "from scipy.signal import savgol_filter\n", 68 | "\n", 69 | "import matplotlib as mpl\n", 70 | "mpl.rcParams[\"figure.facecolor\"] = \"white\"\n", 71 | "mpl.rcParams[\"figure.dpi\"] = 150\n", 72 | "savefig_dpi = 300\n", 73 | "np.warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "29a5cee3-6f4f-4818-92e1-1351eeeb565a", 80 | "metadata": { 81 | "colab": { 82 | "base_uri": "https://localhost:8080/" 83 | }, 84 | "executionInfo": { 85 | "elapsed": 30, 86 | "status": "ok", 87 | "timestamp": 1646449681009, 88 | "user": { 89 | "displayName": "Elizabeth Barnes", 90 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiNPVVIWP6XAkP_hwu-8rAxoeeNuk2BMkX5-yuA=s64", 91 | "userId": "07585723222468022011" 92 | }, 93 | "user_tz": 420 94 | }, 95 | "id": "29a5cee3-6f4f-4818-92e1-1351eeeb565a", 96 | "outputId": "e5f5b0ac-82b8-4147-bf44-4bc3b49466a2", 97 | "tags": [] 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "print(f\"python version = {sys.version}\")\n", 102 | "print(f\"numpy version = {np.__version__}\")\n", 103 | "print(f\"xarray version = {xr.__version__}\") \n", 104 | "print(f\"tensorflow version = {tf.__version__}\") \n", 105 | "print(f\"tensorflow-probability version = {tfp.__version__}\") " 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "id": "651315ce-eecc-4d30-8b90-c97d08936315", 111 | "metadata": { 112 | "tags": [] 113 | }, 114 | "source": [ 115 | "## User Choices" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "c83a544f-ef35-417f-bec4-62225d885014", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "EXP_NAME = 'exp20C_126'#'exp20C_126_all7'#'exp20C_126'#'exp20C_126'#'exp15C_370_uniform' #'exp15C_126_uniform'#'exp20C_126'\n", 126 | "#-------------------------------------------------------\n", 127 | "\n", 128 | "settings = experiment_settings.get_settings(EXP_NAME)\n", 129 | "# display(settings)\n", 130 | "\n", 131 | "MODEL_DIRECTORY = 'saved_models/' \n", 132 | "PREDICTIONS_DIRECTORY = 'saved_predictions/'\n", 133 | "DATA_DIRECTORY = 'data/'\n", 134 | "DIAGNOSTICS_DIRECTORY = 'model_diagnostics/'\n", 135 | "FIGURE_DIRECTORY = 'figures/'" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "id": "d73e25b5-ca78-4984-a318-12e47643aaca", 141 | "metadata": {}, 142 | "source": [ 143 | "## Get seed to show in plot\n", 144 | "You need to first run compare_random_seeds.ipynb to ensure the data/stats on your experiments were saved in the df_random_seed.pickle file." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "a63880c0-517b-40b8-9685-4c1b47e56494", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "df_metrics = pd.read_pickle(PREDICTIONS_DIRECTORY + \"df_random_seed.pickle\")\n", 155 | "df = df_metrics[df_metrics[\"exp_name\"]==EXP_NAME]\n", 156 | "PLOT_SEED = df_metrics.iloc[df['loss_val'].idxmin()][\"seed\"]\n", 157 | "PLOT_SEED = 2247\n", 158 | "print(PLOT_SEED)\n", 159 | "display(df)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "id": "47e6f742-a1e9-423c-9048-450208f54ca9", 165 | "metadata": { 166 | "tags": [] 167 | }, 168 | "source": [ 169 | "## Plotting Functions" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "9d29d831-aa15-46fd-89be-07628cc0f8b2", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "FS = 10\n", 180 | "\n", 181 | "### for white background...\n", 182 | "# plt.rc('text',usetex=True)\n", 183 | "plt.rc('text',usetex=False)\n", 184 | "# plt.rc('font',**{'family':'sans-serif','sans-serif':['Avant Garde']}) \n", 185 | "plt.rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) \n", 186 | "plt.rc('savefig',facecolor='white')\n", 187 | "plt.rc('axes',facecolor='white')\n", 188 | "plt.rc('axes',labelcolor='dimgrey')\n", 189 | "plt.rc('axes',labelcolor='dimgrey')\n", 190 | "plt.rc('xtick',color='dimgrey')\n", 191 | "plt.rc('ytick',color='dimgrey')\n", 192 | "################################ \n", 193 | "################################ \n", 194 | "def adjust_spines(ax, spines):\n", 195 | " for loc, spine in ax.spines.items():\n", 196 | " if loc in spines:\n", 197 | " spine.set_position(('outward', 5))\n", 198 | " else:\n", 199 | " spine.set_color('none') \n", 200 | " if 'left' in spines:\n", 201 | " ax.yaxis.set_ticks_position('left')\n", 202 | " else:\n", 203 | " ax.yaxis.set_ticks([])\n", 204 | " if 'bottom' in spines:\n", 205 | " ax.xaxis.set_ticks_position('bottom')\n", 206 | " else:\n", 207 | " ax.xaxis.set_ticks([]) \n", 208 | "\n", 209 | "def format_spines(ax):\n", 210 | " adjust_spines(ax, ['left', 'bottom'])\n", 211 | " ax.spines['top'].set_color('none')\n", 212 | " ax.spines['right'].set_color('none')\n", 213 | " ax.spines['left'].set_color('dimgrey')\n", 214 | " ax.spines['bottom'].set_color('dimgrey')\n", 215 | " ax.spines['left'].set_linewidth(2)\n", 216 | " ax.spines['bottom'].set_linewidth(2)\n", 217 | " ax.tick_params('both',length=4,width=2,which='major',color='dimgrey')\n", 218 | "# ax.yaxis.grid(zorder=1,color='dimgrey',alpha=0.35) \n", 219 | " " 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "c5dd051e-9e5e-4f8a-b0dd-de1b9285f952", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "model_name_plot = EXP_NAME + '_' + str(PLOT_SEED)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "id": "6073c047-9c02-47ab-a8d9-27bd3580b216", 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "imp.reload(file_methods)\n", 240 | "imp.reload(data_processing)\n", 241 | "\n", 242 | "rng = np.random.default_rng(settings[\"rng_seed\"])\n", 243 | "settings[\"seed\"] = PLOT_SEED\n", 244 | "\n", 245 | "# get model name\n", 246 | "model_name = file_methods.get_model_name(settings)\n", 247 | "\n", 248 | "# load the model\n", 249 | "model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)\n", 250 | "\n", 251 | "settings_new = settings\n", 252 | "settings_new[\"gcmsub\"] = \"OOS\"\n", 253 | "settings_new[\"n_train_val_test\"] = (0, 0, 5)\n", 254 | "# get the data\n", 255 | "(x_train, \n", 256 | " x_val, \n", 257 | " x_test, \n", 258 | " y_train, \n", 259 | " y_val, \n", 260 | " y_test, \n", 261 | " onehot_train, \n", 262 | " onehot_val, \n", 263 | " onehot_test, \n", 264 | " y_yrs_train, \n", 265 | " y_yrs_val, \n", 266 | " y_yrs_test, \n", 267 | " target_years, \n", 268 | " map_shape,\n", 269 | " settings) = data_processing.get_cmip_data(DATA_DIRECTORY, settings) \n", 270 | "\n", 271 | " " 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "id": "83031aa9-9d62-468f-9dee-304416a74dd6", 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "filenames = file_methods.get_cmip_filenames(settings_new, verbose=0)\n", 282 | "N_GCMS = len(filenames)\n", 283 | "N_MEMBERS = settings[\"n_train_val_test\"][-1]\n", 284 | "target_list = []\n", 285 | "\n", 286 | "# loop through the models and plot\n", 287 | "clr = ('lawngreen','tab:green','chocolate')\n", 288 | "fig,axs = plt.subplots(1,2,figsize=(3*2.5,2.25))\n", 289 | "\n", 290 | "\n", 291 | "#---------------------------------------------\n", 292 | "plt.subplot(1,2,1)\n", 293 | "for imodel in np.arange(0,3):\n", 294 | " f = filenames[imodel]\n", 295 | " print(f)\n", 296 | " da = file_methods.get_netcdf_da(DATA_DIRECTORY + f)\n", 297 | " f_labels, f_years, f_target_year = data_processing.get_labels(da, settings_new,)\n", 298 | "\n", 299 | " # compute global mean\n", 300 | " global_mean = data_processing.compute_global_mean(da)\n", 301 | " baseline_mean = global_mean.sel(time=slice(str(settings[\"baseline_yr_bounds\"][0]),str(settings[\"baseline_yr_bounds\"][1]))).mean('time')\n", 302 | " global_mean_anomalies = global_mean - baseline_mean\n", 303 | " if settings[\"smooth\"] == True:\n", 304 | " mean_curve = savgol_filter(np.mean(global_mean_anomalies,axis=0), 15, 3)\n", 305 | " else:\n", 306 | " mean_curve = np.mean(global_mean_anomalies,axis=0)\n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " # plot the members\n", 311 | " plt.plot(f_years, \n", 312 | " np.swapaxes(global_mean_anomalies.to_numpy(),1,0), \n", 313 | " color='gray',\n", 314 | " linewidth=.5,\n", 315 | " alpha=.3,\n", 316 | " zorder=1,\n", 317 | " )\n", 318 | " # plot ensemble mean\n", 319 | " plt.plot(f_years, \n", 320 | " mean_curve, \n", 321 | " color=clr[imodel],\n", 322 | " linewidth=1.,\n", 323 | " alpha=1.,\n", 324 | " zorder=4,\n", 325 | " )\n", 326 | " \n", 327 | " #plot the year\n", 328 | " target_list.append(f_target_year)\n", 329 | " if(f_target_year != 2100):\n", 330 | " plt.axvline(x=f_target_year,\n", 331 | " color=clr[imodel],\n", 332 | " linewidth=1.,\n", 333 | " alpha=1.,\n", 334 | " linestyle='--',\n", 335 | " ) \n", 336 | " \n", 337 | "# plt.title('Global Mean Temperatures for SSP'+ str(settings[\"ssp\"]),fontsize=12)\n", 338 | "plt.xlabel('year',fontsize=FS)\n", 339 | "plt.ylabel('temperature anomaly',fontsize=FS)\n", 340 | "plt.xticks(np.arange(1850,2150,50),np.arange(1850,2150,50))\n", 341 | "\n", 342 | "plt.ylim(-.4,2.9)\n", 343 | "plt.axhline(y=0, color='black', linewidth=0.5)\n", 344 | "plt.axhline(y=1.1, color='gray', linewidth=1.0, linestyle='--')\n", 345 | "plt.axhline(y=1.5, color='gray', linewidth=1.0, linestyle='--')\n", 346 | "plt.axhline(y=2.0, color='gray', linewidth=1.0, linestyle='--')\n", 347 | "\n", 348 | "plt.text(1850,\n", 349 | " 2.0,\n", 350 | " str(settings[\"target_temp\"]) + \"C\\nSSP\" + settings[\"ssp\"][0] + '-' + settings[\"ssp\"][1] + '.' + settings[\"ssp\"][-1],\n", 351 | " fontsize=FS,\n", 352 | " horizontalalignment=\"left\",\n", 353 | " verticalalignment=\"bottom\",\n", 354 | " color='k', \n", 355 | " weight=\"bold\",\n", 356 | " )\n", 357 | "\n", 358 | "format_spines(plt.gca())\n", 359 | "\n", 360 | "\n", 361 | "#--------------------------------------\n", 362 | "plt.subplot(1,2,2)\n", 363 | "# plot the predictions for 8 members\n", 364 | "YEARS_UNIQUE = np.unique(y_yrs_test)\n", 365 | "\n", 366 | "miroc_pred = model.predict(x_test)\n", 367 | "mu_pred = miroc_pred[:,0].reshape(N_GCMS,N_MEMBERS,len(YEARS_UNIQUE))\n", 368 | "sigma_pred = miroc_pred[:,1].reshape(N_GCMS,N_MEMBERS,len(YEARS_UNIQUE))\n", 369 | "\n", 370 | "for imodel in np.arange(0,3):\n", 371 | " print(filenames[imodel])\n", 372 | " iy = np.where(YEARS_UNIQUE==2021)[0]\n", 373 | " print(np.mean(mu_pred[imodel,:,:].swapaxes(1,0),axis=1)[iy],np.mean(sigma_pred[imodel,:,:].swapaxes(1,0),axis=1)[iy])\n", 374 | " plt.plot(YEARS_UNIQUE,mu_pred[imodel,:,:].swapaxes(1,0),color=clr[imodel],linewidth=1.,alpha=.25)\n", 375 | " plt.plot(YEARS_UNIQUE,np.mean(mu_pred[imodel,:,:].swapaxes(1,0),axis=1),\n", 376 | " color=clr[imodel],\n", 377 | " linewidth=2.,\n", 378 | " alpha=.75,\n", 379 | " # label=label1,\n", 380 | " )\n", 381 | "\n", 382 | " if(target_list[imodel] != 2100):\n", 383 | " plt.axvline(x=target_list[imodel],\n", 384 | " color=clr[imodel],\n", 385 | " linewidth=1.,\n", 386 | " alpha=1.,\n", 387 | " linestyle='--',\n", 388 | " ) \n", 389 | "\n", 390 | "plt.legend(frameon=False) \n", 391 | "# plt.ylim(-17,30) \n", 392 | "plt.ylim(-27,65) \n", 393 | "plt.xlim(2020,2100)\n", 394 | "ax = plt.gca()\n", 395 | "format_spines(ax)\n", 396 | "plt.ylabel('predicted years\\nuntil ' + str(settings[\"target_temp\"]) + 'C threshold')\n", 397 | "# plt.title(model_name_plot)\n", 398 | "\n", 399 | "plt.axhline(y=0, color='black', linewidth=0.5)\n", 400 | "\n", 401 | "plt.tight_layout()\n", 402 | "plots.savefig(FIGURE_DIRECTORY + model_name_plot + '_OOS_inference', dpi=savefig_dpi)\n", 403 | "plt.show() \n", 404 | " " 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "id": "3cac55b3-b625-4f77-bca3-fcd9aab97480", 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "# from scipy.signal import savgol_filter\n", 415 | "\n", 416 | "# raw_values = np.mean(global_mean,axis=0)\n", 417 | "# baseline_mean = raw_values.sel(time=slice(str(settings[\"baseline_yr_bounds\"][0]),str(settings[\"baseline_yr_bounds\"][1]))).mean('time')\n", 418 | "# iwarmer = np.where(raw_values > baseline_mean.values+settings[\"target_temp\"])[0]\n", 419 | "# target_year = raw_values[\"time\"].values[iwarmer[0]].year\n", 420 | "# plt.plot(raw_values[\"time.year\"],raw_values)\n", 421 | "\n", 422 | "# smoothed_values = np.mean(global_mean,axis=0)\n", 423 | "# smoothed_values = savgol_filter(smoothed_values, 15, 3) # window size 51, polynomial order 3\n", 424 | "\n", 425 | "# # poly = np.poly1d(np.polyfit(raw_values[\"time.year\"],smoothed_values,deg=10))\n", 426 | "# # smoothed_values = poly(raw_values[\"time.year\"])\n", 427 | "# baseline_mean = raw_values.sel(time=slice(str(settings[\"baseline_yr_bounds\"][0]),str(settings[\"baseline_yr_bounds\"][1]))).mean('time')\n", 428 | "# iwarmer = np.where(smoothed_values > baseline_mean.values+settings[\"target_temp\"])[0]\n", 429 | "# target_year = raw_values[\"time\"].values[iwarmer[0]].year\n", 430 | "# plt.plot(raw_values[\"time.year\"],smoothed_values)\n", 431 | "\n", 432 | "# print(target_year)" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "id": "d01387eb-b1ea-4480-b552-12b03d50534a", 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [ 442 | "# filenames = file_methods.get_cmip_filenames(settings_new, verbose=0)\n", 443 | "# N_GCMS = len(filenames)\n", 444 | "# N_MEMBERS = settings[\"n_train_val_test\"][-1]\n", 445 | "# target_list = []\n", 446 | "\n", 447 | "# # loop through the models and plot\n", 448 | "# clr = ('fuchsia','tab:green','tab:blue','gold','tab:purple','tab:orange','k')\n", 449 | "# fig,axs = plt.subplots(1,2,figsize=(3*2.5,2.25))\n", 450 | "\n", 451 | "\n", 452 | "# #---------------------------------------------\n", 453 | "# plt.subplot(1,2,1)\n", 454 | "# for imodel in np.arange(0,3):\n", 455 | "# f = filenames[imodel]\n", 456 | "# print(f)\n", 457 | "# da = file_methods.get_netcdf_da(DATA_DIRECTORY + f)\n", 458 | "# f_labels, f_years, f_target_year = data_processing.get_labels(da, settings_new,)\n", 459 | "\n", 460 | "# # compute global mean\n", 461 | "# global_mean = data_processing.compute_global_mean(da)\n", 462 | "# baseline_mean = global_mean.sel(time=slice(str(settings[\"baseline_yr_bounds\"][0]),str(settings[\"baseline_yr_bounds\"][1]))).mean('time')\n", 463 | "# global_mean_anomalies = global_mean - baseline_mean\n", 464 | " \n", 465 | "# # plot the members\n", 466 | "# plt.plot(f_years, \n", 467 | "# np.swapaxes(global_mean_anomalies.to_numpy(),1,0), \n", 468 | "# color='gray',\n", 469 | "# linewidth=.5,\n", 470 | "# alpha=.3,\n", 471 | "# zorder=1,\n", 472 | "# )\n", 473 | "# # plot ensemble mean\n", 474 | "# print('max temp = ' + str(np.round(np.max(np.mean(global_mean_anomalies,axis=0)).values,2)))\n", 475 | "# print(np.round((np.mean(global_mean_anomalies,axis=0)).values,2))\n", 476 | "# year_list = da[\"time.year\"].values\n", 477 | "# print('argmax = ' + str(year_list[np.argmax(np.mean(global_mean_anomalies,axis=0).values)]))\n", 478 | "# plt.plot(f_years, \n", 479 | "# np.mean(global_mean_anomalies,axis=0), \n", 480 | "# color=clr[imodel],\n", 481 | "# linewidth=1.,\n", 482 | "# alpha=1.,\n", 483 | "# zorder=4,\n", 484 | "# )\n", 485 | " \n", 486 | "# #plot the year\n", 487 | "# target_list.append(f_target_year)\n", 488 | "# if(f_target_year != 2100):\n", 489 | "# plt.axvline(x=f_target_year,\n", 490 | "# color=clr[imodel],\n", 491 | "# linewidth=1.,\n", 492 | "# alpha=1.,\n", 493 | "# linestyle='--',\n", 494 | "# ) \n", 495 | " \n", 496 | "# # plt.title('Global Mean Temperatures for SSP'+ str(settings[\"ssp\"]),fontsize=12)\n", 497 | "# plt.xlabel('year',fontsize=FS)\n", 498 | "# plt.ylabel('temperature anomaly',fontsize=FS)\n", 499 | "# plt.xticks(np.arange(1850,2150,50),np.arange(1850,2150,50))\n", 500 | "\n", 501 | "# plt.ylim(-.4,2.9)\n", 502 | "# plt.axhline(y=0, color='black', linewidth=0.5)\n", 503 | "# plt.axhline(y=1.1, color='gray', linewidth=1.0, linestyle='--')\n", 504 | "# plt.axhline(y=1.5, color='gray', linewidth=1.0, linestyle='--')\n", 505 | "# plt.axhline(y=2.0, color='gray', linewidth=1.0, linestyle='--')\n", 506 | "\n", 507 | "# plt.text(1850,\n", 508 | "# 2.0,\n", 509 | "# str(settings[\"target_temp\"]) + \"C\\nSSP\" + settings[\"ssp\"][0] + '-' + settings[\"ssp\"][1] + '.' + settings[\"ssp\"][-1],\n", 510 | "# fontsize=FS,\n", 511 | "# horizontalalignment=\"left\",\n", 512 | "# verticalalignment=\"bottom\",\n", 513 | "# color='k', \n", 514 | "# weight=\"bold\",\n", 515 | "# )\n", 516 | "\n", 517 | "# format_spines(plt.gca())\n", 518 | "\n", 519 | "\n", 520 | "# #--------------------------------------\n", 521 | "# plt.subplot(1,2,2)\n", 522 | "# # plot the predictions for 8 members\n", 523 | "# YEARS_UNIQUE = np.unique(y_yrs_test)\n", 524 | "\n", 525 | "# miroc_pred = model.predict(x_test)\n", 526 | "# mu_pred = miroc_pred[:,0].reshape(N_GCMS,N_MEMBERS,len(YEARS_UNIQUE))\n", 527 | "# sigma_pred = miroc_pred[:,1].reshape(N_GCMS,N_MEMBERS,len(YEARS_UNIQUE))\n", 528 | "\n", 529 | "# for imodel in np.arange(0,3):\n", 530 | "# print(filenames[imodel])\n", 531 | "# # plt.plot(YEARS_UNIQUE,mu_pred[imodel,:,:].swapaxes(1,0),color=clr[imodel],linewidth=1.,alpha=.25)\n", 532 | "# plt.errorbar(YEARS_UNIQUE,np.mean(mu_pred[imodel,:,:].swapaxes(1,0),axis=1),yerr=np.mean(sigma_pred[imodel,:,:].swapaxes(1,0),axis=1),\n", 533 | "# color=clr[imodel],\n", 534 | "# linewidth=2.,\n", 535 | "# alpha=.75,\n", 536 | "# # label=label1,\n", 537 | "# )\n", 538 | "\n", 539 | "# if(target_list[imodel] != 2100):\n", 540 | "# plt.axvline(x=target_list[imodel],\n", 541 | "# color=clr[imodel],\n", 542 | "# linewidth=1.,\n", 543 | "# alpha=1.,\n", 544 | "# linestyle='--',\n", 545 | "# ) \n", 546 | "\n", 547 | "# plt.legend(frameon=False) \n", 548 | "# # plt.ylim(-17,30) \n", 549 | "# plt.ylim(-27,65) \n", 550 | "# plt.xlim(2020,2100)\n", 551 | "# ax = plt.gca()\n", 552 | "# format_spines(ax)\n", 553 | "# plt.ylabel('predicted years\\nuntil ' + str(settings[\"target_temp\"]) + 'C threshold')\n", 554 | "# # plt.title(model_name_plot)\n", 555 | "\n", 556 | "# plt.axhline(y=0, color='black', linewidth=0.5)\n", 557 | "\n", 558 | "# plt.tight_layout()\n", 559 | "# # plots.savefig(FIGURE_DIRECTORY + model_name_plot + '_OOS_inference', dpi=savefig_dpi)\n", 560 | "# plt.show() \n", 561 | " " 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": null, 567 | "id": "507a3fe3-cfa1-4193-84d2-0319d40be334", 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [ 571 | "year_plot = 2021\n", 572 | "i = np.where(YEARS_UNIQUE==year_plot)[0]\n", 573 | "YEARS_UNIQUE[i]\n", 574 | "np.mean(mu_pred[:,:,i],axis=1)+year_plot" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": null, 580 | "id": "3c4d1d6e-7ad6-4d52-875a-fa18776eda00", 581 | "metadata": {}, 582 | "outputs": [], 583 | "source": [ 584 | "# load BEST observations for diagnostics plotting\n", 585 | "da_obs, x_obs, global_mean_obs = data_processing.get_observations(DATA_DIRECTORY, settings)\n", 586 | "\n", 587 | "# load GISS observations\n", 588 | "settings[\"obsdata\"] = 'GISS'\n", 589 | "da_obs_giss, x_obs_giss, global_mean_obs_giss = data_processing.get_observations(DATA_DIRECTORY, settings)\n", 590 | "\n", 591 | "\n", 592 | "pred_obs = model.predict(x_obs)" 593 | ] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "id": "79e3520f-b456-4f07-9efd-32cfdf7ef71a", 598 | "metadata": { 599 | "tags": [] 600 | }, 601 | "source": [ 602 | "#### " 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "id": "3f58990f-89cb-4864-976f-ff1d9a2e2c17", 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [ 612 | "PLOT_YEAR = 2021\n", 613 | "iyear = np.where(YEARS_UNIQUE==PLOT_YEAR)[0]\n", 614 | "norm_incs = np.arange(-80,80,1)\n", 615 | "\n", 616 | "plt.figure(figsize=(3,2))\n", 617 | "norm_dist = tfp.distributions.Normal(pred_obs[-1,0],pred_obs[-1,1])\n", 618 | "norm_cpd = norm_dist.prob(norm_incs)\n", 619 | "plt.plot(norm_incs+PLOT_YEAR,\n", 620 | " norm_cpd,\n", 621 | " color='tab:orange',\n", 622 | " linewidth=2.,\n", 623 | " alpha=1.,\n", 624 | " zorder=50,\n", 625 | " )\n", 626 | "\n", 627 | "for imodel in np.arange(0,N_GCMS):\n", 628 | " norm_cpd_mean = np.zeros(len(norm_incs))\n", 629 | " \n", 630 | " for ens in np.arange(0,mu_pred.shape[1]):\n", 631 | " norm_dist = tfp.distributions.Normal(mu_pred[imodel,ens,iyear],sigma_pred[imodel,ens,iyear])\n", 632 | " norm_cpd = norm_dist.prob(norm_incs)\n", 633 | " plt.plot(norm_incs+PLOT_YEAR,\n", 634 | " norm_cpd,\n", 635 | " color=clr[imodel],\n", 636 | " linewidth=.75,\n", 637 | " alpha=.25,\n", 638 | " )\n", 639 | " \n", 640 | "plt.text(2021,\n", 641 | " .15,\n", 642 | " str(settings[\"target_temp\"]) + \"C\\nSSP\" + settings[\"ssp\"][0] + '-' + settings[\"ssp\"][1] + '.' + settings[\"ssp\"][-1],\n", 643 | " fontsize=FS,\n", 644 | " horizontalalignment=\"left\",\n", 645 | " verticalalignment=\"top\",\n", 646 | " color='k', \n", 647 | " # weight=\"bold\",\n", 648 | " ) \n", 649 | "\n", 650 | "plt.xlim(2020,2100)\n", 651 | "plt.yticks(np.arange(0,.25,.02),np.arange(0,.25,.02).round(2))\n", 652 | "plt.ylim(-0.001,.15)\n", 653 | "format_spines(plt.gca())\n", 654 | " " 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": null, 660 | "id": "63ff1ffb-6806-43b5-a7c9-2f3d07391a32", 661 | "metadata": {}, 662 | "outputs": [], 663 | "source": [ 664 | "PLOT_YEAR = 2021\n", 665 | "iyear = np.where(YEARS_UNIQUE==PLOT_YEAR)[0]\n", 666 | "norm_incs = np.arange(-80,80,1)\n", 667 | "\n", 668 | "plt.figure(figsize=(3,2))\n", 669 | "norm_dist = tfp.distributions.Normal(pred_obs[-1,0],pred_obs[-1,1])\n", 670 | "norm_cpd = norm_dist.prob(norm_incs)\n", 671 | "plt.plot(norm_incs+PLOT_YEAR,\n", 672 | " norm_cpd,\n", 673 | " color='tab:orange',\n", 674 | " linewidth=2.,\n", 675 | " alpha=1.,\n", 676 | " zorder=100,\n", 677 | " )\n", 678 | "\n", 679 | "for imodel in (0,1):\n", 680 | " ens = 7\n", 681 | " norm_dist = tfp.distributions.Normal(mu_pred[imodel,ens,iyear],sigma_pred[imodel,ens,iyear])\n", 682 | " norm_cpd = norm_dist.prob(norm_incs)\n", 683 | " plt.plot(norm_incs+PLOT_YEAR,\n", 684 | " norm_cpd,\n", 685 | " color=clr[imodel],\n", 686 | " linewidth=2.,\n", 687 | " alpha=1.,\n", 688 | " zorder=100,\n", 689 | " )\n", 690 | " \n", 691 | "plt.text(2021,\n", 692 | " .15,\n", 693 | " str(settings[\"target_temp\"]) + \"C\\nSSP\" + settings[\"ssp\"][0] + '-' + settings[\"ssp\"][1] + '.' + settings[\"ssp\"][-1],\n", 694 | " fontsize=FS,\n", 695 | " horizontalalignment=\"left\",\n", 696 | " verticalalignment=\"top\",\n", 697 | " color='k', \n", 698 | " # weight=\"bold\",\n", 699 | " ) \n", 700 | "\n", 701 | "plt.xlim(2020,2100)\n", 702 | "plt.yticks(np.arange(0,.25,.05),np.arange(0,.25,.05).round(2))\n", 703 | "plt.ylim(-0.001,.15)\n", 704 | "format_spines(plt.gca())\n", 705 | " " 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": null, 711 | "id": "9526934d-0e6f-4f73-9377-263d9193ee2b", 712 | "metadata": {}, 713 | "outputs": [], 714 | "source": [] 715 | } 716 | ], 717 | "metadata": { 718 | "colab": { 719 | "collapsed_sections": [], 720 | "name": "_main.ipynb", 721 | "provenance": [] 722 | }, 723 | "kernelspec": { 724 | "display_name": "Python 3 (ipykernel)", 725 | "language": "python", 726 | "name": "python3" 727 | }, 728 | "language_info": { 729 | "codemirror_mode": { 730 | "name": "ipython", 731 | "version": 3 732 | }, 733 | "file_extension": ".py", 734 | "mimetype": "text/x-python", 735 | "name": "python", 736 | "nbconvert_exporter": "python", 737 | "pygments_lexer": "ipython3", 738 | "version": "3.9.7" 739 | } 740 | }, 741 | "nbformat": 4, 742 | "nbformat_minor": 5 743 | } 744 | -------------------------------------------------------------------------------- /_compare_random_seeds.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6da79c16-eb43-4664-a883-7a31f3af00da", 6 | "metadata": { 7 | "id": "4a650402-4774-49cb-9b72-9c8f1dd02f1d", 8 | "pycharm": { 9 | "name": "#%% md\n" 10 | }, 11 | "tags": [] 12 | }, 13 | "source": [ 14 | "# Compare across random seeds\n", 15 | "##### authors: Elizabeth A. Barnes and Noah Diffenbaugh\n", 16 | "##### date: March 25, 2022\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "7ccff821-b304-4009-8fe8-75a213b3f421", 22 | "metadata": { 23 | "pycharm": { 24 | "name": "#%% md\n" 25 | }, 26 | "tags": [] 27 | }, 28 | "source": [ 29 | "## Python stuff" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "b9b3c60f", 36 | "metadata": { 37 | "pycharm": { 38 | "name": "#%%\n" 39 | } 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "%%javascript\n", 44 | "require(\n", 45 | " [\"notebook/js/outputarea\"],\n", 46 | " function (oa) {\n", 47 | " oa.OutputArea.auto_scroll_threshold = -1;\n", 48 | " console.log(\"Setting auto_scroll_threshold to -1\");\n", 49 | " });\n", 50 | "\n", 51 | "%%javascript\n", 52 | "require(\"notebook/js/notebook\").Notebook.prototype.scroll_to_bottom = function () {}" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "fb968382-4186-466e-a85b-b00caa5fc9be", 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "executionInfo": { 64 | "elapsed": 17642, 65 | "status": "ok", 66 | "timestamp": 1646449680995, 67 | "user": { 68 | "displayName": "Elizabeth Barnes", 69 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiNPVVIWP6XAkP_hwu-8rAxoeeNuk2BMkX5-yuA=s64", 70 | "userId": "07585723222468022011" 71 | }, 72 | "user_tz": 420 73 | }, 74 | "id": "fb968382-4186-466e-a85b-b00caa5fc9be", 75 | "outputId": "d7964af9-2d52-4466-902d-9b85faba9a91", 76 | "pycharm": { 77 | "name": "#%%\n" 78 | }, 79 | "tags": [] 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "import sys, os, copy, tqdm\n", 84 | "import importlib as imp\n", 85 | "\n", 86 | "import xarray as xr\n", 87 | "import numpy as np\n", 88 | "import pandas as pd\n", 89 | "import matplotlib.pyplot as plt\n", 90 | "import scipy.stats as stats\n", 91 | "import tensorflow as tf\n", 92 | "import tensorflow_probability as tfp\n", 93 | "import silence_tensorflow\n", 94 | "silence_tensorflow\n", 95 | "\n", 96 | "import scipy.stats as stats\n", 97 | "import seaborn as sns\n", 98 | "from tqdm import tqdm\n", 99 | "\n", 100 | "import experiment_settings\n", 101 | "import file_methods, plots, data_processing, custom_metrics, network\n", 102 | "\n", 103 | "import matplotlib as mpl\n", 104 | "mpl.rcParams[\"figure.facecolor\"] = \"white\"\n", 105 | "mpl.rcParams[\"figure.dpi\"] = 150\n", 106 | "savefig_dpi = 300\n", 107 | "np.warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning)\n", 108 | "\n", 109 | "import warnings\n", 110 | "warnings.filterwarnings(\"ignore\")" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "29a5cee3-6f4f-4818-92e1-1351eeeb565a", 117 | "metadata": { 118 | "colab": { 119 | "base_uri": "https://localhost:8080/" 120 | }, 121 | "executionInfo": { 122 | "elapsed": 30, 123 | "status": "ok", 124 | "timestamp": 1646449681009, 125 | "user": { 126 | "displayName": "Elizabeth Barnes", 127 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiNPVVIWP6XAkP_hwu-8rAxoeeNuk2BMkX5-yuA=s64", 128 | "userId": "07585723222468022011" 129 | }, 130 | "user_tz": 420 131 | }, 132 | "id": "29a5cee3-6f4f-4818-92e1-1351eeeb565a", 133 | "outputId": "e5f5b0ac-82b8-4147-bf44-4bc3b49466a2", 134 | "pycharm": { 135 | "name": "#%%\n" 136 | }, 137 | "tags": [] 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "print(f\"python version = {sys.version}\")\n", 142 | "print(f\"numpy version = {np.__version__}\")\n", 143 | "print(f\"xarray version = {xr.__version__}\") \n", 144 | "print(f\"tensorflow version = {tf.__version__}\") \n", 145 | "print(f\"tensorflow-probability version = {tfp.__version__}\") " 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "651315ce-eecc-4d30-8b90-c97d08936315", 151 | "metadata": { 152 | "pycharm": { 153 | "name": "#%% md\n" 154 | }, 155 | "tags": [] 156 | }, 157 | "source": [ 158 | "## User Choices" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "id": "c83a544f-ef35-417f-bec4-62225d885014", 165 | "metadata": { 166 | "pycharm": { 167 | "name": "#%%\n" 168 | } 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "EXP_NAME_VEC = (\n", 173 | " 'exp11C_370','exp15C_370','exp20C_370','exp15C_126','exp20C_126','exp11C_126',\n", 174 | " 'exp11C_245','exp15C_245','exp20C_245',\n", 175 | " 'exp0','exp1','exp2','exp3','exp4','exp5', # ridge sweep w/ hiddens = [10,10]\n", 176 | " 'exp10','exp11','exp12','exp13','exp14', # hiddens sweep w/ ridge = 10.0\n", 177 | " 'exp20','exp21','exp22','exp23', # hiddens sweep w/ ridge = 5.0 \n", 178 | " 'exp30','exp31','exp32','exp33', # ridge sweep w/ hiddens = [25,25]\n", 179 | " 'exp15C_370_uniform','exp20C_370_uniform','exp15C_126_uniform','exp20C_126_uniform', \n", 180 | " 'exp20C_126_force','exp20C_126_extended','exp20C_126_max','exp20C_126_all7','exp20C_126_all7_b',\n", 181 | " 'exp15C_126_all10',\n", 182 | " 'exp20C_126_all7_baseAnoms',\n", 183 | " 'exp19C_126_all7','exp19C_126_all7_smooth','exp20C_126_smooth',\n", 184 | " 'exp15C_126_noM6','exp15C_126_test','exp15C_126_noSH',\n", 185 | " 'exp15C_370_smooth','exp13C_126','exp13C_370',\n", 186 | " 'exp15C_126_nohigh10','exp15C_126_nohigh7','exp15C_126_nohigh5',\n", 187 | " 'exp15C_126_smooth_nohigh10','exp15C_126_smooth_nohigh7','exp15C_126_smooth_nohigh5',\n", 188 | ")\n", 189 | "\n", 190 | "LOOP_THROUGH_EXP = True\n", 191 | "SAVE_FILE = True\n", 192 | "LOAD_METRICS = True\n", 193 | "OVERWRITE = False\n", 194 | "\n", 195 | "#-------------------------------------------------------\n", 196 | "\n", 197 | "MODEL_DIRECTORY = 'saved_models/' \n", 198 | "PREDICTIONS_DIRECTORY = 'saved_predictions/'\n", 199 | "DATA_DIRECTORY = 'data/'\n", 200 | "DIAGNOSTICS_DIRECTORY = 'model_diagnostics/'\n", 201 | "FIGURE_DIRECTORY = 'figures/'" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "id": "47e6f742-a1e9-423c-9048-450208f54ca9", 207 | "metadata": { 208 | "pycharm": { 209 | "name": "#%% md\n" 210 | }, 211 | "tags": [] 212 | }, 213 | "source": [ 214 | "## Plotting Functions" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "9d29d831-aa15-46fd-89be-07628cc0f8b2", 221 | "metadata": { 222 | "pycharm": { 223 | "name": "#%%\n" 224 | } 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "FS = 10\n", 229 | "palette=(\"tab:gray\",\"tab:purple\",\"tab:orange\",\"tab:blue\",\"tab:red\",\"tab:green\",\"tab:pink\",\"tab:brown\",\"tab:olive\")\n", 230 | "\n", 231 | "### for white background...\n", 232 | "# plt.rc('text',usetex=True)\n", 233 | "plt.rc('text',usetex=False)\n", 234 | "# plt.rc('font',**{'family':'sans-serif','sans-serif':['Avant Garde']}) \n", 235 | "plt.rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) \n", 236 | "plt.rc('savefig',facecolor='white')\n", 237 | "plt.rc('axes',facecolor='white')\n", 238 | "plt.rc('axes',labelcolor='dimgrey')\n", 239 | "plt.rc('axes',labelcolor='dimgrey')\n", 240 | "plt.rc('xtick',color='dimgrey')\n", 241 | "plt.rc('ytick',color='dimgrey')\n", 242 | "################################ \n", 243 | "################################ \n", 244 | "def adjust_spines(ax, spines):\n", 245 | " for loc, spine in ax.spines.items():\n", 246 | " if loc in spines:\n", 247 | " spine.set_position(('outward', 5))\n", 248 | " else:\n", 249 | " spine.set_color('none') \n", 250 | " if 'left' in spines:\n", 251 | " ax.yaxis.set_ticks_position('left')\n", 252 | " else:\n", 253 | " ax.yaxis.set_ticks([])\n", 254 | " if 'bottom' in spines:\n", 255 | " ax.xaxis.set_ticks_position('bottom')\n", 256 | " else:\n", 257 | " ax.xaxis.set_ticks([]) \n", 258 | "\n", 259 | "def format_spines(ax):\n", 260 | " adjust_spines(ax, ['left', 'bottom'])\n", 261 | " ax.spines['top'].set_color('none')\n", 262 | " ax.spines['right'].set_color('none')\n", 263 | " ax.spines['left'].set_color('dimgrey')\n", 264 | " ax.spines['bottom'].set_color('dimgrey')\n", 265 | " ax.spines['left'].set_linewidth(2)\n", 266 | " ax.spines['bottom'].set_linewidth(2)\n", 267 | " ax.tick_params('both',length=4,width=2,which='major',color='dimgrey')\n", 268 | "# ax.yaxis.grid(zorder=1,color='dimgrey',alpha=0.35) \n", 269 | " " 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "id": "2daafc6b-3cbe-4cc3-8c6a-fcda10c75248", 275 | "metadata": { 276 | "pycharm": { 277 | "name": "#%% md\n" 278 | }, 279 | "tags": [] 280 | }, 281 | "source": [ 282 | "## Get obs" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "id": "29638693-2c37-4773-a80b-90f74728aaf8", 289 | "metadata": { 290 | "pycharm": { 291 | "name": "#%%\n" 292 | } 293 | }, 294 | "outputs": [], 295 | "source": [ 296 | "# load observations for diagnostics plotting\n", 297 | "settings = experiment_settings.get_settings(\"exp0\")\n", 298 | "da_obs, x_obs, global_mean_obs = data_processing.get_observations(DATA_DIRECTORY, settings)\n", 299 | "\n", 300 | "settings[\"obsdata\"] = 'GISS'\n", 301 | "da_obs_giss, x_obs_giss, global_mean_obs_giss = data_processing.get_observations(DATA_DIRECTORY, settings)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "id": "a4ca07ef-c91c-4bdf-87f0-0fdcd20e2c45", 307 | "metadata": { 308 | "pycharm": { 309 | "name": "#%% md\n" 310 | }, 311 | "tags": [] 312 | }, 313 | "source": [ 314 | "## Analyze CMIP results across random seeds" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "id": "88b9c407-2a94-42d1-900b-bca27a0f4802", 321 | "metadata": { 322 | "pycharm": { 323 | "name": "#%%\n" 324 | }, 325 | "tags": [] 326 | }, 327 | "outputs": [], 328 | "source": [ 329 | "df_metrics = pd.DataFrame()\n", 330 | "if LOAD_METRICS == True:\n", 331 | " df_metrics = pd.read_pickle(PREDICTIONS_DIRECTORY + \"df_random_seed.pickle\")\n", 332 | "\n", 333 | " \n", 334 | "if LOOP_THROUGH_EXP == True:\n", 335 | " for exp_name in EXP_NAME_VEC:\n", 336 | " settings = experiment_settings.get_settings(exp_name)\n", 337 | " rng = np.random.default_rng(settings[\"rng_seed\"]) \n", 338 | " print(exp_name)\n", 339 | "\n", 340 | " for iloop in np.arange(settings['n_models']):\n", 341 | " seed = rng.integers(low=1_000,high=10_000,size=1)[0]\n", 342 | " settings[\"seed\"] = int(seed)\n", 343 | " tf.random.set_seed(settings[\"seed\"])\n", 344 | " np.random.seed(settings[\"seed\"])\n", 345 | "\n", 346 | " # check if entry exists\n", 347 | " if LOAD_METRICS == True:\n", 348 | " entry = df_metrics[(df_metrics[\"exp_name\"]==settings[\"exp_name\"]) & (df_metrics[\"seed\"]==settings[\"seed\"])]\n", 349 | " if OVERWRITE == True:\n", 350 | " print('removing entry: ')\n", 351 | " display(entry)\n", 352 | " df_metrics=df_metrics.drop(index=entry.index,) \n", 353 | " elif (len(entry) > 0):\n", 354 | " continue\n", 355 | " \n", 356 | " # get model name\n", 357 | " model_name = file_methods.get_model_name(settings)\n", 358 | " if os.path.exists(MODEL_DIRECTORY + model_name + \"_model\") == False: \n", 359 | " continue\n", 360 | " model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)\n", 361 | " # get the data\n", 362 | " (x_train, \n", 363 | " x_val, \n", 364 | " x_test, \n", 365 | " y_train, \n", 366 | " y_val, \n", 367 | " y_test, \n", 368 | " onehot_train, \n", 369 | " onehot_val, \n", 370 | " onehot_test, \n", 371 | " y_yrs_train, \n", 372 | " y_yrs_val, \n", 373 | " y_yrs_test, \n", 374 | " target_years, \n", 375 | " map_shape,\n", 376 | " settings) = data_processing.get_cmip_data(DATA_DIRECTORY, settings, verbose=0)\n", 377 | "\n", 378 | " #---------------------------------------- \n", 379 | " # make predictions for observations and cmip results\n", 380 | " pred_train = model.predict(x_train)\n", 381 | " pred_val = model.predict(x_val)\n", 382 | " pred_test = model.predict(x_test) \n", 383 | " pred_obs = model.predict(x_obs)\n", 384 | " pred_obs_giss = model.predict(x_obs_giss)\n", 385 | "\n", 386 | " #---------------------------------------- \n", 387 | " # compute metrics to compare\n", 388 | " error_val = np.mean(np.abs(pred_val[:,0] - onehot_val[:,0]))\n", 389 | " error_test = np.mean(np.abs(pred_test[:,0] - onehot_test[:,0])) \n", 390 | " __, __, d_val, __ = custom_metrics.compute_pit(onehot_val, x_data=x_val, model_shash = model)\n", 391 | " __, __, d_test, __ = custom_metrics.compute_pit(onehot_test, x_data=x_test, model_shash = model) \n", 392 | " __, __, d_valtest, __ = custom_metrics.compute_pit(np.append(onehot_val,onehot_test,axis=0), x_data=np.append(x_val,x_test,axis=0), model_shash = model) \n", 393 | " loss_val = network.RegressLossExpSigma(onehot_val,pred_val).numpy()\n", 394 | " loss_test = network.RegressLossExpSigma(onehot_test,pred_test).numpy()\n", 395 | " \n", 396 | " d = {}\n", 397 | " d[\"exp_name\"] = settings[\"exp_name\"]\n", 398 | " d[\"seed\"] = settings[\"seed\"]\n", 399 | " d[\"hiddens\"] = str(settings[\"hiddens\"])\n", 400 | " d[\"ridge_param\"] = settings[\"ridge_param\"][0] \n", 401 | " d[\"error_val\"] = error_val\n", 402 | " d[\"error_test\"] = error_test\n", 403 | " d[\"loss_val\"] = loss_val\n", 404 | " d[\"loss_test\"] = loss_test \n", 405 | " d[\"d_val\"] = d_val\n", 406 | " d[\"d_test\"] = d_test\n", 407 | " d[\"d_valtest\"] = d_valtest\n", 408 | " d[\"best_2021_mu\"] = pred_obs[-1][0]\n", 409 | " d[\"best_2021_sigma\"] = pred_obs[-1][1] \n", 410 | " d[\"giss_2021_mu\"] = pred_obs_giss[-1][0]\n", 411 | " d[\"giss_2021_sigma\"] = pred_obs_giss[-1][1] \n", 412 | "\n", 413 | " df = pd.DataFrame(d, index=[0])\n", 414 | " df_metrics = pd.concat([df_metrics,df])\n", 415 | "\n", 416 | " # there should NOT be any duplicates \n", 417 | " df_duplicated = df_metrics.duplicated(subset=(\"exp_name\",\"seed\"))\n", 418 | " if(len(df_duplicated[df_duplicated==True]) > 0):\n", 419 | " display(df_duplicated)\n", 420 | " raise ValueError('there are duplicated entries!')\n", 421 | " df_metrics = df_metrics.drop_duplicates(ignore_index=True, keep=\"last\", subset=(\"exp_name\",\"seed\")) \n", 422 | " \n", 423 | " if SAVE_FILE:\n", 424 | " df_metrics.to_pickle(PREDICTIONS_DIRECTORY + \"df_random_seed_rev2.pickle\")\n", 425 | " \n", 426 | " display(df_metrics)" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "id": "a3908458-5186-46e2-8758-db008096a564", 433 | "metadata": { 434 | "pycharm": { 435 | "name": "#%%\n" 436 | } 437 | }, 438 | "outputs": [], 439 | "source": [ 440 | "error('here')" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "id": "a640a189-c126-4ed1-89e5-ad620b82d4c4", 446 | "metadata": { 447 | "pycharm": { 448 | "name": "#%% md\n" 449 | }, 450 | "tags": [] 451 | }, 452 | "source": [ 453 | "## Random seeds for obs" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "id": "0001e72a-849a-46a1-bf8a-988961fd0248", 460 | "metadata": { 461 | "pycharm": { 462 | "name": "#%%\n" 463 | } 464 | }, 465 | "outputs": [], 466 | "source": [ 467 | "# PLOT ACROSS SSPs and TARGETS\n", 468 | "EXP_FOR_PLOTTING = ('exp11C_370','exp11C_245','exp11C_126','exp15C_370','exp15C_245','exp15C_126','exp20C_370','exp20C_245','exp20C_126')\n", 469 | "clr_order = [0,2,1,3,4,5,6,7,8,]\n", 470 | "x_labels = EXP_FOR_PLOTTING\n", 471 | "#------------------------------------------------------------\n", 472 | "fig, ax = plt.subplots(1,1,figsize=(6.5,2.75))\n", 473 | "\n", 474 | "for obs_type in ('best',):\n", 475 | " for iexp,exp_name in enumerate(EXP_FOR_PLOTTING):\n", 476 | " iplot = np.where(df_metrics[\"exp_name\"]==exp_name)[0]\n", 477 | "\n", 478 | " if obs_type=='giss':\n", 479 | " alpha = 0.3\n", 480 | " shift_extra = .05\n", 481 | " clr = np.array(palette)[clr_order][iexp]\n", 482 | " else:\n", 483 | " alpha = 1.0\n", 484 | " shift_extra = 0.\n", 485 | " clr = np.array(palette)[clr_order][iexp]\n", 486 | " \n", 487 | " ax.errorbar(np.ones(iplot.shape)*iexp+np.linspace(-.4,.4,len(iplot))+shift_extra,\n", 488 | " df_metrics.iloc[iplot][obs_type + \"_2021_mu\"]+2021,\n", 489 | " yerr=df_metrics.iloc[iplot][obs_type + \"_2021_sigma\"],\n", 490 | " color=clr,\n", 491 | " marker='.',\n", 492 | " linestyle='',\n", 493 | " elinewidth=.25,\n", 494 | " markersize=2,\n", 495 | " alpha=alpha,\n", 496 | " )\n", 497 | " \n", 498 | " # plot the text above the bars\n", 499 | " if obs_type=='best':\n", 500 | " max_y_value = np.max(df_metrics.iloc[iplot][obs_type + \"_2021_mu\"]+df_metrics.iloc[iplot][obs_type + \"_2021_sigma\"])\n", 501 | " if exp_name=='exp15C_370':\n", 502 | " add_val = 3\n", 503 | " elif exp_name=='exp20C_126' or exp_name=='exp15C_126':\n", 504 | " add_val = -1\n", 505 | " else:\n", 506 | " add_val = 2\n", 507 | " text_name = 'SSP'+exp_name[7]+'-'+exp_name[8] + '.' + exp_name[9] + '\\n' +exp_name[3] + '.' + exp_name[4]+'C'\n", 508 | "\n", 509 | " plt.text(iexp,\n", 510 | " max_y_value+add_val+2021,\n", 511 | " text_name,\n", 512 | " fontsize=FS*0.8,\n", 513 | " color=np.array(palette)[clr_order][iexp],\n", 514 | " horizontalalignment='center',\n", 515 | " )\n", 516 | "\n", 517 | "\n", 518 | "ax.set_ylabel('year threshold is reached')\n", 519 | "ax.set_title('Observations 2021 Predicted $\\mu \\pm \\sigma$')\n", 520 | "ax.set_xlabel(None)\n", 521 | "format_spines(ax)\n", 522 | "ax.set_xticks(np.arange(0,len(x_labels)),'', fontsize=FS*0.8,rotation=45)\n", 523 | "ax.set_yticks(np.arange(1950,2100,10),np.arange(1950,2100,10).round())\n", 524 | "ax.set_ylim(-10+2021,60+2021)\n", 525 | "plt.grid(which='major',axis='y',linewidth=.25,linestyle='--',alpha=.5)\n", 526 | "\n", 527 | "plt.tight_layout()\n", 528 | "# plots.savefig(FIGURE_DIRECTORY + 'obs_BEST_GISS' + '_params_ssp_target_comparison',dpi=savefig_dpi)\n", 529 | "plots.savefig(FIGURE_DIRECTORY + 'obs_BEST' + '_params_ssp_target_comparison',dpi=savefig_dpi)\n", 530 | "plt.show()\n" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": null, 536 | "id": "7495392a-21ed-4eb6-815e-708afd209f9d", 537 | "metadata": { 538 | "pycharm": { 539 | "name": "#%%\n" 540 | } 541 | }, 542 | "outputs": [], 543 | "source": [ 544 | "\n", 545 | "norm_incs = np.arange(-80,80,1)\n", 546 | "#------------------------------------------------------------\n", 547 | "fig, axs = plt.subplots(3,1,figsize=(4.9,2.75*3))\n", 548 | "\n", 549 | "PLOT_SEED = 2247\n", 550 | "obs_type = 'best'\n", 551 | "\n", 552 | "\n", 553 | "for thresh in (1.1, 1.5, 2.0):\n", 554 | " if thresh==1.1:\n", 555 | " EXP_FOR_PLOTTING = ('exp11C_370','exp11C_245','exp11C_126') \n", 556 | " ax = axs[0]\n", 557 | " thresh_text = '1.1C'\n", 558 | " text_x = 2040\n", 559 | " elif thresh==1.5:\n", 560 | " EXP_FOR_PLOTTING = ('exp15C_370','exp15C_245','exp15C_126') \n", 561 | " ax = axs[1]\n", 562 | " thresh_text = '1.5C' \n", 563 | " text_x = 2050 \n", 564 | " elif thresh==2.0:\n", 565 | " EXP_FOR_PLOTTING = ('exp20C_370','exp20C_245','exp20C_126') \n", 566 | " ax = axs[2] \n", 567 | " thresh_text = '2.0C' \n", 568 | " text_x = 2068 \n", 569 | " else:\n", 570 | " raise ValueError('no such threshold')\n", 571 | " for iexp,exp_name in enumerate(EXP_FOR_PLOTTING):\n", 572 | " iplot = np.where((df_metrics[\"exp_name\"]==exp_name) & (df_metrics[\"seed\"]==PLOT_SEED))[0]\n", 573 | "\n", 574 | " mu_pred = df_metrics.iloc[iplot][obs_type + \"_2021_mu\"].values[0]\n", 575 | " sigma_pred = df_metrics.iloc[iplot][obs_type + \"_2021_sigma\"].values[0]\n", 576 | " norm_dist = tfp.distributions.Normal(mu_pred,sigma_pred)\n", 577 | " norm_perc_low = norm_dist.quantile(.25).numpy() \n", 578 | " norm_perc_high = norm_dist.quantile(.75).numpy() \n", 579 | " norm_perc_med = norm_dist.quantile(.5).numpy() \n", 580 | " norm_cpd = norm_dist.prob(norm_incs)\n", 581 | "\n", 582 | " if(df_metrics.iloc[iplot][\"exp_name\"].values[0][-3:]=='370'):\n", 583 | " clr = \"tab:red\"\n", 584 | " # text_x = 2040\n", 585 | " text_y = .05\n", 586 | " ssp_text = 'SSP3-7.0'\n", 587 | " elif(df_metrics.iloc[iplot][\"exp_name\"].values[0][-3:]=='245'):\n", 588 | " clr = \"tab:purple\"\n", 589 | " # text_x = 2050\n", 590 | " text_y = .035 \n", 591 | " ssp_text = 'SSP2-4.5' \n", 592 | " else:\n", 593 | " clr = \"tab:blue\"\n", 594 | " # text_x = 2070\n", 595 | " text_y = .02 \n", 596 | " ssp_text = 'SSP1-2.6' \n", 597 | "\n", 598 | " ax.plot(norm_incs+2021,\n", 599 | " norm_cpd,\n", 600 | " color=clr,\n", 601 | " linewidth=2.5,\n", 602 | " )\n", 603 | "\n", 604 | " # # plot the text above the bars\n", 605 | " text_name = ssp_text + '\\n' + str(int(np.round(mu_pred+2021))) + ' (' + str(int(np.round(mu_pred+2021-sigma_pred))) + ' to ' + str(int(np.round(mu_pred+2021+sigma_pred))) + ')'\n", 606 | " ax.text(text_x,\n", 607 | " text_y,\n", 608 | " text_name,\n", 609 | " fontsize=FS*0.8,\n", 610 | " color=clr,\n", 611 | " horizontalalignment='left',\n", 612 | " )\n", 613 | " ax.text(1998,\n", 614 | " .1,\n", 615 | " thresh_text + ' threshold',\n", 616 | " fontsize=FS,\n", 617 | " horizontalalignment=\"left\",\n", 618 | " verticalalignment=\"top\",\n", 619 | " color='k', \n", 620 | " # weight=\"bold\",\n", 621 | " ) \n", 622 | "\n", 623 | " ax.set_xlabel('year')\n", 624 | " format_spines(ax)\n", 625 | " ax.set_xlim(-25+2021,70+2021)\n", 626 | " ax.set_ylim(-0.001,.1)\n", 627 | "\n", 628 | "plt.tight_layout()\n", 629 | "plots.savefig(FIGURE_DIRECTORY + 'obs_2021PDF_allSSPs',dpi=savefig_dpi)\n", 630 | "plt.show()\n" 631 | ] 632 | }, 633 | { 634 | "cell_type": "markdown", 635 | "id": "a7ca0a86-5fe4-4d86-b0fc-5dfb74b0a3fd", 636 | "metadata": { 637 | "pycharm": { 638 | "name": "#%% md\n" 639 | }, 640 | "tags": [] 641 | }, 642 | "source": [ 643 | "## Plots across hyperparameters" 644 | ] 645 | }, 646 | { 647 | "cell_type": "markdown", 648 | "id": "b0ccdbf7-e9e6-41f8-b5dd-9a4359730c53", 649 | "metadata": { 650 | "pycharm": { 651 | "name": "#%% md\n" 652 | }, 653 | "tags": [] 654 | }, 655 | "source": [ 656 | "### Plots across obs predictions" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": null, 662 | "id": "a7d5f973-8118-4df7-985f-3e21b53a6102", 663 | "metadata": { 664 | "pycharm": { 665 | "name": "#%%\n" 666 | } 667 | }, 668 | "outputs": [], 669 | "source": [ 670 | "# PLOT ACROSS RIDGE CHOICES\n", 671 | "EXP_FOR_PLOTTING = ('exp12','exp30','exp31','exp32','exp33')\n", 672 | "df_metrics_plot = df_metrics[df_metrics[\"exp_name\"].isin(EXP_FOR_PLOTTING)]\n", 673 | "df_metrics_plot = df_metrics_plot.sort_values(\"ridge_param\")\n", 674 | "clr_order = [0,0,0,2,0,0,0,0,]\n", 675 | "print('PARAMETER CHECK: ' + str(df_metrics_plot[\"hiddens\"].unique()))\n", 676 | "#------------------------------------------------------------\n", 677 | "fig, axs = plt.subplots(1,2,figsize=(7,2.75))\n", 678 | "\n", 679 | "ax = axs[0]\n", 680 | "sns.swarmplot(x=\"exp_name\",\n", 681 | " y=\"best_2021_mu\",\n", 682 | " data=df_metrics_plot,\n", 683 | " palette=np.array(palette)[clr_order],\n", 684 | " size=2.5,\n", 685 | " ax = ax,\n", 686 | " )\n", 687 | "ax.set_ylabel('$\\mu$')\n", 688 | "ax.set_title('Obs. 2021 predicted $\\mu$')\n", 689 | "ax.set_ylim(0.0,25.0)\n", 690 | "format_spines(ax)\n", 691 | "ax.set_xlabel('ridge parameter',fontsize=FS)\n", 692 | "x_labels = df_metrics_plot[\"ridge_param\"].unique()\n", 693 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 694 | "\n", 695 | "ax = axs[1]\n", 696 | "sns.swarmplot(x=\"exp_name\",\n", 697 | " y=\"best_2021_sigma\",\n", 698 | " data=df_metrics_plot,\n", 699 | " palette=np.array(palette)[clr_order],\n", 700 | " size=2.5,\n", 701 | " ax=ax,\n", 702 | " )\n", 703 | "ax.set_title('Obs. 2021 predicted $\\sigma$',fontsize=FS*1.2)\n", 704 | "ax.set_ylim(0,None)\n", 705 | "format_spines(ax)\n", 706 | "ax.set_ylabel('$\\sigma$')\n", 707 | "ax.set_xlabel('ridge parameter',fontsize=FS)\n", 708 | "x_labels = df_metrics_plot[\"ridge_param\"].unique()\n", 709 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 710 | "\n", 711 | "\n", 712 | "plt.tight_layout()\n", 713 | "plots.savefig(FIGURE_DIRECTORY + 'obsBEST_params_ridge_comparison',dpi=savefig_dpi)\n", 714 | "plt.show()\n" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": null, 720 | "id": "6d684280-1d84-4257-8d08-c9a50e640ec9", 721 | "metadata": { 722 | "pycharm": { 723 | "name": "#%%\n" 724 | } 725 | }, 726 | "outputs": [], 727 | "source": [ 728 | "# PLOT ACROSS HIDDEN CHOICES\n", 729 | "\n", 730 | "EXP_FOR_PLOTTING = ('exp5','exp10','exp11','exp12','exp13','exp14')\n", 731 | "# EXP_FOR_PLOTTING = ('exp4','exp20','exp21','exp22','exp23',)\n", 732 | "df_metrics_plot = df_metrics[df_metrics[\"exp_name\"].isin(EXP_FOR_PLOTTING)]\n", 733 | "df_metrics_plot = df_metrics_plot.sort_values(\"hiddens\")\n", 734 | "x_labels = df_metrics_plot[\"hiddens\"].unique()\n", 735 | "x_labels[x_labels=='[2]'] = '[2]\\nlinear'\n", 736 | "clr_order = [0,0,0,2,0,0,0,0,]\n", 737 | "print('PARAMETER CHECK: ' + str(df_metrics_plot[\"ridge_param\"].unique()))\n", 738 | "#------------------------------------------------------------\n", 739 | "fig, axs = plt.subplots(1,2,figsize=(7,2.75))\n", 740 | "\n", 741 | "ax = axs[0]\n", 742 | "sns.swarmplot(x=\"exp_name\",\n", 743 | " y=\"best_2021_mu\",\n", 744 | " data=df_metrics_plot,\n", 745 | " palette=np.array(palette)[clr_order],\n", 746 | " size=2.5,\n", 747 | " ax = ax,\n", 748 | " )\n", 749 | "ax.set_ylabel('$\\mu$')\n", 750 | "ax.set_title('Obs. 2021 predicted $\\mu$')\n", 751 | "ax.set_ylim(0.0,20.0)\n", 752 | "format_spines(ax)\n", 753 | "ax.set_xlabel('hidden layers x nodes',fontsize=FS)\n", 754 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 755 | "\n", 756 | "ax = axs[1]\n", 757 | "sns.swarmplot(x=\"exp_name\",\n", 758 | " y=\"best_2021_sigma\",\n", 759 | " data=df_metrics_plot,\n", 760 | " palette=np.array(palette)[clr_order],\n", 761 | " size=2.5,\n", 762 | " ax=ax,\n", 763 | " )\n", 764 | "ax.set_title('Obs. 2021 predicted $\\sigma$',fontsize=FS*1.2)\n", 765 | "ax.set_ylim(2,None)\n", 766 | "format_spines(ax)\n", 767 | "ax.set_ylabel('$\\sigma$')\n", 768 | "ax.set_xlabel('hidden layers x nodes',fontsize=FS)\n", 769 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 770 | "\n", 771 | "\n", 772 | "plt.tight_layout()\n", 773 | "plots.savefig(FIGURE_DIRECTORY + 'obsBEST_params_hiddens_comparison',dpi=savefig_dpi)\n", 774 | "plt.show()\n" 775 | ] 776 | }, 777 | { 778 | "cell_type": "markdown", 779 | "id": "80cd3517-f8e2-4d6c-b51a-81a83dee4a0f", 780 | "metadata": { 781 | "pycharm": { 782 | "name": "#%% md\n" 783 | } 784 | }, 785 | "source": [ 786 | "### Error and PIT" 787 | ] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "execution_count": null, 792 | "id": "cb0ec9be-99ce-4f71-b473-490436078910", 793 | "metadata": { 794 | "pycharm": { 795 | "name": "#%%\n" 796 | } 797 | }, 798 | "outputs": [], 799 | "source": [ 800 | "# PLOT ACROSS RIDGE CHOICES\n", 801 | "EXP_FOR_PLOTTING = ('exp12','exp30','exp31','exp32','exp33')\n", 802 | "df_metrics_plot = df_metrics[df_metrics[\"exp_name\"].isin(EXP_FOR_PLOTTING)]\n", 803 | "df_metrics_plot = df_metrics_plot.sort_values(\"ridge_param\")\n", 804 | "clr_order = [0,0,0,2,0,0,0,0,]\n", 805 | "\n", 806 | "print('PARAMETER CHECK: ' + str(df_metrics_plot[\"hiddens\"].unique()))\n", 807 | "#------------------------------------------------------------\n", 808 | "fig, axs = plt.subplots(1,3,figsize=(8.5,2.5))\n", 809 | "\n", 810 | "ax = axs[0]\n", 811 | "sns.boxplot(x=\"exp_name\",\n", 812 | " y=\"error_val\",\n", 813 | " palette=np.array(palette)[clr_order],\n", 814 | " data=df_metrics_plot,\n", 815 | " boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),\n", 816 | " whiskerprops=dict(color='gray',linewidth=1.),\n", 817 | " medianprops=dict(color='gray',linewidth=1.),\n", 818 | " capprops=dict(color='gray',linewidth=1.), \n", 819 | " whis=100., \n", 820 | " ax = ax,\n", 821 | " )\n", 822 | "sns.swarmplot(x=\"exp_name\",\n", 823 | " y=\"error_test\",\n", 824 | " palette=np.array(palette)[clr_order],\n", 825 | " data=df_metrics_plot,\n", 826 | " size=2.5,\n", 827 | " ax = ax,\n", 828 | " )\n", 829 | "ax.set_ylabel('error (years)')\n", 830 | "ax.set_title('Mean Absolute Error')\n", 831 | "ax.set_ylim(2.0,5.0)\n", 832 | "format_spines(ax)\n", 833 | "ax.set_xlabel('ridge parameter',fontsize=FS)\n", 834 | "x_labels = df_metrics_plot[\"ridge_param\"].unique()\n", 835 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 836 | "\n", 837 | "ax = axs[1]\n", 838 | "sns.boxplot(x=\"exp_name\",\n", 839 | " y=\"d_val\",\n", 840 | " data=df_metrics_plot,\n", 841 | " palette=np.array(palette)[clr_order],\n", 842 | " whis=100., \n", 843 | " boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),\n", 844 | " whiskerprops=dict(color='gray',linewidth=1.),\n", 845 | " medianprops=dict(color='gray',linewidth=1.),\n", 846 | " capprops=dict(color='gray',linewidth=1.), \n", 847 | " ax = ax,\n", 848 | " )\n", 849 | "sns.swarmplot(x=\"exp_name\",\n", 850 | " y=\"d_test\",\n", 851 | " data=df_metrics_plot,\n", 852 | " palette=np.array(palette)[clr_order],\n", 853 | " size=2.5,\n", 854 | " ax=ax,\n", 855 | " )\n", 856 | "ax.set_title('PIT D Metric',fontsize=FS*1.2)\n", 857 | "ax.set_ylim(0,.055)\n", 858 | "format_spines(ax)\n", 859 | "ax.set_xlabel('ridge parameter',fontsize=FS)\n", 860 | "x_labels = df_metrics_plot[\"ridge_param\"].unique()\n", 861 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 862 | "\n", 863 | "ax = axs[2]\n", 864 | "sns.boxplot(x=\"exp_name\",\n", 865 | " y=\"loss_val\",\n", 866 | " data=df_metrics_plot,\n", 867 | " palette=np.array(palette)[clr_order],\n", 868 | " whis=100., \n", 869 | " boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),\n", 870 | " whiskerprops=dict(color='gray',linewidth=1.),\n", 871 | " medianprops=dict(color='gray',linewidth=1.),\n", 872 | " capprops=dict(color='gray',linewidth=1.), \n", 873 | " ax = ax,\n", 874 | " )\n", 875 | "sns.swarmplot(x=\"exp_name\",\n", 876 | " y=\"loss_test\",\n", 877 | " data=df_metrics_plot,\n", 878 | " palette=np.array(palette)[clr_order],\n", 879 | " size=2.5,\n", 880 | " ax=ax,\n", 881 | " )\n", 882 | "ax.set_title('Loss',fontsize=FS*1.2)\n", 883 | "ax.set_ylim(2.,5.0)\n", 884 | "format_spines(ax)\n", 885 | "ax.set_xlabel('ridge parameter',fontsize=FS)\n", 886 | "x_labels = df_metrics_plot[\"ridge_param\"].unique()\n", 887 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 888 | "\n", 889 | "\n", 890 | "\n", 891 | "plt.tight_layout()\n", 892 | "plots.savefig(FIGURE_DIRECTORY + 'cmip6_metrics_ridge_comparison',dpi=savefig_dpi)\n", 893 | "plt.show()\n" 894 | ] 895 | }, 896 | { 897 | "cell_type": "code", 898 | "execution_count": null, 899 | "id": "695cb70d-15c2-41b8-ada2-55fadacb146b", 900 | "metadata": { 901 | "pycharm": { 902 | "name": "#%%\n" 903 | } 904 | }, 905 | "outputs": [], 906 | "source": [ 907 | "# PLOT ACROSS HIDDEN CHOICES\n", 908 | "EXP_FOR_PLOTTING = ('exp5','exp10','exp11','exp12','exp13','exp14')\n", 909 | "df_metrics_plot = df_metrics[df_metrics[\"exp_name\"].isin(EXP_FOR_PLOTTING)]\n", 910 | "df_metrics_plot = df_metrics_plot.sort_values(\"hiddens\")\n", 911 | "x_labels = df_metrics_plot[\"hiddens\"].unique()\n", 912 | "x_labels[x_labels=='[2]'] = '[2]\\nlinear'\n", 913 | "clr_order = [0,0,0,2,0,0,0,0,]\n", 914 | "\n", 915 | "print('PARAMETER CHECK: ' + str(df_metrics_plot[\"ridge_param\"].unique()))\n", 916 | "#------------------------------------------------------------\n", 917 | "fig, axs = plt.subplots(1,3,figsize=(8.5,2.5))\n", 918 | "\n", 919 | "ax = axs[0]\n", 920 | "sns.boxplot(x=\"exp_name\",\n", 921 | " y=\"error_val\",\n", 922 | " data=df_metrics_plot,\n", 923 | " palette=np.array(palette)[clr_order],\n", 924 | " boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),\n", 925 | " whiskerprops=dict(color='gray',linewidth=1.),\n", 926 | " medianprops=dict(color='gray',linewidth=1.),\n", 927 | " capprops=dict(color='gray',linewidth=1.), \n", 928 | " whis=100., \n", 929 | " ax = ax,\n", 930 | " )\n", 931 | "sns.swarmplot(x=\"exp_name\",\n", 932 | " y=\"error_test\",\n", 933 | " data=df_metrics_plot,\n", 934 | " palette=np.array(palette)[clr_order],\n", 935 | " size=2.5,\n", 936 | " ax = ax,\n", 937 | " )\n", 938 | "ax.set_ylabel('error (years)')\n", 939 | "ax.set_title('Mean Absolute Error')\n", 940 | "ax.set_ylim(2.0,6.0)\n", 941 | "format_spines(ax)\n", 942 | "ax.set_xlabel('hidden layers x nodes',fontsize=FS)\n", 943 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 944 | "\n", 945 | "ax = axs[1]\n", 946 | "sns.boxplot(x=\"exp_name\",\n", 947 | " y=\"d_val\",\n", 948 | " data=df_metrics_plot,\n", 949 | " palette=np.array(palette)[clr_order],\n", 950 | " whis=100., \n", 951 | " boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),\n", 952 | " whiskerprops=dict(color='gray',linewidth=1.),\n", 953 | " medianprops=dict(color='gray',linewidth=1.),\n", 954 | " capprops=dict(color='gray',linewidth=1.), \n", 955 | " ax = ax,\n", 956 | " )\n", 957 | "sns.swarmplot(x=\"exp_name\",\n", 958 | " y=\"d_test\",\n", 959 | " data=df_metrics_plot,\n", 960 | " palette=np.array(palette)[clr_order],\n", 961 | " size=2.5,\n", 962 | " ax=ax,\n", 963 | " )\n", 964 | "ax.set_title('PIT D Metric',fontsize=FS*1.2)\n", 965 | "ax.set_ylim(0,.08)\n", 966 | "format_spines(ax)\n", 967 | "ax.set_xlabel('hidden layers x nodes',fontsize=FS)\n", 968 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 969 | "\n", 970 | "ax = axs[2]\n", 971 | "sns.boxplot(x=\"exp_name\",\n", 972 | " y=\"loss_val\",\n", 973 | " data=df_metrics_plot,\n", 974 | " palette=np.array(palette)[clr_order],\n", 975 | " whis=100., \n", 976 | " boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),\n", 977 | " whiskerprops=dict(color='gray',linewidth=1.),\n", 978 | " medianprops=dict(color='gray',linewidth=1.),\n", 979 | " capprops=dict(color='gray',linewidth=1.), \n", 980 | " ax = ax,\n", 981 | " )\n", 982 | "sns.swarmplot(x=\"exp_name\",\n", 983 | " y=\"loss_test\",\n", 984 | " data=df_metrics_plot,\n", 985 | " palette=np.array(palette)[clr_order],\n", 986 | " size=2.5,\n", 987 | " ax=ax,\n", 988 | " )\n", 989 | "ax.set_title('Loss',fontsize=FS*1.2)\n", 990 | "ax.set_ylim(2.0,5.0)\n", 991 | "format_spines(ax)\n", 992 | "ax.set_xlabel('hidden layers x nodes',fontsize=FS)\n", 993 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 994 | "\n", 995 | "\n", 996 | "plt.tight_layout()\n", 997 | "plots.savefig(FIGURE_DIRECTORY + 'cmip6_metrics_hiddens_comparison',dpi=savefig_dpi)\n", 998 | "plt.show()\n" 999 | ] 1000 | }, 1001 | { 1002 | "cell_type": "markdown", 1003 | "id": "4a3e1346-4912-4b63-bf85-84a3d10c893c", 1004 | "metadata": { 1005 | "pycharm": { 1006 | "name": "#%% md\n" 1007 | } 1008 | }, 1009 | "source": [ 1010 | "## Plot all hyperparameter experiments" 1011 | ] 1012 | }, 1013 | { 1014 | "cell_type": "code", 1015 | "execution_count": null, 1016 | "id": "0d07335b-b85d-40a9-b5fc-dcddf834d21d", 1017 | "metadata": { 1018 | "pycharm": { 1019 | "name": "#%%\n" 1020 | } 1021 | }, 1022 | "outputs": [], 1023 | "source": [ 1024 | "# PLOT ACROSS ALL EXPERIMENTS\n", 1025 | "EXP_FOR_PLOTTING = EXP_NAME_VEC\n", 1026 | "df_metrics_plot = df_metrics[df_metrics[\"exp_name\"].isin(EXP_FOR_PLOTTING)]\n", 1027 | "x_labels = df_metrics_plot[\"exp_name\"].unique()\n", 1028 | "clr_order = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,]\n", 1029 | "#------------------------------------------------------------\n", 1030 | "fig, axs = plt.subplots(1,3,figsize=(11,3.))\n", 1031 | "\n", 1032 | "ax = axs[0]\n", 1033 | "sns.boxplot(x=\"exp_name\",\n", 1034 | " y=\"error_val\",\n", 1035 | " data=df_metrics_plot,\n", 1036 | " # palette=np.array(palette)[clr_order],\n", 1037 | " boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),\n", 1038 | " whiskerprops=dict(color='gray',linewidth=1.),\n", 1039 | " medianprops=dict(color='gray',linewidth=1.),\n", 1040 | " capprops=dict(color='gray',linewidth=1.), \n", 1041 | " whis=100., \n", 1042 | " ax = ax,\n", 1043 | " )\n", 1044 | "sns.swarmplot(x=\"exp_name\",\n", 1045 | " y=\"error_test\",\n", 1046 | " data=df_metrics_plot,\n", 1047 | " # palette=np.array(palette)[clr_order],\n", 1048 | " size=2.5,\n", 1049 | " ax = ax,\n", 1050 | " )\n", 1051 | "ax.set_yticks(np.arange(0,10,.5))\n", 1052 | "ax.set_ylabel('error (years)')\n", 1053 | "ax.set_title('Mean Absolute Error')\n", 1054 | "ax.set_ylim(2.5,5.5)\n", 1055 | "format_spines(ax)\n", 1056 | "# ax.set_xlabel('hidden layers x nodes',fontsize=FS)\n", 1057 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 1058 | "ax.grid(alpha=.3)\n", 1059 | "\n", 1060 | "ax = axs[1]\n", 1061 | "sns.boxplot(x=\"exp_name\",\n", 1062 | " y=\"d_val\",\n", 1063 | " data=df_metrics_plot,\n", 1064 | " # palette=np.array(palette)[clr_order],\n", 1065 | " whis=100., \n", 1066 | " boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),\n", 1067 | " whiskerprops=dict(color='gray',linewidth=1.),\n", 1068 | " medianprops=dict(color='gray',linewidth=1.),\n", 1069 | " capprops=dict(color='gray',linewidth=1.), \n", 1070 | " ax = ax,\n", 1071 | " )\n", 1072 | "sns.swarmplot(x=\"exp_name\",\n", 1073 | " y=\"d_test\",\n", 1074 | " data=df_metrics_plot,\n", 1075 | " # palette=np.array(palette)[clr_order],\n", 1076 | " size=2.5,\n", 1077 | " ax=ax,\n", 1078 | " )\n", 1079 | "ax.set_title('PIT D Metric',fontsize=FS*1.2)\n", 1080 | "ax.set_ylim(0,.06)\n", 1081 | "format_spines(ax)\n", 1082 | "# ax.set_xlabel('hidden layers x nodes',fontsize=FS)\n", 1083 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 1084 | "ax.grid(alpha=.3)\n", 1085 | "\n", 1086 | "ax = axs[2]\n", 1087 | "sns.boxplot(x=\"exp_name\",\n", 1088 | " y=\"loss_val\",\n", 1089 | " data=df_metrics_plot,\n", 1090 | " # palette=np.array(palette)[clr_order],\n", 1091 | " whis=100., \n", 1092 | " boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),\n", 1093 | " whiskerprops=dict(color='gray',linewidth=1.),\n", 1094 | " medianprops=dict(color='gray',linewidth=1.),\n", 1095 | " capprops=dict(color='gray',linewidth=1.), \n", 1096 | " ax = ax,\n", 1097 | " )\n", 1098 | "sns.swarmplot(x=\"exp_name\",\n", 1099 | " y=\"loss_test\",\n", 1100 | " data=df_metrics_plot,\n", 1101 | " # palette=np.array(palette)[clr_order],\n", 1102 | " size=2.5,\n", 1103 | " ax=ax,\n", 1104 | " )\n", 1105 | "ax.set_title('Loss',fontsize=FS*1.2)\n", 1106 | "ax.set_ylim(2.0,5.0)\n", 1107 | "format_spines(ax)\n", 1108 | "# ax.set_xlabel('hidden layers x nodes',fontsize=FS)\n", 1109 | "ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)\n", 1110 | "ax.grid(alpha=.3)\n", 1111 | "\n", 1112 | "plt.tight_layout()\n", 1113 | "plots.savefig(FIGURE_DIRECTORY + 'cmip6_metrics_all_comparison',dpi=savefig_dpi)\n", 1114 | "plt.show()\n" 1115 | ] 1116 | }, 1117 | { 1118 | "cell_type": "markdown", 1119 | "id": "a2937969-2ec2-472b-93ea-e35506432d83", 1120 | "metadata": { 1121 | "pycharm": { 1122 | "name": "#%% md\n" 1123 | }, 1124 | "tags": [] 1125 | }, 1126 | "source": [ 1127 | "## Explore the dataframe" 1128 | ] 1129 | }, 1130 | { 1131 | "cell_type": "code", 1132 | "execution_count": null, 1133 | "id": "79bc12a8-4bd2-46bb-814c-b1bd62059d3a", 1134 | "metadata": { 1135 | "pycharm": { 1136 | "name": "#%%\n" 1137 | } 1138 | }, 1139 | "outputs": [], 1140 | "source": [ 1141 | "EXP_NAME = 'exp12'\n", 1142 | "df = df_metrics[df_metrics[\"exp_name\"]==EXP_NAME]\n", 1143 | "PLOT_SEED = df_metrics.iloc[df['loss_test'].idxmin()][\"seed\"]\n", 1144 | "# display(df_metrics.iloc[df['loss_test'].idxmin()])\n", 1145 | "display(df.sort_values(\"loss_val\"))\n", 1146 | "# df['loss_val'].idxmax()\n", 1147 | "# display(df_metrics[df_metrics[\"exp_name\"]==\"exp4\"].sort_values(\"error_val\").head())\n", 1148 | "# PLOT_SEED = 1257\n" 1149 | ] 1150 | }, 1151 | { 1152 | "cell_type": "code", 1153 | "execution_count": null, 1154 | "id": "a3f18794-668c-4805-86fe-611027aff407", 1155 | "metadata": { 1156 | "pycharm": { 1157 | "name": "#%%\n" 1158 | } 1159 | }, 1160 | "outputs": [], 1161 | "source": [ 1162 | "# display(df_metrics[df_metrics[\"exp_name\"]==\"exp4\"].sort_values(\"error_val\").head())\n", 1163 | "# display(df_metrics[df_metrics[\"exp_name\"]==\"exp4\"].sort_values(\"d_val\").head())\n", 1164 | "# display(df_metrics[df_metrics[\"exp_name\"]==\"exp4\"].sort_values(\"error_test\").head())" 1165 | ] 1166 | }, 1167 | { 1168 | "cell_type": "raw", 1169 | "id": "3df53353-1469-44a8-88c2-15c4c00bf714", 1170 | "metadata": { 1171 | "pycharm": { 1172 | "name": "#%% raw\n" 1173 | } 1174 | }, 1175 | "source": [] 1176 | } 1177 | ], 1178 | "metadata": { 1179 | "colab": { 1180 | "collapsed_sections": [], 1181 | "name": "_main.ipynb", 1182 | "provenance": [] 1183 | }, 1184 | "kernelspec": { 1185 | "display_name": "Python 3 (ipykernel)", 1186 | "language": "python", 1187 | "name": "python3" 1188 | }, 1189 | "language_info": { 1190 | "codemirror_mode": { 1191 | "name": "ipython", 1192 | "version": 3 1193 | }, 1194 | "file_extension": ".py", 1195 | "mimetype": "text/x-python", 1196 | "name": "python", 1197 | "nbconvert_exporter": "python", 1198 | "pygments_lexer": "ipython3", 1199 | "version": "3.9.13" 1200 | }, 1201 | "vscode": { 1202 | "interpreter": { 1203 | "hash": "7b327a0708e120f225110d3517c433f9fd296de7c73bb21c34c6d8e40603491f" 1204 | } 1205 | } 1206 | }, 1207 | "nbformat": 4, 1208 | "nbformat_minor": 5 1209 | } 1210 | -------------------------------------------------------------------------------- /_train_model_v4.0_loopA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6da79c16-eb43-4664-a883-7a31f3af00da", 6 | "metadata": { 7 | "id": "4a650402-4774-49cb-9b72-9c8f1dd02f1d", 8 | "tags": [] 9 | }, 10 | "source": [ 11 | "# Detecting temperature targets\n", 12 | "##### authors: Elizabeth A. Barnes and Noah Diffenbaugh\n", 13 | "##### date: March 20, 2022\n", 14 | "\n", 15 | "README:\n", 16 | "This is the main training script for all TF models. Here are some tips for using this new re-factored code.\n", 17 | "\n", 18 | "* ```experiment_settings.py``` is now your go-to place. It is something like a research log. You want to continue to copy and paste new experimental designs (with unique names e.g. ```exp23```) and this way you can always refer back to an experiment you ran before without having to change a bunch of parameters again. \n", 19 | "\n", 20 | "* If all goes well and we don't need more data, you should only be modifying the file called ```experiment_settings.py``` and this notebook (although plots.py might be changed too). \n", 21 | "\n", 22 | "* To train a set of moodels, you go into ```experiment_settings.py``` and make a new experiment (with a new name, e.g. ```exp1``` and then you specify that same name here in Cell 3 for ```EXP_NAME```.\n", 23 | "\n", 24 | "* The parameter in settings called ```n_models```, will be more useful now. If you set this to a larger number, e.g. 20, it will train 20 models with the same experimental design but with different random training/validation/testing sets etc. You will then be able to analyze these models in another notebook.\n", 25 | "\n", 26 | "* Other choices you have here (outside of the usual experiment settings) is whether to overwrite existing models with the name experiment name. Typically, you want ```OVERWRITE_MODEL = False``` so that the code will continue training new random seeds where you left off (rather than starting over again).\n", 27 | "\n", 28 | "* Plots for model diagnostics are saved in the ```model_diagnostics``` directory. \n", 29 | "\n", 30 | "* Predictions for observations are saved in the ```saved_predictions``` directory, although you can always re-load the TF model and re-make the predictions in another notebook. But I thought this might be faster/easier.\n", 31 | "\n", 32 | "* TF models and their meta data are saved in the ```saved_models``` directory.\n", 33 | "\n", 34 | "* Once training is done, you can run the following to perform analysis and make/save plots for the paper. \n", 35 | "** ```compare_random_seeds.ipynb```\n", 36 | "** ```_analyze_models_vX.X.ipynb```\n", 37 | "** ```_visualize_xai_vX.X.ipynb```" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "7ccff821-b304-4009-8fe8-75a213b3f421", 43 | "metadata": { 44 | "tags": [] 45 | }, 46 | "source": [ 47 | "## Python stuff" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "fb968382-4186-466e-a85b-b00caa5fc9be", 54 | "metadata": { 55 | "colab": { 56 | "base_uri": "https://localhost:8080/" 57 | }, 58 | "executionInfo": { 59 | "elapsed": 17642, 60 | "status": "ok", 61 | "timestamp": 1646449680995, 62 | "user": { 63 | "displayName": "Elizabeth Barnes", 64 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiNPVVIWP6XAkP_hwu-8rAxoeeNuk2BMkX5-yuA=s64", 65 | "userId": "07585723222468022011" 66 | }, 67 | "user_tz": 420 68 | }, 69 | "id": "fb968382-4186-466e-a85b-b00caa5fc9be", 70 | "outputId": "d7964af9-2d52-4466-902d-9b85faba9a91", 71 | "tags": [] 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "import sys, os\n", 76 | "import importlib as imp\n", 77 | "\n", 78 | "import xarray as xr\n", 79 | "import numpy as np\n", 80 | "import matplotlib.pyplot as plt\n", 81 | "import scipy.stats as stats\n", 82 | "import tensorflow as tf\n", 83 | "import tensorflow_probability as tfp\n", 84 | "\n", 85 | "import experiment_settings\n", 86 | "import file_methods, plots, custom_metrics, network, data_processing\n", 87 | "\n", 88 | "import matplotlib as mpl\n", 89 | "mpl.rcParams[\"figure.facecolor\"] = \"white\"\n", 90 | "mpl.rcParams[\"figure.dpi\"] = 150\n", 91 | "savefig_dpi = 300\n", 92 | "np.warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "29a5cee3-6f4f-4818-92e1-1351eeeb565a", 99 | "metadata": { 100 | "colab": { 101 | "base_uri": "https://localhost:8080/" 102 | }, 103 | "executionInfo": { 104 | "elapsed": 30, 105 | "status": "ok", 106 | "timestamp": 1646449681009, 107 | "user": { 108 | "displayName": "Elizabeth Barnes", 109 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiNPVVIWP6XAkP_hwu-8rAxoeeNuk2BMkX5-yuA=s64", 110 | "userId": "07585723222468022011" 111 | }, 112 | "user_tz": 420 113 | }, 114 | "id": "29a5cee3-6f4f-4818-92e1-1351eeeb565a", 115 | "outputId": "e5f5b0ac-82b8-4147-bf44-4bc3b49466a2", 116 | "tags": [] 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "print(f\"python version = {sys.version}\")\n", 121 | "print(f\"numpy version = {np.__version__}\")\n", 122 | "print(f\"xarray version = {xr.__version__}\") \n", 123 | "print(f\"tensorflow version = {tf.__version__}\") \n", 124 | "print(f\"tensorflow-probability version = {tfp.__version__}\") " 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "id": "651315ce-eecc-4d30-8b90-c97d08936315", 130 | "metadata": { 131 | "tags": [] 132 | }, 133 | "source": [ 134 | "## User Choices" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "c83a544f-ef35-417f-bec4-62225d885014", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "EXP_NAME_LIST = ('exp13C_126','exp13C_370',)\n", 145 | "OVERWRITE_MODEL = True\n", 146 | "\n", 147 | "MODEL_DIRECTORY = 'saved_models/' \n", 148 | "PREDICTIONS_DIRECTORY = 'saved_predictions/'\n", 149 | "DATA_DIRECTORY = 'data/'\n", 150 | "DIAGNOSTICS_DIRECTORY = 'model_diagnostics/'\n", 151 | "FIGURE_DIRECTORY = 'figures/'" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "id": "30ea5755-e624-4b29-b88a-fd35d85ddb66", 157 | "metadata": { 158 | "tags": [] 159 | }, 160 | "source": [ 161 | "## Plotting functions" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "c1b7e014-e289-4fdc-9d82-3c976f8db8c7", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "def plot_one_to_one_diagnostic():\n", 172 | " if settings['network_type'] == \"shash2\":\n", 173 | " top_pred_idx = 0\n", 174 | " else:\n", 175 | " top_pred_idx = None\n", 176 | "\n", 177 | " YEARS_UNIQUE = np.unique(y_yrs_train)\n", 178 | " predict_train = model.predict(x_train)[:,top_pred_idx].flatten()\n", 179 | " predict_val = model.predict(x_val)[:,top_pred_idx].flatten()\n", 180 | " predict_test = model.predict(x_test)[:,top_pred_idx].flatten()\n", 181 | " mae = np.mean(np.abs(predict_test-y_test[:]))\n", 182 | " \n", 183 | " #--------------------------------\n", 184 | " clr = ('tab:purple','tab:orange', 'tab:blue', 'tab:green', 'gold', 'brown','black','darkorange','fuchsia','cornflowerblue','lime')\n", 185 | " plt.subplots(1,2,figsize=(15,6))\n", 186 | "\n", 187 | " plt.subplot(1,2,1)\n", 188 | " plt.plot(y_train, predict_train,'.',color='gray',alpha=.25, label='training')\n", 189 | " plt.plot(y_val, predict_val,'.', label='validation',color='gray',alpha=.75,)\n", 190 | " plt.plot(y_test, predict_test,'.', label='testing') \n", 191 | " plt.plot(y_train,y_train,'--',color='fuchsia')\n", 192 | " plt.axvline(x=0,color='gray',linewidth=1)\n", 193 | " plt.axhline(y=0,color='gray',linewidth=1)\n", 194 | " plt.title('Testing MAE = ' + str(mae.round(2)) + ' years')\n", 195 | " plt.xlabel('true number of years until target is reached')\n", 196 | " plt.ylabel('predicted number of years until target is reached')\n", 197 | " plt.legend()\n", 198 | "\n", 199 | "\n", 200 | " plt.subplot(1,2,2)\n", 201 | " plt.plot(y_yrs_train, predict_train,'.',color='gray',alpha=.5, label='training')\n", 202 | " plt.title('Time to Target Year for ' + str(settings['target_temp']) + 'C using ssp' + str(settings['ssp']))\n", 203 | " plt.xlabel('year of map')\n", 204 | " plt.ylabel('predicted number of years until target is reached')\n", 205 | " plt.axhline(y=0, color='gray', linewidth=1)\n", 206 | "\n", 207 | " predict_val_mat = predict_val.reshape(N_GCMS,N_VAL,len(YEARS_UNIQUE))\n", 208 | " for i in np.arange(0,predict_val_mat.shape[0]):\n", 209 | " plt.plot(YEARS_UNIQUE, predict_val_mat[i,:,:].swapaxes(1,0),'.', label='validation', color=clr[i])\n", 210 | " plt.axvline(x=target_years[i],linestyle='--',color=clr[i])" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "id": "c807abd7-832a-484b-98cd-7e6c3a9f60c0", 216 | "metadata": { 217 | "id": "c807abd7-832a-484b-98cd-7e6c3a9f60c0", 218 | "tags": [] 219 | }, 220 | "source": [ 221 | "## Train the network" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "id": "7becb266-c9fd-4098-a2ba-e6c52804b8bd", 228 | "metadata": { 229 | "colab": { 230 | "base_uri": "https://localhost:8080/", 231 | "height": 962 232 | }, 233 | "executionInfo": { 234 | "elapsed": 105064, 235 | "status": "ok", 236 | "timestamp": 1646449809976, 237 | "user": { 238 | "displayName": "Elizabeth Barnes", 239 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiNPVVIWP6XAkP_hwu-8rAxoeeNuk2BMkX5-yuA=s64", 240 | "userId": "07585723222468022011" 241 | }, 242 | "user_tz": 420 243 | }, 244 | "id": "7becb266-c9fd-4098-a2ba-e6c52804b8bd", 245 | "outputId": "5f2d4b54-fb88-418f-95a2-3c5e281cc2e4", 246 | "tags": [] 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "imp.reload(data_processing)\n", 251 | "for EXP_NAME in EXP_NAME_LIST:\n", 252 | "\n", 253 | " settings = experiment_settings.get_settings(EXP_NAME)\n", 254 | " display(settings)\n", 255 | "\n", 256 | " # define random number generator\n", 257 | " rng = np.random.default_rng(settings[\"rng_seed\"])\n", 258 | " \n", 259 | " # define early stopping callback (cannot be done elsewhere)\n", 260 | " early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',\n", 261 | " patience=settings['patience'],\n", 262 | " verbose=1,\n", 263 | " mode='auto',\n", 264 | " restore_best_weights=True) \n", 265 | "\n", 266 | " for iloop in np.arange(settings['n_models']):\n", 267 | " seed = rng.integers(low=1_000,high=10_000,size=1)[0]\n", 268 | " settings[\"seed\"] = int(seed)\n", 269 | " tf.random.set_seed(settings[\"seed\"])\n", 270 | " np.random.seed(settings[\"seed\"])\n", 271 | "\n", 272 | " # get model name\n", 273 | " model_name = file_methods.get_model_name(settings)\n", 274 | " if os.path.exists(MODEL_DIRECTORY + model_name + \"_model\") and OVERWRITE_MODEL==False:\n", 275 | " print(model_name + 'exists. Skipping...')\n", 276 | " print(\"================================\\n\")\n", 277 | " continue \n", 278 | " \n", 279 | " # load observations for diagnostics plotting and saving predictions\n", 280 | " da_obs, x_obs, global_mean_obs = data_processing.get_observations(DATA_DIRECTORY, settings)\n", 281 | " N_TRAIN, N_VAL, N_TEST, ALL_MEMBERS = data_processing.get_members(settings) \n", 282 | "\n", 283 | " # get the data\n", 284 | " (x_train, \n", 285 | " x_val, \n", 286 | " x_test, \n", 287 | " y_train, \n", 288 | " y_val, \n", 289 | " y_test, \n", 290 | " onehot_train, \n", 291 | " onehot_val, \n", 292 | " onehot_test, \n", 293 | " y_yrs_train, \n", 294 | " y_yrs_val, \n", 295 | " y_yrs_test, \n", 296 | " target_years, \n", 297 | " map_shape,\n", 298 | " settings) = data_processing.get_cmip_data(DATA_DIRECTORY, settings)\n", 299 | "\n", 300 | " ## determine how many GCMs are being used for later re-shaping\n", 301 | " N_GCMS = len(file_methods.get_cmip_filenames(settings, verbose=0))\n", 302 | "\n", 303 | " #---------------------------------------- \n", 304 | " tf.keras.backend.clear_session() \n", 305 | " model = network.compile_model(x_train, y_train, settings)\n", 306 | " history = model.fit(x_train, onehot_train, \n", 307 | " epochs=settings['n_epochs'], \n", 308 | " batch_size = settings['batch_size'], \n", 309 | " shuffle=True,\n", 310 | " validation_data=[x_val, onehot_val],\n", 311 | " callbacks=[early_stopping,],\n", 312 | " verbose=0, \n", 313 | " )\n", 314 | " #----------------------------------------\n", 315 | " # create predictions for observations with this model\n", 316 | " pred_obs = model.predict(x_obs)\n", 317 | "\n", 318 | " #----------------------------------------\n", 319 | " # save the tensorflow model and obs predictions\n", 320 | " if settings[\"save_model\"]:\n", 321 | " file_methods.save_tf_model(model, model_name, MODEL_DIRECTORY, settings)\n", 322 | " file_methods.save_pred_obs(pred_obs, \n", 323 | " PREDICTIONS_DIRECTORY+model_name + '_obs_predictions',\n", 324 | " )\n", 325 | "\n", 326 | " #----------------------------------------\n", 327 | " # create and save diagnostics plots\n", 328 | " plots.plot_metrics_panels(history,settings)\n", 329 | " plt.savefig(DIAGNOSTICS_DIRECTORY + model_name + '_metrics_diagnostic' + '.png', dpi=savefig_dpi)\n", 330 | " plt.show() \n", 331 | "\n", 332 | " plot_one_to_one_diagnostic()\n", 333 | " plt.savefig(DIAGNOSTICS_DIRECTORY + model_name + '_one_to_one_diagnostic' + '.png', dpi=savefig_dpi)\n", 334 | " plt.show() \n" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "dbc38634-dd38-4389-99e2-11d1c441844d", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "2+2" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "id": "1ae0a234-2d96-4d89-93f6-ff3a714fd1a5", 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [] 354 | } 355 | ], 356 | "metadata": { 357 | "colab": { 358 | "collapsed_sections": [], 359 | "name": "_main.ipynb", 360 | "provenance": [] 361 | }, 362 | "kernelspec": { 363 | "display_name": "Python 3 (ipykernel)", 364 | "language": "python", 365 | "name": "python3" 366 | }, 367 | "language_info": { 368 | "codemirror_mode": { 369 | "name": "ipython", 370 | "version": 3 371 | }, 372 | "file_extension": ".py", 373 | "mimetype": "text/x-python", 374 | "name": "python", 375 | "nbconvert_exporter": "python", 376 | "pygments_lexer": "ipython3", 377 | "version": "3.9.13" 378 | } 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 5 382 | } 383 | -------------------------------------------------------------------------------- /custom_metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics for TF training. 2 | 3 | Classes 4 | --------- 5 | InterquartileCapture(tf.keras.metrics.Metric) 6 | SignTest(tf.keras.metrics.Metric) 7 | CustomMAE(tf.keras.metrics.Metric) 8 | 9 | 10 | """ 11 | import tensorflow as tf 12 | import tensorflow_probability as tfp 13 | import numpy as np 14 | 15 | class InterquartileCapture(tf.keras.metrics.Metric): 16 | """Compute the fraction of true values between the 25 and 75 percentiles. 17 | """ 18 | def __init__(self, **kwargs): 19 | super().__init__(**kwargs) 20 | self.count = self.add_weight("count", initializer="zeros") 21 | self.total = self.add_weight("total", initializer="zeros") 22 | 23 | def update_state(self, y_true, pred, sample_weight=None): 24 | mu = pred[:, 0] 25 | sigma = pred[:, 1] 26 | norm_dist = tfp.distributions.Normal(mu,sigma) 27 | lower = norm_dist.quantile(.25) 28 | upper = norm_dist.quantile(.75) 29 | 30 | batch_count = tf.reduce_sum( 31 | tf.cast( 32 | tf.math.logical_and( 33 | tf.math.greater(y_true[:, 0], lower), 34 | tf.math.less(y_true[:, 0], upper) 35 | ), 36 | tf.float32 37 | ) 38 | 39 | ) 40 | batch_total = len(y_true[:, 0]) 41 | 42 | self.count.assign_add(tf.cast(batch_count, tf.float32)) 43 | self.total.assign_add(tf.cast(batch_total, tf.float32)) 44 | 45 | def result(self): 46 | return self.count / self.total 47 | 48 | def get_config(self): 49 | base_config = super().get_config() 50 | return {**base_config} 51 | 52 | 53 | class SignTest(tf.keras.metrics.Metric): 54 | """Compute the fraction of true values above the median. 55 | 56 | """ 57 | def __init__(self, **kwargs): 58 | super().__init__(**kwargs) 59 | self.count = self.add_weight("count", initializer="zeros") 60 | self.total = self.add_weight("total", initializer="zeros") 61 | 62 | def update_state(self, y_true, pred, sample_weight=None): 63 | mu = pred[:, 0] 64 | sigma = pred[:, 1] 65 | norm_dist = tfp.distributions.Normal(mu,sigma) 66 | median = norm_dist.quantile(.50) 67 | 68 | batch_count = tf.reduce_sum( 69 | tf.cast(tf.math.greater(y_true[:, 0], median), tf.float32) 70 | ) 71 | batch_total = len(y_true[:, 0]) 72 | 73 | self.count.assign_add(tf.cast(batch_count, tf.float32)) 74 | self.total.assign_add(tf.cast(batch_total, tf.float32)) 75 | 76 | def result(self): 77 | return self.count / self.total 78 | 79 | def get_config(self): 80 | base_config = super().get_config() 81 | return {**base_config} 82 | 83 | 84 | class CustomMAE(tf.keras.metrics.Metric): 85 | """Compute the prediction mean absolute error. 86 | 87 | The "predicted value" is the median of the conditional distribution. 88 | 89 | Notes 90 | ----- 91 | * The computation is done by maintaining running sums of total predictions 92 | and correct predictions made across all batches in an epoch. The 93 | running sums are reset at the end of each epoch. 94 | 95 | """ 96 | def __init__(self, **kwargs): 97 | super().__init__(**kwargs) 98 | self.error = self.add_weight("error", initializer="zeros") 99 | self.total = self.add_weight("total", initializer="zeros") 100 | 101 | def update_state(self, y_true, pred, sample_weight=None): 102 | mu = pred[:, 0] 103 | sigma = pred[:, 1] 104 | norm_dist = tfp.distributions.Normal(mu,sigma) 105 | predictions = norm_dist.quantile(.50) 106 | 107 | error = tf.math.abs(y_true[:, 0] - predictions) 108 | batch_error = tf.reduce_sum(error) 109 | batch_total = tf.math.count_nonzero(error) 110 | 111 | self.error.assign_add(tf.cast(batch_error, tf.float32)) 112 | self.total.assign_add(tf.cast(batch_total, tf.float32)) 113 | 114 | def result(self): 115 | return self.error / self.total 116 | 117 | def get_config(self): 118 | base_config = super().get_config() 119 | return {**base_config} 120 | 121 | def compute_iqr(uncertainty_type, onehot_data, bnn_cpd=None, x_data=None, model_shash = None): 122 | 123 | if(uncertainty_type in ("shash","shash2","shash3","shash4")): 124 | shash_pred = model_shash.predict(x_data) 125 | mu = shash_pred[:,0] 126 | sigma = shash_pred[:,1] 127 | gamma = shash_pred[:,2] 128 | tau = np.ones(np.shape(mu)) 129 | 130 | lower = shash.quantile(0.25, mu, sigma, gamma, tau) 131 | upper = shash.quantile(0.75, mu, sigma, gamma, tau) 132 | else: 133 | lower = np.percentile(bnn_cpd,25,axis=1) 134 | upper = np.percentile(bnn_cpd,75,axis=1) 135 | 136 | return lower, upper 137 | 138 | def compute_interquartile_capture(uncertainty_type, onehot_data, bnn_cpd=None, x_data=None, model_shash = None): 139 | 140 | bins = np.linspace(0, 1, 11) 141 | bins_inc = bins[1]-bins[0] 142 | 143 | if(uncertainty_type in ("shash","shash2","shash3","shash4")): 144 | lower, upper = compute_iqr(uncertainty_type, onehot_data, x_data=x_data, model_shash=model_shash) 145 | else: 146 | lower, upper = compute_iqr(uncertainty_type, onehot_data, bnn_cpd=bnn_cpd) 147 | 148 | iqr_capture = np.logical_and(onehot_data[:,0]>lower,onehot_data[:,0] baseline_mean.values+settings["target_temp"])[0] 190 | else: 191 | smoothed_values = savgol_filter(global_mean.values, 15, 3) 192 | iwarmer = np.where(smoothed_values > baseline_mean.values+settings["target_temp"])[0] 193 | target_year = global_mean["time"].values[iwarmer[0]].year 194 | except: 195 | if settings["gcmsub"] == 'FORCE' or settings["gcmsub"] == 'OOS': 196 | target_year = global_mean["time"].values[-1].year 197 | elif settings["gcmsub"] == 'EXTEND': 198 | target_year = 2150 199 | else: 200 | raise ValueError('****no such target****') 201 | 202 | # plot the calculation to make sure things make sense 203 | if plot == True: 204 | for ens in np.arange(0,global_mean_ens.shape[0]): 205 | global_mean_ens[ens,:].plot(linewidth=1.0,color="gray",alpha=.5) 206 | global_mean.plot(linewidth=2,label='data',color="aqua") 207 | plt.axhline(y=baseline_mean, color='k', linestyle='-', label='baseline temp') 208 | plt.axhline(y=baseline_mean+settings["target_temp"], color='tab:blue',linewidth=1., linestyle='--', label='target temp') 209 | plt.axvline(x=target_year,color='tab:blue',linewidth=1., linestyle='--', label='target year') 210 | global_mean_obs.plot(linewidth=2,label='data',color="tab:orange") 211 | plt.xlabel('year') 212 | plt.ylabel('temp (K)') 213 | plt.title(f + '\ntargets [' + str(target_year.year) + ', ' + str(settings["target_temp"]) + 'C]', 214 | fontsize = 8, 215 | ) 216 | plt.show() 217 | 218 | # define the labels 219 | if verbose == 1: 220 | print('TARGET_YEAR = ' + str(target_year) + ', TARGET_TEMP = ' + str(temp_reached)) 221 | labels = target_year - da['time.year'].values 222 | 223 | return labels, da['time.year'].values, target_year 224 | 225 | def preprocess_data(da, MEMBERS, settings): 226 | 227 | if MEMBERS is None: 228 | new_data = da 229 | else: 230 | new_data = da[MEMBERS,:,:,:] 231 | 232 | if settings["anomalies"] is True: 233 | new_data = new_data - new_data.sel(time=slice(str(settings["anomaly_yr_bounds"][0]),str(settings["anomaly_yr_bounds"][1]))).mean('time') 234 | if settings["anomalies"] == 'Baseline': 235 | new_data = new_data - new_data.sel(time=slice(str(settings["baseline_yr_bounds"][0]),str(settings["baseline_yr_bounds"][1]))).mean('time') 236 | new_data = new_data - new_data.sel(time=slice(str(settings["anomaly_yr_bounds"][0]),str(settings["anomaly_yr_bounds"][1]))).mean('time') 237 | 238 | if settings["remove_map_mean"] == 'raw': 239 | new_data = new_data - new_data.mean(("lon","lat")) 240 | elif settings["remove_map_mean"] == 'weighted': 241 | weights = np.cos(np.deg2rad(new_data.lat)) 242 | weights.name = "weights" 243 | new_data_weighted = new_data.weighted(weights) 244 | new_data = new_data - new_data_weighted.mean(("lon","lat")) 245 | 246 | if settings["remove_sh"] == True: 247 | # print('removing SH') 248 | i = np.where(new_data["lat"]<=-50)[0] 249 | if(len(new_data.shape)==3): 250 | new_data[:,i,:] = 0.0 251 | else: 252 | new_data[:,:,i,:] = 0.0 253 | 254 | return new_data 255 | 256 | def make_data_split(da, data, f_labels, f_years, labels, years, MEMBERS, settings): 257 | 258 | # process the data, i.e. compute anomalies, subtract the mean, etc. 259 | new_data = preprocess_data(da, MEMBERS, settings) 260 | 261 | # only train on certain samples 262 | iyears = np.where((f_years >= settings["training_yr_bounds"][0]) & (f_years <= settings["training_yr_bounds"][1]))[0] 263 | f_years = f_years[iyears] 264 | f_labels = f_labels[iyears] 265 | new_data = new_data[:,iyears,:,:] 266 | 267 | if data is None: 268 | data = new_data.values 269 | labels = np.tile(f_labels,(len(MEMBERS),1)) 270 | years = np.tile(f_years,(len(MEMBERS),1)) 271 | else: 272 | data = np.concatenate((data,new_data.values),axis=0) 273 | labels = np.concatenate((labels,np.tile(f_labels,(len(MEMBERS),1))),axis=0) 274 | years = np.concatenate((years,np.tile(f_years,(len(MEMBERS),1))),axis=0) 275 | 276 | return data, labels, years -------------------------------------------------------------------------------- /file_methods.py: -------------------------------------------------------------------------------- 1 | """Functions for working with generic files. 2 | 3 | Functions 4 | --------- 5 | get_model_name(settings) 6 | get_netcdf_da(filename) 7 | save_pred_obs(pred_vector, filename) 8 | save_tf_model(model, model_name, directory, settings) 9 | get_cmip_filenames(settings, verbose=0) 10 | """ 11 | 12 | import xarray as xr 13 | import json 14 | import pickle 15 | import tensorflow as tf 16 | import custom_metrics 17 | 18 | __author__ = "Elizabeth A. Barnes and Noah Diffenbaugh" 19 | __version__ = "20 March 2022" 20 | 21 | 22 | def get_model_name(settings): 23 | # model_name = (settings["exp_name"] + '_' + 24 | # 'ssp' + settings["ssp"] + '_' + 25 | # str(settings["target_temp"]) + '_' + 26 | # 'gcmsub' + settings["gcmsub"] + '_' + 27 | # settings["network_type"] + 28 | # '_rng' + str(settings["rng_seed"]) + 29 | # '_seed' + str(settings["seed"]) 30 | # ) 31 | model_name = (settings["exp_name"] + 32 | '_seed' + str(settings["seed"]) 33 | ) 34 | 35 | return model_name 36 | 37 | 38 | def get_netcdf_da(filename): 39 | da = xr.open_dataarray(filename) 40 | return da 41 | 42 | 43 | def save_pred_obs(pred_vector, filename): 44 | with open(filename + '.pickle', 'wb') as f: 45 | pickle.dump(pred_vector, f) 46 | 47 | 48 | def load_tf_model(model_name, directory): 49 | # loading a tf model 50 | model = tf.keras.models.load_model( 51 | directory + model_name + "_model", 52 | compile=False, 53 | custom_objects={ 54 | "InterquartileCapture": custom_metrics.InterquartileCapture(), 55 | "SignTest": custom_metrics.SignTest(), 56 | "CustomMAE": custom_metrics.CustomMAE() 57 | }, 58 | ) 59 | return model 60 | 61 | 62 | def save_tf_model(model, model_name, directory, settings): 63 | 64 | # save the tf model 65 | tf.keras.models.save_model(model, directory + model_name + "_model", overwrite=True) 66 | 67 | # save the meta data 68 | with open(directory + model_name + '_metadata.json', 'w') as json_file: 69 | json_file.write(json.dumps(settings)) 70 | 71 | 72 | def get_cmip_filenames(settings, verbose=0): 73 | if settings["ssp"] == '370' and settings["gcmsub"] == 'ALL': 74 | filenames = ('tas_Amon_historical_ssp370_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 75 | 'tas_Amon_historical_ssp370_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 76 | 'tas_Amon_historical_ssp370_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 77 | 'tas_Amon_historical_ssp370_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 78 | 'tas_Amon_historical_ssp370_GISS-E2-1-G_r1-10_ncecat_ann_mean_2pt5degree.nc', 79 | 'tas_Amon_historical_ssp370_IPSL-CM6A-LR_r1-10_ncecat_ann_mean_2pt5degree.nc', 80 | 'tas_Amon_historical_ssp370_CESM2-LE2-smbb_r1-10_ncecat_ann_mean_2pt5degree.nc', 81 | ) 82 | elif settings["ssp"] == '245' and settings["gcmsub"] == 'ALL': 83 | filenames = ( 84 | 'tas_Amon_historical_ssp245_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 85 | 'tas_Amon_historical_ssp245_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 86 | 'tas_Amon_historical_ssp245_CNRM-ESM2-1_r1-10_ncecat_ann_mean_2pt5degree.nc', 87 | 'tas_Amon_historical_ssp245_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 88 | 'tas_Amon_historical_ssp245_GISS-E2-1-G_r1-10_ncecat_ann_mean_2pt5degree.nc', 89 | 'tas_Amon_historical_ssp245_IPSL-CM6A-LR_r1-10_ncecat_ann_mean_2pt5degree.nc', 90 | ) 91 | elif settings["ssp"] == '370' and settings["gcmsub"] == 'UNIFORM': 92 | filenames = ('tas_Amon_historical_ssp370_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 93 | 'tas_Amon_historical_ssp370_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 94 | 'tas_Amon_historical_ssp370_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 95 | # 'tas_Amon_historical_ssp370_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 96 | ) 97 | elif ((settings["ssp"] == '126' and settings["gcmsub"] == 'ALL') and (settings["target_temp"] == 2.0)): 98 | filenames = ('tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 99 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 100 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 101 | ) 102 | 103 | elif ((settings["ssp"] == '126' and settings["gcmsub"] == 'ALL7')): 104 | filenames = ('tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 105 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 106 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 107 | 'tas_Amon_historical_ssp126_CNRM-CM6-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 108 | 'tas_Amon_historical_ssp126_CNRM-ESM2-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 109 | 'tas_Amon_historical_ssp126_GISS-E2-1-G_r1-5_ncecat_ann_mean_2pt5degree.nc', 110 | 'tas_Amon_historical_ssp126_IPSL-CM6A-LR_r1-5_ncecat_ann_mean_2pt5degree.nc', 111 | ) 112 | elif ((settings["ssp"] == '126' and settings["gcmsub"] == 'ALL10')): 113 | filenames = ( 114 | 'tas_Amon_historical_ssp126_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 115 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 116 | 'tas_Amon_historical_ssp126_MRI-ESM2-0_r1-5_ncecat_ann_mean_2pt5degree.nc', 117 | 'tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 118 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 119 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 120 | 'tas_Amon_historical_ssp126_CNRM-CM6-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 121 | 'tas_Amon_historical_ssp126_CNRM-ESM2-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 122 | 'tas_Amon_historical_ssp126_GISS-E2-1-G_r1-5_ncecat_ann_mean_2pt5degree.nc', 123 | 'tas_Amon_historical_ssp126_IPSL-CM6A-LR_r1-5_ncecat_ann_mean_2pt5degree.nc', 124 | ) 125 | elif ((settings["ssp"] == '126' and settings["gcmsub"] == 'noHIGH10')): 126 | filenames = ( 127 | 'tas_Amon_historical_ssp126_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 128 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 129 | 'tas_Amon_historical_ssp126_MRI-ESM2-0_r1-5_ncecat_ann_mean_2pt5degree.nc', 130 | 'tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 131 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 132 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 133 | 'tas_Amon_historical_ssp126_CNRM-CM6-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 134 | 'tas_Amon_historical_ssp126_CNRM-ESM2-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 135 | 'tas_Amon_historical_ssp126_GISS-E2-1-G_r1-5_ncecat_ann_mean_2pt5degree.nc', 136 | 'tas_Amon_historical_ssp126_IPSL-CM6A-LR_r1-5_ncecat_ann_mean_2pt5degree.nc', 137 | ) 138 | elif ((settings["ssp"] == '126' and settings["gcmsub"] == 'noHIGH7')): 139 | filenames = ( 140 | 'tas_Amon_historical_ssp126_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 141 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 142 | 'tas_Amon_historical_ssp126_MRI-ESM2-0_r1-5_ncecat_ann_mean_2pt5degree.nc', 143 | # 'tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 144 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 145 | # 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 146 | 'tas_Amon_historical_ssp126_CNRM-CM6-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 147 | 'tas_Amon_historical_ssp126_CNRM-ESM2-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 148 | 'tas_Amon_historical_ssp126_GISS-E2-1-G_r1-5_ncecat_ann_mean_2pt5degree.nc', 149 | # 'tas_Amon_historical_ssp126_IPSL-CM6A-LR_r1-5_ncecat_ann_mean_2pt5degree.nc', 150 | ) 151 | elif ((settings["ssp"] == '126' and settings["gcmsub"] == 'noHIGH5')): 152 | filenames = ( 153 | 'tas_Amon_historical_ssp126_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 154 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 155 | 'tas_Amon_historical_ssp126_MRI-ESM2-0_r1-5_ncecat_ann_mean_2pt5degree.nc', 156 | # 'tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 157 | # 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 158 | # 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 159 | # 'tas_Amon_historical_ssp126_CNRM-CM6-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 160 | 'tas_Amon_historical_ssp126_CNRM-ESM2-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 161 | 'tas_Amon_historical_ssp126_GISS-E2-1-G_r1-5_ncecat_ann_mean_2pt5degree.nc', 162 | # 'tas_Amon_historical_ssp126_IPSL-CM6A-LR_r1-5_ncecat_ann_mean_2pt5degree.nc', 163 | ) 164 | 165 | elif settings["ssp"] == '126' and settings["gcmsub"] == 'ALL': 166 | filenames = ('tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 167 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 168 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 169 | 'tas_Amon_historical_ssp126_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 170 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 171 | ) 172 | elif settings["ssp"] == '126' and settings["gcmsub"] == 'noM6': 173 | filenames = ('tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 174 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 175 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 176 | # 'tas_Amon_historical_ssp126_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 177 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 178 | ) 179 | elif settings["ssp"] == '126' and settings["gcmsub"] == 'FORCE': 180 | filenames = ('tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 181 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 182 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 183 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 184 | ) 185 | elif settings["ssp"] == '126' and settings["gcmsub"] == 'EXTEND': 186 | filenames = ('tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 187 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 188 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 189 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 190 | ) 191 | 192 | elif settings["ssp"] == '126' and settings["gcmsub"] == 'UNIFORM': 193 | filenames = ('tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 194 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 195 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 196 | ) 197 | elif settings["gcmsub"] == 'OOS': 198 | filenames = ( 199 | 'tas_Amon_historical_ssp126_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 200 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 201 | 'tas_Amon_historical_ssp126_MRI-ESM2-0_r1-5_ncecat_ann_mean_2pt5degree.nc', 202 | 'tas_Amon_historical_ssp126_CNRM-CM6-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 203 | 'tas_Amon_historical_ssp126_CNRM-ESM2-1_r1-5_ncecat_ann_mean_2pt5degree.nc', 204 | 'tas_Amon_historical_ssp126_GISS-E2-1-G_r1-5_ncecat_ann_mean_2pt5degree.nc', 205 | 'tas_Amon_historical_ssp126_IPSL-CM6A-LR_r1-5_ncecat_ann_mean_2pt5degree.nc', 206 | ) 207 | 208 | # elif settings["ssp"] == '126' and settings["gcmsub"] == 'MIROC': 209 | # filenames = ( 210 | # 'tas_Amon_historical_ssp126_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 211 | # 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 212 | # ) 213 | # elif settings["ssp"] == '370' and settings["gcmsub"] == 'MIROC': 214 | # filenames = ( 215 | # # 'tas_Amon_historical_ssp370_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 216 | # 'tas_Amon_historical_ssp370_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 217 | # ) 218 | elif settings["ssp"] == '126' and settings["gcmsub"] == 'MAX': 219 | filenames = ('tas_Amon_historical_ssp126_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc', 220 | 'tas_Amon_historical_ssp126_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc', 221 | 'tas_Amon_historical_ssp126_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc', 222 | 'tas_Amon_historical_ssp126_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc', 223 | 'tas_Amon_historical_ssp126_MIROC6_r1-10_ncecat_ann_mean_2pt5degree.nc', 224 | ) 225 | 226 | 227 | else: 228 | raise NotImplementedError('no such SSP') 229 | 230 | if verbose != 0: 231 | print(filenames) 232 | 233 | return filenames 234 | -------------------------------------------------------------------------------- /initial_processing/_process_data_step2_archived.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d1302a83-244c-4f2f-9ac2-cd465eb4e9fe", 6 | "metadata": {}, 7 | "source": [ 8 | "# Detecting temperature targets\n", 9 | "##### author: Elizabeth A. Barnes and Noah Diffenbaugh\n", 10 | "##### version: v0.1.0\n", 11 | "\n", 12 | "```\n", 13 | "conda create --name env-noah python=3.9\n", 14 | "conda activate env-noah\n", 15 | "pip install tensorflow==2.7.0\n", 16 | "pip install tensorflow-probability==0.15.0\n", 17 | "pip install --upgrade numpy scipy pandas statsmodels matplotlib seaborn palettable progressbar2 tabulate icecream flake8 keras-tuner sklearn jupyterlab black isort jupyterlab_code_formatter\n", 18 | "pip install -U scikit-learn\n", 19 | "pip install silence-tensorflow tqdm\n", 20 | "conda install -c conda-forge cmocean cartopy\n", 21 | "conda install -c conda-forge xarray dask netCDF4 bottleneck\n", 22 | "conda install -c conda-forge nc-time-axis\n", 23 | "```\n", 24 | "\n", 25 | "Use the command\n", 26 | "```python -m pip freeze > requirements.txt```\n", 27 | "to make a pip installation list." 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "2ae65312-198a-4b69-aa5e-fd101bb54b9f", 33 | "metadata": {}, 34 | "source": [ 35 | "Data can be found here:\n", 36 | "* https://www.earthsystemgrid.org/dataset/ucar.cgd.cesm2le.atm.proc.monthly_ave.TREFHT/file.html\n", 37 | "* https://www.cesm.ucar.edu/projects/community-projects/LENS2/" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "e6711c2a-2c1c-4e66-b8ee-ea8d5e9da0c8", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import xarray as xr\n", 48 | "import numpy as np\n", 49 | "# import pandas as pd\n", 50 | "import matplotlib.pyplot as plt\n", 51 | "# import cartopy.crs as ccrs\n", 52 | "\n", 53 | "import os.path\n", 54 | "from os import path\n", 55 | "import subprocess\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "2ad192dd-db23-4589-8515-cf785e0221aa", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "dirName = '/Users/eabarnes/Desktop/big_data/orig_grid/'\n", 66 | "dirReMapName = '/Users/eabarnes/Desktop/big_data/remap_grid/'\n", 67 | "dirMergeName = '/Users/eabarnes/Desktop/big_data/merge_all/'\n", 68 | "dirAnnualName = '/Users/eabarnes/Desktop/big_data/annual_mean/'" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "b4e29afb-7ac9-4fb5-9d97-900f88143c56", 74 | "metadata": {}, 75 | "source": [ 76 | "## Remap to coarser grid" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "471e8b63-94e4-4074-8c4b-0def2778ade7", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "var = 'TREFHT'\n", 87 | "\n", 88 | "for exp in ('BHISTsmbb','BSSP370smbb'):\n", 89 | " print('-----' + exp + '-----')\n", 90 | " for control_decade in (1231, 1251, 1281, 1301): \n", 91 | " for member in np.arange(10,21):\n", 92 | " print('-----' + str(member) + '-----')\n", 93 | "\n", 94 | " #---------------------------------------------------------\n", 95 | " if(exp=='BHISTsmbb'):\n", 96 | " timetext = '185001-201412'\n", 97 | " elif(exp=='BSSP370smbb'):\n", 98 | " timetext = '201501-210012'\n", 99 | " else:\n", 100 | " raise ValueError('no such timetext')\n", 101 | " filename_merge = 'b.e21.' + str(exp) + '.f09_g17.LE2-' + str(control_decade) + '.0' + str(member) + '.cam.h0.' + str(var) + '.' + timetext + '.nc'\n", 102 | " filename_remap = filename_merge[:-3] + '.r180x90.nc'\n", 103 | " #---------------------------------------------------------\n", 104 | "\n", 105 | " if(os.path.isfile(dirReMapName+filename_remap)==True):\n", 106 | " print('remap file exists, continue')\n", 107 | " else:\n", 108 | " # grab variable only\n", 109 | " runText = \"cdo -select,name=\" + var + \" \" + dirName+filename_merge + \" \" + dirReMapName + \"outfile.nc\"\n", 110 | " print(runText)\n", 111 | " process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 112 | " output, error = process.communicate() \n", 113 | "\n", 114 | " # remap to 2x2\n", 115 | " runText = 'cdo remapcon,r180x90 ' + dirReMapName + 'outfile.nc ' + dirReMapName+filename_remap \n", 116 | " print(runText)\n", 117 | " process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 118 | " output, error = process.communicate() \n", 119 | "\n", 120 | " # remove outfile.nc\n", 121 | " runText = 'rm ' + dirReMapName + 'outfile.nc'\n", 122 | " print(runText)\n", 123 | " process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 124 | " process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 125 | " output, error = process.communicate() \n", 126 | " " 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "33bf70fe-771d-42f3-b369-e30eb7b0c4e1", 132 | "metadata": {}, 133 | "source": [ 134 | "## Merge historical and future" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "3b622bc8-98ee-4d40-93bc-76800666f1db", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "var = 'TREFHT'\n", 145 | "\n", 146 | "for control_decade in (1231, 1251, 1281, 1301): \n", 147 | " for member in np.arange(10,21):\n", 148 | " print('-----' + str(member) + '-----')\n", 149 | "\n", 150 | " filename_hist = 'b.e21.BHISTsmbb.f09_g17.LE2-' + str(control_decade) + '.0' + str(member) + '.cam.h0.' + str(var) + '.' + '185001-201412' + '.r180x90.nc'\n", 151 | " filename_ssp = 'b.e21.BSSP370smbb.f09_g17.LE2-' + str(control_decade) + '.0' + str(member) + '.cam.h0.' + str(var) + '.' + '201501-210012' + '.r180x90.nc' \n", 152 | " filename_all = 'b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-' + str(control_decade) + '.0' + str(member) + '.cam.h0.' + str(var) + '.' + '185001-210012' + '.r180x90.nc'\n", 153 | " # filename_short = 'b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-' + str(control_decade) + '.0' + str(member) + '.cam.h0.' + str(var) + '.' + '195001-210012' + '.r180x90.nc'\n", 154 | " #---------------------------------------------------------\n", 155 | "\n", 156 | " if(os.path.isfile(dirMergeName+filename_all)==True):\n", 157 | " print('remap file exists, continue')\n", 158 | " else:\n", 159 | " # mergetime\n", 160 | " # runText = 'cdo mergetime ' + dirReMapName+filename_hist + ' ' + dirReMapName+filename_ssp + ' ' + dirMergeName+'timewrong.nc'\n", 161 | " runText = 'cdo mergetime ' + dirReMapName+filename_hist + ' ' + dirReMapName+filename_ssp + ' ' + dirMergeName+filename_all\n", 162 | " print(runText)\n", 163 | " process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 164 | " output, error = process.communicate()\n", 165 | " \n", 166 | " if(os.path.isfile(dirMergeName+filename_short)==True):\n", 167 | " print('remap file exists, continue')\n", 168 | " else:\n", 169 | " # mergetime\n", 170 | " runText = 'cdo selyear,1950/2101 ' + dirMergeName+filename_all + ' ' + dirMergeName+filename_short\n", 171 | " print(runText)\n", 172 | " process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 173 | " output, error = process.communicate()\n", 174 | " \n", 175 | "# if(os.path.isfile(dirMergeName+filename_short)==True):\n", 176 | "# print('remap file exists, continue')\n", 177 | "# else:\n", 178 | "# # mergetime\n", 179 | "# runText = 'cdo selyear,1950/2101 ' + dirMergeName+filename_all + ' ' + dirMergeName+filename_short\n", 180 | "# print(runText)\n", 181 | "# process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 182 | "# output, error = process.communicate()\n", 183 | " \n", 184 | " \n", 185 | " # shifttime\n", 186 | " # runText = 'cdo settaxis,1850-01-01,00:00:00,1month ' + dirMergeName+'timewrong.nc' + ' ' + dirMergeName+filename_all\n", 187 | " # print(runText)\n", 188 | " # process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 189 | " # output, error = process.communicate()\n", 190 | " \n", 191 | " # cdo settaxis,1920-01-15,00:00:00,1month tmp.nc filout.nc\n", 192 | " # remove timewrong.nc\n", 193 | " # runText = 'rm ' + dirMergeName + 'timewrong.nc'\n", 194 | " # print(runText)\n", 195 | " # process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 196 | " # output, error = process.communicate() \n", 197 | " \n", 198 | " " 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "id": "98204e70-7ee5-4a0e-979d-4b922baa8ad6", 204 | "metadata": {}, 205 | "source": [ 206 | "## Take mean over all members" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "id": "4c4bc306-cd99-432b-803c-f73427095516", 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "runText = 'cdo ensmean ' + dirMergeName+'*185001-210012.r180x90.nc ' + dirMergeName + '/b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2.cam.h0.TREFHT.185001-210012.r180x90.nc'\n", 217 | "print(runText)\n", 218 | "process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 219 | "output, error = process.communicate()\n" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "1962100b-aaae-4f61-96dd-cae1f436145a", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "runText = 'cdo ensmean ' + dirMergeName+'*195001-210012.r180x90.nc ' + dirMergeName + '/b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2.cam.h0.TREFHT.195001-210012.r180x90.nc'\n", 230 | "print(runText)\n", 231 | "process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 232 | "output, error = process.communicate()\n" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "id": "a51ce3e7-bd06-464e-8198-2eea6ad17ce6", 238 | "metadata": {}, 239 | "source": [ 240 | "## Rename files and take annual mean" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 46, 246 | "id": "646a99bf-a235-4b2b-b696-6e1f39945917", 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "# https://ncar.github.io/esds/posts/2021/yearly-averages-xarray/\n", 251 | "\n", 252 | "def weighted_temporal_mean(ds, var):\n", 253 | " \"\"\"\n", 254 | " weight by days in each month\n", 255 | " \"\"\"\n", 256 | " # Determine the month length\n", 257 | " month_length = ds.time.dt.days_in_month\n", 258 | "\n", 259 | " # Calculate the weights\n", 260 | " wgts = month_length.groupby(\"time.year\") / month_length.groupby(\"time.year\").sum()\n", 261 | "\n", 262 | " # Make sure the weights in each year add up to 1\n", 263 | " np.testing.assert_allclose(wgts.groupby(\"time.year\").sum(xr.ALL_DIMS), 1.0)\n", 264 | "\n", 265 | " # Subset our dataset for our variable\n", 266 | " obs = ds[var]\n", 267 | "\n", 268 | " # Setup our masking for nan values\n", 269 | " cond = obs.isnull()\n", 270 | " ones = xr.where(cond, 0.0, 1.0)\n", 271 | "\n", 272 | " # Calculate the numerator\n", 273 | " obs_sum = (obs * wgts).resample(time=\"AS\").sum(dim=\"time\")\n", 274 | "\n", 275 | " # Calculate the denominator\n", 276 | " ones_out = (ones * wgts).resample(time=\"AS\").sum(dim=\"time\")\n", 277 | "\n", 278 | " # Return the weighted average\n", 279 | " return obs_sum / ones_out" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 52, 285 | "id": "d07a70fc-ab48-494e-be46-2dc79cc0d9f6", 286 | "metadata": {}, 287 | "outputs": [ 288 | { 289 | "name": "stdout", 290 | "output_type": "stream", 291 | "text": [ 292 | "cdo merge /Users/eabarnes/Desktop/big_data/annual_mean/*.nc /Users/eabarnes/Desktop/big_data/annual_mean/../b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n" 293 | ] 294 | }, 295 | { 296 | "name": "stderr", 297 | "output_type": "stream", 298 | "text": [ 299 | "cdo merge (Warning): Duplicate entry of parameter name __xarray_dataarray_variable__ in /Users/eabarnes/Desktop/big_data/annual_mean/b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.012.cam.h0.TREFHT.185001-210012.r180x90.annual.nc!\n" 300 | ] 301 | } 302 | ], 303 | "source": [ 304 | "import datetime\n", 305 | "var = 'TREFHT'\n", 306 | "X = np.zeros(30,251,90,180)\n", 307 | "count = 0\n", 308 | "\n", 309 | "for control_decade in (1231, 1251, 1281, 1301): \n", 310 | " for member in np.arange(10,21):\n", 311 | "\n", 312 | " filename_all = 'b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-' + str(control_decade) + '.0' + str(member) + '.cam.h0.' + str(var) + '.' + '185001-210012' + '.r180x90.nc'\n", 313 | " filename_ann = filename_all[:-2] + 'annual.nc'\n", 314 | " \n", 315 | " if(os.path.isfile(dirMergeName + filename_all)==False):\n", 316 | " continue\n", 317 | " da = xr.open_dataset(dirMergeName + filename_all)\n", 318 | " #---------------------------------------------------------\n", 319 | " # fix the time stamp\n", 320 | " dates = da[\"time\"]\n", 321 | " delta_time = datetime.timedelta(1)\n", 322 | " new_dates = dates - delta_time\n", 323 | " da[\"time\"] = new_dates\n", 324 | "\n", 325 | " #---------------------------------------------------------\n", 326 | " # take the annual mean\n", 327 | " = weighted_temporal_mean(da,'TREFHT')\n", 328 | " da_annual.to_netcdf(dirAnnualName + filename_ann)\n", 329 | "\n", 330 | " #---------------------------------------------------------\n", 331 | " # concatenate all together\n", 332 | " X[count,:,:,:] = da_annual['TREFHT'].values()\n", 333 | " \n", 334 | "# runText = 'cdo merge ' + dirAnnualName+'*.nc' + ' ' + dirAnnualName+'../b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2.cam.h0.TREFHT.185001-210012.r180x90.annual.nc'\n", 335 | "# print(runText)\n", 336 | "# process = subprocess.Popen(runText.split(), stdout=subprocess.PIPE)\n", 337 | "# output, error = process.communicate()\n", 338 | " \n", 339 | " " 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 58, 345 | "id": "4610aa91-8ef0-4b1c-973f-d56ecaaf2cfc", 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "data": { 350 | "text/plain": [ 351 | "\n", 352 | "Dimensions: (time: 251, lon: 180, lat: 90)\n", 353 | "Coordinates:\n", 354 | " * time (time) object 1850-01-01 00:00:00 ... 2100...\n", 355 | " * lon (lon) float64 0.0 2.0 4.0 ... 356.0 358.0\n", 356 | " * lat (lat) float64 -89.0 -87.0 -85.0 ... 87.0 89.0\n", 357 | "Data variables:\n", 358 | " __xarray_dataarray_variable__ (time, lat, lon) float64 ...>" 359 | ] 360 | }, 361 | "execution_count": 58, 362 | "metadata": {}, 363 | "output_type": "execute_result" 364 | } 365 | ], 366 | "source": [ 367 | "da_annual.keys" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 86, 373 | "id": "0b891a4e-5ac5-4d99-88fa-feec7951f508", 374 | "metadata": {}, 375 | "outputs": [ 376 | { 377 | "name": "stdout", 378 | "output_type": "stream", 379 | "text": [ 380 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.011.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 381 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.012.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 382 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.013.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 383 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.014.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 384 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.015.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 385 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.016.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 386 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.017.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 387 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.018.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 388 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.019.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 389 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1231.020.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 390 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.011.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 391 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.012.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 392 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.013.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 393 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.014.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 394 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.015.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 395 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.016.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 396 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.017.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 397 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.018.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 398 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.019.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 399 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1281.020.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 400 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.011.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 401 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.012.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 402 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.013.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 403 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.014.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 404 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.015.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 405 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.016.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 406 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.017.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 407 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.018.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 408 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.019.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n", 409 | "b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-1301.020.cam.h0.TREFHT.185001-210012.r180x90.annual.nc\n" 410 | ] 411 | } 412 | ], 413 | "source": [ 414 | "import datetime\n", 415 | "var = 'TREFHT'\n", 416 | "X = np.zeros((30,251,90,180))\n", 417 | "count = 0\n", 418 | "\n", 419 | "for control_decade in (1231, 1251, 1281, 1301): \n", 420 | " for member in np.arange(10,21):\n", 421 | " filename_all = 'b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2-' + str(control_decade) + '.0' + str(member) + '.cam.h0.' + str(var) + '.' + '185001-210012' + '.r180x90.nc'\n", 422 | " filename_ann = filename_all[:-2] + 'annual.nc'\n", 423 | " \n", 424 | " if(os.path.isfile(dirAnnualName + filename_ann)==False):\n", 425 | " continue\n", 426 | " da_annual = xr.open_dataset(dirAnnualName + filename_ann)\n", 427 | " print(filename_ann)\n", 428 | " #---------------------------------------------------------\n", 429 | " # concatenate all together\n", 430 | " X[count,:,:,:] = da_annual['__xarray_dataarray_variable__'].values\n", 431 | " count = count + 1\n", 432 | " \n", 433 | "da = xr.DataArray(\n", 434 | " data=X,\n", 435 | " dims=[\"member\",\"time\",\"lat\", \"lon\"],\n", 436 | " coords=dict(\n", 437 | " time=da_annual['time'].data,\n", 438 | " member=np.arange(0,30),\n", 439 | " ),\n", 440 | ")\n", 441 | "da.assign_coords(lon=da_annual['lon'])\n", 442 | "da.assign_coords(lat=da_annual['lat'])\n", 443 | "da.rename_vars({'__xarray_dataarray_variable__':'TREFHT'})\n", 444 | "\n", 445 | "da.to_netcdf(dirAnnualName + 'b.e21.BHISTsmbb-BSSP370smbb.f09_g17.LE2.cam.h0.TREFHT.185001-210012.r180x90.annual.nc')" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "id": "51535160-b581-4ab6-b22f-8e2d94de5f7f", 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [] 455 | } 456 | ], 457 | "metadata": { 458 | "kernelspec": { 459 | "display_name": "Python 3 (ipykernel)", 460 | "language": "python", 461 | "name": "python3" 462 | }, 463 | "language_info": { 464 | "codemirror_mode": { 465 | "name": "ipython", 466 | "version": 3 467 | }, 468 | "file_extension": ".py", 469 | "mimetype": "text/x-python", 470 | "name": "python", 471 | "nbconvert_exporter": "python", 472 | "pygments_lexer": "ipython3", 473 | "version": "3.9.7" 474 | } 475 | }, 476 | "nbformat": 4, 477 | "nbformat_minor": 5 478 | } 479 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | """Network functions. 2 | 3 | Classes 4 | --------- 5 | Exponentiate(keras.layers.Layer) 6 | 7 | 8 | Functions 9 | --------- 10 | RegressLossExpSigma(y_true, y_pred) 11 | compile_model(x_train, y_train, settings) 12 | 13 | 14 | """ 15 | import tensorflow as tf 16 | import tensorflow_probability as tfp 17 | from tensorflow.keras import Model 18 | from tensorflow.keras.layers import Dense, Input, Dropout, Softmax 19 | from tensorflow.keras import optimizers 20 | from tensorflow.keras import regularizers 21 | from tensorflow import keras 22 | import numpy as np 23 | 24 | import custom_metrics 25 | 26 | 27 | class Exponentiate(keras.layers.Layer): 28 | """Custom layer to exp the sigma and tau estimates inline.""" 29 | 30 | def __init__(self, **kwargs): 31 | super(Exponentiate, self).__init__(**kwargs) 32 | 33 | def call(self, inputs): 34 | return tf.math.exp(inputs) 35 | 36 | def RegressLossExpSigma(y_true, y_pred): 37 | # network predictions 38 | mu = y_pred[:,0] 39 | sigma = y_pred[:,1] 40 | 41 | # normal distribution defined by N(mu,sigma) 42 | norm_dist = tfp.distributions.Normal(mu,sigma) 43 | 44 | # compute the log as the -log(p) 45 | loss = -norm_dist.log_prob(y_true[:,0]) 46 | 47 | return tf.reduce_mean(loss, axis=-1) 48 | 49 | def compile_model(x_train, y_train, settings): 50 | 51 | # First we start with an input layer 52 | inputs = Input(shape=x_train.shape[1:]) 53 | 54 | normalizer = tf.keras.layers.Normalization() 55 | normalizer.adapt(x_train) 56 | layers = normalizer(inputs) 57 | 58 | layers = Dropout(rate=settings["dropout_rate"], 59 | seed=settings["seed"])(layers) 60 | 61 | for hidden, activation, ridge in zip(settings["hiddens"], settings["act_fun"], settings["ridge_param"]): 62 | layers = Dense(hidden, activation=activation, 63 | kernel_regularizer=tf.keras.regularizers.l1_l2(l1=0.00, l2=ridge), 64 | bias_initializer=tf.keras.initializers.RandomNormal(seed=settings["seed"]), 65 | kernel_initializer=tf.keras.initializers.RandomNormal(seed=settings["seed"]))(layers) 66 | 67 | 68 | if settings['network_type'] == 'reg': 69 | LOSS = 'mae' 70 | metrics = ['mse',] 71 | 72 | output_layer = Dense(1, activation='linear', 73 | bias_initializer=tf.keras.initializers.RandomNormal(seed=settings["seed"]), 74 | kernel_initializer=tf.keras.initializers.RandomNormal(seed=settings["seed"]))(layers) 75 | 76 | elif settings['network_type'] == 'shash2': 77 | LOSS = RegressLossExpSigma 78 | metrics = [ 79 | custom_metrics.CustomMAE(name="custom_mae"), 80 | custom_metrics.InterquartileCapture(name="interquartile_capture"), 81 | custom_metrics.SignTest(name="sign_test"), 82 | ] 83 | 84 | y_avg = np.mean(y_train) 85 | y_std = np.std(y_train) 86 | 87 | mu_z_unit = tf.keras.layers.Dense( 88 | units=1, 89 | activation="linear", 90 | use_bias=True, 91 | bias_initializer=tf.keras.initializers.RandomNormal(seed=settings["seed"]+100), 92 | kernel_initializer=tf.keras.initializers.RandomNormal(seed=settings["seed"]+100), 93 | name="mu_z_unit", 94 | )(layers) 95 | 96 | mu_unit = tf.keras.layers.Rescaling( 97 | scale=y_std, 98 | offset=y_avg, 99 | name="mu_unit", 100 | )(mu_z_unit) 101 | 102 | # sigma_unit. The network predicts the log of the scaled sigma_z, then 103 | # the resclaing layer scales it up to log of sigma y, and the custom 104 | # Exponentiate layer converts it to sigma_y. 105 | log_sigma_z_unit = tf.keras.layers.Dense( 106 | units=1, 107 | activation="linear", 108 | use_bias=True, 109 | bias_initializer=tf.keras.initializers.Zeros(), 110 | kernel_initializer=tf.keras.initializers.Zeros(), 111 | name="log_sigma_z_unit", 112 | )(layers) 113 | 114 | log_sigma_unit = tf.keras.layers.Rescaling( 115 | scale=1.0, 116 | offset=np.log(y_std), 117 | name="log_sigma_unit", 118 | )(log_sigma_z_unit) 119 | 120 | sigma_unit = Exponentiate( 121 | name="sigma_unit", 122 | )(log_sigma_unit) 123 | 124 | output_layer = tf.keras.layers.concatenate([mu_unit, sigma_unit], axis=1) 125 | 126 | elif settings['network_type'] == 'shash3': 127 | print('here') 128 | 129 | else: 130 | raise NotImpletementedError('no such network_type') 131 | 132 | # Constructing the model 133 | model = Model(inputs, output_layer) 134 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=settings["learning_rate"]), 135 | loss=LOSS, 136 | metrics=metrics, 137 | ) 138 | 139 | 140 | model.summary() 141 | 142 | return model 143 | -------------------------------------------------------------------------------- /plots.py: -------------------------------------------------------------------------------- 1 | """Metrics for generic plotting. 2 | 3 | Functions 4 | --------- 5 | plot_metrics(history,metric) 6 | plot_metrics_panels(history, settings) 7 | plot_map(x, clim=None, title=None, text=None, cmap='RdGy') 8 | """ 9 | 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import matplotlib as mpl 13 | import cartopy as ct 14 | import numpy.ma as ma 15 | import cartopy.crs as ccrs 16 | import cartopy.feature as cfeature 17 | import custom_metrics 18 | 19 | mpl.rcParams["figure.facecolor"] = "white" 20 | mpl.rcParams["figure.dpi"] = 150 21 | 22 | def savefig(filename,dpi=300): 23 | for fig_format in (".png",".pdf"): 24 | plt.savefig(filename + fig_format, 25 | bbox_inches="tight", 26 | dpi=dpi) 27 | 28 | 29 | def plot_metrics(history,metric): 30 | 31 | imin = np.argmin(history.history['val_loss']) 32 | 33 | plt.plot(history.history[metric], label='training') 34 | plt.plot(history.history['val_' + metric], label='validation') 35 | plt.title(metric) 36 | plt.axvline(x=imin, linewidth=.5, color='gray',alpha=.5) 37 | plt.legend() 38 | 39 | def plot_metrics_panels(history, settings): 40 | 41 | if(settings["network_type"]=="reg"): 42 | error_name = "mae" 43 | elif settings['network_type'] == 'shash2': 44 | error_name = "custom_mae" 45 | else: 46 | raise NotImplementedError('no such network_type') 47 | 48 | imin = len(history.history[error_name]) 49 | plt.subplots(figsize=(20,4)) 50 | 51 | plt.subplot(1,4,1) 52 | plot_metrics(history,'loss') 53 | plt.ylim(0,10.) 54 | 55 | plt.subplot(1,4,2) 56 | plot_metrics(history,error_name) 57 | plt.ylim(0,10) 58 | 59 | try: 60 | plt.subplot(1,4,3) 61 | plot_metrics(history,'interquartile_capture') 62 | 63 | plt.subplot(1,4,4) 64 | plot_metrics(history,'sign_test') 65 | except: 66 | pass 67 | 68 | 69 | def plot_map(x, clim=None, title=None, text=None, cmap='RdGy'): 70 | plt.pcolor(x, 71 | cmap=cmap, 72 | ) 73 | plt.clim(clim) 74 | plt.colorbar() 75 | plt.title(title,fontsize=15,loc='right') 76 | plt.yticks([]) 77 | plt.xticks([]) 78 | 79 | plt.text(0.01, 1.0, text, fontfamily='monospace', fontsize='small', va='bottom',transform=plt.gca().transAxes) 80 | 81 | def drawOnGlobe(ax, map_proj, data, lats, lons, cmap='coolwarm', vmin=None, vmax=None, inc=None, cbarBool=True, contourMap=[], contourVals = [], fastBool=False, extent='both'): 82 | 83 | data_crs = ct.crs.PlateCarree() 84 | data_cyc, lons_cyc = add_cyclic_point(data, coord=lons) #fixes white line by adding point#data,lons#ct.util.add_cyclic_point(data, coord=lons) #fixes white line by adding point 85 | data_cyc = data 86 | lons_cyc = lons 87 | 88 | 89 | # ax.set_global() 90 | # ax.coastlines(linewidth = 1.2, color='black') 91 | # ax.add_feature(cartopy.feature.LAND, zorder=0, scale = '50m', edgecolor='black', facecolor='black') 92 | land_feature = cfeature.NaturalEarthFeature( 93 | category='physical', 94 | name='land', 95 | scale='50m', 96 | facecolor='None', 97 | edgecolor = 'k', 98 | linewidth=.5, 99 | ) 100 | ax.add_feature(land_feature) 101 | # ax.GeoAxes.patch.set_facecolor('black') 102 | 103 | if(fastBool): 104 | image = ax.pcolormesh(lons_cyc, lats, data_cyc, transform=data_crs, cmap=cmap) 105 | # image = ax.contourf(lons_cyc, lats, data_cyc, np.linspace(0,vmax,20),transform=data_crs, cmap=cmap) 106 | else: 107 | image = ax.pcolor(lons_cyc, lats, data_cyc, transform=data_crs, cmap=cmap,shading='auto') 108 | 109 | if(np.size(contourMap) !=0 ): 110 | contourMap_cyc, __ = add_cyclic_point(contourMap, coord=lons) #fixes white line by adding point 111 | ax.contour(lons_cyc,lats,contourMap_cyc,contourVals, transform=data_crs, colors='fuchsia') 112 | 113 | if(cbarBool): 114 | cb = plt.colorbar(image, shrink=.45, orientation="horizontal", pad=.02, extend=extent) 115 | cb.ax.tick_params(labelsize=6) 116 | else: 117 | cb = None 118 | 119 | image.set_clim(vmin,vmax) 120 | 121 | return cb, image 122 | 123 | def add_cyclic_point(data, coord=None, axis=-1): 124 | 125 | # had issues with cartopy finding utils so copied for myself 126 | 127 | if coord is not None: 128 | if coord.ndim != 1: 129 | raise ValueError('The coordinate must be 1-dimensional.') 130 | if len(coord) != data.shape[axis]: 131 | raise ValueError('The length of the coordinate does not match ' 132 | 'the size of the corresponding dimension of ' 133 | 'the data array: len(coord) = {}, ' 134 | 'data.shape[{}] = {}.'.format( 135 | len(coord), axis, data.shape[axis])) 136 | delta_coord = np.diff(coord) 137 | if not np.allclose(delta_coord, delta_coord[0]): 138 | raise ValueError('The coordinate must be equally spaced.') 139 | new_coord = ma.concatenate((coord, coord[-1:] + delta_coord[0])) 140 | slicer = [slice(None)] * data.ndim 141 | try: 142 | slicer[axis] = slice(0, 1) 143 | except IndexError: 144 | raise ValueError('The specified axis does not correspond to an ' 145 | 'array dimension.') 146 | new_data = ma.concatenate((data, data[tuple(slicer)]), axis=axis) 147 | if coord is None: 148 | return_value = new_data 149 | else: 150 | return_value = new_data, new_coord 151 | return return_value 152 | 153 | def plot_pits(ax, x_val, onehot_val, model_shash): 154 | plt.sca(ax) 155 | clr_shash = 'tab:blue' 156 | 157 | # shash pit 158 | bins, hist_shash, D_shash, EDp_shash = custom_metrics.compute_pit(onehot_val, x_data=x_val,model_shash=model_shash) 159 | bins_inc = bins[1]-bins[0] 160 | 161 | bin_add = bins_inc/2 162 | bin_width = bins_inc*.98 163 | ax.bar(hist_shash[1][:-1] + bin_add, 164 | hist_shash[0], 165 | width=bin_width, 166 | color=clr_shash, 167 | label='SHASH', 168 | ) 169 | 170 | # make the figure pretty 171 | ax.axhline(y=.1, 172 | linestyle='--', 173 | color='k', 174 | linewidth=2., 175 | ) 176 | # ax = plt.gca() 177 | yticks = np.around(np.arange(0,.55,.05),2) 178 | plt.yticks(yticks,yticks) 179 | ax.set_ylim(0,.25) 180 | ax.set_xticks(bins,np.around(bins,1)) 181 | 182 | plt.text(0.,np.max(ax.get_ylim())*.99, 183 | 'SHASH D: ' + str(np.round(D_shash,4)) + ' (' + str(np.round(EDp_shash,3)) + ')', 184 | color=clr_shash, 185 | verticalalignment='top', 186 | fontsize=12) 187 | 188 | 189 | ax.set_xlabel('probability integral transform') 190 | ax.set_ylabel('probability') 191 | # plt.legend(loc=1) 192 | # plt.title('PIT histogram comparison', fontsize=FS, color='k') -------------------------------------------------------------------------------- /xai.py: -------------------------------------------------------------------------------- 1 | """Functions for XAI analysis. 2 | 3 | Functions 4 | --------- 5 | get_gradients(inputs, top_pred_idx=None) 6 | get_integrated_gradients(inputs, baseline=None, num_steps=50, top_pred_idx=None) 7 | random_baseline_integrated_gradients(inputs, num_steps=50, num_runs=5, top_pred_idx=None) 8 | 9 | """ 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | 14 | def get_gradients(model,inputs, top_pred_idx=None): 15 | """Computes the gradients of outputs w.r.t input image. 16 | 17 | Args: 18 | inputs: 2D/3D/4D matrix of samples 19 | top_pred_idx: (optional) Predicted label for the x_data 20 | if classification problem. If regression, 21 | do not include. 22 | 23 | Returns: 24 | Gradients of the predictions w.r.t img_input 25 | """ 26 | inputs = tf.cast(inputs, tf.float32) 27 | 28 | with tf.GradientTape() as tape: 29 | tape.watch(inputs) 30 | 31 | # Run the forward pass of the layer and record operations 32 | # on GradientTape. 33 | preds = model(inputs, training=False) 34 | 35 | # For classification, grab the top class 36 | if top_pred_idx is not None: 37 | preds = preds[:, top_pred_idx] 38 | 39 | # Use the gradient tape to automatically retrieve 40 | # the gradients of the trainable variables with respect to the loss. 41 | grads = tape.gradient(preds, inputs) 42 | return grads 43 | 44 | def get_integrated_gradients(model, inputs, baseline=None, num_steps=50, top_pred_idx=None): 45 | """Computes Integrated Gradients for a prediction. 46 | 47 | Args: 48 | inputs (ndarray): 2D/3D/4D matrix of samples 49 | baseline (ndarray): The baseline image to start with for interpolation 50 | num_steps: Number of interpolation steps between the baseline 51 | and the input used in the computation of integrated gradients. These 52 | steps along determine the integral approximation error. By default, 53 | num_steps is set to 50. 54 | top_pred_idx: (optional) Predicted label for the x_data 55 | if classification problem. If regression, 56 | do not include. 57 | 58 | Returns: 59 | Integrated gradients w.r.t input image 60 | """ 61 | # If baseline is not provided, start with zeros 62 | # having same size as the input image. 63 | if baseline is None: 64 | input_size = np.shape(inputs)[1:] 65 | baseline = np.zeros(input_size).astype(np.float32) 66 | else: 67 | baseline = baseline.astype(np.float32) 68 | 69 | # 1. Do interpolation. 70 | inputs = inputs.astype(np.float32) 71 | interpolated_inputs = [ 72 | baseline + (step / num_steps) * (inputs - baseline) 73 | for step in range(num_steps + 1) 74 | ] 75 | interpolated_inputs = np.array(interpolated_inputs).astype(np.float32) 76 | 77 | # 3. Get the gradients 78 | grads = [] 79 | for i, x_data in enumerate(interpolated_inputs): 80 | grad = get_gradients(model,x_data, top_pred_idx=top_pred_idx) 81 | # grads.append(grad[0]) WRONG 82 | grads.append(grad) 83 | grads = tf.convert_to_tensor(grads, dtype=tf.float32) 84 | 85 | # 4. Approximate the integral using the trapezoidal rule 86 | grads = (grads[:-1] + grads[1:]) / 2.0 87 | avg_grads = tf.reduce_mean(grads, axis=0) 88 | 89 | # 5. Calculate integrated gradients and return 90 | integrated_grads = (inputs - baseline) * avg_grads 91 | return integrated_grads 92 | 93 | def random_baseline_integrated_gradients(model,inputs, num_steps=50, num_runs=5, top_pred_idx=None): 94 | """Generates a number of random baseline images. 95 | 96 | Args: 97 | inputs (ndarray): 2D/3D/4D matrix of samples 98 | num_steps: Number of interpolation steps between the baseline 99 | and the input used in the computation of integrated gradients. These 100 | steps along determine the integral approximation error. By default, 101 | num_steps is set to 50. 102 | num_runs: number of baseline images to generate 103 | top_pred_idx: (optional) Predicted label for the x_data 104 | if classification problem. If regression, 105 | do not include. 106 | 107 | Returns: 108 | Averaged integrated gradients for `num_runs` baseline images 109 | """ 110 | # 1. List to keep track of Integrated Gradients (IG) for all the images 111 | integrated_grads = [] 112 | 113 | # 2. Get the integrated gradients for all the baselines 114 | for run in range(num_runs): 115 | baseline = np.zeros(np.shape(inputs)[1:]) 116 | for i in np.arange(0,np.shape(random_baseline)[0]): 117 | j = np.random.choice(np.arange(0,np.shape(inputs)[0])) 118 | baseline[i] = inputs[j,i] 119 | 120 | igrads = get_integrated_gradients( 121 | inputs=inputs, 122 | baseline=baseline, 123 | num_steps=num_steps, 124 | ) 125 | integrated_grads.append(igrads) 126 | 127 | # 3. Return the average integrated gradients for the image 128 | integrated_grads = tf.convert_to_tensor(integrated_grads) 129 | return tf.reduce_mean(integrated_grads, axis=0) 130 | 131 | --------------------------------------------------------------------------------