├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── equation.svg ├── include ├── checks.h ├── cuda_utils.cuh ├── dispatch.h ├── inplace_abn.h └── utils.h ├── inplace_abn.png ├── inplace_abn ├── __init__.py ├── _backend.pyi ├── abn.py ├── functions.py └── group.py ├── licenses.csv ├── requirements.txt ├── scripts ├── dataset │ ├── __init__.py │ ├── dataset.py │ ├── sampler.py │ └── transform.py ├── experiments │ ├── densenet264_ipabn_lr_256.json │ ├── resnet101_ipabn-sync_lr_512.json │ ├── resnet34_ipabn-sync_lr_512.json │ ├── resnet50_ipabn-sync_lr_512.json │ ├── resnext101_ipabn-sync_lr_256.json │ ├── resnext101_ipabn_lr_512.json │ ├── resnext101_stdbn_lr_256.json │ ├── resnext152_ipabn_lr_256.json │ └── wider_resnet38_ipabn_lr_256.json ├── imagenet │ ├── __init__.py │ ├── config.py │ ├── transforms.py │ └── utils.py ├── models │ ├── __init__.py │ ├── densenet.py │ ├── resnet.py │ ├── resnext.py │ ├── util.py │ └── wider_resnet.py ├── modules │ ├── __init__.py │ ├── deeplab.py │ ├── dense.py │ ├── misc.py │ └── residual.py ├── requirements.txt ├── test_imagenet.py ├── test_vistas.py ├── test_vistas_single_gpu.py └── train_imagenet.py ├── setup.cfg ├── setup.py └── src ├── inplace_abn.cpp ├── inplace_abn_cpu.cpp ├── inplace_abn_cuda.cu ├── inplace_abn_kernels.cuh └── utils.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | 19 | # Compiled Static libraries 20 | *.lai 21 | *.la 22 | *.a 23 | *.lib 24 | 25 | # Executables 26 | *.exe 27 | *.out 28 | *.app 29 | *~ 30 | 31 | # doc 32 | doc/html 33 | doc/latex 34 | doc/doc 35 | docs/web-data 36 | 37 | #dmlc 38 | config.mk 39 | 40 | *.pyc 41 | .Rhistory 42 | *log 43 | Debug 44 | *suo 45 | tracker 46 | 47 | # vim 48 | *.swp 49 | *.swo 50 | *.swn 51 | .vimrc 52 | .ycm_extra_conf.py 53 | .ycm_extra_conf.pyc 54 | 55 | # Emacs 56 | .#* 57 | .clang_complete 58 | .dir-locals.el 59 | __pycache__ 60 | *.pkl 61 | *.params 62 | *.d 63 | build 64 | data 65 | recommonmark 66 | bin 67 | deps 68 | 69 | # R 70 | *.Rcheck 71 | *.rds 72 | *.Rproj 73 | .Rproj.user 74 | R-package/inst/* 75 | *.tar.gz 76 | *.tgz 77 | R-package/man/*.Rd 78 | 79 | # data 80 | *.rec 81 | *.lst 82 | *.zip 83 | *ubyte 84 | *.bin 85 | 86 | # ipython notebook 87 | *_pb2.py 88 | *.ipynb_checkpoints* 89 | input.txt* 90 | 91 | # Jetbrain 92 | .idea 93 | 94 | # ctags 95 | tags 96 | 97 | # Scala package 98 | *.class 99 | scala-package/*/target/ 100 | scala-package/*/*/target/ 101 | *.scala_dependencies 102 | *.worksheet 103 | *.idea 104 | *.iml 105 | *.classpath 106 | *.project 107 | *.settings 108 | !scala-package/*/bin 109 | 110 | # cffi generated 111 | _ext 112 | 113 | # setuptools 114 | dist/ 115 | *.egg-info 116 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Open Source Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | Using welcoming and inclusive language 12 | Being respectful of differing viewpoints and experiences 13 | Gracefully accepting constructive criticism 14 | Focusing on what is best for the community 15 | Showing empathy towards other community members 16 | Examples of unacceptable behavior by participants include: 17 | 18 | The use of sexualized language or imagery and unwelcome sexual attention or advances 19 | Trolling, insulting/derogatory comments, and personal or political attacks 20 | Public or private harassment 21 | Publishing others’ private information, such as a physical or electronic address, without explicit permission 22 | Other conduct which could reasonably be considered inappropriate in a professional setting 23 | 24 | ## Our Responsibilities 25 | 26 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 27 | 28 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 29 | 30 | ## Scope 31 | 32 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 33 | 34 | ## Enforcement 35 | 36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 37 | 38 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. 39 | 40 | ## Attribution 41 | 42 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 43 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 44 | 45 | [homepage]: https://www.contributor-covenant.org 46 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to InPlace-ABN 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome pull requests addressing bugs. Each pull request should be 7 | referenced in a corresponding Issue explaining the bug and how to reproduce it. 8 | 9 | ## Contributor License Agreement ("CLA") 10 | In order to accept your pull request, we need you to submit a CLA. You only need 11 | to do this once to work on any of Facebook's open source projects. 12 | 13 | Complete your CLA here: 14 | 15 | ## License 16 | By contributing to InPlace-ABN, you agree that your contributions will be licensed 17 | under the LICENSE file in the root directory of this source tree. 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, mapillary 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # In-Place Activated BatchNorm 2 | 3 | [**In-Place Activated BatchNorm for Memory-Optimized Training of DNNs**](https://arxiv.org/abs/1712.02616) 4 | 5 | In-Place Activated BatchNorm (InPlace-ABN) is a novel approach to reduce the memory required for training deep networks. 6 | It allows for up to 50% memory savings in modern architectures such as ResNet, ResNeXt and Wider ResNet by redefining 7 | BN + non linear activation as a single in-place operation, while smartly dropping or recomputing intermediate buffers as 8 | needed. 9 | 10 | This repository contains a [PyTorch](http://pytorch.org/) implementation of the InPlace-ABN layer, as well as some 11 | training scripts to reproduce the ImageNet classification results reported in our paper. 12 | 13 | - [Overview](#overview) 14 | - [Installation](#installation) 15 | - [Training on ImageNet](#training-on-imagenet) 16 | 17 | We have now also released the inference code for semantic segmentation, together with the Mapillary Vistas trained model leading to [#1 position on the Mapillary Vistas Semantic Segmentation leaderboard](https://eval-vistas.mapillary.com/featured-challenges/1/leaderboard/1). More information can be found at the bottom of this page. 18 | 19 | ## Citation 20 | 21 | If you use In-Place Activated BatchNorm in your research, please cite: 22 | ```bibtex 23 | @inproceedings{rotabulo2017place, 24 | title={In-Place Activated BatchNorm for Memory-Optimized Training of DNNs}, 25 | author={Rota Bul\`o, Samuel and Porzi, Lorenzo and Kontschieder, Peter}, 26 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 27 | year={2018} 28 | } 29 | ``` 30 | 31 | ## Overview 32 | 33 |

