├── Results.png ├── image ├── vis1.png ├── vis2.png └── Results.png ├── LICENSE ├── README.md ├── samplers.py ├── losses.py ├── engine.py ├── datasets.py ├── models.py ├── utils.py ├── main.py └── KNN_VisionTransformer.py /Results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/KVT/HEAD/Results.png -------------------------------------------------------------------------------- /image/vis1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/KVT/HEAD/image/vis1.png -------------------------------------------------------------------------------- /image/vis2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/KVT/HEAD/image/vis2.png -------------------------------------------------------------------------------- /image/Results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/KVT/HEAD/image/Results.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 DamoCV 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KVT 2 | 3 | This repository contains PyTorch evaluation code, training code and pretrained models for the following project: 4 | * K-NN Attention for Boosting Vision Transformers, ECCV 2022 5 | 6 | For details see [K-NN Attention for Boosting Vision Transformers](https://arxiv.org/abs/2106.00515) by Pichao Wang, Xue Wang, Fan Wang, Ming Lin, Shuning Chang, Hao Li, Rong Jin. 7 | 8 | The code is based on [DeiT](https://github.com/facebookresearch/deit). 9 | 10 | ## Results on ImageNet-1K 11 | 12 | 13 | ## Visualization 14 | 15 | Self-attention heads from the last layer in Dino-small. 16 | 17 | 18 | Images from different classes are visualized using Transformer Attribution method on DeiT-Tiny. 19 | 20 | 21 | # Usage 22 | 23 | First, clone the repository locally: 24 | ``` 25 | git clone https://github.com/damo-cv/KVT.git 26 | ``` 27 | Then, install PyTorch 1.7.0+ and torchvision 0.8.1+ and [pytorch-image-models 0.3.2](https://github.com/rwightman/pytorch-image-models): 28 | 29 | ``` 30 | conda install -c pytorch pytorch torchvision 31 | pip install timm==0.4.12 32 | ``` 33 | 34 | ## Data preparation 35 | 36 | Download and extract ImageNet train and val images from http://image-net.org/. 37 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively: 38 | 39 | ``` 40 | /path/to/imagenet/ 41 | train/ 42 | class1/ 43 | img1.jpeg 44 | class2/ 45 | img2.jpeg 46 | val/ 47 | class1/ 48 | img3.jpeg 49 | class/2 50 | img4.jpeg 51 | ``` 52 | 53 | ## Training 54 | To train DeiT-KVT-tiny on ImageNet on a single node with 4 gpus for 300 epochs run: 55 | 56 | DeiT-KVT-tiny 57 | ``` 58 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model deit_tiny_patch16_224 --batch-size 256 --data-path /path/to/imagenet --output_dir /path/to/save 59 | ``` 60 | 61 | ## Citation 62 | If you use this code for a paper please cite: 63 | 64 | ``` 65 | @article{wang2021kvt, 66 | title={Kvt: k-nn attention for boosting vision transformers}, 67 | author={Wang, Pichao and Wang, Xue and Wang, Fan and Lin, Ming and Chang, Shuning and Xie, Wen and Li, Hao and Jin, Rong}, 68 | journal={arXiv preprint arXiv:2106.00515}, 69 | year={2021} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from https://github.com/facebookresearch/deit/blob/main/samplers.py. 3 | """ 4 | import torch 5 | import torch.distributed as dist 6 | import math 7 | 8 | 9 | class RASampler(torch.utils.data.Sampler): 10 | """Sampler that restricts data loading to a subset of the dataset for distributed, 11 | with repeated augmentation. 12 | It ensures that different each augmented version of a sample will be visible to a 13 | different process (GPU) 14 | Heavily based on torch.utils.data.DistributedSampler 15 | """ 16 | 17 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 18 | if num_replicas is None: 19 | if not dist.is_available(): 20 | raise RuntimeError("Requires distributed package to be available") 21 | num_replicas = dist.get_world_size() 22 | if rank is None: 23 | if not dist.is_available(): 24 | raise RuntimeError("Requires distributed package to be available") 25 | rank = dist.get_rank() 26 | if num_repeats < 1: 27 | raise ValueError("num_repeats should be greater than 0") 28 | self.dataset = dataset 29 | self.num_replicas = num_replicas 30 | self.rank = rank 31 | self.num_repeats = num_repeats 32 | self.epoch = 0 33 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 34 | self.total_size = self.num_samples * self.num_replicas 35 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 36 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 37 | self.shuffle = shuffle 38 | 39 | def __iter__(self): 40 | if self.shuffle: 41 | # deterministically shuffle based on epoch 42 | g = torch.Generator() 43 | g.manual_seed(self.epoch) 44 | indices = torch.randperm(len(self.dataset), generator=g) 45 | else: 46 | indices = torch.arange(start=0, end=len(self.dataset)) 47 | 48 | # add extra samples to make it evenly divisible 49 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 50 | padding_size: int = self.total_size - len(indices) 51 | if padding_size > 0: 52 | indices += indices[:padding_size] 53 | assert len(indices) == self.total_size 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices[:self.num_selected_samples]) 60 | 61 | def __len__(self): 62 | return self.num_selected_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from https://github.com/facebookresearch/deit/blob/main/losses.py. 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | class DistillationLoss(torch.nn.Module): 9 | """ 10 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 11 | taking a teacher model prediction and using it as additional supervision. 12 | """ 13 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 14 | distillation_type: str, alpha: float, tau: float): 15 | super().__init__() 16 | self.base_criterion = base_criterion 17 | self.teacher_model = teacher_model 18 | assert distillation_type in ['none', 'soft', 'hard'] 19 | self.distillation_type = distillation_type 20 | self.alpha = alpha 21 | self.tau = tau 22 | 23 | def forward(self, inputs, outputs, labels): 24 | """ 25 | Args: 26 | inputs: The original inputs that are feed to the teacher model 27 | outputs: the outputs of the model to be trained. It is expected to be 28 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 29 | in the first position and the distillation predictions as the second output 30 | labels: the labels for the base criterion 31 | """ 32 | outputs_kd = None 33 | if not isinstance(outputs, torch.Tensor): 34 | # assume that the model outputs a tuple of [outputs, outputs_kd] 35 | outputs, outputs_kd = outputs 36 | base_loss = self.base_criterion(outputs, labels) 37 | if self.distillation_type == 'none': 38 | return base_loss 39 | 40 | if outputs_kd is None: 41 | raise ValueError("When knowledge distillation is enabled, the model is " 42 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 43 | "class_token and the dist_token") 44 | # don't backprop throught the teacher 45 | with torch.no_grad(): 46 | teacher_outputs = self.teacher_model(inputs) 47 | 48 | if self.distillation_type == 'soft': 49 | T = self.tau 50 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 51 | # with slight modifications 52 | distillation_loss = F.kl_div( 53 | F.log_softmax(outputs_kd / T, dim=1), 54 | #We provide the teacher's targets in log probability because we use log_target=True 55 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 56 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 57 | F.log_softmax(teacher_outputs / T, dim=1), 58 | reduction='sum', 59 | log_target=True 60 | ) * (T * T) / outputs_kd.numel() 61 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 62 | #But we also experiments output_kd.size(0) 63 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 64 | elif self.distillation_type == 'hard': 65 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 66 | 67 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 68 | return loss 69 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from https://github.com/facebookresearch/deit/blob/main/engine.py. 3 | """ 4 | import math 5 | import sys 6 | from typing import Iterable, Optional 7 | 8 | import torch 9 | 10 | from timm.data import Mixup 11 | from timm.utils import accuracy, ModelEma 12 | 13 | from losses import DistillationLoss 14 | import utils 15 | 16 | 17 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 18 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 19 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 20 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 21 | set_training_mode=True): 22 | model.train(set_training_mode) 23 | metric_logger = utils.MetricLogger(delimiter=" ") 24 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 25 | header = 'Epoch: [{}]'.format(epoch) 26 | print_freq = 10 27 | 28 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 29 | samples = samples.to(device, non_blocking=True) 30 | targets = targets.to(device, non_blocking=True) 31 | 32 | if mixup_fn is not None: 33 | samples, targets = mixup_fn(samples, targets) 34 | 35 | with torch.cuda.amp.autocast(): 36 | outputs = model(samples) 37 | loss = criterion(samples, outputs, targets) 38 | 39 | loss_value = loss.item() 40 | 41 | if not math.isfinite(loss_value): 42 | print("Loss is {}, stopping training".format(loss_value)) 43 | sys.exit(1) 44 | 45 | optimizer.zero_grad() 46 | 47 | # this attribute is added by timm on one optimizer (adahessian) 48 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 49 | loss_scaler(loss, optimizer, clip_grad=max_norm, 50 | parameters=model.parameters(), create_graph=is_second_order) 51 | 52 | torch.cuda.synchronize() 53 | if model_ema is not None: 54 | model_ema.update(model) 55 | 56 | metric_logger.update(loss=loss_value) 57 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 58 | # gather the stats from all processes 59 | metric_logger.synchronize_between_processes() 60 | print("Averaged stats:", metric_logger) 61 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 62 | 63 | 64 | @torch.no_grad() 65 | def evaluate(data_loader, model, device): 66 | criterion = torch.nn.CrossEntropyLoss() 67 | 68 | metric_logger = utils.MetricLogger(delimiter=" ") 69 | header = 'Test:' 70 | 71 | # switch to evaluation mode 72 | model.eval() 73 | 74 | for images, target in metric_logger.log_every(data_loader, 10, header): 75 | images = images.to(device, non_blocking=True) 76 | target = target.to(device, non_blocking=True) 77 | 78 | # compute output 79 | with torch.cuda.amp.autocast(): 80 | output = model(images) 81 | loss = criterion(output, target) 82 | 83 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 84 | 85 | batch_size = images.shape[0] 86 | metric_logger.update(loss=loss.item()) 87 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 88 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 89 | # gather the stats from all processes 90 | metric_logger.synchronize_between_processes() 91 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 92 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 93 | 94 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 95 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from https://github.com/facebookresearch/deit/blob/main/datasets.py. 3 | """ 4 | 5 | import os 6 | import json 7 | 8 | from torchvision import datasets, transforms 9 | from torchvision.datasets.folder import ImageFolder, default_loader 10 | 11 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.data import create_transform 13 | 14 | 15 | class INatDataset(ImageFolder): 16 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 17 | category='name', loader=default_loader): 18 | self.transform = transform 19 | self.loader = loader 20 | self.target_transform = target_transform 21 | self.year = year 22 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 23 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 24 | with open(path_json) as json_file: 25 | data = json.load(json_file) 26 | 27 | with open(os.path.join(root, 'categories.json')) as json_file: 28 | data_catg = json.load(json_file) 29 | 30 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 31 | 32 | with open(path_json_for_targeter) as json_file: 33 | data_for_targeter = json.load(json_file) 34 | 35 | targeter = {} 36 | indexer = 0 37 | for elem in data_for_targeter['annotations']: 38 | king = [] 39 | king.append(data_catg[int(elem['category_id'])][category]) 40 | if king[0] not in targeter.keys(): 41 | targeter[king[0]] = indexer 42 | indexer += 1 43 | self.nb_classes = len(targeter) 44 | 45 | self.samples = [] 46 | for elem in data['images']: 47 | cut = elem['file_name'].split('/') 48 | target_current = int(cut[2]) 49 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 50 | 51 | categors = data_catg[target_current] 52 | target_current_true = targeter[categors[category]] 53 | self.samples.append((path_current, target_current_true)) 54 | 55 | # __getitem__ and __len__ inherited from ImageFolder 56 | 57 | 58 | def build_dataset(is_train, args): 59 | transform = build_transform(is_train, args) 60 | 61 | if args.data_set == 'CIFAR': 62 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 63 | nb_classes = 100 64 | elif args.data_set == 'IMNET': 65 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 66 | dataset = datasets.ImageFolder(root, transform=transform) 67 | nb_classes = 1000 68 | elif args.data_set == 'INAT': 69 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 70 | category=args.inat_category, transform=transform) 71 | nb_classes = dataset.nb_classes 72 | elif args.data_set == 'INAT19': 73 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 74 | category=args.inat_category, transform=transform) 75 | nb_classes = dataset.nb_classes 76 | 77 | return dataset, nb_classes 78 | 79 | 80 | def build_transform(is_train, args): 81 | resize_im = args.input_size > 32 82 | if is_train: 83 | # this should always dispatch to transforms_imagenet_train 84 | transform = create_transform( 85 | input_size=args.input_size, 86 | is_training=True, 87 | color_jitter=args.color_jitter, 88 | auto_augment=args.aa, 89 | interpolation=args.train_interpolation, 90 | re_prob=args.reprob, 91 | re_mode=args.remode, 92 | re_count=args.recount, 93 | ) 94 | if not resize_im: 95 | # replace RandomResizedCropAndInterpolation with 96 | # RandomCrop 97 | transform.transforms[0] = transforms.RandomCrop( 98 | args.input_size, padding=4) 99 | return transform 100 | 101 | t = [] 102 | if resize_im: 103 | size = int((256 / 224) * args.input_size) 104 | t.append( 105 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 106 | ) 107 | t.append(transforms.CenterCrop(args.input_size)) 108 | 109 | t.append(transforms.ToTensor()) 110 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 111 | return transforms.Compose(t) 112 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from https://github.com/facebookresearch/deit/blob/main/models.py. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from functools import partial 7 | 8 | from timm.models.vision_transformer import VisionTransformer, _cfg 9 | from timm.models.registry import register_model 10 | from timm.models.layers import trunc_normal_ 11 | 12 | 13 | __all__ = [ 14 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 15 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', 16 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', 17 | 'deit_base_distilled_patch16_384', 18 | ] 19 | 20 | 21 | class DistilledVisionTransformer(VisionTransformer): 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 25 | num_patches = self.patch_embed.num_patches 26 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 27 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 28 | 29 | trunc_normal_(self.dist_token, std=.02) 30 | trunc_normal_(self.pos_embed, std=.02) 31 | self.head_dist.apply(self._init_weights) 32 | 33 | def forward_features(self, x): 34 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 35 | # with slight modifications to add the dist_token 36 | B = x.shape[0] 37 | x = self.patch_embed(x) 38 | 39 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 40 | dist_token = self.dist_token.expand(B, -1, -1) 41 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 42 | 43 | x = x + self.pos_embed 44 | x = self.pos_drop(x) 45 | 46 | for blk in self.blocks: 47 | x = blk(x) 48 | 49 | x = self.norm(x) 50 | return x[:, 0], x[:, 1] 51 | 52 | def forward(self, x): 53 | x, x_dist = self.forward_features(x) 54 | x = self.head(x) 55 | x_dist = self.head_dist(x_dist) 56 | if self.training: 57 | return x, x_dist 58 | else: 59 | # during inference, return the average of both classifier predictions 60 | return (x + x_dist) / 2 61 | 62 | 63 | @register_model 64 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 65 | model = VisionTransformer( 66 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 67 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 68 | model.default_cfg = _cfg() 69 | if pretrained: 70 | checkpoint = torch.hub.load_state_dict_from_url( 71 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 72 | map_location="cpu", check_hash=True 73 | ) 74 | model.load_state_dict(checkpoint["model"]) 75 | return model 76 | 77 | 78 | @register_model 79 | def deit_small_patch16_224(pretrained=False, **kwargs): 80 | model = VisionTransformer( 81 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 82 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 83 | model.default_cfg = _cfg() 84 | if pretrained: 85 | checkpoint = torch.hub.load_state_dict_from_url( 86 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 87 | map_location="cpu", check_hash=True 88 | ) 89 | model.load_state_dict(checkpoint["model"]) 90 | return model 91 | 92 | 93 | @register_model 94 | def deit_base_patch16_224(pretrained=False, **kwargs): 95 | model = VisionTransformer( 96 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 97 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 98 | model.default_cfg = _cfg() 99 | if pretrained: 100 | checkpoint = torch.hub.load_state_dict_from_url( 101 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 102 | map_location="cpu", check_hash=True 103 | ) 104 | model.load_state_dict(checkpoint["model"]) 105 | return model 106 | 107 | 108 | @register_model 109 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 110 | model = DistilledVisionTransformer( 111 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 112 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 113 | model.default_cfg = _cfg() 114 | if pretrained: 115 | checkpoint = torch.hub.load_state_dict_from_url( 116 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", 117 | map_location="cpu", check_hash=True 118 | ) 119 | model.load_state_dict(checkpoint["model"]) 120 | return model 121 | 122 | 123 | @register_model 124 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 125 | model = DistilledVisionTransformer( 126 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 127 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 128 | model.default_cfg = _cfg() 129 | if pretrained: 130 | checkpoint = torch.hub.load_state_dict_from_url( 131 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", 132 | map_location="cpu", check_hash=True 133 | ) 134 | model.load_state_dict(checkpoint["model"]) 135 | return model 136 | 137 | 138 | @register_model 139 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs): 140 | model = DistilledVisionTransformer( 141 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 142 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 143 | model.default_cfg = _cfg() 144 | if pretrained: 145 | checkpoint = torch.hub.load_state_dict_from_url( 146 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth", 147 | map_location="cpu", check_hash=True 148 | ) 149 | model.load_state_dict(checkpoint["model"]) 150 | return model 151 | 152 | 153 | @register_model 154 | def deit_base_patch16_384(pretrained=False, **kwargs): 155 | model = VisionTransformer( 156 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 157 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 158 | model.default_cfg = _cfg() 159 | if pretrained: 160 | checkpoint = torch.hub.load_state_dict_from_url( 161 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", 162 | map_location="cpu", check_hash=True 163 | ) 164 | model.load_state_dict(checkpoint["model"]) 165 | return model 166 | 167 | 168 | @register_model 169 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs): 170 | model = DistilledVisionTransformer( 171 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 172 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 173 | model.default_cfg = _cfg() 174 | if pretrained: 175 | checkpoint = torch.hub.load_state_dict_from_url( 176 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", 177 | map_location="cpu", check_hash=True 178 | ) 179 | model.load_state_dict(checkpoint["model"]) 180 | return model 181 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from https://github.com/facebookresearch/deit/blob/main/utils.py. 3 | """ 4 | import io 5 | import os 6 | import time 7 | from collections import defaultdict, deque 8 | import datetime 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | 14 | class SmoothedValue(object): 15 | """Track a series of values and provide access to smoothed values over a 16 | window or the global series average. 17 | """ 18 | 19 | def __init__(self, window_size=20, fmt=None): 20 | if fmt is None: 21 | fmt = "{median:.4f} ({global_avg:.4f})" 22 | self.deque = deque(maxlen=window_size) 23 | self.total = 0.0 24 | self.count = 0 25 | self.fmt = fmt 26 | 27 | def update(self, value, n=1): 28 | self.deque.append(value) 29 | self.count += n 30 | self.total += value * n 31 | 32 | def synchronize_between_processes(self): 33 | """ 34 | Warning: does not synchronize the deque! 35 | """ 36 | if not is_dist_avail_and_initialized(): 37 | return 38 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 39 | dist.barrier() 40 | dist.all_reduce(t) 41 | t = t.tolist() 42 | self.count = int(t[0]) 43 | self.total = t[1] 44 | 45 | @property 46 | def median(self): 47 | d = torch.tensor(list(self.deque)) 48 | return d.median().item() 49 | 50 | @property 51 | def avg(self): 52 | d = torch.tensor(list(self.deque), dtype=torch.float32) 53 | return d.mean().item() 54 | 55 | @property 56 | def global_avg(self): 57 | return self.total / self.count 58 | 59 | @property 60 | def max(self): 61 | return max(self.deque) 62 | 63 | @property 64 | def value(self): 65 | return self.deque[-1] 66 | 67 | def __str__(self): 68 | return self.fmt.format( 69 | median=self.median, 70 | avg=self.avg, 71 | global_avg=self.global_avg, 72 | max=self.max, 73 | value=self.value) 74 | 75 | 76 | class MetricLogger(object): 77 | def __init__(self, delimiter="\t"): 78 | self.meters = defaultdict(SmoothedValue) 79 | self.delimiter = delimiter 80 | 81 | def update(self, **kwargs): 82 | for k, v in kwargs.items(): 83 | if isinstance(v, torch.Tensor): 84 | v = v.item() 85 | assert isinstance(v, (float, int)) 86 | self.meters[k].update(v) 87 | 88 | def __getattr__(self, attr): 89 | if attr in self.meters: 90 | return self.meters[attr] 91 | if attr in self.__dict__: 92 | return self.__dict__[attr] 93 | raise AttributeError("'{}' object has no attribute '{}'".format( 94 | type(self).__name__, attr)) 95 | 96 | def __str__(self): 97 | loss_str = [] 98 | for name, meter in self.meters.items(): 99 | loss_str.append( 100 | "{}: {}".format(name, str(meter)) 101 | ) 102 | return self.delimiter.join(loss_str) 103 | 104 | def synchronize_between_processes(self): 105 | for meter in self.meters.values(): 106 | meter.synchronize_between_processes() 107 | 108 | def add_meter(self, name, meter): 109 | self.meters[name] = meter 110 | 111 | def log_every(self, iterable, print_freq, header=None): 112 | i = 0 113 | if not header: 114 | header = '' 115 | start_time = time.time() 116 | end = time.time() 117 | iter_time = SmoothedValue(fmt='{avg:.4f}') 118 | data_time = SmoothedValue(fmt='{avg:.4f}') 119 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 120 | log_msg = [ 121 | header, 122 | '[{0' + space_fmt + '}/{1}]', 123 | 'eta: {eta}', 124 | '{meters}', 125 | 'time: {time}', 126 | 'data: {data}' 127 | ] 128 | if torch.cuda.is_available(): 129 | log_msg.append('max mem: {memory:.0f}') 130 | log_msg = self.delimiter.join(log_msg) 131 | MB = 1024.0 * 1024.0 132 | for obj in iterable: 133 | data_time.update(time.time() - end) 134 | yield obj 135 | iter_time.update(time.time() - end) 136 | if i % print_freq == 0 or i == len(iterable) - 1: 137 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 138 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 139 | if torch.cuda.is_available(): 140 | print(log_msg.format( 141 | i, len(iterable), eta=eta_string, 142 | meters=str(self), 143 | time=str(iter_time), data=str(data_time), 144 | memory=torch.cuda.max_memory_allocated() / MB)) 145 | else: 146 | print(log_msg.format( 147 | i, len(iterable), eta=eta_string, 148 | meters=str(self), 149 | time=str(iter_time), data=str(data_time))) 150 | i += 1 151 | end = time.time() 152 | total_time = time.time() - start_time 153 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 154 | print('{} Total time: {} ({:.4f} s / it)'.format( 155 | header, total_time_str, total_time / len(iterable))) 156 | 157 | 158 | def _load_checkpoint_for_ema(model_ema, checkpoint): 159 | """ 160 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 161 | """ 162 | mem_file = io.BytesIO() 163 | torch.save(checkpoint, mem_file) 164 | mem_file.seek(0) 165 | model_ema._load_checkpoint(mem_file) 166 | 167 | 168 | def setup_for_distributed(is_master): 169 | """ 170 | This function disables printing when not in master process 171 | """ 172 | import builtins as __builtin__ 173 | builtin_print = __builtin__.print 174 | 175 | def print(*args, **kwargs): 176 | force = kwargs.pop('force', False) 177 | if is_master or force: 178 | builtin_print(*args, **kwargs) 179 | 180 | __builtin__.print = print 181 | 182 | 183 | def is_dist_avail_and_initialized(): 184 | if not dist.is_available(): 185 | return False 186 | if not dist.is_initialized(): 187 | return False 188 | return True 189 | 190 | 191 | def get_world_size(): 192 | if not is_dist_avail_and_initialized(): 193 | return 1 194 | return dist.get_world_size() 195 | 196 | 197 | def get_rank(): 198 | if not is_dist_avail_and_initialized(): 199 | return 0 200 | return dist.get_rank() 201 | 202 | 203 | def is_main_process(): 204 | return get_rank() == 0 205 | 206 | 207 | def save_on_master(*args, **kwargs): 208 | if is_main_process(): 209 | torch.save(*args, **kwargs) 210 | 211 | 212 | def init_distributed_mode(args): 213 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 214 | args.rank = int(os.environ["RANK"]) 215 | args.world_size = int(os.environ['WORLD_SIZE']) 216 | args.gpu = int(os.environ['LOCAL_RANK']) 217 | elif 'SLURM_PROCID' in os.environ: 218 | args.rank = int(os.environ['SLURM_PROCID']) 219 | args.gpu = args.rank % torch.cuda.device_count() 220 | else: 221 | print('Not using distributed mode') 222 | args.distributed = False 223 | return 224 | 225 | args.distributed = True 226 | 227 | torch.cuda.set_device(args.gpu) 228 | args.dist_backend = 'nccl' 229 | print('| distributed init (rank {}): {}'.format( 230 | args.rank, args.dist_url), flush=True) 231 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 232 | world_size=args.world_size, rank=args.rank) 233 | torch.distributed.barrier() 234 | setup_for_distributed(args.rank == 0) 235 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from https://github.com/facebookresearch/deit/blob/main/main.py. 3 | """ 4 | import argparse 5 | import datetime 6 | import numpy as np 7 | import time 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import json 11 | 12 | from pathlib import Path 13 | 14 | from timm.data import Mixup 15 | from timm.models import create_model 16 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 17 | from timm.scheduler import create_scheduler 18 | from timm.optim import create_optimizer 19 | from timm.utils import NativeScaler, get_state_dict, ModelEma 20 | 21 | from datasets import build_dataset 22 | from engine import train_one_epoch, evaluate 23 | from losses import DistillationLoss 24 | from samplers import RASampler 25 | import models 26 | import utils 27 | 28 | 29 | def get_args_parser(): 30 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 31 | parser.add_argument('--batch-size', default=64, type=int) 32 | parser.add_argument('--epochs', default=300, type=int) 33 | 34 | # Model parameters 35 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', 36 | help='Name of model to train') 37 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 38 | 39 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 40 | help='Dropout rate (default: 0.)') 41 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 42 | help='Drop path rate (default: 0.1)') 43 | 44 | parser.add_argument('--model-ema', action='store_true') 45 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 46 | parser.set_defaults(model_ema=True) 47 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 48 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 49 | 50 | # Optimizer parameters 51 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 52 | help='Optimizer (default: "adamw"') 53 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 54 | help='Optimizer Epsilon (default: 1e-8)') 55 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 56 | help='Optimizer Betas (default: None, use opt default)') 57 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 58 | help='Clip gradient norm (default: None, no clipping)') 59 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 60 | help='SGD momentum (default: 0.9)') 61 | parser.add_argument('--weight-decay', type=float, default=0.05, 62 | help='weight decay (default: 0.05)') 63 | # Learning rate schedule parameters 64 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 65 | help='LR scheduler (default: "cosine"') 66 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 67 | help='learning rate (default: 5e-4)') 68 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 69 | help='learning rate noise on/off epoch percentages') 70 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 71 | help='learning rate noise limit percent (default: 0.67)') 72 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 73 | help='learning rate noise std-dev (default: 1.0)') 74 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 75 | help='warmup learning rate (default: 1e-6)') 76 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 77 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 78 | 79 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 80 | help='epoch interval to decay LR') 81 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 82 | help='epochs to warmup LR, if scheduler supports') 83 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 84 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 85 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 86 | help='patience epochs for Plateau LR scheduler (default: 10') 87 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 88 | help='LR decay rate (default: 0.1)') 89 | 90 | # Augmentation parameters 91 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 92 | help='Color jitter factor (default: 0.4)') 93 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 94 | help='Use AutoAugment policy. "v0" or "original". " + \ 95 | "(default: rand-m9-mstd0.5-inc1)'), 96 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 97 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 98 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 99 | 100 | parser.add_argument('--repeated-aug', action='store_true') 101 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 102 | parser.set_defaults(repeated_aug=True) 103 | 104 | # * Random Erase params 105 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 106 | help='Random erase prob (default: 0.25)') 107 | parser.add_argument('--remode', type=str, default='pixel', 108 | help='Random erase mode (default: "pixel")') 109 | parser.add_argument('--recount', type=int, default=1, 110 | help='Random erase count (default: 1)') 111 | parser.add_argument('--resplit', action='store_true', default=False, 112 | help='Do not random erase first (clean) augmentation split') 113 | 114 | # * Mixup params 115 | parser.add_argument('--mixup', type=float, default=0.8, 116 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 117 | parser.add_argument('--cutmix', type=float, default=1.0, 118 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 119 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 120 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 121 | parser.add_argument('--mixup-prob', type=float, default=1.0, 122 | help='Probability of performing mixup or cutmix when either/both is enabled') 123 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 124 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 125 | parser.add_argument('--mixup-mode', type=str, default='batch', 126 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 127 | 128 | # Distillation parameters 129 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 130 | help='Name of teacher model to train (default: "regnety_160"') 131 | parser.add_argument('--teacher-path', type=str, default='') 132 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 133 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 134 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 135 | 136 | # * Finetuning params 137 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 138 | 139 | # Dataset parameters 140 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 141 | help='dataset path') 142 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 143 | type=str, help='Image Net dataset path') 144 | parser.add_argument('--inat-category', default='name', 145 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 146 | type=str, help='semantic granularity') 147 | 148 | parser.add_argument('--output_dir', default='', 149 | help='path where to save, empty for no saving') 150 | parser.add_argument('--device', default='cuda', 151 | help='device to use for training / testing') 152 | parser.add_argument('--seed', default=0, type=int) 153 | parser.add_argument('--resume', default='', help='resume from checkpoint') 154 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 155 | help='start epoch') 156 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 157 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 158 | parser.add_argument('--num_workers', default=10, type=int) 159 | parser.add_argument('--pin-mem', action='store_true', 160 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 161 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 162 | help='') 163 | parser.set_defaults(pin_mem=True) 164 | 165 | # distributed training parameters 166 | parser.add_argument('--world_size', default=1, type=int, 167 | help='number of distributed processes') 168 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 169 | return parser 170 | 171 | 172 | def main(args): 173 | utils.init_distributed_mode(args) 174 | 175 | print(args) 176 | 177 | if args.distillation_type != 'none' and args.finetune and not args.eval: 178 | raise NotImplementedError("Finetuning with distillation not yet supported") 179 | 180 | device = torch.device(args.device) 181 | 182 | # fix the seed for reproducibility 183 | seed = args.seed + utils.get_rank() 184 | torch.manual_seed(seed) 185 | np.random.seed(seed) 186 | # random.seed(seed) 187 | 188 | cudnn.benchmark = True 189 | 190 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 191 | dataset_val, _ = build_dataset(is_train=False, args=args) 192 | 193 | if True: # args.distributed: 194 | num_tasks = utils.get_world_size() 195 | global_rank = utils.get_rank() 196 | if args.repeated_aug: 197 | sampler_train = RASampler( 198 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 199 | ) 200 | else: 201 | sampler_train = torch.utils.data.DistributedSampler( 202 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 203 | ) 204 | if args.dist_eval: 205 | if len(dataset_val) % num_tasks != 0: 206 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 207 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 208 | 'equal num of samples per-process.') 209 | sampler_val = torch.utils.data.DistributedSampler( 210 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 211 | else: 212 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 213 | else: 214 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 215 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 216 | 217 | data_loader_train = torch.utils.data.DataLoader( 218 | dataset_train, sampler=sampler_train, 219 | batch_size=args.batch_size, 220 | num_workers=args.num_workers, 221 | pin_memory=args.pin_mem, 222 | drop_last=True, 223 | ) 224 | 225 | data_loader_val = torch.utils.data.DataLoader( 226 | dataset_val, sampler=sampler_val, 227 | batch_size=int(1.5 * args.batch_size), 228 | num_workers=args.num_workers, 229 | pin_memory=args.pin_mem, 230 | drop_last=False 231 | ) 232 | 233 | mixup_fn = None 234 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 235 | if mixup_active: 236 | mixup_fn = Mixup( 237 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 238 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 239 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 240 | 241 | print(f"Creating model: {args.model}") 242 | model = create_model( 243 | args.model, 244 | pretrained=False, 245 | num_classes=args.nb_classes, 246 | drop_rate=args.drop, 247 | drop_path_rate=args.drop_path, 248 | drop_block_rate=None, 249 | ) 250 | 251 | if args.finetune: 252 | if args.finetune.startswith('https'): 253 | checkpoint = torch.hub.load_state_dict_from_url( 254 | args.finetune, map_location='cpu', check_hash=True) 255 | else: 256 | checkpoint = torch.load(args.finetune, map_location='cpu') 257 | 258 | checkpoint_model = checkpoint['model'] 259 | state_dict = model.state_dict() 260 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 261 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 262 | print(f"Removing key {k} from pretrained checkpoint") 263 | del checkpoint_model[k] 264 | 265 | # interpolate position embedding 266 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 267 | embedding_size = pos_embed_checkpoint.shape[-1] 268 | num_patches = model.patch_embed.num_patches 269 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 270 | # height (== width) for the checkpoint position embedding 271 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 272 | # height (== width) for the new position embedding 273 | new_size = int(num_patches ** 0.5) 274 | # class_token and dist_token are kept unchanged 275 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 276 | # only the position tokens are interpolated 277 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 278 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 279 | pos_tokens = torch.nn.functional.interpolate( 280 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 281 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 282 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 283 | checkpoint_model['pos_embed'] = new_pos_embed 284 | 285 | model.load_state_dict(checkpoint_model, strict=False) 286 | 287 | model.to(device) 288 | 289 | model_ema = None 290 | if args.model_ema: 291 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 292 | model_ema = ModelEma( 293 | model, 294 | decay=args.model_ema_decay, 295 | device='cpu' if args.model_ema_force_cpu else '', 296 | resume='') 297 | 298 | model_without_ddp = model 299 | if args.distributed: 300 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 301 | model_without_ddp = model.module 302 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 303 | print('number of params:', n_parameters) 304 | 305 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 306 | args.lr = linear_scaled_lr 307 | optimizer = create_optimizer(args, model_without_ddp) 308 | loss_scaler = NativeScaler() 309 | 310 | lr_scheduler, _ = create_scheduler(args, optimizer) 311 | 312 | criterion = LabelSmoothingCrossEntropy() 313 | 314 | if mixup_active: 315 | # smoothing is handled with mixup label transform 316 | criterion = SoftTargetCrossEntropy() 317 | elif args.smoothing: 318 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 319 | else: 320 | criterion = torch.nn.CrossEntropyLoss() 321 | 322 | teacher_model = None 323 | if args.distillation_type != 'none': 324 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 325 | print(f"Creating teacher model: {args.teacher_model}") 326 | teacher_model = create_model( 327 | args.teacher_model, 328 | pretrained=False, 329 | num_classes=args.nb_classes, 330 | global_pool='avg', 331 | ) 332 | if args.teacher_path.startswith('https'): 333 | checkpoint = torch.hub.load_state_dict_from_url( 334 | args.teacher_path, map_location='cpu', check_hash=True) 335 | else: 336 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 337 | teacher_model.load_state_dict(checkpoint['model']) 338 | teacher_model.to(device) 339 | teacher_model.eval() 340 | 341 | # wrap the criterion in our custom DistillationLoss, which 342 | # just dispatches to the original criterion if args.distillation_type is 'none' 343 | criterion = DistillationLoss( 344 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 345 | ) 346 | 347 | output_dir = Path(args.output_dir) 348 | if args.resume: 349 | if args.resume.startswith('https'): 350 | checkpoint = torch.hub.load_state_dict_from_url( 351 | args.resume, map_location='cpu', check_hash=True) 352 | else: 353 | checkpoint = torch.load(args.resume, map_location='cpu') 354 | model_without_ddp.load_state_dict(checkpoint['model']) 355 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 356 | optimizer.load_state_dict(checkpoint['optimizer']) 357 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 358 | args.start_epoch = checkpoint['epoch'] + 1 359 | if args.model_ema: 360 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 361 | if 'scaler' in checkpoint: 362 | loss_scaler.load_state_dict(checkpoint['scaler']) 363 | lr_scheduler.step(args.start_epoch) 364 | if args.eval: 365 | test_stats = evaluate(data_loader_val, model, device) 366 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 367 | return 368 | 369 | print(f"Start training for {args.epochs} epochs") 370 | start_time = time.time() 371 | max_accuracy = 0.0 372 | for epoch in range(args.start_epoch, args.epochs): 373 | if args.distributed: 374 | data_loader_train.sampler.set_epoch(epoch) 375 | 376 | train_stats = train_one_epoch( 377 | model, criterion, data_loader_train, 378 | optimizer, device, epoch, loss_scaler, 379 | args.clip_grad, model_ema, mixup_fn, 380 | set_training_mode=args.finetune == '' # keep in eval mode during finetuning 381 | ) 382 | 383 | lr_scheduler.step(epoch) 384 | if args.output_dir: 385 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 386 | for checkpoint_path in checkpoint_paths: 387 | utils.save_on_master({ 388 | 'model': model_without_ddp.state_dict(), 389 | 'optimizer': optimizer.state_dict(), 390 | 'lr_scheduler': lr_scheduler.state_dict(), 391 | 'epoch': epoch, 392 | 'model_ema': get_state_dict(model_ema), 393 | 'scaler': loss_scaler.state_dict(), 394 | 'args': args, 395 | }, checkpoint_path) 396 | 397 | 398 | test_stats = evaluate(data_loader_val, model, device) 399 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 400 | 401 | if max_accuracy < test_stats["acc1"]: 402 | max_accuracy = test_stats["acc1"] 403 | if args.output_dir: 404 | checkpoint_paths = [output_dir / 'best_checkpoint.pth'] 405 | for checkpoint_path in checkpoint_paths: 406 | utils.save_on_master({ 407 | 'model': model_without_ddp.state_dict(), 408 | 'optimizer': optimizer.state_dict(), 409 | 'lr_scheduler': lr_scheduler.state_dict(), 410 | 'epoch': epoch, 411 | 'model_ema': get_state_dict(model_ema), 412 | 'scaler': loss_scaler.state_dict(), 413 | 'args': args, 414 | }, checkpoint_path) 415 | 416 | print(f'Max accuracy: {max_accuracy:.2f}%') 417 | 418 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 419 | **{f'test_{k}': v for k, v in test_stats.items()}, 420 | 'epoch': epoch, 421 | 'n_parameters': n_parameters} 422 | 423 | 424 | 425 | 426 | if args.output_dir and utils.is_main_process(): 427 | with (output_dir / "log.txt").open("a") as f: 428 | f.write(json.dumps(log_stats) + "\n") 429 | 430 | total_time = time.time() - start_time 431 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 432 | print('Training time {}'.format(total_time_str)) 433 | 434 | 435 | if __name__ == '__main__': 436 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) 437 | args = parser.parse_args() 438 | if args.output_dir: 439 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 440 | main(args) 441 | -------------------------------------------------------------------------------- /KNN_VisionTransformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 3 | """ 4 | import math 5 | import logging 6 | from functools import partial 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.models.helpers import load_pretrained 15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 16 | from timm.models.resnet import resnet26d, resnet50d 17 | from timm.models.resnetv2 import ResNetV2, StdConv2dSame 18 | from timm.models.registry import register_model 19 | 20 | _logger = logging.getLogger(__name__) 21 | 22 | class Mlp(nn.Module): 23 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.fc1 = nn.Linear(in_features, hidden_features) 28 | self.act = act_layer() 29 | self.fc2 = nn.Linear(hidden_features, out_features) 30 | self.drop = nn.Dropout(drop) 31 | 32 | def forward(self, x): 33 | x = self.fc1(x) 34 | x = self.act(x) 35 | x = self.drop(x) 36 | x = self.fc2(x) 37 | x = self.drop(x) 38 | return x 39 | 40 | 41 | class kNNAttention(nn.Module): 42 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,topk=100): 43 | super().__init__() 44 | self.num_heads = num_heads 45 | head_dim = dim // num_heads 46 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 47 | self.scale = qk_scale or head_dim ** -0.5 48 | 49 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 50 | self.attn_drop = nn.Dropout(attn_drop) 51 | self.proj = nn.Linear(dim, dim) 52 | self.proj_drop = nn.Dropout(proj_drop) 53 | self.topk = topk 54 | 55 | def forward(self, x): 56 | B, N, C = x.shape 57 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 58 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 59 | attn = (q @ k.transpose(-2, -1)) * self.scale 60 | # the core code block 61 | mask=torch.zeros(B,self.num_heads,N,N,device=x.device,requires_grad=False) 62 | index=torch.topk(attn,k=self.topk,dim=-1,largest=True)[1] 63 | mask.scatter_(-1,index,1.) 64 | attn=torch.where(mask>0,attn,torch.full_like(attn,float('-inf'))) 65 | # end of the core code block 66 | 67 | attn = attn.softmax(dim=-1) 68 | attn = self.attn_drop(attn) 69 | 70 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 71 | x = self.proj(x) 72 | x = self.proj_drop(x) 73 | return x 74 | 75 | 76 | class Block(nn.Module): 77 | 78 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 79 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 80 | super().__init__() 81 | self.norm1 = norm_layer(dim) 82 | self.attn = kNNAttention( 83 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 84 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 85 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 86 | self.norm2 = norm_layer(dim) 87 | mlp_hidden_dim = int(dim * mlp_ratio) 88 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 89 | 90 | def forward(self, x): 91 | x = x + self.drop_path(self.attn(self.norm1(x))) 92 | x = x + self.drop_path(self.mlp(self.norm2(x))) 93 | return x 94 | 95 | 96 | class PatchEmbed(nn.Module): 97 | """ Image to Patch Embedding 98 | """ 99 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 100 | super().__init__() 101 | img_size = to_2tuple(img_size) 102 | patch_size = to_2tuple(patch_size) 103 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 104 | self.img_size = img_size 105 | self.patch_size = patch_size 106 | self.num_patches = num_patches 107 | 108 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 109 | 110 | def forward(self, x): 111 | B, C, H, W = x.shape 112 | # FIXME look at relaxing size constraints 113 | assert H == self.img_size[0] and W == self.img_size[1], \ 114 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 115 | x = self.proj(x).flatten(2).transpose(1, 2) 116 | return x 117 | 118 | 119 | class HybridEmbed(nn.Module): 120 | """ CNN Feature Map Embedding 121 | Extract feature map from CNN, flatten, project to embedding dim. 122 | """ 123 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 124 | super().__init__() 125 | assert isinstance(backbone, nn.Module) 126 | img_size = to_2tuple(img_size) 127 | self.img_size = img_size 128 | self.backbone = backbone 129 | if feature_size is None: 130 | with torch.no_grad(): 131 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 132 | # map for all networks, the feature metadata has reliable channel and stride info, but using 133 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 134 | training = backbone.training 135 | if training: 136 | backbone.eval() 137 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) 138 | if isinstance(o, (list, tuple)): 139 | o = o[-1] # last feature if backbone outputs list/tuple of features 140 | feature_size = o.shape[-2:] 141 | feature_dim = o.shape[1] 142 | backbone.train(training) 143 | else: 144 | feature_size = to_2tuple(feature_size) 145 | if hasattr(self.backbone, 'feature_info'): 146 | feature_dim = self.backbone.feature_info.channels()[-1] 147 | else: 148 | feature_dim = self.backbone.num_features 149 | self.num_patches = feature_size[0] * feature_size[1] 150 | self.proj = nn.Conv2d(feature_dim, embed_dim, 1) 151 | 152 | def forward(self, x): 153 | x = self.backbone(x) 154 | if isinstance(x, (list, tuple)): 155 | x = x[-1] # last feature if backbone outputs list/tuple of features 156 | x = self.proj(x).flatten(2).transpose(1, 2) 157 | return x 158 | 159 | 160 | class VisionTransformer(nn.Module): 161 | """ Vision Transformer 162 | 163 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 164 | https://arxiv.org/abs/2010.11929 165 | """ 166 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 167 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 168 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): 169 | """ 170 | Args: 171 | img_size (int, tuple): input image size 172 | patch_size (int, tuple): patch size 173 | in_chans (int): number of input channels 174 | num_classes (int): number of classes for classification head 175 | embed_dim (int): embedding dimension 176 | depth (int): depth of transformer 177 | num_heads (int): number of attention heads 178 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 179 | qkv_bias (bool): enable bias for qkv if True 180 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 181 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 182 | drop_rate (float): dropout rate 183 | attn_drop_rate (float): attention dropout rate 184 | drop_path_rate (float): stochastic depth rate 185 | hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module 186 | norm_layer: (nn.Module): normalization layer 187 | """ 188 | super().__init__() 189 | self.num_classes = num_classes 190 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 191 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 192 | 193 | if hybrid_backbone is not None: 194 | self.patch_embed = HybridEmbed( 195 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 196 | else: 197 | self.patch_embed = PatchEmbed( 198 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 199 | num_patches = self.patch_embed.num_patches 200 | 201 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 202 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 203 | self.pos_drop = nn.Dropout(p=drop_rate) 204 | 205 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 206 | self.blocks = nn.ModuleList([ 207 | Block( 208 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 209 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 210 | for i in range(depth)]) 211 | self.norm = norm_layer(embed_dim) 212 | 213 | # Representation layer 214 | if representation_size: 215 | self.num_features = representation_size 216 | self.pre_logits = nn.Sequential(OrderedDict([ 217 | ('fc', nn.Linear(embed_dim, representation_size)), 218 | ('act', nn.Tanh()) 219 | ])) 220 | else: 221 | self.pre_logits = nn.Identity() 222 | 223 | # Classifier head 224 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 225 | 226 | trunc_normal_(self.pos_embed, std=.02) 227 | trunc_normal_(self.cls_token, std=.02) 228 | self.apply(self._init_weights) 229 | 230 | def _init_weights(self, m): 231 | if isinstance(m, nn.Linear): 232 | trunc_normal_(m.weight, std=.02) 233 | if isinstance(m, nn.Linear) and m.bias is not None: 234 | nn.init.constant_(m.bias, 0) 235 | elif isinstance(m, nn.LayerNorm): 236 | nn.init.constant_(m.bias, 0) 237 | nn.init.constant_(m.weight, 1.0) 238 | 239 | @torch.jit.ignore 240 | def no_weight_decay(self): 241 | return {'pos_embed', 'cls_token'} 242 | 243 | def get_classifier(self): 244 | return self.head 245 | 246 | def reset_classifier(self, num_classes, global_pool=''): 247 | self.num_classes = num_classes 248 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 249 | 250 | def forward_features(self, x): 251 | B = x.shape[0] 252 | x = self.patch_embed(x) 253 | 254 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 255 | x = torch.cat((cls_tokens, x), dim=1) 256 | x = x + self.pos_embed 257 | x = self.pos_drop(x) 258 | 259 | for blk in self.blocks: 260 | x = blk(x) 261 | 262 | x = self.norm(x)[:, 0] 263 | x = self.pre_logits(x) 264 | return x 265 | 266 | def forward(self, x): 267 | x = self.forward_features(x) 268 | x = self.head(x) 269 | return x 270 | 271 | 272 | class DistilledVisionTransformer(VisionTransformer): 273 | """ Vision Transformer with distillation token. 274 | 275 | Paper: `Training data-efficient image transformers & distillation through attention` - 276 | https://arxiv.org/abs/2012.12877 277 | 278 | This impl of distilled ViT is taken from https://github.com/facebookresearch/deit 279 | """ 280 | def __init__(self, *args, **kwargs): 281 | super().__init__(*args, **kwargs) 282 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 283 | num_patches = self.patch_embed.num_patches 284 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 285 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 286 | 287 | trunc_normal_(self.dist_token, std=.02) 288 | trunc_normal_(self.pos_embed, std=.02) 289 | self.head_dist.apply(self._init_weights) 290 | 291 | def forward_features(self, x): 292 | B = x.shape[0] 293 | x = self.patch_embed(x) 294 | 295 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 296 | dist_token = self.dist_token.expand(B, -1, -1) 297 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 298 | 299 | x = x + self.pos_embed 300 | x = self.pos_drop(x) 301 | 302 | for blk in self.blocks: 303 | x = blk(x) 304 | 305 | x = self.norm(x) 306 | return x[:, 0], x[:, 1] 307 | 308 | def forward(self, x): 309 | x, x_dist = self.forward_features(x) 310 | x = self.head(x) 311 | x_dist = self.head_dist(x_dist) 312 | if self.training: 313 | return x, x_dist 314 | else: 315 | # during inference, return the average of both classifier predictions 316 | return (x + x_dist) / 2 317 | 318 | 319 | def resize_pos_embed(posemb, posemb_new): 320 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 321 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 322 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 323 | ntok_new = posemb_new.shape[1] 324 | if True: 325 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 326 | ntok_new -= 1 327 | else: 328 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 329 | gs_old = int(math.sqrt(len(posemb_grid))) 330 | gs_new = int(math.sqrt(ntok_new)) 331 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) 332 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 333 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') 334 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) 335 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 336 | return posemb 337 | 338 | 339 | def checkpoint_filter_fn(state_dict, model): 340 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 341 | out_dict = {} 342 | if 'model' in state_dict: 343 | # For deit models 344 | state_dict = state_dict['model'] 345 | for k, v in state_dict.items(): 346 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 347 | # For old models that I trained prior to conv based patchification 348 | O, I, H, W = model.patch_embed.proj.weight.shape 349 | v = v.reshape(O, -1, H, W) 350 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 351 | # To resize pos embedding when using model at different size from pretrained weights 352 | v = resize_pos_embed(v, model.pos_embed) 353 | out_dict[k] = v 354 | return out_dict 355 | 356 | 357 | def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): 358 | default_cfg = default_cfgs[variant] 359 | default_num_classes = default_cfg['num_classes'] 360 | default_img_size = default_cfg['input_size'][-1] 361 | 362 | num_classes = kwargs.pop('num_classes', default_num_classes) 363 | img_size = kwargs.pop('img_size', default_img_size) 364 | repr_size = kwargs.pop('representation_size', None) 365 | if repr_size is not None and num_classes != default_num_classes: 366 | # Remove representation layer if fine-tuning. This may not always be the desired action, 367 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 368 | _logger.warning("Removing representation layer for fine-tuning.") 369 | repr_size = None 370 | 371 | model_cls = DistilledVisionTransformer if distilled else VisionTransformer 372 | model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs) 373 | model.default_cfg = default_cfg 374 | 375 | if pretrained: 376 | load_pretrained( 377 | model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3), 378 | filter_fn=partial(checkpoint_filter_fn, model=model)) 379 | return model 380 | 381 | 382 | @register_model 383 | def vit_small_patch16_224(pretrained=False, **kwargs): 384 | """ My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3.""" 385 | model_kwargs = dict( 386 | patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., 387 | qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs) 388 | if pretrained: 389 | # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model 390 | model_kwargs.setdefault('qk_scale', 768 ** -0.5) 391 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) 392 | return model 393 | 394 | 395 | @register_model 396 | def vit_base_patch16_224(pretrained=False, **kwargs): 397 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 398 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 399 | """ 400 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 401 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) 402 | return model 403 | 404 | 405 | @register_model 406 | def vit_base_patch32_224(pretrained=False, **kwargs): 407 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. 408 | """ 409 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 410 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) 411 | return model 412 | 413 | 414 | @register_model 415 | def vit_base_patch16_384(pretrained=False, **kwargs): 416 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 417 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 418 | """ 419 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 420 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) 421 | return model 422 | 423 | 424 | @register_model 425 | def vit_base_patch32_384(pretrained=False, **kwargs): 426 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 427 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 428 | """ 429 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 430 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) 431 | return model 432 | 433 | 434 | @register_model 435 | def vit_large_patch16_224(pretrained=False, **kwargs): 436 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 437 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 438 | """ 439 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 440 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) 441 | return model 442 | 443 | 444 | @register_model 445 | def vit_large_patch32_224(pretrained=False, **kwargs): 446 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. 447 | """ 448 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 449 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) 450 | return model 451 | 452 | 453 | @register_model 454 | def vit_large_patch16_384(pretrained=False, **kwargs): 455 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 456 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 457 | """ 458 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 459 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) 460 | return model 461 | 462 | 463 | @register_model 464 | def vit_large_patch32_384(pretrained=False, **kwargs): 465 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 466 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 467 | """ 468 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 469 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) 470 | return model 471 | 472 | 473 | @register_model 474 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs): 475 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 476 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 477 | """ 478 | model_kwargs = dict( 479 | patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) 480 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 481 | return model 482 | 483 | 484 | @register_model 485 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs): 486 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 487 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 488 | """ 489 | model_kwargs = dict( 490 | patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) 491 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 492 | return model 493 | 494 | 495 | @register_model 496 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs): 497 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 498 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 499 | """ 500 | model_kwargs = dict( 501 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) 502 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 503 | return model 504 | 505 | 506 | @register_model 507 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs): 508 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 509 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 510 | """ 511 | model_kwargs = dict( 512 | patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) 513 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 514 | return model 515 | 516 | 517 | @register_model 518 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): 519 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 520 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 521 | NOTE: converted weights not currently available, too large for github release hosting. 522 | """ 523 | model_kwargs = dict( 524 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) 525 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) 526 | return model 527 | 528 | 529 | @register_model 530 | def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): 531 | """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). 532 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 533 | """ 534 | # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head 535 | backbone = ResNetV2( 536 | layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), 537 | preact=False, stem_type='same', conv_layer=StdConv2dSame) 538 | model_kwargs = dict( 539 | embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, 540 | representation_size=768, **kwargs) 541 | model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs) 542 | return model 543 | 544 | 545 | @register_model 546 | def vit_base_resnet50_384(pretrained=False, **kwargs): 547 | """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). 548 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 549 | """ 550 | # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head 551 | backbone = ResNetV2( 552 | layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), 553 | preact=False, stem_type='same', conv_layer=StdConv2dSame) 554 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) 555 | model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs) 556 | return model 557 | 558 | 559 | @register_model 560 | def vit_small_resnet26d_224(pretrained=False, **kwargs): 561 | """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. 562 | """ 563 | backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) 564 | model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) 565 | model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs) 566 | return model 567 | 568 | 569 | @register_model 570 | def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): 571 | """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. 572 | """ 573 | backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) 574 | model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) 575 | model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs) 576 | return model 577 | 578 | 579 | @register_model 580 | def vit_base_resnet26d_224(pretrained=False, **kwargs): 581 | """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. 582 | """ 583 | backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) 584 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) 585 | model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs) 586 | return model 587 | 588 | 589 | @register_model 590 | def vit_base_resnet50d_224(pretrained=False, **kwargs): 591 | """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. 592 | """ 593 | backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) 594 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) 595 | model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs) 596 | return model 597 | 598 | 599 | @register_model 600 | def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): 601 | """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 602 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 603 | """ 604 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 605 | model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 606 | return model 607 | 608 | 609 | @register_model 610 | def vit_deit_small_patch16_224(pretrained=False, **kwargs): 611 | """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 612 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 613 | """ 614 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 615 | model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs) 616 | return model 617 | 618 | 619 | @register_model 620 | def vit_deit_base_patch16_224(pretrained=False, **kwargs): 621 | """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 622 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 623 | """ 624 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 625 | model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs) 626 | return model 627 | 628 | 629 | @register_model 630 | def vit_deit_base_patch16_384(pretrained=False, **kwargs): 631 | """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 632 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 633 | """ 634 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 635 | model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs) 636 | return model 637 | 638 | 639 | @register_model 640 | def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 641 | """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 642 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 643 | """ 644 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 645 | model = _create_vision_transformer( 646 | 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 647 | return model 648 | 649 | 650 | @register_model 651 | def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs): 652 | """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 653 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 654 | """ 655 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 656 | model = _create_vision_transformer( 657 | 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 658 | return model 659 | 660 | 661 | @register_model 662 | def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs): 663 | """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 664 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 665 | """ 666 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 667 | model = _create_vision_transformer( 668 | 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 669 | return model 670 | 671 | 672 | @register_model 673 | def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs): 674 | """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 675 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 676 | """ 677 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 678 | model = _create_vision_transformer( 679 | 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) 680 | return model 681 | --------------------------------------------------------------------------------