├── .DS_Store ├── assets └── images │ └── syllabus.png ├── README.md ├── week1 └── interpreting_image_classifier.ipynb └── week4 └── counterfactual_explanations.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nazneenrajani/interpreting-ml-models-course/HEAD/.DS_Store -------------------------------------------------------------------------------- /assets/images/syllabus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nazneenrajani/interpreting-ml-models-course/HEAD/assets/images/syllabus.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interpreting ML Models 2 | 3 | This repo contains the projects for the course on [Interpreting ML Models](https://corise.com/course/interpreting-machine-learning-models?utm_source=nazneen) 4 | 5 | It is a 4 week fast-paced hands-on course that covers a breadth of different interpretability approaches for deep learning models. 6 | 7 | Each week focuses on a different data modality including vision, text, and tabular data. We also discuss evaluation metrics for interpretability methods and their tradeoffs in Week 3. 8 | 9 | I am grateful to [Praneeth](https://github.com/praneethd7) for helping with the project content and adding the nice storyline to make each project super engaging and 10x more fun! 10 | 11 | Here's the detailed syllabus for the course. 12 | ![course syllabus](/assets/images/syllabus.png) 13 | -------------------------------------------------------------------------------- /week1/interpreting_image_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "7BWQ2MzN6WYe" 7 | }, 8 | "source": [ 9 | "# Week 1 - Interpreting Image Classifiers" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "source": [ 15 | "Welcome to the week 1 project of the Interpreting Machine Learning Models course! We are excited to help you unravel the mysteries behind machine learning algorithms." 16 | ], 17 | "metadata": { 18 | "id": "9R2KW2e0gIKw" 19 | } 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "source": [ 24 | "## Introduction - Week 1 Challenge" 25 | ], 26 | "metadata": { 27 | "id": "QyVWpAmGlF1A" 28 | } 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "source": [ 33 | "It's 2050 and a mysterious virus has caused the global cat population to become hilariously clumsy and forgetful. People can often be seen watching in amusement as their feline companions stumble into walls, accidentally headbutt their own tails, and knock over everything in their way!\n", 34 | "\n", 35 | "
\n", 36 | "\n", 37 | "The situation quickly becomes frustrating for cat owners. Cats can no longer be left alone, as they are prone to forgetting where they put their toys and treats, and are even known to accidentally lock themselves in closets and bathrooms.\n", 38 | "\n", 39 | "To address this problem, cat owners decide to deploy cameras equipped with machine vision to detect and track the activities of these forgetful felines. However, they want to make sure the algorithm can accurately identify cats and doesn't raise false alarms, especially while the owners are napping. To achieve this, they hire a machine learning expert(you, yes you) to interpret the algorithm." 40 | ], 41 | "metadata": { 42 | "id": "Rs_gFaLAizxl" 43 | } 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "source": [ 48 | "## We need you! [TODO]" 49 | ], 50 | "metadata": { 51 | "id": "e0VxLjMIlJwU" 52 | } 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "source": [ 57 | "You are given a pre-trained ResNet model that is trained on Imagenet 1k dataset. Your task is to interpret \"Why the ResNet model detects cats?\"\n", 58 | "\n", 59 | "For interpreting a classification task, there are multiple dimensions to choose from (Global vs Local, Model agnostic vs. specific, Inherent vs. post hoc). We will be using a Model agnostic post hoc method and deploy it at a local scale\n", 60 | "\n", 61 | "Specifically, we will use LIME, SHAP, and integrated-gradient in this project. For each of these algorithms, you will be documenting the compute time and visualizing their explanations. At the end of the project, you'll be comparing the three evaluation approaches and assessing which you agree with most. So let's dive in!" 62 | ], 63 | "metadata": { 64 | "id": "OWqyx_kHi_IJ" 65 | } 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": { 70 | "id": "dMc1B4o1urim" 71 | }, 72 | "source": [ 73 | "## Setup\n", 74 | "Before we start our mission, lets gets some gear set up. Firstly, lets install the missing packages and import the necessary libraries" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": { 80 | "id": "A1dFx3pWqiXs" 81 | }, 82 | "source": [ 83 | "### Installation of Libraries" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": { 90 | "id": "7caVWltC6hUv" 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "!pip install omnixai\n", 95 | "!pip install dash\n", 96 | "!pip install dash-bootstrap-components\n", 97 | "## For local tunnel to a proxy server \n", 98 | "!npm install localtunnel" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": { 104 | "id": "AOmqoiNrqqv7" 105 | }, 106 | "source": [ 107 | "### Imports" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "source": [ 113 | "First, we will import some usual suspects. We will use Pillow Image library to laod/create images. Finally, let us import our main weapon. Let us use [OmniXAI](https://opensource.salesforce.com/OmniXAI/latest/index.html) (Omni eXplainable AI), a Python library for explainable AI (XAI)." 114 | ], 115 | "metadata": { 116 | "id": "PhOmsFd1ddZ8" 117 | } 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "id": "7_ZCAeu_6dYH" 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "## The usual suspects\n", 128 | "import json\n", 129 | "import numpy as np\n", 130 | "import requests\n", 131 | "import pickle\n", 132 | "\n", 133 | "## To build our classifer\n", 134 | "import torch\n", 135 | "from torchvision import models, transforms\n", 136 | "\n", 137 | "## Pillow Library Image function alias PilImage\n", 138 | "from PIL import Image as PilImage\n", 139 | "\n", 140 | "## Omnixai library to build our explainer\n", 141 | "from omnixai.preprocessing.image import Resize\n", 142 | "from omnixai.data.image import Image\n", 143 | "from omnixai.explainers.vision import VisionExplainer\n", 144 | "from omnixai.visualization.dashboard import Dashboard" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": { 150 | "id": "xIIFbVqpukVN" 151 | }, 152 | "source": [ 153 | "## Image Data and Classifier" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": { 160 | "id": "B1tgSQQE6nsH" 161 | }, 162 | "outputs": [], 163 | "source": [ 164 | "## Let's start by loading the image that we want to explain\n", 165 | "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n", 166 | "download = requests.get(url, stream=True).raw\n", 167 | "\n", 168 | "## TODO: Read the image using Pillow and convert the image into RBG\n", 169 | "### Hint: Use PilImage to read and convert\n", 170 | "\n", 171 | "# image = Image(...)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "source": [ 177 | "## TODO: Print the image shape and view the image\n", 178 | "\n", 179 | "## Print the image shape\n", 180 | "# print(...)\n", 181 | "\n", 182 | "# Now, let's view it\n", 183 | "image.to_pil()\n", 184 | "# Shh! They are napping..." 185 | ], 186 | "metadata": { 187 | "id": "ddqyZv2mjFYP" 188 | }, 189 | "execution_count": null, 190 | "outputs": [] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": { 196 | "id": "IEq2_rZY6qj5" 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "## Before we build our classifier, lets make sure to setup the device.\n", 201 | "## To run this notbeook via GPU: Edit->Notebook settings ->Hardware accelerator -> GPU\n", 202 | "## If your GPU is working, device is \"cuda\"\n", 203 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 204 | "device" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "source": [ 210 | "## TODO: Lets build our classification model. We will use pre-trained ResNet34 model from PyTorch torchvision models.\n", 211 | "## Make sure to load the model onto the device for gpu\n", 212 | "\n", 213 | "# model = ..." 214 | ], 215 | "metadata": { 216 | "id": "K67JdQNgijpU" 217 | }, 218 | "execution_count": null, 219 | "outputs": [] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "source": [ 224 | "# Lets get a summary of our model using torchsummary\n", 225 | "from torchsummary import summary\n", 226 | "## TODO: Print the model summary\n", 227 | "### Hint: Use image shape for input_size\n", 228 | "# summary(...)" 229 | ], 230 | "metadata": { 231 | "id": "mk2KQ6e5uQHy" 232 | }, 233 | "execution_count": null, 234 | "outputs": [] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "source": [ 239 | "## Did you notice the last layer had 1000 classes. Lets import all the classes. \n", 240 | "## We will later pass this to our explainer\n", 241 | "classes_url = 'https://gist.githubusercontent.com/DaniFojo/dad37f5bf00ddeb56ed36daf561dbf69/raw/bd006b86300a5886ac7f897a44b0525b75a4b5a1/imagenet_labels.json'\n", 242 | "imagenet_classes = json.loads(requests.get(classes_url).text)\n", 243 | "idx2label = {int(k):v for k,v in imagenet_classes.items()}\n", 244 | "\n", 245 | "first_label = idx2label[next(iter(idx2label))]\n", 246 | "print(f\"The first class label from the ImageNet dataset is: '{first_label}'\")" 247 | ], 248 | "metadata": { 249 | "id": "hLTbu2GHLdT3" 250 | }, 251 | "execution_count": null, 252 | "outputs": [] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": { 257 | "id": "zmyv-bo3u9c8" 258 | }, 259 | "source": [ 260 | "## Buiding our Explainer\n", 261 | "\n", 262 | "To build our Explainer for our model, we will use [Vision Explainer](https://opensource.salesforce.com/OmniXAI/v1.2.3/omnixai.explainers.vision.html) by OmniXAI. The explainer needs some pre-processing and post-processing." 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": { 268 | "id": "WnrmP6R1txc0" 269 | }, 270 | "source": [ 271 | "### Pre-processor" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": { 278 | "id": "wGpN6fIttv1O" 279 | }, 280 | "outputs": [], 281 | "source": [ 282 | "## TODO: Build the pre-processor pipeline for the explainer\n", 283 | "\n", 284 | "# The preprocessing function should convert the image to a Tensor \n", 285 | "# and then Normalise it\n", 286 | "\n", 287 | "# 1. Compose the transformations\n", 288 | "# transform = transforms.Compose([\n", 289 | " ## 1a. write code to convert the image to tensor\n", 290 | " #\n", 291 | " ## 1b. write code to normalize the image\n", 292 | " # \n", 293 | "# ])" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "source": [ 299 | "## TODO: Create the preprocess logic using the transformation built in previous cell\n", 300 | "### Hint: Use torch.stack and load the images to the device\n", 301 | "\n", 302 | "# def preprocess(images):\n", 303 | "# \"\"\"\n", 304 | "# Args:\n", 305 | "# images: Sequence of images to preprocess using the composed \n", 306 | "# transformations created above\n", 307 | "\n", 308 | "# Returns: \n", 309 | "# preprocessed_images: Sequence of preprocessed images\n", 310 | "# \"\"\"\n", 311 | "# preprocessed_images = ...\n", 312 | "# return preprocessed_images" 313 | ], 314 | "metadata": { 315 | "id": "zoVmrOM1vdEA" 316 | }, 317 | "execution_count": null, 318 | "outputs": [] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": { 323 | "id": "ZuVriqWcuVSy" 324 | }, 325 | "source": [ 326 | "### Post-processor\n", 327 | "\n", 328 | "Next, we need to define our post-processing function:" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": { 335 | "id": "ea4FZctmvow4" 336 | }, 337 | "outputs": [], 338 | "source": [ 339 | "## TODO: Build the post-processor function for the explainer\n", 340 | "# We will apply a softmax function to the logits obtained in the last layer\n", 341 | "# in order to convert the prediction scores to probabilities\n", 342 | "\n", 343 | "# def postprocess(logits):\n", 344 | "# \"\"\"\n", 345 | "# Args:\n", 346 | "# logits: Logits from the last layer of the model\n", 347 | " \n", 348 | "# Returns:\n", 349 | "# postprocessed_outputs: Output from the Softmax layer applied to the logits\n", 350 | "# \"\"\"" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "id": "K3Z03RZ2vzVv" 357 | }, 358 | "source": [ 359 | "### Vision Explainer\n", 360 | "Now, construct the explainer using the VisionExplainer class. You'll want to provide it a list of the three explainer types you'd like to try: LIME, SHAP, and integrated gradient. Be sure to check the documentation for the appropriate arguments! See the sample code for VisionExplainer [here](https://opensource.salesforce.com/OmniXAI/v1.2.3/tutorials/vision.html)." 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": { 367 | "id": "-R8gBAgJ6xK8" 368 | }, 369 | "outputs": [], 370 | "source": [ 371 | "#TODO: Build the VisionExplainer by filling in the blanks\n", 372 | "# explainer = VisionExplainer(\n", 373 | "# explainers=[ ...],\n", 374 | "# mode=\"...\",\n", 375 | "# model=...,\n", 376 | "# preprocess=...,\n", 377 | "# postprocess=...,\n", 378 | "\n", 379 | "# )" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": { 385 | "id": "fuRVgBBCwvZX" 386 | }, 387 | "source": [ 388 | "Now, we can generate some explanations for each of the explainers using the explainer.explain() method:" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": { 395 | "id": "IFR55IOvw5aT" 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "## Time to generate the explanations\n", 400 | "local_explanations = explainer.explain(Image(\n", 401 | " data=np.concatenate([\n", 402 | " image.to_numpy()]),\n", 403 | " batched=True\n", 404 | "))" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "source": [ 410 | "## Lets write the local_explantions to a pickle file. We will use this in our dashboard\n", 411 | "with open('file.pkl', 'wb') as file:\n", 412 | " # A new file will be created\n", 413 | " pickle.dump(local_explanations, file)" 414 | ], 415 | "metadata": { 416 | "id": "33JVHA0COHCo" 417 | }, 418 | "execution_count": null, 419 | "outputs": [] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": { 424 | "id": "QL66926z7TA1" 425 | }, 426 | "source": [ 427 | "## Dashboard\n", 428 | "Now let's create a Dashboard to visualize our different explainers that we just built" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "source": [ 434 | "### Google Colab hosts the server on remote local. Therefore, localhost on your machine will not lead you to the dashboard\n", 435 | "\n", 436 | "## Open `output.log` from files and use the link to get redirected. \n", 437 | "## : It might take a minute for the log file to show up. Hit refresh if need be.\n", 438 | "!nohup npx localtunnel --port 8000 > output.log &" 439 | ], 440 | "metadata": { 441 | "id": "MJhl4wR77IxI" 442 | }, 443 | "execution_count": null, 444 | "outputs": [] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": { 450 | "id": "ZKdAir1y7hXD" 451 | }, 452 | "outputs": [], 453 | "source": [ 454 | "##########################################################\n", 455 | "###### Use the link from previous cell once running ######\n", 456 | "##########################################################\n", 457 | "\n", 458 | "\n", 459 | "## TODO: Fill in the Dashboard parameters\n", 460 | "\n", 461 | "# dashboard = Dashboard(\n", 462 | "# instances=...,\n", 463 | "# local_explanations=...,\n", 464 | "# class_names= ...\n", 465 | "# )\n", 466 | "\n", 467 | "\n", 468 | "## Do not change the port number\n", 469 | "## Once you open the link, it might take a minute or two for the website to load fully. Be patient :)\n", 470 | "dashboard.show(port=8000)" 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "metadata": { 476 | "id": "LC_0U5yY7jVk" 477 | }, 478 | "source": [ 479 | "## Outro\n", 480 | "\n", 481 | "🎉Yay, you did it! Now that we've seen the explantions, you are ready to answer some questions about the various explanations!\n", 482 | "\n", 483 | "1. What are your thoughts on Interpretable AI?\n", 484 | "2. Compare the various explanations. Which method do you agree with most, why?\n", 485 | "3. Do you think the ResNet model is good enough for cat owners?" 486 | ] 487 | }, 488 | { 489 | "cell_type": "markdown", 490 | "source": [ 491 | "## Bonus (Extension)\n", 492 | "Document the computation time for each explainer: LIME, SHAP, and integrated-gradient." 493 | ], 494 | "metadata": { 495 | "id": "IALK072KVawY" 496 | } 497 | }, 498 | { 499 | "cell_type": "code", 500 | "source": [ 501 | "## Lets use hugging face cats vs dogs dataset\n", 502 | "!pip install datasets" 503 | ], 504 | "metadata": { 505 | "id": "T2zBHxZv__wz" 506 | }, 507 | "execution_count": null, 508 | "outputs": [] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "source": [ 513 | "## Now we will load 5 cat images from the dataset\n", 514 | "from datasets import load_dataset\n", 515 | "\n", 516 | "## Feel free to change this number. In order to not run out of RAM we use 5 images\n", 517 | "NUM_IMAGES = 5\n", 518 | "dataset = load_dataset(\"cats_vs_dogs\")\n", 519 | "cats_data = dataset['train'][0:NUM_IMAGES]['image']\n", 520 | "cats_data" 521 | ], 522 | "metadata": { 523 | "id": "ZQXzI_tpjVeG" 524 | }, 525 | "execution_count": null, 526 | "outputs": [] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "source": [ 531 | "## Notice that the image sizes are different. \n", 532 | "## TODO: Convert them to same size using transforms.Resize\n", 533 | "\n", 534 | "#transform_resize = transforms.Compose([\n", 535 | "# transforms.Resize(...)\n", 536 | "#])" 537 | ], 538 | "metadata": { 539 | "id": "ooQTgt7lRa_3" 540 | }, 541 | "execution_count": null, 542 | "outputs": [] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "source": [ 547 | "## Lets use the transformer and stack the images\n", 548 | "# TODO: Use `transform_resize` and `np.stack`\n", 549 | "\n", 550 | "# cats = ([... for cat in cats_data])" 551 | ], 552 | "metadata": { 553 | "id": "p8aaCZLuRufe" 554 | }, 555 | "execution_count": null, 556 | "outputs": [] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "source": [ 561 | "## We will use this explainer function to create independant explainer \n", 562 | "def explainer(explainer):\n", 563 | " return VisionExplainer(\n", 564 | " explainers=[explainer],\n", 565 | " mode=\"classification\",\n", 566 | " model=model,\n", 567 | " preprocess=preprocess,\n", 568 | " postprocess=postprocess,\n", 569 | " )" 570 | ], 571 | "metadata": { 572 | "id": "WpcEMXkq_yGu" 573 | }, 574 | "execution_count": null, 575 | "outputs": [] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "source": [ 580 | "### TODO: Initialize the explainer for 'Lime', 'SHAP', and 'integrated gradient'\n", 581 | "# lime = explainer(...)\n", 582 | "# shap = explainer(...)\n", 583 | "# ig = explainer(...)" 584 | ], 585 | "metadata": { 586 | "id": "xOxjBo3Zlpj5" 587 | }, 588 | "execution_count": null, 589 | "outputs": [] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "source": [ 594 | "## Let us time the results. We will use built-in magic commands in jupyter \n", 595 | "%time lime_results = lime.explain(cats)" 596 | ], 597 | "metadata": { 598 | "id": "I2NN4qhXY8oF" 599 | }, 600 | "execution_count": null, 601 | "outputs": [] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "source": [ 606 | "%time shap_results = shap.explain(cats)" 607 | ], 608 | "metadata": { 609 | "id": "GbXgRgrimWED" 610 | }, 611 | "execution_count": null, 612 | "outputs": [] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "source": [ 617 | "%time ig_results = ig.explain(cats)" 618 | ], 619 | "metadata": { 620 | "id": "RXj-xYLQmZ1M" 621 | }, 622 | "execution_count": null, 623 | "outputs": [] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "source": [ 628 | "### Google Colab hosts the server on remote local. Therefore, localhost on your machine will not lead you to the dashboard\n", 629 | "\n", 630 | "## Open `output.log` from files and use the link to get redirected. \n", 631 | "## : It might take a minute for the log file to show up. Hit refresh if need be.\n", 632 | "!nohup npx localtunnel --port 8000 > output.log &" 633 | ], 634 | "metadata": { 635 | "id": "ml_kvswYkOgp" 636 | }, 637 | "execution_count": null, 638 | "outputs": [] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "source": [ 643 | "## Combine all results\n", 644 | "combine_results = lime_results\n", 645 | "combine_results['shap'] = shap_results['shap']\n", 646 | "combine_results['ig'] = ig_results['ig']\n", 647 | "\n", 648 | "## Lets visualize the results on the Dashboard\n", 649 | "dashboard = Dashboard(\n", 650 | " instances=Image(cats,batched =True),\n", 651 | " local_explanations=combine_results,\n", 652 | " class_names=idx2label\n", 653 | ")\n", 654 | "## Do not change the port\n", 655 | "## Once you open the link, it might take a minute or two for the website to load fully. Be patient :)\n", 656 | "dashboard.show(port=8000)" 657 | ], 658 | "metadata": { 659 | "id": "AsEFIWoCK3KO" 660 | }, 661 | "execution_count": null, 662 | "outputs": [] 663 | }, 664 | { 665 | "cell_type": "markdown", 666 | "source": [ 667 | "## Final Thoughts🎉\n", 668 | "\n", 669 | "Congratulations on finishing the bonus sections. It is an impressive feat!\n", 670 | "\n", 671 | "---\n", 672 | "Please share your observations about the computation time for each of the explainers and recommend a method based on this and any other relevant factors, such as effectiveness or accuracy? If your recommendation differs from a previous suggestion, please explain the reason for this change." 673 | ], 674 | "metadata": { 675 | "id": "yz1qeEXWH2Yk" 676 | } 677 | } 678 | ], 679 | "metadata": { 680 | "colab": { 681 | "provenance": [] 682 | }, 683 | "kernelspec": { 684 | "display_name": "Python 3", 685 | "name": "python3" 686 | }, 687 | "language_info": { 688 | "name": "python" 689 | }, 690 | "gpuClass": "standard" 691 | }, 692 | "nbformat": 4, 693 | "nbformat_minor": 0 694 | } -------------------------------------------------------------------------------- /week4/counterfactual_explanations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU", 17 | "gpuClass": "standard" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "# 📅 Week 4 -Counterfactual Explanations using Tabular Data\n", 24 | "#### 🚨 **First things first! Make a copy of this notebook. Your changes will not save unless you create your own copy!**🚨\n" 25 | ], 26 | "metadata": { 27 | "id": "ZsEtuxqZFGnk" 28 | } 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "source": [ 33 | "You are a data scientist and risk modeler of a new fintech company that provides personal loans to individuals. As a startup, the company has limited staff. For this reason, your company decides to create a machine learning algorithm that can assess risk and categories potential loan takers in to low and high risk category. \n", 34 | "\n", 35 | "However, the company has a small pool of data from the manual decisions it made. Despite these challenges, your team is determined to develop a state-of-the-art loan default/risk prediction model to differentiate the company from its competitors.\n", 36 | "\n", 37 | "
\n", 38 | "\n", 39 | "The stakes are high as there is a risk of losing out on potential customers and damaging the reputation of the company in case of false rejections. Also, you want to inform your customers why exactly were they rejected and suggest suitable actionables they can take to lower their risk and hypothetically get the credit loan. Your team decides that the obvious choice would be to generate counterfactuals using XAI principles and present them to customers. \n", 40 | "\n", 41 | "Therefore, your job as the data scientist is to build a machine learning model that predicts risk on the manual data. To generate and verify if the counterfactuals are reasonable and are plausible." 42 | ], 43 | "metadata": { 44 | "id": "dWXumWi8VgHZ" 45 | } 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "source": [ 50 | "#📦 Installation and Imports\n", 51 | "We will use same [OmnixAI](https://github.com/salesforce/OmniXAI) python package for generating counterfactuals" 52 | ], 53 | "metadata": { 54 | "id": "nvvLfjevt72K" 55 | } 56 | }, 57 | { 58 | "cell_type": "code", 59 | "source": [ 60 | "!pip install omnixai" 61 | ], 62 | "metadata": { 63 | "id": "SbqfNp19N5k_" 64 | }, 65 | "execution_count": null, 66 | "outputs": [] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "source": [ 71 | "## The Usual Suspects\n", 72 | "import tensorflow as tf\n", 73 | "import itertools\n", 74 | "import numpy as np\n", 75 | "import pandas as pd\n", 76 | "from typing import Any\n", 77 | "from sklearn.model_selection import train_test_split\n", 78 | "from sklearn.preprocessing import StandardScaler,OneHotEncoder, OrdinalEncoder, MinMaxScaler\n", 79 | "from sklearn.compose import ColumnTransformer\n", 80 | "\n", 81 | "\n", 82 | "## For visualization\n", 83 | "import seaborn as sns\n", 84 | "import matplotlib.pyplot as plt\n", 85 | "plt.rc('font', size=14)\n", 86 | "\n", 87 | "## Training pytorch tabular model\n", 88 | "\n", 89 | "import torch\n", 90 | "import torch.nn as nn\n", 91 | "import torch.nn.functional as F\n", 92 | "from torch.utils.data import DataLoader, TensorDataset\n", 93 | "\n", 94 | "## OmniXAI Counterfactual Explainer\n", 95 | "from omnixai.data.tabular import Tabular\n", 96 | "from omnixai.explainers.tabular import CounterfactualExplainer\n", 97 | "from omnixai.explainers.tabular.specific.decision_tree import TreeClassifier" 98 | ], 99 | "metadata": { 100 | "id": "7gqugzHcNJQv" 101 | }, 102 | "execution_count": null, 103 | "outputs": [] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "source": [ 108 | "## 💻Dataset - German Credit Risk\n", 109 | "For developing a machine learning model, you decide to use a [real data](https://archive.ics.uci.edu/ml/datasets/South+German+Credit) and understand the factors that influenced the decision of credit risk. For this project, we will be using the German Credit Risk dataset. Download the dataset from [Kaggle here](https://www.kaggle.com/datasets/kabure/german-credit-data-with-risk) or just run the cell below. " 110 | ], 111 | "metadata": { 112 | "id": "ckMILS-kNJ7Q" 113 | } 114 | }, 115 | { 116 | "cell_type": "code", 117 | "source": [ 118 | "## Read CSV file and import to dataframe from the url\n", 119 | "url = 'https://drive.google.com/file/d/13vAvup3zgmkPOJ9P4ulRkQ3BQCN7nqe_/view?usp=sharing'\n", 120 | "path = 'https://drive.google.com/uc?export=download&id='+url.split('/')[-2]\n", 121 | "df = pd.read_csv(path, index_col=0)\n", 122 | "df.head()" 123 | ], 124 | "metadata": { 125 | "id": "HRdxCo0ZMGDH" 126 | }, 127 | "execution_count": null, 128 | "outputs": [] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "source": [ 133 | "#🔎 Exploratory Data Analysis\n", 134 | "Let us understand this dataset and make it ready for our ML model. The attributes are pretty much self-explanatory. However, there are some `NaN` in the data. Let us check how many columns have missing information by exploring this dataset" 135 | ], 136 | "metadata": { 137 | "id": "IPygv7EavYfk" 138 | } 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "source": [ 143 | "## Data Cleaning" 144 | ], 145 | "metadata": { 146 | "id": "k_92P7rV-v7M" 147 | } 148 | }, 149 | { 150 | "cell_type": "code", 151 | "source": [ 152 | "df.info()" 153 | ], 154 | "metadata": { 155 | "id": "OoV7mTpsvPq1" 156 | }, 157 | "execution_count": null, 158 | "outputs": [] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "source": [ 163 | "Only categorical variables (Dtype - `object`) have missing entries. For these are categorical variables, we will fill the NaN with an additional category called `other`" 164 | ], 165 | "metadata": { 166 | "id": "a9bx0rCJw1nz" 167 | } 168 | }, 169 | { 170 | "cell_type": "code", 171 | "source": [ 172 | "## TODO: For the columns with missing information, fill the NaN with the variable 'other'\n", 173 | "# df.fillna(..., inplace = True)\n", 174 | "# df.head()" 175 | ], 176 | "metadata": { 177 | "id": "sSEIh5G_wAYd" 178 | }, 179 | "execution_count": null, 180 | "outputs": [] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "source": [ 185 | "df.info()" 186 | ], 187 | "metadata": { 188 | "id": "V6FH2yQBan-z" 189 | }, 190 | "execution_count": null, 191 | "outputs": [] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "source": [ 196 | "Our dataset now has 1000 rows that are valid with no NaNs. The column `Risk` is the predictor of interest that catergorizes the individual based on different attributes into `good` or `bad`. Let us see the number of people with `good` and `bad` risk profile" 197 | ], 198 | "metadata": { 199 | "id": "kxo31b8GxcRP" 200 | } 201 | }, 202 | { 203 | "cell_type": "code", 204 | "source": [ 205 | "df['Risk'].value_counts().plot(kind ='bar')" 206 | ], 207 | "metadata": { 208 | "id": "at7cpROdpUkf" 209 | }, 210 | "execution_count": null, 211 | "outputs": [] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "source": [ 216 | "### Explore Categorical Variables" 217 | ], 218 | "metadata": { 219 | "id": "ARjbZqlNbWyV" 220 | } 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "source": [ 225 | "#### Convert categorical variables with two classes to binary format\n", 226 | "In binary format, we assign a class with label `0` and other class with `1`" 227 | ], 228 | "metadata": { 229 | "id": "R02hzJMgAGlz" 230 | } 231 | }, 232 | { 233 | "cell_type": "code", 234 | "source": [ 235 | "## TO DO: Convert `Risk` and `Sex` to a binary variable\n", 236 | "## Use .map() method and convert bad -> 0 and good -> 1 & male ->1 and female ->0\n", 237 | "\n", 238 | "# df['Risk'] = df['Risk'].map(...)\n", 239 | "# df['Sex'] = df['Sex'].map(...)" 240 | ], 241 | "metadata": { 242 | "id": "h4haytcZ1ytC" 243 | }, 244 | "execution_count": null, 245 | "outputs": [] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "source": [ 250 | "#### Convert categorical variables with >2 classes to ordinal format\n", 251 | "In ordinal format, we assign a number for each class starting from 0. The columns are `Job` , `Housing` , `Saving accounts` , `Checking account`, and `Purpose`. Let us explore each of them" 252 | ], 253 | "metadata": { 254 | "id": "I8dvFHc9AKsW" 255 | } 256 | }, 257 | { 258 | "cell_type": "code", 259 | "source": [ 260 | "cat_variables = ['Job' , 'Housing' , 'Saving accounts' , 'Checking account', 'Purpose']\n", 261 | "fig, ax = plt.subplots(nrows=2,ncols=3, figsize= (14,10))\n", 262 | "for i,category in enumerate(cat_variables):\n", 263 | " j = i if i < 3 else i % 3\n", 264 | " df[category].value_counts().plot(kind = 'bar', ax = ax[int(i/3),j], title = category)\n", 265 | "fig.delaxes(ax[1,2])\n", 266 | "fig.tight_layout()" 267 | ], 268 | "metadata": { 269 | "id": "UHL9BcimCING" 270 | }, 271 | "execution_count": null, 272 | "outputs": [] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "source": [ 277 | "Notice that the `Job` column is already represented in an ordinal format. Let us convert the rest of them too. In order to keep track of the relationship of classe and labels, let us store the class dictionaries in `class_to_labels` variable all categories" 278 | ], 279 | "metadata": { 280 | "id": "ivGhNviccH_G" 281 | } 282 | }, 283 | { 284 | "cell_type": "code", 285 | "source": [ 286 | "class_to_labels = {}\n", 287 | "cat_to_ordinal = ['Housing' , 'Saving accounts' , 'Checking account', 'Purpose']\n", 288 | "for category in cat_to_ordinal:\n", 289 | " values = df[category].unique()\n", 290 | " ids = range(0,len(values))\n", 291 | " cat_dict = dict(zip(values,ids))\n", 292 | " df[category] = df[category].map(cat_dict)\n", 293 | " class_to_labels[category] = cat_dict" 294 | ], 295 | "metadata": { 296 | "id": "EbsvcXIRZbg3" 297 | }, 298 | "execution_count": null, 299 | "outputs": [] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "source": [ 304 | "df" 305 | ], 306 | "metadata": { 307 | "id": "cosS-dzXbwXT" 308 | }, 309 | "execution_count": null, 310 | "outputs": [] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "source": [ 315 | "class_to_labels" 316 | ], 317 | "metadata": { 318 | "id": "HTE6Mb3mbLOg" 319 | }, 320 | "execution_count": null, 321 | "outputs": [] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "source": [ 326 | "## Explore Distributions" 327 | ], 328 | "metadata": { 329 | "id": "a0JBtc77bqtC" 330 | } 331 | }, 332 | { 333 | "cell_type": "code", 334 | "source": [ 335 | "sns.displot(\n", 336 | " df, x=\"Age\", col=\"Risk\", row=\"Sex\",\n", 337 | " binwidth=3, height=5, facet_kws=dict(margin_titles=True)\n", 338 | ")" 339 | ], 340 | "metadata": { 341 | "id": "vV_RcFVKYKKB" 342 | }, 343 | "execution_count": null, 344 | "outputs": [] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "source": [ 349 | "fig, ax = plt.subplots(figsize= (12,10))\n", 350 | "sns.violinplot(data=df, x=\"Risk\", y=\"Credit amount\", hue=\"Housing\", inner=\"box\", linewidth=1,pallette = 'tab10')\n", 351 | "sns.despine(left=True)" 352 | ], 353 | "metadata": { 354 | "id": "ch2sRdCeaCeb" 355 | }, 356 | "execution_count": null, 357 | "outputs": [] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "source": [ 362 | "## View Correlations" 363 | ], 364 | "metadata": { 365 | "id": "kXcZH3lDb0_p" 366 | } 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "source": [ 371 | "Now, let us see which attributes are correlated (positive/negative) with the `Risk`. To do this we need to convert the `Risk` to binary variable" 372 | ], 373 | "metadata": { 374 | "id": "jOP3h56e12m8" 375 | } 376 | }, 377 | { 378 | "cell_type": "code", 379 | "source": [ 380 | "## TO DO: Use `seaborn` heatmap to display the correlations. Fill in the `visualize_corr` function\n", 381 | "## Make sure to display the annotations and colormap\n", 382 | "def visualize_corr(df, figsize = (12,12)):\n", 383 | " plt.figure(figsize=figsize)\n", 384 | " pass\n", 385 | "\n", 386 | "\n", 387 | "def visualize_corr(df, figsize = (12,12)):\n", 388 | " plt.figure(figsize=figsize)\n", 389 | " sns.heatmap(df.corr(), annot=True)\n", 390 | " plt.show()" 391 | ], 392 | "metadata": { 393 | "id": "rsX--rz68cIA" 394 | }, 395 | "execution_count": null, 396 | "outputs": [] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "source": [ 401 | "visualize_corr(df)" 402 | ], 403 | "metadata": { 404 | "id": "ForTmwH44px-" 405 | }, 406 | "execution_count": null, 407 | "outputs": [] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "source": [ 412 | "## 🚨TODO: Let's build some Intuition 🤔\n", 413 | "\n", 414 | "\n", 415 | "\n", 416 | "Based on the exploratory data analysis above, list the attributes/features that are most influential the deciding the `Risk` as good/bad.\n", 417 | "\n", 418 | "\n", 419 | "1. \"List item here\"\n", 420 | "2. \"List item here\"\n", 421 | "\n", 422 | "\n", 423 | "\n", 424 | "\n", 425 | "\n", 426 | "\n" 427 | ], 428 | "metadata": { 429 | "id": "2oL7k0xx8USR" 430 | } 431 | }, 432 | { 433 | "cell_type": "markdown", 434 | "source": [ 435 | "Answer the following questions\n", 436 | "1. List the attributes that make it less risky(good)\n", 437 | "2. List the attributes that make it more risky(bad)\n", 438 | "3. Did you notice any trends? Do they sound reasonable?\n" 439 | ], 440 | "metadata": { 441 | "id": "mSSVevF46lbJ" 442 | } 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "source": [ 447 | "#🤖 Machine Learning Model\n", 448 | "Using the manual data as input, let us build our machine learning model. We will build a Neural Network based classifiers in this section. Before we proceed, we need to normalize our numnerical variables and split the dataset into `train` and `test` with a 80:20 split " 449 | ], 450 | "metadata": { 451 | "id": "hmFMSVz15YNa" 452 | } 453 | }, 454 | { 455 | "cell_type": "code", 456 | "source": [ 457 | "df" 458 | ], 459 | "metadata": { 460 | "id": "24ubZ-7JAsZX" 461 | }, 462 | "execution_count": null, 463 | "outputs": [] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "source": [ 468 | "## TODO: Scale numerical variables to [0,1] using `MinMaxScalar()` from sklearn\n", 469 | "\n", 470 | "numeric_columns= ['Age','Credit amount','Duration']\n", 471 | "# Scaler = ...\n", 472 | "# df[numeric_columns] = ...\n" 473 | ], 474 | "metadata": { 475 | "id": "CZ9qMXLf_Wkj" 476 | }, 477 | "execution_count": null, 478 | "outputs": [] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "source": [ 483 | "## Note: You can convert to the original value using the `inverse_transform`\n", 484 | "Scaler.inverse_transform(df.head()[numeric_columns])" 485 | ], 486 | "metadata": { 487 | "id": "Fq5ehiS3A-KM" 488 | }, 489 | "execution_count": null, 490 | "outputs": [] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "source": [ 495 | "## TODO: Split the data into train and test. Initialize variables X & Y and pass into `train_test_split` function\n", 496 | "## Make sure to fill the remaining blanks and use a random_state\n", 497 | "\n", 498 | "# X = ...\n", 499 | "# Y = ...\n", 500 | "# x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=..., random_state=..., stratify=y)\n" 501 | ], 502 | "metadata": { 503 | "id": "ywqe8naAyzIQ" 504 | }, 505 | "execution_count": null, 506 | "outputs": [] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "source": [ 511 | "### Here is the function to build and train the nueral network model\n", 512 | "def train_tf_model(x_train, y_train, x_test, y_test):\n", 513 | " y_train = tf.keras.utils.to_categorical(y_train, 2)\n", 514 | " y_test = tf.keras.utils.to_categorical(y_test, 2)\n", 515 | "\n", 516 | " model = tf.keras.models.Sequential()\n", 517 | " ### Fill out the inpout size based on the number of input variables\n", 518 | " model.add(tf.keras.layers.Input(shape=(9,)))\n", 519 | " model.add(tf.keras.layers.Dense(units=32, activation=tf.keras.activations.relu))\n", 520 | " model.add(tf.keras.layers.Dense(units=32, activation=tf.keras.activations.relu))\n", 521 | " model.add(tf.keras.layers.Dense(units=2, activation=tf.keras.activations.softmax))\n", 522 | "\n", 523 | " learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(\n", 524 | " initial_learning_rate=0.1,\n", 525 | " decay_steps=1,\n", 526 | " decay_rate=0.99,\n", 527 | " staircase=True\n", 528 | " )\n", 529 | " optimizer = tf.keras.optimizers.SGD(\n", 530 | " learning_rate=learning_rate,\n", 531 | " momentum=0.9,\n", 532 | " nesterov=True\n", 533 | " )\n", 534 | " loss = tf.keras.losses.CategoricalCrossentropy()\n", 535 | " model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])\n", 536 | " model.fit(x_train, y_train, batch_size=64, epochs=10, verbose=0)\n", 537 | " train_loss, train_accuracy = model.evaluate(x_train, y_train, batch_size=51, verbose=0)\n", 538 | " test_loss, test_accuracy = model.evaluate(x_test, y_test, batch_size=51, verbose=0)\n", 539 | "\n", 540 | " print('Train loss: {:.4f}, train accuracy: {:.4f}'.format(train_loss, train_accuracy))\n", 541 | " print('Test loss: {:.4f}, test accuracy: {:.4f}'.format(test_loss, test_accuracy))\n", 542 | " return model" 543 | ], 544 | "metadata": { 545 | "id": "ZDXzxtOQypIg" 546 | }, 547 | "execution_count": null, 548 | "outputs": [] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "source": [ 553 | "## Build and train the nueral network\n", 554 | "model_nn = train_tf_model(x_train, y_train, x_test, y_test)" 555 | ], 556 | "metadata": { 557 | "id": "KdJo_ghb3CS0" 558 | }, 559 | "execution_count": null, 560 | "outputs": [] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "source": [ 565 | "model_nn.predict(x_test[:5])" 566 | ], 567 | "metadata": { 568 | "id": "lLteCoRP0U7Z" 569 | }, 570 | "execution_count": null, 571 | "outputs": [] 572 | }, 573 | { 574 | "cell_type": "markdown", 575 | "source": [ 576 | "# 🔄️ Generate Counterfactuals\n", 577 | "Now that we have our ML model ready, let us generate our CounterfactualExplainer using OmnixAI library. Read this [paper](https://arxiv.org/ftp/arxiv/papers/1711/1711.00399.pdf) for more information on counterfactuals" 578 | ], 579 | "metadata": { 580 | "id": "TPmeiveekDsZ" 581 | } 582 | }, 583 | { 584 | "cell_type": "code", 585 | "source": [ 586 | "## OmnixAI requires the data to be in a Tabular, Image or Text format. Let's convert our Dataframe to a Tabular format\n", 587 | "## TODO: Pass the training data and features\n", 588 | "\n", 589 | "# feature_names = ...\n", 590 | "\n", 591 | "## Used for initializing the explainer\n", 592 | "# tabular_data_train = Tabular(\n", 593 | "# data = ...,\n", 594 | "# feature_columns=feature_names\n", 595 | "# )\n", 596 | "\n", 597 | "# tabular_data_test = Tabular(\n", 598 | "# data = ...,\n", 599 | "# feature_columns=feature_names\n", 600 | "# )" 601 | ], 602 | "metadata": { 603 | "id": "rwWi19shygwf" 604 | }, 605 | "execution_count": null, 606 | "outputs": [] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "source": [ 611 | "## TODO: Pass the training data and model to the Counterfactual explainer\n", 612 | "# explainer_nn = CounterfactualExplainer(training_data=...,\n", 613 | "# predict_function=...)\n" 614 | ], 615 | "metadata": { 616 | "id": "W4B2rpPO3RXo" 617 | }, 618 | "execution_count": null, 619 | "outputs": [] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "source": [ 624 | "## Time to generate some counterfactuals!!!\n", 625 | "explanations = explainer_nn.explain(tabular_data_test[1])" 626 | ], 627 | "metadata": { 628 | "id": "hiPu0nkj4qdT" 629 | }, 630 | "execution_count": null, 631 | "outputs": [] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "source": [ 636 | "explanations.ipython_plot(index = 0)" 637 | ], 638 | "metadata": { 639 | "id": "47GV6LZ3OcRr" 640 | }, 641 | "execution_count": null, 642 | "outputs": [] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "source": [ 647 | "## TODO: Sample few records in the test data where the `Risk` is predicted as `bad`(0) and generate a counterfactuals\n", 648 | "## NOTE: Select a sample size N if the explainer takes long\n", 649 | "\n", 650 | "# N = ...\n", 651 | "# RANDOM_STATE = ...\n", 652 | "# tab_indices = y_test[y_test == ...]\\\n", 653 | "# .sample(n=..., random_state=...)\\\n", 654 | "# .index\\\n", 655 | "# .tolist()\n", 656 | "\n", 657 | "N = 1\n", 658 | "RANDOM_STATE = 42\n", 659 | "tab_indices = y_test[y_test == 0].sample(n=N, random_state=RANDOM_STATE).index.tolist()" 660 | ], 661 | "metadata": { 662 | "id": "yGRdIezJonGF" 663 | }, 664 | "execution_count": null, 665 | "outputs": [] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "source": [ 670 | "explanations_risk_0 = explainer_nn.explain(tabular_data_test.to_pd().loc[tab_indices])" 671 | ], 672 | "metadata": { 673 | "id": "RtHU7pCxPZbQ" 674 | }, 675 | "execution_count": null, 676 | "outputs": [] 677 | }, 678 | { 679 | "cell_type": "code", 680 | "source": [ 681 | "explanations_risk_0.ipython_plot()" 682 | ], 683 | "metadata": { 684 | "id": "smvj_IG0PwgM" 685 | }, 686 | "execution_count": null, 687 | "outputs": [] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "source": [ 692 | "## TODO: Find sample records in the train & test data where the model made an incorrect/different prediction and generate counterfactuals for those predictions\n", 693 | "## Fill out the missing entries in the function\n", 694 | "\n", 695 | "# def find_incorrect_pred_indices(\n", 696 | "# model: Any, X: pd.DataFrame, y: pd.Series, size: int = N\n", 697 | "# ) -> None:\n", 698 | "# \"\"\"\n", 699 | "# Find the incorrect predictions from a model and generates the counterfactual\n", 700 | "# explanations for the particular dataset that the model was trained/evaluated on.\n", 701 | " \n", 702 | "# Args:\n", 703 | "# model: The model used for training\n", 704 | "# X: The dataset used on the model training/evaluation excluding the target column\n", 705 | "# y: The target column values of X\n", 706 | "# size: The size of the random samples to select from each of the false positives and false negatives.\n", 707 | "# The bigger the size, the longer the computation of the counterfactuals. \n", 708 | "# \"\"\"\n", 709 | "# predictions = model.predict(X).argmax(axis=-1)\n", 710 | "# actual = y.values\n", 711 | "# difference = actual - predictions\n", 712 | "# false_positives = np.where(difference == ...)[0]\n", 713 | "# false_negatives = np.where(difference == ...)[0]\n", 714 | "# assert size < len(false_positives)\n", 715 | "# assert size < len(false_negatives)\n", 716 | "# random_sample_false_positives = np.random.choice(..., size=...)\n", 717 | "# random_sample_false_negatives = np.random.choice(..., size=...)\n", 718 | "# flattened_indices = list(\n", 719 | "# itertools.chain.from_iterable((random_sample_false_positives, random_sample_false_negatives))\n", 720 | "# )\n", 721 | "# tab_indices = X.iloc[flattened_indices].index.tolist()\n", 722 | "# return tab_indices" 723 | ], 724 | "metadata": { 725 | "id": "OtUEhXNS3ToY" 726 | }, 727 | "execution_count": null, 728 | "outputs": [] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "source": [ 733 | "tab_indices_train = find_incorrect_pred_indices(model=model_nn, X=x_train, y=y_train,size =1)\n", 734 | "tab_indices_test = find_incorrect_pred_indices(model=model_nn, X=x_test, y=y_test,size = 1)" 735 | ], 736 | "metadata": { 737 | "id": "yldGHAEKTESU" 738 | }, 739 | "execution_count": null, 740 | "outputs": [] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "source": [ 745 | "explanations_train = explainer_nn.explain(tabular_data_train.to_pd().loc[tab_indices_train])\n", 746 | "explanations_test = explainer_nn.explain(tabular_data_test.to_pd().loc[tab_indices_test])" 747 | ], 748 | "metadata": { 749 | "id": "3l2UeZiNTb7L" 750 | }, 751 | "execution_count": null, 752 | "outputs": [] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "source": [ 757 | "explanations_train.ipython_plot()" 758 | ], 759 | "metadata": { 760 | "id": "bzvbFEtVTnv_" 761 | }, 762 | "execution_count": null, 763 | "outputs": [] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "source": [ 768 | "explanations_test.ipython_plot()" 769 | ], 770 | "metadata": { 771 | "id": "icXS4jHWTvt9" 772 | }, 773 | "execution_count": null, 774 | "outputs": [] 775 | }, 776 | { 777 | "cell_type": "markdown", 778 | "source": [ 779 | "# 🗝️Outro\n", 780 | "Awesome! You made it till the end 👏 Answer the following questions to deepen your understanding.\n", 781 | "\n", 782 | "1. Do you think the ML model prediction is similar to the orignal manual prediction? Do the predictions align well with your initial intuition?\n", 783 | "2. We as humans have biases in our decision making. The biases might seep into the ML model as the models try to minimize the loss and improve accuracy. With the help of counterfactuals, did you see that the model has any inherent bias?\n", 784 | "3. Do you think the counterfactuals serve a good tool to explain customers what they could do to achieve a better Risk profile? Give reasoning." 785 | ], 786 | "metadata": { 787 | "id": "jwf9Dt2X4S_1" 788 | } 789 | }, 790 | { 791 | "cell_type": "markdown", 792 | "source": [ 793 | "# 💰Bonus - Counterfactuals for text classification \n", 794 | "Let's apply counterfactuals in the context of text classification. Specifically, we will generate counter factuals for the sentiment analysis we used in Week 2 for movie reviews. We will use the same pretrained `cardiffnlp/twitter-xlm-roberta-base-sentiment` model from hugging face. The OmnixAI provides [polyjuice](https://github.com/tongshuangwu/polyjuice) explainer for counterfactuals. You can find a demonstration of `polyjuice` from Hugging Face [here](https://media1.giphy.com/media/v7yls1pusVAyo2JfPu/giphy.gif?cid=ecf05e47yq0khgmvfuggu2tukla3eavittk2oyxagnd51llz&rid=giphy.gif&ct=g)\n", 795 | "\n", 796 | "\n", 797 | "
" 798 | ], 799 | "metadata": { 800 | "id": "wYudf5tMkGxE" 801 | } 802 | }, 803 | { 804 | "cell_type": "code", 805 | "source": [ 806 | "!pip install transformers==4.6.0 polyjuice_nlp torch omnixai" 807 | ], 808 | "metadata": { 809 | "id": "LIMa1n4ggVXK" 810 | }, 811 | "execution_count": null, 812 | "outputs": [] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "source": [ 817 | "# NLTk is a NLP toolkit that provides helpful lexical resources such as the wordnet (https://wordnet.princeton.edu/) which can be used to find synsets of words. Eg. car <--> automobile\n", 818 | "import nltk\n", 819 | "nltk.download('omw-1.4')" 820 | ], 821 | "metadata": { 822 | "id": "pSLhj9X7dVvG" 823 | }, 824 | "execution_count": null, 825 | "outputs": [] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "source": [ 830 | "import torch\n", 831 | "import transformers\n", 832 | "from polyjuice import Polyjuice\n", 833 | "from transformers import AutoModelForSequenceClassification\n", 834 | "from omnixai.data.text import Text\n", 835 | "from omnixai.explainers.nlp import NLPExplainer" 836 | ], 837 | "metadata": { 838 | "id": "FGjSoWAhBsxI" 839 | }, 840 | "execution_count": null, 841 | "outputs": [] 842 | }, 843 | { 844 | "cell_type": "code", 845 | "source": [ 846 | "## Before we build our transformer, lets make sure to setup the device.\n", 847 | "## To run this notbeook via GPU: Edit -> Notebook settings -> Hardware accelerator -> GPU\n", 848 | "## If your GPU is working, device is \"cuda\"\n", 849 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 850 | "device" 851 | ], 852 | "metadata": { 853 | "id": "p02WE4Yqhc7U" 854 | }, 855 | "execution_count": null, 856 | "outputs": [] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "source": [ 861 | "name = \"cardiffnlp/twitter-xlm-roberta-base-sentiment\" \n", 862 | "# The pre- and post-processing functions\n", 863 | "preprocess = lambda x: x.values\n", 864 | "postprocess = lambda outputs: np.array([[s[\"score\"] for s in ss] for ss in outputs])\n", 865 | "\n", 866 | "\n", 867 | "##TODO: Build pre-trained model for sentiment analysis\n", 868 | "\n", 869 | "# model = transformers.pipeline(\n", 870 | "# 'sentiment-analysis',\n", 871 | "# model=...,\n", 872 | "# return_all_scores=...\n", 873 | "# )" 874 | ], 875 | "metadata": { 876 | "id": "poK9H_P8hXw-" 877 | }, 878 | "execution_count": null, 879 | "outputs": [] 880 | }, 881 | { 882 | "cell_type": "code", 883 | "source": [ 884 | "\n", 885 | "# Build explainer using the NLPExplainer. Use \"polyjuice\" for explainer\n", 886 | "\n", 887 | "# explainer = NLPExplainer(\n", 888 | "# explainers=[...],\n", 889 | "# mode=\"...\",\n", 890 | "# model=...,\n", 891 | "# preprocess=...,\n", 892 | "# postprocess=...\n", 893 | "# )\n" 894 | ], 895 | "metadata": { 896 | "id": "oUkDIRTH8Bqe" 897 | }, 898 | "execution_count": null, 899 | "outputs": [] 900 | }, 901 | { 902 | "cell_type": "code", 903 | "source": [ 904 | "## Remember our classes for sentiment analysis are as follows\n", 905 | "model.model.config.label2id" 906 | ], 907 | "metadata": { 908 | "id": "nHMzJcjv5WBK" 909 | }, 910 | "execution_count": null, 911 | "outputs": [] 912 | }, 913 | { 914 | "cell_type": "markdown", 915 | "source": [ 916 | "Let us use some phrases from movie reviews. Feel free to add your own or ask ChatGPT to generate some for you😉" 917 | ], 918 | "metadata": { 919 | "id": "e8mr841N9gUG" 920 | } 921 | }, 922 | { 923 | "cell_type": "code", 924 | "source": [ 925 | "x = Text([\n", 926 | " \"What a great movie!\",\n", 927 | " \"The movie had great narration and visuals despite a boring storyline\"\n", 928 | "])" 929 | ], 930 | "metadata": { 931 | "id": "oyM6_zZ4kcsf" 932 | }, 933 | "execution_count": null, 934 | "outputs": [] 935 | }, 936 | { 937 | "cell_type": "code", 938 | "source": [ 939 | "# Generates explanations\n", 940 | "local_explanations = explainer.explain(x)" 941 | ], 942 | "metadata": { 943 | "id": "yYvOhcohjASj" 944 | }, 945 | "execution_count": null, 946 | "outputs": [] 947 | }, 948 | { 949 | "cell_type": "code", 950 | "source": [ 951 | "## View explanations\n", 952 | "local_explanations['polyjuice'].ipython_plot(index = 1)" 953 | ], 954 | "metadata": { 955 | "id": "P0Q7UmWKqH0s" 956 | }, 957 | "execution_count": null, 958 | "outputs": [] 959 | }, 960 | { 961 | "cell_type": "markdown", 962 | "source": [ 963 | "This is the final project of the four week `Interpreting Machine Learning Models` course. We hope you had a fun learning experience. Keep engaged with the community and build it stronger💪" 964 | ], 965 | "metadata": { 966 | "id": "dfo85YO54-1x" 967 | } 968 | } 969 | ] 970 | } --------------------------------------------------------------------------------