├── test
├── __init__.py
├── test_torch.py
└── test_keras.py
├── masksembles
├── exceptions.py
├── __init__.py
├── keras.py
├── torch.py
└── common.py
├── setup.cfg
├── MANIFEST.in
├── .gitattributes
├── images
├── mask_logo.png
├── transition.gif
└── complex_sample_mnist.npy
├── pyproject.toml
├── LICENSE
├── setup.py
├── .gitignore
├── README.md
└── notebooks
├── MNIST_Masksembles_tensoflow.ipynb
└── MNIST_Masksembles.ipynb
/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/masksembles/exceptions.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/masksembles/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.1.1"
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max_line_length = 120
3 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE
3 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.py linguist-language=python
2 | *.ipynb linguist-documentation
--------------------------------------------------------------------------------
/images/mask_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikitadurasov/masksembles/HEAD/images/mask_logo.png
--------------------------------------------------------------------------------
/images/transition.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikitadurasov/masksembles/HEAD/images/transition.gif
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=65", "wheel"]
3 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/images/complex_sample_mnist.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikitadurasov/masksembles/HEAD/images/complex_sample_mnist.npy
--------------------------------------------------------------------------------
/test/test_torch.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 |
4 | import masksembles.torch
5 |
6 | class TestCreation(unittest.TestCase):
7 |
8 | def test_init_failed(self):
9 | pass
10 |
11 | def test_init_success(self):
12 | pass
13 |
--------------------------------------------------------------------------------
/test/test_keras.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import tensorflow as tf
3 |
4 | import masksembles.keras
5 |
6 |
7 | class TestCreation(unittest.TestCase):
8 |
9 | def test_init_failed(self):
10 | layer = masksembles.keras.Masksembles2D(4, 11.)
11 | self.assertRaises(ValueError, layer, tf.ones([4, 10, 4, 4]))
12 |
13 | def test_init_success(self):
14 | layer = masksembles.keras.Masksembles2D(4, 11.)
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Nikita Durasov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from setuptools import setup, find_packages
3 |
4 | HERE = pathlib.Path(__file__).parent
5 | readme = (HERE / "README.md").read_text(encoding="utf-8")
6 |
7 | setup(
8 | name="masksembles", # change if taken on PyPI
9 | version="1.1.1", # use semantic x.y.z (PyPI won’t accept re-uploads)
10 | author="Nikita Durasov",
11 | author_email="yassnda@gmail.com",
12 | description="Official implementation of Masksembles approach",
13 | long_description=readme,
14 | long_description_content_type="text/markdown",
15 | url="https://github.com/nikitadurasov/masksembles",
16 | project_urls={
17 | "Issues": "https://github.com/nikitadurasov/masksembles/issues",
18 | "Documentation": "https://github.com/nikitadurasov/masksembles#readme",
19 | },
20 | license="MIT",
21 | packages=find_packages(exclude=("tests", "test", "notebooks")),
22 | python_requires=">=3.8",
23 | install_requires=[
24 | "numpy>=1.20",
25 | ],
26 | extras_require={
27 | "torch": ["torch>=1.12"],
28 | "tensorflow": ["tensorflow>=2.9"],
29 | "dev": ["pytest", "black", "flake8"],
30 | },
31 | include_package_data=True,
32 | classifiers=[
33 | "Programming Language :: Python :: 3",
34 | "License :: OSI Approved :: MIT License",
35 | "Operating System :: OS Independent",
36 | ],
37 | keywords=["uncertainty", "ensembles", "deep-learning", "pytorch", "tensorflow"],
38 | )
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Custom ignores
132 | .ipynb_checkpoints
133 | .idea
134 | dev.ipynb
--------------------------------------------------------------------------------
/masksembles/keras.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from . import common
3 |
4 |
5 | class Masksembles2D(tf.keras.layers.Layer):
6 | """
7 | :class:`Masksembles2D` is high-level class that implements Masksembles approach
8 | for 2-dimensional inputs (similar to :class:`tensorflow.keras.layers.SpatialDropout1D`).
9 |
10 | :param n: int, number of masks
11 | :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \
12 | subnetworks correlations but at the same time decrease capacity of every individual model.
13 |
14 | Shape:
15 | * Input: (N, H, W, C)
16 | * Output: (N, H, W, C) (same shape as input)
17 |
18 | Examples:
19 |
20 | >>> m = Masksembles2D(4, 2.0)
21 | >>> inputs = tf.ones([4, 28, 28, 16])
22 | >>> output = m(inputs)
23 |
24 | References:
25 |
26 | [1] `Masksembles for Uncertainty Estimation`,
27 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua
28 |
29 | """
30 |
31 | def __init__(self, n: int, scale: float):
32 | super(Masksembles2D, self).__init__()
33 |
34 | self.n = n
35 | self.scale = scale
36 |
37 | def build(self, input_shape):
38 | channels = input_shape[-1]
39 | masks = common.generation_wrapper(channels, self.n, self.scale)
40 | self.masks = self.add_weight("masks",
41 | shape=masks.shape,
42 | trainable=False,
43 | dtype="float32")
44 | self.masks.assign(masks)
45 |
46 | def call(self, inputs, training=False):
47 | # inputs : [N, H, W, C]
48 | # masks : [M, C]
49 | x = tf.stack(tf.split(inputs, self.n))
50 | # x : [M, N // M, H, W, C]
51 | # masks : [M, 1, 1, 1, C]
52 | x = x * self.masks[:, tf.newaxis, tf.newaxis, tf.newaxis]
53 | x = tf.concat(tf.split(x, self.n), axis=1)
54 | return tf.squeeze(x, axis=0)
55 |
56 |
57 | class Masksembles1D(tf.keras.layers.Layer):
58 | """
59 | :class:`Masksembles1D` is high-level class that implements Masksembles approach
60 | for 1-dimensional inputs (similar to :class:`tensorflow.keras.layers.Dropout`).
61 |
62 | :param n: int, number of masks
63 | :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \
64 | subnetworks correlations but at the same time decrease capacity of every individual model.
65 |
66 | Shape:
67 | * Input: (N, C)
68 | * Output: (N, C) (same shape as input)
69 |
70 | Examples:
71 |
72 | >>> m = Masksembles1D(4, 2.0)
73 | >>> inputs = tf.ones([4, 16])
74 | >>> output = m(inputs)
75 |
76 |
77 | References:
78 |
79 | [1] `Masksembles for Uncertainty Estimation`,
80 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua
81 |
82 | """
83 |
84 | def __init__(self, n: int, scale: float):
85 | super(Masksembles1D, self).__init__()
86 |
87 | self.n = n
88 | self.scale = scale
89 |
90 | def build(self, input_shape):
91 | channels = input_shape[-1]
92 | masks = common.generation_wrapper(channels, self.n, self.scale)
93 | self.masks = self.add_weight("masks",
94 | shape=masks.shape,
95 | trainable=False,
96 | dtype="float32")
97 | self.masks.assign(masks)
98 |
99 | def call(self, inputs, training=False):
100 | x = tf.stack(tf.split(inputs, self.n))
101 | x = x * self.masks[:, tf.newaxis]
102 | x = tf.concat(tf.split(x, self.n), axis=1)
103 | return tf.squeeze(x, axis=0)
104 |
--------------------------------------------------------------------------------
/masksembles/torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from . import common
5 |
6 |
7 | class Masksembles2D(nn.Module):
8 | """
9 | :class:`Masksembles2D` is high-level class that implements Masksembles approach
10 | for 2-dimensional inputs (similar to :class:`torch.nn.Dropout2d`).
11 |
12 | :param channels: int, number of channels used in masks.
13 | :param n: int, number of masks
14 | :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \
15 | subnetworks correlations but at the same time decrease capacity of every individual model.
16 |
17 | Shape:
18 | * Input: (N, C, H, W)
19 | * Output: (N, C, H, W) (same shape as input)
20 |
21 | Examples:
22 |
23 | >>> m = Masksembles2D(16, 4, 2.0)
24 | >>> input = torch.ones([4, 16, 28, 28])
25 | >>> output = m(input)
26 |
27 | References:
28 |
29 | [1] `Masksembles for Uncertainty Estimation`,
30 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua
31 |
32 | """
33 | def __init__(self, channels: int, n: int, scale: float):
34 | super().__init__()
35 | self.channels, self.n, self.scale = channels, n, scale
36 | masks_np = common.generation_wrapper(channels, n, scale) # numpy float64 by default
37 | masks = torch.as_tensor(masks_np, dtype=torch.float32) # make float32 here
38 | self.register_buffer('masks', masks, persistent=False) # not trainable, moves with .to()
39 |
40 | def forward(self, inputs):
41 | # make sure masks match dtype/device (usually already true because of buffer, but safe)
42 | masks = self.masks.to(dtype=inputs.dtype, device=inputs.device)
43 |
44 | batch = inputs.shape[0]
45 | # safer split even if batch % n != 0
46 | chunks = torch.chunk(inputs.unsqueeze(1), self.n, dim=0) # returns nearly equal chunks
47 | x = torch.cat(chunks, dim=1).permute(1, 0, 2, 3, 4) # [n, ?, C, H, W]
48 | x = x * masks.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # broadcast masks
49 | x = torch.cat(torch.split(x, 1, dim=0), dim=1)
50 | return x.squeeze(0)
51 |
52 |
53 |
54 | class Masksembles1D(nn.Module):
55 | """
56 | :class:`Masksembles1D` is high-level class that implements Masksembles approach
57 | for 1-dimensional inputs (similar to :class:`torch.nn.Dropout`).
58 |
59 | :param channels: int, number of channels used in masks.
60 | :param n: int, number of masks
61 | :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \
62 | subnetworks correlations but at the same time decrease capacity of every individual model.
63 |
64 | Shape:
65 | * Input: (N, C)
66 | * Output: (N, C) (same shape as input)
67 |
68 | Examples:
69 |
70 | >>> m = Masksembles1D(16, 4, 2.0)
71 | >>> input = torch.ones([4, 16])
72 | >>> output = m(input)
73 |
74 |
75 | References:
76 |
77 | [1] `Masksembles for Uncertainty Estimation`,
78 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua
79 |
80 | """
81 |
82 | def __init__(self, channels: int, n: int, scale: float):
83 | super().__init__()
84 | self.channels, self.n, self.scale = channels, n, scale
85 | masks_np = common.generation_wrapper(channels, n, scale)
86 | masks = torch.as_tensor(masks_np, dtype=torch.float32)
87 | self.register_buffer('masks', masks, persistent=False)
88 |
89 | def forward(self, inputs):
90 | masks = self.masks.to(dtype=inputs.dtype, device=inputs.device)
91 |
92 | batch = inputs.shape[0]
93 | chunks = torch.chunk(inputs.unsqueeze(1), self.n, dim=0)
94 | x = torch.cat(chunks, dim=1).permute(1, 0, 2) # [n, ?, C]
95 | x = x * masks.unsqueeze(1)
96 | x = torch.cat(torch.split(x, 1, dim=0), dim=1)
97 | return x.squeeze(0)
98 |
--------------------------------------------------------------------------------
/masksembles/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def generate_masks_(m: int, n: int, s: float) -> np.ndarray:
5 | """Generates set of binary masks with properties defined by n, m, s params.
6 |
7 | Results of this function are stochastic, that is, calls with the same sets
8 | of arguments might generate outputs of different shapes. Check generate_masks
9 | and generation_wrapper function for more deterministic behaviour.
10 |
11 | :param m: int, number of ones in each mask
12 | :param n: int, number of masks in the set
13 | :param s: float, scale param controls overlap of generated masks
14 | :return: np.ndarray, matrix of binary vectors
15 | """
16 |
17 | total_positions = int(m * s)
18 | masks = []
19 |
20 | for _ in range(n):
21 | new_vector = np.zeros([total_positions])
22 | idx = np.random.choice(range(total_positions), m, replace=False)
23 | new_vector[idx] = 1
24 | masks.append(new_vector)
25 |
26 | masks = np.array(masks)
27 | # drop useless positions
28 | masks = masks[:, ~np.all(masks == 0, axis=0)]
29 | return masks
30 |
31 |
32 | def generate_masks(m: int, n: int, s: float) -> np.ndarray:
33 | """Generates set of binary masks with properties defined by n, m, s params.
34 |
35 | Resulting masks are required to have fixed features size as it's described in [1].
36 | Since process of masks generation is stochastic therefore function evaluates
37 | generate_masks_ multiple times till expected size is acquired.
38 |
39 | :param m: int, number of ones in each mask
40 | :param n: int, number of masks in the set
41 | :param s: float, scale param controls overlap of generated masks
42 | :return: np.ndarray, matrix of binary vectors
43 |
44 | References
45 |
46 | [1] `Masksembles for Uncertainty Estimation: Supplementary Material`,
47 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua
48 | """
49 |
50 | masks = generate_masks_(m, n, s)
51 | # hardcoded formula for expected size, check reference
52 | expected_size = int(m * s * (1 - (1 - 1 / s) ** n))
53 | while masks.shape[1] != expected_size:
54 | masks = generate_masks_(m, n, s)
55 | return masks
56 |
57 |
58 | def generation_wrapper(c: int, n: int, scale: float) -> np.ndarray:
59 | """Generates set of binary masks with properties defined by c, n, scale params.
60 |
61 | Allows to generate masks sets with predefined features number c. Particularly
62 | convenient to use in torch-like layers where one need to define shapes inputs
63 | tensors beforehand.
64 |
65 | :param c: int, number of channels in generated masks
66 | :param n: int, number of masks in the set
67 | :param scale: float, scale param controls overlap of generated masks
68 | :return: np.ndarray, matrix of binary vectors
69 | """
70 |
71 | if c < 10:
72 | raise ValueError("Masksembles approach couldn't be used in such setups where "
73 | f"number of channels is less then 10. Current value is (channels={c}). "
74 | "Please increase number of features in your layer or remove this "
75 | "particular instance of Masksembles from your architecture.")
76 |
77 | if scale > 6.:
78 | raise ValueError("Masksembles approach couldn't be used in such setups where "
79 | f"scale parameter is larger then 6. Current value is (scale={scale}).")
80 |
81 | # inverse formula for number of active features in masks
82 | active_features = int(int(c) / (scale * (1 - (1 - 1 / scale) ** n)))
83 |
84 | # Fix the last part by using binary search
85 | max_iter = 1000
86 |
87 | min = np.max([scale * 0.8, 1.0])
88 | max = scale * 1.2
89 |
90 | for _ in range(max_iter):
91 | mid = (min + max) / 2
92 | masks = generate_masks(active_features, n, mid)
93 | if masks.shape[-1] == c:
94 | break
95 | elif masks.shape[-1] > c:
96 | max = mid
97 | else:
98 | min = mid
99 |
100 | if masks.shape[-1] != c:
101 | raise ValueError("generation_wrapper function failed to generate masks with "
102 | "requested number of features. Please try to change scale parameter")
103 |
104 | return masks
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Masksembles for Uncertainty Estimation
2 |
3 |
8 |
9 | 
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 | ### [Project Page](https://nikitadurasov.github.io/projects/masksembles/) | [Paper](https://arxiv.org/abs/2012.08334) | [Video Explanation](https://www.youtube.com/watch?v=YWKVdn3kLp0)
40 |
41 | [](https://colab.research.google.com/github/nikitadurasov/masksembles/blob/main/notebooks/MNIST_Masksembles.ipynb)
42 |
43 |
46 |
47 |
48 | ## Why Masksembles?
49 |
50 | **Uncertainty Estimation** is one of the most important and critical tasks in the area of modern neural networks and deep learning.
51 | There is a long list of potential applications of uncertainty: safety-critical applications, active learning, domain adaptation,
52 | reinforcement learning and etc.
53 |
54 | **Masksembles** is a **simple** and **easy-to-use** drop-in method with performance on par with Deep Ensembles at a fraction of the cost.
55 | It makes *almost* no changes in your original model and requires only to add special intermediate layers.
56 |
57 | [](https://youtu.be/YWKVdn3kLp0)
58 |
59 | ### [Watch this video on YouTube](https://youtu.be/YWKVdn3kLp0)
60 |
61 | ## Installation
62 |
63 | To install this package, use:
64 |
65 | ```bash
66 | pip install masksembles
67 | ```
68 | or
69 | ```bash
70 | pip install git+http://github.com/nikitadurasov/masksembles
71 | ```
72 |
73 | In addition, Masksembles requires installing at least one of the backends: torch or tensorflow2 / keras.
74 | Please follow official installation instructions for [torch](https://pytorch.org/) or [tensorflow](https://www.tensorflow.org/install)
75 | accordingly.
76 |
77 |
78 | ## Usage
79 |
80 | [comment]: <> (In masksembles module you could find implementations of "Masksembles{1|2|3}D" that)
81 |
82 | [comment]: <> (support different shapes of input vectors (1, 2 and 3-dimentional accordingly))
83 |
84 | This package provides implementations for `Masksembles{1|2|3}D` layers in `masksembles.{torch|keras}`
85 | where `{1|2|3}` refers to dimensionality of input tensors (1-, 2- and 3-dimensional
86 | accordingly).
87 |
88 | * `Masksembles1D`: works with 1-dim inputs,`[B, C]` shaped tensors
89 | * `Masksembles2D`: works with 2-dim inputs,`[B, H, W, C]` (keras) or `[B, C, H, W]` (torch) shaped tensors
90 | * `Masksembles3D` : TBD
91 |
92 | In a Nutshell, Masksembles applies binary masks to inputs via multiplying them both channel-wise. For more efficient
93 | implementation we've followed approach similar to [this](https://arxiv.org/abs/2002.06715) one. Therefore, after inference
94 | `outputs[:B // N]` - stores results for the first submodel, `outputs[B // N : 2 * B // N]` - for the second and etc.
95 | ### Torch
96 |
97 | ```python
98 | import torch
99 | from masksembles.torch import Masksembles1D
100 |
101 | layer = Masksembles1D(10, 4, 2.)
102 | layer(torch.ones([4, 10]))
103 | ```
104 | ```bash
105 | tensor([[0., 1., 0., 0., 1., 0., 1., 1., 1., 1.],
106 | [0., 0., 1., 1., 1., 1., 0., 0., 1., 1.],
107 | [1., 0., 1., 1., 0., 0., 1., 0., 1., 1.],
108 | [1., 0., 0., 1., 1., 1., 0., 1., 1., 0.]], dtype=torch.float64)
109 |
110 | ```
111 |
112 | ### Tensorflow / Keras
113 |
114 | ```python
115 | import tensorflow as tf
116 | from masksembles.keras import Masksembles1D
117 |
118 | layer = Masksembles1D(4, 2.)
119 | layer(tf.ones([4, 10]))
120 | ```
121 | ```bash
122 |
127 | ```
128 |
129 | ### Model example
130 | ```python
131 | import tensorflow as tf
132 | from masksembles.keras import Masksembles1D, Masksembles2D
133 |
134 | model = keras.Sequential(
135 | [
136 | keras.Input(shape=input_shape),
137 | layers.Conv2D(32, kernel_size=(3, 3), activation="elu"),
138 | Masksembles2D(4, 2.0),
139 | layers.MaxPooling2D(pool_size=(2, 2)),
140 |
141 | layers.Conv2D(64, kernel_size=(3, 3), activation="elu"),
142 | Masksembles2D(4, 2.0),
143 | layers.MaxPooling2D(pool_size=(2, 2)),
144 |
145 | layers.Flatten(),
146 | Masksembles1D(4, 2.),
147 | layers.Dense(num_classes, activation="softmax"),
148 | ]
149 | )
150 | ```
151 |
152 | ## Citation
153 | If you found this work useful for your projects, please don't forget to cite it.
154 | ```
155 | @inproceedings{Durasov21,
156 | author = {N. Durasov and T. Bagautdinov and P. Baque and P. Fua},
157 | title = {{Masksembles for Uncertainty Estimation}},
158 | booktitle = CVPR,
159 | year = 2021
160 | }
161 | ```
162 |
--------------------------------------------------------------------------------
/notebooks/MNIST_Masksembles_tensoflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "display_name": "Python 3",
7 | "language": "python",
8 | "name": "python3"
9 | },
10 | "language_info": {
11 | "codemirror_mode": {
12 | "name": "ipython",
13 | "version": 2
14 | },
15 | "file_extension": ".py",
16 | "mimetype": "text/x-python",
17 | "name": "python",
18 | "nbconvert_exporter": "python",
19 | "pygments_lexer": "ipython2",
20 | "version": "2.7.6"
21 | },
22 | "colab": {
23 | "name": "MNIST_Masksembles.ipynb",
24 | "provenance": [],
25 | "toc_visible": true
26 | }
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "collapsed": true,
33 | "id": "qrpn5UV2BYLF"
34 | },
35 | "source": [
36 | "!pip install --upgrade git+http://github.com/nikitadurasov/masksembles\n",
37 | "!wget https://github.com/nikitadurasov/masksembles/raw/main/images/complex_sample_mnist.npy"
38 | ],
39 | "execution_count": 124,
40 | "outputs": []
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {
45 | "id": "bHtAoRyhBiKr"
46 | },
47 | "source": [
48 | "# MNIST "
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {
54 | "id": "eFvywyCZBlsh"
55 | },
56 | "source": [
57 | "## Keras"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "metadata": {
63 | "id": "bnCBbU95BuQ0"
64 | },
65 | "source": [
66 | "import numpy as np\n",
67 | "from tensorflow import keras\n",
68 | "from tensorflow.keras import layers\n",
69 | "import tensorflow as tf\n",
70 | "\n",
71 | "import matplotlib.pyplot as plt"
72 | ],
73 | "execution_count": 113,
74 | "outputs": []
75 | },
76 | {
77 | "cell_type": "code",
78 | "metadata": {
79 | "id": "ZZNyHlDLCOt5"
80 | },
81 | "source": [
82 | "from masksembles.keras import Masksembles2D, Masksembles1D"
83 | ],
84 | "execution_count": 3,
85 | "outputs": []
86 | },
87 | {
88 | "cell_type": "code",
89 | "metadata": {
90 | "colab": {
91 | "base_uri": "https://localhost:8080/"
92 | },
93 | "id": "Gavnqj28CYJ7",
94 | "outputId": "c89047ee-ee6c-4ed5-cc2c-8178d8d14863"
95 | },
96 | "source": [
97 | "# Model / data parameters\n",
98 | "num_classes = 10\n",
99 | "input_shape = (28, 28, 1)\n",
100 | "\n",
101 | "# the data, split between train and test sets\n",
102 | "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
103 | "\n",
104 | "# Scale images to the [0, 1] range\n",
105 | "x_train = x_train.astype(\"float32\") / 255\n",
106 | "x_test = x_test.astype(\"float32\") / 255\n",
107 | "# Make sure images have shape (28, 28, 1)\n",
108 | "x_train = np.expand_dims(x_train, -1)\n",
109 | "x_test = np.expand_dims(x_test, -1)\n",
110 | "print(\"x_train shape:\", x_train.shape)\n",
111 | "print(x_train.shape[0], \"train samples\")\n",
112 | "print(x_test.shape[0], \"test samples\")\n",
113 | "\n",
114 | "\n",
115 | "# convert class vectors to binary class matrices\n",
116 | "y_train = keras.utils.to_categorical(y_train, num_classes)\n",
117 | "y_test = keras.utils.to_categorical(y_test, num_classes)"
118 | ],
119 | "execution_count": 6,
120 | "outputs": [
121 | {
122 | "output_type": "stream",
123 | "text": [
124 | "x_train shape: (60000, 28, 28, 1)\n",
125 | "60000 train samples\n",
126 | "10000 test samples\n"
127 | ],
128 | "name": "stdout"
129 | }
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {
135 | "id": "Hr5xjOSffo7x"
136 | },
137 | "source": [
138 | "In order to transform regular model into Masksembles model one should add Masksembles2D or Masksembles1D layers in it. General recommendation is to insert these layers right before or after convolutional layers. \n",
139 | "\n",
140 | "In example below we'll use both Masksembles2D and Masksembles1D layers applied after convolutions. "
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "metadata": {
146 | "colab": {
147 | "base_uri": "https://localhost:8080/"
148 | },
149 | "id": "7jw12ljXCaMg",
150 | "outputId": "f324695b-c4f7-4624-be34-3bb22cd687a0"
151 | },
152 | "source": [
153 | "model = keras.Sequential(\n",
154 | " [\n",
155 | " keras.Input(shape=input_shape),\n",
156 | " layers.Conv2D(32, kernel_size=(3, 3), activation=\"elu\"),\n",
157 | " Masksembles2D(4, 2.0), # adding Masksembles2D\n",
158 | " layers.MaxPooling2D(pool_size=(2, 2)),\n",
159 | " \n",
160 | " layers.Conv2D(64, kernel_size=(3, 3), activation=\"elu\"),\n",
161 | " Masksembles2D(4, 2.0), # adding Masksembles2D\n",
162 | " layers.MaxPooling2D(pool_size=(2, 2)),\n",
163 | " \n",
164 | " layers.Flatten(),\n",
165 | " Masksembles1D(4, 2.), # adding Masksembles1D\n",
166 | " layers.Dense(num_classes, activation=\"softmax\"),\n",
167 | " ]\n",
168 | ")\n",
169 | "\n",
170 | "model.summary()"
171 | ],
172 | "execution_count": 29,
173 | "outputs": [
174 | {
175 | "output_type": "stream",
176 | "text": [
177 | "Model: \"sequential_3\"\n",
178 | "_________________________________________________________________\n",
179 | "Layer (type) Output Shape Param # \n",
180 | "=================================================================\n",
181 | "conv2d_6 (Conv2D) (None, 26, 26, 32) 320 \n",
182 | "_________________________________________________________________\n",
183 | "masksembles2d_6 (Masksembles (None, 26, 26, 32) 128 \n",
184 | "_________________________________________________________________\n",
185 | "max_pooling2d_6 (MaxPooling2 (None, 13, 13, 32) 0 \n",
186 | "_________________________________________________________________\n",
187 | "conv2d_7 (Conv2D) (None, 11, 11, 64) 18496 \n",
188 | "_________________________________________________________________\n",
189 | "masksembles2d_7 (Masksembles (None, 11, 11, 64) 256 \n",
190 | "_________________________________________________________________\n",
191 | "max_pooling2d_7 (MaxPooling2 (None, 5, 5, 64) 0 \n",
192 | "_________________________________________________________________\n",
193 | "flatten_3 (Flatten) (None, 1600) 0 \n",
194 | "_________________________________________________________________\n",
195 | "masksembles1d_2 (Masksembles (None, 1600) 6400 \n",
196 | "_________________________________________________________________\n",
197 | "dense_1 (Dense) (None, 10) 16010 \n",
198 | "=================================================================\n",
199 | "Total params: 41,610\n",
200 | "Trainable params: 34,826\n",
201 | "Non-trainable params: 6,784\n",
202 | "_________________________________________________________________\n"
203 | ],
204 | "name": "stdout"
205 | }
206 | ]
207 | },
208 | {
209 | "cell_type": "markdown",
210 | "metadata": {
211 | "id": "DP9Km-bGigRC"
212 | },
213 | "source": [
214 | "Training of Masksembles is not different from training of regular model. So we just use standard fit Keras API."
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "metadata": {
220 | "colab": {
221 | "base_uri": "https://localhost:8080/"
222 | },
223 | "id": "44akdkDYDDQl",
224 | "outputId": "1e31da86-663a-4fdc-a344-857dedef6e4f"
225 | },
226 | "source": [
227 | "batch_size = 128\n",
228 | "epochs = 5\n",
229 | "\n",
230 | "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
231 | "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)"
232 | ],
233 | "execution_count": 30,
234 | "outputs": [
235 | {
236 | "output_type": "stream",
237 | "text": [
238 | "Epoch 1/5\n",
239 | "422/422 [==============================] - 54s 126ms/step - loss: 1.0014 - accuracy: 0.7048 - val_loss: 0.1698 - val_accuracy: 0.9520\n",
240 | "Epoch 2/5\n",
241 | "422/422 [==============================] - 53s 125ms/step - loss: 0.1825 - accuracy: 0.9455 - val_loss: 0.1100 - val_accuracy: 0.9693\n",
242 | "Epoch 3/5\n",
243 | "422/422 [==============================] - 53s 125ms/step - loss: 0.1235 - accuracy: 0.9628 - val_loss: 0.0824 - val_accuracy: 0.9773\n",
244 | "Epoch 4/5\n",
245 | "422/422 [==============================] - 53s 126ms/step - loss: 0.0961 - accuracy: 0.9706 - val_loss: 0.0791 - val_accuracy: 0.9767\n",
246 | "Epoch 5/5\n",
247 | "422/422 [==============================] - 53s 125ms/step - loss: 0.0850 - accuracy: 0.9742 - val_loss: 0.0674 - val_accuracy: 0.9827\n"
248 | ],
249 | "name": "stdout"
250 | },
251 | {
252 | "output_type": "execute_result",
253 | "data": {
254 | "text/plain": [
255 | ""
256 | ]
257 | },
258 | "metadata": {
259 | "tags": []
260 | },
261 | "execution_count": 30
262 | }
263 | ]
264 | },
265 | {
266 | "cell_type": "markdown",
267 | "metadata": {
268 | "id": "Wh3pGxWujZ0Z"
269 | },
270 | "source": [
271 | "After training we could check that all of Masksembles' submodels would predict similar predictions for training samples."
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "metadata": {
277 | "id": "T3GNT8Z5kItR"
278 | },
279 | "source": [
280 | "img = x_train[0] # just random image from training set"
281 | ],
282 | "execution_count": 118,
283 | "outputs": []
284 | },
285 | {
286 | "cell_type": "code",
287 | "metadata": {
288 | "id": "o0K51rvqkMIs",
289 | "outputId": "2baad7d7-c8b3-41c2-c8d5-857e7754fb07",
290 | "colab": {
291 | "base_uri": "https://localhost:8080/",
292 | "height": 265
293 | }
294 | },
295 | "source": [
296 | "plt.imshow(img[..., 0])\n",
297 | "plt.show()"
298 | ],
299 | "execution_count": 119,
300 | "outputs": [
301 | {
302 | "output_type": "display_data",
303 | "data": {
304 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOZ0lEQVR4nO3dbYxc5XnG8euKbezamMQbB9chLjjgFAg0Jl0ZEBZQobgOqgSoCsSKIkJpnSY4Ca0rQWlV3IpWbpUQUUqRTHExFS+BBIQ/0CTUQpCowWWhBgwEDMY0NmaNWYENIX5Z3/2w42iBnWeXmTMv3vv/k1Yzc+45c24NXD5nznNmHkeEAIx/H+p0AwDag7ADSRB2IAnCDiRB2IEkJrZzY4d5ckzRtHZuEkjlV3pbe2OPR6o1FXbbiyVdJ2mCpH+LiJWl50/RNJ3qc5rZJICC9bGubq3hw3jbEyTdIOnzkk6UtMT2iY2+HoDWauYz+wJJL0TE5ojYK+lOSedV0xaAqjUT9qMk/WLY4621Ze9ie6ntPtt9+7Snic0BaEbLz8ZHxKqI6I2I3kma3OrNAaijmbBvkzRn2ONP1JYB6ELNhP1RSfNsz7V9mKQvSlpbTVsAqtbw0FtE7Le9TNKPNDT0tjoinq6sMwCVamqcPSLul3R/Rb0AaCEulwWSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJpmZxRffzxPJ/4gkfm9nS7T/3F8fUrQ1OPVBc9+hjdxTrU7/uYv3Vaw+rW3u893vFdXcOvl2sn3r38mL9uD9/pFjvhKbCbnuLpN2SBiXtj4jeKpoCUL0q9uy/FxE7K3gdAC3EZ3YgiWbDHpJ+bPsx20tHeoLtpbb7bPft054mNwegUc0exi+MiG22j5T0gO2fR8TDw58QEaskrZKkI9wTTW4PQIOa2rNHxLba7Q5J90paUEVTAKrXcNhtT7M9/eB9SYskbayqMQDVauYwfpake20ffJ3bI+KHlXQ1zkw4YV6xHpMnFeuvnPWRYv2d0+qPCfd8uDxe/JPPlMebO+k/fzm9WP/Hf1lcrK8/+fa6tZf2vVNcd2X/54r1j//k0PtE2nDYI2KzpM9U2AuAFmLoDUiCsANJEHYgCcIOJEHYgST4imsFBs/+bLF+7S03FOufmlT/q5jj2b4YLNb/5vqvFOsT3y4Pf51+97K6tenb9hfXnbyzPDQ3tW99sd6N2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs1dg8nOvFOuP/WpOsf6pSf1VtlOp5dtPK9Y3v1X+Kepbjv1+3dqbB8rj5LP++b+L9VY69L7AOjr27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQhCPaN6J4hHviVJ/Ttu11i4FLTi/Wdy0u/9zzhCcPL9af+Pr1H7ing67Z+TvF+qNnlcfRB994s1iP0+v/APGWbxZX1dwlT5SfgPdZH+u0KwZGnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMPOjxfrg6wPF+ku31x8rf/rM1cV1F/zDN4r1I2/o3HfK8cE1Nc5ue7XtHbY3DlvWY/sB25tqtzOqbBhA9cZyGH+LpPfOen+lpHURMU/SutpjAF1s1LBHxMOS3nsceZ6kNbX7aySdX3FfACrW6G/QzYqI7bX7r0qaVe+JtpdKWipJUzS1wc0BaFbTZ+Nj6Axf3bN8EbEqInojoneSJje7OQANajTs/bZnS1Ltdkd1LQFohUbDvlbSxbX7F0u6r5p2ALTKqJ/Zbd8h6WxJM21vlXS1pJWS7rJ9qaSXJV3YyibHu8Gdrze1/r5djc/v/ukvPVOsv3bjhPILHCjPsY7uMWrYI2JJnRJXxwCHEC6XBZIg7EAShB1IgrADSRB2IAmmbB4HTrji+bq1S04uD5r8+9HrivWzvnBZsT79e48U6+ge7NmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2ceB0rTJr3/thOK6/7f2nWL9ymtuLdb/8sILivX43w/Xrc35+58V11Ubf+Y8A/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEUzYnN/BHpxfrt1397WJ97sQpDW/707cuK9bn3bS9WN+/eUvD2x6vmpqyGcD4QNiBJAg7kARhB5Ig7EAShB1IgrADSTDOjqI4Y36xfsTKrcX6HZ/8UcPbPv7BPy7Wf/tv63+PX5IGN21ueNuHqqbG2W2vtr3D9sZhy1bY3mZ7Q+3v3CobBlC9sRzG3yJp8QjLvxsR82t/91fbFoCqjRr2iHhY0kAbegHQQs2coFtm+8naYf6Mek+yvdR2n+2+fdrTxOYANKPRsN8o6VhJ8yVtl/Sdek+MiFUR0RsRvZM0ucHNAWhWQ2GPiP6IGIyIA5JukrSg2rYAVK2hsNuePezhBZI21nsugO4w6ji77TsknS1ppqR+SVfXHs+XFJK2SPpqRJS/fCzG2cejCbOOLNZfuei4urX1V1xXXPdDo+yLvvTSomL9zYWvF+vjUWmcfdRJIiJiyQiLb266KwBtxeWyQBKEHUiCsANJEHYgCcIOJMFXXNExd20tT9k81YcV67+MvcX6H3zj8vqvfe/64rqHKn5KGgBhB7Ig7EAShB1IgrADSRB2IAnCDiQx6rfekNuBheWfkn7xC+Upm0+av6VubbRx9NFcP3BKsT71vr6mXn+8Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzj7OufekYv35b5bHum86Y02xfuaU8nfKm7En9hXrjwzMLb/AgVF/3TwV9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7IeAiXOPLtZfvOTjdWsrLrqzuO4fHr6zoZ6qcFV/b7H+0HWnFesz1pR/dx7vNuqe3fYc2w/afsb207a/VVveY/sB25tqtzNa3y6ARo3lMH6/pOURcaKk0yRdZvtESVdKWhcR8yStqz0G0KVGDXtEbI+Ix2v3d0t6VtJRks6TdPBayjWSzm9VkwCa94E+s9s+RtIpktZLmhURBy8+flXSrDrrLJW0VJKmaGqjfQJo0pjPxts+XNIPJF0eEbuG12JodsgRZ4iMiFUR0RsRvZM0ualmATRuTGG3PUlDQb8tIu6pLe63PbtWny1pR2taBFCFUQ/jbVvSzZKejYhrh5XWSrpY0sra7X0t6XAcmHjMbxXrb/7u7GL9or/7YbH+px+5p1hvpeXby8NjP/vX+sNrPbf8T3HdGQcYWqvSWD6znyHpy5Kesr2htuwqDYX8LtuXSnpZ0oWtaRFAFUYNe0T8VNKIk7tLOqfadgC0CpfLAkkQdiAJwg4kQdiBJAg7kARfcR2jibN/s25tYPW04rpfm/tQsb5ken9DPVVh2baFxfrjN5anbJ75/Y3Fes9uxsq7BXt2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUgizTj73t8v/2zx3j8bKNavOu7+urVFv/F2Qz1VpX/wnbq1M9cuL657/F//vFjveaM8Tn6gWEU3Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0mkGWffcn7537XnT767Zdu+4Y1ji/XrHlpUrHuw3o/7Djn+mpfq1ub1ry+uO1isYjxhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTgiyk+w50i6VdIsSSFpVURcZ3uFpD+R9FrtqVdFRP0vfUs6wj1xqpn4FWiV9bFOu2JgxAszxnJRzX5JyyPicdvTJT1m+4Fa7bsR8e2qGgXQOmOZn327pO21+7ttPyvpqFY3BqBaH+gzu+1jJJ0i6eA1mMtsP2l7te0ZddZZarvPdt8+7WmqWQCNG3PYbR8u6QeSLo+IXZJulHSspPka2vN/Z6T1ImJVRPRGRO8kTa6gZQCNGFPYbU/SUNBvi4h7JCki+iNiMCIOSLpJ0oLWtQmgWaOG3bYl3Szp2Yi4dtjy2cOedoGk8nSeADpqLGfjz5D0ZUlP2d5QW3aVpCW252toOG6LpK+2pEMAlRjL2fifShpp3K44pg6gu3AFHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlRf0q60o3Zr0l6ediimZJ2tq2BD6Zbe+vWviR6a1SVvR0dER8bqdDWsL9v43ZfRPR2rIGCbu2tW/uS6K1R7eqNw3ggCcIOJNHpsK/q8PZLurW3bu1LordGtaW3jn5mB9A+nd6zA2gTwg4k0ZGw215s+znbL9i+shM91GN7i+2nbG+w3dfhXlbb3mF747BlPbYfsL2pdjviHHsd6m2F7W21926D7XM71Nsc2w/afsb207a/VVve0feu0Fdb3re2f2a3PUHS85I+J2mrpEclLYmIZ9raSB22t0jqjYiOX4Bh+0xJb0m6NSJOqi37J0kDEbGy9g/ljIi4okt6WyHprU5P412brWj28GnGJZ0v6Svq4HtX6OtCteF968SefYGkFyJic0TslXSnpPM60EfXi4iHJQ28Z/F5ktbU7q/R0P8sbVent64QEdsj4vHa/d2SDk4z3tH3rtBXW3Qi7EdJ+sWwx1vVXfO9h6Qf237M9tJONzOCWRGxvXb/VUmzOtnMCEadxrud3jPNeNe8d41Mf94sTtC938KI+Kykz0u6rHa42pVi6DNYN42djmka73YZYZrxX+vke9fo9OfN6kTYt0maM+zxJ2rLukJEbKvd7pB0r7pvKur+gzPo1m53dLifX+umabxHmmZcXfDedXL6806E/VFJ82zPtX2YpC9KWtuBPt7H9rTaiRPZniZpkbpvKuq1ki6u3b9Y0n0d7OVdumUa73rTjKvD713Hpz+PiLb/STpXQ2fkX5T0V53ooU5fn5T0RO3v6U73JukODR3W7dPQuY1LJX1U0jpJmyT9l6SeLurtPyQ9JelJDQVrdod6W6ihQ/QnJW2o/Z3b6feu0Fdb3jculwWS4AQdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTx/65XcTNOWsh5AAAAAElFTkSuQmCC\n",
305 | "text/plain": [
306 | ""
307 | ]
308 | },
309 | "metadata": {
310 | "tags": [],
311 | "needs_background": "light"
312 | }
313 | }
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "metadata": {
319 | "id": "yfIVxwixkQdY"
320 | },
321 | "source": [
322 | "To acquire predictions from different submodels one should transform input (with shape [1, H, W, C]) into batch (with shape [M, H, W, C]) that consists of M copies of original input (H - height of image, W - width of image, C - number of channels).\n",
323 | "\n",
324 | "As we can see Masksembles submodels produce similar predictions for training set samples."
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "metadata": {
330 | "id": "thkg1KJJjY85",
331 | "outputId": "010a3f3c-f20c-427a-87c9-f08966707e0e",
332 | "colab": {
333 | "base_uri": "https://localhost:8080/"
334 | }
335 | },
336 | "source": [
337 | "inputs = np.tile(img[None], [4, 1, 1, 1])\n",
338 | "predictions = model(inputs)\n",
339 | "for i, cls in enumerate(tf.argmax(predictions, axis=1)):\n",
340 | " print(f\"PREDICTION OF {i+1} MODEL: {cls} CLASS\")"
341 | ],
342 | "execution_count": 120,
343 | "outputs": [
344 | {
345 | "output_type": "stream",
346 | "text": [
347 | "PREDICTION OF 1 MODEL: 5 CLASS\n",
348 | "PREDICTION OF 2 MODEL: 5 CLASS\n",
349 | "PREDICTION OF 3 MODEL: 5 CLASS\n",
350 | "PREDICTION OF 4 MODEL: 5 CLASS\n"
351 | ],
352 | "name": "stdout"
353 | }
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "metadata": {
359 | "id": "KH81gNyGlcnh"
360 | },
361 | "source": [
362 | "On out-of-distribution samples Masksembles should produce predictions with high variance, let's check it on complex samples from MNIST."
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "metadata": {
368 | "id": "l9rbe1s9fBDf"
369 | },
370 | "source": [
371 | "img = np.load(\"./complex_sample_mnist.npy\")"
372 | ],
373 | "execution_count": 121,
374 | "outputs": []
375 | },
376 | {
377 | "cell_type": "code",
378 | "metadata": {
379 | "id": "n6YghcdrfdkZ",
380 | "outputId": "a0e0ff76-3ad3-416b-d0ed-642debf7d491",
381 | "colab": {
382 | "base_uri": "https://localhost:8080/",
383 | "height": 283
384 | }
385 | },
386 | "source": [
387 | "plt.imshow(img[..., 0])"
388 | ],
389 | "execution_count": 122,
390 | "outputs": [
391 | {
392 | "output_type": "execute_result",
393 | "data": {
394 | "text/plain": [
395 | ""
396 | ]
397 | },
398 | "metadata": {
399 | "tags": []
400 | },
401 | "execution_count": 122
402 | },
403 | {
404 | "output_type": "display_data",
405 | "data": {
406 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPAklEQVR4nO3dfWxd9X3H8c/XjmOTB2hCmmAlGQks3ZbS1tnctAUKdBRI4Y/QSqNEapcxWncSSO3WTWXsD/ijf2SoLaq0DclA2rRqYVQFEalRIYuoWNcNxYQ0DwQID4mI68RJvCYG8uCH7/7wCTPg87vmPp0bf98vybrX53uPzzc3/vjce3/nnJ+5uwBMfU1FNwCgPgg7EARhB4Ig7EAQhB0IYlo9NzbdWr1NM+u5SSCUk3pTp/2UTVSrKOxmtkrS9yU1S3rA3delHt+mmfqEXV3JJgEkPONbcmtlv4w3s2ZJ/yrpc5KWS1pjZsvL/XkAaquS9+wrJb3s7q+6+2lJD0taXZ22AFRbJWFfKOn1cd8fyJa9g5l1mVmPmfUM6VQFmwNQiZp/Gu/u3e7e6e6dLWqt9eYA5Kgk7L2SFo/7flG2DEADqiTsWyUtM7OlZjZd0s2SNlanLQDVVvbQm7sPm9ntkp7Q2NDbenffXbXOAFRVRePs7r5J0qYq9QKghjhcFgiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEg6jplc1hNzeny9Jb0+i0l6kNDuaXRUyWm3HJP1zFlsGcHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAYZ6+Doc+uSNb3/+VIsn7LR/87WV//3KW5tT+582By3eHf9SXrjMNPHRWF3cz2SRqUNCJp2N07q9EUgOqrxp79M+5+pAo/B0AN8Z4dCKLSsLukJ83sWTPrmugBZtZlZj1m1jOkEsdpA6iZSl/GX+7uvWY2X9JmM3vB3Z8e/wB375bULUnn2lw+7QEKUtGe3d17s9t+SY9JWlmNpgBUX9lhN7OZZjb7zH1J10raVa3GAFRXJS/jF0h6zMzO/Jyfuvsvq9LVWcYv60jWX/tiev1/v6w7WT/p6fPZf9D0qfzi6Gh64wij7LC7+6uSPlbFXgDUEENvQBCEHQiCsANBEHYgCMIOBMEprlVw9JJzkvXVHVuT9ZWt6aG1pY9PeCTy25avyz9NdeRwiXOUOIU1DPbsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAE4+zVUGKo+vRohU9zS/o0VT92PL82PJxct2nGjGT95BUfTtYHF6b/bTMP5V8me9pb6Utot/bsTdZHjuf/u/Fe7NmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjG2augucSsVm8MT6/o539hxbZk/Rd/+8nc2pwX0mP0pz5gybqv+t9k/SPzf5esv3JsXm5t4ERbct3RrZck6xc+kp5uenR/b27Nh04n152K2LMDQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCMs1fBvN8cStZ7FqfHi+/6i6PJ+j0X9KTrX8mv765wPPmj09Nj4ad8KFnfcPzC3No1M15Krju4Iv3ruUZ/l6wv2ZB/AMRwb/r4gKmo5J7dzNabWb+Z7Rq3bK6ZbTazvdntnNq2CaBSk3kZ/0NJq9617A5JW9x9maQt2fcAGljJsLv705IG3rV4taQN2f0Nkm6scl8Aqqzc9+wL3P3MgckHJS3Ie6CZdUnqkqQ2pa93BqB2Kv403t1diUsuunu3u3e6e2eLWivdHIAylRv2Q2bWLknZbX/1WgJQC+WGfaOktdn9tZIer047AGql5Ht2M3tI0lWS5pnZAUl3SVon6REzu1XSfkk31bLJRjey99VkfcnP0n9TH3vrymT9qes+lKx/av5ryXrKwOmZyfr2wwuT9SP95ybrrQfy555/4wtPJNe9btbuZP3EwvR15/0c3jaOVzLs7r4mp3R1lXsBUEMcLgsEQdiBIAg7EARhB4Ig7EAQnOJaByMvvpyst5eoT/tp7tHIkqTfLkqfQpvSdCJ9iuoHDx5O10fT9WOfzR82vGhN+lisphJzYZ+3uzlZ1++Z0nk89uxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EATj7GeB4YPpS1WrVD0hfZKo1NSWvpR0/9oVyfqp6/LHum+YcSy57rePfDxZv+C/0tNJjwz8PlmPhj07EARhB4Ig7EAQhB0IgrADQRB2IAjCDgTBOHtwNi39K3Dizz+SrHfcsjNZ/87CJ3Nrrw2nz1d/aNMVyfpFL2xL1jVa6iiCWNizA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQjLNPAamx8ub2C5Lr7vvSHyTr3V/9l2R9xfTh9PrHlufWfvDA9cl1lz36erI+fOpUso53KrlnN7P1ZtZvZrvGLbvbzHrNbHv2lf5fA1C4ybyM/6GkVRMsv9fdO7KvTdVtC0C1lQy7uz8taaAOvQCooUo+oLvdzHZkL/Pn5D3IzLrMrMfMeobEeyygKOWG/T5JF0vqkNQn6bt5D3T3bnfvdPfOFrWWuTkAlSor7O5+yN1H3H1U0v2SVla3LQDVVlbYzax93Lefl7Qr77EAGkPJcXYze0jSVZLmmdkBSXdJusrMOiS5pH2SvlbDHsNrmj07Wfc/XpJb23dNet1HunLfgUmSPjz9nGRdmp6s/tuOK3Nrf/jjF5PrDh85WmLbeD9Kht3d10yw+MEa9AKghjhcFgiCsANBEHYgCMIOBEHYgSA4xfUscPLSP0rWj902mFt7ouOe5LqLps0qq6fJ+vuOzbm1B25YnVx33i9fSdZHDvWX1VNU7NmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjG2c8CbQffTNb7njs/t/aVWV9MrutuyXpLc3ra4yvnvZSsXzpjb27t9n/8WXLde2felKzPv+9wsi5PTwkdDXt2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQjCvI5jkefaXP+EXV237U0ZTc3pclv+TDs2o9SloEtoTm97ZGl6Sui9f9OSX7vm/uS61+65MVlvvTl9/MFIwEtRP+NbdNwHJjx4gj07EARhB4Ig7EAQhB0IgrADQRB2IAjCDgTB+exng9H0OeWjb72VX0zVqsCODiTr7b/4s9zaf346/ev3D0ueSNa/9aVbk/VFD+dfd3744KHkulNRyT27mS02s6fM7Hkz221mX8+WzzWzzWa2N7udU/t2AZRrMi/jhyV9092XS/qkpNvMbLmkOyRtcfdlkrZk3wNoUCXD7u597r4tuz8oaY+khZJWS9qQPWyDpPSxjQAK9b7es5vZEkkrJD0jaYG792Wlg5IW5KzTJalLkto0o9w+AVRo0p/Gm9ksST+X9A13Pz6+5mNn00x4Ro27d7t7p7t3tij/hA0AtTWpsJtZi8aC/hN3fzRbfMjM2rN6uySm1AQaWMmX8WZmkh6UtMfdvzeutFHSWknrstvHa9IhGpoPDyfrH9jal1u75Vd/nVz3nz+dvtT04NLRZF2Vnt47xUzmPftlkr4saaeZbc+W3amxkD9iZrdK2i8pfZFvAIUqGXZ3/7WkvJkEuBIFcJbgcFkgCMIOBEHYgSAIOxAEYQeC4BRX1NRo/5Hc2vn/szC57ksfb0/WfXZ6jN+npS+DHQ17diAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgnF21JblnTApjZb47ZvVfDJZ/9jFryfrJ8+bn95AMOzZgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIxtlRnAnnEPp/o57eF01rSk9ljXdizw4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQUxmfvbFkn4kaYHGRka73f37Zna3pK9KOpw99E5331SrRnF2spb8X7E305eN17xpx5P1Z3delKwvP9qfW0tfcX5qmsxBNcOSvunu28xstqRnzWxzVrvX3b9Tu/YAVMtk5mfvk9SX3R80sz2SSvxNBtBo3td7djNbImmFpGeyRbeb2Q4zW29mc3LW6TKzHjPrGdKpipoFUL5Jh93MZkn6uaRvuPtxSfdJulhSh8b2/N+daD1373b3TnfvbFFrFVoGUI5Jhd3MWjQW9J+4+6OS5O6H3H3E3Ucl3S9pZe3aBFCpkmE3M5P0oKQ97v69ccvHT7H5eUm7qt8egGqZzKfxl0n6sqSdZrY9W3anpDVm1qGx4bh9kr5Wkw5xVvOR0dxa20D+ZaYl6dvP3ZCsn/dCiV/fE+lLUUczmU/jfy1pov8VxtSBswhH0AFBEHYgCMIOBEHYgSAIOxAEYQeC4FLSqKnRwcHc2gX3/qam2454GmsKe3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCMLcS8ybW82NmR2WtH/conmSjtStgfenUXtr1L4keitXNXu70N0/OFGhrmF/z8bNety9s7AGEhq1t0btS6K3ctWrN17GA0EQdiCIosPeXfD2Uxq1t0btS6K3ctWlt0LfswOon6L37ADqhLADQRQSdjNbZWYvmtnLZnZHET3kMbN9ZrbTzLabWU/Bvaw3s34z2zVu2Vwz22xme7PbCefYK6i3u82sN3vutpvZ9QX1ttjMnjKz581st5l9PVte6HOX6Ksuz1vd37ObWbOklyRdI+mApK2S1rj783VtJIeZ7ZPU6e6FH4BhZldIekPSj9z9kmzZPZIG3H1d9odyjrt/q0F6u1vSG0VP453NVtQ+fppxSTdK+isV+Nwl+rpJdXjeitizr5T0sru/6u6nJT0saXUBfTQ8d39a0sC7Fq+WtCG7v0Fjvyx1l9NbQ3D3Pnfflt0flHRmmvFCn7tEX3VRRNgXSnp93PcH1FjzvbukJ83sWTPrKrqZCSxw977s/kFJC4psZgIlp/Gup3dNM94wz105059Xig/o3utyd/9TSZ+TdFv2crUh+dh7sEYaO53UNN71MsE0428r8rkrd/rzShUR9l5Ji8d9vyhb1hDcvTe77Zf0mBpvKupDZ2bQzW77C+7nbY00jfdE04yrAZ67Iqc/LyLsWyUtM7OlZjZd0s2SNhbQx3uY2czsgxOZ2UxJ16rxpqLeKGltdn+tpMcL7OUdGmUa77xpxlXwc1f49OfuXvcvSddr7BP5VyT9UxE95PR1kaTfZl+7i+5N0kMae1k3pLHPNm6VdL6kLZL2SvoPSXMbqLcfS9opaYfGgtVeUG+Xa+wl+g5J27Ov64t+7hJ91eV543BZIAg+oAOCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIP4PbYZ6s7wJM9MAAAAASUVORK5CYII=\n",
407 | "text/plain": [
408 | ""
409 | ]
410 | },
411 | "metadata": {
412 | "tags": [],
413 | "needs_background": "light"
414 | }
415 | }
416 | ]
417 | },
418 | {
419 | "cell_type": "code",
420 | "metadata": {
421 | "colab": {
422 | "base_uri": "https://localhost:8080/"
423 | },
424 | "id": "65Y_-0h9U96h",
425 | "outputId": "9e86c6b5-2fea-4e6a-dfef-437da2979252"
426 | },
427 | "source": [
428 | "inputs = np.tile(img[None], [4, 1, 1, 1])\n",
429 | "predictions = model(inputs)\n",
430 | "for i, cls in enumerate(tf.argmax(predictions, axis=1)):\n",
431 | " print(f\"PREDICTION OF {i+1} MODEL: {cls} CLASS\")"
432 | ],
433 | "execution_count": 123,
434 | "outputs": [
435 | {
436 | "output_type": "stream",
437 | "text": [
438 | "PREDICTION OF 1 MODEL: 3 CLASS\n",
439 | "PREDICTION OF 2 MODEL: 7 CLASS\n",
440 | "PREDICTION OF 3 MODEL: 7 CLASS\n",
441 | "PREDICTION OF 4 MODEL: 7 CLASS\n"
442 | ],
443 | "name": "stdout"
444 | }
445 | ]
446 | }
447 | ]
448 | }
--------------------------------------------------------------------------------
/notebooks/MNIST_Masksembles.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "dacca9e1"
7 | },
8 | "source": [
9 | "\n",
10 | "# Masksembles on MNIST — Tutorial\n",
11 | "\n",
12 | "Welcome! 👋\n",
13 | "\n",
14 | "This notebook walks you through using **Masksembles** for uncertainty‑aware inference on the classic **MNIST** dataset. \n",
15 | "\n",
16 | "You'll see how to:\n",
17 | "\n",
18 | "- Load and preprocess MNIST\n",
19 | "- Define **Masksembles** layers (`Masksembles2D` / `Masksembles1D`)\n",
20 | "- Build a small CNN\n",
21 | "- Train and validate in PyTorch\n",
22 | "- Run **per‑submodel predictions** (ensembled behavior) at inference time\n",
23 | "- Avoid common pitfalls\n",
24 | "\n",
25 | "> **What is Masksembles?** \n",
26 | "> Masksembles builds multiple *subnetworks* inside a single model using deterministic, non‑overlapping masks. At inference time you can obtain diverse predictions (like an ensemble) **without** keeping multiple separate models. This gives you better **uncertainty estimation** and often more robust predictions with minimal overhead.\n"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "670b136c"
33 | },
34 | "source": [
35 | "\n",
36 | "## Table of Contents\n",
37 | "1. [Setup & Imports](#setup-imports)\n",
38 | "2. [Load MNIST & Preprocessing](#load-preprocess)\n",
39 | "3. [Masksembles Layers](#masksembles-layers)\n",
40 | "4. [Model Architecture](#model-architecture)\n",
41 | "5. [Training Setup](#training-setup)\n",
42 | "6. [Training Loop](#training-loop)\n",
43 | "7. [Evaluation & Inference](#eval-infer)\n",
44 | "8. [Next Steps](#next-steps)\n"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "source": [
50 | "But before we start, lets make sure you installed Masksembles package."
51 | ],
52 | "metadata": {
53 | "id": "YuRK1pIbrBL0"
54 | }
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 1,
59 | "metadata": {
60 | "colab": {
61 | "base_uri": "https://localhost:8080/"
62 | },
63 | "collapsed": true,
64 | "id": "qrpn5UV2BYLF",
65 | "outputId": "b1634cd5-3287-45bc-de91-88c0ba806fd6"
66 | },
67 | "outputs": [
68 | {
69 | "output_type": "stream",
70 | "name": "stdout",
71 | "text": [
72 | "Collecting masksembles\n",
73 | " Downloading masksembles-1.1-py3-none-any.whl.metadata (5.1 kB)\n",
74 | "Downloading masksembles-1.1-py3-none-any.whl (8.4 kB)\n",
75 | "Installing collected packages: masksembles\n",
76 | "Successfully installed masksembles-1.1\n",
77 | "--2025-11-09 12:52:18-- https://github.com/nikitadurasov/masksembles/raw/main/images/complex_sample_mnist.npy\n",
78 | "Resolving github.com (github.com)... 140.82.116.3\n",
79 | "Connecting to github.com (github.com)|140.82.116.3|:443... connected.\n",
80 | "HTTP request sent, awaiting response... 302 Found\n",
81 | "Location: https://raw.githubusercontent.com/nikitadurasov/masksembles/main/images/complex_sample_mnist.npy [following]\n",
82 | "--2025-11-09 12:52:18-- https://raw.githubusercontent.com/nikitadurasov/masksembles/main/images/complex_sample_mnist.npy\n",
83 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
84 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
85 | "HTTP request sent, awaiting response... 200 OK\n",
86 | "Length: 6400 (6.2K) [application/octet-stream]\n",
87 | "Saving to: ‘complex_sample_mnist.npy’\n",
88 | "\n",
89 | "complex_sample_mnis 100%[===================>] 6.25K --.-KB/s in 0s \n",
90 | "\n",
91 | "2025-11-09 12:52:18 (96.2 MB/s) - ‘complex_sample_mnist.npy’ saved [6400/6400]\n",
92 | "\n"
93 | ]
94 | }
95 | ],
96 | "source": [
97 | "!pip install masksembles\n",
98 | "!wget https://github.com/nikitadurasov/masksembles/raw/main/images/complex_sample_mnist.npy"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {
104 | "id": "32ca1561"
105 | },
106 | "source": [
107 | "\n",
108 | "## 1. Setup & Imports \n",
109 | "\n",
110 | "This section imports PyTorch, TorchVision, and the other packages used later. "
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "source": [
116 | "import numpy as np\n",
117 | "import matplotlib.pyplot as plt\n",
118 | "\n",
119 | "import torch\n",
120 | "import torch.nn as nn\n",
121 | "import torch.nn.functional as F\n",
122 | "import torch.optim as optim\n",
123 | "from torch.utils.data import DataLoader, random_split, TensorDataset\n",
124 | "from torchvision import datasets, transforms"
125 | ],
126 | "metadata": {
127 | "id": "r4tPNSFvqUWp"
128 | },
129 | "execution_count": 2,
130 | "outputs": []
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {
135 | "id": "be8e329c"
136 | },
137 | "source": [
138 | "\n",
139 | "## 2. Load MNIST & Preprocessing \n",
140 | "\n",
141 | "Here we load MNIST and transform images to **tensors in [0, 1]**. \n",
142 | "Remember that **PyTorch uses channel‑first** tensors *(N, C, H, W)*, so MNIST images should become shape `(N, 1, 28, 28)`.\n"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "source": [
148 | "# Model / data parameters\n",
149 | "num_classes = 10\n",
150 | "input_shape = (1, 28, 28) # PyTorch uses channel-first format (C, H, W)\n",
151 | "\n",
152 | "# Define transform: convert to tensor and normalize to [0, 1]\n",
153 | "transform = transforms.Compose([\n",
154 | " transforms.ToTensor(), # Converts to tensor and scales to [0, 1]\n",
155 | "])\n",
156 | "\n",
157 | "# Load MNIST dataset\n",
158 | "train_dataset = datasets.MNIST(\n",
159 | " root='./data',\n",
160 | " train=True,\n",
161 | " transform=transform,\n",
162 | " download=True\n",
163 | ")\n",
164 | "test_dataset = datasets.MNIST(\n",
165 | " root='./data',\n",
166 | " train=False,\n",
167 | " transform=transform,\n",
168 | " download=True\n",
169 | ")\n",
170 | "\n",
171 | "# Create DataLoaders\n",
172 | "batch_size = 64\n",
173 | "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
174 | "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
175 | "\n",
176 | "# Check shapes\n",
177 | "images, labels = next(iter(train_loader))\n",
178 | "print(\"x_train shape:\", images.shape) # [batch_size, 1, 28, 28]\n",
179 | "print(len(train_dataset), \"train samples\")\n",
180 | "print(len(test_dataset), \"test samples\")\n",
181 | "\n",
182 | "# Convert labels to one-hot if needed\n",
183 | "# (usually not needed for CrossEntropyLoss)\n",
184 | "y_one_hot = torch.nn.functional.one_hot(labels, num_classes=num_classes).float()\n",
185 | "print(\"y_train one-hot shape:\", y_one_hot.shape)\n"
186 | ],
187 | "metadata": {
188 | "colab": {
189 | "base_uri": "https://localhost:8080/"
190 | },
191 | "id": "Cs0l4BmmpuWm",
192 | "outputId": "04a6eef1-f807-4305-db7c-27677aed40f9"
193 | },
194 | "execution_count": 3,
195 | "outputs": [
196 | {
197 | "output_type": "stream",
198 | "name": "stderr",
199 | "text": [
200 | "100%|██████████| 9.91M/9.91M [00:00<00:00, 19.7MB/s]\n",
201 | "100%|██████████| 28.9k/28.9k [00:00<00:00, 480kB/s]\n",
202 | "100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]\n",
203 | "100%|██████████| 4.54k/4.54k [00:00<00:00, 9.13MB/s]"
204 | ]
205 | },
206 | {
207 | "output_type": "stream",
208 | "name": "stdout",
209 | "text": [
210 | "x_train shape: torch.Size([64, 1, 28, 28])\n",
211 | "60000 train samples\n",
212 | "10000 test samples\n",
213 | "y_train one-hot shape: torch.Size([64, 10])\n"
214 | ]
215 | },
216 | {
217 | "output_type": "stream",
218 | "name": "stderr",
219 | "text": [
220 | "\n"
221 | ]
222 | }
223 | ]
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "source": [
228 | "\n",
229 | "## 3. Masksembles Layers \n",
230 | "\n",
231 | "These classes (`Masksembles2D`, `Masksembles1D`) create **non‑overlapping channel masks** that define *submodels* within the same network. \n",
232 | "- `Masksembles2D`: apply masks on convolutional feature maps (C×H×W). \n",
233 | "- `Masksembles1D`: apply masks on flattened features (C). \n",
234 | "Key arguments:\n",
235 | "- `channels`: number of channels in the input.\n",
236 | "- `n`: number of submodels (masks).\n",
237 | "- `scale`: controls correlation/capacity trade‑off between submodels.\n"
238 | ],
239 | "metadata": {
240 | "id": "1AgTlby2p0g-"
241 | }
242 | },
243 | {
244 | "cell_type": "code",
245 | "source": [
246 | "from masksembles.torch import Masksembles2D, Masksembles1D"
247 | ],
248 | "metadata": {
249 | "id": "lAnwkpGtp3qP"
250 | },
251 | "execution_count": 4,
252 | "outputs": []
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {
257 | "id": "Hr5xjOSffo7x"
258 | },
259 | "source": [
260 | "In order to transform regular model into Masksembles model one should add Masksembles2D or Masksembles1D layers in it. General recommendation is to insert these layers right before or after convolutional layers.\n",
261 | "\n",
262 | "In example below we'll use both Masksembles2D and Masksembles1D layers applied after convolutions."
263 | ]
264 | },
265 | {
266 | "cell_type": "markdown",
267 | "metadata": {
268 | "id": "75491b02"
269 | },
270 | "source": [
271 | "\n",
272 | "## 4. Model Architecture \n",
273 | "\n",
274 | "We define a small CNN with ELU activations and interleave **Masksembles** after conv and before the final linear layer. \n",
275 | "After two conv+pool blocks, MNIST's 28×28 maps down to 5×5; the final linear projects to 10 classes.\n"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "source": [
281 | "import torch\n",
282 | "from torch import nn\n",
283 | "\n",
284 | "from masksembles import common\n",
285 | "\n",
286 | "\n",
287 | "class Masksembles2D(nn.Module):\n",
288 | " \"\"\"\n",
289 | " :class:`Masksembles2D` is high-level class that implements Masksembles approach\n",
290 | " for 2-dimensional inputs (similar to :class:`torch.nn.Dropout2d`).\n",
291 | "\n",
292 | " :param channels: int, number of channels used in masks.\n",
293 | " :param n: int, number of masks\n",
294 | " :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \\\n",
295 | " subnetworks correlations but at the same time decrease capacity of every individual model.\n",
296 | "\n",
297 | " Shape:\n",
298 | " * Input: (N, C, H, W)\n",
299 | " * Output: (N, C, H, W) (same shape as input)\n",
300 | "\n",
301 | " Examples:\n",
302 | "\n",
303 | " >>> m = Masksembles2D(16, 4, 2.0)\n",
304 | " >>> input = torch.ones([4, 16, 28, 28])\n",
305 | " >>> output = m(input)\n",
306 | "\n",
307 | " References:\n",
308 | "\n",
309 | " [1] `Masksembles for Uncertainty Estimation`,\n",
310 | " Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua\n",
311 | "\n",
312 | " \"\"\"\n",
313 | " def __init__(self, channels: int, n: int, scale: float):\n",
314 | " super().__init__()\n",
315 | " self.channels, self.n, self.scale = channels, n, scale\n",
316 | " masks_np = common.generation_wrapper(channels, n, scale) # numpy float64 by default\n",
317 | " masks = torch.as_tensor(masks_np, dtype=torch.float32) # make float32 here\n",
318 | " self.register_buffer('masks', masks, persistent=False) # not trainable, moves with .to()\n",
319 | "\n",
320 | " def forward(self, inputs):\n",
321 | " # make sure masks match dtype/device (usually already true because of buffer, but safe)\n",
322 | " masks = self.masks.to(dtype=inputs.dtype, device=inputs.device)\n",
323 | "\n",
324 | " batch = inputs.shape[0]\n",
325 | " # safer split even if batch % n != 0\n",
326 | " chunks = torch.chunk(inputs.unsqueeze(1), self.n, dim=0) # returns nearly equal chunks\n",
327 | " x = torch.cat(chunks, dim=1).permute(1, 0, 2, 3, 4) # [n, ?, C, H, W]\n",
328 | " x = x * masks.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # broadcast masks\n",
329 | " x = torch.cat(torch.split(x, 1, dim=0), dim=1)\n",
330 | " return x.squeeze(0)\n",
331 | "\n",
332 | "\n",
333 | "\n",
334 | "class Masksembles1D(nn.Module):\n",
335 | " \"\"\"\n",
336 | " :class:`Masksembles1D` is high-level class that implements Masksembles approach\n",
337 | " for 1-dimensional inputs (similar to :class:`torch.nn.Dropout`).\n",
338 | "\n",
339 | " :param channels: int, number of channels used in masks.\n",
340 | " :param n: int, number of masks\n",
341 | " :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \\\n",
342 | " subnetworks correlations but at the same time decrease capacity of every individual model.\n",
343 | "\n",
344 | " Shape:\n",
345 | " * Input: (N, C)\n",
346 | " * Output: (N, C) (same shape as input)\n",
347 | "\n",
348 | " Examples:\n",
349 | "\n",
350 | " >>> m = Masksembles1D(16, 4, 2.0)\n",
351 | " >>> input = torch.ones([4, 16])\n",
352 | " >>> output = m(input)\n",
353 | "\n",
354 | "\n",
355 | " References:\n",
356 | "\n",
357 | " [1] `Masksembles for Uncertainty Estimation`,\n",
358 | " Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua\n",
359 | "\n",
360 | " \"\"\"\n",
361 | "\n",
362 | " def __init__(self, channels: int, n: int, scale: float):\n",
363 | " super().__init__()\n",
364 | " self.channels, self.n, self.scale = channels, n, scale\n",
365 | " masks_np = common.generation_wrapper(channels, n, scale)\n",
366 | " masks = torch.as_tensor(masks_np, dtype=torch.float32)\n",
367 | " self.register_buffer('masks', masks, persistent=False)\n",
368 | "\n",
369 | " def forward(self, inputs):\n",
370 | " masks = self.masks.to(dtype=inputs.dtype, device=inputs.device)\n",
371 | "\n",
372 | " batch = inputs.shape[0]\n",
373 | " chunks = torch.chunk(inputs.unsqueeze(1), self.n, dim=0)\n",
374 | " x = torch.cat(chunks, dim=1).permute(1, 0, 2) # [n, ?, C]\n",
375 | " x = x * masks.unsqueeze(1)\n",
376 | " x = torch.cat(torch.split(x, 1, dim=0), dim=1)\n",
377 | " return x.squeeze(0)\n"
378 | ],
379 | "metadata": {
380 | "id": "ZxZQ-0Fjsa11"
381 | },
382 | "execution_count": 8,
383 | "outputs": []
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": 9,
388 | "metadata": {
389 | "colab": {
390 | "base_uri": "https://localhost:8080/"
391 | },
392 | "id": "7jw12ljXCaMg",
393 | "outputId": "df05375d-7bd1-4d14-c0d5-0a43388381ce"
394 | },
395 | "outputs": [
396 | {
397 | "output_type": "stream",
398 | "name": "stdout",
399 | "text": [
400 | "MasksemblesCNN(\n",
401 | " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
402 | " (mask1): Masksembles2D()\n",
403 | " (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
404 | " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
405 | " (mask2): Masksembles2D()\n",
406 | " (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
407 | " (flatten): Flatten(start_dim=1, end_dim=-1)\n",
408 | " (mask3): Masksembles1D()\n",
409 | " (fc): Linear(in_features=1600, out_features=10, bias=True)\n",
410 | ")\n",
411 | "\n",
412 | "Total parameters: 34,826\n",
413 | "Trainable parameters: 34,826\n"
414 | ]
415 | }
416 | ],
417 | "source": [
418 | "class MasksemblesCNN(nn.Module):\n",
419 | " def __init__(self, num_classes=10):\n",
420 | " super().__init__()\n",
421 | " self.conv1 = nn.Conv2d(1, 32, kernel_size=3)\n",
422 | " self.mask1 = Masksembles2D(32, n=4, scale=2.0)\n",
423 | " self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
424 | "\n",
425 | " self.conv2 = nn.Conv2d(32, 64, kernel_size=3)\n",
426 | " self.mask2 = Masksembles2D(64, n=4, scale=2.0)\n",
427 | " self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
428 | "\n",
429 | " self.flatten = nn.Flatten()\n",
430 | " self.mask3 = Masksembles1D(64 * 5 * 5, n=4, scale=2.0)\n",
431 | " self.fc = nn.Linear(64 * 5 * 5, num_classes) # 5×5 after two 2×2 pools from 28×28 input\n",
432 | "\n",
433 | " def forward(self, x):\n",
434 | " x = F.elu(self.conv1(x))\n",
435 | " x = self.mask1(x)\n",
436 | " x = self.pool1(x)\n",
437 | "\n",
438 | " x = F.elu(self.conv2(x))\n",
439 | " x = self.mask2(x)\n",
440 | " x = self.pool2(x)\n",
441 | "\n",
442 | " x = self.flatten(x)\n",
443 | " x = self.mask3(x)\n",
444 | " x = F.softmax(self.fc(x), dim=1)\n",
445 | " return x\n",
446 | "\n",
447 | "\n",
448 | "# Instantiate and summarize\n",
449 | "model = MasksemblesCNN(num_classes=10)\n",
450 | "print(model)\n",
451 | "\n",
452 | "# Print parameter count\n",
453 | "total_params = sum(p.numel() for p in model.parameters())\n",
454 | "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
455 | "print(f\"\\nTotal parameters: {total_params:,}\")\n",
456 | "print(f\"Trainable parameters: {trainable_params:,}\")\n"
457 | ]
458 | },
459 | {
460 | "cell_type": "markdown",
461 | "metadata": {
462 | "id": "DP9Km-bGigRC"
463 | },
464 | "source": [
465 | "Training of Masksembles is not different from training of regular model. So we just use standard Pytorch training loop."
466 | ]
467 | },
468 | {
469 | "cell_type": "markdown",
470 | "source": [
471 | "\n",
472 | "## 5. Training Setup \n",
473 | "\n",
474 | "We configure:\n",
475 | "- **Loss**: `CrossEntropyLoss` (expects integer class indices, not one‑hot)\n",
476 | "- **Optimizer**: `Adam`\n",
477 | "- **DataLoaders**: mini‑batches for train and validation (90/10 split)\n"
478 | ],
479 | "metadata": {
480 | "id": "Xa-4-JNFqE0l"
481 | }
482 | },
483 | {
484 | "cell_type": "code",
485 | "source": [
486 | "# Model, loss, optimizer\n",
487 | "model = MasksemblesCNN(num_classes=10)\n",
488 | "criterion = nn.CrossEntropyLoss()\n",
489 | "optimizer = optim.Adam(model.parameters())\n",
490 | "\n",
491 | "# Training loop\n",
492 | "epochs = 3\n",
493 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
494 | "model.to(device)"
495 | ],
496 | "metadata": {
497 | "colab": {
498 | "base_uri": "https://localhost:8080/"
499 | },
500 | "id": "y-trVv38qQYA",
501 | "outputId": "1a5198b6-ecfe-44a1-d155-693af748150d"
502 | },
503 | "execution_count": 10,
504 | "outputs": [
505 | {
506 | "output_type": "execute_result",
507 | "data": {
508 | "text/plain": [
509 | "MasksemblesCNN(\n",
510 | " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
511 | " (mask1): Masksembles2D()\n",
512 | " (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
513 | " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
514 | " (mask2): Masksembles2D()\n",
515 | " (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
516 | " (flatten): Flatten(start_dim=1, end_dim=-1)\n",
517 | " (mask3): Masksembles1D()\n",
518 | " (fc): Linear(in_features=1600, out_features=10, bias=True)\n",
519 | ")"
520 | ]
521 | },
522 | "metadata": {},
523 | "execution_count": 10
524 | }
525 | ]
526 | },
527 | {
528 | "cell_type": "markdown",
529 | "source": [
530 | "\n",
531 | "## 6. Training Loop \n",
532 | "\n",
533 | "Standard PyTorch loop:\n",
534 | "1. `model.train()`, forward pass\n",
535 | "2. Compute loss, `backward()`, `step()`\n",
536 | "3. Track accuracy\n",
537 | "Then switch to `model.eval()` for validation with `torch.no_grad()`.\n"
538 | ],
539 | "metadata": {
540 | "id": "tKxCoTOZqACI"
541 | }
542 | },
543 | {
544 | "cell_type": "code",
545 | "source": [
546 | "for epoch in range(epochs):\n",
547 | " model.train()\n",
548 | " train_loss, train_correct = 0, 0\n",
549 | " for x_batch, y_batch in train_loader:\n",
550 | " x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
551 | "\n",
552 | " optimizer.zero_grad()\n",
553 | " outputs = model(x_batch)\n",
554 | " loss = criterion(outputs, y_batch)\n",
555 | " loss.backward()\n",
556 | " optimizer.step()\n",
557 | "\n",
558 | " train_loss += loss.item() * x_batch.size(0)\n",
559 | " train_correct += (outputs.argmax(1) == y_batch).sum().item()\n",
560 | "\n",
561 | " # Validation\n",
562 | " model.eval()\n",
563 | " val_loss, val_correct = 0, 0\n",
564 | " with torch.no_grad():\n",
565 | " for x_batch, y_batch in test_loader:\n",
566 | " x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
567 | " outputs = model(x_batch)\n",
568 | " loss = criterion(outputs, y_batch)\n",
569 | " val_loss += loss.item() * x_batch.size(0)\n",
570 | " val_correct += (outputs.argmax(1) == y_batch).sum().item()\n",
571 | "\n",
572 | " # Print epoch stats\n",
573 | " print(\n",
574 | " f\"Epoch {epoch+1}/{epochs} \"\n",
575 | " f\"- Train loss: {train_loss/len(train_loader.dataset):.4f}, \"\n",
576 | " f\"Train acc: {train_correct/len(train_loader.dataset):.4f}, \"\n",
577 | " f\"Val loss: {val_loss/len(test_loader.dataset):.4f}, \"\n",
578 | " f\"Val acc: {val_correct/len(test_loader.dataset):.4f}\"\n",
579 | " )\n"
580 | ],
581 | "metadata": {
582 | "colab": {
583 | "base_uri": "https://localhost:8080/"
584 | },
585 | "id": "nG18Ftg_qNqQ",
586 | "outputId": "c775536c-861c-42c8-8723-29faa4c2b0ca"
587 | },
588 | "execution_count": 11,
589 | "outputs": [
590 | {
591 | "output_type": "stream",
592 | "name": "stdout",
593 | "text": [
594 | "Epoch 1/3 - Train loss: 1.5990, Train acc: 0.8808, Val loss: 1.5153, Val acc: 0.9487\n",
595 | "Epoch 2/3 - Train loss: 1.5098, Train acc: 0.9545, Val loss: 1.4981, Val acc: 0.9648\n",
596 | "Epoch 3/3 - Train loss: 1.4985, Train acc: 0.9646, Val loss: 1.4927, Val acc: 0.9695\n"
597 | ]
598 | }
599 | ]
600 | },
601 | {
602 | "cell_type": "markdown",
603 | "metadata": {
604 | "id": "25b935df"
605 | },
606 | "source": [
607 | "\n",
608 | "## 7. Evaluation & Inference \n",
609 | "\n",
610 | "For evaluation, we disable gradients (`torch.no_grad()`) and call `model.eval()` to turn off training‑time behavior. \n",
611 | "We then compute `argmax` over class logits to get predicted labels.\n",
612 | "\n",
613 | "### **Important Notes on Inference :** Batch tiling and mask assignment\n",
614 | "\n",
615 | "Masksembles layers divide each batch into *N* segments — one for each submodel (mask). \n",
616 | "That means:\n",
617 | "- The **first** 1/N of the batch goes through the first mask, \n",
618 | "- The **second** 1/N through the second mask, and so on.\n",
619 | "\n",
620 | "So, to obtain predictions from **all submodels**, you need to **tile the same input image** *N* times along the batch dimension. \n",
621 | "Each of these replicated samples will be processed by a different submodel. \n",
622 | "After inference, you can then collect the *N* outputs to see how each submodel predicts the same input.\n",
623 | "\n"
624 | ]
625 | },
626 | {
627 | "cell_type": "markdown",
628 | "metadata": {
629 | "id": "Wh3pGxWujZ0Z"
630 | },
631 | "source": [
632 | "Now, we will check that all of Masksembles' submodels would predict similar predictions for training samples."
633 | ]
634 | },
635 | {
636 | "cell_type": "code",
637 | "execution_count": 12,
638 | "metadata": {
639 | "id": "T3GNT8Z5kItR"
640 | },
641 | "outputs": [],
642 | "source": [
643 | "img = train_dataset[0][0] # just random image from training set"
644 | ]
645 | },
646 | {
647 | "cell_type": "code",
648 | "execution_count": 13,
649 | "metadata": {
650 | "colab": {
651 | "base_uri": "https://localhost:8080/",
652 | "height": 430
653 | },
654 | "id": "o0K51rvqkMIs",
655 | "outputId": "dae2318b-5899-4402-dd06-f73fad962cb2"
656 | },
657 | "outputs": [
658 | {
659 | "output_type": "display_data",
660 | "data": {
661 | "text/plain": [
662 | ""
663 | ],
664 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHE1JREFUeJzt3X9w1PW97/HXAskKmiyNIb9KwIA/sALxFiVmQMSSS0jnOICMB390BrxeHDF4imj1xlGR1jNp8Y61eqne06lEZ8QfnBGojuWOBhOONaEDShlu25TQWOIhCRUnuyFICMnn/sF160ICftZd3kl4Pma+M2T3++b78evWZ7/ZzTcB55wTAADn2DDrBQAAzk8ECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmBhhvYBT9fb26uDBg0pLS1MgELBeDgDAk3NOHR0dysvL07Bh/V/nDLgAHTx4UPn5+dbLAAB8Q83NzRo7dmy/zw+4AKWlpUmSZur7GqEU49UAAHydULc+0DvR/573J2kBWrdunZ566im1traqsLBQzz33nKZPn37WuS+/7TZCKRoRIEAAMOj8/zuMnu1tlKR8COH111/XqlWrtHr1an300UcqLCxUaWmpDh06lIzDAQAGoaQE6Omnn9ayZct055136jvf+Y5eeOEFjRo1Si+++GIyDgcAGIQSHqDjx49r165dKikp+cdBhg1TSUmJ6urqTtu/q6tLkUgkZgMADH0JD9Bnn32mnp4eZWdnxzyenZ2t1tbW0/avrKxUKBSKbnwCDgDOD+Y/iFpRUaFwOBzdmpubrZcEADgHEv4puMzMTA0fPlxtbW0xj7e1tSknJ+e0/YPBoILBYKKXAQAY4BJ+BZSamqpp06apuro6+lhvb6+qq6tVXFyc6MMBAAappPwc0KpVq7RkyRJdc801mj59up555hl1dnbqzjvvTMbhAACDUFICtHjxYv3973/X448/rtbWVl199dXaunXraR9MAACcvwLOOWe9iK+KRCIKhUKarfncCQEABqETrls12qJwOKz09PR+9zP/FBwA4PxEgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmBhhvQBgIAmM8P+fxPAxmUlYSWI0PHhJXHM9o3q9Z8ZPPOQ9M+regPdM69Op3jMfXfO694wkfdbT6T1TtPEB75lLV9V7zwwFXAEBAEwQIACAiYQH6IknnlAgEIjZJk2alOjDAAAGuaS8B3TVVVfpvffe+8dB4vi+OgBgaEtKGUaMGKGcnJxk/NUAgCEiKe8B7du3T3l5eZowYYLuuOMOHThwoN99u7q6FIlEYjYAwNCX8AAVFRWpqqpKW7du1fPPP6+mpiZdf/316ujo6HP/yspKhUKh6Jafn5/oJQEABqCEB6isrEy33HKLpk6dqtLSUr3zzjtqb2/XG2+80ef+FRUVCofD0a25uTnRSwIADEBJ/3TA6NGjdfnll6uxsbHP54PBoILBYLKXAQAYYJL+c0BHjhzR/v37lZubm+xDAQAGkYQH6MEHH1Rtba0++eQTffjhh1q4cKGGDx+u2267LdGHAgAMYgn/Ftynn36q2267TYcPH9aYMWM0c+ZM1dfXa8yYMYk+FABgEEt4gF577bVE/5UYoIZfeZn3jAumeM8cvGG098wX1/nfRFKSMkL+c/9RGN+NLoea3x5N85752f+a5z2zY8oG75mm7i+8ZyTpp23/1Xsm7z9cXMc6H3EvOACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADARNJ/IR0Gvp7Z341r7umqdd4zl6ekxnUsnFvdrsd75vHnlnrPjOj0v3Fn8cYV3jNp/3nCe0aSgp/538R01M4dcR3rfMQVEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExwN2wo2HAwrrldx/K9Zy5PaYvrWEPNAy3Xec/89Uim90zVxH/3npGkcK//Xaqzn/0wrmMNZP5nAT64AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHAzUuhES2tcc8/97BbvmX+d1+k9M3zPRd4zf7j3Oe+ZeD352VTvmcaSUd4zPe0t3jO3F9/rPSNJn/yL/0yB/hDXsXD+4goIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADDBzUgRt4z1dd4zY9662Hum5/Dn3jNXTf5v3jOS9H9nveg985t/u8F7Jqv9Q++ZeATq4rtBaIH/v1rAG1dAAAATBAgAYMI7QNu3b9dNN92kvLw8BQIBbd68OeZ555wef/xx5ebmauTIkSopKdG+ffsStV4AwBDhHaDOzk4VFhZq3bp1fT6/du1aPfvss3rhhRe0Y8cOXXjhhSotLdWxY8e+8WIBAEOH94cQysrKVFZW1udzzjk988wzevTRRzV//nxJ0ssvv6zs7Gxt3rxZt9566zdbLQBgyEjoe0BNTU1qbW1VSUlJ9LFQKKSioiLV1fX9sZquri5FIpGYDQAw9CU0QK2trZKk7OzsmMezs7Ojz52qsrJSoVAouuXn5ydySQCAAcr8U3AVFRUKh8PRrbm52XpJAIBzIKEBysnJkSS1tbXFPN7W1hZ97lTBYFDp6ekxGwBg6EtogAoKCpSTk6Pq6uroY5FIRDt27FBxcXEiDwUAGOS8PwV35MgRNTY2Rr9uamrS7t27lZGRoXHjxmnlypV68sknddlll6mgoECPPfaY8vLytGDBgkSuGwAwyHkHaOfOnbrxxhujX69atUqStGTJElVVVemhhx5SZ2en7r77brW3t2vmzJnaunWrLrjggsStGgAw6AWcc856EV8ViUQUCoU0W/M1IpBivRwMUn/539fGN/dPL3jP3Pm3Od4zf5/Z4T2j3h7/GcDACdetGm1ROBw+4/v65p+CAwCcnwgQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDC+9cxAIPBlQ//Ja65O6f439l6/fjqs+90ihtuKfeeSXu93nsGGMi4AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHAzUgxJPe3huOYOL7/Se+bAb77wnvkfT77sPVPxzwu9Z9zHIe8ZScr/1zr/IefiOhbOX1wBAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmuBkp8BW9f/iT98yta37kPfPK6v/pPbP7Ov8bmOo6/xFJuurCFd4zl/2qxXvmxF8/8Z7B0MEVEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgIuCcc9aL+KpIJKJQKKTZmq8RgRTr5QBJ4WZc7T2T/tNPvWdenfB/vGfiNen9/+49c8WasPdMz76/es/g3DrhulWjLQqHw0pPT+93P66AAAAmCBAAwIR3gLZv366bbrpJeXl5CgQC2rx5c8zzS5cuVSAQiNnmzZuXqPUCAIYI7wB1dnaqsLBQ69at63efefPmqaWlJbq9+uqr32iRAIChx/s3opaVlamsrOyM+wSDQeXk5MS9KADA0JeU94BqamqUlZWlK664QsuXL9fhw4f73berq0uRSCRmAwAMfQkP0Lx58/Tyyy+rurpaP/vZz1RbW6uysjL19PT0uX9lZaVCoVB0y8/PT/SSAAADkPe34M7m1ltvjf55ypQpmjp1qiZOnKiamhrNmTPntP0rKiq0atWq6NeRSIQIAcB5IOkfw54wYYIyMzPV2NjY5/PBYFDp6ekxGwBg6Et6gD799FMdPnxYubm5yT4UAGAQ8f4W3JEjR2KuZpqamrR7925lZGQoIyNDa9as0aJFi5STk6P9+/froYce0qWXXqrS0tKELhwAMLh5B2jnzp268cYbo19/+f7NkiVL9Pzzz2vPnj166aWX1N7erry8PM2dO1c/+clPFAwGE7dqAMCgx81IgUFieHaW98zBxZfGdawdD//Ce2ZYHN/Rv6NprvdMeGb/P9aBgYGbkQIABjQCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYSPiv5AaQHD1th7xnsp/1n5GkYw+d8J4ZFUj1nvnVJW97z/zTwpXeM6M27fCeQfJxBQQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmOBmpICB3plXe8/sv+UC75nJV3/iPSPFd2PReDz3+X/xnhm1ZWcSVgILXAEBAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa4GSnwFYFrJnvP/OVf/G/c+asZL3nPzLrguPfMudTlur1n6j8v8D9Qb4v/DAYkroAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABPcjBQD3oiC8d4z++/Mi+tYTyx+zXtm0UWfxXWsgeyRtmu8Z2p/cZ33zLdeqvOewdDBFRAAwAQBAgCY8ApQZWWlrr32WqWlpSkrK0sLFixQQ0NDzD7Hjh1TeXm5Lr74Yl100UVatGiR2traErpoAMDg5xWg2tpalZeXq76+Xu+++666u7s1d+5cdXZ2Rve5//779dZbb2njxo2qra3VwYMHdfPNNyd84QCAwc3rQwhbt26N+bqqqkpZWVnatWuXZs2apXA4rF//+tfasGGDvve970mS1q9fryuvvFL19fW67jr/NykBAEPTN3oPKBwOS5IyMjIkSbt27VJ3d7dKSkqi+0yaNEnjxo1TXV3fn3bp6upSJBKJ2QAAQ1/cAert7dXKlSs1Y8YMTZ48WZLU2tqq1NRUjR49Ombf7Oxstba29vn3VFZWKhQKRbf8/Px4lwQAGETiDlB5ebn27t2r117z/7mJr6qoqFA4HI5uzc3N3+jvAwAMDnH9IOqKFSv09ttva/v27Ro7dmz08ZycHB0/flzt7e0xV0FtbW3Kycnp8+8KBoMKBoPxLAMAMIh5XQE557RixQpt2rRJ27ZtU0FBQczz06ZNU0pKiqqrq6OPNTQ06MCBAyouLk7MigEAQ4LXFVB5ebk2bNigLVu2KC0tLfq+TigU0siRIxUKhXTXXXdp1apVysjIUHp6uu677z4VFxfzCTgAQAyvAD3//POSpNmzZ8c8vn79ei1dulSS9POf/1zDhg3TokWL1NXVpdLSUv3yl79MyGIBAENHwDnnrBfxVZFIRKFQSLM1XyMCKdbLwRmMuGSc90x4Wq73zOIfbz37Tqe4Z/RfvWcGugda/L+LUPdL/5uKSlJG1e/9h3p74joWhp4Trls12qJwOKz09PR+9+NecAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADAR129ExcA1Irfv3zx7Jp+/eGFcx1peUOs9c1taW1zHGshW/OdM75mPnr/aeybz3/d6z2R01HnPAOcKV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAluRnqOHC+9xn/m/s+9Zx659B3vmbkjO71nBrq2ni/impv1mwe8ZyY9+mfvmYx2/5uE9npPAAMbV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAluRnqOfLLAv/V/mbIxCStJnHXtE71nflE713sm0BPwnpn0ZJP3jCRd1rbDe6YnriMB4AoIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADARcM4560V8VSQSUSgU0mzN14hAivVyAACeTrhu1WiLwuGw0tPT+92PKyAAgAkCBAAw4RWgyspKXXvttUpLS1NWVpYWLFighoaGmH1mz56tQCAQs91zzz0JXTQAYPDzClBtba3Ky8tVX1+vd999V93d3Zo7d646Oztj9lu2bJlaWlqi29q1axO6aADA4Of1G1G3bt0a83VVVZWysrK0a9cuzZo1K/r4qFGjlJOTk5gVAgCGpG/0HlA4HJYkZWRkxDz+yiuvKDMzU5MnT1ZFRYWOHj3a79/R1dWlSCQSswEAhj6vK6Cv6u3t1cqVKzVjxgxNnjw5+vjtt9+u8ePHKy8vT3v27NHDDz+shoYGvfnmm33+PZWVlVqzZk28ywAADFJx/xzQ8uXL9dvf/lYffPCBxo4d2+9+27Zt05w5c9TY2KiJEyee9nxXV5e6urqiX0ciEeXn5/NzQAAwSH3dnwOK6wpoxYoVevvtt7V9+/YzxkeSioqKJKnfAAWDQQWDwXiWAQAYxLwC5JzTfffdp02bNqmmpkYFBQVnndm9e7ckKTc3N64FAgCGJq8AlZeXa8OGDdqyZYvS0tLU2toqSQqFQho5cqT279+vDRs26Pvf/74uvvhi7dmzR/fff79mzZqlqVOnJuUfAAAwOHm9BxQIBPp8fP369Vq6dKmam5v1gx/8QHv37lVnZ6fy8/O1cOFCPfroo2f8PuBXcS84ABjckvIe0NlalZ+fr9raWp+/EgBwnuJecAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEyOsF3Aq55wk6YS6JWe8GACAtxPqlvSP/573Z8AFqKOjQ5L0gd4xXgkA4Jvo6OhQKBTq9/mAO1uizrHe3l4dPHhQaWlpCgQCMc9FIhHl5+erublZ6enpRiu0x3k4ifNwEufhJM7DSQPhPDjn1NHRoby8PA0b1v87PQPuCmjYsGEaO3bsGfdJT08/r19gX+I8nMR5OInzcBLn4STr83CmK58v8SEEAIAJAgQAMDGoAhQMBrV69WoFg0HrpZjiPJzEeTiJ83AS5+GkwXQeBtyHEAAA54dBdQUEABg6CBAAwAQBAgCYIEAAABODJkDr1q3TJZdcogsuuEBFRUX6/e9/b72kc+6JJ55QIBCI2SZNmmS9rKTbvn27brrpJuXl5SkQCGjz5s0xzzvn9Pjjjys3N1cjR45USUmJ9u3bZ7PYJDrbeVi6dOlpr4958+bZLDZJKisrde211yotLU1ZWVlasGCBGhoaYvY5duyYysvLdfHFF+uiiy7SokWL1NbWZrTi5Pg652H27NmnvR7uueceoxX3bVAE6PXXX9eqVau0evVqffTRRyosLFRpaakOHTpkvbRz7qqrrlJLS0t0++CDD6yXlHSdnZ0qLCzUunXr+nx+7dq1evbZZ/XCCy9ox44duvDCC1VaWqpjx46d45Um19nOgyTNmzcv5vXx6quvnsMVJl9tba3Ky8tVX1+vd999V93d3Zo7d646Ozuj+9x///166623tHHjRtXW1urgwYO6+eabDVedeF/nPEjSsmXLYl4Pa9euNVpxP9wgMH36dFdeXh79uqenx+Xl5bnKykrDVZ17q1evdoWFhdbLMCXJbdq0Kfp1b2+vy8nJcU899VT0sfb2dhcMBt2rr75qsMJz49Tz4JxzS5YscfPnzzdZj5VDhw45Sa62ttY5d/LffUpKitu4cWN0nz/96U9Okqurq7NaZtKdeh6cc+6GG25wP/zhD+0W9TUM+Cug48ePa9euXSopKYk+NmzYMJWUlKiurs5wZTb27dunvLw8TZgwQXfccYcOHDhgvSRTTU1Nam1tjXl9hEIhFRUVnZevj5qaGmVlZemKK67Q8uXLdfjwYeslJVU4HJYkZWRkSJJ27dql7u7umNfDpEmTNG7cuCH9ejj1PHzplVdeUWZmpiZPnqyKigodPXrUYnn9GnA3Iz3VZ599pp6eHmVnZ8c8np2drT//+c9Gq7JRVFSkqqoqXXHFFWppadGaNWt0/fXXa+/evUpLS7NenonW1lZJ6vP18eVz54t58+bp5ptvVkFBgfbv369HHnlEZWVlqqur0/Dhw62Xl3C9vb1auXKlZsyYocmTJ0s6+XpITU3V6NGjY/Ydyq+Hvs6DJN1+++0aP3688vLytGfPHj388MNqaGjQm2++abjaWAM+QPiHsrKy6J+nTp2qoqIijR8/Xm+88Ybuuusuw5VhILj11lujf54yZYqmTp2qiRMnqqamRnPmzDFcWXKUl5dr796958X7oGfS33m4++67o3+eMmWKcnNzNWfOHO3fv18TJ04818vs04D/FlxmZqaGDx9+2qdY2tralJOTY7SqgWH06NG6/PLL1djYaL0UM1++Bnh9nG7ChAnKzMwckq+PFStW6O2339b7778f8+tbcnJydPz4cbW3t8fsP1RfD/2dh74UFRVJ0oB6PQz4AKWmpmratGmqrq6OPtbb26vq6moVFxcbrszekSNHtH//fuXm5lovxUxBQYFycnJiXh+RSEQ7duw4718fn376qQ4fPjykXh/OOa1YsUKbNm3Stm3bVFBQEPP8tGnTlJKSEvN6aGho0IEDB4bU6+Fs56Evu3fvlqSB9Xqw/hTE1/Haa6+5YDDoqqqq3B//+Ed39913u9GjR7vW1lbrpZ1TDzzwgKupqXFNTU3ud7/7nSspKXGZmZnu0KFD1ktLqo6ODvfxxx+7jz/+2ElyTz/9tPv444/d3/72N+eccz/96U/d6NGj3ZYtW9yePXvc/PnzXUFBgfviiy+MV55YZzoPHR0d7sEHH3R1dXWuqanJvffee+673/2uu+yyy9yxY8esl54wy5cvd6FQyNXU1LiWlpbodvTo0eg+99xzjxs3bpzbtm2b27lzpysuLnbFxcWGq068s52HxsZG9+Mf/9jt3LnTNTU1uS1btrgJEya4WbNmGa881qAIkHPOPffcc27cuHEuNTXVTZ8+3dXX11sv6ZxbvHixy83Ndampqe7b3/62W7x4sWtsbLReVtK9//77TtJp25IlS5xzJz+K/dhjj7ns7GwXDAbdnDlzXENDg+2ik+BM5+Ho0aNu7ty5bsyYMS4lJcWNHz/eLVu2bMj9n7S+/vklufXr10f3+eKLL9y9997rvvWtb7lRo0a5hQsXupaWFrtFJ8HZzsOBAwfcrFmzXEZGhgsGg+7SSy91P/rRj1w4HLZd+Cn4dQwAABMD/j0gAMDQRIAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY+H+FuPwJ5J7kjwAAAABJRU5ErkJggg==\n"
665 | },
666 | "metadata": {}
667 | }
668 | ],
669 | "source": [
670 | "plt.imshow(img.permute(1, 2, 0))\n",
671 | "plt.show()"
672 | ]
673 | },
674 | {
675 | "cell_type": "markdown",
676 | "metadata": {
677 | "id": "yfIVxwixkQdY"
678 | },
679 | "source": [
680 | "To acquire predictions from different submodels one should transform input (with shape [1, H, W, C]) into batch (with shape [M, H, W, C]) that consists of M copies of original input (H - height of image, W - width of image, C - number of channels).\n",
681 | "\n",
682 | "As we can see Masksembles submodels produce similar predictions for training set samples."
683 | ]
684 | },
685 | {
686 | "cell_type": "code",
687 | "execution_count": 18,
688 | "metadata": {
689 | "colab": {
690 | "base_uri": "https://localhost:8080/"
691 | },
692 | "id": "thkg1KJJjY85",
693 | "outputId": "f40287cb-fd9b-4301-c175-ef3488e37ee8"
694 | },
695 | "outputs": [
696 | {
697 | "output_type": "stream",
698 | "name": "stdout",
699 | "text": [
700 | "PREDICTION OF 1 MODEL: 3 CLASS\n",
701 | "PREDICTION OF 2 MODEL: 5 CLASS\n",
702 | "PREDICTION OF 3 MODEL: 5 CLASS\n",
703 | "PREDICTION OF 4 MODEL: 5 CLASS\n"
704 | ]
705 | }
706 | ],
707 | "source": [
708 | "# Suppose img has shape (28, 28, 1)\n",
709 | "inputs = torch.tile(img[None], (4, 1, 1, 1)) # (4, 1, 28, 28)\n",
710 | "\n",
711 | "# Send to same device as model\n",
712 | "device = next(model.parameters()).device\n",
713 | "inputs_t = inputs.to(device)\n",
714 | "\n",
715 | "# Inference\n",
716 | "model.eval()\n",
717 | "with torch.no_grad():\n",
718 | " predictions = model(inputs_t) # (4, num_classes)\n",
719 | " predicted_classes = predictions.argmax(dim=1)\n",
720 | "\n",
721 | "# Print predictions\n",
722 | "for i, cls in enumerate(predicted_classes):\n",
723 | " print(f\"PREDICTION OF {i+1} MODEL: {cls.item()} CLASS\")\n"
724 | ]
725 | },
726 | {
727 | "cell_type": "markdown",
728 | "metadata": {
729 | "id": "KH81gNyGlcnh"
730 | },
731 | "source": [
732 | "On out-of-distribution samples Masksembles should produce predictions with high variance, let's check it on complex samples from MNIST."
733 | ]
734 | },
735 | {
736 | "cell_type": "code",
737 | "execution_count": 19,
738 | "metadata": {
739 | "id": "l9rbe1s9fBDf"
740 | },
741 | "outputs": [],
742 | "source": [
743 | "img = np.array(np.load(\"./complex_sample_mnist.npy\")[::-1, ::-1])\n",
744 | "img = torch.tensor(img).permute(2, 0, 1).float()"
745 | ]
746 | },
747 | {
748 | "cell_type": "code",
749 | "execution_count": 20,
750 | "metadata": {
751 | "colab": {
752 | "base_uri": "https://localhost:8080/",
753 | "height": 448
754 | },
755 | "id": "n6YghcdrfdkZ",
756 | "outputId": "d0cda917-f87b-40c9-a424-c5a2c214f644"
757 | },
758 | "outputs": [
759 | {
760 | "output_type": "execute_result",
761 | "data": {
762 | "text/plain": [
763 | ""
764 | ]
765 | },
766 | "metadata": {},
767 | "execution_count": 20
768 | },
769 | {
770 | "output_type": "display_data",
771 | "data": {
772 | "text/plain": [
773 | ""
774 | ],
775 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHNBJREFUeJzt3X9w1fW95/HXSSAHkOTEEPKrBBpQwQqktyhpFqVYMoS01wvKuP7qLDgurjS4RWp101XRtjtpcdY6elOde7cF3Sv+miswWksvBhPWNuCCUC7bmiVpLKEkodKbnBBMyI/P/sF67JEg/RzO4Z2E52PmO0PO+b7yffP1i698c04+CTjnnAAAuMCSrAcAAFycKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYGGU9wKcNDAzo6NGjSk1NVSAQsB4HAODJOafOzk7l5eUpKens9zlDroCOHj2q/Px86zEAAOepublZkyZNOuvzQ66AUlNTJUnX6msapdHG0wAAfPWpV+/ozcj/z88mYQVUVVWlxx9/XK2trSosLNTTTz+tuXPnnjP38bfdRmm0RgUoIAAYdv7/CqPnehklIW9CePnll7V27VqtW7dO7733ngoLC1VaWqpjx44l4nAAgGEoIQX0xBNPaOXKlbrzzjv1hS98Qc8++6zGjRunn/3sZ4k4HABgGIp7AZ06dUp79+5VSUnJJwdJSlJJSYnq6urO2L+np0fhcDhqAwCMfHEvoA8//FD9/f3Kzs6Oejw7O1utra1n7F9ZWalQKBTZeAccAFwczH8QtaKiQh0dHZGtubnZeiQAwAUQ93fBZWZmKjk5WW1tbVGPt7W1KScn54z9g8GggsFgvMcAAAxxcb8DSklJ0Zw5c1RdXR15bGBgQNXV1SouLo734QAAw1RCfg5o7dq1Wr58ua6++mrNnTtXTz75pLq6unTnnXcm4nAAgGEoIQV0yy236E9/+pMeeeQRtba26otf/KK2bdt2xhsTAAAXr4BzzlkP8ZfC4bBCoZAWaAkrIVwgo3LPfG3ur/HHm6d6Z04Wd3lnxr57iXcm73/8q3dGkgY6O2PKAfhEn+tVjbaqo6NDaWlpZ93P/F1wAICLEwUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMJWQ0bw8zYMTHFOmb0eWce+5ufe2f+W8PN3plAMl9bAUMd/0oBACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACZYDRvqn5AaU27OrN97Zz7sS/POXPJH74hcr/9K3QAuLO6AAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmGAxUsSsbyDZO5MUGPA/UMA/AmDo4w4IAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACRYjhZI6TsaU+01jvnfm6kv/4J1J6vOOSM7FEAJwIXEHBAAwQQEBAEzEvYAeffRRBQKBqG3GjBnxPgwAYJhLyGtAV111ld56661PDjKKl5oAANES0gyjRo1STk5OIj41AGCESMhrQIcOHVJeXp6mTp2qO+64Q4cPHz7rvj09PQqHw1EbAGDki3sBFRUVaePGjdq2bZueeeYZNTU16brrrlNnZ+eg+1dWVioUCkW2/Hz/t/YCAIafuBdQWVmZbr75Zs2ePVulpaV688031d7erldeeWXQ/SsqKtTR0RHZmpub4z0SAGAISvi7A9LT03XFFVeooaFh0OeDwaCCwWCixwAADDEJ/zmgEydOqLGxUbm5uYk+FABgGIl7Ad1///2qra3VBx98oF//+te68cYblZycrNtuuy3ehwIADGNx/xbckSNHdNttt+n48eOaOHGirr32Wu3atUsTJ06M96EAAMNY3AvopZdeivenRIIF+vpjy3X6Xz5XjGnxzhz/cq93Jnt7pndGkgY+OPuPDACIL9aCAwCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYCLhv5AOw8DJj2KKpTb5f/2SltTtndmw4Gfemf+6baV3RpJSjxz1zri+vpiOBVzsuAMCAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJhgNWyor7Utptzn/sl/FejH/7bUO/MvV27xzrR8PbYVqsc3f8E7k9zU6n+g/n7/TAxcjCudD3T3xBC6MH8njBzcAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADDBYqSIWf/xP3tn/vzadO/MC/85yzvzP+f/o3dGkn599eXemdoPr/DO9PaneGcCAeed+eB/+Z9vSSp47d+8MwO/+V1Mx8LFizsgAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJliMFLFz/otj5v5zg3fm77tu9s78x/+y1TsjSd/JaLwgmVgc6Tvhnfm7E3fFdKzuX6d6Z1J+E9OhcBHjDggAYIICAgCY8C6gnTt36oYbblBeXp4CgYC2bNkS9bxzTo888ohyc3M1duxYlZSU6NChQ/GaFwAwQngXUFdXlwoLC1VVVTXo8+vXr9dTTz2lZ599Vrt379Yll1yi0tJSdXd3n/ewAICRw/tNCGVlZSorKxv0OeecnnzyST300ENasmSJJOn5559Xdna2tmzZoltvvfX8pgUAjBhxfQ2oqalJra2tKikpiTwWCoVUVFSkurq6QTM9PT0Kh8NRGwBg5ItrAbW2tkqSsrOzox7Pzs6OPPdplZWVCoVCkS0/Pz+eIwEAhijzd8FVVFSoo6MjsjU3N1uPBAC4AOJaQDk5OZKktra2qMfb2toiz31aMBhUWlpa1AYAGPniWkAFBQXKyclRdXV15LFwOKzdu3eruLg4nocCAAxz3u+CO3HihBoaPllOpampSfv371dGRoYmT56sNWvW6Ac/+IEuv/xyFRQU6OGHH1ZeXp6WLl0az7kBAMOcdwHt2bNH119/feTjtWvXSpKWL1+ujRs36oEHHlBXV5fuvvtutbe369prr9W2bds0ZsyY+E0NABj2As7FsKJkAoXDYYVCIS3QEo0KjLYeB0NAcuYE70zD05NiOtb//cpzMeV8/Z9TH3ln/v0/fNs7k7+90zsjSYH3P/DODHTGdiyMPH2uVzXaqo6Ojs98Xd/8XXAAgIsTBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMCE969jAM7HqCn53pkjN/lnvjn7Te+MJJ0cOOWd2XfK/5/R3f/ov7L15//psHemv6XVOyNJA319MeUAH9wBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMMFipIhZIBj0zhy6Z5J35o3bH/fOTEwKeGck6d4ji7wzjd+70juT/y/vemf6WCAUIwx3QAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAEywGCmkpOSYYoEZU70zt31tp3emYNQY78zf7P4P3hlJCv4yzTuTtWOfd2aAhUUB7oAAADYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYDFSKDkjPaZcy7xLvTO3hv63d+bnJ7O8M6EXx3tnJCn01vv+oUvGeUeSp03xzgyMHe2diVXykT95Z/pa2xIwCUYy7oAAACYoIACACe8C2rlzp2644Qbl5eUpEAhoy5YtUc+vWLFCgUAgalu8eHG85gUAjBDeBdTV1aXCwkJVVVWddZ/FixerpaUlsr344ovnNSQAYOTxfhNCWVmZysrKPnOfYDConJycmIcCAIx8CXkNqKamRllZWZo+fbpWrVql48ePn3Xfnp4ehcPhqA0AMPLFvYAWL16s559/XtXV1frRj36k2tpalZWVqb+/f9D9KysrFQqFIlt+fn68RwIADEFx/zmgW2+9NfLnWbNmafbs2Zo2bZpqamq0cOHCM/avqKjQ2rVrIx+Hw2FKCAAuAgl/G/bUqVOVmZmphoaGQZ8PBoNKS0uL2gAAI1/CC+jIkSM6fvy4cnNzE30oAMAw4v0tuBMnTkTdzTQ1NWn//v3KyMhQRkaGHnvsMS1btkw5OTlqbGzUAw88oMsuu0ylpaVxHRwAMLx5F9CePXt0/fXXRz7++PWb5cuX65lnntGBAwf03HPPqb29XXl5eVq0aJG+//3vKxgMxm9qAMCw511ACxYskHPurM//8pe/PK+BYCA9ttfdOq4a/J2Nn2VAAe/M73v8FyM9Piu27y4fLZvmncnM8v/RgS9OPOKdyUjp8s7Equ5YgXem/Zf/zjsz6efHvDP99YO/nozhh7XgAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAm4v4ruTH8BD7qiSk39o/J3pn+GFbDvjH1gHdm/M3d3hlJWp72B+9MMDDaO3PgVGzz+bpqdEpswaz3vCPfy5nlnXl13Fe8M59/ecA703/o994ZJB53QAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAEywGCnUf+zDmHJTXgl6Z27TWu9M0jXt3pm0sbEt9vlO6DLvzL8ey/POBLZd6p0JtjvvzL/NiO1rzK//7S7vzH/P9V/AtOnrE7wz7zdf5Z25lMVIhyTugAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJhgMVLI9Z6KKdff0OSdmfL3x70zPVdf7p3pGzfeOyNJv83O9s6k/7HPOzNm52+8MwMnT3pnLk0PeWck6bXPfck7E8tipONH+V97/f5r4GKI4g4IAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACRYjxQXVHw57Z0bt2Ouf8U6cNibGnK+BGDKBUf5/q0AoLYYjSeq9MF+bpiT5L+SqQPzngA3ugAAAJiggAIAJrwKqrKzUNddco9TUVGVlZWnp0qWqr6+P2qe7u1vl5eWaMGGCxo8fr2XLlqmtrS2uQwMAhj+vAqqtrVV5ebl27dql7du3q7e3V4sWLVJXV1dkn/vuu0+vv/66Xn31VdXW1uro0aO66aab4j44AGB483pVc9u2bVEfb9y4UVlZWdq7d6/mz5+vjo4O/fSnP9WmTZv01a9+VZK0YcMGXXnlldq1a5e+/OUvx29yAMCwdl6vAXV0dEiSMjIyJEl79+5Vb2+vSkpKIvvMmDFDkydPVl1d3aCfo6enR+FwOGoDAIx8MRfQwMCA1qxZo3nz5mnmzJmSpNbWVqWkpCg9PT1q3+zsbLW2tg76eSorKxUKhSJbfn5+rCMBAIaRmAuovLxcBw8e1EsvvXReA1RUVKijoyOyNTc3n9fnAwAMDzH9vN7q1av1xhtvaOfOnZo0aVLk8ZycHJ06dUrt7e1Rd0FtbW3KyckZ9HMFg0EFg8FYxgAADGNed0DOOa1evVqbN2/Wjh07VFBQEPX8nDlzNHr0aFVXV0ceq6+v1+HDh1VcXByfiQEAI4LXHVB5ebk2bdqkrVu3KjU1NfK6TigU0tixYxUKhXTXXXdp7dq1ysjIUFpamu69914VFxfzDjgAQBSvAnrmmWckSQsWLIh6fMOGDVqxYoUk6cc//rGSkpK0bNky9fT0qLS0VD/5yU/iMiwAYOTwKiDn3Dn3GTNmjKqqqlRVVRXzUMCIF/BfUTN5YqZ35v0f+Gck6bl5/+Cdeben1zuzdf8XvTOXH/zIO4OhibXgAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmYvqNqAAMJPl/vegG/FfdlqSdJ2Z4ZzYc8P+lkwUve0cU+NV+/xCGJO6AAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmGAxUsCCc96RvqMt3pnp/+nP3hlJ+tXoid6ZK3p/550ZONXrncHIwR0QAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAEyxGCgwXMSxgOtDdHduxYs0BHrgDAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACa8Cqqys1DXXXKPU1FRlZWVp6dKlqq+vj9pnwYIFCgQCUds999wT16EBAMOfVwHV1taqvLxcu3bt0vbt29Xb26tFixapq6srar+VK1eqpaUlsq1fvz6uQwMAhj+v34i6bdu2qI83btyorKws7d27V/Pnz488Pm7cOOXk5MRnQgDAiHRerwF1dHRIkjIyMqIef+GFF5SZmamZM2eqoqJCJ0+ePOvn6OnpUTgcjtoAACOf1x3QXxoYGNCaNWs0b948zZw5M/L47bffrilTpigvL08HDhzQgw8+qPr6er322muDfp7Kyko99thjsY4BABimAs45F0tw1apV+sUvfqF33nlHkyZNOut+O3bs0MKFC9XQ0KBp06ad8XxPT496enoiH4fDYeXn52uBlmhUYHQsowEADPW5XtVoqzo6OpSWlnbW/WK6A1q9erXeeOMN7dy58zPLR5KKiook6awFFAwGFQwGYxkDADCMeRWQc0733nuvNm/erJqaGhUUFJwzs3//fklSbm5uTAMCAEYmrwIqLy/Xpk2btHXrVqWmpqq1tVWSFAqFNHbsWDU2NmrTpk362te+pgkTJujAgQO67777NH/+fM2ePTshfwEAwPDk9RpQIBAY9PENGzZoxYoVam5u1je+8Q0dPHhQXV1dys/P14033qiHHnroM78P+JfC4bBCoRCvAQHAMJWQ14DO1VX5+fmqra31+ZQAgIsUa8EBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAEyMsh7g05xzkqQ+9UrOeBgAgLc+9Ur65P/nZzPkCqizs1OS9I7eNJ4EAHA+Ojs7FQqFzvp8wJ2roi6wgYEBHT16VKmpqQoEAlHPhcNh5efnq7m5WWlpaUYT2uM8nMZ5OI3zcBrn4bShcB6cc+rs7FReXp6Sks7+Ss+QuwNKSkrSpEmTPnOftLS0i/oC+xjn4TTOw2mch9M4D6dZn4fPuvP5GG9CAACYoIAAACaGVQEFg0GtW7dOwWDQehRTnIfTOA+ncR5O4zycNpzOw5B7EwIA4OIwrO6AAAAjBwUEADBBAQEATFBAAAATw6aAqqqq9PnPf15jxoxRUVGR3n33XeuRLrhHH31UgUAgapsxY4b1WAm3c+dO3XDDDcrLy1MgENCWLVuinnfO6ZFHHlFubq7Gjh2rkpISHTp0yGbYBDrXeVixYsUZ18fixYtthk2QyspKXXPNNUpNTVVWVpaWLl2q+vr6qH26u7tVXl6uCRMmaPz48Vq2bJna2tqMJk6Mv+Y8LFiw4Izr4Z577jGaeHDDooBefvllrV27VuvWrdN7772nwsJClZaW6tixY9ajXXBXXXWVWlpaIts777xjPVLCdXV1qbCwUFVVVYM+v379ej311FN69tlntXv3bl1yySUqLS1Vd3f3BZ40sc51HiRp8eLFUdfHiy++eAEnTLza2lqVl5dr165d2r59u3p7e7Vo0SJ1dXVF9rnvvvv0+uuv69VXX1Vtba2OHj2qm266yXDq+PtrzoMkrVy5Mup6WL9+vdHEZ+GGgblz57ry8vLIx/39/S4vL89VVlYaTnXhrVu3zhUWFlqPYUqS27x5c+TjgYEBl5OT4x5//PHIY+3t7S4YDLoXX3zRYMIL49PnwTnnli9f7pYsWWIyj5Vjx445Sa62ttY5d/q//ejRo92rr74a2ed3v/udk+Tq6uqsxky4T58H55z7yle+4r71rW/ZDfVXGPJ3QKdOndLevXtVUlISeSwpKUklJSWqq6sznMzGoUOHlJeXp6lTp+qOO+7Q4cOHrUcy1dTUpNbW1qjrIxQKqaio6KK8PmpqapSVlaXp06dr1apVOn78uPVICdXR0SFJysjIkCTt3btXvb29UdfDjBkzNHny5BF9PXz6PHzshRdeUGZmpmbOnKmKigqdPHnSYryzGnKLkX7ahx9+qP7+fmVnZ0c9np2drffff99oKhtFRUXauHGjpk+frpaWFj322GO67rrrdPDgQaWmplqPZ6K1tVWSBr0+Pn7uYrF48WLddNNNKigoUGNjo7773e+qrKxMdXV1Sk5Oth4v7gYGBrRmzRrNmzdPM2fOlHT6ekhJSVF6enrUviP5ehjsPEjS7bffrilTpigvL08HDhzQgw8+qPr6er322muG00Yb8gWET5SVlUX+PHv2bBUVFWnKlCl65ZVXdNdddxlOhqHg1ltvjfx51qxZmj17tqZNm6aamhotXLjQcLLEKC8v18GDBy+K10E/y9nOw9133x3586xZs5Sbm6uFCxeqsbFR06ZNu9BjDmrIfwsuMzNTycnJZ7yLpa2tTTk5OUZTDQ3p6em64oor1NDQYD2KmY+vAa6PM02dOlWZmZkj8vpYvXq13njjDb399ttRv74lJydHp06dUnt7e9T+I/V6ONt5GExRUZEkDanrYcgXUEpKiubMmaPq6urIYwMDA6qurlZxcbHhZPZOnDihxsZG5ebmWo9ipqCgQDk5OVHXRzgc1u7duy/66+PIkSM6fvz4iLo+nHNavXq1Nm/erB07dqigoCDq+Tlz5mj06NFR10N9fb0OHz48oq6Hc52Hwezfv1+Shtb1YP0uiL/GSy+95ILBoNu4caP77W9/6+6++26Xnp7uWltbrUe7oL797W+7mpoa19TU5H71q1+5kpISl5mZ6Y4dO2Y9WkJ1dna6ffv2uX379jlJ7oknnnD79u1zf/jDH5xzzv3whz906enpbuvWre7AgQNuyZIlrqCgwH300UfGk8fXZ52Hzs5Od//997u6ujrX1NTk3nrrLfelL33JXX755a67u9t69LhZtWqVC4VCrqamxrW0tES2kydPRva555573OTJk92OHTvcnj17XHFxsSsuLjacOv7OdR4aGhrc9773Pbdnzx7X1NTktm7d6qZOnermz59vPHm0YVFAzjn39NNPu8mTJ7uUlBQ3d+5ct2vXLuuRLrhbbrnF5ebmupSUFPe5z33O3XLLLa6hocF6rIR7++23naQztuXLlzvnTr8V++GHH3bZ2dkuGAy6hQsXuvr6etuhE+CzzsPJkyfdokWL3MSJE93o0aPdlClT3MqVK0fcF2mD/f0luQ0bNkT2+eijj9w3v/lNd+mll7px48a5G2+80bW0tNgNnQDnOg+HDx928+fPdxkZGS4YDLrLLrvMfec733EdHR22g38Kv44BAGBiyL8GBAAYmSggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJj4f016Bym3u3cJAAAAAElFTkSuQmCC\n"
776 | },
777 | "metadata": {}
778 | }
779 | ],
780 | "source": [
781 | "plt.imshow(img.permute(1, 2, 0))"
782 | ]
783 | },
784 | {
785 | "cell_type": "code",
786 | "execution_count": 21,
787 | "metadata": {
788 | "colab": {
789 | "base_uri": "https://localhost:8080/"
790 | },
791 | "id": "65Y_-0h9U96h",
792 | "outputId": "559893c1-38d5-48fa-bde9-db0f26c421bb"
793 | },
794 | "outputs": [
795 | {
796 | "output_type": "stream",
797 | "name": "stdout",
798 | "text": [
799 | "PREDICTION OF 1 MODEL: 3 CLASS\n",
800 | "PREDICTION OF 2 MODEL: 5 CLASS\n",
801 | "PREDICTION OF 3 MODEL: 5 CLASS\n",
802 | "PREDICTION OF 4 MODEL: 5 CLASS\n"
803 | ]
804 | }
805 | ],
806 | "source": [
807 | "# Suppose img has shape (28, 28, 1)\n",
808 | "# Duplicate it 4 times\n",
809 | "inputs = torch.tile(img[None], (4, 1, 1, 1)) # (4, 28, 28, 1)\n",
810 | "\n",
811 | "# Move to same device as model\n",
812 | "device = next(model.parameters()).device\n",
813 | "inputs_t = inputs.to(device)\n",
814 | "\n",
815 | "# Run the model\n",
816 | "model.eval()\n",
817 | "with torch.no_grad():\n",
818 | " predictions = model(inputs_t) # shape: (4, num_classes)\n",
819 | " predicted_classes = predictions.argmax(dim=1) # argmax over class dimension\n",
820 | "\n",
821 | "# Print results like in TensorFlow\n",
822 | "for i, cls in enumerate(predicted_classes):\n",
823 | " print(f\"PREDICTION OF {i+1} MODEL: {cls.item()} CLASS\")\n"
824 | ]
825 | },
826 | {
827 | "cell_type": "markdown",
828 | "metadata": {
829 | "id": "3e0fe845"
830 | },
831 | "source": [
832 | "\n",
833 | "## 8. Next Steps \n",
834 | "\n",
835 | "- Try different `n` and `scale` to explore accuracy vs. diversity. \n",
836 | "- Aggregate submodel logits (mean/median) for stronger predictions. \n",
837 | "- Visualize per‑submodel confidence for **uncertainty estimation**. \n",
838 | "- Swap MNIST for CIFAR‑10 or your own dataset to test generalization.\n"
839 | ]
840 | }
841 | ],
842 | "metadata": {
843 | "accelerator": "GPU",
844 | "colab": {
845 | "gpuType": "T4",
846 | "provenance": [],
847 | "toc_visible": true
848 | },
849 | "kernelspec": {
850 | "display_name": "Python 3",
851 | "name": "python3"
852 | },
853 | "language_info": {
854 | "codemirror_mode": {
855 | "name": "ipython",
856 | "version": 2
857 | },
858 | "file_extension": ".py",
859 | "mimetype": "text/x-python",
860 | "name": "python",
861 | "nbconvert_exporter": "python",
862 | "pygments_lexer": "ipython2",
863 | "version": "2.7.6"
864 | }
865 | },
866 | "nbformat": 4,
867 | "nbformat_minor": 0
868 | }
--------------------------------------------------------------------------------