├── AFFBlock.ckpt ├── README.md ├── __init__.py ├── aff_block_LL.py ├── affnet.ckpt ├── affnet_LL.py ├── affnet_config.py ├── base_cls_LL.py ├── base_layer.py ├── base_module.py ├── init_utils.py ├── layers.py ├── logger.py ├── profiler.py └── sync_batch_norm.py /AFFBlock.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TryHard-LL/AFFNet/cd42911f4ebd595a5049653ee8029d4d60eb9b7e/AFFBlock.ckpt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AFFNet 2 | AFFNet-Unofficial Implementation 3 | 4 | code:https://github.com/NWPU-Li/AFFNet 5 | 6 | ''' 7 | Description: 8 | 9 | Author:LL-Version-V1 10 | 11 | LastEditTime: 2023-08-23 12 | 13 | Description:AFFNet-Pytorch-Unofficial-Implementation 14 | 15 | Reference:https://github.com/microsoft/TokenMixers 16 | 17 | Original Paper:Adaptive Frequency Filters As Efficient Global Token Mixers, ICCV 2023-https://arxiv.org/abs/2307.14008 18 | 19 | ''' 20 | 21 | ### AFFNet 22 | ```bash 23 | python affnet_LL.py 24 | ``` 25 | You can change the 'model-mode' by modifying the "config:model.classification.affnet.mode". There are three options to choose, ie. ["xx_small", "x_small", "small"]. 26 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TryHard-LL/AFFNet/cd42911f4ebd595a5049653ee8029d4d60eb9b7e/__init__.py -------------------------------------------------------------------------------- /aff_block_LL.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2023 Microsoft 3 | # Licensed under The MIT License 4 | ''' 5 | Description: 6 | Author:LL-Version-V1 7 | LastEditTime: 2023-08-23 8 | Description:AFFBlock-Pytorch-Unofficial-Implementation 9 | Reference:https://github.com/microsoft/TokenMixers; 10 | Original Paper:Adaptive Frequency Filters As Efficient Global Token Mixers, ICCV 2023-https://arxiv.org/abs/2307.14008 11 | ''' 12 | # -------------------------------------------------------- 13 | import numpy as np 14 | from torch import nn, Tensor 15 | import logger 16 | import math 17 | import torch 18 | from torch.nn import functional as F 19 | from typing import Optional, Dict, Tuple, Union, Sequence 20 | 21 | from base_module import BaseModule 22 | from profiler import module_profile 23 | from layers import ConvLayer, InvertedResidual 24 | from sync_batch_norm import SyncBatchNorm 25 | import math 26 | import torch 27 | import torch.fft 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | import time 31 | 32 | 33 | 34 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 35 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 36 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 37 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 38 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 39 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 40 | 'survival rate' as the argument. 41 | """ 42 | if drop_prob == 0. or not training: 43 | return x 44 | keep_prob = 1 - drop_prob 45 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 46 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 47 | if keep_prob > 0.0 and scale_by_keep: 48 | random_tensor.div_(keep_prob) 49 | return x * random_tensor 50 | 51 | 52 | class DropPath(nn.Module): 53 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 54 | """ 55 | 56 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 57 | super(DropPath, self).__init__() 58 | self.drop_prob = drop_prob 59 | self.scale_by_keep = scale_by_keep 60 | 61 | def forward(self, x): 62 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 63 | 64 | def extra_repr(self): 65 | return f'drop_prob={round(self.drop_prob, 3):0.3f}' 66 | 67 | 68 | class AFNO2D_channelfirst(nn.Module): 69 | """ 70 | hidden_size: channel dimension size 71 | num_blocks: how many blocks to use in the block diagonal weight matrices (higher => less complexity but less parameters) 72 | sparsity_threshold: lambda for softshrink 73 | hard_thresholding_fraction: how many frequencies you want to completely mask out (lower => hard_thresholding_fraction^2 less FLOPs) 74 | input shape [B N C] 75 | """ 76 | 77 | def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, 78 | hidden_size_factor=1): 79 | super().__init__() 80 | assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" 81 | 82 | self.hidden_size = hidden_size 83 | self.sparsity_threshold = 0.01 84 | self.num_blocks = num_blocks 85 | self.block_size = self.hidden_size // self.num_blocks 86 | self.hard_thresholding_fraction = hard_thresholding_fraction 87 | self.hidden_size_factor = hidden_size_factor 88 | self.scale = 0.02 89 | 90 | self.w1 = nn.Parameter( 91 | self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) 92 | self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) 93 | self.w2 = nn.Parameter( 94 | self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) 95 | self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) 96 | 97 | self.act = self.build_act_layer() 98 | self.act2 = self.build_act_layer() 99 | 100 | @staticmethod 101 | def build_act_layer() -> nn.Module: 102 | act_layer = nn.ReLU() 103 | return act_layer 104 | 105 | @torch.cuda.amp.autocast(enabled=False) 106 | def forward(self, x, spatial_size=None): 107 | bias = x 108 | 109 | dtype = x.dtype 110 | x = x.float() 111 | B, C, H, W = x.shape 112 | # x = self.fu(x) 113 | 114 | x = torch.fft.rfft2(x, dim=(2, 3), norm="ortho") 115 | origin_ffted = x 116 | x = x.reshape(B, self.num_blocks, self.block_size, x.shape[2], x.shape[3]) 117 | 118 | 119 | o1_real = self.act( 120 | torch.einsum('bkihw,kio->bkohw', x.real, self.w1[0]) - \ 121 | torch.einsum('bkihw,kio->bkohw', x.imag, self.w1[1]) + \ 122 | self.b1[0, :, :, None, None] 123 | ) 124 | 125 | o1_imag = self.act2( 126 | torch.einsum('bkihw,kio->bkohw', x.imag, self.w1[0]) + \ 127 | torch.einsum('bkihw,kio->bkohw', x.real, self.w1[1]) + \ 128 | self.b1[1, :, :, None, None] 129 | ) 130 | 131 | o2_real = ( 132 | torch.einsum('bkihw,kio->bkohw', o1_real, self.w2[0]) - \ 133 | torch.einsum('bkihw,kio->bkohw', o1_imag, self.w2[1]) + \ 134 | self.b2[0, :, :, None, None] 135 | ) 136 | 137 | o2_imag = ( 138 | torch.einsum('bkihw,kio->bkohw', o1_imag, self.w2[0]) + \ 139 | torch.einsum('bkihw,kio->bkohw', o1_real, self.w2[1]) + \ 140 | self.b2[1, :, :, None, None] 141 | ) 142 | 143 | x = torch.stack([o2_real, o2_imag], dim=-1) 144 | x = F.softshrink(x, lambd=self.sparsity_threshold) 145 | x = torch.view_as_complex(x) 146 | x = x.reshape(B, C, x.shape[3], x.shape[4]) 147 | 148 | x = x * origin_ffted 149 | x = torch.fft.irfft2(x, s=(H, W), dim=(2, 3), norm="ortho") 150 | x = x.type(dtype) 151 | 152 | return x + bias 153 | 154 | def profile_module( 155 | self, input: Tensor, *args, **kwargs 156 | ) -> Tuple[Tensor, float, float]: 157 | # TODO: to edit it 158 | b_sz, c, h, w = input.shape 159 | seq_len = h * w 160 | 161 | # FFT iFFT 162 | p_ff, m_ff = 0, 5 * b_sz * seq_len * int(math.log(seq_len)) * c 163 | # others 164 | # params = macs = sum([p.numel() for p in self.parameters()]) 165 | params = macs = self.hidden_size * self.hidden_size_factor * self.hidden_size * 2 * 2 // self.num_blocks 166 | # // 2 min n become half after fft 167 | macs = macs * b_sz * seq_len 168 | 169 | # return input, params, macs 170 | return input, params, macs + m_ff 171 | 172 | 173 | def remove_edge(img: np.ndarray): 174 | # // remove the edge of a numpy image 175 | return img[1:-1, 1:-1] 176 | 177 | def save_feature(feature): 178 | import time 179 | import matplotlib.pyplot as plt 180 | import os 181 | now = time.time() 182 | feature = feature.detach() 183 | os.makedirs('visual_example', exist_ok=True) 184 | for i in range(feature.shape[1]): 185 | feature_channel = feature[0, i] 186 | fig, ax = plt.subplots() 187 | img_channel = ax.imshow(remove_edge(feature_channel.cpu().numpy()), cmap='gray') 188 | plt.savefig('visual_example/{now}_channel_{i}_feature.png'.format(now=str(now), i=i)) 189 | for i in range(8): 190 | feature_group = torch.mean(feature[0, i * 8:(i + 1) * 8], dim=1) 191 | fig, ax = plt.subplots() 192 | img_group = ax.imshow(remove_edge(feature_group.cpu().numpy()), cmap='gray') 193 | plt.savefig('visual_example/{now}_group_{i}_feature.png'.format(now=str(now), i=i)) 194 | 195 | def save_kernel(origin_ffted, H, W): 196 | import time 197 | import matplotlib.pyplot as plt 198 | import os 199 | now = time.time() 200 | origin_ffted = origin_ffted.detach() 201 | kernel = torch.fft.irfft2(origin_ffted, s=(H, W), dim=(2, 3), norm="ortho") 202 | group_channels = kernel.shape[1] // 8 203 | os.makedirs('visual_example', exist_ok=True) 204 | for i in range(kernel.shape[1]): 205 | kernel_channel = kernel[0, i] 206 | fig, ax = plt.subplots() 207 | img_channel = ax.imshow(remove_edge(kernel_channel.cpu().numpy()), cmap='gray') 208 | plt.savefig('visual_example/{now}_channel_{i}_kernel.png'.format(now=str(now), i=i)) 209 | for i in range(8): 210 | kernel_group = torch.mean(kernel[0, i*group_channels: (i+1)*group_channels], dim=0) 211 | fig, ax = plt.subplots() 212 | img_group = ax.imshow(remove_edge(kernel_group.cpu().numpy()), cmap='gray') 213 | plt.savefig('visual_example/{now}_group_{i}_kernel.png'.format(now=str(now), i=i)) 214 | kernel_mean = torch.mean(kernel[0], dim=0) 215 | fig, ax = plt.subplots() 216 | img_mean = ax.imshow(remove_edge(kernel_mean.cpu().numpy()), cmap='gray') 217 | plt.savefig('visual_example/{now}_all_kernel.png'.format(now=str(now))) 218 | 219 | abs = origin_ffted.abs() 220 | abs_group_channels = abs.shape[1] // 8 221 | os.makedirs('visual_mask_example', exist_ok=True) 222 | for i in range(abs.shape[1]): 223 | abs_channel = abs[0, i] 224 | fig, ax = plt.subplots() 225 | abs_channel = ax.imshow(abs_channel.cpu().numpy(), cmap='gray') 226 | plt.savefig('visual_mask_example/{now}_channel_{i}_abs.png'.format(now=str(now), i=i)) 227 | for i in range(8): 228 | abs_group = torch.mean(abs[0, i*abs_group_channels: (i+1)*abs_group_channels], dim=0) 229 | fig, ax = plt.subplots() 230 | img_group = ax.imshow(abs_group.cpu().numpy(), cmap='gray') 231 | plt.savefig('visual_mask_example/{now}_group_{i}_abs.png'.format(now=str(now), i=i)) 232 | abs_mean = torch.mean(abs[0], dim=0) 233 | fig, ax = plt.subplots() 234 | img_mean = ax.imshow(abs_mean.cpu().numpy(), cmap='gray') 235 | plt.savefig('visual_mask_example/{now}_all_abs.png'.format(now=str(now))) 236 | 237 | real = origin_ffted.real 238 | real_group_channels = real.shape[1] // 8 239 | os.makedirs('visual_mask_example', exist_ok=True) 240 | for i in range(real.shape[1]): 241 | real_channel = real[0, i] 242 | fig, ax = plt.subplots() 243 | real_channel = ax.imshow(real_channel.cpu().numpy(), cmap='gray') 244 | plt.savefig('visual_mask_example/{now}_channel_{i}_real.png'.format(now=str(now), i=i)) 245 | for i in range(8): 246 | real_group = torch.mean(real[0, i*real_group_channels: (i+1)*real_group_channels], dim=0) 247 | fig, ax = plt.subplots() 248 | img_group = ax.imshow(real_group.cpu().numpy(), cmap='gray') 249 | plt.savefig('visual_mask_example/{now}_group_{i}_mask.png'.format(now=str(now), i=i)) 250 | real_mean = torch.mean(real[0], dim=0) 251 | fig, ax = plt.subplots() 252 | img_mean = ax.imshow(real_mean.cpu().numpy(), cmap='gray') 253 | plt.savefig('visual_mask_example/{now}_all_real.png'.format(now=str(now))) 254 | 255 | imag = origin_ffted.imag 256 | imag_group_channels = imag.shape[1] // 8 257 | os.makedirs('visual_mask_example', exist_ok=True) 258 | for i in range(8): 259 | imag_group = torch.mean(imag[0, i*imag_group_channels: (i+1)*imag_group_channels], dim=0) 260 | fig, ax = plt.subplots() 261 | img_group = ax.imshow(imag_group.cpu().numpy(), cmap='gray') 262 | plt.savefig('visual_mask_example/{now}_group_{i}_imag.png'.format(now=str(now), i=i)) 263 | imag_mean = torch.mean(imag[0], dim=0) 264 | fig, ax = plt.subplots() 265 | img_mean = ax.imshow(imag_mean.cpu().numpy(), cmap='gray') 266 | plt.savefig('visual_mask_example/{now}_all_imag.png'.format(now=str(now))) 267 | 268 | 269 | 270 | class Block(nn.Module): 271 | def __init__(self, dim, hidden_size, num_blocks, double_skip, mlp_ratio=4., drop_path=0., attn_norm_layer='sync_batch_norm', enable_coreml_compatible_fn=False): 272 | # input shape [B C H W] 273 | super().__init__() 274 | # self.norm1 = nn.BatchNorm2d(num_features=dim) 275 | # self.norm1 = nn.SyncBatchNorm.convert_sync_batchnorm() 276 | 277 | if torch.cuda.device_count() < 1 + 1e-10: # '1 + 1e-10; is in order to make sure that when more than 1 GPU can use Sync-batch-norm. 278 | # for a CPU-device, Sync-batch norm does not work. So, change to batch norm 279 | self.norm1 = nn.BatchNorm2d(num_features=dim) 280 | self.norm2 = nn.BatchNorm2d(num_features=dim) 281 | logger.info("Using BatchNorm2d") 282 | else: 283 | self.norm1 = SyncBatchNorm(normalized_shape=dim, num_features=dim) 284 | self.norm2 = SyncBatchNorm(normalized_shape=dim, num_features=dim) 285 | logger.info("Using SyncBatchNorm") 286 | self.filter = AFNO2D_channelfirst(hidden_size=hidden_size, num_blocks=num_blocks, sparsity_threshold=0.01, 287 | hard_thresholding_fraction=1, hidden_size_factor=1) # if not enable_coreml_compatible_fn else \ 288 | # AFNO2D_channelfirst_coreml(hidden_size=hidden_size, num_blocks=num_blocks, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1) 289 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 290 | self.mlp = InvertedResidual( 291 | inp=dim, 292 | oup=dim, 293 | stride=1, 294 | expand_ratio=mlp_ratio, 295 | ) 296 | self.double_skip = double_skip 297 | 298 | def forward(self, x): 299 | residual = x 300 | # print(f"Block中传入的x.shape:{x.shape}") 301 | x = self.norm1(x) 302 | # print(x.shape) 303 | # x = self.filter(x) 304 | x = self.mlp(x) 305 | 306 | if self.double_skip: 307 | x = x + residual 308 | residual = x 309 | 310 | x = self.norm2(x) 311 | # x = self.mlp(x) 312 | x = self.filter(x) 313 | x = self.drop_path(x) 314 | x = x + residual 315 | return x 316 | 317 | def profile_module( 318 | self, input: Tensor, *args, **kwargs 319 | ) -> Tuple[Tensor, float, float]: 320 | b_sz, c, h, w = input.shape 321 | seq_len = h * w 322 | 323 | out, p_ffn, m_ffn = module_profile(module=self.mlp, x=input) 324 | # m_ffn = m_ffn * b_sz * seq_len 325 | 326 | out, p_mha, m_mha = module_profile(module=self.filter, x=out) 327 | 328 | 329 | macs = m_mha + m_ffn 330 | params = p_mha + p_ffn 331 | 332 | return input, params, macs 333 | 334 | 335 | class AFFBlock(BaseModule): 336 | 337 | def __init__( 338 | self, 339 | in_channels: int, 340 | transformer_dim: int, 341 | ffn_dim: int, 342 | n_transformer_blocks: Optional[int] = 2, 343 | head_dim: Optional[int] = 32, 344 | attn_dropout: Optional[float] = 0.0, 345 | dropout: Optional[int] = 0.0, 346 | ffn_dropout: Optional[int] = 0.0, 347 | patch_h: Optional[int] = 8, 348 | patch_w: Optional[int] = 8, 349 | attn_norm_layer: Optional[str] = "layer_norm_2d", 350 | conv_ksize: Optional[int] = 3, 351 | dilation: Optional[int] = 1, 352 | no_fusion: Optional[bool] = False, 353 | *args, 354 | **kwargs 355 | ) -> None: 356 | 357 | conv_1x1_out = ConvLayer( 358 | in_channels=transformer_dim, 359 | out_channels=in_channels, 360 | kernel_size=1, 361 | stride=1, 362 | use_norm=True, 363 | use_act=False, 364 | ) 365 | conv_3x3_out = None 366 | if not no_fusion: 367 | conv_3x3_out = ConvLayer( 368 | in_channels=2 * in_channels, 369 | out_channels=in_channels, 370 | kernel_size=1, # conv_ksize -> 1 371 | stride=1, 372 | use_norm=True, 373 | use_act=True, 374 | ) 375 | super().__init__() 376 | 377 | assert transformer_dim % head_dim == 0 378 | num_heads = transformer_dim // head_dim 379 | self.enable_coreml_compatible_fn = False or False 380 | print(self.enable_coreml_compatible_fn) 381 | 382 | global_rep = [ 383 | # TODO: to check the double skip 384 | Block( 385 | dim=transformer_dim, 386 | hidden_size=transformer_dim, 387 | num_blocks=8, 388 | double_skip=False, 389 | mlp_ratio=ffn_dim / transformer_dim, 390 | attn_norm_layer=attn_norm_layer, 391 | enable_coreml_compatible_fn=self.enable_coreml_compatible_fn 392 | ) 393 | for _ in range(n_transformer_blocks) 394 | ] 395 | global_rep.append(nn.BatchNorm2d(num_features=transformer_dim)) 396 | self.global_rep = nn.Sequential(*global_rep) 397 | 398 | self.conv_proj = conv_1x1_out 399 | 400 | self.fusion = conv_3x3_out 401 | 402 | self.patch_h = patch_h 403 | self.patch_w = patch_w 404 | self.patch_area = self.patch_w * self.patch_h 405 | 406 | self.cnn_in_dim = in_channels 407 | self.cnn_out_dim = transformer_dim 408 | self.n_heads = num_heads 409 | self.ffn_dim = ffn_dim 410 | self.dropout = dropout 411 | self.attn_dropout = attn_dropout 412 | self.ffn_dropout = ffn_dropout 413 | self.dilation = dilation 414 | self.n_blocks = n_transformer_blocks 415 | self.conv_ksize = conv_ksize 416 | 417 | def __repr__(self) -> str: 418 | repr_str = "{}(".format(self.__class__.__name__) 419 | 420 | repr_str += "\n\t Global representations with patch size of {}x{}".format( 421 | self.patch_h, self.patch_w 422 | ) 423 | if isinstance(self.global_rep, nn.Sequential): 424 | for m in self.global_rep: 425 | repr_str += "\n\t\t {}".format(m) 426 | else: 427 | repr_str += "\n\t\t {}".format(self.global_rep) 428 | 429 | if isinstance(self.conv_proj, nn.Sequential): 430 | for m in self.conv_proj: 431 | repr_str += "\n\t\t {}".format(m) 432 | else: 433 | repr_str += "\n\t\t {}".format(self.conv_proj) 434 | 435 | if self.fusion is not None: 436 | repr_str += "\n\t Feature fusion" 437 | if isinstance(self.fusion, nn.Sequential): 438 | for m in self.fusion: 439 | repr_str += "\n\t\t {}".format(m) 440 | else: 441 | repr_str += "\n\t\t {}".format(self.fusion) 442 | 443 | repr_str += "\n)" 444 | return repr_str 445 | 446 | def forward_spatial(self, x: Tensor) -> Tensor: 447 | res = x 448 | 449 | # fm = self.local_rep(x) 450 | patches = x 451 | 452 | # b, c, h, w = fm.size() 453 | # patches = einops.rearrange(fm, 'b c h w -> b (h w) c') 454 | 455 | # learn global representations 456 | for transformer_layer in self.global_rep: 457 | patches = transformer_layer(patches) 458 | 459 | # fm = einops.rearrange(patches, 'b (h w) c -> b c h w', h=h, w=w) 460 | 461 | fm = self.conv_proj(patches) 462 | 463 | if self.fusion is not None: 464 | fm = self.fusion(torch.cat((res, fm), dim=1)) 465 | return fm 466 | 467 | 468 | def forward( 469 | self, x: Union[Tensor, Tuple[Tensor]], *args, **kwargs 470 | ) -> Union[Tensor, Tuple[Tensor, Tensor]]: 471 | if isinstance(x, Tuple) and len(x) == 2: 472 | # for spatio-temporal 473 | return self.forward_temporal(x=x[0], x_prev=x[1]) 474 | elif isinstance(x, Tensor): 475 | # For image data 476 | return self.forward_spatial(x) 477 | else: 478 | raise NotImplementedError 479 | 480 | def profile_module( 481 | self, input: Tensor, *args, **kwargs 482 | ) -> Tuple[Tensor, float, float]: 483 | params = macs = 0.0 484 | 485 | res = input 486 | 487 | b, c, h, w = input.size() 488 | 489 | out, p, m = module_profile(module=self.global_rep, x=input) 490 | params += p 491 | macs += m 492 | 493 | out, p, m = module_profile(module=self.conv_proj, x=out) 494 | params += p 495 | macs += m 496 | 497 | if self.fusion is not None: 498 | out, p, m = module_profile( 499 | module=self.fusion, x=torch.cat((out, res), dim=1) 500 | ) 501 | params += p 502 | macs += m 503 | 504 | return res, params, macs 505 | 506 | if __name__ == "__main__": 507 | 508 | head_dim = 4 509 | input = torch.randn(1, 16, 224, 224) 510 | print(input.shape) 511 | affblock = AFFBlock(in_channels=input.shape[1], transformer_dim=input.shape[1], ffn_dim=1, head_dim=head_dim) 512 | torch.save(affblock.state_dict(), "AFFBlock.ckpt") 513 | 514 | output = affblock(input) 515 | print(output.shape) 516 | -------------------------------------------------------------------------------- /affnet.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TryHard-LL/AFFNet/cd42911f4ebd595a5049653ee8029d4d60eb9b7e/affnet.ckpt -------------------------------------------------------------------------------- /affnet_LL.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2023 Microsoft 3 | # Licensed under The MIT License 4 | ''' 5 | Description: 6 | Author:LL-Version-V1 7 | LastEditTime: 2023-08-23 8 | Description:AFFNet-Pytorch-Unofficial-Implementation 9 | Reference:https://github.com/microsoft/TokenMixers; 10 | Original Paper:Adaptive Frequency Filters As Efficient Global Token Mixers, ICCV 2023-https://arxiv.org/abs/2307.14008 11 | ''' 12 | # -------------------------------------------------------- 13 | import torch 14 | from torch import nn 15 | import argparse 16 | from typing import Dict, Tuple, Optional 17 | 18 | import logger 19 | from base_cls_LL import BaseEncoder 20 | from affnet_config import get_configuration 21 | from layers import ConvLayer, LinearLayer, GlobalPool, Dropout, InvertedResidual 22 | from aff_block_LL import AFFBlock 23 | 24 | 25 | class AffNet(BaseEncoder): 26 | 27 | def __init__(self, opts, *args, **kwargs) -> None: 28 | num_classes = 1000 29 | classifier_dropout = 0.0 30 | 31 | pool_type = "mean" 32 | image_channels = 3 33 | out_channels = 16 34 | 35 | affnet_config = get_configuration(opts=opts) 36 | 37 | super().__init__(opts, *args, **kwargs) 38 | 39 | # store model configuration in a dictionary 40 | self.model_conf_dict = dict() 41 | self.conv_1 = ConvLayer( 42 | in_channels=image_channels, 43 | out_channels=out_channels, 44 | kernel_size=3, 45 | stride=2, 46 | use_norm=True, 47 | use_act=True, 48 | ) 49 | 50 | self.model_conf_dict["conv1"] = {"in": image_channels, "out": out_channels} 51 | 52 | in_channels = out_channels 53 | self.layer_1, out_channels = self._make_layer( 54 | input_channel=in_channels, cfg=affnet_config["layer1"] 55 | ) 56 | self.model_conf_dict["layer1"] = {"in": in_channels, "out": out_channels} 57 | 58 | in_channels = out_channels 59 | self.layer_2, out_channels = self._make_layer( 60 | input_channel=in_channels, cfg=affnet_config["layer2"] 61 | ) 62 | self.model_conf_dict["layer2"] = {"in": in_channels, "out": out_channels} 63 | 64 | in_channels = out_channels 65 | self.layer_3, out_channels = self._make_layer( 66 | input_channel=in_channels, cfg=affnet_config["layer3"] 67 | ) 68 | self.model_conf_dict["layer3"] = {"in": in_channels, "out": out_channels} 69 | 70 | in_channels = out_channels 71 | self.layer_4, out_channels = self._make_layer( 72 | input_channel=in_channels, 73 | cfg=affnet_config["layer4"], 74 | dilate=self.dilate_l4, 75 | ) 76 | self.model_conf_dict["layer4"] = {"in": in_channels, "out": out_channels} 77 | 78 | in_channels = out_channels 79 | self.layer_5, out_channels = self._make_layer( 80 | input_channel=in_channels, 81 | cfg=affnet_config["layer5"], 82 | dilate=self.dilate_l5, 83 | ) 84 | self.model_conf_dict["layer5"] = {"in": in_channels, "out": out_channels} 85 | 86 | in_channels = out_channels 87 | exp_channels = min(affnet_config["last_layer_exp_factor"] * in_channels, 960) 88 | self.conv_1x1_exp = ConvLayer( 89 | in_channels=in_channels, 90 | out_channels=exp_channels, 91 | kernel_size=1, 92 | stride=1, 93 | use_act=True, 94 | use_norm=True, 95 | ) 96 | 97 | self.model_conf_dict["exp_before_cls"] = { 98 | "in": in_channels, 99 | "out": exp_channels, 100 | } 101 | 102 | self.classifier = nn.Sequential() 103 | self.classifier.add_module( 104 | name="global_pool", module=GlobalPool(pool_type=pool_type, keep_dim=False) 105 | ) 106 | if 0.0 < classifier_dropout < 1.0: 107 | self.classifier.add_module( 108 | name="dropout", module=Dropout(p=classifier_dropout, inplace=True) 109 | ) 110 | self.classifier.add_module( 111 | name="fc", 112 | module=LinearLayer( 113 | in_features=exp_channels, out_features=num_classes, bias=True 114 | ), 115 | ) 116 | 117 | # check model 118 | self.check_model() 119 | 120 | # weight initialization 121 | self.reset_parameters(opts=opts) 122 | 123 | def _make_layer( 124 | self, 125 | input_channel, 126 | cfg: Dict, 127 | dilate: Optional[bool] = False, 128 | *args, 129 | **kwargs 130 | ) -> Tuple[nn.Sequential, int]: 131 | block_type = cfg.get("block_type", "aff_block") 132 | if block_type.lower() == "aff_block": 133 | return self._make_affnet_layer( 134 | input_channel=input_channel, cfg=cfg, dilate=dilate 135 | ) 136 | else: 137 | return self._make_mobilenet_layer( 138 | input_channel=input_channel, cfg=cfg 139 | ) 140 | 141 | @staticmethod 142 | def _make_mobilenet_layer( 143 | input_channel: int, cfg: Dict, *args, **kwargs 144 | ) -> Tuple[nn.Sequential, int]: 145 | output_channels = cfg.get("out_channels") 146 | num_blocks = cfg.get("num_blocks", 2) 147 | expand_ratio = cfg.get("expand_ratio", 4) 148 | block = [] 149 | 150 | for i in range(num_blocks): 151 | stride = cfg.get("stride", 1) if i == 0 else 1 152 | 153 | layer = InvertedResidual( 154 | inp=input_channel, 155 | oup=output_channels, 156 | stride=stride, 157 | expand_ratio=expand_ratio, 158 | ) 159 | block.append(layer) 160 | input_channel = output_channels 161 | return nn.Sequential(*block), input_channel 162 | 163 | def _make_affnet_layer( 164 | self, 165 | input_channel, 166 | cfg: Dict, 167 | dilate: Optional[bool] = False, 168 | *args, 169 | **kwargs 170 | ) -> Tuple[nn.Sequential, int]: 171 | prev_dilation = self.dilation 172 | block = [] 173 | stride = cfg.get("stride", 1) 174 | no_fuse = cfg.get("no_fuse", False) 175 | 176 | if stride == 2: 177 | if dilate: 178 | self.dilation *= 2 179 | stride = 1 180 | 181 | layer = InvertedResidual( 182 | inp=input_channel, 183 | oup=cfg.get("out_channels"), 184 | stride=stride, 185 | expand_ratio=cfg.get("mv_expand_ratio", 4), 186 | # dilation=prev_dilation, 187 | ) 188 | 189 | block.append(layer) 190 | input_channel = cfg.get("out_channels") 191 | 192 | head_dim = cfg.get("head_dim", 32) 193 | transformer_dim = cfg["transformer_channels"] 194 | ffn_dim = cfg.get("ffn_dim") 195 | if head_dim is None: 196 | num_heads = cfg.get("num_heads", 4) 197 | if num_heads is None: 198 | num_heads = 4 199 | head_dim = transformer_dim // num_heads 200 | 201 | if transformer_dim % head_dim != 0: 202 | logger.error( 203 | "Transformer input dimension should be divisible by head dimension. " 204 | "Got {} and {}.".format(transformer_dim, head_dim) 205 | ) 206 | 207 | block.append( 208 | AFFBlock( 209 | in_channels=input_channel, 210 | transformer_dim=transformer_dim, 211 | ffn_dim=ffn_dim, 212 | n_transformer_blocks=cfg.get("transformer_blocks", 1), 213 | patch_h=cfg.get("patch_h", 2), 214 | patch_w=cfg.get("patch_w", 2), 215 | dropout=0.1, 216 | ffn_dropout=0.0, 217 | attn_dropout=0.1, 218 | head_dim=head_dim, 219 | no_fusion=no_fuse, 220 | conv_ksize=3, 221 | attn_norm_layer="layer_norm_2d", 222 | ) 223 | ) 224 | 225 | return nn.Sequential(*block), input_channel 226 | 227 | if __name__ == "__main__": 228 | 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument( 231 | "--model.classification.affnet.mode", 232 | type=str, 233 | default="xx_small", 234 | choices=["xx_small", "x_small", "small"], 235 | ) 236 | parser.add_argument( 237 | "--model.classification.affnet.attn-dropout", 238 | type=float, 239 | default=0.0, 240 | help="Dropout in attention layer. Defaults to 0.0", 241 | ) 242 | parser.add_argument( 243 | "--model.classification.affnet.ffn-dropout", 244 | type=float, 245 | default=0.0, 246 | help="Dropout between FFN layers. Defaults to 0.0", 247 | ) 248 | parser.add_argument( 249 | "--model.classification.affnet.dropout", 250 | type=float, 251 | default=0.0, 252 | help="Dropout in Transformer layer. Defaults to 0.0", 253 | ) 254 | parser.add_argument( 255 | "--model.classification.affnet.attn-norm-layer", 256 | type=str, 257 | default="layer_norm", 258 | help="Normalization layer in transformer. Defaults to LayerNorm", 259 | ) 260 | parser.add_argument( 261 | "--model.classification.affnet.no-fuse-local-global-features", 262 | action="store_true", 263 | help="Do not combine local and global features in MobileViT block", 264 | ) 265 | parser.add_argument( 266 | "--model.classification.affnet.conv-kernel-size", 267 | type=int, 268 | default=3, 269 | ) 270 | 271 | parser.add_argument( 272 | "--model.classification.affnet.head-dim", 273 | type=int, 274 | default=None, 275 | help="Head dimension in transformer", 276 | ) 277 | parser.add_argument( 278 | "--model.classification.affnet.number-heads", 279 | type=int, 280 | default=None, 281 | help="Number of heads in transformer", 282 | ) 283 | args = parser.parse_args() 284 | 285 | # opts = get_configuration(opts=args) 286 | affnet = AffNet(args) 287 | torch.save(affnet.state_dict(), "affnet.ckpt") 288 | 289 | input = torch.randn(1, 3, 640, 640) 290 | 291 | output = affnet(input) 292 | 293 | print(f"output.shape:{output.shape}") -------------------------------------------------------------------------------- /affnet_config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2023 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | from typing import Dict 7 | 8 | import logger 9 | 10 | 11 | def get_configuration(opts) -> Dict: 12 | mode = getattr(opts, "model.classification.affnet.mode", "xx_small") 13 | if mode is None: 14 | logger.error("Please specify mode") 15 | logger.info("model mode is:" + mode) 16 | 17 | head_dim = getattr(opts, "model.classification.affnet.head_dim", None) 18 | num_heads = getattr(opts, "model.classification.affnet.number_heads", 4) 19 | if head_dim is not None: 20 | if num_heads is not None: 21 | logger.error( 22 | "--model.classification.affnet.head-dim and --model.classification.affnet.number-heads " 23 | "are mutually exclusive." 24 | ) 25 | elif num_heads is not None: 26 | if head_dim is not None: 27 | logger.error( 28 | "--model.classification.affnet.head-dim and --model.classification.affnet.number-heads " 29 | "are mutually exclusive." 30 | ) 31 | mode = mode.lower() 32 | if mode == "xx_small": 33 | mv2_exp_mult = 2 34 | config = { 35 | "layer1": { 36 | "out_channels": 32, 37 | "expand_ratio": mv2_exp_mult, 38 | "num_blocks": 1, 39 | "stride": 1, 40 | "block_type": "mv2", 41 | }, 42 | "layer2": { 43 | "out_channels": 48, 44 | "expand_ratio": mv2_exp_mult, 45 | "num_blocks": 3, 46 | "stride": 2, 47 | "block_type": "mv2", 48 | }, 49 | "layer3": { # 28x28 50 | "out_channels": 64, 51 | "transformer_channels": 64, 52 | "ffn_dim": 128, 53 | "transformer_blocks": 2, 54 | "patch_h": 2, # 8, 55 | "patch_w": 2, # 8, 56 | "stride": 2, 57 | "mv_expand_ratio": mv2_exp_mult, 58 | "head_dim": head_dim, 59 | "num_heads": num_heads, 60 | "block_type": "aff_block", 61 | }, 62 | "layer4": { # 14x14 63 | "out_channels": 104, 64 | "transformer_channels": 104, 65 | "ffn_dim": 208, 66 | "transformer_blocks": 4, 67 | "patch_h": 2, # 4, 68 | "patch_w": 2, # 4, 69 | "stride": 2, 70 | "mv_expand_ratio": mv2_exp_mult, 71 | "head_dim": head_dim, 72 | "num_heads": num_heads, 73 | "block_type": "aff_block", 74 | }, 75 | "layer5": { # 7x7 76 | "out_channels": 144, 77 | "transformer_channels": 144, 78 | "ffn_dim": 288, 79 | "transformer_blocks": 3, 80 | "patch_h": 2, 81 | "patch_w": 2, 82 | "stride": 2, 83 | "mv_expand_ratio": mv2_exp_mult, 84 | "head_dim": head_dim, 85 | "num_heads": num_heads, 86 | "block_type": "aff_block", 87 | }, 88 | "last_layer_exp_factor": 4, 89 | } 90 | elif mode == "x_small": 91 | mv2_exp_mult = 4 92 | config = { 93 | "layer1": { 94 | "out_channels": 32, 95 | "expand_ratio": mv2_exp_mult, 96 | "num_blocks": 1, 97 | "stride": 1, 98 | "block_type": "mv2", 99 | }, 100 | "layer2": { 101 | "out_channels": 48, 102 | "expand_ratio": mv2_exp_mult, 103 | "num_blocks": 3, 104 | "stride": 2, 105 | "block_type": "mv2", 106 | }, 107 | "layer3": { # 28x28 108 | "out_channels": 96, 109 | "transformer_channels": 96, 110 | "ffn_dim": 192, 111 | "transformer_blocks": 2, 112 | "patch_h": 2, 113 | "patch_w": 2, 114 | "stride": 2, 115 | "mv_expand_ratio": mv2_exp_mult, 116 | "head_dim": head_dim, 117 | "num_heads": num_heads, 118 | "block_type": "aff_block", 119 | }, 120 | "layer4": { # 14x14 121 | "out_channels": 160, 122 | "transformer_channels": 160, 123 | "ffn_dim": 320, 124 | "transformer_blocks": 4, 125 | "patch_h": 2, 126 | "patch_w": 2, 127 | "stride": 2, 128 | "mv_expand_ratio": mv2_exp_mult, 129 | "head_dim": head_dim, 130 | "num_heads": num_heads, 131 | "block_type": "aff_block", 132 | }, 133 | "layer5": { # 7x7 134 | "out_channels": 192, 135 | "transformer_channels": 192, 136 | "ffn_dim": 384, 137 | "transformer_blocks": 3, 138 | "patch_h": 2, 139 | "patch_w": 2, 140 | "stride": 2, 141 | "mv_expand_ratio": mv2_exp_mult, 142 | "head_dim": head_dim, 143 | "num_heads": num_heads, 144 | "block_type": "aff_block", 145 | }, 146 | "last_layer_exp_factor": 4, 147 | } 148 | elif mode == "small": 149 | mv2_exp_mult = 4 150 | config = { 151 | "layer1": { 152 | "out_channels": 32, 153 | "expand_ratio": mv2_exp_mult, 154 | "num_blocks": 1, 155 | "stride": 1, 156 | "block_type": "mv2", 157 | }, 158 | "layer2": { 159 | "out_channels": 64, 160 | "expand_ratio": mv2_exp_mult, 161 | "num_blocks": 3, 162 | "stride": 2, 163 | "block_type": "mv2", 164 | }, 165 | "layer3": { # 28x28 166 | "out_channels": 128, 167 | "transformer_channels": 128, 168 | "ffn_dim": 256, 169 | "transformer_blocks": 2, 170 | "patch_h": 2, 171 | "patch_w": 2, 172 | "stride": 2, 173 | "mv_expand_ratio": mv2_exp_mult, 174 | "head_dim": head_dim, 175 | "num_heads": num_heads, 176 | "block_type": "aff_block", 177 | }, 178 | "layer4": { # 14x14 179 | "out_channels": 256, 180 | "transformer_channels": 256, 181 | "ffn_dim": 512, 182 | "transformer_blocks": 4, 183 | "patch_h": 2, 184 | "patch_w": 2, 185 | "stride": 2, 186 | "mv_expand_ratio": mv2_exp_mult, 187 | "head_dim": head_dim, 188 | "num_heads": num_heads, 189 | "block_type": "aff_block", 190 | }, 191 | "layer5": { # 7x7 192 | "out_channels": 320, 193 | "transformer_channels": 320, 194 | "ffn_dim": 640, 195 | "transformer_blocks": 3, 196 | "patch_h": 2, 197 | "patch_w": 2, 198 | "stride": 2, 199 | "mv_expand_ratio": mv2_exp_mult, 200 | "head_dim": head_dim, 201 | "num_heads": num_heads, 202 | "block_type": "aff_block", 203 | }, 204 | "last_layer_exp_factor": 4, 205 | } 206 | elif mode == "base": 207 | mv2_exp_mult = 4 208 | config = { 209 | "layer1": { 210 | "out_channels": 64, 211 | "expand_ratio": mv2_exp_mult, 212 | "num_blocks": 1, 213 | "stride": 1, 214 | "block_type": "mv2", 215 | }, 216 | "layer2": { 217 | "out_channels": 128, 218 | "expand_ratio": mv2_exp_mult, 219 | "num_blocks": 3, 220 | "stride": 2, 221 | "block_type": "mv2", 222 | }, 223 | "layer3": { # 28x28 224 | "out_channels": 256, 225 | "transformer_channels": 256, 226 | "ffn_dim": 512, 227 | "transformer_blocks": 2, 228 | "patch_h": 2, 229 | "patch_w": 2, 230 | "stride": 2, 231 | "mv_expand_ratio": mv2_exp_mult, 232 | "head_dim": head_dim, 233 | "num_heads": num_heads, 234 | "block_type": "aff_block", 235 | }, 236 | "layer4": { # 14x14 237 | "out_channels": 512, 238 | "transformer_channels": 512, 239 | "ffn_dim": 1024, 240 | "transformer_blocks": 4, 241 | "patch_h": 2, 242 | "patch_w": 2, 243 | "stride": 2, 244 | "mv_expand_ratio": mv2_exp_mult, 245 | "head_dim": head_dim, 246 | "num_heads": num_heads, 247 | "block_type": "aff_block", 248 | "no_fuse": True 249 | }, 250 | "layer5": { # 7x7 251 | "out_channels": 640, 252 | "transformer_channels": 640, 253 | "ffn_dim": 1280, 254 | "transformer_blocks": 3, 255 | "patch_h": 2, 256 | "patch_w": 2, 257 | "stride": 2, 258 | "mv_expand_ratio": mv2_exp_mult, 259 | "head_dim": head_dim, 260 | "num_heads": num_heads, 261 | "block_type": "aff_block", 262 | "no_fuse": True 263 | }, 264 | "last_layer_exp_factor": 4, 265 | } 266 | elif mode == "large": 267 | mv2_exp_mult = 4 268 | config = { 269 | "layer1": { 270 | "out_channels": 64, 271 | "expand_ratio": mv2_exp_mult, 272 | "num_blocks": 2, 273 | "stride": 1, 274 | "block_type": "mv2", 275 | }, 276 | "layer2": { 277 | "out_channels": 128, 278 | "expand_ratio": mv2_exp_mult, 279 | "num_blocks": 6, 280 | "stride": 2, 281 | "block_type": "mv2", 282 | }, 283 | "layer3": { # 28x28 284 | "out_channels": 256, 285 | "transformer_channels": 256, 286 | "ffn_dim": 512, 287 | "transformer_blocks": 4, 288 | "patch_h": 2, 289 | "patch_w": 2, 290 | "stride": 2, 291 | "mv_expand_ratio": mv2_exp_mult, 292 | "head_dim": head_dim, 293 | "num_heads": num_heads, 294 | "block_type": "aff_block", 295 | }, 296 | "layer4": { # 14x14 297 | "out_channels": 512, 298 | "transformer_channels": 512, 299 | "ffn_dim": 1024, 300 | "transformer_blocks": 18, 301 | "patch_h": 2, 302 | "patch_w": 2, 303 | "stride": 2, 304 | "mv_expand_ratio": mv2_exp_mult, 305 | "head_dim": head_dim, 306 | "num_heads": num_heads, 307 | "block_type": "aff_block", 308 | "no_fuse": True 309 | }, 310 | "layer5": { # 7x7 311 | "out_channels": 768, 312 | "transformer_channels": 768, 313 | "ffn_dim": 1536, 314 | "transformer_blocks": 6, 315 | "patch_h": 2, 316 | "patch_w": 2, 317 | "stride": 2, 318 | "mv_expand_ratio": mv2_exp_mult, 319 | "head_dim": head_dim, 320 | "num_heads": num_heads, 321 | "block_type": "aff_block", 322 | "no_fuse": True 323 | }, 324 | "last_layer_exp_factor": 4, 325 | } 326 | else: 327 | raise NotImplementedError 328 | 329 | return config 330 | -------------------------------------------------------------------------------- /base_cls_LL.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2023 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | import torch 6 | from torch import nn, Tensor 7 | from torch.utils.checkpoint import checkpoint as gradient_checkpoint_fn 8 | from typing import Optional, Dict, Tuple, Union, Any 9 | import argparse 10 | 11 | import logger 12 | 13 | # from ... import parameter_list 14 | from layers import LinearLayer 15 | from profiler import module_profile 16 | from init_utils import initialize_weights, initialize_fc_layer, norm_layers_tuple 17 | 18 | # norm_layers_tuple = tuple(["batch_norm", "batch_norm_2d", "group_norm", "instance_norm", "instance_norm_2d", "layer_norm", "sync_batch_norm"]) 19 | 20 | class BaseEncoder(nn.Module): 21 | """ 22 | Base class for different classification models 23 | """ 24 | 25 | def __init__(self, opts, *args, **kwargs) -> None: 26 | super().__init__() 27 | self.conv_1 = None 28 | self.layer_1 = None 29 | self.layer_2 = None 30 | self.layer_3 = None 31 | self.layer_4 = None 32 | self.layer_5 = None 33 | self.conv_1x1_exp = None 34 | self.classifier = None 35 | self.round_nearest = 8 36 | 37 | # Segmentation architectures like Deeplab and PSPNet modifies the strides of the backbone 38 | # We allow that using output_stride and replace_stride_with_dilation arguments 39 | self.dilation = 1 40 | output_stride = kwargs.get("output_stride", None) 41 | self.dilate_l4 = False 42 | self.dilate_l5 = False 43 | if output_stride == 8: 44 | self.dilate_l4 = True 45 | self.dilate_l5 = True 46 | elif output_stride == 16: 47 | self.dilate_l5 = True 48 | 49 | self.model_conf_dict = dict() 50 | # self.neural_augmentor = build_neural_augmentor(opts=opts, *args, **kwargs) 51 | self.gradient_checkpointing = getattr( 52 | opts, "model.classification.gradient_checkpointing", False 53 | ) 54 | 55 | @classmethod 56 | def add_arguments(cls, parser: argparse.ArgumentParser): 57 | """Add model-specific arguments""" 58 | group = parser.add_argument_group( 59 | title="".format(cls.__name__), description="".format(cls.__name__) 60 | ) 61 | 62 | group.add_argument( 63 | "--model.classification.classifier-dropout", 64 | type=float, 65 | default=0.0, 66 | help="Dropout rate in classifier", 67 | ) 68 | 69 | group.add_argument( 70 | "--model.classification.name", type=str, default=None, help="Model name" 71 | ) 72 | group.add_argument( 73 | "--model.classification.n-classes", 74 | type=int, 75 | default=1000, 76 | help="Number of classes in the dataset", 77 | ) 78 | group.add_argument( 79 | "--model.classification.pretrained", 80 | type=str, 81 | default=None, 82 | help="Path of the pretrained backbone", 83 | ) 84 | group.add_argument( 85 | "--model.classification.freeze-batch-norm", 86 | action="store_true", 87 | help="Freeze batch norm layers", 88 | ) 89 | group.add_argument( 90 | "--model.classification.activation.name", 91 | default=None, 92 | type=str, 93 | help="Non-linear function name (e.g., relu)", 94 | ) 95 | group.add_argument( 96 | "--model.classification.activation.inplace", 97 | action="store_true", 98 | help="Inplace non-linear functions", 99 | ) 100 | group.add_argument( 101 | "--model.classification.activation.neg-slope", 102 | default=0.1, 103 | type=float, 104 | help="Negative slope in leaky relu", 105 | ) 106 | 107 | group.add_argument( 108 | "--model.classification.finetune-pretrained-model", 109 | action="store_true", 110 | help="Finetune a pretrained model", 111 | ) 112 | group.add_argument( 113 | "--model.classification.n-pretrained-classes", 114 | type=int, 115 | default=None, 116 | help="Number of pre-trained classes", 117 | ) 118 | 119 | group.add_argument( 120 | "--model.classification.gradient-checkpointing", 121 | action="store_true", 122 | help="Checkpoint output of each spatial level in the classification backbone. Note that" 123 | "we only take care of checkpointing in {}. If custom forward functions are used, please" 124 | "implement checkpointing accordingly", 125 | ) 126 | 127 | return parser 128 | 129 | def check_model(self): 130 | assert ( 131 | self.model_conf_dict 132 | ), "Model configuration dictionary should not be empty" 133 | assert self.conv_1 is not None, "Please implement self.conv_1" 134 | assert self.layer_1 is not None, "Please implement self.layer_1" 135 | assert self.layer_2 is not None, "Please implement self.layer_2" 136 | assert self.layer_3 is not None, "Please implement self.layer_3" 137 | assert self.layer_4 is not None, "Please implement self.layer_4" 138 | assert self.layer_5 is not None, "Please implement self.layer_5" 139 | assert self.conv_1x1_exp is not None, "Please implement self.conv_1x1_exp" 140 | assert self.classifier is not None, "Please implement self.classifier" 141 | 142 | def reset_parameters(self, opts): 143 | """Initialize model weights""" 144 | initialize_weights(opts=opts, modules=self.modules()) 145 | 146 | def update_classifier(self, opts, n_classes: int) -> None: 147 | """ 148 | This function updates the classification layer in a model. Useful for finetuning purposes. 149 | """ 150 | linear_init_type = getattr(opts, "model.layer.linear_init", "normal") 151 | if isinstance(self.classifier, nn.Sequential): 152 | in_features = self.classifier[-1].in_features 153 | layer = LinearLayer( 154 | in_features=in_features, out_features=n_classes, bias=True 155 | ) 156 | initialize_fc_layer(layer, init_method=linear_init_type) 157 | self.classifier[-1] = layer 158 | else: 159 | in_features = self.classifier.in_features 160 | layer = LinearLayer( 161 | in_features=in_features, out_features=n_classes, bias=True 162 | ) 163 | initialize_fc_layer(layer, init_method=linear_init_type) 164 | 165 | # re-init head 166 | head_init_scale = 0.001 167 | layer.weight.data.mul_(head_init_scale) 168 | layer.bias.data.mul_(head_init_scale) 169 | 170 | self.classifier = layer 171 | 172 | def _forward_layer(self, layer: nn.Module, x: Tensor) -> Tensor: 173 | # Larger models with large input image size may not be able to fit into memory. 174 | # We can use gradient checkpointing to enable training with large models and large inputs 175 | return ( 176 | gradient_checkpoint_fn(layer, x) 177 | if self.gradient_checkpointing 178 | else layer(x) 179 | ) 180 | 181 | def extract_end_points_all( 182 | self, 183 | x: Tensor, 184 | use_l5: Optional[bool] = True, 185 | use_l5_exp: Optional[bool] = False, 186 | *args, 187 | **kwargs 188 | ) -> Dict[str, Tensor]: 189 | out_dict = {} # Use dictionary over NamedTuple so that JIT is happy 190 | 191 | if self.training and self.neural_augmentor is not None: 192 | x = self.neural_augmentor(x) 193 | out_dict["augmented_tensor"] = x 194 | 195 | x = self._forward_layer(self.conv_1, x) # 112 x112 196 | x = self._forward_layer(self.layer_1, x) # 112 x112 197 | out_dict["out_l1"] = x 198 | 199 | x = self._forward_layer(self.layer_2, x) # 56 x 56 200 | out_dict["out_l2"] = x 201 | 202 | x = self._forward_layer(self.layer_3, x) # 28 x 28 203 | out_dict["out_l3"] = x 204 | 205 | x = self._forward_layer(self.layer_4, x) # 14 x 14 206 | out_dict["out_l4"] = x 207 | 208 | if use_l5: 209 | x = self._forward_layer(self.layer_5, x) # 7 x 7 210 | out_dict["out_l5"] = x 211 | 212 | if use_l5_exp: 213 | x = self._forward_layer(self.conv_1x1_exp, x) 214 | out_dict["out_l5_exp"] = x 215 | return out_dict 216 | 217 | def extract_end_points_l4(self, x: Tensor, *args, **kwargs) -> Dict[str, Tensor]: 218 | return self.extract_end_points_all(x, use_l5=False) 219 | 220 | def _extract_features(self, x: Tensor, *args, **kwargs) -> Tensor: 221 | x = self._forward_layer(self.conv_1, x) 222 | x = self._forward_layer(self.layer_1, x) 223 | x = self._forward_layer(self.layer_2, x) 224 | x = self._forward_layer(self.layer_3, x) 225 | 226 | x = self._forward_layer(self.layer_4, x) 227 | x = self._forward_layer(self.layer_5, x) 228 | x = self._forward_layer(self.conv_1x1_exp, x) 229 | return x 230 | 231 | def _forward_classifier(self, x: Tensor, *args, **kwargs) -> Tensor: 232 | # We add another classifier function so that the classifiers 233 | # that do not adhere to the structure of BaseEncoder can still 234 | # use neural augmentor 235 | x = self._extract_features(x) 236 | x = self.classifier(x) 237 | return x 238 | 239 | def forward(self, x: Any, *args, **kwargs) -> Any: 240 | x = self._forward_classifier(x, *args, **kwargs) 241 | return x 242 | 243 | def freeze_norm_layers(self) -> None: 244 | """Freeze normalization layers""" 245 | for m in self.modules(): 246 | if isinstance(m, norm_layers_tuple): 247 | m.eval() 248 | m.weight.requires_grad = False 249 | m.bias.requires_grad = False 250 | m.training = False 251 | 252 | # def get_trainable_parameters( 253 | # self, 254 | # weight_decay: Optional[float] = 0.0, 255 | # no_decay_bn_filter_bias: Optional[bool] = False, 256 | # *args, 257 | # **kwargs 258 | # ): 259 | # """Get trainable parameters""" 260 | # param_list = parameter_list( 261 | # named_parameters=self.named_parameters, 262 | # weight_decay=weight_decay, 263 | # no_decay_bn_filter_bias=no_decay_bn_filter_bias, 264 | # *args, 265 | # **kwargs 266 | # ) 267 | # return param_list, [1.0] * len(param_list) 268 | 269 | @staticmethod 270 | def _profile_layers( 271 | layers, input, overall_params, overall_macs, *args, **kwargs 272 | ) -> Tuple[Tensor, float, float]: 273 | if not isinstance(layers, list): 274 | layers = [layers] 275 | 276 | for layer in layers: 277 | if layer is None: 278 | continue 279 | input, layer_param, layer_macs = module_profile(module=layer, x=input) 280 | 281 | overall_params += layer_param 282 | overall_macs += layer_macs 283 | 284 | if isinstance(layer, nn.Sequential): 285 | module_name = "\n+".join([l.__class__.__name__ for l in layer]) 286 | else: 287 | module_name = layer.__class__.__name__ 288 | print( 289 | "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( 290 | module_name, 291 | "Params", 292 | round(layer_param / 1e6, 3), 293 | "MACs", 294 | round(layer_macs / 1e6, 3), 295 | ) 296 | ) 297 | logger.singe_dash_line() 298 | return input, overall_params, overall_macs 299 | 300 | def dummy_input_and_label(self, batch_size: int) -> Dict: 301 | """Create dummy input and labels for CI/CD purposes. Child classes must override it 302 | if functionality is different. 303 | """ 304 | img_channels = 3 305 | height = 224 306 | width = 224 307 | n_labels = 10 308 | img_tensor = torch.randn( 309 | batch_size, img_channels, height, width, dtype=torch.float 310 | ) 311 | label_tensor = torch.randint(low=0, high=n_labels, size=(batch_size,)).long() 312 | return {"samples": img_tensor, "targets": label_tensor} 313 | 314 | def profile_model( 315 | self, input: Tensor, is_classification: Optional[bool] = True, *args, **kwargs 316 | ) -> Tuple[Union[Tensor, Dict[str, Tensor]], float, float]: 317 | """ 318 | Helper function to profile a model. 319 | 320 | .. note:: 321 | Model profiling is for reference only and may contain errors as it solely relies on user implementation to 322 | compute theoretical FLOPs 323 | """ 324 | overall_params, overall_macs = 0.0, 0.0 325 | 326 | input_fvcore = input.clone() 327 | 328 | if is_classification: 329 | logger.log("Model statistics for an input of size {}".format(input.size())) 330 | logger.double_dash_line(dashes=65) 331 | print("{:>35} Summary".format(self.__class__.__name__)) 332 | logger.double_dash_line(dashes=65) 333 | 334 | out_dict = {} 335 | input, overall_params, overall_macs = self._profile_layers( 336 | [self.conv_1, self.layer_1], 337 | input=input, 338 | overall_params=overall_params, 339 | overall_macs=overall_macs, 340 | ) 341 | out_dict["out_l1"] = input 342 | 343 | input, overall_params, overall_macs = self._profile_layers( 344 | self.layer_2, 345 | input=input, 346 | overall_params=overall_params, 347 | overall_macs=overall_macs, 348 | ) 349 | out_dict["out_l2"] = input 350 | 351 | input, overall_params, overall_macs = self._profile_layers( 352 | self.layer_3, 353 | input=input, 354 | overall_params=overall_params, 355 | overall_macs=overall_macs, 356 | ) 357 | out_dict["out_l3"] = input 358 | 359 | input, overall_params, overall_macs = self._profile_layers( 360 | self.layer_4, 361 | input=input, 362 | overall_params=overall_params, 363 | overall_macs=overall_macs, 364 | ) 365 | out_dict["out_l4"] = input 366 | 367 | input, overall_params, overall_macs = self._profile_layers( 368 | self.layer_5, 369 | input=input, 370 | overall_params=overall_params, 371 | overall_macs=overall_macs, 372 | ) 373 | out_dict["out_l5"] = input 374 | 375 | if self.conv_1x1_exp is not None: 376 | input, overall_params, overall_macs = self._profile_layers( 377 | self.conv_1x1_exp, 378 | input=input, 379 | overall_params=overall_params, 380 | overall_macs=overall_macs, 381 | ) 382 | out_dict["out_l5_exp"] = input 383 | 384 | if is_classification: 385 | classifier_params, classifier_macs = 0.0, 0.0 386 | if self.classifier is not None: 387 | input, classifier_params, classifier_macs = module_profile( 388 | module=self.classifier, x=input 389 | ) 390 | print( 391 | "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( 392 | "Classifier", 393 | "Params", 394 | round(classifier_params / 1e6, 3), 395 | "MACs", 396 | round(classifier_macs / 1e6, 3), 397 | ) 398 | ) 399 | overall_params += classifier_params 400 | overall_macs += classifier_macs 401 | 402 | logger.double_dash_line(dashes=65) 403 | print( 404 | "{:<20} = {:>8.3f} M".format("Overall parameters", overall_params / 1e6) 405 | ) 406 | overall_params_py = sum([p.numel() for p in self.parameters()]) 407 | print( 408 | "{:<20} = {:>8.3f} M".format( 409 | "Overall parameters (sanity check)", overall_params_py / 1e6 410 | ) 411 | ) 412 | 413 | # Counting Addition and Multiplication as 1 operation 414 | print( 415 | "{:<20} = {:>8.3f} M".format( 416 | "Overall MACs (theoretical)", overall_macs / 1e6 417 | ) 418 | ) 419 | 420 | print("Note: Theoretical MACs depends on user-implementation. Be cautious") 421 | logger.double_dash_line(dashes=65) 422 | 423 | return out_dict, overall_params, overall_macs 424 | -------------------------------------------------------------------------------- /base_layer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import nn, Tensor 7 | import argparse 8 | from typing import Any, Tuple 9 | 10 | 11 | class BaseLayer(nn.Module): 12 | """ 13 | Base class for neural network layers 14 | """ 15 | 16 | def __init__(self, *args, **kwargs) -> None: 17 | super().__init__() 18 | 19 | @classmethod 20 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 21 | """Add layer specific arguments""" 22 | return parser 23 | 24 | def forward(self, *args, **kwargs) -> Any: 25 | pass 26 | 27 | def profile_module(self, *args, **kwargs) -> Tuple[Tensor, float, float]: 28 | raise NotImplementedError 29 | 30 | def __repr__(self): 31 | return "{}".format(self.__class__.__name__) 32 | -------------------------------------------------------------------------------- /base_module.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from typing import Tuple, Union, Any 9 | 10 | 11 | class BaseModule(nn.Module): 12 | """Base class for all modules""" 13 | 14 | def __init__(self, *args, **kwargs): 15 | super(BaseModule, self).__init__() 16 | 17 | def forward(self, x: Any, *args, **kwargs) -> Any: 18 | raise NotImplementedError 19 | 20 | def profile_module(self, input: Any, *args, **kwargs) -> Tuple[Any, float, float]: 21 | raise NotImplementedError 22 | 23 | def __repr__(self): 24 | return "{}".format(self.__class__.__name__) 25 | -------------------------------------------------------------------------------- /init_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import nn 7 | from typing import Optional 8 | 9 | import logger 10 | 11 | from layers import LinearLayer, GroupLinear, norm_layers_tuple 12 | 13 | supported_conv_inits = [ 14 | "kaiming_normal", 15 | "kaiming_uniform", 16 | "xavier_normal", 17 | "xavier_uniform", 18 | "normal", 19 | "trunc_normal", 20 | ] 21 | supported_fc_inits = [ 22 | "kaiming_normal", 23 | "kaiming_uniform", 24 | "xavier_normal", 25 | "xavier_uniform", 26 | "normal", 27 | "trunc_normal", 28 | ] 29 | 30 | 31 | def _init_nn_layers( 32 | module, 33 | init_method: Optional[str] = "kaiming_normal", 34 | std_val: Optional[float] = None, 35 | ) -> None: 36 | """ 37 | Helper function to initialize neural network module 38 | """ 39 | init_method = init_method.lower() 40 | if init_method == "kaiming_normal": 41 | if module.weight is not None: 42 | nn.init.kaiming_normal_(module.weight, mode="fan_out") 43 | if module.bias is not None: 44 | nn.init.zeros_(module.bias) 45 | elif init_method == "kaiming_uniform": 46 | if module.weight is not None: 47 | nn.init.kaiming_uniform_(module.weight, mode="fan_out") 48 | if module.bias is not None: 49 | nn.init.zeros_(module.bias) 50 | elif init_method == "xavier_normal": 51 | if module.weight is not None: 52 | nn.init.xavier_normal_(module.weight) 53 | if module.bias is not None: 54 | nn.init.zeros_(module.bias) 55 | elif init_method == "xavier_uniform": 56 | if module.weight is not None: 57 | nn.init.xavier_uniform_(module.weight) 58 | if module.bias is not None: 59 | nn.init.zeros_(module.bias) 60 | elif init_method == "normal": 61 | if module.weight is not None: 62 | std = 1.0 / module.weight.size(1) ** 0.5 if std_val is None else std_val 63 | nn.init.normal_(module.weight, mean=0.0, std=std) 64 | if module.bias is not None: 65 | nn.init.zeros_(module.bias) 66 | elif init_method == "trunc_normal": 67 | if module.weight is not None: 68 | std = 1.0 / module.weight.size(1) ** 0.5 if std_val is None else std_val 69 | nn.init.trunc_normal_(module.weight, mean=0.0, std=std) 70 | if module.bias is not None: 71 | nn.init.zeros_(module.bias) 72 | else: 73 | supported_conv_message = "Supported initialization methods are:" 74 | for i, l in enumerate(supported_conv_inits): 75 | supported_conv_message += "\n \t {}) {}".format(i, l) 76 | logger.error("{} \n Got: {}".format(supported_conv_message, init_method)) 77 | 78 | 79 | def initialize_conv_layer( 80 | module, 81 | init_method: Optional[str] = "kaiming_normal", 82 | std_val: Optional[float] = 0.01, 83 | ) -> None: 84 | """Helper function to initialize convolution layers""" 85 | _init_nn_layers(module=module, init_method=init_method, std_val=std_val) 86 | 87 | 88 | def initialize_fc_layer( 89 | module, init_method: Optional[str] = "normal", std_val: Optional[float] = 0.01 90 | ) -> None: 91 | """Helper function to initialize fully-connected layers""" 92 | if hasattr(module, "layer"): 93 | _init_nn_layers(module=module.layer, init_method=init_method, std_val=std_val) 94 | else: 95 | _init_nn_layers(module=module, init_method=init_method, std_val=std_val) 96 | 97 | 98 | def initialize_norm_layers(module) -> None: 99 | """Helper function to initialize normalization layers""" 100 | 101 | def _init_fn(module): 102 | if hasattr(module, "weight") and module.weight is not None: 103 | nn.init.ones_(module.weight) 104 | if hasattr(module, "bias") and module.bias is not None: 105 | nn.init.zeros_(module.bias) 106 | 107 | _init_fn(module.layer) if hasattr(module, "layer") else _init_fn(module=module) 108 | 109 | 110 | def initialize_weights(opts, modules) -> None: 111 | """Helper function to initialize differnet layers in a model""" 112 | # weight initialization 113 | conv_init_type = getattr(opts, "model.layer.conv_init", "kaiming_normal") 114 | linear_init_type = getattr(opts, "model.layer.linear_init", "normal") 115 | 116 | conv_std = getattr(opts, "model.layer.conv_init_std_dev", None) 117 | linear_std = getattr(opts, "model.layer.linear_init_std_dev", 0.01) 118 | group_linear_std = getattr(opts, "model.layer.group_linear_init_std_dev", 0.01) 119 | 120 | if isinstance(modules, nn.Sequential): 121 | for m in modules: 122 | if isinstance(m, (nn.Conv2d, nn.Conv3d)): 123 | initialize_conv_layer( 124 | module=m, init_method=conv_init_type, std_val=conv_std 125 | ) 126 | elif isinstance(m, norm_layers_tuple): 127 | initialize_norm_layers(module=m) 128 | elif isinstance(m, (nn.Linear, LinearLayer)): 129 | initialize_fc_layer( 130 | module=m, init_method=linear_init_type, std_val=linear_std 131 | ) 132 | elif isinstance(m, GroupLinear): 133 | initialize_fc_layer( 134 | module=m, init_method=linear_init_type, std_val=group_linear_std 135 | ) 136 | else: 137 | if isinstance(modules, (nn.Conv2d, nn.Conv3d)): 138 | initialize_conv_layer( 139 | module=modules, init_method=conv_init_type, std_val=conv_std 140 | ) 141 | # elif isinstance(modules, norm_layers_tuple): 142 | # initialize_norm_layers(module=modules) 143 | elif isinstance(modules, (nn.Linear, LinearLayer)): 144 | initialize_fc_layer( 145 | module=modules, init_method=linear_init_type, std_val=linear_std 146 | ) 147 | elif isinstance(modules, GroupLinear): 148 | initialize_fc_layer( 149 | module=modules, init_method=linear_init_type, std_val=group_linear_std 150 | ) 151 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # 5 | # For licensing see accompanying LICENSE file. 6 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 7 | # 8 | 9 | from torch import nn, Tensor 10 | import argparse 11 | import torch 12 | from typing import Any, Optional, Union, Tuple 13 | 14 | # 15 | # For licensing see accompanying LICENSE file. 16 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 17 | # 18 | 19 | import torch 20 | import logger 21 | 22 | from torch import Tensor 23 | import argparse 24 | from typing import List, Optional, Tuple 25 | from torch.nn import functional as F 26 | 27 | norm_layers_tuple = tuple(["batch_norm", "batch_norm_2d", "group_norm", "instance_norm", "instance_norm_2d", "layer_norm", "sync_batch_norm"]) 28 | 29 | class Identity(nn.Module): 30 | """ 31 | This is a place-holder and returns the same tensor. 32 | """ 33 | 34 | def __init__(self): 35 | super(Identity, self).__init__() 36 | 37 | def forward(self, x: Tensor) -> Tensor: 38 | return x 39 | 40 | def profile_module(self, x: Tensor) -> Tuple[Tensor, float, float]: 41 | return x, 0.0, 0.0 42 | 43 | class BaseLayer(nn.Module): 44 | """ 45 | Base class for neural network layers 46 | """ 47 | 48 | def __init__(self, *args, **kwargs) -> None: 49 | super().__init__() 50 | 51 | @classmethod 52 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 53 | """Add layer specific arguments""" 54 | return parser 55 | 56 | def forward(self, *args, **kwargs) -> Any: 57 | pass 58 | 59 | def profile_module(self, *args, **kwargs) -> Tuple[Tensor, float, float]: 60 | raise NotImplementedError 61 | 62 | def __repr__(self): 63 | return "{}".format(self.__class__.__name__) 64 | 65 | 66 | class Conv2d(nn.Conv2d): 67 | """ 68 | Applies a 2D convolution over an input 69 | 70 | Args: 71 | in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` 72 | out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` 73 | kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. 74 | stride (Union[int, Tuple[int, int]]): Stride for convolution. Defaults to 1 75 | padding (Union[int, Tuple[int, int]]): Padding for convolution. Defaults to 0 76 | dilation (Union[int, Tuple[int, int]]): Dilation rate for convolution. Default: 1 77 | groups (Optional[int]): Number of groups in convolution. Default: 1 78 | bias (bool): Use bias. Default: ``False`` 79 | padding_mode (Optional[str]): Padding mode. Default: ``zeros`` 80 | 81 | use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` 82 | use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). 83 | Default: ``True`` 84 | act_name (Optional[str]): Use specific activation function. Overrides the one specified in command line args. 85 | 86 | Shape: 87 | - Input: :math:`(N, C_{in}, H_{in}, W_{in})` 88 | - Output: :math:`(N, C_{out}, H_{out}, W_{out})` 89 | """ 90 | 91 | def __init__( 92 | self, 93 | in_channels: int, 94 | out_channels: int, 95 | kernel_size: Union[int, Tuple[int, int]], 96 | stride: Optional[Union[int, Tuple[int, int]]] = 1, 97 | padding: Optional[Union[int, Tuple[int, int]]] = 0, 98 | dilation: Optional[Union[int, Tuple[int, int]]] = 1, 99 | groups: Optional[int] = 1, 100 | bias: Optional[bool] = False, 101 | padding_mode: Optional[str] = "zeros", 102 | *args, 103 | **kwargs 104 | ) -> None: 105 | super().__init__( 106 | in_channels=in_channels, 107 | out_channels=out_channels, 108 | kernel_size=kernel_size, 109 | stride=stride, 110 | padding=padding, 111 | dilation=dilation, 112 | groups=groups, 113 | bias=bias, 114 | padding_mode=padding_mode, 115 | ) 116 | 117 | 118 | class ConvLayer(BaseLayer): 119 | """ 120 | Applies a 2D convolution over an input 121 | 122 | Args: 123 | opts: command line arguments 124 | in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` 125 | out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` 126 | kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. 127 | stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1 128 | dilation (Union[int, Tuple[int, int]]): Dilation rate for convolution. Default: 1 129 | padding (Union[int, Tuple[int, int]]): Padding for convolution. When not specified, 130 | padding is automatically computed based on kernel size 131 | and dilation rage. Default is ``None`` 132 | groups (Optional[int]): Number of groups in convolution. Default: ``1`` 133 | bias (Optional[bool]): Use bias. Default: ``False`` 134 | padding_mode (Optional[str]): Padding mode. Default: ``zeros`` 135 | use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` 136 | use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). 137 | Default: ``True`` 138 | act_name (Optional[str]): Use specific activation function. Overrides the one specified in command line args. 139 | 140 | Shape: 141 | - Input: :math:`(N, C_{in}, H_{in}, W_{in})` 142 | - Output: :math:`(N, C_{out}, H_{out}, W_{out})` 143 | 144 | .. note:: 145 | For depth-wise convolution, `groups=C_{in}=C_{out}`. 146 | """ 147 | 148 | def __init__( 149 | self, 150 | in_channels: int, 151 | out_channels: int, 152 | kernel_size: Union[int, Tuple[int, int]], 153 | stride: Optional[Union[int, Tuple[int, int]]] = 1, 154 | dilation: Optional[Union[int, Tuple[int, int]]] = 1, 155 | padding: Optional[Union[int, Tuple[int, int]]] = None, 156 | groups: Optional[int] = 1, 157 | bias: Optional[bool] = False, 158 | padding_mode: Optional[str] = "zeros", 159 | use_norm: Optional[bool] = True, 160 | use_act: Optional[bool] = True, 161 | act_name: Optional[str] = None, 162 | *args, 163 | **kwargs 164 | ) -> None: 165 | super().__init__() 166 | 167 | if use_norm: 168 | norm_type = "batch_norm" 169 | if norm_type is not None and norm_type.find("batch") > -1: 170 | assert not bias, "Do not use bias when using normalization layers." 171 | elif norm_type is not None and norm_type.find("layer") > -1: 172 | bias = True 173 | if isinstance(kernel_size, int): 174 | kernel_size = (kernel_size, kernel_size) 175 | 176 | if isinstance(stride, int): 177 | stride = (stride, stride) 178 | 179 | if isinstance(dilation, int): 180 | dilation = (dilation, dilation) 181 | 182 | assert isinstance(kernel_size, Tuple) 183 | assert isinstance(stride, Tuple) 184 | assert isinstance(dilation, Tuple) 185 | 186 | if padding is None: 187 | padding = ( 188 | int((kernel_size[0] - 1) / 2) * dilation[0], 189 | int((kernel_size[1] - 1) / 2) * dilation[1], 190 | ) 191 | 192 | block = nn.Sequential() 193 | 194 | conv_layer = Conv2d( 195 | in_channels=in_channels, 196 | out_channels=out_channels, 197 | kernel_size=kernel_size, 198 | stride=stride, 199 | padding=padding, 200 | dilation=dilation, 201 | groups=groups, 202 | bias=bias, 203 | padding_mode=padding_mode, 204 | ) 205 | 206 | block.add_module(name="conv", module=conv_layer) 207 | 208 | self.norm_name = None 209 | if use_norm: 210 | norm_layer = nn.BatchNorm2d(num_features=out_channels) 211 | block.add_module(name="norm", module=norm_layer) 212 | self.norm_name = norm_layer.__class__.__name__ 213 | 214 | self.act_name = None 215 | act_type = ( 216 | "prelu" 217 | if act_name is None 218 | else act_name 219 | ) 220 | 221 | if act_type is not None and use_act: 222 | act_layer = nn.PReLU() 223 | block.add_module(name="act", module=act_layer) 224 | self.act_name = act_layer.__class__.__name__ 225 | 226 | self.block = block 227 | 228 | self.in_channels = in_channels 229 | self.out_channels = out_channels 230 | self.stride = stride 231 | self.groups = groups 232 | self.kernel_size = conv_layer.kernel_size 233 | self.bias = bias 234 | self.dilation = dilation 235 | 236 | @classmethod 237 | def add_arguments(cls, parser: argparse.ArgumentParser): 238 | cls_name = "{} arguments".format(cls.__name__) 239 | group = parser.add_argument_group(title=cls_name, description=cls_name) 240 | group.add_argument( 241 | "--model.layer.conv-init", 242 | type=str, 243 | default="kaiming_normal", 244 | help="Init type for conv layers", 245 | ) 246 | parser.add_argument( 247 | "--model.layer.conv-init-std-dev", 248 | type=float, 249 | default=None, 250 | help="Std deviation for conv layers", 251 | ) 252 | return parser 253 | 254 | def forward(self, x: Tensor) -> Tensor: 255 | return self.block(x) 256 | 257 | def __repr__(self): 258 | repr_str = self.block[0].__repr__() 259 | repr_str = repr_str[:-1] 260 | 261 | if self.norm_name is not None: 262 | repr_str += ", normalization={}".format(self.norm_name) 263 | 264 | if self.act_name is not None: 265 | repr_str += ", activation={}".format(self.act_name) 266 | repr_str += ")" 267 | return repr_str 268 | 269 | def profile_module(self, input: Tensor) -> (Tensor, float, float): 270 | b, in_c, in_h, in_w = input.size() 271 | assert in_c == self.in_channels, "{}!={}".format(in_c, self.in_channels) 272 | 273 | stride_h, stride_w = self.stride 274 | groups = self.groups 275 | 276 | out_h = in_h // stride_h 277 | out_w = in_w // stride_w 278 | 279 | k_h, k_w = self.kernel_size 280 | 281 | # compute MACS 282 | macs = (k_h * k_w) * (in_c * self.out_channels) * (out_h * out_w) * 1.0 283 | macs /= groups 284 | 285 | if self.bias: 286 | macs += self.out_channels * out_h * out_w 287 | 288 | # compute parameters 289 | params = sum([p.numel() for p in self.parameters()]) 290 | 291 | output = torch.zeros( 292 | size=(b, self.out_channels, out_h, out_w), 293 | dtype=input.dtype, 294 | device=input.device, 295 | ) 296 | # print(macs) 297 | return output, params, macs 298 | 299 | class ConvBNReLU(nn.Sequential): 300 | def __init__( 301 | self, 302 | in_planes, 303 | out_planes, 304 | kernel_size=3, 305 | stride=1, 306 | groups=1, 307 | activation="ReLU", 308 | ): 309 | padding = (kernel_size - 1) // 2 310 | super(ConvBNReLU, self).__init__( 311 | nn.Conv2d( 312 | in_planes, 313 | out_planes, 314 | kernel_size, 315 | stride, 316 | padding, 317 | groups=groups, 318 | bias=False, 319 | ), 320 | nn.BatchNorm2d(out_planes), 321 | nn.ReLU(), 322 | ) 323 | class ConvBNReLU(nn.Sequential): 324 | def __init__( 325 | self, 326 | in_planes, 327 | out_planes, 328 | kernel_size=3, 329 | stride=1, 330 | groups=1, 331 | activation="ReLU", 332 | ): 333 | padding = (kernel_size - 1) // 2 334 | super(ConvBNReLU, self).__init__( 335 | nn.Conv2d( 336 | in_planes, 337 | out_planes, 338 | kernel_size, 339 | stride, 340 | padding, 341 | groups=groups, 342 | bias=False, 343 | ), 344 | nn.BatchNorm2d(out_planes), 345 | nn.ReLU(), 346 | ) 347 | 348 | 349 | class InvertedResidual(nn.Module): 350 | def __init__(self, inp, oup, stride, expand_ratio, activation="ReLU"): 351 | super(InvertedResidual, self).__init__() 352 | self.stride = stride 353 | assert stride in [1, 2] 354 | 355 | hidden_dim = int(round(inp * expand_ratio)) 356 | self.use_res_connect = self.stride == 1 and inp == oup 357 | 358 | layers = [] 359 | if expand_ratio != 1: 360 | # pw 361 | layers.append( 362 | ConvBNReLU(inp, hidden_dim, kernel_size=1, activation=activation) 363 | ) 364 | layers.extend( 365 | [ 366 | # dw 367 | ConvBNReLU( 368 | hidden_dim, 369 | hidden_dim, 370 | stride=stride, 371 | groups=hidden_dim, 372 | activation=activation, 373 | ), 374 | # pw-linear 375 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 376 | nn.BatchNorm2d(oup), 377 | ] 378 | ) 379 | self.conv = nn.Sequential(*layers) 380 | 381 | def forward(self, x): 382 | if self.use_res_connect: 383 | return x + self.conv(x) 384 | else: 385 | return self.conv(x) 386 | 387 | 388 | class LinearLayer(BaseLayer): 389 | """ 390 | Applies a linear transformation to the input data 391 | 392 | Args: 393 | in_features (int): number of features in the input tensor 394 | out_features (int): number of features in the output tensor 395 | bias (Optional[bool]): use bias or not 396 | channel_first (Optional[bool]): Channels are first or last dimension. If first, then use Conv2d 397 | 398 | Shape: 399 | - Input: :math:`(N, *, C_{in})` if not channel_first else :math:`(N, C_{in}, *)` where :math:`*` means any number of dimensions. 400 | - Output: :math:`(N, *, C_{out})` if not channel_first else :math:`(N, C_{out}, *)` 401 | 402 | """ 403 | 404 | def __init__( 405 | self, 406 | in_features: int, 407 | out_features: int, 408 | bias: Optional[bool] = True, 409 | channel_first: Optional[bool] = False, 410 | *args, 411 | **kwargs 412 | ) -> None: 413 | super().__init__() 414 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 415 | self.bias = nn.Parameter(torch.Tensor(out_features)) if bias else None 416 | 417 | self.in_features = in_features 418 | self.out_features = out_features 419 | self.channel_first = channel_first 420 | 421 | self.reset_params() 422 | 423 | @classmethod 424 | def add_arguments(cls, parser: argparse.ArgumentParser): 425 | parser.add_argument( 426 | "--model.layer.linear-init", 427 | type=str, 428 | default="xavier_uniform", 429 | help="Init type for linear layers", 430 | ) 431 | parser.add_argument( 432 | "--model.layer.linear-init-std-dev", 433 | type=float, 434 | default=0.01, 435 | help="Std deviation for Linear layers", 436 | ) 437 | return parser 438 | 439 | def reset_params(self): 440 | if self.weight is not None: 441 | torch.nn.init.xavier_uniform_(self.weight) 442 | if self.bias is not None: 443 | torch.nn.init.constant_(self.bias, 0) 444 | 445 | def forward(self, x: Tensor) -> Tensor: 446 | if self.channel_first: 447 | if not self.training: 448 | logger.error("Channel-first mode is only supported during inference") 449 | if x.dim() != 4: 450 | logger.error("Input should be 4D, i.e., (B, C, H, W) format") 451 | # only run during conversion 452 | with torch.no_grad(): 453 | return F.conv2d( 454 | input=x, 455 | weight=self.weight.clone() 456 | .detach() 457 | .reshape(self.out_features, self.in_features, 1, 1), 458 | bias=self.bias, 459 | ) 460 | else: 461 | x = F.linear(x, weight=self.weight, bias=self.bias) 462 | return x 463 | 464 | def __repr__(self): 465 | repr_str = ( 466 | "{}(in_features={}, out_features={}, bias={}, channel_first={})".format( 467 | self.__class__.__name__, 468 | self.in_features, 469 | self.out_features, 470 | True if self.bias is not None else False, 471 | self.channel_first, 472 | ) 473 | ) 474 | return repr_str 475 | 476 | def profile_module( 477 | self, input: Tensor, *args, **kwargs 478 | ) -> Tuple[Tensor, float, float]: 479 | out_size = list(input.shape) 480 | out_size[-1] = self.out_features 481 | params = sum([p.numel() for p in self.parameters()]) 482 | macs = params 483 | output = torch.zeros(size=out_size, dtype=input.dtype, device=input.device) 484 | return output, params, macs 485 | 486 | 487 | class GroupLinear(BaseLayer): 488 | """ 489 | Applies a GroupLinear transformation layer, as defined `here `_, 490 | `here `_ and `here `_ 491 | 492 | Args: 493 | in_features (int): number of features in the input tensor 494 | out_features (int): number of features in the output tensor 495 | n_groups (int): number of groups 496 | bias (Optional[bool]): use bias or not 497 | feature_shuffle (Optional[bool]): Shuffle features between groups 498 | 499 | Shape: 500 | - Input: :math:`(N, *, C_{in})` 501 | - Output: :math:`(N, *, C_{out})` 502 | 503 | """ 504 | 505 | def __init__( 506 | self, 507 | in_features: int, 508 | out_features: int, 509 | n_groups: int, 510 | bias: Optional[bool] = True, 511 | feature_shuffle: Optional[bool] = False, 512 | *args, 513 | **kwargs 514 | ) -> None: 515 | if in_features % n_groups != 0: 516 | logger.error( 517 | "Input dimensions ({}) must be divisible by n_groups ({})".format( 518 | in_features, n_groups 519 | ) 520 | ) 521 | if out_features % n_groups != 0: 522 | logger.error( 523 | "Output dimensions ({}) must be divisible by n_groups ({})".format( 524 | out_features, n_groups 525 | ) 526 | ) 527 | 528 | in_groups = in_features // n_groups 529 | out_groups = out_features // n_groups 530 | 531 | super().__init__() 532 | 533 | self.weight = nn.Parameter(torch.Tensor(n_groups, in_groups, out_groups)) 534 | if bias: 535 | self.bias = nn.Parameter(torch.Tensor(n_groups, 1, out_groups)) 536 | else: 537 | self.bias = None 538 | 539 | self.out_features = out_features 540 | self.in_features = in_features 541 | self.n_groups = n_groups 542 | self.feature_shuffle = feature_shuffle 543 | 544 | self.reset_params() 545 | 546 | @classmethod 547 | def add_arguments(cls, parser: argparse.ArgumentParser): 548 | parser.add_argument( 549 | "--model.layer.group-linear-init", 550 | type=str, 551 | default="xavier_uniform", 552 | help="Init type for group linear layers", 553 | ) 554 | parser.add_argument( 555 | "--model.layer.group-linear-init-std-dev", 556 | type=float, 557 | default=0.01, 558 | help="Std deviation for group linear layers", 559 | ) 560 | return parser 561 | 562 | def reset_params(self): 563 | if self.weight is not None: 564 | torch.nn.init.xavier_uniform_(self.weight.data) 565 | if self.bias is not None: 566 | torch.nn.init.constant_(self.bias.data, 0) 567 | 568 | def _forward(self, x: Tensor) -> Tensor: 569 | bsz = x.shape[0] 570 | # [B, N] --> [B, g, N/g] 571 | x = x.reshape(bsz, self.n_groups, -1) 572 | 573 | # [B, g, N/g] --> [g, B, N/g] 574 | x = x.transpose(0, 1) 575 | # [g, B, N/g] x [g, N/g, M/g] --> [g, B, M/g] 576 | x = torch.bmm(x, self.weight) 577 | 578 | if self.bias is not None: 579 | x = torch.add(x, self.bias) 580 | 581 | if self.feature_shuffle: 582 | # [g, B, M/g] --> [B, M/g, g] 583 | x = x.permute(1, 2, 0) 584 | # [B, M/g, g] --> [B, g, M/g] 585 | x = x.reshape(bsz, self.n_groups, -1) 586 | else: 587 | # [g, B, M/g] --> [B, g, M/g] 588 | x = x.transpose(0, 1) 589 | 590 | return x.reshape(bsz, -1) 591 | 592 | def forward(self, x: Tensor) -> Tensor: 593 | if x.dim() == 2: 594 | x = self._forward(x) 595 | return x 596 | else: 597 | in_dims = x.shape[:-1] 598 | n_elements = x.numel() // self.in_features 599 | x = x.reshape(n_elements, -1) 600 | x = self._forward(x) 601 | x = x.reshape(*in_dims, -1) 602 | return x 603 | 604 | def __repr__(self): 605 | repr_str = "{}(in_features={}, out_features={}, groups={}, bias={}, shuffle={})".format( 606 | self.__class__.__name__, 607 | self.in_features, 608 | self.out_features, 609 | self.n_groups, 610 | True if self.bias is not None else False, 611 | self.feature_shuffle, 612 | ) 613 | return repr_str 614 | 615 | def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: 616 | params = sum([p.numel() for p in self.parameters()]) 617 | macs = params 618 | 619 | out_size = list(input.shape) 620 | out_size[-1] = self.out_features 621 | 622 | output = torch.zeros(size=out_size, dtype=input.dtype, device=input.device) 623 | return output, params, macs 624 | 625 | 626 | class GlobalPool(BaseLayer): 627 | """ 628 | This layers applies global pooling over a 4D or 5D input tensor 629 | 630 | Args: 631 | pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean` 632 | keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False` 633 | 634 | Shape: 635 | - Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)` 636 | - Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)` 637 | """ 638 | 639 | pool_types = ["mean", "rms", "abs"] 640 | 641 | def __init__( 642 | self, 643 | pool_type: Optional[str] = "mean", 644 | keep_dim: Optional[bool] = False, 645 | *args, 646 | **kwargs 647 | ) -> None: 648 | super().__init__() 649 | if pool_type not in self.pool_types: 650 | logger.error( 651 | "Supported pool types are: {}. Got {}".format( 652 | self.pool_types, pool_type 653 | ) 654 | ) 655 | self.pool_type = pool_type 656 | self.keep_dim = keep_dim 657 | 658 | @classmethod 659 | def add_arguments(cls, parser: argparse.ArgumentParser): 660 | cls_name = "{} arguments".format(cls.__name__) 661 | group = parser.add_argument_group(title=cls_name, description=cls_name) 662 | group.add_argument( 663 | "--model.layer.global-pool", 664 | type=str, 665 | default="mean", 666 | help="Which global pooling?", 667 | ) 668 | return parser 669 | 670 | def _global_pool(self, x: Tensor, dims: List): 671 | if self.pool_type == "rms": # root mean square 672 | x = x**2 673 | x = torch.mean(x, dim=dims, keepdim=self.keep_dim) 674 | x = x**-0.5 675 | elif self.pool_type == "abs": # absolute 676 | x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim) 677 | else: 678 | # default is mean 679 | # same as AdaptiveAvgPool 680 | x = torch.mean(x, dim=dims, keepdim=self.keep_dim) 681 | return x 682 | 683 | def forward(self, x: Tensor) -> Tensor: 684 | if x.dim() == 4: 685 | dims = [-2, -1] 686 | elif x.dim() == 5: 687 | dims = [-3, -2, -1] 688 | else: 689 | raise NotImplementedError("Currently 2D and 3D global pooling supported") 690 | return self._global_pool(x, dims=dims) 691 | 692 | def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: 693 | input = self.forward(input) 694 | return input, 0.0, 0.0 695 | 696 | def __repr__(self): 697 | return "{}(type={})".format(self.__class__.__name__, self.pool_type) 698 | 699 | 700 | class Dropout(nn.Dropout): 701 | """ 702 | This layer, during training, randomly zeroes some of the elements of the input tensor with probability `p` 703 | using samples from a Bernoulli distribution. 704 | 705 | Args: 706 | p: probability of an element to be zeroed. Default: 0.5 707 | inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 708 | 709 | Shape: 710 | - Input: :math:`(N, *)` where :math:`N` is the batch size 711 | - Output: same as the input 712 | 713 | """ 714 | 715 | def __init__( 716 | self, p: Optional[float] = 0.5, inplace: Optional[bool] = False, *args, **kwargs 717 | ) -> None: 718 | super().__init__(p=p, inplace=inplace) 719 | 720 | def profile_module( 721 | self, input: Tensor, *args, **kwargs 722 | ) -> Tuple[Tensor, float, float]: 723 | return input, 0.0, 0.0 724 | 725 | 726 | class Dropout2d(nn.Dropout2d): 727 | """ 728 | This layer, during training, randomly zeroes some of the elements of the 4D input tensor with probability `p` 729 | using samples from a Bernoulli distribution. 730 | 731 | Args: 732 | p: probability of an element to be zeroed. Default: 0.5 733 | inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 734 | 735 | Shape: 736 | - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the input channels, 737 | :math:`H` is the input tensor height, and :math:`W` is the input tensor width 738 | - Output: same as the input 739 | 740 | """ 741 | 742 | def __init__(self, p: float = 0.5, inplace: bool = False): 743 | super().__init__(p=p, inplace=inplace) 744 | 745 | def profile_module(self, input: Tensor, *args, **kwargs) -> (Tensor, float, float): 746 | return input, 0.0, 0.0 747 | 748 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2023 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import time 7 | from typing import Optional 8 | import sys 9 | import os 10 | 11 | text_colors = { 12 | "logs": "\033[34m", # 033 is the escape code and 34 is the color code 13 | "info": "\033[32m", 14 | "warning": "\033[33m", 15 | "debug": "\033[93m", 16 | "error": "\033[31m", 17 | "bold": "\033[1m", 18 | "end_color": "\033[0m", 19 | "light_red": "\033[36m", 20 | } 21 | 22 | 23 | def get_curr_time_stamp() -> str: 24 | return time.strftime("%Y-%m-%d %H:%M:%S") 25 | 26 | 27 | def error(message: str) -> None: 28 | time_stamp = get_curr_time_stamp() 29 | error_str = ( 30 | text_colors["error"] 31 | + text_colors["bold"] 32 | + "ERROR " 33 | + text_colors["end_color"] 34 | ) 35 | 36 | sys.exit("{} - {} - {}. Exiting!!!".format(time_stamp, error_str, message)) 37 | 38 | 39 | def color_text(in_text: str) -> str: 40 | return text_colors["light_red"] + in_text + text_colors["end_color"] 41 | 42 | 43 | def log(message: str, end="\n") -> None: 44 | time_stamp = get_curr_time_stamp() 45 | log_str = ( 46 | text_colors["logs"] + text_colors["bold"] + "LOGS " + text_colors["end_color"] 47 | ) 48 | print("{} - {} - {}".format(time_stamp, log_str, message), end=end) 49 | 50 | 51 | def warning(message: str) -> None: 52 | time_stamp = get_curr_time_stamp() 53 | warn_str = ( 54 | text_colors["warning"] 55 | + text_colors["bold"] 56 | + "WARNING" 57 | + text_colors["end_color"] 58 | ) 59 | print("{} - {} - {}".format(time_stamp, warn_str, message)) 60 | 61 | 62 | def info(message: str, print_line: Optional[bool] = False) -> None: 63 | time_stamp = get_curr_time_stamp() 64 | info_str = ( 65 | text_colors["info"] + text_colors["bold"] + "INFO " + text_colors["end_color"] 66 | ) 67 | print("{} - {} - {}".format(time_stamp, info_str, message)) 68 | if print_line: 69 | double_dash_line(dashes=150) 70 | 71 | 72 | def debug(message: str) -> None: 73 | time_stamp = get_curr_time_stamp() 74 | log_str = ( 75 | text_colors["debug"] 76 | + text_colors["bold"] 77 | + "DEBUG " 78 | + text_colors["end_color"] 79 | ) 80 | print("{} - {} - {}".format(time_stamp, log_str, message)) 81 | 82 | 83 | def double_dash_line(dashes: Optional[int] = 75) -> None: 84 | print(text_colors["error"] + "=" * dashes + text_colors["end_color"]) 85 | 86 | 87 | def singe_dash_line(dashes: Optional[int] = 67) -> None: 88 | print("-" * dashes) 89 | 90 | 91 | def print_header(header: str) -> None: 92 | double_dash_line() 93 | print( 94 | text_colors["info"] 95 | + text_colors["bold"] 96 | + "=" * 50 97 | + str(header) 98 | + text_colors["end_color"] 99 | ) 100 | double_dash_line() 101 | 102 | 103 | def print_header_minor(header: str) -> None: 104 | print( 105 | text_colors["warning"] 106 | + text_colors["bold"] 107 | + "=" * 25 108 | + str(header) 109 | + text_colors["end_color"] 110 | ) 111 | 112 | 113 | def disable_printing(): 114 | sys.stdout = open(os.devnull, "w") 115 | 116 | 117 | def enable_printing(): 118 | sys.stdout = sys.__stdout__ 119 | -------------------------------------------------------------------------------- /profiler.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from typing import Tuple 9 | 10 | 11 | def module_profile(module, x: Tensor, *args, **kwargs) -> Tuple[Tensor, float, float]: 12 | """ 13 | Helper function to profile a module. 14 | 15 | .. note:: 16 | Module profiling is for reference only and may contain errors as it solely relies on user implementation to 17 | compute theoretical FLOPs 18 | """ 19 | 20 | if isinstance(module, nn.Sequential): 21 | n_macs = n_params = 0.0 22 | for l in module: 23 | try: 24 | x, l_p, l_macs = l.profile_module(x) 25 | n_macs += l_macs 26 | n_params += l_p 27 | except Exception as e: 28 | print(e, l) 29 | pass 30 | else: 31 | x, n_params, n_macs = module.profile_module(x) 32 | return x, n_params, n_macs 33 | -------------------------------------------------------------------------------- /sync_batch_norm.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from typing import Optional, Tuple 9 | 10 | class SyncBatchNorm(nn.SyncBatchNorm): 11 | """ 12 | Applies a `Syncronized Batch Normalization `_ over the input tensor 13 | 14 | Args: 15 | num_features (Optional, int): :math:`C` from an expected input of size :math:`(N, C, *)` 16 | eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 17 | momentum (Optional, float): Value used for the running_mean and running_var computation. Default: 0.1 18 | affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` 19 | track_running_stats: If ``True``, tracks running mean and variance. Default: ``True`` 20 | 21 | Shape: 22 | - Input: :math:`(N, C, *)` where :math:`N` is the batch size, :math:`C` is the number of input channels, 23 | :math:`*` is the remaining input dimensions 24 | - Output: same shape as the input 25 | 26 | """ 27 | 28 | def __init__( 29 | self, 30 | num_features: int, 31 | eps: Optional[float] = 1e-5, 32 | momentum: Optional[float] = 0.1, 33 | affine: Optional[bool] = True, 34 | track_running_stats: Optional[bool] = True, 35 | *args, 36 | **kwargs 37 | ) -> None: 38 | super().__init__( 39 | num_features=num_features, 40 | eps=eps, 41 | momentum=momentum, 42 | affine=affine, 43 | track_running_stats=track_running_stats, 44 | ) 45 | 46 | def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: 47 | # Since normalization layers can be fused, we do not count their operations 48 | params = sum([p.numel() for p in self.parameters()]) 49 | return input, params, 0.0 50 | 51 | 52 | class SyncBatchNormFP32(SyncBatchNorm): 53 | """ 54 | Synchronized BN in FP32 55 | """ 56 | 57 | def __init__( 58 | self, 59 | num_features: int, 60 | eps: Optional[float] = 1e-5, 61 | momentum: Optional[float] = 0.1, 62 | affine: Optional[bool] = True, 63 | track_running_stats: Optional[bool] = True, 64 | *args, 65 | **kwargs 66 | ) -> None: 67 | super().__init__( 68 | num_features=num_features, 69 | eps=eps, 70 | momentum=momentum, 71 | affine=affine, 72 | track_running_stats=track_running_stats, 73 | ) 74 | 75 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor: 76 | in_dtype = x.dtype 77 | return super().forward(x.to(dtype=torch.float)).to(dtype=in_dtype) 78 | 79 | def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: 80 | # Since normalization layers can be fused, we do not count their operations 81 | params = sum([p.numel() for p in self.parameters()]) 82 | return input, params, 0.0 83 | --------------------------------------------------------------------------------