├── LICENSE
├── README.md
├── basic.py
├── blocks.py
├── blocks_repvgg.py
├── convert.py
├── convnet_utils.py
├── dbb_transforms.py
├── images
├── budgets.png
├── coco_cs.PNG
├── imagenet1.PNG
├── imagenet2.PNG
├── intro.png
├── norm.PNG
├── overview.png
├── similarity.png
└── supp_grad.png
├── models
├── repvgg.py
├── resnet.py
└── resnext.py
├── test.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OREPA: Online Convolutional Re-parameterization
2 | This repo is the PyTorch implementation of our paper to appear in CVPR2022 on ["Online Convolutional Re-parameterization"](https://arxiv.org/abs/2204.00826), authored by
3 | Mu Hu, [Junyi Feng](https://github.com/Sixkplus), [Jiashen Hua](https://github.com/JerExJs), Baisheng Lai, Jianqiang Huang, [Xiaojin Gong](https://person.zju.edu.cn/en/gongxj) and [Xiansheng Hua](https://damo.alibaba.com/labs/city-brain) from Zhejiang University and Alibaba Cloud.
4 |
5 | ## What is Structural Re-parameterization?
6 | + Re-parameterization (Re-param) means different architectures can be mutually converted through equivalent transformation of parameters. For example, a branch of 1x1 convolution and a branch of 3x3 convolution, can be transferred into a single branch of 3x3 convolution for faster inference.
7 | + When the model for deployment is fixed, the task of re-param can be regarded as finding a complex training-time structure, which can be transfered back to the original one, for free performance improvements.
8 |
9 |
10 |

11 |
12 |
13 | ## Why do we propose Online RE-PAram? (OREPA)
14 | + While current re-param blocks ([ACNet](https://github.com/DingXiaoH/ACNet), [ExpandNet](https://github.com/GUOShuxuan/expandnets), [ACNetv2](https://github.com/DingXiaoH/DiverseBranchBlock), *etc*) are still feasible for small models, more complecated design for further performance gain on larger models could lead to unaffordable training budgets.
15 | + We observed that batch normalization (norm) layers are significant in re-param blocks, while their training-time non-linearity prevents us from optimizing computational costs during training.
16 |
17 | 
18 |
19 | ## What is OREPA?
20 | OREPA is a two-step pipeline.
21 | + Linearization: Replace the branch-wise norm layers to scaling layers to enable the linear squeezing of a multi-branch/layer topology.
22 | + Squeezing: Squeeze the linearized block into a single layer, where the convolution upon feature maps is reduced from multiple times to one.
23 |
24 | 
25 |
26 | ## How does OREPA work?
27 | + Through OREPA we could reduce the training budgets while keeping a comparable performance. Then we improve accuracy by additional components, which brings minor extra training costs since they are merged in an online scheme.
28 | + We theoretically present that the removal of branch-wise norm layers risks a multi-branch structure degrading into a single-branch one, indicating that the norm-scaling layer replacement is critical for protecting branch diversity.
29 |
30 | 
31 |
32 | ## ImageNet Results
33 |
34 |

35 |
36 |
37 | 
38 |
39 | Create a new issue for any code-related questions. Feel free to direct me as well at muhu@zju.edu.cn for any paper-related questions.
40 |
41 | ## Contents
42 | 1. [Dependency](#dependency)
43 | 2. [Checkpoints](#checkpoints)
44 | 3. [Training](#training)
45 | 4. [Evaluation](#evaluation)
46 | 5. [Transfer Learning on COCO and Cityscapes](#transfer-learning-on-coco-and-cityscapes)
47 | 6. [About Quantization and Gradient Tweaking](#about-quantization-and-gradient-tweaking)
48 | 7. [Citation](#citation)
49 |
50 |
51 | ## Dependency
52 | Models released in this work is trained and tested on:
53 | + CentOS Linux
54 | + Python 3.8.8 (Anaconda 4.9.1)
55 | + PyTorch 1.9.0 / torchvision 0.10.0
56 | + NVIDIA CUDA 10.2
57 | + 4x NVIDIA V100 GPUs
58 |
59 | ```bash
60 | pip install torch torchvision
61 | pip install numpy matplotlib Pillow
62 | pip install scikit-image
63 | ```
64 |
65 | ## Checkpoints
66 | Download our pre-trained models with OREPA:
67 | - [ResNet-18](https://drive.google.com/file/d/1Z0cfmxLLWD2xjgUXpE2Y_4vxS4ij16Sx/view?usp=sharing)
68 | - [ResNet-34](https://drive.google.com/file/d/1tOC4yoslHF829Yb66_eE4qsFoazf6uoY/view?usp=sharing)
69 | - [ResNet-50](https://drive.google.com/file/d/1Jn92wGHlFzkPjQxAmYdRU_EbSRQjsw0v/view?usp=sharing)
70 | - [ResNet-101](https://drive.google.com/file/d/10L4fwYlB21vlMKOGIyXdlGBbf3l8nflI/view?usp=sharing)
71 | - [RepVGG-A0](https://drive.google.com/file/d/1r674bxWKL5dwDA8OPEuGWN8zbiUGkqsx/view?usp=sharing)
72 | - [RepVGG-A1](https://drive.google.com/file/d/1NVRn8Xave-0jY3R0xAXCShdBcML6f-3q/view?usp=sharing)
73 | - [RepVGG-A2](https://drive.google.com/file/d/1ImHkTct0ACDtOw8sPgKhqerBFigg32s3/view?usp=sharing)
74 | - [WideResNet-18(x2)](https://drive.google.com/file/d/1gseMlq9JGntggyZoK_Bjt7xiSPNs5deF/view?usp=sharing)
75 | - [ResNeXt-50](https://drive.google.com/file/d/1p_eZDhMHQ_xmfXd8Z1RZnWfgwUbi7WH-/view?usp=sharing)
76 |
77 | Note that we don't need to decompress the pre-trained models. Just load the files of .pth.tar format directly.
78 |
79 | ## Training
80 | A complete list of training options is available with
81 | ```bash
82 | python train.py -h
83 | python test.py -h
84 | python convert.py -h
85 | ```
86 |
87 | 1. Train ResNets (ResNeXt and WideResNet included)
88 | ```bash
89 | CUDA_VISIBLE_DEVICES="0,1,2,3" python train.py -a ResNet-18 -t OREPA --data [imagenet-path]
90 | # -a for architecture (ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-18-2x, ResNeXt-50)
91 | # -t for re-param method (base, DBB, OREPA)
92 | ```
93 |
94 | 2. Train RepVGGs
95 | ```bash
96 | CUDA_VISIBLE_DEVICES="0,1,2,3" python train.py -a RepVGG-A0 -t OREPA_VGG --data [imagenet-path]
97 | # -a for architecture (RepVGG-A0, RepVGG-A1, RepVGG-A2)
98 | # -t for re-param method (base, RepVGG, OREPA_VGG)
99 | ```
100 |
101 | ## Evaluation
102 | 1. Use your self-trained model or our pretrained model
103 | ```bash
104 | CUDA_VISIBLE_DEVICES="0" python test.py train [trained-model-path] -a ResNet-18 -t OREPA
105 | ```
106 |
107 | 2. Convert the training-time models into inference-time models
108 | ```bash
109 | CUDA_VISIBLE_DEVICES="0" python convert.py [trained-model-path] [deploy-model-path-to-save] -a ResNet-18 -t OREPA
110 | ```
111 |
112 | 3. Evaluate with the converted model
113 | ```bash
114 | CUDA_VISIBLE_DEVICES="0" python test.py deploy [deploy-model-path] -a ResNet-18 -t OREPA
115 | ```
116 |
117 | ## Transfer Learning on COCO and Cityscapes
118 | We use [mmdetection](https://github.com/open-mmlab/mmdetection) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) tools on COCO and Cityscapes respectively. If you decide to use our pretrained model for downstream tasks, it is strongly suggested that the learning rate of the first stem layer should be fine adjusted, since the deep linear stem layer has a very different weight distribution from the vanilla one after ImageNet training. Contact [@Sixkplus](https://github.com/Sixkplus) (Junyi Feng) for more details on configurations and checkpoints of the reported ResNet-50-backbone models.
119 |
120 |
121 |

122 |
123 |
124 | ## About Quantization and Gradient Tweaking
125 | For re-param models, special weight regulization strategies are required for furthur quantization. Meanwhile, dynamic gradient tweaking or differential searching methods might greatly boost the performance. Currently we have not deployed such techniques to OREPA yet. However such methods could be probably applied to our industrial usage in the future. For experience exchanging and sharing on such topics please contact [@Sixkplus](https://github.com/Sixkplus) (Junyi Feng).
126 |
127 |
128 | ## Citation
129 | If you use our code or method in your work, please cite the following:
130 |
131 | @inproceedings{hu22OREPA,
132 | title={Online Convolutional Re-parameterization},
133 | author={Mu Hu and Junyi Feng and Jiashen Hua and Baisheng Lai and Jianqiang Huang and Xiansheng Hua and Xiaojin Gong},
134 | booktitle={CVPR},
135 | year={2022}
136 | }
137 |
138 | ## Related Repositories
139 | Codes of this work is developed upon Xiaohan Ding's re-param repositories ["Diverse Branch Block: Building a Convolution as an Inception-like Unit"](https://github.com/DingXiaoH/DiverseBranchBlock) and ["RepVGG: Making VGG-style ConvNets Great Again"](https://github.com/DingXiaoH/RepVGG) with similar protocols. [Xiaohan Ding](https://scholar.google.com/citations?user=CIjw0KoAAAAJ&hl=en) is a Ph.D. from Tsinghua University and an expert in structural re-parameterization.
140 |
--------------------------------------------------------------------------------
/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import torch.nn.init as init
6 | import math
7 | from dbb_transforms import transI_fusebn
8 |
9 | class ConvBN(nn.Module):
10 | def __init__(self, in_channels, out_channels, kernel_size,
11 | stride=1, padding=0, dilation=1, groups=1, deploy=False, nonlinear=None):
12 | super().__init__()
13 | if nonlinear is None:
14 | self.nonlinear = nn.Identity()
15 | else:
16 | self.nonlinear = nonlinear
17 | if deploy:
18 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
19 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)
20 | else:
21 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
22 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)
23 | self.bn = nn.BatchNorm2d(num_features=out_channels)
24 |
25 | def forward(self, x):
26 | if hasattr(self, 'bn'):
27 | return self.nonlinear(self.bn(self.conv(x)))
28 | else:
29 | return self.nonlinear(self.conv(x))
30 |
31 | def switch_to_deploy(self):
32 | kernel, bias = transI_fusebn(self.conv.weight, self.bn)
33 | conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size,
34 | stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True)
35 | conv.weight.data = kernel
36 | conv.bias.data = bias
37 | for para in self.parameters():
38 | para.detach_()
39 | self.__delattr__('conv')
40 | self.__delattr__('bn')
41 | self.conv = conv
42 |
43 | class IdentityBasedConv1x1(nn.Conv2d):
44 |
45 | def __init__(self, channels, groups=1):
46 | super(IdentityBasedConv1x1, self).__init__(in_channels=channels, out_channels=channels, kernel_size=1, stride=1, padding=0, groups=groups, bias=False)
47 |
48 | assert channels % groups == 0
49 | input_dim = channels // groups
50 | id_value = np.zeros((channels, input_dim, 1, 1))
51 | for i in range(channels):
52 | id_value[i, i % input_dim, 0, 0] = 1
53 | self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
54 | nn.init.zeros_(self.weight)
55 |
56 | def forward(self, input):
57 | kernel = self.weight + self.id_tensor.to(self.weight.device)
58 | result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups)
59 | return result
60 |
61 | def get_actual_kernel(self):
62 | return self.weight + self.id_tensor.to(self.weight.device)
63 |
64 |
65 | class BNAndPadLayer(nn.Module):
66 | def __init__(self,
67 | pad_pixels,
68 | num_features,
69 | eps=1e-5,
70 | momentum=0.1,
71 | affine=True,
72 | track_running_stats=True):
73 | super(BNAndPadLayer, self).__init__()
74 | self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)
75 | self.pad_pixels = pad_pixels
76 |
77 | def forward(self, input):
78 | output = self.bn(input)
79 | if self.pad_pixels > 0:
80 | if self.bn.affine:
81 | pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(self.bn.running_var + self.bn.eps)
82 | else:
83 | pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
84 | output = F.pad(output, [self.pad_pixels] * 4)
85 | pad_values = pad_values.view(1, -1, 1, 1)
86 | output[:, :, 0:self.pad_pixels, :] = pad_values
87 | output[:, :, -self.pad_pixels:, :] = pad_values
88 | output[:, :, :, 0:self.pad_pixels] = pad_values
89 | output[:, :, :, -self.pad_pixels:] = pad_values
90 | return output
91 |
92 | @property
93 | def weight(self):
94 | return self.bn.weight
95 |
96 | @property
97 | def bias(self):
98 | return self.bn.bias
99 |
100 | @property
101 | def running_mean(self):
102 | return self.bn.running_mean
103 |
104 | @property
105 | def running_var(self):
106 | return self.bn.running_var
107 |
108 | @property
109 | def eps(self):
110 | return self.bn.eps
111 |
112 | class PriorFilter(nn.Module):
113 |
114 | def __init__(self, channels, stride, padding, width=4):
115 | super(PriorFilter, self).__init__()
116 | self.stride = stride
117 | self.width = width
118 | self.group = int(channels/width)
119 | self.padding = padding
120 | self.prior_tensor = torch.Tensor(self.group, 1, 3, 3)
121 |
122 | half_g = int(self.group/2)
123 | for i in range(self.group):
124 | for h in range(3):
125 | for w in range(3):
126 | if i < half_g:
127 | self.prior_tensor[i, 0, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3)
128 | else:
129 | self.prior_tensor[i, 0, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_g)/3)
130 |
131 | self.register_buffer('prior', self.prior_tensor)
132 |
133 | def forward(self, x):
134 | b, c, h, w = x.size()
135 | x = x.view(b*self.width, self.group, h, w)
136 | return F.conv2d(input=x, weight=self.prior, bias=None, stride=self.stride, padding=self.padding, groups=self.group).view(b, c, int(h/self.stride), int(w/self.stride))
137 |
138 | class PriorConv(nn.Module):
139 |
140 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, width=4):
141 | super(PriorConv, self).__init__()
142 | self.stride = stride
143 | self.width = width
144 | self.group = int(in_channels/width)
145 | self.padding = padding
146 | self.prior_tensor = torch.Tensor(self.group, 3, 3)
147 |
148 | half_g = int(self.group/2)
149 | for i in range(self.group):
150 | for h in range(3):
151 | for w in range(3):
152 | if i < half_g:
153 | self.prior_tensor[i, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3)
154 | else:
155 | self.prior_tensor[i, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_g)/3)
156 |
157 | self.register_buffer('prior', self.prior_tensor)
158 |
159 | self.weight = torch.nn.Parameter(torch.FloatTensor(out_channels, width, self.group, kernel_size, kernel_size))
160 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
161 |
162 | def forward(self, x):
163 | c_out, wi, g, kh, kw = self.weight.size()
164 | weight = torch.einsum('gmn,cwgmn->cwgmn', self.prior, self.weight).view(c_out, int(wi*g), kh, kw)
165 | return F.conv2d(input=x, weight=weight, bias=None, stride=self.stride, padding=self.padding)
--------------------------------------------------------------------------------
/blocks.py:
--------------------------------------------------------------------------------
1 |
2 | from basic import *
3 | import numpy as np
4 | from dbb_transforms import transI_fusebn, transII_addbranch, transIII_1x1_kxk, transV_avg, transVI_multiscale
5 |
6 | class DBB(nn.Module):
7 |
8 | def __init__(self, in_channels, out_channels, kernel_size,
9 | stride=1, padding=0, dilation=1, groups=1,
10 | internal_channels_1x1_3x3=None,
11 | deploy=False, nonlinear=None, single_init=False):
12 | super(DBB, self).__init__()
13 | self.deploy = deploy
14 |
15 | if nonlinear is None:
16 | self.nonlinear = nn.Identity()
17 | else:
18 | self.nonlinear = nonlinear
19 |
20 | self.kernel_size = kernel_size
21 | self.out_channels = out_channels
22 | self.groups = groups
23 | assert padding == kernel_size // 2
24 |
25 | if deploy:
26 | self.dbb_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
27 | padding=padding, dilation=dilation, groups=groups, bias=True)
28 |
29 | else:
30 |
31 | self.dbb_origin = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups)
32 |
33 | self.dbb_avg = nn.Sequential()
34 | if groups < out_channels:
35 | self.dbb_avg.add_module('conv',
36 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
37 | stride=1, padding=0, groups=groups, bias=False))
38 | self.dbb_avg.add_module('bn', BNAndPadLayer(pad_pixels=padding, num_features=out_channels))
39 | self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
40 | self.dbb_1x1 = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
41 | padding=0, groups=groups)
42 | else:
43 | self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding))
44 |
45 | self.dbb_avg.add_module('avgbn', nn.BatchNorm2d(out_channels))
46 |
47 |
48 | if internal_channels_1x1_3x3 is None:
49 | internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
50 |
51 | self.dbb_1x1_kxk = nn.Sequential()
52 | if internal_channels_1x1_3x3 == in_channels:
53 | self.dbb_1x1_kxk.add_module('idconv1', IdentityBasedConv1x1(channels=in_channels, groups=groups))
54 | else:
55 | self.dbb_1x1_kxk.add_module('conv1', nn.Conv2d(in_channels=in_channels, out_channels=internal_channels_1x1_3x3,
56 | kernel_size=1, stride=1, padding=0, groups=groups, bias=False))
57 | self.dbb_1x1_kxk.add_module('bn1', BNAndPadLayer(pad_pixels=padding, num_features=internal_channels_1x1_3x3, affine=True))
58 | self.dbb_1x1_kxk.add_module('conv2', nn.Conv2d(in_channels=internal_channels_1x1_3x3, out_channels=out_channels,
59 | kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=False))
60 | self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))
61 |
62 | # The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
63 | if single_init:
64 | # Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
65 | self.single_init()
66 |
67 | def get_equivalent_kernel_bias(self):
68 | k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight, self.dbb_origin.bn)
69 |
70 | if hasattr(self, 'dbb_1x1'):
71 | k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn)
72 | k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
73 | else:
74 | k_1x1, b_1x1 = 0, 0
75 |
76 | if hasattr(self.dbb_1x1_kxk, 'idconv1'):
77 | k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
78 | else:
79 | k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
80 | k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first, self.dbb_1x1_kxk.bn1)
81 | k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2)
82 | k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first, b_1x1_kxk_first, k_1x1_kxk_second, b_1x1_kxk_second, groups=self.groups)
83 |
84 | k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
85 | k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device), self.dbb_avg.avgbn)
86 | if hasattr(self.dbb_avg, 'conv'):
87 | k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight, self.dbb_avg.bn)
88 | k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first, b_1x1_avg_first, k_1x1_avg_second, b_1x1_avg_second, groups=self.groups)
89 | else:
90 | k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
91 |
92 | return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged), (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged))
93 |
94 | def switch_to_deploy(self):
95 | if hasattr(self, 'dbb_reparam'):
96 | return
97 | kernel, bias = self.get_equivalent_kernel_bias()
98 | self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels, out_channels=self.dbb_origin.conv.out_channels,
99 | kernel_size=self.dbb_origin.conv.kernel_size, stride=self.dbb_origin.conv.stride,
100 | padding=self.dbb_origin.conv.padding, dilation=self.dbb_origin.conv.dilation, groups=self.dbb_origin.conv.groups, bias=True)
101 | self.dbb_reparam.weight.data = kernel
102 | self.dbb_reparam.bias.data = bias
103 | for para in self.parameters():
104 | para.detach_()
105 | self.__delattr__('dbb_origin')
106 | self.__delattr__('dbb_avg')
107 | if hasattr(self, 'dbb_1x1'):
108 | self.__delattr__('dbb_1x1')
109 | self.__delattr__('dbb_1x1_kxk')
110 |
111 | def forward(self, inputs):
112 |
113 | if hasattr(self, 'dbb_reparam'):
114 | return self.nonlinear(self.dbb_reparam(inputs))
115 |
116 | out = self.dbb_origin(inputs)
117 | if hasattr(self, 'dbb_1x1'):
118 | out += self.dbb_1x1(inputs)
119 | out += self.dbb_avg(inputs)
120 | out += self.dbb_1x1_kxk(inputs)
121 | return self.nonlinear(out)
122 |
123 | def init_gamma(self, gamma_value):
124 | if hasattr(self, "dbb_origin"):
125 | torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
126 | if hasattr(self, "dbb_1x1"):
127 | torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
128 | if hasattr(self, "dbb_avg"):
129 | torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
130 | if hasattr(self, "dbb_1x1_kxk"):
131 | torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)
132 |
133 | def single_init(self):
134 | self.init_gamma(0.0)
135 | if hasattr(self, "dbb_origin"):
136 | torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
137 |
138 | class OREPA_1x1(nn.Module):
139 |
140 | def __init__(self, in_channels, out_channels, kernel_size=1,
141 | stride=1, padding=0, dilation=1, groups=1,
142 | deploy=False, nonlinear=None, single_init=False):
143 | super(OREPA_1x1, self).__init__()
144 | self.deploy = deploy
145 |
146 | if nonlinear is None:
147 | self.nonlinear = nn.Identity()
148 | else:
149 | self.nonlinear = nonlinear
150 |
151 | self.in_channels = in_channels
152 | self.out_channels = out_channels
153 | self.groups = groups
154 | assert groups == 1
155 | assert kernel_size == 1
156 | assert padding == kernel_size // 2
157 |
158 | self.kernel_size = kernel_size
159 | self.stride = stride
160 | self.padding = padding
161 | self.dilation = dilation
162 |
163 | if deploy:
164 | self.or1x1_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
165 | padding=padding, dilation=dilation, groups=groups, bias=True)
166 |
167 | else:
168 | self.branch_counter = 0
169 |
170 | self.weight_or1x1_origin = nn.Parameter(torch.Tensor(out_channels, in_channels, 1, 1))
171 | init.kaiming_uniform_(self.weight_or1x1_origin, a=math.sqrt(1.0))
172 | self.branch_counter += 1
173 |
174 | if out_channels > in_channels:
175 | self.weight_or1x1_l2i_conv1 = nn.Parameter(torch.eye(in_channels).unsqueeze(2).unsqueeze(3))
176 | self.weight_or1x1_l2i_conv2 = nn.Parameter(torch.Tensor(out_channels, in_channels, 1, 1))
177 | init.kaiming_uniform_(self.weight_or1x1_l2i_conv2, a=math.sqrt(1.0))
178 | else:
179 | self.weight_or1x1_l2i_conv1 = nn.Parameter(torch.Tensor(out_channels, in_channels, 1, 1))
180 | init.kaiming_uniform_(self.weight_or1x1_l2i_conv1, a=math.sqrt(1.0))
181 | self.weight_or1x1_l2i_conv2 = nn.Parameter(torch.eye(out_channels).unsqueeze(2).unsqueeze(3))
182 | self.branch_counter += 1
183 |
184 | self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels))
185 | self.bn = nn.BatchNorm2d(self.out_channels)
186 |
187 | init.constant_(self.vector[0, :], 1.0)
188 | init.constant_(self.vector[1, :], 0.5)
189 |
190 | if single_init:
191 | # Initialize the vector.weight of origin as 1 and others as 0. This is not the default setting.
192 | self.single_init()
193 |
194 | def weight_gen(self):
195 |
196 | weight_or1x1_origin = torch.einsum('oihw,o->oihw', self.weight_or1x1_origin, self.vector[0, :])
197 |
198 | weight_or1x1_l2i = torch.einsum('tihw,othw->oihw', self.weight_or1x1_l2i_conv1, self.weight_or1x1_l2i_conv2)
199 | weight_or1x1_l2i = torch.einsum('oihw,o->oihw', weight_or1x1_l2i, self.vector[1, :])
200 |
201 | return weight_or1x1_origin + weight_or1x1_l2i
202 |
203 | def forward(self, inputs):
204 | if hasattr(self, 'or1x1_reparam'):
205 | return self.nonlinear(self.or1x1_reparam(inputs))
206 |
207 | weight = self.weight_gen()
208 | out = F.conv2d(inputs, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
209 | return self.nonlinear(self.bn(out))
210 |
211 | def get_equivalent_kernel_bias(self):
212 | return transI_fusebn(self.weight_gen(), self.bn)
213 |
214 | def switch_to_deploy(self):
215 | if hasattr(self, 'or1x1_reparam'):
216 | return
217 | kernel, bias = self.get_equivalent_kernel_bias()
218 | self.or1x1_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels,
219 | kernel_size=self.kernel_size, stride=self.stride,
220 | padding=self.padding, dilation=self.dilation, groups=self.groups, bias=True)
221 | self.or1x1_reparam.weight.data = kernel
222 | self.or1x1_reparam.bias.data = bias
223 | for para in self.parameters():
224 | para.detach_()
225 | self.__delattr__('weight_or1x1_origin')
226 | self.__delattr__('weight_or1x1_l2i_conv1')
227 | self.__delattr__('weight_or1x1_l2i_conv2')
228 | self.__delattr__('vector')
229 | self.__delattr__('bn')
230 |
231 | def init_gamma(self, gamma_value):
232 | init.constant_(self.vector, gamma_value)
233 |
234 | def single_init(self):
235 | self.init_gamma(0.0)
236 | init.constant_(self.vector[0, :], 1.0)
237 |
238 | class OREPA(nn.Module):
239 |
240 | def __init__(self,
241 | in_channels,
242 | out_channels,
243 | kernel_size,
244 | stride=1,
245 | padding=0,
246 | dilation=1,
247 | groups=1,
248 | internal_channels_1x1_3x3=None,
249 | deploy=False,
250 | nonlinear=None,
251 | single_init=False,
252 | weight_only=False,
253 | init_hyper_para=1.0, init_hyper_gamma=1.0):
254 | super(OREPA, self).__init__()
255 | self.deploy = deploy
256 |
257 | if nonlinear is None:
258 | self.nonlinear = nn.Identity()
259 | else:
260 | self.nonlinear = nonlinear
261 | self.weight_only = weight_only
262 |
263 | self.kernel_size = kernel_size
264 | self.in_channels = in_channels
265 | self.out_channels = out_channels
266 | self.groups = groups
267 | assert padding == kernel_size // 2
268 |
269 | self.stride = stride
270 | self.padding = padding
271 | self.dilation = dilation
272 |
273 | if deploy:
274 | self.orepa_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
275 | padding=padding, dilation=dilation, groups=groups, bias=True)
276 |
277 | else:
278 |
279 | self.branch_counter = 0
280 |
281 | self.weight_orepa_origin = nn.Parameter(
282 | torch.Tensor(out_channels, int(in_channels / self.groups),
283 | kernel_size, kernel_size))
284 | init.kaiming_uniform_(self.weight_orepa_origin, a=math.sqrt(0.0))
285 | self.branch_counter += 1
286 |
287 | self.weight_orepa_avg_conv = nn.Parameter(
288 | torch.Tensor(out_channels, int(in_channels / self.groups), 1,
289 | 1))
290 | self.weight_orepa_pfir_conv = nn.Parameter(
291 | torch.Tensor(out_channels, int(in_channels / self.groups), 1,
292 | 1))
293 | init.kaiming_uniform_(self.weight_orepa_avg_conv, a=0.0)
294 | init.kaiming_uniform_(self.weight_orepa_pfir_conv, a=0.0)
295 | self.register_buffer(
296 | 'weight_orepa_avg_avg',
297 | torch.ones(kernel_size,
298 | kernel_size).mul(1.0 / kernel_size / kernel_size))
299 | self.branch_counter += 1
300 | self.branch_counter += 1
301 |
302 | self.weight_orepa_1x1 = nn.Parameter(
303 | torch.Tensor(out_channels, int(in_channels / self.groups), 1,
304 | 1))
305 | init.kaiming_uniform_(self.weight_orepa_1x1, a=0.0)
306 | self.branch_counter += 1
307 |
308 | if internal_channels_1x1_3x3 is None:
309 | internal_channels_1x1_3x3 = in_channels if groups <= 4 else 2 * in_channels
310 |
311 | if internal_channels_1x1_3x3 == in_channels:
312 | self.weight_orepa_1x1_kxk_idconv1 = nn.Parameter(
313 | torch.zeros(in_channels, int(in_channels / self.groups), 1, 1))
314 | id_value = np.zeros(
315 | (in_channels, int(in_channels / self.groups), 1, 1))
316 | for i in range(in_channels):
317 | id_value[i, i % int(in_channels / self.groups), 0, 0] = 1
318 | id_tensor = torch.from_numpy(id_value).type_as(
319 | self.weight_orepa_1x1_kxk_idconv1)
320 | self.register_buffer('id_tensor', id_tensor)
321 |
322 | else:
323 | self.weight_orepa_1x1_kxk_idconv1 = nn.Parameter(
324 | torch.zeros(internal_channels_1x1_3x3,
325 | int(in_channels / self.groups), 1, 1))
326 | id_value = np.zeros(
327 | (internal_channels_1x1_3x3, int(in_channels / self.groups), 1, 1))
328 | for i in range(internal_channels_1x1_3x3):
329 | id_value[i, i % int(in_channels / self.groups), 0, 0] = 1
330 | id_tensor = torch.from_numpy(id_value).type_as(
331 | self.weight_orepa_1x1_kxk_idconv1)
332 | self.register_buffer('id_tensor', id_tensor)
333 | #init.kaiming_uniform_(
334 | #self.weight_orepa_1x1_kxk_conv1, a=math.sqrt(0.0))
335 | self.weight_orepa_1x1_kxk_conv2 = nn.Parameter(
336 | torch.Tensor(out_channels,
337 | int(internal_channels_1x1_3x3 / self.groups),
338 | kernel_size, kernel_size))
339 | init.kaiming_uniform_(self.weight_orepa_1x1_kxk_conv2, a=math.sqrt(0.0))
340 | self.branch_counter += 1
341 |
342 | expand_ratio = 8
343 | self.weight_orepa_gconv_dw = nn.Parameter(
344 | torch.Tensor(in_channels * expand_ratio, 1, kernel_size,
345 | kernel_size))
346 | self.weight_orepa_gconv_pw = nn.Parameter(
347 | torch.Tensor(out_channels, int(in_channels * expand_ratio / self.groups), 1, 1))
348 | init.kaiming_uniform_(self.weight_orepa_gconv_dw, a=math.sqrt(0.0))
349 | init.kaiming_uniform_(self.weight_orepa_gconv_pw, a=math.sqrt(0.0))
350 | self.branch_counter += 1
351 |
352 | self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels))
353 | if weight_only is False:
354 | self.bn = nn.BatchNorm2d(self.out_channels)
355 |
356 | self.fre_init()
357 |
358 | init.constant_(self.vector[0, :], 0.25 * math.sqrt(init_hyper_gamma)) #origin
359 | init.constant_(self.vector[1, :], 0.25 * math.sqrt(init_hyper_gamma)) #avg
360 | init.constant_(self.vector[2, :], 0.0 * math.sqrt(init_hyper_gamma)) #prior
361 | init.constant_(self.vector[3, :], 0.5 * math.sqrt(init_hyper_gamma)) #1x1_kxk
362 | init.constant_(self.vector[4, :], 1.0 * math.sqrt(init_hyper_gamma)) #1x1
363 | init.constant_(self.vector[5, :], 0.5 * math.sqrt(init_hyper_gamma)) #dws_conv
364 |
365 | self.weight_orepa_1x1.data = self.weight_orepa_1x1.mul(init_hyper_para)
366 | self.weight_orepa_origin.data = self.weight_orepa_origin.mul(init_hyper_para)
367 | self.weight_orepa_1x1_kxk_conv2.data = self.weight_orepa_1x1_kxk_conv2.mul(init_hyper_para)
368 | self.weight_orepa_avg_conv.data = self.weight_orepa_avg_conv.mul(init_hyper_para)
369 | self.weight_orepa_pfir_conv.data = self.weight_orepa_pfir_conv.mul(init_hyper_para)
370 |
371 | self.weight_orepa_gconv_dw.data = self.weight_orepa_gconv_dw.mul(math.sqrt(init_hyper_para))
372 | self.weight_orepa_gconv_pw.data = self.weight_orepa_gconv_pw.mul(math.sqrt(init_hyper_para))
373 |
374 | if single_init:
375 | # Initialize the vector.weight of origin as 1 and others as 0. This is not the default setting.
376 | self.single_init()
377 |
378 | def fre_init(self):
379 | prior_tensor = torch.Tensor(self.out_channels, self.kernel_size,
380 | self.kernel_size)
381 | half_fg = self.out_channels / 2
382 | for i in range(self.out_channels):
383 | for h in range(3):
384 | for w in range(3):
385 | if i < half_fg:
386 | prior_tensor[i, h, w] = math.cos(math.pi * (h + 0.5) *
387 | (i + 1) / 3)
388 | else:
389 | prior_tensor[i, h, w] = math.cos(math.pi * (w + 0.5) *
390 | (i + 1 - half_fg) / 3)
391 |
392 | self.register_buffer('weight_orepa_prior', prior_tensor)
393 |
394 | def weight_gen(self):
395 | weight_orepa_origin = torch.einsum('oihw,o->oihw',
396 | self.weight_orepa_origin,
397 | self.vector[0, :])
398 |
399 | weight_orepa_avg = torch.einsum('oihw,hw->oihw', self.weight_orepa_avg_conv, self.weight_orepa_avg_avg)
400 | weight_orepa_avg = torch.einsum(
401 | 'oihw,o->oihw',
402 | torch.einsum('oi,hw->oihw', self.weight_orepa_avg_conv.squeeze(3).squeeze(2),
403 | self.weight_orepa_avg_avg), self.vector[1, :])
404 |
405 |
406 | weight_orepa_pfir = torch.einsum(
407 | 'oihw,o->oihw',
408 | torch.einsum('oi,ohw->oihw', self.weight_orepa_pfir_conv.squeeze(3).squeeze(2),
409 | self.weight_orepa_prior), self.vector[2, :])
410 |
411 | weight_orepa_1x1_kxk_conv1 = None
412 | if hasattr(self, 'weight_orepa_1x1_kxk_idconv1'):
413 | weight_orepa_1x1_kxk_conv1 = (self.weight_orepa_1x1_kxk_idconv1 +
414 | self.id_tensor).squeeze(3).squeeze(2)
415 | elif hasattr(self, 'weight_orepa_1x1_kxk_conv1'):
416 | weight_orepa_1x1_kxk_conv1 = self.weight_orepa_1x1_kxk_conv1.squeeze(3).squeeze(2)
417 | else:
418 | raise NotImplementedError
419 | weight_orepa_1x1_kxk_conv2 = self.weight_orepa_1x1_kxk_conv2
420 |
421 | if self.groups > 1:
422 | g = self.groups
423 | t, ig = weight_orepa_1x1_kxk_conv1.size()
424 | o, tg, h, w = weight_orepa_1x1_kxk_conv2.size()
425 | weight_orepa_1x1_kxk_conv1 = weight_orepa_1x1_kxk_conv1.view(
426 | g, int(t / g), ig)
427 | weight_orepa_1x1_kxk_conv2 = weight_orepa_1x1_kxk_conv2.view(
428 | g, int(o / g), tg, h, w)
429 | weight_orepa_1x1_kxk = torch.einsum('gti,gothw->goihw',
430 | weight_orepa_1x1_kxk_conv1,
431 | weight_orepa_1x1_kxk_conv2).reshape(
432 | o, ig, h, w)
433 | else:
434 | weight_orepa_1x1_kxk = torch.einsum('ti,othw->oihw',
435 | weight_orepa_1x1_kxk_conv1,
436 | weight_orepa_1x1_kxk_conv2)
437 | weight_orepa_1x1_kxk = torch.einsum('oihw,o->oihw', weight_orepa_1x1_kxk, self.vector[3, :])
438 |
439 | weight_orepa_1x1 = 0
440 | if hasattr(self, 'weight_orepa_1x1'):
441 | weight_orepa_1x1 = transVI_multiscale(self.weight_orepa_1x1,
442 | self.kernel_size)
443 | weight_orepa_1x1 = torch.einsum('oihw,o->oihw', weight_orepa_1x1,
444 | self.vector[4, :])
445 |
446 | weight_orepa_gconv = self.dwsc2full(self.weight_orepa_gconv_dw,
447 | self.weight_orepa_gconv_pw,
448 | self.in_channels, self.groups)
449 | weight_orepa_gconv = torch.einsum('oihw,o->oihw', weight_orepa_gconv,
450 | self.vector[5, :])
451 |
452 | weight = weight_orepa_origin + weight_orepa_avg + weight_orepa_1x1 + weight_orepa_1x1_kxk + weight_orepa_pfir + weight_orepa_gconv
453 |
454 | return weight
455 |
456 | def dwsc2full(self, weight_dw, weight_pw, groups, groups_conv=1):
457 |
458 | t, ig, h, w = weight_dw.size()
459 | o, _, _, _ = weight_pw.size()
460 | tg = int(t / groups)
461 | i = int(ig * groups)
462 | ogc = int(o / groups_conv)
463 | groups_gc = int(groups / groups_conv)
464 | weight_dw = weight_dw.view(groups_conv, groups_gc, tg, ig, h, w)
465 | weight_pw = weight_pw.squeeze().view(ogc, groups_conv, groups_gc, tg)
466 |
467 | weight_dsc = torch.einsum('cgtihw,ocgt->cogihw', weight_dw, weight_pw)
468 | return weight_dsc.reshape(o, int(i/groups_conv), h, w)
469 |
470 | def forward(self, inputs=None):
471 | if hasattr(self, 'orepa_reparam'):
472 | return self.nonlinear(self.orepa_reparam(inputs))
473 |
474 | weight = self.weight_gen()
475 |
476 | if self.weight_only is True:
477 | return weight
478 |
479 | out = F.conv2d(
480 | inputs,
481 | weight,
482 | bias=None,
483 | stride=self.stride,
484 | padding=self.padding,
485 | dilation=self.dilation,
486 | groups=self.groups)
487 | return self.nonlinear(self.bn(out))
488 |
489 | def get_equivalent_kernel_bias(self):
490 | return transI_fusebn(self.weight_gen(), self.bn)
491 |
492 | def switch_to_deploy(self):
493 | if hasattr(self, 'or1x1_reparam'):
494 | return
495 | kernel, bias = self.get_equivalent_kernel_bias()
496 | self.orepa_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels,
497 | kernel_size=self.kernel_size, stride=self.stride,
498 | padding=self.padding, dilation=self.dilation, groups=self.groups, bias=True)
499 | self.orepa_reparam.weight.data = kernel
500 | self.orepa_reparam.bias.data = bias
501 | for para in self.parameters():
502 | para.detach_()
503 | self.__delattr__('weight_orepa_origin')
504 | self.__delattr__('weight_orepa_1x1')
505 | self.__delattr__('weight_orepa_1x1_kxk_conv2')
506 | if hasattr(self, 'weight_orepa_1x1_kxk_idconv1'):
507 | self.__delattr__('id_tensor')
508 | self.__delattr__('weight_orepa_1x1_kxk_idconv1')
509 | elif hasattr(self, 'weight_orepa_1x1_kxk_conv1'):
510 | self.__delattr__('weight_orepa_1x1_kxk_conv1')
511 | else:
512 | raise NotImplementedError
513 | self.__delattr__('weight_orepa_avg_avg')
514 | self.__delattr__('weight_orepa_avg_conv')
515 | self.__delattr__('weight_orepa_pfir_conv')
516 | self.__delattr__('weight_orepa_prior')
517 | self.__delattr__('weight_orepa_gconv_dw')
518 | self.__delattr__('weight_orepa_gconv_pw')
519 |
520 | self.__delattr__('bn')
521 | self.__delattr__('vector')
522 |
523 | def init_gamma(self, gamma_value):
524 | init.constant_(self.vector, gamma_value)
525 |
526 | def single_init(self):
527 | self.init_gamma(0.0)
528 | init.constant_(self.vector[0, :], 1.0)
529 |
530 | class OREPA_LargeConvBase(nn.Module):
531 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, deploy=False, nonlinear=None):
532 | super(OREPA_LargeConvBase, self).__init__()
533 | assert kernel_size % 2 == 1 and kernel_size > 3
534 |
535 | self.stride = stride
536 | self.padding = padding
537 | self.layers = int((kernel_size - 1) / 2)
538 | self.groups = groups
539 | self.dilation = dilation
540 |
541 | internal_channels = out_channels
542 |
543 | self.kernel_size = kernel_size
544 | self.in_channels = in_channels
545 | self.out_channels = out_channels
546 |
547 | if nonlinear is None:
548 | self.nonlinear = nn.Identity()
549 | else:
550 | self.nonlinear = nonlinear
551 |
552 | if deploy:
553 | self.or_large_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
554 | padding=padding, dilation=dilation, groups=groups, bias=True)
555 |
556 | else:
557 | for i in range(self.layers):
558 | if i == 0:
559 | self.__setattr__('weight'+str(i), nn.Parameter(torch.Tensor(internal_channels, int(in_channels/self.groups), 3, 3)))
560 | elif i == self.layers - 1:
561 | self.__setattr__('weight'+str(i), nn.Parameter(torch.Tensor(out_channels, int(internal_channels/self.groups), 3, 3)))
562 | else:
563 | self.__setattr__('weight'+str(i), nn.Parameter(torch.Tensor(internal_channels, int(internal_channels/self.groups), 3, 3)))
564 | init.kaiming_uniform_(getattr(self, 'weight'+str(i)), a=math.sqrt(5))
565 |
566 | self.bn = nn.BatchNorm2d(out_channels)
567 | #self.unfold = torch.nn.Unfold(kernel_size=3, dilation=1, padding=2, stride=1)
568 |
569 | def weight_gen(self):
570 | weight = getattr(self, 'weight'+str(0)).transpose(0, 1)
571 | for i in range(self.layers - 1):
572 | weight2 = getattr(self, 'weight'+str(i+1))
573 | weight = F.conv2d(weight, weight2, groups=self.groups, padding=2)
574 |
575 | return weight.transpose(0, 1)
576 | '''
577 | weight = getattr(self, 'weight'+str(0)).transpose(0, 1)
578 | for i in range(self.layers - 1):
579 | weight = self.unfold(weight)
580 | weight2 = getattr(self, 'weight'+str(i+1))
581 |
582 | weight = torch.einsum('akl,bk->abl', weight, weight2.view(weight2.size(0), -1))
583 | k = i * 2 + 5
584 | weight = weight.view(weight.size(0), weight.size(1), k, k)
585 |
586 | return weight.transpose(0, 1)
587 | '''
588 |
589 | def forward(self, inputs):
590 | if hasattr(self, 'or_large_reparam'):
591 | return self.nonlinear(self.or_large_reparam(inputs))
592 |
593 | weight = self.weight_gen()
594 | out = F.conv2d(inputs, weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
595 | return self.nonlinear(self.bn(out))
596 |
597 | def get_equivalent_kernel_bias(self):
598 | return transI_fusebn(self.weight_gen(), self.bn)
599 |
600 | def switch_to_deploy(self):
601 | if hasattr(self, 'or_large_reparam'):
602 | return
603 | kernel, bias = self.get_equivalent_kernel_bias()
604 | self.or_large_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels,
605 | kernel_size=self.kernel_size, stride=self.stride,
606 | padding=self.padding, dilation=self.dilation, groups=self.groups, bias=True)
607 | self.or_large_reparam.weight.data = kernel
608 | self.or_large_reparam.bias.data = bias
609 | for para in self.parameters():
610 | para.detach_()
611 | for i in range(self.layers):
612 | self.__delattr__('weight'+str(i))
613 | self.__delattr__('bn')
614 |
615 |
616 | class OREPA_LargeConv(nn.Module):
617 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, deploy=False, nonlinear=None):
618 | super(OREPA_LargeConv, self).__init__()
619 | assert kernel_size % 2 == 1 and kernel_size > 3
620 |
621 | self.stride = stride
622 | self.padding = padding
623 | self.layers = int((kernel_size - 1) / 2)
624 | self.groups = groups
625 | self.dilation = dilation
626 |
627 | self.kernel_size = kernel_size
628 | self.in_channels = in_channels
629 | self.out_channels = out_channels
630 |
631 | internal_channels = out_channels
632 |
633 | if nonlinear is None:
634 | self.nonlinear = nn.Identity()
635 | else:
636 | self.nonlinear = nonlinear
637 |
638 | if deploy:
639 | self.or_large_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
640 | padding=padding, dilation=dilation, groups=groups, bias=True)
641 |
642 | else:
643 | for i in range(self.layers):
644 | if i == 0:
645 | self.__setattr__('weight'+str(i), OREPA(in_channels, internal_channels, kernel_size=3, stride=1, padding=1, groups=groups, weight_only=True))
646 | elif i == self.layers - 1:
647 | self.__setattr__('weight'+str(i), OREPA(internal_channels, out_channels, kernel_size=3, stride=self.stride, padding=1, weight_only=True))
648 | else:
649 | self.__setattr__('weight'+str(i), OREPA(internal_channels, internal_channels, kernel_size=3, stride=1, padding=1, weight_only=True))
650 |
651 | self.bn = nn.BatchNorm2d(out_channels)
652 | #self.unfold = torch.nn.Unfold(kernel_size=3, dilation=1, padding=2, stride=1)
653 |
654 | def weight_gen(self):
655 | weight = getattr(self, 'weight'+str(0)).weight_gen().transpose(0, 1)
656 | for i in range(self.layers - 1):
657 | weight2 = getattr(self, 'weight'+str(i+1)).weight_gen()
658 | weight = F.conv2d(weight, weight2, groups=self.groups, padding=2)
659 |
660 | return weight.transpose(0, 1)
661 | '''
662 | weight = getattr(self, 'weight'+str(0))(inputs=None).transpose(0, 1)
663 | for i in range(self.layers - 1):
664 | weight = self.unfold(weight)
665 | weight2 = getattr(self, 'weight'+str(i+1))(inputs=None)
666 |
667 | weight = torch.einsum('akl,bk->abl', weight, weight2.view(weight2.size(0), -1))
668 | k = i * 2 + 5
669 | weight = weight.view(weight.size(0), weight.size(1), k, k)
670 |
671 | return weight.transpose(0, 1)
672 | '''
673 |
674 | def forward(self, inputs):
675 | if hasattr(self, 'or_large_reparam'):
676 | return self.nonlinear(self.or_large_reparam(inputs))
677 |
678 | weight = self.weight_gen()
679 | out = F.conv2d(inputs, weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
680 | return self.nonlinear(self.bn(out))
681 |
682 | def get_equivalent_kernel_bias(self):
683 | return transI_fusebn(self.weight_gen(), self.bn)
684 |
685 | def switch_to_deploy(self):
686 | if hasattr(self, 'or_large_reparam'):
687 | return
688 | kernel, bias = self.get_equivalent_kernel_bias()
689 | self.or_large_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels,
690 | kernel_size=self.kernel_size, stride=self.stride,
691 | padding=self.padding, dilation=self.dilation, groups=self.groups, bias=True)
692 | self.or_large_reparam.weight.data = kernel
693 | self.or_large_reparam.bias.data = bias
694 | for para in self.parameters():
695 | para.detach_()
696 | for i in range(self.layers):
697 | self.__delattr__('weight'+str(i))
698 | self.__delattr__('bn')
699 |
--------------------------------------------------------------------------------
/blocks_repvgg.py:
--------------------------------------------------------------------------------
1 |
2 | from basic import *
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | class RepVGGBlock(nn.Module):
10 |
11 | def __init__(self, in_channels, out_channels, kernel_size,
12 | stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
13 | super(RepVGGBlock, self).__init__()
14 | self.deploy = deploy
15 | self.groups = groups
16 | self.in_channels = in_channels
17 |
18 | assert kernel_size == 3
19 | assert padding == 1
20 |
21 | padding_11 = padding - kernel_size // 2
22 |
23 | self.nonlinearity = nn.ReLU()
24 |
25 | if use_se:
26 | self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)
27 | else:
28 | self.se = nn.Identity()
29 |
30 | if deploy:
31 | self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
32 | padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
33 |
34 | else:
35 | self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
36 | self.rbr_dense = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
37 | self.rbr_1x1 = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
38 | print('RepVGG Block, identity = ', self.rbr_identity)
39 |
40 |
41 | def forward(self, inputs):
42 | if hasattr(self, 'rbr_reparam'):
43 | return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
44 |
45 | if self.rbr_identity is None:
46 | id_out = 0
47 | else:
48 | id_out = self.rbr_identity(inputs)
49 |
50 | return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
51 |
52 |
53 | # Optional. This improves the accuracy and facilitates quantization.
54 | # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
55 | # 2. Use like this.
56 | # loss = criterion(....)
57 | # for every RepVGGBlock blk:
58 | # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
59 | # optimizer.zero_grad()
60 | # loss.backward()
61 | def get_custom_L2(self):
62 | K3 = self.rbr_dense.conv.weight
63 | K1 = self.rbr_1x1.conv.weight
64 | t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
65 | t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
66 |
67 | l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
68 | eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
69 | l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
70 | return l2_loss_eq_kernel + l2_loss_circle
71 |
72 |
73 |
74 | # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
75 | # You can get the equivalent kernel and bias at any time and do whatever you want,
76 | # for example, apply some penalties or constraints during training, just like you do to the other models.
77 | # May be useful for quantization or pruning.
78 | def get_equivalent_kernel_bias(self):
79 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
80 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
81 | kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
82 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
83 |
84 | def _pad_1x1_to_3x3_tensor(self, kernel1x1):
85 | if kernel1x1 is None:
86 | return 0
87 | else:
88 | return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
89 |
90 | def _fuse_bn_tensor(self, branch):
91 | if branch is None:
92 | return 0, 0
93 | if not isinstance(branch, nn.BatchNorm2d):
94 | kernel = branch.conv.weight
95 | running_mean = branch.bn.running_mean
96 | running_var = branch.bn.running_var
97 | gamma = branch.bn.weight
98 | beta = branch.bn.bias
99 | eps = branch.bn.eps
100 | else:
101 | if not hasattr(self, 'id_tensor'):
102 | input_dim = self.in_channels // self.groups
103 | kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
104 | for i in range(self.in_channels):
105 | kernel_value[i, i % input_dim, 1, 1] = 1
106 | self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
107 | kernel = self.id_tensor
108 | running_mean = branch.running_mean
109 | running_var = branch.running_var
110 | gamma = branch.weight
111 | beta = branch.bias
112 | eps = branch.eps
113 | std = (running_var + eps).sqrt()
114 | t = (gamma / std).reshape(-1, 1, 1, 1)
115 | return kernel * t, beta - running_mean * gamma / std
116 |
117 | def switch_to_deploy(self):
118 | if hasattr(self, 'rbr_reparam'):
119 | return
120 | kernel, bias = self.get_equivalent_kernel_bias()
121 | self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
122 | kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
123 | padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
124 | self.rbr_reparam.weight.data = kernel
125 | self.rbr_reparam.bias.data = bias
126 | for para in self.parameters():
127 | para.detach_()
128 | self.__delattr__('rbr_dense')
129 | self.__delattr__('rbr_1x1')
130 | if hasattr(self, 'rbr_identity'):
131 | self.__delattr__('rbr_identity')
132 |
133 |
134 | class OREPA_3x3_RepVGG(nn.Module):
135 |
136 | def __init__(self, in_channels, out_channels, kernel_size,
137 | stride=1, padding=0, dilation=1, groups=1,
138 | internal_channels_1x1_3x3=None,
139 | deploy=False, nonlinear=None, single_init=False):
140 | super(OREPA_3x3_RepVGG, self).__init__()
141 | self.deploy = deploy
142 |
143 | if nonlinear is None:
144 | self.nonlinear = nn.Identity()
145 | else:
146 | self.nonlinear = nonlinear
147 |
148 | self.kernel_size = kernel_size
149 | self.in_channels = in_channels
150 | self.out_channels = out_channels
151 | self.groups = groups
152 | assert padding == kernel_size // 2
153 |
154 | self.stride = stride
155 | self.padding = padding
156 | self.dilation = dilation
157 |
158 | self.branch_counter = 0
159 |
160 | self.weight_rbr_origin = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), kernel_size, kernel_size))
161 | init.kaiming_uniform_(self.weight_rbr_origin, a=math.sqrt(1.0))
162 | self.branch_counter += 1
163 |
164 |
165 | if groups < out_channels:
166 | self.weight_rbr_avg_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
167 | self.weight_rbr_pfir_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
168 | init.kaiming_uniform_(self.weight_rbr_avg_conv, a=1.0)
169 | init.kaiming_uniform_(self.weight_rbr_pfir_conv, a=1.0)
170 | self.weight_rbr_avg_conv.data
171 | self.weight_rbr_pfir_conv.data
172 | self.register_buffer('weight_rbr_avg_avg', torch.ones(kernel_size, kernel_size).mul(1.0/kernel_size/kernel_size))
173 | self.branch_counter += 1
174 |
175 | else:
176 | raise NotImplementedError
177 | self.branch_counter += 1
178 |
179 | if internal_channels_1x1_3x3 is None:
180 | internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
181 |
182 | if internal_channels_1x1_3x3 == in_channels:
183 | self.weight_rbr_1x1_kxk_idconv1 = nn.Parameter(torch.zeros(in_channels, int(in_channels/self.groups), 1, 1))
184 | id_value = np.zeros((in_channels, int(in_channels/self.groups), 1, 1))
185 | for i in range(in_channels):
186 | id_value[i, i % int(in_channels/self.groups), 0, 0] = 1
187 | id_tensor = torch.from_numpy(id_value).type_as(self.weight_rbr_1x1_kxk_idconv1)
188 | self.register_buffer('id_tensor', id_tensor)
189 |
190 | else:
191 | self.weight_rbr_1x1_kxk_conv1 = nn.Parameter(torch.Tensor(internal_channels_1x1_3x3, int(in_channels/self.groups), 1, 1))
192 | init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv1, a=math.sqrt(1.0))
193 | self.weight_rbr_1x1_kxk_conv2 = nn.Parameter(torch.Tensor(out_channels, int(internal_channels_1x1_3x3/self.groups), kernel_size, kernel_size))
194 | init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv2, a=math.sqrt(1.0))
195 | self.branch_counter += 1
196 |
197 | expand_ratio = 8
198 | self.weight_rbr_gconv_dw = nn.Parameter(torch.Tensor(in_channels*expand_ratio, 1, kernel_size, kernel_size))
199 | self.weight_rbr_gconv_pw = nn.Parameter(torch.Tensor(out_channels, in_channels*expand_ratio, 1, 1))
200 | init.kaiming_uniform_(self.weight_rbr_gconv_dw, a=math.sqrt(1.0))
201 | init.kaiming_uniform_(self.weight_rbr_gconv_pw, a=math.sqrt(1.0))
202 | self.branch_counter += 1
203 |
204 | if out_channels == in_channels and stride == 1:
205 | self.branch_counter += 1
206 |
207 | self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels))
208 | self.bn = nn.BatchNorm2d(out_channels)
209 |
210 | self.fre_init()
211 |
212 | init.constant_(self.vector[0, :], 0.25) #origin
213 | init.constant_(self.vector[1, :], 0.25) #avg
214 | init.constant_(self.vector[2, :], 0.0) #prior
215 | init.constant_(self.vector[3, :], 0.5) #1x1_kxk
216 | init.constant_(self.vector[4, :], 0.5) #dws_conv
217 |
218 |
219 | def fre_init(self):
220 | prior_tensor = torch.Tensor(self.out_channels, self.kernel_size, self.kernel_size)
221 | half_fg = self.out_channels/2
222 | for i in range(self.out_channels):
223 | for h in range(3):
224 | for w in range(3):
225 | if i < half_fg:
226 | prior_tensor[i, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3)
227 | else:
228 | prior_tensor[i, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_fg)/3)
229 |
230 | self.register_buffer('weight_rbr_prior', prior_tensor)
231 |
232 | def weight_gen(self):
233 |
234 | weight_rbr_origin = torch.einsum('oihw,o->oihw', self.weight_rbr_origin, self.vector[0, :])
235 |
236 | weight_rbr_avg = torch.einsum('oihw,o->oihw', torch.einsum('oihw,hw->oihw', self.weight_rbr_avg_conv, self.weight_rbr_avg_avg), self.vector[1, :])
237 |
238 | weight_rbr_pfir = torch.einsum('oihw,o->oihw', torch.einsum('oihw,ohw->oihw', self.weight_rbr_pfir_conv, self.weight_rbr_prior), self.vector[2, :])
239 |
240 | weight_rbr_1x1_kxk_conv1 = None
241 | if hasattr(self, 'weight_rbr_1x1_kxk_idconv1'):
242 | weight_rbr_1x1_kxk_conv1 = (self.weight_rbr_1x1_kxk_idconv1 + self.id_tensor).squeeze()
243 | elif hasattr(self, 'weight_rbr_1x1_kxk_conv1'):
244 | weight_rbr_1x1_kxk_conv1 = self.weight_rbr_1x1_kxk_conv1.squeeze()
245 | else:
246 | raise NotImplementedError
247 | weight_rbr_1x1_kxk_conv2 = self.weight_rbr_1x1_kxk_conv2
248 |
249 | if self.groups > 1:
250 | g = self.groups
251 | t, ig = weight_rbr_1x1_kxk_conv1.size()
252 | o, tg, h, w = weight_rbr_1x1_kxk_conv2.size()
253 | weight_rbr_1x1_kxk_conv1 = weight_rbr_1x1_kxk_conv1.view(g, int(t/g), ig)
254 | weight_rbr_1x1_kxk_conv2 = weight_rbr_1x1_kxk_conv2.view(g, int(o/g), tg, h, w)
255 | weight_rbr_1x1_kxk = torch.einsum('gti,gothw->goihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2).view(o, ig, h, w)
256 | else:
257 | weight_rbr_1x1_kxk = torch.einsum('ti,othw->oihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2)
258 |
259 | weight_rbr_1x1_kxk = torch.einsum('oihw,o->oihw', weight_rbr_1x1_kxk, self.vector[3, :])
260 |
261 | weight_rbr_gconv = self.dwsc2full(self.weight_rbr_gconv_dw, self.weight_rbr_gconv_pw, self.in_channels)
262 | weight_rbr_gconv = torch.einsum('oihw,o->oihw', weight_rbr_gconv, self.vector[4, :])
263 |
264 | weight = weight_rbr_origin + weight_rbr_avg + weight_rbr_1x1_kxk + weight_rbr_pfir + weight_rbr_gconv
265 |
266 | return weight
267 |
268 | def dwsc2full(self, weight_dw, weight_pw, groups):
269 |
270 | t, ig, h, w = weight_dw.size()
271 | o, _, _, _ = weight_pw.size()
272 | tg = int(t/groups)
273 | i = int(ig*groups)
274 | weight_dw = weight_dw.view(groups, tg, ig, h, w)
275 | weight_pw = weight_pw.squeeze().view(o, groups, tg)
276 |
277 | weight_dsc = torch.einsum('gtihw,ogt->ogihw', weight_dw, weight_pw)
278 | return weight_dsc.view(o, i, h, w)
279 |
280 | def forward(self, inputs):
281 | weight = self.weight_gen()
282 | out = F.conv2d(inputs, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
283 |
284 | return self.nonlinear(self.bn(out))
285 |
286 | class RepVGGBlock_OREPA(nn.Module):
287 |
288 | def __init__(self, in_channels, out_channels, kernel_size,
289 | stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, nonlinear=nn.ReLU()):
290 | super(RepVGGBlock_OREPA, self).__init__()
291 | self.deploy = deploy
292 | self.groups = groups
293 | self.in_channels = in_channels
294 | self.out_channels = out_channels
295 |
296 | self.padding = padding
297 | self.dilation = dilation
298 | self.groups = groups
299 |
300 | assert kernel_size == 3
301 | assert padding == 1
302 |
303 | padding_11 = padding - kernel_size // 2
304 |
305 | if nonlinear is None:
306 | self.nonlinearity = nn.Identity()
307 | else:
308 | self.nonlinearity = nonlinear
309 |
310 | if use_se:
311 | self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)
312 | else:
313 | self.se = nn.Identity()
314 |
315 | if deploy:
316 | self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
317 | padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
318 |
319 | else:
320 | self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
321 | self.rbr_dense = OREPA_3x3_RepVGG(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=1)
322 | self.rbr_1x1 = ConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups, dilation=1)
323 | print('RepVGG Block, identity = ', self.rbr_identity)
324 |
325 |
326 | def forward(self, inputs):
327 | if hasattr(self, 'rbr_reparam'):
328 | return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
329 |
330 | if self.rbr_identity is None:
331 | id_out = 0
332 | else:
333 | id_out = self.rbr_identity(inputs)
334 |
335 | out1 = self.rbr_dense(inputs)
336 | out2 = self.rbr_1x1(inputs)
337 | out3 = id_out
338 | out = out1 + out2 + out3
339 |
340 | return self.nonlinearity(self.se(out))
341 |
342 |
343 | # Optional. This improves the accuracy and facilitates quantization.
344 | # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
345 | # 2. Use like this.
346 | # loss = criterion(....)
347 | # for every RepVGGBlock blk:
348 | # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
349 | # optimizer.zero_grad()
350 | # loss.backward()
351 |
352 | # Not used for OREPA
353 | def get_custom_L2(self):
354 | K3 = self.rbr_dense.weight_gen()
355 | K1 = self.rbr_1x1.conv.weight
356 | t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
357 | t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
358 |
359 | l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
360 | eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
361 | l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
362 | return l2_loss_eq_kernel + l2_loss_circle
363 |
364 | def get_equivalent_kernel_bias(self):
365 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
366 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
367 | kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
368 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
369 |
370 | def _pad_1x1_to_3x3_tensor(self, kernel1x1):
371 | if kernel1x1 is None:
372 | return 0
373 | else:
374 | return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
375 |
376 | def _fuse_bn_tensor(self, branch):
377 | if branch is None:
378 | return 0, 0
379 | if not isinstance(branch, nn.BatchNorm2d):
380 | if isinstance(branch, OREPA_3x3_RepVGG):
381 | kernel = branch.weight_gen()
382 | elif isinstance(branch, ConvBN):
383 | kernel = branch.conv.weight
384 | else:
385 | raise NotImplementedError
386 | running_mean = branch.bn.running_mean
387 | running_var = branch.bn.running_var
388 | gamma = branch.bn.weight
389 | beta = branch.bn.bias
390 | eps = branch.bn.eps
391 | else:
392 | if not hasattr(self, 'id_tensor'):
393 | input_dim = self.in_channels // self.groups
394 | kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
395 | for i in range(self.in_channels):
396 | kernel_value[i, i % input_dim, 1, 1] = 1
397 | self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
398 | kernel = self.id_tensor
399 | running_mean = branch.running_mean
400 | running_var = branch.running_var
401 | gamma = branch.weight
402 | beta = branch.bias
403 | eps = branch.eps
404 | std = (running_var + eps).sqrt()
405 | t = (gamma / std).reshape(-1, 1, 1, 1)
406 | return kernel * t, beta - running_mean * gamma / std
407 |
408 | def switch_to_deploy(self):
409 | if hasattr(self, 'rbr_reparam'):
410 | return
411 | kernel, bias = self.get_equivalent_kernel_bias()
412 | self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.in_channels, out_channels=self.rbr_dense.out_channels,
413 | kernel_size=self.rbr_dense.kernel_size, stride=self.rbr_dense.stride,
414 | padding=self.rbr_dense.padding, dilation=self.rbr_dense.dilation, groups=self.rbr_dense.groups, bias=True)
415 | self.rbr_reparam.weight.data = kernel
416 | self.rbr_reparam.bias.data = bias
417 | for para in self.parameters():
418 | para.detach_()
419 | self.__delattr__('rbr_dense')
420 | self.__delattr__('rbr_1x1')
421 | if hasattr(self, 'rbr_identity'):
422 | self.__delattr__('rbr_identity')
423 |
424 | class SEBlock(nn.Module):
425 |
426 | def __init__(self, input_channels, internal_neurons):
427 | super(SEBlock, self).__init__()
428 | self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True)
429 | self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True)
430 | self.input_channels = input_channels
431 |
432 | def forward(self, inputs):
433 | x = F.avg_pool2d(inputs, kernel_size=inputs.size(3))
434 | x = self.down(x)
435 | x = F.relu(x)
436 | x = self.up(x)
437 | x = torch.sigmoid(x)
438 | x = x.view(-1, self.input_channels, 1, 1)
439 | return inputs * x
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from convnet_utils import switch_conv_bn_impl, switch_deploy_flag, build_model
5 |
6 | parser = argparse.ArgumentParser(description='Convert Conversion')
7 | parser.add_argument('load', metavar='LOAD', help='path to the weights file')
8 | parser.add_argument('save', metavar='SAVE', help='path to the weights file')
9 | parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18')
10 | parser.add_argument('-t', '--blocktype', metavar='BLK', default='OREPA')
11 |
12 | def convert():
13 | args = parser.parse_args()
14 |
15 | switch_conv_bn_impl(args.blocktype)
16 | switch_deploy_flag(False)
17 | train_model = build_model(args.arch)
18 |
19 | if 'hdf5' in args.load:
20 | from utils import model_load_hdf5
21 | model_load_hdf5(train_model, args.load)
22 | elif os.path.isfile(args.load):
23 | print("=> loading checkpoint '{}'".format(args.load))
24 | checkpoint = torch.load(args.load)
25 | if 'state_dict' in checkpoint:
26 | checkpoint = checkpoint['state_dict']
27 | ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names
28 | train_model.load_state_dict(ckpt)
29 | else:
30 | print("=> no checkpoint found at '{}'".format(args.load))
31 |
32 | for m in train_model.modules():
33 | if hasattr(m, 'switch_to_deploy'):
34 | m.switch_to_deploy()
35 |
36 | torch.save(train_model.state_dict(), args.save)
37 |
38 |
39 | if __name__ == '__main__':
40 | convert()
--------------------------------------------------------------------------------
/convnet_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from basic import ConvBN
4 | from blocks import DBB, OREPA_1x1, OREPA, OREPA_LargeConvBase, OREPA_LargeConv
5 | from blocks_repvgg import RepVGGBlock, RepVGGBlock_OREPA
6 |
7 | CONV_BN_IMPL = 'base'
8 |
9 | DEPLOY_FLAG = False
10 |
11 | def choose_blk(kernel_size):
12 | if CONV_BN_IMPL == 'OREPA':
13 | if kernel_size == 1:
14 | blk_type = OREPA_1x1
15 | elif kernel_size >= 7:
16 | blk_type = OREPA_LargeConv
17 | else:
18 | blk_type = OREPA
19 | elif CONV_BN_IMPL == 'base' or kernel_size == 1 or kernel_size >= 7:
20 | blk_type = ConvBN
21 | elif CONV_BN_IMPL == 'DBB':
22 | blk_type = DBB
23 | elif CONV_BN_IMPL == 'RepVGG':
24 | blk_type = RepVGGBlock
25 | elif CONV_BN_IMPL == 'OREPA_VGG':
26 | blk_type = RepVGGBlock_OREPA
27 | else:
28 | raise NotImplementedError
29 |
30 | return blk_type
31 |
32 | def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, assign_type=None):
33 | if assign_type is not None:
34 | blk_type = assign_type
35 | else:
36 | blk_type = choose_blk(kernel_size)
37 | return blk_type(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
38 | padding=padding, dilation=dilation, groups=groups, deploy=DEPLOY_FLAG)
39 |
40 | def conv_bn_relu(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, assign_type=None):
41 | if assign_type is not None:
42 | blk_type = assign_type
43 | else:
44 | blk_type = choose_blk(kernel_size)
45 | return blk_type(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
46 | padding=padding, dilation=dilation, groups=groups, deploy=DEPLOY_FLAG, nonlinear=nn.ReLU())
47 |
48 | def switch_conv_bn_impl(block_type):
49 | global CONV_BN_IMPL
50 | CONV_BN_IMPL = block_type
51 |
52 | def switch_deploy_flag(deploy):
53 | global DEPLOY_FLAG
54 | DEPLOY_FLAG = deploy
55 | print('deploy flag: ', DEPLOY_FLAG)
56 |
57 | def build_model(arch):
58 | if arch == 'ResNet-18':
59 | from models.resnet import create_Res18
60 | model = create_Res18()
61 | elif arch == 'ResNet-34':
62 | from models.resnet import create_Res34
63 | model = create_Res34()
64 | elif arch == 'ResNet-50':
65 | from models.resnet import create_Res50
66 | model = create_Res50()
67 | elif arch == 'ResNet-101':
68 | from models.resnet import create_Res101
69 | model = create_Res101()
70 | elif arch == 'RepVGG-A0':
71 | from models.repvgg import create_RepVGG_A0
72 | model = create_RepVGG_A0()
73 | elif arch == 'RepVGG-A1':
74 | from models.repvgg import create_RepVGG_A1
75 | model = create_RepVGG_A1()
76 | elif arch == 'RepVGG-A2':
77 | from models.repvgg import create_RepVGG_A2
78 | model = create_RepVGG_A2()
79 | elif arch == 'RepVGG-B1':
80 | from models.repvgg import create_RepVGG_B1
81 | model = create_RepVGG_B1()
82 | elif arch == 'ResNet-18-1.5x':
83 | from models.resnet import create_Res18_1d5x
84 | model = create_Res18_1d5x()
85 | elif arch == 'ResNet-18-2x':
86 | from models.resnet import create_Res18_2x
87 | model = create_Res18_2x()
88 | elif arch == 'ResNeXt-50':
89 | from models.resnext import create_Res50_32x4d
90 | model = create_Res50_32x4d()
91 | #elif arch == 'RegNet-800MF':
92 | #from models.regnet import create_Reg800MF
93 | #model = create_Reg800MF()
94 | #elif arch == 'ConvNext-T-0.5x':
95 | #from models.convnext import convnext_tiny_0d5x
96 | #model = convnext_tiny_0d5x()
97 | else:
98 | raise ValueError('TODO')
99 | return model
--------------------------------------------------------------------------------
/dbb_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 |
5 | def transI_fusebn(kernel, bn):
6 | gamma = bn.weight
7 | std = (bn.running_var + bn.eps).sqrt()
8 | return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std
9 |
10 | def transII_addbranch(kernels, biases):
11 | return sum(kernels), sum(biases)
12 |
13 | def transIII_1x1_kxk(k1, b1, k2, b2, groups):
14 | if groups == 1:
15 | k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) #
16 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3))
17 | else:
18 | k_slices = []
19 | b_slices = []
20 | k1_T = k1.permute(1, 0, 2, 3)
21 | k1_group_width = k1.size(0) // groups
22 | k2_group_width = k2.size(0) // groups
23 | for g in range(groups):
24 | k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
25 | k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
26 | k_slices.append(F.conv2d(k2_slice, k1_T_slice))
27 | b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
28 | k, b_hat = transIV_depthconcat(k_slices, b_slices)
29 | return k, b_hat + b2
30 |
31 | def transIV_depthconcat(kernels, biases):
32 | return torch.cat(kernels, dim=0), torch.cat(biases)
33 |
34 | def transV_avg(channels, kernel_size, groups):
35 | input_dim = channels // groups
36 | k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
37 | k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
38 | return k
39 |
40 | # This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
41 | def transVI_multiscale(kernel, target_kernel_size):
42 | H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
43 | W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
44 | return F.pad(kernel, [W_pixels_to_pad, W_pixels_to_pad, H_pixels_to_pad, H_pixels_to_pad])
--------------------------------------------------------------------------------
/images/budgets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/budgets.png
--------------------------------------------------------------------------------
/images/coco_cs.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/coco_cs.PNG
--------------------------------------------------------------------------------
/images/imagenet1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/imagenet1.PNG
--------------------------------------------------------------------------------
/images/imagenet2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/imagenet2.PNG
--------------------------------------------------------------------------------
/images/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/intro.png
--------------------------------------------------------------------------------
/images/norm.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/norm.PNG
--------------------------------------------------------------------------------
/images/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/overview.png
--------------------------------------------------------------------------------
/images/similarity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/similarity.png
--------------------------------------------------------------------------------
/images/supp_grad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JUGGHM/OREPA_CVPR2022/3ed037beeb106cc1c3d011eb2678cebfefc4cba0/images/supp_grad.png
--------------------------------------------------------------------------------
/models/repvgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | import copy
5 | from convnet_utils import conv_bn, conv_bn_relu
6 |
7 | '''
8 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
9 | result = nn.Sequential()
10 | result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
11 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
12 | result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
13 | return result
14 | '''
15 |
16 | class RepVGG(nn.Module):
17 |
18 | def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False):
19 | super(RepVGG, self).__init__()
20 |
21 | assert len(width_multiplier) == 4
22 |
23 | self.deploy = deploy
24 | self.override_groups_map = override_groups_map or dict()
25 | self.use_se = use_se
26 |
27 | assert 0 not in self.override_groups_map
28 |
29 | self.in_planes = min(64, int(64 * width_multiplier[0]))
30 |
31 | self.stage0 = conv_bn_relu(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1)
32 | self.cur_layer_idx = 1
33 | self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2)
34 | self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2)
35 | self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2)
36 | self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2)
37 | self.gap = nn.AdaptiveAvgPool2d(output_size=1)
38 | self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
39 |
40 |
41 | def _make_stage(self, planes, num_blocks, stride):
42 | strides = [stride] + [1]*(num_blocks-1)
43 | blocks = []
44 | for stride in strides:
45 | cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
46 | blocks.append(conv_bn_relu(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
47 | stride=stride, padding=1, groups=cur_groups))
48 | self.in_planes = planes
49 | self.cur_layer_idx += 1
50 | return nn.Sequential(*blocks)
51 |
52 | def forward(self, x):
53 | out = self.stage0(x)
54 | out = self.stage1(out)
55 | out = self.stage2(out)
56 | out = self.stage3(out)
57 | out = self.stage4(out)
58 | out = self.gap(out)
59 | out = out.view(out.size(0), -1)
60 | out = self.linear(out)
61 | return out
62 |
63 |
64 | optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
65 | g2_map = {l: 2 for l in optional_groupwise_layers}
66 | g4_map = {l: 4 for l in optional_groupwise_layers}
67 |
68 | def create_RepVGG_A0(deploy=False):
69 | return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
70 | width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy)
71 |
72 | def create_RepVGG_A1(deploy=False):
73 | return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
74 | width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
75 |
76 | def create_RepVGG_A2(deploy=False):
77 | return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
78 | width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy)
79 |
80 | def create_RepVGG_B0(deploy=False):
81 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
82 | width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
83 |
84 | def create_RepVGG_B1(deploy=False):
85 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
86 | width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy)
87 |
88 | def create_RepVGG_B1g2(deploy=False):
89 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
90 | width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy)
91 |
92 | def create_RepVGG_B1g4(deploy=False):
93 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
94 | width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy)
95 |
96 |
97 | def create_RepVGG_B2(deploy=False):
98 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
99 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy)
100 |
101 | def create_RepVGG_B2g2(deploy=False):
102 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
103 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy)
104 |
105 | def create_RepVGG_B2g4(deploy=False):
106 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
107 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy)
108 |
109 |
110 | def create_RepVGG_B3(deploy=False):
111 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
112 | width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy)
113 |
114 | def create_RepVGG_B3g2(deploy=False):
115 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
116 | width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy)
117 |
118 | def create_RepVGG_B3g4(deploy=False):
119 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
120 | width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy)
121 |
122 | def create_RepVGG_D2se(deploy=False):
123 | return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000,
124 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True)
125 |
126 |
127 | func_dict = {
128 | 'RepVGG-A0': create_RepVGG_A0,
129 | 'RepVGG-A1': create_RepVGG_A1,
130 | 'RepVGG-A2': create_RepVGG_A2,
131 | 'RepVGG-B0': create_RepVGG_B0,
132 | 'RepVGG-B1': create_RepVGG_B1,
133 | 'RepVGG-B1g2': create_RepVGG_B1g2,
134 | 'RepVGG-B1g4': create_RepVGG_B1g4,
135 | 'RepVGG-B2': create_RepVGG_B2,
136 | 'RepVGG-B2g2': create_RepVGG_B2g2,
137 | 'RepVGG-B2g4': create_RepVGG_B2g4,
138 | 'RepVGG-B3': create_RepVGG_B3,
139 | 'RepVGG-B3g2': create_RepVGG_B3g2,
140 | 'RepVGG-B3g4': create_RepVGG_B3g4,
141 | 'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper.
142 | }
143 | def get_RepVGG_func_by_name(name):
144 | return func_dict[name]
145 |
146 |
147 |
148 | # Use this for converting a RepVGG model or a bigger model with RepVGG as its component
149 | # Use like this
150 | # model = create_RepVGG_A0(deploy=False)
151 | # train model or load weights
152 | # repvgg_model_convert(model, save_path='repvgg_deploy.pth')
153 | # If you want to preserve the original model, call with do_copy=True
154 |
155 | # ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
156 | # train_backbone = create_RepVGG_B2(deploy=False)
157 | # train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
158 | # train_pspnet = build_pspnet(backbone=train_backbone)
159 | # segmentation_train(train_pspnet)
160 | # deploy_pspnet = repvgg_model_convert(train_pspnet)
161 | # segmentation_test(deploy_pspnet)
162 | # ===================== example_pspnet.py shows an example
163 |
164 | def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):
165 | if do_copy:
166 | model = copy.deepcopy(model)
167 | for module in model.modules():
168 | if hasattr(module, 'switch_to_deploy'):
169 | module.switch_to_deploy()
170 | if save_path is not None:
171 | torch.save(model.state_dict(), save_path)
172 | return model
173 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import imp
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from convnet_utils import conv_bn, conv_bn_relu
5 | from basic import ConvBN
6 |
7 | class BasicBlock(nn.Module):
8 | expansion = 1
9 | def __init__(self, in_planes, planes, stride=1):
10 | super(BasicBlock, self).__init__()
11 | if stride != 1 or in_planes != self.expansion * planes:
12 | self.shortcut = conv_bn(in_channels=in_planes, out_channels=self.expansion * planes, kernel_size=1, stride=stride, assign_type=ConvBN)
13 | else:
14 | self.shortcut = nn.Identity()
15 | self.conv1 = conv_bn_relu(in_channels=in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1)
16 | self.conv2 = conv_bn(in_channels=planes, out_channels=self.expansion * planes, kernel_size=3, stride=1, padding=1)
17 |
18 | def forward(self, x):
19 | out = self.conv1(x)
20 | out = self.conv2(out)
21 | out = out + self.shortcut(x)
22 | out = F.relu(out)
23 | return out
24 |
25 |
26 | class Bottleneck(nn.Module):
27 | expansion = 4
28 | def __init__(self, in_planes, planes, stride=1):
29 | super(Bottleneck, self).__init__()
30 |
31 | if stride != 1 or in_planes != self.expansion*planes:
32 | self.shortcut = conv_bn(in_planes, self.expansion*planes, kernel_size=1, stride=stride)
33 | else:
34 | self.shortcut = nn.Identity()
35 |
36 | self.conv1 = conv_bn_relu(in_planes, planes, kernel_size=1)
37 | self.conv2 = conv_bn_relu(planes, planes, kernel_size=3, stride=stride, padding=1)
38 | self.conv3 = conv_bn(planes, self.expansion*planes, kernel_size=1)
39 |
40 | def forward(self, x):
41 | out = self.conv1(x)
42 | out = self.conv2(out)
43 | out = self.conv3(out)
44 | out += self.shortcut(x)
45 | out = F.relu(out)
46 | return out
47 |
48 |
49 | class ResNet(nn.Module):
50 | def __init__(self, block, num_blocks, num_classes=1000, width_multiplier=1):
51 | super(ResNet, self).__init__()
52 |
53 | self.in_planes = int(64 * width_multiplier)
54 | self.stage0 = nn.Sequential()
55 | self.stage0.add_module('conv1', conv_bn_relu(in_channels=3, out_channels=self.in_planes, kernel_size=7, stride=2, padding=3))
56 | self.stage0.add_module('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
57 | self.stage1 = self._make_stage(block, int(64 * width_multiplier), num_blocks[0], stride=1)
58 | self.stage2 = self._make_stage(block, int(128 * width_multiplier), num_blocks[1], stride=2)
59 | self.stage3 = self._make_stage(block, int(256 * width_multiplier), num_blocks[2], stride=2)
60 | self.stage4 = self._make_stage(block, int(512 * width_multiplier), num_blocks[3], stride=2)
61 | self.gap = nn.AdaptiveAvgPool2d(output_size=1)
62 | self.linear = nn.Linear(int(512*block.expansion*width_multiplier), num_classes)
63 |
64 | output_channels = int(512*block.expansion*width_multiplier)
65 |
66 | def _make_stage(self, block, planes, num_blocks, stride):
67 | strides = [stride] + [1]*(num_blocks-1)
68 | blocks = []
69 | for stride in strides:
70 | if block is Bottleneck:
71 | blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride))
72 | else:
73 | blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride))
74 | self.in_planes = int(planes * block.expansion)
75 | return nn.Sequential(*blocks)
76 |
77 | def forward(self, x):
78 | out = self.stage0(x)
79 | out = self.stage1(out)
80 | out = self.stage2(out)
81 | out = self.stage3(out)
82 | out = self.stage4(out)
83 | out = self.gap(out)
84 | out = out.view(out.size(0), -1)
85 | out = self.linear(out)
86 | return out
87 |
88 |
89 | def create_Res18():
90 | return ResNet(BasicBlock, [2,2,2,2], num_classes=1000, width_multiplier=1)
91 |
92 | def create_Res18_1d5x():
93 | return ResNet(BasicBlock, [2,2,2,2], num_classes=1000, width_multiplier=1.5)
94 |
95 | def create_Res18_2x():
96 | return ResNet(BasicBlock, [2,2,2,2], num_classes=1000, width_multiplier=2)
97 |
98 | def create_Res34():
99 | return ResNet(BasicBlock, [3,4,6,3], num_classes=1000, width_multiplier=1)
100 |
101 | def create_Res50():
102 | return ResNet(Bottleneck, [3,4,6,3], num_classes=1000, width_multiplier=1)
103 |
104 | def create_Res101():
105 | return ResNet(Bottleneck, [3,4,23,3], num_classes=1000, width_multiplier=1)
--------------------------------------------------------------------------------
/models/resnext.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from convnet_utils import conv_bn, conv_bn_relu
4 |
5 |
6 | class Bottleneck(nn.Module):
7 | expansion = 4
8 | def __init__(self, in_planes, planes, stride=1, cardinality=32, width_per_group=4):
9 | super(Bottleneck, self).__init__()
10 |
11 | D = cardinality * int(planes*width_per_group/64)
12 |
13 | if stride != 1 or in_planes != self.expansion*planes:
14 | self.shortcut = conv_bn(in_planes, self.expansion*planes, kernel_size=1, stride=stride)
15 | else:
16 | self.shortcut = nn.Identity()
17 |
18 | self.conv1 = conv_bn_relu(in_planes, D, kernel_size=1)
19 | self.conv2 = conv_bn_relu(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality)
20 | self.conv3 = conv_bn(D, self.expansion*planes, kernel_size=1)
21 |
22 | def forward(self, x):
23 | out = self.conv1(x)
24 | out = self.conv2(out)
25 | out = self.conv3(out)
26 | out += self.shortcut(x)
27 | out = F.relu(out)
28 | return out
29 |
30 | class ResNeXt(nn.Module):
31 | def __init__(self, block, num_blocks, num_classes=1000, width_multiplier=1, cardinality=32, width_per_group=4):
32 | super(ResNeXt, self).__init__()
33 |
34 | self.in_planes = int(64 * width_multiplier)
35 | self.stage0 = nn.Sequential()
36 | self.stage0.add_module('conv1', conv_bn_relu(in_channels=3, out_channels=self.in_planes, kernel_size=7, stride=2, padding=3))
37 | self.stage0.add_module('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
38 | self.stage1 = self._make_stage(block, int(64 * width_multiplier), num_blocks[0], stride=1, cardinality=cardinality, width_per_group=width_per_group)
39 | self.stage2 = self._make_stage(block, int(128 * width_multiplier), num_blocks[1], stride=2, cardinality=cardinality, width_per_group=width_per_group)
40 | self.stage3 = self._make_stage(block, int(256 * width_multiplier), num_blocks[2], stride=2, cardinality=cardinality, width_per_group=width_per_group)
41 | self.stage4 = self._make_stage(block, int(512 * width_multiplier), num_blocks[3], stride=2, cardinality=cardinality, width_per_group=width_per_group)
42 | self.gap = nn.AdaptiveAvgPool2d(output_size=1)
43 | self.linear = nn.Linear(int(512*block.expansion*width_multiplier), num_classes)
44 |
45 | output_channels = int(512*block.expansion*width_multiplier)
46 |
47 | def _make_stage(self, block, planes, num_blocks, stride, cardinality, width_per_group):
48 | strides = [stride] + [1]*(num_blocks-1)
49 | blocks = []
50 | for stride in strides:
51 | blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride, cardinality=cardinality, width_per_group=width_per_group))
52 | self.in_planes = int(planes * block.expansion)
53 | return nn.Sequential(*blocks)
54 |
55 | def forward(self, x):
56 | out = self.stage0(x)
57 | out = self.stage1(out)
58 | out = self.stage2(out)
59 | out = self.stage3(out)
60 | out = self.stage4(out)
61 | out = self.gap(out)
62 | out = out.view(out.size(0), -1)
63 | out = self.linear(out)
64 | return out
65 |
66 |
67 | def create_Res50_32x4d():
68 | return ResNeXt(Bottleneck, [3,4,6,3], num_classes=1000, width_multiplier=1, cardinality=32, width_per_group=4)
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | import torchvision.datasets as datasets
7 | from utils import accuracy, ProgressMeter, AverageMeter, get_default_ImageNet_val_loader
8 | from convnet_utils import switch_deploy_flag, switch_conv_bn_impl, build_model
9 |
10 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Test')
11 | parser.add_argument('--data', metavar='DIR', type=str, help='path to dataset')
12 | parser.add_argument('mode', metavar='MODE', default='deploy', choices=['train', 'deploy'], help='train or deploy')
13 | parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file')
14 | parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18')
15 | parser.add_argument('-t', '--blocktype', metavar='BLK', default='OREPA')
16 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
17 | help='number of data loading workers (default: 4)')
18 | parser.add_argument('-b', '--val-batch-size', default=100, type=int,
19 | metavar='N',
20 | help='mini-batch size (default: 100) for test')
21 |
22 | def test():
23 | args = parser.parse_args()
24 | args.data = '/disk1/humu.hm/ImageNet/ILSVRC2015/Data/CLS-LOC/'
25 |
26 | switch_deploy_flag(args.mode == 'deploy')
27 | switch_conv_bn_impl(args.blocktype)
28 | model = build_model(args.arch)
29 |
30 | if not torch.cuda.is_available():
31 | print('using CPU, this will be slow')
32 | use_gpu = False
33 | else:
34 | model = model.cuda()
35 | use_gpu = True
36 |
37 | # define loss function (criterion) and optimizer
38 | criterion = torch.nn.CrossEntropyLoss().cuda()
39 |
40 | if 'hdf5' in args.weights:
41 | from utils import model_load_hdf5
42 | model_load_hdf5(model, args.weights)
43 | elif os.path.isfile(args.weights):
44 | print("=> loading checkpoint '{}'".format(args.weights))
45 | checkpoint = torch.load(args.weights)
46 | if 'state_dict' in checkpoint:
47 | checkpoint = checkpoint['state_dict']
48 | ckpt = {k.replace('module.', ''):v for k,v in checkpoint.items()} # strip the names
49 | model.load_state_dict(ckpt)
50 | else:
51 | print("=> no checkpoint found at '{}'".format(args.weights))
52 |
53 |
54 | cudnn.benchmark = True
55 |
56 | # Data loading code
57 | val_loader = get_default_ImageNet_val_loader(args)
58 | validate(val_loader, model, criterion, use_gpu)
59 |
60 |
61 | def validate(val_loader, model, criterion, use_gpu):
62 | batch_time = AverageMeter('Time', ':6.3f', warm=True)
63 | losses = AverageMeter('Loss', ':.4e')
64 | top1 = AverageMeter('Acc@1', ':6.2f')
65 | top5 = AverageMeter('Acc@5', ':6.2f')
66 | progress = ProgressMeter(
67 | len(val_loader),
68 | [batch_time, losses, top1, top5],
69 | prefix='Test: ')
70 |
71 | # switch to evaluate mode
72 | model.eval()
73 |
74 | with torch.no_grad():
75 | #torch.cuda.synchronize()
76 | #end = time.time()
77 | for i, (images, target) in enumerate(val_loader):
78 | if use_gpu:
79 | images = images.cuda(non_blocking=True)
80 | target = target.cuda(non_blocking=True)
81 |
82 | # compute output
83 | torch.cuda.synchronize()
84 | end = time.time()
85 |
86 | output = model(images)
87 |
88 | # measure elapsed time
89 | torch.cuda.synchronize()
90 | batch_time.update(time.time() - end)
91 |
92 | loss = criterion(output, target)
93 |
94 | # measure accuracy and record loss
95 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
96 | losses.update(loss.item(), images.size(0))
97 | top1.update(acc1[0], images.size(0))
98 | top5.update(acc5[0], images.size(0))
99 |
100 | if i % 10 == 0:
101 | progress.display(i)
102 |
103 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
104 | .format(top1=top1, top5=top5))
105 |
106 | return top1.avg
107 |
108 |
109 |
110 |
111 | if __name__ == '__main__':
112 | test()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.distributed as dist
13 | import torch.optim
14 | import torch.multiprocessing as mp
15 | import torch.utils.data
16 | import torch.utils.data.distributed
17 | from torch.optim.lr_scheduler import CosineAnnealingLR
18 | from utils import WarmupCosineAnnealingLR
19 | from utils import AverageMeter, accuracy, ProgressMeter, get_default_ImageNet_val_loader, get_default_ImageNet_train_sampler_loader, log_msg
20 |
21 | IMAGENET_TRAINSET_SIZE = 1281167
22 |
23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
24 | parser.add_argument('--data', metavar='DIR', type=str, help='path to dataset')
25 |
26 | parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18')
27 | parser.add_argument('-t', '--blocktype', metavar='BLK', default='base')
28 |
29 | parser.add_argument('-j', '--workers', default=64, type=int, metavar='N',
30 | help='number of data loading workers (default: 4)')
31 | parser.add_argument('--epochs', default=120, type=int, metavar='N',
32 | help='number of total epochs to run')
33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
34 | help='manual epoch number (useful on restarts)')
35 | parser.add_argument('-b', '--batch-size', default=256, type=int,
36 | metavar='N',
37 | help='mini-batch size (default: 256), this is the total '
38 | 'batch size of all GPUs on the current node when '
39 | 'using `Data Parallel or Distributed Data Parallel')
40 | parser.add_argument('--val-batch-size', default=100, type=int, metavar='V',
41 | help='validation batch size')
42 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
43 | metavar='LR', help='initial learning rate', dest='lr')
44 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
45 | help='momentum')
46 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
47 | metavar='W', help='weight decay (default: 1e-4)',
48 | dest='weight_decay')
49 | parser.add_argument('-p', '--print-freq', default=10, type=int,
50 | metavar='N', help='print frequency (default: 10)')
51 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
52 | help='path to latest checkpoint (default: none)')
53 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
54 | help='evaluate model on validation set')
55 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
56 | help='use pre-trained model')
57 | parser.add_argument('--world-size', default=1, type=int,
58 | help='number of nodes for distributed training')
59 | parser.add_argument('--rank', default=0, type=int,
60 | help='node rank for distributed training')
61 | parser.add_argument('--dist-url', default='20333', type=str,
62 | help='url used to set up distributed training')
63 | parser.add_argument('--dist-backend', default='nccl', type=str,
64 | help='distributed backend')
65 | parser.add_argument('--seed', default=None, type=int,
66 | help='seed for initializing training. ')
67 | parser.add_argument('--gpu', default=None, type=int,
68 | help='GPU id to use.')
69 | parser.add_argument('--multiprocessing-distributed', action='store_true',
70 | help='Use multi-processing distributed training to launch '
71 | 'N processes per node, which has N GPUs. This is the '
72 | 'fastest way to use PyTorch for either single node or '
73 | 'multi node data parallel training')
74 | parser.add_argument('--custwd', dest='custwd', action='store_true',
75 | help='Use custom weight decay. It improves the accuracy and makes quantization easier.')
76 | parser.add_argument('--tag', default='', type=str,
77 | help='the tag for identifying the log and model files. Just a string.')
78 |
79 | best_acc1 = 0
80 |
81 | def sgd_optimizer(model, lr, momentum, weight_decay, use_custwd):
82 | params = []
83 | for key, value in model.named_parameters():
84 | if not value.requires_grad:
85 | continue
86 | apply_weight_decay = weight_decay
87 | apply_lr = lr
88 |
89 | if (use_custwd and ('blk' in key)) or (use_custwd and ('rbr_dense' in key or 'rbr_1x1' in key)) or 'bias' in key or 'bn' in key or 'prior' in key:
90 | apply_weight_decay = 0
91 | print('set weight decay=0 for {}'.format(key))
92 | if 'bias' in key:
93 | apply_lr = 2 * lr # Just a Caffe-style common practice. Made no difference.
94 | params += [{'params': [value], 'lr': apply_lr, 'weight_decay': apply_weight_decay}]
95 | optimizer = torch.optim.SGD(params, lr, momentum=momentum)
96 | return optimizer
97 |
98 | def adam_optimizer(model):
99 | params = []
100 | for key, value in model.named_parameters():
101 | if not value.requires_grad or not 'prior' in key:
102 | continue
103 | params += [{'params': [value], 'lr': 1e-4, 'betas': (0.5, 0.999)}]
104 | print('adam opt for {}'.format(key))
105 | optimizer = torch.optim.Adam(params, 1e-4)
106 | return optimizer
107 |
108 | def main():
109 | args = parser.parse_args()
110 |
111 | args.multiprocessing_distributed = True
112 | args.dist_backend = 'nccl'
113 | #args.data = '/gruntdata/humu.hm/ILSVRC2015/Data/CLS-LOC/'
114 | args.data='/disk1/humu.hm/ImageNet/ILSVRC2015/Data/CLS-LOC/'
115 | args.dist_url = 'tcp://127.0.0.1:' + args.dist_url
116 |
117 | if args.seed is not None:
118 | random.seed(args.seed)
119 | torch.manual_seed(args.seed)
120 | cudnn.deterministic = True
121 | warnings.warn('You have chosen to seed training. '
122 | 'This will turn on the CUDNN deterministic setting, '
123 | 'which can slow down your training considerably! '
124 | 'You may see unexpected behavior when restarting '
125 | 'from checkpoints.')
126 |
127 | if args.gpu is not None:
128 | warnings.warn('You have chosen a specific GPU. This will completely '
129 | 'disable data parallelism.')
130 |
131 | if args.dist_url == "env://" and args.world_size == -1:
132 | args.world_size = int(os.environ["WORLD_SIZE"])
133 |
134 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
135 |
136 | ngpus_per_node = torch.cuda.device_count()
137 | if args.multiprocessing_distributed:
138 | # Since we have ngpus_per_node processes per node, the total world_size
139 | # needs to be adjusted accordingly
140 | args.world_size = ngpus_per_node * args.world_size
141 | # Use torch.multiprocessing.spawn to launch distributed processes: the
142 | # main_worker process function
143 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
144 | else:
145 | # Simply call main_worker function
146 | main_worker(args.gpu, ngpus_per_node, args)
147 |
148 |
149 | def main_worker(gpu, ngpus_per_node, args):
150 | global best_acc1
151 | args.gpu = gpu
152 | log_file = 'results/train_{}_{}_{}_exp.txt'.format(args.arch, args.blocktype, args.tag)
153 |
154 | if args.gpu is not None:
155 | print("Use GPU: {} for training".format(args.gpu))
156 |
157 | if args.distributed:
158 | if args.dist_url == "env://" and args.rank == -1:
159 | args.rank = int(os.environ["RANK"])
160 | if args.multiprocessing_distributed:
161 | # For multiprocessing distributed training, rank needs to be the
162 | # global rank among all the processes
163 | args.rank = args.rank * ngpus_per_node + gpu
164 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
165 | world_size=args.world_size, rank=args.rank)
166 |
167 | from convnet_utils import switch_deploy_flag, switch_conv_bn_impl, build_model
168 | switch_deploy_flag(False)
169 | switch_conv_bn_impl(args.blocktype)
170 | model = build_model(args.arch)
171 |
172 | is_main = not args.multiprocessing_distributed or (
173 | args.multiprocessing_distributed and args.rank % ngpus_per_node == 0)
174 |
175 | if is_main:
176 | for n, p in model.named_parameters():
177 | print(n, p.size())
178 | for n, p in model.named_buffers():
179 | print(n, p.size())
180 | log_msg('epochs {}, lr {}, weight_decay {}'.format(args.epochs, args.lr, args.weight_decay), log_file)
181 |
182 | if not torch.cuda.is_available():
183 | print('using CPU, this will be slow')
184 | elif args.distributed:
185 | # For multiprocessing distributed, DistributedDataParallel constructor
186 | # should always set the single device scope, otherwise,
187 | # DistributedDataParallel will use all available devices.
188 | if args.gpu is not None:
189 | torch.cuda.set_device(args.gpu)
190 | model.cuda(args.gpu)
191 | # When using a single GPU per process and per
192 | # DistributedDataParallel, we need to divide the batch size
193 | # ourselves based on the total number of GPUs we have
194 | args.batch_size = int(args.batch_size / ngpus_per_node)
195 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
196 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
197 | else:
198 | model.cuda()
199 | # DistributedDataParallel will divide and allocate batch_size to all
200 | # available GPUs if device_ids are not set
201 | model = torch.nn.parallel.DistributedDataParallel(model)
202 | elif args.gpu is not None:
203 | torch.cuda.set_device(args.gpu)
204 | model = model.cuda(args.gpu)
205 | else:
206 | # DataParallel will divide and allocate batch_size to all available GPUs
207 | model = torch.nn.DataParallel(model).cuda()
208 |
209 |
210 | # define loss function (criterion) and optimizer
211 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
212 |
213 | optimizer = sgd_optimizer(model, args.lr, args.momentum, args.weight_decay, args.custwd)
214 | #arch_optimizer = adam_optimizer(model)
215 |
216 | #lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node)
217 | lr_scheduler = WarmupCosineAnnealingLR(optimizer=optimizer, T_cosine_max=args.epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node, warmup=args.epochs/24)
218 |
219 | # optionally resume from a checkpoint
220 | if args.resume:
221 | if os.path.isfile(args.resume):
222 | print("=> loading checkpoint '{}'".format(args.resume))
223 | if args.gpu is None:
224 | checkpoint = torch.load(args.resume)
225 | else:
226 | # Map model to be loaded to specified single gpu.
227 | loc = 'cuda:{}'.format(args.gpu)
228 | checkpoint = torch.load(args.resume, map_location=loc)
229 | args.start_epoch = checkpoint['epoch']
230 | best_acc1 = checkpoint['best_acc1']
231 | if args.gpu is not None:
232 | # best_acc1 may be from a checkpoint from a different GPU
233 | best_acc1 = best_acc1.to(args.gpu)
234 | model.load_state_dict(checkpoint['state_dict'])
235 | optimizer.load_state_dict(checkpoint['optimizer'])
236 | lr_scheduler.load_state_dict(checkpoint['scheduler'])
237 | print("=> loaded checkpoint '{}' (epoch {})"
238 | .format(args.resume, checkpoint['epoch']))
239 | else:
240 | print("=> no checkpoint found at '{}'".format(args.resume))
241 |
242 | cudnn.benchmark = True
243 |
244 | train_sampler, train_loader = get_default_ImageNet_train_sampler_loader(args)
245 | val_loader = get_default_ImageNet_val_loader(args)
246 |
247 | if args.evaluate:
248 | validate(val_loader, model, criterion, args)
249 | return
250 |
251 | for epoch in range(args.start_epoch, args.epochs):
252 | if args.distributed:
253 | train_sampler.set_epoch(epoch)
254 | # adjust_learning_rate(optimizer, epoch, args)
255 |
256 | # train for one epoch
257 | train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler, is_main=is_main)
258 |
259 | if is_main:
260 | # evaluate on validation set
261 | acc1 = validate(val_loader, model, criterion, args)
262 | msg = '{}, epoch {}, acc {}'.format(args.arch, epoch, acc1)
263 | log_msg(msg, log_file)
264 |
265 | # remember best acc@1 and save checkpoint
266 | is_best = acc1 > best_acc1
267 | best_acc1 = max(acc1, best_acc1)
268 |
269 | save_checkpoint({
270 | 'epoch': epoch + 1,
271 | 'arch': args.arch,
272 | 'state_dict': model.state_dict(),
273 | 'best_acc1': best_acc1,
274 | 'optimizer' : optimizer.state_dict(),
275 | 'scheduler': lr_scheduler.state_dict(),
276 | }, is_best, filename = '/disk1/humu.hm/ckpt/{}_{}_{}.pth.tar'.format(args.arch, args.blocktype, args.tag),
277 | best_filename='/disk1/humu.hm/ckpt/{}_{}_{}_best.pth.tar'.format(args.arch, args.blocktype, args.tag))
278 |
279 |
280 | def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler, is_main):
281 | batch_time = AverageMeter('Time', ':6.3f', warm=True)
282 | data_time = AverageMeter('Data', ':6.3f')
283 | losses = AverageMeter('Loss', ':.4e')
284 | top1 = AverageMeter('Acc@1', ':6.2f')
285 | top5 = AverageMeter('Acc@5', ':6.2f')
286 | progress = ProgressMeter(
287 | len(train_loader),
288 | [batch_time, data_time, losses, top1, top5, ],
289 | prefix="Epoch: [{}]".format(epoch))
290 |
291 | # switch to train mode
292 | model.train()
293 |
294 | torch.cuda.synchronize()
295 | end = time.time()
296 | for i, (images, target) in enumerate(train_loader):
297 | # measure data loading time
298 | torch.cuda.synchronize()
299 | data_time.update(time.time() - end)
300 |
301 | images = images.cuda(args.gpu, non_blocking=True)
302 | target = target.cuda(args.gpu, non_blocking=True)
303 |
304 | # compute output
305 | output = model(images)
306 | loss = criterion(output, target)
307 |
308 | if args.custwd:
309 | for module in model.modules():
310 | if hasattr(module, 'get_custom_L2'):
311 | loss += args.weight_decay * 0.5 * module.get_custom_L2()
312 | #if hasattr(module, 'weight_gen'):
313 | #loss += args.weight_decay * 0.5 * ((module.weight_gen()**2).sum())
314 | #if hasattr(module, 'weight'):
315 | #loss += args.weight_decay * 0.5 * ((module.weight**2).sum())
316 |
317 | # measure accuracy and record loss
318 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
319 | losses.update(loss.item(), images.size(0))
320 | top1.update(acc1[0], images.size(0))
321 | top5.update(acc5[0], images.size(0))
322 |
323 | # compute gradient and do SGD step
324 | optimizer.zero_grad()
325 | loss.backward()
326 | optimizer.step()
327 |
328 | # measure elapsed time
329 | torch.cuda.synchronize()
330 | batch_time.update(time.time() - end)
331 | end = time.time()
332 |
333 | lr_scheduler.step()
334 |
335 | if is_main and i % args.print_freq == 0:
336 | progress.display(i)
337 | if is_main and i % 1000 == 0:
338 | print('cur lr: ', lr_scheduler.get_lr()[0])
339 |
340 |
341 |
342 | def validate(val_loader, model, criterion, args):
343 | batch_time = AverageMeter('Time', ':6.3f')
344 | losses = AverageMeter('Loss', ':.4e')
345 | top1 = AverageMeter('Acc@1', ':6.2f')
346 | top5 = AverageMeter('Acc@5', ':6.2f')
347 | progress = ProgressMeter(
348 | len(val_loader),
349 | [batch_time, losses, top1, top5],
350 | prefix='Test: ')
351 |
352 | # switch to evaluate mode
353 | model.eval()
354 |
355 | with torch.no_grad():
356 | end = time.time()
357 | for i, (images, target) in enumerate(val_loader):
358 | if args.gpu is not None:
359 | images = images.cuda(args.gpu, non_blocking=True)
360 | if torch.cuda.is_available():
361 | target = target.cuda(args.gpu, non_blocking=True)
362 |
363 | # compute output
364 | output = model(images)
365 | loss = criterion(output, target)
366 |
367 | # measure accuracy and record loss
368 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
369 | losses.update(loss.item(), images.size(0))
370 | top1.update(acc1[0], images.size(0))
371 | top5.update(acc5[0], images.size(0))
372 |
373 | # measure elapsed time
374 | batch_time.update(time.time() - end)
375 | end = time.time()
376 |
377 | if i % args.print_freq == 0:
378 | progress.display(i)
379 |
380 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
381 | .format(top1=top1, top5=top5))
382 |
383 | return top1.avg
384 |
385 |
386 | def save_checkpoint(state, is_best, filename, best_filename):
387 | torch.save(state, filename)
388 | if is_best:
389 | shutil.copyfile(filename, best_filename)
390 |
391 |
392 | if __name__ == '__main__':
393 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torchvision.datasets as datasets
4 | import os
5 | import torchvision.transforms as transforms
6 | import PIL
7 |
8 | from PIL import ImageFile
9 | ImageFile.LOAD_TRUNCATED_IMAGES = True
10 |
11 | class AverageMeter(object):
12 | """Computes and stores the average and current value"""
13 | def __init__(self, name, fmt=':f', warm=False):
14 | self.name = name
15 | self.fmt = fmt
16 | self.reset()
17 |
18 | self.time_thres = -1
19 | if warm:
20 | self.time_thres = 20
21 |
22 | def reset(self):
23 | self.val = 0
24 | self.avg = 0
25 | self.sum = 0
26 | self.count = 0
27 |
28 | def update(self, val, n=1):
29 | if self.time_thres > 0:
30 | self.time_thres -= 1
31 | return
32 |
33 | self.val = val
34 | self.sum += val * n
35 | self.count += n
36 | self.avg = self.sum / self.count
37 |
38 | def __str__(self):
39 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
40 | return fmtstr.format(**self.__dict__)
41 |
42 |
43 | class ProgressMeter(object):
44 | def __init__(self, num_batches, meters, prefix=""):
45 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
46 | self.meters = meters
47 | self.prefix = prefix
48 |
49 | def display(self, batch):
50 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
51 | entries += [str(meter) for meter in self.meters]
52 | print('\t'.join(entries))
53 |
54 | def _get_batch_fmtstr(self, num_batches):
55 | num_digits = len(str(num_batches // 1))
56 | fmt = '{:' + str(num_digits) + 'd}'
57 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
58 |
59 |
60 | def accuracy(output, target, topk=(1,)):
61 | """Computes the accuracy over the k top predictions for the specified values of k"""
62 | with torch.no_grad():
63 | maxk = max(topk)
64 | batch_size = target.size(0)
65 |
66 | _, pred = output.topk(maxk, 1, True, True)
67 | pred = pred.t()
68 | correct = pred.eq(target.view(1, -1).expand_as(pred))
69 |
70 | res = []
71 | for k in topk:
72 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
73 | res.append(correct_k.mul_(100.0 / batch_size))
74 | return res
75 |
76 | def load_checkpoint(model, ckpt_path):
77 | checkpoint = torch.load(ckpt_path)
78 | if 'state_dict' in checkpoint:
79 | checkpoint = checkpoint['state_dict']
80 | ckpt = {}
81 | for k, v in checkpoint.items():
82 | if k.startswith('module.'):
83 | ckpt[k[7:]] = v
84 | else:
85 | ckpt[k] = v
86 | model.load_state_dict(ckpt)
87 |
88 | def read_hdf5(file_path):
89 | import h5py
90 | import numpy as np
91 | result = {}
92 | with h5py.File(file_path, 'r') as f:
93 | for k in f.keys():
94 | value = np.asarray(f[k])
95 | result[str(k).replace('+', '/')] = value
96 | print('read {} arrays from {}'.format(len(result), file_path))
97 | f.close()
98 | return result
99 |
100 | def model_load_hdf5(model:torch.nn.Module, hdf5_path, ignore_keys='stage0.'):
101 | weights_dict = read_hdf5(hdf5_path)
102 | for name, param in model.named_parameters():
103 | print('load param: ', name, param.size())
104 | if name in weights_dict:
105 | np_value = weights_dict[name]
106 | else:
107 | np_value = weights_dict[name.replace(ignore_keys, '')]
108 | value = torch.from_numpy(np_value).float()
109 | assert tuple(value.size()) == tuple(param.size())
110 | param.data = value
111 | for name, param in model.named_buffers():
112 | print('load buffer: ', name, param.size())
113 | if name in weights_dict:
114 | np_value = weights_dict[name]
115 | else:
116 | np_value = weights_dict[name.replace(ignore_keys, '')]
117 | value = torch.from_numpy(np_value).float()
118 | assert tuple(value.size()) == tuple(param.size())
119 | param.data = value
120 |
121 | class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
122 |
123 | def __init__(self, optimizer, T_cosine_max, eta_min=0, last_epoch=-1, warmup=0):
124 | self.eta_min = eta_min
125 | self.T_cosine_max = T_cosine_max
126 | self.warmup = warmup
127 | super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
128 |
129 | def get_lr(self):
130 | if self.last_epoch < self.warmup:
131 | return [self.last_epoch / self.warmup * base_lr for base_lr in self.base_lrs]
132 | else:
133 | return [self.eta_min + (base_lr - self.eta_min) *
134 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup) / (self.T_cosine_max - self.warmup))) / 2
135 | for base_lr in self.base_lrs]
136 |
137 |
138 | def log_msg(message, log_file):
139 | print(message)
140 | with open(log_file, 'a') as f:
141 | print(message, file=f)
142 |
143 | def get_ImageNet_train_dataset(args, trans):
144 | traindir = os.path.join(args.data, 'train')
145 | train_dataset = datasets.ImageFolder(traindir, trans)
146 | return train_dataset
147 |
148 |
149 | def get_ImageNet_val_dataset(args, trans):
150 | traindir = os.path.join(args.data, 'val')
151 | val_dataset = datasets.ImageFolder(traindir, trans)
152 | return val_dataset
153 |
154 |
155 | def get_default_train_trans(args):
156 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
157 | std=[0.229, 0.224, 0.225])
158 | if (not hasattr(args, 'resolution')) or args.resolution == 224:
159 | trans = transforms.Compose([
160 | transforms.RandomResizedCrop(224),
161 | transforms.RandomHorizontalFlip(),
162 | transforms.ToTensor(),
163 | normalize])
164 | else:
165 | raise ValueError('Not yet implemented.')
166 | return trans
167 |
168 |
169 | def get_default_val_trans(args):
170 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
171 | std=[0.229, 0.224, 0.225])
172 | if (not hasattr(args, 'resolution')) or args.resolution == 224:
173 | trans = transforms.Compose([
174 | transforms.Resize(256),
175 | transforms.CenterCrop(224),
176 | transforms.ToTensor(),
177 | normalize])
178 | else:
179 | trans = transforms.Compose([
180 | transforms.Resize(args.resolution, interpolation=PIL.Image.BILINEAR),
181 | transforms.CenterCrop(args.resolution),
182 | transforms.ToTensor(),
183 | normalize,
184 | ])
185 | return trans
186 |
187 |
188 | def get_default_ImageNet_train_sampler_loader(args):
189 | train_trans = get_default_train_trans(args)
190 | train_dataset = get_ImageNet_train_dataset(args, train_trans)
191 | if args.distributed:
192 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
193 | else:
194 | train_sampler = None
195 | train_loader = torch.utils.data.DataLoader(
196 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
197 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
198 | return train_sampler, train_loader
199 |
200 |
201 | def get_default_ImageNet_val_loader(args):
202 | val_trans = get_default_val_trans(args)
203 | val_dataset = get_ImageNet_val_dataset(args, val_trans)
204 | val_loader = torch.utils.data.DataLoader(
205 | val_dataset,
206 | batch_size=args.val_batch_size, shuffle=False,
207 | num_workers=args.workers, pin_memory=True)
208 | return val_loader
--------------------------------------------------------------------------------