├── .gitignore ├── LICENSE ├── README.md ├── acb.py ├── alexnet.py ├── convert.py ├── convnet_utils.py ├── dbb_transforms.py ├── dbb_verify.py ├── diversebranchblock.py ├── intro.PNG ├── mobilenet.py ├── resnet.py ├── table1.PNG ├── table2.PNG ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/deployment.xml 131 | .idea/DiverseBranchBlock.iml 132 | .idea/inspectionProfiles/profiles_settings.xml 133 | .idea/misc.xml 134 | .idea/modules.xml 135 | .idea/vcs.xml 136 | 137 | .idea/* 138 | *nori* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diverse Branch Block: Building a Convolution as an Inception-like Unit (PyTorch) (CVPR-2021) 2 | 3 | DBB is a powerful ConvNet building block to replace regular conv. It improves the performance **without any extra inference-time costs**. This repo contains the code for building DBB and converting it into a single conv. You can also get the equivalent kernel and bias in a differentiable way at any time (get_equivalent_kernel_bias in diversebranchblock.py). This may help training-based pruning or quantization. 4 | 5 | This is the PyTorch implementation. The MegEngine version is at https://github.com/megvii-model/DiverseBranchBlock 6 | 7 | Another nice implementation by [@zjykzj](https://github.com/zjykzj). Please check [here](https://github.com/DingXiaoH/DiverseBranchBlock/issues/20). 8 | 9 | Paper: https://arxiv.org/abs/2103.13425 10 | 11 | Update: released the code for building the block, transformations and verification. 12 | 13 | Update: a more efficient implementation of BNAndPadLayer 14 | 15 | Update: MobileNet, ResNet-18 and ResNet-50 models released. You can download them from Google Drive or Baidu Cloud. For the 1x1-KxK branch of MobileNet, we used internal_channels = 2x input_channels for every depthwise conv. 1x also worked but the accuracy was slightly lower (72.71% v.s. 72.88%). On dense conv like ResNet, we used internal_channels = input_channels, and larger internal_channels seemed useless. 16 | 17 | Sometimes I call it ACNet v2 because 'DBB' is two bits larger than 'ACB' in ASCII. (lol) 18 | 19 | @inproceedings{ding2021diverse, 20 | title={Diverse Branch Block: Building a Convolution as an Inception-like Unit}, 21 | author={Ding, Xiaohan and Zhang, Xiangyu and Han, Jungong and Ding, Guiguang}, 22 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 23 | pages={10886--10895}, 24 | year={2021} 25 | } 26 | 27 | # Abstract 28 | 29 | We propose a universal building block of Convolutional Neural Network (ConvNet) to improve the performance without any inference-time costs. The block is named Diverse Branch Block (DBB), which enhances the representational capacity of a single convolution by combining diverse branches of different scales and complexities to enrich the feature space, including sequences of convolutions, multi-scale convolutions, and average pooling. After training, a DBB can be equivalently converted into a single conv layer for deployment. Unlike the advancements of novel ConvNet architectures, DBB complicates the training-time microstructure while maintaining the macro architecture, so that it can be used as a drop-in replacement for regular conv layers of any architecture. In this way, the model can be trained to reach a higher level of performance and then transformed into the original inference-time structure for inference. DBB improves ConvNets on image classification (up to 1.9% higher top-1 accuracy on ImageNet), object detection and semantic segmentation. 30 | 31 | ![image](https://github.com/DingXiaoH/DiverseBranchBlock/blob/main/intro.PNG) 32 | ![image](https://github.com/DingXiaoH/DiverseBranchBlock/blob/main/table1.PNG) 33 | ![image](https://github.com/DingXiaoH/DiverseBranchBlock/blob/main/table2.PNG) 34 | 35 | 36 | # Use our pretrained models 37 | 38 | You may download the models reported in the paper (ResNet-18, ResNet-50, MobileNet) from Google Drive (https://drive.google.com/drive/folders/1BPuqY_ktKz8LvHjFK5abD0qy3ESp8v6H?usp=sharing) or Baidu Cloud (https://pan.baidu.com/s/1wPaQnLKyNjF_bEMNRo4z6Q, the access code is "dbbk"). For the ease of transfer learning on other tasks, we provide both training-time and inference-time models. For ResNet-18 as an example, assume IMGNET_PATH is the path to your directory that contains the "train" and "val" directories of ImageNet, you may test the accuracy by running 39 | ``` 40 | python test.py IMGNET_PATH train ResNet-18_DBB_7101.pth -a ResNet-18 -t DBB 41 | ``` 42 | Here "train" indicates the training-time structure 43 | 44 | 45 | # Convert the training-time models into inference-time 46 | 47 | You may convert a trained model into the inference-time structure with 48 | ``` 49 | python convert.py [weights file of the training-time model to load] [path to save] -a [architecture name] 50 | ``` 51 | For example, 52 | ``` 53 | python convert.py ResNet-18_DBB_7101.pth ResNet-18_DBB_7101_deploy.pth -a ResNet-18 54 | ``` 55 | Then you may test the inference-time model by 56 | ``` 57 | python test.py IMGNET_PATH deploy ResNet-18_DBB_7101_deploy.pth -a ResNet-18 -t DBB 58 | ``` 59 | Note that the argument "deploy" builds an inference-time model. 60 | 61 | 62 | # ImageNet training 63 | 64 | The multi-processing training script in this repo is based on the [official PyTorch example](https://github.com/pytorch/examples/blob/master/imagenet/main.py) for the simplicity and better readability. The modifications include the model-building part and cosine learning rate scheduler. 65 | You may train and test like this: 66 | ``` 67 | python train.py -a ResNet-18 -t DBB --dist-url tcp://127.0.0.1:23333 --dist-backend nccl --multiprocessing-distributed --world-size 1 --rank 0 --workers 64 IMGNET_PATH 68 | python test.py IMGNET_PATH train model_best.pth.tar -a ResNet-18 69 | ``` 70 | 71 | 72 | # Use like this in your own code 73 | 74 | Assume your model is like 75 | ``` 76 | class SomeModel(nn.Module): 77 | def __init__(self, ...): 78 | ... 79 | self.some_conv = nn.Conv2d(...) 80 | self.some_bn = nn.BatchNorm2d(...) 81 | ... 82 | 83 | def forward(self, inputs): 84 | out = ... 85 | out = self.some_bn(self.some_conv(out)) 86 | ... 87 | ``` 88 | For training, just use DiverseBranchBlock to replace the conv-BN. Then SomeModel will be like 89 | ``` 90 | class SomeModel(nn.Module): 91 | def __init__(self, ...): 92 | ... 93 | self.some_dbb = DiverseBranchBlock(..., deploy=False) 94 | ... 95 | 96 | def forward(self, inputs): 97 | out = ... 98 | out = self.some_dbb(out) 99 | ... 100 | ``` 101 | Train the model just like you train the other regular models. Then call **switch_to_deploy** of every DiverseBranchBlock, test, and save. 102 | ``` 103 | model = SomeModel(...) 104 | train(model) 105 | for m in train_model.modules(): 106 | if hasattr(m, 'switch_to_deploy'): 107 | m.switch_to_deploy() 108 | test(model) 109 | save(model) 110 | ``` 111 | 112 | # FAQs 113 | 114 | **Q**: Is the inference-time model's output the _same_ as the training-time model? 115 | 116 | **A**: Yes. You can verify that by 117 | ``` 118 | python dbb_verify.py 119 | ``` 120 | 121 | **Q**: What is the relationship between DBB and RepVGG? 122 | 123 | **A**: RepVGG is a plain architecture, and the RepVGG-style structural re-param is designed for the plain architecture. On a non-plain architecture, a RepVGG block shows no superiority compared to a single 3x3 conv (it improves Res-50 by only 0.03%, as reported in the RepVGG paper). DBB is a universal building block that can be used on numerous architectures. 124 | 125 | **Q**: How to quantize a model with DBB? 126 | 127 | **A1**: Post-training quantization. After training and conversion, you may quantize the converted model with any post-training quantization method. Then you may insert a BN after the conv converted from a DBB and finetune to recover the accuracy just like you quantize and finetune the other models. This is the recommended solution. Please see the quantization example of [RepVGG](https://github.com/DingXiaoH/RepVGG). 128 | 129 | **A2**: Quantization-aware training. During the quantization-aware training, instead of constraining the params in a single kernel (e.g., making every param in {-127, -126, .., 126, 127} for int8) for an ordinary conv, you should constrain the equivalent kernel of a DBB (get_equivalent_kernel_bias()). 130 | 131 | **Q**: I tried to finetune your model with multiple GPUs but got an error. Why are the names of params like "xxxx.weight" in the downloaded weight file but sometimes like "module.xxxx.weight" (shown by nn.Module.named_parameters()) in my model? 132 | 133 | **A**: DistributedDataParallel may prefix "module." to the name of params and cause a mismatch when loading weights by name. The simplest solution is to load the weights (model.load_state_dict(...)) before DistributedDataParallel(model). Otherwise, you may insert "module." before the names like this 134 | ``` 135 | checkpoint = torch.load(...) # This is just a name-value dict 136 | ckpt = {('module.' + k) : v for k, v in checkpoint.items()} 137 | model.load_state_dict(ckpt) 138 | ``` 139 | Likewise, if the param names in the checkpoint file start with "module." but those in your model do not, you may strip the names like 140 | ``` 141 | ckpt = {k.replace('module.', ''):v for k,v in checkpoint.items()} # strip the names 142 | model.load_state_dict(ckpt) 143 | ``` 144 | **Q**: So a DBB derives the equivalent KxK kernels before each forwarding to save computations? 145 | 146 | **A**: No! More precisely, we do the conversion only once right after training. Then the training-time model can be discarded, and every resultant block is just a KxK conv. We only save and use the resultant model. 147 | 148 | 149 | ## Contact 150 | 151 | **xiaohding@gmail.com** (The original Tsinghua mailbox dxh17@mails.tsinghua.edu.cn will expire in several months) 152 | 153 | Google Scholar Profile: https://scholar.google.com/citations?user=CIjw0KoAAAAJ&hl=en 154 | 155 | Homepage: https://dingxiaohan.xyz/ 156 | 157 | My open-sourced papers and repos: 158 | 159 | The **Structural Re-parameterization Universe**: 160 | 161 | 1. RepLKNet (CVPR 2022) **Powerful efficient architecture with very large kernels (31x31) and guidelines for using large kernels in model CNNs**\ 162 | [Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs](https://arxiv.org/abs/2203.06717)\ 163 | [code](https://github.com/DingXiaoH/RepLKNet-pytorch). 164 | 165 | 2. **RepOptimizer** (ICLR 2023) uses **Gradient Re-parameterization** to train powerful models efficiently. The training-time **RepOpt-VGG** is **as simple as the inference-time**. It also addresses the problem of quantization.\ 166 | [Re-parameterizing Your Optimizers rather than Architectures](https://arxiv.org/pdf/2205.15242.pdf)\ 167 | [code](https://github.com/DingXiaoH/RepOptimizers). 168 | 169 | 3. RepVGG (CVPR 2021) **A super simple and powerful VGG-style ConvNet architecture**. Up to **84.16%** ImageNet top-1 accuracy!\ 170 | [RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697)\ 171 | [code](https://github.com/DingXiaoH/RepVGG). 172 | 173 | 4. RepMLP (CVPR 2022) **MLP-style building block and Architecture**\ 174 | [RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality](https://arxiv.org/abs/2112.11081)\ 175 | [code](https://github.com/DingXiaoH/RepMLP). 176 | 177 | 5. ResRep (ICCV 2021) **State-of-the-art** channel pruning (Res50, 55\% FLOPs reduction, 76.15\% acc)\ 178 | [ResRep: Lossless CNN Pruning via Decoupling Remembering and Forgetting](https://openaccess.thecvf.com/content/ICCV2021/papers/Ding_ResRep_Lossless_CNN_Pruning_via_Decoupling_Remembering_and_Forgetting_ICCV_2021_paper.pdf)\ 179 | [code](https://github.com/DingXiaoH/ResRep). 180 | 181 | 6. ACB (ICCV 2019) is a CNN component without any inference-time costs. The first work of our Structural Re-parameterization Universe.\ 182 | [ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks](http://openaccess.thecvf.com/content_ICCV_2019/papers/Ding_ACNet_Strengthening_the_Kernel_Skeletons_for_Powerful_CNN_via_Asymmetric_ICCV_2019_paper.pdf).\ 183 | [code](https://github.com/DingXiaoH/ACNet). 184 | 185 | 7. DBB (CVPR 2021) is a CNN component with higher performance than ACB and still no inference-time costs. Sometimes I call it ACNet v2 because "DBB" is 2 bits larger than "ACB" in ASCII (lol).\ 186 | [Diverse Branch Block: Building a Convolution as an Inception-like Unit](https://arxiv.org/abs/2103.13425)\ 187 | [code](https://github.com/DingXiaoH/DiverseBranchBlock). 188 | 189 | **Model compression and acceleration**: 190 | 191 | 1. (CVPR 2019) Channel pruning: [Centripetal SGD for Pruning Very Deep Convolutional Networks with Complicated Structure](http://openaccess.thecvf.com/content_CVPR_2019/html/Ding_Centripetal_SGD_for_Pruning_Very_Deep_Convolutional_Networks_With_Complicated_CVPR_2019_paper.html)\ 192 | [code](https://github.com/DingXiaoH/Centripetal-SGD) 193 | 194 | 2. (ICML 2019) Channel pruning: [Approximated Oracle Filter Pruning for Destructive CNN Width Optimization](http://proceedings.mlr.press/v97/ding19a.html)\ 195 | [code](https://github.com/DingXiaoH/AOFP) 196 | 197 | 3. (NeurIPS 2019) Unstructured pruning: [Global Sparse Momentum SGD for Pruning Very Deep Neural Networks](http://papers.nips.cc/paper/8867-global-sparse-momentum-sgd-for-pruning-very-deep-neural-networks.pdf)\ 198 | [code](https://github.com/DingXiaoH/GSM-SGD) 199 | 200 | -------------------------------------------------------------------------------- /acb.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | import torch 4 | 5 | class ACBlock(nn.Module): 6 | 7 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, 8 | use_affine=True, reduce_gamma=False, gamma_init=None ): 9 | super(ACBlock, self).__init__() 10 | self.deploy = deploy 11 | if deploy: 12 | self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride, 13 | padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) 14 | else: 15 | self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 16 | kernel_size=(kernel_size, kernel_size), stride=stride, 17 | padding=padding, dilation=dilation, groups=groups, bias=False, 18 | padding_mode=padding_mode) 19 | self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine) 20 | 21 | 22 | if padding - kernel_size // 2 >= 0: 23 | # Common use case. E.g., k=3, p=1 or k=5, p=2 24 | self.crop = 0 25 | # Compared to the KxK layer, the padding of the 1xK layer and Kx1 layer should be adjust to align the sliding windows (Fig 2 in the paper) 26 | hor_padding = [padding - kernel_size // 2, padding] 27 | ver_padding = [padding, padding - kernel_size // 2] 28 | else: 29 | # A negative "padding" (padding - kernel_size//2 < 0, which is not a common use case) is cropping. 30 | # Since nn.Conv2d does not support negative padding, we implement it manually 31 | self.crop = kernel_size // 2 - padding 32 | hor_padding = [0, padding] 33 | ver_padding = [padding, 0] 34 | 35 | self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1), 36 | stride=stride, 37 | padding=ver_padding, dilation=dilation, groups=groups, bias=False, 38 | padding_mode=padding_mode) 39 | 40 | self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size), 41 | stride=stride, 42 | padding=hor_padding, dilation=dilation, groups=groups, bias=False, 43 | padding_mode=padding_mode) 44 | self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine) 45 | self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine) 46 | 47 | if reduce_gamma: 48 | self.init_gamma(1.0 / 3) 49 | 50 | if gamma_init is not None: 51 | assert not reduce_gamma 52 | self.init_gamma(gamma_init) 53 | 54 | 55 | def _fuse_bn_tensor(self, conv, bn): 56 | std = (bn.running_var + bn.eps).sqrt() 57 | t = (bn.weight / std).reshape(-1, 1, 1, 1) 58 | return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std 59 | 60 | def _add_to_square_kernel(self, square_kernel, asym_kernel): 61 | asym_h = asym_kernel.size(2) 62 | asym_w = asym_kernel.size(3) 63 | square_h = square_kernel.size(2) 64 | square_w = square_kernel.size(3) 65 | square_kernel[:, :, square_h // 2 - asym_h // 2: square_h // 2 - asym_h // 2 + asym_h, 66 | square_w // 2 - asym_w // 2: square_w // 2 - asym_w // 2 + asym_w] += asym_kernel 67 | 68 | def get_equivalent_kernel_bias(self): 69 | hor_k, hor_b = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) 70 | ver_k, ver_b = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) 71 | square_k, square_b = self._fuse_bn_tensor(self.square_conv, self.square_bn) 72 | self._add_to_square_kernel(square_k, hor_k) 73 | self._add_to_square_kernel(square_k, ver_k) 74 | return square_k, hor_b + ver_b + square_b 75 | 76 | 77 | def switch_to_deploy(self): 78 | deploy_k, deploy_b = self.get_equivalent_kernel_bias() 79 | self.deploy = True 80 | self.fused_conv = nn.Conv2d(in_channels=self.square_conv.in_channels, out_channels=self.square_conv.out_channels, 81 | kernel_size=self.square_conv.kernel_size, stride=self.square_conv.stride, 82 | padding=self.square_conv.padding, dilation=self.square_conv.dilation, groups=self.square_conv.groups, bias=True, 83 | padding_mode=self.square_conv.padding_mode) 84 | self.__delattr__('square_conv') 85 | self.__delattr__('square_bn') 86 | self.__delattr__('hor_conv') 87 | self.__delattr__('hor_bn') 88 | self.__delattr__('ver_conv') 89 | self.__delattr__('ver_bn') 90 | self.fused_conv.weight.data = deploy_k 91 | self.fused_conv.bias.data = deploy_b 92 | 93 | 94 | def init_gamma(self, gamma_value): 95 | init.constant_(self.square_bn.weight, gamma_value) 96 | init.constant_(self.ver_bn.weight, gamma_value) 97 | init.constant_(self.hor_bn.weight, gamma_value) 98 | print('init gamma of square, ver and hor as ', gamma_value) 99 | 100 | def single_init(self): 101 | init.constant_(self.square_bn.weight, 1.0) 102 | init.constant_(self.ver_bn.weight, 0.0) 103 | init.constant_(self.hor_bn.weight, 0.0) 104 | print('init gamma of square as 1, ver and hor as 0') 105 | 106 | def forward(self, input): 107 | if self.deploy: 108 | return self.fused_conv(input) 109 | else: 110 | square_outputs = self.square_conv(input) 111 | square_outputs = self.square_bn(square_outputs) 112 | if self.crop > 0: 113 | ver_input = input[:, :, :, self.crop:-self.crop] 114 | hor_input = input[:, :, self.crop:-self.crop, :] 115 | else: 116 | ver_input = input 117 | hor_input = input 118 | vertical_outputs = self.ver_conv(ver_input) 119 | vertical_outputs = self.ver_bn(vertical_outputs) 120 | horizontal_outputs = self.hor_conv(hor_input) 121 | horizontal_outputs = self.hor_bn(horizontal_outputs) 122 | result = square_outputs + vertical_outputs + horizontal_outputs 123 | return result 124 | 125 | if __name__ == '__main__': 126 | N = 1 127 | C = 2 128 | H = 62 129 | W = 62 130 | O = 8 131 | groups = 4 132 | 133 | x = torch.randn(N, C, H, W) 134 | print('input shape is ', x.size()) 135 | 136 | test_kernel_padding = [(3,1), (3,0), (5,1), (5,2), (5,3), (5,4), (5,6)] 137 | 138 | for k, p in test_kernel_padding: 139 | acb = ACBlock(C, O, kernel_size=k, padding=p, stride=1, deploy=False) 140 | acb.eval() 141 | for module in acb.modules(): 142 | if isinstance(module, nn.BatchNorm2d): 143 | nn.init.uniform_(module.running_mean, 0, 0.1) 144 | nn.init.uniform_(module.running_var, 0, 0.2) 145 | nn.init.uniform_(module.weight, 0, 0.3) 146 | nn.init.uniform_(module.bias, 0, 0.4) 147 | out = acb(x) 148 | acb.switch_to_deploy() 149 | deployout = acb(x) 150 | print('difference between the outputs of the training-time and converted ACB is') 151 | print(((deployout - out) ** 2).sum()) 152 | 153 | -------------------------------------------------------------------------------- /alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from convnet_utils import conv_bn, conv_bn_relu 4 | 5 | def create_stem(channels): 6 | stem = nn.Sequential() 7 | stem.add_module('conv1', conv_bn_relu(in_channels=3, out_channels=channels[0], kernel_size=11, stride=4, padding=2)) 8 | stem.add_module('maxpool1', nn.Maxpool2d(kernel_size=3, stride=2)) 9 | stem.add_module('conv2', conv_bn_relu(in_channels=channels[0], out_channels=channels[1], kernel_size=5, padding=2)) 10 | stem.add_module('maxpool2', nn.Maxpool2d(kernel_size=3, stride=2)) 11 | stem.add_module('conv3', conv_bn_relu(in_channels=channels[1], out_channels=channels[2], kernel_size=3, padding=1)) 12 | stem.add_module('conv4', conv_bn_relu(in_channels=channels[2], out_channels=channels[3], kernel_size=3, padding=1)) 13 | stem.add_module('conv5', conv_bn_relu(in_channels=channels[3], out_channels=channels[4], kernel_size=3, padding=1)) 14 | stem.add_module('maxpool3', nn.Maxpool2d(kernel_size=3, stride=2)) 15 | return stem 16 | 17 | class AlexNet(nn.Module): 18 | 19 | def __init__(self): 20 | super(AlexNet, self).__init__() 21 | channels = [64, 192, 384, 384, 256] 22 | self.stem = create_stem(channels) 23 | self.linear1 = nn.Linear(in_features=channels[4] * 6 * 6, out_features=4096) 24 | self.relu1 = nn.ReLU() 25 | self.drop1 = nn.Dropout(0.5) 26 | self.linear2 = nn.Linear(in_features=4096, out_features=4096) 27 | self.relu2 = nn.ReLU() 28 | self.drop2 = nn.Dropout(0.5) 29 | self.linear3 = nn.Linear(in_features=4096, out_features=1000) 30 | 31 | def forward(self, x): 32 | out = self.stem(x) 33 | out = out.view(out.size(0), -1) 34 | out = self.linear1(out) 35 | out = self.relu1(out) 36 | out = self.drop1(out) 37 | out = self.linear2(out) 38 | out = self.relu2(out) 39 | out = self.drop2(out) 40 | out = self.linear3(out) 41 | return out 42 | 43 | def create_AlexNet(): 44 | return AlexNet() -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from convnet_utils import switch_conv_bn_impl, switch_deploy_flag, build_model 5 | 6 | parser = argparse.ArgumentParser(description='DBB Conversion') 7 | parser.add_argument('load', metavar='LOAD', help='path to the weights file') 8 | parser.add_argument('save', metavar='SAVE', help='path to the weights file') 9 | parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18') 10 | 11 | def convert(): 12 | args = parser.parse_args() 13 | 14 | switch_conv_bn_impl('DBB') 15 | switch_deploy_flag(False) 16 | train_model = build_model(args.arch) 17 | 18 | if 'hdf5' in args.load: 19 | from utils import model_load_hdf5 20 | model_load_hdf5(train_model, args.load) 21 | elif os.path.isfile(args.load): 22 | print("=> loading checkpoint '{}'".format(args.load)) 23 | checkpoint = torch.load(args.load) 24 | if 'state_dict' in checkpoint: 25 | checkpoint = checkpoint['state_dict'] 26 | ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names 27 | train_model.load_state_dict(ckpt) 28 | else: 29 | print("=> no checkpoint found at '{}'".format(args.load)) 30 | 31 | for m in train_model.modules(): 32 | if hasattr(m, 'switch_to_deploy'): 33 | m.switch_to_deploy() 34 | 35 | torch.save(train_model.state_dict(), args.save) 36 | 37 | 38 | if __name__ == '__main__': 39 | convert() -------------------------------------------------------------------------------- /convnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diversebranchblock import DiverseBranchBlock 4 | from acb import ACBlock 5 | from dbb_transforms import transI_fusebn 6 | 7 | CONV_BN_IMPL = 'base' 8 | 9 | DEPLOY_FLAG = False 10 | 11 | class ConvBN(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, 13 | stride, padding, dilation, groups, deploy=False, nonlinear=None): 14 | super().__init__() 15 | if nonlinear is None: 16 | self.nonlinear = nn.Identity() 17 | else: 18 | self.nonlinear = nonlinear 19 | if deploy: 20 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 21 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True) 22 | else: 23 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 24 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False) 25 | self.bn = nn.BatchNorm2d(num_features=out_channels) 26 | 27 | def forward(self, x): 28 | if hasattr(self, 'bn'): 29 | return self.nonlinear(self.bn(self.conv(x))) 30 | else: 31 | return self.nonlinear(self.conv(x)) 32 | 33 | def switch_to_deploy(self): 34 | kernel, bias = transI_fusebn(self.conv.weight, self.bn) 35 | conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size, 36 | stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True) 37 | conv.weight.data = kernel 38 | conv.bias.data = bias 39 | for para in self.parameters(): 40 | para.detach_() 41 | self.__delattr__('conv') 42 | self.__delattr__('bn') 43 | self.conv = conv 44 | 45 | 46 | def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1): 47 | if CONV_BN_IMPL == 'base' or kernel_size == 1 or kernel_size >= 7: 48 | blk_type = ConvBN 49 | elif CONV_BN_IMPL == 'ACB': 50 | blk_type = ACBlock 51 | else: 52 | blk_type = DiverseBranchBlock 53 | return blk_type(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 54 | padding=padding, dilation=dilation, groups=groups, deploy=DEPLOY_FLAG) 55 | 56 | def conv_bn_relu(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1): 57 | if CONV_BN_IMPL == 'base' or kernel_size == 1 or kernel_size >= 7: 58 | blk_type = ConvBN 59 | elif CONV_BN_IMPL == 'ACB': 60 | blk_type = ACBlock 61 | else: 62 | blk_type = DiverseBranchBlock 63 | return blk_type(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 64 | padding=padding, dilation=dilation, groups=groups, deploy=DEPLOY_FLAG, nonlinear=nn.ReLU()) 65 | 66 | 67 | def switch_conv_bn_impl(block_type): 68 | assert block_type in ['base', 'DBB', 'ACB'] 69 | global CONV_BN_IMPL 70 | CONV_BN_IMPL = block_type 71 | 72 | def switch_deploy_flag(deploy): 73 | global DEPLOY_FLAG 74 | DEPLOY_FLAG = deploy 75 | print('deploy flag: ', DEPLOY_FLAG) 76 | 77 | 78 | def build_model(arch): 79 | if arch == 'ResNet-18': 80 | from resnet import create_Res18 81 | model = create_Res18() 82 | elif arch == 'ResNet-50': 83 | from resnet import create_Res50 84 | model = create_Res50() 85 | elif arch == 'MobileNet': 86 | from mobilenet import create_MobileNet 87 | model = create_MobileNet() 88 | else: 89 | raise ValueError('TODO') 90 | return model -------------------------------------------------------------------------------- /dbb_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | def transI_fusebn(kernel, bn): 6 | gamma = bn.weight 7 | std = (bn.running_var + bn.eps).sqrt() 8 | return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std 9 | 10 | def transII_addbranch(kernels, biases): 11 | return sum(kernels), sum(biases) 12 | 13 | def transIII_1x1_kxk(k1, b1, k2, b2, groups): 14 | if groups == 1: 15 | k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) # 16 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 17 | else: 18 | k_slices = [] 19 | b_slices = [] 20 | k1_T = k1.permute(1, 0, 2, 3) 21 | k1_group_width = k1.size(0) // groups 22 | k2_group_width = k2.size(0) // groups 23 | for g in range(groups): 24 | k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :] 25 | k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :] 26 | k_slices.append(F.conv2d(k2_slice, k1_T_slice)) 27 | b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3))) 28 | k, b_hat = transIV_depthconcat(k_slices, b_slices) 29 | return k, b_hat + b2 30 | 31 | def transIV_depthconcat(kernels, biases): 32 | return torch.cat(kernels, dim=0), torch.cat(biases) 33 | 34 | def transV_avg(channels, kernel_size, groups): 35 | input_dim = channels // groups 36 | k = torch.zeros((channels, input_dim, kernel_size, kernel_size)) 37 | k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2 38 | return k 39 | 40 | # This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels 41 | def transVI_multiscale(kernel, target_kernel_size): 42 | H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2 43 | W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2 44 | return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]) -------------------------------------------------------------------------------- /dbb_verify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diversebranchblock import DiverseBranchBlock 4 | 5 | 6 | if __name__ == '__main__': 7 | x = torch.randn(1, 32, 56, 56) 8 | for k in (3, 5): 9 | for s in (1, 2): 10 | dbb = DiverseBranchBlock(in_channels=32, out_channels=64, kernel_size=k, stride=s, padding=k//2, 11 | groups=2, deploy=False) 12 | for module in dbb.modules(): 13 | if isinstance(module, torch.nn.BatchNorm2d): 14 | nn.init.uniform_(module.running_mean, 0, 0.1) 15 | nn.init.uniform_(module.running_var, 0, 0.1) 16 | nn.init.uniform_(module.weight, 0, 0.1) 17 | nn.init.uniform_(module.bias, 0, 0.1) 18 | dbb.eval() 19 | print(dbb) 20 | train_y = dbb(x) 21 | dbb.switch_to_deploy() 22 | deploy_y = dbb(x) 23 | print(dbb) 24 | print('========================== The diff is') 25 | print(((train_y - deploy_y) ** 2).sum()) -------------------------------------------------------------------------------- /diversebranchblock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from dbb_transforms import * 5 | 6 | def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, 7 | padding_mode='zeros'): 8 | conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 9 | stride=stride, padding=padding, dilation=dilation, groups=groups, 10 | bias=False, padding_mode=padding_mode) 11 | bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True) 12 | se = nn.Sequential() 13 | se.add_module('conv', conv_layer) 14 | se.add_module('bn', bn_layer) 15 | return se 16 | 17 | 18 | class IdentityBasedConv1x1(nn.Conv2d): 19 | 20 | def __init__(self, channels, groups=1): 21 | super(IdentityBasedConv1x1, self).__init__(in_channels=channels, out_channels=channels, kernel_size=1, stride=1, padding=0, groups=groups, bias=False) 22 | 23 | assert channels % groups == 0 24 | input_dim = channels // groups 25 | id_value = np.zeros((channels, input_dim, 1, 1)) 26 | for i in range(channels): 27 | id_value[i, i % input_dim, 0, 0] = 1 28 | self.id_tensor = torch.from_numpy(id_value).type_as(self.weight) 29 | nn.init.zeros_(self.weight) 30 | 31 | def forward(self, input): 32 | kernel = self.weight + self.id_tensor.to(self.weight.device) 33 | result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups) 34 | return result 35 | 36 | def get_actual_kernel(self): 37 | return self.weight + self.id_tensor.to(self.weight.device) 38 | 39 | 40 | class BNAndPadLayer(nn.Module): 41 | def __init__(self, 42 | pad_pixels, 43 | num_features, 44 | eps=1e-5, 45 | momentum=0.1, 46 | affine=True, 47 | track_running_stats=True): 48 | super(BNAndPadLayer, self).__init__() 49 | self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) 50 | self.pad_pixels = pad_pixels 51 | 52 | def forward(self, input): 53 | output = self.bn(input) 54 | if self.pad_pixels > 0: 55 | if self.bn.affine: 56 | pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(self.bn.running_var + self.bn.eps) 57 | else: 58 | pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps) 59 | output = F.pad(output, [self.pad_pixels] * 4) 60 | pad_values = pad_values.view(1, -1, 1, 1) 61 | output[:, :, 0:self.pad_pixels, :] = pad_values 62 | output[:, :, -self.pad_pixels:, :] = pad_values 63 | output[:, :, :, 0:self.pad_pixels] = pad_values 64 | output[:, :, :, -self.pad_pixels:] = pad_values 65 | return output 66 | 67 | @property 68 | def weight(self): 69 | return self.bn.weight 70 | 71 | @property 72 | def bias(self): 73 | return self.bn.bias 74 | 75 | @property 76 | def running_mean(self): 77 | return self.bn.running_mean 78 | 79 | @property 80 | def running_var(self): 81 | return self.bn.running_var 82 | 83 | @property 84 | def eps(self): 85 | return self.bn.eps 86 | 87 | 88 | class DiverseBranchBlock(nn.Module): 89 | 90 | def __init__(self, in_channels, out_channels, kernel_size, 91 | stride=1, padding=0, dilation=1, groups=1, 92 | internal_channels_1x1_3x3=None, 93 | deploy=False, nonlinear=None, single_init=False): 94 | super(DiverseBranchBlock, self).__init__() 95 | self.deploy = deploy 96 | 97 | if nonlinear is None: 98 | self.nonlinear = nn.Identity() 99 | else: 100 | self.nonlinear = nonlinear 101 | 102 | self.kernel_size = kernel_size 103 | self.out_channels = out_channels 104 | self.groups = groups 105 | assert padding == kernel_size // 2 106 | 107 | if deploy: 108 | self.dbb_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 109 | padding=padding, dilation=dilation, groups=groups, bias=True) 110 | 111 | else: 112 | 113 | self.dbb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups) 114 | 115 | self.dbb_avg = nn.Sequential() 116 | if groups < out_channels: 117 | self.dbb_avg.add_module('conv', 118 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, 119 | stride=1, padding=0, groups=groups, bias=False)) 120 | self.dbb_avg.add_module('bn', BNAndPadLayer(pad_pixels=padding, num_features=out_channels)) 121 | self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) 122 | self.dbb_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, 123 | padding=0, groups=groups) 124 | else: 125 | self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding)) 126 | 127 | self.dbb_avg.add_module('avgbn', nn.BatchNorm2d(out_channels)) 128 | 129 | 130 | if internal_channels_1x1_3x3 is None: 131 | internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels 132 | 133 | self.dbb_1x1_kxk = nn.Sequential() 134 | if internal_channels_1x1_3x3 == in_channels: 135 | self.dbb_1x1_kxk.add_module('idconv1', IdentityBasedConv1x1(channels=in_channels, groups=groups)) 136 | else: 137 | self.dbb_1x1_kxk.add_module('conv1', nn.Conv2d(in_channels=in_channels, out_channels=internal_channels_1x1_3x3, 138 | kernel_size=1, stride=1, padding=0, groups=groups, bias=False)) 139 | self.dbb_1x1_kxk.add_module('bn1', BNAndPadLayer(pad_pixels=padding, num_features=internal_channels_1x1_3x3, affine=True)) 140 | self.dbb_1x1_kxk.add_module('conv2', nn.Conv2d(in_channels=internal_channels_1x1_3x3, out_channels=out_channels, 141 | kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=False)) 142 | self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels)) 143 | 144 | # The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases. 145 | if single_init: 146 | # Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting. 147 | self.single_init() 148 | 149 | def get_equivalent_kernel_bias(self): 150 | k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight, self.dbb_origin.bn) 151 | 152 | if hasattr(self, 'dbb_1x1'): 153 | k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn) 154 | k_1x1 = transVI_multiscale(k_1x1, self.kernel_size) 155 | else: 156 | k_1x1, b_1x1 = 0, 0 157 | 158 | if hasattr(self.dbb_1x1_kxk, 'idconv1'): 159 | k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel() 160 | else: 161 | k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight 162 | k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first, self.dbb_1x1_kxk.bn1) 163 | k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2) 164 | k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first, b_1x1_kxk_first, k_1x1_kxk_second, b_1x1_kxk_second, groups=self.groups) 165 | 166 | k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups) 167 | k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device), self.dbb_avg.avgbn) 168 | if hasattr(self.dbb_avg, 'conv'): 169 | k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight, self.dbb_avg.bn) 170 | k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first, b_1x1_avg_first, k_1x1_avg_second, b_1x1_avg_second, groups=self.groups) 171 | else: 172 | k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second 173 | 174 | return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged), (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged)) 175 | 176 | def switch_to_deploy(self): 177 | if hasattr(self, 'dbb_reparam'): 178 | return 179 | kernel, bias = self.get_equivalent_kernel_bias() 180 | self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels, out_channels=self.dbb_origin.conv.out_channels, 181 | kernel_size=self.dbb_origin.conv.kernel_size, stride=self.dbb_origin.conv.stride, 182 | padding=self.dbb_origin.conv.padding, dilation=self.dbb_origin.conv.dilation, groups=self.dbb_origin.conv.groups, bias=True) 183 | self.dbb_reparam.weight.data = kernel 184 | self.dbb_reparam.bias.data = bias 185 | for para in self.parameters(): 186 | para.detach_() 187 | self.__delattr__('dbb_origin') 188 | self.__delattr__('dbb_avg') 189 | if hasattr(self, 'dbb_1x1'): 190 | self.__delattr__('dbb_1x1') 191 | self.__delattr__('dbb_1x1_kxk') 192 | 193 | def forward(self, inputs): 194 | 195 | if hasattr(self, 'dbb_reparam'): 196 | return self.nonlinear(self.dbb_reparam(inputs)) 197 | 198 | out = self.dbb_origin(inputs) 199 | if hasattr(self, 'dbb_1x1'): 200 | out += self.dbb_1x1(inputs) 201 | out += self.dbb_avg(inputs) 202 | out += self.dbb_1x1_kxk(inputs) 203 | return self.nonlinear(out) 204 | 205 | def init_gamma(self, gamma_value): 206 | if hasattr(self, "dbb_origin"): 207 | torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value) 208 | if hasattr(self, "dbb_1x1"): 209 | torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value) 210 | if hasattr(self, "dbb_avg"): 211 | torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value) 212 | if hasattr(self, "dbb_1x1_kxk"): 213 | torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value) 214 | 215 | def single_init(self): 216 | self.init_gamma(0.0) 217 | if hasattr(self, "dbb_origin"): 218 | torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0) 219 | -------------------------------------------------------------------------------- /intro.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DingXiaoH/DiverseBranchBlock/8d2b16b6aee45a33236b2d11685be6857f9ba929/intro.PNG -------------------------------------------------------------------------------- /mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from convnet_utils import conv_bn_relu 4 | MOBILE_CHANNELS = [32, 5 | 32, 64, 6 | 64, 128, 7 | 128, 128, 8 | 128, 256, 9 | 256, 256, 10 | 256, 512, 11 | 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 12 | 512, 1024, 13 | 1024, 1024] 14 | 15 | class MobileV1Block(nn.Module): 16 | '''Depthwise conv + Pointwise conv''' 17 | def __init__(self, in_planes, out_planes, stride=1): 18 | super(MobileV1Block, self).__init__() 19 | self.depthwise = conv_bn_relu(in_channels=in_planes, out_channels=in_planes, kernel_size=3, 20 | stride=stride, padding=1, groups=in_planes) 21 | self.pointwise = conv_bn_relu(in_channels=in_planes, out_channels=out_planes, kernel_size=1, 22 | stride=1, padding=0) 23 | 24 | def forward(self, x): 25 | out = self.depthwise(x) 26 | out = self.pointwise(out) 27 | return out 28 | 29 | 30 | class MobileV1(nn.Module): 31 | 32 | def __init__(self, num_classes): 33 | super(MobileV1, self).__init__() 34 | channels = MOBILE_CHANNELS 35 | assert len(channels) == 27 36 | self.conv1 = conv_bn_relu(in_channels=3, out_channels=channels[0], kernel_size=3, stride=2, padding=1) 37 | blocks = [] 38 | for block_idx in range(13): 39 | depthwise_channels = int(channels[block_idx * 2 + 1]) 40 | pointwise_channels = int(channels[block_idx * 2 + 2]) 41 | stride = 2 if block_idx in [1, 3, 5, 11] else 1 42 | blocks.append(MobileV1Block(in_planes=depthwise_channels, out_planes=pointwise_channels, stride=stride)) 43 | self.stem = nn.Sequential(*blocks) 44 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 45 | self.linear = nn.Linear(channels[-1], num_classes) 46 | 47 | def forward(self, x): 48 | out = self.conv1(x) 49 | out = self.stem(out) 50 | out = self.gap(out) 51 | out = out.view(out.size(0), -1) 52 | out = self.linear(out) 53 | return out 54 | 55 | def create_MobileNet(): 56 | return MobileV1(num_classes=1000) -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from convnet_utils import conv_bn, conv_bn_relu 4 | 5 | class BasicBlock(nn.Module): 6 | expansion = 1 7 | def __init__(self, in_planes, planes, stride=1): 8 | super(BasicBlock, self).__init__() 9 | if stride != 1 or in_planes != self.expansion * planes: 10 | self.shortcut = conv_bn(in_channels=in_planes, out_channels=self.expansion * planes, kernel_size=1, stride=stride) 11 | else: 12 | self.shortcut = nn.Identity() 13 | self.conv1 = conv_bn_relu(in_channels=in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1) 14 | self.conv2 = conv_bn(in_channels=planes, out_channels=self.expansion * planes, kernel_size=3, stride=1, padding=1) 15 | 16 | def forward(self, x): 17 | out = self.conv1(x) 18 | out = self.conv2(out) 19 | out = out + self.shortcut(x) 20 | out = F.relu(out) 21 | return out 22 | 23 | 24 | class Bottleneck(nn.Module): 25 | expansion = 4 26 | def __init__(self, in_planes, planes, stride=1): 27 | super(Bottleneck, self).__init__() 28 | 29 | if stride != 1 or in_planes != self.expansion*planes: 30 | self.shortcut = conv_bn(in_planes, self.expansion*planes, kernel_size=1, stride=stride) 31 | else: 32 | self.shortcut = nn.Identity() 33 | 34 | self.conv1 = conv_bn_relu(in_planes, planes, kernel_size=1) 35 | self.conv2 = conv_bn_relu(planes, planes, kernel_size=3, stride=stride, padding=1) 36 | self.conv3 = conv_bn(planes, self.expansion*planes, kernel_size=1) 37 | 38 | def forward(self, x): 39 | out = self.conv1(x) 40 | out = self.conv2(out) 41 | out = self.conv3(out) 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class ResNet(nn.Module): 48 | def __init__(self, block, num_blocks, num_classes=1000, width_multiplier=1): 49 | super(ResNet, self).__init__() 50 | 51 | self.in_planes = int(64 * width_multiplier) 52 | self.stage0 = nn.Sequential() 53 | self.stage0.add_module('conv1', conv_bn_relu(in_channels=3, out_channels=self.in_planes, kernel_size=7, stride=2, padding=3)) 54 | self.stage0.add_module('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 55 | self.stage1 = self._make_stage(block, int(64 * width_multiplier), num_blocks[0], stride=1) 56 | self.stage2 = self._make_stage(block, int(128 * width_multiplier), num_blocks[1], stride=2) 57 | self.stage3 = self._make_stage(block, int(256 * width_multiplier), num_blocks[2], stride=2) 58 | self.stage4 = self._make_stage(block, int(512 * width_multiplier), num_blocks[3], stride=2) 59 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 60 | self.linear = nn.Linear(int(512*block.expansion*width_multiplier), num_classes) 61 | 62 | def _make_stage(self, block, planes, num_blocks, stride): 63 | strides = [stride] + [1]*(num_blocks-1) 64 | blocks = [] 65 | for stride in strides: 66 | if block is Bottleneck: 67 | blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride)) 68 | else: 69 | blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride)) 70 | self.in_planes = int(planes * block.expansion) 71 | return nn.Sequential(*blocks) 72 | 73 | def forward(self, x): 74 | out = self.stage0(x) 75 | out = self.stage1(out) 76 | out = self.stage2(out) 77 | out = self.stage3(out) 78 | out = self.stage4(out) 79 | out = self.gap(out) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | return out 83 | 84 | 85 | def create_Res18(): 86 | return ResNet(BasicBlock, [2,2,2,2], num_classes=1000, width_multiplier=1) 87 | 88 | 89 | def create_Res50(): 90 | return ResNet(Bottleneck, [3,4,6,3], num_classes=1000, width_multiplier=1) -------------------------------------------------------------------------------- /table1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DingXiaoH/DiverseBranchBlock/8d2b16b6aee45a33236b2d11685be6857f9ba929/table1.PNG -------------------------------------------------------------------------------- /table2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DingXiaoH/DiverseBranchBlock/8d2b16b6aee45a33236b2d11685be6857f9ba929/table2.PNG -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torchvision.datasets as datasets 7 | from utils import accuracy, ProgressMeter, AverageMeter, val_preprocess 8 | from convnet_utils import switch_deploy_flag, switch_conv_bn_impl, build_model 9 | 10 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Test') 11 | parser.add_argument('data', metavar='DIR', help='path to dataset') 12 | parser.add_argument('mode', metavar='MODE', default='train', choices=['train', 'deploy'], help='train or deploy') 13 | parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file') 14 | parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18') 15 | parser.add_argument('-t', '--blocktype', metavar='BLK', default='DBB', choices=['DBB', 'ACB', 'base']) 16 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 17 | help='number of data loading workers (default: 4)') 18 | parser.add_argument('-b', '--batch-size', default=100, type=int, 19 | metavar='N', 20 | help='mini-batch size (default: 100) for test') 21 | 22 | def test(): 23 | args = parser.parse_args() 24 | 25 | switch_deploy_flag(args.mode == 'deploy') 26 | switch_conv_bn_impl(args.blocktype) 27 | model = build_model(args.arch) 28 | 29 | if not torch.cuda.is_available(): 30 | print('using CPU, this will be slow') 31 | use_gpu = False 32 | else: 33 | model = model.cuda() 34 | use_gpu = True 35 | 36 | # define loss function (criterion) and optimizer 37 | criterion = torch.nn.CrossEntropyLoss().cuda() 38 | 39 | if 'hdf5' in args.weights: 40 | from utils import model_load_hdf5 41 | model_load_hdf5(model, args.weights) 42 | elif os.path.isfile(args.weights): 43 | print("=> loading checkpoint '{}'".format(args.weights)) 44 | checkpoint = torch.load(args.weights) 45 | if 'state_dict' in checkpoint: 46 | checkpoint = checkpoint['state_dict'] 47 | ckpt = {k.replace('module.', ''):v for k,v in checkpoint.items()} # strip the names 48 | model.load_state_dict(ckpt) 49 | else: 50 | print("=> no checkpoint found at '{}'".format(args.weights)) 51 | 52 | 53 | cudnn.benchmark = True 54 | 55 | # Data loading code 56 | valdir = os.path.join(args.data, 'val') 57 | 58 | val_loader = torch.utils.data.DataLoader( 59 | datasets.ImageFolder(valdir, val_preprocess(224)), 60 | batch_size=args.batch_size, shuffle=False, 61 | num_workers=args.workers, pin_memory=True) 62 | 63 | validate(val_loader, model, criterion, use_gpu) 64 | 65 | 66 | def validate(val_loader, model, criterion, use_gpu): 67 | batch_time = AverageMeter('Time', ':6.3f') 68 | losses = AverageMeter('Loss', ':.4e') 69 | top1 = AverageMeter('Acc@1', ':6.2f') 70 | top5 = AverageMeter('Acc@5', ':6.2f') 71 | progress = ProgressMeter( 72 | len(val_loader), 73 | [batch_time, losses, top1, top5], 74 | prefix='Test: ') 75 | 76 | # switch to evaluate mode 77 | model.eval() 78 | 79 | with torch.no_grad(): 80 | end = time.time() 81 | for i, (images, target) in enumerate(val_loader): 82 | if use_gpu: 83 | images = images.cuda(non_blocking=True) 84 | target = target.cuda(non_blocking=True) 85 | 86 | # compute output 87 | output = model(images) 88 | loss = criterion(output, target) 89 | 90 | # measure accuracy and record loss 91 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 92 | losses.update(loss.item(), images.size(0)) 93 | top1.update(acc1[0], images.size(0)) 94 | top5.update(acc5[0], images.size(0)) 95 | 96 | # measure elapsed time 97 | batch_time.update(time.time() - end) 98 | end = time.time() 99 | 100 | if i % 10 == 0: 101 | progress.display(i) 102 | 103 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 104 | .format(top1=top1, top5=top5)) 105 | 106 | return top1.avg 107 | 108 | 109 | 110 | 111 | if __name__ == '__main__': 112 | test() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | from torch.optim.lr_scheduler import CosineAnnealingLR 20 | from utils import AverageMeter, accuracy, ProgressMeter, val_preprocess, strong_train_preprocess, standard_train_preprocess 21 | 22 | IMAGENET_TRAINSET_SIZE = 1281167 23 | 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 26 | parser.add_argument('data', metavar='DIR', 27 | help='path to dataset') 28 | 29 | parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18') 30 | parser.add_argument('-t', '--blocktype', metavar='BLK', default='DBB', choices=['DBB', 'ACB', 'base']) 31 | 32 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 37 | help='manual epoch number (useful on restarts)') 38 | parser.add_argument('-b', '--batch-size', default=256, type=int, 39 | metavar='N', 40 | help='mini-batch size (default: 256), this is the total ' 41 | 'batch size of all GPUs on the current node when ' 42 | 'using Data Parallel or Distributed Data Parallel') 43 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 44 | metavar='LR', help='initial learning rate', dest='lr') 45 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 46 | help='momentum') 47 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 48 | metavar='W', help='weight decay (default: 1e-4)', 49 | dest='weight_decay') 50 | parser.add_argument('-p', '--print-freq', default=10, type=int, 51 | metavar='N', help='print frequency (default: 10)') 52 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 53 | help='path to latest checkpoint (default: none)') 54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 55 | help='evaluate model on validation set') 56 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 57 | help='use pre-trained model') 58 | parser.add_argument('--world-size', default=-1, type=int, 59 | help='number of nodes for distributed training') 60 | parser.add_argument('--rank', default=-1, type=int, 61 | help='node rank for distributed training') 62 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 63 | help='url used to set up distributed training') 64 | parser.add_argument('--dist-backend', default='nccl', type=str, 65 | help='distributed backend') 66 | parser.add_argument('--seed', default=None, type=int, 67 | help='seed for initializing training. ') 68 | parser.add_argument('--gpu', default=None, type=int, 69 | help='GPU id to use.') 70 | parser.add_argument('--multiprocessing-distributed', action='store_true', 71 | help='Use multi-processing distributed training to launch ' 72 | 'N processes per node, which has N GPUs. This is the ' 73 | 'fastest way to use PyTorch for either single node or ' 74 | 'multi node data parallel training') 75 | 76 | best_acc1 = 0 77 | 78 | 79 | def sgd_optimizer(model, lr, momentum, weight_decay): 80 | params = [] 81 | for key, value in model.named_parameters(): 82 | if not value.requires_grad: 83 | continue 84 | apply_lr = lr 85 | apply_wd = weight_decay 86 | if 'bias' in key: 87 | apply_lr = 2 * lr # Just a Caffe-style common practice. Made no difference. 88 | if 'depth' in key: 89 | apply_wd = 0 90 | print('set weight decay ', key, apply_wd) 91 | params += [{'params': [value], 'lr': apply_lr, 'weight_decay': apply_wd}] 92 | optimizer = torch.optim.SGD(params, lr, momentum=momentum) 93 | return optimizer 94 | 95 | def main(): 96 | args = parser.parse_args() 97 | 98 | if args.seed is not None: 99 | random.seed(args.seed) 100 | torch.manual_seed(args.seed) 101 | cudnn.deterministic = True 102 | warnings.warn('You have chosen to seed training. ' 103 | 'This will turn on the CUDNN deterministic setting, ' 104 | 'which can slow down your training considerably! ' 105 | 'You may see unexpected behavior when restarting ' 106 | 'from checkpoints.') 107 | 108 | if args.gpu is not None: 109 | warnings.warn('You have chosen a specific GPU. This will completely ' 110 | 'disable data parallelism.') 111 | 112 | if args.dist_url == "env://" and args.world_size == -1: 113 | args.world_size = int(os.environ["WORLD_SIZE"]) 114 | 115 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 116 | 117 | ngpus_per_node = torch.cuda.device_count() 118 | if args.multiprocessing_distributed: 119 | # Since we have ngpus_per_node processes per node, the total world_size 120 | # needs to be adjusted accordingly 121 | args.world_size = ngpus_per_node * args.world_size 122 | # Use torch.multiprocessing.spawn to launch distributed processes: the 123 | # main_worker process function 124 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 125 | else: 126 | # Simply call main_worker function 127 | main_worker(args.gpu, ngpus_per_node, args) 128 | 129 | 130 | def main_worker(gpu, ngpus_per_node, args): 131 | global best_acc1 132 | args.gpu = gpu 133 | 134 | if args.gpu is not None: 135 | print("Use GPU: {} for training".format(args.gpu)) 136 | 137 | if args.distributed: 138 | if args.dist_url == "env://" and args.rank == -1: 139 | args.rank = int(os.environ["RANK"]) 140 | if args.multiprocessing_distributed: 141 | # For multiprocessing distributed training, rank needs to be the 142 | # global rank among all the processes 143 | args.rank = args.rank * ngpus_per_node + gpu 144 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 145 | world_size=args.world_size, rank=args.rank) 146 | 147 | # =========================== build model 148 | from convnet_utils import switch_deploy_flag, switch_conv_bn_impl, build_model 149 | switch_deploy_flag(False) 150 | switch_conv_bn_impl(args.blocktype) 151 | model = build_model(args.arch) 152 | 153 | if gpu == 0: 154 | for name, param in model.named_parameters(): 155 | print(name, param.size()) 156 | 157 | if not torch.cuda.is_available(): 158 | print('using CPU, this will be slow') 159 | elif args.distributed: 160 | # For multiprocessing distributed, DistributedDataParallel constructor 161 | # should always set the single device scope, otherwise, 162 | # DistributedDataParallel will use all available devices. 163 | if args.gpu is not None: 164 | torch.cuda.set_device(args.gpu) 165 | model.cuda(args.gpu) 166 | # When using a single GPU per process and per 167 | # DistributedDataParallel, we need to divide the batch size 168 | # ourselves based on the total number of GPUs we have 169 | args.batch_size = int(args.batch_size / ngpus_per_node) 170 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 171 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False) 172 | else: 173 | model.cuda() 174 | # DistributedDataParallel will divide and allocate batch_size to all 175 | # available GPUs if device_ids are not set 176 | model = torch.nn.parallel.DistributedDataParallel(model, broadcast_buffers=False) 177 | elif args.gpu is not None: 178 | torch.cuda.set_device(args.gpu) 179 | model = model.cuda(args.gpu) 180 | else: 181 | # DataParallel will divide and allocate batch_size to all available GPUs 182 | model = torch.nn.DataParallel(model).cuda() 183 | 184 | 185 | # define loss function (criterion) and optimizer 186 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 187 | 188 | optimizer = sgd_optimizer(model, args.lr, args.momentum, args.weight_decay) 189 | 190 | lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node) 191 | 192 | # optionally resume from a checkpoint 193 | if args.resume: 194 | if os.path.isfile(args.resume): 195 | print("=> loading checkpoint '{}'".format(args.resume)) 196 | if args.gpu is None: 197 | checkpoint = torch.load(args.resume) 198 | else: 199 | # Map model to be loaded to specified single gpu. 200 | loc = 'cuda:{}'.format(args.gpu) 201 | checkpoint = torch.load(args.resume, map_location=loc) 202 | args.start_epoch = checkpoint['epoch'] 203 | best_acc1 = checkpoint['best_acc1'] 204 | if args.gpu is not None: 205 | # best_acc1 may be from a checkpoint from a different GPU 206 | best_acc1 = best_acc1.to(args.gpu) 207 | model.load_state_dict(checkpoint['state_dict']) 208 | optimizer.load_state_dict(checkpoint['optimizer']) 209 | lr_scheduler.load_state_dict(checkpoint['scheduler']) 210 | print("=> loaded checkpoint '{}' (epoch {})" 211 | .format(args.resume, checkpoint['epoch'])) 212 | else: 213 | print("=> no checkpoint found at '{}'".format(args.resume)) 214 | 215 | cudnn.benchmark = True 216 | 217 | # Data loading code 218 | traindir = os.path.join(args.data, 'train') 219 | valdir = os.path.join(args.data, 'val') 220 | 221 | trans = strong_train_preprocess(224) if 'ResNet' in args.arch else standard_train_preprocess(224) 222 | print('aug is ', trans) 223 | train_dataset = datasets.ImageFolder(traindir, trans) 224 | 225 | if args.distributed: 226 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True) 227 | else: 228 | train_sampler = None 229 | 230 | train_loader = torch.utils.data.DataLoader( 231 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 232 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 233 | 234 | 235 | val_dataset = datasets.ImageFolder(valdir, val_preprocess(224)) 236 | val_loader = torch.utils.data.DataLoader( 237 | val_dataset, 238 | batch_size=args.batch_size, shuffle=False, 239 | num_workers=args.workers, pin_memory=True) 240 | 241 | 242 | if args.evaluate: 243 | validate(val_loader, model, criterion, args) 244 | return 245 | 246 | for epoch in range(args.start_epoch, args.epochs): 247 | if args.distributed: 248 | train_sampler.set_epoch(epoch) 249 | print('set sampler') 250 | # adjust_learning_rate(optimizer, epoch, args) 251 | 252 | # train for one epoch 253 | train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler) 254 | 255 | # evaluate on validation set 256 | acc1 = validate(val_loader, model, criterion, args) 257 | 258 | # remember best acc@1 and save checkpoint 259 | is_best = acc1 > best_acc1 260 | best_acc1 = max(acc1, best_acc1) 261 | 262 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 263 | and args.rank % ngpus_per_node == 0): 264 | save_checkpoint({ 265 | 'epoch': epoch + 1, 266 | 'arch': args.arch, 267 | 'state_dict': model.state_dict(), 268 | 'best_acc1': best_acc1, 269 | 'optimizer' : optimizer.state_dict(), 270 | 'scheduler': lr_scheduler.state_dict(), 271 | }, is_best, filename='{}_{}.pth.tar'.format(args.arch, args.blocktype)) 272 | 273 | 274 | def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler): 275 | batch_time = AverageMeter('Time', ':6.3f') 276 | data_time = AverageMeter('Data', ':6.3f') 277 | losses = AverageMeter('Loss', ':.4e') 278 | top1 = AverageMeter('Acc@1', ':6.2f') 279 | top5 = AverageMeter('Acc@5', ':6.2f') 280 | progress = ProgressMeter( 281 | len(train_loader), 282 | [batch_time, data_time, losses, top1, top5, ], 283 | prefix="Epoch: [{}]".format(epoch)) 284 | 285 | # switch to train mode 286 | model.train() 287 | 288 | end = time.time() 289 | for i, (images, target) in enumerate(train_loader): 290 | # measure data loading time 291 | data_time.update(time.time() - end) 292 | 293 | if args.gpu is not None: 294 | images = images.cuda(args.gpu, non_blocking=True) 295 | if torch.cuda.is_available(): 296 | target = target.cuda(args.gpu, non_blocking=True) 297 | 298 | # compute output 299 | output = model(images) 300 | loss = criterion(output, target) 301 | 302 | # measure accuracy and record loss 303 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 304 | losses.update(loss.item(), images.size(0)) 305 | top1.update(acc1[0], images.size(0)) 306 | top5.update(acc5[0], images.size(0)) 307 | 308 | # compute gradient and do SGD step 309 | optimizer.zero_grad() 310 | loss.backward() 311 | optimizer.step() 312 | 313 | # measure elapsed time 314 | batch_time.update(time.time() - end) 315 | end = time.time() 316 | 317 | lr_scheduler.step() 318 | 319 | if i % args.print_freq == 0 and args.gpu == 0: 320 | progress.display(i) 321 | if i % 1000 == 0 and args.gpu == 0: 322 | print('cur lr: ', lr_scheduler.get_lr()[0]) 323 | 324 | 325 | 326 | 327 | def validate(val_loader, model, criterion, args): 328 | batch_time = AverageMeter('Time', ':6.3f') 329 | losses = AverageMeter('Loss', ':.4e') 330 | top1 = AverageMeter('Acc@1', ':6.2f') 331 | top5 = AverageMeter('Acc@5', ':6.2f') 332 | progress = ProgressMeter( 333 | len(val_loader), 334 | [batch_time, losses, top1, top5], 335 | prefix='Test: ') 336 | 337 | # switch to evaluate mode 338 | model.eval() 339 | 340 | with torch.no_grad(): 341 | end = time.time() 342 | for i, (images, target) in enumerate(val_loader): 343 | if args.gpu is not None: 344 | images = images.cuda(args.gpu, non_blocking=True) 345 | if torch.cuda.is_available(): 346 | target = target.cuda(args.gpu, non_blocking=True) 347 | 348 | # compute output 349 | output = model(images) 350 | loss = criterion(output, target) 351 | 352 | # measure accuracy and record loss 353 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 354 | losses.update(loss.item(), images.size(0)) 355 | top1.update(acc1[0], images.size(0)) 356 | top5.update(acc5[0], images.size(0)) 357 | 358 | # measure elapsed time 359 | batch_time.update(time.time() - end) 360 | end = time.time() 361 | 362 | if i % args.print_freq == 0: 363 | progress.display(i) 364 | 365 | 366 | # TODO: this should also be done with the ProgressMeter 367 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 368 | .format(top1=top1, top5=top5)) 369 | 370 | return top1.avg 371 | 372 | 373 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 374 | torch.save(state, filename) 375 | if is_best: 376 | shutil.copyfile(filename, filename.replace('.pth.tar', '_best.pth.tar')) 377 | 378 | 379 | 380 | 381 | if __name__ == '__main__': 382 | main() 383 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | 4 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 5 | std=[0.229, 0.224, 0.225]) 6 | 7 | class PCALighting(object): 8 | """Lighting noise(AlexNet - style PCA - based noise)""" 9 | def __init__(self, alphastd, eigval, eigvec): 10 | self.alphastd = alphastd 11 | self.eigval = eigval 12 | self.eigvec = eigvec 13 | 14 | def __call__(self, img): 15 | if self.alphastd == 0: 16 | return img 17 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 18 | rgb = self.eigvec.type_as(img).clone()\ 19 | .mul(alpha.view(1, 3).expand(3, 3))\ 20 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 21 | .sum(1).squeeze() 22 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 23 | 24 | 25 | imagenet_pca = { 26 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 27 | 'eigvec': torch.Tensor([ 28 | [-0.5675, 0.7192, 0.4009], 29 | [-0.5808, -0.0045, -0.8140], 30 | [-0.5836, -0.6948, 0.4203], 31 | ]) 32 | } 33 | 34 | def strong_train_preprocess(img_size): 35 | trans = transforms.Compose([ 36 | transforms.RandomResizedCrop(img_size), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ColorJitter(brightness=0.4, saturation=0.4, hue=0.4), 39 | transforms.ToTensor(), 40 | PCALighting(0.1, imagenet_pca['eigval'], imagenet_pca['eigvec']), 41 | normalize, 42 | ]) 43 | print('---------------------- strong dataaug!') 44 | return trans 45 | 46 | def standard_train_preprocess(img_size): 47 | trans = transforms.Compose([ 48 | transforms.RandomResizedCrop(img_size), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | normalize, 52 | ]) 53 | print('---------------------- weak dataaug!') 54 | return trans 55 | 56 | def val_preprocess(img_size): 57 | trans = transforms.Compose([ 58 | transforms.Resize(256), 59 | transforms.CenterCrop(img_size), 60 | transforms.ToTensor(), 61 | normalize, 62 | ]) 63 | return trans 64 | 65 | class AverageMeter(object): 66 | """Computes and stores the average and current value""" 67 | def __init__(self, name, fmt=':f'): 68 | self.name = name 69 | self.fmt = fmt 70 | self.reset() 71 | 72 | def reset(self): 73 | self.val = 0 74 | self.avg = 0 75 | self.sum = 0 76 | self.count = 0 77 | 78 | def update(self, val, n=1): 79 | self.val = val 80 | self.sum += val * n 81 | self.count += n 82 | self.avg = self.sum / self.count 83 | 84 | def __str__(self): 85 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 86 | return fmtstr.format(**self.__dict__) 87 | 88 | 89 | class ProgressMeter(object): 90 | def __init__(self, num_batches, meters, prefix=""): 91 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 92 | self.meters = meters 93 | self.prefix = prefix 94 | 95 | def display(self, batch): 96 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 97 | entries += [str(meter) for meter in self.meters] 98 | print('\t'.join(entries)) 99 | 100 | def _get_batch_fmtstr(self, num_batches): 101 | num_digits = len(str(num_batches // 1)) 102 | fmt = '{:' + str(num_digits) + 'd}' 103 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 104 | 105 | 106 | def accuracy(output, target, topk=(1,)): 107 | """Computes the accuracy over the k top predictions for the specified values of k""" 108 | with torch.no_grad(): 109 | maxk = max(topk) 110 | batch_size = target.size(0) 111 | 112 | _, pred = output.topk(maxk, 1, True, True) 113 | pred = pred.t() 114 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 115 | 116 | res = [] 117 | for k in topk: 118 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 119 | res.append(correct_k.mul_(100.0 / batch_size)) 120 | return res 121 | 122 | 123 | def read_hdf5(file_path): 124 | import h5py 125 | import numpy as np 126 | result = {} 127 | with h5py.File(file_path, 'r') as f: 128 | for k in f.keys(): 129 | value = np.asarray(f[k]) 130 | result[str(k).replace('+', '/')] = value 131 | print('read {} arrays from {}'.format(len(result), file_path)) 132 | f.close() 133 | return result 134 | 135 | def model_load_hdf5(model:torch.nn.Module, hdf5_path, ignore_keys='stage0.'): 136 | weights_dict = read_hdf5(hdf5_path) 137 | for name, param in model.named_parameters(): 138 | print('load param: ', name, param.size()) 139 | if name in weights_dict: 140 | np_value = weights_dict[name] 141 | else: 142 | np_value = weights_dict[name.replace(ignore_keys, '')] 143 | value = torch.from_numpy(np_value).float() 144 | assert tuple(value.size()) == tuple(param.size()) 145 | param.data = value 146 | for name, param in model.named_buffers(): 147 | print('load buffer: ', name, param.size()) 148 | if name in weights_dict: 149 | np_value = weights_dict[name] 150 | else: 151 | np_value = weights_dict[name.replace(ignore_keys, '')] 152 | value = torch.from_numpy(np_value).float() 153 | assert tuple(value.size()) == tuple(param.size()) 154 | param.data = value --------------------------------------------------------------------------------