├── figures ├── SemAIM.png └── visualization.png ├── util ├── lr_sched.py ├── crop.py ├── lars.py ├── lr_decay.py ├── pos_embed.py ├── blocks.py └── misc.py ├── LICENSE ├── models ├── models_vit.py ├── models_clip.py └── models_semaim.py ├── datasets └── datasets.py ├── engines ├── engine_finetune.py └── engine_pretrain.py ├── README.md ├── main_knn.py ├── main_pretrain.py ├── main_linprobe.py └── main_finetune.py /figures/SemAIM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyoux/SemAIM/HEAD/figures/SemAIM.png -------------------------------------------------------------------------------- /figures/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyoux/SemAIM/HEAD/figures/visualization.png -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 SenseTime 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F.get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /models/models_vit.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 6 | # DeiT: https://github.com/facebookresearch/deit 7 | # MAE: https://github.com/facebookresearch/mae 8 | # -------------------------------------------------------- 9 | 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | import timm.models.vision_transformer 16 | 17 | 18 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 19 | """ Vision Transformer with support for global average pooling 20 | """ 21 | def __init__(self, global_pool=False, **kwargs): 22 | super(VisionTransformer, self).__init__(**kwargs) 23 | 24 | self.global_pool = global_pool 25 | if self.global_pool: 26 | norm_layer = kwargs['norm_layer'] 27 | embed_dim = kwargs['embed_dim'] 28 | self.fc_norm = norm_layer(embed_dim) 29 | 30 | del self.norm # remove the original norm 31 | 32 | def forward_features(self, x): 33 | B = x.shape[0] 34 | x = self.patch_embed(x) 35 | 36 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 37 | x = torch.cat((cls_tokens, x), dim=1) 38 | x = x + self.pos_embed 39 | x = self.pos_drop(x) 40 | 41 | for blk in self.blocks: 42 | x = blk(x) 43 | 44 | if self.global_pool: 45 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 46 | outcome = self.fc_norm(x) 47 | else: 48 | x = self.norm(x) 49 | outcome = x[:, 0] 50 | 51 | return outcome 52 | 53 | def forward_head(self, x): 54 | return self.head(x) 55 | 56 | 57 | def vit_small_patch16(**kwargs): 58 | model = VisionTransformer( 59 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 60 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 61 | return model 62 | 63 | 64 | def vit_base_patch16(**kwargs): 65 | model = VisionTransformer( 66 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 67 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 68 | return model 69 | 70 | 71 | def vit_large_patch16(**kwargs): 72 | model = VisionTransformer( 73 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 74 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 75 | return model 76 | 77 | 78 | def vit_huge_patch14(**kwargs): 79 | model = VisionTransformer( 80 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 81 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 82 | return model -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # DeiT: https://github.com/facebookresearch/deit 6 | # MAE: https://github.com/facebookresearch/mae 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import PIL 11 | import torchvision.transforms 12 | from PIL import Image 13 | 14 | from torch.utils.data import Dataset 15 | from torchvision import datasets, transforms 16 | from torchvision.datasets.folder import default_loader 17 | 18 | from timm.data import create_transform 19 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 20 | 21 | # from refile import smart_open as open 22 | # import nori2 as nori 23 | from io import BytesIO as Bytes2Data 24 | 25 | 26 | class ImageListFolder(datasets.ImageFolder): 27 | def __init__(self, root, transform=None, target_transform=None, 28 | ann_file=None, loader=default_loader): 29 | self.root = root 30 | self.transform = transform 31 | self.loader = loader 32 | self.target_transform = target_transform 33 | self.nb_classes = 1000 34 | 35 | assert ann_file is not None 36 | print('load info from', ann_file) 37 | 38 | self.samples = [] 39 | ann = open(ann_file) 40 | for elem in ann.readlines(): 41 | cut = elem.split(' ') 42 | path_current = os.path.join(root, cut[0]) 43 | target_current = int(cut[1]) 44 | self.samples.append((path_current, target_current)) 45 | ann.close() 46 | 47 | print('load finish') 48 | 49 | 50 | def build_dataset(is_train, args): 51 | transform = build_transform(is_train, args) 52 | 53 | # TODO modify your own dataset here 54 | folder = os.path.join(args.data_path, 'train' if is_train else 'val') 55 | ann_file = os.path.join(args.data_path, 'train.txt' if is_train else 'val.txt') 56 | dataset = ImageListFolder(folder, transform=transform, ann_file=ann_file) 57 | 58 | print(dataset) 59 | 60 | return dataset 61 | 62 | 63 | def build_transform(is_train, args): 64 | mean = IMAGENET_DEFAULT_MEAN 65 | std = IMAGENET_DEFAULT_STD 66 | # train transform 67 | if is_train: 68 | # this should always dispatch to transforms_imagenet_train 69 | transform = create_transform( 70 | input_size=args.input_size, 71 | is_training=True, 72 | color_jitter=args.color_jitter, 73 | auto_augment=args.aa, 74 | interpolation='bicubic', 75 | re_prob=args.reprob, 76 | re_mode=args.remode, 77 | re_count=args.recount, 78 | mean=mean, 79 | std=std, 80 | ) 81 | return transform 82 | 83 | # eval transform 84 | t = [] 85 | if args.input_size <= 224: 86 | crop_pct = 224 / 256 87 | else: 88 | crop_pct = 1.0 89 | size = int(args.input_size / crop_pct) 90 | t.append( 91 | transforms.Resize(size, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), # to maintain same ratio w.r.t. 224 images 92 | ) 93 | t.append(transforms.CenterCrop(args.input_size)) 94 | 95 | t.append(transforms.ToTensor()) 96 | t.append(transforms.Normalize(mean, std)) 97 | return transforms.Compose(t) 98 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /engines/engine_finetune.py: -------------------------------------------------------------------------------- 1 | # References: 2 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 3 | # DeiT: https://github.com/facebookresearch/deit 4 | # MAE: https://github.com/facebookresearch/mae 5 | # -------------------------------------------------------- 6 | 7 | import math 8 | import sys 9 | from typing import Iterable, Optional 10 | 11 | import torch 12 | 13 | from timm.data import Mixup 14 | from timm.utils import accuracy 15 | 16 | import util.misc as misc 17 | import util.lr_sched as lr_sched 18 | 19 | 20 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 21 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 22 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 23 | mixup_fn: Optional[Mixup] = None, log_writer=None, 24 | args=None): 25 | model.train(True) 26 | metric_logger = misc.MetricLogger(delimiter=" ") 27 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) 29 | print_freq = 20 30 | 31 | accum_iter = args.accum_iter 32 | 33 | optimizer.zero_grad() 34 | 35 | if log_writer is not None: 36 | print('log_dir: {}'.format(log_writer.log_dir)) 37 | 38 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 39 | 40 | # we use a per iteration (instead of per epoch) lr scheduler 41 | if data_iter_step % accum_iter == 0: 42 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 43 | 44 | samples = samples.to(device, non_blocking=True) 45 | targets = targets.to(device, non_blocking=True) 46 | 47 | if mixup_fn is not None: 48 | samples, targets = mixup_fn(samples, targets) 49 | 50 | with torch.cuda.amp.autocast(): 51 | outputs = model(samples) 52 | loss = criterion(outputs, targets) 53 | 54 | loss_value = loss.item() 55 | 56 | if not math.isfinite(loss_value): 57 | print("Loss is {}, stopping training".format(loss_value)) 58 | sys.exit(1) 59 | 60 | loss /= accum_iter 61 | loss_scaler(loss, optimizer, clip_grad=max_norm, 62 | parameters=model.parameters(), create_graph=False, 63 | update_grad=(data_iter_step + 1) % accum_iter == 0) 64 | if (data_iter_step + 1) % accum_iter == 0: 65 | optimizer.zero_grad() 66 | 67 | torch.cuda.synchronize() 68 | 69 | metric_logger.update(loss=loss_value) 70 | min_lr = 10. 71 | max_lr = 0. 72 | for group in optimizer.param_groups: 73 | min_lr = min(min_lr, group["lr"]) 74 | max_lr = max(max_lr, group["lr"]) 75 | 76 | metric_logger.update(lr=max_lr) 77 | 78 | loss_value_reduce = misc.all_reduce_mean(loss_value) 79 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 80 | """ We use epoch_1000x as the x-axis in tensorboard. 81 | This calibrates different curves when batch size changes. 82 | """ 83 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 84 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 85 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 86 | 87 | # gather the stats from all processes 88 | metric_logger.synchronize_between_processes() 89 | print("Averaged stats:", metric_logger) 90 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 91 | 92 | 93 | @torch.no_grad() 94 | def evaluate(data_loader, model, device): 95 | criterion = torch.nn.CrossEntropyLoss() 96 | 97 | metric_logger = misc.MetricLogger(delimiter=" ") 98 | header = 'Test:' 99 | 100 | # switch to evaluation mode 101 | model.eval() 102 | 103 | for batch in metric_logger.log_every(data_loader, 10, header): 104 | images = batch[0] 105 | target = batch[-1] 106 | images = images.to(device, non_blocking=True) 107 | target = target.to(device, non_blocking=True) 108 | 109 | # compute output 110 | with torch.cuda.amp.autocast(): 111 | output = model(images) 112 | loss = criterion(output, target) 113 | 114 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 115 | 116 | batch_size = images.shape[0] 117 | metric_logger.update(loss=loss.item()) 118 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 119 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 120 | # gather the stats from all processes 121 | metric_logger.synchronize_between_processes() 122 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 123 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 124 | 125 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SemAIM 2 | Official PyTorch Implementation of [Semantic-Aware Autoregressive Image Modeling for Visual Representation Learning](https://arxiv.org/abs/2312.10457), Accepted by AAAI 2024. 3 | ## Introduction 4 | ![Pipeline](./figures/SemAIM.png) 5 | 6 | **SemAIM** is a novel autoregressive image modeling method for self-supervised learning. The key insight of SemAIM is to autoregressive model images from the semantic patches to the less semantic patches. 7 | 8 | ## Main Results on ImageNet-1k 9 | 10 | The results of Finetune Accuracy (%) on ImageNet-1k are as following: 11 | 12 | | Models | Pretrain Epochs | ViT-B | ViT-L | 13 | | :------: | :-------------: | :---------: | :---------: | 14 | | DINO | 800 | 82.8 | - | 15 | | BEiT | 800 | 83.2 | 85.2 | 16 | | MAE | 1600 | 83.6 | 85.9 | 17 | | SimMIM | 1600 | 83.8 | - | 18 | | LocalMIM | 1600 | 84.0 | - | 19 | | HPM | 800 | 84.2 | 85.8 | 20 | | iGPT | - | 72.6 | - | 21 | | ViT-iGPT | 300 | 82.7 | - | 22 | | RandSAC | 1600 | 83.9 | - | 23 | | SAIM | 800 | 83.9 | - | 24 | | SemAIM | 400 | 83.8 | 85.5 | 25 | | SemAIM | 800 | 84.1 | 85.8 | 26 | | SemAIM* | 800 | **85.3** | **86.5** | 27 | 28 | * means using CLIP feature as predict targets. 29 | 30 | ## Getting Started 31 | 32 | ### Install 33 | - Clone this repo: 34 | 35 | ```bash 36 | git clone https://github.com/skyoux/SemAIM 37 | cd SemAIM 38 | ``` 39 | 40 | - Create a conda environment and activate it: 41 | ```bash 42 | conda create -n semaim python=3.9 43 | conda activate semaim 44 | ``` 45 | 46 | - Install `Pytorch==1.13.0` and `torchvision==0.14.0` with `CUDA==11.6` 47 | 48 | ```bash 49 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 50 | ``` 51 | 52 | - Install `timm==0.4.5` 53 | 54 | ```bash 55 | pip install timm==0.4.5 56 | ``` 57 | 58 | ### Data preparation 59 | 60 | You can download the ImageNet-1K [here](https://image-net.org) and prepare the ImageNet-1K follow this format: 61 | ```tree data 62 | imagenet 63 | ├── train 64 | │ ├── class1 65 | │ │ ├── img1.jpeg 66 | │ │ ├── img2.jpeg 67 | │ │ └── ... 68 | │ ├── class2 69 | │ │ ├── img3.jpeg 70 | │ │ └── ... 71 | │ └── ... 72 | └── val 73 | ├── class1 74 | │ ├── img4.jpeg 75 | │ ├── img5.jpeg 76 | │ └── ... 77 | ├── class2 78 | │ ├── img6.jpeg 79 | │ └── ... 80 | └── ... 81 | ``` 82 | 83 | ### Pretrain 84 | ```shell 85 | python -m torch.distributed.launch --nproc_per_node 8 --nnodes 4 --node_rank 0 \ 86 | main_pretrain.py \ 87 | --batch_size 64 --epochs 800 --accum_iter 1 \ 88 | --model aim_base --query_depth 12 --prediction_head_type MLP \ 89 | --gaussian_kernel_size 9 --gaussian_sigma 1 --norm_pix_loss \ 90 | --permutation_type attention_center --attention_type cls \ 91 | --blr 2e-4 --warmup_epochs 30 --weight_decay 0.05 --clip_grad 3 \ 92 | --data_path --output_dir \ 93 | --log_dir \ 104 | --finetune --output_dir \ 105 | --log_dir \ 114 | --finetune --output_dir \ 115 | --log_dir \ 124 | --checkpoint_key state_dict \ 125 | --data_path \ 126 | --use_cuda \ 127 | ``` 128 | 129 | ## Visualization 130 | 131 | ![SemAIM-visualization](./figures/visualization.png) 132 | 133 | Visualization of different autoregression orders. (a) input images, (b) raster order used in iGPT, (c) stochastic order used in SAIM, (d) similarity order (the similarity map is directly used as the autoregression order), and (e) semantic-aware order used in SemAIM. In (b)(c)(d)(e), the first column shows the self-attention maps from the last block, the second column shows similarity maps from the last block, and the last column shows the corresponding autoregression orders (more warm-colored patches are predicted first). 134 | The self-attention maps and the similarity maps of the semantic-aware order used in SemAIM locate on semantic regions more accurately than 135 | other methods. This indicates that SemAIM can learn more semantic representations. 136 | 137 | ## Acknowledgement 138 | 139 | This project is based on [SAIM](https://github.com/qiy20/SAIM), [DeiT](https://github.com/facebookresearch/deit), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [MAE](https://github.com/facebookresearch/mae), and [DINO](https://github.com/facebookresearch/dino). 140 | 141 | ## LICENSE 142 | 143 | SemAIM is released under the [MIT License](./LICENSE). 144 | 145 | ## Citation 146 | 147 | ``` 148 | @article{song2023semantic, 149 | title={Semantic-Aware Autoregressive Image Modeling for Visual Representation Learning}, 150 | author={Song, Kaiyou and Zhang, Shan and Wang, Tong}, 151 | journal={arXiv preprint arXiv:2312.10457}, 152 | year={2023} 153 | } 154 | ``` 155 | 156 | -------------------------------------------------------------------------------- /util/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from timm.models.vision_transformer import Mlp, DropPath 7 | 8 | 9 | class Attention_SelfMask(nn.Module): 10 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 11 | super().__init__() 12 | self.num_heads = num_heads 13 | head_dim = dim // num_heads 14 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 15 | self.scale = qk_scale or head_dim ** -0.5 16 | 17 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 18 | self.attn_drop = nn.Dropout(attn_drop) 19 | self.proj = nn.Linear(dim, dim) 20 | self.proj_drop = nn.Dropout(proj_drop) 21 | 22 | def forward(self, x, mask=None, return_attention=False): 23 | B, N, C = x.shape 24 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 25 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 26 | 27 | attn = (q @ k.transpose(-2, -1)) * self.scale 28 | if mask is not None: 29 | attn += mask 30 | attn = attn.softmax(dim=-1) 31 | if return_attention: 32 | return attn # B H N N 33 | 34 | attn = self.attn_drop(attn) 35 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 36 | x = self.proj(x) 37 | x = self.proj_drop(x) 38 | return x 39 | 40 | 41 | class Block_SelfMask(nn.Module): 42 | 43 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 44 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 45 | super().__init__() 46 | self.norm1 = norm_layer(dim) 47 | self.attn = Attention_SelfMask( 48 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 49 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 50 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 51 | self.norm2 = norm_layer(dim) 52 | mlp_hidden_dim = int(dim * mlp_ratio) 53 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 54 | 55 | def forward(self, x, mask=None, return_attention=False): 56 | if return_attention: 57 | return self.attn(self.norm1(x), mask, return_attention) 58 | x = x + self.drop_path(self.attn(self.norm1(x), mask)) 59 | x = x + self.drop_path(self.mlp(self.norm2(x))) 60 | return x 61 | 62 | 63 | class Attention_SelfCrossMask(nn.Module): 64 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 65 | super().__init__() 66 | self.num_heads = num_heads 67 | head_dim = dim // num_heads 68 | self.scale = head_dim ** -0.5 69 | 70 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 71 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 72 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 73 | self.attn_drop = nn.Dropout(attn_drop) 74 | self.proj = nn.Linear(dim, dim) 75 | self.proj_drop = nn.Dropout(proj_drop) 76 | 77 | def forward(self, q, k, v, mask=None, return_attention=False): 78 | B, N, C = q.shape 79 | # B, N_k, C = k.shape 80 | # B, N_v, C = v.shape 81 | q = self.q(q).reshape(B, N, self.num_heads, C // self.num_heads).transpose(1, 2) 82 | k = self.k(k).reshape(B, N, self.num_heads, C // self.num_heads).transpose(1, 2) 83 | v = self.v(v).reshape(B, N, self.num_heads, C // self.num_heads).transpose(1, 2) 84 | attn = (q @ k.transpose(-2, -1)) * self.scale 85 | if mask is not None: 86 | attn += mask 87 | attn = attn.softmax(dim=-1) 88 | attn = self.attn_drop(attn) 89 | if return_attention: 90 | return attn 91 | 92 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 93 | x = self.proj(x) 94 | x = self.proj_drop(x) 95 | return x 96 | 97 | 98 | class Block_SelfCrossMask(nn.Module): 99 | """ 100 | The universal attention block can be used as both self-attention and cross-attention. 101 | q,k,v can define separately. 102 | If we only assign a value to q, it's a self-attention block; 103 | if we assign values for q and k, it's a cross-attention block. 104 | """ 105 | 106 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 107 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 108 | super().__init__() 109 | self.norm1 = norm_layer(dim) 110 | self.attn = Attention_SelfCrossMask(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 111 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 112 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 113 | self.norm2 = norm_layer(dim) 114 | mlp_hidden_dim = int(dim * mlp_ratio) 115 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 116 | 117 | def forward(self, q, k=None, v=None, mask=None, return_attention=False): 118 | if k is None: 119 | k = q 120 | if v is None: 121 | v = k 122 | if return_attention: 123 | return self.attn(self.norm1(q), self.norm1(k), self.norm1(v), mask, return_attention) 124 | x = q + self.drop_path(self.attn(self.norm1(q), self.norm1(k), self.norm1(v), mask)) 125 | x = x + self.drop_path(self.mlp(self.norm2(x))) 126 | return x 127 | 128 | 129 | 130 | class GaussianConv2d(nn.Module): 131 | def __init__(self, channels=3, kernel_size=9, sigma=1): 132 | super().__init__() 133 | position = torch.stack(torch.meshgrid([torch.arange(kernel_size), torch.arange(kernel_size)]), dim=-1) 134 | mean = torch.tensor([(kernel_size - 1) // 2, (kernel_size - 1) // 2]) 135 | std = torch.tensor([sigma, sigma]) 136 | kernel = 1 / (2 * math.pi * torch.prod(std, dim=-1)) * math.e ** (-((position - mean) ** 2 / std ** 2).sum(-1)/2) 137 | kernel = kernel / kernel.sum() 138 | 139 | kernel = kernel.view(1, 1, kernel_size, kernel_size).repeat(channels, 1, 1, 1) 140 | 141 | self.register_buffer('weight', kernel) 142 | self.groups = channels 143 | self.padding = kernel_size // 2 144 | 145 | def forward(self, input): 146 | return F.conv2d(input, weight=self.weight, groups=self.groups, padding=self.padding) -------------------------------------------------------------------------------- /engines/engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # References: 2 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 3 | # DeiT: https://github.com/facebookresearch/deit 4 | # MAE: https://github.com/facebookresearch/mae 5 | # -------------------------------------------------------- 6 | import math 7 | import sys 8 | from typing import Iterable 9 | 10 | import torch 11 | import torchvision 12 | import cv2 13 | import numpy as np 14 | 15 | import util.misc as misc 16 | import util.lr_sched as lr_sched 17 | 18 | 19 | def generate_saliency(saliency_model, imgs): 20 | patch_size = 16 21 | width = imgs.shape[2] // patch_size 22 | with torch.no_grad(): 23 | # print("imgs:", imgs) 24 | # d1, d2, d3, d4, d5, d6, d7, d8 = self.saliency_model(imgs) 25 | d1, _, _, _, _, _, _, _ = saliency_model(imgs) 26 | saliency = d1[:, 0, :, :] # Bx224x224 27 | # mx, mn = torch.max(pred), torch.min(pred) 28 | # pred = (pred - mn) / (mx - mn) 29 | # print('saliency:',saliency) 30 | # print("max value:", saliency.max()) 31 | pred = torch.nn.functional.interpolate(saliency.unsqueeze(dim=1), (width, width), 32 | mode='bilinear',align_corners=True) 33 | N, _, _, _ = pred.shape 34 | pred = pred.reshape(N, -1) 35 | mx, mn = torch.max(pred, dim=-1, keepdim=True)[0], torch.min(pred, dim=-1, keepdim=True)[0] 36 | pred = (pred - mn) / (mx - mn + 1e-5) 37 | 38 | return pred, saliency.unsqueeze(dim=1) 39 | 40 | 41 | def forward_teacher_features(model, x, model_type): 42 | assert model_type in ['dino', 'clip'] 43 | if model_type == 'dino': 44 | return forward_features_dino(model, x) 45 | else: 46 | return forward_features_clip(model, x) 47 | 48 | 49 | def forward_features_dino(model, x): 50 | B = x.shape[0] 51 | x = model.patch_embed(x) 52 | 53 | cls_tokens = model.cls_token.expand(B, -1, -1) 54 | x = torch.cat((cls_tokens, x), dim=1) 55 | x = x + model.pos_embed 56 | x = model.pos_drop(x) 57 | 58 | for blk in model.blocks: 59 | x = blk(x) 60 | 61 | x = model.norm(x) 62 | # return x[:, 1:, :] 63 | return x 64 | 65 | 66 | def forward_features_clip(model, x): 67 | x = model.conv1(x) # shape = [*, width, grid, grid] 68 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 69 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 70 | x = torch.cat( 71 | [model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], 72 | dim=1) # shape = [*, grid ** 2 + 1, width] 73 | x = x + model.positional_embedding.to(x.dtype) 74 | x = model.ln_pre(x) 75 | 76 | x = x.permute(1, 0, 2) # NLD -> LND 77 | x = model.transformer(x) 78 | x = x.permute(1, 0, 2) # LND -> NLD 79 | 80 | # x = model.ln_post(x[:, 0, :]) 81 | x = model.ln_post(x) 82 | 83 | if model.proj is not None: 84 | x = x @ model.proj 85 | 86 | # return x[:, 1:, :] 87 | return x 88 | 89 | def calculate_similarity(tokens): 90 | tokens = torch.nn.functional.normalize(tokens, p=2, dim=-1) 91 | similarity = torch.sum(tokens[:, 0, :].unsqueeze(1) * tokens[:, 1:, :], dim=-1) 92 | 93 | mx, mn = torch.max(similarity, dim=1, keepdim=True)[0], torch.min(similarity, dim=1, keepdim=True)[0] 94 | similarity = (similarity - mn) / (mx - mn + 1e-6) 95 | 96 | return similarity 97 | 98 | def applyColorMap_on_tensor(tensor, images, alpha=0.3, norm=False, inverse=False): 99 | # tensor: B C H W 100 | heat_map = [] 101 | tensor = tensor.cpu() 102 | for i in range(tensor.shape[0]): 103 | temp_map = tensor[i] 104 | if norm: 105 | temp_map = (temp_map - temp_map.min()) / (temp_map.max() - temp_map.min() + 1e-5) 106 | if inverse: 107 | temp_map = 1 - temp_map # 这里不应该 1-的,但是显示的colormap反了,所以这样操作了 108 | temp_map = np.uint8(255 * temp_map) 109 | temp_map = cv2.applyColorMap(temp_map, cv2.COLORMAP_JET) # 0-255 110 | heat_map.append(temp_map) 111 | heat_map = torch.Tensor(np.array(heat_map)).cuda().permute(0, 3, 1, 2) 112 | heat_map = torch.clip(heat_map * alpha + images * (1 - alpha), 0, 255) 113 | return heat_map 114 | 115 | 116 | def train_one_epoch(model: torch.nn.Module, 117 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 118 | device: torch.device, epoch: int, loss_scaler, max_norm=None, 119 | log_writer=None, 120 | args=None, model_ema=None, teacher_model=None): 121 | model.train(True) 122 | metric_logger = misc.MetricLogger(delimiter=" ") 123 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 124 | header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) 125 | print_freq = 20 126 | 127 | accum_iter = args.accum_iter 128 | 129 | if args.use_ema_model: 130 | assert model_ema is not None 131 | if epoch < 100: 132 | model_ema.decay = 0.999 + epoch / 100 * (0.9999 - 0.999) 133 | else: 134 | model_ema.decay = 0.9999 135 | 136 | optimizer.zero_grad() 137 | 138 | if log_writer is not None: 139 | print('log_dir: {}'.format(log_writer.log_dir)) 140 | 141 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 142 | it = len(data_loader) * epoch + data_iter_step 143 | # we use a per iteration (instead of per epoch) lr scheduler 144 | if data_iter_step % accum_iter == 0: 145 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 146 | 147 | samples = samples.to(device, non_blocking=True) 148 | 149 | enc_tokens, attention = None, None 150 | 151 | # generate attention 152 | feature_attention, self_attention = None, None 153 | if args.predict_feature == 'none' and 'attention' in args.permutation_type: # stochastic for debug 154 | model.eval() 155 | with torch.no_grad(): 156 | if args.use_ema_model: 157 | _, feature_attention, self_attention = model_ema.ema(samples, forward_encoder=True) 158 | else: 159 | enc_tokens, feature_attention, self_attention = model(samples, forward_encoder=True) 160 | model.train() 161 | feature_attention, self_attention = feature_attention.detach(), self_attention.detach() 162 | attention = self_attention if args.attention_type == 'self' else feature_attention 163 | 164 | with torch.cuda.amp.autocast(loss_scaler is not None): 165 | if args.predict_feature == 'inference': 166 | model.eval() 167 | with torch.no_grad(): 168 | enc_tokens, feature_attention, self_attention = model(samples, forward_encoder=True) 169 | enc_tokens = enc_tokens.detach() 170 | feature_attention, self_attention = feature_attention.detach(), self_attention.detach() 171 | model.train() 172 | attention = self_attention if args.attention_type == 'self' else feature_attention 173 | elif args.predict_feature == 'ema': 174 | if enc_tokens == None: 175 | with torch.no_grad(): 176 | enc_tokens, _, _ = model_ema.ema(samples, forward_encoder=True) 177 | enc_tokens = enc_tokens.detach() 178 | attention = self_attention if args.attention_type == 'self' else feature_attention 179 | elif args.predict_feature == 'dino': 180 | with torch.no_grad(): 181 | enc_tokens = forward_teacher_features(teacher_model, samples, 'dino') 182 | enc_tokens = enc_tokens.detach() 183 | attention = calculate_similarity(enc_tokens) 184 | feature_attention = attention 185 | elif args.predict_feature == 'clip': 186 | with torch.no_grad(): 187 | enc_tokens = forward_teacher_features(teacher_model, samples, 'clip') 188 | enc_tokens = enc_tokens.detach() 189 | attention = calculate_similarity(enc_tokens) 190 | feature_attention = attention 191 | 192 | loss, permutation, loss_map = model(samples, enc_tokens, attention) 193 | 194 | loss_value = loss.item() 195 | 196 | if not math.isfinite(loss_value): 197 | print("Loss is {}, stopping training".format(loss_value)) 198 | sys.exit(1) 199 | 200 | loss /= accum_iter 201 | 202 | if loss_scaler is None: 203 | loss.backward() 204 | if (data_iter_step + 1) % accum_iter == 0: 205 | norm = 0 206 | if max_norm is not None: 207 | norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 208 | optimizer.step() 209 | else: 210 | norm = None 211 | else: 212 | norm = loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=max_norm, 213 | update_grad=(data_iter_step + 1) % accum_iter == 0) 214 | fp16_scaler = loss_scaler._scaler.get_scale() 215 | 216 | if (data_iter_step + 1) % accum_iter == 0: 217 | optimizer.zero_grad() 218 | if model_ema is not None: 219 | model_ema.update(model) 220 | 221 | torch.cuda.synchronize() 222 | 223 | metric_logger.update(loss=loss_value, total_norm=norm) 224 | 225 | lr = optimizer.param_groups[0]["lr"] 226 | metric_logger.update(lr=lr) 227 | 228 | loss_value_reduce = misc.all_reduce_mean(loss_value) 229 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 230 | """ We use epoch_1000x as the x-axis in tensorboard. 231 | This calibrates different curves when batch size changes. 232 | """ 233 | # epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 234 | log_writer.add_scalar('loss', loss_value_reduce, it) 235 | log_writer.add_scalar('lr', lr, it) 236 | log_writer.add_scalar('grad_norm', norm, it) 237 | if loss_scaler is not None: 238 | log_writer.add_scalar('fp16_scaler', fp16_scaler, it) 239 | 240 | # gather the stats from all processes 241 | metric_logger.synchronize_between_processes() 242 | print("Averaged stats:", metric_logger) 243 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /main_knn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | import argparse 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | import torch.distributed as dist 10 | import torch.backends.cudnn as cudnn 11 | from torchvision import datasets 12 | from torchvision import transforms as pth_transforms 13 | from torchvision import models as torchvision_models 14 | import timm.models as timm_models 15 | from timm.models.layers import trunc_normal_ 16 | 17 | import util.misc as misc 18 | from util.pos_embed import interpolate_pos_embed 19 | 20 | from models import models_vit 21 | 22 | def extract_feature_pipeline(args): 23 | ######################## preparing data ... ######################## 24 | resize_size = 256 if args.input_size == 224 else 512 25 | transform = pth_transforms.Compose([ 26 | pth_transforms.Resize(resize_size, interpolation=3), 27 | pth_transforms.CenterCrop(args.input_size), 28 | pth_transforms.ToTensor(), 29 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 30 | ]) 31 | 32 | dataset_train = ReturnIndexDataset(os.path.join(args.data_path, 'train'), transform) 33 | dataset_val = ReturnIndexDataset(os.path.join(args.data_path, 'val'), transform) 34 | 35 | 36 | train_labels = torch.tensor(dataset_train.target).long() 37 | test_labels = torch.tensor(dataset_val.target).long() 38 | 39 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) 40 | data_loader_train = torch.utils.data.DataLoader( 41 | dataset_train, 42 | sampler=sampler, 43 | batch_size=args.batch_size_per_gpu, 44 | num_workers=args.num_workers, 45 | pin_memory=False, 46 | drop_last=False, 47 | ) 48 | data_loader_val = torch.utils.data.DataLoader( 49 | dataset_val, 50 | batch_size=args.batch_size_per_gpu, 51 | num_workers=args.num_workers, 52 | pin_memory=False, 53 | drop_last=False, 54 | ) 55 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 56 | 57 | ######################## building network ... ######################## 58 | model = models_vit.__dict__[args.model]( 59 | num_classes=0, 60 | global_pool=args.global_pool, 61 | ) 62 | 63 | if args.pretrained_weights: 64 | checkpoint = torch.load(args.pretrained_weights, map_location='cpu') 65 | 66 | print("Load pre-trained checkpoint from: %s" % args.pretrained_weights) 67 | if args.checkpoint_key in checkpoint: 68 | checkpoint_model = checkpoint[args.checkpoint_key] 69 | else: 70 | print(f"There is no {args.checkpoint_key} in given checkpoints!") 71 | sys.exit(1) 72 | state_dict = model.state_dict() 73 | checkpoint_model = {k.replace("module.", ""): v for k, v in checkpoint_model.items()} 74 | for k in ['head.weight', 'head.bias']: 75 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 76 | print(f"Removing key {k} from pretrained checkpoint") 77 | del checkpoint_model[k] 78 | 79 | # interpolate position embedding 80 | interpolate_pos_embed(model, checkpoint_model) 81 | 82 | # load pre-trained model 83 | msg = model.load_state_dict(checkpoint_model, strict=False) 84 | print(msg) 85 | 86 | # print(model) 87 | model.cuda() 88 | model.eval() 89 | 90 | ######################## extract features ... ######################## 91 | print("Extracting features for train set...") 92 | train_features = extract_features(model, data_loader_train, args.model, args.avgpool_patchtokens, args.use_cuda) 93 | print("Extracting features for val set...") 94 | test_features = extract_features(model, data_loader_val, args.model, args.avgpool_patchtokens, args.use_cuda) 95 | 96 | global_rank = misc.get_rank() 97 | if global_rank == 0: 98 | train_features = nn.functional.normalize(train_features, dim=1, p=2) 99 | test_features = nn.functional.normalize(test_features, dim=1, p=2) 100 | 101 | # save features and labels 102 | if args.dump_features and dist.get_rank() == 0: 103 | torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")) 104 | torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")) 105 | torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")) 106 | torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")) 107 | return train_features, test_features, train_labels, test_labels 108 | 109 | 110 | @torch.no_grad() 111 | def extract_features(model, data_loader, arch="resnet50", avgpool_patchtokens=1, use_cuda=True): 112 | metric_logger = misc.MetricLogger(delimiter=" ") 113 | features = None 114 | for samples, index in metric_logger.log_every(data_loader, 10): 115 | samples = samples.cuda(non_blocking=True) 116 | index = index.cuda(non_blocking=True) 117 | 118 | feats = model(samples).clone() 119 | if len(feats.shape) != 2: 120 | feats = feats.squeeze() 121 | 122 | # init storage feature matrix 123 | if dist.get_rank() == 0 and features is None: 124 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) 125 | if use_cuda: 126 | features = features.cuda(non_blocking=True) 127 | print(f"Storing features into tensor of shape {features.shape}") 128 | 129 | # get indexes from all processes 130 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) 131 | y_l = list(y_all.unbind(0)) 132 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) 133 | y_all_reduce.wait() 134 | index_all = torch.cat(y_l) 135 | 136 | # share features between processes 137 | feats_all = torch.empty( 138 | dist.get_world_size(), 139 | feats.size(0), 140 | feats.size(1), 141 | dtype=feats.dtype, 142 | device=feats.device, 143 | ) 144 | output_l = list(feats_all.unbind(0)) 145 | output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) 146 | output_all_reduce.wait() 147 | 148 | # update storage feature matrix 149 | if dist.get_rank() == 0: 150 | if use_cuda: 151 | features.index_copy_(0, index_all, torch.cat(output_l)) 152 | else: 153 | features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) 154 | return features 155 | 156 | 157 | @torch.no_grad() 158 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000, use_cuda=True): 159 | top1, top5, total = 0.0, 0.0, 0 160 | train_features = train_features.t() 161 | num_test_images, num_chunks = test_labels.shape[0], 100 162 | imgs_per_chunk = num_test_images // num_chunks 163 | retrieval_one_hot = torch.zeros(k, num_classes) 164 | if use_cuda: 165 | retrieval_one_hot = retrieval_one_hot.cuda() 166 | for idx in range(0, num_test_images, imgs_per_chunk): 167 | # get the features for test images 168 | features = test_features[ 169 | idx : min((idx + imgs_per_chunk), num_test_images), : 170 | ] 171 | targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] 172 | batch_size = targets.shape[0] 173 | 174 | # calculate the dot product and compute top-k neighbors 175 | similarity = torch.mm(features, train_features) 176 | distances, indices = similarity.topk(k, largest=True, sorted=True) 177 | candidates = train_labels.view(1, -1).expand(batch_size, -1) #500x1281167 178 | retrieved_neighbors = torch.gather(candidates, 1, indices) # 500x10 179 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() #5000x0 180 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 181 | distances_transform = distances.clone().div_(T).exp_() 182 | probs = torch.sum( 183 | torch.mul( 184 | retrieval_one_hot.view(batch_size, -1, num_classes), 185 | distances_transform.view(batch_size, -1, 1), 186 | ), 187 | 1, 188 | ) 189 | _, predictions = probs.sort(1, True) 190 | 191 | # find the predictions that match the target 192 | correct = predictions.eq(targets.data.view(-1, 1)) 193 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 194 | top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() # top5 does not make sense if k < 5 195 | total += targets.size(0) 196 | top1 = top1 * 100.0 / total 197 | top5 = top5 * 100.0 / total 198 | return top1, top5 199 | 200 | 201 | class ReturnIndexDataset(datasets.ImageFolder): 202 | def __getitem__(self, idx): 203 | img, lab = super(ReturnIndexDataset, self).__getitem__(idx) 204 | return img, idx 205 | 206 | 207 | def get_args_parser(): 208 | parser = argparse.ArgumentParser("KNN Evaluation", add_help=False) 209 | parser.add_argument('--input_size', default=224, type=int, help='input image size') 210 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') 211 | parser.add_argument('--nb_knn', default=[20, 10, 30], nargs='+', type=int, 212 | help='Number of NN to use. 20 is usually working the best.') 213 | parser.add_argument('--nb_classes', default=1000, type=int, help='Number of labels for linear classifier') 214 | parser.add_argument('--temperature', default=0.07, type=float, 215 | help='Temperature used in the voting coefficient') 216 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 217 | parser.add_argument('--use_cuda', default=False, action='store_true', 218 | help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM") 219 | parser.set_defaults(norm_pix_loss=False) 220 | parser.add_argument('--model', default='vit_small', type=str, help='Architecture') 221 | parser.add_argument("--checkpoint_key", default="state_dict", type=str, 222 | help='Key to use in the checkpoint (example: "teacher")') 223 | # for ViTs 224 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 225 | parser.add_argument('--avgpool_patchtokens', default=0, choices=[0, 1, 2], type=int, 226 | help="""Whether or not to use global average pooled features or the [CLS] token. 227 | We typically set this to 1 for BEiT and 0 for models with [CLS] token (e.g., DINO). 228 | we set this to 2 for base-size models with [CLS] token when doing linear classification.""") 229 | parser.add_argument('--global_pool', action='store_true') 230 | parser.set_defaults(global_pool=False) 231 | 232 | parser.add_argument('--dump_features', default=None, 233 | help='Path where to save computed features, empty for no saving') 234 | parser.add_argument('--load_features', default=None, help="""If the features have 235 | already been computed, where to find them.""") 236 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 237 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 238 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 239 | parser.add_argument('--dist_backend', default='nccl', type=str, help='experiment name (for log)') 240 | parser.add_argument('--dist_on_itp', action='store_true') 241 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 242 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) 243 | parser.add_argument('--method', default='moco', type=str, help='model name') 244 | 245 | return parser 246 | 247 | 248 | if __name__ == '__main__': 249 | parser = argparse.ArgumentParser("KNN Evaluation", parents=[get_args_parser()]) 250 | args = parser.parse_args() 251 | 252 | misc.init_distributed_mode(args) 253 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 254 | cudnn.benchmark = True 255 | 256 | if args.load_features: 257 | train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) 258 | test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) 259 | train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) 260 | test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) 261 | else: 262 | # need to extract features ! 263 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args) 264 | 265 | if misc.get_rank() == 0: 266 | if args.use_cuda: 267 | train_features = train_features.cuda() 268 | test_features = test_features.cuda() 269 | train_labels = train_labels.cuda() 270 | test_labels = test_labels.cuda() 271 | else: 272 | train_features = train_features.cpu() 273 | test_features = test_features.cpu() 274 | train_labels = train_labels.cpu() 275 | test_labels = test_labels.cpu() 276 | 277 | print("Features are ready!\nStart the k-NN classification.") 278 | for k in args.nb_knn: 279 | top1, top5 = knn_classifier(train_features, train_labels, 280 | test_features, test_labels, k, args.temperature, args.nb_classes, args.use_cuda) 281 | print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") 282 | dist.barrier() 283 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if (i % print_freq == 0 or i == len(iterable) - 1) and is_main_process(): 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | if is_main_process(): 167 | print('{} Total time: {} ({:.4f} s / it)'.format( 168 | header, total_time_str, total_time / len(iterable))) 169 | 170 | 171 | def setup_for_distributed(is_master): 172 | """ 173 | This function disables printing when not in master process 174 | """ 175 | builtin_print = builtins.print 176 | 177 | def print(*args, **kwargs): 178 | force = kwargs.pop('force', False) 179 | force = force or (get_world_size() > 8) 180 | if is_master or force: 181 | now = datetime.datetime.now().time() 182 | builtin_print('[{}] '.format(now), end='') # print with time stamp 183 | builtin_print(*args, **kwargs) 184 | 185 | builtins.print = print 186 | 187 | 188 | def is_dist_avail_and_initialized(): 189 | if not dist.is_available(): 190 | return False 191 | if not dist.is_initialized(): 192 | return False 193 | return True 194 | 195 | 196 | def get_world_size(): 197 | if not is_dist_avail_and_initialized(): 198 | return 1 199 | return dist.get_world_size() 200 | 201 | 202 | def get_rank(): 203 | if not is_dist_avail_and_initialized(): 204 | return 0 205 | return dist.get_rank() 206 | 207 | 208 | def is_main_process(): 209 | return get_rank() == 0 210 | 211 | 212 | def save_on_master(*args, **kwargs): 213 | if is_main_process(): 214 | torch.save(*args, **kwargs) 215 | 216 | 217 | def init_distributed_mode(args): 218 | if args.dist_on_itp: 219 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 220 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 221 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 222 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 223 | os.environ['LOCAL_RANK'] = str(args.gpu) 224 | os.environ['RANK'] = str(args.rank) 225 | os.environ['WORLD_SIZE'] = str(args.world_size) 226 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 227 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 228 | args.rank = int(os.environ["RANK"]) 229 | args.world_size = int(os.environ['WORLD_SIZE']) 230 | args.gpu = int(os.environ['LOCAL_RANK']) 231 | elif 'SLURM_PROCID' in os.environ: 232 | import subprocess 233 | num_gpus = torch.cuda.device_count() 234 | args.rank = int(os.environ['SLURM_PROCID']) 235 | args.world_size = int(os.environ['SLURM_NTASKS']) 236 | args.gpu = args.rank % torch.cuda.device_count() 237 | node_list = os.environ["SLURM_NODELIST"] 238 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") 239 | if "MASTER_PORT" not in os.environ: 240 | os.environ["MASTER_PORT"] = "23233" 241 | if "MASTER_ADDR" not in os.environ: 242 | os.environ["MASTER_ADDR"] = addr 243 | # args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 244 | os.environ["WORLD_SIZE"] = str(args.world_size) 245 | os.environ["LOCAL_RANK"] = str(args.rank % num_gpus) 246 | os.environ["RANK"] = str(args.rank) 247 | else: 248 | print('Not using distributed mode') 249 | setup_for_distributed(is_master=True) # hack 250 | args.distributed = False 251 | return 252 | 253 | args.distributed = True 254 | 255 | torch.cuda.set_device(args.gpu) 256 | args.dist_backend = 'nccl' 257 | print('| distributed init (rank {}): {}, gpu {}'.format( 258 | args.rank, args.dist_url, args.gpu), flush=True) 259 | dist.init_process_group( 260 | backend=args.dist_backend, 261 | world_size=args.world_size, 262 | rank=args.rank, 263 | ) 264 | # torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 265 | # world_size=args.world_size, rank=args.rank) 266 | torch.distributed.barrier() 267 | setup_for_distributed(args.rank == 0) 268 | 269 | 270 | class NativeScalerWithGradNormCount: 271 | state_dict_key = "amp_scaler" 272 | 273 | def __init__(self): 274 | self._scaler = torch.cuda.amp.GradScaler() 275 | 276 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 277 | self._scaler.scale(loss).backward(create_graph=create_graph) 278 | if update_grad: 279 | if clip_grad is not None: 280 | assert parameters is not None 281 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 282 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 283 | else: 284 | self._scaler.unscale_(optimizer) 285 | norm = get_grad_norm_(parameters) 286 | self._scaler.step(optimizer) 287 | self._scaler.update() 288 | else: 289 | norm = None 290 | return norm 291 | 292 | def state_dict(self): 293 | return self._scaler.state_dict() 294 | 295 | def load_state_dict(self, state_dict): 296 | self._scaler.load_state_dict(state_dict) 297 | 298 | 299 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 300 | if isinstance(parameters, torch.Tensor): 301 | parameters = [parameters] 302 | parameters = [p for p in parameters if p.grad is not None] 303 | norm_type = float(norm_type) 304 | if len(parameters) == 0: 305 | return torch.tensor(0.) 306 | device = parameters[0].grad.device 307 | if norm_type == inf: 308 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 309 | else: 310 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 311 | return total_norm 312 | 313 | 314 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 315 | output_dir = Path(args.output_dir) 316 | epoch_name = str(epoch) 317 | if loss_scaler is not None: 318 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 319 | for checkpoint_path in checkpoint_paths: 320 | to_save = { 321 | 'model': model_without_ddp.state_dict(), 322 | 'optimizer': optimizer.state_dict(), 323 | 'epoch': epoch, 324 | 'loss_scaler': loss_scaler.state_dict(), 325 | 'args': args, 326 | } 327 | 328 | save_on_master(to_save, checkpoint_path) 329 | save_on_master(to_save, os.path.join(args.output_dir, "checkpoint.pth")) 330 | 331 | else: 332 | client_state = {'epoch': epoch} 333 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 334 | 335 | 336 | def load_model(args, ckpt_path, model_without_ddp, optimizer, loss_scaler, model_ema=None): 337 | if ckpt_path.startswith('https'): 338 | checkpoint = torch.hub.load_state_dict_from_url( 339 | ckpt_path, map_location='cpu', check_hash=True) 340 | else: 341 | checkpoint = torch.load(ckpt_path, map_location='cpu') 342 | 343 | if 'state_dict' in checkpoint: 344 | model_without_ddp.load_state_dict(checkpoint['state_dict']) 345 | else: 346 | model_without_ddp.load_state_dict(checkpoint['module']) 347 | 348 | if 'ema_state_dict' in checkpoint: 349 | model_ema.load_state_dict(checkpoint['ema_state_dict']) 350 | 351 | print("Resume checkpoint %s" % ckpt_path) 352 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 353 | optimizer.load_state_dict(checkpoint['optimizer']) 354 | args.start_epoch = checkpoint['epoch'] 355 | if 'loss_scaler' in checkpoint and loss_scaler is not None: 356 | loss_scaler.load_state_dict(checkpoint['loss_scaler']) 357 | print("With optim & sched!") 358 | 359 | def all_reduce_mean(x): 360 | world_size = get_world_size() 361 | if world_size > 1: 362 | x_reduce = torch.tensor(x).cuda() 363 | dist.all_reduce(x_reduce) 364 | x_reduce /= world_size 365 | return x_reduce.item() 366 | else: 367 | return x 368 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # References: 2 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 3 | # DeiT: https://github.com/facebookresearch/deit 4 | # MAE: https://github.com/facebookresearch/mae 5 | # -------------------------------------------------------- 6 | import argparse 7 | import datetime 8 | import json 9 | import os 10 | import time 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import timm 15 | import timm.optim.optim_factory as optim_factory 16 | from timm.utils import ModelEma 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | import torchvision.datasets as datasets 20 | import torchvision.transforms as transforms 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | from models import models_semaim as models_aim 24 | from engines.engine_pretrain import train_one_epoch 25 | from datasets.datasets import ImagenetLoader 26 | import util.misc as misc 27 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 28 | 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser('SemAIM pre-training', add_help=False) 32 | parser.add_argument('--batch_size', default=64, type=int, 33 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 34 | parser.add_argument('--epochs', default=800, type=int) 35 | parser.add_argument('--accum_iter', default=1, type=int, 36 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 37 | 38 | # Model parameters 39 | parser.add_argument('--model', default='saim_base', type=str, metavar='MODEL', 40 | help='Name of model to train') 41 | parser.add_argument('--input_size', default=224, type=int, 42 | help='images input size') 43 | 44 | parser.add_argument('--query_depth', default=12, type=int, 45 | help='decoder depth') 46 | parser.add_argument('--share_weight', action='store_true', 47 | help='Share weight between encoder and decoder') 48 | 49 | parser.add_argument('--prediction_head_type', default='MLP', type=str, 50 | help='the type of prediction head: MLP or LINEAR') 51 | parser.add_argument('--gaussian_kernel_size', default=None, type=int, 52 | help='Use gaussian blur to smooth the target image') 53 | parser.add_argument('--gaussian_sigma', default=None, type=int, 54 | help='standard deviation of gaussian blur') 55 | parser.add_argument('--loss_type', default='L2', type=str, 56 | help='Calculate loss between prediction and target per pixel: L1 or L2') 57 | parser.add_argument('--norm_pix_loss', action='store_true', 58 | help='Use (per-patch) normalized pixels as targets for computing loss') 59 | 60 | # semaim 61 | parser.add_argument('--permutation_type', default='stochastic', type=str, 62 | help='Permutation type for autoregression: zigzag, raster, stochastic, center2out, out2center, saliency,' 63 | ' attention, attention_guided, saliency_guided, stochastic_center, attention_center') 64 | parser.add_argument('--use_ema_model', action='store_true', help='Use ema features as targets for computing loss') 65 | parser.set_defaults(use_ema_model=False) 66 | parser.add_argument('--predict_feature', default='none', type=str, help='Use features as targets: none, inference, ema, dino, clip') 67 | # parser.set_defaults(predict_feature=False) 68 | parser.add_argument('--attention_type', default='cls', type=str, help='Attention type: gap, cls and self') 69 | 70 | # Optimizer parameters 71 | parser.add_argument('--weight_decay', type=float, default=0.05, 72 | help='weight decay (default: 0.05)') 73 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 74 | help='Clip gradient norm (default: None, no clipping)') 75 | 76 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 77 | help='learning rate (absolute lr)') 78 | parser.add_argument('--blr', type=float, default=2e-4, metavar='LR', 79 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 80 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 81 | help='lower lr bound for cyclic schedulers that hit 0') 82 | 83 | parser.add_argument('--warmup_epochs', type=int, default=30, metavar='N', 84 | help='epochs to warmup LR') 85 | parser.add_argument('--not_use_fp16', action='store_true', help='whether to use fp16') 86 | parser.set_defaults(not_use_fp16=False) 87 | 88 | # Dataset parameters 89 | parser.add_argument('--data_path', default='../imagenet', type=str, help='dataset path') 90 | 91 | parser.add_argument('--output_dir', default='./pretrain/saim_base', 92 | help='path where to save, empty for no saving') 93 | parser.add_argument('--log_dir', default='./output_dir', help='path where to tensorboard log') 94 | parser.add_argument('--saveckp_freq', default=20, type=int, help='Save checkpoint every x epochs.') 95 | parser.add_argument('--experiment', default='exp', type=str, help='experiment name (for log)') 96 | parser.add_argument('--device', default='cuda', 97 | help='device to use for training / testing') 98 | parser.add_argument('--seed', default=0, type=int) 99 | parser.add_argument('--resume', default='', 100 | help='resume from checkpoint') 101 | 102 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 103 | help='start epoch') 104 | parser.add_argument('--num_workers', default=4, type=int) 105 | parser.add_argument('--pin_mem', action='store_true', 106 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 107 | 108 | # distributed training parameters 109 | parser.add_argument('--world_size', default=1, type=int, 110 | help='number of distributed processes') 111 | parser.add_argument('--local_rank', default=-1, type=int) 112 | parser.add_argument('--dist_on_itp', action='store_true') 113 | parser.add_argument('--dist_url', default='env://', 114 | help='url used to set up distributed training') 115 | 116 | return parser 117 | 118 | 119 | def main(args): 120 | misc.init_distributed_mode(args) 121 | 122 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 123 | print("{}".format(args).replace(', ', ',\n')) 124 | 125 | device = torch.device(args.device) 126 | 127 | # fix the seed for reproducibility 128 | seed = args.seed + misc.get_rank() 129 | torch.manual_seed(seed) 130 | np.random.seed(seed) 131 | 132 | cudnn.benchmark = True 133 | 134 | # simple augmentation 135 | transform_train = transforms.Compose([ 136 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 137 | transforms.RandomHorizontalFlip(), 138 | transforms.ToTensor(), 139 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 140 | 141 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 142 | print(dataset_train) 143 | 144 | if True: # args.distributed: 145 | num_tasks = misc.get_world_size() 146 | global_rank = misc.get_rank() 147 | sampler_train = torch.utils.data.DistributedSampler( 148 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 149 | ) 150 | print("Sampler_train = %s" % str(sampler_train)) 151 | else: 152 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 153 | 154 | data_loader_train = torch.utils.data.DataLoader( 155 | dataset_train, sampler=sampler_train, 156 | batch_size=args.batch_size, 157 | num_workers=args.num_workers, 158 | pin_memory=args.pin_mem, 159 | drop_last=True, 160 | ) 161 | 162 | # define the model 163 | out_dim = 512 164 | model = models_aim.__dict__[args.model](permutation_type=args.permutation_type,attention_type=args.attention_type, 165 | query_depth=args.query_depth, share_weight=args.share_weight,out_dim=out_dim, 166 | prediction_head_type=args.prediction_head_type, 167 | gaussian_kernel_size=args.gaussian_kernel_size, 168 | gaussian_sigma=args.gaussian_sigma, 169 | loss_type=args.loss_type, predict_feature=args.predict_feature, 170 | norm_pix_loss=args.norm_pix_loss) 171 | 172 | model.to(device) 173 | 174 | model_without_ddp = model 175 | if misc.is_main_process(): 176 | print("Model = %s" % str(model_without_ddp)) 177 | 178 | # define ema model 179 | model_ema = None 180 | teacher_model = None 181 | if args.use_ema_model: 182 | # if args.predict_feature == 'ema': 183 | # assert args.predict_feature == 'ema' 184 | model_ema = ModelEma(model, decay=0.999, device=args.device, resume='') 185 | elif args.predict_feature == 'dino': 186 | teacher_model = timm.models.vit_base_patch16_224(num_classes=0) 187 | state_dict = torch.load('/path_to_dino_model/dino_vitbase16_pretrain.pth') 188 | # state_dict = torch.load('/data/code/ssl/checkpoints/ssl_ckpt/ar/ibot_vitbase16_pretrain.pth') 189 | msg = teacher_model.load_state_dict(state_dict, strict=False) 190 | print("loaded dino model with msg:", msg) 191 | teacher_model.to(device) 192 | teacher_model.eval() 193 | elif args.predict_feature == 'clip': 194 | from models.models_clip import build_model 195 | state_dict = torch.load('/path_to_clip_model/clip_vitbase16_pretrain.pth', map_location='cpu') 196 | model_clip = build_model(state_dict) 197 | msg = model_clip.load_state_dict(state_dict, strict=False) 198 | print("loaded clip model with msg:", msg) 199 | model_clip.float() 200 | teacher_model = model_clip.visual 201 | teacher_model.to(device) 202 | teacher_model.eval() 203 | 204 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 205 | 206 | if args.lr is None: # only base_lr is specified 207 | args.lr = args.blr * eff_batch_size / 256 208 | 209 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 210 | print("actual lr: %.2e" % args.lr) 211 | 212 | print("accumulate grad iterations: %d" % args.accum_iter) 213 | print("effective batch size: %d" % eff_batch_size) 214 | 215 | if args.distributed: 216 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], 217 | find_unused_parameters=True) 218 | model_without_ddp = model.module 219 | 220 | # following timm: set wd as 0 for bias and norm layers 221 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 222 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 223 | print(optimizer) 224 | if args.not_use_fp16: 225 | loss_scaler = None 226 | else: 227 | loss_scaler = NativeScaler() 228 | 229 | ckpt_path = os.path.join(args.output_dir, f"{args.model}.{args.experiment}.temp.pth") 230 | if not os.path.isfile(ckpt_path): 231 | print("Checkpoint not founded in {}, train from random initialization".format(ckpt_path)) 232 | else: 233 | print("Found checkpoint at {}".format(ckpt_path)) 234 | model_ema_state_dict = model_ema.ema if args.use_ema_model else None 235 | misc.load_model(args=args, ckpt_path=ckpt_path, model_without_ddp=model, model_ema=model_ema_state_dict, 236 | optimizer=optimizer, loss_scaler=loss_scaler) 237 | 238 | if global_rank == 0: 239 | log_dir = os.path.join(args.log_dir, f"{args.model}.{args.experiment}") 240 | os.makedirs(log_dir, exist_ok=True) 241 | log_writer = SummaryWriter(log_dir=log_dir) 242 | else: 243 | log_writer = None 244 | 245 | print(f"Start training for {args.epochs} epochs") 246 | start_time = time.time() 247 | for epoch in range(args.start_epoch, args.epochs): 248 | if args.distributed: 249 | data_loader_train.sampler.set_epoch(epoch) 250 | train_stats = train_one_epoch( 251 | model, data_loader_train, 252 | optimizer, device, epoch, loss_scaler, args.clip_grad, 253 | log_writer=log_writer, 254 | args=args, model_ema=model_ema,teacher_model=teacher_model, 255 | ) 256 | 257 | save_dict = { 258 | "epoch": epoch + 1, 259 | "state_dict": model.state_dict(), 260 | "optimizer": optimizer.state_dict(), 261 | "model": args.model, 262 | } 263 | if loss_scaler is not None: 264 | save_dict['loss_scaler'] = loss_scaler.state_dict() 265 | if model_ema is not None: 266 | save_dict['ema_state_dict'] = model_ema.ema.state_dict() 267 | 268 | ckpt_path = os.path.join(args.output_dir, f"{args.model}.{args.experiment}.temp.pth") 269 | misc.save_on_master(save_dict, ckpt_path) 270 | print(f"model_path: {ckpt_path}") 271 | 272 | if args.output_dir and ((epoch + 1) % args.saveckp_freq == 0 or epoch + 1 == args.epochs): 273 | ckpt_path = os.path.join(args.output_dir, "{}.{}.{:04d}.pth".format(args.model, args.experiment, epoch+1)) 274 | misc.save_on_master(save_dict, ckpt_path) 275 | 276 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, } 277 | 278 | if args.output_dir and misc.is_main_process(): 279 | if log_writer is not None: 280 | log_writer.flush() 281 | with open(os.path.join(args.output_dir,"{}.{}.log.txt".format(args.model,args.experiment)), mode="a", encoding="utf-8") as f: 282 | f.write(json.dumps(log_stats) + "\n") 283 | 284 | total_time = time.time() - start_time 285 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 286 | print('Training time {}'.format(total_time_str)) 287 | 288 | 289 | if __name__ == '__main__': 290 | args = get_args_parser() 291 | args = args.parse_args() 292 | if args.output_dir: 293 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 294 | main(args) 295 | -------------------------------------------------------------------------------- /main_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # MoCo v3: https://github.com/facebookresearch/moco-v3 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from torch.utils.tensorboard import SummaryWriter 23 | import torchvision.transforms as transforms 24 | import torchvision.datasets as datasets 25 | 26 | import timm 27 | 28 | # assert timm.__version__ == "0.3.2" # version check 29 | from timm.models.layers import trunc_normal_ 30 | 31 | import util.misc as misc 32 | from util.pos_embed import interpolate_pos_embed 33 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 34 | from util.lars import LARS 35 | from util.crop import RandomResizedCrop 36 | 37 | from models import models_vit 38 | 39 | from engines.engine_finetune import train_one_epoch, evaluate 40 | 41 | 42 | def get_args_parser(): 43 | parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False) 44 | parser.add_argument('--batch_size', default=512, type=int, 45 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 46 | parser.add_argument('--epochs', default=90, type=int) 47 | parser.add_argument('--accum_iter', default=1, type=int, 48 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 49 | 50 | # Model parameters 51 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 52 | help='Name of model to train') 53 | 54 | # Optimizer parameters 55 | parser.add_argument('--weight_decay', type=float, default=0, 56 | help='weight decay (default: 0 for linear probe following MoCo v1)') 57 | 58 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 59 | help='learning rate (absolute lr)') 60 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR', 61 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 62 | 63 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 64 | help='lower lr bound for cyclic schedulers that hit 0') 65 | 66 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 67 | help='epochs to warmup LR') 68 | 69 | # * Finetuning params 70 | parser.add_argument('--finetune', default='', 71 | help='finetune from checkpoint') 72 | parser.add_argument('--global_pool', action='store_true') 73 | parser.set_defaults(global_pool=True) 74 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 75 | help='Use class token instead of global pool for classification') 76 | 77 | # Dataset parameters 78 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 79 | help='dataset path') 80 | parser.add_argument('--nb_classes', default=1000, type=int, 81 | help='number of the classification types') 82 | 83 | parser.add_argument('--output_dir', default='./output_dir', 84 | help='path where to save, empty for no saving') 85 | parser.add_argument('--log_dir', default='./output_dir', 86 | help='path where to tensorboard log') 87 | parser.add_argument('--device', default='cuda', 88 | help='device to use for training / testing') 89 | parser.add_argument('--seed', default=0, type=int) 90 | parser.add_argument('--resume', default='', 91 | help='resume from checkpoint') 92 | parser.add_argument('--experiment', default='exp', type=str, help='experiment name (for log)') 93 | 94 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 95 | help='start epoch') 96 | parser.add_argument('--eval', action='store_true', 97 | help='Perform evaluation only') 98 | parser.add_argument('--dist_eval', action='store_true', default=False, 99 | help='Enabling distributed evaluation (recommended during training for faster monitor') 100 | parser.add_argument('--num_workers', default=10, type=int) 101 | parser.add_argument('--pin_mem', action='store_true', 102 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 103 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 104 | parser.set_defaults(pin_mem=False) 105 | 106 | # distributed training parameters 107 | parser.add_argument('--world_size', default=1, type=int, 108 | help='number of distributed processes') 109 | parser.add_argument('--local_rank', default=-1, type=int) 110 | parser.add_argument('--dist_on_itp', action='store_true') 111 | parser.add_argument('--dist_url', default='env://', 112 | help='url used to set up distributed training') 113 | parser.add_argument('--dist_backend', default='nccl', type=str, help='experiment name (for log)') 114 | 115 | return parser 116 | 117 | 118 | def main(args): 119 | misc.init_distributed_mode(args) 120 | 121 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 122 | print("{}".format(args).replace(', ', ',\n')) 123 | 124 | device = torch.device(args.device) 125 | 126 | # fix the seed for reproducibility 127 | seed = args.seed + misc.get_rank() 128 | torch.manual_seed(seed) 129 | np.random.seed(seed) 130 | 131 | cudnn.benchmark = True 132 | 133 | # linear probe: weak augmentation 134 | transform_train = transforms.Compose([ 135 | # RandomResizedCrop(224, interpolation=transforms.InterpolationMode.BICUBIC), 136 | transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), 137 | transforms.RandomCrop((224, 224)), 138 | transforms.RandomHorizontalFlip(), 139 | transforms.ToTensor(), 140 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 141 | transform_val = transforms.Compose([ 142 | transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), 143 | transforms.CenterCrop(224), 144 | transforms.ToTensor(), 145 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 146 | 147 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 148 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) 149 | print(dataset_train) 150 | print(dataset_val) 151 | 152 | if True: # args.distributed: 153 | num_tasks = misc.get_world_size() 154 | global_rank = misc.get_rank() 155 | sampler_train = torch.utils.data.DistributedSampler( 156 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 157 | ) 158 | print("Sampler_train = %s" % str(sampler_train)) 159 | if args.dist_eval: 160 | if len(dataset_val) % num_tasks != 0: 161 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 162 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 163 | 'equal num of samples per-process.') 164 | sampler_val = torch.utils.data.DistributedSampler( 165 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 166 | else: 167 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 168 | else: 169 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 170 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 171 | 172 | data_loader_train = torch.utils.data.DataLoader( 173 | dataset_train, sampler=sampler_train, 174 | batch_size=args.batch_size, 175 | num_workers=args.num_workers, 176 | pin_memory=args.pin_mem, 177 | drop_last=True, 178 | ) 179 | 180 | data_loader_val = torch.utils.data.DataLoader( 181 | dataset_val, sampler=sampler_val, 182 | batch_size=args.batch_size, 183 | num_workers=args.num_workers, 184 | pin_memory=args.pin_mem, 185 | drop_last=False 186 | ) 187 | 188 | model = models_vit.__dict__[args.model]( 189 | num_classes=args.nb_classes, 190 | global_pool=args.global_pool, 191 | ) 192 | 193 | if args.finetune and not args.eval: 194 | checkpoint = torch.load(args.finetune, map_location='cpu') 195 | 196 | print("Load pre-trained checkpoint from: %s" % args.finetune) 197 | if 'state_dict' in checkpoint: 198 | checkpoint_model = checkpoint['state_dict'] 199 | else: 200 | checkpoint_model = checkpoint['model'] 201 | state_dict = model.state_dict() 202 | checkpoint_model = {k.replace("module.", ""): v for k, v in checkpoint_model.items()} 203 | for k in ['head.weight', 'head.bias']: 204 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 205 | print(f"Removing key {k} from pretrained checkpoint") 206 | del checkpoint_model[k] 207 | 208 | # interpolate position embedding 209 | interpolate_pos_embed(model, checkpoint_model) 210 | 211 | # load pre-trained model 212 | msg = model.load_state_dict(checkpoint_model, strict=False) 213 | print(msg) 214 | 215 | print("global_pool =", args.global_pool) 216 | 217 | if args.global_pool: 218 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 219 | else: 220 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 221 | 222 | # manually initialize fc layer: following MoCo v3 223 | trunc_normal_(model.head.weight, std=0.01) 224 | 225 | # for linear prob only 226 | # hack: revise model's head with BN 227 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) 228 | # freeze all but the head 229 | for _, p in model.named_parameters(): 230 | p.requires_grad = False 231 | for _, p in model.head.named_parameters(): 232 | p.requires_grad = True 233 | 234 | model.to(device) 235 | 236 | model_without_ddp = model 237 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 238 | 239 | # print("Model = %s" % str(model_without_ddp)) 240 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 241 | 242 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 243 | 244 | if args.lr is None: # only base_lr is specified 245 | args.lr = args.blr * eff_batch_size / 256 246 | 247 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 248 | print("actual lr: %.2e" % args.lr) 249 | 250 | print("accumulate grad iterations: %d" % args.accum_iter) 251 | print("effective batch size: %d" % eff_batch_size) 252 | 253 | if args.distributed: 254 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 255 | model_without_ddp = model.module 256 | 257 | # optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) 258 | optimizer = torch.optim.SGD(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay, 259 | momentum=0.9) 260 | print(optimizer) 261 | loss_scaler = NativeScaler() 262 | 263 | criterion = torch.nn.CrossEntropyLoss() 264 | 265 | print("criterion = %s" % str(criterion)) 266 | 267 | # resume model 268 | ckpt_path = os.path.join(args.output_dir, f"{args.model}.{args.experiment}.linear.pth") 269 | if not os.path.isfile(ckpt_path): 270 | print("Checkpoint not founded in {}, train from random initialization".format(ckpt_path)) 271 | else: 272 | print("Found checkpoint at {}".format(ckpt_path)) 273 | misc.load_model(args=args, ckpt_path=ckpt_path, model_without_ddp=model, optimizer=optimizer, 274 | loss_scaler=loss_scaler) 275 | 276 | if args.eval: 277 | test_stats = evaluate(data_loader_val, model, device) 278 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 279 | exit(0) 280 | 281 | if global_rank == 0 and args.log_dir is not None and not args.eval: 282 | log_dir = os.path.join(args.log_dir, f"{args.model}.{args.experiment}") 283 | os.makedirs(log_dir, exist_ok=True) 284 | log_writer = SummaryWriter(log_dir=log_dir) 285 | else: 286 | log_writer = None 287 | 288 | print(f"Start training for {args.epochs} epochs") 289 | start_time = time.time() 290 | max_accuracy = 0.0 291 | for epoch in range(args.start_epoch, args.epochs): 292 | if args.distributed: 293 | data_loader_train.sampler.set_epoch(epoch) 294 | train_stats = train_one_epoch( 295 | model, criterion, data_loader_train, 296 | optimizer, device, epoch, loss_scaler, 297 | max_norm=None, 298 | log_writer=log_writer, 299 | args=args 300 | ) 301 | 302 | save_dict = { 303 | "epoch": epoch + 1, 304 | "state_dict": model.state_dict(), 305 | "optimizer": optimizer.state_dict(), 306 | "model": args.model, 307 | } 308 | if loss_scaler is not None: 309 | save_dict['loss_scaler'] = loss_scaler.state_dict() 310 | 311 | ckpt_path = os.path.join(args.output_dir, f"{args.model}.{args.experiment}.linear.pth") 312 | misc.save_on_master(save_dict, ckpt_path) 313 | print(f"model_path: {ckpt_path}") 314 | 315 | test_stats = evaluate(data_loader_val, model, device) 316 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 317 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 318 | print(f'Max accuracy: {max_accuracy:.2f}%') 319 | 320 | if log_writer is not None: 321 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 322 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 323 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 324 | 325 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 326 | **{f'test_{k}': v for k, v in test_stats.items()}, 327 | 'epoch': epoch, 328 | 'n_parameters': n_parameters} 329 | 330 | if args.output_dir and misc.is_main_process(): 331 | if log_writer is not None: 332 | log_writer.flush() 333 | with open(os.path.join(args.output_dir,"{}.{}.log.txt".format(args.model, args.experiment)), mode="a", encoding="utf-8") as f: 334 | f.write(json.dumps(log_stats) + "\n") 335 | 336 | total_time = time.time() - start_time 337 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 338 | print('Training time {}'.format(total_time_str)) 339 | 340 | 341 | if __name__ == '__main__': 342 | args = get_args_parser() 343 | args = args.parse_args() 344 | if args.output_dir: 345 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 346 | main(args) 347 | -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # DeiT: https://github.com/facebookresearch/deit 6 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 7 | # MAE: https://github.com/facebookresearch/mae 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import datetime 12 | import json 13 | import numpy as np 14 | import os 15 | import time 16 | from pathlib import Path 17 | import builtins 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | import timm 24 | 25 | # assert timm.__version__ == "0.3.2" # version check 26 | from timm.models.layers import trunc_normal_ 27 | from timm.data.mixup import Mixup 28 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 29 | 30 | import util.lr_decay as lrd 31 | import util.misc as misc 32 | from datasets.datasets import ImageListFolder, build_transform 33 | from util.pos_embed import interpolate_pos_embed 34 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 35 | 36 | from models import models_vit 37 | 38 | from engines.engine_finetune import train_one_epoch, evaluate 39 | 40 | 41 | def get_args_parser(): 42 | parser = argparse.ArgumentParser('UM-MAE fine-tuning for image classification', add_help=False) 43 | parser.add_argument('--batch_size', default=64, type=int, 44 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 45 | parser.add_argument('--epochs', default=50, type=int) 46 | parser.add_argument('--accum_iter', default=1, type=int, 47 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 48 | 49 | # Model parameters 50 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 51 | help='Name of model to train') 52 | 53 | parser.add_argument('--input_size', default=224, type=int, 54 | help='images input size') 55 | 56 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 57 | help='Drop path rate (default: 0.1)') 58 | 59 | # Optimizer parameters 60 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 61 | help='Clip gradient norm (default: None, no clipping)') 62 | parser.add_argument('--weight_decay', type=float, default=0.05, 63 | help='weight decay (default: 0.05)') 64 | 65 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 66 | help='learning rate (absolute lr)') 67 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 68 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 69 | parser.add_argument('--layer_decay', type=float, default=0.75, 70 | help='layer-wise lr decay from ELECTRA/BEiT') 71 | 72 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 73 | help='lower lr bound for cyclic schedulers that hit 0') 74 | 75 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 76 | help='epochs to warmup LR') 77 | 78 | # Augmentation parameters 79 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 80 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 81 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 82 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 83 | parser.add_argument('--smoothing', type=float, default=0.1, 84 | help='Label smoothing (default: 0.1)') 85 | 86 | # * Random Erase params 87 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 88 | help='Random erase prob (default: 0.25)') 89 | parser.add_argument('--remode', type=str, default='pixel', 90 | help='Random erase mode (default: "pixel")') 91 | parser.add_argument('--recount', type=int, default=1, 92 | help='Random erase count (default: 1)') 93 | parser.add_argument('--resplit', action='store_true', default=False, 94 | help='Do not random erase first (clean) augmentation split') 95 | 96 | # * Mixup params 97 | parser.add_argument('--mixup', type=float, default=0, 98 | help='mixup alpha, mixup enabled if > 0.') 99 | parser.add_argument('--cutmix', type=float, default=0, 100 | help='cutmix alpha, cutmix enabled if > 0.') 101 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 102 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 103 | parser.add_argument('--mixup_prob', type=float, default=1.0, 104 | help='Probability of performing mixup or cutmix when either/both is enabled') 105 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 106 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 107 | parser.add_argument('--mixup_mode', type=str, default='batch', 108 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 109 | 110 | # * Finetuning params 111 | parser.add_argument('--finetune', default='', 112 | help='finetune from checkpoint') 113 | parser.add_argument('--global_pool', action='store_true') 114 | parser.set_defaults(global_pool=True) 115 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 116 | help='Use class token instead of global pool for classification') 117 | 118 | # Dataset parameters 119 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 120 | help='dataset path') 121 | parser.add_argument('--nb_classes', default=1000, type=int, 122 | help='number of the classification types') 123 | 124 | parser.add_argument('--output_dir', default='./output_dir', 125 | help='path where to save, empty for no saving') 126 | parser.add_argument('--log_dir', default='./output_dir', 127 | help='path where to tensorboard log') 128 | parser.add_argument('--saveckp_freq', default=20, type=int, help='Save checkpoint every x epochs.') 129 | parser.add_argument('--device', default='cuda', 130 | help='device to use for training / testing') 131 | parser.add_argument('--seed', default=0, type=int) 132 | parser.add_argument('--resume', default='', 133 | help='resume from checkpoint') 134 | parser.add_argument('--experiment', default='exp', type=str, help='experiment name (for log)') 135 | 136 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 137 | help='start epoch') 138 | parser.add_argument('--eval', action='store_true', 139 | help='Perform evaluation only') 140 | parser.add_argument('--dist_eval', action='store_true', default=False, 141 | help='Enabling distributed evaluation (recommended during training for faster monitor') 142 | parser.add_argument('--num_workers', default=10, type=int) 143 | parser.add_argument('--pin_mem', action='store_true', 144 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 145 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 146 | parser.set_defaults(pin_mem=False) 147 | 148 | # distributed training parameters 149 | parser.add_argument('--world_size', default=1, type=int, 150 | help='number of distributed processes') 151 | parser.add_argument('--local_rank', default=-1, type=int) 152 | parser.add_argument('--dist_on_itp', action='store_true') 153 | parser.add_argument('--dist_url', default='env://', 154 | help='url used to set up distributed training') 155 | parser.add_argument('--dist_backend', default='nccl', type=str, help='experiment name (for log)') 156 | 157 | return parser 158 | 159 | 160 | def main(args): 161 | misc.init_distributed_mode(args) 162 | 163 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 164 | print("{}".format(args).replace(', ', ',\n')) 165 | 166 | device = torch.device(args.device) 167 | 168 | # fix the seed for reproducibility 169 | seed = args.seed + misc.get_rank() 170 | torch.manual_seed(seed) 171 | np.random.seed(seed) 172 | 173 | cudnn.benchmark = True 174 | 175 | transform_train = build_transform(is_train=True, args=args) 176 | transform_val = build_transform(is_train=False, args=args) 177 | dataset_train = ImageListFolder(os.path.join(args.data_path, 'train'), transform=transform_train, 178 | ann_file=os.path.join(args.data_path, 'train.txt')) 179 | print(dataset_train) 180 | num_tasks = misc.get_world_size() 181 | global_rank = misc.get_rank() 182 | sampler_train = torch.utils.data.DistributedSampler( 183 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 184 | ) 185 | print("Sampler_train = %s" % str(sampler_train)) 186 | 187 | data_loader_train = torch.utils.data.DataLoader( 188 | dataset_train, sampler=sampler_train, 189 | batch_size=args.batch_size, 190 | num_workers=args.num_workers, 191 | pin_memory=args.pin_mem, 192 | drop_last=True, 193 | ) 194 | 195 | 196 | dataset_val = ImageListFolder(os.path.join(args.data_path, 'train'), transform=transform_val, 197 | ann_file=os.path.join(args.data_path, 'train.txt')) 198 | num_tasks = misc.get_world_size() 199 | global_rank = misc.get_rank() 200 | sampler_val = torch.utils.data.DistributedSampler( 201 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False 202 | ) 203 | print("Sampler_val = %s" % str(sampler_val)) 204 | 205 | data_loader_val = torch.utils.data.DataLoader( 206 | dataset_val, sampler=sampler_val, 207 | batch_size=args.batch_size, 208 | num_workers=args.num_workers, 209 | pin_memory=args.pin_mem, 210 | drop_last=False, 211 | shuffle=False, 212 | ) 213 | 214 | mixup_fn = None 215 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 216 | if mixup_active: 217 | print("Mixup is activated!") 218 | mixup_fn = Mixup( 219 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 220 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 221 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 222 | 223 | model = models_vit.__dict__[args.model]( 224 | num_classes=args.nb_classes, 225 | drop_path_rate=args.drop_path, 226 | global_pool=args.global_pool, 227 | ) 228 | 229 | if args.finetune and not args.eval: 230 | # load pretrained model 231 | checkpoint = torch.load(args.finetune, map_location='cpu') 232 | 233 | print("Load pre-trained checkpoint from: %s" % args.finetune) 234 | if 'state_dict' in checkpoint: 235 | checkpoint_model = checkpoint['state_dict'] 236 | else: 237 | checkpoint_model = checkpoint['model'] 238 | state_dict = model.state_dict() 239 | checkpoint_model = {k.replace("module.", ""): v for k, v in checkpoint_model.items()} 240 | 241 | for k in ['head.weight', 'head.bias']: 242 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 243 | print(f"Removing key {k} from pretrained checkpoint") 244 | del checkpoint_model[k] 245 | 246 | # interpolate position embedding 247 | interpolate_pos_embed(model, checkpoint_model) 248 | 249 | # load pre-trained model 250 | msg = model.load_state_dict(checkpoint_model, strict=False) 251 | print(msg) 252 | 253 | print("global_pool = ", args.global_pool) 254 | 255 | if args.global_pool: 256 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 257 | else: 258 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 259 | 260 | # manually initialize fc layer 261 | trunc_normal_(model.head.weight, std=2e-5) 262 | 263 | model.to(device) 264 | 265 | model_without_ddp = model 266 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 267 | 268 | # print("Model = %s" % str(model_without_ddp)) 269 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 270 | 271 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 272 | 273 | if args.lr is None: # only base_lr is specified 274 | args.lr = args.blr * eff_batch_size / 256 275 | 276 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 277 | print("actual lr: %.2e" % args.lr) 278 | 279 | print("accumulate grad iterations: %d" % args.accum_iter) 280 | print("effective batch size: %d" % eff_batch_size) 281 | 282 | if args.distributed: 283 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 284 | model_without_ddp = model.module 285 | 286 | # build optimizer with layer-wise lr decay (lrd) 287 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 288 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 289 | layer_decay=args.layer_decay 290 | ) 291 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 292 | loss_scaler = NativeScaler() 293 | 294 | if mixup_fn is not None: 295 | # smoothing is handled with mixup label transform 296 | criterion = SoftTargetCrossEntropy() 297 | elif args.smoothing > 0.: 298 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 299 | else: 300 | criterion = torch.nn.CrossEntropyLoss() 301 | 302 | print("criterion = %s" % str(criterion)) 303 | 304 | # resume model 305 | ckpt_path = os.path.join(args.output_dir, f"{args.model}.{args.experiment}.temp.pth") 306 | if not os.path.isfile(ckpt_path): 307 | print("Checkpoint not founded in {}, train from random initialization".format(ckpt_path)) 308 | else: 309 | print("Found checkpoint at {}".format(ckpt_path)) 310 | misc.load_model(args=args, ckpt_path=ckpt_path, model_without_ddp=model, optimizer=optimizer, 311 | loss_scaler=loss_scaler) 312 | 313 | if args.eval: 314 | test_stats = evaluate(data_loader_val, model, device) 315 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 316 | exit(0) 317 | 318 | if global_rank == 0 and args.log_dir is not None and not args.eval: 319 | log_dir = os.path.join(args.log_dir, f"{args.model}.{args.experiment}") 320 | os.makedirs(log_dir, exist_ok=True) 321 | log_writer = SummaryWriter(log_dir=log_dir) 322 | else: 323 | log_writer = None 324 | 325 | print(f"Start training for {args.epochs} epochs") 326 | start_time = time.time() 327 | max_accuracy = 0.0 328 | for epoch in range(args.start_epoch, args.epochs): 329 | if args.distributed: 330 | data_loader_train.sampler.set_epoch(epoch) 331 | train_stats = train_one_epoch( 332 | model, criterion, data_loader_train, 333 | optimizer, device, epoch, loss_scaler, 334 | args.clip_grad, mixup_fn, 335 | log_writer=log_writer, 336 | args=args 337 | ) 338 | 339 | save_dict = { 340 | "epoch": epoch + 1, 341 | "state_dict": model.state_dict(), 342 | "optimizer": optimizer.state_dict(), 343 | "model": args.model, 344 | } 345 | if loss_scaler is not None: 346 | save_dict['loss_scaler'] = loss_scaler.state_dict() 347 | 348 | ckpt_path = os.path.join(args.output_dir, f"{args.model}.{args.experiment}.temp.pth") 349 | misc.save_on_master(save_dict, ckpt_path) 350 | print(f"model_path: {ckpt_path}") 351 | 352 | test_stats = evaluate(data_loader_val, model, device) 353 | print(f"Pretrained from: {args.finetune}") 354 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 355 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 356 | print(f'Max accuracy: {max_accuracy:.2f}%') 357 | 358 | if log_writer is not None: 359 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 360 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 361 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 362 | 363 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 364 | **{f'test_{k}': v for k, v in test_stats.items()}, 365 | 'epoch': epoch, 366 | 'n_parameters': n_parameters} 367 | 368 | if args.output_dir and misc.is_main_process(): 369 | if log_writer is not None: 370 | log_writer.flush() 371 | with open(os.path.join(args.output_dir,"{}.{}.log.txt".format(args.model, args.experiment)), mode="a", encoding="utf-8") as f: 372 | f.write(json.dumps(log_stats) + "\n") 373 | 374 | total_time = time.time() - start_time 375 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 376 | print('Training time {}'.format(total_time_str)) 377 | 378 | 379 | if __name__ == '__main__': 380 | if not misc.is_main_process(): 381 | def print_pass(*args): 382 | pass 383 | builtins.print = print_pass 384 | 385 | args = get_args_parser() 386 | args = args.parse_args() 387 | if args.output_dir: 388 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 389 | main(args) 390 | -------------------------------------------------------------------------------- /models/models_clip.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | return ret.type(orig_type) 164 | 165 | 166 | class QuickGELU(nn.Module): 167 | def forward(self, x: torch.Tensor): 168 | return x * torch.sigmoid(1.702 * x) 169 | 170 | 171 | class ResidualAttentionBlock(nn.Module): 172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 173 | super().__init__() 174 | 175 | self.attn = nn.MultiheadAttention(d_model, n_head) 176 | self.ln_1 = LayerNorm(d_model) 177 | self.mlp = nn.Sequential(OrderedDict([ 178 | ("c_fc", nn.Linear(d_model, d_model * 4)), 179 | ("gelu", QuickGELU()), 180 | ("c_proj", nn.Linear(d_model * 4, d_model)) 181 | ])) 182 | self.ln_2 = LayerNorm(d_model) 183 | self.attn_mask = attn_mask 184 | 185 | def attention(self, x: torch.Tensor): 186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 188 | 189 | def forward(self, x: torch.Tensor): 190 | x = x + self.attention(self.ln_1(x)) 191 | x = x + self.mlp(self.ln_2(x)) 192 | return x 193 | 194 | 195 | class Transformer(nn.Module): 196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 197 | super().__init__() 198 | self.width = width 199 | self.layers = layers 200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 201 | 202 | def forward(self, x: torch.Tensor): 203 | return self.resblocks(x) 204 | 205 | 206 | class VisionTransformer(nn.Module): 207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 208 | super().__init__() 209 | self.input_resolution = input_resolution 210 | self.output_dim = output_dim 211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 212 | 213 | scale = width ** -0.5 214 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 216 | self.ln_pre = LayerNorm(width) 217 | 218 | self.transformer = Transformer(width, layers, heads) 219 | 220 | self.ln_post = LayerNorm(width) 221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 222 | 223 | def forward(self, x: torch.Tensor): 224 | x = self.conv1(x) # shape = [*, width, grid, grid] 225 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 226 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 227 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 228 | x = x + self.positional_embedding.to(x.dtype) 229 | x = self.ln_pre(x) 230 | 231 | x = x.permute(1, 0, 2) # NLD -> LND 232 | x = self.transformer(x) 233 | x = x.permute(1, 0, 2) # LND -> NLD 234 | 235 | x = self.ln_post(x[:, 0, :]) 236 | 237 | if self.proj is not None: 238 | x = x @ self.proj 239 | 240 | return x 241 | 242 | 243 | class CLIP(nn.Module): 244 | def __init__(self, 245 | embed_dim: int, 246 | # vision 247 | image_resolution: int, 248 | vision_layers: Union[Tuple[int, int, int, int], int], 249 | vision_width: int, 250 | vision_patch_size: int, 251 | # text 252 | context_length: int, 253 | vocab_size: int, 254 | transformer_width: int, 255 | transformer_heads: int, 256 | transformer_layers: int 257 | ): 258 | super().__init__() 259 | 260 | self.context_length = context_length 261 | 262 | if isinstance(vision_layers, (tuple, list)): 263 | vision_heads = vision_width * 32 // 64 264 | self.visual = ModifiedResNet( 265 | layers=vision_layers, 266 | output_dim=embed_dim, 267 | heads=vision_heads, 268 | input_resolution=image_resolution, 269 | width=vision_width 270 | ) 271 | else: 272 | vision_heads = vision_width // 64 273 | self.visual = VisionTransformer( 274 | input_resolution=image_resolution, 275 | patch_size=vision_patch_size, 276 | width=vision_width, 277 | layers=vision_layers, 278 | heads=vision_heads, 279 | output_dim=embed_dim 280 | ) 281 | 282 | self.transformer = Transformer( 283 | width=transformer_width, 284 | layers=transformer_layers, 285 | heads=transformer_heads, 286 | attn_mask=self.build_attention_mask() 287 | ) 288 | 289 | self.vocab_size = vocab_size 290 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 291 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 292 | self.ln_final = LayerNorm(transformer_width) 293 | 294 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 295 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 296 | 297 | self.initialize_parameters() 298 | 299 | def initialize_parameters(self): 300 | nn.init.normal_(self.token_embedding.weight, std=0.02) 301 | nn.init.normal_(self.positional_embedding, std=0.01) 302 | 303 | if isinstance(self.visual, ModifiedResNet): 304 | if self.visual.attnpool is not None: 305 | std = self.visual.attnpool.c_proj.in_features ** -0.5 306 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 307 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 308 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 309 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 310 | 311 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 312 | for name, param in resnet_block.named_parameters(): 313 | if name.endswith("bn3.weight"): 314 | nn.init.zeros_(param) 315 | 316 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 317 | attn_std = self.transformer.width ** -0.5 318 | fc_std = (2 * self.transformer.width) ** -0.5 319 | for block in self.transformer.resblocks: 320 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 321 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 322 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 323 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 324 | 325 | if self.text_projection is not None: 326 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 327 | 328 | def build_attention_mask(self): 329 | # lazily create causal attention mask, with full attention between the vision tokens 330 | # pytorch uses additive attention mask; fill with -inf 331 | mask = torch.empty(self.context_length, self.context_length) 332 | mask.fill_(float("-inf")) 333 | mask.triu_(1) # zero out the lower diagonal 334 | return mask 335 | 336 | @property 337 | def dtype(self): 338 | return self.visual.conv1.weight.dtype 339 | 340 | def encode_image(self, image): 341 | return self.visual(image.type(self.dtype)) 342 | 343 | def encode_text(self, text): 344 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 345 | 346 | x = x + self.positional_embedding.type(self.dtype) 347 | x = x.permute(1, 0, 2) # NLD -> LND 348 | x = self.transformer(x) 349 | x = x.permute(1, 0, 2) # LND -> NLD 350 | x = self.ln_final(x).type(self.dtype) 351 | 352 | # x.shape = [batch_size, n_ctx, transformer.width] 353 | # take features from the eot embedding (eot_token is the highest number in each sequence) 354 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 355 | 356 | return x 357 | 358 | def forward(self, image, text): 359 | image_features = self.encode_image(image) 360 | text_features = self.encode_text(text) 361 | 362 | # normalized features 363 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 364 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 365 | 366 | # cosine similarity as logits 367 | logit_scale = self.logit_scale.exp() 368 | logits_per_image = logit_scale * image_features @ text_features.t() 369 | logits_per_text = logits_per_image.t() 370 | 371 | # shape = [global_batch_size, global_batch_size] 372 | return logits_per_image, logits_per_text 373 | 374 | 375 | def convert_weights(model: nn.Module): 376 | """Convert applicable model parameters to fp16""" 377 | 378 | def _convert_weights_to_fp16(l): 379 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 380 | l.weight.data = l.weight.data.half() 381 | if l.bias is not None: 382 | l.bias.data = l.bias.data.half() 383 | 384 | if isinstance(l, nn.MultiheadAttention): 385 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 386 | tensor = getattr(l, attr) 387 | if tensor is not None: 388 | tensor.data = tensor.data.half() 389 | 390 | for name in ["text_projection", "proj"]: 391 | if hasattr(l, name): 392 | attr = getattr(l, name) 393 | if attr is not None: 394 | attr.data = attr.data.half() 395 | 396 | model.apply(_convert_weights_to_fp16) 397 | 398 | 399 | def build_model(state_dict: dict): 400 | vit = "visual.proj" in state_dict 401 | 402 | if vit: 403 | vision_width = state_dict["visual.conv1.weight"].shape[0] 404 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 405 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 406 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 407 | image_resolution = vision_patch_size * grid_size 408 | else: 409 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 410 | vision_layers = tuple(counts) 411 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 412 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 413 | vision_patch_size = None 414 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 415 | image_resolution = output_width * 32 416 | 417 | embed_dim = state_dict["text_projection"].shape[1] 418 | context_length = state_dict["positional_embedding"].shape[0] 419 | vocab_size = state_dict["token_embedding.weight"].shape[0] 420 | transformer_width = state_dict["ln_final.weight"].shape[0] 421 | transformer_heads = transformer_width // 64 422 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 423 | 424 | model = CLIP( 425 | embed_dim, 426 | image_resolution, vision_layers, vision_width, vision_patch_size, 427 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 428 | ) 429 | 430 | for key in ["input_resolution", "context_length", "vocab_size"]: 431 | if key in state_dict: 432 | del state_dict[key] 433 | 434 | convert_weights(model) 435 | model.load_state_dict(state_dict) 436 | return model.eval() -------------------------------------------------------------------------------- /models/models_semaim.py: -------------------------------------------------------------------------------- 1 | # References: 2 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 3 | # DeiT: https://github.com/facebookresearch/deit 4 | # MAE: https://github.com/facebookresearch/mae 5 | # -------------------------------------------------------- 6 | import math 7 | from functools import partial 8 | 9 | import torch 10 | import torch.nn as nn 11 | from timm.models.vision_transformer import PatchEmbed, Mlp 12 | 13 | from util.pos_embed import get_2d_sincos_pos_embed 14 | from util.blocks import GaussianConv2d 15 | from util.blocks import Block_SelfMask, Block_SelfCrossMask 16 | 17 | 18 | class AimViT(nn.Module): 19 | """ 20 | Pretrain vision transformer backbone with AIM 21 | parall encoder-decoder architecture 22 | Modified by sky: use the blocks in ViT (+ mask) for encoders, which is more convinent for finetune, linear 23 | modify the permutation form stochastic mask to center-out mask 24 | """ 25 | 26 | def __init__(self, 27 | # vision transformer backbone 28 | img_size=224, patch_size=16, in_chans=3, 29 | embed_dim=1024, depth=24, num_heads=16, drop_path_rate=0., out_dim=768, 30 | mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), 31 | # aim 32 | permutation_type='center2out', attention_type='cls', 33 | # decoder 34 | query_depth=12, share_weight=False, 35 | prediction_head_type='MLP', 36 | # loss function 37 | gaussian_kernel_size=None, gaussian_sigma=None, 38 | loss_type='L2', predict_feature='none', norm_pix_loss=True): 39 | super().__init__() 40 | 41 | # patch embedding 42 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 43 | num_patches = self.patch_embed.num_patches 44 | self.patch_size = patch_size 45 | 46 | # cls token 47 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 48 | 49 | # position embedding 50 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) 51 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 52 | 53 | # encoder 54 | self.blocks = nn.ModuleList([ 55 | Block_SelfMask(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_path=dpr[i]) 56 | for i in range(depth)]) 57 | 58 | # decoder 59 | if share_weight: 60 | self.query_blocks = self.blocks 61 | else: 62 | self.query_blocks = nn.ModuleList([ 63 | Block_SelfCrossMask(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_path=dpr[i]) 64 | for i in range(query_depth)]) 65 | self.depth = depth 66 | self.step = depth // query_depth 67 | 68 | # permutation type 69 | self.permutation_type = permutation_type 70 | 71 | # prediction head 72 | self.norm = norm_layer(embed_dim) 73 | self.predict_feature = predict_feature 74 | self.attention_type = attention_type 75 | if prediction_head_type == 'LINEAR': 76 | if predict_feature == 'none': 77 | self.prediction_head = nn.Linear(embed_dim, patch_size ** 2 * 3) 78 | else: 79 | rec_dim = out_dim if predict_feature == 'clip' else embed_dim 80 | self.prediction_head = nn.Linear(embed_dim, rec_dim) 81 | elif prediction_head_type == 'MLP': 82 | if predict_feature == 'none': 83 | self.prediction_head = Mlp(embed_dim, int(embed_dim * mlp_ratio), patch_size ** 2 * 3) 84 | else: 85 | rec_dim = out_dim if predict_feature == 'clip' else embed_dim 86 | self.prediction_head = Mlp(embed_dim, int(embed_dim * mlp_ratio), rec_dim) 87 | 88 | # define loss parameters 89 | self.loss_type = loss_type 90 | self.norm_pix_loss = norm_pix_loss 91 | if gaussian_kernel_size is not None and gaussian_sigma is not None and self.predict_feature == 'none': 92 | self.gaussian_blur = GaussianConv2d(3, gaussian_kernel_size, gaussian_sigma) 93 | else: 94 | self.gaussian_blur = nn.Identity() 95 | 96 | # spilit matrix for guided center permutation 97 | num_patch = img_size // patch_size 98 | split_matrix = torch.zeros((num_patch, 2, 4)) 99 | split_matrix[0, :, :] = torch.tensor([[0, 0, 0, 0], [2, 6, 10, 13]]) 100 | split_matrix[1, :, :] = torch.tensor([[0, 0, 0, 0], [2, 6, 10, 13]]) 101 | split_matrix[2, :, :] = torch.tensor([[0, 0, 0, 0], [2, 6, 10, 13]]) 102 | split_matrix[3, :, :] = torch.tensor([[2, 0, 0, 0], [4, 6, 10, 13]]) 103 | split_matrix[4, :, :] = torch.tensor([[3, 1, 0, 0], [5, 7, 10, 13]]) 104 | split_matrix[5, :, :] = torch.tensor([[4, 2, 0, 0], [6, 8, 10, 13]]) 105 | split_matrix[6, :, :] = torch.tensor([[5, 3, 1, 0], [7, 9, 11, 13]]) 106 | split_matrix[7, :, :] = torch.tensor([[6, 4, 2, 0], [8, 10, 12, 13]]) 107 | split_matrix[8, :, :] = torch.tensor([[7, 5, 3, 0], [9, 11, 13, 13]]) 108 | split_matrix[9, :, :] = torch.tensor([[8, 6, 3, 0], [10, 12, 13, 13]]) 109 | split_matrix[10, :, :] = torch.tensor([[9, 7, 3, 0], [11, 13, 13, 13]]) 110 | split_matrix[11, :, :] = torch.tensor([[11, 7, 3, 0], [13, 13, 13, 13]]) 111 | split_matrix[12, :, :] = torch.tensor([[11, 7, 3, 0], [13, 13, 13, 13]]) 112 | split_matrix[13, :, :] = torch.tensor([[11, 7, 3, 0], [13, 13, 13, 13]]) 113 | self.split_matrix = split_matrix 114 | 115 | # coordinates for patches (row, col) 116 | coordinates = torch.zeros((num_patches, 2)) 117 | for i in range(num_patch): 118 | for j in range(num_patch): 119 | coordinates[i*num_patch+j, 0] = i # row 120 | coordinates[i*num_patch+j, 1] = j # col 121 | self.coordinates = coordinates.unsqueeze(0) 122 | 123 | # initialize weight 124 | self.initialize_weights() 125 | 126 | def initialize_weights(self): 127 | # initialization 128 | # initialize (and freeze) pos_embed by sin-cos embedding 129 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), 130 | cls_token=True) 131 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 132 | 133 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 134 | w = self.patch_embed.proj.weight.data 135 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 136 | 137 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 138 | torch.nn.init.normal_(self.cls_token, std=.02) 139 | 140 | # initialize nn.Linear and nn.LayerNorm 141 | self.apply(self._init_weights) 142 | 143 | def _init_weights(self, m): 144 | if isinstance(m, nn.Linear): 145 | # we use xavier_uniform following official JAX ViT: 146 | torch.nn.init.xavier_uniform_(m.weight) 147 | if isinstance(m, nn.Linear) and m.bias is not None: 148 | nn.init.constant_(m.bias, 0) 149 | elif isinstance(m, nn.LayerNorm): 150 | nn.init.constant_(m.bias, 0) 151 | nn.init.constant_(m.weight, 1.0) 152 | 153 | def patchify(self, imgs): 154 | """ 155 | imgs: (N, 3, H, W) 156 | x: (N, L, patch_size**2 *3) 157 | """ 158 | p = self.patch_embed.patch_size[0] 159 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 160 | h = w = imgs.shape[2] // p 161 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 162 | x = torch.einsum('nchpwq->nhwpqc', x) 163 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) 164 | return x 165 | 166 | def unpatchify(self, x): 167 | """ 168 | x: (N, L, patch_size**2 *3) 169 | imgs: (N, 3, H, W) 170 | """ 171 | p = self.patch_embed.patch_size[0] 172 | h = w = int(x.shape[1] ** .5) 173 | assert h * w == x.shape[1] 174 | 175 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 176 | x = torch.einsum('nhwpqc->nchpwq', x) 177 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 178 | return imgs 179 | 180 | def generate_raster_permutation(self, N, L): 181 | """ 182 | Generate raster permutation 183 | small to large 184 | """ 185 | 186 | width = int(L ** 0.5) 187 | permutation = torch.zeros((N, width, width)) 188 | 189 | init_value = 0 190 | odd_row = torch.tensor([13 - i for i in range(width)]) 191 | even_row = torch.tensor([i for i in range(width)]) 192 | for i in range(width): 193 | if i % 2 == 0: 194 | permutation[:, i, :] = even_row + init_value 195 | else: 196 | permutation[:, i, :] = odd_row + init_value 197 | 198 | init_value += width 199 | 200 | # print(permutation) 201 | permutation = permutation.reshape(N, L) 202 | 203 | return permutation 204 | 205 | def generate_center_permutation(self, N, L, center_first=True): 206 | """ 207 | Generate center-out permutation 208 | small to large 209 | """ 210 | 211 | width = int(L ** 0.5) 212 | half_width = width // 2 213 | permutation = torch.rand((N, width, width)) 214 | 215 | if center_first: 216 | # center 6-7: (-3, -2) 217 | permutation[:, half_width-1:half_width+1, half_width-1:half_width+1] -= 1 218 | # surrounding 4-9 (-2 -1) 219 | permutation[:, half_width-3:half_width+3, half_width-3:half_width+3] -= 1 220 | # surrounding 2-11 (-1 -0) 221 | permutation[:, half_width-5:half_width+5, half_width-5:half_width+5] -= 1 222 | # surrounding 0-13 (0 1) 223 | # permutation[:, half_width-7:half_width+7, half_width-7:half_width+7] -= 1 224 | else: 225 | # center 6-7: (-3, -2) 226 | permutation[:, half_width-1:half_width+1, half_width-1:half_width+1] += 1 227 | # surrounding 4-9 (-2 -1) 228 | permutation[:, half_width-3:half_width+3, half_width-3:half_width+3] += 1 229 | # surrounding 2-11 (-1 -0) 230 | permutation[:, half_width-5:half_width+5, half_width-5:half_width+5] += 1 231 | # surrounding 0-13 (0 1) 232 | # permutation[:, half_width-7:half_width+7, half_width-7:half_width+7] += 1 233 | 234 | permutation = permutation.reshape(N, L) 235 | 236 | return permutation 237 | 238 | def generate_stochastic_center_permutation(self, N, L): 239 | """ 240 | Generate stochastic center permutation 241 | small to large 242 | """ 243 | 244 | width = int(L ** 0.5) 245 | permutation = torch.rand((N, width, width)) 246 | 247 | center_row, center_col = torch.rand((N)) * (width - 1), torch.rand((N)) * (width - 1) 248 | 249 | for i in range(N): 250 | row_split = self.split_matrix[int(center_row[i]), :, :] # 2x4 251 | col_split = self.split_matrix[int(center_col[i]), :, :] # 2x4 252 | for j in range(3): 253 | permutation[i, int(row_split[0][j]):int(row_split[1][j]), int(col_split[0][j]):int(col_split[1][j])] -= 1 254 | 255 | permutation = permutation.reshape(N, L) 256 | return permutation 257 | 258 | def generate_guided_center_permutation(self, attention_maps): 259 | """ 260 | Generate attention guided center permutation 261 | small to large 262 | """ 263 | 264 | N, L = attention_maps.shape 265 | width = int(L ** 0.5) 266 | permutation = torch.rand((N, width, width)) 267 | 268 | _, max_index = torch.max(attention_maps, dim=-1) 269 | center_row, center_col = max_index // width, max_index % width 270 | # attention_maps = attention_maps.reshape(N, width, width) 271 | 272 | for i in range(N): 273 | row_split = self.split_matrix[center_row[i], :, :] # 2x4 274 | col_split = self.split_matrix[center_col[i], :, :] # 2x4 275 | for j in range(3): 276 | permutation[i, int(row_split[0][j]):int(row_split[1][j]), int(col_split[0][j]):int(col_split[1][j])] -= 1 277 | 278 | permutation = permutation.reshape(N, L) 279 | return permutation 280 | 281 | def generate_attention_distance_center_permutation(self, attention_maps): 282 | """ 283 | Generate attention guided gaussian center permutation 284 | small to large 285 | """ 286 | 287 | N, L = attention_maps.shape 288 | width = int(L ** 0.5) 289 | 290 | _, max_index = torch.max(attention_maps, dim=-1) 291 | center_row, center_col = max_index // width, max_index % width 292 | 293 | # smaller distance to center, autoregression first 294 | self.coordinates = self.coordinates.cuda() 295 | permutation = (self.coordinates[:, :, 0] - center_row.unsqueeze(1)) ** 2 + (self.coordinates[:, :, 1] - center_col.unsqueeze(1)) ** 2 # N L 296 | permutation = permutation ** 0.5 297 | 298 | # add randomness for patches with the same distance 299 | permutation += torch.rand(N, L).cuda() * 1e-3 300 | 301 | return permutation 302 | 303 | def generate_attention_mask(self, x, attention_maps=None): 304 | """ 305 | Generate permutation mask(content mask and query mask) 306 | """ 307 | N, L, D = x.shape # batch, length, dim 308 | 309 | # generate permutation 310 | if self.permutation_type == 'zigzag': 311 | permutation = [i for i in range(L)] 312 | permutation = torch.tensor(permutation).repeat(N, 1).cuda() 313 | elif self.permutation_type == 'raster': 314 | permutation = self.generate_raster_permutation(N, L).cuda() 315 | elif self.permutation_type == 'stochastic': 316 | permutation = torch.rand(N, L, device=x.device) # noise in [0, 1] 317 | elif self.permutation_type == 'stochastic_center': 318 | permutation = self.generate_stochastic_center_permutation(N, L).cuda() 319 | elif self.permutation_type == 'center2out': 320 | permutation = self.generate_center_permutation(N, L).cuda() 321 | elif self.permutation_type == 'attention': 322 | assert attention_maps != None 323 | assert attention_maps.shape[1] == L 324 | permutation = 1 - attention_maps 325 | 326 | elif self.permutation_type == 'attention_guided': 327 | assert attention_maps != None 328 | assert attention_maps.shape[1] == L 329 | permutation = self.generate_guided_center_permutation(attention_maps).cuda() 330 | 331 | elif self.permutation_type == 'attention_center': 332 | assert attention_maps != None 333 | assert attention_maps.shape[1] == L 334 | permutation = self.generate_attention_distance_center_permutation(attention_maps) 335 | else: 336 | print("Not supported permutation type!") 337 | 338 | # content mask 339 | full_mask = torch.full((N, L, L), -math.inf, device=x.device) 340 | no_mask = torch.zeros((N, L, L), device=x.device) 341 | mask_h = torch.where(permutation.unsqueeze(-1) < permutation.unsqueeze(1), full_mask, no_mask) # broadcast-->N*L*L 342 | 343 | # query mask 344 | mask_g = torch.where(permutation.unsqueeze(-1) <= permutation.unsqueeze(1), full_mask, no_mask) 345 | 346 | # consider cls_token 347 | top_padding = torch.full((N, 1, L), -math.inf, device=x.device) # cls token can't see other tokens 348 | left_padding = torch.zeros((N, L + 1, 1), device=x.device) # other tokens can see cls token 349 | mask_h = torch.cat((top_padding, mask_h), dim=1) 350 | mask_h = torch.cat((left_padding, mask_h), dim=2) 351 | mask_g = torch.cat((top_padding, mask_g), dim=1) 352 | mask_g = torch.cat((left_padding, mask_g), dim=2) 353 | return mask_h.unsqueeze(1), mask_g.unsqueeze(1), permutation 354 | 355 | def forward_aim(self, x, attention_maps=None): 356 | 357 | # embed patches 358 | x = self.patch_embed(x) 359 | 360 | mask_h, mask_g, permutation = self.generate_attention_mask(x, attention_maps) 361 | 362 | # add pos embed w/o cls token 363 | x = x + self.pos_embed[:, 1:, :] 364 | # append cls token 365 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 366 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 367 | x = torch.cat((cls_tokens, x), dim=1) 368 | # permutation mask 369 | h = x 370 | g = self.pos_embed.expand(x.shape[0], -1, -1) # use fixed pos-embedding, not learnable tensor 371 | for i in range(self.depth): 372 | h = self.blocks[i](h, mask=mask_h) 373 | if (i + 1) % self.step == 0: 374 | g = self.query_blocks[i // self.step](g, h, mask=mask_g) 375 | g = self.norm(g) 376 | g = self.prediction_head(g) 377 | 378 | return g, permutation 379 | 380 | def forward_aim_no_mask(self, x, attention_maps=None): 381 | 382 | # embed patches 383 | x = self.patch_embed(x) 384 | 385 | mask_h, mask_g, permutation = self.generate_attention_mask(x, attention_maps) 386 | 387 | # add pos embed w/o cls token 388 | x = x + self.pos_embed[:, 1:, :] 389 | # append cls token 390 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 391 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 392 | x = torch.cat((cls_tokens, x), dim=1) 393 | # permutation mask 394 | h = x 395 | g = self.pos_embed.expand(x.shape[0], -1, -1) # use fixed pos-embedding, not learnable tensor 396 | for i in range(self.depth): 397 | h = self.blocks[i](h) 398 | if (i + 1) % self.step == 0: 399 | g = self.query_blocks[i // self.step](g, h) 400 | g = self.norm(g) 401 | g = self.prediction_head(g) 402 | 403 | return g, permutation 404 | 405 | 406 | ############################ for generate 407 | def generate_raster_permutation_for_generate(self, N, L): 408 | """ 409 | Generate raster permutation 410 | small to large 411 | """ 412 | 413 | width = int(L ** 0.5) 414 | permutation = torch.zeros((N, width, width)) 415 | 416 | init_value = 0 417 | odd_row = torch.tensor([13 - i for i in range(width)]) 418 | even_row = torch.tensor([i for i in range(width)]) 419 | for i in range(width): 420 | if i < width // 2: 421 | continue 422 | if i % 2 == 0: 423 | permutation[:, i, :] = even_row + init_value 424 | else: 425 | permutation[:, i, :] = odd_row + init_value 426 | 427 | init_value += width 428 | 429 | # print(permutation) 430 | permutation = permutation.reshape(N, L) 431 | 432 | return permutation 433 | 434 | def generate_center_permutation_for_generate(self, N, L): 435 | """ 436 | Generate center-out permutation 437 | small to large 438 | """ 439 | 440 | width = int(L ** 0.5) 441 | half_width = width // 2 442 | permutation = torch.zeros((N, width, width)) 443 | 444 | # center 6-7: (-3, -2) 445 | # permutation[:, half_width-1:half_width+1, half_width-1:half_width+1] -= 1 446 | # surrounding 4-9 (-2 -1) 447 | permutation[:, half_width-3:half_width+3, half_width-3:half_width+3] -= 1 448 | # surrounding 2-11 (-1 -0) 449 | permutation[:, half_width-5:half_width+5, half_width-5:half_width+5] -= 1 450 | # surrounding 0-13 (0 1) 451 | permutation[:, half_width-7:half_width+7, half_width-7:half_width+7] -= 1 452 | 453 | 454 | permutation = permutation.reshape(N, L) 455 | 456 | return permutation 457 | 458 | 459 | def generate_attention_mask_for_generate(self, x): 460 | """ 461 | Generate permutation mask(content mask and query mask) 462 | """ 463 | N, L, D = x.shape # batch, length, dim 464 | 465 | # generate permutation 466 | if self.permutation_type == 'raster': 467 | permutation = self.generate_raster_permutation_for_generate(N, L).cuda() 468 | elif self.permutation_type == 'center2out': 469 | permutation = self.generate_center_permutation_for_generate(N, L).cuda() 470 | else: 471 | print("Not supported permutation type!") 472 | 473 | # content mask 474 | full_mask = torch.full((N, L, L), -math.inf, device=x.device) 475 | no_mask = torch.zeros((N, L, L), device=x.device) 476 | 477 | # query mask 478 | mask_g = torch.where(permutation.unsqueeze(-1) < permutation.unsqueeze(1), full_mask, no_mask) 479 | 480 | # consider cls_token 481 | top_padding = torch.zeros((N, 1, L), device=x.device) # cls token can't see other tokens 482 | left_padding = torch.zeros((N, L + 1, 1), device=x.device) # other tokens can see cls token 483 | mask_g = torch.cat((top_padding, mask_g), dim=1) 484 | mask_g = torch.cat((left_padding, mask_g), dim=2) 485 | return mask_g.unsqueeze(1), permutation 486 | 487 | def forward_aim_for_generate(self, x): 488 | 489 | # embed patches 490 | x = self.patch_embed(x) 491 | 492 | mask_g, permutation = self.generate_attention_mask_for_generate(x) 493 | 494 | # add pos embed w/o cls token 495 | x = x + self.pos_embed[:, 1:, :] 496 | 497 | # take up half 498 | # x = x[:, :98, :] 499 | 500 | # append cls token 501 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 502 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 503 | x = torch.cat((cls_tokens, x), dim=1) 504 | 505 | # encoder 506 | h = x 507 | g = self.pos_embed.expand(x.shape[0], -1, -1) 508 | for i in range(self.depth): 509 | h = self.blocks[i](h) 510 | if (i + 1) % self.step == 0: 511 | g = self.query_blocks[i // self.step](g, h, mask=mask_g) 512 | g = self.norm(g) 513 | g = self.prediction_head(g) 514 | 515 | return g, permutation 516 | 517 | def forward_encoder(self, x): 518 | # embed patches 519 | x = self.patch_embed(x) 520 | # add pos embed w/o cls token 521 | x = x + self.pos_embed[:, 1:, :] 522 | # append cls token 523 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 524 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 525 | x = torch.cat((cls_tokens, x), dim=1) 526 | 527 | for i in range(len(self.blocks)-1): 528 | x = self.blocks[i](x) 529 | 530 | self_attention = self.blocks[len(self.blocks)-1](x, return_attention=True) # B H N+1 N+1 531 | self_attention = torch.mean(self_attention, dim=1)[:, 0, 1:] # B N 532 | 533 | x = self.blocks[len(self.blocks)-1](x) 534 | x = self.norm(x) 535 | 536 | # calculate attention 537 | if self.attention_type == 'gap': 538 | feature_attention = self.calculate_attention_gap(x) 539 | else: 540 | feature_attention = self.calculate_attention_cls(x) 541 | 542 | return x, feature_attention, self_attention 543 | 544 | def calculate_attention_cls(self, tokens): 545 | tokens = torch.nn.functional.normalize(tokens, p=2, dim=-1) 546 | attention = torch.sum(tokens[:, 0, :].unsqueeze(1) * tokens[:, 1:, :], dim=-1) 547 | 548 | attention = attention.softmax(dim=1) 549 | 550 | return attention 551 | 552 | def calculate_attention_gap(self, tokens): 553 | pth_gap = torch.mean(tokens[:, 1:, :], dim=1, keepdim=True) 554 | pth_gap = torch.nn.functional.normalize(pth_gap, p=2, dim=-1) 555 | tokens = torch.nn.functional.normalize(tokens, p=2, dim=-1) 556 | attention = torch.sum(pth_gap * tokens[:, 1:, :], dim=-1) 557 | 558 | attention = attention.softmax(dim=1) 559 | 560 | return attention 561 | 562 | def forward_pixel_loss(self, imgs, pred): 563 | imgs = self.gaussian_blur(imgs) 564 | target = self.patchify(imgs) 565 | pred = pred[:, 1:, :] 566 | if self.norm_pix_loss: 567 | mean = target.mean(dim=-1, keepdim=True) 568 | var = target.var(dim=-1, keepdim=True) 569 | target = (target - mean) / (var + 1.e-6) ** .5 570 | if self.loss_type == 'L1': 571 | loss = (pred - target).abs() 572 | elif self.loss_type == 'L2': 573 | loss = (pred - target) ** 2 574 | return loss.mean(), loss.mean(dim=-1) 575 | 576 | def forward_feature_loss(self, feature, pred): 577 | feature = feature[:, 1:, :] 578 | pred = pred[:, 1:, :] 579 | feature = torch.nn.functional.normalize(feature, p=2, dim=-1) 580 | pred = torch.nn.functional.normalize(pred, p=2, dim=-1) 581 | loss = ((pred - feature) ** 2).sum(dim=-1) 582 | return loss.mean(), loss 583 | 584 | def forward(self, imgs, tokens=None, attention_maps=None, forward_encoder=False): 585 | if forward_encoder: 586 | enc_tokens, feature_attention, self_attention = self.forward_encoder(imgs) 587 | return enc_tokens, feature_attention, self_attention 588 | 589 | pred, permutation = self.forward_aim(imgs, attention_maps) 590 | 591 | if self.predict_feature == 'none': 592 | loss, loss_map = self.forward_pixel_loss(imgs, pred) 593 | else: 594 | assert tokens != None 595 | loss, loss_map = self.forward_feature_loss(tokens, pred) 596 | return loss, permutation, loss_map 597 | 598 | def forward_for_visilization(self, imgs, attention_maps=None): 599 | pred, permutation = self.forward_aim(imgs, attention_maps) 600 | 601 | loss, loss_map = self.forward_pixel_loss(imgs, pred) 602 | imgs_blur = self.gaussian_blur(imgs) 603 | 604 | return loss, permutation, pred, imgs_blur 605 | 606 | 607 | def aim_base(**kwargs): 608 | return AimViT(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) 609 | 610 | def aim_large(**kwargs): 611 | return AimViT(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) 612 | 613 | def aim_huge(**kwargs): 614 | return AimViT(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) 615 | 616 | 617 | if __name__ == '__main__': 618 | torch.manual_seed(2023) 619 | model = aim_base(img_size=224, norm_pix_loss=False, 620 | permutation_type='attention_center', 621 | prediction_head_type='MLP', loss_type='L2', 622 | query_depth=12, share_weight=False, 623 | gaussian_kernel_size=9, gaussian_sigma=1) 624 | model.eval() 625 | x = torch.rand(1, 3, 224, 224) 626 | print(model(x)) 627 | --------------------------------------------------------------------------------