├── .gitignore ├── LICENSE ├── README.md ├── data ├── p2d_sei_10k.csv └── p2d_sei_parameters.csv ├── notebooks ├── 1_introduction.ipynb ├── 2_dataset_exploration.ipynb ├── 3_optional_bayesian_linear_regression.ipynb ├── 4_optional_gpr_basics.ipynb ├── 5_train_gpr.ipynb ├── 6_sensitivity_analysis_of_gpr.ipynb └── 7_interactive_sensitivity_analysis_of_gpr.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_store 2 | .ipynb_checkpoints 3 | *.p 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sensitivity Analysis Tutorial 2 | 3 | This tutorial was prepared for the BIG-MAP AI school January 2022 by Jonas Busk (jbusk@dtu.dk). 4 | 5 | The repository contains a series of exercises demonstrating a sensitivity analysis of a battery degradation model in a step-by-step manner. 6 | After completing these exercises, you will be able to perform a similar analysis of your own dataset and you will have some code to get started. 7 | The exercises are provided as a series of Python notebooks that let you write and run code. 8 | Some basic understanding of Python programming and machine learning is assumed. 9 | 10 | The easiest way to get started with the exercises is to open them using [Google Colab](https://colab.research.google.com/), but you can also download and run them locally if you prefer. 11 | 12 | 13 | ## Run exercises using Google Colab 14 | 15 | You can run these exercises using Google Colab without the need to install anything on your computer. 16 | This requires a Google account. 17 | 18 | * Go to [https://colab.research.google.com/](https://colab.research.google.com/). 19 | * Select 'File' > 'Open notebook'. 20 | * Select 'GitHub'. 21 | * Enter the GitHub URL to this repository and hit return (a list of the exercise notebooks should show up). 22 | * Select the exercise notebook you would like to open. 23 | 24 | 25 | ## Run exercises locally 26 | 27 | You can also download the exercises and run them locally on your own computer. 28 | This requires that you have Python 3 installed and know how to run Python notebooks. 29 | 30 | * Download this repository. 31 | * Create and activate a new virtual environment. 32 | * Install requirements listed in `requirements.txt`. 33 | * Run the exercises. 34 | 35 | 36 | ## Citation 37 | 38 | We are currently preparing a publication based on the methods and data presented in this tutorial which will be available in the near future. 39 | -------------------------------------------------------------------------------- /data/p2d_sei_parameters.csv: -------------------------------------------------------------------------------- 1 | ,name,description,unit,low,high,nominal 2 | 0,i_1C,Applied current,A,0.13,6.5,1.3 3 | 1,rp_neg,Particle radius,um,1e-06,1.1e-05,5.5e-06 4 | 2,Eeq_side,Equilibrium potential (SEI),V,0.0,0.4,0.4 5 | 3,kappa_film,SEI film conductivity,S/m,1e-06,0.00024,0.00024 6 | 4,epsl_neg,Porosity of anode,,0.23,0.4,0.3 7 | 5,Dl,Electrolyte diffusion coefficient,m2/s,1.5e-10,7.5e-10,3.75e-10 8 | 6,Ds_neg,Diffusion coefficient of Li in Anode,m2/s,1e-14,1e-13,3.6e-14 9 | 7,i0ref_neg,Anode exchange current density,A/m2,0.36,3.6,0.96 10 | 8,E_min,Minimum cut-off voltage,V,0.0,0.1,0.05 11 | 9,i0_SEI,SEI exchange current density,A/m2,8e-08,1.5e-06,4.5e-07 12 | 10,csmax_neg,Maximum Li ion concentration in anode,mol/m3,29000.0,33000.0,31500.0 13 | 11,cl_0,Initial electrolyte concentration,mol/m3,1000.0,1200.0,1150.0 14 | 12,t_plus,Transference number,,0.25,0.43,0.363 15 | 13,i0ref_metal,Li metal exchange current density,A/m2,50.0,100.0,100.0 16 | 14,sigma_neg,Anode conductivity,S/m,50.0,100.0,100.0 17 | -------------------------------------------------------------------------------- /notebooks/1_introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7ecdc317-abf3-45c8-80d8-df00f30534d8", 6 | "metadata": {}, 7 | "source": [ 8 | "# Introduction\n", 9 | "\n", 10 | "These exercises demonstrate a sensitivity analysis of a battery degradation model in a step-by-step manner.\n", 11 | "After doing these exercises you will be able to perform a similar analysis of your own dataset and you will have some code to get you started.\n", 12 | "\n", 13 | "The exercises are provided as a series of Python notebooks (like the one you are reading now) that let you write and run code.\n", 14 | "Some basic understanding of Python programming and machine learning is assumed.\n", 15 | "\n", 16 | "Please note that some of the exercises require you to have completed the earlier exercises to function so you should preferably do them in the indicated order. \n", 17 | "The exercises marked as [optional] are not essential to the method but provide some additional useful information. \n", 18 | "You can skip these or revisit them later if you like.\n", 19 | "\n", 20 | "List of exercises:\n", 21 | "\n", 22 | "1. Introduction (you are here)\n", 23 | "1. Dataset exploration\n", 24 | "1. Bayesian linear regression model [optional]\n", 25 | "1. Gaussian process regression basics [optional]\n", 26 | "1. Train Gaussian process regression model\n", 27 | "1. Sensitivity analysis with GP regression model\n", 28 | "1. Interactive sensitivity analysis with GP regression model\n", 29 | "\n", 30 | "\n", 31 | "In each exercise we encourage you to familiarise yourself with the code and think about the results. You are also encouraged to experiment with modifying the code and changing different input values to see how they affect the output. You will likely learn much more this way than if you just run the code and read the instructions. \n", 32 | "\n", 33 | "Finally, the exercise series is setup to analyse a single output of the battery degradation model, namely the `SEI_thickness(m)`.\n", 34 | "If at the end you want to do some additional work, you can repeat the analysis for the `Capacity loss (%)` output. \n", 35 | "Interestingly, the analysis yields different results for the two outputs." 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "a91d41b8-1665-4887-a194-5fe651fba792", 41 | "metadata": { 42 | "tags": [] 43 | }, 44 | "source": [ 45 | "## A very brief introduction to Python notebooks\n", 46 | "\n", 47 | "In case you are new to Python notebooks here is a *very* brief introduction:\n", 48 | "\n", 49 | "Notebooks (like the one you are reading now) allow you to combine executable code and rich text in a single document.\n", 50 | "A notebook is made up of a sequence of cells that can contain either text in markdown format (like this cell) or code like the cell below." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "62633322-d5ac-423d-a33e-3b3e114eb3f7", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "seconds_in_a_day = 24 * 60 * 60\n", 61 | "print(f\"There are {seconds_in_a_day} seconds in a day.\")" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "beda41b5-ed1c-449f-9594-87e6ddecb173", 67 | "metadata": {}, 68 | "source": [ 69 | "To execute the code in the above cell, select it with a click and then either press the play button, or use the keyboard shortcut 'Shift/Command/Ctrl+Enter'.\n", 70 | "The output of the code (if any) is shown directly underneath the cell.\n", 71 | "To edit the code, just click the cell and start editing.\n", 72 | "\n", 73 | "Variables that you define in one cell can later be used in other cells:" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "94bb7de7-7cbb-4794-b8fb-77880e0f9456", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "print(f\"There are {seconds_in_a_day} seconds in a day.\")" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "1f173220-d541-45fc-919b-874e1753a0fe", 89 | "metadata": {}, 90 | "source": [ 91 | "And that is basically all there is to it..." 92 | ] 93 | } 94 | ], 95 | "metadata": { 96 | "kernelspec": { 97 | "display_name": "Python 3 (ipykernel)", 98 | "language": "python", 99 | "name": "python3" 100 | }, 101 | "language_info": { 102 | "codemirror_mode": { 103 | "name": "ipython", 104 | "version": 3 105 | }, 106 | "file_extension": ".py", 107 | "mimetype": "text/x-python", 108 | "name": "python", 109 | "nbconvert_exporter": "python", 110 | "pygments_lexer": "ipython3", 111 | "version": "3.9.6" 112 | } 113 | }, 114 | "nbformat": 4, 115 | "nbformat_minor": 5 116 | } 117 | -------------------------------------------------------------------------------- /notebooks/2_dataset_exploration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "3XExE3pz6vhG" 7 | }, 8 | "source": [ 9 | "# Dataset exploration\n", 10 | "\n", 11 | "In this notebook we will get to know the dataset we will use throughout the tutorial series.\n", 12 | "\n", 13 | "The dataset consists of inputs and outputs of a battery degradation model.\n", 14 | "More specifically, a pseudo-two-dimensional (P2D) model configured to simulate the formation of the solid electrolyte interphase (SEI) in a battery based on the reduction of the solvent near the surface of the negative electrode during charging.\n", 15 | "\n", 16 | "The electrolyte considered in the model is a mixture of ethyl carbonate/ethyl methyl carbonate (EC/EMC) with LiPF$_6$ salt. Hence, we assume that main product forming the SEI layer is Li$_2$CO$_3$ and it is formed according to the reaction:\n", 17 | "\n", 18 | "$$\n", 19 | "\\text{S} + 2\\text{Li}^+ + 2e^- \\rightarrow \\text{P}\n", 20 | "$$\n", 21 | "\n", 22 | "where $\\text{S}$ is the solvent species and $\\text{P}$ is the product of the reaction between the solvent ant the Li ions.\n", 23 | "The growth of the SEI layer is assumed to be in one-dimension and to be controlled by the kinetics of the reaction occurring at the interphase.\n", 24 | "\n", 25 | "The inputs and outputs are explored in more detail below." 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Dependencies" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": { 39 | "id": "Vq7thp398aVr" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "# imports\n", 44 | "import matplotlib.pyplot as plt\n", 45 | "import numpy as np\n", 46 | "import pandas as pd" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": { 52 | "id": "L6w0L8uY65N3" 53 | }, 54 | "source": [ 55 | "## Load dataset\n", 56 | "\n", 57 | "We can load the dataset directly from the GitHub URL.\n", 58 | "Alternatively, the dataset can be loaded from a local file." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "id": "IX4x2ioy6V2C" 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "# load parameter table\n", 70 | "parameters_path = \"https://raw.githubusercontent.com/BIG-MAP/sensitivity_analysis_tutorial/main/data/p2d_sei_parameters.csv\"\n", 71 | "# parameters_path = \"./../data/p2d_sei_parameters.csv\" # local\n", 72 | "pt = pd.read_csv(parameters_path, index_col=0)\n", 73 | "pt.unit.replace(np.nan, \"-\", inplace=True)\n", 74 | "\n", 75 | "# load dataset\n", 76 | "dataset_path = \"https://raw.githubusercontent.com/BIG-MAP/sensitivity_analysis_tutorial/main/data/p2d_sei_10k.csv\"\n", 77 | "# dataset_path = \"./../data/p2d_sei_10k.csv\" # local\n", 78 | "df = pd.read_csv(dataset_path, index_col=0)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "id": "1PNvGZ9q6_94" 85 | }, 86 | "source": [ 87 | "## Parameter table\n", 88 | "\n", 89 | "Let us first have a look at the parameter table that was used to generate the dataset.\n", 90 | "The table shows the input parameters of the P2D model along with the selected input ranges (low, high) and nominal values." 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "id": "HREV_DtD-1aM" 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "pt" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": { 107 | "id": "FblXBqRe9xcX" 108 | }, 109 | "source": [ 110 | "## Dataset\n", 111 | "\n", 112 | "The dataset was generated by sampling the input parameters uniformly at random within the ranges given in the parameter table above.\n", 113 | "Then the P2D model was used to label each row.\n", 114 | "The outputs of the P2D model are stored in the last 2 columns of the dataset:\n", 115 | "\n", 116 | " - SEI_thickness(m): Thickness of the solid electrolyte interphase (SEI).\n", 117 | " - Capacity loss (%): Loss of capacity due to SEI formation.\n", 118 | "\n", 119 | "In this analysis we will primarily focus on `SEI_thickness(m)` and optionally on `Capacity loss (%)`." 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "id": "9q7o6aDB-3BH" 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "# show dataset statistics\n", 131 | "df.describe()" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": { 138 | "id": "xM9h7z35_qg0" 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "# show first rows of the dataset\n", 143 | "df.head()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "id": "2Qt4PfIZAAGX" 150 | }, 151 | "source": [ 152 | "As you might notice, the first row in the dataset corresponds to the nominal values given in the table. " 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": { 158 | "id": "I_ZgmwU-AYkS" 159 | }, 160 | "source": [ 161 | "## Data visualisation\n", 162 | "\n", 163 | "Let us explore the dataset with some visualisations. " 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": { 170 | "id": "W8H_W4vd_u-N" 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "# plot histograms\n", 175 | "_ = df.hist(figsize=(20,20))" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": { 182 | "id": "rFZEyCfmAjlk" 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "# plot scatter matrix\n", 187 | "n = 1000\n", 188 | "_ = pd.plotting.scatter_matrix(df.iloc[:n], figsize=(20,20))" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": { 194 | "id": "0vqvDZ2eBl6k" 195 | }, 196 | "source": [ 197 | "We plot only the first `n = 1000` points in the scatter matrix above as to not crowd the visualisation. Feel free to try and change this number. \n", 198 | "\n", 199 | "Since the data is sampled at random, you should see that the data covers the entire input range, as defined in the parameter table, and that the input parameters are not correlated.\n", 200 | "\n", 201 | "Perhaps you can also spot some interesting correlations or patterns between the inputs and outputs? It might not be super clear from this figure, but you can try to make a note of the input variables that look interesting to you and follow up on them later in our analysis of this dataset." 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": { 207 | "id": "M2i_RRMLER9F" 208 | }, 209 | "source": [ 210 | "## Outputs\n", 211 | "\n", 212 | "Let us have a closer look at the two outputs of interest: `SEI_thickness(m)` and `Capacity loss (%)`." 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": { 219 | "id": "irvqa279AqGT" 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "_ = df[[\"SEI_thickness(m)\", \"Capacity loss (%)\"]].hist(bins=50, figsize=(10, 4))" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": { 229 | "id": "YavAp8TAFIJi" 230 | }, 231 | "source": [ 232 | "Both of these outputs are strictly positive and have long tails.\n", 233 | "In our analysis with machine learning models, it might be useful to instead consider the log transformed outputs to account for these properties." 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": { 240 | "id": "yGHkK_KiEpUK" 241 | }, 242 | "outputs": [], 243 | "source": [ 244 | "_ = df[[\"SEI_thickness(m)\", \"Capacity loss (%)\"]].transform(np.log).add_prefix(\"log \").hist(bins=50, figsize=(10, 4))" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": { 250 | "id": "WV0LfwSjGTbh" 251 | }, 252 | "source": [ 253 | "After applying the log transformation, the outputs are on an unbounded continuous scale, which will be simpler to model. " 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "We can also look at the relationship between the two outputs of interest" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "plt.figure(figsize=(5,5))\n", 270 | "plt.plot(df[\"SEI_thickness(m)\"], df[\"Capacity loss (%)\"], \".\", alpha=.5)\n", 271 | "plt.xlabel(\"SEI_thickness (m)\"); plt.ylabel(\"Capacity loss (%)\")\n", 272 | "plt.grid()\n", 273 | "plt.show()" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "It looks like an increase in `SEI_thickness(m)` correlates with an increase in `Capacity loss (%)`. \n", 281 | "However, it looks like other factors might also lead to high `Capacity loss (%)`.\n", 282 | "Perhaps that is something we can also see in the sensitivity analysis. " 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": { 288 | "id": "xt4YX_NgJp4e" 289 | }, 290 | "source": [ 291 | "## Additional exploration\n", 292 | "\n", 293 | "If there is anything else you are curious to know about the dataset, go ahead and create your own plots and statistics below (or wherever you like). \n", 294 | "Do not forget to save a copy of the notebook if you want to keep your changes. " 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": { 301 | "id": "0sAWcwQLFU2l" 302 | }, 303 | "outputs": [], 304 | "source": [ 305 | "# my additional data exploration\n" 306 | ] 307 | } 308 | ], 309 | "metadata": { 310 | "colab": { 311 | "authorship_tag": "ABX9TyPij3o7gRl+CYLJWRa6aGts", 312 | "collapsed_sections": [], 313 | "name": "Copy of dataset_exploration.ipynb", 314 | "provenance": [ 315 | { 316 | "file_id": "https://github.com/BIG-MAP/sensitivity_analysis_tutorial/blob/main/notebooks/2_dataset_exploration.ipynb", 317 | "timestamp": 1642409492380 318 | } 319 | ] 320 | }, 321 | "kernelspec": { 322 | "display_name": "Python 3 (ipykernel)", 323 | "language": "python", 324 | "name": "python3" 325 | }, 326 | "language_info": { 327 | "codemirror_mode": { 328 | "name": "ipython", 329 | "version": 3 330 | }, 331 | "file_extension": ".py", 332 | "mimetype": "text/x-python", 333 | "name": "python", 334 | "nbconvert_exporter": "python", 335 | "pygments_lexer": "ipython3", 336 | "version": "3.9.6" 337 | } 338 | }, 339 | "nbformat": 4, 340 | "nbformat_minor": 4 341 | } 342 | -------------------------------------------------------------------------------- /notebooks/3_optional_bayesian_linear_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6125089d-1d04-4a51-b371-f140b9570806", 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "source": [ 10 | "# Bayesian linear regression model [optional]\n", 11 | "\n", 12 | "In this notebook we fit a [Bayesian linear regression](https://en.wikipedia.org/wiki/Bayesian_linear_regression) model to the data. \n", 13 | "This serves mainly as a useful baseline and tells us if there is a strong linear relationship between the inputs and output. \n", 14 | "If the linear model fits the data well, then perhaps there is no reason to apply a more complicated model!\n", 15 | "\n", 16 | "The advantage of using a Bayesian approach over a classical linear regression model in this case is that the Bayesian inference provides us with distributions of all parameters instead of just point estimates, which allows us to perform some additional analyses.\n", 17 | "\n", 18 | "Note that this step is optional and not strictly necessary to complete the rest of the tutorial. " 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "fb08b88b-5f03-4e02-b3d8-50bd7b738d50", 24 | "metadata": {}, 25 | "source": [ 26 | "## Dependencies\n", 27 | "\n", 28 | "First we import the required dependencies.\n", 29 | "\n", 30 | "If you are in Colab, you need to install the [pyro](https://pyro.ai/) package by uncommenting and running the line `!pip3 install pyro-ppl` below before proceeding." 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "896d1a72-6e06-45e7-acf2-599b975f6773", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# install dependencies\n", 41 | "# !pip3 install pyro-ppl" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "7ae65a64-a095-4fd3-b7a0-90230c3ca1ae", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# imports\n", 52 | "from collections import defaultdict\n", 53 | "from pathlib import Path\n", 54 | "\n", 55 | "import matplotlib.pyplot as plt\n", 56 | "import numpy as np\n", 57 | "import pandas as pd\n", 58 | "import torch\n", 59 | "import pyro\n", 60 | "\n", 61 | "pyro.set_rng_seed(0)\n", 62 | "print(f\"torch version: {torch.__version__}\")\n", 63 | "print(f\"pyro version: {pyro.__version__}\")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "a8a38a1a-e0b6-48de-8bf8-cb4be4e2f431", 69 | "metadata": { 70 | "tags": [] 71 | }, 72 | "source": [ 73 | "## Load dataset\n", 74 | "\n", 75 | "We can load the dataset directly from the GitHub URL.\n", 76 | "Alternatively, the dataset can be loaded from a local file." 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "58c72af5-e986-4a91-a140-d8736a4ff765", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# load dataset\n", 87 | "dataset_path = \"https://raw.githubusercontent.com/BIG-MAP/sensitivity_analysis_tutorial/main/data/p2d_sei_10k.csv\"\n", 88 | "# dataset_path = \"data/p2d_sei_10k.csv\" # local\n", 89 | "df = pd.read_csv(dataset_path, index_col=0)\n", 90 | "\n", 91 | "# store the names of the features and the name of the target variable\n", 92 | "features = df.columns[:15].tolist() # use input parameters as features\n", 93 | "target = \"SEI_thickness(m)\" # primary target\n", 94 | "# target = \"Capacity loss (%)\" # secondary target" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "id": "72255706-693b-4042-96c6-d0ec2084e1a2", 100 | "metadata": {}, 101 | "source": [ 102 | "## Prepare training and validation data\n", 103 | "\n", 104 | "In preparation for training the machine learning model we do a few data transformations:\n", 105 | "\n", 106 | "* The target variable is log transformed and normalised to zero mean and unit variance.\n", 107 | "* The input features are normalised to zero mean and unit variance to make the model parameters easier to learn and to put the inputs on the same scale and thus make results for each input directly comparable. \n", 108 | "\n", 109 | "Finally, the data is split into a training and a validation set. " 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "80e50d7f-8a98-4673-8999-957346610bd3", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# helper functions\n", 120 | "\n", 121 | "def create_data_split_index(n_data, n_train, n_valid=None, shuffle=False):\n", 122 | " \"\"\"Create data split index.\"\"\"\n", 123 | " n_valid = n_data - n_train if n_valid is None else n_valid \n", 124 | " index = torch.randperm(n_data) if shuffle else torch.arange(n_data)\n", 125 | " split = {\n", 126 | " \"train\": index[:n_train],\n", 127 | " \"valid\": index[n_train:n_train + n_valid],\n", 128 | " \"rest\": index[n_train + n_valid:],\n", 129 | " }\n", 130 | " return split\n", 131 | "\n", 132 | "def create_normaliser(x, y):\n", 133 | " \"\"\"Create data normalisation function\"\"\"\n", 134 | " x_mean, x_std = x.mean(axis=0), x.std(axis=0)\n", 135 | " y_mean, y_std = y.mean(axis=0), y.std(axis=0)\n", 136 | " def normaliser(x, y):\n", 137 | " return (x - x_mean) / x_std, (y - y_mean) / y_std\n", 138 | " normaliser_params = {\"x_mean\": x_mean, \"x_std\": x_std, \"y_mean\": y_mean, \"y_std\": y_std}\n", 139 | " return normaliser, normaliser_params" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "id": "44f73ddd-2ac5-4dc4-9080-6213a7e49f00", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "# settings\n", 150 | "shuffle = False\n", 151 | "n_data = len(df)\n", 152 | "n_train = 5000\n", 153 | "n_valid = 5000\n", 154 | "\n", 155 | "assert n_train + n_valid <= n_data\n", 156 | "\n", 157 | "# create data tensors\n", 158 | "x_data_orig = torch.tensor(df[features].values, dtype=torch.float)\n", 159 | "y_data_orig = torch.tensor(df[target].values, dtype=torch.float)\n", 160 | "\n", 161 | "# log transform y\n", 162 | "y_data_orig = torch.log(y_data_orig)\n", 163 | "\n", 164 | "# create data split index\n", 165 | "split = create_data_split_index(n_data, n_train, n_valid)\n", 166 | "\n", 167 | "# create normalisation function from training split\n", 168 | "normaliser, normaliser_params = create_normaliser(x_data_orig[split[\"train\"]], y_data_orig[split[\"train\"]])\n", 169 | "\n", 170 | "# normalize data\n", 171 | "x_data, y_data = normaliser(x_data_orig, y_data_orig)\n", 172 | "\n", 173 | "# create data splits \n", 174 | "x_train, y_train = x_data[split[\"train\"]], y_data[split[\"train\"]]\n", 175 | "x_valid, y_valid = x_data[split[\"valid\"]], y_data[split[\"valid\"]]\n", 176 | "\n", 177 | "assert len(x_train) == len(y_train) == n_train\n", 178 | "assert len(x_valid) == len(y_valid) == n_valid\n", 179 | "\n", 180 | "n_bins = 50\n", 181 | "plt.figure(figsize=(8,3))\n", 182 | "plt.subplot(121)\n", 183 | "plt.hist(y_train.numpy(), bins=n_bins)\n", 184 | "plt.xlabel(\"y_train\")\n", 185 | "plt.subplot(122)\n", 186 | "plt.hist(y_valid.numpy(), bins=n_bins)\n", 187 | "plt.xlabel(\"y_valid\")\n", 188 | "plt.show()" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "id": "db7db671-5efc-4912-b76e-914710006da3", 194 | "metadata": {}, 195 | "source": [ 196 | "## Fit Bayesian linear regression model\n", 197 | "\n", 198 | "Here we first define a [Bayesian linear regression](https://en.wikipedia.org/wiki/Bayesian_linear_regression) model with normal priors on the parameters.\n", 199 | "Then we train it on the training data we prepared above." 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "aa106a96-1958-44b9-ab46-86064ed5be45", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "def bayesian_linear_regression_model(x, y=None):\n", 210 | " \"\"\"Bayesian linear regression with normal priors.\"\"\"\n", 211 | " # priors\n", 212 | " n_features = x.shape[1]\n", 213 | " w = pyro.sample(\"weight\", pyro.distributions.Normal(0., 1.).expand([n_features]).to_event(1))\n", 214 | " b = pyro.sample(\"bias\", pyro.distributions.Normal(0., 1.))\n", 215 | " sigma = pyro.sample(\"sigma\", pyro.distributions.HalfNormal(1.))\n", 216 | " # likelihood\n", 217 | " mu = (x @ w + b)\n", 218 | " with pyro.plate(\"data\", len(x)):\n", 219 | " pyro.sample(\"obs\", pyro.distributions.Normal(mu, sigma), obs=y)\n", 220 | " return mu" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "aa2a1bb0-c787-4644-baac-7b18c7fcbc90", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "# train the model\n", 231 | "def train(\n", 232 | " model,\n", 233 | " x_train,\n", 234 | " y_train,\n", 235 | " x_valid,\n", 236 | " y_valid,\n", 237 | " n_steps=1000,\n", 238 | " eval_freq=100,\n", 239 | "):\n", 240 | " pyro.clear_param_store()\n", 241 | " guide = pyro.infer.autoguide.AutoDiagonalNormal(model)\n", 242 | " optimiser = pyro.optim.Adam({\"lr\": 0.01})\n", 243 | " svi = pyro.infer.SVI(model, guide, optimiser, loss=pyro.infer.Trace_ELBO())\n", 244 | " errors = defaultdict(list)\n", 245 | " for step in range(n_steps):\n", 246 | " elbo = svi.step(x_train, y_train)\n", 247 | " if step == 0 or (step + 1) % eval_freq == 0:\n", 248 | " train_loss = svi.evaluate_loss(x_train, y_train) / len(x_train)\n", 249 | " valid_loss = svi.evaluate_loss(x_valid, y_valid) / len(x_valid)\n", 250 | " errors[\"train_step\"].append(step + 1)\n", 251 | " errors[\"train_loss\"].append(train_loss)\n", 252 | " errors[\"valid_loss\"].append(valid_loss)\n", 253 | " print(f\"[{step + 1:5d}] train loss: {train_loss:7.4f}, valid loss: {valid_loss:7.4f}\")\n", 254 | " return guide, errors\n", 255 | "\n", 256 | "guide, errors = train(bayesian_linear_regression_model, x_train, y_train, x_valid, y_valid)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "id": "686785e6-496c-45fc-995c-c6452d03aeb4", 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "# plot training curve\n", 267 | "plt.figure()\n", 268 | "plt.plot(errors[\"train_step\"], errors[\"train_loss\"], label=\"train loss\")\n", 269 | "plt.plot(errors[\"train_step\"], errors[\"valid_loss\"], label=\"valid loss\")\n", 270 | "plt.xlabel(\"training step\"); plt.ylabel(\"loss\")\n", 271 | "plt.legend()\n", 272 | "plt.grid()\n", 273 | "plt.show()" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "id": "a8076aec-b9ff-45c4-b41c-b4f2a0d0115e", 279 | "metadata": {}, 280 | "source": [ 281 | "## Sample posterior distribution\n", 282 | "\n", 283 | "Now that we have a trained model, we can draw samples of the trained model parameters and predictions from the posterior distribution so we can analyse them below. " 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "id": "1daaf847-54fd-40b5-9fee-24ab005f1acc", 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "def sample_posterior(model, guide, x, n_samples=1000):\n", 294 | " predictive = pyro.infer.Predictive(model, guide=guide, num_samples=n_samples, return_sites=(\"weight\", \"bias\", \"sigma\", \"obs\", \"_RETURN\",))\n", 295 | " raw_samples = predictive(x)\n", 296 | " samples = {\"raw_samples\": raw_samples}\n", 297 | " for i in range(len(features)):\n", 298 | " samples[f\"weight{i}\"] = raw_samples[\"weight\"][:,:,i].squeeze()\n", 299 | " samples[\"bias\"] = raw_samples[\"bias\"].squeeze()#.numpy()\n", 300 | " samples[\"sigma\"] = raw_samples[\"sigma\"].squeeze()#.numpy()\n", 301 | " samples[\"mu\"] = raw_samples[\"_RETURN\"].squeeze()#.numpy()\n", 302 | " return samples\n", 303 | "\n", 304 | "samples_train = sample_posterior(bayesian_linear_regression_model, guide, x_train)\n", 305 | "samples_valid = sample_posterior(bayesian_linear_regression_model, guide, x_valid)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "id": "39b9510b-380e-4af6-9125-330c0c42c0ac", 311 | "metadata": {}, 312 | "source": [ 313 | "## Posterior predictive distribution\n", 314 | "\n", 315 | "First, we check the predictions on the training and validation data to see how well the model fits the data." 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "id": "52fb404a-c101-46b6-bd96-483d564ea816", 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "def r2(y_true, y_pred):\n", 326 | " \"\"\"Compute coefficient of determination.\"\"\"\n", 327 | " ssr = torch.sum((y_true - y_pred)**2)\n", 328 | " sst = torch.sum((y_true - torch.mean(y_true))**2)\n", 329 | " return 1 - (ssr / sst)\n", 330 | "\n", 331 | "def mae(y_true, y_pred):\n", 332 | " \"\"\"Compute mean absolute error.\"\"\"\n", 333 | " return torch.mean(torch.abs(y_true - y_pred))\n", 334 | "\n", 335 | "def evaluate_predictions(y_true, y_pred, lim=(-3,3), figsize=(8,4)):\n", 336 | " _r2 = r2(y_true, y_pred) # coefficient of determination\n", 337 | " _mae = mae(y_true, y_pred) # mean absolute error\n", 338 | " print(f\"r2: {_r2:.4f}, mae: {_mae:.4f}\\n\")\n", 339 | " # plot\n", 340 | " plt.figure(figsize=figsize)\n", 341 | " plt.subplot(121)\n", 342 | " plt.plot(lim, lim, color=\"k\", linestyle=\"--\", linewidth=1)\n", 343 | " plt.plot(y_true.numpy(), y_pred.numpy(), \".\", alpha=.1)\n", 344 | " plt.xlim(lim); plt.ylim(lim)\n", 345 | " plt.xlabel(\"y_true\"); plt.ylabel(\"y_pred\")\n", 346 | " plt.grid()\n", 347 | " plt.subplot(122)\n", 348 | " plt.hist(y_true.numpy() - y_pred.numpy(), bins=20)\n", 349 | " plt.xlim(lim)\n", 350 | " plt.xlabel(\"y_true - y_pred\")\n", 351 | " plt.tight_layout()\n", 352 | " plt.show()" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "id": "a800f107-0575-4454-a676-4e1b3636b59b", 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "# evaluate on training data\n", 363 | "evaluate_predictions(y_train, samples_train[\"mu\"].mean(axis=0).detach())" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "id": "8d60a366-7fe9-45a0-8194-30b02d9d264e", 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "# evaluate on validation data\n", 374 | "evaluate_predictions(y_valid, samples_valid[\"mu\"].mean(axis=0))" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "id": "4815e439-0292-4578-8d38-278cfd55b48f", 380 | "metadata": {}, 381 | "source": [ 382 | "On the log transformed `SEI_thickness(m)` output, the linear model performs rather well.\n", 383 | "But there is still room for improvement." 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "id": "dbbaaeb2-07bb-47aa-a002-06e96cba44fe", 389 | "metadata": {}, 390 | "source": [ 391 | "## Plot posterior distribution\n", 392 | "\n", 393 | "Below we first plot the distributions of the model parameters.\n", 394 | "Parameters that are close to zero have a small effect on the output.\n", 395 | "\n", 396 | "We can actually compute the effects by multiplying the (mean) parameters with the corresponding input data and plot the resulting distributions. \n", 397 | "As expected, the close-to-zero parameters on average have very little effect in the output.\n", 398 | "\n", 399 | "Finally we can quantify the feature importances by computing the absolute [t-statistic](https://en.wikipedia.org/wiki/T-statistic) of each parameter distribution:\n", 400 | "\n", 401 | "$$\n", 402 | "t_d = \\frac{\\text{mean}(w_d)}{\\text{std}(w_d)}\n", 403 | "$$\n", 404 | "\n", 405 | "where $d$ denotes the input dimension.\n", 406 | "\n", 407 | "If you want to know more about these methods, this chapter on linear regression from the Interpretable Machine Learning book explains them well: https://christophm.github.io/interpretable-ml-book/limo.html" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "id": "838e5046-2a98-4af5-a307-4fdb765ae729", 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "def evaluate_posterior(samples, x, features):\n", 418 | " # prepare parameter table\n", 419 | " posterior_df = pd.DataFrame.from_dict({k:v.numpy() for k,v in samples.items() if k not in [\"raw_samples\", \"mu\"]})\n", 420 | " \n", 421 | " # parameter box plot\n", 422 | " _ = posterior_df.boxplot(figsize=(posterior_df.shape[1],4), rot=90)\n", 423 | " plt.title(\"Posterior distribution of model parameters\")\n", 424 | " plt.show()\n", 425 | " \n", 426 | " # effects box plot (w*x)\n", 427 | " effects = (samples[\"raw_samples\"][\"weight\"].mean(axis=0) * x)\n", 428 | " effects_df = pd.DataFrame(effects.numpy(), columns=[f\"x{i}: {f}\" for i,f in enumerate(features)])\n", 429 | " _ = effects_df.boxplot(figsize=(effects_df.shape[1], 4), rot=90)\n", 430 | " plt.title(\"Effects\")\n", 431 | " plt.show()\n", 432 | " \n", 433 | " # feature importance computed as absolute t-statistic |t|\n", 434 | " feature_importance = (posterior_df.mean() / posterior_df.std())[[c for c in posterior_df.columns if \"weight\" in c]].abs()\n", 435 | " plt.figure(figsize=(6,3))\n", 436 | " plt.bar(range(len(feature_importance)), feature_importance.values)\n", 437 | " plt.xticks(range(len(features)), [f\"x{i}: {f}\" for i,f in enumerate(features)], rotation=90)\n", 438 | " plt.xlabel(\"Feature\"); plt.ylabel(\"feature importance\")\n", 439 | " plt.show()" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "id": "3463d21f-1cc3-46e0-af5d-64fe03e20d05", 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [ 449 | "# evaluate on validation data\n", 450 | "evaluate_posterior(samples_valid, x_valid, features)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "markdown", 455 | "id": "d84200ea-3be5-4d6f-8dcc-4f24fae9edd1", 456 | "metadata": {}, 457 | "source": [ 458 | "From this analysis, it looks like there are a few important inputs with a large effect on the output, which can already be really useful information. \n", 459 | "These insights can help us simplify the model by ignoring some of the least important inputs and focus our efforts on the more important inputs.\n", 460 | "\n", 461 | "However, there are some important assumptions to keep in mind. \n", 462 | "* The validity of these results of course depends on how well the linear model fits the data.\n", 463 | "* In this example we made the analysis with regards to the log transformed output and care should be taken if we were to back-transform the results to the original scale since this is a nonlinear transformation and the predictive distribution would no longer be Gaussian.\n", 464 | "* Since this is a linear model, it does not reveal if there are any nonlinear effects and thus if the important inputs have the same effect along their entire range of variation or not.\n", 465 | "Perhaps a nonlinear model could provide some additional insights? We will look at that in the next steps of the tutorial." 466 | ] 467 | } 468 | ], 469 | "metadata": { 470 | "kernelspec": { 471 | "display_name": "Python 3 (ipykernel)", 472 | "language": "python", 473 | "name": "python3" 474 | }, 475 | "language_info": { 476 | "codemirror_mode": { 477 | "name": "ipython", 478 | "version": 3 479 | }, 480 | "file_extension": ".py", 481 | "mimetype": "text/x-python", 482 | "name": "python", 483 | "nbconvert_exporter": "python", 484 | "pygments_lexer": "ipython3", 485 | "version": "3.9.6" 486 | } 487 | }, 488 | "nbformat": 4, 489 | "nbformat_minor": 5 490 | } 491 | -------------------------------------------------------------------------------- /notebooks/4_optional_gpr_basics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5f98f0c2-a35f-4ae7-b2a8-5ff29a7a4e11", 6 | "metadata": {}, 7 | "source": [ 8 | "# Gaussian process regression basics [optional]\n", 9 | "\n", 10 | "This notebook presents basic Gaussian process (GP) regression using only the `numpy` package to keep the example as simple as possible. \n", 11 | "Note that you do not need to know Gaussian processes in detail to be able to complete the rest of the tutorial, so this step is optional.\n", 12 | "\n", 13 | "While a complete introduction to Gaussian proccess is beyond the scope of this tutorial, we try to describe the very basics. For more information and theoretical background, we recommend you to take a look at for example the excelent [Gaussian Processes for Machine Learning](http://www.gaussianprocess.org/gpml/) book.\n", 14 | "The aim here is just to build some intuition about how the model works in practice.\n", 15 | "\n", 16 | "In supervised learning we observe input-output pairs $(\\mathbf{x}, y)$ and we assume $y = f(\\mathbf{x})$ for some unknown function $f$, possibly corrupted by noise.\n", 17 | "The goal of learning is to estimate $f$ as closely as possible from the observed data.\n", 18 | "The optimal approach would be to estimate a distribution over functions given the data $p(f|\\mathbf{X},\\mathbf{y})$ and use it to make predictions given new inputs $p(y_*|\\mathbf{x}_*,\\mathbf{X},\\mathbf{y}) = \\int p(y_*|f,\\mathbf{x}_*)p(f|\\mathbf{X},\\mathbf{y})df$.\n", 19 | "\n", 20 | "A Gaussian process is a generalisation of the Gaussian distribution that describes a distribution over functions and is fully specified by its mean and covariance function:\n", 21 | "\n", 22 | "$$\n", 23 | "f \\sim GP(m(\\mathbf{x}),k(\\mathbf{x},\\mathbf{x'}))\n", 24 | "$$\n", 25 | "\n", 26 | "where $m(\\mathbf{x})$ is the mean function and $k(\\mathbf{x},\\mathbf{x'})$ is the covariance function also known as the kernel." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "f2a74fef-1cfd-4a68-963e-8de7c7d94816", 32 | "metadata": {}, 33 | "source": [ 34 | "## Dependencies" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "2d3ccc50-beb4-4d2a-abde-098473b6b4bd", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# imports\n", 45 | "import matplotlib.pyplot as plt\n", 46 | "import numpy as np" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "42f65067-abb1-48b7-9502-6e1363d82d18", 52 | "metadata": {}, 53 | "source": [ 54 | "## Example data\n", 55 | "\n", 56 | "We start by generating some example data. To keep it simple, we have just one input variable `x` and one output variable `y`. We make `y` a nonlinear function of `x` to better illustrate the flexibility of the GP regression model. " 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "29fd72fd-a25e-482a-9268-457cf6f7ee90", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "np.random.seed(0)\n", 67 | "\n", 68 | "N = 20 # number of data points\n", 69 | "f = lambda x: np.sin(4 * x) * np.sin (5 * x) # true function\n", 70 | "x = np.random.uniform(0.1, 1.5, N) # random values of x\n", 71 | "y = f(x) + np.random.normal(0, .05, N) # noisy observations\n", 72 | "x1 = np.linspace(0, 1.6, 200) # grid of x values for plotting\n", 73 | "\n", 74 | "plt.figure(figsize=(12,4))\n", 75 | "plt.plot(x1, f(x1), color='k', label=\"true function\")\n", 76 | "plt.plot(x, y, 'o', color='k', label=\"noisy observations\")\n", 77 | "plt.xlabel(\"x\"); plt.ylabel(\"y\")\n", 78 | "plt.legend()\n", 79 | "plt.show()" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "id": "b187a4d9-eeba-4576-aa81-33d51cab2d68", 85 | "metadata": {}, 86 | "source": [ 87 | "## GP prior distribution\n", 88 | "\n", 89 | "Here we define the GP prior distribution. That is the distribution of functions before we observe any data.\n", 90 | "\n", 91 | "We will use a very simple zero mean function:\n", 92 | "\n", 93 | "$$\n", 94 | "m(\\mathbf{x}) = 0\n", 95 | "$$\n", 96 | "\n", 97 | "Even with a prior with a zero-mean function, the GP is usually flexible enough to fit a wide variety of functions.\n", 98 | "\n", 99 | "We also define the widely used squared exponential kernel (also known as RBF):\n", 100 | "\n", 101 | "$$\n", 102 | "k_{SE}(x, x') = \\sigma_f^2 \\exp \\big( -\\frac{(x - x')^2}{2l^2} \\big)\n", 103 | "$$\n", 104 | "\n", 105 | "where $\\sigma_f^2$ is the variance parameter and $l$ is the length scale parameter.\n", 106 | "\n", 107 | "Finally, we define a simple noise kernel in the code below." 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "d5c107e6-ddef-4cc6-a1d7-55864ed9a5b7", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "class SquaredExponentialKernel:\n", 118 | " \n", 119 | " def __init__(self, variance, length_scale):\n", 120 | " self.variance = variance\n", 121 | " self.length_scale = length_scale\n", 122 | " \n", 123 | " def kernel_function(self, x1, x2):\n", 124 | " z = (x1 - x2)**2 / (2 * self.length_scale**2)\n", 125 | " return self.variance * np.exp(-z)\n", 126 | " \n", 127 | " def __call__(self, X, Z=None):\n", 128 | " \"\"\"Compute covaraince matrix.\"\"\"\n", 129 | " if Z is None:\n", 130 | " Z = X\n", 131 | " N, M = len(X), len(Z)\n", 132 | " K = np.zeros((N, M))\n", 133 | " # naive\n", 134 | " for i in range(N):\n", 135 | " for j in range(M):\n", 136 | " K[i, j] = self.kernel_function(X[i], Z[j])\n", 137 | " return K\n", 138 | " \n", 139 | "\n", 140 | "class NoiseKernel:\n", 141 | " \n", 142 | " def __init__(self, variance):\n", 143 | " self.variance = variance\n", 144 | " \n", 145 | " def __call__(self, X):\n", 146 | " \"\"\"Compute covaraince matrix.\"\"\"\n", 147 | " return self.variance * np.eye(len(X))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "bc0e3fb2-6e1c-4882-91a4-15057b0ce0d0", 153 | "metadata": {}, 154 | "source": [ 155 | "Now we can draw some sample functions from the GP prior distribution and plot them. You can try to change the `variance` and `length_scale` parameters in the code below to see how they affect the sampled functions. Try for example to change the `length_scale` in factors of ten." 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "adb40430-f870-4890-89d9-588d816c3fb6", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "def draw_samples_from_gp_prior(n_samples=5, variance=1.0, length_scale=0.1):\n", 166 | " # define kernel\n", 167 | " kernel = SquaredExponentialKernel(variance=variance, length_scale=length_scale)\n", 168 | " # draw samples\n", 169 | " samples = []\n", 170 | " for _ in range(n_samples):\n", 171 | " samples.append(np.random.multivariate_normal(np.zeros(len(x1)), kernel(x1)))\n", 172 | " std = np.sqrt(variance)\n", 173 | " # plot\n", 174 | " plt.figure(figsize=(12,4))\n", 175 | " plt.fill_between((0.0, 1.6), (2*std, 2*std), (-2*std, -2*std), color=\"C0\", alpha=0.3, label=\"uncertainty (2*std)\")\n", 176 | " for i,fx1 in enumerate(samples):\n", 177 | " plt.plot(x1, fx1, color=\"C0\", linestyle=\"--\", label=\"prior samples\" if i == 0 else \"\")\n", 178 | " plt.xlabel(\"x\"); plt.ylabel(\"y\")\n", 179 | " plt.legend()\n", 180 | " plt.show()\n", 181 | " \n", 182 | "draw_samples_from_gp_prior(n_samples=5, variance=1.0, length_scale=0.1)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "id": "d0759340-6bff-4ad8-be45-3d3593d30207", 188 | "metadata": { 189 | "tags": [] 190 | }, 191 | "source": [ 192 | "## GP posterior distribution (with noise)\n", 193 | "\n", 194 | "We now condition on the observed data to compute the GP posterior and make predictions of the mean $\\mu$ and the (co)variance $\\Sigma$.\n", 195 | "For the simple GP regression model, this step can be computed in closed form with the following expressions:\n", 196 | "\n", 197 | "\\begin{align}\n", 198 | "p(\\mathbf{f}_*|\\mathbf{X}_*,\\mathbf{X},\\mathbf{y}) &= \\mathcal{N}(\\mathbf{f}_*|\\mathbf{\\mu},\\mathbf{\\Sigma}) \\\\\n", 199 | "\\mathbf{\\mu} &= \\mu(\\mathbf{X}_*) + \\mathbf{K}^T_* \\mathbf{K}^{-1} (\\mathbf{y} - \\mu(\\mathbf{X})) = \\mathbf{K}^T_* \\mathbf{K}^{-1} \\mathbf{y} \\\\\n", 200 | "\\Sigma &= \\mathbf{K}_{**} - \\mathbf{K}^T_* \\mathbf{K}^{-1} \\mathbf{K}_*\n", 201 | "\\end{align}\n", 202 | "\n", 203 | "where $\\mathbf{K} = k(\\mathbf{X},\\mathbf{X}) + \\sigma^2 \\mathbf{I}$, $\\mathbf{K}_* = k(\\mathbf{X},\\mathbf{X}_*)$ and $\\mathbf{K}_{**}=k(\\mathbf{X}_*,\\mathbf{X}_*)$ with covariance function (kernel) $k$ and noise level $\\sigma^2$ and assuming a zero mean function $\\mu(\\mathbf{X})=\\mathbf{0}$.\n", 204 | "\n", 205 | "(In practice the Cholesky decomposition can be used instead of the computationally expensive matrix inversion.)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "5ea7b0a0-c1d4-4a83-82c8-646022537953", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "def predict(variance, length_scale, noise_level, n_samples=5):\n", 216 | " # kernels\n", 217 | " kernel = SquaredExponentialKernel(variance=variance, length_scale=length_scale)\n", 218 | " noise = NoiseKernel(variance=noise_level)\n", 219 | " # inference\n", 220 | " K = kernel(x) + noise(x)\n", 221 | " K_inv = np.linalg.inv(K)\n", 222 | " K1 = kernel(x, x1)\n", 223 | " K11 = kernel(x1, x1)\n", 224 | " mu = K1.T @ K_inv @ y\n", 225 | " Sigma = K11 - K1.T @ K_inv @ K1\n", 226 | " # sample posterior\n", 227 | " samples = []\n", 228 | " for _ in range(n_samples):\n", 229 | " fx1 = np.random.multivariate_normal(mu, Sigma)\n", 230 | " samples.append(fx1)\n", 231 | " sigma = np.sqrt(Sigma.diagonal())\n", 232 | " return mu, sigma, samples" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "id": "02503689-52cb-41f9-9af3-d4a3e3e4859c", 238 | "metadata": {}, 239 | "source": [ 240 | "Let us plot the model predictions for various values of the length scale parameter to see how it behaves. \n", 241 | "\n", 242 | "You can try and change the values of the `variance`, `length_scale` and `noise_level` parameters in the code below to see they affect the predictions of the model." 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "id": "45b47929-1ca7-4cdb-b172-3c7d0fb65acb", 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "for ls in [0.1, 0.2, 0.3, 0.4, 0.5]:\n", 253 | " mu, sigma, samples = predict(variance=1.0, length_scale=ls, noise_level=0.01)\n", 254 | " # plot\n", 255 | " plt.figure(figsize=(12,4))\n", 256 | " plt.title(f\"length_scale={ls}\")\n", 257 | " plt.plot(x1, f(x1), color=\"k\", label=\"true function\")\n", 258 | " plt.plot(x, y, 'o', color=\"k\", label=\"noisy observations\")\n", 259 | " plt.plot(x1, mu, color=\"C0\", label=\"prediction\")\n", 260 | " plt.fill_between(x1, mu + 2 * sigma, mu - 2 * sigma, color=\"C0\", alpha=0.3, label=\"uncertainty (2*std)\")\n", 261 | " for i,fx1 in enumerate(samples):\n", 262 | " plt.plot(x1, fx1, color=\"C0\", linestyle=\"--\", alpha=0.5, label=\"posterior samples\" if i==0 else \"\")\n", 263 | " plt.legend(loc=1)\n", 264 | " plt.show()" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "id": "615635d1-e212-4e2b-b0d5-5cafaf729db2", 270 | "metadata": {}, 271 | "source": [ 272 | "As you can see, some parameter values make the model fit the data better than others. Thus, the problem of learning a good model of the data consists of finding suitable kernel parameters for the GP model.\n", 273 | "In practice the kernel parameters can be optimised automatically using gradient based methods. \n", 274 | "Fortunately there are many GP packages available so we do not have to implement this ourselves. " 275 | ] 276 | } 277 | ], 278 | "metadata": { 279 | "kernelspec": { 280 | "display_name": "Python 3 (ipykernel)", 281 | "language": "python", 282 | "name": "python3" 283 | }, 284 | "language_info": { 285 | "codemirror_mode": { 286 | "name": "ipython", 287 | "version": 3 288 | }, 289 | "file_extension": ".py", 290 | "mimetype": "text/x-python", 291 | "name": "python", 292 | "nbconvert_exporter": "python", 293 | "pygments_lexer": "ipython3", 294 | "version": "3.9.6" 295 | } 296 | }, 297 | "nbformat": 4, 298 | "nbformat_minor": 5 299 | } 300 | -------------------------------------------------------------------------------- /notebooks/5_train_gpr.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "rk8JW-uTF0XC" 7 | }, 8 | "source": [ 9 | "# Train Gaussian process regression model\n", 10 | "\n", 11 | "In this notebook we will train the Gaussian process (GP) regression model that we will later use for the sensitivity analysis.\n", 12 | "\n", 13 | "We will go through the following steps:\n", 14 | "\n", 15 | "* Load the dataset.\n", 16 | "* Prepare the training and validation data.\n", 17 | "* Train a GP regression model.\n", 18 | "* Check the model predictions.\n", 19 | "* Save the trained model parameters to a file.\n" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "KgOGQl_AHEf_" 26 | }, 27 | "source": [ 28 | "## Dependencies\n", 29 | "\n", 30 | "First we import the dependencies.\n", 31 | "\n", 32 | "If you are in Colab, you need to install the [pyro](https://pyro.ai/) package by uncommenting and running the line `!pip3 install pyro-ppl` below before proceeding." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": { 39 | "executionInfo": { 40 | "elapsed": 218, 41 | "status": "ok", 42 | "timestamp": 1642348080942, 43 | "user": { 44 | "displayName": "jonas busk", 45 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 46 | "userId": "13756499934799797810" 47 | }, 48 | "user_tz": -60 49 | }, 50 | "id": "42A0W46fFhOl" 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "# install dependencies\n", 55 | "# !pip3 install pyro-ppl" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "colab": { 63 | "base_uri": "https://localhost:8080/" 64 | }, 65 | "executionInfo": { 66 | "elapsed": 3382, 67 | "status": "ok", 68 | "timestamp": 1642348084631, 69 | "user": { 70 | "displayName": "jonas busk", 71 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 72 | "userId": "13756499934799797810" 73 | }, 74 | "user_tz": -60 75 | }, 76 | "id": "I84YwXleGSY-", 77 | "outputId": "37246823-38b1-4e13-e529-bf7418925c6e" 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "# imports\n", 82 | "from collections import defaultdict\n", 83 | "from pathlib import Path\n", 84 | "\n", 85 | "import matplotlib.pyplot as plt\n", 86 | "import numpy as np\n", 87 | "import pandas as pd\n", 88 | "import torch\n", 89 | "import pyro\n", 90 | "import pyro.contrib.gp as gp\n", 91 | "\n", 92 | "pyro.set_rng_seed(0)\n", 93 | "print(f\"torch version: {torch.__version__}\")\n", 94 | "print(f\"pyro version: {pyro.__version__}\")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": { 100 | "id": "xtDnIX49G0Ww" 101 | }, 102 | "source": [ 103 | "## Load dataset\n", 104 | "\n", 105 | "We can load the dataset directly from the GitHub URL.\n", 106 | "Alternatively, the dataset can be loaded from a local file." 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "executionInfo": { 114 | "elapsed": 331, 115 | "status": "ok", 116 | "timestamp": 1642348084956, 117 | "user": { 118 | "displayName": "jonas busk", 119 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 120 | "userId": "13756499934799797810" 121 | }, 122 | "user_tz": -60 123 | }, 124 | "id": "hnz2yDrqyuv0" 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "# load dataset\n", 129 | "dataset_path = \"https://raw.githubusercontent.com/BIG-MAP/sensitivity_analysis_tutorial/main/data/p2d_sei_10k.csv\"\n", 130 | "# dataset_path = \"data/p2d_sei_10k.csv\" # local\n", 131 | "df = pd.read_csv(dataset_path, index_col=0)\n", 132 | "\n", 133 | "# store the names of the features and the name of the target variable\n", 134 | "features = df.columns[:15].tolist() # use input parameters as features\n", 135 | "target = \"SEI_thickness(m)\" # primary target\n", 136 | "# target = \"Capacity loss (%)\" # secondary target" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": { 142 | "id": "bmnY6vfXG4bO" 143 | }, 144 | "source": [ 145 | "## Prepare training and validation data\n", 146 | "\n", 147 | "In preparation for training the GP regression model we do a few data transformations:\n", 148 | "\n", 149 | "* The target variable is log transformed and normalised to zero mean and unit variance.\n", 150 | "* The input features are normalised to zero mean and unit variance to make the kernel parameters easier to learn and to put the inputs on the same scale and thus make results for each input directly comparable. \n", 151 | "\n", 152 | "Finally, the data is split into a training and a validation set. " 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": { 159 | "executionInfo": { 160 | "elapsed": 5, 161 | "status": "ok", 162 | "timestamp": 1642348084957, 163 | "user": { 164 | "displayName": "jonas busk", 165 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 166 | "userId": "13756499934799797810" 167 | }, 168 | "user_tz": -60 169 | }, 170 | "id": "PE8A-MTp0gLi" 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "# helper functions\n", 175 | "\n", 176 | "def create_data_split_index(n_data, n_train, n_valid=None, shuffle=False):\n", 177 | " \"\"\"Create data split index.\"\"\"\n", 178 | " n_valid = n_data - n_train if n_valid is None else n_valid \n", 179 | " index = torch.randperm(n_data) if shuffle else torch.arange(n_data)\n", 180 | " split = {\n", 181 | " \"train\": index[:n_train],\n", 182 | " \"valid\": index[n_train:n_train + n_valid],\n", 183 | " \"rest\": index[n_train + n_valid:],\n", 184 | " }\n", 185 | " return split\n", 186 | "\n", 187 | "def create_normaliser(x, y):\n", 188 | " \"\"\"Create data normalisation function\"\"\"\n", 189 | " x_mean, x_std = x.mean(axis=0), x.std(axis=0)\n", 190 | " y_mean, y_std = y.mean(axis=0), y.std(axis=0)\n", 191 | " def normaliser(x, y):\n", 192 | " return (x - x_mean) / x_std, (y - y_mean) / y_std\n", 193 | " normaliser_params = {\"x_mean\": x_mean, \"x_std\": x_std, \"y_mean\": y_mean, \"y_std\": y_std}\n", 194 | " return normaliser, normaliser_params" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": { 201 | "colab": { 202 | "base_uri": "https://localhost:8080/", 203 | "height": 228 204 | }, 205 | "executionInfo": { 206 | "elapsed": 545, 207 | "status": "ok", 208 | "timestamp": 1642348085498, 209 | "user": { 210 | "displayName": "jonas busk", 211 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 212 | "userId": "13756499934799797810" 213 | }, 214 | "user_tz": -60 215 | }, 216 | "id": "Ad--pHWb1FhX", 217 | "outputId": "2588172a-b058-4d6a-ad31-7f5dfde2a239" 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "# settings\n", 222 | "shuffle = False\n", 223 | "n_data = len(df)\n", 224 | "n_train = 5000\n", 225 | "n_valid = 5000\n", 226 | "\n", 227 | "assert n_train + n_valid <= n_data\n", 228 | "\n", 229 | "# create data tensors\n", 230 | "x_data_orig = torch.tensor(df[features].values, dtype=torch.float)\n", 231 | "y_data_orig = torch.tensor(df[target].values, dtype=torch.float)\n", 232 | "\n", 233 | "# log transform y\n", 234 | "y_data_orig = torch.log(y_data_orig)\n", 235 | "\n", 236 | "# create data split index\n", 237 | "split = create_data_split_index(n_data, n_train, n_valid)\n", 238 | "\n", 239 | "# create normalisation function from training split\n", 240 | "normaliser, normaliser_params = create_normaliser(x_data_orig[split[\"train\"]], y_data_orig[split[\"train\"]])\n", 241 | "\n", 242 | "# normalise data\n", 243 | "x_data, y_data = normaliser(x_data_orig, y_data_orig)\n", 244 | "\n", 245 | "# create data splits \n", 246 | "x_train, y_train = x_data[split[\"train\"]], y_data[split[\"train\"]]\n", 247 | "x_valid, y_valid = x_data[split[\"valid\"]], y_data[split[\"valid\"]]\n", 248 | "\n", 249 | "assert len(x_train) == len(y_train) == n_train\n", 250 | "assert len(x_valid) == len(y_valid) == n_valid\n", 251 | "\n", 252 | "n_bins = 50\n", 253 | "plt.figure(figsize=(8,3))\n", 254 | "plt.subplot(121)\n", 255 | "plt.hist(y_train.numpy(), bins=n_bins)\n", 256 | "plt.xlabel(\"y_train\")\n", 257 | "plt.subplot(122)\n", 258 | "plt.hist(y_valid.numpy(), bins=n_bins)\n", 259 | "plt.xlabel(\"y_valid\")\n", 260 | "plt.show()" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": { 266 | "id": "Axt5oeUCMJCP" 267 | }, 268 | "source": [ 269 | "## Train sparse GP regression model\n", 270 | "\n", 271 | "Now we train the GP regression model that we will later use in the sensitivity analysis.\n", 272 | "Specifically, we use the [SparseGPRegression](https://docs.pyro.ai/en/stable/contrib.gp.html#module-pyro.contrib.gp.models.sgpr) model from the [pyro](https://pyro.ai/) package because we have found it can handle rather large datasets while still being quite fast to train, and it is easy to use with automatic differentiation as we will see later.\n", 273 | "Please refer to the [pyro documentation](https://docs.pyro.ai/en/stable/contrib.gp.html#module-pyro.contrib.gp.models.sgpr) for details about the model.\n", 274 | "\n", 275 | "If at some point you want to apply this method on a small dataset, perhaps you do not need a sparse mode and you can use the simpler [GPRegression](https://docs.pyro.ai/en/stable/contrib.gp.html#module-pyro.contrib.gp.models.gpr) model instead.\n", 276 | "\n", 277 | "The model training might take a minute to run. " 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": { 284 | "executionInfo": { 285 | "elapsed": 4, 286 | "status": "ok", 287 | "timestamp": 1642348085499, 288 | "user": { 289 | "displayName": "jonas busk", 290 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 291 | "userId": "13756499934799797810" 292 | }, 293 | "user_tz": -60 294 | }, 295 | "id": "LPF-hk4OLJRQ" 296 | }, 297 | "outputs": [], 298 | "source": [ 299 | "# helper functions\n", 300 | "\n", 301 | "def mnll(loc, scale, targets):\n", 302 | " \"\"\"Compute mean negative log likelihood.\"\"\"\n", 303 | " log2pi = np.log(2 * np.pi)\n", 304 | " loglik = -0.5 * (torch.log(scale) + log2pi + (targets - loc)**2 / scale)\n", 305 | " return torch.mean(-loglik)\n", 306 | "\n", 307 | "def rmse(y_true, y_pred):\n", 308 | " \"\"\"Compute root mean squared error.\"\"\"\n", 309 | " return torch.sqrt(torch.mean((y_true - y_pred)**2))\n", 310 | "\n", 311 | "def mae(y_true, y_pred):\n", 312 | " \"\"\"Compute mean absolute error.\"\"\"\n", 313 | " return torch.mean(torch.abs(y_true - y_pred))\n", 314 | "\n", 315 | "def r2(y_true, y_pred):\n", 316 | " \"\"\"Compute coefficient of determination.\"\"\"\n", 317 | " ssr = torch.sum((y_true - y_pred)**2)\n", 318 | " sst = torch.sum((y_true - torch.mean(y_true))**2)\n", 319 | " return 1 - (ssr / sst)\n", 320 | "\n", 321 | "@torch.no_grad()\n", 322 | "def evaluate(model, x, y):\n", 323 | " \"\"\"Evaluate model.\"\"\"\n", 324 | " mean, var = model(x, full_cov=False, noiseless=False)\n", 325 | " errors = dict()\n", 326 | " errors[\"mnll\"] = mnll(mean, var, y).detach().item()\n", 327 | " errors[\"rmse\"] = rmse(y, mean).detach().item()\n", 328 | " errors[\"mae\"] = mae(y, mean).detach().item()\n", 329 | " errors[\"r2\"] = r2(y, mean).detach().item()\n", 330 | " return errors" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": { 337 | "colab": { 338 | "base_uri": "https://localhost:8080/" 339 | }, 340 | "executionInfo": { 341 | "elapsed": 39851, 342 | "status": "ok", 343 | "timestamp": 1642348125347, 344 | "user": { 345 | "displayName": "jonas busk", 346 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 347 | "userId": "13756499934799797810" 348 | }, 349 | "user_tz": -60 350 | }, 351 | "id": "ATZ9iYww298w", 352 | "outputId": "828b807f-48f5-48e6-c87d-99dfe81b1f36" 353 | }, 354 | "outputs": [], 355 | "source": [ 356 | "# train model\n", 357 | "\n", 358 | "def train(\n", 359 | " x_train,\n", 360 | " y_train,\n", 361 | " x_valid,\n", 362 | " y_valid,\n", 363 | " n_inducing_points=100,\n", 364 | " n_steps=1000,\n", 365 | " eval_freq=100,\n", 366 | " jitter=1.0e-5\n", 367 | "):\n", 368 | " pyro.clear_param_store()\n", 369 | " n_features = x_train.shape[1]\n", 370 | "\n", 371 | " # select the first n training points as the inducing inputs\n", 372 | " x_inducing = x_train[:n_inducing_points].clone()\n", 373 | " \n", 374 | " # initialise the kernel and model\n", 375 | " kernel = gp.kernels.RBF(input_dim=n_features, variance=torch.tensor(5.), lengthscale=torch.tensor(n_features * [10.]))\n", 376 | " model = gp.models.SparseGPRegression(x_train, y_train, kernel, Xu=x_inducing, jitter=jitter)\n", 377 | "\n", 378 | " # setup optimiser and loss function \n", 379 | " optimiser = torch.optim.Adam(model.parameters(), lr=0.01)\n", 380 | " loss_fn = pyro.infer.Trace_ELBO().differentiable_loss\n", 381 | "\n", 382 | " errors = defaultdict(list)\n", 383 | " for step in range(n_steps):\n", 384 | " # train\n", 385 | " optimiser.zero_grad()\n", 386 | " loss = loss_fn(model.model, model.guide)\n", 387 | " loss.backward()\n", 388 | " optimiser.step()\n", 389 | " # evaluate\n", 390 | " if step == 0 or (step + 1) % eval_freq == 0:\n", 391 | " with torch.no_grad():\n", 392 | " errors[\"train_step\"].append(step + 1)\n", 393 | " errors[\"train_loss\"].append(loss.item() / len(x_train))\n", 394 | " for k,v in evaluate(model, x_train, y_train).items():\n", 395 | " errors[\"train_\" + k].append(v)\n", 396 | " for k,v in evaluate(model, x_valid, y_valid).items():\n", 397 | " errors[\"valid_\" + k].append(v)\n", 398 | " print(f\"[{step + 1:5d}] train loss: {errors['train_loss'][-1]:7.4f} train mnll: {errors['train_mnll'][-1]:7.4f} valid mnll: {errors['valid_mnll'][-1]:7.4f}\") \n", 399 | " return model, errors\n", 400 | " \n", 401 | "\n", 402 | "model, errors = train(x_train, y_train, x_valid, y_valid, n_steps=800, jitter=1.0e-4)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": { 409 | "colab": { 410 | "base_uri": "https://localhost:8080/", 411 | "height": 279 412 | }, 413 | "executionInfo": { 414 | "elapsed": 22, 415 | "status": "ok", 416 | "timestamp": 1642348125347, 417 | "user": { 418 | "displayName": "jonas busk", 419 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 420 | "userId": "13756499934799797810" 421 | }, 422 | "user_tz": -60 423 | }, 424 | "id": "kNfqwOoOMiI6", 425 | "outputId": "680b06c0-fc40-49ec-fb28-17386bee019a" 426 | }, 427 | "outputs": [], 428 | "source": [ 429 | "# plot training curve\n", 430 | "plt.figure()\n", 431 | "plt.plot(errors[\"train_step\"], errors[\"train_mnll\"], label=\"train mnll\")\n", 432 | "plt.plot(errors[\"train_step\"], errors[\"valid_mnll\"], label=\"valid mnll\")\n", 433 | "plt.xlabel(\"training step\"); plt.ylabel(\"error\")\n", 434 | "plt.legend()\n", 435 | "plt.grid()\n", 436 | "plt.show()" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": { 442 | "id": "6GF12D715T4Y" 443 | }, 444 | "source": [ 445 | "We should see the training and validation errors go down with the number of training steps.\n", 446 | "Go ahead and plot some of the other errors stored in the `errors` dictionary if you like." 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": { 452 | "id": "oX8wpStIlfHj" 453 | }, 454 | "source": [ 455 | "## Check model predictions\n", 456 | "\n", 457 | "Before we do any further analyses, we want to verify that the model fits the training data and makes good predictions on the held-out validation data. " 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "metadata": { 464 | "executionInfo": { 465 | "elapsed": 21, 466 | "status": "ok", 467 | "timestamp": 1642348125348, 468 | "user": { 469 | "displayName": "jonas busk", 470 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 471 | "userId": "13756499934799797810" 472 | }, 473 | "user_tz": -60 474 | }, 475 | "id": "X5JHDtS4zbbo" 476 | }, 477 | "outputs": [], 478 | "source": [ 479 | "def evaluate_predictions(y_true, y_pred, lim=(-3,3), figsize=(5,5)):\n", 480 | " _r2 = r2(y_true, y_pred) # coefficient of determination\n", 481 | " _mae = mae(y_true, y_pred) # mean absolute error\n", 482 | " print(f\"r2: {_r2:.4f}, mae: {_mae:.4f}\\n\")\n", 483 | " # plot y_true against y_pred\n", 484 | " plt.figure(figsize=figsize)\n", 485 | " plt.plot(lim, lim, color=\"k\", linestyle=\"--\", linewidth=1)\n", 486 | " plt.plot(y_true, y_pred, \".\", alpha=0.1)\n", 487 | " plt.xlabel(\"y_true\"); plt.ylabel(\"y_pred\")\n", 488 | " plt.xlim(lim); plt.ylim(lim)\n", 489 | " plt.grid()\n", 490 | " plt.show()" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": { 497 | "colab": { 498 | "base_uri": "https://localhost:8080/", 499 | "height": 373 500 | }, 501 | "executionInfo": { 502 | "elapsed": 393, 503 | "status": "ok", 504 | "timestamp": 1642348125721, 505 | "user": { 506 | "displayName": "jonas busk", 507 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 508 | "userId": "13756499934799797810" 509 | }, 510 | "user_tz": -60 511 | }, 512 | "id": "klrdsztQllau", 513 | "outputId": "077bbb33-f178-40bf-b18c-682abd0551c2" 514 | }, 515 | "outputs": [], 516 | "source": [ 517 | "# evaluate on training data\n", 518 | "y_pred, y_var = model(x_train, full_cov=False, noiseless=False)\n", 519 | "evaluate_predictions(y_train.detach(), y_pred.detach())" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": null, 525 | "metadata": { 526 | "colab": { 527 | "base_uri": "https://localhost:8080/", 528 | "height": 373 529 | }, 530 | "executionInfo": { 531 | "elapsed": 366, 532 | "status": "ok", 533 | "timestamp": 1642348126084, 534 | "user": { 535 | "displayName": "jonas busk", 536 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 537 | "userId": "13756499934799797810" 538 | }, 539 | "user_tz": -60 540 | }, 541 | "id": "7dO9HdiZyHS6", 542 | "outputId": "4fdf3507-e98f-4882-af44-41b0ec9d0c34" 543 | }, 544 | "outputs": [], 545 | "source": [ 546 | "# evaluate on validation data\n", 547 | "y_pred, y_var = model(x_valid, full_cov=False, noiseless=False)\n", 548 | "evaluate_predictions(y_valid.detach(), y_pred.detach())" 549 | ] 550 | }, 551 | { 552 | "cell_type": "markdown", 553 | "metadata": { 554 | "id": "z2MlHHbuFama" 555 | }, 556 | "source": [ 557 | "We should see that the model achieves a r2 value close to 1, indicating the model is able to explain most of the variation in the data, and that the predictions generally correlate with the true target values on both the training and validation data splits." 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "metadata": { 563 | "id": "4l7Dk00yPIi-" 564 | }, 565 | "source": [ 566 | "## Save trained model\n", 567 | "\n", 568 | "Finally, we save the trained model parameters so we can use the model for analysis later.\n", 569 | "We additionally save some data parameters that will be useful later.\n", 570 | "\n", 571 | "IMPORTANT: If you are running this notebook in Colab, you should make sure to download the saved file as we will need it later in the tutorial series. \n", 572 | "You can find it in the Files section to the left (the small folder icon) after running the code below." 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": null, 578 | "metadata": { 579 | "colab": { 580 | "base_uri": "https://localhost:8080/" 581 | }, 582 | "executionInfo": { 583 | "elapsed": 5, 584 | "status": "ok", 585 | "timestamp": 1642348126085, 586 | "user": { 587 | "displayName": "jonas busk", 588 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 589 | "userId": "13756499934799797810" 590 | }, 591 | "user_tz": -60 592 | }, 593 | "id": "ZuoGiqab8H9S", 594 | "outputId": "ed40bbf6-0de6-4474-ba38-b0457b6a6733" 595 | }, 596 | "outputs": [], 597 | "source": [ 598 | "# store data normalisation parameters\n", 599 | "pyro.param(\"data.x_mean\", normaliser_params[\"x_mean\"])\n", 600 | "pyro.param(\"data.x_std\", normaliser_params[\"x_std\"])\n", 601 | "pyro.param(\"data.y_mean\", normaliser_params[\"y_mean\"])\n", 602 | "pyro.param(\"data.y_std\", normaliser_params[\"y_std\"])\n", 603 | "\n", 604 | "# store data range parameters\n", 605 | "pyro.param(\"data.x_min\", x_data.min(dim=0)[0])\n", 606 | "pyro.param(\"data.x_max\", x_data.max(dim=0)[0])\n", 607 | "pyro.param(\"data.y_min\", y_data.min())\n", 608 | "pyro.param(\"data.y_max\", y_data.max())\n", 609 | "\n", 610 | "# store training and validation data\n", 611 | "pyro.param(\"data.x_train\", x_train)\n", 612 | "pyro.param(\"data.y_train\", y_train)\n", 613 | "pyro.param(\"data.x_valid\", x_valid)\n", 614 | "pyro.param(\"data.y_valid\", y_valid)\n", 615 | "\n", 616 | "# save model parameters in a file\n", 617 | "print(pyro.get_param_store().keys())\n", 618 | "if target == \"SEI_thickness(m)\":\n", 619 | " pyro.get_param_store().save(\"sgpr_params_sei.p\")\n", 620 | "if target == \"Capacity loss (%)\":\n", 621 | " pyro.get_param_store().save(\"sgpr_params_cap.p\")\n", 622 | " \n", 623 | "# !!! remember to download the saved file !!!" 624 | ] 625 | } 626 | ], 627 | "metadata": { 628 | "colab": { 629 | "authorship_tag": "ABX9TyNYghml2cyAIF05V4NSOkWt", 630 | "collapsed_sections": [], 631 | "name": "train_sgpr.ipynb", 632 | "provenance": [] 633 | }, 634 | "kernelspec": { 635 | "display_name": "Python 3 (ipykernel)", 636 | "language": "python", 637 | "name": "python3" 638 | }, 639 | "language_info": { 640 | "codemirror_mode": { 641 | "name": "ipython", 642 | "version": 3 643 | }, 644 | "file_extension": ".py", 645 | "mimetype": "text/x-python", 646 | "name": "python", 647 | "nbconvert_exporter": "python", 648 | "pygments_lexer": "ipython3", 649 | "version": "3.9.6" 650 | } 651 | }, 652 | "nbformat": 4, 653 | "nbformat_minor": 4 654 | } 655 | -------------------------------------------------------------------------------- /notebooks/6_sensitivity_analysis_of_gpr.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "_tom9RzVKkrV", 7 | "tags": [] 8 | }, 9 | "source": [ 10 | "# Sensitivity analysis with GP regression model\n", 11 | "\n", 12 | "Now that we are familiar the data and have a trained GP regression model, we can proceed to the actual sensitivity analysis." 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": { 18 | "id": "lk0Uot9GK4OU" 19 | }, 20 | "source": [ 21 | "## Dependencies\n", 22 | "\n", 23 | "As in the previous notebooks, we start by importing all dependencies.\n", 24 | "\n", 25 | "If you are in Colab, you need to install the [pyro](https://pyro.ai/) package by uncommenting and running the line `!pip3 install pyro-ppl` below before proceeding." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": { 32 | "executionInfo": { 33 | "elapsed": 515, 34 | "status": "ok", 35 | "timestamp": 1642348244803, 36 | "user": { 37 | "displayName": "jonas busk", 38 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 39 | "userId": "13756499934799797810" 40 | }, 41 | "user_tz": -60 42 | }, 43 | "id": "-752nQQfK3Xu" 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "# install dependencies\n", 48 | "# !pip3 install pyro-ppl" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "colab": { 56 | "base_uri": "https://localhost:8080/" 57 | }, 58 | "executionInfo": { 59 | "elapsed": 1394, 60 | "status": "ok", 61 | "timestamp": 1642348246658, 62 | "user": { 63 | "displayName": "jonas busk", 64 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 65 | "userId": "13756499934799797810" 66 | }, 67 | "user_tz": -60 68 | }, 69 | "id": "AH752hkAGieC", 70 | "outputId": "a4165bc3-8268-4c5b-8b8b-968b57ec0d01" 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "# imports\n", 75 | "from collections import defaultdict\n", 76 | "from pathlib import Path\n", 77 | "\n", 78 | "import matplotlib.pyplot as plt\n", 79 | "from matplotlib.ticker import FormatStrFormatter\n", 80 | "import numpy as np\n", 81 | "import pandas as pd\n", 82 | "import torch\n", 83 | "import pyro\n", 84 | "import pyro.contrib.gp as gp\n", 85 | "\n", 86 | "pyro.set_rng_seed(0)\n", 87 | "print(f\"torch version: {torch.__version__}\")\n", 88 | "print(f\"pyro version: {pyro.__version__}\")" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": { 94 | "id": "aOX52GOBL7U8" 95 | }, 96 | "source": [ 97 | "## Load the dataset and model parameters\n", 98 | "\n", 99 | "We can load the dataset directly from the GitHub URL.\n", 100 | "Alternatively, the dataset can be loaded from a local file." 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "executionInfo": { 108 | "elapsed": 1010, 109 | "status": "ok", 110 | "timestamp": 1642348247664, 111 | "user": { 112 | "displayName": "jonas busk", 113 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 114 | "userId": "13756499934799797810" 115 | }, 116 | "user_tz": -60 117 | }, 118 | "id": "LiOZx8pWMIA5" 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "# load dataset\n", 123 | "dataset_path = \"https://raw.githubusercontent.com/BIG-MAP/sensitivity_analysis_tutorial/main/data/p2d_sei_10k.csv\"\n", 124 | "# dataset_path = \"data/p2d_sei_10k.csv\" # local\n", 125 | "df = pd.read_csv(dataset_path, index_col=0)\n", 126 | "\n", 127 | "# store the names of the features and the name of the target variable\n", 128 | "features = df.columns[:15].tolist() # use input parameters as features\n", 129 | "target = \"SEI_thickness(m)\" # primary target\n", 130 | "# target = \"Capacity loss (%)\" # secondary target" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": { 136 | "id": "J6Vqgb4eMgYE" 137 | }, 138 | "source": [ 139 | "We also need to load the trained model parameters that we saved in the previous notebook. \n", 140 | "\n", 141 | "If you are running this notebook in Colab, you need to make the parameter file available in the working directory by uploading it to the Files section to the left." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": { 148 | "colab": { 149 | "base_uri": "https://localhost:8080/" 150 | }, 151 | "executionInfo": { 152 | "elapsed": 42, 153 | "status": "ok", 154 | "timestamp": 1642348247665, 155 | "user": { 156 | "displayName": "jonas busk", 157 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 158 | "userId": "13756499934799797810" 159 | }, 160 | "user_tz": -60 161 | }, 162 | "id": "56DDLaOlMhow", 163 | "outputId": "bb5e88b1-ba87-4e05-beef-c69ed81dcb65" 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "pyro.clear_param_store()\n", 168 | "\n", 169 | "if target == \"SEI_thickness(m)\":\n", 170 | " pyro.get_param_store().load(\"sgpr_params_sei.p\")\n", 171 | "if target == \"Capacity loss (%)\":\n", 172 | " pyro.get_param_store().load(\"sgpr_params_cap.p\")\n", 173 | "\n", 174 | "params = pyro.get_param_store()\n", 175 | "params.keys()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": { 181 | "id": "RW2HoGqfO-hq" 182 | }, 183 | "source": [ 184 | "## Setup model\n", 185 | "\n", 186 | "Setup the model with the trained parameters." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": { 193 | "executionInfo": { 194 | "elapsed": 39, 195 | "status": "ok", 196 | "timestamp": 1642348247666, 197 | "user": { 198 | "displayName": "jonas busk", 199 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 200 | "userId": "13756499934799797810" 201 | }, 202 | "user_tz": -60 203 | }, 204 | "id": "q_FnFDO8PAXl" 205 | }, 206 | "outputs": [], 207 | "source": [ 208 | "kernel = gp.kernels.RBF(input_dim=params[\"data.x_train\"].shape[1], variance=params[\"kernel.variance\"], lengthscale=params[\"kernel.lengthscale\"])\n", 209 | "model = gp.models.SparseGPRegression(params[\"data.x_train\"], params[\"data.y_train\"], kernel, Xu=params[\"Xu\"], noise=params[\"noise\"])" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": { 215 | "id": "4AUv4iOePQZV" 216 | }, 217 | "source": [ 218 | "## Global sensitivity analysis\n", 219 | "\n", 220 | "Here we compute the average sensitivity of each input parameter $j$ using the validation dataset.\n", 221 | "The sensitivities are computed by taking the gradient of the predicted output $f(\\mathbf{x}_n)$ with respect to each input $x_{n,j}$ averaged over the data:\n", 222 | "\n", 223 | "$$\n", 224 | "s_j^f = \\sqrt{ \\frac{1}{N} \\sum_{n=1}^N \\Big( \\frac{\\partial f(\\mathbf{x}_n)}{\\partial x_{n,j}} \\Big)^2 }\n", 225 | "$$" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": { 232 | "executionInfo": { 233 | "elapsed": 38, 234 | "status": "ok", 235 | "timestamp": 1642348247666, 236 | "user": { 237 | "displayName": "jonas busk", 238 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 239 | "userId": "13756499934799797810" 240 | }, 241 | "user_tz": -60 242 | }, 243 | "id": "AXV9iKu8PTaA" 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "def sa_autograd(model, X, reduce=None): \n", 248 | " \"\"\"Sensitivity analysis of GP regression model with automatic differentiation.\n", 249 | " \n", 250 | " Args:\n", 251 | " model: Gaussian process regression model\n", 252 | " X (tensor): Input data (design matrix)\n", 253 | " reduce (string): method used to reduce the sensitivity result: sum, mean, none.\n", 254 | " \"\"\"\n", 255 | " X.requires_grad = True\n", 256 | " # compute gradient of the mean prediction\n", 257 | " model.zero_grad()\n", 258 | " mean, _ = model(X, full_cov=False, noiseless=False)\n", 259 | " gmean = torch.autograd.grad(mean.sum(), X)[0]\n", 260 | " # compute gradient of the variance prediction\n", 261 | " model.zero_grad()\n", 262 | " _, var = model(X, full_cov=False, noiseless=False)\n", 263 | " gvar = torch.autograd.grad(var.sum(), X)[0]\n", 264 | " X.requires_grad = False\n", 265 | " if reduce == \"sum\":\n", 266 | " return mean, var, torch.sqrt(torch.sum(gmean**2, dim=0)), torch.sqrt(torch.sum(gvar**2, dim=0))\n", 267 | " elif reduce == \"mean\":\n", 268 | " return mean, var, torch.sqrt(torch.mean(gmean**2, dim=0)), torch.sqrt(torch.mean(gvar**2, dim=0))\n", 269 | " else:\n", 270 | " return mean, var, torch.sqrt(gmean**2), torch.sqrt(gvar**2)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": { 277 | "executionInfo": { 278 | "elapsed": 39, 279 | "status": "ok", 280 | "timestamp": 1642348247667, 281 | "user": { 282 | "displayName": "jonas busk", 283 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 284 | "userId": "13756499934799797810" 285 | }, 286 | "user_tz": -60 287 | }, 288 | "id": "jFBUGhMgQQvx" 289 | }, 290 | "outputs": [], 291 | "source": [ 292 | "def plot_sensitivity_bar(s_mean, s_var, features=None, normalise=False):\n", 293 | " features = list(range(len(s_mean))) if features is None else features\n", 294 | " \n", 295 | " # normalise\n", 296 | " if normalise:\n", 297 | " s_mean = s_mean / s_mean.sum()\n", 298 | " s_var = s_var / s_var.sum()\n", 299 | "\n", 300 | " plt.figure(figsize=(6,3))\n", 301 | " plt.title(\"average sensitivities of the mean prediction\")\n", 302 | " plt.bar(range(len(features)), s_mean)\n", 303 | " plt.xticks(range(len(features)), [f\"x{i}: {f}\" for i,f in enumerate(features)], rotation=90)\n", 304 | " plt.xlabel(\"Feature\"); plt.ylabel(\"Sensitivity\")\n", 305 | " plt.show()\n", 306 | "\n", 307 | " plt.figure(figsize=(6,3))\n", 308 | " plt.title(\"average sensitivities of the variance prediction\")\n", 309 | " plt.bar(range(len(features)), s_var, color=\"C1\")\n", 310 | " plt.xticks(range(len(features)), [f\"x{i}: {f}\" for i,f in enumerate(features)], rotation=90)\n", 311 | " plt.xlabel(\"Feature\"); plt.ylabel(\"Sensitivity\")\n", 312 | " plt.show()" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": { 319 | "colab": { 320 | "base_uri": "https://localhost:8080/", 321 | "height": 715 322 | }, 323 | "executionInfo": { 324 | "elapsed": 879, 325 | "status": "ok", 326 | "timestamp": 1642348248507, 327 | "user": { 328 | "displayName": "jonas busk", 329 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 330 | "userId": "13756499934799797810" 331 | }, 332 | "user_tz": -60 333 | }, 334 | "id": "yD9wrxecQalZ", 335 | "outputId": "3b57a57d-a28d-474e-c846-2d3f7e5eb798" 336 | }, 337 | "outputs": [], 338 | "source": [ 339 | "_, _, s_mean, s_var = sa_autograd(model, params[\"data.x_valid\"], reduce=\"mean\")\n", 340 | "\n", 341 | "plot_sensitivity_bar(s_mean, s_var, features, normalise=True)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "metadata": { 347 | "id": "5AZeEkVTSzEj" 348 | }, 349 | "source": [ 350 | "The sensitivities are normalised so they sum to 1 as we are mainly interested in the relative sensitivities.\n", 351 | "\n", 352 | "Notice how only a few of the input parameters seem to have high average sensitivity and thus be important.\n", 353 | "\n", 354 | "If you made note of any particular input parameters while doing the initial data exploration, how does it compare to the sensitivities? \n", 355 | "Do the inputs you noticed correspond to the most important inputs found by the sensitivity analysis?\n", 356 | "\n", 357 | "If you did the optional analysis of the Bayesian linear model, how does the results compare?\n", 358 | "\n", 359 | "Here we used the validation dataset to compute the sensitivities. \n", 360 | "We could also have sampled new inputs in the appropriate range and used that for the sensitivity analysis (since we do not need to know the true outputs in this analysis). \n", 361 | "However, since we know the validation data is sampled at random, we would expect to get very similar results.\n", 362 | "\n", 363 | "If you are familiar with automatic relevance determination (ARD), you can try to compute feature importances based on ARD defined as the inverse of the kernel length scale parameters (available in `params[\"kernel.lengthscale\"]`) and compare the result with the global sensitivity analysis above.\n", 364 | "Note that [ARD has been shown to overestimate the importance of nonlinear features](http://proceedings.mlr.press/v89/paananen19a/paananen19a.pdf)." 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": { 370 | "id": "RlbcpERZUBe-" 371 | }, 372 | "source": [ 373 | "## Local sensitivity analysis\n", 374 | "\n", 375 | "Looking at the sensitivities averaged over the data is useful for identifying the most important inputs.\n", 376 | "But we might get a better understanding of the data by considering the predictions and sensitivities along the entire range of variation of each input (while keeping all other inputs fixed at their nominal values)." 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": { 383 | "executionInfo": { 384 | "elapsed": 4, 385 | "status": "ok", 386 | "timestamp": 1642348248508, 387 | "user": { 388 | "displayName": "jonas busk", 389 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 390 | "userId": "13756499934799797810" 391 | }, 392 | "user_tz": -60 393 | }, 394 | "id": "OQdJigPYRlEa" 395 | }, 396 | "outputs": [], 397 | "source": [ 398 | "# helper functions\n", 399 | "def predict_sa(x):\n", 400 | " return sa_autograd(model, x, reduce=None)\n", 401 | "\n", 402 | "def predict_and_plot_1d(d, predict_sa, features, target, x_min, x_max, x_nominal, y_lim=None, n_points=100, figsize=(12,3)):\n", 403 | " # create inputs\n", 404 | " x = x_nominal\n", 405 | " X = x.repeat(n_points, 1)\n", 406 | " xd = torch.linspace(x_min[d], x_max[d], n_points)\n", 407 | " X[:,d] = xd\n", 408 | " # predict point\n", 409 | " mean0, var0, s_mean0, s_var0 = predict_sa(x.unsqueeze(0))\n", 410 | " mean0, var0, s_mean0, s_var0 = mean0.detach(), var0.detach(), s_mean0.detach(), s_var0.detach()\n", 411 | " std0 = var0.sqrt()\n", 412 | " # predict grid\n", 413 | " mean, var, s_mean, s_var = predict_sa(X)\n", 414 | " mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()\n", 415 | " std = var.sqrt().detach()\n", 416 | " # plot\n", 417 | " plt.figure(figsize=figsize)\n", 418 | " # plot mean prediction with uncertainty\n", 419 | " plt.subplot(121)\n", 420 | " plt.title(\"mean prediction with uncertainty (2*std)\")\n", 421 | " plt.plot(xd.numpy(), mean.numpy())\n", 422 | " plt.fill_between(xd.numpy(), (mean.numpy() - 2.0 * std.numpy()), (mean.numpy() + 2.0 * std.numpy()), color='C0', alpha=0.3)\n", 423 | " plt.axvline(x[d].numpy(), color=\"k\", linewidth=1, label=f\"{mean0.item():.4f} ({std0.item():.4f})\")\n", 424 | " plt.xlim((x_min[d], x_max[d]))\n", 425 | " if y_lim is not None:\n", 426 | " plt.ylim(y_lim)\n", 427 | " plt.xlabel(f\"x{d}: {features[d]}\")\n", 428 | " plt.ylabel(f\"log y: {target}\")\n", 429 | " plt.grid()\n", 430 | " plt.legend(loc=4)\n", 431 | " # plot sensitivity of mean prediction\n", 432 | " plt.subplot(122)\n", 433 | " plt.title(\"sensitivity of mean prediction\")\n", 434 | " plt.plot(xd.numpy(), s_mean[:, d].numpy())\n", 435 | " plt.axvline(x[d].numpy(), color=\"k\", linewidth=1, label=f\"{s_mean0[:,d].item():.4f}\")\n", 436 | " plt.xlim((x_min[d], x_max[d]))\n", 437 | " plt.ylim((0,5))\n", 438 | " plt.xlabel(f\"x{d}: {features[d]}\")\n", 439 | " plt.ylabel(\"sensitivity\")\n", 440 | " plt.grid()\n", 441 | " plt.legend(loc=4)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": { 448 | "colab": { 449 | "base_uri": "https://localhost:8080/", 450 | "height": 1000 451 | }, 452 | "executionInfo": { 453 | "elapsed": 7433, 454 | "status": "ok", 455 | "timestamp": 1642348255937, 456 | "user": { 457 | "displayName": "jonas busk", 458 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 459 | "userId": "13756499934799797810" 460 | }, 461 | "user_tz": -60 462 | }, 463 | "id": "ds60GTH-Vqln", 464 | "outputId": "1595491d-f887-4877-d8cf-ef4d2abcf456" 465 | }, 466 | "outputs": [], 467 | "source": [ 468 | "for d in range(len(features)):\n", 469 | " predict_and_plot_1d(\n", 470 | " d,\n", 471 | " predict_sa,\n", 472 | " features,\n", 473 | " target,\n", 474 | " x_min=params[\"data.x_min\"].detach().numpy(),\n", 475 | " x_max=params[\"data.x_max\"].detach().numpy(),\n", 476 | " x_nominal=params[\"data.x_train\"][0].detach(), # the first training point correponds to the nominal values\n", 477 | " y_lim=(params[\"data.y_min\"].item(), params[\"data.y_max\"].item()),\n", 478 | " )" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": { 484 | "id": "_L72a4XKXTXv" 485 | }, 486 | "source": [ 487 | "Some of the prediction curves are almost entirely flat because changing their value does not change the output.\n", 488 | "These correspond to the inputs with low average sensitivity that we identified above.\n", 489 | "\n", 490 | "Maybe you also notice that some inputs seem to affect the output along their entire range while some other inputs only seem to affect the output at some specific range of values (for example only high or low values). \n", 491 | "\n", 492 | "For each of the important inputs, try to characterise the effect they have on the output:\n", 493 | " * Is it linear or nonlinear?\n", 494 | " * Is it sensitive along its entire range of values or not?" 495 | ] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "metadata": { 500 | "id": "euGO-YNCZHTt" 501 | }, 502 | "source": [ 503 | "Rather than looking at the inputs in one dimension, we can also plot two inputs against each other in two dimensions." 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": null, 509 | "metadata": { 510 | "executionInfo": { 511 | "elapsed": 11, 512 | "status": "ok", 513 | "timestamp": 1642348255937, 514 | "user": { 515 | "displayName": "jonas busk", 516 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 517 | "userId": "13756499934799797810" 518 | }, 519 | "user_tz": -60 520 | }, 521 | "id": "wT6Oo8nPV9DY" 522 | }, 523 | "outputs": [], 524 | "source": [ 525 | "def predict_and_plot_2d(d0, d1, predict_sa, features, target, x_min, x_max, x_nominal, y_lim=None, n_points=100, n_levels=21, figsize=(12,10)):\n", 526 | " # create inputs\n", 527 | " x = x_nominal\n", 528 | " X = x.repeat(n_points**2, 1)\n", 529 | " # setup grid\n", 530 | " xd0 = torch.linspace(x_min[d0], x_max[d0], n_points)\n", 531 | " xd1 = torch.linspace(x_min[d1], x_max[d1], n_points)\n", 532 | " grid_xd0, grid_xd1 = torch.meshgrid(xd0, xd1) \n", 533 | " X[:,d0] = grid_xd0.reshape(len(X))\n", 534 | " X[:,d1] = grid_xd1.reshape(len(X))\n", 535 | " # predict point\n", 536 | " mean0, var0, s_mean0, s_var0 = predict_sa(x.unsqueeze(0))\n", 537 | " mean0, var0, s_mean0, s_var0 = mean0.detach(), var0.detach(), s_mean0.detach(), s_var0.detach()\n", 538 | " std0 = var0.sqrt()\n", 539 | " # predict grid\n", 540 | " mean, var, s_mean, s_var = predict_sa(X)\n", 541 | " mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()\n", 542 | " std = var.sqrt()\n", 543 | "\n", 544 | " s_mean0_d = (s_mean0[:, d0] + s_mean0[:, d1]).item()\n", 545 | " s_var0_d = (s_var0[:, d0] + s_var0[:, d1]).item()\n", 546 | "\n", 547 | " s_mean_d = (s_mean[:, d0] + s_mean[:, d1]).reshape(n_points, n_points)\n", 548 | " s_var_d = (s_var[:, d0] + s_var[:, d1]).reshape(n_points, n_points)\n", 549 | "\n", 550 | " plt.figure(figsize=figsize)\n", 551 | " # plot mean prediction\n", 552 | " ax = plt.subplot(221)\n", 553 | " plt.title(\"mean prediction of log y\")\n", 554 | " if y_lim is None:\n", 555 | " levels = torch.linspace(mean.min().item(), mean.max().item(), n_levels).numpy()\n", 556 | " else:\n", 557 | " levels = torch.linspace(y_lim[0], y_lim[1], n_levels).numpy()\n", 558 | " plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), mean.reshape(n_points, n_points).numpy(), levels=levels, cmap=\"plasma\")\n", 559 | " plt.axvline(x[d0].numpy(), color=\"k\", linewidth=1, label=f\"{mean0.item():.4f} ({std0.item():.4f})\")\n", 560 | " plt.axhline(x[d1].numpy(), color=\"k\", linewidth=1)\n", 561 | " plt.xlabel(f\"x{d0}: {features[d0]}\"); plt.ylabel(f\"x{d1}: {features[d1]}\")\n", 562 | " plt.colorbar(shrink=0.9)\n", 563 | " ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))\n", 564 | " plt.legend(loc=4)\n", 565 | " # plot uncertainty\n", 566 | " ax = plt.subplot(222)\n", 567 | " plt.title(\"uncertainty (2*std)\")\n", 568 | " levels = torch.linspace(0, 1.0, 21).numpy()\n", 569 | " plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), 2*std.reshape(n_points, n_points).numpy(), levels=levels, cmap=\"plasma\")\n", 570 | " plt.axvline(x[d0].numpy(), color=\"k\", linewidth=1, label=f\"{std0.item()*2:.4f}\")\n", 571 | " plt.axhline(x[d1].numpy(), color=\"k\", linewidth=1)\n", 572 | " plt.xlabel(f\"x{d0}: {features[d0]}\"); plt.ylabel(f\"x{d1}: {features[d1]}\")\n", 573 | " plt.colorbar(shrink=0.9)\n", 574 | " ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))\n", 575 | " plt.legend(loc=4)\n", 576 | " # plot sensitivity of mean prediction\n", 577 | " ax = plt.subplot(223)\n", 578 | " plt.title(\"sensitivity of mean prediction\")\n", 579 | " levels = torch.linspace(0, 5.0, 21).numpy()\n", 580 | " plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), s_mean_d.numpy(), levels=levels, cmap=\"plasma\")\n", 581 | " plt.axvline(x[d0].numpy(), color=\"k\", linewidth=1, label=f\"{s_mean0_d:.4f}\")\n", 582 | " plt.axhline(x[d1].numpy(), color=\"k\", linewidth=1)\n", 583 | " plt.xlabel(f\"x{d0}: {features[d0]}\"); plt.ylabel(f\"x{d1}: {features[d1]}\")\n", 584 | " plt.colorbar(shrink=0.9)\n", 585 | " ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))\n", 586 | " plt.legend(loc=4)\n", 587 | " # plot sensitivity of uncertainty prediction\n", 588 | " ax = plt.subplot(224)\n", 589 | " plt.title(\"sensitivity of uncertainty prediction\")\n", 590 | " levels = torch.linspace(0, 0.25, 21).numpy()\n", 591 | " plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), s_var_d.numpy(), levels=levels, cmap=\"plasma\")\n", 592 | " plt.axvline(x[d0].numpy(), color=\"k\", linewidth=1, label=f\"{s_var0_d:.4f}\")\n", 593 | " plt.axhline(x[d1].numpy(), color=\"k\", linewidth=1)\n", 594 | " plt.xlabel(f\"x{d0}: {features[d0]}\"); plt.ylabel(f\"x{d1}: {features[d1]}\")\n", 595 | " plt.colorbar(shrink=0.9)\n", 596 | " ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))\n", 597 | " plt.legend(loc=4)\n", 598 | " plt.tight_layout()\n", 599 | " plt.show()" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": null, 605 | "metadata": { 606 | "colab": { 607 | "base_uri": "https://localhost:8080/", 608 | "height": 784 609 | }, 610 | "executionInfo": { 611 | "elapsed": 2241, 612 | "status": "ok", 613 | "timestamp": 1642348258168, 614 | "user": { 615 | "displayName": "jonas busk", 616 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gj7-1gAF7PppPBq1jWtOrRLj_kiVnCZpQDWsCTO4g=s64", 617 | "userId": "13756499934799797810" 618 | }, 619 | "user_tz": -60 620 | }, 621 | "id": "qTbdKI4PaVak", 622 | "outputId": "474b0699-59c8-4ec3-eb2a-674fbd685b76" 623 | }, 624 | "outputs": [], 625 | "source": [ 626 | "predict_and_plot_2d(\n", 627 | " 0, 2, # <-- change the input dimensions that are plotted here\n", 628 | " predict_sa,\n", 629 | " features,\n", 630 | " target,\n", 631 | " x_min=params[\"data.x_min\"].detach().numpy(),\n", 632 | " x_max=params[\"data.x_max\"].detach().numpy(),\n", 633 | " x_nominal=params[\"data.x_train\"][0].detach(), # the first training point correponds to the nominal values\n", 634 | " y_lim=(params[\"data.y_min\"].item(), params[\"data.y_max\"].item()),\n", 635 | ")" 636 | ] 637 | }, 638 | { 639 | "cell_type": "markdown", 640 | "metadata": { 641 | "id": "8e0-4h0rbNWN" 642 | }, 643 | "source": [ 644 | "Here we plotted input 0 against input 2.\n", 645 | "You can change the inputs that are plotted in the code above.\n", 646 | "How about for example inputs 8 and 9?\n", 647 | "\n", 648 | "These figures can reveal interesting properties of the data.\n", 649 | "However, even when plotting two inputs against each other along their entire ranges of values, we still need to assume fixed values for all the other inputs.\n", 650 | "But changing the value of some sensitive input could potentially interact with other sensitive inputs.\n", 651 | "Unfortunately, it is difficult to visualize such effects for high dimensional problems like this one.\n", 652 | "In the next notebooks we will try to mitigate this and make exploring the results of the sensitivity analysis more intuitive by creating interactive plots.\n", 653 | "\n", 654 | "As always, we should be aware of the assumptions we made in the analysis and keep them in mind when interpreting the results.\n", 655 | "* The validity of the results depends on how well the model fits the data.\n", 656 | "* In this example we made the analysis with regards to the log transformed output and care should be taken if we were to back-transform the results to the original scale since this is a nonlinear transformation and the predictive distribution would no longer be Gaussian." 657 | ] 658 | } 659 | ], 660 | "metadata": { 661 | "colab": { 662 | "authorship_tag": "ABX9TyMs6SXpRqgByQQ8tGM66L2M", 663 | "collapsed_sections": [], 664 | "name": "sensitivity_analysis_of_sgpr.ipynb", 665 | "provenance": [] 666 | }, 667 | "kernelspec": { 668 | "display_name": "Python 3 (ipykernel)", 669 | "language": "python", 670 | "name": "python3" 671 | }, 672 | "language_info": { 673 | "codemirror_mode": { 674 | "name": "ipython", 675 | "version": 3 676 | }, 677 | "file_extension": ".py", 678 | "mimetype": "text/x-python", 679 | "name": "python", 680 | "nbconvert_exporter": "python", 681 | "pygments_lexer": "ipython3", 682 | "version": "3.9.6" 683 | } 684 | }, 685 | "nbformat": 4, 686 | "nbformat_minor": 4 687 | } 688 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | pandas 4 | torch 5 | pyro-ppl 6 | --------------------------------------------------------------------------------