├── .gitignore
├── Colaboratory_へようこそ.ipynb
├── LICENSE
├── README.md
├── images
├── amazon.jpeg
├── fig-01-04.png
├── fig-02-05.png
├── fig-02-06.png
├── fig-05-04-13-ng.pdf
├── fig-05-04-13-ng.png
├── fig-05-04-13-ok.pdf
├── fig-05-04-13-ok.png
├── fig-06-01-10.png
├── fig-06-03-12-after.png
├── fig-06-03-12-before.png
├── kinokunuya.jpeg
├── model01-01.png
├── pdf-04-04-06.png
├── seikyo.jpeg
├── step1.png
├── step2.png
├── step3.png
├── step4.png
├── table-1-6.png
└── 表紙-v1.png
├── notebooks
├── 1章_確率分布とは.ipynb
├── 2章_よく利用される確率分布.ipynb
├── 3章_ベイス推論とは.ipynb
├── 4章_はじめてのベイズ推論実習.ipynb
├── 5_1_データ分布のベイズ推論.ipynb
├── 5_2_線形回帰のベイズ推論.ipynb
├── 5_3_階層ベイズモデル.ipynb
├── 5_4_潜在変数モデル.ipynb
├── 6_1_ABテスト効果検証.ipynb
├── 6_2_ベイス回帰モデルによる効果検証.ipynb
└── 6_3_IRTによるテスト結果評価.ipynb
├── refs
├── 3クラス潜在変数モデル.pdf
├── errors.md
├── faqs.md
├── how-to-run.md
└── 目次.md
└── sample-notebooks
├── 4章_図4_4用.ipynb
├── 5_4_潜在変数モデル_簡略版.ipynb
├── 6_3_IRTによるテスト結果評価_GPU版.ipynb
├── A_3クラス潜在変数モデル.ipynb
├── FAQ_潜在変数モデル.ipynb
└── 書籍評価.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | local
2 | .DS_Store
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 書籍『Pythonでスラスラわかる ベイズ推論「超」入門』サポートサイト
2 |
3 |
4 |

