├── .gitignore ├── LICENSE ├── README.md ├── csv └── italy-raw-data.csv ├── files ├── example.pgmx └── xmlbelief.xml ├── images ├── 1 │ ├── Iris_BN.png │ └── student.png ├── 2 │ ├── student_full_param.png │ ├── three_nodes.png │ └── two_nodes.png ├── connection.png ├── connections_directed_graphs.png ├── housing_price.png ├── housing_price_small.png ├── housing_price_with_CPD.png ├── simple_pendulum.gif └── student.png ├── notebooks ├── 1. Introduction to Probabilistic Graphical Models.ipynb ├── 10. Learning Bayesian Networks from Data.ipynb ├── 11. A Bayesian Network to model the influence of energy consumption on greenhouse gases in Italy.ipynb ├── 2. Bayesian Networks.ipynb ├── 3. Causal Bayesian Networks.ipynb ├── 4. Markov Models.ipynb ├── 5. Exact Inference in Graphical Models.ipynb ├── 6. Approximate Inference in Graphical Models.ipynb ├── 7. Parameterizing with Continuous Variables.ipynb ├── 8. Sampling Algorithms.ipynb └── 9. Reading and Writing from pgmpy file formats.ipynb ├── pdfs └── Probabilistic Graphical Models using pgmpy.pdf ├── requirements.txt └── scripts └── 1 └── discretize.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Ankur Ankan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Archived repository. For updated examples, please refer: https://github.com/pgmpy/pgmpy/tree/dev/examples 2 | -------------------------------------------------------------------------------- /csv/italy-raw-data.csv: -------------------------------------------------------------------------------- 1 | "Country Name","Country Code","Indicator Name","Indicator Code","1960","1961","1962","1963","1964","1965","1966","1967","1968","1969","1970","1971","1972","1973","1974","1975","1976","1977","1978","1979","1980","1981","1982","1983","1984","1985","1986","1987","1988","1989","1990","1991","1992","1993","1994","1995","1996","1997","1998","1999","2000","2001","2002","2003","2004","2005","2006","2007","2008","2009","2010","2011","2012","2013","2014","2015","2016","2017","2018","2019" 2 | "Italy","ITA","Population growth (annual %)","SP.POP.GROW","1.99392840634543","0.668382874410405","0.676622989670017","0.729553242231653","0.822623699867765","0.842108616866269","0.77730440730804","0.723778172767654","0.631737252752781","0.56605850865814","0.528877031717835","0.466452872314018","0.567712498645759","0.678187647585975","0.654388935648374","0.597247218554801","0.498851058340731","0.424722057350148","0.356312950366128","0.289147207649738","0.20599975185947","0.12005432007943","0.0740818599916188","0.0362946366697159","0.0223509131041464","0.0288999392122477","0.00544928189677628","0.0102051195547741","0.0483205970471755","0.0750090045352437","0.0837085729587984","0.0692311889532843","0.0679244330808808","0.061135853002924","0.0203720739457194","0.00158856230917393","0.0281044080072756","0.05290695026061","0.0287740158774372","0.0168208444553697","0.045303631138613","0.0561676014374784","0.148916429490753","0.444507312655549","0.647182705470772","0.491389107503214","0.30055968851778","0.504933687397673","0.662469252941527","0.455613449577732","0.30759122234095","0.171978290995717","0.269541239520996","1.1592511168095","0.917504095962437","-0.0963761331392095","-0.169884073301448","-0.14986111697397","-0.190063640113791","" 3 | "Italy","ITA","Urban population growth (annual %)","SP.URB.GROW","2.83640101522505","1.498806985842","1.50683335249242","1.55128737033283","1.63602671711505","1.64248495869882","1.56811258601466","1.50360997508773","1.40395453417646","1.32608737490657","1.27850466807299","1.20741706004214","1.01305086677094","0.988289736577197","0.962002539492351","0.903918571621992","0.804585418380234","0.728011211530533","0.657172197390672","0.589107555455562","0.503561294703949","0.41523676692889","0.147367695402826","0.00638799801151978","-0.00756463983110026","-0.00102329301802305","-0.0244854742714839","-0.0197350167961586","0.018367630826705","0.0450504195104213","0.0537393805274748","0.0392535418060194","0.121878595507805","0.150993200701863","0.110149317664983","0.0912844846717047","0.117720946939261","0.142443745434026","0.118229703401032","0.106198406299854","0.134599938879629","0.145383965430874","0.297433898475445","0.620956423105583","0.823321280014444","0.665741727415553","0.47460792573121","0.678681719512836","0.835914949481381","0.627292740575372","0.480438950517546","0.343066345015609","0.619579150261878","1.58783518674393","1.34137129947699","0.325701454305887","0.24612722181536","0.262998539787345","0.228197851468404","" 4 | "Italy","ITA","GDP per capita growth (annual %)","NY.GDP.PCAP.KD.ZG","","7.48641883914016","5.48747776983932","4.84205248927523","1.95553270968082","2.40204614684968","5.16416328074524","6.40567760378134","5.87359459594813","5.49917950556602","4.71342067483722","1.34428044727264","3.10348970019818","6.40180593473045","4.81201385791708","-2.67318392284196","6.59231973824173","2.12582746405019","2.8729363910716","5.65322401027726","3.21717011535372","0.723232418069614","0.339224954168117","1.13249117163346","3.2027830351187","2.76838143154765","2.85436691162933","3.18143029401365","4.14404208471082","3.31086199529021","1.90043978320719","1.46817561124544","0.765807606496722","-0.913401708049989","2.13021548225964","2.88520232953732","1.23832837048366","1.77635122351289","1.78132437367471","1.60863473141113","3.73994653301939","1.89412399481279","0.104759513559856","-0.305508789463318","0.769317663351927","0.323656244406962","1.48515735957702","0.975922132963376","-1.61594062742927","-5.71150838276657","1.40091534299962","0.534287439584816","-3.24206011995115","-2.97240379999698","-0.917813879753353","0.875477401786128","1.45187539107614","1.86871464141127","0.966057526934264","" 5 | "Italy","ITA","Energy use (kg of oil equivalent per capita)","EG.USE.PCAP.KG.OE","794.816044717399","890.688346111265","1007.04834663111","1119.72317177866","1221.73080588714","1316.46285381488","1431.0744873284","1551.51917278665","1698.42756042697","1832.49991828227","2026.21747487312","1949.14872333929","2050.73787711576","2175.58754564221","2212.20669215371","2106.07874846993","2264.05718699758","2204.17980666785","2247.56551683966","2340.04203541587","2318.42730013811","2265.93108965354","2206.54273410646","2203.86418771745","2253.71892374528","2284.81129783538","2311.54807247948","2407.25635314456","2462.79677752615","2566.10892465158","2583.88807748482","2645.67323732766","2627.34000777188","2611.99043754026","2578.46597846012","2799.37576154289","2796.15473585155","2834.4029460732","2912.94736525428","2957.31148591361","3012.19499987601","3020.62119454278","3037.2669647055","3168.52902114021","3169.93264162306","3214.67412061146","3175.79987086883","3149.57655346296","3087.5663310273","2869.92071205584","2930.58852412547","2828.40489139601","2709.29772810307","2579.47254262663","2414.48400158318","2481.75464546017","","","","" 6 | "Italy","ITA","Fossil fuel energy consumption (% of total)","EG.USE.COMM.FO.ZS","85.6104208954658","87.617144513331","89.3109802387936","89.1611811462079","90.1198093870077","90.0526868903059","90.5498640135678","91.6289825520547","92.4356126799736","93.3052262281325","93.3180237787483","93.4644277422204","93.7367118008874","94.5455530891835","93.683737865099","93.1620177584507","93.9231518849086","93.0277292510405","93.4931025734777","93.844069813235","93.7303886554851","93.3510872412231","92.58894558046","92.5221878531444","91.6157211130719","91.8147752149175","91.4578323656889","93.383433772931","92.9048437377345","93.29463061558","93.4381888314656","92.8953720752908","92.5812418928353","92.2702900571247","92.0879374489954","93.0246520668659","92.6146554090799","92.4434641764684","92.2858909543463","91.8413718480898","91.7300953931907","91.4034247883699","91.0358673382802","90.2161116960919","90.406242362097","89.8014461800991","89.1908787219636","88.1830431207349","86.8421531044015","84.8513256669208","84.6287740206212","84.4647972084689","82.2013557429068","79.9643345866816","78.5864226219386","79.9484547354947","","","","" 7 | "Italy","ITA","Renewable energy consumption (% of total final energy consumption)","EG.FEC.RNEW.ZS","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","3.78146348144952","4.66941317046864","4.83946324160706","4.77149529502877","5.2480221051995","4.46227833744118","4.74891389547652","4.83141976728715","4.75762543217441","5.02245978426758","5.11951479788946","5.38256684916779","5.59666992531442","6.24058670916462","5.93594854215673","6.70263907861876","7.51256404147516","8.73172974748555","10.8233492990728","12.5397527301242","12.7940929982975","11.9049402680297","14.3916489650323","16.3208850422585","17.0903588438944","16.5168505774397","","","","" 8 | "Italy","ITA","Energy imports, net (% of energy use)","EG.IMP.CONS.ZS","65.3167016240568","68.6299659882852","72.5788986649106","74.4074707397282","74.8770028375011","76.6254237199936","77.7889599089265","79.2965129695427","80.5622987209328","81.1469375436769","81.7830568497905","81.4697038559081","81.9590870613341","82.8876260334664","82.5988613481289","82.0314727132063","82.7663495977271","83.0537976783657","83.4055147196634","84.3999752629075","84.7939688402871","83.730668876626","81.9072901855161","82.7731712175379","82.2424550236134","82.3905160971675","81.0349454334449","82.2457029543014","81.7355944614857","82.6596403559659","82.7279681603029","82.6341212470324","81.7361864990966","80.9387506237783","79.8704844175778","81.5123078720173","80.9609613285211","81.1459930992474","81.7286017337053","82.639022489153","83.5770384679968","84.4230696083365","83.7003423076947","83.5382570544466","84.064837070475","83.7920833042666","83.7151253356548","83.0854301565539","81.8823860631965","81.3307662429071","81.0026718172754","80.9821353995273","78.3338476613038","76.338774185009","75.0013037014096","76.4179821065515","","","","" 9 | "Italy","ITA","CO2 emissions (kt)","EN.ATM.CO2E.KT","109357.274","124549.655","146456.313","164780.312","175957.328","189767.25","214248.142","234427.643","249495.346","270008.544","296740.974","311588.657","328970.237","354320.208","359406.337","342310.783","367312.389","356175.71","373234.594","387466.221","388969.691","378020.029","369681.271","361364.515","367649.753","372134.494","366454.311","383670.876","389780.098","408907.17","417550.289","423894.199","420535.227","411917.777","407378.031","430483.798","425830.375","430487.465","440153.677","441998.178","450564.29","450347.937","452610.476","468349.24","473970.751","473384.031","469346.664","462676.391","447186.983","401591.505","405361.181","397994.178","369468.585","345317.723","320411.459","","","","","" 10 | "Italy","ITA","Methane emissions in energy sector (thousand metric tons of CO2 equivalent)","EN.ATM.METH.EG.KT.CE","","","","","","","","","","","4611.97656711245","4555.8883317771","4772.68274496568","5093.31144892858","5551.81603429452","5838.5639377942","6284.56787907265","6412.82730428888","6527.34290358489","6744.29186704468","6808.36005127144","6893.97146439988","6988.74051687669","6982.10169998515","7227.8908349347","7436.56892677431","7722.67634745305","7953.22909730878","8149.11572062463","8421.71069826367","8599.63149859163","8731.39924415843","8780.55917897662","8660.102724886","8633.6154322617","8422.07226873451","8562.42324880918","8360.32951056901","8061.47951516973","7807.0795129992","7587.2090967216","7170.05326174286","7085.60527846628","6739.42211792086","6548.34506224891","6154.04622231858","6177.78125932456","6261.2671605469","6244.73502392927","","","","","","","","","","","" 11 | "Italy","ITA","Nitrous oxide emissions in energy sector (thousand metric tons of CO2 equivalent)","EN.ATM.NOXE.EG.KT.CE","","","","","","","","","","","1480.02984916","1534.15141089","1587.834578587","1644.833729701","1669.159068923","1628.2190822","1671.259018316","1694.598128839","1741.21993562","1865.8930597","1885.72524367","1849.73549576","1858.16878071","1841.63314666","1890.62245131","1944.03723722","1955.98341675","1968.74459465","2137.70375918","2304.89575321","2344.6291783","2315.6671555","2432.181787","2557.7053378","2626.4618188","2754.4412299","2750.9560642","2763.7745611","2827.2205902","2811.76301525","2767.996696","2816.5296295","2868.3449422","2942.274852","3313.2912778","3287.2587922","3387.9273893","3242.0693141","3196.8895421","","","","","","","","","","","" -------------------------------------------------------------------------------- /files/example.pgmx: -------------------------------------------------------------------------------- 1 | 2 | 3 | Student example model from Probabilistic Graphical Models: Principles and Techniques by Daphne Koller 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 0.01 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 0.01 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 0.01 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 0.01 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 0.01 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 0.01 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 0.01 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 0.01 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 0.95 0.05 0.02 0.98 144 | 145 | 146 | 147 | 148 | 149 | 150 | 0.7 0.3 0.4 0.6 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 0.9 0.1 0.3 0.7 0.2 0.8 0.1 0.9 159 | 160 | 161 | 162 | 163 | 164 | 0.99 0.01 165 | 166 | 167 | 168 | 169 | 170 | 0.5 0.5 171 | 172 | 173 | 174 | 175 | 176 | 177 | 0.99 0.01 0.9 0.1 178 | 179 | 180 | 181 | 182 | 183 | 184 | 0.99 0.01 0.95 0.05 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 1.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /files/xmlbelief.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | (a) Metastatic Cancer 11 | Present 12 | Absent 13 | 14 | 15 | (b) Serum Calcium Increase 16 | Present 17 | Absent 18 | 19 | 20 | (c) Brain Tumor 21 | Present 22 | Absent 23 | 24 | 25 | (d) Coma 26 | Present 27 | Absent 28 | 29 | 30 | (e) Papilledema 31 | Present 32 | Absent 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 0.2 0.8 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 0.8 0.2 56 | 0.2 0.8 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 0.2 0.8 66 | 0.05 0.95 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 0.8 0.2 77 | 0.9 0.1 78 | 0.7 0.3 79 | 0.05 0.95 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 0.8 0.2 89 | 0.6 0.4 90 | 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /images/1/Iris_BN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/1/Iris_BN.png -------------------------------------------------------------------------------- /images/1/student.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/1/student.png -------------------------------------------------------------------------------- /images/2/student_full_param.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/2/student_full_param.png -------------------------------------------------------------------------------- /images/2/three_nodes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/2/three_nodes.png -------------------------------------------------------------------------------- /images/2/two_nodes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/2/two_nodes.png -------------------------------------------------------------------------------- /images/connection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/connection.png -------------------------------------------------------------------------------- /images/connections_directed_graphs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/connections_directed_graphs.png -------------------------------------------------------------------------------- /images/housing_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/housing_price.png -------------------------------------------------------------------------------- /images/housing_price_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/housing_price_small.png -------------------------------------------------------------------------------- /images/housing_price_with_CPD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/housing_price_with_CPD.png -------------------------------------------------------------------------------- /images/simple_pendulum.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/simple_pendulum.gif -------------------------------------------------------------------------------- /images/student.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/images/student.png -------------------------------------------------------------------------------- /notebooks/10. Learning Bayesian Networks from Data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Learning Bayesian Networks from Data\n", 8 | "\n", 9 | "\n", 10 | "Previous notebooks showed how Bayesian networks economically encode a probability distribution over a set of variables, and how they can be used e.g. to predict variable states, or to generate new samples from the joint distribution. This section will be about obtaining a Bayesian network, given a set of sample data. Learning a Bayesian network can be split into two problems:\n", 11 | "\n", 12 | " **Parameter learning:** Given a set of data samples and a DAG that captures the dependencies between the variables, estimate the (conditional) probability distributions of the individual variables.\n", 13 | " \n", 14 | " **Structure learning:** Given a set of data samples, estimate a DAG that captures the dependencies between the variables.\n", 15 | " \n", 16 | "This notebook aims to illustrate how parameter learning and structure learning can be done with pgmpy.\n", 17 | "Currently, the library supports:\n", 18 | " - Parameter learning for *discrete* nodes:\n", 19 | " - Maximum Likelihood Estimation\n", 20 | " - Bayesian Estimation\n", 21 | " - Structure learning for *discrete*, *fully observed* networks:\n", 22 | " - Score-based structure estimation (BIC/BDeu/K2 score; exhaustive search, hill climb/tabu search)\n", 23 | " - Constraint-based structure estimation (PC)\n", 24 | " - Hybrid structure estimation (MMHC)\n", 25 | "\n", 26 | "\n", 27 | "## Parameter Learning\n", 28 | "\n", 29 | "Suppose we have the following data:" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | " fruit tasty size\n", 42 | "0 banana yes large\n", 43 | "1 apple no large\n", 44 | "2 banana yes large\n", 45 | "3 apple yes small\n", 46 | "4 banana yes large\n", 47 | "5 apple yes large\n", 48 | "6 banana yes large\n", 49 | "7 apple yes small\n", 50 | "8 apple yes large\n", 51 | "9 apple yes large\n", 52 | "10 banana yes large\n", 53 | "11 banana no large\n", 54 | "12 apple no small\n", 55 | "13 banana no small\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "import pandas as pd\n", 61 | "data = pd.DataFrame(data={'fruit': [\"banana\", \"apple\", \"banana\", \"apple\", \"banana\",\"apple\", \"banana\", \n", 62 | " \"apple\", \"apple\", \"apple\", \"banana\", \"banana\", \"apple\", \"banana\",], \n", 63 | " 'tasty': [\"yes\", \"no\", \"yes\", \"yes\", \"yes\", \"yes\", \"yes\", \n", 64 | " \"yes\", \"yes\", \"yes\", \"yes\", \"no\", \"no\", \"no\"], \n", 65 | " 'size': [\"large\", \"large\", \"large\", \"small\", \"large\", \"large\", \"large\",\n", 66 | " \"small\", \"large\", \"large\", \"large\", \"large\", \"small\", \"small\"]})\n", 67 | "print(data)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "We know that the variables relate as follows:" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 2, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stderr", 84 | "output_type": "stream", 85 | "text": [ 86 | "/home/ankur/pgmpy_notebook/notebooks/pgmpy/models/BayesianModel.py:8: FutureWarning: BayesianModel has been renamed to BayesianNetwork. Please use BayesianNetwork class, BayesianModel will be removed in future.\n", 87 | " warnings.warn(\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "from pgmpy.models import BayesianModel\n", 93 | "\n", 94 | "model = BayesianModel([('fruit', 'tasty'), ('size', 'tasty')]) # fruit -> tasty <- size" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "Parameter learning is the task to estimate the values of the conditional probability distributions (CPDs), for the variables `fruit`, `size`, and `tasty`. \n", 102 | "\n", 103 | "#### State counts\n", 104 | "To make sense of the given data, we can start by counting how often each state of the variable occurs. If the variable is dependent on parents, the counts are done conditionally on the parents states, i.e. for seperately for each parent configuration:" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 3, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "\n", 117 | " fruit\n", 118 | "apple 7\n", 119 | "banana 7\n", 120 | "\n", 121 | " fruit apple banana \n", 122 | "size large small large small\n", 123 | "tasty \n", 124 | "no 1.0 1.0 1.0 1.0\n", 125 | "yes 3.0 2.0 5.0 0.0\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "from pgmpy.estimators import ParameterEstimator\n", 131 | "pe = ParameterEstimator(model, data)\n", 132 | "print(\"\\n\", pe.state_counts('fruit')) # unconditional\n", 133 | "print(\"\\n\", pe.state_counts('tasty')) # conditional on fruit and size" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "We can see, for example, that as many apples as bananas were observed and that `5` large bananas were tasty, while only `1` was not.\n", 141 | "\n", 142 | "#### Maximum Likelihood Estimation\n", 143 | "\n", 144 | "A natural estimate for the CPDs is to simply use the *relative frequencies*, with which the variable states have occured. We observed `7 apples` among a total of `14 fruits`, so we might guess that about `50%` of `fruits` are `apples`.\n", 145 | "\n", 146 | "This approach is *Maximum Likelihood Estimation (MLE)*. According to MLE, we should fill the CPDs in such a way, that $P(\\text{data}|\\text{model})$ is maximal. This is achieved when using the *relative frequencies*. See [1], section 17.1 for an introduction to ML parameter estimation. pgmpy supports MLE as follows:" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 4, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "+---------------+-----+\n", 159 | "| fruit(apple) | 0.5 |\n", 160 | "+---------------+-----+\n", 161 | "| fruit(banana) | 0.5 |\n", 162 | "+---------------+-----+\n", 163 | "+------------+--------------+-----+---------------+\n", 164 | "| fruit | fruit(apple) | ... | fruit(banana) |\n", 165 | "+------------+--------------+-----+---------------+\n", 166 | "| size | size(large) | ... | size(small) |\n", 167 | "+------------+--------------+-----+---------------+\n", 168 | "| tasty(no) | 0.25 | ... | 1.0 |\n", 169 | "+------------+--------------+-----+---------------+\n", 170 | "| tasty(yes) | 0.75 | ... | 0.0 |\n", 171 | "+------------+--------------+-----+---------------+\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "from pgmpy.estimators import MaximumLikelihoodEstimator\n", 177 | "mle = MaximumLikelihoodEstimator(model, data)\n", 178 | "print(mle.estimate_cpd('fruit')) # unconditional\n", 179 | "print(mle.estimate_cpd('tasty')) # conditional" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "`mle.estimate_cpd(variable)` computes the state counts and divides each cell by the (conditional) sample size. The `mle.get_parameters()`-method returns a list of CPDs for all variable of the model.\n", 187 | "\n", 188 | "The built-in `fit()`-method of `BayesianModel` provides more convenient access to parameter estimators:\n" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 5, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "# Calibrate all CPDs of `model` using MLE:\n", 198 | "model.fit(data, estimator=MaximumLikelihoodEstimator)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "\n", 206 | "While very straightforward, the ML estimator has the problem of *overfitting* to the data. In above CPD, the probability of a large banana being tasty is estimated at `0.833`, because `5` out of `6` observed large bananas were tasty. Fine. But note that the probability of a small banana being tasty is estimated at `0.0`, because we observed only one small banana and it happened to be not tasty. But that should hardly make us certain that small bananas aren't tasty!\n", 207 | "We simply do not have enough observations to rely on the observed frequencies. If the observed data is not representative for the underlying distribution, ML estimations will be extremly far off. \n", 208 | "\n", 209 | "When estimating parameters for Bayesian networks, lack of data is a frequent problem. Even if the total sample size is very large, the fact that state counts are done conditionally for each parents configuration causes immense fragmentation. If a variable has 3 parents that can each take 10 states, then state counts will be done seperately for `10^3 = 1000` parents configurations. This makes MLE very fragile and unstable for learning Bayesian Network parameters. A way to mitigate MLE's overfitting is *Bayesian Parameter Estimation*.\n", 210 | "\n", 211 | "#### Bayesian Parameter Estimation\n", 212 | "\n", 213 | "The Bayesian Parameter Estimator starts with already existing prior CPDs, that express our beliefs about the variables *before* the data was observed. Those \"priors\" are then updated, using the state counts from the observed data. See [1], Section 17.3 for a general introduction to Bayesian estimators.\n", 214 | "\n", 215 | "One can think of the priors as consisting in *pseudo state counts*, that are added to the actual counts before normalization.\n", 216 | "Unless one wants to encode specific beliefs about the distributions of the variables, one commonly chooses uniform priors, i.e. ones that deem all states equiprobable.\n", 217 | "\n", 218 | "A very simple prior is the so-called *K2* prior, which simply adds `1` to the count of every single state.\n", 219 | "A somewhat more sensible choice of prior is *BDeu* (Bayesian Dirichlet equivalent uniform prior). For BDeu we need to specify an *equivalent sample size* `N` and then the pseudo-counts are the equivalent of having observed `N` uniform samples of each variable (and each parent configuration). In pgmpy:\n", 220 | "\n", 221 | "\n" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 6, 227 | "metadata": { 228 | "scrolled": true 229 | }, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "+------------+---------------------+-----+---------------------+\n", 236 | "| fruit | fruit(apple) | ... | fruit(banana) |\n", 237 | "+------------+---------------------+-----+---------------------+\n", 238 | "| size | size(large) | ... | size(small) |\n", 239 | "+------------+---------------------+-----+---------------------+\n", 240 | "| tasty(no) | 0.34615384615384615 | ... | 0.6428571428571429 |\n", 241 | "+------------+---------------------+-----+---------------------+\n", 242 | "| tasty(yes) | 0.6538461538461539 | ... | 0.35714285714285715 |\n", 243 | "+------------+---------------------+-----+---------------------+\n" 244 | ] 245 | } 246 | ], 247 | "source": [ 248 | "from pgmpy.estimators import BayesianEstimator\n", 249 | "est = BayesianEstimator(model, data)\n", 250 | "\n", 251 | "print(est.estimate_cpd('tasty', prior_type='BDeu', equivalent_sample_size=10))" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "The estimated values in the CPDs are now more conservative. In particular, the estimate for a small banana being not tasty is now around `0.64` rather than `1.0`. Setting `equivalent_sample_size` to `10` means that for each parent configuration, we add the equivalent of 10 uniform samples (here: `+5` small bananas that are tasty and `+5` that aren't).\n", 259 | "\n", 260 | "`BayesianEstimator`, too, can be used via the `fit()`-method. Full example:" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 7, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "name": "stderr", 270 | "output_type": "stream", 271 | "text": [ 272 | "/home/ankur/pgmpy_notebook/notebooks/pgmpy/models/BayesianModel.py:8: FutureWarning: BayesianModel has been renamed to BayesianNetwork. Please use BayesianNetwork class, BayesianModel will be removed in future.\n", 273 | " warnings.warn(\n" 274 | ] 275 | }, 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | "+------+----------+\n", 281 | "| A(0) | 0.511788 |\n", 282 | "+------+----------+\n", 283 | "| A(1) | 0.488212 |\n", 284 | "+------+----------+\n", 285 | "+------+---------------------+---------------------+\n", 286 | "| A | A(0) | A(1) |\n", 287 | "+------+---------------------+---------------------+\n", 288 | "| B(0) | 0.49199687682998244 | 0.5002046245140168 |\n", 289 | "+------+---------------------+---------------------+\n", 290 | "| B(1) | 0.5080031231700176 | 0.49979537548598324 |\n", 291 | "+------+---------------------+---------------------+\n", 292 | "+------+--------------------+-----+---------------------+\n", 293 | "| A | A(0) | ... | A(1) |\n", 294 | "+------+--------------------+-----+---------------------+\n", 295 | "| D | D(0) | ... | D(1) |\n", 296 | "+------+--------------------+-----+---------------------+\n", 297 | "| C(0) | 0.4882005899705015 | ... | 0.5085907138474126 |\n", 298 | "+------+--------------------+-----+---------------------+\n", 299 | "| C(1) | 0.5117994100294986 | ... | 0.49140928615258744 |\n", 300 | "+------+--------------------+-----+---------------------+\n", 301 | "+------+--------------------+---------------------+\n", 302 | "| B | B(0) | B(1) |\n", 303 | "+------+--------------------+---------------------+\n", 304 | "| D(0) | 0.5120845921450151 | 0.48414271555996036 |\n", 305 | "+------+--------------------+---------------------+\n", 306 | "| D(1) | 0.4879154078549849 | 0.5158572844400396 |\n", 307 | "+------+--------------------+---------------------+\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "import numpy as np\n", 313 | "import pandas as pd\n", 314 | "from pgmpy.models import BayesianModel\n", 315 | "from pgmpy.estimators import BayesianEstimator\n", 316 | "\n", 317 | "# generate data\n", 318 | "data = pd.DataFrame(np.random.randint(low=0, high=2, size=(5000, 4)), columns=['A', 'B', 'C', 'D'])\n", 319 | "model = BayesianModel([('A', 'B'), ('A', 'C'), ('D', 'C'), ('B', 'D')])\n", 320 | "\n", 321 | "model.fit(data, estimator=BayesianEstimator, prior_type=\"BDeu\") # default equivalent_sample_size=5\n", 322 | "for cpd in model.get_cpds():\n", 323 | " print(cpd)\n" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "## Structure Learning\n", 331 | "\n", 332 | "To learn model structure (a DAG) from a data set, there are two broad techniques:\n", 333 | "\n", 334 | " - score-based structure learning\n", 335 | " - constraint-based structure learning\n", 336 | "\n", 337 | "The combination of both techniques allows further improvement:\n", 338 | " - hybrid structure learning\n", 339 | "\n", 340 | "We briefly discuss all approaches and give examples.\n", 341 | "\n", 342 | "### Score-based Structure Learning\n", 343 | "\n", 344 | "\n", 345 | "This approach construes model selection as an optimization task. It has two building blocks:\n", 346 | "\n", 347 | "- A _scoring function_ $s_D\\colon M \\to \\mathbb R$ that maps models to a numerical score, based on how well they fit to a given data set $D$.\n", 348 | "- A _search strategy_ to traverse the search space of possible models $M$ and select a model with optimal score.\n", 349 | "\n", 350 | "\n", 351 | "#### Scoring functions\n", 352 | "\n", 353 | "Commonly used scores to measure the fit between model and data are _Bayesian Dirichlet scores_ such as *BDeu* or *K2* and the _Bayesian Information Criterion_ (BIC, also called MDL). See [1], Section 18.3 for a detailed introduction on scores. As before, BDeu is dependent on an equivalent sample size." 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 8, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "name": "stdout", 363 | "output_type": "stream", 364 | "text": [ 365 | "-13938.353002020234\n", 366 | "-14329.194269073454\n", 367 | "-14294.390420213556\n", 368 | "-20906.432489257266\n", 369 | "-20933.26023936978\n", 370 | "-20950.47339067585\n" 371 | ] 372 | }, 373 | { 374 | "name": "stderr", 375 | "output_type": "stream", 376 | "text": [ 377 | "/home/ankur/pgmpy_notebook/notebooks/pgmpy/models/BayesianModel.py:8: FutureWarning: BayesianModel has been renamed to BayesianNetwork. Please use BayesianNetwork class, BayesianModel will be removed in future.\n", 378 | " warnings.warn(\n" 379 | ] 380 | } 381 | ], 382 | "source": [ 383 | "import pandas as pd\n", 384 | "import numpy as np\n", 385 | "from pgmpy.estimators import BDeuScore, K2Score, BicScore\n", 386 | "from pgmpy.models import BayesianModel\n", 387 | "\n", 388 | "# create random data sample with 3 variables, where Z is dependent on X, Y:\n", 389 | "data = pd.DataFrame(np.random.randint(0, 4, size=(5000, 2)), columns=list('XY'))\n", 390 | "data['Z'] = data['X'] + data['Y']\n", 391 | "\n", 392 | "bdeu = BDeuScore(data, equivalent_sample_size=5)\n", 393 | "k2 = K2Score(data)\n", 394 | "bic = BicScore(data)\n", 395 | "\n", 396 | "model1 = BayesianModel([('X', 'Z'), ('Y', 'Z')]) # X -> Z <- Y\n", 397 | "model2 = BayesianModel([('X', 'Z'), ('X', 'Y')]) # Y <- X -> Z\n", 398 | "\n", 399 | "\n", 400 | "print(bdeu.score(model1))\n", 401 | "print(k2.score(model1))\n", 402 | "print(bic.score(model1))\n", 403 | "\n", 404 | "print(bdeu.score(model2))\n", 405 | "print(k2.score(model2))\n", 406 | "print(bic.score(model2))\n" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": {}, 412 | "source": [ 413 | "While the scores vary slightly, we can see that the correct `model1` has a much higher score than `model2`.\n", 414 | "Importantly, these scores _decompose_, i.e. they can be computed locally for each of the variables given their potential parents, independent of other parts of the network:" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 9, 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "name": "stdout", 424 | "output_type": "stream", 425 | "text": [ 426 | "-9282.88160824462\n", 427 | "-6993.603560250576\n", 428 | "-57.1217389219957\n" 429 | ] 430 | } 431 | ], 432 | "source": [ 433 | "print(bdeu.local_score('Z', parents=[]))\n", 434 | "print(bdeu.local_score('Z', parents=['X']))\n", 435 | "print(bdeu.local_score('Z', parents=['X', 'Y']))" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": {}, 441 | "source": [ 442 | "#### Search strategies\n", 443 | "The search space of DAGs is super-exponential in the number of variables and the above scoring functions allow for local maxima. The first property makes exhaustive search intractable for all but very small networks, the second prohibits efficient local optimization algorithms to always find the optimal structure. Thus, identifiying the ideal structure is often not tractable. Despite these bad news, heuristic search strategies often yields good results.\n", 444 | "\n", 445 | "If only few nodes are involved (read: less than 5), `ExhaustiveSearch` can be used to compute the score for every DAG and returns the best-scoring one:" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 10, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "name": "stdout", 455 | "output_type": "stream", 456 | "text": [ 457 | "[('X', 'Z'), ('Y', 'Z')]\n", 458 | "\n", 459 | "All DAGs by score:\n", 460 | "-14294.390420213556 [('X', 'Z'), ('Y', 'Z')]\n", 461 | "-14330.086974085189 [('X', 'Z'), ('Y', 'Z'), ('Y', 'X')]\n", 462 | "-14330.086974085189 [('X', 'Y'), ('X', 'Z'), ('Z', 'Y')]\n", 463 | "-14330.08697408519 [('Y', 'X'), ('Z', 'X'), ('Z', 'Y')]\n", 464 | "-14330.08697408519 [('Y', 'Z'), ('Y', 'X'), ('Z', 'X')]\n", 465 | "-14330.08697408519 [('X', 'Y'), ('Z', 'X'), ('Z', 'Y')]\n", 466 | "-14330.08697408519 [('X', 'Y'), ('X', 'Z'), ('Y', 'Z')]\n", 467 | "-16586.926723773093 [('Y', 'X'), ('Z', 'X')]\n", 468 | "-16587.66791728165 [('X', 'Y'), ('Z', 'Y')]\n", 469 | "-18657.937087116316 [('Z', 'X'), ('Z', 'Y')]\n", 470 | "-18657.937087116316 [('Y', 'Z'), ('Z', 'X')]\n", 471 | "-18657.937087116316 [('X', 'Z'), ('Z', 'Y')]\n", 472 | "-20914.776836804216 [('Z', 'X')]\n", 473 | "-20914.776836804216 [('X', 'Z')]\n", 474 | "-20915.518030312778 [('Z', 'Y')]\n", 475 | "-20915.518030312778 [('Y', 'Z')]\n", 476 | "-20950.47339067585 [('X', 'Z'), ('Y', 'X')]\n", 477 | "-20950.47339067585 [('X', 'Y'), ('Z', 'X')]\n", 478 | "-20950.47339067585 [('X', 'Y'), ('X', 'Z')]\n", 479 | "-20951.21458418441 [('Y', 'X'), ('Z', 'Y')]\n", 480 | "-20951.21458418441 [('Y', 'Z'), ('Y', 'X')]\n", 481 | "-20951.21458418441 [('X', 'Y'), ('Y', 'Z')]\n", 482 | "-23172.357780000675 []\n", 483 | "-23208.05433387231 [('Y', 'X')]\n", 484 | "-23208.05433387231 [('X', 'Y')]\n" 485 | ] 486 | } 487 | ], 488 | "source": [ 489 | "from pgmpy.estimators import ExhaustiveSearch\n", 490 | "\n", 491 | "es = ExhaustiveSearch(data, scoring_method=bic)\n", 492 | "best_model = es.estimate()\n", 493 | "print(best_model.edges())\n", 494 | "\n", 495 | "print(\"\\nAll DAGs by score:\")\n", 496 | "for score, dag in reversed(es.all_scores()):\n", 497 | " print(score, dag.edges())" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": {}, 503 | "source": [ 504 | "Once more nodes are involved, one needs to switch to heuristic search. `HillClimbSearch` implements a greedy local search that starts from the DAG `start` (default: disconnected DAG) and proceeds by iteratively performing single-edge manipulations that maximally increase the score. The search terminates once a local maximum is found.\n", 505 | "\n", 506 | "\n" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 11, 512 | "metadata": {}, 513 | "outputs": [ 514 | { 515 | "data": { 516 | "application/vnd.jupyter.widget-view+json": { 517 | "model_id": "e1f781a2436e4066ab8e015257068eb8", 518 | "version_major": 2, 519 | "version_minor": 0 520 | }, 521 | "text/plain": [ 522 | " 0%| | 0/1000000 [00:00\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "colab_type": "text", 17 | "id": "qR-vG2CF4jaz" 18 | }, 19 | "source": [ 20 | "# Causal Bayesian Networks\n", 21 | "\n", 22 | "Causal Inference is a new feature for pgmpy, so I wanted to develop a few examples which show off the features that we're developing! \n", 23 | "\n", 24 | "This particular notebook walks through the 5 games that used as examples for building intuition about backdoor paths in *The Book of Why* by Judea Peal. I have consistently been using them to test different implementations of backdoor adjustment from different libraries and include them as unit tests in pgmpy, so I wanted to walk through them and a few other related games as a potential resource to both understand the implementation of CausalInference in pgmpy, as well as develope some useful intuitions about backdoor paths. \n", 25 | "\n", 26 | "## Objective of the Games\n", 27 | "\n", 28 | "For each game we get a causal graph, and our goal is to identify the set of deconfounders (often denoted $Z$) which will close all backdoor paths from nodes $X$ to $Y$. For the time being, I'll assume that you're familiar with the concept of backdoor paths, though I may expand this portion to explain it. " 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 21, 34 | "metadata": { 35 | "cellView": "both", 36 | "colab": { 37 | "base_uri": "https://localhost:8080/", 38 | "height": 258 39 | }, 40 | "colab_type": "code", 41 | "id": "p1uBjhCQgaaG", 42 | "outputId": "25270025-2cd2-4ce6-a8e6-7d16e2855492" 43 | }, 44 | "outputs": [ 45 | { 46 | "name": "stderr", 47 | "output_type": "stream", 48 | "text": [ 49 | "/usr/local/lib/python3.6/dist-packages/statsmodels/compat/pandas.py:56: FutureWarning: The pandas.core.datetools module is deprecated and will be removed in a future version. Please use the pandas.tseries module instead.\n", 50 | " from pandas.core import datetools\n" 51 | ] 52 | }, 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "Your branch is up to date with 'origin/feature/causalmodel'.\n", 58 | "Updating c459420..95af23b\n", 59 | "Fast-forward\n", 60 | " pgmpy/inference/causal_inference.py | 14 +++++++-------\n", 61 | " pgmpy/tests/test_inference/test_causal_inference.py | 2 +-\n", 62 | " 2 files changed, 8 insertions(+), 8 deletions(-)\n" 63 | ] 64 | }, 65 | { 66 | "name": "stderr", 67 | "output_type": "stream", 68 | "text": [ 69 | "fatal: destination path 'pgmpy' already exists and is not an empty directory.\n", 70 | "mv: cannot move '/content/pgmpy' to '/content/pgmpydev/pgmpy': Directory not empty\n", 71 | "Already on 'feature/causalmodel'\n", 72 | "From https://github.com/mrklees/pgmpy\n", 73 | " c459420..95af23b feature/causalmodel -> origin/feature/causalmodel\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "#@title Clone the Development Repo & Install Requirements\n", 79 | "#@markdown Because the Causal Inference class is currently in dev, we will actually need to pull the code from GitHub. This cell will give us a pretty good development environment for interactively developing and testing the CausalModel class and its methods.\n", 80 | "#@markdown You only need to run this the first time you've started the kernel.\n", 81 | "%%sh\n", 82 | "git clone https://github.com/mrklees/pgmpy.git\n", 83 | "mv /content/pgmpy /content/pgmpydev\n", 84 | "cd pgmpydev/\n", 85 | "git checkout feature/causalmodel\n", 86 | "git pull\n", 87 | "#@markdown In testing the CausalModel and Bayesian Network portion of pgmpy we've actually been able to use up to date version of Networkx and other packages, but we may be forced to downgrade to networkx 1.11 if errors arise.\n", 88 | "#pip install -U -r requirements-dev.txt" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 22, 94 | "metadata": { 95 | "cellView": "both", 96 | "colab": { 97 | "base_uri": "https://localhost:8080/", 98 | "height": 51 99 | }, 100 | "colab_type": "code", 101 | "id": "EkFVGD3Tv7ma", 102 | "outputId": "a3fd160d-dcdf-4495-e8ea-e5b3aedadaf5" 103 | }, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "The autoreload extension is already loaded. To reload it, use:\n", 110 | " %reload_ext autoreload\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "#@title Imports\n", 116 | "%load_ext autoreload\n", 117 | "%autoreload 2\n", 118 | "import sys\n", 119 | "sys.path.append('/content/pgmpydev')\n", 120 | "\n", 121 | "!pip3 install -q daft\n", 122 | "import matplotlib.pyplot as plt\n", 123 | "%matplotlib inline\n", 124 | "import daft\n", 125 | "from daft import PGM\n", 126 | "\n", 127 | "# We can now import the development version of pgmpy\n", 128 | "from pgmpy.models.BayesianModel import BayesianModel\n", 129 | "from pgmpy.inference.causal_inference import CausalInference\n", 130 | "\n", 131 | "def convert_pgm_to_pgmpy(pgm):\n", 132 | " \"\"\"Takes a Daft PGM object and converts it to a pgmpy BayesianModel\"\"\"\n", 133 | " edges = [(edge.node1.name, edge.node2.name) for edge in pgm._edges]\n", 134 | " model = BayesianModel(edges)\n", 135 | " return model" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 23, 141 | "metadata": { 142 | "cellView": "form", 143 | "colab": { 144 | "base_uri": "https://localhost:8080/", 145 | "height": 201 146 | }, 147 | "colab_type": "code", 148 | "id": "P90trQAQ7Clc", 149 | "outputId": "2e62d5b8-de0e-4e28-b14d-5c03bcd1ab61" 150 | }, 151 | "outputs": [ 152 | { 153 | "data": { 154 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPEAAAC4CAYAAAAynAqtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAD/FJREFUeJzt3XnUVdV5x/HvowiIOMVZghJq3NE6\nVdSSOKFiNV0xNU44VINlKUaJbUptrFOXY0yzbKpdWhPTCA61LI0ZOjkkS63G1EQUtRaPicUJYwyJ\njSAgDk//eA7G9S7gved9z7n33YffZ6274IV79t7nvud3z7T3PubuiEi+1ul1A0RkcBRikcwpxCKZ\nU4hFMqcQi2ROIRbJnEIskjmFWCRzCrFI5hRikcwpxCKZU4hFMqcQi2ROIRbJnEIskjmFWCRzCrFI\n5hRikcwpxCKZU4hFMqcQi2ROIRbJnEIskjmFWCRzCrFI5hRikcwpxCKZU4hFMqcQi2ROIRbJnEIs\nkjmFWCRzCrFI5hRikcwpxCKZU4hFMqcQi2ROIRbJnEIskjmFWCRzCrFI5hRikcwpxCKZU4hFMqcQ\ni2ROIRbJnEIskrlhvW5AL5jZKGAnYDTgwK+Bwt3f7mnDamRmWwHjgZHACmAh8IK7e08bVhMzWwfY\nAdiS2I6XAc+6++s9bVgPWEt+p/0ysz2A04EDiI37WeANwIAtgLHAfwP3AF9395d61NQBMbNhwKeA\nk4F9gA2IdVwODAfGlX8+CtwO3ObuS3vS2AEys02BqcAfAb9HfPkuBN4FRgEfA34B/Ai4EbivLV9a\na+TurX4Bk4CHgReBC4A9geGreN+GRMCvAX4F3Ans0Ov2d7B+6wJ/CrwEPAR8lviSslW8dxvgKOB7\nwCLgS8CoXq9DB+u4OfA14HXgFuBwYLPVfBY7AWcBTwHzgRNX9Vm06dXzBjT4ix8N/D3wMnAcMKzi\nsn9RbuhnA+v0en1W084E/BC4H9ij4rLjgFuBnwL79Xpd1tDOo4CfA18FtqywnJVf4E8B3wG27vW6\nNPYZ9boBDf3itwaeAG4GNh1EOR8t9+J3ACN6vV592jYZ+CUwYzBfMsCRZUhO7/U69WmXAVcAPwP2\nHUQ5I4DLiMPuXXu9Xo18Vr1uQAO//C3Kw6iL6jiMKjeCO8tv84735g2v48HAa8D+NZW3A/D8UAoy\ncCUwF9i8pvKmAK8CO/d63ep+terCVnnF8l7gJ+5+bo3lDgf+BXi8znIH2JaxxMZ9rLs/UGO5OwAP\nluU+VFe5A2zLycD5xB74VzWXezGwm7svqavcXmtbiM8krs7u5+7v1lz2VsQh+qfd/cd1ll2hDQb8\nB/Cgu1/eQPlHAl8BdvceXbk2s22Iz/lwd3+sgfJvBJa6+1l1l90rrQlx+ct/igjwMw3VMQW4kPgm\nf6+JOvqp/yTgC8BEd3+noTpuJe4nn9dE+R3UP4e433thQ+VvQmwnx7n7j5qoo9vaFOKLgG3c/XMN\n1mHAY8AX3f2epupZQ92PAue7+10N1jOurGe7bu+NzWx74lRhe3d/s8F6ZhDXE6Y0VUc3taLbpZmt\nB0wH/qHJejy+8a4l7kN22z7AJkRnlMa4+/NEZ4njm6xnNaYDNzcZ4NLNwB+UR2/Za0WIgY8Dr7r7\nk528OaU0OaV0/wd+HpNSKlJKG3Ww+G3AZDPbYGBNHbDjgNlVDuNTSieklN5OKW1esa4by/q67Vhg\nVidvTCldkFK64gM/r5NSmpdS2q2/Zd39N0SHl88MtKFDSVtCvBex9+hIURTfB15MKZ1S/tNVwPlF\nUbzR37LlXmI+sPtAGjoIldaxdCLwHHBMxeX+C5hQHsJ3Rdmlcmui62snrgKOTimNKX8+FXikKIqO\nvsgp17FaK4emtoR4AnEuVcWfA+emlI4ENiyK4o4Ky84lQlUbMxtmZtPM7HAz26LP/61D9BXu+Gpt\nSulDxCH4TOCEKm1x91eAd4DtqizXHzPbxsw+Z2YTzWz9Pv+9J/BEp3cViqJYBlwKXJZSGkX0sKty\nMWwuLQlxW0YxfRhYUGWBoigWpZSuAuYQHeerWADsZma7VFxuTXYGvg4sAUaa2WJgHvAA8DRAxXum\nxwL/CtwF3JBSGlMUxcIKy78A7G9mG1ZYpj9fAE4BlgKjzOxl4BHi/vSWRIeTKm4leqzdAMwqiuK1\nCssuILab7LUlxMOBgQwj3J3YcPai2pfACmBa+arbyvPyzYBDyhfEULsqTgQuLYri3ZTSHUSPpb+t\nsPwWxAWgujm/Xcdx5WvlVeLvVimoKApPKZ1HDIr4k4rtWEFsN9lry+H0cmLcbMdSSvsAvwscBFyc\nUhpdYfH1gS+7u9X1IvbE75Tr8gYxIulyom9zAt7r9Bw1pfRh4PeBq1JK84hRP1WvNi8CPlHzOp5b\nruPbwCvExaWZxECF04g9dFX/C7xSFMVbFZdbH6i6zJDUlj3xc8SG/oNO3pxSGgZcB3y2KIpXUkrf\nJLrjzeywvh2JkNXG3eeb2X7EqKtX/AM38MvwLgO2JTry9+cE4NqiKGYCpJQM+GlK6XeKoniuv4XL\n+nYkPtc6XU0MKJnn7ov71Pk2cEbN9a1JIgZXZK8te+KqFylmAg8URfF0+fPVwKEppV07XH4C0SGi\nVu7+iLsv/GCAy393qq3jCcRtIiAOO4HZdL43Hg8sdvcq55j9cvfl7v5g3wCX5gE7l/3Uu2EgF0OH\npFb02DKzCcT929Q3AA3UtTlx/ryZu69osq4+9V4CrO/u53ShrqnAEe5+dNN19al3HjCjGwMwzOzb\nwHfcfXbTdTWtLXvix4hzrQO6UNc04PZuBrh0C3CKmY3oQl2nATd1oZ6+biamUGqUmW1LnId/u+m6\nuqEVIS73vtfRcHdIM1uXOG+7rsl6VsXdnyVG91TtuFFJORfZWODfmqxnNWYBR/S9T96A04B/dvd+\nO/fkoBWH0wBmthHlnEp1jrPtU8c5wGHuPrmJ8juo/xDgH4lRVLVvgGWnkvuAb7n7NXWX32EbrgY2\ndvepDZU/DvgJMdqtaKKObmtNiAHM7Ajg74iNvNZO9Ga2E9EpYW93r9SxpOZ2fAN4192nN1D22USf\n6QPrHo9doQ2jgSeBz7t7rUcD5VX3e4Hvu/uVdZbdS60KMbw/6HtD4Pi6xtya2WZEz6lr3b3RkVId\ntGVj4kruJe5+Y3/vr1DuAcC3iNk0nq2r3AG25SDgn4CD6hwbbmaXEZ1n9m9qPHYvtOKcuI8ziB5B\nt5RDFAfFzLYkvr3/Hbh+sOUNVjkC53DgcjM7tY4yzWwSMRng8b0OMIC73wf8FXCvme082PIsXAwc\nTczM0poAA+2bKK88shhJXHn8MYOYGA34JDFf9cUMsbmLic4KC4j5mDcaYBnrEXNxv0bs9Xq+Xn3a\n98fEjJ6nDfTzJ+ba/i5xT7jjKW9zevW8AQ1uAEYMMv8lcAkV5h0mukDOJvpVH9LrdVlDOzcmOv8/\nX27wHU2rSxyBfbLcsO8CxvZ6XdbQ1l2IjjX3APt2GmbiaGwG8USIyzr9bHJ8te6cuK/yauR5xKie\nu4mJ5h4FnvHy4o2ZjQR2IwZCHEuMaroB+IqvunfRkFJetf5LYA/iXuuDREDf7/1VntdPACYST4l4\nnRgQcZsP8Y2gPC06A/g80f30FuIo63Evr9KXF63GE+t4MHGB7gfAl7yBCfeGktaHeKXygtBJRIeQ\nCcS90KXEXmkEUBAb/t1ET55ud+YYNDP7KLFH3odYxw2IkTrLgfeAx4l1nOM9mrFzMMpbYIcQz2Ka\nQHzxvlO+RhGDNuYSA/5v8hgX3XprTYj7Kve+GxAb9xJv0RMR4f0902hiRNREYi7urs/Q2aSy881o\n4tx+mTc/N9eQ1JZRTJW5+3JiD9VK5SHy4nL04pttCzBAeTr0m163o9faeItJZK2iEItkTiEWyZxC\nLJI5hVgkcwqxSOYUYpHMKcQimVOIRTKnEItkTiEWyZxCLJI5hVgkcwqxSOYUYpHMKcQimVOIRTKn\nEItkTiFuKTPbwszuLH+cZWadPntZMqMQt9d44A/Lv+9KTMcrLaQQt9eTwLrl398ipnKVFlKIW8rd\nlwEvlT+OBP6nh82RBinE7fZI+edz3raHiMn7FOJ2e7D884c9bYU0SiFut5XnwQpxi621T4BoKzMb\nTlyNngDsCDhwgJltRDxIbt7a+riTtlprn8XUNma2N3AmcAzxqNO5wLPEw8YM+AgR7J2JpwVeC9zb\nxse7rG0U4syZ2XjiMawfAa4Hvunui9bw/g2AE4CziAeRTXP3R1b3fhn6FOJMlU89PAO4FLgS+OrK\n5y1XWP444BpgFnC+rmDnSSHOUBnALxM9so5x92cGUdZWxIPJlwJT3P2telop3aIQZ8jMLgE+BUx2\n91/XUN5wYA5x/jxF58l50S2mzJjZZGAqcFgdAQZw9xXA8cAYYEYdZUr3aE+ckfI20ZPAdHe/u4Hy\ndwQeBia6+8/qLl+aoRBnpDyM3s7dpzZYxznAx939qKbqkHopxJkoz1tfAA529/kN1jMaeBHYzd1f\nbqoeqY/OifPxGWB+kwEGcPclwK3A6U3WI/VRt8t8HAbc0embU0rjgKeInltODEc8pyiKhzpY/Hbg\nb4CLqjdTuk174nxMIPo+V1EURTGpKIqDgC8CF3a43OPArmamL/kMKMQZMLP1gERcmR6orYCFnbzR\n3RcTEwp8bBD1SZfomzYPGwLL3H15xeVSSul+4lB6DHFI3qlFwKYV65Me0J44D0ac11a18nB6InAo\nMCel1OkX93to+8iCfkl5eBMYZWbr9vvO1SiK4hlgGTC2w0U2ApYMtD7pHoU4A+Vh9MsM4hw1pfQh\nYBs6OC82sxHEhAKN3s6SeuicOB9ziSvUT1dYZuU5McR58YyiKFZ0sNwuxOR6S6s1UXpBPbYyYWZn\nAge4+/FdqOsCYFt3P7PpumTwFOJMmNkmwAJgJ3d/tcF6hpX1HOHu85qqR+qjc+JMuPv/ET2pmt47\nHgW8pADnQ3vijJjZ9kSvrQPdvfYnOpjZpkRXzZPc/YG6y5dmKMSZMbPpwDRgX3d/u8ZyDbgJeN3d\nz66rXGmeQpyZMmx3AsuBk+ua3M7M/po4lP6E5qXOi0KcITMbCXwPeAOYWg4fHGhZ6wJXAJ8GJrn7\nL+pppXSLLmxlqOz8cQSwGHjCzA4cSDlmthPxiJe9gP0V4DwpxJly97fc/VTgz4Bbzex2M5tUHm6v\nkZntbmbXA/8JzAYOXdOE8zK06XC6BcoJ9E4mnuowDHiI6OFVEP2lh/Pbx7hMJIYlfg34hrv/vBdt\nlvooxC1S7oX3BPYmArsDMAJYQfS9nlu+HtbTHtpDIRbJnM6JRTKnEItkTiEWyZxCLJI5hVgkcwqx\nSOYUYpHMKcQimVOIRTKnEItkTiEWyZxCLJI5hVgkcwqxSOYUYpHMKcQimVOIRTKnEItkTiEWyZxC\nLJI5hVgkcwqxSOYUYpHMKcQimVOIRTKnEItkTiEWyZxCLJI5hVgkcwqxSOYUYpHMKcQimVOIRTKn\nEItkTiEWyZxCLJI5hVgkcwqxSOYUYpHMKcQimVOIRTKnEItkTiEWyZxCLJI5hVgkcwqxSOYUYpHM\nKcQimft/k4wcZaGmP/YAAAAASUVORK5CYII=\n", 155 | "text/plain": [ 156 | "
" 157 | ] 158 | }, 159 | "metadata": { 160 | "tags": [] 161 | }, 162 | "output_type": "display_data" 163 | } 164 | ], 165 | "source": [ 166 | "#@title # Game 1\n", 167 | "#@markdown While this is a \"trivial\" example, many statisticians would consider including either or both A and B in their models \"just for good measure\". Notice though how controlling for A would close off the path of causal information from X to Y, actually *impeding* your effort to measure that effect.\n", 168 | "pgm = PGM(shape=[4, 3])\n", 169 | "\n", 170 | "pgm.add_node(daft.Node('X', r\"X\", 1, 2))\n", 171 | "pgm.add_node(daft.Node('Y', r\"Y\", 3, 2))\n", 172 | "pgm.add_node(daft.Node('A', r\"A\", 2, 2))\n", 173 | "pgm.add_node(daft.Node('B', r\"B\", 2, 1))\n", 174 | "\n", 175 | "\n", 176 | "pgm.add_edge('X', 'A')\n", 177 | "pgm.add_edge('A', 'Y')\n", 178 | "pgm.add_edge('A', 'B')\n", 179 | "\n", 180 | "pgm.render()\n", 181 | "plt.show()" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 24, 187 | "metadata": { 188 | "cellView": "both", 189 | "colab": { 190 | "base_uri": "https://localhost:8080/", 191 | "height": 51 192 | }, 193 | "colab_type": "code", 194 | "id": "yQyYJEC83ODX", 195 | "outputId": "07597081-bbec-4209-db8f-b731de34ae3a" 196 | }, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "Are there are active backdoor paths? False\n", 203 | "If so, what's the possible backdoor adjustment sets? frozenset()\n" 204 | ] 205 | } 206 | ], 207 | "source": [ 208 | "#@markdown Notice how there are no nodes with arrows pointing into X. Said another way, X has no parents. Therefore, there can't be any backdoor paths confounding X and Y. pgmpy will confirm this in the following way:\n", 209 | "game1 = convert_pgm_to_pgmpy(pgm)\n", 210 | "inference1 = CausalInference(game1)\n", 211 | "print(f\"Are there are active backdoor paths? {inference1._has_active_backdoors('X', 'Y')}\")\n", 212 | "adj_sets = inference1.get_all_backdoor_adjustment_sets(\"X\", \"Y\")\n", 213 | "print(f\"If so, what's the possible backdoor adjustment sets? {adj_sets}\")" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 25, 219 | "metadata": { 220 | "cellView": "form", 221 | "colab": { 222 | "base_uri": "https://localhost:8080/", 223 | "height": 258 224 | }, 225 | "colab_type": "code", 226 | "id": "b5RJ0UsH_kQ4", 227 | "outputId": "3e4f9430-fca7-40d4-f8b3-ffd241c6a8f8" 228 | }, 229 | "outputs": [ 230 | { 231 | "data": { 232 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPEAAADxCAYAAAAay1EJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGkFJREFUeJztnXn4XtO1xz8rI5KYQwlKDCeGCEnE\nUJFQw9XWrfGa3dCiJLeToUXog7qqqKaK0tbU1hRVrVIzEU0NiZlmI6KIoeYaYl73j7V/lRtJfud9\n33Pe89sn6/M875O8yTlrrf2e8z17n73X3ltUFcdx0qVb1QE4jtMaLmLHSRwXseMkjovYcRLHRew4\nieMidpzEcRE7TuK4iB0ncVzEjpM4LmLHSRwXseMkjovYcRLHRew4ieMidpzEcRE7TuK4iB0ncVzE\njpM4LmLHSRwXseMkjovYcRLHRew4ieMidpzEcRE7TuK4iB0ncVzEjpM4LmLHSRwXseMkjovYcRLH\nRew4ieMidpzEcRE7TuK4iB0ncVzEjpM4LmLHSRwXseMkjovYcRLHRew4ieMidpzEcRE7TuK4iB0n\ncVzEjpM4LmLHSRwXseMkjovYcRLHRew4ieMidpzEcRE7TuK4iB0ncXpUHUC7EJHuwDbAFsAwYD2g\nL6DAa8ADwDTgBlWdWlWcrSAiSwA7AiOwMq4OLAJ8ADyHlW8qcLWqPl9VnK0gImsCO2DlGwosj93H\ns4HpWBn/Blyjqu9VFWc7EVWtOoZSEZHFgUOBbwAvA9dhF/pB4E1AgOWwG2I4sDPwCnAW8BtV/biC\nsBtCRDLgO8DuwG3AZKyMAXgP6Amsht34mwL/CdwE/FRVp1QRcyOIiABfAr4JbAj8HrgHK+Ms4COg\nD7AuVsat4nEXABNUdVYFYbcPVa3tB6t5/wH8Dtgo5zndgS8Dd2BP9EFVl2MBsfYAjsIeTscBK+Q8\nb3FgLPAs9rDqW3VZFhDr8pho/w7sCyyS87w1gDPib7M/scKq46fyAEq68N2An0QBb9uCjUPjTTCm\n6jLNI77l4kPmJuDzTdpYEqutngLWqbpM84hvNPAicHJe8c7DxhDgfuBqYLGqy1TK71R1ACVc+G7A\n+bEmXbIAe4Piw+CQqss2R0zLx5rpxCJqmFjDvQgMrrpsc8S0DfBPYKsCbPUELgYm1VHItXsnFpGT\nsCf4tqr6TkE2V8PeM8ep6tVF2Gwhlt7AFODPqvqDAu3ujrVehqnqi0XZbTKW9YFbgJ1U9c6CbHYD\nLsReJXbSGt34tRKxiGwKXAWsr6ovF2x7M+zdrHDbDcZxEtaBU/iNWKbtBmLoCdwN/FxVzy/B9j3A\nz1T1giJtV0ltRCwii2DDRONV9cqSfJwGrKSqe5RhP4f/ocBfgCFl1Jaxlp8KnKyqlxRtP2cM44HN\ngC+X8SARkSHAzcAGWpNe6zqJeD9gX1XdpkQfiwIzgdGqOr0sPwvwfyVwu6r+vEQfo4FzsI6utt4c\nItIHeAYYrqozS/RzBvCBqn6vLB/tpE4ZW4cCZ5bpQFVnA7/GxpzbiogMwMY/LyrZ1STgY6xfod3s\nBUwuU8CRs4H9Y+steWpRE4vIOsANwKraQHJGlmV7Yr2WK4QQXsnpaxWs2d6/EV+tIiLfx4aSDslz\nfJZlqwIPYwkRimVuHRFC6LSjSETGARur6r7NR9w4IjIFOEFVr89zfJZlawI/Bfpj4/tTgMNDCO/n\n8HUTcJ6qTmwh5C5BXWriTYBJTYhqL2AGsGveE1T1GeBVIGvQV6tsCtza4DkhhDA6hLAl8D3g2Jzn\n3Yr9pm0jvo9vgA0NdkqWZd2xjsYfhxBGYNl2YEkveWh7GcuiLiIehtU4ucmybGksx/gwYM8G/U3j\n05umMERkBRGZXz57w2Wci+WxFMU8BOBzIrJkC/4+g4h0E5EBMY1ybtYDZqjquznNbQNMDyFMAggh\nKHAkcELO86dhv2ny1GUCxCDgjw2esxvwZ+B64JdZlg0IIeS9yR+j4JpYRFbHEvhVRJ4A/oo1D6dh\nySZLY51qjZBlWXY71pQeAGyX5yRV/VhEArAWNiRTFGOBCcDbIvIw9v7dkQOdYb9rXgZhrzX/JoQw\nu4HzH402kqcuIl4UaDSxYy/gxBDCx1mWXYlNHvhJznPfBo4XkaMb9JmXdeJnDJZt9AHwURO9xSGE\nMBogy7JBwMQsyzYMIXyU49y+wN3zrjRbQoF+2DBSR3O2G/Au1q/RiJ3uLcTxNrBYC+d3Geoi4o9o\noCxZlq0EbAycnmWZYhfzDfKLuCc21viXBuNcEKtjPeyzgQ9jTM9hiQ9TgZays0II07Msmw2sTL4a\n/S2st//pVvzOxRis2fwWJtweWProZOw3XaIBW9OBcXP+Q5ZlvYE1QwiP5Di/J/Y7J09dRPwS1lzM\ny57AWSGEwwCyLBPgiSzLVg8hzMhx/gBs3nFe0XeKiCyG1UYzseblQ3FIq2Mu9Aki0k9V32rGfuwD\nWIH878X9gTNV9Ylm/M0LEbkeOBh7ME0DnlDVT+L/bQ8c3oC5m4BTsyzbIYRwTZZl3YBTsAdEHhEP\nwHKzk6cuHVv30VgnxZ7Y7B3g350iFwF5M7Fa7WT6DKr6rqoeoapnq+rdHQKO//cxNly0YYNmsyzL\nbo/vxdcB40IIH3R2kogsAyyF9dwXhqo+pqrfUtVLVDV0CDhyHzB0Pp1enyGE8An2jn9QlmVTgTux\n+eF5WyyFX8OqqMs48bbAsao6sg2+emErgaykqm+U7W8Ov2cBM1X1tDb42h44UlW3LNvXXH6fxWYt\nFVb7L8DXL4CgqmeU7ats6lITTwbWFpGBbfC1IzC1nQKOXA3sm7emapF9gD+1wc/c/DH6LpWYPrsL\ncE3ZvtpBLUQcm54XYe9bZTMWWw2j3dyC9cJvWqYTEVkOWwrnwjL9zIdzgANja6dMdgOmqeqTJftp\nC7UQceQc4AARWbEsByKyFbZWVdvnFMf3xzOxoa0ya+OjgCtU9fUSfcwTVX0UGysu7WEc86WPAn5W\nlo92U4t34g5E5Hisw2KHEuba9sU6l8aq6nVF2m4ghh7AXcAvVPVXJdjfHLgCW+Hj1aLt54xhENZJ\nNUJVnyrB/o+w9bd2q8vCAHUTcS/gXuBcVT27QLuCNS9VVccUZbfJWNbDVrQcpaqNZDh1Zrc/tmbX\nEar6h6LsNhnL4cBOwNZz9tIXYPeL2KKJQ1T1paLsVk2dmtOo6gdYh8UxIlLIDJy4rMtPsbTAcZ0c\nXjqq+gi2PO0NsdZqGRFZFks/vbxqAUfOwOYVXxk7oVpGREYClwK710nAQP0Wyosti3Ww5ViPB3q1\nYGcZrHk5BViq6nLNFdsY4AXgKy3aGYZlP/2ILrSsK5ZRdSmWQz6wBTsCfA1btXSbqstVym9VdQAl\n3gQDgGuxJPnhDZ7bDavRn8dq4S65QiI2cf8prGe+f4Pn9gF+iGUt7VN1WeYTY3csi+sV4H8afSBj\nnZDXY4kkXWYlz8J/p6oDKPkmkFhjPYeNJe85vxo1HjsA+DY2Fe9BYGTVZchRxr7xQfM6tsDBSKD3\nfI7tjq3DPAGbE305ORecr7iMGZZm+QI21XCt+bUasJzz7bEx51eB8UDPqstQ5qdWHVvzI65yuAM2\ndLEZVvs8iC1D8wk2q2YodpPfiC3fMkUT+nFiquQYbA3ptbCJBY/z6QyvVTEBv4C9IpyrtsBBMojI\nusAh2DY0/bBF4WdhExn6Yq9RA7HW1/nApVrQssVdmYVCxHMSJxOsBQzGaqKZWEfR/cCzKQl3fsTJ\nFEOwGUPnYX0DdwD3afszzUohJqUMxXbC6IlNHgnAI2odnAsNC52I50REFLhXVUdUHUsZxJUrb8PG\ntgsbcnO6FrUaYnI+Q8fMrtInhjjV4SKuN6Pin7VYEM6ZNy7ietOxmN9KRSVNOF0PF3FNiZurLxu/\nvgusX2E4Tom4iOvLkPhnxxpiQyuMxSkRF3F9uRdbLF6wJJBKt2R1ysOHmGo8xAT/LuNgtYkTTg3x\nmthxEsdF7DiJ4yJ2nMRxETtO4riIHSdxXMSOkzguYsdJHBex4ySOi9hxEsdF7DiJ4yJ2nMRxETtO\n4riIHSdxXMSOkzguYsdJHBex4ySOi9hxEsdF7DiJ06PqAJxyEJHlsX2JADYQkbeBf9Rhmxrn/+Nr\nbNVkjS0R6QF8BdtQbQS2denj2LK1rwCrAL2AqcBEbLOxd6uJ1ikSb04njoh0F5FvYRvDHQ78Cdv5\nYRlV3URV14h/rohtInce8FXgGRE5OW6+5iSM18QJ18QikmFbeH4IfFtVH2jg3FWBk7Bae39VvbOM\nGJ3y8Zo4UURka+BO4FJgq0YEDKCqT6vq3sARwEQROaiEMJ024B1bCSIiWwGXADur6uRWbKnq1SLy\nCHCziKCq5xUSpNM2XMSJISIrA5cBu7Uq4A5U9clYs08Wkce8aZ0W3pxOCBER4JfABFWdVKRtVX0S\nOAS4wDu70sJFnBZ7YUNGp5RhXFWvBu4Bxpdh3ykHF3EixFr4u8B4Vf2oRFfHAAd5bZwOLuJ0GAEs\nCdxYphNVfRr4G7BHmX6c4vCOrXT4L+AiVf0kz8FZlq0KPAxMw7Y3/Qj43xDCLTlOvwA4CBuDdro4\nXhOnw3CshmyEEEIYHUIYhYnyzCzL1s9x3l3AsNiEd7o4LuIEEJFuwIbAfc3aCCHMwDK0xnZ2rKo+\nj9XcqzTrz2kfLuI06Aegqq+2aGcqsE7OY2cCA1r057QBF3Ea9AI+KMBOP+DjnMd+GP06XRwXcRq8\nByxSwDvqcOD+nMcuArzfoj+nDXjvdBq8DcwGVgRmNWMgy7LVsXHmrTs7Nj4s1gJmNOPLaS8u4gRQ\nVRWRacAwGhNxlmXZ7UBvoDswNoTwTI7zBgJvqeo/Gw7WaTsu4nS4BxiJTfrvlBDC08QOsSYYCdzb\n5LlOm/F34nT4LbCfiPRug68DgYvb4McpABdxIqjq48CDwK5l+hGRDYCVgWvL9OMUh4s4LU4BThKR\nxcswHpNKJgCnlTzJwikQF3FCqOotwM3AqSW5GId1gJ1Vkn2nBFzE6XEYsK2I7F+kURHZAjgWOEBV\n8yaEOF0A751ODFV9U0T+A7gtrol1Qas2RWQ0cAWwR3z3dhLCa+IEUdUAbAkcJyLnNvuOLCI9RWQ8\nJuDdY3PdSQwXcaJEIW+AXcOHRGSfvMNPItJNRLbHphxuDgxT1dvKi9YpE188PuHF4zsQkS8CR2Ki\n/g0wGVsMYFbH3ksisgyW8bUJ8N/A68BPsO1cFt6boAa4iGsg4g5EZE1gH2wpn2HYfky9sAkUn2CT\nH6YBl6vqPVXF6RSLi7hGIp6TOImhL/AvrPa9N+/SPk5a+DtxTVHjrfj1HRdwfXERO07iuIgdJ3Fc\nxI6TOC5ix0kcF7HjJI6L2HESx0XsOInjInacxHERO07iuIgdJ3FcxI6TOC5ix0kcF7HjJI6L2HES\nx0XsOInjInacxHERO07iuIgdJ3EWShGLcXr8upGIHFxpQCUgIv1F5Kr49UIRGVxpQE5pLJQijnx9\njr/vXlkU5TEQ+FL8+2BgeIWxOCWyUIo4rrP8UPz6ETCpwnDK4iFsczSA97Glap0aslCKODIJ+Bh4\nB7i34lgKR1VnA8/Gr4sAj1UYjlMiC7OI78UEvAj1raXujn/O8P2G68vCLOJp2OLq76rqS1UHUxKT\n459/rTQKp1QWZhHPAt4GHqg6kBLpaGG4iGvMQrU/cdzaZCC2T9F62BYni4jIMZiYp6nqixWG2DIi\n0gvrjR4GrAUosEXc/nQq8ICqvlNhiC0jIksCQ7EyLofdx7OBgJVx+sK0UfpCsReTiKwIHBg/itVQ\nDwFvAQIsC2yI3RTPAWcDv5tjG5Quj4hsBBwK7Ao8jZXxcaz3XYDVsPKtA9wCnAXclMr2LvHhtCNW\nxmHAg8SdH7Ey9sHKNgy7nr8BzlHV6ZUE3E5UtbYfYFHgNOA1TJiDOzlegC8CvwdeBb5BfNB11Q/W\nsrgFeArb3nTZTo7vg42R3w88AmxcdRlylPHLwDPAbcBuQM9Ojl8F+CHwInAZ0L/qMpT6+1QdQIkX\nfgTWvGrqImJP9XuAm4GVqi7PPOIT4BDgFeBwoHsT5+8OvAScAvSoukzziLEfcFF8QG3ZxPmLAqcC\nLwA7Vl2e0n6nqgMo6eJvB7wM7NainR7AcVjzdM2qyzVHXAL8ONakg1q0tTxwI3A10Lvqss0R11LY\nENlFQN8WbW2GvSYdUnW5yvjU7p1YRDYH/oA9eQvplRWRg4Cjgc1V9bkibLYYzwnAV4CtVfW1Auz1\nAi7H3i1314rfk0VkMewV4W7gO1rATSoiA7Hm+A9U9cJW7XUlaiXi2Gv5MHCgql5fsO3jgC2AbYq4\nqVqIY2vgfGCYqr5coN3e2E1+mar+rCi7TcYyAfgcsEeRv7WIDALuxB7GtenwqpuILwBmq+qhJdju\nAUwBfq2q5xZtP2cMi2O96ger6g0l2F8LK+Mmqvpk0fZzxrAFcCnWCdlyK2Me9scC+2BCrsUwVG1E\nHIdYrgLWVtW3S/KxLnAH8PmyfHTi/wRgFVUdU6KPI4BNVXXnsnwswLdg4/XHq+pVnR3fpI9uwO3A\n+XVpVtcpY+tQ4OdliktVH8UmTuxdlo/5Ed9bD8R6ksvkHGC0iKxUsp958QWgN9anUQrxff/H2P1S\nC2pRE4vI0tgwxJp53hOzLNsaGB9CGB2/DwBuBTYKIfyrE19bA6er6pCWA28AEdkda0Zvlef4LMtW\nxfoH5p7csXMIYYHNVBE5E3hdVY9rJtZmEZFLgLtVdUJnx2ZZNh5YLIRwdPzeDbgP2C+E8NCCzhWR\n7sAMYFdVndp65NVSl5p4JHbxc3X0hBBuBp7Jsmy/+E+nA8d0JuDIrcDKIrJ8c6E2zXbAlQ2eE0II\no+f65HnPnAhs23iIzROb0tuSv4ynA7vEBzDA/sDdnQkYIL4L/542l7Es6iLiYVjObCN8F/h+lmU7\nAv1CCLluntgcuy/6bCfNlLFZ7gcGx868drEK8KGqzspzcAhhNnAi8MMsyxbDEl6ObcDfVNp/DUuh\nLiLeABNWbkIIr2BP88uBcQ36uw/LtS4MEVlaRG4VkQtE5OsisoGI9Iz/1xPI+HQ1klJRyxl/FhhU\npF0RGSUiU0TkVBHZVURWizUw2DW8v0GTvwPWBn4JXBhC+GcD594XfSZPXWYxLQ00cgE7GIJlYw0H\nZjZw3svAOBH5ahM+58dKwArx77sBn2AzrGZitcYHqvpegzazLMtun+N7CCHkXRRwNnCFiBTZUbgR\nNgFlY2waaA9AReRRLL/9zUaMhRA0y7Kjgd8CBzQYy8vYfZM8dRGxYDdHbrIsGwGsC2wJ3Jxl2V9C\nCHlv2E+w5t8qDUWZnz7Au8CHwOrYdWqm1RQ6Ou+aoCdWy5VBNywv+m1sZZXBWA53M68LTwHPhxDe\nb/C8T6hJS7QWhcBuhiXyHpxlWQ9sVtM3QwjPYxlQxzfgbwngRFWVoj5YTfwm8CRWs3wHyxDrgz1s\nesRe1XbxETC84DLuhdXw04AJ2JDZ2rGM3+PThf3awRLYVNTkqUtN/AjWNL425/GHAZNCCI/G7xOA\naVmWDQ4hPJzj/CGY0ApDVWeJyFLzSzMUkeewd9RH5/X/82Hu5jTAkSGEexZ0UkzBXAv4ewO+OkVV\nLxWRy+ZVRhF5BFi/SH+dMAS7b5KnLuPEewG7qOoubfL3HLCFqj7VDn/R5+XAtap6cRt8DQMuVNW2\nLTgfWxlvACur6htt8HcCNn3zmLJ9lU1dmtN3AqNEZNGyHYlIR23RSEdYEUzi08Xgy2Z7Pl1kry3E\nsdsp0Xc7aHsZy6IWIlbVZ7AJ/O3YyeFQ4LwKZjJdAmwnIp8r00kcGz4YOK9MP/PhPNqQDhnz7JcF\nbirbVzuohYgjZwPfKrPzR0T6Yw+KX5XlY37EJuZEyr/JdwaeVdUqVgH9EzBQRMrecubbwLk+i6mL\nEWen3AZcrapnlOTjcuwGP7wM+zn8fx4bhhmlqoXv6CAiS2H51nuraiVb24jIGExkG6nqhyXY3w44\nF1hfVfOk2XZ5aiNiABFZHbgLGFn0pG8R2Q1L89tQbYuUSog7OH4N+EKRN3nMnLoYm/jwzaLsNhnH\nn4GpqvqDgm0viWW9HaCqNxdpu0rq1JxGVWdg443XFTmVTkS+gDXX961SwJHzsGVaLy44t/k4bIjn\nqAJtNkzsazgIOEBE9uvs+LyISB/gGmBinQQMNRMxgKqejwnujrgcS0uIyPbYInJ7q2rlG6/Fm3xP\nYBngMhHp24o9EekuIqcAewDbahdYWD5OgtgOOFlExs6RX90UIrIcthjg48ARBYTYpaidiAFU9TRs\n3eHJIvLdZjq7RKSfiJyDvT/tpKo3Fh1ns8Qc6h2wjKMHRWRUM3ZEZG1si5fh2CtIl9mTKr7zj8Jq\n5WviBgANE1+DHsL6Sw7URBbLbwjtAktulvXB8o4nYbNj9gcWzXFOf+D72CyeXwNLVF2OTuLdAVuO\ndSIwmhyL3WPZSr/AJgEcAnSruhwLiLUXlhL7MrYix8Ac5/QAvooNIf0dWzOs8rKU9alVx9a8iL3W\nX8Ju1o2xNbKmYtuAvIm1Rvpje/sMj8dcBZytiaz6EBfQ2xcYi93Ad2L5yQHLVe7Fp9u4bIKtNX0u\n8CtVfaGKmBtFRNbAduQYA0zH8gI6tnH5ENvhcl2sjKOwh/DZwBWq2ujkiKSovYjnRERWwdZx6thQ\nrR82m+U14oZqwJ1awiqL7SC+Ow7FpvwNA9bA1qz6AKutp8XPFE10v+KYlbc5Vr6ODdV6YrO+pmPl\nu0tVC8377sosVCJ2nDpSy44tx1mYcBE7TuK4iB0ncVzEjpM4LmLHSRwXseMkjovYcRLHRew4ieMi\ndpzEcRE7TuK4iB0ncVzEjpM4LmLHSRwXseMkjovYcRLHRew4ieMidpzEcRE7TuK4iB0ncVzEjpM4\nLmLHSRwXseMkjovYcRLHRew4ieMidpzEcRE7TuK4iB0ncVzEjpM4LmLHSRwXseMkjovYcRLHRew4\nieMidpzEcRE7TuK4iB0ncVzEjpM4LmLHSRwXseMkjovYcRLHRew4ieMidpzEcRE7TuK4iB0ncVzE\njpM4LmLHSRwXseMkjovYcRLHRew4ifN/jHxgrkKc1LAAAAAASUVORK5CYII=\n", 233 | "text/plain": [ 234 | "
" 235 | ] 236 | }, 237 | "metadata": { 238 | "tags": [] 239 | }, 240 | "output_type": "display_data" 241 | } 242 | ], 243 | "source": [ 244 | "#@title # Game 2\n", 245 | "#@markdown This graph looks harder, but actualy is also trivial to solve. The key is noticing the one backdoor path, which goes from X <- A -> B <- D -> E -> Y, has a collider at B (or a 'V structure'), and therefore the backdoor path is closed. \n", 246 | "pgm = PGM(shape=[4, 4])\n", 247 | "\n", 248 | "pgm.add_node(daft.Node('X', r\"X\", 1, 1))\n", 249 | "pgm.add_node(daft.Node('Y', r\"Y\", 3, 1))\n", 250 | "pgm.add_node(daft.Node('A', r\"A\", 1, 3))\n", 251 | "pgm.add_node(daft.Node('B', r\"B\", 2, 3))\n", 252 | "pgm.add_node(daft.Node('C', r\"C\", 3, 3))\n", 253 | "pgm.add_node(daft.Node('D', r\"D\", 2, 2))\n", 254 | "pgm.add_node(daft.Node('E', r\"E\", 2, 1))\n", 255 | "\n", 256 | "\n", 257 | "pgm.add_edge('X', 'E')\n", 258 | "pgm.add_edge('A', 'X')\n", 259 | "pgm.add_edge('A', 'B')\n", 260 | "pgm.add_edge('B', 'C')\n", 261 | "pgm.add_edge('D', 'B')\n", 262 | "pgm.add_edge('D', 'E')\n", 263 | "pgm.add_edge('E', 'Y')\n", 264 | "\n", 265 | "pgm.render()\n", 266 | "plt.show()" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 26, 272 | "metadata": { 273 | "colab": { 274 | "base_uri": "https://localhost:8080/", 275 | "height": 51 276 | }, 277 | "colab_type": "code", 278 | "id": "2d6Ezs6PDDON", 279 | "outputId": "c597c0df-3271-4e8e-9a68-3200c7f105a5" 280 | }, 281 | "outputs": [ 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "Are there are active backdoor paths? False\n", 287 | "If so, what's the possible backdoor adjustment sets? frozenset()\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "graph = convert_pgm_to_pgmpy(pgm)\n", 293 | "inference = CausalInference(graph)\n", 294 | "print(f\"Are there are active backdoor paths? {inference._has_active_backdoors('X', 'Y')}\")\n", 295 | "adj_sets = inference.get_all_backdoor_adjustment_sets(\"X\", \"Y\")\n", 296 | "print(f\"If so, what's the possible backdoor adjustment sets? {adj_sets}\")" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 27, 302 | "metadata": { 303 | "cellView": "form", 304 | "colab": { 305 | "base_uri": "https://localhost:8080/", 306 | "height": 258 307 | }, 308 | "colab_type": "code", 309 | "id": "Pg6T2WA3DZ8n", 310 | "outputId": "5d79efd3-b8f7-46a8-dcf0-da2b7997be3c" 311 | }, 312 | "outputs": [ 313 | { 314 | "data": { 315 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPEAAADxCAYAAAAay1EJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFKxJREFUeJzt3X/0Z1O9x/Hne3wlw6TkRzW4fuX4\nkTs0QshQIV0tolYhwi3dhKRcinJTuusWq+sS+XGVcsMoJS3cuAzGhMbvX3NakSu/iluEBtN43z/2\n/pjPqvnO95zP5+zz+ezPvB5rfdfM1/fsvc/4ft/f1/nsz9n7mLsjIvmaNOgTEJH+qIhFMqciFsmc\nilgkcypikcypiEUypyIWyZyKWCRzKmKRzKmIRTKnIhbJnIpYJHMqYpHMqYhFMqciFsmcilgkcypi\nkcypiEUypyIWyZyKWCRzKmKRzKmIRTKnIhbJnIpYJHMqYpHMqYhFMqciFsmcilgkcypikcypiEUy\npyIWyZyKWCRzKmKRzKmIRTKnIhbJnIpYJHMqYpHMqYhFMqciFsmcilgkcypikcypiEUypyIWyZyK\nWCRzKmKRzKmIRTKnIhbJnIpYJHNjgz4BaY6ZrQTsAWwJTAfWA14NvAQ8AtwKzAV+4u6PDeo8pVlK\n4hFgZoWZfRt4CNgd+DXwz8BbgDWAAvgYcAuwFXCvmc00s20Gc8bSJHP3QZ+D9MjMxoCjgCOBU4Gz\n3f3xCu1eA+wHHAP8FDja3Z9Lea6Sjoo4U2a2GnAp8BzwMXf/3x76eC3wTWAGsJu739fsWUobdDmd\nITNbHbgOuBrYuZcCBnD3p939QOB44Boz27TB05SWaGIrM2a2HHA5MNPdj2+iT3f/vpm9BFxpZtPd\n/Ykm+pV26HI6M2Z2IrAJ8H5v+JuXsm9JR0WcETN7K3AFMC1FWsaUnwv8q7v/oOn+JQ0VcUbM7IfA\nLHc/LeEYOwBnABsrjfOgia1MmNlU4J3AeYmHug5YCOyQeBxpiCa28rEfcJG7P1vl4KIo1gbuJtyl\n5YQ7t44qy3L2ktq5u8cbRw4Cru3rjKUVSuJ8vB24pmabsizLHcqy3BE4GvhixXbXAFvXHEsGREWc\nj+mEVO3V6sCjFY8tgTfEm0FkyOlyOgNmtgKwMvCbmk2LoihmES6lpwK7VGnk7gvNrAQ2INxvLUNM\nSZyH5YH5PcwWdy6ntwZ2Ai4qiqLqL+7ngck1x5MBUBHn4S/AMv10UJblPGA+sGbFJmPAgn7GlHao\niPPwLLCsmU3ptYOiKFYG3kj118VTgd/3Op60R6+JMxBfo94NbA5cX6Np5zUxhNfFh5Zl+dJEjczs\n9cDrgAfqnqu0T0Wcj1sJO3ZUKuKyLB8Cek3uLYHb3P3lHttLi3Q5nY+fAPuZmbUw1kcImwVIBnTv\ndCbMbBIwDzjA3eckHGc1wvvE67r7H1ONI81REmciXtqeCnw5cRp/nrBWWQWcCSVxRuKeWjcB33b3\ncxL0vx0wE9jU3f+v6f4lDRVxZszsLYSFCTOa3BPLzFYFfgEc5e4/bqpfSU+X05lx93uAzwD/bWYb\nNtGnma0CXElYJaUCzoyKOEPufj5hRdK1ZrZbP32Z2XRgNnAVcFwDpyct0+V0xuIuHOcCNwCfc/cn\na7RdgTCJdTBwZPzFIBlSEWfOzFYk3AiyDnAhcDZwi7u/uJhjlyE8FeIgwhMhANavsuG8DC8Vcebi\n5nYvACcBTxB2ANkAuB/4FWEl0nPA2sA04HHCDPRc4BLCXlr3t37i0hgVcebM7OeEZYaTOksVzWwy\noWDXA74PfAG4mXAr5dNdbR3A3du4C0wSURFnrCuFv+LuXxrnGCe873vPYr62DXAjSuOsqYgztrgU\nXswx4xZx19eVxhnTW0yZiim8EyGF+/lNvG3sb6NGTkxapyTOVJUUjsctMYm7jlEaZ0pJnKEGU7hD\naZwxJXGGqqZwPHbCJO46TmmcISVxZhKkcIfSOFNK4szUSeF4fKUk7jpWaZwZJXFGEqZwh9I4Q0ri\njNRN4dimchJ3Ha80zoiSOBMtpHCH0jgzSuJM9JLCsV2tJO5qozTOhJI4Ay2mcIfSOCNK4gz0msKx\nbe0k7mqnNM6AknjIDSCFO5TGmVASD7l+Uji27ymJu9oqjYeckniIDTCFO5TGGVASD7F+Uzj20XMS\nd7VXGg8xJfGQGoIU7lAaDzkl8ZBqIoVjP30lcVcfSuMhpSQeQkOUwh1K4yGmJB5CTaVw7KvvJO7q\nR2k8hJTEQ2YIU7hDaTyklMRDpskUjv01ksRdfSmNh4ySeIgMcQp3KI2HkJJ4iDSdwrHPxpK4qz+l\n8RBREg+JDFK4Q2k8ZJTEQyJFCsd+G03irj6VxkNCSTwEMkrhDqXxEFESD4FUKRz7bjyJu/pVGg8B\nJfGAZZjCHUrjIaEkHrCUKRz7T5LEXX0rjQdMSTxAGadwh9J4CKiIB+uy+OfxKTo3sw3jXzcys7Gm\n+3f3OfGv9zXdt1SnIh6Q1ClsZlsBd8dPLwQ+2vQYkdJ4wFTEg5M0hYFHgQXx788Cj6QYRGk8eCri\nAWjptfCjwF/i35cHbk00DiiNB0qz0wOQeka6a5zZhAJ7yt1XTTVOHEsz1QOiJG5ZyzPS18c/b088\nDiiNB0ZF3L7Ur4W73RL/vC71QHptPDgq4ha1mcJmZsC8+Ol9ZtbG91ppPAB6TdyiFu7O2gDYF9gS\n2AKYDBiLJrhuA+YCFwO3JLxDTK+NW6QkbknKFDazd8VfEDcQCvcMYDNgRXef7O6vAdYFvgb8CfgB\n8Esz2ycmdpOUxi1TErck0a4dKwEnATsDxwIXu/uLFdpNAnYBTgR+Bxzs7r9t4pxi/0rjFimJW5Ai\nhc2sAO4AOgsczq9SwADu/rK7XwFsBcwBbjOzHZs4r0hp3CIlcQsS7GBZANcCx7n7uQ30twMwE9jH\n3a/ut7/Yp9K4JUrixJpO4XgJfSVwbBMFDODus4APABfEybEmKI1boiROLEEKnwMsdPdP9H1yf9v3\n4cCHgO3dfWED/SmNW6AkTihBCr8LeDdwVL99jeM0wttRn2qoP6VxC5TECSVI4Z8D33P38/s+ufHH\n2By4FFjX3f8y0fEV+lMaJ6YkTiRBCm8ATCPcqJGMu99OWLb4Dw11qTROTEmcSIIUPgGY7O6fq9qm\nKIq9ge8BbyzL8qkaYx0AvM/d96p9oovvT2mckJI4gUR3Z23JolVJVe0DPECYea7jBuBtNdssidI4\nIRVxGo2uVIq3Rk6nxsL+oihWJhT+Z4G9aw75IDDFzFar2W6xtMIpLRVxwxKl8IqE3Tkeq9Hmg8DP\nCO8pv7koiqlVG8bz/hWwXp2TnIDSOBEVcfNSrBd+NfBCzV8K+wAXlGW5EPgh4f3fOl4AlqvZZlxK\n43RUxA1KuFLpJeBVVQ8uimINwn3RJxdFcQfwHuDDNcdcNo7bJKVxAiriZqXateNZADN7fcXj9wa+\nVZbltLIsNwMKYOWiKOpcHq9D2GyvMUrjNFTEDUm5XtjdXybsk/XWik32Br7T+aQsSwfOo2Iam9mb\ngDHg4XpnWonSuGF6n7ghLezacTLwjLuf0HTfixlrT8Ia4/ck6l/vGzdISdyAlvbOmgl8tKW9sg6M\n46WiNG6QkrgBbewjHd8rnktYgnhlijHiOGvHcdZy9z8nHEdp3BAlcZ/a2sEy9v1N4EQzWzbVOIQt\ne85OWcCR0rghSuI+tfU0hziWEfaQvs7dv5ig/z2AbwDTWihipXFDlMR9aHkf6Y0Ir1O3BY6IW+o0\n2f/6hF0yD2yjgCOlcQOUxH1o6bXwdMIl7gzCnVuXAqcCFwAfcPe6iyIWN8b6wNXA19z9rH77qzm2\n0rhPSuIetfB84WnxgWjXE7aknQTcDHzI3f+HcFvlj8zssH5mrOMl9A3AiW0XcKQ07pOSuEctvC/8\nLeATwDKELXMeArZw92e6jimAc+PXP+3ud9Tof21Cwm9JuISe3dS516U07o+SuActpPAYcAihgB34\nIzCju4AB3L0Etgd+BFxmZjea2QFmtt7inuxgZm8ysz3N7DLC20gPEyaxBlbAkdK4D0riHqRMYTPb\nBLgnfnoocDDhEnre+K1eKfzdgP0JC/pXJCwnnE9YPLEOYVHDrYQJsgtanMCakNK4dyrimmIKv0BI\n4S813PdxwFfip8u7+wt99LU6YT3wcoTVSI8AD6eeRe+VmW0D3Ahs7O73D/p8cqIirinRM5XGgGcI\nD0M7xd2PaKLf3CiNe6PXxDUkeqbSJsACQgFvvbQWcKTXxj1QEteQYAfLxi6fR4XSuD4lcUVNprCZ\njZnZ84QCPsXdTQX8CqVxTUriippK4b+afd7a3W9u4vxGidK4HiVxBU2lcLx87hTw8irgcSmNa1AS\nV9BvCmv2uT6lcXVK4iUws73M7AlCAf9bjwWs2efevAPAzJ40sy8P+mSGmZJ4CczsCsJ2rwBPAFPj\npnVV22v2uUdmdh3hllKAR9x9zUGezzBTEo/DzJZh0Q/Ri8C9Ndpq9rl/c4HObaGrmJmKeBwq4vFt\nASyMf/8DYe3uhCmsy+fGHE2YBFxA+D7sNNjTGV4q4vHtSijEPwO7uvvTEzXQ7HNzPDzg/H3An4AV\ngD0He0bDa6l6TRyX561LeMLgWwgrfZyQtHcAt7r7E/HYu4BNgQPc/bwJ+tXscyJm9nbCpgUvACu5\n+0Izey1hI/3pwGqEje7nAyXhMnyeuy8cp8uRs1QUcXyiwcfjhxOW491FKDwj/CBsTviheISw0P5k\n4Lvu/o8T9K2bNxIzs8OA/wA+T5honA7cSfg+PkrYFGEFYOP4tVWA7wNnTLSEcxSMDfoEUjKz5QmT\nSwcBFxIui+9ewvEGvBM4grB873Yzs/HeWtLsc2seJFxWvx84CfiJuy8Y72AzW4uwDnuWmc0CDnP3\nJ9s40YFw95H8IGw7UxKKd9Ue2m8M3ELYQG6Nv/raGPA8IdX/fdD/1lH9AKYQniH1ILBjD+2XJ2zB\n+ziwx6D/Pcn+Pw36BBJ983cBngQ+2Gc/Y8CXCPtbvTn+t01i8Tqw1aD/raP6AbyOsDHgecCKffa1\nDeFl0icH/e9K8TFyl9Nmth1wPuE374399OVhhvSEeNfWVWY2EzgqflmXz4mY2WTgcuAXwGc8VmKv\n3H2OmW0PXGtm8939uw2c5tAYqYmtOGt5N/Bxb/h5RWZ2PPAvaPY5OTM7BXgD8OF+C/iv+t0QmA1s\n5yM04TVqRfwdYL67H5Kg7zFgDvCf7n5m0/1LEBPzAmBTd/9Dgv4/BXyEUMgj8TbUyBSxmb0NuATY\nyN2fSzTGJoTN3P8u1RhLs/juwB3Al939kkRjTAJmAeeOymX1KN2xdQhwWsricvd7CQ802zfVGEu5\nbQm7c/441QAebp39OuHnZSSMRBGb2cqE9xDPrXJ8URTvLopiVtfnU4uiKIuieE2F5qczQj8AQ+YQ\nwg0aE14eFkVxXFEUX+v6fFJRFHcURfH3Fca5AljNzLbo41yHxkgUMWHt6c1e8Q39siyvBh4uimL/\n+J9OBo4ty/JPFZpfA6wZ93WWhsRL6Z2BH1ZscjKwV1EUU+PnBwI3l2V510QN42vhH8XxsjcqRTyd\ncM9sHUcCxxRFsQcwpSzLSj888XLstjimNGctYIG7P1rl4LIs5xPulvtqURSTgc8BdZ7ZPJcR+R6O\nShFvRiisysqyfIrw2/wiwuNS6riNcK+11GBmM8xsjpl9w8w+YGbrdD0zajPg9ppd/hewEXA28N2y\nLH9fo+1tcczsjcrNHisDdb6BHdOITxsEflOj3ZPAoWa2ew9jLs3eRrzTDXiO8PPnZnYvYSXZM0to\n+zfKsvSiKL5AuLnnoJrn8iTh5yZ7o1LERvjhqKwoii0Jt1DuCFxdFMUVZVlWndl+mXD5t1ats5SO\nSYT7op8jPDh9U+B31H9JBOG+6sfKsnyxZruXGZEr0ZH4RxB+GFaqenBRFGOEWebDy7J8jDCrXWcz\ntpUI29eaPqp/EB6MPp+whPAUwtLQjQjLCI8mPMq1LSsBz7Y4XjKjUsT3EC6Nq/oscF1Zlp19s04B\ndiqKYtOK7acRbu+UGtz9AsJihi3c/TPufpG7PxDfUroHqPL2UFOmsWgdeNZG4o4tM9sH2Mvd92pp\nvEeA7d39wTbGWxrEjQmfBtb0ClshNTDeCcAy7n5s6rFSG5Ukng3MiJsAJGVmnbSoMxEmE4jv3c4h\n7G3Whl0J2/5kbySK2N0fJizg/1ALwx0CnNXk6hp5xVm0cDdcvM9+FeCq1GO1YSSKODod+HS8LEvC\nzFYl/KI4J9UYS7mfAuu2cDvkEcCZo7KKaZSK+HLCPkyHJxzjNMJSxMcSjrHU8rBv1rHAOWa2bIox\nzGwXwkKL01P0PwgjMbHVYWbrATcB72h60beZfZBwm9/m7j6/yb5lkXgH18+Aue5+fMN9v5awy+lB\n7n51k30P0iglMe7+AOH9xsvNbI2m+jWzzm/u/VTAacW5hoOBg8xs/4mOr8rMVgAuAy4epQKG0blj\n6xXufm5cmni9mb2330Q2s12B7wH7uvsvGzlJWSJ3fzRe9l5lZlOA0/uZSDSz1QhrlOexaI+0kTFS\nSdzh7icBXwVuMLMje5nsMrMpZnYGcCbwfnf/edPnKeNz9/uAGYRUviw+AKC2+DLoLuBawt5rlZ9q\nmYuRLGIIiQxsDewOzDWzA6u8j2xmq5rZMcB9wKsIez3NTnu2sjju/mvCoolbgTvN7Otmtu5E7eJT\nKXc3s6uAEwg7nx43igUMIzaxtThxT6X3Ap8krJ65nnCj/Z2EVTOTgFUJz/bZIh5zCeESrpcb8iUB\nM1sf+CfgAMJl8S0seozLAsJztTYhrBGeAfyWMI8x093rLo7IysgXcbf4eI9tWfRAtSmE1SyvPFAN\nmO0JdlmUZsSrqe0I38POA9WWJTy9ch7he3iTu98/sJNs2VJVxCKjaGRfE4ssLVTEIplTEYtkTkUs\nkjkVsUjmVMQimVMRi2RORSySORWxSOZUxCKZUxGLZE5FLJI5FbFI5lTEIplTEYtkTkUskjkVsUjm\nVMQimVMRi2RORSySORWxSOZUxCKZUxGLZE5FLJI5FbFI5lTEIplTEYtkTkUskjkVsUjmVMQimVMR\ni2RORSySORWxSOZUxCKZUxGLZE5FLJI5FbFI5lTEIplTEYtkTkUskjkVsUjmVMQimVMRi2RORSyS\nORWxSOZUxCKZUxGLZE5FLJK5/weF62ecIGohUQAAAABJRU5ErkJggg==\n", 316 | "text/plain": [ 317 | "
" 318 | ] 319 | }, 320 | "metadata": { 321 | "tags": [] 322 | }, 323 | "output_type": "display_data" 324 | } 325 | ], 326 | "source": [ 327 | "#@title # Game 3\n", 328 | "#@markdown This game actually requires some action. Notice the backdoor path X <- B -> Y. This is a confounding pattern, is one of the clearest signs that we'll need to control for something, in this case B. \n", 329 | "pgm = PGM(shape=[4, 4])\n", 330 | "\n", 331 | "pgm.add_node(daft.Node('X', r\"X\", 1, 1))\n", 332 | "pgm.add_node(daft.Node('Y', r\"Y\", 3, 1))\n", 333 | "pgm.add_node(daft.Node('A', r\"A\", 2, 1.75))\n", 334 | "pgm.add_node(daft.Node('B', r\"B\", 2, 3))\n", 335 | "\n", 336 | "\n", 337 | "pgm.add_edge('X', 'Y')\n", 338 | "pgm.add_edge('X', 'A')\n", 339 | "pgm.add_edge('B', 'A')\n", 340 | "pgm.add_edge('B', 'X')\n", 341 | "pgm.add_edge('B', 'Y')\n", 342 | "\n", 343 | "pgm.render()\n", 344 | "plt.show()" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 28, 350 | "metadata": { 351 | "colab": { 352 | "base_uri": "https://localhost:8080/", 353 | "height": 51 354 | }, 355 | "colab_type": "code", 356 | "id": "l0GI2mM3WQeI", 357 | "outputId": "7ca879e8-e678-42d9-9d03-404ffb67c51c" 358 | }, 359 | "outputs": [ 360 | { 361 | "name": "stdout", 362 | "output_type": "stream", 363 | "text": [ 364 | "Are there are active backdoor paths? True\n", 365 | "If so, what's the possible backdoor adjustment sets? frozenset({frozenset({'B'})})\n" 366 | ] 367 | } 368 | ], 369 | "source": [ 370 | "graph = convert_pgm_to_pgmpy(pgm)\n", 371 | "inference = CausalInference(graph)\n", 372 | "print(f\"Are there are active backdoor paths? {inference._has_active_backdoors('X', 'Y')}\")\n", 373 | "adj_sets = inference.get_all_backdoor_adjustment_sets(\"X\", \"Y\")\n", 374 | "print(f\"If so, what's the possible backdoor adjustment sets? {adj_sets}\")" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 29, 380 | "metadata": { 381 | "cellView": "form", 382 | "colab": { 383 | "base_uri": "https://localhost:8080/", 384 | "height": 258 385 | }, 386 | "colab_type": "code", 387 | "id": "XP2ORZw8EtyZ", 388 | "outputId": "9f267191-b408-483b-87d9-754c949a7d72" 389 | }, 390 | "outputs": [ 391 | { 392 | "data": { 393 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPEAAADxCAYAAAAay1EJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFO1JREFUeJzt3Xm4HUWdxvHvLyEhEBZBFmUXwUJA\niAYiQsJqwOCgMDFjCMIAThACbgjiDAEJo44wZBQHgyCyiDygKDKiuATFkAxgSCCsk1ZZRILKkgEF\ngmw//6i+5opJbvc93eec6vN+nuc+z73J6aq6t897urq6u8rcHRFJ15BON0BEWqMQiyROIRZJnEIs\nkjiFWCRxCrFI4hRikcQpxCKJU4hFEqcQiyROIRZJnEIskjiFWCRxCrFI4hRikcQpxCKJU4hFEqcQ\niyROIRZJnEIskjiFWCRxCrFI4hRikcQpxCKJU4hFEqcQiyROIRZJnEIskjiFWCRxCrFI4hRikcQp\nxCKJU4hFEqcQiyROIRZJnEIskjiFWCRxCrFI4hRikcQpxCKJU4hFEqcQiyROIRZJnEIskjiFWCRx\nCrFI4hRikcQpxCKJW63TDWgXMxsKjAf2BEYDOwJrAQ4sBRYBC4Efu/uCTrVTVs3MtgUOIu7DtwEb\nE9/Hy4DFxH14C3Cduz/fqXa2k7l7p9tQKzNbB5gGHAs8DlxP3NF3Ak8DBmxEfEPsAvwj8ATwZeBy\nd3+5A82WfszMgAOBjwBvBb4DzCfuxyXAS8BIYAdiuPfNX3cJcK67L+lAs9um0SE2s/HARcA84Ivu\nfluBbYYC7wJOAYYBR7n74lobKitlZhsDs4Dtgc8BVxc5wprZNsDxwAeATwKXekPf7I0MsZkNAc4B\nJgJT3f0ngyzjWGAGcLK7X1ppI2VAZrY3cBXxiDpjMN1jM9sZuBT4DTDF3Z+rso3doHEhzsN3EbAN\n8B53f6rF8rYDfgx83t3Pr6CJUkDei7oCmOzuP2uxrGHA14AtgQlNC3ITQ/xZYG9gf3d/tqIy3wDM\nBU5w92urKFNWzsx2An4KHOLu8yoqcwjxiLxOXm5j3viNCrGZvQO4BtjJ3R+vuOzdiQMqlZcty+VH\nzV8A57n7xTWUPR/4krtfUmXZndSYEJvZCOJlounu/u2a6jgH2MzdJ9dRvoCZTQd2B95dx9EyP0e+\nARjVlFHrJoX4COBwdx9fYx1rAA8Ce2vEunpmNhJ4GNjF3R+ssZ4vAC+4+yl11dFOTbpjaxrw33VW\n4O7LiAMkx9ZZTw+bAsytM8C5WcBRee8teY04EpvZ9sQR5K3K3JwRQjgU+Drw+izLnihY1xbEbvuG\nuhGkWmZ2M3Cmu/+oyOtDCNsCXwQ2BIYCNwMnZVn25wJ1zQYudPerW2hyV2jKkXg3YM4gQjUFuB94\nX9EN3P1h4EkglKxLVsHMVgdGATcVeX0IYShxoPHsLMvGEO+2Azi9YJU/I75vkteUEI8m3oJXWAhh\nfWAM8Ang0JL1LWT5m0aqsSNwf4lruOOBxVmWzQHIssyJd2adWXD7hcT3TfKaEuLtgHtLbjMJ+D7w\nI2DbEMKmJba9Dx2JqxaIf9eitiOe1vxVlmXLinSlc/fmZSSvKSFeAyh7Y8cU4Mosy14Gvg28v8S2\nzwBrlqxPVq3sPnTiefBgNWYfNiXEL1HiscoQwmbA24GZIYRFxAceylz7HQa8WKqFMpBS+5D42OGY\n/v8QQlg9hLBjwe0bsw+bEuI/AGW6w4cCX86ybOcsy0YRu3LrhxDeWHD7TYHHSrZRVu0xyu3D2cCW\nIYSDAEIIQ4CzKN6jasw+bEqIb6fcIMWhxCdjgL8OilxG8aNx6YE0GdDtwNvyZ4cHlGXZK8ABwDEh\nhAXEx02fBj5dsL7G7MOmXCfeHzjN3ce1oa7hxJlANmv1CSn5W2b2W2Bfd/9VG+r6CpC5+xfqrqtu\nTTkSzwXebGZbt6Gug4EFCnAt/of4EH+t8ttnJwLX1V1XOzQixPntkJcBH2pDdccTp+6R6p0PTM17\nO3WaBCx091/XXE9bNCLEufOBo81sk7oqMLN9gTcAeqa4Bu5+L/FacW0fxvn90v8KfKmuOtqtEefE\nfcxsBnHA4qCqH2Mzs7WAu4Hj3f36KsuW5fKZVOYBY9z9gRrK/zxx1pdJTZkYoGkhHg7cBlzg7rMq\nLNeIs0K4ux9ZVbmyYmZ2EnAI8M78VKmqcvcjTvmzs7v/oapyO61RIYa/znI4B/iUu19eQXl9c3Yd\nBLzB3Z9ptUxZtXzG0TuA54B9qgiymY0jPjAxyd3ntFpeN2nSOTEA+WDFeOBzZjajlUESM3stcbbF\ntwMbAKVnzZRBOQ14C/G2yBtauepg0QeJ0zYd1rQAQwNDDODu9xEfM9sFmG9mpZ44MrMhZjaReA78\nKLAr8G7gHfkzr1ITMzuDeMPGDOIE8N8l7sMPl/1Azic4/CHxisK+7j674uZ2B3dv7BdxdYcjgUeI\n15IPBdZbxWs3BT4GZMQVIsa96jUHEm+8v7nTv1sTv4Az8r/vGa/690C8zfJ3xEcN30R+KriCMtYE\nJhCvOT8JTAeGdfp3q/OrcefEK5LPcngQ8dLF7sR7Zu8EXgZeAdYmLuMylNhlnkUM6t/9cczsQOAH\nwC3uvntbfoEe0P8I7O5nrOQ1OwDHAe8h7rM7iPfNr058gGJ7YGviI4oXA1d6RdMWd7OeCHF/+aDJ\nm4jnXN8kTnz3ceIb4rcrCu4KylCQK1QkwCvYpm/9rCOIPawpxB7UPe7+Qj0t7U49F+L+zMyB29x9\nzIAv/vttFeQKDCbAr9r+RGCmuxd6cKKJGjmw1Q4eb/jQYFcLWg2wRApxCxTkwVOAq6MQt0hBLk8B\nrpZCXAEFuTgFuHoKcUUU5IEpwPVQiCukIK+cAlwfhbhiCvLfU4DrpRDXQEFeTgGun0JcEwVZAW4X\nhbhGvRxkBbh9FOKa9WKQFeD2UojboJeCrAC3n0LcJr0QZAW4MxTiNmpykBXgzlGI26yJQVaAO0sh\n7oAmBVkB7jyFuEOaEGQFuDsoxB2UcpAV4O6hEHdYikFWgLuLQtwFUgqyAtx9FOIukUKQFeDupBB3\nkW4OsgLcvRTiLtONQVaAu5tC3IW6KcgKcPdTiLvUq4K8zMz+aGZ7taNuM9vSzJaa2WIU4K63Wqcb\nICvn7teb2bnAR4ERwFVmtk2R9YXMbGPiukQjgBeAJcBvBlqmJl9Q/SpgHWA9IFOAu5tC3P3GAC8C\nw4B1gf8iLgz3N8xsNeAfgMPzbUYCvwSeB4YDWwHDzWwBcDVxsbHnVlDfscR1qoYSF5zbqNpfR6qm\n7nT3O5C4QPZzwBrA4f271WY21Mw+SlwY7iTge8BewGvdfTd339vdd3f3TYjhvBB4L/Cwmf2Hma3Z\nr6ytgP8kfgA8C8wnrhEsXUwh7nLu/pS7TwYmAUuJy3heZWYjzSwANwGHAAe5+1h3v8zdH1hRt9nd\nf+fu17j7e4gLsG8B3GlmY/Nu9JXE9X2fJX4g7OHuv2nLLyqDphAnIh/oeiOxK/w64rKs84jB29fd\nF5Us7yF3Pww4OS/zG8BuwK3ADu7+lSLLvErn6Zw4Ie7+FDDZzBYBpwMHuPvcFsu81szuIXadrwAO\nV3jToiNxYsxsc+BEYEKrAe7j7r8mDobtB+xRRZnSPgpxQvLz1q8C57r7nCrLzoN8HHBJ/8Eu6X4K\ncVqmABsAZ9VRuLtfS+xWT6+jfKmHQpyI/Ch8IjDd3V+qsapTgWN0NE6HQpyOMcBrgJ/UWYm7PwTc\nAkyusx6pjkan0/FPwGXu/kqRF4cQtgLuBhYCTrz98uQsy+YV2PwS4Bjg4sE1VdpJR+J07EI8QpaR\nZVm2d5Zl+wCnAKcV3O5WYHTehZcupxAnwMyGEG9/vL2FYjYmPgQxIHd/FHiJeEeXdDl1p9OwNoC7\nP1lyuxBC+DmxK70pcECJbR/Mt9Ftl11OR+I0DCc+TlhWX3d6N2A88M0QQtEP7hfzeqXLKcRpeB4Y\n0co5apZli4FlwOYFNxkB/Hmw9Un7KMRpeIYYwE0GW0AIYX3g9RQ4L84/LN4E3D/Y+qR9dE6cAHd3\nM1sIjKbg4FSu75wY4pH1hCzLinTLtwb+5O6PlWupdIJCnI75wDjiQ/8DyrLsIfIBsUEYB9w2yG2l\nzdSdTsc3gCPMbPU21DUV+Hob6pEKKMSJcPdfAncC76uzHjMbRRz8+kGd9Uh1FOK0nAV81szWqaPw\n/KaSc4Fzan7IQiqkECfE3X8K3ECczK4OJxBnufxyTeVLDRTi9HwC2N/MjqqyUDPbk3hv9dHu/nKV\nZUu9NDqdGHd/2szeBdxoZrj7Ja2WaWZ7A98CJufn3pIQHYkT5O4ZsA9wupldMNhzZDMbZmbTiQF+\nf95dl8QoxInKgzyKuA/vMrMPFL38ZGZDzGwC8ZHDscBod7+xvtZKndSdTpi7Pw1MNbP9gE8CM83s\ncmAucTKAJX3Tz5rZa4l3fO0G/DPw/8BM4nIumqI2YQpxA+Td4J+a2bbAB4jrKY0GRprZ88R1nBy4\ngxjuQ919fqfaK9VSiBvE3X9FXIq07yGGtYjLvrwAPFN0ah9Ji0LcUHkX+U/5lzSYBrZEEqcQiyRO\nIRZJnEIskjiFWCRxCrFI4hRikcQpxCKJU4hFEqcQiyROIRZJnEIskjiFWCRxCrFI4hRikcQpxCKJ\nU4hFEqcQiySuJ0Ns0cz8x13N7EMdbZAMipkdAMzIv7/czNbocJM6oidDnPuXft+/v2OtkFbsA4zM\nv5/S7/ue0pMhzieRuyv/8SVgTgebI4M3n+UTAS519yc62ZhO6ckQ5+YALwPPArd1uC0yOAuJc2pD\nnFO7J/VyiG8jBngE8c0g6XkYeIUe7031cogXEidXf87d/9Dpxkh5+WnR3cQ1lRd0uDkd08shXgI8\nAyzqdEOkJXMAo4d7Uz21AkS+tMnWxHWKdgT+CIwws1OJYV7o7r/vYBOlADN7DfA24n7cjnhadKKZ\nZcQj8uJeWijdemFBPDPbBJiafznxU/su4simARsAbyW+KR4BZgFXuLuWQOkSZjYcOBiYRtxPdxL3\n46PE8+IRwPb5/20AXA6c7+6LO9LgdnL3xn4BawDnAEuJwXzLAK83YD/gO8CTxNUFrdO/R69/Ae8m\nDmLdCEwChg3w+i2AzwC/B64CNuz071DnV2OPxGY2hvhpfAfwYXd/vOT22wOXErvcR7r7I5U3UlbJ\nzNYGzgPGAR/0kguh53dwnUlc7vU4d7+2+lZ2XiMHtvLb8X4ATHf3yWUDDODu9wG7AzcB8/K1f6VN\nzGw94Ib8x53KBhjA3Ze5+8nAROA8MzuuyjZ2i8YNbJnZWOAbwMHu/r+tlOXuLwFnmtnvgdlmNlZH\n5PqZ2ZrA9cAtwMe9xe6iu99sZnsCN5rZMne/tIJmdo1GdafzUcu7ganu/qOKyz4d2BMY3+qbSlbN\nzM4FXgdMrvJvbWbbAfOAsd6gAa+mhfgSYJm7T6uh7NWAm4GvufsFVZcvUX7EvJI4CLm0hvKPJ54j\nj/WGXIZqTIjNbFfgGuDN7v5MTXXsQDxH3rKuOnpZfh1/ETDD3a+pqY4hwM+Bi5vSrW7SwNY04Lw6\nw+Xu9xLvEDqsrjp63B7A6sB366rA3V8Bzia+XxqhESE2s/WBQ4CLi7w+hPDOEMLP+/28aQghCyGs\nU2DzWTToDdBlphFv0BiwexhCmB5C+Fy/n4eEEBaFEHYqUM8PgY3MbJcW2to1GhFi4nXEXxS9lJRl\n2Q3AwyGEI/J/mgmcmmXZHwts/jNgczPbeHBNlRXJu9L7A98uuMlMYGIIYdP856OAX2RZdtcqtgEg\nPxf+Tl5f8poS4tGUf4rlROBTIYSDgbWzLCv05sm7Y7fndUp1tgBedPclRV6cZdky4N+Bz4QQ1gRO\nAk4rUd8CGrIPmxLiUcRgFZZl2RPET/NvAieUrO924r3WUp1RlH+w/wrgzcBXgUuzLHusxLa353Um\nrykhXh8oswP77Aw8BJQ9N3ocWG8Q9cnKld6HWZY58G/Euba+WLK+x/M6k9eUEBvx6aTCQghjgB2I\nb4AZIYS1Smz+Cs3523WL0vsw9wDwaJZlfy65XWP2YSN+CeLD/esWfXEIYTXiKPNHsix7lDiqPaNE\nfevmdUp1Su3DCqzL8kn2ktaUEN9D7BoX9QlgTpZl9+Y/nwuMDyG8peD2OxNv75Tq3AMUuTxUlZ3z\nOpPXiDu2zGwKMNHdJ7apvkeAPd39gXbU1wvMbCjwFLC5uz/VhvrOBIa6+6l111W3phyJ5wF7tWMF\nADPrO1o8WHddvSS/dnszMKFNVU4A5raprlo1IsTu/jBxIvF2rOQwDbhQTzLV4kLacDdcfp/9BsDs\nuutqh0aEODcL+GjeLauFmW1I/KC4qK46etz3gK3bcDvkx4ALmvIUU5NCfD1xKp2P1FjHecRHER+t\nsY6e5e4vAqcCF5nZsIFePxj5rC97ED/0G6ERA1t9zOyNwK3AuKof+jazScTb/N7q7suqLFuWy++h\n/j6wwN0/XXHZryHOcnq0u98w0OtT0aQjMe5+P3AKcL2ZbVZVuWbW98l9uAJcr3ys4RjgaDM7YqDX\nF2VmI4HrgKubFGBo4Bxb7n5x/mjiTWZ2YKtHZDObAHwdOMzdtfBaG7j7krzbOzuf8XJWKwOJZrYR\n8RnlxcDJFTWzazTqSNzH3c8hzjs818xOHMxgl5mtbWbnAxcAh7j7T6pup6xcPtvoXsSj8nX5AgCl\n5adBdxHnrJ6aP4XWKI0MMcQjMrAb8F5ggZkdVeQ6spltaGafAu4DhhPneppXb2tlRdz918CuxJUe\n7jSzs81s64G2M7PVzOy9ZjabOO/0we4+vYkBhoYNbK1IPqfSgcBxwNuJc2QtIC4D8jTxg2xD4to+\nu+SvuYbYhevZlfa6jZltQ1yR40hit3g+MdxLgBeJK1zuQHxGeC/gt8RxjG+5e9mHI5LS+BD3Z2Zb\nEC8v9C2otjbxaZal5AuqAfPqmGVRqpH3psYS9+FoYCPiQuPPEcO9ELjV3f+vY41ss54KsUgTNfac\nWKRXKMQiiVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQiiVOIRRKnEIskTiEWSZxCLJI4\nhVgkcQqxSOIUYpHEKcQiiVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQiiVOIRRKnEIsk\nTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQiiVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQi\niVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOL+AgEKR0Xy3CYuAAAAAElFTkSuQmCC\n", 394 | "text/plain": [ 395 | "
" 396 | ] 397 | }, 398 | "metadata": { 399 | "tags": [] 400 | }, 401 | "output_type": "display_data" 402 | } 403 | ], 404 | "source": [ 405 | "#@title # Game 4\n", 406 | "#@markdown Pearl named this particular configuration \"M Bias\", not only because of it's shape, but also because of the common practice of statisticians to want to control for B in many situations. However, notice how in this configuration X and Y start out as *not confounded* and how by controlling for B we would actually introduce confounding by opening the path at the collider, B. \n", 407 | "pgm = PGM(shape=[4, 4])\n", 408 | "\n", 409 | "pgm.add_node(daft.Node('X', r\"X\", 1, 1))\n", 410 | "pgm.add_node(daft.Node('Y', r\"Y\", 3, 1))\n", 411 | "pgm.add_node(daft.Node('A', r\"A\", 1, 3))\n", 412 | "pgm.add_node(daft.Node('B', r\"B\", 2, 2))\n", 413 | "pgm.add_node(daft.Node('C', r\"C\", 3, 3))\n", 414 | "\n", 415 | "\n", 416 | "pgm.add_edge('A', 'X')\n", 417 | "pgm.add_edge('A', 'B')\n", 418 | "pgm.add_edge('C', 'B')\n", 419 | "pgm.add_edge('C', 'Y')\n", 420 | "\n", 421 | "pgm.render()\n", 422 | "plt.show()" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 30, 428 | "metadata": { 429 | "colab": { 430 | "base_uri": "https://localhost:8080/", 431 | "height": 51 432 | }, 433 | "colab_type": "code", 434 | "id": "CBaGzLKSFmnQ", 435 | "outputId": "95c1c6a5-dbe3-4768-8f37-94d9547e9db6" 436 | }, 437 | "outputs": [ 438 | { 439 | "name": "stdout", 440 | "output_type": "stream", 441 | "text": [ 442 | "Are there are active backdoor paths? False\n", 443 | "If so, what's the possible backdoor adjustment sets? frozenset()\n" 444 | ] 445 | } 446 | ], 447 | "source": [ 448 | "graph = convert_pgm_to_pgmpy(pgm)\n", 449 | "inference = CausalInference(graph)\n", 450 | "print(f\"Are there are active backdoor paths? {inference._has_active_backdoors('X', 'Y')}\")\n", 451 | "adj_sets = inference.get_all_backdoor_adjustment_sets(\"X\", \"Y\")\n", 452 | "print(f\"If so, what's the possible backdoor adjustment sets? {adj_sets}\")" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 31, 458 | "metadata": { 459 | "cellView": "form", 460 | "colab": { 461 | "base_uri": "https://localhost:8080/", 462 | "height": 258 463 | }, 464 | "colab_type": "code", 465 | "id": "ZAbSVPvZFxZH", 466 | "outputId": "88a7fff6-0744-4dda-b5e9-8d93dcea84d7" 467 | }, 468 | "outputs": [ 469 | { 470 | "data": { 471 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPEAAADxCAYAAAAay1EJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFdFJREFUeJzt3Xm4HFWdxvHvLxuBEBBkUXYQPKyC\nJgQEwmpggkZgkDGEZYSZIAbcEMSRgASXcSGjcTAIYgCRBxBERxRRghqSQQgJJmyTUllEgrJF0EBY\n85s/Tl1yxSS3+3ZVdZ/q9/M8/Ty5N13nVN/qt+tU9VnM3RGRdA1o9w6ISGsUYpHEKcQiiVOIRRKn\nEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQiiVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHE\nKcQiiVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQiiVOIRRKnEIskTiEWSZxCLJI4hVgk\ncQqxSOIUYpHEKcQiiVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQiiVOIRRKnEIskTiEW\nSZxCLJI4hVgkcYPavQNVMbOBwBhgX2AEsDOwNuDAEmABMB/4mbvPa9d+yuqZ2XbAOOIxfAewMfF9\nvAxYRDyGvwZucPcX2rWfVTJ3b/c+lMrM1gEmAScDTwI3Eg/0QuBZwICNiG+IkcA/A08B3wCucPdX\n27Db0ouZGXAo8BHg7cD3gbnE47gYeAUYBuxEDPeB+fMuBaa5++I27HZlah1iMxsDXALMAb7m7nc2\nsM1A4J+AM4HBwAnuvqjUHZVVMrONgenAjsAXgGsbOcOa2bbAKcCxwCeBy7ymb/ZahtjMBgDnA0cC\nE9395/0s42RgCnCGu19W6E5Kn8xsf+Bq4hl1Sn+ax2a2K3AZ8Adggrs/X+Q+doLahTgP3yXAtsB7\n3f2ZFsvbHvgZ8EV3v7CAXZQG5K2oK4Hx7v6LFssaDHwb2BIYW7cg1zHEnwf2Bw529+cKKnNrYDZw\nqrv/sIgyZdXM7G3ALcAR7j6noDIHEM/I6+Tl1uaNX6sQm9k7geuBt7n7kwWXvRfxhkrhZcsK+Vnz\nDuACd59RQtlzga+7+6VFlt1OtQmxmQ0lfk002d2vK6mO84HN3H18GeULmNlkYC/g3WWcLfNr5JnA\nbnW5a12nEB8PHOfuY0qsY03gIWB/3bEunpkNAx4BRrr7QyXW81XgJXc/s6w6qlSnHluTgP8uswJ3\nX0a8QXJymfV0sQnA7DIDnJsOnJC33pJXizOxme1IvIO8VTOdM0IIRwPfAd6cZdlTDda1BbHZvqE6\nghTLzG4DznP3mxp5fghhO+BrwIbAQOA24PQsy15soK6bgYvd/doWdrkj1OVMvCcwqx+hmgA8ALyv\n0Q3c/RHgaSA0WZeshpmtAewG3NrI80MIA4k3Gr+cZdkoYm87gHMarPIXxPdN8uoS4hHELngNCyGs\nD4wCPgEc3WR981nxppFi7Aw80MR3uGOARVmWzQLIssyJPbPOa3D7+cT3TfLqEuLtgfua3OYo4MfA\nTcB2IYRNm9j2fnQmLlog/l0btT3xsuY1WZYta6QpnbsvLyN5dQnxmkCzHTsmAFdlWfYqcB3w/ia2\nXQqs1WR9snrNHkMnXgf3V22OYV1C/ApNDKsMIWwG7AFMDSEsIA54aOa738HAy03tofSlqWNIHHY4\nqvcvQghrhBB2bnD72hzDuoT4caCZ5vDRwDeyLNs1y7LdiE259UMIb2lw+02BJ5rcR1m9J2juGN4M\nbBlCGAcQQhgAfInGW1S1OYZ1CfFdNHeT4mjiyBjgtZsil9P42bjpG2nSp7uAd+Rjh/uUZdly4BDg\npBDCPOJw02eBzzRYX22OYV2+Jz4YONvdR1dQ1xDiTCCbtTpCSv6emf0RONDdf1dBXd8EMnf/atl1\nla0uZ+LZwA5mtk0FdR0OzFOAS/E/xEH8pcq7zx4J3FB2XVWoRYjz7pCXAx+soLpTiFP3SPEuBCbm\nrZ0yHQXMd/ffl1xPJWoR4tyFwIlmtklZFZjZgcDWgMYUl8Dd7yN+V1zah3HeX/o/gK+XVUfVanFN\n3MPMphBvWIwrehibma0N3AOc4u43Flm2rJDPpDIHGOXuD5ZQ/heJs74cVZeJAeoW4iHAncBF7j69\nwHKNOCuEu/sHiipXVs7MTgeOAN6VXyoVVe5BxCl/dnX3x4sqt91qFWJ4bZbDWcCn3P2KAsrrmbNr\nHLC1uy9ttUxZvXzG0d8AzwMHFBFkMxtNHDBxlLvParW8TlKna2IA8psVY4AvmNmUVm6SmNkbibMt\n7gFsADQ9a6b0y9nALsRukTNb+dbBon8jTtt0TN0CDDUMMYC7308cZjYSmGtmTY04MrMBZnYk8Rr4\nMWB34N3AO/Mxr1ISMzuX2GFjCnEC+B8Qj+GHm/1Azic4/CnxG4UD3f3mgne3M7h7bR/E1R0+ADxK\n/C75aGC91Tx3U+BjQEZcIWL0655zKLHj/W3tfm11fADn5n/fc1/3+0DsZvkn4lDDt5JfCq6kjLWA\nscTvnJ8GJgOD2/3aynzU7pp4ZfJZDscRv7rYi9hndiHwKrAcGE5cxmUgsck8nRjUf/jjmNmhwE+A\nX7v7XpW8gC7Q+wzs7ueu4jk7AR8C3ks8Zr8h9ptfgziAYkdgG+IQxRnAVV7QtMWdrCtC3Ft+0+St\nxGuua4gT332c+Ib448qCu5IyFOQCNRLglWzTs37W8cQW1gRiC+ped3+pnD3tTF0X4t7MzIE73X1U\nn0/+x20V5AL0J8Cv2/40YKq7NzRwoo5qeWOrCh47fOhmVwtaDbBECnELFOT+U4CLoxC3SEFungJc\nLIW4AApy4xTg4inEBVGQ+6YAl0MhLpCCvGoKcHkU4oIpyP9IAS6XQlwCBXkFBbh8CnFJFGQFuCoK\ncYm6OcgKcHUU4pJ1Y5AV4GopxBXopiArwNVTiCvSDUFWgNtDIa5QnYOsALePQlyxOgZZAW4vhbgN\n6hRkBbj9FOI2qUOQFeDOoBC3UcpBVoA7h0LcZikGWQHuLApxB0gpyApw51GIO0QKQVaAO5NC3EE6\nOcgKcOdSiDtMJwZZAe5sCnEH6qQgK8CdTyHuUK8L8jIz+6uZ7VdF3Wa2pZktMbNFKMAdb1C7d0BW\nzd1vNLNpwEeBocDVZrZtI+sLmdnGxHWJhgIvAYuBP/S1TE2+oPrVwDrAekCmAHc2hbjzjQJeBgYD\n6wL/RVwY7u+Y2SDgPcBx+TbDgN8CLwBDgK2AIWY2D7iWuNjY8yup72TiOlUDiQvObVTsy5GiqTnd\n+Q4lLpD9PLAmcFzvZrWZDTSzjxIXhjsd+BGwH/BGd9/T3fd3973cfRNiOC8GDgMeMbP/NLO1epW1\nFfAV4gfAc8Bc4hrB0sEU4g7n7s+4+3jgKGAJcRnPq81smJkF4FbgCGCcu+/j7pe7+4Mraza7+5/c\n/Xp3fy9xAfYtgIVmtk/ejL6KuL7vc8QPhL3d/Q+VvFDpN4U4EfmNrrcQm8JvIi7LOocYvAPdfUGT\n5T3s7scAZ+RlfhfYE7gd2Mndv9nIMq/SfromToi7PwOMN7MFwDnAIe4+u8Uyf2hm9xKbzlcCxym8\nadGZODFmtjlwGjC21QD3cPffE2+GHQTsXUSZUh2FOCH5deu3gGnuPqvIsvMgfwi4tPfNLul8CnFa\nJgAbAF8qo3B3/yGxWT25jPKlHApxIvKz8GnAZHd/pcSqzgJO0tk4HQpxOkYBbwB+XmYl7v4w8Gtg\nfJn1SHF0dzod/wJc7u7LG3lyCGEr4B5gPuDE7pdnZFk2p4HNLwVOAmb0b1elSjoTp2Mk8QzZjCzL\nsv2zLDsAOBM4u8HtbgdG5E146XAKcQLMbACx++NdLRSzMXEQRJ/c/THgFWKPLulwak6nYTiAuz/d\n5HYhhPArYlN6U+CQJrZ9KN9G3S47nM7EaRhCHE7YrJ7m9J7AGOCaEEKjH9wv5/VKh1OI0/ACMLSV\na9QsyxYBy4DNG9xkKPBif+uT6ijEaVhKDOAm/S0ghLA+8GYauC7OPyzeCjzQ3/qkOromToC7u5nN\nB0bQ4M2pXM81McQz66lZljXSLN8G+Ju7P9Hcnko7KMTpmAuMJg7671OWZQ+T3xDrh9HAnf3cViqm\n5nQ6vgscb2ZrVFDXROA7FdQjBVCIE+HuvwUWAu8rsx4z24148+snZdYjxVGI0/Il4PNmtk4Zheed\nSqYB55c8yEIKpBAnxN1vAWYSJ7Mrw6nEWS6/UVL5UgKFOD2fAA42sxOKLNTM9iX2rT7R3V8tsmwp\nl0KcGHd/FvgxMKOoIJvZ/sB1wPj82lsSohAnJl8b6VTgAuAcM7uov9fIZjbYzCYD3wPenzfXJTEK\ncUJet7jZh4HdiMfwbjM7ttGvn8xsgJmNJQ453AcY4e6/LGm3pWTq7JGIla1OmDetJ5rZQcAngalm\ndgUwmzgZwOKe6WfN7I3EHl97Av8K/AWYSlzORVPUJkwhTkBfy4vmzeBbzGw74FjiekojgGFm9gJx\nHScHfkMM99HuPreavZeyKcQdrpn1gd39d/lzewYxrE1c9uUlYGmjU/tIWhTiDtbKAt95E/lv+UNq\nTDe2OlQrAZbuohB3IAVYmqEQdxgFWJqlEHcQBVj6QyHuEAqw9JdC3AEUYGmFQtxmCrC0SiFuIwVY\niqAQt4kCLEVRiNtAAZYiKcQVU4ClaApxhRRgKYNCXBEFWMqiEFdAAZYyKcQlU4ClbApxiRRgqYJC\nXBIFWKqiEJdAAZYqdWWILZqa/7i7mX2wwLLPRQGuhJkdAkzJ/32Fma3Z5l1qi64Mce7fe/37/UUU\nqABX7gBgWP7vCb3+3VW6MsT5JHJ35z++AsxqtUwFuC3msmIiwCXu/lQ7d6ZdujLEuVnAq8BzwJ2t\nFKQAt8184pzaEOfU7krdHOI7iQEeSnwz9IsC3FaPAMspqDWVqm6ed3o+cXL1Z9398WY2NLNtiG+c\nE1GA28bd3czuAfYA5rV7f9qlm0O8GFgKLGhmIzNbn/iGWS//lQLcXrOI60v1uzWVuq4Kcb60yTbE\ndYp2Bv4KDDWzs4hhnu/uf+6jmIuAtXr9PKOMfZVVM7M3AO8gHsftiZdFp5lZRvyAXdRNC6VbNyyI\nZ2abABPzhxM/te8m3tk0YAPg7cQ3xaPAdOBKd//b68oZB1xNDPGLwPPEZUEfquaVdC8zGwIcDkwi\nHqeFxOP4GPG6eCiwY/5/GwBXABe6+6K27HCV3L22D2BN4HxgCTGYu/TxfAMOAr4PPE1cXbDng269\nvBwnfvJ/D1iv3a+xGx7Au4k3sX4JHAUM7uP5WwCfA/5M/NDdsN2vodS/T7t3oMQDPwrI+nsQiZ/q\nc4GZwGbAtXmAlwDvaffr64YHMBy4HHgQOKAf268JfAX4E3B4u19PWY9aNqfz7njfBSa5+7UtlDMI\n+DSxCbcxMcgfdPe/FLKjskpmth5wE7AIOMXdl7ZQ1l7EltPn3f3CgnaxY9QuxGa2D/AD4ifv/xZU\n5oeBs4CR7v5oEWXKqpnZWsAtwB3Ax72AN2n+teAvgc+4+2WtltdJahXi/K7lPcBEd7+p4LLPAfYF\nxhTxppJVM7NpwJuA8UX+rc1se2AOsI/X6IZX3UJ8KbDM3SeVUPYg4Dbg2+5+UdHlS2Rm+wJXEW9C\nLimh/FOAY4lBrsXXULUJsZntDlwP7NDK9VMfdewE3ApsWVYd3Sz/Hn8BsQPN9SXVMQD4FTCjLs3q\nOvWdngRcUGa43P0+Yg+hY8qqo8vtDaxBvKdRCndfDnyZ+H6phVqEOO8KeQQN9p4KIbwrhPCrXj9v\nGkLIQgjrNLD5dGr0Bugwk4gdNPpsHoYQJocQvtDr5wEhhAUhhLc1UM9PgY3MbGQL+9oxahFiYDRw\nh7s/2ciTsyybCTwSQjg+/9VU4Kwsy/7awOa/ADY3s437t6uyMnlT+mDgugY3mQocGULYNP/5BOCO\nLMvuXs02AOTXwt/P60teXUI8guZHsZwGfCqEcDgwPMuyht48eXPsrrxOKc4WwMvuvriRJ2dZtgz4\nLPC5EMJawOnA2U3UN4+aHMO6hHg3YrAalmXZU8RP82uAU5us7y5iX2tpgpntZ2a3mdlXzOx9ZrZ1\nfgaGeAybHdh/JbAD8C3gsizLnmhi27vyOpNXl1FM6wPNHMAeuwIPAyOBZgYxPAmcamaH9aPObrY7\nsevqHsRhoIMAN7P7iN1Zn22msCzLPITwaWLvvBOb3Jcnie+b5NUlxEZ8czQshDAK2Ik42drMEMJP\nsyxr9M72cmLzb4um9lJ6DCD2i15KHH20C/A4/RvY/yDwWJZlLza53XJq0hKtxYsgvhnWbfTJIYRB\nxLvMH8my7DHiXe0pTdS3LvBZdzc9Gn8QZ6RcRhxCOI04NHQH4iyVZwIDmzgGrVqXFZPsJa0uIb6X\n2DRu1CeAWVmW3Zf/PA0YE0LYpcHtdyV275QmuPtVwNruPtLdP+7u17j7A/lXSvcCjXw9VJRd8zqT\nV4seW2Y2ATjS3Y+sqL5HgX3d/cEq6usGZjYQeAbY3N2fqaC+84CB7n5W2XWVrS5n4jnAflWsAGBm\nPWcLzeZRoPy729uAsRVVORaYXVFdpapFiN39EeIA/kJWcujDJOBijWQqxcVU0Bsu72e/AXBz2XVV\noRYhzk0HPpo3y0phZhsSPyguKauOLvcjYJsKukN+DLioLqOY6hTiG4mzV36kxDouIA5FfKzEOrqW\nu79MnHzhEjMb3Nfz+yOf9WVv4od+LdTixlYPM3sLcDswuuhB32Z2FLGb39vdfVmRZcsKeQ+uHwPz\n3P0zBZf9BuIspye6+8wiy26nOp2JcfcHiN833mhmmxVVrpn1fHIfpwCXK7/XcBJwopkd39fzG2Vm\nw4AbgGvrFGCoT4+t17j7jHxo4q1mdmirZ2QzGwt8BzjG3VtaeE0a4+6L82bvzWY2HJjeyo1EM9uI\nOEZ5EXBGQbvZMWp1Ju7h7ucT5x2ebWan9edml5kNN7MLiSs+HOHuPy96P2XV3P1+YD/iWfmGfAGA\npuWXQXcTJ8mbmI9Cq5VahhjiGZm4Rs9hwDwzO6GR75HNbEMz+xRwPzCEONfTnHL3VlbG3X9PHDQx\nH1hoZl/OZ61cLTMbZGaHmdnNwHnEmU8n1zHAULMbWyuTz6l0KPAh4uiZW4kd7RcSR80MADYkru0z\nMn/O9cQmXNeutNdpzGxb4oocHyA2i+cSw70YeJm4wuVOxDHC+wF/JN7H+J67Nzs4Iim1D3FvZrYF\n8euFngXVhhNHsywhX1ANmOMlzLIoxchbU/sQj+EIYCPiQuPPE8M9H7jd3f+vbTtZsa4KsUgd1faa\nWKRbKMQiiVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQiiVOIRRKnEIskTiEWSZxCLJI4\nhVgkcQqxSOIUYpHEKcQiiVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQiiVOIRRKnEIsk\nTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQiiVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOIUYpHEKcQi\niVOIRRKnEIskTiEWSZxCLJI4hVgkcQqxSOL+H73NVvPwcCIZAAAAAElFTkSuQmCC\n", 472 | "text/plain": [ 473 | "
" 474 | ] 475 | }, 476 | "metadata": { 477 | "tags": [] 478 | }, 479 | "output_type": "display_data" 480 | } 481 | ], 482 | "source": [ 483 | "#@title # Game 5\n", 484 | "#@markdown This is the last game in The Book of Why is the most complex. In this case we have two backdoor paths, one going through A and the other through B, and it's important to notice that if we only control for B that the path: X <- A -> B <- C -> Y (which starts out as closed because B is a collider) actually is opened. Therefore we have to either close both A and B or, as astute observers will notice, we can also just close C and completely close both backdoor paths. pgmpy will nicely confirm these results for us. \n", 485 | "pgm = PGM(shape=[4, 4])\n", 486 | "\n", 487 | "pgm.add_node(daft.Node('X', r\"X\", 1, 1))\n", 488 | "pgm.add_node(daft.Node('Y', r\"Y\", 3, 1))\n", 489 | "pgm.add_node(daft.Node('A', r\"A\", 1, 3))\n", 490 | "pgm.add_node(daft.Node('B', r\"B\", 2, 2))\n", 491 | "pgm.add_node(daft.Node('C', r\"C\", 3, 3))\n", 492 | "\n", 493 | "\n", 494 | "pgm.add_edge('A', 'X')\n", 495 | "pgm.add_edge('A', 'B')\n", 496 | "pgm.add_edge('C', 'B')\n", 497 | "pgm.add_edge('C', 'Y')\n", 498 | "pgm.add_edge(\"X\", \"Y\")\n", 499 | "pgm.add_edge(\"B\", \"X\")\n", 500 | "\n", 501 | "pgm.render()\n", 502 | "plt.show()" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 32, 508 | "metadata": { 509 | "colab": { 510 | "base_uri": "https://localhost:8080/", 511 | "height": 51 512 | }, 513 | "colab_type": "code", 514 | "id": "IF1jYMq_eNHd", 515 | "outputId": "68bcc15d-c1aa-40c4-fb61-b400e7c31abe" 516 | }, 517 | "outputs": [ 518 | { 519 | "name": "stdout", 520 | "output_type": "stream", 521 | "text": [ 522 | "Are there are active backdoor paths? True\n", 523 | "If so, what's the possible backdoor adjustment sets? frozenset({frozenset({'B', 'A'}), frozenset({'C'})})\n" 524 | ] 525 | } 526 | ], 527 | "source": [ 528 | "graph = convert_pgm_to_pgmpy(pgm)\n", 529 | "inference = CausalInference(graph)\n", 530 | "print(f\"Are there are active backdoor paths? {inference._has_active_backdoors('X', 'Y')}\")\n", 531 | "adj_sets = inference.get_all_backdoor_adjustment_sets(\"X\", \"Y\")\n", 532 | "print(f\"If so, what's the possible backdoor adjustment sets? {adj_sets}\")" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 36, 538 | "metadata": { 539 | "cellView": "form", 540 | "colab": { 541 | "base_uri": "https://localhost:8080/", 542 | "height": 258 543 | }, 544 | "colab_type": "code", 545 | "id": "dSjZqd5fHF06", 546 | "outputId": "05857131-230d-4dd9-88d0-25d4b800eb68" 547 | }, 548 | "outputs": [ 549 | { 550 | "data": { 551 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPEAAADxCAYAAAAay1EJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHr9JREFUeJztnXmcHVWVx78n+0LYV4EQwlIQlgAJ\niwRIWCNoFIHIJgyokZ1RBgRlkTDiMsqHAYcgiOwIGFkURSAIRCIgJGHHlBoY2VSCGYFAEAhn/ji3\nSdt0uqveq+VVvfP9fN4n3emqc+59dX91b9269xxRVRzHqS59yi6A4zjN4SJ2nIrjInaciuMidpyK\n4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMi\ndpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInac\niuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrj\nInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInacitOv7AIUhYj0\nBfYEdgHGAJsDywEKLAQeA+YAd6rq7LLK6fSMiGwETMKu4TbAGlg7XgzMw67hg8Btqvp2WeUsElHV\nssuQKyKyPHAscDSwALgdu9CPA68BAqyONYixwH7Aq8BFwDWquqSEYjudEBEB9gFOBLYGbgIexq7j\nS8B7wFBgM0zcu4XjrgAuUNWXSih2YdRaxCKyJ3AZMAv4b1V9JME5fYGPAacC/YEjVXVergV1lomI\nrAFMA0YB3wSmJ+lhRWRD4Djgs8BXgCu1po29liIWkT7A94D9gSmqeleDNo4GpgKnqOqVmRbS6RUR\nmQDcgPWoUxsZHovIaOBK4M/AIar6VpZlbAVqJ+IgvsuADYFPquo/mrS3CXAn8G1VvTiDIjoJCKOo\n64CDVPWeJm31B34ErAfsXTch11HE5wITgL1U9c2MbK4P3A8cr6q3ZmHTWTYisiXwa+DTqjorI5t9\nsB55+WC3Ng2/ViIWkY8CNwNbquqCjG3viE2oZG7bWUroNX8H/I+qXp6D7YeBC1X1iixtl0ltRCwi\ng7DXRGeo6k9z8vE9YB1VPSgP+w6IyBnAjsDH8+gtwzPy3cBWdZm1rpOIDwcOU9U9c/QxGHgOmOAz\n1tkjIkOB54Gxqvpcjn7OB95R1VPz8lEkdVqxdSzw/TwdqOpibILk6Dz9tDGHAPfnKeDANODIMHqr\nPLXoiUVkFDaDPCLN4owoig4GrgbWiuP41YS+hmPD9tV8IUi2iMgDwDmqekeS46MoGgE8iS36UGAQ\ncEocx71OhonIDOBSVZ3eeIlbg7r0xDsAMxsQ1SHAfOCApCeo6vPA34EopS+nB0RkILAV8JuUp8Zx\nHE+I43hXbIHOmQnPuwdrN5WnLiIeg92NExNF0crAdsB/AAen9DcHW6LpZMfmwPwm3+GugS3DTMIc\nrN1UnrpsgNgE+FnKcyYDvwDuAH4YRdHacRwnbQDP4D1x1kTY95r6vCiK7sOG0msDExOe9zTWbipP\nXXriwUDahR2HANfHcbwE+ClwYIpzFwFDUvpzeqaRawhLh9M7YLvUboyiKEnnVJtrWBcRv0eKUUUU\nResA2wPnRVH0GLbhIc273/7Au6lK6PRGqmvYHXEcz8O2JK6b4PDaXMO6iPhv2FAqKQcDF8VxPDqO\n462wodzKURRtkPD8tYFXUpbR6ZlXSHcNP0SY51iLZM/FtbmGdRHxXNJNUhyM7YwBII5jBa4ieW+c\neiLN6ZW5wDZh73AaoiiK7gvPxbcDx8dx/E6C82pzDevynngv4ExV3bkAXwOwSCDrNLtDyvlXROQF\nYDdV/WMBvn4AxKp6ft6+8qYuPfH9wKYiMrIAX/sCs13AufAzbBN/roTls/sDt+XtqwhqIeKwHPIq\n4KgC3B2Hhe5xsudiYEoY7eTJZGCOqv4pZz+FUAsRBy4GPiciH8nLgYjsBqwP+J7iHFDVp7F3xbnd\njMN66a8CF+blo2hq8UzcgYhMxSYsJmW9jU1ElsPW6R6nqrdnadtZSoikMgvYTlWfzcH+t7GoL5Pr\nEhigbiIeADwCXKKq0zK0K1hUCFXVI7Ky63SPiJwMfBrYIzwqZWV3dyzkz2hV/VtWdsumViKGD6Ic\nzgROU9VrMrDXEbNrErC+qi5q1qbTMyHi6KPAW8CuWQhZRHbGIrNMVtWZzdprJer0TAxAmKzYE/im\niExtZpJERFbBoi1uD6wKpI6a6TTEmcAW2LLIu5t56yDG57GwTYfWTcBQQxEDqOoz2DazscDDIpJq\nx5GI9BGR/bFn4JeBbYGPAx8Ne16dnBCRs4GvY6GCtwZuwa7hCWlvyCHA4a+wNwq7qeqMjIvbGqhq\nbT9YdocjgBexd8kHAyv1cOzawJeAGMsQsXOXY/bBNp8/UHbd6vgBzg7f79ld/j8CZgB/Ac4BNiY8\nCnZjYwiwN/bO+e/AGUD/suuW56d2z8TdEaIcTsJeXeyIrZl9HFgCvA8Mw9K49MWGzNMwoX7oyxGR\nfYBfAg+q6o6FVKAN6NwDq+rZyzhmM+AY4JPYNXsUWzc/ENtAMQoYiUVeuRy4XjMKW9zKtIWIOxMm\nTTbGnrluxALffRlrEC90J9xubLiQMySJgLs5pyN/1uHYCOsQbAT1lKomWTtdG9pOxJ0REQUeUdXt\nGjjXhZwBjQi4y/knAeepatqNE7WhlhNbRaC24MMnu5qgWQE7hou4CVzIjeMCzg4XcZO4kNPjAs4W\nF3EGuJCT4wLOHhdxRriQe8cFnA8u4gxxIS8bF3B+uIgzxoX8YVzA+eIizgEX8lJcwPnjIs4JF7IL\nuChcxDnSzkJ2AReHizhn2lHILuBicREXQDsJ2QVcPC7igmgHIbuAy8FFXCB1FrILuDxcxAVTRyG7\ngMvFRVwCdRKyC7h8XMQlUQchu4BbAxdxiVRZyC7g1qFtRSwi48KP24bUIaVQRSG3ioBFZC0sxnhH\nuKS2pC1jbIW0LIuAwVi0y9+p6riez8q9TJWI2dUqAg5luRCLftkPC3W7rqq+VGaZyqAte+IQ0fI5\nLNZ0H1ogY3wVeuRWEnDgMeCf4ee3sUD/bUdbijgwK/y7CHiozIJ00MpCbkEBg9183w8/P5Mk3HAd\naWcRP4gJuCV64g5aUcgtKmCwXMaDsKH0b0ouS2m0s4hnA/2xrA9/LLks/0IrCbmFBYyqvgvMxzJ5\ntMRoqgzaWcTzsPrPU9X3ezu4aLoIebGIvC4i44vwLSLrichCEZlHiwq4E7/FJrZaZjRVNG0pYhEZ\nAmyFJVqbLyKbh3xNLUUQ8gXYkHEYcIOIDE1yroisISIfFZFdRWSciIwIs/K9nSdYOtflsURmcSsK\nOGSu3BhYgE1qrSoiK5VcrFJom1dMIrIV8EVgFyzp1h/Cn97FBLIu8BSWUO1SVX2hjHJ2JQynx2JD\n/8XANap6VDfH9QM+ARwGbAcMxer4NjAAGBH+nQ1Mx5KNvdWNnWOA74bzlwCvq+rKmVesAYJIjwA+\nhaU9XYglxxuMzVJvgiVYexC4Ari3LSa7yk7LmPcHmAA8ADyPpbncBhjQzXHDMIFfiKXEvBnYsAXK\nvyLWM76JTeC8BYzv9Pe+wL8DL2Az7v+G3aQ+lPoTWAvYD/g58CrwLWBIp7+PwCb7NPz7ALBeC3wH\nqwKXAP8HXAt8DFilm+P6Apti+YifBH6PJVrrNg1qXT6lFyDHC78c8H1syPwZoF/Kc08ODf1EoE8L\n1GefcHNZguXpHYoNd38L3AdsldLeCOA6bFJvJ+yd+YPYK5tFwNGt0PjDTecvwPnA6inOk3ADfxK4\nFViz7Lrk9h2VXYCcLvyaWP7ha1hGUvGEdjYKvdFPgYEtUK+OXlmBX2DPg8c3c5MB9g0iuS7YbZXe\nV4BvAn8CxjVhZyDwDeAlYIuy65XLd1V2AXK4+KuFYdRZWfQkoRHcHO7miXvznOt4Gjas3jkjexti\nz5fXtkLvG8r0bWzGedWM7B0I/BUYVXbdsv7UamJLRPoAM7Ccw6dlaHcAcBvwaJZ2GyzLuljjnqyq\nMzO0uyFwf7A7q7fj80REDgNOx3rgv2dsdyqwpaouyspu2dRNxMdis7M7qeqSjG2vgQ3RP6mqD2dp\nO0UZBPgVcL+qnpuD/X2xmenR2s3MdRGEnUmPAx9T1bk52L8CeEtVj8vadlnURsTh4j+JCXheTj4O\nBM7E7uSFLxARkUOBLwM7qOp7Ofm4Dvizqn4tD/sJ/N8I/EFVz8zJ/opYO/mMqj6Yh4+iqZOIzwLW\nUtVjcvQhwFzgVFW9Ky8/PfieDZyuqnfk6GdE8DO86N5YRNbDHhXWU9U3c/RzPDafcGBePoqkFiu2\nwmqro4CL8/Sjdse7CHsPWTTbYbPTud48VPV/sVdNB+XpZxkchS1myU3AgWuAvcLorfLUoicWkV2A\n81V1TNJzoijaCPhvbDa7L/Zq5eQ4jv/Z03lh2eMr2DvLvBtbZ7/nAa+p6jlJjo+iaAQ2bJyDva55\nD/hmHMe/TuBrP+CLqvqxxkucHhH5I3CAqj6e5PgudezMfnEcL+zF11VYMIhpjZS1lehXdgEyYizW\neyQiiqK+wE3ACXEcz4yiSLCVWmdhs6LLRFXfFJHfA6Mx4RfFWOx9ZxriOI4nAERRtAFwWxRFB8Vx\n/EQv5z0EXCIiogXd5cOSyjWxpa9p+KCOKXkI+04rTy2G08AY0u1i2ROYF8fxTIA4jhX4CpColwu+\nCmsA4dXZ1tjzeEPEcTwfOJcEjwKq+jLWcw9v1F8DbAM8nvVbhR6Yg7WbylOXnngdLNxOUjbBQrt8\nQBzHi1Oc/xywdorjm2UYQAbvTGdjyymT0FHHPzfpMylrk+4aNstzWLupPHUR8QBsN1JSFHsObpR3\ngs+iGBB8NsswbO11Et6l+DqmuYYdRFEU3dfp9ziO4w/t8uqGoq9hbtRFxG9je26TMg9bc/wBURQN\nBDaK4zjJM1nH1reieBsYlMEz6ljg0YTHDqKEOjZwXqPPxEVfw9yoyzPxfGxHT1JmAOtFUTQJIIqi\nPsB3sPW1Sdg4+CyKRdhe4o80aiBMbJ2E7QbqkfBOuug6Pht8FkWEba6oPHXpiVNNNMVx/H4URROB\nS6Mo+jo2tJqBratNwhjs9VQhqKqKSMdETJq4yh1DzYHY48NxcRw/n+C8kcAbqvpK6sI2zmPAKBEZ\noKppHh26DqcBvhLHcW9LY9NOhrYsdXlPPAa4HojyfiUiIqtikyKrpGxszfo9BxisqqcU4OsIYJKq\n7p+3ry5+HwOOL2IDhojcAtyqqlfl7Stv6jKcnou9EtmlAF+fB6YXKeDAtcDhIjKwAF9TgKsL8NOV\na7AQSrkiIh/BAgbckrevIqiFiEPvO42cl0OKSF/sFU3hq3xU9Q/Y7p4D8vQTYpGti6WUKZorgUki\nslrOfqYAN6jq6zn7KYRaDKcBRGR5QkylLPfZdvFxCjBRVffIw34C/7sDP8J2UWXeAMOiknuBm1T1\nwqztJyzDBcAKqnpETvZHAI9gu93iPHwUTW1EDCAik7AJpy2zXtcsIptim+a3VdUiFyV0LcdlwBLt\nJuJlBrZPxOKRjS9w5VTXMiwHPAGcoKqZjgbCrPsM4G5V/XaWtsukViKGDzZ9DwMOymrPrYisAswE\nLlLVXHdKJSjLCthM7jmqekWGdnfB1pOPC0P30hCRXYEfA7tmuTdcRL4B7I5tQ8xlP3YZ1OKZuAtH\nY4HPr80iILyIrI7dvW8HftCsvWZR1dewIHmXi8iRWdgUkQnAPcArZQsYQFXvBb4KzBCRUc3aE2Mq\nsD8WmaU2AgbqFygvjCwGYTOPD9NEYDRgbyxe9VRaJ4Dc2diy0e9jr7ouAZZv0FZ/LBb3K1jwPQUe\nKLuOncr3WSyi55RGv38s1vbPsHfCiUPeVulTegFybACCbTJfgO1OShx3GBgFXAX8L7B72XXpVK4O\nAZ8dfl8B+GEo52dJGFYXG4HtHRr2HVhybrDY1q0m5M2xjRt3AeOSihkbjR2PZYT4RtLvpoqf0gtQ\nQCMYAVyKZQ+4AcuQsBnQt9Mxg7DIGcdis7N/CcIfVnb5O5XxXwTc5W+7A3eGBvs9LM3JOp0bPLAK\nsBe2Z3p+EMaHsiO0qJD7AydgaWkeB04BxncegYSb9gbYxNwPsBC804Ftyi5/3p/aTWwtizAhdCi2\nIGQM9i5UsYwHfYAY65nuxFbyFL2YY5kkTS8qIhthPfJ2WB2HYhsL+mN1fRSr443aQ8ROEdkHe0/8\noKrumE0tmie8Atsdu0mNAbbErt+A8O+rWP0eAq5W2xdde9pGxF0RkUHYpoK5WPTIRrbB5U6j+YHD\n65TlsHXT7wCLNEWEzlYVcmfC4ptTsWAHy2mB4ZJaiTrOTidCVd8OPy6pm4DBVrGp6huq+qqqvp5G\nwOH8lkl0vizU3mW/HX5uSwFDG4u41WlGwFlRBSE7LuKWpBUE3IELufVxEbcYrSTgDlzIrY2LuIVo\nRQF34EJuXVzELUIrC7gDF3Jr4iJuAaog4A5cyK2Hi7hkqiTgDlzIrYWLuESqKOAOXMitg4u4JKos\n4A5cyK2Bi7gE6iDgDlzI5eMiLpg6CbgDF3K5uIgLpI4C7sCFXB4u4oKos4A7cCGXg4u4ANpBwB24\nkIvHRZwz7STgDlzIxeIizpF2FHAHLuTicBHnRDsLuAMXcjG4iHPABbwUF3L+tKWIQzDx88Kv24pI\nZilRXMAfJi8hi8hEQk5pEblGRAZnZbtKtKWIA1/o9POBWRh0AS+bnIS8KxbREyz87tAejq0tbSli\ntRCfT4Rf38PyLDWFC7h3chDyw8Ab4eeFqvpqBjYrR1uKODATWAK8iaW6bBgXcHIyFvIcLKY2WEzt\ntqSdRfwIJuBBWGNoCBdwejIU8vNY0PhMRlNVpZ1FPAcLrv6Wqv4tzYkiMlJEhruAG6erkEVkZRHZ\nIqUNBZ4E+mJpadqSfmUXoEReAhZhuX4TIyIrYw1mpfBfLuAGUdXbReTjWKaJvwPviMjWqvpMCjMz\ngR1oYjRVddpKxCG1yUgsj8/mwOvAIBE5HRPzHFX9ay9mLgGGdPr98jzK2igiMgDYAqvjSOxx4R3s\npjUbeKzFsiXci2VxGIA9304XkdHaQw5hEVkR2Aar4ybYY9FJIhJjdZwXskO0BW2Ri0lEPoLluJ2C\nJRabg81Ov4Fl01sV2BprFC8C04DrVPWNLnYmYZkVhwD/BN4Cxqjqc8XUZNmIyLZYVscDsFSnc7As\ngh0J1dbH6jcK+DVwETAjbXqXrAk5sZ4B1sC+1zeBb6nquV2OGwDsi9VxDJYdcQ7wMvZcPAir2xjs\nel4DXKyq84qpSYmUnZYxzw8wGEv1uRAT5ha9HC9Y1r2bsOHd0Sy90a0U7CjW0H4CrNQCdRyJifJZ\n4CvAqr0cPxR7R/4o8BSwfQvUYSDwHeymqOHfUZ3+/nFsEuteYDLQvxd7w7GcxH/FbrqrlV3HXL+/\nsguQY8PYDktX2tBFxO7qDwN3Y7l+p4cGthD4RAvUT4BjsHSeJ9Mp33KK8w/Echp/B+jXAnUaG25G\n7wNPAytiyd6fBXZtwN5g4LtYvul9y65fbt9b2QXIqTFMBBYAk5u00w9Lyv3XIOBW6X0F+K/Qk27S\npK01gLuAW4GBLVC3jl55CfDnIOLlmrS5I/aYdEzZ9cvjU7tnYhHZCbgFu/P+NiObJwCnA2NV9cUs\nbDZZnnOATwB7qOrCDOwNAG7E3rceqOU/Jw/BEoU/AnxBM2ikIjISG45/XVWvbNZeK1ErEYdZyyeB\nKap6R8a2zwJ2AfbMolE1UY49sBnxMaq6IEO7A7FGfoOqXpiV3QbLcgGwJnBQlt+1iGwCzAJ20hpN\neNVNxFcAi1X12Bxs9wMeAH6kqpdkbT9hGZbHZtWPUtU7c7C/MVbHHVT1T1nbT1iGXYDrsUnIpkcZ\n3dg/DvgsJuRavIaqjYjDK5abgU1VdVFOPjYDfgOsl5ePXvyfAwxX1SNy9HEK8FFV3S8vHz34Fux9\n/VRVvTknH32A+4DL6zKsrtOyy2OB/8lTXKr6NLZC6NC8fCyL8Nw6BZv0yZOLgQkisk7OfrpjHDax\ndUteDsLz/n9h7aUW1KInDkshnwU2SvKcGEXRHsAZcRxPCL+vDdwDbBvH8eu9+NoDOE9VRzdd8BSI\nyIHYMHq3JMdHUTQCmx/ovBzxsTiOv5TA1/eB/1PVsxopa6OIyI+B36nqBb0dG0XRGcCQOI6/Fn7v\nA8wFDo/j+ImezhWRvsB84ABVrfya67r0xDtjFz/RRE8cx3cDz0dRdHj4r/OA03sTcOAeYF0RWaOx\nojbMROCnKc+J4zie0OnTq4AD04G9UvpqijCU3ovkdTwP2D/cgAGOBH7Xm4ABwrPwTRRcx7yoi4jH\nkH4Xy0nAaVEU7QsMi+M4UeMJw7G5wWeRNFLHRnkU2CJM5hXFcOBdVX0pycFxHC8G/hP4RhRFQ7AF\nL2em8Deb4q9hLtRFxFthwkpMHMevYnfzG4HjU/qbi621zoywFe8eEblCRL4gIluJSP/wt/5AxNJo\nJLmitmb8BWxzQWaIyPiw7fC7InKAiKwfemCwa5h2Y/91wKbAD4Er4zh+JcW5c4PPylOXXUwrA2ku\nYAejsc0CY4E0mxgWAMeLyKca8Lks1gHWCj9PJizqF5HnsF7jHVV9O6XNKIqi+zr9PiOO43OXdXAX\nFgM/EZEsJwq3xVa+bY9tA+0HqIg8jS1nfS2NsTiONYqirwHXAp9LWZYFWLupPHURsWCNIzFRFG0H\nbIYFW7s7iqJfxXGctMG+jw3/hqcqZXKGYpsA3gU2wK5TI6OmuGPyrgH6Y71cHvQBhmFCHoRtnfwb\njT0uPAu8HMfxP1Oe9z41GYnWohJYY1gh6cFRFPXDdjWdGMfxy9gKqKkp/K0A/KeqSlYfrCd+DfgT\n1rN8GVshNhS72fQLs6pF8R62zDTLOh6C9fBzgAuwV2abhjqeikXoKIoVWBpkr9LUpSd+Chsa/zLh\n8f8BzIzj+Onw+wXAnCiKtojj+MkE54/GhJYZqvqSiKy0rGWGIvIi9oz6dHd/z5KwBHNj4PdZ2lXV\n60Xkhu7qKCJPAVtm6a8XRmPtpvLU5T3xIcD+qrp/Qf5eBHZR1WeL8Bd83gj8UlWvLsDXGOBKVU0V\n86pJn32BfwDrquo/CvB3DrZ98/S8feVNXYbTs4DxRWQAEJGO3qLoaB4zgX0K8rU3cH9BvoAP3t0+\nEHwXQeF1zItaiFhVn8c28GeSyaEXjgUuLWEn04+BiSKyZp5Owrvho4BL8/SzDC6lgOWQYZ39qsCM\nvH0VQS1EHJgG/Huekz8ishp2o7gsLx/LIgwxp5N/I98PeEFVU0UBzYifAyNFZGzOfr4EXOK7mFqM\nsDvlXuBWVT0/Jx83Yg385DzsJ/C/HvYaZrymC+ua1P5K2HrrQ1W1lGDsInIEJrJtVfXdHOxPxCKW\nbqmqSZbZtjy1ETGAiGyARYTYOetN3yIyGVvmt7WqLs7SdspyHAV8HhiXZSMPK6euxjY+nJiV3QbL\n8Qtgtqp+PWPbK2Kr3j6nqndnabtM6jScRlXnY+8bb89yK52IjMOG64eVKeDApVgM6aszXtt8FvaK\n56sZ2kxNmGv4IvA5ETm8t+OTIiJDgduA6XUSMNRMxACqejkmuN+EcCxNISJ7Y0HkDlXVphKvZUFo\n5AcDqwA3iMhyzdgTkb4i8h3gIGAvbYHA8mETxETgWyJyXKf11Q0hIqtjwQD/AJySQRFbitqJGEBV\nv4fFHb5fRE5qZLJLRIaJyMXY89OnVfWurMvZKGEN9SRsxdHjIjK+ETsisinwW2zt+M6aMidVnoRn\n/vFYr3xbSACQmvAY9AQ2XzJFSw4CmAvaAiE38/pg645nYrtjjgQGJzhnNeA0bBfPj4AVyq5HL+Wd\nhIVjnQ5MIMxz9HLOaOAH2CaAY4A+Zdejh7IOwJbELsAicoxMcE4/4FPYK6TfYzHDSq9LXp9aTWx1\nR5i13gdrrNtjMbJmY2lAXsNGI6thuX3GhmNuBqZpRaI+hAB6hwHHYQ14FrY+OcbWKg9gaRqXHbBY\n05cAl6nqX8ooc1pEZEMsI8cRwDxsXcAcbH7gXSzD5WZYHcdjN+FpwE9UNe3miEpRexF3RkSGY3Gc\nOhKqDcN2sywkJFQDZmkOURaLIDw7boNt+RsDbIjFrHoH663nhM8D2kPCslYmrMrbCavfGGB1bMfV\nW5i45wAPqWqm675bmbYSsePUkVpObDlOO+EidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMi\ndpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInac\niuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrj\nInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2\nnIrjInaciuMidpyK4yJ2nIrjInaciuMidpyK4yJ2nIrz//fZsGHg3Z2QAAAAAElFTkSuQmCC\n", 552 | "text/plain": [ 553 | "
" 554 | ] 555 | }, 556 | "metadata": { 557 | "tags": [] 558 | }, 559 | "output_type": "display_data" 560 | } 561 | ], 562 | "source": [ 563 | "#@title # Game 6\n", 564 | "#@markdown So these are no longer drawn from The Book of Why, but were either drawn from another source (which I will reference) or a developed to try to induce a specific bug. \n", 565 | "#@markdown This example is drawn from Causality by Pearl on p. 80. This example is kind of interesting because there are many possible combinations of nodes which will close the two backdoor paths which exist in this graph. In turns out that D plus any other node in {A, B, C, E} will deconfound X and Y. \n", 566 | "\n", 567 | "pgm = PGM(shape=[4, 4])\n", 568 | "\n", 569 | "pgm.add_node(daft.Node('X', r\"X\", 1, 1))\n", 570 | "pgm.add_node(daft.Node('Y', r\"Y\", 3, 1))\n", 571 | "pgm.add_node(daft.Node('A', r\"A\", 1, 3))\n", 572 | "pgm.add_node(daft.Node('B', r\"B\", 3, 3))\n", 573 | "pgm.add_node(daft.Node('C', r\"C\", 1, 2))\n", 574 | "pgm.add_node(daft.Node('D', r\"D\", 2, 2))\n", 575 | "pgm.add_node(daft.Node('E', r\"E\", 3, 2))\n", 576 | "pgm.add_node(daft.Node('F', r\"F\", 2, 1))\n", 577 | "\n", 578 | "\n", 579 | "pgm.add_edge('X', 'F')\n", 580 | "pgm.add_edge('F', 'Y')\n", 581 | "pgm.add_edge('C', 'X')\n", 582 | "pgm.add_edge('A', 'C')\n", 583 | "pgm.add_edge('A', 'D')\n", 584 | "pgm.add_edge('D', 'X')\n", 585 | "pgm.add_edge('D', 'Y')\n", 586 | "pgm.add_edge('B', 'D')\n", 587 | "pgm.add_edge('B', 'E')\n", 588 | "pgm.add_edge('E', 'Y')\n", 589 | "\n", 590 | "pgm.render()\n", 591 | "plt.show()" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": 38, 597 | "metadata": { 598 | "colab": { 599 | "base_uri": "https://localhost:8080/", 600 | "height": 68 601 | }, 602 | "colab_type": "code", 603 | "id": "30OIiRt7raN2", 604 | "outputId": "4c285443-724c-4988-b9a6-2ad2fcf2654e" 605 | }, 606 | "outputs": [ 607 | { 608 | "name": "stdout", 609 | "output_type": "stream", 610 | "text": [ 611 | "Are there are active backdoor paths? True\n", 612 | "If so, what's the possible backdoor adjustment sets? frozenset({frozenset({'B', 'D'}), frozenset({'A', 'D'}), frozenset({'C', 'D'}), frozenset({'E', 'D'})})\n", 613 | "Ehat's the possible front adjustment sets? frozenset({frozenset({'F'})})\n" 614 | ] 615 | } 616 | ], 617 | "source": [ 618 | "graph = convert_pgm_to_pgmpy(pgm)\n", 619 | "inference = CausalInference(graph)\n", 620 | "print(f\"Are there are active backdoor paths? {inference._has_active_backdoors('X', 'Y')}\")\n", 621 | "bd_adj_sets = inference.get_all_backdoor_adjustment_sets(\"X\", \"Y\")\n", 622 | "print(f\"If so, what's the possible backdoor adjustment sets? {bd_adj_sets}\")\n", 623 | "fd_adj_sets = inference.get_all_frontdoor_adjustment_sets(\"X\", \"Y\")\n", 624 | "print(f\"Ehat's the possible front adjustment sets? {fd_adj_sets}\")" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 40, 630 | "metadata": { 631 | "cellView": "form", 632 | "colab": { 633 | "base_uri": "https://localhost:8080/", 634 | "height": 201 635 | }, 636 | "colab_type": "code", 637 | "id": "Z4pkuyOwM9xq", 638 | "outputId": "82e581df-b378-423c-95bd-d16b901c5296" 639 | }, 640 | "outputs": [ 641 | { 642 | "data": { 643 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPEAAAC4CAYAAAAynAqtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEMpJREFUeJzt3Xm0HGWZx/HvE5IQBAxEFtkCBuQV\nEAIk7LsCGRGOwBBlH2BYZNFBhcEZNgGZGQScYY7CEJQgg7JElIMMIoMsgoiQQNiE4ogsAgpxGQQJ\nEOCZP57qIYMht6tvVXe/nd/nnHuSe2/Xu9yuX1fVW29VmbsjIvka0esGiMjwKMQimVOIRTKnEItk\nTiEWyZxCLJI5hVgkcwqxSOYUYpHMKcQimVOIRTKnEItkTiEWyZxCLJI5hVgkcwqxSOYUYpHMKcQi\nmVOIRTKnEItkTiEWyZxCLJI5hVgkcwqxSOYUYpHMKcQimVOIRTKnEItkTiEWyZxCLJI5hVgkcwqx\nSOYUYpHMKcQimVOIRTKnEItkTiEWyZxCLJI5hVgkcwqxSOYUYpHMKcQimVOIRTKnEItkTiEWyZxC\nLJI5hVgkcyN73QBphpmtCEwAxgCvA88CT7m797RhUjvTezoYzGwksCtwALApsCTwGPAqMBpYo/x3\nJjADuNzdX+lJY6VWCnHmzGwx4BjgOOAp4CLgduCJd251zWwlYAvgIGDL8rVnKMx5U4gzZmYJuBiY\nBxzr7rMrLLsGcCax1T7Y3e9ooo3SPA1sZcrMdgTuAC4HPlIlwADu/qS77wccD8wws8MbaKZ0gQa2\nMmRmHwG+A+zp7rcPpyx3v8bMHgJuMjPcfVotjZSu0e50ZsxsNWAWMNXdb6ux3LWIY+mp2rXOi0Kc\nETMz4IfA7e5+ZgPl7w6cDUzUYFc+dEycl32B5YCzmijc3a8B7gZOaqJ8aYa2xJkot8IzgRPd/YYG\n61mjrGe8tsZ50JY4H5sCywA3NlmJuz8J/AzYu8l6pD4anc7HJ4Fvuftb7bw4pbQG8CAxCObE9Mvj\ni6JoZ9BqOnA4cQ5a+py2xPmYTGwhqyiKoti+KIodgBOAk9tc7i5gUrkLL31OIc6AmY0ANgLuHUYx\nKxIXQQzJ3Z8D3gDGD6M+6RLtTudhaQB3/33F5VJK6VZiV3oVYEqFZZ8ol3mqYp3SZdoS52E0cTlh\nVa3d6c2BnYArU0rtfnDPK+uVPqcQ5+FVYMxwjlGLongUmAus1uYiY4DXOq1PukchzsPLRABX7rSA\nlNI4YCXaOC4uPyzWBh7vtD7pHh0TZ8Dd3cxmAZNoc3Cq1DomhtiyHlMURTu75ROAl9z9hWotlV5Q\niPNxN7ANcG07Ly6K4knKAbEObAPc0+Gy0mXanc7HZcCBZrZ4F+o6DLi0C/VIDRTiTLj7Y8D9wF5N\n1mNmGxKDX//VZD1SH4U4L2cBZ5rZe5sovJxUch5wjru/0UQdUj+FOCPu/mPgJuKa3yYcAywGfL2h\n8qUBuhQxM2Y2FpgNnO7u02ssd1vgamCrctddMqEtcWbc/UXgOuBiMzu4jjLNbHvgZuAFBTg/CnFm\nzOxLxG7v14BTzOzCTo+RzWyUmZ0EXEXczWNdM7uztsZKVyjEGSkDfCpwmrt/BtiQeA8fMLP92z39\nZGYjzOxjxCWHWwOT3P1fgI8DWyjIedExcSbeEeAvveN3HwX+ngj1fxJ3rZwFPNt6CoSZvY+Y8bU5\n8DfAH4GvEo9z8fnK2oU4vfQzd9+y2V5JHRTiDCwswO943QeB/Ylb+Uwinsf0KjCKuLvHfUS4r3T3\nuxdSjoKcEYW4z7Ub4AUsZ8BSwOLEZYwvt3trn3J5BTkTCnEf6zTANdavIGdAA1t9qtcBBnD369Fg\nV99TiPtQPwS4RUHufwpxn+mnALcoyP1NIe4j/RjgFgW5fynEfaKfA9yiIPcnhbgP5BDgFgW5/yjE\nPZZTgFsU5P6iEPdQjgFuUZD7h0LcIzkHuEVB7g8KcQ8MQoBbFOTeU4i7bJAC3KIg95ZC3EWDGOAW\nBbl3FOIuGeQAtyjIvaEQd8GiEOAWBbn7FOKGLUoBblGQu0shbtCiGOAWBbl7FOKGLMoBblGQu0Mh\nboAC/DYFuXkKcc0U4L+kIDdLIa6RAvzuFOTmKMQ1UYCHpiA3QyGugQLcPgW5fgrxMCnA1SnI9dJ9\npztgZhOAN4BDUIA7Nv99rYFdgVXc/cHetio/CnFFZjYO+CWwbPkjBXgY5gsyxJMqNnL3X/SwSdlZ\npEJcPtpkAvGcog8Tjzlx4A/Eg7tnuftvhyhjBrAb8XgUgNXd/enGGl2RmY0G1if6OAEYQ4TjWWAm\nMNvd/9y7Fv5/ZrYE8fcfDRjwCDDR3d9YyDLLABsTfVwBGAnMBQqij4+6+5sNN71vLBLHxGa2spmd\nCvwauBXYt/zVs8BvgLHAscAvzOwBM/u0mS29gHJ2A3YhAvwa8WTBxZrvwdDMbBMzmw78HrgE2BL4\nE/B0+bN1gH8DXjCza81sipn1w/vvxHvwKhHi1YET3vkiMxttZp80s1uJ9/F0YFVgTvn9XOCvgO8D\nc8zsPDP7UFd60GvuPrBfwBLAOcQn/fnA+kO83oCPAlcTK/6neXtvZdmyHAf+TDyYe9k+6OME4MfA\nr4jHmy43xOuXBA4lnpD4ELBZH/RhceAs4JXy7/sKsO58v/848WF0CzAVGDVEeeOBLwO/Ba4Alu91\nHxv9+/W6AQ2uGJsSu1cdvYnAusDdwE3EJ/4M3t713rUP+mfAkcDvgOOAxTpY/lPA82WARvZBnyaX\nH0ZvAQ8DywDfKn+2QwflLQGcTWzpd+91/xr7u/W6AQ2tDFOI3aypwyxnJHBK+YnufbT1NeAr5Zb0\nQ8Msa0XgRuAaYPE+6Ftrq/wm8FQZ4qWGWeaWwDPAkb3uXxNfAzewZWZbE8dFu7v7T2sq8zPAicBk\nd3+mjjKH2Z7TiVMyO7r7H2oobzRwJXHa7FNe4TnGTTCz9wB3AfcAh3oNK2l5WvAW4FR3v2S45fWT\ngQpxOWr5IHCYu99Qc9mnANsCO9WxUg2jHTsCFwOT3H1OjeUuTqzkV7j7v9dVbodtOQ94P7B3nX/r\ncqDrDmBrd3+0rnJ7bdBCPB2Y6+5HNVD2SOBO4JvufmHd5bfZhvcCDwBHuPuPGih/baKPm7v7L+su\nv802bAtcTgxCDnsvYwHlHw3sTwR5IE5DDUyIzWwT4HvAOu7+ckN1rAf8hDg33EgdQ9R/OjDe3Q9q\nsI7jgS3cfc+m6lhI3Uacrz/N3b/XUB0jiNOMFw/KbnU/nCesy1HA15oMl7s/DNwG7NdUHe+mPG49\njBj0adIFwPZmtmrD9SzIVsTA1vebqqA83v8Ksb4MhIEIcTkVcg/iWLFp59ObFWAP4BF3f6TJSsoP\nwW8DhzdZz7s4CrigC2MOPwRWMLPJDdfTFQMRYmAb4OftDvSklHZMKd063/erpJSKlNJ721j8ZmA1\nM1uxs6Z2bArw3SoLpJT2SSnNSyktV7GuGcDOFZcZlnJXemfa7GNK6aSU0j/N9/2IlNLslNIGQy1b\nHgtfTZf72JRBCfEkYs5sW4qiuAl4OqV0YPmjc4ETi6L401DLlrtj95Z1dlOlPpb2BR4H9qq43H3A\n+uVgXreMB+a5+7Ntvv5c4K9TSquU3x8M/LwoigfaXH4m3X8PGzEoId6QCFYVnwe+mFLaHVi6KIoq\nW7l7gY0q1rdQZjbOzG42s+lmdqiZbWhmo8rfjQISMTLdlpTSOGLW2heAfaq0xd1fIuYj1zr32My2\nM7M7zexsM9vLzD5QboEh3sP72i2rKIq5wBnAl1NK7yFmrZ1coTn3lnVmr5uftE0aB7xQZYGiKH6X\nUjqXmORQdWWdAxxjZp+ouNzCrAqsVP5/KjH1cIyZPUFsNV5391crlDcVuA64AbgopbRKURTtbuUg\nLii4yszqHCjchJj5thnwMrH+uZk9TExnfbFied8GjgEuAi4piqLKOjCHWG+yNyghNmLlqGoi8CQx\nZ/eJCsu9Rez+je+gznYsSVwEMA9Yk3ifqu417QucURTFmyml7xLzpL9aYflRxJVPTRgBLE0EeQxx\n6eTzVDxcKIrCU0r/CFxG3KChircYkD3RgegEsTKMrbJASmlTYD1gB+C0lNJSFRYfC5zh7lbXF7El\nfpG44cBlwOeIGWJLlu0caWZtXfaYUlqV2Nqdm1KaTVyit3eF/kFMwZxccx/3Jbbws4DziFNm65R9\nPIHOLuv8FfBcURSvVVxuLPBSB/X1nUEJ8UPEVrUtKaWRxKmizxZF8Rxxauq0CvVNJKZ31qYc0FnW\n3T/o7ge4+zR3n+Xur5W70c/Q/m7/PsDXi6KYWBTFhsTx9LiU0prtLFxOwVybuEC/Nu5+OXExw2R3\n/5y7X+nuj5enlB4ChhxZrtHEss7sDUqIZ1FtpPELwG1FUTxcfn8esFNKaf02l59U1lmrIc6PVunj\nPsD01jdFUThxNVC7W+MPA4+7+yttvr5tC+ljAaxYzn/vhsk08B72wkBMuzSz8cRo42ruPrfhujYA\nri/r6tofz8yOArZ196q7xZ3UdRKwchNz0Ieo90fAJeUWu+m67gFOrvtCmV4YiC2xxz2u7iYGb5p2\nFDCtB1cyfQeYYmbvb7KS8tzwEcC0Jut5F9Powmy4cp79csB/N11XNwxEiEvnA3/X7uBPJ8xseeKD\n4htN1fFu3P1/iJlUTa/kewK/dvfZDdezINcCE7owHfJY4EJdxdRnyqtTbgGucfd/baiOK4kV/Lgm\nym+j/tWJ0zDbeQO3dTWzZYkBu/3c/ba6y2+zDQcRIdvE3ec1UP4U4EJgA3cfcoZeDgYmxABmtiZx\nR4ht6r7o28ymEjOENmr6uHuIdhwB/C2wVZ0reTlz6lLgj+7+2brK7bAd1wEz3f3Umstehpj1doi7\n31Rn2b00SLvTuPvjxPnG6+u8lM7MtiJ21w/oZYBL04hb7V5a89zmU4hTPP9QY5mVlWMNhwOHmNmB\nQ72+XWa2JPADYMYgBRgGLMQA7n4xEbif1HHfYTP7GHETuf3c/Z7hljdc5Uq+D/A+4AozqzJJ5S+Y\n2WJmdhZx+mln74Mby5fnzKcA/2xmR883v7ojZrYCcTPAx4Dja2hiXxm4EAO4+znEfYdvN7PPdzLY\nZWZLm9kFxPHTHu5+Y93t7FQ5+WM3YsbR/Wa2XSflmNk6wE+Jc6bbuPvz9bVyeMpj/u2IrfIPzGzl\nTsopD4MeIMZLDvMe3wSwEd4Ht9xs6ouYd3wbcXXMwcASbSyzPPBF4iqebwJje92PIdq7GzGbawaw\nPeU4xxDLTAT+g7gI4EhgRK/7sZC2jiZm080h7sgxoY1lRgKfIE4hPULcM6znfWnqa6AGthakHLXe\nhVhZNyPukTUTuJ+YqzyCCO7GxBZpM+JeXee7e9Xrd3uivIHeAcDRxAp8BzEbqSDmKo8GPkDM+Nqc\nuNf0hcA33P03vWhzVWa2FvFEjoOAR4l5AbOI8YF5xHO11iP6uB3xIXw+cJW7V51XnZWBD/H8ypld\nW/H2A9WWJq5m+b8HqgF3eAN3WeyG8thxY+KSv0nAWsQ9q14nttazyq87fSEPLOtn5QPYtib613qg\n2ijiqq9Hif7d5Q3fxqifLFIhFhlEAzmwJbIoUYhFMqcQi2ROIRbJnEIskjmFWCRzCrFI5hRikcwp\nxCKZU4hFMqcQi2ROIRbJnEIskjmFWCRzCrFI5hRikcwpxCKZU4hFMqcQi2ROIRbJnEIskjmFWCRz\nCrFI5hRikcwpxCKZU4hFMqcQi2ROIRbJnEIskjmFWCRzCrFI5hRikcwpxCKZU4hFMqcQi2ROIRbJ\nnEIskjmFWCRzCrFI5hRikcwpxCKZU4hFMqcQi2ROIRbJnEIskjmFWCRzCrFI5hRikcz9L/I0184T\nXIh7AAAAAElFTkSuQmCC\n", 644 | "text/plain": [ 645 | "
" 646 | ] 647 | }, 648 | "metadata": { 649 | "tags": [] 650 | }, 651 | "output_type": "display_data" 652 | } 653 | ], 654 | "source": [ 655 | "#@title # Game 7\n", 656 | "#@markdown This game tests the front door adjustment. B is taken to be unobserved, and therfore we cannot close the backdoor path X <- B -> Y. \n", 657 | "pgm = PGM(shape=[4, 3])\n", 658 | "\n", 659 | "pgm.add_node(daft.Node('X', r\"X\", 1, 1))\n", 660 | "pgm.add_node(daft.Node('Y', r\"Y\", 3, 1))\n", 661 | "pgm.add_node(daft.Node('A', r\"A\", 2, 1))\n", 662 | "pgm.add_node(daft.Node('B', r\"B\", 2, 2))\n", 663 | "\n", 664 | "\n", 665 | "pgm.add_edge('X', 'A')\n", 666 | "pgm.add_edge('A', 'Y')\n", 667 | "pgm.add_edge('B', 'X')\n", 668 | "pgm.add_edge('B', 'Y')\n", 669 | "\n", 670 | "pgm.render()\n", 671 | "plt.show()" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": 41, 677 | "metadata": { 678 | "colab": { 679 | "base_uri": "https://localhost:8080/", 680 | "height": 68 681 | }, 682 | "colab_type": "code", 683 | "id": "m8DZd_FQ4uLV", 684 | "outputId": "b73751d4-8aa9-4ec1-ba10-d509fe13c1a1" 685 | }, 686 | "outputs": [ 687 | { 688 | "name": "stdout", 689 | "output_type": "stream", 690 | "text": [ 691 | "Are there are active backdoor paths? True\n", 692 | "If so, what's the possible backdoor adjustment sets? frozenset({frozenset({'B'})})\n", 693 | "Ehat's the possible front adjustment sets? frozenset({frozenset({'A'})})\n" 694 | ] 695 | } 696 | ], 697 | "source": [ 698 | "graph = convert_pgm_to_pgmpy(pgm)\n", 699 | "inference = CausalInference(graph)\n", 700 | "print(f\"Are there are active backdoor paths? {inference._has_active_backdoors('X', 'Y')}\")\n", 701 | "bd_adj_sets = inference.get_all_backdoor_adjustment_sets(\"X\", \"Y\")\n", 702 | "print(f\"If so, what's the possible backdoor adjustment sets? {bd_adj_sets}\")\n", 703 | "fd_adj_sets = inference.get_all_frontdoor_adjustment_sets(\"X\", \"Y\")\n", 704 | "print(f\"Ehat's the possible front adjustment sets? {fd_adj_sets}\")" 705 | ] 706 | }, 707 | { 708 | "cell_type": "code", 709 | "execution_count": 0, 710 | "metadata": { 711 | "colab": {}, 712 | "colab_type": "code", 713 | "id": "8zeU4DX0Bbxl" 714 | }, 715 | "outputs": [], 716 | "source": [] 717 | } 718 | ], 719 | "metadata": { 720 | "colab": { 721 | "collapsed_sections": [], 722 | "include_colab_link": true, 723 | "name": "Causal Games.ipynb", 724 | "provenance": [], 725 | "toc_visible": true, 726 | "version": "0.3.2" 727 | }, 728 | "kernelspec": { 729 | "display_name": "Python 3", 730 | "language": "python", 731 | "name": "python3" 732 | }, 733 | "language_info": { 734 | "codemirror_mode": { 735 | "name": "ipython", 736 | "version": 3 737 | }, 738 | "file_extension": ".py", 739 | "mimetype": "text/x-python", 740 | "name": "python", 741 | "nbconvert_exporter": "python", 742 | "pygments_lexer": "ipython3", 743 | "version": "3.8.10" 744 | } 745 | }, 746 | "nbformat": 4, 747 | "nbformat_minor": 1 748 | } 749 | -------------------------------------------------------------------------------- /notebooks/4. Markov Models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Markov Networks" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "collapsed": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "from IPython.display import Image" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Markov Models\n", 26 | "1. What are Markov Models\n", 27 | "2. Independencie in Markov Models\n", 28 | "3. How is Markov Model encoding Joint Probability Distribution\n", 29 | "4. How we do inference from Markov Models\n", 30 | "5. Types of methods for inference" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "### 1. What are Markov Models\n" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [] 48 | } 49 | ], 50 | "metadata": { 51 | "kernelspec": { 52 | "display_name": "Python 3", 53 | "language": "python", 54 | "name": "python3" 55 | }, 56 | "language_info": { 57 | "codemirror_mode": { 58 | "name": "ipython", 59 | "version": 3 60 | }, 61 | "file_extension": ".py", 62 | "mimetype": "text/x-python", 63 | "name": "python", 64 | "nbconvert_exporter": "python", 65 | "pygments_lexer": "ipython3", 66 | "version": "3.8.10" 67 | } 68 | }, 69 | "nbformat": 4, 70 | "nbformat_minor": 1 71 | } 72 | -------------------------------------------------------------------------------- /notebooks/5. Exact Inference in Graphical Models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Exact Inference in Graphical Models" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Inference\n", 15 | "Inference is same as asking conditional probability questions to the models. So in our student example we might would have liked to know what is the probability of a student getting a good grade given that he is intelligent which is basically equivalent of asking $P(g^1 | i^1)$. Inference algorithms deals with efficiently finding these conditional probability queries.\n", 16 | "\n", 17 | "There are two main categories for inference algorithms:\n", 18 | "\n", 19 | "1. Exact Inference: These algorithms find the exact probability values for our queries.\n", 20 | "\n", 21 | "2. Approximate Inference: These algorithms try to find approximate values by saving on computation." 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "### Exact Inference\n", 29 | "There are multiple algorithms for doing exact inference. We will mainly be talking about two very common algorithms in this notebook:\n", 30 | "\n", 31 | "1. Variable Elimination\n", 32 | "\n", 33 | "2. Clique Tree Belief Propagation" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### Variable Elimination\n", 41 | "The basic concept of variable elimination is same as doing marginalization over Joint Distribution. But variable elimination avoids computing the Joint Distribution by doing marginalization over much smaller factors. So basically if we want to eliminate $X$ from our distribution, then we compute the product of all the factors involving $X$ and marginalize over them, thus allowing us to work on much smaller factors. Let's take the student example to make things more clear:\n", 42 | "\n", 43 | "$$ P(D) = \\sum_I \\sum_S \\sum_G \\sum_L P(D, I, S, G, L) $$\n", 44 | "$$ P(D) = \\sum_I \\sum_S \\sum_G \\sum_L P(D) * P(I) * P(S | I) * P(G | D, I) * P(L | G) $$\n", 45 | "$$ P(D) = P(D) \\sum_S P(S | I) \\sum_I P(I) \\sum_G P(G | D, I) \\sum_L P(L | G) $$\n", 46 | "\n", 47 | "In the above equation we can see that we pushed the summation inside and operated the summation only factors that involved that variable and hence avoiding computing the complete joint distribution.\n", 48 | "\n", 49 | "Let's now see some code examples:" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "## Add code examples for Variable Elimination" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "### Clique Tree Belief Propagation" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [] 74 | } 75 | ], 76 | "metadata": { 77 | "kernelspec": { 78 | "display_name": "Python 3", 79 | "language": "python", 80 | "name": "python3" 81 | }, 82 | "language_info": { 83 | "codemirror_mode": { 84 | "name": "ipython", 85 | "version": 3 86 | }, 87 | "file_extension": ".py", 88 | "mimetype": "text/x-python", 89 | "name": "python", 90 | "nbconvert_exporter": "python", 91 | "pygments_lexer": "ipython3", 92 | "version": "3.8.10" 93 | } 94 | }, 95 | "nbformat": 4, 96 | "nbformat_minor": 1 97 | } 98 | -------------------------------------------------------------------------------- /notebooks/6. Approximate Inference in Graphical Models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Approximate Inference in Graphical Models" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [] 18 | } 19 | ], 20 | "metadata": { 21 | "kernelspec": { 22 | "display_name": "Python 3", 23 | "language": "python", 24 | "name": "python3" 25 | }, 26 | "language_info": { 27 | "codemirror_mode": { 28 | "name": "ipython", 29 | "version": 3 30 | }, 31 | "file_extension": ".py", 32 | "mimetype": "text/x-python", 33 | "name": "python", 34 | "nbconvert_exporter": "python", 35 | "pygments_lexer": "ipython3", 36 | "version": "3.8.10" 37 | } 38 | }, 39 | "nbformat": 4, 40 | "nbformat_minor": 1 41 | } 42 | -------------------------------------------------------------------------------- /notebooks/7. Parameterizing with Continuous Variables.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Parameterizing with Continuous Variables" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from IPython.display import Image" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "## Continuous Factors" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "1. Base Class for Continuous Factors\n", 31 | "2. Joint Gaussian Distributions\n", 32 | "3. Canonical Factors\n", 33 | "4. Linear Gaussian CPD" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "In many situations, some variables are best modeled as taking values in some continuous space. Examples include variables such as position, velocity, temperature, and pressure. Clearly, we cannot use a table representation in this case. \n", 41 | "\n", 42 | "Nothing in the formulation of a Bayesian network requires that we restrict attention to discrete variables. The only requirement is that the CPD, $P(X | Y_1, Y_2, \\cdots Y_n)$ represent, for every assignment of values $y_1 \\in Val(Y_1), y_2 \\in Val(Y_2), \\cdots, y_n \\in val(Y_n)$, a distribution over $X$. In this case, $X$ might be continuous, in which case the CPD would need to represent distributions over a continuum of values; we might also have $X$’s parents continuous, so that the CPD would also need to represent a continuum of different probability distributions. There exists implicit representations for CPDs of this type, allowing us to apply all the network machinery for the continuous case as well." 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### Base Class for Continuous Factors" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "This class will behave as a base class for the continuous factor representations. All the present and future factor classes will be derived from this base class. We need to specify the variable names and a pdf function to initialize this class." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import numpy as np\n", 66 | "from scipy.special import beta\n", 67 | "\n", 68 | "# Two variable drichlet ditribution with alpha = (1,2)\n", 69 | "def drichlet_pdf(x, y):\n", 70 | " return (np.power(x, 1)*np.power(y, 2))/beta(x, y)\n", 71 | "\n", 72 | "from pgmpy.factors.continuous import ContinuousFactor\n", 73 | "drichlet_factor = ContinuousFactor(['x', 'y'], drichlet_pdf)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "(['x', 'y'], 226800.0)" 85 | ] 86 | }, 87 | "execution_count": 4, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "drichlet_factor.scope(), drichlet_factor.assignment(5,6)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "This class supports methods like **marginalize, reduce, product and divide** just like what we have with discrete classes. One caveat is that when there are a number of variables involved, these methods prove to be inefficient and hence we resort to certain Gaussian or some other approximations which are discussed later." 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 5, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "def custom_pdf(x, y, z):\n", 110 | " return z*(np.power(x, 1)*np.power(y, 2))/beta(x, y)\n", 111 | "\n", 112 | "custom_factor = ContinuousFactor(['x', 'y', 'z'], custom_pdf)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 6, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "data": { 122 | "text/plain": [ 123 | "(['x', 'y', 'z'], 24.0)" 124 | ] 125 | }, 126 | "execution_count": 6, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "custom_factor.scope(), custom_factor.assignment(1, 2, 3)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 7, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "data": { 142 | "text/plain": [ 143 | "(['x', 'z'], 24.0)" 144 | ] 145 | }, 146 | "execution_count": 7, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "custom_factor.reduce([('y', 2)])\n", 153 | "custom_factor.scope(), custom_factor.assignment(1, 3)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 8, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "(['x1', 'x2'], 0.058549831524319168)" 165 | ] 166 | }, 167 | "execution_count": 8, 168 | "metadata": {}, 169 | "output_type": "execute_result" 170 | } 171 | ], 172 | "source": [ 173 | "from scipy.stats import multivariate_normal\n", 174 | "\n", 175 | "std_normal_pdf = lambda *x: multivariate_normal.pdf(x, [0, 0], [[1, 0], [0, 1]])\n", 176 | "std_normal = ContinuousFactor(['x1', 'x2'], std_normal_pdf)\n", 177 | "std_normal.scope(), std_normal.assignment([1, 1])" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 9, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/plain": [ 188 | "(['x1'], 0.24197072451914328)" 189 | ] 190 | }, 191 | "execution_count": 9, 192 | "metadata": {}, 193 | "output_type": "execute_result" 194 | } 195 | ], 196 | "source": [ 197 | "std_normal.marginalize(['x2'])\n", 198 | "std_normal.scope(), std_normal.assignment(1)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 10, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "text/plain": [ 209 | "(0.063493635934240983, 0.3989422804014327)" 210 | ] 211 | }, 212 | "execution_count": 10, 213 | "metadata": {}, 214 | "output_type": "execute_result" 215 | } 216 | ], 217 | "source": [ 218 | "sn_pdf1 = lambda x: multivariate_normal.pdf([x], [0], [[1]])\n", 219 | "sn_pdf2 = lambda x1,x2: multivariate_normal.pdf([x1, x2], [0, 0], [[1, 0], [0, 1]])\n", 220 | "sn1 = ContinuousFactor(['x2'], sn_pdf1)\n", 221 | "sn2 = ContinuousFactor(['x1', 'x2'], sn_pdf2)\n", 222 | "sn3 = sn1 * sn2\n", 223 | "sn4 = sn2 / sn1\n", 224 | "sn3.assignment(0, 0), sn4.assignment(0, 0)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "The ContinuousFactor class also has a method **discretize** that takes a pgmpy Discretizer class as input. It will output a list of discrete probability masses or a Factor or TabularCPD object depending upon the discretization method used. Although, we do not have inbuilt discretization algorithms for multivariate distributions for now, the users can always define their own Discretizer class by subclassing the pgmpy.BaseDiscretizer class." 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "### Joint Gaussian Distributions" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "In its most common representation, a multivariate Gaussian distribution over $X_1 \\cdots X_n$ is characterized by an n-dimensional mean vector $\\mu$, and a symmetric $n \\times n$ covariance matrix $\\Sigma$. The density function is most defined as -" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "$$p(x) = \\dfrac{1}{(2\\pi)^{n/2}| \\Sigma |^{1/2}} \\exp[-0.5*(x- \\mu )^T \\Sigma^{-1}(x- \\mu)]$$\n" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "The class pgmpy.JointGaussianDistribution provides its representation. This is derived from the class pgmpy.ContinuousFactor. We need to specify the variable names, a mean vector and a covariance matrix for its inialization. It will automatically comute the pdf function given these parameters." 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 11, 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "data": { 269 | "text/plain": [ 270 | "['x1', 'x2', 'x3']" 271 | ] 272 | }, 273 | "execution_count": 11, 274 | "metadata": {}, 275 | "output_type": "execute_result" 276 | } 277 | ], 278 | "source": [ 279 | "from pgmpy.factors.distributions import GaussianDistribution as JGD\n", 280 | "dis = JGD(['x1', 'x2', 'x3'], np.array([[1], [-3], [4]]),\n", 281 | " np.array([[4, 2, -2], [2, 5, -5], [-2, -5, 8]]))\n", 282 | "dis.variables" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 12, 288 | "metadata": {}, 289 | "outputs": [ 290 | { 291 | "data": { 292 | "text/plain": [ 293 | "array([[ 1.],\n", 294 | " [-3.],\n", 295 | " [ 4.]])" 296 | ] 297 | }, 298 | "execution_count": 12, 299 | "metadata": {}, 300 | "output_type": "execute_result" 301 | } 302 | ], 303 | "source": [ 304 | "dis.mean" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 13, 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "data": { 314 | "text/plain": [ 315 | "array([[ 4., 2., -2.],\n", 316 | " [ 2., 5., -5.],\n", 317 | " [-2., -5., 8.]])" 318 | ] 319 | }, 320 | "execution_count": 13, 321 | "metadata": {}, 322 | "output_type": "execute_result" 323 | } 324 | ], 325 | "source": [ 326 | "dis.covariance" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 14, 332 | "metadata": {}, 333 | "outputs": [ 334 | { 335 | "data": { 336 | "text/plain": [ 337 | "0.0014805631279234139" 338 | ] 339 | }, 340 | "execution_count": 14, 341 | "metadata": {}, 342 | "output_type": "execute_result" 343 | } 344 | ], 345 | "source": [ 346 | "dis.pdf([0,0,0])" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "This class overrides the basic operation methods **(marginalize, reduce, normalize, product and divide)** as these operations here are more efficient than the ones in its parent class. Most of these operation involve a matrix inversion which is $\\mathcal{O}(n^3)$ with repect to the number of variables." 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 15, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "data": { 363 | "text/plain": [ 364 | "['x1', 'x2', 'x3', 'x4']" 365 | ] 366 | }, 367 | "execution_count": 15, 368 | "metadata": {}, 369 | "output_type": "execute_result" 370 | } 371 | ], 372 | "source": [ 373 | "dis1 = JGD(['x1', 'x2', 'x3'], np.array([[1], [-3], [4]]),\n", 374 | " np.array([[4, 2, -2], [2, 5, -5], [-2, -5, 8]]))\n", 375 | "dis2 = JGD(['x3', 'x4'], [1, 2], [[2, 3], [5, 6]])\n", 376 | "dis3 = dis1 * dis2\n", 377 | "dis3.variables" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 16, 383 | "metadata": {}, 384 | "outputs": [ 385 | { 386 | "data": { 387 | "text/plain": [ 388 | "array([[ 1.6],\n", 389 | " [-1.5],\n", 390 | " [ 1.6],\n", 391 | " [ 3.5]])" 392 | ] 393 | }, 394 | "execution_count": 16, 395 | "metadata": {}, 396 | "output_type": "execute_result" 397 | } 398 | ], 399 | "source": [ 400 | "dis3.mean" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 17, 406 | "metadata": {}, 407 | "outputs": [ 408 | { 409 | "data": { 410 | "text/plain": [ 411 | "array([[ 3.6, 1. , -0.4, -0.6],\n", 412 | " [ 1. , 2.5, -1. , -1.5],\n", 413 | " [-0.4, -1. , 1.6, 2.4],\n", 414 | " [-1. , -2.5, 4. , 4.5]])" 415 | ] 416 | }, 417 | "execution_count": 17, 418 | "metadata": {}, 419 | "output_type": "execute_result" 420 | } 421 | ], 422 | "source": [ 423 | "dis3.covariance" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "metadata": {}, 429 | "source": [ 430 | "The others methods can also be used in a similar fashion." 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "### Canonical Factors" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": {}, 443 | "source": [ 444 | "While the Joint Gaussian representation is useful for certain sampling algorithms, a closer look reveals that it can also not be used directly in the sum-product algorithms. Why? Because operations like product and reduce, as mentioned above involve matrix inversions at each step. \n", 445 | "\n", 446 | "So, in order to compactly describe the intermediate factors in a Gaussian network without the costly matrix inversions at each step, a simple parametric representation is used known as the Canonical Factor. This representation is closed under the basic operations used in inference: factor product, factor division, factor reduction, and marginalization. Thus, we can define a set of simple data structures that allow the inference process to be performed. Moreover, the integration operation required by marginalization is always well defined, and it is guaranteed to produce a finite integral under certain conditions; when it is well defined, it has a simple analytical solution.\n", 447 | "\n", 448 | "A canonical form $C (X; K,h, g)$ is defined as:" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "$$C(X; K,h,g) = \\exp(-0.5X^TKX + h^TX + g)$$" 456 | ] 457 | }, 458 | { 459 | "cell_type": "markdown", 460 | "metadata": {}, 461 | "source": [ 462 | "We can represent every Gaussian as a canonical form. Rewriting the joint Gaussian pdf we obtain," 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": {}, 468 | "source": [ 469 | "$N (\\mu; \\Sigma) = C (K, h, g)$ where:" 470 | ] 471 | }, 472 | { 473 | "cell_type": "markdown", 474 | "metadata": {}, 475 | "source": [ 476 | "$$K = \\Sigma^{-1}$$\n", 477 | "$$h = \\Sigma^{-1} \\mu$$\n", 478 | "$$g = -0.5 \\mu^T \\Sigma^{-1} \\mu - \\log((2 \\pi)^{n/2}| \\Sigma |^{1/2}$$" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "Similar to the JointGaussainDistribution class, the CanonicalFactor class is also derived from the ContinuousFactor class but with its own implementations of the methods required for the sum-product algorithms that are much more efficient than its parent class methods. Let us have a look at the API of a few methods in this class." 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 18, 491 | "metadata": {}, 492 | "outputs": [ 493 | { 494 | "data": { 495 | "text/plain": [ 496 | "['x1', 'x2', 'x3']" 497 | ] 498 | }, 499 | "execution_count": 18, 500 | "metadata": {}, 501 | "output_type": "execute_result" 502 | } 503 | ], 504 | "source": [ 505 | "from pgmpy.factors.continuous import CanonicalDistribution\n", 506 | "\n", 507 | "phi1 = CanonicalDistribution(['x1', 'x2', 'x3'],\n", 508 | " np.array([[1, -1, 0], [-1, 4, -2], [0, -2, 4]]),\n", 509 | " np.array([[1], [4], [-1]]), -2)\n", 510 | "phi2 = CanonicalDistribution(['x1', 'x2'], np.array([[3, -2], [-2, 4]]),\n", 511 | " np.array([[5], [-1]]), 1)\n", 512 | "\n", 513 | "phi3 = phi1 * phi2\n", 514 | "phi3.variables" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 19, 520 | "metadata": {}, 521 | "outputs": [ 522 | { 523 | "data": { 524 | "text/plain": [ 525 | "array([[ 6.],\n", 526 | " [ 3.],\n", 527 | " [-1.]])" 528 | ] 529 | }, 530 | "execution_count": 19, 531 | "metadata": {}, 532 | "output_type": "execute_result" 533 | } 534 | ], 535 | "source": [ 536 | "phi3.h" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 20, 542 | "metadata": {}, 543 | "outputs": [ 544 | { 545 | "data": { 546 | "text/plain": [ 547 | "array([[ 4., -3., 0.],\n", 548 | " [-3., 8., -2.],\n", 549 | " [ 0., -2., 4.]])" 550 | ] 551 | }, 552 | "execution_count": 20, 553 | "metadata": {}, 554 | "output_type": "execute_result" 555 | } 556 | ], 557 | "source": [ 558 | "phi3.K" 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": 21, 564 | "metadata": {}, 565 | "outputs": [ 566 | { 567 | "data": { 568 | "text/plain": [ 569 | "-1" 570 | ] 571 | }, 572 | "execution_count": 21, 573 | "metadata": {}, 574 | "output_type": "execute_result" 575 | } 576 | ], 577 | "source": [ 578 | "phi3.g" 579 | ] 580 | }, 581 | { 582 | "cell_type": "markdown", 583 | "metadata": {}, 584 | "source": [ 585 | "This class also has a method, to_joint_gaussian to convert the canoncial representation back into the joint gaussian distribution." 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 22, 591 | "metadata": {}, 592 | "outputs": [ 593 | { 594 | "data": { 595 | "text/plain": [ 596 | "['x1', 'x2']" 597 | ] 598 | }, 599 | "execution_count": 22, 600 | "metadata": {}, 601 | "output_type": "execute_result" 602 | } 603 | ], 604 | "source": [ 605 | "phi = CanonicalDistribution(['x1', 'x2'], np.array([[3, -2], [-2, 4]]),\n", 606 | " np.array([[5], [-1]]), 1)\n", 607 | "jgd = phi.to_joint_gaussian()\n", 608 | "jgd.variables" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": 23, 614 | "metadata": {}, 615 | "outputs": [ 616 | { 617 | "data": { 618 | "text/plain": [ 619 | "array([[ 0.5 , 0.25 ],\n", 620 | " [ 0.25 , 0.375]])" 621 | ] 622 | }, 623 | "execution_count": 23, 624 | "metadata": {}, 625 | "output_type": "execute_result" 626 | } 627 | ], 628 | "source": [ 629 | "jgd.covariance" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": 24, 635 | "metadata": {}, 636 | "outputs": [ 637 | { 638 | "data": { 639 | "text/plain": [ 640 | "array([[ 2.25 ],\n", 641 | " [ 0.875]])" 642 | ] 643 | }, 644 | "execution_count": 24, 645 | "metadata": {}, 646 | "output_type": "execute_result" 647 | } 648 | ], 649 | "source": [ 650 | "jgd.mean" 651 | ] 652 | }, 653 | { 654 | "cell_type": "markdown", 655 | "metadata": {}, 656 | "source": [ 657 | "### Linear Gaussian CPD" 658 | ] 659 | }, 660 | { 661 | "cell_type": "markdown", 662 | "metadata": {}, 663 | "source": [ 664 | "A linear gaussian conditional probability distribution is defined on a continuous variable. All the parents of this variable are also continuous. The mean of this variable, is linearly dependent on the mean of its parent variables and the variance is independent.\n", 665 | "\n", 666 | "For example,\n", 667 | "$$P(Y ; x_1, x_2, x_3) = N(\\beta_1 x_1 + \\beta_2 x_2 + \\beta_3 x_3 + \\beta_0 ; \\sigma^2)$$\n", 668 | "\n", 669 | "Let $Y$ be a linear Gaussian of its parents $X_1, \\cdots, X_k:\n", 670 | "$$p(Y | x) = N(\\beta_0 + \\beta^T x ; \\sigma^2)$$\n", 671 | "\n", 672 | "The distribution of $Y$ is a normal distribution $p(Y)$ where:\n", 673 | "$$\\mu_Y = \\beta_0 + \\beta^T \\mu$$\n", 674 | "\n", 675 | "$$\\sigma^2_Y = \\sigma^2 + \\beta^{T \\Sigma \\beta}$$\n", 676 | "\n", 677 | "\n", 678 | "The joint distribution over $\\{X, Y\\}$ is a normal distribution where:\n", 679 | "\n", 680 | "$$ Cov[X_i; Y] = {\\sum_{j=1}^{k} \\beta_j \\Sigma_{i,j}}$$\n", 681 | "\n", 682 | "\n", 683 | "Assume that $X_1, \\cdots, X_k$ are jointly Gaussian with distribution $\\mathcal{N}(\\mu; \\Sigma)$. Then:\n", 684 | "For its representation pgmpy has a class named LinearGaussianCPD in the module pgmpy.factors.continuous. To instantiate an object of this class, one needs to provide a variable name, the value of the $\\beta_0$ term, the variance, a list of the parent variable names and a list of the coefficient values of the linear equation (beta_vector), where the list of parent variable names and beta_vector list is optional and defaults to None." 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": 25, 690 | "metadata": {}, 691 | "outputs": [ 692 | { 693 | "name": "stdout", 694 | "output_type": "stream", 695 | "text": [ 696 | "P(Y | X1, X2, X3) = N(-2*X1 + 3*X2 + 7*X3 + 0.2; 9.6)\n" 697 | ] 698 | } 699 | ], 700 | "source": [ 701 | "# For P(Y| X1, X2, X3) = N(-2x1 + 3x2 + 7x3 + 0.2; 9.6)\n", 702 | "from pgmpy.factors.continuous import LinearGaussianCPD\n", 703 | "cpd = LinearGaussianCPD('Y', [0.2, -2, 3, 7], 9.6, ['X1', 'X2', 'X3'])\n", 704 | "print(cpd)" 705 | ] 706 | }, 707 | { 708 | "cell_type": "markdown", 709 | "metadata": { 710 | "collapsed": true 711 | }, 712 | "source": [ 713 | "A Gaussian Bayesian is defined as a network all of whose variables are continuous, and where all of the CPDs are linear Gaussians. These networks are of particular interest as these are an alternate form of representaion of the Joint Gaussian distribution.\n", 714 | "\n", 715 | "These networks are implemented as the LinearGaussianBayesianNetwork class in the module, pgmpy.models.continuous. This class is a subclass of the BayesianModel class in pgmpy.models and will inherit most of the methods from it. It will have a special method known as to_joint_gaussian that will return an equivalent JointGuassianDistribution object for the model." 716 | ] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "execution_count": 26, 721 | "metadata": {}, 722 | "outputs": [ 723 | { 724 | "data": { 725 | "text/plain": [ 726 | "['x1', 'x2', 'x3']" 727 | ] 728 | }, 729 | "execution_count": 26, 730 | "metadata": {}, 731 | "output_type": "execute_result" 732 | } 733 | ], 734 | "source": [ 735 | "from pgmpy.models import LinearGaussianBayesianNetwork\n", 736 | "\n", 737 | "model = LinearGaussianBayesianNetwork([('x1', 'x2'), ('x2', 'x3')])\n", 738 | "cpd1 = LinearGaussianCPD('x1', [1], 4)\n", 739 | "cpd2 = LinearGaussianCPD('x2', [-5, 0.5], 4, ['x1'])\n", 740 | "cpd3 = LinearGaussianCPD('x3', [4, -1], 3, ['x2'])\n", 741 | "# This is a hack due to a bug in pgmpy (LinearGaussianCPD\n", 742 | "# doesn't have `variables` attribute but `add_cpds` function\n", 743 | "# wants to check that...)\n", 744 | "cpd1.variables = [*cpd1.evidence, cpd1.variable]\n", 745 | "cpd2.variables = [*cpd2.evidence, cpd2.variable]\n", 746 | "cpd3.variables = [*cpd3.evidence, cpd3.variable]\n", 747 | "model.add_cpds(cpd1, cpd2, cpd3)\n", 748 | "jgd = model.to_joint_gaussian()\n", 749 | "jgd.variables" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 27, 755 | "metadata": {}, 756 | "outputs": [ 757 | { 758 | "data": { 759 | "text/plain": [ 760 | "array([[ 1. ],\n", 761 | " [-4.5],\n", 762 | " [ 8.5]])" 763 | ] 764 | }, 765 | "execution_count": 27, 766 | "metadata": {}, 767 | "output_type": "execute_result" 768 | } 769 | ], 770 | "source": [ 771 | "jgd.mean" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": 28, 777 | "metadata": {}, 778 | "outputs": [ 779 | { 780 | "data": { 781 | "text/plain": [ 782 | "array([[ 4., 2., -2.],\n", 783 | " [ 2., 5., -5.],\n", 784 | " [-2., -5., 8.]])" 785 | ] 786 | }, 787 | "execution_count": 28, 788 | "metadata": {}, 789 | "output_type": "execute_result" 790 | } 791 | ], 792 | "source": [ 793 | "jgd.covariance" 794 | ] 795 | } 796 | ], 797 | "metadata": { 798 | "kernelspec": { 799 | "display_name": "Python 3 (ipykernel)", 800 | "language": "python", 801 | "name": "python3" 802 | }, 803 | "language_info": { 804 | "codemirror_mode": { 805 | "name": "ipython", 806 | "version": 3 807 | }, 808 | "file_extension": ".py", 809 | "mimetype": "text/x-python", 810 | "name": "python", 811 | "nbconvert_exporter": "python", 812 | "pygments_lexer": "ipython3", 813 | "version": "3.8.12" 814 | } 815 | }, 816 | "nbformat": 4, 817 | "nbformat_minor": 1 818 | } 819 | -------------------------------------------------------------------------------- /notebooks/9. Reading and Writing from pgmpy file formats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Reading and Writing from pgmpy file formats" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "pgmpy is a python library for creation, manipulation and implementation of Probabilistic graph models. There are various standard file formats for representing PGM data. PGM data basically consists of graph, a distribution assoicated to each node and a few other attributes of a graph.\n", 15 | "\n", 16 | "pgmpy has a functionality to read networks from and write networks to these standard file formats. Currently pgmpy supports 5 file formats ProbModelXML, PomDPX, XMLBIF, XMLBeliefNetwork and UAI file formats. Using these modules, models can be specified in a uniform file format and readily converted to bayesian or markov model objects. \n", 17 | "\n", 18 | "Now, Let's read a ProbModel XML File and get the corresponding model instance of the probmodel.\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 55, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from pgmpy.readwrite import ProbModelXMLReader" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 56, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "reader_string = ProbModelXMLReader('../files/example.pgmx')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "Now to get the corresponding model instance we need `get_model`" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 57, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "model = reader_string.get_model()" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "Now we can query this model accoring to our requirements. It is an instance of BayesianModel or MarkovModel depending on the type of the model which is given.\n", 60 | "\n", 61 | "Suppose we want to know all the nodes in the given model, we can do:" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 58, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "['X-ray', 'Bronchitis', 'Dyspnea', 'VisitToAsia', 'Smoker', 'LungCancer', 'Tuberculosis', 'TuberculosisOrCancer']\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "print(model.nodes())" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "To get all the edges we can use `model.edges` method." 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 59, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": [ 96 | "OutEdgeView([('Bronchitis', 'Dyspnea'), ('VisitToAsia', 'Tuberculosis'), ('Smoker', 'Bronchitis'), ('Smoker', 'LungCancer'), ('LungCancer', 'TuberculosisOrCancer'), ('Tuberculosis', 'TuberculosisOrCancer'), ('TuberculosisOrCancer', 'Dyspnea'), ('TuberculosisOrCancer', 'X-ray')])" 97 | ] 98 | }, 99 | "execution_count": 59, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "model.edges()" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "To get all the cpds of the given model we can use `model.get_cpds` and to get the corresponding values we can iterate over each cpd and call the corresponding `get_cpd` method." 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 60, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "+----------------------+-------------------------+-------------------------+\n", 125 | "| TuberculosisOrCancer | TuberculosisOrCancer(0) | TuberculosisOrCancer(1) |\n", 126 | "+----------------------+-------------------------+-------------------------+\n", 127 | "| X-ray(0) | 0.95 | 0.05 |\n", 128 | "+----------------------+-------------------------+-------------------------+\n", 129 | "| X-ray(1) | 0.02 | 0.98 |\n", 130 | "+----------------------+-------------------------+-------------------------+\n", 131 | "+---------------+-----------+-----------+\n", 132 | "| Smoker | Smoker(0) | Smoker(1) |\n", 133 | "+---------------+-----------+-----------+\n", 134 | "| Bronchitis(0) | 0.7 | 0.3 |\n", 135 | "+---------------+-----------+-----------+\n", 136 | "| Bronchitis(1) | 0.4 | 0.6 |\n", 137 | "+---------------+-----------+-----------+\n", 138 | "+----------------------+-------------------------+-------------------------+-------------------------+-------------------------+\n", 139 | "| TuberculosisOrCancer | TuberculosisOrCancer(0) | TuberculosisOrCancer(0) | TuberculosisOrCancer(1) | TuberculosisOrCancer(1) |\n", 140 | "+----------------------+-------------------------+-------------------------+-------------------------+-------------------------+\n", 141 | "| Bronchitis | Bronchitis(0) | Bronchitis(1) | Bronchitis(0) | Bronchitis(1) |\n", 142 | "+----------------------+-------------------------+-------------------------+-------------------------+-------------------------+\n", 143 | "| Dyspnea(0) | 0.9 | 0.1 | 0.3 | 0.7 |\n", 144 | "+----------------------+-------------------------+-------------------------+-------------------------+-------------------------+\n", 145 | "| Dyspnea(1) | 0.2 | 0.8 | 0.1 | 0.9 |\n", 146 | "+----------------------+-------------------------+-------------------------+-------------------------+-------------------------+\n", 147 | "+----------------+------+\n", 148 | "| VisitToAsia(0) | 0.99 |\n", 149 | "+----------------+------+\n", 150 | "| VisitToAsia(1) | 0.01 |\n", 151 | "+----------------+------+\n", 152 | "+-----------+-----+\n", 153 | "| Smoker(0) | 0.5 |\n", 154 | "+-----------+-----+\n", 155 | "| Smoker(1) | 0.5 |\n", 156 | "+-----------+-----+\n", 157 | "+---------------+-----------+-----------+\n", 158 | "| Smoker | Smoker(0) | Smoker(1) |\n", 159 | "+---------------+-----------+-----------+\n", 160 | "| LungCancer(0) | 0.99 | 0.01 |\n", 161 | "+---------------+-----------+-----------+\n", 162 | "| LungCancer(1) | 0.9 | 0.1 |\n", 163 | "+---------------+-----------+-----------+\n", 164 | "+-----------------+----------------+----------------+\n", 165 | "| VisitToAsia | VisitToAsia(0) | VisitToAsia(1) |\n", 166 | "+-----------------+----------------+----------------+\n", 167 | "| Tuberculosis(0) | 0.99 | 0.01 |\n", 168 | "+-----------------+----------------+----------------+\n", 169 | "| Tuberculosis(1) | 0.95 | 0.05 |\n", 170 | "+-----------------+----------------+----------------+\n", 171 | "+-------------------------+-----------------+-----------------+-----------------+-----------------+\n", 172 | "| LungCancer | LungCancer(0) | LungCancer(0) | LungCancer(1) | LungCancer(1) |\n", 173 | "+-------------------------+-----------------+-----------------+-----------------+-----------------+\n", 174 | "| Tuberculosis | Tuberculosis(0) | Tuberculosis(1) | Tuberculosis(0) | Tuberculosis(1) |\n", 175 | "+-------------------------+-----------------+-----------------+-----------------+-----------------+\n", 176 | "| TuberculosisOrCancer(0) | 1.0 | 0.0 | 0.0 | 1.0 |\n", 177 | "+-------------------------+-----------------+-----------------+-----------------+-----------------+\n", 178 | "| TuberculosisOrCancer(1) | 0.0 | 1.0 | 0.0 | 1.0 |\n", 179 | "+-------------------------+-----------------+-----------------+-----------------+-----------------+\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "cpds = model.get_cpds()\n", 185 | "for cpd in cpds:\n", 186 | " print(cpd)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "pgmpy not only allows us to read from the specific file format but also helps us to write the given model into the specific file format.\n", 194 | "Let's write a sample model into Probmodel XML file.\n", 195 | "\n", 196 | "For that first define our data for the model." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 61, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "import numpy as np\n", 206 | "\n", 207 | "edges_list = [('VisitToAsia', 'Tuberculosis'),\n", 208 | " ('LungCancer', 'TuberculosisOrCancer'),\n", 209 | " ('Smoker', 'LungCancer'),\n", 210 | " ('Smoker', 'Bronchitis'),\n", 211 | " ('Tuberculosis', 'TuberculosisOrCancer'),\n", 212 | " ('Bronchitis', 'Dyspnea'),\n", 213 | " ('TuberculosisOrCancer', 'Dyspnea'),\n", 214 | " ('TuberculosisOrCancer', 'X-ray')]\n", 215 | "nodes = {'Smoker': {'States': {'no': {}, 'yes': {}},\n", 216 | " 'role': 'chance',\n", 217 | " 'type': 'finiteStates',\n", 218 | " 'Coordinates': {'y': '52', 'x': '568'},\n", 219 | " 'AdditionalProperties': {'Title': 'S', 'Relevance': '7.0'}},\n", 220 | " 'Bronchitis': {'States': {'no': {}, 'yes': {}},\n", 221 | " 'role': 'chance',\n", 222 | " 'type': 'finiteStates',\n", 223 | " 'Coordinates': {'y': '181', 'x': '698'},\n", 224 | " 'AdditionalProperties': {'Title': 'B', 'Relevance': '7.0'}},\n", 225 | " 'VisitToAsia': {'States': {'no': {}, 'yes': {}},\n", 226 | " 'role': 'chance',\n", 227 | " 'type': 'finiteStates',\n", 228 | " 'Coordinates': {'y': '58', 'x': '290'},\n", 229 | " 'AdditionalProperties': {'Title': 'A', 'Relevance': '7.0'}},\n", 230 | " 'Tuberculosis': {'States': {'no': {}, 'yes': {}},\n", 231 | " 'role': 'chance',\n", 232 | " 'type': 'finiteStates',\n", 233 | " 'Coordinates': {'y': '150', 'x': '201'},\n", 234 | " 'AdditionalProperties': {'Title': 'T', 'Relevance': '7.0'}},\n", 235 | " 'X-ray': {'States': {'no': {}, 'yes': {}},\n", 236 | " 'role': 'chance',\n", 237 | " 'AdditionalProperties': {'Title': 'X', 'Relevance': '7.0'},\n", 238 | " 'Coordinates': {'y': '322', 'x': '252'},\n", 239 | " 'Comment': 'Indica si el test de rayos X ha sido positivo',\n", 240 | " 'type': 'finiteStates'},\n", 241 | " 'Dyspnea': {'States': {'no': {}, 'yes': {}},\n", 242 | " 'role': 'chance',\n", 243 | " 'type': 'finiteStates',\n", 244 | " 'Coordinates': {'y': '321', 'x': '533'},\n", 245 | " 'AdditionalProperties': {'Title': 'D', 'Relevance': '7.0'}},\n", 246 | " 'TuberculosisOrCancer': {'States': {'no': {}, 'yes': {}},\n", 247 | " 'role': 'chance',\n", 248 | " 'type': 'finiteStates',\n", 249 | " 'Coordinates': {'y': '238', 'x': '336'},\n", 250 | " 'AdditionalProperties': {'Title': 'E', 'Relevance': '7.0'}},\n", 251 | " 'LungCancer': {'States': {'no': {}, 'yes': {}},\n", 252 | " 'role': 'chance',\n", 253 | " 'type': 'finiteStates',\n", 254 | " 'Coordinates': {'y': '152', 'x': '421'},\n", 255 | " 'AdditionalProperties': {'Title': 'L', 'Relevance': '7.0'}}}\n", 256 | "edges = {'LungCancer': {'TuberculosisOrCancer': {'directed': 'true'}},\n", 257 | " 'Smoker': {'LungCancer': {'directed': 'true'},\n", 258 | " 'Bronchitis': {'directed': 'true'}},\n", 259 | " 'Dyspnea': {},\n", 260 | " 'X-ray': {},\n", 261 | " 'VisitToAsia': {'Tuberculosis': {'directed': 'true'}},\n", 262 | " 'TuberculosisOrCancer': {'X-ray': {'directed': 'true'},\n", 263 | " 'Dyspnea': {'directed': 'true'}},\n", 264 | " 'Bronchitis': {'Dyspnea': {'directed': 'true'}},\n", 265 | " 'Tuberculosis': {'TuberculosisOrCancer': {'directed': 'true'}}}\n", 266 | "\n", 267 | "cpds = [{'Values': np.array([[0.95, 0.05], [0.02, 0.98]]),\n", 268 | " 'Variables': {'X-ray': ['TuberculosisOrCancer']}},\n", 269 | " {'Values': np.array([[0.7, 0.3], [0.4, 0.6]]),\n", 270 | " 'Variables': {'Bronchitis': ['Smoker']}},\n", 271 | " {'Values': np.array([[0.9, 0.1, 0.3, 0.7], [0.2, 0.8, 0.1, 0.9]]),\n", 272 | " 'Variables': {'Dyspnea': ['TuberculosisOrCancer', 'Bronchitis']}},\n", 273 | " {'Values': np.array([[0.99], [0.01]]),\n", 274 | " 'Variables': {'VisitToAsia': []}},\n", 275 | " {'Values': np.array([[0.5], [0.5]]),\n", 276 | " 'Variables': {'Smoker': []}},\n", 277 | " {'Values': np.array([[0.99, 0.01], [0.9, 0.1]]),\n", 278 | " 'Variables': {'LungCancer': ['Smoker']}},\n", 279 | " {'Values': np.array([[0.99, 0.01], [0.95, 0.05]]),\n", 280 | " 'Variables': {'Tuberculosis': ['VisitToAsia']}},\n", 281 | " {'Values': np.array([[1, 0, 0, 1], [0, 1, 0, 1]]),\n", 282 | " 'Variables': {'TuberculosisOrCancer': ['LungCancer', 'Tuberculosis']}}]" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "Now let's create a `BayesianModel` for this data." 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 62, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "from pgmpy.models import BayesianModel\n", 299 | "from pgmpy.factors.discrete import TabularCPD\n", 300 | "\n", 301 | "model = BayesianModel(edges_list)\n", 302 | "\n", 303 | "for node in nodes:\n", 304 | " for key, value in nodes[node].items():\n", 305 | " model.nodes[node][key] = value\n", 306 | "\n", 307 | "for u in edges.keys():\n", 308 | " for v in edges[u].keys():\n", 309 | " #import pdb; pdb.set_trace()\n", 310 | " for key, value in edges[u][v].items():\n", 311 | " model.edges[(u, v)][key] = value\n", 312 | "\n", 313 | "tabular_cpds = []\n", 314 | "for cpd in cpds:\n", 315 | " var = list(cpd['Variables'].keys())[0]\n", 316 | " evidence = cpd['Variables'][var]\n", 317 | " values = cpd['Values']\n", 318 | " states = len(nodes[var]['States'])\n", 319 | " evidence_card = [len(nodes[evidence_var]['States'])\n", 320 | " for evidence_var in evidence]\n", 321 | " tabular_cpds.append(\n", 322 | " TabularCPD(var, states, values, evidence, evidence_card))\n", 323 | "\n", 324 | "model.add_cpds(*tabular_cpds)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 63, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "from pgmpy.readwrite import ProbModelXMLWriter, get_probmodel_data" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": {}, 339 | "source": [ 340 | "To get the data which we need to give to the ProbModelXMLWriter to get the corresponding fileformat we need to use the method get_probmodel_data. This method is only specific to ProbModelXML file, for other file formats we would directly pass the model to the given Writer Class." 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 64, 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "name": "stdout", 350 | "output_type": "stream", 351 | "text": [ 352 | "\n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " Indica si el test de rayos X ha sido positivo\n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " 0.95 0.05 0.02 0.98 \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " 0.7 0.3 0.4 0.6 \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " 0.9 0.1 0.3 0.7 0.2 0.8 0.1 0.9 \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " 0.99 0.01 \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " 0.5 0.5 \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " 0.99 0.01 0.9 0.1 \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " 0.99 0.01 0.95 0.05 \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " 1.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | "\n", 550 | "\n" 551 | ] 552 | } 553 | ], 554 | "source": [ 555 | "model_data = get_probmodel_data(model)\n", 556 | "writer = ProbModelXMLWriter(model_data=model_data)\n", 557 | "print(writer.__str__().decode('utf-8'))" 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "metadata": {}, 563 | "source": [ 564 | "To write the xml data into the file we can use the method `write_file` of the given Writer class." 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": 65, 570 | "metadata": {}, 571 | "outputs": [], 572 | "source": [ 573 | "writer.write_file('probmodelxml.pgmx')" 574 | ] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "metadata": {}, 579 | "source": [ 580 | "## General WorkFlow of the readwrite module" 581 | ] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": {}, 586 | "source": [ 587 | "pgmpy.readwrite.[fileformat]Reader is base class for reading the given file format. Replace file format with the desired fileforamt from which you want to read the file. In this base class there are different methods defined to parse the given file. For example for XMLBelief Network various methods which are defined are as follows:" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": 66, 593 | "metadata": {}, 594 | "outputs": [], 595 | "source": [ 596 | "from pgmpy.readwrite.XMLBeliefNetwork import XBNReader\n", 597 | "reader = XBNReader('../files/xmlbelief.xml')" 598 | ] 599 | }, 600 | { 601 | "cell_type": "markdown", 602 | "metadata": {}, 603 | "source": [ 604 | "`get_model`: It returns an instance of the given model, for ex, BayesianModel in cases of XMLBelief format." 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": 67, 610 | "metadata": {}, 611 | "outputs": [ 612 | { 613 | "name": "stdout", 614 | "output_type": "stream", 615 | "text": [ 616 | "['a', 'b', 'c', 'd', 'e']\n", 617 | "[('a', 'b'), ('a', 'c'), ('b', 'd'), ('c', 'd'), ('c', 'e')]\n" 618 | ] 619 | } 620 | ], 621 | "source": [ 622 | "model = reader.get_model()\n", 623 | "print(model.nodes())\n", 624 | "print(model.edges())" 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "metadata": {}, 630 | "source": [ 631 | "pgmpy.readwrite.[fileformat]Writer is base class for writing the model into the given file format. It takes a model as an argument which can be an instance of BayesianModel, MarkovModel. Replace file fomat with the desired fileforamt from which you want to read the file. In this base class there are different methods defined to set the contents of the new file to be created from the given model. For example for XMLBelief Network various methods such as set_analysisnotebook, etc are defined which helps to set up the network data." 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 53, 637 | "metadata": {}, 638 | "outputs": [], 639 | "source": [ 640 | "from pgmpy.models import BayesianModel\n", 641 | "from pgmpy.factors.discrete import TabularCPD\n", 642 | "import numpy as np\n", 643 | "nodes = {'c': {'STATES': ['Present', 'Absent'],\n", 644 | " 'DESCRIPTION': '(c) Brain Tumor',\n", 645 | " 'YPOS': '11935',\n", 646 | " 'XPOS': '15250',\n", 647 | " 'TYPE': 'discrete'},\n", 648 | " 'a': {'STATES': ['Present', 'Absent'],\n", 649 | " 'DESCRIPTION': '(a) Metastatic Cancer',\n", 650 | " 'YPOS': '10465',\n", 651 | " 'XPOS': '13495',\n", 652 | " 'TYPE': 'discrete'},\n", 653 | " 'b': {'STATES': ['Present', 'Absent'],\n", 654 | " 'DESCRIPTION': '(b) Serum Calcium Increase',\n", 655 | " 'YPOS': '11965',\n", 656 | " 'XPOS': '11290',\n", 657 | " 'TYPE': 'discrete'},\n", 658 | " 'e': {'STATES': ['Present', 'Absent'],\n", 659 | " 'DESCRIPTION': '(e) Papilledema',\n", 660 | " 'YPOS': '13240',\n", 661 | " 'XPOS': '17305',\n", 662 | " 'TYPE': 'discrete'},\n", 663 | " 'd': {'STATES': ['Present', 'Absent'],\n", 664 | " 'DESCRIPTION': '(d) Coma',\n", 665 | " 'YPOS': '12985',\n", 666 | " 'XPOS': '13960',\n", 667 | " 'TYPE': 'discrete'}}\n", 668 | "model = BayesianModel([('b', 'd'), ('a', 'b'), ('a', 'c'), ('c', 'd'), ('c', 'e')])\n", 669 | "cpd_distribution = {'a': {'TYPE': 'discrete', 'DPIS': np.array([[0.2, 0.8]])},\n", 670 | " 'e': {'TYPE': 'discrete', 'DPIS': np.array([[0.8, 0.2],\n", 671 | " [0.6, 0.4]]), 'CONDSET': ['c'], 'CARDINALITY': [2]},\n", 672 | " 'b': {'TYPE': 'discrete', 'DPIS': np.array([[0.8, 0.2],\n", 673 | " [0.2, 0.8]]), 'CONDSET': ['a'], 'CARDINALITY': [2]},\n", 674 | " 'c': {'TYPE': 'discrete', 'DPIS': np.array([[0.2, 0.8],\n", 675 | " [0.05, 0.95]]), 'CONDSET': ['a'], 'CARDINALITY': [2]},\n", 676 | " 'd': {'TYPE': 'discrete', 'DPIS': np.array([[0.8, 0.2],\n", 677 | " [0.9, 0.1],\n", 678 | " [0.7, 0.3],\n", 679 | " [0.05, 0.95]]), 'CONDSET': ['b', 'c'], 'CARDINALITY': [2, 2]}}\n", 680 | "\n", 681 | "tabular_cpds = []\n", 682 | "for var, values in cpd_distribution.items():\n", 683 | " evidence = values['CONDSET'] if 'CONDSET' in values else []\n", 684 | " cpd = values['DPIS']\n", 685 | " evidence_card = values['CARDINALITY'] if 'CARDINALITY' in values else []\n", 686 | " states = nodes[var]['STATES']\n", 687 | " cpd = TabularCPD(var, len(states), cpd,\n", 688 | " evidence=evidence,\n", 689 | " evidence_card=evidence_card)\n", 690 | " tabular_cpds.append(cpd)\n", 691 | "model.add_cpds(*tabular_cpds)\n", 692 | "\n", 693 | "for var, properties in nodes.items():\n", 694 | " for key, value in properties.items():\n", 695 | " model.nodes[var][key] = value\n" 696 | ] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "execution_count": 54, 701 | "metadata": {}, 702 | "outputs": [], 703 | "source": [ 704 | "from pgmpy.readwrite.XMLBeliefNetwork import XBNWriter\n", 705 | "writer = XBNWriter(model = model)" 706 | ] 707 | } 708 | ], 709 | "metadata": { 710 | "kernelspec": { 711 | "display_name": "Python 3", 712 | "language": "python", 713 | "name": "python3" 714 | }, 715 | "language_info": { 716 | "codemirror_mode": { 717 | "name": "ipython", 718 | "version": 3 719 | }, 720 | "file_extension": ".py", 721 | "mimetype": "text/x-python", 722 | "name": "python", 723 | "nbconvert_exporter": "python", 724 | "pygments_lexer": "ipython3", 725 | "version": "3.8.10" 726 | } 727 | }, 728 | "nbformat": 4, 729 | "nbformat_minor": 1 730 | } 731 | -------------------------------------------------------------------------------- /pdfs/Probabilistic Graphical Models using pgmpy.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgmpy/pgmpy_notebook/6bcd9fc2d84d7181055ddfd8c9459ebecd0b67b0/pdfs/Probabilistic Graphical Models using pgmpy.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | networkx 2 | numpy 3 | scipy 4 | pandas 5 | pyparsing 6 | pytorch 7 | statsmodels 8 | tqdm 9 | joblib 10 | pgmpy 11 | scikit-learn 12 | -------------------------------------------------------------------------------- /scripts/1/discretize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.datasets import load_iris 4 | 5 | iris = load_iris() 6 | mini_iris = np.round(iris.data[:, :2]).astype(int) 7 | data = pd.DataFrame(mini_iris, columns=['length', 'width']) 8 | data['type'] = iris.target 9 | 10 | #Shuffle data 11 | data = data.iloc[np.random.permutation(len(data))] 12 | --------------------------------------------------------------------------------