├── .gitignore ├── img └── 300_spike_driven_transformer_v2_me.png ├── classification ├── eval.sh ├── train.sh ├── train_finetune_from_t1_to_t4.sh ├── util │ ├── lr_sched.py │ ├── crop.py │ ├── datasets.py │ ├── lars.py │ ├── lr_decay.py │ ├── lr_decay_spikformer.py │ ├── kd_loss.py │ ├── pos_embed.py │ └── misc.py ├── engine_finetune.py ├── models.py ├── main_finetune.py └── metaformer.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | __pycache__/ 3 | *M.sh 4 | outputs/ 5 | *.pth 6 | -------------------------------------------------------------------------------- /img/300_spike_driven_transformer_v2_me.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike-Driven-Transformer-V2/HEAD/img/300_spike_driven_transformer_v2_me.png -------------------------------------------------------------------------------- /classification/eval.sh: -------------------------------------------------------------------------------- 1 | python main_finetune.py --batch_size 50 --model metaspikformer_8_512 --data_path /raid/ligq/imagenet1-k --eval --resume checkpoint-299.pth 2 | -------------------------------------------------------------------------------- /classification/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 \ 2 | main_finetune.py \ 3 | --batch_size 128 \ 4 | --blr 6e-4 \ 5 | --warmup_epochs 10 \ 6 | --epochs 200 \ 7 | --model metaspikformer_8_512 \ 8 | --data_path /raid/ligq/imagenet1-k \ 9 | --output_dir outputs/55M \ 10 | --log_dir outputs/55M \ 11 | --model_mode ms \ 12 | --dist_eval 13 | -------------------------------------------------------------------------------- /classification/train_finetune_from_t1_to_t4.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 \ 2 | main_finetune.py \ 3 | --batch_size 24 \ 4 | --blr 2e-5 \ 5 | --warmup_epochs 5 \ 6 | --epochs 50 \ 7 | --model metaspikformer_8_512 \ 8 | --data_path /raid/ligq/imagenet1-k \ 9 | --output_dir outputs/55M_T4 \ 10 | --log_dir outputs/55M_T4 \ 11 | --model_mode ms \ 12 | --dist_eval \ 13 | --finetune checkpoint-299.pth \ 14 | --time_steps 4 \ 15 | --kd \ 16 | --teacher_model caformer_b36_in21ft1k \ 17 | --distillation_type hard 18 | -------------------------------------------------------------------------------- /classification/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | 10 | def adjust_learning_rate(optimizer, epoch, args): 11 | """Decay the learning rate with half-cycle cosine after warmup""" 12 | if epoch < args.warmup_epochs: 13 | lr = args.lr * epoch / args.warmup_epochs 14 | else: 15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( 16 | 1.0 17 | + math.cos( 18 | math.pi 19 | * (epoch - args.warmup_epochs) 20 | / (args.epochs - args.warmup_epochs) 21 | ) 22 | ) 23 | for param_group in optimizer.param_groups: 24 | if "lr_scale" in param_group: 25 | param_group["lr"] = lr * param_group["lr_scale"] 26 | else: 27 | param_group["lr"] = lr 28 | return lr 29 | -------------------------------------------------------------------------------- /classification/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | 23 | @staticmethod 24 | def get_params(img, scale, ratio): 25 | width, height = F._get_image_size(img) 26 | area = height * width 27 | 28 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 29 | log_ratio = torch.log(torch.tensor(ratio)) 30 | aspect_ratio = torch.exp( 31 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 32 | ).item() 33 | 34 | w = int(round(math.sqrt(target_area * aspect_ratio))) 35 | h = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | w = min(w, width) 38 | h = min(h, height) 39 | 40 | i = torch.randint(0, height - h + 1, size=(1,)).item() 41 | j = torch.randint(0, width - w + 1, size=(1,)).item() 42 | 43 | return i, j, h, w 44 | -------------------------------------------------------------------------------- /classification/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, "train" if is_train else "val") 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation="bicubic", 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize( 60 | size, interpolation=PIL.Image.BICUBIC 61 | ), # to maintain same ratio w.r.t. 224 images 62 | ) 63 | t.append(transforms.CenterCrop(args.input_size)) 64 | 65 | t.append(transforms.ToTensor()) 66 | t.append(transforms.Normalize(mean, std)) 67 | return transforms.Compose(t) 68 | -------------------------------------------------------------------------------- /classification/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | 19 | def __init__( 20 | self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 21 | ): 22 | defaults = dict( 23 | lr=lr, 24 | weight_decay=weight_decay, 25 | momentum=momentum, 26 | trust_coefficient=trust_coefficient, 27 | ) 28 | super().__init__(params, defaults) 29 | 30 | @torch.no_grad() 31 | def step(self): 32 | for g in self.param_groups: 33 | for p in g["params"]: 34 | dp = p.grad 35 | 36 | if dp is None: 37 | continue 38 | 39 | if p.ndim > 1: # if not normalization gamma/beta or bias 40 | dp = dp.add(p, alpha=g["weight_decay"]) 41 | param_norm = torch.norm(p) 42 | update_norm = torch.norm(dp) 43 | one = torch.ones_like(param_norm) 44 | q = torch.where( 45 | param_norm > 0.0, 46 | torch.where( 47 | update_norm > 0, 48 | (g["trust_coefficient"] * param_norm / update_norm), 49 | one, 50 | ), 51 | one, 52 | ) 53 | dp = dp.mul(q) 54 | 55 | param_state = self.state[p] 56 | if "mu" not in param_state: 57 | param_state["mu"] = torch.zeros_like(p) 58 | mu = param_state["mu"] 59 | mu.mul_(g["momentum"]).add_(dp) 60 | p.add_(mu, alpha=-g["lr"]) 61 | -------------------------------------------------------------------------------- /classification/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd( 16 | model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 17 | ): 18 | """ 19 | Parameter groups for layer-wise lr decay 20 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 21 | """ 22 | param_group_names = {} 23 | param_groups = {} 24 | 25 | num_layers = len(model.blocks) + 1 26 | 27 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 28 | 29 | for n, p in model.named_parameters(): 30 | if not p.requires_grad: 31 | continue 32 | 33 | # no decay: all 1D parameters and model specific ones 34 | if p.ndim == 1 or n in no_weight_decay_list: 35 | g_decay = "no_decay" 36 | this_decay = 0.0 37 | else: 38 | g_decay = "decay" 39 | this_decay = weight_decay 40 | 41 | layer_id = get_layer_id_for_vit(n, num_layers) 42 | group_name = "layer_%d_%s" % (layer_id, g_decay) 43 | 44 | if group_name not in param_group_names: 45 | this_scale = layer_scales[layer_id] 46 | 47 | param_group_names[group_name] = { 48 | "lr_scale": this_scale, 49 | "weight_decay": this_decay, 50 | "params": [], 51 | } 52 | param_groups[group_name] = { 53 | "lr_scale": this_scale, 54 | "weight_decay": this_decay, 55 | "params": [], 56 | } 57 | 58 | param_group_names[group_name]["params"].append(n) 59 | param_groups[group_name]["params"].append(p) 60 | 61 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 62 | 63 | return list(param_groups.values()) 64 | 65 | 66 | def get_layer_id_for_vit(name, num_layers): 67 | """ 68 | Assign a parameter with its layer id 69 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 70 | """ 71 | if name in ["cls_token", "pos_embed"]: 72 | return 0 73 | elif name.startswith("patch_embed"): 74 | return 0 75 | elif name.startswith("blocks"): 76 | return int(name.split(".")[1]) + 1 77 | else: 78 | return num_layers 79 | -------------------------------------------------------------------------------- /classification/util/lr_decay_spikformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd( 16 | model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 17 | ): 18 | """ 19 | Parameter groups for layer-wise lr decay 20 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 21 | """ 22 | param_group_names = {} 23 | param_groups = {} 24 | 25 | num_layers = len(model.block3) + len(model.block4) + 1 26 | 27 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 28 | 29 | for n, p in model.named_parameters(): 30 | if not p.requires_grad: # 仅针对需要利用梯度进行更新的参数 31 | continue 32 | 33 | # no decay: all 1D parameters and model specific ones 34 | if p.ndim == 1 or n in no_weight_decay_list: 35 | g_decay = "no_decay" 36 | this_decay = 0.0 37 | else: 38 | g_decay = "decay" 39 | this_decay = weight_decay 40 | 41 | layer_id = get_layer_id_for_vit(n, num_layers) 42 | group_name = "layer_%d_%s" % (layer_id, g_decay) 43 | 44 | if group_name not in param_group_names: 45 | this_scale = layer_scales[layer_id] 46 | 47 | param_group_names[group_name] = { 48 | "lr_scale": this_scale, 49 | "weight_decay": this_decay, 50 | "params": [], 51 | } 52 | param_groups[group_name] = { 53 | "lr_scale": this_scale, 54 | "weight_decay": this_decay, 55 | "params": [], 56 | } 57 | 58 | param_group_names[group_name]["params"].append(n) 59 | param_groups[group_name]["params"].append(p) 60 | 61 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 62 | 63 | return list(param_groups.values()) 64 | 65 | 66 | def get_layer_id_for_vit(name, num_layers): 67 | """ 68 | Assign a parameter with its layer id 69 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 70 | """ 71 | if name in ["cls_token", "pos_embed"]: 72 | return 0 73 | elif name.startswith("patch_embed"): 74 | return 0 75 | elif name.startswith("block"): 76 | # return int(name.split('.')[1]) + 1 77 | return num_layers 78 | else: 79 | return num_layers 80 | -------------------------------------------------------------------------------- /classification/util/kd_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | base_criterion: torch.nn.Module, 19 | teacher_model: torch.nn.Module, 20 | distillation_type: str, 21 | alpha: float, 22 | tau: float, 23 | ): 24 | super().__init__() 25 | self.base_criterion = base_criterion 26 | self.teacher_model = teacher_model 27 | assert distillation_type in ["none", "soft", "hard"] 28 | self.distillation_type = distillation_type 29 | self.alpha = alpha 30 | self.tau = tau 31 | 32 | def forward(self, inputs, outputs, labels): 33 | """ 34 | Args: 35 | inputs: The original inputs that are feed to the teacher model 36 | outputs: the outputs of the model to be trained. It is expected to be 37 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 38 | in the first position and the distillation predictions as the second output 39 | labels: the labels for the base criterion 40 | """ 41 | outputs_kd = None 42 | if not isinstance(outputs, torch.Tensor): 43 | # assume that the model outputs a tuple of [outputs, outputs_kd] 44 | outputs, outputs_kd = outputs 45 | base_loss = self.base_criterion(outputs, labels) 46 | if self.distillation_type == "none": 47 | return base_loss 48 | 49 | if outputs_kd is None: 50 | raise ValueError( 51 | "When knowledge distillation is enabled, the model is " 52 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 53 | "class_token and the dist_token" 54 | ) 55 | # don't backprop throught the teacher 56 | with torch.no_grad(): 57 | teacher_outputs = self.teacher_model(inputs) 58 | 59 | if self.distillation_type == "soft": 60 | T = self.tau 61 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 62 | # with slight modifications 63 | distillation_loss = ( 64 | F.kl_div( 65 | F.log_softmax(outputs_kd / T, dim=1), 66 | # We provide the teacher's targets in log probability because we use log_target=True 67 | # (as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 68 | # but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 69 | F.log_softmax(teacher_outputs / T, dim=1), 70 | reduction="sum", 71 | log_target=True, 72 | ) 73 | * (T * T) 74 | / outputs_kd.numel() 75 | ) 76 | # We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 77 | # But we also experiments output_kd.size(0) 78 | # see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 79 | elif self.distillation_type == "hard": 80 | distillation_loss = F.cross_entropy( 81 | outputs_kd, teacher_outputs.argmax(dim=1) 82 | ) 83 | 84 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 85 | return loss 86 | -------------------------------------------------------------------------------- /classification/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 19 | # MoCo v3: https://github.com/facebookresearch/moco-v3 20 | # -------------------------------------------------------- 21 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 22 | """ 23 | grid_size: int of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 26 | """ 27 | grid_h = np.arange(grid_size, dtype=np.float32) 28 | grid_w = np.arange(grid_size, dtype=np.float32) 29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 30 | grid = np.stack(grid, axis=0) 31 | 32 | grid = grid.reshape([2, 1, grid_size, grid_size]) 33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 34 | if cls_token: 35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 36 | return pos_embed 37 | 38 | 39 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 40 | assert embed_dim % 2 == 0 41 | 42 | # use half of dimensions to encode grid_h 43 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 44 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 45 | 46 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 47 | return emb 48 | 49 | 50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 51 | """ 52 | embed_dim: output dimension for each position 53 | pos: a list of positions to be encoded: size (M,) 54 | out: (M, D) 55 | """ 56 | assert embed_dim % 2 == 0 57 | omega = np.arange(embed_dim // 2, dtype=np.float64) 58 | omega /= embed_dim / 2.0 59 | omega = 1.0 / 10000**omega # (D/2,) 60 | 61 | pos = pos.reshape(-1) # (M,) 62 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 63 | 64 | emb_sin = np.sin(out) # (M, D/2) 65 | emb_cos = np.cos(out) # (M, D/2) 66 | 67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 68 | return emb 69 | 70 | 71 | # -------------------------------------------------------- 72 | # Interpolate position embeddings for high-resolution 73 | # References: 74 | # DeiT: https://github.com/facebookresearch/deit 75 | # -------------------------------------------------------- 76 | def interpolate_pos_embed(model, checkpoint_model): 77 | if "pos_embed" in checkpoint_model: 78 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 79 | embedding_size = pos_embed_checkpoint.shape[-1] 80 | num_patches = model.patch_embed.num_patches 81 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 82 | # height (== width) for the checkpoint position embedding 83 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 84 | # height (== width) for the new position embedding 85 | new_size = int(num_patches**0.5) 86 | # class_token and dist_token are kept unchanged 87 | if orig_size != new_size: 88 | print( 89 | "Position interpolate from %dx%d to %dx%d" 90 | % (orig_size, orig_size, new_size, new_size) 91 | ) 92 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 93 | # only the position tokens are interpolated 94 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 95 | pos_tokens = pos_tokens.reshape( 96 | -1, orig_size, orig_size, embedding_size 97 | ).permute(0, 3, 1, 2) 98 | pos_tokens = torch.nn.functional.interpolate( 99 | pos_tokens, 100 | size=(new_size, new_size), 101 | mode="bicubic", 102 | align_corners=False, 103 | ) 104 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 105 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 106 | checkpoint_model["pos_embed"] = new_pos_embed 107 | -------------------------------------------------------------------------------- /classification/engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy 20 | 21 | import util.misc as misc 22 | import util.lr_sched as lr_sched 23 | from spikingjelly.clock_driven import functional 24 | 25 | 26 | def train_one_epoch( 27 | model: torch.nn.Module, 28 | criterion: torch.nn.Module, 29 | data_loader: Iterable, 30 | optimizer: torch.optim.Optimizer, 31 | device: torch.device, 32 | epoch: int, 33 | loss_scaler, 34 | max_norm: float = 0, 35 | mixup_fn: Optional[Mixup] = None, 36 | log_writer=None, 37 | args=None, 38 | ): 39 | model.train(True) 40 | metric_logger = misc.MetricLogger(delimiter=" ") 41 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) 42 | header = "Epoch: [{}]".format(epoch) 43 | print_freq = 100 44 | 45 | accum_iter = args.accum_iter 46 | 47 | optimizer.zero_grad() 48 | 49 | if log_writer is not None: 50 | print("log_dir: {}".format(log_writer.log_dir)) 51 | 52 | for data_iter_step, (samples, targets) in enumerate( 53 | metric_logger.log_every(data_loader, print_freq, header) 54 | ): 55 | # we use a per iteration (instead of per epoch) lr scheduler 56 | if data_iter_step % accum_iter == 0: 57 | lr_sched.adjust_learning_rate( 58 | optimizer, data_iter_step / len(data_loader) + epoch, args 59 | ) 60 | 61 | samples = samples.to(device, non_blocking=True) 62 | targets = targets.to(device, non_blocking=True) 63 | 64 | if mixup_fn is not None: 65 | samples, targets = mixup_fn(samples, targets) 66 | 67 | with torch.cuda.amp.autocast(): 68 | outputs = model(samples) 69 | if args.kd: 70 | loss = criterion(samples, outputs, targets) 71 | else: 72 | loss = criterion(outputs, targets) 73 | 74 | loss_value = loss.item() 75 | 76 | if not math.isfinite(loss_value): 77 | print("Loss is {}, stopping training".format(loss_value)) 78 | sys.exit(1) 79 | 80 | loss = loss / accum_iter 81 | loss_scaler( 82 | loss, 83 | optimizer, 84 | clip_grad=max_norm, 85 | parameters=model.parameters(), 86 | create_graph=False, 87 | update_grad=(data_iter_step + 1) % accum_iter == 0, 88 | ) 89 | if (data_iter_step + 1) % accum_iter == 0: 90 | optimizer.zero_grad() 91 | 92 | torch.cuda.synchronize() 93 | functional.reset_net(model) 94 | metric_logger.update(loss=loss_value) 95 | min_lr = 10.0 96 | max_lr = 0.0 97 | for group in optimizer.param_groups: 98 | min_lr = min(min_lr, group["lr"]) 99 | max_lr = max(max_lr, group["lr"]) 100 | 101 | metric_logger.update(lr=max_lr) 102 | 103 | loss_value_reduce = misc.all_reduce_mean(loss_value) 104 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 105 | """We use epoch_1000x as the x-axis in tensorboard. 106 | This calibrates different curves when batch size changes. 107 | """ 108 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 109 | log_writer.add_scalar("loss", loss_value_reduce, epoch_1000x) 110 | log_writer.add_scalar("lr", max_lr, epoch_1000x) 111 | 112 | # gather the stats from all processes 113 | metric_logger.synchronize_between_processes() 114 | print("Averaged stats:", metric_logger) 115 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 116 | 117 | 118 | @torch.no_grad() 119 | def evaluate(data_loader, model, device): 120 | criterion = torch.nn.CrossEntropyLoss() 121 | 122 | metric_logger = misc.MetricLogger(delimiter=" ") 123 | header = "Test:" 124 | 125 | # switch to evaluation mode 126 | model.eval() 127 | 128 | for batch in metric_logger.log_every(data_loader, 100, header): 129 | images = batch[0] 130 | target = batch[-1] 131 | images = images.to(device, non_blocking=True) 132 | target = target.to(device, non_blocking=True) 133 | 134 | # compute output 135 | with torch.cuda.amp.autocast(): 136 | output = model(images) 137 | loss = criterion(output, target) 138 | 139 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 140 | functional.reset_net(model) 141 | 142 | batch_size = images.shape[0] 143 | metric_logger.update(loss=loss.item()) 144 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 145 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 146 | # gather the stats from all processes 147 | metric_logger.synchronize_between_processes() 148 | print( 149 | "* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format( 150 | top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss 151 | ) 152 | ) 153 | 154 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spike-driven Transformer V2: Meta Spiking Neural Network Architecture Inspiring the Design of Next-generation Neuromorphic Chips ([ICLR2024](https://openreview.net/forum?id=1SIBN5Xyw7)) 2 | 3 | [Man Yao](https://scholar.google.com/citations?user=eE4vvp0AAAAJ), [Jiakui Hu](https://github.com/jkhu29), [Tianxiang Hu](), [Yifan Xu](https://scholar.google.com/citations?hl=zh-CN&user=pbcoTgsAAAAJ), [Zhaokun Zhou](https://scholar.google.com/citations?user=4nz-h1QAAAAJ), [Yonghong Tian](https://scholar.google.com/citations?user=fn6hJx0AAAAJ), [Bo Xu](), [Guoqi Li](https://scholar.google.com/citations?user=qCfE--MAAAAJ&) 4 | 5 | BICLab, Institute of Automation, Chinese Academy of Sciences 6 | 7 | --- 8 | 9 | :rocket: :rocket: :rocket: **News**: 10 | 11 | - **Jan. 16, 2024**: Accepted as poster in ICLR2024. 12 | - **Feb. 15, 2024**: Release the training and inference codes in classification tasks. 13 | - **Apr. 19, 2024**: Release the [pre-trained ckpts and training logs](https://drive.google.com/drive/folders/12JcIRG8BF6JcgPsXIetSS14udtHXeSSx?usp=sharing) of SDT-v2. 14 | 15 | TODO: 16 | 17 | - [x] Upload train and test scripts. 18 | - [x] Upload checkpoints. 19 | 20 | ## Abstract 21 | 22 | Neuromorphic computing, which exploits Spiking Neural Networks (SNNs) on neuromorphic chips, is a promising energy-efficient alternative to traditional AI. CNN-based SNNs are the current mainstream of neuromorphic computing. By contrast, no neuromorphic chips are designed especially for Transformer-based SNNs, which have just emerged, and their performance is only on par with CNN-based SNNs, offering no distinct advantage. In this work, we propose a general Transformer-based SNN architecture, termed as "Meta-SpikeFormer", whose goals are: (1) **Lower-power**, supports the spike-driven paradigm that there is only sparse addition in the network; (2) **Versatility**, handles various vision tasks; (3) **High-performance**, shows overwhelming performance advantages over CNN-based SNNs; (4) **Meta-architecture**, provides inspiration for future next-generation Transformer-based neuromorphic chip designs. Specifically, we extend the [Spike-driven Transformer](https://github.com/BICLab/Spike-Driven-Transformer) into a meta architecture, and explore the impact of structure, spike-driven self-attention, and skip connection on its performance. On ImageNet-1K, Meta-SpikeFormer achieves **80.0% top-1 accuracy** (55M), surpassing the current state-of-the-art (SOTA) SNN baselines (66M) by 3.7%. This is the first direct training SNN backbone that can simultaneously **supports classification, detection, and segmentation**, obtaining SOTA results in SNNs. Finally, we discuss the inspiration of the meta SNN architecture for neuromorphic chip design. 23 | 24 | ![V2](./img/300_spike_driven_transformer_v2_me.png) 25 | 26 | ## Classification 27 | 28 | ### Requirements 29 | 30 | ```python3 31 | pytorch >= 2.0.0 32 | cupy 33 | spikingjelly == 0.0.0.0.12 34 | ``` 35 | 36 | ### Results on Imagenet-1K 37 | 38 | Pre-trained ckpts and training logs of 55M: [here](https://drive.google.com/drive/folders/12JcIRG8BF6JcgPsXIetSS14udtHXeSSx?usp=sharing). 39 | 40 | ### Train & Test 41 | 42 | The hyper-parameters are in `./conf/`. 43 | 44 | Train: 45 | 46 | ```shell 47 | torchrun --standalone --nproc_per_node=8 \ 48 | main_finetune.py \ 49 | --batch_size 128 \ 50 | --blr 6e-4 \ 51 | --warmup_epochs 10 \ 52 | --epochs 200 \ 53 | --model metaspikformer_8_512 \ 54 | --data_path /your/data/path \ 55 | --output_dir outputs/T1 \ 56 | --log_dir outputs/T1 \ 57 | --model_mode ms \ 58 | --dist_eval 59 | ``` 60 | 61 | Finetune: 62 | 63 | > Please download caformer_b36_in21_ft1k.pth first following [PoolFormer](https://github.com/sail-sg/poolformer). 64 | 65 | ```shell 66 | torchrun --standalone --nproc_per_node=8 \ 67 | main_finetune.py \ 68 | --batch_size 24 \ 69 | --blr 2e-5 \ 70 | --warmup_epochs 5 \ 71 | --epochs 50 \ 72 | --model metaspikformer_8_512 \ 73 | --data_path /your/data/path \ 74 | --output_dir outputs/T4 \ 75 | --log_dir outputs/T4 \ 76 | --model_mode ms \ 77 | --dist_eval \ 78 | --finetune /your/ckpt/path \ 79 | --time_steps 4 \ 80 | --kd \ 81 | --teacher_model caformer_b36_in21ft1k \ 82 | --distillation_type hard 83 | ``` 84 | 85 | Test: 86 | 87 | ```shell 88 | python main_finetune.py --batch_size 128 --model metaspikformer_8_512 --data_path /your/data/path --eval --resume /your/ckpt/path 89 | ``` 90 | 91 | ### Data Prepare 92 | 93 | ImageNet with the following folder structure, you can extract imagenet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4). 94 | 95 | ```shell 96 | │imagenet/ 97 | ├──train/ 98 | │ ├── n01440764 99 | │ │ ├── n01440764_10026.JPEG 100 | │ │ ├── n01440764_10027.JPEG 101 | │ │ ├── ...... 102 | │ ├── ...... 103 | ├──val/ 104 | │ ├── n01440764 105 | │ │ ├── ILSVRC2012_val_00000293.JPEG 106 | │ │ ├── ILSVRC2012_val_00002138.JPEG 107 | │ │ ├── ...... 108 | │ ├── ...... 109 | ``` 110 | 111 | ## Contact Information 112 | 113 | ``` 114 | @inproceedings{ 115 | yao2024spikedriven, 116 | title={Spike-driven Transformer V2: Meta Spiking Neural Network Architecture Inspiring the Design of Next-generation Neuromorphic Chips}, 117 | author={Man Yao and JiaKui Hu and Tianxiang Hu and Yifan Xu and Zhaokun Zhou and Yonghong Tian and Bo XU and Guoqi Li}, 118 | booktitle={The Twelfth International Conference on Learning Representations}, 119 | year={2024}, 120 | url={https://openreview.net/forum?id=1SIBN5Xyw7} 121 | } 122 | ``` 123 | 124 | For help or issues using this git, please submit a GitHub issue. 125 | 126 | For other communications related to this git, please contact `manyao@ia.ac.cn` and `jkhu29@stu.pku.edu.cn`. 127 | 128 | ## Thanks 129 | 130 | Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works. 131 | 132 | [deit](https://github.com/facebookresearch/deit) 133 | -------------------------------------------------------------------------------- /classification/util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value, 84 | ) 85 | 86 | 87 | class MetricLogger(object): 88 | def __init__(self, delimiter="\t"): 89 | self.meters = defaultdict(SmoothedValue) 90 | self.delimiter = delimiter 91 | 92 | def update(self, **kwargs): 93 | for k, v in kwargs.items(): 94 | if v is None: 95 | continue 96 | if isinstance(v, torch.Tensor): 97 | v = v.item() 98 | assert isinstance(v, (float, int)) 99 | self.meters[k].update(v) 100 | 101 | def __getattr__(self, attr): 102 | if attr in self.meters: 103 | return self.meters[attr] 104 | if attr in self.__dict__: 105 | return self.__dict__[attr] 106 | raise AttributeError( 107 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 108 | ) 109 | 110 | def __str__(self): 111 | loss_str = [] 112 | for name, meter in self.meters.items(): 113 | loss_str.append("{}: {}".format(name, str(meter))) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = "" 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt="{avg:.4f}") 130 | data_time = SmoothedValue(fmt="{avg:.4f}") 131 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 132 | log_msg = [ 133 | header, 134 | "[{0" + space_fmt + "}/{1}]", 135 | "eta: {eta}", 136 | "{meters}", 137 | "time: {time}", 138 | "data: {data}", 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append("max mem: {memory:.0f}") 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print( 153 | log_msg.format( 154 | i, 155 | len(iterable), 156 | eta=eta_string, 157 | meters=str(self), 158 | time=str(iter_time), 159 | data=str(data_time), 160 | memory=torch.cuda.max_memory_allocated() / MB, 161 | ) 162 | ) 163 | else: 164 | print( 165 | log_msg.format( 166 | i, 167 | len(iterable), 168 | eta=eta_string, 169 | meters=str(self), 170 | time=str(iter_time), 171 | data=str(data_time), 172 | ) 173 | ) 174 | i += 1 175 | end = time.time() 176 | total_time = time.time() - start_time 177 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 178 | print( 179 | "{} Total time: {} ({:.4f} s / it)".format( 180 | header, total_time_str, total_time / len(iterable) 181 | ) 182 | ) 183 | 184 | 185 | def setup_for_distributed(is_master): 186 | """ 187 | This function disables printing when not in master process 188 | """ 189 | builtin_print = builtins.print 190 | 191 | def print(*args, **kwargs): 192 | force = kwargs.pop("force", False) 193 | force = force or (get_world_size() > 8) 194 | if is_master or force: 195 | now = datetime.datetime.now().time() 196 | builtin_print("[{}] ".format(now), end="") # print with time stamp 197 | builtin_print(*args, **kwargs) 198 | 199 | builtins.print = print 200 | 201 | 202 | def is_dist_avail_and_initialized(): 203 | if not dist.is_available(): 204 | return False 205 | if not dist.is_initialized(): 206 | return False 207 | return True 208 | 209 | 210 | def get_world_size(): 211 | if not is_dist_avail_and_initialized(): 212 | return 1 213 | return dist.get_world_size() 214 | 215 | 216 | def get_rank(): 217 | if not is_dist_avail_and_initialized(): 218 | return 0 219 | return dist.get_rank() 220 | 221 | 222 | def is_main_process(): 223 | return get_rank() == 0 224 | 225 | 226 | def save_on_master(*args, **kwargs): 227 | if is_main_process(): 228 | torch.save(*args, **kwargs) 229 | 230 | 231 | def init_distributed_mode(args): 232 | if args.dist_on_itp: 233 | args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 234 | args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 235 | args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 236 | args.dist_url = "tcp://%s:%s" % ( 237 | os.environ["MASTER_ADDR"], 238 | os.environ["MASTER_PORT"], 239 | ) 240 | os.environ["LOCAL_RANK"] = str(args.gpu) 241 | os.environ["RANK"] = str(args.rank) 242 | os.environ["WORLD_SIZE"] = str(args.world_size) 243 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 244 | elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: 245 | args.rank = int(os.environ["RANK"]) 246 | args.world_size = int(os.environ["WORLD_SIZE"]) 247 | args.gpu = int(os.environ["LOCAL_RANK"]) 248 | elif "SLURM_PROCID" in os.environ: 249 | args.rank = int(os.environ["SLURM_PROCID"]) 250 | args.gpu = args.rank % torch.cuda.device_count() 251 | else: 252 | print("Not using distributed mode") 253 | setup_for_distributed(is_master=True) # hack 254 | args.distributed = False 255 | return 256 | 257 | args.distributed = True 258 | 259 | torch.cuda.set_device(args.gpu) 260 | args.dist_backend = "nccl" 261 | print( 262 | "| distributed init (rank {}): {}, gpu {}".format( 263 | args.rank, args.dist_url, args.gpu 264 | ), 265 | flush=True, 266 | ) 267 | torch.distributed.init_process_group( 268 | backend=args.dist_backend, 269 | init_method=args.dist_url, 270 | world_size=args.world_size, 271 | rank=args.rank, 272 | ) 273 | torch.distributed.barrier() 274 | setup_for_distributed(args.rank == 0) 275 | 276 | 277 | class NativeScalerWithGradNormCount: 278 | state_dict_key = "amp_scaler" 279 | 280 | def __init__(self): 281 | self._scaler = torch.cuda.amp.GradScaler() 282 | 283 | def __call__( 284 | self, 285 | loss, 286 | optimizer, 287 | clip_grad=None, 288 | parameters=None, 289 | create_graph=False, 290 | update_grad=True, 291 | ): 292 | self._scaler.scale(loss).backward(create_graph=create_graph) 293 | if update_grad: 294 | if clip_grad is not None: 295 | assert parameters is not None 296 | self._scaler.unscale_( 297 | optimizer 298 | ) # unscale the gradients of optimizer's assigned params in-place 299 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 300 | else: 301 | self._scaler.unscale_(optimizer) 302 | norm = get_grad_norm_(parameters) 303 | self._scaler.step(optimizer) 304 | self._scaler.update() 305 | else: 306 | norm = None 307 | return norm 308 | 309 | def state_dict(self): 310 | return self._scaler.state_dict() 311 | 312 | def load_state_dict(self, state_dict): 313 | self._scaler.load_state_dict(state_dict) 314 | 315 | 316 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 317 | if isinstance(parameters, torch.Tensor): 318 | parameters = [parameters] 319 | parameters = [p for p in parameters if p.grad is not None] 320 | norm_type = float(norm_type) 321 | if len(parameters) == 0: 322 | return torch.tensor(0.0) 323 | device = parameters[0].grad.device 324 | if norm_type == inf: 325 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 326 | else: 327 | total_norm = torch.norm( 328 | torch.stack( 329 | [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] 330 | ), 331 | norm_type, 332 | ) 333 | return total_norm 334 | 335 | 336 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 337 | output_dir = Path(args.output_dir) 338 | epoch_name = str(epoch) 339 | if loss_scaler is not None: 340 | checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] 341 | for checkpoint_path in checkpoint_paths: 342 | to_save = { 343 | "model": model_without_ddp.state_dict(), 344 | "optimizer": optimizer.state_dict(), 345 | "epoch": epoch, 346 | "scaler": loss_scaler.state_dict(), 347 | "args": args, 348 | } 349 | 350 | save_on_master(to_save, checkpoint_path) 351 | else: 352 | client_state = {"epoch": epoch} 353 | model.save_checkpoint( 354 | save_dir=args.output_dir, 355 | tag="checkpoint-%s" % epoch_name, 356 | client_state=client_state, 357 | ) 358 | 359 | 360 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 361 | if args.resume: 362 | if args.resume.startswith("https"): 363 | checkpoint = torch.hub.load_state_dict_from_url( 364 | args.resume, map_location="cpu", check_hash=True 365 | ) 366 | else: 367 | checkpoint = torch.load(args.resume, map_location="cpu") 368 | model_without_ddp.load_state_dict(checkpoint["model"]) 369 | print("Resume checkpoint %s" % args.resume) 370 | if ( 371 | "optimizer" in checkpoint 372 | and "epoch" in checkpoint 373 | and not (hasattr(args, "eval") and args.eval) 374 | ): 375 | optimizer.load_state_dict(checkpoint["optimizer"]) 376 | args.start_epoch = checkpoint["epoch"] + 1 377 | if "scaler" in checkpoint: 378 | loss_scaler.load_state_dict(checkpoint["scaler"]) 379 | print("With optim & sched!") 380 | 381 | 382 | def all_reduce_mean(x): 383 | world_size = get_world_size() 384 | if world_size > 1: 385 | x_reduce = torch.tensor(x).cuda() 386 | dist.all_reduce(x_reduce) 387 | x_reduce /= world_size 388 | return x_reduce.item() 389 | else: 390 | return x 391 | 392 | 393 | def accuracy(output, target, topk=(1,)): 394 | """Computes the accuracy over the k top predictions for the specified values of k""" 395 | maxk = min(max(topk), output.size()[1]) 396 | batch_size = target.size(0) 397 | _, pred = output.topk(maxk, 1, True, True) 398 | pred = pred.t() 399 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 400 | return [ 401 | correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size 402 | for k in topk 403 | ] 404 | -------------------------------------------------------------------------------- /classification/models.py: -------------------------------------------------------------------------------- 1 | # from visualizer import get_local 2 | import torch 3 | import torchinfo 4 | import torch.nn as nn 5 | from spikingjelly.clock_driven.neuron import ( 6 | MultiStepParametricLIFNode, 7 | MultiStepLIFNode, 8 | ) 9 | from spikingjelly.clock_driven import layer 10 | from timm.models.layers import to_2tuple, trunc_normal_, DropPath 11 | from timm.models.registry import register_model 12 | from timm.models.vision_transformer import _cfg 13 | from einops.layers.torch import Rearrange 14 | import torch.nn.functional as F 15 | from functools import partial 16 | 17 | 18 | class BNAndPadLayer(nn.Module): 19 | def __init__( 20 | self, 21 | pad_pixels, 22 | num_features, 23 | eps=1e-5, 24 | momentum=0.1, 25 | affine=True, 26 | track_running_stats=True, 27 | ): 28 | super(BNAndPadLayer, self).__init__() 29 | self.bn = nn.BatchNorm2d( 30 | num_features, eps, momentum, affine, track_running_stats 31 | ) 32 | self.pad_pixels = pad_pixels 33 | 34 | def forward(self, input): 35 | output = self.bn(input) 36 | if self.pad_pixels > 0: 37 | if self.bn.affine: 38 | pad_values = ( 39 | self.bn.bias.detach() 40 | - self.bn.running_mean 41 | * self.bn.weight.detach() 42 | / torch.sqrt(self.bn.running_var + self.bn.eps) 43 | ) 44 | else: 45 | pad_values = -self.bn.running_mean / torch.sqrt( 46 | self.bn.running_var + self.bn.eps 47 | ) 48 | output = F.pad(output, [self.pad_pixels] * 4) 49 | pad_values = pad_values.view(1, -1, 1, 1) 50 | output[:, :, 0 : self.pad_pixels, :] = pad_values 51 | output[:, :, -self.pad_pixels :, :] = pad_values 52 | output[:, :, :, 0 : self.pad_pixels] = pad_values 53 | output[:, :, :, -self.pad_pixels :] = pad_values 54 | return output 55 | 56 | @property 57 | def weight(self): 58 | return self.bn.weight 59 | 60 | @property 61 | def bias(self): 62 | return self.bn.bias 63 | 64 | @property 65 | def running_mean(self): 66 | return self.bn.running_mean 67 | 68 | @property 69 | def running_var(self): 70 | return self.bn.running_var 71 | 72 | @property 73 | def eps(self): 74 | return self.bn.eps 75 | 76 | 77 | class RepConv(nn.Module): 78 | def __init__( 79 | self, 80 | in_channel, 81 | out_channel, 82 | bias=False, 83 | ): 84 | super().__init__() 85 | # hidden_channel = in_channel 86 | conv1x1 = nn.Conv2d(in_channel, in_channel, 1, 1, 0, bias=False, groups=1) 87 | bn = BNAndPadLayer(pad_pixels=1, num_features=in_channel) 88 | conv3x3 = nn.Sequential( 89 | nn.Conv2d(in_channel, in_channel, 3, 1, 0, groups=in_channel, bias=False), 90 | nn.Conv2d(in_channel, out_channel, 1, 1, 0, groups=1, bias=False), 91 | nn.BatchNorm2d(out_channel), 92 | ) 93 | 94 | self.body = nn.Sequential(conv1x1, bn, conv3x3) 95 | 96 | def forward(self, x): 97 | return self.body(x) 98 | 99 | 100 | class SepConv(nn.Module): 101 | r""" 102 | Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | dim, 108 | expansion_ratio=2, 109 | act2_layer=nn.Identity, 110 | bias=False, 111 | kernel_size=7, 112 | padding=3, 113 | ): 114 | super().__init__() 115 | med_channels = int(expansion_ratio * dim) 116 | self.lif1 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 117 | self.pwconv1 = nn.Conv2d(dim, med_channels, kernel_size=1, stride=1, bias=bias) 118 | self.bn1 = nn.BatchNorm2d(med_channels) 119 | self.lif2 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 120 | self.dwconv = nn.Conv2d( 121 | med_channels, 122 | med_channels, 123 | kernel_size=kernel_size, 124 | padding=padding, 125 | groups=med_channels, 126 | bias=bias, 127 | ) # depthwise conv 128 | self.pwconv2 = nn.Conv2d(med_channels, dim, kernel_size=1, stride=1, bias=bias) 129 | self.bn2 = nn.BatchNorm2d(dim) 130 | 131 | def forward(self, x): 132 | T, B, C, H, W = x.shape 133 | x = self.lif1(x) 134 | x = self.bn1(self.pwconv1(x.flatten(0, 1))).reshape(T, B, -1, H, W) 135 | x = self.lif2(x) 136 | x = self.dwconv(x.flatten(0, 1)) 137 | x = self.bn2(self.pwconv2(x)).reshape(T, B, -1, H, W) 138 | return x 139 | 140 | 141 | class MS_ConvBlock(nn.Module): 142 | def __init__( 143 | self, 144 | dim, 145 | mlp_ratio=4.0, 146 | ): 147 | super().__init__() 148 | 149 | self.Conv = SepConv(dim=dim) 150 | # self.Conv = MHMC(dim=dim) 151 | 152 | self.lif1 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 153 | self.conv1 = nn.Conv2d( 154 | dim, dim * mlp_ratio, kernel_size=3, padding=1, groups=1, bias=False 155 | ) 156 | # self.conv1 = RepConv(dim, dim*mlp_ratio) 157 | self.bn1 = nn.BatchNorm2d(dim * mlp_ratio) # 这里可以进行改进 158 | self.lif2 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 159 | self.conv2 = nn.Conv2d( 160 | dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False 161 | ) 162 | # self.conv2 = RepConv(dim*mlp_ratio, dim) 163 | self.bn2 = nn.BatchNorm2d(dim) # 这里可以进行改进 164 | 165 | def forward(self, x): 166 | T, B, C, H, W = x.shape 167 | 168 | x = self.Conv(x) + x 169 | x_feat = x 170 | x = self.bn1(self.conv1(self.lif1(x).flatten(0, 1))).reshape(T, B, 4 * C, H, W) 171 | x = self.bn2(self.conv2(self.lif2(x).flatten(0, 1))).reshape(T, B, C, H, W) 172 | x = x_feat + x 173 | 174 | return x 175 | 176 | 177 | class MS_MLP(nn.Module): 178 | def __init__( 179 | self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0 180 | ): 181 | super().__init__() 182 | out_features = out_features or in_features 183 | hidden_features = hidden_features or in_features 184 | # self.fc1 = linear_unit(in_features, hidden_features) 185 | self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1) 186 | self.fc1_bn = nn.BatchNorm1d(hidden_features) 187 | self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 188 | 189 | # self.fc2 = linear_unit(hidden_features, out_features) 190 | self.fc2_conv = nn.Conv1d( 191 | hidden_features, out_features, kernel_size=1, stride=1 192 | ) 193 | self.fc2_bn = nn.BatchNorm1d(out_features) 194 | self.fc2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 195 | # self.drop = nn.Dropout(0.1) 196 | 197 | self.c_hidden = hidden_features 198 | self.c_output = out_features 199 | 200 | def forward(self, x): 201 | T, B, C, H, W = x.shape 202 | N = H * W 203 | x = x.flatten(3) 204 | x = self.fc1_lif(x) 205 | x = self.fc1_conv(x.flatten(0, 1)) 206 | x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() 207 | 208 | x = self.fc2_lif(x) 209 | x = self.fc2_conv(x.flatten(0, 1)) 210 | x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous() 211 | 212 | return x 213 | 214 | 215 | class MS_Attention_RepConv_qkv_id(nn.Module): 216 | def __init__( 217 | self, 218 | dim, 219 | num_heads=8, 220 | qkv_bias=False, 221 | qk_scale=None, 222 | attn_drop=0.0, 223 | proj_drop=0.0, 224 | sr_ratio=1, 225 | ): 226 | super().__init__() 227 | assert ( 228 | dim % num_heads == 0 229 | ), f"dim {dim} should be divided by num_heads {num_heads}." 230 | self.dim = dim 231 | self.num_heads = num_heads 232 | self.scale = 0.125 233 | 234 | self.head_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 235 | 236 | self.q_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)) 237 | 238 | self.k_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)) 239 | 240 | self.v_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)) 241 | 242 | self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 243 | 244 | self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 245 | 246 | self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 247 | 248 | self.attn_lif = MultiStepLIFNode( 249 | tau=2.0, v_threshold=0.5, detach_reset=True, backend="cupy" 250 | ) 251 | 252 | self.proj_conv = nn.Sequential( 253 | RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim) 254 | ) 255 | 256 | def forward(self, x): 257 | T, B, C, H, W = x.shape 258 | N = H * W 259 | 260 | x = self.head_lif(x) 261 | 262 | q = self.q_conv(x.flatten(0, 1)).reshape(T, B, C, H, W) 263 | k = self.k_conv(x.flatten(0, 1)).reshape(T, B, C, H, W) 264 | v = self.v_conv(x.flatten(0, 1)).reshape(T, B, C, H, W) 265 | 266 | q = self.q_lif(q).flatten(3) 267 | q = ( 268 | q.transpose(-1, -2) 269 | .reshape(T, B, N, self.num_heads, C // self.num_heads) 270 | .permute(0, 1, 3, 2, 4) 271 | .contiguous() 272 | ) 273 | 274 | k = self.k_lif(k).flatten(3) 275 | k = ( 276 | k.transpose(-1, -2) 277 | .reshape(T, B, N, self.num_heads, C // self.num_heads) 278 | .permute(0, 1, 3, 2, 4) 279 | .contiguous() 280 | ) 281 | 282 | v = self.v_lif(v).flatten(3) 283 | v = ( 284 | v.transpose(-1, -2) 285 | .reshape(T, B, N, self.num_heads, C // self.num_heads) 286 | .permute(0, 1, 3, 2, 4) 287 | .contiguous() 288 | ) 289 | 290 | x = k.transpose(-2, -1) @ v 291 | x = (q @ x) * self.scale 292 | 293 | x = x.transpose(3, 4).reshape(T, B, C, N).contiguous() 294 | x = self.attn_lif(x).reshape(T, B, C, H, W) 295 | x = x.reshape(T, B, C, H, W) 296 | x = x.flatten(0, 1) 297 | x = self.proj_conv(x).reshape(T, B, C, H, W) 298 | 299 | return x 300 | 301 | 302 | class MS_Block(nn.Module): 303 | def __init__( 304 | self, 305 | dim, 306 | num_heads, 307 | mlp_ratio=4.0, 308 | qkv_bias=False, 309 | qk_scale=None, 310 | drop=0.0, 311 | attn_drop=0.0, 312 | drop_path=0.0, 313 | norm_layer=nn.LayerNorm, 314 | sr_ratio=1, 315 | ): 316 | super().__init__() 317 | 318 | self.attn = MS_Attention_RepConv_qkv_id( 319 | dim, 320 | num_heads=num_heads, 321 | qkv_bias=qkv_bias, 322 | qk_scale=qk_scale, 323 | attn_drop=attn_drop, 324 | proj_drop=drop, 325 | sr_ratio=sr_ratio, 326 | ) 327 | 328 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 329 | mlp_hidden_dim = int(dim * mlp_ratio) 330 | self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) 331 | 332 | def forward(self, x): 333 | x = x + self.attn(x) 334 | x = x + self.mlp(x) 335 | 336 | return x 337 | 338 | 339 | class MS_DownSampling(nn.Module): 340 | def __init__( 341 | self, 342 | in_channels=2, 343 | embed_dims=256, 344 | kernel_size=3, 345 | stride=2, 346 | padding=1, 347 | first_layer=True, 348 | ): 349 | super().__init__() 350 | 351 | self.encode_conv = nn.Conv2d( 352 | in_channels, 353 | embed_dims, 354 | kernel_size=kernel_size, 355 | stride=stride, 356 | padding=padding, 357 | ) 358 | 359 | self.encode_bn = nn.BatchNorm2d(embed_dims) 360 | if not first_layer: 361 | self.encode_lif = MultiStepLIFNode( 362 | tau=2.0, detach_reset=True, backend="cupy" 363 | ) 364 | 365 | def forward(self, x): 366 | T, B, _, _, _ = x.shape 367 | 368 | if hasattr(self, "encode_lif"): 369 | x = self.encode_lif(x) 370 | x = self.encode_conv(x.flatten(0, 1)) 371 | _, _, H, W = x.shape 372 | x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous() 373 | 374 | return x 375 | 376 | 377 | class Spiking_vit_MetaFormer(nn.Module): 378 | def __init__( 379 | self, 380 | img_size_h=128, 381 | img_size_w=128, 382 | patch_size=16, 383 | in_channels=2, 384 | num_classes=11, 385 | embed_dim=[64, 128, 256], 386 | num_heads=[1, 2, 4], 387 | mlp_ratios=[4, 4, 4], 388 | qkv_bias=False, 389 | qk_scale=None, 390 | drop_rate=0.0, 391 | attn_drop_rate=0.0, 392 | drop_path_rate=0.0, 393 | norm_layer=nn.LayerNorm, 394 | depths=[6, 8, 6], 395 | sr_ratios=[8, 4, 2], 396 | kd=False, 397 | ): 398 | super().__init__() 399 | self.num_classes = num_classes 400 | self.depths = depths 401 | self.T = 1 402 | # embed_dim = [64, 128, 256, 512] 403 | 404 | dpr = [ 405 | x.item() for x in torch.linspace(0, drop_path_rate, depths) 406 | ] # stochastic depth decay rule 407 | 408 | self.downsample1_1 = MS_DownSampling( 409 | in_channels=in_channels, 410 | embed_dims=embed_dim[0] // 2, 411 | kernel_size=7, 412 | stride=2, 413 | padding=3, 414 | first_layer=True, 415 | ) 416 | 417 | self.ConvBlock1_1 = nn.ModuleList( 418 | [MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] 419 | ) 420 | 421 | self.downsample1_2 = MS_DownSampling( 422 | in_channels=embed_dim[0] // 2, 423 | embed_dims=embed_dim[0], 424 | kernel_size=3, 425 | stride=2, 426 | padding=1, 427 | first_layer=False, 428 | ) 429 | 430 | self.ConvBlock1_2 = nn.ModuleList( 431 | [MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)] 432 | ) 433 | 434 | self.downsample2 = MS_DownSampling( 435 | in_channels=embed_dim[0], 436 | embed_dims=embed_dim[1], 437 | kernel_size=3, 438 | stride=2, 439 | padding=1, 440 | first_layer=False, 441 | ) 442 | 443 | self.ConvBlock2_1 = nn.ModuleList( 444 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] 445 | ) 446 | 447 | self.ConvBlock2_2 = nn.ModuleList( 448 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] 449 | ) 450 | 451 | self.downsample3 = MS_DownSampling( 452 | in_channels=embed_dim[1], 453 | embed_dims=embed_dim[2], 454 | kernel_size=3, 455 | stride=2, 456 | padding=1, 457 | first_layer=False, 458 | ) 459 | 460 | self.block3 = nn.ModuleList( 461 | [ 462 | MS_Block( 463 | dim=embed_dim[2], 464 | num_heads=num_heads, 465 | mlp_ratio=mlp_ratios, 466 | qkv_bias=qkv_bias, 467 | qk_scale=qk_scale, 468 | drop=drop_rate, 469 | attn_drop=attn_drop_rate, 470 | drop_path=dpr[j], 471 | norm_layer=norm_layer, 472 | sr_ratio=sr_ratios, 473 | ) 474 | for j in range(6) 475 | ] 476 | ) 477 | 478 | self.downsample4 = MS_DownSampling( 479 | in_channels=embed_dim[2], 480 | embed_dims=embed_dim[3], 481 | kernel_size=3, 482 | stride=1, 483 | padding=1, 484 | first_layer=False, 485 | ) 486 | 487 | self.block4 = nn.ModuleList( 488 | [ 489 | MS_Block( 490 | dim=embed_dim[3], 491 | num_heads=num_heads, 492 | mlp_ratio=mlp_ratios, 493 | qkv_bias=qkv_bias, 494 | qk_scale=qk_scale, 495 | drop=drop_rate, 496 | attn_drop=attn_drop_rate, 497 | drop_path=dpr[j], 498 | norm_layer=norm_layer, 499 | sr_ratio=sr_ratios, 500 | ) 501 | for j in range(2) 502 | ] 503 | ) 504 | 505 | self.lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 506 | self.head = ( 507 | nn.Linear(embed_dim[3], num_classes) if num_classes > 0 else nn.Identity() 508 | ) 509 | 510 | self.kd = kd 511 | if self.kd: 512 | self.head_kd = ( 513 | nn.Linear(embed_dim[3], num_classes) 514 | if num_classes > 0 515 | else nn.Identity() 516 | ) 517 | self.apply(self._init_weights) 518 | 519 | def _init_weights(self, m): 520 | if isinstance(m, nn.Linear): 521 | trunc_normal_(m.weight, std=0.02) 522 | if isinstance(m, nn.Linear) and m.bias is not None: 523 | nn.init.constant_(m.bias, 0) 524 | elif isinstance(m, nn.LayerNorm): 525 | nn.init.constant_(m.bias, 0) 526 | nn.init.constant_(m.weight, 1.0) 527 | 528 | def forward_features(self, x): 529 | x = self.downsample1_1(x) 530 | for blk in self.ConvBlock1_1: 531 | x = blk(x) 532 | x = self.downsample1_2(x) 533 | for blk in self.ConvBlock1_2: 534 | x = blk(x) 535 | 536 | x = self.downsample2(x) 537 | for blk in self.ConvBlock2_1: 538 | x = blk(x) 539 | for blk in self.ConvBlock2_2: 540 | x = blk(x) 541 | 542 | x = self.downsample3(x) 543 | for blk in self.block3: 544 | x = blk(x) 545 | 546 | x = self.downsample4(x) 547 | for blk in self.block4: 548 | x = blk(x) 549 | return x # T,B,C,N 550 | 551 | def forward(self, x): 552 | x = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1) 553 | x = self.forward_features(x) 554 | x = x.flatten(3).mean(3) 555 | x_lif = self.lif(x) 556 | x = self.head(x_lif).mean(0) 557 | if self.kd: 558 | x_kd = self.head_kd(x_lif).mean(0) 559 | if self.training: 560 | return x, x_kd 561 | else: 562 | return (x + x_kd) / 2 563 | return x 564 | 565 | 566 | def metaspikformer_8_384(**kwargs): 567 | model = Spiking_vit_MetaFormer( 568 | img_size_h=224, 569 | img_size_w=224, 570 | patch_size=16, 571 | embed_dim=[96, 192, 384, 480], 572 | num_heads=8, 573 | mlp_ratios=4, 574 | in_channels=3, 575 | num_classes=1000, 576 | qkv_bias=False, 577 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 578 | depths=8, 579 | sr_ratios=1, 580 | **kwargs, 581 | ) 582 | return model 583 | 584 | 585 | def metaspikformer_8_512(**kwargs): 586 | model = Spiking_vit_MetaFormer( 587 | img_size_h=224, 588 | img_size_w=224, 589 | patch_size=16, 590 | embed_dim=[128, 256, 512, 640], 591 | num_heads=8, 592 | mlp_ratios=4, 593 | in_channels=3, 594 | num_classes=1000, 595 | qkv_bias=False, 596 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 597 | depths=8, 598 | sr_ratios=1, 599 | **kwargs, 600 | ) 601 | return model 602 | 603 | 604 | def metaspikformer_8_768(**kwargs): 605 | model = Spiking_vit_MetaFormer( 606 | img_size_h=224, 607 | img_size_w=224, 608 | patch_size=16, 609 | embed_dim=[192, 384, 768, 960], 610 | num_heads=8, 611 | mlp_ratios=4, 612 | in_channels=3, 613 | num_classes=1000, 614 | qkv_bias=False, 615 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 616 | depths=8, 617 | sr_ratios=1, 618 | **kwargs, 619 | ) 620 | return model 621 | 622 | 623 | from timm.models import create_model 624 | 625 | -------------------------------------------------------------------------------- /classification/main_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | import importlib 20 | 21 | import torch 22 | 23 | # import torchinfo 24 | import torch.backends.cudnn as cudnn 25 | from torch.utils.tensorboard import SummaryWriter 26 | 27 | import timm 28 | 29 | # assert timm.__version__ == "0.5.4" # version check 30 | from timm.models.layers import trunc_normal_ 31 | import timm.optim.optim_factory as optim_factory 32 | from timm.data.mixup import Mixup 33 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 34 | 35 | import util.lr_decay_spikformer as lrd 36 | import util.misc as misc 37 | from util.datasets import build_dataset 38 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 39 | from util.kd_loss import DistillationLoss 40 | 41 | import models 42 | import models_sew 43 | 44 | from engine_finetune import train_one_epoch, evaluate 45 | from timm.data import create_loader 46 | 47 | 48 | def get_args_parser(): 49 | # important params 50 | parser = argparse.ArgumentParser( 51 | "MAE fine-tuning for image classification", add_help=False 52 | ) 53 | parser.add_argument( 54 | "--batch_size", 55 | default=64, 56 | type=int, 57 | help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus", 58 | ) 59 | parser.add_argument("--epochs", default=200, type=int) # 20/30(T=4) 60 | parser.add_argument( 61 | "--accum_iter", 62 | default=1, 63 | type=int, 64 | help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)", 65 | ) 66 | parser.add_argument("--finetune", default="", help="finetune from checkpoint") 67 | parser.add_argument( 68 | "--data_path", default="/raid/ligq/imagenet1-k/", type=str, help="dataset path" 69 | ) 70 | 71 | # Model parameters 72 | parser.add_argument( 73 | "--model", 74 | default="spikformer_8_384_CAFormer", 75 | type=str, 76 | metavar="MODEL", 77 | help="Name of model to train", 78 | ) 79 | parser.add_argument( 80 | "--model_mode", 81 | default="ms", 82 | type=str, 83 | help="Mode of model to train", 84 | ) 85 | 86 | parser.add_argument("--input_size", default=224, type=int, help="images input size") 87 | 88 | parser.add_argument( 89 | "--drop_path", 90 | type=float, 91 | default=0.1, 92 | metavar="PCT", 93 | help="Drop path rate (default: 0.1)", 94 | ) 95 | 96 | # Optimizer parameters 97 | parser.add_argument( 98 | "--clip_grad", 99 | type=float, 100 | default=None, 101 | metavar="NORM", 102 | help="Clip gradient norm (default: None, no clipping)", 103 | ) 104 | parser.add_argument( 105 | "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)" 106 | ) 107 | 108 | parser.add_argument( 109 | "--lr", 110 | type=float, 111 | default=None, 112 | metavar="LR", 113 | help="learning rate (absolute lr)", 114 | ) 115 | parser.add_argument( 116 | "--blr", 117 | type=float, 118 | default=6e-4, 119 | metavar="LR", # 1e-5,2e-5(T=4) 120 | help="base learning rate: absolute_lr = base_lr * total_batch_size / 256", 121 | ) 122 | parser.add_argument( 123 | "--layer_decay", 124 | type=float, 125 | default=1.0, 126 | help="layer-wise lr decay from ELECTRA/BEiT", 127 | ) 128 | 129 | parser.add_argument( 130 | "--min_lr", 131 | type=float, 132 | default=1e-6, 133 | metavar="LR", 134 | help="lower lr bound for cyclic schedulers that hit 0", 135 | ) 136 | 137 | parser.add_argument( 138 | "--warmup_epochs", type=int, default=10, metavar="N", help="epochs to warmup LR" 139 | ) 140 | 141 | # Augmentation parameters 142 | parser.add_argument( 143 | "--color_jitter", 144 | type=float, 145 | default=None, 146 | metavar="PCT", 147 | help="Color jitter factor (enabled only when not using Auto/RandAug)", 148 | ) 149 | parser.add_argument( 150 | "--aa", 151 | type=str, 152 | default="rand-m9-mstd0.5-inc1", 153 | metavar="NAME", 154 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)', 155 | ), 156 | parser.add_argument( 157 | "--smoothing", type=float, default=0.1, help="Label smoothing (default: 0.1)" 158 | ) 159 | 160 | # * Random Erase params 161 | parser.add_argument( 162 | "--reprob", 163 | type=float, 164 | default=0.25, 165 | metavar="PCT", 166 | help="Random erase prob (default: 0.25)", 167 | ) 168 | parser.add_argument( 169 | "--remode", 170 | type=str, 171 | default="pixel", 172 | help='Random erase mode (default: "pixel")', 173 | ) 174 | parser.add_argument( 175 | "--recount", type=int, default=1, help="Random erase count (default: 1)" 176 | ) 177 | parser.add_argument( 178 | "--resplit", 179 | action="store_true", 180 | default=False, 181 | help="Do not random erase first (clean) augmentation split", 182 | ) 183 | 184 | # * Mixup params 185 | parser.add_argument( 186 | "--mixup", type=float, default=0, help="mixup alpha, mixup enabled if > 0." 187 | ) 188 | parser.add_argument( 189 | "--cutmix", type=float, default=0, help="cutmix alpha, cutmix enabled if > 0." 190 | ) 191 | parser.add_argument( 192 | "--cutmix_minmax", 193 | type=float, 194 | nargs="+", 195 | default=None, 196 | help="cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)", 197 | ) 198 | parser.add_argument( 199 | "--mixup_prob", 200 | type=float, 201 | default=1.0, 202 | help="Probability of performing mixup or cutmix when either/both is enabled", 203 | ) 204 | parser.add_argument( 205 | "--mixup_switch_prob", 206 | type=float, 207 | default=0.5, 208 | help="Probability of switching to cutmix when both mixup and cutmix enabled", 209 | ) 210 | parser.add_argument( 211 | "--mixup_mode", 212 | type=str, 213 | default="batch", 214 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"', 215 | ) 216 | 217 | # * Finetuning params 218 | 219 | parser.add_argument("--global_pool", action="store_true") 220 | parser.set_defaults(global_pool=True) 221 | parser.add_argument( 222 | "--cls_token", 223 | action="store_false", 224 | dest="global_pool", 225 | help="Use class token instead of global pool for classification", 226 | ) 227 | parser.add_argument("--time_steps", default=1, type=int) 228 | 229 | # Dataset parameters 230 | 231 | parser.add_argument( 232 | "--nb_classes", 233 | default=1000, 234 | type=int, 235 | help="number of the classification types", 236 | ) 237 | 238 | parser.add_argument( 239 | "--output_dir", 240 | default="/raid/ligq/htx/spikemae/output_dir", 241 | help="path where to save, empty for no saving", 242 | ) 243 | parser.add_argument( 244 | "--log_dir", 245 | default="/raid/ligq/htx/spikemae/output_dir", 246 | help="path where to tensorboard log", 247 | ) 248 | parser.add_argument( 249 | "--device", default="cuda", help="device to use for training / testing" 250 | ) 251 | parser.add_argument("--seed", default=0, type=int) 252 | parser.add_argument("--resume", default=None, help="resume from checkpoint") 253 | 254 | parser.add_argument( 255 | "--start_epoch", default=0, type=int, metavar="N", help="start epoch" 256 | ) 257 | parser.add_argument("--eval", action="store_true", help="Perform evaluation only") 258 | parser.add_argument( 259 | "--dist_eval", 260 | action="store_true", 261 | default=False, 262 | help="Enabling distributed evaluation (recommended during training for faster monitor", 263 | ) 264 | parser.add_argument("--num_workers", default=10, type=int) 265 | parser.add_argument( 266 | "--pin_mem", 267 | action="store_true", 268 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", 269 | ) 270 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem") 271 | parser.set_defaults(pin_mem=True) 272 | 273 | # Distillation parameters 274 | parser.add_argument( 275 | "--kd", 276 | action="store_true", 277 | default=False, 278 | help="kd or not", 279 | ) 280 | parser.add_argument( 281 | "--teacher_model", 282 | default="caformer_b36_in21ft1k", 283 | type=str, 284 | metavar="MODEL", 285 | help='Name of teacher model to train (default: "caformer_b36_in21ft1k"', 286 | ) 287 | parser.add_argument( 288 | "--distillation_type", 289 | default="none", 290 | choices=["none", "soft", "hard"], 291 | type=str, 292 | help="", 293 | ) 294 | parser.add_argument("--distillation_alpha", default=0.5, type=float, help="") 295 | parser.add_argument("--distillation_tau", default=1.0, type=float, help="") 296 | 297 | # distributed training parameters 298 | parser.add_argument( 299 | "--world_size", default=1, type=int, help="number of distributed processes" 300 | ) 301 | parser.add_argument("--local-rank", default=-1, type=int) 302 | parser.add_argument("--dist_on_itp", action="store_true") 303 | parser.add_argument( 304 | "--dist_url", default="env://", help="url used to set up distributed training" 305 | ) 306 | 307 | return parser 308 | 309 | 310 | def main(args): 311 | misc.init_distributed_mode(args) 312 | 313 | print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) 314 | print("{}".format(args).replace(", ", ",\n")) 315 | 316 | device = torch.device(args.device) 317 | 318 | # fix the seed for reproducibility 319 | seed = args.seed + misc.get_rank() 320 | torch.manual_seed(seed) 321 | np.random.seed(seed) 322 | 323 | cudnn.benchmark = True 324 | 325 | dataset_train = build_dataset(is_train=True, args=args) 326 | dataset_val = build_dataset(is_train=False, args=args) 327 | 328 | if True: # args.distributed: 329 | num_tasks = misc.get_world_size() 330 | global_rank = misc.get_rank() 331 | sampler_train = torch.utils.data.DistributedSampler( 332 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 333 | ) 334 | print("Sampler_train = %s" % str(sampler_train)) 335 | if args.dist_eval: 336 | if len(dataset_val) % num_tasks != 0: 337 | print( 338 | "Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. " 339 | "This will slightly alter validation results as extra duplicate entries are added to achieve " 340 | "equal num of samples per-process." 341 | ) 342 | sampler_val = torch.utils.data.DistributedSampler( 343 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True 344 | ) # shuffle=True to reduce monitor bias 345 | else: 346 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 347 | else: 348 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 349 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 350 | 351 | if global_rank == 0 and args.log_dir is not None and not args.eval: 352 | os.makedirs(args.log_dir, exist_ok=True) 353 | log_writer = SummaryWriter(log_dir=args.log_dir) 354 | else: 355 | log_writer = None 356 | 357 | data_loader_train = torch.utils.data.DataLoader( 358 | dataset_train, 359 | sampler=sampler_train, 360 | batch_size=args.batch_size, 361 | num_workers=args.num_workers, 362 | pin_memory=args.pin_mem, 363 | drop_last=True, 364 | ) 365 | 366 | data_loader_val = torch.utils.data.DataLoader( 367 | dataset_val, 368 | sampler=sampler_val, 369 | batch_size=args.batch_size, 370 | num_workers=args.num_workers, 371 | pin_memory=args.pin_mem, 372 | drop_last=False, 373 | ) 374 | 375 | mixup_fn = None 376 | mixup_active = args.mixup > 0 or args.cutmix > 0.0 or args.cutmix_minmax is not None 377 | if mixup_active: 378 | print("Mixup is activated!") 379 | mixup_fn = Mixup( 380 | mixup_alpha=args.mixup, 381 | cutmix_alpha=args.cutmix, 382 | cutmix_minmax=args.cutmix_minmax, 383 | prob=args.mixup_prob, 384 | switch_prob=args.mixup_switch_prob, 385 | mode=args.mixup_mode, 386 | label_smoothing=args.smoothing, 387 | num_classes=args.nb_classes, 388 | ) 389 | 390 | if args.model_mode == "ms": 391 | model = models.__dict__[args.model](kd=args.kd) 392 | elif args.model_mode == "sew": 393 | model = models_sew.__dict__[args.model]() 394 | model.T = args.time_steps 395 | 396 | if args.finetune and not args.eval: 397 | checkpoint = torch.load(args.finetune, map_location="cpu") 398 | 399 | print("Load pre-trained checkpoint from: %s" % args.finetune) 400 | checkpoint_model = checkpoint["model"] 401 | # state_dict = model.state_dict() 402 | # for k in ["head.weight", "head.bias"]: 403 | # if ( 404 | # k in checkpoint_model 405 | # and checkpoint_model[k].shape != state_dict[k].shape 406 | # ): 407 | # print(f"Removing key {k} from pretrained checkpoint") 408 | # del checkpoint_model[k] # T=4注释 409 | 410 | # load pre-trained model 411 | msg = model.load_state_dict(checkpoint_model, strict=False) 412 | print(msg) 413 | 414 | # if args.global_pool: 415 | # assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 416 | # else: 417 | # assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 418 | 419 | # manually initialize fc layer 420 | # trunc_normal_(model.head.weight, std=2e-5) # T=4注释 421 | 422 | model.to(device) 423 | 424 | model_without_ddp = model 425 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 426 | 427 | print("Model = %s" % str(model_without_ddp)) 428 | print("number of params (M): %.2f" % (n_parameters / 1.0e6)) 429 | 430 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 431 | 432 | if args.lr is None: # only base_lr is specified 433 | args.lr = args.blr * eff_batch_size / 256 434 | 435 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 436 | print("actual lr: %.2e" % args.lr) 437 | 438 | print("accumulate grad iterations: %d" % args.accum_iter) 439 | print("effective batch size: %d" % eff_batch_size) 440 | 441 | if args.distributed: 442 | model = torch.nn.parallel.DistributedDataParallel( 443 | model, device_ids=[args.gpu], find_unused_parameters=False 444 | ) 445 | model_without_ddp = model.module 446 | 447 | # build optimizer with layer-wise lr decay (lrd) 448 | param_groups = lrd.param_groups_lrd( 449 | model_without_ddp, 450 | args.weight_decay, 451 | # no_weight_decay_list=model_without_ddp.no_weight_decay(), 452 | layer_decay=args.layer_decay, 453 | ) 454 | # optimizer = torch.optim.AdamW(param_groups, lr=args.lr) # lamb 455 | optimizer = optim_factory.Lamb(param_groups, trust_clip=True, lr=args.lr) 456 | loss_scaler = NativeScaler() 457 | 458 | if mixup_fn is not None: 459 | # smoothing is handled with mixup label transform 460 | criterion = SoftTargetCrossEntropy() 461 | elif args.smoothing > 0.0: 462 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 463 | else: 464 | criterion = torch.nn.CrossEntropyLoss() 465 | 466 | if args.kd: 467 | teacher_model = None 468 | if args.distillation_type == "none": 469 | args.distillation_type = "hard" 470 | print(f"Creating teacher model: {args.teacher_model}") 471 | # teacher_model_name = importlib.import_module("metaformer."+args.teacher_model) 472 | from metaformer import caformer_b36_in21ft1k 473 | 474 | teacher_model = caformer_b36_in21ft1k(pretrained=True) 475 | teacher_model.to(device) 476 | teacher_model.eval() 477 | # wrap the criterion in our custom DistillationLoss, which 478 | # just dispatches to the original criterion if args.distillation_type is 'none' 479 | criterion = DistillationLoss( 480 | criterion, 481 | teacher_model, 482 | args.distillation_type, 483 | args.distillation_alpha, 484 | args.distillation_tau, 485 | ) 486 | 487 | print("criterion = %s" % str(criterion)) 488 | 489 | misc.load_model( 490 | args=args, 491 | model_without_ddp=model_without_ddp, 492 | optimizer=optimizer, 493 | loss_scaler=loss_scaler, 494 | ) 495 | 496 | if args.eval: 497 | test_stats = evaluate(data_loader_val, model, device) 498 | print( 499 | f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" 500 | ) 501 | exit(0) 502 | 503 | print(f"Start training for {args.epochs} epochs") 504 | start_time = time.time() 505 | max_accuracy = 0.0 506 | best_acc = 0 507 | best_epoch = 0 508 | for epoch in range(args.start_epoch, args.epochs): 509 | if args.distributed: 510 | data_loader_train.sampler.set_epoch(epoch) 511 | train_stats = train_one_epoch( 512 | model, 513 | criterion, 514 | data_loader_train, 515 | optimizer, 516 | device, 517 | epoch, 518 | loss_scaler, 519 | args.clip_grad, 520 | mixup_fn, 521 | log_writer=log_writer, 522 | args=args, 523 | ) 524 | if args.output_dir and (epoch % 50 == 0 or epoch + 1 == args.epochs): 525 | print("Saving model at epoch:", epoch) 526 | misc.save_model( 527 | args=args, 528 | model=model, 529 | model_without_ddp=model_without_ddp, 530 | optimizer=optimizer, 531 | loss_scaler=loss_scaler, 532 | epoch=epoch, 533 | ) 534 | 535 | test_stats = evaluate(data_loader_val, model, device) 536 | print( 537 | f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" 538 | ) 539 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 540 | print(f"Max accuracy: {max_accuracy:.2f}%") 541 | if args.output_dir and test_stats["acc1"] > best_acc: 542 | print("Saving model at epoch:", epoch) 543 | misc.save_model( 544 | args=args, 545 | model=model, 546 | model_without_ddp=model_without_ddp, 547 | optimizer=optimizer, 548 | loss_scaler=loss_scaler, 549 | epoch=epoch, 550 | ) 551 | 552 | if log_writer is not None: 553 | log_writer.add_scalar("perf/test_acc1", test_stats["acc1"], epoch) 554 | log_writer.add_scalar("perf/test_acc5", test_stats["acc5"], epoch) 555 | log_writer.add_scalar("perf/test_loss", test_stats["loss"], epoch) 556 | 557 | log_stats = { 558 | **{f"train_{k}": v for k, v in train_stats.items()}, 559 | **{f"test_{k}": v for k, v in test_stats.items()}, 560 | "epoch": epoch, 561 | "n_parameters": n_parameters, 562 | } 563 | 564 | if args.output_dir and misc.is_main_process(): 565 | if log_writer is not None: 566 | log_writer.flush() 567 | with open( 568 | os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" 569 | ) as f: 570 | f.write(json.dumps(log_stats) + "\n") 571 | 572 | total_time = time.time() - start_time 573 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 574 | print("Training time {}".format(total_time_str)) 575 | 576 | 577 | if __name__ == "__main__": 578 | args = get_args_parser() 579 | args = args.parse_args() 580 | if args.output_dir: 581 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 582 | main(args) 583 | -------------------------------------------------------------------------------- /classification/metaformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, 17 | ConvFormer and CAFormer. 18 | Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models). 19 | """ 20 | from functools import partial 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | from timm.models.layers import trunc_normal_, DropPath 26 | from timm.models.registry import register_model 27 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 28 | from timm.models.layers import to_2tuple 29 | 30 | 31 | def _cfg(url="", **kwargs): 32 | return { 33 | "url": url, 34 | "num_classes": 1000, 35 | "input_size": (3, 224, 224), 36 | "pool_size": None, 37 | "crop_pct": 1.0, 38 | "interpolation": "bicubic", 39 | "mean": IMAGENET_DEFAULT_MEAN, 40 | "std": IMAGENET_DEFAULT_STD, 41 | "classifier": "head", 42 | **kwargs, 43 | } 44 | 45 | 46 | default_cfgs = { 47 | "identityformer_s12": _cfg( 48 | url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth" 49 | ), 50 | "identityformer_s24": _cfg( 51 | url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth" 52 | ), 53 | "identityformer_s36": _cfg( 54 | url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth" 55 | ), 56 | "identityformer_m36": _cfg( 57 | url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth" 58 | ), 59 | "identityformer_m48": _cfg( 60 | url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth" 61 | ), 62 | "randformer_s12": _cfg( 63 | url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth" 64 | ), 65 | "randformer_s24": _cfg( 66 | url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth" 67 | ), 68 | "randformer_s36": _cfg( 69 | url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth" 70 | ), 71 | "randformer_m36": _cfg( 72 | url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth" 73 | ), 74 | "randformer_m48": _cfg( 75 | url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth" 76 | ), 77 | "poolformerv2_s12": _cfg( 78 | url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth" 79 | ), 80 | "poolformerv2_s24": _cfg( 81 | url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth" 82 | ), 83 | "poolformerv2_s36": _cfg( 84 | url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth" 85 | ), 86 | "poolformerv2_m36": _cfg( 87 | url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth" 88 | ), 89 | "poolformerv2_m48": _cfg( 90 | url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth" 91 | ), 92 | "convformer_s18": _cfg( 93 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth" 94 | ), 95 | "convformer_s18_384": _cfg( 96 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth", 97 | input_size=(3, 384, 384), 98 | ), 99 | "convformer_s18_in21ft1k": _cfg( 100 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth" 101 | ), 102 | "convformer_s18_384_in21ft1k": _cfg( 103 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth", 104 | input_size=(3, 384, 384), 105 | ), 106 | "convformer_s18_in21k": _cfg( 107 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth", 108 | num_classes=21841, 109 | ), 110 | "convformer_s36": _cfg( 111 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth" 112 | ), 113 | "convformer_s36_384": _cfg( 114 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth", 115 | input_size=(3, 384, 384), 116 | ), 117 | "convformer_s36_in21ft1k": _cfg( 118 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth" 119 | ), 120 | "convformer_s36_384_in21ft1k": _cfg( 121 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth", 122 | input_size=(3, 384, 384), 123 | ), 124 | "convformer_s36_in21k": _cfg( 125 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth", 126 | num_classes=21841, 127 | ), 128 | "convformer_m36": _cfg( 129 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth" 130 | ), 131 | "convformer_m36_384": _cfg( 132 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth", 133 | input_size=(3, 384, 384), 134 | ), 135 | "convformer_m36_in21ft1k": _cfg( 136 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth" 137 | ), 138 | "convformer_m36_384_in21ft1k": _cfg( 139 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth", 140 | input_size=(3, 384, 384), 141 | ), 142 | "convformer_m36_in21k": _cfg( 143 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth", 144 | num_classes=21841, 145 | ), 146 | "convformer_b36": _cfg( 147 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth" 148 | ), 149 | "convformer_b36_384": _cfg( 150 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth", 151 | input_size=(3, 384, 384), 152 | ), 153 | "convformer_b36_in21ft1k": _cfg( 154 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth" 155 | ), 156 | "convformer_b36_384_in21ft1k": _cfg( 157 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth", 158 | input_size=(3, 384, 384), 159 | ), 160 | "convformer_b36_in21k": _cfg( 161 | url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth", 162 | num_classes=21841, 163 | ), 164 | "caformer_s18": _cfg( 165 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth" 166 | ), 167 | "caformer_s18_384": _cfg( 168 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth", 169 | input_size=(3, 384, 384), 170 | ), 171 | "caformer_s18_in21ft1k": _cfg( 172 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth" 173 | ), 174 | "caformer_s18_384_in21ft1k": _cfg( 175 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth", 176 | input_size=(3, 384, 384), 177 | ), 178 | "caformer_s18_in21k": _cfg( 179 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth", 180 | num_classes=21841, 181 | ), 182 | "caformer_s36": _cfg( 183 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth" 184 | ), 185 | "caformer_s36_384": _cfg( 186 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth", 187 | input_size=(3, 384, 384), 188 | ), 189 | "caformer_s36_in21ft1k": _cfg( 190 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth" 191 | ), 192 | "caformer_s36_384_in21ft1k": _cfg( 193 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth", 194 | input_size=(3, 384, 384), 195 | ), 196 | "caformer_s36_in21k": _cfg( 197 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth", 198 | num_classes=21841, 199 | ), 200 | "caformer_m36": _cfg( 201 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth" 202 | ), 203 | "caformer_m36_384": _cfg( 204 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth", 205 | input_size=(3, 384, 384), 206 | ), 207 | "caformer_m36_in21ft1k": _cfg( 208 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth" 209 | ), 210 | "caformer_m36_384_in21ft1k": _cfg( 211 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth", 212 | input_size=(3, 384, 384), 213 | ), 214 | "caformer_m36_in21k": _cfg( 215 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth", 216 | num_classes=21841, 217 | ), 218 | "caformer_b36": _cfg( 219 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth" 220 | ), 221 | "caformer_b36_384": _cfg( 222 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth", 223 | input_size=(3, 384, 384), 224 | ), 225 | "caformer_b36_in21ft1k": _cfg( 226 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth" 227 | ), 228 | "caformer_b36_384_in21ft1k": _cfg( 229 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth", 230 | input_size=(3, 384, 384), 231 | ), 232 | "caformer_b36_in21k": _cfg( 233 | url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth", 234 | num_classes=21841, 235 | ), 236 | } 237 | 238 | 239 | class Downsampling(nn.Module): 240 | """ 241 | Downsampling implemented by a layer of convolution. 242 | """ 243 | 244 | def __init__( 245 | self, 246 | in_channels, 247 | out_channels, 248 | kernel_size, 249 | stride=1, 250 | padding=0, 251 | pre_norm=None, 252 | post_norm=None, 253 | pre_permute=False, 254 | ): 255 | super().__init__() 256 | self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() 257 | self.pre_permute = pre_permute 258 | self.conv = nn.Conv2d( 259 | in_channels, 260 | out_channels, 261 | kernel_size=kernel_size, 262 | stride=stride, 263 | padding=padding, 264 | ) 265 | self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() 266 | 267 | def forward(self, x): 268 | x = self.pre_norm(x) 269 | if self.pre_permute: 270 | # if take [B, H, W, C] as input, permute it to [B, C, H, W] 271 | x = x.permute(0, 3, 1, 2) 272 | x = self.conv(x) 273 | x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] 274 | x = self.post_norm(x) 275 | return x 276 | 277 | 278 | class Scale(nn.Module): 279 | """ 280 | Scale vector by element multiplications. 281 | """ 282 | 283 | def __init__(self, dim, init_value=1.0, trainable=True): 284 | super().__init__() 285 | self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) 286 | 287 | def forward(self, x): 288 | return x * self.scale 289 | 290 | 291 | class SquaredReLU(nn.Module): 292 | """ 293 | Squared ReLU: https://arxiv.org/abs/2109.08668 294 | """ 295 | 296 | def __init__(self, inplace=False): 297 | super().__init__() 298 | self.relu = nn.ReLU(inplace=inplace) 299 | 300 | def forward(self, x): 301 | return torch.square(self.relu(x)) 302 | 303 | 304 | class StarReLU(nn.Module): 305 | """ 306 | StarReLU: s * relu(x) ** 2 + b 307 | """ 308 | 309 | def __init__( 310 | self, 311 | scale_value=1.0, 312 | bias_value=0.0, 313 | scale_learnable=True, 314 | bias_learnable=True, 315 | mode=None, 316 | inplace=False, 317 | ): 318 | super().__init__() 319 | self.inplace = inplace 320 | self.relu = nn.ReLU(inplace=inplace) 321 | self.scale = nn.Parameter( 322 | scale_value * torch.ones(1), requires_grad=scale_learnable 323 | ) 324 | self.bias = nn.Parameter( 325 | bias_value * torch.ones(1), requires_grad=bias_learnable 326 | ) 327 | 328 | def forward(self, x): 329 | return self.scale * self.relu(x) ** 2 + self.bias 330 | 331 | 332 | class Attention(nn.Module): 333 | """ 334 | Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762. 335 | Modified from timm. 336 | """ 337 | 338 | def __init__( 339 | self, 340 | dim, 341 | head_dim=32, 342 | num_heads=None, 343 | qkv_bias=False, 344 | attn_drop=0.0, 345 | proj_drop=0.0, 346 | proj_bias=False, 347 | **kwargs, 348 | ): 349 | super().__init__() 350 | 351 | self.head_dim = head_dim 352 | self.scale = head_dim**-0.5 353 | 354 | self.num_heads = num_heads if num_heads else dim // head_dim 355 | if self.num_heads == 0: 356 | self.num_heads = 1 357 | 358 | self.attention_dim = self.num_heads * self.head_dim 359 | 360 | self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) 361 | self.attn_drop = nn.Dropout(attn_drop) 362 | self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) 363 | self.proj_drop = nn.Dropout(proj_drop) 364 | 365 | def forward(self, x): 366 | B, H, W, C = x.shape 367 | N = H * W 368 | qkv = ( 369 | self.qkv(x) 370 | .reshape(B, N, 3, self.num_heads, self.head_dim) 371 | .permute(2, 0, 3, 1, 4) 372 | ) 373 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 374 | 375 | attn = (q @ k.transpose(-2, -1)) * self.scale 376 | attn = attn.softmax(dim=-1) 377 | attn = self.attn_drop(attn) 378 | 379 | x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim) 380 | x = self.proj(x) 381 | x = self.proj_drop(x) 382 | return x 383 | 384 | 385 | class RandomMixing(nn.Module): 386 | def __init__(self, num_tokens=196, **kwargs): 387 | super().__init__() 388 | self.random_matrix = nn.parameter.Parameter( 389 | data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), 390 | requires_grad=False, 391 | ) 392 | 393 | def forward(self, x): 394 | B, H, W, C = x.shape 395 | x = x.reshape(B, H * W, C) 396 | x = torch.einsum("mn, bnc -> bmc", self.random_matrix, x) 397 | x = x.reshape(B, H, W, C) 398 | return x 399 | 400 | 401 | class LayerNormGeneral(nn.Module): 402 | r"""General LayerNorm for different situations. 403 | 404 | Args: 405 | affine_shape (int, list or tuple): The shape of affine weight and bias. 406 | Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm, 407 | the affine_shape is the same as normalized_dim by default. 408 | To adapt to different situations, we offer this argument here. 409 | normalized_dim (tuple or list): Which dims to compute mean and variance. 410 | scale (bool): Flag indicates whether to use scale or not. 411 | bias (bool): Flag indicates whether to use scale or not. 412 | 413 | We give several examples to show how to specify the arguments. 414 | 415 | LayerNorm (https://arxiv.org/abs/1607.06450): 416 | For input shape of (B, *, C) like (B, N, C) or (B, H, W, C), 417 | affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True; 418 | For input shape of (B, C, H, W), 419 | affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True. 420 | 421 | Modified LayerNorm (https://arxiv.org/abs/2111.11418) 422 | that is idental to partial(torch.nn.GroupNorm, num_groups=1): 423 | For input shape of (B, N, C), 424 | affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True; 425 | For input shape of (B, H, W, C), 426 | affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True; 427 | For input shape of (B, C, H, W), 428 | affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True. 429 | 430 | For the several metaformer baslines, 431 | IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False); 432 | ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False). 433 | """ 434 | 435 | def __init__( 436 | self, affine_shape=None, normalized_dim=(-1,), scale=True, bias=True, eps=1e-5 437 | ): 438 | super().__init__() 439 | self.normalized_dim = normalized_dim 440 | self.use_scale = scale 441 | self.use_bias = bias 442 | self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None 443 | self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None 444 | self.eps = eps 445 | 446 | def forward(self, x): 447 | c = x - x.mean(self.normalized_dim, keepdim=True) 448 | s = c.pow(2).mean(self.normalized_dim, keepdim=True) 449 | x = c / torch.sqrt(s + self.eps) 450 | if self.use_scale: 451 | x = x * self.weight 452 | if self.use_bias: 453 | x = x + self.bias 454 | return x 455 | 456 | 457 | class LayerNormWithoutBias(nn.Module): 458 | """ 459 | Equal to partial(LayerNormGeneral, bias=False) but faster, 460 | because it directly utilizes otpimized F.layer_norm 461 | """ 462 | 463 | def __init__(self, normalized_shape, eps=1e-5, **kwargs): 464 | super().__init__() 465 | self.eps = eps 466 | self.bias = None 467 | if isinstance(normalized_shape, int): 468 | normalized_shape = (normalized_shape,) 469 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 470 | self.normalized_shape = normalized_shape 471 | 472 | def forward(self, x): 473 | return F.layer_norm( 474 | x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps 475 | ) 476 | 477 | 478 | class SepConv(nn.Module): 479 | r""" 480 | Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. 481 | """ 482 | 483 | def __init__( 484 | self, 485 | dim, 486 | expansion_ratio=2, 487 | act1_layer=StarReLU, 488 | act2_layer=nn.Identity, 489 | bias=False, 490 | kernel_size=7, 491 | padding=3, 492 | **kwargs, 493 | ): 494 | super().__init__() 495 | med_channels = int(expansion_ratio * dim) 496 | self.pwconv1 = nn.Linear(dim, med_channels, bias=bias) 497 | self.act1 = act1_layer() 498 | self.dwconv = nn.Conv2d( 499 | med_channels, 500 | med_channels, 501 | kernel_size=kernel_size, 502 | padding=padding, 503 | groups=med_channels, 504 | bias=bias, 505 | ) # depthwise conv 506 | self.act2 = act2_layer() 507 | self.pwconv2 = nn.Linear(med_channels, dim, bias=bias) 508 | 509 | def forward(self, x): 510 | x = self.pwconv1(x) 511 | x = self.act1(x) 512 | x = x.permute(0, 3, 1, 2) 513 | x = self.dwconv(x) 514 | x = x.permute(0, 2, 3, 1) 515 | x = self.act2(x) 516 | x = self.pwconv2(x) 517 | return x 518 | 519 | 520 | class Pooling(nn.Module): 521 | """ 522 | Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418 523 | Modfiled for [B, H, W, C] input 524 | """ 525 | 526 | def __init__(self, pool_size=3, **kwargs): 527 | super().__init__() 528 | self.pool = nn.AvgPool2d( 529 | pool_size, stride=1, padding=pool_size // 2, count_include_pad=False 530 | ) 531 | 532 | def forward(self, x): 533 | y = x.permute(0, 3, 1, 2) 534 | y = self.pool(y) 535 | y = y.permute(0, 2, 3, 1) 536 | return y - x 537 | 538 | 539 | class Mlp(nn.Module): 540 | """MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. 541 | Mostly copied from timm. 542 | """ 543 | 544 | def __init__( 545 | self, 546 | dim, 547 | mlp_ratio=4, 548 | out_features=None, 549 | act_layer=StarReLU, 550 | drop=0.0, 551 | bias=False, 552 | **kwargs, 553 | ): 554 | super().__init__() 555 | in_features = dim 556 | out_features = out_features or in_features 557 | hidden_features = int(mlp_ratio * in_features) 558 | drop_probs = to_2tuple(drop) 559 | 560 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 561 | self.act = act_layer() 562 | self.drop1 = nn.Dropout(drop_probs[0]) 563 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 564 | self.drop2 = nn.Dropout(drop_probs[1]) 565 | 566 | def forward(self, x): 567 | x = self.fc1(x) 568 | x = self.act(x) 569 | x = self.drop1(x) 570 | x = self.fc2(x) 571 | x = self.drop2(x) 572 | return x 573 | 574 | 575 | class MlpHead(nn.Module): 576 | """MLP classification head""" 577 | 578 | def __init__( 579 | self, 580 | dim, 581 | num_classes=1000, 582 | mlp_ratio=4, 583 | act_layer=SquaredReLU, 584 | norm_layer=nn.LayerNorm, 585 | head_dropout=0.0, 586 | bias=True, 587 | ): 588 | super().__init__() 589 | hidden_features = int(mlp_ratio * dim) 590 | self.fc1 = nn.Linear(dim, hidden_features, bias=bias) 591 | self.act = act_layer() 592 | self.norm = norm_layer(hidden_features) 593 | self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) 594 | self.head_dropout = nn.Dropout(head_dropout) 595 | 596 | def forward(self, x): 597 | x = self.fc1(x) 598 | x = self.act(x) 599 | x = self.norm(x) 600 | x = self.head_dropout(x) 601 | x = self.fc2(x) 602 | return x 603 | 604 | 605 | class MetaFormerBlock(nn.Module): 606 | """ 607 | Implementation of one MetaFormer block. 608 | """ 609 | 610 | def __init__( 611 | self, 612 | dim, 613 | token_mixer=nn.Identity, 614 | mlp=Mlp, 615 | norm_layer=nn.LayerNorm, 616 | drop=0.0, 617 | drop_path=0.0, 618 | layer_scale_init_value=None, 619 | res_scale_init_value=None, 620 | ): 621 | super().__init__() 622 | 623 | self.norm1 = norm_layer(dim) 624 | self.token_mixer = token_mixer(dim=dim, drop=drop) 625 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 626 | self.layer_scale1 = ( 627 | Scale(dim=dim, init_value=layer_scale_init_value) 628 | if layer_scale_init_value 629 | else nn.Identity() 630 | ) 631 | self.res_scale1 = ( 632 | Scale(dim=dim, init_value=res_scale_init_value) 633 | if res_scale_init_value 634 | else nn.Identity() 635 | ) 636 | 637 | self.norm2 = norm_layer(dim) 638 | self.mlp = mlp(dim=dim, drop=drop) 639 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 640 | self.layer_scale2 = ( 641 | Scale(dim=dim, init_value=layer_scale_init_value) 642 | if layer_scale_init_value 643 | else nn.Identity() 644 | ) 645 | self.res_scale2 = ( 646 | Scale(dim=dim, init_value=res_scale_init_value) 647 | if res_scale_init_value 648 | else nn.Identity() 649 | ) 650 | 651 | def forward(self, x): 652 | x = self.res_scale1(x) + self.layer_scale1( 653 | self.drop_path1(self.token_mixer(self.norm1(x))) 654 | ) 655 | x = self.res_scale2(x) + self.layer_scale2( 656 | self.drop_path2(self.mlp(self.norm2(x))) 657 | ) 658 | return x 659 | 660 | 661 | r""" 662 | downsampling (stem) for the first stage is a layer of conv with k7, s4 and p2 663 | downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1 664 | DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling] 665 | use `partial` to specify some arguments 666 | """ 667 | DOWNSAMPLE_LAYERS_FOUR_STAGES = [ 668 | partial( 669 | Downsampling, 670 | kernel_size=7, 671 | stride=4, 672 | padding=2, 673 | post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), 674 | ) 675 | ] + [ 676 | partial( 677 | Downsampling, 678 | kernel_size=3, 679 | stride=2, 680 | padding=1, 681 | pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), 682 | pre_permute=True, 683 | ) 684 | ] * 3 685 | 686 | 687 | class MetaFormer(nn.Module): 688 | r"""MetaFormer 689 | A PyTorch impl of : `MetaFormer Baselines for Vision` - 690 | https://arxiv.org/abs/2210.13452 691 | 692 | Args: 693 | in_chans (int): Number of input image channels. Default: 3. 694 | num_classes (int): Number of classes for classification head. Default: 1000. 695 | depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2]. 696 | dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512]. 697 | downsample_layers: (list or tuple): Downsampling layers before each stage. 698 | token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity. 699 | mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp. 700 | norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False). 701 | drop_path_rate (float): Stochastic depth rate. Default: 0. 702 | head_dropout (float): dropout for MLP classifier. Default: 0. 703 | layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None. 704 | None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239. 705 | res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0]. 706 | None means not use the layer scale. From: https://arxiv.org/abs/2110.09456. 707 | output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6). 708 | head_fn: classification head. Default: nn.Linear. 709 | """ 710 | 711 | def __init__( 712 | self, 713 | in_chans=3, 714 | num_classes=1000, 715 | depths=[2, 2, 6, 2], 716 | dims=[64, 128, 320, 512], 717 | downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, 718 | token_mixers=nn.Identity, 719 | mlps=Mlp, 720 | norm_layers=partial( 721 | LayerNormWithoutBias, eps=1e-6 722 | ), # partial(LayerNormGeneral, eps=1e-6, bias=False), 723 | drop_path_rate=0.0, 724 | head_dropout=0.0, 725 | layer_scale_init_values=None, 726 | res_scale_init_values=[None, None, 1.0, 1.0], 727 | output_norm=partial(nn.LayerNorm, eps=1e-6), 728 | head_fn=nn.Linear, 729 | **kwargs, 730 | ): 731 | super().__init__() 732 | self.num_classes = num_classes 733 | 734 | if not isinstance(depths, (list, tuple)): 735 | depths = [depths] # it means the model has only one stage 736 | if not isinstance(dims, (list, tuple)): 737 | dims = [dims] 738 | 739 | num_stage = len(depths) 740 | self.num_stage = num_stage 741 | 742 | if not isinstance(downsample_layers, (list, tuple)): 743 | downsample_layers = [downsample_layers] * num_stage 744 | down_dims = [in_chans] + dims 745 | self.downsample_layers = nn.ModuleList( 746 | [ 747 | downsample_layers[i](down_dims[i], down_dims[i + 1]) 748 | for i in range(num_stage) 749 | ] 750 | ) 751 | 752 | if not isinstance(token_mixers, (list, tuple)): 753 | token_mixers = [token_mixers] * num_stage 754 | 755 | if not isinstance(mlps, (list, tuple)): 756 | mlps = [mlps] * num_stage 757 | 758 | if not isinstance(norm_layers, (list, tuple)): 759 | norm_layers = [norm_layers] * num_stage 760 | 761 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 762 | 763 | if not isinstance(layer_scale_init_values, (list, tuple)): 764 | layer_scale_init_values = [layer_scale_init_values] * num_stage 765 | if not isinstance(res_scale_init_values, (list, tuple)): 766 | res_scale_init_values = [res_scale_init_values] * num_stage 767 | 768 | self.stages = ( 769 | nn.ModuleList() 770 | ) # each stage consists of multiple metaformer blocks 771 | cur = 0 772 | for i in range(num_stage): 773 | stage = nn.Sequential( 774 | *[ 775 | MetaFormerBlock( 776 | dim=dims[i], 777 | token_mixer=token_mixers[i], 778 | mlp=mlps[i], 779 | norm_layer=norm_layers[i], 780 | drop_path=dp_rates[cur + j], 781 | layer_scale_init_value=layer_scale_init_values[i], 782 | res_scale_init_value=res_scale_init_values[i], 783 | ) 784 | for j in range(depths[i]) 785 | ] 786 | ) 787 | self.stages.append(stage) 788 | cur += depths[i] 789 | 790 | self.norm = output_norm(dims[-1]) 791 | 792 | if head_dropout > 0.0: 793 | self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout) 794 | else: 795 | self.head = head_fn(dims[-1], num_classes) 796 | 797 | self.apply(self._init_weights) 798 | 799 | def _init_weights(self, m): 800 | if isinstance(m, (nn.Conv2d, nn.Linear)): 801 | trunc_normal_(m.weight, std=0.02) 802 | if m.bias is not None: 803 | nn.init.constant_(m.bias, 0) 804 | 805 | @torch.jit.ignore 806 | def no_weight_decay(self): 807 | return {"norm"} 808 | 809 | def forward_features(self, x): 810 | for i in range(self.num_stage): 811 | x = self.downsample_layers[i](x) 812 | x = self.stages[i](x) 813 | return self.norm(x.mean([1, 2])) # (B, H, W, C) -> (B, C) 814 | 815 | def forward(self, x): 816 | x = self.forward_features(x) 817 | x = self.head(x) 818 | return x 819 | 820 | 821 | @register_model 822 | def identityformer_s12(pretrained=False, **kwargs): 823 | model = MetaFormer( 824 | depths=[2, 2, 6, 2], 825 | dims=[64, 128, 320, 512], 826 | token_mixers=nn.Identity, 827 | norm_layers=partial( 828 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 829 | ), 830 | **kwargs, 831 | ) 832 | model.default_cfg = default_cfgs["identityformer_s12"] 833 | if pretrained: 834 | state_dict = torch.hub.load_state_dict_from_url( 835 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 836 | ) 837 | model.load_state_dict(state_dict) 838 | return model 839 | 840 | 841 | @register_model 842 | def identityformer_s24(pretrained=False, **kwargs): 843 | model = MetaFormer( 844 | depths=[4, 4, 12, 4], 845 | dims=[64, 128, 320, 512], 846 | token_mixers=nn.Identity, 847 | norm_layers=partial( 848 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 849 | ), 850 | **kwargs, 851 | ) 852 | model.default_cfg = default_cfgs["identityformer_s24"] 853 | if pretrained: 854 | state_dict = torch.hub.load_state_dict_from_url( 855 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 856 | ) 857 | model.load_state_dict(state_dict) 858 | return model 859 | 860 | 861 | @register_model 862 | def identityformer_s36(pretrained=False, **kwargs): 863 | model = MetaFormer( 864 | depths=[6, 6, 18, 6], 865 | dims=[64, 128, 320, 512], 866 | token_mixers=nn.Identity, 867 | norm_layers=partial( 868 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 869 | ), 870 | **kwargs, 871 | ) 872 | model.default_cfg = default_cfgs["identityformer_s36"] 873 | if pretrained: 874 | state_dict = torch.hub.load_state_dict_from_url( 875 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 876 | ) 877 | model.load_state_dict(state_dict) 878 | return model 879 | 880 | 881 | @register_model 882 | def identityformer_m36(pretrained=False, **kwargs): 883 | model = MetaFormer( 884 | depths=[6, 6, 18, 6], 885 | dims=[96, 192, 384, 768], 886 | token_mixers=nn.Identity, 887 | norm_layers=partial( 888 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 889 | ), 890 | **kwargs, 891 | ) 892 | model.default_cfg = default_cfgs["identityformer_m36"] 893 | if pretrained: 894 | state_dict = torch.hub.load_state_dict_from_url( 895 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 896 | ) 897 | model.load_state_dict(state_dict) 898 | return model 899 | 900 | 901 | @register_model 902 | def identityformer_m48(pretrained=False, **kwargs): 903 | model = MetaFormer( 904 | depths=[8, 8, 24, 8], 905 | dims=[96, 192, 384, 768], 906 | token_mixers=nn.Identity, 907 | norm_layers=partial( 908 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 909 | ), 910 | **kwargs, 911 | ) 912 | model.default_cfg = default_cfgs["identityformer_m48"] 913 | if pretrained: 914 | state_dict = torch.hub.load_state_dict_from_url( 915 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 916 | ) 917 | model.load_state_dict(state_dict) 918 | return model 919 | 920 | 921 | @register_model 922 | def randformer_s12(pretrained=False, **kwargs): 923 | model = MetaFormer( 924 | depths=[2, 2, 6, 2], 925 | dims=[64, 128, 320, 512], 926 | token_mixers=[ 927 | nn.Identity, 928 | nn.Identity, 929 | RandomMixing, 930 | partial(RandomMixing, num_tokens=49), 931 | ], 932 | norm_layers=partial( 933 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 934 | ), 935 | **kwargs, 936 | ) 937 | model.default_cfg = default_cfgs["randformer_s12"] 938 | if pretrained: 939 | state_dict = torch.hub.load_state_dict_from_url( 940 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 941 | ) 942 | model.load_state_dict(state_dict) 943 | return model 944 | 945 | 946 | @register_model 947 | def randformer_s24(pretrained=False, **kwargs): 948 | model = MetaFormer( 949 | depths=[4, 4, 12, 4], 950 | dims=[64, 128, 320, 512], 951 | token_mixers=[ 952 | nn.Identity, 953 | nn.Identity, 954 | RandomMixing, 955 | partial(RandomMixing, num_tokens=49), 956 | ], 957 | norm_layers=partial( 958 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 959 | ), 960 | **kwargs, 961 | ) 962 | model.default_cfg = default_cfgs["randformer_s24"] 963 | if pretrained: 964 | state_dict = torch.hub.load_state_dict_from_url( 965 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 966 | ) 967 | model.load_state_dict(state_dict) 968 | return model 969 | 970 | 971 | @register_model 972 | def randformer_s36(pretrained=False, **kwargs): 973 | model = MetaFormer( 974 | depths=[6, 6, 18, 6], 975 | dims=[64, 128, 320, 512], 976 | token_mixers=[ 977 | nn.Identity, 978 | nn.Identity, 979 | RandomMixing, 980 | partial(RandomMixing, num_tokens=49), 981 | ], 982 | norm_layers=partial( 983 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 984 | ), 985 | **kwargs, 986 | ) 987 | model.default_cfg = default_cfgs["randformer_s36"] 988 | if pretrained: 989 | state_dict = torch.hub.load_state_dict_from_url( 990 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 991 | ) 992 | model.load_state_dict(state_dict) 993 | return model 994 | 995 | 996 | @register_model 997 | def randformer_m36(pretrained=False, **kwargs): 998 | model = MetaFormer( 999 | depths=[6, 6, 18, 6], 1000 | dims=[96, 192, 384, 768], 1001 | token_mixers=[ 1002 | nn.Identity, 1003 | nn.Identity, 1004 | RandomMixing, 1005 | partial(RandomMixing, num_tokens=49), 1006 | ], 1007 | norm_layers=partial( 1008 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 1009 | ), 1010 | **kwargs, 1011 | ) 1012 | model.default_cfg = default_cfgs["randformer_m36"] 1013 | if pretrained: 1014 | state_dict = torch.hub.load_state_dict_from_url( 1015 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1016 | ) 1017 | model.load_state_dict(state_dict) 1018 | return model 1019 | 1020 | 1021 | @register_model 1022 | def randformer_m48(pretrained=False, **kwargs): 1023 | model = MetaFormer( 1024 | depths=[8, 8, 24, 8], 1025 | dims=[96, 192, 384, 768], 1026 | token_mixers=[ 1027 | nn.Identity, 1028 | nn.Identity, 1029 | RandomMixing, 1030 | partial(RandomMixing, num_tokens=49), 1031 | ], 1032 | norm_layers=partial( 1033 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 1034 | ), 1035 | **kwargs, 1036 | ) 1037 | model.default_cfg = default_cfgs["randformer_m48"] 1038 | if pretrained: 1039 | state_dict = torch.hub.load_state_dict_from_url( 1040 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1041 | ) 1042 | model.load_state_dict(state_dict) 1043 | return model 1044 | 1045 | 1046 | @register_model 1047 | def poolformerv2_s12(pretrained=False, **kwargs): 1048 | model = MetaFormer( 1049 | depths=[2, 2, 6, 2], 1050 | dims=[64, 128, 320, 512], 1051 | token_mixers=Pooling, 1052 | norm_layers=partial( 1053 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 1054 | ), 1055 | **kwargs, 1056 | ) 1057 | model.default_cfg = default_cfgs["poolformerv2_s12"] 1058 | if pretrained: 1059 | state_dict = torch.hub.load_state_dict_from_url( 1060 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1061 | ) 1062 | model.load_state_dict(state_dict) 1063 | return model 1064 | 1065 | 1066 | @register_model 1067 | def poolformerv2_s24(pretrained=False, **kwargs): 1068 | model = MetaFormer( 1069 | depths=[4, 4, 12, 4], 1070 | dims=[64, 128, 320, 512], 1071 | token_mixers=Pooling, 1072 | norm_layers=partial( 1073 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 1074 | ), 1075 | **kwargs, 1076 | ) 1077 | model.default_cfg = default_cfgs["poolformerv2_s24"] 1078 | if pretrained: 1079 | state_dict = torch.hub.load_state_dict_from_url( 1080 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1081 | ) 1082 | model.load_state_dict(state_dict) 1083 | return model 1084 | 1085 | 1086 | @register_model 1087 | def poolformerv2_s36(pretrained=False, **kwargs): 1088 | model = MetaFormer( 1089 | depths=[6, 6, 18, 6], 1090 | dims=[64, 128, 320, 512], 1091 | token_mixers=Pooling, 1092 | norm_layers=partial( 1093 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 1094 | ), 1095 | **kwargs, 1096 | ) 1097 | model.default_cfg = default_cfgs["poolformerv2_s36"] 1098 | if pretrained: 1099 | state_dict = torch.hub.load_state_dict_from_url( 1100 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1101 | ) 1102 | model.load_state_dict(state_dict) 1103 | return model 1104 | 1105 | 1106 | @register_model 1107 | def poolformerv2_m36(pretrained=False, **kwargs): 1108 | model = MetaFormer( 1109 | depths=[6, 6, 18, 6], 1110 | dims=[96, 192, 384, 768], 1111 | token_mixers=Pooling, 1112 | norm_layers=partial( 1113 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 1114 | ), 1115 | **kwargs, 1116 | ) 1117 | model.default_cfg = default_cfgs["poolformerv2_m36"] 1118 | if pretrained: 1119 | state_dict = torch.hub.load_state_dict_from_url( 1120 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1121 | ) 1122 | model.load_state_dict(state_dict) 1123 | return model 1124 | 1125 | 1126 | @register_model 1127 | def poolformerv2_m48(pretrained=False, **kwargs): 1128 | model = MetaFormer( 1129 | depths=[8, 8, 24, 8], 1130 | dims=[96, 192, 384, 768], 1131 | token_mixers=Pooling, 1132 | norm_layers=partial( 1133 | LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False 1134 | ), 1135 | **kwargs, 1136 | ) 1137 | model.default_cfg = default_cfgs["poolformerv2_m48"] 1138 | if pretrained: 1139 | state_dict = torch.hub.load_state_dict_from_url( 1140 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1141 | ) 1142 | model.load_state_dict(state_dict) 1143 | return model 1144 | 1145 | 1146 | @register_model 1147 | def convformer_s18(pretrained=False, **kwargs): 1148 | model = MetaFormer( 1149 | depths=[3, 3, 9, 3], 1150 | dims=[64, 128, 320, 512], 1151 | token_mixers=SepConv, 1152 | head_fn=MlpHead, 1153 | **kwargs, 1154 | ) 1155 | model.default_cfg = default_cfgs["convformer_s18"] 1156 | if pretrained: 1157 | state_dict = torch.hub.load_state_dict_from_url( 1158 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1159 | ) 1160 | model.load_state_dict(state_dict) 1161 | return model 1162 | 1163 | 1164 | @register_model 1165 | def convformer_s18_384(pretrained=False, **kwargs): 1166 | model = MetaFormer( 1167 | depths=[3, 3, 9, 3], 1168 | dims=[64, 128, 320, 512], 1169 | token_mixers=SepConv, 1170 | head_fn=MlpHead, 1171 | **kwargs, 1172 | ) 1173 | model.default_cfg = default_cfgs["convformer_s18_384"] 1174 | if pretrained: 1175 | state_dict = torch.hub.load_state_dict_from_url( 1176 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1177 | ) 1178 | model.load_state_dict(state_dict) 1179 | return model 1180 | 1181 | 1182 | @register_model 1183 | def convformer_s18_in21ft1k(pretrained=False, **kwargs): 1184 | model = MetaFormer( 1185 | depths=[3, 3, 9, 3], 1186 | dims=[64, 128, 320, 512], 1187 | token_mixers=SepConv, 1188 | head_fn=MlpHead, 1189 | **kwargs, 1190 | ) 1191 | model.default_cfg = default_cfgs["convformer_s18_in21ft1k"] 1192 | if pretrained: 1193 | state_dict = torch.hub.load_state_dict_from_url( 1194 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1195 | ) 1196 | model.load_state_dict(state_dict) 1197 | return model 1198 | 1199 | 1200 | @register_model 1201 | def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): 1202 | model = MetaFormer( 1203 | depths=[3, 3, 9, 3], 1204 | dims=[64, 128, 320, 512], 1205 | token_mixers=SepConv, 1206 | head_fn=MlpHead, 1207 | **kwargs, 1208 | ) 1209 | model.default_cfg = default_cfgs["convformer_s18_384_in21ft1k"] 1210 | if pretrained: 1211 | state_dict = torch.hub.load_state_dict_from_url( 1212 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1213 | ) 1214 | model.load_state_dict(state_dict) 1215 | return model 1216 | 1217 | 1218 | @register_model 1219 | def convformer_s18_in21k(pretrained=False, **kwargs): 1220 | model = MetaFormer( 1221 | depths=[3, 3, 9, 3], 1222 | dims=[64, 128, 320, 512], 1223 | token_mixers=SepConv, 1224 | head_fn=MlpHead, 1225 | **kwargs, 1226 | ) 1227 | model.default_cfg = default_cfgs["convformer_s18_in21k"] 1228 | if pretrained: 1229 | state_dict = torch.hub.load_state_dict_from_url( 1230 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1231 | ) 1232 | model.load_state_dict(state_dict) 1233 | return model 1234 | 1235 | 1236 | @register_model 1237 | def convformer_s36(pretrained=False, **kwargs): 1238 | model = MetaFormer( 1239 | depths=[3, 12, 18, 3], 1240 | dims=[64, 128, 320, 512], 1241 | token_mixers=SepConv, 1242 | head_fn=MlpHead, 1243 | **kwargs, 1244 | ) 1245 | model.default_cfg = default_cfgs["convformer_s36"] 1246 | if pretrained: 1247 | state_dict = torch.hub.load_state_dict_from_url( 1248 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1249 | ) 1250 | model.load_state_dict(state_dict) 1251 | return model 1252 | 1253 | 1254 | @register_model 1255 | def convformer_s36_384(pretrained=False, **kwargs): 1256 | model = MetaFormer( 1257 | depths=[3, 12, 18, 3], 1258 | dims=[64, 128, 320, 512], 1259 | token_mixers=SepConv, 1260 | head_fn=MlpHead, 1261 | **kwargs, 1262 | ) 1263 | model.default_cfg = default_cfgs["convformer_s36_384"] 1264 | if pretrained: 1265 | state_dict = torch.hub.load_state_dict_from_url( 1266 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1267 | ) 1268 | model.load_state_dict(state_dict) 1269 | return model 1270 | 1271 | 1272 | @register_model 1273 | def convformer_s36_in21ft1k(pretrained=False, **kwargs): 1274 | model = MetaFormer( 1275 | depths=[3, 12, 18, 3], 1276 | dims=[64, 128, 320, 512], 1277 | token_mixers=SepConv, 1278 | head_fn=MlpHead, 1279 | **kwargs, 1280 | ) 1281 | model.default_cfg = default_cfgs["convformer_s36_in21ft1k"] 1282 | if pretrained: 1283 | state_dict = torch.hub.load_state_dict_from_url( 1284 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1285 | ) 1286 | model.load_state_dict(state_dict) 1287 | return model 1288 | 1289 | 1290 | @register_model 1291 | def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): 1292 | model = MetaFormer( 1293 | depths=[3, 12, 18, 3], 1294 | dims=[64, 128, 320, 512], 1295 | token_mixers=SepConv, 1296 | head_fn=MlpHead, 1297 | **kwargs, 1298 | ) 1299 | model.default_cfg = default_cfgs["convformer_s36_384_in21ft1k"] 1300 | if pretrained: 1301 | state_dict = torch.hub.load_state_dict_from_url( 1302 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1303 | ) 1304 | model.load_state_dict(state_dict) 1305 | return model 1306 | 1307 | 1308 | @register_model 1309 | def convformer_s36_in21k(pretrained=False, **kwargs): 1310 | model = MetaFormer( 1311 | depths=[3, 12, 18, 3], 1312 | dims=[64, 128, 320, 512], 1313 | token_mixers=SepConv, 1314 | head_fn=MlpHead, 1315 | **kwargs, 1316 | ) 1317 | model.default_cfg = default_cfgs["convformer_s36_in21k"] 1318 | if pretrained: 1319 | state_dict = torch.hub.load_state_dict_from_url( 1320 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1321 | ) 1322 | model.load_state_dict(state_dict) 1323 | return model 1324 | 1325 | 1326 | @register_model 1327 | def convformer_m36(pretrained=False, **kwargs): 1328 | model = MetaFormer( 1329 | depths=[3, 12, 18, 3], 1330 | dims=[96, 192, 384, 576], 1331 | token_mixers=SepConv, 1332 | head_fn=MlpHead, 1333 | **kwargs, 1334 | ) 1335 | model.default_cfg = default_cfgs["convformer_m36"] 1336 | if pretrained: 1337 | state_dict = torch.hub.load_state_dict_from_url( 1338 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1339 | ) 1340 | model.load_state_dict(state_dict) 1341 | return model 1342 | 1343 | 1344 | @register_model 1345 | def convformer_m36_384(pretrained=False, **kwargs): 1346 | model = MetaFormer( 1347 | depths=[3, 12, 18, 3], 1348 | dims=[96, 192, 384, 576], 1349 | token_mixers=SepConv, 1350 | head_fn=MlpHead, 1351 | **kwargs, 1352 | ) 1353 | model.default_cfg = default_cfgs["convformer_m36_384"] 1354 | if pretrained: 1355 | state_dict = torch.hub.load_state_dict_from_url( 1356 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1357 | ) 1358 | model.load_state_dict(state_dict) 1359 | return model 1360 | 1361 | 1362 | @register_model 1363 | def convformer_m36_in21ft1k(pretrained=False, **kwargs): 1364 | model = MetaFormer( 1365 | depths=[3, 12, 18, 3], 1366 | dims=[96, 192, 384, 576], 1367 | token_mixers=SepConv, 1368 | head_fn=MlpHead, 1369 | **kwargs, 1370 | ) 1371 | model.default_cfg = default_cfgs["convformer_m36_in21ft1k"] 1372 | if pretrained: 1373 | state_dict = torch.hub.load_state_dict_from_url( 1374 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1375 | ) 1376 | model.load_state_dict(state_dict) 1377 | return model 1378 | 1379 | 1380 | @register_model 1381 | def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): 1382 | model = MetaFormer( 1383 | depths=[3, 12, 18, 3], 1384 | dims=[96, 192, 384, 576], 1385 | token_mixers=SepConv, 1386 | head_fn=MlpHead, 1387 | **kwargs, 1388 | ) 1389 | model.default_cfg = default_cfgs["convformer_m36_384_in21ft1k"] 1390 | if pretrained: 1391 | state_dict = torch.hub.load_state_dict_from_url( 1392 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1393 | ) 1394 | model.load_state_dict(state_dict) 1395 | return model 1396 | 1397 | 1398 | @register_model 1399 | def convformer_m36_in21k(pretrained=False, **kwargs): 1400 | model = MetaFormer( 1401 | depths=[3, 12, 18, 3], 1402 | dims=[96, 192, 384, 576], 1403 | token_mixers=SepConv, 1404 | head_fn=MlpHead, 1405 | **kwargs, 1406 | ) 1407 | model.default_cfg = default_cfgs["convformer_m36_in21k"] 1408 | if pretrained: 1409 | state_dict = torch.hub.load_state_dict_from_url( 1410 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1411 | ) 1412 | model.load_state_dict(state_dict) 1413 | return model 1414 | 1415 | 1416 | @register_model 1417 | def convformer_b36(pretrained=False, **kwargs): 1418 | model = MetaFormer( 1419 | depths=[3, 12, 18, 3], 1420 | dims=[128, 256, 512, 768], 1421 | token_mixers=SepConv, 1422 | head_fn=MlpHead, 1423 | **kwargs, 1424 | ) 1425 | model.default_cfg = default_cfgs["convformer_b36"] 1426 | if pretrained: 1427 | state_dict = torch.hub.load_state_dict_from_url( 1428 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1429 | ) 1430 | model.load_state_dict(state_dict) 1431 | return model 1432 | 1433 | 1434 | @register_model 1435 | def convformer_b36_384(pretrained=False, **kwargs): 1436 | model = MetaFormer( 1437 | depths=[3, 12, 18, 3], 1438 | dims=[128, 256, 512, 768], 1439 | token_mixers=SepConv, 1440 | head_fn=MlpHead, 1441 | **kwargs, 1442 | ) 1443 | model.default_cfg = default_cfgs["convformer_b36_384"] 1444 | if pretrained: 1445 | state_dict = torch.hub.load_state_dict_from_url( 1446 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1447 | ) 1448 | model.load_state_dict(state_dict) 1449 | return model 1450 | 1451 | 1452 | @register_model 1453 | def convformer_b36_in21ft1k(pretrained=False, **kwargs): 1454 | model = MetaFormer( 1455 | depths=[3, 12, 18, 3], 1456 | dims=[128, 256, 512, 768], 1457 | token_mixers=SepConv, 1458 | head_fn=MlpHead, 1459 | **kwargs, 1460 | ) 1461 | model.default_cfg = default_cfgs["convformer_b36_in21ft1k"] 1462 | if pretrained: 1463 | state_dict = torch.hub.load_state_dict_from_url( 1464 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1465 | ) 1466 | model.load_state_dict(state_dict) 1467 | return model 1468 | 1469 | 1470 | @register_model 1471 | def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): 1472 | model = MetaFormer( 1473 | depths=[3, 12, 18, 3], 1474 | dims=[128, 256, 512, 768], 1475 | token_mixers=SepConv, 1476 | head_fn=MlpHead, 1477 | **kwargs, 1478 | ) 1479 | model.default_cfg = default_cfgs["convformer_b36_384_in21ft1k"] 1480 | if pretrained: 1481 | state_dict = torch.hub.load_state_dict_from_url( 1482 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1483 | ) 1484 | model.load_state_dict(state_dict) 1485 | return model 1486 | 1487 | 1488 | @register_model 1489 | def convformer_b36_in21k(pretrained=False, **kwargs): 1490 | model = MetaFormer( 1491 | depths=[3, 12, 18, 3], 1492 | dims=[128, 256, 512, 768], 1493 | token_mixers=SepConv, 1494 | head_fn=MlpHead, 1495 | **kwargs, 1496 | ) 1497 | model.default_cfg = default_cfgs["convformer_b36_in21k"] 1498 | if pretrained: 1499 | state_dict = torch.hub.load_state_dict_from_url( 1500 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1501 | ) 1502 | model.load_state_dict(state_dict) 1503 | return model 1504 | 1505 | 1506 | @register_model 1507 | def caformer_s18(pretrained=False, **kwargs): 1508 | model = MetaFormer( 1509 | depths=[3, 3, 9, 3], 1510 | dims=[64, 128, 320, 512], 1511 | token_mixers=[SepConv, SepConv, Attention, Attention], 1512 | head_fn=MlpHead, 1513 | **kwargs, 1514 | ) 1515 | model.default_cfg = default_cfgs["caformer_s18"] 1516 | if pretrained: 1517 | state_dict = torch.hub.load_state_dict_from_url( 1518 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1519 | ) 1520 | model.load_state_dict(state_dict) 1521 | return model 1522 | 1523 | 1524 | @register_model 1525 | def caformer_s18_384(pretrained=False, **kwargs): 1526 | model = MetaFormer( 1527 | depths=[3, 3, 9, 3], 1528 | dims=[64, 128, 320, 512], 1529 | token_mixers=[SepConv, SepConv, Attention, Attention], 1530 | head_fn=MlpHead, 1531 | **kwargs, 1532 | ) 1533 | model.default_cfg = default_cfgs["caformer_s18_384"] 1534 | if pretrained: 1535 | state_dict = torch.hub.load_state_dict_from_url( 1536 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1537 | ) 1538 | model.load_state_dict(state_dict) 1539 | return model 1540 | 1541 | 1542 | @register_model 1543 | def caformer_s18_in21ft1k(pretrained=False, **kwargs): 1544 | model = MetaFormer( 1545 | depths=[3, 3, 9, 3], 1546 | dims=[64, 128, 320, 512], 1547 | token_mixers=[SepConv, SepConv, Attention, Attention], 1548 | head_fn=MlpHead, 1549 | **kwargs, 1550 | ) 1551 | model.default_cfg = default_cfgs["caformer_s18_in21ft1k"] 1552 | if pretrained: 1553 | state_dict = torch.hub.load_state_dict_from_url( 1554 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1555 | ) 1556 | model.load_state_dict(state_dict) 1557 | return model 1558 | 1559 | 1560 | @register_model 1561 | def caformer_s18_384_in21ft1k(pretrained=False, **kwargs): 1562 | model = MetaFormer( 1563 | depths=[3, 3, 9, 3], 1564 | dims=[64, 128, 320, 512], 1565 | token_mixers=[SepConv, SepConv, Attention, Attention], 1566 | head_fn=MlpHead, 1567 | **kwargs, 1568 | ) 1569 | model.default_cfg = default_cfgs["caformer_s18_384_in21ft1k"] 1570 | if pretrained: 1571 | state_dict = torch.hub.load_state_dict_from_url( 1572 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1573 | ) 1574 | model.load_state_dict(state_dict) 1575 | return model 1576 | 1577 | 1578 | @register_model 1579 | def caformer_s18_in21k(pretrained=False, **kwargs): 1580 | model = MetaFormer( 1581 | depths=[3, 3, 9, 3], 1582 | dims=[64, 128, 320, 512], 1583 | token_mixers=[SepConv, SepConv, Attention, Attention], 1584 | head_fn=MlpHead, 1585 | **kwargs, 1586 | ) 1587 | model.default_cfg = default_cfgs["caformer_s18_in21k"] 1588 | if pretrained: 1589 | state_dict = torch.hub.load_state_dict_from_url( 1590 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1591 | ) 1592 | model.load_state_dict(state_dict) 1593 | return model 1594 | 1595 | 1596 | @register_model 1597 | def caformer_s36(pretrained=False, **kwargs): 1598 | model = MetaFormer( 1599 | depths=[3, 12, 18, 3], 1600 | dims=[64, 128, 320, 512], 1601 | token_mixers=[SepConv, SepConv, Attention, Attention], 1602 | head_fn=MlpHead, 1603 | **kwargs, 1604 | ) 1605 | model.default_cfg = default_cfgs["caformer_s36"] 1606 | if pretrained: 1607 | state_dict = torch.hub.load_state_dict_from_url( 1608 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1609 | ) 1610 | model.load_state_dict(state_dict) 1611 | return model 1612 | 1613 | 1614 | @register_model 1615 | def caformer_s36_384(pretrained=False, **kwargs): 1616 | model = MetaFormer( 1617 | depths=[3, 12, 18, 3], 1618 | dims=[64, 128, 320, 512], 1619 | token_mixers=[SepConv, SepConv, Attention, Attention], 1620 | head_fn=MlpHead, 1621 | **kwargs, 1622 | ) 1623 | model.default_cfg = default_cfgs["caformer_s36_384"] 1624 | if pretrained: 1625 | state_dict = torch.hub.load_state_dict_from_url( 1626 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1627 | ) 1628 | model.load_state_dict(state_dict) 1629 | return model 1630 | 1631 | 1632 | @register_model 1633 | def caformer_s36_in21ft1k(pretrained=False, **kwargs): 1634 | model = MetaFormer( 1635 | depths=[3, 12, 18, 3], 1636 | dims=[64, 128, 320, 512], 1637 | token_mixers=[SepConv, SepConv, Attention, Attention], 1638 | head_fn=MlpHead, 1639 | **kwargs, 1640 | ) 1641 | model.default_cfg = default_cfgs["caformer_s36_in21ft1k"] 1642 | if pretrained: 1643 | state_dict = torch.hub.load_state_dict_from_url( 1644 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1645 | ) 1646 | model.load_state_dict(state_dict) 1647 | return model 1648 | 1649 | 1650 | @register_model 1651 | def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): 1652 | model = MetaFormer( 1653 | depths=[3, 12, 18, 3], 1654 | dims=[64, 128, 320, 512], 1655 | token_mixers=[SepConv, SepConv, Attention, Attention], 1656 | head_fn=MlpHead, 1657 | **kwargs, 1658 | ) 1659 | model.default_cfg = default_cfgs["caformer_s36_384_in21ft1k"] 1660 | if pretrained: 1661 | state_dict = torch.hub.load_state_dict_from_url( 1662 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1663 | ) 1664 | model.load_state_dict(state_dict) 1665 | return model 1666 | 1667 | 1668 | @register_model 1669 | def caformer_s36_in21k(pretrained=False, **kwargs): 1670 | model = MetaFormer( 1671 | depths=[3, 12, 18, 3], 1672 | dims=[64, 128, 320, 512], 1673 | token_mixers=[SepConv, SepConv, Attention, Attention], 1674 | head_fn=MlpHead, 1675 | **kwargs, 1676 | ) 1677 | model.default_cfg = default_cfgs["caformer_s36_in21k"] 1678 | if pretrained: 1679 | state_dict = torch.hub.load_state_dict_from_url( 1680 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1681 | ) 1682 | model.load_state_dict(state_dict) 1683 | return model 1684 | 1685 | 1686 | @register_model 1687 | def caformer_m36(pretrained=False, **kwargs): 1688 | model = MetaFormer( 1689 | depths=[3, 12, 18, 3], 1690 | dims=[96, 192, 384, 576], 1691 | token_mixers=[SepConv, SepConv, Attention, Attention], 1692 | head_fn=MlpHead, 1693 | **kwargs, 1694 | ) 1695 | model.default_cfg = default_cfgs["caformer_m36"] 1696 | if pretrained: 1697 | state_dict = torch.hub.load_state_dict_from_url( 1698 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1699 | ) 1700 | model.load_state_dict(state_dict) 1701 | return model 1702 | 1703 | 1704 | @register_model 1705 | def caformer_m36_384(pretrained=False, **kwargs): 1706 | model = MetaFormer( 1707 | depths=[3, 12, 18, 3], 1708 | dims=[96, 192, 384, 576], 1709 | token_mixers=[SepConv, SepConv, Attention, Attention], 1710 | head_fn=MlpHead, 1711 | **kwargs, 1712 | ) 1713 | model.default_cfg = default_cfgs["caformer_m36_384"] 1714 | if pretrained: 1715 | state_dict = torch.hub.load_state_dict_from_url( 1716 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1717 | ) 1718 | model.load_state_dict(state_dict) 1719 | return model 1720 | 1721 | 1722 | @register_model 1723 | def caformer_m36_in21ft1k(pretrained=False, **kwargs): 1724 | model = MetaFormer( 1725 | depths=[3, 12, 18, 3], 1726 | dims=[96, 192, 384, 576], 1727 | token_mixers=[SepConv, SepConv, Attention, Attention], 1728 | head_fn=MlpHead, 1729 | **kwargs, 1730 | ) 1731 | model.default_cfg = default_cfgs["caformer_m36_in21ft1k"] 1732 | if pretrained: 1733 | state_dict = torch.hub.load_state_dict_from_url( 1734 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1735 | ) 1736 | model.load_state_dict(state_dict) 1737 | return model 1738 | 1739 | 1740 | @register_model 1741 | def caformer_m36_384_in21ft1k(pretrained=False, **kwargs): 1742 | model = MetaFormer( 1743 | depths=[3, 12, 18, 3], 1744 | dims=[96, 192, 384, 576], 1745 | token_mixers=[SepConv, SepConv, Attention, Attention], 1746 | head_fn=MlpHead, 1747 | **kwargs, 1748 | ) 1749 | model.default_cfg = default_cfgs["caformer_m36_384_in21ft1k"] 1750 | if pretrained: 1751 | state_dict = torch.hub.load_state_dict_from_url( 1752 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1753 | ) 1754 | model.load_state_dict(state_dict) 1755 | return model 1756 | 1757 | 1758 | @register_model 1759 | def caformer_m364_in21k(pretrained=False, **kwargs): 1760 | model = MetaFormer( 1761 | depths=[3, 12, 18, 3], 1762 | dims=[96, 192, 384, 576], 1763 | token_mixers=[SepConv, SepConv, Attention, Attention], 1764 | head_fn=MlpHead, 1765 | **kwargs, 1766 | ) 1767 | model.default_cfg = default_cfgs["caformer_m364_in21k"] 1768 | if pretrained: 1769 | state_dict = torch.hub.load_state_dict_from_url( 1770 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1771 | ) 1772 | model.load_state_dict(state_dict) 1773 | return model 1774 | 1775 | 1776 | @register_model 1777 | def caformer_b36(pretrained=False, **kwargs): 1778 | model = MetaFormer( 1779 | depths=[3, 12, 18, 3], 1780 | dims=[128, 256, 512, 768], 1781 | token_mixers=[SepConv, SepConv, Attention, Attention], 1782 | head_fn=MlpHead, 1783 | **kwargs, 1784 | ) 1785 | model.default_cfg = default_cfgs["caformer_b36"] 1786 | if pretrained: 1787 | state_dict = torch.hub.load_state_dict_from_url( 1788 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1789 | ) 1790 | model.load_state_dict(state_dict) 1791 | return model 1792 | 1793 | 1794 | @register_model 1795 | def caformer_b36_384(pretrained=False, **kwargs): 1796 | model = MetaFormer( 1797 | depths=[3, 12, 18, 3], 1798 | dims=[128, 256, 512, 768], 1799 | token_mixers=[SepConv, SepConv, Attention, Attention], 1800 | head_fn=MlpHead, 1801 | **kwargs, 1802 | ) 1803 | model.default_cfg = default_cfgs["caformer_b36_384"] 1804 | if pretrained: 1805 | state_dict = torch.hub.load_state_dict_from_url( 1806 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1807 | ) 1808 | model.load_state_dict(state_dict) 1809 | return model 1810 | 1811 | 1812 | @register_model 1813 | def caformer_b36_in21ft1k(pretrained=False, **kwargs): 1814 | model = MetaFormer( 1815 | depths=[3, 12, 18, 3], 1816 | dims=[128, 256, 512, 768], 1817 | token_mixers=[SepConv, SepConv, Attention, Attention], 1818 | head_fn=MlpHead, 1819 | **kwargs, 1820 | ) 1821 | model.default_cfg = default_cfgs["caformer_b36_in21ft1k"] 1822 | if pretrained: 1823 | state_dict = torch.hub.load_state_dict_from_url( 1824 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1825 | ) 1826 | model.load_state_dict(state_dict) 1827 | return model 1828 | 1829 | 1830 | @register_model 1831 | def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): 1832 | model = MetaFormer( 1833 | depths=[3, 12, 18, 3], 1834 | dims=[128, 256, 512, 768], 1835 | token_mixers=[SepConv, SepConv, Attention, Attention], 1836 | head_fn=MlpHead, 1837 | **kwargs, 1838 | ) 1839 | model.default_cfg = default_cfgs["caformer_b36_384_in21ft1k"] 1840 | if pretrained: 1841 | state_dict = torch.hub.load_state_dict_from_url( 1842 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1843 | ) 1844 | model.load_state_dict(state_dict) 1845 | return model 1846 | 1847 | 1848 | @register_model 1849 | def caformer_b36_in21k(pretrained=False, **kwargs): 1850 | model = MetaFormer( 1851 | depths=[3, 12, 18, 3], 1852 | dims=[128, 256, 512, 768], 1853 | token_mixers=[SepConv, SepConv, Attention, Attention], 1854 | head_fn=MlpHead, 1855 | **kwargs, 1856 | ) 1857 | model.default_cfg = default_cfgs["caformer_b36_in21k"] 1858 | if pretrained: 1859 | state_dict = torch.hub.load_state_dict_from_url( 1860 | url=model.default_cfg["url"], map_location="cpu", check_hash=True 1861 | ) 1862 | model.load_state_dict(state_dict) 1863 | return model 1864 | --------------------------------------------------------------------------------