├── .github └── workflows │ └── build.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.org ├── batchrenorm ├── __init__.py ├── __version__.py └── batchrenorm.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── conftest.py └── test_batchrenorm.py /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: [push, pull_request] 4 | jobs: 5 | linting: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v1 9 | - name: Set up Python 3.7 10 | uses: actions/setup-python@v1 11 | with: 12 | python-version: 3.7 13 | - name: Install dependencies 14 | run: | 15 | python -m pip install --upgrade pip 16 | pip install poetry==1.0.0 17 | poetry install 18 | - name: Linting 19 | run: poetry run pre-commit run --all-files 20 | 21 | testing: 22 | needs: linting 23 | runs-on: ${{ matrix.os }} 24 | strategy: 25 | matrix: 26 | os: [ubuntu-latest, macos-latest] 27 | steps: 28 | - uses: actions/checkout@v1 29 | - name: Set up Python 3.7 30 | uses: actions/setup-python@v1 31 | with: 32 | python-version: 3.7 33 | - name: Install dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install poetry==1.0.0 37 | poetry install 38 | - name: Testing 39 | run: poetry run pytest --cov=./batchrenorm --cov-report=xml 40 | - name: Upload coverage report 41 | uses: codecov/codecov-action@v1 42 | with: 43 | token: ${{ secrets.CODECOV_TOKEN }} 44 | file: ./coverage.xml 45 | flags: unittests 46 | fail_ci_if_error: true 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.egg-info 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-added-large-files 8 | 9 | - repo: https://github.com/psf/black 10 | rev: 19.10b0 11 | hooks: 12 | - id: black 13 | 14 | - repo: https://gitlab.com/pycqa/flake8 15 | rev: 3.7.9 16 | hooks: 17 | - id: flake8 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 2 | 3 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 4 | 5 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 6 | -------------------------------------------------------------------------------- /README.org: -------------------------------------------------------------------------------- 1 | #+TITLE: Batch Renormalization 2 | 3 | [[https://github.com/ludvb/batchrenorm/actions?query=workflow%3Abuild+branch%3Amaster][https://github.com/ludvb/batchrenorm/workflows/build/badge.svg?branch=master]] 4 | [[https://codecov.io/gh/ludvb/batchrenorm/branch/master][https://codecov.io/gh/ludvb/batchrenorm/branch/master/graph/badge.svg]] 5 | 6 | Pytorch implementation of Batch Renormalization, introduced in https://arxiv.org/abs/1702.03275. 7 | 8 | * Installation 9 | 10 | Requires Python 3.7. 11 | To install the latest version of this package, run 12 | 13 | #+BEGIN_SRC 14 | pip install git+https://github.com/ludvb/batchrenorm@master 15 | #+END_SRC 16 | 17 | * Usage 18 | 19 | #+BEGIN_SRC python 20 | import torch 21 | from batchrenorm import BatchRenorm2d 22 | 23 | # Create batch renormalization layer 24 | br = BatchRenorm2d(3) 25 | 26 | # Create some example data with dimensions N x C x H x W 27 | x = torch.randn(1, 3, 10, 10) 28 | 29 | # Batch renormalize the data 30 | x = br(x) 31 | #+END_SRC 32 | -------------------------------------------------------------------------------- /batchrenorm/__init__.py: -------------------------------------------------------------------------------- 1 | from .__version__ import __version__ 2 | from .batchrenorm import BatchRenorm1d, BatchRenorm2d, BatchRenorm3d 3 | 4 | 5 | __all__ = ["__version__", "BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"] 6 | -------------------------------------------------------------------------------- /batchrenorm/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /batchrenorm/batchrenorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ["BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"] 5 | 6 | 7 | class BatchRenorm(torch.jit.ScriptModule): 8 | def __init__( 9 | self, 10 | num_features: int, 11 | eps: float = 1e-3, 12 | momentum: float = 0.01, 13 | affine: bool = True, 14 | ): 15 | super().__init__() 16 | self.register_buffer( 17 | "running_mean", torch.zeros(num_features, dtype=torch.float) 18 | ) 19 | self.register_buffer( 20 | "running_std", torch.ones(num_features, dtype=torch.float) 21 | ) 22 | self.register_buffer( 23 | "num_batches_tracked", torch.tensor(0, dtype=torch.long) 24 | ) 25 | self.weight = torch.nn.Parameter( 26 | torch.ones(num_features, dtype=torch.float) 27 | ) 28 | self.bias = torch.nn.Parameter( 29 | torch.zeros(num_features, dtype=torch.float) 30 | ) 31 | self.affine = affine 32 | self.eps = eps 33 | self.step = 0 34 | self.momentum = momentum 35 | 36 | def _check_input_dim(self, x: torch.Tensor) -> None: 37 | raise NotImplementedError() # pragma: no cover 38 | 39 | @property 40 | def rmax(self) -> torch.Tensor: 41 | return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_( 42 | 1.0, 3.0 43 | ) 44 | 45 | @property 46 | def dmax(self) -> torch.Tensor: 47 | return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_( 48 | 0.0, 5.0 49 | ) 50 | 51 | def forward(self, x: torch.Tensor, mask = None) -> torch.Tensor: 52 | ''' 53 | Mask is a boolean tensor used for indexing, where True values are padded 54 | i.e for 3D input, mask should be of shape (batch_size, seq_len) 55 | mask is used to prevent padded values from affecting the batch statistics 56 | ''' 57 | self._check_input_dim(x) 58 | if x.dim() > 2: 59 | x = x.transpose(1, -1) 60 | if self.training: 61 | dims = [i for i in range(x.dim() - 1)] 62 | if mask is not None: 63 | z = x[~mask] 64 | batch_mean = z.mean(0) 65 | batch_std = z.std(0, unbiased=False) + self.eps 66 | else: 67 | batch_mean = x.mean(dims) 68 | batch_std = x.std(dims, unbiased=False) + self.eps 69 | 70 | r = ( 71 | batch_std.detach() / self.running_std.view_as(batch_std) 72 | ).clamp_(1 / self.rmax, self.rmax) 73 | d = ( 74 | (batch_mean.detach() - self.running_mean.view_as(batch_mean)) 75 | / self.running_std.view_as(batch_std) 76 | ).clamp_(-self.dmax, self.dmax) 77 | x = (x - batch_mean) / batch_std * r + d 78 | self.running_mean += self.momentum * ( 79 | batch_mean.detach() - self.running_mean 80 | ) 81 | self.running_std += self.momentum * ( 82 | batch_std.detach() - self.running_std 83 | ) 84 | self.num_batches_tracked += 1 85 | else: 86 | x = (x - self.running_mean) / self.running_std 87 | if self.affine: 88 | x = self.weight * x + self.bias 89 | if x.dim() > 2: 90 | x = x.transpose(1, -1) 91 | return x 92 | 93 | 94 | class BatchRenorm1d(BatchRenorm): 95 | def _check_input_dim(self, x: torch.Tensor) -> None: 96 | if x.dim() not in [2, 3]: 97 | raise ValueError("expected 2D or 3D input (got {x.dim()}D input)") 98 | 99 | 100 | class BatchRenorm2d(BatchRenorm): 101 | def _check_input_dim(self, x: torch.Tensor) -> None: 102 | if x.dim() != 4: 103 | raise ValueError("expected 4D input (got {x.dim()}D input)") 104 | 105 | 106 | class BatchRenorm3d(BatchRenorm): 107 | def _check_input_dim(self, x: torch.Tensor) -> None: 108 | if x.dim() != 5: 109 | raise ValueError("expected 5D input (got {x.dim()}D input)") 110 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | category = "dev" 3 | description = "A few extensions to pyyaml." 4 | name = "aspy.yaml" 5 | optional = false 6 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 7 | version = "1.3.0" 8 | 9 | [package.dependencies] 10 | pyyaml = "*" 11 | 12 | [[package]] 13 | category = "dev" 14 | description = "Atomic file writes." 15 | marker = "sys_platform == \"win32\"" 16 | name = "atomicwrites" 17 | optional = false 18 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 19 | version = "1.3.0" 20 | 21 | [[package]] 22 | category = "dev" 23 | description = "Classes Without Boilerplate" 24 | name = "attrs" 25 | optional = false 26 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 27 | version = "19.3.0" 28 | 29 | [package.extras] 30 | azure-pipelines = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "pytest-azurepipelines"] 31 | dev = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "sphinx", "pre-commit"] 32 | docs = ["sphinx", "zope.interface"] 33 | tests = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] 34 | 35 | [[package]] 36 | category = "dev" 37 | description = "Validate configuration and produce human readable error messages." 38 | name = "cfgv" 39 | optional = false 40 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 41 | version = "2.0.1" 42 | 43 | [package.dependencies] 44 | six = "*" 45 | 46 | [[package]] 47 | category = "dev" 48 | description = "Cross-platform colored terminal text." 49 | marker = "sys_platform == \"win32\"" 50 | name = "colorama" 51 | optional = false 52 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 53 | version = "0.4.3" 54 | 55 | [[package]] 56 | category = "dev" 57 | description = "Code coverage measurement for Python" 58 | name = "coverage" 59 | optional = false 60 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" 61 | version = "5.0.1" 62 | 63 | [package.extras] 64 | toml = ["toml"] 65 | 66 | [[package]] 67 | category = "main" 68 | description = "Clean single-source support for Python 3 and 2" 69 | name = "future" 70 | optional = false 71 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 72 | version = "0.18.2" 73 | 74 | [[package]] 75 | category = "dev" 76 | description = "File identification library for Python" 77 | name = "identify" 78 | optional = false 79 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" 80 | version = "1.4.9" 81 | 82 | [package.extras] 83 | license = ["editdistance"] 84 | 85 | [[package]] 86 | category = "dev" 87 | description = "Read metadata from Python packages" 88 | marker = "python_version < \"3.8\"" 89 | name = "importlib-metadata" 90 | optional = false 91 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" 92 | version = "1.3.0" 93 | 94 | [package.dependencies] 95 | zipp = ">=0.5" 96 | 97 | [package.extras] 98 | docs = ["sphinx", "rst.linker"] 99 | testing = ["packaging", "importlib-resources"] 100 | 101 | [[package]] 102 | category = "dev" 103 | description = "More routines for operating on iterables, beyond itertools" 104 | name = "more-itertools" 105 | optional = false 106 | python-versions = ">=3.5" 107 | version = "8.0.2" 108 | 109 | [[package]] 110 | category = "dev" 111 | description = "Node.js virtual environment builder" 112 | name = "nodeenv" 113 | optional = false 114 | python-versions = "*" 115 | version = "1.3.3" 116 | 117 | [[package]] 118 | category = "main" 119 | description = "NumPy is the fundamental package for array computing with Python." 120 | name = "numpy" 121 | optional = false 122 | python-versions = ">=3.5" 123 | version = "1.18.0" 124 | 125 | [[package]] 126 | category = "dev" 127 | description = "Core utilities for Python packages" 128 | name = "packaging" 129 | optional = false 130 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 131 | version = "19.2" 132 | 133 | [package.dependencies] 134 | pyparsing = ">=2.0.2" 135 | six = "*" 136 | 137 | [[package]] 138 | category = "dev" 139 | description = "plugin and hook calling mechanisms for python" 140 | name = "pluggy" 141 | optional = false 142 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 143 | version = "0.13.1" 144 | 145 | [package.dependencies] 146 | [package.dependencies.importlib-metadata] 147 | python = "<3.8" 148 | version = ">=0.12" 149 | 150 | [package.extras] 151 | dev = ["pre-commit", "tox"] 152 | 153 | [[package]] 154 | category = "dev" 155 | description = "A framework for managing and maintaining multi-language pre-commit hooks." 156 | name = "pre-commit" 157 | optional = false 158 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 159 | version = "1.20.0" 160 | 161 | [package.dependencies] 162 | "aspy.yaml" = "*" 163 | cfgv = ">=2.0.0" 164 | identify = ">=1.0.0" 165 | nodeenv = ">=0.11.1" 166 | pyyaml = "*" 167 | six = "*" 168 | toml = "*" 169 | virtualenv = ">=15.2" 170 | 171 | [package.dependencies.importlib-metadata] 172 | python = "<3.8" 173 | version = "*" 174 | 175 | [[package]] 176 | category = "dev" 177 | description = "library with cross-python path, ini-parsing, io, code, log facilities" 178 | name = "py" 179 | optional = false 180 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 181 | version = "1.8.0" 182 | 183 | [[package]] 184 | category = "dev" 185 | description = "Python parsing module" 186 | name = "pyparsing" 187 | optional = false 188 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 189 | version = "2.4.6" 190 | 191 | [[package]] 192 | category = "dev" 193 | description = "pytest: simple powerful testing with Python" 194 | name = "pytest" 195 | optional = false 196 | python-versions = ">=3.5" 197 | version = "5.3.2" 198 | 199 | [package.dependencies] 200 | atomicwrites = ">=1.0" 201 | attrs = ">=17.4.0" 202 | colorama = "*" 203 | more-itertools = ">=4.0.0" 204 | packaging = "*" 205 | pluggy = ">=0.12,<1.0" 206 | py = ">=1.5.0" 207 | wcwidth = "*" 208 | 209 | [package.dependencies.importlib-metadata] 210 | python = "<3.8" 211 | version = ">=0.12" 212 | 213 | [package.extras] 214 | testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] 215 | 216 | [[package]] 217 | category = "dev" 218 | description = "Pytest plugin for measuring coverage." 219 | name = "pytest-cov" 220 | optional = false 221 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 222 | version = "2.8.1" 223 | 224 | [package.dependencies] 225 | coverage = ">=4.4" 226 | pytest = ">=3.6" 227 | 228 | [package.extras] 229 | testing = ["fields", "hunter", "process-tests (2.0.2)", "six", "virtualenv"] 230 | 231 | [[package]] 232 | category = "dev" 233 | description = "YAML parser and emitter for Python" 234 | name = "pyyaml" 235 | optional = false 236 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 237 | version = "5.2" 238 | 239 | [[package]] 240 | category = "dev" 241 | description = "Python 2 and 3 compatibility utilities" 242 | name = "six" 243 | optional = false 244 | python-versions = ">=2.6, !=3.0.*, !=3.1.*" 245 | version = "1.13.0" 246 | 247 | [[package]] 248 | category = "dev" 249 | description = "Python Library for Tom's Obvious, Minimal Language" 250 | name = "toml" 251 | optional = false 252 | python-versions = "*" 253 | version = "0.10.0" 254 | 255 | [[package]] 256 | category = "main" 257 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 258 | name = "torch" 259 | optional = false 260 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 261 | version = "1.3.1" 262 | 263 | [package.dependencies] 264 | future = "*" 265 | numpy = "*" 266 | 267 | [[package]] 268 | category = "dev" 269 | description = "Virtual Python Environment builder" 270 | name = "virtualenv" 271 | optional = false 272 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" 273 | version = "16.7.9" 274 | 275 | [package.extras] 276 | docs = ["sphinx (>=1.8.0,<2)", "towncrier (>=18.5.0)", "sphinx-rtd-theme (>=0.4.2,<1)"] 277 | testing = ["pytest (>=4.0.0,<5)", "coverage (>=4.5.0,<5)", "pytest-timeout (>=1.3.0,<2)", "six (>=1.10.0,<2)", "pytest-xdist", "pytest-localserver", "pypiserver", "mock", "xonsh"] 278 | 279 | [[package]] 280 | category = "dev" 281 | description = "Measures number of Terminal column cells of wide-character codes" 282 | name = "wcwidth" 283 | optional = false 284 | python-versions = "*" 285 | version = "0.1.7" 286 | 287 | [[package]] 288 | category = "dev" 289 | description = "Backport of pathlib-compatible object wrapper for zip files" 290 | marker = "python_version < \"3.8\"" 291 | name = "zipp" 292 | optional = false 293 | python-versions = ">=2.7" 294 | version = "0.6.0" 295 | 296 | [package.dependencies] 297 | more-itertools = "*" 298 | 299 | [package.extras] 300 | docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] 301 | testing = ["pathlib2", "contextlib2", "unittest2"] 302 | 303 | [metadata] 304 | content-hash = "9a464bada20020895a83a0eb922ea1b52c4030c52a74f272c10ed2ea9b6eed9e" 305 | python-versions = "3.7" 306 | 307 | [metadata.files] 308 | "aspy.yaml" = [ 309 | {file = "aspy.yaml-1.3.0-py2.py3-none-any.whl", hash = "sha256:463372c043f70160a9ec950c3f1e4c3a82db5fca01d334b6bc89c7164d744bdc"}, 310 | {file = "aspy.yaml-1.3.0.tar.gz", hash = "sha256:e7c742382eff2caed61f87a39d13f99109088e5e93f04d76eb8d4b28aa143f45"}, 311 | ] 312 | atomicwrites = [ 313 | {file = "atomicwrites-1.3.0-py2.py3-none-any.whl", hash = "sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4"}, 314 | {file = "atomicwrites-1.3.0.tar.gz", hash = "sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6"}, 315 | ] 316 | attrs = [ 317 | {file = "attrs-19.3.0-py2.py3-none-any.whl", hash = "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c"}, 318 | {file = "attrs-19.3.0.tar.gz", hash = "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72"}, 319 | ] 320 | cfgv = [ 321 | {file = "cfgv-2.0.1-py2.py3-none-any.whl", hash = "sha256:fbd93c9ab0a523bf7daec408f3be2ed99a980e20b2d19b50fc184ca6b820d289"}, 322 | {file = "cfgv-2.0.1.tar.gz", hash = "sha256:edb387943b665bf9c434f717bf630fa78aecd53d5900d2e05da6ad6048553144"}, 323 | ] 324 | colorama = [ 325 | {file = "colorama-0.4.3-py2.py3-none-any.whl", hash = "sha256:7d73d2a99753107a36ac6b455ee49046802e59d9d076ef8e47b61499fa29afff"}, 326 | {file = "colorama-0.4.3.tar.gz", hash = "sha256:e96da0d330793e2cb9485e9ddfd918d456036c7149416295932478192f4436a1"}, 327 | ] 328 | coverage = [ 329 | {file = "coverage-5.0.1-cp27-cp27m-macosx_10_12_x86_64.whl", hash = "sha256:c90bda74e16bcd03861b09b1d37c0a4158feda5d5a036bb2d6e58de6ff65793e"}, 330 | {file = "coverage-5.0.1-cp27-cp27m-macosx_10_13_intel.whl", hash = "sha256:bb3d29df5d07d5399d58a394d0ef50adf303ab4fbf66dfd25b9ef258effcb692"}, 331 | {file = "coverage-5.0.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:1ca43dbd739c0fc30b0a3637a003a0d2c7edc1dd618359d58cc1e211742f8bd1"}, 332 | {file = "coverage-5.0.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:591506e088901bdc25620c37aec885e82cc896528f28c57e113751e3471fc314"}, 333 | {file = "coverage-5.0.1-cp27-cp27m-win32.whl", hash = "sha256:a50b0888d8a021a3342d36a6086501e30de7d840ab68fca44913e97d14487dc1"}, 334 | {file = "coverage-5.0.1-cp27-cp27m-win_amd64.whl", hash = "sha256:c792d3707a86c01c02607ae74364854220fb3e82735f631cd0a345dea6b4cee5"}, 335 | {file = "coverage-5.0.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:f425f50a6dd807cb9043d15a4fcfba3b5874a54d9587ccbb748899f70dc18c47"}, 336 | {file = "coverage-5.0.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:25b8f60b5c7da71e64c18888f3067d5b6f1334b9681876b2fb41eea26de881ae"}, 337 | {file = "coverage-5.0.1-cp35-cp35m-macosx_10_12_x86_64.whl", hash = "sha256:7362a7f829feda10c7265b553455de596b83d1623b3d436b6d3c51c688c57bf6"}, 338 | {file = "coverage-5.0.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:fcd4459fe35a400b8f416bc57906862693c9f88b66dc925e7f2a933e77f6b18b"}, 339 | {file = "coverage-5.0.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:40fbfd6b044c9db13aeec1daf5887d322c710d811f944011757526ef6e323fd9"}, 340 | {file = "coverage-5.0.1-cp35-cp35m-win32.whl", hash = "sha256:7f2675750c50151f806070ec11258edf4c328340916c53bac0adbc465abd6b1e"}, 341 | {file = "coverage-5.0.1-cp35-cp35m-win_amd64.whl", hash = "sha256:24bcfa86fd9ce86b73a8368383c39d919c497a06eebb888b6f0c12f13e920b1a"}, 342 | {file = "coverage-5.0.1-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:eeafb646f374988c22c8e6da5ab9fb81367ecfe81c70c292623373d2a021b1a1"}, 343 | {file = "coverage-5.0.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:2ca2cd5264e84b2cafc73f0045437f70c6378c0d7dbcddc9ee3fe192c1e29e5d"}, 344 | {file = "coverage-5.0.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:2cc707fc9aad2592fc686d63ef72dc0031fc98b6fb921d2f5395d9ab84fbc3ef"}, 345 | {file = "coverage-5.0.1-cp36-cp36m-win32.whl", hash = "sha256:04b961862334687549eb91cd5178a6fbe977ad365bddc7c60f2227f2f9880cf4"}, 346 | {file = "coverage-5.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:232f0b52a5b978288f0bbc282a6c03fe48cd19a04202df44309919c142b3bb9c"}, 347 | {file = "coverage-5.0.1-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:cfce79ce41cc1a1dc7fc85bb41eeeb32d34a4cf39a645c717c0550287e30ff06"}, 348 | {file = "coverage-5.0.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:46c9c6a1d1190c0b75ec7c0f339088309952b82ae8d67a79ff1319eb4e749b96"}, 349 | {file = "coverage-5.0.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:1cbb88b34187bdb841f2599770b7e6ff8e259dc3bb64fc7893acf44998acf5f8"}, 350 | {file = "coverage-5.0.1-cp37-cp37m-win32.whl", hash = "sha256:ff3936dd5feaefb4f91c8c1f50a06c588b5dc69fba4f7d9c79a6617ad80bb7df"}, 351 | {file = "coverage-5.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:65bead1ac8c8930cf92a1ccaedcce19a57298547d5d1db5c9d4d068a0675c38b"}, 352 | {file = "coverage-5.0.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:348630edea485f4228233c2f310a598abf8afa5f8c716c02a9698089687b6085"}, 353 | {file = "coverage-5.0.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:960d7f42277391e8b1c0b0ae427a214e1b31a1278de6b73f8807b20c2e913bba"}, 354 | {file = "coverage-5.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0101888bd1592a20ccadae081ba10e8b204d20235d18d05c6f7d5e904a38fc10"}, 355 | {file = "coverage-5.0.1-cp38-cp38m-win32.whl", hash = "sha256:c0fff2733f7c2950f58a4fd09b5db257b00c6fec57bf3f68c5bae004d804b407"}, 356 | {file = "coverage-5.0.1-cp38-cp38m-win_amd64.whl", hash = "sha256:5f622f19abda4e934938e24f1d67599249abc201844933a6f01aaa8663094489"}, 357 | {file = "coverage-5.0.1-cp39-cp39m-win32.whl", hash = "sha256:2714160a63da18aed9340c70ed514973971ee7e665e6b336917ff4cca81a25b1"}, 358 | {file = "coverage-5.0.1-cp39-cp39m-win_amd64.whl", hash = "sha256:b7dbc5e8c39ea3ad3db22715f1b5401cd698a621218680c6daf42c2f9d36e205"}, 359 | {file = "coverage-5.0.1.tar.gz", hash = "sha256:5ac71bba1e07eab403b082c4428f868c1c9e26a21041436b4905c4c3d4e49b08"}, 360 | ] 361 | future = [ 362 | {file = "future-0.18.2.tar.gz", hash = "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d"}, 363 | ] 364 | identify = [ 365 | {file = "identify-1.4.9-py2.py3-none-any.whl", hash = "sha256:72e9c4ed3bc713c7045b762b0d2e2115c572b85abfc1f4604f5a4fd4c6642b71"}, 366 | {file = "identify-1.4.9.tar.gz", hash = "sha256:6f44e637caa40d1b4cb37f6ed3b262ede74901d28b1cc5b1fc07360871edd65d"}, 367 | ] 368 | importlib-metadata = [ 369 | {file = "importlib_metadata-1.3.0-py2.py3-none-any.whl", hash = "sha256:d95141fbfa7ef2ec65cfd945e2af7e5a6ddbd7c8d9a25e66ff3be8e3daf9f60f"}, 370 | {file = "importlib_metadata-1.3.0.tar.gz", hash = "sha256:073a852570f92da5f744a3472af1b61e28e9f78ccf0c9117658dc32b15de7b45"}, 371 | ] 372 | more-itertools = [ 373 | {file = "more-itertools-8.0.2.tar.gz", hash = "sha256:b84b238cce0d9adad5ed87e745778d20a3f8487d0f0cb8b8a586816c7496458d"}, 374 | {file = "more_itertools-8.0.2-py3-none-any.whl", hash = "sha256:c833ef592a0324bcc6a60e48440da07645063c453880c9477ceb22490aec1564"}, 375 | ] 376 | nodeenv = [ 377 | {file = "nodeenv-1.3.3.tar.gz", hash = "sha256:ad8259494cf1c9034539f6cced78a1da4840a4b157e23640bc4a0c0546b0cb7a"}, 378 | ] 379 | numpy = [ 380 | {file = "numpy-1.18.0-cp35-cp35m-macosx_10_6_intel.whl", hash = "sha256:b091e5d4cbbe79f0e8b6b6b522346e54a282eadb06e3fd761e9b6fafc2ca91ad"}, 381 | {file = "numpy-1.18.0-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:443ab93fc35b31f01db8704681eb2fd82f3a1b2fa08eed2dd0e71f1f57423d4a"}, 382 | {file = "numpy-1.18.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:88c5ccbc4cadf39f32193a5ef22e3f84674418a9fd877c63322917ae8f295a56"}, 383 | {file = "numpy-1.18.0-cp35-cp35m-win32.whl", hash = "sha256:e1080e37c090534adb2dd7ae1c59ee883e5d8c3e63d2a4d43c20ee348d0459c5"}, 384 | {file = "numpy-1.18.0-cp35-cp35m-win_amd64.whl", hash = "sha256:f084d513de729ff10cd72a1f80db468cff464fedb1ef2fea030221a0f62d7ff4"}, 385 | {file = "numpy-1.18.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:1baefd1fb4695e7f2e305467dbd876d765e6edd30c522894df76f8301efaee36"}, 386 | {file = "numpy-1.18.0-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:cc070fc43a494e42732d6ae2f6621db040611c1dde64762a40c8418023af56d7"}, 387 | {file = "numpy-1.18.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:6f8113c8dbfc192b58996ee77333696469ea121d1c44ea429d8fd266e4c6be51"}, 388 | {file = "numpy-1.18.0-cp36-cp36m-win32.whl", hash = "sha256:a30f5c3e1b1b5d16ec1f03f4df28e08b8a7529d8c920bbed657f4fde61f1fbcd"}, 389 | {file = "numpy-1.18.0-cp36-cp36m-win_amd64.whl", hash = "sha256:3c68c827689ca0ca713dba598335073ce0966850ec0b30715527dce4ecd84055"}, 390 | {file = "numpy-1.18.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f6a7421da632fc01e8a3ecd19c3f7350258d82501a646747664bae9c6a87c731"}, 391 | {file = "numpy-1.18.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:905cd6fa6ac14654a6a32b21fad34670e97881d832e24a3ca32e19b455edb4a8"}, 392 | {file = "numpy-1.18.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:854f6ed4fa91fa6da5d764558804ba5b0f43a51e5fe9fc4fdc93270b052f188a"}, 393 | {file = "numpy-1.18.0-cp37-cp37m-win32.whl", hash = "sha256:ac3cf835c334fcc6b74dc4e630f9b5ff7b4c43f7fb2a7813208d95d4e10b5623"}, 394 | {file = "numpy-1.18.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62506e9e4d2a39c87984f081a2651d4282a1d706b1a82fe9d50a559bb58e705a"}, 395 | {file = "numpy-1.18.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9d6de2ad782aae68f7ed0e0e616477fbf693d6d7cc5f0f1505833ff12f84a673"}, 396 | {file = "numpy-1.18.0-cp38-cp38-manylinux1_i686.whl", hash = "sha256:1c35fb1131362e6090d30286cfda52ddd42e69d3e2bf1fea190a0fad83ea3a18"}, 397 | {file = "numpy-1.18.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:56710a756c5009af9f35b91a22790701420406d9ac24cf6b652b0e22cfbbb7ff"}, 398 | {file = "numpy-1.18.0-cp38-cp38-win32.whl", hash = "sha256:03bbde29ac8fba860bb2c53a1525b3604a9b60417855ac3119d89868ec6041c3"}, 399 | {file = "numpy-1.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:712f0c32555132f4b641b918bdb1fd3c692909ae916a233ce7f50eac2de87e37"}, 400 | {file = "numpy-1.18.0.zip", hash = "sha256:a9d72d9abaf65628f0f31bbb573b7d9304e43b1e6bbae43149c17737a42764c4"}, 401 | ] 402 | packaging = [ 403 | {file = "packaging-19.2-py2.py3-none-any.whl", hash = "sha256:d9551545c6d761f3def1677baf08ab2a3ca17c56879e70fecba2fc4dde4ed108"}, 404 | {file = "packaging-19.2.tar.gz", hash = "sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47"}, 405 | ] 406 | pluggy = [ 407 | {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, 408 | {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, 409 | ] 410 | pre-commit = [ 411 | {file = "pre_commit-1.20.0-py2.py3-none-any.whl", hash = "sha256:c2e4810d2d3102d354947907514a78c5d30424d299dc0fe48f5aa049826e9b50"}, 412 | {file = "pre_commit-1.20.0.tar.gz", hash = "sha256:9f152687127ec90642a2cc3e4d9e1e6240c4eb153615cb02aa1ad41d331cbb6e"}, 413 | ] 414 | py = [ 415 | {file = "py-1.8.0-py2.py3-none-any.whl", hash = "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa"}, 416 | {file = "py-1.8.0.tar.gz", hash = "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53"}, 417 | ] 418 | pyparsing = [ 419 | {file = "pyparsing-2.4.6-py2.py3-none-any.whl", hash = "sha256:c342dccb5250c08d45fd6f8b4a559613ca603b57498511740e65cd11a2e7dcec"}, 420 | {file = "pyparsing-2.4.6.tar.gz", hash = "sha256:4c830582a84fb022400b85429791bc551f1f4871c33f23e44f353119e92f969f"}, 421 | ] 422 | pytest = [ 423 | {file = "pytest-5.3.2-py3-none-any.whl", hash = "sha256:e41d489ff43948babd0fad7ad5e49b8735d5d55e26628a58673c39ff61d95de4"}, 424 | {file = "pytest-5.3.2.tar.gz", hash = "sha256:6b571215b5a790f9b41f19f3531c53a45cf6bb8ef2988bc1ff9afb38270b25fa"}, 425 | ] 426 | pytest-cov = [ 427 | {file = "pytest-cov-2.8.1.tar.gz", hash = "sha256:cc6742d8bac45070217169f5f72ceee1e0e55b0221f54bcf24845972d3a47f2b"}, 428 | {file = "pytest_cov-2.8.1-py2.py3-none-any.whl", hash = "sha256:cdbdef4f870408ebdbfeb44e63e07eb18bb4619fae852f6e760645fa36172626"}, 429 | ] 430 | pyyaml = [ 431 | {file = "PyYAML-5.2-cp27-cp27m-win32.whl", hash = "sha256:35ace9b4147848cafac3db142795ee42deebe9d0dad885ce643928e88daebdcc"}, 432 | {file = "PyYAML-5.2-cp27-cp27m-win_amd64.whl", hash = "sha256:ebc4ed52dcc93eeebeae5cf5deb2ae4347b3a81c3fa12b0b8c976544829396a4"}, 433 | {file = "PyYAML-5.2-cp35-cp35m-win32.whl", hash = "sha256:38a4f0d114101c58c0f3a88aeaa44d63efd588845c5a2df5290b73db8f246d15"}, 434 | {file = "PyYAML-5.2-cp35-cp35m-win_amd64.whl", hash = "sha256:483eb6a33b671408c8529106df3707270bfacb2447bf8ad856a4b4f57f6e3075"}, 435 | {file = "PyYAML-5.2-cp36-cp36m-win32.whl", hash = "sha256:7f38e35c00e160db592091751d385cd7b3046d6d51f578b29943225178257b31"}, 436 | {file = "PyYAML-5.2-cp36-cp36m-win_amd64.whl", hash = "sha256:0e7f69397d53155e55d10ff68fdfb2cf630a35e6daf65cf0bdeaf04f127c09dc"}, 437 | {file = "PyYAML-5.2-cp37-cp37m-win32.whl", hash = "sha256:e4c015484ff0ff197564917b4b4246ca03f411b9bd7f16e02a2f586eb48b6d04"}, 438 | {file = "PyYAML-5.2-cp37-cp37m-win_amd64.whl", hash = "sha256:4b6be5edb9f6bb73680f5bf4ee08ff25416d1400fbd4535fe0069b2994da07cd"}, 439 | {file = "PyYAML-5.2-cp38-cp38-win32.whl", hash = "sha256:8100c896ecb361794d8bfdb9c11fce618c7cf83d624d73d5ab38aef3bc82d43f"}, 440 | {file = "PyYAML-5.2-cp38-cp38-win_amd64.whl", hash = "sha256:2e9f0b7c5914367b0916c3c104a024bb68f269a486b9d04a2e8ac6f6597b7803"}, 441 | {file = "PyYAML-5.2.tar.gz", hash = "sha256:c0ee8eca2c582d29c3c2ec6e2c4f703d1b7f1fb10bc72317355a746057e7346c"}, 442 | ] 443 | six = [ 444 | {file = "six-1.13.0-py2.py3-none-any.whl", hash = "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd"}, 445 | {file = "six-1.13.0.tar.gz", hash = "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66"}, 446 | ] 447 | toml = [ 448 | {file = "toml-0.10.0-py2.7.egg", hash = "sha256:f1db651f9657708513243e61e6cc67d101a39bad662eaa9b5546f789338e07a3"}, 449 | {file = "toml-0.10.0-py2.py3-none-any.whl", hash = "sha256:235682dd292d5899d361a811df37e04a8828a5b1da3115886b73cf81ebc9100e"}, 450 | {file = "toml-0.10.0.tar.gz", hash = "sha256:229f81c57791a41d65e399fc06bf0848bab550a9dfd5ed66df18ce5f05e73d5c"}, 451 | ] 452 | torch = [ 453 | {file = "torch-1.3.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:d8e1d904a6193ed14a4fed220b00503b2baa576e71471286d1ebba899c851fae"}, 454 | {file = "torch-1.3.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:b6f01d851d1c5989d4a99b50ae0187762b15b7718dcd1a33704b665daa2402f9"}, 455 | {file = "torch-1.3.1-cp27-none-macosx_10_7_x86_64.whl", hash = "sha256:31062923ac2e60eac676f6a0ae14702b051c158bbcf7f440eaba266b0defa197"}, 456 | {file = "torch-1.3.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:458f1d87e5b7064b2c39e36675d84e163be3143dd2fc806057b7878880c461bc"}, 457 | {file = "torch-1.3.1-cp35-none-macosx_10_6_x86_64.whl", hash = "sha256:3b05233481b51bb636cee63dc761bb7f602e198178782ff4159d385d1759608b"}, 458 | {file = "torch-1.3.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:0cec2e13a2e95c24c34f17d437f354ee2a40902e8d515a524556b350e12555dd"}, 459 | {file = "torch-1.3.1-cp36-none-macosx_10_7_x86_64.whl", hash = "sha256:77fd8866c0bf529861ffd850a5dada2190a8d9c5167719fb0cfa89163e23b143"}, 460 | {file = "torch-1.3.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:72a1c85bffd2154f085bc0a1d378d8a54e55a57d49664b874fe7c949022bf071"}, 461 | {file = "torch-1.3.1-cp37-none-macosx_10_7_x86_64.whl", hash = "sha256:134e8291a97151b1ffeea09cb9ddde5238beb4e6d9dfb66657143d6990bfb865"}, 462 | ] 463 | virtualenv = [ 464 | {file = "virtualenv-16.7.9-py2.py3-none-any.whl", hash = "sha256:55059a7a676e4e19498f1aad09b8313a38fcc0cdbe4fdddc0e9b06946d21b4bb"}, 465 | {file = "virtualenv-16.7.9.tar.gz", hash = "sha256:0d62c70883c0342d59c11d0ddac0d954d0431321a41ab20851facf2b222598f3"}, 466 | ] 467 | wcwidth = [ 468 | {file = "wcwidth-0.1.7-py2.py3-none-any.whl", hash = "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c"}, 469 | {file = "wcwidth-0.1.7.tar.gz", hash = "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e"}, 470 | ] 471 | zipp = [ 472 | {file = "zipp-0.6.0-py2.py3-none-any.whl", hash = "sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335"}, 473 | {file = "zipp-0.6.0.tar.gz", hash = "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e"}, 474 | ] 475 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry>=0.12"] 3 | build-backend = "poetry.masonry.api" 4 | 5 | [tool.black] 6 | line-length = 79 7 | 8 | [tool.poetry] 9 | name = "batchrenorm" 10 | version = "0.1.0" 11 | description = "Batch Renormalization" 12 | authors = ["Ludvig Bergenstråhle "] 13 | license = "MIT" 14 | 15 | [tool.poetry.dependencies] 16 | python = "^3.7.0" 17 | torch = "^1.3" 18 | 19 | [tool.poetry.dev-dependencies] 20 | pre-commit = "^1.20.0" 21 | pytest = "^5.3" 22 | pytest-cov = "^2.8.1" 23 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludvb/batchrenorm/143076f952f0861bbf8af55833d046e10ffaa99b/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pytest_configure(config): 5 | config.addinivalue_line("markers", "seed_rng: seed the pytorch RNG") 6 | 7 | 8 | def pytest_runtest_setup(item): 9 | seed_marker = item.get_closest_marker("seed_rng") 10 | if seed_marker is not None: 11 | if len(seed_marker.args) != 0: 12 | seed = seed_marker.args[0] 13 | else: 14 | seed = 1337 15 | torch.manual_seed(seed) 16 | -------------------------------------------------------------------------------- /tests/test_batchrenorm.py: -------------------------------------------------------------------------------- 1 | import batchrenorm 2 | import pytest 3 | import torch 4 | 5 | 6 | @pytest.mark.parametrize("kwargs", [{"affine": True}, {"affine": False}]) 7 | def test_step1(kwargs): 8 | for _ in range(100): 9 | batch_renorm = batchrenorm.BatchRenorm1d(10, eps=0.0, **kwargs) 10 | x = torch.randn(100, 10) 11 | assert ( 12 | batch_renorm(x) 13 | - torch.nn.functional.batch_norm( 14 | x, torch.zeros(10), torch.ones(10), eps=0.0, training=True 15 | ) 16 | ).abs().max() < 1e-5 17 | 18 | 19 | @pytest.mark.parametrize("kwargs", [{"affine": True}, {"affine": False}]) 20 | @pytest.mark.seed_rng() 21 | def test_ablation(kwargs): 22 | br = batchrenorm.BatchRenorm1d(5, eps=0.0, **kwargs) 23 | br.num_batches_tracked = torch.tensor(50000) 24 | xs = torch.randn(10, 5, 5) * 10 25 | xs_mean = xs.mean((0, 1)) 26 | xs_var = xs.var((0, 1), unbiased=False) 27 | 28 | def _step(): 29 | return ( 30 | torch.stack( 31 | [ 32 | br(x) 33 | - torch.nn.functional.batch_norm( 34 | x, xs_mean, xs_var, eps=0.0, training=False 35 | ) 36 | for x in xs 37 | ] 38 | ) 39 | .abs() 40 | .mean() 41 | ) 42 | 43 | errors = torch.stack([_step() for _ in range(100)]) 44 | assert errors[-10:].mean() < errors[:10].mean() 45 | 46 | 47 | def test_batchnorm1d(): 48 | br = batchrenorm.BatchRenorm1d(3).eval() 49 | x = torch.randn(5, 3) 50 | assert (br(x) == br(x.unsqueeze(-1)).squeeze(-1)).all() 51 | with pytest.raises(ValueError, match="expected 2D or 3D input"): 52 | br(x[0]) 53 | with pytest.raises(ValueError, match="expected 2D or 3D input"): 54 | br(x[..., None, None]) 55 | with pytest.raises(ValueError, match="expected 2D or 3D input"): 56 | br(x[..., None, None, None]) 57 | 58 | 59 | def test_batchnorm2d(): 60 | br = batchrenorm.BatchRenorm2d(3).eval() 61 | x = torch.randn(5, 3, 10, 10) 62 | br(x) 63 | assert ( 64 | br(x[:, :, :1, :1]).squeeze() 65 | == batchrenorm.BatchRenorm1d(3).eval()(x[:, :, 0, 0]) 66 | ).all() 67 | with pytest.raises(ValueError, match="expected 4D input"): 68 | br(x[0, :, 0, 0]) 69 | with pytest.raises(ValueError, match="expected 4D input"): 70 | br(x[:, :, 0, 0]) 71 | with pytest.raises(ValueError, match="expected 4D input"): 72 | br(x[:, :, :, 0]) 73 | with pytest.raises(ValueError, match="expected 4D input"): 74 | br(x[:, :, None]) 75 | 76 | 77 | def test_batchnorm3d(): 78 | br = batchrenorm.BatchRenorm3d(3).eval() 79 | x = torch.randn(5, 3, 10, 10, 10) 80 | br(x) 81 | assert ( 82 | br(x[:, :, :1]).squeeze() 83 | == batchrenorm.BatchRenorm2d(3).eval()(x[:, :, 0]) 84 | ).all() 85 | with pytest.raises(ValueError, match="expected 5D input"): 86 | br(x[0, :, 0, 0, 0]) 87 | with pytest.raises(ValueError, match="expected 5D input"): 88 | br(x[:, :, 0, 0, 0]) 89 | with pytest.raises(ValueError, match="expected 5D input"): 90 | br(x[:, :, :, 0, 0]) 91 | with pytest.raises(ValueError, match="expected 5D input"): 92 | br(x[:, :, 0, :, :]) 93 | --------------------------------------------------------------------------------