├── .github └── workflows │ └── binder.yml ├── .snoitulos ├── LICENSE ├── R └── Causality_Tutorial_Exercises.ipynb ├── README.md ├── binder └── Dockerfile ├── data ├── Exercise-ANM.csv ├── Exercise-ANM.png └── Exercise-ICP.csv ├── install.R ├── python ├── Causality_Tutorial_Exercises.ipynb └── kerpy │ ├── README.md │ └── __init__.py ├── requirements.txt └── runtime.txt /.github/workflows/binder.yml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_dispatch: 3 | 4 | jobs: 5 | binder: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: Checkout Code 9 | uses: actions/checkout@v2 10 | with: 11 | ref: ${{ github.event.pull_request.head.sha }} 12 | 13 | - name: update jupyter dependencies with repo2docker 14 | uses: jupyterhub/repo2docker-action@master 15 | with: 16 | DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} 17 | DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} 18 | BINDER_CACHE: true 19 | PUBLIC_REGISTRY_CHECK: true 20 | -------------------------------------------------------------------------------- /.snoitulos: -------------------------------------------------------------------------------- 1 | R: https://seafile.erda.dk/seafile/f/8832433ded/ 2 | V2: https://seafile.erda.dk/seafile/f/7e1a032002/?dl=1 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /R/Causality_Tutorial_Exercises.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "aaQIJ9pj4Rqv" 7 | }, 8 | "source": [ 9 | "# Causality Tutorial Exercises – R\n", 10 | "\n", 11 | "Contributors: Rune Christiansen, Jonas Peters, Niklas Pfister, Sorawit Saengkyongam, Sebastian Weichwald.\n", 12 | "The MIT License applies; copyright is with the authors.\n", 13 | "Some exercises are adapted from \"Elements of Causal Inference: Foundations and Learning Algorithms\" by J. Peters, D. Janzing and B. Schölkopf.\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "bobON40QU5hk" 20 | }, 21 | "source": [ 22 | "# Exercise 1 – Structural Causal Model" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": { 28 | "id": "qhShyZbusKxG" 29 | }, 30 | "source": [ 31 | "Let's first draw a sample from an SCM" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "id": "UAD0hTOf9Sh1" 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "set.seed(1)\n", 43 | "\n", 44 | "n <- 200\n", 45 | "C <- rnorm(n)\n", 46 | "A <- 0.8*rnorm(n)\n", 47 | "K <- A + 0.1*rnorm(n)\n", 48 | "X <- C - 2*A + 0.2*rnorm(n)\n", 49 | "F <- 3*X + 0.8*rnorm(n)\n", 50 | "D <- -2*X + 0.5*rnorm(n)\n", 51 | "G <- D + 0.5*rnorm(n)\n", 52 | "Y <- 2*K - D + 0.2*rnorm(n)\n", 53 | "H <- 0.5*Y + 0.1*rnorm(n)\n", 54 | "\n", 55 | "data.obs <- cbind(C, A, K, X, F, D, G, Y, H)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "### (a)\n", 63 | "\n", 64 | "What are the parents and children of $X$ in the above SCM ?\n", 65 | "\n", 66 | "Take a pair of variables and think about whether you expect this pair to be dependent\n", 67 | "(at this stage, you can only guess, later you will have tools to know). Check empirically." 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "jp-MarkdownHeadingCollapsed": true 74 | }, 75 | "source": [ 76 | "#### Solution" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "id": "8PMvvEeIoKFN" 84 | }, 85 | "outputs": [], 86 | "source": [] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "### (b)\n", 93 | "\n", 94 | "Generate a sample of size 300 from the interventional distribution $P_{\\mathrm{do}(X=\\mathcal{N}(2, 1))}$\n", 95 | "and store the data matrix as `data.int`." 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "jp-MarkdownHeadingCollapsed": true 102 | }, 103 | "source": [ 104 | "#### Solution" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "id": "bokBGvsmVCQJ" 112 | }, 113 | "outputs": [], 114 | "source": [] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": { 119 | "id": "l3wOg_4vozpz" 120 | }, 121 | "source": [ 122 | "### (c)\n", 123 | "\n", 124 | "Do you expect the marginal distribution of $Y$ to be different in both samples?" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": { 130 | "jp-MarkdownHeadingCollapsed": true 131 | }, 132 | "source": [ 133 | "#### Solution" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": { 139 | "id": "3paV1bkro6lV" 140 | }, 141 | "source": [ 142 | "Double-click (or enter) to edit" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": { 148 | "id": "CH9Tt444o-RH" 149 | }, 150 | "source": [ 151 | "### (d)\n", 152 | "\n", 153 | "Do you expect the joint distribution of $(A, Y)$ to be different in both samples?" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": { 159 | "jp-MarkdownHeadingCollapsed": true 160 | }, 161 | "source": [ 162 | "#### Solution" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": { 168 | "id": "FJz4fZKEpE4-" 169 | }, 170 | "source": [ 171 | "Double-click (or enter) to edit" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": { 177 | "id": "eZmh_AizpGp-" 178 | }, 179 | "source": [ 180 | "### (e)\n", 181 | "\n", 182 | "Check your answers to c) and d) empirically." 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": { 188 | "jp-MarkdownHeadingCollapsed": true 189 | }, 190 | "source": [ 191 | "#### Solution" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": { 198 | "id": "ZMiVnsjeVC2-" 199 | }, 200 | "outputs": [], 201 | "source": [] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": { 206 | "id": "ZjECw2eiVFjC" 207 | }, 208 | "source": [ 209 | "# Exercise 2 – Adjusting" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": { 215 | "id": "il0b9fnVq-bz" 216 | }, 217 | "source": [ 218 | "\n", 219 | "![DAG](https://raw.githubusercontent.com/sweichwald/causality-tutorial-exercises/main/data/Exercise-ANM.png)\n", 220 | "\n", 221 | "Suppose we are given a fixed DAG (like the one above)." 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "### (a)\n", 229 | "What are valid adjustment sets (VAS) used for?" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "jp-MarkdownHeadingCollapsed": true 236 | }, 237 | "source": [ 238 | "#### Solution" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "Double-click (or enter) to edit" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "### (b)\n", 253 | "\n", 254 | "Assume we want to find a VAS for the causal effect from $X$ to $Y$.\n", 255 | "What are general recipies (plural 😉) for constructing VASs (no proof)?\n", 256 | "Which sets are VAS in the DAG above?" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": { 262 | "jp-MarkdownHeadingCollapsed": true 263 | }, 264 | "source": [ 265 | "#### Solution" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "Double-click (or enter) to edit" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "### (c)\n", 280 | "\n", 281 | "The following code samples from an SCM. Perform linear regressions using different VAS and compare the regression coefficient against the causal effect from $X$ to $Y$." 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": { 288 | "id": "IaJjtEMUoO1I" 289 | }, 290 | "outputs": [], 291 | "source": [ 292 | "set.seed(1)\n", 293 | "\n", 294 | "n <- 200\n", 295 | "C <- rnorm(n)\n", 296 | "A <- 0.8*rnorm(n)\n", 297 | "K <- A + 1.1*rnorm(n)\n", 298 | "X <- C - 2*A + 0.2*rnorm(n)\n", 299 | "F <- 3*X + 0.8*rnorm(n)\n", 300 | "D <- -2*X + 0.5*rnorm(n)\n", 301 | "G <- D + 0.5*rnorm(n)\n", 302 | "Y <- 2*K - D + 0.2*rnorm(n)\n", 303 | "H <- 0.5*Y + 0.1*rnorm(n)\n", 304 | "\n", 305 | "data.obs <- cbind(C, A, K, X, F, D, G, Y, H)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": { 311 | "jp-MarkdownHeadingCollapsed": true 312 | }, 313 | "source": [ 314 | "#### Solution" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "### (d)\n", 329 | "\n", 330 | "Why could it be interesting to have several options for choosing a VAS?" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": { 336 | "jp-MarkdownHeadingCollapsed": true 337 | }, 338 | "source": [ 339 | "#### Solution" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "Double-click (or enter) to edit" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": { 352 | "id": "xC4_cF0XoQqN" 353 | }, 354 | "source": [ 355 | "### (e)\n", 356 | "\n", 357 | "If you indeed have access to several VASs, what would you do?" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": { 363 | "jp-MarkdownHeadingCollapsed": true 364 | }, 365 | "source": [ 366 | "#### Solution" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": {}, 372 | "source": [ 373 | "Double-click (or enter) to edit" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": { 379 | "id": "LQ7RuuF4rMD6", 380 | "jp-MarkdownHeadingCollapsed": true 381 | }, 382 | "source": [ 383 | "# Exercise 3 – Independence-based Causal Structure Learning" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": {}, 389 | "source": [ 390 | "### (a)\n", 391 | "\n", 392 | "Assume $P^{X,Y,Z}$ is Markov and faithful wrt. $G$. Assume all (!) conditional independences are\n", 393 | "\n", 394 | "$$\n", 395 | "\\newcommand{\\indep}{{\\,⫫\\,}}\n", 396 | "\\newcommand{\\dep}{\\not{}\\!\\!\\indep}\n", 397 | "$$\n", 398 | "\n", 399 | "$$X \\dep Z \\mid \\emptyset$$\n", 400 | "\n", 401 | "(plus symmetric statements). What is $G$?" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "#### Solution" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "metadata": { 414 | "id": "p21N9AFBrB0o" 415 | }, 416 | "source": [ 417 | "Double-click (or enter) to edit" 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "metadata": {}, 423 | "source": [ 424 | "### (b)\n", 425 | "\n", 426 | "Assume $P^{W,X,Y,Z}$ is Markov and faithful wrt. $G$. Assume all (!) conditional independences are\n", 427 | "\n", 428 | "$$\\begin{aligned}\n", 429 | "(Y,Z) &\\indep W \\mid \\emptyset \\\\\n", 430 | "W &\\indep Y \\mid (X,Z) \\\\\n", 431 | "(X,W) &\\indep Y | Z\n", 432 | "\\end{aligned}\n", 433 | "$$\n", 434 | "\n", 435 | "(plus symmetric statements). What is $G$?" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": {}, 441 | "source": [ 442 | "#### Solution" 443 | ] 444 | }, 445 | { 446 | "cell_type": "markdown", 447 | "metadata": {}, 448 | "source": [ 449 | "Double-click (or enter) to edit" 450 | ] 451 | }, 452 | { 453 | "cell_type": "markdown", 454 | "metadata": { 455 | "id": "DtBvbvkNWXSo", 456 | "jp-MarkdownHeadingCollapsed": true 457 | }, 458 | "source": [ 459 | "# Exercise 4 – Additive Noise Models" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": { 465 | "id": "4ONM6ulNwWpk" 466 | }, 467 | "source": [ 468 | "Set-up required packages:" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "metadata": { 475 | "id": "jqnELcs_pI32" 476 | }, 477 | "outputs": [], 478 | "source": [ 479 | "# set up – not needed when run on mybinder\n", 480 | "# if needed (colab), change FALSE to TRUE and run cell\n", 481 | "if (FALSE) {\n", 482 | " install.packages('dHSIC')\n", 483 | "}" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": { 490 | "id": "vcH5XTGNweBH" 491 | }, 492 | "outputs": [], 493 | "source": [ 494 | "library(mgcv)\n", 495 | "library(dHSIC)" 496 | ] 497 | }, 498 | { 499 | "cell_type": "markdown", 500 | "metadata": { 501 | "id": "hXoqFwV0wiGT" 502 | }, 503 | "source": [ 504 | "Let's load and plot some real data set:" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": null, 510 | "metadata": { 511 | "id": "X10gfbouWe7z" 512 | }, 513 | "outputs": [], 514 | "source": [ 515 | "# Load some real data set\n", 516 | "real.dat <- read.csv('https://raw.githubusercontent.com/sweichwald/causality-tutorial-exercises/main/data/Exercise-ANM.csv')\n", 517 | "Y <- real.dat[, \"Y\"]\n", 518 | "X <- real.dat[, \"X\"]\n", 519 | "\n", 520 | "# Let us plot the data\n", 521 | "par(mfrow=c(1,1))\n", 522 | "plot(X, Y, pch = 19, cex = .8)" 523 | ] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "metadata": { 528 | "id": "XV-Pvjsqx7Fz" 529 | }, 530 | "source": [ 531 | "### (a)\n", 532 | "\n", 533 | "Do you believed that $X \\to Y$ or that $X \\gets Y$? Why?" 534 | ] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "metadata": {}, 539 | "source": [ 540 | "#### Solution" 541 | ] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": { 546 | "id": "QUv9ROdOyCXB" 547 | }, 548 | "source": [ 549 | "Double-click (or enter) to edit" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "metadata": {}, 555 | "source": [ 556 | "### (b)\n", 557 | "\n", 558 | "$$\n", 559 | "\\newcommand{\\indep}{{\\,⫫\\,}}\n", 560 | "\\newcommand{\\dep}{\\not{}\\!\\!\\indep}\n", 561 | "$$\n", 562 | "\n", 563 | "Let us now try to get a more statistical answer. We have heard that we cannot \n", 564 | "have \n", 565 | "$$Y = f(X) + N_Y,\\ N_Y \\indep X$$\n", 566 | "and\n", 567 | "$$X = g(Y) + N_X,\\ N_X \\indep Y$$\n", 568 | "at the same time.\n", 569 | "\n", 570 | "Given a data set over $(X,Y)$,\n", 571 | "we now want to decide for one of the two models. \n", 572 | "\n", 573 | "Come up with a method to do so.\n", 574 | "\n", 575 | "Hints: \n", 576 | "* `gam(B ∼ s(A))$residuals` provides residuals when regressing $B$ on $A$. \n", 577 | "* `dhsic.test` (with `method = \"gamma\"`) can be used as an independence test." 578 | ] 579 | }, 580 | { 581 | "cell_type": "markdown", 582 | "metadata": {}, 583 | "source": [ 584 | "#### Solution" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": { 591 | "id": "gkvF9mjzW4tS" 592 | }, 593 | "outputs": [], 594 | "source": [] 595 | }, 596 | { 597 | "cell_type": "markdown", 598 | "metadata": { 599 | "id": "Ff7xkIzByx8X" 600 | }, 601 | "source": [ 602 | "### (c)\n", 603 | "\n", 604 | "Assume that the error terms are Gaussian with zero mean and variances \n", 605 | "$\\sigma_X^2$ and $\\sigma_Y^2$, respectively.\n", 606 | "The maximum likelihood for DAG G is \n", 607 | "then proportional to \n", 608 | "$-\\log(\\mathrm{var}(R^G_X)) - \\log(\\mathrm{var}(R^G_Y))$,\n", 609 | "where $R^G_X$ and $R^G_Y$ are the residuals obtained from regressing $X$ and $Y$ on \n", 610 | "their parents in $G$, respectively (no proof).\n", 611 | "\n", 612 | "Find the maximum likelihood solution." 613 | ] 614 | }, 615 | { 616 | "cell_type": "markdown", 617 | "metadata": {}, 618 | "source": [ 619 | "#### Solution" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": null, 625 | "metadata": { 626 | "id": "fUB-zlgwW6FS" 627 | }, 628 | "outputs": [], 629 | "source": [] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "metadata": { 634 | "id": "iPgnXJvlXAc4" 635 | }, 636 | "source": [ 637 | "# Exercise 5 – Invariant Causal Prediction" 638 | ] 639 | }, 640 | { 641 | "cell_type": "markdown", 642 | "metadata": { 643 | "id": "-cuaTYbq09wN" 644 | }, 645 | "source": [ 646 | "### (a)\n", 647 | "\n", 648 | "Generate some observational and interventional data:" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": { 655 | "id": "4OPBqL5jXHqZ" 656 | }, 657 | "outputs": [], 658 | "source": [ 659 | "# Generate n=1000 observations from the observational distribution\n", 660 | "na <- 1000\n", 661 | "Xa <- rnorm(na)\n", 662 | "Ya <- 1.5*Xa + rnorm(na)\n", 663 | "\n", 664 | "# Generate n=1000 observations from an interventional distribution\n", 665 | "nb <- 1000\n", 666 | "Xb <- rnorm(nb, 2, 1)\n", 667 | "Yb <- 1.5*Xb + rnorm(nb)\n", 668 | "red <- rgb(1,0,0,alpha=0.4)\n", 669 | "blue <- rgb(0,0,1,alpha=0.4)\n", 670 | "\n", 671 | "# plot Y vs X\n", 672 | "plot(Xa,Ya,pch=16,col=blue,xlim=range(c(Xa,Xb)),ylim=range(c(Ya,Yb)),xlab=\"X\",ylab=\"Y\")\n", 673 | "points(Xb,Yb,pch=17,col=red)\n", 674 | "legend(\"topright\",c(\"observational\",\"interventional\"),pch=c(16,17),col=c(blue,red),inset=0.02)" 675 | ] 676 | }, 677 | { 678 | "cell_type": "markdown", 679 | "metadata": { 680 | "id": "uZcSibWjypDR" 681 | }, 682 | "source": [ 683 | "Look at the above plot. Is the predictor $\\{X\\}$ an invariant set, that is (roughly speaking), does $Y \\mid X = x$ have the same distribution in the red and blue data?" 684 | ] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "metadata": { 689 | "jp-MarkdownHeadingCollapsed": true 690 | }, 691 | "source": [ 692 | "#### Solution" 693 | ] 694 | }, 695 | { 696 | "cell_type": "markdown", 697 | "metadata": { 698 | "id": "rhnmzEIiyvmt" 699 | }, 700 | "source": [ 701 | "Double-click (or enter) to edit" 702 | ] 703 | }, 704 | { 705 | "cell_type": "markdown", 706 | "metadata": { 707 | "id": "z0PC0Vy01BNc" 708 | }, 709 | "source": [ 710 | "### (b)\n", 711 | "We now consider data over a response and three covariates $X1, X2$, and $X3$\n", 712 | "and try to infer $\\mathrm{pa}(Y)$. To do so, we need to find all sets for which this\n", 713 | "invariance is satisfied." 714 | ] 715 | }, 716 | { 717 | "cell_type": "code", 718 | "execution_count": null, 719 | "metadata": { 720 | "id": "Al8FBjpXXSDV" 721 | }, 722 | "outputs": [], 723 | "source": [ 724 | "data <- as.matrix(read.csv('https://raw.githubusercontent.com/sweichwald/causality-tutorial-exercises/main/data/Exercise-ICP.csv'))\n", 725 | "pairs(data, col = c(rep(1,140), rep(2,80)))\n", 726 | "\n", 727 | "# The code below plots the residuals versus fitted values for all sets of \n", 728 | "# predictors. \n", 729 | "# extract response and predictors\n", 730 | "Y <- data[,1]\n", 731 | "Xmat <- data[,2:4]\n", 732 | "S <- list( c(1), c(2), c(3), c(1,2), c(1,3), c(2,3), c(1,2,3))\n", 733 | "resid <- fitted <- vector(\"list\", length(S))\n", 734 | "for(i in 1:length(S)){\n", 735 | " modelfit <- lm.fit(Xmat[,S[[i]],drop=FALSE], Y)\n", 736 | " resid[[i]] <- modelfit$residuals\n", 737 | " fitted[[i]] <- modelfit$fitted.values\n", 738 | "}\n", 739 | "env <- c(rep(0,140),rep(1,80))\n", 740 | "par(mfrow=c(2,2))\n", 741 | "red <- rgb(1,0,0,alpha=0.4)\n", 742 | "blue <- rgb(0,0,1,alpha=0.4)\n", 743 | "names <- c(\"X1\", \"X2\", \"X3\", \"X1, X2\", \"X1, X3\", \"X2, X3\", \"X1, X2, X3\")\n", 744 | "plot((1:length(Y))[env==0], Y[env==0], pch=16, col=blue, xlim=c(0,220), ylim=range(Y), xlab=\"index\", ylab=\"Y\", main=\"empty set\")\n", 745 | "points((1:length(Y))[env==1], Y[env==1], pch=17, col=red)\n", 746 | "legend(\"topleft\",c(\"observational\",\"interventional\"),pch=c(16,17),col=c(blue,red),inset=0.02)\n", 747 | "for(i in 1:length(S)){\n", 748 | " plot(fitted[[i]][env==0], resid[[i]][env==0], pch=16, col=blue, xlim=range(fitted[[i]]), ylim=range(resid[[i]]), xlab=\"fitted values\", ylab=\"residuals\", main=names[i])\n", 749 | " points(fitted[[i]][env==1], resid[[i]][env==1], pch=17, col=red)\n", 750 | " legend(\"topleft\",c(\"observational\",\"interventional\"),pch=c(16,17),col=c(blue,red),inset=0.02)\n", 751 | "}\n" 752 | ] 753 | }, 754 | { 755 | "cell_type": "markdown", 756 | "metadata": { 757 | "id": "1GfZKCL7zJve" 758 | }, 759 | "source": [ 760 | "Which of the sets are invariant? (There are two plots with four scatter plots each.)" 761 | ] 762 | }, 763 | { 764 | "cell_type": "markdown", 765 | "metadata": { 766 | "jp-MarkdownHeadingCollapsed": true 767 | }, 768 | "source": [ 769 | "#### Solution" 770 | ] 771 | }, 772 | { 773 | "cell_type": "markdown", 774 | "metadata": { 775 | "id": "j0sgjfRSzWEt" 776 | }, 777 | "source": [ 778 | "Double-click (or enter) to edit" 779 | ] 780 | }, 781 | { 782 | "cell_type": "markdown", 783 | "metadata": { 784 | "id": "AO7tZSjLzMr0" 785 | }, 786 | "source": [ 787 | "### (c)\n", 788 | "What is your best guess for $\\mathrm{pa}(Y)$?" 789 | ] 790 | }, 791 | { 792 | "cell_type": "markdown", 793 | "metadata": { 794 | "jp-MarkdownHeadingCollapsed": true 795 | }, 796 | "source": [ 797 | "#### Solution" 798 | ] 799 | }, 800 | { 801 | "cell_type": "markdown", 802 | "metadata": { 803 | "id": "B6QtA9p9zdD7" 804 | }, 805 | "source": [ 806 | "Double-click (or enter) to edit" 807 | ] 808 | }, 809 | { 810 | "cell_type": "markdown", 811 | "metadata": { 812 | "id": "qU-jIHvX1rRU" 813 | }, 814 | "source": [ 815 | "### (d) \n", 816 | "**(optional, and R only)**\n", 817 | "\n", 818 | "Use the function ICP to check your result." 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": null, 824 | "metadata": { 825 | "id": "bc1nr0TgpNrb" 826 | }, 827 | "outputs": [], 828 | "source": [ 829 | "# set up – not needed when run on mybinder\n", 830 | "# if needed (colab), change FALSE to TRUE and run cell\n", 831 | "if (FALSE) {\n", 832 | " install.packages('InvariantCausalPrediction')\n", 833 | "}" 834 | ] 835 | }, 836 | { 837 | "cell_type": "code", 838 | "execution_count": null, 839 | "metadata": { 840 | "id": "FGiVa_SDXTPc" 841 | }, 842 | "outputs": [], 843 | "source": [ 844 | "library(InvariantCausalPrediction)" 845 | ] 846 | }, 847 | { 848 | "cell_type": "markdown", 849 | "metadata": { 850 | "jp-MarkdownHeadingCollapsed": true 851 | }, 852 | "source": [ 853 | "#### Solution" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": null, 859 | "metadata": {}, 860 | "outputs": [], 861 | "source": [] 862 | }, 863 | { 864 | "cell_type": "markdown", 865 | "metadata": {}, 866 | "source": [ 867 | "# Exercise 6 - Confounding and selection bias" 868 | ] 869 | }, 870 | { 871 | "cell_type": "markdown", 872 | "metadata": {}, 873 | "source": [ 874 | "### Generate data\n", 875 | "\n", 876 | "We start by generating data from the following SCM" 877 | ] 878 | }, 879 | { 880 | "cell_type": "code", 881 | "execution_count": null, 882 | "metadata": {}, 883 | "outputs": [], 884 | "source": [ 885 | "generate_data <- function(seed){\n", 886 | " set.seed(seed)\n", 887 | " n <- 200\n", 888 | " V <- rbinom(n, 1, 0.2)\n", 889 | " W <- 3*V + rnorm(n)\n", 890 | " X <- V + rnorm(n)\n", 891 | " Y <- X + W^2 + 1 + rnorm(n)\n", 892 | " Z <- X + Y + rnorm(n)\n", 893 | " data.obs <- data.frame(V=V, W=W, X=X, Y=Y, Z=Z)\n", 894 | " return(data.obs)\n", 895 | "}\n", 896 | "\n", 897 | "data.obs <- generate_data(1)\n", 898 | "\n", 899 | "# Visualize data set\n", 900 | "pairs(data.obs)" 901 | ] 902 | }, 903 | { 904 | "cell_type": "markdown", 905 | "metadata": {}, 906 | "source": [ 907 | "Assume now that we know the causal ordering induced by the SCM and that\n", 908 | "* $X$ is a treatment variable,\n", 909 | "* $Y$ is the response and\n", 910 | "* $(V, W, Z)$ are additional covariates.\n", 911 | "\n", 912 | "Furthermore we will assume a partially linear outcome model, i.e., \n", 913 | "$$Y = \\theta X + g(V, W) + \\epsilon\\quad \\text{with}\\quad\\mathbb{E}[\\epsilon\\mid X, V, W]=0.$$\n", 914 | "\n", 915 | "We are interested in estimating the causal effect of $X$ on $Y$, corresponding to the parameter $\\theta$ in the partially linear model." 916 | ] 917 | }, 918 | { 919 | "cell_type": "markdown", 920 | "metadata": {}, 921 | "source": [ 922 | "### Background: Confounding and selection bias" 923 | ] 924 | }, 925 | { 926 | "cell_type": "markdown", 927 | "metadata": {}, 928 | "source": [ 929 | "Ignoring the causal structure can lead to wrong conclusions. In the following exercise, we will see the two most important types of bias that may occur:\n", 930 | "* **Confounding bias:** Bias arising because of unaccounted variables that have an effect on both treatment and response.\n", 931 | "* **Selection bias:** Bias arising due to conditioning on descendents of the response. This can occur either if we only observe a subset of the entire sample or if we mistakenly include a descendent of the response in the outcome model." 932 | ] 933 | }, 934 | { 935 | "cell_type": "markdown", 936 | "metadata": {}, 937 | "source": [ 938 | "### Exercises" 939 | ] 940 | }, 941 | { 942 | "cell_type": "markdown", 943 | "metadata": {}, 944 | "source": [ 945 | "### (a)\n", 946 | "\n", 947 | "In the code below we fitted several different outcome models. Compare the resulting coefficients for $X$. Which regressions appear to lead to unbiased estimates of the causal effect?" 948 | ] 949 | }, 950 | { 951 | "cell_type": "code", 952 | "execution_count": null, 953 | "metadata": {}, 954 | "outputs": [], 955 | "source": [ 956 | "library(gam)\n", 957 | "\n", 958 | "# linear model of Y on X\n", 959 | "lin_YX <- lm(Y ~ X, data=data.obs)\n", 960 | "# linear model of Y on X and V\n", 961 | "lin_YV <- lm(Y ~ X + V, data=data.obs)\n", 962 | "# linear model Y on X and W\n", 963 | "lin_YW <- lm(Y ~ X + W, data=data.obs)\n", 964 | "# gam model of Y on X and s(W)\n", 965 | "gam_YW <- gam(Y ~ X + s(W), data=data.obs)\n", 966 | "# gam model of Y on X, V and s(W)\n", 967 | "gam_YVW <- gam(Y ~ X + V + s(W), data=data.obs)\n", 968 | "# gam model of Y on X, s(W), s(Z)\n", 969 | "gam_YWZ <- gam(Y ~ X + s(W) + s(Z), data=data.obs)\n", 970 | "\n", 971 | "# Print each model\n", 972 | "results = list(linear_X = unname(coefficients(lin_YX)['X']),\n", 973 | " linear_V = unname(coefficients(lin_YV)['X']),\n", 974 | " linear_W = unname(coefficients(lin_YW)['X']),\n", 975 | " gam_W = unname(coefficients(gam_YW)['X']),\n", 976 | " gam_VW = unname(coefficients(gam_YVW)['X']),\n", 977 | " gam_VWZ = unname(coefficients(gam_YWZ)['X']))\n", 978 | "results" 979 | ] 980 | }, 981 | { 982 | "cell_type": "markdown", 983 | "metadata": { 984 | "jp-MarkdownHeadingCollapsed": true 985 | }, 986 | "source": [ 987 | "#### Solution" 988 | ] 989 | }, 990 | { 991 | "cell_type": "markdown", 992 | "metadata": {}, 993 | "source": [ 994 | "Double-click (or enter) to edit" 995 | ] 996 | }, 997 | { 998 | "cell_type": "markdown", 999 | "metadata": {}, 1000 | "source": [ 1001 | "### (b)\n", 1002 | "List all valid adjustment sets for this causal structure." 1003 | ] 1004 | }, 1005 | { 1006 | "cell_type": "markdown", 1007 | "metadata": { 1008 | "jp-MarkdownHeadingCollapsed": true 1009 | }, 1010 | "source": [ 1011 | "#### Solution" 1012 | ] 1013 | }, 1014 | { 1015 | "cell_type": "markdown", 1016 | "metadata": {}, 1017 | "source": [ 1018 | "Double-click (or enter) to edit" 1019 | ] 1020 | }, 1021 | { 1022 | "cell_type": "markdown", 1023 | "metadata": {}, 1024 | "source": [ 1025 | "### (c)\n", 1026 | "Assume now that you only have access to the subset $\\texttt{data.cond}$ constructed in the code snippet below. Use a gam regression Y ~ X + s(W) to estimate the causal effect. What do you observe?" 1027 | ] 1028 | }, 1029 | { 1030 | "cell_type": "code", 1031 | "execution_count": null, 1032 | "metadata": {}, 1033 | "outputs": [], 1034 | "source": [ 1035 | "data.cond = data.obs[data.obs$Z<1,]" 1036 | ] 1037 | }, 1038 | { 1039 | "cell_type": "markdown", 1040 | "metadata": { 1041 | "jp-MarkdownHeadingCollapsed": true 1042 | }, 1043 | "source": [ 1044 | "#### Solution" 1045 | ] 1046 | }, 1047 | { 1048 | "cell_type": "code", 1049 | "execution_count": null, 1050 | "metadata": {}, 1051 | "outputs": [], 1052 | "source": [] 1053 | }, 1054 | { 1055 | "cell_type": "markdown", 1056 | "metadata": {}, 1057 | "source": [ 1058 | "Double-click (or enter) to edit" 1059 | ] 1060 | }, 1061 | { 1062 | "cell_type": "markdown", 1063 | "metadata": {}, 1064 | "source": [ 1065 | "# Exercise 7 - Estimating causal effects" 1066 | ] 1067 | }, 1068 | { 1069 | "cell_type": "markdown", 1070 | "metadata": {}, 1071 | "source": [ 1072 | "We use the same data $\\texttt{data.obs}$ as in Exercise 6 and make the same assumptions. In this exercise you will go over the following approaches for estimating the causal effect from $X$ to $Y$:\n", 1073 | "* **Covariate adjustment:** Directly estimate the outcome model based on a valid adjustment set and use it estimate the causal effect.\n", 1074 | "* **Propensity score matching:** Estimate the propensity score, use it to match samples and then estimate the causal effect based on the matched data.\n", 1075 | "* **Inverse probability weighting:** Estimate the propensity score, use it to weight the samples and then estimate the causal effect based on the weighted sample.\n", 1076 | "* **Double machine learning:** Estimate a regression function for both the propensity model and the outcome model and combine them to estimate the causal effect." 1077 | ] 1078 | }, 1079 | { 1080 | "cell_type": "markdown", 1081 | "metadata": {}, 1082 | "source": [ 1083 | "### (a)\n", 1084 | "In this part of the exercise, we will compute a covariate adjustment estimator using the library $\\texttt{gam}$." 1085 | ] 1086 | }, 1087 | { 1088 | "cell_type": "markdown", 1089 | "metadata": {}, 1090 | "source": [ 1091 | "Implement a function $\\texttt{causal\\_effect\\_adjustment}$ that the data $\\texttt{data.obs}$ as input and computes the covariate adjustment estimator with the gam equation Y ~ X + s(W)." 1092 | ] 1093 | }, 1094 | { 1095 | "cell_type": "markdown", 1096 | "metadata": { 1097 | "jp-MarkdownHeadingCollapsed": true 1098 | }, 1099 | "source": [ 1100 | "#### Solution" 1101 | ] 1102 | }, 1103 | { 1104 | "cell_type": "code", 1105 | "execution_count": null, 1106 | "metadata": {}, 1107 | "outputs": [], 1108 | "source": [ 1109 | "library(gam)\n", 1110 | "\n", 1111 | "# Function to estimate causal effect based on covariate adjustment\n", 1112 | "causal_effect_adjustment <- function(data){\n", 1113 | " ### ADD CODE HERE ###\n", 1114 | "}\n", 1115 | "\n", 1116 | "# Estimate causal effect\n", 1117 | "ate_adjustment <- causal_effect_adjustment(data.obs)\n", 1118 | "ate_adjustment" 1119 | ] 1120 | }, 1121 | { 1122 | "cell_type": "markdown", 1123 | "metadata": {}, 1124 | "source": [ 1125 | "### (b)\n", 1126 | "In this part of the exercise, we will compute a propensity matching estimator using the library $\\texttt{Matchit}$." 1127 | ] 1128 | }, 1129 | { 1130 | "cell_type": "markdown", 1131 | "metadata": {}, 1132 | "source": [ 1133 | "Since most matching methods apply only to binary treatments, we first discretize the treatment $X$, compute a binary treatment effect and then backtransform with an adjustment factor. The following code snippet explains this (you can later copy this to your propensity score estimation function)." 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "code", 1138 | "execution_count": null, 1139 | "metadata": {}, 1140 | "outputs": [], 1141 | "source": [ 1142 | "library(MatchIt)\n", 1143 | "\n", 1144 | "# Create binary treatment (more complicated matching procedures also exist of continuous responses)\n", 1145 | "data.matching <- data.obs\n", 1146 | "T <- as.numeric(data.obs$X > median(data.obs$X))\n", 1147 | "upperT <- mean(data.obs$X[T == 1])\n", 1148 | "lowerT <- mean(data.obs$X[T == 0])\n", 1149 | "adjust_factor <- upperT-lowerT\n", 1150 | "data.matching$T <- T\n", 1151 | "print(adjust_factor)\n", 1152 | "\n", 1153 | "# Without confounding the following estimator would be unbiased\n", 1154 | "lmfit <- lm(Y ~ T, data = data.matching)\n", 1155 | "coefficients(lmfit)['T']/adjust_factor" 1156 | ] 1157 | }, 1158 | { 1159 | "cell_type": "markdown", 1160 | "metadata": {}, 1161 | "source": [ 1162 | "We will use the $\\texttt{MatchIt}$ package. First, we need to select an appropriate matching procedure. Consider the following two options, which is preferable?" 1163 | ] 1164 | }, 1165 | { 1166 | "cell_type": "code", 1167 | "execution_count": null, 1168 | "metadata": {}, 1169 | "outputs": [], 1170 | "source": [ 1171 | "# Create a matching object without matching to check if confounding exists\n", 1172 | "match0 <- matchit(T ~ V + W, data = data.matching,\n", 1173 | " method = NULL, distance = \"glm\")\n", 1174 | "summary(match0)\n", 1175 | "plot(match0, type=\"density\")" 1176 | ] 1177 | }, 1178 | { 1179 | "cell_type": "code", 1180 | "execution_count": null, 1181 | "metadata": {}, 1182 | "outputs": [], 1183 | "source": [ 1184 | "# Matching Option 1: Using coarsend exact matching method using W\n", 1185 | "match1 <- matchit(T ~ W, data = data.matching,\n", 1186 | " method = \"cem\", distance=\"glm\")\n", 1187 | "summary(match1)\n", 1188 | "plot(match1, type=\"density\")" 1189 | ] 1190 | }, 1191 | { 1192 | "cell_type": "code", 1193 | "execution_count": null, 1194 | "metadata": {}, 1195 | "outputs": [], 1196 | "source": [ 1197 | "# Matching Option 2: Using exact matching (matches the covariate V directly)\n", 1198 | "match2 <- matchit(T ~ V, data = data.matching,\n", 1199 | " method = \"exact\")\n", 1200 | "summary(match2)\n", 1201 | "plot(match2, type=\"density\")" 1202 | ] 1203 | }, 1204 | { 1205 | "cell_type": "markdown", 1206 | "metadata": {}, 1207 | "source": [ 1208 | "Use the selected matching procedure to implement a function $\\texttt{causal\\_effect\\_matching}$ that takes the data $\\texttt{data.obs}$ as input and computes the propensity score matching estimator of the causal effect of $X$ on $Y$.\n", 1209 | "\n", 1210 | "*Hint: Use the code at the beginning of the question as the first part of your function.*" 1211 | ] 1212 | }, 1213 | { 1214 | "cell_type": "markdown", 1215 | "metadata": { 1216 | "jp-MarkdownHeadingCollapsed": true 1217 | }, 1218 | "source": [ 1219 | "#### Solution" 1220 | ] 1221 | }, 1222 | { 1223 | "cell_type": "code", 1224 | "execution_count": null, 1225 | "metadata": {}, 1226 | "outputs": [], 1227 | "source": [ 1228 | "# Function to estimate causal effect based matching\n", 1229 | "causal_effect_matching <- function(data){\n", 1230 | " # Discretize\n", 1231 | " data.matching <- data\n", 1232 | " T <- as.numeric(data$X > median(data$X))\n", 1233 | " upperT <- mean(data$X[T == 1])\n", 1234 | " lowerT <- mean(data$X[T == 0])\n", 1235 | " adjust_factor <- upperT-lowerT\n", 1236 | " data.matching$T <- T\n", 1237 | " \n", 1238 | " ### APPLY MATCHING HERE ###\n", 1239 | " \n", 1240 | " data.matched <- match.data(match2)\n", 1241 | " # Compute causal effect\n", 1242 | " fit_matched <- lm(Y ~ T, data = data.matched, weights=weights)\n", 1243 | " return(unname(coefficients(fit_matched)['T'])/adjust_factor)\n", 1244 | "}\n", 1245 | "\n", 1246 | "# Estimate causal effect\n", 1247 | "ate_matching <- causal_effect_matching(data.obs)\n", 1248 | "ate_matching" 1249 | ] 1250 | }, 1251 | { 1252 | "cell_type": "markdown", 1253 | "metadata": {}, 1254 | "source": [ 1255 | "### (c)\n", 1256 | "In this part of the exercise, we will compute an inverse probability weighting estimator using the library $\\texttt{WeightIt}$." 1257 | ] 1258 | }, 1259 | { 1260 | "cell_type": "code", 1261 | "execution_count": null, 1262 | "metadata": {}, 1263 | "outputs": [], 1264 | "source": [ 1265 | "library(WeightIt)\n", 1266 | "\n", 1267 | "# Fit a weightit object based on the covariates (V, W)\n", 1268 | "weight_obj <- weightit(X ~ V + W, data = data.obs, estimand = \"ATE\", method = \"glm\")\n", 1269 | "weight_obj\n", 1270 | "summary(weight_obj)\n", 1271 | "\n", 1272 | "# Plot a histogram of the weights\n", 1273 | "hist(weight_obj$weights)" 1274 | ] 1275 | }, 1276 | { 1277 | "cell_type": "markdown", 1278 | "metadata": {}, 1279 | "source": [ 1280 | "Use the code block above to implement a function $\\texttt{causal\\_effect\\_weighting}$ that the data $\\texttt{data.obs}$ as input and computes the inverse probablity weighting estimator." 1281 | ] 1282 | }, 1283 | { 1284 | "cell_type": "markdown", 1285 | "metadata": { 1286 | "jp-MarkdownHeadingCollapsed": true 1287 | }, 1288 | "source": [ 1289 | "#### Solution" 1290 | ] 1291 | }, 1292 | { 1293 | "cell_type": "code", 1294 | "execution_count": null, 1295 | "metadata": {}, 1296 | "outputs": [], 1297 | "source": [ 1298 | "# Function to estimate causal effect based on inverse probability weighting\n", 1299 | "causal_effect_weighting <- function(data){\n", 1300 | " ### ADD CODE HERE ###\n", 1301 | "}\n", 1302 | "\n", 1303 | "# Estimate causal effect\n", 1304 | "ate_weighting <- causal_effect_weighting(data.obs)\n", 1305 | "ate_weighting" 1306 | ] 1307 | }, 1308 | { 1309 | "cell_type": "markdown", 1310 | "metadata": {}, 1311 | "source": [ 1312 | "### (d) \n", 1313 | "In this part of the exercise, we will compute a double machine learning estimator using the library $\\texttt{DoubleML}$." 1314 | ] 1315 | }, 1316 | { 1317 | "cell_type": "markdown", 1318 | "metadata": {}, 1319 | "source": [ 1320 | "Go over the following code and try to understand the individual steps." 1321 | ] 1322 | }, 1323 | { 1324 | "cell_type": "code", 1325 | "execution_count": null, 1326 | "metadata": {}, 1327 | "outputs": [], 1328 | "source": [ 1329 | "# Load packages (mlr3 packages are required to specify the ML learners)\n", 1330 | "library(DoubleML)\n", 1331 | "library(mlr3)\n", 1332 | "library(mlr3learners)\n", 1333 | "# Suppress output of mlr3 learners during estimation\n", 1334 | "lgr::get_logger(\"mlr3\")$set_threshold(\"warn\")\n", 1335 | "\n", 1336 | "\n", 1337 | "# Function to estimate causal effect based on double machine learning\n", 1338 | "causal_effect_dml <- function(data){\n", 1339 | " # Remove Z as all variables in data will be treated as valid adjustments by DoubleML package\n", 1340 | " data$Z <- NULL\n", 1341 | "\n", 1342 | " # Step 1:\n", 1343 | " # Format the data (this object encodes the causal structure)\n", 1344 | " obj_dml_data = DoubleMLData$new(data, y_col = \"Y\", d_cols = \"X\")\n", 1345 | "\n", 1346 | " # Step 2:\n", 1347 | " # Learner for Y given covariates (V, W) - using random forest from ranger (other learners are possible)\n", 1348 | " ml_l = lrn(\"regr.ranger\", num.trees = 100, mtry = 2, min.node.size = 2, max.depth = 5)\n", 1349 | " # Learner for X given covariates (V, W) - using random forest from ranger (other learners are possible)\n", 1350 | " ml_m = lrn(\"regr.ranger\", num.trees = 100, mtry = 2, min.node.size = 2, max.depth = 5)\n", 1351 | "\n", 1352 | " # Step 3:\n", 1353 | " # Setup DML task\n", 1354 | " doubleml_plr = DoubleMLPLR$new(obj_dml_data,\n", 1355 | " ml_l, ml_m,\n", 1356 | " n_folds = 2,\n", 1357 | " score = \"partialling out\")\n", 1358 | "\n", 1359 | " # Fit DML\n", 1360 | " doubleml_plr$fit()\n", 1361 | " # you can also look at: doubleml_plr$summary()\n", 1362 | "\n", 1363 | " return(unname(doubleml_plr$all_coef[1]))\n", 1364 | "}\n", 1365 | "\n", 1366 | "# Estimate causal effect\n", 1367 | "ate_weighting <- causal_effect_dml(data.obs)" 1368 | ] 1369 | }, 1370 | { 1371 | "cell_type": "markdown", 1372 | "metadata": {}, 1373 | "source": [ 1374 | "### (e)\n", 1375 | "In this part of the exercise, we will compare all estimators using a simulation study." 1376 | ] 1377 | }, 1378 | { 1379 | "cell_type": "markdown", 1380 | "metadata": {}, 1381 | "source": [ 1382 | "Run $10$ repeptitions of a simulation study, where in each step you create a new data sets using the $\\texttt{generate\\_data}$ function and then applying all of the estimators constructed in the exercises above. Create a boxplot of the results." 1383 | ] 1384 | }, 1385 | { 1386 | "cell_type": "markdown", 1387 | "metadata": { 1388 | "jp-MarkdownHeadingCollapsed": true 1389 | }, 1390 | "source": [ 1391 | "#### Solution" 1392 | ] 1393 | }, 1394 | { 1395 | "cell_type": "code", 1396 | "execution_count": null, 1397 | "metadata": {}, 1398 | "outputs": [], 1399 | "source": [] 1400 | } 1401 | ], 1402 | "metadata": { 1403 | "colab": { 1404 | "collapsed_sections": [], 1405 | "name": "Causality Tutorial Exercises – R", 1406 | "provenance": [], 1407 | "toc_visible": true 1408 | }, 1409 | "kernelspec": { 1410 | "display_name": "R", 1411 | "language": "R", 1412 | "name": "ir" 1413 | }, 1414 | "language_info": { 1415 | "codemirror_mode": "r", 1416 | "file_extension": ".r", 1417 | "mimetype": "text/x-r-source", 1418 | "name": "R", 1419 | "pygments_lexer": "r", 1420 | "version": "4.1.2" 1421 | } 1422 | }, 1423 | "nbformat": 4, 1424 | "nbformat_minor": 4 1425 | } 1426 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Causality Tutorial Exercises 2 | 3 | Contributors: 4 | Rune Christiansen, 5 | Jonas Peters, 6 | Niklas Pfister, 7 | Sorawit Saengkyongam, 8 | Sebastian Weichwald. 9 | The MIT License applies; copyright is with the authors. 10 | Some exercises are adapted from 11 | "Elements of Causal Inference: Foundations and Learning Algorithms" 12 | by J. Peters, D. Janzing and B. Schölkopf. 13 | 14 | 15 | ## Python 16 | 17 | __Launch a 18 | [![badge](https://img.shields.io/badge/python-binder-F5A252.svg?logo=)](https://mybinder.org/v2/gh/CoCaLa/causality-tutorial-exercises/HEAD?filepath=python%2FCausality_Tutorial_Exercises.ipynb) 19 | instance to work on the causality exercises in your browser.__ \ 20 | No installation/registration/account required. 21 | 22 | Alternatively, use one of the following options: 23 | 24 | 1. `podman run --rm -p 8888:8888 docker.io/learningbydoingdocker/causality-tutorial-exercises`, \ 25 | follow the shown link to open the notebook in your browser 26 | 2. clone repository, install [./requirements.txt](./requirements.txt), 27 | and run jupyter notebook 28 | [./python/Causality_Tutorial_Exercises.ipynb](./python/Causality_Tutorial_Exercises.ipynb) 29 | 3. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sweichwald/causality-tutorial-exercises/blob/master/python/Causality_Tutorial_Exercises.ipynb) 30 | 31 | 32 | ## R 33 | 34 | __Launch a 35 | [![badge](https://img.shields.io/badge/R-binder-F5A252.svg?logo=)](https://mybinder.org/v2/gh/CoCaLa/causality-tutorial-exercises/HEAD?filepath=R%2FCausality_Tutorial_Exercises.ipynb) 36 | instance to work on the causality exercises in your browser.__ \ 37 | No installation/registration/account required. 38 | 39 | Alternatively, use one of the following options: 40 | 41 | 1. `podman run --rm -p 8888:8888 docker.io/learningbydoingdocker/causality-tutorial-exercises`, \ 42 | follow the shown link to open the notebook in your browser 43 | 2. clone repository, install packages as per [./install.R](./install.R), 44 | and run jupyter notebook 45 | [./R/Causality_Tutorial_Exercises.ipynb](./R/Causality_Tutorial_Exercises.ipynb) 46 | 3. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sweichwald/causality-tutorial-exercises/blob/master/R/Causality_Tutorial_Exercises.ipynb) 47 | -------------------------------------------------------------------------------- /binder/Dockerfile: -------------------------------------------------------------------------------- 1 | ### DO NOT EDIT THIS FILE! This Is Automatically Generated And Will Be Overwritten ### 2 | FROM learningbydoingdocker/causality-tutorial-exercises:8638c889cc9f -------------------------------------------------------------------------------- /data/Exercise-ANM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoCaLa/causality-tutorial-exercises/38946e0b2f866f4be50bd6dd671d743c6738e72c/data/Exercise-ANM.png -------------------------------------------------------------------------------- /data/Exercise-ICP.csv: -------------------------------------------------------------------------------- 1 | "Y","X1","X2","X3" 2 | -0.127264734462173,-0.626453810742332,-1.00932569587833,0.270771691693335 3 | 0.0820728070930421,0.183643324222082,0.418959986625795,-0.15310992167803 4 | 0.0161920319132916,-0.835628612410047,-1.16862309965245,-0.02269878927885 5 | -0.285274895315157,1.59528080213779,1.50257472184331,0.109328021778564 6 | -0.224922620210731,0.329507771815361,0.106323750806791,0.281894917392641 7 | -0.108158659833954,-0.820468384118015,-0.970632184356705,0.269393339654132 8 | 0.134899219557214,0.487429052428485,0.904862361554155,-0.298627284267685 9 | 0.0227734991875767,0.738324705129217,0.741803829067868,-0.100730350507751 10 | -0.168563983862952,0.575781351653492,0.318521245566627,0.104446290487896 11 | -0.0658179036468895,-0.305388387156356,-0.633509494040072,-0.00229523571022609 12 | -0.136167531097734,1.51178116845085,1.60181858870538,-0.067161028464061 13 | -0.00357447438943485,0.389843236411431,0.386131269868503,0.0536708303142255 14 | 0.0483723455575684,-0.621240580541804,-0.684854255450573,-0.201552159554143 15 | -0.0326792032149482,-2.2146998871775,-2.40057231666824,0.0301794392871218 16 | -0.113145400283744,1.12493091814311,0.827438856114811,0.172443872386273 17 | -0.111084948598991,-0.0449336090152309,-0.259972068338367,0.091265406450527 18 | 0.198182382671351,-0.0161902630989461,0.183815497643837,-0.108981543423998 19 | -0.0734219567532505,0.943836210685299,0.819582871725935,0.0708504496614351 20 | -0.253309911340175,0.821221195098089,0.54433582562119,0.188543866281608 21 | 0.134343200592382,0.593901321217509,0.967759445702225,-0.0697072590889177 22 | 0.0484816783664232,0.918977371608218,1.00399744708271,-0.0918649523693821 23 | -0.2115810970888,0.782136300731067,0.734406880548461,0.38884221558665 24 | 0.316695206130979,0.0745649833651906,0.286261593106995,-0.31852117724728 25 | 0.266942677122494,-1.98935169586337,-1.81206716558839,-0.18166117776244 26 | 0.0291227896502904,0.61982574789471,0.495977138248481,-0.00860649932588773 27 | 0.421566439092819,-0.0561287395290008,0.385091753379093,-0.722371298984867 28 | -0.0067271195877975,-0.155795506705329,-0.206800912733532,-0.129884073545101 29 | 0.0328579718495402,-1.47075238389927,-1.75565131394184,-0.0752681978640214 30 | -0.0719677946770903,-0.47815005510862,-0.507029975499464,0.0956481610516461 31 | 0.0154110949901548,0.417941560199702,0.459449228046171,-0.249683407026051 32 | 0.245450698569837,1.35867955152904,1.82027523134092,-0.149281035231766 33 | 0.0328829055712647,-0.102787727342996,-0.0816272537642532,-0.0933254789567088 34 | -0.0293409954205633,0.387671611559369,0.479071372644052,-0.045946732515728 35 | -0.0694560334295372,-0.0538050405829051,-0.0692356276542113,-0.0861051258248874 36 | 0.0940336123366243,-1.37705955682861,-1.44385972530192,-0.23942298613673 37 | 0.144248479000294,-0.41499456329968,-0.421939768961935,-0.138615295394853 38 | 0.0855482550165267,-0.394289953710349,-0.236762032584317,-0.0346113143491374 39 | 0.242859729576648,-0.0593133967111857,0.355735605019271,-0.452648025555078 40 | -0.11612944492882,1.10002537198388,1.30550385973664,0.0156932469838867 41 | 0.118062716562136,0.763175748457544,1.00475742813489,-0.0644855443368766 42 | -0.000516298950191979,-0.164523596253587,-0.410788280565196,-0.0447874095147998 43 | 0.293107737360335,-0.253361680136508,-0.0565825661258317,-0.0765708871786082 44 | 0.0381649119885401,0.696963375404737,0.740948336136868,0.0864097552862705 45 | -0.418715202378696,0.556663198673657,0.263213192855209,0.478265005785656 46 | 0.179601248985105,-0.68875569454952,-0.584551146019892,-0.179112804031286 47 | 0.0973125234604102,-0.70749515696212,-0.739246077905323,-0.0693764452548842 48 | 0.103952252639702,0.36458196213683,0.65749942453079,-0.174542865166464 49 | -0.151734185309303,0.768532924515416,0.615316524594483,0.214535900605028 50 | -0.126794384669093,-0.112346212150228,-0.198388562935937,0.274815780677664 51 | -0.131320834929148,0.881107726454215,0.695885826978727,0.239663825937477 52 | -0.093773164874402,0.398105880367068,0.36268508807976,0.0124487392079929 53 | -0.0474641656509576,-0.612026393250771,-0.531624037353504,-0.114423519271401 54 | -0.158666825533843,0.341119691424425,0.194770056800503,0.147701255585019 55 | 0.349024582672865,-1.12936309608079,-0.963288462484458,-0.304935645581167 56 | -0.321700440999756,1.43302370170104,1.19140714444014,0.456799838980616 57 | -0.250523115166605,1.98039989850586,1.77080301594431,0.118662166713248 58 | 0.304319636487743,-0.367221476466509,-0.0789899350976528,-0.267881177226999 59 | -0.0170483627722674,-1.04413462631653,-1.24730411937746,0.0403983462811192 60 | -0.0427672278211244,0.569719627442413,0.652114569905916,0.16216275394703 61 | -0.0851932166583383,-0.135054603880824,-0.211269814102608,0.0824022194453806 62 | -0.117074632734886,2.40161776050478,2.48349812843496,0.0813447472844607 63 | 0.100243053069521,-0.0392400027331692,0.29853465450764,-0.214924466681358 64 | 0.146037760121599,0.689739362450777,1.00705704913917,-0.197779808480961 65 | -0.0714590886164297,0.0280021587806661,-0.0381794013558872,0.0352467113586025 66 | -0.42638987899565,-0.743273208882405,-1.2003203159409,0.66144531157456 67 | 0.139955115221346,0.188792299514343,0.688324617481176,0.10469802234865 68 | 0.35214573578022,-1.80495862889104,-1.67154539553794,-0.368816063728648 69 | -0.100724100894164,1.46555486156289,1.57382032875563,-0.00364264301202429 70 | 0.0633950450146659,0.153253338211898,0.150573433582716,-0.260688538422133 71 | 0.0326992900489964,2.17261167036215,2.27463335495274,0.0187678732948065 72 | 0.0801120654085875,0.475509528899663,0.442634362545729,-0.189169423782073 73 | 0.189204849514023,-0.709946430921815,-0.625807502270912,0.0392610830404136 74 | -0.0711059759695483,0.610726353489055,0.530677004693526,-0.0174557812962476 75 | -0.0902950247870035,-0.934097631644252,-1.20813920715374,0.101405667761576 76 | 0.401693111608936,-1.2536334002391,-1.05606574674813,-0.0206654435378691 77 | 0.212848190440045,0.291446235517463,0.595395240617372,-0.32373919023437 78 | -0.110077375072366,-0.443291873218433,-0.505039987063556,0.140834037514512 79 | -0.166069559325117,0.00110535163162413,-0.249552599489914,0.0553801121026009 80 | -0.122256157760812,0.0743413241516641,0.202789585287229,0.157021522643008 81 | 0.0340611135804661,-0.589520946188072,-0.598462773566868,-0.12138756708667 82 | -0.410352102531729,-0.568668732818502,-0.91531241418347,0.418082414805421 83 | 0.145173901393821,-0.135178615123832,-0.134752243187778,-0.174860765609483 84 | -0.256999039831912,1.1780869965732,1.05202692978758,0.138674815788645 85 | 0.0684425665903087,-1.52356680042976,-1.59176051640184,-0.0673132977479187 86 | -0.215115135509341,0.593946187628422,0.362631715101251,0.314275239115277 87 | 0.244303809227657,0.332950371213518,0.693578752797012,-0.0849070638372834 88 | -0.0782118103723128,1.06309983727636,0.996873429998118,-0.0590593167206042 89 | -0.105448019859772,-0.304183923634301,-0.625286606084916,0.0804869265555401 90 | -0.0705929287355037,0.370018809916288,0.409457497664184,0.186535381421059 91 | -0.1314579271364,0.267098790772231,0.319733920053326,0.0200356923520406 92 | -0.10291942538374,-0.54252003099165,-0.739685371073509,-0.149930643501299 93 | -0.439665847954808,1.20786780598317,0.630083671647263,0.346075592103477 94 | -0.275206178034512,1.16040261569495,1.03230627518193,0.178482232240955 95 | -0.00844454208888811,0.700213649514998,0.814315176699095,0.0131934013216677 96 | -0.282616371192178,1.58683345454085,1.57488879933232,0.242242691842374 97 | -0.0684609932587192,0.558486425565304,0.538850676764257,0.0916106060831824 98 | 0.207843248439958,-1.27659220845804,-1.16442806273401,-0.250080489232633 99 | -0.0996360580519485,-0.573265414236886,-0.810557141952781,0.137047897521478 100 | 0.237683411129138,-1.22461261489836,-1.00525720604351,-0.274283988628839 101 | 0.223053980529043,-0.473400636439312,-0.474469442094945,-0.104043835876937 102 | 0.223172599052014,-0.620366677224124,-0.478904543744508,-0.296905351577737 103 | 0.231024448927135,0.0421158731442352,0.248937420091728,-0.201957784383209 104 | 0.025589119362098,-0.910921648552446,-0.866225565569385,-0.114074076190544 105 | -0.104813606941534,0.158028772404075,-0.0177127501691288,0.125614254828617 106 | 0.320496729817689,-0.654584643918818,-0.421991732725352,-0.325269747065541 107 | -0.422400662764171,1.76728726937265,1.36725428041555,0.253948598150311 108 | -0.349981701225393,0.716707476017206,0.607749328016861,0.335559045563498 109 | -0.0872133318384159,0.910174229495227,0.859040087663829,0.205234698388681 110 | -0.248848604752988,0.384185357826345,0.350961150473343,0.316988597072939 111 | -0.12687895431187,1.68217608051942,1.88626886227624,0.141203717400626 112 | 0.212320704697191,-0.635736453948977,-0.608492075328422,-0.331552349068483 113 | 0.156588270377184,-0.461644730360566,-0.380211209675799,-0.0396654050982756 114 | -0.0424199058621788,1.43228223854166,1.41835127593908,0.050340076756716 115 | 0.0660104004880802,-0.650696353310367,-0.700229221634233,-0.111187775764948 116 | 0.0931882949062677,-0.207380743601965,-0.0682705822780374,0.0710145263738095 117 | 0.0843969184973661,-0.392807929441984,-0.163562257998824,-0.161356150657318 118 | -0.0970808835602544,-0.319992868548507,-0.800612111526881,0.127416979636039 119 | 0.101141136739345,-0.279113302976559,-0.164565391927391,0.0270326053790067 120 | -0.0759647443801324,0.494188331267827,0.569133212623558,0.136187023897175 121 | 0.0532232316119453,-0.177330482269606,-0.262384026580822,-0.0839254580656288 122 | 0.272161378948434,-0.505957462114257,-0.315754900598894,-0.314003189290698 123 | 0.00855313309661135,1.34303882517041,1.26519138882674,0.0269604199079971 124 | -0.0729614689217714,-0.214579408546869,-0.271445540906783,0.124309580381727 125 | 0.0817580443016014,-0.179556530043387,-0.00807457442742643,-0.0798973042687274 126 | 0.174752146862809,-0.100190741213562,0.243734718610558,-0.0429072496371413 127 | -0.0764257848747122,0.712666307051405,0.766677287238851,0.0698425848908969 128 | -0.07996873533212,-0.0735644041263263,-0.158001206083854,0.00993912750034963 129 | -0.168497923506534,-0.0376341714670479,-0.275456830438965,0.222230536658668 130 | 0.172624131430072,-0.681660478755657,-0.747867074531459,-0.392802363665463 131 | -0.150106321147875,-0.324270272246319,-0.512236137548323,0.189303695527716 132 | -0.0759047046098417,0.0601604404345152,0.00837392381075819,0.125600799852214 133 | 0.171468594029707,-0.588894486259664,-0.510018652615349,-0.193956065572697 134 | -0.0428952256526079,0.531496192632572,0.3611247742278,-0.0688190908805806 135 | 0.392528353607671,-1.51839408178679,-0.988560705567811,-0.432027813939059 136 | -0.0627430066129957,0.306557860789766,0.337760195922781,0.217726040836626 137 | 0.341631913498153,-1.53644982353759,-1.3104083700466,-0.41598336147802 138 | -0.142821842243355,-0.300976126836611,-0.758800922804632,-0.0903493695296027 139 | 0.116631670499155,-0.528279904445006,-0.380079673005918,-0.0354071262782339 140 | -0.235739285924949,-0.652094780680999,-0.915343812771311,0.185608220204813 141 | 0.286978222307924,-0.0568967778473925,0.127063957674436,-0.33806687890316 142 | 0.607772155054219,-1.21536404115501,-0.486527949081133,-0.108690881697743 143 | -0.102958180906067,-0.022558628347222,-0.0680969853404291,-0.182608301264641 144 | -0.230947591625832,0.701239300429707,0.142606638970047,0.0995281807292567 145 | 0.324314203001057,-0.587482025589799,-0.131807477297826,-0.00118617814307389 146 | 0.32278586503292,-0.60672794141879,-0.00728052451329706,-0.0599628394751609 147 | -0.951776524209577,1.09664021500029,-0.318657260305381,-0.0177947986602105 148 | 0.12313195674362,-0.247509677080296,0.169558559449554,-0.0425981341802302 149 | -0.0237375592338688,-0.159901713323247,-0.370077769732123,0.0996658776455071 150 | 0.449363538248151,-0.625778250735075,-0.0647301263380628,0.0727660708501531 151 | -0.696412151385246,0.900434635600238,-0.0510496225006345,-0.172663059586351 152 | 0.654403199004049,-0.994193629266225,0.01218424546135,0.035339849564175 153 | -0.599664870869249,0.849250385880362,-0.164698325803776,0.0726813665899557 154 | -0.450664782745396,0.805702288976437,0.365946096967207,0.0668260975976705 155 | 0.0573486221296348,-0.46760093599122,-0.2859832431563,-0.242431730928421 156 | -0.520972974743782,0.848420313723343,0.0508274286009171,-0.0235357425015276 157 | -1.08864309601279,0.986769863596017,-0.587954739067197,0.197963332091796 158 | -0.310135825217945,0.57562028851936,0.000483161788695027,0.0796794538639531 159 | -1.37609164422252,2.02484204541817,0.101933114252387,-0.170927618064586 160 | 1.36296593541389,-1.96235319122504,-0.216944000222792,-0.166366871188344 161 | 0.949579079189634,-1.16492093063759,0.140966595375513,0.049110955231555 162 | 0.778765354201592,-1.37651921373083,0.0661952699988424,-0.0174055485572285 163 | -0.133753793413494,0.167679934469767,0.195265494595274,0.0961290563877483 164 | -1.08216362044188,1.58462907915972,-0.168667976079747,0.0293826661677307 165 | -1.22191226524955,1.67788895297625,-0.194115980972583,0.00809993635091707 166 | -0.651097717499451,0.488296698394955,-0.354306269728586,0.0183661842799281 167 | -0.788347093291752,0.878673262560324,-0.0644940683291897,0.016625503539133 168 | 0.0441228616699499,-0.144874874029881,-0.2677601484709,-0.126959906621425 169 | -0.326879154640724,0.468971760074208,0.137631205629211,0.234949332061946 170 | -0.0746386078587518,0.376235477129484,0.0142561304501158,-0.141200540742533 171 | 0.972652671690978,-0.761040275056481,0.437950471855836,-0.00169614927822389 172 | -0.0790875952708116,-0.293294933750864,-0.231541519860296,-0.0544319352621665 173 | 0.151626021518779,-0.134841264407518,0.236337612818467,0.180011233335541 174 | -1.16402424064651,1.39384581617302,-0.105473672321636,0.101144017617923 175 | 0.617825528221931,-1.03698868969829,-0.291325602262151,-0.0563716555906629 176 | 1.41971371915798,-2.11433514780366,0.114593474098634,0.0205420795375023 177 | -0.913300431150634,0.768278218204398,-0.286675540934441,0.116546195019875 178 | 0.646824931753512,-0.816160620753485,-0.2110370037053,0.22363228395069 179 | 0.317898769682985,-0.436106923160574,-0.14662237548439,0.0302265076168852 180 | -0.52627229006531,0.904705031282809,0.0421814528502597,-0.104250660223103 181 | 0.347891069578776,-0.763086264548317,-0.199784145412202,-0.0983542313495811 182 | 0.366960801406098,-0.341066979637804,0.215570064641125,0.200571858043419 183 | -1.13360637424418,1.50242453423459,-0.239794876501575,-0.207057148385748 184 | -0.471942828595735,0.5283077123164,0.0433274070369421,0.305574236888186 185 | -0.374789637887949,0.542191355464308,0.0286174059577855,-0.02613505940873 186 | -0.0146444964736668,-0.136673355717877,-0.213150018210596,-0.045439325911372 187 | 0.91355626747497,-1.13673385330273,-0.0857246821403559,0.0157560555463951 188 | 1.03311074255719,-1.49662715435579,-0.131235895354269,0.0933388727865257 189 | 0.399720572947451,-0.223385643556595,0.191878865371523,0.0302828275834352 190 | -1.20042244575401,2.00171922777288,0.311210527220444,-0.195615022247248 191 | -0.391338511722182,0.221703816213622,-0.208159286757649,0.0353536709469366 192 | -0.0373600446438317,0.164372909246388,0.186114481708049,0.0450424513742423 193 | -0.408366501842717,0.33262360885001,-0.0150891862024382,0.0659550870534498 194 | 0.126467333649913,-0.385207999091431,-0.393439069810192,-0.103142072760739 195 | 1.03010206494468,-1.39875402655789,-0.151180728529039,-0.237102288136148 196 | -1.82395273559816,2.67574079585182,0.0922298321013524,-0.0324576307824521 197 | 0.215902824193346,-0.423686088681983,0.0290213262095744,-0.094429875090974 198 | 0.0246587249470748,-0.298601511933064,-0.488462264216285,-0.0765889998443101 199 | 1.33821015370734,-1.79234172685298,0.116063737090315,-0.0953779266719918 200 | 0.2135847875987,-0.248008225068098,0.131010399697349,-0.0398004449854522 201 | 0.248930224709095,-0.247303918374605,-0.0609017674518837,-0.0311217062923606 202 | 0.0179645113186073,-0.255510378526346,-0.141513646523459,0.0796092713239834 203 | 1.60234122212509,-1.78693810009257,0.394314402895646,0.098642834387578 204 | -1.34431143878874,1.78466281602476,-0.0179997361295498,-0.0794531662417426 205 | -1.19705117961467,1.76358634774588,-0.00280345038638484,-0.0308817971515212 206 | -0.528397263575222,0.689600221934096,-0.224691387422031,0.0361444766025588 207 | 0.475696965036168,-1.10074064442392,-0.268826024675241,0.139879110483801 208 | -0.643122894572425,0.714509356897811,-0.304631154212418,-0.00560704196424014 209 | 0.110734356313072,-0.246470316934021,-0.0843936420764348,-0.169887349155411 210 | 0.45473564350299,-0.319786165927205,0.272184892816197,0.0231852545141301 211 | -0.822255417166027,1.3626442929547,0.350758969066974,-0.0119090670970556 212 | 1.03902294770446,-1.22788258998268,0.313672946870558,0.177249285346043 213 | 0.65169253209464,-0.511219232750473,0.25935111398955,0.0343422165443603 214 | 0.500173965553044,-0.731194999064964,-0.0475192502887223,-0.0623049782354728 215 | -0.0784053262679263,0.0197520068767907,-0.244830027136141,-0.0439522293828527 216 | 1.03957775891854,-1.57286391470999,-0.0655625360280143,-0.0505296789545481 217 | 0.0999000902556607,-0.703333269828288,-0.48249005525445,0.0186035137157474 218 | -0.539900176133285,0.715932088907665,-0.062758574688678,0.0176417797599477 219 | -0.248964105482767,0.465214906423109,0.331975740929143,0.0915848206824781 220 | 0.437834847710909,-0.973902306404292,0.0261906201215034,0.0320176726180939 221 | -0.143033510205934,0.559217730468333,0.219177736537917,-0.036668729712132 222 | -------------------------------------------------------------------------------- /install.R: -------------------------------------------------------------------------------- 1 | install.packages("gam") 2 | install.packages("DoubleML") 3 | install.packages("WeightIt") 4 | install.packages("MatchIt") 5 | install.packages("ranger") 6 | install.packages('dHSIC') 7 | install.packages('InvariantCausalPrediction') 8 | -------------------------------------------------------------------------------- /python/Causality_Tutorial_Exercises.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "3In88vKrq4WB" 7 | }, 8 | "source": [ 9 | "# Causality Tutorial Exercises – Python\n", 10 | "\n", 11 | "Contributors: Rune Christiansen, Jonas Peters, Niklas Pfister, Sorawit Saengkyongam, Sebastian Weichwald.\n", 12 | "The MIT License applies; copyright is with the authors.\n", 13 | "Some exercises are adapted from \"Elements of Causal Inference: Foundations and Learning Algorithms\" by J. Peters, D. Janzing and B. Schölkopf.\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "KnsIE8yWlVIQ" 20 | }, 21 | "source": [ 22 | "# Exercise 1 – Structural Causal Model\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": { 28 | "id": "FSNemB3GrBIE" 29 | }, 30 | "source": [ 31 | "\n", 32 | "Let's first draw a sample from an SCM" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": { 39 | "id": "5Cy58Ut1liKd" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "import numpy as np\n", 44 | "\n", 45 | "# set seed\n", 46 | "np.random.seed(1)\n", 47 | "\n", 48 | "rnorm = lambda n: np.random.normal(size=n)\n", 49 | "\n", 50 | "n = 200\n", 51 | "C = rnorm(n)\n", 52 | "A = .8 * rnorm(n)\n", 53 | "K = A + .1 * rnorm(n)\n", 54 | "X = C - 2 * A + .2 * rnorm(n)\n", 55 | "F = 3 * X + .8 * rnorm(n)\n", 56 | "D = -2 * X + .5 * rnorm(n)\n", 57 | "G = D + .5 * rnorm(n)\n", 58 | "Y = 2 * K - D + .2 * rnorm(n)\n", 59 | "H = .5 * Y + .1 * rnorm(n)\n", 60 | "\n", 61 | "data = np.c_[C, A, K, X, F, D, G, Y, H]" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": { 67 | "id": "8PMvvEeIoKFN" 68 | }, 69 | "source": [ 70 | "__a)__\n", 71 | "\n", 72 | "What are the parents and children of $X$ in the above SCM ?\n", 73 | "\n", 74 | "Take a pair of variables and think about whether you expect this pair to be dependent\n", 75 | "(at this stage, you can only guess, later you will have tools to know). Check empirically.\n", 76 | "\n", 77 | "__b)__\n", 78 | "\n", 79 | "Generate a sample of size 300 from the interventional distribution $P_{\\mathrm{do}(X=\\mathcal{N}(2, 1))}$\n", 80 | "and store the data matrix as `data_int`." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "id": "FtbA6c2Ron5f" 88 | }, 89 | "outputs": [], 90 | "source": [] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": { 95 | "id": "l3wOg_4vozpz" 96 | }, 97 | "source": [ 98 | "__c)__\n", 99 | "\n", 100 | "Do you expect the marginal distribution of $Y$ to be different in both samples?" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": { 106 | "id": "3paV1bkro6lV" 107 | }, 108 | "source": [ 109 | "Double-click (or enter) to edit" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": { 115 | "id": "CH9Tt444o-RH" 116 | }, 117 | "source": [ 118 | "__d)__\n", 119 | "\n", 120 | "Do you expect the joint distribution of $(A, Y)$ to be different in both samples?\n" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": { 126 | "id": "FJz4fZKEpE4-" 127 | }, 128 | "source": [ 129 | "Double-click (or enter) to edit" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "eZmh_AizpGp-" 136 | }, 137 | "source": [ 138 | "__e)__\n", 139 | "\n", 140 | "Check your answers to c) and d) empirically." 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": { 147 | "id": "q2PMSXqKpLpH" 148 | }, 149 | "outputs": [], 150 | "source": [] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "id": "1Idk_ElwrEht" 156 | }, 157 | "source": [ 158 | "# Exercise 2 – Adjusting" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": { 164 | "id": "il0b9fnVq-bz" 165 | }, 166 | "source": [ 167 | "\n", 168 | "![DAG](https://raw.githubusercontent.com/sweichwald/causality-tutorial-exercises/main/data/Exercise-ANM.png)\n", 169 | "\n", 170 | "Suppose we are given a fixed DAG (like the one above).\n", 171 | "\n", 172 | "a) What are valid adjustment sets (VAS) used for?\n", 173 | "\n", 174 | "b) Assume we want to find a VAS for the causal effect from $X$ to $Y$.\n", 175 | "What are general recipies (plural 😉) for constructing VASs (no proof)?\n", 176 | "Which sets are VAS in the DAG above?\n", 177 | "\n", 178 | "c) The following code samples from an SCM. Perform linear regressions using different VAS and compare the regression coefficient against the causal effect from $X$ to $Y$.\n" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": { 185 | "id": "R3y5ckYKJHiJ" 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "import numpy as np\n", 190 | "\n", 191 | "# set seed\n", 192 | "np.random.seed(1)\n", 193 | "\n", 194 | "rnorm = lambda n: np.random.normal(size=n)\n", 195 | "\n", 196 | "n = 200\n", 197 | "C = rnorm(n)\n", 198 | "A = .8 * rnorm(n)\n", 199 | "K = A + .1 * rnorm(n)\n", 200 | "X = C - 2 * A + .2 * rnorm(n)\n", 201 | "F = 3 * X + .8 * rnorm(n)\n", 202 | "D = -2 * X + .5 * rnorm(n)\n", 203 | "G = D + .5 * rnorm(n)\n", 204 | "Y = 2 * K - D + .2 * rnorm(n)\n", 205 | "H = .5 * Y + .1 * rnorm(n)\n", 206 | "\n", 207 | "data = np.c_[C, A, K, X, F, D, G, Y, H]" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": { 213 | "id": "UqFFtwP5JQVw" 214 | }, 215 | "source": [ 216 | "d) Why could it be interesting to have several options for choosing a VAS?\n", 217 | "\n", 218 | "e) If you indeed have access to several VASs, what would you do?" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": { 224 | "id": "LQ7RuuF4rMD6" 225 | }, 226 | "source": [ 227 | "# Exercise 3 – Independence-based Causal Structure Learning" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": { 233 | "id": "p21N9AFBrB0o" 234 | }, 235 | "source": [ 236 | "__a)__\n", 237 | "\n", 238 | "Assume $P^{X,Y,Z}$ is Markov and faithful wrt. $G$. Assume all (!) conditional independences are\n", 239 | "\n", 240 | "$$\n", 241 | "\\newcommand{\\indep}{{\\,⫫\\,}}\n", 242 | "\\newcommand{\\dep}{\\not{}\\!\\!\\indep}\n", 243 | "$$\n", 244 | "\n", 245 | "$$X \\dep Z \\mid \\emptyset$$\n", 246 | "\n", 247 | "(plus symmetric statements). What is $G$?\n", 248 | "\n", 249 | "__b)__\n", 250 | "\n", 251 | "Assume $P^{W,X,Y,Z}$ is Markov and faithful wrt. $G$. Assume all (!) conditional independences are\n", 252 | "\n", 253 | "$$\\begin{aligned}\n", 254 | "(Y,Z) &\\indep W \\mid \\emptyset \\\\\n", 255 | "W &\\indep Y \\mid (X,Z) \\\\\n", 256 | "(X,W) &\\indep Y | Z\n", 257 | "\\end{aligned}\n", 258 | "$$\n", 259 | "\n", 260 | "(plus symmetric statements). What is $G$?" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": { 266 | "id": "craCADN8rKd3" 267 | }, 268 | "source": [ 269 | "# Exercise 4 – Additive Noise Models" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": { 275 | "id": "OlFh1Zk50_z7" 276 | }, 277 | "source": [ 278 | "Set-up required packages:" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": { 285 | "id": "qk3IE7jvvUxG" 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "# set up – not needed when run on mybinder\n", 290 | "# if needed (colab), change False to True and run cell\n", 291 | "if False:\n", 292 | " !mkdir ../data/\n", 293 | " !wget https://raw.githubusercontent.com/sweichwald/causality-tutorial-exercises/main/data/Exercise-ANM.csv -q -O ../data/Exercise-ANM.csv\n", 294 | " !wget https://raw.githubusercontent.com/sweichwald/causality-tutorial-exercises/main/python/kerpy/__init__.py -q -O kerpy.py\n", 295 | " !pip install pygam" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": { 302 | "id": "GNsEcFUJ1P4I" 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "from kerpy import hsic\n", 307 | "import matplotlib.pyplot as plt\n", 308 | "import numpy as np\n", 309 | "import pandas as pd\n", 310 | "from pygam import GAM, s" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": { 316 | "id": "pmh91goS1DCT" 317 | }, 318 | "source": [ 319 | "Let's load and plot some real data set:" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": { 326 | "id": "2hwvlkYX1EPW" 327 | }, 328 | "outputs": [], 329 | "source": [ 330 | "data = pd.read_csv('../data/Exercise-ANM.csv')\n", 331 | "\n", 332 | "plt.scatter(data[\"X\"].values, data[\"Y\"].values, s=2.);" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": { 338 | "id": "-uDnv5eD2pGd" 339 | }, 340 | "source": [ 341 | "__a)__\n", 342 | "\n", 343 | "Do you believed that $X \\to Y$ or that $X \\gets Y$? Why?" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": { 349 | "id": "4owvM1J_2rcM" 350 | }, 351 | "source": [ 352 | "Double-click (or enter) to edit" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": { 358 | "id": "mYdffpZN2uDc" 359 | }, 360 | "source": [ 361 | "$$\n", 362 | "\\newcommand{\\indep}{{\\,⫫\\,}}\n", 363 | "\\newcommand{\\dep}{\\not{}\\!\\!\\indep}\n", 364 | "$$\n", 365 | "\n", 366 | "__b)__\n", 367 | "Let us now try to get a more statistical answer. We have heard that we cannot \n", 368 | "have \n", 369 | "$$Y = f(X) + N_Y,\\ N_Y \\indep X$$\n", 370 | "and\n", 371 | "$$X = g(Y) + N_X,\\ N_X \\indep Y$$\n", 372 | "at the same time.\n", 373 | "\n", 374 | "Given a data set over $(X,Y)$,\n", 375 | "we now want to decide for one of the two models. \n", 376 | "\n", 377 | "Come up with a method to do so.\n", 378 | "\n", 379 | "Hints: \n", 380 | "* `GAM(s(0)).fit(A, B).deviance_residuals(A, B)` provides residuals when regressing $B$ on $A$.\n", 381 | "* `hsic(a, b)` can be used as an independence test (here, `a` and `b` are $n \\times 1$ numpy arrays)." 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "metadata": { 388 | "id": "llz5Eeck2xz5" 389 | }, 390 | "outputs": [], 391 | "source": [] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": { 396 | "id": "o8SBfEFi6oqH" 397 | }, 398 | "source": [ 399 | "__c)__\n", 400 | "\n", 401 | "Assume that the error terms are Gaussian with zero mean and variances \n", 402 | "$\\sigma_X^2$ and $\\sigma_Y^2$, respectively.\n", 403 | "The maximum likelihood for DAG G is \n", 404 | "then proportional to \n", 405 | "$-\\log(\\mathrm{var}(R^G_X)) - \\log(\\mathrm{var}(R^G_Y))$,\n", 406 | "where $R^G_X$ and $R^G_Y$ are the residuals obtained from regressing $X$ and $Y$ on \n", 407 | "their parents in $G$, respectively (no proof).\n", 408 | "\n", 409 | "Find the maximum likelihood solution." 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": { 416 | "id": "pASFG1DC6sQA" 417 | }, 418 | "outputs": [], 419 | "source": [] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": { 424 | "id": "d4JPYnSHrOfW" 425 | }, 426 | "source": [ 427 | "# Exercise 5 – Invariant Causal Prediction" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": { 433 | "id": "Fb5CwEUEAaOp" 434 | }, 435 | "source": [ 436 | "Set-up required packages and data:" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": { 443 | "id": "wudBwtYswFeo" 444 | }, 445 | "outputs": [], 446 | "source": [ 447 | "# set up – not needed when run on mybinder\n", 448 | "# if needed (colab), change False to True and run cell\n", 449 | "if False:\n", 450 | " !mkdir ../data/\n", 451 | " !wget https://raw.githubusercontent.com/sweichwald/causality-tutorial-exercises/main/data/Exercise-ICP.csv -q -O ../data/Exercise-ICP.csv" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "metadata": { 458 | "id": "DIosbymbbkhg" 459 | }, 460 | "outputs": [], 461 | "source": [ 462 | "import matplotlib.pyplot as plt\n", 463 | "import numpy as np\n", 464 | "import pandas as pd\n", 465 | "import seaborn as sns\n", 466 | "import statsmodels.api as sm" 467 | ] 468 | }, 469 | { 470 | "cell_type": "markdown", 471 | "metadata": { 472 | "id": "gy8eUIaDdmrz" 473 | }, 474 | "source": [ 475 | "__a)__\n", 476 | "\n", 477 | "Generate some observational and interventional data:" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "metadata": { 484 | "id": "IGBQGfetbYPj" 485 | }, 486 | "outputs": [], 487 | "source": [ 488 | "# Generate n=1000 observations from the observational distribution\n", 489 | "na = 1000\n", 490 | "Xa = np.random.normal(size=na)\n", 491 | "Ya = 1.5*Xa + np.random.normal(size=na)\n", 492 | "\n", 493 | "# Generate n=1000 observations from an interventional distribution\n", 494 | "nb = 1000\n", 495 | "Xb = np.random.normal(loc=2, scale=1, size=nb)\n", 496 | "Yb = 1.5*Xb + np.random.normal(size=nb)\n", 497 | "\n", 498 | "# plot Y vs X1\n", 499 | "fig, ax = plt.subplots(figsize=(7,5))\n", 500 | "ax.scatter(Xa, Ya, label='observational', marker='o', alpha=0.6)\n", 501 | "ax.scatter(Xb, Yb, label='interventional', marker ='^', alpha=0.6)\n", 502 | "ax.legend();" 503 | ] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": { 508 | "id": "uZcSibWjypDR" 509 | }, 510 | "source": [ 511 | "Look at the above plot. Is the predictor $\\{X\\}$ an invariant set, that is (roughly speaking), does $Y \\mid X = x$ have the same distribution in the orange and blue data?" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": { 517 | "id": "rhnmzEIiyvmt" 518 | }, 519 | "source": [ 520 | "Double-click (or enter) to edit" 521 | ] 522 | }, 523 | { 524 | "cell_type": "markdown", 525 | "metadata": { 526 | "id": "DnDdgV_QeEFH" 527 | }, 528 | "source": [ 529 | "__b)__" 530 | ] 531 | }, 532 | { 533 | "cell_type": "markdown", 534 | "metadata": { 535 | "id": "BqcN5gRdeRoi" 536 | }, 537 | "source": [ 538 | "We now consider data over a response and three covariates $X1, X2$, and $X3$\n", 539 | "and try to infer $\\mathrm{pa}(Y)$. To do so, we need to find all sets for which this\n", 540 | "invariance is satisfied." 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": null, 546 | "metadata": { 547 | "id": "i4vMv59_wjKG" 548 | }, 549 | "outputs": [], 550 | "source": [ 551 | "# load data\n", 552 | "data = pd.read_csv('../data/Exercise-ICP.csv')\n", 553 | "data['env'] = np.concatenate([np.repeat('observational', 140), np.repeat('interventional', 80)])\n", 554 | "# pairplot\n", 555 | "sns.pairplot(data, hue='env', height=2, plot_kws={'alpha':0.6});" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": null, 561 | "metadata": { 562 | "id": "2yF7KhYZe7g9" 563 | }, 564 | "outputs": [], 565 | "source": [ 566 | "# The code below plots the residuals versus fitted values for all sets of \n", 567 | "# predictors. \n", 568 | "# extract response and predictors\n", 569 | "\n", 570 | "Y = data['Y'].to_numpy()\n", 571 | "X = data[['X1','X2','X3']].to_numpy()\n", 572 | "# get environment indicator\n", 573 | "obs_ind = data[data['env'] == 'observational'].index\n", 574 | "int_ind = data[data['env'] == 'interventional'].index\n", 575 | "# create all sets\n", 576 | "all_sets = [(0,), (1,), (2,), (0,1), (0,2), (1,2), (0,1,2)]\n", 577 | "# label each set\n", 578 | "set_labels = ['X1', 'X2', 'X3', 'X1,X2', 'X1,X3', 'X2,X3', 'X1,X2,X3']\n", 579 | "\n", 580 | "# fit OLS and store fitted values and residuals for each set\n", 581 | "fitted = []\n", 582 | "resid = []\n", 583 | "for s in all_sets:\n", 584 | " model = sm.OLS(Y, X[:, s]).fit()\n", 585 | " fitted += [model.fittedvalues]\n", 586 | " resid += [model.resid]\n", 587 | "\n", 588 | "# plotting function\n", 589 | "def plot_fitted_resid(fv, res, ax, title):\n", 590 | " ax.scatter(fv[obs_ind], res[obs_ind], label='observational', marker='o', alpha=0.6)\n", 591 | " ax.scatter(fv[int_ind], res[int_ind], label='interventional', marker ='^', alpha=0.6)\n", 592 | " ax.legend()\n", 593 | " ax.set_xlabel('fitted values')\n", 594 | " ax.set_ylabel('residuals')\n", 595 | " ax.set_title(title)\n", 596 | "\n", 597 | "# creating plots\n", 598 | "fig, axes = plt.subplots(4, 2, figsize=(7,14))\n", 599 | "\n", 600 | "# plot result for the empty set predictor\n", 601 | "ax0 = axes[0,0]\n", 602 | "ax0.scatter(obs_ind, Y[obs_ind], label='observational', marker='o', alpha=0.6)\n", 603 | "ax0.scatter(int_ind, Y[int_ind], label='interventional', marker ='^', alpha=0.6)\n", 604 | "ax0.legend()\n", 605 | "ax0.set_xlabel('index')\n", 606 | "ax0.set_ylabel('Y')\n", 607 | "ax0.set_title('empty set')\n", 608 | "\n", 609 | "# plot result for the other sets\n", 610 | "for i, ax in enumerate(axes.flatten()[1:]):\n", 611 | " plot_fitted_resid(fitted[i], resid[i], ax, set_labels[i])\n", 612 | "\n", 613 | "# make tight layout\n", 614 | "plt.tight_layout()" 615 | ] 616 | }, 617 | { 618 | "cell_type": "markdown", 619 | "metadata": { 620 | "id": "1GfZKCL7zJve" 621 | }, 622 | "source": [ 623 | "Which of the sets are invariant? (There are two plots with four scatter plots each.)" 624 | ] 625 | }, 626 | { 627 | "cell_type": "markdown", 628 | "metadata": { 629 | "id": "j0sgjfRSzWEt" 630 | }, 631 | "source": [ 632 | "Double-click (or enter) to edit" 633 | ] 634 | }, 635 | { 636 | "cell_type": "markdown", 637 | "metadata": { 638 | "id": "AO7tZSjLzMr0" 639 | }, 640 | "source": [ 641 | "__c)__\n", 642 | "What is your best guess for $\\mathrm{pa}(Y)$?" 643 | ] 644 | }, 645 | { 646 | "cell_type": "markdown", 647 | "metadata": { 648 | "id": "B6QtA9p9zdD7" 649 | }, 650 | "source": [ 651 | "Double-click (or enter) to edit" 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "metadata": { 657 | "id": "AZGGVS8lP0Ly" 658 | }, 659 | "source": [ 660 | "__d) (optional)__\n", 661 | "\n", 662 | "Use the function ICP to check your result." 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": null, 668 | "metadata": { 669 | "id": "1Qi2_GCnQmEG" 670 | }, 671 | "outputs": [], 672 | "source": [ 673 | "# set up – not needed when run on mybinder\n", 674 | "# if needed (colab), change False to True and run cell\n", 675 | "if False:\n", 676 | " !pip install causalicp" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": null, 682 | "metadata": { 683 | "id": "fqUzMXw5QLva" 684 | }, 685 | "outputs": [], 686 | "source": [ 687 | "import causalicp as icp" 688 | ] 689 | } 690 | ], 691 | "metadata": { 692 | "colab": { 693 | "collapsed_sections": [], 694 | "name": "Causality Tutorial Exercises – Python", 695 | "provenance": [], 696 | "toc_visible": true 697 | }, 698 | "kernelspec": { 699 | "display_name": "Python 3", 700 | "language": "python", 701 | "name": "python3" 702 | }, 703 | "language_info": { 704 | "codemirror_mode": { 705 | "name": "ipython", 706 | "version": 3 707 | }, 708 | "file_extension": ".py", 709 | "mimetype": "text/x-python", 710 | "name": "python", 711 | "nbconvert_exporter": "python", 712 | "pygments_lexer": "ipython3", 713 | "version": "3.7.6" 714 | } 715 | }, 716 | "nbformat": 4, 717 | "nbformat_minor": 1 718 | } 719 | -------------------------------------------------------------------------------- /python/kerpy/README.md: -------------------------------------------------------------------------------- 1 | Plain HSIC independence test, adapted down https://github.com/oxcsml/kerpy (MIT licensed) 2 | -------------------------------------------------------------------------------- /python/kerpy/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import matplotlib.cm as cm 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from numpy import (arange, cos, dot, exp, fill_diagonal, mean, 6 | shape, sin, sqrt, zeros) 7 | from numpy.random import permutation, randn 8 | from scipy import linalg 9 | from scipy.linalg import sqrtm, inv 10 | from scipy.spatial.distance import squareform, pdist, cdist 11 | from scipy.stats import norm as normaldist 12 | import time 13 | 14 | 15 | def hsic(x, y): 16 | if x.shape[0] > 222: 17 | inds = np.random.choice(x.shape[0], 18 | size=222, 19 | replace=False) 20 | X = x[inds, :] 21 | Y = y[inds, :] 22 | else: 23 | X = x 24 | Y = y 25 | hsic = HSICPermutationTestObject(X.shape[0], 26 | kernelX=GaussianKernel(), 27 | kernelY=GaussianKernel()) 28 | return hsic.compute_pvalue(X, Y) 29 | 30 | 31 | class TestObject(object): 32 | 33 | def __init__(self, test_type, streaming=False, freeze_data=False): 34 | self.test_type = test_type 35 | self.streaming = streaming 36 | self.freeze_data = freeze_data 37 | if self.freeze_data: 38 | self.generate_data() 39 | assert not self.streaming 40 | 41 | @abstractmethod 42 | def compute_Zscore(self): 43 | raise NotImplementedError 44 | 45 | @abstractmethod 46 | def generate_data(self): 47 | raise NotImplementedError 48 | 49 | def compute_pvalue(self): 50 | Z_score = self.compute_Zscore() 51 | pvalue = normaldist.sf(Z_score) 52 | return pvalue 53 | 54 | def perform_test(self, alpha): 55 | pvalue = self.compute_pvalue() 56 | return pvalue < alpha 57 | 58 | 59 | class HSICTestObject(TestObject): 60 | def __init__(self, 61 | num_samples, 62 | data_generator=None, 63 | kernelX=None, 64 | kernelY=None, 65 | kernelZ=None, 66 | kernelX_use_median=False, 67 | kernelY_use_median=False, 68 | kernelZ_use_median=False, 69 | rff=False, 70 | num_rfx=None, 71 | num_rfy=None, 72 | induce_set=False, 73 | num_inducex=None, 74 | num_inducey=None, 75 | streaming=False, 76 | freeze_data=False): 77 | TestObject.__init__(self, 78 | self.__class__.__name__, 79 | streaming=streaming, 80 | freeze_data=freeze_data) 81 | # We have same number of samples from X and Y in independence testing 82 | self.num_samples = num_samples 83 | self.data_generator = data_generator 84 | self.kernelX = kernelX 85 | self.kernelY = kernelY 86 | self.kernelZ = kernelZ 87 | # indicate if median heuristic for Gaussian Kernel should be used 88 | self.kernelX_use_median = kernelX_use_median 89 | self.kernelY_use_median = kernelY_use_median 90 | self.kernelZ_use_median = kernelZ_use_median 91 | self.rff = rff 92 | self.num_rfx = num_rfx 93 | self.num_rfy = num_rfy 94 | self.induce_set = induce_set 95 | self.num_inducex = num_inducex 96 | self.num_inducey = num_inducey 97 | if self.rff | self.induce_set: 98 | self.HSICmethod = self.HSIC_with_shuffles_rff 99 | else: 100 | self.HSICmethod = self.HSIC_with_shuffles 101 | 102 | def generate_data(self, isConditionalTesting=False): 103 | if not isConditionalTesting: 104 | self.data_x, self.data_y = self.data_generator(self.num_samples) 105 | return self.data_x, self.data_y 106 | else: 107 | self.data_x, self.data_y, self.data_z = self.data_generator( 108 | self.num_samples) 109 | return self.data_x, self.data_y, self.data_z 110 | 111 | @staticmethod 112 | def HSIC_U_statistic(Kx, Ky): 113 | m = shape(Kx)[0] 114 | fill_diagonal(Kx, 0.) 115 | fill_diagonal(Ky, 0.) 116 | K = np.dot(Kx, Ky) 117 | first_term = np.trace(K)/float(m*(m-3.)) 118 | second_term = np.sum(Kx)*np.sum(Ky)/float(m*(m-3.)*(m-1.)*(m-2.)) 119 | third_term = 2.*np.sum(K)/float(m*(m-3.)*(m-2.)) 120 | return first_term+second_term-third_term 121 | 122 | @staticmethod 123 | def HSIC_V_statistic(Kx, Ky): 124 | Kxc = Kernel.center_kernel_matrix(Kx) 125 | Kyc = Kernel.center_kernel_matrix(Ky) 126 | return np.sum(Kxc*Kyc) 127 | 128 | @staticmethod 129 | def HSIC_V_statistic_rff(phix, phiy): 130 | m = shape(phix)[0] 131 | phix_c = phix-mean(phix, axis=0) 132 | phiy_c = phiy-mean(phiy, axis=0) 133 | featCov = (phix_c.T).dot(phiy_c)/float(m) 134 | return np.linalg.norm(featCov)**2 135 | 136 | # generalise distance correlation ---- a kernel interpretation 137 | @staticmethod 138 | def dCor_HSIC_statistic(Kx, Ky, unbiased=False): 139 | if unbiased: 140 | first_term = HSICTestObject.HSIC_U_statistic(Kx, Ky) 141 | second_term = HSICTestObject.HSIC_U_statistic(Kx, Kx) \ 142 | * HSICTestObject.HSIC_U_statistic(Ky, Ky) 143 | dCor = first_term/float(sqrt(second_term)) 144 | else: 145 | first_term = HSICTestObject.HSIC_V_statistic(Kx, Ky) 146 | second_term = HSICTestObject.HSIC_V_statistic(Kx, Kx) \ 147 | * HSICTestObject.HSIC_V_statistic(Ky, Ky) 148 | dCor = first_term/float(sqrt(second_term)) 149 | return dCor 150 | 151 | # approximated dCor using rff/Nystrom 152 | @staticmethod 153 | def dCor_HSIC_statistic_rff(phix, phiy): 154 | first_term = HSICTestObject.HSIC_V_statistic_rff(phix, phiy) 155 | second_term = HSICTestObject.HSIC_V_statistic_rff(phix, phix) \ 156 | * HSICTestObject.HSIC_V_statistic_rff(phiy, phiy) 157 | approx_dCor = first_term/float(sqrt(second_term)) 158 | return approx_dCor 159 | 160 | def SubdCor_HSIC_statistic(self, data_x=None, data_y=None, unbiased=True): 161 | if data_x is None: 162 | data_x = self.data_x 163 | if data_y is None: 164 | data_y = self.data_y 165 | dx = shape(data_x)[1] 166 | stats_value = zeros(dx) 167 | for dd in range(dx): 168 | Kx, Ky = self.compute_kernel_matrix_on_data( 169 | data_x[:, [dd]], data_y) 170 | stats_value[dd] = HSICTestObject.dCor_HSIC_statistic( 171 | Kx, Ky, unbiased) 172 | SubdCor = sum(stats_value)/float(dx) 173 | return SubdCor 174 | 175 | def SubHSIC_statistic(self, data_x=None, data_y=None, unbiased=True): 176 | if data_x is None: 177 | data_x = self.data_x 178 | if data_y is None: 179 | data_y = self.data_y 180 | dx = shape(data_x)[1] 181 | stats_value = zeros(dx) 182 | for dd in range(dx): 183 | Kx, Ky = self.compute_kernel_matrix_on_data( 184 | data_x[:, [dd]], data_y) 185 | if unbiased: 186 | stats_value[dd] = HSICTestObject.HSIC_U_statistic(Kx, Ky) 187 | else: 188 | stats_value[dd] = HSICTestObject.HSIC_V_statistic(Kx, Ky) 189 | SubHSIC = sum(stats_value)/float(dx) 190 | return SubHSIC 191 | 192 | def HSIC_with_shuffles(self, 193 | data_x=None, 194 | data_y=None, 195 | unbiased=True, 196 | num_shuffles=0, 197 | estimate_nullvar=False, 198 | isBlockHSIC=False): 199 | start = time.time() 200 | if data_x is None: 201 | data_x = self.data_x 202 | if data_y is None: 203 | data_y = self.data_y 204 | time_passed = time.time() - start 205 | if isBlockHSIC: 206 | Kx, Ky = self.compute_kernel_matrix_on_dataB(data_x, data_y) 207 | else: 208 | Kx, Ky = self.compute_kernel_matrix_on_data(data_x, data_y) 209 | ny = shape(data_y)[0] 210 | if unbiased: 211 | test_statistic = HSICTestObject.HSIC_U_statistic(Kx, Ky) 212 | else: 213 | test_statistic = HSICTestObject.HSIC_V_statistic(Kx, Ky) 214 | null_samples = zeros(num_shuffles) 215 | for jj in range(num_shuffles): 216 | pp = permutation(ny) 217 | Kpp = Ky[pp, :][:, pp] 218 | if unbiased: 219 | null_samples[jj] = HSICTestObject.HSIC_U_statistic(Kx, Kpp) 220 | else: 221 | null_samples[jj] = HSICTestObject.HSIC_V_statistic(Kx, Kpp) 222 | if estimate_nullvar: 223 | nullvarx, nullvary = \ 224 | self.unbiased_HSnorm_estimate_of_centred_operator(Kx, Ky) 225 | nullvarx = 2. * nullvarx 226 | nullvary = 2. * nullvary 227 | else: 228 | nullvarx, nullvary = None, None 229 | return (test_statistic, 230 | null_samples, 231 | nullvarx, 232 | nullvary, 233 | Kx, 234 | Ky, 235 | time_passed) 236 | 237 | def HSIC_with_shuffles_rff(self, 238 | data_x=None, 239 | data_y=None, 240 | unbiased=True, 241 | num_shuffles=0, 242 | estimate_nullvar=False): 243 | start = time.clock() 244 | if data_x is None: 245 | data_x = self.data_x 246 | if data_y is None: 247 | data_y = self.data_y 248 | time_passed = time.clock()-start 249 | if self.rff: 250 | phix, phiy = self.compute_rff_on_data(data_x, data_y) 251 | else: 252 | phix, phiy = self.compute_induced_kernel_matrix_on_data( 253 | data_x, data_y) 254 | ny = shape(data_y)[0] 255 | if unbiased: 256 | test_statistic = HSICTestObject.HSIC_U_statistic_rff(phix, phiy) 257 | else: 258 | test_statistic = HSICTestObject.HSIC_V_statistic_rff(phix, phiy) 259 | null_samples = zeros(num_shuffles) 260 | for jj in range(num_shuffles): 261 | pp = permutation(ny) 262 | if unbiased: 263 | null_samples[jj] = HSICTestObject.HSIC_U_statistic_rff( 264 | phix, phiy[pp]) 265 | else: 266 | null_samples[jj] = HSICTestObject.HSIC_V_statistic_rff( 267 | phix, phiy[pp]) 268 | if estimate_nullvar: 269 | raise NotImplementedError() 270 | else: 271 | nullvarx, nullvary = None, None 272 | return (test_statistic, 273 | null_samples, 274 | nullvarx, 275 | nullvary, 276 | phix, 277 | phiy, 278 | time_passed) 279 | 280 | def get_spectrum_on_data(self, Mx, My): 281 | '''Mx and My are Kx Ky when rff =False 282 | Mx and My are phix, phiy when rff =True''' 283 | if self.rff | self.induce_set: 284 | Cx = np.cov(Mx.T) 285 | Cy = np.cov(My.T) 286 | lambdax = np.linalg.eigvalsh(Cx) 287 | lambday = np.linalg.eigvalsh(Cy) 288 | else: 289 | Kxc = Kernel.center_kernel_matrix(Mx) 290 | Kyc = Kernel.center_kernel_matrix(My) 291 | lambdax = np.linalg.eigvalsh(Kxc) 292 | lambday = np.linalg.eigvalsh(Kyc) 293 | return lambdax, lambday 294 | 295 | @abstractmethod 296 | def compute_kernel_matrix_on_data(self, data_x, data_y): 297 | if self.kernelX_use_median: 298 | sigmax = self.kernelX.get_sigma_median_heuristic(data_x) 299 | self.kernelX.set_width(float(sigmax)) 300 | if self.kernelY_use_median: 301 | sigmay = self.kernelY.get_sigma_median_heuristic(data_y) 302 | self.kernelY.set_width(float(sigmay)) 303 | Kx = self.kernelX.kernel(data_x) 304 | Ky = self.kernelY.kernel(data_y) 305 | return Kx, Ky 306 | 307 | @abstractmethod 308 | def compute_kernel_matrix_on_dataB(self, data_x, data_y): 309 | Kx = self.kernelX.kernel(data_x) 310 | Ky = self.kernelY.kernel(data_y) 311 | return Kx, Ky 312 | 313 | @abstractmethod 314 | def compute_kernel_matrix_on_data_CI(self, data_x, data_y, data_z): 315 | if self.kernelX_use_median: 316 | sigmax = self.kernelX.get_sigma_median_heuristic(data_x) 317 | self.kernelX.set_width(float(sigmax)) 318 | if self.kernelY_use_median: 319 | sigmay = self.kernelY.get_sigma_median_heuristic(data_y) 320 | self.kernelY.set_width(float(sigmay)) 321 | if self.kernelZ_use_median: 322 | sigmaz = self.kernelZ.get_sigma_median_heuristic(data_z) 323 | self.kernelZ.set_width(float(sigmaz)) 324 | Kx = self.kernelX.kernel(data_x) 325 | Ky = self.kernelY.kernel(data_y) 326 | Kz = self.kernelZ.kernel(data_z) 327 | return Kx, Ky, Kz 328 | 329 | def unbiased_HSnorm_estimate_of_centred_operator(self, Kx, Ky): 330 | '''returns an unbiased estimate of 2*Sum_p Sum_q lambda^2_p theta^2_q 331 | where lambda and theta are the eigenvalues 332 | of the centered matrices for X and Y respectively''' 333 | varx = HSICTestObject.HSIC_U_statistic(Kx, Kx) 334 | vary = HSICTestObject.HSIC_U_statistic(Ky, Ky) 335 | return varx, vary 336 | 337 | @abstractmethod 338 | def compute_rff_on_data(self, data_x, data_y): 339 | self.kernelX.rff_generate(self.num_rfx, dim=shape(data_x)[1]) 340 | self.kernelY.rff_generate(self.num_rfy, dim=shape(data_y)[1]) 341 | if self.kernelX_use_median: 342 | sigmax = self.kernelX.get_sigma_median_heuristic(data_x) 343 | self.kernelX.set_width(float(sigmax)) 344 | if self.kernelY_use_median: 345 | sigmay = self.kernelY.get_sigma_median_heuristic(data_y) 346 | self.kernelY.set_width(float(sigmay)) 347 | phix = self.kernelX.rff_expand(data_x) 348 | phiy = self.kernelY.rff_expand(data_y) 349 | return phix, phiy 350 | 351 | @abstractmethod 352 | def compute_induced_kernel_matrix_on_data(self, data_x, data_y): 353 | '''Z follows the same distribution as X; W follows that of Y. 354 | The current data generating methods we use 355 | generate X and Y at the same time. ''' 356 | size_induced_set = max(self.num_inducex, self.num_inducey) 357 | if self.data_generator is None: 358 | subsample_idx = np.random.randint( 359 | self.num_samples, size=size_induced_set) 360 | self.data_z = data_x[subsample_idx, :] 361 | self.data_w = data_y[subsample_idx, :] 362 | else: 363 | self.data_z, self.data_w = self.data_generator(size_induced_set) 364 | self.data_z[[range(self.num_inducex)], :] 365 | self.data_w[[range(self.num_inducey)], :] 366 | if self.kernelX_use_median: 367 | sigmax = self.kernelX.get_sigma_median_heuristic(data_x) 368 | self.kernelX.set_width(float(sigmax)) 369 | if self.kernelY_use_median: 370 | sigmay = self.kernelY.get_sigma_median_heuristic(data_y) 371 | self.kernelY.set_width(float(sigmay)) 372 | Kxz = self.kernelX.kernel(data_x, self.data_z) 373 | Kzz = self.kernelX.kernel(self.data_z) 374 | # R = inv(sqrtm(Kzz)) 375 | R = inv(sqrtm(Kzz + np.eye(np.shape(Kzz)[0])*10**(-6))) 376 | phix = Kxz.dot(R) 377 | Kyw = self.kernelY.kernel(data_y, self.data_w) 378 | Kww = self.kernelY.kernel(self.data_w) 379 | # S = inv(sqrtm(Kww)) 380 | S = inv(sqrtm(Kww + np.eye(np.shape(Kww)[0])*10**(-6))) 381 | phiy = Kyw.dot(S) 382 | return phix, phiy 383 | 384 | def compute_pvalue(self, data_x=None, data_y=None): 385 | pvalue, _ = self.compute_pvalue_with_time_tracking(data_x, data_y) 386 | return pvalue 387 | 388 | 389 | class HSICPermutationTestObject(HSICTestObject): 390 | 391 | def __init__(self, 392 | num_samples, 393 | data_generator=None, 394 | kernelX=None, 395 | kernelY=None, 396 | kernelX_use_median=False, 397 | kernelY_use_median=False, 398 | num_rfx=None, 399 | num_rfy=None, 400 | rff=False, 401 | induce_set=False, 402 | num_inducex=None, 403 | num_inducey=None, 404 | num_shuffles=500, 405 | unbiased=True): 406 | HSICTestObject.__init__(self, 407 | num_samples, 408 | data_generator=data_generator, 409 | kernelX=kernelX, 410 | kernelY=kernelY, 411 | kernelX_use_median=kernelX_use_median, 412 | kernelY_use_median=kernelY_use_median, 413 | num_rfx=num_rfx, 414 | num_rfy=num_rfy, 415 | rff=rff, 416 | induce_set=induce_set, 417 | num_inducex=num_inducex, 418 | num_inducey=num_inducey) 419 | self.num_shuffles = num_shuffles 420 | self.unbiased = unbiased 421 | 422 | def compute_pvalue_with_time_tracking(self, data_x=None, data_y=None): 423 | if data_x is None and data_y is None: 424 | if not self.streaming and not self.freeze_data: 425 | start = time.clock() 426 | self.generate_data() 427 | data_generating_time = time.clock()-start 428 | data_x = self.data_x 429 | data_y = self.data_y 430 | else: 431 | data_generating_time = 0. 432 | else: 433 | data_generating_time = 0. 434 | hsic_statistic, null_samples = self.HSICmethod( 435 | unbiased=self.unbiased, 436 | num_shuffles=self.num_shuffles, 437 | data_x=data_x, 438 | data_y=data_y)[:2] 439 | pvalue = (1 + sum(null_samples > hsic_statistic)) \ 440 | / float(1 + self.num_shuffles) 441 | 442 | return pvalue, data_generating_time 443 | 444 | 445 | class GenericTests(): 446 | @staticmethod 447 | def check_type(varvalue, varname, vartype, required_shapelen=None): 448 | if not type(varvalue) is vartype: 449 | raise TypeError( 450 | "Variable " + varname 451 | + " must be of type " + vartype.__name__ + 452 | ". Given is " + str(type(varvalue))) 453 | if required_shapelen is not None: 454 | if not len(varvalue.shape) is required_shapelen: 455 | raise ValueError( 456 | "Variable " + varname 457 | + " must be " + str(required_shapelen) + "-dimensional") 458 | return 0 459 | 460 | 461 | class Kernel(object): 462 | def __init__(self): 463 | self.rff_num = None 464 | self.rff_freq = None 465 | pass 466 | 467 | def __str__(self): 468 | s = "" 469 | return s 470 | 471 | @abstractmethod 472 | def kernel(self, X, Y=None): 473 | raise NotImplementedError() 474 | 475 | @abstractmethod 476 | def set_kerpar(self, kerpar): 477 | self.set_width(kerpar) 478 | 479 | @abstractmethod 480 | def set_width(self, width): 481 | if hasattr(self, 'width'): 482 | if self.rff_freq is not None: 483 | self.rff_freq = self.unit_rff_freq / width 484 | self.width = width 485 | else: 486 | raise ValueError("Senseless: kernel has no 'width' attribute!") 487 | 488 | @abstractmethod 489 | def rff_generate(self, m, dim=1): 490 | raise NotImplementedError() 491 | 492 | @abstractmethod 493 | def rff_expand(self, X): 494 | if self.rff_freq is None: 495 | raise ValueError( 496 | "rff_freq has not been set. use rff_generate first") 497 | """ 498 | Computes the random Fourier features for the input dataset X 499 | for a set of frequencies in rff_freq. 500 | This set of frequencies has to be precomputed 501 | X - 2d numpy.ndarray, first set of samples: 502 | number of rows: number of samples 503 | number of columns: dimensionality 504 | """ 505 | GenericTests.check_type(X, 'X', np.ndarray) 506 | xdotw = dot(X, (self.rff_freq).T) 507 | return sqrt(2. / self.rff_num) * np.concatenate( 508 | (cos(xdotw), sin(xdotw)), axis=1) 509 | 510 | @abstractmethod 511 | def gradient(self, x, Y): 512 | 513 | # ensure this in every implementation 514 | assert(len(shape(x)) == 1) 515 | assert(len(shape(Y)) == 2) 516 | assert(len(x) == shape(Y)[1]) 517 | 518 | raise NotImplementedError() 519 | 520 | @staticmethod 521 | def centering_matrix(n): 522 | """ 523 | Returns the centering matrix eye(n) - 1.0 / n 524 | """ 525 | return np.eye(n) - 1.0 / n 526 | 527 | @staticmethod 528 | def center_kernel_matrix(K): 529 | """ 530 | Centers the kernel matrix via a centering matrix H=I-1/n 531 | and returns HKH 532 | """ 533 | n = shape(K)[0] 534 | H = np.eye(n) - 1.0 / n 535 | return 1.0 / n * H.dot(K.dot(H)) 536 | 537 | @abstractmethod 538 | def svc(self, X, y, lmbda=1.0, Xtst=None, ytst=None): 539 | from sklearn import svm 540 | svc = svm.SVC(kernel=self.kernel, C=lmbda) 541 | svc.fit(X, y) 542 | if Xtst is None: 543 | return svc 544 | else: 545 | ypre = svc.predict(Xtst) 546 | if ytst is None: 547 | return svc, ypre 548 | else: 549 | return svc, ypre, 1-svc.score(Xtst, ytst) 550 | 551 | @abstractmethod 552 | def svc_rff(self, X, y, lmbda=1.0, Xtst=None, ytst=None): 553 | from sklearn import svm 554 | phi = self.rff_expand(X) 555 | svc = svm.LinearSVC(C=lmbda, dual=True) 556 | svc.fit(phi, y) 557 | if Xtst is None: 558 | return svc 559 | else: 560 | phitst = self.rff_expand(Xtst) 561 | ypre = svc.predict(phitst) 562 | if ytst is None: 563 | return svc, ypre 564 | else: 565 | return svc, ypre, 1-svc.score(phitst, ytst) 566 | 567 | @abstractmethod 568 | def ridge_regress(self, X, y, lmbda=0.01, Xtst=None, ytst=None): 569 | K = self.kernel(X) 570 | n = shape(K)[0] 571 | aa = linalg.solve(K + lmbda * np.eye(n), y) 572 | if Xtst is None: 573 | return aa 574 | else: 575 | ypre = dot(aa.T, self.kernel(X, Xtst)).T 576 | if ytst is None: 577 | return aa, ypre 578 | else: 579 | return aa, ypre, (linalg.norm(ytst-ypre)**2)/np.shape(ytst)[0] 580 | 581 | @abstractmethod 582 | def ridge_regress_rff(self, X, y, lmbda=0.01, Xtst=None, ytst=None): 583 | phi = self.rff_expand(X) 584 | bb = linalg.solve( 585 | dot(phi.T, phi)+lmbda*np.eye(self.rff_num), dot(phi.T, y)) 586 | if Xtst is None: 587 | return bb 588 | else: 589 | phitst = self.rff_expand(Xtst) 590 | ypre = dot(phitst, bb) 591 | if ytst is None: 592 | return bb, ypre 593 | else: 594 | return bb, ypre, (linalg.norm(ytst-ypre)**2)/np.shape(ytst)[0] 595 | 596 | @abstractmethod 597 | def xvalidate(self, 598 | X, 599 | y, 600 | method='ridge_regress', 601 | regpar_grid=(1+arange(25))/200.0, 602 | kerpar_grid=exp(-13+arange(25)), 603 | numFolds=10, 604 | verbose=False, 605 | visualise=False): 606 | from sklearn import cross_validation 607 | which_method = getattr(self, method) 608 | n = len(X) 609 | kf = cross_validation.KFold(n, n_folds=numFolds) 610 | xvalerr = zeros((len(regpar_grid), len(kerpar_grid))) 611 | width_idx = 0 612 | for width in kerpar_grid: 613 | try: 614 | self.set_kerpar(width) 615 | except ValueError: 616 | xvalerr[:, width_idx] = np.inf 617 | width_idx += 1 618 | continue 619 | else: 620 | lmbda_idx = 0 621 | for lmbda in regpar_grid: 622 | fold = 0 623 | prederr = zeros(numFolds) 624 | for train_index, test_index in kf: 625 | if type(X) == list: 626 | X_train = [X[i] for i in train_index] 627 | X_test = [X[i] for i in test_index] 628 | else: 629 | X_train, X_test = X[train_index], X[test_index] 630 | if type(y) == list: 631 | y_train = [y[i] for i in train_index] 632 | y_test = [y[i] for i in test_index] 633 | else: 634 | y_train, y_test = y[train_index], y[test_index] 635 | _, _, prederr[fold] = which_method(X_train, 636 | y_train, 637 | lmbda=lmbda, 638 | Xtst=X_test, 639 | ytst=y_test) 640 | fold += 1 641 | xvalerr[lmbda_idx, width_idx] = mean(prederr) 642 | lmbda_idx += 1 643 | width_idx += 1 644 | min_idx = np.unravel_index(np.argmin(xvalerr), shape(xvalerr)) 645 | if visualise: 646 | plt.imshow(xvalerr, 647 | interpolation='none', 648 | origin='lower', 649 | cmap=cm.pink) 650 | plt.colorbar() 651 | plt.title("cross-validated loss") 652 | plt.ylabel("regularisation parameter") 653 | plt.xlabel("kernel parameter") 654 | plt.show() 655 | return regpar_grid[min_idx[0]], kerpar_grid[min_idx[1]] 656 | 657 | @abstractmethod 658 | def estimateMMD(self, sample1, sample2, unbiased=False): 659 | """ 660 | Compute the MMD between two samples 661 | """ 662 | K11 = self.kernel(sample1) 663 | K22 = self.kernel(sample2) 664 | K12 = self.kernel(sample1, sample2) 665 | if unbiased: 666 | fill_diagonal(K11, 0.0) 667 | fill_diagonal(K22, 0.0) 668 | n = float(shape(K11)[0]) 669 | m = float(shape(K22)[0]) 670 | return sum(sum(K11))/(pow(n, 2)-n) \ 671 | + sum(sum(K22))/(pow(m, 2)-m) - 2*mean(K12[:]) 672 | else: 673 | return mean(K11[:])+mean(K22[:])-2*mean(K12[:]) 674 | 675 | @abstractmethod 676 | def estimateMMD_rff(self, sample1, sample2, unbiased=False): 677 | phi1 = self.rff_expand(sample1) 678 | phi2 = self.rff_expand(sample2) 679 | featuremean1 = mean(phi1, axis=0) 680 | featuremean2 = mean(phi2, axis=0) 681 | if unbiased: 682 | nx = shape(phi1)[0] 683 | ny = shape(phi2)[0] 684 | first_term = nx/(nx-1.0)*(dot(featuremean1, featuremean1) 685 | - mean(linalg.norm(phi1, axis=1)**2)/nx) 686 | second_term = ny/(ny-1.0)*(dot(featuremean2, featuremean2) 687 | - mean(linalg.norm(phi2, axis=1)**2)/ny) 688 | third_term = -2*dot(featuremean1, featuremean2) 689 | return first_term+second_term+third_term 690 | else: 691 | return linalg.norm(featuremean1-featuremean2)**2 692 | 693 | 694 | class GaussianKernel(Kernel): 695 | def __init__(self, sigma=1.0, is_sparse=False): 696 | Kernel.__init__(self) 697 | self.width = sigma 698 | self.is_sparse = is_sparse 699 | 700 | def __str__(self): 701 | s = self.__class__.__name__ + "[" 702 | s += "width=" + str(self.width) 703 | s += "]" 704 | return s 705 | 706 | def kernel(self, X, Y=None): 707 | """ 708 | Computes the standard Gaussian kernel 709 | k(x,y)=exp(-0.5* ||x-y||**2 / sigma**2) 710 | 711 | X - 2d numpy.ndarray, first set of samples: 712 | number of rows: number of samples 713 | number of columns: dimensionality 714 | Y - 2d numpy.ndarray, second set of samples, 715 | can be None in which case its replaced by X 716 | """ 717 | if self.is_sparse: 718 | X = X.todense() 719 | Y = Y.todense() 720 | GenericTests.check_type(X, 'X', np.ndarray) 721 | assert(len(shape(X)) == 2) 722 | 723 | # if X=Y, use more efficient pdist call which exploits symmetry 724 | if Y is None: 725 | sq_dists = squareform(pdist(X, 'sqeuclidean')) 726 | else: 727 | GenericTests.check_type(Y, 'Y', np.ndarray) 728 | assert(len(shape(Y)) == 2) 729 | assert(shape(X)[1] == shape(Y)[1]) 730 | sq_dists = cdist(X, Y, 'sqeuclidean') 731 | 732 | K = exp(-0.5 * (sq_dists) / self.width ** 2) 733 | return K 734 | 735 | def gradient(self, x, Y): 736 | """ 737 | Computes the gradient of the Gaussian kernel 738 | wrt. to the left argument, i.e. 739 | k(x,y)=exp(-0.5* ||x-y||**2 / sigma**2), which is 740 | \nabla_x k(x,y)=1.0/sigma**2 k(x,y)(y-x) 741 | Given a set of row vectors Y, this computes the 742 | gradient for every pair (x,y) for y in Y. 743 | """ 744 | if self.is_sparse: 745 | x = x.todense() 746 | Y = Y.todense() 747 | assert(len(shape(x)) == 1) 748 | assert(len(shape(Y)) == 2) 749 | assert(len(x) == shape(Y)[1]) 750 | 751 | x_2d = np.reshape(x, (1, len(x))) 752 | k = self.kernel(x_2d, Y) 753 | differences = Y - x 754 | G = (1.0 / self.width ** 2) * (k.T * differences) 755 | return G 756 | 757 | def rff_generate(self, m, dim=1): 758 | self.rff_num = m 759 | self.unit_rff_freq = randn(int(m/2), dim) 760 | self.rff_freq = self.unit_rff_freq/self.width 761 | 762 | @staticmethod 763 | def get_sigma_median_heuristic(X, is_sparse=False): 764 | if is_sparse: 765 | X = X.todense() 766 | n = shape(X)[0] 767 | if n > 1000: 768 | X = X[permutation(n)[:1000], :] 769 | dists = squareform(pdist(X, 'euclidean')) 770 | median_dist = np.median(dists[dists > 0]) 771 | sigma = median_dist/sqrt(2.) 772 | return sigma 773 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | causalicp 2 | matplotlib 3 | numpy 4 | pandas 5 | pygam 6 | seaborn 7 | statsmodels 8 | -------------------------------------------------------------------------------- /runtime.txt: -------------------------------------------------------------------------------- 1 | r-4.1.2-2023-06-01 2 | --------------------------------------------------------------------------------