5 |
6 |
7 |
8 | 当サイトは 、書籍『**Pythonでスラスラわかる ベイズ推論「超」入門**』のサポートサイトです。
ベイズ推論理解のためには、確率分布などの数学知識が必須です。本書であれば、数学知識が不十分な読者も、Python実習を通じて簡単にベイス推論が理解できます。
9 |
10 | ## 宣伝
11 | ### 講演会のビデオがYoutubeで公開されています
12 | 2024年3月25日(月)に実施されたデータサイエンティスト協会主催の『ベイズ推論』webセミナーの講演内容が下記リンク先から公開されています。
13 | https://www.youtube.com/watch?v=ex119MhbyCk
14 |
15 | ### おかげさまで大変好評です
16 |
17 | |ジャンル|順位|写真|
18 | |---|---|---|
19 | |アマゾンカテゴリ別(情報学・情報科学全般)|1位|
|
20 | |紀伊國屋新宿本店 週間IT書籍ランキング|1位|
|
21 | |東大生協 本郷書籍部 ベストセラー |4位|
|
22 |
23 |
24 |
25 | ## Amazonへのリンク
26 |
27 | [単行本](https://www.amazon.co.jp/dp/4065337631)
28 |
29 | [Kindle版](https://www.amazon.co.jp/dp/B0CP26YYRJ)
30 |
31 |
32 |
33 |
34 | ## 実習Notebookリンク
35 | 本書の実習コードは、Google Colabで動かすことを前提に、すべてGithub(当サポートサイト)で公開しています。
36 |
37 | [実習Notebook一覧](https://github.com/makaishi2/python_bayes_intro/tree/main/notebooks)
38 |
39 | [実習Notebookの動かし方](refs/how-to-run.md)
40 |
41 | [参考Notebook](https://github.com/makaishi2/python_bayes_intro/tree/main/sample-notebooks)
42 |
43 |
44 |
45 |
46 | ## 本書の特徴
47 |
48 | * ベイズ推論でモデルを構築する上で必須の数学概念である確率分布の初歩を、オブジェクト指向プログラミングモデルと対比しながら理解できるようになります
49 |
50 | * PyMCとArVizの使い方を一歩一歩学べます
51 |
52 | * 「くじ引きを5回引いた結果からくじの当たる確率を類推する」という簡単な題材を例にして、ベイズ推論の考え方を理解できます
53 |
54 | * 「正規分布の平均・標準偏差を推論する」というシンプルな問題から「潜在変数モデル」という高度な問題までさまざまなベイズ推論の仕組みを、実習プログラムを通じて理解できます
55 |
56 | * ABテストや線形回帰モデルの効果検証など、業務観点でのベイズ推論活用事例を学ぶことができます
57 |
58 | * 各章・節の最後のコラムで、「事前分布と事後分布の違い」「HDIとCIの違い」や、「target_acceptによるチューニング」「変分推論法の利用」など、知っておくと役に立つ、やや高度な概念や手法を理解できます
59 |
60 |
61 |
62 | ## 本書の想定読者
63 |
64 |
65 |
66 | 本書では、scikit-learnなどのライブラリを利用する**普通の機械学習はマスターした上で**、**次のステップでベイズ推論を学習したい**という読者の方を主に想定しています。
プログラミングとの類推を活用して確率などの数学を説明する部分もあるので、**ある程度のPythonプログラミングスキルは前提**としています。具体的な知識レベルは以下のとおりです。
67 |
68 |
69 |
70 | * Python文法の基礎知識
71 |
72 | - 整数型、浮動小数点数型、ブーリアン型などの基本型
73 |
74 | - 関数呼び出しにおけるオプション付き引数
75 |
76 | - オブジェクト指向プログラミングの基礎概念(クラス、インスタンス、コンストラクタ)
77 |
78 |
79 | * NumPy, pandas, matplotlib, Seabornの基本的な操作
80 |
81 |
82 |
83 | 数学に関しては、極力、**高校1年程度の数学知識で読み進めることができる**よう心がけました。確率分布の説明などで、数式が出てくる箇所もありますが、数式をスルーしても先に読み進められるよう工夫したつもりです。
逆に Pythonコードと数学概念との対応はとても重視しているので、読者の方には極力、本書の前提である**Google Colabで実習コードをを動かしながら本書をを読み進めていただく**ことを推奨いたします。
84 |
85 |
86 |
87 |
88 | ## 目次
89 |
90 | [目次リンク](refs/目次.md)
91 |
92 |
93 |
94 | ## 本書ハイライト紹介
95 |
96 | #### 確率分布関連の数学概念はオブジェクト指向プログラミングモデルとの対比で説明
97 |
98 | ベイズ推論モデル理解における最大のハードルである確率分布の数学概念に関しては、オブジェクト指向プログラミングモデルとの対比により、具体的イメージを持ちやすくしました。
99 |
100 |
101 |
102 |
103 |
104 | #### 確率モデル定義に関しては可視化結果で詳しく解説
105 |
106 | ベイズ推論開発の根幹である確率モデル定義に関しては、すべての例題で可視化結果をつけ、この図を使って詳しく解説をしています。
107 |
108 |
109 |
110 | #### ベイズ推論のステップをわかりやすく解説
111 |
112 | ベイズ推論における4つのステップを、利用するラブラリとの関連を含めてわかりやすく解説しています。
113 |
114 |
115 |
116 |
117 |
118 | #### ベイズ推論の仕組みについては最尤推定との対比で説明
119 |
120 | ベイズ推定の仕組みについては、別の統計分析手法である最尤推定との対比で、何がどう違うのかをわかりやすく解説します。
121 |
122 |
123 |
124 |
125 |
126 | #### 潜在変数モデル
127 |
128 | ベイズ推論の中でとても面白い、しかし難易度の高い潜在変数モデルについても例題に含めました。実データを用いて潜在変数モデルを開発する際のツボも、コラムで解説しています。
129 |
130 |
131 |
132 |
133 |
134 | #### 実業務利用例
135 |
136 | 6章では、すぐに業務で使えそうな利用例を掲載しています。この章を読むことで、実業務での利用イメージを持つことが可能です。
137 |
138 |
139 |
140 |
141 |
142 | ## その他解説記事
143 |
144 | |ソース |タイトルとリンク |補足|
145 | |---|---|---|
146 | |qiita|[Amazonレビュー分析](https://qiita.com/makaishi2/items/074c803de4368ef7874f)|アマゾンのレビューは5から1までのスコア値を持っており平均スコア値をベイズ推論することは多項分布のいい例題になります。著者自身の7冊の本のスコアを題材にこのテーマでベイズ推論をしてみた事例です。|
147 | |当サポートサイト|[3クラス潜在変数モデル](refs/3クラス潜在変数モデル.pdf)|5.4節 潜在変数モデルは対象を3クラスに拡張可能です。その解説をしています。|
148 | |qiita|[潜在変数モデル補足](https://qiita.com/makaishi2/items/8dae04e51a79ae456995)|5.4節 潜在変数モデルの「ラベルスイッチ」について詳しく解説しました。|
149 | |qiita|[潜在変数モデル簡易版](https://qiita.com/makaishi2/items/2ce59c2562537b92f383)|5.4節の潜在変数モデルは複数の推論タスクを同時に行う非常に複雑な確率モデルです。この確率モデルをよく理解するためには、カテゴリ比率pを固定値とした簡易版を用いるといいことがわかったため、この点をqiita記事で詳しく説明しました。|
150 | |qiita|[Colab上のPyMC5をGPUで動かす](https://qiita.com/makaishi2/items/6247e80006b341216df8)|本書の実習の中でも6.3節の実習は特にサンプリングに時間がかかります。対策としてColabのGPUが使えることがわかり、その手順を解説しました。|
151 |
152 |
153 |
154 |
155 | ## リンク集
156 |
157 |
158 |
159 | ### 著者発信の情報
160 |
161 | |ソース |タイトルとリンク |補足|
162 | |---|---|---|
163 | |twitter|[@makaishi2](https://twitter.com/makaishi2)|著者のツイッターアカウント。書籍に関連した情報発信とretweetが中心です。|
164 | |異業種データサイエンス研究会招待講演資料|[AI・DS領域を学習したい方に向けて](https://speakerdeck.com/makaishi2/aidsling-yu-woxue-xi-sitaifang-nixiang-kete)|2022-12-03に異業種データサイエンス研究会主催で行われたイベントでの招講演資料です。|
165 |
166 |
167 | ### 外部リンク
168 |
169 |
170 | |ソース |タイトルとリンク |補足|
171 | |---|---|---|
172 | |Amazon|[Amazonレビュー](https://www.amazon.co.jp/product-reviews/4065337631/)||
173 | |togetterまとめサイト|https://togetter.com/li/2274456 |有名ブロガーのkenkenさんの本書へのコメントがまとめられています|
174 |
175 |
176 |
177 | ***
178 |
179 |
180 | ## 正誤訂正・FAQ
181 |
182 |
183 | * [正誤訂正](refs/errors.md)
184 |
185 | * [FAQ](refs/faqs.md)
186 |
187 |
--------------------------------------------------------------------------------
/images/amazon.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/amazon.jpeg
--------------------------------------------------------------------------------
/images/fig-01-04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-01-04.png
--------------------------------------------------------------------------------
/images/fig-02-05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-02-05.png
--------------------------------------------------------------------------------
/images/fig-02-06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-02-06.png
--------------------------------------------------------------------------------
/images/fig-05-04-13-ng.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-05-04-13-ng.pdf
--------------------------------------------------------------------------------
/images/fig-05-04-13-ng.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-05-04-13-ng.png
--------------------------------------------------------------------------------
/images/fig-05-04-13-ok.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-05-04-13-ok.pdf
--------------------------------------------------------------------------------
/images/fig-05-04-13-ok.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-05-04-13-ok.png
--------------------------------------------------------------------------------
/images/fig-06-01-10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-06-01-10.png
--------------------------------------------------------------------------------
/images/fig-06-03-12-after.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-06-03-12-after.png
--------------------------------------------------------------------------------
/images/fig-06-03-12-before.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/fig-06-03-12-before.png
--------------------------------------------------------------------------------
/images/kinokunuya.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/kinokunuya.jpeg
--------------------------------------------------------------------------------
/images/model01-01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/model01-01.png
--------------------------------------------------------------------------------
/images/pdf-04-04-06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/pdf-04-04-06.png
--------------------------------------------------------------------------------
/images/seikyo.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/seikyo.jpeg
--------------------------------------------------------------------------------
/images/step1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/step1.png
--------------------------------------------------------------------------------
/images/step2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/step2.png
--------------------------------------------------------------------------------
/images/step3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/step3.png
--------------------------------------------------------------------------------
/images/step4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/step4.png
--------------------------------------------------------------------------------
/images/table-1-6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/table-1-6.png
--------------------------------------------------------------------------------
/images/表紙-v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/images/表紙-v1.png
--------------------------------------------------------------------------------
/notebooks/1章_確率分布とは.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "qSPzKMVJyI1Q"
7 | },
8 | "source": [
9 | "## 1章 確率分布"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {
15 | "id": "k78Unf4mzxD6"
16 | },
17 | "source": [
18 | "
"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "GnnBseyayPhn"
25 | },
26 | "source": [
27 | "### 共通処理"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {
34 | "id": "if-gXKFKx5c9"
35 | },
36 | "outputs": [],
37 | "source": [
38 | "%matplotlib inline\n",
39 | "# 日本語化ライブラリ導入\n",
40 | "!pip install japanize-matplotlib | tail -n 1"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {
47 | "id": "zS2KvH3Kx-zQ"
48 | },
49 | "outputs": [],
50 | "source": [
51 | "# ライブラリのimport\n",
52 | "\n",
53 | "# NumPy用ライブラリ\n",
54 | "import numpy as np\n",
55 | "# Matplotlib中のpyplotライブラリのインポート\n",
56 | "import matplotlib.pyplot as plt\n",
57 | "# matplotlib日本語化対応ライブラリのインポート\n",
58 | "import japanize_matplotlib\n",
59 | "# pandas用ライブラリ\n",
60 | "import pandas as pd\n",
61 | "# データフレーム表示用関数\n",
62 | "from IPython.display import display\n",
63 | "# seaborn\n",
64 | "import seaborn as sns\n",
65 | "# 表示オプション調整\n",
66 | "# NumPy表示形式の設定\n",
67 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
68 | "# グラフのデフォルトフォント指定\n",
69 | "plt.rcParams[\"font.size\"] = 14\n",
70 | "# サイズ設定\n",
71 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
72 | "# 方眼表示ON\n",
73 | "plt.rcParams['axes.grid'] = True\n",
74 | "# データフレームでの表示精度\n",
75 | "pd.options.display.float_format = '{:.3f}'.format\n",
76 | "# データフレームですべての項目を表示\n",
77 | "pd.set_option(\"display.max_columns\",None)"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {
83 | "id": "wTy9T_N0Yz8b"
84 | },
85 | "source": [
86 | "### 1.3 離散分布と連続分布"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "metadata": {
92 | "id": "kMDHDLNGFVJ5"
93 | },
94 | "source": [
95 | "#### くじ引き問題の確率分布の可視化プログラム"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {
102 | "id": "vqA9_i3NUaMj"
103 | },
104 | "outputs": [],
105 | "source": [
106 | "from scipy.special import comb\n",
107 | "n = 5\n",
108 | "x = range(n+1)\n",
109 | "y = [comb(n, i)/2**n for i in x]\n",
110 | "plt.bar(x, y)\n",
111 | "plt.title('くじ引き問題の確率分布の可視化結果');"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {
117 | "id": "mEFPmS4AFkrI"
118 | },
119 | "source": [
120 | "#### くじ引き問題の確率分布の可視化プログラム(n=1000)"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {
127 | "id": "RGPPAx6kmywq"
128 | },
129 | "outputs": [],
130 | "source": [
131 | "from scipy.special import comb\n",
132 | "n = 1000\n",
133 | "x = range(n+1)\n",
134 | "y = [comb(n, i)/2**n for i in x]\n",
135 | "plt.bar(x, y)\n",
136 | "plt.xlim((430,570))\n",
137 | "plt.title('くじ引き問題の確率分布の可視化結果(n=1000)');"
138 | ]
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "metadata": {
143 | "id": "5s48or3cFrqI"
144 | },
145 | "source": [
146 | "#### 確率分布と正規分布関数の重ね描き"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {
153 | "id": "ZZDtKNGkoMKA"
154 | },
155 | "outputs": [],
156 | "source": [
157 | "# 正規分布関数の定義\n",
158 | "def norm(x, mu, sigma):\n",
159 | " return np.exp(-((x - mu)/sigma)**2/2) / (np.sqrt(2 * np.pi) * sigma)\n",
160 | "\n",
161 | "n = 1000\n",
162 | "\n",
163 | "# グラフ描画\n",
164 | "plt.xlim((430,570))\n",
165 | "x = np.arange(430, 571)\n",
166 | "\n",
167 | "# 確率分布のグラフ描画\n",
168 | "y1 = [comb(n, i)/2**n for i in x]\n",
169 | "plt.bar(x, y1)\n",
170 | "\n",
171 | "# 正規分布関数のグラフ描画\n",
172 | "mu = n/2\n",
173 | "sigma = np.sqrt(mu/2)\n",
174 | "y2 = norm(x, mu, sigma)\n",
175 | "plt.plot(x, y2, c='k')\n",
176 | "\n",
177 | "plt.title('確率分布と正規分布関数の重ね描き');"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {
183 | "id": "jVV7J1e_GFwf"
184 | },
185 | "source": [
186 | "#### 正規分布関数と確率の関係正規分布関数と確率の関係"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {
193 | "id": "k36bvH-Kr8kn"
194 | },
195 | "outputs": [],
196 | "source": [
197 | "n = 1000\n",
198 | "\n",
199 | "# グラフ描画\n",
200 | "plt.xlim((430,570))\n",
201 | "x = np.arange(430, 571)\n",
202 | "x1 = 460\n",
203 | "x2 = 480\n",
204 | "x_range = np.arange(x1, x2+1)\n",
205 | "\n",
206 | "# 正規分布関数\n",
207 | "mu = n/2\n",
208 | "sigma = np.sqrt(mu/2)\n",
209 | "plt.plot(x, norm(x, mu, sigma), c='k')\n",
210 | "plt.fill_between(x_range, 0, norm(x_range, mu, sigma), facecolor='b', alpha=0.2)\n",
211 | "plt.plot((x1, x1), (0, norm(x1, mu, sigma)), c='b')\n",
212 | "plt.plot((x2, x2), (0, norm(x2, mu, sigma)), c='b')\n",
213 | "plt.title('正規分布関数と確率の関係');"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {
219 | "id": "4Ancrb2l6EBU"
220 | },
221 | "source": [
222 | "### 1.4 PyMCプログラミング"
223 | ]
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "metadata": {
228 | "id": "JGsKiNojRCXl"
229 | },
230 | "source": [
231 | "#### インポート文"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "metadata": {
238 | "id": "iXT5wRuQyw3w"
239 | },
240 | "outputs": [],
241 | "source": [
242 | "import pymc as pm\n",
243 | "import arviz as az\n",
244 | "\n",
245 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
246 | "print(f\"Running on ArviZ v{az.__version__}\")"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {
252 | "id": "lHJ7uJVuRLac"
253 | },
254 | "source": [
255 | "#### 確率モデル定義"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": null,
261 | "metadata": {
262 | "id": "FwlXwkox9v-k"
263 | },
264 | "outputs": [],
265 | "source": [
266 | "# 確率モデルコンテキスト用インスタンス生成\n",
267 | "model = pm.Model()\n",
268 | "\n",
269 | "# 先ほど定義したインスタンスを用いてwith文でコンテキストを指定\n",
270 | "with model:\n",
271 | " # Binomial: 二項分布の確率分布クラス\n",
272 | " # p: 二項分布の元になる試行の成功確率\n",
273 | " # n: 二項分布における試行回数\n",
274 | " # 'x': 確率変数 x のサンプル値を参照するときのラベル\n",
275 | " x = pm.Binomial('x', p=0.5, n=5)"
276 | ]
277 | },
278 | {
279 | "cell_type": "markdown",
280 | "metadata": {
281 | "id": "_jZHMFWhRRJn"
282 | },
283 | "source": [
284 | "#### サンプリングの実施"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": null,
290 | "metadata": {
291 | "id": "-wHEgO4mNr31"
292 | },
293 | "outputs": [],
294 | "source": [
295 | "# with modelのコンテキスト定義により、\n",
296 | "# 上で定義した確率モデルと紐付けられる\n",
297 | "# sample_prior_predictive: 事前分布の予測値取得関数\n",
298 | "# 乱数により生成されたサンプル値が変数prior_samplesにセットされる\n",
299 | "with model:\n",
300 | " prior_samples = pm.sample_prior_predictive(random_seed=42)"
301 | ]
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "metadata": {
306 | "id": "8w0TiB2eRWI5"
307 | },
308 | "source": [
309 | "### 1.5 サンプル値分析"
310 | ]
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "metadata": {
315 | "id": "ynpfI4uAvleZ"
316 | },
317 | "source": [
318 | "#### notebook UIによる方法"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": null,
324 | "metadata": {
325 | "id": "HRdnOGBQ8tD4"
326 | },
327 | "outputs": [],
328 | "source": [
329 | "prior_samples"
330 | ]
331 | },
332 | {
333 | "cell_type": "markdown",
334 | "metadata": {
335 | "id": "LZFPv7JAvuNe"
336 | },
337 | "source": [
338 | "#### NumPy形式データを抽出して分析する方法"
339 | ]
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": null,
344 | "metadata": {
345 | "id": "wTQH6mM2N2ZQ"
346 | },
347 | "outputs": [],
348 | "source": [
349 | "# 事前分布(prior)としてのサンプル値を取得\n",
350 | "x_samples = prior_samples['prior']['x'].values\n",
351 | "print('type: ', type(x_samples))\n",
352 | "print('shape: ', x_samples.shape)\n",
353 | "print('values: ' , x_samples, '\\n')\n",
354 | "\n",
355 | "# 通常のNumPyデータとして頻度分析をした例\n",
356 | "value_counts = pd.DataFrame(\n",
357 | " x_samples.reshape(-1)).value_counts().sort_index()\n",
358 | "print(value_counts)"
359 | ]
360 | },
361 | {
362 | "cell_type": "markdown",
363 | "metadata": {
364 | "id": "EVt5HAdSvzLs"
365 | },
366 | "source": [
367 | "#### ArviZによる分析"
368 | ]
369 | },
370 | {
371 | "cell_type": "markdown",
372 | "metadata": {
373 | "id": "ay1ljOEvHpVe"
374 | },
375 | "source": [
376 | "##### ArviZによるサンプル値の集計"
377 | ]
378 | },
379 | {
380 | "cell_type": "code",
381 | "execution_count": null,
382 | "metadata": {
383 | "id": "OHu8odHkwPfq"
384 | },
385 | "outputs": [],
386 | "source": [
387 | "summary = az.summary(prior_samples, kind='stats')\n",
388 | "display(summary)"
389 | ]
390 | },
391 | {
392 | "cell_type": "markdown",
393 | "metadata": {
394 | "id": "FIK3QIh_Hv5w"
395 | },
396 | "source": [
397 | "##### ArviZによる可視化分析"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": null,
403 | "metadata": {
404 | "id": "721FFuePfYT9"
405 | },
406 | "outputs": [],
407 | "source": [
408 | "ax = az.plot_dist(x_samples)\n",
409 | "ax.set_title('ArviZによるサンプル値の可視化結果');"
410 | ]
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "metadata": {
415 | "id": "2oUa1cRnsU_q"
416 | },
417 | "source": [
418 | "### バージョンの確認"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": null,
424 | "metadata": {
425 | "id": "HXPitM6GS8ti"
426 | },
427 | "outputs": [],
428 | "source": [
429 | "!pip install watermark | tail -n 1\n",
430 | "%load_ext watermark\n",
431 | "%watermark --iversions"
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": null,
437 | "metadata": {
438 | "id": "5dY6lqRIsbWV"
439 | },
440 | "outputs": [],
441 | "source": []
442 | }
443 | ],
444 | "metadata": {
445 | "colab": {
446 | "provenance": [],
447 | "toc_visible": true
448 | },
449 | "kernelspec": {
450 | "display_name": "Python 3",
451 | "name": "python3"
452 | },
453 | "language_info": {
454 | "name": "python"
455 | }
456 | },
457 | "nbformat": 4,
458 | "nbformat_minor": 0
459 | }
--------------------------------------------------------------------------------
/notebooks/3章_ベイス推論とは.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "z9YCAOVcfPie"
7 | },
8 | "source": [
9 | "## 3章 ベイス推論とは"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {
15 | "id": "tNWSPFiM1zaG"
16 | },
17 | "source": [
18 | "
\n"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "_HG_NLmyfdF6"
25 | },
26 | "source": [
27 | "### 共通処理"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {
34 | "id": "LF-JMPVefZXr"
35 | },
36 | "outputs": [],
37 | "source": [
38 | "%matplotlib inline\n",
39 | "# 日本語化ライブラリ導入\n",
40 | "!pip install japanize-matplotlib | tail -n 1"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {
47 | "id": "SllTvgULfivc"
48 | },
49 | "outputs": [],
50 | "source": [
51 | "# ライブラリのimport\n",
52 | "\n",
53 | "# NumPy用ライブラリ\n",
54 | "import numpy as np\n",
55 | "# Matplotlib中のpyplotライブラリのインポート\n",
56 | "import matplotlib.pyplot as plt\n",
57 | "# matplotlib日本語化対応ライブラリのインポート\n",
58 | "import japanize_matplotlib\n",
59 | "# pandas用ライブラリ\n",
60 | "import pandas as pd\n",
61 | "# データフレーム表示用関数\n",
62 | "from IPython.display import display\n",
63 | "# seaborn\n",
64 | "import seaborn as sns\n",
65 | "# 表示オプション調整\n",
66 | "# NumPy表示形式の設定\n",
67 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
68 | "# グラフのデフォルトフォント指定\n",
69 | "plt.rcParams[\"font.size\"] = 14\n",
70 | "# サイズ設定\n",
71 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
72 | "# 方眼表示ON\n",
73 | "plt.rcParams['axes.grid'] = True\n",
74 | "# データフレームでの表示精度\n",
75 | "pd.options.display.float_format = '{:.3f}'.format\n",
76 | "# データフレームですべての項目を表示\n",
77 | "pd.set_option(\"display.max_columns\",None)"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "metadata": {
84 | "id": "MBzeKi-lfltG"
85 | },
86 | "outputs": [],
87 | "source": [
88 | "import pymc as pm\n",
89 | "import arviz as az\n",
90 | "\n",
91 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
92 | "print(f\"Running on ArViz v{az.__version__}\")"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {
98 | "id": "tSoiwbzgfyWE"
99 | },
100 | "source": [
101 | "### 3.1 ベイズ推論利用の目的"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {
108 | "id": "uXtb8isNfoal"
109 | },
110 | "outputs": [],
111 | "source": [
112 | "# [0, 1]区間一様分布の確率密度関数\n",
113 | "def f(x):\n",
114 | " # 関数をユニバーサル関数にするための工夫\n",
115 | " return x - x + 1.0\n",
116 | "\n",
117 | "x = np.arange(0.0, 1.1, 0.1)\n",
118 | "plt.fill_between(x, f(x))\n",
119 | "plt.title('区間[0, 1]一様分布の確率密度関数');"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {
125 | "id": "W3l3sKXzanNg"
126 | },
127 | "source": [
128 | "### バージョンの確認"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "metadata": {
135 | "id": "KOTCOgxggKcP"
136 | },
137 | "outputs": [],
138 | "source": [
139 | "!pip install watermark | tail -n 1\n",
140 | "%load_ext watermark\n",
141 | "%watermark --iversions"
142 | ]
143 | }
144 | ],
145 | "metadata": {
146 | "colab": {
147 | "provenance": [],
148 | "toc_visible": true
149 | },
150 | "kernelspec": {
151 | "display_name": "Python 3",
152 | "name": "python3"
153 | },
154 | "language_info": {
155 | "name": "python"
156 | }
157 | },
158 | "nbformat": 4,
159 | "nbformat_minor": 0
160 | }
--------------------------------------------------------------------------------
/notebooks/4章_はじめてのベイズ推論実習.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "``plot_trace``関数でワーニングを出さないため次のコマンド実行が必要です。"
7 | ],
8 | "metadata": {
9 | "id": "yYHN_GYMfaBA"
10 | }
11 | },
12 | {
13 | "cell_type": "code",
14 | "source": [
15 | "!pip install --upgrade numba | tail -n 1"
16 | ],
17 | "metadata": {
18 | "id": "vAm7pfE6fa-Y"
19 | },
20 | "execution_count": null,
21 | "outputs": []
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {
26 | "id": "nU4c1bIkBgOe"
27 | },
28 | "source": [
29 | "## 4章 はじめてのベイズ推論実習"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {
35 | "id": "xk73aEA_1OU3"
36 | },
37 | "source": [
38 | "
"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "PqwDnWT-BR0q"
45 | },
46 | "source": [
47 | "### 共通処理"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "id": "GuyTLGJkBO7o"
55 | },
56 | "outputs": [],
57 | "source": [
58 | "%matplotlib inline\n",
59 | "# 日本語化ライブラリ導入\n",
60 | "!pip install japanize-matplotlib | tail -n 1"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {
67 | "id": "GUpYvKuKBolx"
68 | },
69 | "outputs": [],
70 | "source": [
71 | "# ライブラリのimport\n",
72 | "\n",
73 | "# NumPy用ライブラリ\n",
74 | "import numpy as np\n",
75 | "# Matplotlib中のpyplotライブラリのインポート\n",
76 | "import matplotlib.pyplot as plt\n",
77 | "# matplotlib日本語化対応ライブラリのインポート\n",
78 | "import japanize_matplotlib\n",
79 | "# pandas用ライブラリ\n",
80 | "import pandas as pd\n",
81 | "# データフレーム表示用関数\n",
82 | "from IPython.display import display\n",
83 | "# seaborn\n",
84 | "import seaborn as sns\n",
85 | "# 表示オプション調整\n",
86 | "# NumPy表示形式の設定\n",
87 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
88 | "# グラフのデフォルトフォント指定\n",
89 | "plt.rcParams[\"font.size\"] = 14\n",
90 | "# サイズ設定\n",
91 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
92 | "# 方眼表示ON\n",
93 | "plt.rcParams['axes.grid'] = True\n",
94 | "# データフレームでの表示精度\n",
95 | "pd.options.display.float_format = '{:.3f}'.format\n",
96 | "# データフレームですべての項目を表示\n",
97 | "pd.set_option(\"display.max_columns\",None)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "id": "YoSNt5Z_qFIF"
105 | },
106 | "outputs": [],
107 | "source": [
108 | "import pymc as pm\n",
109 | "import arviz as az\n",
110 | "\n",
111 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
112 | "print(f\"Running on ArViz v{az.__version__}\")"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "36QxNPyYB0AW"
119 | },
120 | "source": [
121 | "### 問題定義\n",
122 | "常に確率が一定で、前回の結果が次回に一切影響しない(数学的には「独立な事象」という)くじ引きがあります。 \n",
123 | "ある人がこのくじ引きを5回引いたところ、結果は「当たり、はずれ、はずれ、当たり、はずれ」でした。 \n",
124 | "一回のくじ引きにあたる確率を $p$ とするとき、この $p$ の値を求めなさい。"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {
130 | "id": "SWSYXTrlCKHM"
131 | },
132 | "source": [
133 | "### 4.2 最尤推定"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {
139 | "id": "pDKCb_6YL5UU"
140 | },
141 | "source": [
142 | "#### 尤度関数のグラフ"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "metadata": {
149 | "id": "n3gFFtsE1qoJ"
150 | },
151 | "outputs": [],
152 | "source": [
153 | "def lh(p):\n",
154 | " return p ** 2 * (1-p) ** 3\n",
155 | "\n",
156 | "# グラフ描画用x座標の定義\n",
157 | "# 0.0 < x < 1.0\n",
158 | "p = np.arange(0.0, 1.0, 0.01)\n",
159 | "\n",
160 | "# グラフ描画\n",
161 | "plt.rcParams['figure.figsize'] = (6, 4)\n",
162 | "plt.plot(p, lh(p))\n",
163 | "plt.xlabel('p(確率値)')\n",
164 | "plt.ylabel('尤度')\n",
165 | "plt.title(f'尤度関数');"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "metadata": {
171 | "id": "bUWTucW_L_q0"
172 | },
173 | "source": [
174 | "#### PyTorchで解く"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "metadata": {
181 | "id": "7SUwHW7rBuVn"
182 | },
183 | "outputs": [],
184 | "source": [
185 | "import torch # ライブラリインポート\n",
186 | "\n",
187 | "def log_lh(p): # 対数尤度関数\n",
188 | " return (2 * torch.log(p) + 3 * torch.log(1-p))\n",
189 | "\n",
190 | "num_epochs = 40 # 繰り返し回数\n",
191 | "lr = 0.01 # 学習率\n",
192 | "\n",
193 | "# パラメータ初期値 (p=0.1)\n",
194 | "p = torch.tensor(0.1, dtype=torch.float32, requires_grad=True)\n",
195 | "\n",
196 | "logs = np.zeros((0,3))\n",
197 | "for epoch in range(num_epochs):\n",
198 | " loss = -log_lh(p) # 損失計算\n",
199 | " loss.backward() # 勾配計算\n",
200 | " with torch.no_grad():\n",
201 | " p -= lr * p.grad # パラメータ修正\n",
202 | " p.grad.zero_() # 勾配値の初期化\n",
203 | " log = np.array([epoch, p.item(), loss.item()]).reshape(1,-1)\n",
204 | " logs = np.vstack([logs, log])"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {
211 | "id": "eFIxQYmhLuIw"
212 | },
213 | "outputs": [],
214 | "source": [
215 | "plt.rcParams['figure.figsize'] = (8, 4)\n",
216 | "fig, axes = plt.subplots(1, 2)\n",
217 | "axes[0].plot(logs[:,0], logs[:,1])\n",
218 | "axes[0].set_title('p(確率値)')\n",
219 | "axes[1].plot(logs[:,0], logs[:,2])\n",
220 | "axes[1].set_title('loss(損失)')\n",
221 | "plt.tight_layout();"
222 | ]
223 | },
224 | {
225 | "cell_type": "markdown",
226 | "metadata": {
227 | "id": "8gX0x79yp08G"
228 | },
229 | "source": [
230 | "### 4.3 ベイズ推論(確率モデル定義)"
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "metadata": {
236 | "id": "JWd5GvBy7lQJ"
237 | },
238 | "source": [
239 | "#### データ(観測値)準備"
240 | ]
241 | },
242 | {
243 | "cell_type": "code",
244 | "execution_count": null,
245 | "metadata": {
246 | "id": "TxqXnyE_DPI7"
247 | },
248 | "outputs": [],
249 | "source": [
250 | "X = np.array([1, 0, 0, 1, 0])\n",
251 | "print(X)"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {
257 | "id": "6wEg7dkc7t3c"
258 | },
259 | "source": [
260 | "#### 確率モデル定義"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {
267 | "id": "F5ThhRYXFwQH"
268 | },
269 | "outputs": [],
270 | "source": [
271 | "# コンテキスト定義\n",
272 | "model1 = pm.Model()\n",
273 | "\n",
274 | "with model1:\n",
275 | " # pm.Uniform: 一様分布\n",
276 | " p = pm.Uniform('p', lower=0.0, upper=1.0)\n",
277 | "\n",
278 | " # pm.Bernoulli: ベルヌーイ分布\n",
279 | " X_obs = pm.Bernoulli('X_obs', p=p, observed=X)"
280 | ]
281 | },
282 | {
283 | "cell_type": "markdown",
284 | "metadata": {
285 | "id": "3LN0Loj4g11H"
286 | },
287 | "source": [
288 | "#### 確率モデルの可視化"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "metadata": {
295 | "id": "9MOWF2ZSHlPc"
296 | },
297 | "outputs": [],
298 | "source": [
299 | "g = pm.model_to_graphviz(model1)\n",
300 | "display(g)"
301 | ]
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "metadata": {
306 | "id": "DcpA25EXrSKN"
307 | },
308 | "source": [
309 | "### 4.4 ベイズ推論(サンプリング)"
310 | ]
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "metadata": {
315 | "id": "d2QcWGvcM-AL"
316 | },
317 | "source": [
318 | "#### パラメータ値を明示的に設定してサンプリング"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": null,
324 | "metadata": {
325 | "id": "dVuGeMxRrds_"
326 | },
327 | "outputs": [],
328 | "source": [
329 | "with model1:\n",
330 | " idata1_1 = pm.sample(\n",
331 | " # 乱数系列の数(デフォルト2)\n",
332 | " chains=3,\n",
333 | " # 捨てるサンプル数(デフォルト1000)\n",
334 | " tune=2000,\n",
335 | " # 取得するサンプル数(デフォルト1000)\n",
336 | " draws=2000,\n",
337 | " random_seed=42)"
338 | ]
339 | },
340 | {
341 | "cell_type": "markdown",
342 | "metadata": {
343 | "id": "18J9j5ZKPJ4V"
344 | },
345 | "source": [
346 | "#### すべてデフォルト値でサンプリング"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": null,
352 | "metadata": {
353 | "id": "0obEALDftJCl"
354 | },
355 | "outputs": [],
356 | "source": [
357 | "with model1:\n",
358 | " idata1_2 = pm.sample(random_seed=42)"
359 | ]
360 | },
361 | {
362 | "cell_type": "markdown",
363 | "metadata": {
364 | "id": "q8hfG__lrfUE"
365 | },
366 | "source": [
367 | "### 4.5 ベイズ推論( 結果分析)"
368 | ]
369 | },
370 | {
371 | "cell_type": "markdown",
372 | "metadata": {
373 | "id": "P426h2UMYi2e"
374 | },
375 | "source": [
376 | "#### plot_trace 関数呼び出し"
377 | ]
378 | },
379 | {
380 | "cell_type": "code",
381 | "execution_count": null,
382 | "metadata": {
383 | "id": "SEdQFb7wrD04"
384 | },
385 | "outputs": [],
386 | "source": [
387 | "axes = az.plot_trace(idata1_2, compact=False)\n",
388 | "plt.tight_layout();"
389 | ]
390 | },
391 | {
392 | "cell_type": "markdown",
393 | "metadata": {
394 | "id": "jXfQ_7P0ufoj"
395 | },
396 | "source": [
397 | "#### plot_posteror関数呼び出し"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": null,
403 | "metadata": {
404 | "id": "fVXbgG9_uGUW"
405 | },
406 | "outputs": [],
407 | "source": [
408 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
409 | "ax = az.plot_posterior(idata1_2)\n",
410 | "ax.set_xlim(0, 1)\n",
411 | "ax.set_title('ベイズ推論結果 初期版');"
412 | ]
413 | },
414 | {
415 | "cell_type": "markdown",
416 | "metadata": {
417 | "id": "_i9NiSEGusAg"
418 | },
419 | "source": [
420 | "#### summary関数呼び出し"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": null,
426 | "metadata": {
427 | "id": "Qk5vZR7ZttxV"
428 | },
429 | "outputs": [],
430 | "source": [
431 | "summary1_2 = az.summary(idata1_2)\n",
432 | "display(summary1_2)"
433 | ]
434 | },
435 | {
436 | "cell_type": "markdown",
437 | "metadata": {
438 | "id": "Xqou6AwPsROC"
439 | },
440 | "source": [
441 | "### 4.6 ベイズ推論(二項分布バージョン)"
442 | ]
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "metadata": {
447 | "id": "mUEsQFlBugPs"
448 | },
449 | "source": [
450 | "#### 確率モデル定義 二項分布バージョン"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": null,
456 | "metadata": {
457 | "id": "nkg_R5H_sVzi"
458 | },
459 | "outputs": [],
460 | "source": [
461 | "# コンテキスト定義\n",
462 | "model2 = pm.Model()\n",
463 | "\n",
464 | "with model2:\n",
465 | " # pm.Uniform: 一様分布\n",
466 | " p = pm.Uniform('p', lower=0.0, upper=1.0)\n",
467 | "\n",
468 | " # pm.Binomial: 二項分布\n",
469 | " # p: 成功確率\n",
470 | " # n: 試行数\n",
471 | " X_obs = pm.Binomial('X_obs', p=p, n=5, observed=2)"
472 | ]
473 | },
474 | {
475 | "cell_type": "markdown",
476 | "metadata": {
477 | "id": "nO32Zyy7pK4z"
478 | },
479 | "source": [
480 | "#### 二項分布バージョンの確率モデル可視化"
481 | ]
482 | },
483 | {
484 | "cell_type": "code",
485 | "execution_count": null,
486 | "metadata": {
487 | "id": "LNy1Cc9kvQas"
488 | },
489 | "outputs": [],
490 | "source": [
491 | "# モデル構造可視化\n",
492 | "g = pm.model_to_graphviz(model2)\n",
493 | "display(g)"
494 | ]
495 | },
496 | {
497 | "cell_type": "markdown",
498 | "metadata": {
499 | "id": "zI_Jps3lukwQ"
500 | },
501 | "source": [
502 | "#### サンプリング"
503 | ]
504 | },
505 | {
506 | "cell_type": "code",
507 | "execution_count": null,
508 | "metadata": {
509 | "id": "6elFd8Nwuov_"
510 | },
511 | "outputs": [],
512 | "source": [
513 | "with model2:\n",
514 | " idata2 = pm.sample(random_seed=42)"
515 | ]
516 | },
517 | {
518 | "cell_type": "markdown",
519 | "metadata": {
520 | "id": "oxMJmW6kupjP"
521 | },
522 | "source": [
523 | "#### 結果分析"
524 | ]
525 | },
526 | {
527 | "cell_type": "markdown",
528 | "metadata": {
529 | "id": "NZimo5WIzgix"
530 | },
531 | "source": [
532 | "##### plot_trace関数呼び出し"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": null,
538 | "metadata": {
539 | "id": "HnS-4Ntfutf2"
540 | },
541 | "outputs": [],
542 | "source": [
543 | "axes = az.plot_trace(idata2, compact=False)\n",
544 | "plt.tight_layout();"
545 | ]
546 | },
547 | {
548 | "cell_type": "markdown",
549 | "metadata": {
550 | "id": "QqS0xe2oYDF9"
551 | },
552 | "source": [
553 | "##### plot_posteror関数呼び出し"
554 | ]
555 | },
556 | {
557 | "cell_type": "code",
558 | "execution_count": null,
559 | "metadata": {
560 | "id": "ybB2q3jAYMU2"
561 | },
562 | "outputs": [],
563 | "source": [
564 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
565 | "ax = az.plot_posterior(idata2)\n",
566 | "ax.set_xlim(0, 1)\n",
567 | "ax.set_title('ベイズ推論結果 二項分布版');"
568 | ]
569 | },
570 | {
571 | "cell_type": "markdown",
572 | "metadata": {
573 | "id": "Mrl1-9gqzr2K"
574 | },
575 | "source": [
576 | "##### summary関数呼び出し"
577 | ]
578 | },
579 | {
580 | "cell_type": "code",
581 | "execution_count": null,
582 | "metadata": {
583 | "id": "kjOh_5J4vnmS"
584 | },
585 | "outputs": [],
586 | "source": [
587 | "summary2 = az.summary(idata2)\n",
588 | "display(summary2)"
589 | ]
590 | },
591 | {
592 | "cell_type": "markdown",
593 | "metadata": {
594 | "id": "DMUKtXZbrzIl"
595 | },
596 | "source": [
597 | "### 4.7 ベイズ推論(試行数を増やす)"
598 | ]
599 | },
600 | {
601 | "cell_type": "markdown",
602 | "metadata": {
603 | "id": "ipSWvogSwOT1"
604 | },
605 | "source": [
606 | "#### 確率モデル定義"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": null,
612 | "metadata": {
613 | "id": "53VXoEwlwSFK"
614 | },
615 | "outputs": [],
616 | "source": [
617 | "# コンテキスト定義\n",
618 | "model3 = pm.Model()\n",
619 | "\n",
620 | "with model3:\n",
621 | " # pm.Uniform: 一様分布\n",
622 | " p = pm.Uniform('p', lower=0.0, upper=1.0)\n",
623 | "\n",
624 | " # pm.Binomial: 二項分布\n",
625 | " # p: 成功確率\n",
626 | " # n: 試行数\n",
627 | " X_obs = pm.Binomial('X_obs', p=p, n=50, observed=20)"
628 | ]
629 | },
630 | {
631 | "cell_type": "markdown",
632 | "metadata": {
633 | "id": "5dtaWhYN0BBF"
634 | },
635 | "source": [
636 | "#### 確率モデル可視化"
637 | ]
638 | },
639 | {
640 | "cell_type": "code",
641 | "execution_count": null,
642 | "metadata": {
643 | "id": "Qwszi6e9znQb"
644 | },
645 | "outputs": [],
646 | "source": [
647 | "g = pm.model_to_graphviz(model3)\n",
648 | "display(g)"
649 | ]
650 | },
651 | {
652 | "cell_type": "markdown",
653 | "metadata": {
654 | "id": "UnJcsyBkwStT"
655 | },
656 | "source": [
657 | "#### サンプリング"
658 | ]
659 | },
660 | {
661 | "cell_type": "code",
662 | "execution_count": null,
663 | "metadata": {
664 | "id": "_JQzyGrtsWt2"
665 | },
666 | "outputs": [],
667 | "source": [
668 | "# サンプリング\n",
669 | "with model3:\n",
670 | " idata3 = pm.sample(random_seed=42)"
671 | ]
672 | },
673 | {
674 | "cell_type": "markdown",
675 | "metadata": {
676 | "id": "eFYbErMBwWw-"
677 | },
678 | "source": [
679 | "#### 結果分析"
680 | ]
681 | },
682 | {
683 | "cell_type": "markdown",
684 | "metadata": {
685 | "id": "7lR42jx-0Ll4"
686 | },
687 | "source": [
688 | "##### plot_trace関数呼び出し"
689 | ]
690 | },
691 | {
692 | "cell_type": "code",
693 | "execution_count": null,
694 | "metadata": {
695 | "id": "fLYRGRdDwalA"
696 | },
697 | "outputs": [],
698 | "source": [
699 | "axes = az.plot_trace(idata3, compact=False)\n",
700 | "plt.tight_layout();"
701 | ]
702 | },
703 | {
704 | "cell_type": "markdown",
705 | "metadata": {
706 | "id": "riapq5By0SIX"
707 | },
708 | "source": [
709 | "##### plot_posterior関数呼び出し"
710 | ]
711 | },
712 | {
713 | "cell_type": "code",
714 | "execution_count": null,
715 | "metadata": {
716 | "id": "g43nAFOix2rj"
717 | },
718 | "outputs": [],
719 | "source": [
720 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
721 | "ax = az.plot_posterior(idata3)\n",
722 | "ax.set_xlim(0, 1)\n",
723 | "ax.set_title('試行回数を増やす(n=50)');"
724 | ]
725 | },
726 | {
727 | "cell_type": "markdown",
728 | "metadata": {
729 | "id": "Z7sMEfRi0a7H"
730 | },
731 | "source": [
732 | "##### summary関数呼び出し"
733 | ]
734 | },
735 | {
736 | "cell_type": "code",
737 | "execution_count": null,
738 | "metadata": {
739 | "id": "EHTGm9VHz-Dg"
740 | },
741 | "outputs": [],
742 | "source": [
743 | "summary3 = az.summary(idata3)\n",
744 | "display(summary3)"
745 | ]
746 | },
747 | {
748 | "cell_type": "markdown",
749 | "metadata": {
750 | "id": "vwx-xiBCIaCC"
751 | },
752 | "source": [
753 | "### 4.8 事前分布を変更する\n"
754 | ]
755 | },
756 | {
757 | "cell_type": "code",
758 | "execution_count": null,
759 | "metadata": {
760 | "id": "zTWoxuwkI072"
761 | },
762 | "outputs": [],
763 | "source": [
764 | "# コンテキスト定義\n",
765 | "model4 = pm.Model()\n",
766 | "\n",
767 | "with model4:\n",
768 | " # 確率モデル定義\n",
769 | "\n",
770 | " # 一様分布のパラメータを変更\n",
771 | " p = pm.Uniform('p', lower=0.1, upper=0.9)\n",
772 | "\n",
773 | " # 5回中2回あたりという観測値はそのまま\n",
774 | " X_obs = pm.Binomial('X_obs', p=p, n=5, observed=2)\n",
775 | "\n",
776 | " # サンプル値取得\n",
777 | " idata4 = pm.sample(random_seed=42)"
778 | ]
779 | },
780 | {
781 | "cell_type": "code",
782 | "execution_count": null,
783 | "metadata": {
784 | "id": "TboUyaG-u4DM"
785 | },
786 | "outputs": [],
787 | "source": [
788 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
789 | "# 事後分布の可視化\n",
790 | "ax = az.plot_posterior(idata4)\n",
791 | "ax.set_title('事前分布変更版 [0.1, 0.9]')\n",
792 | "ax.set_xlim(0, 1)\n",
793 | "\n",
794 | "# サンプル値の集計\n",
795 | "summary4 = az.summary(idata4)\n",
796 | "display(summary4)"
797 | ]
798 | },
799 | {
800 | "cell_type": "markdown",
801 | "metadata": {
802 | "id": "r00m_BKFsdkV"
803 | },
804 | "source": [
805 | "### 4.9 ベータ分布との比較"
806 | ]
807 | },
808 | {
809 | "cell_type": "markdown",
810 | "metadata": {
811 | "id": "L9faosvZ6PFS"
812 | },
813 | "source": [
814 | "#### ベータ分布とベイズ推論結果の重ね描き"
815 | ]
816 | },
817 | {
818 | "cell_type": "code",
819 | "execution_count": null,
820 | "metadata": {
821 | "id": "EGClbDSX1OSp"
822 | },
823 | "outputs": [],
824 | "source": [
825 | "# 真のベータ関数の定義\n",
826 | "from scipy import stats\n",
827 | "alpha = 20 + 1\n",
828 | "beta = 30 + 1\n",
829 | "true_beta = stats.beta(alpha, beta)\n",
830 | "\n",
831 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
832 | "# ベイズ推論結果の可視化\n",
833 | "# idata3は4.7節で計算した結果を利用\n",
834 | "ax = az.plot_posterior(idata3)\n",
835 | "ax.lines[0].set_label('ベイズ推論結果')\n",
836 | "\n",
837 | "# 真のベータ関数の可視化\n",
838 | "x = np.linspace(*ax.get_xlim())\n",
839 | "ax.plot(x, true_beta.pdf(x), color='orange', label='真値')\n",
840 | "ax.legend(loc='center right');"
841 | ]
842 | },
843 | {
844 | "cell_type": "markdown",
845 | "metadata": {
846 | "id": "SE-35GfF1igk"
847 | },
848 | "source": [
849 | "### コラム ArviZのFAQ"
850 | ]
851 | },
852 | {
853 | "cell_type": "markdown",
854 | "metadata": {
855 | "id": "CcUfjEW31-vT"
856 | },
857 | "source": [
858 | "#### plot_posterior関数でyスケールを表示"
859 | ]
860 | },
861 | {
862 | "cell_type": "code",
863 | "execution_count": null,
864 | "metadata": {
865 | "id": "4r93XQ7lvWpd"
866 | },
867 | "outputs": [],
868 | "source": [
869 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
870 | "ax = az.plot_posterior(idata1_2)\n",
871 | "\n",
872 | "# y軸の線の表示\n",
873 | "ax.spines['left'].set_visible(True)\n",
874 | "\n",
875 | "# y軸ラベル表示\n",
876 | "ax.set_ylabel(\"Density\")\n",
877 | "\n",
878 | "# y軸のティックの位置を自動的に決定\n",
879 | "from matplotlib.ticker import AutoLocator\n",
880 | "ax.yaxis.set_major_locator(AutoLocator())\n",
881 | "\n",
882 | "ax.set_xlim(0, 1)\n",
883 | "ax.set_title('ベイズ推論結果 y軸スケール表示版');"
884 | ]
885 | },
886 | {
887 | "cell_type": "markdown",
888 | "metadata": {
889 | "id": "nqhQ5QON5U7F"
890 | },
891 | "source": [
892 | "#### plot_trace関数グラフでタイトルを変更"
893 | ]
894 | },
895 | {
896 | "cell_type": "code",
897 | "execution_count": null,
898 | "metadata": {
899 | "id": "XOEzJKjj5Ag2"
900 | },
901 | "outputs": [],
902 | "source": [
903 | "axes = az.plot_trace(idata1_2, compact=False)\n",
904 | "plt.tight_layout()\n",
905 | "axes[0,1].set_title('グラフタイトルの変更');"
906 | ]
907 | },
908 | {
909 | "cell_type": "markdown",
910 | "metadata": {
911 | "id": "daMkk5jQmAgf"
912 | },
913 | "source": [
914 | "#### バージョンの確認"
915 | ]
916 | },
917 | {
918 | "cell_type": "code",
919 | "execution_count": null,
920 | "metadata": {
921 | "id": "oMZeD9bY9F-m"
922 | },
923 | "outputs": [],
924 | "source": [
925 | "!pip install watermark | tail -n 1\n",
926 | "%load_ext watermark\n",
927 | "%watermark --iversions"
928 | ]
929 | },
930 | {
931 | "cell_type": "code",
932 | "execution_count": null,
933 | "metadata": {
934 | "id": "Vqizv8yl14Kg"
935 | },
936 | "outputs": [],
937 | "source": []
938 | }
939 | ],
940 | "metadata": {
941 | "colab": {
942 | "provenance": [],
943 | "toc_visible": true
944 | },
945 | "kernelspec": {
946 | "display_name": "Python 3",
947 | "name": "python3"
948 | },
949 | "language_info": {
950 | "name": "python"
951 | }
952 | },
953 | "nbformat": 4,
954 | "nbformat_minor": 0
955 | }
--------------------------------------------------------------------------------
/notebooks/5_1_データ分布のベイズ推論.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "``plot_trace``関数でワーニングを出さないため次のコマンド実行が必要です。"
7 | ],
8 | "metadata": {
9 | "id": "R4d32npiN0Be"
10 | }
11 | },
12 | {
13 | "cell_type": "code",
14 | "source": [
15 | "!pip install --upgrade numba | tail -n 1"
16 | ],
17 | "metadata": {
18 | "id": "UkCP90SUN1GO"
19 | },
20 | "execution_count": null,
21 | "outputs": []
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {
26 | "id": "uIoj_UW_qasp"
27 | },
28 | "source": [
29 | "## 5.1 データ分布のベイス推論"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {
35 | "id": "nwV0yADH5ED0"
36 | },
37 | "source": [
38 | "
\n"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "vdi2sqRyqr4p"
45 | },
46 | "source": [
47 | "### 共通処理"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "id": "0dTjJOmIqmiS"
55 | },
56 | "outputs": [],
57 | "source": [
58 | "%matplotlib inline\n",
59 | "# 日本語化ライブラリ導入\n",
60 | "!pip install japanize-matplotlib | tail -n 1"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {
67 | "id": "M46-MnBXq5-Z"
68 | },
69 | "outputs": [],
70 | "source": [
71 | "# ライブラリのimport\n",
72 | "\n",
73 | "# NumPy用ライブラリ\n",
74 | "import numpy as np\n",
75 | "# Matplotlib中のpyplotライブラリのインポート\n",
76 | "import matplotlib.pyplot as plt\n",
77 | "# matplotlib日本語化対応ライブラリのインポート\n",
78 | "import japanize_matplotlib\n",
79 | "# pandas用ライブラリ\n",
80 | "import pandas as pd\n",
81 | "# データフレーム表示用関数\n",
82 | "from IPython.display import display\n",
83 | "# seaborn\n",
84 | "import seaborn as sns\n",
85 | "# 表示オプション調整\n",
86 | "# NumPy表示形式の設定\n",
87 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
88 | "# グラフのデフォルトフォント指定\n",
89 | "plt.rcParams[\"font.size\"] = 14\n",
90 | "# サイズ設定\n",
91 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
92 | "# 方眼表示ON\n",
93 | "plt.rcParams['axes.grid'] = True\n",
94 | "# データフレームでの表示精度\n",
95 | "pd.options.display.float_format = '{:.3f}'.format\n",
96 | "# データフレームですべての項目を表示\n",
97 | "pd.set_option(\"display.max_columns\",None)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "id": "my5YU9YCq-ou"
105 | },
106 | "outputs": [],
107 | "source": [
108 | "import pymc as pm\n",
109 | "import arviz as az\n",
110 | "\n",
111 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
112 | "print(f\"Running on ArViz v{az.__version__}\")"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "GKrDeczv3ULO"
119 | },
120 | "source": [
121 | "### 5章"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {
128 | "id": "SXSs1e4u3cWX"
129 | },
130 | "outputs": [],
131 | "source": [
132 | "df = sns.load_dataset('iris')\n",
133 | "df1 = df.iloc[[0, 1, 50, 51, 100, 101]]\n",
134 | "display(df1)"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "metadata": {
140 | "id": "rBIfYJ_BrGUD"
141 | },
142 | "source": [
143 | "### 5.1.1 問題定義\n",
144 | "アイリスデータセットの特定の花の種類の、特定の項目の長さの分布は、正規分布に従うと見なせるものとする。 \n",
145 | "3つの花のうちのひとつであるvirginicaのsepal_length(花弁の長さ)に最も近い正規分布のパラメータ(平均と分散)の値をベイズ推論で求めよ。\n"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "metadata": {
151 | "id": "lnTxrJeXtJlG"
152 | },
153 | "source": [
154 | "### 5.1.2 データ準備"
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "metadata": {
160 | "id": "Jzjj2s1HLA6T"
161 | },
162 | "source": [
163 | "#### データ読み込みと確認"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {
170 | "id": "4DkJdzLIrCFM"
171 | },
172 | "outputs": [],
173 | "source": [
174 | "# アイリスデータセットの読み込み\n",
175 | "df = sns.load_dataset('iris')\n",
176 | "\n",
177 | "# 先頭5件の確認\n",
178 | "display(df.head())\n",
179 | "\n",
180 | "# speciesの分布確認\n",
181 | "df['species'].value_counts()"
182 | ]
183 | },
184 | {
185 | "cell_type": "markdown",
186 | "metadata": {
187 | "id": "DTKV9xhELLcO"
188 | },
189 | "source": [
190 | "#### データ抽出とヒストグラムの描画"
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": null,
196 | "metadata": {
197 | "id": "CPKscIcDu1vf"
198 | },
199 | "outputs": [],
200 | "source": [
201 | "# setosaの行のみ抽出\n",
202 | "df1 = df.query('species == \"setosa\"')\n",
203 | "\n",
204 | "bins = np.arange(4.0, 6.2, 0.2)\n",
205 | "# ヒストグラムを描画\n",
206 | "sns.histplot(df1, x='sepal_length', bins=bins, kde=True)\n",
207 | "plt.xticks(bins);"
208 | ]
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "metadata": {
213 | "id": "yCkS_qvfLVGf"
214 | },
215 | "source": [
216 | "#### 変数Xの抽出と値の確認"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": null,
222 | "metadata": {
223 | "id": "VvS8dmXFvHO1"
224 | },
225 | "outputs": [],
226 | "source": [
227 | "# sepal_length列の抽出\n",
228 | "s1 = df1['sepal_length']\n",
229 | "\n",
230 | "# NumPy変数の1次元配列に変換\n",
231 | "X = s1.values\n",
232 | "\n",
233 | "# 統計情報の確認\n",
234 | "print(s1.describe())\n",
235 | "\n",
236 | "# 値の確認\n",
237 | "print(X)"
238 | ]
239 | },
240 | {
241 | "cell_type": "markdown",
242 | "metadata": {
243 | "id": "PgH4jGWpvoW9"
244 | },
245 | "source": [
246 | "### 5.1.3 確率モデル定義"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {
252 | "id": "3FkUoMNwVSzF"
253 | },
254 | "source": [
255 | "#### 確率モデル定義"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": null,
261 | "metadata": {
262 | "id": "Bdk8zw3fuphC"
263 | },
264 | "outputs": [],
265 | "source": [
266 | "model1 = pm.Model()\n",
267 | "\n",
268 | "with model1:\n",
269 | " mu = pm.Normal('mu',mu=0.0, sigma=10.0)\n",
270 | " sigma = pm.HalfNormal('sigma', sigma=10.0)\n",
271 | " X_obs = pm.Normal('X_obs', mu=mu, sigma=sigma, observed=X)"
272 | ]
273 | },
274 | {
275 | "cell_type": "markdown",
276 | "metadata": {
277 | "id": "43pVQ0NsVX7V"
278 | },
279 | "source": [
280 | "#### 確率モデル構造可視化"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": null,
286 | "metadata": {
287 | "id": "WoopmWwGvvWe"
288 | },
289 | "outputs": [],
290 | "source": [
291 | "g = pm.model_to_graphviz(model1)\n",
292 | "display(g)"
293 | ]
294 | },
295 | {
296 | "cell_type": "markdown",
297 | "metadata": {
298 | "id": "_mP-yzOLv_9n"
299 | },
300 | "source": [
301 | "### 5.1.4 サンプリング"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": null,
307 | "metadata": {
308 | "id": "zQSSO-S9v4Qr"
309 | },
310 | "outputs": [],
311 | "source": [
312 | "with model1:\n",
313 | " idata1 = pm.sample(random_seed=42)"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "metadata": {
319 | "id": "A5aRjCmEyVnE"
320 | },
321 | "source": [
322 | "### 5.1.5 結果分析"
323 | ]
324 | },
325 | {
326 | "cell_type": "markdown",
327 | "metadata": {
328 | "id": "FUKh3mO-fMg8"
329 | },
330 | "source": [
331 | "#### plot_trace関数による分析"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": null,
337 | "metadata": {
338 | "id": "dHkdwKit1A3W"
339 | },
340 | "outputs": [],
341 | "source": [
342 | "az.plot_trace(idata1, compact=False)\n",
343 | "plt.tight_layout();"
344 | ]
345 | },
346 | {
347 | "cell_type": "markdown",
348 | "metadata": {
349 | "id": "20YE1pt5iI7x"
350 | },
351 | "source": [
352 | "#### idata直接確認"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": null,
358 | "metadata": {
359 | "id": "qj4Ioag3iOUT"
360 | },
361 | "outputs": [],
362 | "source": [
363 | "idata1"
364 | ]
365 | },
366 | {
367 | "cell_type": "markdown",
368 | "metadata": {
369 | "id": "fWpyG8h5fWXZ"
370 | },
371 | "source": [
372 | "#### plot_posterior関数による分析"
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": null,
378 | "metadata": {
379 | "id": "gZRIn1vF1I7t"
380 | },
381 | "outputs": [],
382 | "source": [
383 | "az.plot_posterior(idata1);"
384 | ]
385 | },
386 | {
387 | "cell_type": "markdown",
388 | "metadata": {
389 | "id": "xp7dgWNZff50"
390 | },
391 | "source": [
392 | "#### sumamry関数による統計分析"
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "execution_count": null,
398 | "metadata": {
399 | "id": "CtqOQID11XYs"
400 | },
401 | "outputs": [],
402 | "source": [
403 | "summary1 = az.summary(idata1)\n",
404 | "display(summary1)"
405 | ]
406 | },
407 | {
408 | "cell_type": "markdown",
409 | "metadata": {
410 | "id": "-lAI6N26fn5T"
411 | },
412 | "source": [
413 | "#### 各確率変数の平均値取得"
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "execution_count": null,
419 | "metadata": {
420 | "id": "nrL3o7ky0nt4"
421 | },
422 | "outputs": [],
423 | "source": [
424 | "mu_mean1 = summary1.loc['mu','mean']\n",
425 | "sigma_mean1 = summary1.loc['sigma','mean']\n",
426 | "\n",
427 | "# 結果確認\n",
428 | "print(f'mu={mu_mean1}, sigma={sigma_mean1}')"
429 | ]
430 | },
431 | {
432 | "cell_type": "markdown",
433 | "metadata": {
434 | "id": "PTcyd5k32g6G"
435 | },
436 | "source": [
437 | "### 5.1.6 正規分布関数とヒストグラムの重ね描き"
438 | ]
439 | },
440 | {
441 | "cell_type": "markdown",
442 | "metadata": {
443 | "id": "mgmraDXF8z_b"
444 | },
445 | "source": [
446 | "#### 正規分布関数"
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "execution_count": null,
452 | "metadata": {
453 | "id": "Xs55zld60zZQ"
454 | },
455 | "outputs": [],
456 | "source": [
457 | "def norm(x, mu, sigma):\n",
458 | " y = (x-mu)/sigma\n",
459 | " a = np.exp(-(y**2)/2)\n",
460 | " b = np.sqrt(2*np.pi)*sigma\n",
461 | " return a/b"
462 | ]
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "metadata": {
467 | "id": "0I4B_8va9PtY"
468 | },
469 | "source": [
470 | "#### ベイズ推論結果に基づく関数値計算"
471 | ]
472 | },
473 | {
474 | "cell_type": "code",
475 | "execution_count": null,
476 | "metadata": {
477 | "id": "GogVKUbO2xu8"
478 | },
479 | "outputs": [],
480 | "source": [
481 | "x_min = X.min()\n",
482 | "x_max = X.max()\n",
483 | "x_list = np.arange(x_min, x_max, 0.01)\n",
484 | "y_list = norm(x_list, mu_mean1, sigma_mean1)"
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "metadata": {
490 | "id": "0E2zs-LV9nx_"
491 | },
492 | "source": [
493 | "#### ベイズ推論結果に基づく正規分布関数とKDE曲線の重ね描き"
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": null,
499 | "metadata": {
500 | "id": "zN5vCU-hvsIx"
501 | },
502 | "outputs": [],
503 | "source": [
504 | "delta = 0.2\n",
505 | "bins=np.arange(4.0, 6.0, delta)\n",
506 | "fig, ax = plt.subplots()\n",
507 | "sns.histplot(df1, ax=ax, x='sepal_length',\n",
508 | " bins=bins, kde=True, stat='probability')\n",
509 | "ax.get_lines()[0].set_label('KDE曲線')\n",
510 | "ax.set_xticks(bins)\n",
511 | "ax.plot(x_list, y_list*delta, c='b', label='ベイズ推論結果')\n",
512 | "ax.set_title('ベイズ推論結果とKDE曲線の比較')\n",
513 | "plt.legend();"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "metadata": {
519 | "id": "yCzyPliQ3qWQ"
520 | },
521 | "source": [
522 | "### 5.1.7 少ないサンプル数でベイス推論"
523 | ]
524 | },
525 | {
526 | "cell_type": "markdown",
527 | "metadata": {
528 | "id": "xWmhdRGHlPmR"
529 | },
530 | "source": [
531 | "#### データの絞り込み"
532 | ]
533 | },
534 | {
535 | "cell_type": "code",
536 | "execution_count": null,
537 | "metadata": {
538 | "id": "sDZlagN54f_5"
539 | },
540 | "outputs": [],
541 | "source": [
542 | "# 先頭の5件だけにする\n",
543 | "X_less = X[:5]\n",
544 | "\n",
545 | "# 結果確認\n",
546 | "print(X_less)\n",
547 | "\n",
548 | "# 統計値確認\n",
549 | "print(pd.Series(X_less).describe())"
550 | ]
551 | },
552 | {
553 | "cell_type": "markdown",
554 | "metadata": {
555 | "id": "e61Psv1qlWVQ"
556 | },
557 | "source": [
558 | "#### 確率モデル定義とサンプリング"
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": null,
564 | "metadata": {
565 | "id": "pNqa0GoI3a7S"
566 | },
567 | "outputs": [],
568 | "source": [
569 | "model2 = pm.Model()\n",
570 | "\n",
571 | "with model2:\n",
572 | " mu = pm.Normal('mu', mu=0.0, sigma=10.0)\n",
573 | " sigma = pm.HalfNormal('sigma', sigma=10.0)\n",
574 | " X_obs = pm.Normal('X_obs', mu=mu, sigma=sigma, observed=X_less)\n",
575 | "\n",
576 | " # サンプリング\n",
577 | " idata2 = pm.sample(random_seed=42)"
578 | ]
579 | },
580 | {
581 | "cell_type": "markdown",
582 | "metadata": {
583 | "id": "8rTSxuAKlhIH"
584 | },
585 | "source": [
586 | "#### サンプリング結果の可視化"
587 | ]
588 | },
589 | {
590 | "cell_type": "code",
591 | "execution_count": null,
592 | "metadata": {
593 | "id": "-ZyvC1oH43IE"
594 | },
595 | "outputs": [],
596 | "source": [
597 | "az.plot_posterior(idata2);"
598 | ]
599 | },
600 | {
601 | "cell_type": "markdown",
602 | "metadata": {
603 | "id": "--zHIudRlpQs"
604 | },
605 | "source": [
606 | "#### サンプリング結果の統計分析"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": null,
612 | "metadata": {
613 | "id": "9AqNBNe232p6"
614 | },
615 | "outputs": [],
616 | "source": [
617 | "summary2 = az.summary(idata2)\n",
618 | "display(summary2)"
619 | ]
620 | },
621 | {
622 | "cell_type": "markdown",
623 | "metadata": {
624 | "id": "pGtpoDNOMxVt"
625 | },
626 | "source": [
627 | "### コラム tauによる確率モデルの定義"
628 | ]
629 | },
630 | {
631 | "cell_type": "markdown",
632 | "metadata": {
633 | "id": "fV0AOLZgQayD"
634 | },
635 | "source": [
636 | "#### 確率モデル定義とサンプリング"
637 | ]
638 | },
639 | {
640 | "cell_type": "code",
641 | "execution_count": null,
642 | "metadata": {
643 | "id": "1mNaKloG5O2n"
644 | },
645 | "outputs": [],
646 | "source": [
647 | "model3 = pm.Model()\n",
648 | "\n",
649 | "with model3:\n",
650 | " mu = pm.Normal('mu', mu=0.0, sigma=10.0)\n",
651 | " tau = pm.HalfNormal('tau', sigma=10.0)\n",
652 | " X_obs = pm.Normal('X_obs', mu=mu, tau=tau, observed=X)\n",
653 | " sigma = pm.Deterministic('sigma', 1/pm.math.sqrt(tau))\n",
654 | "\n",
655 | " # サンプリング\n",
656 | " idata3 = pm.sample(random_seed=42)"
657 | ]
658 | },
659 | {
660 | "cell_type": "markdown",
661 | "metadata": {
662 | "id": "OsbJLlF9S-yn"
663 | },
664 | "source": [
665 | "#### サンプリング結果の可視化"
666 | ]
667 | },
668 | {
669 | "cell_type": "code",
670 | "execution_count": null,
671 | "metadata": {
672 | "id": "GRmiLRw_RTHU"
673 | },
674 | "outputs": [],
675 | "source": [
676 | "az.plot_posterior(idata3);"
677 | ]
678 | },
679 | {
680 | "cell_type": "markdown",
681 | "metadata": {
682 | "id": "f45Q3TnfNefC"
683 | },
684 | "source": [
685 | "#### バージョンの確認"
686 | ]
687 | },
688 | {
689 | "cell_type": "code",
690 | "execution_count": null,
691 | "metadata": {
692 | "id": "vD-SCcaGRqrJ"
693 | },
694 | "outputs": [],
695 | "source": [
696 | "!pip install watermark | tail -n 1\n",
697 | "%load_ext watermark\n",
698 | "%watermark --iversions"
699 | ]
700 | }
701 | ],
702 | "metadata": {
703 | "colab": {
704 | "provenance": [],
705 | "toc_visible": true
706 | },
707 | "kernelspec": {
708 | "display_name": "Python 3",
709 | "name": "python3"
710 | },
711 | "language_info": {
712 | "name": "python"
713 | }
714 | },
715 | "nbformat": 4,
716 | "nbformat_minor": 0
717 | }
--------------------------------------------------------------------------------
/notebooks/5_2_線形回帰のベイズ推論.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "``plot_trace``関数でワーニングを出さないため次のコマンド実行が必要です。"
7 | ],
8 | "metadata": {
9 | "id": "sPUJpOrtKLYo"
10 | }
11 | },
12 | {
13 | "cell_type": "code",
14 | "source": [
15 | "!pip install --upgrade numba | tail -n 1"
16 | ],
17 | "metadata": {
18 | "id": "2sy-eC_jKMhy"
19 | },
20 | "execution_count": null,
21 | "outputs": []
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {
26 | "id": "uIoj_UW_qasp"
27 | },
28 | "source": [
29 | "## 5.2 線形回帰のベイス推論"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {
35 | "id": "CW5ODy9I5R71"
36 | },
37 | "source": [
38 | "
"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "vdi2sqRyqr4p"
45 | },
46 | "source": [
47 | "### 共通処理"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "id": "0dTjJOmIqmiS"
55 | },
56 | "outputs": [],
57 | "source": [
58 | "%matplotlib inline\n",
59 | "# 日本語化ライブラリ導入\n",
60 | "!pip install japanize-matplotlib | tail -n 1"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {
67 | "id": "M46-MnBXq5-Z"
68 | },
69 | "outputs": [],
70 | "source": [
71 | "# ライブラリのimport\n",
72 | "\n",
73 | "# NumPy用ライブラリ\n",
74 | "import numpy as np\n",
75 | "# Matplotlib中のpyplotライブラリのインポート\n",
76 | "import matplotlib.pyplot as plt\n",
77 | "# matplotlib日本語化対応ライブラリのインポート\n",
78 | "import japanize_matplotlib\n",
79 | "# pandas用ライブラリ\n",
80 | "import pandas as pd\n",
81 | "# データフレーム表示用関数\n",
82 | "from IPython.display import display\n",
83 | "# seaborn\n",
84 | "import seaborn as sns\n",
85 | "# 表示オプション調整\n",
86 | "# NumPy表示形式の設定\n",
87 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
88 | "# グラフのデフォルトフォント指定\n",
89 | "plt.rcParams[\"font.size\"] = 14\n",
90 | "# サイズ設定\n",
91 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
92 | "# 方眼表示ON\n",
93 | "plt.rcParams['axes.grid'] = True\n",
94 | "# データフレームでの表示精度\n",
95 | "pd.options.display.float_format = '{:.3f}'.format\n",
96 | "# データフレームですべての項目を表示\n",
97 | "pd.set_option(\"display.max_columns\",None)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "id": "my5YU9YCq-ou"
105 | },
106 | "outputs": [],
107 | "source": [
108 | "import pymc as pm\n",
109 | "import arviz as az\n",
110 | "\n",
111 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
112 | "print(f\"Running on ArViz v{az.__version__}\")"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "rBIfYJ_BrGUD"
119 | },
120 | "source": [
121 | "### 5.2.1 問題設定\n",
122 | "アイリスデータセットで花の種類を特定した場合、2つの項目値には正の相関があり。線形回帰(単回帰)に従うと見なせる。 \n",
123 | "当実習では、上の点を前提とした上で、最適な回帰式をベイズ推論で求める。 \n"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "metadata": {
129 | "id": "lnTxrJeXtJlG"
130 | },
131 | "source": [
132 | "### 5.2.2 データ準備"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {
138 | "id": "qtTYOjAQlH-G"
139 | },
140 | "source": [
141 | "#### データ読み込みと確認"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {
148 | "id": "4DkJdzLIrCFM"
149 | },
150 | "outputs": [],
151 | "source": [
152 | "# アイリスデータセットの読み込み\n",
153 | "df = sns.load_dataset('iris')\n",
154 | "\n",
155 | "# 先頭5行の確認\n",
156 | "display(df.head())\n",
157 | "\n",
158 | "# speciesの分布確認\n",
159 | "df['species'].value_counts()"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {
165 | "id": "OZIm6GndlUDA"
166 | },
167 | "source": [
168 | "#### 分析対象データ抽出"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "metadata": {
175 | "id": "ugzTmiojNV6m"
176 | },
177 | "outputs": [],
178 | "source": [
179 | "# versicolorの行のみ抽出\n",
180 | "df1 = df.query('species == \"versicolor\"')\n",
181 | "\n",
182 | "# sepal_lengthとsepal_widthの列を抽出\n",
183 | "X = df1['sepal_length']\n",
184 | "Y = df1['sepal_width']"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {
190 | "id": "TMDozZ0XojEK"
191 | },
192 | "source": [
193 | "#### 分析対象項目間の散布図表示"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": null,
199 | "metadata": {
200 | "id": "RVvUfPLGVTH-"
201 | },
202 | "outputs": [],
203 | "source": [
204 | "plt.title('2つの項目間の関係')\n",
205 | "plt.scatter(X, Y, label='ベイズ推論で利用', c='b', marker='o')\n",
206 | "plt.legend()\n",
207 | "plt.xlabel('sepal_length')\n",
208 | "plt.ylabel('sepal_width');"
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "metadata": {
214 | "id": "PgH4jGWpvoW9"
215 | },
216 | "source": [
217 | "### 5.2.3 確率モデル定義1"
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {
223 | "id": "_OL7srviYyiX"
224 | },
225 | "source": [
226 | "$ y_n = \\alpha x_n + \\beta + \\epsilon_n$"
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {
232 | "id": "KGL7gLc4Iy8M"
233 | },
234 | "source": [
235 | "#### 確率モデル定義1\n",
236 | "シンプルなモデル定義の方法"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "metadata": {
243 | "id": "Bdk8zw3fuphC"
244 | },
245 | "outputs": [],
246 | "source": [
247 | "model1 = pm.Model()\n",
248 | "\n",
249 | "with model1:\n",
250 | " # 確率変数alpha、betaの定義(一次関数の傾きと切片)\n",
251 | " alpha = pm.Normal('alpha', mu=0.0, sigma=10.0)\n",
252 | " beta = pm.Normal('beta', mu=0.0, sigma=10.0)\n",
253 | "\n",
254 | " # 平均値muの計算\n",
255 | " mu = alpha * X + beta\n",
256 | "\n",
257 | " # 誤差を示す確率変数epsilonの定義\n",
258 | " epsilon = pm.HalfNormal('epsilon', sigma=1.0)\n",
259 | "\n",
260 | " # 観測値を持つ確率変数はY_obsとして定義\n",
261 | " Y_obs= pm.Normal('Y_obs', mu=mu, sigma=epsilon, observed=Y)"
262 | ]
263 | },
264 | {
265 | "cell_type": "markdown",
266 | "metadata": {
267 | "id": "KP68nsy3I9Eo"
268 | },
269 | "source": [
270 | "#### 確率モデル構造の可視化"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": null,
276 | "metadata": {
277 | "id": "WoopmWwGvvWe"
278 | },
279 | "outputs": [],
280 | "source": [
281 | "g = pm.model_to_graphviz(model1)\n",
282 | "display(g)"
283 | ]
284 | },
285 | {
286 | "cell_type": "markdown",
287 | "metadata": {
288 | "id": "eAKDls9dJ-Q_"
289 | },
290 | "source": [
291 | "### 5.2.4 確率モデル定義2\n",
292 | "より詳細なモデル定義の方法"
293 | ]
294 | },
295 | {
296 | "cell_type": "markdown",
297 | "metadata": {
298 | "id": "QdqlMSQOJGBz"
299 | },
300 | "source": [
301 | "#### 確率モデル定義2"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": null,
307 | "metadata": {
308 | "id": "8ibX_Xo2J2pB"
309 | },
310 | "outputs": [],
311 | "source": [
312 | "model2 = pm.Model()\n",
313 | "\n",
314 | "with model2:\n",
315 | " # X, Yの観測値をConstantDataとして定義\n",
316 | " X_data = pm.ConstantData('X_data', X)\n",
317 | " Y_data = pm.ConstantData('Y_data', Y)\n",
318 | "\n",
319 | " # 確率変数alpha、betaの定義(一次関数の傾きと切片)\n",
320 | " alpha = pm.Normal('alpha', mu=0.0, sigma=10.0)\n",
321 | " beta = pm.Normal('beta', mu=0.0, sigma=10.0)\n",
322 | "\n",
323 | " # 平均値muの計算\n",
324 | " mu = pm.Deterministic('mu', alpha * X_data + beta)\n",
325 | "\n",
326 | " # 誤差を示す確率変数epsilonの定義\n",
327 | " epsilon = pm.HalfNormal('epsilon', sigma=1.0)\n",
328 | "\n",
329 | " # 観測値を持つ確率変数はobsとして定義\n",
330 | " obs = pm.Normal('obs', mu=mu, sigma=epsilon, observed=Y_data)"
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "metadata": {
336 | "id": "yag_0AbFxpJ3"
337 | },
338 | "source": [
339 | "#### 確率モデル構造の可視化"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": null,
345 | "metadata": {
346 | "id": "QtK06_gdKESL"
347 | },
348 | "outputs": [],
349 | "source": [
350 | "g = pm.model_to_graphviz(model2)\n",
351 | "display(g)"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {
357 | "id": "_mP-yzOLv_9n"
358 | },
359 | "source": [
360 | "### 5.2.5 サンプリングと結果分析"
361 | ]
362 | },
363 | {
364 | "cell_type": "markdown",
365 | "metadata": {
366 | "id": "pIhC2Q0NxysT"
367 | },
368 | "source": [
369 | "#### サンプリング"
370 | ]
371 | },
372 | {
373 | "cell_type": "code",
374 | "execution_count": null,
375 | "metadata": {
376 | "id": "zQSSO-S9v4Qr"
377 | },
378 | "outputs": [],
379 | "source": [
380 | "with model2:\n",
381 | " idata2 = pm.sample(random_seed=42)"
382 | ]
383 | },
384 | {
385 | "cell_type": "markdown",
386 | "metadata": {
387 | "id": "jUIFku6B0Pt0"
388 | },
389 | "source": [
390 | "#### plot_trace関数呼び出し"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": null,
396 | "metadata": {
397 | "id": "IFYibU57KOdZ"
398 | },
399 | "outputs": [],
400 | "source": [
401 | "az.plot_trace(idata2, compact=False, var_names=['alpha', 'beta', 'epsilon'])\n",
402 | "plt.tight_layout();"
403 | ]
404 | },
405 | {
406 | "cell_type": "markdown",
407 | "metadata": {
408 | "id": "Nsc3eedw0ab-"
409 | },
410 | "source": [
411 | "#### plot_posterior関数呼び出し"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": null,
417 | "metadata": {
418 | "id": "UbMt15TcrCTi"
419 | },
420 | "outputs": [],
421 | "source": [
422 | "az.plot_posterior(idata2, var_names=['alpha', 'beta', 'epsilon']);"
423 | ]
424 | },
425 | {
426 | "cell_type": "markdown",
427 | "metadata": {
428 | "id": "PTcyd5k32g6G"
429 | },
430 | "source": [
431 | "### 5.2.6 散布図と回帰直線の重ね描き"
432 | ]
433 | },
434 | {
435 | "cell_type": "markdown",
436 | "metadata": {
437 | "id": "d-UzacE93MZ3"
438 | },
439 | "source": [
440 | "#### 個別のサンプルにおける回帰直線予測値の計算"
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "execution_count": null,
446 | "metadata": {
447 | "id": "kIMQA4gojgSk"
448 | },
449 | "outputs": [],
450 | "source": [
451 | "# xの2点をNumPy配列にする\n",
452 | "x_values = np.array([X.min()-0.1, X.max()+0.1])\n",
453 | "print(x_values, x_values.shape)\n",
454 | "\n",
455 | "# サンプリング結果からalphaとbetaを取り出しshapeを加工する\n",
456 | "alphas2 = idata2['posterior']['alpha'].values.reshape(-1, 1)\n",
457 | "betas2 = idata2['posterior']['beta'].values.reshape(-1, 1)\n",
458 | "\n",
459 | "# shapeの確認\n",
460 | "print(alphas2.shape, betas2.shape)\n",
461 | "\n",
462 | "# 2000パターンそれぞれで、2点の1次関数値の計算\n",
463 | "y_preds = x_values * alphas2 + betas2\n",
464 | "print(y_preds.shape)"
465 | ]
466 | },
467 | {
468 | "cell_type": "markdown",
469 | "metadata": {
470 | "id": "Px8DRNbF3PP3"
471 | },
472 | "source": [
473 | "#### ベイズ推論における回帰直線"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": null,
479 | "metadata": {
480 | "id": "Eox4W8MtnkSv"
481 | },
482 | "outputs": [],
483 | "source": [
484 | "for y_pred in y_preds:\n",
485 | " plt.plot(x_values, y_pred, lw=1, alpha=0.01, c='g')\n",
486 | "plt.scatter(X, Y)\n",
487 | "plt.title('ベイズ推論による回帰直線')\n",
488 | "plt.xlabel('sepal_length')\n",
489 | "plt.ylabel('sepal_width');"
490 | ]
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "metadata": {
495 | "id": "yCzyPliQ3qWQ"
496 | },
497 | "source": [
498 | "### 5.2.7 少ないサンプル数でベイス推論"
499 | ]
500 | },
501 | {
502 | "cell_type": "markdown",
503 | "metadata": {
504 | "id": "PCVhshFFCdgH"
505 | },
506 | "source": [
507 | "#### 乱数により3個のインデックスを生成"
508 | ]
509 | },
510 | {
511 | "cell_type": "code",
512 | "execution_count": null,
513 | "metadata": {
514 | "id": "SRWppjjmPGat"
515 | },
516 | "outputs": [],
517 | "source": [
518 | "import random\n",
519 | "random.seed(42)\n",
520 | "indexes =range(len(X))\n",
521 | "sample_indexes=random.sample(indexes, 3)\n",
522 | "print('インデックス値', sample_indexes)\n",
523 | "\n",
524 | "# データ数を3個にする\n",
525 | "X_less = X.iloc[sample_indexes]\n",
526 | "Y_less = Y.iloc[sample_indexes]\n",
527 | "print('xの値', X_less.values)\n",
528 | "print('yの値', Y_less.values)"
529 | ]
530 | },
531 | {
532 | "cell_type": "markdown",
533 | "metadata": {
534 | "id": "FHCDa9bKCmHI"
535 | },
536 | "source": [
537 | "#### 抽出した3点の散布図表示"
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "execution_count": null,
543 | "metadata": {
544 | "id": "bW2snNg6Q5sp"
545 | },
546 | "outputs": [],
547 | "source": [
548 | "plt.title('sepal_lengthとsepal_widthの関係')\n",
549 | "plt.scatter(X_less, Y_less, label='ベイズ推論で利用', c='b', marker='o')\n",
550 | "plt.legend()\n",
551 | "plt.xlabel('sepal_length')\n",
552 | "plt.ylabel('sepal_width');"
553 | ]
554 | },
555 | {
556 | "cell_type": "markdown",
557 | "metadata": {
558 | "id": "zbkUV8zaCs-8"
559 | },
560 | "source": [
561 | "#### 確率モデル定義とサンプリング"
562 | ]
563 | },
564 | {
565 | "cell_type": "code",
566 | "execution_count": null,
567 | "metadata": {
568 | "id": "pNqa0GoI3a7S"
569 | },
570 | "outputs": [],
571 | "source": [
572 | "model3 = pm.Model()\n",
573 | "\n",
574 | "with model3:\n",
575 | " # X, Yの観測値をConstantDataとして定義\n",
576 | " X_data = pm.ConstantData('X_data', X_less)\n",
577 | " Y_data = pm.ConstantData('Y_data', Y_less)\n",
578 | "\n",
579 | " # 確率変数alpha、betaの定義(一次関数の傾きと切片)\n",
580 | " alpha = pm.Normal('alpha', mu=0.0, sigma=10.0)\n",
581 | " beta = pm.Normal('beta', mu=0.0, sigma=10.0)\n",
582 | "\n",
583 | " # 平均値muの計算\n",
584 | " mu = pm.Deterministic('mu', alpha * X_data + beta)\n",
585 | "\n",
586 | " # 誤差を示す確率変数epsilonの定義\n",
587 | " epsilon = pm.HalfNormal('epsilon', sigma=1.0)\n",
588 | "\n",
589 | " # 観測値を持つ確率変数はobsとして定義\n",
590 | " obs = pm.Normal('obs', mu=mu, sigma=epsilon, observed=Y_data)\n",
591 | "\n",
592 | " # サンプリング\n",
593 | " idata3 = pm.sample(random_seed=42, target_accept=0.995)"
594 | ]
595 | },
596 | {
597 | "cell_type": "markdown",
598 | "metadata": {
599 | "id": "c2XPr0QNC3h9"
600 | },
601 | "source": [
602 | "#### plot_trace関数でベイズ推論結果の確認"
603 | ]
604 | },
605 | {
606 | "cell_type": "code",
607 | "execution_count": null,
608 | "metadata": {
609 | "id": "qmASPaG_goiN"
610 | },
611 | "outputs": [],
612 | "source": [
613 | "az.plot_trace(idata3, compact=False, var_names=['alpha', 'beta', 'epsilon'])\n",
614 | "plt.tight_layout();"
615 | ]
616 | },
617 | {
618 | "cell_type": "markdown",
619 | "metadata": {
620 | "id": "qLM8PGmpC-cY"
621 | },
622 | "source": [
623 | "#### 散布図と回帰直線の重ね描き"
624 | ]
625 | },
626 | {
627 | "cell_type": "code",
628 | "execution_count": null,
629 | "metadata": {
630 | "id": "1mNaKloG5O2n"
631 | },
632 | "outputs": [],
633 | "source": [
634 | "# xの2点をNumPy配列にする\n",
635 | "x_values = np.array([X_less.min()-0.1, X_less.max()+0.1])\n",
636 | "\n",
637 | "# サンプル値からalphaとbetaを取り出しshapeを加工する\n",
638 | "alphas3 = idata3['posterior']['alpha'].values.reshape(-1, 1)\n",
639 | "betas3 = idata3['posterior']['beta'].values.reshape(-1, 1)\n",
640 | "\n",
641 | "# 2000パターンそれぞれで、2点の1次関数値の計算\n",
642 | "y_preds = x_values * alphas3 + betas3\n",
643 | "\n",
644 | "# 2000組の直線を散布図と同時表示\n",
645 | "for y_pred in y_preds:\n",
646 | " plt.plot(x_values, y_pred, lw=1, alpha=0.01, c='g')\n",
647 | "plt.scatter(X_less, Y_less)\n",
648 | "plt.ylim(1.75, 3.75)\n",
649 | "plt.title('ベイズ推論による回帰直線')\n",
650 | "plt.xlabel('sepal_length')\n",
651 | "plt.ylabel('sepal_width');"
652 | ]
653 | },
654 | {
655 | "cell_type": "markdown",
656 | "metadata": {
657 | "id": "bBOCfigvJl9A"
658 | },
659 | "source": [
660 | "### コラム target_acceptによるチューニング"
661 | ]
662 | },
663 | {
664 | "cell_type": "markdown",
665 | "metadata": {
666 | "id": "Gv2T_qMeDGaQ"
667 | },
668 | "source": [
669 | "#### 確率モデル定義とサンプリング(``target_accept``パラメータなし)"
670 | ]
671 | },
672 | {
673 | "cell_type": "code",
674 | "execution_count": null,
675 | "metadata": {
676 | "id": "pf97GBZ_05sJ"
677 | },
678 | "outputs": [],
679 | "source": [
680 | "model4 = pm.Model()\n",
681 | "\n",
682 | "with model4:\n",
683 | " # X, Yの観測値をConstantDataとして定義\n",
684 | " X_data = pm.ConstantData('X_data', X_less)\n",
685 | " Y_data = pm.ConstantData('Y_data', Y_less)\n",
686 | "\n",
687 | " # 確率変数alpha、betaの定義(一次関数の傾きと切片)\n",
688 | " alpha = pm.Normal('alpha', mu=0.0, sigma=10.0)\n",
689 | " beta = pm.Normal('beta', mu=0.0, sigma=10.0)\n",
690 | "\n",
691 | " # 平均値muの計算\n",
692 | " mu = pm.Deterministic('mu', alpha * X_data + beta)\n",
693 | "\n",
694 | " # 誤差を示す確率変数epsilonの定義\n",
695 | " epsilon = pm.HalfNormal('epsilon', sigma=1.0)\n",
696 | "\n",
697 | " # 観測値を持つ確率変数はobsとして定義\n",
698 | " obs = pm.Normal('obs', mu=mu, sigma=epsilon, observed=Y_data)\n",
699 | "\n",
700 | " # サンプリング\n",
701 | " idata4 = pm.sample(random_seed=42)"
702 | ]
703 | },
704 | {
705 | "cell_type": "markdown",
706 | "metadata": {
707 | "id": "LO0YvmJmDLe_"
708 | },
709 | "source": [
710 | "#### plot_trace関数でベイズ推論結果の確認"
711 | ]
712 | },
713 | {
714 | "cell_type": "code",
715 | "execution_count": null,
716 | "metadata": {
717 | "id": "UFnNkdxtJ0vq"
718 | },
719 | "outputs": [],
720 | "source": [
721 | "az.plot_trace(idata4, compact=False, var_names=['alpha', 'beta', 'epsilon'])\n",
722 | "plt.tight_layout();"
723 | ]
724 | },
725 | {
726 | "cell_type": "markdown",
727 | "metadata": {
728 | "id": "hdRMXdvd3cw9"
729 | },
730 | "source": [
731 | "#### summary関数でベイズ推論結果の確認"
732 | ]
733 | },
734 | {
735 | "cell_type": "code",
736 | "execution_count": null,
737 | "metadata": {
738 | "id": "_sUOZ7-n1RMm"
739 | },
740 | "outputs": [],
741 | "source": [
742 | "summary4 = az.summary(idata4, var_names=['alpha', 'beta', 'epsilon'])\n",
743 | "display(summary4)"
744 | ]
745 | },
746 | {
747 | "cell_type": "markdown",
748 | "metadata": {
749 | "id": "F4m2U-HADRGo"
750 | },
751 | "source": [
752 | "#### 散布図と回帰直線の重ね描き"
753 | ]
754 | },
755 | {
756 | "cell_type": "code",
757 | "execution_count": null,
758 | "metadata": {
759 | "id": "Y3axvdwBJ6Qc"
760 | },
761 | "outputs": [],
762 | "source": [
763 | "# xの2点をNumPy配列にする\n",
764 | "x_values = np.array([X_less.min()-0.1, X_less.max()+0.1])\n",
765 | "\n",
766 | "# サンプル値からalphaとbetaを取り出しshapeを加工する\n",
767 | "alphas4 = idata4['posterior']['alpha'].values.reshape(-1, 1)\n",
768 | "betas4 = idata4['posterior']['beta'].values.reshape(-1, 1)\n",
769 | "\n",
770 | "# 2000パターンそれぞれで、2点の1次関数値の計算\n",
771 | "y_preds = x_values * alphas4 + betas4\n",
772 | "\n",
773 | "# 2000組の直線を散布図と同時表示\n",
774 | "for y_pred in y_preds:\n",
775 | " plt.plot(x_values, y_pred, lw=1, alpha=0.01, c='g')\n",
776 | "plt.scatter(X_less, Y_less)\n",
777 | "plt.ylim(1.75, 3.75)\n",
778 | "plt.title('ベイズ推論による回帰直線')\n",
779 | "plt.xlabel('sepal_length')\n",
780 | "plt.ylabel('sepal_width');"
781 | ]
782 | },
783 | {
784 | "cell_type": "markdown",
785 | "metadata": {
786 | "id": "P-xWewzUPm65"
787 | },
788 | "source": [
789 | "#### バージョンの確認"
790 | ]
791 | },
792 | {
793 | "cell_type": "code",
794 | "execution_count": null,
795 | "metadata": {
796 | "id": "TMop6cFw1jlv"
797 | },
798 | "outputs": [],
799 | "source": [
800 | "!pip install watermark | tail -n 1\n",
801 | "%load_ext watermark\n",
802 | "%watermark --iversions"
803 | ]
804 | },
805 | {
806 | "cell_type": "code",
807 | "execution_count": null,
808 | "metadata": {
809 | "id": "t2UAcikzUdlq"
810 | },
811 | "outputs": [],
812 | "source": []
813 | }
814 | ],
815 | "metadata": {
816 | "colab": {
817 | "provenance": [],
818 | "toc_visible": true
819 | },
820 | "kernelspec": {
821 | "display_name": "Python 3",
822 | "name": "python3"
823 | },
824 | "language_info": {
825 | "name": "python"
826 | }
827 | },
828 | "nbformat": 4,
829 | "nbformat_minor": 0
830 | }
--------------------------------------------------------------------------------
/notebooks/5_3_階層ベイズモデル.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "``plot_trace``関数でワーニングを出さないため次のコマンド実行が必要です。"
7 | ],
8 | "metadata": {
9 | "id": "u1Dxl3l7GFiO"
10 | }
11 | },
12 | {
13 | "cell_type": "code",
14 | "source": [
15 | "!pip install --upgrade numba | tail -n 1"
16 | ],
17 | "metadata": {
18 | "id": "R2CT9qNvGAMl"
19 | },
20 | "execution_count": null,
21 | "outputs": []
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {
26 | "id": "uIoj_UW_qasp"
27 | },
28 | "source": [
29 | "## 5.3 階層ベイズモデル"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {
35 | "id": "rZsf6jGB5Lwg"
36 | },
37 | "source": [
38 | "
"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "vdi2sqRyqr4p"
45 | },
46 | "source": [
47 | "### 共通処理"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "id": "0dTjJOmIqmiS"
55 | },
56 | "outputs": [],
57 | "source": [
58 | "%matplotlib inline\n",
59 | "# 日本語化ライブラリ導入\n",
60 | "!pip install japanize-matplotlib | tail -n 1"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {
67 | "id": "M46-MnBXq5-Z"
68 | },
69 | "outputs": [],
70 | "source": [
71 | "# ライブラリのimport\n",
72 | "\n",
73 | "# NumPy用ライブラリ\n",
74 | "import numpy as np\n",
75 | "# Matplotlib中のpyplotライブラリのインポート\n",
76 | "import matplotlib.pyplot as plt\n",
77 | "# matplotlib日本語化対応ライブラリのインポート\n",
78 | "import japanize_matplotlib\n",
79 | "# pandas用ライブラリ\n",
80 | "import pandas as pd\n",
81 | "# データフレーム表示用関数\n",
82 | "from IPython.display import display\n",
83 | "# seaborn\n",
84 | "import seaborn as sns\n",
85 | "# 表示オプション調整\n",
86 | "# NumPy表示形式の設定\n",
87 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
88 | "# グラフのデフォルトフォント指定\n",
89 | "plt.rcParams[\"font.size\"] = 14\n",
90 | "# サイズ設定\n",
91 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
92 | "# 方眼表示ON\n",
93 | "plt.rcParams['axes.grid'] = True\n",
94 | "# データフレームでの表示精度\n",
95 | "pd.options.display.float_format = '{:.3f}'.format\n",
96 | "# データフレームですべての項目を表示\n",
97 | "pd.set_option(\"display.max_columns\",None)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "id": "my5YU9YCq-ou"
105 | },
106 | "outputs": [],
107 | "source": [
108 | "import pymc as pm\n",
109 | "import arviz as az\n",
110 | "\n",
111 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
112 | "print(f\"Running on ArViz v{az.__version__}\")"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "rBIfYJ_BrGUD"
119 | },
120 | "source": [
121 | "### 5.3.1 問題設定\n",
122 | "アイリスデータセットで3種類の花の回帰直線を同時に求める。 \n",
123 | "データ数が少ないという条件の下で、3つの回帰直線は共通の傾向がある前提とする。 \n",
124 | "この場合、階層ベイズモデルの問題に帰着する。\n"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {
130 | "id": "lnTxrJeXtJlG"
131 | },
132 | "source": [
133 | "### 5.3.2 データ準備"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {
139 | "id": "lqlJlSuGfR-A"
140 | },
141 | "source": [
142 | "#### アイリス・データセットの読み込みと内容の確認"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "metadata": {
149 | "id": "4DkJdzLIrCFM"
150 | },
151 | "outputs": [],
152 | "source": [
153 | "# アイリスデータセットの読み込み\n",
154 | "df = sns.load_dataset('iris')\n",
155 | "\n",
156 | "# 先頭5行の確認\n",
157 | "display(df.head())\n",
158 | "\n",
159 | "# speciesの分布確認\n",
160 | "df['species'].value_counts()"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {
167 | "id": "WNzQpPLHqP6x"
168 | },
169 | "outputs": [],
170 | "source": [
171 | "df.head().to_excel('df.xlsx')"
172 | ]
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "metadata": {
177 | "id": "wTLIiuipfX-W"
178 | },
179 | "source": [
180 | "#### 目標とするデータの抽出"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "metadata": {
187 | "id": "qu0fKUVOpBct"
188 | },
189 | "outputs": [],
190 | "source": [
191 | "# setosaの行のみ抽出\n",
192 | "df0 = df.query('species == \"setosa\"')\n",
193 | "\n",
194 | "# versicolorの行のみ抽出\n",
195 | "df1 = df.query('species == \"versicolor\"')\n",
196 | "\n",
197 | "# virginicaの行のみ抽出\n",
198 | "df2 = df.query('species == \"virginica\"')\n",
199 | "\n",
200 | "# 乱数により3個のインデックスを生成\n",
201 | "import random\n",
202 | "random.seed(42)\n",
203 | "indexes =range(len(df0))\n",
204 | "sample_indexes=random.sample(indexes, 3)\n",
205 | "\n",
206 | "# df0, df1, df2のデータ数をそれぞれ3行にする\n",
207 | "df0_sel = df0.iloc[sample_indexes]\n",
208 | "df1_sel = df1.iloc[sample_indexes]\n",
209 | "df2_sel = df2.iloc[sample_indexes]\n",
210 | "\n",
211 | "# 全部連結して一つにする\n",
212 | "df_sel = pd.concat([df0_sel, df1_sel, df2_sel]).reset_index(drop=True)"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {
218 | "id": "xr2VZ6wtgB9T"
219 | },
220 | "source": [
221 | "#### 加工結果の確認"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": null,
227 | "metadata": {
228 | "id": "FBLpngQSqWLI"
229 | },
230 | "outputs": [],
231 | "source": [
232 | "display(df_sel)"
233 | ]
234 | },
235 | {
236 | "cell_type": "markdown",
237 | "metadata": {
238 | "id": "Of0MuJdwgJYU"
239 | },
240 | "source": [
241 | "#### 散布図による抽出結果表示"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "metadata": {
248 | "id": "1CPM_dirr_v-"
249 | },
250 | "outputs": [],
251 | "source": [
252 | "sns.scatterplot(\n",
253 | " x='sepal_length', y='sepal_width', hue='species', style='species',\n",
254 | " data=df_sel, s=100)\n",
255 | "plt.title('抽出した計9個の観測値の散布図');"
256 | ]
257 | },
258 | {
259 | "cell_type": "markdown",
260 | "metadata": {
261 | "id": "LJtp411JgVLz"
262 | },
263 | "source": [
264 | "#### ベイス推論用変数の抽出"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": null,
270 | "metadata": {
271 | "id": "Rhjla-2GsHFz"
272 | },
273 | "outputs": [],
274 | "source": [
275 | "X = df_sel['sepal_length'].values\n",
276 | "Y = df_sel['sepal_width'].values\n",
277 | "species = df_sel['species']\n",
278 | "cl = pd.Categorical(species).codes\n",
279 | "\n",
280 | "# 結果確認\n",
281 | "print(X)\n",
282 | "print(Y)\n",
283 | "print(species.values)\n",
284 | "print(cl)"
285 | ]
286 | },
287 | {
288 | "cell_type": "markdown",
289 | "metadata": {
290 | "id": "PgH4jGWpvoW9"
291 | },
292 | "source": [
293 | "### 5.3.3 確率モデル定義"
294 | ]
295 | },
296 | {
297 | "cell_type": "markdown",
298 | "metadata": {
299 | "id": "q2wEx9Eui877"
300 | },
301 | "source": [
302 | "#### 階層ベイズモデルの確率モデル定義"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": null,
308 | "metadata": {
309 | "id": "Bdk8zw3fuphC"
310 | },
311 | "outputs": [],
312 | "source": [
313 | "model1 = pm.Model()\n",
314 | "\n",
315 | "with model1:\n",
316 | " # X, Yの観測値をConstantDataとして定義(通常ベイズと共通)\n",
317 | " X_data = pm.ConstantData('X_data', X)\n",
318 | " Y_data = pm.ConstantData('Y_data', Y)\n",
319 | "\n",
320 | " # クラス変数定義(階層ベイズ固有)\n",
321 | " cl_data = pm.ConstantData('cl_data', cl)\n",
322 | "\n",
323 | " # 確率変数alphaの定義(階層ベイズ固有)\n",
324 | " a_mu = pm.Normal('a_mu', mu=0.0, sigma=10.0)\n",
325 | " a_sigma = pm.HalfNormal('a_sigma',sigma=10.0)\n",
326 | " alpha = pm.Normal('alpha', mu=a_mu, sigma=a_sigma, shape=(3,))\n",
327 | "\n",
328 | " # 確率変数betaの定義(階層ベイズ固有)\n",
329 | " b_mu = pm.Normal('b_mu', mu=0.0, sigma=10.0)\n",
330 | " b_sigma = pm.HalfNormal('b_sigma', sigma=10.0)\n",
331 | " beta = pm.Normal('beta', mu=b_mu, sigma=b_sigma, shape=(3,))\n",
332 | "\n",
333 | " # 誤差epsilon(通常ベイスと共通)\n",
334 | " epsilon = pm.HalfNormal('epsilon', sigma=1.0)\n",
335 | "\n",
336 | " # muの値は、cl_dataによりindexを切り替えて計算(階層ベイズ固有)\n",
337 | " mu = pm.Deterministic('mu', X_data * alpha[cl_data] + beta[cl_data])\n",
338 | "\n",
339 | " # mu, epsilonを使って観測値用の確率モデルを定義(通常ベイスと共通)\n",
340 | " obs = pm.Normal('obs', mu=mu, sigma=epsilon, observed=Y_data)"
341 | ]
342 | },
343 | {
344 | "cell_type": "markdown",
345 | "metadata": {
346 | "id": "UJq-5kQ1yeqf"
347 | },
348 | "source": [
349 | "#### PyMC変数をNumPy変数に置き換えた場合のmuの計算の仕組み"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": null,
355 | "metadata": {
356 | "id": "rh6TJNT0w3Ub"
357 | },
358 | "outputs": [],
359 | "source": [
360 | "# ALPHAは要素数3の配列\n",
361 | "ALPHA = np.array([0.1, 0.2, 0.3])\n",
362 | "print(ALPHA)\n",
363 | "\n",
364 | "# CLは要素数9の配列\n",
365 | "CL = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])\n",
366 | "print(CL)\n",
367 | "\n",
368 | "# MUも要素数9の配列になる\n",
369 | "MU = ALPHA[CL]\n",
370 | "print(MU)"
371 | ]
372 | },
373 | {
374 | "cell_type": "markdown",
375 | "metadata": {
376 | "id": "mUwpWWu6hIXj"
377 | },
378 | "source": [
379 | "#### 確率モデル構造可視化"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": null,
385 | "metadata": {
386 | "id": "WoopmWwGvvWe"
387 | },
388 | "outputs": [],
389 | "source": [
390 | "g = pm.model_to_graphviz(model1)\n",
391 | "display(g)"
392 | ]
393 | },
394 | {
395 | "cell_type": "markdown",
396 | "metadata": {
397 | "id": "_mP-yzOLv_9n"
398 | },
399 | "source": [
400 | "### 5.3.4 サンプリングと結果分析"
401 | ]
402 | },
403 | {
404 | "cell_type": "markdown",
405 | "metadata": {
406 | "id": "fFuw4MLrhQVb"
407 | },
408 | "source": [
409 | "#### サンプリング"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": null,
415 | "metadata": {
416 | "id": "zQSSO-S9v4Qr"
417 | },
418 | "outputs": [],
419 | "source": [
420 | "with model1:\n",
421 | " idata1 = pm.sample(random_seed=42, target_accept=0.998)"
422 | ]
423 | },
424 | {
425 | "cell_type": "markdown",
426 | "metadata": {
427 | "id": "YBM38RmvhYg5"
428 | },
429 | "source": [
430 | "#### plot_trace関数で推論結果の確認"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": null,
436 | "metadata": {
437 | "id": "dHkdwKit1A3W"
438 | },
439 | "outputs": [],
440 | "source": [
441 | "az.plot_trace(idata1, compact=False, var_names=['alpha', 'beta'])\n",
442 | "plt.tight_layout();"
443 | ]
444 | },
445 | {
446 | "cell_type": "markdown",
447 | "metadata": {
448 | "id": "EnKKNK2ihgt5"
449 | },
450 | "source": [
451 | "#### summary関数で推論結果の確認"
452 | ]
453 | },
454 | {
455 | "cell_type": "code",
456 | "execution_count": null,
457 | "metadata": {
458 | "id": "gGp8Zzmumfjt"
459 | },
460 | "outputs": [],
461 | "source": [
462 | "summary1 = az.summary(idata1, var_names=['alpha', 'beta'])\n",
463 | "display(summary1)"
464 | ]
465 | },
466 | {
467 | "cell_type": "markdown",
468 | "metadata": {
469 | "id": "PTcyd5k32g6G"
470 | },
471 | "source": [
472 | "### 5.3.5 散布図と回帰直線の重ね描き"
473 | ]
474 | },
475 | {
476 | "cell_type": "markdown",
477 | "metadata": {
478 | "id": "P_OzcUr0h5cs"
479 | },
480 | "source": [
481 | "#### 散布図と回帰直線の重ね描き"
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": null,
487 | "metadata": {
488 | "id": "7abMvBMkVx30"
489 | },
490 | "outputs": [],
491 | "source": [
492 | "# alphaとbetaの平均値の導出\n",
493 | "means = summary1['mean']\n",
494 | "alpha0 = means['alpha[0]']\n",
495 | "alpha1 = means['alpha[1]']\n",
496 | "alpha2 = means['alpha[2]']\n",
497 | "beta0 = means['beta[0]']\n",
498 | "beta1 = means['beta[1]']\n",
499 | "beta2 = means['beta[2]']\n",
500 | "\n",
501 | "# 回帰直線用座標値の計算\n",
502 | "x_range = np.array([X.min()-0.1,X.max()+0.1])\n",
503 | "y0_range = alpha0 * x_range + beta0\n",
504 | "y1_range = alpha1 * x_range + beta1\n",
505 | "y2_range = alpha2 * x_range + beta2\n",
506 | "\n",
507 | "# 散布図表示\n",
508 | "sns.scatterplot(\n",
509 | " x='sepal_length', y='sepal_width', hue='species', style='species',\n",
510 | " data=df_sel, s=100)\n",
511 | "plt.plot(x_range, y0_range, label='setosa')\n",
512 | "plt.plot(x_range, y1_range, label='versicolor')\n",
513 | "plt.plot(x_range, y2_range, label='virginica')\n",
514 | "plt.legend();"
515 | ]
516 | },
517 | {
518 | "cell_type": "markdown",
519 | "metadata": {
520 | "id": "b0Q5cPTaiALL"
521 | },
522 | "source": [
523 | "#### オリジナルの散布図と回帰直線の重ね描き"
524 | ]
525 | },
526 | {
527 | "cell_type": "code",
528 | "execution_count": null,
529 | "metadata": {
530 | "id": "AY3snblC-Ge6"
531 | },
532 | "outputs": [],
533 | "source": [
534 | "# 回帰直線の座標値計算\n",
535 | "x_range = np.array([\n",
536 | " df['sepal_length'].min()-0.1,\n",
537 | " df['sepal_length'].max()+0.1])\n",
538 | "y0_range = alpha0 * x_range + beta0\n",
539 | "y1_range = alpha1 * x_range + beta1\n",
540 | "y2_range = alpha2 * x_range + beta2\n",
541 | "\n",
542 | "# 散布図表示\n",
543 | "sns.scatterplot(\n",
544 | " x='sepal_length', y='sepal_width', hue='species', style='species',\n",
545 | " s=50, data=df)\n",
546 | "plt.plot(x_range, y0_range, label='setosa')\n",
547 | "plt.plot(x_range, y1_range, label='versicolor')\n",
548 | "plt.plot(x_range, y2_range, label='virginica')\n",
549 | "plt.legend();"
550 | ]
551 | },
552 | {
553 | "cell_type": "markdown",
554 | "metadata": {
555 | "id": "uLxgWJrRB1xB"
556 | },
557 | "source": [
558 | "### コラム PyMCの構成要素はどこまで細かく定義すべきか"
559 | ]
560 | },
561 | {
562 | "cell_type": "markdown",
563 | "metadata": {
564 | "id": "EZg_xHzbS3vT"
565 | },
566 | "source": [
567 | "#### 確率変数の関係のみに着目した簡易的な方法"
568 | ]
569 | },
570 | {
571 | "cell_type": "code",
572 | "execution_count": null,
573 | "metadata": {
574 | "id": "drOFaYy3KJ2U"
575 | },
576 | "outputs": [],
577 | "source": [
578 | "model2 = pm.Model()\n",
579 | "\n",
580 | "with model2:\n",
581 | " # 確率変数alphaの定義(階層ベイズ固有)\n",
582 | " a_mu = pm.Normal('a_mu', mu=0.0, sigma=10.0)\n",
583 | " a_sigma = pm.HalfNormal('a_sigma',sigma=10.0)\n",
584 | " alpha = pm.Normal('alpha', mu=a_mu, sigma=a_sigma, shape=(3,))\n",
585 | "\n",
586 | " # 確率変数betaの定義(階層ベイズ固有)\n",
587 | " b_mu = pm.Normal('b_mu', mu=0.0, sigma=10.0)\n",
588 | " b_sigma = pm.HalfNormal('b_sigma', sigma=10.0)\n",
589 | " beta = pm.Normal('beta', mu=b_mu, sigma=b_sigma, shape=(3,))\n",
590 | "\n",
591 | " # 誤差epsilon(通常ベイスと共通)\n",
592 | " epsilon = pm.HalfNormal('epsilon', sigma=1.0)\n",
593 | "\n",
594 | " # muの値は、clによりindexを切り替えて計算(階層ベイズ固有)\n",
595 | " mu = X * alpha[cl] + beta[cl]\n",
596 | "\n",
597 | " # mu, epsilonを使って観測値用の確率モデルを定義(通常ベイスと共通)\n",
598 | " Y_obs = pm.Normal('Y_obs', mu=mu, sigma=epsilon, observed=Y)\n",
599 | "\n",
600 | "g = pm.model_to_graphviz(model2)\n",
601 | "display(g)"
602 | ]
603 | },
604 | {
605 | "cell_type": "markdown",
606 | "metadata": {
607 | "id": "0uRqtLwkTB03"
608 | },
609 | "source": [
610 | "#### 観測値や途中経過を含めたすべての計算過程を表現"
611 | ]
612 | },
613 | {
614 | "cell_type": "code",
615 | "execution_count": null,
616 | "metadata": {
617 | "id": "F-LqzuFftSHq"
618 | },
619 | "outputs": [],
620 | "source": [
621 | "model3 = pm.Model()\n",
622 | "\n",
623 | "with model3:\n",
624 | " # X, Yの観測値をConstantDataとして定義(通常ベイズと共通)\n",
625 | " X_data = pm.ConstantData('X_data', X)\n",
626 | " Y_data = pm.ConstantData('Y_data', Y)\n",
627 | "\n",
628 | " # クラス変数定義(階層ベイズ固有)\n",
629 | " cl_data = pm.ConstantData('cl_data', cl)\n",
630 | "\n",
631 | " # 確率変数alphaの定義(階層ベイズ固有)\n",
632 | " a_mu = pm.Normal('a_mu', mu=0.0, sigma=10.0)\n",
633 | " a_sigma = pm.HalfNormal('a_sigma',sigma=10.0)\n",
634 | " alpha = pm.Normal('alpha', mu=a_mu, sigma=a_sigma, shape=(3,))\n",
635 | "\n",
636 | " # 確率変数betaの定義(階層ベイズ固有)\n",
637 | " b_mu = pm.Normal('b_mu', mu=0.0, sigma=10.0)\n",
638 | " b_sigma = pm.HalfNormal('b_sigma', sigma=10.0)\n",
639 | " beta = pm.Normal('beta', mu=b_mu, sigma=b_sigma, shape=(3,))\n",
640 | "\n",
641 | " # 誤差epsilon(通常ベイスと共通)\n",
642 | " epsilon = pm.HalfNormal('epsilon', sigma=1.0)\n",
643 | "\n",
644 | " # muの値は、cl_dataによりindexを切り替えて計算(階層ベイズ固有)\n",
645 | " mu = pm.Deterministic('mu', X_data * alpha[cl_data] + beta[cl_data])\n",
646 | "\n",
647 | " # mu, epsilonを使って観測値用の確率モデルを定義(通常ベイスと共通)\n",
648 | " obs = pm.Normal('obs', mu=mu, sigma=epsilon, observed=Y_data)\n",
649 | "\n",
650 | "g = pm.model_to_graphviz(model3)\n",
651 | "display(g)"
652 | ]
653 | },
654 | {
655 | "cell_type": "markdown",
656 | "metadata": {
657 | "id": "W2DN4aGCVLBL"
658 | },
659 | "source": [
660 | "#### バージョンの確認"
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "execution_count": null,
666 | "metadata": {
667 | "id": "jdB0-PL_Cn10"
668 | },
669 | "outputs": [],
670 | "source": [
671 | "!pip install watermark | tail -n 1\n",
672 | "%load_ext watermark\n",
673 | "%watermark --iversions"
674 | ]
675 | },
676 | {
677 | "cell_type": "code",
678 | "execution_count": null,
679 | "metadata": {
680 | "id": "Nl1O5pvJpt56"
681 | },
682 | "outputs": [],
683 | "source": []
684 | }
685 | ],
686 | "metadata": {
687 | "colab": {
688 | "provenance": [],
689 | "toc_visible": true
690 | },
691 | "kernelspec": {
692 | "display_name": "Python 3",
693 | "name": "python3"
694 | },
695 | "language_info": {
696 | "name": "python"
697 | }
698 | },
699 | "nbformat": 4,
700 | "nbformat_minor": 0
701 | }
--------------------------------------------------------------------------------
/notebooks/5_4_潜在変数モデル.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | " ``plot_trace``関数でワーニングを出さないため次のコマンド実行が必要です。"
7 | ],
8 | "metadata": {
9 | "id": "ChK2aGZpAppi"
10 | }
11 | },
12 | {
13 | "cell_type": "code",
14 | "source": [
15 | "!pip install --upgrade numba | tail -n 1"
16 | ],
17 | "metadata": {
18 | "id": "Ws5z32Up-jnB"
19 | },
20 | "execution_count": null,
21 | "outputs": []
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {
26 | "id": "uIoj_UW_qasp"
27 | },
28 | "source": [
29 | "## 5.4 潜在変数モデル"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {
35 | "id": "JHS5qeidVZvJ"
36 | },
37 | "source": [
38 | "
"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "vdi2sqRyqr4p"
45 | },
46 | "source": [
47 | "### 共通処理"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "id": "0dTjJOmIqmiS"
55 | },
56 | "outputs": [],
57 | "source": [
58 | "%matplotlib inline\n",
59 | "# 日本語化ライブラリ導入\n",
60 | "!pip install japanize-matplotlib | tail -n 1"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {
67 | "id": "M46-MnBXq5-Z"
68 | },
69 | "outputs": [],
70 | "source": [
71 | "# ライブラリのimport\n",
72 | "\n",
73 | "# NumPy用ライブラリ\n",
74 | "import numpy as np\n",
75 | "# Matplotlib中のpyplotライブラリのインポート\n",
76 | "import matplotlib.pyplot as plt\n",
77 | "# matplotlib日本語化対応ライブラリのインポート\n",
78 | "import japanize_matplotlib\n",
79 | "# pandas用ライブラリ\n",
80 | "import pandas as pd\n",
81 | "# データフレーム表示用関数\n",
82 | "from IPython.display import display\n",
83 | "# seaborn\n",
84 | "import seaborn as sns\n",
85 | "# 表示オプション調整\n",
86 | "# NumPy表示形式の設定\n",
87 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
88 | "# グラフのデフォルトフォント指定\n",
89 | "plt.rcParams[\"font.size\"] = 14\n",
90 | "# サイズ設定\n",
91 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
92 | "# 方眼表示ON\n",
93 | "plt.rcParams['axes.grid'] = True\n",
94 | "# データフレームでの表示精度\n",
95 | "pd.options.display.float_format = '{:.3f}'.format\n",
96 | "# データフレームですべての項目を表示\n",
97 | "pd.set_option(\"display.max_columns\",None)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "id": "my5YU9YCq-ou"
105 | },
106 | "outputs": [],
107 | "source": [
108 | "import pymc as pm\n",
109 | "import arviz as az\n",
110 | "\n",
111 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
112 | "print(f\"Running on ArViz v{az.__version__}\")"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "rBIfYJ_BrGUD"
119 | },
120 | "source": [
121 | "### 5.4.1 問題設定\n",
122 | "アイリスデータセットで特定項目の値のみを利用する。 \n",
123 | "花の種別の情報をなしで、2種類の花の統計的特徴を推論する。\n"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "metadata": {
129 | "id": "lnTxrJeXtJlG"
130 | },
131 | "source": [
132 | "### 5.4.2 データ準備"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {
138 | "id": "Z6oJ-5Yj4QQI"
139 | },
140 | "source": [
141 | "#### データ読み込みと確認"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {
148 | "id": "4DkJdzLIrCFM"
149 | },
150 | "outputs": [],
151 | "source": [
152 | "# アイリスデータセットの読み込み\n",
153 | "df = sns.load_dataset('iris')\n",
154 | "\n",
155 | "# 先頭5行の確認\n",
156 | "display(df.head())\n",
157 | "\n",
158 | "# speciesの分布確認\n",
159 | "df['species'].value_counts()"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {
165 | "id": "x8bv5SkA4U-f"
166 | },
167 | "source": [
168 | "#### 分析対象データの絞り込み\n"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "metadata": {
175 | "id": "CZFE_MrE2Vr8"
176 | },
177 | "outputs": [],
178 | "source": [
179 | "# 花の種類をsetosa以外の2種類に絞り込む\n",
180 | "df2 = df.query('species != \"setosa\"')\n",
181 | "\n",
182 | "# インデックスを0から振り直す\n",
183 | "df2 = df2.reset_index(drop=True)\n",
184 | "\n",
185 | "# petal_widthの項目値をx_dataにセット\n",
186 | "X = df2['petal_width'].values"
187 | ]
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "metadata": {
192 | "id": "areTBgXp4qVs"
193 | },
194 | "source": [
195 | "#### 分析対象データを色分けなしにヒストグラム表示"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "metadata": {
202 | "id": "yk01phVIlugz"
203 | },
204 | "outputs": [],
205 | "source": [
206 | "bins = np.arange(0.8, 3.0, 0.1)\n",
207 | "fig, ax = plt.subplots()\n",
208 | "sns.histplot(bins=bins, x=X)\n",
209 | "ax.set_xlabel('petal_width')\n",
210 | "ax.xaxis.set_tick_params(rotation=90)\n",
211 | "ax.set_title('petal_widthのヒストグラム')\n",
212 | "ax.set_xticks(bins);"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {
218 | "id": "qT7F6BWp4x6l"
219 | },
220 | "source": [
221 | "#### petal_widthのヒストグラム描画(花の種類で色分け)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": null,
227 | "metadata": {
228 | "id": "4KUJ4s76hQM-"
229 | },
230 | "outputs": [],
231 | "source": [
232 | "bins = np.arange(0.8, 3.0, 0.1)\n",
233 | "fig, ax = plt.subplots()\n",
234 | "sns.histplot(data=df2, bins=bins, x='petal_width',\n",
235 | " hue='species', kde=True)\n",
236 | "ax.xaxis.set_tick_params(rotation=90)\n",
237 | "ax.set_title('petal_widthのヒストグラム')\n",
238 | "ax.set_xticks(bins);"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {
244 | "id": "PgH4jGWpvoW9"
245 | },
246 | "source": [
247 | "### 5.4.3 確率モデル定義"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {
253 | "id": "TwLsB_CftLay"
254 | },
255 | "source": [
256 | "#### 潜在変数モデルの確率モデル定義"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "metadata": {
263 | "id": "k_Lw-2-XvIDi"
264 | },
265 | "outputs": [],
266 | "source": [
267 | "# 変数の初期設定\n",
268 | "\n",
269 | "# 何種類の正規分布モデルがあるか\n",
270 | "n_components = 2\n",
271 | "\n",
272 | "# 観測データ件数\n",
273 | "N = X.shape\n",
274 | "\n",
275 | "model1 = pm.Model()\n",
276 | "\n",
277 | "with model1:\n",
278 | " # Xの観測値をConstantDataとして定義\n",
279 | " X_data = pm.ConstantData('X_data', X)\n",
280 | "\n",
281 | " # p: 潜在変数が1の値をとる確率\n",
282 | " p = pm.Uniform('p', lower=0.0, upper=1.0)\n",
283 | "\n",
284 | " # s: 潜在変数pの確率値をもとに0, 1のいずれかの値を返す\n",
285 | " s = pm.Bernoulli('s', p=p, shape=N)\n",
286 | "\n",
287 | " # mus: 2つの花の種類毎の平均値\n",
288 | " mus = pm.Normal('mus', mu=0.0, sigma=10.0, shape=n_components)\n",
289 | "\n",
290 | " # taus: 2つの花の種類毎のバラツキ\n",
291 | " # 標準偏差sigmasとの間にはtaus = 1/(sigmas*sigmas)の関係がある\n",
292 | " taus = pm.HalfNormal('taus', sigma=10.0, shape=n_components)\n",
293 | "\n",
294 | " # グラフ描画など分析でsigmasが必要なため、tausからsigmasを求めておく\n",
295 | " sigmas = pm.Deterministic('sigmas', 1/pm.math.sqrt(taus))\n",
296 | "\n",
297 | " # 各観測値ごとに潜在変数からmuとtauを求める\n",
298 | " mu = pm.Deterministic('mu', mus[s])\n",
299 | " tau = pm.Deterministic('tau', taus[s])\n",
300 | "\n",
301 | " # 正規分布に従う確率変数X_obsの定義\n",
302 | " X_obs = pm.Normal('X_obs', mu=mu, tau=tau, observed=X_data)"
303 | ]
304 | },
305 | {
306 | "cell_type": "markdown",
307 | "metadata": {
308 | "id": "5FPUpEQOtSl7"
309 | },
310 | "source": [
311 | "#### 確率モデル構造可視化"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": null,
317 | "metadata": {
318 | "id": "WoopmWwGvvWe"
319 | },
320 | "outputs": [],
321 | "source": [
322 | "g = pm.model_to_graphviz(model1)\n",
323 | "display(g)"
324 | ]
325 | },
326 | {
327 | "cell_type": "markdown",
328 | "metadata": {
329 | "id": "_mP-yzOLv_9n"
330 | },
331 | "source": [
332 | "### 5.4.4 サンプリングと結果分析"
333 | ]
334 | },
335 | {
336 | "cell_type": "markdown",
337 | "metadata": {
338 | "id": "PLyVQMlKyD8a"
339 | },
340 | "source": [
341 | "#### サンプリング"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": null,
347 | "metadata": {
348 | "id": "UOt2twamwm7q"
349 | },
350 | "outputs": [],
351 | "source": [
352 | "with model1:\n",
353 | " idata1 = pm.sample(chains=1, draws=2000, target_accept=0.99,\n",
354 | " random_seed=42)"
355 | ]
356 | },
357 | {
358 | "cell_type": "markdown",
359 | "metadata": {
360 | "id": "_sl5QzimyNJv"
361 | },
362 | "source": [
363 | "#### plot_trace関数で推論結果の確認"
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": null,
369 | "metadata": {
370 | "id": "dHkdwKit1A3W"
371 | },
372 | "outputs": [],
373 | "source": [
374 | "az.plot_trace(idata1, var_names=['p', 'mus', 'sigmas'], compact=False)\n",
375 | "plt.tight_layout();"
376 | ]
377 | },
378 | {
379 | "cell_type": "markdown",
380 | "metadata": {
381 | "id": "nCNaXqIm1k3P"
382 | },
383 | "source": [
384 | "#### plot_posterior関数で各確率変数の事後分布の確認"
385 | ]
386 | },
387 | {
388 | "cell_type": "code",
389 | "execution_count": null,
390 | "metadata": {
391 | "id": "4zMT8RzI-BFK"
392 | },
393 | "outputs": [],
394 | "source": [
395 | "plt.rcParams['figure.figsize']=(6,6)\n",
396 | "az.plot_posterior(idata1, var_names=['p', 'mus', 'sigmas'])\n",
397 | "plt.tight_layout();"
398 | ]
399 | },
400 | {
401 | "cell_type": "markdown",
402 | "metadata": {
403 | "id": "8rKTNX-zyWwE"
404 | },
405 | "source": [
406 | "#### summary関数で統計情報の取得"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": null,
412 | "metadata": {
413 | "id": "SXEWuooOfo-c"
414 | },
415 | "outputs": [],
416 | "source": [
417 | "summary1 = az.summary(idata1, var_names=['p', 'mus', 'sigmas'])\n",
418 | "display(summary1)"
419 | ]
420 | },
421 | {
422 | "cell_type": "markdown",
423 | "metadata": {
424 | "id": "PTcyd5k32g6G"
425 | },
426 | "source": [
427 | "### 5.4.5 ヒストグラムと正規分布関数の重ね描き"
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": null,
433 | "metadata": {
434 | "id": "7abMvBMkVx30"
435 | },
436 | "outputs": [],
437 | "source": [
438 | "# 正規分布関数の定義\n",
439 | "def norm(x, mu, sigma):\n",
440 | " return np.exp(-((x - mu)/sigma)**2/2) / (np.sqrt(2 * np.pi) * sigma)\n",
441 | "\n",
442 | "# 推論結果から各パラメータの平均値を取得\n",
443 | "mean = summary1['mean']\n",
444 | "\n",
445 | "# muの平均値取得\n",
446 | "mean_mu0 = mean['mus[0]']\n",
447 | "mean_mu1 = mean['mus[1]']\n",
448 | "\n",
449 | "# sigmaの平均値取得\n",
450 | "mean_sigma0 = mean['sigmas[0]']\n",
451 | "mean_sigma1 = mean['sigmas[1]']\n",
452 | "\n",
453 | "# 正規分布関数値の計算\n",
454 | "x = np.arange(0.8, 3.0, 0.05)\n",
455 | "delta = 0.1\n",
456 | "y0 = norm(x, mean_mu0, mean_sigma0) * delta / n_components\n",
457 | "y1 = norm(x, mean_mu1, mean_sigma1) * delta / n_components\n",
458 | "\n",
459 | "# グラフ描画\n",
460 | "bins = np.arange(0.8, 3.0, delta)\n",
461 | "plt.rcParams['figure.figsize']=(6,6)\n",
462 | "fig, ax = plt.subplots()\n",
463 | "sns.histplot(data=df2, bins=bins, x='petal_width',\n",
464 | " hue='species', kde=True, ax=ax, stat='probability')\n",
465 | "ax.get_lines()[1].set_label('KDE versicolor')\n",
466 | "ax.get_lines()[0].set_label('KDE virginica')\n",
467 | "ax.plot(x, y0, c='b', lw=3, label='Bayse versicolor')\n",
468 | "ax.plot(x, y1, c='y', lw=3, label='Bayse virginica')\n",
469 | "ax.set_xticks(bins);\n",
470 | "ax.xaxis.set_tick_params(rotation=90)\n",
471 | "ax.set_title('ヒストグラムと正規分布関数の重ね描き')\n",
472 | "plt.legend();"
473 | ]
474 | },
475 | {
476 | "cell_type": "markdown",
477 | "metadata": {
478 | "id": "m2-RanqsKuu_"
479 | },
480 | "source": [
481 | "### 5.4.6 潜在変数の確率分布"
482 | ]
483 | },
484 | {
485 | "cell_type": "markdown",
486 | "metadata": {
487 | "id": "KsMpViBfI7Rb"
488 | },
489 | "source": [
490 | "#### petal_widthの値が1.0, 1.5, 1.7, 2.0, 2.5のインデックスを調べる"
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": null,
496 | "metadata": {
497 | "id": "ANHu670nK4u1"
498 | },
499 | "outputs": [],
500 | "source": [
501 | "value_list = [1.0, 1.5, 1.7, 2.0, 2.5]\n",
502 | "\n",
503 | "df_heads = pd.DataFrame(None)\n",
504 | "\n",
505 | "# petal_widthの値が1.0から2.5までそれぞれの値である先頭の行を抽出\n",
506 | "for value in value_list:\n",
507 | "\n",
508 | " # df2からpetal_widthの値がvalueである行のみ抽出\n",
509 | " w = df2.query('petal_width == @value', engine='python')\n",
510 | "\n",
511 | " # 先頭の1行を抽出し、df_headsに連結\n",
512 | " df_heads = pd.concat([df_heads, w.head(1)], axis=0)\n",
513 | "\n",
514 | "# 結果確認\n",
515 | "display(df_heads)"
516 | ]
517 | },
518 | {
519 | "cell_type": "markdown",
520 | "metadata": {
521 | "id": "j5_2fT5LKQll"
522 | },
523 | "source": [
524 | "#### petal_widthの値の違いによる潜在変数sの確率分布の可視化"
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": null,
530 | "metadata": {
531 | "id": "pE8grEh-t4xk"
532 | },
533 | "outputs": [],
534 | "source": [
535 | "# df_headsのインデックスを抽出\n",
536 | "indexes, n_indexes = df_heads.index, len(df_heads)\n",
537 | "\n",
538 | "# 潜在変数sのサンプル値から、index=7, 1, 27, 60, 50の値を抽出\n",
539 | "sval = idata1.posterior['s'][:,:,indexes].values.reshape(-1,n_indexes).T\n",
540 | "\n",
541 | "# それぞれのケースでヒストグラムの描画\n",
542 | "plt.rcParams['figure.figsize']=(15,3)\n",
543 | "vlist = df_heads['petal_width']\n",
544 | "fig, axes = plt.subplots(1, n_indexes)\n",
545 | "for ax, item, value, index in zip(axes, sval, vlist, indexes):\n",
546 | " f = pd.DataFrame(item)\n",
547 | " f.hist(ax=ax, bins=2)\n",
548 | " ax.set_ylim(0,2000)\n",
549 | " ax.set_title(f'value={value} index={index}')\n",
550 | "plt.tight_layout();"
551 | ]
552 | },
553 | {
554 | "cell_type": "markdown",
555 | "metadata": {
556 | "id": "yh7xE04ZZ88t"
557 | },
558 | "source": [
559 | " ### コラム 潜在変数モデルにおけるベイズ推論のツボ"
560 | ]
561 | },
562 | {
563 | "cell_type": "markdown",
564 | "metadata": {
565 | "id": "evzU0aqd5XNK"
566 | },
567 | "source": [
568 | "#### 意図しない結果になる確率モデル"
569 | ]
570 | },
571 | {
572 | "cell_type": "markdown",
573 | "metadata": {
574 | "id": "9VaeidrvVNVx"
575 | },
576 | "source": [
577 | "##### 確率モデル定義"
578 | ]
579 | },
580 | {
581 | "cell_type": "code",
582 | "execution_count": null,
583 | "metadata": {
584 | "id": "OPgJxdeL5JM0"
585 | },
586 | "outputs": [],
587 | "source": [
588 | "model2 = pm.Model()\n",
589 | "\n",
590 | "with model2:\n",
591 | " # Xの観測値をConstantDataとして定義\n",
592 | " X_data = pm.ConstantData('X_data', X)\n",
593 | "\n",
594 | " # p: 潜在変数が1の値をとる確率\n",
595 | " p = pm.Uniform('p', lower=0.0, upper=1.0)\n",
596 | "\n",
597 | " # s: 潜在変数 pの確率値をもとに0, 1のいずれかの値を返す\n",
598 | " s = pm.Bernoulli('s', p=p, shape=N)\n",
599 | "\n",
600 | " # mus: 2つの花の種類毎の平均値\n",
601 | " mus = pm.Normal('mus', mu=0.0, sigma=10.0, shape=n_components)\n",
602 | "\n",
603 | " # sigmas: 2つの花の種類毎のバラツキ\n",
604 | " sigmas = pm.HalfNormal('sigmas', sigma=10.0, shape=n_components)\n",
605 | "\n",
606 | " # 各観測値ごとに潜在変数から平均値と標準偏差を求める\n",
607 | " mu = pm.Deterministic('mu', mus[s])\n",
608 | " sigma = pm.Deterministic('sigma', sigmas[s])\n",
609 | "\n",
610 | " # 正規分布によりxの値を求める\n",
611 | " X_obs = pm.Normal('X_obs', mu=mu, sigma=sigma, observed=X_data)\n",
612 | "\n",
613 | "# 確率モデル構造可視化\n",
614 | "g = pm.model_to_graphviz(model2)\n",
615 | "display(g)"
616 | ]
617 | },
618 | {
619 | "cell_type": "markdown",
620 | "metadata": {
621 | "id": "lpa-9fp5bhu1"
622 | },
623 | "source": [
624 | "##### サンプリングとplot_trace関数呼び出し"
625 | ]
626 | },
627 | {
628 | "cell_type": "code",
629 | "execution_count": null,
630 | "metadata": {
631 | "id": "aAuXArC7brMj"
632 | },
633 | "outputs": [],
634 | "source": [
635 | "with model2:\n",
636 | " # サンプリング\n",
637 | " idata2 = pm.sample(random_seed=42, chains=1, target_accept=0.998)\n",
638 | "\n",
639 | "# plot_trace関数で推論結果の確認\n",
640 | "az.plot_trace(idata2, var_names=['p', 'mus', 'sigmas'], compact=False)\n",
641 | "plt.tight_layout();"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": null,
647 | "metadata": {
648 | "id": "R0FKbOg-gIVp"
649 | },
650 | "outputs": [],
651 | "source": [
652 | "summary2 = az.summary(idata2, var_names=['p', 'mus', 'sigmas'])\n",
653 | "display(summary2)"
654 | ]
655 | },
656 | {
657 | "cell_type": "markdown",
658 | "metadata": {
659 | "id": "Tn63ZJlGmPzI"
660 | },
661 | "source": [
662 | "#### ラベルスイッチが起きない確率モデル"
663 | ]
664 | },
665 | {
666 | "cell_type": "markdown",
667 | "metadata": {
668 | "id": "X6nN310nVS14"
669 | },
670 | "source": [
671 | "##### 確率モデル定義"
672 | ]
673 | },
674 | {
675 | "cell_type": "code",
676 | "execution_count": null,
677 | "metadata": {
678 | "id": "Ifp_8MNKxyBn"
679 | },
680 | "outputs": [],
681 | "source": [
682 | "# 変数の初期設定\n",
683 | "\n",
684 | "# 何種類の正規分布モデルがあるか\n",
685 | "n_components = 2\n",
686 | "\n",
687 | "# 観測データ件数\n",
688 | "N = X.shape\n",
689 | "\n",
690 | "model3 = pm.Model()\n",
691 | "\n",
692 | "with model3:\n",
693 | " # Xの観測値をConstantDataとして定義\n",
694 | " X_data = pm.ConstantData('X_data', X)\n",
695 | "\n",
696 | " # p: 潜在変数が1の値をとる確率\n",
697 | " p = pm.Uniform('p', lower=0.0, upper=1.0)\n",
698 | "\n",
699 | " # s: 潜在変数 pの確率値をもとに0, 1のいずれかの値を返す\n",
700 | " s = pm.Bernoulli('s', p=p, shape=N)\n",
701 | "\n",
702 | " # mus: 2つの花の種類毎の平均値\n",
703 | " mu0 = pm.HalfNormal('mu0', sigma=10.0)\n",
704 | " delta0 = pm.HalfNormal('delta0', sigma=10.0)\n",
705 | " mu1 = pm.Deterministic('mu1', mu0+delta0)\n",
706 | " mus = pm.Deterministic('mus',pm.math.stack([mu0, mu1]))\n",
707 | "\n",
708 | " # taus: 2つの花の種類毎のバラツキ\n",
709 | " # 標準偏差sigmasとの間にはtaus = 1/(sigmas*sigmas)の関係がある\n",
710 | " taus = pm.HalfNormal('taus', sigma=10.0, shape=n_components)\n",
711 | "\n",
712 | " # グラフ描画など分析でsigmasが必要なため、tausからsigmasを求めておく\n",
713 | " sigmas = pm.Deterministic('sigmas', 1/pm.math.sqrt(taus))\n",
714 | "\n",
715 | " # 各観測値ごとに潜在変数からmuとtauを求める\n",
716 | " mu = pm.Deterministic('mu', mus[s])\n",
717 | " tau = pm.Deterministic('tau', taus[s])\n",
718 | "\n",
719 | " # 正規分布に従う確率変数X_obsの定義\n",
720 | " X_obs = pm.Normal('X_obs', mu=mu, tau=tau, observed=X_data)\n",
721 | "\n",
722 | "g = pm.model_to_graphviz(model3)\n",
723 | "display(g)"
724 | ]
725 | },
726 | {
727 | "cell_type": "markdown",
728 | "metadata": {
729 | "id": "qXYa1ZbwcBrq"
730 | },
731 | "source": [
732 | "##### サンプリングとplot_trace関数呼び出し"
733 | ]
734 | },
735 | {
736 | "cell_type": "code",
737 | "execution_count": null,
738 | "metadata": {
739 | "id": "R3eULAMgnAz4"
740 | },
741 | "outputs": [],
742 | "source": [
743 | "with model3:\n",
744 | " # サンプリング\n",
745 | " idata3 = pm.sample(random_seed=42, target_accept=0.999)\n",
746 | "\n",
747 | "# plot_trace関数で推論結果の確認\n",
748 | "az.plot_trace(idata3, var_names=['p', 'mus', 'sigmas'], compact=False)\n",
749 | "plt.tight_layout();"
750 | ]
751 | },
752 | {
753 | "cell_type": "code",
754 | "execution_count": null,
755 | "metadata": {
756 | "id": "0jG4FmB3nxHm"
757 | },
758 | "outputs": [],
759 | "source": [
760 | "summary3 = az.summary(idata3, var_names=['p', 'mus', 'sigmas'])\n",
761 | "display(summary3)"
762 | ]
763 | },
764 | {
765 | "cell_type": "markdown",
766 | "metadata": {
767 | "id": "1VoXqEGUH3Nt"
768 | },
769 | "source": [
770 | "##### グラフ描画"
771 | ]
772 | },
773 | {
774 | "cell_type": "code",
775 | "execution_count": null,
776 | "metadata": {
777 | "id": "zTr-6VdbDK6w"
778 | },
779 | "outputs": [],
780 | "source": [
781 | "# サイズ設定\n",
782 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
783 | "\n",
784 | "# 正規分布関数の定義\n",
785 | "def norm(x, mu, sigma):\n",
786 | " return np.exp(-((x - mu)/sigma)**2/2) / (np.sqrt(2 * np.pi) * sigma)\n",
787 | "\n",
788 | "# 推論結果から各パラメータの平均値を取得\n",
789 | "mean = summary3['mean']\n",
790 | "\n",
791 | "# muの平均値取得\n",
792 | "mean_mu0 = mean['mus[0]']\n",
793 | "mean_mu1 = mean['mus[1]']\n",
794 | "\n",
795 | "# sigmaの平均値取得\n",
796 | "mean_sigma0 = mean['sigmas[0]']\n",
797 | "mean_sigma1 = mean['sigmas[1]']\n",
798 | "\n",
799 | "# 正規分布関数値の計算\n",
800 | "y0 = norm(x, mean_mu0, mean_sigma0) * delta / n_components\n",
801 | "y1 = norm(x, mean_mu1, mean_sigma1) * delta / n_components\n",
802 | "\n",
803 | "# グラフ描画\n",
804 | "delta = 0.1\n",
805 | "bins = np.arange(0.8, 3.0, delta)\n",
806 | "x = np.arange(0.8, 3.0, 0.05)\n",
807 | "plt.rcParams['figure.figsize']=(6,6)\n",
808 | "fig, ax = plt.subplots()\n",
809 | "sns.histplot(data=df2, bins=bins, x='petal_width',\n",
810 | " hue='species', kde=True, ax=ax, stat='probability')\n",
811 | "ax.get_lines()[1].set_label('KDE versicolor')\n",
812 | "ax.get_lines()[0].set_label('KDE virginica')\n",
813 | "ax.xaxis.set_tick_params(rotation=90)\n",
814 | "ax.set_title('ヒストグラムと正規分布関数の重ね描き\\n(ラベルスイッチ対策版)')\n",
815 | "ax.set_xticks(bins);\n",
816 | "ax.plot(x, y0, c='b', lw=3, label='Bayse versicolor')\n",
817 | "ax.plot(x, y1, c='y', lw=3, label='Bayse virginica')\n",
818 | "plt.legend();"
819 | ]
820 | },
821 | {
822 | "cell_type": "markdown",
823 | "metadata": {
824 | "id": "x5DEy2muwyeQ"
825 | },
826 | "source": [
827 | "#### バージョンの確認"
828 | ]
829 | },
830 | {
831 | "cell_type": "code",
832 | "execution_count": null,
833 | "metadata": {
834 | "id": "zEa9F-tuH9fU"
835 | },
836 | "outputs": [],
837 | "source": [
838 | "!pip install watermark | tail -n 1\n",
839 | "%load_ext watermark\n",
840 | "%watermark --iversions"
841 | ]
842 | }
843 | ],
844 | "metadata": {
845 | "colab": {
846 | "provenance": [],
847 | "toc_visible": true
848 | },
849 | "kernelspec": {
850 | "display_name": "Python 3",
851 | "name": "python3"
852 | },
853 | "language_info": {
854 | "name": "python"
855 | }
856 | },
857 | "nbformat": 4,
858 | "nbformat_minor": 0
859 | }
--------------------------------------------------------------------------------
/notebooks/6_1_ABテスト効果検証.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "``plot_trace``関数でワーニングを出さないため次のコマンド実行が必要です。"
7 | ],
8 | "metadata": {
9 | "id": "zk5cJ_HGf8jU"
10 | }
11 | },
12 | {
13 | "cell_type": "code",
14 | "source": [
15 | "!pip install --upgrade numba | tail -n 1"
16 | ],
17 | "metadata": {
18 | "id": "NvfsFBPEf7pw"
19 | },
20 | "execution_count": null,
21 | "outputs": []
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {
26 | "id": "uIoj_UW_qasp"
27 | },
28 | "source": [
29 | "## 6.1 ABテストの効果検証"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {
35 | "id": "h0cR7yvwll73"
36 | },
37 | "source": [
38 | "
"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "vdi2sqRyqr4p"
45 | },
46 | "source": [
47 | "### 共通処理"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "id": "0dTjJOmIqmiS"
55 | },
56 | "outputs": [],
57 | "source": [
58 | "%matplotlib inline\n",
59 | "# 日本語化ライブラリ導入\n",
60 | "!pip install japanize-matplotlib | tail -n 1"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {
67 | "id": "M46-MnBXq5-Z"
68 | },
69 | "outputs": [],
70 | "source": [
71 | "# ライブラリのimport\n",
72 | "\n",
73 | "# NumPy用ライブラリ\n",
74 | "import numpy as np\n",
75 | "# Matplotlib中のpyplotライブラリのインポート\n",
76 | "import matplotlib.pyplot as plt\n",
77 | "# matplotlib日本語化対応ライブラリのインポート\n",
78 | "import japanize_matplotlib\n",
79 | "# pandas用ライブラリ\n",
80 | "import pandas as pd\n",
81 | "# データフレーム表示用関数\n",
82 | "from IPython.display import display\n",
83 | "# seaborn\n",
84 | "import seaborn as sns\n",
85 | "# 表示オプション調整\n",
86 | "# NumPy表示形式の設定\n",
87 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
88 | "# グラフのデフォルトフォント指定\n",
89 | "plt.rcParams[\"font.size\"] = 14\n",
90 | "# サイズ設定\n",
91 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
92 | "# 方眼表示ON\n",
93 | "plt.rcParams['axes.grid'] = True\n",
94 | "# データフレームでの表示精度\n",
95 | "pd.options.display.float_format = '{:.3f}'.format\n",
96 | "# データフレームですべての項目を表示\n",
97 | "pd.set_option(\"display.max_columns\",None)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "id": "my5YU9YCq-ou"
105 | },
106 | "outputs": [],
107 | "source": [
108 | "import pymc as pm\n",
109 | "import arviz as az\n",
110 | "\n",
111 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
112 | "print(f\"Running on ArViz v{az.__version__}\")"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "rBIfYJ_BrGUD"
119 | },
120 | "source": [
121 | "### 6.1.1 問題設定\n",
122 | "鈴木さんと佐藤さんはそれぞれ自分の担当のWebページに改善を加えました。 \n",
123 | "改善後のページBと従来のページAをランダムに提示するABテストを実施し、効果検証をしました。 \n",
124 | "その結果が以下の通りであった場合、それぞれの改善に効果があったかどうか判断したいということがビジネス上の課題と考えて下さい。\n",
125 | "\n",
126 | "\n",
127 | "| | | 鈴木さん | 山田さん |\n",
128 | "| --------------- | ---------- | --------: | --------: |\n",
129 | "| 従来のページA | 表示数 | 40 | 1200 |\n",
130 | "| | クリック数 | 2 | 60 |\n",
131 | "| | クリック率 | 5% | 5% |\n",
132 | "| 改善後のページB | 表示数 | 25 | 1500 |\n",
133 | "| | クリック数 | 2 | 110 |\n",
134 | "| | クリック率 | 8% | 6.88% |\n",
135 | "\n"
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "metadata": {
141 | "id": "PgH4jGWpvoW9"
142 | },
143 | "source": [
144 | "### 6.1.2 確率モデル定義(鈴木さんケース)\n"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {
150 | "id": "XPXTh1tG6Ldo"
151 | },
152 | "source": [
153 | "#### 確率モデル定義と確率モデル構造可視化\n",
154 | "最初に鈴木さんケースで分析する"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "id": "Bdk8zw3fuphC"
162 | },
163 | "outputs": [],
164 | "source": [
165 | "model_s = pm.Model()\n",
166 | "\n",
167 | "with pm.Model() as model_s:\n",
168 | " # 事前分布として一様分布を採用\n",
169 | " p_s_a = pm.Uniform('p_s_a', lower=0.0, upper=1.0)\n",
170 | " p_s_b = pm.Uniform('p_s_b', lower=0.0, upper=1.0)\n",
171 | "\n",
172 | " # 二項分布で確率モデルを定義\n",
173 | " # n:表示数 observed:ヒット数 とすればよい\n",
174 | " obs_s_a = pm.Binomial('obs_s_a', p=p_s_a, n=40, observed=2)\n",
175 | " obs_s_b = pm.Binomial('obs_s_b', p=p_s_b, n=25, observed=2)\n",
176 | "\n",
177 | " # 新たな確率変数として二つの確率変数の差を定義\n",
178 | " delta_prob_s = pm.Deterministic('delta_prob_s', p_s_b - p_s_a)\n",
179 | "\n",
180 | "# 確率モデル構造可視化\n",
181 | "g = pm.model_to_graphviz(model_s)\n",
182 | "display(g)"
183 | ]
184 | },
185 | {
186 | "cell_type": "markdown",
187 | "metadata": {
188 | "id": "_mP-yzOLv_9n"
189 | },
190 | "source": [
191 | "### 6.1.3 サンプリングと結果分析"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {
197 | "id": "wcImDr5fDwI8"
198 | },
199 | "source": [
200 | "#### サンプリング"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "metadata": {
207 | "id": "zQSSO-S9v4Qr"
208 | },
209 | "outputs": [],
210 | "source": [
211 | "with model_s:\n",
212 | " idata_s = pm.sample(random_seed=42, target_accept=0.99)"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {
218 | "id": "clM0U3APD3Yc"
219 | },
220 | "source": [
221 | "#### plot_trace関数で推論結果の確認"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": null,
227 | "metadata": {
228 | "id": "dHkdwKit1A3W"
229 | },
230 | "outputs": [],
231 | "source": [
232 | "az.plot_trace(idata_s, compact=False)\n",
233 | "plt.tight_layout();"
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "metadata": {
239 | "id": "zJVAPn7RECIc"
240 | },
241 | "source": [
242 | "#### delta_prob_sの分布を可視化"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": null,
248 | "metadata": {
249 | "id": "gZRIn1vF1I7t"
250 | },
251 | "outputs": [],
252 | "source": [
253 | "ax = az.plot_posterior(idata_s, var_names=['delta_prob_s'])\n",
254 | "xx, yy = ax.get_lines()[0].get_data()\n",
255 | "ax.fill_between(xx[xx<0], yy[xx<0]);"
256 | ]
257 | },
258 | {
259 | "cell_type": "markdown",
260 | "metadata": {
261 | "id": "lXP15bA1EN9M"
262 | },
263 | "source": [
264 | "#### 画面Aの方がクリック率が高い確率を計算"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": null,
270 | "metadata": {
271 | "id": "VjWALJCcTDD-"
272 | },
273 | "outputs": [],
274 | "source": [
275 | "# サンプリング結果から delta_probの値を抽出\n",
276 | "delta_prob_s = idata_s['posterior'].data_vars['delta_prob_s']\n",
277 | "delta_prob_s_values = delta_prob_s.values.reshape(-1)\n",
278 | "\n",
279 | "# delta_probの値がマイナスであった件数\n",
280 | "n1_s = (delta_prob_s_values < 0).sum()\n",
281 | "\n",
282 | "# 全体サンプル数\n",
283 | "n_s = len(delta_prob_s_values)\n",
284 | "\n",
285 | "# 比率計算\n",
286 | "n1_rate_s = n1_s/n_s\n",
287 | "print(f'鈴木さんケース 画面Aの方がクリック率が高い確率: {n1_rate_s*100:.02f}%')"
288 | ]
289 | },
290 | {
291 | "cell_type": "markdown",
292 | "metadata": {
293 | "id": "4hIky2aPLPPj"
294 | },
295 | "source": [
296 | "### 6.1.4 山田さんケースでABテスト効果検証"
297 | ]
298 | },
299 | {
300 | "cell_type": "markdown",
301 | "metadata": {
302 | "id": "XmbI0RY7LhAQ"
303 | },
304 | "source": [
305 | "#### 確率モデル構築・サンプリング・結果分析"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": null,
311 | "metadata": {
312 | "id": "F-LqzuFftSHq"
313 | },
314 | "outputs": [],
315 | "source": [
316 | "model_y = pm.Model()\n",
317 | "\n",
318 | "with pm.Model() as model_y:\n",
319 | " # 事前分布として一様分布を採用\n",
320 | " p_y_a = pm.Uniform('p_y_a', lower=0.0, upper=1.0)\n",
321 | " p_y_b = pm.Uniform('p_y_b', lower=0.0, upper=1.0)\n",
322 | "\n",
323 | " # 二項分布で確率モデルを定義\n",
324 | " # n:表示数 observed:ヒット数 とすればよい\n",
325 | " obs_y_a = pm.Binomial('obs_y_a', p=p_y_a, n=1200, observed=60)\n",
326 | " obs_y_b = pm.Binomial('obs_y_b', p=p_y_b, n=1600, observed=110)\n",
327 | "\n",
328 | " # 新たな確率変数として二つの確率変数の差を定義\n",
329 | " delta_prob_y = pm.Deterministic('delta_prob_y', p_y_b - p_y_a)\n",
330 | "\n",
331 | "# サンプリング\n",
332 | "with model_y:\n",
333 | " idata_y = pm.sample(random_seed=42, target_accept=0.99)\n",
334 | "\n",
335 | "# traceの確認\n",
336 | "az.plot_trace(idata_y, compact=False)\n",
337 | "plt.tight_layout();\n",
338 | "plt.show()\n",
339 | "\n",
340 | "# delta_prob_yの分布を可視化\n",
341 | "ax = az.plot_posterior(idata_y, var_names=['delta_prob_y'])\n",
342 | "xx, yy = ax.get_lines()[0].get_data()\n",
343 | "ax.fill_between(xx[xx<0], yy[xx<0]);"
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": null,
349 | "metadata": {
350 | "id": "G-kXINLGvObv"
351 | },
352 | "outputs": [],
353 | "source": [
354 | "# サンプリング結果から delta_probの値を抽出\n",
355 | "delta_prob_y = idata_y['posterior'].data_vars['delta_prob_y']\n",
356 | "delta_prob_y_values = delta_prob_y.values.reshape(-1)\n",
357 | "\n",
358 | "# delta_probの値がマイナスであった件数\n",
359 | "n1_y = (delta_prob_y_values < 0).sum()\n",
360 | "\n",
361 | "# 全体サンプル数\n",
362 | "n_y = len(delta_prob_y_values)\n",
363 | "\n",
364 | "# 比率計算\n",
365 | "n1_rate_y = n1_y/n_y\n",
366 | "\n",
367 | "print(f'山田さんケース 画面Aの方がクリック率が高い確率: {n1_rate_y*100:.02f}%')"
368 | ]
369 | },
370 | {
371 | "cell_type": "markdown",
372 | "metadata": {
373 | "id": "I0_cnO4fOoPL"
374 | },
375 | "source": [
376 | "### 6.1.5 確率モデルを直接使った別解\n"
377 | ]
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "metadata": {
382 | "id": "cXPwAFZbO7oE"
383 | },
384 | "source": [
385 | "#### 鈴木さんケース"
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": null,
391 | "metadata": {
392 | "id": "onseiB__Xppa"
393 | },
394 | "outputs": [],
395 | "source": [
396 | "# 画面A 成功2回 失敗38回\n",
397 | "alpha_a = 2 + 1\n",
398 | "beta_a = 38 + 1\n",
399 | "\n",
400 | "# 画面B 成功2回 失敗23回\n",
401 | "alpha_b = 2 + 1\n",
402 | "beta_b = 23 + 1\n",
403 | "\n",
404 | "model_s2 = pm.Model()\n",
405 | "with model_s2:\n",
406 | " # 確率モデル定義\n",
407 | " # pm.Beta: ベータ分布\n",
408 | " # alpha: 注目している試行の成功数+1\n",
409 | " # beta: 注目している試行の失敗数+1\n",
410 | " p_a = pm.Beta('p_a', alpha=alpha_a, beta=beta_a)\n",
411 | " p_b = pm.Beta('p_b', alpha=alpha_b, beta=beta_b)\n",
412 | "\n",
413 | " # サンプル値取得\n",
414 | " samples_s2 = pm.sample_prior_predictive(random_seed=42, samples=10000)\n",
415 | "\n",
416 | "# サンプル値抽出\n",
417 | "p_a_samples_s2 = samples_s2['prior']['p_a'].values.reshape(-1)\n",
418 | "p_b_samples_s2 = samples_s2['prior']['p_b'].values.reshape(-1)\n",
419 | "delta_a_b_s2 = p_b_samples_s2 - p_a_samples_s2\n",
420 | "\n",
421 | "# delta_probの値がマイナスであった件数\n",
422 | "n1_s2 = (delta_a_b_s2 < 0).sum()\n",
423 | "\n",
424 | "# 全体サンプル数\n",
425 | "n_s2 = len(delta_a_b_s2)\n",
426 | "\n",
427 | "# 比率計算\n",
428 | "n1_rate_s2 = n1_s2/n_s2\n",
429 | "\n",
430 | "# 可視化\n",
431 | "ax = az.plot_dist(delta_a_b_s2)\n",
432 | "xx, yy = ax.get_lines()[0].get_data()\n",
433 | "ax.fill_between(xx[xx<0], yy[xx<0])\n",
434 | "\n",
435 | "# グラフタイトル\n",
436 | "title = f'鈴木さんケース 画面Aの方がクリック率が高い確率(別解):\\\n",
437 | "{n1_rate_s2*100:.02f}%'\n",
438 | "ax.set_title(title, fontsize=12);"
439 | ]
440 | },
441 | {
442 | "cell_type": "markdown",
443 | "metadata": {
444 | "id": "fKvSJbNdTe6M"
445 | },
446 | "source": [
447 | "#### 山田さんケース"
448 | ]
449 | },
450 | {
451 | "cell_type": "code",
452 | "execution_count": null,
453 | "metadata": {
454 | "id": "lK-Y5pQLRqqA"
455 | },
456 | "outputs": [],
457 | "source": [
458 | "# A 成功60回 失敗1140回\n",
459 | "alpha_a = 60 + 1\n",
460 | "beta_a = 1140 + 1\n",
461 | "\n",
462 | "# B 成功110回 失敗1490回\n",
463 | "alpha_b = 110 + 1\n",
464 | "beta_b = 1490 + 1\n",
465 | "\n",
466 | "model_y2 = pm.Model()\n",
467 | "with model_y2:\n",
468 | " # 確率モデル定義\n",
469 | " # pm.Beta: ベータ分布\n",
470 | " # alpha: 注目している試行の成功数+1\n",
471 | " # beta: 注目している試行の失敗数+1\n",
472 | " p_a = pm.Beta('p_a', alpha=alpha_a, beta=beta_a)\n",
473 | " p_b = pm.Beta('p_b', alpha=alpha_b, beta=beta_b)\n",
474 | "\n",
475 | " # サンプル値取得\n",
476 | " samples_y2 = pm.sample_prior_predictive(random_seed=42, samples=10000)\n",
477 | "\n",
478 | "# サンプル値抽出\n",
479 | "p_a_samples_y2 = samples_y2['prior']['p_a'].values.reshape(-1)\n",
480 | "p_b_samples_y2 = samples_y2['prior']['p_b'].values.reshape(-1)\n",
481 | "delta_a_b_y2 = p_b_samples_y2 - p_a_samples_y2\n",
482 | "\n",
483 | "# delta_probの値がマイナスであった件数\n",
484 | "n1_y2 = (delta_a_b_y2 < 0).sum()\n",
485 | "\n",
486 | "# 全体サンプル数\n",
487 | "n_y2 = len(delta_a_b_y2)\n",
488 | "\n",
489 | "# 比率計算\n",
490 | "n1_rate_y2 = n1_y2/n_y2\n",
491 | "\n",
492 | "# 可視化\n",
493 | "ax = az.plot_dist(delta_a_b_y2)\n",
494 | "xx, yy = ax.get_lines()[0].get_data()\n",
495 | "ax.fill_between(xx[xx<0], yy[xx<0])\n",
496 | "\n",
497 | "# グラフタイトル\n",
498 | "title = f'山田さんケース 画面Aの方がクリック率が高い確率(別解):\\\n",
499 | "{n1_rate_y2*100:.02f}%'\n",
500 | "ax.set_title(title, fontsize=12);"
501 | ]
502 | },
503 | {
504 | "cell_type": "markdown",
505 | "metadata": {
506 | "id": "jE102bygTmYs"
507 | },
508 | "source": [
509 | "#### バージョンの確認"
510 | ]
511 | },
512 | {
513 | "cell_type": "code",
514 | "execution_count": null,
515 | "metadata": {
516 | "id": "b2skIcAaSbFC"
517 | },
518 | "outputs": [],
519 | "source": [
520 | "!pip install watermark | tail -n 1\n",
521 | "%load_ext watermark\n",
522 | "%watermark --iversions"
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": null,
528 | "metadata": {
529 | "id": "Jsl0TNsgVlA-"
530 | },
531 | "outputs": [],
532 | "source": []
533 | }
534 | ],
535 | "metadata": {
536 | "colab": {
537 | "provenance": [],
538 | "toc_visible": true
539 | },
540 | "kernelspec": {
541 | "display_name": "Python 3",
542 | "name": "python3"
543 | },
544 | "language_info": {
545 | "name": "python"
546 | }
547 | },
548 | "nbformat": 4,
549 | "nbformat_minor": 0
550 | }
--------------------------------------------------------------------------------
/notebooks/6_2_ベイス回帰モデルによる効果検証.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "``plot_trace``関数でワーニングを出さないため次のコマンド実行が必要です。"
7 | ],
8 | "metadata": {
9 | "id": "IvEX8cm2gUo8"
10 | }
11 | },
12 | {
13 | "cell_type": "code",
14 | "source": [
15 | "!pip install --upgrade numba | tail -n 1"
16 | ],
17 | "metadata": {
18 | "id": "Q3bey6oygVkm"
19 | },
20 | "execution_count": null,
21 | "outputs": []
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {
26 | "id": "mfW1NghXXjph"
27 | },
28 | "source": [
29 | "## 6.2 ベイス回帰モデルによる効果検証\n",
30 | "参照リンク (PyMCチュートリアル) https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_overview.html#case-study-1-educational-outcomes-for-hearing-impaired-children\n",
31 | "\n",
32 | "チュートリアルでは正則化を含んだ複雑なモデルを作っているが、簡略化して単純な線形回帰モデルを作り、各説明変数の影響度を考察する"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {
38 | "id": "5GnkdjsgeSem"
39 | },
40 | "source": [
41 | "
"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {
47 | "id": "44VPXX5bgtKT"
48 | },
49 | "source": [
50 | "### 共通処理"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {
57 | "id": "64-SAjF1gkms"
58 | },
59 | "outputs": [],
60 | "source": [
61 | "%matplotlib inline\n",
62 | "# 日本語化ライブラリ導入\n",
63 | "!pip install japanize-matplotlib | tail -n 1"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {
70 | "id": "UKyvpbGBg3q4"
71 | },
72 | "outputs": [],
73 | "source": [
74 | "# ライブラリのimport\n",
75 | "\n",
76 | "# NumPy用ライブラリ\n",
77 | "import numpy as np\n",
78 | "# Matplotlib中のpyplotライブラリのインポート\n",
79 | "import matplotlib.pyplot as plt\n",
80 | "# matplotlib日本語化対応ライブラリのインポート\n",
81 | "import japanize_matplotlib\n",
82 | "# pandas用ライブラリ\n",
83 | "import pandas as pd\n",
84 | "# データフレーム表示用関数\n",
85 | "from IPython.display import display\n",
86 | "# seaborn\n",
87 | "import seaborn as sns\n",
88 | "# 表示オプション調整\n",
89 | "# NumPy表示形式の設定\n",
90 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
91 | "# グラフのデフォルトフォント指定\n",
92 | "plt.rcParams[\"font.size\"] = 14\n",
93 | "# サイズ設定\n",
94 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
95 | "# 方眼表示ON\n",
96 | "plt.rcParams['axes.grid'] = True\n",
97 | "# データフレームでの表示精度\n",
98 | "pd.options.display.float_format = '{:.3f}'.format\n",
99 | "# データフレームですべての項目を表示\n",
100 | "pd.set_option(\"display.max_columns\",None)"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {
107 | "id": "Iw1YsHsqiKOP"
108 | },
109 | "outputs": [],
110 | "source": [
111 | "import pymc as pm\n",
112 | "import arviz as az\n",
113 | "\n",
114 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
115 | "print(f\"Running on ArViz v{az.__version__}\")"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {
121 | "id": "pKjboyDdhAqu"
122 | },
123 | "source": [
124 | "### 6.2.1 問題設定"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {
130 | "id": "WyjYNL7khjm5"
131 | },
132 | "source": [
133 | " 本節では、Listening and Spoken Language Data Repository (LSL-DR)で提供されているデータを用いて、ベイス推論により線形回帰モデルを構築します。\n",
134 | "\n",
135 | " LSL-DRは難聴の子供の音声言語スキルの発達を支援する専門教育プログラムにおける国際的なデータリポジトリです。4か国の48のプログラムで、5,748人の難聴の子供の情報を収集し、彼らの音声言語の発達に影響を与える要因を調査するデータを収集しています。\n",
136 | "\n",
137 | "\n",
138 | " このデータセットには、以下のような項目が含まれています。\n",
139 | "\n",
140 | "| 項目名 | 説明 | 項目値 |\n",
141 | "| ------------- | ---------------------------------------------- | ----------- |\n",
142 | "| score | 能力テストのスコア(目的変数) | 0-144の整数 |\n",
143 | "| male | 性別 | 1/0 |\n",
144 | "| siblings | 世帯内の兄弟の数 | 非負整数値 |\n",
145 | "| family_inv | 家族の関与の指標 | 0-4の整数値 |\n",
146 | "| non_english | 家庭での主な言語が英語ではないか | True/False |\n",
147 | "| prev_disab | 以前の障害の存在 | 1/0 |\n",
148 | "| age_test | テスト時の年齢 | 48-59の整数 |\n",
149 | "| non_severe_hl | 重度の難聴ではないか | 1/0 |\n",
150 | "| mother_hs | 被験者の母親が高校卒業以上の学歴を持っているか | 1/0 |\n",
151 | "| early_ident | 聴覚障害が生後3か月までに特定されたか | True/False |\n",
152 | "| non_white | 非白人 | True/False |\n",
153 | "\n",
154 | "\n",
155 | " 目的変数は、学習ドメインにおいて標準化されたテストのスコア の1つを用いています(score)。"
156 | ]
157 | },
158 | {
159 | "cell_type": "markdown",
160 | "metadata": {
161 | "id": "j3bqcZ5oidnY"
162 | },
163 | "source": [
164 | "### 6.2.2 データ読み込み"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "metadata": {
171 | "id": "JsG19bl8g7dQ"
172 | },
173 | "outputs": [],
174 | "source": [
175 | "# LSL-DR データ読み込み\n",
176 | "df = pd.read_csv(pm.get_data('test_scores.csv'), index_col=0)\n",
177 | "\n",
178 | "# 結果確認\n",
179 | "display(df.head())"
180 | ]
181 | },
182 | {
183 | "cell_type": "markdown",
184 | "metadata": {
185 | "id": "a3R3OTtUmjvH"
186 | },
187 | "source": [
188 | "### 6.2.3 データ確認"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "metadata": {
194 | "id": "EDLkaX5HmzrI"
195 | },
196 | "source": [
197 | "#### 目的変数scoreの分布"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": null,
203 | "metadata": {
204 | "id": "LgbugmoTipUH"
205 | },
206 | "outputs": [],
207 | "source": [
208 | "bins = np.arange(0, 150, 10)\n",
209 | "fig, ax = plt.subplots()\n",
210 | "df['score'].hist(bins=bins, align='left')\n",
211 | "plt.setp(ax.get_xticklabels(), rotation=90)\n",
212 | "plt.title('目的変数scoreスコアの分布')\n",
213 | "plt.xticks(bins);"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {
219 | "id": "gh3mRSKPm6zZ"
220 | },
221 | "source": [
222 | "#### 統計情報の確認"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {
229 | "id": "i9BSIMn8isEm"
230 | },
231 | "outputs": [],
232 | "source": [
233 | "df.describe()"
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "metadata": {
239 | "id": "d4-nsEaHnETb"
240 | },
241 | "source": [
242 | "#### データ件数と欠損値の確認"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": null,
248 | "metadata": {
249 | "id": "usmq6yUGjTJy"
250 | },
251 | "outputs": [],
252 | "source": [
253 | "# データ件数の確認\n",
254 | "print(f'データ件数 {len(df)}\\n')\n",
255 | "\n",
256 | "# 欠損値値の確認\n",
257 | "print(df.isnull().sum())"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {
263 | "id": "kwypJMkVkkry"
264 | },
265 | "source": [
266 | "### 6.2.4 データ加工"
267 | ]
268 | },
269 | {
270 | "cell_type": "markdown",
271 | "metadata": {
272 | "id": "e3j_p0P1nS8K"
273 | },
274 | "source": [
275 | "#### 欠損値除去"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": null,
281 | "metadata": {
282 | "id": "ACHpxyaEjy5s"
283 | },
284 | "outputs": [],
285 | "source": [
286 | "# 欠損値除去\n",
287 | "df1 = df.dropna().astype(float)\n",
288 | "\n",
289 | "# データ件数の確認\n",
290 | "print(f'データ件数 {len(df1)}')"
291 | ]
292 | },
293 | {
294 | "cell_type": "markdown",
295 | "metadata": {
296 | "id": "MIfjVv1JneCl"
297 | },
298 | "source": [
299 | "#### 目的変数yと説明変数Xに分離"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": null,
305 | "metadata": {
306 | "id": "AaLZGSspkcK9"
307 | },
308 | "outputs": [],
309 | "source": [
310 | "y = df1.pop(\"score\")\n",
311 | "X = df1.copy()\n",
312 | "\n",
313 | "# Xの結果確認\n",
314 | "display(X.head())"
315 | ]
316 | },
317 | {
318 | "cell_type": "markdown",
319 | "metadata": {
320 | "id": "dka6ajgMnl-o"
321 | },
322 | "source": [
323 | "#### Xの正規化"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": null,
329 | "metadata": {
330 | "id": "hVfZQT4KlGhi"
331 | },
332 | "outputs": [],
333 | "source": [
334 | "X -= X.mean()\n",
335 | "X /= X.std()\n",
336 | "\n",
337 | "# 結果確認\n",
338 | "display(X.head())"
339 | ]
340 | },
341 | {
342 | "cell_type": "markdown",
343 | "metadata": {
344 | "id": "dgmMK-F9nr26"
345 | },
346 | "source": [
347 | "#### ベイズモデル構築に必要な変数定義"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "metadata": {
354 | "id": "2vZARmUiti9t"
355 | },
356 | "outputs": [],
357 | "source": [
358 | "# データ件数とデータ項目数の設定\n",
359 | "N, D = X.shape\n",
360 | "\n",
361 | "# 項目名一覧をcolumnsに設定する\n",
362 | "columns = X.columns.values\n",
363 | "\n",
364 | "# 結果確認\n",
365 | "print(f'N: {N} (データ件数)\\n')\n",
366 | "print(f'D: {D} (説明変数項目数)\\n')\n",
367 | "print(f'項目名一覧: {columns}')"
368 | ]
369 | },
370 | {
371 | "cell_type": "markdown",
372 | "metadata": {
373 | "id": "mJJgPCHlmmSe"
374 | },
375 | "source": [
376 | "### 6.2.5 確率モデル定義"
377 | ]
378 | },
379 | {
380 | "cell_type": "code",
381 | "execution_count": null,
382 | "metadata": {
383 | "id": "xYuuSAdIloc3"
384 | },
385 | "outputs": [],
386 | "source": [
387 | "# 説明変数リストをpredictorsとして定義\n",
388 | "model1 = pm.Model(coords={'predictors': columns})\n",
389 | "\n",
390 | "with model1:\n",
391 | " # Xは従来のベクトルが行列になる。転置していることに注意\n",
392 | " X_data = pm.ConstantData('X_data', X.T)\n",
393 | "\n",
394 | " # yが回帰モデルの目的変数\n",
395 | " y_data = pm.ConstantData('y_data', y)\n",
396 | "\n",
397 | " # 単回帰のときスカラーだったalphaは重回帰でベクトルになる\n",
398 | " # 要素数はpredictorsにより間接的に指定できる(上でcoordsパラメータを指定した効果)\n",
399 | " alpha = pm.Normal('alpha', mu=0.0, sigma=10.0, dims='predictors')\n",
400 | "\n",
401 | " # betaとepsilonは単回帰の時と同じ(パラメータ値の選定理由は本文で説明)\n",
402 | " beta = pm.Normal('beta', mu=100.0, sigma=25.0)\n",
403 | " epsilon = pm.HalfNormal('epsilon', sigma=25.0)\n",
404 | "\n",
405 | " # muの計算では、単回帰のときのかけ算が内積に変わっている\n",
406 | " mu = pm.Deterministic('mu', alpha @ X_data + beta)\n",
407 | "\n",
408 | " # 正規分布の定義は5.2節の単回帰と同じ\n",
409 | " obs = pm.Normal('obs', mu=mu, sigma=epsilon, observed=y_data)\n",
410 | "\n",
411 | "g = pm.model_to_graphviz(model1)\n",
412 | "display(g)"
413 | ]
414 | },
415 | {
416 | "cell_type": "markdown",
417 | "metadata": {
418 | "id": "Gh-YIGW4xVqm"
419 | },
420 | "source": [
421 | "### 6.2.6 サンプリングと結果分析"
422 | ]
423 | },
424 | {
425 | "cell_type": "markdown",
426 | "metadata": {
427 | "id": "7UVoAYaaJRRu"
428 | },
429 | "source": [
430 | "#### サンプリングとplot_trace関数による」結果分析"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": null,
436 | "metadata": {
437 | "id": "0ogASrZAwJWy"
438 | },
439 | "outputs": [],
440 | "source": [
441 | "with model1:\n",
442 | " idata1 = pm.sample(random_seed=42, target_accept=0.95)\n",
443 | "\n",
444 | "# plot_trace関数で推論結果の確認\n",
445 | "az.plot_trace(idata1, var_names=['alpha', 'beta', 'epsilon'], compact=False)\n",
446 | "plt.tight_layout();"
447 | ]
448 | },
449 | {
450 | "cell_type": "markdown",
451 | "metadata": {
452 | "id": "IpIKPWPqJkmL"
453 | },
454 | "source": [
455 | "#### サンプリング結果の統計情報取得"
456 | ]
457 | },
458 | {
459 | "cell_type": "code",
460 | "execution_count": null,
461 | "metadata": {
462 | "id": "3Q5d5AY2249Q"
463 | },
464 | "outputs": [],
465 | "source": [
466 | "summary1 = az.summary(idata1, var_names=['alpha'])\n",
467 | "display(summary1)"
468 | ]
469 | },
470 | {
471 | "cell_type": "markdown",
472 | "metadata": {
473 | "id": "mIxZ821fRM5O"
474 | },
475 | "source": [
476 | "#### plot_forest関数で各項目の寄与度の確認"
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": null,
482 | "metadata": {
483 | "id": "d3yNOWxn0jbF"
484 | },
485 | "outputs": [],
486 | "source": [
487 | "az.plot_forest(idata1, combined=True, var_names=['alpha']);"
488 | ]
489 | },
490 | {
491 | "cell_type": "markdown",
492 | "metadata": {
493 | "id": "kEXNOZHARYxZ"
494 | },
495 | "source": [
496 | "#### plot_forest関数 combinedオプションを指定しない場合"
497 | ]
498 | },
499 | {
500 | "cell_type": "code",
501 | "execution_count": null,
502 | "metadata": {
503 | "id": "gzG5D3Dbx95z"
504 | },
505 | "outputs": [],
506 | "source": [
507 | "az.plot_forest(idata1, var_names=['alpha']);"
508 | ]
509 | },
510 | {
511 | "cell_type": "markdown",
512 | "metadata": {
513 | "id": "l_NxOrYJvHsk"
514 | },
515 | "source": [
516 | "### コラム チュートリアルの確率モデル"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "metadata": {
522 | "id": "AbrseeJTvu-8"
523 | },
524 | "source": [
525 | "#### 確率モデル定義"
526 | ]
527 | },
528 | {
529 | "cell_type": "code",
530 | "execution_count": null,
531 | "metadata": {
532 | "id": "rO7ETJPNvdXl"
533 | },
534 | "outputs": [],
535 | "source": [
536 | "# D0の定義\n",
537 | "D0 = int(D / 2)\n",
538 | "\n",
539 | "# 説明変数リストをpredictorsとして定義\n",
540 | "model2 = pm.Model(coords={'predictors': columns})\n",
541 | "\n",
542 | "with model2:\n",
543 | "\n",
544 | " # Xは従来のベクトルが行列になる。転置していることに注意\n",
545 | " X_data = pm.ConstantData('X_data', X.T)\n",
546 | "\n",
547 | " # yが回帰モデルの目的変数\n",
548 | " y_data = pm.ConstantData('y_data', y)\n",
549 | "\n",
550 | " # 誤差の分布 sigma -> epsilon 文字の置き換えのみ\n",
551 | " epsilon = pm.HalfNormal('epsilon', sigma=25.0)\n",
552 | "\n",
553 | " # 一次関数の係数の分布 beta -> alpha 定義内容も全面的に変更\n",
554 | "\n",
555 | " # 事前分布の全体的な縮小\n",
556 | " tau = pm.HalfStudentT(\"tau\", 2, D0 / (D - D0) * epsilon / np.sqrt(N))\n",
557 | "\n",
558 | " # 事前分布の局所的な縮小\n",
559 | " lam = pm.HalfStudentT(\"lam\", 2, dims=\"predictors\")\n",
560 | " c2 = pm.InverseGamma(\"c2\", 1, 0.1)\n",
561 | " z = pm.Normal(\"z\", 0.0, 1.0, dims=\"predictors\")\n",
562 | "\n",
563 | " alpha = pm.Deterministic(\n",
564 | " \"alpha\", z * tau * lam * pm.math.sqrt(\n",
565 | " c2 / (c2 + tau**2 * lam**2)), dims=\"predictors\")\n",
566 | "\n",
567 | " # 定数項 beta0 -> beta 文字の置き換えのみ\n",
568 | " beta = pm.Normal(\"beta\", mu=100.0, sigma=25.0)\n",
569 | "\n",
570 | " # 正規分布の平均 mu np.dotを@に変えたがロジックは同じ\n",
571 | " mu = pm.Deterministic('mu', alpha @ X_data + beta)\n",
572 | "\n",
573 | " # 観測値の分布 scores -> obs 文字の置き換えのみ\n",
574 | " obs = pm.Normal(\"obs\", mu, epsilon, observed=y_data)\n",
575 | "\n",
576 | "# 確率モデル可視化\n",
577 | "g = pm.model_to_graphviz(model2)\n",
578 | "display(g)"
579 | ]
580 | },
581 | {
582 | "cell_type": "markdown",
583 | "metadata": {
584 | "id": "i5ZqZqnh2T4k"
585 | },
586 | "source": [
587 | "#### サンプリングと推論結果の確認"
588 | ]
589 | },
590 | {
591 | "cell_type": "code",
592 | "execution_count": null,
593 | "metadata": {
594 | "id": "JtNJ2g0ev4TO"
595 | },
596 | "outputs": [],
597 | "source": [
598 | "with model2:\n",
599 | " idata2 = pm.sample(random_seed=42, target_accept=0.95)\n",
600 | "\n",
601 | "# plot_trace関数で推論結果の確認\n",
602 | "az.plot_trace(idata2, var_names=['alpha', 'beta', 'epsilon'], compact=False)\n",
603 | "plt.tight_layout();"
604 | ]
605 | },
606 | {
607 | "cell_type": "markdown",
608 | "metadata": {
609 | "id": "YuSQrLJQ2xCs"
610 | },
611 | "source": [
612 | "#### plot_forest関数で各項目の寄与度の確認"
613 | ]
614 | },
615 | {
616 | "cell_type": "code",
617 | "execution_count": null,
618 | "metadata": {
619 | "id": "enV-c9OF2XDv"
620 | },
621 | "outputs": [],
622 | "source": [
623 | "az.plot_forest(idata2, combined=True, var_names=['alpha']);"
624 | ]
625 | },
626 | {
627 | "cell_type": "markdown",
628 | "metadata": {
629 | "id": "Kbeq8nKofJWK"
630 | },
631 | "source": [
632 | "#### サンプリング結果の統計分析"
633 | ]
634 | },
635 | {
636 | "cell_type": "code",
637 | "execution_count": null,
638 | "metadata": {
639 | "id": "xvOgJLxa21JJ"
640 | },
641 | "outputs": [],
642 | "source": [
643 | "summary2 = az.summary(idata2, var_names=['alpha'])\n",
644 | "display(summary2)"
645 | ]
646 | },
647 | {
648 | "cell_type": "markdown",
649 | "metadata": {
650 | "id": "yIpoduRaWSsJ"
651 | },
652 | "source": [
653 | "#### バージョンの確認"
654 | ]
655 | },
656 | {
657 | "cell_type": "code",
658 | "execution_count": null,
659 | "metadata": {
660 | "id": "Tn17A0Iu3hIV"
661 | },
662 | "outputs": [],
663 | "source": [
664 | "!pip install watermark | tail -n 1\n",
665 | "%load_ext watermark\n",
666 | "%watermark --iversions"
667 | ]
668 | },
669 | {
670 | "cell_type": "code",
671 | "execution_count": null,
672 | "metadata": {
673 | "id": "EPeyuEI6WcFC"
674 | },
675 | "outputs": [],
676 | "source": []
677 | }
678 | ],
679 | "metadata": {
680 | "colab": {
681 | "provenance": [],
682 | "toc_visible": true
683 | },
684 | "kernelspec": {
685 | "display_name": "Python 3",
686 | "name": "python3"
687 | },
688 | "language_info": {
689 | "name": "python"
690 | }
691 | },
692 | "nbformat": 4,
693 | "nbformat_minor": 0
694 | }
--------------------------------------------------------------------------------
/refs/3クラス潜在変数モデル.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/makaishi2/python_bayes_intro/1643281f7ab360443a66e3be943a7d562bf4dc9e/refs/3クラス潜在変数モデル.pdf
--------------------------------------------------------------------------------
/refs/errors.md:
--------------------------------------------------------------------------------
1 | ### 正誤訂正
2 |
3 |
4 | #### 第1版第1刷
5 | |章 |ページ |内容 |補足|最終更新日|
6 | |:--|---|:--|:--|:--|
7 | |まえがき|ⅵ|2行目
(誤) 重要なな
(正) 重要な||2023-11-27|
8 | |1章|p.2|下から2行目
(誤) 対応がづく
(正) 対応がつく||2023-11-27|
9 |
10 | #### 第1版1~2刷
11 | |章 |ページ |内容 |補足|最終更新日|
12 | |:--|---|:--|:--|:--|
13 | |まえがき|v|17行目
(誤) 活用されてるかを含めてを紹介し
(正) 活用されているかを含めて紹介し||2023-12-12|
14 | |1章|p.3|下から4行目
(誤) 意味を理解することでなく
(正) 意味を理解することではなく||2023-12-12|
15 | |1章|p.19|コード1.11のキャプション
(誤) (コード1.3からコード1.6の一部を再掲)
(正) (コード1.5、1.6、1.8の一部を再掲)||2023-12-18|
16 | |1章|p.19|コード1.11の7行目
(誤) ``x_samples = samples['prior']['x'].values``
(正) ``x_samples = prior_samples['prior']['x'].values``||2023-12-18|
17 | |2章|p.50|3行目
(誤) に対応して
(正) に対応した||2023-12-12|
18 | |4章|p.82|7〜8行目
(修正前)"2.4. Diagnosing Numerical Inference"
(修正後) |URLとしてはダブルクオートの内部は不要なのでなくします|2023-12-12|
19 | |5章|p.97|6行目
(誤) なりますが。
(正) なりますが、||2023-12-12|
20 | |5章|p.101|コード5.1.13の最終行
(誤) y_list = norm(x_list, mu_mean, sigma_mean)
(正) y_list = norm(x_list, mu_mean1, sigma_mean1)||2023-11-27|
21 | |5章|p.102|2行目
(誤) mu_meanとsigma_mean
(正) mu_mean1とsigma_mean1||2023-11-27|
22 | |5章|p.103|コード5.1.15の2行目
(誤) X_less = x_result[:5]
(正) X_less = X[:5]||2023-11-27|
23 | |6章|p.209|コード6.3.12の13行目
(誤) x1 = summary_theta1['mean'].values
(正) x1 = summary_theta1['mean']|紙面の実装だと16行目の計算が標本標準偏差になってしまうため修正|2023-12-12|
24 | |6章|p.209|コード6.3.12の18行目
(誤) df_sum1['能力値'] = x1
(正) df_sum1['能力値'] = x1.values|紙面の実装だと16行目の計算が標本標準偏差になってしまうため修正|2023-12-12|
25 | |6章|p.209|実行結果
(コード修正前)

(コード修正後)
|能力値の列の値が上記コード修正に伴い多少変化します|2023-12-12|
26 | |6章|p.211|コード6.3.15の下から2行目
(出版時) w3 = (w1 * b_mean1).sum(axis=1)/w2[0]
(現在) w3 = (w1 * b_mean1).sum(axis=1)/w2.iloc[0]|pandasのバージョンが上がるとワーニングが出ることがわかっているので、バージョンアップに備えて事前にコードを修正|2023-12-12|
27 |
28 | #### 第1版1~3刷
29 | |章 |ページ |内容 |補足|最終更新日|
30 | |:--|---|:--|:--|:--|
31 | |4章|p.74|コード4.6 1行目
(誤) pm.model_to_graphviz(model)
(正) pm.model_to_graphviz(model1)||2024-01-06|
32 | |6章|p.164|表6.1.1 5行目一番右
(誤) 1500
(正) 1600||2024-01-13|
33 |
34 | #### 第1版1~4刷
35 | |章 |ページ |内容 |補足|最終更新日|
36 | |:--|---|:--|:--|:--|
37 | |1章|p.17|6行目
(誤) samples['prior']
(正) prior_samples['prior']||2024-02-18|
38 | |2章|p.25|最終行
(誤)x_samplesを用いて
(正)x_samples1を用いて||2024-02-18|
39 | |4章|p.83|コード4.13 1行目
(誤)model_to_graphviz(model)
(正)model_to_graphviz(model2)||2024-02-18|
40 | |5章|p.100|1行目
(誤)4章でも説明したとおり
(正)1.5.1項でも説明したとおり||2024-02-18|
41 | |5章|p.101|3行目
(誤)mu_meanとsigma_mean
(正)mu_mean1とsigma_mean1||2024-02-18|
42 | |5章|p.155|1行目
(誤)コード5.4.15の4行目
(正)コード5.4.15の5行目||2024-02-18|
43 | |6章|p.210|コード6.3.13 1行目
(誤)df_sum['偏差値'], df_sum['能力値']
(正)df_sum1['偏差値'], df_sum1['能力値']||2024-02-18|
44 |
45 | #### 第1版1~5刷
46 | |章 |ページ |内容 |補足|最終更新日|
47 | |:--|---|:--|:--|:--|
48 | |5章|p.152|コード5.4.13の28行目、29行目
(誤)
``ax.get_lines()[0].set_label('KDE versicolor')``
``ax.get_lines()[1].set_label('KDE virginica')``
(正)
``ax.get_lines()[1].set_label('KDE versicolor')``
``ax.get_lines()[0].set_label('KDE virginica')``||2024-03-11|
49 | |5 章|p.152|コード5.4.13の実行結果(グラフ)
(誤)

(正)
||2024-03-11|
50 |
51 |
52 |
53 |
54 | [メインページに戻る](../README.md)
55 |
--------------------------------------------------------------------------------
/refs/faqs.md:
--------------------------------------------------------------------------------
1 | ### FAQ
2 |
3 |
4 | |章-節|ページ |質問 |回答|最終更新日|
5 | |---|---|---|---|---|
6 | |4章|p.79 コード4.9|実行すると次のようなワーニングが表示されます(5章・6章の実習も同じ)
NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator.|Colab上のnumbaのバージョンが古いために表示されます。Notebookの冒頭で次のコマンドを実行すると解決するので、GithubのNotebookを更新しました
!pip install --upgrade numba \| tail -n 1|2024-03-09|
7 | |5章| p.148 コード5.4.9|次の引数でsample関数を呼び出しましたが、ラベルスイッチが起きませんでした。なぜですか?
``pm.sample(target_accept=0.99, random_seed=42)``|詳しい解説をqiitaの記事で書いたので、そちらを参照して下さい。
[qiitaリンク](https://qiita.com/makaishi2/items/8dae04e51a79ae456995)|2024-02-08|
8 | |5章|p.152 コード4.5.13 19,20行目|``* delta / n_components``の意味がわかりません。|正規分布関数にどのようなスケールの値を補正用にかけるとヒストグラムと高さが揃うかが、話の本質です。
わかりやすくするため、最も簡単な連続分布である[0, 1]の一様分布の例を考えます。
ヒストグラムの「高さ」は「該当するサンプル数の比率」であることも注意してください。
binsの値が1、つまり区間が一つしかない場合、すべてのサンプルが一つの区画に集中します。縦軸が比率なので、値は1です。
[0, 1]の一様分布の確率密度関数は``f(x) = 1``つまり値は1なので、このケースではスケール調整の必要はないことになります。
では、bins=10の場合はどうでしょうか。この場合、[0, 0.1]など一つの区画の幅は0.1です。
縦軸の値は比率なので、高さは先ほどの1/10、つまり0.1になることがわかります。この場合、スケール調整用の比率も0.1になります。
p.152のヒストグラムの場合、ヒストグラムの横幅は0.1であり、これが変数deltaの意味になります。
今回のケースでは2つの確率分布のサンプルが混在しています。この点を次に考慮します。
ここでも、問題を単純化するため、[0, 1]の一様分布に従う確率変数と、[3, 4]の一様分布に従う確率変数が1対1の割合で混在することを想定します。すると、すべてのサンプルが[0, 1]の一様分布に従う確率変数である場合と比べると、比率(高さ)は1/2になることがわかります。
これが``delta / n_components``とdelta(=0.1)の値を2で割っていることの意味です。
この思考実験から、2つの確率変数の混在比が1対1でない場合の計算方法もわかります。
例えばグループAの件数とグループBの件数の比が2対3である場合、グループA用の確率分布の補正スケールは``delta * 2 / 5``となります。
|2024-02-08|
9 | |6章| p.217 3行目|「2パラメータ・ロジスティックモデルでは自由度1が残っており」の意味がわからないです|p.198の表6.3.1でaの値を2倍、bの値を1/2倍、そして能力値の値も1/2にすることを考えます。するとp.197の式(6.3.2)から正答率は同じになることがわかります。a,b,θは一定の関係を保ったままで、同じ正答率になる組み合わせが他にあることが自由度1が残っていることを示しています。|2024-01-08|
10 | |(全体)|(全体)|ローカルPCのAnaconda環境で実習を動かすことはできますか|プログラミング経験のない初心者でも実習を始めることができるよう、実習はGoogle Colabを前提とした詳細手順を書籍で記載しています。ある程度プログラミングやLinuxに習熟した方は、ローカルで動かすことも可能ですが、下記の事象が想定され、これらの事象に自力で対応できることが条件となります。
1. サポートページのリンクはGoogle Colabと連携して動かすことが前提のURLです。ローカルで動かす場合は、GithubのUIを用いてzip形式でリポジトリ全体をダウンロードするか、git cloneコマンドでローカルにリポジトリをPULLするかしてください。
2. 不足しているライブラリがあった場合、自力でライブラリの追加導入が必要です。このことに伴って、ライブラリ間のバージョン不整合の問題が発生する可能性があります。|2023-12-10|
11 | |(全体)|(全体)|この本を最後までマスターした時の次のステップの参考書籍を教えて下さい|ベイズ推論の書籍は数多く出版されており、著者もすべてを把握しているわけではないですが、下記の書籍は同じ講談社様のシリーズ本の一つでもあり、次のステップの書籍として適切と思います。
[『Pythonではじめるベイズ機械学習入門』](https://www.amazon.co.jp/dp/406527978X)|2023-12-10|
12 | |(全体)|(全体)|まえがきに書かれている前提知識を満たしていないのですが、どうやって補足すればいいですか|本書の前提知識のほとんどは、著者の他の書籍でカバーできます。手前味噌で恐縮ですが、必要な知識と書籍の関係を整理したので、参考としていただけると幸いです。
* Python文法 [最短コースでわかる Pythonプログラミングとデータ分析](https://www.amazon.co.jp/dp/4296201123) 1章
* NumPy基礎 [最短コースでわかる Pythonプログラミングとデータ分析](https://www.amazon.co.jp/dp/4296201123) 2.2節
* データ前処理・機械学習全般 [Pythonで儲かるAIをつくる](https://www.amazon.co.jp/dp/4296106961/) 1章から3章・4.1節-4.3節・5.1節・5.2節
* オブジェクト指向プログラミング [最短コースでわかる PyTorch &深層学習プログラミング](https://www.amazon.co.jp/dp/4296110322) 1.5節 |2023-12-10|
13 |
14 |
15 |
16 | [メインページに戻る](../README.md)
17 |
--------------------------------------------------------------------------------
/refs/how-to-run.md:
--------------------------------------------------------------------------------
1 | ## 実習notebookの動かし方
2 |
3 | 本書の実習notebookはすべて、Google Colabに完全対応しています。事前にgmailのアカウントを取得していれば、下記の手順で**環境準備の作業ゼロで本書の実習コードを動かすことが可能**です。
4 |
5 |
6 |
7 | ### 前提
8 |
9 | Gmailアドレスを事前に取得し、別タブでGmailにログインした状態にします。
10 |
11 |
12 | ### 1. notebook一覧の表示
13 |
14 |
15 | 以下のリンクをクリックしてNotebook一覧を表示します。
16 |
17 |
18 | [Notebook一覧](../notebooks/)
19 |
20 |
21 |
22 | コントロールキーを押しながらリンクをクリックすると別タブで開くことができます。
23 |
24 |
25 |
26 | ### 2. 実行したいnotebookの選択
27 |
28 | 左のリストから実行したいnotebookのリンクをクリックします。(右のリストからでも起動可能です。
29 |
30 |
31 |
32 |

33 |
34 |
35 |
36 | 以下では「2章_よく利用される確率分布.ipynb」を選択しています。
37 |
38 |
39 | ### 3. Google Colabの起動
40 |
41 | 画面右上の**Open in Colab**アイコンをクリックします。
42 |
43 |
44 |

45 |
46 |
47 |
48 |
49 | ### 4. notebookのコピー
50 | 画面左上に8の数字を横にしたアイコンが表示されていれば、すでにGoogle Colabが使える状態です。
51 |
52 | ただ、このままでは更新したNotebookを保存できないので、Notebookを自分のホームディレクトリにコピーします。
53 | そのためには、画面右上の**ドライブにコピー**をクリックして下さい。
54 |
55 |
56 |

57 |
58 |
59 |
60 |
61 | ### 5. notebookの 実行
62 |
63 | 画面上部のファイル名が「xxx.ipynbのコピー」となっていれば、ホームディレクトリへのコピーが完了しています。
64 |
65 |
66 |
67 |
68 |

69 |
70 |
71 |
72 |
73 |
74 |
75 | この状態で、「Shift + Enter」(Shift を押しながら Enter を押す)を入力すると、選択されたセルが 実行されます。あるいは、メニューから「ラインタイム」「すべてのセルを実行」をクリックすると、すべてのセルを同時に実行することもできます。
76 |
77 | [メインページに戻る](../README.md)
78 |
--------------------------------------------------------------------------------
/refs/目次.md:
--------------------------------------------------------------------------------
1 | ## 1章 確率分布を理解する
2 | ### 1.1 ベイズ推論における確率分布の必要性
3 | ### 1.2 確率変数と確率分布
4 | ### 1.3 離散分布と連続分布
5 | ### 1.4 PyMCによる確率モデル定義とサンプリング
6 | ### 1.5 PyMCによるサンプリング結果分析
7 | ### 1.6 確率分布とPyMCプログラミングの関係
8 |
9 | ### コラム 確率モデルとサンプル値(観測値)の関係を考える
10 |
11 | ## 2章 よく利用される確率分布
12 | ### 2.1 ベルヌーイ分布(pm.Bernoulli 関数)
13 | ### 2.2 2項分布(pm.Binomial関数)
14 | ### 2.3 正規分布(pm.Normal関数)
15 | ### 2.4 一様分布(pm.Uniform関数)
16 | ### 2.5 ベータ分布(pm.Beta関数)
17 | ### 2.6 半正規分布(pm.HalfNormal関数)
18 |
19 | ### コラム HDIとCIの違い
20 |
21 | ## 3章 ベイズ推論とは
22 | ### 3.1 ベイズ推論利用の目的
23 | ### 3.2 問題設定
24 | ### 3.3 最尤推定による解
25 | ### 3.4 ベイズ推論による解
26 | ### 3.5 ベイズ推論の精度を上げる方法
27 | ### 3.6 ベイズ推論の活用例
28 |
29 | ### コラム 事前分布と事後分布
30 |
31 | ## 4章 はじめてのベイズ推論実習
32 | ### 4.1 問題設定 (再掲)
33 | ### 4.2 最尤推定
34 | ### 4.3 ベイズ推論 (確率モデル定義)
35 | ### 4.4 ベイズ推論 (サンプリング)
36 | ### 4.5 ベイズ推論 (結果分析)
37 | ### 4.6 ベイズ推論 (二項分布バージョン)
38 | ### 4.7 ベイズ推論 (試行数を増やす)
39 | ### 4.8 ベイズ推論 (事前分布の変更)
40 | ### 4.9 ベータ分布で直接確率分布を求める
41 |
42 | ### コラム ArbiZのFAQ
43 |
44 | ## 5章 ベイズ推論プログラミング
45 | ## 5.1 データ分布のベイズ推論
46 | ### 5.1.1 問題設定
47 | ### 5.1.2 データ準備
48 | ### 5.1.3 確率モデル定義
49 | ### 5.1.4 サンプリング
50 | ### 5.1.5 結果分析
51 | ### 5.1.6 正規分布関数と重ね描き
52 | ### 5.1.7 少ないサンプル数でベイズ推論
53 |
54 | ### コラム tauによる確率モデルの定義
55 |
56 | ## 5.2 線形回帰のベイズ推論
57 | ### 5.2.1 問題設定
58 | ### 5.2.2 データ準備
59 | ### 5.2.3 確率モデル定義1
60 | ### 5.2.4 確率モデル定義2
61 | ### 5.2.5 サンプリングと結果分析
62 | ### 5.2.6 回帰直線と散布図の重ね描き
63 | ### 5.2.7 少ない観測値でのベイズ推論
64 | ### コラム target_acceptによるチューニング
65 |
66 | ## 5.3 階層ベイズモデル
67 | ### 5.3.1 問題設定
68 | ### 5.3.2 データ準備
69 | ### 5.3.3 確率モデル定義
70 | ### 5.3.4 サンプリングと結果分析
71 | ### 5.3.5 散布図と回帰直線の重ね描き
72 | ### コラム PyMCの構成要素をどこまで細かく定義すべきか
73 |
74 | ## 5.4 潜在変数モデル
75 | ### 5.4.1 問題設定
76 | ### 5.4.2 データ準備
77 | ### 5.4.3 確率モデル定義
78 | ### 5.4.4 サンプリングと結果分析
79 | ### 5.4.5 ヒストグラムと分布関数の重ね描き
80 | ### 5.4.6 潜在変数の確率分布
81 | ### コラム 潜在変数モデルにおけるベイス推論のツボ
82 |
83 | ## 6章 ベイズ推論の業務活用事例
84 | ## 6.1 ABテストの効果検証
85 | ### 6.1.1 問題設定
86 | ### 6.1.2 確率モデル構築
87 | ### 6.1.3 サンプリングと結果分析
88 | ### 6.1.4 山田さんケースの分析
89 | ### 6.1.5 確率モデルを直接使った別解
90 |
91 | ### コラム ABテスト評価にベイズ推論を使う必要はないのか
92 |
93 | ## 6.2 ベイズ回帰モデルによる効果検証
94 | ### 6.2.1 問題設定
95 | ### 6.2.2 データ読み込み
96 | ### 6.2.3 データ確認
97 | ### 6.2.4 データ加工
98 | ### 6.2.5 確率モデル定義
99 | ### 6.2.6 サンプリングと結果分析
100 | ### 6.2.7 結果解釈
101 |
102 | ### コラム チュートリアルの確率モデル
103 |
104 | ## 6.3 IRTによるテスト結果評価
105 | ### 6.3.1 IRT(Item Response Theory)とは
106 | ### 6.3.2 問題設定
107 | ### 6.3.3 データ読み込み
108 | ### 6.3.4 データ加工
109 | ### 6.3.5 確率モデル定義
110 | ### 6.3.6 サンプリングと結果分析
111 | ### 6.3.7 詳細分析
112 | ### 6.3.8 偏差値と能力値の関係
113 | ### 6.3.9 同じ偏差値の受験者間の能力値の違い分析
114 |
115 | ### コラム 変分推論法の利用
116 |
117 |
--------------------------------------------------------------------------------
/sample-notebooks/A_3クラス潜在変数モデル.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "gT2-RDc3oM5q"
7 | },
8 | "source": [
9 | "## 補足 3クラス潜在変数モデル"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "
\n"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "-76dVLz4oWnR"
23 | },
24 | "source": [
25 | "### 共通処理"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {
32 | "id": "mO7SvaGzoRtu"
33 | },
34 | "outputs": [],
35 | "source": [
36 | "%matplotlib inline\n",
37 | "# 日本語化ライブラリ導入\n",
38 | "!pip install japanize-matplotlib | tail -n 1"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {
45 | "id": "31bw6AQRoSiV"
46 | },
47 | "outputs": [],
48 | "source": [
49 | "# ライブラリのimport\n",
50 | "\n",
51 | "# NumPy用ライブラリ\n",
52 | "import numpy as np\n",
53 | "\n",
54 | "# Matplotlib中のpyplotライブラリのインポート\n",
55 | "import matplotlib.pyplot as plt\n",
56 | "\n",
57 | "# matplotlib日本語化対応ライブラリのインポート\n",
58 | "import japanize_matplotlib\n",
59 | "\n",
60 | "# pandas用ライブラリ\n",
61 | "import pandas as pd\n",
62 | "\n",
63 | "# データフレーム表示用関数\n",
64 | "from IPython.display import display\n",
65 | "\n",
66 | "# seaborn\n",
67 | "import seaborn as sns\n",
68 | "\n",
69 | "# 表示オプション調整\n",
70 | "\n",
71 | "# NumPy表示形式の設定\n",
72 | "np.set_printoptions(precision=3, floatmode='fixed')\n",
73 | "\n",
74 | "# グラフのデフォルトフォント指定\n",
75 | "plt.rcParams[\"font.size\"] = 14\n",
76 | "\n",
77 | "# サイズ設定\n",
78 | "plt.rcParams['figure.figsize'] = (6, 6)\n",
79 | "\n",
80 | "# 方眼表示ON\n",
81 | "plt.rcParams['axes.grid'] = True\n",
82 | "\n",
83 | "# データフレームでの表示精度\n",
84 | "pd.options.display.float_format = '{:.3f}'.format\n",
85 | "\n",
86 | "# データフレームですべての項目を表示\n",
87 | "pd.set_option(\"display.max_columns\",None)"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {
94 | "id": "8bhlEvpJohj3"
95 | },
96 | "outputs": [],
97 | "source": [
98 | "import pymc as pm\n",
99 | "import arviz as az\n",
100 | "\n",
101 | "print(f\"Running on PyMC v{pm.__version__}\")\n",
102 | "print(f\"Running on ArViz v{az.__version__}\")"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "metadata": {
108 | "id": "FUPlPalapU-m"
109 | },
110 | "source": [
111 | "### A.1 カテゴリカル分布"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {
117 | "id": "YNC3IVj0LyaW"
118 | },
119 | "source": [
120 | "#### 確率モデル定義"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {
127 | "id": "gEQtGCvfC4cx"
128 | },
129 | "outputs": [],
130 | "source": [
131 | "# パラメータ設定\n",
132 | "p = [0.2, 0.5, 0.3]\n",
133 | "\n",
134 | "model1 = pm.Model()\n",
135 | "with model1:\n",
136 | " # pm.Categorical: カテゴリカル分布\n",
137 | " # p: 各要素の発生確率\n",
138 | " x = pm.Categorical('x', p=p)"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "metadata": {
144 | "id": "9GZg8qfeL4ii"
145 | },
146 | "source": [
147 | "#### 事前分布のサンプリングとサンプル値抽出"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {
154 | "id": "wMbozycaL_s5"
155 | },
156 | "outputs": [],
157 | "source": [
158 | "with model1:\n",
159 | " # 事前分布のサンプリング\n",
160 | " prior_samples1 = pm.sample_prior_predictive(random_seed=42)\n",
161 | "\n",
162 | "x_samples1 = prior_samples1['prior']['x'].values\n",
163 | "print(x_samples1)"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {
169 | "id": "Fa73TMsyMNhB"
170 | },
171 | "source": [
172 | "#### サンプリング結果の可視化"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {
179 | "id": "M4yjjoeJC4yY"
180 | },
181 | "outputs": [],
182 | "source": [
183 | "ax = az.plot_dist(x_samples1)\n",
184 | "ax.set_title(f'カテゴリカル分布 p={p}');"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {
190 | "id": "__ivn9rporp2"
191 | },
192 | "source": [
193 | "### A.2 ディリクレ分布"
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "metadata": {
199 | "id": "hCxdKrX3MpNr"
200 | },
201 | "source": [
202 | "#### 確率モデル定義"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": null,
208 | "metadata": {
209 | "id": "Tft9PJ-dopyX"
210 | },
211 | "outputs": [],
212 | "source": [
213 | "# パラメータ設定\n",
214 | "n_components = 3\n",
215 | "\n",
216 | "model2 = pm.Model()\n",
217 | "with model2:\n",
218 | " # ディリクレ分布\n",
219 | " # a:パラメータ [1, 1, 1]だと一様分布\n",
220 | " p = pm.Dirichlet('p', a=np.ones(n_components))"
221 | ]
222 | },
223 | {
224 | "cell_type": "markdown",
225 | "metadata": {
226 | "id": "ci26H8TnMxlV"
227 | },
228 | "source": [
229 | "#### 事前分布のサンプリングとサンプル値抽出"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {
236 | "id": "VGIS7kQbMt7P"
237 | },
238 | "outputs": [],
239 | "source": [
240 | "with model2:\n",
241 | " # サンプル値取得\n",
242 | " samples2 = pm.sample_prior_predictive(random_seed=42)\n",
243 | "\n",
244 | "# サンプル値抽出\n",
245 | "x_samples2 = samples2['prior']['p'].values\n",
246 | "# 桁数が多いので先頭10個だけに限定\n",
247 | "print(x_samples2[:,:10])"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {
253 | "id": "O8sHHas3NCMo"
254 | },
255 | "source": [
256 | "#### サンプリング結果の可視化"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "metadata": {
263 | "id": "qZyzSLJiok-U"
264 | },
265 | "outputs": [],
266 | "source": [
267 | "# サンプル値の可視化\n",
268 | "samples2 = x_samples2.reshape(-1,3)\n",
269 | "plt.title('ディリクレ分布 a=(1,1,1)の場合')\n",
270 | "x1 = samples2[:,0]\n",
271 | "x2 = samples2[:,1]\n",
272 | "plt.scatter(x1,x2, s=5);"
273 | ]
274 | },
275 | {
276 | "cell_type": "markdown",
277 | "metadata": {
278 | "id": "oEUc-ARQWD49"
279 | },
280 | "source": [
281 | "### A.3 3クラス潜在変数モデル"
282 | ]
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "metadata": {
287 | "id": "slBA8sjXWSo-"
288 | },
289 | "source": [
290 | "#### データ読み込み"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "metadata": {
297 | "id": "3rljUGYwWY3q"
298 | },
299 | "outputs": [],
300 | "source": [
301 | "# アイリスデータセットの読み込み\n",
302 | "df = sns.load_dataset('iris')\n",
303 | "\n",
304 | "# 先頭5行の確認\n",
305 | "display(df.head())\n",
306 | "\n",
307 | "# speciesの分布確認\n",
308 | "df['species'].value_counts()"
309 | ]
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "metadata": {
314 | "id": "s5CNWBeilv5q"
315 | },
316 | "source": [
317 | "#### 変数設定"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": null,
323 | "metadata": {
324 | "id": "yWoSC4qkWnAo"
325 | },
326 | "outputs": [],
327 | "source": [
328 | "# 観測値データ\n",
329 | "X = df['petal_width'].values\n",
330 | "\n",
331 | "# データ件数\n",
332 | "N = X.shape\n",
333 | "\n",
334 | "# 分類先クラス数\n",
335 | "n_components = 3"
336 | ]
337 | },
338 | {
339 | "cell_type": "markdown",
340 | "metadata": {
341 | "id": "9wOBXpofW1xu"
342 | },
343 | "source": [
344 | "#### 確率モデル定義"
345 | ]
346 | },
347 | {
348 | "cell_type": "code",
349 | "execution_count": null,
350 | "metadata": {
351 | "id": "C8QJ9PZ7W0LR"
352 | },
353 | "outputs": [],
354 | "source": [
355 | "model3 = pm.Model()\n",
356 | "\n",
357 | "with model3:\n",
358 | " # 観測値をpm.ConstantDataで定義する\n",
359 | " X_data = pm.ConstantData('X_data', X)\n",
360 | "\n",
361 | " # p: それぞれの値を取るの確率を示す3要素のベクトル\n",
362 | " p = pm.Dirichlet('p', a=np.ones(n_components))\n",
363 | "\n",
364 | " # s: pの確率値を基に0, 1, 2のいずれかの値を返す\n",
365 | " s = pm.Categorical('s', p=p, shape=N)\n",
366 | "\n",
367 | " # mus: 3つの花の種類毎の平均値\n",
368 | " mus = pm.Normal('mus', mu=0.0, sigma=10.0, shape=n_components)\n",
369 | "\n",
370 | " # taus: 3つの花の種類毎のバラツキ\n",
371 | " # 標準偏差sigmasとは taus = 1/(sigmas*sigmas) の関係にある\n",
372 | " taus = pm.HalfNormal('taus', sigma=10.0, shape=n_components)\n",
373 | "\n",
374 | " # グラフ描画など分析でsigmaが必要なため、tauからsigmaを求めておく\n",
375 | " sigmas = pm.Deterministic('sigmas', 1/pm.math.sqrt(taus))\n",
376 | "\n",
377 | " # 各観測値ごとに潜在変数からmuとtauを求める\n",
378 | " mu = pm.Deterministic('mu', mus[s])\n",
379 | " tau = pm.Deterministic('tau', taus[s])\n",
380 | "\n",
381 | " # 正規分布に従う確率変数X_obsの定義\n",
382 | " X_obs = pm.Normal('X_obs', mu=mu, tau=tau, observed=X_data)\n",
383 | "\n",
384 | "# モデル構造可視化\n",
385 | "g = pm.model_to_graphviz(model3)\n",
386 | "display(g);"
387 | ]
388 | },
389 | {
390 | "cell_type": "markdown",
391 | "metadata": {
392 | "id": "ZdEdJ-yaXVGq"
393 | },
394 | "source": [
395 | "#### サンプリング"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": null,
401 | "metadata": {
402 | "id": "FvKBaWrvqVvb"
403 | },
404 | "outputs": [],
405 | "source": [
406 | "with model3:\n",
407 | " idata3 = pm.sample(\n",
408 | " chains=1, draws=2000, target_accept=0.99,\n",
409 | " random_seed=42)"
410 | ]
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "metadata": {
415 | "id": "vqtyI9SWpt9H"
416 | },
417 | "source": [
418 | "#### 推論結果の確認"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": null,
424 | "metadata": {
425 | "id": "SW6IYZY5b5IR"
426 | },
427 | "outputs": [],
428 | "source": [
429 | "az.plot_trace(idata3, var_names=['p', 'mus', 'sigmas'], compact=False)\n",
430 | "plt.tight_layout();"
431 | ]
432 | },
433 | {
434 | "cell_type": "markdown",
435 | "metadata": {
436 | "id": "z8eqJS3FvtNA"
437 | },
438 | "source": [
439 | "#### 統計処理の集計"
440 | ]
441 | },
442 | {
443 | "cell_type": "code",
444 | "execution_count": null,
445 | "metadata": {
446 | "id": "c7rpy_LfdKif"
447 | },
448 | "outputs": [],
449 | "source": [
450 | "summary3 = az.summary(idata3, var_names=['p', 'mus', 'sigmas'])\n",
451 | "display(summary3)"
452 | ]
453 | },
454 | {
455 | "cell_type": "markdown",
456 | "metadata": {
457 | "id": "kYVTQ7nSw-UE"
458 | },
459 | "source": [
460 | "#### ヒストグラムと推論結果の重ね描き"
461 | ]
462 | },
463 | {
464 | "cell_type": "code",
465 | "execution_count": null,
466 | "metadata": {
467 | "id": "FyGWFSg7qyZ1"
468 | },
469 | "outputs": [],
470 | "source": [
471 | "# 正規分布関数の定義\n",
472 | "def norm(x, mu, sigma):\n",
473 | " return np.exp(-((x - mu)/sigma)**2/2) / (np.sqrt(2 * np.pi) * sigma)\n",
474 | "\n",
475 | "# 推論結果から各パラメータの平均値を取得\n",
476 | "mean3 = summary3['mean']\n",
477 | "\n",
478 | "# muの平均値取得\n",
479 | "mean3_mu0 = mean3['mus[0]']\n",
480 | "mean3_mu1 = mean3['mus[1]']\n",
481 | "mean3_mu2 = mean3['mus[2]']\n",
482 | "\n",
483 | "# sigmaの平均値取得\n",
484 | "mean3_sigma0 = mean3['sigmas[0]']\n",
485 | "mean3_sigma1 = mean3['sigmas[1]']\n",
486 | "mean3_sigma2 = mean3['sigmas[2]']\n",
487 | "\n",
488 | "# グラフ描画\n",
489 | "x = np.arange(0.0, 3.0, 0.05)\n",
490 | "plt.rcParams['figure.figsize']=(8,6)\n",
491 | "fig, ax = plt.subplots()\n",
492 | "sns.histplot(\n",
493 | " data=df,\n",
494 | " bins=np.arange(0.0, 3.0, 0.1),\n",
495 | " x='petal_width',\n",
496 | " hue='species', kde=True)\n",
497 | "plt.setp(ax.get_xticklabels(), rotation=90)\n",
498 | "plt.title('petal_widthのヒストグラム')\n",
499 | "plt.xticks(np.arange(0.0, 3.0, 0.1));\n",
500 | "plt.title('ヒストグラムと正規分布関数の重ね描き')\n",
501 | "plt.plot(x, norm(x, mean3_mu0, mean3_sigma0)*5.0, c='y', lw=3)\n",
502 | "plt.plot(x, norm(x, mean3_mu1, mean3_sigma1)*5.0, c='g', lw=3)\n",
503 | "plt.plot(x, norm(x, mean3_mu2, mean3_sigma2)*5.0, c='b', lw=3);"
504 | ]
505 | },
506 | {
507 | "cell_type": "markdown",
508 | "metadata": {
509 | "id": "Oba0vsmLXtC-"
510 | },
511 | "source": [
512 | "### A.4 3クラス潜在変数モデル(失敗例)"
513 | ]
514 | },
515 | {
516 | "cell_type": "markdown",
517 | "metadata": {
518 | "id": "xsP2fx_PX33M"
519 | },
520 | "source": [
521 | "#### 確率モデル定義"
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": null,
527 | "metadata": {
528 | "id": "QT0cH5MpXQGY"
529 | },
530 | "outputs": [],
531 | "source": [
532 | "model4 = pm.Model()\n",
533 | "\n",
534 | "with model4:\n",
535 | " # 観測値をpm.ConstantDataで定義する\n",
536 | " X_data = pm.ConstantData('X_data', X)\n",
537 | "\n",
538 | " # p: それぞれの値を取るの確率を示す3要素のベクトル\n",
539 | " p = pm.Dirichlet('p', a=np.ones(n_components))\n",
540 | "\n",
541 | " # s: pの確率値を基に0, 1, 2のいずれかの値を返す\n",
542 | " s = pm.Categorical('s', p=p, shape=N)\n",
543 | "\n",
544 | " # mus: 3つの花の種類毎の平均値\n",
545 | " mus = pm.Normal('mus', mu=0.0, sigma=10.0, shape=n_components)\n",
546 | "\n",
547 | " # taus: 3つの花の種類毎のバラツキ\n",
548 | " # 標準偏差sigmasとは taus = 1/(sigmas*sigmas) の関係にある\n",
549 | " sigmas = pm.HalfNormal('sigmas', sigma=10.0, shape=n_components)\n",
550 | "\n",
551 | " # 各観測値ごとに潜在変数からmuとtauを求める\n",
552 | " mu = pm.Deterministic('mu', mus[s])\n",
553 | " sigma = pm.Deterministic('sigma', sigmas[s])\n",
554 | "\n",
555 | " # mu[s], tau[s]: 潜在変数による参照\n",
556 | " X_obs = pm.Normal('X_obs', mu=mu, sigma=sigma, observed=X_data)\n",
557 | "\n",
558 | "# モデル構造可視化\n",
559 | "g = pm.model_to_graphviz(model4)\n",
560 | "display(g);"
561 | ]
562 | },
563 | {
564 | "cell_type": "markdown",
565 | "metadata": {
566 | "id": "RiCh50j_yy_Y"
567 | },
568 | "source": [
569 | "#### サンプリングと推論結果の確認"
570 | ]
571 | },
572 | {
573 | "cell_type": "code",
574 | "execution_count": null,
575 | "metadata": {
576 | "id": "ShlnIZssZqfU"
577 | },
578 | "outputs": [],
579 | "source": [
580 | "with model4:\n",
581 | " idata4 = pm.sample(\n",
582 | " chains=1, draws=2000, target_accept=0.99,\n",
583 | " random_seed=42)\n",
584 | "\n",
585 | "az.plot_trace(idata4, var_names=['p', 'mus', 'sigmas'], compact=False)\n",
586 | "plt.tight_layout();"
587 | ]
588 | },
589 | {
590 | "cell_type": "markdown",
591 | "metadata": {
592 | "id": "btrpEJTX2c5v"
593 | },
594 | "source": [
595 | "### A.5 3クラス潜在モデル(改良版)"
596 | ]
597 | },
598 | {
599 | "cell_type": "markdown",
600 | "metadata": {
601 | "id": "9fb0jstC2k4a"
602 | },
603 | "source": [
604 | "#### 確率モデル定義"
605 | ]
606 | },
607 | {
608 | "cell_type": "code",
609 | "execution_count": null,
610 | "metadata": {
611 | "id": "mE31ONmzfPXj"
612 | },
613 | "outputs": [],
614 | "source": [
615 | "model5 = pm.Model()\n",
616 | "\n",
617 | "with model5:\n",
618 | " # 観測値をpm.ConstantDataで定義する\n",
619 | " X_data = pm.ConstantData('X_data', X)\n",
620 | "\n",
621 | " # p: それぞれの値を取るの確率を示す3要素のベクトル\n",
622 | " p = pm.Dirichlet('p', a=np.ones(n_components))\n",
623 | "\n",
624 | " # s: pの確率値を基に0, 1, 2のいずれかの値を返す\n",
625 | " s = pm.Categorical('s', p=p, shape=N)\n",
626 | "\n",
627 | " # mus: 3つの花の種類毎の平均値\n",
628 | " mu0 = pm.HalfNormal('mu0', sigma=10.0)\n",
629 | " delta0 = pm.HalfNormal('delta0', sigma=10.0)\n",
630 | " mu1 = pm.Deterministic('mu1', mu0+delta0)\n",
631 | " delta1 = pm.HalfNormal('delta1', sigma=10.0)\n",
632 | " mu2 = pm.Deterministic('mu2', mu1+delta1)\n",
633 | " mus = pm.Deterministic('mus', pm.math.stack([mu0, mu1, mu2]))\n",
634 | "\n",
635 | " # taus: 3つの花の種類毎のバラツキ\n",
636 | " # 標準偏差sigmasとは taus = 1/(sigmas*sigmas) の関係にある\n",
637 | " taus = pm.HalfNormal('taus', sigma=10.0, shape=n_components)\n",
638 | "\n",
639 | " # グラフ描画など分析でsigmaが必要なため、tauからsigmaを求めておく\n",
640 | " sigmas = pm.Deterministic('sigmas', 1/pm.math.sqrt(taus))\n",
641 | "\n",
642 | " # 各観測値ごとに潜在変数からmuとtauを求める\n",
643 | " mu = pm.Deterministic('mu', mus[s])\n",
644 | " tau = pm.Deterministic('tau', taus[s])\n",
645 | "\n",
646 | " # mu[s], tau[s]: 潜在変数による参照\n",
647 | " X_obs = pm.Normal('X_obs', mu=mu, tau=tau, observed=X_data)\n",
648 | "\n",
649 | "# モデル構造可視化\n",
650 | "g = pm.model_to_graphviz(model5)\n",
651 | "display(g);"
652 | ]
653 | },
654 | {
655 | "cell_type": "markdown",
656 | "metadata": {
657 | "id": "AH_2X9Zr4MaD"
658 | },
659 | "source": [
660 | "#### サンプリングと推計結果の確認"
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "execution_count": null,
666 | "metadata": {
667 | "id": "yOK34Kt34EuC"
668 | },
669 | "outputs": [],
670 | "source": [
671 | "with model5:\n",
672 | " idata5 = pm.sample(target_accept=0.99, random_seed=42)\n",
673 | "\n",
674 | "plt.rcParams['figure.figsize']=(6,6)\n",
675 | "az.plot_trace(idata5, var_names=['p', 'mus', 'sigmas'], compact=False)\n",
676 | "plt.tight_layout();"
677 | ]
678 | },
679 | {
680 | "cell_type": "markdown",
681 | "metadata": {
682 | "id": "lBQRemhm7YpW"
683 | },
684 | "source": [
685 | "#### 統計処理の集計"
686 | ]
687 | },
688 | {
689 | "cell_type": "code",
690 | "execution_count": null,
691 | "metadata": {
692 | "id": "MR1yEsm67dBz"
693 | },
694 | "outputs": [],
695 | "source": [
696 | "summary5 = az.summary(idata5, var_names=['p', 'mus', 'sigmas'])\n",
697 | "display(summary5)"
698 | ]
699 | },
700 | {
701 | "cell_type": "markdown",
702 | "metadata": {
703 | "id": "IxSD7nVp62nw"
704 | },
705 | "source": [
706 | "#### ヒストグラムと推論結果の重ね描き"
707 | ]
708 | },
709 | {
710 | "cell_type": "code",
711 | "execution_count": null,
712 | "metadata": {
713 | "id": "fHtLga_l5AjH"
714 | },
715 | "outputs": [],
716 | "source": [
717 | "# 推論結果から各パラメータの平均値を取得\n",
718 | "mean5 = summary5['mean']\n",
719 | "\n",
720 | "# muの平均値取得\n",
721 | "mean5_mu0 = mean5['mus[0]']\n",
722 | "mean5_mu1 = mean5['mus[1]']\n",
723 | "mean5_mu2 = mean5['mus[2]']\n",
724 | "\n",
725 | "# sigmaの平均値取得\n",
726 | "mean5_sigma0 = mean5['sigmas[0]']\n",
727 | "mean5_sigma1 = mean5['sigmas[1]']\n",
728 | "mean5_sigma2 = mean5['sigmas[2]']\n",
729 | "\n",
730 | "# グラフ描画\n",
731 | "x = np.arange(0.0, 3.0, 0.05)\n",
732 | "plt.rcParams['figure.figsize']=(8,6)\n",
733 | "fig, ax = plt.subplots()\n",
734 | "sns.histplot(\n",
735 | " data=df,\n",
736 | " bins=np.arange(0.0, 3.0, 0.1),\n",
737 | " x='petal_width',\n",
738 | " hue='species', kde=True)\n",
739 | "plt.setp(ax.get_xticklabels(), rotation=90)\n",
740 | "plt.title('petal_widthのヒストグラム')\n",
741 | "plt.xticks(np.arange(0.0, 3.0, 0.1));\n",
742 | "plt.title('ヒストグラムと正規分布関数の重ね描き')\n",
743 | "plt.plot(x, norm(x, mean5_mu0, mean5_sigma0)*5.0, c='b', lw=3)\n",
744 | "plt.plot(x, norm(x, mean5_mu1, mean5_sigma1)*5.0, c='y', lw=3)\n",
745 | "plt.plot(x, norm(x, mean5_mu2, mean5_sigma2)*5.0, c='g', lw=3);"
746 | ]
747 | },
748 | {
749 | "cell_type": "code",
750 | "execution_count": null,
751 | "metadata": {
752 | "id": "Jkrtu7tz7M5V"
753 | },
754 | "outputs": [],
755 | "source": []
756 | }
757 | ],
758 | "metadata": {
759 | "colab": {
760 | "provenance": [],
761 | "toc_visible": true
762 | },
763 | "kernelspec": {
764 | "display_name": "Python 3",
765 | "name": "python3"
766 | },
767 | "language_info": {
768 | "name": "python"
769 | }
770 | },
771 | "nbformat": 4,
772 | "nbformat_minor": 0
773 | }
774 |
--------------------------------------------------------------------------------