├── .gitignore ├── Chapter_01.ipynb ├── Chapter_02.ipynb ├── Chapter_03.ipynb ├── Chapter_04.ipynb ├── Chapter_05.ipynb ├── Chapter_06.ipynb ├── Chapter_07.ipynb ├── Chapter_08.ipynb ├── Chapter_09.ipynb ├── Chapter_10.ipynb ├── Chapter_11.1.ipynb ├── Chapter_11.2.ipynb ├── Chapter_13.ipynb ├── Chapter_14.ipynb ├── EXP -- Chapter_09.ipynb ├── Extras_01__DGPs.ipynb ├── Extras_02__Additional_computations.ipynb ├── LICENSE ├── README.md ├── causal-pymc.yml ├── causal_book_py39_apple_m1_(experimental-by-ferrari-leo).txt ├── causal_book_py39_cuda117.yml ├── causal_model.png ├── data ├── ch_01_drug_data.csv ├── data_11_no_interaction_test.csv ├── data_11_no_interaction_train.csv ├── data_11_with_interaction_test.csv ├── data_11_with_interaction_train.csv ├── gt_social_media_data.csv ├── hillstrom_clean.csv ├── hillstrom_clean_label_mapping.json ├── hillstrom_original.csv ├── manga.csv ├── manga_processed.csv ├── ml_earnings.csv ├── ml_earnings_interaction_test.csv ├── ml_earnings_interaction_train.csv └── shpitser_thesis1.grapl ├── docs └── himsolt-gml-technical-report.pdf ├── errata ├── Errata - Early Print (ordered before June 13 2023).ipynb ├── Errata - Non-Early Print (ordered after June 13 2023).ipynb ├── img │ ├── ch_04__fig_4_1.png │ ├── ch_06__fig_6_3.png │ ├── ch_06__fig_6_4.png │ ├── ch_06__fig_6_5.png │ └── ch_07__fig_7_6.png └── minor-errors.txt ├── img ├── ch_03_graph_01 ├── ch_03_graph_01.png ├── ch_03_graph_02 ├── ch_03_graph_02.png ├── ch_04_graph_DAG ├── ch_04_graph_DAG.png ├── ch_04_graph_DCG ├── ch_04_graph_DCG.png ├── ch_04_graph_Fully connected ├── ch_04_graph_Fully connected.png ├── ch_04_graph_Partially connected ├── ch_04_graph_Partially connected.png ├── ch_04_graph_Undirected ├── ch_04_graph_Undirected.png ├── ch_04_graph_adj_00 ├── ch_04_graph_adj_00.png ├── ch_04_graph_adj_01 ├── ch_04_graph_adj_01.png ├── ch_04_graph_adj_02 ├── ch_04_graph_adj_02.png ├── ch_05_chain_00 ├── ch_05_chain_00.png ├── ch_05_collider_00 ├── ch_05_collider_00.png ├── ch_05_fork_00 ├── ch_05_fork_00.png ├── ch_05_markov_01 ├── ch_05_markov_01.png ├── ch_05_markov_02 ├── ch_05_markov_02.png ├── ch_06_confounding_00 ├── ch_06_confounding_00.png ├── ch_06_d_sep_00 ├── ch_06_d_sep_00.png ├── ch_06_d_sep_01 ├── ch_06_d_sep_01.png ├── ch_06_d_sep_02 ├── ch_06_d_sep_02.png ├── ch_06_d_sep_03 ├── ch_06_d_sep_03.png ├── ch_06_d_sep_04 ├── ch_06_d_sep_04.png ├── ch_06_equivalent_estimands_00 ├── ch_06_equivalent_estimands_00.png ├── ch_06_equivalent_estimands_01 ├── ch_06_equivalent_estimands_01.png ├── ch_06_gps_00 ├── ch_06_gps_00.png ├── ch_06_gps_01 ├── ch_06_gps_01.png ├── ch_06_gps_02 ├── ch_06_gps_02.png ├── ch_06_gps_03 ├── ch_06_gps_03.png ├── ch_06_icecream ├── ch_06_icecream.png ├── ch_06_instrumental_00 ├── ch_06_instrumental_00.png ├── ch_07_full_example ├── ch_07_full_example.png ├── ch_08_modularity ├── ch_08_modularity.png ├── ch_08_modularity_mod ├── ch_08_modularity_mod.png ├── ch_08_selection ├── ch_08_selection.png ├── ch_08_selection_02 ├── ch_08_selection_02.png ├── ch_08_selection_03 └── ch_08_selection_03.png └── models └── causal_bert_pytorch ├── CausalBert.py ├── README.md ├── __pycache__ ├── CausalBert.cpython-38.pyc └── CausalBert.cpython-39.pyc └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints -------------------------------------------------------------------------------- /Chapter_01.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "d885824b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from itertools import combinations\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "from scipy import stats\n", 14 | "\n", 15 | "import networkx as nx\n", 16 | "\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "plt.style.use('fivethirtyeight')" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "7a2494ed", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "COLORS = [\n", 29 | " '#00B0F0',\n", 30 | " '#FF0000'\n", 31 | "]" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "a0fd06fd", 37 | "metadata": {}, 38 | "source": [ 39 | "# Chapter 01\n", 40 | "\n", 41 | "This chapter introduces the concept of causality and highlights similarities and differences between causal inference and statistical learning. A brief historical outline of the concept of causality is provided to help the reader understand a broader context. Finally, three motivating examples are provided (medicine, marketing, social policy) to demonstrate the importance of causal inference in terms of technical, practical and business perspectives. " 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "c6c49b0c", 47 | "metadata": {}, 48 | "source": [ 49 | "## Confounding " 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 33, 55 | "id": "9db62153", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# Let's set random seed for reproducibility\n", 60 | "np.random.seed(45)\n", 61 | "\n", 62 | "# `b` represents our confounder\n", 63 | "b = np.random.rand(100)\n", 64 | "\n", 65 | "# `a` and `c` are causally independent of each other, but they are both children of `b` \n", 66 | "a = b + .1 * np.random.rand(100)\n", 67 | "c = b + .3 * np.random.rand(100)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 34, 73 | "id": "25652f3a", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "0.9627497625297509\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "# Let's check correlation between `a` and `c`\n", 86 | "coef, p_val = stats.pearsonr(a, c)\n", 87 | "\n", 88 | "print(coef)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 35, 94 | "id": "ac319993", 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "data": { 99 | "image/png": "\n", 100 | "text/plain": [ 101 | "
" 102 | ] 103 | }, 104 | "metadata": {}, 105 | "output_type": "display_data" 106 | } 107 | ], 108 | "source": [ 109 | "variables = {\n", 110 | " 'a': a,\n", 111 | " 'b': b,\n", 112 | " 'c': c\n", 113 | "}\n", 114 | "\n", 115 | "plt.figure(figsize=(12, 7))\n", 116 | "\n", 117 | "for i, (var_1, var_2) in enumerate([('b', 'a'), ('b', 'c'), ('c', 'a')]):\n", 118 | " \n", 119 | " color = COLORS[1]\n", 120 | " \n", 121 | " if 'b' in [var_1, var_2]:\n", 122 | " color = COLORS[0]\n", 123 | " \n", 124 | " plt.subplot(2, 2, i + 1)\n", 125 | " plt.scatter(variables[var_1], variables[var_2], alpha=.8, color=color)\n", 126 | " \n", 127 | " plt.xlabel(f'${var_1}$', fontsize=16)\n", 128 | " plt.ylabel(f'${var_2}$', fontsize=16)\n", 129 | "\n", 130 | "plt.suptitle('Pairwise relationships between $a$, $b$ and $c$')\n", 131 | "plt.subplots_adjust(hspace=.25, wspace=.25)\n", 132 | "plt.show()" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "a1f0699b", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python [conda env:causal_book_py38]", 147 | "language": "python", 148 | "name": "conda-env-causal_book_py38-py" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.8.13" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 5 165 | } 166 | -------------------------------------------------------------------------------- /Extras_01__DGPs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 387, 6 | "id": "de02baff", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import pandas as pd\n", 12 | "from scipy import stats\n", 13 | "\n", 14 | "from sklearn.model_selection import train_test_split\n", 15 | "\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "plt.style.use('fivethirtyeight')" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "id": "702ababf", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "COLORS = [\n", 28 | " '#00B0F0',\n", 29 | " '#FF0000'\n", 30 | "]" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "4d417229", 36 | "metadata": {}, 37 | "source": [ 38 | "# Additional code for data generating processes" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "id": "ac23d7d9", 44 | "metadata": {}, 45 | "source": [ 46 | "## Chapter 09" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "7058fff7", 52 | "metadata": {}, 53 | "source": [ 54 | "### Post-training earnings data (simple)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 39, 60 | "id": "865c3739", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "SAMPLE_SIZE = 200\n", 65 | "MAX_AGE = 50\n", 66 | "\n", 67 | "age = stats.halfnorm.rvs(loc=19, scale=10, size=SAMPLE_SIZE).astype(int)\n", 68 | "age = np.where(age > MAX_AGE, np.random.choice(np.arange(20, MAX_AGE)), age)\n", 69 | "\n", 70 | "took_a_course = stats.bernoulli(p=10/age).rvs().astype(bool)\n", 71 | "\n", 72 | "earnings = 75000 + took_a_course * 10000 + age * 1000 + age**2 * 50 + np.random.randn(SAMPLE_SIZE) * 2000\n", 73 | "earnings = earnings.round()\n", 74 | "\n", 75 | "earnings = pd.DataFrame(dict(\n", 76 | " age=age,\n", 77 | " took_a_course=took_a_course,\n", 78 | " earnings=earnings\n", 79 | "))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 40, 85 | "id": "ca46d6ce", 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/plain": [ 91 | "" 92 | ] 93 | }, 94 | "execution_count": 40, 95 | "metadata": {}, 96 | "output_type": "execute_result" 97 | }, 98 | { 99 | "data": { 100 | "image/png": "\n", 101 | "text/plain": [ 102 | "
" 103 | ] 104 | }, 105 | "metadata": {}, 106 | "output_type": "display_data" 107 | } 108 | ], 109 | "source": [ 110 | "plt.scatter(earnings.age, earnings.earnings)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 41, 116 | "id": "100a985c", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "earnings.to_csv('data/ml_earnings.csv', index=False)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "567bbbbb", 126 | "metadata": {}, 127 | "source": [ 128 | "### Post-training earnings data (enhanced)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 139, 134 | "id": "2d26efeb", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "# Train set large\n", 139 | "SAMPLE_SIZE = 5000\n", 140 | "MAX_AGE = 50\n", 141 | "\n", 142 | "age = stats.halfnorm.rvs(loc=19, scale=10, size=SAMPLE_SIZE).astype(int)\n", 143 | "age = np.where(age > MAX_AGE, np.random.choice(np.arange(20, MAX_AGE)), age)\n", 144 | "\n", 145 | "took_a_course = stats.bernoulli(p=10/age).rvs().astype(bool)\n", 146 | "python_proficiency = np.random.uniform(0, 1, SAMPLE_SIZE)\n", 147 | "\n", 148 | "noise = np.random.randn(SAMPLE_SIZE)\n", 149 | "\n", 150 | "earnings = 75000 + took_a_course * 10000 + took_a_course * python_proficiency * 5000 + age * 1000 + age**2 * 50 + noise * 2000\n", 151 | "earnings = earnings.round()\n", 152 | "\n", 153 | "earnings = pd.DataFrame(dict(\n", 154 | " age=age,\n", 155 | " python_proficiency = python_proficiency,\n", 156 | " took_a_course=took_a_course,\n", 157 | " earnings=earnings\n", 158 | "))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 140, 164 | "id": "99788d91", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "earnings.to_csv('data/ml_earnings_interaction_train.csv', index=False)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 143, 174 | "id": "555e6273", 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "# Test set\n", 179 | "SAMPLE_SIZE = 100\n", 180 | "MAX_AGE = 50\n", 181 | "\n", 182 | "age = stats.halfnorm.rvs(loc=19, scale=10, size=SAMPLE_SIZE).astype(int)\n", 183 | "age = np.where(age > MAX_AGE, np.random.choice(np.arange(20, MAX_AGE)), age)\n", 184 | "\n", 185 | "python_proficiency = np.random.uniform(0, 1, SAMPLE_SIZE)\n", 186 | "\n", 187 | "noise = np.random.randn(SAMPLE_SIZE)\n", 188 | "\n", 189 | "earnings_0 = (75000 + 0 * 10000 + 0 * python_proficiency * 5000 + age * 5000 + age**2 * 50 + noise * 2000).round()\n", 190 | "earnings_1 = (75000 + 1 * 10000 + 1 * python_proficiency * 5000 + age * 5000 + age**2 * 50 + noise * 2000).round()\n", 191 | "true_effect = earnings_1 - earnings_0\n", 192 | "\n", 193 | "earnings_test = pd.DataFrame(dict(\n", 194 | " age=age,\n", 195 | " python_proficiency=python_proficiency,\n", 196 | " took_a_course=True,\n", 197 | " true_effect=true_effect\n", 198 | "))" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 144, 204 | "id": "2763703c", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "earnings_test.to_csv('data/ml_earnings_interaction_test.csv', index=False)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "id": "74afa358", 214 | "metadata": {}, 215 | "source": [ 216 | "## Chapter 11" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "id": "283685f2", 222 | "metadata": {}, 223 | "source": [ 224 | "### Simulated data" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 467, 230 | "id": "9db54f09", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "SAMPLE_SIZE = 2000\n", 235 | "\n", 236 | "def get_data_dgp_11(sample_size=1000, with_interaction=True):\n", 237 | " chi = np.random.chisquare(12, (sample_size, 5)) / 30\n", 238 | " norm = np.random.normal(0, 1, (sample_size, 5))\n", 239 | " binom = np.random.binomial(1, [.3, .5, .7, .12, .9], (sample_size, 5))\n", 240 | " gumbel = np.random.gumbel(-2, 1, (sample_size, 5)) / 6\n", 241 | "\n", 242 | " X = np.concatenate([chi, norm, binom, gumbel], axis=1)\n", 243 | "\n", 244 | " T0 = np.zeros(sample_size)\n", 245 | " T1 = np.ones(sample_size)\n", 246 | " \n", 247 | " noise = np.random.randn(sample_size)\n", 248 | " \n", 249 | " coefs = np.random.gumbel(0, 10, X.shape[1] + 1)\n", 250 | " \n", 251 | " if with_interaction:\n", 252 | " interaction_term = X[:, 0] + X[:, 7]\n", 253 | " y0 = (coefs * np.concatenate([(T0 * interaction_term).reshape(-1, 1), X], axis=1)).sum(axis=1) + noise\n", 254 | " y1 = (coefs * np.concatenate([(T1 * interaction_term).reshape(-1, 1), X], axis=1)).sum(axis=1) + noise\n", 255 | " else:\n", 256 | " interaction_term = 1\n", 257 | " y0 = (coefs * np.concatenate([\n", 258 | " T0.reshape(-1, 1), \n", 259 | " X], \n", 260 | " axis=1))\\\n", 261 | " .sum(axis=1) + noise\n", 262 | " \n", 263 | " y1 = (coefs * np.concatenate([\n", 264 | " 10 * np.exp(\n", 265 | " T1 + X[:, 7:12]\\\n", 266 | " .sum(axis=1) / 1000).reshape(-1, 1), \n", 267 | " X], \n", 268 | " axis=1)).sum(axis=1) + noise\n", 269 | " \n", 270 | " return X, y0, y1, coefs\n", 271 | "\n", 272 | "\n", 273 | "def filter_outcomes(y, t):\n", 274 | " result = np.zeros_like(t, dtype=float) \n", 275 | " result[t == 0] = y[t == 0, 0] \n", 276 | " result[t == 1] = y[t == 1, 1] \n", 277 | " return result" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "id": "51807226", 283 | "metadata": {}, 284 | "source": [ 285 | "#### No interaction" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 468, 291 | "id": "08f5832e", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "# No-interaction data\n", 296 | "X, y0, y1, coefs = get_data_dgp_11(SAMPLE_SIZE, False)\n", 297 | "\n", 298 | "# Generate treatments\n", 299 | "T = np.random.binomial(1, 0.5, SAMPLE_SIZE)\n", 300 | "\n", 301 | "# Train-test split\n", 302 | "X_train, X_test, y_train, y_test, T_train, T_test = train_test_split(X, np.stack([y0, y1]).T, T, test_size=.1)\n", 303 | "\n", 304 | "# Filter actual outcomes\n", 305 | "y_train_actual = filter_outcomes(y_train, T_train)\n", 306 | "y_test_actual = filter_outcomes(y_test, T_test)\n", 307 | "\n", 308 | "# To DF & update cols\n", 309 | "data_no_interaction_train = pd.DataFrame(\n", 310 | " np.concatenate(\n", 311 | " [\n", 312 | " X_train, \n", 313 | " T_train.reshape(-1, 1), \n", 314 | " y_train_actual.reshape(-1, 1)],\n", 315 | " axis=1\n", 316 | " )\n", 317 | ")\n", 318 | "\n", 319 | "data_no_interaction_train.columns = [f'x{i}' for i in range(20)] + ['treatment', 'outcome']\n", 320 | "\n", 321 | "\n", 322 | "data_no_interaction_test = pd.DataFrame(\n", 323 | " np.concatenate(\n", 324 | " [\n", 325 | " X_test, \n", 326 | " T_test.reshape(-1, 1), \n", 327 | " (y_test[:, 1] - y_test[:, 0]).reshape(-1, 1)],\n", 328 | " axis=1\n", 329 | " )\n", 330 | ")\n", 331 | "\n", 332 | "\n", 333 | "data_no_interaction_test.columns = [f'x{i}' for i in range(20)] + ['treatment', 'true_effect']\n", 334 | "\n", 335 | "# Store\n", 336 | "data_no_interaction_train.to_csv('data/data_11_no_interaction_train.csv', index=False)\n", 337 | "data_no_interaction_test.to_csv('data/data_11_no_interaction_test.csv', index=False)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 469, 343 | "id": "e87ab637", 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "data": { 348 | "text/plain": [ 349 | "(array([ 34.89751834, 8.61226375, -12.63565299, ..., 2.99937547,\n", 350 | " 14.48281498, 32.13379007]),\n", 351 | " array([ -9.92622448, -36.0636879 , -57.4257863 , ..., -41.66052399,\n", 352 | " -30.30810687, -12.62264189]))" 353 | ] 354 | }, 355 | "execution_count": 469, 356 | "metadata": {}, 357 | "output_type": "execute_result" 358 | } 359 | ], 360 | "source": [ 361 | "y0, y1" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 470, 367 | "id": "9ad1f4f9", 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "data": { 372 | "text/plain": [ 373 | "array([-44.82374282, -44.67595165, -44.79013331, ..., -44.65989946,\n", 374 | " -44.79092185, -44.75643196])" 375 | ] 376 | }, 377 | "execution_count": 470, 378 | "metadata": {}, 379 | "output_type": "execute_result" 380 | } 381 | ], 382 | "source": [ 383 | "y1 - y0" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "id": "31ce7420", 389 | "metadata": {}, 390 | "source": [ 391 | "#### With interaction" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 475, 397 | "id": "71909b39", 398 | "metadata": {}, 399 | "outputs": [], 400 | "source": [ 401 | "# No-interaction data\n", 402 | "X, y0, y1, coefs = get_data_dgp_11(SAMPLE_SIZE, True)\n", 403 | "\n", 404 | "# Generate treatments\n", 405 | "T = np.random.binomial(1, 0.5, SAMPLE_SIZE)\n", 406 | "\n", 407 | "# Train-test split\n", 408 | "X_train, X_test, y_train, y_test, T_train, T_test = train_test_split(X, np.stack([y0, y1]).T, T, test_size=.1)\n", 409 | "\n", 410 | "# Filter actual outcomes\n", 411 | "y_train_actual = filter_outcomes(y_train, T_train)\n", 412 | "y_test_actual = filter_outcomes(y_test, T_test)\n", 413 | "\n", 414 | "# To DF & update cols\n", 415 | "data_with_interaction_train = pd.DataFrame(\n", 416 | " np.concatenate(\n", 417 | " [\n", 418 | " X_train, \n", 419 | " T_train.reshape(-1, 1), \n", 420 | " y_train_actual.reshape(-1, 1)],\n", 421 | " axis=1\n", 422 | " )\n", 423 | ")\n", 424 | "\n", 425 | "data_with_interaction_train.columns = [f'x{i}' for i in range(20)] + ['treatment', 'outcome']\n", 426 | "\n", 427 | "\n", 428 | "data_with_interaction_test = pd.DataFrame(\n", 429 | " np.concatenate(\n", 430 | " [\n", 431 | " X_test, \n", 432 | " T_test.reshape(-1, 1), \n", 433 | " (y_test[:, 1] - y_test[:, 0]).reshape(-1, 1)],\n", 434 | " axis=1\n", 435 | " )\n", 436 | ")\n", 437 | "\n", 438 | "\n", 439 | "data_with_interaction_test.columns = [f'x{i}' for i in range(20)] + ['treatment', 'true_effect']\n", 440 | "\n", 441 | "# Store\n", 442 | "data_with_interaction_train.to_csv('data/data_11_with_interaction_train.csv', index=False)\n", 443 | "data_with_interaction_test.to_csv('data/data_11_with_interaction_test.csv', index=False)" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 474, 449 | "id": "72d62da4", 450 | "metadata": {}, 451 | "outputs": [ 452 | { 453 | "data": { 454 | "text/plain": [ 455 | "array([[ 1.63371979e+01, 1.59045881e+01],\n", 456 | " [-3.06609145e+00, -1.84422244e+00],\n", 457 | " [-1.03359389e+01, -7.87196605e+00],\n", 458 | " [ 4.46875521e+00, 4.96359667e+00],\n", 459 | " [ 3.18576088e+01, 3.37932190e+01],\n", 460 | " [ 1.19932029e+01, 2.98884169e+00],\n", 461 | " [ 9.69525909e+00, 7.93796477e+00],\n", 462 | " [-7.38440149e+00, -8.60629373e+00],\n", 463 | " [ 2.41254919e+01, 1.35643255e+01],\n", 464 | " [ 1.97529380e+01, 1.63711534e+01],\n", 465 | " [ 2.43068304e+01, 2.23882172e+01],\n", 466 | " [-7.98322104e+00, -1.16848962e+01],\n", 467 | " [-2.54805435e+00, -9.10154055e+00],\n", 468 | " [-2.12343129e+01, -2.11751544e+01],\n", 469 | " [ 1.99864813e+00, -3.96052447e+00],\n", 470 | " [ 1.54031350e+01, 1.24315489e+01],\n", 471 | " [ 7.02887405e+00, 5.32079794e+00],\n", 472 | " [ 2.07225787e+00, 3.98963467e+00],\n", 473 | " [ 1.05768464e+01, 1.16281929e+01],\n", 474 | " [ 2.58577213e+01, 2.11922848e+01],\n", 475 | " [ 1.11960625e+01, 6.92346165e+00],\n", 476 | " [ 8.63504905e+00, 7.25811536e+00],\n", 477 | " [ 1.45579537e+01, 1.14446180e+01],\n", 478 | " [-1.05771942e+01, -1.20518737e+01],\n", 479 | " [ 1.50904679e+01, 1.10366070e+01],\n", 480 | " [ 1.18356618e+01, 4.89932210e+00],\n", 481 | " [ 1.18242024e+01, 9.95327571e+00],\n", 482 | " [-4.42150929e+00, -8.56758507e-01],\n", 483 | " [ 1.90045567e+01, 2.25398094e+01],\n", 484 | " [-9.05976959e+00, -1.25219855e+01],\n", 485 | " [ 1.83118169e+01, 1.50464956e+01],\n", 486 | " [-9.71122700e+00, -9.93741894e+00],\n", 487 | " [-4.60208952e+00, 2.71032818e+00],\n", 488 | " [-2.15014674e+01, -2.04324464e+01],\n", 489 | " [ 1.31560593e+01, 1.00683814e+01],\n", 490 | " [-1.79174245e+01, -1.94410151e+01],\n", 491 | " [ 1.93439684e+01, 2.07561603e+01],\n", 492 | " [-1.58168095e+01, -1.03663941e+01],\n", 493 | " [-7.75661603e+00, -5.33883483e+00],\n", 494 | " [ 3.22820694e+01, 2.83894343e+01],\n", 495 | " [-2.01817097e+01, -1.76334410e+01],\n", 496 | " [ 3.17113259e+01, 2.81813370e+01],\n", 497 | " [-1.09985805e+01, -7.91712686e+00],\n", 498 | " [ 2.13866154e+01, 1.66148733e+01],\n", 499 | " [-1.48789697e+01, -8.17279022e+00],\n", 500 | " [ 9.17563124e+00, 6.02526627e+00],\n", 501 | " [ 2.58140460e+00, -1.12501476e-01],\n", 502 | " [-1.21727684e+01, -1.07267230e+01],\n", 503 | " [-1.90409191e+01, -1.65927276e+01],\n", 504 | " [ 3.72328902e+01, 3.79919913e+01],\n", 505 | " [-9.47953164e-01, 1.21423457e+00],\n", 506 | " [-1.45674899e+00, -1.87900795e+00],\n", 507 | " [-1.53365385e+01, -1.20907620e+01],\n", 508 | " [ 1.32460485e+01, 8.05235665e+00],\n", 509 | " [ 3.63131086e+00, 4.44464150e+00],\n", 510 | " [-3.49429150e+00, -4.56076985e+00],\n", 511 | " [ 3.10834849e+00, -8.10915916e-01],\n", 512 | " [ 2.77119072e+01, 2.41466627e+01],\n", 513 | " [-1.70212069e+01, -1.68258705e+01],\n", 514 | " [ 5.04745524e+01, 4.66274134e+01],\n", 515 | " [ 1.48456545e+00, 2.31260440e+00],\n", 516 | " [-1.18330778e+01, -8.95333768e+00],\n", 517 | " [ 2.81931054e+01, 2.69260371e+01],\n", 518 | " [-4.21986388e+00, -1.41826508e+00],\n", 519 | " [ 1.46745912e+01, 1.20520962e+01],\n", 520 | " [-1.11813974e+01, -1.13292031e+01],\n", 521 | " [-6.02535429e+00, -5.27047080e+00],\n", 522 | " [ 3.07129641e+01, 2.91119306e+01],\n", 523 | " [ 5.46496851e+00, 4.28464672e-02],\n", 524 | " [ 2.59619746e+01, 2.00919997e+01],\n", 525 | " [ 3.49669991e+00, 1.86626270e-01],\n", 526 | " [ 1.47507462e+01, 1.03204859e+01],\n", 527 | " [ 3.67698077e+01, 3.48365664e+01],\n", 528 | " [ 5.50672610e+00, 4.63873391e+00],\n", 529 | " [ 2.72345918e+01, 1.94359379e+01],\n", 530 | " [ 1.89310805e+01, 1.55254734e+01],\n", 531 | " [ 1.24083762e+01, 6.71134780e+00],\n", 532 | " [-6.70397101e+00, -2.09137370e-01],\n", 533 | " [-2.06016817e+01, -1.86689633e+01],\n", 534 | " [-1.27942259e+01, -9.91056303e+00],\n", 535 | " [ 1.77919844e+01, 1.51883790e+01],\n", 536 | " [ 9.71692034e+00, 4.06146703e+00],\n", 537 | " [-2.39694407e+00, -1.79774750e+00],\n", 538 | " [ 1.05421436e+01, 9.23487406e+00],\n", 539 | " [ 1.22045593e+00, -1.07875268e+00],\n", 540 | " [-1.07546228e+01, -4.74439361e+00],\n", 541 | " [-4.81063247e+00, -8.78817245e+00],\n", 542 | " [ 2.31450768e+01, 1.91748792e+01],\n", 543 | " [-4.18048254e-02, 4.66992569e+00],\n", 544 | " [-4.84710941e+01, -4.25253861e+01],\n", 545 | " [ 1.08864775e+01, 9.17476425e+00],\n", 546 | " [-3.27788678e+01, -2.98826665e+01],\n", 547 | " [-3.54064472e+00, -4.09085475e+00],\n", 548 | " [ 1.12191320e+01, 1.10812952e+01],\n", 549 | " [ 1.43260242e+01, 1.37793678e+01],\n", 550 | " [ 6.55637768e+00, 5.68611171e+00],\n", 551 | " [-1.11058627e+01, -6.48561277e+00],\n", 552 | " [ 7.85651786e+00, 8.00218141e+00],\n", 553 | " [ 4.46217704e+00, 1.14376013e+00],\n", 554 | " [ 3.29265255e+01, 3.49590882e+01],\n", 555 | " [ 4.85949912e+01, 4.11315565e+01],\n", 556 | " [ 1.19354415e+01, 9.37292040e+00],\n", 557 | " [ 3.73828496e+01, 2.93836980e+01],\n", 558 | " [ 3.30241330e+01, 3.16625772e+01],\n", 559 | " [ 1.03462408e+01, -2.98447923e-02],\n", 560 | " [ 1.71716305e+01, 1.27985762e+01],\n", 561 | " [ 2.39301432e+01, 1.67916835e+01],\n", 562 | " [-6.80227926e+00, -4.66061176e+00],\n", 563 | " [ 1.11362317e+01, 9.57382673e+00],\n", 564 | " [-1.40885447e+01, -1.69857739e+01],\n", 565 | " [ 7.84001035e+00, -1.13391887e+00],\n", 566 | " [ 1.29158828e+01, 1.19344714e+01],\n", 567 | " [-1.06241773e+01, -1.44117050e+01],\n", 568 | " [ 3.34716804e+01, 2.83302128e+01],\n", 569 | " [-1.55024300e+01, -4.42891211e+00],\n", 570 | " [-3.48955183e+01, -3.03897814e+01],\n", 571 | " [ 4.33012619e+00, -1.58450917e+00],\n", 572 | " [-6.73299617e+00, -3.80308113e+00],\n", 573 | " [-2.35171223e+00, -3.40258619e+00],\n", 574 | " [ 2.16014984e+01, 2.05111530e+01],\n", 575 | " [-1.14591606e+01, -7.76235530e+00],\n", 576 | " [-7.04712496e+00, -6.03325693e+00],\n", 577 | " [-9.36561770e+00, -1.05784975e+01],\n", 578 | " [ 1.51976132e+01, 1.75157483e+01],\n", 579 | " [-5.27283770e+00, -7.32338430e+00],\n", 580 | " [ 8.84887775e+00, 2.55961507e+00],\n", 581 | " [ 2.90433598e+01, 3.13938059e+01],\n", 582 | " [-4.87162731e+00, -1.53659872e+00],\n", 583 | " [ 7.03312019e+00, 1.35123136e+01],\n", 584 | " [ 6.21913815e+00, 7.58283931e+00],\n", 585 | " [ 1.76485937e+01, 1.86315969e+01],\n", 586 | " [ 2.07231446e+01, 1.73708759e+01],\n", 587 | " [-1.72123582e+01, -1.38098575e+01],\n", 588 | " [ 1.30620030e+01, 9.51586713e+00],\n", 589 | " [ 2.05136987e+01, 1.57194550e+01],\n", 590 | " [-1.11707578e+01, -1.03694898e+01],\n", 591 | " [ 1.24416031e+01, 1.70158000e+01],\n", 592 | " [-1.14513667e+01, -1.00988180e+01],\n", 593 | " [ 3.96631849e-01, 1.83286619e-01],\n", 594 | " [-1.13900816e+01, -1.51945068e+01],\n", 595 | " [ 2.78301418e+01, 2.35278376e+01],\n", 596 | " [-2.68684473e-01, -2.30994163e+00],\n", 597 | " [-5.79886629e+00, -6.99356703e+00],\n", 598 | " [ 2.65129618e+01, 2.79829307e+01],\n", 599 | " [-3.11016922e+01, -2.83476668e+01],\n", 600 | " [-3.29337223e+01, -3.42038410e+01],\n", 601 | " [-7.92670029e+00, -8.73191611e+00],\n", 602 | " [-2.13186348e+01, -1.76976884e+01],\n", 603 | " [ 7.22327839e+00, 9.56145219e+00],\n", 604 | " [ 2.51860980e+01, 2.35128088e+01],\n", 605 | " [-9.75730761e+00, -8.19911689e+00],\n", 606 | " [ 1.87096854e+01, 2.08531999e+01],\n", 607 | " [ 2.07770950e+01, 1.73692478e+01],\n", 608 | " [ 9.11028756e+00, 1.49462636e+01],\n", 609 | " [ 1.68582071e+01, 1.34435252e+01],\n", 610 | " [ 2.84208726e+01, 2.80328979e+01],\n", 611 | " [-1.52401652e+01, -1.42386803e+01],\n", 612 | " [-9.38910802e+00, -9.95319600e+00],\n", 613 | " [ 5.15012713e+00, 6.99167391e-01],\n", 614 | " [-3.92883724e+01, -3.68248844e+01],\n", 615 | " [ 1.87737433e+01, 1.72472071e+01],\n", 616 | " [ 3.44181087e+01, 3.48890706e+01],\n", 617 | " [-3.26906582e+01, -3.06290054e+01],\n", 618 | " [-2.40423520e+01, -1.45808145e+01],\n", 619 | " [ 2.13756242e+00, 5.27183576e+00],\n", 620 | " [ 4.90009592e+00, 5.61565090e+00],\n", 621 | " [ 4.54859419e+01, 3.76886129e+01],\n", 622 | " [ 2.31324656e+01, 2.85480940e+01],\n", 623 | " [ 3.24232398e+01, 2.99045715e+01],\n", 624 | " [ 4.88183641e+00, 3.50497753e+00],\n", 625 | " [ 9.65033120e+00, 2.89810642e+00],\n", 626 | " [ 3.12306314e+01, 2.94777631e+01],\n", 627 | " [ 1.59728878e+01, 1.06858864e+01],\n", 628 | " [ 1.20997705e+00, 5.95213548e-01],\n", 629 | " [ 1.33539038e+01, 1.75907486e+01],\n", 630 | " [ 2.62493437e+00, 6.53876886e-01],\n", 631 | " [-7.84884410e-01, 3.13612900e+00],\n", 632 | " [-2.36139030e+00, -1.67070765e+00],\n", 633 | " [ 3.12325596e+00, 3.24636605e-02],\n", 634 | " [ 2.39675710e+01, 2.50891188e+01],\n", 635 | " [-3.65625038e+00, -4.68218168e+00],\n", 636 | " [ 6.93704714e+00, 8.36365197e+00],\n", 637 | " [-1.56590115e+01, -1.38351988e+01],\n", 638 | " [-3.86054744e+00, -7.50929234e+00],\n", 639 | " [ 3.56766317e+00, 4.16903836e+00],\n", 640 | " [ 6.94816559e+00, 4.98694702e-01],\n", 641 | " [ 1.98748626e+01, 1.86915632e+01],\n", 642 | " [ 9.63694712e+00, 5.44322195e+00],\n", 643 | " [ 6.35608252e+00, 7.74252966e+00],\n", 644 | " [ 2.29898145e+01, 1.74060334e+01],\n", 645 | " [ 2.21190507e+01, 1.57930569e+01],\n", 646 | " [-7.53326886e+00, -1.15775530e+01],\n", 647 | " [-1.41301319e+01, -1.06411886e+01],\n", 648 | " [-1.40032833e+01, -1.28802892e+01],\n", 649 | " [ 2.14106292e+01, 1.91283181e+01],\n", 650 | " [ 5.16093055e-01, 3.61359753e+00],\n", 651 | " [-1.72316863e+00, 3.20182049e+00],\n", 652 | " [ 3.89969272e+00, 6.56362203e+00],\n", 653 | " [ 2.11949427e+01, 1.98445588e+01],\n", 654 | " [ 4.79663284e+00, 7.20981922e+00]])" 655 | ] 656 | }, 657 | "execution_count": 474, 658 | "metadata": {}, 659 | "output_type": "execute_result" 660 | } 661 | ], 662 | "source": [ 663 | "y_test" 664 | ] 665 | }, 666 | { 667 | "cell_type": "markdown", 668 | "id": "628499f2", 669 | "metadata": {}, 670 | "source": [ 671 | "### Chapter 11 - NLP data" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": null, 677 | "id": "f2ac5b5a", 678 | "metadata": {}, 679 | "outputs": [], 680 | "source": [ 681 | "df = pd.read_csv('data/manga.csv') " 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": null, 687 | "id": "e3f61a77", 688 | "metadata": {}, 689 | "outputs": [], 690 | "source": [ 691 | "# Check gender indicator\n", 692 | "df['gender_indicator'].unique()" 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": null, 698 | "id": "2df35ef5", 699 | "metadata": {}, 700 | "outputs": [], 701 | "source": [ 702 | "# Produce female avatar indicator\n", 703 | "avatar = []\n", 704 | "\n", 705 | "for i in df['gender_indicator']:\n", 706 | " if i:\n", 707 | " avatar.append(np.random.choice([0, 1], p=[.15, .85]))\n", 708 | " else:\n", 709 | " avatar.append(np.random.choice([0, 1], p=[.97, .03]))\n", 710 | " \n", 711 | "df['female_avatar'] = avatar\n", 712 | "\n", 713 | "# Sanity\n", 714 | "df.groupby('gender_indicator')['female_avatar'].mean()" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": null, 720 | "id": "d33d6722", 721 | "metadata": {}, 722 | "outputs": [], 723 | "source": [ 724 | "# Produce stereotype indicator\n", 725 | "df['love_indicator'] = df['text'].str.contains('love|roman').astype('int')" 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "execution_count": null, 731 | "id": "1ba8a129", 732 | "metadata": {}, 733 | "outputs": [], 734 | "source": [ 735 | "# Add photo to the post at random\n", 736 | "df['has_photo'] = np.random.choice([0, 1], size=df.shape[0])" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": null, 742 | "id": "e2dcbd14", 743 | "metadata": {}, 744 | "outputs": [], 745 | "source": [ 746 | "# Produce the outcome\n", 747 | "noise = np.random.normal(0.01, .05, size=df.shape[0])\n", 748 | "probas = .8 + 0.1 * df['love_indicator'] * 1.2 * (0.1 + df['has_photo']) - 0.7 * df['female_avatar'] + noise\n", 749 | "probas = np.clip(probas, 0, 1)\n", 750 | "\n", 751 | "upvote = []\n", 752 | "\n", 753 | "for p in probas:\n", 754 | " upvote_ = np.random.choice([0, 1], p=[1 - p, p])\n", 755 | " upvote.append(upvote_)\n", 756 | " \n", 757 | "df['upvote'] = likes" 758 | ] 759 | }, 760 | { 761 | "cell_type": "code", 762 | "execution_count": null, 763 | "id": "02c2eee9", 764 | "metadata": {}, 765 | "outputs": [], 766 | "source": [ 767 | "df.drop(['gender_indicator', 'topic', 'love_indicator', 'author'], axis=1).to_csv('data/manga_processed.csv', index=False)" 768 | ] 769 | } 770 | ], 771 | "metadata": { 772 | "kernelspec": { 773 | "display_name": "Python [conda env:causal_book_py38]", 774 | "language": "python", 775 | "name": "conda-env-causal_book_py38-py" 776 | }, 777 | "language_info": { 778 | "codemirror_mode": { 779 | "name": "ipython", 780 | "version": 3 781 | }, 782 | "file_extension": ".py", 783 | "mimetype": "text/x-python", 784 | "name": "python", 785 | "nbconvert_exporter": "python", 786 | "pygments_lexer": "ipython3", 787 | "version": "3.8.13" 788 | } 789 | }, 790 | "nbformat": 4, 791 | "nbformat_minor": 5 792 | } 793 | -------------------------------------------------------------------------------- /Extras_02__Additional_computations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1eaa963c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import pandas as pd\n", 12 | "from scipy import stats\n", 13 | "\n", 14 | "import dowhy\n", 15 | "from dowhy import CausalModel\n", 16 | "\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "plt.style.use('fivethirtyeight')" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "fc74dab6", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "COLORS = [\n", 29 | " '#00B0F0',\n", 30 | " '#FF0000'\n", 31 | "]" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "df565036", 37 | "metadata": {}, 38 | "source": [ 39 | "# Additional computations" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "e42ea37a", 45 | "metadata": {}, 46 | "source": [ 47 | "## Chapter 01 / Chapter 09" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "d28e80e8", 53 | "metadata": {}, 54 | "source": [ 55 | "### Solving Simpson's paradox with IPW" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "id": "865c3739", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "propensity_score_weighting\n", 69 | "0.04811580602068774\n" 70 | ] 71 | }, 72 | { 73 | "name": "stderr", 74 | "output_type": "stream", 75 | "text": [ 76 | "C:\\Users\\aleks\\anaconda3\\envs\\causal_book_py38\\lib\\site-packages\\sklearn\\utils\\validation.py:1111: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", 77 | " y = column_or_1d(y, warn=True)\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "pd.read_csv('data/ch_01_drug_data.csv')\n", 83 | "\n", 84 | "gender = [1] * (24 + 56 + 17 + 25) + [0] * (3 + 39 + 6 +74)\n", 85 | "clot = [1] * 24 + [0] * 56 + [1] * 17 + [0] * 25 + [1] * 3 + [0] * 39 + [1] * 6 + [0] * 74\n", 86 | "drug = [0] * (24 + 56) + [1] * (17 + 25) + [0] * 42 + [1] * 80\n", 87 | "\n", 88 | "drug_data = pd.DataFrame(dict(\n", 89 | " gender=gender,\n", 90 | " clot=clot,\n", 91 | " drug=drug\n", 92 | "))\n", 93 | "\n", 94 | "# Construct the graph (the graph is constant for all iterations)\n", 95 | "nodes_drug = ['drug', 'clot', 'gender']\n", 96 | "edges_drug = [\n", 97 | " ('drug', 'clot'),\n", 98 | " ('gender', 'drug'),\n", 99 | " ('gender', 'clot')\n", 100 | "]\n", 101 | "\n", 102 | "# Generate the GML graph\n", 103 | "gml_string_drug = 'graph [directed 1\\n'\n", 104 | "\n", 105 | "for node in nodes_drug:\n", 106 | " gml_string_drug += f'\\tnode [id \"{node}\" label \"{node}\"]\\n'\n", 107 | "\n", 108 | "for edge in edges_drug:\n", 109 | " gml_string_drug += f'\\tedge [source \"{edge[0]}\" target \"{edge[1]}\"]\\n'\n", 110 | " \n", 111 | "gml_string_drug += ']'\n", 112 | "\n", 113 | "# Instantiate the CausalModel\n", 114 | "model_drug = CausalModel(\n", 115 | " data=drug_data,\n", 116 | " treatment='drug',\n", 117 | " outcome='clot',\n", 118 | " graph=gml_string_drug\n", 119 | ")\n", 120 | "\n", 121 | "# Identify effect\n", 122 | "estimand_drug = model_drug.identify_effect()\n", 123 | "\n", 124 | "# Get estimate (IPW weighting)\n", 125 | "estimate_drug = model_drug.estimate_effect(\n", 126 | " identified_estimand=estimand_drug,\n", 127 | " method_name='backdoor.propensity_score_weighting',\n", 128 | " target_units='ate'\n", 129 | ")\n", 130 | "\n", 131 | "print(estimate_drug.value)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "f85982a5", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "Python [conda env:causal_book_py38]", 146 | "language": "python", 147 | "name": "conda-env-causal_book_py38-py" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.8.13" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 5 164 | } 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Packt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Machine Learning Summit 2025

2 | 3 | ## Machine Learning Summit 2025 4 | **Bridging Theory and Practice: ML Solutions for Today’s Challenges** 5 | 6 | 3 days, 20+ experts, and 25+ tech sessions and talks covering critical aspects of: 7 | - **Agentic and Generative AI** 8 | - **Applied Machine Learning in the Real World** 9 | - **ML Engineering and Optimization** 10 | 11 | 👉 [Book your ticket now >>](https://packt.link/mlsumgh) 12 | 13 | --- 14 | 15 | ## Join Our Newsletters 📬 16 | 17 | ### DataPro 18 | *The future of AI is unfolding. Don’t fall behind.* 19 | 20 |

DataPro QR

21 | 22 | Stay ahead with [**DataPro**](https://landing.packtpub.com/subscribe-datapronewsletter/?link_from_packtlink=yes), the free weekly newsletter for data scientists, AI/ML researchers, and data engineers. 23 | From trending tools like **PyTorch**, **scikit-learn**, **XGBoost**, and **BentoML** to hands-on insights on **database optimization** and real-world **ML workflows**, you’ll get what matters, fast. 24 | 25 | > Stay sharp with [DataPro](https://landing.packtpub.com/subscribe-datapronewsletter/?link_from_packtlink=yes). Join **115K+ data professionals** who never miss a beat. 26 | 27 | --- 28 | 29 | ### BIPro 30 | *Business runs on data. Make sure yours tells the right story.* 31 | 32 |

BIPro QR

33 | 34 | [**BIPro**](https://landing.packtpub.com/subscribe-bipro-newsletter/?link_from_packtlink=yes) is your free weekly newsletter for BI professionals, analysts, and data leaders. 35 | Get practical tips on **dashboarding**, **data visualization**, and **analytics strategy** with tools like **Power BI**, **Tableau**, **Looker**, **SQL**, and **dbt**. 36 | 37 | > Get smarter with [BIPro](https://landing.packtpub.com/subscribe-bipro-newsletter/?link_from_packtlink=yes). Trusted by **35K+ BI professionals**, see what you’re missing. 38 | 39 | 40 | 41 | # Causal Inference and Discovery in Python 42 | 43 | Causal Inference and Discovery in Python 44 | 45 | This is the code repository for [Causal Inference and Discovery in Python](https://www.packtpub.com/product/causal-inference-and-discovery-in-python/9781804612989), published by Packt. 46 | 47 | **Unlock the secrets of modern causal machine learning with DoWhy, EconML, PyTorch and more** 48 | 49 | ## What is this book about? 50 | 51 | Causal methods present unique challenges compared to traditional machine learning and statistics. Learning causality can be challenging, but it offers distinct advantages that elude a purely statistical mindset. Causal Inference and Discovery in Python helps you unlock the potential of causality. 52 | 53 | You’ll start with basic motivations behind causal thinking and a comprehensive introduction to Pearlian causal concepts, such as structural causal models, interventions, counterfactuals, and more. Each concept is accompanied by a theoretical explanation and a set of practical exercises with Python code. 54 | 55 | Next, you’ll dive into the world of causal effect estimation, consistently progressing towards modern machine learning methods. Step-by-step, you’ll discover Python causal ecosystem and harness the power of cutting-edge algorithms. You’ll further explore the mechanics of how “causes leave traces” and compare the main families of causal discovery algorithms. 56 | 57 | The final chapter gives you a broad outlook into the future of causal AI where we examine challenges and opportunities and provide you with a comprehensive list of resources to learn more. 58 | 59 | This book covers the following exciting features: 60 | * Master the fundamental concepts of causal inference 61 | * Decipher the mysteries of structural causal models 62 | * Unleash the power of the 4-step causal inference process in Python 63 | * Explore advanced uplift modeling techniques 64 | * Unlock the secrets of modern causal discovery using Python 65 | * Use causal inference for social impact and community benefit 66 | 67 | If you feel this book is for you, get your [copy](https://www.amazon.com/Causal-Inference-Discovery-Python-learning/dp/1804612987/ref=sr_1_1?keywords=Causal+Inference+and+Discovery+in+Python&s=books&sr=1-1) today! 68 | 69 | 70 | ## Instructions and Navigations 71 | All of the code is organized into folders. 72 | 73 | The code will look like the following: 74 | ``` 75 | preds = causal_bert.inference( 76 | texts=df['text'], 77 | confounds=df['has_photo'], 78 | )[0] 79 | ``` 80 | 81 | **Following is what you need for this book:** 82 | 83 | This book is for machine learning engineers, data scientists, and machine learning researchers looking to extend their data science toolkit and explore causal machine learning. It will also help developers familiar with causality who have worked in another technology and want to switch to Python, and data scientists with a history of working with traditional causality who want to learn causal machine learning. It’s also a must-read for tech-savvy entrepreneurs looking to build a competitive edge for their products and go beyond the limitations of traditional machine learning. 84 | 85 | With the following software and hardware list you can run all code files present in the book (Chapter 1-15). 86 | 87 | ### Software and Hardware List 88 | 89 | | Chapter | Software required | OS required | 90 | | -------- | -------------------------------------------------------------------------------------| -----------------------------------| 91 | | 1-15 | Python 3.9 | Windows macOS, or Linux | 92 | | 1-15 | DoWhy 0.8 | Windows, macOS, or Linux | 93 | | 1-15 | EconML 0.12.0 | Windows, macOS, or Linux | 94 | | 1-15 | CATENets 0.2.3 | Windows, macOS, or Linux | 95 | | 1-15 | gCastle 1.0.3 | Windows, macOS, or Linux | 96 | | 1-15 | Causica 0.2.0 | Windows, macOS, or Linux | 97 | | 1-15 | Causal-learn 0.1.3.3 | Windows, macOS, or Linux | 98 | | 1-15 | Transformers 4.24.0 | Windows, macOS, or Linux | 99 | 100 | 101 | ## Join our Discord server Coding 102 | 103 | Join our Discord community to meet like-minded people and learn alongside more than 2000 members at [Discord](https://packt.link/infer) Coding 104 | 105 | 106 | ### Related products 107 | * Hands-On Graph Neural Networks Using Python [[Packt]](https://www.packtpub.com/product/hands-on-graph-neural-networks-using-python/9781804617526) [[Amazon]](https://www.amazon.com/Hands-Graph-Neural-Networks-Python/dp/1804617520/ref=sr_1_1?keywords=Hands-On+Graph+Neural+Networks+Using+Python&s=books&sr=1-1) 108 | 109 | * Applying Math with Python - Second Edition [[Packt]](https://www.packtpub.com/product/applying-math-with-python-second-edition/9781804618370) [[Amazon]](https://www.amazon.com/Applying-Math-Python-real-world-computational/dp/1804618373/ref=sr_1_1?keywords=Applying+Math+with+Python+-+Second+Edition&s=books&sr=1-1) 110 | 111 | ## Get to Know the Author 112 | [**Aleksander Molak**](https://www.linkedin.com/in/aleksandermolak/) is a Machine Learning Researcher and Consultant who gained experience working with Fortune 100, Fortune 500, and Inc. 5000 companies across Europe, the USA, and Israel, designing and building large-scale machine learning systems. On a mission to democratize causality for businesses and machine learning practitioners, Aleksander is a prolific writer, creator, and international speaker. As a co-founder of Lespire, an innovative provider of AI and machine learning training for corporate teams, Aleksander is committed to empowering businesses to harness the full potential of cutting-edge technologies that allow them to stay ahead of the curve. 113 | He's the host of the Causal AI-centered [Causal Bandits Podcast](https://causalbanditspodcast.com/). 114 | 115 | 116 | 117 | 118 | # Note from the Author: 119 | 120 | ## Environment installation 121 | 1. See the section **Using `graphviz` and GPU** below 122 | 123 | 2. To install the basic environment run: `conda env create -f causal_book_py39_cuda117.yml` 124 | 125 | 3. To install the environment for notebook `Chapter_11.2.ipynb` run: `conda env create -f causal-pymc.yml` 126 | 127 | **NOTE:** We added an experimental environment for Apple M1 as suggested by @ferrari-leo [here](https://github.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/issues/8). This environment hasn't been thoroughly tested so please use it at your own risk. 128 | 129 | ## Selecting the kernel 130 | After a successful installation of the environment, open your notebook and select the kernel `causal_book_py39_cuda117` 131 | 132 | For notebook `Chapter_11.2.ipynb` change kernel to `causal-pymc` 133 | 134 | ## Using `graphviz` and GPU 135 | 136 | **Note**: Depending on your system settings, you might need to install `graphviz` manually in order to recreate the graph plots in the code. 137 | Check https://pypi.org/project/graphviz/ for instructions 138 | specific to your operating system. 139 | 140 | **Note 2**: To use GPU you'll need to install CUDA 11.7 drivers. 141 | This can be done here: https://developer.nvidia.com/cuda-11-7-0-download-archive 142 | 143 | ## Citation 144 | 145 | ### BibTeX 146 | ```{bibtex} 147 | @book{Molak2023, 148 | title={Causal Inference and Discovery in Python: Unlock the secrets of modern causal machine learning with DoWhy, EconML, PyTorch and more}, 149 | author={Molak, Aleksander}, 150 | publisher={Packt Publishing}, 151 | address={Birmingham}, 152 | edition={1.}, 153 | year={2023}, 154 | isbn={1804612987}, 155 | note={\url{https://amzn.to/3RebWzn}} 156 | } 157 | ``` 158 | 159 | ### APA 160 | ``` 161 | Molak, A. (2023). Causal Inference and Discovery in Python: Unlock the secrets of modern causal machine learning with DoWhy, EconML, PyTorch and more. Packt Publishing. 162 | ``` 163 | 164 | ## ‼️ Known mistakes // errata 165 | For known errors and corrections check: 166 | 167 | * [Books purchased before ~12:00 PM on June 13, 2023](https://github.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/blob/main/errata/Errata%20-%20Early%20Print%20(ordered%20before%20June%2013%202023).ipynb) 168 | 169 | * [Books purchased after ~12:00 PM on June 13, 2023](https://github.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/blob/main/errata/Errata%20-%20Non-Early%20Print%20(ordered%20after%20June%2013%202023).ipynb) 170 | 171 | If you spotted a mistake, let us know at book(at)causalpython.io or just open an **issue** in this repo. Thank you 🙏🏼 172 | -------------------------------------------------------------------------------- /causal-pymc.yml: -------------------------------------------------------------------------------- 1 | name: causal-pymc 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.9 7 | - numba 8 | - pip 9 | - nb_conda 10 | - arviz 11 | - matplotlib 12 | - pandas 13 | - scipy 14 | - numpy 15 | - statsmodels 16 | - pydot 17 | - tqdm 18 | - ipywidgets 19 | - pip: 20 | - pymc>=4 21 | - CausalPy==0.0.8 -------------------------------------------------------------------------------- /causal_book_py39_apple_m1_(experimental-by-ferrari-leo).txt: -------------------------------------------------------------------------------- 1 | name: causal_book_py39 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - abseil-cpp=20230802.0=h313beb8_2 7 | - aiohttp=3.9.0=py39h80987f9_0 8 | - aiosignal=1.2.0=pyhd3eb1b0_0 9 | - anyio=3.5.0=py39hca03da5_0 10 | - appnope=0.1.2=py39hca03da5_1001 11 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 12 | - argon2-cffi-bindings=21.2.0=py39h1a28f6b_0 13 | - arrow-cpp=11.0.0=hc7aafb3_2 14 | - asttokens=2.0.5=pyhd3eb1b0_0 15 | - async-timeout=4.0.3=py39hca03da5_0 16 | - attrs=23.1.0=py39hca03da5_0 17 | - aws-c-common=0.6.8=h80987f9_1 18 | - aws-c-event-stream=0.1.6=h313beb8_6 19 | - aws-checksums=0.1.11=h80987f9_2 20 | - aws-sdk-cpp=1.8.185=ha71a6ea_1 21 | - backcall=0.2.0=pyhd3eb1b0_0 22 | - beautifulsoup4=4.12.2=py39hca03da5_0 23 | - blas=1.0=openblas 24 | - bleach=4.1.0=pyhd3eb1b0_0 25 | - boost-cpp=1.82.0=h48ca7d4_2 26 | - bottleneck=1.3.5=py39heec5a64_0 27 | - brotli=1.0.9=h1a28f6b_7 28 | - brotli-bin=1.0.9=h1a28f6b_7 29 | - brotli-python=1.0.9=py39hc377ac9_7 30 | - bzip2=1.0.8=h620ffc9_4 31 | - c-ares=1.19.1=h80987f9_0 32 | - ca-certificates=2023.12.12=hca03da5_0 33 | - certifi=2023.11.17=py39hca03da5_0 34 | - cffi=1.16.0=py39h80987f9_0 35 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 36 | - comm=0.1.2=py39hca03da5_0 37 | - contourpy=1.2.0=py39h48ca7d4_0 38 | - cryptography=41.0.7=py39hd4332d6_0 39 | - cycler=0.11.0=pyhd3eb1b0_0 40 | - datasets=2.12.0=py39hca03da5_0 41 | - debugpy=1.6.7=py39h313beb8_0 42 | - decorator=5.1.1=pyhd3eb1b0_0 43 | - defusedxml=0.7.1=pyhd3eb1b0_0 44 | - dill=0.3.6=py39hca03da5_0 45 | - entrypoints=0.4=py39hca03da5_0 46 | - exceptiongroup=1.2.0=py39hca03da5_0 47 | - executing=0.8.3=pyhd3eb1b0_0 48 | - filelock=3.13.1=py39hca03da5_0 49 | - fonttools=4.25.0=pyhd3eb1b0_0 50 | - freetype=2.12.1=h1192e45_0 51 | - frozenlist=1.4.0=py39h80987f9_0 52 | - fsspec=2023.10.0=py39hca03da5_0 53 | - future=0.18.3=py39hca03da5_0 54 | - gflags=2.2.2=hc377ac9_0 55 | - giflib=5.2.1=h80987f9_3 56 | - glog=0.5.0=hc377ac9_0 57 | - grpc-cpp=1.48.2=hc60591f_4 58 | - gtest=1.14.0=h48ca7d4_0 59 | - huggingface_hub=0.17.3=py39hca03da5_0 60 | - icu=73.1=h313beb8_0 61 | - idna=3.4=py39hca03da5_0 62 | - importlib-metadata=7.0.1=py39hca03da5_0 63 | - importlib_resources=6.1.1=py39hca03da5_1 64 | - ipykernel=6.25.0=py39h33ce5c2_0 65 | - ipython=8.15.0=py39hca03da5_0 66 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 67 | - jedi=0.18.1=py39hca03da5_1 68 | - jinja2=3.1.2=py39hca03da5_0 69 | - jpeg=9e=h80987f9_1 70 | - jsonschema=4.19.2=py39hca03da5_0 71 | - jsonschema-specifications=2023.7.1=py39hca03da5_0 72 | - jupyter_client=7.4.9=py39hca03da5_0 73 | - jupyter_core=5.5.0=py39hca03da5_0 74 | - jupyter_server=1.23.4=py39hca03da5_0 75 | - jupyterlab_pygments=0.2.2=py39hca03da5_0 76 | - kiwisolver=1.4.4=py39h313beb8_0 77 | - krb5=1.20.1=hf3e1bf2_1 78 | - lcms2=2.12=hba8e193_0 79 | - lerc=3.0=hc377ac9_0 80 | - libboost=1.82.0=h0bc93f9_2 81 | - libbrotlicommon=1.0.9=h1a28f6b_7 82 | - libbrotlidec=1.0.9=h1a28f6b_7 83 | - libbrotlienc=1.0.9=h1a28f6b_7 84 | - libcurl=8.5.0=h3e2b118_0 85 | - libcxx=14.0.6=h848a8c0_0 86 | - libdeflate=1.17=h80987f9_1 87 | - libedit=3.1.20230828=h80987f9_0 88 | - libev=4.33=h1a28f6b_1 89 | - libevent=2.1.12=h02f6b3c_1 90 | - libffi=3.4.4=hca03da5_0 91 | - libgfortran=5.0.0=11_3_0_hca03da5_28 92 | - libgfortran5=11.3.0=h009349e_28 93 | - libiconv=1.16=h1a28f6b_2 94 | - libnghttp2=1.57.0=h62f6fdd_0 95 | - libopenblas=0.3.21=h269037a_0 96 | - libpng=1.6.39=h80987f9_0 97 | - libprotobuf=3.20.3=h514c7bf_0 98 | - libsodium=1.0.18=h1a28f6b_0 99 | - libssh2=1.10.0=h02f6b3c_2 100 | - libthrift=0.15.0=h73c2103_2 101 | - libtiff=4.5.1=h313beb8_0 102 | - libwebp=1.3.2=ha3663a8_0 103 | - libwebp-base=1.3.2=h80987f9_0 104 | - lightgbm=4.1.0=py39h313beb8_0 105 | - llvm-openmp=14.0.6=hc6e5704_0 106 | - lz4-c=1.9.4=h313beb8_0 107 | - markupsafe=2.1.3=py39h80987f9_0 108 | - matplotlib=3.8.0=py39hca03da5_0 109 | - matplotlib-base=3.8.0=py39h46d7db6_0 110 | - matplotlib-inline=0.1.6=py39hca03da5_0 111 | - mistune=2.0.4=py39hca03da5_0 112 | - multidict=6.0.4=py39h80987f9_0 113 | - multiprocess=0.70.14=py39hca03da5_0 114 | - munkres=1.1.4=py_0 115 | - nb_conda=2.2.1=py39hca03da5_1 116 | - nb_conda_kernels=2.3.1=py39hca03da5_0 117 | - nbclassic=1.0.0=py39hca03da5_0 118 | - nbclient=0.8.0=py39hca03da5_0 119 | - nbconvert=7.10.0=py39hca03da5_0 120 | - nbformat=5.9.2=py39hca03da5_0 121 | - ncurses=6.4=h313beb8_0 122 | - nest-asyncio=1.5.6=py39hca03da5_0 123 | - notebook=6.5.4=py39hca03da5_1 124 | - notebook-shim=0.2.3=py39hca03da5_0 125 | - numexpr=2.8.7=py39hecc3335_0 126 | - openjpeg=2.3.0=h7a6adac_2 127 | - openssl=3.0.12=h1a28f6b_0 128 | - orc=1.7.4=hdca1487_1 129 | - packaging=23.1=py39hca03da5_0 130 | - pandocfilters=1.5.0=pyhd3eb1b0_0 131 | - parso=0.8.3=pyhd3eb1b0_0 132 | - pexpect=4.8.0=pyhd3eb1b0_3 133 | - pickleshare=0.7.5=pyhd3eb1b0_1003 134 | - pillow=10.0.1=py39h3b245a6_0 135 | - pip=23.3.1=py39hca03da5_0 136 | - platformdirs=3.10.0=py39hca03da5_0 137 | - progressbar2=4.2.0=py39hca03da5_0 138 | - prometheus_client=0.14.1=py39hca03da5_0 139 | - prompt-toolkit=3.0.43=py39hca03da5_0 140 | - psutil=5.9.0=py39h1a28f6b_0 141 | - ptyprocess=0.7.0=pyhd3eb1b0_2 142 | - pure_eval=0.2.2=pyhd3eb1b0_0 143 | - pyarrow=11.0.0=py39hbfed03b_1 144 | - pycparser=2.21=pyhd3eb1b0_0 145 | - pygam=0.9.0=pyhd8ed1ab_2 146 | - pygments=2.15.1=py39hca03da5_1 147 | - pyopenssl=23.2.0=py39hca03da5_0 148 | - pyparsing=3.0.9=py39hca03da5_0 149 | - pysocks=1.7.1=py39hca03da5_0 150 | - python=3.9.18=hb885b13_0 151 | - python-dateutil=2.8.2=pyhd3eb1b0_0 152 | - python-fastjsonschema=2.16.2=py39hca03da5_0 153 | - python-tzdata=2023.3=pyhd3eb1b0_0 154 | - python-utils=3.3.3=py39hca03da5_0 155 | - python-xxhash=2.0.2=py39h1a28f6b_1 156 | - pytz=2023.3.post1=py39hca03da5_0 157 | - pyyaml=6.0.1=py39h80987f9_0 158 | - pyzmq=23.2.0=py39hc377ac9_0 159 | - re2=2022.04.01=hc377ac9_0 160 | - readline=8.2=h1a28f6b_0 161 | - referencing=0.30.2=py39hca03da5_0 162 | - regex=2023.10.3=py39h80987f9_0 163 | - requests=2.31.0=py39hca03da5_0 164 | - responses=0.13.3=pyhd3eb1b0_0 165 | - rpds-py=0.10.6=py39hf0e4da2_0 166 | - safetensors=0.4.0=py39h482802a_0 167 | - scipy=1.11.4=py39h20cbe94_0 168 | - seaborn=0.12.2=py39hca03da5_0 169 | - send2trash=1.8.2=py39hca03da5_0 170 | - setuptools=68.2.2=py39hca03da5_0 171 | - six=1.16.0=pyhd3eb1b0_1 172 | - snappy=1.1.10=h313beb8_1 173 | - sniffio=1.3.0=py39hca03da5_0 174 | - soupsieve=2.5=py39hca03da5_0 175 | - sqlite=3.41.2=h80987f9_0 176 | - stack_data=0.2.0=pyhd3eb1b0_0 177 | - terminado=0.17.1=py39hca03da5_0 178 | - tinycss2=1.2.1=py39hca03da5_0 179 | - tk=8.6.12=hb8d0fd4_0 180 | - tokenizers=0.13.2=py39h3dd52b7_1 181 | - tornado=6.3.3=py39h80987f9_0 182 | - tqdm=4.65.0=py39h86d0a89_0 183 | - traitlets=5.7.1=py39hca03da5_0 184 | - transformers=4.32.1=py39hca03da5_0 185 | - typing-extensions=4.9.0=py39hca03da5_1 186 | - typing_extensions=4.9.0=py39hca03da5_1 187 | - tzdata=2023d=h04d1e81_0 188 | - urllib3=1.26.18=py39hca03da5_0 189 | - utf8proc=2.6.1=h1a28f6b_0 190 | - wcwidth=0.2.5=pyhd3eb1b0_0 191 | - webencodings=0.5.1=py39hca03da5_1 192 | - websocket-client=0.58.0=py39hca03da5_4 193 | - wheel=0.41.2=py39hca03da5_0 194 | - xxhash=0.8.0=h1a28f6b_3 195 | - xz=5.4.5=h80987f9_0 196 | - yaml=0.2.5=h1a28f6b_0 197 | - yarl=1.9.3=py39h80987f9_0 198 | - zeromq=4.3.5=h313beb8_0 199 | - zipp=3.17.0=py39hca03da5_0 200 | - zlib=1.2.13=h5a0b063_0 201 | - zstd=1.5.5=hd90d995_0 202 | - pip: 203 | - absl-py==2.1.0 204 | - alembic==1.13.1 205 | - antlr4-python3-runtime==4.9.3 206 | - arviz==0.17.0 207 | - azure-common==1.1.28 208 | - azure-core==1.29.7 209 | - azure-identity==1.15.0 210 | - azure-mgmt-core==1.4.0 211 | - azure-storage-blob==12.19.0 212 | - azureml-mlflow==1.54.0.post1 213 | - blinker==1.7.0 214 | - cachetools==5.3.2 215 | - catenets==0.2.3 216 | - causal-learn==0.1.3.3 217 | - causalpy==0.2.0 218 | - causica==0.2.0 219 | - click==8.1.7 220 | - cloudpickle==3.0.0 221 | - cons==0.4.6 222 | - databricks-cli==0.18.0 223 | - dataclasses-json==0.5.14 224 | - docker==7.0.0 225 | - docstring-parser==0.15 226 | - dowhy==0.8 227 | - econml==0.14.1 228 | - etuples==0.3.9 229 | - fastprogress==1.0.3 230 | - flask==3.0.1 231 | - gcastle==1.0.3 232 | - gdown==5.0.0 233 | - gitdb==4.0.11 234 | - gitpython==3.1.41 235 | - google-auth==2.27.0 236 | - google-auth-oauthlib==1.2.0 237 | - grapl-causal==1.5.1 238 | - grpcio==1.60.0 239 | - gunicorn==21.2.0 240 | - h5netcdf==1.3.0 241 | - h5py==3.10.0 242 | - hydra-core==1.3.2 243 | - iniconfig==2.0.0 244 | - isodate==0.6.1 245 | - itsdangerous==2.1.2 246 | - jax==0.4.23 247 | - jaxlib==0.4.23 248 | - joblib==1.3.2 249 | - jsonargparse==4.27.3 250 | - jsonpickle==3.0.2 251 | - lightning-utilities==0.10.1 252 | - llvmlite==0.41.1 253 | - logical-unification==0.4.6 254 | - loguru==0.7.2 255 | - mako==1.3.0 256 | - markdown==3.5.2 257 | - markdown-it-py==3.0.0 258 | - marshmallow==3.20.2 259 | - mdurl==0.1.2 260 | - minikanren==1.0.3 261 | - ml-dtypes==0.3.2 262 | - mlflow==2.10.0 263 | - mlflow-skinny==2.10.0 264 | - mpmath==1.3.0 265 | - msal==1.26.0 266 | - msal-extensions==1.1.0 267 | - msrest==0.7.1 268 | - multipledispatch==1.0.0 269 | - mypy-extensions==1.0.0 270 | - networkx==2.8.7 271 | - numba==0.58.1 272 | - numpy==1.25.2 273 | - oauthlib==3.2.2 274 | - omegaconf==2.3.0 275 | - opt-einsum==3.3.0 276 | - pandas==1.5.3 277 | - patsy==0.5.6 278 | - pluggy==1.4.0 279 | - portalocker==2.8.2 280 | - protobuf==4.23.4 281 | - pyasn1==0.5.1 282 | - pyasn1-modules==0.3.0 283 | - pydot==2.0.0 284 | - pyjwt==2.8.0 285 | - pymc==5.10.3 286 | - pytensor==2.18.6 287 | - pytest==7.4.4 288 | - python-graphviz==0.20.1 289 | - pytorch-lightning==1.9.5 290 | - querystring-parser==1.2.4 291 | - requests-oauthlib==1.3.1 292 | - rich==13.7.0 293 | - rsa==4.9 294 | - scikit-learn==1.2.2 295 | - shap==0.41.0 296 | - slicer==0.0.7 297 | - smmap==5.0.1 298 | - sparse==0.15.1 299 | - sqlalchemy==2.0.25 300 | - sqlparse==0.4.4 301 | - statsmodels==0.14.1 302 | - sympy==1.12 303 | - tabulate==0.9.0 304 | - tensorboard==2.15.1 305 | - tensorboard-data-server==0.7.2 306 | - tensorboardx==2.6.2.2 307 | - tensordict==0.1.2 308 | - threadpoolctl==3.2.0 309 | - tomli==2.0.1 310 | - toolz==0.12.1 311 | - torch==2.1.2 312 | - torchmetrics==1.3.0.post0 313 | - types-pyyaml==6.0.12.12 314 | - typeshed-client==2.4.0 315 | - typing-inspect==0.9.0 316 | - werkzeug==3.0.1 317 | - xarray==2024.1.1 318 | - xarray-einstats==0.7.0 319 | -------------------------------------------------------------------------------- /causal_book_py39_cuda117.yml: -------------------------------------------------------------------------------- 1 | name: causal_book_py39_cuda117 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - nvidia 6 | - pytorch 7 | dependencies: 8 | - pip 9 | - python=3.9 10 | - nb_conda 11 | - numpy 12 | - pandas 13 | - matplotlib 14 | - seaborn 15 | - pytorch 16 | - pytorch-cuda=11.7 17 | - pygam 18 | - lightgbm 19 | - transformers 20 | - pip: 21 | - econml 22 | - networkx==2.8.7 23 | - dowhy==0.8 24 | - gcastle==1.0.3 25 | - graphviz 26 | - causal-learn==0.1.3.3 27 | - grapl-causal==1.5.1 28 | - causica==0.2.0 29 | - catenets==0.2.3 30 | -------------------------------------------------------------------------------- /causal_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/causal_model.png -------------------------------------------------------------------------------- /data/ch_01_drug_data.csv: -------------------------------------------------------------------------------- 1 | Drug,A,A,B,B 2 | Clot,1,0,1,0 3 | Female,24,56,17,25 4 | Male,3,39,6,74 5 | Total,27,95,23,99 6 | -------------------------------------------------------------------------------- /data/gt_social_media_data.csv: -------------------------------------------------------------------------------- 1 | date,twitter,linkedin,tiktok,instagram 2 | 2022-05-15,55,9,23,59 3 | 2022-05-16,54,18,20,59 4 | 2022-05-17,54,20,23,57 5 | 2022-05-18,54,20,21,55 6 | 2022-05-19,49,23,21,52 7 | 2022-05-20,46,18,22,56 8 | 2022-05-21,51,9,23,58 9 | 2022-05-22,47,9,27,59 10 | 2022-05-23,45,19,21,58 11 | 2022-05-24,49,21,23,53 12 | 2022-05-25,55,21,21,61 13 | 2022-05-26,53,19,22,68 14 | 2022-05-27,52,16,23,52 15 | 2022-05-28,46,8,24,59 16 | 2022-05-29,45,7,22,56 17 | 2022-05-30,45,9,24,61 18 | 2022-05-31,46,19,20,58 19 | 2022-06-01,51,21,22,56 20 | 2022-06-02,47,19,22,54 21 | 2022-06-03,46,17,21,56 22 | 2022-06-04,45,9,23,58 23 | 2022-06-05,47,9,23,60 24 | 2022-06-06,48,20,21,58 25 | 2022-06-07,46,23,21,57 26 | 2022-06-08,48,22,24,56 27 | 2022-06-09,48,20,23,55 28 | 2022-06-10,48,17,22,56 29 | 2022-06-11,47,9,25,54 30 | 2022-06-12,46,9,24,56 31 | 2022-06-13,46,20,22,54 32 | 2022-06-14,46,21,23,58 33 | 2022-06-15,47,21,23,58 34 | 2022-06-16,46,21,24,56 35 | 2022-06-17,51,17,23,57 36 | 2022-06-18,44,9,24,54 37 | 2022-06-19,43,8,26,59 38 | 2022-06-20,45,16,25,53 39 | 2022-06-21,53,23,24,56 40 | 2022-06-22,48,21,25,58 41 | 2022-06-23,48,22,26,55 42 | 2022-06-24,54,18,24,56 43 | 2022-06-25,54,9,25,57 44 | 2022-06-26,48,8,23,62 45 | 2022-06-27,51,19,24,58 46 | 2022-06-28,54,21,25,56 47 | 2022-06-29,49,20,25,61 48 | 2022-06-30,53,19,27,58 49 | 2022-07-01,56,17,26,55 50 | 2022-07-02,47,8,25,56 51 | 2022-07-03,49,8,28,58 52 | 2022-07-04,47,9,25,58 53 | 2022-07-05,52,20,22,61 54 | 2022-07-06,49,20,27,60 55 | 2022-07-07,49,21,26,56 56 | 2022-07-08,58,17,27,53 57 | 2022-07-09,61,10,30,55 58 | 2022-07-10,56,15,31,55 59 | 2022-07-11,59,27,27,55 60 | 2022-07-12,52,25,28,57 61 | 2022-07-13,54,25,26,55 62 | 2022-07-14,61,25,25,60 63 | 2022-07-15,51,22,25,58 64 | 2022-07-16,50,12,24,58 65 | 2022-07-17,53,13,28,57 66 | 2022-07-18,48,24,25,57 67 | 2022-07-19,50,26,25,57 68 | 2022-07-20,47,26,25,53 69 | 2022-07-21,49,24,26,55 70 | 2022-07-22,50,21,26,57 71 | 2022-07-23,47,9,26,55 72 | 2022-07-24,45,8,27,57 73 | 2022-07-25,47,19,26,59 74 | 2022-07-26,47,20,28,58 75 | 2022-07-27,47,20,27,56 76 | 2022-07-28,48,19,26,57 77 | 2022-07-29,47,18,27,58 78 | 2022-07-30,46,9,28,53 79 | 2022-07-31,48,9,27,58 80 | 2022-08-01,48,19,25,56 81 | 2022-08-02,51,21,25,56 82 | 2022-08-03,49,21,27,55 83 | 2022-08-04,47,20,24,54 84 | 2022-08-05,47,17,25,54 85 | 2022-08-06,47,10,26,57 86 | 2022-08-07,46,9,26,58 87 | 2022-08-08,46,19,24,59 88 | 2022-08-09,62,21,25,59 89 | 2022-08-10,52,20,25,56 90 | 2022-08-11,54,21,24,61 91 | 2022-08-12,60,18,24,59 92 | 2022-08-13,55,9,26,56 93 | 2022-08-14,52,9,25,55 94 | 2022-08-15,48,19,24,54 95 | 2022-08-16,49,21,27,55 96 | 2022-08-17,47,20,24,53 97 | 2022-08-18,45,19,24,53 98 | 2022-08-19,47,18,25,53 99 | 2022-08-20,47,9,25,54 100 | 2022-08-21,49,10,27,56 101 | 2022-08-22,44,20,23,55 102 | 2022-08-23,46,22,23,52 103 | 2022-08-24,49,22,23,52 104 | 2022-08-25,47,21,23,56 105 | 2022-08-26,56,17,23,54 106 | 2022-08-27,52,9,25,55 107 | 2022-08-28,51,9,26,57 108 | 2022-08-29,47,19,22,56 109 | 2022-08-30,47,21,22,52 110 | 2022-08-31,46,20,21,52 111 | 2022-09-01,47,18,22,55 112 | 2022-09-02,48,16,21,49 113 | 2022-09-03,50,8,24,52 114 | 2022-09-04,47,8,25,56 115 | 2022-09-05,48,10,27,56 116 | 2022-09-06,45,19,24,54 117 | 2022-09-07,49,20,21,54 118 | 2022-09-08,51,19,23,50 119 | 2022-09-09,49,16,21,52 120 | 2022-09-10,49,9,21,52 121 | 2022-09-11,52,9,22,55 122 | 2022-09-12,49,19,21,50 123 | 2022-09-13,54,21,20,52 124 | 2022-09-14,47,21,20,49 125 | 2022-09-15,45,20,30,51 126 | 2022-09-16,42,16,22,48 127 | 2022-09-17,47,9,23,52 128 | 2022-09-18,50,9,23,51 129 | 2022-09-19,48,20,21,48 130 | 2022-09-20,47,21,22,50 131 | 2022-09-21,47,22,23,51 132 | 2022-09-22,46,21,23,55 133 | 2022-09-23,46,18,21,49 134 | 2022-09-24,47,9,22,50 135 | 2022-09-25,48,9,23,53 136 | 2022-09-26,45,17,23,51 137 | 2022-09-27,46,20,23,47 138 | 2022-09-28,49,19,23,50 139 | 2022-09-29,45,18,30,56 140 | 2022-09-30,49,17,32,56 141 | 2022-10-01,47,9,36,64 142 | 2022-10-02,52,9,39,66 143 | 2022-10-03,53,19,31,60 144 | 2022-10-04,64,21,33,66 145 | 2022-10-05,58,20,34,68 146 | 2022-10-06,49,19,33,62 147 | 2022-10-07,51,17,22,49 148 | 2022-10-08,50,9,24,48 149 | 2022-10-09,52,9,22,52 150 | 2022-10-10,52,16,21,50 151 | 2022-10-11,51,20,22,50 152 | 2022-10-12,47,19,22,49 153 | 2022-10-13,46,20,23,48 154 | 2022-10-14,46,18,20,49 155 | 2022-10-15,49,9,22,48 156 | 2022-10-16,52,9,23,52 157 | 2022-10-17,47,19,22,52 158 | 2022-10-18,48,20,24,49 159 | 2022-10-19,46,19,24,50 160 | 2022-10-20,46,20,22,49 161 | 2022-10-21,49,18,22,47 162 | 2022-10-22,50,9,23,49 163 | 2022-10-23,58,8,22,53 164 | 2022-10-24,53,18,22,50 165 | 2022-10-25,53,21,22,50 166 | 2022-10-26,52,19,22,49 167 | 2022-10-27,56,19,28,49 168 | 2022-10-28,100,17,26,48 169 | 2022-10-29,75,8,25,49 170 | 2022-10-30,66,9,23,56 171 | 2022-10-31,69,17,21,83 172 | 2022-11-01,75,19,21,58 173 | 2022-11-02,64,21,23,51 174 | 2022-11-03,61,19,24,49 175 | 2022-11-04,76,17,28,52 176 | 2022-11-05,69,9,23,54 177 | 2022-11-06,62,8,25,51 178 | 2022-11-07,66,18,23,50 179 | 2022-11-08,60,18,24,47 180 | 2022-11-09,64,18,21,45 181 | 2022-11-10,61,19,22,49 182 | 2022-11-11,69,16,23,47 -------------------------------------------------------------------------------- /data/hillstrom_clean_label_mapping.json: -------------------------------------------------------------------------------- 1 | {"control": 0, "womans_email": 1, "mens_email": 2} -------------------------------------------------------------------------------- /data/ml_earnings.csv: -------------------------------------------------------------------------------- 1 | age,took_a_course,earnings 2 | 19,False,110579.0 3 | 28,False,142577.0 4 | 22,True,130520.0 5 | 25,True,142687.0 6 | 24,False,127832.0 7 | 23,False,125557.0 8 | 20,False,113922.0 9 | 23,False,124084.0 10 | 22,True,131905.0 11 | 28,False,138213.0 12 | 22,False,121170.0 13 | 32,False,154397.0 14 | 35,False,174856.0 15 | 33,False,163965.0 16 | 27,False,136437.0 17 | 22,True,132110.0 18 | 22,True,131140.0 19 | 34,False,166636.0 20 | 22,True,131868.0 21 | 30,True,157744.0 22 | 25,False,130011.0 23 | 41,True,207676.0 24 | 24,True,137985.0 25 | 30,True,160678.0 26 | 37,False,181514.0 27 | 33,False,159670.0 28 | 19,True,123372.0 29 | 26,False,134860.0 30 | 22,False,118836.0 31 | 32,True,167632.0 32 | 35,True,180750.0 33 | 24,False,122944.0 34 | 24,False,128146.0 35 | 25,False,133634.0 36 | 21,True,132726.0 37 | 24,True,137189.0 38 | 20,False,113015.0 39 | 22,False,120984.0 40 | 23,False,122567.0 41 | 22,True,130602.0 42 | 36,False,174964.0 43 | 32,False,161667.0 44 | 26,False,133173.0 45 | 22,False,119640.0 46 | 28,True,149414.0 47 | 36,False,177082.0 48 | 24,False,125944.0 49 | 19,False,111847.0 50 | 34,False,162654.0 51 | 20,True,123706.0 52 | 28,False,139394.0 53 | 23,False,126213.0 54 | 33,True,169772.0 55 | 23,True,135389.0 56 | 32,False,155655.0 57 | 21,False,115116.0 58 | 26,False,133087.0 59 | 23,True,134155.0 60 | 26,True,144015.0 61 | 25,True,143167.0 62 | 27,True,148107.0 63 | 33,False,162571.0 64 | 19,True,121933.0 65 | 25,False,127461.0 66 | 33,True,169929.0 67 | 21,False,117047.0 68 | 30,True,156543.0 69 | 36,False,173676.0 70 | 29,True,154168.0 71 | 23,True,133546.0 72 | 24,False,124253.0 73 | 24,False,128471.0 74 | 30,False,149020.0 75 | 19,False,112694.0 76 | 23,False,124771.0 77 | 30,True,163548.0 78 | 22,False,121280.0 79 | 19,False,108913.0 80 | 21,False,120143.0 81 | 19,True,125577.0 82 | 39,False,185360.0 83 | 33,True,169475.0 84 | 22,False,119802.0 85 | 26,True,143772.0 86 | 42,False,204343.0 87 | 21,False,115600.0 88 | 29,False,146446.0 89 | 22,False,119595.0 90 | 31,True,163592.0 91 | 23,True,134419.0 92 | 30,True,159419.0 93 | 24,True,138847.0 94 | 23,True,134703.0 95 | 23,False,125074.0 96 | 24,False,127307.0 97 | 34,False,165055.0 98 | 20,False,114155.0 99 | 20,False,117019.0 100 | 32,True,172076.0 101 | 42,False,201715.0 102 | 28,False,141481.0 103 | 27,False,139070.0 104 | 32,False,159917.0 105 | 19,False,108433.0 106 | 23,False,123586.0 107 | 33,True,170740.0 108 | 21,False,117066.0 109 | 24,False,124757.0 110 | 24,False,128756.0 111 | 23,False,123440.0 112 | 19,True,122551.0 113 | 34,True,179133.0 114 | 28,False,141329.0 115 | 23,False,125327.0 116 | 22,True,129636.0 117 | 34,False,167906.0 118 | 38,False,186144.0 119 | 41,False,201729.0 120 | 24,False,131654.0 121 | 20,True,126153.0 122 | 42,True,214445.0 123 | 22,False,121647.0 124 | 20,True,126501.0 125 | 28,False,141564.0 126 | 23,True,132636.0 127 | 20,False,118623.0 128 | 25,True,143428.0 129 | 27,False,139404.0 130 | 24,True,136236.0 131 | 23,True,134872.0 132 | 38,False,185658.0 133 | 34,False,166220.0 134 | 19,False,112008.0 135 | 30,False,153864.0 136 | 24,False,125082.0 137 | 34,True,178711.0 138 | 39,False,185525.0 139 | 33,False,163608.0 140 | 31,False,154288.0 141 | 35,False,173291.0 142 | 31,False,150776.0 143 | 30,False,149096.0 144 | 27,True,151007.0 145 | 26,True,146392.0 146 | 35,False,170960.0 147 | 45,True,233035.0 148 | 26,False,134897.0 149 | 21,False,116928.0 150 | 26,True,145022.0 151 | 20,True,124835.0 152 | 20,True,124358.0 153 | 43,False,210243.0 154 | 20,True,128241.0 155 | 21,False,115543.0 156 | 37,True,187627.0 157 | 22,False,118132.0 158 | 20,False,111554.0 159 | 20,False,118302.0 160 | 20,True,125639.0 161 | 29,True,155448.0 162 | 21,False,117786.0 163 | 21,False,119334.0 164 | 28,True,153722.0 165 | 22,True,132427.0 166 | 23,True,138425.0 167 | 33,True,170184.0 168 | 30,False,149355.0 169 | 48,False,239510.0 170 | 23,True,133744.0 171 | 28,False,144378.0 172 | 35,False,169430.0 173 | 21,True,128129.0 174 | 25,False,130482.0 175 | 29,False,145436.0 176 | 25,True,141578.0 177 | 27,True,147119.0 178 | 21,True,129610.0 179 | 28,True,154320.0 180 | 29,False,143885.0 181 | 24,False,127397.0 182 | 26,True,143563.0 183 | 22,False,121027.0 184 | 25,False,132476.0 185 | 21,False,116586.0 186 | 29,False,143288.0 187 | 39,False,190876.0 188 | 27,False,138602.0 189 | 23,False,123990.0 190 | 22,False,120377.0 191 | 29,False,147188.0 192 | 29,True,148466.0 193 | 19,True,124007.0 194 | 22,True,132727.0 195 | 28,False,143704.0 196 | 27,False,141665.0 197 | 35,True,180059.0 198 | 19,False,113031.0 199 | 38,False,184837.0 200 | 19,False,111382.0 201 | 32,False,155940.0 202 | -------------------------------------------------------------------------------- /data/ml_earnings_interaction_test.csv: -------------------------------------------------------------------------------- 1 | age,python_proficiency,took_a_course,true_effect 2 | 30,0.22387682273767961,True,11120.0 3 | 23,0.39415189924761007,True,11970.0 4 | 37,0.2146377127092679,True,11073.0 5 | 21,0.8690694065487253,True,14345.0 6 | 41,0.8339335515161994,True,14169.0 7 | 33,0.27165464274377815,True,11358.0 8 | 24,0.6749056709382142,True,13374.0 9 | 21,0.6548377191770705,True,13275.0 10 | 36,0.028713567977660337,True,10143.0 11 | 36,0.07706736417814364,True,10385.0 12 | 24,0.5969832080096646,True,12985.0 13 | 19,0.754652317177385,True,13773.0 14 | 22,0.9701476625375967,True,14850.0 15 | 23,0.516337378365627,True,12582.0 16 | 27,0.14763330181007217,True,10738.0 17 | 32,0.10665130841274617,True,10534.0 18 | 21,0.9034475785889416,True,14517.0 19 | 21,0.00754939322492898,True,10038.0 20 | 23,0.5856538108502074,True,12928.0 21 | 21,0.7731091931329046,True,13865.0 22 | 22,0.4155164444090621,True,12077.0 23 | 24,0.879776135206835,True,14399.0 24 | 30,0.9273541754244794,True,14637.0 25 | 22,0.8829521915372446,True,14415.0 26 | 25,0.41105913268579086,True,12055.0 27 | 25,0.6191696160038245,True,13096.0 28 | 24,0.04733337838583784,True,10236.0 29 | 41,0.47893314888488114,True,12395.0 30 | 44,0.09944093127332043,True,10498.0 31 | 25,0.2141457794330497,True,11071.0 32 | 34,0.3010400850697761,True,11505.0 33 | 19,0.8935352732069651,True,14468.0 34 | 44,0.5011724863627216,True,12506.0 35 | 19,0.18598147770314144,True,10930.0 36 | 19,0.23651636982734914,True,11182.0 37 | 19,0.9515226743334864,True,14757.0 38 | 20,0.5903458191540543,True,12952.0 39 | 22,0.614049897785869,True,13071.0 40 | 21,0.5685133860458594,True,12842.0 41 | 23,0.9657358123678771,True,14828.0 42 | 29,0.2939826080119571,True,11470.0 43 | 23,0.31206223355020046,True,11561.0 44 | 25,0.24394882853743838,True,11219.0 45 | 35,0.1888105446811692,True,10944.0 46 | 21,0.17212297381880481,True,10861.0 47 | 27,0.6787326814995325,True,13393.0 48 | 20,0.8485841294117298,True,14243.0 49 | 35,0.9484767287957341,True,14742.0 50 | 28,0.18462968527393953,True,10923.0 51 | 23,0.19854916503306286,True,10992.0 52 | 30,0.7868749315943776,True,13934.0 53 | 31,0.2560144316176236,True,11281.0 54 | 23,0.7931986715456523,True,13966.0 55 | 19,0.5984763023800789,True,12993.0 56 | 34,0.3218355012581319,True,11609.0 57 | 26,0.8773770629018882,True,14387.0 58 | 20,0.027998583565693846,True,10140.0 59 | 24,0.7016890963543043,True,13508.0 60 | 28,0.414826108610463,True,12074.0 61 | 24,0.4055053361791364,True,12027.0 62 | 27,0.672410497084009,True,13362.0 63 | 24,0.9188827199485727,True,14594.0 64 | 21,0.8581974673286622,True,14291.0 65 | 23,0.24638740966187578,True,11232.0 66 | 21,0.03482611933114499,True,10174.0 67 | 25,0.6901719167371854,True,13451.0 68 | 32,0.37286235605965223,True,11864.0 69 | 33,0.10630526540167351,True,10532.0 70 | 32,0.14191426691020936,True,10709.0 71 | 27,0.6947074917198913,True,13474.0 72 | 19,0.29825221382550793,True,11491.0 73 | 29,0.6289800314325729,True,13145.0 74 | 21,0.007712802008101538,True,10039.0 75 | 28,0.7339499732538902,True,13670.0 76 | 28,0.6641348201845104,True,13321.0 77 | 27,0.22998238811942795,True,11150.0 78 | 26,0.20427867542820322,True,11022.0 79 | 21,0.5058140845042293,True,12529.0 80 | 30,0.2150062667778866,True,11075.0 81 | 45,0.5996982119068462,True,12998.0 82 | 24,0.46955708258014006,True,12348.0 83 | 34,0.689212768736534,True,13446.0 84 | 28,0.688789641829235,True,13444.0 85 | 19,0.08368086340831893,True,10419.0 86 | 36,0.08353312239165811,True,10418.0 87 | 32,0.6602178774563141,True,13301.0 88 | 25,0.3034097182007709,True,11517.0 89 | 24,0.4722092843068547,True,12361.0 90 | 28,0.894533822251131,True,14473.0 91 | 24,0.012281682704884167,True,10061.0 92 | 27,0.4460295972542465,True,12230.0 93 | 22,0.17985420688312093,True,10899.0 94 | 22,0.8572867893318362,True,14287.0 95 | 29,0.9950724171687579,True,14976.0 96 | 34,0.45387654040874015,True,12270.0 97 | 21,0.27297914716043903,True,11365.0 98 | 30,0.7261709231998095,True,13630.0 99 | 29,0.4287144754158567,True,12143.0 100 | 25,0.17942665780924838,True,10897.0 101 | 19,0.45354055693928375,True,12268.0 102 | -------------------------------------------------------------------------------- /data/shpitser_thesis1.grapl: -------------------------------------------------------------------------------- 1 | "Shpitser thesis, Fig 4.1"; 2 | X; W1; W2; Y1; Y2; 3 | X -> Y1; 4 | W1 <-> Y1; 5 | W1 -> X; 6 | W1 <-> Y2; 7 | W1 <-> W2; 8 | W2 -> Y2; 9 | W2 <-> X; -------------------------------------------------------------------------------- /docs/himsolt-gml-technical-report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/docs/himsolt-gml-technical-report.pdf -------------------------------------------------------------------------------- /errata/Errata - Early Print (ordered before June 13 2023).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d3201184", 6 | "metadata": {}, 7 | "source": [ 8 | "# Errata\n", 9 | "\n", 10 | "\n", 11 | "(Books ordered before ~12:00 PM GMT on June 13, 2023)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "ea449a06", 17 | "metadata": {}, 18 | "source": [ 19 | "## Chapter 02\n", 20 | "\n", 21 | "\n", 22 | "**Page 22**\n", 23 | "\n", 24 | "\n", 25 | "In the book (p. 22) we say:\n", 26 | "_______________________________________________________________\n", 27 | "\n", 28 | "*(...) the probability of buying book A, given we bought book B, is 63.8%. This indicates a\n", 29 | "positive relationship between both variables (if there was no association between them, we would\n", 30 | "expect the result to be **50%**)*\n", 31 | "\n", 32 | "\n", 33 | "This is **incorrect**. \n", 34 | "\n", 35 | "_______________________________________________________________\n", 36 | "\n", 37 | "The **correct** statement should be:\n", 38 | "\n", 39 | "*(...) the probability of buying book A, given we bought book B, is 63.8%. This indicates a\n", 40 | "positive relationship between both variables (if there was no association between them, we would\n", 41 | "expect the result to be **39%**)*\n", 42 | "_______________________________________________________________\n", 43 | "\n", 44 | " \n", 45 | "*Thanks for David for spotting this mistake.*" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "952febd2", 51 | "metadata": {}, 52 | "source": [ 53 | "## Chapter 04\n", 54 | "\n", 55 | "\n", 56 | "**Figure 4.1** is **incorrect**.\n", 57 | "\n", 58 | "The **correct** version of **Figure 4.1** is the following:\n", 59 | "\n", 60 | "" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "id": "85952f08", 66 | "metadata": {}, 67 | "source": [ 68 | "## Chapter 06\n", 69 | "\n", 70 | "### You’re gonna keep ‘em d-separated\n", 71 | "\n", 72 | "\n", 73 | "**Figure 6.3** representing DAG **d)** in the *Keep 'em d-separated* game is **incorrect**.\n", 74 | "\n", 75 | "The **correct** version of **Figure 6.3** is the following:\n", 76 | "\n", 77 | "\n", 78 | "\n", 79 | "

\n", 80 | "\n", 81 | "\n", 82 | "**Figure 6.4** representing DAG **e)** in the *Keep 'em d-separated* game is **incorrect**.\n", 83 | "\n", 84 | "The **correct** version of **Figure 6.4** is the following:\n", 85 | "\n", 86 | "\n", 87 | "\n", 88 | "

\n", 89 | "\n", 90 | "\n", 91 | "**Figure 6.5** representing **all DAGs** in the *Keep 'em d-separated* game is **incorrect**.\n", 92 | "\n", 93 | "The **correct** version of **Figure 6.5** is the following:\n", 94 | "\n", 95 | "\n", 96 | "\n", 97 | "\n", 98 | "### Instrumental variables\n", 99 | "\n", 100 | "**Page 121**\n", 101 | "\n", 102 | "In the book (p. 121) one of the formulas for regression models is **incorrect**. \n", 103 | "\n", 104 | "The book says:\n", 105 | "\n", 106 | "_________\n", 107 | "\n", 108 | "\n", 109 | "*To calculate the causal effect of X on Y in a linear case, all we need to do is fit two linear regression\n", 110 | "models and compute the ratio of their coefficients!*\n", 111 | "\n", 112 | "*The two models are as follows:*\n", 113 | "\n", 114 | "*• Y ~ Z*\n", 115 | "\n", 116 | "*• Y ~ X*\n", 117 | "\n", 118 | "_____________\n", 119 | "\n", 120 | "\n", 121 | "The **correct** formulas should be:\n", 122 | "\n", 123 | "_____________\n", 124 | "\n", 125 | "*The two models are as follows:*\n", 126 | "\n", 127 | "*• Y ~ Z*\n", 128 | "\n", 129 | "*• X ~ Z*\n", 130 | "\n", 131 | "__________________\n", 132 | "\n", 133 | "\n", 134 | "The **code example** below uses the **correct formulas**." 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "id": "8c0953ef", 140 | "metadata": {}, 141 | "source": [ 142 | "## Chapter 07\n", 143 | "\n", 144 | "\n", 145 | "**Figure 7.6** is **incorrect**.\n", 146 | "\n", 147 | "The **correct** version of **Figure 7.6** is the following:\n", 148 | "\n", 149 | "" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "id": "7575e3b4", 155 | "metadata": {}, 156 | "source": [ 157 | "## Chapter 13\n", 158 | "\n", 159 | "\n", 160 | "**Page 343** (point 3)\n", 161 | "\n", 162 | "\n", 163 | "In the book (p. 343) in point **3.** we say:\n", 164 | "\n", 165 | "_______________________________________________________________\n", 166 | "\n", 167 | "3. *We remove the edges between **P and Q**, and Q and S, as P ⫫ Q | R and Q ⫫ S | R (C).*\n", 168 | "\n", 169 | "\n", 170 | "This is **incorrect**. \n", 171 | "\n", 172 | "_______________________________________________________________\n", 173 | "\n", 174 | "The **correct** statement should be:\n", 175 | "\n", 176 | "3. *We remove the edges between **P and S**, and Q and S, as P ⫫ Q | R and Q ⫫ S | R (C).*\n", 177 | "\n", 178 | "_______________________________________________________________\n", 179 | "\n", 180 | " \n", 181 | "*Thanks for Miguel for spotting this mistake.*\n", 182 | " \n", 183 | "______________________________________________________\n", 184 | "

\n", 185 | "\n", 186 | "**Page 343** (point 4)\n", 187 | "\n", 188 | "In the book (p. 343) in point **4.** we say:\n", 189 | "_______________________________________________________________\n", 190 | "\n", 191 | "4. *We orient the edges **P -> R <- S** because R is a collider between **P and S** (D).*\n", 192 | "\n", 193 | "\n", 194 | "This is **incorrect**. \n", 195 | "\n", 196 | "_______________________________________________________________\n", 197 | "\n", 198 | "The **correct** statement should be:\n", 199 | "\n", 200 | "4. *We orient the edges **P -> R <- Q** because R is a collider between **P and Q** (D).*\n", 201 | "\n", 202 | "_______________________________________________________________\n", 203 | "\n", 204 | " \n", 205 | "*Thanks for Takashi for spotting this mistake.*" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "id": "7d249ae5", 211 | "metadata": {}, 212 | "source": [ 213 | "## Chapter 14\n", 214 | "\n", 215 | "**Page 391** (bullet point 1)\n", 216 | "\n", 217 | "\n", 218 | "In the book (p. 391) in the first bullet point we say:\n", 219 | "\n", 220 | "_______________________________________________________________\n", 221 | "\n", 222 | "*- **X** is not an ancestor of **Y** - there might be...*\n", 223 | "\n", 224 | "\n", 225 | "This is **incorrect**. \n", 226 | "\n", 227 | "_______________________________________________________________\n", 228 | "\n", 229 | "The **correct** statement should be:\n", 230 | "\n", 231 | "*- **Y** is not an ancestor of **X** - there might be...*\n", 232 | "_______________________________________________________________\n", 233 | "\n" 234 | ] 235 | } 236 | ], 237 | "metadata": { 238 | "kernelspec": { 239 | "display_name": "Python 3 (ipykernel)", 240 | "language": "python", 241 | "name": "python3" 242 | }, 243 | "language_info": { 244 | "codemirror_mode": { 245 | "name": "ipython", 246 | "version": 3 247 | }, 248 | "file_extension": ".py", 249 | "mimetype": "text/x-python", 250 | "name": "python", 251 | "nbconvert_exporter": "python", 252 | "pygments_lexer": "ipython3", 253 | "version": "3.9.18" 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 5 258 | } 259 | -------------------------------------------------------------------------------- /errata/Errata - Non-Early Print (ordered after June 13 2023).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d3201184", 6 | "metadata": {}, 7 | "source": [ 8 | "# Errata\n", 9 | "\n", 10 | "(Books ordered after ~12:00 PM GMT on June 13, 2023)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "a9740489", 16 | "metadata": {}, 17 | "source": [ 18 | "## Chapter 02\n", 19 | "\n", 20 | "\n", 21 | "**Page 22**\n", 22 | "\n", 23 | "\n", 24 | "In the book (p. 22) we say:\n", 25 | "_______________________________________________________________\n", 26 | "\n", 27 | "*(...) the probability of buying book A, given we bought book B, is 63.8%. This indicates a\n", 28 | "positive relationship between both variables (if there was no association between them, we would\n", 29 | "expect the result to be **50%**)*\n", 30 | "\n", 31 | "\n", 32 | "This is **incorrect**. \n", 33 | "\n", 34 | "_______________________________________________________________\n", 35 | "\n", 36 | "The **correct** statement should be:\n", 37 | "\n", 38 | "*(...) the probability of buying book A, given we bought book B, is 63.8%. This indicates a\n", 39 | "positive relationship between both variables (if there was no association between them, we would\n", 40 | "expect the result to be **39%**)*\n", 41 | "_______________________________________________________________\n", 42 | "\n", 43 | " \n", 44 | "*Thanks for David for spotting this mistake.*" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "85952f08", 50 | "metadata": {}, 51 | "source": [ 52 | "## Chapter 06\n", 53 | "\n", 54 | "\n", 55 | "### Instrumental variables\n", 56 | "\n", 57 | "**Page 121**\n", 58 | "\n", 59 | "In the book (p. 121) one of the formulas for regression models is **incorrect**. \n", 60 | "\n", 61 | "The book says:\n", 62 | "\n", 63 | "_________\n", 64 | "\n", 65 | "\n", 66 | "*To calculate the causal effect of X on Y in a linear case, all we need to do is fit two linear regression\n", 67 | "models and compute the ratio of their coefficients!*\n", 68 | "\n", 69 | "*The two models are as follows:*\n", 70 | "\n", 71 | "*• Y ~ Z*\n", 72 | "\n", 73 | "*• Y ~ X*\n", 74 | "\n", 75 | "_____________\n", 76 | "\n", 77 | "\n", 78 | "The **correct** formulas should be:\n", 79 | "\n", 80 | "_____________\n", 81 | "\n", 82 | "*The two models are as follows:*\n", 83 | "\n", 84 | "*• Y ~ Z*\n", 85 | "\n", 86 | "*• X ~ Z*\n", 87 | "\n", 88 | "__________________\n", 89 | "\n", 90 | "\n", 91 | "The **code example** below uses the **correct formulas**." 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "53434695", 97 | "metadata": {}, 98 | "source": [ 99 | "## Chapter 13\n", 100 | "\n", 101 | "\n", 102 | "**Page 343** (point 3)\n", 103 | "\n", 104 | "\n", 105 | "In the book (p. 343) in point **3.** we say:\n", 106 | "\n", 107 | "_______________________________________________________________\n", 108 | "\n", 109 | "3. *We remove the edges between **P and Q**, and Q and S, as P ⫫ Q | R and Q ⫫ S | R (C).*\n", 110 | "\n", 111 | "\n", 112 | "This is **incorrect**. \n", 113 | "\n", 114 | "_______________________________________________________________\n", 115 | "\n", 116 | "The **correct** statement should be:\n", 117 | "\n", 118 | "3. *We remove the edges between **P and S**, and Q and S, as P ⫫ Q | R and Q ⫫ S | R (C).*\n", 119 | "\n", 120 | "_______________________________________________________________\n", 121 | "\n", 122 | " \n", 123 | "*Thanks for Miguel for spotting this mistake.*\n", 124 | " \n", 125 | "______________________________________________________\n", 126 | "

\n", 127 | "\n", 128 | "**Page 343** (point 4)\n", 129 | "\n", 130 | "In the book (p. 343) in point **4.** we say:\n", 131 | "_______________________________________________________________\n", 132 | "\n", 133 | "4. *We orient the edges **P -> R <- S** because R is a collider between **P and S** (D).*\n", 134 | "\n", 135 | "\n", 136 | "This is **incorrect**. \n", 137 | "\n", 138 | "_______________________________________________________________\n", 139 | "\n", 140 | "The **correct** statement should be:\n", 141 | "\n", 142 | "4. *We orient the edges **P -> R <- Q** because R is a collider between **P and Q** (D).*\n", 143 | "\n", 144 | "_______________________________________________________________\n", 145 | "\n", 146 | " \n", 147 | "*Thanks for Takashi for spotting this mistake.*" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "9e86603e", 153 | "metadata": {}, 154 | "source": [ 155 | "## Chapter 14\n", 156 | "\n", 157 | "**Page 391** (bullet point 1)\n", 158 | "\n", 159 | "\n", 160 | "In the book (p. 391) in the first bullet point we say:\n", 161 | "\n", 162 | "_______________________________________________________________\n", 163 | "\n", 164 | "*- **X** is not an ancestor of **Y** - there might be...*\n", 165 | "\n", 166 | "\n", 167 | "This is **incorrect**. \n", 168 | "\n", 169 | "_______________________________________________________________\n", 170 | "\n", 171 | "The **correct** statement should be:\n", 172 | "\n", 173 | "*- **Y** is not an ancestor of **X** - there might be...*\n", 174 | "_______________________________________________________________\n", 175 | "\n" 176 | ] 177 | } 178 | ], 179 | "metadata": { 180 | "kernelspec": { 181 | "display_name": "Python 3 (ipykernel)", 182 | "language": "python", 183 | "name": "python3" 184 | }, 185 | "language_info": { 186 | "codemirror_mode": { 187 | "name": "ipython", 188 | "version": 3 189 | }, 190 | "file_extension": ".py", 191 | "mimetype": "text/x-python", 192 | "name": "python", 193 | "nbconvert_exporter": "python", 194 | "pygments_lexer": "ipython3", 195 | "version": "3.9.18" 196 | } 197 | }, 198 | "nbformat": 4, 199 | "nbformat_minor": 5 200 | } 201 | -------------------------------------------------------------------------------- /errata/img/ch_04__fig_4_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/errata/img/ch_04__fig_4_1.png -------------------------------------------------------------------------------- /errata/img/ch_06__fig_6_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/errata/img/ch_06__fig_6_3.png -------------------------------------------------------------------------------- /errata/img/ch_06__fig_6_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/errata/img/ch_06__fig_6_4.png -------------------------------------------------------------------------------- /errata/img/ch_06__fig_6_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/errata/img/ch_06__fig_6_5.png -------------------------------------------------------------------------------- /errata/img/ch_07__fig_7_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/errata/img/ch_07__fig_7_6.png -------------------------------------------------------------------------------- /errata/minor-errors.txt: -------------------------------------------------------------------------------- 1 | ## Chapter 06 2 | 3 | - Code section "Front-door in practice" 4 | - `class GPSMemorySCM:`: 5 | - `u_y` is not used anywhere while it should be used in the `memory` computations [reported by Arash A.] 6 | 7 | -------------------------------------------------------------------------------- /img/ch_03_graph_01: -------------------------------------------------------------------------------- 1 | digraph { 2 | A 3 | X 4 | B 5 | Y 6 | A -> X 7 | X -> B 8 | A -> Y 9 | Y -> B 10 | } 11 | -------------------------------------------------------------------------------- /img/ch_03_graph_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_03_graph_01.png -------------------------------------------------------------------------------- /img/ch_03_graph_02: -------------------------------------------------------------------------------- 1 | digraph { 2 | graph [nodesep=1.5 rankdir=LR ranksep=.6] 3 | A 4 | X 5 | Y 6 | A -> Y 7 | A -> X 8 | X -> Y 9 | } 10 | -------------------------------------------------------------------------------- /img/ch_03_graph_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_03_graph_02.png -------------------------------------------------------------------------------- /img/ch_04_graph_DAG: -------------------------------------------------------------------------------- 1 | digraph { 2 | A 3 | B 4 | C 5 | D 6 | A -> B 7 | B -> C 8 | A -> D 9 | D -> C 10 | } 11 | -------------------------------------------------------------------------------- /img/ch_04_graph_DAG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_04_graph_DAG.png -------------------------------------------------------------------------------- /img/ch_04_graph_DCG: -------------------------------------------------------------------------------- 1 | digraph { 2 | A 3 | B 4 | C 5 | D 6 | A -> B 7 | A -> D 8 | B -> B 9 | B -> C 10 | D -> C 11 | C -> A 12 | } 13 | -------------------------------------------------------------------------------- /img/ch_04_graph_DCG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_04_graph_DCG.png -------------------------------------------------------------------------------- /img/ch_04_graph_Fully connected: -------------------------------------------------------------------------------- 1 | graph { 2 | A 3 | B 4 | C 5 | D 6 | A -- B 7 | A -- C 8 | A -- D 9 | B -- C 10 | B -- D 11 | C -- D 12 | } 13 | -------------------------------------------------------------------------------- /img/ch_04_graph_Fully connected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_04_graph_Fully connected.png -------------------------------------------------------------------------------- /img/ch_04_graph_Partially connected: -------------------------------------------------------------------------------- 1 | graph { 2 | A 3 | B 4 | C 5 | D 6 | A -- B 7 | A -- C 8 | B -- C 9 | } 10 | -------------------------------------------------------------------------------- /img/ch_04_graph_Partially connected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_04_graph_Partially connected.png -------------------------------------------------------------------------------- /img/ch_04_graph_Undirected: -------------------------------------------------------------------------------- 1 | graph { 2 | A 3 | B 4 | C 5 | D 6 | A -- B 7 | B -- C 8 | A -- D 9 | D -- C 10 | } 11 | -------------------------------------------------------------------------------- /img/ch_04_graph_Undirected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_04_graph_Undirected.png -------------------------------------------------------------------------------- /img/ch_04_graph_adj_00: -------------------------------------------------------------------------------- 1 | digraph { 2 | 0 3 | 1 4 | 0 -> 1 5 | } 6 | -------------------------------------------------------------------------------- /img/ch_04_graph_adj_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_04_graph_adj_00.png -------------------------------------------------------------------------------- /img/ch_04_graph_adj_01: -------------------------------------------------------------------------------- 1 | digraph { 2 | 0 3 | 1 4 | 2 5 | 0 -> 1 6 | 2 -> 1 7 | 2 -> 0 8 | } 9 | -------------------------------------------------------------------------------- /img/ch_04_graph_adj_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_04_graph_adj_01.png -------------------------------------------------------------------------------- /img/ch_04_graph_adj_02: -------------------------------------------------------------------------------- 1 | digraph { 2 | 0 3 | 1 4 | 2 5 | 3 6 | 0 -> 2 7 | 1 -> 3 8 | 3 -> 2 9 | 3 -> 0 10 | } 11 | -------------------------------------------------------------------------------- /img/ch_04_graph_adj_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_04_graph_adj_02.png -------------------------------------------------------------------------------- /img/ch_05_chain_00: -------------------------------------------------------------------------------- 1 | digraph { 2 | A [pos="0,0!"] 3 | B [pos="1.5,0!"] 4 | C [pos="3,0!"] 5 | A -> B 6 | B -> C 7 | } 8 | -------------------------------------------------------------------------------- /img/ch_05_chain_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_05_chain_00.png -------------------------------------------------------------------------------- /img/ch_05_collider_00: -------------------------------------------------------------------------------- 1 | digraph { 2 | A [pos="0,0!"] 3 | B [pos="1.5,0!"] 4 | C [pos="3,0!"] 5 | A -> B 6 | C -> B 7 | } 8 | -------------------------------------------------------------------------------- /img/ch_05_collider_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_05_collider_00.png -------------------------------------------------------------------------------- /img/ch_05_fork_00: -------------------------------------------------------------------------------- 1 | digraph { 2 | A [pos="0,0!"] 3 | B [pos="1.5,0!"] 4 | C [pos="3,0!"] 5 | B -> A 6 | B -> C 7 | } 8 | -------------------------------------------------------------------------------- /img/ch_05_fork_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_05_fork_00.png -------------------------------------------------------------------------------- /img/ch_05_markov_01: -------------------------------------------------------------------------------- 1 | digraph { 2 | A [pos="0,2.75!"] 3 | B [pos="2,3!"] 4 | C [pos="0,1!"] 5 | A -> B 6 | C -> B 7 | C -> A 8 | } 9 | -------------------------------------------------------------------------------- /img/ch_05_markov_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_05_markov_01.png -------------------------------------------------------------------------------- /img/ch_05_markov_02: -------------------------------------------------------------------------------- 1 | digraph { 2 | A [pos="0,2.75!"] 3 | B [pos="2,3!"] 4 | C [pos="0,1!"] 5 | A -> B 6 | C -> B 7 | A -> C 8 | } 9 | -------------------------------------------------------------------------------- /img/ch_05_markov_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_05_markov_02.png -------------------------------------------------------------------------------- /img/ch_06_confounding_00: -------------------------------------------------------------------------------- 1 | digraph { 2 | "Gene X Expression" [pos="4,0!"] 3 | Stress [pos="2,1.75!"] 4 | "Immune System Response" [pos="7.5,0!"] 5 | "Delay of Gratification" [pos="2.5,-1.75!"] 6 | "Gene X Expression" -> "Immune System Response" 7 | Stress -> "Gene X Expression" 8 | Stress -> "Immune System Response" 9 | Stress -> "Delay of Gratification" 10 | "Gene X Expression" -> "Delay of Gratification" 11 | } 12 | -------------------------------------------------------------------------------- /img/ch_06_confounding_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_confounding_00.png -------------------------------------------------------------------------------- /img/ch_06_d_sep_00: -------------------------------------------------------------------------------- 1 | digraph { 2 | X [pos="1,0!"] 3 | B [pos="2.5,0!"] 4 | Y [pos="4,0!"] 5 | X -> B 6 | B -> Y 7 | } 8 | -------------------------------------------------------------------------------- /img/ch_06_d_sep_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_d_sep_00.png -------------------------------------------------------------------------------- /img/ch_06_d_sep_01: -------------------------------------------------------------------------------- 1 | digraph { 2 | X [pos="1,0!"] 3 | B [pos="2.5,0!"] 4 | Y [pos="4,0!"] 5 | X -> B 6 | Y -> B 7 | } 8 | -------------------------------------------------------------------------------- /img/ch_06_d_sep_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_d_sep_01.png -------------------------------------------------------------------------------- /img/ch_06_d_sep_02: -------------------------------------------------------------------------------- 1 | digraph { 2 | A [pos="1,1!"] 3 | X [pos="3,1!"] 4 | B [pos="3,0!"] 5 | Y [pos="1, 0!"] 6 | X -> B 7 | Y -> B 8 | X -> A 9 | A -> Y 10 | } 11 | -------------------------------------------------------------------------------- /img/ch_06_d_sep_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_d_sep_02.png -------------------------------------------------------------------------------- /img/ch_06_d_sep_03: -------------------------------------------------------------------------------- 1 | digraph { 2 | A [pos="1.5,0!"] 3 | X [pos="0,0!"] 4 | B [pos="3,0!"] 5 | Y [pos="4.5,0!"] 6 | X -> A 7 | B -> A 8 | B -> Y 9 | } 10 | -------------------------------------------------------------------------------- /img/ch_06_d_sep_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_d_sep_03.png -------------------------------------------------------------------------------- /img/ch_06_d_sep_04: -------------------------------------------------------------------------------- 1 | digraph { 2 | S -> Q 3 | S -> Y 4 | Q -> X 5 | Q -> Y 6 | X -> Z 7 | Z -> Y 8 | X -> P 9 | Y -> P 10 | } 11 | -------------------------------------------------------------------------------- /img/ch_06_d_sep_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_d_sep_04.png -------------------------------------------------------------------------------- /img/ch_06_equivalent_estimands_00: -------------------------------------------------------------------------------- 1 | digraph { 2 | X [pos="0,0!"] 3 | Y [pos="3,0!"] 4 | A [pos=".5,2!"] 5 | B [pos="1.75,1!"] 6 | X -> Y 7 | A -> X 8 | A -> B 9 | B -> Y 10 | } 11 | -------------------------------------------------------------------------------- /img/ch_06_equivalent_estimands_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_equivalent_estimands_00.png -------------------------------------------------------------------------------- /img/ch_06_equivalent_estimands_01: -------------------------------------------------------------------------------- 1 | digraph { 2 | X [pos="0,0!"] 3 | Y [pos="3,0!"] 4 | A [pos=".5,2!"] 5 | B [pos="1.75,1!"] 6 | A [style=dashed] 7 | X -> Y 8 | B -> Y 9 | A -> X [style=dashed] 10 | A -> B [style=dashed] 11 | } 12 | -------------------------------------------------------------------------------- /img/ch_06_equivalent_estimands_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_equivalent_estimands_01.png -------------------------------------------------------------------------------- /img/ch_06_gps_00: -------------------------------------------------------------------------------- 1 | digraph { 2 | GPS [pos="0,0!"] 3 | Memory [pos="3,0!"] 4 | Motivation [pos="1.5,1.5!"] 5 | Motivation [style=dashed] 6 | Motivation -> GPS [style=dashed] 7 | Motivation -> Memory [style=dashed] 8 | } 9 | -------------------------------------------------------------------------------- /img/ch_06_gps_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_gps_00.png -------------------------------------------------------------------------------- /img/ch_06_gps_01: -------------------------------------------------------------------------------- 1 | digraph { 2 | GPS [pos="0,0!"] 3 | Memory [pos="3,0!"] 4 | Motivation [pos="1.5,1.5!"] 5 | Motivation [style=dashed] 6 | Motivation -> GPS [style=dashed] 7 | Motivation -> Memory [style=dashed] 8 | GPS -> Memory 9 | } 10 | -------------------------------------------------------------------------------- /img/ch_06_gps_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_gps_01.png -------------------------------------------------------------------------------- /img/ch_06_gps_02: -------------------------------------------------------------------------------- 1 | digraph { 2 | GPS [pos="0,0!"] 3 | "Hippocampal vol." [pos="2.5,0!"] 4 | Memory [pos="5,0!"] 5 | Motivation [pos="2.5,1.5!"] 6 | Motivation [style=dashed] 7 | Motivation -> GPS [style=dashed] 8 | Motivation -> Memory [style=dashed] 9 | GPS -> "Hippocampal vol." 10 | "Hippocampal vol." -> Memory 11 | } 12 | -------------------------------------------------------------------------------- /img/ch_06_gps_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_gps_02.png -------------------------------------------------------------------------------- /img/ch_06_gps_03: -------------------------------------------------------------------------------- 1 | digraph { 2 | X [pos="0,0!"] 3 | Z [pos="2.5,0!"] 4 | Y [pos="5,0!"] 5 | U [pos="2.5,1.5!"] 6 | U [style=dashed] 7 | U -> X [style=dashed] 8 | U -> Y [style=dashed] 9 | X -> Z 10 | Z -> Y 11 | } 12 | -------------------------------------------------------------------------------- /img/ch_06_gps_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_gps_03.png -------------------------------------------------------------------------------- /img/ch_06_icecream: -------------------------------------------------------------------------------- 1 | digraph { 2 | ICE [pos="0,0!"] 3 | TMP [pos="1.5,.75!"] 4 | ACC [pos="3,0!"] 5 | TMP -> ICE 6 | TMP -> ACC 7 | } 8 | -------------------------------------------------------------------------------- /img/ch_06_icecream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_icecream.png -------------------------------------------------------------------------------- /img/ch_06_instrumental_00: -------------------------------------------------------------------------------- 1 | digraph { 2 | Z [pos="0,0!"] 3 | X [pos="1.5,0!"] 4 | Y [pos="5,0!"] 5 | U [pos="3.25,1.5!"] 6 | U [style=dashed] 7 | U -> X [style=dashed] 8 | U -> Y [style=dashed] 9 | Z -> X 10 | X -> Y 11 | } 12 | -------------------------------------------------------------------------------- /img/ch_06_instrumental_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_06_instrumental_00.png -------------------------------------------------------------------------------- /img/ch_07_full_example: -------------------------------------------------------------------------------- 1 | digraph { 2 | S [pos="2,2.5!"] 3 | Q [pos="3,1!"] 4 | X [pos="3,0!"] 5 | Y [pos="1, 0!"] 6 | P [pos="1,2!"] 7 | S -> Q 8 | S -> Y 9 | Q -> X 10 | Q -> Y 11 | X -> P 12 | Y -> P 13 | X -> Y 14 | } 15 | -------------------------------------------------------------------------------- /img/ch_07_full_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_07_full_example.png -------------------------------------------------------------------------------- /img/ch_08_modularity: -------------------------------------------------------------------------------- 1 | digraph { 2 | X [pos="1,2!"] 3 | Y [pos="4,.5!"] 4 | R [pos="2.5,.5!"] 5 | Z [pos="1, 0!"] 6 | X -> R 7 | Z -> R 8 | R -> Y 9 | X -> Y 10 | } 11 | -------------------------------------------------------------------------------- /img/ch_08_modularity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_08_modularity.png -------------------------------------------------------------------------------- /img/ch_08_modularity_mod: -------------------------------------------------------------------------------- 1 | digraph { 2 | X [pos="1,2!"] 3 | Y [pos="4,.5!"] 4 | R [pos="2.5,.5!"] 5 | Z [pos="1, 0!"] 6 | R -> Y 7 | X -> Y 8 | } 9 | -------------------------------------------------------------------------------- /img/ch_08_modularity_mod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_08_modularity_mod.png -------------------------------------------------------------------------------- /img/ch_08_selection: -------------------------------------------------------------------------------- 1 | digraph { 2 | T [pos="0,.5!"] 3 | Y [pos="1.5,0!"] 4 | C [pos="3,.5!"] 5 | T -> Y 6 | T -> C 7 | Y -> C 8 | } 9 | -------------------------------------------------------------------------------- /img/ch_08_selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_08_selection.png -------------------------------------------------------------------------------- /img/ch_08_selection_02: -------------------------------------------------------------------------------- 1 | digraph { 2 | T [pos="1.,1.5!"] 3 | Y [pos="0,0!"] 4 | C [pos="3,0!"] 5 | Z [pos="3,2!"] 6 | W [pos="0,2!"] 7 | Z -> C 8 | T -> C 9 | W -> Z 10 | W -> Y 11 | } 12 | -------------------------------------------------------------------------------- /img/ch_08_selection_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_08_selection_02.png -------------------------------------------------------------------------------- /img/ch_08_selection_03: -------------------------------------------------------------------------------- 1 | digraph { 2 | T [pos="1.,1.5!"] 3 | Y [pos="0,0!"] 4 | C [pos="3,0!"] 5 | Z [pos="3,2!"] 6 | W [pos="0,2!"] 7 | T -> Z 8 | Z -> C 9 | W -> Z 10 | W -> Y 11 | } 12 | -------------------------------------------------------------------------------- /img/ch_08_selection_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/img/ch_08_selection_03.png -------------------------------------------------------------------------------- /models/causal_bert_pytorch/CausalBert.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code comes originally from https://github.com/rpryzant/causal-bert-pytorch 3 | At the time of writing the original code contained an error 4 | (https://github.com/rpryzant/causal-bert-pytorch/issues/6) 5 | that made using one of the methods (.ATE()) unadvisable. 6 | This version of code fixes this error. 7 | 8 | An extensible implementation of the Causal Bert model from 9 | "Adapting Text Embeddings for Causal Inference" 10 | (https://arxiv.org/abs/1905.12741) 11 | """ 12 | from collections import defaultdict 13 | import os 14 | import pickle 15 | 16 | import scipy 17 | from sklearn.model_selection import KFold 18 | 19 | from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler 20 | 21 | from transformers import BertTokenizer 22 | from transformers import BertModel, BertPreTrainedModel, AdamW, BertConfig 23 | from transformers import get_linear_schedule_with_warmup 24 | 25 | from transformers import DistilBertTokenizer 26 | from transformers import DistilBertModel, DistilBertPreTrainedModel 27 | 28 | from torch.nn import CrossEntropyLoss 29 | 30 | import torch 31 | import torch.nn as nn 32 | from scipy.special import softmax 33 | import numpy as np 34 | from scipy.special import logit 35 | from sklearn.linear_model import LogisticRegression 36 | 37 | from tqdm import tqdm 38 | import math 39 | 40 | CUDA = (torch.cuda.device_count() > 0) 41 | MASK_IDX = 103 42 | 43 | 44 | def platt_scale(outcome, probs): 45 | logits = logit(probs) 46 | logits = logits.reshape(-1, 1) 47 | log_reg = LogisticRegression(penalty='none', warm_start=True, solver='lbfgs') 48 | log_reg.fit(logits, outcome) 49 | return log_reg.predict_proba(logits) 50 | 51 | 52 | def gelu(x): 53 | return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0))) 54 | 55 | 56 | def make_bow_vector(ids, vocab_size, use_counts=False): 57 | """ Make a sparse BOW vector from a tensor of dense ids. 58 | Args: 59 | ids: torch.LongTensor [batch, features]. Dense tensor of ids. 60 | vocab_size: vocab size for this tensor. 61 | use_counts: if true, the outgoing BOW vector will contain 62 | feature counts. If false, will contain binary indicators. 63 | Returns: 64 | The sparse bag-of-words representation of ids. 65 | """ 66 | vec = torch.zeros(ids.shape[0], vocab_size) 67 | ones = torch.ones_like(ids, dtype=torch.float) 68 | if CUDA: 69 | vec = vec.cuda() 70 | ones = ones.cuda() 71 | ids = ids.cuda() 72 | 73 | vec.scatter_add_(1, ids, ones) 74 | vec[:, 1] = 0.0 # zero out pad 75 | if not use_counts: 76 | vec = (vec != 0).float() 77 | return vec 78 | 79 | 80 | 81 | class CausalBert(DistilBertPreTrainedModel): 82 | """The model itself.""" 83 | def __init__(self, config): 84 | super().__init__(config) 85 | 86 | self.num_labels = config.num_labels 87 | self.vocab_size = config.vocab_size 88 | 89 | self.distilbert = DistilBertModel(config) 90 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 91 | self.vocab_transform = nn.Linear(config.dim, config.dim) 92 | self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) 93 | self.vocab_projector = nn.Linear(config.dim, config.vocab_size) 94 | 95 | self.Q_cls = nn.ModuleDict() 96 | 97 | for T in range(2): 98 | # ModuleDict keys have to be strings.. 99 | self.Q_cls['%d' % T] = nn.Sequential( 100 | nn.Linear(config.hidden_size + self.num_labels, 200), 101 | nn.ReLU(), 102 | nn.Linear(200, self.num_labels)) 103 | 104 | self.g_cls = nn.Linear(config.hidden_size + self.num_labels, 105 | self.config.num_labels) 106 | 107 | self.init_weights() 108 | 109 | def forward(self, W_ids, W_len, W_mask, C, T, Y=None, use_mlm=True): 110 | if use_mlm: 111 | W_len = W_len.unsqueeze(1) - 2 # -2 because of the +1 below 112 | mask_class = torch.cuda.FloatTensor if CUDA else torch.FloatTensor 113 | mask = (mask_class(W_len.shape).uniform_() * W_len.float()).long() + 1 # + 1 to avoid CLS 114 | target_words = torch.gather(W_ids, 1, mask) 115 | mlm_labels = torch.ones(W_ids.shape).long() * -100 116 | if CUDA: 117 | mlm_labels = mlm_labels.cuda() 118 | mlm_labels.scatter_(1, mask, target_words) 119 | W_ids.scatter_(1, mask, MASK_IDX) 120 | 121 | outputs = self.distilbert(W_ids, attention_mask=W_mask) 122 | seq_output = outputs[0] 123 | pooled_output = seq_output[:, 0] 124 | # seq_output, pooled_output = outputs[:2] 125 | # pooled_output = self.dropout(pooled_output) 126 | 127 | if use_mlm: 128 | prediction_logits = self.vocab_transform(seq_output) # (bs, seq_length, dim) 129 | prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim) 130 | prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) 131 | prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size) 132 | mlm_loss = CrossEntropyLoss()( 133 | prediction_logits.view(-1, self.vocab_size), mlm_labels.view(-1)) 134 | else: 135 | mlm_loss = 0.0 136 | 137 | C_bow = make_bow_vector(C.unsqueeze(1), self.num_labels) 138 | inputs = torch.cat((pooled_output, C_bow), 1) 139 | 140 | # g logits 141 | g = self.g_cls(inputs) 142 | if Y is not None: # TODO train/test mode, this is a lil hacky 143 | g_loss = CrossEntropyLoss()(g.view(-1, self.num_labels), T.view(-1)) 144 | else: 145 | g_loss = 0.0 146 | 147 | # conditional expected outcome logits: 148 | # run each example through its corresponding T matrix 149 | # TODO this would be cleaner with sigmoid and BCELoss, but less general 150 | # (and I couldn't get it to work as well) 151 | Q_logits_T0 = self.Q_cls['0'](inputs) 152 | Q_logits_T1 = self.Q_cls['1'](inputs) 153 | 154 | if Y is not None: 155 | T0_indices = (T == 0).nonzero().squeeze() 156 | Y_T1_labels = Y.clone().scatter(0, T0_indices, -100) 157 | 158 | T1_indices = (T == 1).nonzero().squeeze() 159 | Y_T0_labels = Y.clone().scatter(0, T1_indices, -100) 160 | 161 | Q_loss_T1 = CrossEntropyLoss()( 162 | Q_logits_T1.view(-1, self.num_labels), Y_T1_labels) 163 | Q_loss_T0 = CrossEntropyLoss()( 164 | Q_logits_T0.view(-1, self.num_labels), Y_T0_labels) 165 | 166 | Q_loss = Q_loss_T0 + Q_loss_T1 167 | else: 168 | Q_loss = 0.0 169 | 170 | sm = nn.Softmax(dim=1) 171 | Q0 = sm(Q_logits_T0)[:, 1] 172 | Q1 = sm(Q_logits_T1)[:, 1] 173 | g = sm(g)[:, 1] 174 | 175 | return g, Q0, Q1, g_loss, Q_loss, mlm_loss 176 | 177 | 178 | 179 | class CausalBertWrapper: 180 | """Model wrapper in charge of training and inference.""" 181 | 182 | def __init__(self, g_weight=1.0, Q_weight=0.1, mlm_weight=1.0, 183 | batch_size=32): 184 | self.model = CausalBert.from_pretrained( 185 | "distilbert-base-uncased", 186 | num_labels=2, 187 | output_attentions=False, 188 | output_hidden_states=False) 189 | if CUDA: 190 | self.model = self.model.cuda() 191 | 192 | self.loss_weights = { 193 | 'g': g_weight, 194 | 'Q': Q_weight, 195 | 'mlm': mlm_weight 196 | } 197 | self.batch_size = batch_size 198 | 199 | 200 | def train(self, texts, confounds, treatments, outcomes, 201 | learning_rate=2e-5, epochs=3): 202 | dataloader = self.build_dataloader( 203 | texts, confounds, treatments, outcomes) 204 | 205 | self.model.train() 206 | optimizer = AdamW(self.model.parameters(), lr=learning_rate, eps=1e-8) 207 | total_steps = len(dataloader) * epochs 208 | warmup_steps = total_steps * 0.1 209 | scheduler = get_linear_schedule_with_warmup( 210 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) 211 | 212 | for epoch in range(epochs): 213 | losses = [] 214 | self.model.train() 215 | for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)): 216 | if CUDA: 217 | batch = (x.cuda() for x in batch) 218 | W_ids, W_len, W_mask, C, T, Y = batch 219 | # while True: 220 | self.model.zero_grad() 221 | g, Q0, Q1, g_loss, Q_loss, mlm_loss = self.model(W_ids, W_len, W_mask, C, T, Y) 222 | loss = self.loss_weights['g'] * g_loss + \ 223 | self.loss_weights['Q'] * Q_loss + \ 224 | self.loss_weights['mlm'] * mlm_loss 225 | loss.backward() 226 | optimizer.step() 227 | scheduler.step() 228 | losses.append(loss.detach().cpu().item()) 229 | # print(np.mean(losses)) 230 | # if step > 5: continue 231 | return self.model 232 | 233 | 234 | def inference(self, texts, confounds, outcome=None): 235 | self.model.eval() 236 | dataloader = self.build_dataloader(texts, confounds, outcomes=outcome, 237 | sampler='sequential') 238 | Q0s = [] 239 | Q1s = [] 240 | Ys = [] 241 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): 242 | if CUDA: 243 | batch = (x.cuda() for x in batch) 244 | W_ids, W_len, W_mask, C, T, Y = batch 245 | g, Q0, Q1, _, _, _ = self.model(W_ids, W_len, W_mask, C, T, use_mlm=False) 246 | Q0s += Q0.detach().cpu().numpy().tolist() 247 | Q1s += Q1.detach().cpu().numpy().tolist() 248 | Ys += Y.detach().cpu().numpy().tolist() 249 | # if i > 5: break 250 | probs = np.array(list(zip(Q0s, Q1s))) 251 | preds = np.argmax(probs, axis=1) 252 | 253 | return probs, preds, Ys 254 | 255 | def ATE(self, C, W, Y=None, platt_scaling=False): 256 | Q_probs, _, Ys = self.inference(W, C, outcome=Y) 257 | if platt_scaling and Y is not None: 258 | Q0 = platt_scale(Ys, Q_probs[:, 0])[:, 0] 259 | Q1 = platt_scale(Ys, Q_probs[:, 1])[:, 1] 260 | else: 261 | Q0 = Q_probs[:, 0] 262 | Q1 = Q_probs[:, 1] 263 | 264 | return np.mean(Q1 - Q0) 265 | 266 | def build_dataloader(self, texts, confounds, treatments=None, outcomes=None, 267 | tokenizer=None, sampler='random'): 268 | def collate_CandT(data): 269 | # sort by (C, T), so you can get boundaries later 270 | # (do this here on cpu for speed) 271 | data.sort(key=lambda x: (x[1], x[2])) 272 | return data 273 | # fill with dummy values 274 | if treatments is None: 275 | treatments = [-1 for _ in range(len(confounds))] 276 | if outcomes is None: 277 | outcomes = [-1 for _ in range(len(treatments))] 278 | 279 | if tokenizer is None: 280 | tokenizer = DistilBertTokenizer.from_pretrained( 281 | 'distilbert-base-uncased', do_lower_case=True) 282 | 283 | out = defaultdict(list) 284 | for i, (W, C, T, Y) in enumerate(zip(texts, confounds, treatments, outcomes)): 285 | # out['W_raw'].append(W) 286 | encoded_sent = tokenizer.encode_plus(W, add_special_tokens=True, 287 | max_length=128, 288 | truncation=True, 289 | pad_to_max_length=True) 290 | 291 | out['W_ids'].append(encoded_sent['input_ids']) 292 | out['W_mask'].append(encoded_sent['attention_mask']) 293 | out['W_len'].append(sum(encoded_sent['attention_mask'])) 294 | out['Y'].append(Y) 295 | out['T'].append(T) 296 | out['C'].append(C) 297 | # if i > 100: break 298 | 299 | data = (torch.tensor(out[x]) for x in ['W_ids', 'W_len', 'W_mask', 'C', 'T', 'Y']) 300 | data = TensorDataset(*data) 301 | sampler = RandomSampler(data) if sampler == 'random' else SequentialSampler(data) 302 | dataloader = DataLoader(data, sampler=sampler, batch_size=self.batch_size) 303 | # collate_fn=collate_CandT) 304 | 305 | return dataloader 306 | 307 | 308 | if __name__ == '__main__': 309 | import pandas as pd 310 | 311 | df = pd.read_csv('testdata.csv') 312 | cb = CausalBertWrapper(batch_size=2, 313 | g_weight=0.1, Q_weight=0.1, mlm_weight=1) 314 | print(df.T) 315 | cb.train(df['text'], df['C'], df['T'], df['Y'], epochs=1) 316 | print(cb.ATE(df['C'], df.text, platt_scaling=True)) 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | -------------------------------------------------------------------------------- /models/causal_bert_pytorch/README.md: -------------------------------------------------------------------------------- 1 | # This version 2 | 3 | This code comes originally from https://github.com/rpryzant/causal-bert-pytorch 4 | At the time of writing the original code contained an error 5 | (https://github.com/rpryzant/causal-bert-pytorch/issues/6) 6 | that made using one of the methods (.ATE()) unadvisable. 7 | This version of code fixes this error. 8 | 9 | I also fixed minor errors in the original README.md content 10 | 11 | Below starts the original README.md file: 12 | 13 | # Causal Bert -- in Pytorch! 14 | Pytorch implementation of ["Adapting Text Embeddings for Causal Inference" by Victor Veitch, Dhanya Sridhar, and David M. Blei](https://arxiv.org/pdf/1905.12741.pdf). 15 | 16 | # Quickstart 17 | 18 | ``` 19 | pip install -r requirements.txt 20 | python CausalBert.py 21 | ``` 22 | 23 | This will train a system on some test data and calculate an average treatment effect (ATE). 24 | 25 | # Description 26 | 27 | As input this system expects data where each row consists of: 28 | * Freeform **text** 29 | * A categorical variable (numerically coded) representing a **confound** 30 | * A binary **treatment variable** 31 | * A binary **outcome variable** 32 | 33 | Then the system will give the text to BERT, and use the BERT embeddings + confound to predict 34 | 1) _P(T | C, text)_ 35 | 2) _P(Y | T = 1, C, text)_ 36 | 3) _P(Y | T = 0, C, text)_ 37 | 4) The original masked language modeling objective of BERT. 38 | 39 | Once trained the resulting BERT embeddings will be sufficient for some causal inferences. 40 | 41 | # Example 42 | 43 | ``` 44 | df = pd.read_csv('testdata.csv') 45 | cb = CausalBertWrapper(batch_size=2, # init a model wrapper 46 | g_weight=0.1, Q_weight=0.1, mlm_weight=1) 47 | cb.train(df['text'], df['C'], df['T'], df['Y'], epochs=1) # train the model 48 | print(cb.ATE(df['C'], df['text'], platt_scaling=True)) # use the model to get an average treatment effect 49 | ``` 50 | 51 | 52 | # Usage 53 | 54 | **Initialize** the model wrapper (handles training and inference): 55 | 56 | ``` 57 | cb = CausalBertWrapper( 58 | batch_size=2, # batch size for training 59 | g_weight=1.0, # loss weight for P(T | C, text) prediction head 60 | Q_weight=0.1, # loss weight for P(Y | T, C, text) prediction heads 61 | mlm_weight=1) # loss weight for original MLM objective 62 | ``` 63 | 64 | Then **train** 65 | ``` 66 | cb.train( 67 | df['text'], # list of texts 68 | df['C'], # list of confounds 69 | df['T'], # list of treatments 70 | df['Y'], # list of outcomes 71 | epochs=1) # training epochs 72 | ``` 73 | 74 | Perform **inference** 75 | 76 | ``` 77 | ( ( P(Y=1|T=1), P(Y=0|T=1)), ( P(Y=1|T=0), P(Y=0|T=0) ), ... = cb.inference( 78 | df['text'], # list of texts 79 | df['C']) # list of confounds 80 | ``` 81 | 82 | Or estimate an **average treatment effect** 83 | 84 | ``` 85 | ATE = cb.ATE( 86 | df['text'], # list of texts 87 | df['C'], # list of confounds 88 | platt_scailing=False) # https://en.wikipedia.org/wiki/Platt_scaling 89 | ``` 90 | 91 | 92 | -------------------------------------------------------------------------------- /models/causal_bert_pytorch/__pycache__/CausalBert.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/models/causal_bert_pytorch/__pycache__/CausalBert.cpython-38.pyc -------------------------------------------------------------------------------- /models/causal_bert_pytorch/__pycache__/CausalBert.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/de2a9df880e31f7217cb99b2c878feee03fc0cb1/models/causal_bert_pytorch/__pycache__/CausalBert.cpython-39.pyc -------------------------------------------------------------------------------- /models/causal_bert_pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | sklearn==0.0 2 | scipy==1.4.1 3 | torch==1.6.0 4 | keras==2.6.0 5 | transformers==4.12.2 6 | pandas==0.25.3 7 | --------------------------------------------------------------------------------