├── 1_3_pseudo_correlation.ipynb
├── 4_1_reggresion_adjustment.ipynb
├── 4_2_iptw.ipynb
├── 4_3_dr.ipynb
├── 5_1_randomforest.ipynb
├── 5_2_meta_learners.ipynb
├── 5_2_meta_learners_issue18.ipynb
├── 5_2_meta_learners_issue18_issue36.ipynb
├── 5_3_doubly_robust_learning.ipynb
├── 5_3_doubly_robust_learning_issue18.ipynb
├── 6_3_lingam.ipynb
├── 7_2_bayesian_network_bic.ipynb
├── 7_3_bayesian_network_independence_test.ipynb
├── 7_5_bayesian_network_pc_algorithm.ipynb
├── 7_5_bayesian_network_pc_algorithm_20220421.ipynb
├── 7_5_bayesian_network_pc_algorithm_210410.ipynb
├── 8_3_5_deeplearning_gan_sam.ipynb
├── LICENSE
├── README.md
└── etc
└── book.jpg
/4_1_reggresion_adjustment.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "4_1_reggresion_adjustment.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "d-IAJLC2k1NX",
20 | "colab_type": "text"
21 | },
22 | "source": [
23 | "# 4.1 回帰分析による因果推論\n",
24 | "\n",
25 | "本ファイルは、4.1節の実装です。\n",
26 | "\n",
27 | "テレビCMの広告効果の推定を例に、回帰分析による因果推論を実装します。\n"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "2XdIDbdlejUk",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "## プログラム実行前の設定など"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "id": "wqHjwstVeXYt",
44 | "colab_type": "code",
45 | "colab": {}
46 | },
47 | "source": [
48 | "# 乱数のシードを設定\n",
49 | "import random\n",
50 | "import numpy as np\n",
51 | "\n",
52 | "np.random.seed(1234)\n",
53 | "random.seed(1234)\n"
54 | ],
55 | "execution_count": 0,
56 | "outputs": []
57 | },
58 | {
59 | "cell_type": "code",
60 | "metadata": {
61 | "id": "RIhcLRqlem3V",
62 | "colab_type": "code",
63 | "colab": {}
64 | },
65 | "source": [
66 | "# 使用するパッケージ(ライブラリと関数)を定義\n",
67 | "# 標準正規分布の生成用\n",
68 | "from numpy.random import *\n",
69 | "\n",
70 | "# グラフの描画用\n",
71 | "import matplotlib.pyplot as plt\n",
72 | "\n",
73 | "# SciPy 平均0、分散1に正規化(標準化)関数\n",
74 | "import scipy.stats\n",
75 | "\n",
76 | "# シグモイド関数をimport\n",
77 | "from scipy.special import expit\n",
78 | "\n",
79 | "# その他\n",
80 | "import pandas as pd\n"
81 | ],
82 | "execution_count": 0,
83 | "outputs": []
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {
88 | "id": "AWqP6yeQlI_t",
89 | "colab_type": "text"
90 | },
91 | "source": [
92 | "## データの作成"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "metadata": {
98 | "id": "EJaQIHz4fNXb",
99 | "colab_type": "code",
100 | "colab": {}
101 | },
102 | "source": [
103 | "# データ数\n",
104 | "num_data = 200\n",
105 | "\n",
106 | "# 年齢\n",
107 | "x_1 = randint(15, 76, num_data) # 15から75歳の一様乱数\n",
108 | "\n",
109 | "# 性別(0を女性、1を男性とします)\n",
110 | "x_2 = randint(0, 2, num_data) # 0か1の乱数\n"
111 | ],
112 | "execution_count": 0,
113 | "outputs": []
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "JiT_gc5ZmAQa",
119 | "colab_type": "text"
120 | },
121 | "source": [
122 | "## テレビCMを見たかどうか"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "metadata": {
128 | "id": "hfPbhhm6gvW6",
129 | "colab_type": "code",
130 | "colab": {}
131 | },
132 | "source": [
133 | "# ノイズの生成\n",
134 | "e_z = randn(num_data)\n",
135 | "\n",
136 | "# シグモイド関数に入れる部分\n",
137 | "z_base = x_1 + (1-x_2)*10 - 40 + 5*e_z\n",
138 | "\n",
139 | "# シグモイド関数を計算\n",
140 | "z_prob = expit(0.1*z_base)\n",
141 | "\n",
142 | "# テレビCMを見たかどうかの変数(0は見ていない、1は見た)\n",
143 | "Z = np.array([])\n",
144 | "\n",
145 | "for i in range(num_data):\n",
146 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n",
147 | " Z = np.append(Z, Z_i)\n"
148 | ],
149 | "execution_count": 0,
150 | "outputs": []
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {
155 | "id": "b2PLquJGi2Te",
156 | "colab_type": "text"
157 | },
158 | "source": [
159 | "## 購入量Yを作成"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "metadata": {
165 | "id": "nv-ELtFqi5L5",
166 | "colab_type": "code",
167 | "colab": {}
168 | },
169 | "source": [
170 | "# ノイズの生成\n",
171 | "e_y = randn(num_data)\n",
172 | "\n",
173 | "Y = -x_1 + 30*x_2 + 10*Z + 80 + 10*e_y\n"
174 | ],
175 | "execution_count": 0,
176 | "outputs": []
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {
181 | "id": "BHcdUlW9koTa",
182 | "colab_type": "text"
183 | },
184 | "source": [
185 | "## データをまとめた表を作成し、平均値を比べる"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "metadata": {
191 | "id": "HPqwrISXktRj",
192 | "colab_type": "code",
193 | "colab": {
194 | "base_uri": "https://localhost:8080/",
195 | "height": 195
196 | },
197 | "outputId": "2afcd49a-b744-4b4a-db32-838377fa0305"
198 | },
199 | "source": [
200 | "df = pd.DataFrame({'年齢': x_1,\n",
201 | " '性別': x_2,\n",
202 | " 'CMを見た': Z,\n",
203 | " '購入量': Y,\n",
204 | " })\n",
205 | "\n",
206 | "df.head() # 先頭を表示\n"
207 | ],
208 | "execution_count": 6,
209 | "outputs": [
210 | {
211 | "output_type": "execute_result",
212 | "data": {
213 | "text/html": [
214 | "
\n",
215 | "\n",
228 | "
\n",
229 | " \n",
230 | " \n",
231 | " | \n",
232 | " 年齢 | \n",
233 | " 性別 | \n",
234 | " CMを見た | \n",
235 | " 購入量 | \n",
236 | "
\n",
237 | " \n",
238 | " \n",
239 | " \n",
240 | " 0 | \n",
241 | " 62 | \n",
242 | " 0 | \n",
243 | " 1.0 | \n",
244 | " 24.464285 | \n",
245 | "
\n",
246 | " \n",
247 | " 1 | \n",
248 | " 34 | \n",
249 | " 0 | \n",
250 | " 0.0 | \n",
251 | " 45.693411 | \n",
252 | "
\n",
253 | " \n",
254 | " 2 | \n",
255 | " 53 | \n",
256 | " 1 | \n",
257 | " 1.0 | \n",
258 | " 64.998281 | \n",
259 | "
\n",
260 | " \n",
261 | " 3 | \n",
262 | " 68 | \n",
263 | " 1 | \n",
264 | " 1.0 | \n",
265 | " 47.186898 | \n",
266 | "
\n",
267 | " \n",
268 | " 4 | \n",
269 | " 27 | \n",
270 | " 1 | \n",
271 | " 0.0 | \n",
272 | " 100.114260 | \n",
273 | "
\n",
274 | " \n",
275 | "
\n",
276 | "
"
277 | ],
278 | "text/plain": [
279 | " 年齢 性別 CMを見た 購入量\n",
280 | "0 62 0 1.0 24.464285\n",
281 | "1 34 0 0.0 45.693411\n",
282 | "2 53 1 1.0 64.998281\n",
283 | "3 68 1 1.0 47.186898\n",
284 | "4 27 1 0.0 100.114260"
285 | ]
286 | },
287 | "metadata": {
288 | "tags": []
289 | },
290 | "execution_count": 6
291 | }
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "metadata": {
297 | "id": "HHInQ1Sukrg0",
298 | "colab_type": "code",
299 | "colab": {
300 | "base_uri": "https://localhost:8080/",
301 | "height": 210
302 | },
303 | "outputId": "40cbd6c0-df68-4eec-c8c7-8a518aa4d52c"
304 | },
305 | "source": [
306 | "# 平均値を比べる\n",
307 | "\n",
308 | "print(df[df[\"CMを見た\"] == 1.0].mean())\n",
309 | "print(\"--------\")\n",
310 | "print(df[df[\"CMを見た\"] == 0.0].mean())\n"
311 | ],
312 | "execution_count": 7,
313 | "outputs": [
314 | {
315 | "output_type": "stream",
316 | "text": [
317 | "年齢 55.836066\n",
318 | "性別 0.483607\n",
319 | "CMを見た 1.000000\n",
320 | "購入量 49.711478\n",
321 | "dtype: float64\n",
322 | "--------\n",
323 | "年齢 32.141026\n",
324 | "性別 0.692308\n",
325 | "CMを見た 0.000000\n",
326 | "購入量 68.827143\n",
327 | "dtype: float64\n"
328 | ],
329 | "name": "stdout"
330 | }
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "metadata": {
336 | "id": "kwKOk59aogBd",
337 | "colab_type": "text"
338 | },
339 | "source": [
340 | "## 回帰分析を実施"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "metadata": {
346 | "id": "rBtHC3smoiMC",
347 | "colab_type": "code",
348 | "colab": {
349 | "base_uri": "https://localhost:8080/",
350 | "height": 34
351 | },
352 | "outputId": "e5a4dee9-e80a-4c83-b32b-8461d494eb00"
353 | },
354 | "source": [
355 | "# scikit-learnから線形回帰をimport\n",
356 | "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html\n",
357 | "from sklearn.linear_model import LinearRegression\n",
358 | "\n",
359 | "# 説明変数\n",
360 | "X = df[[\"年齢\", \"性別\", \"CMを見た\"]]\n",
361 | "\n",
362 | "# 被説明変数(目的変数)\n",
363 | "y = df[\"購入量\"]\n",
364 | "\n",
365 | "# 回帰の実施\n",
366 | "reg = LinearRegression().fit(X, y)\n",
367 | "\n",
368 | "# 回帰した結果の係数を出力\n",
369 | "print(\"係数:\", reg.coef_)\n"
370 | ],
371 | "execution_count": 8,
372 | "outputs": [
373 | {
374 | "output_type": "stream",
375 | "text": [
376 | "係数: [-0.95817951 32.70149412 10.41327647]\n"
377 | ],
378 | "name": "stdout"
379 | }
380 | ]
381 | },
382 | {
383 | "cell_type": "markdown",
384 | "metadata": {
385 | "id": "1IdVhXmMps-w",
386 | "colab_type": "text"
387 | },
388 | "source": [
389 | "以上"
390 | ]
391 | }
392 | ]
393 | }
--------------------------------------------------------------------------------
/4_2_iptw.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "4_2_iptw.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "d-IAJLC2k1NX",
20 | "colab_type": "text"
21 | },
22 | "source": [
23 | "# 4.2 逆確率重み付け法(IPTW)による因果推論\n",
24 | "\n",
25 | "本ファイルは、4.2節の実装です。\n",
26 | "\n",
27 | "4.1節と同じく、テレビCMの広告効果の推定を例に、回帰分析による因果推論を実装します。\n"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "2XdIDbdlejUk",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "## プログラム実行前の設定など"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "id": "wqHjwstVeXYt",
44 | "colab_type": "code",
45 | "colab": {}
46 | },
47 | "source": [
48 | "# 乱数のシードを設定\n",
49 | "import random\n",
50 | "import numpy as np\n",
51 | "\n",
52 | "np.random.seed(1234)\n",
53 | "random.seed(1234)\n"
54 | ],
55 | "execution_count": 0,
56 | "outputs": []
57 | },
58 | {
59 | "cell_type": "code",
60 | "metadata": {
61 | "id": "RIhcLRqlem3V",
62 | "colab_type": "code",
63 | "colab": {}
64 | },
65 | "source": [
66 | "# 使用するパッケージ(ライブラリと関数)を定義\n",
67 | "# 標準正規分布の生成用\n",
68 | "from numpy.random import *\n",
69 | "\n",
70 | "# グラフの描画用\n",
71 | "import matplotlib.pyplot as plt\n",
72 | "\n",
73 | "# SciPy 平均0、分散1に正規化(標準化)関数\n",
74 | "import scipy.stats\n",
75 | "\n",
76 | "# シグモイド関数をimport\n",
77 | "from scipy.special import expit\n",
78 | "\n",
79 | "# その他\n",
80 | "import pandas as pd\n"
81 | ],
82 | "execution_count": 0,
83 | "outputs": []
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {
88 | "id": "AWqP6yeQlI_t",
89 | "colab_type": "text"
90 | },
91 | "source": [
92 | "## データの作成"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "metadata": {
98 | "id": "EJaQIHz4fNXb",
99 | "colab_type": "code",
100 | "colab": {}
101 | },
102 | "source": [
103 | "# データ数\n",
104 | "num_data = 200\n",
105 | "\n",
106 | "# 年齢\n",
107 | "x_1 = randint(15, 76, num_data) # 15から75歳の一様乱数\n",
108 | "\n",
109 | "# 性別(0を女性、1を男性とします)\n",
110 | "x_2 = randint(0, 2, num_data) # 0か1の乱数\n"
111 | ],
112 | "execution_count": 0,
113 | "outputs": []
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "JiT_gc5ZmAQa",
119 | "colab_type": "text"
120 | },
121 | "source": [
122 | "## テレビCMを見たかどうか"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "metadata": {
128 | "id": "hfPbhhm6gvW6",
129 | "colab_type": "code",
130 | "colab": {}
131 | },
132 | "source": [
133 | "# ノイズの生成\n",
134 | "e_z = randn(num_data)\n",
135 | "\n",
136 | "# シグモイド関数に入れる部分\n",
137 | "z_base = x_1 + (1-x_2)*10 - 40 + 5*e_z\n",
138 | "\n",
139 | "# シグモイド関数を計算\n",
140 | "z_prob = expit(0.1*z_base)\n",
141 | "\n",
142 | "# テレビCMを見たかどうかの変数(0は見ていない、1は見た)\n",
143 | "Z = np.array([])\n",
144 | "\n",
145 | "for i in range(num_data):\n",
146 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n",
147 | " Z = np.append(Z, Z_i)\n"
148 | ],
149 | "execution_count": 0,
150 | "outputs": []
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {
155 | "id": "b2PLquJGi2Te",
156 | "colab_type": "text"
157 | },
158 | "source": [
159 | "## 購入量Yを作成"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "metadata": {
165 | "id": "nv-ELtFqi5L5",
166 | "colab_type": "code",
167 | "colab": {}
168 | },
169 | "source": [
170 | "# ノイズの生成\n",
171 | "e_y = randn(num_data)\n",
172 | "\n",
173 | "Y = -x_1 + 30*x_2 + 10*Z + 80 + 10*e_y\n"
174 | ],
175 | "execution_count": 0,
176 | "outputs": []
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {
181 | "id": "BHcdUlW9koTa",
182 | "colab_type": "text"
183 | },
184 | "source": [
185 | "## データをまとめた表を作成し、平均値を比べる"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "metadata": {
191 | "id": "HPqwrISXktRj",
192 | "colab_type": "code",
193 | "outputId": "9ad6013c-9715-481b-d68a-17fdd3edd281",
194 | "colab": {
195 | "base_uri": "https://localhost:8080/",
196 | "height": 195
197 | }
198 | },
199 | "source": [
200 | "df = pd.DataFrame({'年齢': x_1,\n",
201 | " '性別': x_2,\n",
202 | " 'CMを見た': Z,\n",
203 | " '購入量': Y,\n",
204 | " })\n",
205 | "\n",
206 | "df.head() # 先頭を表示\n"
207 | ],
208 | "execution_count": 6,
209 | "outputs": [
210 | {
211 | "output_type": "execute_result",
212 | "data": {
213 | "text/html": [
214 | "\n",
215 | "\n",
228 | "
\n",
229 | " \n",
230 | " \n",
231 | " | \n",
232 | " 年齢 | \n",
233 | " 性別 | \n",
234 | " CMを見た | \n",
235 | " 購入量 | \n",
236 | "
\n",
237 | " \n",
238 | " \n",
239 | " \n",
240 | " 0 | \n",
241 | " 62 | \n",
242 | " 0 | \n",
243 | " 1.0 | \n",
244 | " 24.464285 | \n",
245 | "
\n",
246 | " \n",
247 | " 1 | \n",
248 | " 34 | \n",
249 | " 0 | \n",
250 | " 0.0 | \n",
251 | " 45.693411 | \n",
252 | "
\n",
253 | " \n",
254 | " 2 | \n",
255 | " 53 | \n",
256 | " 1 | \n",
257 | " 1.0 | \n",
258 | " 64.998281 | \n",
259 | "
\n",
260 | " \n",
261 | " 3 | \n",
262 | " 68 | \n",
263 | " 1 | \n",
264 | " 1.0 | \n",
265 | " 47.186898 | \n",
266 | "
\n",
267 | " \n",
268 | " 4 | \n",
269 | " 27 | \n",
270 | " 1 | \n",
271 | " 0.0 | \n",
272 | " 100.114260 | \n",
273 | "
\n",
274 | " \n",
275 | "
\n",
276 | "
"
277 | ],
278 | "text/plain": [
279 | " 年齢 性別 CMを見た 購入量\n",
280 | "0 62 0 1.0 24.464285\n",
281 | "1 34 0 0.0 45.693411\n",
282 | "2 53 1 1.0 64.998281\n",
283 | "3 68 1 1.0 47.186898\n",
284 | "4 27 1 0.0 100.114260"
285 | ]
286 | },
287 | "metadata": {
288 | "tags": []
289 | },
290 | "execution_count": 6
291 | }
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "metadata": {
297 | "id": "HHInQ1Sukrg0",
298 | "colab_type": "code",
299 | "outputId": "ec83304b-9f7c-4334-93a9-c13022813ae9",
300 | "colab": {
301 | "base_uri": "https://localhost:8080/",
302 | "height": 210
303 | }
304 | },
305 | "source": [
306 | "# 平均値を比べる\n",
307 | "\n",
308 | "print(df[df[\"CMを見た\"] == 1.0].mean())\n",
309 | "print(\"--------\")\n",
310 | "print(df[df[\"CMを見た\"] == 0.0].mean())\n"
311 | ],
312 | "execution_count": 7,
313 | "outputs": [
314 | {
315 | "output_type": "stream",
316 | "text": [
317 | "年齢 55.836066\n",
318 | "性別 0.483607\n",
319 | "CMを見た 1.000000\n",
320 | "購入量 49.711478\n",
321 | "dtype: float64\n",
322 | "--------\n",
323 | "年齢 32.141026\n",
324 | "性別 0.692308\n",
325 | "CMを見た 0.000000\n",
326 | "購入量 68.827143\n",
327 | "dtype: float64\n"
328 | ],
329 | "name": "stdout"
330 | }
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "metadata": {
336 | "id": "kwKOk59aogBd",
337 | "colab_type": "text"
338 | },
339 | "source": [
340 | "## ここからが4.1節と異なります。傾向スコアの推定"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "metadata": {
346 | "id": "rBtHC3smoiMC",
347 | "colab_type": "code",
348 | "outputId": "95ab93a5-7d8e-4226-f654-e1351b2537ba",
349 | "colab": {
350 | "base_uri": "https://localhost:8080/",
351 | "height": 52
352 | }
353 | },
354 | "source": [
355 | "# scikit-learnからロジスティク回帰をimport\n",
356 | "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html\n",
357 | "from sklearn.linear_model import LogisticRegression\n",
358 | "\n",
359 | "# 説明変数\n",
360 | "X = df[[\"年齢\", \"性別\"]]\n",
361 | "\n",
362 | "# 被説明変数(目的変数)\n",
363 | "Z = df[\"CMを見た\"]\n",
364 | "\n",
365 | "# 回帰の実施\n",
366 | "reg = LogisticRegression().fit(X,Z)\n",
367 | "\n",
368 | "# 回帰した結果の係数を出力\n",
369 | "print(\"係数beta:\", reg.coef_)\n",
370 | "print(\"係数alpha:\", reg.intercept_)"
371 | ],
372 | "execution_count": 8,
373 | "outputs": [
374 | {
375 | "output_type": "stream",
376 | "text": [
377 | "係数beta: [[ 0.10562765 -1.38263933]]\n",
378 | "係数alpha: [-3.37146523]\n"
379 | ],
380 | "name": "stdout"
381 | }
382 | ]
383 | },
384 | {
385 | "cell_type": "markdown",
386 | "metadata": {
387 | "id": "nV0dm82l3QXy",
388 | "colab_type": "text"
389 | },
390 | "source": [
391 | "### 各人の傾向スコアを求める"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "metadata": {
397 | "id": "gGCKiujL3P1i",
398 | "colab_type": "code",
399 | "colab": {
400 | "base_uri": "https://localhost:8080/",
401 | "height": 228
402 | },
403 | "outputId": "c9201825-cd6f-41fa-aec2-52858d82b59e"
404 | },
405 | "source": [
406 | "Z_pre = reg.predict_proba(X)\n",
407 | "print(Z_pre[0:5]) # 5人ほどの結果を見てみる\n",
408 | "print(\"----\")\n",
409 | "print(Z[0:5]) # 5人ほどの正解\n"
410 | ],
411 | "execution_count": 9,
412 | "outputs": [
413 | {
414 | "output_type": "stream",
415 | "text": [
416 | "[[0.04002323 0.95997677]\n",
417 | " [0.44525168 0.55474832]\n",
418 | " [0.30065918 0.69934082]\n",
419 | " [0.08101946 0.91898054]\n",
420 | " [0.87013558 0.12986442]]\n",
421 | "----\n",
422 | "0 1.0\n",
423 | "1 0.0\n",
424 | "2 1.0\n",
425 | "3 1.0\n",
426 | "4 0.0\n",
427 | "Name: CMを見た, dtype: float64\n"
428 | ],
429 | "name": "stdout"
430 | }
431 | ]
432 | },
433 | {
434 | "cell_type": "markdown",
435 | "metadata": {
436 | "colab_type": "text",
437 | "id": "wL-hlBN36DZf"
438 | },
439 | "source": [
440 | "### 平均処置効果ATEを求める"
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "metadata": {
446 | "id": "6Ujy7JJa6Gwi",
447 | "colab_type": "code",
448 | "colab": {
449 | "base_uri": "https://localhost:8080/",
450 | "height": 34
451 | },
452 | "outputId": "2c4ab7d4-2393-4937-ed5e-5e394fe2116d"
453 | },
454 | "source": [
455 | "ATE_i = Y/Z_pre[:, 1]*Z - Y/Z_pre[:, 0]*(1-Z)\n",
456 | "ATE = 1/len(Y)*ATE_i.sum()\n",
457 | "print(\"推定したATE\", ATE)\n"
458 | ],
459 | "execution_count": 10,
460 | "outputs": [
461 | {
462 | "output_type": "stream",
463 | "text": [
464 | "推定したATE 8.847476810855458\n"
465 | ],
466 | "name": "stdout"
467 | }
468 | ]
469 | },
470 | {
471 | "cell_type": "markdown",
472 | "metadata": {
473 | "id": "1IdVhXmMps-w",
474 | "colab_type": "text"
475 | },
476 | "source": [
477 | "以上"
478 | ]
479 | }
480 | ]
481 | }
--------------------------------------------------------------------------------
/4_3_dr.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "4_3_dr.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "d-IAJLC2k1NX"
20 | },
21 | "source": [
22 | "# 4.3 Doubly Robust法(DR法)による因果推論の実装\n",
23 | "\n",
24 | "本ファイルは、4.3節の実装です。\n",
25 | "\n",
26 | "4.1節、4.2節と同じく、テレビCMの広告効果の推定を例に、回帰分析による因果推論を実装します。\n"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "2XdIDbdlejUk"
33 | },
34 | "source": [
35 | "## プログラム実行前の設定など"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "metadata": {
41 | "id": "wqHjwstVeXYt"
42 | },
43 | "source": [
44 | "# 乱数のシードを設定\n",
45 | "import random\n",
46 | "import numpy as np\n",
47 | "\n",
48 | "np.random.seed(1234)\n",
49 | "random.seed(1234)\n"
50 | ],
51 | "execution_count": null,
52 | "outputs": []
53 | },
54 | {
55 | "cell_type": "code",
56 | "metadata": {
57 | "id": "RIhcLRqlem3V"
58 | },
59 | "source": [
60 | "# 使用するパッケージ(ライブラリと関数)を定義\n",
61 | "# 標準正規分布の生成用\n",
62 | "from numpy.random import *\n",
63 | "\n",
64 | "# グラフの描画用\n",
65 | "import matplotlib.pyplot as plt\n",
66 | "\n",
67 | "# SciPy 平均0、分散1に正規化(標準化)関数\n",
68 | "import scipy.stats\n",
69 | "\n",
70 | "# シグモイド関数をimport\n",
71 | "from scipy.special import expit\n",
72 | "\n",
73 | "# その他\n",
74 | "import pandas as pd\n"
75 | ],
76 | "execution_count": null,
77 | "outputs": []
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {
82 | "id": "AWqP6yeQlI_t"
83 | },
84 | "source": [
85 | "## データの作成"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "metadata": {
91 | "id": "EJaQIHz4fNXb"
92 | },
93 | "source": [
94 | "# データ数\n",
95 | "num_data = 200\n",
96 | "\n",
97 | "# 年齢\n",
98 | "x_1 = randint(15, 76, num_data) # 15から75歳の一様乱数\n",
99 | "\n",
100 | "# 性別(0を女性、1を男性とします)\n",
101 | "x_2 = randint(0, 2, num_data) # 0か1の乱数\n"
102 | ],
103 | "execution_count": null,
104 | "outputs": []
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "metadata": {
109 | "id": "JiT_gc5ZmAQa"
110 | },
111 | "source": [
112 | "## テレビCMを見たかどうか"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "metadata": {
118 | "id": "hfPbhhm6gvW6"
119 | },
120 | "source": [
121 | "# ノイズの生成\n",
122 | "e_z = randn(num_data)\n",
123 | "\n",
124 | "# シグモイド関数に入れる部分\n",
125 | "z_base = x_1 + (1-x_2)*10 - 40 + 5*e_z\n",
126 | "\n",
127 | "# シグモイド関数を計算\n",
128 | "z_prob = expit(0.1*z_base)\n",
129 | "\n",
130 | "# テレビCMを見たかどうかの変数(0は見ていない、1は見た)\n",
131 | "Z = np.array([])\n",
132 | "\n",
133 | "for i in range(num_data):\n",
134 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n",
135 | " Z = np.append(Z, Z_i)\n"
136 | ],
137 | "execution_count": null,
138 | "outputs": []
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "metadata": {
143 | "id": "b2PLquJGi2Te"
144 | },
145 | "source": [
146 | "## 購入量Yを作成"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "metadata": {
152 | "id": "nv-ELtFqi5L5"
153 | },
154 | "source": [
155 | "# ノイズの生成\n",
156 | "e_y = randn(num_data)\n",
157 | "\n",
158 | "Y = -x_1 + 30*x_2 + 10*Z + 80 + 10*e_y\n"
159 | ],
160 | "execution_count": null,
161 | "outputs": []
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {
166 | "id": "BHcdUlW9koTa"
167 | },
168 | "source": [
169 | "## データをまとめた表を作成し、平均値を比べる"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "metadata": {
175 | "id": "HPqwrISXktRj",
176 | "colab": {
177 | "base_uri": "https://localhost:8080/",
178 | "height": 195
179 | },
180 | "outputId": "539bb1b5-9936-461e-dd94-258b6980d366"
181 | },
182 | "source": [
183 | "df = pd.DataFrame({'年齢': x_1,\n",
184 | " '性別': x_2,\n",
185 | " 'CMを見た': Z,\n",
186 | " '購入量': Y,\n",
187 | " })\n",
188 | "\n",
189 | "df.head() # 先頭を表示\n"
190 | ],
191 | "execution_count": null,
192 | "outputs": [
193 | {
194 | "output_type": "execute_result",
195 | "data": {
196 | "text/html": [
197 | "\n",
198 | "\n",
211 | "
\n",
212 | " \n",
213 | " \n",
214 | " | \n",
215 | " 年齢 | \n",
216 | " 性別 | \n",
217 | " CMを見た | \n",
218 | " 購入量 | \n",
219 | "
\n",
220 | " \n",
221 | " \n",
222 | " \n",
223 | " 0 | \n",
224 | " 62 | \n",
225 | " 0 | \n",
226 | " 1.0 | \n",
227 | " 24.464285 | \n",
228 | "
\n",
229 | " \n",
230 | " 1 | \n",
231 | " 34 | \n",
232 | " 0 | \n",
233 | " 0.0 | \n",
234 | " 45.693411 | \n",
235 | "
\n",
236 | " \n",
237 | " 2 | \n",
238 | " 53 | \n",
239 | " 1 | \n",
240 | " 1.0 | \n",
241 | " 64.998281 | \n",
242 | "
\n",
243 | " \n",
244 | " 3 | \n",
245 | " 68 | \n",
246 | " 1 | \n",
247 | " 1.0 | \n",
248 | " 47.186898 | \n",
249 | "
\n",
250 | " \n",
251 | " 4 | \n",
252 | " 27 | \n",
253 | " 1 | \n",
254 | " 0.0 | \n",
255 | " 100.114260 | \n",
256 | "
\n",
257 | " \n",
258 | "
\n",
259 | "
"
260 | ],
261 | "text/plain": [
262 | " 年齢 性別 CMを見た 購入量\n",
263 | "0 62 0 1.0 24.464285\n",
264 | "1 34 0 0.0 45.693411\n",
265 | "2 53 1 1.0 64.998281\n",
266 | "3 68 1 1.0 47.186898\n",
267 | "4 27 1 0.0 100.114260"
268 | ]
269 | },
270 | "metadata": {
271 | "tags": []
272 | },
273 | "execution_count": 6
274 | }
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "metadata": {
280 | "id": "HHInQ1Sukrg0",
281 | "colab": {
282 | "base_uri": "https://localhost:8080/",
283 | "height": 202
284 | },
285 | "outputId": "4739a485-80aa-425b-ccd8-91294cb7d9fb"
286 | },
287 | "source": [
288 | "# 平均値を比べる\n",
289 | "\n",
290 | "print(df[df[\"CMを見た\"] == 1.0].mean())\n",
291 | "print(\"--------\")\n",
292 | "print(df[df[\"CMを見た\"] == 0.0].mean())\n"
293 | ],
294 | "execution_count": null,
295 | "outputs": [
296 | {
297 | "output_type": "stream",
298 | "text": [
299 | "年齢 55.836066\n",
300 | "性別 0.483607\n",
301 | "CMを見た 1.000000\n",
302 | "購入量 49.711478\n",
303 | "dtype: float64\n",
304 | "--------\n",
305 | "年齢 32.141026\n",
306 | "性別 0.692308\n",
307 | "CMを見た 0.000000\n",
308 | "購入量 68.827143\n",
309 | "dtype: float64\n"
310 | ],
311 | "name": "stdout"
312 | }
313 | ]
314 | },
315 | {
316 | "cell_type": "markdown",
317 | "metadata": {
318 | "id": "-KMbYTvx-D4N"
319 | },
320 | "source": [
321 | "## 回帰分析を実施"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "metadata": {
327 | "id": "CiVVt59d-gdj"
328 | },
329 | "source": [
330 | "# scikit-learnから線形回帰をimport\n",
331 | "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html\n",
332 | "from sklearn.linear_model import LinearRegression\n",
333 | "\n",
334 | "# 説明変数\n",
335 | "X = df[[\"年齢\", \"性別\", \"CMを見た\"]]\n",
336 | "\n",
337 | "# 被説明変数(目的変数)\n",
338 | "y = df[\"購入量\"]\n",
339 | "\n",
340 | "# 回帰の実施\n",
341 | "reg2 = LinearRegression().fit(X, y)\n",
342 | "\n",
343 | "# Z=0の場合\n",
344 | "X_0 = X.copy()\n",
345 | "X_0[\"CMを見た\"] = 0\n",
346 | "Y_0 = reg2.predict(X_0)\n",
347 | "\n",
348 | "# Z=1の場合\n",
349 | "X_1 = X.copy()\n",
350 | "X_1[\"CMを見た\"] = 1\n",
351 | "Y_1 = reg2.predict(X_1)\n"
352 | ],
353 | "execution_count": null,
354 | "outputs": []
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "metadata": {
359 | "id": "kwKOk59aogBd"
360 | },
361 | "source": [
362 | "## 傾向スコアの推定"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "metadata": {
368 | "id": "rBtHC3smoiMC",
369 | "colab": {
370 | "base_uri": "https://localhost:8080/",
371 | "height": 101
372 | },
373 | "outputId": "6b9e06dc-3ee1-45be-a573-39a6839d9c85"
374 | },
375 | "source": [
376 | "# scikit-learnからロジスティク回帰をimport\n",
377 | "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html\n",
378 | "from sklearn.linear_model import LogisticRegression\n",
379 | "\n",
380 | "# 説明変数\n",
381 | "X = df[[\"年齢\", \"性別\"]]\n",
382 | "\n",
383 | "# 被説明変数(目的変数)\n",
384 | "Z = df[\"CMを見た\"]\n",
385 | "\n",
386 | "# 回帰の実施\n",
387 | "reg = LogisticRegression().fit(X, Z)\n",
388 | "\n",
389 | "# 傾向スコアを求める\n",
390 | "Z_pre = reg.predict_proba(X)\n",
391 | "print(Z_pre[0:5]) # 5人ほどの結果を見てみる\n"
392 | ],
393 | "execution_count": null,
394 | "outputs": [
395 | {
396 | "output_type": "stream",
397 | "text": [
398 | "[[0.04002323 0.95997677]\n",
399 | " [0.44525168 0.55474832]\n",
400 | " [0.30065918 0.69934082]\n",
401 | " [0.08101946 0.91898054]\n",
402 | " [0.87013558 0.12986442]]\n"
403 | ],
404 | "name": "stdout"
405 | }
406 | ]
407 | },
408 | {
409 | "cell_type": "markdown",
410 | "metadata": {
411 | "id": "wL-hlBN36DZf"
412 | },
413 | "source": [
414 | "### 平均処置効果ATEを求める"
415 | ]
416 | },
417 | {
418 | "cell_type": "code",
419 | "metadata": {
420 | "id": "F7bIHOC2ABSK",
421 | "colab": {
422 | "base_uri": "https://localhost:8080/",
423 | "height": 34
424 | },
425 | "outputId": "2bc8f4d0-2d8a-4620-a60f-844afc4c96e3"
426 | },
427 | "source": [
428 | "ATE_1_i = Y/Z_pre[:, 1]*Z + (1-Z/Z_pre[:, 1])*Y_1\n",
429 | "ATE_0_i = Y/Z_pre[:, 0]*(1-Z) + (1-(1-Z)/Z_pre[:, 0])*Y_0\n",
430 | "ATE = 1/len(Y)*(ATE_1_i-ATE_0_i).sum()\n",
431 | "print(\"推定したATE\", ATE)\n"
432 | ],
433 | "execution_count": null,
434 | "outputs": [
435 | {
436 | "output_type": "stream",
437 | "text": [
438 | "推定したATE 9.75277505424846\n"
439 | ],
440 | "name": "stdout"
441 | }
442 | ]
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "metadata": {
447 | "id": "1IdVhXmMps-w"
448 | },
449 | "source": [
450 | "以上"
451 | ]
452 | }
453 | ]
454 | }
--------------------------------------------------------------------------------
/5_1_randomforest.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "5_1_randomforest.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "d-IAJLC2k1NX",
20 | "colab_type": "text"
21 | },
22 | "source": [
23 | "# 5.1 ランダムフォレストとは\n",
24 | "\n",
25 | "本ファイルは、5.1節の実装です。\n",
26 | "\n",
27 | "機械学習モデルのランダムフォレストを解説・実装します。\n",
28 | "決定木の分類、決定木の回帰、ランダムフォレストの分類、ランダムフォレストの回帰を実施します"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {
34 | "id": "2XdIDbdlejUk",
35 | "colab_type": "text"
36 | },
37 | "source": [
38 | "## プログラム実行前の設定など"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "metadata": {
44 | "id": "wqHjwstVeXYt",
45 | "colab_type": "code",
46 | "colab": {}
47 | },
48 | "source": [
49 | "# 乱数のシードを設定\n",
50 | "import random\n",
51 | "import numpy as np\n",
52 | "\n",
53 | "np.random.seed(1234)\n",
54 | "random.seed(1234)\n"
55 | ],
56 | "execution_count": 0,
57 | "outputs": []
58 | },
59 | {
60 | "cell_type": "code",
61 | "metadata": {
62 | "id": "RIhcLRqlem3V",
63 | "colab_type": "code",
64 | "colab": {}
65 | },
66 | "source": [
67 | "# 使用するパッケージ(ライブラリと関数)を定義\n",
68 | "# 標準正規分布の生成用\n",
69 | "from numpy.random import *\n",
70 | "\n",
71 | "# グラフの描画用\n",
72 | "import matplotlib.pyplot as plt\n",
73 | "\n",
74 | "# SciPy 平均0、分散1に正規化(標準化)関数\n",
75 | "import scipy.stats\n",
76 | "\n",
77 | "# シグモイド関数をimport\n",
78 | "from scipy.special import expit\n",
79 | "\n",
80 | "# その他\n",
81 | "import pandas as pd\n"
82 | ],
83 | "execution_count": 0,
84 | "outputs": []
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {
89 | "id": "AWqP6yeQlI_t",
90 | "colab_type": "text"
91 | },
92 | "source": [
93 | "## データの作成"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "metadata": {
99 | "id": "EJaQIHz4fNXb",
100 | "colab_type": "code",
101 | "colab": {}
102 | },
103 | "source": [
104 | "# データ数\n",
105 | "num_data = 200\n",
106 | "\n",
107 | "# 年齢\n",
108 | "x_1 = randint(15, 76, num_data) # 15から75歳の一様乱数\n",
109 | "\n",
110 | "# 性別(0を女性、1を男性とします)\n",
111 | "x_2 = randint(0, 2, num_data) # 0か1の乱数\n"
112 | ],
113 | "execution_count": 0,
114 | "outputs": []
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {
119 | "id": "JiT_gc5ZmAQa",
120 | "colab_type": "text"
121 | },
122 | "source": [
123 | "## テレビCMを見たかどうか"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "metadata": {
129 | "id": "hfPbhhm6gvW6",
130 | "colab_type": "code",
131 | "colab": {}
132 | },
133 | "source": [
134 | "# ノイズの生成\n",
135 | "e_z = randn(num_data)\n",
136 | "\n",
137 | "# シグモイド関数に入れる部分\n",
138 | "z_base = x_1 + (1-x_2)*10 - 40 + 5*e_z\n",
139 | "\n",
140 | "# シグモイド関数を計算\n",
141 | "z_prob = expit(0.1*z_base)\n",
142 | "\n",
143 | "# テレビCMを見たかどうかの変数(0は見ていない、1は見た)\n",
144 | "Z = np.array([])\n",
145 | "\n",
146 | "for i in range(num_data):\n",
147 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n",
148 | " Z = np.append(Z, Z_i)\n"
149 | ],
150 | "execution_count": 0,
151 | "outputs": []
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {
156 | "id": "b2PLquJGi2Te",
157 | "colab_type": "text"
158 | },
159 | "source": [
160 | "## 購入量Yを作成"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "metadata": {
166 | "id": "nv-ELtFqi5L5",
167 | "colab_type": "code",
168 | "colab": {}
169 | },
170 | "source": [
171 | "# ノイズの生成\n",
172 | "e_y = randn(num_data)\n",
173 | "\n",
174 | "Y = -x_1 + 30*x_2 + 10*Z + 80 + 10*e_y\n"
175 | ],
176 | "execution_count": 0,
177 | "outputs": []
178 | },
179 | {
180 | "cell_type": "markdown",
181 | "metadata": {
182 | "id": "BHcdUlW9koTa",
183 | "colab_type": "text"
184 | },
185 | "source": [
186 | "## データをまとめた表を作成し、平均値を比べる"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "metadata": {
192 | "id": "HPqwrISXktRj",
193 | "colab_type": "code",
194 | "outputId": "472a7556-d2c0-4341-f570-d101075ca857",
195 | "colab": {
196 | "base_uri": "https://localhost:8080/",
197 | "height": 195
198 | }
199 | },
200 | "source": [
201 | "df = pd.DataFrame({'年齢': x_1,\n",
202 | " '性別': x_2,\n",
203 | " 'CMを見た': Z,\n",
204 | " '購入量': Y,\n",
205 | " })\n",
206 | "\n",
207 | "df.head() # 先頭を表示\n"
208 | ],
209 | "execution_count": 6,
210 | "outputs": [
211 | {
212 | "output_type": "execute_result",
213 | "data": {
214 | "text/html": [
215 | "\n",
216 | "\n",
229 | "
\n",
230 | " \n",
231 | " \n",
232 | " | \n",
233 | " 年齢 | \n",
234 | " 性別 | \n",
235 | " CMを見た | \n",
236 | " 購入量 | \n",
237 | "
\n",
238 | " \n",
239 | " \n",
240 | " \n",
241 | " 0 | \n",
242 | " 62 | \n",
243 | " 0 | \n",
244 | " 1.0 | \n",
245 | " 24.464285 | \n",
246 | "
\n",
247 | " \n",
248 | " 1 | \n",
249 | " 34 | \n",
250 | " 0 | \n",
251 | " 0.0 | \n",
252 | " 45.693411 | \n",
253 | "
\n",
254 | " \n",
255 | " 2 | \n",
256 | " 53 | \n",
257 | " 1 | \n",
258 | " 1.0 | \n",
259 | " 64.998281 | \n",
260 | "
\n",
261 | " \n",
262 | " 3 | \n",
263 | " 68 | \n",
264 | " 1 | \n",
265 | " 1.0 | \n",
266 | " 47.186898 | \n",
267 | "
\n",
268 | " \n",
269 | " 4 | \n",
270 | " 27 | \n",
271 | " 1 | \n",
272 | " 0.0 | \n",
273 | " 100.114260 | \n",
274 | "
\n",
275 | " \n",
276 | "
\n",
277 | "
"
278 | ],
279 | "text/plain": [
280 | " 年齢 性別 CMを見た 購入量\n",
281 | "0 62 0 1.0 24.464285\n",
282 | "1 34 0 0.0 45.693411\n",
283 | "2 53 1 1.0 64.998281\n",
284 | "3 68 1 1.0 47.186898\n",
285 | "4 27 1 0.0 100.114260"
286 | ]
287 | },
288 | "metadata": {
289 | "tags": []
290 | },
291 | "execution_count": 6
292 | }
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "metadata": {
298 | "id": "HHInQ1Sukrg0",
299 | "colab_type": "code",
300 | "outputId": "66879f86-ad4f-46cc-9a20-8edf74c860e1",
301 | "colab": {
302 | "base_uri": "https://localhost:8080/",
303 | "height": 202
304 | }
305 | },
306 | "source": [
307 | "# 平均値を比べる\n",
308 | "\n",
309 | "print(df[df[\"CMを見た\"] == 1.0].mean())\n",
310 | "print(\"--------\")\n",
311 | "print(df[df[\"CMを見た\"] == 0.0].mean())\n"
312 | ],
313 | "execution_count": 7,
314 | "outputs": [
315 | {
316 | "output_type": "stream",
317 | "text": [
318 | "年齢 55.836066\n",
319 | "性別 0.483607\n",
320 | "CMを見た 1.000000\n",
321 | "購入量 49.711478\n",
322 | "dtype: float64\n",
323 | "--------\n",
324 | "年齢 32.141026\n",
325 | "性別 0.692308\n",
326 | "CMを見た 0.000000\n",
327 | "購入量 68.827143\n",
328 | "dtype: float64\n"
329 | ],
330 | "name": "stdout"
331 | }
332 | ]
333 | },
334 | {
335 | "cell_type": "markdown",
336 | "metadata": {
337 | "id": "kwKOk59aogBd",
338 | "colab_type": "text"
339 | },
340 | "source": [
341 | "## 決定木で分類\n",
342 | "\n",
343 | "決定木でCMを見たかどうかを分類予測するモデルを構築します"
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "metadata": {
349 | "id": "rBtHC3smoiMC",
350 | "colab_type": "code",
351 | "outputId": "1b5ab97b-27f8-42fe-dc8b-995d9f6b5d0d",
352 | "colab": {
353 | "base_uri": "https://localhost:8080/",
354 | "height": 67
355 | }
356 | },
357 | "source": [
358 | "# scikit-learnから決定木の分類をimport\n",
359 | "# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html\n",
360 | "from sklearn.tree import DecisionTreeClassifier\n",
361 | "\n",
362 | "# データを訓練と検証に分割する\n",
363 | "# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html\n",
364 | "from sklearn.model_selection import train_test_split\n",
365 | "\n",
366 | "\n",
367 | "# 説明変数\n",
368 | "X = df[[\"年齢\", \"性別\"]]\n",
369 | "\n",
370 | "# 被説明変数(目的変数)\n",
371 | "Z = df[\"CMを見た\"]\n",
372 | "\n",
373 | "# データを訓練と検証に分割\n",
374 | "X_train, X_val, Z_train, Z_val = train_test_split(\n",
375 | " X, Z, train_size=0.6, random_state=0)\n",
376 | "\n",
377 | "# 学習と性能確認\n",
378 | "clf = DecisionTreeClassifier(max_depth=1, random_state=0)\n",
379 | "clf.fit(X_train, Z_train)\n",
380 | "print(\"深さ1の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n",
381 | "\n",
382 | "# 学習と性能確認\n",
383 | "clf = DecisionTreeClassifier(max_depth=2, random_state=0)\n",
384 | "clf.fit(X_train, Z_train)\n",
385 | "print(\"深さ2の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n",
386 | "\n",
387 | "# 学習と性能確認\n",
388 | "clf = DecisionTreeClassifier(max_depth=3, random_state=0)\n",
389 | "clf.fit(X_train, Z_train)\n",
390 | "print(\"深さ3の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n"
391 | ],
392 | "execution_count": 8,
393 | "outputs": [
394 | {
395 | "output_type": "stream",
396 | "text": [
397 | "深さ1の性能: 0.85\n",
398 | "深さ2の性能: 0.85\n",
399 | "深さ3の性能: 0.825\n"
400 | ],
401 | "name": "stdout"
402 | }
403 | ]
404 | },
405 | {
406 | "cell_type": "markdown",
407 | "metadata": {
408 | "colab_type": "text",
409 | "id": "-KMbYTvx-D4N"
410 | },
411 | "source": [
412 | "## 決定木で回帰\n",
413 | "\n",
414 | "決定木で購入量を回帰予測するモデルを構築します"
415 | ]
416 | },
417 | {
418 | "cell_type": "code",
419 | "metadata": {
420 | "id": "CiVVt59d-gdj",
421 | "colab_type": "code",
422 | "outputId": "93ee9046-b983-4dbb-d8cb-2d86c14bf7f2",
423 | "colab": {
424 | "base_uri": "https://localhost:8080/",
425 | "height": 67
426 | }
427 | },
428 | "source": [
429 | "# scikit-learnから決定木の回帰をimport\n",
430 | "# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html#sklearn.tree.DecisionTreeRegressor\n",
431 | "from sklearn.tree import DecisionTreeRegressor\n",
432 | "\n",
433 | "# データを訓練と検証に分割する\n",
434 | "# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html\n",
435 | "from sklearn.model_selection import train_test_split\n",
436 | "\n",
437 | "\n",
438 | "# 説明変数\n",
439 | "X = df[[\"年齢\", \"性別\"]]\n",
440 | "\n",
441 | "# 被説明変数(目的変数)\n",
442 | "Y = df[\"購入量\"]\n",
443 | "\n",
444 | "# データを訓練と検証に分割\n",
445 | "X_train, X_val, Y_train, Y_val = train_test_split(\n",
446 | " X, Y, train_size=0.6, random_state=0)\n",
447 | "\n",
448 | "# 学習と性能確認\n",
449 | "reg = DecisionTreeRegressor(max_depth=2, random_state=0)\n",
450 | "reg = reg.fit(X_train, Y_train)\n",
451 | "print(\"深さ2の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n",
452 | "\n",
453 | "# 学習と性能確認\n",
454 | "reg = DecisionTreeRegressor(max_depth=3, random_state=0)\n",
455 | "reg = reg.fit(X_train, Y_train)\n",
456 | "print(\"深さ3の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n",
457 | "\n",
458 | "# 学習と性能確認\n",
459 | "reg = DecisionTreeRegressor(max_depth=4, random_state=0)\n",
460 | "reg = reg.fit(X_train, Y_train)\n",
461 | "print(\"深さ4の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n"
462 | ],
463 | "execution_count": 9,
464 | "outputs": [
465 | {
466 | "output_type": "stream",
467 | "text": [
468 | "深さ2の性能: 0.7257496664596153\n",
469 | "深さ3の性能: 0.7399348963931736\n",
470 | "深さ4の性能: 0.7165539691159019\n"
471 | ],
472 | "name": "stdout"
473 | }
474 | ]
475 | },
476 | {
477 | "cell_type": "markdown",
478 | "metadata": {
479 | "colab_type": "text",
480 | "id": "1LHqDZTHyMeA"
481 | },
482 | "source": [
483 | "## ランダムフォレストで分類\n",
484 | "\n",
485 | "ランダムフォレストでCMを見たかどうかを分類予測するモデルを構築します"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "metadata": {
491 | "id": "QZCX_vszyRIF",
492 | "colab_type": "code",
493 | "outputId": "fcc9656b-0c0f-4884-d682-992ba88bb9a7",
494 | "colab": {
495 | "base_uri": "https://localhost:8080/",
496 | "height": 67
497 | }
498 | },
499 | "source": [
500 | "# scikit-learnからランダムフォレストの分類をimport\n",
501 | "# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html?highlight=randomforest\n",
502 | "from sklearn.ensemble import RandomForestClassifier\n",
503 | "from sklearn.model_selection import train_test_split\n",
504 | "\n",
505 | "# 説明変数\n",
506 | "X = df[[\"年齢\", \"性別\"]]\n",
507 | "\n",
508 | "# 被説明変数(目的変数)\n",
509 | "Z = df[\"CMを見た\"]\n",
510 | "\n",
511 | "# データを訓練と検証に分割\n",
512 | "X_train, X_val, Z_train, Z_val = train_test_split(\n",
513 | " X, Z, train_size=0.6, random_state=0)\n",
514 | "\n",
515 | "# 学習と性能確認\n",
516 | "clf = RandomForestClassifier(max_depth=1, random_state=0)\n",
517 | "clf.fit(X_train, Z_train)\n",
518 | "print(\"深さ1の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n",
519 | "\n",
520 | "# 学習と性能確認\n",
521 | "clf = RandomForestClassifier(max_depth=2, random_state=0)\n",
522 | "clf.fit(X_train, Z_train)\n",
523 | "print(\"深さ2の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n",
524 | "\n",
525 | "# 学習と性能確認\n",
526 | "clf = RandomForestClassifier(max_depth=3, random_state=0)\n",
527 | "clf.fit(X_train, Z_train)\n",
528 | "print(\"深さ3の性能:\", clf.score(X_val, Z_val)) # 正解率を表示\n"
529 | ],
530 | "execution_count": 10,
531 | "outputs": [
532 | {
533 | "output_type": "stream",
534 | "text": [
535 | "深さ1の性能: 0.775\n",
536 | "深さ2の性能: 0.85\n",
537 | "深さ3の性能: 0.825\n"
538 | ],
539 | "name": "stdout"
540 | }
541 | ]
542 | },
543 | {
544 | "cell_type": "markdown",
545 | "metadata": {
546 | "colab_type": "text",
547 | "id": "OuTwc5Kt4AiW"
548 | },
549 | "source": [
550 | "## ランダムフォレストで回帰\n",
551 | "\n",
552 | "ランダムフォレストで購入量を回帰予測するモデルを構築します"
553 | ]
554 | },
555 | {
556 | "cell_type": "code",
557 | "metadata": {
558 | "id": "evnAhZnb4DXj",
559 | "colab_type": "code",
560 | "outputId": "1b439677-66ce-49dc-8760-70916ab7d2fc",
561 | "colab": {
562 | "base_uri": "https://localhost:8080/",
563 | "height": 67
564 | }
565 | },
566 | "source": [
567 | "# scikit-learnから決定木の回帰をimport\n",
568 | "# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html?highlight=randomforest\n",
569 | "from sklearn.ensemble import RandomForestRegressor\n",
570 | "from sklearn.model_selection import train_test_split\n",
571 | "\n",
572 | "\n",
573 | "# 説明変数\n",
574 | "X = df[[\"年齢\", \"性別\"]]\n",
575 | "\n",
576 | "# 被説明変数(目的変数)\n",
577 | "Y = df[\"購入量\"]\n",
578 | "\n",
579 | "# データを訓練と検証に分割\n",
580 | "X_train, X_val, Y_train, Y_val = train_test_split(\n",
581 | " X, Y, train_size=0.6, random_state=0)\n",
582 | "\n",
583 | "# 学習と性能確認\n",
584 | "reg = RandomForestRegressor(max_depth=2, random_state=0)\n",
585 | "reg = reg.fit(X_train, Y_train)\n",
586 | "print(\"深さ2の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n",
587 | "\n",
588 | "# 学習と性能確認\n",
589 | "reg = RandomForestRegressor(max_depth=3, random_state=0)\n",
590 | "reg = reg.fit(X_train, Y_train)\n",
591 | "print(\"深さ3の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n",
592 | "\n",
593 | "# 学習と性能確認\n",
594 | "reg = RandomForestRegressor(max_depth=4, random_state=0)\n",
595 | "reg = reg.fit(X_train, Y_train)\n",
596 | "print(\"深さ4の性能:\", reg.score(X_val, Y_val)) # 決定係数R2を表示\n"
597 | ],
598 | "execution_count": 11,
599 | "outputs": [
600 | {
601 | "output_type": "stream",
602 | "text": [
603 | "深さ2の性能: 0.7618786062003249\n",
604 | "深さ3の性能: 0.7810610687821996\n",
605 | "深さ4の性能: 0.7655149049335735\n"
606 | ],
607 | "name": "stdout"
608 | }
609 | ]
610 | },
611 | {
612 | "cell_type": "markdown",
613 | "metadata": {
614 | "id": "1IdVhXmMps-w",
615 | "colab_type": "text"
616 | },
617 | "source": [
618 | "以上"
619 | ]
620 | }
621 | ]
622 | }
--------------------------------------------------------------------------------
/5_3_doubly_robust_learning_issue18.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "name": "python3",
7 | "display_name": "Python 3"
8 | },
9 | "colab": {
10 | "name": "5_3_doubly_robust_learning_issue18.ipynb",
11 | "provenance": [],
12 | "collapsed_sections": []
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "aoxI3DOK9vm2"
20 | },
21 | "source": [
22 | "# 5.3 Doubly Robust Learningの実装\n",
23 | "\n",
24 | "本ファイルは、5.3節の実装です。\n",
25 | "\n",
26 | "5.2節と同じく、人事研修の効果について因果推論を実施します。"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "2XdIDbdlejUk"
33 | },
34 | "source": [
35 | "## プログラム実行前の設定など"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "metadata": {
41 | "id": "XZFKJwcu-_Oj"
42 | },
43 | "source": [
44 | "# 乱数のシードを設定\n",
45 | "import random\n",
46 | "import numpy as np\n",
47 | "\n",
48 | "np.random.seed(1234)\n",
49 | "random.seed(1234)\n"
50 | ],
51 | "execution_count": 1,
52 | "outputs": []
53 | },
54 | {
55 | "cell_type": "code",
56 | "metadata": {
57 | "id": "hx1idArc_F15"
58 | },
59 | "source": [
60 | "# 使用するパッケージ(ライブラリと関数)を定義\n",
61 | "# 標準正規分布の生成用\n",
62 | "from numpy.random import *\n",
63 | "\n",
64 | "# グラフの描画用\n",
65 | "import matplotlib.pyplot as plt\n",
66 | "\n",
67 | "# SciPy 平均0、分散1に正規化(標準化)関数\n",
68 | "import scipy.stats\n",
69 | "\n",
70 | "# シグモイド関数をimport\n",
71 | "from scipy.special import expit\n",
72 | "\n",
73 | "# その他\n",
74 | "import pandas as pd\n"
75 | ],
76 | "execution_count": 2,
77 | "outputs": []
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {
82 | "id": "AWqP6yeQlI_t"
83 | },
84 | "source": [
85 | "## データの作成"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "metadata": {
91 | "id": "DpnGB2KZ_L8x",
92 | "outputId": "b67517df-c4d8-4c40-851b-669d4aefabdc",
93 | "colab": {
94 | "base_uri": "https://localhost:8080/",
95 | "height": 282
96 | }
97 | },
98 | "source": [
99 | "# データ数\n",
100 | "num_data = 500\n",
101 | "\n",
102 | "# 部下育成への熱心さ\n",
103 | "x = np.random.uniform(low=-1, high=1, size=num_data) # -1から1の一様乱数\n",
104 | "\n",
105 | "# 上司が「上司向け:部下とのキャリア面談のポイント研修」に参加したかどうか\n",
106 | "e_z = randn(num_data) # ノイズの生成\n",
107 | "z_prob = expit(-1*-5.0*x+5*e_z) # xの効果が反対になっていたのを修正Issue:#18\n",
108 | "Z = np.array([])\n",
109 | "\n",
110 | "# 上司が「上司向け:部下とのキャリア面談のポイント研修」に参加したかどうか\n",
111 | "for i in range(num_data):\n",
112 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n",
113 | " Z = np.append(Z, Z_i)\n",
114 | "\n",
115 | "# 介入効果の非線形性:部下育成の熱心さxの値に応じて段階的に変化\n",
116 | "t = np.zeros(num_data)\n",
117 | "for i in range(num_data):\n",
118 | " if x[i] < 0:\n",
119 | " t[i] = 0.5\n",
120 | " elif x[i] >= 0 and x[i] < 0.5:\n",
121 | " t[i] = 0.7\n",
122 | " elif x[i] >= 0.5:\n",
123 | " t[i] = 1.0\n",
124 | "\n",
125 | "e_y = randn(num_data)\n",
126 | "Y = 2.0 + t*Z + 0.3*x + 0.1*e_y \n",
127 | "\n",
128 | "# 介入効果を図で確認\n",
129 | "plt.scatter(x, t, label=\"treatment-effect\")\n"
130 | ],
131 | "execution_count": 3,
132 | "outputs": [
133 | {
134 | "output_type": "execute_result",
135 | "data": {
136 | "text/plain": [
137 | ""
138 | ]
139 | },
140 | "metadata": {
141 | "tags": []
142 | },
143 | "execution_count": 3
144 | },
145 | {
146 | "output_type": "display_data",
147 | "data": {
148 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAASc0lEQVR4nO3df5BdZX3H8fc3mwQjVkPIaiUJJDgRTRsFvYO0zBT8SWA6SfzZZMoIlppqxc60yhRGxjIMjra049BKa5GhqNggUqXrFCejAuOMQ5ClCBiYQIgjJFBZQZxpoZAf3/5xT+R0c3fv3ey5e7MP79fMzt7zPM8955vn3nz27Dnn7onMRJI0+80ZdAGSpGYY6JJUCANdkgphoEtSIQx0SSrE3EFtePHixbl8+fJBbV6SZqW77rrrF5k53KlvYIG+fPlyRkdHB7V5SZqVIuJnE/V5yEWSCmGgS1IhDHRJKoSBLkmFMNAlqRBdr3KJiGuA3weeyMzf7tAfwBXAWcAzwLmZ+Z9NFyrp8HTxTffxtTseYbp/52/unGDv/slXEkACEUx7e4O2cME8Lln7W6w/aUlj6+xlD/1aYM0k/WcCK6uvTcA/Tb8sSbPBxTfdx3Vbpx/mQNcwh3aYw+wPc4Cnn93DBd+4h5vu3t3YOrsGemb+AHhqkiHrgK9k21ZgYUS8uqkCJR2+Nt/x6KBLmNX27E8u37K9sfU1cQx9CVB/VXdVbQeJiE0RMRoRo2NjYw1sWtIg7SthV3nAHnv62cbWNaMnRTPzqsxsZWZreLjjJ1clzSJDEYMuYdY7ZuGCxtbVRKDvBpbVlpdWbZIKt/Ety7oP0oTmzQkuOOOExtbXRKCPAB+MtlOAX2Xm4w2sV9Jh7rL1qzn7lGNpYkd97pzuKzkwooRfDBYumMfl739jo1e5RLd7ikbEZuB0YDHwc+CvgHkAmfnF6rLFL9C+EuYZ4EOZ2fWvbrVarfSPc0nS1ETEXZnZ6tTX9Tr0zNzYpT+Bjx1ibZKkhvhJUUkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCtFToEfEmojYHhE7IuLCDv3HRcT3I+LeiLgtIpY2X6okaTJdAz0ihoArgTOBVcDGiFg1btjfAl/JzDcAlwKfbbpQSdLketlDPxnYkZk7M/N54Hpg3bgxq4Bbqse3duiXJPVZL4G+BHi0tryraqu7B3hP9fjdwG9ExNHjVxQRmyJiNCJGx8bGDqVeSdIEmjop+kngtIi4GzgN2A3sGz8oM6/KzFZmtoaHhxvatCQJYG4PY3YDy2rLS6u2X8vMx6j20CPiZcB7M/PppoqUJHXXyx76ncDKiFgREfOBDcBIfUBELI6IA+u6CLim2TIlSd10DfTM3AucD2wBHgBuyMxtEXFpRKythp0ObI+IB4FXAZ/pU72SpAlEZg5kw61WK0dHRweybUmarSLirsxsderzk6KSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpED0FekSsiYjtEbEjIi7s0H9sRNwaEXdHxL0RcVbzpUqSJtM10CNiCLgSOBNYBWyMiFXjhl0M3JCZJwEbgH9sulBJ0uR62UM/GdiRmTsz83ngemDduDEJvLx6/ArgseZKlCT1opdAXwI8WlveVbXVXQKcHRG7gJuBj3daUURsiojRiBgdGxs7hHIlSRNp6qToRuDazFwKnAV8NSIOWndmXpWZrcxsDQ8PN7RpSRL0Fui7gWW15aVVW915wA0AmXk78BJgcRMFSpJ600ug3wmsjIgVETGf9knPkXFjHgHeDhARr6cd6B5TkaQZ1DXQM3MvcD6wBXiA9tUs2yLi0ohYWw37BPDhiLgH2Aycm5nZr6IlSQeb28ugzLyZ9snOetuna4/vB05ttjRJ0lT4SVFJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgoxt5dBEbEGuAIYAq7OzM+N6/888NZq8aXAKzNzYZOFSv128U33cd3WRwZdxqx06msW8bUP/86gy3jR67qHHhFDwJXAmcAqYGNErKqPycw/z8wTM/NE4B+Ab/ajWKlfDPPp+eHDT/GHX7p90GW86PVyyOVkYEdm7szM54HrgXWTjN8IbG6iOGmmbL7j0UGXMOv98OGnBl3Ci14vgb4EqL/bd1VtB4mI44AVwC0T9G+KiNGIGB0bG5tqrVLf7MscdAnStDV9UnQDcGNm7uvUmZlXZWYrM1vDw8MNb1o6dEMRgy5BmrZeAn03sKy2vLRq62QDHm7RLLTxLcu6D9KkTn3NokGX8KLXS6DfCayMiBURMZ92aI+MHxQRrwOOAjwzolnnsvWrOfuUYwddxqzlVS6Hh66XLWbm3og4H9hC+7LFazJzW0RcCoxm5oFw3wBcn+nBSM1Ol61fzWXrVw+6DOmQ9XQdembeDNw8ru3T45Yvaa4sSdJU+UlRSSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVIieAj0i1kTE9ojYEREXTjDmAxFxf0Rsi4h/bbZMSVI3c7sNiIgh4ErgncAu4M6IGMnM+2tjVgIXAadm5i8j4pX9KliS1Fkve+gnAzsyc2dmPg9cD6wbN+bDwJWZ+UuAzHyi2TIlSd30EuhLgEdry7uqtrrXAq+NiB9GxNaIWNNpRRGxKSJGI2J0bGzs0CqWJHXU1EnRucBK4HRgI/CliFg4flBmXpWZrcxsDQ8PN7RpSRL0Fui7gWW15aVVW90uYCQz92TmT4EHaQe8JGmG9BLodwIrI2JFRMwHNgAj48bcRHvvnIhYTPsQzM4G65QkddE10DNzL3A+sAV4ALghM7dFxKURsbYatgV4MiLuB24FLsjMJ/tVtCTpYJGZA9lwq9XK0dHRgWxbkmariLgrM1ud+vykqCQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhZjby6CIWANcAQwBV2fm58b1nwtcDuyumr6QmVc3WCcAN929m8u3bGf30882vepJHTl/iMzkmT37Z3S707Vg3hze++alfP1Hj9BE6XOArL6atmThApYfvYDbdz7F/n5sADhi7hye29t9Io6cP8Rn3r2a9Sct6U8hUp90DfSIGAKuBN4J7ALujIiRzLx/3NCvZ+b5fagRaIf5Rd+8j2f37OvXJib0P8/P/Dab8Oye/Vy39ZHG1tfPH2e7n3627z+oewlzaL/en/jGPQCGumaVXg65nAzsyMydmfk8cD2wrr9lHezyLdsHEuZ6cdq3P7l8y/ZBlyFNSS+BvgR4tLa8q2ob770RcW9E3BgRyzqtKCI2RcRoRIyOjY1NqdDHZvgwi+R7TrNNUydFvw0sz8w3AN8FvtxpUGZelZmtzGwNDw9PaQPHLFww/SqlKfA9p9mml0DfDdT3uJfywslPADLzycx8rlq8GnhzM+W94IIzTmDBvKGmVyt1NDQnuOCMEwZdhjQlvQT6ncDKiFgREfOBDcBIfUBEvLq2uBZ4oLkS29aftITPvmc1Swaw13Tk/CFeOm/2XeG5YN4czj7lWJoqfQ4QzazqIEsWLuDU1yxiTr82QPsql14cOX+Iv3v/Gz0hqlmn61Uumbk3Is4HttC+bPGazNwWEZcCo5k5AvxZRKwF9gJPAef2o9j1Jy3xP9khuGz96kGXIGkGRGafLvrtotVq5ejo6EC2LUmzVUTclZmtTn2z7ziCJKkjA12SCmGgS1IhDHRJKsTATopGxBjws0N8+mLgFw2W0xTrmhrrmhrrmrrDtbbp1HVcZnb8ZObAAn06ImJ0orO8g2RdU2NdU2NdU3e41tavujzkIkmFMNAlqRCzNdCvGnQBE7CuqbGuqbGuqTtca+tLXbPyGLok6WCzdQ9dkjSOgS5JhTgsAz0i3h8R2yJif0RMeGlPRKyJiO0RsSMiLqy1r4iIO6r2r1d/9rep2hZFxHcj4qHq+1Edxrw1In5c+/rfiFhf9V0bET+t9Z04U3VV4/bVtj1Sa+/LnPU4XydGxO3Va35vRPxBra/R+ZroPVPrP6L69++o5mN5re+iqn17RJwxnToOoa6/iIj7q/n5fkQcV+vr+JrOUF3nRsRYbft/XOs7p3rdH4qIc2a4rs/XanowIp6u9fVzvq6JiCci4icT9EdE/H1V970R8aZa3/TnKzMPuy/g9cAJwG1Aa4IxQ8DDwPHAfOAeYFXVdwOwoXr8ReCjDdb2N8CF1eMLgb/uMn4R7T8p/NJq+VrgfX2Ys57qAv57gva+zFkvdQGvBVZWj48BHgcWNj1fk71namP+FPhi9XgD7ZufA6yqxh8BrKjWMzSDdb219h766IG6JntNZ6iuc4EvdHjuImBn9f2o6vFRM1XXuPEfp/1nv/s6X9W6fw94E/CTCfrPAr5D+9YCpwB3NDlfh+UeemY+kJnd7tDb8ebVERHA24Abq3FfBtY3WN46XrjFXi/rfh/wncx8psEaOplqXb/W5znrWldmPpiZD1WPHwOeAKZ2j8Le9HLD83q9NwJvr+ZnHXB9Zj6XmT8FdlTrm5G6MvPW2ntoK+07h/XbdG4Qfwbw3cx8KjN/SfvWlGsGVNdGYHND255UZv6A9g7cRNYBX8m2rcDCaN8gqJH5OiwDvUcT3bz6aODpzNw7rr0pr8rMx6vH/wW8qsv4DRz8ZvpM9evW5yPiiBmu6yXRvlH31gOHgejvnE1pviLiZNp7XQ/Xmpuar15ueP7rMdV8/Ir2/PR6s/R+1VV3Hu29vAM6vaYzWVenG8QfFvNVHZpaAdxSa+7XfPViotobma+udyzql4j4HvCbHbo+lZn/PtP11E1WW30hMzMiJrzus/rJu5r23Z4OuIh2sM2nfS3qXwKXzmBdx2Xm7og4HrglIu6jHVqHrOH5+ipwTmbur5oPeb5KFBFnAy3gtFrzQa9pZj7ceQ2N+zawOTOfi4g/of3bzdtmaNu92ADcmJn7am2DnK++GligZ+Y7prmKiW5e/STtX2PmVntYB93Uejq1RcTPI+LVmfl4FUBPTLKqDwDfysw9tXUf2Ft9LiL+BfjkTNaVmbur7zsj4jbgJODfmMacNVFXRLwc+A/aP9C31tZ9yPPVQdcbntfG7IqIucAraL+nenluP+siIt5B+4fkafnCTdknek2bCKiebhBfW7ya9jmTA889fdxzb2ugpp7qqtkAfKze0Mf56sVEtTcyX7P5kEvHm1dn+wzDrbSPXQOcAzS5xz9SrbOXdR907K4KtQPHrdcDHc+G96OuiDjqwCGLiFgMnArc3+c566Wu+cC3aB9bvHFcX5Pz1fWG5+PqfR9wSzU/I8CGaF8FswJYCfxoGrVMqa6IOAn4Z2BtZj5Ra+/4ms5gXRPdIH4L8K6qvqOAd/H/f1Pta11Vba+jfYLx9lpbP+erFyPAB6urXU4BflXttDQzX/062zudL+DdtI8hPQf8HNhStR8D3FwbdxbwIO2frp+qtR9P+z/bDuAbwBEN1nY08H3gIeB7wKKqvQVcXRu3nPZP3Tnjnn8LcB/tYLoOeNlM1QX8brXte6rv5/V7znqs62xgD/Dj2teJ/ZivTu8Z2odw1laPX1L9+3dU83F87bmfqp63HTiz4fd8t7q+V/1fODA/I91e0xmq67PAtmr7twKvqz33j6p53AF8aCbrqpYvAT437nn9nq/NtK/S2kM7w84DPgJ8pOoP4Mqq7vuoXcXXxHz50X9JKsRsPuQiSaox0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1Ih/g9Sp439On0bGgAAAABJRU5ErkJggg==\n",
149 | "text/plain": [
150 | ""
151 | ]
152 | },
153 | "metadata": {
154 | "tags": [],
155 | "needs_background": "light"
156 | }
157 | }
158 | ]
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "metadata": {
163 | "id": "BHcdUlW9koTa"
164 | },
165 | "source": [
166 | "## データをまとめた表を作成し、可視化する"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "metadata": {
172 | "id": "1EMwdGIIIPrK",
173 | "outputId": "c3575b0b-6d6d-4bbf-b59c-321d7e5978e2",
174 | "colab": {
175 | "base_uri": "https://localhost:8080/",
176 | "height": 195
177 | }
178 | },
179 | "source": [
180 | "df = pd.DataFrame({'x': x,\n",
181 | " 'Z': Z,\n",
182 | " 't': t,\n",
183 | " 'Y': Y,\n",
184 | " })\n",
185 | "\n",
186 | "df.head() # 先頭を表示\n"
187 | ],
188 | "execution_count": 4,
189 | "outputs": [
190 | {
191 | "output_type": "execute_result",
192 | "data": {
193 | "text/html": [
194 | "\n",
195 | "\n",
208 | "
\n",
209 | " \n",
210 | " \n",
211 | " | \n",
212 | " x | \n",
213 | " Z | \n",
214 | " t | \n",
215 | " Y | \n",
216 | "
\n",
217 | " \n",
218 | " \n",
219 | " \n",
220 | " 0 | \n",
221 | " -0.616961 | \n",
222 | " 0.0 | \n",
223 | " 0.5 | \n",
224 | " 1.803183 | \n",
225 | "
\n",
226 | " \n",
227 | " 1 | \n",
228 | " 0.244218 | \n",
229 | " 1.0 | \n",
230 | " 0.7 | \n",
231 | " 2.668873 | \n",
232 | "
\n",
233 | " \n",
234 | " 2 | \n",
235 | " -0.124545 | \n",
236 | " 0.0 | \n",
237 | " 0.5 | \n",
238 | " 2.193123 | \n",
239 | "
\n",
240 | " \n",
241 | " 3 | \n",
242 | " 0.570717 | \n",
243 | " 1.0 | \n",
244 | " 1.0 | \n",
245 | " 3.245229 | \n",
246 | "
\n",
247 | " \n",
248 | " 4 | \n",
249 | " 0.559952 | \n",
250 | " 1.0 | \n",
251 | " 1.0 | \n",
252 | " 3.139868 | \n",
253 | "
\n",
254 | " \n",
255 | "
\n",
256 | "
"
257 | ],
258 | "text/plain": [
259 | " x Z t Y\n",
260 | "0 -0.616961 0.0 0.5 1.803183\n",
261 | "1 0.244218 1.0 0.7 2.668873\n",
262 | "2 -0.124545 0.0 0.5 2.193123\n",
263 | "3 0.570717 1.0 1.0 3.245229\n",
264 | "4 0.559952 1.0 1.0 3.139868"
265 | ]
266 | },
267 | "metadata": {
268 | "tags": []
269 | },
270 | "execution_count": 4
271 | }
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "metadata": {
277 | "id": "L6Tb2Hjk9vno",
278 | "outputId": "dd55119a-872c-4dc1-ebc7-0f66a3bab881",
279 | "colab": {
280 | "base_uri": "https://localhost:8080/",
281 | "height": 282
282 | }
283 | },
284 | "source": [
285 | "plt.scatter(x, Y)\n"
286 | ],
287 | "execution_count": 5,
288 | "outputs": [
289 | {
290 | "output_type": "execute_result",
291 | "data": {
292 | "text/plain": [
293 | ""
294 | ]
295 | },
296 | "metadata": {
297 | "tags": []
298 | },
299 | "execution_count": 5
300 | },
301 | {
302 | "output_type": "display_data",
303 | "data": {
304 | "image/png": "\n",
305 | "text/plain": [
306 | ""
307 | ]
308 | },
309 | "metadata": {
310 | "tags": [],
311 | "needs_background": "light"
312 | }
313 | }
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "metadata": {
319 | "id": "AeC7Uv29KsXC"
320 | },
321 | "source": [
322 | "## DR-Learnerの開始、まずはT-Learner"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "metadata": {
328 | "id": "xp2P-IDT9vql",
329 | "outputId": "122ee358-9b12-4166-94aa-38ad5ef19208",
330 | "colab": {
331 | "base_uri": "https://localhost:8080/"
332 | }
333 | },
334 | "source": [
335 | "# ランダムフォレストモデルを作成\n",
336 | "from sklearn.ensemble import RandomForestRegressor\n",
337 | "\n",
338 | "# 集団を2つに分ける\n",
339 | "df_0 = df[df.Z == 0.0] # 介入を受けていない集団\n",
340 | "df_1 = df[df.Z == 1.0] # 介入を受けた集団\n",
341 | "\n",
342 | "# 介入を受けていないモデル\n",
343 | "M_0 = RandomForestRegressor(max_depth=3)\n",
344 | "M_0.fit(df_0[[\"x\"]], df_0[[\"Y\"]])\n",
345 | "\n",
346 | "# 介入を受けたモデル\n",
347 | "M_1 = RandomForestRegressor(max_depth=3)\n",
348 | "M_1.fit(df_1[[\"x\"]], df_1[[\"Y\"]])\n"
349 | ],
350 | "execution_count": 6,
351 | "outputs": [
352 | {
353 | "output_type": "stream",
354 | "text": [
355 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:10: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
356 | " # Remove the CWD from sys.path while we load stuff.\n",
357 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:14: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
358 | " \n"
359 | ],
360 | "name": "stderr"
361 | },
362 | {
363 | "output_type": "execute_result",
364 | "data": {
365 | "text/plain": [
366 | "RandomForestRegressor(bootstrap=True, ccp_alpha=0.0, criterion='mse',\n",
367 | " max_depth=3, max_features='auto', max_leaf_nodes=None,\n",
368 | " max_samples=None, min_impurity_decrease=0.0,\n",
369 | " min_impurity_split=None, min_samples_leaf=1,\n",
370 | " min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
371 | " n_estimators=100, n_jobs=None, oob_score=False,\n",
372 | " random_state=None, verbose=0, warm_start=False)"
373 | ]
374 | },
375 | "metadata": {
376 | "tags": []
377 | },
378 | "execution_count": 6
379 | }
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "metadata": {
385 | "id": "wAeHIJiqOF-h"
386 | },
387 | "source": [
388 | "# 傾向スコアを求めます\n",
389 | "from sklearn.linear_model import LogisticRegression\n",
390 | "\n",
391 | "# 説明変数\n",
392 | "X = df[[\"x\"]]\n",
393 | "\n",
394 | "# 被説明変数(目的変数)\n",
395 | "Z = df[\"Z\"]\n",
396 | "\n",
397 | "# 回帰の実施\n",
398 | "g_x = LogisticRegression().fit(X, Z)\n",
399 | "g_x_val = g_x.predict_proba(X)\n"
400 | ],
401 | "execution_count": 7,
402 | "outputs": []
403 | },
404 | {
405 | "cell_type": "markdown",
406 | "metadata": {
407 | "id": "xTjMfuZTNrLO"
408 | },
409 | "source": [
410 | "## DR法に基づく推定"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "metadata": {
416 | "id": "jwEhxJQeNvhw",
417 | "outputId": "a56db28c-a884-489c-b597-0d7190281fbf",
418 | "colab": {
419 | "base_uri": "https://localhost:8080/",
420 | "height": 406
421 | }
422 | },
423 | "source": [
424 | "# 処置群\n",
425 | "Y_1 = M_1.predict(df_1[[\"x\"]]) + (df_1[\"Y\"] - M_1.predict(df_1[[\"x\"]])) / \\\n",
426 | " g_x.predict_proba(df_1[[\"x\"]])[:, 1] # [:,1]はZ=1側の確率\n",
427 | "df_1[\"ITE\"] = Y_1 - M_0.predict(df_1[[\"x\"]])\n",
428 | "\n",
429 | "# 非処置群\n",
430 | "Y_0 = M_0.predict(df_0[[\"x\"]]) + (df_0[\"Y\"] - M_0.predict(df_0[[\"x\"]])) / \\\n",
431 | " g_x.predict_proba(df_0[[\"x\"]])[:, 0] # [:,0]はZ=0側の確率\n",
432 | "df_0[\"ITE\"] = M_1.predict(df_0[[\"x\"]]) - Y_0\n",
433 | "\n",
434 | "# 表を結合する\n",
435 | "df_DR = pd.concat([df_0, df_1])\n",
436 | "df_DR.head()\n"
437 | ],
438 | "execution_count": 8,
439 | "outputs": [
440 | {
441 | "output_type": "stream",
442 | "text": [
443 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:3: SettingWithCopyWarning: \n",
444 | "A value is trying to be set on a copy of a slice from a DataFrame.\n",
445 | "Try using .loc[row_indexer,col_indexer] = value instead\n",
446 | "\n",
447 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
448 | " This is separate from the ipykernel package so we can avoid doing imports until\n",
449 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
450 | "A value is trying to be set on a copy of a slice from a DataFrame.\n",
451 | "Try using .loc[row_indexer,col_indexer] = value instead\n",
452 | "\n",
453 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
454 | " import sys\n"
455 | ],
456 | "name": "stderr"
457 | },
458 | {
459 | "output_type": "execute_result",
460 | "data": {
461 | "text/html": [
462 | "\n",
463 | "\n",
476 | "
\n",
477 | " \n",
478 | " \n",
479 | " | \n",
480 | " x | \n",
481 | " Z | \n",
482 | " t | \n",
483 | " Y | \n",
484 | " ITE | \n",
485 | "
\n",
486 | " \n",
487 | " \n",
488 | " \n",
489 | " 0 | \n",
490 | " -0.616961 | \n",
491 | " 0.0 | \n",
492 | " 0.5 | \n",
493 | " 1.803183 | \n",
494 | " 0.514190 | \n",
495 | "
\n",
496 | " \n",
497 | " 2 | \n",
498 | " -0.124545 | \n",
499 | " 0.0 | \n",
500 | " 0.5 | \n",
501 | " 2.193123 | \n",
502 | " 0.081865 | \n",
503 | "
\n",
504 | " \n",
505 | " 5 | \n",
506 | " -0.454815 | \n",
507 | " 0.0 | \n",
508 | " 0.5 | \n",
509 | " 1.973293 | \n",
510 | " 0.333970 | \n",
511 | "
\n",
512 | " \n",
513 | " 6 | \n",
514 | " -0.447071 | \n",
515 | " 0.0 | \n",
516 | " 0.5 | \n",
517 | " 1.953387 | \n",
518 | " 0.364906 | \n",
519 | "
\n",
520 | " \n",
521 | " 9 | \n",
522 | " 0.751865 | \n",
523 | " 0.0 | \n",
524 | " 1.0 | \n",
525 | " 2.289369 | \n",
526 | " 0.776072 | \n",
527 | "
\n",
528 | " \n",
529 | "
\n",
530 | "
"
531 | ],
532 | "text/plain": [
533 | " x Z t Y ITE\n",
534 | "0 -0.616961 0.0 0.5 1.803183 0.514190\n",
535 | "2 -0.124545 0.0 0.5 2.193123 0.081865\n",
536 | "5 -0.454815 0.0 0.5 1.973293 0.333970\n",
537 | "6 -0.447071 0.0 0.5 1.953387 0.364906\n",
538 | "9 0.751865 0.0 1.0 2.289369 0.776072"
539 | ]
540 | },
541 | "metadata": {
542 | "tags": []
543 | },
544 | "execution_count": 8
545 | }
546 | ]
547 | },
548 | {
549 | "cell_type": "code",
550 | "metadata": {
551 | "id": "XvOWVBt99vq7",
552 | "outputId": "97da2ae4-09f7-4411-a4ef-1f177442fa60",
553 | "colab": {
554 | "base_uri": "https://localhost:8080/",
555 | "height": 338
556 | }
557 | },
558 | "source": [
559 | "# モデルM_DRを構築し、各人の治療効果をモデルから求める\n",
560 | "\n",
561 | "# モデルM_DR\n",
562 | "M_DR = RandomForestRegressor(max_depth=3)\n",
563 | "M_DR.fit(df_DR[[\"x\"]], df_DR[[\"ITE\"]])\n",
564 | "\n",
565 | "\n",
566 | "# 推定された治療効果を各人ごとに求めます\n",
567 | "t_estimated = M_DR.predict(df_DR[[\"x\"]])\n",
568 | "plt.scatter(df_DR[[\"x\"]], t_estimated,\n",
569 | " label=\"estimated_treatment-effect\")\n",
570 | "\n",
571 | "# 正解のグラフを作成\n",
572 | "x_index = np.arange(-1, 1, 0.01)\n",
573 | "t_ans = np.zeros(len(x_index))\n",
574 | "for i in range(len(x_index)):\n",
575 | " if x_index[i] < 0:\n",
576 | " t_ans[i] = 0.5\n",
577 | " elif x_index[i] >= 0 and x_index[i] < 0.5:\n",
578 | " t_ans[i] = 0.7\n",
579 | " elif x_index[i] >= 0.5:\n",
580 | " t_ans[i] = 1.0\n",
581 | "\n",
582 | "\n",
583 | "# 正解を描画\n",
584 | "plt.plot(x_index, t_ans, color='black', ls='--', label='Baseline')\n"
585 | ],
586 | "execution_count": 9,
587 | "outputs": [
588 | {
589 | "output_type": "stream",
590 | "text": [
591 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:5: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
592 | " \"\"\"\n"
593 | ],
594 | "name": "stderr"
595 | },
596 | {
597 | "output_type": "execute_result",
598 | "data": {
599 | "text/plain": [
600 | "[]"
601 | ]
602 | },
603 | "metadata": {
604 | "tags": []
605 | },
606 | "execution_count": 9
607 | },
608 | {
609 | "output_type": "display_data",
610 | "data": {
611 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de3iU5Z3/8fc3kwQip3AIp0AAEUEUC25+otW2irVQvUSqbsFqt+5a7clu1ZafWA+1alda9rJ2W7uuutrt1oraakrVbtZj+S0VJDQogpuKhyJB0XKySAg5fH9/zDNhJjOTTMJM5sDndV25mLnve+b55pnwzZ37ue/7MXdHRETyX1G2AxARkfRQQhcRKRBK6CIiBUIJXUSkQCihi4gUiOJsHXjEiBE+ceLEbB1eRCQvrVu37i/uXpGoLmsJfeLEidTV1WXr8CIiecnM/pysTkMuIiIFQgldRKRAKKGLiBQIJXQRkQKhhC4iUiCyNstFRCTTauobWVbbwLbdTYwtL2Px3KksmFWZ7bAyRgldRApSTX0j1z66gaaWNgAadzdx7aMbAAo2qWvIRUQK0rLaho5kHtHU0say2oYsRZR5SugiUpC27W7qUXkhUEIXkYI0trysR+WFQAldRArS4rlTKSsJxZSVlYRYPHdqliLKPF0UFZGCFLnwqVkuIiIFYMGsyoJO4J11O+RiZveZ2Xtm9kqSejOzfzGzzWb2spmdkP4wRUR6pqa+kVOWPsukJU9wytJnqalvzHZIGZfKGPrPgHld1H8amBJ8XQ7866GHJSLSe5E56I27m3AOzkEv9KTe7ZCLu680s4ldNDkX+Lm7O7DazMrNbIy7v5OmGEUkh/3Xf/0X7777Lueddx6DBw/mlVdeSXivg4ULF1JWVkZ9fT0vvfRSXP1FF11ESUkJL774Ips2bYqrv+SSSwBYtWoVr732WkxdSUkJF110EQC///3vWfyTJ9i1r6Wj3or7wTEfY1ltAwN3vMrWrVtjXj9kyBA+85nPAPC73/2O7du3x9QPHz6cc845B4AVK1awc+fOmPrRo0czb1643/vrX/+av/71rzH148eP54wzzoj7ntLO3bv9AiYCrySpexw4Ner5M0B1kraXA3VAXVVVlYtIftu+fbsDDvhrr73m7u4/+MEPOsqiv9555x13d7/hhhsS1u/du9fd3a+88sq4OjPrOOall14aVz948OCO+oULF8bVhwYO9wnXPO4Tr3nczzrrrLj6qVOndrz+Yx/7WFx9dXV1R/3MmTPj6k8//fSO+smTJ8fVz58/P23nHKjzJLnawvVdC3roj7v7cQnqHgeWuvv/BM+fAa5x9y5vR1RdXe26Y5FIfnvjjTeYPHkyy5Yt4xvf+AYlJSXs2bOHXbt2xbUdP348oVCIXbt2sWfPnrj6qqoqioqK2LlzJx988EFcfeSWlX/5y1/Yu3dvTJ2ZMWHCBADef/995v3zU2z/YH9UgyKKB1dQWV7Go39/LE1NsYuLSkpKqKwMXzx95513aG5ujqkvLS1l7NixAGzbto0DBw7E1Pfv35/Ro0cDsHXrVlpbW2Pqy8rKGDVqVNz31Btmts7dqxPVpWOWSyMwPur5uKBMRApcS0t4WGPMmDGUlJQA4eGLIUOGAOGx7Ose28CHB9qAjRhw0UlV3LpgRtL3HDZsGMOGDUtaP2LECEaMGJG0vqKighsWfSJmHxc4OAe9u8Q6ZsyYLusjiT2ZcePGdVmfSelI6CuAK8xsOTAb2OMaPxc5LER6opFkHr274ZCyEj7Y30J71CCAA79YvQWgy6R+qA7HOehA90MuZvYgcBowAtgOfAcoAXD3u8zMgJ8QngmzD/j77oZbQEMuIoVgx44dPPjgg8ybN49X/loW1ytOJmTG67ed1QcRFp6uhlxSGkPPBCV0kcJyytJnaezFxleVveg919Q38u1HX2ZfS3vXMU0exgOXndzjmHKZErqIZMTevXt58803mTRpEjNu/T2ZziYGPT6GAT9cOLNghlu6SujanEtEem3dunUcf/zxvPjiiwwpK8n48XrzC8OBKx9az/U1G9IdTs5RQheRXovMcln95m52N7V00zq7Hli9peBXiiqhi0ivReZj/+v/eythvfVhLN1xKOi7FYESuogcgkgPvZVQwnoH+odyJ60X8t2KQAldRA5BJKFbKPmSlv1t2Zl4kUgh360IlNBF5BBUV1dz1AWLCQ1KvnIzoqwkxB0LZ/LW0rO5Y+FMjihJT/rp/L7J/h4wKOi7FYFucCEih2DixIm0Tv5EkgGXWE0tbSyrbei46URvpxFGr0btvAI08m+iBU7RY+g9OfZF97zAqtd3dt8wgSkjB/DU1af16rW9oXnoItJr7777Lmd+9xE+GDgeC3U/bdGAN5eenfG4Ikk/2UKnIqAdKDJitibIhMH9Qrz83a5uKdEzmocuIhlRU1PDK3f9I6Ut+1Jq31dj2AtmVXY5vBJZX5rpZA7wQXNbn82BV0IXkV6LXBS9Yf4MKrtJ1pHdDvtKLk1RfHDN231yHCV0Eem1SEI/92+qWLVkTtKkHjLjtvNm9Ony+1yaotjm3if3N1VCF5FeiyT0yPa5ycas29z7fC+VXJui2Bf3N1VCF5FeiyT00tJSINwTTyRZeSYtnjuVotxZ0xQjMuMn3TRtUUR67YILLmDatGmEQuGJi21JZs0lK8+kuj/vTPtFzzu62bWxJ1sIZ2JISAldRHpt2rRpTJs2reN5ZXlZwoTW3QXTTEh2ITLRFrzpmlq4eO5UrnxofUptizMwPqKELiK99uqrr7Jt2zbOOOMMIJzQkt3Ls68l+6vAgbcyNBd+waxKHqnbktJCpJb28Hz5dF5bSOl3hJnNM7MGM9tsZksS1E8ws2fM7GUze97MsneXVBHpM3fddRfnn39+x/MFsyq57bzwFEYj3DPv69ktEdkaz3/gspO5+KQqUjlMusfRu+2hm1kIuBM4E9gKrDWzFe6+KarZPwM/d/f/MLM5wG3A59MaqYjknJaWlo4ZLhGHsqw/nS6cPb7jhtSdyzPt1gUzeOyPjXx4oOv7q/bmln1dSaWHfiKw2d3fcPcDwHLg3E5tpgPPBo+fS1AvIgUoUULPFbcumMHFJ1V19MhDZlx8UhW3LpiR8WNfX7Oh22QO4fH8dE5fTGUMvRKIvrqwFZjdqc1LwHnAj4DPAIPMbLi774huZGaXA5cDVFVV9TZmEckRuZzQIZzU+yKBd5bqytDIhmHp+osmXddZvwV8wszqgU8AjUDcryd3v9vdq929uqKiIk2HFpFsyfWEni09maaZzumLqfTQG4HoQadxQVkHd99GuIeOmQ0Eznf33ekKUkRy05IlS9i9W//VOwuZpZzU07miNZWEvhaYYmaTCCfyRcDnohuY2Qhgp7u3A9cC96UtQhHJWTNmJB/O6Grf8kKX7IJsqMhoi1rtlO4pnd0mdHdvNbMrgFogBNzn7hvN7Gagzt1XAKcBt5mZAyuBr6UtQpE+dDgnod5YuXIlra2tzJkzJ6a8pr4xZj56ZP8S6NnNJfJVZNz+wTVv0+ZOyIwLZ4+nesKwjP586QYXIoHOSQjCPahszaPOB3PnzmXPnj2sXr06pjzZEvjK8jJWLZkTVy6p0w0uRFKwrLYh7rZlmdpEqVAkuyia7EJfLm1pW4i09F8koCTUcy0tLeze384pS5+NGUYYm2RPl1zb0rbQqIcuEkiWbJSEkntv94e8uXN/3F7fp0+roKwk9tbR2drT5XCihC4SWDx3qpJQDzXu3Eu7xZ6zppY2nvvf93NmT5fDiYZcRAKRZKNZLqkbPO8bWFEorrxxd1PO7OlyOFFCF4miJNQzZSMnJVxAk407FImGXETkEHywaSX7334lrjwbdygS9dBFYmhhUc98sPJ+SsbNoP/442LKs3GHIlFCF+lwuK9u7I0BxdBeHDsPXReSs0dDLiIBLSzquSJv4+QpIzWbJUeohy4S0MKinmtpaWHq2KHcoeX8OUE9dJGAFhb1nPZDzy3qoYsEcumO9fli3bp1lJeXZzsMCSihiwS0sKjnjjnmmGyHIFGU0EWiaGFR6tra2vjxj3/MqaeeSnV1wt1cpY9pDF1EeqW5uZmrrrqKZ555JtuhSEA9dJEoWliUupaWFgBdFM0hKfXQzWyemTWY2WYzW5KgvsrMnjOzejN72czOSn+oIpkVWVjUeSvYmvrGbl97OFJCzz3dJnQzCwF3Ap8GpgMXmtn0Ts2uBx5291mEbyL903QHKpJpWljUM0rouSeVHvqJwGZ3f8PdDwDLgXM7tXFgcPB4CLAtfSGK9A0tLOoZJfTck8oYeiXwdtTzrcDsTm1uAv7bzL4ODAA+meiNzOxy4HKAqqqqnsYqklG6bVrPjB07li1btmgeeg5J1yyXC4Gfufs44CzgP80s7r3d/W53r3b36oqKijQdWiQ9dMeinikuLmb8+PEMGjQo26FIIJWE3giMj3o+LiiLdinwMIC7vwD0B0akI0CRvrJgVqVum9YD27dv55ZbbqGhQdcYckUqQy5rgSlmNolwIl8EfK5Tmy3AGcDPzOwYwgn9/XQGKtIXtLAodY2Njdx4440cf/zxTJ2qv2JyQbc9dHdvBa4AaoFXCc9m2WhmN5vZ/KDZN4HLzOwl4EHgEnfdskSkkOmiaO5JaWGRuz8JPNmp7Maox5uAU9IbmojkMiX03KOl/yLSK0rouUcJXUR6RQk992gvFxHplTlz5rBz504GDhyY7VAkoIQuIr1SXFzM0KFDsx2GRFFCF5GU1NQ38u1HX2ZfSzsALe+9wYTd9dTctZSRI0dmOToBjaGLSApq6hu5+uH1Hckc4MBftvDCY/fznUfWZDEyiaaELiLdWlbbQHunlSXeFt6ZcsWG97TFcI5QQheRbiXacdLbWwGwomIWP/KSknoOUEIXkW4l3HGyPdxDt1CIlnbnphUb+zgq6UwJXUS6tXjuVIostszbwj10isJzK3Y3tfRxVNKZZrmISLciG5ZFz3IZ9DfnMOiEsyF+p2zJEn0SIpKSBbMq2XTLpxlQGt4z3sywohBm4a57pFyyRwldRHqkJBROG/s2v8iO//4pkY1VI+WSPRpyEZFu1dQ3sqy2gW27m4jMXmze1sDe+t8x/FNfBWCPxtCzTgldJIHoBDa2vIzFc6f26MYXNfWNLH5kPVHrcAC4+KQqbl0wI83RZlZNfSPXPrqBppa22Ir2VggdHGbRvVezTwldpJNwMn6JlmAlTePuJhY/8hJASkn9zNuf57X3PkxY94vVWwDyKqkvq22IT+aEZ7lYKLzTou69mhtSSuhmNg/4ERAC7nX3pZ3qfwicHjw9Ahjp7roVuOSlm1Zs7EjmES3tzpUPrefKh9Yf8vs/sHpLXiX0RIuKAGhvw4pCVPbiLxjJjG4TupmFgDuBM4GtwFozWxHcpQgAd78qqv3XgVkZiFWkT2R6PnW+3ZtxbHkZjQmS+sD+pZRXDGXVkjlZiEoSSaWHfiKw2d3fADCz5cC5wKYk7S8EvpOe8ET61mN/3Epz4//SfmAfoSOGUDpqMgD7t7x8cCFNIDRwGKUVEwFoems9eOyAeWjQCEpHVOHu7H+rPqbupp/u4XOfrOa7v9/B//zpPfZveTkuluLyMUyfehS//epJrFy5Mq5+6tSpTJgwgb179/KHP/whrn769OmMGzeOPXv2sGZN/AZaxx9/PKNHj2bHjh2sW7curn7WrFlUVFRw2f8Zzo3//jjNbQe/v36hIn7wTzdz8SeOjXudZJG7d/kFXEB4mCXy/PPAT5K0nQC8A4SS1F8O1AF1VVVVLpJLHvvjVp942Z1OuBPtZUfN9gnXPO4Trnnci44o7yiPfA2YflpHvRX3i6sfOPPTPuGax71q8W/i6gAffOJ5PuGax338lQ8lrB9y6kU+4ZrHvfIrP0tYf/vtt7u7+6ZNmxLW33PPPe7uvmbNmoT1y5cvd3f3p59+OmH9E0884Y/9catP+/zNCetXrlyZtc/qcAbUeZJ8ne6LoouAX7l7/BUUwN3vBu4GqK6uzre/PKXALattYN9fdwMw7Mwv07/qIx11Iy+4sWN3wYjQEYM7Ho9a9L2O+dgd9QOCy0hWxKiLlsUdr3jQsHB1Sf/E9YNHBMcZ0lE/77hRXP7x8F8NEydOBGDChAmsWrUq7vWTJ4fbHXPMMQnrjz76aGrqG1lSuyvu+POOG8XOARP43qMb2Dt0Skd9v+IivnLaZD5+dAXHHXdc3HtKdqWS0BuB8VHPxwVliSwCvnaoQYlkw7bdTdAaHj8vHXUUJSMO/tj3G3N0l6/tVzktaZ2Z0X/cMcnri0Jd1xeXdNQ/vxueX7ErqNkF1Cd93dAjXuI75xzLglmVfPSjH42rr6lv5JuPvERbUVnc8Z/fDc//7i0AQmWDCEXV37+5iCWXxL+fZF8qCX0tMMXMJhFO5IuAz3VuZGbTgKHAC2mNUKSPjC0vY3OomOIho7DS/J9TvWtfS9pm5kRrbm3nonte4IHLTk7r+8qh6zahu3urmV0B1BKetnifu280s5sJj+WsCJouApZ75787RfLE4rlTufbDA5RN/PeOsiKg09ogykpC3HbejITT9DovSDp9WkXH3PNCsur1ndkOQRJIaQzd3Z8EnuxUdmOn5zelLyyRvhdJ0JGEPKSshA8PtNLedrCPYsD5f1OZdM71glmJ6woxqUvu0W46IlFsSx1lT/8T6/7vyQzoV0xLW+wfnA4897/v9+g9b10wgzsWzmToESVdtrv4pCreWno2by09mykjB/Q0dBEt/ReJ9tZbb/HMM88AJFxM01V5V5L13JN56urTuL5mQ8727E+ZPCzbIUgC6qGLRNm/fz8A/fv3z3Ik4Z79W0vPzrnkecrkYbogmqPUQxeJ0tzcDEC/fv2yHMlBkeR50T0vxFyMLCmCZX87M6bnX1PfyE0rNvZ6+4JRg0rZ/tcDceVDjyjpmAIpuUsJXSTK/v37KS4uJhQKETKjLcGkrZBZgldmXiq94p4O7Uhh0ZCLSJSRI0dywgknAHDh7PEJ2yQrF8k2y9a08erqaq+rq8vKsUVSdX3NBh5c8zZt7oTMuHD2+Lza+lYKj5mtc/fqhHVK6CIi+aOrhK4hF5Eo1113HRdccEG2wxDpFSV0kSgNDQ28+uqr2Q5DpFeU0EWiNDc359SURZGeUEIXibJ///6cWFQk0htK6CJRmpubldAlb2lhkUiU4447Tgld8pYSukiUn/70p9kOQaTXlNBFonS+QcXiuVO1lF7yhsbQRQI19Y18bsGneWXFv+GEt8m99tEN1NQnu4WuSG5JKaGb2TwzazCzzWa2JEmbz5rZJjPbaGa/TG+YIpm3rLaB5h1v077vg46yppY2ltU2ZDEqkdR1O+RiZiHgTuBMYCuw1sxWuPumqDZTgGuBU9x9l5mNzFTAIpmybXcT3tqCFZfElYvkg1R66CcCm939DXc/ACwHzu3U5jLgTnffBeDu76U3TJHMG1teBm0tWKgkvlwkD6SS0CuBt6Oebw3Koh0NHG1mq8xstZnNS/RGZna5mdWZWd377/fsvowimfatTx2Ntx7Aiks7yspKQiyeOzWLUYmkLl0XRYuBKcBpwIXAPWZW3rmRu9/t7tXuXl1RUZGmQ4ukxznHj6b6Y59k5PjJGFBZXsZt583QLBfJG6lMW2wEonf0HxeURdsKrHH3FuBNM/sT4QS/Ni1RivSB4uJi1q58KtthiPRaKj30tcAUM5tkZqXAImBFpzY1hHvnmNkIwkMwb6QxThER6Ua3Cd3dW4ErgFrgVeBhd99oZjeb2fygWS2ww8w2Ac8Bi919R6aCFsmExsZGxowZw/Lly7MdikivpLRS1N2fBJ7sVHZj1GMHrg6+RPJSU1MT7777Lq2trdkORaRXtFJUJLB//34A7YcueUsJXSQQSejabVHylRK6SKC5uRlQD13ylxK6SGDo0KFceOGFVFZq3rnkJ22fKxKYPn06v/yl9pWT/KUeuohIgVBCFwksX76cQYMGsXnz5myHItIrSugigX379rF3715KSkq6byySg5TQRQKRWS6atij5SgldJKCFRZLvlNBFAlpYJPlOCV0k8JGPfIQvfvGLlJaWdt9YJAdZeF+tvlddXe11dXVZObaISL4ys3XuXp2oTj10kUBbWxvZ6uCIpINWiiZQU9/ITSs2sruppaOsyKDdwYDIf3kDLjqpilsXzMhGmJJmV1xxBY8++ijbt2/PdigivaKEHrjonhdY9frOpPXtQRaP7r858IvVWwCU1PPY9TUbeGD1Ft5ftZn9Te1MXPIEF+sXteQhJXS6T+bd+cXqLR2JPRVHlBTxT+cdr5sP54DrazZ0fHbe2oIVhxcV6Re15KOUErqZzQN+BISAe919aaf6S4BlHLx59E/c/d40xplRh5LMe2NfSztXPrSeKx9a32W7UyYP44HLTu6jqA5PD655++CTthYsVBpTp4Qu+aTbhG5mIeBO4ExgK7DWzFa4+6ZOTR9y9ysyEOMhmf29p9j+1wNx5UUGn5tdRfWEYVmIKjWrXt/JxCVPUFlexuK5U1Pq0dfUN7KstoFtu5sY24PXHa7aoi6CeuuBjh565zqRfJBKD/1EYLO7vwFgZsuBc4HOCb3PXHrppTz33HMxZUceeSRPP/00AAsXLmTt2rVs291EWzD4XVIxkZHn3wDAe7/6Li1/2cL3//Xg60srp1FxzmIA3n3w27Ttib0w1n/iTIbP+zoA7/z8Ktr3fRBTX3bUbIZ98nIAtt37Vby1Oab+iOmnMfTjn8e9nW3/dlnc9zTwI3MZcvJnaW/exzv3fz2mrhH4h6fP5crq+bTt3cW7v/hW3OuHfHQhA4//FC273uG9h64Hwr99P/t9Y9iAUu745++zaNEiXn75ZRYsWBD3+h/96Eecc845/OEPf+Diiy+Oq7/33nuZM2cOTz31FF/60pfi6pcvX86JJ57Ib37zG6666qq4+t/+9rcce+yxPPDAA9xwww1x9c888wyTJk3i7rvvZunSpXH1a9asoaKigh/+8If8+Mc/jqt/5ZVXOOKII7jlllu4//77Y+rMjNdffx2Aa6+9loceeoh9B9rYva+FtvZ2rLSMsf/wE8qOPhlvOfi5hczijiOSy1JJ6JVA1N+lbAVmJ2h3vpl9HPgTcJW7v925gZldDlwOUFVV1fNoA9OnT6elpSWmbMyYMR2Pjz/+ePr168ejf9za8Q0WDxndUV868kiK+g+MeX3J8PEdj/uNPoq2QcOT14+ZSnvzh53qxx2sr5yGt8XGV1J+8Pj9xk2P+56Kh4wKPygqSlw/uCL8IFScsD40MByvFZfG1beXFjN6dPj4AwcO5NRTT417fUVF+P3Ly8sT1g8bFv5LZvjw4QnrBw8e3PE+ieoHDBgAhD+nRPVlZWUAVFZWJqyPbJhVVVWVsD4UCgHhX+yd6y0qMR911FFMmH4C6/68i+Jypxiw4vBS/0EfmRvzugtnj0ckn3S7sMjMLgDmufsXg+efB2ZHD6+Y2XBgr7s3m9mXgIXuPqer9830wqKa+sZux6i7YxAzbBF9AS3fvLX07GyH0OH6mg08sGYLnX/0ykqK2N/SzqEMdFx8UngYrfO0Uwgvumjvwfto/FxyUVcLi1LpoTcC0V2VcRy8+AmAu++Ienov8IOeBpluy2obDun15WUlrP/Op2LKIv/Bf7lmS8c0xnwxcckTaXmfyHz8TGhqSTXdJtfVjKNU3r2sJMRt583QdQfJS6kk9LXAFDObRDiRLwI+F93AzMa4+zvB0/nAq2mNshcadzf1+rUlRcZN849NWHfrghm97rkl65nmk3z7RdZTTS1tLKttUEKXvNRtQnf3VjO7AqglPG3xPnffaGY3A3XuvgL4RzObD7QCO4FLMhhzt2rqG2NWdPbE0CNK+M45x2bkP3Qqvwxq6hv59qMvsy8NvVXpnW2H0BkQyaaC3Jxr1s3/za59Ld037OSOhTNzpmcWWb1Y4B3inFRZXsaqJV1eAhLJmkMdQ89p19ds4ME1b9PmTsiMk44c2mUyv2PhTK5+eH3M0EGRwe2fzZ1kDj0b2omee96/pIjm1vaCHxrJlLKSEIvnTs12GCK9ktcJvfOskzb3Lld9VpaXdSTtQlp8s2BWZcL4a+ob+eYjL3XMxc916Zjl0pXuZrn0ZAGXSC7K64Qes2w7BZGeV7IEWGgi3+N1j23gwwNtaXnPdM9yGTWolDXXnXlI79HVXjxK0nI4yeuE3pOl2eVlJYflf+rD4ZeX9rsRCcvrG1ykujS7yEg6DVFEpFDkdUJPdWl2u0Pdn/t2R0URkb6W1wm9Jwt8ejreLiKSb/I6oUP4olcqtBWqiBS6vE/oqc4Z1laoIlLo8j6hL5hVyZSRA7ptp61QRaTQ5X1CB3jq6tM4ZXLyOw8NKA1pK1QRKXgFkdAhPBf5joUzKSsJxZSXlYT43meUzEWk8OX1wqLOCnFZv4hIqgoqocPhsTJSRCSRghlyERE53Cmhi4gUiJQSupnNM7MGM9tsZku6aHe+mbmZJdx8XUREMqfbMXQzCwF3AmcCW4G1ZrbC3Td1ajcI+AawJhOBRnS+RZsZXDRbd2gXEUmlh34isNnd33D3A8By4NwE7W4Bvg/sT2N8MWrqG7n64fUx99t0D9/p/fqaDZk6rIhIXkgloVcC0TtbbQ3KOpjZCcB4d38ijbHFWVbbkPTmCtp8S0QOd4d8UdTMioDbgW+m0PZyM6szs7r333+/x8fq6m7s2nxLRA53qST0RiB6I5RxQVnEIOA44Hkzews4CViR6MKou9/t7tXuXl1RUdHjYMd2sbOiNt8SkcNdKgl9LTDFzCaZWSmwCFgRqXT3Pe4+wt0nuvtEYDUw393r0h3s4rlTKUqSt4vMqalvTFwpInIY6Dahu3srcAVQC7wKPOzuG83sZjObn+kAoy2YVcntn51JaSg+q7e0w+JHXlJSF5HDlnmWxp6rq6u9rq53nfhTlj5LY5Lx9MryMlYtmXMooYmI5CwzW+fuCdf65OVK0a4ujnZVJyJSyPIyoXd1cbSrOhGRQpaXCX3x3KmUJBhHLymylG9JJyJSaPJy+9zI9rjf/e1Gdu1rASFeU4UAAAhLSURBVKC8rISb5h+rrXNF5LCVlwkdtO+5iEhneTnkIiIi8ZTQRUQKhBK6iEiBUEIXESkQSugiIgVCCV1EpEAooYuIFAgldBGRApG3C4sgfI/RZbUNbNvdxNjyMhbPnarFRiJy2MrbhF5T38i1j26gqaUNgMbdTVz7aPhG0UrqInI4ytshl2W1DR3JPKKppY1ltQ1ZikhEJLvyNqEn2/dc+6GLyOEqpYRuZvPMrMHMNpvZkgT1XzazDWa23sz+x8ympz/UWMn2Pdd+6CJyuOo2oZtZCLgT+DQwHbgwQcL+pbvPcPeZwA+A29MeaSeL506lrCQUU1ZWEtJ+6CJy2ErlouiJwGZ3fwPAzJYD5wKbIg3c/YOo9gOAjN+oNHLhU7NcRETCUknolcDbUc+3ArM7NzKzrwFXA6VAwrs0m9nlwOUAVVVVPY01jvZEFxE5KG0XRd39TnefDFwDXJ+kzd3uXu3u1RUVFek6tIiIkFpCbwTGRz0fF5QlsxxYcChBiYhIz6WS0NcCU8xskpmVAouAFdENzGxK1NOzgdfSF6KIiKSi2zF0d281syuAWiAE3OfuG83sZqDO3VcAV5jZJ4EWYBfwhUwGLSIi8VJa+u/uTwJPdiq7MerxN9Icl4iI9FDerhQVEZFYebs5F2i3RRGRaHmb0LXboohIrLwdctFuiyIisfI2oWu3RRGRWHmb0LXboohIrLxN6NptUUQkVt5eFNVuiyIisfI2oYN2WxQRiZa3Qy4iIhJLCV1EpEAooYuIFAgldBGRAqGELiJSIMw94/dzTnxgs/eBP/fy5SOAv6QxnHRRXD2Tq3FB7samuHqmEOOa4O4J7+GZtYR+KMyszt2rsx1HZ4qrZ3I1Lsjd2BRXzxxucWnIRUSkQCihi4gUiHxN6HdnO4AkFFfP5GpckLuxKa6eOaziyssxdBERiZevPXQREelECV1EpEDkbEI3s781s41m1m5mSaf3mNk8M2sws81mtiSqfJKZrQnKHzKz0jTFNczMnjKz14J/hyZoc7qZrY/62m9mC4K6n5nZm1F1M/sqrqBdW9SxV0SVZ/N8zTSzF4LP+2UzWxhVl9bzleznJaq+X/D9bw7Ox8SoumuD8gYzm3socfQirqvNbFNwfp4xswlRdQk/0z6K6xIzez/q+F+MqvtC8Lm/ZmZf6OO4fhgV05/MbHdUXSbP131m9p6ZvZKk3szsX4K4XzazE6LqDv18uXtOfgHHAFOB54HqJG1CwOvAkUAp8BIwPah7GFgUPL4L+Eqa4voBsCR4vAT4fjfthwE7gSOC5z8DLsjA+UopLmBvkvKsnS/gaGBK8Hgs8A5Qnu7z1dXPS1SbrwJ3BY8XAQ8Fj6cH7fsBk4L3CfVhXKdH/Qx9JRJXV59pH8V1CfCTBK8dBrwR/Ds0eDy0r+Lq1P7rwH2ZPl/Be38cOAF4JUn9WcDvAANOAtak83zlbA/d3V919+7u+HwisNnd33D3A8By4FwzM2AO8Kug3X8AC9IU2rnB+6X6vhcAv3P3fWk6fjI9jatDts+Xu//J3V8LHm8D3gMSroQ7RAl/XrqI91fAGcH5ORdY7u7N7v4msDl4vz6Jy92fi/oZWg2MS9OxDymuLswFnnL3ne6+C3gKmJeluC4EHkzTsbvk7isJd+CSORf4uYetBsrNbAxpOl85m9BTVAm8HfV8a1A2HNjt7q2dytNhlLu/Ezx+FxjVTftFxP8wfS/4c+uHZtavj+Pqb2Z1ZrY6MgxEDp0vMzuRcK/r9ajidJ2vZD8vCdsE52MP4fOTymszGVe0Swn38iISfaZ9Gdf5wefzKzMb38PXZjIugqGpScCzUcWZOl+pSBZ7Ws5XVu9YZGZPA6MTVF3n7r/p63giuoor+om7u5klnfcZ/OadAdRGFV9LOLGVEp6Leg1wcx/GNcHdG83sSOBZM9tAOGn1WprP138CX3D39qC41+erEJnZxUA18Imo4rjP1N1fT/wOafdb4EF3bzazLxH+62ZOHx07FYuAX7l7W1RZNs9XRmU1obv7Jw/xLRqB8VHPxwVlOwj/KVMc9LIi5Yccl5ltN7Mx7v5OkIDe6+KtPgs85u4tUe8d6a02m9n9wLf6Mi53bwz+fcPMngdmAb8my+fLzAYDTxD+Zb466r17fb4SSPbzkqjNVjMrBoYQ/nlK5bWZjAsz+yThX5KfcPfmSHmSzzQdCarbuNx9R9TTewlfM4m89rROr30+DTGlFFeURcDXogsyeL5SkSz2tJyvfB9yWQtMsfAMjVLCH94KD19leI7w+DXAF4B09fhXBO+XyvvGjd0FSS0ybr0ASHg1PBNxmdnQyJCFmY0ATgE2Zft8BZ/dY4THFn/VqS6d5yvhz0sX8V4APBucnxXAIgvPgpkETAFePIRYehSXmc0C/g2Y7+7vRZUn/Ez7MK4xUU/nA68Gj2uBTwXxDQU+RexfqhmNK4htGuELjC9ElWXyfKViBfB3wWyXk4A9QaclPecrU1d7D/UL+AzhcaRmYDtQG5SPBZ6MancW8CfCv2Gviyo/kvB/uM3AI0C/NMU1HHgGeA14GhgWlFcD90a1m0j4t25Rp9c/C2wgnJh+AQzsq7iAjwbHfin499JcOF/AxUALsD7qa2YmzleinxfCQzjzg8f9g+9/c3A+jox67XXB6xqAT6f55727uJ4O/h9Ezs+K7j7TPorrNmBjcPzngGlRr/2H4DxuBv6+L+MKnt8ELO30ukyfrwcJz9JqIZy/LgW+DHw5qDfgziDuDUTN4EvH+dLSfxGRApHvQy4iIhJQQhcRKRBK6CIiBUIJXUSkQCihi4gUCCV0EZECoYQuIlIg/j8qlDYa6ALtuAAAAABJRU5ErkJggg==\n",
612 | "text/plain": [
613 | ""
614 | ]
615 | },
616 | "metadata": {
617 | "tags": [],
618 | "needs_background": "light"
619 | }
620 | }
621 | ]
622 | },
623 | {
624 | "cell_type": "markdown",
625 | "metadata": {
626 | "id": "4riPBjmmWX__"
627 | },
628 | "source": [
629 | "以上"
630 | ]
631 | }
632 | ]
633 | }
--------------------------------------------------------------------------------
/6_3_lingam.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "name": "python3",
7 | "display_name": "Python 3"
8 | },
9 | "colab": {
10 | "name": "6_3_lingam.ipynb",
11 | "provenance": [],
12 | "collapsed_sections": []
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "5-FKx0scdZyU",
20 | "colab_type": "text"
21 | },
22 | "source": [
23 | "# 6.3 LinGAMを用いた因果探索\n",
24 | "\n",
25 | "本ファイルは、6.3節の内容となります。LiNGAMを実装しながらその内容の理解を深めていきます。"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "A1B4kBJTjXCc",
32 | "colab_type": "text"
33 | },
34 | "source": [
35 | "## プログラム実行前の設定など"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "metadata": {
41 | "id": "qE00vj2hjUsc",
42 | "colab_type": "code",
43 | "colab": {}
44 | },
45 | "source": [
46 | "# 乱数のシードを固定\n",
47 | "import random\n",
48 | "import numpy as np\n",
49 | "\n",
50 | "random.seed(1234)\n",
51 | "np.random.seed(1234)\n"
52 | ],
53 | "execution_count": 0,
54 | "outputs": []
55 | },
56 | {
57 | "cell_type": "code",
58 | "metadata": {
59 | "id": "PtmPH1FUjZZA",
60 | "colab_type": "code",
61 | "colab": {}
62 | },
63 | "source": [
64 | "# 使用するパッケージ(ライブラリと関数)を定義\n",
65 | "import pandas as pd"
66 | ],
67 | "execution_count": 0,
68 | "outputs": []
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {
73 | "id": "ONKBX56LdZyX",
74 | "colab_type": "text"
75 | },
76 | "source": [
77 | "# データ生成"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {
83 | "colab_type": "text",
84 | "id": "3IVwBzaljlko"
85 | },
86 | "source": [
87 | "## モデル\n",
88 | "x1 = 3×x2 + ex1\n",
89 | "\n",
90 | "x2 = ex2\n",
91 | "\n",
92 | "x3 = 2×x1 + 4×x2 + ex3\n",
93 | " \n"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "metadata": {
99 | "id": "-CXDyGWOmOQb",
100 | "colab_type": "code",
101 | "outputId": "abf2a734-ee59-43ea-af8b-e8d21abdefe0",
102 | "colab": {
103 | "base_uri": "https://localhost:8080/",
104 | "height": 195
105 | }
106 | },
107 | "source": [
108 | "# データ数\n",
109 | "num_data = 200\n",
110 | "\n",
111 | "# 非ガウスのノイズ\n",
112 | "ex1 = 2*(np.random.rand(num_data)-0.5) # -1.0から1.0\n",
113 | "ex2 = 2*(np.random.rand(num_data)-0.5)\n",
114 | "ex3 = 2*(np.random.rand(num_data)-0.5)\n",
115 | "\n",
116 | "# データ生成\n",
117 | "x2 = ex2\n",
118 | "x1 = 3*x2 + ex1\n",
119 | "x3 = 2*x1 + 4*x2 + ex3\n",
120 | "\n",
121 | "# 表にまとめる\n",
122 | "df = pd.DataFrame({\"x1\": x1, \"x2\": x2, \"x3\": x3})\n",
123 | "df.head()\n"
124 | ],
125 | "execution_count": 3,
126 | "outputs": [
127 | {
128 | "output_type": "execute_result",
129 | "data": {
130 | "text/html": [
131 | "\n",
132 | "\n",
145 | "
\n",
146 | " \n",
147 | " \n",
148 | " | \n",
149 | " x1 | \n",
150 | " x2 | \n",
151 | " x3 | \n",
152 | "
\n",
153 | " \n",
154 | " \n",
155 | " \n",
156 | " 0 | \n",
157 | " 2.257272 | \n",
158 | " 0.958078 | \n",
159 | " 8.776842 | \n",
160 | "
\n",
161 | " \n",
162 | " 1 | \n",
163 | " 2.531611 | \n",
164 | " 0.762464 | \n",
165 | " 8.561263 | \n",
166 | "
\n",
167 | " \n",
168 | " 2 | \n",
169 | " 0.641547 | \n",
170 | " 0.255364 | \n",
171 | " 1.341902 | \n",
172 | "
\n",
173 | " \n",
174 | " 3 | \n",
175 | " 3.153636 | \n",
176 | " 0.860973 | \n",
177 | " 9.322791 | \n",
178 | "
\n",
179 | " \n",
180 | " 4 | \n",
181 | " 1.908691 | \n",
182 | " 0.449580 | \n",
183 | " 5.776675 | \n",
184 | "
\n",
185 | " \n",
186 | "
\n",
187 | "
"
188 | ],
189 | "text/plain": [
190 | " x1 x2 x3\n",
191 | "0 2.257272 0.958078 8.776842\n",
192 | "1 2.531611 0.762464 8.561263\n",
193 | "2 0.641547 0.255364 1.341902\n",
194 | "3 3.153636 0.860973 9.322791\n",
195 | "4 1.908691 0.449580 5.776675"
196 | ]
197 | },
198 | "metadata": {
199 | "tags": []
200 | },
201 | "execution_count": 3
202 | }
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {
208 | "id": "Z3Z3H7PldZ0I",
209 | "colab_type": "text"
210 | },
211 | "source": [
212 | "## 独立成分分析を実施"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "metadata": {
218 | "id": "um30h2fLdZ0K",
219 | "colab_type": "code",
220 | "outputId": "3fc228bc-1ef1-4c88-f437-4ff435e2f7fa",
221 | "colab": {
222 | "base_uri": "https://localhost:8080/",
223 | "height": 67
224 | }
225 | },
226 | "source": [
227 | "# 独立成分分析はscikit-learnの関数を使用します\n",
228 | "from sklearn.decomposition import FastICA\n",
229 | "# https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FastICA.html\n",
230 | "\n",
231 | "ica = FastICA(random_state=1234).fit(df)\n",
232 | "\n",
233 | "# ICAで求めた行列A\n",
234 | "A_ica = ica.mixing_\n",
235 | "\n",
236 | "# 行列Aの逆行列を求める\n",
237 | "A_ica_inv = np.linalg.pinv(A_ica)\n",
238 | "print(A_ica_inv)\n"
239 | ],
240 | "execution_count": 4,
241 | "outputs": [
242 | {
243 | "output_type": "stream",
244 | "text": [
245 | "[[-0.23203107 -0.4635971 0.1154553 ]\n",
246 | " [-0.02158245 0.12961253 0.00557934]\n",
247 | " [-0.11326384 0.40437635 -0.00563091]]\n"
248 | ],
249 | "name": "stdout"
250 | }
251 | ]
252 | },
253 | {
254 | "cell_type": "markdown",
255 | "metadata": {
256 | "colab_type": "text",
257 | "id": "VWpDEqLNtGe1"
258 | },
259 | "source": [
260 | "## 行列A_invを求め、行の順番と大きさを調整\n",
261 | "\n",
262 | "プログラムの参考\n",
263 | "\n",
264 | "https://qiita.com/m__k/items/bd87c063a7496897ba7c"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "metadata": {
270 | "id": "nmBp3lFddZ0e",
271 | "colab_type": "code",
272 | "outputId": "7967ae8a-b03f-4f4f-aef8-6ab7de221eeb",
273 | "colab": {
274 | "base_uri": "https://localhost:8080/",
275 | "height": 34
276 | }
277 | },
278 | "source": [
279 | "!pip install munkres\n",
280 | "from munkres import Munkres\n",
281 | "from copy import deepcopy\n"
282 | ],
283 | "execution_count": 5,
284 | "outputs": [
285 | {
286 | "output_type": "stream",
287 | "text": [
288 | "Requirement already satisfied: munkres in /usr/local/lib/python3.6/dist-packages (1.1.2)\n"
289 | ],
290 | "name": "stdout"
291 | }
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "metadata": {
297 | "id": "NUhUjh3YtVto",
298 | "colab_type": "code",
299 | "outputId": "d70aa90c-2aa9-4ac5-f58f-053df8759f11",
300 | "colab": {
301 | "base_uri": "https://localhost:8080/",
302 | "height": 67
303 | }
304 | },
305 | "source": [
306 | "# 実装の参考\n",
307 | "# [5] Qiita:LiNGAMモデルの推定方法について\n",
308 | "# https://qiita.com/m__k/items/bd87c063a7496897ba7c\n",
309 | "\n",
310 | "# ①「行の順番を変換」→対角成分の絶対値を最大にする\n",
311 | "# (元のA^-1の対角成分は必ず0ではないので)\n",
312 | "\n",
313 | "# 絶対値の逆数にして対角成分の和を最小にする問題に置き換える\n",
314 | "A_ica_inv_small = 1 / np.abs(A_ica_inv)\n",
315 | "\n",
316 | "# 対角成分の和を最小にする行の入れ替え順を求める\n",
317 | "m = Munkres() # ハンガリアン法\n",
318 | "ixs = np.vstack(m.compute(deepcopy(A_ica_inv_small)))\n",
319 | "\n",
320 | "# 求めた順番で変換\n",
321 | "ixs = ixs[np.argsort(ixs[:, 0]), :]\n",
322 | "ixs_perm = ixs[:, 1]\n",
323 | "A_ica_inv_perm = np.zeros_like(A_ica_inv)\n",
324 | "A_ica_inv_perm[ixs_perm] = A_ica_inv\n",
325 | "print(A_ica_inv_perm)\n"
326 | ],
327 | "execution_count": 6,
328 | "outputs": [
329 | {
330 | "output_type": "stream",
331 | "text": [
332 | "[[-0.11326384 0.40437635 -0.00563091]\n",
333 | " [-0.02158245 0.12961253 0.00557934]\n",
334 | " [-0.23203107 -0.4635971 0.1154553 ]]\n"
335 | ],
336 | "name": "stdout"
337 | }
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "metadata": {
343 | "id": "aVTjcyKs1aFR",
344 | "colab_type": "code",
345 | "colab": {
346 | "base_uri": "https://localhost:8080/",
347 | "height": 67
348 | },
349 | "outputId": "c3a08db9-1fad-4485-f18e-c34f0d16cc3f"
350 | },
351 | "source": [
352 | "# 並び替わった順番\n",
353 | "print(ixs)"
354 | ],
355 | "execution_count": 7,
356 | "outputs": [
357 | {
358 | "output_type": "stream",
359 | "text": [
360 | "[[0 2]\n",
361 | " [1 1]\n",
362 | " [2 0]]\n"
363 | ],
364 | "name": "stdout"
365 | }
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "metadata": {
371 | "id": "uD47ajJBdZ1B",
372 | "colab_type": "code",
373 | "outputId": "82d6f933-1954-42da-a7f2-529910fc0fd3",
374 | "colab": {
375 | "base_uri": "https://localhost:8080/",
376 | "height": 67
377 | }
378 | },
379 | "source": [
380 | "# ②「行の大きさを調整」\n",
381 | "D = np.diag(A_ica_inv_perm)[:, np.newaxis] # D倍されているDを求める\n",
382 | "A_ica_inv_perm_D = A_ica_inv_perm / D\n",
383 | "print(A_ica_inv_perm_D)\n"
384 | ],
385 | "execution_count": 8,
386 | "outputs": [
387 | {
388 | "output_type": "stream",
389 | "text": [
390 | "[[ 1. -3.57021564 0.04971498]\n",
391 | " [-0.16651518 1. 0.0430463 ]\n",
392 | " [-2.00970483 -4.01538182 1. ]]\n"
393 | ],
394 | "name": "stdout"
395 | }
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "metadata": {
401 | "id": "Hfu7SLMwdZ1G",
402 | "colab_type": "code",
403 | "outputId": "9e1e43ef-8efd-4e03-8b82-ccd860c4dd20",
404 | "colab": {
405 | "base_uri": "https://localhost:8080/",
406 | "height": 67
407 | }
408 | },
409 | "source": [
410 | "# ③「B=I-A_inv」\n",
411 | "B_est = np.eye(3) - A_ica_inv_perm_D\n",
412 | "print(B_est)\n"
413 | ],
414 | "execution_count": 9,
415 | "outputs": [
416 | {
417 | "output_type": "stream",
418 | "text": [
419 | "[[ 0. 3.57021564 -0.04971498]\n",
420 | " [ 0.16651518 0. -0.0430463 ]\n",
421 | " [ 2.00970483 4.01538182 0. ]]\n"
422 | ],
423 | "name": "stdout"
424 | }
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "metadata": {
430 | "id": "Vipp7leM5pRd",
431 | "colab_type": "code",
432 | "colab": {
433 | "base_uri": "https://localhost:8080/",
434 | "height": 67
435 | },
436 | "outputId": "7b5c03cf-29d8-4e0d-e68a-de6552b4ece5"
437 | },
438 | "source": [
439 | "# ①上側成分の0になるはずの数(3×3であれば3個、4×4であれば6個と、対角成分の上側の要素数分)、絶対値が小さい成分を0にする\n",
440 | "# ②変数の順番を入れ替えて、下三角行列になるかを確かめる、\n",
441 | "# 実装の参考\n",
442 | "# [5] Qiita:LiNGAMモデルの推定方法について\n",
443 | "# https://qiita.com/m__k/items/bd87c063a7496897ba7c\n",
444 | "\n",
445 | "def _slttestperm(b_i):\n",
446 | "# b_iの行を並び替えて下三角行列にできるかどうかチェック\n",
447 | " n = b_i.shape[0]\n",
448 | " remnodes = np.arange(n)\n",
449 | " b_rem = deepcopy(b_i)\n",
450 | " p = list() \n",
451 | "\n",
452 | " for i in range(n):\n",
453 | " # 成分が全て0である行番号のリスト\n",
454 | " ixs = np.where(np.sum(np.abs(b_rem), axis=1) < 1e-12)[0]\n",
455 | "\n",
456 | " if len(ixs) == 0:\n",
457 | " return None\n",
458 | " else:\n",
459 | " ix = ixs[0]\n",
460 | " p.append(remnodes[ix])\n",
461 | "\n",
462 | " # 成分が全て0である行を削除\n",
463 | " remnodes = np.hstack((remnodes[:ix], remnodes[(ix + 1):]))\n",
464 | " ixs = np.hstack((np.arange(ix), np.arange(ix + 1, len(b_rem))))\n",
465 | " b_rem = b_rem[ixs, :]\n",
466 | " b_rem = b_rem[:, ixs]\n",
467 | "\n",
468 | " return np.array(p)\n",
469 | "\n",
470 | "b = B_est\n",
471 | "n = b.shape[0]\n",
472 | "assert(b.shape == (n, n))\n",
473 | "\n",
474 | "ixs = np.argsort(np.abs(b).ravel())\n",
475 | "\n",
476 | "for i in range(int(n * (n + 1) / 2) - 1, (n * n) - 1):\n",
477 | " b_i = deepcopy(b)\n",
478 | " b_i.ravel()[ixs[:i]] = 0\n",
479 | " ixs_perm = _slttestperm(b_i)\n",
480 | " if ixs_perm is not None:\n",
481 | " b_opt = deepcopy(b)\n",
482 | " b_opt = b_opt[ixs_perm, :]\n",
483 | " b_opt = b_opt[:, ixs_perm]\n",
484 | " break\n",
485 | "b_csl = np.tril(b_opt, -1)\n",
486 | "b_csl[ixs_perm, :] = deepcopy(b_csl)\n",
487 | "b_csl[:, ixs_perm] = deepcopy(b_csl)\n",
488 | "\n",
489 | "B_est1 = b_csl\n",
490 | "print(B_est1)\n",
491 | "\n"
492 | ],
493 | "execution_count": 10,
494 | "outputs": [
495 | {
496 | "output_type": "stream",
497 | "text": [
498 | "[[0. 3.57021564 0. ]\n",
499 | " [0. 0. 0. ]\n",
500 | " [2.00970483 4.01538182 0. ]]\n"
501 | ],
502 | "name": "stdout"
503 | }
504 | ]
505 | },
506 | {
507 | "cell_type": "markdown",
508 | "metadata": {
509 | "id": "EKf9DwAK84hB",
510 | "colab_type": "text"
511 | },
512 | "source": [
513 | "## Bの非ゼロ要素を求め直す"
514 | ]
515 | },
516 | {
517 | "cell_type": "code",
518 | "metadata": {
519 | "id": "QBdpU_dS88Lf",
520 | "colab_type": "code",
521 | "outputId": "176e10b7-a52d-4cf3-fd43-e1a9567b4472",
522 | "colab": {
523 | "base_uri": "https://localhost:8080/",
524 | "height": 50
525 | }
526 | },
527 | "source": [
528 | "# scikit-learnから線形回帰をimport\n",
529 | "from sklearn.linear_model import LinearRegression\n",
530 | "\n",
531 | "# 説明変数\n",
532 | "X1 = df[[\"x2\"]]\n",
533 | "X3 = df[[\"x1\", \"x2\"]]\n",
534 | "\n",
535 | "# 被説明変数(目的変数)\n",
536 | "# df[\"x1\"]\n",
537 | "# df[\"x3\"]\n",
538 | "\n",
539 | "# 回帰の実施\n",
540 | "reg1 = LinearRegression().fit(X1, df[\"x1\"])\n",
541 | "reg3 = LinearRegression().fit(X3, df[\"x3\"])\n",
542 | "\n",
543 | "# 回帰した結果の係数を出力\n",
544 | "print(\"係数:\", reg1.coef_)\n",
545 | "print(\"係数:\", reg3.coef_)\n"
546 | ],
547 | "execution_count": 11,
548 | "outputs": [
549 | {
550 | "output_type": "stream",
551 | "text": [
552 | "係数: [3.14642595]\n",
553 | "係数: [1.96164568 4.11256441]\n"
554 | ],
555 | "name": "stdout"
556 | }
557 | ]
558 | },
559 | {
560 | "cell_type": "markdown",
561 | "metadata": {
562 | "id": "i_V6mNsXyXs2",
563 | "colab_type": "text"
564 | },
565 | "source": [
566 | "以上"
567 | ]
568 | }
569 | ]
570 | }
--------------------------------------------------------------------------------
/7_2_bayesian_network_bic.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "name": "python3",
7 | "display_name": "Python 3"
8 | },
9 | "colab": {
10 | "name": "7_2_bayesian_network_bic.ipynb",
11 | "provenance": [],
12 | "collapsed_sections": []
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "aoxI3DOK9vm2",
20 | "colab_type": "text"
21 | },
22 | "source": [
23 | "# 7.2 BICの計算\n",
24 | "\n",
25 | "本ファイルは、7.2節の実装です。\n",
26 | "\n",
27 | "データに対してBICの値を求めます。"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "2XdIDbdlejUk",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "## プログラム実行前の設定など"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "id": "XZFKJwcu-_Oj",
44 | "colab_type": "code",
45 | "colab": {}
46 | },
47 | "source": [
48 | "# 乱数のシードを設定\n",
49 | "import random\n",
50 | "import numpy as np\n",
51 | "\n",
52 | "np.random.seed(1234)\n",
53 | "random.seed(1234)\n"
54 | ],
55 | "execution_count": 0,
56 | "outputs": []
57 | },
58 | {
59 | "cell_type": "code",
60 | "metadata": {
61 | "id": "hx1idArc_F15",
62 | "colab_type": "code",
63 | "colab": {}
64 | },
65 | "source": [
66 | "# 使用するパッケージ(ライブラリと関数)を定義\n",
67 | "from numpy.random import *\n",
68 | "import pandas as pd\n"
69 | ],
70 | "execution_count": 0,
71 | "outputs": []
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {
76 | "id": "AWqP6yeQlI_t",
77 | "colab_type": "text"
78 | },
79 | "source": [
80 | "## データの作成"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "metadata": {
86 | "id": "DpnGB2KZ_L8x",
87 | "colab_type": "code",
88 | "outputId": "b2ca2d8d-76dc-48f1-8dea-ad30de30b06c",
89 | "colab": {
90 | "base_uri": "https://localhost:8080/",
91 | "height": 195
92 | }
93 | },
94 | "source": [
95 | "# データ数\n",
96 | "num_data = 10\n",
97 | "\n",
98 | "# x1:0か1の値をnum_data個生成、0の確率は0.6、1の確率は0.4\n",
99 | "x1 = np.random.choice([0, 1], num_data, p=[0.6, 0.4])\n",
100 | "\n",
101 | "# x2:0か1の値をnum_data個生成、0の確率は0.4、1の確率は0.6\n",
102 | "x2 = np.random.choice([0, 1], num_data, p=[0.4, 0.6])\n",
103 | "\n",
104 | "# 2変数で表にする\n",
105 | "df = pd.DataFrame({'x1': x1,\n",
106 | " 'x2': x2,\n",
107 | " })\n",
108 | "\n",
109 | "df.head() # 先頭を表示\n"
110 | ],
111 | "execution_count": 0,
112 | "outputs": [
113 | {
114 | "output_type": "execute_result",
115 | "data": {
116 | "text/html": [
117 | "\n",
118 | "\n",
131 | "
\n",
132 | " \n",
133 | " \n",
134 | " | \n",
135 | " x1 | \n",
136 | " x2 | \n",
137 | "
\n",
138 | " \n",
139 | " \n",
140 | " \n",
141 | " 0 | \n",
142 | " 0 | \n",
143 | " 0 | \n",
144 | "
\n",
145 | " \n",
146 | " 1 | \n",
147 | " 1 | \n",
148 | " 1 | \n",
149 | "
\n",
150 | " \n",
151 | " 2 | \n",
152 | " 0 | \n",
153 | " 1 | \n",
154 | "
\n",
155 | " \n",
156 | " 3 | \n",
157 | " 1 | \n",
158 | " 1 | \n",
159 | "
\n",
160 | " \n",
161 | " 4 | \n",
162 | " 1 | \n",
163 | " 0 | \n",
164 | "
\n",
165 | " \n",
166 | "
\n",
167 | "
"
168 | ],
169 | "text/plain": [
170 | " x1 x2\n",
171 | "0 0 0\n",
172 | "1 1 1\n",
173 | "2 0 1\n",
174 | "3 1 1\n",
175 | "4 1 0"
176 | ]
177 | },
178 | "metadata": {
179 | "tags": []
180 | },
181 | "execution_count": 3
182 | }
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "metadata": {
188 | "id": "YvCfB7uRZvZI",
189 | "colab_type": "code",
190 | "outputId": "bad340c8-374a-42e7-bc43-1c9b22fb1251",
191 | "colab": {
192 | "base_uri": "https://localhost:8080/",
193 | "height": 343
194 | }
195 | },
196 | "source": [
197 | "# 変数x3:0か1の値をnum_data個生成する\n",
198 | "# (x1,x2)= (0,0)のとき、0の確率は0.2\n",
199 | "# (x1,x2)= (1,0)のとき、0の確率は0.3\n",
200 | "# (x1,x2)= (0,1)のとき、0の確率は0.4\n",
201 | "# (x1,x2)= (1,1)のとき、0の確率は0.1\n",
202 | "\n",
203 | "x3 = []\n",
204 | "for i in range(num_data):\n",
205 | " if x1[i] == 0 and x2[i] == 0:\n",
206 | " x3_value = np.random.choice([0, 1], 1, p=[0.2, 0.8])\n",
207 | " x3.append(x3_value[0]) # x3はリストになっているので、0番目の要素を取り出して追加\n",
208 | " elif x1[i] == 0 and x2[i] == 1:\n",
209 | " x3_value = np.random.choice([0, 1], 1, p=[0.3, 0.7])\n",
210 | " x3.append(x3_value[0])\n",
211 | " elif x1[i] == 1 and x2[i] == 0:\n",
212 | " x3_value = np.random.choice([0, 1], 1, p=[0.4, 0.6])\n",
213 | " x3.append(x3_value[0])\n",
214 | " elif x1[i] == 1 and x2[i] == 1:\n",
215 | " x3_value = np.random.choice([0, 1], 1, p=[0.1, 0.9])\n",
216 | " x3.append(x3_value[0])\n",
217 | "\n",
218 | "df[\"x3\"] = x3\n",
219 | "\n",
220 | "df # 表示\n"
221 | ],
222 | "execution_count": 0,
223 | "outputs": [
224 | {
225 | "output_type": "execute_result",
226 | "data": {
227 | "text/html": [
228 | "\n",
229 | "\n",
242 | "
\n",
243 | " \n",
244 | " \n",
245 | " | \n",
246 | " x1 | \n",
247 | " x2 | \n",
248 | " x3 | \n",
249 | "
\n",
250 | " \n",
251 | " \n",
252 | " \n",
253 | " 0 | \n",
254 | " 0 | \n",
255 | " 0 | \n",
256 | " 1 | \n",
257 | "
\n",
258 | " \n",
259 | " 1 | \n",
260 | " 1 | \n",
261 | " 1 | \n",
262 | " 1 | \n",
263 | "
\n",
264 | " \n",
265 | " 2 | \n",
266 | " 0 | \n",
267 | " 1 | \n",
268 | " 0 | \n",
269 | "
\n",
270 | " \n",
271 | " 3 | \n",
272 | " 1 | \n",
273 | " 1 | \n",
274 | " 1 | \n",
275 | "
\n",
276 | " \n",
277 | " 4 | \n",
278 | " 1 | \n",
279 | " 0 | \n",
280 | " 1 | \n",
281 | "
\n",
282 | " \n",
283 | " 5 | \n",
284 | " 0 | \n",
285 | " 1 | \n",
286 | " 1 | \n",
287 | "
\n",
288 | " \n",
289 | " 6 | \n",
290 | " 0 | \n",
291 | " 1 | \n",
292 | " 1 | \n",
293 | "
\n",
294 | " \n",
295 | " 7 | \n",
296 | " 1 | \n",
297 | " 0 | \n",
298 | " 1 | \n",
299 | "
\n",
300 | " \n",
301 | " 8 | \n",
302 | " 1 | \n",
303 | " 1 | \n",
304 | " 1 | \n",
305 | "
\n",
306 | " \n",
307 | " 9 | \n",
308 | " 1 | \n",
309 | " 1 | \n",
310 | " 1 | \n",
311 | "
\n",
312 | " \n",
313 | "
\n",
314 | "
"
315 | ],
316 | "text/plain": [
317 | " x1 x2 x3\n",
318 | "0 0 0 1\n",
319 | "1 1 1 1\n",
320 | "2 0 1 0\n",
321 | "3 1 1 1\n",
322 | "4 1 0 1\n",
323 | "5 0 1 1\n",
324 | "6 0 1 1\n",
325 | "7 1 0 1\n",
326 | "8 1 1 1\n",
327 | "9 1 1 1"
328 | ]
329 | },
330 | "metadata": {
331 | "tags": []
332 | },
333 | "execution_count": 4
334 | }
335 | ]
336 | },
337 | {
338 | "cell_type": "markdown",
339 | "metadata": {
340 | "id": "BHcdUlW9koTa",
341 | "colab_type": "text"
342 | },
343 | "source": [
344 | "## pgmpy(Python library for Probabilistic Graphical Models)によるBICの計算\n"
345 | ]
346 | },
347 | {
348 | "cell_type": "code",
349 | "metadata": {
350 | "id": "25oDRf7qtNtF",
351 | "colab_type": "code",
352 | "outputId": "8d2c208a-2371-4a82-9415-59ee92f0fdca",
353 | "colab": {
354 | "base_uri": "https://localhost:8080/",
355 | "height": 101
356 | }
357 | },
358 | "source": [
359 | "!pip install pgmpy==0.1.9"
360 | ],
361 | "execution_count": 0,
362 | "outputs": [
363 | {
364 | "output_type": "stream",
365 | "text": [
366 | "Collecting pgmpy==0.1.9\n",
367 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/5a/b1/18dfdfcb10dcce71fd39f8c6801407e9aebd953939682558a5317e4a021c/pgmpy-0.1.9-py3-none-any.whl (331kB)\n",
368 | "\u001b[K |████████████████████████████████| 337kB 2.8MB/s \n",
369 | "\u001b[?25hInstalling collected packages: pgmpy\n",
370 | "Successfully installed pgmpy-0.1.9\n"
371 | ],
372 | "name": "stdout"
373 | }
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "metadata": {
379 | "id": "H1i9-YjMdDTh",
380 | "colab_type": "code",
381 | "colab": {}
382 | },
383 | "source": [
384 | "# 正解のDAGを与える\n",
385 | "from pgmpy.models import BayesianModel\n",
386 | "model = BayesianModel([('x1', 'x3'), ('x2', 'x3')]) # x1 -> x3 <- x2\n"
387 | ],
388 | "execution_count": 0,
389 | "outputs": []
390 | },
391 | {
392 | "cell_type": "code",
393 | "metadata": {
394 | "id": "WFKQb7XudDW3",
395 | "colab_type": "code",
396 | "outputId": "460d06f3-9e80-430e-e3c3-4e9b29066eaa",
397 | "colab": {
398 | "base_uri": "https://localhost:8080/",
399 | "height": 286
400 | }
401 | },
402 | "source": [
403 | "# 各データパターンの個数を表示する\n",
404 | "from pgmpy.estimators import ParameterEstimator\n",
405 | "pe = ParameterEstimator(model, df)\n",
406 | "print(\"\\n\", pe.state_counts('x1'))\n",
407 | "print(\"\\n\", pe.state_counts('x2'))\n",
408 | "print(\"\\n\", pe.state_counts('x3'))\n"
409 | ],
410 | "execution_count": 0,
411 | "outputs": [
412 | {
413 | "output_type": "stream",
414 | "text": [
415 | "/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n",
416 | " import pandas.util.testing as tm\n"
417 | ],
418 | "name": "stderr"
419 | },
420 | {
421 | "output_type": "stream",
422 | "text": [
423 | "\n",
424 | " x1\n",
425 | "0 4\n",
426 | "1 6\n",
427 | "\n",
428 | " x2\n",
429 | "0 3\n",
430 | "1 7\n",
431 | "\n",
432 | " x1 0 1 \n",
433 | "x2 0 1 0 1\n",
434 | "x3 \n",
435 | "0 0.0 1.0 0.0 0.0\n",
436 | "1 1.0 2.0 2.0 4.0\n"
437 | ],
438 | "name": "stdout"
439 | }
440 | ]
441 | },
442 | {
443 | "cell_type": "code",
444 | "metadata": {
445 | "id": "ztZJyobWwalY",
446 | "colab_type": "code",
447 | "outputId": "11bbeb28-dc26-4a7b-cbd3-08f4cd7f75e9",
448 | "colab": {
449 | "base_uri": "https://localhost:8080/",
450 | "height": 336
451 | }
452 | },
453 | "source": [
454 | "# CPT(条件付き確率表)を推定する\n",
455 | "from pgmpy.estimators import BayesianEstimator\n",
456 | "\n",
457 | "estimator = BayesianEstimator(model, df)\n",
458 | "\n",
459 | "cpd_x1 = estimator.estimate_cpd(\n",
460 | " 'x1', prior_type=\"dirichlet\", pseudo_counts=[[0], [0]])\n",
461 | "cpd_x2 = estimator.estimate_cpd(\n",
462 | " 'x2', prior_type=\"dirichlet\", pseudo_counts=[[0], [0]])\n",
463 | "cpd_x3 = estimator.estimate_cpd('x3', prior_type=\"dirichlet\", pseudo_counts=[\n",
464 | " [0, 0, 0, 0], [0, 0, 0, 0]])\n",
465 | "# 注意:pseudo_countsはハイパーパラメータ0のディリクレ分布の設定を与えています。\n",
466 | "\n",
467 | "print(cpd_x1)\n",
468 | "print(cpd_x2)\n",
469 | "print(cpd_x3)\n"
470 | ],
471 | "execution_count": 0,
472 | "outputs": [
473 | {
474 | "output_type": "stream",
475 | "text": [
476 | "+-------+-----+\n",
477 | "| x1(0) | 0.4 |\n",
478 | "+-------+-----+\n",
479 | "| x1(1) | 0.6 |\n",
480 | "+-------+-----+\n",
481 | "+-------+-----+\n",
482 | "| x2(0) | 0.3 |\n",
483 | "+-------+-----+\n",
484 | "| x2(1) | 0.7 |\n",
485 | "+-------+-----+\n",
486 | "+-------+-------+--------------------+-------+-------+\n",
487 | "| x1 | x1(0) | x1(0) | x1(1) | x1(1) |\n",
488 | "+-------+-------+--------------------+-------+-------+\n",
489 | "| x2 | x2(0) | x2(1) | x2(0) | x2(1) |\n",
490 | "+-------+-------+--------------------+-------+-------+\n",
491 | "| x3(0) | 0.0 | 0.3333333333333333 | 0.0 | 0.0 |\n",
492 | "+-------+-------+--------------------+-------+-------+\n",
493 | "| x3(1) | 1.0 | 0.6666666666666666 | 1.0 | 1.0 |\n",
494 | "+-------+-------+--------------------+-------+-------+\n"
495 | ],
496 | "name": "stdout"
497 | }
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "metadata": {
503 | "id": "T8UqcSmXyX_4",
504 | "colab_type": "code",
505 | "outputId": "3207d204-6108-476f-c9a6-069a676aa1d3",
506 | "colab": {
507 | "base_uri": "https://localhost:8080/",
508 | "height": 34
509 | }
510 | },
511 | "source": [
512 | "# BICを求める\n",
513 | "from pgmpy.estimators import BicScore\n",
514 | "bic = BicScore(df)\n",
515 | "print(bic.score(model))\n"
516 | ],
517 | "execution_count": 0,
518 | "outputs": [
519 | {
520 | "output_type": "stream",
521 | "text": [
522 | "-21.65605747450808\n"
523 | ],
524 | "name": "stdout"
525 | }
526 | ]
527 | },
528 | {
529 | "cell_type": "markdown",
530 | "metadata": {
531 | "id": "6Pvo1RIbEoyY",
532 | "colab_type": "text"
533 | },
534 | "source": [
535 | "## 異なるDAGでのBICの計算"
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "metadata": {
541 | "id": "y2ZRLS0fEtnc",
542 | "colab_type": "code",
543 | "outputId": "6e9cfc89-cd37-4b73-ee48-39c598f5f59f",
544 | "colab": {
545 | "base_uri": "https://localhost:8080/",
546 | "height": 34
547 | }
548 | },
549 | "source": [
550 | "# 正解ではないDAGを与える\n",
551 | "from pgmpy.models import BayesianModel\n",
552 | "model = BayesianModel([('x2', 'x1'), ('x2', 'x3')]) # x1 <- x2 -> x3\n",
553 | "bic = BicScore(df)\n",
554 | "print(bic.score(model))\n"
555 | ],
556 | "execution_count": 0,
557 | "outputs": [
558 | {
559 | "output_type": "stream",
560 | "text": [
561 | "-21.425819218840655\n"
562 | ],
563 | "name": "stdout"
564 | }
565 | ]
566 | },
567 | {
568 | "cell_type": "markdown",
569 | "metadata": {
570 | "colab_type": "text",
571 | "id": "I6P1x9vAdG3i"
572 | },
573 | "source": [
574 | "以上"
575 | ]
576 | }
577 | ]
578 | }
--------------------------------------------------------------------------------
/7_3_bayesian_network_independence_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "name": "python3",
7 | "display_name": "Python 3"
8 | },
9 | "colab": {
10 | "name": "7_3_bayesian_network__independence_test.ipynb",
11 | "provenance": [],
12 | "collapsed_sections": []
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "aoxI3DOK9vm2",
20 | "colab_type": "text"
21 | },
22 | "source": [
23 | "# 7.3 独立性の検定\n",
24 | "\n",
25 | "本ファイルは、7.3節の実装です。\n",
26 | "\n",
27 | "データに対して独立性のカイ二乗検定を実施ます"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "2XdIDbdlejUk",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "## プログラム実行前の設定など"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "id": "XZFKJwcu-_Oj",
44 | "colab_type": "code",
45 | "colab": {}
46 | },
47 | "source": [
48 | "# 乱数のシードを設定\n",
49 | "import random\n",
50 | "import numpy as np\n",
51 | "\n",
52 | "np.random.seed(1234)\n",
53 | "random.seed(1234)\n"
54 | ],
55 | "execution_count": 0,
56 | "outputs": []
57 | },
58 | {
59 | "cell_type": "code",
60 | "metadata": {
61 | "id": "hx1idArc_F15",
62 | "colab_type": "code",
63 | "colab": {}
64 | },
65 | "source": [
66 | "# 使用するパッケージ(ライブラリと関数)を定義\n",
67 | "from numpy.random import *\n",
68 | "import pandas as pd\n"
69 | ],
70 | "execution_count": 0,
71 | "outputs": []
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {
76 | "id": "AWqP6yeQlI_t",
77 | "colab_type": "text"
78 | },
79 | "source": [
80 | "## データの作成"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "metadata": {
86 | "id": "DpnGB2KZ_L8x",
87 | "colab_type": "code",
88 | "outputId": "7a7e8d89-2383-4b04-d0d0-3a15be37316f",
89 | "colab": {
90 | "base_uri": "https://localhost:8080/",
91 | "height": 195
92 | }
93 | },
94 | "source": [
95 | "# データ数\n",
96 | "num_data = 100\n",
97 | "\n",
98 | "# x1:0か1の値をnum_data個生成、0の確率は0.6、1の確率は0.4\n",
99 | "x1 = np.random.choice([0, 1], num_data, p=[0.6, 0.4])\n",
100 | "\n",
101 | "# x2:0か1の値をnum_data個生成、0の確率は0.4、1の確率は0.6\n",
102 | "x2 = np.random.choice([0, 1], num_data, p=[0.4, 0.6])\n",
103 | "\n",
104 | "# x2はx1と因果関係にあるとする\n",
105 | "x2 = x2*x1\n",
106 | "\n",
107 | "# 2変数で表にする\n",
108 | "df = pd.DataFrame({'x1': x1,\n",
109 | " 'x2': x2,\n",
110 | " })\n",
111 | "\n",
112 | "df.head() # 先頭を表示\n"
113 | ],
114 | "execution_count": 3,
115 | "outputs": [
116 | {
117 | "output_type": "execute_result",
118 | "data": {
119 | "text/html": [
120 | "\n",
121 | "\n",
134 | "
\n",
135 | " \n",
136 | " \n",
137 | " | \n",
138 | " x1 | \n",
139 | " x2 | \n",
140 | "
\n",
141 | " \n",
142 | " \n",
143 | " \n",
144 | " 0 | \n",
145 | " 0 | \n",
146 | " 0 | \n",
147 | "
\n",
148 | " \n",
149 | " 1 | \n",
150 | " 1 | \n",
151 | " 1 | \n",
152 | "
\n",
153 | " \n",
154 | " 2 | \n",
155 | " 0 | \n",
156 | " 0 | \n",
157 | "
\n",
158 | " \n",
159 | " 3 | \n",
160 | " 1 | \n",
161 | " 1 | \n",
162 | "
\n",
163 | " \n",
164 | " 4 | \n",
165 | " 1 | \n",
166 | " 1 | \n",
167 | "
\n",
168 | " \n",
169 | "
\n",
170 | "
"
171 | ],
172 | "text/plain": [
173 | " x1 x2\n",
174 | "0 0 0\n",
175 | "1 1 1\n",
176 | "2 0 0\n",
177 | "3 1 1\n",
178 | "4 1 1"
179 | ]
180 | },
181 | "metadata": {
182 | "tags": []
183 | },
184 | "execution_count": 3
185 | }
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "metadata": {
191 | "id": "YvCfB7uRZvZI",
192 | "colab_type": "code",
193 | "outputId": "0852e5c6-bdcc-4aae-b0d9-722dbfd24514",
194 | "colab": {
195 | "base_uri": "https://localhost:8080/",
196 | "height": 84
197 | }
198 | },
199 | "source": [
200 | "# 各カウント\n",
201 | "print(((df[\"x1\"] == 0) & (df[\"x2\"] == 0)).sum())\n",
202 | "print(((df[\"x1\"] == 1) & (df[\"x2\"] == 0)).sum())\n",
203 | "print(((df[\"x1\"] == 0) & (df[\"x2\"] == 1)).sum())\n",
204 | "print(((df[\"x1\"] == 1) & (df[\"x2\"] == 1)).sum())\n"
205 | ],
206 | "execution_count": 4,
207 | "outputs": [
208 | {
209 | "output_type": "stream",
210 | "text": [
211 | "58\n",
212 | "9\n",
213 | "0\n",
214 | "33\n"
215 | ],
216 | "name": "stdout"
217 | }
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {
223 | "id": "BHcdUlW9koTa",
224 | "colab_type": "text"
225 | },
226 | "source": [
227 | "## pgmpy(Python library for Probabilistic Graphical Models)による独立性の検定\n"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "metadata": {
233 | "id": "25oDRf7qtNtF",
234 | "colab_type": "code",
235 | "outputId": "64ca1b9e-846d-4040-a37b-acd185f2d66a",
236 | "colab": {
237 | "base_uri": "https://localhost:8080/",
238 | "height": 101
239 | }
240 | },
241 | "source": [
242 | "!pip install pgmpy==0.1.9"
243 | ],
244 | "execution_count": 5,
245 | "outputs": [
246 | {
247 | "output_type": "stream",
248 | "text": [
249 | "Collecting pgmpy==0.1.9\n",
250 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/5a/b1/18dfdfcb10dcce71fd39f8c6801407e9aebd953939682558a5317e4a021c/pgmpy-0.1.9-py3-none-any.whl (331kB)\n",
251 | "\r\u001b[K |█ | 10kB 19.2MB/s eta 0:00:01\r\u001b[K |██ | 20kB 1.7MB/s eta 0:00:01\r\u001b[K |███ | 30kB 2.5MB/s eta 0:00:01\r\u001b[K |████ | 40kB 1.7MB/s eta 0:00:01\r\u001b[K |█████ | 51kB 2.1MB/s eta 0:00:01\r\u001b[K |██████ | 61kB 2.5MB/s eta 0:00:01\r\u001b[K |███████ | 71kB 2.9MB/s eta 0:00:01\r\u001b[K |████████ | 81kB 2.2MB/s eta 0:00:01\r\u001b[K |█████████ | 92kB 2.5MB/s eta 0:00:01\r\u001b[K |█████████▉ | 102kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████▉ | 112kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████▉ | 122kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████▉ | 133kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████▉ | 143kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████▉ | 153kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████▉ | 163kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████▉ | 174kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████▉ | 184kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████████▊ | 194kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████▊ | 204kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████▊ | 215kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████▊ | 225kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████████████▊ | 235kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████▊ | 245kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 256kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████▊ | 266kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████▊ | 276kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████▋ | 286kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████▋ | 296kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 307kB 2.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▋ | 317kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▋| 327kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 337kB 2.8MB/s \n",
252 | "\u001b[?25hInstalling collected packages: pgmpy\n",
253 | "Successfully installed pgmpy-0.1.9\n"
254 | ],
255 | "name": "stdout"
256 | }
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "metadata": {
262 | "id": "K8rFdErsnUhd",
263 | "colab_type": "code",
264 | "outputId": "72888e1d-07dd-4a0f-886d-469df5010b54",
265 | "colab": {
266 | "base_uri": "https://localhost:8080/",
267 | "height": 84
268 | }
269 | },
270 | "source": [
271 | "# データ数\n",
272 | "num_data = 100\n",
273 | "\n",
274 | "# x1:0か1の値をnum_data個生成、0の確率は0.6、1の確率は0.4\n",
275 | "x1 = np.random.choice([0, 1], num_data, p=[0.6, 0.4])\n",
276 | "\n",
277 | "# x2:0か1の値をnum_data個生成、0の確率は0.4、1の確率は0.6\n",
278 | "x2 = np.random.choice([0, 1], num_data, p=[0.4, 0.6])\n",
279 | "\n",
280 | "# 2変数で表にする\n",
281 | "df2 = pd.DataFrame({'x1': x1,\n",
282 | " 'x2': x2,\n",
283 | " })\n",
284 | "\n",
285 | "# 各カウント\n",
286 | "print(((df2[\"x1\"] == 0) & (df2[\"x2\"] == 0)).sum())\n",
287 | "print(((df2[\"x1\"] == 1) & (df2[\"x2\"] == 0)).sum())\n",
288 | "print(((df2[\"x1\"] == 0) & (df2[\"x2\"] == 1)).sum())\n",
289 | "print(((df2[\"x1\"] == 1) & (df2[\"x2\"] == 1)).sum())\n"
290 | ],
291 | "execution_count": 6,
292 | "outputs": [
293 | {
294 | "output_type": "stream",
295 | "text": [
296 | "20\n",
297 | "15\n",
298 | "35\n",
299 | "30\n"
300 | ],
301 | "name": "stdout"
302 | }
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "metadata": {
308 | "id": "nJJBRMKEnbjO",
309 | "colab_type": "code",
310 | "outputId": "fb9adaf1-b048-46a6-a96c-4cc5987a4055",
311 | "colab": {
312 | "base_uri": "https://localhost:8080/",
313 | "height": 84
314 | }
315 | },
316 | "source": [
317 | "from pgmpy.estimators import ConstraintBasedEstimator\n",
318 | "\n",
319 | "est = ConstraintBasedEstimator(df2)\n",
320 | "print(est.test_conditional_independence(\n",
321 | " 'x1', 'x2', method=\"chi_square\", tol=0.05)) # 独立\n",
322 | "\n",
323 | "# 最初の例の場合\n",
324 | "est = ConstraintBasedEstimator(df)\n",
325 | "print(est.test_conditional_independence(\n",
326 | " 'x1', 'x2', method=\"chi_square\", tol=0.05)) # 独立でない\n"
327 | ],
328 | "execution_count": 7,
329 | "outputs": [
330 | {
331 | "output_type": "stream",
332 | "text": [
333 | "/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n",
334 | " import pandas.util.testing as tm\n"
335 | ],
336 | "name": "stderr"
337 | },
338 | {
339 | "output_type": "stream",
340 | "text": [
341 | "True\n",
342 | "False\n"
343 | ],
344 | "name": "stdout"
345 | }
346 | ]
347 | },
348 | {
349 | "cell_type": "markdown",
350 | "metadata": {
351 | "colab_type": "text",
352 | "id": "I6P1x9vAdG3i"
353 | },
354 | "source": [
355 | "以上"
356 | ]
357 | }
358 | ]
359 | }
--------------------------------------------------------------------------------
/8_3_5_deeplearning_gan_sam.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "8_3_5_deeplearning_gan_sam.ipynb",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "language": "python",
14 | "name": "python3"
15 | },
16 | "language_info": {
17 | "codemirror_mode": {
18 | "name": "ipython",
19 | "version": 3
20 | },
21 | "file_extension": ".py",
22 | "mimetype": "text/x-python",
23 | "name": "python",
24 | "nbconvert_exporter": "python",
25 | "pygments_lexer": "ipython3",
26 | "version": "3.6.5"
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "colab_type": "text",
34 | "id": "aoxI3DOK9vm2"
35 | },
36 | "source": [
37 | "# 8.3 SAM(Structural Agnostic Model)による因果探索の実装\n",
38 | "\n",
39 | "本ファイルは、8.3節の実装です。\n",
40 | "\n",
41 | "7.5節と同じく、「上司向け:部下とのキャリア面談のポイント研修」の疑似データを作成して、SAMによる因果探索を実施します。"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {
47 | "colab_type": "text",
48 | "id": "2XdIDbdlejUk"
49 | },
50 | "source": [
51 | "## プログラム実行前の設定など"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "metadata": {
57 | "id": "_QZagoIYv44f",
58 | "colab_type": "code",
59 | "outputId": "1c133069-f1aa-4f79-ed78-ab14a644889b",
60 | "colab": {
61 | "base_uri": "https://localhost:8080/",
62 | "height": 122
63 | }
64 | },
65 | "source": [
66 | "# PyTorchのバージョンを下げる\n",
67 | "!pip install torch==1.4.0+cu92 torchvision==0.5.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html"
68 | ],
69 | "execution_count": 0,
70 | "outputs": [
71 | {
72 | "output_type": "stream",
73 | "text": [
74 | "Looking in links: https://download.pytorch.org/whl/torch_stable.html\n",
75 | "Requirement already satisfied: torch==1.4.0+cu92 in /usr/local/lib/python3.6/dist-packages (1.4.0+cu92)\n",
76 | "Requirement already satisfied: torchvision==0.5.0+cu92 in /usr/local/lib/python3.6/dist-packages (0.5.0+cu92)\n",
77 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0+cu92) (1.18.4)\n",
78 | "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0+cu92) (7.0.0)\n",
79 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0+cu92) (1.12.0)\n"
80 | ],
81 | "name": "stdout"
82 | }
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "metadata": {
88 | "id": "iqh9FyP-wHGa",
89 | "colab_type": "code",
90 | "outputId": "d70b471a-b77c-47b2-af12-990ad611b00d",
91 | "colab": {
92 | "base_uri": "https://localhost:8080/",
93 | "height": 34
94 | }
95 | },
96 | "source": [
97 | "import torch \n",
98 | "print(torch.__version__) # 元は1.5.0+cu101、versionを1.4に下げた"
99 | ],
100 | "execution_count": 0,
101 | "outputs": [
102 | {
103 | "output_type": "stream",
104 | "text": [
105 | "1.4.0+cu92\n"
106 | ],
107 | "name": "stdout"
108 | }
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "metadata": {
114 | "colab_type": "code",
115 | "id": "XZFKJwcu-_Oj",
116 | "colab": {}
117 | },
118 | "source": [
119 | "# 乱数のシードを設定\n",
120 | "import random\n",
121 | "import numpy as np\n",
122 | "\n",
123 | "np.random.seed(1234)\n",
124 | "random.seed(1234)\n"
125 | ],
126 | "execution_count": 0,
127 | "outputs": []
128 | },
129 | {
130 | "cell_type": "code",
131 | "metadata": {
132 | "colab_type": "code",
133 | "id": "hx1idArc_F15",
134 | "colab": {}
135 | },
136 | "source": [
137 | "# 使用するパッケージ(ライブラリと関数)を定義\n",
138 | "# 標準正規分布の生成用\n",
139 | "from numpy.random import *\n",
140 | "\n",
141 | "# グラフの描画用\n",
142 | "import matplotlib.pyplot as plt\n",
143 | "\n",
144 | "# その他\n",
145 | "import pandas as pd\n",
146 | "\n",
147 | "# シグモイド関数をimport\n",
148 | "from scipy.special import expit\n"
149 | ],
150 | "execution_count": 0,
151 | "outputs": []
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {
156 | "colab_type": "text",
157 | "id": "AWqP6yeQlI_t"
158 | },
159 | "source": [
160 | "## データの作成"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "metadata": {
166 | "colab_type": "code",
167 | "id": "QBsAEiQ77xww",
168 | "colab": {}
169 | },
170 | "source": [
171 | "# データ数\n",
172 | "num_data = 2000\n",
173 | "\n",
174 | "# 部下育成への熱心さ\n",
175 | "x = np.random.uniform(low=-1, high=1, size=num_data) # -1から1の一様乱数\n",
176 | "\n",
177 | "# 上司が「上司向け:部下とのキャリア面談のポイント研修」に参加したかどうか\n",
178 | "e_z = randn(num_data) # ノイズの生成\n",
179 | "z_prob = expit(-5.0*x+5*e_z)\n",
180 | "Z = np.array([])\n",
181 | "\n",
182 | "# 上司が「上司向け:部下とのキャリア面談のポイント研修」に参加したかどうか\n",
183 | "for i in range(num_data):\n",
184 | " Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]\n",
185 | " Z = np.append(Z, Z_i)\n",
186 | "\n",
187 | "# 介入効果の非線形性:部下育成の熱心さxの値に応じて段階的に変化\n",
188 | "t = np.zeros(num_data)\n",
189 | "for i in range(num_data):\n",
190 | " if x[i] < 0:\n",
191 | " t[i] = 0.5\n",
192 | " elif x[i] >= 0 and x[i] < 0.5:\n",
193 | " t[i] = 0.7\n",
194 | " elif x[i] >= 0.5:\n",
195 | " t[i] = 1.0\n",
196 | "\n",
197 | "e_y = randn(num_data)\n",
198 | "Y = 2.0 + t*Z + 0.3*x + 0.1*e_y \n",
199 | "\n",
200 | "\n",
201 | "# 本章からの追加データを生成\n",
202 | "\n",
203 | "# Y2:部下当人のチームメンバへの満足度 1から5の5段階\n",
204 | "Y2 = np.random.choice([1.0, 2.0, 3.0, 4.0, 5.0],\n",
205 | " num_data, p=[0.1, 0.2, 0.3, 0.2, 0.2])\n",
206 | "\n",
207 | "# Y3:部下当人の仕事への満足度\n",
208 | "e_y3 = randn(num_data)\n",
209 | "Y3 = 3*Y + Y2 + e_y3\n",
210 | "\n",
211 | "# Y4:部下当人の仕事のパフォーマンス\n",
212 | "e_y4 = randn(num_data)\n",
213 | "Y4 = 3*Y3 + 2*e_y4 + 5\n"
214 | ],
215 | "execution_count": 0,
216 | "outputs": []
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {
221 | "colab_type": "text",
222 | "id": "BHcdUlW9koTa"
223 | },
224 | "source": [
225 | "## データをまとめた表を作成し、正規化し、可視化する"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "metadata": {
231 | "colab_type": "code",
232 | "id": "1EMwdGIIIPrK",
233 | "outputId": "9530a268-91e5-4d68-b8ab-c2d1f62dddcd",
234 | "colab": {
235 | "base_uri": "https://localhost:8080/",
236 | "height": 195
237 | }
238 | },
239 | "source": [
240 | "df = pd.DataFrame({'x': x,\n",
241 | " 'Z': Z,\n",
242 | " 't': t,\n",
243 | " 'Y': Y,\n",
244 | " 'Y2': Y2,\n",
245 | " 'Y3': Y3,\n",
246 | " 'Y4': Y4,\n",
247 | " })\n",
248 | "\n",
249 | "del df[\"t\"] # 変数tは観測できないので削除\n",
250 | "\n",
251 | "df.head() # 先頭を表示\n"
252 | ],
253 | "execution_count": 0,
254 | "outputs": [
255 | {
256 | "output_type": "execute_result",
257 | "data": {
258 | "text/html": [
259 | "\n",
260 | "\n",
273 | "
\n",
274 | " \n",
275 | " \n",
276 | " | \n",
277 | " x | \n",
278 | " Z | \n",
279 | " Y | \n",
280 | " Y2 | \n",
281 | " Y3 | \n",
282 | " Y4 | \n",
283 | "
\n",
284 | " \n",
285 | " \n",
286 | " \n",
287 | " 0 | \n",
288 | " -0.616961 | \n",
289 | " 1.0 | \n",
290 | " 2.286924 | \n",
291 | " 2.0 | \n",
292 | " 8.732544 | \n",
293 | " 30.326507 | \n",
294 | "
\n",
295 | " \n",
296 | " 1 | \n",
297 | " 0.244218 | \n",
298 | " 1.0 | \n",
299 | " 2.864636 | \n",
300 | " 3.0 | \n",
301 | " 10.743959 | \n",
302 | " 37.149014 | \n",
303 | "
\n",
304 | " \n",
305 | " 2 | \n",
306 | " -0.124545 | \n",
307 | " 0.0 | \n",
308 | " 2.198515 | \n",
309 | " 3.0 | \n",
310 | " 10.569163 | \n",
311 | " 38.481185 | \n",
312 | "
\n",
313 | " \n",
314 | " 3 | \n",
315 | " 0.570717 | \n",
316 | " 1.0 | \n",
317 | " 3.230572 | \n",
318 | " 3.0 | \n",
319 | " 12.312526 | \n",
320 | " 43.709229 | \n",
321 | "
\n",
322 | " \n",
323 | " 4 | \n",
324 | " 0.559952 | \n",
325 | " 0.0 | \n",
326 | " 2.459267 | \n",
327 | " 5.0 | \n",
328 | " 12.418739 | \n",
329 | " 40.833938 | \n",
330 | "
\n",
331 | " \n",
332 | "
\n",
333 | "
"
334 | ],
335 | "text/plain": [
336 | " x Z Y Y2 Y3 Y4\n",
337 | "0 -0.616961 1.0 2.286924 2.0 8.732544 30.326507\n",
338 | "1 0.244218 1.0 2.864636 3.0 10.743959 37.149014\n",
339 | "2 -0.124545 0.0 2.198515 3.0 10.569163 38.481185\n",
340 | "3 0.570717 1.0 3.230572 3.0 12.312526 43.709229\n",
341 | "4 0.559952 0.0 2.459267 5.0 12.418739 40.833938"
342 | ]
343 | },
344 | "metadata": {
345 | "tags": []
346 | },
347 | "execution_count": 6
348 | }
349 | ]
350 | },
351 | {
352 | "cell_type": "markdown",
353 | "metadata": {
354 | "colab_type": "text",
355 | "id": "1TPIeXDg6QDG"
356 | },
357 | "source": [
358 | "## SAMによる推論を実施"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "metadata": {
364 | "colab_type": "code",
365 | "id": "edNNPSLY6u6d",
366 | "outputId": "36f912b0-1dbb-4516-8ea5-d5af39cce918",
367 | "colab": {
368 | "base_uri": "https://localhost:8080/",
369 | "height": 386
370 | }
371 | },
372 | "source": [
373 | "!pip install cdt==0.5.18"
374 | ],
375 | "execution_count": 0,
376 | "outputs": [
377 | {
378 | "output_type": "stream",
379 | "text": [
380 | "Requirement already satisfied: cdt==0.5.18 in /usr/local/lib/python3.6/dist-packages (0.5.18)\n",
381 | "Requirement already satisfied: GPUtil in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.4.0)\n",
382 | "Requirement already satisfied: statsmodels in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.10.2)\n",
383 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (4.41.1)\n",
384 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (2.23.0)\n",
385 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.18.4)\n",
386 | "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.15.1)\n",
387 | "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.0.4)\n",
388 | "Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.4.1)\n",
389 | "Requirement already satisfied: networkx in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (2.4)\n",
390 | "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.22.2.post1)\n",
391 | "Requirement already satisfied: skrebate in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.6)\n",
392 | "Requirement already satisfied: patsy>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from statsmodels->cdt==0.5.18) (0.5.1)\n",
393 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (2020.4.5.1)\n",
394 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (2.9)\n",
395 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (1.24.3)\n",
396 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (3.0.4)\n",
397 | "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->cdt==0.5.18) (2.8.1)\n",
398 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->cdt==0.5.18) (2018.9)\n",
399 | "Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx->cdt==0.5.18) (4.4.2)\n",
400 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from patsy>=0.4.0->statsmodels->cdt==0.5.18) (1.12.0)\n"
401 | ],
402 | "name": "stdout"
403 | }
404 | ]
405 | },
406 | {
407 | "cell_type": "markdown",
408 | "metadata": {
409 | "colab_type": "text",
410 | "id": "ihTvgRcv1E8s"
411 | },
412 | "source": [
413 | "### SAMの識別器Dの実装"
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "metadata": {
419 | "colab_type": "code",
420 | "id": "sJQ2_9LY8MQ8",
421 | "colab": {}
422 | },
423 | "source": [
424 | "# PyTorchから使用するものをimport\n",
425 | "import torch\n",
426 | "import torch.nn as nn\n",
427 | "\n",
428 | "\n",
429 | "class SAMDiscriminator(nn.Module):\n",
430 | " \"\"\"SAMのDiscriminatorのニューラルネットワーク\n",
431 | " \"\"\"\n",
432 | "\n",
433 | " def __init__(self, nfeatures, dnh, hlayers):\n",
434 | " super(SAMDiscriminator, self).__init__()\n",
435 | "\n",
436 | " # ----------------------------------\n",
437 | " # ネットワークの用意\n",
438 | " # ----------------------------------\n",
439 | " self.nfeatures = nfeatures # 入力変数の数\n",
440 | "\n",
441 | " layers = []\n",
442 | " layers.append(nn.Linear(nfeatures, dnh))\n",
443 | " layers.append(nn.BatchNorm1d(dnh))\n",
444 | " layers.append(nn.LeakyReLU(.2))\n",
445 | "\n",
446 | " for i in range(hlayers-1):\n",
447 | " layers.append(nn.Linear(dnh, dnh))\n",
448 | " layers.append(nn.BatchNorm1d(dnh))\n",
449 | " layers.append(nn.LeakyReLU(.2))\n",
450 | "\n",
451 | " layers.append(nn.Linear(dnh, 1)) # 最終出力\n",
452 | "\n",
453 | " self.layers = nn.Sequential(*layers)\n",
454 | "\n",
455 | " # ----------------------------------\n",
456 | " # maskの用意(対角成分のみ1で、他は0の行列)\n",
457 | " # ----------------------------------\n",
458 | " mask = torch.eye(nfeatures, nfeatures) # 変数の数×変数の数の単位行列\n",
459 | " self.register_buffer(\"mask\", mask.unsqueeze(0)) # 単位行列maskを保存しておく\n",
460 | "\n",
461 | " # 注意:register_bufferはmodelのパラメータではないが、その後forwardで使う変数を登録するPyTorchのメソッドです\n",
462 | " # self.変数名で、以降も使用可能になります\n",
463 | " # https://pytorch.org/docs/stable/nn.html?highlight=register_buffer#torch.nn.Module.register_buffer\n",
464 | "\n",
465 | " def forward(self, input, obs_data=None):\n",
466 | " \"\"\" 順伝搬の計算\n",
467 | " Args:\n",
468 | " input (torch.Size([データ数, 観測変数の種類数])): 観測したデータ、もしくは生成されたデータ\n",
469 | " obs_data (torch.Size([データ数, 観測変数の種類数])):観測したデータ\n",
470 | " Returns:\n",
471 | " torch.Tensor: 観測したデータか、それとも生成されたデータかの判定結果\n",
472 | " \"\"\"\n",
473 | "\n",
474 | " if obs_data is not None:\n",
475 | " # 生成データを識別器に入力する場合\n",
476 | " return [self.layers(i) for i in torch.unbind(obs_data.unsqueeze(1) * (1 - self.mask)\n",
477 | " + input.unsqueeze(1) * self.mask, 1)]\n",
478 | " # 対角成分のみ生成したデータ、その他は観測データに\n",
479 | " # データを各変数ごとに、生成したもの、その他観測したもので混ぜて、1変数ずつ生成したものを放り込む\n",
480 | " # torch.unbind(x,1)はxの1次元目でテンソルをタプルに展開する\n",
481 | " # minibatch数が2000、観測データの変数が6種類の場合、\n",
482 | " # [2000,6]→[2000,6,6]→([2000,6], [2000,6], [2000,6], [2000,6], [2000,6], [2000,6])→([2000,1], [2000,1], [2000,1], [2000,1], [2000,1], [2000,1])\n",
483 | " # returnは[torch.Size([2000, 1]),torch.Size([2000, 1]),torch.Size([2000, 1], torch.Size([2000, 1]),torch.Size([2000, 1]),torch.Size([2000, 1])]\n",
484 | "\n",
485 | " # 注:生成した変数全種類を用いた判定はしない。\n",
486 | " # すなわち、生成した変数1種類と、元の観測データたちをまとめて1つにし、それが観測結果か、生成結果を判定させる\n",
487 | "\n",
488 | " else:\n",
489 | " # 観測データを識別器に入力する場合\n",
490 | "\n",
491 | " return self.layers(input)\n",
492 | " # returnは[torch.Size([2000, 1])]\n",
493 | "\n",
494 | "\n",
495 | " def reset_parameters(self):\n",
496 | " \"\"\"識別器Dの重みパラメータの初期化を実施\"\"\"\n",
497 | " for layer in self.layers:\n",
498 | " if hasattr(layer, 'reset_parameters'):\n",
499 | " layer.reset_parameters()\n"
500 | ],
501 | "execution_count": 0,
502 | "outputs": []
503 | },
504 | {
505 | "cell_type": "markdown",
506 | "metadata": {
507 | "colab_type": "text",
508 | "id": "yLyjZsSc1S2i"
509 | },
510 | "source": [
511 | "### SAMの生成器Gの実装"
512 | ]
513 | },
514 | {
515 | "cell_type": "code",
516 | "metadata": {
517 | "colab_type": "code",
518 | "id": "pBUh-fKh8X-E",
519 | "outputId": "b30b006a-9c56-4e93-9185-1d4dc3d2abfb",
520 | "colab": {
521 | "base_uri": "https://localhost:8080/",
522 | "height": 72
523 | }
524 | },
525 | "source": [
526 | "from cdt.utils.torch import ChannelBatchNorm1d, MatrixSampler, Linear3D\n",
527 | "\n",
528 | "\n",
529 | "class SAMGenerator(nn.Module):\n",
530 | " \"\"\"SAMのGeneratorのニューラルネットワーク\n",
531 | " \"\"\"\n",
532 | "\n",
533 | " def __init__(self, data_shape, nh):\n",
534 | " \"\"\"初期化\"\"\"\n",
535 | " super(SAMGenerator, self).__init__()\n",
536 | "\n",
537 | " # ----------------------------------\n",
538 | " # 対角成分のみ0で、残りは1のmaskとなる変数skeletonを作成\n",
539 | " # ※最後の行は、全部1です\n",
540 | " # ----------------------------------\n",
541 | " nb_vars = data_shape[1] # 変数の数\n",
542 | " skeleton = 1 - torch.eye(nb_vars + 1, nb_vars)\n",
543 | "\n",
544 | " self.register_buffer('skeleton', skeleton)\n",
545 | "\n",
546 | " # 注意:register_bufferはmodelのパラメータではないが、その後forwardで使う変数を登録するPyTorchのメソッドです\n",
547 | " # self.変数名で、以降も使用可能になります\n",
548 | " # https://pytorch.org/docs/stable/nn.html?highlight=register_buffer#torch.nn.Module.register_buffer\n",
549 | "\n",
550 | " # ----------------------------------\n",
551 | " # ネットワークの用意\n",
552 | " # ----------------------------------\n",
553 | " # 入力層(SAMの形での全結合層) \n",
554 | " self.input_layer = Linear3D(\n",
555 | " (nb_vars, nb_vars + 1, nh)) # nhは中間層のニューロン数\n",
556 | " # https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/torch.py#L289\n",
557 | "\n",
558 | " # 中間層\n",
559 | " layers = []\n",
560 | " # 2次元を1次元に変換してバッチノーマライゼーションするモジュール\n",
561 | " layers.append(ChannelBatchNorm1d(nb_vars, nh))\n",
562 | " layers.append(nn.Tanh())\n",
563 | " self.layers = nn.Sequential(*layers)\n",
564 | "\n",
565 | " # ChannelBatchNorm1d\n",
566 | " # https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/torch.py#L130\n",
567 | "\n",
568 | " # 出力層(再度、SAMの形での全結合層)\n",
569 | " self.output_layer = Linear3D((nb_vars, nh, 1))\n",
570 | "\n",
571 | " def forward(self, data, noise, adj_matrix, drawn_neurons=None):\n",
572 | " \"\"\" 順伝搬の計算\n",
573 | " Args:\n",
574 | " data (torch.Tensor): 観測データ\n",
575 | " noise (torch.Tensor): データ生成用のノイズ\n",
576 | " adj_matrix (torch.Tensor): 因果関係を示す因果構造マトリクスM\n",
577 | " drawn_neurons (torch.Tensor): Linear3Dの複雑さを制御する複雑さマトリクスZ\n",
578 | " Returns:\n",
579 | " torch.Tensor: 生成されたデータ\n",
580 | " \"\"\"\n",
581 | "\n",
582 | " # 入力層\n",
583 | " x = self.input_layer(data, noise, adj_matrix *\n",
584 | " self.skeleton) # Linear3D\n",
585 | "\n",
586 | " # 中間層(バッチノーマライゼーションとTanh)\n",
587 | " x = self.layers(x)\n",
588 | "\n",
589 | " # 出力層\n",
590 | " output = self.output_layer(\n",
591 | " x, noise=None, adj_matrix=drawn_neurons) # Linear3D\n",
592 | "\n",
593 | " return output.squeeze(2)\n",
594 | "\n",
595 | " def reset_parameters(self):\n",
596 | " \"\"\"重みパラメータの初期化を実施\"\"\"\n",
597 | "\n",
598 | " self.input_layer.reset_parameters()\n",
599 | " self.output_layer.reset_parameters()\n",
600 | "\n",
601 | " for layer in self.layers:\n",
602 | " if hasattr(layer, 'reset_parameters'):\n",
603 | " layer.reset_parameters()\n"
604 | ],
605 | "execution_count": 0,
606 | "outputs": [
607 | {
608 | "output_type": "stream",
609 | "text": [
610 | "Detecting 1 CUDA device(s).\n",
611 | "sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.\n"
612 | ],
613 | "name": "stderr"
614 | }
615 | ]
616 | },
617 | {
618 | "cell_type": "markdown",
619 | "metadata": {
620 | "colab_type": "text",
621 | "id": "2MubteRua0mj"
622 | },
623 | "source": [
624 | "### SAMの誤差関数"
625 | ]
626 | },
627 | {
628 | "cell_type": "code",
629 | "metadata": {
630 | "colab_type": "code",
631 | "id": "Hy2GqNNdapc6",
632 | "colab": {}
633 | },
634 | "source": [
635 | "# ネットワークを示す因果構造マトリクスMがDAG(有向非循環グラフ)になるように加える損失\n",
636 | "\n",
637 | "def notears_constr(adj_m, max_pow=None):\n",
638 | " \"\"\"No Tears constraint for binary adjacency matrixes. \n",
639 | " Args:\n",
640 | " adj_m (array-like): Adjacency matrix of the graph\n",
641 | " max_pow (int): maximum value to which the infinite sum is to be computed.\n",
642 | " defaults to the shape of the adjacency_matrix\n",
643 | " Returns:\n",
644 | " np.ndarray or torch.Tensor: Scalar value of the loss with the type\n",
645 | " depending on the input.\n",
646 | " 参考:https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/loss.py#L215\n",
647 | " \"\"\"\n",
648 | " m_exp = [adj_m]\n",
649 | " if max_pow is None:\n",
650 | " max_pow = adj_m.shape[1]\n",
651 | " while(m_exp[-1].sum() > 0 and len(m_exp) < max_pow):\n",
652 | " m_exp.append(m_exp[-1] @ adj_m/len(m_exp))\n",
653 | "\n",
654 | " return sum([i.diag().sum() for idx, i in enumerate(m_exp)])\n",
655 | " "
656 | ],
657 | "execution_count": 0,
658 | "outputs": []
659 | },
660 | {
661 | "cell_type": "markdown",
662 | "metadata": {
663 | "colab_type": "text",
664 | "id": "d01nY6IKKmXe"
665 | },
666 | "source": [
667 | "### SAMの学習を実施する関数"
668 | ]
669 | },
670 | {
671 | "cell_type": "code",
672 | "metadata": {
673 | "colab_type": "code",
674 | "id": "LdgNruwmJkxj",
675 | "colab": {}
676 | },
677 | "source": [
678 | "from sklearn.preprocessing import scale\n",
679 | "from torch import optim\n",
680 | "from torch.utils.data import DataLoader\n",
681 | "from tqdm import tqdm\n",
682 | "\n",
683 | "\n",
684 | "def run_SAM(in_data, lr_gen, lr_disc, lambda1, lambda2, hlayers, nh, dnh, train_epochs, test_epochs, device):\n",
685 | " '''SAMの学習を実行する関数'''\n",
686 | "\n",
687 | " # ---------------------------------------------------\n",
688 | " # 入力データの前処理\n",
689 | " # ---------------------------------------------------\n",
690 | " list_nodes = list(in_data.columns) # 入力データの列名のリスト\n",
691 | " data = scale(in_data[list_nodes].values) # 入力データの正規化\n",
692 | " nb_var = len(list_nodes) # 入力データの数 = d\n",
693 | " data = data.astype('float32') # 入力データをfloat32型に\n",
694 | " data = torch.from_numpy(data).to(device) # 入力データをPyTorchのテンソルに\n",
695 | " rows, cols = data.size() # rowsはデータ数、colsは変数の数\n",
696 | "\n",
697 | " # ---------------------------------------------------\n",
698 | " # DataLoaderの作成(バッチサイズは全データ)\n",
699 | " # ---------------------------------------------------\n",
700 | " batch_size = rows # 入力データ全てを使用したミニバッチ学習とする\n",
701 | " data_iterator = DataLoader(data, batch_size=batch_size,\n",
702 | " shuffle=True, drop_last=True)\n",
703 | " # 注意:引数のdrop_lastはdataをbatch_sizeで取り出していったときに最後に余ったものは使用しない設定\n",
704 | "\n",
705 | " # ---------------------------------------------------\n",
706 | " # 【Generator】ネットワークの生成とパラメータの初期化\n",
707 | " # cols:入力変数の数、nhは中間ニューロンの数、hlayersは中間層の数\n",
708 | " # neuron_samplerは、Functional gatesの変数zを学習するネットワーク\n",
709 | " # graph_samplerは、Structual gatesの変数aを学習するネットワーク\n",
710 | " # ---------------------------------------------------\n",
711 | " sam = SAMGenerator((batch_size, cols), nh).to(device) # 生成器G\n",
712 | " graph_sampler = MatrixSampler(nb_var, mask=None, gumbel=False).to(\n",
713 | " device) # 因果構造マトリクスMを作るネットワーク\n",
714 | " neuron_sampler = MatrixSampler((nh, nb_var), mask=False, gumbel=True).to(\n",
715 | " device) # 複雑さマトリクスZを作るネットワーク\n",
716 | "\n",
717 | " # 注意:MatrixSamplerはGumbel-Softmaxを使用し、0か1を出力させるニューラルネットワーク\n",
718 | " # SAMの著者らの実装モジュール、MatrixSamplerを使用\n",
719 | " # https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/torch.py#L212\n",
720 | "\n",
721 | " # 重みパラメータの初期化\n",
722 | " sam.reset_parameters()\n",
723 | " graph_sampler.weights.data.fill_(2)\n",
724 | "\n",
725 | " # ---------------------------------------------------\n",
726 | " # 【Discriminator】ネットワークの生成とパラメータの初期化\n",
727 | " # cols:入力変数の数、dnhは中間ニューロンの数、hlayersは中間層の数。\n",
728 | " # ---------------------------------------------------\n",
729 | " discriminator = SAMDiscriminator(cols, dnh, hlayers).to(device)\n",
730 | " discriminator.reset_parameters() # 重みパラメータの初期化\n",
731 | "\n",
732 | " # ---------------------------------------------------\n",
733 | " # 最適化の設定\n",
734 | " # ---------------------------------------------------\n",
735 | " # 生成器\n",
736 | "\n",
737 | " g_optimizer = optim.Adam(sam.parameters(), lr=lr_gen)\n",
738 | " graph_optimizer = optim.Adam(graph_sampler.parameters(), lr=lr_gen)\n",
739 | " neuron_optimizer = optim.Adam(neuron_sampler.parameters(), lr=lr_gen)\n",
740 | "\n",
741 | " # 識別器\n",
742 | " d_optimizer = optim.Adam(discriminator.parameters(), lr=lr_disc)\n",
743 | "\n",
744 | " # 損失関数\n",
745 | " criterion = nn.BCEWithLogitsLoss()\n",
746 | " # nn.BCEWithLogitsLoss()は、binary cross entropy with Logistic function\n",
747 | " # https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss\n",
748 | "\n",
749 | " # 損失関数のDAGに関する制約の設定パラメータ\n",
750 | " dagstart = 0.5\n",
751 | " dagpenalization_increase = 0.001*10\n",
752 | "\n",
753 | " # ---------------------------------------------------\n",
754 | " # forward計算、および損失関数の計算に使用する変数を用意\n",
755 | " # ---------------------------------------------------\n",
756 | " _true = torch.ones(1).to(device)\n",
757 | " _false = torch.zeros(1).to(device)\n",
758 | "\n",
759 | " noise = torch.randn(batch_size, nb_var).to(device) # 生成器Gで使用する生成ノイズ\n",
760 | " noise_row = torch.ones(1, nb_var).to(device)\n",
761 | "\n",
762 | " output = torch.zeros(nb_var, nb_var).to(device) # 求まった隣接行列\n",
763 | " output_loss = torch.zeros(1, 1).to(device)\n",
764 | "\n",
765 | " # ---------------------------------------------------\n",
766 | " # forwardの計算で、ネットワークを学習させる\n",
767 | " # ---------------------------------------------------\n",
768 | " pbar = tqdm(range(train_epochs + test_epochs)) # 進捗(progressive bar)の表示\n",
769 | "\n",
770 | " for epoch in pbar:\n",
771 | " for i_batch, batch in enumerate(data_iterator):\n",
772 | "\n",
773 | " # 最適化を初期化\n",
774 | " g_optimizer.zero_grad()\n",
775 | " graph_optimizer.zero_grad()\n",
776 | " neuron_optimizer.zero_grad()\n",
777 | " d_optimizer.zero_grad()\n",
778 | "\n",
779 | " # 因果構造マトリクスM(drawn_graph)と複雑さマトリクスZ(drawn_neurons)をMatrixSamplerから取得\n",
780 | " drawn_graph = graph_sampler()\n",
781 | " drawn_neurons = neuron_sampler()\n",
782 | " # (drawn_graph)のサイズは、torch.Size([nb_var, nb_var])。 出力値は0か1\n",
783 | " # (drawn_neurons)のサイズは、torch.Size([nh, nb_var])。 出力値は0か1\n",
784 | "\n",
785 | " # ノイズをリセットし、生成器Gで疑似データを生成\n",
786 | " noise.normal_()\n",
787 | " generated_variables = sam(data=batch, noise=noise,\n",
788 | " adj_matrix=torch.cat(\n",
789 | " [drawn_graph, noise_row], 0),\n",
790 | " drawn_neurons=drawn_neurons)\n",
791 | "\n",
792 | " # 識別器Dで判定\n",
793 | " # 観測変数のリスト[]で、各torch.Size([data数, 1])が求まる\n",
794 | " disc_vars_d = discriminator(generated_variables.detach(), batch)\n",
795 | " # 観測変数のリスト[] で、各torch.Size([data数, 1])が求まる\n",
796 | " disc_vars_g = discriminator(generated_variables, batch)\n",
797 | " true_vars_disc = discriminator(batch) # torch.Size([data数, 1])が求まる\n",
798 | "\n",
799 | " # 損失関数の計算(DCGAN)\n",
800 | " disc_loss = sum([criterion(gen, _false.expand_as(gen)) for gen in disc_vars_d]) / nb_var \\\n",
801 | " + criterion(true_vars_disc, _true.expand_as(true_vars_disc))\n",
802 | "\n",
803 | " gen_loss = sum([criterion(gen,\n",
804 | " _true.expand_as(gen))\n",
805 | " for gen in disc_vars_g])\n",
806 | "\n",
807 | " # 損失の計算(SAM論文のオリジナルのfgan)\n",
808 | " #disc_loss = sum([torch.mean(torch.exp(gen - 1)) for gen in disc_vars_d]) / nb_var - torch.mean(true_vars_disc)\n",
809 | " #gen_loss = -sum([torch.mean(torch.exp(gen - 1)) for gen in disc_vars_g])\n",
810 | "\n",
811 | " # 識別器Dのバックプロパゲーションとパラメータの更新\n",
812 | " if epoch < train_epochs:\n",
813 | " disc_loss.backward()\n",
814 | " d_optimizer.step()\n",
815 | "\n",
816 | " # 生成器のGの損失の計算の残り(マトリクスの複雑さとDAGのNO TEAR)\n",
817 | " struc_loss = lambda1 / batch_size*drawn_graph.sum() # Mのloss\n",
818 | " func_loss = lambda2 / batch_size*drawn_neurons.sum() # Aのloss\n",
819 | "\n",
820 | " regul_loss = struc_loss + func_loss\n",
821 | "\n",
822 | " if epoch <= train_epochs * dagstart:\n",
823 | " # epochが基準前のときは、DAGになるようにMへのNO TEARSの制限はかけない\n",
824 | " loss = gen_loss + regul_loss\n",
825 | "\n",
826 | " else:\n",
827 | " # epochが基準後のときは、DAGになるようにNO TEARSの制限をかける\n",
828 | " filters = graph_sampler.get_proba() # マトリクスMの要素を取得(ただし、0,1ではなく、1の確率)\n",
829 | " dag_constraint = notears_constr(filters*filters) # NO TERARの計算\n",
830 | "\n",
831 | " # 徐々に線形にDAGの正則を強くする\n",
832 | " loss = gen_loss + regul_loss + \\\n",
833 | " ((epoch - train_epochs * dagstart) *\n",
834 | " dagpenalization_increase) * dag_constraint\n",
835 | "\n",
836 | " if epoch >= train_epochs:\n",
837 | " # testのepochの場合、結果を取得\n",
838 | " output.add_(filters.data)\n",
839 | " output_loss.add_(gen_loss.data)\n",
840 | " else:\n",
841 | " # trainのepochの場合、生成器Gのバックプロパゲーションと更新\n",
842 | " # retain_graph=Trueにすることで、以降3つのstep()が実行できる\n",
843 | " loss.backward(retain_graph=True)\n",
844 | " g_optimizer.step()\n",
845 | " graph_optimizer.step()\n",
846 | " neuron_optimizer.step()\n",
847 | "\n",
848 | " # 進捗の表示\n",
849 | " if epoch % 50 == 0:\n",
850 | " pbar.set_postfix(gen=gen_loss.item()/cols,\n",
851 | " disc=disc_loss.item(),\n",
852 | " regul_loss=regul_loss.item(),\n",
853 | " tot=loss.item())\n",
854 | "\n",
855 | " return output.cpu().numpy()/test_epochs, output_loss.cpu().numpy()/test_epochs/cols # Mと損失を出力\n"
856 | ],
857 | "execution_count": 0,
858 | "outputs": []
859 | },
860 | {
861 | "cell_type": "markdown",
862 | "metadata": {
863 | "colab_type": "text",
864 | "id": "S5SXuXOCUgmg"
865 | },
866 | "source": [
867 | "### GPUの使用可能を確認\n",
868 | "\n",
869 | "画面上部のメニュー ランタイム > ランタイムのタイプを変更 で、 ノートブックの設定 を開く\n",
870 | "\n",
871 | "ハードウェアアクセラレータに GPU を選択し、 保存 する"
872 | ]
873 | },
874 | {
875 | "cell_type": "code",
876 | "metadata": {
877 | "colab_type": "code",
878 | "id": "ClTdYzxzXsL2",
879 | "outputId": "854e7cdb-d51e-4cb4-fbc6-9830415f4c44",
880 | "colab": {
881 | "base_uri": "https://localhost:8080/",
882 | "height": 34
883 | }
884 | },
885 | "source": [
886 | "# GPUの使用確認:True or False\n",
887 | "torch.cuda.is_available()\n"
888 | ],
889 | "execution_count": 0,
890 | "outputs": [
891 | {
892 | "output_type": "execute_result",
893 | "data": {
894 | "text/plain": [
895 | "True"
896 | ]
897 | },
898 | "metadata": {
899 | "tags": []
900 | },
901 | "execution_count": 12
902 | }
903 | ]
904 | },
905 | {
906 | "cell_type": "markdown",
907 | "metadata": {
908 | "colab_type": "text",
909 | "id": "R-FzZ-W3Xseu"
910 | },
911 | "source": [
912 | "### SAMの学習を実施"
913 | ]
914 | },
915 | {
916 | "cell_type": "code",
917 | "metadata": {
918 | "colab_type": "code",
919 | "id": "xfqAztolY1fo",
920 | "outputId": "8489d950-76d7-46fb-c951-6381c4f871d2",
921 | "colab": {
922 | "base_uri": "https://localhost:8080/",
923 | "height": 826
924 | }
925 | },
926 | "source": [
927 | "# numpyの出力を小数点2桁に\n",
928 | "np.set_printoptions(precision=2, floatmode='fixed', suppress=True)\n",
929 | "\n",
930 | "# 因果探索の結果を格納するリスト\n",
931 | "m_list = []\n",
932 | "loss_list = []\n",
933 | "\n",
934 | "for i in range(5):\n",
935 | " m, loss = run_SAM(in_data=df, lr_gen=0.01*0.5,\n",
936 | " lr_disc=0.01*0.5*2,\n",
937 | " #lambda1=0.01, lambda2=1e-05,\n",
938 | " lambda1=5.0*20, lambda2=0.005*20,\n",
939 | " hlayers=2,\n",
940 | " nh=200, dnh=200,\n",
941 | " train_epochs=10000,\n",
942 | " test_epochs=1000,\n",
943 | " device='cuda:0')\n",
944 | "\n",
945 | " print(loss)\n",
946 | " print(m)\n",
947 | "\n",
948 | " m_list.append(m)\n",
949 | " loss_list.append(loss)\n",
950 | "\n",
951 | "# ネットワーク構造(5回の平均)\n",
952 | "print(sum(m_list) / len(m_list))\n",
953 | "\n",
954 | "# mはこうなって欲しい\n",
955 | "# x Z Y Y2 Y3 Y4\n",
956 | "# x 0 1 1 0 0 0\n",
957 | "# Z 0 0 1 0 0 0\n",
958 | "# Y 0 0 0 0 1 0\n",
959 | "# Y2 0 0 0 0 1 0\n",
960 | "# Y3 0 0 0 0 0 1\n",
961 | "# Y4 0 0 0 0 0 0\n"
962 | ],
963 | "execution_count": 0,
964 | "outputs": [
965 | {
966 | "output_type": "stream",
967 | "text": [
968 | "100%|██████████| 11000/11000 [05:20<00:00, 34.29it/s, disc=0.259, gen=5.63, regul_loss=0.564, tot=42.9]\n",
969 | " 0%| | 4/11000 [00:00<05:14, 34.97it/s, disc=1.43, gen=0.626, regul_loss=1.48, tot=5.23]"
970 | ],
971 | "name": "stderr"
972 | },
973 | {
974 | "output_type": "stream",
975 | "text": [
976 | "[[7.23]]\n",
977 | "[[0.00 0.11 0.96 0.00 0.01 0.00]\n",
978 | " [0.37 0.00 0.96 0.00 0.81 0.00]\n",
979 | " [0.00 0.03 0.00 0.99 1.00 0.66]\n",
980 | " [0.02 0.00 0.00 0.00 0.07 0.00]\n",
981 | " [0.02 0.00 0.02 1.00 0.00 0.98]\n",
982 | " [0.00 0.00 0.04 0.59 0.25 0.00]]\n"
983 | ],
984 | "name": "stdout"
985 | },
986 | {
987 | "output_type": "stream",
988 | "text": [
989 | "100%|██████████| 11000/11000 [05:22<00:00, 34.15it/s, disc=0.301, gen=5.6, regul_loss=0.515, tot=40.2]\n",
990 | " 0%| | 3/11000 [00:00<06:53, 26.59it/s, disc=1.46, gen=0.8, regul_loss=1.38, tot=6.18]"
991 | ],
992 | "name": "stderr"
993 | },
994 | {
995 | "output_type": "stream",
996 | "text": [
997 | "[[7.37]]\n",
998 | "[[0.00 1.00 0.99 0.00 0.38 0.14]\n",
999 | " [0.05 0.00 0.98 0.00 0.19 0.94]\n",
1000 | " [0.03 0.10 0.00 1.00 0.24 0.03]\n",
1001 | " [0.00 0.00 0.00 0.00 0.10 0.03]\n",
1002 | " [0.05 0.00 0.09 0.98 0.00 0.23]\n",
1003 | " [0.03 0.01 0.33 0.04 0.66 0.00]]\n"
1004 | ],
1005 | "name": "stdout"
1006 | },
1007 | {
1008 | "output_type": "stream",
1009 | "text": [
1010 | "100%|██████████| 11000/11000 [05:21<00:00, 34.18it/s, disc=0.666, gen=6.01, regul_loss=0.412, tot=41.8]\n",
1011 | " 0%| | 4/11000 [00:00<04:54, 37.32it/s, disc=1.46, gen=0.887, regul_loss=1.48, tot=6.8]"
1012 | ],
1013 | "name": "stderr"
1014 | },
1015 | {
1016 | "output_type": "stream",
1017 | "text": [
1018 | "[[4.51]]\n",
1019 | "[[0.00 0.96 0.96 0.00 0.33 0.02]\n",
1020 | " [0.05 0.00 0.14 0.00 0.00 0.35]\n",
1021 | " [0.00 0.94 0.00 0.99 0.99 0.98]\n",
1022 | " [0.02 0.00 0.00 0.00 0.10 0.00]\n",
1023 | " [0.00 0.00 0.02 0.99 0.00 0.97]\n",
1024 | " [0.01 0.00 0.01 0.00 0.12 0.00]]\n"
1025 | ],
1026 | "name": "stdout"
1027 | },
1028 | {
1029 | "output_type": "stream",
1030 | "text": [
1031 | "100%|██████████| 11000/11000 [05:23<00:00, 34.01it/s, disc=0.409, gen=5.87, regul_loss=0.365, tot=39.8]\n",
1032 | " 0%| | 4/11000 [00:00<04:48, 38.08it/s, disc=1.41, gen=0.776, regul_loss=1.53, tot=6.19]"
1033 | ],
1034 | "name": "stderr"
1035 | },
1036 | {
1037 | "output_type": "stream",
1038 | "text": [
1039 | "[[5.35]]\n",
1040 | "[[0.00 1.00 0.15 0.02 0.00 0.00]\n",
1041 | " [0.02 0.00 1.00 0.00 0.00 0.00]\n",
1042 | " [0.01 0.07 0.00 1.00 1.00 0.02]\n",
1043 | " [0.04 0.00 0.00 0.00 0.09 0.00]\n",
1044 | " [0.00 0.00 0.01 1.00 0.00 0.99]\n",
1045 | " [0.09 0.03 0.04 0.01 0.14 0.00]]\n"
1046 | ],
1047 | "name": "stdout"
1048 | },
1049 | {
1050 | "output_type": "stream",
1051 | "text": [
1052 | "100%|██████████| 11000/11000 [05:23<00:00, 34.01it/s, disc=0.597, gen=4.92, regul_loss=0.413, tot=32.8]"
1053 | ],
1054 | "name": "stderr"
1055 | },
1056 | {
1057 | "output_type": "stream",
1058 | "text": [
1059 | "[[4.98]]\n",
1060 | "[[0.00 0.97 0.01 0.00 0.00 0.00]\n",
1061 | " [0.06 0.00 0.06 0.00 0.00 0.00]\n",
1062 | " [0.73 0.99 0.00 0.99 1.00 0.03]\n",
1063 | " [0.92 0.00 0.00 0.00 0.08 0.00]\n",
1064 | " [0.18 0.00 0.00 0.98 0.00 1.00]\n",
1065 | " [0.19 0.00 0.00 0.00 0.09 0.00]]\n",
1066 | "[[0.00 0.81 0.62 0.01 0.14 0.03]\n",
1067 | " [0.11 0.00 0.63 0.00 0.20 0.26]\n",
1068 | " [0.16 0.42 0.00 1.00 0.85 0.34]\n",
1069 | " [0.20 0.00 0.00 0.00 0.09 0.01]\n",
1070 | " [0.05 0.00 0.03 0.99 0.00 0.84]\n",
1071 | " [0.06 0.01 0.08 0.13 0.25 0.00]]\n"
1072 | ],
1073 | "name": "stdout"
1074 | },
1075 | {
1076 | "output_type": "stream",
1077 | "text": [
1078 | "\n"
1079 | ],
1080 | "name": "stderr"
1081 | }
1082 | ]
1083 | },
1084 | {
1085 | "cell_type": "markdown",
1086 | "metadata": {
1087 | "colab_type": "text",
1088 | "id": "MGNG7pzi8LI6"
1089 | },
1090 | "source": [
1091 | "以上"
1092 | ]
1093 | },
1094 | {
1095 | "cell_type": "code",
1096 | "metadata": {
1097 | "id": "S9LudNsLxfkd",
1098 | "colab_type": "code",
1099 | "colab": {}
1100 | },
1101 | "source": [
1102 | ""
1103 | ],
1104 | "execution_count": 0,
1105 | "outputs": []
1106 | }
1107 | ]
1108 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yutaro Ogawa
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## つくりながら学ぶ! Python による因果分析 ~因果推論・因果探索の実践入門
2 |
3 |
4 |

5 |
6 |
7 |
8 |
9 | [書籍「つくりながら学ぶ! Python による因果分析 ~因果推論・因果探索の実践入門」(小川雄太郎、マイナビ出版 、20/06/30) ](https://www.amazon.co.jp/dp/4839973571/)
10 |
11 | のサポートリポジトリです。
12 |
13 |
14 |
15 | ### 1. 本書で扱う内容
16 |
17 | 本書の概要を以下の記事で解説しております。
18 |
19 | [「Python による因果推論と因果探索(初心者の方向け)」](https://qiita.com/sugulu/items/2cffb239b44853b07f70)
20 |
21 |
22 |
23 | **本書の目次**
24 |
25 | - 第 1 章 相関と因果の違いを理解しよう
26 | - 第 2 章 因果効果の種類を把握しよう
27 | - 第 3 章 グラフ表現とバックドア基準を理解しよう
28 | - 第 4 章 因果推論を実装しよう
29 | - 4-1 回帰分析による因果推論の実装
30 | - 4-2 傾向スコアを用いた逆確率重み付け法(IPTW)の実装
31 | - 4-3 Doubly Robust 法(DR 法)による因果推論の実装
32 | - 第 5 章 機械学習を用いた因果推論
33 | - 5-1 ランダムフォレストによる分類と回帰のしくみ
34 | - 5-2 Meta-Learners(T-Learner、S-Learner、X-Learner)の実装
35 | - 5-3 Doubly Robust Learning の実装
36 | - 第 6 章 LiNGAM の実装
37 | - 6-1 LiNGAM(Linear Non-Gaussian Acyclic Model)とは
38 | - 6-2 独立成分分析とは
39 | - 6-3 LiNGAM による因果探索の実装
40 | - 第 7 章 ベイジアンネットワークの実装
41 | - 7-1 ベイジアンネットワークとは
42 | - 7-2 ネットワークの当てはまりの良さを測る方法
43 | - 7-3 変数間の独立性の検定
44 | - 7-4 3タイプのベイジアンネットワークの探索手法
45 | - 7-5 PC アルゴリズムによるベイジアンネットワーク探索の実装
46 | - 第 8 章 ディープラーニングを用いた因果探索
47 | - 8-1 因果探索と GAN(Generative Adversarial Networks)の関係
48 | - 8-2 SAM(Structural Agnostic Model)の概要
49 | - 8-3 SAM の識別器 D と生成器 G の実装
50 | - 8-4 SAM の損失関数の解説と因果探索の実装
51 | - 8-5 Google Colaboratory で GPU を使用した因果探索の実行
52 |
53 |
54 |
55 | ### 2. 疑問点・修正点は Issue にて管理しています
56 |
57 | 本 GitHub の Issue にて、疑問点や修正点を管理しています。
58 |
59 | 不明な点などがございましたら、こちらをご覧ください。
60 |
61 | https://github.com/YutaroOgawa/causal_book/issues
62 |
63 |
64 |
65 | ### 3. 誤植について
66 |
67 | 書籍中の誤植一覧はこちらになります。
68 | 大変申し訳ございません。
69 |
70 | [誤植一覧](https://github.com/YutaroOgawa/causal_book/labels/%E8%AA%A4%E6%A4%8D)
71 |
--------------------------------------------------------------------------------
/etc/book.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YutaroOgawa/causal_book/bde8891125b8c85bb8cf521c8cba0ec08226b2b9/etc/book.jpg
--------------------------------------------------------------------------------