├── LICENSE ├── README.md └── algnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ALGNet 2 | Learning Enriched Features via Selective State Spaces Model for Efficient Image Deblurring 3 | 4 | 5 | ## Our code will be released after the paper is published 6 | 7 | Since we are preparing to extend the paper to journals, the code will be published together with the journal paper after it is accepted. 8 | Our core code does not modify the SSM module and can be self-written according to our network architecture diagram. 9 | 10 | 11 | ## Quick Run 12 | 13 | To test the pre-trained models 14 | [Google Drive](https://drive.google.com/drive/folders/1WOYuuvGCDOJWo0U6PizE4780EmkxbDya?usp=sharing) 15 | 16 | 17 | The visual result(ALGNet-32, trained only on GoPro) 18 | [Google Drive](https://drive.google.com/drive/folders/1auM3j5Yx2HEKuDUlIJWDDlirvjnv_RvB?usp=sharing). 19 | 20 | ## Citations 21 | If our code helps your research or work, please consider citing our paper. 22 | The following is a BibTeX reference: 23 | 24 | ``` 25 | @inproceedings{ 26 | gao2024learning, 27 | title={Learning Enriched Features via Selective State Spaces Model for Efficient Image Deblurring}, 28 | author={Hu Gao and Bowen Ma and Ying Zhang and Jingfan Yang and Jing Yang and Depeng Dang}, 29 | booktitle={ACM Multimedia 2024}, 30 | year={2024}, 31 | } 32 | ``` 33 | 34 | 35 | ## Contact 36 | Should you have any question, please contact two_bits@163.com 37 | 38 | -------------------------------------------------------------------------------- /algnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import time 6 | from torch.cuda.amp import autocast 7 | from utils.arch_utils import LayerNorm2d, MySequential 8 | from mamba_ssm import Mamba 9 | 10 | 11 | 12 | 13 | class SimpleGate(nn.Module): 14 | def forward(self, x): 15 | x1, x2 = x.chunk(2, dim=1) 16 | return x1 * x2 17 | 18 | class Glayer(nn.Module): 19 | def __init__(self, dim, d_state = 32, d_conv = 4, expand = 2): 20 | super().__init__() 21 | self.dim = dim 22 | self.norm = nn.LayerNorm(dim) 23 | 24 | 25 | @autocast(enabled=False) 26 | def forward(self, x): 27 | if x.dtype == torch.float16: 28 | x = x.type(torch.float32) 29 | B, C = x.shape[:2] 30 | 31 | 32 | return out 33 | 34 | class Llayer(nn.Module): 35 | def __init__(self, c, DW_Expand=2): 36 | super().__init__() 37 | 38 | 39 | 40 | def forward(self, x): 41 | 42 | return x 43 | 44 | 45 | 46 | class BasicBlock(nn.Module): 47 | def __init__(self, c, DW_Expand=2,FFN_Expand=2, drop_out_rate=0.): 48 | super().__init__() 49 | self.sg = SimpleGate() 50 | 51 | dw_channel = c * DW_Expand 52 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 53 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, 54 | bias=True) 55 | 56 | 57 | self.conv3_2 = nn.Conv2d(in_channels=c*2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 58 | # Simplified Channel Attention 59 | 60 | 61 | 62 | self.norm1 = LayerNorm2d(c) 63 | self.norm2 = LayerNorm2d(c) 64 | 65 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 66 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 67 | 68 | ffn_channel = FFN_Expand * c 69 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 70 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 71 | 72 | 73 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 74 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 75 | self.g_layer = Glayer(c) 76 | self.l_layer = Llayer(c) 77 | 78 | self.inside_all = nn.Parameter(torch.zeros(c,1,1), requires_grad=True) 79 | self.lamb_g = nn.Parameter(torch.zeros(c), requires_grad=True) 80 | self.lamb_l = nn.Parameter(torch.zeros(c), requires_grad=True) 81 | 82 | 83 | 84 | 85 | def forward(self, inp): 86 | x = inp 87 | 88 | x = self.norm1(x) 89 | 90 | x = self.conv1(x) 91 | x = self.conv2(x) 92 | x = self.sg(x) 93 | 94 | x_g = self.g_layer(x) 95 | 96 | x_l = self.l_layer(x) 97 | 98 | 99 | x_g = x_g * (self.inside_all + 1.) 100 | x_g = x_g * self.lamb_g[None,:,None,None] 101 | x_l = x_l * self.lamb_l[None,:,None,None] 102 | x = x_g + x_l 103 | 104 | 105 | x = self.dropout1(x) 106 | y = inp + x 107 | 108 | x = self.conv4(self.norm2(y)) 109 | x = self.sg(x) 110 | x = self.conv5(x) 111 | 112 | x = self.dropout2(x) 113 | 114 | return y + x * self.beta 115 | 116 | 117 | 118 | class FAM(nn.Module): 119 | def __init__(self, channel): 120 | super(FAM, self).__init__() 121 | self.merge = nn.Conv2d(channel*2, channel, kernel_size=1, stride=1) 122 | 123 | def forward(self, x1, x2): 124 | return self.merge(torch.cat([x1, x2], dim=1)) 125 | 126 | class BasicConv(nn.Module): 127 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 128 | super(BasicConv, self).__init__() 129 | if bias and norm: 130 | bias = False 131 | 132 | padding = kernel_size // 2 133 | layers = list() 134 | if transpose: 135 | padding = kernel_size // 2 -1 136 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 137 | else: 138 | layers.append( 139 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 140 | if norm: 141 | layers.append(LayerNorm2d(out_channel)) 142 | if relu: 143 | layers.append(nn.GELU()) 144 | self.main = nn.Sequential(*layers) 145 | 146 | def forward(self, x): 147 | return self.main(x) 148 | 149 | class SFE(nn.Module): 150 | def __init__(self, out_plane): 151 | super(SFE, self).__init__() 152 | self.main = nn.Sequential( 153 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), 154 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 155 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 156 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False), 157 | nn.InstanceNorm2d(out_plane, affine=True) 158 | ) 159 | 160 | 161 | 162 | def forward(self, x): 163 | x = self.main(x) 164 | return x 165 | 166 | 167 | 168 | 169 | 170 | 171 | class ALGNet(nn.Module): 172 | 173 | #middle_blk_num=1, enc_blk_nums=[1,1,1,28], dec_blk_nums=[1,1,1,1] 174 | # middle_blk_num=8, enc_blk_nums=[2,2,4,8], dec_blk_nums=[8,4,2,2] 175 | def __init__(self, img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1,1,1,28], dec_blk_nums=[1,1,1,1]): 176 | super().__init__() 177 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, 178 | bias=True) 179 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, 180 | bias=True) 181 | 182 | self.encoders = nn.ModuleList() 183 | self.decoders = nn.ModuleList() 184 | self.middle_blks = nn.ModuleList() 185 | self.ups = nn.ModuleList() 186 | self.downs = nn.ModuleList() 187 | 188 | chan = width 189 | for num in enc_blk_nums: 190 | self.encoders.append( 191 | nn.Sequential( 192 | *[BasicBlock(chan) for _ in range(num)] 193 | ) 194 | ) 195 | self.downs.append( 196 | nn.Conv2d(chan, 2*chan, 2, 2) 197 | ) 198 | chan = chan * 2 199 | 200 | self.middle_blks = \ 201 | nn.Sequential( 202 | *[BasicBlock(chan) for _ in range(middle_blk_num)] 203 | ) 204 | 205 | for num in dec_blk_nums: 206 | self.ups.append( 207 | nn.Sequential( 208 | nn.Conv2d(chan, chan * 2, 1, bias=False), 209 | nn.PixelShuffle(2) 210 | ) 211 | ) 212 | chan = chan // 2 213 | self.decoders.append( 214 | nn.Sequential( 215 | *[BasicBlock(chan) for _ in range(num)] 216 | ) 217 | ) 218 | 219 | self.padder_size = 2 ** len(self.encoders) 220 | self.SFE2 = SFE(width*2) 221 | self.SFE4 = SFE(width*4) 222 | self.SFE8 = SFE(width*8) 223 | 224 | self.FAM2 = FAM(width * 2) 225 | self.FAM4 = FAM(width * 4) 226 | self.FAM8 = FAM(width * 8) 227 | 228 | self.Convs = nn.ModuleList([ 229 | nn.Conv2d(width * 16, width * 8, kernel_size=1), 230 | nn.Conv2d(width * 8, width*4, kernel_size=1,), 231 | nn.Conv2d(width * 4, width*2, kernel_size=1,), 232 | nn.Conv2d(width * 2, width, kernel_size=1,), 233 | ]) 234 | self.ConvsOut = nn.ModuleList( 235 | [ 236 | nn.Conv2d(width * 8, 3, kernel_size=3,padding=1, stride=1, groups=1, 237 | bias=True), 238 | nn.Conv2d(width * 4, 3, kernel_size=3,padding=1, stride=1, groups=1, 239 | bias=True), 240 | nn.Conv2d(width * 2, 3, kernel_size=3,padding=1, stride=1, groups=1, 241 | bias=True), 242 | ] 243 | ) 244 | 245 | def forward(self, inp): 246 | B, C, H, W = inp.shape 247 | inp = self.check_image_size(inp) 248 | x_2 = F.interpolate(inp, scale_factor=0.5) 249 | x_4 = F.interpolate(x_2, scale_factor=0.5) 250 | x_8 = F.interpolate(x_4, scale_factor=0.5) 251 | z2 = self.SFE2(x_2) 252 | z4 = self.SFE4(x_4) 253 | z8 = self.SFE8(x_8) 254 | 255 | x = self.intro(inp) 256 | 257 | encs = [] 258 | enc_i = 0 259 | for encoder, down in zip(self.encoders, self.downs): 260 | x = encoder(x) 261 | encs.append(x) 262 | x = down(x) 263 | enc_i = enc_i + 1 264 | if enc_i == 1: 265 | x = self.FAM2(x, z2) 266 | elif enc_i == 2: 267 | x = self.FAM4(x, z4) 268 | elif enc_i == 3: 269 | x = self.FAM8(x, z8) 270 | 271 | x = self.middle_blks(x) 272 | 273 | index = 0 274 | decs = [] 275 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 276 | x = up(x) 277 | # enc_skip = self.Convs1[index](enc_skip) 278 | x = torch.concatenate([x,enc_skip],dim=1) 279 | x = self.Convs[index](x) 280 | index = index+1 281 | x = decoder(x) 282 | decs.append(x) 283 | 284 | 285 | 286 | outs = [] 287 | outs.append(self.ConvsOut[0](decs[0])+x_8) 288 | outs.append(self.ConvsOut[1](decs[1])+x_4) 289 | outs.append(self.ConvsOut[2](decs[2])+x_2) 290 | 291 | x = self.ending(x) 292 | x = x + inp 293 | outs.append(x) 294 | 295 | return outs 296 | 297 | 298 | def check_image_size(self, x): 299 | _, _, h, w = x.size() 300 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 301 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 302 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 303 | return x 304 | 305 | 306 | 307 | class AvgPool2d(nn.Module): 308 | def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): 309 | super().__init__() 310 | self.kernel_size = kernel_size 311 | self.base_size = base_size 312 | self.auto_pad = auto_pad 313 | 314 | # only used for fast implementation 315 | self.fast_imp = fast_imp 316 | self.rs = [5, 4, 3, 2, 1] 317 | self.max_r1 = self.rs[0] 318 | self.max_r2 = self.rs[0] 319 | self.train_size = train_size 320 | 321 | def extra_repr(self) -> str: 322 | return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( 323 | self.kernel_size, self.base_size, self.kernel_size, self.fast_imp 324 | ) 325 | 326 | def forward(self, x): 327 | if self.kernel_size is None and self.base_size: 328 | train_size = self.train_size 329 | if isinstance(self.base_size, int): 330 | self.base_size = (self.base_size, self.base_size) 331 | self.kernel_size = list(self.base_size) 332 | self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] 333 | self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] 334 | 335 | # only used for fast implementation 336 | self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) 337 | self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) 338 | 339 | if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): 340 | return F.adaptive_avg_pool2d(x, 1) 341 | 342 | if self.fast_imp: # Non-equivalent implementation but faster 343 | h, w = x.shape[2:] 344 | if self.kernel_size[0] >= h and self.kernel_size[1] >= w: 345 | out = F.adaptive_avg_pool2d(x, 1) 346 | else: 347 | r1 = [r for r in self.rs if h % r == 0][0] 348 | r2 = [r for r in self.rs if w % r == 0][0] 349 | # reduction_constraint 350 | r1 = min(self.max_r1, r1) 351 | r2 = min(self.max_r2, r2) 352 | s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) 353 | n, c, h, w = s.shape 354 | k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) 355 | out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) 356 | out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2)) 357 | else: 358 | n, c, h, w = x.shape 359 | s = x.cumsum(dim=-1).cumsum_(dim=-2) 360 | s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience 361 | k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) 362 | s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] 363 | out = s4 + s1 - s2 - s3 364 | out = out / (k1 * k2) 365 | 366 | if self.auto_pad: 367 | n, c, h, w = x.shape 368 | _h, _w = out.shape[2:] 369 | # print(x.shape, self.kernel_size) 370 | pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) 371 | out = torch.nn.functional.pad(out, pad2d, mode='replicate') 372 | 373 | return out 374 | 375 | 376 | def replace_layers(model, base_size, train_size, fast_imp, **kwargs): 377 | for n, m in model.named_children(): 378 | if len(list(m.children())) > 0: 379 | ## compound module, go inside it 380 | replace_layers(m, base_size, train_size, fast_imp, **kwargs) 381 | 382 | if isinstance(m, nn.AdaptiveAvgPool2d): 383 | pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size).cuda() 384 | 385 | if m.output_size == 1: 386 | setattr(model, n, pool) 387 | # assert m.output_size == 1 388 | 389 | 390 | # if isinstance(m, Attention): 391 | # attn = LocalAttention(dim=m.dim, num_heads=m.num_heads, is_prompt=m.is_prompt, bias=True, base_size=base_size, fast_imp=False, 392 | # train_size=train_size) 393 | # setattr(model, n, attn) 394 | 395 | 396 | class Local_Base(): 397 | def convert(self, *args, train_size, **kwargs): 398 | replace_layers(self, *args, train_size=train_size, **kwargs) 399 | imgs = torch.rand(train_size).cuda() 400 | with torch.no_grad(): 401 | self.forward(imgs) 402 | 403 | class ALGNetLocal(Local_Base, ALGNet): 404 | def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs): 405 | Local_Base.__init__(self) 406 | ALGNet.__init__(self, *args, **kwargs) 407 | self.cuda() 408 | 409 | N, C, H, W = train_size 410 | base_size = (int(H * 1.5), int(W * 1.5)) 411 | 412 | self.eval() 413 | with torch.no_grad(): 414 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) 415 | 416 | 417 | 418 | 419 | if __name__ == "__main__": 420 | import time 421 | # start = time.time() 422 | net = ALGNetLocal().cuda() 423 | x1 = torch.randn((1, 3, 30, 90)) 424 | x2 = torch.randn((1, 3, 30, 90)) 425 | x = torch.randn((1, 3, 256, 256)) 426 | print("Total number of param is ", sum(i.numel() for i in net.parameters())) 427 | x = x.cuda() 428 | t=net(x) 429 | print(t[0].shape) 430 | torch.cuda.synchronize() 431 | end = time.time() 432 | 433 | inp_shape = (3, 256, 256) 434 | from ptflops import get_model_complexity_info 435 | FLOPS = 0 436 | 437 | macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True) 438 | # # print(params) 439 | macs = float(macs[:-4]) + FLOPS / 10 ** 9 440 | 441 | 442 | 443 | print('mac', macs, params) 444 | 445 | from thop import profile 446 | x3 = torch.randn((1, 3, 256, 256)).to('cuda:0') 447 | flops, params = profile(net.to('cuda:0'), inputs=(x3.to('cuda:0'), )) 448 | print('FLOPs = ' + str(flops/1000**3) + 'G') 449 | print('Params = ' + str(params/1000**2) + 'M') 450 | 451 | 452 | --------------------------------------------------------------------------------