├── 1_3_pseudo_correlation.ipynb ├── 4_1_reggresion_adjustment.ipynb ├── 4_2_iptw.ipynb ├── 4_3_dr.ipynb ├── 5_1_randomforest.ipynb ├── 5_2_meta_learners.ipynb ├── 5_2_meta_learners_issue18.ipynb ├── 5_2_meta_learners_issue18_issue36.ipynb ├── 5_3_doubly_robust_learning.ipynb ├── 5_3_doubly_robust_learning_issue18.ipynb ├── 6_3_lingam.ipynb ├── 7_2_bayesian_network_bic.ipynb ├── 7_3_bayesian_network_independence_test.ipynb ├── 7_5_bayesian_network_pc_algorithm.ipynb ├── 7_5_bayesian_network_pc_algorithm_20220421.ipynb ├── 7_5_bayesian_network_pc_algorithm_210410.ipynb ├── 8_3_5_deeplearning_gan_sam.ipynb ├── LICENSE ├── README.md └── etc └── book.jpg /4_1_reggresion_adjustment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "4_1_reggresion_adjustment.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "d-IAJLC2k1NX", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "# 4.1 回帰分析による因果推論\n", 24 | "\n", 25 | "本ファイルは、4.1節の実装です。\n", 26 | "\n", 27 | "テレビCMの広告効果の推定を例に、回帰分析による因果推論を実装します。\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "2XdIDbdlejUk", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "## プログラム実行前の設定など" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "wqHjwstVeXYt", 44 | "colab_type": "code", 45 | "colab": {} 46 | }, 47 | "source": [ 48 | "# 乱数のシードを設定\n", 49 | "import random\n", 50 | "import numpy as np\n", 51 | "\n", 52 | "np.random.seed(1234)\n", 53 | "random.seed(1234)\n" 54 | ], 55 | "execution_count": 0, 56 | "outputs": [] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "id": "RIhcLRqlem3V", 62 | "colab_type": "code", 63 | "colab": {} 64 | }, 65 | "source": [ 66 | "# 使用するパッケージ(ライブラリと関数)を定義\n", 67 | "# 標準正規分布の生成用\n", 68 | "from numpy.random import *\n", 69 | "\n", 70 | "# グラフの描画用\n", 71 | "import matplotlib.pyplot as plt\n", 72 | "\n", 73 | "# SciPy 平均0、分散1に正規化(標準化)関数\n", 74 | "import scipy.stats\n", 75 | "\n", 76 | "# シグモイド関数をimport\n", 77 | "from scipy.special import expit\n", 78 | "\n", 79 | "# その他\n", 80 | "import pandas as pd\n" 81 | ], 82 | "execution_count": 0, 83 | "outputs": [] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": { 88 | "id": "AWqP6yeQlI_t", 89 | "colab_type": "text" 90 | }, 91 | "source": [ 92 | "## データの作成" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "metadata": { 98 | "id": "EJaQIHz4fNXb", 99 | "colab_type": "code", 100 | "colab": {} 101 | }, 102 | "source": [ 103 | "# データ数\n", 104 | "num_data = 200\n", 105 | "\n", 106 | "# 年齢\n", 107 | "x_1 = randint(15, 76, num_data) # 15から75歳の一様乱数\n", 108 | "\n", 109 | "# 性別(0を女性、1を男性とします)\n", 110 | "x_2 = randint(0, 2, num_data) # 0か1の乱数\n" 111 | ], 112 | "execution_count": 0, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": { 118 | "id": "JiT_gc5ZmAQa", 119 | "colab_type": "text" 120 | }, 121 | "source": [ 122 | "## テレビCMを見たかどうか" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "metadata": { 128 | "id": "hfPbhhm6gvW6", 129 | "colab_type": "code", 130 | "colab": {} 131 | }, 132 | "source": [ 133 | "# ノイズの生成\n", 134 | "e_z = randn(num_data)\n", 135 | "\n", 136 | "# シグモイド関数に入れる部分\n", 137 | "z_base = x_1 + (1-x_2)*10 - 40 + 5*e_z\n", 138 | "\n", 139 | "# シグモイド関数を計算\n", 140 | "z_prob = expit(0.1*z_base)\n", 141 | "\n", 142 | "# テレビCMを見たかどうかの変数(0は見ていない、1は見た)\n", 143 | "Z = np.array([])\n", 144 | "\n", 145 | "for i in range(num_data):\n", 146 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n", 147 | " Z = np.append(Z, Z_i)\n" 148 | ], 149 | "execution_count": 0, 150 | "outputs": [] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "id": "b2PLquJGi2Te", 156 | "colab_type": "text" 157 | }, 158 | "source": [ 159 | "## 購入量Yを作成" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "metadata": { 165 | "id": "nv-ELtFqi5L5", 166 | "colab_type": "code", 167 | "colab": {} 168 | }, 169 | "source": [ 170 | "# ノイズの生成\n", 171 | "e_y = randn(num_data)\n", 172 | "\n", 173 | "Y = -x_1 + 30*x_2 + 10*Z + 80 + 10*e_y\n" 174 | ], 175 | "execution_count": 0, 176 | "outputs": [] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": { 181 | "id": "BHcdUlW9koTa", 182 | "colab_type": "text" 183 | }, 184 | "source": [ 185 | "## データをまとめた表を作成し、平均値を比べる" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "metadata": { 191 | "id": "HPqwrISXktRj", 192 | "colab_type": "code", 193 | "colab": { 194 | "base_uri": "https://localhost:8080/", 195 | "height": 195 196 | }, 197 | "outputId": "2afcd49a-b744-4b4a-db32-838377fa0305" 198 | }, 199 | "source": [ 200 | "df = pd.DataFrame({'年齢': x_1,\n", 201 | " '性別': x_2,\n", 202 | " 'CMを見た': Z,\n", 203 | " '購入量': Y,\n", 204 | " })\n", 205 | "\n", 206 | "df.head() # 先頭を表示\n" 207 | ], 208 | "execution_count": 6, 209 | "outputs": [ 210 | { 211 | "output_type": "execute_result", 212 | "data": { 213 | "text/html": [ 214 | "
\n", 215 | "\n", 228 | "\n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | "
年齢性別CMを見た購入量
06201.024.464285
13400.045.693411
25311.064.998281
36811.047.186898
42710.0100.114260
\n", 276 | "
" 277 | ], 278 | "text/plain": [ 279 | " 年齢 性別 CMを見た 購入量\n", 280 | "0 62 0 1.0 24.464285\n", 281 | "1 34 0 0.0 45.693411\n", 282 | "2 53 1 1.0 64.998281\n", 283 | "3 68 1 1.0 47.186898\n", 284 | "4 27 1 0.0 100.114260" 285 | ] 286 | }, 287 | "metadata": { 288 | "tags": [] 289 | }, 290 | "execution_count": 6 291 | } 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "metadata": { 297 | "id": "HHInQ1Sukrg0", 298 | "colab_type": "code", 299 | "colab": { 300 | "base_uri": "https://localhost:8080/", 301 | "height": 210 302 | }, 303 | "outputId": "40cbd6c0-df68-4eec-c8c7-8a518aa4d52c" 304 | }, 305 | "source": [ 306 | "# 平均値を比べる\n", 307 | "\n", 308 | "print(df[df[\"CMを見た\"] == 1.0].mean())\n", 309 | "print(\"--------\")\n", 310 | "print(df[df[\"CMを見た\"] == 0.0].mean())\n" 311 | ], 312 | "execution_count": 7, 313 | "outputs": [ 314 | { 315 | "output_type": "stream", 316 | "text": [ 317 | "年齢 55.836066\n", 318 | "性別 0.483607\n", 319 | "CMを見た 1.000000\n", 320 | "購入量 49.711478\n", 321 | "dtype: float64\n", 322 | "--------\n", 323 | "年齢 32.141026\n", 324 | "性別 0.692308\n", 325 | "CMを見た 0.000000\n", 326 | "購入量 68.827143\n", 327 | "dtype: float64\n" 328 | ], 329 | "name": "stdout" 330 | } 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": { 336 | "id": "kwKOk59aogBd", 337 | "colab_type": "text" 338 | }, 339 | "source": [ 340 | "## 回帰分析を実施" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "metadata": { 346 | "id": "rBtHC3smoiMC", 347 | "colab_type": "code", 348 | "colab": { 349 | "base_uri": "https://localhost:8080/", 350 | "height": 34 351 | }, 352 | "outputId": "e5a4dee9-e80a-4c83-b32b-8461d494eb00" 353 | }, 354 | "source": [ 355 | "# scikit-learnから線形回帰をimport\n", 356 | "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html\n", 357 | "from sklearn.linear_model import LinearRegression\n", 358 | "\n", 359 | "# 説明変数\n", 360 | "X = df[[\"年齢\", \"性別\", \"CMを見た\"]]\n", 361 | "\n", 362 | "# 被説明変数(目的変数)\n", 363 | "y = df[\"購入量\"]\n", 364 | "\n", 365 | "# 回帰の実施\n", 366 | "reg = LinearRegression().fit(X, y)\n", 367 | "\n", 368 | "# 回帰した結果の係数を出力\n", 369 | "print(\"係数:\", reg.coef_)\n" 370 | ], 371 | "execution_count": 8, 372 | "outputs": [ 373 | { 374 | "output_type": "stream", 375 | "text": [ 376 | "係数: [-0.95817951 32.70149412 10.41327647]\n" 377 | ], 378 | "name": "stdout" 379 | } 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": { 385 | "id": "1IdVhXmMps-w", 386 | "colab_type": "text" 387 | }, 388 | "source": [ 389 | "以上" 390 | ] 391 | } 392 | ] 393 | } -------------------------------------------------------------------------------- /4_2_iptw.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "4_2_iptw.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "d-IAJLC2k1NX", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "# 4.2 逆確率重み付け法(IPTW)による因果推論\n", 24 | "\n", 25 | "本ファイルは、4.2節の実装です。\n", 26 | "\n", 27 | "4.1節と同じく、テレビCMの広告効果の推定を例に、回帰分析による因果推論を実装します。\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "2XdIDbdlejUk", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "## プログラム実行前の設定など" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "wqHjwstVeXYt", 44 | "colab_type": "code", 45 | "colab": {} 46 | }, 47 | "source": [ 48 | "# 乱数のシードを設定\n", 49 | "import random\n", 50 | "import numpy as np\n", 51 | "\n", 52 | "np.random.seed(1234)\n", 53 | "random.seed(1234)\n" 54 | ], 55 | "execution_count": 0, 56 | "outputs": [] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "id": "RIhcLRqlem3V", 62 | "colab_type": "code", 63 | "colab": {} 64 | }, 65 | "source": [ 66 | "# 使用するパッケージ(ライブラリと関数)を定義\n", 67 | "# 標準正規分布の生成用\n", 68 | "from numpy.random import *\n", 69 | "\n", 70 | "# グラフの描画用\n", 71 | "import matplotlib.pyplot as plt\n", 72 | "\n", 73 | "# SciPy 平均0、分散1に正規化(標準化)関数\n", 74 | "import scipy.stats\n", 75 | "\n", 76 | "# シグモイド関数をimport\n", 77 | "from scipy.special import expit\n", 78 | "\n", 79 | "# その他\n", 80 | "import pandas as pd\n" 81 | ], 82 | "execution_count": 0, 83 | "outputs": [] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": { 88 | "id": "AWqP6yeQlI_t", 89 | "colab_type": "text" 90 | }, 91 | "source": [ 92 | "## データの作成" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "metadata": { 98 | "id": "EJaQIHz4fNXb", 99 | "colab_type": "code", 100 | "colab": {} 101 | }, 102 | "source": [ 103 | "# データ数\n", 104 | "num_data = 200\n", 105 | "\n", 106 | "# 年齢\n", 107 | "x_1 = randint(15, 76, num_data) # 15から75歳の一様乱数\n", 108 | "\n", 109 | "# 性別(0を女性、1を男性とします)\n", 110 | "x_2 = randint(0, 2, num_data) # 0か1の乱数\n" 111 | ], 112 | "execution_count": 0, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": { 118 | "id": "JiT_gc5ZmAQa", 119 | "colab_type": "text" 120 | }, 121 | "source": [ 122 | "## テレビCMを見たかどうか" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "metadata": { 128 | "id": "hfPbhhm6gvW6", 129 | "colab_type": "code", 130 | "colab": {} 131 | }, 132 | "source": [ 133 | "# ノイズの生成\n", 134 | "e_z = randn(num_data)\n", 135 | "\n", 136 | "# シグモイド関数に入れる部分\n", 137 | "z_base = x_1 + (1-x_2)*10 - 40 + 5*e_z\n", 138 | "\n", 139 | "# シグモイド関数を計算\n", 140 | "z_prob = expit(0.1*z_base)\n", 141 | "\n", 142 | "# テレビCMを見たかどうかの変数(0は見ていない、1は見た)\n", 143 | "Z = np.array([])\n", 144 | "\n", 145 | "for i in range(num_data):\n", 146 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n", 147 | " Z = np.append(Z, Z_i)\n" 148 | ], 149 | "execution_count": 0, 150 | "outputs": [] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "id": "b2PLquJGi2Te", 156 | "colab_type": "text" 157 | }, 158 | "source": [ 159 | "## 購入量Yを作成" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "metadata": { 165 | "id": "nv-ELtFqi5L5", 166 | "colab_type": "code", 167 | "colab": {} 168 | }, 169 | "source": [ 170 | "# ノイズの生成\n", 171 | "e_y = randn(num_data)\n", 172 | "\n", 173 | "Y = -x_1 + 30*x_2 + 10*Z + 80 + 10*e_y\n" 174 | ], 175 | "execution_count": 0, 176 | "outputs": [] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": { 181 | "id": "BHcdUlW9koTa", 182 | "colab_type": "text" 183 | }, 184 | "source": [ 185 | "## データをまとめた表を作成し、平均値を比べる" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "metadata": { 191 | "id": "HPqwrISXktRj", 192 | "colab_type": "code", 193 | "outputId": "9ad6013c-9715-481b-d68a-17fdd3edd281", 194 | "colab": { 195 | "base_uri": "https://localhost:8080/", 196 | "height": 195 197 | } 198 | }, 199 | "source": [ 200 | "df = pd.DataFrame({'年齢': x_1,\n", 201 | " '性別': x_2,\n", 202 | " 'CMを見た': Z,\n", 203 | " '購入量': Y,\n", 204 | " })\n", 205 | "\n", 206 | "df.head() # 先頭を表示\n" 207 | ], 208 | "execution_count": 6, 209 | "outputs": [ 210 | { 211 | "output_type": "execute_result", 212 | "data": { 213 | "text/html": [ 214 | "
\n", 215 | "\n", 228 | "\n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | "
年齢性別CMを見た購入量
06201.024.464285
13400.045.693411
25311.064.998281
36811.047.186898
42710.0100.114260
\n", 276 | "
" 277 | ], 278 | "text/plain": [ 279 | " 年齢 性別 CMを見た 購入量\n", 280 | "0 62 0 1.0 24.464285\n", 281 | "1 34 0 0.0 45.693411\n", 282 | "2 53 1 1.0 64.998281\n", 283 | "3 68 1 1.0 47.186898\n", 284 | "4 27 1 0.0 100.114260" 285 | ] 286 | }, 287 | "metadata": { 288 | "tags": [] 289 | }, 290 | "execution_count": 6 291 | } 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "metadata": { 297 | "id": "HHInQ1Sukrg0", 298 | "colab_type": "code", 299 | "outputId": "ec83304b-9f7c-4334-93a9-c13022813ae9", 300 | "colab": { 301 | "base_uri": "https://localhost:8080/", 302 | "height": 210 303 | } 304 | }, 305 | "source": [ 306 | "# 平均値を比べる\n", 307 | "\n", 308 | "print(df[df[\"CMを見た\"] == 1.0].mean())\n", 309 | "print(\"--------\")\n", 310 | "print(df[df[\"CMを見た\"] == 0.0].mean())\n" 311 | ], 312 | "execution_count": 7, 313 | "outputs": [ 314 | { 315 | "output_type": "stream", 316 | "text": [ 317 | "年齢 55.836066\n", 318 | "性別 0.483607\n", 319 | "CMを見た 1.000000\n", 320 | "購入量 49.711478\n", 321 | "dtype: float64\n", 322 | "--------\n", 323 | "年齢 32.141026\n", 324 | "性別 0.692308\n", 325 | "CMを見た 0.000000\n", 326 | "購入量 68.827143\n", 327 | "dtype: float64\n" 328 | ], 329 | "name": "stdout" 330 | } 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": { 336 | "id": "kwKOk59aogBd", 337 | "colab_type": "text" 338 | }, 339 | "source": [ 340 | "## ここからが4.1節と異なります。傾向スコアの推定" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "metadata": { 346 | "id": "rBtHC3smoiMC", 347 | "colab_type": "code", 348 | "outputId": "95ab93a5-7d8e-4226-f654-e1351b2537ba", 349 | "colab": { 350 | "base_uri": "https://localhost:8080/", 351 | "height": 52 352 | } 353 | }, 354 | "source": [ 355 | "# scikit-learnからロジスティク回帰をimport\n", 356 | "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html\n", 357 | "from sklearn.linear_model import LogisticRegression\n", 358 | "\n", 359 | "# 説明変数\n", 360 | "X = df[[\"年齢\", \"性別\"]]\n", 361 | "\n", 362 | "# 被説明変数(目的変数)\n", 363 | "Z = df[\"CMを見た\"]\n", 364 | "\n", 365 | "# 回帰の実施\n", 366 | "reg = LogisticRegression().fit(X,Z)\n", 367 | "\n", 368 | "# 回帰した結果の係数を出力\n", 369 | "print(\"係数beta:\", reg.coef_)\n", 370 | "print(\"係数alpha:\", reg.intercept_)" 371 | ], 372 | "execution_count": 8, 373 | "outputs": [ 374 | { 375 | "output_type": "stream", 376 | "text": [ 377 | "係数beta: [[ 0.10562765 -1.38263933]]\n", 378 | "係数alpha: [-3.37146523]\n" 379 | ], 380 | "name": "stdout" 381 | } 382 | ] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": { 387 | "id": "nV0dm82l3QXy", 388 | "colab_type": "text" 389 | }, 390 | "source": [ 391 | "### 各人の傾向スコアを求める" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "metadata": { 397 | "id": "gGCKiujL3P1i", 398 | "colab_type": "code", 399 | "colab": { 400 | "base_uri": "https://localhost:8080/", 401 | "height": 228 402 | }, 403 | "outputId": "c9201825-cd6f-41fa-aec2-52858d82b59e" 404 | }, 405 | "source": [ 406 | "Z_pre = reg.predict_proba(X)\n", 407 | "print(Z_pre[0:5]) # 5人ほどの結果を見てみる\n", 408 | "print(\"----\")\n", 409 | "print(Z[0:5]) # 5人ほどの正解\n" 410 | ], 411 | "execution_count": 9, 412 | "outputs": [ 413 | { 414 | "output_type": "stream", 415 | "text": [ 416 | "[[0.04002323 0.95997677]\n", 417 | " [0.44525168 0.55474832]\n", 418 | " [0.30065918 0.69934082]\n", 419 | " [0.08101946 0.91898054]\n", 420 | " [0.87013558 0.12986442]]\n", 421 | "----\n", 422 | "0 1.0\n", 423 | "1 0.0\n", 424 | "2 1.0\n", 425 | "3 1.0\n", 426 | "4 0.0\n", 427 | "Name: CMを見た, dtype: float64\n" 428 | ], 429 | "name": "stdout" 430 | } 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": { 436 | "colab_type": "text", 437 | "id": "wL-hlBN36DZf" 438 | }, 439 | "source": [ 440 | "### 平均処置効果ATEを求める" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "metadata": { 446 | "id": "6Ujy7JJa6Gwi", 447 | "colab_type": "code", 448 | "colab": { 449 | "base_uri": "https://localhost:8080/", 450 | "height": 34 451 | }, 452 | "outputId": "2c4ab7d4-2393-4937-ed5e-5e394fe2116d" 453 | }, 454 | "source": [ 455 | "ATE_i = Y/Z_pre[:, 1]*Z - Y/Z_pre[:, 0]*(1-Z)\n", 456 | "ATE = 1/len(Y)*ATE_i.sum()\n", 457 | "print(\"推定したATE\", ATE)\n" 458 | ], 459 | "execution_count": 10, 460 | "outputs": [ 461 | { 462 | "output_type": "stream", 463 | "text": [ 464 | "推定したATE 8.847476810855458\n" 465 | ], 466 | "name": "stdout" 467 | } 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "metadata": { 473 | "id": "1IdVhXmMps-w", 474 | "colab_type": "text" 475 | }, 476 | "source": [ 477 | "以上" 478 | ] 479 | } 480 | ] 481 | } -------------------------------------------------------------------------------- /4_3_dr.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "4_3_dr.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "d-IAJLC2k1NX" 20 | }, 21 | "source": [ 22 | "# 4.3 Doubly Robust法(DR法)による因果推論の実装\n", 23 | "\n", 24 | "本ファイルは、4.3節の実装です。\n", 25 | "\n", 26 | "4.1節、4.2節と同じく、テレビCMの広告効果の推定を例に、回帰分析による因果推論を実装します。\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "2XdIDbdlejUk" 33 | }, 34 | "source": [ 35 | "## プログラム実行前の設定など" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "metadata": { 41 | "id": "wqHjwstVeXYt" 42 | }, 43 | "source": [ 44 | "# 乱数のシードを設定\n", 45 | "import random\n", 46 | "import numpy as np\n", 47 | "\n", 48 | "np.random.seed(1234)\n", 49 | "random.seed(1234)\n" 50 | ], 51 | "execution_count": null, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "metadata": { 57 | "id": "RIhcLRqlem3V" 58 | }, 59 | "source": [ 60 | "# 使用するパッケージ(ライブラリと関数)を定義\n", 61 | "# 標準正規分布の生成用\n", 62 | "from numpy.random import *\n", 63 | "\n", 64 | "# グラフの描画用\n", 65 | "import matplotlib.pyplot as plt\n", 66 | "\n", 67 | "# SciPy 平均0、分散1に正規化(標準化)関数\n", 68 | "import scipy.stats\n", 69 | "\n", 70 | "# シグモイド関数をimport\n", 71 | "from scipy.special import expit\n", 72 | "\n", 73 | "# その他\n", 74 | "import pandas as pd\n" 75 | ], 76 | "execution_count": null, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": { 82 | "id": "AWqP6yeQlI_t" 83 | }, 84 | "source": [ 85 | "## データの作成" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "metadata": { 91 | "id": "EJaQIHz4fNXb" 92 | }, 93 | "source": [ 94 | "# データ数\n", 95 | "num_data = 200\n", 96 | "\n", 97 | "# 年齢\n", 98 | "x_1 = randint(15, 76, num_data) # 15から75歳の一様乱数\n", 99 | "\n", 100 | "# 性別(0を女性、1を男性とします)\n", 101 | "x_2 = randint(0, 2, num_data) # 0か1の乱数\n" 102 | ], 103 | "execution_count": null, 104 | "outputs": [] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": { 109 | "id": "JiT_gc5ZmAQa" 110 | }, 111 | "source": [ 112 | "## テレビCMを見たかどうか" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "metadata": { 118 | "id": "hfPbhhm6gvW6" 119 | }, 120 | "source": [ 121 | "# ノイズの生成\n", 122 | "e_z = randn(num_data)\n", 123 | "\n", 124 | "# シグモイド関数に入れる部分\n", 125 | "z_base = x_1 + (1-x_2)*10 - 40 + 5*e_z\n", 126 | "\n", 127 | "# シグモイド関数を計算\n", 128 | "z_prob = expit(0.1*z_base)\n", 129 | "\n", 130 | "# テレビCMを見たかどうかの変数(0は見ていない、1は見た)\n", 131 | "Z = np.array([])\n", 132 | "\n", 133 | "for i in range(num_data):\n", 134 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n", 135 | " Z = np.append(Z, Z_i)\n" 136 | ], 137 | "execution_count": null, 138 | "outputs": [] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": { 143 | "id": "b2PLquJGi2Te" 144 | }, 145 | "source": [ 146 | "## 購入量Yを作成" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "metadata": { 152 | "id": "nv-ELtFqi5L5" 153 | }, 154 | "source": [ 155 | "# ノイズの生成\n", 156 | "e_y = randn(num_data)\n", 157 | "\n", 158 | "Y = -x_1 + 30*x_2 + 10*Z + 80 + 10*e_y\n" 159 | ], 160 | "execution_count": null, 161 | "outputs": [] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": { 166 | "id": "BHcdUlW9koTa" 167 | }, 168 | "source": [ 169 | "## データをまとめた表を作成し、平均値を比べる" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "metadata": { 175 | "id": "HPqwrISXktRj", 176 | "colab": { 177 | "base_uri": "https://localhost:8080/", 178 | "height": 195 179 | }, 180 | "outputId": "539bb1b5-9936-461e-dd94-258b6980d366" 181 | }, 182 | "source": [ 183 | "df = pd.DataFrame({'年齢': x_1,\n", 184 | " '性別': x_2,\n", 185 | " 'CMを見た': Z,\n", 186 | " '購入量': Y,\n", 187 | " })\n", 188 | "\n", 189 | "df.head() # 先頭を表示\n" 190 | ], 191 | "execution_count": null, 192 | "outputs": [ 193 | { 194 | "output_type": "execute_result", 195 | "data": { 196 | "text/html": [ 197 | "
\n", 198 | "\n", 211 | "\n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | "
年齢性別CMを見た購入量
06201.024.464285
13400.045.693411
25311.064.998281
36811.047.186898
42710.0100.114260
\n", 259 | "
" 260 | ], 261 | "text/plain": [ 262 | " 年齢 性別 CMを見た 購入量\n", 263 | "0 62 0 1.0 24.464285\n", 264 | "1 34 0 0.0 45.693411\n", 265 | "2 53 1 1.0 64.998281\n", 266 | "3 68 1 1.0 47.186898\n", 267 | "4 27 1 0.0 100.114260" 268 | ] 269 | }, 270 | "metadata": { 271 | "tags": [] 272 | }, 273 | "execution_count": 6 274 | } 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "metadata": { 280 | "id": "HHInQ1Sukrg0", 281 | "colab": { 282 | "base_uri": "https://localhost:8080/", 283 | "height": 202 284 | }, 285 | "outputId": "4739a485-80aa-425b-ccd8-91294cb7d9fb" 286 | }, 287 | "source": [ 288 | "# 平均値を比べる\n", 289 | "\n", 290 | "print(df[df[\"CMを見た\"] == 1.0].mean())\n", 291 | "print(\"--------\")\n", 292 | "print(df[df[\"CMを見た\"] == 0.0].mean())\n" 293 | ], 294 | "execution_count": null, 295 | "outputs": [ 296 | { 297 | "output_type": "stream", 298 | "text": [ 299 | "年齢 55.836066\n", 300 | "性別 0.483607\n", 301 | "CMを見た 1.000000\n", 302 | "購入量 49.711478\n", 303 | "dtype: float64\n", 304 | "--------\n", 305 | "年齢 32.141026\n", 306 | "性別 0.692308\n", 307 | "CMを見た 0.000000\n", 308 | "購入量 68.827143\n", 309 | "dtype: float64\n" 310 | ], 311 | "name": "stdout" 312 | } 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": { 318 | "id": "-KMbYTvx-D4N" 319 | }, 320 | "source": [ 321 | "## 回帰分析を実施" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "metadata": { 327 | "id": "CiVVt59d-gdj" 328 | }, 329 | "source": [ 330 | "# scikit-learnから線形回帰をimport\n", 331 | "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html\n", 332 | "from sklearn.linear_model import LinearRegression\n", 333 | "\n", 334 | "# 説明変数\n", 335 | "X = df[[\"年齢\", \"性別\", \"CMを見た\"]]\n", 336 | "\n", 337 | "# 被説明変数(目的変数)\n", 338 | "y = df[\"購入量\"]\n", 339 | "\n", 340 | "# 回帰の実施\n", 341 | "reg2 = LinearRegression().fit(X, y)\n", 342 | "\n", 343 | "# Z=0の場合\n", 344 | "X_0 = X.copy()\n", 345 | "X_0[\"CMを見た\"] = 0\n", 346 | "Y_0 = reg2.predict(X_0)\n", 347 | "\n", 348 | "# Z=1の場合\n", 349 | "X_1 = X.copy()\n", 350 | "X_1[\"CMを見た\"] = 1\n", 351 | "Y_1 = reg2.predict(X_1)\n" 352 | ], 353 | "execution_count": null, 354 | "outputs": [] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": { 359 | "id": "kwKOk59aogBd" 360 | }, 361 | "source": [ 362 | "## 傾向スコアの推定" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "metadata": { 368 | "id": "rBtHC3smoiMC", 369 | "colab": { 370 | "base_uri": "https://localhost:8080/", 371 | "height": 101 372 | }, 373 | "outputId": "6b9e06dc-3ee1-45be-a573-39a6839d9c85" 374 | }, 375 | "source": [ 376 | "# scikit-learnからロジスティク回帰をimport\n", 377 | "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html\n", 378 | "from sklearn.linear_model import LogisticRegression\n", 379 | "\n", 380 | "# 説明変数\n", 381 | "X = df[[\"年齢\", \"性別\"]]\n", 382 | "\n", 383 | "# 被説明変数(目的変数)\n", 384 | "Z = df[\"CMを見た\"]\n", 385 | "\n", 386 | "# 回帰の実施\n", 387 | "reg = LogisticRegression().fit(X, Z)\n", 388 | "\n", 389 | "# 傾向スコアを求める\n", 390 | "Z_pre = reg.predict_proba(X)\n", 391 | "print(Z_pre[0:5]) # 5人ほどの結果を見てみる\n" 392 | ], 393 | "execution_count": null, 394 | "outputs": [ 395 | { 396 | "output_type": "stream", 397 | "text": [ 398 | "[[0.04002323 0.95997677]\n", 399 | " [0.44525168 0.55474832]\n", 400 | " [0.30065918 0.69934082]\n", 401 | " [0.08101946 0.91898054]\n", 402 | " [0.87013558 0.12986442]]\n" 403 | ], 404 | "name": "stdout" 405 | } 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": { 411 | "id": "wL-hlBN36DZf" 412 | }, 413 | "source": [ 414 | "### 平均処置効果ATEを求める" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "metadata": { 420 | "id": "F7bIHOC2ABSK", 421 | "colab": { 422 | "base_uri": "https://localhost:8080/", 423 | "height": 34 424 | }, 425 | "outputId": "2bc8f4d0-2d8a-4620-a60f-844afc4c96e3" 426 | }, 427 | "source": [ 428 | "ATE_1_i = Y/Z_pre[:, 1]*Z + (1-Z/Z_pre[:, 1])*Y_1\n", 429 | "ATE_0_i = Y/Z_pre[:, 0]*(1-Z) + (1-(1-Z)/Z_pre[:, 0])*Y_0\n", 430 | "ATE = 1/len(Y)*(ATE_1_i-ATE_0_i).sum()\n", 431 | "print(\"推定したATE\", ATE)\n" 432 | ], 433 | "execution_count": null, 434 | "outputs": [ 435 | { 436 | "output_type": "stream", 437 | "text": [ 438 | "推定したATE 9.75277505424846\n" 439 | ], 440 | "name": "stdout" 441 | } 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": { 447 | "id": "1IdVhXmMps-w" 448 | }, 449 | "source": [ 450 | "以上" 451 | ] 452 | } 453 | ] 454 | } -------------------------------------------------------------------------------- /5_1_randomforest.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "5_1_randomforest.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "d-IAJLC2k1NX", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "# 5.1 ランダムフォレストとは\n", 24 | "\n", 25 | "本ファイルは、5.1節の実装です。\n", 26 | "\n", 27 | "機械学習モデルのランダムフォレストを解説・実装します。\n", 28 | "決定木の分類、決定木の回帰、ランダムフォレストの分類、ランダムフォレストの回帰を実施します" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "2XdIDbdlejUk", 35 | "colab_type": "text" 36 | }, 37 | "source": [ 38 | "## プログラム実行前の設定など" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "metadata": { 44 | "id": "wqHjwstVeXYt", 45 | "colab_type": "code", 46 | "colab": {} 47 | }, 48 | "source": [ 49 | "# 乱数のシードを設定\n", 50 | "import random\n", 51 | "import numpy as np\n", 52 | "\n", 53 | "np.random.seed(1234)\n", 54 | "random.seed(1234)\n" 55 | ], 56 | "execution_count": 0, 57 | "outputs": [] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "metadata": { 62 | "id": "RIhcLRqlem3V", 63 | "colab_type": "code", 64 | "colab": {} 65 | }, 66 | "source": [ 67 | "# 使用するパッケージ(ライブラリと関数)を定義\n", 68 | "# 標準正規分布の生成用\n", 69 | "from numpy.random import *\n", 70 | "\n", 71 | "# グラフの描画用\n", 72 | "import matplotlib.pyplot as plt\n", 73 | "\n", 74 | "# SciPy 平均0、分散1に正規化(標準化)関数\n", 75 | "import scipy.stats\n", 76 | "\n", 77 | "# シグモイド関数をimport\n", 78 | "from scipy.special import expit\n", 79 | "\n", 80 | "# その他\n", 81 | "import pandas as pd\n" 82 | ], 83 | "execution_count": 0, 84 | "outputs": [] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": { 89 | "id": "AWqP6yeQlI_t", 90 | "colab_type": "text" 91 | }, 92 | "source": [ 93 | "## データの作成" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "metadata": { 99 | "id": "EJaQIHz4fNXb", 100 | "colab_type": "code", 101 | "colab": {} 102 | }, 103 | "source": [ 104 | "# データ数\n", 105 | "num_data = 200\n", 106 | "\n", 107 | "# 年齢\n", 108 | "x_1 = randint(15, 76, num_data) # 15から75歳の一様乱数\n", 109 | "\n", 110 | "# 性別(0を女性、1を男性とします)\n", 111 | "x_2 = randint(0, 2, num_data) # 0か1の乱数\n" 112 | ], 113 | "execution_count": 0, 114 | "outputs": [] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": { 119 | "id": "JiT_gc5ZmAQa", 120 | "colab_type": "text" 121 | }, 122 | "source": [ 123 | "## テレビCMを見たかどうか" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "metadata": { 129 | "id": "hfPbhhm6gvW6", 130 | "colab_type": "code", 131 | "colab": {} 132 | }, 133 | "source": [ 134 | "# ノイズの生成\n", 135 | "e_z = randn(num_data)\n", 136 | "\n", 137 | "# シグモイド関数に入れる部分\n", 138 | "z_base = x_1 + (1-x_2)*10 - 40 + 5*e_z\n", 139 | "\n", 140 | "# シグモイド関数を計算\n", 141 | "z_prob = expit(0.1*z_base)\n", 142 | "\n", 143 | "# テレビCMを見たかどうかの変数(0は見ていない、1は見た)\n", 144 | "Z = np.array([])\n", 145 | "\n", 146 | "for i in range(num_data):\n", 147 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n", 148 | " Z = np.append(Z, Z_i)\n" 149 | ], 150 | "execution_count": 0, 151 | "outputs": [] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": { 156 | "id": "b2PLquJGi2Te", 157 | "colab_type": "text" 158 | }, 159 | "source": [ 160 | "## 購入量Yを作成" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "nv-ELtFqi5L5", 167 | "colab_type": "code", 168 | "colab": {} 169 | }, 170 | "source": [ 171 | "# ノイズの生成\n", 172 | "e_y = randn(num_data)\n", 173 | "\n", 174 | "Y = -x_1 + 30*x_2 + 10*Z + 80 + 10*e_y\n" 175 | ], 176 | "execution_count": 0, 177 | "outputs": [] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": { 182 | "id": "BHcdUlW9koTa", 183 | "colab_type": "text" 184 | }, 185 | "source": [ 186 | "## データをまとめた表を作成し、平均値を比べる" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "metadata": { 192 | "id": "HPqwrISXktRj", 193 | "colab_type": "code", 194 | "outputId": "472a7556-d2c0-4341-f570-d101075ca857", 195 | "colab": { 196 | "base_uri": "https://localhost:8080/", 197 | "height": 195 198 | } 199 | }, 200 | "source": [ 201 | "df = pd.DataFrame({'年齢': x_1,\n", 202 | " '性別': x_2,\n", 203 | " 'CMを見た': Z,\n", 204 | " '購入量': Y,\n", 205 | " })\n", 206 | "\n", 207 | "df.head() # 先頭を表示\n" 208 | ], 209 | "execution_count": 6, 210 | "outputs": [ 211 | { 212 | "output_type": "execute_result", 213 | "data": { 214 | "text/html": [ 215 | "
\n", 216 | "\n", 229 | "\n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | "
年齢性別CMを見た購入量
06201.024.464285
13400.045.693411
25311.064.998281
36811.047.186898
42710.0100.114260
\n", 277 | "
" 278 | ], 279 | "text/plain": [ 280 | " 年齢 性別 CMを見た 購入量\n", 281 | "0 62 0 1.0 24.464285\n", 282 | "1 34 0 0.0 45.693411\n", 283 | "2 53 1 1.0 64.998281\n", 284 | "3 68 1 1.0 47.186898\n", 285 | "4 27 1 0.0 100.114260" 286 | ] 287 | }, 288 | "metadata": { 289 | "tags": [] 290 | }, 291 | "execution_count": 6 292 | } 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "metadata": { 298 | "id": "HHInQ1Sukrg0", 299 | "colab_type": "code", 300 | "outputId": "66879f86-ad4f-46cc-9a20-8edf74c860e1", 301 | "colab": { 302 | "base_uri": "https://localhost:8080/", 303 | "height": 202 304 | } 305 | }, 306 | "source": [ 307 | "# 平均値を比べる\n", 308 | "\n", 309 | "print(df[df[\"CMを見た\"] == 1.0].mean())\n", 310 | "print(\"--------\")\n", 311 | "print(df[df[\"CMを見た\"] == 0.0].mean())\n" 312 | ], 313 | "execution_count": 7, 314 | "outputs": [ 315 | { 316 | "output_type": "stream", 317 | "text": [ 318 | "年齢 55.836066\n", 319 | "性別 0.483607\n", 320 | "CMを見た 1.000000\n", 321 | "購入量 49.711478\n", 322 | "dtype: float64\n", 323 | "--------\n", 324 | "年齢 32.141026\n", 325 | "性別 0.692308\n", 326 | "CMを見た 0.000000\n", 327 | "購入量 68.827143\n", 328 | "dtype: float64\n" 329 | ], 330 | "name": "stdout" 331 | } 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": { 337 | "id": "kwKOk59aogBd", 338 | "colab_type": "text" 339 | }, 340 | "source": [ 341 | "## 決定木で分類\n", 342 | "\n", 343 | "決定木でCMを見たかどうかを分類予測するモデルを構築します" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "metadata": { 349 | "id": "rBtHC3smoiMC", 350 | "colab_type": "code", 351 | "outputId": "1b5ab97b-27f8-42fe-dc8b-995d9f6b5d0d", 352 | "colab": { 353 | "base_uri": "https://localhost:8080/", 354 | "height": 67 355 | } 356 | }, 357 | "source": [ 358 | "# scikit-learnから決定木の分類をimport\n", 359 | "# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html\n", 360 | "from sklearn.tree import DecisionTreeClassifier\n", 361 | "\n", 362 | "# データを訓練と検証に分割する\n", 363 | "# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html\n", 364 | "from sklearn.model_selection import train_test_split\n", 365 | "\n", 366 | "\n", 367 | "# 説明変数\n", 368 | "X = df[[\"年齢\", \"性別\"]]\n", 369 | "\n", 370 | "# 被説明変数(目的変数)\n", 371 | "Z = df[\"CMを見た\"]\n", 372 | "\n", 373 | "# データを訓練と検証に分割\n", 374 | "X_train, X_val, Z_train, Z_val = train_test_split(\n", 375 | " X, Z, train_size=0.6, random_state=0)\n", 376 | "\n", 377 | "# 学習と性能確認\n", 378 | "clf = DecisionTreeClassifier(max_depth=1, random_state=0)\n", 379 | "clf.fit(X_train, Z_train)\n", 380 | "print(\"深さ1の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n", 381 | "\n", 382 | "# 学習と性能確認\n", 383 | "clf = DecisionTreeClassifier(max_depth=2, random_state=0)\n", 384 | "clf.fit(X_train, Z_train)\n", 385 | "print(\"深さ2の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n", 386 | "\n", 387 | "# 学習と性能確認\n", 388 | "clf = DecisionTreeClassifier(max_depth=3, random_state=0)\n", 389 | "clf.fit(X_train, Z_train)\n", 390 | "print(\"深さ3の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n" 391 | ], 392 | "execution_count": 8, 393 | "outputs": [ 394 | { 395 | "output_type": "stream", 396 | "text": [ 397 | "深さ1の性能: 0.85\n", 398 | "深さ2の性能: 0.85\n", 399 | "深さ3の性能: 0.825\n" 400 | ], 401 | "name": "stdout" 402 | } 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": { 408 | "colab_type": "text", 409 | "id": "-KMbYTvx-D4N" 410 | }, 411 | "source": [ 412 | "## 決定木で回帰\n", 413 | "\n", 414 | "決定木で購入量を回帰予測するモデルを構築します" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "metadata": { 420 | "id": "CiVVt59d-gdj", 421 | "colab_type": "code", 422 | "outputId": "93ee9046-b983-4dbb-d8cb-2d86c14bf7f2", 423 | "colab": { 424 | "base_uri": "https://localhost:8080/", 425 | "height": 67 426 | } 427 | }, 428 | "source": [ 429 | "# scikit-learnから決定木の回帰をimport\n", 430 | "# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html#sklearn.tree.DecisionTreeRegressor\n", 431 | "from sklearn.tree import DecisionTreeRegressor\n", 432 | "\n", 433 | "# データを訓練と検証に分割する\n", 434 | "# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html\n", 435 | "from sklearn.model_selection import train_test_split\n", 436 | "\n", 437 | "\n", 438 | "# 説明変数\n", 439 | "X = df[[\"年齢\", \"性別\"]]\n", 440 | "\n", 441 | "# 被説明変数(目的変数)\n", 442 | "Y = df[\"購入量\"]\n", 443 | "\n", 444 | "# データを訓練と検証に分割\n", 445 | "X_train, X_val, Y_train, Y_val = train_test_split(\n", 446 | " X, Y, train_size=0.6, random_state=0)\n", 447 | "\n", 448 | "# 学習と性能確認\n", 449 | "reg = DecisionTreeRegressor(max_depth=2, random_state=0)\n", 450 | "reg = reg.fit(X_train, Y_train)\n", 451 | "print(\"深さ2の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n", 452 | "\n", 453 | "# 学習と性能確認\n", 454 | "reg = DecisionTreeRegressor(max_depth=3, random_state=0)\n", 455 | "reg = reg.fit(X_train, Y_train)\n", 456 | "print(\"深さ3の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n", 457 | "\n", 458 | "# 学習と性能確認\n", 459 | "reg = DecisionTreeRegressor(max_depth=4, random_state=0)\n", 460 | "reg = reg.fit(X_train, Y_train)\n", 461 | "print(\"深さ4の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n" 462 | ], 463 | "execution_count": 9, 464 | "outputs": [ 465 | { 466 | "output_type": "stream", 467 | "text": [ 468 | "深さ2の性能: 0.7257496664596153\n", 469 | "深さ3の性能: 0.7399348963931736\n", 470 | "深さ4の性能: 0.7165539691159019\n" 471 | ], 472 | "name": "stdout" 473 | } 474 | ] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "metadata": { 479 | "colab_type": "text", 480 | "id": "1LHqDZTHyMeA" 481 | }, 482 | "source": [ 483 | "## ランダムフォレストで分類\n", 484 | "\n", 485 | "ランダムフォレストでCMを見たかどうかを分類予測するモデルを構築します" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "metadata": { 491 | "id": "QZCX_vszyRIF", 492 | "colab_type": "code", 493 | "outputId": "fcc9656b-0c0f-4884-d682-992ba88bb9a7", 494 | "colab": { 495 | "base_uri": "https://localhost:8080/", 496 | "height": 67 497 | } 498 | }, 499 | "source": [ 500 | "# scikit-learnからランダムフォレストの分類をimport\n", 501 | "# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html?highlight=randomforest\n", 502 | "from sklearn.ensemble import RandomForestClassifier\n", 503 | "from sklearn.model_selection import train_test_split\n", 504 | "\n", 505 | "# 説明変数\n", 506 | "X = df[[\"年齢\", \"性別\"]]\n", 507 | "\n", 508 | "# 被説明変数(目的変数)\n", 509 | "Z = df[\"CMを見た\"]\n", 510 | "\n", 511 | "# データを訓練と検証に分割\n", 512 | "X_train, X_val, Z_train, Z_val = train_test_split(\n", 513 | " X, Z, train_size=0.6, random_state=0)\n", 514 | "\n", 515 | "# 学習と性能確認\n", 516 | "clf = RandomForestClassifier(max_depth=1, random_state=0)\n", 517 | "clf.fit(X_train, Z_train)\n", 518 | "print(\"深さ1の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n", 519 | "\n", 520 | "# 学習と性能確認\n", 521 | "clf = RandomForestClassifier(max_depth=2, random_state=0)\n", 522 | "clf.fit(X_train, Z_train)\n", 523 | "print(\"深さ2の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n", 524 | "\n", 525 | "# 学習と性能確認\n", 526 | "clf = RandomForestClassifier(max_depth=3, random_state=0)\n", 527 | "clf.fit(X_train, Z_train)\n", 528 | "print(\"深さ3の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n" 529 | ], 530 | "execution_count": 10, 531 | "outputs": [ 532 | { 533 | "output_type": "stream", 534 | "text": [ 535 | "深さ1の性能: 0.775\n", 536 | "深さ2の性能: 0.85\n", 537 | "深さ3の性能: 0.825\n" 538 | ], 539 | "name": "stdout" 540 | } 541 | ] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": { 546 | "colab_type": "text", 547 | "id": "OuTwc5Kt4AiW" 548 | }, 549 | "source": [ 550 | "## ランダムフォレストで回帰\n", 551 | "\n", 552 | "ランダムフォレストで購入量を回帰予測するモデルを構築します" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "metadata": { 558 | "id": "evnAhZnb4DXj", 559 | "colab_type": "code", 560 | "outputId": "1b439677-66ce-49dc-8760-70916ab7d2fc", 561 | "colab": { 562 | "base_uri": "https://localhost:8080/", 563 | "height": 67 564 | } 565 | }, 566 | "source": [ 567 | "# scikit-learnから決定木の回帰をimport\n", 568 | "# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html?highlight=randomforest\n", 569 | "from sklearn.ensemble import RandomForestRegressor\n", 570 | "from sklearn.model_selection import train_test_split\n", 571 | "\n", 572 | "\n", 573 | "# 説明変数\n", 574 | "X = df[[\"年齢\", \"性別\"]]\n", 575 | "\n", 576 | "# 被説明変数(目的変数)\n", 577 | "Y = df[\"購入量\"]\n", 578 | "\n", 579 | "# データを訓練と検証に分割\n", 580 | "X_train, X_val, Y_train, Y_val = train_test_split(\n", 581 | " X, Y, train_size=0.6, random_state=0)\n", 582 | "\n", 583 | "# 学習と性能確認\n", 584 | "reg = RandomForestRegressor(max_depth=2, random_state=0)\n", 585 | "reg = reg.fit(X_train, Y_train)\n", 586 | "print(\"深さ2の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n", 587 | "\n", 588 | "# 学習と性能確認\n", 589 | "reg = RandomForestRegressor(max_depth=3, random_state=0)\n", 590 | "reg = reg.fit(X_train, Y_train)\n", 591 | "print(\"深さ3の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n", 592 | "\n", 593 | "# 学習と性能確認\n", 594 | "reg = RandomForestRegressor(max_depth=4, random_state=0)\n", 595 | "reg = reg.fit(X_train, Y_train)\n", 596 | "print(\"深さ4の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n" 597 | ], 598 | "execution_count": 11, 599 | "outputs": [ 600 | { 601 | "output_type": "stream", 602 | "text": [ 603 | "深さ2の性能: 0.7618786062003249\n", 604 | "深さ3の性能: 0.7810610687821996\n", 605 | "深さ4の性能: 0.7655149049335735\n" 606 | ], 607 | "name": "stdout" 608 | } 609 | ] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": { 614 | "id": "1IdVhXmMps-w", 615 | "colab_type": "text" 616 | }, 617 | "source": [ 618 | "以上" 619 | ] 620 | } 621 | ] 622 | } -------------------------------------------------------------------------------- /5_3_doubly_robust_learning_issue18.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "colab": { 10 | "name": "5_3_doubly_robust_learning_issue18.ipynb", 11 | "provenance": [], 12 | "collapsed_sections": [] 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "aoxI3DOK9vm2" 20 | }, 21 | "source": [ 22 | "# 5.3 Doubly Robust Learningの実装\n", 23 | "\n", 24 | "本ファイルは、5.3節の実装です。\n", 25 | "\n", 26 | "5.2節と同じく、人事研修の効果について因果推論を実施します。" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "2XdIDbdlejUk" 33 | }, 34 | "source": [ 35 | "## プログラム実行前の設定など" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "metadata": { 41 | "id": "XZFKJwcu-_Oj" 42 | }, 43 | "source": [ 44 | "# 乱数のシードを設定\n", 45 | "import random\n", 46 | "import numpy as np\n", 47 | "\n", 48 | "np.random.seed(1234)\n", 49 | "random.seed(1234)\n" 50 | ], 51 | "execution_count": 1, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "metadata": { 57 | "id": "hx1idArc_F15" 58 | }, 59 | "source": [ 60 | "# 使用するパッケージ(ライブラリと関数)を定義\n", 61 | "# 標準正規分布の生成用\n", 62 | "from numpy.random import *\n", 63 | "\n", 64 | "# グラフの描画用\n", 65 | "import matplotlib.pyplot as plt\n", 66 | "\n", 67 | "# SciPy 平均0、分散1に正規化(標準化)関数\n", 68 | "import scipy.stats\n", 69 | "\n", 70 | "# シグモイド関数をimport\n", 71 | "from scipy.special import expit\n", 72 | "\n", 73 | "# その他\n", 74 | "import pandas as pd\n" 75 | ], 76 | "execution_count": 2, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": { 82 | "id": "AWqP6yeQlI_t" 83 | }, 84 | "source": [ 85 | "## データの作成" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "metadata": { 91 | "id": "DpnGB2KZ_L8x", 92 | "outputId": "b67517df-c4d8-4c40-851b-669d4aefabdc", 93 | "colab": { 94 | "base_uri": "https://localhost:8080/", 95 | "height": 282 96 | } 97 | }, 98 | "source": [ 99 | "# データ数\n", 100 | "num_data = 500\n", 101 | "\n", 102 | "# 部下育成への熱心さ\n", 103 | "x = np.random.uniform(low=-1, high=1, size=num_data) # -1から1の一様乱数\n", 104 | "\n", 105 | "# 上司が「上司向け:部下とのキャリア面談のポイント研修」に参加したかどうか\n", 106 | "e_z = randn(num_data) # ノイズの生成\n", 107 | "z_prob = expit(-1*-5.0*x+5*e_z) # xの効果が反対になっていたのを修正Issue:#18\n", 108 | "Z = np.array([])\n", 109 | "\n", 110 | "# 上司が「上司向け:部下とのキャリア面談のポイント研修」に参加したかどうか\n", 111 | "for i in range(num_data):\n", 112 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n", 113 | " Z = np.append(Z, Z_i)\n", 114 | "\n", 115 | "# 介入効果の非線形性:部下育成の熱心さxの値に応じて段階的に変化\n", 116 | "t = np.zeros(num_data)\n", 117 | "for i in range(num_data):\n", 118 | " if x[i] < 0:\n", 119 | " t[i] = 0.5\n", 120 | " elif x[i] >= 0 and x[i] < 0.5:\n", 121 | " t[i] = 0.7\n", 122 | " elif x[i] >= 0.5:\n", 123 | " t[i] = 1.0\n", 124 | "\n", 125 | "e_y = randn(num_data)\n", 126 | "Y = 2.0 + t*Z + 0.3*x + 0.1*e_y \n", 127 | "\n", 128 | "# 介入効果を図で確認\n", 129 | "plt.scatter(x, t, label=\"treatment-effect\")\n" 130 | ], 131 | "execution_count": 3, 132 | "outputs": [ 133 | { 134 | "output_type": "execute_result", 135 | "data": { 136 | "text/plain": [ 137 | "" 138 | ] 139 | }, 140 | "metadata": { 141 | "tags": [] 142 | }, 143 | "execution_count": 3 144 | }, 145 | { 146 | "output_type": "display_data", 147 | "data": { 148 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAASc0lEQVR4nO3df5BdZX3H8fc3mwQjVkPIaiUJJDgRTRsFvYO0zBT8SWA6SfzZZMoIlppqxc60yhRGxjIMjra049BKa5GhqNggUqXrFCejAuOMQ5ClCBiYQIgjJFBZQZxpoZAf3/5xT+R0c3fv3ey5e7MP79fMzt7zPM8955vn3nz27Dnn7onMRJI0+80ZdAGSpGYY6JJUCANdkgphoEtSIQx0SSrE3EFtePHixbl8+fJBbV6SZqW77rrrF5k53KlvYIG+fPlyRkdHB7V5SZqVIuJnE/V5yEWSCmGgS1IhDHRJKoSBLkmFMNAlqRBdr3KJiGuA3weeyMzf7tAfwBXAWcAzwLmZ+Z9NFyrp8HTxTffxtTseYbp/52/unGDv/slXEkACEUx7e4O2cME8Lln7W6w/aUlj6+xlD/1aYM0k/WcCK6uvTcA/Tb8sSbPBxTfdx3Vbpx/mQNcwh3aYw+wPc4Cnn93DBd+4h5vu3t3YOrsGemb+AHhqkiHrgK9k21ZgYUS8uqkCJR2+Nt/x6KBLmNX27E8u37K9sfU1cQx9CVB/VXdVbQeJiE0RMRoRo2NjYw1sWtIg7SthV3nAHnv62cbWNaMnRTPzqsxsZWZreLjjJ1clzSJDEYMuYdY7ZuGCxtbVRKDvBpbVlpdWbZIKt/Ety7oP0oTmzQkuOOOExtbXRKCPAB+MtlOAX2Xm4w2sV9Jh7rL1qzn7lGNpYkd97pzuKzkwooRfDBYumMfl739jo1e5RLd7ikbEZuB0YDHwc+CvgHkAmfnF6rLFL9C+EuYZ4EOZ2fWvbrVarfSPc0nS1ETEXZnZ6tTX9Tr0zNzYpT+Bjx1ibZKkhvhJUUkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCtFToEfEmojYHhE7IuLCDv3HRcT3I+LeiLgtIpY2X6okaTJdAz0ihoArgTOBVcDGiFg1btjfAl/JzDcAlwKfbbpQSdLketlDPxnYkZk7M/N54Hpg3bgxq4Bbqse3duiXJPVZL4G+BHi0tryraqu7B3hP9fjdwG9ExNHjVxQRmyJiNCJGx8bGDqVeSdIEmjop+kngtIi4GzgN2A3sGz8oM6/KzFZmtoaHhxvatCQJYG4PY3YDy2rLS6u2X8vMx6j20CPiZcB7M/PppoqUJHXXyx76ncDKiFgREfOBDcBIfUBELI6IA+u6CLim2TIlSd10DfTM3AucD2wBHgBuyMxtEXFpRKythp0ObI+IB4FXAZ/pU72SpAlEZg5kw61WK0dHRweybUmarSLirsxsderzk6KSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpED0FekSsiYjtEbEjIi7s0H9sRNwaEXdHxL0RcVbzpUqSJtM10CNiCLgSOBNYBWyMiFXjhl0M3JCZJwEbgH9sulBJ0uR62UM/GdiRmTsz83ngemDduDEJvLx6/ArgseZKlCT1opdAXwI8WlveVbXVXQKcHRG7gJuBj3daUURsiojRiBgdGxs7hHIlSRNp6qToRuDazFwKnAV8NSIOWndmXpWZrcxsDQ8PN7RpSRL0Fui7gWW15aVVW915wA0AmXk78BJgcRMFSpJ600ug3wmsjIgVETGf9knPkXFjHgHeDhARr6cd6B5TkaQZ1DXQM3MvcD6wBXiA9tUs2yLi0ohYWw37BPDhiLgH2Aycm5nZr6IlSQeb28ugzLyZ9snOetuna4/vB05ttjRJ0lT4SVFJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgoxt5dBEbEGuAIYAq7OzM+N6/888NZq8aXAKzNzYZOFSv128U33cd3WRwZdxqx06msW8bUP/86gy3jR67qHHhFDwJXAmcAqYGNErKqPycw/z8wTM/NE4B+Ab/ajWKlfDPPp+eHDT/GHX7p90GW86PVyyOVkYEdm7szM54HrgXWTjN8IbG6iOGmmbL7j0UGXMOv98OGnBl3Ci14vgb4EqL/bd1VtB4mI44AVwC0T9G+KiNGIGB0bG5tqrVLf7MscdAnStDV9UnQDcGNm7uvUmZlXZWYrM1vDw8MNb1o6dEMRgy5BmrZeAn03sKy2vLRq62QDHm7RLLTxLcu6D9KkTn3NokGX8KLXS6DfCayMiBURMZ92aI+MHxQRrwOOAjwzolnnsvWrOfuUYwddxqzlVS6Hh66XLWbm3og4H9hC+7LFazJzW0RcCoxm5oFw3wBcn+nBSM1Ol61fzWXrVw+6DOmQ9XQdembeDNw8ru3T45Yvaa4sSdJU+UlRSSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVIieAj0i1kTE9ojYEREXTjDmAxFxf0Rsi4h/bbZMSVI3c7sNiIgh4ErgncAu4M6IGMnM+2tjVgIXAadm5i8j4pX9KliS1Fkve+gnAzsyc2dmPg9cD6wbN+bDwJWZ+UuAzHyi2TIlSd30EuhLgEdry7uqtrrXAq+NiB9GxNaIWNNpRRGxKSJGI2J0bGzs0CqWJHXU1EnRucBK4HRgI/CliFg4flBmXpWZrcxsDQ8PN7RpSRL0Fui7gWW15aVVW90uYCQz92TmT4EHaQe8JGmG9BLodwIrI2JFRMwHNgAj48bcRHvvnIhYTPsQzM4G65QkddE10DNzL3A+sAV4ALghM7dFxKURsbYatgV4MiLuB24FLsjMJ/tVtCTpYJGZA9lwq9XK0dHRgWxbkmariLgrM1ud+vykqCQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhZjby6CIWANcAQwBV2fm58b1nwtcDuyumr6QmVc3WCcAN929m8u3bGf30882vepJHTl/iMzkmT37Z3S707Vg3hze++alfP1Hj9BE6XOArL6atmThApYfvYDbdz7F/n5sADhi7hye29t9Io6cP8Rn3r2a9Sct6U8hUp90DfSIGAKuBN4J7ALujIiRzLx/3NCvZ+b5fagRaIf5Rd+8j2f37OvXJib0P8/P/Dab8Oye/Vy39ZHG1tfPH2e7n3627z+oewlzaL/en/jGPQCGumaVXg65nAzsyMydmfk8cD2wrr9lHezyLdsHEuZ6cdq3P7l8y/ZBlyFNSS+BvgR4tLa8q2ob770RcW9E3BgRyzqtKCI2RcRoRIyOjY1NqdDHZvgwi+R7TrNNUydFvw0sz8w3AN8FvtxpUGZelZmtzGwNDw9PaQPHLFww/SqlKfA9p9mml0DfDdT3uJfywslPADLzycx8rlq8GnhzM+W94IIzTmDBvKGmVyt1NDQnuOCMEwZdhjQlvQT6ncDKiFgREfOBDcBIfUBEvLq2uBZ4oLkS29aftITPvmc1Swaw13Tk/CFeOm/2XeG5YN4czj7lWJoqfQ4QzazqIEsWLuDU1yxiTr82QPsql14cOX+Iv3v/Gz0hqlmn61Uumbk3Is4HttC+bPGazNwWEZcCo5k5AvxZRKwF9gJPAef2o9j1Jy3xP9khuGz96kGXIGkGRGafLvrtotVq5ejo6EC2LUmzVUTclZmtTn2z7ziCJKkjA12SCmGgS1IhDHRJKsTATopGxBjws0N8+mLgFw2W0xTrmhrrmhrrmrrDtbbp1HVcZnb8ZObAAn06ImJ0orO8g2RdU2NdU2NdU3e41tavujzkIkmFMNAlqRCzNdCvGnQBE7CuqbGuqbGuqTtca+tLXbPyGLok6WCzdQ9dkjSOgS5JhTgsAz0i3h8R2yJif0RMeGlPRKyJiO0RsSMiLqy1r4iIO6r2r1d/9rep2hZFxHcj4qHq+1Edxrw1In5c+/rfiFhf9V0bET+t9Z04U3VV4/bVtj1Sa+/LnPU4XydGxO3Va35vRPxBra/R+ZroPVPrP6L69++o5mN5re+iqn17RJwxnToOoa6/iIj7q/n5fkQcV+vr+JrOUF3nRsRYbft/XOs7p3rdH4qIc2a4rs/XanowIp6u9fVzvq6JiCci4icT9EdE/H1V970R8aZa3/TnKzMPuy/g9cAJwG1Aa4IxQ8DDwPHAfOAeYFXVdwOwoXr8ReCjDdb2N8CF1eMLgb/uMn4R7T8p/NJq+VrgfX2Ys57qAv57gva+zFkvdQGvBVZWj48BHgcWNj1fk71namP+FPhi9XgD7ZufA6yqxh8BrKjWMzSDdb219h766IG6JntNZ6iuc4EvdHjuImBn9f2o6vFRM1XXuPEfp/1nv/s6X9W6fw94E/CTCfrPAr5D+9YCpwB3NDlfh+UeemY+kJnd7tDb8ebVERHA24Abq3FfBtY3WN46XrjFXi/rfh/wncx8psEaOplqXb/W5znrWldmPpiZD1WPHwOeAKZ2j8Le9HLD83q9NwJvr+ZnHXB9Zj6XmT8FdlTrm5G6MvPW2ntoK+07h/XbdG4Qfwbw3cx8KjN/SfvWlGsGVNdGYHND255UZv6A9g7cRNYBX8m2rcDCaN8gqJH5OiwDvUcT3bz6aODpzNw7rr0pr8rMx6vH/wW8qsv4DRz8ZvpM9evW5yPiiBmu6yXRvlH31gOHgejvnE1pviLiZNp7XQ/Xmpuar15ueP7rMdV8/Ir2/PR6s/R+1VV3Hu29vAM6vaYzWVenG8QfFvNVHZpaAdxSa+7XfPViotobma+udyzql4j4HvCbHbo+lZn/PtP11E1WW30hMzMiJrzus/rJu5r23Z4OuIh2sM2nfS3qXwKXzmBdx2Xm7og4HrglIu6jHVqHrOH5+ipwTmbur5oPeb5KFBFnAy3gtFrzQa9pZj7ceQ2N+zawOTOfi4g/of3bzdtmaNu92ADcmJn7am2DnK++GligZ+Y7prmKiW5e/STtX2PmVntYB93Uejq1RcTPI+LVmfl4FUBPTLKqDwDfysw9tXUf2Ft9LiL+BfjkTNaVmbur7zsj4jbgJODfmMacNVFXRLwc+A/aP9C31tZ9yPPVQdcbntfG7IqIucAraL+nenluP+siIt5B+4fkafnCTdknek2bCKiebhBfW7ya9jmTA889fdxzb2ugpp7qqtkAfKze0Mf56sVEtTcyX7P5kEvHm1dn+wzDrbSPXQOcAzS5xz9SrbOXdR907K4KtQPHrdcDHc+G96OuiDjqwCGLiFgMnArc3+c566Wu+cC3aB9bvHFcX5Pz1fWG5+PqfR9wSzU/I8CGaF8FswJYCfxoGrVMqa6IOAn4Z2BtZj5Ra+/4ms5gXRPdIH4L8K6qvqOAd/H/f1Pta11Vba+jfYLx9lpbP+erFyPAB6urXU4BflXttDQzX/062zudL+DdtI8hPQf8HNhStR8D3FwbdxbwIO2frp+qtR9P+z/bDuAbwBEN1nY08H3gIeB7wKKqvQVcXRu3nPZP3Tnjnn8LcB/tYLoOeNlM1QX8brXte6rv5/V7znqs62xgD/Dj2teJ/ZivTu8Z2odw1laPX1L9+3dU83F87bmfqp63HTiz4fd8t7q+V/1fODA/I91e0xmq67PAtmr7twKvqz33j6p53AF8aCbrqpYvAT437nn9nq/NtK/S2kM7w84DPgJ8pOoP4Mqq7vuoXcXXxHz50X9JKsRsPuQiSaox0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1Ih/g9Sp439On0bGgAAAABJRU5ErkJggg==\n", 149 | "text/plain": [ 150 | "
" 151 | ] 152 | }, 153 | "metadata": { 154 | "tags": [], 155 | "needs_background": "light" 156 | } 157 | } 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": { 163 | "id": "BHcdUlW9koTa" 164 | }, 165 | "source": [ 166 | "## データをまとめた表を作成し、可視化する" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "metadata": { 172 | "id": "1EMwdGIIIPrK", 173 | "outputId": "c3575b0b-6d6d-4bbf-b59c-321d7e5978e2", 174 | "colab": { 175 | "base_uri": "https://localhost:8080/", 176 | "height": 195 177 | } 178 | }, 179 | "source": [ 180 | "df = pd.DataFrame({'x': x,\n", 181 | " 'Z': Z,\n", 182 | " 't': t,\n", 183 | " 'Y': Y,\n", 184 | " })\n", 185 | "\n", 186 | "df.head() # 先頭を表示\n" 187 | ], 188 | "execution_count": 4, 189 | "outputs": [ 190 | { 191 | "output_type": "execute_result", 192 | "data": { 193 | "text/html": [ 194 | "
\n", 195 | "\n", 208 | "\n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | "
xZtY
0-0.6169610.00.51.803183
10.2442181.00.72.668873
2-0.1245450.00.52.193123
30.5707171.01.03.245229
40.5599521.01.03.139868
\n", 256 | "
" 257 | ], 258 | "text/plain": [ 259 | " x Z t Y\n", 260 | "0 -0.616961 0.0 0.5 1.803183\n", 261 | "1 0.244218 1.0 0.7 2.668873\n", 262 | "2 -0.124545 0.0 0.5 2.193123\n", 263 | "3 0.570717 1.0 1.0 3.245229\n", 264 | "4 0.559952 1.0 1.0 3.139868" 265 | ] 266 | }, 267 | "metadata": { 268 | "tags": [] 269 | }, 270 | "execution_count": 4 271 | } 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "metadata": { 277 | "id": "L6Tb2Hjk9vno", 278 | "outputId": "dd55119a-872c-4dc1-ebc7-0f66a3bab881", 279 | "colab": { 280 | "base_uri": "https://localhost:8080/", 281 | "height": 282 282 | } 283 | }, 284 | "source": [ 285 | "plt.scatter(x, Y)\n" 286 | ], 287 | "execution_count": 5, 288 | "outputs": [ 289 | { 290 | "output_type": "execute_result", 291 | "data": { 292 | "text/plain": [ 293 | "" 294 | ] 295 | }, 296 | "metadata": { 297 | "tags": [] 298 | }, 299 | "execution_count": 5 300 | }, 301 | { 302 | "output_type": "display_data", 303 | "data": { 304 | "image/png": "\n", 305 | "text/plain": [ 306 | "
" 307 | ] 308 | }, 309 | "metadata": { 310 | "tags": [], 311 | "needs_background": "light" 312 | } 313 | } 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "id": "AeC7Uv29KsXC" 320 | }, 321 | "source": [ 322 | "## DR-Learnerの開始、まずはT-Learner" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "metadata": { 328 | "id": "xp2P-IDT9vql", 329 | "outputId": "122ee358-9b12-4166-94aa-38ad5ef19208", 330 | "colab": { 331 | "base_uri": "https://localhost:8080/" 332 | } 333 | }, 334 | "source": [ 335 | "# ランダムフォレストモデルを作成\n", 336 | "from sklearn.ensemble import RandomForestRegressor\n", 337 | "\n", 338 | "# 集団を2つに分ける\n", 339 | "df_0 = df[df.Z == 0.0] # 介入を受けていない集団\n", 340 | "df_1 = df[df.Z == 1.0] # 介入を受けた集団\n", 341 | "\n", 342 | "# 介入を受けていないモデル\n", 343 | "M_0 = RandomForestRegressor(max_depth=3)\n", 344 | "M_0.fit(df_0[[\"x\"]], df_0[[\"Y\"]])\n", 345 | "\n", 346 | "# 介入を受けたモデル\n", 347 | "M_1 = RandomForestRegressor(max_depth=3)\n", 348 | "M_1.fit(df_1[[\"x\"]], df_1[[\"Y\"]])\n" 349 | ], 350 | "execution_count": 6, 351 | "outputs": [ 352 | { 353 | "output_type": "stream", 354 | "text": [ 355 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:10: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", 356 | " # Remove the CWD from sys.path while we load stuff.\n", 357 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:14: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", 358 | " \n" 359 | ], 360 | "name": "stderr" 361 | }, 362 | { 363 | "output_type": "execute_result", 364 | "data": { 365 | "text/plain": [ 366 | "RandomForestRegressor(bootstrap=True, ccp_alpha=0.0, criterion='mse',\n", 367 | " max_depth=3, max_features='auto', max_leaf_nodes=None,\n", 368 | " max_samples=None, min_impurity_decrease=0.0,\n", 369 | " min_impurity_split=None, min_samples_leaf=1,\n", 370 | " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", 371 | " n_estimators=100, n_jobs=None, oob_score=False,\n", 372 | " random_state=None, verbose=0, warm_start=False)" 373 | ] 374 | }, 375 | "metadata": { 376 | "tags": [] 377 | }, 378 | "execution_count": 6 379 | } 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "metadata": { 385 | "id": "wAeHIJiqOF-h" 386 | }, 387 | "source": [ 388 | "# 傾向スコアを求めます\n", 389 | "from sklearn.linear_model import LogisticRegression\n", 390 | "\n", 391 | "# 説明変数\n", 392 | "X = df[[\"x\"]]\n", 393 | "\n", 394 | "# 被説明変数(目的変数)\n", 395 | "Z = df[\"Z\"]\n", 396 | "\n", 397 | "# 回帰の実施\n", 398 | "g_x = LogisticRegression().fit(X, Z)\n", 399 | "g_x_val = g_x.predict_proba(X)\n" 400 | ], 401 | "execution_count": 7, 402 | "outputs": [] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": { 407 | "id": "xTjMfuZTNrLO" 408 | }, 409 | "source": [ 410 | "## DR法に基づく推定" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "metadata": { 416 | "id": "jwEhxJQeNvhw", 417 | "outputId": "a56db28c-a884-489c-b597-0d7190281fbf", 418 | "colab": { 419 | "base_uri": "https://localhost:8080/", 420 | "height": 406 421 | } 422 | }, 423 | "source": [ 424 | "# 処置群\n", 425 | "Y_1 = M_1.predict(df_1[[\"x\"]]) + (df_1[\"Y\"] - M_1.predict(df_1[[\"x\"]])) / \\\n", 426 | " g_x.predict_proba(df_1[[\"x\"]])[:, 1] # [:,1]はZ=1側の確率\n", 427 | "df_1[\"ITE\"] = Y_1 - M_0.predict(df_1[[\"x\"]])\n", 428 | "\n", 429 | "# 非処置群\n", 430 | "Y_0 = M_0.predict(df_0[[\"x\"]]) + (df_0[\"Y\"] - M_0.predict(df_0[[\"x\"]])) / \\\n", 431 | " g_x.predict_proba(df_0[[\"x\"]])[:, 0] # [:,0]はZ=0側の確率\n", 432 | "df_0[\"ITE\"] = M_1.predict(df_0[[\"x\"]]) - Y_0\n", 433 | "\n", 434 | "# 表を結合する\n", 435 | "df_DR = pd.concat([df_0, df_1])\n", 436 | "df_DR.head()\n" 437 | ], 438 | "execution_count": 8, 439 | "outputs": [ 440 | { 441 | "output_type": "stream", 442 | "text": [ 443 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:3: SettingWithCopyWarning: \n", 444 | "A value is trying to be set on a copy of a slice from a DataFrame.\n", 445 | "Try using .loc[row_indexer,col_indexer] = value instead\n", 446 | "\n", 447 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", 448 | " This is separate from the ipykernel package so we can avoid doing imports until\n", 449 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n", 450 | "A value is trying to be set on a copy of a slice from a DataFrame.\n", 451 | "Try using .loc[row_indexer,col_indexer] = value instead\n", 452 | "\n", 453 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", 454 | " import sys\n" 455 | ], 456 | "name": "stderr" 457 | }, 458 | { 459 | "output_type": "execute_result", 460 | "data": { 461 | "text/html": [ 462 | "
\n", 463 | "\n", 476 | "\n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | "
xZtYITE
0-0.6169610.00.51.8031830.514190
2-0.1245450.00.52.1931230.081865
5-0.4548150.00.51.9732930.333970
6-0.4470710.00.51.9533870.364906
90.7518650.01.02.2893690.776072
\n", 530 | "
" 531 | ], 532 | "text/plain": [ 533 | " x Z t Y ITE\n", 534 | "0 -0.616961 0.0 0.5 1.803183 0.514190\n", 535 | "2 -0.124545 0.0 0.5 2.193123 0.081865\n", 536 | "5 -0.454815 0.0 0.5 1.973293 0.333970\n", 537 | "6 -0.447071 0.0 0.5 1.953387 0.364906\n", 538 | "9 0.751865 0.0 1.0 2.289369 0.776072" 539 | ] 540 | }, 541 | "metadata": { 542 | "tags": [] 543 | }, 544 | "execution_count": 8 545 | } 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "metadata": { 551 | "id": "XvOWVBt99vq7", 552 | "outputId": "97da2ae4-09f7-4411-a4ef-1f177442fa60", 553 | "colab": { 554 | "base_uri": "https://localhost:8080/", 555 | "height": 338 556 | } 557 | }, 558 | "source": [ 559 | "# モデルM_DRを構築し、各人の治療効果をモデルから求める\n", 560 | "\n", 561 | "# モデルM_DR\n", 562 | "M_DR = RandomForestRegressor(max_depth=3)\n", 563 | "M_DR.fit(df_DR[[\"x\"]], df_DR[[\"ITE\"]])\n", 564 | "\n", 565 | "\n", 566 | "# 推定された治療効果を各人ごとに求めます\n", 567 | "t_estimated = M_DR.predict(df_DR[[\"x\"]])\n", 568 | "plt.scatter(df_DR[[\"x\"]], t_estimated,\n", 569 | " label=\"estimated_treatment-effect\")\n", 570 | "\n", 571 | "# 正解のグラフを作成\n", 572 | "x_index = np.arange(-1, 1, 0.01)\n", 573 | "t_ans = np.zeros(len(x_index))\n", 574 | "for i in range(len(x_index)):\n", 575 | " if x_index[i] < 0:\n", 576 | " t_ans[i] = 0.5\n", 577 | " elif x_index[i] >= 0 and x_index[i] < 0.5:\n", 578 | " t_ans[i] = 0.7\n", 579 | " elif x_index[i] >= 0.5:\n", 580 | " t_ans[i] = 1.0\n", 581 | "\n", 582 | "\n", 583 | "# 正解を描画\n", 584 | "plt.plot(x_index, t_ans, color='black', ls='--', label='Baseline')\n" 585 | ], 586 | "execution_count": 9, 587 | "outputs": [ 588 | { 589 | "output_type": "stream", 590 | "text": [ 591 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:5: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", 592 | " \"\"\"\n" 593 | ], 594 | "name": "stderr" 595 | }, 596 | { 597 | "output_type": "execute_result", 598 | "data": { 599 | "text/plain": [ 600 | "[]" 601 | ] 602 | }, 603 | "metadata": { 604 | "tags": [] 605 | }, 606 | "execution_count": 9 607 | }, 608 | { 609 | "output_type": "display_data", 610 | "data": { 611 | "image/png": "\n", 612 | "text/plain": [ 613 | "
" 614 | ] 615 | }, 616 | "metadata": { 617 | "tags": [], 618 | "needs_background": "light" 619 | } 620 | } 621 | ] 622 | }, 623 | { 624 | "cell_type": "markdown", 625 | "metadata": { 626 | "id": "4riPBjmmWX__" 627 | }, 628 | "source": [ 629 | "以上" 630 | ] 631 | } 632 | ] 633 | } -------------------------------------------------------------------------------- /6_3_lingam.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "colab": { 10 | "name": "6_3_lingam.ipynb", 11 | "provenance": [], 12 | "collapsed_sections": [] 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "5-FKx0scdZyU", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "# 6.3 LinGAMを用いた因果探索\n", 24 | "\n", 25 | "本ファイルは、6.3節の内容となります。LiNGAMを実装しながらその内容の理解を深めていきます。" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "A1B4kBJTjXCc", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "## プログラム実行前の設定など" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "metadata": { 41 | "id": "qE00vj2hjUsc", 42 | "colab_type": "code", 43 | "colab": {} 44 | }, 45 | "source": [ 46 | "# 乱数のシードを固定\n", 47 | "import random\n", 48 | "import numpy as np\n", 49 | "\n", 50 | "random.seed(1234)\n", 51 | "np.random.seed(1234)\n" 52 | ], 53 | "execution_count": 0, 54 | "outputs": [] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "PtmPH1FUjZZA", 60 | "colab_type": "code", 61 | "colab": {} 62 | }, 63 | "source": [ 64 | "# 使用するパッケージ(ライブラリと関数)を定義\n", 65 | "import pandas as pd" 66 | ], 67 | "execution_count": 0, 68 | "outputs": [] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "ONKBX56LdZyX", 74 | "colab_type": "text" 75 | }, 76 | "source": [ 77 | "# データ生成" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": { 83 | "colab_type": "text", 84 | "id": "3IVwBzaljlko" 85 | }, 86 | "source": [ 87 | "## モデル\n", 88 | "x1 = 3×x2 + ex1\n", 89 | "\n", 90 | "x2 = ex2\n", 91 | "\n", 92 | "x3 = 2×x1 + 4×x2 + ex3\n", 93 | " \n" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "metadata": { 99 | "id": "-CXDyGWOmOQb", 100 | "colab_type": "code", 101 | "outputId": "abf2a734-ee59-43ea-af8b-e8d21abdefe0", 102 | "colab": { 103 | "base_uri": "https://localhost:8080/", 104 | "height": 195 105 | } 106 | }, 107 | "source": [ 108 | "# データ数\n", 109 | "num_data = 200\n", 110 | "\n", 111 | "# 非ガウスのノイズ\n", 112 | "ex1 = 2*(np.random.rand(num_data)-0.5) # -1.0から1.0\n", 113 | "ex2 = 2*(np.random.rand(num_data)-0.5)\n", 114 | "ex3 = 2*(np.random.rand(num_data)-0.5)\n", 115 | "\n", 116 | "# データ生成\n", 117 | "x2 = ex2\n", 118 | "x1 = 3*x2 + ex1\n", 119 | "x3 = 2*x1 + 4*x2 + ex3\n", 120 | "\n", 121 | "# 表にまとめる\n", 122 | "df = pd.DataFrame({\"x1\": x1, \"x2\": x2, \"x3\": x3})\n", 123 | "df.head()\n" 124 | ], 125 | "execution_count": 3, 126 | "outputs": [ 127 | { 128 | "output_type": "execute_result", 129 | "data": { 130 | "text/html": [ 131 | "
\n", 132 | "\n", 145 | "\n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | "
x1x2x3
02.2572720.9580788.776842
12.5316110.7624648.561263
20.6415470.2553641.341902
33.1536360.8609739.322791
41.9086910.4495805.776675
\n", 187 | "
" 188 | ], 189 | "text/plain": [ 190 | " x1 x2 x3\n", 191 | "0 2.257272 0.958078 8.776842\n", 192 | "1 2.531611 0.762464 8.561263\n", 193 | "2 0.641547 0.255364 1.341902\n", 194 | "3 3.153636 0.860973 9.322791\n", 195 | "4 1.908691 0.449580 5.776675" 196 | ] 197 | }, 198 | "metadata": { 199 | "tags": [] 200 | }, 201 | "execution_count": 3 202 | } 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": { 208 | "id": "Z3Z3H7PldZ0I", 209 | "colab_type": "text" 210 | }, 211 | "source": [ 212 | "## 独立成分分析を実施" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "metadata": { 218 | "id": "um30h2fLdZ0K", 219 | "colab_type": "code", 220 | "outputId": "3fc228bc-1ef1-4c88-f437-4ff435e2f7fa", 221 | "colab": { 222 | "base_uri": "https://localhost:8080/", 223 | "height": 67 224 | } 225 | }, 226 | "source": [ 227 | "# 独立成分分析はscikit-learnの関数を使用します\n", 228 | "from sklearn.decomposition import FastICA\n", 229 | "# https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FastICA.html\n", 230 | "\n", 231 | "ica = FastICA(random_state=1234).fit(df)\n", 232 | "\n", 233 | "# ICAで求めた行列A\n", 234 | "A_ica = ica.mixing_\n", 235 | "\n", 236 | "# 行列Aの逆行列を求める\n", 237 | "A_ica_inv = np.linalg.pinv(A_ica)\n", 238 | "print(A_ica_inv)\n" 239 | ], 240 | "execution_count": 4, 241 | "outputs": [ 242 | { 243 | "output_type": "stream", 244 | "text": [ 245 | "[[-0.23203107 -0.4635971 0.1154553 ]\n", 246 | " [-0.02158245 0.12961253 0.00557934]\n", 247 | " [-0.11326384 0.40437635 -0.00563091]]\n" 248 | ], 249 | "name": "stdout" 250 | } 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": { 256 | "colab_type": "text", 257 | "id": "VWpDEqLNtGe1" 258 | }, 259 | "source": [ 260 | "## 行列A_invを求め、行の順番と大きさを調整\n", 261 | "\n", 262 | "プログラムの参考\n", 263 | "\n", 264 | "https://qiita.com/m__k/items/bd87c063a7496897ba7c" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "metadata": { 270 | "id": "nmBp3lFddZ0e", 271 | "colab_type": "code", 272 | "outputId": "7967ae8a-b03f-4f4f-aef8-6ab7de221eeb", 273 | "colab": { 274 | "base_uri": "https://localhost:8080/", 275 | "height": 34 276 | } 277 | }, 278 | "source": [ 279 | "!pip install munkres\n", 280 | "from munkres import Munkres\n", 281 | "from copy import deepcopy\n" 282 | ], 283 | "execution_count": 5, 284 | "outputs": [ 285 | { 286 | "output_type": "stream", 287 | "text": [ 288 | "Requirement already satisfied: munkres in /usr/local/lib/python3.6/dist-packages (1.1.2)\n" 289 | ], 290 | "name": "stdout" 291 | } 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "metadata": { 297 | "id": "NUhUjh3YtVto", 298 | "colab_type": "code", 299 | "outputId": "d70aa90c-2aa9-4ac5-f58f-053df8759f11", 300 | "colab": { 301 | "base_uri": "https://localhost:8080/", 302 | "height": 67 303 | } 304 | }, 305 | "source": [ 306 | "# 実装の参考\n", 307 | "# [5] Qiita:LiNGAMモデルの推定方法について\n", 308 | "# https://qiita.com/m__k/items/bd87c063a7496897ba7c\n", 309 | "\n", 310 | "# ①「行の順番を変換」→対角成分の絶対値を最大にする\n", 311 | "# (元のA^-1の対角成分は必ず0ではないので)\n", 312 | "\n", 313 | "# 絶対値の逆数にして対角成分の和を最小にする問題に置き換える\n", 314 | "A_ica_inv_small = 1 / np.abs(A_ica_inv)\n", 315 | "\n", 316 | "# 対角成分の和を最小にする行の入れ替え順を求める\n", 317 | "m = Munkres() # ハンガリアン法\n", 318 | "ixs = np.vstack(m.compute(deepcopy(A_ica_inv_small)))\n", 319 | "\n", 320 | "# 求めた順番で変換\n", 321 | "ixs = ixs[np.argsort(ixs[:, 0]), :]\n", 322 | "ixs_perm = ixs[:, 1]\n", 323 | "A_ica_inv_perm = np.zeros_like(A_ica_inv)\n", 324 | "A_ica_inv_perm[ixs_perm] = A_ica_inv\n", 325 | "print(A_ica_inv_perm)\n" 326 | ], 327 | "execution_count": 6, 328 | "outputs": [ 329 | { 330 | "output_type": "stream", 331 | "text": [ 332 | "[[-0.11326384 0.40437635 -0.00563091]\n", 333 | " [-0.02158245 0.12961253 0.00557934]\n", 334 | " [-0.23203107 -0.4635971 0.1154553 ]]\n" 335 | ], 336 | "name": "stdout" 337 | } 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "metadata": { 343 | "id": "aVTjcyKs1aFR", 344 | "colab_type": "code", 345 | "colab": { 346 | "base_uri": "https://localhost:8080/", 347 | "height": 67 348 | }, 349 | "outputId": "c3a08db9-1fad-4485-f18e-c34f0d16cc3f" 350 | }, 351 | "source": [ 352 | "# 並び替わった順番\n", 353 | "print(ixs)" 354 | ], 355 | "execution_count": 7, 356 | "outputs": [ 357 | { 358 | "output_type": "stream", 359 | "text": [ 360 | "[[0 2]\n", 361 | " [1 1]\n", 362 | " [2 0]]\n" 363 | ], 364 | "name": "stdout" 365 | } 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "metadata": { 371 | "id": "uD47ajJBdZ1B", 372 | "colab_type": "code", 373 | "outputId": "82d6f933-1954-42da-a7f2-529910fc0fd3", 374 | "colab": { 375 | "base_uri": "https://localhost:8080/", 376 | "height": 67 377 | } 378 | }, 379 | "source": [ 380 | "# ②「行の大きさを調整」\n", 381 | "D = np.diag(A_ica_inv_perm)[:, np.newaxis] # D倍されているDを求める\n", 382 | "A_ica_inv_perm_D = A_ica_inv_perm / D\n", 383 | "print(A_ica_inv_perm_D)\n" 384 | ], 385 | "execution_count": 8, 386 | "outputs": [ 387 | { 388 | "output_type": "stream", 389 | "text": [ 390 | "[[ 1. -3.57021564 0.04971498]\n", 391 | " [-0.16651518 1. 0.0430463 ]\n", 392 | " [-2.00970483 -4.01538182 1. ]]\n" 393 | ], 394 | "name": "stdout" 395 | } 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "metadata": { 401 | "id": "Hfu7SLMwdZ1G", 402 | "colab_type": "code", 403 | "outputId": "9e1e43ef-8efd-4e03-8b82-ccd860c4dd20", 404 | "colab": { 405 | "base_uri": "https://localhost:8080/", 406 | "height": 67 407 | } 408 | }, 409 | "source": [ 410 | "# ③「B=I-A_inv」\n", 411 | "B_est = np.eye(3) - A_ica_inv_perm_D\n", 412 | "print(B_est)\n" 413 | ], 414 | "execution_count": 9, 415 | "outputs": [ 416 | { 417 | "output_type": "stream", 418 | "text": [ 419 | "[[ 0. 3.57021564 -0.04971498]\n", 420 | " [ 0.16651518 0. -0.0430463 ]\n", 421 | " [ 2.00970483 4.01538182 0. ]]\n" 422 | ], 423 | "name": "stdout" 424 | } 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "metadata": { 430 | "id": "Vipp7leM5pRd", 431 | "colab_type": "code", 432 | "colab": { 433 | "base_uri": "https://localhost:8080/", 434 | "height": 67 435 | }, 436 | "outputId": "7b5c03cf-29d8-4e0d-e68a-de6552b4ece5" 437 | }, 438 | "source": [ 439 | "# ①上側成分の0になるはずの数(3×3であれば3個、4×4であれば6個と、対角成分の上側の要素数分)、絶対値が小さい成分を0にする\n", 440 | "# ②変数の順番を入れ替えて、下三角行列になるかを確かめる、\n", 441 | "# 実装の参考\n", 442 | "# [5] Qiita:LiNGAMモデルの推定方法について\n", 443 | "# https://qiita.com/m__k/items/bd87c063a7496897ba7c\n", 444 | "\n", 445 | "def _slttestperm(b_i):\n", 446 | "# b_iの行を並び替えて下三角行列にできるかどうかチェック\n", 447 | " n = b_i.shape[0]\n", 448 | " remnodes = np.arange(n)\n", 449 | " b_rem = deepcopy(b_i)\n", 450 | " p = list() \n", 451 | "\n", 452 | " for i in range(n):\n", 453 | " # 成分が全て0である行番号のリスト\n", 454 | " ixs = np.where(np.sum(np.abs(b_rem), axis=1) < 1e-12)[0]\n", 455 | "\n", 456 | " if len(ixs) == 0:\n", 457 | " return None\n", 458 | " else:\n", 459 | " ix = ixs[0]\n", 460 | " p.append(remnodes[ix])\n", 461 | "\n", 462 | " # 成分が全て0である行を削除\n", 463 | " remnodes = np.hstack((remnodes[:ix], remnodes[(ix + 1):]))\n", 464 | " ixs = np.hstack((np.arange(ix), np.arange(ix + 1, len(b_rem))))\n", 465 | " b_rem = b_rem[ixs, :]\n", 466 | " b_rem = b_rem[:, ixs]\n", 467 | "\n", 468 | " return np.array(p)\n", 469 | "\n", 470 | "b = B_est\n", 471 | "n = b.shape[0]\n", 472 | "assert(b.shape == (n, n))\n", 473 | "\n", 474 | "ixs = np.argsort(np.abs(b).ravel())\n", 475 | "\n", 476 | "for i in range(int(n * (n + 1) / 2) - 1, (n * n) - 1):\n", 477 | " b_i = deepcopy(b)\n", 478 | " b_i.ravel()[ixs[:i]] = 0\n", 479 | " ixs_perm = _slttestperm(b_i)\n", 480 | " if ixs_perm is not None:\n", 481 | " b_opt = deepcopy(b)\n", 482 | " b_opt = b_opt[ixs_perm, :]\n", 483 | " b_opt = b_opt[:, ixs_perm]\n", 484 | " break\n", 485 | "b_csl = np.tril(b_opt, -1)\n", 486 | "b_csl[ixs_perm, :] = deepcopy(b_csl)\n", 487 | "b_csl[:, ixs_perm] = deepcopy(b_csl)\n", 488 | "\n", 489 | "B_est1 = b_csl\n", 490 | "print(B_est1)\n", 491 | "\n" 492 | ], 493 | "execution_count": 10, 494 | "outputs": [ 495 | { 496 | "output_type": "stream", 497 | "text": [ 498 | "[[0. 3.57021564 0. ]\n", 499 | " [0. 0. 0. ]\n", 500 | " [2.00970483 4.01538182 0. ]]\n" 501 | ], 502 | "name": "stdout" 503 | } 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": { 509 | "id": "EKf9DwAK84hB", 510 | "colab_type": "text" 511 | }, 512 | "source": [ 513 | "## Bの非ゼロ要素を求め直す" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "metadata": { 519 | "id": "QBdpU_dS88Lf", 520 | "colab_type": "code", 521 | "outputId": "176e10b7-a52d-4cf3-fd43-e1a9567b4472", 522 | "colab": { 523 | "base_uri": "https://localhost:8080/", 524 | "height": 50 525 | } 526 | }, 527 | "source": [ 528 | "# scikit-learnから線形回帰をimport\n", 529 | "from sklearn.linear_model import LinearRegression\n", 530 | "\n", 531 | "# 説明変数\n", 532 | "X1 = df[[\"x2\"]]\n", 533 | "X3 = df[[\"x1\", \"x2\"]]\n", 534 | "\n", 535 | "# 被説明変数(目的変数)\n", 536 | "# df[\"x1\"]\n", 537 | "# df[\"x3\"]\n", 538 | "\n", 539 | "# 回帰の実施\n", 540 | "reg1 = LinearRegression().fit(X1, df[\"x1\"])\n", 541 | "reg3 = LinearRegression().fit(X3, df[\"x3\"])\n", 542 | "\n", 543 | "# 回帰した結果の係数を出力\n", 544 | "print(\"係数:\", reg1.coef_)\n", 545 | "print(\"係数:\", reg3.coef_)\n" 546 | ], 547 | "execution_count": 11, 548 | "outputs": [ 549 | { 550 | "output_type": "stream", 551 | "text": [ 552 | "係数: [3.14642595]\n", 553 | "係数: [1.96164568 4.11256441]\n" 554 | ], 555 | "name": "stdout" 556 | } 557 | ] 558 | }, 559 | { 560 | "cell_type": "markdown", 561 | "metadata": { 562 | "id": "i_V6mNsXyXs2", 563 | "colab_type": "text" 564 | }, 565 | "source": [ 566 | "以上" 567 | ] 568 | } 569 | ] 570 | } -------------------------------------------------------------------------------- /7_2_bayesian_network_bic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "colab": { 10 | "name": "7_2_bayesian_network_bic.ipynb", 11 | "provenance": [], 12 | "collapsed_sections": [] 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "aoxI3DOK9vm2", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "# 7.2 BICの計算\n", 24 | "\n", 25 | "本ファイルは、7.2節の実装です。\n", 26 | "\n", 27 | "データに対してBICの値を求めます。" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "2XdIDbdlejUk", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "## プログラム実行前の設定など" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "XZFKJwcu-_Oj", 44 | "colab_type": "code", 45 | "colab": {} 46 | }, 47 | "source": [ 48 | "# 乱数のシードを設定\n", 49 | "import random\n", 50 | "import numpy as np\n", 51 | "\n", 52 | "np.random.seed(1234)\n", 53 | "random.seed(1234)\n" 54 | ], 55 | "execution_count": 0, 56 | "outputs": [] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "id": "hx1idArc_F15", 62 | "colab_type": "code", 63 | "colab": {} 64 | }, 65 | "source": [ 66 | "# 使用するパッケージ(ライブラリと関数)を定義\n", 67 | "from numpy.random import *\n", 68 | "import pandas as pd\n" 69 | ], 70 | "execution_count": 0, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": { 76 | "id": "AWqP6yeQlI_t", 77 | "colab_type": "text" 78 | }, 79 | "source": [ 80 | "## データの作成" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "id": "DpnGB2KZ_L8x", 87 | "colab_type": "code", 88 | "outputId": "b2ca2d8d-76dc-48f1-8dea-ad30de30b06c", 89 | "colab": { 90 | "base_uri": "https://localhost:8080/", 91 | "height": 195 92 | } 93 | }, 94 | "source": [ 95 | "# データ数\n", 96 | "num_data = 10\n", 97 | "\n", 98 | "# x1:0か1の値をnum_data個生成、0の確率は0.6、1の確率は0.4\n", 99 | "x1 = np.random.choice([0, 1], num_data, p=[0.6, 0.4])\n", 100 | "\n", 101 | "# x2:0か1の値をnum_data個生成、0の確率は0.4、1の確率は0.6\n", 102 | "x2 = np.random.choice([0, 1], num_data, p=[0.4, 0.6])\n", 103 | "\n", 104 | "# 2変数で表にする\n", 105 | "df = pd.DataFrame({'x1': x1,\n", 106 | " 'x2': x2,\n", 107 | " })\n", 108 | "\n", 109 | "df.head() # 先頭を表示\n" 110 | ], 111 | "execution_count": 0, 112 | "outputs": [ 113 | { 114 | "output_type": "execute_result", 115 | "data": { 116 | "text/html": [ 117 | "
\n", 118 | "\n", 131 | "\n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | "
x1x2
000
111
201
311
410
\n", 167 | "
" 168 | ], 169 | "text/plain": [ 170 | " x1 x2\n", 171 | "0 0 0\n", 172 | "1 1 1\n", 173 | "2 0 1\n", 174 | "3 1 1\n", 175 | "4 1 0" 176 | ] 177 | }, 178 | "metadata": { 179 | "tags": [] 180 | }, 181 | "execution_count": 3 182 | } 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "metadata": { 188 | "id": "YvCfB7uRZvZI", 189 | "colab_type": "code", 190 | "outputId": "bad340c8-374a-42e7-bc43-1c9b22fb1251", 191 | "colab": { 192 | "base_uri": "https://localhost:8080/", 193 | "height": 343 194 | } 195 | }, 196 | "source": [ 197 | "# 変数x3:0か1の値をnum_data個生成する\n", 198 | "# (x1,x2)= (0,0)のとき、0の確率は0.2\n", 199 | "# (x1,x2)= (1,0)のとき、0の確率は0.3\n", 200 | "# (x1,x2)= (0,1)のとき、0の確率は0.4\n", 201 | "# (x1,x2)= (1,1)のとき、0の確率は0.1\n", 202 | "\n", 203 | "x3 = []\n", 204 | "for i in range(num_data):\n", 205 | " if x1[i] == 0 and x2[i] == 0:\n", 206 | " x3_value = np.random.choice([0, 1], 1, p=[0.2, 0.8])\n", 207 | " x3.append(x3_value[0]) # x3はリストになっているので、0番目の要素を取り出して追加\n", 208 | " elif x1[i] == 0 and x2[i] == 1:\n", 209 | " x3_value = np.random.choice([0, 1], 1, p=[0.3, 0.7])\n", 210 | " x3.append(x3_value[0])\n", 211 | " elif x1[i] == 1 and x2[i] == 0:\n", 212 | " x3_value = np.random.choice([0, 1], 1, p=[0.4, 0.6])\n", 213 | " x3.append(x3_value[0])\n", 214 | " elif x1[i] == 1 and x2[i] == 1:\n", 215 | " x3_value = np.random.choice([0, 1], 1, p=[0.1, 0.9])\n", 216 | " x3.append(x3_value[0])\n", 217 | "\n", 218 | "df[\"x3\"] = x3\n", 219 | "\n", 220 | "df # 表示\n" 221 | ], 222 | "execution_count": 0, 223 | "outputs": [ 224 | { 225 | "output_type": "execute_result", 226 | "data": { 227 | "text/html": [ 228 | "
\n", 229 | "\n", 242 | "\n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | "
x1x2x3
0001
1111
2010
3111
4101
5011
6011
7101
8111
9111
\n", 314 | "
" 315 | ], 316 | "text/plain": [ 317 | " x1 x2 x3\n", 318 | "0 0 0 1\n", 319 | "1 1 1 1\n", 320 | "2 0 1 0\n", 321 | "3 1 1 1\n", 322 | "4 1 0 1\n", 323 | "5 0 1 1\n", 324 | "6 0 1 1\n", 325 | "7 1 0 1\n", 326 | "8 1 1 1\n", 327 | "9 1 1 1" 328 | ] 329 | }, 330 | "metadata": { 331 | "tags": [] 332 | }, 333 | "execution_count": 4 334 | } 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "metadata": { 340 | "id": "BHcdUlW9koTa", 341 | "colab_type": "text" 342 | }, 343 | "source": [ 344 | "## pgmpy(Python library for Probabilistic Graphical Models)によるBICの計算\n" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "metadata": { 350 | "id": "25oDRf7qtNtF", 351 | "colab_type": "code", 352 | "outputId": "8d2c208a-2371-4a82-9415-59ee92f0fdca", 353 | "colab": { 354 | "base_uri": "https://localhost:8080/", 355 | "height": 101 356 | } 357 | }, 358 | "source": [ 359 | "!pip install pgmpy==0.1.9" 360 | ], 361 | "execution_count": 0, 362 | "outputs": [ 363 | { 364 | "output_type": "stream", 365 | "text": [ 366 | "Collecting pgmpy==0.1.9\n", 367 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/5a/b1/18dfdfcb10dcce71fd39f8c6801407e9aebd953939682558a5317e4a021c/pgmpy-0.1.9-py3-none-any.whl (331kB)\n", 368 | "\u001b[K |████████████████████████████████| 337kB 2.8MB/s \n", 369 | "\u001b[?25hInstalling collected packages: pgmpy\n", 370 | "Successfully installed pgmpy-0.1.9\n" 371 | ], 372 | "name": "stdout" 373 | } 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "metadata": { 379 | "id": "H1i9-YjMdDTh", 380 | "colab_type": "code", 381 | "colab": {} 382 | }, 383 | "source": [ 384 | "# 正解のDAGを与える\n", 385 | "from pgmpy.models import BayesianModel\n", 386 | "model = BayesianModel([('x1', 'x3'), ('x2', 'x3')]) # x1 -> x3 <- x2\n" 387 | ], 388 | "execution_count": 0, 389 | "outputs": [] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "metadata": { 394 | "id": "WFKQb7XudDW3", 395 | "colab_type": "code", 396 | "outputId": "460d06f3-9e80-430e-e3c3-4e9b29066eaa", 397 | "colab": { 398 | "base_uri": "https://localhost:8080/", 399 | "height": 286 400 | } 401 | }, 402 | "source": [ 403 | "# 各データパターンの個数を表示する\n", 404 | "from pgmpy.estimators import ParameterEstimator\n", 405 | "pe = ParameterEstimator(model, df)\n", 406 | "print(\"\\n\", pe.state_counts('x1'))\n", 407 | "print(\"\\n\", pe.state_counts('x2'))\n", 408 | "print(\"\\n\", pe.state_counts('x3'))\n" 409 | ], 410 | "execution_count": 0, 411 | "outputs": [ 412 | { 413 | "output_type": "stream", 414 | "text": [ 415 | "/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", 416 | " import pandas.util.testing as tm\n" 417 | ], 418 | "name": "stderr" 419 | }, 420 | { 421 | "output_type": "stream", 422 | "text": [ 423 | "\n", 424 | " x1\n", 425 | "0 4\n", 426 | "1 6\n", 427 | "\n", 428 | " x2\n", 429 | "0 3\n", 430 | "1 7\n", 431 | "\n", 432 | " x1 0 1 \n", 433 | "x2 0 1 0 1\n", 434 | "x3 \n", 435 | "0 0.0 1.0 0.0 0.0\n", 436 | "1 1.0 2.0 2.0 4.0\n" 437 | ], 438 | "name": "stdout" 439 | } 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "metadata": { 445 | "id": "ztZJyobWwalY", 446 | "colab_type": "code", 447 | "outputId": "11bbeb28-dc26-4a7b-cbd3-08f4cd7f75e9", 448 | "colab": { 449 | "base_uri": "https://localhost:8080/", 450 | "height": 336 451 | } 452 | }, 453 | "source": [ 454 | "# CPT(条件付き確率表)を推定する\n", 455 | "from pgmpy.estimators import BayesianEstimator\n", 456 | "\n", 457 | "estimator = BayesianEstimator(model, df)\n", 458 | "\n", 459 | "cpd_x1 = estimator.estimate_cpd(\n", 460 | " 'x1', prior_type=\"dirichlet\", pseudo_counts=[[0], [0]])\n", 461 | "cpd_x2 = estimator.estimate_cpd(\n", 462 | " 'x2', prior_type=\"dirichlet\", pseudo_counts=[[0], [0]])\n", 463 | "cpd_x3 = estimator.estimate_cpd('x3', prior_type=\"dirichlet\", pseudo_counts=[\n", 464 | " [0, 0, 0, 0], [0, 0, 0, 0]])\n", 465 | "# 注意:pseudo_countsはハイパーパラメータ0のディリクレ分布の設定を与えています。\n", 466 | "\n", 467 | "print(cpd_x1)\n", 468 | "print(cpd_x2)\n", 469 | "print(cpd_x3)\n" 470 | ], 471 | "execution_count": 0, 472 | "outputs": [ 473 | { 474 | "output_type": "stream", 475 | "text": [ 476 | "+-------+-----+\n", 477 | "| x1(0) | 0.4 |\n", 478 | "+-------+-----+\n", 479 | "| x1(1) | 0.6 |\n", 480 | "+-------+-----+\n", 481 | "+-------+-----+\n", 482 | "| x2(0) | 0.3 |\n", 483 | "+-------+-----+\n", 484 | "| x2(1) | 0.7 |\n", 485 | "+-------+-----+\n", 486 | "+-------+-------+--------------------+-------+-------+\n", 487 | "| x1 | x1(0) | x1(0) | x1(1) | x1(1) |\n", 488 | "+-------+-------+--------------------+-------+-------+\n", 489 | "| x2 | x2(0) | x2(1) | x2(0) | x2(1) |\n", 490 | "+-------+-------+--------------------+-------+-------+\n", 491 | "| x3(0) | 0.0 | 0.3333333333333333 | 0.0 | 0.0 |\n", 492 | "+-------+-------+--------------------+-------+-------+\n", 493 | "| x3(1) | 1.0 | 0.6666666666666666 | 1.0 | 1.0 |\n", 494 | "+-------+-------+--------------------+-------+-------+\n" 495 | ], 496 | "name": "stdout" 497 | } 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "metadata": { 503 | "id": "T8UqcSmXyX_4", 504 | "colab_type": "code", 505 | "outputId": "3207d204-6108-476f-c9a6-069a676aa1d3", 506 | "colab": { 507 | "base_uri": "https://localhost:8080/", 508 | "height": 34 509 | } 510 | }, 511 | "source": [ 512 | "# BICを求める\n", 513 | "from pgmpy.estimators import BicScore\n", 514 | "bic = BicScore(df)\n", 515 | "print(bic.score(model))\n" 516 | ], 517 | "execution_count": 0, 518 | "outputs": [ 519 | { 520 | "output_type": "stream", 521 | "text": [ 522 | "-21.65605747450808\n" 523 | ], 524 | "name": "stdout" 525 | } 526 | ] 527 | }, 528 | { 529 | "cell_type": "markdown", 530 | "metadata": { 531 | "id": "6Pvo1RIbEoyY", 532 | "colab_type": "text" 533 | }, 534 | "source": [ 535 | "## 異なるDAGでのBICの計算" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "metadata": { 541 | "id": "y2ZRLS0fEtnc", 542 | "colab_type": "code", 543 | "outputId": "6e9cfc89-cd37-4b73-ee48-39c598f5f59f", 544 | "colab": { 545 | "base_uri": "https://localhost:8080/", 546 | "height": 34 547 | } 548 | }, 549 | "source": [ 550 | "# 正解ではないDAGを与える\n", 551 | "from pgmpy.models import BayesianModel\n", 552 | "model = BayesianModel([('x2', 'x1'), ('x2', 'x3')]) # x1 <- x2 -> x3\n", 553 | "bic = BicScore(df)\n", 554 | "print(bic.score(model))\n" 555 | ], 556 | "execution_count": 0, 557 | "outputs": [ 558 | { 559 | "output_type": "stream", 560 | "text": [ 561 | "-21.425819218840655\n" 562 | ], 563 | "name": "stdout" 564 | } 565 | ] 566 | }, 567 | { 568 | "cell_type": "markdown", 569 | "metadata": { 570 | "colab_type": "text", 571 | "id": "I6P1x9vAdG3i" 572 | }, 573 | "source": [ 574 | "以上" 575 | ] 576 | } 577 | ] 578 | } -------------------------------------------------------------------------------- /7_3_bayesian_network_independence_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "colab": { 10 | "name": "7_3_bayesian_network__independence_test.ipynb", 11 | "provenance": [], 12 | "collapsed_sections": [] 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "aoxI3DOK9vm2", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "# 7.3 独立性の検定\n", 24 | "\n", 25 | "本ファイルは、7.3節の実装です。\n", 26 | "\n", 27 | "データに対して独立性のカイ二乗検定を実施ます" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "2XdIDbdlejUk", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "## プログラム実行前の設定など" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "XZFKJwcu-_Oj", 44 | "colab_type": "code", 45 | "colab": {} 46 | }, 47 | "source": [ 48 | "# 乱数のシードを設定\n", 49 | "import random\n", 50 | "import numpy as np\n", 51 | "\n", 52 | "np.random.seed(1234)\n", 53 | "random.seed(1234)\n" 54 | ], 55 | "execution_count": 0, 56 | "outputs": [] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "id": "hx1idArc_F15", 62 | "colab_type": "code", 63 | "colab": {} 64 | }, 65 | "source": [ 66 | "# 使用するパッケージ(ライブラリと関数)を定義\n", 67 | "from numpy.random import *\n", 68 | "import pandas as pd\n" 69 | ], 70 | "execution_count": 0, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": { 76 | "id": "AWqP6yeQlI_t", 77 | "colab_type": "text" 78 | }, 79 | "source": [ 80 | "## データの作成" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "id": "DpnGB2KZ_L8x", 87 | "colab_type": "code", 88 | "outputId": "7a7e8d89-2383-4b04-d0d0-3a15be37316f", 89 | "colab": { 90 | "base_uri": "https://localhost:8080/", 91 | "height": 195 92 | } 93 | }, 94 | "source": [ 95 | "# データ数\n", 96 | "num_data = 100\n", 97 | "\n", 98 | "# x1:0か1の値をnum_data個生成、0の確率は0.6、1の確率は0.4\n", 99 | "x1 = np.random.choice([0, 1], num_data, p=[0.6, 0.4])\n", 100 | "\n", 101 | "# x2:0か1の値をnum_data個生成、0の確率は0.4、1の確率は0.6\n", 102 | "x2 = np.random.choice([0, 1], num_data, p=[0.4, 0.6])\n", 103 | "\n", 104 | "# x2はx1と因果関係にあるとする\n", 105 | "x2 = x2*x1\n", 106 | "\n", 107 | "# 2変数で表にする\n", 108 | "df = pd.DataFrame({'x1': x1,\n", 109 | " 'x2': x2,\n", 110 | " })\n", 111 | "\n", 112 | "df.head() # 先頭を表示\n" 113 | ], 114 | "execution_count": 3, 115 | "outputs": [ 116 | { 117 | "output_type": "execute_result", 118 | "data": { 119 | "text/html": [ 120 | "
\n", 121 | "\n", 134 | "\n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | "
x1x2
000
111
200
311
411
\n", 170 | "
" 171 | ], 172 | "text/plain": [ 173 | " x1 x2\n", 174 | "0 0 0\n", 175 | "1 1 1\n", 176 | "2 0 0\n", 177 | "3 1 1\n", 178 | "4 1 1" 179 | ] 180 | }, 181 | "metadata": { 182 | "tags": [] 183 | }, 184 | "execution_count": 3 185 | } 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "metadata": { 191 | "id": "YvCfB7uRZvZI", 192 | "colab_type": "code", 193 | "outputId": "0852e5c6-bdcc-4aae-b0d9-722dbfd24514", 194 | "colab": { 195 | "base_uri": "https://localhost:8080/", 196 | "height": 84 197 | } 198 | }, 199 | "source": [ 200 | "# 各カウント\n", 201 | "print(((df[\"x1\"] == 0) & (df[\"x2\"] == 0)).sum())\n", 202 | "print(((df[\"x1\"] == 1) & (df[\"x2\"] == 0)).sum())\n", 203 | "print(((df[\"x1\"] == 0) & (df[\"x2\"] == 1)).sum())\n", 204 | "print(((df[\"x1\"] == 1) & (df[\"x2\"] == 1)).sum())\n" 205 | ], 206 | "execution_count": 4, 207 | "outputs": [ 208 | { 209 | "output_type": "stream", 210 | "text": [ 211 | "58\n", 212 | "9\n", 213 | "0\n", 214 | "33\n" 215 | ], 216 | "name": "stdout" 217 | } 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": { 223 | "id": "BHcdUlW9koTa", 224 | "colab_type": "text" 225 | }, 226 | "source": [ 227 | "## pgmpy(Python library for Probabilistic Graphical Models)による独立性の検定\n" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "metadata": { 233 | "id": "25oDRf7qtNtF", 234 | "colab_type": "code", 235 | "outputId": "64ca1b9e-846d-4040-a37b-acd185f2d66a", 236 | "colab": { 237 | "base_uri": "https://localhost:8080/", 238 | "height": 101 239 | } 240 | }, 241 | "source": [ 242 | "!pip install pgmpy==0.1.9" 243 | ], 244 | "execution_count": 5, 245 | "outputs": [ 246 | { 247 | "output_type": "stream", 248 | "text": [ 249 | "Collecting pgmpy==0.1.9\n", 250 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/5a/b1/18dfdfcb10dcce71fd39f8c6801407e9aebd953939682558a5317e4a021c/pgmpy-0.1.9-py3-none-any.whl (331kB)\n", 251 | "\r\u001b[K |█ | 10kB 19.2MB/s eta 0:00:01\r\u001b[K |██ | 20kB 1.7MB/s eta 0:00:01\r\u001b[K |███ | 30kB 2.5MB/s eta 0:00:01\r\u001b[K |████ | 40kB 1.7MB/s eta 0:00:01\r\u001b[K |█████ | 51kB 2.1MB/s eta 0:00:01\r\u001b[K |██████ | 61kB 2.5MB/s eta 0:00:01\r\u001b[K |███████ | 71kB 2.9MB/s eta 0:00:01\r\u001b[K |████████ | 81kB 2.2MB/s eta 0:00:01\r\u001b[K |█████████ | 92kB 2.5MB/s eta 0:00:01\r\u001b[K |█████████▉ | 102kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████▉ | 112kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████▉ | 122kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████▉ | 133kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████▉ | 143kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████▉ | 153kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████▉ | 163kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████▉ | 174kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████▉ | 184kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████████▊ | 194kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████▊ | 204kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████▊ | 215kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████▊ | 225kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████████████▊ | 235kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████▊ | 245kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 256kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████▊ | 266kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████▊ | 276kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████▋ | 286kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████▋ | 296kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 307kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▋ | 317kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▋| 327kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 337kB 2.8MB/s \n", 252 | "\u001b[?25hInstalling collected packages: pgmpy\n", 253 | "Successfully installed pgmpy-0.1.9\n" 254 | ], 255 | "name": "stdout" 256 | } 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "metadata": { 262 | "id": "K8rFdErsnUhd", 263 | "colab_type": "code", 264 | "outputId": "72888e1d-07dd-4a0f-886d-469df5010b54", 265 | "colab": { 266 | "base_uri": "https://localhost:8080/", 267 | "height": 84 268 | } 269 | }, 270 | "source": [ 271 | "# データ数\n", 272 | "num_data = 100\n", 273 | "\n", 274 | "# x1:0か1の値をnum_data個生成、0の確率は0.6、1の確率は0.4\n", 275 | "x1 = np.random.choice([0, 1], num_data, p=[0.6, 0.4])\n", 276 | "\n", 277 | "# x2:0か1の値をnum_data個生成、0の確率は0.4、1の確率は0.6\n", 278 | "x2 = np.random.choice([0, 1], num_data, p=[0.4, 0.6])\n", 279 | "\n", 280 | "# 2変数で表にする\n", 281 | "df2 = pd.DataFrame({'x1': x1,\n", 282 | " 'x2': x2,\n", 283 | " })\n", 284 | "\n", 285 | "# 各カウント\n", 286 | "print(((df2[\"x1\"] == 0) & (df2[\"x2\"] == 0)).sum())\n", 287 | "print(((df2[\"x1\"] == 1) & (df2[\"x2\"] == 0)).sum())\n", 288 | "print(((df2[\"x1\"] == 0) & (df2[\"x2\"] == 1)).sum())\n", 289 | "print(((df2[\"x1\"] == 1) & (df2[\"x2\"] == 1)).sum())\n" 290 | ], 291 | "execution_count": 6, 292 | "outputs": [ 293 | { 294 | "output_type": "stream", 295 | "text": [ 296 | "20\n", 297 | "15\n", 298 | "35\n", 299 | "30\n" 300 | ], 301 | "name": "stdout" 302 | } 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "metadata": { 308 | "id": "nJJBRMKEnbjO", 309 | "colab_type": "code", 310 | "outputId": "fb9adaf1-b048-46a6-a96c-4cc5987a4055", 311 | "colab": { 312 | "base_uri": "https://localhost:8080/", 313 | "height": 84 314 | } 315 | }, 316 | "source": [ 317 | "from pgmpy.estimators import ConstraintBasedEstimator\n", 318 | "\n", 319 | "est = ConstraintBasedEstimator(df2)\n", 320 | "print(est.test_conditional_independence(\n", 321 | " 'x1', 'x2', method=\"chi_square\", tol=0.05)) # 独立\n", 322 | "\n", 323 | "# 最初の例の場合\n", 324 | "est = ConstraintBasedEstimator(df)\n", 325 | "print(est.test_conditional_independence(\n", 326 | " 'x1', 'x2', method=\"chi_square\", tol=0.05)) # 独立でない\n" 327 | ], 328 | "execution_count": 7, 329 | "outputs": [ 330 | { 331 | "output_type": "stream", 332 | "text": [ 333 | "/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", 334 | " import pandas.util.testing as tm\n" 335 | ], 336 | "name": "stderr" 337 | }, 338 | { 339 | "output_type": "stream", 340 | "text": [ 341 | "True\n", 342 | "False\n" 343 | ], 344 | "name": "stdout" 345 | } 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": { 351 | "colab_type": "text", 352 | "id": "I6P1x9vAdG3i" 353 | }, 354 | "source": [ 355 | "以上" 356 | ] 357 | } 358 | ] 359 | } -------------------------------------------------------------------------------- /8_3_5_deeplearning_gan_sam.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "8_3_5_deeplearning_gan_sam.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "language": "python", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "codemirror_mode": { 18 | "name": "ipython", 19 | "version": 3 20 | }, 21 | "file_extension": ".py", 22 | "mimetype": "text/x-python", 23 | "name": "python", 24 | "nbconvert_exporter": "python", 25 | "pygments_lexer": "ipython3", 26 | "version": "3.6.5" 27 | } 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "colab_type": "text", 34 | "id": "aoxI3DOK9vm2" 35 | }, 36 | "source": [ 37 | "# 8.3 SAM(Structural Agnostic Model)による因果探索の実装\n", 38 | "\n", 39 | "本ファイルは、8.3節の実装です。\n", 40 | "\n", 41 | "7.5節と同じく、「上司向け:部下とのキャリア面談のポイント研修」の疑似データを作成して、SAMによる因果探索を実施します。" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": { 47 | "colab_type": "text", 48 | "id": "2XdIDbdlejUk" 49 | }, 50 | "source": [ 51 | "## プログラム実行前の設定など" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "metadata": { 57 | "id": "_QZagoIYv44f", 58 | "colab_type": "code", 59 | "outputId": "1c133069-f1aa-4f79-ed78-ab14a644889b", 60 | "colab": { 61 | "base_uri": "https://localhost:8080/", 62 | "height": 122 63 | } 64 | }, 65 | "source": [ 66 | "# PyTorchのバージョンを下げる\n", 67 | "!pip install torch==1.4.0+cu92 torchvision==0.5.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html" 68 | ], 69 | "execution_count": 0, 70 | "outputs": [ 71 | { 72 | "output_type": "stream", 73 | "text": [ 74 | "Looking in links: https://download.pytorch.org/whl/torch_stable.html\n", 75 | "Requirement already satisfied: torch==1.4.0+cu92 in /usr/local/lib/python3.6/dist-packages (1.4.0+cu92)\n", 76 | "Requirement already satisfied: torchvision==0.5.0+cu92 in /usr/local/lib/python3.6/dist-packages (0.5.0+cu92)\n", 77 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0+cu92) (1.18.4)\n", 78 | "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0+cu92) (7.0.0)\n", 79 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0+cu92) (1.12.0)\n" 80 | ], 81 | "name": "stdout" 82 | } 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "metadata": { 88 | "id": "iqh9FyP-wHGa", 89 | "colab_type": "code", 90 | "outputId": "d70b471a-b77c-47b2-af12-990ad611b00d", 91 | "colab": { 92 | "base_uri": "https://localhost:8080/", 93 | "height": 34 94 | } 95 | }, 96 | "source": [ 97 | "import torch \n", 98 | "print(torch.__version__) # 元は1.5.0+cu101、versionを1.4に下げた" 99 | ], 100 | "execution_count": 0, 101 | "outputs": [ 102 | { 103 | "output_type": "stream", 104 | "text": [ 105 | "1.4.0+cu92\n" 106 | ], 107 | "name": "stdout" 108 | } 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "metadata": { 114 | "colab_type": "code", 115 | "id": "XZFKJwcu-_Oj", 116 | "colab": {} 117 | }, 118 | "source": [ 119 | "# 乱数のシードを設定\n", 120 | "import random\n", 121 | "import numpy as np\n", 122 | "\n", 123 | "np.random.seed(1234)\n", 124 | "random.seed(1234)\n" 125 | ], 126 | "execution_count": 0, 127 | "outputs": [] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "metadata": { 132 | "colab_type": "code", 133 | "id": "hx1idArc_F15", 134 | "colab": {} 135 | }, 136 | "source": [ 137 | "# 使用するパッケージ(ライブラリと関数)を定義\n", 138 | "# 標準正規分布の生成用\n", 139 | "from numpy.random import *\n", 140 | "\n", 141 | "# グラフの描画用\n", 142 | "import matplotlib.pyplot as plt\n", 143 | "\n", 144 | "# その他\n", 145 | "import pandas as pd\n", 146 | "\n", 147 | "# シグモイド関数をimport\n", 148 | "from scipy.special import expit\n" 149 | ], 150 | "execution_count": 0, 151 | "outputs": [] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": { 156 | "colab_type": "text", 157 | "id": "AWqP6yeQlI_t" 158 | }, 159 | "source": [ 160 | "## データの作成" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "colab_type": "code", 167 | "id": "QBsAEiQ77xww", 168 | "colab": {} 169 | }, 170 | "source": [ 171 | "# データ数\n", 172 | "num_data = 2000\n", 173 | "\n", 174 | "# 部下育成への熱心さ\n", 175 | "x = np.random.uniform(low=-1, high=1, size=num_data) # -1から1の一様乱数\n", 176 | "\n", 177 | "# 上司が「上司向け:部下とのキャリア面談のポイント研修」に参加したかどうか\n", 178 | "e_z = randn(num_data) # ノイズの生成\n", 179 | "z_prob = expit(-5.0*x+5*e_z)\n", 180 | "Z = np.array([])\n", 181 | "\n", 182 | "# 上司が「上司向け:部下とのキャリア面談のポイント研修」に参加したかどうか\n", 183 | "for i in range(num_data):\n", 184 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n", 185 | " Z = np.append(Z, Z_i)\n", 186 | "\n", 187 | "# 介入効果の非線形性:部下育成の熱心さxの値に応じて段階的に変化\n", 188 | "t = np.zeros(num_data)\n", 189 | "for i in range(num_data):\n", 190 | " if x[i] < 0:\n", 191 | " t[i] = 0.5\n", 192 | " elif x[i] >= 0 and x[i] < 0.5:\n", 193 | " t[i] = 0.7\n", 194 | " elif x[i] >= 0.5:\n", 195 | " t[i] = 1.0\n", 196 | "\n", 197 | "e_y = randn(num_data)\n", 198 | "Y = 2.0 + t*Z + 0.3*x + 0.1*e_y \n", 199 | "\n", 200 | "\n", 201 | "# 本章からの追加データを生成\n", 202 | "\n", 203 | "# Y2:部下当人のチームメンバへの満足度 1から5の5段階\n", 204 | "Y2 = np.random.choice([1.0, 2.0, 3.0, 4.0, 5.0],\n", 205 | " num_data, p=[0.1, 0.2, 0.3, 0.2, 0.2])\n", 206 | "\n", 207 | "# Y3:部下当人の仕事への満足度\n", 208 | "e_y3 = randn(num_data)\n", 209 | "Y3 = 3*Y + Y2 + e_y3\n", 210 | "\n", 211 | "# Y4:部下当人の仕事のパフォーマンス\n", 212 | "e_y4 = randn(num_data)\n", 213 | "Y4 = 3*Y3 + 2*e_y4 + 5\n" 214 | ], 215 | "execution_count": 0, 216 | "outputs": [] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": { 221 | "colab_type": "text", 222 | "id": "BHcdUlW9koTa" 223 | }, 224 | "source": [ 225 | "## データをまとめた表を作成し、正規化し、可視化する" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "metadata": { 231 | "colab_type": "code", 232 | "id": "1EMwdGIIIPrK", 233 | "outputId": "9530a268-91e5-4d68-b8ab-c2d1f62dddcd", 234 | "colab": { 235 | "base_uri": "https://localhost:8080/", 236 | "height": 195 237 | } 238 | }, 239 | "source": [ 240 | "df = pd.DataFrame({'x': x,\n", 241 | " 'Z': Z,\n", 242 | " 't': t,\n", 243 | " 'Y': Y,\n", 244 | " 'Y2': Y2,\n", 245 | " 'Y3': Y3,\n", 246 | " 'Y4': Y4,\n", 247 | " })\n", 248 | "\n", 249 | "del df[\"t\"] # 変数tは観測できないので削除\n", 250 | "\n", 251 | "df.head() # 先頭を表示\n" 252 | ], 253 | "execution_count": 0, 254 | "outputs": [ 255 | { 256 | "output_type": "execute_result", 257 | "data": { 258 | "text/html": [ 259 | "
\n", 260 | "\n", 273 | "\n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | "
xZYY2Y3Y4
0-0.6169611.02.2869242.08.73254430.326507
10.2442181.02.8646363.010.74395937.149014
2-0.1245450.02.1985153.010.56916338.481185
30.5707171.03.2305723.012.31252643.709229
40.5599520.02.4592675.012.41873940.833938
\n", 333 | "
" 334 | ], 335 | "text/plain": [ 336 | " x Z Y Y2 Y3 Y4\n", 337 | "0 -0.616961 1.0 2.286924 2.0 8.732544 30.326507\n", 338 | "1 0.244218 1.0 2.864636 3.0 10.743959 37.149014\n", 339 | "2 -0.124545 0.0 2.198515 3.0 10.569163 38.481185\n", 340 | "3 0.570717 1.0 3.230572 3.0 12.312526 43.709229\n", 341 | "4 0.559952 0.0 2.459267 5.0 12.418739 40.833938" 342 | ] 343 | }, 344 | "metadata": { 345 | "tags": [] 346 | }, 347 | "execution_count": 6 348 | } 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": { 354 | "colab_type": "text", 355 | "id": "1TPIeXDg6QDG" 356 | }, 357 | "source": [ 358 | "## SAMによる推論を実施" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "metadata": { 364 | "colab_type": "code", 365 | "id": "edNNPSLY6u6d", 366 | "outputId": "36f912b0-1dbb-4516-8ea5-d5af39cce918", 367 | "colab": { 368 | "base_uri": "https://localhost:8080/", 369 | "height": 386 370 | } 371 | }, 372 | "source": [ 373 | "!pip install cdt==0.5.18" 374 | ], 375 | "execution_count": 0, 376 | "outputs": [ 377 | { 378 | "output_type": "stream", 379 | "text": [ 380 | "Requirement already satisfied: cdt==0.5.18 in /usr/local/lib/python3.6/dist-packages (0.5.18)\n", 381 | "Requirement already satisfied: GPUtil in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.4.0)\n", 382 | "Requirement already satisfied: statsmodels in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.10.2)\n", 383 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (4.41.1)\n", 384 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (2.23.0)\n", 385 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.18.4)\n", 386 | "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.15.1)\n", 387 | "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.0.4)\n", 388 | "Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.4.1)\n", 389 | "Requirement already satisfied: networkx in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (2.4)\n", 390 | "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.22.2.post1)\n", 391 | "Requirement already satisfied: skrebate in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.6)\n", 392 | "Requirement already satisfied: patsy>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from statsmodels->cdt==0.5.18) (0.5.1)\n", 393 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (2020.4.5.1)\n", 394 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (2.9)\n", 395 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (1.24.3)\n", 396 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (3.0.4)\n", 397 | "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->cdt==0.5.18) (2.8.1)\n", 398 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->cdt==0.5.18) (2018.9)\n", 399 | "Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx->cdt==0.5.18) (4.4.2)\n", 400 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from patsy>=0.4.0->statsmodels->cdt==0.5.18) (1.12.0)\n" 401 | ], 402 | "name": "stdout" 403 | } 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "metadata": { 409 | "colab_type": "text", 410 | "id": "ihTvgRcv1E8s" 411 | }, 412 | "source": [ 413 | "### SAMの識別器Dの実装" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "metadata": { 419 | "colab_type": "code", 420 | "id": "sJQ2_9LY8MQ8", 421 | "colab": {} 422 | }, 423 | "source": [ 424 | "# PyTorchから使用するものをimport\n", 425 | "import torch\n", 426 | "import torch.nn as nn\n", 427 | "\n", 428 | "\n", 429 | "class SAMDiscriminator(nn.Module):\n", 430 | " \"\"\"SAMのDiscriminatorのニューラルネットワーク\n", 431 | " \"\"\"\n", 432 | "\n", 433 | " def __init__(self, nfeatures, dnh, hlayers):\n", 434 | " super(SAMDiscriminator, self).__init__()\n", 435 | "\n", 436 | " # ----------------------------------\n", 437 | " # ネットワークの用意\n", 438 | " # ----------------------------------\n", 439 | " self.nfeatures = nfeatures # 入力変数の数\n", 440 | "\n", 441 | " layers = []\n", 442 | " layers.append(nn.Linear(nfeatures, dnh))\n", 443 | " layers.append(nn.BatchNorm1d(dnh))\n", 444 | " layers.append(nn.LeakyReLU(.2))\n", 445 | "\n", 446 | " for i in range(hlayers-1):\n", 447 | " layers.append(nn.Linear(dnh, dnh))\n", 448 | " layers.append(nn.BatchNorm1d(dnh))\n", 449 | " layers.append(nn.LeakyReLU(.2))\n", 450 | "\n", 451 | " layers.append(nn.Linear(dnh, 1)) # 最終出力\n", 452 | "\n", 453 | " self.layers = nn.Sequential(*layers)\n", 454 | "\n", 455 | " # ----------------------------------\n", 456 | " # maskの用意(対角成分のみ1で、他は0の行列)\n", 457 | " # ----------------------------------\n", 458 | " mask = torch.eye(nfeatures, nfeatures) # 変数の数×変数の数の単位行列\n", 459 | " self.register_buffer(\"mask\", mask.unsqueeze(0)) # 単位行列maskを保存しておく\n", 460 | "\n", 461 | " # 注意:register_bufferはmodelのパラメータではないが、その後forwardで使う変数を登録するPyTorchのメソッドです\n", 462 | " # self.変数名で、以降も使用可能になります\n", 463 | " # https://pytorch.org/docs/stable/nn.html?highlight=register_buffer#torch.nn.Module.register_buffer\n", 464 | "\n", 465 | " def forward(self, input, obs_data=None):\n", 466 | " \"\"\" 順伝搬の計算\n", 467 | " Args:\n", 468 | " input (torch.Size([データ数, 観測変数の種類数])): 観測したデータ、もしくは生成されたデータ\n", 469 | " obs_data (torch.Size([データ数, 観測変数の種類数])):観測したデータ\n", 470 | " Returns:\n", 471 | " torch.Tensor: 観測したデータか、それとも生成されたデータかの判定結果\n", 472 | " \"\"\"\n", 473 | "\n", 474 | " if obs_data is not None:\n", 475 | " # 生成データを識別器に入力する場合\n", 476 | " return [self.layers(i) for i in torch.unbind(obs_data.unsqueeze(1) * (1 - self.mask)\n", 477 | " + input.unsqueeze(1) * self.mask, 1)]\n", 478 | " # 対角成分のみ生成したデータ、その他は観測データに\n", 479 | " # データを各変数ごとに、生成したもの、その他観測したもので混ぜて、1変数ずつ生成したものを放り込む\n", 480 | " # torch.unbind(x,1)はxの1次元目でテンソルをタプルに展開する\n", 481 | " # minibatch数が2000、観測データの変数が6種類の場合、\n", 482 | " # [2000,6]→[2000,6,6]→([2000,6], [2000,6], [2000,6], [2000,6], [2000,6], [2000,6])→([2000,1], [2000,1], [2000,1], [2000,1], [2000,1], [2000,1])\n", 483 | " # returnは[torch.Size([2000, 1]),torch.Size([2000, 1]),torch.Size([2000, 1], torch.Size([2000, 1]),torch.Size([2000, 1]),torch.Size([2000, 1])]\n", 484 | "\n", 485 | " # 注:生成した変数全種類を用いた判定はしない。\n", 486 | " # すなわち、生成した変数1種類と、元の観測データたちをまとめて1つにし、それが観測結果か、生成結果を判定させる\n", 487 | "\n", 488 | " else:\n", 489 | " # 観測データを識別器に入力する場合\n", 490 | "\n", 491 | " return self.layers(input)\n", 492 | " # returnは[torch.Size([2000, 1])]\n", 493 | "\n", 494 | "\n", 495 | " def reset_parameters(self):\n", 496 | " \"\"\"識別器Dの重みパラメータの初期化を実施\"\"\"\n", 497 | " for layer in self.layers:\n", 498 | " if hasattr(layer, 'reset_parameters'):\n", 499 | " layer.reset_parameters()\n" 500 | ], 501 | "execution_count": 0, 502 | "outputs": [] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": { 507 | "colab_type": "text", 508 | "id": "yLyjZsSc1S2i" 509 | }, 510 | "source": [ 511 | "### SAMの生成器Gの実装" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "metadata": { 517 | "colab_type": "code", 518 | "id": "pBUh-fKh8X-E", 519 | "outputId": "b30b006a-9c56-4e93-9185-1d4dc3d2abfb", 520 | "colab": { 521 | "base_uri": "https://localhost:8080/", 522 | "height": 72 523 | } 524 | }, 525 | "source": [ 526 | "from cdt.utils.torch import ChannelBatchNorm1d, MatrixSampler, Linear3D\n", 527 | "\n", 528 | "\n", 529 | "class SAMGenerator(nn.Module):\n", 530 | " \"\"\"SAMのGeneratorのニューラルネットワーク\n", 531 | " \"\"\"\n", 532 | "\n", 533 | " def __init__(self, data_shape, nh):\n", 534 | " \"\"\"初期化\"\"\"\n", 535 | " super(SAMGenerator, self).__init__()\n", 536 | "\n", 537 | " # ----------------------------------\n", 538 | " # 対角成分のみ0で、残りは1のmaskとなる変数skeletonを作成\n", 539 | " # ※最後の行は、全部1です\n", 540 | " # ----------------------------------\n", 541 | " nb_vars = data_shape[1] # 変数の数\n", 542 | " skeleton = 1 - torch.eye(nb_vars + 1, nb_vars)\n", 543 | "\n", 544 | " self.register_buffer('skeleton', skeleton)\n", 545 | "\n", 546 | " # 注意:register_bufferはmodelのパラメータではないが、その後forwardで使う変数を登録するPyTorchのメソッドです\n", 547 | " # self.変数名で、以降も使用可能になります\n", 548 | " # https://pytorch.org/docs/stable/nn.html?highlight=register_buffer#torch.nn.Module.register_buffer\n", 549 | "\n", 550 | " # ----------------------------------\n", 551 | " # ネットワークの用意\n", 552 | " # ----------------------------------\n", 553 | " # 入力層(SAMの形での全結合層) \n", 554 | " self.input_layer = Linear3D(\n", 555 | " (nb_vars, nb_vars + 1, nh)) # nhは中間層のニューロン数\n", 556 | " # https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/torch.py#L289\n", 557 | "\n", 558 | " # 中間層\n", 559 | " layers = []\n", 560 | " # 2次元を1次元に変換してバッチノーマライゼーションするモジュール\n", 561 | " layers.append(ChannelBatchNorm1d(nb_vars, nh))\n", 562 | " layers.append(nn.Tanh())\n", 563 | " self.layers = nn.Sequential(*layers)\n", 564 | "\n", 565 | " # ChannelBatchNorm1d\n", 566 | " # https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/torch.py#L130\n", 567 | "\n", 568 | " # 出力層(再度、SAMの形での全結合層)\n", 569 | " self.output_layer = Linear3D((nb_vars, nh, 1))\n", 570 | "\n", 571 | " def forward(self, data, noise, adj_matrix, drawn_neurons=None):\n", 572 | " \"\"\" 順伝搬の計算\n", 573 | " Args:\n", 574 | " data (torch.Tensor): 観測データ\n", 575 | " noise (torch.Tensor): データ生成用のノイズ\n", 576 | " adj_matrix (torch.Tensor): 因果関係を示す因果構造マトリクスM\n", 577 | " drawn_neurons (torch.Tensor): Linear3Dの複雑さを制御する複雑さマトリクスZ\n", 578 | " Returns:\n", 579 | " torch.Tensor: 生成されたデータ\n", 580 | " \"\"\"\n", 581 | "\n", 582 | " # 入力層\n", 583 | " x = self.input_layer(data, noise, adj_matrix *\n", 584 | " self.skeleton) # Linear3D\n", 585 | "\n", 586 | " # 中間層(バッチノーマライゼーションとTanh)\n", 587 | " x = self.layers(x)\n", 588 | "\n", 589 | " # 出力層\n", 590 | " output = self.output_layer(\n", 591 | " x, noise=None, adj_matrix=drawn_neurons) # Linear3D\n", 592 | "\n", 593 | " return output.squeeze(2)\n", 594 | "\n", 595 | " def reset_parameters(self):\n", 596 | " \"\"\"重みパラメータの初期化を実施\"\"\"\n", 597 | "\n", 598 | " self.input_layer.reset_parameters()\n", 599 | " self.output_layer.reset_parameters()\n", 600 | "\n", 601 | " for layer in self.layers:\n", 602 | " if hasattr(layer, 'reset_parameters'):\n", 603 | " layer.reset_parameters()\n" 604 | ], 605 | "execution_count": 0, 606 | "outputs": [ 607 | { 608 | "output_type": "stream", 609 | "text": [ 610 | "Detecting 1 CUDA device(s).\n", 611 | "sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.\n" 612 | ], 613 | "name": "stderr" 614 | } 615 | ] 616 | }, 617 | { 618 | "cell_type": "markdown", 619 | "metadata": { 620 | "colab_type": "text", 621 | "id": "2MubteRua0mj" 622 | }, 623 | "source": [ 624 | "### SAMの誤差関数" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "metadata": { 630 | "colab_type": "code", 631 | "id": "Hy2GqNNdapc6", 632 | "colab": {} 633 | }, 634 | "source": [ 635 | "# ネットワークを示す因果構造マトリクスMがDAG(有向非循環グラフ)になるように加える損失\n", 636 | "\n", 637 | "def notears_constr(adj_m, max_pow=None):\n", 638 | " \"\"\"No Tears constraint for binary adjacency matrixes. \n", 639 | " Args:\n", 640 | " adj_m (array-like): Adjacency matrix of the graph\n", 641 | " max_pow (int): maximum value to which the infinite sum is to be computed.\n", 642 | " defaults to the shape of the adjacency_matrix\n", 643 | " Returns:\n", 644 | " np.ndarray or torch.Tensor: Scalar value of the loss with the type\n", 645 | " depending on the input.\n", 646 | " 参考:https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/loss.py#L215\n", 647 | " \"\"\"\n", 648 | " m_exp = [adj_m]\n", 649 | " if max_pow is None:\n", 650 | " max_pow = adj_m.shape[1]\n", 651 | " while(m_exp[-1].sum() > 0 and len(m_exp) < max_pow):\n", 652 | " m_exp.append(m_exp[-1] @ adj_m/len(m_exp))\n", 653 | "\n", 654 | " return sum([i.diag().sum() for idx, i in enumerate(m_exp)])\n", 655 | " " 656 | ], 657 | "execution_count": 0, 658 | "outputs": [] 659 | }, 660 | { 661 | "cell_type": "markdown", 662 | "metadata": { 663 | "colab_type": "text", 664 | "id": "d01nY6IKKmXe" 665 | }, 666 | "source": [ 667 | "### SAMの学習を実施する関数" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "metadata": { 673 | "colab_type": "code", 674 | "id": "LdgNruwmJkxj", 675 | "colab": {} 676 | }, 677 | "source": [ 678 | "from sklearn.preprocessing import scale\n", 679 | "from torch import optim\n", 680 | "from torch.utils.data import DataLoader\n", 681 | "from tqdm import tqdm\n", 682 | "\n", 683 | "\n", 684 | "def run_SAM(in_data, lr_gen, lr_disc, lambda1, lambda2, hlayers, nh, dnh, train_epochs, test_epochs, device):\n", 685 | " '''SAMの学習を実行する関数'''\n", 686 | "\n", 687 | " # ---------------------------------------------------\n", 688 | " # 入力データの前処理\n", 689 | " # ---------------------------------------------------\n", 690 | " list_nodes = list(in_data.columns) # 入力データの列名のリスト\n", 691 | " data = scale(in_data[list_nodes].values) # 入力データの正規化\n", 692 | " nb_var = len(list_nodes) # 入力データの数 = d\n", 693 | " data = data.astype('float32') # 入力データをfloat32型に\n", 694 | " data = torch.from_numpy(data).to(device) # 入力データをPyTorchのテンソルに\n", 695 | " rows, cols = data.size() # rowsはデータ数、colsは変数の数\n", 696 | "\n", 697 | " # ---------------------------------------------------\n", 698 | " # DataLoaderの作成(バッチサイズは全データ)\n", 699 | " # ---------------------------------------------------\n", 700 | " batch_size = rows # 入力データ全てを使用したミニバッチ学習とする\n", 701 | " data_iterator = DataLoader(data, batch_size=batch_size,\n", 702 | " shuffle=True, drop_last=True)\n", 703 | " # 注意:引数のdrop_lastはdataをbatch_sizeで取り出していったときに最後に余ったものは使用しない設定\n", 704 | "\n", 705 | " # ---------------------------------------------------\n", 706 | " # 【Generator】ネットワークの生成とパラメータの初期化\n", 707 | " # cols:入力変数の数、nhは中間ニューロンの数、hlayersは中間層の数\n", 708 | " # neuron_samplerは、Functional gatesの変数zを学習するネットワーク\n", 709 | " # graph_samplerは、Structual gatesの変数aを学習するネットワーク\n", 710 | " # ---------------------------------------------------\n", 711 | " sam = SAMGenerator((batch_size, cols), nh).to(device) # 生成器G\n", 712 | " graph_sampler = MatrixSampler(nb_var, mask=None, gumbel=False).to(\n", 713 | " device) # 因果構造マトリクスMを作るネットワーク\n", 714 | " neuron_sampler = MatrixSampler((nh, nb_var), mask=False, gumbel=True).to(\n", 715 | " device) # 複雑さマトリクスZを作るネットワーク\n", 716 | "\n", 717 | " # 注意:MatrixSamplerはGumbel-Softmaxを使用し、0か1を出力させるニューラルネットワーク\n", 718 | " # SAMの著者らの実装モジュール、MatrixSamplerを使用\n", 719 | " # https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/torch.py#L212\n", 720 | "\n", 721 | " # 重みパラメータの初期化\n", 722 | " sam.reset_parameters()\n", 723 | " graph_sampler.weights.data.fill_(2)\n", 724 | "\n", 725 | " # ---------------------------------------------------\n", 726 | " # 【Discriminator】ネットワークの生成とパラメータの初期化\n", 727 | " # cols:入力変数の数、dnhは中間ニューロンの数、hlayersは中間層の数。\n", 728 | " # ---------------------------------------------------\n", 729 | " discriminator = SAMDiscriminator(cols, dnh, hlayers).to(device)\n", 730 | " discriminator.reset_parameters() # 重みパラメータの初期化\n", 731 | "\n", 732 | " # ---------------------------------------------------\n", 733 | " # 最適化の設定\n", 734 | " # ---------------------------------------------------\n", 735 | " # 生成器\n", 736 | "\n", 737 | " g_optimizer = optim.Adam(sam.parameters(), lr=lr_gen)\n", 738 | " graph_optimizer = optim.Adam(graph_sampler.parameters(), lr=lr_gen)\n", 739 | " neuron_optimizer = optim.Adam(neuron_sampler.parameters(), lr=lr_gen)\n", 740 | "\n", 741 | " # 識別器\n", 742 | " d_optimizer = optim.Adam(discriminator.parameters(), lr=lr_disc)\n", 743 | "\n", 744 | " # 損失関数\n", 745 | " criterion = nn.BCEWithLogitsLoss()\n", 746 | " # nn.BCEWithLogitsLoss()は、binary cross entropy with Logistic function\n", 747 | " # https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss\n", 748 | "\n", 749 | " # 損失関数のDAGに関する制約の設定パラメータ\n", 750 | " dagstart = 0.5\n", 751 | " dagpenalization_increase = 0.001*10\n", 752 | "\n", 753 | " # ---------------------------------------------------\n", 754 | " # forward計算、および損失関数の計算に使用する変数を用意\n", 755 | " # ---------------------------------------------------\n", 756 | " _true = torch.ones(1).to(device)\n", 757 | " _false = torch.zeros(1).to(device)\n", 758 | "\n", 759 | " noise = torch.randn(batch_size, nb_var).to(device) # 生成器Gで使用する生成ノイズ\n", 760 | " noise_row = torch.ones(1, nb_var).to(device)\n", 761 | "\n", 762 | " output = torch.zeros(nb_var, nb_var).to(device) # 求まった隣接行列\n", 763 | " output_loss = torch.zeros(1, 1).to(device)\n", 764 | "\n", 765 | " # ---------------------------------------------------\n", 766 | " # forwardの計算で、ネットワークを学習させる\n", 767 | " # ---------------------------------------------------\n", 768 | " pbar = tqdm(range(train_epochs + test_epochs)) # 進捗(progressive bar)の表示\n", 769 | "\n", 770 | " for epoch in pbar:\n", 771 | " for i_batch, batch in enumerate(data_iterator):\n", 772 | "\n", 773 | " # 最適化を初期化\n", 774 | " g_optimizer.zero_grad()\n", 775 | " graph_optimizer.zero_grad()\n", 776 | " neuron_optimizer.zero_grad()\n", 777 | " d_optimizer.zero_grad()\n", 778 | "\n", 779 | " # 因果構造マトリクスM(drawn_graph)と複雑さマトリクスZ(drawn_neurons)をMatrixSamplerから取得\n", 780 | " drawn_graph = graph_sampler()\n", 781 | " drawn_neurons = neuron_sampler()\n", 782 | " # (drawn_graph)のサイズは、torch.Size([nb_var, nb_var])。 出力値は0か1\n", 783 | " # (drawn_neurons)のサイズは、torch.Size([nh, nb_var])。 出力値は0か1\n", 784 | "\n", 785 | " # ノイズをリセットし、生成器Gで疑似データを生成\n", 786 | " noise.normal_()\n", 787 | " generated_variables = sam(data=batch, noise=noise,\n", 788 | " adj_matrix=torch.cat(\n", 789 | " [drawn_graph, noise_row], 0),\n", 790 | " drawn_neurons=drawn_neurons)\n", 791 | "\n", 792 | " # 識別器Dで判定\n", 793 | " # 観測変数のリスト[]で、各torch.Size([data数, 1])が求まる\n", 794 | " disc_vars_d = discriminator(generated_variables.detach(), batch)\n", 795 | " # 観測変数のリスト[] で、各torch.Size([data数, 1])が求まる\n", 796 | " disc_vars_g = discriminator(generated_variables, batch)\n", 797 | " true_vars_disc = discriminator(batch) # torch.Size([data数, 1])が求まる\n", 798 | "\n", 799 | " # 損失関数の計算(DCGAN)\n", 800 | " disc_loss = sum([criterion(gen, _false.expand_as(gen)) for gen in disc_vars_d]) / nb_var \\\n", 801 | " + criterion(true_vars_disc, _true.expand_as(true_vars_disc))\n", 802 | "\n", 803 | " gen_loss = sum([criterion(gen,\n", 804 | " _true.expand_as(gen))\n", 805 | " for gen in disc_vars_g])\n", 806 | "\n", 807 | " # 損失の計算(SAM論文のオリジナルのfgan)\n", 808 | " #disc_loss = sum([torch.mean(torch.exp(gen - 1)) for gen in disc_vars_d]) / nb_var - torch.mean(true_vars_disc)\n", 809 | " #gen_loss = -sum([torch.mean(torch.exp(gen - 1)) for gen in disc_vars_g])\n", 810 | "\n", 811 | " # 識別器Dのバックプロパゲーションとパラメータの更新\n", 812 | " if epoch < train_epochs:\n", 813 | " disc_loss.backward()\n", 814 | " d_optimizer.step()\n", 815 | "\n", 816 | " # 生成器のGの損失の計算の残り(マトリクスの複雑さとDAGのNO TEAR)\n", 817 | " struc_loss = lambda1 / batch_size*drawn_graph.sum() # Mのloss\n", 818 | " func_loss = lambda2 / batch_size*drawn_neurons.sum() # Aのloss\n", 819 | "\n", 820 | " regul_loss = struc_loss + func_loss\n", 821 | "\n", 822 | " if epoch <= train_epochs * dagstart:\n", 823 | " # epochが基準前のときは、DAGになるようにMへのNO TEARSの制限はかけない\n", 824 | " loss = gen_loss + regul_loss\n", 825 | "\n", 826 | " else:\n", 827 | " # epochが基準後のときは、DAGになるようにNO TEARSの制限をかける\n", 828 | " filters = graph_sampler.get_proba() # マトリクスMの要素を取得(ただし、0,1ではなく、1の確率)\n", 829 | " dag_constraint = notears_constr(filters*filters) # NO TERARの計算\n", 830 | "\n", 831 | " # 徐々に線形にDAGの正則を強くする\n", 832 | " loss = gen_loss + regul_loss + \\\n", 833 | " ((epoch - train_epochs * dagstart) *\n", 834 | " dagpenalization_increase) * dag_constraint\n", 835 | "\n", 836 | " if epoch >= train_epochs:\n", 837 | " # testのepochの場合、結果を取得\n", 838 | " output.add_(filters.data)\n", 839 | " output_loss.add_(gen_loss.data)\n", 840 | " else:\n", 841 | " # trainのepochの場合、生成器Gのバックプロパゲーションと更新\n", 842 | " # retain_graph=Trueにすることで、以降3つのstep()が実行できる\n", 843 | " loss.backward(retain_graph=True)\n", 844 | " g_optimizer.step()\n", 845 | " graph_optimizer.step()\n", 846 | " neuron_optimizer.step()\n", 847 | "\n", 848 | " # 進捗の表示\n", 849 | " if epoch % 50 == 0:\n", 850 | " pbar.set_postfix(gen=gen_loss.item()/cols,\n", 851 | " disc=disc_loss.item(),\n", 852 | " regul_loss=regul_loss.item(),\n", 853 | " tot=loss.item())\n", 854 | "\n", 855 | " return output.cpu().numpy()/test_epochs, output_loss.cpu().numpy()/test_epochs/cols # Mと損失を出力\n" 856 | ], 857 | "execution_count": 0, 858 | "outputs": [] 859 | }, 860 | { 861 | "cell_type": "markdown", 862 | "metadata": { 863 | "colab_type": "text", 864 | "id": "S5SXuXOCUgmg" 865 | }, 866 | "source": [ 867 | "### GPUの使用可能を確認\n", 868 | "\n", 869 | "画面上部のメニュー ランタイム > ランタイムのタイプを変更 で、 ノートブックの設定 を開く\n", 870 | "\n", 871 | "ハードウェアアクセラレータに GPU を選択し、 保存 する" 872 | ] 873 | }, 874 | { 875 | "cell_type": "code", 876 | "metadata": { 877 | "colab_type": "code", 878 | "id": "ClTdYzxzXsL2", 879 | "outputId": "854e7cdb-d51e-4cb4-fbc6-9830415f4c44", 880 | "colab": { 881 | "base_uri": "https://localhost:8080/", 882 | "height": 34 883 | } 884 | }, 885 | "source": [ 886 | "# GPUの使用確認:True or False\n", 887 | "torch.cuda.is_available()\n" 888 | ], 889 | "execution_count": 0, 890 | "outputs": [ 891 | { 892 | "output_type": "execute_result", 893 | "data": { 894 | "text/plain": [ 895 | "True" 896 | ] 897 | }, 898 | "metadata": { 899 | "tags": [] 900 | }, 901 | "execution_count": 12 902 | } 903 | ] 904 | }, 905 | { 906 | "cell_type": "markdown", 907 | "metadata": { 908 | "colab_type": "text", 909 | "id": "R-FzZ-W3Xseu" 910 | }, 911 | "source": [ 912 | "### SAMの学習を実施" 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "metadata": { 918 | "colab_type": "code", 919 | "id": "xfqAztolY1fo", 920 | "outputId": "8489d950-76d7-46fb-c951-6381c4f871d2", 921 | "colab": { 922 | "base_uri": "https://localhost:8080/", 923 | "height": 826 924 | } 925 | }, 926 | "source": [ 927 | "# numpyの出力を小数点2桁に\n", 928 | "np.set_printoptions(precision=2, floatmode='fixed', suppress=True)\n", 929 | "\n", 930 | "# 因果探索の結果を格納するリスト\n", 931 | "m_list = []\n", 932 | "loss_list = []\n", 933 | "\n", 934 | "for i in range(5):\n", 935 | " m, loss = run_SAM(in_data=df, lr_gen=0.01*0.5,\n", 936 | " lr_disc=0.01*0.5*2,\n", 937 | " #lambda1=0.01, lambda2=1e-05,\n", 938 | " lambda1=5.0*20, lambda2=0.005*20,\n", 939 | " hlayers=2,\n", 940 | " nh=200, dnh=200,\n", 941 | " train_epochs=10000,\n", 942 | " test_epochs=1000,\n", 943 | " device='cuda:0')\n", 944 | "\n", 945 | " print(loss)\n", 946 | " print(m)\n", 947 | "\n", 948 | " m_list.append(m)\n", 949 | " loss_list.append(loss)\n", 950 | "\n", 951 | "# ネットワーク構造(5回の平均)\n", 952 | "print(sum(m_list) / len(m_list))\n", 953 | "\n", 954 | "# mはこうなって欲しい\n", 955 | "# x Z Y Y2 Y3 Y4\n", 956 | "# x 0 1 1 0 0 0\n", 957 | "# Z 0 0 1 0 0 0\n", 958 | "# Y 0 0 0 0 1 0\n", 959 | "# Y2 0 0 0 0 1 0\n", 960 | "# Y3 0 0 0 0 0 1\n", 961 | "# Y4 0 0 0 0 0 0\n" 962 | ], 963 | "execution_count": 0, 964 | "outputs": [ 965 | { 966 | "output_type": "stream", 967 | "text": [ 968 | "100%|██████████| 11000/11000 [05:20<00:00, 34.29it/s, disc=0.259, gen=5.63, regul_loss=0.564, tot=42.9]\n", 969 | " 0%| | 4/11000 [00:00<05:14, 34.97it/s, disc=1.43, gen=0.626, regul_loss=1.48, tot=5.23]" 970 | ], 971 | "name": "stderr" 972 | }, 973 | { 974 | "output_type": "stream", 975 | "text": [ 976 | "[[7.23]]\n", 977 | "[[0.00 0.11 0.96 0.00 0.01 0.00]\n", 978 | " [0.37 0.00 0.96 0.00 0.81 0.00]\n", 979 | " [0.00 0.03 0.00 0.99 1.00 0.66]\n", 980 | " [0.02 0.00 0.00 0.00 0.07 0.00]\n", 981 | " [0.02 0.00 0.02 1.00 0.00 0.98]\n", 982 | " [0.00 0.00 0.04 0.59 0.25 0.00]]\n" 983 | ], 984 | "name": "stdout" 985 | }, 986 | { 987 | "output_type": "stream", 988 | "text": [ 989 | "100%|██████████| 11000/11000 [05:22<00:00, 34.15it/s, disc=0.301, gen=5.6, regul_loss=0.515, tot=40.2]\n", 990 | " 0%| | 3/11000 [00:00<06:53, 26.59it/s, disc=1.46, gen=0.8, regul_loss=1.38, tot=6.18]" 991 | ], 992 | "name": "stderr" 993 | }, 994 | { 995 | "output_type": "stream", 996 | "text": [ 997 | "[[7.37]]\n", 998 | "[[0.00 1.00 0.99 0.00 0.38 0.14]\n", 999 | " [0.05 0.00 0.98 0.00 0.19 0.94]\n", 1000 | " [0.03 0.10 0.00 1.00 0.24 0.03]\n", 1001 | " [0.00 0.00 0.00 0.00 0.10 0.03]\n", 1002 | " [0.05 0.00 0.09 0.98 0.00 0.23]\n", 1003 | " [0.03 0.01 0.33 0.04 0.66 0.00]]\n" 1004 | ], 1005 | "name": "stdout" 1006 | }, 1007 | { 1008 | "output_type": "stream", 1009 | "text": [ 1010 | "100%|██████████| 11000/11000 [05:21<00:00, 34.18it/s, disc=0.666, gen=6.01, regul_loss=0.412, tot=41.8]\n", 1011 | " 0%| | 4/11000 [00:00<04:54, 37.32it/s, disc=1.46, gen=0.887, regul_loss=1.48, tot=6.8]" 1012 | ], 1013 | "name": "stderr" 1014 | }, 1015 | { 1016 | "output_type": "stream", 1017 | "text": [ 1018 | "[[4.51]]\n", 1019 | "[[0.00 0.96 0.96 0.00 0.33 0.02]\n", 1020 | " [0.05 0.00 0.14 0.00 0.00 0.35]\n", 1021 | " [0.00 0.94 0.00 0.99 0.99 0.98]\n", 1022 | " [0.02 0.00 0.00 0.00 0.10 0.00]\n", 1023 | " [0.00 0.00 0.02 0.99 0.00 0.97]\n", 1024 | " [0.01 0.00 0.01 0.00 0.12 0.00]]\n" 1025 | ], 1026 | "name": "stdout" 1027 | }, 1028 | { 1029 | "output_type": "stream", 1030 | "text": [ 1031 | "100%|██████████| 11000/11000 [05:23<00:00, 34.01it/s, disc=0.409, gen=5.87, regul_loss=0.365, tot=39.8]\n", 1032 | " 0%| | 4/11000 [00:00<04:48, 38.08it/s, disc=1.41, gen=0.776, regul_loss=1.53, tot=6.19]" 1033 | ], 1034 | "name": "stderr" 1035 | }, 1036 | { 1037 | "output_type": "stream", 1038 | "text": [ 1039 | "[[5.35]]\n", 1040 | "[[0.00 1.00 0.15 0.02 0.00 0.00]\n", 1041 | " [0.02 0.00 1.00 0.00 0.00 0.00]\n", 1042 | " [0.01 0.07 0.00 1.00 1.00 0.02]\n", 1043 | " [0.04 0.00 0.00 0.00 0.09 0.00]\n", 1044 | " [0.00 0.00 0.01 1.00 0.00 0.99]\n", 1045 | " [0.09 0.03 0.04 0.01 0.14 0.00]]\n" 1046 | ], 1047 | "name": "stdout" 1048 | }, 1049 | { 1050 | "output_type": "stream", 1051 | "text": [ 1052 | "100%|██████████| 11000/11000 [05:23<00:00, 34.01it/s, disc=0.597, gen=4.92, regul_loss=0.413, tot=32.8]" 1053 | ], 1054 | "name": "stderr" 1055 | }, 1056 | { 1057 | "output_type": "stream", 1058 | "text": [ 1059 | "[[4.98]]\n", 1060 | "[[0.00 0.97 0.01 0.00 0.00 0.00]\n", 1061 | " [0.06 0.00 0.06 0.00 0.00 0.00]\n", 1062 | " [0.73 0.99 0.00 0.99 1.00 0.03]\n", 1063 | " [0.92 0.00 0.00 0.00 0.08 0.00]\n", 1064 | " [0.18 0.00 0.00 0.98 0.00 1.00]\n", 1065 | " [0.19 0.00 0.00 0.00 0.09 0.00]]\n", 1066 | "[[0.00 0.81 0.62 0.01 0.14 0.03]\n", 1067 | " [0.11 0.00 0.63 0.00 0.20 0.26]\n", 1068 | " [0.16 0.42 0.00 1.00 0.85 0.34]\n", 1069 | " [0.20 0.00 0.00 0.00 0.09 0.01]\n", 1070 | " [0.05 0.00 0.03 0.99 0.00 0.84]\n", 1071 | " [0.06 0.01 0.08 0.13 0.25 0.00]]\n" 1072 | ], 1073 | "name": "stdout" 1074 | }, 1075 | { 1076 | "output_type": "stream", 1077 | "text": [ 1078 | "\n" 1079 | ], 1080 | "name": "stderr" 1081 | } 1082 | ] 1083 | }, 1084 | { 1085 | "cell_type": "markdown", 1086 | "metadata": { 1087 | "colab_type": "text", 1088 | "id": "MGNG7pzi8LI6" 1089 | }, 1090 | "source": [ 1091 | "以上" 1092 | ] 1093 | }, 1094 | { 1095 | "cell_type": "code", 1096 | "metadata": { 1097 | "id": "S9LudNsLxfkd", 1098 | "colab_type": "code", 1099 | "colab": {} 1100 | }, 1101 | "source": [ 1102 | "" 1103 | ], 1104 | "execution_count": 0, 1105 | "outputs": [] 1106 | } 1107 | ] 1108 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yutaro Ogawa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## つくりながら学ぶ! Python による因果分析 ~因果推論・因果探索の実践入門 2 | 3 |
4 | 5 |
6 | 7 |
8 | 9 | [書籍「つくりながら学ぶ! Python による因果分析 ~因果推論・因果探索の実践入門」(小川雄太郎、マイナビ出版 、20/06/30) ](https://www.amazon.co.jp/dp/4839973571/) 10 | 11 | のサポートリポジトリです。 12 | 13 |
14 | 15 | ### 1. 本書で扱う内容 16 | 17 | 本書の概要を以下の記事で解説しております。 18 | 19 | [「Python による因果推論と因果探索(初心者の方向け)」](https://qiita.com/sugulu/items/2cffb239b44853b07f70) 20 | 21 |
22 | 23 | **本書の目次** 24 | 25 | - 第 1 章 相関と因果の違いを理解しよう 26 | - 第 2 章 因果効果の種類を把握しよう 27 | - 第 3 章 グラフ表現とバックドア基準を理解しよう 28 | - 第 4 章 因果推論を実装しよう 29 | - 4-1  回帰分析による因果推論の実装 30 | - 4-2  傾向スコアを用いた逆確率重み付け法(IPTW)の実装 31 | - 4-3   Doubly Robust 法(DR 法)による因果推論の実装 32 | - 第 5 章 機械学習を用いた因果推論 33 | - 5-1  ランダムフォレストによる分類と回帰のしくみ 34 | - 5-2   Meta-Learners(T-Learner、S-Learner、X-Learner)の実装 35 | - 5-3   Doubly Robust Learning の実装 36 | - 第 6 章  LiNGAM の実装 37 | - 6-1   LiNGAM(Linear Non-Gaussian Acyclic Model)とは 38 | - 6-2  独立成分分析とは 39 | - 6-3   LiNGAM による因果探索の実装 40 | - 第 7 章 ベイジアンネットワークの実装 41 | - 7-1  ベイジアンネットワークとは 42 | - 7-2  ネットワークの当てはまりの良さを測る方法 43 | - 7-3  変数間の独立性の検定 44 | - 7-4  3タイプのベイジアンネットワークの探索手法 45 | - 7-5   PC アルゴリズムによるベイジアンネットワーク探索の実装 46 | - 第 8 章 ディープラーニングを用いた因果探索 47 | - 8-1  因果探索と GAN(Generative Adversarial Networks)の関係 48 | - 8-2   SAM(Structural Agnostic Model)の概要 49 | - 8-3   SAM の識別器 D と生成器 G の実装 50 | - 8-4   SAM の損失関数の解説と因果探索の実装 51 | - 8-5   Google Colaboratory で GPU を使用した因果探索の実行 52 | 53 |
54 | 55 | ### 2. 疑問点・修正点は Issue にて管理しています 56 | 57 | 本 GitHub の Issue にて、疑問点や修正点を管理しています。 58 | 59 | 不明な点などがございましたら、こちらをご覧ください。 60 | 61 | https://github.com/YutaroOgawa/causal_book/issues 62 | 63 |
64 | 65 | ### 3. 誤植について 66 | 67 | 書籍中の誤植一覧はこちらになります。 68 | 大変申し訳ございません。 69 | 70 | [誤植一覧](https://github.com/YutaroOgawa/causal_book/labels/%E8%AA%A4%E6%A4%8D) 71 | -------------------------------------------------------------------------------- /etc/book.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YutaroOgawa/causal_book/bde8891125b8c85bb8cf521c8cba0ec08226b2b9/etc/book.jpg --------------------------------------------------------------------------------