├── LICENSE ├── README.md ├── basic.py ├── blocks.py ├── blocks_repvgg.py ├── convert.py ├── convnet_utils.py ├── dbb_transforms.py ├── images ├── budgets.png ├── coco_cs.PNG ├── imagenet1.PNG ├── imagenet2.PNG ├── intro.png ├── norm.PNG ├── overview.png ├── similarity.png └── supp_grad.png ├── models ├── repvgg.py ├── resnet.py └── resnext.py ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OREPA: Online Convolutional Re-parameterization 2 | This repo is the PyTorch implementation of our paper to appear in CVPR2022 on ["Online Convolutional Re-parameterization"](https://arxiv.org/abs/2204.00826), authored by 3 | Mu Hu, [Junyi Feng](https://github.com/Sixkplus), [Jiashen Hua](https://github.com/JerExJs), Baisheng Lai, Jianqiang Huang, [Xiaojin Gong](https://person.zju.edu.cn/en/gongxj) and [Xiansheng Hua](https://damo.alibaba.com/labs/city-brain) from Zhejiang University and Alibaba Cloud. 4 | 5 | ## What is Structural Re-parameterization? 6 | + Re-parameterization (Re-param) means different architectures can be mutually converted through equivalent transformation of parameters. For example, a branch of 1x1 convolution and a branch of 3x3 convolution, can be transferred into a single branch of 3x3 convolution for faster inference. 7 | + When the model for deployment is fixed, the task of re-param can be regarded as finding a complex training-time structure, which can be transfered back to the original one, for free performance improvements. 8 | 9 |
10 | 11 |
12 | 13 | ## Why do we propose Online RE-PAram? (OREPA) 14 | + While current re-param blocks ([ACNet](https://github.com/DingXiaoH/ACNet), [ExpandNet](https://github.com/GUOShuxuan/expandnets), [ACNetv2](https://github.com/DingXiaoH/DiverseBranchBlock), *etc*) are still feasible for small models, more complecated design for further performance gain on larger models could lead to unaffordable training budgets. 15 | + We observed that batch normalization (norm) layers are significant in re-param blocks, while their training-time non-linearity prevents us from optimizing computational costs during training. 16 | 17 | 18 | 19 | ## What is OREPA? 20 | OREPA is a two-step pipeline. 21 | + Linearization: Replace the branch-wise norm layers to scaling layers to enable the linear squeezing of a multi-branch/layer topology. 22 | + Squeezing: Squeeze the linearized block into a single layer, where the convolution upon feature maps is reduced from multiple times to one. 23 | 24 | ![Overview](https://github.com/JUGGHM/OREPA_CVPR2022/blob/main/images/overview.png) 25 | 26 | ## How does OREPA work? 27 | + Through OREPA we could reduce the training budgets while keeping a comparable performance. Then we improve accuracy by additional components, which brings minor extra training costs since they are merged in an online scheme. 28 | + We theoretically present that the removal of branch-wise norm layers risks a multi-branch structure degrading into a single-branch one, indicating that the norm-scaling layer replacement is critical for protecting branch diversity. 29 | 30 | 31 | 32 | ## ImageNet Results 33 |
34 | 35 |
36 | 37 | ![ImageNet2](https://github.com/JUGGHM/OREPA_CVPR2022/blob/main/images/imagenet2.PNG) 38 | 39 | Create a new issue for any code-related questions. Feel free to direct me as well at muhu@zju.edu.cn for any paper-related questions. 40 | 41 | ## Contents 42 | 1. [Dependency](#dependency) 43 | 2. [Checkpoints](#checkpoints) 44 | 3. [Training](#training) 45 | 4. [Evaluation](#evaluation) 46 | 5. [Transfer Learning on COCO and Cityscapes](#transfer-learning-on-coco-and-cityscapes) 47 | 6. [About Quantization and Gradient Tweaking](#about-quantization-and-gradient-tweaking) 48 | 7. [Citation](#citation) 49 | 50 | 51 | ## Dependency 52 | Models released in this work is trained and tested on: 53 | + CentOS Linux 54 | + Python 3.8.8 (Anaconda 4.9.1) 55 | + PyTorch 1.9.0 / torchvision 0.10.0 56 | + NVIDIA CUDA 10.2 57 | + 4x NVIDIA V100 GPUs 58 | 59 | ```bash 60 | pip install torch torchvision 61 | pip install numpy matplotlib Pillow 62 | pip install scikit-image 63 | ``` 64 | 65 | ## Checkpoints 66 | Download our pre-trained models with OREPA: 67 | - [ResNet-18](https://drive.google.com/file/d/1Z0cfmxLLWD2xjgUXpE2Y_4vxS4ij16Sx/view?usp=sharing) 68 | - [ResNet-34](https://drive.google.com/file/d/1tOC4yoslHF829Yb66_eE4qsFoazf6uoY/view?usp=sharing) 69 | - [ResNet-50](https://drive.google.com/file/d/1Jn92wGHlFzkPjQxAmYdRU_EbSRQjsw0v/view?usp=sharing) 70 | - [ResNet-101](https://drive.google.com/file/d/10L4fwYlB21vlMKOGIyXdlGBbf3l8nflI/view?usp=sharing) 71 | - [RepVGG-A0](https://drive.google.com/file/d/1r674bxWKL5dwDA8OPEuGWN8zbiUGkqsx/view?usp=sharing) 72 | - [RepVGG-A1](https://drive.google.com/file/d/1NVRn8Xave-0jY3R0xAXCShdBcML6f-3q/view?usp=sharing) 73 | - [RepVGG-A2](https://drive.google.com/file/d/1ImHkTct0ACDtOw8sPgKhqerBFigg32s3/view?usp=sharing) 74 | - [WideResNet-18(x2)](https://drive.google.com/file/d/1gseMlq9JGntggyZoK_Bjt7xiSPNs5deF/view?usp=sharing) 75 | - [ResNeXt-50](https://drive.google.com/file/d/1p_eZDhMHQ_xmfXd8Z1RZnWfgwUbi7WH-/view?usp=sharing) 76 | 77 | Note that we don't need to decompress the pre-trained models. Just load the files of .pth.tar format directly. 78 | 79 | ## Training 80 | A complete list of training options is available with 81 | ```bash 82 | python train.py -h 83 | python test.py -h 84 | python convert.py -h 85 | ``` 86 | 87 | 1. Train ResNets (ResNeXt and WideResNet included) 88 | ```bash 89 | CUDA_VISIBLE_DEVICES="0,1,2,3" python train.py -a ResNet-18 -t OREPA --data [imagenet-path] 90 | # -a for architecture (ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-18-2x, ResNeXt-50) 91 | # -t for re-param method (base, DBB, OREPA) 92 | ``` 93 | 94 | 2. Train RepVGGs 95 | ```bash 96 | CUDA_VISIBLE_DEVICES="0,1,2,3" python train.py -a RepVGG-A0 -t OREPA_VGG --data [imagenet-path] 97 | # -a for architecture (RepVGG-A0, RepVGG-A1, RepVGG-A2) 98 | # -t for re-param method (base, RepVGG, OREPA_VGG) 99 | ``` 100 | 101 | ## Evaluation 102 | 1. Use your self-trained model or our pretrained model 103 | ```bash 104 | CUDA_VISIBLE_DEVICES="0" python test.py train [trained-model-path] -a ResNet-18 -t OREPA 105 | ``` 106 | 107 | 2. Convert the training-time models into inference-time models 108 | ```bash 109 | CUDA_VISIBLE_DEVICES="0" python convert.py [trained-model-path] [deploy-model-path-to-save] -a ResNet-18 -t OREPA 110 | ``` 111 | 112 | 3. Evaluate with the converted model 113 | ```bash 114 | CUDA_VISIBLE_DEVICES="0" python test.py deploy [deploy-model-path] -a ResNet-18 -t OREPA 115 | ``` 116 | 117 | ## Transfer Learning on COCO and Cityscapes 118 | We use [mmdetection](https://github.com/open-mmlab/mmdetection) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) tools on COCO and Cityscapes respectively. If you decide to use our pretrained model for downstream tasks, it is strongly suggested that the learning rate of the first stem layer should be fine adjusted, since the deep linear stem layer has a very different weight distribution from the vanilla one after ImageNet training. Contact [@Sixkplus](https://github.com/Sixkplus) (Junyi Feng) for more details on configurations and checkpoints of the reported ResNet-50-backbone models. 119 | 120 |
121 | 122 |
123 | 124 | ## About Quantization and Gradient Tweaking 125 | For re-param models, special weight regulization strategies are required for furthur quantization. Meanwhile, dynamic gradient tweaking or differential searching methods might greatly boost the performance. Currently we have not deployed such techniques to OREPA yet. However such methods could be probably applied to our industrial usage in the future. For experience exchanging and sharing on such topics please contact [@Sixkplus](https://github.com/Sixkplus) (Junyi Feng). 126 | 127 | 128 | ## Citation 129 | If you use our code or method in your work, please cite the following: 130 | 131 | @inproceedings{hu22OREPA, 132 | title={Online Convolutional Re-parameterization}, 133 | author={Mu Hu and Junyi Feng and Jiashen Hua and Baisheng Lai and Jianqiang Huang and Xiansheng Hua and Xiaojin Gong}, 134 | booktitle={CVPR}, 135 | year={2022} 136 | } 137 | 138 | ## Related Repositories 139 | Codes of this work is developed upon Xiaohan Ding's re-param repositories ["Diverse Branch Block: Building a Convolution as an Inception-like Unit"](https://github.com/DingXiaoH/DiverseBranchBlock) and ["RepVGG: Making VGG-style ConvNets Great Again"](https://github.com/DingXiaoH/RepVGG) with similar protocols. [Xiaohan Ding](https://scholar.google.com/citations?user=CIjw0KoAAAAJ&hl=en) is a Ph.D. from Tsinghua University and an expert in structural re-parameterization. 140 | -------------------------------------------------------------------------------- /basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torch.nn.init as init 6 | import math 7 | from dbb_transforms import transI_fusebn 8 | 9 | class ConvBN(nn.Module): 10 | def __init__(self, in_channels, out_channels, kernel_size, 11 | stride=1, padding=0, dilation=1, groups=1, deploy=False, nonlinear=None): 12 | super().__init__() 13 | if nonlinear is None: 14 | self.nonlinear = nn.Identity() 15 | else: 16 | self.nonlinear = nonlinear 17 | if deploy: 18 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 19 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True) 20 | else: 21 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 22 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False) 23 | self.bn = nn.BatchNorm2d(num_features=out_channels) 24 | 25 | def forward(self, x): 26 | if hasattr(self, 'bn'): 27 | return self.nonlinear(self.bn(self.conv(x))) 28 | else: 29 | return self.nonlinear(self.conv(x)) 30 | 31 | def switch_to_deploy(self): 32 | kernel, bias = transI_fusebn(self.conv.weight, self.bn) 33 | conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size, 34 | stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True) 35 | conv.weight.data = kernel 36 | conv.bias.data = bias 37 | for para in self.parameters(): 38 | para.detach_() 39 | self.__delattr__('conv') 40 | self.__delattr__('bn') 41 | self.conv = conv 42 | 43 | class IdentityBasedConv1x1(nn.Conv2d): 44 | 45 | def __init__(self, channels, groups=1): 46 | super(IdentityBasedConv1x1, self).__init__(in_channels=channels, out_channels=channels, kernel_size=1, stride=1, padding=0, groups=groups, bias=False) 47 | 48 | assert channels % groups == 0 49 | input_dim = channels // groups 50 | id_value = np.zeros((channels, input_dim, 1, 1)) 51 | for i in range(channels): 52 | id_value[i, i % input_dim, 0, 0] = 1 53 | self.id_tensor = torch.from_numpy(id_value).type_as(self.weight) 54 | nn.init.zeros_(self.weight) 55 | 56 | def forward(self, input): 57 | kernel = self.weight + self.id_tensor.to(self.weight.device) 58 | result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups) 59 | return result 60 | 61 | def get_actual_kernel(self): 62 | return self.weight + self.id_tensor.to(self.weight.device) 63 | 64 | 65 | class BNAndPadLayer(nn.Module): 66 | def __init__(self, 67 | pad_pixels, 68 | num_features, 69 | eps=1e-5, 70 | momentum=0.1, 71 | affine=True, 72 | track_running_stats=True): 73 | super(BNAndPadLayer, self).__init__() 74 | self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) 75 | self.pad_pixels = pad_pixels 76 | 77 | def forward(self, input): 78 | output = self.bn(input) 79 | if self.pad_pixels > 0: 80 | if self.bn.affine: 81 | pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(self.bn.running_var + self.bn.eps) 82 | else: 83 | pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps) 84 | output = F.pad(output, [self.pad_pixels] * 4) 85 | pad_values = pad_values.view(1, -1, 1, 1) 86 | output[:, :, 0:self.pad_pixels, :] = pad_values 87 | output[:, :, -self.pad_pixels:, :] = pad_values 88 | output[:, :, :, 0:self.pad_pixels] = pad_values 89 | output[:, :, :, -self.pad_pixels:] = pad_values 90 | return output 91 | 92 | @property 93 | def weight(self): 94 | return self.bn.weight 95 | 96 | @property 97 | def bias(self): 98 | return self.bn.bias 99 | 100 | @property 101 | def running_mean(self): 102 | return self.bn.running_mean 103 | 104 | @property 105 | def running_var(self): 106 | return self.bn.running_var 107 | 108 | @property 109 | def eps(self): 110 | return self.bn.eps 111 | 112 | class PriorFilter(nn.Module): 113 | 114 | def __init__(self, channels, stride, padding, width=4): 115 | super(PriorFilter, self).__init__() 116 | self.stride = stride 117 | self.width = width 118 | self.group = int(channels/width) 119 | self.padding = padding 120 | self.prior_tensor = torch.Tensor(self.group, 1, 3, 3) 121 | 122 | half_g = int(self.group/2) 123 | for i in range(self.group): 124 | for h in range(3): 125 | for w in range(3): 126 | if i < half_g: 127 | self.prior_tensor[i, 0, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3) 128 | else: 129 | self.prior_tensor[i, 0, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_g)/3) 130 | 131 | self.register_buffer('prior', self.prior_tensor) 132 | 133 | def forward(self, x): 134 | b, c, h, w = x.size() 135 | x = x.view(b*self.width, self.group, h, w) 136 | return F.conv2d(input=x, weight=self.prior, bias=None, stride=self.stride, padding=self.padding, groups=self.group).view(b, c, int(h/self.stride), int(w/self.stride)) 137 | 138 | class PriorConv(nn.Module): 139 | 140 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, width=4): 141 | super(PriorConv, self).__init__() 142 | self.stride = stride 143 | self.width = width 144 | self.group = int(in_channels/width) 145 | self.padding = padding 146 | self.prior_tensor = torch.Tensor(self.group, 3, 3) 147 | 148 | half_g = int(self.group/2) 149 | for i in range(self.group): 150 | for h in range(3): 151 | for w in range(3): 152 | if i < half_g: 153 | self.prior_tensor[i, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3) 154 | else: 155 | self.prior_tensor[i, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_g)/3) 156 | 157 | self.register_buffer('prior', self.prior_tensor) 158 | 159 | self.weight = torch.nn.Parameter(torch.FloatTensor(out_channels, width, self.group, kernel_size, kernel_size)) 160 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 161 | 162 | def forward(self, x): 163 | c_out, wi, g, kh, kw = self.weight.size() 164 | weight = torch.einsum('gmn,cwgmn->cwgmn', self.prior, self.weight).view(c_out, int(wi*g), kh, kw) 165 | return F.conv2d(input=x, weight=weight, bias=None, stride=self.stride, padding=self.padding) -------------------------------------------------------------------------------- /blocks.py: -------------------------------------------------------------------------------- 1 | 2 | from basic import * 3 | import numpy as np 4 | from dbb_transforms import transI_fusebn, transII_addbranch, transIII_1x1_kxk, transV_avg, transVI_multiscale 5 | 6 | class DBB(nn.Module): 7 | 8 | def __init__(self, in_channels, out_channels, kernel_size, 9 | stride=1, padding=0, dilation=1, groups=1, 10 | internal_channels_1x1_3x3=None, 11 | deploy=False, nonlinear=None, single_init=False): 12 | super(DBB, self).__init__() 13 | self.deploy = deploy 14 | 15 | if nonlinear is None: 16 | self.nonlinear = nn.Identity() 17 | else: 18 | self.nonlinear = nonlinear 19 | 20 | self.kernel_size = kernel_size 21 | self.out_channels = out_channels 22 | self.groups = groups 23 | assert padding == kernel_size // 2 24 | 25 | if deploy: 26 | self.dbb_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 27 | padding=padding, dilation=dilation, groups=groups, bias=True) 28 | 29 | else: 30 | 31 | self.dbb_origin = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups) 32 | 33 | self.dbb_avg = nn.Sequential() 34 | if groups < out_channels: 35 | self.dbb_avg.add_module('conv', 36 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, 37 | stride=1, padding=0, groups=groups, bias=False)) 38 | self.dbb_avg.add_module('bn', BNAndPadLayer(pad_pixels=padding, num_features=out_channels)) 39 | self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) 40 | self.dbb_1x1 = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, 41 | padding=0, groups=groups) 42 | else: 43 | self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding)) 44 | 45 | self.dbb_avg.add_module('avgbn', nn.BatchNorm2d(out_channels)) 46 | 47 | 48 | if internal_channels_1x1_3x3 is None: 49 | internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels 50 | 51 | self.dbb_1x1_kxk = nn.Sequential() 52 | if internal_channels_1x1_3x3 == in_channels: 53 | self.dbb_1x1_kxk.add_module('idconv1', IdentityBasedConv1x1(channels=in_channels, groups=groups)) 54 | else: 55 | self.dbb_1x1_kxk.add_module('conv1', nn.Conv2d(in_channels=in_channels, out_channels=internal_channels_1x1_3x3, 56 | kernel_size=1, stride=1, padding=0, groups=groups, bias=False)) 57 | self.dbb_1x1_kxk.add_module('bn1', BNAndPadLayer(pad_pixels=padding, num_features=internal_channels_1x1_3x3, affine=True)) 58 | self.dbb_1x1_kxk.add_module('conv2', nn.Conv2d(in_channels=internal_channels_1x1_3x3, out_channels=out_channels, 59 | kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=False)) 60 | self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels)) 61 | 62 | # The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases. 63 | if single_init: 64 | # Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting. 65 | self.single_init() 66 | 67 | def get_equivalent_kernel_bias(self): 68 | k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight, self.dbb_origin.bn) 69 | 70 | if hasattr(self, 'dbb_1x1'): 71 | k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn) 72 | k_1x1 = transVI_multiscale(k_1x1, self.kernel_size) 73 | else: 74 | k_1x1, b_1x1 = 0, 0 75 | 76 | if hasattr(self.dbb_1x1_kxk, 'idconv1'): 77 | k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel() 78 | else: 79 | k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight 80 | k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first, self.dbb_1x1_kxk.bn1) 81 | k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2) 82 | k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first, b_1x1_kxk_first, k_1x1_kxk_second, b_1x1_kxk_second, groups=self.groups) 83 | 84 | k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups) 85 | k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device), self.dbb_avg.avgbn) 86 | if hasattr(self.dbb_avg, 'conv'): 87 | k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight, self.dbb_avg.bn) 88 | k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first, b_1x1_avg_first, k_1x1_avg_second, b_1x1_avg_second, groups=self.groups) 89 | else: 90 | k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second 91 | 92 | return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged), (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged)) 93 | 94 | def switch_to_deploy(self): 95 | if hasattr(self, 'dbb_reparam'): 96 | return 97 | kernel, bias = self.get_equivalent_kernel_bias() 98 | self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels, out_channels=self.dbb_origin.conv.out_channels, 99 | kernel_size=self.dbb_origin.conv.kernel_size, stride=self.dbb_origin.conv.stride, 100 | padding=self.dbb_origin.conv.padding, dilation=self.dbb_origin.conv.dilation, groups=self.dbb_origin.conv.groups, bias=True) 101 | self.dbb_reparam.weight.data = kernel 102 | self.dbb_reparam.bias.data = bias 103 | for para in self.parameters(): 104 | para.detach_() 105 | self.__delattr__('dbb_origin') 106 | self.__delattr__('dbb_avg') 107 | if hasattr(self, 'dbb_1x1'): 108 | self.__delattr__('dbb_1x1') 109 | self.__delattr__('dbb_1x1_kxk') 110 | 111 | def forward(self, inputs): 112 | 113 | if hasattr(self, 'dbb_reparam'): 114 | return self.nonlinear(self.dbb_reparam(inputs)) 115 | 116 | out = self.dbb_origin(inputs) 117 | if hasattr(self, 'dbb_1x1'): 118 | out += self.dbb_1x1(inputs) 119 | out += self.dbb_avg(inputs) 120 | out += self.dbb_1x1_kxk(inputs) 121 | return self.nonlinear(out) 122 | 123 | def init_gamma(self, gamma_value): 124 | if hasattr(self, "dbb_origin"): 125 | torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value) 126 | if hasattr(self, "dbb_1x1"): 127 | torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value) 128 | if hasattr(self, "dbb_avg"): 129 | torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value) 130 | if hasattr(self, "dbb_1x1_kxk"): 131 | torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value) 132 | 133 | def single_init(self): 134 | self.init_gamma(0.0) 135 | if hasattr(self, "dbb_origin"): 136 | torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0) 137 | 138 | class OREPA_1x1(nn.Module): 139 | 140 | def __init__(self, in_channels, out_channels, kernel_size=1, 141 | stride=1, padding=0, dilation=1, groups=1, 142 | deploy=False, nonlinear=None, single_init=False): 143 | super(OREPA_1x1, self).__init__() 144 | self.deploy = deploy 145 | 146 | if nonlinear is None: 147 | self.nonlinear = nn.Identity() 148 | else: 149 | self.nonlinear = nonlinear 150 | 151 | self.in_channels = in_channels 152 | self.out_channels = out_channels 153 | self.groups = groups 154 | assert groups == 1 155 | assert kernel_size == 1 156 | assert padding == kernel_size // 2 157 | 158 | self.kernel_size = kernel_size 159 | self.stride = stride 160 | self.padding = padding 161 | self.dilation = dilation 162 | 163 | if deploy: 164 | self.or1x1_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 165 | padding=padding, dilation=dilation, groups=groups, bias=True) 166 | 167 | else: 168 | self.branch_counter = 0 169 | 170 | self.weight_or1x1_origin = nn.Parameter(torch.Tensor(out_channels, in_channels, 1, 1)) 171 | init.kaiming_uniform_(self.weight_or1x1_origin, a=math.sqrt(1.0)) 172 | self.branch_counter += 1 173 | 174 | if out_channels > in_channels: 175 | self.weight_or1x1_l2i_conv1 = nn.Parameter(torch.eye(in_channels).unsqueeze(2).unsqueeze(3)) 176 | self.weight_or1x1_l2i_conv2 = nn.Parameter(torch.Tensor(out_channels, in_channels, 1, 1)) 177 | init.kaiming_uniform_(self.weight_or1x1_l2i_conv2, a=math.sqrt(1.0)) 178 | else: 179 | self.weight_or1x1_l2i_conv1 = nn.Parameter(torch.Tensor(out_channels, in_channels, 1, 1)) 180 | init.kaiming_uniform_(self.weight_or1x1_l2i_conv1, a=math.sqrt(1.0)) 181 | self.weight_or1x1_l2i_conv2 = nn.Parameter(torch.eye(out_channels).unsqueeze(2).unsqueeze(3)) 182 | self.branch_counter += 1 183 | 184 | self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels)) 185 | self.bn = nn.BatchNorm2d(self.out_channels) 186 | 187 | init.constant_(self.vector[0, :], 1.0) 188 | init.constant_(self.vector[1, :], 0.5) 189 | 190 | if single_init: 191 | # Initialize the vector.weight of origin as 1 and others as 0. This is not the default setting. 192 | self.single_init() 193 | 194 | def weight_gen(self): 195 | 196 | weight_or1x1_origin = torch.einsum('oihw,o->oihw', self.weight_or1x1_origin, self.vector[0, :]) 197 | 198 | weight_or1x1_l2i = torch.einsum('tihw,othw->oihw', self.weight_or1x1_l2i_conv1, self.weight_or1x1_l2i_conv2) 199 | weight_or1x1_l2i = torch.einsum('oihw,o->oihw', weight_or1x1_l2i, self.vector[1, :]) 200 | 201 | return weight_or1x1_origin + weight_or1x1_l2i 202 | 203 | def forward(self, inputs): 204 | if hasattr(self, 'or1x1_reparam'): 205 | return self.nonlinear(self.or1x1_reparam(inputs)) 206 | 207 | weight = self.weight_gen() 208 | out = F.conv2d(inputs, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) 209 | return self.nonlinear(self.bn(out)) 210 | 211 | def get_equivalent_kernel_bias(self): 212 | return transI_fusebn(self.weight_gen(), self.bn) 213 | 214 | def switch_to_deploy(self): 215 | if hasattr(self, 'or1x1_reparam'): 216 | return 217 | kernel, bias = self.get_equivalent_kernel_bias() 218 | self.or1x1_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, 219 | kernel_size=self.kernel_size, stride=self.stride, 220 | padding=self.padding, dilation=self.dilation, groups=self.groups, bias=True) 221 | self.or1x1_reparam.weight.data = kernel 222 | self.or1x1_reparam.bias.data = bias 223 | for para in self.parameters(): 224 | para.detach_() 225 | self.__delattr__('weight_or1x1_origin') 226 | self.__delattr__('weight_or1x1_l2i_conv1') 227 | self.__delattr__('weight_or1x1_l2i_conv2') 228 | self.__delattr__('vector') 229 | self.__delattr__('bn') 230 | 231 | def init_gamma(self, gamma_value): 232 | init.constant_(self.vector, gamma_value) 233 | 234 | def single_init(self): 235 | self.init_gamma(0.0) 236 | init.constant_(self.vector[0, :], 1.0) 237 | 238 | class OREPA(nn.Module): 239 | 240 | def __init__(self, 241 | in_channels, 242 | out_channels, 243 | kernel_size, 244 | stride=1, 245 | padding=0, 246 | dilation=1, 247 | groups=1, 248 | internal_channels_1x1_3x3=None, 249 | deploy=False, 250 | nonlinear=None, 251 | single_init=False, 252 | weight_only=False, 253 | init_hyper_para=1.0, init_hyper_gamma=1.0): 254 | super(OREPA, self).__init__() 255 | self.deploy = deploy 256 | 257 | if nonlinear is None: 258 | self.nonlinear = nn.Identity() 259 | else: 260 | self.nonlinear = nonlinear 261 | self.weight_only = weight_only 262 | 263 | self.kernel_size = kernel_size 264 | self.in_channels = in_channels 265 | self.out_channels = out_channels 266 | self.groups = groups 267 | assert padding == kernel_size // 2 268 | 269 | self.stride = stride 270 | self.padding = padding 271 | self.dilation = dilation 272 | 273 | if deploy: 274 | self.orepa_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 275 | padding=padding, dilation=dilation, groups=groups, bias=True) 276 | 277 | else: 278 | 279 | self.branch_counter = 0 280 | 281 | self.weight_orepa_origin = nn.Parameter( 282 | torch.Tensor(out_channels, int(in_channels / self.groups), 283 | kernel_size, kernel_size)) 284 | init.kaiming_uniform_(self.weight_orepa_origin, a=math.sqrt(0.0)) 285 | self.branch_counter += 1 286 | 287 | self.weight_orepa_avg_conv = nn.Parameter( 288 | torch.Tensor(out_channels, int(in_channels / self.groups), 1, 289 | 1)) 290 | self.weight_orepa_pfir_conv = nn.Parameter( 291 | torch.Tensor(out_channels, int(in_channels / self.groups), 1, 292 | 1)) 293 | init.kaiming_uniform_(self.weight_orepa_avg_conv, a=0.0) 294 | init.kaiming_uniform_(self.weight_orepa_pfir_conv, a=0.0) 295 | self.register_buffer( 296 | 'weight_orepa_avg_avg', 297 | torch.ones(kernel_size, 298 | kernel_size).mul(1.0 / kernel_size / kernel_size)) 299 | self.branch_counter += 1 300 | self.branch_counter += 1 301 | 302 | self.weight_orepa_1x1 = nn.Parameter( 303 | torch.Tensor(out_channels, int(in_channels / self.groups), 1, 304 | 1)) 305 | init.kaiming_uniform_(self.weight_orepa_1x1, a=0.0) 306 | self.branch_counter += 1 307 | 308 | if internal_channels_1x1_3x3 is None: 309 | internal_channels_1x1_3x3 = in_channels if groups <= 4 else 2 * in_channels 310 | 311 | if internal_channels_1x1_3x3 == in_channels: 312 | self.weight_orepa_1x1_kxk_idconv1 = nn.Parameter( 313 | torch.zeros(in_channels, int(in_channels / self.groups), 1, 1)) 314 | id_value = np.zeros( 315 | (in_channels, int(in_channels / self.groups), 1, 1)) 316 | for i in range(in_channels): 317 | id_value[i, i % int(in_channels / self.groups), 0, 0] = 1 318 | id_tensor = torch.from_numpy(id_value).type_as( 319 | self.weight_orepa_1x1_kxk_idconv1) 320 | self.register_buffer('id_tensor', id_tensor) 321 | 322 | else: 323 | self.weight_orepa_1x1_kxk_idconv1 = nn.Parameter( 324 | torch.zeros(internal_channels_1x1_3x3, 325 | int(in_channels / self.groups), 1, 1)) 326 | id_value = np.zeros( 327 | (internal_channels_1x1_3x3, int(in_channels / self.groups), 1, 1)) 328 | for i in range(internal_channels_1x1_3x3): 329 | id_value[i, i % int(in_channels / self.groups), 0, 0] = 1 330 | id_tensor = torch.from_numpy(id_value).type_as( 331 | self.weight_orepa_1x1_kxk_idconv1) 332 | self.register_buffer('id_tensor', id_tensor) 333 | #init.kaiming_uniform_( 334 | #self.weight_orepa_1x1_kxk_conv1, a=math.sqrt(0.0)) 335 | self.weight_orepa_1x1_kxk_conv2 = nn.Parameter( 336 | torch.Tensor(out_channels, 337 | int(internal_channels_1x1_3x3 / self.groups), 338 | kernel_size, kernel_size)) 339 | init.kaiming_uniform_(self.weight_orepa_1x1_kxk_conv2, a=math.sqrt(0.0)) 340 | self.branch_counter += 1 341 | 342 | expand_ratio = 8 343 | self.weight_orepa_gconv_dw = nn.Parameter( 344 | torch.Tensor(in_channels * expand_ratio, 1, kernel_size, 345 | kernel_size)) 346 | self.weight_orepa_gconv_pw = nn.Parameter( 347 | torch.Tensor(out_channels, int(in_channels * expand_ratio / self.groups), 1, 1)) 348 | init.kaiming_uniform_(self.weight_orepa_gconv_dw, a=math.sqrt(0.0)) 349 | init.kaiming_uniform_(self.weight_orepa_gconv_pw, a=math.sqrt(0.0)) 350 | self.branch_counter += 1 351 | 352 | self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels)) 353 | if weight_only is False: 354 | self.bn = nn.BatchNorm2d(self.out_channels) 355 | 356 | self.fre_init() 357 | 358 | init.constant_(self.vector[0, :], 0.25 * math.sqrt(init_hyper_gamma)) #origin 359 | init.constant_(self.vector[1, :], 0.25 * math.sqrt(init_hyper_gamma)) #avg 360 | init.constant_(self.vector[2, :], 0.0 * math.sqrt(init_hyper_gamma)) #prior 361 | init.constant_(self.vector[3, :], 0.5 * math.sqrt(init_hyper_gamma)) #1x1_kxk 362 | init.constant_(self.vector[4, :], 1.0 * math.sqrt(init_hyper_gamma)) #1x1 363 | init.constant_(self.vector[5, :], 0.5 * math.sqrt(init_hyper_gamma)) #dws_conv 364 | 365 | self.weight_orepa_1x1.data = self.weight_orepa_1x1.mul(init_hyper_para) 366 | self.weight_orepa_origin.data = self.weight_orepa_origin.mul(init_hyper_para) 367 | self.weight_orepa_1x1_kxk_conv2.data = self.weight_orepa_1x1_kxk_conv2.mul(init_hyper_para) 368 | self.weight_orepa_avg_conv.data = self.weight_orepa_avg_conv.mul(init_hyper_para) 369 | self.weight_orepa_pfir_conv.data = self.weight_orepa_pfir_conv.mul(init_hyper_para) 370 | 371 | self.weight_orepa_gconv_dw.data = self.weight_orepa_gconv_dw.mul(math.sqrt(init_hyper_para)) 372 | self.weight_orepa_gconv_pw.data = self.weight_orepa_gconv_pw.mul(math.sqrt(init_hyper_para)) 373 | 374 | if single_init: 375 | # Initialize the vector.weight of origin as 1 and others as 0. This is not the default setting. 376 | self.single_init() 377 | 378 | def fre_init(self): 379 | prior_tensor = torch.Tensor(self.out_channels, self.kernel_size, 380 | self.kernel_size) 381 | half_fg = self.out_channels / 2 382 | for i in range(self.out_channels): 383 | for h in range(3): 384 | for w in range(3): 385 | if i < half_fg: 386 | prior_tensor[i, h, w] = math.cos(math.pi * (h + 0.5) * 387 | (i + 1) / 3) 388 | else: 389 | prior_tensor[i, h, w] = math.cos(math.pi * (w + 0.5) * 390 | (i + 1 - half_fg) / 3) 391 | 392 | self.register_buffer('weight_orepa_prior', prior_tensor) 393 | 394 | def weight_gen(self): 395 | weight_orepa_origin = torch.einsum('oihw,o->oihw', 396 | self.weight_orepa_origin, 397 | self.vector[0, :]) 398 | 399 | weight_orepa_avg = torch.einsum('oihw,hw->oihw', self.weight_orepa_avg_conv, self.weight_orepa_avg_avg) 400 | weight_orepa_avg = torch.einsum( 401 | 'oihw,o->oihw', 402 | torch.einsum('oi,hw->oihw', self.weight_orepa_avg_conv.squeeze(3).squeeze(2), 403 | self.weight_orepa_avg_avg), self.vector[1, :]) 404 | 405 | 406 | weight_orepa_pfir = torch.einsum( 407 | 'oihw,o->oihw', 408 | torch.einsum('oi,ohw->oihw', self.weight_orepa_pfir_conv.squeeze(3).squeeze(2), 409 | self.weight_orepa_prior), self.vector[2, :]) 410 | 411 | weight_orepa_1x1_kxk_conv1 = None 412 | if hasattr(self, 'weight_orepa_1x1_kxk_idconv1'): 413 | weight_orepa_1x1_kxk_conv1 = (self.weight_orepa_1x1_kxk_idconv1 + 414 | self.id_tensor).squeeze(3).squeeze(2) 415 | elif hasattr(self, 'weight_orepa_1x1_kxk_conv1'): 416 | weight_orepa_1x1_kxk_conv1 = self.weight_orepa_1x1_kxk_conv1.squeeze(3).squeeze(2) 417 | else: 418 | raise NotImplementedError 419 | weight_orepa_1x1_kxk_conv2 = self.weight_orepa_1x1_kxk_conv2 420 | 421 | if self.groups > 1: 422 | g = self.groups 423 | t, ig = weight_orepa_1x1_kxk_conv1.size() 424 | o, tg, h, w = weight_orepa_1x1_kxk_conv2.size() 425 | weight_orepa_1x1_kxk_conv1 = weight_orepa_1x1_kxk_conv1.view( 426 | g, int(t / g), ig) 427 | weight_orepa_1x1_kxk_conv2 = weight_orepa_1x1_kxk_conv2.view( 428 | g, int(o / g), tg, h, w) 429 | weight_orepa_1x1_kxk = torch.einsum('gti,gothw->goihw', 430 | weight_orepa_1x1_kxk_conv1, 431 | weight_orepa_1x1_kxk_conv2).reshape( 432 | o, ig, h, w) 433 | else: 434 | weight_orepa_1x1_kxk = torch.einsum('ti,othw->oihw', 435 | weight_orepa_1x1_kxk_conv1, 436 | weight_orepa_1x1_kxk_conv2) 437 | weight_orepa_1x1_kxk = torch.einsum('oihw,o->oihw', weight_orepa_1x1_kxk, self.vector[3, :]) 438 | 439 | weight_orepa_1x1 = 0 440 | if hasattr(self, 'weight_orepa_1x1'): 441 | weight_orepa_1x1 = transVI_multiscale(self.weight_orepa_1x1, 442 | self.kernel_size) 443 | weight_orepa_1x1 = torch.einsum('oihw,o->oihw', weight_orepa_1x1, 444 | self.vector[4, :]) 445 | 446 | weight_orepa_gconv = self.dwsc2full(self.weight_orepa_gconv_dw, 447 | self.weight_orepa_gconv_pw, 448 | self.in_channels, self.groups) 449 | weight_orepa_gconv = torch.einsum('oihw,o->oihw', weight_orepa_gconv, 450 | self.vector[5, :]) 451 | 452 | weight = weight_orepa_origin + weight_orepa_avg + weight_orepa_1x1 + weight_orepa_1x1_kxk + weight_orepa_pfir + weight_orepa_gconv 453 | 454 | return weight 455 | 456 | def dwsc2full(self, weight_dw, weight_pw, groups, groups_conv=1): 457 | 458 | t, ig, h, w = weight_dw.size() 459 | o, _, _, _ = weight_pw.size() 460 | tg = int(t / groups) 461 | i = int(ig * groups) 462 | ogc = int(o / groups_conv) 463 | groups_gc = int(groups / groups_conv) 464 | weight_dw = weight_dw.view(groups_conv, groups_gc, tg, ig, h, w) 465 | weight_pw = weight_pw.squeeze().view(ogc, groups_conv, groups_gc, tg) 466 | 467 | weight_dsc = torch.einsum('cgtihw,ocgt->cogihw', weight_dw, weight_pw) 468 | return weight_dsc.reshape(o, int(i/groups_conv), h, w) 469 | 470 | def forward(self, inputs=None): 471 | if hasattr(self, 'orepa_reparam'): 472 | return self.nonlinear(self.orepa_reparam(inputs)) 473 | 474 | weight = self.weight_gen() 475 | 476 | if self.weight_only is True: 477 | return weight 478 | 479 | out = F.conv2d( 480 | inputs, 481 | weight, 482 | bias=None, 483 | stride=self.stride, 484 | padding=self.padding, 485 | dilation=self.dilation, 486 | groups=self.groups) 487 | return self.nonlinear(self.bn(out)) 488 | 489 | def get_equivalent_kernel_bias(self): 490 | return transI_fusebn(self.weight_gen(), self.bn) 491 | 492 | def switch_to_deploy(self): 493 | if hasattr(self, 'or1x1_reparam'): 494 | return 495 | kernel, bias = self.get_equivalent_kernel_bias() 496 | self.orepa_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, 497 | kernel_size=self.kernel_size, stride=self.stride, 498 | padding=self.padding, dilation=self.dilation, groups=self.groups, bias=True) 499 | self.orepa_reparam.weight.data = kernel 500 | self.orepa_reparam.bias.data = bias 501 | for para in self.parameters(): 502 | para.detach_() 503 | self.__delattr__('weight_orepa_origin') 504 | self.__delattr__('weight_orepa_1x1') 505 | self.__delattr__('weight_orepa_1x1_kxk_conv2') 506 | if hasattr(self, 'weight_orepa_1x1_kxk_idconv1'): 507 | self.__delattr__('id_tensor') 508 | self.__delattr__('weight_orepa_1x1_kxk_idconv1') 509 | elif hasattr(self, 'weight_orepa_1x1_kxk_conv1'): 510 | self.__delattr__('weight_orepa_1x1_kxk_conv1') 511 | else: 512 | raise NotImplementedError 513 | self.__delattr__('weight_orepa_avg_avg') 514 | self.__delattr__('weight_orepa_avg_conv') 515 | self.__delattr__('weight_orepa_pfir_conv') 516 | self.__delattr__('weight_orepa_prior') 517 | self.__delattr__('weight_orepa_gconv_dw') 518 | self.__delattr__('weight_orepa_gconv_pw') 519 | 520 | self.__delattr__('bn') 521 | self.__delattr__('vector') 522 | 523 | def init_gamma(self, gamma_value): 524 | init.constant_(self.vector, gamma_value) 525 | 526 | def single_init(self): 527 | self.init_gamma(0.0) 528 | init.constant_(self.vector[0, :], 1.0) 529 | 530 | class OREPA_LargeConvBase(nn.Module): 531 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, deploy=False, nonlinear=None): 532 | super(OREPA_LargeConvBase, self).__init__() 533 | assert kernel_size % 2 == 1 and kernel_size > 3 534 | 535 | self.stride = stride 536 | self.padding = padding 537 | self.layers = int((kernel_size - 1) / 2) 538 | self.groups = groups 539 | self.dilation = dilation 540 | 541 | internal_channels = out_channels 542 | 543 | self.kernel_size = kernel_size 544 | self.in_channels = in_channels 545 | self.out_channels = out_channels 546 | 547 | if nonlinear is None: 548 | self.nonlinear = nn.Identity() 549 | else: 550 | self.nonlinear = nonlinear 551 | 552 | if deploy: 553 | self.or_large_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 554 | padding=padding, dilation=dilation, groups=groups, bias=True) 555 | 556 | else: 557 | for i in range(self.layers): 558 | if i == 0: 559 | self.__setattr__('weight'+str(i), nn.Parameter(torch.Tensor(internal_channels, int(in_channels/self.groups), 3, 3))) 560 | elif i == self.layers - 1: 561 | self.__setattr__('weight'+str(i), nn.Parameter(torch.Tensor(out_channels, int(internal_channels/self.groups), 3, 3))) 562 | else: 563 | self.__setattr__('weight'+str(i), nn.Parameter(torch.Tensor(internal_channels, int(internal_channels/self.groups), 3, 3))) 564 | init.kaiming_uniform_(getattr(self, 'weight'+str(i)), a=math.sqrt(5)) 565 | 566 | self.bn = nn.BatchNorm2d(out_channels) 567 | #self.unfold = torch.nn.Unfold(kernel_size=3, dilation=1, padding=2, stride=1) 568 | 569 | def weight_gen(self): 570 | weight = getattr(self, 'weight'+str(0)).transpose(0, 1) 571 | for i in range(self.layers - 1): 572 | weight2 = getattr(self, 'weight'+str(i+1)) 573 | weight = F.conv2d(weight, weight2, groups=self.groups, padding=2) 574 | 575 | return weight.transpose(0, 1) 576 | ''' 577 | weight = getattr(self, 'weight'+str(0)).transpose(0, 1) 578 | for i in range(self.layers - 1): 579 | weight = self.unfold(weight) 580 | weight2 = getattr(self, 'weight'+str(i+1)) 581 | 582 | weight = torch.einsum('akl,bk->abl', weight, weight2.view(weight2.size(0), -1)) 583 | k = i * 2 + 5 584 | weight = weight.view(weight.size(0), weight.size(1), k, k) 585 | 586 | return weight.transpose(0, 1) 587 | ''' 588 | 589 | def forward(self, inputs): 590 | if hasattr(self, 'or_large_reparam'): 591 | return self.nonlinear(self.or_large_reparam(inputs)) 592 | 593 | weight = self.weight_gen() 594 | out = F.conv2d(inputs, weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) 595 | return self.nonlinear(self.bn(out)) 596 | 597 | def get_equivalent_kernel_bias(self): 598 | return transI_fusebn(self.weight_gen(), self.bn) 599 | 600 | def switch_to_deploy(self): 601 | if hasattr(self, 'or_large_reparam'): 602 | return 603 | kernel, bias = self.get_equivalent_kernel_bias() 604 | self.or_large_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, 605 | kernel_size=self.kernel_size, stride=self.stride, 606 | padding=self.padding, dilation=self.dilation, groups=self.groups, bias=True) 607 | self.or_large_reparam.weight.data = kernel 608 | self.or_large_reparam.bias.data = bias 609 | for para in self.parameters(): 610 | para.detach_() 611 | for i in range(self.layers): 612 | self.__delattr__('weight'+str(i)) 613 | self.__delattr__('bn') 614 | 615 | 616 | class OREPA_LargeConv(nn.Module): 617 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, deploy=False, nonlinear=None): 618 | super(OREPA_LargeConv, self).__init__() 619 | assert kernel_size % 2 == 1 and kernel_size > 3 620 | 621 | self.stride = stride 622 | self.padding = padding 623 | self.layers = int((kernel_size - 1) / 2) 624 | self.groups = groups 625 | self.dilation = dilation 626 | 627 | self.kernel_size = kernel_size 628 | self.in_channels = in_channels 629 | self.out_channels = out_channels 630 | 631 | internal_channels = out_channels 632 | 633 | if nonlinear is None: 634 | self.nonlinear = nn.Identity() 635 | else: 636 | self.nonlinear = nonlinear 637 | 638 | if deploy: 639 | self.or_large_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 640 | padding=padding, dilation=dilation, groups=groups, bias=True) 641 | 642 | else: 643 | for i in range(self.layers): 644 | if i == 0: 645 | self.__setattr__('weight'+str(i), OREPA(in_channels, internal_channels, kernel_size=3, stride=1, padding=1, groups=groups, weight_only=True)) 646 | elif i == self.layers - 1: 647 | self.__setattr__('weight'+str(i), OREPA(internal_channels, out_channels, kernel_size=3, stride=self.stride, padding=1, weight_only=True)) 648 | else: 649 | self.__setattr__('weight'+str(i), OREPA(internal_channels, internal_channels, kernel_size=3, stride=1, padding=1, weight_only=True)) 650 | 651 | self.bn = nn.BatchNorm2d(out_channels) 652 | #self.unfold = torch.nn.Unfold(kernel_size=3, dilation=1, padding=2, stride=1) 653 | 654 | def weight_gen(self): 655 | weight = getattr(self, 'weight'+str(0)).weight_gen().transpose(0, 1) 656 | for i in range(self.layers - 1): 657 | weight2 = getattr(self, 'weight'+str(i+1)).weight_gen() 658 | weight = F.conv2d(weight, weight2, groups=self.groups, padding=2) 659 | 660 | return weight.transpose(0, 1) 661 | ''' 662 | weight = getattr(self, 'weight'+str(0))(inputs=None).transpose(0, 1) 663 | for i in range(self.layers - 1): 664 | weight = self.unfold(weight) 665 | weight2 = getattr(self, 'weight'+str(i+1))(inputs=None) 666 | 667 | weight = torch.einsum('akl,bk->abl', weight, weight2.view(weight2.size(0), -1)) 668 | k = i * 2 + 5 669 | weight = weight.view(weight.size(0), weight.size(1), k, k) 670 | 671 | return weight.transpose(0, 1) 672 | ''' 673 | 674 | def forward(self, inputs): 675 | if hasattr(self, 'or_large_reparam'): 676 | return self.nonlinear(self.or_large_reparam(inputs)) 677 | 678 | weight = self.weight_gen() 679 | out = F.conv2d(inputs, weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) 680 | return self.nonlinear(self.bn(out)) 681 | 682 | def get_equivalent_kernel_bias(self): 683 | return transI_fusebn(self.weight_gen(), self.bn) 684 | 685 | def switch_to_deploy(self): 686 | if hasattr(self, 'or_large_reparam'): 687 | return 688 | kernel, bias = self.get_equivalent_kernel_bias() 689 | self.or_large_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, 690 | kernel_size=self.kernel_size, stride=self.stride, 691 | padding=self.padding, dilation=self.dilation, groups=self.groups, bias=True) 692 | self.or_large_reparam.weight.data = kernel 693 | self.or_large_reparam.bias.data = bias 694 | for para in self.parameters(): 695 | para.detach_() 696 | for i in range(self.layers): 697 | self.__delattr__('weight'+str(i)) 698 | self.__delattr__('bn') 699 | -------------------------------------------------------------------------------- /blocks_repvgg.py: -------------------------------------------------------------------------------- 1 | 2 | from basic import * 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class RepVGGBlock(nn.Module): 10 | 11 | def __init__(self, in_channels, out_channels, kernel_size, 12 | stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False): 13 | super(RepVGGBlock, self).__init__() 14 | self.deploy = deploy 15 | self.groups = groups 16 | self.in_channels = in_channels 17 | 18 | assert kernel_size == 3 19 | assert padding == 1 20 | 21 | padding_11 = padding - kernel_size // 2 22 | 23 | self.nonlinearity = nn.ReLU() 24 | 25 | if use_se: 26 | self.se = SEBlock(out_channels, internal_neurons=out_channels // 16) 27 | else: 28 | self.se = nn.Identity() 29 | 30 | if deploy: 31 | self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 32 | padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) 33 | 34 | else: 35 | self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None 36 | self.rbr_dense = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) 37 | self.rbr_1x1 = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups) 38 | print('RepVGG Block, identity = ', self.rbr_identity) 39 | 40 | 41 | def forward(self, inputs): 42 | if hasattr(self, 'rbr_reparam'): 43 | return self.nonlinearity(self.se(self.rbr_reparam(inputs))) 44 | 45 | if self.rbr_identity is None: 46 | id_out = 0 47 | else: 48 | id_out = self.rbr_identity(inputs) 49 | 50 | return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)) 51 | 52 | 53 | # Optional. This improves the accuracy and facilitates quantization. 54 | # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight. 55 | # 2. Use like this. 56 | # loss = criterion(....) 57 | # for every RepVGGBlock blk: 58 | # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2() 59 | # optimizer.zero_grad() 60 | # loss.backward() 61 | def get_custom_L2(self): 62 | K3 = self.rbr_dense.conv.weight 63 | K1 = self.rbr_1x1.conv.weight 64 | t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() 65 | t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() 66 | 67 | l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them. 68 | eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel. 69 | l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2. 70 | return l2_loss_eq_kernel + l2_loss_circle 71 | 72 | 73 | 74 | # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way. 75 | # You can get the equivalent kernel and bias at any time and do whatever you want, 76 | # for example, apply some penalties or constraints during training, just like you do to the other models. 77 | # May be useful for quantization or pruning. 78 | def get_equivalent_kernel_bias(self): 79 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) 80 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) 81 | kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) 82 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid 83 | 84 | def _pad_1x1_to_3x3_tensor(self, kernel1x1): 85 | if kernel1x1 is None: 86 | return 0 87 | else: 88 | return torch.nn.functional.pad(kernel1x1, [1,1,1,1]) 89 | 90 | def _fuse_bn_tensor(self, branch): 91 | if branch is None: 92 | return 0, 0 93 | if not isinstance(branch, nn.BatchNorm2d): 94 | kernel = branch.conv.weight 95 | running_mean = branch.bn.running_mean 96 | running_var = branch.bn.running_var 97 | gamma = branch.bn.weight 98 | beta = branch.bn.bias 99 | eps = branch.bn.eps 100 | else: 101 | if not hasattr(self, 'id_tensor'): 102 | input_dim = self.in_channels // self.groups 103 | kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) 104 | for i in range(self.in_channels): 105 | kernel_value[i, i % input_dim, 1, 1] = 1 106 | self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) 107 | kernel = self.id_tensor 108 | running_mean = branch.running_mean 109 | running_var = branch.running_var 110 | gamma = branch.weight 111 | beta = branch.bias 112 | eps = branch.eps 113 | std = (running_var + eps).sqrt() 114 | t = (gamma / std).reshape(-1, 1, 1, 1) 115 | return kernel * t, beta - running_mean * gamma / std 116 | 117 | def switch_to_deploy(self): 118 | if hasattr(self, 'rbr_reparam'): 119 | return 120 | kernel, bias = self.get_equivalent_kernel_bias() 121 | self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels, 122 | kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride, 123 | padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True) 124 | self.rbr_reparam.weight.data = kernel 125 | self.rbr_reparam.bias.data = bias 126 | for para in self.parameters(): 127 | para.detach_() 128 | self.__delattr__('rbr_dense') 129 | self.__delattr__('rbr_1x1') 130 | if hasattr(self, 'rbr_identity'): 131 | self.__delattr__('rbr_identity') 132 | 133 | 134 | class OREPA_3x3_RepVGG(nn.Module): 135 | 136 | def __init__(self, in_channels, out_channels, kernel_size, 137 | stride=1, padding=0, dilation=1, groups=1, 138 | internal_channels_1x1_3x3=None, 139 | deploy=False, nonlinear=None, single_init=False): 140 | super(OREPA_3x3_RepVGG, self).__init__() 141 | self.deploy = deploy 142 | 143 | if nonlinear is None: 144 | self.nonlinear = nn.Identity() 145 | else: 146 | self.nonlinear = nonlinear 147 | 148 | self.kernel_size = kernel_size 149 | self.in_channels = in_channels 150 | self.out_channels = out_channels 151 | self.groups = groups 152 | assert padding == kernel_size // 2 153 | 154 | self.stride = stride 155 | self.padding = padding 156 | self.dilation = dilation 157 | 158 | self.branch_counter = 0 159 | 160 | self.weight_rbr_origin = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), kernel_size, kernel_size)) 161 | init.kaiming_uniform_(self.weight_rbr_origin, a=math.sqrt(1.0)) 162 | self.branch_counter += 1 163 | 164 | 165 | if groups < out_channels: 166 | self.weight_rbr_avg_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1)) 167 | self.weight_rbr_pfir_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1)) 168 | init.kaiming_uniform_(self.weight_rbr_avg_conv, a=1.0) 169 | init.kaiming_uniform_(self.weight_rbr_pfir_conv, a=1.0) 170 | self.weight_rbr_avg_conv.data 171 | self.weight_rbr_pfir_conv.data 172 | self.register_buffer('weight_rbr_avg_avg', torch.ones(kernel_size, kernel_size).mul(1.0/kernel_size/kernel_size)) 173 | self.branch_counter += 1 174 | 175 | else: 176 | raise NotImplementedError 177 | self.branch_counter += 1 178 | 179 | if internal_channels_1x1_3x3 is None: 180 | internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels 181 | 182 | if internal_channels_1x1_3x3 == in_channels: 183 | self.weight_rbr_1x1_kxk_idconv1 = nn.Parameter(torch.zeros(in_channels, int(in_channels/self.groups), 1, 1)) 184 | id_value = np.zeros((in_channels, int(in_channels/self.groups), 1, 1)) 185 | for i in range(in_channels): 186 | id_value[i, i % int(in_channels/self.groups), 0, 0] = 1 187 | id_tensor = torch.from_numpy(id_value).type_as(self.weight_rbr_1x1_kxk_idconv1) 188 | self.register_buffer('id_tensor', id_tensor) 189 | 190 | else: 191 | self.weight_rbr_1x1_kxk_conv1 = nn.Parameter(torch.Tensor(internal_channels_1x1_3x3, int(in_channels/self.groups), 1, 1)) 192 | init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv1, a=math.sqrt(1.0)) 193 | self.weight_rbr_1x1_kxk_conv2 = nn.Parameter(torch.Tensor(out_channels, int(internal_channels_1x1_3x3/self.groups), kernel_size, kernel_size)) 194 | init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv2, a=math.sqrt(1.0)) 195 | self.branch_counter += 1 196 | 197 | expand_ratio = 8 198 | self.weight_rbr_gconv_dw = nn.Parameter(torch.Tensor(in_channels*expand_ratio, 1, kernel_size, kernel_size)) 199 | self.weight_rbr_gconv_pw = nn.Parameter(torch.Tensor(out_channels, in_channels*expand_ratio, 1, 1)) 200 | init.kaiming_uniform_(self.weight_rbr_gconv_dw, a=math.sqrt(1.0)) 201 | init.kaiming_uniform_(self.weight_rbr_gconv_pw, a=math.sqrt(1.0)) 202 | self.branch_counter += 1 203 | 204 | if out_channels == in_channels and stride == 1: 205 | self.branch_counter += 1 206 | 207 | self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels)) 208 | self.bn = nn.BatchNorm2d(out_channels) 209 | 210 | self.fre_init() 211 | 212 | init.constant_(self.vector[0, :], 0.25) #origin 213 | init.constant_(self.vector[1, :], 0.25) #avg 214 | init.constant_(self.vector[2, :], 0.0) #prior 215 | init.constant_(self.vector[3, :], 0.5) #1x1_kxk 216 | init.constant_(self.vector[4, :], 0.5) #dws_conv 217 | 218 | 219 | def fre_init(self): 220 | prior_tensor = torch.Tensor(self.out_channels, self.kernel_size, self.kernel_size) 221 | half_fg = self.out_channels/2 222 | for i in range(self.out_channels): 223 | for h in range(3): 224 | for w in range(3): 225 | if i < half_fg: 226 | prior_tensor[i, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3) 227 | else: 228 | prior_tensor[i, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_fg)/3) 229 | 230 | self.register_buffer('weight_rbr_prior', prior_tensor) 231 | 232 | def weight_gen(self): 233 | 234 | weight_rbr_origin = torch.einsum('oihw,o->oihw', self.weight_rbr_origin, self.vector[0, :]) 235 | 236 | weight_rbr_avg = torch.einsum('oihw,o->oihw', torch.einsum('oihw,hw->oihw', self.weight_rbr_avg_conv, self.weight_rbr_avg_avg), self.vector[1, :]) 237 | 238 | weight_rbr_pfir = torch.einsum('oihw,o->oihw', torch.einsum('oihw,ohw->oihw', self.weight_rbr_pfir_conv, self.weight_rbr_prior), self.vector[2, :]) 239 | 240 | weight_rbr_1x1_kxk_conv1 = None 241 | if hasattr(self, 'weight_rbr_1x1_kxk_idconv1'): 242 | weight_rbr_1x1_kxk_conv1 = (self.weight_rbr_1x1_kxk_idconv1 + self.id_tensor).squeeze() 243 | elif hasattr(self, 'weight_rbr_1x1_kxk_conv1'): 244 | weight_rbr_1x1_kxk_conv1 = self.weight_rbr_1x1_kxk_conv1.squeeze() 245 | else: 246 | raise NotImplementedError 247 | weight_rbr_1x1_kxk_conv2 = self.weight_rbr_1x1_kxk_conv2 248 | 249 | if self.groups > 1: 250 | g = self.groups 251 | t, ig = weight_rbr_1x1_kxk_conv1.size() 252 | o, tg, h, w = weight_rbr_1x1_kxk_conv2.size() 253 | weight_rbr_1x1_kxk_conv1 = weight_rbr_1x1_kxk_conv1.view(g, int(t/g), ig) 254 | weight_rbr_1x1_kxk_conv2 = weight_rbr_1x1_kxk_conv2.view(g, int(o/g), tg, h, w) 255 | weight_rbr_1x1_kxk = torch.einsum('gti,gothw->goihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2).view(o, ig, h, w) 256 | else: 257 | weight_rbr_1x1_kxk = torch.einsum('ti,othw->oihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2) 258 | 259 | weight_rbr_1x1_kxk = torch.einsum('oihw,o->oihw', weight_rbr_1x1_kxk, self.vector[3, :]) 260 | 261 | weight_rbr_gconv = self.dwsc2full(self.weight_rbr_gconv_dw, self.weight_rbr_gconv_pw, self.in_channels) 262 | weight_rbr_gconv = torch.einsum('oihw,o->oihw', weight_rbr_gconv, self.vector[4, :]) 263 | 264 | weight = weight_rbr_origin + weight_rbr_avg + weight_rbr_1x1_kxk + weight_rbr_pfir + weight_rbr_gconv 265 | 266 | return weight 267 | 268 | def dwsc2full(self, weight_dw, weight_pw, groups): 269 | 270 | t, ig, h, w = weight_dw.size() 271 | o, _, _, _ = weight_pw.size() 272 | tg = int(t/groups) 273 | i = int(ig*groups) 274 | weight_dw = weight_dw.view(groups, tg, ig, h, w) 275 | weight_pw = weight_pw.squeeze().view(o, groups, tg) 276 | 277 | weight_dsc = torch.einsum('gtihw,ogt->ogihw', weight_dw, weight_pw) 278 | return weight_dsc.view(o, i, h, w) 279 | 280 | def forward(self, inputs): 281 | weight = self.weight_gen() 282 | out = F.conv2d(inputs, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) 283 | 284 | return self.nonlinear(self.bn(out)) 285 | 286 | class RepVGGBlock_OREPA(nn.Module): 287 | 288 | def __init__(self, in_channels, out_channels, kernel_size, 289 | stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, nonlinear=nn.ReLU()): 290 | super(RepVGGBlock_OREPA, self).__init__() 291 | self.deploy = deploy 292 | self.groups = groups 293 | self.in_channels = in_channels 294 | self.out_channels = out_channels 295 | 296 | self.padding = padding 297 | self.dilation = dilation 298 | self.groups = groups 299 | 300 | assert kernel_size == 3 301 | assert padding == 1 302 | 303 | padding_11 = padding - kernel_size // 2 304 | 305 | if nonlinear is None: 306 | self.nonlinearity = nn.Identity() 307 | else: 308 | self.nonlinearity = nonlinear 309 | 310 | if use_se: 311 | self.se = SEBlock(out_channels, internal_neurons=out_channels // 16) 312 | else: 313 | self.se = nn.Identity() 314 | 315 | if deploy: 316 | self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 317 | padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) 318 | 319 | else: 320 | self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None 321 | self.rbr_dense = OREPA_3x3_RepVGG(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=1) 322 | self.rbr_1x1 = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups, dilation=1) 323 | print('RepVGG Block, identity = ', self.rbr_identity) 324 | 325 | 326 | def forward(self, inputs): 327 | if hasattr(self, 'rbr_reparam'): 328 | return self.nonlinearity(self.se(self.rbr_reparam(inputs))) 329 | 330 | if self.rbr_identity is None: 331 | id_out = 0 332 | else: 333 | id_out = self.rbr_identity(inputs) 334 | 335 | out1 = self.rbr_dense(inputs) 336 | out2 = self.rbr_1x1(inputs) 337 | out3 = id_out 338 | out = out1 + out2 + out3 339 | 340 | return self.nonlinearity(self.se(out)) 341 | 342 | 343 | # Optional. This improves the accuracy and facilitates quantization. 344 | # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight. 345 | # 2. Use like this. 346 | # loss = criterion(....) 347 | # for every RepVGGBlock blk: 348 | # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2() 349 | # optimizer.zero_grad() 350 | # loss.backward() 351 | 352 | # Not used for OREPA 353 | def get_custom_L2(self): 354 | K3 = self.rbr_dense.weight_gen() 355 | K1 = self.rbr_1x1.conv.weight 356 | t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() 357 | t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() 358 | 359 | l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them. 360 | eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel. 361 | l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2. 362 | return l2_loss_eq_kernel + l2_loss_circle 363 | 364 | def get_equivalent_kernel_bias(self): 365 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) 366 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) 367 | kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) 368 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid 369 | 370 | def _pad_1x1_to_3x3_tensor(self, kernel1x1): 371 | if kernel1x1 is None: 372 | return 0 373 | else: 374 | return torch.nn.functional.pad(kernel1x1, [1,1,1,1]) 375 | 376 | def _fuse_bn_tensor(self, branch): 377 | if branch is None: 378 | return 0, 0 379 | if not isinstance(branch, nn.BatchNorm2d): 380 | if isinstance(branch, OREPA_3x3_RepVGG): 381 | kernel = branch.weight_gen() 382 | elif isinstance(branch, ConvBN): 383 | kernel = branch.conv.weight 384 | else: 385 | raise NotImplementedError 386 | running_mean = branch.bn.running_mean 387 | running_var = branch.bn.running_var 388 | gamma = branch.bn.weight 389 | beta = branch.bn.bias 390 | eps = branch.bn.eps 391 | else: 392 | if not hasattr(self, 'id_tensor'): 393 | input_dim = self.in_channels // self.groups 394 | kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) 395 | for i in range(self.in_channels): 396 | kernel_value[i, i % input_dim, 1, 1] = 1 397 | self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) 398 | kernel = self.id_tensor 399 | running_mean = branch.running_mean 400 | running_var = branch.running_var 401 | gamma = branch.weight 402 | beta = branch.bias 403 | eps = branch.eps 404 | std = (running_var + eps).sqrt() 405 | t = (gamma / std).reshape(-1, 1, 1, 1) 406 | return kernel * t, beta - running_mean * gamma / std 407 | 408 | def switch_to_deploy(self): 409 | if hasattr(self, 'rbr_reparam'): 410 | return 411 | kernel, bias = self.get_equivalent_kernel_bias() 412 | self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.in_channels, out_channels=self.rbr_dense.out_channels, 413 | kernel_size=self.rbr_dense.kernel_size, stride=self.rbr_dense.stride, 414 | padding=self.rbr_dense.padding, dilation=self.rbr_dense.dilation, groups=self.rbr_dense.groups, bias=True) 415 | self.rbr_reparam.weight.data = kernel 416 | self.rbr_reparam.bias.data = bias 417 | for para in self.parameters(): 418 | para.detach_() 419 | self.__delattr__('rbr_dense') 420 | self.__delattr__('rbr_1x1') 421 | if hasattr(self, 'rbr_identity'): 422 | self.__delattr__('rbr_identity') 423 | 424 | class SEBlock(nn.Module): 425 | 426 | def __init__(self, input_channels, internal_neurons): 427 | super(SEBlock, self).__init__() 428 | self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True) 429 | self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True) 430 | self.input_channels = input_channels 431 | 432 | def forward(self, inputs): 433 | x = F.avg_pool2d(inputs, kernel_size=inputs.size(3)) 434 | x = self.down(x) 435 | x = F.relu(x) 436 | x = self.up(x) 437 | x = torch.sigmoid(x) 438 | x = x.view(-1, self.input_channels, 1, 1) 439 | return inputs * x -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from convnet_utils import switch_conv_bn_impl, switch_deploy_flag, build_model 5 | 6 | parser = argparse.ArgumentParser(description='Convert Conversion') 7 | parser.add_argument('load', metavar='LOAD', help='path to the weights file') 8 | parser.add_argument('save', metavar='SAVE', help='path to the weights file') 9 | parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18') 10 | parser.add_argument('-t', '--blocktype', metavar='BLK', default='OREPA') 11 | 12 | def convert(): 13 | args = parser.parse_args() 14 | 15 | switch_conv_bn_impl(args.blocktype) 16 | switch_deploy_flag(False) 17 | train_model = build_model(args.arch) 18 | 19 | if 'hdf5' in args.load: 20 | from utils import model_load_hdf5 21 | model_load_hdf5(train_model, args.load) 22 | elif os.path.isfile(args.load): 23 | print("=> loading checkpoint '{}'".format(args.load)) 24 | checkpoint = torch.load(args.load) 25 | if 'state_dict' in checkpoint: 26 | checkpoint = checkpoint['state_dict'] 27 | ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names 28 | train_model.load_state_dict(ckpt) 29 | else: 30 | print("=> no checkpoint found at '{}'".format(args.load)) 31 | 32 | for m in train_model.modules(): 33 | if hasattr(m, 'switch_to_deploy'): 34 | m.switch_to_deploy() 35 | 36 | torch.save(train_model.state_dict(), args.save) 37 | 38 | 39 | if __name__ == '__main__': 40 | convert() -------------------------------------------------------------------------------- /convnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from basic import ConvBN 4 | from blocks import DBB, OREPA_1x1, OREPA, OREPA_LargeConvBase, OREPA_LargeConv 5 | from blocks_repvgg import RepVGGBlock, RepVGGBlock_OREPA 6 | 7 | CONV_BN_IMPL = 'base' 8 | 9 | DEPLOY_FLAG = False 10 | 11 | def choose_blk(kernel_size): 12 | if CONV_BN_IMPL == 'OREPA': 13 | if kernel_size == 1: 14 | blk_type = OREPA_1x1 15 | elif kernel_size >= 7: 16 | blk_type = OREPA_LargeConv 17 | else: 18 | blk_type = OREPA 19 | elif CONV_BN_IMPL == 'base' or kernel_size == 1 or kernel_size >= 7: 20 | blk_type = ConvBN 21 | elif CONV_BN_IMPL == 'DBB': 22 | blk_type = DBB 23 | elif CONV_BN_IMPL == 'RepVGG': 24 | blk_type = RepVGGBlock 25 | elif CONV_BN_IMPL == 'OREPA_VGG': 26 | blk_type = RepVGGBlock_OREPA 27 | else: 28 | raise NotImplementedError 29 | 30 | return blk_type 31 | 32 | def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, assign_type=None): 33 | if assign_type is not None: 34 | blk_type = assign_type 35 | else: 36 | blk_type = choose_blk(kernel_size) 37 | return blk_type(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 38 | padding=padding, dilation=dilation, groups=groups, deploy=DEPLOY_FLAG) 39 | 40 | def conv_bn_relu(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, assign_type=None): 41 | if assign_type is not None: 42 | blk_type = assign_type 43 | else: 44 | blk_type = choose_blk(kernel_size) 45 | return blk_type(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 46 | padding=padding, dilation=dilation, groups=groups, deploy=DEPLOY_FLAG, nonlinear=nn.ReLU()) 47 | 48 | def switch_conv_bn_impl(block_type): 49 | global CONV_BN_IMPL 50 | CONV_BN_IMPL = block_type 51 | 52 | def switch_deploy_flag(deploy): 53 | global DEPLOY_FLAG 54 | DEPLOY_FLAG = deploy 55 | print('deploy flag: ', DEPLOY_FLAG) 56 | 57 | def build_model(arch): 58 | if arch == 'ResNet-18': 59 | from models.resnet import create_Res18 60 | model = create_Res18() 61 | elif arch == 'ResNet-34': 62 | from models.resnet import create_Res34 63 | model = create_Res34() 64 | elif arch == 'ResNet-50': 65 | from models.resnet import create_Res50 66 | model = create_Res50() 67 | elif arch == 'ResNet-101': 68 | from models.resnet import create_Res101 69 | model = create_Res101() 70 | elif arch == 'RepVGG-A0': 71 | from models.repvgg import create_RepVGG_A0 72 | model = create_RepVGG_A0() 73 | elif arch == 'RepVGG-A1': 74 | from models.repvgg import create_RepVGG_A1 75 | model = create_RepVGG_A1() 76 | elif arch == 'RepVGG-A2': 77 | from models.repvgg import create_RepVGG_A2 78 | model = create_RepVGG_A2() 79 | elif arch == 'RepVGG-B1': 80 | from models.repvgg import create_RepVGG_B1 81 | model = create_RepVGG_B1() 82 | elif arch == 'ResNet-18-1.5x': 83 | from models.resnet import create_Res18_1d5x 84 | model = create_Res18_1d5x() 85 | elif arch == 'ResNet-18-2x': 86 | from models.resnet import create_Res18_2x 87 | model = create_Res18_2x() 88 | elif arch == 'ResNeXt-50': 89 | from models.resnext import create_Res50_32x4d 90 | model = create_Res50_32x4d() 91 | #elif arch == 'RegNet-800MF': 92 | #from models.regnet import create_Reg800MF 93 | #model = create_Reg800MF() 94 | #elif arch == 'ConvNext-T-0.5x': 95 | #from models.convnext import convnext_tiny_0d5x 96 | #model = convnext_tiny_0d5x() 97 | else: 98 | raise ValueError('TODO') 99 | return model -------------------------------------------------------------------------------- /dbb_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | def transI_fusebn(kernel, bn): 6 | gamma = bn.weight 7 | std = (bn.running_var + bn.eps).sqrt() 8 | return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std 9 | 10 | def transII_addbranch(kernels, biases): 11 | return sum(kernels), sum(biases) 12 | 13 | def transIII_1x1_kxk(k1, b1, k2, b2, groups): 14 | if groups == 1: 15 | k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) # 16 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 17 | else: 18 | k_slices = [] 19 | b_slices = [] 20 | k1_T = k1.permute(1, 0, 2, 3) 21 | k1_group_width = k1.size(0) // groups 22 | k2_group_width = k2.size(0) // groups 23 | for g in range(groups): 24 | k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :] 25 | k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :] 26 | k_slices.append(F.conv2d(k2_slice, k1_T_slice)) 27 | b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3))) 28 | k, b_hat = transIV_depthconcat(k_slices, b_slices) 29 | return k, b_hat + b2 30 | 31 | def transIV_depthconcat(kernels, biases): 32 | return torch.cat(kernels, dim=0), torch.cat(biases) 33 | 34 | def transV_avg(channels, kernel_size, groups): 35 | input_dim = channels // groups 36 | k = torch.zeros((channels, input_dim, kernel_size, kernel_size)) 37 | k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2 38 | return k 39 | 40 | # This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels 41 | def transVI_multiscale(kernel, target_kernel_size): 42 | H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2 43 | W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2 44 | return F.pad(kernel, [W_pixels_to_pad, W_pixels_to_pad, H_pixels_to_pad, H_pixels_to_pad]) -------------------------------------------------------------------------------- /images/budgets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/budgets.png -------------------------------------------------------------------------------- /images/coco_cs.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/coco_cs.PNG -------------------------------------------------------------------------------- /images/imagenet1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/imagenet1.PNG -------------------------------------------------------------------------------- /images/imagenet2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/imagenet2.PNG -------------------------------------------------------------------------------- /images/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/intro.png -------------------------------------------------------------------------------- /images/norm.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/norm.PNG -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/overview.png -------------------------------------------------------------------------------- /images/similarity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/similarity.png -------------------------------------------------------------------------------- /images/supp_grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/supp_grad.png -------------------------------------------------------------------------------- /models/repvgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | import copy 5 | from convnet_utils import conv_bn, conv_bn_relu 6 | 7 | ''' 8 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): 9 | result = nn.Sequential() 10 | result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 11 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)) 12 | result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) 13 | return result 14 | ''' 15 | 16 | class RepVGG(nn.Module): 17 | 18 | def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False): 19 | super(RepVGG, self).__init__() 20 | 21 | assert len(width_multiplier) == 4 22 | 23 | self.deploy = deploy 24 | self.override_groups_map = override_groups_map or dict() 25 | self.use_se = use_se 26 | 27 | assert 0 not in self.override_groups_map 28 | 29 | self.in_planes = min(64, int(64 * width_multiplier[0])) 30 | 31 | self.stage0 = conv_bn_relu(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1) 32 | self.cur_layer_idx = 1 33 | self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2) 34 | self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2) 35 | self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2) 36 | self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2) 37 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 38 | self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes) 39 | 40 | 41 | def _make_stage(self, planes, num_blocks, stride): 42 | strides = [stride] + [1]*(num_blocks-1) 43 | blocks = [] 44 | for stride in strides: 45 | cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) 46 | blocks.append(conv_bn_relu(in_channels=self.in_planes, out_channels=planes, kernel_size=3, 47 | stride=stride, padding=1, groups=cur_groups)) 48 | self.in_planes = planes 49 | self.cur_layer_idx += 1 50 | return nn.Sequential(*blocks) 51 | 52 | def forward(self, x): 53 | out = self.stage0(x) 54 | out = self.stage1(out) 55 | out = self.stage2(out) 56 | out = self.stage3(out) 57 | out = self.stage4(out) 58 | out = self.gap(out) 59 | out = out.view(out.size(0), -1) 60 | out = self.linear(out) 61 | return out 62 | 63 | 64 | optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] 65 | g2_map = {l: 2 for l in optional_groupwise_layers} 66 | g4_map = {l: 4 for l in optional_groupwise_layers} 67 | 68 | def create_RepVGG_A0(deploy=False): 69 | return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, 70 | width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy) 71 | 72 | def create_RepVGG_A1(deploy=False): 73 | return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, 74 | width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy) 75 | 76 | def create_RepVGG_A2(deploy=False): 77 | return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, 78 | width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy) 79 | 80 | def create_RepVGG_B0(deploy=False): 81 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 82 | width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy) 83 | 84 | def create_RepVGG_B1(deploy=False): 85 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 86 | width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy) 87 | 88 | def create_RepVGG_B1g2(deploy=False): 89 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 90 | width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy) 91 | 92 | def create_RepVGG_B1g4(deploy=False): 93 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 94 | width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy) 95 | 96 | 97 | def create_RepVGG_B2(deploy=False): 98 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 99 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy) 100 | 101 | def create_RepVGG_B2g2(deploy=False): 102 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 103 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy) 104 | 105 | def create_RepVGG_B2g4(deploy=False): 106 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 107 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy) 108 | 109 | 110 | def create_RepVGG_B3(deploy=False): 111 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 112 | width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy) 113 | 114 | def create_RepVGG_B3g2(deploy=False): 115 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 116 | width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy) 117 | 118 | def create_RepVGG_B3g4(deploy=False): 119 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 120 | width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy) 121 | 122 | def create_RepVGG_D2se(deploy=False): 123 | return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000, 124 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True) 125 | 126 | 127 | func_dict = { 128 | 'RepVGG-A0': create_RepVGG_A0, 129 | 'RepVGG-A1': create_RepVGG_A1, 130 | 'RepVGG-A2': create_RepVGG_A2, 131 | 'RepVGG-B0': create_RepVGG_B0, 132 | 'RepVGG-B1': create_RepVGG_B1, 133 | 'RepVGG-B1g2': create_RepVGG_B1g2, 134 | 'RepVGG-B1g4': create_RepVGG_B1g4, 135 | 'RepVGG-B2': create_RepVGG_B2, 136 | 'RepVGG-B2g2': create_RepVGG_B2g2, 137 | 'RepVGG-B2g4': create_RepVGG_B2g4, 138 | 'RepVGG-B3': create_RepVGG_B3, 139 | 'RepVGG-B3g2': create_RepVGG_B3g2, 140 | 'RepVGG-B3g4': create_RepVGG_B3g4, 141 | 'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper. 142 | } 143 | def get_RepVGG_func_by_name(name): 144 | return func_dict[name] 145 | 146 | 147 | 148 | # Use this for converting a RepVGG model or a bigger model with RepVGG as its component 149 | # Use like this 150 | # model = create_RepVGG_A0(deploy=False) 151 | # train model or load weights 152 | # repvgg_model_convert(model, save_path='repvgg_deploy.pth') 153 | # If you want to preserve the original model, call with do_copy=True 154 | 155 | # ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like 156 | # train_backbone = create_RepVGG_B2(deploy=False) 157 | # train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth')) 158 | # train_pspnet = build_pspnet(backbone=train_backbone) 159 | # segmentation_train(train_pspnet) 160 | # deploy_pspnet = repvgg_model_convert(train_pspnet) 161 | # segmentation_test(deploy_pspnet) 162 | # ===================== example_pspnet.py shows an example 163 | 164 | def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True): 165 | if do_copy: 166 | model = copy.deepcopy(model) 167 | for module in model.modules(): 168 | if hasattr(module, 'switch_to_deploy'): 169 | module.switch_to_deploy() 170 | if save_path is not None: 171 | torch.save(model.state_dict(), save_path) 172 | return model 173 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from convnet_utils import conv_bn, conv_bn_relu 5 | from basic import ConvBN 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | def __init__(self, in_planes, planes, stride=1): 10 | super(BasicBlock, self).__init__() 11 | if stride != 1 or in_planes != self.expansion * planes: 12 | self.shortcut = conv_bn(in_channels=in_planes, out_channels=self.expansion * planes, kernel_size=1, stride=stride, assign_type=ConvBN) 13 | else: 14 | self.shortcut = nn.Identity() 15 | self.conv1 = conv_bn_relu(in_channels=in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1) 16 | self.conv2 = conv_bn(in_channels=planes, out_channels=self.expansion * planes, kernel_size=3, stride=1, padding=1) 17 | 18 | def forward(self, x): 19 | out = self.conv1(x) 20 | out = self.conv2(out) 21 | out = out + self.shortcut(x) 22 | out = F.relu(out) 23 | return out 24 | 25 | 26 | class Bottleneck(nn.Module): 27 | expansion = 4 28 | def __init__(self, in_planes, planes, stride=1): 29 | super(Bottleneck, self).__init__() 30 | 31 | if stride != 1 or in_planes != self.expansion*planes: 32 | self.shortcut = conv_bn(in_planes, self.expansion*planes, kernel_size=1, stride=stride) 33 | else: 34 | self.shortcut = nn.Identity() 35 | 36 | self.conv1 = conv_bn_relu(in_planes, planes, kernel_size=1) 37 | self.conv2 = conv_bn_relu(planes, planes, kernel_size=3, stride=stride, padding=1) 38 | self.conv3 = conv_bn(planes, self.expansion*planes, kernel_size=1) 39 | 40 | def forward(self, x): 41 | out = self.conv1(x) 42 | out = self.conv2(out) 43 | out = self.conv3(out) 44 | out += self.shortcut(x) 45 | out = F.relu(out) 46 | return out 47 | 48 | 49 | class ResNet(nn.Module): 50 | def __init__(self, block, num_blocks, num_classes=1000, width_multiplier=1): 51 | super(ResNet, self).__init__() 52 | 53 | self.in_planes = int(64 * width_multiplier) 54 | self.stage0 = nn.Sequential() 55 | self.stage0.add_module('conv1', conv_bn_relu(in_channels=3, out_channels=self.in_planes, kernel_size=7, stride=2, padding=3)) 56 | self.stage0.add_module('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 57 | self.stage1 = self._make_stage(block, int(64 * width_multiplier), num_blocks[0], stride=1) 58 | self.stage2 = self._make_stage(block, int(128 * width_multiplier), num_blocks[1], stride=2) 59 | self.stage3 = self._make_stage(block, int(256 * width_multiplier), num_blocks[2], stride=2) 60 | self.stage4 = self._make_stage(block, int(512 * width_multiplier), num_blocks[3], stride=2) 61 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 62 | self.linear = nn.Linear(int(512*block.expansion*width_multiplier), num_classes) 63 | 64 | output_channels = int(512*block.expansion*width_multiplier) 65 | 66 | def _make_stage(self, block, planes, num_blocks, stride): 67 | strides = [stride] + [1]*(num_blocks-1) 68 | blocks = [] 69 | for stride in strides: 70 | if block is Bottleneck: 71 | blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride)) 72 | else: 73 | blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride)) 74 | self.in_planes = int(planes * block.expansion) 75 | return nn.Sequential(*blocks) 76 | 77 | def forward(self, x): 78 | out = self.stage0(x) 79 | out = self.stage1(out) 80 | out = self.stage2(out) 81 | out = self.stage3(out) 82 | out = self.stage4(out) 83 | out = self.gap(out) 84 | out = out.view(out.size(0), -1) 85 | out = self.linear(out) 86 | return out 87 | 88 | 89 | def create_Res18(): 90 | return ResNet(BasicBlock, [2,2,2,2], num_classes=1000, width_multiplier=1) 91 | 92 | def create_Res18_1d5x(): 93 | return ResNet(BasicBlock, [2,2,2,2], num_classes=1000, width_multiplier=1.5) 94 | 95 | def create_Res18_2x(): 96 | return ResNet(BasicBlock, [2,2,2,2], num_classes=1000, width_multiplier=2) 97 | 98 | def create_Res34(): 99 | return ResNet(BasicBlock, [3,4,6,3], num_classes=1000, width_multiplier=1) 100 | 101 | def create_Res50(): 102 | return ResNet(Bottleneck, [3,4,6,3], num_classes=1000, width_multiplier=1) 103 | 104 | def create_Res101(): 105 | return ResNet(Bottleneck, [3,4,23,3], num_classes=1000, width_multiplier=1) -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from convnet_utils import conv_bn, conv_bn_relu 4 | 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | def __init__(self, in_planes, planes, stride=1, cardinality=32, width_per_group=4): 9 | super(Bottleneck, self).__init__() 10 | 11 | D = cardinality * int(planes*width_per_group/64) 12 | 13 | if stride != 1 or in_planes != self.expansion*planes: 14 | self.shortcut = conv_bn(in_planes, self.expansion*planes, kernel_size=1, stride=stride) 15 | else: 16 | self.shortcut = nn.Identity() 17 | 18 | self.conv1 = conv_bn_relu(in_planes, D, kernel_size=1) 19 | self.conv2 = conv_bn_relu(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality) 20 | self.conv3 = conv_bn(D, self.expansion*planes, kernel_size=1) 21 | 22 | def forward(self, x): 23 | out = self.conv1(x) 24 | out = self.conv2(out) 25 | out = self.conv3(out) 26 | out += self.shortcut(x) 27 | out = F.relu(out) 28 | return out 29 | 30 | class ResNeXt(nn.Module): 31 | def __init__(self, block, num_blocks, num_classes=1000, width_multiplier=1, cardinality=32, width_per_group=4): 32 | super(ResNeXt, self).__init__() 33 | 34 | self.in_planes = int(64 * width_multiplier) 35 | self.stage0 = nn.Sequential() 36 | self.stage0.add_module('conv1', conv_bn_relu(in_channels=3, out_channels=self.in_planes, kernel_size=7, stride=2, padding=3)) 37 | self.stage0.add_module('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 38 | self.stage1 = self._make_stage(block, int(64 * width_multiplier), num_blocks[0], stride=1, cardinality=cardinality, width_per_group=width_per_group) 39 | self.stage2 = self._make_stage(block, int(128 * width_multiplier), num_blocks[1], stride=2, cardinality=cardinality, width_per_group=width_per_group) 40 | self.stage3 = self._make_stage(block, int(256 * width_multiplier), num_blocks[2], stride=2, cardinality=cardinality, width_per_group=width_per_group) 41 | self.stage4 = self._make_stage(block, int(512 * width_multiplier), num_blocks[3], stride=2, cardinality=cardinality, width_per_group=width_per_group) 42 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 43 | self.linear = nn.Linear(int(512*block.expansion*width_multiplier), num_classes) 44 | 45 | output_channels = int(512*block.expansion*width_multiplier) 46 | 47 | def _make_stage(self, block, planes, num_blocks, stride, cardinality, width_per_group): 48 | strides = [stride] + [1]*(num_blocks-1) 49 | blocks = [] 50 | for stride in strides: 51 | blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride, cardinality=cardinality, width_per_group=width_per_group)) 52 | self.in_planes = int(planes * block.expansion) 53 | return nn.Sequential(*blocks) 54 | 55 | def forward(self, x): 56 | out = self.stage0(x) 57 | out = self.stage1(out) 58 | out = self.stage2(out) 59 | out = self.stage3(out) 60 | out = self.stage4(out) 61 | out = self.gap(out) 62 | out = out.view(out.size(0), -1) 63 | out = self.linear(out) 64 | return out 65 | 66 | 67 | def create_Res50_32x4d(): 68 | return ResNeXt(Bottleneck, [3,4,6,3], num_classes=1000, width_multiplier=1, cardinality=32, width_per_group=4) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torchvision.datasets as datasets 7 | from utils import accuracy, ProgressMeter, AverageMeter, get_default_ImageNet_val_loader 8 | from convnet_utils import switch_deploy_flag, switch_conv_bn_impl, build_model 9 | 10 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Test') 11 | parser.add_argument('--data', metavar='DIR', type=str, help='path to dataset') 12 | parser.add_argument('mode', metavar='MODE', default='deploy', choices=['train', 'deploy'], help='train or deploy') 13 | parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file') 14 | parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18') 15 | parser.add_argument('-t', '--blocktype', metavar='BLK', default='OREPA') 16 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 17 | help='number of data loading workers (default: 4)') 18 | parser.add_argument('-b', '--val-batch-size', default=100, type=int, 19 | metavar='N', 20 | help='mini-batch size (default: 100) for test') 21 | 22 | def test(): 23 | args = parser.parse_args() 24 | args.data = '/disk1/humu.hm/ImageNet/ILSVRC2015/Data/CLS-LOC/' 25 | 26 | switch_deploy_flag(args.mode == 'deploy') 27 | switch_conv_bn_impl(args.blocktype) 28 | model = build_model(args.arch) 29 | 30 | if not torch.cuda.is_available(): 31 | print('using CPU, this will be slow') 32 | use_gpu = False 33 | else: 34 | model = model.cuda() 35 | use_gpu = True 36 | 37 | # define loss function (criterion) and optimizer 38 | criterion = torch.nn.CrossEntropyLoss().cuda() 39 | 40 | if 'hdf5' in args.weights: 41 | from utils import model_load_hdf5 42 | model_load_hdf5(model, args.weights) 43 | elif os.path.isfile(args.weights): 44 | print("=> loading checkpoint '{}'".format(args.weights)) 45 | checkpoint = torch.load(args.weights) 46 | if 'state_dict' in checkpoint: 47 | checkpoint = checkpoint['state_dict'] 48 | ckpt = {k.replace('module.', ''):v for k,v in checkpoint.items()} # strip the names 49 | model.load_state_dict(ckpt) 50 | else: 51 | print("=> no checkpoint found at '{}'".format(args.weights)) 52 | 53 | 54 | cudnn.benchmark = True 55 | 56 | # Data loading code 57 | val_loader = get_default_ImageNet_val_loader(args) 58 | validate(val_loader, model, criterion, use_gpu) 59 | 60 | 61 | def validate(val_loader, model, criterion, use_gpu): 62 | batch_time = AverageMeter('Time', ':6.3f', warm=True) 63 | losses = AverageMeter('Loss', ':.4e') 64 | top1 = AverageMeter('Acc@1', ':6.2f') 65 | top5 = AverageMeter('Acc@5', ':6.2f') 66 | progress = ProgressMeter( 67 | len(val_loader), 68 | [batch_time, losses, top1, top5], 69 | prefix='Test: ') 70 | 71 | # switch to evaluate mode 72 | model.eval() 73 | 74 | with torch.no_grad(): 75 | #torch.cuda.synchronize() 76 | #end = time.time() 77 | for i, (images, target) in enumerate(val_loader): 78 | if use_gpu: 79 | images = images.cuda(non_blocking=True) 80 | target = target.cuda(non_blocking=True) 81 | 82 | # compute output 83 | torch.cuda.synchronize() 84 | end = time.time() 85 | 86 | output = model(images) 87 | 88 | # measure elapsed time 89 | torch.cuda.synchronize() 90 | batch_time.update(time.time() - end) 91 | 92 | loss = criterion(output, target) 93 | 94 | # measure accuracy and record loss 95 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 96 | losses.update(loss.item(), images.size(0)) 97 | top1.update(acc1[0], images.size(0)) 98 | top5.update(acc5[0], images.size(0)) 99 | 100 | if i % 10 == 0: 101 | progress.display(i) 102 | 103 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 104 | .format(top1=top1, top5=top5)) 105 | 106 | return top1.avg 107 | 108 | 109 | 110 | 111 | if __name__ == '__main__': 112 | test() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | from torch.optim.lr_scheduler import CosineAnnealingLR 18 | from utils import WarmupCosineAnnealingLR 19 | from utils import AverageMeter, accuracy, ProgressMeter, get_default_ImageNet_val_loader, get_default_ImageNet_train_sampler_loader, log_msg 20 | 21 | IMAGENET_TRAINSET_SIZE = 1281167 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 24 | parser.add_argument('--data', metavar='DIR', type=str, help='path to dataset') 25 | 26 | parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18') 27 | parser.add_argument('-t', '--blocktype', metavar='BLK', default='base') 28 | 29 | parser.add_argument('-j', '--workers', default=64, type=int, metavar='N', 30 | help='number of data loading workers (default: 4)') 31 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 32 | help='number of total epochs to run') 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('-b', '--batch-size', default=256, type=int, 36 | metavar='N', 37 | help='mini-batch size (default: 256), this is the total ' 38 | 'batch size of all GPUs on the current node when ' 39 | 'using `Data Parallel or Distributed Data Parallel') 40 | parser.add_argument('--val-batch-size', default=100, type=int, metavar='V', 41 | help='validation batch size') 42 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 43 | metavar='LR', help='initial learning rate', dest='lr') 44 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 45 | help='momentum') 46 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 47 | metavar='W', help='weight decay (default: 1e-4)', 48 | dest='weight_decay') 49 | parser.add_argument('-p', '--print-freq', default=10, type=int, 50 | metavar='N', help='print frequency (default: 10)') 51 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 52 | help='path to latest checkpoint (default: none)') 53 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 54 | help='evaluate model on validation set') 55 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 56 | help='use pre-trained model') 57 | parser.add_argument('--world-size', default=1, type=int, 58 | help='number of nodes for distributed training') 59 | parser.add_argument('--rank', default=0, type=int, 60 | help='node rank for distributed training') 61 | parser.add_argument('--dist-url', default='20333', type=str, 62 | help='url used to set up distributed training') 63 | parser.add_argument('--dist-backend', default='nccl', type=str, 64 | help='distributed backend') 65 | parser.add_argument('--seed', default=None, type=int, 66 | help='seed for initializing training. ') 67 | parser.add_argument('--gpu', default=None, type=int, 68 | help='GPU id to use.') 69 | parser.add_argument('--multiprocessing-distributed', action='store_true', 70 | help='Use multi-processing distributed training to launch ' 71 | 'N processes per node, which has N GPUs. This is the ' 72 | 'fastest way to use PyTorch for either single node or ' 73 | 'multi node data parallel training') 74 | parser.add_argument('--custwd', dest='custwd', action='store_true', 75 | help='Use custom weight decay. It improves the accuracy and makes quantization easier.') 76 | parser.add_argument('--tag', default='', type=str, 77 | help='the tag for identifying the log and model files. Just a string.') 78 | 79 | best_acc1 = 0 80 | 81 | def sgd_optimizer(model, lr, momentum, weight_decay, use_custwd): 82 | params = [] 83 | for key, value in model.named_parameters(): 84 | if not value.requires_grad: 85 | continue 86 | apply_weight_decay = weight_decay 87 | apply_lr = lr 88 | 89 | if (use_custwd and ('blk' in key)) or (use_custwd and ('rbr_dense' in key or 'rbr_1x1' in key)) or 'bias' in key or 'bn' in key or 'prior' in key: 90 | apply_weight_decay = 0 91 | print('set weight decay=0 for {}'.format(key)) 92 | if 'bias' in key: 93 | apply_lr = 2 * lr # Just a Caffe-style common practice. Made no difference. 94 | params += [{'params': [value], 'lr': apply_lr, 'weight_decay': apply_weight_decay}] 95 | optimizer = torch.optim.SGD(params, lr, momentum=momentum) 96 | return optimizer 97 | 98 | def adam_optimizer(model): 99 | params = [] 100 | for key, value in model.named_parameters(): 101 | if not value.requires_grad or not 'prior' in key: 102 | continue 103 | params += [{'params': [value], 'lr': 1e-4, 'betas': (0.5, 0.999)}] 104 | print('adam opt for {}'.format(key)) 105 | optimizer = torch.optim.Adam(params, 1e-4) 106 | return optimizer 107 | 108 | def main(): 109 | args = parser.parse_args() 110 | 111 | args.multiprocessing_distributed = True 112 | args.dist_backend = 'nccl' 113 | #args.data = '/gruntdata/humu.hm/ILSVRC2015/Data/CLS-LOC/' 114 | args.data='/disk1/humu.hm/ImageNet/ILSVRC2015/Data/CLS-LOC/' 115 | args.dist_url = 'tcp://127.0.0.1:' + args.dist_url 116 | 117 | if args.seed is not None: 118 | random.seed(args.seed) 119 | torch.manual_seed(args.seed) 120 | cudnn.deterministic = True 121 | warnings.warn('You have chosen to seed training. ' 122 | 'This will turn on the CUDNN deterministic setting, ' 123 | 'which can slow down your training considerably! ' 124 | 'You may see unexpected behavior when restarting ' 125 | 'from checkpoints.') 126 | 127 | if args.gpu is not None: 128 | warnings.warn('You have chosen a specific GPU. This will completely ' 129 | 'disable data parallelism.') 130 | 131 | if args.dist_url == "env://" and args.world_size == -1: 132 | args.world_size = int(os.environ["WORLD_SIZE"]) 133 | 134 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 135 | 136 | ngpus_per_node = torch.cuda.device_count() 137 | if args.multiprocessing_distributed: 138 | # Since we have ngpus_per_node processes per node, the total world_size 139 | # needs to be adjusted accordingly 140 | args.world_size = ngpus_per_node * args.world_size 141 | # Use torch.multiprocessing.spawn to launch distributed processes: the 142 | # main_worker process function 143 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 144 | else: 145 | # Simply call main_worker function 146 | main_worker(args.gpu, ngpus_per_node, args) 147 | 148 | 149 | def main_worker(gpu, ngpus_per_node, args): 150 | global best_acc1 151 | args.gpu = gpu 152 | log_file = 'results/train_{}_{}_{}_exp.txt'.format(args.arch, args.blocktype, args.tag) 153 | 154 | if args.gpu is not None: 155 | print("Use GPU: {} for training".format(args.gpu)) 156 | 157 | if args.distributed: 158 | if args.dist_url == "env://" and args.rank == -1: 159 | args.rank = int(os.environ["RANK"]) 160 | if args.multiprocessing_distributed: 161 | # For multiprocessing distributed training, rank needs to be the 162 | # global rank among all the processes 163 | args.rank = args.rank * ngpus_per_node + gpu 164 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 165 | world_size=args.world_size, rank=args.rank) 166 | 167 | from convnet_utils import switch_deploy_flag, switch_conv_bn_impl, build_model 168 | switch_deploy_flag(False) 169 | switch_conv_bn_impl(args.blocktype) 170 | model = build_model(args.arch) 171 | 172 | is_main = not args.multiprocessing_distributed or ( 173 | args.multiprocessing_distributed and args.rank % ngpus_per_node == 0) 174 | 175 | if is_main: 176 | for n, p in model.named_parameters(): 177 | print(n, p.size()) 178 | for n, p in model.named_buffers(): 179 | print(n, p.size()) 180 | log_msg('epochs {}, lr {}, weight_decay {}'.format(args.epochs, args.lr, args.weight_decay), log_file) 181 | 182 | if not torch.cuda.is_available(): 183 | print('using CPU, this will be slow') 184 | elif args.distributed: 185 | # For multiprocessing distributed, DistributedDataParallel constructor 186 | # should always set the single device scope, otherwise, 187 | # DistributedDataParallel will use all available devices. 188 | if args.gpu is not None: 189 | torch.cuda.set_device(args.gpu) 190 | model.cuda(args.gpu) 191 | # When using a single GPU per process and per 192 | # DistributedDataParallel, we need to divide the batch size 193 | # ourselves based on the total number of GPUs we have 194 | args.batch_size = int(args.batch_size / ngpus_per_node) 195 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 196 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 197 | else: 198 | model.cuda() 199 | # DistributedDataParallel will divide and allocate batch_size to all 200 | # available GPUs if device_ids are not set 201 | model = torch.nn.parallel.DistributedDataParallel(model) 202 | elif args.gpu is not None: 203 | torch.cuda.set_device(args.gpu) 204 | model = model.cuda(args.gpu) 205 | else: 206 | # DataParallel will divide and allocate batch_size to all available GPUs 207 | model = torch.nn.DataParallel(model).cuda() 208 | 209 | 210 | # define loss function (criterion) and optimizer 211 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 212 | 213 | optimizer = sgd_optimizer(model, args.lr, args.momentum, args.weight_decay, args.custwd) 214 | #arch_optimizer = adam_optimizer(model) 215 | 216 | #lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node) 217 | lr_scheduler = WarmupCosineAnnealingLR(optimizer=optimizer, T_cosine_max=args.epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node, warmup=args.epochs/24) 218 | 219 | # optionally resume from a checkpoint 220 | if args.resume: 221 | if os.path.isfile(args.resume): 222 | print("=> loading checkpoint '{}'".format(args.resume)) 223 | if args.gpu is None: 224 | checkpoint = torch.load(args.resume) 225 | else: 226 | # Map model to be loaded to specified single gpu. 227 | loc = 'cuda:{}'.format(args.gpu) 228 | checkpoint = torch.load(args.resume, map_location=loc) 229 | args.start_epoch = checkpoint['epoch'] 230 | best_acc1 = checkpoint['best_acc1'] 231 | if args.gpu is not None: 232 | # best_acc1 may be from a checkpoint from a different GPU 233 | best_acc1 = best_acc1.to(args.gpu) 234 | model.load_state_dict(checkpoint['state_dict']) 235 | optimizer.load_state_dict(checkpoint['optimizer']) 236 | lr_scheduler.load_state_dict(checkpoint['scheduler']) 237 | print("=> loaded checkpoint '{}' (epoch {})" 238 | .format(args.resume, checkpoint['epoch'])) 239 | else: 240 | print("=> no checkpoint found at '{}'".format(args.resume)) 241 | 242 | cudnn.benchmark = True 243 | 244 | train_sampler, train_loader = get_default_ImageNet_train_sampler_loader(args) 245 | val_loader = get_default_ImageNet_val_loader(args) 246 | 247 | if args.evaluate: 248 | validate(val_loader, model, criterion, args) 249 | return 250 | 251 | for epoch in range(args.start_epoch, args.epochs): 252 | if args.distributed: 253 | train_sampler.set_epoch(epoch) 254 | # adjust_learning_rate(optimizer, epoch, args) 255 | 256 | # train for one epoch 257 | train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler, is_main=is_main) 258 | 259 | if is_main: 260 | # evaluate on validation set 261 | acc1 = validate(val_loader, model, criterion, args) 262 | msg = '{}, epoch {}, acc {}'.format(args.arch, epoch, acc1) 263 | log_msg(msg, log_file) 264 | 265 | # remember best acc@1 and save checkpoint 266 | is_best = acc1 > best_acc1 267 | best_acc1 = max(acc1, best_acc1) 268 | 269 | save_checkpoint({ 270 | 'epoch': epoch + 1, 271 | 'arch': args.arch, 272 | 'state_dict': model.state_dict(), 273 | 'best_acc1': best_acc1, 274 | 'optimizer' : optimizer.state_dict(), 275 | 'scheduler': lr_scheduler.state_dict(), 276 | }, is_best, filename = '/disk1/humu.hm/ckpt/{}_{}_{}.pth.tar'.format(args.arch, args.blocktype, args.tag), 277 | best_filename='/disk1/humu.hm/ckpt/{}_{}_{}_best.pth.tar'.format(args.arch, args.blocktype, args.tag)) 278 | 279 | 280 | def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler, is_main): 281 | batch_time = AverageMeter('Time', ':6.3f', warm=True) 282 | data_time = AverageMeter('Data', ':6.3f') 283 | losses = AverageMeter('Loss', ':.4e') 284 | top1 = AverageMeter('Acc@1', ':6.2f') 285 | top5 = AverageMeter('Acc@5', ':6.2f') 286 | progress = ProgressMeter( 287 | len(train_loader), 288 | [batch_time, data_time, losses, top1, top5, ], 289 | prefix="Epoch: [{}]".format(epoch)) 290 | 291 | # switch to train mode 292 | model.train() 293 | 294 | torch.cuda.synchronize() 295 | end = time.time() 296 | for i, (images, target) in enumerate(train_loader): 297 | # measure data loading time 298 | torch.cuda.synchronize() 299 | data_time.update(time.time() - end) 300 | 301 | images = images.cuda(args.gpu, non_blocking=True) 302 | target = target.cuda(args.gpu, non_blocking=True) 303 | 304 | # compute output 305 | output = model(images) 306 | loss = criterion(output, target) 307 | 308 | if args.custwd: 309 | for module in model.modules(): 310 | if hasattr(module, 'get_custom_L2'): 311 | loss += args.weight_decay * 0.5 * module.get_custom_L2() 312 | #if hasattr(module, 'weight_gen'): 313 | #loss += args.weight_decay * 0.5 * ((module.weight_gen()**2).sum()) 314 | #if hasattr(module, 'weight'): 315 | #loss += args.weight_decay * 0.5 * ((module.weight**2).sum()) 316 | 317 | # measure accuracy and record loss 318 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 319 | losses.update(loss.item(), images.size(0)) 320 | top1.update(acc1[0], images.size(0)) 321 | top5.update(acc5[0], images.size(0)) 322 | 323 | # compute gradient and do SGD step 324 | optimizer.zero_grad() 325 | loss.backward() 326 | optimizer.step() 327 | 328 | # measure elapsed time 329 | torch.cuda.synchronize() 330 | batch_time.update(time.time() - end) 331 | end = time.time() 332 | 333 | lr_scheduler.step() 334 | 335 | if is_main and i % args.print_freq == 0: 336 | progress.display(i) 337 | if is_main and i % 1000 == 0: 338 | print('cur lr: ', lr_scheduler.get_lr()[0]) 339 | 340 | 341 | 342 | def validate(val_loader, model, criterion, args): 343 | batch_time = AverageMeter('Time', ':6.3f') 344 | losses = AverageMeter('Loss', ':.4e') 345 | top1 = AverageMeter('Acc@1', ':6.2f') 346 | top5 = AverageMeter('Acc@5', ':6.2f') 347 | progress = ProgressMeter( 348 | len(val_loader), 349 | [batch_time, losses, top1, top5], 350 | prefix='Test: ') 351 | 352 | # switch to evaluate mode 353 | model.eval() 354 | 355 | with torch.no_grad(): 356 | end = time.time() 357 | for i, (images, target) in enumerate(val_loader): 358 | if args.gpu is not None: 359 | images = images.cuda(args.gpu, non_blocking=True) 360 | if torch.cuda.is_available(): 361 | target = target.cuda(args.gpu, non_blocking=True) 362 | 363 | # compute output 364 | output = model(images) 365 | loss = criterion(output, target) 366 | 367 | # measure accuracy and record loss 368 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 369 | losses.update(loss.item(), images.size(0)) 370 | top1.update(acc1[0], images.size(0)) 371 | top5.update(acc5[0], images.size(0)) 372 | 373 | # measure elapsed time 374 | batch_time.update(time.time() - end) 375 | end = time.time() 376 | 377 | if i % args.print_freq == 0: 378 | progress.display(i) 379 | 380 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 381 | .format(top1=top1, top5=top5)) 382 | 383 | return top1.avg 384 | 385 | 386 | def save_checkpoint(state, is_best, filename, best_filename): 387 | torch.save(state, filename) 388 | if is_best: 389 | shutil.copyfile(filename, best_filename) 390 | 391 | 392 | if __name__ == '__main__': 393 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torchvision.datasets as datasets 4 | import os 5 | import torchvision.transforms as transforms 6 | import PIL 7 | 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self, name, fmt=':f', warm=False): 14 | self.name = name 15 | self.fmt = fmt 16 | self.reset() 17 | 18 | self.time_thres = -1 19 | if warm: 20 | self.time_thres = 20 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=1): 29 | if self.time_thres > 0: 30 | self.time_thres -= 1 31 | return 32 | 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | def __str__(self): 39 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 40 | return fmtstr.format(**self.__dict__) 41 | 42 | 43 | class ProgressMeter(object): 44 | def __init__(self, num_batches, meters, prefix=""): 45 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 46 | self.meters = meters 47 | self.prefix = prefix 48 | 49 | def display(self, batch): 50 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 51 | entries += [str(meter) for meter in self.meters] 52 | print('\t'.join(entries)) 53 | 54 | def _get_batch_fmtstr(self, num_batches): 55 | num_digits = len(str(num_batches // 1)) 56 | fmt = '{:' + str(num_digits) + 'd}' 57 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 58 | 59 | 60 | def accuracy(output, target, topk=(1,)): 61 | """Computes the accuracy over the k top predictions for the specified values of k""" 62 | with torch.no_grad(): 63 | maxk = max(topk) 64 | batch_size = target.size(0) 65 | 66 | _, pred = output.topk(maxk, 1, True, True) 67 | pred = pred.t() 68 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 69 | 70 | res = [] 71 | for k in topk: 72 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 73 | res.append(correct_k.mul_(100.0 / batch_size)) 74 | return res 75 | 76 | def load_checkpoint(model, ckpt_path): 77 | checkpoint = torch.load(ckpt_path) 78 | if 'state_dict' in checkpoint: 79 | checkpoint = checkpoint['state_dict'] 80 | ckpt = {} 81 | for k, v in checkpoint.items(): 82 | if k.startswith('module.'): 83 | ckpt[k[7:]] = v 84 | else: 85 | ckpt[k] = v 86 | model.load_state_dict(ckpt) 87 | 88 | def read_hdf5(file_path): 89 | import h5py 90 | import numpy as np 91 | result = {} 92 | with h5py.File(file_path, 'r') as f: 93 | for k in f.keys(): 94 | value = np.asarray(f[k]) 95 | result[str(k).replace('+', '/')] = value 96 | print('read {} arrays from {}'.format(len(result), file_path)) 97 | f.close() 98 | return result 99 | 100 | def model_load_hdf5(model:torch.nn.Module, hdf5_path, ignore_keys='stage0.'): 101 | weights_dict = read_hdf5(hdf5_path) 102 | for name, param in model.named_parameters(): 103 | print('load param: ', name, param.size()) 104 | if name in weights_dict: 105 | np_value = weights_dict[name] 106 | else: 107 | np_value = weights_dict[name.replace(ignore_keys, '')] 108 | value = torch.from_numpy(np_value).float() 109 | assert tuple(value.size()) == tuple(param.size()) 110 | param.data = value 111 | for name, param in model.named_buffers(): 112 | print('load buffer: ', name, param.size()) 113 | if name in weights_dict: 114 | np_value = weights_dict[name] 115 | else: 116 | np_value = weights_dict[name.replace(ignore_keys, '')] 117 | value = torch.from_numpy(np_value).float() 118 | assert tuple(value.size()) == tuple(param.size()) 119 | param.data = value 120 | 121 | class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): 122 | 123 | def __init__(self, optimizer, T_cosine_max, eta_min=0, last_epoch=-1, warmup=0): 124 | self.eta_min = eta_min 125 | self.T_cosine_max = T_cosine_max 126 | self.warmup = warmup 127 | super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 128 | 129 | def get_lr(self): 130 | if self.last_epoch < self.warmup: 131 | return [self.last_epoch / self.warmup * base_lr for base_lr in self.base_lrs] 132 | else: 133 | return [self.eta_min + (base_lr - self.eta_min) * 134 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup) / (self.T_cosine_max - self.warmup))) / 2 135 | for base_lr in self.base_lrs] 136 | 137 | 138 | def log_msg(message, log_file): 139 | print(message) 140 | with open(log_file, 'a') as f: 141 | print(message, file=f) 142 | 143 | def get_ImageNet_train_dataset(args, trans): 144 | traindir = os.path.join(args.data, 'train') 145 | train_dataset = datasets.ImageFolder(traindir, trans) 146 | return train_dataset 147 | 148 | 149 | def get_ImageNet_val_dataset(args, trans): 150 | traindir = os.path.join(args.data, 'val') 151 | val_dataset = datasets.ImageFolder(traindir, trans) 152 | return val_dataset 153 | 154 | 155 | def get_default_train_trans(args): 156 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 157 | std=[0.229, 0.224, 0.225]) 158 | if (not hasattr(args, 'resolution')) or args.resolution == 224: 159 | trans = transforms.Compose([ 160 | transforms.RandomResizedCrop(224), 161 | transforms.RandomHorizontalFlip(), 162 | transforms.ToTensor(), 163 | normalize]) 164 | else: 165 | raise ValueError('Not yet implemented.') 166 | return trans 167 | 168 | 169 | def get_default_val_trans(args): 170 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 171 | std=[0.229, 0.224, 0.225]) 172 | if (not hasattr(args, 'resolution')) or args.resolution == 224: 173 | trans = transforms.Compose([ 174 | transforms.Resize(256), 175 | transforms.CenterCrop(224), 176 | transforms.ToTensor(), 177 | normalize]) 178 | else: 179 | trans = transforms.Compose([ 180 | transforms.Resize(args.resolution, interpolation=PIL.Image.BILINEAR), 181 | transforms.CenterCrop(args.resolution), 182 | transforms.ToTensor(), 183 | normalize, 184 | ]) 185 | return trans 186 | 187 | 188 | def get_default_ImageNet_train_sampler_loader(args): 189 | train_trans = get_default_train_trans(args) 190 | train_dataset = get_ImageNet_train_dataset(args, train_trans) 191 | if args.distributed: 192 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 193 | else: 194 | train_sampler = None 195 | train_loader = torch.utils.data.DataLoader( 196 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 197 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 198 | return train_sampler, train_loader 199 | 200 | 201 | def get_default_ImageNet_val_loader(args): 202 | val_trans = get_default_val_trans(args) 203 | val_dataset = get_ImageNet_val_dataset(args, val_trans) 204 | val_loader = torch.utils.data.DataLoader( 205 | val_dataset, 206 | batch_size=args.val_batch_size, shuffle=False, 207 | num_workers=args.workers, pin_memory=True) 208 | return val_loader --------------------------------------------------------------------------------