├── src
├── cover.png
└── teaser.gif
├── LICENSE
├── .gitignore
├── README.md
└── exps
├── toy_regression.py
└── notebooks
└── mnist_classification.ipynb
/src/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvlab-epfl/iter_unc/HEAD/src/cover.png
--------------------------------------------------------------------------------
/src/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvlab-epfl/iter_unc/HEAD/src/teaser.gif
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 CVLAB @ EPFL
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | **/.DS_Store
163 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Enabling Uncertainty Estimation in Iterative Neural Networks
2 |
3 | 
4 |
5 | [](https://arxiv.org/abs/2403.16732)
6 | [](https://www.python.org/downloads/release/python-31014/)
7 | [](https://pytorch.org/)
8 | [](https://github.com/cvlab-epfl/iter_unc/blob/main/LICENSE)
9 |
10 | ### [Project Page](https://www.norange.io/projects/unc_iter/) | [ICML Paper]() | [Poster](https://icml.cc/virtual/2024/poster/34213)
11 |
12 | > **For a quick tryout of the method, check the Colab notebooks below!**
13 |
14 | > **Please also refer to the [Zigzag](https://github.com/cvlab-epfl/zigzag) paper, which served as the foundation for our work.**
15 |
16 | ## Abstract
17 |
18 | Turning pass-through network architectures into iterative ones, which use their own output as input, is a well-known approach for boosting performance. In this paper, we argue that such architectures offer an additional benefit: The convergence rate of their successive outputs is highly correlated with the accuracy of the value to which they converge. Thus, we can use the convergence rate as a useful proxy for uncertainty. This results in an approach to uncertainty estimation that provides state-of-the-art estimates at a much lower computational cost than techniques like Ensembles, and without requiring any modifications to the original iterative model. We demonstrate its practical value by embedding it in two application domains: road detection in aerial images and the estimation of aerodynamic properties of 2D and 3D shapes.
19 |
20 | ## TL;DR
21 |
22 |
23 |
24 |
25 |
26 |
27 | **Uncertainty in recursive models:** Such models use their initial predictions as inputs to produce subsequent predictions. We display the output of three consecutive iterations of a network trained to compute distance maps to road pixels. **(Top:)** All roads are clearly visible. The three maps are similar and the per pixel variance is low. **(Bottom:)** The road in the red square is tree-covered. It is eventually detected properly but the variance is high.
28 |
29 | ## Experiments
30 |
31 | ### 1D Regression
32 |
33 | **Uncertainty Estimation for Regression:** The task is to regress $y$-axis values for $x$-axis data points drawn from the range $x \in [-1, 1.3]$ using a third-degree polynomial with added Gaussian noise. The method displays higher uncertainty or iteration variance for out-of-distribution inputs and lower for in-distribution samples.
34 |
35 | [](https://colab.research.google.com/github/cvlab-epfl/iter_unc/blob/main/exps/notebooks/toy_regression.ipynb)
36 |
37 | ### MNIST Classification
38 |
39 | **MNIST vs FashionMNIST:** We train the networks on MNIST and compute the accuracy and calibration metrics (rAULC). We then use the uncertainty measure they produce to classify images from the test sets of MNIST and FashionMNIST as being within the MNIST distribution or not to compute the OOD metrics, ROC- and PR-AUCs. We use a standard architecture with several convolution and pooling layers, followed by fully connected layers with LeakyReLU activations.
40 |
41 | [](https://colab.research.google.com/github/cvlab-epfl/iter_unc/blob/main/exps/notebooks/mnist_classification.ipynb)
42 |
43 | ### Delineation Experiments (TBD)
44 |
45 | In these experiments, we focus on the delineation task, particularly for road detection in aerial imagery. This task requires precise identification and outlining of narrow, intricate features within varied image datasets. Our experiments utilized the U-Net architecture, known for its efficacy in image segmentation. We employed it to classify pixels within the images as belonging to the target structures (roads) or not, producing a binary map as the final output. The networks were trained and tested on two public datasets, "RoadTracer" and "Massachusetts," which encompass a diverse range of urban and rural landscapes, thereby providing a comprehensive benchmark for our methods.
46 |
47 | ## Citation
48 |
49 | If you find this code useful, please consider citing our paper:
50 |
51 | > Durasov, Nikita, et al. "Enabling Uncertainty Estimation in Iterative Neural Networks", ICML, 2024.
52 |
53 | ```bibtex
54 | @inproceedings{
55 | durasov2024enabling,
56 | title={Enabling Uncertainty Estimation in Iterative Neural Networks},
57 | author={Nikita Durasov and Doruk Oner and Jonathan Donier and Hieu Le and Pascal Fua},
58 | booktitle={Forty-first International Conference on Machine Learning},
59 | year={2024}
60 | }
61 | ```
62 |
63 | > Durasov, Nikita, et al. "ZigZag: Universal Sampling-free Uncertainty Estimation Through Two-Step Inference." TMLR 20224.
64 |
65 | ```bibtex
66 | @article{durasov2024zigzag,
67 | title = {ZigZag: Universal Sampling-free Uncertainty Estimation Through Two-Step Inference},
68 | author = {Nikita Durasov and Nik Dorndorf and Hieu Le and Pascal Fua},
69 | journal = {Transactions on Machine Learning Research},
70 | issn = {2835-8856},
71 | year = {2024}
72 | }
73 | ```
74 |
--------------------------------------------------------------------------------
/exps/toy_regression.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """simple_regression.ipynb
3 |
4 | Automatically generated by Colab.
5 |
6 | Original file is located at
7 | https://colab.research.google.com/drive/1687Kj81qQYGJ3BBoWiVYbtfHYrlm0DA6
8 |
9 | # Training Dataset
10 | """
11 |
12 | import numpy as np
13 |
14 | import matplotlib.cm as cm # Import cm for colormap
15 | import matplotlib.pyplot as plt
16 |
17 | seed = 42
18 |
19 | import torch
20 | torch.manual_seed(seed)
21 |
22 | import numpy as np
23 | np.random.seed(seed)
24 |
25 | import random
26 | random.seed(seed)
27 |
28 | ###### Generate Data ######
29 | def polynom(X):
30 | return 0 + X + -5 * X**2 + 1.5*X ** 3
31 |
32 | def generate(N=1000):
33 | X = np.random.uniform(-1, 1.5, N)
34 | Y = 0 + X + -5 * X**2 + 1.5*X ** 3
35 | Y += np.random.randn(Y.shape[0]) / 5
36 | return torch.tensor(X).float()[:, None], torch.tensor(Y)[:, None].float()
37 |
38 | X_data, Y_data = generate(100)
39 | Y_data += 3.5
40 | Y_data /= 15
41 | Y_data -= 0.2
42 |
43 | ###### Data Visualization ######
44 | fig, (ax) = plt.subplots(1, 1, figsize=(10, 6))
45 | X_min, X_max = -5, 5
46 |
47 | ax.grid()
48 | ax.scatter(X_data, Y_data, color="seagreen", label="data")
49 |
50 | Y_min, Y_max = -0.9, 0.7
51 | ax.fill_between([-1, 1.5], Y_min, Y_max, color='seagreen', alpha=0.1, label="In-distibution Region")
52 | ax.fill_between([1.5, X_max], Y_min, Y_max, color='red', alpha=0.1, label="Out-of-distibution Region")
53 | ax.fill_between([X_min, -1], Y_min, Y_max, color='red', alpha=0.1)
54 |
55 | ax.legend(fontsize=15, ncol=2, loc='upper left', framealpha=0.7)
56 | ax.grid()
57 | ax.set_xlim(X_min, X_max)
58 | ax.set_ylim(Y_min, Y_max)
59 | ax.set_yticks(np.linspace(Y_min, Y_max, 8).round(2))
60 | ax.set_yticklabels(np.linspace(Y_min, Y_max, 8).round(2), fontsize=15)
61 | ax.set_xticks(range(-5, 5, 1))
62 | ax.set_xticklabels(range(-5, 5, 1), fontsize=15)
63 | ax.grid()
64 |
65 | ###### Model Architechture ######
66 |
67 | import torch
68 | import numpy as np
69 | import matplotlib.pyplot as plt
70 | import torch.optim as optim
71 |
72 | A = -1 # initial value of the second argument
73 | class Net(torch.nn.Module):
74 |
75 | def __init__(self, hidden_layer=100):
76 | super().__init__()
77 |
78 | self.hidden_layer = hidden_layer
79 | self.fc1 = torch.nn.Linear(2, hidden_layer)
80 | self.fc2 = torch.nn.Linear(hidden_layer, 1)
81 | self.activation = torch.nn.LeakyReLU()
82 |
83 | def forward(self, x, y=None):
84 |
85 | if y is None:
86 | x = torch.concat([x, -A * torch.ones([x.shape[0], 1])], dim=1)
87 |
88 | else:
89 | x = torch.concat([x, y], dim=1)
90 |
91 | x = self.activation(self.fc1(x))
92 | x = self.fc2(x)
93 |
94 | return x
95 |
96 | ###### Optimization Step ######
97 | I = 3 # number of iterations
98 | def step(epoch, model, optimizer, criterion, X_data, Y_data, l_gradient_penalty):
99 |
100 | model.train()
101 | optimizer.zero_grad()
102 |
103 | Y_0 = network(X_data, -A*torch.ones_like(Y_data))
104 | loss = criterion(Y_0, Y_data)
105 |
106 | Y_t = Y_0
107 | for _ in range(I):
108 | Y_t = network(X_data, Y_t)
109 | loss += criterion(Y_t, Y_data)
110 |
111 | loss.backward()
112 | optimizer.step()
113 |
114 | return loss.item()
115 |
116 | epochs = 150
117 |
118 | network = Net(hidden_layer=128)
119 |
120 | criterion = torch.nn.MSELoss()
121 | optimizer = optim.Adam(network.parameters(), lr=0.05)
122 |
123 | network.train()
124 |
125 | ###### Training ######
126 | for epoch in range(epochs):
127 | network.train()
128 | step(epoch, network, optimizer, criterion, X_data, Y_data, 3e-4)
129 |
130 | ###### Visualize Final Result ######
131 | network.eval()
132 |
133 | X_min, X_max = -5, 5
134 | Xt = torch.linspace(X_min, X_max, 100)[:, None]
135 |
136 | Y_1 = network(Xt)
137 | Y_t = Y_1.detach()
138 |
139 | preds = [Y_t]
140 | for _ in range(I):
141 | Y_t = network(Xt, Y_t)
142 | Y_t = Y_t.detach()
143 | Y_final = Y_t
144 | preds.append(Y_final)
145 |
146 | preds = torch.stack(preds)
147 |
148 | error = torch.std(preds, dim=0).detach()
149 |
150 | # Create figure and subplots
151 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6))
152 |
153 | ax = ax1
154 | ax.grid()
155 |
156 | cmap = cm.get_cmap('winter', len(preds)) # 'winter' colormap goes from blue to green
157 |
158 | ax.scatter(X_data, Y_data, color="seagreen", label="data")
159 | for i, p in enumerate(preds):
160 | color = cmap(i)
161 | ax.plot(Xt, p, c=color, label=f"$y_{i}$")
162 |
163 | Y_min, Y_max = -0.9, 0.7
164 | ax.fill_between([-1, 1.5], Y_min, Y_max, color='seagreen', alpha=0.1, label="In-distibution Region")
165 | ax.fill_between([1.5, X_max], Y_min, Y_max, color='red', alpha=0.1, label="Out-of-distibution Region")
166 | ax.fill_between([X_min, -1], Y_min, Y_max, color='red', alpha=0.1)
167 | ax.fill_between(Xt[:, 0], preds.min(dim=0)[0][:, 0], preds.max(dim=0)[0][:, 0], alpha=0.4, color="skyblue", label="$\sigma_{pred}$")
168 |
169 | ax.legend(fontsize=15, ncol=2, loc='upper left', framealpha=0.7)
170 | ax.grid()
171 | ax.set_xlim(X_min, X_max)
172 | ax.set_ylim(Y_min, Y_max)
173 | ax.set_yticks(np.linspace(Y_min, Y_max, 8).round(2))
174 | ax.set_yticklabels(np.linspace(Y_min, Y_max, 8).round(2), fontsize=15)
175 | ax.set_xticks(range(-5, 5, 1))
176 | ax.set_xticklabels(range(-5, 5, 1), fontsize=15)
177 | ax.grid()
178 |
179 | ax = ax2
180 | ax.set_title("Iteration Variance")
181 | plt.plot(Xt[:, 0], error, c="dodgerblue", label="Iteration Variance")
182 | ax.fill_between([-1, 1.5], 0, 1, color='seagreen', alpha=0.1, label="In-distibution Region")
183 | ax.fill_between([1.5, X_max], 0, 1, color='red', alpha=0.1, label="Out-of-distibution Region")
184 | ax.fill_between([X_min, -1], 0, 1, color='red', alpha=0.1)
185 | ax.legend(fontsize=15, loc='upper left', framealpha=0.7)
186 | ax.grid(True)
187 | ax.set_xlim(-5, 5)
188 | ax.set_ylim(0, 1)
189 | ax.set_xticks(range(-5, 5, 1))
190 | ax.set_xticklabels(range(-5, 5, 1), fontsize=15)
191 |
192 | fig.suptitle(f"Epoch #{epoch}", fontsize=16)
193 |
194 | plt.tight_layout()
195 | plt.show()
196 |
197 |
--------------------------------------------------------------------------------
/exps/notebooks/mnist_classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4"
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | },
13 | "language_info": {
14 | "name": "python"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "source": [
22 | "# MNIST Classification with Iterative Uncertainty\n",
23 | "\n",
24 | "In this notebook, we will implement a simple neural network to classify handwritten digits from the MNIST dataset using PyTorch. On top of it, we will apply the iterative uncertainty approach and evaluate its uncertainty quality in terms of out-of-distribution detection."
25 | ],
26 | "metadata": {
27 | "id": "zhr7Isu5wkQj"
28 | }
29 | },
30 | {
31 | "cell_type": "code",
32 | "source": [
33 | "import torch\n",
34 | "import torchvision\n",
35 | "import numpy as np\n",
36 | "import random\n",
37 | "\n",
38 | "n_epochs = 3 # number of epochs for training\n",
39 | "batch_size_train = 64 # batch size for training\n",
40 | "batch_size_test = 1000 # batch size for testing\n",
41 | "learning_rate = 0.01 # learning rate for Adam\n",
42 | "momentum = 0.5 # optimizer momentum\n",
43 | "log_interval = 10 # logging interval for metrics\n",
44 | "device = \"cuda\"\n",
45 | "\n",
46 | "# fixing random seeds\n",
47 | "random_seed = 6\n",
48 | "torch.backends.cudnn.enabled = False\n",
49 | "torch.manual_seed(random_seed)\n",
50 | "np.random.seed(random_seed)\n",
51 | "random.seed(random_seed)"
52 | ],
53 | "metadata": {
54 | "id": "N17oLGkpHftI"
55 | },
56 | "execution_count": 1,
57 | "outputs": []
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "source": [
62 | "# Train / Test data loading\n",
63 | "\n",
64 | "The MNIST dataset is a collection of handwritten digits commonly used for training and testing in the field of machine learning. It contains 70,000 grayscale images of handwritten digits, split into 60,000 training images and 10,000 test images. Each image is 28x28 pixels in size and represents a digit from 0 to 9."
65 | ],
66 | "metadata": {
67 | "id": "lx0CY2jPw_Fn"
68 | }
69 | },
70 | {
71 | "cell_type": "code",
72 | "source": [
73 | "train_loader = torch.utils.data.DataLoader(\n",
74 | " torchvision.datasets.MNIST('./MNIST/', train=True, download=True,\n",
75 | " transform=torchvision.transforms.Compose([\n",
76 | " torchvision.transforms.ToTensor(),\n",
77 | " torchvision.transforms.Normalize(\n",
78 | " (0.1307,), (0.3081,))\n",
79 | " ])),\n",
80 | " batch_size=batch_size_train, shuffle=True)\n",
81 | "\n",
82 | "test_loader = torch.utils.data.DataLoader(\n",
83 | " torchvision.datasets.MNIST('./MNIST/', train=False, download=True,\n",
84 | " transform=torchvision.transforms.Compose([\n",
85 | " torchvision.transforms.ToTensor(),\n",
86 | " torchvision.transforms.Normalize(\n",
87 | " (0.1307,), (0.3081,))\n",
88 | " ])),\n",
89 | " batch_size=batch_size_test, shuffle=True)"
90 | ],
91 | "metadata": {
92 | "id": "wIp68p_WHlVt"
93 | },
94 | "execution_count": 2,
95 | "outputs": []
96 | },
97 | {
98 | "cell_type": "code",
99 | "source": [
100 | "import matplotlib.pyplot as plt\n",
101 | "\n",
102 | "images, targets = next(iter(train_loader))\n",
103 | "\n",
104 | "# Number of images you want to display\n",
105 | "num_images = 10\n",
106 | "\n",
107 | "# Create a figure and a row of subplots\n",
108 | "fig, axes = plt.subplots(1, num_images, figsize=(15, 3))\n",
109 | "\n",
110 | "# Plot each image on a separate subplot\n",
111 | "for i in range(num_images):\n",
112 | " axes[i].imshow(images[i, 0], cmap='gray')\n",
113 | " axes[i].axis('off') # Hide axis\n",
114 | "\n",
115 | "plt.show()"
116 | ],
117 | "metadata": {
118 | "id": "90b10PHvxE7i",
119 | "outputId": "955d5cde-8ded-41b3-bc74-d979fc1b7e4e",
120 | "colab": {
121 | "base_uri": "https://localhost:8080/",
122 | "height": 155
123 | }
124 | },
125 | "execution_count": 3,
126 | "outputs": [
127 | {
128 | "output_type": "display_data",
129 | "data": {
130 | "text/plain": [
131 | ""
132 | ],
133 | "image/png": "\n"
134 | },
135 | "metadata": {}
136 | }
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "source": [
142 | "## NN Architechture (Iterative schema)\n",
143 | "\n",
144 | "We define a simple neural network model with several Conv2D layers followed by pooling layers and some fully connected layers. The input layer takes 28x28 MNIST images. The hidden layers have (320, 50) and (50, 50) units respectively, and the output layer has 10 units (one for each digit).\n",
145 | "\n",
146 | "The only difference from the common architecture is that the forward function, in addition to taking an input image, can also take the output from the previous step. Therefore, our model can process inputs iteratively.\n"
147 | ],
148 | "metadata": {
149 | "id": "wpdY17tRxRTG"
150 | }
151 | },
152 | {
153 | "cell_type": "code",
154 | "source": [
155 | "import torch.nn as nn\n",
156 | "import torch.nn.functional as F\n",
157 | "import torch.optim as optim\n",
158 | "\n",
159 | "class Net(nn.Module):\n",
160 | " def __init__(self):\n",
161 | " super(Net, self).__init__()\n",
162 | " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
163 | " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
164 | " self.fc1 = nn.Linear(320, 50)\n",
165 | " self.fc2 = nn.Linear(50, 50)\n",
166 | " self.fc3 = nn.Linear(50, 10)\n",
167 | " self.activation = nn.ELU()\n",
168 | "\n",
169 | " self.cf = nn.Linear(10, 50)\n",
170 | "\n",
171 | " def forward(self, x, pred_prob=None):\n",
172 | " x = self.activation(F.max_pool2d(self.conv1(x), 2))\n",
173 | " x = self.activation(F.max_pool2d(self.conv2(x), 2))\n",
174 | " x = x.view(-1, 320)\n",
175 | " x = self.activation(self.fc1(x))\n",
176 | "\n",
177 | " # iterative part, if \"pred_prob\"\n",
178 | " # is not None, we use it for inference\n",
179 | " if pred_prob is not None:\n",
180 | " pred_prob = torch.nn.functional.softmax(pred_prob, 1)\n",
181 | " x += self.activation(self.cf(pred_prob))\n",
182 | "\n",
183 | " x = self.activation(self.fc2(x))\n",
184 | " x = self.fc3(x)\n",
185 | " return F.log_softmax(x)"
186 | ],
187 | "metadata": {
188 | "id": "t3CXPKqeHlzW"
189 | },
190 | "execution_count": 4,
191 | "outputs": []
192 | },
193 | {
194 | "cell_type": "markdown",
195 | "source": [
196 | "# Model Training"
197 | ],
198 | "metadata": {
199 | "id": "sGGLrajiykjK"
200 | }
201 | },
202 | {
203 | "cell_type": "code",
204 | "source": [
205 | "network = Net().to(device)\n",
206 | "optimizer = optim.Adam(network.parameters(), lr=learning_rate)\n",
207 | "\n",
208 | "train_losses = []\n",
209 | "train_counter = []\n",
210 | "test_losses = []\n",
211 | "test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]"
212 | ],
213 | "metadata": {
214 | "id": "Av8yDtCaHwE1"
215 | },
216 | "execution_count": 5,
217 | "outputs": []
218 | },
219 | {
220 | "cell_type": "code",
221 | "source": [
222 | "# helper function for training inference\n",
223 | "def inference_train(x, N = 3):\n",
224 | " preds = []\n",
225 | " pred = None\n",
226 | " for _ in range(N):\n",
227 | " pred = network(x, pred)\n",
228 | " preds.append(pred)\n",
229 | " return torch.cat(preds)\n",
230 | "\n",
231 | "# helper function for testing inference\n",
232 | "def inference_test(x, N = 3):\n",
233 | " preds = []\n",
234 | " pred = None\n",
235 | " for _ in range(N):\n",
236 | " pred = torch.nn.functional.softmax(network(x, pred), 1)\n",
237 | " preds.append(pred)\n",
238 | "\n",
239 | " # similar to ensembles, final predictions\n",
240 | " # is an average of predictions\n",
241 | " return sum(preds) / N"
242 | ],
243 | "metadata": {
244 | "id": "xdh7CeqpH8Rs"
245 | },
246 | "execution_count": 6,
247 | "outputs": []
248 | },
249 | {
250 | "cell_type": "code",
251 | "source": [
252 | "def train(epoch):\n",
253 | " network.train()\n",
254 | " for batch_idx, (data, target) in enumerate(train_loader):\n",
255 | "\n",
256 | " data = data.to(device)\n",
257 | " target = target.to(device)\n",
258 | "\n",
259 | " optimizer.zero_grad()\n",
260 | " output = inference_train(data)\n",
261 | "\n",
262 | " # For further details, see Eq. 4 in https://arxiv.org/pdf/2403.16732\n",
263 | " # The first input with an additional \"blank\" channel, the first term in Eq. 1\n",
264 | " loss = F.nll_loss(output, target.tile(3))\n",
265 | "\n",
266 | " loss.backward()\n",
267 | " optimizer.step()\n",
268 | " if batch_idx % log_interval == 0:\n",
269 | " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
270 | " epoch, batch_idx * len(data), len(train_loader.dataset),\n",
271 | " 100. * batch_idx / len(train_loader), loss.item()))\n",
272 | " train_losses.append(loss.item())\n",
273 | " train_counter.append(\n",
274 | " (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))\n",
275 | "\n",
276 | "def test():\n",
277 | " network.eval()\n",
278 | " test_loss = 0\n",
279 | " correct = 0\n",
280 | " with torch.no_grad():\n",
281 | " for data, target in test_loader:\n",
282 | " data = data.to(device)\n",
283 | " target = target.to(device)\n",
284 | " output = inference_test(data)\n",
285 | " test_loss += F.nll_loss(output, target, size_average=False).item()\n",
286 | " pred = output.data.max(1, keepdim=True)[1]\n",
287 | " correct += pred.eq(target.data.view_as(pred)).sum()\n",
288 | " test_loss /= len(test_loader.dataset)\n",
289 | " test_losses.append(test_loss)\n",
290 | " print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
291 | " test_loss, correct, len(test_loader.dataset),\n",
292 | " 100. * correct / len(test_loader.dataset)))"
293 | ],
294 | "metadata": {
295 | "id": "l8J7gph4IBK-"
296 | },
297 | "execution_count": 7,
298 | "outputs": []
299 | },
300 | {
301 | "cell_type": "code",
302 | "source": [
303 | "test()\n",
304 | "for epoch in range(1, n_epochs + 1):\n",
305 | " train(epoch)\n",
306 | " test()\n",
307 | "\n",
308 | "for param_group in optimizer.param_groups:\n",
309 | " param_group['lr'] = 1e-3\n",
310 | "\n",
311 | "for _ in range(2):\n",
312 | " train(epoch)\n",
313 | " test()"
314 | ],
315 | "metadata": {
316 | "colab": {
317 | "base_uri": "https://localhost:8080/"
318 | },
319 | "id": "SQwPqCUtIELo",
320 | "outputId": "219030a6-b56a-4031-d1e0-9662e267e33f"
321 | },
322 | "execution_count": 8,
323 | "outputs": [
324 | {
325 | "output_type": "stream",
326 | "name": "stderr",
327 | "text": [
328 | ":31: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
329 | " return F.log_softmax(x)\n",
330 | "/usr/local/lib/python3.10/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
331 | " warnings.warn(warning.format(ret))\n"
332 | ]
333 | },
334 | {
335 | "output_type": "stream",
336 | "name": "stdout",
337 | "text": [
338 | "\n",
339 | "Test set: Avg. loss: -0.1003, Accuracy: 1781/10000 (18%)\n",
340 | "\n",
341 | "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.312376\n",
342 | "Train Epoch: 1 [640/60000 (1%)]\tLoss: 1.144794\n",
343 | "Train Epoch: 1 [1280/60000 (2%)]\tLoss: 0.467108\n",
344 | "Train Epoch: 1 [1920/60000 (3%)]\tLoss: 0.602299\n",
345 | "Train Epoch: 1 [2560/60000 (4%)]\tLoss: 0.497253\n",
346 | "Train Epoch: 1 [3200/60000 (5%)]\tLoss: 0.241655\n",
347 | "Train Epoch: 1 [3840/60000 (6%)]\tLoss: 0.283207\n",
348 | "Train Epoch: 1 [4480/60000 (7%)]\tLoss: 0.301319\n",
349 | "Train Epoch: 1 [5120/60000 (9%)]\tLoss: 0.293853\n",
350 | "Train Epoch: 1 [5760/60000 (10%)]\tLoss: 0.265691\n",
351 | "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.175180\n",
352 | "Train Epoch: 1 [7040/60000 (12%)]\tLoss: 0.286978\n",
353 | "Train Epoch: 1 [7680/60000 (13%)]\tLoss: 0.053208\n",
354 | "Train Epoch: 1 [8320/60000 (14%)]\tLoss: 0.179960\n",
355 | "Train Epoch: 1 [8960/60000 (15%)]\tLoss: 0.142399\n",
356 | "Train Epoch: 1 [9600/60000 (16%)]\tLoss: 0.080546\n",
357 | "Train Epoch: 1 [10240/60000 (17%)]\tLoss: 0.230218\n",
358 | "Train Epoch: 1 [10880/60000 (18%)]\tLoss: 0.307157\n",
359 | "Train Epoch: 1 [11520/60000 (19%)]\tLoss: 0.316138\n",
360 | "Train Epoch: 1 [12160/60000 (20%)]\tLoss: 0.254721\n",
361 | "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.171003\n",
362 | "Train Epoch: 1 [13440/60000 (22%)]\tLoss: 0.273300\n",
363 | "Train Epoch: 1 [14080/60000 (23%)]\tLoss: 0.064190\n",
364 | "Train Epoch: 1 [14720/60000 (25%)]\tLoss: 0.292505\n",
365 | "Train Epoch: 1 [15360/60000 (26%)]\tLoss: 0.099401\n",
366 | "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.321094\n",
367 | "Train Epoch: 1 [16640/60000 (28%)]\tLoss: 0.156088\n",
368 | "Train Epoch: 1 [17280/60000 (29%)]\tLoss: 0.551294\n",
369 | "Train Epoch: 1 [17920/60000 (30%)]\tLoss: 0.109925\n",
370 | "Train Epoch: 1 [18560/60000 (31%)]\tLoss: 0.060247\n",
371 | "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.017538\n",
372 | "Train Epoch: 1 [19840/60000 (33%)]\tLoss: 0.154289\n",
373 | "Train Epoch: 1 [20480/60000 (34%)]\tLoss: 0.164849\n",
374 | "Train Epoch: 1 [21120/60000 (35%)]\tLoss: 0.104086\n",
375 | "Train Epoch: 1 [21760/60000 (36%)]\tLoss: 0.261193\n",
376 | "Train Epoch: 1 [22400/60000 (37%)]\tLoss: 0.080790\n",
377 | "Train Epoch: 1 [23040/60000 (38%)]\tLoss: 0.309249\n",
378 | "Train Epoch: 1 [23680/60000 (39%)]\tLoss: 0.038709\n",
379 | "Train Epoch: 1 [24320/60000 (41%)]\tLoss: 0.208796\n",
380 | "Train Epoch: 1 [24960/60000 (42%)]\tLoss: 0.018224\n",
381 | "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.165220\n",
382 | "Train Epoch: 1 [26240/60000 (44%)]\tLoss: 0.111875\n",
383 | "Train Epoch: 1 [26880/60000 (45%)]\tLoss: 0.081600\n",
384 | "Train Epoch: 1 [27520/60000 (46%)]\tLoss: 0.040166\n",
385 | "Train Epoch: 1 [28160/60000 (47%)]\tLoss: 0.168901\n",
386 | "Train Epoch: 1 [28800/60000 (48%)]\tLoss: 0.041811\n",
387 | "Train Epoch: 1 [29440/60000 (49%)]\tLoss: 0.152959\n",
388 | "Train Epoch: 1 [30080/60000 (50%)]\tLoss: 0.113740\n",
389 | "Train Epoch: 1 [30720/60000 (51%)]\tLoss: 0.089509\n",
390 | "Train Epoch: 1 [31360/60000 (52%)]\tLoss: 0.188480\n",
391 | "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.253904\n",
392 | "Train Epoch: 1 [32640/60000 (54%)]\tLoss: 0.244142\n",
393 | "Train Epoch: 1 [33280/60000 (55%)]\tLoss: 0.286333\n",
394 | "Train Epoch: 1 [33920/60000 (57%)]\tLoss: 0.075342\n",
395 | "Train Epoch: 1 [34560/60000 (58%)]\tLoss: 0.195837\n",
396 | "Train Epoch: 1 [35200/60000 (59%)]\tLoss: 0.148247\n",
397 | "Train Epoch: 1 [35840/60000 (60%)]\tLoss: 0.240426\n",
398 | "Train Epoch: 1 [36480/60000 (61%)]\tLoss: 0.056933\n",
399 | "Train Epoch: 1 [37120/60000 (62%)]\tLoss: 0.191502\n",
400 | "Train Epoch: 1 [37760/60000 (63%)]\tLoss: 0.144343\n",
401 | "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.178543\n",
402 | "Train Epoch: 1 [39040/60000 (65%)]\tLoss: 0.156982\n",
403 | "Train Epoch: 1 [39680/60000 (66%)]\tLoss: 0.231199\n",
404 | "Train Epoch: 1 [40320/60000 (67%)]\tLoss: 0.193805\n",
405 | "Train Epoch: 1 [40960/60000 (68%)]\tLoss: 0.430056\n",
406 | "Train Epoch: 1 [41600/60000 (69%)]\tLoss: 0.039209\n",
407 | "Train Epoch: 1 [42240/60000 (70%)]\tLoss: 0.271172\n",
408 | "Train Epoch: 1 [42880/60000 (71%)]\tLoss: 0.039059\n",
409 | "Train Epoch: 1 [43520/60000 (72%)]\tLoss: 0.083934\n",
410 | "Train Epoch: 1 [44160/60000 (74%)]\tLoss: 0.143705\n",
411 | "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.014715\n",
412 | "Train Epoch: 1 [45440/60000 (76%)]\tLoss: 0.034582\n",
413 | "Train Epoch: 1 [46080/60000 (77%)]\tLoss: 0.046080\n",
414 | "Train Epoch: 1 [46720/60000 (78%)]\tLoss: 0.142156\n",
415 | "Train Epoch: 1 [47360/60000 (79%)]\tLoss: 0.485598\n",
416 | "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.328156\n",
417 | "Train Epoch: 1 [48640/60000 (81%)]\tLoss: 0.087537\n",
418 | "Train Epoch: 1 [49280/60000 (82%)]\tLoss: 0.017084\n",
419 | "Train Epoch: 1 [49920/60000 (83%)]\tLoss: 0.011264\n",
420 | "Train Epoch: 1 [50560/60000 (84%)]\tLoss: 0.027894\n",
421 | "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.208086\n",
422 | "Train Epoch: 1 [51840/60000 (86%)]\tLoss: 0.125015\n",
423 | "Train Epoch: 1 [52480/60000 (87%)]\tLoss: 0.127294\n",
424 | "Train Epoch: 1 [53120/60000 (88%)]\tLoss: 0.238284\n",
425 | "Train Epoch: 1 [53760/60000 (90%)]\tLoss: 0.030879\n",
426 | "Train Epoch: 1 [54400/60000 (91%)]\tLoss: 0.089591\n",
427 | "Train Epoch: 1 [55040/60000 (92%)]\tLoss: 0.049601\n",
428 | "Train Epoch: 1 [55680/60000 (93%)]\tLoss: 0.666180\n",
429 | "Train Epoch: 1 [56320/60000 (94%)]\tLoss: 0.067744\n",
430 | "Train Epoch: 1 [56960/60000 (95%)]\tLoss: 0.048680\n",
431 | "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.164761\n",
432 | "Train Epoch: 1 [58240/60000 (97%)]\tLoss: 0.180428\n",
433 | "Train Epoch: 1 [58880/60000 (98%)]\tLoss: 0.137088\n",
434 | "Train Epoch: 1 [59520/60000 (99%)]\tLoss: 0.089983\n",
435 | "\n",
436 | "Test set: Avg. loss: -0.9662, Accuracy: 9717/10000 (97%)\n",
437 | "\n",
438 | "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.021369\n",
439 | "Train Epoch: 2 [640/60000 (1%)]\tLoss: 0.112997\n",
440 | "Train Epoch: 2 [1280/60000 (2%)]\tLoss: 0.182833\n",
441 | "Train Epoch: 2 [1920/60000 (3%)]\tLoss: 0.109863\n",
442 | "Train Epoch: 2 [2560/60000 (4%)]\tLoss: 0.066127\n",
443 | "Train Epoch: 2 [3200/60000 (5%)]\tLoss: 0.312570\n",
444 | "Train Epoch: 2 [3840/60000 (6%)]\tLoss: 0.107947\n",
445 | "Train Epoch: 2 [4480/60000 (7%)]\tLoss: 0.129223\n",
446 | "Train Epoch: 2 [5120/60000 (9%)]\tLoss: 0.006381\n",
447 | "Train Epoch: 2 [5760/60000 (10%)]\tLoss: 0.038859\n",
448 | "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.013144\n",
449 | "Train Epoch: 2 [7040/60000 (12%)]\tLoss: 0.020549\n",
450 | "Train Epoch: 2 [7680/60000 (13%)]\tLoss: 0.011785\n",
451 | "Train Epoch: 2 [8320/60000 (14%)]\tLoss: 0.231971\n",
452 | "Train Epoch: 2 [8960/60000 (15%)]\tLoss: 0.147990\n",
453 | "Train Epoch: 2 [9600/60000 (16%)]\tLoss: 0.171210\n",
454 | "Train Epoch: 2 [10240/60000 (17%)]\tLoss: 0.147863\n",
455 | "Train Epoch: 2 [10880/60000 (18%)]\tLoss: 0.333462\n",
456 | "Train Epoch: 2 [11520/60000 (19%)]\tLoss: 0.023902\n",
457 | "Train Epoch: 2 [12160/60000 (20%)]\tLoss: 0.224719\n",
458 | "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.025668\n",
459 | "Train Epoch: 2 [13440/60000 (22%)]\tLoss: 0.064675\n",
460 | "Train Epoch: 2 [14080/60000 (23%)]\tLoss: 0.037010\n",
461 | "Train Epoch: 2 [14720/60000 (25%)]\tLoss: 0.050318\n",
462 | "Train Epoch: 2 [15360/60000 (26%)]\tLoss: 0.083974\n",
463 | "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.138085\n",
464 | "Train Epoch: 2 [16640/60000 (28%)]\tLoss: 0.217917\n",
465 | "Train Epoch: 2 [17280/60000 (29%)]\tLoss: 0.116676\n",
466 | "Train Epoch: 2 [17920/60000 (30%)]\tLoss: 0.034423\n",
467 | "Train Epoch: 2 [18560/60000 (31%)]\tLoss: 0.032855\n",
468 | "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.225311\n",
469 | "Train Epoch: 2 [19840/60000 (33%)]\tLoss: 0.006941\n",
470 | "Train Epoch: 2 [20480/60000 (34%)]\tLoss: 0.072871\n",
471 | "Train Epoch: 2 [21120/60000 (35%)]\tLoss: 0.058232\n",
472 | "Train Epoch: 2 [21760/60000 (36%)]\tLoss: 0.015516\n",
473 | "Train Epoch: 2 [22400/60000 (37%)]\tLoss: 0.105818\n",
474 | "Train Epoch: 2 [23040/60000 (38%)]\tLoss: 0.127326\n",
475 | "Train Epoch: 2 [23680/60000 (39%)]\tLoss: 0.212847\n",
476 | "Train Epoch: 2 [24320/60000 (41%)]\tLoss: 0.173671\n",
477 | "Train Epoch: 2 [24960/60000 (42%)]\tLoss: 0.158303\n",
478 | "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.514192\n",
479 | "Train Epoch: 2 [26240/60000 (44%)]\tLoss: 0.491562\n",
480 | "Train Epoch: 2 [26880/60000 (45%)]\tLoss: 0.124600\n",
481 | "Train Epoch: 2 [27520/60000 (46%)]\tLoss: 0.151669\n",
482 | "Train Epoch: 2 [28160/60000 (47%)]\tLoss: 0.042192\n",
483 | "Train Epoch: 2 [28800/60000 (48%)]\tLoss: 0.102364\n",
484 | "Train Epoch: 2 [29440/60000 (49%)]\tLoss: 0.048251\n",
485 | "Train Epoch: 2 [30080/60000 (50%)]\tLoss: 0.084469\n",
486 | "Train Epoch: 2 [30720/60000 (51%)]\tLoss: 0.548334\n",
487 | "Train Epoch: 2 [31360/60000 (52%)]\tLoss: 0.256530\n",
488 | "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.018591\n",
489 | "Train Epoch: 2 [32640/60000 (54%)]\tLoss: 0.255403\n",
490 | "Train Epoch: 2 [33280/60000 (55%)]\tLoss: 0.034644\n",
491 | "Train Epoch: 2 [33920/60000 (57%)]\tLoss: 0.044591\n",
492 | "Train Epoch: 2 [34560/60000 (58%)]\tLoss: 0.040136\n",
493 | "Train Epoch: 2 [35200/60000 (59%)]\tLoss: 0.105666\n",
494 | "Train Epoch: 2 [35840/60000 (60%)]\tLoss: 0.225017\n",
495 | "Train Epoch: 2 [36480/60000 (61%)]\tLoss: 0.076942\n",
496 | "Train Epoch: 2 [37120/60000 (62%)]\tLoss: 0.142982\n",
497 | "Train Epoch: 2 [37760/60000 (63%)]\tLoss: 0.008805\n",
498 | "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.030216\n",
499 | "Train Epoch: 2 [39040/60000 (65%)]\tLoss: 0.525962\n",
500 | "Train Epoch: 2 [39680/60000 (66%)]\tLoss: 0.081722\n",
501 | "Train Epoch: 2 [40320/60000 (67%)]\tLoss: 0.293573\n",
502 | "Train Epoch: 2 [40960/60000 (68%)]\tLoss: 0.278041\n",
503 | "Train Epoch: 2 [41600/60000 (69%)]\tLoss: 0.066600\n",
504 | "Train Epoch: 2 [42240/60000 (70%)]\tLoss: 0.083429\n",
505 | "Train Epoch: 2 [42880/60000 (71%)]\tLoss: 0.176400\n",
506 | "Train Epoch: 2 [43520/60000 (72%)]\tLoss: 0.070648\n",
507 | "Train Epoch: 2 [44160/60000 (74%)]\tLoss: 0.136096\n",
508 | "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.072206\n",
509 | "Train Epoch: 2 [45440/60000 (76%)]\tLoss: 0.044717\n",
510 | "Train Epoch: 2 [46080/60000 (77%)]\tLoss: 0.054737\n",
511 | "Train Epoch: 2 [46720/60000 (78%)]\tLoss: 0.155012\n",
512 | "Train Epoch: 2 [47360/60000 (79%)]\tLoss: 0.033407\n",
513 | "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.572021\n",
514 | "Train Epoch: 2 [48640/60000 (81%)]\tLoss: 0.045251\n",
515 | "Train Epoch: 2 [49280/60000 (82%)]\tLoss: 0.180520\n",
516 | "Train Epoch: 2 [49920/60000 (83%)]\tLoss: 0.084542\n",
517 | "Train Epoch: 2 [50560/60000 (84%)]\tLoss: 0.150355\n",
518 | "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.001370\n",
519 | "Train Epoch: 2 [51840/60000 (86%)]\tLoss: 0.027981\n",
520 | "Train Epoch: 2 [52480/60000 (87%)]\tLoss: 0.023305\n",
521 | "Train Epoch: 2 [53120/60000 (88%)]\tLoss: 0.031542\n",
522 | "Train Epoch: 2 [53760/60000 (90%)]\tLoss: 0.161370\n",
523 | "Train Epoch: 2 [54400/60000 (91%)]\tLoss: 0.013536\n",
524 | "Train Epoch: 2 [55040/60000 (92%)]\tLoss: 0.026579\n",
525 | "Train Epoch: 2 [55680/60000 (93%)]\tLoss: 0.022276\n",
526 | "Train Epoch: 2 [56320/60000 (94%)]\tLoss: 0.105375\n",
527 | "Train Epoch: 2 [56960/60000 (95%)]\tLoss: 0.092841\n",
528 | "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.091326\n",
529 | "Train Epoch: 2 [58240/60000 (97%)]\tLoss: 0.110826\n",
530 | "Train Epoch: 2 [58880/60000 (98%)]\tLoss: 0.081717\n",
531 | "Train Epoch: 2 [59520/60000 (99%)]\tLoss: 0.092696\n",
532 | "\n",
533 | "Test set: Avg. loss: -0.9776, Accuracy: 9831/10000 (98%)\n",
534 | "\n",
535 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.063463\n",
536 | "Train Epoch: 3 [640/60000 (1%)]\tLoss: 0.114469\n",
537 | "Train Epoch: 3 [1280/60000 (2%)]\tLoss: 0.006183\n",
538 | "Train Epoch: 3 [1920/60000 (3%)]\tLoss: 0.039701\n",
539 | "Train Epoch: 3 [2560/60000 (4%)]\tLoss: 0.106729\n",
540 | "Train Epoch: 3 [3200/60000 (5%)]\tLoss: 0.012483\n",
541 | "Train Epoch: 3 [3840/60000 (6%)]\tLoss: 0.236577\n",
542 | "Train Epoch: 3 [4480/60000 (7%)]\tLoss: 0.063634\n",
543 | "Train Epoch: 3 [5120/60000 (9%)]\tLoss: 0.061753\n",
544 | "Train Epoch: 3 [5760/60000 (10%)]\tLoss: 0.052922\n",
545 | "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.190764\n",
546 | "Train Epoch: 3 [7040/60000 (12%)]\tLoss: 0.058996\n",
547 | "Train Epoch: 3 [7680/60000 (13%)]\tLoss: 0.022317\n",
548 | "Train Epoch: 3 [8320/60000 (14%)]\tLoss: 0.007046\n",
549 | "Train Epoch: 3 [8960/60000 (15%)]\tLoss: 0.080138\n",
550 | "Train Epoch: 3 [9600/60000 (16%)]\tLoss: 0.063802\n",
551 | "Train Epoch: 3 [10240/60000 (17%)]\tLoss: 0.041988\n",
552 | "Train Epoch: 3 [10880/60000 (18%)]\tLoss: 0.098359\n",
553 | "Train Epoch: 3 [11520/60000 (19%)]\tLoss: 0.153680\n",
554 | "Train Epoch: 3 [12160/60000 (20%)]\tLoss: 0.464352\n",
555 | "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.070375\n",
556 | "Train Epoch: 3 [13440/60000 (22%)]\tLoss: 0.018513\n",
557 | "Train Epoch: 3 [14080/60000 (23%)]\tLoss: 0.022059\n",
558 | "Train Epoch: 3 [14720/60000 (25%)]\tLoss: 0.003471\n",
559 | "Train Epoch: 3 [15360/60000 (26%)]\tLoss: 0.019260\n",
560 | "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.055360\n",
561 | "Train Epoch: 3 [16640/60000 (28%)]\tLoss: 0.185760\n",
562 | "Train Epoch: 3 [17280/60000 (29%)]\tLoss: 0.094002\n",
563 | "Train Epoch: 3 [17920/60000 (30%)]\tLoss: 0.088621\n",
564 | "Train Epoch: 3 [18560/60000 (31%)]\tLoss: 0.202922\n",
565 | "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.217287\n",
566 | "Train Epoch: 3 [19840/60000 (33%)]\tLoss: 0.022548\n",
567 | "Train Epoch: 3 [20480/60000 (34%)]\tLoss: 0.064074\n",
568 | "Train Epoch: 3 [21120/60000 (35%)]\tLoss: 0.217457\n",
569 | "Train Epoch: 3 [21760/60000 (36%)]\tLoss: 0.041618\n",
570 | "Train Epoch: 3 [22400/60000 (37%)]\tLoss: 0.426640\n",
571 | "Train Epoch: 3 [23040/60000 (38%)]\tLoss: 0.064455\n",
572 | "Train Epoch: 3 [23680/60000 (39%)]\tLoss: 0.029709\n",
573 | "Train Epoch: 3 [24320/60000 (41%)]\tLoss: 0.374250\n",
574 | "Train Epoch: 3 [24960/60000 (42%)]\tLoss: 0.019916\n",
575 | "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.077315\n",
576 | "Train Epoch: 3 [26240/60000 (44%)]\tLoss: 0.120163\n",
577 | "Train Epoch: 3 [26880/60000 (45%)]\tLoss: 0.007376\n",
578 | "Train Epoch: 3 [27520/60000 (46%)]\tLoss: 0.048888\n",
579 | "Train Epoch: 3 [28160/60000 (47%)]\tLoss: 0.148275\n",
580 | "Train Epoch: 3 [28800/60000 (48%)]\tLoss: 0.070750\n",
581 | "Train Epoch: 3 [29440/60000 (49%)]\tLoss: 0.353416\n",
582 | "Train Epoch: 3 [30080/60000 (50%)]\tLoss: 0.076076\n",
583 | "Train Epoch: 3 [30720/60000 (51%)]\tLoss: 0.074755\n",
584 | "Train Epoch: 3 [31360/60000 (52%)]\tLoss: 0.120992\n",
585 | "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.245001\n",
586 | "Train Epoch: 3 [32640/60000 (54%)]\tLoss: 0.023723\n",
587 | "Train Epoch: 3 [33280/60000 (55%)]\tLoss: 0.081988\n",
588 | "Train Epoch: 3 [33920/60000 (57%)]\tLoss: 0.150652\n",
589 | "Train Epoch: 3 [34560/60000 (58%)]\tLoss: 0.028882\n",
590 | "Train Epoch: 3 [35200/60000 (59%)]\tLoss: 0.078787\n",
591 | "Train Epoch: 3 [35840/60000 (60%)]\tLoss: 0.063896\n",
592 | "Train Epoch: 3 [36480/60000 (61%)]\tLoss: 0.012244\n",
593 | "Train Epoch: 3 [37120/60000 (62%)]\tLoss: 0.090776\n",
594 | "Train Epoch: 3 [37760/60000 (63%)]\tLoss: 0.207105\n",
595 | "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.006419\n",
596 | "Train Epoch: 3 [39040/60000 (65%)]\tLoss: 0.035537\n",
597 | "Train Epoch: 3 [39680/60000 (66%)]\tLoss: 0.145132\n",
598 | "Train Epoch: 3 [40320/60000 (67%)]\tLoss: 0.049233\n",
599 | "Train Epoch: 3 [40960/60000 (68%)]\tLoss: 0.056022\n",
600 | "Train Epoch: 3 [41600/60000 (69%)]\tLoss: 0.593132\n",
601 | "Train Epoch: 3 [42240/60000 (70%)]\tLoss: 0.180111\n",
602 | "Train Epoch: 3 [42880/60000 (71%)]\tLoss: 0.065261\n",
603 | "Train Epoch: 3 [43520/60000 (72%)]\tLoss: 0.024525\n",
604 | "Train Epoch: 3 [44160/60000 (74%)]\tLoss: 0.086528\n",
605 | "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.120266\n",
606 | "Train Epoch: 3 [45440/60000 (76%)]\tLoss: 0.077596\n",
607 | "Train Epoch: 3 [46080/60000 (77%)]\tLoss: 0.043466\n",
608 | "Train Epoch: 3 [46720/60000 (78%)]\tLoss: 0.120706\n",
609 | "Train Epoch: 3 [47360/60000 (79%)]\tLoss: 0.015336\n",
610 | "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.051106\n",
611 | "Train Epoch: 3 [48640/60000 (81%)]\tLoss: 0.118524\n",
612 | "Train Epoch: 3 [49280/60000 (82%)]\tLoss: 0.023619\n",
613 | "Train Epoch: 3 [49920/60000 (83%)]\tLoss: 0.203596\n",
614 | "Train Epoch: 3 [50560/60000 (84%)]\tLoss: 0.005664\n",
615 | "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.076866\n",
616 | "Train Epoch: 3 [51840/60000 (86%)]\tLoss: 0.049961\n",
617 | "Train Epoch: 3 [52480/60000 (87%)]\tLoss: 0.099215\n",
618 | "Train Epoch: 3 [53120/60000 (88%)]\tLoss: 0.239370\n",
619 | "Train Epoch: 3 [53760/60000 (90%)]\tLoss: 0.044627\n",
620 | "Train Epoch: 3 [54400/60000 (91%)]\tLoss: 0.134020\n",
621 | "Train Epoch: 3 [55040/60000 (92%)]\tLoss: 0.125089\n",
622 | "Train Epoch: 3 [55680/60000 (93%)]\tLoss: 0.096078\n",
623 | "Train Epoch: 3 [56320/60000 (94%)]\tLoss: 0.087274\n",
624 | "Train Epoch: 3 [56960/60000 (95%)]\tLoss: 0.128849\n",
625 | "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.022029\n",
626 | "Train Epoch: 3 [58240/60000 (97%)]\tLoss: 0.014636\n",
627 | "Train Epoch: 3 [58880/60000 (98%)]\tLoss: 0.124867\n",
628 | "Train Epoch: 3 [59520/60000 (99%)]\tLoss: 0.113813\n",
629 | "\n",
630 | "Test set: Avg. loss: -0.9710, Accuracy: 9783/10000 (98%)\n",
631 | "\n",
632 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.034476\n",
633 | "Train Epoch: 3 [640/60000 (1%)]\tLoss: 0.415755\n",
634 | "Train Epoch: 3 [1280/60000 (2%)]\tLoss: 0.181356\n",
635 | "Train Epoch: 3 [1920/60000 (3%)]\tLoss: 0.086929\n",
636 | "Train Epoch: 3 [2560/60000 (4%)]\tLoss: 0.029738\n",
637 | "Train Epoch: 3 [3200/60000 (5%)]\tLoss: 0.016178\n",
638 | "Train Epoch: 3 [3840/60000 (6%)]\tLoss: 0.001996\n",
639 | "Train Epoch: 3 [4480/60000 (7%)]\tLoss: 0.051265\n",
640 | "Train Epoch: 3 [5120/60000 (9%)]\tLoss: 0.060758\n",
641 | "Train Epoch: 3 [5760/60000 (10%)]\tLoss: 0.007387\n",
642 | "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.016684\n",
643 | "Train Epoch: 3 [7040/60000 (12%)]\tLoss: 0.027652\n",
644 | "Train Epoch: 3 [7680/60000 (13%)]\tLoss: 0.083977\n",
645 | "Train Epoch: 3 [8320/60000 (14%)]\tLoss: 0.031926\n",
646 | "Train Epoch: 3 [8960/60000 (15%)]\tLoss: 0.071717\n",
647 | "Train Epoch: 3 [9600/60000 (16%)]\tLoss: 0.001477\n",
648 | "Train Epoch: 3 [10240/60000 (17%)]\tLoss: 0.023606\n",
649 | "Train Epoch: 3 [10880/60000 (18%)]\tLoss: 0.022856\n",
650 | "Train Epoch: 3 [11520/60000 (19%)]\tLoss: 0.067140\n",
651 | "Train Epoch: 3 [12160/60000 (20%)]\tLoss: 0.020441\n",
652 | "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.000635\n",
653 | "Train Epoch: 3 [13440/60000 (22%)]\tLoss: 0.017227\n",
654 | "Train Epoch: 3 [14080/60000 (23%)]\tLoss: 0.022927\n",
655 | "Train Epoch: 3 [14720/60000 (25%)]\tLoss: 0.027828\n",
656 | "Train Epoch: 3 [15360/60000 (26%)]\tLoss: 0.032064\n",
657 | "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.026003\n",
658 | "Train Epoch: 3 [16640/60000 (28%)]\tLoss: 0.029695\n",
659 | "Train Epoch: 3 [17280/60000 (29%)]\tLoss: 0.000279\n",
660 | "Train Epoch: 3 [17920/60000 (30%)]\tLoss: 0.000815\n",
661 | "Train Epoch: 3 [18560/60000 (31%)]\tLoss: 0.030595\n",
662 | "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.004104\n",
663 | "Train Epoch: 3 [19840/60000 (33%)]\tLoss: 0.012980\n",
664 | "Train Epoch: 3 [20480/60000 (34%)]\tLoss: 0.140954\n",
665 | "Train Epoch: 3 [21120/60000 (35%)]\tLoss: 0.036284\n",
666 | "Train Epoch: 3 [21760/60000 (36%)]\tLoss: 0.056631\n",
667 | "Train Epoch: 3 [22400/60000 (37%)]\tLoss: 0.001214\n",
668 | "Train Epoch: 3 [23040/60000 (38%)]\tLoss: 0.229223\n",
669 | "Train Epoch: 3 [23680/60000 (39%)]\tLoss: 0.018490\n",
670 | "Train Epoch: 3 [24320/60000 (41%)]\tLoss: 0.122923\n",
671 | "Train Epoch: 3 [24960/60000 (42%)]\tLoss: 0.027173\n",
672 | "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.006939\n",
673 | "Train Epoch: 3 [26240/60000 (44%)]\tLoss: 0.038346\n",
674 | "Train Epoch: 3 [26880/60000 (45%)]\tLoss: 0.083031\n",
675 | "Train Epoch: 3 [27520/60000 (46%)]\tLoss: 0.144066\n",
676 | "Train Epoch: 3 [28160/60000 (47%)]\tLoss: 0.002842\n",
677 | "Train Epoch: 3 [28800/60000 (48%)]\tLoss: 0.047957\n",
678 | "Train Epoch: 3 [29440/60000 (49%)]\tLoss: 0.032559\n",
679 | "Train Epoch: 3 [30080/60000 (50%)]\tLoss: 0.075398\n",
680 | "Train Epoch: 3 [30720/60000 (51%)]\tLoss: 0.000979\n",
681 | "Train Epoch: 3 [31360/60000 (52%)]\tLoss: 0.039745\n",
682 | "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.087644\n",
683 | "Train Epoch: 3 [32640/60000 (54%)]\tLoss: 0.245594\n",
684 | "Train Epoch: 3 [33280/60000 (55%)]\tLoss: 0.263317\n",
685 | "Train Epoch: 3 [33920/60000 (57%)]\tLoss: 0.057245\n",
686 | "Train Epoch: 3 [34560/60000 (58%)]\tLoss: 0.024240\n",
687 | "Train Epoch: 3 [35200/60000 (59%)]\tLoss: 0.079506\n",
688 | "Train Epoch: 3 [35840/60000 (60%)]\tLoss: 0.001688\n",
689 | "Train Epoch: 3 [36480/60000 (61%)]\tLoss: 0.024482\n",
690 | "Train Epoch: 3 [37120/60000 (62%)]\tLoss: 0.090378\n",
691 | "Train Epoch: 3 [37760/60000 (63%)]\tLoss: 0.111796\n",
692 | "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.034894\n",
693 | "Train Epoch: 3 [39040/60000 (65%)]\tLoss: 0.047220\n",
694 | "Train Epoch: 3 [39680/60000 (66%)]\tLoss: 0.135415\n",
695 | "Train Epoch: 3 [40320/60000 (67%)]\tLoss: 0.040737\n",
696 | "Train Epoch: 3 [40960/60000 (68%)]\tLoss: 0.015579\n",
697 | "Train Epoch: 3 [41600/60000 (69%)]\tLoss: 0.007837\n",
698 | "Train Epoch: 3 [42240/60000 (70%)]\tLoss: 0.000603\n",
699 | "Train Epoch: 3 [42880/60000 (71%)]\tLoss: 0.020224\n",
700 | "Train Epoch: 3 [43520/60000 (72%)]\tLoss: 0.026869\n",
701 | "Train Epoch: 3 [44160/60000 (74%)]\tLoss: 0.001488\n",
702 | "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.007323\n",
703 | "Train Epoch: 3 [45440/60000 (76%)]\tLoss: 0.005850\n",
704 | "Train Epoch: 3 [46080/60000 (77%)]\tLoss: 0.011149\n",
705 | "Train Epoch: 3 [46720/60000 (78%)]\tLoss: 0.017864\n",
706 | "Train Epoch: 3 [47360/60000 (79%)]\tLoss: 0.034819\n",
707 | "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.021358\n",
708 | "Train Epoch: 3 [48640/60000 (81%)]\tLoss: 0.006743\n",
709 | "Train Epoch: 3 [49280/60000 (82%)]\tLoss: 0.003964\n",
710 | "Train Epoch: 3 [49920/60000 (83%)]\tLoss: 0.011560\n",
711 | "Train Epoch: 3 [50560/60000 (84%)]\tLoss: 0.000351\n",
712 | "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.016029\n",
713 | "Train Epoch: 3 [51840/60000 (86%)]\tLoss: 0.040337\n",
714 | "Train Epoch: 3 [52480/60000 (87%)]\tLoss: 0.021229\n",
715 | "Train Epoch: 3 [53120/60000 (88%)]\tLoss: 0.086175\n",
716 | "Train Epoch: 3 [53760/60000 (90%)]\tLoss: 0.008898\n",
717 | "Train Epoch: 3 [54400/60000 (91%)]\tLoss: 0.021425\n",
718 | "Train Epoch: 3 [55040/60000 (92%)]\tLoss: 0.002073\n",
719 | "Train Epoch: 3 [55680/60000 (93%)]\tLoss: 0.053018\n",
720 | "Train Epoch: 3 [56320/60000 (94%)]\tLoss: 0.001988\n",
721 | "Train Epoch: 3 [56960/60000 (95%)]\tLoss: 0.010265\n",
722 | "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.001524\n",
723 | "Train Epoch: 3 [58240/60000 (97%)]\tLoss: 0.048560\n",
724 | "Train Epoch: 3 [58880/60000 (98%)]\tLoss: 0.017292\n",
725 | "Train Epoch: 3 [59520/60000 (99%)]\tLoss: 0.004543\n",
726 | "\n",
727 | "Test set: Avg. loss: -0.9860, Accuracy: 9886/10000 (99%)\n",
728 | "\n",
729 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.000774\n",
730 | "Train Epoch: 3 [640/60000 (1%)]\tLoss: 0.005242\n",
731 | "Train Epoch: 3 [1280/60000 (2%)]\tLoss: 0.087029\n",
732 | "Train Epoch: 3 [1920/60000 (3%)]\tLoss: 0.042792\n",
733 | "Train Epoch: 3 [2560/60000 (4%)]\tLoss: 0.035679\n",
734 | "Train Epoch: 3 [3200/60000 (5%)]\tLoss: 0.000884\n",
735 | "Train Epoch: 3 [3840/60000 (6%)]\tLoss: 0.008985\n",
736 | "Train Epoch: 3 [4480/60000 (7%)]\tLoss: 0.017822\n",
737 | "Train Epoch: 3 [5120/60000 (9%)]\tLoss: 0.000644\n",
738 | "Train Epoch: 3 [5760/60000 (10%)]\tLoss: 0.007079\n",
739 | "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.003289\n",
740 | "Train Epoch: 3 [7040/60000 (12%)]\tLoss: 0.014545\n",
741 | "Train Epoch: 3 [7680/60000 (13%)]\tLoss: 0.037752\n",
742 | "Train Epoch: 3 [8320/60000 (14%)]\tLoss: 0.012087\n",
743 | "Train Epoch: 3 [8960/60000 (15%)]\tLoss: 0.036052\n",
744 | "Train Epoch: 3 [9600/60000 (16%)]\tLoss: 0.040073\n",
745 | "Train Epoch: 3 [10240/60000 (17%)]\tLoss: 0.049330\n",
746 | "Train Epoch: 3 [10880/60000 (18%)]\tLoss: 0.001090\n",
747 | "Train Epoch: 3 [11520/60000 (19%)]\tLoss: 0.013898\n",
748 | "Train Epoch: 3 [12160/60000 (20%)]\tLoss: 0.003239\n",
749 | "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.167850\n",
750 | "Train Epoch: 3 [13440/60000 (22%)]\tLoss: 0.011250\n",
751 | "Train Epoch: 3 [14080/60000 (23%)]\tLoss: 0.026054\n",
752 | "Train Epoch: 3 [14720/60000 (25%)]\tLoss: 0.015770\n",
753 | "Train Epoch: 3 [15360/60000 (26%)]\tLoss: 0.003927\n",
754 | "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.024283\n",
755 | "Train Epoch: 3 [16640/60000 (28%)]\tLoss: 0.002161\n",
756 | "Train Epoch: 3 [17280/60000 (29%)]\tLoss: 0.036340\n",
757 | "Train Epoch: 3 [17920/60000 (30%)]\tLoss: 0.000880\n",
758 | "Train Epoch: 3 [18560/60000 (31%)]\tLoss: 0.101127\n",
759 | "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.004294\n",
760 | "Train Epoch: 3 [19840/60000 (33%)]\tLoss: 0.008890\n",
761 | "Train Epoch: 3 [20480/60000 (34%)]\tLoss: 0.005472\n",
762 | "Train Epoch: 3 [21120/60000 (35%)]\tLoss: 0.065323\n",
763 | "Train Epoch: 3 [21760/60000 (36%)]\tLoss: 0.069902\n",
764 | "Train Epoch: 3 [22400/60000 (37%)]\tLoss: 0.001641\n",
765 | "Train Epoch: 3 [23040/60000 (38%)]\tLoss: 0.076958\n",
766 | "Train Epoch: 3 [23680/60000 (39%)]\tLoss: 0.032411\n",
767 | "Train Epoch: 3 [24320/60000 (41%)]\tLoss: 0.034886\n",
768 | "Train Epoch: 3 [24960/60000 (42%)]\tLoss: 0.010985\n",
769 | "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.005426\n",
770 | "Train Epoch: 3 [26240/60000 (44%)]\tLoss: 0.032913\n",
771 | "Train Epoch: 3 [26880/60000 (45%)]\tLoss: 0.126640\n",
772 | "Train Epoch: 3 [27520/60000 (46%)]\tLoss: 0.037310\n",
773 | "Train Epoch: 3 [28160/60000 (47%)]\tLoss: 0.008347\n",
774 | "Train Epoch: 3 [28800/60000 (48%)]\tLoss: 0.046984\n",
775 | "Train Epoch: 3 [29440/60000 (49%)]\tLoss: 0.009953\n",
776 | "Train Epoch: 3 [30080/60000 (50%)]\tLoss: 0.002314\n",
777 | "Train Epoch: 3 [30720/60000 (51%)]\tLoss: 0.085070\n",
778 | "Train Epoch: 3 [31360/60000 (52%)]\tLoss: 0.016453\n",
779 | "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.045299\n",
780 | "Train Epoch: 3 [32640/60000 (54%)]\tLoss: 0.000546\n",
781 | "Train Epoch: 3 [33280/60000 (55%)]\tLoss: 0.241897\n",
782 | "Train Epoch: 3 [33920/60000 (57%)]\tLoss: 0.005824\n",
783 | "Train Epoch: 3 [34560/60000 (58%)]\tLoss: 0.060355\n",
784 | "Train Epoch: 3 [35200/60000 (59%)]\tLoss: 0.002446\n",
785 | "Train Epoch: 3 [35840/60000 (60%)]\tLoss: 0.006852\n",
786 | "Train Epoch: 3 [36480/60000 (61%)]\tLoss: 0.001183\n",
787 | "Train Epoch: 3 [37120/60000 (62%)]\tLoss: 0.004089\n",
788 | "Train Epoch: 3 [37760/60000 (63%)]\tLoss: 0.011785\n",
789 | "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.102750\n",
790 | "Train Epoch: 3 [39040/60000 (65%)]\tLoss: 0.010851\n",
791 | "Train Epoch: 3 [39680/60000 (66%)]\tLoss: 0.005771\n",
792 | "Train Epoch: 3 [40320/60000 (67%)]\tLoss: 0.030049\n",
793 | "Train Epoch: 3 [40960/60000 (68%)]\tLoss: 0.005349\n",
794 | "Train Epoch: 3 [41600/60000 (69%)]\tLoss: 0.003538\n",
795 | "Train Epoch: 3 [42240/60000 (70%)]\tLoss: 0.063063\n",
796 | "Train Epoch: 3 [42880/60000 (71%)]\tLoss: 0.012078\n",
797 | "Train Epoch: 3 [43520/60000 (72%)]\tLoss: 0.000568\n",
798 | "Train Epoch: 3 [44160/60000 (74%)]\tLoss: 0.000667\n",
799 | "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.095291\n",
800 | "Train Epoch: 3 [45440/60000 (76%)]\tLoss: 0.002779\n",
801 | "Train Epoch: 3 [46080/60000 (77%)]\tLoss: 0.005778\n",
802 | "Train Epoch: 3 [46720/60000 (78%)]\tLoss: 0.057257\n",
803 | "Train Epoch: 3 [47360/60000 (79%)]\tLoss: 0.007636\n",
804 | "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.069810\n",
805 | "Train Epoch: 3 [48640/60000 (81%)]\tLoss: 0.002557\n",
806 | "Train Epoch: 3 [49280/60000 (82%)]\tLoss: 0.139816\n",
807 | "Train Epoch: 3 [49920/60000 (83%)]\tLoss: 0.060314\n",
808 | "Train Epoch: 3 [50560/60000 (84%)]\tLoss: 0.000701\n",
809 | "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.001461\n",
810 | "Train Epoch: 3 [51840/60000 (86%)]\tLoss: 0.225825\n",
811 | "Train Epoch: 3 [52480/60000 (87%)]\tLoss: 0.014944\n",
812 | "Train Epoch: 3 [53120/60000 (88%)]\tLoss: 0.041378\n",
813 | "Train Epoch: 3 [53760/60000 (90%)]\tLoss: 0.042047\n",
814 | "Train Epoch: 3 [54400/60000 (91%)]\tLoss: 0.020290\n",
815 | "Train Epoch: 3 [55040/60000 (92%)]\tLoss: 0.010125\n",
816 | "Train Epoch: 3 [55680/60000 (93%)]\tLoss: 0.003389\n",
817 | "Train Epoch: 3 [56320/60000 (94%)]\tLoss: 0.131909\n",
818 | "Train Epoch: 3 [56960/60000 (95%)]\tLoss: 0.051108\n",
819 | "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.006570\n",
820 | "Train Epoch: 3 [58240/60000 (97%)]\tLoss: 0.067857\n",
821 | "Train Epoch: 3 [58880/60000 (98%)]\tLoss: 0.006648\n",
822 | "Train Epoch: 3 [59520/60000 (99%)]\tLoss: 0.149868\n",
823 | "\n",
824 | "Test set: Avg. loss: -0.9873, Accuracy: 9892/10000 (99%)\n",
825 | "\n"
826 | ]
827 | }
828 | ]
829 | },
830 | {
831 | "cell_type": "markdown",
832 | "source": [
833 | "# Out-of-distribution Detection (FashionMNIST)\n",
834 | "\n",
835 | "The FashionMNIST dataset serves as an out-of-distribution (OOD) dataset for the MNIST dataset, providing a more challenging and diverse set of images. While MNIST consists of 70,000 grayscale images of handwritten digits (0-9), FashionMNIST contains 70,000 grayscale images of various clothing items, including t-shirts, trousers, pullovers, dresses, coats, sandals, shirts, sneakers, bags, and ankle boots. Both datasets have the same structure, with 28x28 pixel images and 10 classes, making FashionMNIST a suitable alternative for evaluating models trained on MNIST. By using FashionMNIST as OOD data, we can assess the robustness our model when exposed to visually different, yet structurally similar, data."
836 | ],
837 | "metadata": {
838 | "id": "EIXcgft4zcwM"
839 | }
840 | },
841 | {
842 | "cell_type": "code",
843 | "source": [
844 | "train_fashion_loader = torch.utils.data.DataLoader(\n",
845 | " torchvision.datasets.FashionMNIST('./FMNIST/', train=True, download=True,\n",
846 | " transform=torchvision.transforms.Compose([\n",
847 | " torchvision.transforms.ToTensor(),\n",
848 | " torchvision.transforms.Normalize(\n",
849 | " (0.1307,), (0.3081,))\n",
850 | " ])),\n",
851 | " batch_size=batch_size_train, shuffle=True)\n",
852 | "\n",
853 | "test_fashion_loader = torch.utils.data.DataLoader(\n",
854 | " torchvision.datasets.FashionMNIST('./FMNIST/', train=False, download=True,\n",
855 | " transform=torchvision.transforms.Compose([\n",
856 | " torchvision.transforms.ToTensor(),\n",
857 | " torchvision.transforms.Normalize(\n",
858 | " (0.1307,), (0.3081,))\n",
859 | " ])),\n",
860 | " batch_size=batch_size_test, shuffle=True)"
861 | ],
862 | "metadata": {
863 | "id": "LfO5qfyIISEk"
864 | },
865 | "execution_count": 9,
866 | "outputs": []
867 | },
868 | {
869 | "cell_type": "code",
870 | "source": [
871 | "import matplotlib.pyplot as plt\n",
872 | "\n",
873 | "images, targets = next(iter(test_fashion_loader))\n",
874 | "\n",
875 | "# Number of images you want to display\n",
876 | "num_images = 10\n",
877 | "\n",
878 | "# Create a figure and a row of subplots\n",
879 | "fig, axes = plt.subplots(1, num_images, figsize=(15, 3))\n",
880 | "\n",
881 | "# Plot each image on a separate subplot\n",
882 | "for i in range(num_images):\n",
883 | " axes[i].imshow(images[i, 0], cmap='gray')\n",
884 | " axes[i].axis('off') # Hide axis\n",
885 | "\n",
886 | "plt.show()"
887 | ],
888 | "metadata": {
889 | "id": "3i5gW6af0QXT",
890 | "outputId": "7b8279b6-bcd5-4f2e-8af1-07617fa77fd4",
891 | "colab": {
892 | "base_uri": "https://localhost:8080/",
893 | "height": 155
894 | }
895 | },
896 | "execution_count": 10,
897 | "outputs": [
898 | {
899 | "output_type": "display_data",
900 | "data": {
901 | "text/plain": [
902 | ""
903 | ],
904 | "image/png": "\n"
905 | },
906 | "metadata": {}
907 | }
908 | ]
909 | },
910 | {
911 | "cell_type": "code",
912 | "source": [
913 | "# Computing OOD metrics\n",
914 | "\n",
915 | "network.eval()\n",
916 | "\n",
917 | "uncertainties = np.array([])\n",
918 | "labels = np.array([])\n",
919 | "eps = 1e-10\n",
920 | "\n",
921 | "with torch.no_grad():\n",
922 | " for data, target in test_loader:\n",
923 | "\n",
924 | " data = data.to(device)\n",
925 | " prob = inference_test(data)\n",
926 | " uncertainty = (-prob * torch.log(prob + eps)).sum(dim=1).cpu().detach().numpy()\n",
927 | " label = np.zeros_like(uncertainty)\n",
928 | "\n",
929 | " uncertainties = np.concatenate([uncertainties, uncertainty])\n",
930 | " labels = np.concatenate([labels, label])\n",
931 | "\n",
932 | "with torch.no_grad():\n",
933 | " for data, target in test_fashion_loader:\n",
934 | "\n",
935 | " data = data.to(device)\n",
936 | "\n",
937 | " prob = inference_test(data)\n",
938 | " uncertainty = (-prob * torch.log(prob + eps)).sum(dim=1).cpu().detach().numpy()\n",
939 | " label = np.ones_like(uncertainty)\n",
940 | "\n",
941 | " uncertainties = np.concatenate([uncertainties, uncertainty])\n",
942 | " labels = np.concatenate([labels, label])"
943 | ],
944 | "metadata": {
945 | "colab": {
946 | "base_uri": "https://localhost:8080/"
947 | },
948 | "id": "siZj8y_8IVDg",
949 | "outputId": "0281b8fa-98ed-490e-c464-c5b806083a7a"
950 | },
951 | "execution_count": 11,
952 | "outputs": [
953 | {
954 | "output_type": "stream",
955 | "name": "stderr",
956 | "text": [
957 | ":31: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
958 | " return F.log_softmax(x)\n"
959 | ]
960 | }
961 | ]
962 | },
963 | {
964 | "cell_type": "code",
965 | "source": [
966 | "import sklearn.metrics\n",
967 | "roc_auc = sklearn.metrics.roc_auc_score(labels, uncertainties)\n",
968 | "precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, uncertainties)\n",
969 | "pr_auc = sklearn.metrics.auc(recall, precision)\n",
970 | "\n",
971 | "# evaluate ROC- and PR-AUC metrics, see https://arxiv.org/abs/1802.10501 for more details\n",
972 | "print(f\"ROC AUC: {roc_auc:.4f} \")\n",
973 | "print(f\"PR AUC: {pr_auc:.4f}\")"
974 | ],
975 | "metadata": {
976 | "id": "Lxy3IUwzMK_x",
977 | "outputId": "0051f24c-f1bd-49f0-8f5a-82443d59cea8",
978 | "colab": {
979 | "base_uri": "https://localhost:8080/"
980 | }
981 | },
982 | "execution_count": 12,
983 | "outputs": [
984 | {
985 | "output_type": "stream",
986 | "name": "stdout",
987 | "text": [
988 | "ROC AUC: 0.9536 \n",
989 | "PR AUC: 0.9517\n"
990 | ]
991 | }
992 | ]
993 | },
994 | {
995 | "cell_type": "code",
996 | "source": [
997 | "# Plot ROC curve\n",
998 | "fpr, tpr, _ = sklearn.metrics.roc_curve(labels, uncertainties)\n",
999 | "plt.figure()\n",
1000 | "plt.plot(fpr, tpr, label=f'ROC curve (area = {roc_auc:.2f})')\n",
1001 | "plt.plot([0, 1], [0, 1], 'g--', label=\"Random classifier\")\n",
1002 | "plt.hlines(1, xmin=0, xmax=1, color='k', linestyle=\"--\")\n",
1003 | "plt.vlines(0, ymin=0, ymax=1, color='k', linestyle=\"--\")\n",
1004 | "plt.xlim([-0.01, 1.05])\n",
1005 | "plt.ylim([0.0, 1.05])\n",
1006 | "plt.xlabel('False Positive Rate')\n",
1007 | "plt.ylabel('True Positive Rate')\n",
1008 | "plt.title('Receiver Operating Characteristic')\n",
1009 | "plt.grid()\n",
1010 | "plt.legend(loc=\"lower right\")\n",
1011 | "plt.show()\n",
1012 | "\n",
1013 | "# Plot PR curve\n",
1014 | "plt.figure()\n",
1015 | "plt.plot(recall, precision, label=f'PR curve (area = {pr_auc:.2f})')\n",
1016 | "plt.hlines(1, xmin=0, xmax=1, color='k', linestyle=\"--\")\n",
1017 | "plt.vlines(1, ymin=0, ymax=1, color='k', linestyle=\"--\")\n",
1018 | "plt.xlabel('Recall')\n",
1019 | "plt.ylabel('Precision')\n",
1020 | "plt.title('Precision-Recall curve')\n",
1021 | "plt.xlim([-0.01, 1.05])\n",
1022 | "plt.ylim([0.0, 1.05])\n",
1023 | "plt.grid()\n",
1024 | "plt.legend(loc=\"lower left\")\n",
1025 | "plt.show()\n"
1026 | ],
1027 | "metadata": {
1028 | "id": "qGoMn6os0oZl",
1029 | "outputId": "e351b1aa-21ed-45f0-b8d3-bd336647afaa",
1030 | "colab": {
1031 | "base_uri": "https://localhost:8080/",
1032 | "height": 947
1033 | }
1034 | },
1035 | "execution_count": 13,
1036 | "outputs": [
1037 | {
1038 | "output_type": "display_data",
1039 | "data": {
1040 | "text/plain": [
1041 | ""
1042 | ],
1043 | "image/png": "\n"
1044 | },
1045 | "metadata": {}
1046 | },
1047 | {
1048 | "output_type": "display_data",
1049 | "data": {
1050 | "text/plain": [
1051 | ""
1052 | ],
1053 | "image/png": "\n"
1054 | },
1055 | "metadata": {}
1056 | }
1057 | ]
1058 | },
1059 | {
1060 | "cell_type": "code",
1061 | "source": [],
1062 | "metadata": {
1063 | "id": "tMHnbvfr0sdC"
1064 | },
1065 | "execution_count": null,
1066 | "outputs": []
1067 | }
1068 | ]
1069 | }
--------------------------------------------------------------------------------