├── README.md ├── .DS_Store ├── notebooks ├── .DS_Store ├── 09_x_numpy.ipynb ├── 07_x_numpy.ipynb ├── 11_03_weight_init.ipynb ├── 11_02_num_diff.ipynb ├── 11_01_pytorch.ipynb ├── 11_04_batch_size.ipynb ├── 11_06_optimizarion.ipynb ├── 07_regression.ipynb ├── 11_07_cnn.ipynb ├── 11_05_overfitting_prev.ipynb ├── 08_bi_classify.ipynb └── 09_multi_classify.ipynb ├── math-answers ├── q2_answers.md └── q3_answers.md ├── code-list.md └── LICENSE /README.md: -------------------------------------------------------------------------------- 1 | # math_dl_book_info_v2 2 | ディープラーニングの数学 第2版 サポートページ 3 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makaishi2/math_dl_book_info_v2/main/.DS_Store -------------------------------------------------------------------------------- /notebooks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makaishi2/math_dl_book_info_v2/main/notebooks/.DS_Store -------------------------------------------------------------------------------- /math-answers/q2_answers.md: -------------------------------------------------------------------------------- 1 | 2 | ## Q2-1 3 | 合成関数 4 | 5 | $$ h(x)=g\circ f(x) $$ 6 | 7 | を求める。 8 | 9 | --- 10 | 11 | $$ 12 | f(x)=x^2+1,\quad g(x)=\sqrt{x} 13 | $$ 14 | 15 | よって, 16 | 17 | $$ 18 | h(x)=g(f(x))=\sqrt{x^2+1} 19 | $$ 20 | 21 | --- 22 | 23 | ## Q2-2 24 | 次の関数の逆関数を求める。 25 | 26 | $$ 27 | f(x)=\sqrt{2x+1} 28 | $$ 29 | 30 | --- 31 | 32 | $$ 33 | y=\sqrt{2x+1} 34 | $$ 35 | 36 | $$ 37 | y^2=2x+1 38 | $$ 39 | 40 | $$ 41 | x=\frac{y^2-1}{2} 42 | $$ 43 | 44 | したがって逆関数は, 45 | 46 | $$ 47 | f^{-1}(x)=\frac{x^2-1}{2}\quad (x\ge 0) 48 | $$ 49 | 50 | --- 51 | 52 | ## Q2-3 53 | Q2-1で求めた合成関数を微分する。 54 | 55 | --- 56 | 57 | $$ 58 | h(x)=\sqrt{x^2+1}=(x^2+1)^{1/2} 59 | $$ 60 | 61 | $$ 62 | h'(x)=\frac{1}{2}(x^2+1)^{-1/2}\cdot 2x 63 | =\frac{x}{\sqrt{x^2+1}} 64 | $$ 65 | 66 | --- 67 | 68 | ## Q2-4 69 | 次の関数を微分せよ。 70 | 71 | $$ 72 | f(x)=\frac{x}{\sqrt{x^2+1}} 73 | $$ 74 | 75 | --- 76 | 77 | 分母をべきの形に直す。 78 | 79 | $$ 80 | f(x)=x(x^2+1)^{-1/2} 81 | $$ 82 | 83 | 積の微分を用いる。 84 | 85 | $$ 86 | f'(x)=1\cdot (x^2+1)^{-1/2} 87 | +x\cdot\left(-\frac12\right)(x^2+1)^{-3/2}\cdot 2x 88 | $$ 89 | 90 | $$ 91 | f'(x)=(x^2+1)^{-1/2}-x^2(x^2+1)^{-3/2} 92 | $$ 93 | 94 | 共通因子でまとめる。 95 | 96 | $$ 97 | f'(x)=\frac{x^2+1-x^2}{(x^2+1)^{3/2}} 98 | =\frac{1}{(x^2+1)^{3/2}} 99 | $$ 100 | 101 | --- 102 | 103 | # Q2-5 104 | 次の不定積分を計算せよ。 105 | 106 | $$ 107 | \int \frac{x}{\sqrt{x^2+1}}\ dx 108 | $$ 109 | 110 | --- 111 | 112 |
113 | Q2-3 の解答より 114 | 115 | $$ 116 | \sqrt{x^2+1}+C 117 | $$ 118 | -------------------------------------------------------------------------------- /code-list.md: -------------------------------------------------------------------------------- 1 |  実習コード集 2 | 3 | 4 | 5 | | 章 |タイトル(リンク) | 備考 | 6 | | ---- | -------------- | ------------------------------------------------------------ | 7 | | 7章 | [回帰モデル](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/07_regression.ipynb) | | 8 | | 8章 | [二値分類モデル](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/08_bi_classify.ipynb) | | 9 | | 9章 |[多値分類モデル](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/09_multi_classify.ipynb) | | 10 | | 10章 |[ディープラーニングモデル](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/10_deeplearning.ipynb) | | 11 | | 11.1節 |[フレームワーク](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/11_01_pytorch.ipynb) | | 12 | | 11.2節 |[数値微分](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/11_02_num_diff.ipynb) | | 13 | | 11.3節 |[重み行列初期化](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/11_03_weight_init.ipynb) | | 14 | | 11.4節 |[学習の単位](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/11_04_batch_size.ipynb) | | 15 | | 11.5節 |[過学習対策](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/11_05_overfitting_prev.ipynb) | | 16 | | 11.6節 |[最適化関数](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/11_06_optimizarion.ipynb) | | 17 | | 11.7節 |[CNN](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/11_07_cnn.ipynb) | | 18 | | 11.8節 |[RNN/LSTM](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/11_08_lstm.ipynb) | | 19 | | 11.9節 |[Transformer](https://colab.research.google.com/github/makaishi2/math_dl_book_info_v2/blob/main/notebooks/11_09_transformer.ipynb) | | 20 | | | | | 21 | 22 | -------------------------------------------------------------------------------- /notebooks/09_x_numpy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 9章補足" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n", 20 | "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n", 21 | "[[1 2 3]\n", 22 | " [4 5 6]]\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "# データ定義\n", 28 | "import numpy as np\n", 29 | "\n", 30 | "x = np.array([[1,2,3],[4,5,6]])\n", 31 | "print(x)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "[5 7 9]\n", 44 | "[ 6 15]\n", 45 | "21\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "# 3通りのsum関数\n", 51 | "\n", 52 | "y = x.sum(axis=0)\n", 53 | "print(y)\n", 54 | "\n", 55 | "z = x.sum(axis=1)\n", 56 | "print(z)\n", 57 | "\n", 58 | "w = x.sum()\n", 59 | "print(w)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": "base", 73 | "language": "python", 74 | "name": "python3" 75 | }, 76 | "language_info": { 77 | "codemirror_mode": { 78 | "name": "ipython", 79 | "version": 3 80 | }, 81 | "file_extension": ".py", 82 | "mimetype": "text/x-python", 83 | "name": "python", 84 | "nbconvert_exporter": "python", 85 | "pygments_lexer": "ipython3", 86 | "version": "3.11.4" 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 2 91 | } 92 | -------------------------------------------------------------------------------- /notebooks/07_x_numpy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# 7章コラム Numpyを使ったコーディングテクニック" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## ベクトル・ベクトル間の内積" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "# w = (1, 2)\n", 35 | "w = np.array([1, 2])\n", 36 | "print(w)\n", 37 | "print(w.shape)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# x = (3, 4)\n", 47 | "x = np.array([3, 4])\n", 48 | "print(x)\n", 49 | "print(x.shape)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# (3.7.2)式の内積の実装例\n", 59 | "# y = 1*3 + 2*4 = 11\n", 60 | "y = x @ w\n", 61 | "print(y)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## 行列・ベクトル間の内積" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# X は3行2列の行列\n", 78 | "X = np.array([[1,2],[3,4],[5,6]])\n", 79 | "print(X)\n", 80 | "print(X.shape)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "Y = X @ w\n", 90 | "print(Y)\n", 91 | "print(Y.shape)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "## データ系列方向の行列・ベクトル間の内積" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "# 転置行列の作成\n", 108 | "XT = X . T\n", 109 | "print(X)\n", 110 | "print(XT)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "yd = np.array([1, 2, 3])\n", 120 | "print(yd)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "# 勾配値の計算(の一部)\n", 130 | "grad = XT @ yd\n", 131 | "print(grad)" 132 | ] 133 | } 134 | ], 135 | "metadata": { 136 | "kernelspec": { 137 | "display_name": "Python 3", 138 | "language": "python", 139 | "name": "python3" 140 | }, 141 | "language_info": { 142 | "codemirror_mode": { 143 | "name": "ipython", 144 | "version": 3 145 | }, 146 | "file_extension": ".py", 147 | "mimetype": "text/x-python", 148 | "name": "python", 149 | "nbconvert_exporter": "python", 150 | "pygments_lexer": "ipython3", 151 | "version": "3.6.8" 152 | } 153 | }, 154 | "nbformat": 4, 155 | "nbformat_minor": 1 156 | } 157 | -------------------------------------------------------------------------------- /notebooks/11_03_weight_init.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "source": [ 21 | "## 11.3 重み行列の初期化" 22 | ], 23 | "metadata": { 24 | "id": "pD3dNabNPCT5" 25 | } 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "source": [ 30 | "### ライブラリのインポート" 31 | ], 32 | "metadata": { 33 | "id": "n6GiWXlLQDhw" 34 | } 35 | }, 36 | { 37 | "cell_type": "code", 38 | "source": [ 39 | "# ライブラリのインポート\n", 40 | "import torch\n", 41 | "import torch.nn as nn" 42 | ], 43 | "metadata": { 44 | "id": "yIED_UaFPGfU" 45 | }, 46 | "execution_count": null, 47 | "outputs": [] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "source": [ 52 | "### PyTorchデフォルトでの初期化" 53 | ], 54 | "metadata": { 55 | "id": "BW5sZ-FFNCe1" 56 | } 57 | }, 58 | { 59 | "cell_type": "code", 60 | "source": [ 61 | "# PyTorchデフォルトでの初期化\n", 62 | "\n", 63 | "# 乱数初期化\n", 64 | "torch.manual_seed(123)\n", 65 | "\n", 66 | "# 入力3次元 → 出力2次元の全結合層\n", 67 | "# デフォルトではKaiming Uniform(He Uniform)となる\n", 68 | "layer = nn.Linear(3, 2)\n", 69 | "\n", 70 | "# 結果確認\n", 71 | "print(layer.weight)\n", 72 | "print(layer.bias)" 73 | ], 74 | "metadata": { 75 | "id": "y48Yyj4WNH9N" 76 | }, 77 | "execution_count": null, 78 | "outputs": [] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "source": [ 83 | "### PyTorchで重み行列初期化を手動変更" 84 | ], 85 | "metadata": { 86 | "id": "S9wDI19tPx1N" 87 | } 88 | }, 89 | { 90 | "cell_type": "code", 91 | "source": [ 92 | "# PyTorchで重み行列初期化を手動変更\n", 93 | "\n", 94 | "# 初期化用ライブラリのインポート\n", 95 | "import torch.nn.init as init\n", 96 | "\n", 97 | "torch.manual_seed(123)\n", 98 | "layer = nn.Linear(3, 2)\n", 99 | "\n", 100 | "# weightをXavier一様分布\n", 101 | "init.xavier_uniform_(layer.weight)\n", 102 | "\n", 103 | "# biasをすべてゼロ\n", 104 | "init.zeros_(layer.bias)\n", 105 | "\n", 106 | "# 結果確認\n", 107 | "print(layer.weight)\n", 108 | "print(layer.bias)" 109 | ], 110 | "metadata": { 111 | "id": "gVm_bz7FNOWu" 112 | }, 113 | "execution_count": null, 114 | "outputs": [] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "source": [ 119 | "### バージョン確認" 120 | ], 121 | "metadata": { 122 | "id": "XjEnM4uwSt2T" 123 | } 124 | }, 125 | { 126 | "cell_type": "code", 127 | "source": [ 128 | "!pip install watermark -qq\n", 129 | "%load_ext watermark\n", 130 | "%watermark --iversions" 131 | ], 132 | "metadata": { 133 | "id": "4csWtLmsOC0R" 134 | }, 135 | "execution_count": null, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "source": [], 141 | "metadata": { 142 | "id": "VRPt-bXLS0M2" 143 | }, 144 | "execution_count": null, 145 | "outputs": [] 146 | } 147 | ] 148 | } -------------------------------------------------------------------------------- /math-answers/q3_answers.md: -------------------------------------------------------------------------------- 1 | 2 | ## Q3-1 3 | 次のベクトルの絶対値を求めよ。 4 | 5 | $$ 6 | \mathbf{a}=(15,\,36) 7 | $$ 8 | 9 | --- 10 | 11 | ### 解答 12 | $$ 13 | |\mathbf{a}|=\sqrt{15^2+36^2} 14 | =\sqrt{225+1296} 15 | =\sqrt{1521} 16 | =39 17 | $$ 18 | 19 | --- 20 | 21 | ## Q3-2 22 | 次の 2 つのベクトルの距離を求めよ。 23 | 24 | $$ 25 | \mathbf{b}=(9,\,2,\,4),\quad 26 | \mathbf{c}=(6,\,-3,\,8) 27 | $$ 28 | 29 | --- 30 | 31 | ### 解答 32 | 距離は差ベクトルの絶対値である。 33 | 34 | $$ 35 | \mathbf{b}-\mathbf{c} 36 | =(9-6,\;2-(-3),\;4-8) 37 | =(3,\,5,\,-4) 38 | $$ 39 | 40 | $$ 41 | \text{距離} 42 | =\sqrt{3^2+5^2+(-4)^2} 43 | =\sqrt{9+25+16} 44 | =\sqrt{50} 45 | =5\sqrt{2} 46 | $$ 47 | 48 | --- 49 | 50 | ## Q3-3 51 | 次の 2 つのベクトルのなす角を求めよ。 52 | 53 | $$ 54 | \mathbf{u}=(4,\,1,\,1),\quad 55 | \mathbf{v}=(1,\,1,\,4) 56 | $$ 57 | 58 | --- 59 | 60 | ### 解答 61 | まず内積を計算する。 62 | 63 | $$ 64 | \mathbf{u}\cdot\mathbf{v} 65 | =4\cdot1+1\cdot1+1\cdot4 66 | =4+1+4 67 | =9 68 | $$ 69 | 70 | 次に、それぞれの絶対値を求める。 71 | 72 | $$ 73 | |\mathbf{u}|=\sqrt{4^2+1^2+1^2} 74 | =\sqrt{16+1+1} 75 | =\sqrt{18} 76 | $$ 77 | 78 | $$ 79 | |\mathbf{v}|=\sqrt{1^2+1^2+4^2} 80 | =\sqrt{1+1+16} 81 | =\sqrt{18} 82 | $$ 83 | 84 | なす角を $$\theta$$ とすると、 85 | 86 | $$ 87 | \cos\theta 88 | =\frac{\mathbf{u}\cdot\mathbf{v}}{|\mathbf{u}||\mathbf{v}|} 89 | =\frac{9}{\sqrt{18}\sqrt{18}} 90 | =\frac{9}{18} 91 | =\frac{1}{2} 92 | $$ 93 | 94 | よって、 95 | 96 | $$ 97 | \theta=60^\circ 98 | $$ 99 | 100 | --- 101 | 102 | ## Q3-4 103 | 次の 3 つのベクトル間の **コサイン類似度** を求め、 104 | 「最も近い 2 つのベクトル」の組を見つけよ。 105 | 106 | $$ 107 | \mathbf{a}=(-3,\,-3,\,-3,\,-3) 108 | $$ 109 | 110 | $$ 111 | \mathbf{b}=(-3,\,-2,\,-2,\,1) 112 | $$ 113 | 114 | $$ 115 | \mathbf{c}=(-4,\,1,\,4,\,-2) 116 | $$ 117 | 118 | --- 119 | 120 | ### 解答 121 | コサイン類似度は 122 | 123 | $$ 124 | \cos(\mathbf{x},\mathbf{y}) 125 | =\frac{\mathbf{x}\cdot\mathbf{y}}{|\mathbf{x}||\mathbf{y}|} 126 | $$ 127 | 128 | である。 129 | 130 | --- 131 | 132 | ### 1) ノルム 133 | $$ 134 | |\mathbf{a}|=\sqrt{(-3)^2+(-3)^2+(-3)^2+(-3)^2} 135 | =\sqrt{36} 136 | =6 137 | $$ 138 | 139 | $$ 140 | |\mathbf{b}|=\sqrt{(-3)^2+(-2)^2+(-2)^2+1^2} 141 | =\sqrt{9+4+4+1} 142 | =\sqrt{18} 143 | $$ 144 | 145 | $$ 146 | |\mathbf{c}|=\sqrt{(-4)^2+1^2+4^2+(-2)^2} 147 | =\sqrt{16+1+16+4} 148 | =\sqrt{37} 149 | $$ 150 | 151 | --- 152 | 153 | ### 2) 内積とコサイン類似度 154 | 155 | #### (a, b) 156 | $$ 157 | \mathbf{a}\cdot\mathbf{b} 158 | =(-3)(-3)+(-3)(-2)+(-3)(-2)+(-3)(1) 159 | =9+6+6-3 160 | =18 161 | $$ 162 | 163 | $$ 164 | \cos(\mathbf{a},\mathbf{b}) 165 | =\frac{18}{6\sqrt{18}} 166 | =\frac{3}{\sqrt{18}} 167 | =\frac{1}{\sqrt{2}} 168 | $$ 169 | 170 | #### (a, c) 171 | $$ 172 | \mathbf{a}\cdot\mathbf{c} 173 | =(-3)(-4)+(-3)(1)+(-3)(4)+(-3)(-2) 174 | =12-3-12+6 175 | =3 176 | $$ 177 | 178 | $$ 179 | \cos(\mathbf{a},\mathbf{c}) 180 | =\frac{3}{6\sqrt{37}} 181 | =\frac{1}{2\sqrt{37}} 182 | $$ 183 | 184 | #### (b, c) 185 | $$ 186 | \mathbf{b}\cdot\mathbf{c} 187 | =(-3)(-4)+(-2)(1)+(-2)(4)+1(-2) 188 | =12-2-8-2 189 | =0 190 | $$ 191 | 192 | $$ 193 | \cos(\mathbf{b},\mathbf{c})=0 194 | $$ 195 | 196 | --- 197 | 198 | ### 3) 最も近い組 199 | $$ 200 | \cos(\mathbf{a},\mathbf{b})=\frac{1}{\sqrt{2}} 201 | ,\quad 202 | \cos(\mathbf{a},\mathbf{c})=\frac{1}{2\sqrt{37}} 203 | ,\quad 204 | \cos(\mathbf{b},\mathbf{c})=0 205 | $$ 206 | 207 | よってコサイン類似度が最大なのは $$\mathbf{a}$$ と $$\mathbf{b}$$ である。 208 | 209 | $$ 210 | \boxed{\text{最も近い 2 つのベクトルは }\mathbf{a}\text{ と }\mathbf{b}} 211 | $$ 212 | 213 | ## Q3-5 214 | 次の行列とベクトルのかけ算を計算せよ。 215 | 216 | $$ 217 | \begin{pmatrix} 218 | 2 & -1 \\ 219 | 4 & 3 \\ 220 | -5 & 6 221 | \end{pmatrix} 222 | \begin{pmatrix} 223 | 3 \\ 224 | 2 225 | \end{pmatrix} 226 | $$ 227 | 228 | --- 229 | 230 | ### 解答 231 | 行列とベクトルの積は、**各行とベクトルの内積**で計算する。 232 | 233 | #### 第1成分 234 | $$ 235 | 2\cdot 3 + (-1)\cdot 2 236 | =6-2 237 | =4 238 | $$ 239 | 240 | #### 第2成分 241 | $$ 242 | 4\cdot 3 + 3\cdot 2 243 | =12+6 244 | =18 245 | $$ 246 | 247 | #### 第3成分 248 | $$ 249 | (-5)\cdot 3 + 6\cdot 2 250 | =-15+12 251 | =-3 252 | $$ 253 | 254 | ### 結果 255 | 256 | $$ 257 | \begin{pmatrix} 258 | 2 & -1 \\ 259 | 4 & 3 \\ 260 | -5 & 6 261 | \end{pmatrix} 262 | \begin{pmatrix} 263 | 3 \\ 264 | 2 265 | \end{pmatrix} 266 | = 267 | \begin{pmatrix} 268 | 4 \\ 269 | 18 \\ 270 | -3 271 | \end{pmatrix} 272 | $$ 273 | 274 | -------------------------------------------------------------------------------- /notebooks/11_02_num_diff.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "source": [ 21 | "## 11.2 数値微分" 22 | ], 23 | "metadata": { 24 | "id": "xUFtXDp0eSHj" 25 | } 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "source": [ 30 | "### 必要ライブラリの導入" 31 | ], 32 | "metadata": { 33 | "id": "eWC1xCN3exGZ" 34 | } 35 | }, 36 | { 37 | "cell_type": "code", 38 | "source": [ 39 | "# 日本語化ライブラリ導入\n", 40 | "!pip install japanize-matplotlib -qq\n", 41 | "!pip install torchviz -qq\n", 42 | "!pip install torchinfo -qq" 43 | ], 44 | "metadata": { 45 | "id": "lVRPvERSeo4m" 46 | }, 47 | "execution_count": null, 48 | "outputs": [] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "source": [ 53 | "# 共通事前処理\n", 54 | "\n", 55 | "# 必要ライブラリのimport\n", 56 | "import pandas as pd\n", 57 | "import numpy as np\n", 58 | "import matplotlib.pyplot as plt\n", 59 | "\n", 60 | "# matplotlib日本語化対応\n", 61 | "import japanize_matplotlib\n", 62 | "\n", 63 | "# Numpyで浮動小数点表示を固定する\n", 64 | "np.set_printoptions(formatter={'float': '{:0.3f}'.format})\n", 65 | "\n", 66 | "# データフレーム表示用関数\n", 67 | "from IPython.display import display\n", 68 | "\n", 69 | "# pandasでの浮動小数点の表示精度\n", 70 | "pd.options.display.float_format = '{:.3f}'.format\n", 71 | "\n", 72 | "# 余分なワーニングを非表示にする\n", 73 | "import warnings\n", 74 | "warnings.filterwarnings('ignore')" 75 | ], 76 | "metadata": { 77 | "id": "Q0HTOcPOe3qg" 78 | }, 79 | "execution_count": null, 80 | "outputs": [] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "source": [ 85 | "# PyTorch関係ライブラリインポート\n", 86 | "import torch\n", 87 | "import torch.nn as nn\n", 88 | "import torch.optim as optim\n", 89 | "from torchinfo import summary\n", 90 | "from torchviz import make_dot" 91 | ], 92 | "metadata": { 93 | "id": "baZFF0XXe74D" 94 | }, 95 | "execution_count": null, 96 | "outputs": [] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "source": [ 101 | "### PyTorchによる数値微分サンプル" 102 | ], 103 | "metadata": { 104 | "id": "iwMVLdrKfK-8" 105 | } 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "source": [ 110 | "#### 2次関数の数値微分" 111 | ], 112 | "metadata": { 113 | "id": "7ZhKXh_C8Tf5" 114 | } 115 | }, 116 | { 117 | "cell_type": "code", 118 | "source": [ 119 | "# 2次関数の数値微分\n", 120 | "\n", 121 | "# xをnumpy配列で定義\n", 122 | "x_np = np.arange(-2, 2.1, 0.25)\n", 123 | "\n", 124 | "# 勾配計算用変数xの定義\n", 125 | "x = torch.tensor(x_np, requires_grad=True,\n", 126 | " dtype=torch.float32)\n", 127 | "\n", 128 | "# 2次関数の計算\n", 129 | "# 裏で計算グラフが自動生成される\n", 130 | "y = 2 * x**2 + 2\n", 131 | "\n", 132 | "# 勾配計算のため、sum 関数で 1階テンソルの関数値をスカラー化する\n", 133 | "z = y.sum()\n", 134 | "\n", 135 | "# 勾配計算\n", 136 | "z.backward()\n", 137 | "\n", 138 | "# 勾配値の取得\n", 139 | "print('xの値')\n", 140 | "print(x)\n", 141 | "print('xの勾配値')\n", 142 | "print(x.grad)" 143 | ], 144 | "metadata": { 145 | "id": "zaZPdRmKe-sF" 146 | }, 147 | "execution_count": null, 148 | "outputs": [] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "source": [ 153 | "#### 2次関数とその勾配値のグラフ" 154 | ], 155 | "metadata": { 156 | "id": "f9WA2v678iP5" 157 | } 158 | }, 159 | { 160 | "cell_type": "code", 161 | "source": [ 162 | "# 2次関数とその勾配値のグラフ\n", 163 | "\n", 164 | "plt.figure(figsize=(6,6))\n", 165 | "plt.plot(x.data, y.data, c='b', label='y')\n", 166 | "plt.plot(x.data, x.grad.data, c='k', label='y.grad')\n", 167 | "plt.legend()\n", 168 | "plt.grid()\n", 169 | "plt.title('2次関数とその勾配値のグラフ')\n", 170 | "plt.show()" 171 | ], 172 | "metadata": { 173 | "id": "-hhBAqWFfSjU" 174 | }, 175 | "execution_count": null, 176 | "outputs": [] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "source": [ 181 | "#### 計算グラフの可視化" 182 | ], 183 | "metadata": { 184 | "id": "dZXTbx8H8EAd" 185 | } 186 | }, 187 | { 188 | "cell_type": "code", 189 | "source": [ 190 | "# 計算グラフの可視化\n", 191 | "\n", 192 | "# 必要ライブラリのインポート\n", 193 | "from torchviz import make_dot\n", 194 | "\n", 195 | "# 可視化関数の呼び出し\n", 196 | "g= make_dot(z, params={'x': x})\n", 197 | "display(g)" 198 | ], 199 | "metadata": { 200 | "id": "JOdXStlU7qSG" 201 | }, 202 | "execution_count": null, 203 | "outputs": [] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "source": [ 208 | "### バージョン確認" 209 | ], 210 | "metadata": { 211 | "id": "OMYT26W-9IgV" 212 | } 213 | }, 214 | { 215 | "cell_type": "code", 216 | "source": [ 217 | "!pip install watermark -qq\n", 218 | "%load_ext watermark\n", 219 | "%watermark --iversions" 220 | ], 221 | "metadata": { 222 | "id": "k0GOosfyfyDH" 223 | }, 224 | "execution_count": null, 225 | "outputs": [] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "source": [], 230 | "metadata": { 231 | "id": "FlJnFEDKf_yw" 232 | }, 233 | "execution_count": null, 234 | "outputs": [] 235 | } 236 | ] 237 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /notebooks/11_01_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true, 8 | "gpuType": "T4" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "### 11.1 フレームワークの利用" 24 | ], 25 | "metadata": { 26 | "id": "LM12uBlAv7Nh" 27 | } 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "### ライブラリ・環境設定" 33 | ], 34 | "metadata": { 35 | "id": "jylsKAu3xsAR" 36 | } 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "source": [ 41 | "#### ライブラリ導入" 42 | ], 43 | "metadata": { 44 | "id": "SSq84FlOwK_d" 45 | } 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "id": "L9gNlklzvwP5" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# 必要ライブラリ追加導入\n", 56 | "!pip install japanize-matplotlib -qq\n", 57 | "!pip install torchviz -qq\n", 58 | "!pip install torchinfo -qq" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "source": [ 64 | "#### ライブラリインポート" 65 | ], 66 | "metadata": { 67 | "id": "QvjcvAf8xLQ6" 68 | } 69 | }, 70 | { 71 | "cell_type": "code", 72 | "source": [ 73 | "import warnings\n", 74 | "import numpy as np\n", 75 | "import pandas as pd\n", 76 | "import matplotlib.pyplot as plt\n", 77 | "import japanize_matplotlib\n", 78 | "import torch\n", 79 | "import torch.nn as nn\n", 80 | "import torch.optim as optim\n", 81 | "from torch.utils.data import DataLoader\n", 82 | "from torchvision import datasets, transforms\n", 83 | "from torchinfo import summary\n", 84 | "from tqdm.notebook import tqdm" 85 | ], 86 | "metadata": { 87 | "id": "Ht5HZ0vSv_2K" 88 | }, 89 | "execution_count": null, 90 | "outputs": [] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "source": [ 95 | "#### 環境設定" 96 | ], 97 | "metadata": { 98 | "id": "rS6Q8vz_xjxN" 99 | } 100 | }, 101 | { 102 | "cell_type": "code", 103 | "source": [ 104 | "np.set_printoptions(formatter={'float': '{:0.3f}'.format})\n", 105 | "pd.options.display.float_format = '{:.3f}'.format\n", 106 | "warnings.filterwarnings('ignore')" 107 | ], 108 | "metadata": { 109 | "id": "JDsTIanrxiuT" 110 | }, 111 | "execution_count": null, 112 | "outputs": [] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "source": [ 117 | "#### GPU存在チェック" 118 | ], 119 | "metadata": { 120 | "id": "cfPB0_4CyYas" 121 | } 122 | }, 123 | { 124 | "cell_type": "code", 125 | "source": [ 126 | "# GPU存在チェック\n", 127 | "\n", 128 | "# GPUが利用可能かどうかのチェック\n", 129 | "device = torch.device(\"cuda:0\" \\\n", 130 | "if torch.cuda.is_available() else \"cpu\")\n", 131 | "\n", 132 | "# 利用可能な場合は\"cuda:0\"が出力される\n", 133 | "print(device)" 134 | ], 135 | "metadata": { 136 | "id": "MEWmXgQPx_Tg" 137 | }, 138 | "execution_count": null, 139 | "outputs": [] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "source": [ 144 | "### データ準備" 145 | ], 146 | "metadata": { 147 | "id": "Aa7AHIUkVmDe" 148 | } 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "source": [ 153 | "#### データローダー構築" 154 | ], 155 | "metadata": { 156 | "id": "ghXrMUzUy0iK" 157 | } 158 | }, 159 | { 160 | "cell_type": "code", 161 | "source": [ 162 | "# データローダー構築\n", 163 | "def get_data_loaders(batch_size=100, data_dir=\"./data\"):\n", 164 | " \"\"\"MNISTの訓練・テストデータをDataLoaderで返す\"\"\"\n", 165 | " transform = transforms.Compose([\n", 166 | " transforms.ToTensor(),\n", 167 | " transforms.Normalize((0.1307,), (0.3081,))\n", 168 | " ])\n", 169 | "\n", 170 | " train_set = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)\n", 171 | " test_set = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)\n", 172 | "\n", 173 | " train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)\n", 174 | " test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True)\n", 175 | "\n", 176 | " return train_loader, test_loader\n", 177 | "\n", 178 | "batch_size = 100\n", 179 | "train_loader, test_loader = get_data_loaders(batch_size)" 180 | ], 181 | "metadata": { 182 | "id": "78o7UspryhQv" 183 | }, 184 | "execution_count": null, 185 | "outputs": [] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "### モデル構築" 191 | ], 192 | "metadata": { 193 | "id": "6u7rvOBVzGWJ" 194 | } 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "source": [ 199 | "#### モデル定義" 200 | ], 201 | "metadata": { 202 | "id": "qP-2o1IeV0_e" 203 | } 204 | }, 205 | { 206 | "cell_type": "code", 207 | "source": [ 208 | "# モデル定義\n", 209 | "\n", 210 | "class Net(nn.Module):\n", 211 | " \"\"\"全結合1層のシンプルなNN(MNIST想定: 28*28 -> 100 -> 10)\"\"\"\n", 212 | " def __init__(self, n_input=28*28, n_hidden=100, n_output=10):\n", 213 | " super().__init__()\n", 214 | " self.net = nn.Sequential(\n", 215 | " nn.Linear(n_input, n_hidden),\n", 216 | " nn.ReLU(inplace=True),\n", 217 | " nn.Linear(n_hidden, n_output)\n", 218 | " )\n", 219 | "\n", 220 | " def forward(self, x):\n", 221 | " # (B, 1, 28, 28) -> (B, 784)\n", 222 | " x = torch.flatten(x, 1)\n", 223 | " return self.net(x)" 224 | ], 225 | "metadata": { 226 | "id": "0YPkV7sRy8zu" 227 | }, 228 | "execution_count": null, 229 | "outputs": [] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "source": [ 234 | "#### 訓練用関数" 235 | ], 236 | "metadata": { 237 | "id": "a1-SrTEDzkPD" 238 | } 239 | }, 240 | { 241 | "cell_type": "code", 242 | "source": [ 243 | "# 学習用関数\n", 244 | "\n", 245 | "def train_one_epoch(model, loader, criterion, optimizer):\n", 246 | " model.train()\n", 247 | " running_loss, correct, total = 0.0, 0, 0\n", 248 | "\n", 249 | " for inputs, labels in loader:\n", 250 | " inputs = inputs.to(device)\n", 251 | " labels = labels.to(device)\n", 252 | "\n", 253 | " optimizer.zero_grad()\n", 254 | " outputs = model(inputs)\n", 255 | " loss = criterion(outputs, labels)\n", 256 | " loss.backward()\n", 257 | " optimizer.step()\n", 258 | "\n", 259 | " with torch.no_grad():\n", 260 | " preds = outputs.argmax(1)\n", 261 | " running_loss += loss.item() * labels.size(0)\n", 262 | " correct += (preds == labels).sum().item()\n", 263 | " total += labels.size(0)\n", 264 | "\n", 265 | " return running_loss / total, correct / total" 266 | ], 267 | "metadata": { 268 | "id": "IC62CIeKzWHw" 269 | }, 270 | "execution_count": null, 271 | "outputs": [] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "source": [ 276 | "#### 検証用関数" 277 | ], 278 | "metadata": { 279 | "id": "Z1wa7CvWSALc" 280 | } 281 | }, 282 | { 283 | "cell_type": "code", 284 | "source": [ 285 | "# 検証用関数\n", 286 | "\n", 287 | "@torch.no_grad()\n", 288 | "def validate(model, loader, criterion):\n", 289 | " model.eval()\n", 290 | " running_loss, correct, total = 0.0, 0, 0\n", 291 | "\n", 292 | " for inputs, labels in loader:\n", 293 | " inputs = inputs.to(device)\n", 294 | " labels = labels.to(device)\n", 295 | "\n", 296 | " outputs = model(inputs)\n", 297 | " loss = criterion(outputs, labels)\n", 298 | " preds = outputs.argmax(1)\n", 299 | "\n", 300 | " running_loss += loss.item() * labels.size(0)\n", 301 | " correct += (preds == labels).sum().item()\n", 302 | " total += labels.size(0)\n", 303 | "\n", 304 | " return running_loss / total, correct / total" 305 | ], 306 | "metadata": { 307 | "id": "nJDJfVU1zvS-" 308 | }, 309 | "execution_count": null, 310 | "outputs": [] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "source": [ 315 | "#### 学習関数" 316 | ], 317 | "metadata": { 318 | "id": "FbNqnE-PUS2K" 319 | } 320 | }, 321 | { 322 | "cell_type": "code", 323 | "source": [ 324 | "def fit(*, n_hidden=100, num_epochs=20, lr=0.01, optimizer_class=optim.SGD, seed=42):\n", 325 | " \"\"\"\n", 326 | " 引数(すべてキーワード専用):\n", 327 | " n_hidden : 中間層ユニット数\n", 328 | " num_epochs : 学習エポック数\n", 329 | " lr : 学習率\n", 330 | " optimizer_class : 例) optim.SGD / optim.Adam\n", 331 | " seed : 乱数シード\n", 332 | "\n", 333 | " 返り値:\n", 334 | " model, history(np.ndarray: [epoch, train_loss, train_acc, val_loss, val_acc])\n", 335 | "\n", 336 | " 依存(グローバル):\n", 337 | " train_loader, test_loader, device\n", 338 | " \"\"\"\n", 339 | " torch.manual_seed(seed)\n", 340 | " np.random.seed(seed)\n", 341 | "\n", 342 | " model = Net(n_hidden=n_hidden).to(device)\n", 343 | " criterion = nn.CrossEntropyLoss()\n", 344 | " optimizer = optimizer_class(model.parameters(), lr=lr)\n", 345 | "\n", 346 | " history = []\n", 347 | " for epoch in tqdm(range(1, num_epochs + 1), desc=\"Training\"):\n", 348 | " tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion, optimizer)\n", 349 | " va_loss, va_acc = validate(model, test_loader, criterion)\n", 350 | "\n", 351 | " print(\n", 352 | " f\"Epoch [{epoch}/{num_epochs}] \"\n", 353 | " f\"train_loss: {tr_loss:.5f}, train_acc: {tr_acc:.5f}, \"\n", 354 | " f\"val_loss: {va_loss:.5f}, val_acc: {va_acc:.5f}\"\n", 355 | " )\n", 356 | " history.append([epoch, tr_loss, tr_acc, va_loss, va_acc])\n", 357 | "\n", 358 | " return model, np.array(history, dtype=float)" 359 | ], 360 | "metadata": { 361 | "id": "V4gsqUMzUX2M" 362 | }, 363 | "execution_count": null, 364 | "outputs": [] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "source": [ 369 | "#### 学習" 370 | ], 371 | "metadata": { 372 | "id": "8E69wVDmz38y" 373 | } 374 | }, 375 | { 376 | "cell_type": "code", 377 | "source": [ 378 | "model1, history1 = fit(num_epochs=20, lr=0.01, optimizer_class=optim.SGD)" 379 | ], 380 | "metadata": { 381 | "id": "VE0ijWPlzzpy" 382 | }, 383 | "execution_count": null, 384 | "outputs": [] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "source": [ 389 | "### 結果確認" 390 | ], 391 | "metadata": { 392 | "id": "JZRkD-xWWQRD" 393 | } 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "source": [ 398 | "#### 学習曲線の描画" 399 | ], 400 | "metadata": { 401 | "id": "luVmPDGB0Yat" 402 | } 403 | }, 404 | { 405 | "cell_type": "code", 406 | "source": [ 407 | "def plot_learning_curves(history, title_suffix=\"\"):\n", 408 | " \"\"\"損失と精度の学習曲線を並べて描画\"\"\"\n", 409 | " epochs = history[:, 0]\n", 410 | " train_loss = history[:, 1]\n", 411 | " train_acc = history[:, 2]\n", 412 | " val_loss = history[:, 3]\n", 413 | " val_acc = history[:, 4]\n", 414 | "\n", 415 | " plt.figure(figsize=(8, 4), tight_layout=True)\n", 416 | "\n", 417 | " # --- 損失曲線 ---\n", 418 | " plt.subplot(1, 2, 1)\n", 419 | " plt.plot(epochs, train_loss, label='訓練', color='b')\n", 420 | " plt.plot(epochs, val_loss, label='テスト', color='k', linestyle='--')\n", 421 | " plt.xlabel('繰り返し回数')\n", 422 | " plt.ylabel('損失')\n", 423 | " plt.title(f'学習曲線(損失){title_suffix}')\n", 424 | " plt.legend()\n", 425 | " plt.grid(True)\n", 426 | "\n", 427 | " # --- 精度曲線 ---\n", 428 | " plt.subplot(1, 2, 2)\n", 429 | " plt.plot(epochs, train_acc, label='訓練', color='b')\n", 430 | " plt.plot(epochs, val_acc, label='テスト', color='k', linestyle='--')\n", 431 | " plt.xlabel('繰り返し回数')\n", 432 | " plt.ylabel('精度')\n", 433 | " plt.title(f'学習曲線(精度){title_suffix}')\n", 434 | " plt.legend()\n", 435 | " plt.grid(True)\n", 436 | " plt.show()\n", 437 | "\n", 438 | "plot_learning_curves(history1)" 439 | ], 440 | "metadata": { 441 | "id": "5W9AnVQ80AXs" 442 | }, 443 | "execution_count": null, 444 | "outputs": [] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "source": [ 449 | "model2, history2 = fit(num_epochs=50, lr=0.01, optimizer_class=optim.SGD)\n", 450 | "plot_learning_curves(history2)" 451 | ], 452 | "metadata": { 453 | "id": "Ytx6UaAvZqAU" 454 | }, 455 | "execution_count": null, 456 | "outputs": [] 457 | }, 458 | { 459 | "cell_type": "markdown", 460 | "source": [ 461 | "### バージョン確認" 462 | ], 463 | "metadata": { 464 | "id": "HgzYQ9vc0vtU" 465 | } 466 | }, 467 | { 468 | "cell_type": "code", 469 | "source": [ 470 | "!pip install watermark -qq\n", 471 | "%load_ext watermark\n", 472 | "%watermark --iversions" 473 | ], 474 | "metadata": { 475 | "id": "VSZrDXEP0oXe" 476 | }, 477 | "execution_count": null, 478 | "outputs": [] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "source": [], 483 | "metadata": { 484 | "id": "GQdgyLIt12RM" 485 | }, 486 | "execution_count": null, 487 | "outputs": [] 488 | } 489 | ] 490 | } -------------------------------------------------------------------------------- /notebooks/11_04_batch_size.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true, 8 | "gpuType": "T4" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "### 11.4  学習の単位" 24 | ], 25 | "metadata": { 26 | "id": "LM12uBlAv7Nh" 27 | } 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "### ライブラリ・環境設定" 33 | ], 34 | "metadata": { 35 | "id": "jylsKAu3xsAR" 36 | } 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "source": [ 41 | "#### ライブラリ導入" 42 | ], 43 | "metadata": { 44 | "id": "SSq84FlOwK_d" 45 | } 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "id": "L9gNlklzvwP5" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# 必要ライブラリ追加導入\n", 56 | "!pip install japanize-matplotlib -qq\n", 57 | "!pip install torchviz -qq\n", 58 | "!pip install torchinfo -qq" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "source": [ 64 | "#### ライブラリインポート" 65 | ], 66 | "metadata": { 67 | "id": "QvjcvAf8xLQ6" 68 | } 69 | }, 70 | { 71 | "cell_type": "code", 72 | "source": [ 73 | "import warnings\n", 74 | "import numpy as np\n", 75 | "import pandas as pd\n", 76 | "import matplotlib.pyplot as plt\n", 77 | "import japanize_matplotlib\n", 78 | "import torch\n", 79 | "import torch.nn as nn\n", 80 | "import torch.optim as optim\n", 81 | "from torch.utils.data import DataLoader\n", 82 | "from torchvision import datasets, transforms\n", 83 | "from torchinfo import summary\n", 84 | "from tqdm.notebook import tqdm" 85 | ], 86 | "metadata": { 87 | "id": "Ht5HZ0vSv_2K" 88 | }, 89 | "execution_count": null, 90 | "outputs": [] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "source": [ 95 | "#### 環境設定" 96 | ], 97 | "metadata": { 98 | "id": "rS6Q8vz_xjxN" 99 | } 100 | }, 101 | { 102 | "cell_type": "code", 103 | "source": [ 104 | "np.set_printoptions(formatter={'float': '{:0.3f}'.format})\n", 105 | "pd.options.display.float_format = '{:.3f}'.format\n", 106 | "warnings.filterwarnings('ignore')" 107 | ], 108 | "metadata": { 109 | "id": "JDsTIanrxiuT" 110 | }, 111 | "execution_count": null, 112 | "outputs": [] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "source": [ 117 | "#### GPU存在チェック" 118 | ], 119 | "metadata": { 120 | "id": "cfPB0_4CyYas" 121 | } 122 | }, 123 | { 124 | "cell_type": "code", 125 | "source": [ 126 | "# GPU存在チェック\n", 127 | "\n", 128 | "# GPUが利用可能かどうかのチェック\n", 129 | "device = torch.device(\"cuda:0\" \\\n", 130 | "if torch.cuda.is_available() else \"cpu\")\n", 131 | "\n", 132 | "# 利用可能な場合は\"cuda:0\"が出力される\n", 133 | "print(device)" 134 | ], 135 | "metadata": { 136 | "id": "MEWmXgQPx_Tg" 137 | }, 138 | "execution_count": null, 139 | "outputs": [] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "source": [ 144 | "### データ準備" 145 | ], 146 | "metadata": { 147 | "id": "Aa7AHIUkVmDe" 148 | } 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "source": [ 153 | "#### データローダー構築" 154 | ], 155 | "metadata": { 156 | "id": "ghXrMUzUy0iK" 157 | } 158 | }, 159 | { 160 | "cell_type": "code", 161 | "source": [ 162 | "# データローダー構築\n", 163 | "def get_data_loaders(batch_size=100, data_dir=\"./data\"):\n", 164 | " \"\"\"MNISTの訓練・テストデータをDataLoaderで返す\"\"\"\n", 165 | " transform = transforms.Compose([\n", 166 | " transforms.ToTensor(),\n", 167 | " transforms.Normalize((0.1307,), (0.3081,))\n", 168 | " ])\n", 169 | "\n", 170 | " train_set = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)\n", 171 | " test_set = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)\n", 172 | "\n", 173 | " train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)\n", 174 | " test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True)\n", 175 | "\n", 176 | " return train_loader, test_loader\n", 177 | "\n", 178 | "batch_size = 100\n", 179 | "train_loader, test_loader = get_data_loaders(batch_size)" 180 | ], 181 | "metadata": { 182 | "id": "78o7UspryhQv" 183 | }, 184 | "execution_count": null, 185 | "outputs": [] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "### モデル構築" 191 | ], 192 | "metadata": { 193 | "id": "6u7rvOBVzGWJ" 194 | } 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "source": [ 199 | "#### モデル定義" 200 | ], 201 | "metadata": { 202 | "id": "qP-2o1IeV0_e" 203 | } 204 | }, 205 | { 206 | "cell_type": "code", 207 | "source": [ 208 | "# モデル定義\n", 209 | "\n", 210 | "class Net(nn.Module):\n", 211 | " \"\"\"全結合1層のシンプルなNN(MNIST想定: 28*28 -> 100 -> 10)\"\"\"\n", 212 | " def __init__(self, n_input=28*28, n_hidden=100, n_output=10):\n", 213 | " super().__init__()\n", 214 | " self.net = nn.Sequential(\n", 215 | " nn.Linear(n_input, n_hidden),\n", 216 | " nn.ReLU(inplace=True),\n", 217 | " nn.Linear(n_hidden, n_output)\n", 218 | " )\n", 219 | "\n", 220 | " def forward(self, x):\n", 221 | " # (B, 1, 28, 28) -> (B, 784)\n", 222 | " x = torch.flatten(x, 1)\n", 223 | " return self.net(x)" 224 | ], 225 | "metadata": { 226 | "id": "0YPkV7sRy8zu" 227 | }, 228 | "execution_count": null, 229 | "outputs": [] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "source": [ 234 | "#### 訓練用関数" 235 | ], 236 | "metadata": { 237 | "id": "a1-SrTEDzkPD" 238 | } 239 | }, 240 | { 241 | "cell_type": "code", 242 | "source": [ 243 | "# 学習用関数\n", 244 | "\n", 245 | "def train_one_epoch(model, loader, criterion, optimizer):\n", 246 | " model.train()\n", 247 | " running_loss, correct, total = 0.0, 0, 0\n", 248 | "\n", 249 | " for inputs, labels in loader:\n", 250 | " inputs = inputs.to(device)\n", 251 | " labels = labels.to(device)\n", 252 | "\n", 253 | " # 勾配初期化\n", 254 | " optimizer.zero_grad()\n", 255 | " # 予測計算\n", 256 | " outputs = model(inputs)\n", 257 | " # 損失計算\n", 258 | " loss = criterion(outputs, labels)\n", 259 | " # 勾配(微分)計算\n", 260 | " loss.backward()\n", 261 | " # パラメータ更新\n", 262 | " optimizer.step()\n", 263 | "\n", 264 | " with torch.no_grad():\n", 265 | " preds = outputs.argmax(1)\n", 266 | " running_loss += loss.item() * labels.size(0)\n", 267 | " correct += (preds == labels).sum().item()\n", 268 | " total += labels.size(0)\n", 269 | "\n", 270 | " return running_loss / total, correct / total" 271 | ], 272 | "metadata": { 273 | "id": "IC62CIeKzWHw" 274 | }, 275 | "execution_count": null, 276 | "outputs": [] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "source": [ 281 | "#### 検証用関数" 282 | ], 283 | "metadata": { 284 | "id": "Z1wa7CvWSALc" 285 | } 286 | }, 287 | { 288 | "cell_type": "code", 289 | "source": [ 290 | "# 検証用関数\n", 291 | "\n", 292 | "@torch.no_grad()\n", 293 | "def validate(model, loader, criterion):\n", 294 | " model.eval()\n", 295 | " running_loss, correct, total = 0.0, 0, 0\n", 296 | "\n", 297 | " for inputs, labels in loader:\n", 298 | " inputs = inputs.to(device)\n", 299 | " labels = labels.to(device)\n", 300 | "\n", 301 | " outputs = model(inputs)\n", 302 | " loss = criterion(outputs, labels)\n", 303 | " preds = outputs.argmax(1)\n", 304 | "\n", 305 | " running_loss += loss.item() * labels.size(0)\n", 306 | " correct += (preds == labels).sum().item()\n", 307 | " total += labels.size(0)\n", 308 | "\n", 309 | " return running_loss / total, correct / total" 310 | ], 311 | "metadata": { 312 | "id": "nJDJfVU1zvS-" 313 | }, 314 | "execution_count": null, 315 | "outputs": [] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "source": [ 320 | "#### 学習関数" 321 | ], 322 | "metadata": { 323 | "id": "FbNqnE-PUS2K" 324 | } 325 | }, 326 | { 327 | "cell_type": "code", 328 | "source": [ 329 | "def fit(*, net_class=Net, n_hidden=100, num_epochs=20, lr=0.01,\n", 330 | " batch_size=100, optimizer_class=optim.SGD,\n", 331 | " seed=42, data_dir=\"./data\"):\n", 332 | " \"\"\"\n", 333 | " 引数(すべてキーワード専用):\n", 334 | " net_class : モデルクラス(例: Net, CustomCNNなど)\n", 335 | " n_hidden : 隠れ層ノード数(Net使用時のみ)\n", 336 | " num_epochs : 繰り返し数\n", 337 | " lr : 学習率\n", 338 | " batch_size : バッチサイズ\n", 339 | " optimizer_class : 最適化関数クラス (例: optim.SGD, optim.Adam)\n", 340 | "\n", 341 | " 戻り値:\n", 342 | " model, history(np.ndarray: [epoch, train_loss, train_acc, val_loss, val_acc])\n", 343 | " \"\"\"\n", 344 | " torch.manual_seed(seed)\n", 345 | " np.random.seed(seed)\n", 346 | "\n", 347 | " # DataLoader作成(バッチサイズ指定)\n", 348 | " train_loader, test_loader = get_data_loaders(batch_size=batch_size,\n", 349 | " data_dir=data_dir)\n", 350 | "\n", 351 | " # モデル構築(引数対応)\n", 352 | " try:\n", 353 | " model = net_class(n_hidden=n_hidden).to(device)\n", 354 | " except TypeError:\n", 355 | " model = net_class().to(device)\n", 356 | "\n", 357 | " criterion = nn.CrossEntropyLoss()\n", 358 | " optimizer = optimizer_class(model.parameters(), lr=lr)\n", 359 | "\n", 360 | " history = []\n", 361 | " for epoch in tqdm(range(1, num_epochs + 1), desc=\"Training\"):\n", 362 | " tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion,\n", 363 | " optimizer)\n", 364 | " va_loss, va_acc = validate(model, test_loader, criterion)\n", 365 | "\n", 366 | " print(\n", 367 | " f\"Epoch [{epoch}/{num_epochs}] \"\n", 368 | " f\"train_loss: {tr_loss:.5f}, train_acc: {tr_acc:.5f}, \"\n", 369 | " f\"val_loss: {va_loss:.5f}, val_acc: {va_acc:.5f}\"\n", 370 | " )\n", 371 | " history.append([epoch, tr_loss, tr_acc, va_loss, va_acc])\n", 372 | "\n", 373 | " return model, np.array(history, dtype=float)" 374 | ], 375 | "metadata": { 376 | "id": "V4gsqUMzUX2M" 377 | }, 378 | "execution_count": null, 379 | "outputs": [] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "source": [ 384 | "#### 学習" 385 | ], 386 | "metadata": { 387 | "id": "8E69wVDmz38y" 388 | } 389 | }, 390 | { 391 | "cell_type": "code", 392 | "source": [ 393 | "batch_sizes = [500, 200, 100, 50]\n", 394 | "histories = []\n", 395 | "models = []\n", 396 | "for batch_size in batch_sizes:\n", 397 | " model, history = fit(num_epochs=20, lr=0.01, batch_size=batch_size)\n", 398 | " histories.append(history)\n", 399 | " models.append(model)" 400 | ], 401 | "metadata": { 402 | "id": "Vfd2bVOcDu1F" 403 | }, 404 | "execution_count": null, 405 | "outputs": [] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "source": [ 410 | "### 結果確認" 411 | ], 412 | "metadata": { 413 | "id": "gxycwR5pKhbF" 414 | } 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "source": [ 419 | "#### グラフ描画関数" 420 | ], 421 | "metadata": { 422 | "id": "orXMGV7BUbAd" 423 | } 424 | }, 425 | { 426 | "cell_type": "code", 427 | "source": [ 428 | "# グラフ描画関数\n", 429 | "\n", 430 | "import matplotlib.pyplot as plt\n", 431 | "\n", 432 | "def plot_learning_curves_multi(histories, labels=None, title_suffix=\"\"):\n", 433 | " \"\"\"\n", 434 | " 複数のhistoryを黒と青のみで重ねて描画\n", 435 | " 実線: 訓練データ, 破線: テストデータ\n", 436 | " \"\"\"\n", 437 | " plt.figure(figsize=(10, 4), tight_layout=True)\n", 438 | "\n", 439 | " # --- 線スタイルの組み合わせ(黒と青のみ)---\n", 440 | " colors = ['b', 'k', 'b', 'k'] # 4本まで対応\n", 441 | " linestyles = ['-', '--', ':', '-.'] # パターンで区別\n", 442 | "\n", 443 | " # --- 損失曲線 ---\n", 444 | " plt.subplot(1, 2, 1)\n", 445 | " for i, history in enumerate(histories):\n", 446 | " epochs = history[:, 0]\n", 447 | " train_loss = history[:, 1]\n", 448 | " val_loss = history[:, 3]\n", 449 | " color = colors[i % len(colors)]\n", 450 | " ls_val = linestyles[(i + 0) % len(linestyles)] # テストは異なる線種\n", 451 | "\n", 452 | " label_val = f\"{labels[i]}(テスト)\" if labels else f\"model{i+1}(テスト)\"\n", 453 | "\n", 454 | " plt.plot(epochs, val_loss, color=color, linestyle=ls_val, label=label_val)\n", 455 | "\n", 456 | " plt.xlabel('繰り返し回数')\n", 457 | " plt.ylabel('損失')\n", 458 | " plt.title(f'学習曲線(損失){title_suffix}')\n", 459 | " plt.legend(fontsize=8)\n", 460 | " plt.grid(True)\n", 461 | "\n", 462 | " # --- 精度曲線 ---\n", 463 | " plt.subplot(1, 2, 2)\n", 464 | " for i, history in enumerate(histories):\n", 465 | " epochs = history[:, 0]\n", 466 | " train_acc = history[:, 2]\n", 467 | " val_acc = history[:, 4]\n", 468 | " color = colors[i % len(colors)]\n", 469 | " ls_val = linestyles[(i + 0) % len(linestyles)]\n", 470 | "\n", 471 | " label_val = f\"{labels[i]}(テスト)\" if labels else f\"model{i+1}(テスト)\"\n", 472 | "\n", 473 | " plt.plot(epochs, val_acc, color=color, linestyle=ls_val, label=label_val)\n", 474 | "\n", 475 | " plt.xlabel('繰り返し回数')\n", 476 | " plt.ylabel('精度')\n", 477 | " plt.xticks(np.arange(0,21,2))\n", 478 | " plt.title(f'学習曲線(精度){title_suffix}')\n", 479 | " plt.legend(fontsize=8)\n", 480 | " plt.grid(True)\n", 481 | " plt.show()" 482 | ], 483 | "metadata": { 484 | "id": "VE0ijWPlzzpy" 485 | }, 486 | "execution_count": null, 487 | "outputs": [] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "source": [ 492 | "#### グラフ描画" 493 | ], 494 | "metadata": { 495 | "id": "t1g4b_0hUfw2" 496 | } 497 | }, 498 | { 499 | "cell_type": "code", 500 | "source": [ 501 | "# グラフ描画\n", 502 | "plot_learning_curves_multi(\n", 503 | " histories,\n", 504 | " labels=[\"500\", \"200\", \"100\", \"50\"],\n", 505 | " title_suffix=\"バッチサイズの違い\"\n", 506 | ")" 507 | ], 508 | "metadata": { 509 | "id": "4hLdyqQaFtFI" 510 | }, 511 | "execution_count": null, 512 | "outputs": [] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "source": [ 517 | "### バージョン確認" 518 | ], 519 | "metadata": { 520 | "id": "HgzYQ9vc0vtU" 521 | } 522 | }, 523 | { 524 | "cell_type": "code", 525 | "source": [ 526 | "!pip install watermark -qq\n", 527 | "%load_ext watermark\n", 528 | "%watermark --iversions" 529 | ], 530 | "metadata": { 531 | "id": "VSZrDXEP0oXe" 532 | }, 533 | "execution_count": null, 534 | "outputs": [] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "source": [], 539 | "metadata": { 540 | "id": "GQdgyLIt12RM" 541 | }, 542 | "execution_count": null, 543 | "outputs": [] 544 | } 545 | ] 546 | } -------------------------------------------------------------------------------- /notebooks/11_06_optimizarion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true, 8 | "gpuType": "T4" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "### 11.6  最適化関数" 24 | ], 25 | "metadata": { 26 | "id": "LM12uBlAv7Nh" 27 | } 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "### ライブラリ・環境設定" 33 | ], 34 | "metadata": { 35 | "id": "jylsKAu3xsAR" 36 | } 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "source": [ 41 | "#### ライブラリ導入" 42 | ], 43 | "metadata": { 44 | "id": "SSq84FlOwK_d" 45 | } 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "id": "L9gNlklzvwP5" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# 必要ライブラリ追加導入\n", 56 | "!pip install japanize-matplotlib -qq\n", 57 | "!pip install torchviz -qq\n", 58 | "!pip install torchinfo -qq" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "source": [ 64 | "#### ライブラリインポート" 65 | ], 66 | "metadata": { 67 | "id": "QvjcvAf8xLQ6" 68 | } 69 | }, 70 | { 71 | "cell_type": "code", 72 | "source": [ 73 | "import warnings\n", 74 | "import numpy as np\n", 75 | "import pandas as pd\n", 76 | "import matplotlib.pyplot as plt\n", 77 | "import japanize_matplotlib\n", 78 | "import torch\n", 79 | "import torch.nn as nn\n", 80 | "import torch.optim as optim\n", 81 | "from torch.utils.data import DataLoader\n", 82 | "from torchvision import datasets, transforms\n", 83 | "from torchinfo import summary\n", 84 | "from tqdm.notebook import tqdm" 85 | ], 86 | "metadata": { 87 | "id": "Ht5HZ0vSv_2K" 88 | }, 89 | "execution_count": null, 90 | "outputs": [] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "source": [ 95 | "#### 環境設定" 96 | ], 97 | "metadata": { 98 | "id": "rS6Q8vz_xjxN" 99 | } 100 | }, 101 | { 102 | "cell_type": "code", 103 | "source": [ 104 | "np.set_printoptions(formatter={'float': '{:0.3f}'.format})\n", 105 | "pd.options.display.float_format = '{:.3f}'.format\n", 106 | "warnings.filterwarnings('ignore')" 107 | ], 108 | "metadata": { 109 | "id": "JDsTIanrxiuT" 110 | }, 111 | "execution_count": null, 112 | "outputs": [] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "source": [ 117 | "#### GPU存在チェック" 118 | ], 119 | "metadata": { 120 | "id": "cfPB0_4CyYas" 121 | } 122 | }, 123 | { 124 | "cell_type": "code", 125 | "source": [ 126 | "# GPU存在チェック\n", 127 | "\n", 128 | "# GPUが利用可能かどうかのチェック\n", 129 | "device = torch.device(\"cuda:0\" \\\n", 130 | "if torch.cuda.is_available() else \"cpu\")\n", 131 | "\n", 132 | "# 利用可能な場合は\"cuda:0\"が出力される\n", 133 | "print(device)" 134 | ], 135 | "metadata": { 136 | "id": "MEWmXgQPx_Tg" 137 | }, 138 | "execution_count": null, 139 | "outputs": [] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "source": [ 144 | "### データ準備" 145 | ], 146 | "metadata": { 147 | "id": "Aa7AHIUkVmDe" 148 | } 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "source": [ 153 | "#### データローダー構築" 154 | ], 155 | "metadata": { 156 | "id": "ghXrMUzUy0iK" 157 | } 158 | }, 159 | { 160 | "cell_type": "code", 161 | "source": [ 162 | "# データローダー構築\n", 163 | "def get_data_loaders(batch_size=100, data_dir=\"./data\"):\n", 164 | " \"\"\"MNISTの訓練・テストデータをDataLoaderで返す\"\"\"\n", 165 | " transform = transforms.Compose([\n", 166 | " transforms.ToTensor(),\n", 167 | " transforms.Normalize((0.1307,), (0.3081,))\n", 168 | " ])\n", 169 | "\n", 170 | " train_set = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)\n", 171 | " test_set = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)\n", 172 | "\n", 173 | " train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)\n", 174 | " test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True)\n", 175 | "\n", 176 | " return train_loader, test_loader\n", 177 | "\n", 178 | "batch_size = 100\n", 179 | "train_loader, test_loader = get_data_loaders(batch_size)" 180 | ], 181 | "metadata": { 182 | "id": "78o7UspryhQv" 183 | }, 184 | "execution_count": null, 185 | "outputs": [] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "### モデル構築" 191 | ], 192 | "metadata": { 193 | "id": "6u7rvOBVzGWJ" 194 | } 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "source": [ 199 | "#### モデル定義\n", 200 | "\n" 201 | ], 202 | "metadata": { 203 | "id": "t2qWJkSpaK_p" 204 | } 205 | }, 206 | { 207 | "cell_type": "code", 208 | "source": [ 209 | "class Net(nn.Module):\n", 210 | " \"\"\"全結合1層のシンプルなNN(MNIST想定: 28*28 -> 100 -> 10)\"\"\"\n", 211 | " def __init__(self, n_input=28*28, n_hidden=100, n_output=10):\n", 212 | " super().__init__()\n", 213 | " self.net = nn.Sequential(\n", 214 | " nn.Linear(n_input, n_hidden),\n", 215 | " nn.ReLU(inplace=True),\n", 216 | " nn.Linear(n_hidden, n_output)\n", 217 | " )\n", 218 | "\n", 219 | " def forward(self, x):\n", 220 | " # (B, 1, 28, 28) -> (B, 784)\n", 221 | " x = torch.flatten(x, 1)\n", 222 | " return self.net(x)" 223 | ], 224 | "metadata": { 225 | "id": "c3jPRrSHaUbV" 226 | }, 227 | "execution_count": null, 228 | "outputs": [] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "source": [ 233 | "#### 訓練用関数" 234 | ], 235 | "metadata": { 236 | "id": "a1-SrTEDzkPD" 237 | } 238 | }, 239 | { 240 | "cell_type": "code", 241 | "source": [ 242 | "# 学習用関数\n", 243 | "\n", 244 | "def train_one_epoch(model, loader, criterion, optimizer):\n", 245 | " model.train()\n", 246 | " running_loss, correct, total = 0.0, 0, 0\n", 247 | "\n", 248 | " for inputs, labels in loader:\n", 249 | " inputs = inputs.to(device)\n", 250 | " labels = labels.to(device)\n", 251 | "\n", 252 | " # 勾配初期化\n", 253 | " optimizer.zero_grad()\n", 254 | " # 予測計算\n", 255 | " outputs = model(inputs)\n", 256 | " # 損失計算\n", 257 | " loss = criterion(outputs, labels)\n", 258 | " # 勾配(微分)計算\n", 259 | " loss.backward()\n", 260 | " # パラメータ更新\n", 261 | " optimizer.step()\n", 262 | "\n", 263 | " with torch.no_grad():\n", 264 | " preds = outputs.argmax(1)\n", 265 | " running_loss += loss.item() * labels.size(0)\n", 266 | " correct += (preds == labels).sum().item()\n", 267 | " total += labels.size(0)\n", 268 | "\n", 269 | " return running_loss / total, correct / total" 270 | ], 271 | "metadata": { 272 | "id": "IC62CIeKzWHw" 273 | }, 274 | "execution_count": null, 275 | "outputs": [] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "source": [ 280 | "#### 検証用関数" 281 | ], 282 | "metadata": { 283 | "id": "Z1wa7CvWSALc" 284 | } 285 | }, 286 | { 287 | "cell_type": "code", 288 | "source": [ 289 | "# 検証用関数\n", 290 | "\n", 291 | "@torch.no_grad()\n", 292 | "def validate(model, loader, criterion):\n", 293 | " model.eval()\n", 294 | " running_loss, correct, total = 0.0, 0, 0\n", 295 | "\n", 296 | " for inputs, labels in loader:\n", 297 | " inputs = inputs.to(device)\n", 298 | " labels = labels.to(device)\n", 299 | "\n", 300 | " outputs = model(inputs)\n", 301 | " loss = criterion(outputs, labels)\n", 302 | " preds = outputs.argmax(1)\n", 303 | "\n", 304 | " running_loss += loss.item() * labels.size(0)\n", 305 | " correct += (preds == labels).sum().item()\n", 306 | " total += labels.size(0)\n", 307 | "\n", 308 | " return running_loss / total, correct / total" 309 | ], 310 | "metadata": { 311 | "id": "nJDJfVU1zvS-" 312 | }, 313 | "execution_count": null, 314 | "outputs": [] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "source": [ 319 | "#### 学習関数" 320 | ], 321 | "metadata": { 322 | "id": "FbNqnE-PUS2K" 323 | } 324 | }, 325 | { 326 | "cell_type": "code", 327 | "source": [ 328 | "def fit(*, net_class=Net, n_hidden=100, num_epochs=20, lr=0.01,\n", 329 | " batch_size=100, optimizer_class=optim.SGD,\n", 330 | " seed=42, data_dir=\"./data\"):\n", 331 | " \"\"\"\n", 332 | " 引数(すべてキーワード専用):\n", 333 | " net_class : モデルクラス(例: Net, CustomCNNなど)\n", 334 | " n_hidden : 隠れ層ノード数(Net使用時のみ)\n", 335 | " num_epochs : 繰り返し数\n", 336 | " lr : 学習率\n", 337 | " batch_size : バッチサイズ\n", 338 | " optimizer_class : 最適化関数クラス (例: optim.SGD, optim.Adam)\n", 339 | "\n", 340 | " 戻り値:\n", 341 | " model, history(np.ndarray: [epoch, train_loss, train_acc, val_loss, val_acc])\n", 342 | " \"\"\"\n", 343 | " torch.manual_seed(seed)\n", 344 | " np.random.seed(seed)\n", 345 | "\n", 346 | " # DataLoader作成(バッチサイズ指定)\n", 347 | " train_loader, test_loader = get_data_loaders(batch_size=batch_size,\n", 348 | " data_dir=data_dir)\n", 349 | "\n", 350 | " # モデル構築(引数対応)\n", 351 | " try:\n", 352 | " model = net_class(n_hidden=n_hidden).to(device)\n", 353 | " except TypeError:\n", 354 | " model = net_class().to(device)\n", 355 | "\n", 356 | " criterion = nn.CrossEntropyLoss()\n", 357 | " optimizer = optimizer_class(model.parameters(), lr=lr)\n", 358 | "\n", 359 | " history = []\n", 360 | " for epoch in tqdm(range(1, num_epochs + 1), desc=\"Training\"):\n", 361 | " tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion,\n", 362 | " optimizer)\n", 363 | " va_loss, va_acc = validate(model, test_loader, criterion)\n", 364 | "\n", 365 | " print(\n", 366 | " f\"Epoch [{epoch}/{num_epochs}] \"\n", 367 | " f\"train_loss: {tr_loss:.5f}, train_acc: {tr_acc:.5f}, \"\n", 368 | " f\"val_loss: {va_loss:.5f}, val_acc: {va_acc:.5f}\"\n", 369 | " )\n", 370 | " history.append([epoch, tr_loss, tr_acc, va_loss, va_acc])\n", 371 | "\n", 372 | " return model, np.array(history, dtype=float)" 373 | ], 374 | "metadata": { 375 | "id": "V4gsqUMzUX2M" 376 | }, 377 | "execution_count": null, 378 | "outputs": [] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "source": [ 383 | "#### 学習(最適化関数間の比較)" 384 | ], 385 | "metadata": { 386 | "id": "8E69wVDmz38y" 387 | } 388 | }, 389 | { 390 | "cell_type": "code", 391 | "source": [ 392 | "# 学習(最適化関数間の比較)\n", 393 | "from functools import partial\n", 394 | "import torch.optim as optim\n", 395 | "\n", 396 | "# オリジナル(SGD, momentumなし)\n", 397 | "model1, history1 = fit(optimizer_class=optim.SGD, num_epochs=20, lr=0.01)\n", 398 | "\n", 399 | "# Momentum(SGD + momentum=0.9)\n", 400 | "model2, history2 = fit(optimizer_class=partial(optim.SGD, momentum=0.9),\n", 401 | " num_epochs=20, lr=0.01)\n", 402 | "\n", 403 | "# RMSProp\n", 404 | "model3, history3 = fit(optimizer_class=optim.RMSprop, num_epochs=20, lr=0.001)\n", 405 | "\n", 406 | "# Adam\n", 407 | "model4, history4 = fit(optimizer_class=optim.Adam, num_epochs=20, lr=0.001)" 408 | ], 409 | "metadata": { 410 | "id": "Vfd2bVOcDu1F" 411 | }, 412 | "execution_count": null, 413 | "outputs": [] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "source": [ 418 | "### 結果確認" 419 | ], 420 | "metadata": { 421 | "id": "gxycwR5pKhbF" 422 | } 423 | }, 424 | { 425 | "cell_type": "code", 426 | "source": [ 427 | "import matplotlib.pyplot as plt\n", 428 | "\n", 429 | "def plot_learning_curves_multi(histories, labels=None, title_suffix=\"\"):\n", 430 | " \"\"\"\n", 431 | " 複数のhistoryを黒と青のみで重ねて描画\n", 432 | " 実線: 訓練データ, 破線: テストデータ\n", 433 | " \"\"\"\n", 434 | " plt.figure(figsize=(10, 4), tight_layout=True)\n", 435 | "\n", 436 | " # --- 線スタイルの組み合わせ(黒と青のみ)---\n", 437 | " colors = ['b', 'k', 'b', 'k'] # 4本まで対応\n", 438 | " linestyles = ['-', '--', ':', '-.'] # パターンで区別\n", 439 | "\n", 440 | " # --- 損失曲線 ---\n", 441 | " plt.subplot(1, 2, 1)\n", 442 | " for i, history in enumerate(histories):\n", 443 | " epochs = history[:, 0]\n", 444 | " train_loss = history[:, 1]\n", 445 | " val_loss = history[:, 3]\n", 446 | " color = colors[i % len(colors)]\n", 447 | " ls_val = linestyles[(i + 0) % len(linestyles)] # テストは異なる線種\n", 448 | "\n", 449 | " label_val = f\"{labels[i]}(テスト)\" if labels else f\"model{i+1}(テスト)\"\n", 450 | "\n", 451 | " plt.plot(epochs, val_loss, color=color, linestyle=ls_val, label=label_val)\n", 452 | "\n", 453 | " plt.xlabel('繰り返し回数')\n", 454 | " plt.ylabel('損失')\n", 455 | " plt.xticks(np.arange(0,21,2))\n", 456 | " plt.title(f'学習曲線(損失){title_suffix}')\n", 457 | " plt.legend(fontsize=8)\n", 458 | " plt.grid(True)\n", 459 | "\n", 460 | " # --- 精度曲線 ---\n", 461 | " plt.subplot(1, 2, 2)\n", 462 | " for i, history in enumerate(histories):\n", 463 | " epochs = history[:, 0]\n", 464 | " train_acc = history[:, 2]\n", 465 | " val_acc = history[:, 4]\n", 466 | " color = colors[i % len(colors)]\n", 467 | " ls_val = linestyles[(i + 0) % len(linestyles)]\n", 468 | "\n", 469 | " label_val = f\"{labels[i]}(テスト)\" if labels else f\"model{i+1}(テスト)\"\n", 470 | "\n", 471 | " plt.plot(epochs, val_acc, color=color, linestyle=ls_val, label=label_val)\n", 472 | "\n", 473 | " plt.xlabel('繰り返し回数')\n", 474 | " plt.ylabel('精度')\n", 475 | " plt.xticks(np.arange(0,21,2))\n", 476 | " plt.title(f'学習曲線(精度){title_suffix}')\n", 477 | " plt.legend(fontsize=8)\n", 478 | " plt.grid(True)\n", 479 | " plt.show()" 480 | ], 481 | "metadata": { 482 | "id": "VE0ijWPlzzpy" 483 | }, 484 | "execution_count": null, 485 | "outputs": [] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "source": [ 490 | "# 3つのhistoryを比較して描画\n", 491 | "histories = [history1, history2, history3, history4]\n", 492 | "plot_learning_curves_multi(\n", 493 | " histories,\n", 494 | " labels=[\"SGD\", \"Momentum\", \"RMSProp\", \"Adam\"],\n", 495 | " title_suffix=\"最適化関数\"\n", 496 | ")" 497 | ], 498 | "metadata": { 499 | "id": "4hLdyqQaFtFI" 500 | }, 501 | "execution_count": null, 502 | "outputs": [] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "source": [ 507 | "### バージョン確認" 508 | ], 509 | "metadata": { 510 | "id": "HgzYQ9vc0vtU" 511 | } 512 | }, 513 | { 514 | "cell_type": "code", 515 | "source": [ 516 | "!pip install watermark -qq\n", 517 | "%load_ext watermark\n", 518 | "%watermark --iversions" 519 | ], 520 | "metadata": { 521 | "id": "VSZrDXEP0oXe" 522 | }, 523 | "execution_count": null, 524 | "outputs": [] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "source": [], 529 | "metadata": { 530 | "id": "GQdgyLIt12RM" 531 | }, 532 | "execution_count": null, 533 | "outputs": [] 534 | } 535 | ] 536 | } -------------------------------------------------------------------------------- /notebooks/07_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "source": [ 21 | "### 7章 線形回帰" 22 | ], 23 | "metadata": { 24 | "id": "Ax41Ij4jFBRz" 25 | } 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "source": [ 30 | "### 環境準備" 31 | ], 32 | "metadata": { 33 | "id": "v8v_1Z_wFZwx" 34 | } 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "source": [ 39 | "#### ライブラリ導入" 40 | ], 41 | "metadata": { 42 | "id": "3985O09wou7_" 43 | } 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "id": "jerrUYY3FAt8" 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "# ライブラリ導入\n", 54 | "!pip install japanize-matplotlib -q" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "source": [ 60 | "#### ライブラリインポート" 61 | ], 62 | "metadata": { 63 | "id": "o2A88whXo2GI" 64 | } 65 | }, 66 | { 67 | "cell_type": "code", 68 | "source": [ 69 | "# ライブラリインポート\n", 70 | "import pandas as pd\n", 71 | "import numpy as np\n", 72 | "import matplotlib.pyplot as plt\n", 73 | "import japanize_matplotlib\n", 74 | "from IPython.display import display\n", 75 | "from sklearn.datasets import fetch_openml\n", 76 | "import warnings" 77 | ], 78 | "metadata": { 79 | "id": "y0-MVONWFHsc" 80 | }, 81 | "execution_count": null, 82 | "outputs": [] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "source": [ 87 | "#### 環境設定" 88 | ], 89 | "metadata": { 90 | "id": "w9mfx28Xo_ER" 91 | } 92 | }, 93 | { 94 | "cell_type": "code", 95 | "source": [ 96 | "# 環境設定\n", 97 | "np.set_printoptions(formatter={'float': '{:0.3f}'.format})\n", 98 | "pd.options.display.float_format = '{:.3f}'.format\n", 99 | "warnings.filterwarnings('ignore')" 100 | ], 101 | "metadata": { 102 | "id": "K2IuA25qpCl0" 103 | }, 104 | "execution_count": null, 105 | "outputs": [] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "source": [ 110 | "### データ読み込み" 111 | ], 112 | "metadata": { 113 | "id": "fT_aj13PF3V5" 114 | } 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "source": [ 119 | "#### 読み込み関数定義" 120 | ], 121 | "metadata": { 122 | "id": "bWDggqoLF6k9" 123 | } 124 | }, 125 | { 126 | "cell_type": "code", 127 | "source": [ 128 | "# 読み込み関数定義\n", 129 | "\n", 130 | "def load_california_housing():\n", 131 | " \"\"\"California Housing データセットを取得し DataFrame を返す\"\"\"\n", 132 | " try:\n", 133 | " data = fetch_openml(name=\"california\", version=2, as_frame=True)\n", 134 | " features = data.data\n", 135 | " target = data.target.astype(float)\n", 136 | " df = pd.concat([features, target.rename(\"MedianHouseValue\")], axis=1)\n", 137 | "\n", 138 | " print(\"データの読み込みが完了しました。\")\n", 139 | " print(f\"サンプル数: {df.shape[0]:,} 件, 特徴量数: {df.shape[1]-1} 個\")\n", 140 | " return df\n", 141 | " except Exception as e:\n", 142 | " print(\"データ読み込みエラー:\", e)\n", 143 | " return None" 144 | ], 145 | "metadata": { 146 | "id": "PsZevToFFreV" 147 | }, 148 | "execution_count": null, 149 | "outputs": [] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "source": [ 154 | "#### 読み込み" 155 | ], 156 | "metadata": { 157 | "id": "boq-ThkjGBqC" 158 | } 159 | }, 160 | { 161 | "cell_type": "code", 162 | "source": [ 163 | "# 読み込み\n", 164 | "df = load_california_housing()" 165 | ], 166 | "metadata": { 167 | "id": "2bcyo0MhF_7A" 168 | }, 169 | "execution_count": null, 170 | "outputs": [] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "source": [ 175 | "#### 内容確認" 176 | ], 177 | "metadata": { 178 | "id": "BGanIGYqtB8w" 179 | } 180 | }, 181 | { 182 | "cell_type": "code", 183 | "source": [ 184 | "# 内容確認\n", 185 | "\n", 186 | "# 先頭5行表示\n", 187 | "display(df.head())" 188 | ], 189 | "metadata": { 190 | "id": "Zw6iw7kftGT3" 191 | }, 192 | "execution_count": null, 193 | "outputs": [] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "source": [ 198 | "### データ加工" 199 | ], 200 | "metadata": { 201 | "id": "umEh8dGOGp6C" 202 | } 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "source": [ 207 | "#### 特徴量MedIncの抽出" 208 | ], 209 | "metadata": { 210 | "id": "c0h0tYfPuwPq" 211 | } 212 | }, 213 | { 214 | "cell_type": "code", 215 | "source": [ 216 | "# 特徴量MedIncの抽出\n", 217 | "x_data = df[['MedInc']].values" 218 | ], 219 | "metadata": { 220 | "id": "bi-moONJu0yD" 221 | }, 222 | "execution_count": null, 223 | "outputs": [] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "source": [ 228 | "#### ダミー変数(定数項)追加" 229 | ], 230 | "metadata": { 231 | "id": "fBNfTZU3vAvI" 232 | } 233 | }, 234 | { 235 | "cell_type": "code", 236 | "source": [ 237 | "# ダミー変数(定数項)追加\n", 238 | "x = np.insert(x_data, 0, 1.0, axis=1)" 239 | ], 240 | "metadata": { 241 | "id": "gSaLyZgCvEZc" 242 | }, 243 | "execution_count": null, 244 | "outputs": [] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "source": [ 249 | "#### 目的変数ytの設定" 250 | ], 251 | "metadata": { 252 | "id": "4LiXu3ZUvO41" 253 | } 254 | }, 255 | { 256 | "cell_type": "code", 257 | "source": [ 258 | "# 目的変数ytの設定\n", 259 | "yt = df['MedianHouseValue'].values" 260 | ], 261 | "metadata": { 262 | "id": "Ym199LilvTNd" 263 | }, 264 | "execution_count": null, 265 | "outputs": [] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "source": [ 270 | "#### 加工結果確認" 271 | ], 272 | "metadata": { 273 | "id": "lLM0XE7yvml2" 274 | } 275 | }, 276 | { 277 | "cell_type": "code", 278 | "source": [ 279 | "# 加工結果確認\n", 280 | "\n", 281 | "print(\"xとytのshape\")\n", 282 | "print(f\"x.shape: {x.shape}, yt.shape: {yt.shape}\")\n", 283 | "\n", 284 | "print('xの先頭5行')\n", 285 | "print(x[:5])\n", 286 | "\n", 287 | "print('ytの先頭5要素')\n", 288 | "print(yt[:5])" 289 | ], 290 | "metadata": { 291 | "id": "jZYvB9WevqWw" 292 | }, 293 | "execution_count": null, 294 | "outputs": [] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "source": [ 299 | "#### 散布図表示" 300 | ], 301 | "metadata": { 302 | "id": "zeyvLEnVHZAU" 303 | } 304 | }, 305 | { 306 | "cell_type": "code", 307 | "source": [ 308 | "# 散布図表示\n", 309 | "# x(収入)とyt(不動産価格)の関係を散布図表示する\n", 310 | "\n", 311 | "plt.figure(figsize=(6,6))\n", 312 | "plt.scatter(x[:,1], yt, s=0.5, c='blue')\n", 313 | "plt.title('収入 vs 不動産価格の散布図', fontsize=14)\n", 314 | "plt.xlabel('MedInc(収入)', fontsize=13)\n", 315 | "plt.ylabel('MedianHouseValue(不動産価格)', fontsize=13)\n", 316 | "plt.grid(True)\n", 317 | "plt.show()" 318 | ], 319 | "metadata": { 320 | "id": "D2hXYdIkGwVk" 321 | }, 322 | "execution_count": null, 323 | "outputs": [] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "source": [ 328 | "### 基本関数定義" 329 | ], 330 | "metadata": { 331 | "id": "h2PC3BxYHlVO" 332 | } 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "source": [ 337 | "#### 予測関数" 338 | ], 339 | "metadata": { 340 | "id": "URizE3UCHpLR" 341 | } 342 | }, 343 | { 344 | "cell_type": "code", 345 | "source": [ 346 | "# 予測関数\n", 347 | "\n", 348 | "def predict(x, w):\n", 349 | " \"\"\"線形回帰モデルによる予測値を計算\"\"\"\n", 350 | " return x @ w" 351 | ], 352 | "metadata": { 353 | "id": "-NoEaiqKHdK9" 354 | }, 355 | "execution_count": null, 356 | "outputs": [] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "source": [ 361 | "#### 損失関数" 362 | ], 363 | "metadata": { 364 | "id": "jt6j5ka_IPQa" 365 | } 366 | }, 367 | { 368 | "cell_type": "code", 369 | "source": [ 370 | "# 損失関数\n", 371 | "\n", 372 | "def compute_loss(yp, yt):\n", 373 | " \"\"\"平均二乗誤差 (MSE) に基づく損失関数\"\"\"\n", 374 | " return np.mean((yp - yt) ** 2) / 2" 375 | ], 376 | "metadata": { 377 | "id": "oqnIEYuvH00S" 378 | }, 379 | "execution_count": null, 380 | "outputs": [] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "source": [ 385 | "### 学習" 386 | ], 387 | "metadata": { 388 | "id": "7H5wZ0O-IKjq" 389 | } 390 | }, 391 | { 392 | "cell_type": "markdown", 393 | "source": [ 394 | "#### 学習関数" 395 | ], 396 | "metadata": { 397 | "id": "xjYsaqd0IVwT" 398 | } 399 | }, 400 | { 401 | "cell_type": "code", 402 | "source": [ 403 | "# 学習関数\n", 404 | "def train_linear_regression(x, yt, alpha=0.005, iters=5000, his_unit=100):\n", 405 | " \"\"\"勾配降下法による線形回帰モデルの学習\"\"\"\n", 406 | " # M(データ件数)とD(入力データ要素数)の設定\n", 407 | " M, D = x.shape\n", 408 | " # 重みベクトル初期化(全要素1を設定)\n", 409 | " w = np.ones(D)\n", 410 | " # 学習過程記録用\n", 411 | " history = np.zeros((0,2))\n", 412 | "\n", 413 | " # 繰り返し処理\n", 414 | " for k in range(iters):\n", 415 | " # 予測計算\n", 416 | " yp = predict(x, w)\n", 417 | " # 誤差計算\n", 418 | " yd = yp - yt\n", 419 | " # 勾配(微分)計算\n", 420 | " grad = (x.T @ yd)/M\n", 421 | " # パラメータ修正\n", 422 | " w -= alpha * grad\n", 423 | "\n", 424 | " if k % his_unit == 0:\n", 425 | " # 損失計算\n", 426 | " loss = compute_loss(yp, yt)\n", 427 | " # 記録用変数に追記\n", 428 | " history = np.vstack((history, np.array([k, loss])))\n", 429 | " # 結果の画面表示\n", 430 | " print(f\"iter={k:5d} | loss={loss:.6f}\")\n", 431 | " return w, history" 432 | ], 433 | "metadata": { 434 | "id": "S_aiRDf-IUsK" 435 | }, 436 | "execution_count": null, 437 | "outputs": [] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "source": [ 442 | "#### 学習" 443 | ], 444 | "metadata": { 445 | "id": "zswqLhrZIfkr" 446 | } 447 | }, 448 | { 449 | "cell_type": "code", 450 | "source": [ 451 | "# 学習\n", 452 | "\n", 453 | "# 学習率と繰り返し回数の設定\n", 454 | "alpha = 0.005\n", 455 | "iters = 5000\n", 456 | "his_unit = 100\n", 457 | "\n", 458 | "# 繰り返し処理\n", 459 | "w, history = train_linear_regression(x, yt, alpha=alpha, \\\n", 460 | " iters=iters, his_unit=his_unit)" 461 | ], 462 | "metadata": { 463 | "id": "4LXWiMjkIceP" 464 | }, 465 | "execution_count": null, 466 | "outputs": [] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "source": [ 471 | "### 結果分析" 472 | ], 473 | "metadata": { 474 | "id": "1S-oHUrHI6We" 475 | } 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "source": [ 480 | "#### 損失の確認" 481 | ], 482 | "metadata": { 483 | "id": "gh817X_zI9Pq" 484 | } 485 | }, 486 | { 487 | "cell_type": "code", 488 | "source": [ 489 | "# 損失の確認\n", 490 | "print(f\"損失初期値: {history[0,1]:.06f}\")\n", 491 | "print(f\"損失最終値: {history[-1,1]:.06f}\")" 492 | ], 493 | "metadata": { 494 | "id": "CHaiOhW-Ik7g" 495 | }, 496 | "execution_count": null, 497 | "outputs": [] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "source": [ 502 | "#### 学習曲線(損失)" 503 | ], 504 | "metadata": { 505 | "id": "eFAMfXO5JS9v" 506 | } 507 | }, 508 | { 509 | "cell_type": "code", 510 | "source": [ 511 | "# 学習曲線(損失)\n", 512 | "plt.figure(figsize=(6,6))\n", 513 | "plt.plot(history[1:,0], history[1:,1])\n", 514 | "plt.title('学習曲線(損失)', fontsize=14)\n", 515 | "plt.grid()\n", 516 | "plt.xlabel('繰り返し回数', fontsize=13)\n", 517 | "plt.ylabel('損失', fontsize=13)\n", 518 | "plt.show()" 519 | ], 520 | "metadata": { 521 | "id": "SC4uuiroJQKP" 522 | }, 523 | "execution_count": null, 524 | "outputs": [] 525 | }, 526 | { 527 | "cell_type": "markdown", 528 | "source": [ 529 | "#### 回帰直線の可視化" 530 | ], 531 | "metadata": { 532 | "id": "LiAdKEWcJzN-" 533 | } 534 | }, 535 | { 536 | "cell_type": "code", 537 | "source": [ 538 | "# 回帰直線の可視化\n", 539 | "\n", 540 | "# 回帰直線の座標計算\n", 541 | "xall = x[:,1].ravel()\n", 542 | "xl = np.array([[1, xall.min()], [1, xall.max()]])\n", 543 | "yl = predict(xl, w)\n", 544 | "\n", 545 | "# グラフ描画\n", 546 | "plt.figure(figsize=(6,6))\n", 547 | "plt.scatter(x[:,1], yt, s=0.1, c='b', label='観測データ')\n", 548 | "plt.plot(xl[:,1], yl, c='k', lw=2, label='学習後の回帰直線')\n", 549 | "plt.title('散布図と回帰直線', fontsize=14)\n", 550 | "plt.xlabel('収入', fontsize=14)\n", 551 | "plt.ylabel('不動産価格', fontsize=14)\n", 552 | "plt.legend()\n", 553 | "plt.grid()\n", 554 | "plt.show()" 555 | ], 556 | "metadata": { 557 | "id": "yr3otJj0JfuE" 558 | }, 559 | "execution_count": null, 560 | "outputs": [] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "source": [ 565 | "### 重回帰" 566 | ], 567 | "metadata": { 568 | "id": "-SpOocCpKivY" 569 | } 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "source": [ 574 | "#### データ加工(重回帰)" 575 | ], 576 | "metadata": { 577 | "id": "xOc8xIboNMKE" 578 | } 579 | }, 580 | { 581 | "cell_type": "code", 582 | "source": [ 583 | "# データ加工(重回帰)\n", 584 | "\n", 585 | "# 特徴量AveRoomsの抽出\n", 586 | "x_add = df[['AveRooms']].values\n", 587 | "\n", 588 | "# 特徴量AveRoomsの説明変数xへの追加\n", 589 | "x2 = np.hstack((x, x_add))\n", 590 | "\n", 591 | "# 結果確認\n", 592 | "print(\"x2のshape\")\n", 593 | "print(f\"x2.shape: {x2.shape}\")\n", 594 | "print('x2の先頭5行')\n", 595 | "print(x2[:5])" 596 | ], 597 | "metadata": { 598 | "id": "6_zXLPse1zb2" 599 | }, 600 | "execution_count": null, 601 | "outputs": [] 602 | }, 603 | { 604 | "cell_type": "markdown", 605 | "source": [ 606 | "#### 学習(重回帰)" 607 | ], 608 | "metadata": { 609 | "id": "yLdJibTSLBve" 610 | } 611 | }, 612 | { 613 | "cell_type": "code", 614 | "source": [ 615 | "# 学習(重回帰)\n", 616 | "\n", 617 | "# 学習率と繰り返し回数の設定\n", 618 | "alpha = 0.005\n", 619 | "iters = 5000\n", 620 | "his_unit = 100\n", 621 | "\n", 622 | "# 繰り返し処理\n", 623 | "w2, history2 = train_linear_regression(x2, yt, alpha=alpha, \\\n", 624 | " iters=iters, his_unit=his_unit)" 625 | ], 626 | "metadata": { 627 | "id": "sk18uh1oKtkG" 628 | }, 629 | "execution_count": null, 630 | "outputs": [] 631 | }, 632 | { 633 | "cell_type": "markdown", 634 | "source": [ 635 | "#### 結果分析(重回帰)" 636 | ], 637 | "metadata": { 638 | "id": "HUnuqX5WLMGS" 639 | } 640 | }, 641 | { 642 | "cell_type": "code", 643 | "source": [ 644 | "# 結果分析(重回帰)\n", 645 | "\n", 646 | "# 損失の確認\n", 647 | "print(f\"損失 初期値 : {history2[0,1]:.06f}\")\n", 648 | "print(f\"損失 最終値 : {history2[-1,1]:.06f}\")\n", 649 | "\n", 650 | "# 学習曲線(損失 重回帰モデル)\n", 651 | "plt.figure(figsize=(6,6))\n", 652 | "plt.plot(history[1:,0], history2[1:,1], color='blue')\n", 653 | "plt.title('学習曲線(損失 重回帰モデル)', fontsize=14)\n", 654 | "plt.xlabel('繰り返し回数', fontsize=13)\n", 655 | "plt.ylabel('損失', fontsize=13)\n", 656 | "plt.grid(True)\n", 657 | "plt.show()" 658 | ], 659 | "metadata": { 660 | "id": "5MJTE9HQLIjh" 661 | }, 662 | "execution_count": null, 663 | "outputs": [] 664 | }, 665 | { 666 | "cell_type": "markdown", 667 | "source": [ 668 | "### バージョン確認" 669 | ], 670 | "metadata": { 671 | "id": "hSkYuskK3e9f" 672 | } 673 | }, 674 | { 675 | "cell_type": "code", 676 | "source": [ 677 | "!pip install watermark -qq\n", 678 | "%load_ext watermark\n", 679 | "%watermark --iversions" 680 | ], 681 | "metadata": { 682 | "id": "LkOSsgSZLT8N" 683 | }, 684 | "execution_count": null, 685 | "outputs": [] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "source": [], 690 | "metadata": { 691 | "id": "3BR-_dbR3s14" 692 | }, 693 | "execution_count": null, 694 | "outputs": [] 695 | } 696 | ] 697 | } -------------------------------------------------------------------------------- /notebooks/11_07_cnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true, 8 | "gpuType": "T4" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "### 11.7 CNN" 24 | ], 25 | "metadata": { 26 | "id": "LM12uBlAv7Nh" 27 | } 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "### ライブラリ・環境設定" 33 | ], 34 | "metadata": { 35 | "id": "jylsKAu3xsAR" 36 | } 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "source": [ 41 | "#### ライブラリ導入" 42 | ], 43 | "metadata": { 44 | "id": "SSq84FlOwK_d" 45 | } 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "id": "L9gNlklzvwP5" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# 必要ライブラリ追加導入\n", 56 | "!pip install japanize-matplotlib -qq\n", 57 | "!pip install torchviz -qq\n", 58 | "!pip install torchinfo -qq" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "source": [ 64 | "#### ライブラリインポート" 65 | ], 66 | "metadata": { 67 | "id": "QvjcvAf8xLQ6" 68 | } 69 | }, 70 | { 71 | "cell_type": "code", 72 | "source": [ 73 | "import warnings\n", 74 | "import numpy as np\n", 75 | "import pandas as pd\n", 76 | "import matplotlib.pyplot as plt\n", 77 | "import japanize_matplotlib\n", 78 | "import torch\n", 79 | "import torch.nn as nn\n", 80 | "import torch.optim as optim\n", 81 | "from torch.utils.data import DataLoader\n", 82 | "from torchvision import datasets, transforms\n", 83 | "from torchinfo import summary\n", 84 | "from tqdm.notebook import tqdm" 85 | ], 86 | "metadata": { 87 | "id": "Ht5HZ0vSv_2K" 88 | }, 89 | "execution_count": null, 90 | "outputs": [] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "source": [ 95 | "#### 環境設定" 96 | ], 97 | "metadata": { 98 | "id": "rS6Q8vz_xjxN" 99 | } 100 | }, 101 | { 102 | "cell_type": "code", 103 | "source": [ 104 | "np.set_printoptions(formatter={'float': '{:0.3f}'.format})\n", 105 | "pd.options.display.float_format = '{:.3f}'.format\n", 106 | "warnings.filterwarnings('ignore')" 107 | ], 108 | "metadata": { 109 | "id": "JDsTIanrxiuT" 110 | }, 111 | "execution_count": null, 112 | "outputs": [] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "source": [ 117 | "#### GPU存在チェック" 118 | ], 119 | "metadata": { 120 | "id": "cfPB0_4CyYas" 121 | } 122 | }, 123 | { 124 | "cell_type": "code", 125 | "source": [ 126 | "# GPU存在チェック\n", 127 | "\n", 128 | "# GPUが利用可能かどうかのチェック\n", 129 | "device = torch.device(\"cuda:0\" \\\n", 130 | "if torch.cuda.is_available() else \"cpu\")\n", 131 | "\n", 132 | "# 利用可能な場合は\"cuda:0\"が出力される\n", 133 | "print(device)" 134 | ], 135 | "metadata": { 136 | "id": "MEWmXgQPx_Tg" 137 | }, 138 | "execution_count": null, 139 | "outputs": [] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "source": [ 144 | "### データ準備" 145 | ], 146 | "metadata": { 147 | "id": "Aa7AHIUkVmDe" 148 | } 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "source": [ 153 | "#### データローダー構築" 154 | ], 155 | "metadata": { 156 | "id": "ghXrMUzUy0iK" 157 | } 158 | }, 159 | { 160 | "cell_type": "code", 161 | "source": [ 162 | "# データローダー構築\n", 163 | "def get_data_loaders(batch_size=100, data_dir=\"./data\"):\n", 164 | " \"\"\"MNISTの訓練・テストデータをDataLoaderで返す\"\"\"\n", 165 | " transform = transforms.Compose([\n", 166 | " transforms.ToTensor(),\n", 167 | " transforms.Normalize((0.1307,), (0.3081,))\n", 168 | " ])\n", 169 | "\n", 170 | " train_set = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)\n", 171 | " test_set = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)\n", 172 | "\n", 173 | " train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)\n", 174 | " test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True)\n", 175 | "\n", 176 | " return train_loader, test_loader\n", 177 | "\n", 178 | "batch_size = 100\n", 179 | "train_loader, test_loader = get_data_loaders(batch_size)" 180 | ], 181 | "metadata": { 182 | "id": "78o7UspryhQv" 183 | }, 184 | "execution_count": null, 185 | "outputs": [] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "### モデル構築" 191 | ], 192 | "metadata": { 193 | "id": "6u7rvOBVzGWJ" 194 | } 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "source": [ 199 | "#### モデル定義" 200 | ], 201 | "metadata": { 202 | "id": "qP-2o1IeV0_e" 203 | } 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "source": [ 208 | "#### 全結合型モデル定義" 209 | ], 210 | "metadata": { 211 | "id": "8PVH74TOWXsb" 212 | } 213 | }, 214 | { 215 | "cell_type": "code", 216 | "source": [ 217 | "# 全結合型モデル定義\n", 218 | "\n", 219 | "class Net(nn.Module):\n", 220 | " \"\"\"全結合1層のシンプルなNN(MNIST想定: 28*28 -> 100 -> 10)\"\"\"\n", 221 | " def __init__(self, n_input=28*28, n_hidden=100, n_output=10):\n", 222 | " super().__init__()\n", 223 | " self.net = nn.Sequential(\n", 224 | " nn.Linear(n_input, n_hidden),\n", 225 | " nn.ReLU(inplace=True),\n", 226 | " nn.Linear(n_hidden, n_output)\n", 227 | " )\n", 228 | "\n", 229 | " def forward(self, x):\n", 230 | " # (B, 1, 28, 28) -> (B, 784)\n", 231 | " x = torch.flatten(x, 1)\n", 232 | " return self.net(x)" 233 | ], 234 | "metadata": { 235 | "id": "0YPkV7sRy8zu" 236 | }, 237 | "execution_count": null, 238 | "outputs": [] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "source": [ 243 | "#### CNNモデル定義" 244 | ], 245 | "metadata": { 246 | "id": "oOIgds5lWhXf" 247 | } 248 | }, 249 | { 250 | "cell_type": "code", 251 | "source": [ 252 | "# CNNモデル定義\n", 253 | "class CNN(nn.Module):\n", 254 | " def __init__(self):\n", 255 | " super().__init__()\n", 256 | " self.features = nn.Sequential(\n", 257 | " nn.Conv2d(1, 32, kernel_size=3, padding=1),\n", 258 | " nn.ReLU(inplace=True),\n", 259 | " nn.MaxPool2d(2),\n", 260 | " nn.Conv2d(32, 64, kernel_size=3, padding=1),\n", 261 | " nn.ReLU(inplace=True),\n", 262 | " nn.MaxPool2d(2),\n", 263 | " )\n", 264 | " self.classifier = nn.Linear(64 * 7 * 7, 10)\n", 265 | "\n", 266 | " def forward(self, x):\n", 267 | " # 入力は (B, 1, 28, 28) を想定(ToTensor()でOK)\n", 268 | " x = self.features(x)\n", 269 | " x = x.view(x.size(0), -1) # Flatten\n", 270 | " x = self.classifier(x)\n", 271 | " return x # CrossEntropyLoss用にlogitsを返す" 272 | ], 273 | "metadata": { 274 | "id": "wD4Qm1lVWkfN" 275 | }, 276 | "execution_count": null, 277 | "outputs": [] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "source": [ 282 | "#### 訓練用関数" 283 | ], 284 | "metadata": { 285 | "id": "a1-SrTEDzkPD" 286 | } 287 | }, 288 | { 289 | "cell_type": "code", 290 | "source": [ 291 | "# 学習用関数\n", 292 | "\n", 293 | "def train_one_epoch(model, loader, criterion, optimizer):\n", 294 | " model.train()\n", 295 | " running_loss, correct, total = 0.0, 0, 0\n", 296 | "\n", 297 | " for inputs, labels in loader:\n", 298 | " inputs = inputs.to(device)\n", 299 | " labels = labels.to(device)\n", 300 | "\n", 301 | " # 勾配初期化\n", 302 | " optimizer.zero_grad()\n", 303 | " # 予測計算\n", 304 | " outputs = model(inputs)\n", 305 | " # 損失計算\n", 306 | " loss = criterion(outputs, labels)\n", 307 | " # 勾配(微分)計算\n", 308 | " loss.backward()\n", 309 | " # パラメータ更新\n", 310 | " optimizer.step()\n", 311 | "\n", 312 | " with torch.no_grad():\n", 313 | " preds = outputs.argmax(1)\n", 314 | " running_loss += loss.item() * labels.size(0)\n", 315 | " correct += (preds == labels).sum().item()\n", 316 | " total += labels.size(0)\n", 317 | "\n", 318 | " return running_loss / total, correct / total" 319 | ], 320 | "metadata": { 321 | "id": "IC62CIeKzWHw" 322 | }, 323 | "execution_count": null, 324 | "outputs": [] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "source": [ 329 | "#### 検証用関数" 330 | ], 331 | "metadata": { 332 | "id": "Z1wa7CvWSALc" 333 | } 334 | }, 335 | { 336 | "cell_type": "code", 337 | "source": [ 338 | "# 検証用関数\n", 339 | "\n", 340 | "@torch.no_grad()\n", 341 | "def validate(model, loader, criterion):\n", 342 | " model.eval()\n", 343 | " running_loss, correct, total = 0.0, 0, 0\n", 344 | "\n", 345 | " for inputs, labels in loader:\n", 346 | " inputs = inputs.to(device)\n", 347 | " labels = labels.to(device)\n", 348 | "\n", 349 | " outputs = model(inputs)\n", 350 | " loss = criterion(outputs, labels)\n", 351 | " preds = outputs.argmax(1)\n", 352 | "\n", 353 | " running_loss += loss.item() * labels.size(0)\n", 354 | " correct += (preds == labels).sum().item()\n", 355 | " total += labels.size(0)\n", 356 | "\n", 357 | " return running_loss / total, correct / total" 358 | ], 359 | "metadata": { 360 | "id": "nJDJfVU1zvS-" 361 | }, 362 | "execution_count": null, 363 | "outputs": [] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "source": [ 368 | "#### 学習関数" 369 | ], 370 | "metadata": { 371 | "id": "FbNqnE-PUS2K" 372 | } 373 | }, 374 | { 375 | "cell_type": "code", 376 | "source": [ 377 | "def fit(*, net_class=Net, n_hidden=100, num_epochs=20, lr=0.01,\n", 378 | " batch_size=100, optimizer_class=optim.SGD,\n", 379 | " seed=42, data_dir=\"./data\"):\n", 380 | " \"\"\"\n", 381 | " 引数(すべてキーワード専用):\n", 382 | " net_class : モデルクラス(例: Net, CustomCNNなど)\n", 383 | " n_hidden : 隠れ層ノード数(Net使用時のみ)\n", 384 | " num_epochs : 繰り返し数\n", 385 | " lr : 学習率\n", 386 | " batch_size : バッチサイズ\n", 387 | " optimizer_class : 最適化関数クラス (例: optim.SGD, optim.Adam)\n", 388 | "\n", 389 | " 戻り値:\n", 390 | " model, history(np.ndarray: [epoch, train_loss, train_acc, val_loss, val_acc])\n", 391 | " \"\"\"\n", 392 | " torch.manual_seed(seed)\n", 393 | " np.random.seed(seed)\n", 394 | "\n", 395 | " # DataLoader作成(バッチサイズ指定)\n", 396 | " train_loader, test_loader = get_data_loaders(batch_size=batch_size,\n", 397 | " data_dir=data_dir)\n", 398 | "\n", 399 | " # モデル構築(引数対応)\n", 400 | " try:\n", 401 | " model = net_class(n_hidden=n_hidden).to(device)\n", 402 | " except TypeError:\n", 403 | " model = net_class().to(device)\n", 404 | "\n", 405 | " criterion = nn.CrossEntropyLoss()\n", 406 | " optimizer = optimizer_class(model.parameters(), lr=lr)\n", 407 | "\n", 408 | " history = []\n", 409 | " for epoch in tqdm(range(1, num_epochs + 1), desc=\"Training\"):\n", 410 | " tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion,\n", 411 | " optimizer)\n", 412 | " va_loss, va_acc = validate(model, test_loader, criterion)\n", 413 | "\n", 414 | " print(\n", 415 | " f\"Epoch [{epoch}/{num_epochs}] \"\n", 416 | " f\"train_loss: {tr_loss:.5f}, train_acc: {tr_acc:.5f}, \"\n", 417 | " f\"val_loss: {va_loss:.5f}, val_acc: {va_acc:.5f}\"\n", 418 | " )\n", 419 | " history.append([epoch, tr_loss, tr_acc, va_loss, va_acc])\n", 420 | "\n", 421 | " return model, np.array(history, dtype=float)" 422 | ], 423 | "metadata": { 424 | "id": "V4gsqUMzUX2M" 425 | }, 426 | "execution_count": null, 427 | "outputs": [] 428 | }, 429 | { 430 | "cell_type": "markdown", 431 | "source": [ 432 | "#### 学習" 433 | ], 434 | "metadata": { 435 | "id": "8E69wVDmz38y" 436 | } 437 | }, 438 | { 439 | "cell_type": "code", 440 | "source": [ 441 | "model, history = fit()\n" 442 | ], 443 | "metadata": { 444 | "id": "pYJL083SQmQZ" 445 | }, 446 | "execution_count": null, 447 | "outputs": [] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "source": [ 452 | "model_cnn, history_cnn = fit(net_class=CNN)\n" 453 | ], 454 | "metadata": { 455 | "id": "Vfd2bVOcDu1F" 456 | }, 457 | "execution_count": null, 458 | "outputs": [] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "source": [ 463 | "### 結果確認" 464 | ], 465 | "metadata": { 466 | "id": "gxycwR5pKhbF" 467 | } 468 | }, 469 | { 470 | "cell_type": "markdown", 471 | "source": [ 472 | "#### グラフ描画関数" 473 | ], 474 | "metadata": { 475 | "id": "orXMGV7BUbAd" 476 | } 477 | }, 478 | { 479 | "cell_type": "code", 480 | "source": [ 481 | "# グラフ描画関数\n", 482 | "\n", 483 | "import matplotlib.pyplot as plt\n", 484 | "\n", 485 | "def plot_learning_curves_multi(histories, labels=None, title_suffix=\"\"):\n", 486 | " \"\"\"\n", 487 | " 複数のhistoryを黒と青のみで重ねて描画\n", 488 | " 実線: 訓練データ, 破線: テストデータ\n", 489 | " \"\"\"\n", 490 | " plt.figure(figsize=(10, 4), tight_layout=True)\n", 491 | "\n", 492 | " # --- 線スタイルの組み合わせ(黒と青のみ)---\n", 493 | " colors = ['b', 'k'] # 4本まで対応\n", 494 | " linestyles = ['-', '-'] # パターンで区別\n", 495 | "\n", 496 | " # --- 損失曲線 ---\n", 497 | " plt.subplot(1, 2, 1)\n", 498 | " for i, history in enumerate(histories):\n", 499 | " epochs = history[:, 0]\n", 500 | " train_loss = history[:, 1]\n", 501 | " val_loss = history[:, 3]\n", 502 | " color = colors[i % len(colors)]\n", 503 | " ls_val = linestyles[(i + 0) % len(linestyles)] # テストは異なる線種\n", 504 | "\n", 505 | " label_val = f\"{labels[i]}(テスト)\" if labels else f\"model{i+1}(テスト)\"\n", 506 | "\n", 507 | " plt.plot(epochs, val_loss, color=color, linestyle=ls_val, label=label_val)\n", 508 | "\n", 509 | " plt.xlabel('繰り返し回数')\n", 510 | " plt.ylabel('損失')\n", 511 | " plt.xticks(np.arange(0,21,2))\n", 512 | " plt.title(f'学習曲線(損失){title_suffix}')\n", 513 | " plt.legend(fontsize=8)\n", 514 | " plt.grid(True)\n", 515 | "\n", 516 | " # --- 精度曲線 ---\n", 517 | " plt.subplot(1, 2, 2)\n", 518 | " for i, history in enumerate(histories):\n", 519 | " epochs = history[:, 0]\n", 520 | " train_acc = history[:, 2]\n", 521 | " val_acc = history[:, 4]\n", 522 | " color = colors[i % len(colors)]\n", 523 | " ls_val = linestyles[(i + 0) % len(linestyles)]\n", 524 | "\n", 525 | " label_val = f\"{labels[i]}(テスト)\" if labels else f\"model{i+1}(テスト)\"\n", 526 | "\n", 527 | " plt.plot(epochs, val_acc, color=color, linestyle=ls_val, label=label_val)\n", 528 | "\n", 529 | " plt.xlabel('繰り返し回数')\n", 530 | " plt.ylabel('精度')\n", 531 | " plt.xticks(np.arange(0,21,2))\n", 532 | " plt.title(f'学習曲線(精度){title_suffix}')\n", 533 | " plt.ylim(0.9,1.0)\n", 534 | " plt.legend(fontsize=8)\n", 535 | " plt.grid(True)\n", 536 | " plt.show()" 537 | ], 538 | "metadata": { 539 | "id": "VE0ijWPlzzpy" 540 | }, 541 | "execution_count": null, 542 | "outputs": [] 543 | }, 544 | { 545 | "cell_type": "markdown", 546 | "source": [ 547 | "#### グラフ描画" 548 | ], 549 | "metadata": { 550 | "id": "t1g4b_0hUfw2" 551 | } 552 | }, 553 | { 554 | "cell_type": "code", 555 | "source": [ 556 | "# グラフ描画\n", 557 | "plot_learning_curves_multi(\n", 558 | " [history, history_cnn],\n", 559 | " labels=[\"Net\", \"CNN\"],\n", 560 | " title_suffix=\"Net vs CNN\"\n", 561 | ")" 562 | ], 563 | "metadata": { 564 | "id": "4hLdyqQaFtFI" 565 | }, 566 | "execution_count": null, 567 | "outputs": [] 568 | }, 569 | { 570 | "cell_type": "markdown", 571 | "source": [ 572 | "### バージョン確認" 573 | ], 574 | "metadata": { 575 | "id": "HgzYQ9vc0vtU" 576 | } 577 | }, 578 | { 579 | "cell_type": "code", 580 | "source": [ 581 | "!pip install watermark -qq\n", 582 | "%load_ext watermark\n", 583 | "%watermark --iversions" 584 | ], 585 | "metadata": { 586 | "id": "VSZrDXEP0oXe" 587 | }, 588 | "execution_count": null, 589 | "outputs": [] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "source": [], 594 | "metadata": { 595 | "id": "GQdgyLIt12RM" 596 | }, 597 | "execution_count": null, 598 | "outputs": [] 599 | } 600 | ] 601 | } -------------------------------------------------------------------------------- /notebooks/11_05_overfitting_prev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true, 8 | "gpuType": "T4" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "### 11.5  過学習対策" 24 | ], 25 | "metadata": { 26 | "id": "LM12uBlAv7Nh" 27 | } 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "### ライブラリ・環境設定" 33 | ], 34 | "metadata": { 35 | "id": "jylsKAu3xsAR" 36 | } 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "source": [ 41 | "#### ライブラリ導入" 42 | ], 43 | "metadata": { 44 | "id": "SSq84FlOwK_d" 45 | } 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "id": "L9gNlklzvwP5" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# 必要ライブラリ追加導入\n", 56 | "!pip install japanize-matplotlib -qq\n", 57 | "!pip install torchviz -qq\n", 58 | "!pip install torchinfo -qq" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "source": [ 64 | "#### ライブラリインポート" 65 | ], 66 | "metadata": { 67 | "id": "QvjcvAf8xLQ6" 68 | } 69 | }, 70 | { 71 | "cell_type": "code", 72 | "source": [ 73 | "import warnings\n", 74 | "import numpy as np\n", 75 | "import pandas as pd\n", 76 | "import matplotlib.pyplot as plt\n", 77 | "import japanize_matplotlib\n", 78 | "import torch\n", 79 | "import torch.nn as nn\n", 80 | "import torch.optim as optim\n", 81 | "from torch.utils.data import DataLoader\n", 82 | "from torchvision import datasets, transforms\n", 83 | "from torchinfo import summary\n", 84 | "from tqdm.notebook import tqdm" 85 | ], 86 | "metadata": { 87 | "id": "Ht5HZ0vSv_2K" 88 | }, 89 | "execution_count": null, 90 | "outputs": [] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "source": [ 95 | "#### 環境設定" 96 | ], 97 | "metadata": { 98 | "id": "rS6Q8vz_xjxN" 99 | } 100 | }, 101 | { 102 | "cell_type": "code", 103 | "source": [ 104 | "np.set_printoptions(formatter={'float': '{:0.3f}'.format})\n", 105 | "pd.options.display.float_format = '{:.3f}'.format\n", 106 | "warnings.filterwarnings('ignore')" 107 | ], 108 | "metadata": { 109 | "id": "JDsTIanrxiuT" 110 | }, 111 | "execution_count": null, 112 | "outputs": [] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "source": [ 117 | "#### GPU存在チェック" 118 | ], 119 | "metadata": { 120 | "id": "cfPB0_4CyYas" 121 | } 122 | }, 123 | { 124 | "cell_type": "code", 125 | "source": [ 126 | "# GPU存在チェック\n", 127 | "\n", 128 | "# GPUが利用可能かどうかのチェック\n", 129 | "device = torch.device(\"cuda:0\" \\\n", 130 | "if torch.cuda.is_available() else \"cpu\")\n", 131 | "\n", 132 | "# 利用可能な場合は\"cuda:0\"が出力される\n", 133 | "print(device)" 134 | ], 135 | "metadata": { 136 | "id": "MEWmXgQPx_Tg" 137 | }, 138 | "execution_count": null, 139 | "outputs": [] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "source": [ 144 | "### データ準備" 145 | ], 146 | "metadata": { 147 | "id": "Aa7AHIUkVmDe" 148 | } 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "source": [ 153 | "#### データローダー構築" 154 | ], 155 | "metadata": { 156 | "id": "ghXrMUzUy0iK" 157 | } 158 | }, 159 | { 160 | "cell_type": "code", 161 | "source": [ 162 | "# データローダー構築\n", 163 | "def get_data_loaders(batch_size=100, data_dir=\"./data\"):\n", 164 | " \"\"\"MNISTの訓練・テストデータをDataLoaderで返す\"\"\"\n", 165 | " transform = transforms.Compose([\n", 166 | " transforms.ToTensor(),\n", 167 | " transforms.Normalize((0.1307,), (0.3081,))\n", 168 | " ])\n", 169 | "\n", 170 | " train_set = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)\n", 171 | " test_set = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)\n", 172 | "\n", 173 | " train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)\n", 174 | " test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True)\n", 175 | "\n", 176 | " return train_loader, test_loader\n", 177 | "\n", 178 | "batch_size = 100\n", 179 | "train_loader, test_loader = get_data_loaders(batch_size)" 180 | ], 181 | "metadata": { 182 | "id": "78o7UspryhQv" 183 | }, 184 | "execution_count": null, 185 | "outputs": [] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "### モデル構築" 191 | ], 192 | "metadata": { 193 | "id": "6u7rvOBVzGWJ" 194 | } 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "source": [ 199 | "#### モデル定義(オリジナル)\n", 200 | "\n" 201 | ], 202 | "metadata": { 203 | "id": "t2qWJkSpaK_p" 204 | } 205 | }, 206 | { 207 | "cell_type": "code", 208 | "source": [ 209 | "class Net(nn.Module):\n", 210 | " \"\"\"全結合1層のシンプルなNN(MNIST想定: 28*28 -> 100 -> 10)\"\"\"\n", 211 | " def __init__(self, n_input=28*28, n_hidden=100, n_output=10):\n", 212 | " super().__init__()\n", 213 | " self.net = nn.Sequential(\n", 214 | " nn.Linear(n_input, n_hidden),\n", 215 | " nn.ReLU(inplace=True),\n", 216 | " nn.Linear(n_hidden, n_output)\n", 217 | " )\n", 218 | "\n", 219 | " def forward(self, x):\n", 220 | " # (B, 1, 28, 28) -> (B, 784)\n", 221 | " x = torch.flatten(x, 1)\n", 222 | " return self.net(x)" 223 | ], 224 | "metadata": { 225 | "id": "c3jPRrSHaUbV" 226 | }, 227 | "execution_count": null, 228 | "outputs": [] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "source": [ 233 | "#### モデル定義(dropout付き)" 234 | ], 235 | "metadata": { 236 | "id": "qP-2o1IeV0_e" 237 | } 238 | }, 239 | { 240 | "cell_type": "code", 241 | "source": [ 242 | "# モデル定義(dropout付き)\n", 243 | "\n", 244 | "class Net2(nn.Module):\n", 245 | " \"\"\"ドロップアウトを導入したモデル\"\"\"\n", 246 | " def __init__(self, n_input=28*28, n_hidden=100, n_output=10, dropout_p=0.5):\n", 247 | " super().__init__()\n", 248 | " self.net = nn.Sequential(\n", 249 | " nn.Linear(n_input, n_hidden),\n", 250 | " nn.ReLU(inplace=True),\n", 251 | " nn.Dropout(p=dropout_p), # ドロップアウト層を追加\n", 252 | " nn.Linear(n_hidden, n_output)\n", 253 | " )\n", 254 | "\n", 255 | " def forward(self, x):\n", 256 | " x = torch.flatten(x, 1)\n", 257 | " return self.net(x)" 258 | ], 259 | "metadata": { 260 | "id": "0YPkV7sRy8zu" 261 | }, 262 | "execution_count": null, 263 | "outputs": [] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "source": [ 268 | "#### モデル定義(batch normalization付き)" 269 | ], 270 | "metadata": { 271 | "id": "mJSELHkQZqXr" 272 | } 273 | }, 274 | { 275 | "cell_type": "code", 276 | "source": [ 277 | "# モデル定義(batch normalization付き)\n", 278 | "class Net3(nn.Module):\n", 279 | " \"\"\"バッチ正規化を導入したモデル\"\"\"\n", 280 | " def __init__(self, n_input=28*28, n_hidden=100, n_output=10):\n", 281 | " super().__init__()\n", 282 | " self.net = nn.Sequential(\n", 283 | " nn.Linear(n_input, n_hidden),\n", 284 | " nn.BatchNorm1d(n_hidden), # バッチノーマライゼーションを追加\n", 285 | " nn.ReLU(inplace=True),\n", 286 | " nn.Linear(n_hidden, n_output)\n", 287 | " )\n", 288 | "\n", 289 | " def forward(self, x):\n", 290 | " x = torch.flatten(x, 1)\n", 291 | " return self.net(x)" 292 | ], 293 | "metadata": { 294 | "id": "5aa3W8mHZ0je" 295 | }, 296 | "execution_count": null, 297 | "outputs": [] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "source": [ 302 | "#### 訓練用関数" 303 | ], 304 | "metadata": { 305 | "id": "a1-SrTEDzkPD" 306 | } 307 | }, 308 | { 309 | "cell_type": "code", 310 | "source": [ 311 | "# 学習用関数\n", 312 | "\n", 313 | "def train_one_epoch(model, loader, criterion, optimizer):\n", 314 | " model.train()\n", 315 | " running_loss, correct, total = 0.0, 0, 0\n", 316 | "\n", 317 | " for inputs, labels in loader:\n", 318 | " inputs = inputs.to(device)\n", 319 | " labels = labels.to(device)\n", 320 | "\n", 321 | " # 勾配初期化\n", 322 | " optimizer.zero_grad()\n", 323 | " # 予測計算\n", 324 | " outputs = model(inputs)\n", 325 | " # 損失計算\n", 326 | " loss = criterion(outputs, labels)\n", 327 | " # 勾配(微分)計算\n", 328 | " loss.backward()\n", 329 | " # パラメータ更新\n", 330 | " optimizer.step()\n", 331 | "\n", 332 | " with torch.no_grad():\n", 333 | " preds = outputs.argmax(1)\n", 334 | " running_loss += loss.item() * labels.size(0)\n", 335 | " correct += (preds == labels).sum().item()\n", 336 | " total += labels.size(0)\n", 337 | "\n", 338 | " return running_loss / total, correct / total" 339 | ], 340 | "metadata": { 341 | "id": "IC62CIeKzWHw" 342 | }, 343 | "execution_count": null, 344 | "outputs": [] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "source": [ 349 | "#### 検証用関数" 350 | ], 351 | "metadata": { 352 | "id": "Z1wa7CvWSALc" 353 | } 354 | }, 355 | { 356 | "cell_type": "code", 357 | "source": [ 358 | "# 検証用関数\n", 359 | "\n", 360 | "@torch.no_grad()\n", 361 | "def validate(model, loader, criterion):\n", 362 | " model.eval()\n", 363 | " running_loss, correct, total = 0.0, 0, 0\n", 364 | "\n", 365 | " for inputs, labels in loader:\n", 366 | " inputs = inputs.to(device)\n", 367 | " labels = labels.to(device)\n", 368 | "\n", 369 | " outputs = model(inputs)\n", 370 | " loss = criterion(outputs, labels)\n", 371 | " preds = outputs.argmax(1)\n", 372 | "\n", 373 | " running_loss += loss.item() * labels.size(0)\n", 374 | " correct += (preds == labels).sum().item()\n", 375 | " total += labels.size(0)\n", 376 | "\n", 377 | " return running_loss / total, correct / total" 378 | ], 379 | "metadata": { 380 | "id": "nJDJfVU1zvS-" 381 | }, 382 | "execution_count": null, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "source": [ 388 | "#### 学習関数" 389 | ], 390 | "metadata": { 391 | "id": "FbNqnE-PUS2K" 392 | } 393 | }, 394 | { 395 | "cell_type": "code", 396 | "source": [ 397 | "def fit(*, net_class=Net, n_hidden=100, num_epochs=20, lr=0.01,\n", 398 | " batch_size=100, optimizer_class=optim.SGD,\n", 399 | " seed=42, data_dir=\"./data\"):\n", 400 | " \"\"\"\n", 401 | " 引数(すべてキーワード専用):\n", 402 | " net_class : モデルクラス(例: Net, CustomCNNなど)\n", 403 | " n_hidden : 隠れ層ノード数(Net使用時のみ)\n", 404 | " num_epochs : 繰り返し数\n", 405 | " lr : 学習率\n", 406 | " batch_size : バッチサイズ\n", 407 | " optimizer_class : 最適化関数クラス (例: optim.SGD, optim.Adam)\n", 408 | "\n", 409 | " 戻り値:\n", 410 | " model, history(np.ndarray: [epoch, train_loss, train_acc, val_loss, val_acc])\n", 411 | " \"\"\"\n", 412 | " torch.manual_seed(seed)\n", 413 | " np.random.seed(seed)\n", 414 | "\n", 415 | " # DataLoader作成(バッチサイズ指定)\n", 416 | " train_loader, test_loader = get_data_loaders(batch_size=batch_size,\n", 417 | " data_dir=data_dir)\n", 418 | "\n", 419 | " # モデル構築(引数対応)\n", 420 | " try:\n", 421 | " model = net_class(n_hidden=n_hidden).to(device)\n", 422 | " except TypeError:\n", 423 | " model = net_class().to(device)\n", 424 | "\n", 425 | " criterion = nn.CrossEntropyLoss()\n", 426 | " optimizer = optimizer_class(model.parameters(), lr=lr)\n", 427 | "\n", 428 | " history = []\n", 429 | " for epoch in tqdm(range(1, num_epochs + 1), desc=\"Training\"):\n", 430 | " tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion,\n", 431 | " optimizer)\n", 432 | " va_loss, va_acc = validate(model, test_loader, criterion)\n", 433 | "\n", 434 | " print(\n", 435 | " f\"Epoch [{epoch}/{num_epochs}] \"\n", 436 | " f\"train_loss: {tr_loss:.5f}, train_acc: {tr_acc:.5f}, \"\n", 437 | " f\"val_loss: {va_loss:.5f}, val_acc: {va_acc:.5f}\"\n", 438 | " )\n", 439 | " history.append([epoch, tr_loss, tr_acc, va_loss, va_acc])\n", 440 | "\n", 441 | " return model, np.array(history, dtype=float)" 442 | ], 443 | "metadata": { 444 | "id": "V4gsqUMzUX2M" 445 | }, 446 | "execution_count": null, 447 | "outputs": [] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "source": [ 452 | "#### 学習" 453 | ], 454 | "metadata": { 455 | "id": "8E69wVDmz38y" 456 | } 457 | }, 458 | { 459 | "cell_type": "code", 460 | "source": [ 461 | "# オリジナルモデル\n", 462 | "model1, history1 = fit(net_class=Net, num_epochs=20, lr=0.01)\n", 463 | "\n", 464 | "# Dropout付きモデル\n", 465 | "model2, history2 = fit(net_class=Net2, num_epochs=20, lr=0.01)\n", 466 | "\n", 467 | "# BatchNorm付きモデル\n", 468 | "model3, history3 = fit(net_class=Net3, num_epochs=20, lr=0.01)" 469 | ], 470 | "metadata": { 471 | "id": "Vfd2bVOcDu1F" 472 | }, 473 | "execution_count": null, 474 | "outputs": [] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "source": [ 479 | "### 結果確認" 480 | ], 481 | "metadata": { 482 | "id": "gxycwR5pKhbF" 483 | } 484 | }, 485 | { 486 | "cell_type": "code", 487 | "source": [ 488 | "import matplotlib.pyplot as plt\n", 489 | "\n", 490 | "def plot_learning_curves_multi(histories, labels=None, title_suffix=\"\"):\n", 491 | " \"\"\"\n", 492 | " 複数のhistoryを黒と青のみで重ねて描画\n", 493 | " 実線: 訓練データ, 破線: テストデータ\n", 494 | " \"\"\"\n", 495 | " plt.figure(figsize=(10, 4), tight_layout=True)\n", 496 | "\n", 497 | " # --- 線スタイルの組み合わせ(黒と青のみ)---\n", 498 | " colors = ['b', 'k', 'b'] # 4本まで対応\n", 499 | " linestyles = ['-', '--', ':'] # パターンで区別\n", 500 | "\n", 501 | " # --- 損失曲線 ---\n", 502 | " plt.subplot(1, 2, 1)\n", 503 | " for i, history in enumerate(histories):\n", 504 | " epochs = history[:, 0]\n", 505 | " train_loss = history[:, 1]\n", 506 | " val_loss = history[:, 3]\n", 507 | " color = colors[i % len(colors)]\n", 508 | " ls_val = linestyles[(i + 0) % len(linestyles)] # テストは異なる線種\n", 509 | "\n", 510 | " label_val = f\"{labels[i]}(テスト)\" if labels else f\"model{i+1}(テスト)\"\n", 511 | "\n", 512 | " plt.plot(epochs, val_loss, color=color, linestyle=ls_val, label=label_val)\n", 513 | "\n", 514 | " plt.xlabel('繰り返し回数')\n", 515 | " plt.ylabel('損失')\n", 516 | " plt.title(f'学習曲線(損失){title_suffix}')\n", 517 | " plt.legend(fontsize=8)\n", 518 | " plt.grid(True)\n", 519 | "\n", 520 | " # --- 精度曲線 ---\n", 521 | " plt.subplot(1, 2, 2)\n", 522 | " for i, history in enumerate(histories):\n", 523 | " epochs = history[:, 0]\n", 524 | " train_acc = history[:, 2]\n", 525 | " val_acc = history[:, 4]\n", 526 | " color = colors[i % len(colors)]\n", 527 | " ls_val = linestyles[(i + 0) % len(linestyles)]\n", 528 | "\n", 529 | " label_val = f\"{labels[i]}(テスト)\" if labels else f\"model{i+1}(テスト)\"\n", 530 | "\n", 531 | " plt.plot(epochs, val_acc, color=color, linestyle=ls_val, label=label_val)\n", 532 | "\n", 533 | " plt.xlabel('繰り返し回数')\n", 534 | " plt.ylabel('精度')\n", 535 | " plt.xticks(np.arange(0,21,2))\n", 536 | " plt.title(f'学習曲線(精度){title_suffix}')\n", 537 | " plt.legend(fontsize=8)\n", 538 | " plt.grid(True)\n", 539 | " plt.show()" 540 | ], 541 | "metadata": { 542 | "id": "VE0ijWPlzzpy" 543 | }, 544 | "execution_count": null, 545 | "outputs": [] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "source": [ 550 | "# 3つのhistoryを比較して描画\n", 551 | "histories = [history1, history2, history3]\n", 552 | "plot_learning_curves_multi(\n", 553 | " histories,\n", 554 | " labels=[\"orginal\", \"dropout\", \"BN\"],\n", 555 | " title_suffix=\"過学習対策\"\n", 556 | ")" 557 | ], 558 | "metadata": { 559 | "id": "4hLdyqQaFtFI" 560 | }, 561 | "execution_count": null, 562 | "outputs": [] 563 | }, 564 | { 565 | "cell_type": "markdown", 566 | "source": [ 567 | "### 結果確認" 568 | ], 569 | "metadata": { 570 | "id": "JZRkD-xWWQRD" 571 | } 572 | }, 573 | { 574 | "cell_type": "markdown", 575 | "source": [ 576 | "#### 学習曲線の描画" 577 | ], 578 | "metadata": { 579 | "id": "luVmPDGB0Yat" 580 | } 581 | }, 582 | { 583 | "cell_type": "code", 584 | "source": [ 585 | "def plot_learning_curves(history, title_suffix=\"\"):\n", 586 | " \"\"\"損失と精度の学習曲線を並べて描画\"\"\"\n", 587 | " epochs = history[:, 0]\n", 588 | " train_loss = history[:, 1]\n", 589 | " train_acc = history[:, 2]\n", 590 | " val_loss = history[:, 3]\n", 591 | " val_acc = history[:, 4]\n", 592 | "\n", 593 | " plt.figure(figsize=(8, 4), tight_layout=True)\n", 594 | "\n", 595 | " # --- 損失曲線 ---\n", 596 | " plt.subplot(1, 2, 1)\n", 597 | " plt.plot(epochs, train_loss, label='訓練', color='b')\n", 598 | " plt.plot(epochs, val_loss, label='テスト', color='k', linestyle='--')\n", 599 | " plt.xlabel('繰り返し回数')\n", 600 | " plt.ylabel('損失')\n", 601 | " plt.title(f'学習曲線(損失){title_suffix}')\n", 602 | " plt.legend()\n", 603 | " plt.grid(True)\n", 604 | "\n", 605 | " # --- 精度曲線 ---\n", 606 | " plt.subplot(1, 2, 2)\n", 607 | " plt.plot(epochs, train_acc, label='訓練', color='b')\n", 608 | " plt.plot(epochs, val_acc, label='テスト', color='k', linestyle='--')\n", 609 | " plt.xlabel('繰り返し回数')\n", 610 | " plt.ylabel('精度')\n", 611 | " plt.title(f'学習曲線(精度){title_suffix}')\n", 612 | " plt.legend()\n", 613 | " plt.grid(True)\n", 614 | " plt.show()" 615 | ], 616 | "metadata": { 617 | "id": "5W9AnVQ80AXs" 618 | }, 619 | "execution_count": null, 620 | "outputs": [] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "source": [ 625 | "### バージョン確認" 626 | ], 627 | "metadata": { 628 | "id": "HgzYQ9vc0vtU" 629 | } 630 | }, 631 | { 632 | "cell_type": "code", 633 | "source": [ 634 | "!pip install watermark -qq\n", 635 | "%load_ext watermark\n", 636 | "%watermark --iversions" 637 | ], 638 | "metadata": { 639 | "id": "VSZrDXEP0oXe" 640 | }, 641 | "execution_count": null, 642 | "outputs": [] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "source": [], 647 | "metadata": { 648 | "id": "GQdgyLIt12RM" 649 | }, 650 | "execution_count": null, 651 | "outputs": [] 652 | } 653 | ] 654 | } -------------------------------------------------------------------------------- /notebooks/08_bi_classify.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "source": [ 21 | "### 8章 二値ロジスティック回帰" 22 | ], 23 | "metadata": { 24 | "id": "pxnU7W46X7fq" 25 | } 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "source": [ 30 | "### 環境準備" 31 | ], 32 | "metadata": { 33 | "id": "ZMcdPpbprxoJ" 34 | } 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "source": [ 39 | "#### ライブラリ導入" 40 | ], 41 | "metadata": { 42 | "id": "OMClzwfxYTnb" 43 | } 44 | }, 45 | { 46 | "cell_type": "code", 47 | "source": [ 48 | "# 日本語化ライブラリ導入\n", 49 | "!pip install japanize-matplotlib -q" 50 | ], 51 | "metadata": { 52 | "id": "W9cejPkrYFMX" 53 | }, 54 | "execution_count": null, 55 | "outputs": [] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "source": [ 60 | "#### ライブラリインポート" 61 | ], 62 | "metadata": { 63 | "id": "eJzHdbjar51-" 64 | } 65 | }, 66 | { 67 | "cell_type": "code", 68 | "source": [ 69 | "# ライブラリインポート\n", 70 | "import pandas as pd\n", 71 | "import numpy as np\n", 72 | "import matplotlib.pyplot as plt\n", 73 | "import japanize_matplotlib\n", 74 | "from IPython.display import display\n", 75 | "from sklearn.datasets import load_iris\n", 76 | "import warnings" 77 | ], 78 | "metadata": { 79 | "id": "LjLvpx_AYYTg" 80 | }, 81 | "execution_count": null, 82 | "outputs": [] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "source": [ 87 | "#### 環境設定" 88 | ], 89 | "metadata": { 90 | "id": "Ccazfpcjr_Zs" 91 | } 92 | }, 93 | { 94 | "cell_type": "code", 95 | "source": [ 96 | "# 環境設定\n", 97 | "np.set_printoptions(formatter={'float': '{:0.3f}'.format})\n", 98 | "pd.options.display.float_format = '{:.3f}'.format\n", 99 | "warnings.filterwarnings('ignore')" 100 | ], 101 | "metadata": { 102 | "id": "2HwOFwCnsCOs" 103 | }, 104 | "execution_count": null, 105 | "outputs": [] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "source": [ 110 | "### データ読み込み" 111 | ], 112 | "metadata": { 113 | "id": "1vVr9AcSYjYW" 114 | } 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "source": [ 119 | "#### 読み込み関数" 120 | ], 121 | "metadata": { 122 | "id": "6vC56VJxtt-o" 123 | } 124 | }, 125 | { 126 | "cell_type": "code", 127 | "source": [ 128 | "# 読み込み関数\n", 129 | "def load_iris_dataset():\n", 130 | " iris = load_iris(as_frame=True)\n", 131 | " df = iris.data.copy()\n", 132 | " df.columns = ['がく片長', 'がく片幅', '花弁長', '花弁幅']\n", 133 | " df['品種'] = iris.target.map({0:'setosa', 1:'versicolor', 2:'virginica'})\n", 134 | " print(f\"データ読み込み完了 ({df.shape[0]}件, 特徴量4)\")\n", 135 | " return df, iris" 136 | ], 137 | "metadata": { 138 | "id": "qHv15g2aYftl" 139 | }, 140 | "execution_count": null, 141 | "outputs": [] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "source": [ 146 | "#### 読み込み" 147 | ], 148 | "metadata": { 149 | "id": "M-goTVKptw_F" 150 | } 151 | }, 152 | { 153 | "cell_type": "code", 154 | "source": [ 155 | "# 読み込み\n", 156 | "df, iris = load_iris_dataset()\n" 157 | ], 158 | "metadata": { 159 | "id": "Q2SxV7yYtzVW" 160 | }, 161 | "execution_count": null, 162 | "outputs": [] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "source": [ 167 | "#### 内容確認" 168 | ], 169 | "metadata": { 170 | "id": "RX4_YIHzt0OW" 171 | } 172 | }, 173 | { 174 | "cell_type": "code", 175 | "source": [ 176 | "# 内容確認\n", 177 | "\n", 178 | "# 先頭5行表示\n", 179 | "display(df.head())" 180 | ], 181 | "metadata": { 182 | "id": "hTaSVmD5t28V" 183 | }, 184 | "execution_count": null, 185 | "outputs": [] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "### データ加工" 191 | ], 192 | "metadata": { 193 | "id": "0JJtExx7ZAro" 194 | } 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "source": [ 199 | "#### 2クラスのみ抽出" 200 | ], 201 | "metadata": { 202 | "id": "_l0iYrZPZDzG" 203 | } 204 | }, 205 | { 206 | "cell_type": "code", 207 | "source": [ 208 | "# 2クラスのみ抽出\n", 209 | "# 先頭の100行が2クラス分のデータになっていることを利用\n", 210 | "\n", 211 | "# x_dataは、更に列についても2列の絞り込みを行う\n", 212 | "x_data = df[['がく片長', 'がく片幅']].head(100)\n", 213 | "y_data = iris.target[:100].values\n", 214 | "\n", 215 | "# x_dataとy_dataのshape確認\n", 216 | "print(f\"x_data.shape: {x_data.shape}\")\n", 217 | "print(f\"y_data.shape: {y_data.shape}\")" 218 | ], 219 | "metadata": { 220 | "id": "h_ZTVYjaYrPq" 221 | }, 222 | "execution_count": null, 223 | "outputs": [] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "source": [ 228 | "#### ダミー変数追加" 229 | ], 230 | "metadata": { 231 | "id": "Y62bFz_DZTcp" 232 | } 233 | }, 234 | { 235 | "cell_type": "code", 236 | "source": [ 237 | "# ダミー変数追加\n", 238 | "x_data2 = np.insert(x_data, 0, 1.0, axis=1)\n", 239 | "\n", 240 | "# shape確認\n", 241 | "print(f\"x_data2.shape = {x_data2.shape}\")\n", 242 | "\n", 243 | "# 先頭5行のデータ確認\n", 244 | "display(x_data2[:5])" 245 | ], 246 | "metadata": { 247 | "id": "UMU5gP2oZMsV" 248 | }, 249 | "execution_count": null, 250 | "outputs": [] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "source": [ 255 | "#### 訓練データとテストデータの分割" 256 | ], 257 | "metadata": { 258 | "id": "ivI-5QlEZp5K" 259 | } 260 | }, 261 | { 262 | "cell_type": "code", 263 | "source": [ 264 | "# 訓練データとテストデータの分割\n", 265 | "\n", 266 | "# データ分割用のライブラリ関数インポート\n", 267 | "from sklearn.model_selection import train_test_split\n", 268 | "\n", 269 | "# データ分割の実施\n", 270 | "x_train, x_test, y_train, y_test = train_test_split(\n", 271 | " x_data2, y_data, train_size=70, test_size=30, random_state=123)\n", 272 | "\n", 273 | "# 分割後の各変数のshape確認\n", 274 | "print(f\"x_train.shape: {x_train.shape}\")\n", 275 | "print(f\"x_test.shape : {x_test.shape}\")\n", 276 | "print(f\"y_train.shape: {y_train.shape}\")\n", 277 | "print(f\"y_test.shape : {y_test.shape}\")" 278 | ], 279 | "metadata": { 280 | "id": "rEX4w8RLZeuW" 281 | }, 282 | "execution_count": null, 283 | "outputs": [] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "source": [ 288 | "#### 散布図表示(訓練データ)" 289 | ], 290 | "metadata": { 291 | "id": "4SuTqWZ7Z37m" 292 | } 293 | }, 294 | { 295 | "cell_type": "code", 296 | "source": [ 297 | "# 散布図表示(訓練データ)\n", 298 | "\n", 299 | "# 正解値によるデータ分割\n", 300 | "x_t0 = x_train[y_train==0]\n", 301 | "x_t1 = x_train[y_train==1]\n", 302 | "\n", 303 | "# 散布図表示\n", 304 | "plt.figure(figsize=(6,6))\n", 305 | "plt.scatter(x_t0[:,1], x_t0[:,2], marker='x', c='b', label='0 (setosa)')\n", 306 | "plt.scatter(x_t1[:,1], x_t1[:,2], marker='o', c='k', label='1 (versicolor)')\n", 307 | "plt.title('訓練データの散布図')\n", 308 | "plt.xlabel('がく片長')\n", 309 | "plt.ylabel('がく片幅')\n", 310 | "plt.legend()\n", 311 | "plt.grid(True)\n", 312 | "plt.show()" 313 | ], 314 | "metadata": { 315 | "id": "Vo12_TDPZwi0" 316 | }, 317 | "execution_count": null, 318 | "outputs": [] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "source": [ 323 | "### 基本関数定義" 324 | ], 325 | "metadata": { 326 | "id": "HWbX9XWkaNiK" 327 | } 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "source": [ 332 | "#### シグモイド関数" 333 | ], 334 | "metadata": { 335 | "id": "x7TqUG8_aO_b" 336 | } 337 | }, 338 | { 339 | "cell_type": "code", 340 | "source": [ 341 | "# シグモイド関数\n", 342 | "def sigmoid(x):\n", 343 | " \"\"\"シグモイド関数(確率出力)\"\"\"\n", 344 | " return 1 / (1 + np.exp(-x))" 345 | ], 346 | "metadata": { 347 | "id": "MrDs7NCPZ9WQ" 348 | }, 349 | "execution_count": null, 350 | "outputs": [] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "source": [ 355 | "#### 予測関数" 356 | ], 357 | "metadata": { 358 | "id": "UgJf569IaZ09" 359 | } 360 | }, 361 | { 362 | "cell_type": "code", 363 | "source": [ 364 | "# 予測関数\n", 365 | "def pred(x, w):\n", 366 | " \"\"\"予測関数(確率を出力)\"\"\"\n", 367 | " return sigmoid(x @ w)" 368 | ], 369 | "metadata": { 370 | "id": "2diO594YaWNc" 371 | }, 372 | "execution_count": null, 373 | "outputs": [] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "source": [ 378 | "#### 交差エントロピー関数" 379 | ], 380 | "metadata": { 381 | "id": "VSxtzQrvafd4" 382 | } 383 | }, 384 | { 385 | "cell_type": "code", 386 | "source": [ 387 | "# 交差エントロピー関数\n", 388 | "def cross_entropy(yt, yp):\n", 389 | " \"\"\"交差エントロピー損失\"\"\"\n", 390 | "\n", 391 | " # 個別データごとに交差エントロピーを計算\n", 392 | " ce = -(yt * np.log(yp) + (1 - yt) * np.log(1 - yp))\n", 393 | "\n", 394 | " # 全データの平均を取り戻り値とする\n", 395 | " return np.mean(ce)" 396 | ], 397 | "metadata": { 398 | "id": "W2z6Nmcoaeb8" 399 | }, 400 | "execution_count": null, 401 | "outputs": [] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "source": [ 406 | "#### クラス変換関数" 407 | ], 408 | "metadata": { 409 | "id": "p3Sa5KLRaqnL" 410 | } 411 | }, 412 | { 413 | "cell_type": "code", 414 | "source": [ 415 | "# クラス変換関数\n", 416 | "def classify(y):\n", 417 | " \"\"\"確率→クラス(0 or 1)変換\"\"\"\n", 418 | "\n", 419 | " # 確率値と閾値(0.5)の比較で0/1を判定\n", 420 | " return np.where(y < 0.5, 0, 1)" 421 | ], 422 | "metadata": { 423 | "id": "chic_4Ceakgs" 424 | }, 425 | "execution_count": null, 426 | "outputs": [] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "source": [ 431 | "#### 評価関数" 432 | ], 433 | "metadata": { 434 | "id": "xRtX4pWNazgb" 435 | } 436 | }, 437 | { 438 | "cell_type": "code", 439 | "source": [ 440 | "# 評価関数\n", 441 | "from sklearn.metrics import accuracy_score\n", 442 | "def evaluate(xt, yt, w):\n", 443 | " \"\"\"損失と精度を計算\"\"\"\n", 444 | "\n", 445 | " # 予測値の計算(確率値)\n", 446 | " yp = pred(xt, w)\n", 447 | "\n", 448 | " # 予測クラスの計算(0/1)\n", 449 | " yp_b = classify(yp)\n", 450 | "\n", 451 | " # 損失の計算(確率値を利用)\n", 452 | " loss = cross_entropy(yt, yp)\n", 453 | "\n", 454 | " # 精度の計算(予測クラスを利用)\n", 455 | " score = accuracy_score(yt, yp_b)\n", 456 | "\n", 457 | " # 損失と精度を戻す\n", 458 | " return loss, score" 459 | ], 460 | "metadata": { 461 | "id": "Uu6SSNI5avqL" 462 | }, 463 | "execution_count": null, 464 | "outputs": [] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "source": [ 469 | "### 学習" 470 | ], 471 | "metadata": { 472 | "id": "8SDFHf9XcMWQ" 473 | } 474 | }, 475 | { 476 | "cell_type": "markdown", 477 | "source": [ 478 | "#### 学習関数" 479 | ], 480 | "metadata": { 481 | "id": "pCJtQD2ucfaE" 482 | } 483 | }, 484 | { 485 | "cell_type": "code", 486 | "source": [ 487 | "# 学習関数\n", 488 | "def train_logistic_regression(x, yt, x_test, y_test, \\\n", 489 | " alpha=0.01, iters=5000, his_unit=100):\n", 490 | " # M(データ件数)とD(入力データ要素数)の設定\n", 491 | " M, D = x.shape\n", 492 | " # 重みベクトル初期化(全要素1を設定)\n", 493 | " w = np.ones(D)\n", 494 | " # 学習過程記録用\n", 495 | " history = np.zeros((0,3))\n", 496 | "\n", 497 | " # 繰り返し処理\n", 498 | " for k in range(iters):\n", 499 | " # 予測計算\n", 500 | " yp = pred(x, w)\n", 501 | " # 誤差計算\n", 502 | " yd = yp - yt\n", 503 | " # 勾配計算\n", 504 | " grad = (x.T @ yd) / M\n", 505 | " # パラメータ修正\n", 506 | " w -= alpha * grad\n", 507 | "\n", 508 | " if k % his_unit == 0:\n", 509 | " loss, score = evaluate(x_test, y_test, w)\n", 510 | " history = np.vstack((history, np.array([k, loss, score])))\n", 511 | " print(f\"iter={k:5d} | loss={loss:.6f} | score={score:.6f}\")\n", 512 | " return w, history" 513 | ], 514 | "metadata": { 515 | "id": "4P_tAOZDcXvc" 516 | }, 517 | "execution_count": null, 518 | "outputs": [] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "source": [ 523 | "#### 学習" 524 | ], 525 | "metadata": { 526 | "id": "ZA4cDjlYcuVT" 527 | } 528 | }, 529 | { 530 | "cell_type": "code", 531 | "source": [ 532 | "# 学習\n", 533 | "\n", 534 | "# 学習用変数設定\n", 535 | "x, yt = x_train, y_train\n", 536 | "\n", 537 | "# 学習率と繰り返し回数の設定\n", 538 | "alpha = 0.01\n", 539 | "iters = 5000\n", 540 | "his_unit = 100\n", 541 | "\n", 542 | "# 繰り返し処理\n", 543 | "w, history = train_logistic_regression(x, yt, \\\n", 544 | " x_test, y_test, alpha=alpha, iters=iters, his_unit=his_unit)" 545 | ], 546 | "metadata": { 547 | "id": "dG3B2WMVcsvy" 548 | }, 549 | "execution_count": null, 550 | "outputs": [] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "source": [ 555 | "### 結果分析" 556 | ], 557 | "metadata": { 558 | "id": "6d2QMgInc9Nt" 559 | } 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "source": [ 564 | "#### 損失と精度の確認" 565 | ], 566 | "metadata": { 567 | "id": "XE3ohGrxdJjD" 568 | } 569 | }, 570 | { 571 | "cell_type": "code", 572 | "source": [ 573 | "# 損失と精度の確認\n", 574 | "print(f\"初期状態: 損失={history[0,1]:.6f}, 精度={history[0,2]:.6f}\")\n", 575 | "print(f\"最終状態: 損失={history[-1,1]:.6f}, 精度={history[-1,2]:.6f}\")" 576 | ], 577 | "metadata": { 578 | "id": "VVPDuvk2c0UZ" 579 | }, 580 | "execution_count": null, 581 | "outputs": [] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "source": [ 586 | "#### 学習曲線(損失)" 587 | ], 588 | "metadata": { 589 | "id": "-enGG1ECdZoo" 590 | } 591 | }, 592 | { 593 | "cell_type": "code", 594 | "source": [ 595 | "# 学習曲線(損失)\n", 596 | "plt.figure(figsize=(6,6))\n", 597 | "plt.plot(history[1:,0], history[1:,1], color='blue')\n", 598 | "plt.title('学習曲線(損失)', fontsize=14)\n", 599 | "plt.xlabel('繰り返し回数', fontsize=13)\n", 600 | "plt.ylabel('損失', fontsize=13)\n", 601 | "plt.grid(True)\n", 602 | "plt.show()" 603 | ], 604 | "metadata": { 605 | "id": "44wI-nJIdQsO" 606 | }, 607 | "execution_count": null, 608 | "outputs": [] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "source": [ 613 | "#### 学習曲線(精度)" 614 | ], 615 | "metadata": { 616 | "id": "euKMBD-Fde7d" 617 | } 618 | }, 619 | { 620 | "cell_type": "code", 621 | "source": [ 622 | "# 学習曲線(精度)\n", 623 | "plt.figure(figsize=(6,6))\n", 624 | "plt.plot(history[1:,0], history[1:,2], color='black')\n", 625 | "plt.title('学習曲線(精度)', fontsize=14)\n", 626 | "plt.xlabel('繰り返し回数', fontsize=13)\n", 627 | "plt.ylabel('精度(Accuracy)', fontsize=13)\n", 628 | "plt.grid(True)\n", 629 | "plt.show()" 630 | ], 631 | "metadata": { 632 | "id": "lBZEkk37ddgo" 633 | }, 634 | "execution_count": null, 635 | "outputs": [] 636 | }, 637 | { 638 | "cell_type": "markdown", 639 | "source": [ 640 | "#### 決定境界の計算" 641 | ], 642 | "metadata": { 643 | "id": "3xF1XPrYeGW5" 644 | } 645 | }, 646 | { 647 | "cell_type": "code", 648 | "source": [ 649 | "# 決定境界の計算\n", 650 | "def decision_boundary_x2(x1, w):\n", 651 | " \"\"\"決定境界 x2 = -(w0 + w1*x1)/w2\"\"\"\n", 652 | " return -(w[0] + w[1]*x1) / w[2]\n", 653 | "\n", 654 | "xl = np.array([x_test[:,1].min(), x_test[:,1].max()])\n", 655 | "yl = decision_boundary_x2(xl, w)" 656 | ], 657 | "metadata": { 658 | "id": "3LJRMN6bdlgh" 659 | }, 660 | "execution_count": null, 661 | "outputs": [] 662 | }, 663 | { 664 | "cell_type": "markdown", 665 | "source": [ 666 | "#### 決定境界の可視化" 667 | ], 668 | "metadata": { 669 | "id": "MxxdcopfeOf_" 670 | } 671 | }, 672 | { 673 | "cell_type": "code", 674 | "source": [ 675 | "# 決定境界の可視化\n", 676 | "x_t0 = x_test[y_test==0]\n", 677 | "x_t1 = x_test[y_test==1]\n", 678 | "\n", 679 | "plt.figure(figsize=(6,6))\n", 680 | "plt.scatter(x_t0[:,1], x_t0[:,2], marker='x', c='b', s=50, label='0 (setosa)')\n", 681 | "plt.scatter(x_t1[:,1], x_t1[:,2], marker='o', c='k', s=50, label='1 (versicolor)')\n", 682 | "plt.plot(xl, yl, c='k', label='決定境界')\n", 683 | "plt.title('テストデータと決定境界')\n", 684 | "plt.xlabel('がく片長')\n", 685 | "plt.ylabel('がく片幅')\n", 686 | "plt.legend()\n", 687 | "plt.grid(True)\n", 688 | "plt.show()" 689 | ], 690 | "metadata": { 691 | "id": "BspHVUfOeLwP" 692 | }, 693 | "execution_count": null, 694 | "outputs": [] 695 | }, 696 | { 697 | "cell_type": "markdown", 698 | "source": [ 699 | "#### 予測関数の3次元曲面表示" 700 | ], 701 | "metadata": { 702 | "id": "v1pD-CLrf8O7" 703 | } 704 | }, 705 | { 706 | "cell_type": "code", 707 | "source": [ 708 | "# 予測関数の3次元曲面表示\n", 709 | "\n", 710 | "from mpl_toolkits.mplot3d import Axes3D\n", 711 | "x1 = np.linspace(4, 7.5, 100)\n", 712 | "x2 = np.linspace(2, 4.5, 100)\n", 713 | "xx1, xx2 = np.meshgrid(x1, x2)\n", 714 | "xxx = np.asarray([np.ones(xx1.ravel().shape),\n", 715 | " xx1.ravel(), xx2.ravel()]).T\n", 716 | "c = pred(xxx, w).reshape(xx1.shape)\n", 717 | "plt.figure(figsize=(8,8))\n", 718 | "ax = plt.subplot(1, 1, 1, projection='3d')\n", 719 | "ax.plot_surface(xx1, xx2, c, color='blue',\n", 720 | " edgecolor='black', rstride=10, cstride=10, alpha=0.1)\n", 721 | "ax.scatter(x_t0[:,1], x_t0[:,2], 0, s=20, alpha=0.9, marker='x', c='b')\n", 722 | "ax.scatter(x_t1[:,1], x_t1[:,2], 1, s=20, alpha=0.9, marker='o', c='k')\n", 723 | "ax.set_xlim(4,7.5)\n", 724 | "ax.set_ylim(2,4.5)\n", 725 | "ax.view_init(elev=20, azim=60)" 726 | ], 727 | "metadata": { 728 | "id": "CjrQBTOigAxx" 729 | }, 730 | "execution_count": null, 731 | "outputs": [] 732 | }, 733 | { 734 | "cell_type": "markdown", 735 | "source": [ 736 | "### scikit-learn ライブラリとの比較" 737 | ], 738 | "metadata": { 739 | "id": "5twmVrORgVEM" 740 | } 741 | }, 742 | { 743 | "cell_type": "markdown", 744 | "source": [ 745 | "#### 必要ライブラリのロード" 746 | ], 747 | "metadata": { 748 | "id": "5O3Slhq3gZR_" 749 | } 750 | }, 751 | { 752 | "cell_type": "code", 753 | "source": [ 754 | "# 必要ライブラリのロード\n", 755 | "from sklearn.linear_model import LogisticRegression\n", 756 | "from sklearn import svm" 757 | ], 758 | "metadata": { 759 | "id": "t0uAjNGsge4Y" 760 | }, 761 | "execution_count": null, 762 | "outputs": [] 763 | }, 764 | { 765 | "cell_type": "markdown", 766 | "source": [ 767 | "#### モデル生成" 768 | ], 769 | "metadata": { 770 | "id": "U3oxxIcSgqEd" 771 | } 772 | }, 773 | { 774 | "cell_type": "code", 775 | "source": [ 776 | "# モデル生成\n", 777 | "model_lr = LogisticRegression(solver='liblinear')\n", 778 | "model_svm = svm.SVC(kernel='linear')" 779 | ], 780 | "metadata": { 781 | "id": "vbtVGT1dgumR" 782 | }, 783 | "execution_count": null, 784 | "outputs": [] 785 | }, 786 | { 787 | "cell_type": "markdown", 788 | "source": [ 789 | "#### 学習" 790 | ], 791 | "metadata": { 792 | "id": "KCIN1B5ggzFF" 793 | } 794 | }, 795 | { 796 | "cell_type": "code", 797 | "source": [ 798 | "# 学習\n", 799 | "model_lr.fit(x, yt)\n", 800 | "model_svm.fit(x, yt)" 801 | ], 802 | "metadata": { 803 | "id": "MOhLGY6Ug3g4" 804 | }, 805 | "execution_count": null, 806 | "outputs": [] 807 | }, 808 | { 809 | "cell_type": "markdown", 810 | "source": [ 811 | "#### 線形回帰の決定直線" 812 | ], 813 | "metadata": { 814 | "id": "q2a72F_ehJx5" 815 | } 816 | }, 817 | { 818 | "cell_type": "code", 819 | "source": [ 820 | "# 線形回帰の決定直接\n", 821 | "# 切片の値\n", 822 | "lr_w0 = model_lr.intercept_[0]\n", 823 | "# x1(sepal_length)の係数\n", 824 | "lr_w1 = model_lr.coef_[0,1]\n", 825 | "# x2(sepal_width)の係数\n", 826 | "lr_w2 = model_lr.coef_[0,2]\n", 827 | "\n", 828 | "def rl(x):\n", 829 | " wk = lr_w0 + lr_w1 * x\n", 830 | " wk2 = -wk / lr_w2\n", 831 | " return(wk2)\n", 832 | "\n", 833 | "y_rl = rl(xl)\n", 834 | "\n", 835 | "# 結果確認\n", 836 | "print(xl, y_rl)" 837 | ], 838 | "metadata": { 839 | "id": "N34hOyc9hOG9" 840 | }, 841 | "execution_count": null, 842 | "outputs": [] 843 | }, 844 | { 845 | "cell_type": "markdown", 846 | "source": [ 847 | "#### SVMの決定直線" 848 | ], 849 | "metadata": { 850 | "id": "m0YYu80jhTSR" 851 | } 852 | }, 853 | { 854 | "cell_type": "code", 855 | "source": [ 856 | "# SVMの決定直線\n", 857 | "# 切片の値\n", 858 | "svm_w0 = model_svm.intercept_[0]\n", 859 | "# x1(sepal_length)の係数\n", 860 | "svm_w1 = model_svm.coef_[0,1]\n", 861 | "# x2(sepal_width)の係数\n", 862 | "svm_w2 = model_svm.coef_[0,2]\n", 863 | "\n", 864 | "def svm(x):\n", 865 | " wk = svm_w0 + svm_w1 * x\n", 866 | " wk2 = -wk / svm_w2\n", 867 | " return(wk2)\n", 868 | "\n", 869 | "y_svm = svm(xl)\n", 870 | "\n", 871 | "# 結果確認\n", 872 | "print(xl,y_svm)" 873 | ], 874 | "metadata": { 875 | "id": "bUozJn0AhaIv" 876 | }, 877 | "execution_count": null, 878 | "outputs": [] 879 | }, 880 | { 881 | "cell_type": "markdown", 882 | "source": [ 883 | "#### 散布図と決定直線の同時表示" 884 | ], 885 | "metadata": { 886 | "id": "jNFq0p0_ho6X" 887 | } 888 | }, 889 | { 890 | "cell_type": "code", 891 | "source": [ 892 | "# 散布図と決定直線の同時表示\n", 893 | "\n", 894 | "plt.figure(figsize=(6,6))\n", 895 | "\n", 896 | "# 散布図の表示\n", 897 | "plt.scatter(x_t0[:,1], x_t0[:,2], marker='x',c='b', s=50, label='0 (setosa)')\n", 898 | "plt.scatter(x_t1[:,1], x_t1[:,2], marker='o',c='k', s=50, label='1 (versicolor)')\n", 899 | "\n", 900 | "# 決定直線の表示\n", 901 | "plt.plot(xl, yl, linewidth=2, c='k', label='Hands On')\n", 902 | "# lr model\n", 903 | "plt.plot(xl, y_rl, linewidth=2, c='k', linestyle=\"--\", label='scikit LR')\n", 904 | "# svm\n", 905 | "plt.plot(xl, y_svm, linewidth=2, c='b', linestyle=\"-.\", label='scikit SVM')\n", 906 | "\n", 907 | "# グラフのキレイ化\n", 908 | "plt.title('テストデータの散布図と決定境界')\n", 909 | "plt.xlabel('がく片長')\n", 910 | "plt.ylabel('がく片幅')\n", 911 | "plt.legend()\n", 912 | "plt.grid()\n", 913 | "plt.show()" 914 | ], 915 | "metadata": { 916 | "id": "_MFPzXeahqYL" 917 | }, 918 | "execution_count": null, 919 | "outputs": [] 920 | }, 921 | { 922 | "cell_type": "markdown", 923 | "source": [ 924 | "### バージョン確認" 925 | ], 926 | "metadata": { 927 | "id": "YPtwY72cA_yX" 928 | } 929 | }, 930 | { 931 | "cell_type": "code", 932 | "source": [ 933 | "!pip install watermark -qq\n", 934 | "%load_ext watermark\n", 935 | "%watermark --iversions" 936 | ], 937 | "metadata": { 938 | "id": "D_mwjZ-deVwq" 939 | }, 940 | "execution_count": null, 941 | "outputs": [] 942 | }, 943 | { 944 | "cell_type": "code", 945 | "source": [], 946 | "metadata": { 947 | "id": "aBSrdxzKGRfG" 948 | }, 949 | "execution_count": null, 950 | "outputs": [] 951 | } 952 | ] 953 | } -------------------------------------------------------------------------------- /notebooks/09_multi_classify.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "BFAZGXMPI1q9" 7 | }, 8 | "source": [ 9 | "### 9章 多値ロジスティック回帰" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "2UYadA_-KVUo" 16 | }, 17 | "source": [ 18 | "### 環境準備" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "-0UN5cm7KbqD" 25 | }, 26 | "source": [ 27 | "#### ライブラリ導入" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "id": "37h9wcBWInXD" 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "# 日本語化ライブラリ導入\n", 39 | "!pip install japanize-matplotlib -q" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "AfWRRzVGKii9" 46 | }, 47 | "source": [ 48 | "#### ライブラリインポート" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "id": "-JPZjeEYI7-3" 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "# ライブラリインポート\n", 60 | "import pandas as pd\n", 61 | "import numpy as np\n", 62 | "import matplotlib.pyplot as plt\n", 63 | "import japanize_matplotlib\n", 64 | "from IPython.display import display\n", 65 | "from sklearn.datasets import load_iris\n", 66 | "import warnings" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "4nFN1LA0Kvkx" 73 | }, 74 | "source": [ 75 | "#### 環境設定" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": { 82 | "id": "ILcfAZSqKtcc" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "# 環境設定\n", 87 | "np.set_printoptions(formatter={'float': '{:0.3f}'.format})\n", 88 | "pd.options.display.float_format = '{:.3f}'.format\n", 89 | "warnings.filterwarnings('ignore')" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": { 95 | "id": "tjkMy5K2LAha" 96 | }, 97 | "source": [ 98 | "### データ読み込み" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": { 104 | "id": "8efg5d8PLFFl" 105 | }, 106 | "source": [ 107 | "#### 読み込み関数" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "id": "xkdTcFQ3K0HH" 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "# 読み込み関数\n", 119 | "def load_iris_dataset():\n", 120 | " iris = load_iris(as_frame=True)\n", 121 | " df = iris.data.copy()\n", 122 | " df.columns = ['がく片長', 'がく片幅', '花弁長', '花弁幅']\n", 123 | " df['品種'] = iris.target.map({0:'setosa', 1:'versicolor', 2:'virginica'})\n", 124 | " print(f\"データ読み込み完了 ({df.shape[0]}件, 特徴量4)\")\n", 125 | " return df, iris" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": { 131 | "id": "LRCmGjVoLSya" 132 | }, 133 | "source": [ 134 | "#### 読み込み" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": { 141 | "id": "cJvhwzNJLNWk" 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "# 読み込み\n", 146 | "df, iris = load_iris_dataset()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": { 152 | "id": "c22OtPsqLY9N" 153 | }, 154 | "source": [ 155 | "#### 内容表示" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": { 162 | "id": "RRKM-6yVLXuA" 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "# 内容表示\n", 167 | "\n", 168 | "# 先頭5行表示\n", 169 | "display(df.head())" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "id": "9dguBjfbL9bG" 176 | }, 177 | "source": [ 178 | "### データ加工" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": { 184 | "id": "4JjCE9z7MNes" 185 | }, 186 | "source": [ 187 | "#### 入力データ絞り込み(2項目のみ)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "id": "NrHAS2SQLeNo" 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "# 入力データ絞り込み(2項目のみ)\n", 199 | "\n", 200 | "# クラス: すべて\n", 201 | "# 項目: がく片長と花弁長のみ\n", 202 | "x_data = df[['がく片長','花弁長']]\n", 203 | "y_data = iris.target.values\n", 204 | "\n", 205 | "# x_dataとy_dataのshape確認\n", 206 | "print(f\"x_data.shape: {x_data.shape}\")\n", 207 | "print(f\"y_data.shape: {y_data.shape}\")" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": { 213 | "id": "iqERrOAuMwt5" 214 | }, 215 | "source": [ 216 | "#### ダミー変数追加" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": { 223 | "id": "aDeW8iiLMhiX" 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "# ダミー変数追加\n", 228 | "x_data2 = np.insert(x_data, 0, 1.0, axis=1)\n", 229 | "\n", 230 | "# shape確認\n", 231 | "print(f\"x_data2.shape = {x_data2.shape}\")\n", 232 | "\n", 233 | "# 先頭5行のデータ確認\n", 234 | "display(x_data2[:5])" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": { 240 | "id": "uOo06Wz-OhFG" 241 | }, 242 | "source": [ 243 | "#### 正解値のOne Hotベクトル化" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "id": "VBnEuioRM4GL" 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "# 正解値のOne Hotベクトル化\n", 255 | "\n", 256 | "# OneHotEncoderのインポート\n", 257 | "from sklearn.preprocessing import OneHotEncoder\n", 258 | "\n", 259 | "# one hot encoderインスタンスの生成\n", 260 | "ohe = OneHotEncoder(sparse_output=False,categories='auto')\n", 261 | "\n", 262 | "# y_dataの行列化\n", 263 | "y_data_matrix = y_data.reshape(-1,1)\n", 264 | "\n", 265 | "# y_data_magtrixのOne Hotベクトル化\n", 266 | "y_data_ohe = ohe.fit_transform(y_data_matrix)\n", 267 | "\n", 268 | "# 各変数のshape確認\n", 269 | "print('オリジナル', y_data.shape)\n", 270 | "print('2次元化', y_data_matrix.shape)\n", 271 | "print('One Hot Vector化後', y_data_ohe.shape)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": { 277 | "id": "1jPLlGKsPBA5" 278 | }, 279 | "source": [ 280 | "#### 訓練データとテストデータの分割" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": { 287 | "id": "gDc4Y1STOu1t" 288 | }, 289 | "outputs": [], 290 | "source": [ 291 | "# 訓練データとテストデータへの分割\n", 292 | "\n", 293 | "# 分割用関数train_test_splitのインポート\n", 294 | "from sklearn.model_selection import train_test_split\n", 295 | "\n", 296 | "# データ分割の実施(x_data2, y_data, y_data_oheを同時に関数に渡す)\n", 297 | "x_train, x_test, y_train, y_test, \\\n", 298 | "y_train_ohe, y_test_ohe = train_test_split(\n", 299 | " x_data2, y_data, y_data_ohe,\n", 300 | " train_size=75, test_size=75, random_state=123)\n", 301 | "\n", 302 | "# 各変数のshape確認\n", 303 | "print('x_train', x_train.shape)\n", 304 | "print('x_test', x_test.shape)\n", 305 | "print('y_train', y_train.shape)\n", 306 | "print('y_test', y_test.shape)\n", 307 | "print('y_train_ohe', y_train_ohe.shape)\n", 308 | "print('y_test_ohe', y_test_ohe.shape)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": { 314 | "id": "FSFTndvbPihX" 315 | }, 316 | "source": [ 317 | "#### 訓練用データ確認" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": { 324 | "id": "HRZmgca8PRGO" 325 | }, 326 | "outputs": [], 327 | "source": [ 328 | "# 訓練用データ確認\n", 329 | "\n", 330 | "print('x_train先頭5行')\n", 331 | "print(x_train[:5])\n", 332 | "print('y_train先頭5要素')\n", 333 | "print(y_train[:5])\n", 334 | "print('y_train_ohe先頭5行')\n", 335 | "print(y_train_ohe[:5])" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": { 341 | "id": "4vmNXNWKQWnh" 342 | }, 343 | "source": [ 344 | "#### 散布図表示(訓練データ)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": { 351 | "id": "qpZZihcwP1tZ" 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "# 散布図表示用に入力データを分割\n", 356 | "\n", 357 | "# 正解値によるデータ分割\n", 358 | "x_t0 = x_train[y_train == 0]\n", 359 | "x_t1 = x_train[y_train == 1]\n", 360 | "x_t2 = x_train[y_train == 2]\n", 361 | "\n", 362 | "# グラフのサイズ指定\n", 363 | "plt.figure(figsize=(6,6))\n", 364 | "\n", 365 | "# マーカを変えて散布図表示\n", 366 | "plt.scatter(x_t0[:,1], x_t0[:,2], marker='x', c='k', s=50, label='0 (setosa)')\n", 367 | "plt.scatter(x_t1[:,1], x_t1[:,2], marker='o', c='b', s=50, label='1 (versicolour)')\n", 368 | "plt.scatter(x_t2[:,1], x_t2[:,2], marker='+', c='k', s=50, label='2 (virginica)')\n", 369 | "\n", 370 | "# グラフのキレイ化\n", 371 | "plt.title('アイリスデータセットの散布図(がく片長vs花弁長)')\n", 372 | "plt.xlabel('がく片長')\n", 373 | "plt.ylabel('花弁長')\n", 374 | "plt.legend()\n", 375 | "plt.grid()\n", 376 | "plt.show()" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": { 382 | "id": "nxUX4gd5SRbl" 383 | }, 384 | "source": [ 385 | "### 基本関数定義" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": { 391 | "id": "gt_aKU0DSVNf" 392 | }, 393 | "source": [ 394 | "#### softmax関数" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "metadata": { 401 | "id": "AQ8XI2oWQ-Y3" 402 | }, 403 | "outputs": [], 404 | "source": [ 405 | "# softmax関数\n", 406 | "def softmax(x):\n", 407 | " \"\"\"予測関数(確率を出力)\"\"\"\n", 408 | " x_max = x.max(axis=1, keepdims=True)\n", 409 | " x = x - x_max\n", 410 | " w = np.exp(x)\n", 411 | " return w / w.sum(axis=1, keepdims=True)" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "x8B3BpvRS88D" 418 | }, 419 | "source": [ 420 | "#### 予測関数" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": { 427 | "id": "B_-NLlH6SpDY" 428 | }, 429 | "outputs": [], 430 | "source": [ 431 | "# 予測関数\n", 432 | "def pred(x, W):\n", 433 | " \"\"\"予測関数(確率を出力)\"\"\"\n", 434 | " return softmax(x @ W)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": { 440 | "id": "c4uPYI9xTLgP" 441 | }, 442 | "source": [ 443 | "#### 交差エントロピー関数" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": { 450 | "id": "OatFwoigTF_-" 451 | }, 452 | "outputs": [], 453 | "source": [ 454 | "# 交差エントロピー関数\n", 455 | "def cross_entropy(yt, yp):\n", 456 | " \"\"\"交差エントロピー損失\"\"\"\n", 457 | "\n", 458 | " # 個別データごとに交差エントロピーを計算\n", 459 | " ce = -np.sum(yt * np.log(yp), axis=1)\n", 460 | "\n", 461 | " # 全データの平均を取り戻り値とする\n", 462 | " return np.mean(ce)" 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": { 468 | "id": "oFVM_3tLT6EM" 469 | }, 470 | "source": [ 471 | "#### クラス変換関数" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": null, 477 | "metadata": { 478 | "id": "4HTpjxokTorN" 479 | }, 480 | "outputs": [], 481 | "source": [ 482 | "# クラス変換関数\n", 483 | "def classify(yp_ohe):\n", 484 | " \"\"\"確率ベクトル→クラス(0 or 1 or 2)変換\"\"\"\n", 485 | "\n", 486 | " # 確率値ベクトルとargmax関数で0/1/2を判定\n", 487 | " return np.argmax(yp_ohe, axis=1)" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": { 493 | "id": "phetP14UUizx" 494 | }, 495 | "source": [ 496 | "#### 評価関数" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": null, 502 | "metadata": { 503 | "id": "bsjfZIyAUcGb" 504 | }, 505 | "outputs": [], 506 | "source": [ 507 | "# 評価関数\n", 508 | "from sklearn.metrics import accuracy_score\n", 509 | "def evaluate(x_test, y_test, y_test_ohe, W):\n", 510 | " \"\"\"損失と精度を計算\"\"\"\n", 511 | "\n", 512 | " # 予測値の計算(確率値)\n", 513 | " yp_test_ohe = pred(x_test, W)\n", 514 | "\n", 515 | " # 予測クラス計算(0, 1, 2)\n", 516 | " yp_test = classify(yp_test_ohe)\n", 517 | "\n", 518 | " # 損失計算(確率値を利用)\n", 519 | " loss = cross_entropy(y_test_ohe, yp_test_ohe)\n", 520 | "\n", 521 | " # 精度計算(予測クラスを利用)\n", 522 | " score = accuracy_score(y_test, yp_test)\n", 523 | "\n", 524 | " # 損失と精度を戻す\n", 525 | " return loss, score" 526 | ] 527 | }, 528 | { 529 | "cell_type": "markdown", 530 | "metadata": { 531 | "id": "WAeAEVzwVpGC" 532 | }, 533 | "source": [ 534 | "### 学習" 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": { 540 | "id": "0rxmTSN3VuwJ" 541 | }, 542 | "source": [ 543 | "#### 学習関数" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "metadata": { 550 | "id": "k0f2oyauVgfR" 551 | }, 552 | "outputs": [], 553 | "source": [ 554 | "# 学習関数\n", 555 | "\n", 556 | "def train_multi_logistic_regression(x, yt_ohe, x_test, \\\n", 557 | "y_test, y_test_ohe, alpha=0.01, iters=10000, his_unit=100):\n", 558 | " # M(データ件数)とD(入力データ要素数)の設定\n", 559 | " M, D = x.shape\n", 560 | " # 分類先クラス数\n", 561 | " N = yt_ohe.shape[1]\n", 562 | " # 重み行列初期化(全要素1を設定)\n", 563 | " W = np.ones((D, N))\n", 564 | " # 学習過程記録用\n", 565 | " history = np.zeros((0,3))\n", 566 | "\n", 567 | " # 繰り返し処理\n", 568 | " for k in range(iters):\n", 569 | " # 予測計算\n", 570 | " yp = pred(x, W)\n", 571 | " # 誤差計算\n", 572 | " yd = yp - yt\n", 573 | " # 勾配計算\n", 574 | " grad = (x.T @ yd) / M\n", 575 | " # パラメータ修正\n", 576 | " W -= alpha * grad\n", 577 | "\n", 578 | " if k % his_unit == 0:\n", 579 | " loss, score = evaluate(x_test, y_test, y_test_ohe, W)\n", 580 | " history = np.vstack((history, np.array([k, loss, score])))\n", 581 | " print(f\"iter={k:5d} | loss={loss:.6f} | score={score:.6f}\")\n", 582 | " return W, history\n" 583 | ] 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "metadata": { 588 | "id": "9iksDyP8Xh1U" 589 | }, 590 | "source": [ 591 | "#### 学習" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": null, 597 | "metadata": { 598 | "id": "VDNTRukHXb7V" 599 | }, 600 | "outputs": [], 601 | "source": [ 602 | "# 学習\n", 603 | "\n", 604 | "# 変数設定\n", 605 | "x, yt = x_train, y_train_ohe\n", 606 | "\n", 607 | "# 学習率と繰り返し回数の設定\n", 608 | "alpha = 0.01\n", 609 | "iters = 10000\n", 610 | "his_unit = 100\n", 611 | "\n", 612 | "# 繰り返し処理\n", 613 | "W, history = train_multi_logistic_regression(x, yt, \\\n", 614 | " x_test, y_test, y_test_ohe, alpha=alpha, iters=iters, his_unit=his_unit)\n" 615 | ] 616 | }, 617 | { 618 | "cell_type": "markdown", 619 | "metadata": { 620 | "id": "TC-qOioQa4r7" 621 | }, 622 | "source": [ 623 | "### 結果分析" 624 | ] 625 | }, 626 | { 627 | "cell_type": "markdown", 628 | "metadata": { 629 | "id": "EhDPfti3bCLI" 630 | }, 631 | "source": [ 632 | "#### 初期状態と最終状態の比較" 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": null, 638 | "metadata": { 639 | "id": "j0IMShDEXyK5" 640 | }, 641 | "outputs": [], 642 | "source": [ 643 | "# 初期状態と最終状態の比較\n", 644 | "print(f\"初期状態: 損失={history[0,1]:.6f}, 精度={history[0,2]:.6f}\")\n", 645 | "print(f\"最終状態: 損失={history[-1,1]:.6f}, 精度={history[-1,2]:.6f}\")" 646 | ] 647 | }, 648 | { 649 | "cell_type": "markdown", 650 | "metadata": { 651 | "id": "tgfeKafLbJdM" 652 | }, 653 | "source": [ 654 | "#### 学習曲線(損失)" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": null, 660 | "metadata": { 661 | "id": "Re7FAZi4bGnt" 662 | }, 663 | "outputs": [], 664 | "source": [ 665 | "# 学習曲線(損失)\n", 666 | "plt.figure(figsize=(6,6))\n", 667 | "plt.plot(history[1:,0], history[1:,1], color='blue')\n", 668 | "plt.title('学習曲線(損失)', fontsize=14)\n", 669 | "plt.xlabel('繰り返し回数', fontsize=13)\n", 670 | "plt.ylabel('損失', fontsize=13)\n", 671 | "plt.grid(True)\n", 672 | "plt.show()" 673 | ] 674 | }, 675 | { 676 | "cell_type": "markdown", 677 | "metadata": { 678 | "id": "P-i170bLbng5" 679 | }, 680 | "source": [ 681 | "#### 学習曲線(精度)" 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": null, 687 | "metadata": { 688 | "id": "Immy3pvcbmwe" 689 | }, 690 | "outputs": [], 691 | "source": [ 692 | "# 学習曲線(精度)\n", 693 | "plt.figure(figsize=(6,6))\n", 694 | "plt.plot(history[1:,0], history[1:,2], color='black')\n", 695 | "plt.title('学習曲線(精度)', fontsize=14)\n", 696 | "plt.xlabel('繰り返し回数', fontsize=13)\n", 697 | "plt.ylabel('精度(Accuracy)', fontsize=13)\n", 698 | "plt.grid(True)\n", 699 | "plt.show()" 700 | ] 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "metadata": { 705 | "id": "ZVDktbELcEdi" 706 | }, 707 | "source": [ 708 | "#### 予測関数の3次元曲面表示" 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": null, 714 | "metadata": { 715 | "id": "OECpLnNbbm1p" 716 | }, 717 | "outputs": [], 718 | "source": [ 719 | "# 予測関数の3次元曲面表示\n", 720 | "from mpl_toolkits.mplot3d import Axes3D\n", 721 | "x1 = np.linspace(4, 8.5, 100)\n", 722 | "x2 = np.linspace(0.5, 7.5, 100)\n", 723 | "xx1, xx2 = np.meshgrid(x1, x2)\n", 724 | "xxx = np.array([np.ones(xx1.ravel().shape),\n", 725 | " xx1.ravel(), xx2.ravel()]).T\n", 726 | "pp = pred(xxx, W)\n", 727 | "c0 = pp[:,0].reshape(xx1.shape)\n", 728 | "c1 = pp[:,1].reshape(xx1.shape)\n", 729 | "c2 = pp[:,2].reshape(xx1.shape)\n", 730 | "plt.figure(figsize=(8,8))\n", 731 | "ax = plt.subplot(1, 1, 1, projection='3d')\n", 732 | "ax.plot_surface(xx1, xx2, c0, color='lightblue',\n", 733 | " edgecolor='black', rstride=10, cstride=10, alpha=0.7)\n", 734 | "ax.plot_surface(xx1, xx2, c1, color='blue',\n", 735 | " edgecolor='black', rstride=10, cstride=10, alpha=0.7)\n", 736 | "ax.plot_surface(xx1, xx2, c2, color='lightgrey',\n", 737 | " edgecolor='black', rstride=10, cstride=10, alpha=0.7)\n", 738 | "ax.scatter(x_t0[:,0], x_t0[:,1], 1, s=50, alpha=1, marker='+', c='k')\n", 739 | "ax.scatter(x_t1[:,0], x_t1[:,1], 1, s=30, alpha=1, marker='o', c='k')\n", 740 | "ax.scatter(x_t2[:,0], x_t2[:,1], 1, s=50, alpha=1, marker='x', c='k')\n", 741 | "ax.set_xlim(4,8.5)\n", 742 | "ax.set_ylim(0.5,7.5)\n", 743 | "ax.view_init(elev=40, azim=70)" 744 | ] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "metadata": { 749 | "id": "YTvtog9scTfr" 750 | }, 751 | "source": [ 752 | "#### 詳細な精度評価" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": null, 758 | "metadata": { 759 | "id": "hH-JIxztblZc" 760 | }, 761 | "outputs": [], 762 | "source": [ 763 | "# 詳細な精度評価\n", 764 | "from sklearn.metrics import accuracy_score\n", 765 | "from sklearn.metrics import confusion_matrix\n", 766 | "from sklearn.metrics import classification_report\n", 767 | "\n", 768 | "# テストデータで予測値の計算\n", 769 | "yp_test_one = pred(x_test, W)\n", 770 | "yp_test = np.argmax(yp_test_one, axis=1)\n", 771 | "\n", 772 | "# 精度の計算\n", 773 | "from sklearn.metrics import accuracy_score\n", 774 | "score = accuracy_score(y_test, yp_test)\n", 775 | "print('accuracy: %f' % score)\n", 776 | "\n", 777 | "# 混同行列の表示\n", 778 | "from sklearn.metrics import confusion_matrix\n", 779 | "print(confusion_matrix(y_test, yp_test))\n", 780 | "print(classification_report(y_test, yp_test))" 781 | ] 782 | }, 783 | { 784 | "cell_type": "markdown", 785 | "metadata": { 786 | "id": "KeMAcG-pclCC" 787 | }, 788 | "source": [ 789 | "### 入力変数をオリジナルの4つに変更" 790 | ] 791 | }, 792 | { 793 | "cell_type": "markdown", 794 | "metadata": { 795 | "id": "dWqIW70vc30M" 796 | }, 797 | "source": [ 798 | "#### データ加工" 799 | ] 800 | }, 801 | { 802 | "cell_type": "code", 803 | "execution_count": null, 804 | "metadata": { 805 | "id": "S5HcFyHtcZaW" 806 | }, 807 | "outputs": [], 808 | "source": [ 809 | "# データ加工\n", 810 | "\n", 811 | "# x_data3: 4要素を持つNumPy配列\n", 812 | "x_data3 = df[['がく片長', 'がく片幅', '花弁長', '花弁幅']].values\n", 813 | "\n", 814 | "# x_daa4: x_data3にダミー変数を追加\n", 815 | "x_data4 = np.insert(x_data3, 0, 1.0, axis=1)\n", 816 | "\n", 817 | "# 訓練データとテストデータへの分割\n", 818 | "x_train2, x_test2, y_train, y_test,\\\n", 819 | "y_train_ohe, y_test_ohe = train_test_split(\n", 820 | " x_data4, y_data, y_data_ohe,\n", 821 | " train_size=75, test_size=75, random_state=123)\n", 822 | "\n", 823 | "# 各変数のshape確認\n", 824 | "print('x_train2', x_train2.shape)\n", 825 | "print('x_test2', x_test2.shape)\n", 826 | "\n", 827 | "# xtrainの内容確認\n", 828 | "print(x_train2[:5])" 829 | ] 830 | }, 831 | { 832 | "cell_type": "markdown", 833 | "metadata": { 834 | "id": "N4Eu9dbceN_8" 835 | }, 836 | "source": [ 837 | "#### 学習" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": null, 843 | "metadata": { 844 | "id": "KQ9b4hoediza" 845 | }, 846 | "outputs": [], 847 | "source": [ 848 | "# 学習\n", 849 | "\n", 850 | "# 変数設定\n", 851 | "x, yt = x_train2, y_train_ohe\n", 852 | "\n", 853 | "# 学習率と繰り返し回数の設定\n", 854 | "alpha = 0.01\n", 855 | "iters = 10000\n", 856 | "his_unit = 100\n", 857 | "\n", 858 | "# 繰り返し処理\n", 859 | "W, history = train_multi_logistic_regression(x, yt, \\\n", 860 | " x_test2, y_test, y_test_ohe, alpha=alpha, iters=iters, his_unit=his_unit)" 861 | ] 862 | }, 863 | { 864 | "cell_type": "markdown", 865 | "metadata": { 866 | "id": "jOOfchKwfMb6" 867 | }, 868 | "source": [ 869 | "#### 結果分析" 870 | ] 871 | }, 872 | { 873 | "cell_type": "code", 874 | "execution_count": null, 875 | "metadata": { 876 | "id": "ZGRmqSiyeria" 877 | }, 878 | "outputs": [], 879 | "source": [ 880 | "# 結果分析\n", 881 | "\n", 882 | "# 初期状態と最終状態の比較\n", 883 | "print(f\"初期状態: 損失={history[0,1]:.6f}, 精度={history[0,2]:.6f}\")\n", 884 | "print(f\"最終状態: 損失={history[-1,1]:.6f}, 精度={history[-1,2]:.6f}\")\n", 885 | "\n", 886 | "# 学習曲線(損失)\n", 887 | "plt.figure(figsize=(6,6))\n", 888 | "plt.plot(history[1:,0], history[1:,1], color='blue')\n", 889 | "plt.title('学習曲線(損失)', fontsize=14)\n", 890 | "plt.xlabel('繰り返し回数', fontsize=13)\n", 891 | "plt.ylabel('損失', fontsize=13)\n", 892 | "plt.grid(True)\n", 893 | "plt.show()\n", 894 | "\n", 895 | "# 学習曲線(精度)\n", 896 | "plt.figure(figsize=(6,6))\n", 897 | "plt.plot(history[1:,0], history[1:,2], color='black')\n", 898 | "plt.title('学習曲線(精度)', fontsize=14)\n", 899 | "plt.xlabel('繰り返し回数', fontsize=13)\n", 900 | "plt.ylabel('精度(Accuracy)', fontsize=13)\n", 901 | "plt.grid(True)\n", 902 | "plt.show()" 903 | ] 904 | }, 905 | { 906 | "cell_type": "markdown", 907 | "metadata": { 908 | "id": "DYvIGhG4f7mq" 909 | }, 910 | "source": [ 911 | "### バージョン確認" 912 | ] 913 | }, 914 | { 915 | "cell_type": "code", 916 | "execution_count": null, 917 | "metadata": { 918 | "id": "pwKdDx85fiwv" 919 | }, 920 | "outputs": [], 921 | "source": [ 922 | "!pip install watermark -qq\n", 923 | "%load_ext watermark\n", 924 | "%watermark --iversions" 925 | ] 926 | }, 927 | { 928 | "cell_type": "code", 929 | "execution_count": null, 930 | "metadata": { 931 | "id": "JVMCPlR1gBdO" 932 | }, 933 | "outputs": [], 934 | "source": [] 935 | } 936 | ], 937 | "metadata": { 938 | "colab": { 939 | "provenance": [], 940 | "toc_visible": true 941 | }, 942 | "kernelspec": { 943 | "display_name": "Python 3", 944 | "name": "python3" 945 | }, 946 | "language_info": { 947 | "name": "python" 948 | } 949 | }, 950 | "nbformat": 4, 951 | "nbformat_minor": 0 952 | } 953 | --------------------------------------------------------------------------------