34 | 35 | When processing a BN-Activation-Convolution sequence in the forward pass, most deep learning frameworks need to store 36 | two big buffers, _i.e._ the input `x` of BN and the input `z` of Conv. 37 | This is necessary because the standard implementations of the backward passes of BN and Conv depend on their inputs to 38 | calculate the gradients. 39 | Using Inplace-ABN to replace the BN-Activation sequence, we can safely discard `x`, thus saving up to 50% GPU memory at 40 | training time. 41 | To achieve this, we rewrite the backward pass of BN in terms of its output `y`, which is in turn reconstructed from `z` 42 | by inverting the activation function. 43 | 44 | The parametrization for the scaling factor of BN changed compared to standard BN, in order to ensure an invertible transformation. Specifically, the scaling factor becomes 45 | . 46 | 47 | ## Requirements 48 | 49 | To install PyTorch, please refer to https://github.com/pytorch/pytorch#installation. 50 | 51 | **NOTE 1: our code _requires_ PyTorch v1.1 or later** 52 | 53 | **NOTE 2: we are only able to provide support for Linux platforms and CUDA versions >= 10.0** 54 | 55 | **NOTE 3: in general, it is not possible to load weights from a network trained with standard BN into an InPlace-ABN network without severe performance degradation, due to the different handling of BN scaling parameters** 56 | 57 | To install the package containing the iABN layers: 58 | ```bash 59 | pip install inplace-abn 60 | ``` 61 | Note that some parts of InPlace-ABN have native C++/CUDA implementations, meaning that the command above will need to 62 | compile them. 63 | 64 | Alternatively, to download and install the latest version of our library, also obtaining a copy of the Imagenet / Vistas 65 | scripts: 66 | ```bash 67 | git clone https://github.com/mapillary/inplace_abn.git 68 | cd inplace_abn 69 | python setup.py install 70 | cd scripts 71 | pip install -r requirements.txt 72 | ``` 73 | The last of the commands above will install some additional libraries required by the Imagenet / Vistas scripts. 74 | 75 | ### Force compiling with CUDA 76 | 77 | In order to force the compilation of the native CUDA functions on systems that do not 78 | have access to a GPU (e.g. Docker containers), two environment variables have to be set: 79 | ```bash 80 | export TORCH_CUDA_ARCH_LIST="{archs}" 81 | export IABN_FORCE_CUDA=1 82 | ``` 83 | where `{archs}` is a list of target CUDA architectures, e.g. `Pascal;Volta`, `6.0;6.5` etc. 84 | 85 | ## Training on ImageNet-1k 86 | 87 | Here you can find the results from our arXiv paper (top-1 / top-5 scores) with corresponding, trained models and md5 checksums, respectively. The model files provided below are made available under the [license attached to ImageNet](http://www.image-net.org/download-faq). 88 | 89 | | Network | Batch | 224 | 224, 10-crops | 320 | Trained models (+md5) | 90 | |-----------------------------------|-------|----------------|----------------|---------------|----------------------------------| 91 | | [ResNeXt101, Std-BN][1] | 256 | 77.04 / 93.50 | 78.72 / 94.47 | 77.92 / 94.28 | [`448438885986d14db5e870b95f814f91`][6] | 92 | | [ResNeXt101, InPlace-ABN][2] | 512 | 78.08 / 93.79 | 79.52 / 94.66 | 79.38 / 94.67 | [`3b7a221cbc076410eb12c8dd361b7e4e`][7] | 93 | | [ResNeXt152, InPlace-ABN][3] | 256 | 78.28 / 94.04 | 79.73 / 94.82 | 79.56 / 94.67 | [`2c8d572587961ed74611d534c5b2e9ce`][8] | 94 | | [WideResNet38, InPlace-ABN][4] | 256 | 79.72 / 94.78 | 81.03 / 95.43 | 80.69 / 95.27 | [`1c085ab70b789cc1d6c1594f7a761007`][9] | 95 | | [ResNeXt101, InPlace-ABN sync][5] | 256 | 77.70 / 93.78 | 79.18 / 94.60 | 78.98 / 94.56 | [`0a85a21847b15e5a242e17bf3b753849`][10] | 96 | | [DenseNet264, InPlace-ABN][11] | 256 | 78.57 / 94.17 | 79.72 / 94.93 | 79.49 / 94.89 | [`0b413d67b725619441d0646d663865bf`][12] | 97 | | [ResNet50v1, InPlace-ABN sync][13] | 512 | 75.53 / 92.59 | 77.04 / 93.57 | 76.60 / 93.49 | [`2522ca639f7fdfd7c0089ba1f5f6c2e8`][14] | 98 | | [ResNet34v1, InPlace-ABN sync][15] | 512 | 73.27 / 91.34 | 75.19 / 92.66 | 74.87 / 92.42 | [`61515c1484911c3cc753d405131e1dda`][16] | 99 | | [ResNet101v1, InPlace-ABN sync][17] | 512 | 77.07 / 93.45 | 78.58 / 94.40 | 78.25 / 94.19 | [`1552ae0f3d610108df702135f56bd27b`][18] | 100 | 101 | [1]: scripts/experiments/resnext101_stdbn_lr_256.json 102 | [2]: scripts/experiments/resnext101_ipabn_lr_512.json 103 | [3]: scripts/experiments/resnext152_ipabn_lr_256.json 104 | [4]: scripts/experiments/wider_resnet38_ipabn_lr_256.json 105 | [5]: scripts/experiments/resnext101_ipabn-sync_lr_256.json 106 | [6]: https://drive.google.com/file/d/1qT8qCSZzUHorai1EP6Liywa28ASac_G_/view 107 | [7]: https://drive.google.com/file/d/1rQd-NoZuCsGZ7_l_X9GO1GGiXeXHE8CT/view 108 | [8]: https://drive.google.com/file/d/1RmHK3tdVTVsHiyNO14bYLkMC0XUjenIn/view 109 | [9]: https://drive.google.com/file/d/1Y0McSz9InDSxMEcBylAbCv1gvyeaz8Ij/view 110 | [10]: https://drive.google.com/file/d/1v2gmUPBMDKf0wZm9r1JwCQLGAig0DdXJ/view 111 | [11]: scripts/experiments/densenet264_ipabn_lr_256.json 112 | [12]: https://drive.google.com/file/d/1J2wp59bzzEd6zttM6oMa1KgbmCL1MS0k/view 113 | [13]: scripts/experiments/resnet50_ipabn-sync_lr_512.json 114 | [14]: https://drive.google.com/file/d/1N7kjWrnUbD_aBOUNi9ZLGnI3E_1ATH8U/view 115 | [15]: scripts/experiments/resnet34_ipabn-sync_lr_512.json 116 | [16]: https://drive.google.com/file/d/1V5dCIZeRCfnZi9krNaQNhXNDHyXz9JR8/view 117 | [17]: scripts/experiments/resnet101_ipabn-sync_lr_512.json 118 | [18]: https://drive.google.com/file/d/1oFVSIUYAxa_uNDq2OLkbhyiFmKwnYzpt/view 119 | 120 | ### Data preparation 121 | 122 | Our script uses [torchvision.datasets.ImageFolder](http://pytorch.org/docs/master/torchvision/datasets.html#torchvision.datasets.ImageFolder) 123 | for loading ImageNet data, which expects folders organized as follows: 124 | ``` 125 | root/train/[class_id1]/xxx.{jpg,png,jpeg} 126 | root/train/[class_id1]/xxy.{jpg,png,jpeg} 127 | root/train/[class_id2]/xxz.{jpg,png,jpeg} 128 | ... 129 | 130 | root/val/[class_id1]/asdas.{jpg,png,jpeg} 131 | root/val/[class_id1]/123456.{jpg,png,jpeg} 132 | root/val/[class_id2]/__32_.{jpg,png,jpeg} 133 | ... 134 | ``` 135 | Images can have any name, as long as the extension is that of a recognized image format. 136 | Class ids are also free-form, but they are expected to match between train and validation data. 137 | Note that the training data in the standard ImageNet distribution is already given in the required format, while 138 | validation images need to be split into class sub-folders as described above. 139 | 140 | ### Training 141 | 142 | The main training script is `scripts/train_imagenet.py`: this supports training on ImageNet, or any other dataset 143 | formatted as described above, while keeping a log of relevant metrics in Tensorboard format and periodically saving 144 | snapshots. 145 | Most training parameters can be specified as a `json`-formatted configuration file (look 146 | [here](scripts/imagenet/config.py) for a complete list of configurable parameters). 147 | All parameters not explicitly specified in the configuration file are set to their defaults, also available in 148 | [scripts/imagenet/config.py](scripts/imagenet/config.py). 149 | 150 | Our arXiv results can be reproduced by running `scripts/train_imagenet.py` with the configuration files in 151 | `scripts/experiments`. 152 | As an example, the command to train `ResNeXt101` with InPlace-ABN, Leaky ReLU and `batch_size = 512` is: 153 | ```bash 154 | cd scripts 155 | python -m torch.distributed.launch --nproc_per_node train_imagenet.py --log-dir /path/to/tensorboard/logs experiments/resnext101_ipabn_lr_512.json /path/to/imagenet/root 156 | ``` 157 | 158 | ### Validation 159 | 160 | Validation is run by `scripts/train_imagenet.py` at the end of every training epoch. 161 | To validate a trained model, you can use the `scripts/test_imagenet.py` script, which allows for 10-crops validation and 162 | transferring weights across compatible networks (_e.g._ from `ResNeXt101` with ReLU to `ResNeXt101` with Leaky 163 | ReLU). 164 | This script accepts the same configuration files as `scripts/train_imagenet.py`, but note that the `scale_val` and 165 | `crop_val` parameters are ignored in favour of the `--scale` and `--crop` command-line arguments. 166 | 167 | As an example, to validate the `ResNeXt101` trained above using 10-crops of size `224` from images scaled to `256` 168 | pixels, you can run: 169 | ```bash 170 | cd scripts 171 | python -m torch.distributed.launch --nproc_per_node test_imagenet.py --crop 224 --scale 256 --ten_crops experiments/resnext101_ipabn_lr_512.json /path/to/checkpoint /path/to/imagenet/root 172 | ``` 173 | 174 | ## Usage for Semantic Segmentation on Cityscapes and Mapillary Vistas 175 | 176 | We have successfully used InPlace-ABN with a DeepLab3 segmentation head that was trained on top of the WideResNet38 177 | model above. 178 | Due to InPlace-ABN, we can significantly increase the amount of input data to this model, which eventually allowed us to 179 | obtain #1 positions on [Cityscapes](https://www.cityscapes-dataset.com/benchmarks/#scene-labeling-task), 180 | [Mapillary Vistas](https://eval-vistas.mapillary.com/featured-challenges/1/leaderboard/1), [AutoNUE](http://cvit.iiit.ac.in/scene-understanding-challenge-2018/benchmarks.php), 181 | [Kitti](http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015) and 182 | [ScanNet](http://dovahkiin.stanford.edu/adai/semantic_label) segmentation leaderboards. 183 | The training settings mostly follow the description in our [paper](https://arxiv.org/abs/1712.02616). 184 | 185 | ### Mapillary Vistas pre-trained model 186 | 187 | We release our WideResNet38 + DeepLab3 segmentation model trained on the Mapillary Vistas research set. 188 | This is the model used to reach #1 position on the MVD semantic segmentation leaderboard. 189 | The segmentation model file provided below is made available under a 190 | [CC BY-NC-SA 4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/). 191 | 192 | | Network | mIOU | Trained model (+md5) | 193 | |-------------------------------|-------|----------------------------------------| 194 | | [WideResNet38 + DeepLab3][19] | 53.42 | [913f78486a34aa1577a7cd295e8a33bb][20] | 195 | 196 | [19]: scripts/test_vistas.py 197 | [20]: https://drive.google.com/file/d/1SJJx5-LFG3J3M99TrPMU-z6ZmgWynxo-/view 198 | 199 | To use this, please download the `.pth.tar` model file linked above and run the `test_vistas.py` script as follows: 200 | ```bash 201 | cd scripts 202 | python test_vistas.py /path/to/model.pth.tar /path/to/input/folder /path/to/output/folder 203 | ``` 204 | 205 | The script will process all `.png`, `.jpg` and `.jpeg` images from the input folder and write the predictions in the 206 | output folder as `.png` images. 207 | For additional options, _e.g._ test time augmentation, please consult the script's help message. 208 | 209 | The results on the test data written above were obtained by employing only scale 1.0 + flipping. 210 | 211 | ## Changelog 212 | 213 | **Update 04 Jul. 2019: version 1.0.0** 214 | - Complete rewrite of the CUDA code following the most recent native BN implementation from Pytorch 215 | - Improved synchronized BN implementation, correctly handling different per-GPU batch sizes and Pytorch distributed groups 216 | - The iABN layers are now packaged in an installable python library to simplify use in other projects 217 | - The Imagenet / Vistas scripts are still available in the `scripts` folder 218 | - Requires now PyTorch 1.1 219 | 220 | **Update 08 Jan. 2019:** 221 | - Enabled multiprocessing and inplace ABN synchronization over multiple processes (previously using threads). It now requires to use DistributedDataParallel instead of DataParallel 222 | - Added compatibility with fp16 (currently allows fp16 input but requires the module to stay in fp32 mode) 223 | - Requires now PyTorch 1.0 224 | 225 | **Update Feb. 2019:** 226 | - Added ResNet34v1, ResNet50v1 and ResNet101v1 ImageNet-1k pre-trained models 227 | 228 | We have modified the imagenet training code and BN synchronization in order to work with multiple processes. We have also added compatibility of our Inplace ABN module with fp16. 229 | -------------------------------------------------------------------------------- /equation.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /include/checks.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #ifdef TORCH_CHECK 8 | #define IABN_CHECK TORCH_CHECK 9 | #else 10 | #define IABN_CHECK AT_CHECK 11 | #endif 12 | 13 | #define CHECK_CUDA(x) IABN_CHECK((x).is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CPU(x) IABN_CHECK(!(x).is_cuda(), #x " must be a CPU tensor") 15 | #define CHECK_NOT_HALF(x) \ 16 | IABN_CHECK( \ 17 | (x).scalar_type() != at::ScalarType::Half, #x " can't have type Half") 18 | #define CHECK_SAME_TYPE(x, y) \ 19 | IABN_CHECK( \ 20 | (x).scalar_type() == (y).scalar_type(), \ 21 | #x " and " #y " must have the same scalar type") 22 | 23 | inline bool have_same_dims(const at::Tensor& x, const at::Tensor& y) { 24 | bool success = x.ndimension() == y.ndimension(); 25 | for (int64_t dim = 0; dim < x.ndimension(); ++dim) 26 | success &= x.size(dim) == y.size(dim); 27 | return success; 28 | } 29 | 30 | inline bool is_compatible_weight(const at::Tensor& x, const at::Tensor& w) { 31 | // Dimensions check 32 | bool success = w.ndimension() == 1; 33 | success &= x.size(1) == w.size(0); 34 | 35 | // Typing check 36 | if (x.scalar_type() == at::ScalarType::Half) { 37 | success &= (w.scalar_type() == at::ScalarType::Half) || 38 | (w.scalar_type() == at::ScalarType::Float); 39 | } else { 40 | success &= x.scalar_type() == w.scalar_type(); 41 | } 42 | 43 | return success; 44 | } 45 | 46 | inline bool is_compatible_stat(const at::Tensor& x, const at::Tensor& s) { 47 | // Dimensions check 48 | bool success = s.ndimension() == 1; 49 | success &= x.size(1) == s.size(0); 50 | 51 | // Typing check 52 | if (x.scalar_type() == at::ScalarType::Half) { 53 | success &= s.scalar_type() == at::ScalarType::Float; 54 | } else { 55 | success &= x.scalar_type() == s.scalar_type(); 56 | } 57 | 58 | return success; 59 | } 60 | -------------------------------------------------------------------------------- /include/cuda_utils.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #if defined(__HIP_PLATFORM_HCC__) 9 | constexpr int WARP_SIZE = 64; 10 | #else 11 | constexpr int WARP_SIZE = 32; 12 | #endif 13 | 14 | // The maximum number of threads in a block 15 | #if defined(__HIP_PLATFORM_HCC__) 16 | constexpr int MAX_BLOCK_SIZE = 256; 17 | #else 18 | constexpr int MAX_BLOCK_SIZE = 512; 19 | #endif 20 | 21 | template 22 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) 23 | { 24 | #if CUDA_VERSION >= 9000 25 | return __shfl_xor_sync(mask, value, laneMask, width); 26 | #else 27 | return __shfl_xor(value, laneMask, width); 28 | #endif 29 | } 30 | 31 | // Number of threads in a block given an input size up to MAX_BLOCK_SIZE 32 | static int getNumThreads(int nElem) { 33 | #if defined(__HIP_PLATFORM_HCC__) 34 | int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; 35 | #else 36 | int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; 37 | #endif 38 | for (int i = 0; i != 5; ++i) { 39 | if (nElem <= threadSizes[i]) { 40 | return threadSizes[i]; 41 | } 42 | } 43 | return MAX_BLOCK_SIZE; 44 | } 45 | 46 | static int lastPow2(unsigned int n) { 47 | n |= (n >> 1); 48 | n |= (n >> 2); 49 | n |= (n >> 4); 50 | n |= (n >> 8); 51 | n |= (n >> 16); 52 | return n - (n >> 1); 53 | } 54 | 55 | // Returns the index of the most significant 1 bit in `val`. 56 | __device__ __forceinline__ int getMSB(int val) { 57 | return 31 - __clz(val); 58 | } 59 | 60 | // Sum across all threads within a warp 61 | template 62 | static __device__ __forceinline__ T warpSum(T val) { 63 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 64 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 65 | } 66 | return val; 67 | } 68 | 69 | template< 70 | typename scalar_t, 71 | int64_t dim, 72 | template class PtrTraits = at::DefaultPtrTraits, 73 | typename index_t = int64_t> 74 | static at::PackedTensorAccessor packed_accessor_or_dummy( 75 | const std::optional& t) { 76 | if (!t.has_value()) { 77 | const std::vector zeros(dim); 78 | return at::PackedTensorAccessor(nullptr, zeros.data(), zeros.data()); 79 | } 80 | return t.value().packed_accessor(); 81 | } 82 | 83 | template 84 | struct Float2 { 85 | scalar_t v1, v2; 86 | __device__ Float2() {} 87 | __device__ Float2(scalar_t v1, scalar_t v2) : v1(v1), v2(v2) {} 88 | __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {} 89 | __device__ Float2& operator+=(const Float2& a) { 90 | v1 += a.v1; 91 | v2 += a.v2; 92 | return *this; 93 | } 94 | }; 95 | 96 | template 97 | static __device__ __forceinline__ Float2 warpSum(Float2 value) { 98 | value.v1 = warpSum(value.v1); 99 | value.v2 = warpSum(value.v2); 100 | return value; 101 | } 102 | 103 | // Sum across (batch, x/y/z) applying Op() pointwise 104 | // this works by first having each thread sum it's part 105 | // of the data. Then there is a double-shuffeling reduction. 106 | // First each warp (of WARP_SIZE threads) uses warpSum to reduce its 107 | // data to the "warp leader", who writes its value into shared memory. 108 | // Then a single warp reads the remaining (at most WARP_SIZE) items 109 | // and reduces them using another warpSum. 110 | // The implicit assumption is that there are no more 111 | // than WARP_SIZE**2 threads. 112 | template 113 | __device__ scalar_t reduce(Op op, PTA tensor, int plane) { 114 | // first the reductions each thread does separately 115 | scalar_t sum = static_cast(0); 116 | for (auto batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { 117 | for (auto x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { 118 | sum += op(batch, plane, x); 119 | } 120 | } 121 | 122 | // first warpSum to get one value per thread to 123 | // one value per warp 124 | sum = warpSum(sum); 125 | 126 | // this writes each warps item into shared memory 127 | // there are at most WARP_SIZE items left because 128 | // there are at most WARP_SIZE**2 threads at the beginning 129 | __shared__ scalar_t shared[WARP_SIZE]; 130 | __syncthreads(); 131 | auto tid = threadIdx.x + threadIdx.y * blockDim.x; 132 | if (tid % WARP_SIZE == 0) { 133 | shared[tid / WARP_SIZE] = sum; 134 | } 135 | if (tid >= blockDim.x * blockDim.y / WARP_SIZE && tid < WARP_SIZE) { 136 | // zero out the other entries in shared 137 | shared[tid] = (scalar_t)0; 138 | } 139 | __syncthreads(); 140 | // now have a second warpSum to reduce the intermediate values 141 | // from shared memory to a single number. The very first 142 | // thread writes it to shared memory. 143 | 144 | if (tid / WARP_SIZE == 0) { 145 | sum = warpSum(shared[tid]); 146 | if (tid == 0) { 147 | shared[0] = sum; 148 | } 149 | } 150 | __syncthreads(); 151 | 152 | // Everyone picks it up, should be broadcast into the whole grad_input 153 | return shared[0]; 154 | } 155 | -------------------------------------------------------------------------------- /include/dispatch.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #pragma once 4 | 5 | #define NORMAL_CASE_TYPE(enum_type, type, ...) \ 6 | case enum_type: { \ 7 | using scalar_t = type; \ 8 | using prmscalar_t = type; \ 9 | return __VA_ARGS__(); \ 10 | } 11 | 12 | #define HALF_CASE_TYPE(enum_type, x_type, w_scalar_type, ...) \ 13 | case enum_type: { \ 14 | using scalar_t = x_type; \ 15 | if (w_scalar_type == at::ScalarType::Half) { \ 16 | using prmscalar_t = at::Half; \ 17 | return __VA_ARGS__(); \ 18 | } else if (w_scalar_type == at::ScalarType::Float) { \ 19 | using prmscalar_t = float; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | AT_ERROR("Unsupported type combination '" #enum_type \ 23 | "', '" #w_scalar_type "'"); \ 24 | } \ 25 | } 26 | 27 | #define DOUBLE_DISPATCH(XTYPE, WTYPE, NAME, ...) \ 28 | [&] { \ 29 | const auto& x_type = XTYPE; \ 30 | const auto& w_type = WTYPE; \ 31 | switch (x_type) { \ 32 | NORMAL_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ 33 | NORMAL_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ 34 | HALF_CASE_TYPE(at::ScalarType::Half, at::Half, w_type, __VA_ARGS__) \ 35 | default: \ 36 | AT_ERROR(#NAME, " not implemented for '", toString(x_type), "'"); \ 37 | } \ 38 | }() 39 | 40 | #ifdef WITH_CUDA 41 | #define CUDA_DISPATCH(REF_TENSOR, METHOD, ...) \ 42 | if ((REF_TENSOR).is_cuda()) { \ 43 | return METHOD##_cuda(__VA_ARGS__); \ 44 | } else { \ 45 | return METHOD##_cpu(__VA_ARGS__); \ 46 | } 47 | #else 48 | #define CUDA_DISPATCH(REF_TENSOR, METHOD, ...) \ 49 | if ((REF_TENSOR).is_cuda()) { \ 50 | AT_ERROR("CUDA support was not enabled at compile time"); \ 51 | } else { \ 52 | return METHOD##_cpu(__VA_ARGS__); \ 53 | } 54 | #endif 55 | -------------------------------------------------------------------------------- /include/inplace_abn.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include "utils.h" 11 | #ifdef __CUDACC__ 12 | #include "cuda_utils.cuh" 13 | #endif 14 | 15 | /*********************************************************************************************************************** 16 | * Enums 17 | **********************************************************************************************************************/ 18 | 19 | enum class Activation { LeakyReLU, ELU, Identity }; 20 | 21 | /*********************************************************************************************************************** 22 | * CPU / Cuda methods 23 | **********************************************************************************************************************/ 24 | 25 | std::tuple statistics_cpu( 26 | const at::Tensor& x); 27 | std::tuple statistics_cuda( 28 | const at::Tensor& x); 29 | 30 | std::tuple reduce_statistics_cuda( 31 | const at::Tensor& all_mean, 32 | const at::Tensor& all_var, 33 | const at::Tensor& all_count); 34 | 35 | void forward_cpu( 36 | at::Tensor& x, 37 | const at::Tensor& mean, 38 | const at::Tensor& var, 39 | const std::optional& weight, 40 | const std::optional& bias, 41 | float eps, 42 | Activation activation, 43 | float activation_param); 44 | void forward_cuda( 45 | at::Tensor& x, 46 | const at::Tensor& mean, 47 | const at::Tensor& var, 48 | const std::optional& weight, 49 | const std::optional& bias, 50 | float eps, 51 | Activation activation, 52 | float activation_param); 53 | 54 | std::tuple backward_reduce_cpu( 55 | const at::Tensor& y_act, 56 | const at::Tensor& dy_act, 57 | const std::optional& weight, 58 | const std::optional& bias, 59 | float eps, 60 | Activation activation, 61 | float activation_param); 62 | std::tuple backward_reduce_cuda( 63 | const at::Tensor& y_act, 64 | const at::Tensor& dy_act, 65 | const std::optional& weight, 66 | const std::optional& bias, 67 | float eps, 68 | Activation activation, 69 | float activation_param); 70 | 71 | void backward_cpu( 72 | const at::Tensor& xhat, 73 | at::Tensor& dy, 74 | const at::Tensor& var, 75 | const at::Tensor& count, 76 | const at::Tensor& sum_dy, 77 | const at::Tensor& sum_xhat_dy, 78 | const std::optional& weight, 79 | float eps); 80 | void backward_cuda( 81 | const at::Tensor& xhat, 82 | at::Tensor& dy, 83 | const at::Tensor& var, 84 | const at::Tensor& count, 85 | const at::Tensor& sum_dy, 86 | const at::Tensor& sum_xhat_dy, 87 | const std::optional& weight, 88 | float eps); 89 | 90 | /*********************************************************************************************************************** 91 | * Handling of activation functions 92 | **********************************************************************************************************************/ 93 | 94 | template 95 | struct ActivationFn; 96 | 97 | template 98 | struct ActivationFn { 99 | static INLINE_HOST_DEVICE void forward(scalar_t& x, float activation_param) { 100 | x = (x >= 0) ? x : static_cast(x * activation_param); 101 | } 102 | 103 | static INLINE_HOST_DEVICE void backward( 104 | scalar_t y_act, 105 | scalar_t dy_act, 106 | float activation_param, 107 | scalar_t& y, 108 | scalar_t& dy) { 109 | if (y_act >= 0) { 110 | y = y_act; 111 | dy = dy_act; 112 | } else { 113 | y = static_cast(y_act / activation_param); 114 | dy = static_cast(dy_act * activation_param); 115 | } 116 | } 117 | }; 118 | 119 | template 120 | struct ActivationFn { 121 | static INLINE_HOST_DEVICE void forward(scalar_t& x, float activation_param) { 122 | x = (x >= 0) ? x : static_cast(activation_param * (::exp(x) - 1)); 123 | } 124 | 125 | static INLINE_HOST_DEVICE void backward( 126 | scalar_t y_act, 127 | scalar_t dy_act, 128 | float activation_param, 129 | scalar_t& y, 130 | scalar_t& dy) { 131 | if (y_act >= 0) { 132 | y = y_act; 133 | dy = dy_act; 134 | } else { 135 | y = ::log1p(static_cast(y_act / activation_param)); 136 | dy = static_cast(dy_act * (y_act + activation_param)); 137 | } 138 | } 139 | }; 140 | 141 | template 142 | struct ActivationFn { 143 | static INLINE_HOST_DEVICE void forward(scalar_t& x, float activation_param) {} 144 | 145 | static INLINE_HOST_DEVICE void backward( 146 | scalar_t y_act, 147 | scalar_t dy_act, 148 | float activation_param, 149 | scalar_t& y, 150 | scalar_t& dy) { 151 | y = y_act; 152 | dy = dy_act; 153 | } 154 | }; 155 | -------------------------------------------------------------------------------- /include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | /*********************************************************************************************************************** 11 | * General defines 12 | **********************************************************************************************************************/ 13 | 14 | #ifdef __CUDACC__ 15 | 16 | #define HOST_DEVICE __host__ __device__ 17 | #define INLINE_HOST_DEVICE __host__ __device__ __forceinline__ 18 | 19 | #else 20 | // CPU versions 21 | 22 | #define HOST_DEVICE 23 | #define INLINE_HOST_DEVICE inline 24 | 25 | #endif // #ifdef __CUDACC__ 26 | 27 | /*********************************************************************************************************************** 28 | * Utility functions 29 | **********************************************************************************************************************/ 30 | 31 | at::Tensor normalize_shape(const at::Tensor& x); 32 | 33 | template 34 | static at::TensorAccessor accessor_or_dummy( 35 | const std::optional& t) { 36 | if (!t.has_value()) { 37 | return at::TensorAccessor(nullptr, nullptr, nullptr); 38 | } 39 | return t.value().accessor(); 40 | } 41 | -------------------------------------------------------------------------------- /inplace_abn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapillary/inplace_abn/086317a56338649e0f994eaea68be84fbc6d7cf7/inplace_abn.png -------------------------------------------------------------------------------- /inplace_abn/__init__.py: -------------------------------------------------------------------------------- 1 | from .abn import ABN, InPlaceABN, InPlaceABNSync 2 | from .group import active_group, set_active_group 3 | 4 | try: 5 | from ._version import version as __version__ 6 | except ImportError: 7 | pass 8 | -------------------------------------------------------------------------------- /inplace_abn/_backend.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | """Stubs for the native methods""" 4 | from typing import Tuple, Optional 5 | 6 | import torch 7 | 8 | def statistics(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... 9 | def reduce_statistics( 10 | all_mean: torch.Tensor, all_var: torch.Tensor, all_count: torch.Tensor 11 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... 12 | def forward( 13 | x: torch.Tensor, 14 | mean: torch.Tensor, 15 | var: torch.Tensor, 16 | weight: Optional[torch.Tensor], 17 | bias: Optional[torch.Tensor], 18 | eps: float, 19 | activation, 20 | activation_param: float, 21 | ) -> None: ... 22 | def backward_reduce( 23 | y_act: torch.Tensor, 24 | dy_act: torch.Tensor, 25 | weight: Optional[torch.Tensor], 26 | bias: Optional[torch.Tensor], 27 | eps: float, 28 | activation, 29 | activation_param: float, 30 | ) -> Tuple[torch.Tensor, torch.Tensor]: ... 31 | def backward_train( 32 | xhat: torch.Tensor, 33 | dy: torch.Tensor, 34 | var: torch.Tensor, 35 | count: torch.Tensor, 36 | sum_dy: torch.Tensor, 37 | sum_xhat_dy: torch.Tensor, 38 | weight: Optional[torch.Tensor], 39 | eps: float, 40 | ) -> torch.Tensor: ... 41 | def backward_test( 42 | dy: torch.Tensor, var: torch.Tensor, weight: Optional[torch.Tensor], eps: float 43 | ) -> torch.Tensor: ... 44 | 45 | class Activation: 46 | LeakyReLU = ... 47 | ELU = ... 48 | Identity = ... 49 | -------------------------------------------------------------------------------- /inplace_abn/abn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from typing import Optional, Any 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as functional 8 | 9 | from .functions import inplace_abn, inplace_abn_sync 10 | 11 | 12 | class ABN(nn.Module): 13 | """Activated Batch Normalization 14 | 15 | This gathers a BatchNorm and an activation function in a single module 16 | 17 | Args: 18 | num_features: Number of feature channels in the input and output 19 | eps: Small constant to prevent numerical issues 20 | momentum: Momentum factor applied to compute running statistics with 21 | exponential moving average, or `None` to compute running statistics 22 | with cumulative moving average 23 | affine: If `True` apply learned scale and shift transformation after normalization 24 | track_running_stats: a boolean value that when set to `True`, this 25 | module tracks the running mean and variance, and when set to `False`, 26 | this module does not track such statistics and uses batch statistics instead 27 | in both training and eval modes if the running mean and variance are `None` 28 | activation: Name of the activation functions, one of: `relu`, `leaky_relu`, 29 | `elu` or `identity` 30 | activation_param: Negative slope for the `leaky_relu` activation or `alpha` 31 | parameter for the `elu` activation 32 | """ 33 | 34 | _version = 2 35 | __constants__ = [ 36 | "track_running_stats", 37 | "momentum", 38 | "eps", 39 | "num_features", 40 | "affine", 41 | "activation", 42 | "activation_param", 43 | ] 44 | num_features: int 45 | eps: float 46 | momentum: Optional[float] 47 | affine: bool 48 | track_running_stats: bool 49 | activation: str 50 | activation_param: float 51 | 52 | def __init__( 53 | self, 54 | num_features: int, 55 | eps: float = 1e-5, 56 | momentum: Optional[float] = 0.1, 57 | affine: bool = True, 58 | track_running_stats: bool = True, 59 | activation: str = "leaky_relu", 60 | activation_param: float = 0.01, 61 | ): 62 | super(ABN, self).__init__() 63 | self.num_features = num_features 64 | self.eps = eps 65 | self.momentum = momentum 66 | self.affine = affine 67 | self.track_running_stats = track_running_stats 68 | self.activation = activation 69 | self.activation_param = activation_param 70 | if self.affine: 71 | self.weight = nn.Parameter(torch.Tensor(num_features)) 72 | self.bias = nn.Parameter(torch.Tensor(num_features)) 73 | else: 74 | self.register_parameter("weight", None) 75 | self.register_parameter("bias", None) 76 | if self.track_running_stats: 77 | self.register_buffer("running_mean", torch.zeros(num_features)) 78 | self.register_buffer("running_var", torch.ones(num_features)) 79 | self.register_buffer( 80 | "num_batches_tracked", torch.tensor(0, dtype=torch.long) 81 | ) 82 | else: 83 | self.register_parameter("running_mean", None) 84 | self.register_parameter("running_var", None) 85 | self.register_parameter("num_batches_tracked", None) 86 | self.reset_parameters() 87 | 88 | def reset_running_stats(self) -> None: 89 | if self.track_running_stats: 90 | self.running_mean.zero_() 91 | self.running_var.fill_(1) 92 | self.num_batches_tracked.zero_() 93 | 94 | def reset_parameters(self) -> None: 95 | self.reset_running_stats() 96 | if self.affine: 97 | nn.init.ones_(self.weight) 98 | nn.init.zeros_(self.bias) 99 | 100 | def _get_momentum_and_training(self): 101 | if self.momentum is None: 102 | momentum = 0.0 103 | else: 104 | momentum = self.momentum 105 | 106 | if self.training and self.track_running_stats: 107 | if self.num_batches_tracked is not None: 108 | self.num_batches_tracked = self.num_batches_tracked + 1 109 | if self.momentum is None: 110 | momentum = 1.0 / float(self.num_batches_tracked) 111 | else: 112 | momentum = self.momentum 113 | 114 | if self.training: 115 | training = True 116 | else: 117 | training = (self.running_mean is None) and (self.running_var is None) 118 | 119 | return momentum, training 120 | 121 | def _get_running_stats(self): 122 | running_mean = ( 123 | self.running_mean if not self.training or self.track_running_stats else None 124 | ) 125 | running_var = ( 126 | self.running_var if not self.training or self.track_running_stats else None 127 | ) 128 | return running_mean, running_var 129 | 130 | def forward(self, x: torch.Tensor) -> torch.Tensor: 131 | momentum, training = self._get_momentum_and_training() 132 | running_mean, running_var = self._get_running_stats() 133 | 134 | x = functional.batch_norm( 135 | x, 136 | running_mean, 137 | running_var, 138 | self.weight, 139 | self.bias, 140 | training, 141 | momentum, 142 | self.eps, 143 | ) 144 | 145 | if self.activation == "relu": 146 | return functional.relu(x, inplace=True) 147 | elif self.activation == "leaky_relu": 148 | return functional.leaky_relu( 149 | x, negative_slope=self.activation_param, inplace=True 150 | ) 151 | elif self.activation == "elu": 152 | return functional.elu(x, alpha=self.activation_param, inplace=True) 153 | elif self.activation == "identity": 154 | return x 155 | else: 156 | raise RuntimeError(f"Unknown activation function {self.activation}") 157 | 158 | def _load_from_state_dict( 159 | self, 160 | state_dict, 161 | prefix, 162 | local_metadata, 163 | strict, 164 | missing_keys, 165 | unexpected_keys, 166 | error_msgs, 167 | ): 168 | version = local_metadata.get("version", None) 169 | 170 | if (version is None or version < 2) and self.track_running_stats: 171 | # at version 2: added num_batches_tracked buffer 172 | # this should have a default value of 0 173 | num_batches_tracked_key = prefix + "num_batches_tracked" 174 | if num_batches_tracked_key not in state_dict: 175 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 176 | 177 | super(ABN, self)._load_from_state_dict( 178 | state_dict, 179 | prefix, 180 | local_metadata, 181 | strict, 182 | missing_keys, 183 | unexpected_keys, 184 | error_msgs, 185 | ) 186 | 187 | def extra_repr(self): 188 | rep = "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, activation={activation}" 189 | if self.activation in ["leaky_relu", "elu"]: 190 | rep += "[{activation_param}]" 191 | return rep.format(**self.__dict__) 192 | 193 | 194 | class InPlaceABN(ABN): 195 | """InPlace Activated Batch Normalization 196 | 197 | Args: 198 | num_features: Number of feature channels in the input and output 199 | eps: Small constant to prevent numerical issues 200 | momentum: Momentum factor applied to compute running statistics with 201 | exponential moving average, or `None` to compute running statistics 202 | with cumulative moving average 203 | affine: If `True` apply learned scale and shift transformation after normalization 204 | track_running_stats: a boolean value that when set to `True`, this 205 | module tracks the running mean and variance, and when set to `False`, 206 | this module does not track such statistics and uses batch statistics instead 207 | in both training and eval modes if the running mean and variance are `None` 208 | activation: Name of the activation functions, one of: `relu`, `leaky_relu`, 209 | `elu` or `identity` 210 | activation_param: Negative slope for the `leaky_relu` activation or `alpha` 211 | parameter for the `elu` activation 212 | """ 213 | 214 | def __init__( 215 | self, 216 | num_features: int, 217 | eps: float = 1e-5, 218 | momentum: Optional[float] = 0.1, 219 | affine: bool = True, 220 | track_running_stats: bool = True, 221 | activation: str = "leaky_relu", 222 | activation_param: float = 0.01, 223 | ): 224 | super(InPlaceABN, self).__init__( 225 | num_features, 226 | eps, 227 | momentum, 228 | affine, 229 | track_running_stats, 230 | activation, 231 | activation_param, 232 | ) 233 | 234 | def forward(self, x: torch.Tensor) -> torch.Tensor: 235 | momentum, training = self._get_momentum_and_training() 236 | running_mean, running_var = self._get_running_stats() 237 | 238 | return inplace_abn( 239 | x, 240 | self.weight, 241 | self.bias, 242 | running_mean, 243 | running_var, 244 | training, 245 | momentum, 246 | self.eps, 247 | self.activation, 248 | self.activation_param, 249 | ) 250 | 251 | 252 | class InPlaceABNSync(ABN): 253 | """InPlace Activated Batch Normalization with distributed synchronization 254 | 255 | This operates like `inplace_abn`, but assumes to be called by all replicas 256 | in a given distributed group, and computes batch statistics across all of them. 257 | Note that the input tensors can have different dimensions in each replica. 258 | 259 | Args: 260 | num_features: Number of feature channels in the input and output 261 | eps: Small constant to prevent numerical issues 262 | momentum: Momentum factor applied to compute running statistics with 263 | exponential moving average, or `None` to compute running statistics 264 | with cumulative moving average 265 | affine: If `True` apply learned scale and shift transformation after normalization 266 | track_running_stats: a boolean value that when set to `True`, this 267 | module tracks the running mean and variance, and when set to `False`, 268 | this module does not track such statistics and uses batch statistics instead 269 | in both training and eval modes if the running mean and variance are `None` 270 | activation: Name of the activation functions, one of: `relu`, `leaky_relu`, 271 | `elu` or `identity` 272 | activation_param: Negative slope for the `leaky_relu` activation or `alpha` 273 | parameter for the `elu` activation 274 | group: Distributed group to synchronize with, or `None` to use the default group 275 | """ 276 | 277 | def __init__( 278 | self, 279 | num_features: int, 280 | eps: float = 1e-5, 281 | momentum: Optional[float] = 0.1, 282 | affine: bool = True, 283 | track_running_stats: bool = True, 284 | activation: str = "leaky_relu", 285 | activation_param: float = 0.01, 286 | group: Optional[Any] = None, 287 | ): 288 | super(InPlaceABNSync, self).__init__( 289 | num_features, 290 | eps, 291 | momentum, 292 | affine, 293 | track_running_stats, 294 | activation, 295 | activation_param, 296 | ) 297 | self.group = group 298 | 299 | def set_group(self, group: Optional[Any]) -> None: 300 | """Set distributed group to synchronize with 301 | 302 | This function should never be called between forward and backward 303 | 304 | Args: 305 | group: Distributed group to synchronize with, or `None` to use the default group 306 | """ 307 | self.group = group 308 | 309 | def forward(self, x: torch.Tensor) -> torch.Tensor: 310 | momentum, training = self._get_momentum_and_training() 311 | running_mean, running_var = self._get_running_stats() 312 | 313 | return inplace_abn_sync( 314 | x, 315 | self.weight, 316 | self.bias, 317 | running_mean, 318 | running_var, 319 | training, 320 | momentum, 321 | self.eps, 322 | self.activation, 323 | self.activation_param, 324 | self.group, 325 | ) 326 | -------------------------------------------------------------------------------- /inplace_abn/functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from typing import Optional, Any 4 | from warnings import warn 5 | 6 | import torch 7 | import torch.autograd as autograd 8 | import torch.distributed as distributed 9 | from torch.autograd.function import once_differentiable 10 | 11 | from . import _backend 12 | 13 | 14 | def _activation_from_name(activation): 15 | if activation == "leaky_relu": 16 | return _backend.Activation.LeakyReLU 17 | elif activation == "elu": 18 | return _backend.Activation.ELU 19 | elif activation == "identity": 20 | return _backend.Activation.Identity 21 | else: 22 | raise ValueError("Unknown activation function {}".format(activation)) 23 | 24 | 25 | def _count_samples(x): 26 | count = x.size(0) 27 | for i in range(2, x.ndimension()): 28 | count *= x.size(i) 29 | return count 30 | 31 | 32 | class InPlaceABN(autograd.Function): 33 | @staticmethod 34 | def _reduce_forward(mean, var, count, group, world_size): 35 | # Mean and variance 36 | mean_var = torch.cat([mean, var], dim=0) 37 | all_mean_var = mean_var.new_empty(world_size, mean_var.numel()) 38 | distributed.all_gather( 39 | list(all_mean_var.unbind(0)), mean_var, group=group, async_op=False 40 | ) 41 | all_mean, all_var = all_mean_var.split(mean.numel(), dim=1) 42 | 43 | # Count 44 | all_count = count.new_empty(world_size, 1) 45 | distributed.all_gather( 46 | list(all_count.unbind(0)), count, group=group, async_op=False 47 | ) 48 | 49 | return _backend.reduce_statistics(all_mean, all_var, all_count) 50 | 51 | @staticmethod 52 | def _reduce_backward(sum_dy, sum_xhat_dy, group): 53 | stacked = torch.cat([sum_dy, sum_xhat_dy], dim=0) 54 | distributed.all_reduce( 55 | stacked, distributed.ReduceOp.SUM, group=group, async_op=False 56 | ) 57 | return torch.split(stacked, sum_dy.numel(), dim=0) 58 | 59 | @staticmethod 60 | def forward( 61 | ctx, 62 | x, 63 | weight, 64 | bias, 65 | running_mean, 66 | running_var, 67 | training=True, 68 | momentum=0.1, 69 | eps=1e-05, 70 | activation="leaky_relu", 71 | activation_param=0.01, 72 | group=None, 73 | world_size=1, 74 | ): 75 | # Save context 76 | ctx.training = training 77 | ctx.momentum = momentum 78 | ctx.eps = eps 79 | ctx.activation = _activation_from_name(activation) 80 | ctx.activation_param = activation_param 81 | ctx.group = group 82 | ctx.world_size = world_size 83 | ctx.has_running_stats = running_mean is not None and running_mean is not None 84 | 85 | if ctx.training: 86 | mean, var, count = _backend.statistics(x) 87 | 88 | # Gather stats from all workers if needed 89 | if ctx.world_size > 1: 90 | mean, var, count = InPlaceABN._reduce_forward( 91 | mean, var, count, ctx.group, ctx.world_size 92 | ) 93 | 94 | # Update running stats if needed 95 | if ctx.has_running_stats: 96 | count_ = count.to(dtype=var.dtype) 97 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 98 | running_var.mul_((1 - ctx.momentum)).add_( 99 | ctx.momentum * var * count_ / (count_ - 1) 100 | ) 101 | else: 102 | mean, var, count = running_mean, running_var, None 103 | 104 | # Transform x 105 | _backend.forward( 106 | x, mean, var, weight, bias, ctx.eps, ctx.activation, ctx.activation_param 107 | ) 108 | 109 | # Save for backward and mark dirty tensors 110 | ctx.save_for_backward(x, var, count, weight, bias) 111 | ctx.mark_dirty(x) 112 | return x 113 | 114 | @staticmethod 115 | @once_differentiable 116 | def backward(ctx, dy_act): 117 | y_act, var, count, weight, bias = ctx.saved_tensors 118 | 119 | # Call backward_reduce if we need to compute at least one of the gradients 120 | if any(ctx.needs_input_grad): 121 | xhat, dy, sum_dy_local, sum_xhat_dy_local = _backend.backward_reduce( 122 | y_act, 123 | dy_act, 124 | weight, 125 | bias, 126 | ctx.eps, 127 | ctx.activation, 128 | ctx.activation_param, 129 | ) 130 | 131 | if ctx.world_size > 1: 132 | sum_dy, sum_xhat_dy = InPlaceABN._reduce_backward( 133 | sum_dy_local, sum_xhat_dy_local, ctx.group 134 | ) 135 | else: 136 | sum_dy, sum_xhat_dy = sum_dy_local, sum_xhat_dy_local 137 | else: 138 | return (None,) * 12 139 | 140 | # Gradient w.r.t. x 141 | if ctx.needs_input_grad[0]: 142 | if ctx.training: 143 | # This overwrites dy with dx 144 | _backend.backward_train( 145 | xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, ctx.eps 146 | ) 147 | dx = dy 148 | else: 149 | dx = _backend.backward_test(dy, var, weight, ctx.eps) 150 | else: 151 | dx = None 152 | 153 | # Gradient w.r.t. weight 154 | if weight is not None and ctx.needs_input_grad[1]: 155 | dweight = sum_xhat_dy_local 156 | dweight[weight < 0] *= -1 157 | else: 158 | dweight = None 159 | 160 | # Gradient w.r.t. bias 161 | if bias is not None and ctx.needs_input_grad[2]: 162 | dbias = sum_dy_local 163 | else: 164 | dbias = None 165 | 166 | return (dx, dweight, dbias) + (None,) * 9 167 | 168 | 169 | def inplace_abn( 170 | x: torch.Tensor, 171 | weight: Optional[torch.Tensor], 172 | bias: Optional[torch.Tensor], 173 | running_mean: Optional[torch.Tensor], 174 | running_var: Optional[torch.Tensor], 175 | training: bool = True, 176 | momentum: float = 0.1, 177 | eps: float = 1e-05, 178 | activation: str = "leaky_relu", 179 | activation_param: float = 0.01, 180 | ): 181 | """InPlace Activated Batch Normalization 182 | 183 | This applies the following per-channel combined BatchNorm + activation operation: 184 | 185 | x_hat = (x - mu) / sqrt(sigma^2 + eps) 186 | x <- act(x_hat, p) * (|weight| + eps) + bias 187 | 188 | where: 189 | - mu is the per-channel batch mean, or `running_mean` if `training` is `False` 190 | - sigma^2 is the per-channel batch variance, or `running_var` if `training` is `False` 191 | - act(., p) is the activation function specified by `activation` 192 | - p is `activation_param`, i.e. the negative slope of Leaky ReLU or alpha 193 | parameter of ELU 194 | - `weight` and `bias` are the optional affine parameters 195 | - `eps` is a small positive number 196 | 197 | The running statistics, if given and if `training` is `True` are updated as follows: 198 | 199 | running_mean <- running_mean * momentum + (1 - momentum) * mu 200 | running_var <- running_var * momentum + (1 - momentum) * unbiased_sigma^2 201 | 202 | where unbiased_sigma^2 is the unbiased batch variance 203 | 204 | Args: 205 | x: Input tensor with shape N x C or N x C x S_1 x ... x S_n, which will be 206 | overwritten with the result 207 | weight: Tensor of affine scale parameters with shape C, or `None` 208 | bias: Tensor of affine bias parameters with shape C, or `None` 209 | running_mean: Running mean tensor with shape C, or `None` 210 | running_var: Running variance tensor with shape C, or `None` 211 | training: If `True` compute, use and update batch statistics, otherwise use 212 | running statistics 213 | momentum: Momentum factor applied to compute running statistics 214 | eps: Small constant to prevent numerical issues 215 | activation: Name of the activation function, one of: `leaky_relu`, `elu` or `identity` 216 | activation_param: Negative slope for the `leaky_relu` activation or `alpha` 217 | parameter for the `elu` activation 218 | """ 219 | if training: 220 | samples = _count_samples(x) 221 | if samples <= 1: 222 | raise ValueError( 223 | "inplace_abn is trying to compute batch statistics, but the input " 224 | "tensor only contains a single sample per channel" 225 | ) 226 | 227 | return InPlaceABN.apply( 228 | x, 229 | weight, 230 | bias, 231 | running_mean, 232 | running_var, 233 | training, 234 | momentum, 235 | eps, 236 | activation, 237 | activation_param, 238 | None, 239 | 1, 240 | ) 241 | 242 | 243 | def inplace_abn_sync( 244 | x: torch.Tensor, 245 | weight: Optional[torch.Tensor], 246 | bias: Optional[torch.Tensor], 247 | running_mean: Optional[torch.Tensor], 248 | running_var: Optional[torch.Tensor], 249 | training: bool = True, 250 | momentum: float = 0.1, 251 | eps: float = 1e-05, 252 | activation: str = "leaky_relu", 253 | activation_param: float = 0.01, 254 | group: Optional[Any] = None, 255 | ): 256 | """InPlace Activated Batch Normalization with distributed synchronization 257 | 258 | This operates like `inplace_abn`, but assumes to be called by all replicas 259 | in the given distributed group, and computes batch statistics across all of them. 260 | Note that the input tensors can have different dimensions in each replica. 261 | 262 | Args: 263 | x: Input tensor with shape N x C or N x C x S_1 x ... x S_n, which will be 264 | overwritten with the result 265 | weight: Tensor of affine scale parameters with shape C, or `None` 266 | bias: Tensor of affine bias parameters with shape C, or `None` 267 | running_mean: Running mean tensor with shape C, or `None` 268 | running_var: Running variance tensor with shape C, or `None` 269 | training: If `True` compute, use and update batch statistics, otherwise use 270 | running statistics 271 | momentum: Momentum factor applied to compute running statistics 272 | eps: Small constant to prevent numerical issues 273 | activation: Name of the activation function, one of: `leaky_relu`, `elu` or `identity` 274 | activation_param: Negative slope for the `leaky_relu` activation or `alpha` 275 | parameter for the `elu` activation 276 | group: Distributed group to synchronize with, or `None` to use the default group 277 | """ 278 | if training: 279 | samples = _count_samples(x) 280 | if samples <= 1: 281 | raise ValueError( 282 | "inplace_abn_sync is trying to compute batch statistics, but the input " 283 | "tensor only contains a single sample per channel" 284 | ) 285 | 286 | if distributed.is_initialized(): 287 | if group is None: 288 | group = distributed.group.WORLD 289 | world_size = distributed.get_world_size(group) 290 | 291 | return InPlaceABN.apply( 292 | x, 293 | weight, 294 | bias, 295 | running_mean, 296 | running_var, 297 | training, 298 | momentum, 299 | eps, 300 | activation, 301 | activation_param, 302 | group, 303 | world_size, 304 | ) 305 | else: 306 | warn( 307 | "inplace_abn_sync is being called, but torch.distributed is not initialized. " 308 | "Reverting to non-synchronized inplace_abn.", 309 | category=RuntimeWarning, 310 | ) 311 | 312 | return InPlaceABN.apply( 313 | x, 314 | weight, 315 | bias, 316 | running_mean, 317 | running_var, 318 | training, 319 | momentum, 320 | eps, 321 | activation, 322 | activation_param, 323 | None, 324 | 1, 325 | ) 326 | 327 | 328 | __all__ = ["inplace_abn", "inplace_abn_sync"] 329 | -------------------------------------------------------------------------------- /inplace_abn/group.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import torch 4 | import torch.distributed as distributed 5 | import torch.nn as nn 6 | 7 | 8 | def active_group(active: bool): 9 | """Initialize a distributed group where each process can independently decide whether to participate or not 10 | 11 | Args: 12 | active: Whether this process will be active in the group or not 13 | 14 | Returns: 15 | group: A distributed group containing all processes that passed `active=True`, 16 | or `None` if all passed `False` 17 | """ 18 | world_size = distributed.get_world_size() 19 | rank = distributed.get_rank() 20 | 21 | # Check if cache is initialized, add WORLD and None to it 22 | if not hasattr(active_group, "__cache__"): 23 | active_group.__cache__ = { 24 | frozenset(range(world_size)): distributed.group.WORLD, 25 | frozenset(): None, 26 | } 27 | 28 | # Gather active status from all workers 29 | active = torch.tensor( 30 | rank if active else -1, dtype=torch.long, device=torch.cuda.current_device() 31 | ) 32 | active_workers = torch.empty( 33 | world_size, dtype=torch.long, device=torch.cuda.current_device() 34 | ) 35 | distributed.all_gather(list(active_workers.unbind(0)), active) 36 | 37 | # Create and cache group if it doesn't exist yet 38 | active_workers = frozenset(int(i) for i in active_workers.tolist() if i != -1) 39 | if active_workers not in active_group.__cache__: 40 | group = distributed.new_group(list(active_workers)) 41 | active_group.__cache__[active_workers] = group 42 | 43 | return active_group.__cache__[active_workers] 44 | 45 | 46 | def set_active_group(module: nn.Module, group): 47 | """Scan all submodules, passing a distributed group to all those that implement `set_group`""" 48 | 49 | def _set_group(m): 50 | if hasattr(m, "set_group"): 51 | m.set_group(group) 52 | 53 | module.apply(_set_group) 54 | -------------------------------------------------------------------------------- /licenses.csv: -------------------------------------------------------------------------------- 1 | Pillow, 7.0.0, HPND 2 | numpy, 1.18.1, BSD 3 | protobuf, 3.11.3, "3-Clause BSD License" 4 | setuptools, 20.7.0, MIT 5 | six, 1.14.0, MIT 6 | tensorboardX, 2.0, MIT 7 | torch, 1.4.0, "New BSD" 8 | torchvision, 0.5.0, BSD 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | numpy == 1.22.0 3 | setuptools == 70.0.0 4 | tensorboardX == 2.0 5 | torch == 2.2.0 6 | torchvision == 0.5.0 7 | -------------------------------------------------------------------------------- /scripts/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapillary/inplace_abn/086317a56338649e0f994eaea68be84fbc6d7cf7/scripts/dataset/__init__.py -------------------------------------------------------------------------------- /scripts/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import glob 4 | from itertools import chain 5 | from os import path 6 | 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class SegmentationDataset(Dataset): 13 | _EXTENSIONS = ["*.jpg", "*.jpeg", "*.png"] 14 | 15 | def __init__(self, in_dir, transform): 16 | super(SegmentationDataset, self).__init__() 17 | 18 | self.in_dir = in_dir 19 | self.transform = transform 20 | 21 | # Find all images 22 | self.images = [] 23 | for img_path in chain( 24 | *( 25 | glob.iglob(path.join(self.in_dir, ext)) 26 | for ext in SegmentationDataset._EXTENSIONS 27 | ) 28 | ): 29 | _, name_with_ext = path.split(img_path) 30 | idx, _ = path.splitext(name_with_ext) 31 | self.images.append({"idx": idx, "path": img_path}) 32 | 33 | def __len__(self): 34 | return len(self.images) 35 | 36 | def __getitem__(self, item): 37 | # Load image 38 | with Image.open(self.images[item]["path"]) as img_raw: 39 | size = img_raw.size 40 | img = self.transform(img_raw.convert(mode="RGB")) 41 | 42 | return {"img": img, "meta": {"idx": self.images[item]["idx"], "size": size}} 43 | 44 | 45 | def segmentation_collate(items): 46 | imgs = torch.stack([item["img"] for item in items]) 47 | metas = [item["meta"] for item in items] 48 | 49 | return {"img": imgs, "meta": metas} 50 | -------------------------------------------------------------------------------- /scripts/dataset/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class TestDistributedSampler(torch.utils.data.Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = dist.get_world_size() if dist.is_initialized() else 1 28 | if rank is None: 29 | rank = dist.get_rank() if dist.is_initialized() else 0 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.num_samples = (len(self.dataset) // self.num_replicas) + int( 34 | (len(self.dataset) % self.num_replicas) < self.rank 35 | ) 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | indices = torch.arange(0, len(self.dataset)) 40 | 41 | # subsample 42 | indices = indices[self.rank :: self.num_replicas] 43 | 44 | return iter(indices) 45 | 46 | def __len__(self): 47 | return self.num_samples 48 | -------------------------------------------------------------------------------- /scripts/dataset/transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from PIL import Image 4 | from torchvision.transforms import functional as tfn 5 | 6 | 7 | class SegmentationTransform: 8 | def __init__(self, longest_max_size, rgb_mean, rgb_std): 9 | self.longest_max_size = longest_max_size 10 | self.rgb_mean = rgb_mean 11 | self.rgb_std = rgb_std 12 | 13 | def __call__(self, img): 14 | # Scaling 15 | scale = self.longest_max_size / float(max(img.size[0], img.size[1])) 16 | if scale != 1.0: 17 | out_size = tuple(int(dim * scale) for dim in img.size) 18 | img = img.resize(out_size, resample=Image.BILINEAR) 19 | 20 | # Convert to torch and normalize 21 | img = tfn.to_tensor(img) 22 | img.sub_(img.new(self.rgb_mean).view(-1, 1, 1)) 23 | img.div_(img.new(self.rgb_std).view(-1, 1, 1)) 24 | 25 | return img 26 | -------------------------------------------------------------------------------- /scripts/experiments/densenet264_ipabn_lr_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "densenet264", 4 | "activation": "leaky_relu", 5 | "activation_param": 0.01, 6 | "input_3x3": true, 7 | "bn_mode": "inplace" 8 | }, 9 | "optimizer": { 10 | "batch_size": 256, 11 | "clip": 0, 12 | "learning_rate": 0.1, 13 | "nesterov": true, 14 | "schedule": { 15 | "type": "step", 16 | "epochs": 90, 17 | "params": { 18 | "step_size": 30 19 | } 20 | } 21 | }, 22 | "input": { 23 | "color_jitter_train": true, 24 | "lighting_train": true 25 | } 26 | } -------------------------------------------------------------------------------- /scripts/experiments/resnet101_ipabn-sync_lr_512.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "resnet101", 4 | "activation": "leaky_relu", 5 | "activation_param": 0.01, 6 | "bn_mode": "sync" 7 | }, 8 | "optimizer": { 9 | "batch_size": 512, 10 | "clip": 0, 11 | "learning_rate": 0.2, 12 | "nesterov": true, 13 | "schedule": { 14 | "type": "step", 15 | "epochs": 90, 16 | "params": { 17 | "step_size": 30 18 | } 19 | } 20 | }, 21 | "input": { 22 | "color_jitter_train": true, 23 | "lighting_train": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /scripts/experiments/resnet34_ipabn-sync_lr_512.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "resnet34", 4 | "activation": "leaky_relu", 5 | "activation_param": 0.01, 6 | "bn_mode": "sync" 7 | }, 8 | "optimizer": { 9 | "batch_size": 512, 10 | "clip": 0, 11 | "learning_rate": 0.2, 12 | "nesterov": true, 13 | "schedule": { 14 | "type": "step", 15 | "epochs": 90, 16 | "params": { 17 | "step_size": 30 18 | } 19 | } 20 | }, 21 | "input": { 22 | "color_jitter_train": true, 23 | "lighting_train": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /scripts/experiments/resnet50_ipabn-sync_lr_512.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "resnet50", 4 | "activation": "leaky_relu", 5 | "activation_param": 0.01, 6 | "bn_mode": "sync" 7 | }, 8 | "optimizer": { 9 | "batch_size": 512, 10 | "clip": 0, 11 | "learning_rate": 0.2, 12 | "nesterov": true, 13 | "schedule": { 14 | "type": "step", 15 | "epochs": 90, 16 | "params": { 17 | "step_size": 30 18 | } 19 | } 20 | }, 21 | "input": { 22 | "color_jitter_train": true, 23 | "lighting_train": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /scripts/experiments/resnext101_ipabn-sync_lr_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "resnext101", 4 | "activation": "leaky_relu", 5 | "activation_param": 0.01, 6 | "input_3x3": true, 7 | "bn_mode": "sync" 8 | }, 9 | "optimizer": { 10 | "batch_size": 256, 11 | "clip": 0, 12 | "learning_rate": 0.1, 13 | "nesterov": true, 14 | "schedule": { 15 | "type": "step", 16 | "epochs": 90, 17 | "params": { 18 | "step_size": 30 19 | } 20 | } 21 | }, 22 | "input": { 23 | "color_jitter_train": true, 24 | "lighting_train": true 25 | } 26 | } -------------------------------------------------------------------------------- /scripts/experiments/resnext101_ipabn_lr_512.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "resnext101", 4 | "activation": "leaky_relu", 5 | "activation_param": 0.01, 6 | "input_3x3": true, 7 | "bn_mode": "inplace" 8 | }, 9 | "optimizer": { 10 | "batch_size": 512, 11 | "clip": 0, 12 | "learning_rate": 0.1, 13 | "nesterov": true, 14 | "schedule": { 15 | "type": "step", 16 | "epochs": 90, 17 | "params": { 18 | "step_size": 30 19 | } 20 | } 21 | }, 22 | "input": { 23 | "color_jitter_train": true, 24 | "lighting_train": true 25 | } 26 | } -------------------------------------------------------------------------------- /scripts/experiments/resnext101_stdbn_lr_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "resnext101", 4 | "activation": "leaky_relu", 5 | "activation_param": 0.01, 6 | "input_3x3": true, 7 | "bn_mode": "standard" 8 | }, 9 | "optimizer": { 10 | "batch_size": 256, 11 | "clip": 0, 12 | "learning_rate": 0.1, 13 | "nesterov": true, 14 | "schedule": { 15 | "type": "step", 16 | "epochs": 90, 17 | "params": { 18 | "step_size": 30 19 | } 20 | } 21 | }, 22 | "input": { 23 | "color_jitter_train": true, 24 | "lighting_train": true 25 | } 26 | } -------------------------------------------------------------------------------- /scripts/experiments/resnext152_ipabn_lr_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "resnext152", 4 | "activation": "leaky_relu", 5 | "activation_param": 0.01, 6 | "input_3x3": true, 7 | "bn_mode": "inplace" 8 | }, 9 | "optimizer": { 10 | "batch_size": 256, 11 | "clip": 0, 12 | "learning_rate": 0.1, 13 | "nesterov": true, 14 | "schedule": { 15 | "type": "step", 16 | "epochs": 90, 17 | "params": { 18 | "step_size": 30 19 | } 20 | } 21 | }, 22 | "input": { 23 | "color_jitter_train": true, 24 | "lighting_train": true 25 | } 26 | } -------------------------------------------------------------------------------- /scripts/experiments/wider_resnet38_ipabn_lr_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "wider_resnet38", 4 | "activation": "leaky_relu", 5 | "activation_param": 0.01, 6 | "input_3x3": true, 7 | "bn_mode": "inplace" 8 | }, 9 | "optimizer": { 10 | "batch_size": 256, 11 | "clip": 0, 12 | "learning_rate": 0.1, 13 | "nesterov": true, 14 | "schedule": { 15 | "type": "linear", 16 | "mode": "step", 17 | "epochs": 90, 18 | "params": { 19 | "alpha": -0.000002222, 20 | "beta": 1.0 21 | } 22 | } 23 | }, 24 | "input": { 25 | "color_jitter_train": true, 26 | "lighting_train": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/imagenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapillary/inplace_abn/086317a56338649e0f994eaea68be84fbc6d7cf7/scripts/imagenet/__init__.py -------------------------------------------------------------------------------- /scripts/imagenet/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import json 4 | 5 | DEFAULTS = { 6 | "network": { 7 | "arch": "resnet101", 8 | "activation": "relu", # supported: relu, leaky_relu, elu, identity 9 | "activation_param": 0.01, # slope for leaky_relu, alpha for elu 10 | "input_3x3": False, 11 | "bn_mode": "standard", # supported: standard, inplace, sync 12 | "classes": 1000, 13 | "dilation": 1, 14 | "weight_gain_multiplier": 1, # note: this is ignored if weight_init == kaiming_* 15 | "weight_init": "xavier_normal", # supported: xavier_[normal,uniform], kaiming_[normal,uniform], orthogonal 16 | }, 17 | "optimizer": { 18 | "batch_size": 256, 19 | "type": "SGD", # supported: SGD, Adam 20 | "momentum": 0.9, 21 | "weight_decay": 1e-4, 22 | "clip": 1.0, 23 | "learning_rate": 0.1, 24 | "classifier_lr": -1.0, # If -1 use same learning rate as the rest of the network 25 | "nesterov": False, 26 | "schedule": { 27 | "type": "constant", # supported: constant, step, multistep, exponential, linear 28 | "mode": "epoch", # supported: epoch, step 29 | "epochs": 10, 30 | "params": {}, 31 | }, 32 | }, 33 | "input": { 34 | "scale_train": -1, # If -1 do not scale 35 | "crop_train": 224, 36 | "color_jitter_train": False, 37 | "lighting_train": False, 38 | "scale_val": 256, # If -1 do not scale 39 | "crop_val": 224, 40 | "mean": [0.485, 0.456, 0.406], 41 | "std": [0.229, 0.224, 0.225], 42 | }, 43 | } 44 | 45 | 46 | def _merge(src, dst): 47 | for k, v in src.items(): 48 | if k in dst: 49 | if isinstance(v, dict): 50 | _merge(src[k], dst[k]) 51 | else: 52 | dst[k] = v 53 | 54 | 55 | def load_config(config_file, defaults=DEFAULTS): 56 | with open(config_file, "r") as fd: 57 | config = json.load(fd) 58 | _merge(defaults, config) 59 | return config 60 | -------------------------------------------------------------------------------- /scripts/imagenet/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from random import sample 4 | 5 | import torch 6 | 7 | # Default augmentation values compatible with ImageNet data augmentation pipeline 8 | _DEFAULT_ALPHASTD = 0.1 9 | _DEFAULT_EIGVAL = [0.2175, 0.0188, 0.0045] 10 | _DEFAULT_EIGVEC = [ 11 | [-0.5675, 0.7192, 0.4009], 12 | [-0.5808, -0.0045, -0.8140], 13 | [-0.5836, -0.6948, 0.4203], 14 | ] 15 | _DEFAULT_BCS = [0.4, 0.4, 0.4] 16 | 17 | 18 | def _grayscale(img): 19 | alpha = img.new([0.299, 0.587, 0.114]) 20 | return (alpha.view(3, 1, 1) * img).sum(0, keepdim=True) 21 | 22 | 23 | def _blend(img1, img2, alpha): 24 | return img1 * alpha + (1 - alpha) * img2 25 | 26 | 27 | class Lighting: 28 | def __init__( 29 | self, alphastd=_DEFAULT_ALPHASTD, eigval=_DEFAULT_EIGVAL, eigvec=_DEFAULT_EIGVEC 30 | ): 31 | self._alphastd = alphastd 32 | self._eigval = eigval 33 | self._eigvec = eigvec 34 | 35 | def __call__(self, img): 36 | if self._alphastd == 0.0: 37 | return img 38 | 39 | alpha = torch.normal(img.new_zeros(3), self._alphastd) 40 | eigval = img.new(self._eigval) 41 | eigvec = img.new(self._eigvec) 42 | 43 | rgb = (eigvec * alpha * eigval).sum(dim=1) 44 | return img + rgb.view(3, 1, 1) 45 | 46 | 47 | class Saturation: 48 | def __init__(self, var): 49 | self._var = var 50 | 51 | def __call__(self, img): 52 | gs = _grayscale(img) 53 | alpha = img.new(1).uniform_(-self._var, self._var) + 1.0 54 | return _blend(img, gs, alpha) 55 | 56 | 57 | class Brightness: 58 | def __init__(self, var): 59 | self._var = var 60 | 61 | def __call__(self, img): 62 | gs = torch.zeros_like(img) 63 | alpha = img.new(1).uniform_(-self._var, self._var) + 1.0 64 | return _blend(img, gs, alpha) 65 | 66 | 67 | class Contrast: 68 | def __init__(self, var): 69 | self._var = var 70 | 71 | def __call__(self, img): 72 | gs = _grayscale(img) 73 | gs = img.new_full((1, 1, 1), gs.mean()) 74 | alpha = img.new(1).uniform_(-self._var, self._var) + 1.0 75 | return _blend(img, gs, alpha) 76 | 77 | 78 | class ColorJitter: 79 | def __init__( 80 | self, 81 | saturation=_DEFAULT_BCS[0], 82 | brightness=_DEFAULT_BCS[1], 83 | contrast=_DEFAULT_BCS[2], 84 | ): 85 | self._transforms = [] 86 | if saturation is not None: 87 | self._transforms.append(Saturation(saturation)) 88 | if brightness is not None: 89 | self._transforms.append(Brightness(brightness)) 90 | if contrast is not None: 91 | self._transforms.append(Contrast(contrast)) 92 | 93 | def __call__(self, img): 94 | if len(self._transforms) == 0: 95 | return img 96 | 97 | for t in sample(self._transforms, len(self._transforms)): 98 | img = t(img) 99 | return img 100 | -------------------------------------------------------------------------------- /scripts/imagenet/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import time 4 | from functools import partial 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.optim as optim 9 | import torch.optim.lr_scheduler as lr_scheduler 10 | import torchvision.transforms as transforms 11 | from inplace_abn import ABN, InPlaceABN, InPlaceABNSync 12 | 13 | from .transforms import ColorJitter, Lighting 14 | 15 | 16 | def _get_norm_act(network_config): 17 | if network_config["bn_mode"] == "standard": 18 | assert network_config["activation"] in ( 19 | "relu", 20 | "leaky_relu", 21 | "elu", 22 | "identity", 23 | ), "Standard batch normalization is only compatible with relu, leaky_relu, elu and identity" 24 | activation_fn = partial( 25 | ABN, 26 | activation=network_config["activation"], 27 | activation_param=network_config["activation_param"], 28 | ) 29 | elif network_config["bn_mode"] == "inplace": 30 | assert network_config["activation"] in ( 31 | "leaky_relu", 32 | "elu", 33 | "identity", 34 | ), "Inplace batch normalization is only compatible with leaky_relu, elu and identity" 35 | activation_fn = partial( 36 | InPlaceABN, 37 | activation=network_config["activation"], 38 | activation_param=network_config["activation_param"], 39 | ) 40 | elif network_config["bn_mode"] == "sync": 41 | assert network_config["activation"] in ( 42 | "leaky_relu", 43 | "elu", 44 | "identity", 45 | ), "Sync batch normalization is only compatible with leaky_relu, elu and identity" 46 | activation_fn = partial( 47 | InPlaceABNSync, 48 | activation=network_config["activation"], 49 | activation_param=network_config["activation_param"], 50 | ) 51 | else: 52 | print("Unrecognized batch normalization mode", network_config["bn_mode"]) 53 | exit(1) 54 | 55 | return activation_fn 56 | 57 | 58 | def get_model_params(network_config): 59 | """Convert a configuration to actual model parameters 60 | 61 | Parameters 62 | ---------- 63 | network_config : dict 64 | Dictionary containing the configuration options for the network. 65 | 66 | Returns 67 | ------- 68 | model_params : dict 69 | Dictionary containing the actual parameters to be passed to the `net_*` functions in `models`. 70 | """ 71 | model_params = {} 72 | if network_config["input_3x3"] and not network_config["arch"].startswith("wider"): 73 | model_params["input_3x3"] = True 74 | model_params["norm_act"] = _get_norm_act(network_config) 75 | model_params["classes"] = network_config["classes"] 76 | if not network_config["arch"].startswith("wider"): 77 | model_params["dilation"] = network_config["dilation"] 78 | return model_params 79 | 80 | 81 | def create_optimizer(optimizer_config, model): 82 | """Creates optimizer and schedule from configuration 83 | 84 | Parameters 85 | ---------- 86 | optimizer_config : dict 87 | Dictionary containing the configuration options for the optimizer. 88 | model : Model 89 | The network model. 90 | 91 | Returns 92 | ------- 93 | optimizer : Optimizer 94 | The optimizer. 95 | scheduler : LRScheduler 96 | The learning rate scheduler. 97 | """ 98 | if optimizer_config["classifier_lr"] != -1: 99 | # Separate classifier parameters from all others 100 | net_params = [] 101 | classifier_params = [] 102 | for k, v in model.named_parameters(): 103 | if k.find("fc") != -1: 104 | classifier_params.append(v) 105 | else: 106 | net_params.append(v) 107 | params = [ 108 | {"params": net_params}, 109 | {"params": classifier_params, "lr": optimizer_config["classifier_lr"]}, 110 | ] 111 | else: 112 | params = model.parameters() 113 | 114 | if optimizer_config["type"] == "SGD": 115 | optimizer = optim.SGD( 116 | params, 117 | lr=optimizer_config["learning_rate"], 118 | momentum=optimizer_config["momentum"], 119 | weight_decay=optimizer_config["weight_decay"], 120 | nesterov=optimizer_config["nesterov"], 121 | ) 122 | elif optimizer_config["type"] == "Adam": 123 | optimizer = optim.Adam( 124 | params, 125 | lr=optimizer_config["learning_rate"], 126 | weight_decay=optimizer_config["weight_decay"], 127 | ) 128 | else: 129 | raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"])) 130 | 131 | if optimizer_config["schedule"]["type"] == "step": 132 | scheduler = lr_scheduler.StepLR( 133 | optimizer, **optimizer_config["schedule"]["params"] 134 | ) 135 | elif optimizer_config["schedule"]["type"] == "multistep": 136 | scheduler = lr_scheduler.MultiStepLR( 137 | optimizer, **optimizer_config["schedule"]["params"] 138 | ) 139 | elif optimizer_config["schedule"]["type"] == "exponential": 140 | scheduler = lr_scheduler.ExponentialLR( 141 | optimizer, **optimizer_config["schedule"]["params"] 142 | ) 143 | elif optimizer_config["schedule"]["type"] == "constant": 144 | scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0) 145 | elif optimizer_config["schedule"]["type"] == "linear": 146 | 147 | def linear_lr(it): 148 | return ( 149 | it * optimizer_config["schedule"]["params"]["alpha"] 150 | + optimizer_config["schedule"]["params"]["beta"] 151 | ) 152 | 153 | scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr) 154 | 155 | return optimizer, scheduler 156 | 157 | 158 | def create_transforms(input_config): 159 | """Create transforms from configuration 160 | 161 | Parameters 162 | ---------- 163 | input_config : dict 164 | Dictionary containing the configuration options for input pre-processing. 165 | 166 | Returns 167 | ------- 168 | train_transforms : list 169 | List of transforms to be applied to the input during training. 170 | val_transforms : list 171 | List of transforms to be applied to the input during validation. 172 | """ 173 | normalize = transforms.Normalize(mean=input_config["mean"], std=input_config["std"]) 174 | 175 | train_transforms = [] 176 | if input_config["scale_train"] != -1: 177 | train_transforms.append(transforms.Scale(input_config["scale_train"])) 178 | train_transforms += [ 179 | transforms.RandomResizedCrop(input_config["crop_train"]), 180 | transforms.RandomHorizontalFlip(), 181 | transforms.ToTensor(), 182 | ] 183 | if input_config["color_jitter_train"]: 184 | train_transforms.append(ColorJitter()) 185 | if input_config["lighting_train"]: 186 | train_transforms.append(Lighting()) 187 | train_transforms.append(normalize) 188 | 189 | val_transforms = [] 190 | if input_config["scale_val"] != -1: 191 | val_transforms.append(transforms.Resize(input_config["scale_val"])) 192 | val_transforms += [ 193 | transforms.CenterCrop(input_config["crop_val"]), 194 | transforms.ToTensor(), 195 | normalize, 196 | ] 197 | 198 | return train_transforms, val_transforms 199 | 200 | 201 | def create_test_transforms(config, crop, scale, ten_crops): 202 | normalize = transforms.Normalize(mean=config["mean"], std=config["std"]) 203 | 204 | val_transforms = [] 205 | if scale != -1: 206 | val_transforms.append(transforms.Resize(scale)) 207 | if ten_crops: 208 | val_transforms += [ 209 | transforms.TenCrop(crop), 210 | transforms.Lambda( 211 | lambda crops: [transforms.ToTensor()(crop) for crop in crops] 212 | ), 213 | transforms.Lambda(lambda crops: [normalize(crop) for crop in crops]), 214 | transforms.Lambda(lambda crops: torch.stack(crops)), 215 | ] 216 | else: 217 | val_transforms += [ 218 | transforms.CenterCrop(crop), 219 | transforms.ToTensor(), 220 | normalize, 221 | ] 222 | 223 | return val_transforms 224 | 225 | 226 | class AverageMeter: 227 | """Computes and stores the average and current value""" 228 | 229 | def __init__(self): 230 | self.reset() 231 | 232 | def reset(self): 233 | self.val = 0 234 | self.avg = 0 235 | self.sum = 0 236 | self.count = 0 237 | 238 | def update(self, val, n=1): 239 | self.val = val 240 | self.sum += val * n 241 | self.count += n 242 | self.avg = self.sum / self.count 243 | 244 | 245 | def accuracy_sum(output, target, topk=(1,)): 246 | """Computes the precision@k for the specified values of k""" 247 | maxk = max(topk) 248 | 249 | _, pred = output.topk(maxk, 1, True, True) 250 | pred = pred.t() 251 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 252 | 253 | res = [] 254 | for k in topk: 255 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 256 | res.append(correct_k.mul_(100.0)) 257 | return res 258 | 259 | 260 | def validate( 261 | val_loader, 262 | model, 263 | criterion, 264 | ten_crops=False, 265 | print_freq=1, 266 | it=None, 267 | tb=None, 268 | logger=print, 269 | ): 270 | batch_time = AverageMeter() 271 | losses = AverageMeter() 272 | top1 = AverageMeter() 273 | top5 = AverageMeter() 274 | 275 | # switch to evaluate mode 276 | model.eval() 277 | 278 | end = time.time() 279 | 280 | rank = dist.get_rank() if dist.is_initialized() else 0 281 | world_size = dist.get_world_size() if dist.is_initialized() else 1 282 | do_print = rank == 0 283 | 284 | def process(input, target, all_reduce=None): 285 | with torch.no_grad(): 286 | if ten_crops: 287 | bs, ncrops, c, h, w = input.size() 288 | input = input.view(-1, c, h, w) 289 | 290 | target = target.cuda(non_blocking=True) 291 | 292 | # compute output 293 | if ten_crops: 294 | output = model(input).view(bs, ncrops, -1).mean(1) 295 | else: 296 | output = model(input) 297 | loss = criterion(output, target) 298 | 299 | # measure accuracy and record loss 300 | prec1, prec5 = accuracy_sum(output.data, target, topk=(1, 5)) 301 | 302 | loss *= target.shape[0] 303 | count = target.new_tensor([target.shape[0]], dtype=torch.long) 304 | if all_reduce: 305 | all_reduce(count) 306 | for meter, val in (losses, loss), (top1, prec1), (top5, prec5): 307 | if all_reduce: 308 | all_reduce(val) 309 | val /= count.item() 310 | meter.update(val.item(), count.item()) 311 | 312 | # deal with remainder 313 | all_reduce = ( 314 | partial(dist.all_reduce, op=dist.ReduceOp.SUM) 315 | if dist.is_initialized() 316 | else None 317 | ) 318 | last_group_size = len(val_loader.dataset) % world_size 319 | for i, (input, target) in enumerate(val_loader): 320 | if input.shape[0] > 1 or last_group_size == 0: 321 | process(input, target, all_reduce) 322 | else: 323 | process( 324 | input, 325 | target, 326 | partial( 327 | dist.all_reduce, 328 | op=dist.ReduceOp.SUM, 329 | group=dist.new_group(range(last_group_size)), 330 | ), 331 | ) 332 | 333 | # measure elapsed time 334 | batch_time.update(time.time() - end) 335 | end = time.time() 336 | 337 | if do_print and i % print_freq == 0: 338 | logger( 339 | "Test: [{0}/{1}]\t" 340 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t" 341 | "Loss {loss.val:.4f} ({loss.avg:.4f}) \t" 342 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f}) \t" 343 | "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format( 344 | i, 345 | len(val_loader), 346 | batch_time=batch_time, 347 | loss=losses, 348 | top1=top1, 349 | top5=top5, 350 | ) 351 | ) 352 | if input.shape[0] == 1 and rank > last_group_size > 0: 353 | dist.new_group(range(last_group_size)) 354 | 355 | if do_print: 356 | logger( 357 | " * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}".format( 358 | top1=top1, top5=top5 359 | ) 360 | ) 361 | 362 | if it is not None and (not dist.is_initialized() or dist.get_rank() == 0): 363 | tb.add_scalar("val/loss", losses.avg, it) 364 | tb.add_scalar("val/top1", top1.avg, it) 365 | tb.add_scalar("val/top5", top5.avg, it) 366 | 367 | return top1.avg 368 | -------------------------------------------------------------------------------- /scripts/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .densenet import * 2 | from .resnet import * 3 | from .resnext import * 4 | from .wider_resnet import * 5 | -------------------------------------------------------------------------------- /scripts/models/densenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import sys 4 | from collections import OrderedDict 5 | from functools import partial 6 | 7 | import torch.nn as nn 8 | from inplace_abn import ABN 9 | from modules import GlobalAvgPool2d, DenseModule 10 | 11 | from .util import try_index 12 | 13 | 14 | class DenseNet(nn.Module): 15 | def __init__( 16 | self, 17 | structure, 18 | norm_act=ABN, 19 | input_3x3=False, 20 | growth=32, 21 | theta=0.5, 22 | classes=0, 23 | dilation=1, 24 | ): 25 | """DenseNet 26 | 27 | Parameters 28 | ---------- 29 | structure : list of int 30 | Number of layers in each of the four dense blocks of the network. 31 | norm_act : callable 32 | Function to create normalization / activation Module. 33 | input_3x3 : bool 34 | If `True` use three `3x3` convolutions in the input module instead of a single `7x7` one. 35 | growth : int 36 | Number of channels in each layer, i.e. the "growth" factor of the DenseNet. 37 | theta : float 38 | Reduction factor for the transition blocks. 39 | classes : int 40 | If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end 41 | of the network. 42 | dilation : int or list of int 43 | List of dilation factors, or `1` to ignore dilation. If the dilation factor for a module is greater than `1` 44 | skip the pooling in the transition block right before it. 45 | """ 46 | super(DenseNet, self).__init__() 47 | self.structure = structure 48 | if len(structure) != 4: 49 | raise ValueError("Expected a structure with four values") 50 | 51 | # Initial layers 52 | if input_3x3: 53 | layers = [ 54 | ("conv1", nn.Conv2d(3, growth * 2, 3, stride=2, padding=1, bias=False)), 55 | ("bn1", norm_act(growth * 2)), 56 | ( 57 | "conv2", 58 | nn.Conv2d( 59 | growth * 2, growth * 2, 3, stride=1, padding=1, bias=False 60 | ), 61 | ), 62 | ("bn2", norm_act(growth * 2)), 63 | ( 64 | "conv3", 65 | nn.Conv2d( 66 | growth * 2, growth * 2, 3, stride=1, padding=1, bias=False 67 | ), 68 | ), 69 | ("pool", nn.MaxPool2d(3, stride=2, padding=1)), 70 | ] 71 | else: 72 | layers = [ 73 | ("conv1", nn.Conv2d(3, growth * 2, 7, stride=2, padding=3, bias=False)), 74 | ("pool", nn.MaxPool2d(3, stride=2, padding=1)), 75 | ] 76 | self.mod1 = nn.Sequential(OrderedDict(layers)) 77 | 78 | in_channels = growth * 2 79 | for mod_id in range(4): 80 | d = try_index(dilation, mod_id) 81 | s = 2 if d == 1 and mod_id > 0 else 1 82 | 83 | # Create transition module 84 | if mod_id > 0: 85 | out_channels = int(in_channels * theta) 86 | layers = [ 87 | ("bn", norm_act(in_channels)), 88 | ("conv", nn.Conv2d(in_channels, out_channels, 1, bias=False)), 89 | ] 90 | if s == 2: 91 | layers.append(("pool", nn.AvgPool2d(2, 2))) 92 | self.add_module( 93 | "tra%d" % (mod_id + 1), nn.Sequential(OrderedDict(layers)) 94 | ) 95 | in_channels = out_channels 96 | 97 | # Create dense module 98 | mod = DenseModule( 99 | in_channels, growth, structure[mod_id], norm_act=norm_act, dilation=d 100 | ) 101 | self.add_module("mod%d" % (mod_id + 2), mod) 102 | in_channels = mod.out_channels 103 | 104 | # Pooling and predictor 105 | self.bn_out = norm_act(in_channels) 106 | if classes != 0: 107 | self.classifier = nn.Sequential( 108 | OrderedDict( 109 | [ 110 | ("avg_pool", GlobalAvgPool2d()), 111 | ("fc", nn.Linear(in_channels, classes)), 112 | ] 113 | ) 114 | ) 115 | 116 | def forward(self, x): 117 | x = self.mod1(x) 118 | x = self.mod2(x) 119 | x = self.tra2(x) 120 | x = self.mod3(x) 121 | x = self.tra3(x) 122 | x = self.mod4(x) 123 | x = self.tra4(x) 124 | x = self.mod5(x) 125 | x = self.bn_out(x) 126 | 127 | if hasattr(self, "classifier"): 128 | x = self.classifier(x) 129 | return x 130 | 131 | 132 | _NETS = { 133 | "121": {"structure": [6, 12, 24, 16]}, 134 | "169": {"structure": [6, 12, 32, 32]}, 135 | "201": {"structure": [6, 12, 48, 32]}, 136 | "264": {"structure": [6, 12, 64, 48]}, 137 | } 138 | 139 | __all__ = [] 140 | for name, params in _NETS.items(): 141 | net_name = "net_densenet" + name 142 | setattr(sys.modules[__name__], net_name, partial(DenseNet, **params)) 143 | __all__.append(net_name) 144 | -------------------------------------------------------------------------------- /scripts/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import sys 4 | from collections import OrderedDict 5 | from functools import partial 6 | 7 | import torch.nn as nn 8 | from inplace_abn import ABN 9 | from modules import GlobalAvgPool2d, ResidualBlock 10 | 11 | from .util import try_index 12 | 13 | 14 | class ResNet(nn.Module): 15 | """Standard residual network 16 | 17 | Parameters 18 | ---------- 19 | structure : list of int 20 | Number of residual blocks in each of the four modules of the network 21 | bottleneck : bool 22 | If `True` use "bottleneck" residual blocks with 3 convolutions, otherwise use standard blocks 23 | norm_act : callable 24 | Function to create normalization / activation Module 25 | classes : int 26 | If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end 27 | of the network 28 | dilation : int or list of int 29 | List of dilation factors for the four modules of the network, or `1` to ignore dilation 30 | keep_outputs : bool 31 | If `True` output a list with the outputs of all modules 32 | """ 33 | 34 | def __init__( 35 | self, 36 | structure, 37 | bottleneck, 38 | norm_act=ABN, 39 | classes=0, 40 | dilation=1, 41 | keep_outputs=False, 42 | ): 43 | super(ResNet, self).__init__() 44 | self.structure = structure 45 | self.bottleneck = bottleneck 46 | self.dilation = dilation 47 | self.keep_outputs = keep_outputs 48 | 49 | if len(structure) != 4: 50 | raise ValueError("Expected a structure with four values") 51 | if dilation != 1 and len(dilation) != 4: 52 | raise ValueError("If dilation is not 1 it must contain four values") 53 | 54 | # Initial layers 55 | layers = [ 56 | ("conv1", nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)), 57 | ("bn1", norm_act(64)), 58 | ] 59 | if try_index(dilation, 0) == 1: 60 | layers.append(("pool1", nn.MaxPool2d(3, stride=2, padding=1))) 61 | self.mod1 = nn.Sequential(OrderedDict(layers)) 62 | 63 | # Groups of residual blocks 64 | in_channels = 64 65 | if self.bottleneck: 66 | channels = (64, 64, 256) 67 | else: 68 | channels = (64, 64) 69 | for mod_id, num in enumerate(structure): 70 | # Create blocks for module 71 | blocks = [] 72 | for block_id in range(num): 73 | stride, dil = self._stride_dilation(dilation, mod_id, block_id) 74 | blocks.append( 75 | ( 76 | "block%d" % (block_id + 1), 77 | ResidualBlock( 78 | in_channels, 79 | channels, 80 | norm_act=norm_act, 81 | stride=stride, 82 | dilation=dil, 83 | ), 84 | ) 85 | ) 86 | 87 | # Update channels and p_keep 88 | in_channels = channels[-1] 89 | 90 | # Create module 91 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 92 | 93 | # Double the number of channels for the next module 94 | channels = [c * 2 for c in channels] 95 | 96 | # Pooling and predictor 97 | if classes != 0: 98 | self.classifier = nn.Sequential( 99 | OrderedDict( 100 | [ 101 | ("avg_pool", GlobalAvgPool2d()), 102 | ("fc", nn.Linear(in_channels, classes)), 103 | ] 104 | ) 105 | ) 106 | 107 | @staticmethod 108 | def _stride_dilation(dilation, mod_id, block_id): 109 | d = try_index(dilation, mod_id) 110 | s = 2 if d == 1 and block_id == 0 and mod_id > 0 else 1 111 | return s, d 112 | 113 | def forward(self, x): 114 | outs = list() 115 | 116 | outs.append(self.mod1(x)) 117 | outs.append(self.mod2(outs[-1])) 118 | outs.append(self.mod3(outs[-1])) 119 | outs.append(self.mod4(outs[-1])) 120 | outs.append(self.mod5(outs[-1])) 121 | 122 | if hasattr(self, "classifier"): 123 | outs.append(self.classifier(outs[-1])) 124 | 125 | if self.keep_outputs: 126 | return outs 127 | else: 128 | return outs[-1] 129 | 130 | 131 | _NETS = { 132 | "18": {"structure": [2, 2, 2, 2], "bottleneck": False}, 133 | "34": {"structure": [3, 4, 6, 3], "bottleneck": False}, 134 | "50": {"structure": [3, 4, 6, 3], "bottleneck": True}, 135 | "101": {"structure": [3, 4, 23, 3], "bottleneck": True}, 136 | "152": {"structure": [3, 8, 36, 3], "bottleneck": True}, 137 | } 138 | 139 | __all__ = [] 140 | for name, params in _NETS.items(): 141 | net_name = "net_resnet" + name 142 | setattr(sys.modules[__name__], net_name, partial(ResNet, **params)) 143 | __all__.append(net_name) 144 | -------------------------------------------------------------------------------- /scripts/models/resnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import sys 4 | from collections import OrderedDict 5 | from functools import partial 6 | 7 | import torch.nn as nn 8 | from inplace_abn import ABN 9 | from modules import IdentityResidualBlock, GlobalAvgPool2d 10 | 11 | from .util import try_index 12 | 13 | 14 | class ResNeXt(nn.Module): 15 | def __init__( 16 | self, 17 | structure, 18 | groups=64, 19 | norm_act=ABN, 20 | input_3x3=False, 21 | classes=0, 22 | dilation=1, 23 | base_channels=(128, 128, 256), 24 | ): 25 | """Pre-activation (identity mapping) ResNeXt model 26 | 27 | Parameters 28 | ---------- 29 | structure : list of int 30 | Number of residual blocks in each of the four modules of the network. 31 | groups : int 32 | Number of groups in each ResNeXt block 33 | norm_act : callable 34 | Function to create normalization / activation Module. 35 | input_3x3 : bool 36 | If `True` use three `3x3` convolutions in the input module instead of a single `7x7` one. 37 | classes : int 38 | If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end 39 | of the network. 40 | dilation : list of list of int or list of int or int 41 | List of dilation factors, or `1` to ignore dilation. For each module, if a single value is given it is 42 | used for all its blocks, otherwise this expects a value for each block. 43 | base_channels : list of int 44 | Channels in the blocks of the first residual module. Each following module will multiply these values by 2. 45 | """ 46 | super(ResNeXt, self).__init__() 47 | self.structure = structure 48 | 49 | if len(structure) != 4: 50 | raise ValueError("Expected a structure with four values") 51 | if dilation != 1 and len(dilation) != 4: 52 | raise ValueError("If dilation is not 1 it must contain four values") 53 | 54 | # Initial layers 55 | if input_3x3: 56 | layers = [ 57 | ("conv1", nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)), 58 | ("bn1", norm_act(64)), 59 | ("conv2", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), 60 | ("bn2", norm_act(64)), 61 | ("conv3", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), 62 | ("pool", nn.MaxPool2d(3, stride=2, padding=1)), 63 | ] 64 | else: 65 | layers = [ 66 | ("conv1", nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)), 67 | ("pool", nn.MaxPool2d(3, stride=2, padding=1)), 68 | ] 69 | self.mod1 = nn.Sequential(OrderedDict(layers)) 70 | 71 | # Groups of residual blocks 72 | in_channels = 64 73 | channels = base_channels 74 | for mod_id, num in enumerate(structure): 75 | # Create blocks for module 76 | blocks = [] 77 | for block_id in range(num): 78 | s, d = self._stride_dilation(mod_id, block_id, dilation) 79 | blocks.append( 80 | ( 81 | "block%d" % (block_id + 1), 82 | IdentityResidualBlock( 83 | in_channels, 84 | channels, 85 | stride=s, 86 | norm_act=norm_act, 87 | groups=groups, 88 | dilation=d, 89 | ), 90 | ) 91 | ) 92 | 93 | # Update channels 94 | in_channels = channels[-1] 95 | 96 | # Create and add module 97 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 98 | channels = [c * 2 for c in channels] 99 | 100 | # Pooling and predictor 101 | self.bn_out = norm_act(in_channels) 102 | if classes != 0: 103 | self.classifier = nn.Sequential( 104 | OrderedDict( 105 | [ 106 | ("avg_pool", GlobalAvgPool2d()), 107 | ("fc", nn.Linear(in_channels, classes)), 108 | ] 109 | ) 110 | ) 111 | 112 | def forward(self, img): 113 | out = self.mod1(img) 114 | out = self.mod2(out) 115 | out = self.mod3(out) 116 | out = self.mod4(out) 117 | out = self.mod5(out) 118 | out = self.bn_out(out) 119 | 120 | if hasattr(self, "classifier"): 121 | out = self.classifier(out) 122 | 123 | return out 124 | 125 | @staticmethod 126 | def _stride_dilation(mod_id, block_id, dilation): 127 | if dilation == 1: 128 | s = 2 if mod_id > 0 and block_id == 0 else 1 129 | d = 1 130 | else: 131 | if dilation[mod_id] == 1: 132 | s = 2 if mod_id > 0 and block_id == 0 else 1 133 | d = 1 134 | else: 135 | s = 1 136 | d = try_index(dilation[mod_id], block_id) 137 | return s, d 138 | 139 | 140 | _NETS = { 141 | "50": {"structure": [3, 4, 6, 3]}, 142 | "101": {"structure": [3, 4, 23, 3]}, 143 | "152": {"structure": [3, 8, 36, 3]}, 144 | } 145 | 146 | __all__ = [] 147 | for name, params in _NETS.items(): 148 | net_name = "net_resnext" + name 149 | setattr(sys.modules[__name__], net_name, partial(ResNeXt, **params)) 150 | __all__.append(net_name) 151 | -------------------------------------------------------------------------------- /scripts/models/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | def try_index(scalar_or_list, i): 5 | try: 6 | return scalar_or_list[i] 7 | except TypeError: 8 | return scalar_or_list 9 | -------------------------------------------------------------------------------- /scripts/models/wider_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import sys 4 | from collections import OrderedDict 5 | from functools import partial 6 | 7 | import torch.nn as nn 8 | from inplace_abn import ABN 9 | from modules import IdentityResidualBlock, GlobalAvgPool2d 10 | 11 | 12 | class WiderResNet(nn.Module): 13 | def __init__(self, structure, norm_act=ABN, classes=0): 14 | """Wider ResNet with pre-activation (identity mapping) blocks 15 | 16 | Parameters 17 | ---------- 18 | structure : list of int 19 | Number of residual blocks in each of the six modules of the network. 20 | norm_act : callable 21 | Function to create normalization / activation Module. 22 | classes : int 23 | If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end 24 | of the network. 25 | """ 26 | super(WiderResNet, self).__init__() 27 | self.structure = structure 28 | 29 | if len(structure) != 6: 30 | raise ValueError("Expected a structure with six values") 31 | 32 | # Initial layers 33 | self.mod1 = nn.Sequential( 34 | OrderedDict( 35 | [("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))] 36 | ) 37 | ) 38 | 39 | # Groups of residual blocks 40 | in_channels = 64 41 | channels = [ 42 | (128, 128), 43 | (256, 256), 44 | (512, 512), 45 | (512, 1024), 46 | (512, 1024, 2048), 47 | (1024, 2048, 4096), 48 | ] 49 | for mod_id, num in enumerate(structure): 50 | # Create blocks for module 51 | blocks = [] 52 | for block_id in range(num): 53 | blocks.append( 54 | ( 55 | "block%d" % (block_id + 1), 56 | IdentityResidualBlock( 57 | in_channels, channels[mod_id], norm_act=norm_act 58 | ), 59 | ) 60 | ) 61 | 62 | # Update channels and p_keep 63 | in_channels = channels[mod_id][-1] 64 | 65 | # Create module 66 | if mod_id <= 4: 67 | self.add_module( 68 | "pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1) 69 | ) 70 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 71 | 72 | # Pooling and predictor 73 | self.bn_out = norm_act(in_channels) 74 | if classes != 0: 75 | self.classifier = nn.Sequential( 76 | OrderedDict( 77 | [ 78 | ("avg_pool", GlobalAvgPool2d()), 79 | ("fc", nn.Linear(in_channels, classes)), 80 | ] 81 | ) 82 | ) 83 | 84 | def forward(self, img): 85 | out = self.mod1(img) 86 | out = self.mod2(self.pool2(out)) 87 | out = self.mod3(self.pool3(out)) 88 | out = self.mod4(self.pool4(out)) 89 | out = self.mod5(self.pool5(out)) 90 | out = self.mod6(self.pool6(out)) 91 | out = self.mod7(out) 92 | out = self.bn_out(out) 93 | 94 | if hasattr(self, "classifier"): 95 | out = self.classifier(out) 96 | 97 | return out 98 | 99 | 100 | class WiderResNetA2(nn.Module): 101 | def __init__(self, structure, norm_act=ABN, classes=0, dilation=False): 102 | """Wider ResNet with pre-activation (identity mapping) blocks 103 | 104 | This variant uses down-sampling by max-pooling in the first two blocks and by strided convolution in the others. 105 | 106 | Parameters 107 | ---------- 108 | structure : list of int 109 | Number of residual blocks in each of the six modules of the network. 110 | norm_act : callable 111 | Function to create normalization / activation Module. 112 | classes : int 113 | If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end 114 | of the network. 115 | dilation : bool 116 | If `True` apply dilation to the last three modules and change the down-sampling factor from 32 to 8. 117 | """ 118 | super(WiderResNetA2, self).__init__() 119 | self.structure = structure 120 | self.dilation = dilation 121 | 122 | if len(structure) != 6: 123 | raise ValueError("Expected a structure with six values") 124 | 125 | # Initial layers 126 | self.mod1 = nn.Sequential( 127 | OrderedDict( 128 | [("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))] 129 | ) 130 | ) 131 | 132 | # Groups of residual blocks 133 | in_channels = 64 134 | channels = [ 135 | (128, 128), 136 | (256, 256), 137 | (512, 512), 138 | (512, 1024), 139 | (512, 1024, 2048), 140 | (1024, 2048, 4096), 141 | ] 142 | for mod_id, num in enumerate(structure): 143 | # Create blocks for module 144 | blocks = [] 145 | for block_id in range(num): 146 | if not dilation: 147 | dil = 1 148 | stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1 149 | else: 150 | if mod_id == 3: 151 | dil = 2 152 | elif mod_id > 3: 153 | dil = 4 154 | else: 155 | dil = 1 156 | stride = 2 if block_id == 0 and mod_id == 2 else 1 157 | 158 | if mod_id == 4: 159 | drop = partial(nn.Dropout2d, p=0.3) 160 | elif mod_id == 5: 161 | drop = partial(nn.Dropout2d, p=0.5) 162 | else: 163 | drop = None 164 | 165 | blocks.append( 166 | ( 167 | "block%d" % (block_id + 1), 168 | IdentityResidualBlock( 169 | in_channels, 170 | channels[mod_id], 171 | norm_act=norm_act, 172 | stride=stride, 173 | dilation=dil, 174 | dropout=drop, 175 | ), 176 | ) 177 | ) 178 | 179 | # Update channels and p_keep 180 | in_channels = channels[mod_id][-1] 181 | 182 | # Create module 183 | if mod_id < 2: 184 | self.add_module( 185 | "pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1) 186 | ) 187 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 188 | 189 | # Pooling and predictor 190 | self.bn_out = norm_act(in_channels) 191 | if classes != 0: 192 | self.classifier = nn.Sequential( 193 | OrderedDict( 194 | [ 195 | ("avg_pool", GlobalAvgPool2d()), 196 | ("fc", nn.Linear(in_channels, classes)), 197 | ] 198 | ) 199 | ) 200 | 201 | def forward(self, img): 202 | out = self.mod1(img) 203 | out = self.mod2(self.pool2(out)) 204 | out = self.mod3(self.pool3(out)) 205 | out = self.mod4(out) 206 | out = self.mod5(out) 207 | out = self.mod6(out) 208 | out = self.mod7(out) 209 | out = self.bn_out(out) 210 | 211 | if hasattr(self, "classifier"): 212 | return self.classifier(out) 213 | else: 214 | return out 215 | 216 | 217 | _NETS = { 218 | "16": {"structure": [1, 1, 1, 1, 1, 1]}, 219 | "20": {"structure": [1, 1, 1, 3, 1, 1]}, 220 | "38": {"structure": [3, 3, 6, 3, 1, 1]}, 221 | } 222 | 223 | __all__ = [] 224 | for name, params in _NETS.items(): 225 | net_name = "net_wider_resnet" + name 226 | setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params)) 227 | __all__.append(net_name) 228 | for name, params in _NETS.items(): 229 | net_name = "net_wider_resnet" + name + "_a2" 230 | setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params)) 231 | __all__.append(net_name) 232 | -------------------------------------------------------------------------------- /scripts/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplab import DeeplabV3 2 | from .dense import DenseModule 3 | from .misc import GlobalAvgPool2d, SingleGPU 4 | from .residual import IdentityResidualBlock, ResidualBlock 5 | -------------------------------------------------------------------------------- /scripts/modules/deeplab.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as functional 6 | from inplace_abn import ABN 7 | from models.util import try_index 8 | 9 | 10 | class DeeplabV3(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | hidden_channels=256, 16 | dilations=(12, 24, 36), 17 | norm_act=ABN, 18 | pooling_size=None, 19 | ): 20 | super(DeeplabV3, self).__init__() 21 | self.pooling_size = pooling_size 22 | 23 | self.map_convs = nn.ModuleList( 24 | [ 25 | nn.Conv2d(in_channels, hidden_channels, 1, bias=False), 26 | nn.Conv2d( 27 | in_channels, 28 | hidden_channels, 29 | 3, 30 | bias=False, 31 | dilation=dilations[0], 32 | padding=dilations[0], 33 | ), 34 | nn.Conv2d( 35 | in_channels, 36 | hidden_channels, 37 | 3, 38 | bias=False, 39 | dilation=dilations[1], 40 | padding=dilations[1], 41 | ), 42 | nn.Conv2d( 43 | in_channels, 44 | hidden_channels, 45 | 3, 46 | bias=False, 47 | dilation=dilations[2], 48 | padding=dilations[2], 49 | ), 50 | ] 51 | ) 52 | self.map_bn = norm_act(hidden_channels * 4) 53 | 54 | self.global_pooling_conv = nn.Conv2d( 55 | in_channels, hidden_channels, 1, bias=False 56 | ) 57 | self.global_pooling_bn = norm_act(hidden_channels) 58 | 59 | self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) 60 | self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) 61 | self.red_bn = norm_act(out_channels) 62 | 63 | self.reset_parameters(self.map_bn.activation, self.map_bn.activation_param) 64 | 65 | def reset_parameters(self, activation, slope): 66 | gain = nn.init.calculate_gain(activation, slope) 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv2d): 69 | nn.init.xavier_normal_(m.weight.data, gain) 70 | if hasattr(m, "bias") and m.bias is not None: 71 | nn.init.constant_(m.bias, 0) 72 | elif isinstance(m, ABN): 73 | if hasattr(m, "weight") and m.weight is not None: 74 | nn.init.constant_(m.weight, 1) 75 | if hasattr(m, "bias") and m.bias is not None: 76 | nn.init.constant_(m.bias, 0) 77 | 78 | def forward(self, x): 79 | # Map convolutions 80 | out = torch.cat([m(x) for m in self.map_convs], dim=1) 81 | out = self.map_bn(out) 82 | out = self.red_conv(out) 83 | 84 | # Global pooling 85 | pool = self._global_pooling(x) 86 | pool = self.global_pooling_conv(pool) 87 | pool = self.global_pooling_bn(pool) 88 | pool = self.pool_red_conv(pool) 89 | if self.training or self.pooling_size is None: 90 | pool = pool.repeat(1, 1, x.size(2), x.size(3)) 91 | 92 | out += pool 93 | out = self.red_bn(out) 94 | return out 95 | 96 | def _global_pooling(self, x): 97 | if self.training or self.pooling_size is None: 98 | pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) 99 | pool = pool.view(x.size(0), x.size(1), 1, 1) 100 | else: 101 | pooling_size = ( 102 | min(try_index(self.pooling_size, 0), x.shape[2]), 103 | min(try_index(self.pooling_size, 1), x.shape[3]), 104 | ) 105 | padding = ( 106 | (pooling_size[1] - 1) // 2, 107 | (pooling_size[1] - 1) // 2 108 | if pooling_size[1] % 2 == 1 109 | else (pooling_size[1] - 1) // 2 + 1, 110 | (pooling_size[0] - 1) // 2, 111 | (pooling_size[0] - 1) // 2 112 | if pooling_size[0] % 2 == 1 113 | else (pooling_size[0] - 1) // 2 + 1, 114 | ) 115 | 116 | pool = functional.avg_pool2d(x, pooling_size, stride=1) 117 | pool = functional.pad(pool, pad=padding, mode="replicate") 118 | return pool 119 | -------------------------------------------------------------------------------- /scripts/modules/dense.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn as nn 7 | from inplace_abn import ABN 8 | 9 | 10 | class DenseModule(nn.Module): 11 | def __init__( 12 | self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1 13 | ): 14 | super(DenseModule, self).__init__() 15 | self.in_channels = in_channels 16 | self.growth = growth 17 | self.layers = layers 18 | 19 | self.convs1 = nn.ModuleList() 20 | self.convs3 = nn.ModuleList() 21 | for i in range(self.layers): 22 | self.convs1.append( 23 | nn.Sequential( 24 | OrderedDict( 25 | [ 26 | ("bn", norm_act(in_channels)), 27 | ( 28 | "conv", 29 | nn.Conv2d( 30 | in_channels, 31 | self.growth * bottleneck_factor, 32 | 1, 33 | bias=False, 34 | ), 35 | ), 36 | ] 37 | ) 38 | ) 39 | ) 40 | self.convs3.append( 41 | nn.Sequential( 42 | OrderedDict( 43 | [ 44 | ("bn", norm_act(self.growth * bottleneck_factor)), 45 | ( 46 | "conv", 47 | nn.Conv2d( 48 | self.growth * bottleneck_factor, 49 | self.growth, 50 | 3, 51 | padding=dilation, 52 | bias=False, 53 | dilation=dilation, 54 | ), 55 | ), 56 | ] 57 | ) 58 | ) 59 | ) 60 | in_channels += self.growth 61 | 62 | @property 63 | def out_channels(self): 64 | return self.in_channels + self.growth * self.layers 65 | 66 | def forward(self, x): 67 | inputs = [x] 68 | for i in range(self.layers): 69 | x = torch.cat(inputs, dim=1) 70 | x = self.convs1[i](x) 71 | x = self.convs3[i](x) 72 | inputs += [x] 73 | 74 | return torch.cat(inputs, dim=1) 75 | -------------------------------------------------------------------------------- /scripts/modules/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class GlobalAvgPool2d(nn.Module): 7 | def __init__(self): 8 | """Global average pooling over the input's spatial dimensions""" 9 | super(GlobalAvgPool2d, self).__init__() 10 | 11 | def forward(self, inputs): 12 | in_size = inputs.size() 13 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 14 | 15 | 16 | class SingleGPU(nn.Module): 17 | def __init__(self, module): 18 | super(SingleGPU, self).__init__() 19 | self.module = module 20 | 21 | def forward(self, x): 22 | return self.module(x.cuda(non_blocking=True)) 23 | -------------------------------------------------------------------------------- /scripts/modules/residual.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from collections import OrderedDict 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as functional 7 | from inplace_abn import ABN 8 | 9 | 10 | class ResidualBlock(nn.Module): 11 | """Configurable residual block 12 | 13 | Parameters 14 | ---------- 15 | in_channels : int 16 | Number of input channels. 17 | channels : list of int 18 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 19 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 20 | `3 x 3` then `1 x 1` convolutions. 21 | stride : int 22 | Stride of the first `3 x 3` convolution 23 | dilation : int 24 | Dilation to apply to the `3 x 3` convolutions. 25 | groups : int 26 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 27 | bottleneck blocks. 28 | norm_act : callable 29 | Function to create normalization / activation Module. 30 | dropout: callable 31 | Function to create Dropout Module. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | in_channels, 37 | channels, 38 | stride=1, 39 | dilation=1, 40 | groups=1, 41 | norm_act=ABN, 42 | dropout=None, 43 | ): 44 | super(ResidualBlock, self).__init__() 45 | 46 | # Check parameters for inconsistencies 47 | if len(channels) != 2 and len(channels) != 3: 48 | raise ValueError("channels must contain either two or three values") 49 | if len(channels) == 2 and groups != 1: 50 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 51 | 52 | is_bottleneck = len(channels) == 3 53 | need_proj_conv = stride != 1 or in_channels != channels[-1] 54 | 55 | if not is_bottleneck: 56 | bn2 = norm_act(channels[1]) 57 | bn2.activation = "identity" 58 | layers = [ 59 | ( 60 | "conv1", 61 | nn.Conv2d( 62 | in_channels, 63 | channels[0], 64 | 3, 65 | stride=stride, 66 | padding=dilation, 67 | bias=False, 68 | dilation=dilation, 69 | ), 70 | ), 71 | ("bn1", norm_act(channels[0])), 72 | ( 73 | "conv2", 74 | nn.Conv2d( 75 | channels[0], 76 | channels[1], 77 | 3, 78 | stride=1, 79 | padding=dilation, 80 | bias=False, 81 | dilation=dilation, 82 | ), 83 | ), 84 | ("bn2", bn2), 85 | ] 86 | if dropout is not None: 87 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 88 | else: 89 | bn3 = norm_act(channels[2]) 90 | bn3.activation = "identity" 91 | layers = [ 92 | ( 93 | "conv1", 94 | nn.Conv2d( 95 | in_channels, channels[0], 1, stride=1, padding=0, bias=False 96 | ), 97 | ), 98 | ("bn1", norm_act(channels[0])), 99 | ( 100 | "conv2", 101 | nn.Conv2d( 102 | channels[0], 103 | channels[1], 104 | 3, 105 | stride=stride, 106 | padding=dilation, 107 | bias=False, 108 | groups=groups, 109 | dilation=dilation, 110 | ), 111 | ), 112 | ("bn2", norm_act(channels[1])), 113 | ( 114 | "conv3", 115 | nn.Conv2d( 116 | channels[1], channels[2], 1, stride=1, padding=0, bias=False 117 | ), 118 | ), 119 | ("bn3", bn3), 120 | ] 121 | if dropout is not None: 122 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 123 | self.convs = nn.Sequential(OrderedDict(layers)) 124 | 125 | if need_proj_conv: 126 | self.proj_conv = nn.Conv2d( 127 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False 128 | ) 129 | self.proj_bn = norm_act(channels[-1]) 130 | self.proj_bn.activation = "identity" 131 | 132 | def forward(self, x): 133 | if hasattr(self, "proj_conv"): 134 | residual = self.proj_conv(x) 135 | residual = self.proj_bn(residual) 136 | else: 137 | residual = x 138 | x = self.convs(x) + residual 139 | 140 | if self.convs.bn1.activation == "leaky_relu": 141 | return functional.leaky_relu( 142 | x, negative_slope=self.convs.bn1.activation_param, inplace=True 143 | ) 144 | elif self.convs.bn1.activation == "elu": 145 | return functional.elu( 146 | x, alpha=self.convs.bn1.activation_param, inplace=True 147 | ) 148 | elif self.convs.bn1.activation == "identity": 149 | return x 150 | 151 | 152 | class IdentityResidualBlock(nn.Module): 153 | def __init__( 154 | self, 155 | in_channels, 156 | channels, 157 | stride=1, 158 | dilation=1, 159 | groups=1, 160 | norm_act=ABN, 161 | dropout=None, 162 | ): 163 | """Configurable identity-mapping residual block 164 | 165 | Parameters 166 | ---------- 167 | in_channels : int 168 | Number of input channels. 169 | channels : list of int 170 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 171 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 172 | `3 x 3` then `1 x 1` convolutions. 173 | stride : int 174 | Stride of the first `3 x 3` convolution 175 | dilation : int 176 | Dilation to apply to the `3 x 3` convolutions. 177 | groups : int 178 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 179 | bottleneck blocks. 180 | norm_act : callable 181 | Function to create normalization / activation Module. 182 | dropout: callable 183 | Function to create Dropout Module. 184 | """ 185 | super(IdentityResidualBlock, self).__init__() 186 | 187 | # Check parameters for inconsistencies 188 | if len(channels) != 2 and len(channels) != 3: 189 | raise ValueError("channels must contain either two or three values") 190 | if len(channels) == 2 and groups != 1: 191 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 192 | 193 | is_bottleneck = len(channels) == 3 194 | need_proj_conv = stride != 1 or in_channels != channels[-1] 195 | 196 | self.bn1 = norm_act(in_channels) 197 | if not is_bottleneck: 198 | layers = [ 199 | ( 200 | "conv1", 201 | nn.Conv2d( 202 | in_channels, 203 | channels[0], 204 | 3, 205 | stride=stride, 206 | padding=dilation, 207 | bias=False, 208 | dilation=dilation, 209 | ), 210 | ), 211 | ("bn2", norm_act(channels[0])), 212 | ( 213 | "conv2", 214 | nn.Conv2d( 215 | channels[0], 216 | channels[1], 217 | 3, 218 | stride=1, 219 | padding=dilation, 220 | bias=False, 221 | dilation=dilation, 222 | ), 223 | ), 224 | ] 225 | if dropout is not None: 226 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 227 | else: 228 | layers = [ 229 | ( 230 | "conv1", 231 | nn.Conv2d( 232 | in_channels, 233 | channels[0], 234 | 1, 235 | stride=stride, 236 | padding=0, 237 | bias=False, 238 | ), 239 | ), 240 | ("bn2", norm_act(channels[0])), 241 | ( 242 | "conv2", 243 | nn.Conv2d( 244 | channels[0], 245 | channels[1], 246 | 3, 247 | stride=1, 248 | padding=dilation, 249 | bias=False, 250 | groups=groups, 251 | dilation=dilation, 252 | ), 253 | ), 254 | ("bn3", norm_act(channels[1])), 255 | ( 256 | "conv3", 257 | nn.Conv2d( 258 | channels[1], channels[2], 1, stride=1, padding=0, bias=False 259 | ), 260 | ), 261 | ] 262 | if dropout is not None: 263 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 264 | self.convs = nn.Sequential(OrderedDict(layers)) 265 | 266 | if need_proj_conv: 267 | self.proj_conv = nn.Conv2d( 268 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False 269 | ) 270 | 271 | def forward(self, x): 272 | if hasattr(self, "proj_conv"): 273 | bn1 = self.bn1(x) 274 | shortcut = self.proj_conv(bn1) 275 | else: 276 | shortcut = x.clone() 277 | bn1 = self.bn1(x) 278 | 279 | out = self.convs(bn1) 280 | out.add_(shortcut) 281 | 282 | return out 283 | -------------------------------------------------------------------------------- /scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0 2 | torchvision 3 | tensorboardX 4 | pillow 5 | -------------------------------------------------------------------------------- /scripts/test_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import argparse 4 | import os 5 | 6 | import models 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.datasets as datasets 16 | import torchvision.transforms as transforms 17 | from dataset.sampler import TestDistributedSampler 18 | from imagenet import config, utils 19 | from modules import SingleGPU 20 | 21 | parser = argparse.ArgumentParser(description="PyTorch ImageNet Testing.") 22 | parser.add_argument( 23 | "config", 24 | metavar="CONFIG_FILE", 25 | help="path to configuration file. NOTE: validation-related settings are ignored", 26 | ) 27 | parser.add_argument( 28 | "checkpoint", 29 | metavar="PATH", 30 | type=str, 31 | help="path to latest checkpoint (default: none)", 32 | ) 33 | parser.add_argument("data", metavar="DIR", help="path to dataset") 34 | parser.add_argument( 35 | "-j", 36 | "--workers", 37 | default=4, 38 | type=int, 39 | metavar="N", 40 | help="number of data loading workers (default: 4)", 41 | ) 42 | parser.add_argument( 43 | "--print-freq", 44 | "-p", 45 | default=10, 46 | type=int, 47 | metavar="N", 48 | help="print frequency (default: 10)", 49 | ) 50 | parser.add_argument( 51 | "--crop", "-c", metavar="N", type=int, default=224, help="crop size" 52 | ) 53 | parser.add_argument( 54 | "--scale", 55 | "-s", 56 | metavar="N", 57 | type=int, 58 | default=256, 59 | help="scale size, if -1 do not scale input", 60 | ) 61 | parser.add_argument( 62 | "--ten_crops", 63 | action="store_true", 64 | help="run ten-crops testing instead of center-crop testing", 65 | ) 66 | parser.add_argument("--local_rank", default=0, type=int, help="process rank on node") 67 | parser.add_argument( 68 | "--dist-backend", default="nccl", type=str, help="distributed backend" 69 | ) 70 | 71 | args = None 72 | conf = None 73 | cudnn.benchmark = True 74 | 75 | 76 | def main(): 77 | global args, conf 78 | args = parser.parse_args() 79 | 80 | torch.cuda.set_device(args.local_rank) 81 | 82 | try: 83 | world_size = int(os.environ["WORLD_SIZE"]) 84 | distributed = world_size > 1 85 | except: 86 | distributed = False 87 | world_size = 1 88 | 89 | if distributed: 90 | dist.init_process_group(backend=args.dist_backend, init_method="env://") 91 | 92 | # Load configuration 93 | conf = config.load_config(args.config) 94 | 95 | # Create model 96 | model_params = utils.get_model_params(conf["network"]) 97 | model = models.__dict__["net_" + conf["network"]["arch"]](**model_params) 98 | model.cuda() 99 | if distributed: 100 | model = torch.nn.parallel.DistributedDataParallel( 101 | model, device_ids=[args.local_rank], output_device=args.local_rank 102 | ) 103 | else: 104 | model = SingleGPU(model) 105 | 106 | # Resume from checkpoint 107 | checkpoint = torch.load(args.checkpoint) 108 | model.load_state_dict(checkpoint["state_dict"]) 109 | 110 | # Data loading code 111 | valdir = os.path.join(args.data, "val") 112 | val_transforms = utils.create_test_transforms( 113 | conf["input"], args.crop, args.scale, args.ten_crops 114 | ) 115 | 116 | batch_size = ( 117 | conf["optimizer"]["batch_size"] 118 | if not args.ten_crops 119 | else conf["optimizer"]["batch_size"] // 10 120 | ) 121 | dataset = datasets.ImageFolder(valdir, transforms.Compose(val_transforms)) 122 | val_loader = torch.utils.data.DataLoader( 123 | dataset, 124 | batch_size=batch_size // world_size, 125 | shuffle=False, 126 | sampler=TestDistributedSampler(dataset), 127 | num_workers=args.workers, 128 | pin_memory=True, 129 | ) 130 | 131 | criterion = nn.CrossEntropyLoss().cuda() 132 | utils.validate(val_loader, model, criterion, args.ten_crops, args.print_freq) 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /scripts/test_vistas.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import argparse 4 | import subprocess 5 | import sys 6 | 7 | import torch 8 | 9 | parser = argparse.ArgumentParser( 10 | description="Testing script for the Vistas segmentation model" 11 | ) 12 | parser.add_argument( 13 | "--scales", metavar="LIST", type=str, default="[0.7, 1, 1.2]", help="List of scales" 14 | ) 15 | parser.add_argument("--flip", action="store_true", help="Use horizontal flipping") 16 | parser.add_argument( 17 | "--fusion-mode", 18 | metavar="NAME", 19 | type=str, 20 | choices=["mean", "voting", "max"], 21 | default="mean", 22 | help="How to fuse the outputs. Options: 'mean', 'voting', 'max'", 23 | ) 24 | parser.add_argument( 25 | "--output-mode", 26 | metavar="NAME", 27 | type=str, 28 | choices=["palette", "raw", "prob"], 29 | default="final", 30 | help="How the output files are formatted." 31 | " -- palette: color coded predictions" 32 | " -- raw: gray-scale predictions" 33 | " -- prob: gray-scale predictions plus probabilities", 34 | ) 35 | parser.add_argument( 36 | "snapshot", metavar="SNAPSHOT_FILE", type=str, help="Snapshot file to load" 37 | ) 38 | parser.add_argument("data", metavar="IN_DIR", type=str, help="Path to dataset") 39 | parser.add_argument("output", metavar="OUT_DIR", type=str, help="Path to output folder") 40 | 41 | 42 | def docstring_hack(): 43 | """ 44 | Multiproc file which will launch a set of processes locally for multi-gpu 45 | usage: python -m apex.parallel.multiproc main.py ... 46 | """ 47 | pass 48 | 49 | 50 | def main(): 51 | # Load configuration 52 | args = parser.parse_args() 53 | 54 | argslist = list(sys.argv)[1:] 55 | world_size = torch.cuda.device_count() 56 | 57 | if "--world-size" in argslist: 58 | world_size = int(argslist[argslist.index("--world-size") + 1]) 59 | else: 60 | argslist.append("--world-size") 61 | argslist.append(str(world_size)) 62 | 63 | workers = [] 64 | 65 | for i in range(world_size): 66 | if "--rank" in argslist: 67 | argslist[argslist.index("--rank") + 1] = str(i) 68 | else: 69 | argslist.append("--rank") 70 | argslist.append(str(i)) 71 | stdout = None if i == 0 else open("GPU_" + str(i) + ".log", "w") 72 | print(argslist) 73 | p = subprocess.Popen( 74 | [str(sys.executable), "test_vistas_single_gpu.py"] + argslist, stdout=stdout 75 | ) 76 | workers.append(p) 77 | 78 | for p in workers: 79 | p.wait() 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /scripts/test_vistas_single_gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import argparse 4 | from functools import partial 5 | from os import path 6 | 7 | import models 8 | import numpy as np 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn as nn 12 | import torch.nn.functional as functional 13 | from dataset.dataset import SegmentationDataset, segmentation_collate 14 | from dataset.transform import SegmentationTransform 15 | from inplace_abn import InPlaceABN 16 | from modules import DeeplabV3 17 | from PIL import Image, ImagePalette 18 | from torch.utils.data import DataLoader 19 | from torch.utils.data.distributed import DistributedSampler 20 | 21 | parser = argparse.ArgumentParser( 22 | description="Testing script for the Vistas segmentation model" 23 | ) 24 | parser.add_argument( 25 | "--scales", metavar="LIST", type=str, default="[0.7, 1, 1.2]", help="List of scales" 26 | ) 27 | parser.add_argument("--flip", action="store_true", help="Use horizontal flipping") 28 | parser.add_argument( 29 | "--fusion-mode", 30 | metavar="NAME", 31 | type=str, 32 | choices=["mean", "voting", "max"], 33 | default="mean", 34 | help="How to fuse the outputs. Options: 'mean', 'voting', 'max'", 35 | ) 36 | parser.add_argument( 37 | "--output-mode", 38 | metavar="NAME", 39 | type=str, 40 | choices=["palette", "raw", "prob"], 41 | default="final", 42 | help="How the output files are formatted." 43 | " -- palette: color coded predictions" 44 | " -- raw: gray-scale predictions" 45 | " -- prob: gray-scale predictions plus probabilities", 46 | ) 47 | parser.add_argument( 48 | "snapshot", metavar="SNAPSHOT_FILE", type=str, help="Snapshot file to load" 49 | ) 50 | parser.add_argument("data", metavar="IN_DIR", type=str, help="Path to dataset") 51 | parser.add_argument("output", metavar="OUT_DIR", type=str, help="Path to output folder") 52 | parser.add_argument( 53 | "--world-size", metavar="WS", type=int, default=1, help="Number of GPUs" 54 | ) 55 | parser.add_argument("--rank", metavar="RANK", type=int, default=0, help="GPU id") 56 | 57 | 58 | def flip(x, dim): 59 | indices = [slice(None)] * x.dim() 60 | indices[dim] = torch.arange( 61 | x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device 62 | ) 63 | return x[tuple(indices)] 64 | 65 | 66 | class SegmentationModule(nn.Module): 67 | _IGNORE_INDEX = 255 68 | 69 | class _MeanFusion: 70 | def __init__(self, x, classes): 71 | self.buffer = x.new_zeros(x.size(0), classes, x.size(2), x.size(3)) 72 | self.counter = 0 73 | 74 | def update(self, sem_logits): 75 | probs = functional.softmax(sem_logits, dim=1) 76 | self.counter += 1 77 | self.buffer.add_((probs - self.buffer) / self.counter) 78 | 79 | def output(self): 80 | probs, cls = self.buffer.max(1) 81 | return probs, cls 82 | 83 | class _VotingFusion: 84 | def __init__(self, x, classes): 85 | self.votes = x.new_zeros(x.size(0), classes, x.size(2), x.size(3)) 86 | self.probs = x.new_zeros(x.size(0), classes, x.size(2), x.size(3)) 87 | 88 | def update(self, sem_logits): 89 | probs = functional.softmax(sem_logits, dim=1) 90 | probs, cls = probs.max(1, keepdim=True) 91 | 92 | self.votes.scatter_add_(1, cls, self.votes.new_ones(cls.size())) 93 | self.probs.scatter_add_(1, cls, probs) 94 | 95 | def output(self): 96 | cls, idx = self.votes.max(1, keepdim=True) 97 | probs = self.probs / self.votes.clamp(min=1) 98 | probs = probs.gather(1, idx) 99 | return probs.squeeze(1), cls.squeeze(1) 100 | 101 | class _MaxFusion: 102 | def __init__(self, x, _): 103 | self.buffer_cls = x.new_zeros( 104 | x.size(0), x.size(2), x.size(3), dtype=torch.long 105 | ) 106 | self.buffer_prob = x.new_zeros(x.size(0), x.size(2), x.size(3)) 107 | 108 | def update(self, sem_logits): 109 | probs = functional.softmax(sem_logits, dim=1) 110 | max_prob, max_cls = probs.max(1) 111 | 112 | replace_idx = max_prob > self.buffer_prob 113 | self.buffer_cls[replace_idx] = max_cls[replace_idx] 114 | self.buffer_prob[replace_idx] = max_prob[replace_idx] 115 | 116 | def output(self): 117 | return self.buffer_prob, self.buffer_cls 118 | 119 | def __init__(self, body, head, head_channels, classes, fusion_mode="mean"): 120 | super(SegmentationModule, self).__init__() 121 | self.body = body 122 | self.head = head 123 | self.cls = nn.Conv2d(head_channels, classes, 1) 124 | 125 | self.classes = classes 126 | if fusion_mode == "mean": 127 | self.fusion_cls = SegmentationModule._MeanFusion 128 | elif fusion_mode == "voting": 129 | self.fusion_cls = SegmentationModule._VotingFusion 130 | elif fusion_mode == "max": 131 | self.fusion_cls = SegmentationModule._MaxFusion 132 | 133 | def _network(self, x, scale): 134 | if scale != 1: 135 | scaled_size = [round(s * scale) for s in x.shape[-2:]] 136 | x_up = functional.upsample(x, size=scaled_size, mode="bilinear") 137 | else: 138 | x_up = x 139 | 140 | x_up = self.body(x_up) 141 | x_up = self.head(x_up) 142 | sem_logits = self.cls(x_up) 143 | 144 | del x_up 145 | return sem_logits 146 | 147 | def forward(self, x, scales, do_flip=True): 148 | out_size = x.shape[-2:] 149 | fusion = self.fusion_cls(x, self.classes) 150 | 151 | for scale in scales: 152 | # Main orientation 153 | sem_logits = self._network(x, scale) 154 | sem_logits = functional.upsample(sem_logits, size=out_size, mode="bilinear") 155 | fusion.update(sem_logits) 156 | 157 | # Flipped orientation 158 | if do_flip: 159 | # Main orientation 160 | sem_logits = self._network(flip(x, -1), scale) 161 | sem_logits = functional.upsample( 162 | sem_logits, size=out_size, mode="bilinear" 163 | ) 164 | fusion.update(flip(sem_logits, -1)) 165 | 166 | return fusion.output() 167 | 168 | 169 | def main(): 170 | # Load configuration 171 | args = parser.parse_args() 172 | 173 | # Torch stuff 174 | torch.cuda.set_device(args.rank) 175 | cudnn.benchmark = True 176 | 177 | # Create model by loading a snapshot 178 | body, head, cls_state = load_snapshot(args.snapshot) 179 | model = SegmentationModule(body, head, 256, 65, args.fusion_mode) 180 | model.cls.load_state_dict(cls_state) 181 | model = model.cuda().eval() 182 | print(model) 183 | 184 | # Create data loader 185 | transformation = SegmentationTransform( 186 | 2048, 187 | (0.41738699, 0.45732192, 0.46886091), 188 | (0.25685097, 0.26509955, 0.29067996), 189 | ) 190 | dataset = SegmentationDataset(args.data, transformation) 191 | data_loader = DataLoader( 192 | dataset, 193 | batch_size=1, 194 | pin_memory=True, 195 | sampler=DistributedSampler(dataset, args.world_size, args.rank), 196 | num_workers=2, 197 | collate_fn=segmentation_collate, 198 | shuffle=False, 199 | ) 200 | 201 | # Run testing 202 | scales = eval(args.scales) 203 | with torch.no_grad(): 204 | for batch_i, rec in enumerate(data_loader): 205 | print("Testing batch [{:3d}/{:3d}]".format(batch_i + 1, len(data_loader))) 206 | 207 | img = rec["img"].cuda(non_blocking=True) 208 | probs, preds = model(img, scales, args.flip) 209 | 210 | for i, (prob, pred) in enumerate( 211 | zip(torch.unbind(probs, dim=0), torch.unbind(preds, dim=0)) 212 | ): 213 | out_size = rec["meta"][i]["size"] 214 | img_name = rec["meta"][i]["idx"] 215 | 216 | # Save prediction 217 | prob = prob.cpu() 218 | pred = pred.cpu() 219 | pred_img = get_pred_image(pred, out_size, args.output_mode == "palette") 220 | pred_img.save(path.join(args.output, img_name + ".png")) 221 | 222 | # Optionally save probabilities 223 | if args.output_mode == "prob": 224 | prob_img = get_prob_image(prob, out_size) 225 | prob_img.save(path.join(args.output, img_name + "_prob.png")) 226 | 227 | 228 | def load_snapshot(snapshot_file): 229 | """Load a training snapshot""" 230 | print("--- Loading model from snapshot") 231 | 232 | # Create network 233 | norm_act = partial(InPlaceABN, activation="leaky_relu", activation_param=0.01) 234 | body = models.__dict__["net_wider_resnet38_a2"]( 235 | norm_act=norm_act, dilation=(1, 2, 4, 4) 236 | ) 237 | head = DeeplabV3(4096, 256, 256, norm_act=norm_act, pooling_size=(84, 84)) 238 | 239 | # Load snapshot and recover network state 240 | data = torch.load(snapshot_file) 241 | body.load_state_dict(data["state_dict"]["body"]) 242 | head.load_state_dict(data["state_dict"]["head"]) 243 | 244 | return body, head, data["state_dict"]["cls"] 245 | 246 | 247 | _PALETTE = np.array( 248 | [ 249 | [165, 42, 42], 250 | [0, 192, 0], 251 | [196, 196, 196], 252 | [190, 153, 153], 253 | [180, 165, 180], 254 | [90, 120, 150], 255 | [102, 102, 156], 256 | [128, 64, 255], 257 | [140, 140, 200], 258 | [170, 170, 170], 259 | [250, 170, 160], 260 | [96, 96, 96], 261 | [230, 150, 140], 262 | [128, 64, 128], 263 | [110, 110, 110], 264 | [244, 35, 232], 265 | [150, 100, 100], 266 | [70, 70, 70], 267 | [150, 120, 90], 268 | [220, 20, 60], 269 | [255, 0, 0], 270 | [255, 0, 100], 271 | [255, 0, 200], 272 | [200, 128, 128], 273 | [255, 255, 255], 274 | [64, 170, 64], 275 | [230, 160, 50], 276 | [70, 130, 180], 277 | [190, 255, 255], 278 | [152, 251, 152], 279 | [107, 142, 35], 280 | [0, 170, 30], 281 | [255, 255, 128], 282 | [250, 0, 30], 283 | [100, 140, 180], 284 | [220, 220, 220], 285 | [220, 128, 128], 286 | [222, 40, 40], 287 | [100, 170, 30], 288 | [40, 40, 40], 289 | [33, 33, 33], 290 | [100, 128, 160], 291 | [142, 0, 0], 292 | [70, 100, 150], 293 | [210, 170, 100], 294 | [153, 153, 153], 295 | [128, 128, 128], 296 | [0, 0, 80], 297 | [250, 170, 30], 298 | [192, 192, 192], 299 | [220, 220, 0], 300 | [140, 140, 20], 301 | [119, 11, 32], 302 | [150, 0, 255], 303 | [0, 60, 100], 304 | [0, 0, 142], 305 | [0, 0, 90], 306 | [0, 0, 230], 307 | [0, 80, 100], 308 | [128, 64, 64], 309 | [0, 0, 110], 310 | [0, 0, 70], 311 | [0, 0, 192], 312 | [32, 32, 32], 313 | [120, 10, 10], 314 | ], 315 | dtype=np.uint8, 316 | ) 317 | _PALETTE = np.concatenate( 318 | [_PALETTE, np.zeros((256 - _PALETTE.shape[0], 3), dtype=np.uint8)], axis=0 319 | ) 320 | _PALETTE = ImagePalette.ImagePalette( 321 | palette=list(_PALETTE[:, 0]) + list(_PALETTE[:, 1]) + list(_PALETTE[:, 2]), 322 | mode="RGB", 323 | ) 324 | 325 | 326 | def get_pred_image(tensor, out_size, with_palette): 327 | tensor = tensor.numpy() 328 | if with_palette: 329 | img = Image.fromarray(tensor.astype(np.uint8), mode="P") 330 | img.putpalette(_PALETTE) 331 | else: 332 | img = Image.fromarray(tensor.astype(np.uint8), mode="L") 333 | 334 | return img.resize(out_size, Image.NEAREST) 335 | 336 | 337 | def get_prob_image(tensor, out_size): 338 | tensor = (tensor * 255).to(torch.uint8) 339 | img = Image.fromarray(tensor.numpy(), mode="L") 340 | return img.resize(out_size, Image.NEAREST) 341 | 342 | 343 | if __name__ == "__main__": 344 | main() 345 | -------------------------------------------------------------------------------- /scripts/train_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import shutil 7 | import time 8 | 9 | import models 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.optim 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.datasets as datasets 19 | import torchvision.transforms as transforms 20 | from dataset.sampler import TestDistributedSampler 21 | from imagenet import config, utils 22 | from inplace_abn import ABN 23 | from modules import SingleGPU 24 | from tensorboardX import SummaryWriter 25 | 26 | parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") 27 | parser.add_argument("config", metavar="CONFIG_FILE", help="path to configuration file") 28 | parser.add_argument("data", metavar="DIR", help="path to dataset") 29 | parser.add_argument( 30 | "-j", 31 | "--workers", 32 | default=2, 33 | type=int, 34 | metavar="N", 35 | help="number of data loading workers (default: 2)", 36 | ) 37 | parser.add_argument( 38 | "--print-freq", 39 | "-p", 40 | default=10, 41 | type=int, 42 | metavar="N", 43 | help="print frequency (default: 10)", 44 | ) 45 | parser.add_argument( 46 | "--resume", 47 | default="", 48 | type=str, 49 | metavar="PATH", 50 | help="path to latest checkpoint (default: none)", 51 | ) 52 | parser.add_argument( 53 | "-e", 54 | "--evaluate", 55 | dest="evaluate", 56 | action="store_true", 57 | help="evaluate model on validation set", 58 | ) 59 | parser.add_argument("--local_rank", default=0, type=int, help="process rank on node") 60 | parser.add_argument( 61 | "--dist-backend", default="nccl", type=str, help="distributed backend" 62 | ) 63 | parser.add_argument( 64 | "--log-dir", 65 | type=str, 66 | default=".", 67 | metavar="PATH", 68 | help="output directory for Tensorboard log", 69 | ) 70 | parser.add_argument( 71 | "--log-hist", action="store_true", help="log histograms of the weights" 72 | ) 73 | 74 | best_prec1 = 0 75 | args = None 76 | conf = None 77 | tb = None 78 | logger = None 79 | 80 | 81 | def init_logger(rank, log_dir): 82 | global logger 83 | logger = logging.getLogger(__name__) 84 | logger.setLevel(logging.INFO) 85 | handler = logging.FileHandler(os.path.join(log_dir, "training_{}.log".format(rank))) 86 | formatter = logging.Formatter("%(asctime)s - %(message)s") 87 | handler.setFormatter(formatter) 88 | logger.addHandler(handler) 89 | if rank == 0: 90 | handler = logging.StreamHandler() 91 | handler.setFormatter(formatter) 92 | logger.addHandler(handler) 93 | 94 | 95 | def main(): 96 | global args, best_prec1, logger, conf, tb 97 | args = parser.parse_args() 98 | 99 | torch.cuda.set_device(args.local_rank) 100 | 101 | try: 102 | world_size = int(os.environ["WORLD_SIZE"]) 103 | distributed = world_size > 1 104 | except: 105 | distributed = False 106 | world_size = 1 107 | 108 | if distributed: 109 | dist.init_process_group(backend=args.dist_backend, init_method="env://") 110 | 111 | rank = 0 if not distributed else dist.get_rank() 112 | init_logger(rank, args.log_dir) 113 | tb = SummaryWriter(args.log_dir) if rank == 0 else None 114 | 115 | # Load configuration 116 | conf = config.load_config(args.config) 117 | 118 | # Create model 119 | model_params = utils.get_model_params(conf["network"]) 120 | model = models.__dict__["net_" + conf["network"]["arch"]](**model_params) 121 | 122 | model.cuda() 123 | if distributed: 124 | model = torch.nn.parallel.DistributedDataParallel( 125 | model, device_ids=[args.local_rank], output_device=args.local_rank 126 | ) 127 | else: 128 | model = SingleGPU(model) 129 | 130 | # define loss function (criterion) and optimizer 131 | criterion = nn.CrossEntropyLoss().cuda() 132 | optimizer, scheduler = utils.create_optimizer(conf["optimizer"], model) 133 | 134 | # optionally resume from a checkpoint 135 | if args.resume: 136 | if os.path.isfile(args.resume): 137 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 138 | checkpoint = torch.load(args.resume) 139 | args.start_epoch = checkpoint["epoch"] 140 | best_prec1 = checkpoint["best_prec1"] 141 | model.load_state_dict(checkpoint["state_dict"]) 142 | optimizer.load_state_dict(checkpoint["optimizer"]) 143 | logger.info( 144 | "=> loaded checkpoint '{}' (epoch {})".format( 145 | args.resume, checkpoint["epoch"] 146 | ) 147 | ) 148 | else: 149 | logger.warning("=> no checkpoint found at '{}'".format(args.resume)) 150 | else: 151 | init_weights(model) 152 | args.start_epoch = 0 153 | 154 | cudnn.benchmark = True 155 | 156 | # Data loading code 157 | traindir = os.path.join(args.data, "train") 158 | valdir = os.path.join(args.data, "val") 159 | 160 | train_transforms, val_transforms = utils.create_transforms(conf["input"]) 161 | train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_transforms)) 162 | 163 | if distributed: 164 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 165 | else: 166 | train_sampler = None 167 | 168 | train_loader = torch.utils.data.DataLoader( 169 | train_dataset, 170 | batch_size=conf["optimizer"]["batch_size"] // world_size, 171 | shuffle=(train_sampler is None), 172 | num_workers=args.workers, 173 | pin_memory=True, 174 | sampler=train_sampler, 175 | ) 176 | 177 | val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_transforms)) 178 | val_loader = torch.utils.data.DataLoader( 179 | val_dataset, 180 | batch_size=conf["optimizer"]["batch_size"] // world_size, 181 | shuffle=False, 182 | num_workers=args.workers, 183 | pin_memory=True, 184 | sampler=TestDistributedSampler(val_dataset), 185 | ) 186 | 187 | if args.evaluate: 188 | utils.validate( 189 | val_loader, 190 | model, 191 | criterion, 192 | print_freq=args.print_freq, 193 | tb=tb, 194 | logger=logger.info, 195 | ) 196 | return 197 | 198 | for epoch in range(args.start_epoch, conf["optimizer"]["schedule"]["epochs"]): 199 | if distributed: 200 | train_sampler.set_epoch(epoch) 201 | 202 | # train for one epoch 203 | train(train_loader, model, criterion, optimizer, scheduler, epoch) 204 | 205 | # evaluate on validation set 206 | prec1 = utils.validate( 207 | val_loader, 208 | model, 209 | criterion, 210 | it=epoch * len(train_loader), 211 | print_freq=args.print_freq, 212 | tb=tb, 213 | logger=logger.info, 214 | ) 215 | 216 | # remember best prec@1 and save checkpoint 217 | is_best = prec1 > best_prec1 218 | best_prec1 = max(prec1, best_prec1) 219 | if rank == 0: 220 | save_checkpoint( 221 | { 222 | "epoch": epoch + 1, 223 | "arch": conf["network"]["arch"], 224 | "state_dict": model.state_dict(), 225 | "best_prec1": best_prec1, 226 | "optimizer": optimizer.state_dict(), 227 | }, 228 | is_best, 229 | args.log_dir, 230 | ) 231 | 232 | 233 | def train(train_loader, model, criterion, optimizer, scheduler, epoch): 234 | global logger, conf, tb 235 | batch_time = utils.AverageMeter() 236 | data_time = utils.AverageMeter() 237 | losses = utils.AverageMeter() 238 | top1 = utils.AverageMeter() 239 | top5 = utils.AverageMeter() 240 | 241 | if conf["optimizer"]["schedule"]["mode"] == "epoch": 242 | scheduler.step(epoch) 243 | 244 | # switch to train mode 245 | model.train() 246 | 247 | end = time.time() 248 | for i, (input, target) in enumerate(train_loader): 249 | if conf["optimizer"]["schedule"]["mode"] == "step": 250 | scheduler.step(i + epoch * len(train_loader)) 251 | 252 | # measure data loading time 253 | data_time.update(time.time() - end) 254 | 255 | target = target.cuda(non_blocking=True) 256 | 257 | # compute output 258 | output = model(input) 259 | loss = criterion(output, target) 260 | 261 | # compute gradient and do SGD step 262 | optimizer.zero_grad() 263 | loss.backward() 264 | if conf["optimizer"]["clip"] != 0.0: 265 | nn.utils.clip_grad_norm(model.parameters(), conf["optimizer"]["clip"]) 266 | optimizer.step() 267 | 268 | # measure accuracy and record loss 269 | with torch.no_grad(): 270 | output = output.detach() 271 | loss = loss.detach() * target.shape[0] 272 | prec1, prec5 = utils.accuracy_sum(output, target, topk=(1, 5)) 273 | count = target.new_tensor([target.shape[0]], dtype=torch.long) 274 | if dist.is_initialized(): 275 | dist.all_reduce(count, dist.ReduceOp.SUM) 276 | for meter, val in (losses, loss), (top1, prec1), (top5, prec5): 277 | if dist.is_initialized(): 278 | dist.all_reduce(val, dist.ReduceOp.SUM) 279 | val /= count.item() 280 | meter.update(val.item(), count.item()) 281 | 282 | # measure elapsed time 283 | batch_time.update(time.time() - end) 284 | end = time.time() 285 | 286 | if i % args.print_freq == 0: 287 | logger.info( 288 | "Epoch: [{0}][{1}/{2}]\t" 289 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t" 290 | "Data {data_time.val:.3f} ({data_time.avg:.3f}) \t" 291 | "Loss {loss.val:.4f} ({loss.avg:.4f}) \t" 292 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f}) \t" 293 | "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format( 294 | epoch, 295 | i, 296 | len(train_loader), 297 | batch_time=batch_time, 298 | data_time=data_time, 299 | loss=losses, 300 | top1=top1, 301 | top5=top5, 302 | ) 303 | ) 304 | 305 | if not dist.is_initialized() or dist.get_rank() == 0: 306 | tb.add_scalar("train/loss", losses.val, i + epoch * len(train_loader)) 307 | tb.add_scalar( 308 | "train/lr", scheduler.get_lr()[0], i + epoch * len(train_loader) 309 | ) 310 | tb.add_scalar("train/top1", top1.val, i + epoch * len(train_loader)) 311 | tb.add_scalar("train/top5", top5.val, i + epoch * len(train_loader)) 312 | if args.log_hist and i % 10 == 0: 313 | for name, param in model.named_parameters(): 314 | if name.find("fc") != -1 or name.find("bn_out") != -1: 315 | tb.add_histogram( 316 | name, 317 | param.clone().cpu().data.numpy(), 318 | i + epoch * len(train_loader), 319 | ) 320 | 321 | 322 | def save_checkpoint(state, is_best, log_dir): 323 | filepath = os.path.join(log_dir, "checkpoint.pth.tar") 324 | torch.save(state, filepath) 325 | if is_best: 326 | shutil.copyfile(filepath, os.path.join(log_dir, "model_best.pth.tar")) 327 | 328 | 329 | def init_weights(model): 330 | global conf 331 | for name, m in model.named_modules(): 332 | if isinstance(m, nn.Conv2d): 333 | init_fn = getattr(nn.init, conf["network"]["weight_init"] + "_") 334 | if ( 335 | conf["network"]["weight_init"].startswith("xavier") 336 | or conf["network"]["weight_init"] == "orthogonal" 337 | ): 338 | gain = conf["network"]["weight_gain_multiplier"] 339 | if ( 340 | conf["network"]["activation"] == "relu" 341 | or conf["network"]["activation"] == "elu" 342 | ): 343 | gain *= nn.init.calculate_gain("relu") 344 | elif conf["network"]["activation"] == "leaky_relu": 345 | gain *= nn.init.calculate_gain( 346 | "leaky_relu", conf["network"]["activation_param"] 347 | ) 348 | init_fn(m.weight, gain) 349 | elif conf["network"]["weight_init"].startswith("kaiming"): 350 | if ( 351 | conf["network"]["activation"] == "relu" 352 | or conf["network"]["activation"] == "elu" 353 | ): 354 | init_fn(m.weight, 0) 355 | else: 356 | init_fn(m.weight, conf["network"]["activation_param"]) 357 | 358 | if hasattr(m, "bias") and m.bias is not None: 359 | nn.init.constant_(m.bias, 0.0) 360 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, ABN): 361 | nn.init.constant_(m.weight, 1.0) 362 | nn.init.constant_(m.bias, 0.0) 363 | elif isinstance(m, nn.Linear): 364 | nn.init.xavier_uniform_(m.weight, 0.1) 365 | nn.init.constant_(m.bias, 0.0) 366 | 367 | 368 | if __name__ == "__main__": 369 | main() 370 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from os import path, walk, getenv 4 | 5 | import setuptools 6 | import torch 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 8 | 9 | 10 | def find_sources(root_dir, with_cuda=True): 11 | extensions = [".cpp", ".cu"] if with_cuda else [".cpp"] 12 | 13 | sources = [] 14 | for subdir, _, files in walk(root_dir): 15 | for filename in files: 16 | _, ext = path.splitext(filename) 17 | if ext in extensions: 18 | sources.append(path.join(subdir, filename)) 19 | 20 | return sources 21 | 22 | 23 | here = path.abspath(path.dirname(__file__)) 24 | 25 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 26 | long_description = f.read() 27 | 28 | if torch.has_cuda or getenv("IABN_FORCE_CUDA") == "1": 29 | ext_modules = [ 30 | CUDAExtension( 31 | name="inplace_abn._backend", 32 | sources=find_sources("src"), 33 | extra_compile_args={"cxx": ["-O3"], "nvcc": []}, 34 | include_dirs=[path.join(here, "include")], 35 | define_macros=[("WITH_CUDA", 1)], 36 | ) 37 | ] 38 | else: 39 | ext_modules = [ 40 | CppExtension( 41 | name="inplace_abn._backend", 42 | sources=find_sources("src", False), 43 | extra_compile_args=["-O3"], 44 | include_dirs=[path.join(here, "include")], 45 | ) 46 | ] 47 | 48 | setuptools.setup( 49 | # Meta-data 50 | name="inplace-abn", 51 | author="Lorenzo Porzi", 52 | author_email="lorenzo@mapillary.com", 53 | description="In-Place Activate BatchNorm for Pytorch", 54 | long_description=long_description, 55 | long_description_content_type="text/markdown", 56 | url="https://github.com/mapillary/inplace_abn", 57 | classifiers=[ 58 | "Programming Language :: Python :: 3", 59 | "Programming Language :: Python :: 3.4", 60 | "Programming Language :: Python :: 3.5", 61 | "Programming Language :: Python :: 3.6", 62 | "Programming Language :: Python :: 3.7", 63 | ], 64 | # Versioning 65 | use_scm_version={ 66 | "root": ".", 67 | "relative_to": __file__, 68 | "write_to": "inplace_abn/_version.py", 69 | }, 70 | # Requirements 71 | setup_requires=["setuptools_scm"], 72 | python_requires=">=3, <4", 73 | # Package description 74 | packages=["inplace_abn"], 75 | ext_modules=ext_modules, 76 | cmdclass={"build_ext": BuildExtension}, 77 | ) 78 | -------------------------------------------------------------------------------- /src/inplace_abn.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include "checks.h" 9 | #include "dispatch.h" 10 | #include "inplace_abn.h" 11 | #include "utils.h" 12 | 13 | /*********************************************************************************************************************** 14 | * Exposed methods 15 | **********************************************************************************************************************/ 16 | 17 | std::tuple statistics(const at::Tensor& x) { 18 | IABN_CHECK(x.ndimension() >= 2, "x should have at least 2 dimensions"); 19 | 20 | CUDA_DISPATCH(x, statistics, x) 21 | } 22 | 23 | #ifdef WITH_CUDA 24 | std::tuple reduce_statistics( 25 | const at::Tensor& all_mean, 26 | const at::Tensor& all_var, 27 | const at::Tensor& all_count) { 28 | // Inputs shouldn't be half 29 | CHECK_NOT_HALF(all_mean); 30 | CHECK_NOT_HALF(all_var); 31 | 32 | // reduce_statistics is only used on GPU 33 | CHECK_CUDA(all_mean); 34 | CHECK_CUDA(all_var); 35 | CHECK_CUDA(all_count); 36 | 37 | // Check types and dimensions 38 | CHECK_SAME_TYPE(all_mean, all_var); 39 | IABN_CHECK( 40 | all_count.scalar_type() == at::ScalarType::Long, 41 | "all_count should have type int64"); 42 | IABN_CHECK(all_mean.ndimension() == 2, "all_mean should have size N x C"); 43 | IABN_CHECK(all_var.ndimension() == 2, "all_var should have size N x C"); 44 | IABN_CHECK( 45 | all_count.ndimension() == 2 && all_count.size(1) == 1, 46 | "all_count should have size N x 1"); 47 | IABN_CHECK( 48 | all_mean.size(0) == all_var.size(0) && 49 | all_mean.size(0) == all_count.size(0), 50 | "Inputs should have the same size in dimension 0"); 51 | IABN_CHECK( 52 | all_mean.size(1) == all_var.size(1), 53 | "all_mean and all_var should have the same size in dimension 1"); 54 | 55 | return reduce_statistics_cuda(all_mean, all_var, all_count); 56 | } 57 | #endif 58 | 59 | void forward( 60 | at::Tensor& x, 61 | const at::Tensor& mean, 62 | const at::Tensor& var, 63 | const std::optional& weight, 64 | const std::optional& bias, 65 | float eps, 66 | Activation activation, 67 | float activation_param) { 68 | // Check dimensions and types 69 | IABN_CHECK(x.ndimension() >= 2, "x should have at least 2 dimensions"); 70 | IABN_CHECK( 71 | is_compatible_stat(x, mean), 72 | "mean is not compatible with x (wrong size or scalar type)"); 73 | IABN_CHECK( 74 | is_compatible_stat(x, var), 75 | "var is not compatible with x (wrong size or scalar type)"); 76 | if (weight.has_value()) 77 | IABN_CHECK( 78 | is_compatible_weight(x, weight.value()), 79 | "weight is not compatible with x (wrong size or scalar type)"); 80 | if (bias.has_value()) 81 | IABN_CHECK( 82 | is_compatible_weight(x, bias.value()), 83 | "bias is not compatible with x (wrong size or scalar type)"); 84 | if (weight.has_value() && bias.has_value()) 85 | CHECK_SAME_TYPE(weight.value(), bias.value()); 86 | 87 | IABN_CHECK( 88 | (weight.has_value() && bias.has_value()) || 89 | (!weight.has_value() && !bias.has_value()), 90 | "weight and bias must be equally present or not present"); 91 | 92 | CUDA_DISPATCH( 93 | x, forward, x, mean, var, weight, bias, eps, activation, activation_param) 94 | } 95 | 96 | std::tuple backward_reduce( 97 | const at::Tensor& y_act, 98 | const at::Tensor& dy_act, 99 | const std::optional& weight, 100 | const std::optional& bias, 101 | float eps, 102 | Activation activation, 103 | float activation_param) { 104 | // Check dimensions and types 105 | IABN_CHECK( 106 | y_act.ndimension() >= 2, "y_act should have at least 2 dimensions"); 107 | IABN_CHECK( 108 | have_same_dims(y_act, dy_act), 109 | "y_act and dy_act should have the same size"); 110 | CHECK_SAME_TYPE(y_act, dy_act); 111 | if (weight.has_value()) 112 | IABN_CHECK( 113 | is_compatible_weight(y_act, weight.value()), 114 | "weight is not compatible with y_act (wrong size or scalar type)"); 115 | if (bias.has_value()) 116 | IABN_CHECK( 117 | is_compatible_weight(y_act, bias.value()), 118 | "bias is not compatible with y_act (wrong size or scalar type)"); 119 | if (weight.has_value() && bias.has_value()) 120 | CHECK_SAME_TYPE(weight.value(), bias.value()); 121 | 122 | IABN_CHECK( 123 | (weight.has_value() && bias.has_value()) || 124 | (!weight.has_value() && !bias.has_value()), 125 | "weight and bias must be equally present or not present"); 126 | 127 | CUDA_DISPATCH( 128 | y_act, 129 | backward_reduce, 130 | y_act, 131 | dy_act, 132 | weight, 133 | bias, 134 | eps, 135 | activation, 136 | activation_param) 137 | } 138 | 139 | void backward_train( 140 | const at::Tensor& xhat, 141 | at::Tensor& dy, 142 | const at::Tensor& var, 143 | const at::Tensor& count, 144 | const at::Tensor& sum_dy, 145 | const at::Tensor& sum_xhat_dy, 146 | const std::optional& weight, 147 | float eps) { 148 | // Check dimensions and types 149 | IABN_CHECK(xhat.ndimension() >= 2, "xhat should have at least 2 dimensions"); 150 | IABN_CHECK(have_same_dims(xhat, dy), "xhat and dy should have the same size"); 151 | CHECK_SAME_TYPE(xhat, dy); 152 | IABN_CHECK( 153 | is_compatible_stat(xhat, var), 154 | "var is not compatible with xhat (wrong size or scalar type)"); 155 | IABN_CHECK( 156 | count.ndimension() == 1 && count.size(0) == 1, 157 | "count should be a vector with a single element"); 158 | IABN_CHECK( 159 | count.scalar_type() == at::ScalarType::Long, 160 | "count should have type int64"); 161 | IABN_CHECK( 162 | is_compatible_stat(xhat, sum_dy), 163 | "sum_dy is not compatible with xhat (wrong size or scalar type)"); 164 | IABN_CHECK( 165 | is_compatible_stat(xhat, sum_xhat_dy), 166 | "sum_xhat_dy is not compatible with xhat (wrong size or scalar type)"); 167 | if (weight.has_value()) 168 | IABN_CHECK( 169 | is_compatible_weight(xhat, weight.value()), 170 | "weight is not compatible with xhat (wrong size or scalar type)"); 171 | 172 | CUDA_DISPATCH( 173 | xhat, backward, xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, eps) 174 | } 175 | 176 | at::Tensor backward_test( 177 | const at::Tensor& dy_, 178 | const at::Tensor& var, 179 | const std::optional& weight, 180 | float eps) { 181 | // Check dimensions and types 182 | IABN_CHECK(dy_.ndimension() >= 2, "dy should have at least 2 dimensions"); 183 | IABN_CHECK( 184 | is_compatible_stat(dy_, var), 185 | "var is not compatible with dy (wrong size or scalar type)"); 186 | if (weight.has_value()) 187 | IABN_CHECK( 188 | is_compatible_weight(dy_, weight.value()), 189 | "weight is not compatible with dy (wrong size or scalar type)"); 190 | 191 | // TODO: optimize implementation for GPU 192 | auto dy = normalize_shape(dy_); 193 | auto mult = weight.has_value() 194 | ? (weight.value().to(var.options()).abs() + eps) / (var + eps).sqrt() 195 | : 1 / (var + eps).sqrt(); 196 | auto dx = normalize_shape(mult) * dy.to(var.options()); 197 | return dx.to(dy_.options()).view(dy_.sizes()); 198 | } 199 | 200 | /*********************************************************************************************************************** 201 | * Python Bindings 202 | **********************************************************************************************************************/ 203 | 204 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 205 | pybind11::enum_(m, "Activation") 206 | .value("LeakyReLU", Activation::LeakyReLU) 207 | .value("ELU", Activation::ELU) 208 | .value("Identity", Activation::Identity); 209 | 210 | // Forward methods 211 | m.def( 212 | "statistics", 213 | &statistics, 214 | "Compute iABN statistics, i.e. mean, biased variance and sample count"); 215 | #ifdef WITH_CUDA 216 | m.def( 217 | "reduce_statistics", 218 | &reduce_statistics, 219 | "Reduce statistics from multiple GPUs"); 220 | #endif 221 | m.def( 222 | "forward", 223 | &forward, 224 | "iABN forward pass. This is an in-place operation w.r.t. x"); 225 | 226 | // Backward methods 227 | m.def("backward_reduce", &backward_reduce, "First step of the backward pass"); 228 | m.def( 229 | "backward_train", 230 | &backward_train, 231 | "Second step of the backward pass. This is an in-place operation w.r.t. dy"); 232 | m.def( 233 | "backward_test", 234 | &backward_test, 235 | "Second step of the backward pass, test mode"); 236 | } 237 | -------------------------------------------------------------------------------- /src/inplace_abn_cpu.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include "checks.h" 9 | #include "inplace_abn.h" 10 | #include "utils.h" 11 | 12 | /*********************************************************************************************************************** 13 | * Utility functions 14 | **********************************************************************************************************************/ 15 | 16 | int32_t count_samples(const at::Tensor& x) { 17 | return x.size(0) * x.size(2); 18 | } 19 | 20 | /*********************************************************************************************************************** 21 | * Templated implementations 22 | **********************************************************************************************************************/ 23 | 24 | template 25 | std::tuple backward_reduce_impl( 26 | const at::Tensor& y_act_, 27 | const at::Tensor& dy_act_, 28 | const std::optional& weight_, 29 | const std::optional& bias_, 30 | float eps, 31 | float activation_param) { 32 | // Initialize output tensors 33 | auto xhat_ = at::empty_like(y_act_); 34 | auto dy_ = at::empty_like(y_act_); 35 | auto sum_dy_ = at::zeros({y_act_.size(1)}, y_act_.options()); 36 | auto sum_xhat_dy_ = at::zeros({y_act_.size(1)}, y_act_.options()); 37 | 38 | // Normalize shapes 39 | auto y_act_norm_ = normalize_shape(y_act_); 40 | auto dy_act_norm_ = normalize_shape(dy_act_); 41 | auto xhat_norm_ = normalize_shape(xhat_); 42 | auto dy_norm_ = normalize_shape(dy_); 43 | 44 | // Get dimensions 45 | int64_t num = y_act_norm_.size(0), chn = y_act_norm_.size(1), 46 | sp = y_act_norm_.size(2); 47 | 48 | // Make accessors 49 | auto y_act = y_act_norm_.accessor(); 50 | auto dy_act = dy_act_norm_.accessor(); 51 | auto xhat = xhat_norm_.accessor(); 52 | auto dy = dy_norm_.accessor(); 53 | auto weight = accessor_or_dummy(weight_); 54 | auto bias = accessor_or_dummy(bias_); 55 | auto sum_dy = sum_dy_.accessor(); 56 | auto sum_xhat_dy = sum_xhat_dy_.accessor(); 57 | 58 | // Main loop 59 | for (int64_t c = 0; c < chn; ++c) { 60 | auto inv_gamma_c = weight_.has_value() 61 | ? scalar_t(1) / (std::abs(weight[c]) + eps) 62 | : scalar_t(1); 63 | auto beta_c = bias_.has_value() ? bias[c] : scalar_t(0); 64 | 65 | for (int64_t n = 0; n < num; ++n) { 66 | auto y_act_nc = y_act[n][c]; 67 | auto dy_act_nc = dy_act[n][c]; 68 | auto xhat_nc = xhat[n][c]; 69 | auto dy_nc = dy[n][c]; 70 | 71 | for (int64_t s = 0; s < sp; ++s) { 72 | // Invert activation 73 | ActivationFn::backward( 74 | y_act_nc[s], dy_act_nc[s], activation_param, xhat_nc[s], dy_nc[s]); 75 | 76 | // Invert affine transformation 77 | xhat_nc[s] = (xhat_nc[s] - beta_c) * inv_gamma_c; 78 | 79 | // Accumulate 80 | sum_dy[c] += dy_nc[s]; 81 | sum_xhat_dy[c] += xhat_nc[s] * dy_nc[s]; 82 | } 83 | } 84 | } 85 | 86 | return std::make_tuple(xhat_, dy_, sum_dy_, sum_xhat_dy_); 87 | } 88 | 89 | /*********************************************************************************************************************** 90 | * Interface methods 91 | **********************************************************************************************************************/ 92 | 93 | std::tuple statistics_cpu( 94 | const at::Tensor& x_) { 95 | CHECK_NOT_HALF(x_); 96 | 97 | auto x = normalize_shape(x_); 98 | 99 | auto mean = x.mean(c10::IntArrayRef({0, 2})); 100 | auto var = (x - normalize_shape(mean)).pow(2).mean(c10::IntArrayRef({0, 2})); 101 | auto count = 102 | at::full({1}, count_samples(x), x.options().dtype(at::ScalarType::Long)); 103 | 104 | return std::make_tuple(mean, var, count); 105 | } 106 | 107 | void forward_cpu( 108 | at::Tensor& x_, 109 | const at::Tensor& mean, 110 | const at::Tensor& var, 111 | const std::optional& weight, 112 | const std::optional& bias, 113 | float eps, 114 | Activation activation, 115 | float activation_param) { 116 | CHECK_NOT_HALF(x_); 117 | 118 | auto x = normalize_shape(x_); 119 | 120 | // Apply normalization 121 | auto abs_weight = weight.has_value() 122 | ? weight.value().abs() + eps 123 | : at::ones({mean.size(0)}, mean.options()); 124 | auto inv_std = 1 / at::sqrt(var + eps); 125 | 126 | auto scale = weight.has_value() ? abs_weight * inv_std : inv_std; 127 | auto shift = weight.has_value() ? bias.value() - mean * abs_weight * inv_std 128 | : -mean * inv_std; 129 | 130 | x.mul_(normalize_shape(scale)).add_(normalize_shape(shift)); 131 | 132 | switch (activation) { 133 | case Activation::LeakyReLU: 134 | at::leaky_relu_(x, activation_param); 135 | break; 136 | case Activation::ELU: 137 | at::elu_(x, activation_param); 138 | break; 139 | case Activation::Identity: 140 | break; 141 | } 142 | } 143 | 144 | std::tuple backward_reduce_cpu( 145 | const at::Tensor& y_act, 146 | const at::Tensor& dy_act, 147 | const std::optional& weight, 148 | const std::optional& bias, 149 | float eps, 150 | Activation activation, 151 | float activation_param) { 152 | CHECK_NOT_HALF(y_act); 153 | 154 | // Run templated implementation 155 | return AT_DISPATCH_FLOATING_TYPES( 156 | y_act.scalar_type(), "backward_reduce_cpu", [&] { 157 | switch (activation) { 158 | case Activation::LeakyReLU: 159 | return backward_reduce_impl( 160 | y_act, dy_act, weight, bias, eps, activation_param); 161 | case Activation::ELU: 162 | return backward_reduce_impl( 163 | y_act, dy_act, weight, bias, eps, activation_param); 164 | case Activation::Identity: 165 | default: 166 | return backward_reduce_impl( 167 | y_act, dy_act, weight, bias, eps, activation_param); 168 | } 169 | }); 170 | } 171 | 172 | void backward_cpu( 173 | const at::Tensor& xhat_, 174 | at::Tensor& dy_, 175 | const at::Tensor& var, 176 | const at::Tensor& count, 177 | const at::Tensor& sum_dy, 178 | const at::Tensor& sum_xhat_dy, 179 | const std::optional& weight, 180 | float eps) { 181 | CHECK_NOT_HALF(xhat_); 182 | 183 | auto xhat = normalize_shape(xhat_); 184 | auto dy = normalize_shape(dy_); 185 | auto mean_dy = normalize_shape(sum_dy / count.to(sum_dy.options())); 186 | auto mean_xhat_dy = 187 | normalize_shape(sum_xhat_dy / count.to(sum_xhat_dy.options())); 188 | 189 | auto mult = weight.has_value() 190 | ? (weight.value().abs() + eps) / (var + eps).sqrt() 191 | : 1 / (var + eps).sqrt(); 192 | 193 | // dy = (dy - mean_dy - xhat * mean_xhat_dy) * mult 194 | dy.sub_(mean_dy).sub_(xhat * mean_xhat_dy).mul_(normalize_shape(mult)); 195 | } 196 | -------------------------------------------------------------------------------- /src/inplace_abn_cuda.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "inplace_abn.h" 12 | #include "utils.h" 13 | #include "cuda_utils.cuh" 14 | #include "inplace_abn_kernels.cuh" 15 | #include "dispatch.h" 16 | 17 | /*********************************************************************************************************************** 18 | * Templated implementations 19 | **********************************************************************************************************************/ 20 | 21 | template 22 | std::tuple statistics_template(const at::Tensor& x_) { 23 | // Normalize shape and get dimensions 24 | auto x = normalize_shape(x_); 25 | auto num = x.size(0), chn = x.size(1), sp = x.size(2); 26 | 27 | // Type handling 28 | using accscalar_t = at::acc_type; 29 | auto acc_options = x.options(); 30 | if (x.scalar_type() == at::ScalarType::Half) { 31 | acc_options = acc_options.dtype(at::ScalarType::Float); 32 | } 33 | 34 | // Initialize output tensors 35 | auto mean = at::empty({chn}, acc_options); 36 | auto var = at::empty({chn}, acc_options); 37 | auto count = at::full({1}, num * sp, x.options().dtype(at::ScalarType::Long)); 38 | 39 | // Make accessors 40 | auto x_accessor = x.packed_accessor(); 41 | auto mean_accessor = mean.packed_accessor(); 42 | auto var_accessor = var.packed_accessor(); 43 | 44 | // Kernel parameters 45 | auto stream = at::cuda::getCurrentCUDAStream(); 46 | dim3 blocks(chn); 47 | int tf = getNumThreads(sp); 48 | dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE / tf)); 49 | 50 | // Invoke kernel 51 | statistics_kernel<<>>( 52 | x_accessor, mean_accessor, var_accessor); 53 | 54 | return std::make_tuple(mean, var, count); 55 | } 56 | 57 | template 58 | std::tuple reduce_statistics_template( 59 | const at::Tensor& all_mean, const at::Tensor& all_var, const at::Tensor& all_count) { 60 | auto num = all_mean.size(0), chn = all_mean.size(1); 61 | 62 | // Initialize output tensors 63 | auto mean = at::empty({chn}, all_mean.options()); 64 | auto var = at::empty({chn}, all_var.options()); 65 | auto count = all_count.sum({0}); 66 | 67 | // Make accessors 68 | auto all_mean_accessor = all_mean.packed_accessor(); 69 | auto all_var_accessor = all_var.packed_accessor(); 70 | auto all_count_accessor = all_count.packed_accessor(); 71 | auto mean_accessor = mean.packed_accessor(); 72 | auto var_accessor = var.packed_accessor(); 73 | 74 | // Kernel parameters 75 | auto stream = at::cuda::getCurrentCUDAStream(); 76 | int threads = getNumThreads(chn); 77 | int blocks = std::max(1, chn / threads); 78 | 79 | // Invoke kernel 80 | reduce_statistics_kernel<<>>( 81 | all_mean_accessor, all_var_accessor, all_count_accessor, mean_accessor, var_accessor); 82 | 83 | return std::make_tuple(mean, var, count); 84 | } 85 | 86 | template 87 | void forward_template(at::Tensor& x_, const at::Tensor& mean, const at::Tensor& var, 88 | const std::optional& weight, const std::optional& bias, 89 | float eps, Activation activation, float activation_param) { 90 | // Normalize shape and get dimensions 91 | auto x = normalize_shape(x_); 92 | auto num = x.size(0), chn = x.size(1), sp = x.size(2); 93 | 94 | // Type handling 95 | using accscalar_t = at::acc_type; 96 | 97 | // Make accessors 98 | auto x_accessor = x.packed_accessor(); 99 | auto mean_accessor = mean.packed_accessor(); 100 | auto var_accessor = var.packed_accessor(); 101 | auto weight_accessor = packed_accessor_or_dummy(weight); 102 | auto bias_accessor = packed_accessor_or_dummy(bias); 103 | 104 | // Kernel parameters 105 | auto stream = at::cuda::getCurrentCUDAStream(); 106 | int tf = std::max(getNumThreads(sp / 4), std::min(getNumThreads(sp), 64)); 107 | int tb = std::max(64 / tf, 1); 108 | dim3 blocks(chn, std::max(1, std::min((256 * 1024) / chn, (chn + tb - 1) / tb))); 109 | blocks.y = std::min(blocks.y, 65535); 110 | dim3 threads(tf, tb); 111 | 112 | // Invoke kernel 113 | switch (activation) { 114 | case Activation::LeakyReLU: 115 | forward_kernel<<>>( 116 | x_accessor, mean_accessor, var_accessor, weight_accessor, bias_accessor, eps, activation_param); 117 | break; 118 | case Activation::ELU: 119 | forward_kernel<<>>( 120 | x_accessor, mean_accessor, var_accessor, weight_accessor, bias_accessor, eps, activation_param); 121 | break; 122 | case Activation::Identity: 123 | forward_kernel<<>>( 124 | x_accessor, mean_accessor, var_accessor, weight_accessor, bias_accessor, eps, activation_param); 125 | break; 126 | } 127 | } 128 | 129 | template 130 | std::tuple backward_reduce_template( 131 | const at::Tensor& y_act_, const at::Tensor& dy_act_, const std::optional& weight, 132 | const std::optional& bias, float eps, Activation activation, float activation_param) { 133 | // Normalize shape and get dimensions 134 | auto y_act = normalize_shape(y_act_); 135 | auto dy_act = normalize_shape(dy_act_); 136 | auto num = y_act.size(0), chn = y_act.size(1), sp = y_act.size(2); 137 | 138 | // Type handling 139 | using accscalar_t = at::acc_type; 140 | auto acc_options = y_act.options(); 141 | if (y_act.scalar_type() == at::ScalarType::Half) { 142 | acc_options = acc_options.dtype(at::ScalarType::Float); 143 | } 144 | 145 | // Initialize output tensors 146 | auto xhat = at::empty_like(y_act); 147 | auto dy = at::empty_like(y_act); 148 | auto sum_dy = at::empty({chn}, acc_options); 149 | auto sum_xhat_dy = at::empty({chn}, acc_options); 150 | 151 | // Make accessors 152 | auto y_act_accessor = y_act.packed_accessor(); 153 | auto dy_act_accessor = dy_act.packed_accessor(); 154 | auto xhat_accessor = xhat.packed_accessor(); 155 | auto dy_accessor = dy.packed_accessor(); 156 | auto weight_accessor = packed_accessor_or_dummy(weight); 157 | auto bias_accessor = packed_accessor_or_dummy(bias); 158 | auto sum_dy_accessor = sum_dy.packed_accessor(); 159 | auto sum_xhat_dy_accessor = sum_xhat_dy.packed_accessor(); 160 | 161 | // Kernel parameters 162 | auto stream = at::cuda::getCurrentCUDAStream(); 163 | int block_y = std::min(lastPow2(num), MAX_BLOCK_SIZE / 32); 164 | int block_x = std::min(getNumThreads(sp), MAX_BLOCK_SIZE / block_y); 165 | const dim3 threads(block_x, block_y); 166 | const dim3 blocks(chn); 167 | 168 | // Invoke kernel 169 | switch (activation) { 170 | case Activation::LeakyReLU: 171 | backward_reduce_kernel<<>>( 172 | y_act_accessor, dy_act_accessor, weight_accessor, bias_accessor, xhat_accessor, dy_accessor, sum_dy_accessor, sum_xhat_dy_accessor, 173 | eps, activation_param); 174 | break; 175 | case Activation::ELU: 176 | backward_reduce_kernel<<>>( 177 | y_act_accessor, dy_act_accessor, weight_accessor, bias_accessor, xhat_accessor, dy_accessor, sum_dy_accessor, sum_xhat_dy_accessor, 178 | eps, activation_param); 179 | break; 180 | case Activation::Identity: 181 | backward_reduce_kernel<<>>( 182 | y_act_accessor, dy_act_accessor, weight_accessor, bias_accessor, xhat_accessor, dy_accessor, sum_dy_accessor, sum_xhat_dy_accessor, 183 | eps, activation_param); 184 | break; 185 | } 186 | 187 | return std::make_tuple(xhat.view(y_act_.sizes()), dy.view(y_act_.sizes()), sum_dy, sum_xhat_dy); 188 | } 189 | 190 | template 191 | void backward_template(const at::Tensor& xhat_, at::Tensor& dy_, const at::Tensor& var, 192 | const at::Tensor& count, const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy, 193 | const std::optional& weight, float eps) { 194 | // Normalize shape and get dimensions 195 | auto xhat = normalize_shape(xhat_); 196 | auto dy = normalize_shape(dy_); 197 | auto num = xhat.size(0), chn = xhat.size(1), sp = xhat.size(2); 198 | 199 | // Type handling 200 | using accscalar_t = at::acc_type; 201 | 202 | // Make accessors 203 | auto xhat_accessor = xhat.packed_accessor(); 204 | auto dy_accessor = dy.packed_accessor(); 205 | auto var_accessor = var.packed_accessor(); 206 | auto count_accessor = count.packed_accessor(); 207 | auto sum_dy_accessor = sum_dy.packed_accessor(); 208 | auto sum_xhat_dy_accessor = sum_xhat_dy.packed_accessor(); 209 | auto weight_accessor = packed_accessor_or_dummy(weight); 210 | 211 | // Kernel parameters 212 | auto stream = at::cuda::getCurrentCUDAStream(); 213 | int tf = std::max(getNumThreads(sp / 4), std::min(getNumThreads(sp), 64)); 214 | int tb = std::max(64 / tf, 1); 215 | dim3 blocks(chn, std::max(1, std::min((256 * 1024) / chn, (chn + tb - 1) / tb))); 216 | blocks.y = std::min(blocks.y, 65535); 217 | dim3 threads(tf, tb); 218 | 219 | // Invoke kernel 220 | backward_kernel<<>>( 221 | xhat_accessor, dy_accessor, var_accessor, count_accessor, sum_dy_accessor, sum_xhat_dy_accessor, 222 | weight_accessor, eps); 223 | } 224 | 225 | /*********************************************************************************************************************** 226 | * Interface methods 227 | **********************************************************************************************************************/ 228 | 229 | std::tuple statistics_cuda(const at::Tensor& x) { 230 | return AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "statistics_cuda", [&] { 231 | if (at::cuda::detail::canUse32BitIndexMath(x)) { 232 | return statistics_template(x); 233 | } else { 234 | return statistics_template(x); 235 | } 236 | }); 237 | } 238 | 239 | std::tuple reduce_statistics_cuda( 240 | const at::Tensor& all_mean, const at::Tensor& all_var, const at::Tensor& all_count) { 241 | return AT_DISPATCH_FLOATING_TYPES(all_mean.scalar_type(), "reduce_statistics_cuda", [&] { 242 | if (at::cuda::detail::canUse32BitIndexMath(all_mean)) { 243 | return reduce_statistics_template(all_mean, all_var, all_count); 244 | } else { 245 | return reduce_statistics_template(all_mean, all_var, all_count); 246 | } 247 | }); 248 | } 249 | 250 | void forward_cuda(at::Tensor& x, const at::Tensor& mean, const at::Tensor& var, 251 | const std::optional& weight, const std::optional& bias, 252 | float eps, Activation activation, float activation_param) { 253 | const auto& w_scalar_type = weight.has_value() ? weight.value().scalar_type() : x.scalar_type(); 254 | 255 | DOUBLE_DISPATCH(x.scalar_type(), w_scalar_type, "forward_cuda", [&] { 256 | if (at::cuda::detail::canUse32BitIndexMath(x)) { 257 | forward_template(x, mean, var, weight, bias, eps, activation, activation_param); 258 | } else { 259 | forward_template(x, mean, var, weight, bias, eps, activation, activation_param); 260 | } 261 | }); 262 | } 263 | 264 | std::tuple backward_reduce_cuda( 265 | const at::Tensor& y_act, const at::Tensor& dy_act, const std::optional& weight, 266 | const std::optional& bias, float eps, Activation activation, float activation_param) { 267 | const auto& w_scalar_type = weight.has_value() ? weight.value().scalar_type() : y_act.scalar_type(); 268 | 269 | return DOUBLE_DISPATCH(y_act.scalar_type(), w_scalar_type, "backward_reduce_cuda", [&] { 270 | if (at::cuda::detail::canUse32BitIndexMath(y_act)) { 271 | return backward_reduce_template( 272 | y_act, dy_act, weight, bias, eps, activation, activation_param); 273 | } else { 274 | return backward_reduce_template( 275 | y_act, dy_act, weight, bias, eps, activation, activation_param); 276 | } 277 | }); 278 | } 279 | 280 | void backward_cuda(const at::Tensor& xhat, at::Tensor& dy, const at::Tensor& var, const at::Tensor& count, 281 | const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy, 282 | const std::optional& weight, float eps) { 283 | const auto& w_scalar_type = weight.has_value() ? weight.value().scalar_type() : xhat.scalar_type(); 284 | 285 | return DOUBLE_DISPATCH(xhat.scalar_type(), w_scalar_type, "backward_cuda", [&] { 286 | if (at::cuda::detail::canUse32BitIndexMath(xhat)) { 287 | backward_template(xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, eps); 288 | } else { 289 | backward_template(xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, eps); 290 | } 291 | }); 292 | } 293 | -------------------------------------------------------------------------------- /src/inplace_abn_kernels.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "inplace_abn.h" 8 | #include "cuda_utils.cuh" 9 | 10 | /*********************************************************************************************************************** 11 | * Kernels 12 | * ------- 13 | * 14 | * These are copy-pasted (+ some minor modifications) from the pytorch 1.1 native implementation of BN 15 | **********************************************************************************************************************/ 16 | 17 | template 18 | __global__ void statistics_kernel( 19 | const at::PackedTensorAccessor input, 20 | at::PackedTensorAccessor mean, 21 | at::PackedTensorAccessor var) { 22 | 23 | __shared__ int shared_n[2 * 2 * WARP_SIZE + WARP_SIZE]; 24 | 25 | auto plane = blockIdx.x; 26 | int N = input.size(0) * input.size(2); 27 | auto tid = threadIdx.x + threadIdx.y * blockDim.x; 28 | 29 | // Compute the mean and variance across (batch, x/y/z) 30 | // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block) 31 | // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm 32 | // and the parallel algorithm on the same page. 33 | // We use two shuffles to reduce across the entire block. 34 | // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description. 35 | accscalar_t* shared_avg_var = (accscalar_t*) &shared_n[WARP_SIZE]; 36 | 37 | // first the reductions each thread does separately 38 | accscalar_t avg = 0; 39 | accscalar_t var_n = 0; 40 | int n = 0; 41 | for (auto batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { 42 | for (auto x = threadIdx.x; x < input.size(2); x += blockDim.x) { 43 | accscalar_t v = input[batch][plane][x]; 44 | accscalar_t d1 = v - avg; 45 | n++; 46 | avg += d1 / n; 47 | var_n += d1 * (v - avg); 48 | } 49 | } 50 | 51 | // first warpSum to get one value per thread to 52 | // one value per warp 53 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 54 | accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE); 55 | int o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE); 56 | accscalar_t factor = 1.0 / fmaxf(1.0, n + o_n); 57 | var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; 58 | avg = (n * avg + o_n * o_avg) * factor; 59 | n += o_n; 60 | } 61 | 62 | // this writes each warps item into shared memory 63 | // there are at most WARP_SIZE items left because 64 | // there are at most WARP_SIZE**2 threads at the beginning 65 | __syncthreads(); 66 | if (tid % WARP_SIZE == 0) { 67 | shared_n[tid / WARP_SIZE] = n; 68 | shared_avg_var[tid / WARP_SIZE * 2] = avg; 69 | shared_avg_var[tid / WARP_SIZE * 2 + 1] = var_n; 70 | } 71 | __syncthreads(); 72 | 73 | // now have a second warpSum to reduce the intermediate values 74 | // from shared memory to a single number. The very first 75 | // thread writes it to shared memory. 76 | if (tid < WARP_SIZE) { 77 | n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_n[tid] : 0); 78 | avg = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid] : accscalar_t(0)); 79 | var_n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid + 1] : accscalar_t(0)); 80 | } 81 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 82 | accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE); 83 | int o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE); 84 | accscalar_t factor = 1.0 / fmaxf(1.0, n + o_n); 85 | var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; 86 | avg = (n * avg + o_n * o_avg) * factor; 87 | n += o_n; 88 | } 89 | 90 | // Save mean and variance 91 | if (tid == 0) { 92 | mean[plane] = avg; 93 | var[plane] = var_n / N; 94 | } 95 | } 96 | 97 | template 98 | __global__ void reduce_statistics_kernel( 99 | const at::PackedTensorAccessor all_mean, 100 | const at::PackedTensorAccessor all_var, 101 | const at::PackedTensorAccessor all_count, 102 | at::PackedTensorAccessor mean, 103 | at::PackedTensorAccessor var) { 104 | int num = all_mean.size(0), chn = all_mean.size(1); 105 | auto tid = threadIdx.x, bid = blockIdx.x; 106 | 107 | for (auto c = bid * blockDim.x + tid; c < chn; c += gridDim.x * blockDim.x) { 108 | scalar_t mean_c = 0; 109 | scalar_t var_c = 0; 110 | int64_t count_c = 0; 111 | 112 | for (int n = 0; n < num; ++n) { 113 | auto count_n = all_count[n][0]; 114 | auto mean_n = all_mean[n][c]; 115 | auto var_n = all_var[n][c] * count_n; 116 | 117 | auto delta = mean_n - mean_c; 118 | auto new_count = count_c + count_n; 119 | 120 | mean_c = (count_c * mean_c + count_n * mean_n) / new_count; 121 | var_c += var_n + delta * delta * count_c * count_n / new_count; 122 | count_c = new_count; 123 | } 124 | 125 | mean[c] = mean_c; 126 | var[c] = var_c / count_c; 127 | } 128 | } 129 | 130 | template 131 | __global__ void forward_kernel( 132 | at::PackedTensorAccessor x, 133 | const at::PackedTensorAccessor mean_, 134 | const at::PackedTensorAccessor var, 135 | const at::PackedTensorAccessor weight_, 136 | const at::PackedTensorAccessor bias_, 137 | float eps_, float activation_param) { 138 | index_t c = blockIdx.x; 139 | if (c >= x.size(1)) return; 140 | 141 | // Cache channel-wise values 142 | accscalar_t eps = static_cast(eps_); 143 | accscalar_t mean = mean_[c]; 144 | accscalar_t inv_std = accscalar_t(1) / ::sqrt(var[c] + eps); 145 | accscalar_t weight = weight_.size(0) > 0 ? ::abs(static_cast(weight_[c])) + eps : accscalar_t(1); 146 | accscalar_t bias = bias_.size(0) > 0 ? static_cast(bias_[c]) : accscalar_t(0); 147 | 148 | index_t num = x.size(0); 149 | index_t sp = x.size(2); 150 | 151 | index_t step = blockDim.y * gridDim.y; 152 | for (index_t n = threadIdx.y + blockIdx.y * blockDim.y; n < num; n += step) { 153 | auto x_nc = x[n][c]; 154 | 155 | for (index_t s = threadIdx.x; s < sp; s += blockDim.x) { 156 | x_nc[s] = static_cast(weight * (static_cast(x_nc[s]) - mean) * inv_std + bias); 157 | ActivationFn::forward(x_nc[s], activation_param); 158 | } 159 | } 160 | } 161 | 162 | // Functor used in the backward_reduce kernel 163 | template 164 | struct GradOp { 165 | __device__ GradOp(const PTA& y_act, const PTA& dy_act, PTA& xhat, PTA& dy, 166 | accscalar_t inv_gamma, accscalar_t beta, float activation_param) 167 | : y_act(y_act), dy_act(dy_act), xhat(xhat), dy(dy), inv_gamma(inv_gamma), beta(beta), activation_param(activation_param) {} 168 | 169 | __device__ __forceinline__ Float2 operator()(int b, int c, int s) { 170 | const scalar_t y_act_ = y_act[b][c][s]; 171 | const scalar_t dy_act_ = dy_act[b][c][s]; 172 | scalar_t& xhat_ = xhat[b][c][s]; 173 | scalar_t& dy_ = dy[b][c][s]; 174 | 175 | // Invert activation 176 | ActivationFn::backward(y_act_, dy_act_, activation_param, xhat_, dy_); 177 | 178 | // Invert affine transform 179 | xhat_ = (xhat_ - beta) * inv_gamma; 180 | 181 | // Accumulate 182 | accscalar_t xhat_accscalar = static_cast(xhat_); 183 | accscalar_t dy_accscalar = static_cast(dy_); 184 | return Float2(dy_accscalar, xhat_accscalar * dy_accscalar); 185 | } 186 | 187 | const PTA& y_act; 188 | const PTA& dy_act; 189 | PTA& xhat; 190 | PTA& dy; 191 | const accscalar_t inv_gamma; 192 | const accscalar_t beta; 193 | const float activation_param; 194 | }; 195 | 196 | template 197 | __global__ void backward_reduce_kernel( 198 | const at::PackedTensorAccessor y_act, 199 | const at::PackedTensorAccessor dy_act, 200 | const at::PackedTensorAccessor weight, 201 | const at::PackedTensorAccessor bias, 202 | at::PackedTensorAccessor xhat, 203 | at::PackedTensorAccessor dy, 204 | at::PackedTensorAccessor sum_dy, 205 | at::PackedTensorAccessor sum_xhat_dy, 206 | float eps_, float activation_param) { 207 | typedef at::PackedTensorAccessor pta_t; 208 | typedef GradOp gradop_t; 209 | index_t c = blockIdx.x; 210 | 211 | accscalar_t eps = static_cast(eps_); 212 | accscalar_t inv_gamma = weight.size(0) > 0 213 | ? accscalar_t(1) / (::abs(static_cast(weight[c])) + eps) 214 | : accscalar_t(1); 215 | accscalar_t beta = bias.size(0) > 0 ? static_cast(bias[c]) : accscalar_t(0); 216 | 217 | gradop_t gop(y_act, dy_act, xhat, dy, inv_gamma, beta, activation_param); 218 | Float2 res = reduce, gradop_t, pta_t>(gop, y_act, c); 219 | 220 | if (threadIdx.x == 0) { 221 | sum_dy[c] = res.v1; 222 | sum_xhat_dy[c] = res.v2; 223 | } 224 | } 225 | 226 | template 227 | __global__ void backward_kernel( 228 | const at::PackedTensorAccessor xhat, 229 | at::PackedTensorAccessor dy, 230 | const at::PackedTensorAccessor var, 231 | const at::PackedTensorAccessor count, 232 | const at::PackedTensorAccessor sum_dy, 233 | const at::PackedTensorAccessor sum_xhat_dy, 234 | const at::PackedTensorAccessor weight_, 235 | float eps_) { 236 | index_t c = blockIdx.x; 237 | if (c >= xhat.size(1)) return; 238 | 239 | // Cache channel-wise values 240 | accscalar_t eps = static_cast(eps_); 241 | accscalar_t mult = weight_.size(0) > 0 242 | ? (::abs(static_cast(weight_[c])) + eps) / ::sqrt(var[c] + eps) 243 | : accscalar_t(1) / ::sqrt(var[c] + eps); 244 | 245 | accscalar_t norm = accscalar_t(1) / static_cast(count[0]); 246 | accscalar_t mean_dy_c = sum_dy[c] * norm; 247 | accscalar_t mean_xhat_dy_c = sum_xhat_dy[c] * norm; 248 | 249 | index_t num = xhat.size(0); 250 | index_t sp = xhat.size(2); 251 | 252 | index_t step = blockDim.y * gridDim.y; 253 | for (index_t n = threadIdx.y + blockIdx.y * blockDim.y; n < num; n += step) { 254 | auto xhat_nc = xhat[n][c]; 255 | auto dy_nc = dy[n][c]; 256 | 257 | for (index_t s = threadIdx.x; s < sp; s += blockDim.x) { 258 | dy_nc[s] = static_cast(mult * ( 259 | static_cast(dy_nc[s]) - mean_dy_c - static_cast(xhat_nc[s]) * mean_xhat_dy_c)); 260 | } 261 | } 262 | } 263 | -------------------------------------------------------------------------------- /src/utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "utils.h" 8 | 9 | /*********************************************************************************************************************** 10 | * Utility functions 11 | **********************************************************************************************************************/ 12 | 13 | at::Tensor normalize_shape(const at::Tensor& x) { 14 | if (x.ndimension() == 1) { 15 | return x.view({1, -1, 1}); 16 | } else { 17 | return x.view({x.size(0), x.size(1), -1}); 18 | } 19 | } 20 | --------------------------------------------------------------------------------