├── test ├── __init__.py ├── test_torch.py └── test_keras.py ├── masksembles ├── exceptions.py ├── __init__.py ├── keras.py ├── torch.py └── common.py ├── setup.cfg ├── MANIFEST.in ├── .gitattributes ├── images ├── mask_logo.png ├── transition.gif └── complex_sample_mnist.npy ├── pyproject.toml ├── LICENSE ├── setup.py ├── .gitignore ├── README.md └── notebooks ├── MNIST_Masksembles_tensoflow.ipynb └── MNIST_Masksembles.ipynb /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /masksembles/exceptions.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /masksembles/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.1" -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max_line_length = 120 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.py linguist-language=python 2 | *.ipynb linguist-documentation -------------------------------------------------------------------------------- /images/mask_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikitadurasov/masksembles/HEAD/images/mask_logo.png -------------------------------------------------------------------------------- /images/transition.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikitadurasov/masksembles/HEAD/images/transition.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=65", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /images/complex_sample_mnist.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikitadurasov/masksembles/HEAD/images/complex_sample_mnist.npy -------------------------------------------------------------------------------- /test/test_torch.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | import masksembles.torch 5 | 6 | class TestCreation(unittest.TestCase): 7 | 8 | def test_init_failed(self): 9 | pass 10 | 11 | def test_init_success(self): 12 | pass 13 | -------------------------------------------------------------------------------- /test/test_keras.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | 4 | import masksembles.keras 5 | 6 | 7 | class TestCreation(unittest.TestCase): 8 | 9 | def test_init_failed(self): 10 | layer = masksembles.keras.Masksembles2D(4, 11.) 11 | self.assertRaises(ValueError, layer, tf.ones([4, 10, 4, 4])) 12 | 13 | def test_init_success(self): 14 | layer = masksembles.keras.Masksembles2D(4, 11.) 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Nikita Durasov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from setuptools import setup, find_packages 3 | 4 | HERE = pathlib.Path(__file__).parent 5 | readme = (HERE / "README.md").read_text(encoding="utf-8") 6 | 7 | setup( 8 | name="masksembles", # change if taken on PyPI 9 | version="1.1.1", # use semantic x.y.z (PyPI won’t accept re-uploads) 10 | author="Nikita Durasov", 11 | author_email="yassnda@gmail.com", 12 | description="Official implementation of Masksembles approach", 13 | long_description=readme, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/nikitadurasov/masksembles", 16 | project_urls={ 17 | "Issues": "https://github.com/nikitadurasov/masksembles/issues", 18 | "Documentation": "https://github.com/nikitadurasov/masksembles#readme", 19 | }, 20 | license="MIT", 21 | packages=find_packages(exclude=("tests", "test", "notebooks")), 22 | python_requires=">=3.8", 23 | install_requires=[ 24 | "numpy>=1.20", 25 | ], 26 | extras_require={ 27 | "torch": ["torch>=1.12"], 28 | "tensorflow": ["tensorflow>=2.9"], 29 | "dev": ["pytest", "black", "flake8"], 30 | }, 31 | include_package_data=True, 32 | classifiers=[ 33 | "Programming Language :: Python :: 3", 34 | "License :: OSI Approved :: MIT License", 35 | "Operating System :: OS Independent", 36 | ], 37 | keywords=["uncertainty", "ensembles", "deep-learning", "pytorch", "tensorflow"], 38 | ) 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Custom ignores 132 | .ipynb_checkpoints 133 | .idea 134 | dev.ipynb -------------------------------------------------------------------------------- /masksembles/keras.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from . import common 3 | 4 | 5 | class Masksembles2D(tf.keras.layers.Layer): 6 | """ 7 | :class:`Masksembles2D` is high-level class that implements Masksembles approach 8 | for 2-dimensional inputs (similar to :class:`tensorflow.keras.layers.SpatialDropout1D`). 9 | 10 | :param n: int, number of masks 11 | :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \ 12 | subnetworks correlations but at the same time decrease capacity of every individual model. 13 | 14 | Shape: 15 | * Input: (N, H, W, C) 16 | * Output: (N, H, W, C) (same shape as input) 17 | 18 | Examples: 19 | 20 | >>> m = Masksembles2D(4, 2.0) 21 | >>> inputs = tf.ones([4, 28, 28, 16]) 22 | >>> output = m(inputs) 23 | 24 | References: 25 | 26 | [1] `Masksembles for Uncertainty Estimation`, 27 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua 28 | 29 | """ 30 | 31 | def __init__(self, n: int, scale: float): 32 | super(Masksembles2D, self).__init__() 33 | 34 | self.n = n 35 | self.scale = scale 36 | 37 | def build(self, input_shape): 38 | channels = input_shape[-1] 39 | masks = common.generation_wrapper(channels, self.n, self.scale) 40 | self.masks = self.add_weight("masks", 41 | shape=masks.shape, 42 | trainable=False, 43 | dtype="float32") 44 | self.masks.assign(masks) 45 | 46 | def call(self, inputs, training=False): 47 | # inputs : [N, H, W, C] 48 | # masks : [M, C] 49 | x = tf.stack(tf.split(inputs, self.n)) 50 | # x : [M, N // M, H, W, C] 51 | # masks : [M, 1, 1, 1, C] 52 | x = x * self.masks[:, tf.newaxis, tf.newaxis, tf.newaxis] 53 | x = tf.concat(tf.split(x, self.n), axis=1) 54 | return tf.squeeze(x, axis=0) 55 | 56 | 57 | class Masksembles1D(tf.keras.layers.Layer): 58 | """ 59 | :class:`Masksembles1D` is high-level class that implements Masksembles approach 60 | for 1-dimensional inputs (similar to :class:`tensorflow.keras.layers.Dropout`). 61 | 62 | :param n: int, number of masks 63 | :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \ 64 | subnetworks correlations but at the same time decrease capacity of every individual model. 65 | 66 | Shape: 67 | * Input: (N, C) 68 | * Output: (N, C) (same shape as input) 69 | 70 | Examples: 71 | 72 | >>> m = Masksembles1D(4, 2.0) 73 | >>> inputs = tf.ones([4, 16]) 74 | >>> output = m(inputs) 75 | 76 | 77 | References: 78 | 79 | [1] `Masksembles for Uncertainty Estimation`, 80 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua 81 | 82 | """ 83 | 84 | def __init__(self, n: int, scale: float): 85 | super(Masksembles1D, self).__init__() 86 | 87 | self.n = n 88 | self.scale = scale 89 | 90 | def build(self, input_shape): 91 | channels = input_shape[-1] 92 | masks = common.generation_wrapper(channels, self.n, self.scale) 93 | self.masks = self.add_weight("masks", 94 | shape=masks.shape, 95 | trainable=False, 96 | dtype="float32") 97 | self.masks.assign(masks) 98 | 99 | def call(self, inputs, training=False): 100 | x = tf.stack(tf.split(inputs, self.n)) 101 | x = x * self.masks[:, tf.newaxis] 102 | x = tf.concat(tf.split(x, self.n), axis=1) 103 | return tf.squeeze(x, axis=0) 104 | -------------------------------------------------------------------------------- /masksembles/torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from . import common 5 | 6 | 7 | class Masksembles2D(nn.Module): 8 | """ 9 | :class:`Masksembles2D` is high-level class that implements Masksembles approach 10 | for 2-dimensional inputs (similar to :class:`torch.nn.Dropout2d`). 11 | 12 | :param channels: int, number of channels used in masks. 13 | :param n: int, number of masks 14 | :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \ 15 | subnetworks correlations but at the same time decrease capacity of every individual model. 16 | 17 | Shape: 18 | * Input: (N, C, H, W) 19 | * Output: (N, C, H, W) (same shape as input) 20 | 21 | Examples: 22 | 23 | >>> m = Masksembles2D(16, 4, 2.0) 24 | >>> input = torch.ones([4, 16, 28, 28]) 25 | >>> output = m(input) 26 | 27 | References: 28 | 29 | [1] `Masksembles for Uncertainty Estimation`, 30 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua 31 | 32 | """ 33 | def __init__(self, channels: int, n: int, scale: float): 34 | super().__init__() 35 | self.channels, self.n, self.scale = channels, n, scale 36 | masks_np = common.generation_wrapper(channels, n, scale) # numpy float64 by default 37 | masks = torch.as_tensor(masks_np, dtype=torch.float32) # make float32 here 38 | self.register_buffer('masks', masks, persistent=False) # not trainable, moves with .to() 39 | 40 | def forward(self, inputs): 41 | # make sure masks match dtype/device (usually already true because of buffer, but safe) 42 | masks = self.masks.to(dtype=inputs.dtype, device=inputs.device) 43 | 44 | batch = inputs.shape[0] 45 | # safer split even if batch % n != 0 46 | chunks = torch.chunk(inputs.unsqueeze(1), self.n, dim=0) # returns nearly equal chunks 47 | x = torch.cat(chunks, dim=1).permute(1, 0, 2, 3, 4) # [n, ?, C, H, W] 48 | x = x * masks.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # broadcast masks 49 | x = torch.cat(torch.split(x, 1, dim=0), dim=1) 50 | return x.squeeze(0) 51 | 52 | 53 | 54 | class Masksembles1D(nn.Module): 55 | """ 56 | :class:`Masksembles1D` is high-level class that implements Masksembles approach 57 | for 1-dimensional inputs (similar to :class:`torch.nn.Dropout`). 58 | 59 | :param channels: int, number of channels used in masks. 60 | :param n: int, number of masks 61 | :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \ 62 | subnetworks correlations but at the same time decrease capacity of every individual model. 63 | 64 | Shape: 65 | * Input: (N, C) 66 | * Output: (N, C) (same shape as input) 67 | 68 | Examples: 69 | 70 | >>> m = Masksembles1D(16, 4, 2.0) 71 | >>> input = torch.ones([4, 16]) 72 | >>> output = m(input) 73 | 74 | 75 | References: 76 | 77 | [1] `Masksembles for Uncertainty Estimation`, 78 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua 79 | 80 | """ 81 | 82 | def __init__(self, channels: int, n: int, scale: float): 83 | super().__init__() 84 | self.channels, self.n, self.scale = channels, n, scale 85 | masks_np = common.generation_wrapper(channels, n, scale) 86 | masks = torch.as_tensor(masks_np, dtype=torch.float32) 87 | self.register_buffer('masks', masks, persistent=False) 88 | 89 | def forward(self, inputs): 90 | masks = self.masks.to(dtype=inputs.dtype, device=inputs.device) 91 | 92 | batch = inputs.shape[0] 93 | chunks = torch.chunk(inputs.unsqueeze(1), self.n, dim=0) 94 | x = torch.cat(chunks, dim=1).permute(1, 0, 2) # [n, ?, C] 95 | x = x * masks.unsqueeze(1) 96 | x = torch.cat(torch.split(x, 1, dim=0), dim=1) 97 | return x.squeeze(0) 98 | -------------------------------------------------------------------------------- /masksembles/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def generate_masks_(m: int, n: int, s: float) -> np.ndarray: 5 | """Generates set of binary masks with properties defined by n, m, s params. 6 | 7 | Results of this function are stochastic, that is, calls with the same sets 8 | of arguments might generate outputs of different shapes. Check generate_masks 9 | and generation_wrapper function for more deterministic behaviour. 10 | 11 | :param m: int, number of ones in each mask 12 | :param n: int, number of masks in the set 13 | :param s: float, scale param controls overlap of generated masks 14 | :return: np.ndarray, matrix of binary vectors 15 | """ 16 | 17 | total_positions = int(m * s) 18 | masks = [] 19 | 20 | for _ in range(n): 21 | new_vector = np.zeros([total_positions]) 22 | idx = np.random.choice(range(total_positions), m, replace=False) 23 | new_vector[idx] = 1 24 | masks.append(new_vector) 25 | 26 | masks = np.array(masks) 27 | # drop useless positions 28 | masks = masks[:, ~np.all(masks == 0, axis=0)] 29 | return masks 30 | 31 | 32 | def generate_masks(m: int, n: int, s: float) -> np.ndarray: 33 | """Generates set of binary masks with properties defined by n, m, s params. 34 | 35 | Resulting masks are required to have fixed features size as it's described in [1]. 36 | Since process of masks generation is stochastic therefore function evaluates 37 | generate_masks_ multiple times till expected size is acquired. 38 | 39 | :param m: int, number of ones in each mask 40 | :param n: int, number of masks in the set 41 | :param s: float, scale param controls overlap of generated masks 42 | :return: np.ndarray, matrix of binary vectors 43 | 44 | References 45 | 46 | [1] `Masksembles for Uncertainty Estimation: Supplementary Material`, 47 | Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua 48 | """ 49 | 50 | masks = generate_masks_(m, n, s) 51 | # hardcoded formula for expected size, check reference 52 | expected_size = int(m * s * (1 - (1 - 1 / s) ** n)) 53 | while masks.shape[1] != expected_size: 54 | masks = generate_masks_(m, n, s) 55 | return masks 56 | 57 | 58 | def generation_wrapper(c: int, n: int, scale: float) -> np.ndarray: 59 | """Generates set of binary masks with properties defined by c, n, scale params. 60 | 61 | Allows to generate masks sets with predefined features number c. Particularly 62 | convenient to use in torch-like layers where one need to define shapes inputs 63 | tensors beforehand. 64 | 65 | :param c: int, number of channels in generated masks 66 | :param n: int, number of masks in the set 67 | :param scale: float, scale param controls overlap of generated masks 68 | :return: np.ndarray, matrix of binary vectors 69 | """ 70 | 71 | if c < 10: 72 | raise ValueError("Masksembles approach couldn't be used in such setups where " 73 | f"number of channels is less then 10. Current value is (channels={c}). " 74 | "Please increase number of features in your layer or remove this " 75 | "particular instance of Masksembles from your architecture.") 76 | 77 | if scale > 6.: 78 | raise ValueError("Masksembles approach couldn't be used in such setups where " 79 | f"scale parameter is larger then 6. Current value is (scale={scale}).") 80 | 81 | # inverse formula for number of active features in masks 82 | active_features = int(int(c) / (scale * (1 - (1 - 1 / scale) ** n))) 83 | 84 | # Fix the last part by using binary search 85 | max_iter = 1000 86 | 87 | min = np.max([scale * 0.8, 1.0]) 88 | max = scale * 1.2 89 | 90 | for _ in range(max_iter): 91 | mid = (min + max) / 2 92 | masks = generate_masks(active_features, n, mid) 93 | if masks.shape[-1] == c: 94 | break 95 | elif masks.shape[-1] > c: 96 | max = mid 97 | else: 98 | min = mid 99 | 100 | if masks.shape[-1] != c: 101 | raise ValueError("generation_wrapper function failed to generate masks with " 102 | "requested number of features. Please try to change scale parameter") 103 | 104 | return masks 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Masksembles for Uncertainty Estimation 2 | 3 | 8 | 9 | ![Project Page](./images/mask_logo.png) 10 | 11 |

12 | 13 | PyPI 14 | 15 | 16 | PyPI Downloads 17 | 18 | 19 | Monthly Downloads 20 | 21 | 22 | Python Versions 23 | 24 | 25 | GitHub stars 26 | 27 | 28 | GitHub forks 29 | 30 | 31 | Issues 32 | 33 | 34 | License 35 | 36 |

37 | 38 | 39 | ### [Project Page](https://nikitadurasov.github.io/projects/masksembles/) | [Paper](https://arxiv.org/abs/2012.08334) | [Video Explanation](https://www.youtube.com/watch?v=YWKVdn3kLp0) 40 | 41 | [![Open Masksembles in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nikitadurasov/masksembles/blob/main/notebooks/MNIST_Masksembles.ipynb) 42 | 43 | 46 | 47 | 48 | ## Why Masksembles? 49 | 50 | **Uncertainty Estimation** is one of the most important and critical tasks in the area of modern neural networks and deep learning. 51 | There is a long list of potential applications of uncertainty: safety-critical applications, active learning, domain adaptation, 52 | reinforcement learning and etc. 53 | 54 | **Masksembles** is a **simple** and **easy-to-use** drop-in method with performance on par with Deep Ensembles at a fraction of the cost. 55 | It makes *almost* no changes in your original model and requires only to add special intermediate layers. 56 | 57 | [![Watch the video](https://img.youtube.com/vi/YWKVdn3kLp0/maxresdefault.jpg)](https://youtu.be/YWKVdn3kLp0) 58 | 59 | ### [Watch this video on YouTube](https://youtu.be/YWKVdn3kLp0) 60 | 61 | ## Installation 62 | 63 | To install this package, use: 64 | 65 | ```bash 66 | pip install masksembles 67 | ``` 68 | or 69 | ```bash 70 | pip install git+http://github.com/nikitadurasov/masksembles 71 | ``` 72 | 73 | In addition, Masksembles requires installing at least one of the backends: torch or tensorflow2 / keras. 74 | Please follow official installation instructions for [torch](https://pytorch.org/) or [tensorflow](https://www.tensorflow.org/install) 75 | accordingly. 76 | 77 | 78 | ## Usage 79 | 80 | [comment]: <> (In masksembles module you could find implementations of "Masksembles{1|2|3}D" that) 81 | 82 | [comment]: <> (support different shapes of input vectors (1, 2 and 3-dimentional accordingly)) 83 | 84 | This package provides implementations for `Masksembles{1|2|3}D` layers in `masksembles.{torch|keras}` 85 | where `{1|2|3}` refers to dimensionality of input tensors (1-, 2- and 3-dimensional 86 | accordingly). 87 | 88 | * `Masksembles1D`: works with 1-dim inputs,`[B, C]` shaped tensors 89 | * `Masksembles2D`: works with 2-dim inputs,`[B, H, W, C]` (keras) or `[B, C, H, W]` (torch) shaped tensors 90 | * `Masksembles3D` : TBD 91 | 92 | In a Nutshell, Masksembles applies binary masks to inputs via multiplying them both channel-wise. For more efficient 93 | implementation we've followed approach similar to [this](https://arxiv.org/abs/2002.06715) one. Therefore, after inference 94 | `outputs[:B // N]` - stores results for the first submodel, `outputs[B // N : 2 * B // N]` - for the second and etc. 95 | ### Torch 96 | 97 | ```python 98 | import torch 99 | from masksembles.torch import Masksembles1D 100 | 101 | layer = Masksembles1D(10, 4, 2.) 102 | layer(torch.ones([4, 10])) 103 | ``` 104 | ```bash 105 | tensor([[0., 1., 0., 0., 1., 0., 1., 1., 1., 1.], 106 | [0., 0., 1., 1., 1., 1., 0., 0., 1., 1.], 107 | [1., 0., 1., 1., 0., 0., 1., 0., 1., 1.], 108 | [1., 0., 0., 1., 1., 1., 0., 1., 1., 0.]], dtype=torch.float64) 109 | 110 | ``` 111 | 112 | ### Tensorflow / Keras 113 | 114 | ```python 115 | import tensorflow as tf 116 | from masksembles.keras import Masksembles1D 117 | 118 | layer = Masksembles1D(4, 2.) 119 | layer(tf.ones([4, 10])) 120 | ``` 121 | ```bash 122 | 127 | ``` 128 | 129 | ### Model example 130 | ```python 131 | import tensorflow as tf 132 | from masksembles.keras import Masksembles1D, Masksembles2D 133 | 134 | model = keras.Sequential( 135 | [ 136 | keras.Input(shape=input_shape), 137 | layers.Conv2D(32, kernel_size=(3, 3), activation="elu"), 138 | Masksembles2D(4, 2.0), 139 | layers.MaxPooling2D(pool_size=(2, 2)), 140 | 141 | layers.Conv2D(64, kernel_size=(3, 3), activation="elu"), 142 | Masksembles2D(4, 2.0), 143 | layers.MaxPooling2D(pool_size=(2, 2)), 144 | 145 | layers.Flatten(), 146 | Masksembles1D(4, 2.), 147 | layers.Dense(num_classes, activation="softmax"), 148 | ] 149 | ) 150 | ``` 151 | 152 | ## Citation 153 | If you found this work useful for your projects, please don't forget to cite it. 154 | ``` 155 | @inproceedings{Durasov21, 156 | author = {N. Durasov and T. Bagautdinov and P. Baque and P. Fua}, 157 | title = {{Masksembles for Uncertainty Estimation}}, 158 | booktitle = CVPR, 159 | year = 2021 160 | } 161 | ``` 162 | -------------------------------------------------------------------------------- /notebooks/MNIST_Masksembles_tensoflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 2 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython2", 20 | "version": "2.7.6" 21 | }, 22 | "colab": { 23 | "name": "MNIST_Masksembles.ipynb", 24 | "provenance": [], 25 | "toc_visible": true 26 | } 27 | }, 28 | "cells": [ 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "collapsed": true, 33 | "id": "qrpn5UV2BYLF" 34 | }, 35 | "source": [ 36 | "!pip install --upgrade git+http://github.com/nikitadurasov/masksembles\n", 37 | "!wget https://github.com/nikitadurasov/masksembles/raw/main/images/complex_sample_mnist.npy" 38 | ], 39 | "execution_count": 124, 40 | "outputs": [] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "bHtAoRyhBiKr" 46 | }, 47 | "source": [ 48 | "# MNIST " 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": { 54 | "id": "eFvywyCZBlsh" 55 | }, 56 | "source": [ 57 | "## Keras" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "metadata": { 63 | "id": "bnCBbU95BuQ0" 64 | }, 65 | "source": [ 66 | "import numpy as np\n", 67 | "from tensorflow import keras\n", 68 | "from tensorflow.keras import layers\n", 69 | "import tensorflow as tf\n", 70 | "\n", 71 | "import matplotlib.pyplot as plt" 72 | ], 73 | "execution_count": 113, 74 | "outputs": [] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "metadata": { 79 | "id": "ZZNyHlDLCOt5" 80 | }, 81 | "source": [ 82 | "from masksembles.keras import Masksembles2D, Masksembles1D" 83 | ], 84 | "execution_count": 3, 85 | "outputs": [] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "metadata": { 90 | "colab": { 91 | "base_uri": "https://localhost:8080/" 92 | }, 93 | "id": "Gavnqj28CYJ7", 94 | "outputId": "c89047ee-ee6c-4ed5-cc2c-8178d8d14863" 95 | }, 96 | "source": [ 97 | "# Model / data parameters\n", 98 | "num_classes = 10\n", 99 | "input_shape = (28, 28, 1)\n", 100 | "\n", 101 | "# the data, split between train and test sets\n", 102 | "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", 103 | "\n", 104 | "# Scale images to the [0, 1] range\n", 105 | "x_train = x_train.astype(\"float32\") / 255\n", 106 | "x_test = x_test.astype(\"float32\") / 255\n", 107 | "# Make sure images have shape (28, 28, 1)\n", 108 | "x_train = np.expand_dims(x_train, -1)\n", 109 | "x_test = np.expand_dims(x_test, -1)\n", 110 | "print(\"x_train shape:\", x_train.shape)\n", 111 | "print(x_train.shape[0], \"train samples\")\n", 112 | "print(x_test.shape[0], \"test samples\")\n", 113 | "\n", 114 | "\n", 115 | "# convert class vectors to binary class matrices\n", 116 | "y_train = keras.utils.to_categorical(y_train, num_classes)\n", 117 | "y_test = keras.utils.to_categorical(y_test, num_classes)" 118 | ], 119 | "execution_count": 6, 120 | "outputs": [ 121 | { 122 | "output_type": "stream", 123 | "text": [ 124 | "x_train shape: (60000, 28, 28, 1)\n", 125 | "60000 train samples\n", 126 | "10000 test samples\n" 127 | ], 128 | "name": "stdout" 129 | } 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "Hr5xjOSffo7x" 136 | }, 137 | "source": [ 138 | "In order to transform regular model into Masksembles model one should add Masksembles2D or Masksembles1D layers in it. General recommendation is to insert these layers right before or after convolutional layers. \n", 139 | "\n", 140 | "In example below we'll use both Masksembles2D and Masksembles1D layers applied after convolutions. " 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "metadata": { 146 | "colab": { 147 | "base_uri": "https://localhost:8080/" 148 | }, 149 | "id": "7jw12ljXCaMg", 150 | "outputId": "f324695b-c4f7-4624-be34-3bb22cd687a0" 151 | }, 152 | "source": [ 153 | "model = keras.Sequential(\n", 154 | " [\n", 155 | " keras.Input(shape=input_shape),\n", 156 | " layers.Conv2D(32, kernel_size=(3, 3), activation=\"elu\"),\n", 157 | " Masksembles2D(4, 2.0), # adding Masksembles2D\n", 158 | " layers.MaxPooling2D(pool_size=(2, 2)),\n", 159 | " \n", 160 | " layers.Conv2D(64, kernel_size=(3, 3), activation=\"elu\"),\n", 161 | " Masksembles2D(4, 2.0), # adding Masksembles2D\n", 162 | " layers.MaxPooling2D(pool_size=(2, 2)),\n", 163 | " \n", 164 | " layers.Flatten(),\n", 165 | " Masksembles1D(4, 2.), # adding Masksembles1D\n", 166 | " layers.Dense(num_classes, activation=\"softmax\"),\n", 167 | " ]\n", 168 | ")\n", 169 | "\n", 170 | "model.summary()" 171 | ], 172 | "execution_count": 29, 173 | "outputs": [ 174 | { 175 | "output_type": "stream", 176 | "text": [ 177 | "Model: \"sequential_3\"\n", 178 | "_________________________________________________________________\n", 179 | "Layer (type) Output Shape Param # \n", 180 | "=================================================================\n", 181 | "conv2d_6 (Conv2D) (None, 26, 26, 32) 320 \n", 182 | "_________________________________________________________________\n", 183 | "masksembles2d_6 (Masksembles (None, 26, 26, 32) 128 \n", 184 | "_________________________________________________________________\n", 185 | "max_pooling2d_6 (MaxPooling2 (None, 13, 13, 32) 0 \n", 186 | "_________________________________________________________________\n", 187 | "conv2d_7 (Conv2D) (None, 11, 11, 64) 18496 \n", 188 | "_________________________________________________________________\n", 189 | "masksembles2d_7 (Masksembles (None, 11, 11, 64) 256 \n", 190 | "_________________________________________________________________\n", 191 | "max_pooling2d_7 (MaxPooling2 (None, 5, 5, 64) 0 \n", 192 | "_________________________________________________________________\n", 193 | "flatten_3 (Flatten) (None, 1600) 0 \n", 194 | "_________________________________________________________________\n", 195 | "masksembles1d_2 (Masksembles (None, 1600) 6400 \n", 196 | "_________________________________________________________________\n", 197 | "dense_1 (Dense) (None, 10) 16010 \n", 198 | "=================================================================\n", 199 | "Total params: 41,610\n", 200 | "Trainable params: 34,826\n", 201 | "Non-trainable params: 6,784\n", 202 | "_________________________________________________________________\n" 203 | ], 204 | "name": "stdout" 205 | } 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": { 211 | "id": "DP9Km-bGigRC" 212 | }, 213 | "source": [ 214 | "Training of Masksembles is not different from training of regular model. So we just use standard fit Keras API." 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "metadata": { 220 | "colab": { 221 | "base_uri": "https://localhost:8080/" 222 | }, 223 | "id": "44akdkDYDDQl", 224 | "outputId": "1e31da86-663a-4fdc-a344-857dedef6e4f" 225 | }, 226 | "source": [ 227 | "batch_size = 128\n", 228 | "epochs = 5\n", 229 | "\n", 230 | "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n", 231 | "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)" 232 | ], 233 | "execution_count": 30, 234 | "outputs": [ 235 | { 236 | "output_type": "stream", 237 | "text": [ 238 | "Epoch 1/5\n", 239 | "422/422 [==============================] - 54s 126ms/step - loss: 1.0014 - accuracy: 0.7048 - val_loss: 0.1698 - val_accuracy: 0.9520\n", 240 | "Epoch 2/5\n", 241 | "422/422 [==============================] - 53s 125ms/step - loss: 0.1825 - accuracy: 0.9455 - val_loss: 0.1100 - val_accuracy: 0.9693\n", 242 | "Epoch 3/5\n", 243 | "422/422 [==============================] - 53s 125ms/step - loss: 0.1235 - accuracy: 0.9628 - val_loss: 0.0824 - val_accuracy: 0.9773\n", 244 | "Epoch 4/5\n", 245 | "422/422 [==============================] - 53s 126ms/step - loss: 0.0961 - accuracy: 0.9706 - val_loss: 0.0791 - val_accuracy: 0.9767\n", 246 | "Epoch 5/5\n", 247 | "422/422 [==============================] - 53s 125ms/step - loss: 0.0850 - accuracy: 0.9742 - val_loss: 0.0674 - val_accuracy: 0.9827\n" 248 | ], 249 | "name": "stdout" 250 | }, 251 | { 252 | "output_type": "execute_result", 253 | "data": { 254 | "text/plain": [ 255 | "" 256 | ] 257 | }, 258 | "metadata": { 259 | "tags": [] 260 | }, 261 | "execution_count": 30 262 | } 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": { 268 | "id": "Wh3pGxWujZ0Z" 269 | }, 270 | "source": [ 271 | "After training we could check that all of Masksembles' submodels would predict similar predictions for training samples." 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "metadata": { 277 | "id": "T3GNT8Z5kItR" 278 | }, 279 | "source": [ 280 | "img = x_train[0] # just random image from training set" 281 | ], 282 | "execution_count": 118, 283 | "outputs": [] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "metadata": { 288 | "id": "o0K51rvqkMIs", 289 | "outputId": "2baad7d7-c8b3-41c2-c8d5-857e7754fb07", 290 | "colab": { 291 | "base_uri": "https://localhost:8080/", 292 | "height": 265 293 | } 294 | }, 295 | "source": [ 296 | "plt.imshow(img[..., 0])\n", 297 | "plt.show()" 298 | ], 299 | "execution_count": 119, 300 | "outputs": [ 301 | { 302 | "output_type": "display_data", 303 | "data": { 304 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOZ0lEQVR4nO3dbYxc5XnG8euKbezamMQbB9chLjjgFAg0Jl0ZEBZQobgOqgSoCsSKIkJpnSY4Ca0rQWlV3IpWbpUQUUqRTHExFS+BBIQ/0CTUQpCowWWhBgwEDMY0NmaNWYENIX5Z3/2w42iBnWeXmTMv3vv/k1Yzc+45c24NXD5nznNmHkeEAIx/H+p0AwDag7ADSRB2IAnCDiRB2IEkJrZzY4d5ckzRtHZuEkjlV3pbe2OPR6o1FXbbiyVdJ2mCpH+LiJWl50/RNJ3qc5rZJICC9bGubq3hw3jbEyTdIOnzkk6UtMT2iY2+HoDWauYz+wJJL0TE5ojYK+lOSedV0xaAqjUT9qMk/WLY4621Ze9ie6ntPtt9+7Snic0BaEbLz8ZHxKqI6I2I3kma3OrNAaijmbBvkzRn2ONP1JYB6ELNhP1RSfNsz7V9mKQvSlpbTVsAqtbw0FtE7Le9TNKPNDT0tjoinq6sMwCVamqcPSLul3R/Rb0AaCEulwWSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJpmZxRffzxPJ/4gkfm9nS7T/3F8fUrQ1OPVBc9+hjdxTrU7/uYv3Vaw+rW3u893vFdXcOvl2sn3r38mL9uD9/pFjvhKbCbnuLpN2SBiXtj4jeKpoCUL0q9uy/FxE7K3gdAC3EZ3YgiWbDHpJ+bPsx20tHeoLtpbb7bPft054mNwegUc0exi+MiG22j5T0gO2fR8TDw58QEaskrZKkI9wTTW4PQIOa2rNHxLba7Q5J90paUEVTAKrXcNhtT7M9/eB9SYskbayqMQDVauYwfpake20ffJ3bI+KHlXQ1zkw4YV6xHpMnFeuvnPWRYv2d0+qPCfd8uDxe/JPPlMebO+k/fzm9WP/Hf1lcrK8/+fa6tZf2vVNcd2X/54r1j//k0PtE2nDYI2KzpM9U2AuAFmLoDUiCsANJEHYgCcIOJEHYgST4imsFBs/+bLF+7S03FOufmlT/q5jj2b4YLNb/5vqvFOsT3y4Pf51+97K6tenb9hfXnbyzPDQ3tW99sd6N2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs1dg8nOvFOuP/WpOsf6pSf1VtlOp5dtPK9Y3v1X+Kepbjv1+3dqbB8rj5LP++b+L9VY69L7AOjr27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQhCPaN6J4hHviVJ/Ttu11i4FLTi/Wdy0u/9zzhCcPL9af+Pr1H7ing67Z+TvF+qNnlcfRB994s1iP0+v/APGWbxZX1dwlT5SfgPdZH+u0KwZGnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMPOjxfrg6wPF+ku31x8rf/rM1cV1F/zDN4r1I2/o3HfK8cE1Nc5ue7XtHbY3DlvWY/sB25tqtzOqbBhA9cZyGH+LpPfOen+lpHURMU/SutpjAF1s1LBHxMOS3nsceZ6kNbX7aySdX3FfACrW6G/QzYqI7bX7r0qaVe+JtpdKWipJUzS1wc0BaFbTZ+Nj6Axf3bN8EbEqInojoneSJje7OQANajTs/bZnS1Ltdkd1LQFohUbDvlbSxbX7F0u6r5p2ALTKqJ/Zbd8h6WxJM21vlXS1pJWS7rJ9qaSXJV3YyibHu8Gdrze1/r5djc/v/ukvPVOsv3bjhPILHCjPsY7uMWrYI2JJnRJXxwCHEC6XBZIg7EAShB1IgrADSRB2IAmmbB4HTrji+bq1S04uD5r8+9HrivWzvnBZsT79e48U6+ge7NmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2ceB0rTJr3/thOK6/7f2nWL9ymtuLdb/8sILivX43w/Xrc35+58V11Ubf+Y8A/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEUzYnN/BHpxfrt1397WJ97sQpDW/707cuK9bn3bS9WN+/eUvD2x6vmpqyGcD4QNiBJAg7kARhB5Ig7EAShB1IgrADSTDOjqI4Y36xfsTKrcX6HZ/8UcPbPv7BPy7Wf/tv63+PX5IGN21ueNuHqqbG2W2vtr3D9sZhy1bY3mZ7Q+3v3CobBlC9sRzG3yJp8QjLvxsR82t/91fbFoCqjRr2iHhY0kAbegHQQs2coFtm+8naYf6Mek+yvdR2n+2+fdrTxOYANKPRsN8o6VhJ8yVtl/Sdek+MiFUR0RsRvZM0ucHNAWhWQ2GPiP6IGIyIA5JukrSg2rYAVK2hsNuePezhBZI21nsugO4w6ji77TsknS1ppqR+SVfXHs+XFJK2SPpqRJS/fCzG2cejCbOOLNZfuei4urX1V1xXXPdDo+yLvvTSomL9zYWvF+vjUWmcfdRJIiJiyQiLb266KwBtxeWyQBKEHUiCsANJEHYgCcIOJMFXXNExd20tT9k81YcV67+MvcX6H3zj8vqvfe/64rqHKn5KGgBhB7Ig7EAShB1IgrADSRB2IAnCDiQx6rfekNuBheWfkn7xC+Upm0+av6VubbRx9NFcP3BKsT71vr6mXn+8Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzj7OufekYv35b5bHum86Y02xfuaU8nfKm7En9hXrjwzMLb/AgVF/3TwV9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7IeAiXOPLtZfvOTjdWsrLrqzuO4fHr6zoZ6qcFV/b7H+0HWnFesz1pR/dx7vNuqe3fYc2w/afsb207a/VVveY/sB25tqtzNa3y6ARo3lMH6/pOURcaKk0yRdZvtESVdKWhcR8yStqz0G0KVGDXtEbI+Ix2v3d0t6VtJRks6TdPBayjWSzm9VkwCa94E+s9s+RtIpktZLmhURBy8+flXSrDrrLJW0VJKmaGqjfQJo0pjPxts+XNIPJF0eEbuG12JodsgRZ4iMiFUR0RsRvZM0ualmATRuTGG3PUlDQb8tIu6pLe63PbtWny1pR2taBFCFUQ/jbVvSzZKejYhrh5XWSrpY0sra7X0t6XAcmHjMbxXrb/7u7GL9or/7YbH+px+5p1hvpeXby8NjP/vX+sNrPbf8T3HdGQcYWqvSWD6znyHpy5Kesr2htuwqDYX8LtuXSnpZ0oWtaRFAFUYNe0T8VNKIk7tLOqfadgC0CpfLAkkQdiAJwg4kQdiBJAg7kARfcR2jibN/s25tYPW04rpfm/tQsb5ken9DPVVh2baFxfrjN5anbJ75/Y3Fes9uxsq7BXt2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUgizTj73t8v/2zx3j8bKNavOu7+urVFv/F2Qz1VpX/wnbq1M9cuL657/F//vFjveaM8Tn6gWEU3Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0mkGWffcn7537XnT767Zdu+4Y1ji/XrHlpUrHuw3o/7Djn+mpfq1ub1ry+uO1isYjxhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTgiyk+w50i6VdIsSSFpVURcZ3uFpD+R9FrtqVdFRP0vfUs6wj1xqpn4FWiV9bFOu2JgxAszxnJRzX5JyyPicdvTJT1m+4Fa7bsR8e2qGgXQOmOZn327pO21+7ttPyvpqFY3BqBaH+gzu+1jJJ0i6eA1mMtsP2l7te0ZddZZarvPdt8+7WmqWQCNG3PYbR8u6QeSLo+IXZJulHSspPka2vN/Z6T1ImJVRPRGRO8kTa6gZQCNGFPYbU/SUNBvi4h7JCki+iNiMCIOSLpJ0oLWtQmgWaOG3bYl3Szp2Yi4dtjy2cOedoGk8nSeADpqLGfjz5D0ZUlP2d5QW3aVpCW252toOG6LpK+2pEMAlRjL2fifShpp3K44pg6gu3AFHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlRf0q60o3Zr0l6ediimZJ2tq2BD6Zbe+vWviR6a1SVvR0dER8bqdDWsL9v43ZfRPR2rIGCbu2tW/uS6K1R7eqNw3ggCcIOJNHpsK/q8PZLurW3bu1LordGtaW3jn5mB9A+nd6zA2gTwg4k0ZGw215s+znbL9i+shM91GN7i+2nbG+w3dfhXlbb3mF747BlPbYfsL2pdjviHHsd6m2F7W21926D7XM71Nsc2w/afsb207a/VVve0feu0Fdb3re2f2a3PUHS85I+J2mrpEclLYmIZ9raSB22t0jqjYiOX4Bh+0xJb0m6NSJOqi37J0kDEbGy9g/ljIi4okt6WyHprU5P412brWj28GnGJZ0v6Svq4HtX6OtCteF968SefYGkFyJic0TslXSnpPM60EfXi4iHJQ28Z/F5ktbU7q/R0P8sbVent64QEdsj4vHa/d2SDk4z3tH3rtBXW3Qi7EdJ+sWwx1vVXfO9h6Qf237M9tJONzOCWRGxvXb/VUmzOtnMCEadxrud3jPNeNe8d41Mf94sTtC938KI+Kykz0u6rHa42pVi6DNYN42djmka73YZYZrxX+vke9fo9OfN6kTYt0maM+zxJ2rLukJEbKvd7pB0r7pvKur+gzPo1m53dLifX+umabxHmmZcXfDedXL6806E/VFJ82zPtX2YpC9KWtuBPt7H9rTaiRPZniZpkbpvKuq1ki6u3b9Y0n0d7OVdumUa73rTjKvD713Hpz+PiLb/STpXQ2fkX5T0V53ooU5fn5T0RO3v6U73JukODR3W7dPQuY1LJX1U0jpJmyT9l6SeLurtPyQ9JelJDQVrdod6W6ihQ/QnJW2o/Z3b6feu0Fdb3jculwWS4AQdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTx/65XcTNOWsh5AAAAAElFTkSuQmCC\n", 305 | "text/plain": [ 306 | "
" 307 | ] 308 | }, 309 | "metadata": { 310 | "tags": [], 311 | "needs_background": "light" 312 | } 313 | } 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "id": "yfIVxwixkQdY" 320 | }, 321 | "source": [ 322 | "To acquire predictions from different submodels one should transform input (with shape [1, H, W, C]) into batch (with shape [M, H, W, C]) that consists of M copies of original input (H - height of image, W - width of image, C - number of channels).\n", 323 | "\n", 324 | "As we can see Masksembles submodels produce similar predictions for training set samples." 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "metadata": { 330 | "id": "thkg1KJJjY85", 331 | "outputId": "010a3f3c-f20c-427a-87c9-f08966707e0e", 332 | "colab": { 333 | "base_uri": "https://localhost:8080/" 334 | } 335 | }, 336 | "source": [ 337 | "inputs = np.tile(img[None], [4, 1, 1, 1])\n", 338 | "predictions = model(inputs)\n", 339 | "for i, cls in enumerate(tf.argmax(predictions, axis=1)):\n", 340 | " print(f\"PREDICTION OF {i+1} MODEL: {cls} CLASS\")" 341 | ], 342 | "execution_count": 120, 343 | "outputs": [ 344 | { 345 | "output_type": "stream", 346 | "text": [ 347 | "PREDICTION OF 1 MODEL: 5 CLASS\n", 348 | "PREDICTION OF 2 MODEL: 5 CLASS\n", 349 | "PREDICTION OF 3 MODEL: 5 CLASS\n", 350 | "PREDICTION OF 4 MODEL: 5 CLASS\n" 351 | ], 352 | "name": "stdout" 353 | } 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": { 359 | "id": "KH81gNyGlcnh" 360 | }, 361 | "source": [ 362 | "On out-of-distribution samples Masksembles should produce predictions with high variance, let's check it on complex samples from MNIST." 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "metadata": { 368 | "id": "l9rbe1s9fBDf" 369 | }, 370 | "source": [ 371 | "img = np.load(\"./complex_sample_mnist.npy\")" 372 | ], 373 | "execution_count": 121, 374 | "outputs": [] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "metadata": { 379 | "id": "n6YghcdrfdkZ", 380 | "outputId": "a0e0ff76-3ad3-416b-d0ed-642debf7d491", 381 | "colab": { 382 | "base_uri": "https://localhost:8080/", 383 | "height": 283 384 | } 385 | }, 386 | "source": [ 387 | "plt.imshow(img[..., 0])" 388 | ], 389 | "execution_count": 122, 390 | "outputs": [ 391 | { 392 | "output_type": "execute_result", 393 | "data": { 394 | "text/plain": [ 395 | "" 396 | ] 397 | }, 398 | "metadata": { 399 | "tags": [] 400 | }, 401 | "execution_count": 122 402 | }, 403 | { 404 | "output_type": "display_data", 405 | "data": { 406 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPAklEQVR4nO3dfWxd9X3H8c/XjmOTB2hCmmAlGQks3ZbS1tnctAUKdBRI4Y/QSqNEapcxWncSSO3WTWXsD/ijf2SoLaq0DclA2rRqYVQFEalRIYuoWNcNxYQ0DwQID4mI68RJvCYG8uCH7/7wCTPg87vmPp0bf98vybrX53uPzzc3/vjce3/nnJ+5uwBMfU1FNwCgPgg7EARhB4Ig7EAQhB0IYlo9NzbdWr1NM+u5SSCUk3pTp/2UTVSrKOxmtkrS9yU1S3rA3delHt+mmfqEXV3JJgEkPONbcmtlv4w3s2ZJ/yrpc5KWS1pjZsvL/XkAaquS9+wrJb3s7q+6+2lJD0taXZ22AFRbJWFfKOn1cd8fyJa9g5l1mVmPmfUM6VQFmwNQiZp/Gu/u3e7e6e6dLWqt9eYA5Kgk7L2SFo/7flG2DEADqiTsWyUtM7OlZjZd0s2SNlanLQDVVvbQm7sPm9ntkp7Q2NDbenffXbXOAFRVRePs7r5J0qYq9QKghjhcFgiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEg6jplc1hNzeny9Jb0+i0l6kNDuaXRUyWm3HJP1zFlsGcHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAYZ6+Doc+uSNb3/+VIsn7LR/87WV//3KW5tT+582By3eHf9SXrjMNPHRWF3cz2SRqUNCJp2N07q9EUgOqrxp79M+5+pAo/B0AN8Z4dCKLSsLukJ83sWTPrmugBZtZlZj1m1jOkEsdpA6iZSl/GX+7uvWY2X9JmM3vB3Z8e/wB375bULUnn2lw+7QEKUtGe3d17s9t+SY9JWlmNpgBUX9lhN7OZZjb7zH1J10raVa3GAFRXJS/jF0h6zMzO/Jyfuvsvq9LVWcYv60jWX/tiev1/v6w7WT/p6fPZf9D0qfzi6Gh64wij7LC7+6uSPlbFXgDUEENvQBCEHQiCsANBEHYgCMIOBMEprlVw9JJzkvXVHVuT9ZWt6aG1pY9PeCTy25avyz9NdeRwiXOUOIU1DPbsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAE4+zVUGKo+vRohU9zS/o0VT92PL82PJxct2nGjGT95BUfTtYHF6b/bTMP5V8me9pb6Utot/bsTdZHjuf/u/Fe7NmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjG2augucSsVm8MT6/o539hxbZk/Rd/+8nc2pwX0mP0pz5gybqv+t9k/SPzf5esv3JsXm5t4ERbct3RrZck6xc+kp5uenR/b27Nh04n152K2LMDQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCMs1fBvN8cStZ7FqfHi+/6i6PJ+j0X9KTrX8mv765wPPmj09Nj4ad8KFnfcPzC3No1M15Krju4Iv3ruUZ/l6wv2ZB/AMRwb/r4gKmo5J7dzNabWb+Z7Rq3bK6ZbTazvdntnNq2CaBSk3kZ/0NJq9617A5JW9x9maQt2fcAGljJsLv705IG3rV4taQN2f0Nkm6scl8Aqqzc9+wL3P3MgckHJS3Ie6CZdUnqkqQ2pa93BqB2Kv403t1diUsuunu3u3e6e2eLWivdHIAylRv2Q2bWLknZbX/1WgJQC+WGfaOktdn9tZIer047AGql5Ht2M3tI0lWS5pnZAUl3SVon6REzu1XSfkk31bLJRjey99VkfcnP0n9TH3vrymT9qes+lKx/av5ryXrKwOmZyfr2wwuT9SP95ybrrQfy555/4wtPJNe9btbuZP3EwvR15/0c3jaOVzLs7r4mp3R1lXsBUEMcLgsEQdiBIAg7EARhB4Ig7EAQnOJaByMvvpyst5eoT/tp7tHIkqTfLkqfQpvSdCJ9iuoHDx5O10fT9WOfzR82vGhN+lisphJzYZ+3uzlZ1++Z0nk89uxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EATj7GeB4YPpS1WrVD0hfZKo1NSWvpR0/9oVyfqp6/LHum+YcSy57rePfDxZv+C/0tNJjwz8PlmPhj07EARhB4Ig7EAQhB0IgrADQRB2IAjCDgTBOHtwNi39K3Dizz+SrHfcsjNZ/87CJ3Nrrw2nz1d/aNMVyfpFL2xL1jVa6iiCWNizA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQjLNPAamx8ub2C5Lr7vvSHyTr3V/9l2R9xfTh9PrHlufWfvDA9cl1lz36erI+fOpUso53KrlnN7P1ZtZvZrvGLbvbzHrNbHv2lf5fA1C4ybyM/6GkVRMsv9fdO7KvTdVtC0C1lQy7uz8taaAOvQCooUo+oLvdzHZkL/Pn5D3IzLrMrMfMeobEeyygKOWG/T5JF0vqkNQn6bt5D3T3bnfvdPfOFrWWuTkAlSor7O5+yN1H3H1U0v2SVla3LQDVVlbYzax93Lefl7Qr77EAGkPJcXYze0jSVZLmmdkBSXdJusrMOiS5pH2SvlbDHsNrmj07Wfc/XpJb23dNet1HunLfgUmSPjz9nGRdmp6s/tuOK3Nrf/jjF5PrDh85WmLbeD9Kht3d10yw+MEa9AKghjhcFgiCsANBEHYgCMIOBEHYgSA4xfUscPLSP0rWj902mFt7ouOe5LqLps0qq6fJ+vuOzbm1B25YnVx33i9fSdZHDvWX1VNU7NmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjG2c8CbQffTNb7njs/t/aVWV9MrutuyXpLc3ra4yvnvZSsXzpjb27t9n/8WXLde2felKzPv+9wsi5PTwkdDXt2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQjCvI5jkefaXP+EXV237U0ZTc3pclv+TDs2o9SloEtoTm97ZGl6Sui9f9OSX7vm/uS61+65MVlvvTl9/MFIwEtRP+NbdNwHJjx4gj07EARhB4Ig7EAQhB0IgrADQRB2IAjCDgTB+exng9H0OeWjb72VX0zVqsCODiTr7b/4s9zaf346/ev3D0ueSNa/9aVbk/VFD+dfd3744KHkulNRyT27mS02s6fM7Hkz221mX8+WzzWzzWa2N7udU/t2AZRrMi/jhyV9092XS/qkpNvMbLmkOyRtcfdlkrZk3wNoUCXD7u597r4tuz8oaY+khZJWS9qQPWyDpPSxjQAK9b7es5vZEkkrJD0jaYG792Wlg5IW5KzTJalLkto0o9w+AVRo0p/Gm9ksST+X9A13Pz6+5mNn00x4Ro27d7t7p7t3tij/hA0AtTWpsJtZi8aC/hN3fzRbfMjM2rN6uySm1AQaWMmX8WZmkh6UtMfdvzeutFHSWknrstvHa9IhGpoPDyfrH9jal1u75Vd/nVz3nz+dvtT04NLRZF2Vnt47xUzmPftlkr4saaeZbc+W3amxkD9iZrdK2i8pfZFvAIUqGXZ3/7WkvJkEuBIFcJbgcFkgCMIOBEHYgSAIOxAEYQeC4BRX1NRo/5Hc2vn/szC57ksfb0/WfXZ6jN+npS+DHQ17diAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgnF21JblnTApjZb47ZvVfDJZ/9jFryfrJ8+bn95AMOzZgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIxtlRnAnnEPp/o57eF01rSk9ljXdizw4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQUxmfvbFkn4kaYHGRka73f37Zna3pK9KOpw99E5331SrRnF2spb8X7E305eN17xpx5P1Z3delKwvP9qfW0tfcX5qmsxBNcOSvunu28xstqRnzWxzVrvX3b9Tu/YAVMtk5mfvk9SX3R80sz2SSvxNBtBo3td7djNbImmFpGeyRbeb2Q4zW29mc3LW6TKzHjPrGdKpipoFUL5Jh93MZkn6uaRvuPtxSfdJulhSh8b2/N+daD1373b3TnfvbFFrFVoGUI5Jhd3MWjQW9J+4+6OS5O6H3H3E3Ucl3S9pZe3aBFCpkmE3M5P0oKQ97v69ccvHT7H5eUm7qt8egGqZzKfxl0n6sqSdZrY9W3anpDVm1qGx4bh9kr5Wkw5xVvOR0dxa20D+ZaYl6dvP3ZCsn/dCiV/fE+lLUUczmU/jfy1pov8VxtSBswhH0AFBEHYgCMIOBEHYgSAIOxAEYQeC4FLSqKnRwcHc2gX3/qam2454GmsKe3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCMLcS8ybW82NmR2WtH/conmSjtStgfenUXtr1L4keitXNXu70N0/OFGhrmF/z8bNety9s7AGEhq1t0btS6K3ctWrN17GA0EQdiCIosPeXfD2Uxq1t0btS6K3ctWlt0LfswOon6L37ADqhLADQRQSdjNbZWYvmtnLZnZHET3kMbN9ZrbTzLabWU/Bvaw3s34z2zVu2Vwz22xme7PbCefYK6i3u82sN3vutpvZ9QX1ttjMnjKz581st5l9PVte6HOX6Ksuz1vd37ObWbOklyRdI+mApK2S1rj783VtJIeZ7ZPU6e6FH4BhZldIekPSj9z9kmzZPZIG3H1d9odyjrt/q0F6u1vSG0VP453NVtQ+fppxSTdK+isV+Nwl+rpJdXjeitizr5T0sru/6u6nJT0saXUBfTQ8d39a0sC7Fq+WtCG7v0Fjvyx1l9NbQ3D3Pnfflt0flHRmmvFCn7tEX3VRRNgXSnp93PcH1FjzvbukJ83sWTPrKrqZCSxw977s/kFJC4psZgIlp/Gup3dNM94wz105059Xig/o3utyd/9TSZ+TdFv2crUh+dh7sEYaO53UNN71MsE0428r8rkrd/rzShUR9l5Ji8d9vyhb1hDcvTe77Zf0mBpvKupDZ2bQzW77C+7nbY00jfdE04yrAZ67Iqc/LyLsWyUtM7OlZjZd0s2SNhbQx3uY2czsgxOZ2UxJ16rxpqLeKGltdn+tpMcL7OUdGmUa77xpxlXwc1f49OfuXvcvSddr7BP5VyT9UxE95PR1kaTfZl+7i+5N0kMae1k3pLHPNm6VdL6kLZL2SvoPSXMbqLcfS9opaYfGgtVeUG+Xa+wl+g5J27Ov64t+7hJ91eV543BZIAg+oAOCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIP4PbYZ6s7wJM9MAAAAASUVORK5CYII=\n", 407 | "text/plain": [ 408 | "
" 409 | ] 410 | }, 411 | "metadata": { 412 | "tags": [], 413 | "needs_background": "light" 414 | } 415 | } 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "metadata": { 421 | "colab": { 422 | "base_uri": "https://localhost:8080/" 423 | }, 424 | "id": "65Y_-0h9U96h", 425 | "outputId": "9e86c6b5-2fea-4e6a-dfef-437da2979252" 426 | }, 427 | "source": [ 428 | "inputs = np.tile(img[None], [4, 1, 1, 1])\n", 429 | "predictions = model(inputs)\n", 430 | "for i, cls in enumerate(tf.argmax(predictions, axis=1)):\n", 431 | " print(f\"PREDICTION OF {i+1} MODEL: {cls} CLASS\")" 432 | ], 433 | "execution_count": 123, 434 | "outputs": [ 435 | { 436 | "output_type": "stream", 437 | "text": [ 438 | "PREDICTION OF 1 MODEL: 3 CLASS\n", 439 | "PREDICTION OF 2 MODEL: 7 CLASS\n", 440 | "PREDICTION OF 3 MODEL: 7 CLASS\n", 441 | "PREDICTION OF 4 MODEL: 7 CLASS\n" 442 | ], 443 | "name": "stdout" 444 | } 445 | ] 446 | } 447 | ] 448 | } -------------------------------------------------------------------------------- /notebooks/MNIST_Masksembles.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "dacca9e1" 7 | }, 8 | "source": [ 9 | "\n", 10 | "# Masksembles on MNIST — Tutorial\n", 11 | "\n", 12 | "Welcome! 👋\n", 13 | "\n", 14 | "This notebook walks you through using **Masksembles** for uncertainty‑aware inference on the classic **MNIST** dataset. \n", 15 | "\n", 16 | "You'll see how to:\n", 17 | "\n", 18 | "- Load and preprocess MNIST\n", 19 | "- Define **Masksembles** layers (`Masksembles2D` / `Masksembles1D`)\n", 20 | "- Build a small CNN\n", 21 | "- Train and validate in PyTorch\n", 22 | "- Run **per‑submodel predictions** (ensembled behavior) at inference time\n", 23 | "- Avoid common pitfalls\n", 24 | "\n", 25 | "> **What is Masksembles?** \n", 26 | "> Masksembles builds multiple *subnetworks* inside a single model using deterministic, non‑overlapping masks. At inference time you can obtain diverse predictions (like an ensemble) **without** keeping multiple separate models. This gives you better **uncertainty estimation** and often more robust predictions with minimal overhead.\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "670b136c" 33 | }, 34 | "source": [ 35 | "\n", 36 | "## Table of Contents\n", 37 | "1. [Setup & Imports](#setup-imports)\n", 38 | "2. [Load MNIST & Preprocessing](#load-preprocess)\n", 39 | "3. [Masksembles Layers](#masksembles-layers)\n", 40 | "4. [Model Architecture](#model-architecture)\n", 41 | "5. [Training Setup](#training-setup)\n", 42 | "6. [Training Loop](#training-loop)\n", 43 | "7. [Evaluation & Inference](#eval-infer)\n", 44 | "8. [Next Steps](#next-steps)\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "source": [ 50 | "But before we start, lets make sure you installed Masksembles package." 51 | ], 52 | "metadata": { 53 | "id": "YuRK1pIbrBL0" 54 | } 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 1, 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "collapsed": true, 64 | "id": "qrpn5UV2BYLF", 65 | "outputId": "b1634cd5-3287-45bc-de91-88c0ba806fd6" 66 | }, 67 | "outputs": [ 68 | { 69 | "output_type": "stream", 70 | "name": "stdout", 71 | "text": [ 72 | "Collecting masksembles\n", 73 | " Downloading masksembles-1.1-py3-none-any.whl.metadata (5.1 kB)\n", 74 | "Downloading masksembles-1.1-py3-none-any.whl (8.4 kB)\n", 75 | "Installing collected packages: masksembles\n", 76 | "Successfully installed masksembles-1.1\n", 77 | "--2025-11-09 12:52:18-- https://github.com/nikitadurasov/masksembles/raw/main/images/complex_sample_mnist.npy\n", 78 | "Resolving github.com (github.com)... 140.82.116.3\n", 79 | "Connecting to github.com (github.com)|140.82.116.3|:443... connected.\n", 80 | "HTTP request sent, awaiting response... 302 Found\n", 81 | "Location: https://raw.githubusercontent.com/nikitadurasov/masksembles/main/images/complex_sample_mnist.npy [following]\n", 82 | "--2025-11-09 12:52:18-- https://raw.githubusercontent.com/nikitadurasov/masksembles/main/images/complex_sample_mnist.npy\n", 83 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n", 84 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", 85 | "HTTP request sent, awaiting response... 200 OK\n", 86 | "Length: 6400 (6.2K) [application/octet-stream]\n", 87 | "Saving to: ‘complex_sample_mnist.npy’\n", 88 | "\n", 89 | "complex_sample_mnis 100%[===================>] 6.25K --.-KB/s in 0s \n", 90 | "\n", 91 | "2025-11-09 12:52:18 (96.2 MB/s) - ‘complex_sample_mnist.npy’ saved [6400/6400]\n", 92 | "\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "!pip install masksembles\n", 98 | "!wget https://github.com/nikitadurasov/masksembles/raw/main/images/complex_sample_mnist.npy" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": { 104 | "id": "32ca1561" 105 | }, 106 | "source": [ 107 | "\n", 108 | "## 1. Setup & Imports \n", 109 | "\n", 110 | "This section imports PyTorch, TorchVision, and the other packages used later. " 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "source": [ 116 | "import numpy as np\n", 117 | "import matplotlib.pyplot as plt\n", 118 | "\n", 119 | "import torch\n", 120 | "import torch.nn as nn\n", 121 | "import torch.nn.functional as F\n", 122 | "import torch.optim as optim\n", 123 | "from torch.utils.data import DataLoader, random_split, TensorDataset\n", 124 | "from torchvision import datasets, transforms" 125 | ], 126 | "metadata": { 127 | "id": "r4tPNSFvqUWp" 128 | }, 129 | "execution_count": 2, 130 | "outputs": [] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "be8e329c" 136 | }, 137 | "source": [ 138 | "\n", 139 | "## 2. Load MNIST & Preprocessing \n", 140 | "\n", 141 | "Here we load MNIST and transform images to **tensors in [0, 1]**. \n", 142 | "Remember that **PyTorch uses channel‑first** tensors *(N, C, H, W)*, so MNIST images should become shape `(N, 1, 28, 28)`.\n" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "source": [ 148 | "# Model / data parameters\n", 149 | "num_classes = 10\n", 150 | "input_shape = (1, 28, 28) # PyTorch uses channel-first format (C, H, W)\n", 151 | "\n", 152 | "# Define transform: convert to tensor and normalize to [0, 1]\n", 153 | "transform = transforms.Compose([\n", 154 | " transforms.ToTensor(), # Converts to tensor and scales to [0, 1]\n", 155 | "])\n", 156 | "\n", 157 | "# Load MNIST dataset\n", 158 | "train_dataset = datasets.MNIST(\n", 159 | " root='./data',\n", 160 | " train=True,\n", 161 | " transform=transform,\n", 162 | " download=True\n", 163 | ")\n", 164 | "test_dataset = datasets.MNIST(\n", 165 | " root='./data',\n", 166 | " train=False,\n", 167 | " transform=transform,\n", 168 | " download=True\n", 169 | ")\n", 170 | "\n", 171 | "# Create DataLoaders\n", 172 | "batch_size = 64\n", 173 | "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", 174 | "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n", 175 | "\n", 176 | "# Check shapes\n", 177 | "images, labels = next(iter(train_loader))\n", 178 | "print(\"x_train shape:\", images.shape) # [batch_size, 1, 28, 28]\n", 179 | "print(len(train_dataset), \"train samples\")\n", 180 | "print(len(test_dataset), \"test samples\")\n", 181 | "\n", 182 | "# Convert labels to one-hot if needed\n", 183 | "# (usually not needed for CrossEntropyLoss)\n", 184 | "y_one_hot = torch.nn.functional.one_hot(labels, num_classes=num_classes).float()\n", 185 | "print(\"y_train one-hot shape:\", y_one_hot.shape)\n" 186 | ], 187 | "metadata": { 188 | "colab": { 189 | "base_uri": "https://localhost:8080/" 190 | }, 191 | "id": "Cs0l4BmmpuWm", 192 | "outputId": "04a6eef1-f807-4305-db7c-27677aed40f9" 193 | }, 194 | "execution_count": 3, 195 | "outputs": [ 196 | { 197 | "output_type": "stream", 198 | "name": "stderr", 199 | "text": [ 200 | "100%|██████████| 9.91M/9.91M [00:00<00:00, 19.7MB/s]\n", 201 | "100%|██████████| 28.9k/28.9k [00:00<00:00, 480kB/s]\n", 202 | "100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]\n", 203 | "100%|██████████| 4.54k/4.54k [00:00<00:00, 9.13MB/s]" 204 | ] 205 | }, 206 | { 207 | "output_type": "stream", 208 | "name": "stdout", 209 | "text": [ 210 | "x_train shape: torch.Size([64, 1, 28, 28])\n", 211 | "60000 train samples\n", 212 | "10000 test samples\n", 213 | "y_train one-hot shape: torch.Size([64, 10])\n" 214 | ] 215 | }, 216 | { 217 | "output_type": "stream", 218 | "name": "stderr", 219 | "text": [ 220 | "\n" 221 | ] 222 | } 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "source": [ 228 | "\n", 229 | "## 3. Masksembles Layers \n", 230 | "\n", 231 | "These classes (`Masksembles2D`, `Masksembles1D`) create **non‑overlapping channel masks** that define *submodels* within the same network. \n", 232 | "- `Masksembles2D`: apply masks on convolutional feature maps (C×H×W). \n", 233 | "- `Masksembles1D`: apply masks on flattened features (C). \n", 234 | "Key arguments:\n", 235 | "- `channels`: number of channels in the input.\n", 236 | "- `n`: number of submodels (masks).\n", 237 | "- `scale`: controls correlation/capacity trade‑off between submodels.\n" 238 | ], 239 | "metadata": { 240 | "id": "1AgTlby2p0g-" 241 | } 242 | }, 243 | { 244 | "cell_type": "code", 245 | "source": [ 246 | "from masksembles.torch import Masksembles2D, Masksembles1D" 247 | ], 248 | "metadata": { 249 | "id": "lAnwkpGtp3qP" 250 | }, 251 | "execution_count": 4, 252 | "outputs": [] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": { 257 | "id": "Hr5xjOSffo7x" 258 | }, 259 | "source": [ 260 | "In order to transform regular model into Masksembles model one should add Masksembles2D or Masksembles1D layers in it. General recommendation is to insert these layers right before or after convolutional layers.\n", 261 | "\n", 262 | "In example below we'll use both Masksembles2D and Masksembles1D layers applied after convolutions." 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": { 268 | "id": "75491b02" 269 | }, 270 | "source": [ 271 | "\n", 272 | "## 4. Model Architecture \n", 273 | "\n", 274 | "We define a small CNN with ELU activations and interleave **Masksembles** after conv and before the final linear layer. \n", 275 | "After two conv+pool blocks, MNIST's 28×28 maps down to 5×5; the final linear projects to 10 classes.\n" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "source": [ 281 | "import torch\n", 282 | "from torch import nn\n", 283 | "\n", 284 | "from masksembles import common\n", 285 | "\n", 286 | "\n", 287 | "class Masksembles2D(nn.Module):\n", 288 | " \"\"\"\n", 289 | " :class:`Masksembles2D` is high-level class that implements Masksembles approach\n", 290 | " for 2-dimensional inputs (similar to :class:`torch.nn.Dropout2d`).\n", 291 | "\n", 292 | " :param channels: int, number of channels used in masks.\n", 293 | " :param n: int, number of masks\n", 294 | " :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \\\n", 295 | " subnetworks correlations but at the same time decrease capacity of every individual model.\n", 296 | "\n", 297 | " Shape:\n", 298 | " * Input: (N, C, H, W)\n", 299 | " * Output: (N, C, H, W) (same shape as input)\n", 300 | "\n", 301 | " Examples:\n", 302 | "\n", 303 | " >>> m = Masksembles2D(16, 4, 2.0)\n", 304 | " >>> input = torch.ones([4, 16, 28, 28])\n", 305 | " >>> output = m(input)\n", 306 | "\n", 307 | " References:\n", 308 | "\n", 309 | " [1] `Masksembles for Uncertainty Estimation`,\n", 310 | " Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua\n", 311 | "\n", 312 | " \"\"\"\n", 313 | " def __init__(self, channels: int, n: int, scale: float):\n", 314 | " super().__init__()\n", 315 | " self.channels, self.n, self.scale = channels, n, scale\n", 316 | " masks_np = common.generation_wrapper(channels, n, scale) # numpy float64 by default\n", 317 | " masks = torch.as_tensor(masks_np, dtype=torch.float32) # make float32 here\n", 318 | " self.register_buffer('masks', masks, persistent=False) # not trainable, moves with .to()\n", 319 | "\n", 320 | " def forward(self, inputs):\n", 321 | " # make sure masks match dtype/device (usually already true because of buffer, but safe)\n", 322 | " masks = self.masks.to(dtype=inputs.dtype, device=inputs.device)\n", 323 | "\n", 324 | " batch = inputs.shape[0]\n", 325 | " # safer split even if batch % n != 0\n", 326 | " chunks = torch.chunk(inputs.unsqueeze(1), self.n, dim=0) # returns nearly equal chunks\n", 327 | " x = torch.cat(chunks, dim=1).permute(1, 0, 2, 3, 4) # [n, ?, C, H, W]\n", 328 | " x = x * masks.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # broadcast masks\n", 329 | " x = torch.cat(torch.split(x, 1, dim=0), dim=1)\n", 330 | " return x.squeeze(0)\n", 331 | "\n", 332 | "\n", 333 | "\n", 334 | "class Masksembles1D(nn.Module):\n", 335 | " \"\"\"\n", 336 | " :class:`Masksembles1D` is high-level class that implements Masksembles approach\n", 337 | " for 1-dimensional inputs (similar to :class:`torch.nn.Dropout`).\n", 338 | "\n", 339 | " :param channels: int, number of channels used in masks.\n", 340 | " :param n: int, number of masks\n", 341 | " :param scale: float, scale parameter similar to *S* in [1]. Larger values decrease \\\n", 342 | " subnetworks correlations but at the same time decrease capacity of every individual model.\n", 343 | "\n", 344 | " Shape:\n", 345 | " * Input: (N, C)\n", 346 | " * Output: (N, C) (same shape as input)\n", 347 | "\n", 348 | " Examples:\n", 349 | "\n", 350 | " >>> m = Masksembles1D(16, 4, 2.0)\n", 351 | " >>> input = torch.ones([4, 16])\n", 352 | " >>> output = m(input)\n", 353 | "\n", 354 | "\n", 355 | " References:\n", 356 | "\n", 357 | " [1] `Masksembles for Uncertainty Estimation`,\n", 358 | " Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua\n", 359 | "\n", 360 | " \"\"\"\n", 361 | "\n", 362 | " def __init__(self, channels: int, n: int, scale: float):\n", 363 | " super().__init__()\n", 364 | " self.channels, self.n, self.scale = channels, n, scale\n", 365 | " masks_np = common.generation_wrapper(channels, n, scale)\n", 366 | " masks = torch.as_tensor(masks_np, dtype=torch.float32)\n", 367 | " self.register_buffer('masks', masks, persistent=False)\n", 368 | "\n", 369 | " def forward(self, inputs):\n", 370 | " masks = self.masks.to(dtype=inputs.dtype, device=inputs.device)\n", 371 | "\n", 372 | " batch = inputs.shape[0]\n", 373 | " chunks = torch.chunk(inputs.unsqueeze(1), self.n, dim=0)\n", 374 | " x = torch.cat(chunks, dim=1).permute(1, 0, 2) # [n, ?, C]\n", 375 | " x = x * masks.unsqueeze(1)\n", 376 | " x = torch.cat(torch.split(x, 1, dim=0), dim=1)\n", 377 | " return x.squeeze(0)\n" 378 | ], 379 | "metadata": { 380 | "id": "ZxZQ-0Fjsa11" 381 | }, 382 | "execution_count": 8, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 9, 388 | "metadata": { 389 | "colab": { 390 | "base_uri": "https://localhost:8080/" 391 | }, 392 | "id": "7jw12ljXCaMg", 393 | "outputId": "df05375d-7bd1-4d14-c0d5-0a43388381ce" 394 | }, 395 | "outputs": [ 396 | { 397 | "output_type": "stream", 398 | "name": "stdout", 399 | "text": [ 400 | "MasksemblesCNN(\n", 401 | " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", 402 | " (mask1): Masksembles2D()\n", 403 | " (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 404 | " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", 405 | " (mask2): Masksembles2D()\n", 406 | " (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 407 | " (flatten): Flatten(start_dim=1, end_dim=-1)\n", 408 | " (mask3): Masksembles1D()\n", 409 | " (fc): Linear(in_features=1600, out_features=10, bias=True)\n", 410 | ")\n", 411 | "\n", 412 | "Total parameters: 34,826\n", 413 | "Trainable parameters: 34,826\n" 414 | ] 415 | } 416 | ], 417 | "source": [ 418 | "class MasksemblesCNN(nn.Module):\n", 419 | " def __init__(self, num_classes=10):\n", 420 | " super().__init__()\n", 421 | " self.conv1 = nn.Conv2d(1, 32, kernel_size=3)\n", 422 | " self.mask1 = Masksembles2D(32, n=4, scale=2.0)\n", 423 | " self.pool1 = nn.MaxPool2d(kernel_size=2)\n", 424 | "\n", 425 | " self.conv2 = nn.Conv2d(32, 64, kernel_size=3)\n", 426 | " self.mask2 = Masksembles2D(64, n=4, scale=2.0)\n", 427 | " self.pool2 = nn.MaxPool2d(kernel_size=2)\n", 428 | "\n", 429 | " self.flatten = nn.Flatten()\n", 430 | " self.mask3 = Masksembles1D(64 * 5 * 5, n=4, scale=2.0)\n", 431 | " self.fc = nn.Linear(64 * 5 * 5, num_classes) # 5×5 after two 2×2 pools from 28×28 input\n", 432 | "\n", 433 | " def forward(self, x):\n", 434 | " x = F.elu(self.conv1(x))\n", 435 | " x = self.mask1(x)\n", 436 | " x = self.pool1(x)\n", 437 | "\n", 438 | " x = F.elu(self.conv2(x))\n", 439 | " x = self.mask2(x)\n", 440 | " x = self.pool2(x)\n", 441 | "\n", 442 | " x = self.flatten(x)\n", 443 | " x = self.mask3(x)\n", 444 | " x = F.softmax(self.fc(x), dim=1)\n", 445 | " return x\n", 446 | "\n", 447 | "\n", 448 | "# Instantiate and summarize\n", 449 | "model = MasksemblesCNN(num_classes=10)\n", 450 | "print(model)\n", 451 | "\n", 452 | "# Print parameter count\n", 453 | "total_params = sum(p.numel() for p in model.parameters())\n", 454 | "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 455 | "print(f\"\\nTotal parameters: {total_params:,}\")\n", 456 | "print(f\"Trainable parameters: {trainable_params:,}\")\n" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": { 462 | "id": "DP9Km-bGigRC" 463 | }, 464 | "source": [ 465 | "Training of Masksembles is not different from training of regular model. So we just use standard Pytorch training loop." 466 | ] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "source": [ 471 | "\n", 472 | "## 5. Training Setup \n", 473 | "\n", 474 | "We configure:\n", 475 | "- **Loss**: `CrossEntropyLoss` (expects integer class indices, not one‑hot)\n", 476 | "- **Optimizer**: `Adam`\n", 477 | "- **DataLoaders**: mini‑batches for train and validation (90/10 split)\n" 478 | ], 479 | "metadata": { 480 | "id": "Xa-4-JNFqE0l" 481 | } 482 | }, 483 | { 484 | "cell_type": "code", 485 | "source": [ 486 | "# Model, loss, optimizer\n", 487 | "model = MasksemblesCNN(num_classes=10)\n", 488 | "criterion = nn.CrossEntropyLoss()\n", 489 | "optimizer = optim.Adam(model.parameters())\n", 490 | "\n", 491 | "# Training loop\n", 492 | "epochs = 3\n", 493 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 494 | "model.to(device)" 495 | ], 496 | "metadata": { 497 | "colab": { 498 | "base_uri": "https://localhost:8080/" 499 | }, 500 | "id": "y-trVv38qQYA", 501 | "outputId": "1a5198b6-ecfe-44a1-d155-693af748150d" 502 | }, 503 | "execution_count": 10, 504 | "outputs": [ 505 | { 506 | "output_type": "execute_result", 507 | "data": { 508 | "text/plain": [ 509 | "MasksemblesCNN(\n", 510 | " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", 511 | " (mask1): Masksembles2D()\n", 512 | " (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 513 | " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", 514 | " (mask2): Masksembles2D()\n", 515 | " (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 516 | " (flatten): Flatten(start_dim=1, end_dim=-1)\n", 517 | " (mask3): Masksembles1D()\n", 518 | " (fc): Linear(in_features=1600, out_features=10, bias=True)\n", 519 | ")" 520 | ] 521 | }, 522 | "metadata": {}, 523 | "execution_count": 10 524 | } 525 | ] 526 | }, 527 | { 528 | "cell_type": "markdown", 529 | "source": [ 530 | "\n", 531 | "## 6. Training Loop \n", 532 | "\n", 533 | "Standard PyTorch loop:\n", 534 | "1. `model.train()`, forward pass\n", 535 | "2. Compute loss, `backward()`, `step()`\n", 536 | "3. Track accuracy\n", 537 | "Then switch to `model.eval()` for validation with `torch.no_grad()`.\n" 538 | ], 539 | "metadata": { 540 | "id": "tKxCoTOZqACI" 541 | } 542 | }, 543 | { 544 | "cell_type": "code", 545 | "source": [ 546 | "for epoch in range(epochs):\n", 547 | " model.train()\n", 548 | " train_loss, train_correct = 0, 0\n", 549 | " for x_batch, y_batch in train_loader:\n", 550 | " x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n", 551 | "\n", 552 | " optimizer.zero_grad()\n", 553 | " outputs = model(x_batch)\n", 554 | " loss = criterion(outputs, y_batch)\n", 555 | " loss.backward()\n", 556 | " optimizer.step()\n", 557 | "\n", 558 | " train_loss += loss.item() * x_batch.size(0)\n", 559 | " train_correct += (outputs.argmax(1) == y_batch).sum().item()\n", 560 | "\n", 561 | " # Validation\n", 562 | " model.eval()\n", 563 | " val_loss, val_correct = 0, 0\n", 564 | " with torch.no_grad():\n", 565 | " for x_batch, y_batch in test_loader:\n", 566 | " x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n", 567 | " outputs = model(x_batch)\n", 568 | " loss = criterion(outputs, y_batch)\n", 569 | " val_loss += loss.item() * x_batch.size(0)\n", 570 | " val_correct += (outputs.argmax(1) == y_batch).sum().item()\n", 571 | "\n", 572 | " # Print epoch stats\n", 573 | " print(\n", 574 | " f\"Epoch {epoch+1}/{epochs} \"\n", 575 | " f\"- Train loss: {train_loss/len(train_loader.dataset):.4f}, \"\n", 576 | " f\"Train acc: {train_correct/len(train_loader.dataset):.4f}, \"\n", 577 | " f\"Val loss: {val_loss/len(test_loader.dataset):.4f}, \"\n", 578 | " f\"Val acc: {val_correct/len(test_loader.dataset):.4f}\"\n", 579 | " )\n" 580 | ], 581 | "metadata": { 582 | "colab": { 583 | "base_uri": "https://localhost:8080/" 584 | }, 585 | "id": "nG18Ftg_qNqQ", 586 | "outputId": "c775536c-861c-42c8-8723-29faa4c2b0ca" 587 | }, 588 | "execution_count": 11, 589 | "outputs": [ 590 | { 591 | "output_type": "stream", 592 | "name": "stdout", 593 | "text": [ 594 | "Epoch 1/3 - Train loss: 1.5990, Train acc: 0.8808, Val loss: 1.5153, Val acc: 0.9487\n", 595 | "Epoch 2/3 - Train loss: 1.5098, Train acc: 0.9545, Val loss: 1.4981, Val acc: 0.9648\n", 596 | "Epoch 3/3 - Train loss: 1.4985, Train acc: 0.9646, Val loss: 1.4927, Val acc: 0.9695\n" 597 | ] 598 | } 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": { 604 | "id": "25b935df" 605 | }, 606 | "source": [ 607 | "\n", 608 | "## 7. Evaluation & Inference \n", 609 | "\n", 610 | "For evaluation, we disable gradients (`torch.no_grad()`) and call `model.eval()` to turn off training‑time behavior. \n", 611 | "We then compute `argmax` over class logits to get predicted labels.\n", 612 | "\n", 613 | "### **Important Notes on Inference :** Batch tiling and mask assignment\n", 614 | "\n", 615 | "Masksembles layers divide each batch into *N* segments — one for each submodel (mask). \n", 616 | "That means:\n", 617 | "- The **first** 1/N of the batch goes through the first mask, \n", 618 | "- The **second** 1/N through the second mask, and so on.\n", 619 | "\n", 620 | "So, to obtain predictions from **all submodels**, you need to **tile the same input image** *N* times along the batch dimension. \n", 621 | "Each of these replicated samples will be processed by a different submodel. \n", 622 | "After inference, you can then collect the *N* outputs to see how each submodel predicts the same input.\n", 623 | "\n" 624 | ] 625 | }, 626 | { 627 | "cell_type": "markdown", 628 | "metadata": { 629 | "id": "Wh3pGxWujZ0Z" 630 | }, 631 | "source": [ 632 | "Now, we will check that all of Masksembles' submodels would predict similar predictions for training samples." 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": 12, 638 | "metadata": { 639 | "id": "T3GNT8Z5kItR" 640 | }, 641 | "outputs": [], 642 | "source": [ 643 | "img = train_dataset[0][0] # just random image from training set" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 13, 649 | "metadata": { 650 | "colab": { 651 | "base_uri": "https://localhost:8080/", 652 | "height": 430 653 | }, 654 | "id": "o0K51rvqkMIs", 655 | "outputId": "dae2318b-5899-4402-dd06-f73fad962cb2" 656 | }, 657 | "outputs": [ 658 | { 659 | "output_type": "display_data", 660 | "data": { 661 | "text/plain": [ 662 | "
" 663 | ], 664 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHE1JREFUeJzt3X9w1PW97/HXAskKmiyNIb9KwIA/sALxFiVmQMSSS0jnOICMB390BrxeHDF4imj1xlGR1jNp8Y61eqne06lEZ8QfnBGojuWOBhOONaEDShlu25TQWOIhCRUnuyFICMnn/sF160ICftZd3kl4Pma+M2T3++b78evWZ7/ZzTcB55wTAADn2DDrBQAAzk8ECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmBhhvYBT9fb26uDBg0pLS1MgELBeDgDAk3NOHR0dysvL07Bh/V/nDLgAHTx4UPn5+dbLAAB8Q83NzRo7dmy/zw+4AKWlpUmSZur7GqEU49UAAHydULc+0DvR/573J2kBWrdunZ566im1traqsLBQzz33nKZPn37WuS+/7TZCKRoRIEAAMOj8/zuMnu1tlKR8COH111/XqlWrtHr1an300UcqLCxUaWmpDh06lIzDAQAGoaQE6Omnn9ayZct055136jvf+Y5eeOEFjRo1Si+++GIyDgcAGIQSHqDjx49r165dKikp+cdBhg1TSUmJ6urqTtu/q6tLkUgkZgMADH0JD9Bnn32mnp4eZWdnxzyenZ2t1tbW0/avrKxUKBSKbnwCDgDOD+Y/iFpRUaFwOBzdmpubrZcEADgHEv4puMzMTA0fPlxtbW0xj7e1tSknJ+e0/YPBoILBYKKXAQAY4BJ+BZSamqpp06apuro6+lhvb6+qq6tVXFyc6MMBAAappPwc0KpVq7RkyRJdc801mj59up555hl1dnbqzjvvTMbhAACDUFICtHjxYv3973/X448/rtbWVl199dXaunXraR9MAACcvwLOOWe9iK+KRCIKhUKarfncCQEABqETrls12qJwOKz09PR+9zP/FBwA4PxEgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmBhhvQBgIAmM8P+fxPAxmUlYSWI0PHhJXHM9o3q9Z8ZPPOQ9M+regPdM69Op3jMfXfO694wkfdbT6T1TtPEB75lLV9V7zwwFXAEBAEwQIACAiYQH6IknnlAgEIjZJk2alOjDAAAGuaS8B3TVVVfpvffe+8dB4vi+OgBgaEtKGUaMGKGcnJxk/NUAgCEiKe8B7du3T3l5eZowYYLuuOMOHThwoN99u7q6FIlEYjYAwNCX8AAVFRWpqqpKW7du1fPPP6+mpiZdf/316ujo6HP/yspKhUKh6Jafn5/oJQEABqCEB6isrEy33HKLpk6dqtLSUr3zzjtqb2/XG2+80ef+FRUVCofD0a25uTnRSwIADEBJ/3TA6NGjdfnll6uxsbHP54PBoILBYLKXAQAYYJL+c0BHjhzR/v37lZubm+xDAQAGkYQH6MEHH1Rtba0++eQTffjhh1q4cKGGDx+u2267LdGHAgAMYgn/Ftynn36q2267TYcPH9aYMWM0c+ZM1dfXa8yYMYk+FABgEEt4gF577bVE/5UYoIZfeZn3jAumeM8cvGG098wX1/nfRFKSMkL+c/9RGN+NLoea3x5N85752f+a5z2zY8oG75mm7i+8ZyTpp23/1Xsm7z9cXMc6H3EvOACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADARNJ/IR0Gvp7Z341r7umqdd4zl6ekxnUsnFvdrsd75vHnlnrPjOj0v3Fn8cYV3jNp/3nCe0aSgp/538R01M4dcR3rfMQVEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExwN2wo2HAwrrldx/K9Zy5PaYvrWEPNAy3Xec/89Uim90zVxH/3npGkcK//Xaqzn/0wrmMNZP5nAT64AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHAzUuhES2tcc8/97BbvmX+d1+k9M3zPRd4zf7j3Oe+ZeD352VTvmcaSUd4zPe0t3jO3F9/rPSNJn/yL/0yB/hDXsXD+4goIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADDBzUgRt4z1dd4zY9662Hum5/Dn3jNXTf5v3jOS9H9nveg985t/u8F7Jqv9Q++ZeATq4rtBaIH/v1rAG1dAAAATBAgAYMI7QNu3b9dNN92kvLw8BQIBbd68OeZ555wef/xx5ebmauTIkSopKdG+ffsStV4AwBDhHaDOzk4VFhZq3bp1fT6/du1aPfvss3rhhRe0Y8cOXXjhhSotLdWxY8e+8WIBAEOH94cQysrKVFZW1udzzjk988wzevTRRzV//nxJ0ssvv6zs7Gxt3rxZt9566zdbLQBgyEjoe0BNTU1qbW1VSUlJ9LFQKKSioiLV1fX9sZquri5FIpGYDQAw9CU0QK2trZKk7OzsmMezs7Ojz52qsrJSoVAouuXn5ydySQCAAcr8U3AVFRUKh8PRrbm52XpJAIBzIKEBysnJkSS1tbXFPN7W1hZ97lTBYFDp6ekxGwBg6EtogAoKCpSTk6Pq6uroY5FIRDt27FBxcXEiDwUAGOS8PwV35MgRNTY2Rr9uamrS7t27lZGRoXHjxmnlypV68sknddlll6mgoECPPfaY8vLytGDBgkSuGwAwyHkHaOfOnbrxxhujX69atUqStGTJElVVVemhhx5SZ2en7r77brW3t2vmzJnaunWrLrjggsStGgAw6AWcc856EV8ViUQUCoU0W/M1IpBivRwMUn/539fGN/dPL3jP3Pm3Od4zf5/Z4T2j3h7/GcDACdetGm1ROBw+4/v65p+CAwCcnwgQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDC+9cxAIPBlQ//Ja65O6f439l6/fjqs+90ihtuKfeeSXu93nsGGMi4AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHAzUgxJPe3huOYOL7/Se+bAb77wnvkfT77sPVPxzwu9Z9zHIe8ZScr/1zr/IefiOhbOX1wBAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmuBkp8BW9f/iT98yta37kPfPK6v/pPbP7Ov8bmOo6/xFJuurCFd4zl/2qxXvmxF8/8Z7B0MEVEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgIuCcc9aL+KpIJKJQKKTZmq8RgRTr5QBJ4WZc7T2T/tNPvWdenfB/vGfiNen9/+49c8WasPdMz76/es/g3DrhulWjLQqHw0pPT+93P66AAAAmCBAAwIR3gLZv366bbrpJeXl5CgQC2rx5c8zzS5cuVSAQiNnmzZuXqPUCAIYI7wB1dnaqsLBQ69at63efefPmqaWlJbq9+uqr32iRAIChx/s3opaVlamsrOyM+wSDQeXk5MS9KADA0JeU94BqamqUlZWlK664QsuXL9fhw4f73berq0uRSCRmAwAMfQkP0Lx58/Tyyy+rurpaP/vZz1RbW6uysjL19PT0uX9lZaVCoVB0y8/PT/SSAAADkPe34M7m1ltvjf55ypQpmjp1qiZOnKiamhrNmTPntP0rKiq0atWq6NeRSIQIAcB5IOkfw54wYYIyMzPV2NjY5/PBYFDp6ekxGwBg6Et6gD799FMdPnxYubm5yT4UAGAQ8f4W3JEjR2KuZpqamrR7925lZGQoIyNDa9as0aJFi5STk6P9+/froYce0qWXXqrS0tKELhwAMLh5B2jnzp268cYbo19/+f7NkiVL9Pzzz2vPnj166aWX1N7erry8PM2dO1c/+clPFAwGE7dqAMCgx81IgUFieHaW98zBxZfGdawdD//Ce2ZYHN/Rv6NprvdMeGb/P9aBgYGbkQIABjQCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYSPiv5AaQHD1th7xnsp/1n5GkYw+d8J4ZFUj1nvnVJW97z/zTwpXeM6M27fCeQfJxBQQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmOBmpICB3plXe8/sv+UC75nJV3/iPSPFd2PReDz3+X/xnhm1ZWcSVgILXAEBAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa4GSnwFYFrJnvP/OVf/G/c+asZL3nPzLrguPfMudTlur1n6j8v8D9Qb4v/DAYkroAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABPcjBQD3oiC8d4z++/Mi+tYTyx+zXtm0UWfxXWsgeyRtmu8Z2p/cZ33zLdeqvOewdDBFRAAwAQBAgCY8ApQZWWlrr32WqWlpSkrK0sLFixQQ0NDzD7Hjh1TeXm5Lr74Yl100UVatGiR2traErpoAMDg5xWg2tpalZeXq76+Xu+++666u7s1d+5cdXZ2Rve5//779dZbb2njxo2qra3VwYMHdfPNNyd84QCAwc3rQwhbt26N+bqqqkpZWVnatWuXZs2apXA4rF//+tfasGGDvve970mS1q9fryuvvFL19fW67jr/NykBAEPTN3oPKBwOS5IyMjIkSbt27VJ3d7dKSkqi+0yaNEnjxo1TXV3fn3bp6upSJBKJ2QAAQ1/cAert7dXKlSs1Y8YMTZ48WZLU2tqq1NRUjR49Ombf7Oxstba29vn3VFZWKhQKRbf8/Px4lwQAGETiDlB5ebn27t2r117z/7mJr6qoqFA4HI5uzc3N3+jvAwAMDnH9IOqKFSv09ttva/v27Ro7dmz08ZycHB0/flzt7e0xV0FtbW3Kycnp8+8KBoMKBoPxLAMAMIh5XQE557RixQpt2rRJ27ZtU0FBQczz06ZNU0pKiqqrq6OPNTQ06MCBAyouLk7MigEAQ4LXFVB5ebk2bNigLVu2KC0tLfq+TigU0siRIxUKhXTXXXdp1apVysjIUHp6uu677z4VFxfzCTgAQAyvAD3//POSpNmzZ8c8vn79ei1dulSS9POf/1zDhg3TokWL1NXVpdLSUv3yl79MyGIBAENHwDnnrBfxVZFIRKFQSLM1XyMCKdbLwRmMuGSc90x4Wq73zOIfbz37Tqe4Z/RfvWcGugda/L+LUPdL/5uKSlJG1e/9h3p74joWhp4Trls12qJwOKz09PR+9+NecAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADAR129ExcA1Irfv3zx7Jp+/eGFcx1peUOs9c1taW1zHGshW/OdM75mPnr/aeybz3/d6z2R01HnPAOcKV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAluRnqOHC+9xn/m/s+9Zx659B3vmbkjO71nBrq2ni/impv1mwe8ZyY9+mfvmYx2/5uE9npPAAMbV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAluRnqOfLLAv/V/mbIxCStJnHXtE71nflE713sm0BPwnpn0ZJP3jCRd1rbDe6YnriMB4AoIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADARcM4560V8VSQSUSgU0mzN14hAivVyAACeTrhu1WiLwuGw0tPT+92PKyAAgAkCBAAw4RWgyspKXXvttUpLS1NWVpYWLFighoaGmH1mz56tQCAQs91zzz0JXTQAYPDzClBtba3Ky8tVX1+vd999V93d3Zo7d646Oztj9lu2bJlaWlqi29q1axO6aADA4Of1G1G3bt0a83VVVZWysrK0a9cuzZo1K/r4qFGjlJOTk5gVAgCGpG/0HlA4HJYkZWRkxDz+yiuvKDMzU5MnT1ZFRYWOHj3a79/R1dWlSCQSswEAhj6vK6Cv6u3t1cqVKzVjxgxNnjw5+vjtt9+u8ePHKy8vT3v27NHDDz+shoYGvfnmm33+PZWVlVqzZk28ywAADFJx/xzQ8uXL9dvf/lYffPCBxo4d2+9+27Zt05w5c9TY2KiJEyee9nxXV5e6urqiX0ciEeXn5/NzQAAwSH3dnwOK6wpoxYoVevvtt7V9+/YzxkeSioqKJKnfAAWDQQWDwXiWAQAYxLwC5JzTfffdp02bNqmmpkYFBQVnndm9e7ckKTc3N64FAgCGJq8AlZeXa8OGDdqyZYvS0tLU2toqSQqFQho5cqT279+vDRs26Pvf/74uvvhi7dmzR/fff79mzZqlqVOnJuUfAAAwOHm9BxQIBPp8fP369Vq6dKmam5v1gx/8QHv37lVnZ6fy8/O1cOFCPfroo2f8PuBXcS84ABjckvIe0NlalZ+fr9raWp+/EgBwnuJecAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEyOsF3Aq55wk6YS6JWe8GACAtxPqlvSP/573Z8AFqKOjQ5L0gd4xXgkA4Jvo6OhQKBTq9/mAO1uizrHe3l4dPHhQaWlpCgQCMc9FIhHl5+erublZ6enpRiu0x3k4ifNwEufhJM7DSQPhPDjn1NHRoby8PA0b1v87PQPuCmjYsGEaO3bsGfdJT08/r19gX+I8nMR5OInzcBLn4STr83CmK58v8SEEAIAJAgQAMDGoAhQMBrV69WoFg0HrpZjiPJzEeTiJ83AS5+GkwXQeBtyHEAAA54dBdQUEABg6CBAAwAQBAgCYIEAAABODJkDr1q3TJZdcogsuuEBFRUX6/e9/b72kc+6JJ55QIBCI2SZNmmS9rKTbvn27brrpJuXl5SkQCGjz5s0xzzvn9Pjjjys3N1cjR45USUmJ9u3bZ7PYJDrbeVi6dOlpr4958+bZLDZJKisrde211yotLU1ZWVlasGCBGhoaYvY5duyYysvLdfHFF+uiiy7SokWL1NbWZrTi5Pg652H27NmnvR7uueceoxX3bVAE6PXXX9eqVau0evVqffTRRyosLFRpaakOHTpkvbRz7qqrrlJLS0t0++CDD6yXlHSdnZ0qLCzUunXr+nx+7dq1evbZZ/XCCy9ox44duvDCC1VaWqpjx46d45Um19nOgyTNmzcv5vXx6quvnsMVJl9tba3Ky8tVX1+vd999V93d3Zo7d646Ozuj+9x///166623tHHjRtXW1urgwYO6+eabDVedeF/nPEjSsmXLYl4Pa9euNVpxP9wgMH36dFdeXh79uqenx+Xl5bnKykrDVZ17q1evdoWFhdbLMCXJbdq0Kfp1b2+vy8nJcU899VT0sfb2dhcMBt2rr75qsMJz49Tz4JxzS5YscfPnzzdZj5VDhw45Sa62ttY5d/LffUpKitu4cWN0nz/96U9Okqurq7NaZtKdeh6cc+6GG25wP/zhD+0W9TUM+Cug48ePa9euXSopKYk+NmzYMJWUlKiurs5wZTb27dunvLw8TZgwQXfccYcOHDhgvSRTTU1Nam1tjXl9hEIhFRUVnZevj5qaGmVlZemKK67Q8uXLdfjwYeslJVU4HJYkZWRkSJJ27dql7u7umNfDpEmTNG7cuCH9ejj1PHzplVdeUWZmpiZPnqyKigodPXrUYnn9GnA3Iz3VZ599pp6eHmVnZ8c8np2drT//+c9Gq7JRVFSkqqoqXXHFFWppadGaNWt0/fXXa+/evUpLS7NenonW1lZJ6vP18eVz54t58+bp5ptvVkFBgfbv369HHnlEZWVlqqur0/Dhw62Xl3C9vb1auXKlZsyYocmTJ0s6+XpITU3V6NGjY/Ydyq+Hvs6DJN1+++0aP3688vLytGfPHj388MNqaGjQm2++abjaWAM+QPiHsrKy6J+nTp2qoqIijR8/Xm+88Ybuuusuw5VhILj11lujf54yZYqmTp2qiRMnqqamRnPmzDFcWXKUl5dr796958X7oGfS33m4++67o3+eMmWKcnNzNWfOHO3fv18TJ04818vs04D/FlxmZqaGDx9+2qdY2tralJOTY7SqgWH06NG6/PLL1djYaL0UM1++Bnh9nG7ChAnKzMwckq+PFStW6O2339b7778f8+tbcnJydPz4cbW3t8fsP1RfD/2dh74UFRVJ0oB6PQz4AKWmpmratGmqrq6OPtbb26vq6moVFxcbrszekSNHtH//fuXm5lovxUxBQYFycnJiXh+RSEQ7duw4718fn376qQ4fPjykXh/OOa1YsUKbNm3Stm3bVFBQEPP8tGnTlJKSEvN6aGho0IEDB4bU6+Fs56Evu3fvlqSB9Xqw/hTE1/Haa6+5YDDoqqqq3B//+Ed39913u9GjR7vW1lbrpZ1TDzzwgKupqXFNTU3ud7/7nSspKXGZmZnu0KFD1ktLqo6ODvfxxx+7jz/+2ElyTz/9tPv444/d3/72N+eccz/96U/d6NGj3ZYtW9yePXvc/PnzXUFBgfviiy+MV55YZzoPHR0d7sEHH3R1dXWuqanJvffee+673/2uu+yyy9yxY8esl54wy5cvd6FQyNXU1LiWlpbodvTo0eg+99xzjxs3bpzbtm2b27lzpysuLnbFxcWGq068s52HxsZG9+Mf/9jt3LnTNTU1uS1btrgJEya4WbNmGa881qAIkHPOPffcc27cuHEuNTXVTZ8+3dXX11sv6ZxbvHixy83Ndampqe7b3/62W7x4sWtsbLReVtK9//77TtJp25IlS5xzJz+K/dhjj7ns7GwXDAbdnDlzXENDg+2ik+BM5+Ho0aNu7ty5bsyYMS4lJcWNHz/eLVu2bMj9n7S+/vklufXr10f3+eKLL9y9997rvvWtb7lRo0a5hQsXupaWFrtFJ8HZzsOBAwfcrFmzXEZGhgsGg+7SSy91P/rRj1w4HLZd+Cn4dQwAABMD/j0gAMDQRIAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY+H+FuPwJ5J7kjwAAAABJRU5ErkJggg==\n" 665 | }, 666 | "metadata": {} 667 | } 668 | ], 669 | "source": [ 670 | "plt.imshow(img.permute(1, 2, 0))\n", 671 | "plt.show()" 672 | ] 673 | }, 674 | { 675 | "cell_type": "markdown", 676 | "metadata": { 677 | "id": "yfIVxwixkQdY" 678 | }, 679 | "source": [ 680 | "To acquire predictions from different submodels one should transform input (with shape [1, H, W, C]) into batch (with shape [M, H, W, C]) that consists of M copies of original input (H - height of image, W - width of image, C - number of channels).\n", 681 | "\n", 682 | "As we can see Masksembles submodels produce similar predictions for training set samples." 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": 18, 688 | "metadata": { 689 | "colab": { 690 | "base_uri": "https://localhost:8080/" 691 | }, 692 | "id": "thkg1KJJjY85", 693 | "outputId": "f40287cb-fd9b-4301-c175-ef3488e37ee8" 694 | }, 695 | "outputs": [ 696 | { 697 | "output_type": "stream", 698 | "name": "stdout", 699 | "text": [ 700 | "PREDICTION OF 1 MODEL: 3 CLASS\n", 701 | "PREDICTION OF 2 MODEL: 5 CLASS\n", 702 | "PREDICTION OF 3 MODEL: 5 CLASS\n", 703 | "PREDICTION OF 4 MODEL: 5 CLASS\n" 704 | ] 705 | } 706 | ], 707 | "source": [ 708 | "# Suppose img has shape (28, 28, 1)\n", 709 | "inputs = torch.tile(img[None], (4, 1, 1, 1)) # (4, 1, 28, 28)\n", 710 | "\n", 711 | "# Send to same device as model\n", 712 | "device = next(model.parameters()).device\n", 713 | "inputs_t = inputs.to(device)\n", 714 | "\n", 715 | "# Inference\n", 716 | "model.eval()\n", 717 | "with torch.no_grad():\n", 718 | " predictions = model(inputs_t) # (4, num_classes)\n", 719 | " predicted_classes = predictions.argmax(dim=1)\n", 720 | "\n", 721 | "# Print predictions\n", 722 | "for i, cls in enumerate(predicted_classes):\n", 723 | " print(f\"PREDICTION OF {i+1} MODEL: {cls.item()} CLASS\")\n" 724 | ] 725 | }, 726 | { 727 | "cell_type": "markdown", 728 | "metadata": { 729 | "id": "KH81gNyGlcnh" 730 | }, 731 | "source": [ 732 | "On out-of-distribution samples Masksembles should produce predictions with high variance, let's check it on complex samples from MNIST." 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": 19, 738 | "metadata": { 739 | "id": "l9rbe1s9fBDf" 740 | }, 741 | "outputs": [], 742 | "source": [ 743 | "img = np.array(np.load(\"./complex_sample_mnist.npy\")[::-1, ::-1])\n", 744 | "img = torch.tensor(img).permute(2, 0, 1).float()" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 20, 750 | "metadata": { 751 | "colab": { 752 | "base_uri": "https://localhost:8080/", 753 | "height": 448 754 | }, 755 | "id": "n6YghcdrfdkZ", 756 | "outputId": "d0cda917-f87b-40c9-a424-c5a2c214f644" 757 | }, 758 | "outputs": [ 759 | { 760 | "output_type": "execute_result", 761 | "data": { 762 | "text/plain": [ 763 | "" 764 | ] 765 | }, 766 | "metadata": {}, 767 | "execution_count": 20 768 | }, 769 | { 770 | "output_type": "display_data", 771 | "data": { 772 | "text/plain": [ 773 | "
" 774 | ], 775 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHNBJREFUeJzt3X9w1fW95/HXSSAHkOTEEPKrBBpQwQqktyhpFqVYMoS01wvKuP7qLDgurjS4RWp101XRtjtpcdY6elOde7cF3Sv+miswWksvBhPWNuCCUC7bmiVpLKEkodKbnBBMyI/P/sF67JEg/RzO4Z2E52PmO0PO+b7yffP1i698c04+CTjnnAAAuMCSrAcAAFycKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYGGU9wKcNDAzo6NGjSk1NVSAQsB4HAODJOafOzk7l5eUpKens9zlDroCOHj2q/Px86zEAAOepublZkyZNOuvzQ66AUlNTJUnX6msapdHG0wAAfPWpV+/ozcj/z88mYQVUVVWlxx9/XK2trSosLNTTTz+tuXPnnjP38bfdRmm0RgUoIAAYdv7/CqPnehklIW9CePnll7V27VqtW7dO7733ngoLC1VaWqpjx44l4nAAgGEoIQX0xBNPaOXKlbrzzjv1hS98Qc8++6zGjRunn/3sZ4k4HABgGIp7AZ06dUp79+5VSUnJJwdJSlJJSYnq6urO2L+np0fhcDhqAwCMfHEvoA8//FD9/f3Kzs6Oejw7O1utra1n7F9ZWalQKBTZeAccAFwczH8QtaKiQh0dHZGtubnZeiQAwAUQ93fBZWZmKjk5WW1tbVGPt7W1KScn54z9g8GggsFgvMcAAAxxcb8DSklJ0Zw5c1RdXR15bGBgQNXV1SouLo734QAAw1RCfg5o7dq1Wr58ua6++mrNnTtXTz75pLq6unTnnXcm4nAAgGEoIQV0yy236E9/+pMeeeQRtba26otf/KK2bdt2xhsTAAAXr4BzzlkP8ZfC4bBCoZAWaAkrIVwgo3LPfG3ur/HHm6d6Z04Wd3lnxr57iXcm73/8q3dGkgY6O2PKAfhEn+tVjbaqo6NDaWlpZ93P/F1wAICLEwUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMJWQ0bw8zYMTHFOmb0eWce+5ufe2f+W8PN3plAMl9bAUMd/0oBACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACZYDRvqn5AaU27OrN97Zz7sS/POXPJH74hcr/9K3QAuLO6AAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmGAxUsSsbyDZO5MUGPA/UMA/AmDo4w4IAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACRYjhZI6TsaU+01jvnfm6kv/4J1J6vOOSM7FEAJwIXEHBAAwQQEBAEzEvYAeffRRBQKBqG3GjBnxPgwAYJhLyGtAV111ld56661PDjKKl5oAANES0gyjRo1STk5OIj41AGCESMhrQIcOHVJeXp6mTp2qO+64Q4cPHz7rvj09PQqHw1EbAGDki3sBFRUVaePGjdq2bZueeeYZNTU16brrrlNnZ+eg+1dWVioUCkW2/Hz/t/YCAIafuBdQWVmZbr75Zs2ePVulpaV688031d7erldeeWXQ/SsqKtTR0RHZmpub4z0SAGAISvi7A9LT03XFFVeooaFh0OeDwaCCwWCixwAADDEJ/zmgEydOqLGxUbm5uYk+FABgGIl7Ad1///2qra3VBx98oF//+te68cYblZycrNtuuy3ehwIADGNx/xbckSNHdNttt+n48eOaOHGirr32Wu3atUsTJ06M96EAAMNY3AvopZdeivenRIIF+vpjy3X6Xz5XjGnxzhz/cq93Jnt7pndGkgY+OPuPDACIL9aCAwCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYCLhv5AOw8DJj2KKpTb5f/2SltTtndmw4Gfemf+6baV3RpJSjxz1zri+vpiOBVzsuAMCAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJhgNWyor7Utptzn/sl/FejH/7bUO/MvV27xzrR8PbYVqsc3f8E7k9zU6n+g/n7/TAxcjCudD3T3xBC6MH8njBzcAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADDBYqSIWf/xP3tn/vzadO/MC/85yzvzP+f/o3dGkn599eXemdoPr/DO9PaneGcCAeed+eB/+Z9vSSp47d+8MwO/+V1Mx8LFizsgAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJliMFLFz/otj5v5zg3fm77tu9s78x/+y1TsjSd/JaLwgmVgc6Tvhnfm7E3fFdKzuX6d6Z1J+E9OhcBHjDggAYIICAgCY8C6gnTt36oYbblBeXp4CgYC2bNkS9bxzTo888ohyc3M1duxYlZSU6NChQ/GaFwAwQngXUFdXlwoLC1VVVTXo8+vXr9dTTz2lZ599Vrt379Yll1yi0tJSdXd3n/ewAICRw/tNCGVlZSorKxv0OeecnnzyST300ENasmSJJOn5559Xdna2tmzZoltvvfX8pgUAjBhxfQ2oqalJra2tKikpiTwWCoVUVFSkurq6QTM9PT0Kh8NRGwBg5ItrAbW2tkqSsrOzox7Pzs6OPPdplZWVCoVCkS0/Pz+eIwEAhijzd8FVVFSoo6MjsjU3N1uPBAC4AOJaQDk5OZKktra2qMfb2toiz31aMBhUWlpa1AYAGPniWkAFBQXKyclRdXV15LFwOKzdu3eruLg4nocCAAxz3u+CO3HihBoaPllOpampSfv371dGRoYmT56sNWvW6Ac/+IEuv/xyFRQU6OGHH1ZeXp6WLl0az7kBAMOcdwHt2bNH119/feTjtWvXSpKWL1+ujRs36oEHHlBXV5fuvvtutbe369prr9W2bds0ZsyY+E0NABj2As7FsKJkAoXDYYVCIS3QEo0KjLYeB0NAcuYE70zD05NiOtb//cpzMeV8/Z9TH3ln/v0/fNs7k7+90zsjSYH3P/DODHTGdiyMPH2uVzXaqo6Ojs98Xd/8XXAAgIsTBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMCE969jAM7HqCn53pkjN/lnvjn7Te+MJJ0cOOWd2XfK/5/R3f/ov7L15//psHemv6XVOyNJA319MeUAH9wBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMMFipIhZIBj0zhy6Z5J35o3bH/fOTEwKeGck6d4ji7wzjd+70juT/y/vemf6WCAUIwx3QAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAEywGCmkpOSYYoEZU70zt31tp3emYNQY78zf7P4P3hlJCv4yzTuTtWOfd2aAhUUB7oAAADYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYDFSKDkjPaZcy7xLvTO3hv63d+bnJ7O8M6EXx3tnJCn01vv+oUvGeUeSp03xzgyMHe2diVXykT95Z/pa2xIwCUYy7oAAACYoIACACe8C2rlzp2644Qbl5eUpEAhoy5YtUc+vWLFCgUAgalu8eHG85gUAjBDeBdTV1aXCwkJVVVWddZ/FixerpaUlsr344ovnNSQAYOTxfhNCWVmZysrKPnOfYDConJycmIcCAIx8CXkNqKamRllZWZo+fbpWrVql48ePn3Xfnp4ehcPhqA0AMPLFvYAWL16s559/XtXV1frRj36k2tpalZWVqb+/f9D9KysrFQqFIlt+fn68RwIADEFx/zmgW2+9NfLnWbNmafbs2Zo2bZpqamq0cOHCM/avqKjQ2rVrIx+Hw2FKCAAuAgl/G/bUqVOVmZmphoaGQZ8PBoNKS0uL2gAAI1/CC+jIkSM6fvy4cnNzE30oAMAw4v0tuBMnTkTdzTQ1NWn//v3KyMhQRkaGHnvsMS1btkw5OTlqbGzUAw88oMsuu0ylpaVxHRwAMLx5F9CePXt0/fXXRz7++PWb5cuX65lnntGBAwf03HPPqb29XXl5eVq0aJG+//3vKxgMxm9qAMCw511ACxYskHPurM//8pe/PK+BYCA9ttfdOq4a/J2Nn2VAAe/M73v8FyM9Piu27y4fLZvmncnM8v/RgS9OPOKdyUjp8s7Equ5YgXem/Zf/zjsz6efHvDP99YO/nozhh7XgAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAm4v4ruTH8BD7qiSk39o/J3pn+GFbDvjH1gHdm/M3d3hlJWp72B+9MMDDaO3PgVGzz+bpqdEpswaz3vCPfy5nlnXl13Fe8M59/ecA703/o994ZJB53QAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAEywGCnUf+zDmHJTXgl6Z27TWu9M0jXt3pm0sbEt9vlO6DLvzL8ey/POBLZd6p0JtjvvzL/NiO1rzK//7S7vzH/P9V/AtOnrE7wz7zdf5Z25lMVIhyTugAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJhgMVLI9Z6KKdff0OSdmfL3x70zPVdf7p3pGzfeOyNJv83O9s6k/7HPOzNm52+8MwMnT3pnLk0PeWck6bXPfck7E8tipONH+V97/f5r4GKI4g4IAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACRYjxQXVHw57Z0bt2Ouf8U6cNibGnK+BGDKBUf5/q0AoLYYjSeq9MF+bpiT5L+SqQPzngA3ugAAAJiggAIAJrwKqrKzUNddco9TUVGVlZWnp0qWqr6+P2qe7u1vl5eWaMGGCxo8fr2XLlqmtrS2uQwMAhj+vAqqtrVV5ebl27dql7du3q7e3V4sWLVJXV1dkn/vuu0+vv/66Xn31VdXW1uro0aO66aab4j44AGB483pVc9u2bVEfb9y4UVlZWdq7d6/mz5+vjo4O/fSnP9WmTZv01a9+VZK0YcMGXXnlldq1a5e+/OUvx29yAMCwdl6vAXV0dEiSMjIyJEl79+5Vb2+vSkpKIvvMmDFDkydPVl1d3aCfo6enR+FwOGoDAIx8MRfQwMCA1qxZo3nz5mnmzJmSpNbWVqWkpCg9PT1q3+zsbLW2tg76eSorKxUKhSJbfn5+rCMBAIaRmAuovLxcBw8e1EsvvXReA1RUVKijoyOyNTc3n9fnAwAMDzH9vN7q1av1xhtvaOfOnZo0aVLk8ZycHJ06dUrt7e1Rd0FtbW3KyckZ9HMFg0EFg8FYxgAADGNed0DOOa1evVqbN2/Wjh07VFBQEPX8nDlzNHr0aFVXV0ceq6+v1+HDh1VcXByfiQEAI4LXHVB5ebk2bdqkrVu3KjU1NfK6TigU0tixYxUKhXTXXXdp7dq1ysjIUFpamu69914VFxfzDjgAQBSvAnrmmWckSQsWLIh6fMOGDVqxYoUk6cc//rGSkpK0bNky9fT0qLS0VD/5yU/iMiwAYOTwKiDn3Dn3GTNmjKqqqlRVVRXzUMCIF/BfUTN5YqZ35v0f+Gck6bl5/+Cdeben1zuzdf8XvTOXH/zIO4OhibXgAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmYvqNqAAMJPl/vegG/FfdlqSdJ2Z4ZzYc8P+lkwUve0cU+NV+/xCGJO6AAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmGAxUsCCc96RvqMt3pnp/+nP3hlJ+tXoid6ZK3p/550ZONXrncHIwR0QAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAEyxGCgwXMSxgOtDdHduxYs0BHrgDAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACa8Cqqys1DXXXKPU1FRlZWVp6dKlqq+vj9pnwYIFCgQCUds999wT16EBAMOfVwHV1taqvLxcu3bt0vbt29Xb26tFixapq6srar+VK1eqpaUlsq1fvz6uQwMAhj+v34i6bdu2qI83btyorKws7d27V/Pnz488Pm7cOOXk5MRnQgDAiHRerwF1dHRIkjIyMqIef+GFF5SZmamZM2eqoqJCJ0+ePOvn6OnpUTgcjtoAACOf1x3QXxoYGNCaNWs0b948zZw5M/L47bffrilTpigvL08HDhzQgw8+qPr6er322muDfp7Kyko99thjsY4BABimAs45F0tw1apV+sUvfqF33nlHkyZNOut+O3bs0MKFC9XQ0KBp06ad8XxPT496enoiH4fDYeXn52uBlmhUYHQsowEADPW5XtVoqzo6OpSWlnbW/WK6A1q9erXeeOMN7dy58zPLR5KKiook6awFFAwGFQwGYxkDADCMeRWQc0733nuvNm/erJqaGhUUFJwzs3//fklSbm5uTAMCAEYmrwIqLy/Xpk2btHXrVqWmpqq1tVWSFAqFNHbsWDU2NmrTpk362te+pgkTJujAgQO67777NH/+fM2ePTshfwEAwPDk9RpQIBAY9PENGzZoxYoVam5u1je+8Q0dPHhQXV1dys/P14033qiHHnroM78P+JfC4bBCoRCvAQHAMJWQ14DO1VX5+fmqra31+ZQAgIsUa8EBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAEyMsh7g05xzkqQ+9UrOeBgAgLc+9Ur65P/nZzPkCqizs1OS9I7eNJ4EAHA+Ojs7FQqFzvp8wJ2roi6wgYEBHT16VKmpqQoEAlHPhcNh5efnq7m5WWlpaUYT2uM8nMZ5OI3zcBrn4bShcB6cc+rs7FReXp6Sks7+Ss+QuwNKSkrSpEmTPnOftLS0i/oC+xjn4TTOw2mch9M4D6dZn4fPuvP5GG9CAACYoIAAACaGVQEFg0GtW7dOwWDQehRTnIfTOA+ncR5O4zycNpzOw5B7EwIA4OIwrO6AAAAjBwUEADBBAQEATFBAAAATw6aAqqqq9PnPf15jxoxRUVGR3n33XeuRLrhHH31UgUAgapsxY4b1WAm3c+dO3XDDDcrLy1MgENCWLVuinnfO6ZFHHlFubq7Gjh2rkpISHTp0yGbYBDrXeVixYsUZ18fixYtthk2QyspKXXPNNUpNTVVWVpaWLl2q+vr6qH26u7tVXl6uCRMmaPz48Vq2bJna2tqMJk6Mv+Y8LFiw4Izr4Z577jGaeHDDooBefvllrV27VuvWrdN7772nwsJClZaW6tixY9ajXXBXXXWVWlpaIts777xjPVLCdXV1qbCwUFVVVYM+v379ej311FN69tlntXv3bl1yySUqLS1Vd3f3BZ40sc51HiRp8eLFUdfHiy++eAEnTLza2lqVl5dr165d2r59u3p7e7Vo0SJ1dXVF9rnvvvv0+uuv69VXX1Vtba2OHj2qm266yXDq+PtrzoMkrVy5Mup6WL9+vdHEZ+GGgblz57ry8vLIx/39/S4vL89VVlYaTnXhrVu3zhUWFlqPYUqS27x5c+TjgYEBl5OT4x5//PHIY+3t7S4YDLoXX3zRYMIL49PnwTnnli9f7pYsWWIyj5Vjx445Sa62ttY5d/q//ejRo92rr74a2ed3v/udk+Tq6uqsxky4T58H55z7yle+4r71rW/ZDfVXGPJ3QKdOndLevXtVUlISeSwpKUklJSWqq6sznMzGoUOHlJeXp6lTp+qOO+7Q4cOHrUcy1dTUpNbW1qjrIxQKqaio6KK8PmpqapSVlaXp06dr1apVOn78uPVICdXR0SFJysjIkCTt3btXvb29UdfDjBkzNHny5BF9PXz6PHzshRdeUGZmpmbOnKmKigqdPHnSYryzGnKLkX7ahx9+qP7+fmVnZ0c9np2drffff99oKhtFRUXauHGjpk+frpaWFj322GO67rrrdPDgQaWmplqPZ6K1tVWSBr0+Pn7uYrF48WLddNNNKigoUGNjo7773e+qrKxMdXV1Sk5Oth4v7gYGBrRmzRrNmzdPM2fOlHT6ekhJSVF6enrUviP5ehjsPEjS7bffrilTpigvL08HDhzQgw8+qPr6er322muG00Yb8gWET5SVlUX+PHv2bBUVFWnKlCl65ZVXdNdddxlOhqHg1ltvjfx51qxZmj17tqZNm6aamhotXLjQcLLEKC8v18GDBy+K10E/y9nOw9133x3586xZs5Sbm6uFCxeqsbFR06ZNu9BjDmrIfwsuMzNTycnJZ7yLpa2tTTk5OUZTDQ3p6em64oor1NDQYD2KmY+vAa6PM02dOlWZmZkj8vpYvXq13njjDb399ttRv74lJydHp06dUnt7e9T+I/V6ONt5GExRUZEkDanrYcgXUEpKiubMmaPq6urIYwMDA6qurlZxcbHhZPZOnDihxsZG5ebmWo9ipqCgQDk5OVHXRzgc1u7duy/66+PIkSM6fvz4iLo+nHNavXq1Nm/erB07dqigoCDq+Tlz5mj06NFR10N9fb0OHz48oq6Hc52Hwezfv1+Shtb1YP0uiL/GSy+95ILBoNu4caP77W9/6+6++26Xnp7uWltbrUe7oL797W+7mpoa19TU5H71q1+5kpISl5mZ6Y4dO2Y9WkJ1dna6ffv2uX379jlJ7oknnnD79u1zf/jDH5xzzv3whz906enpbuvWre7AgQNuyZIlrqCgwH300UfGk8fXZ52Hzs5Od//997u6ujrX1NTk3nrrLfelL33JXX755a67u9t69LhZtWqVC4VCrqamxrW0tES2kydPRva555573OTJk92OHTvcnj17XHFxsSsuLjacOv7OdR4aGhrc9773Pbdnzx7X1NTktm7d6qZOnermz59vPHm0YVFAzjn39NNPu8mTJ7uUlBQ3d+5ct2vXLuuRLrhbbrnF5ebmupSUFPe5z33O3XLLLa6hocF6rIR7++23naQztuXLlzvnTr8V++GHH3bZ2dkuGAy6hQsXuvr6etuhE+CzzsPJkyfdokWL3MSJE93o0aPdlClT3MqVK0fcF2mD/f0luQ0bNkT2+eijj9w3v/lNd+mll7px48a5G2+80bW0tNgNnQDnOg+HDx928+fPdxkZGS4YDLrLLrvMfec733EdHR22g38Kv44BAGBiyL8GBAAYmSggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJj4f016Bym3u3cJAAAAAElFTkSuQmCC\n" 776 | }, 777 | "metadata": {} 778 | } 779 | ], 780 | "source": [ 781 | "plt.imshow(img.permute(1, 2, 0))" 782 | ] 783 | }, 784 | { 785 | "cell_type": "code", 786 | "execution_count": 21, 787 | "metadata": { 788 | "colab": { 789 | "base_uri": "https://localhost:8080/" 790 | }, 791 | "id": "65Y_-0h9U96h", 792 | "outputId": "559893c1-38d5-48fa-bde9-db0f26c421bb" 793 | }, 794 | "outputs": [ 795 | { 796 | "output_type": "stream", 797 | "name": "stdout", 798 | "text": [ 799 | "PREDICTION OF 1 MODEL: 3 CLASS\n", 800 | "PREDICTION OF 2 MODEL: 5 CLASS\n", 801 | "PREDICTION OF 3 MODEL: 5 CLASS\n", 802 | "PREDICTION OF 4 MODEL: 5 CLASS\n" 803 | ] 804 | } 805 | ], 806 | "source": [ 807 | "# Suppose img has shape (28, 28, 1)\n", 808 | "# Duplicate it 4 times\n", 809 | "inputs = torch.tile(img[None], (4, 1, 1, 1)) # (4, 28, 28, 1)\n", 810 | "\n", 811 | "# Move to same device as model\n", 812 | "device = next(model.parameters()).device\n", 813 | "inputs_t = inputs.to(device)\n", 814 | "\n", 815 | "# Run the model\n", 816 | "model.eval()\n", 817 | "with torch.no_grad():\n", 818 | " predictions = model(inputs_t) # shape: (4, num_classes)\n", 819 | " predicted_classes = predictions.argmax(dim=1) # argmax over class dimension\n", 820 | "\n", 821 | "# Print results like in TensorFlow\n", 822 | "for i, cls in enumerate(predicted_classes):\n", 823 | " print(f\"PREDICTION OF {i+1} MODEL: {cls.item()} CLASS\")\n" 824 | ] 825 | }, 826 | { 827 | "cell_type": "markdown", 828 | "metadata": { 829 | "id": "3e0fe845" 830 | }, 831 | "source": [ 832 | "\n", 833 | "## 8. Next Steps \n", 834 | "\n", 835 | "- Try different `n` and `scale` to explore accuracy vs. diversity. \n", 836 | "- Aggregate submodel logits (mean/median) for stronger predictions. \n", 837 | "- Visualize per‑submodel confidence for **uncertainty estimation**. \n", 838 | "- Swap MNIST for CIFAR‑10 or your own dataset to test generalization.\n" 839 | ] 840 | } 841 | ], 842 | "metadata": { 843 | "accelerator": "GPU", 844 | "colab": { 845 | "gpuType": "T4", 846 | "provenance": [], 847 | "toc_visible": true 848 | }, 849 | "kernelspec": { 850 | "display_name": "Python 3", 851 | "name": "python3" 852 | }, 853 | "language_info": { 854 | "codemirror_mode": { 855 | "name": "ipython", 856 | "version": 2 857 | }, 858 | "file_extension": ".py", 859 | "mimetype": "text/x-python", 860 | "name": "python", 861 | "nbconvert_exporter": "python", 862 | "pygments_lexer": "ipython2", 863 | "version": "2.7.6" 864 | } 865 | }, 866 | "nbformat": 4, 867 | "nbformat_minor": 0 868 | } --------------------------------------------------------------------------------