├── Results.png
├── image
├── vis1.png
├── vis2.png
└── Results.png
├── LICENSE
├── README.md
├── samplers.py
├── losses.py
├── engine.py
├── datasets.py
├── models.py
├── utils.py
├── main.py
└── KNN_VisionTransformer.py
/Results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/KVT/HEAD/Results.png
--------------------------------------------------------------------------------
/image/vis1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/KVT/HEAD/image/vis1.png
--------------------------------------------------------------------------------
/image/vis2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/KVT/HEAD/image/vis2.png
--------------------------------------------------------------------------------
/image/Results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/KVT/HEAD/image/Results.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 DamoCV
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # KVT
2 |
3 | This repository contains PyTorch evaluation code, training code and pretrained models for the following project:
4 | * K-NN Attention for Boosting Vision Transformers, ECCV 2022
5 |
6 | For details see [K-NN Attention for Boosting Vision Transformers](https://arxiv.org/abs/2106.00515) by Pichao Wang, Xue Wang, Fan Wang, Ming Lin, Shuning Chang, Hao Li, Rong Jin.
7 |
8 | The code is based on [DeiT](https://github.com/facebookresearch/deit).
9 |
10 | ## Results on ImageNet-1K
11 |
12 |
13 | ## Visualization
14 |
15 | Self-attention heads from the last layer in Dino-small.
16 |
17 |
18 | Images from different classes are visualized using Transformer Attribution method on DeiT-Tiny.
19 |
20 |
21 | # Usage
22 |
23 | First, clone the repository locally:
24 | ```
25 | git clone https://github.com/damo-cv/KVT.git
26 | ```
27 | Then, install PyTorch 1.7.0+ and torchvision 0.8.1+ and [pytorch-image-models 0.3.2](https://github.com/rwightman/pytorch-image-models):
28 |
29 | ```
30 | conda install -c pytorch pytorch torchvision
31 | pip install timm==0.4.12
32 | ```
33 |
34 | ## Data preparation
35 |
36 | Download and extract ImageNet train and val images from http://image-net.org/.
37 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively:
38 |
39 | ```
40 | /path/to/imagenet/
41 | train/
42 | class1/
43 | img1.jpeg
44 | class2/
45 | img2.jpeg
46 | val/
47 | class1/
48 | img3.jpeg
49 | class/2
50 | img4.jpeg
51 | ```
52 |
53 | ## Training
54 | To train DeiT-KVT-tiny on ImageNet on a single node with 4 gpus for 300 epochs run:
55 |
56 | DeiT-KVT-tiny
57 | ```
58 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model deit_tiny_patch16_224 --batch-size 256 --data-path /path/to/imagenet --output_dir /path/to/save
59 | ```
60 |
61 | ## Citation
62 | If you use this code for a paper please cite:
63 |
64 | ```
65 | @article{wang2021kvt,
66 | title={Kvt: k-nn attention for boosting vision transformers},
67 | author={Wang, Pichao and Wang, Xue and Wang, Fan and Lin, Ming and Chang, Shuning and Xie, Wen and Li, Hao and Jin, Rong},
68 | journal={arXiv preprint arXiv:2106.00515},
69 | year={2021}
70 | }
71 | ```
72 |
--------------------------------------------------------------------------------
/samplers.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is from https://github.com/facebookresearch/deit/blob/main/samplers.py.
3 | """
4 | import torch
5 | import torch.distributed as dist
6 | import math
7 |
8 |
9 | class RASampler(torch.utils.data.Sampler):
10 | """Sampler that restricts data loading to a subset of the dataset for distributed,
11 | with repeated augmentation.
12 | It ensures that different each augmented version of a sample will be visible to a
13 | different process (GPU)
14 | Heavily based on torch.utils.data.DistributedSampler
15 | """
16 |
17 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
18 | if num_replicas is None:
19 | if not dist.is_available():
20 | raise RuntimeError("Requires distributed package to be available")
21 | num_replicas = dist.get_world_size()
22 | if rank is None:
23 | if not dist.is_available():
24 | raise RuntimeError("Requires distributed package to be available")
25 | rank = dist.get_rank()
26 | if num_repeats < 1:
27 | raise ValueError("num_repeats should be greater than 0")
28 | self.dataset = dataset
29 | self.num_replicas = num_replicas
30 | self.rank = rank
31 | self.num_repeats = num_repeats
32 | self.epoch = 0
33 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
34 | self.total_size = self.num_samples * self.num_replicas
35 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
36 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
37 | self.shuffle = shuffle
38 |
39 | def __iter__(self):
40 | if self.shuffle:
41 | # deterministically shuffle based on epoch
42 | g = torch.Generator()
43 | g.manual_seed(self.epoch)
44 | indices = torch.randperm(len(self.dataset), generator=g)
45 | else:
46 | indices = torch.arange(start=0, end=len(self.dataset))
47 |
48 | # add extra samples to make it evenly divisible
49 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
50 | padding_size: int = self.total_size - len(indices)
51 | if padding_size > 0:
52 | indices += indices[:padding_size]
53 | assert len(indices) == self.total_size
54 |
55 | # subsample
56 | indices = indices[self.rank:self.total_size:self.num_replicas]
57 | assert len(indices) == self.num_samples
58 |
59 | return iter(indices[:self.num_selected_samples])
60 |
61 | def __len__(self):
62 | return self.num_selected_samples
63 |
64 | def set_epoch(self, epoch):
65 | self.epoch = epoch
66 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is from https://github.com/facebookresearch/deit/blob/main/losses.py.
3 | """
4 | import torch
5 | from torch.nn import functional as F
6 |
7 |
8 | class DistillationLoss(torch.nn.Module):
9 | """
10 | This module wraps a standard criterion and adds an extra knowledge distillation loss by
11 | taking a teacher model prediction and using it as additional supervision.
12 | """
13 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
14 | distillation_type: str, alpha: float, tau: float):
15 | super().__init__()
16 | self.base_criterion = base_criterion
17 | self.teacher_model = teacher_model
18 | assert distillation_type in ['none', 'soft', 'hard']
19 | self.distillation_type = distillation_type
20 | self.alpha = alpha
21 | self.tau = tau
22 |
23 | def forward(self, inputs, outputs, labels):
24 | """
25 | Args:
26 | inputs: The original inputs that are feed to the teacher model
27 | outputs: the outputs of the model to be trained. It is expected to be
28 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output
29 | in the first position and the distillation predictions as the second output
30 | labels: the labels for the base criterion
31 | """
32 | outputs_kd = None
33 | if not isinstance(outputs, torch.Tensor):
34 | # assume that the model outputs a tuple of [outputs, outputs_kd]
35 | outputs, outputs_kd = outputs
36 | base_loss = self.base_criterion(outputs, labels)
37 | if self.distillation_type == 'none':
38 | return base_loss
39 |
40 | if outputs_kd is None:
41 | raise ValueError("When knowledge distillation is enabled, the model is "
42 | "expected to return a Tuple[Tensor, Tensor] with the output of the "
43 | "class_token and the dist_token")
44 | # don't backprop throught the teacher
45 | with torch.no_grad():
46 | teacher_outputs = self.teacher_model(inputs)
47 |
48 | if self.distillation_type == 'soft':
49 | T = self.tau
50 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
51 | # with slight modifications
52 | distillation_loss = F.kl_div(
53 | F.log_softmax(outputs_kd / T, dim=1),
54 | #We provide the teacher's targets in log probability because we use log_target=True
55 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
56 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
57 | F.log_softmax(teacher_outputs / T, dim=1),
58 | reduction='sum',
59 | log_target=True
60 | ) * (T * T) / outputs_kd.numel()
61 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
62 | #But we also experiments output_kd.size(0)
63 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
64 | elif self.distillation_type == 'hard':
65 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
66 |
67 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
68 | return loss
69 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is from https://github.com/facebookresearch/deit/blob/main/engine.py.
3 | """
4 | import math
5 | import sys
6 | from typing import Iterable, Optional
7 |
8 | import torch
9 |
10 | from timm.data import Mixup
11 | from timm.utils import accuracy, ModelEma
12 |
13 | from losses import DistillationLoss
14 | import utils
15 |
16 |
17 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
18 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
19 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
20 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
21 | set_training_mode=True):
22 | model.train(set_training_mode)
23 | metric_logger = utils.MetricLogger(delimiter=" ")
24 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
25 | header = 'Epoch: [{}]'.format(epoch)
26 | print_freq = 10
27 |
28 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
29 | samples = samples.to(device, non_blocking=True)
30 | targets = targets.to(device, non_blocking=True)
31 |
32 | if mixup_fn is not None:
33 | samples, targets = mixup_fn(samples, targets)
34 |
35 | with torch.cuda.amp.autocast():
36 | outputs = model(samples)
37 | loss = criterion(samples, outputs, targets)
38 |
39 | loss_value = loss.item()
40 |
41 | if not math.isfinite(loss_value):
42 | print("Loss is {}, stopping training".format(loss_value))
43 | sys.exit(1)
44 |
45 | optimizer.zero_grad()
46 |
47 | # this attribute is added by timm on one optimizer (adahessian)
48 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
49 | loss_scaler(loss, optimizer, clip_grad=max_norm,
50 | parameters=model.parameters(), create_graph=is_second_order)
51 |
52 | torch.cuda.synchronize()
53 | if model_ema is not None:
54 | model_ema.update(model)
55 |
56 | metric_logger.update(loss=loss_value)
57 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
58 | # gather the stats from all processes
59 | metric_logger.synchronize_between_processes()
60 | print("Averaged stats:", metric_logger)
61 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
62 |
63 |
64 | @torch.no_grad()
65 | def evaluate(data_loader, model, device):
66 | criterion = torch.nn.CrossEntropyLoss()
67 |
68 | metric_logger = utils.MetricLogger(delimiter=" ")
69 | header = 'Test:'
70 |
71 | # switch to evaluation mode
72 | model.eval()
73 |
74 | for images, target in metric_logger.log_every(data_loader, 10, header):
75 | images = images.to(device, non_blocking=True)
76 | target = target.to(device, non_blocking=True)
77 |
78 | # compute output
79 | with torch.cuda.amp.autocast():
80 | output = model(images)
81 | loss = criterion(output, target)
82 |
83 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
84 |
85 | batch_size = images.shape[0]
86 | metric_logger.update(loss=loss.item())
87 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
88 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
89 | # gather the stats from all processes
90 | metric_logger.synchronize_between_processes()
91 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
92 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
93 |
94 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
95 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is from https://github.com/facebookresearch/deit/blob/main/datasets.py.
3 | """
4 |
5 | import os
6 | import json
7 |
8 | from torchvision import datasets, transforms
9 | from torchvision.datasets.folder import ImageFolder, default_loader
10 |
11 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12 | from timm.data import create_transform
13 |
14 |
15 | class INatDataset(ImageFolder):
16 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
17 | category='name', loader=default_loader):
18 | self.transform = transform
19 | self.loader = loader
20 | self.target_transform = target_transform
21 | self.year = year
22 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
23 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
24 | with open(path_json) as json_file:
25 | data = json.load(json_file)
26 |
27 | with open(os.path.join(root, 'categories.json')) as json_file:
28 | data_catg = json.load(json_file)
29 |
30 | path_json_for_targeter = os.path.join(root, f"train{year}.json")
31 |
32 | with open(path_json_for_targeter) as json_file:
33 | data_for_targeter = json.load(json_file)
34 |
35 | targeter = {}
36 | indexer = 0
37 | for elem in data_for_targeter['annotations']:
38 | king = []
39 | king.append(data_catg[int(elem['category_id'])][category])
40 | if king[0] not in targeter.keys():
41 | targeter[king[0]] = indexer
42 | indexer += 1
43 | self.nb_classes = len(targeter)
44 |
45 | self.samples = []
46 | for elem in data['images']:
47 | cut = elem['file_name'].split('/')
48 | target_current = int(cut[2])
49 | path_current = os.path.join(root, cut[0], cut[2], cut[3])
50 |
51 | categors = data_catg[target_current]
52 | target_current_true = targeter[categors[category]]
53 | self.samples.append((path_current, target_current_true))
54 |
55 | # __getitem__ and __len__ inherited from ImageFolder
56 |
57 |
58 | def build_dataset(is_train, args):
59 | transform = build_transform(is_train, args)
60 |
61 | if args.data_set == 'CIFAR':
62 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
63 | nb_classes = 100
64 | elif args.data_set == 'IMNET':
65 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
66 | dataset = datasets.ImageFolder(root, transform=transform)
67 | nb_classes = 1000
68 | elif args.data_set == 'INAT':
69 | dataset = INatDataset(args.data_path, train=is_train, year=2018,
70 | category=args.inat_category, transform=transform)
71 | nb_classes = dataset.nb_classes
72 | elif args.data_set == 'INAT19':
73 | dataset = INatDataset(args.data_path, train=is_train, year=2019,
74 | category=args.inat_category, transform=transform)
75 | nb_classes = dataset.nb_classes
76 |
77 | return dataset, nb_classes
78 |
79 |
80 | def build_transform(is_train, args):
81 | resize_im = args.input_size > 32
82 | if is_train:
83 | # this should always dispatch to transforms_imagenet_train
84 | transform = create_transform(
85 | input_size=args.input_size,
86 | is_training=True,
87 | color_jitter=args.color_jitter,
88 | auto_augment=args.aa,
89 | interpolation=args.train_interpolation,
90 | re_prob=args.reprob,
91 | re_mode=args.remode,
92 | re_count=args.recount,
93 | )
94 | if not resize_im:
95 | # replace RandomResizedCropAndInterpolation with
96 | # RandomCrop
97 | transform.transforms[0] = transforms.RandomCrop(
98 | args.input_size, padding=4)
99 | return transform
100 |
101 | t = []
102 | if resize_im:
103 | size = int((256 / 224) * args.input_size)
104 | t.append(
105 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
106 | )
107 | t.append(transforms.CenterCrop(args.input_size))
108 |
109 | t.append(transforms.ToTensor())
110 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
111 | return transforms.Compose(t)
112 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is from https://github.com/facebookresearch/deit/blob/main/models.py.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | from functools import partial
7 |
8 | from timm.models.vision_transformer import VisionTransformer, _cfg
9 | from timm.models.registry import register_model
10 | from timm.models.layers import trunc_normal_
11 |
12 |
13 | __all__ = [
14 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
15 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
16 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
17 | 'deit_base_distilled_patch16_384',
18 | ]
19 |
20 |
21 | class DistilledVisionTransformer(VisionTransformer):
22 | def __init__(self, *args, **kwargs):
23 | super().__init__(*args, **kwargs)
24 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
25 | num_patches = self.patch_embed.num_patches
26 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
27 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
28 |
29 | trunc_normal_(self.dist_token, std=.02)
30 | trunc_normal_(self.pos_embed, std=.02)
31 | self.head_dist.apply(self._init_weights)
32 |
33 | def forward_features(self, x):
34 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
35 | # with slight modifications to add the dist_token
36 | B = x.shape[0]
37 | x = self.patch_embed(x)
38 |
39 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
40 | dist_token = self.dist_token.expand(B, -1, -1)
41 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
42 |
43 | x = x + self.pos_embed
44 | x = self.pos_drop(x)
45 |
46 | for blk in self.blocks:
47 | x = blk(x)
48 |
49 | x = self.norm(x)
50 | return x[:, 0], x[:, 1]
51 |
52 | def forward(self, x):
53 | x, x_dist = self.forward_features(x)
54 | x = self.head(x)
55 | x_dist = self.head_dist(x_dist)
56 | if self.training:
57 | return x, x_dist
58 | else:
59 | # during inference, return the average of both classifier predictions
60 | return (x + x_dist) / 2
61 |
62 |
63 | @register_model
64 | def deit_tiny_patch16_224(pretrained=False, **kwargs):
65 | model = VisionTransformer(
66 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
67 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
68 | model.default_cfg = _cfg()
69 | if pretrained:
70 | checkpoint = torch.hub.load_state_dict_from_url(
71 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
72 | map_location="cpu", check_hash=True
73 | )
74 | model.load_state_dict(checkpoint["model"])
75 | return model
76 |
77 |
78 | @register_model
79 | def deit_small_patch16_224(pretrained=False, **kwargs):
80 | model = VisionTransformer(
81 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
82 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
83 | model.default_cfg = _cfg()
84 | if pretrained:
85 | checkpoint = torch.hub.load_state_dict_from_url(
86 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
87 | map_location="cpu", check_hash=True
88 | )
89 | model.load_state_dict(checkpoint["model"])
90 | return model
91 |
92 |
93 | @register_model
94 | def deit_base_patch16_224(pretrained=False, **kwargs):
95 | model = VisionTransformer(
96 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
97 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
98 | model.default_cfg = _cfg()
99 | if pretrained:
100 | checkpoint = torch.hub.load_state_dict_from_url(
101 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
102 | map_location="cpu", check_hash=True
103 | )
104 | model.load_state_dict(checkpoint["model"])
105 | return model
106 |
107 |
108 | @register_model
109 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
110 | model = DistilledVisionTransformer(
111 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
112 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
113 | model.default_cfg = _cfg()
114 | if pretrained:
115 | checkpoint = torch.hub.load_state_dict_from_url(
116 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
117 | map_location="cpu", check_hash=True
118 | )
119 | model.load_state_dict(checkpoint["model"])
120 | return model
121 |
122 |
123 | @register_model
124 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
125 | model = DistilledVisionTransformer(
126 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
127 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
128 | model.default_cfg = _cfg()
129 | if pretrained:
130 | checkpoint = torch.hub.load_state_dict_from_url(
131 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
132 | map_location="cpu", check_hash=True
133 | )
134 | model.load_state_dict(checkpoint["model"])
135 | return model
136 |
137 |
138 | @register_model
139 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
140 | model = DistilledVisionTransformer(
141 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
142 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
143 | model.default_cfg = _cfg()
144 | if pretrained:
145 | checkpoint = torch.hub.load_state_dict_from_url(
146 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
147 | map_location="cpu", check_hash=True
148 | )
149 | model.load_state_dict(checkpoint["model"])
150 | return model
151 |
152 |
153 | @register_model
154 | def deit_base_patch16_384(pretrained=False, **kwargs):
155 | model = VisionTransformer(
156 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
157 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
158 | model.default_cfg = _cfg()
159 | if pretrained:
160 | checkpoint = torch.hub.load_state_dict_from_url(
161 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
162 | map_location="cpu", check_hash=True
163 | )
164 | model.load_state_dict(checkpoint["model"])
165 | return model
166 |
167 |
168 | @register_model
169 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
170 | model = DistilledVisionTransformer(
171 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
172 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
173 | model.default_cfg = _cfg()
174 | if pretrained:
175 | checkpoint = torch.hub.load_state_dict_from_url(
176 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
177 | map_location="cpu", check_hash=True
178 | )
179 | model.load_state_dict(checkpoint["model"])
180 | return model
181 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is from https://github.com/facebookresearch/deit/blob/main/utils.py.
3 | """
4 | import io
5 | import os
6 | import time
7 | from collections import defaultdict, deque
8 | import datetime
9 |
10 | import torch
11 | import torch.distributed as dist
12 |
13 |
14 | class SmoothedValue(object):
15 | """Track a series of values and provide access to smoothed values over a
16 | window or the global series average.
17 | """
18 |
19 | def __init__(self, window_size=20, fmt=None):
20 | if fmt is None:
21 | fmt = "{median:.4f} ({global_avg:.4f})"
22 | self.deque = deque(maxlen=window_size)
23 | self.total = 0.0
24 | self.count = 0
25 | self.fmt = fmt
26 |
27 | def update(self, value, n=1):
28 | self.deque.append(value)
29 | self.count += n
30 | self.total += value * n
31 |
32 | def synchronize_between_processes(self):
33 | """
34 | Warning: does not synchronize the deque!
35 | """
36 | if not is_dist_avail_and_initialized():
37 | return
38 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
39 | dist.barrier()
40 | dist.all_reduce(t)
41 | t = t.tolist()
42 | self.count = int(t[0])
43 | self.total = t[1]
44 |
45 | @property
46 | def median(self):
47 | d = torch.tensor(list(self.deque))
48 | return d.median().item()
49 |
50 | @property
51 | def avg(self):
52 | d = torch.tensor(list(self.deque), dtype=torch.float32)
53 | return d.mean().item()
54 |
55 | @property
56 | def global_avg(self):
57 | return self.total / self.count
58 |
59 | @property
60 | def max(self):
61 | return max(self.deque)
62 |
63 | @property
64 | def value(self):
65 | return self.deque[-1]
66 |
67 | def __str__(self):
68 | return self.fmt.format(
69 | median=self.median,
70 | avg=self.avg,
71 | global_avg=self.global_avg,
72 | max=self.max,
73 | value=self.value)
74 |
75 |
76 | class MetricLogger(object):
77 | def __init__(self, delimiter="\t"):
78 | self.meters = defaultdict(SmoothedValue)
79 | self.delimiter = delimiter
80 |
81 | def update(self, **kwargs):
82 | for k, v in kwargs.items():
83 | if isinstance(v, torch.Tensor):
84 | v = v.item()
85 | assert isinstance(v, (float, int))
86 | self.meters[k].update(v)
87 |
88 | def __getattr__(self, attr):
89 | if attr in self.meters:
90 | return self.meters[attr]
91 | if attr in self.__dict__:
92 | return self.__dict__[attr]
93 | raise AttributeError("'{}' object has no attribute '{}'".format(
94 | type(self).__name__, attr))
95 |
96 | def __str__(self):
97 | loss_str = []
98 | for name, meter in self.meters.items():
99 | loss_str.append(
100 | "{}: {}".format(name, str(meter))
101 | )
102 | return self.delimiter.join(loss_str)
103 |
104 | def synchronize_between_processes(self):
105 | for meter in self.meters.values():
106 | meter.synchronize_between_processes()
107 |
108 | def add_meter(self, name, meter):
109 | self.meters[name] = meter
110 |
111 | def log_every(self, iterable, print_freq, header=None):
112 | i = 0
113 | if not header:
114 | header = ''
115 | start_time = time.time()
116 | end = time.time()
117 | iter_time = SmoothedValue(fmt='{avg:.4f}')
118 | data_time = SmoothedValue(fmt='{avg:.4f}')
119 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
120 | log_msg = [
121 | header,
122 | '[{0' + space_fmt + '}/{1}]',
123 | 'eta: {eta}',
124 | '{meters}',
125 | 'time: {time}',
126 | 'data: {data}'
127 | ]
128 | if torch.cuda.is_available():
129 | log_msg.append('max mem: {memory:.0f}')
130 | log_msg = self.delimiter.join(log_msg)
131 | MB = 1024.0 * 1024.0
132 | for obj in iterable:
133 | data_time.update(time.time() - end)
134 | yield obj
135 | iter_time.update(time.time() - end)
136 | if i % print_freq == 0 or i == len(iterable) - 1:
137 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
138 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
139 | if torch.cuda.is_available():
140 | print(log_msg.format(
141 | i, len(iterable), eta=eta_string,
142 | meters=str(self),
143 | time=str(iter_time), data=str(data_time),
144 | memory=torch.cuda.max_memory_allocated() / MB))
145 | else:
146 | print(log_msg.format(
147 | i, len(iterable), eta=eta_string,
148 | meters=str(self),
149 | time=str(iter_time), data=str(data_time)))
150 | i += 1
151 | end = time.time()
152 | total_time = time.time() - start_time
153 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
154 | print('{} Total time: {} ({:.4f} s / it)'.format(
155 | header, total_time_str, total_time / len(iterable)))
156 |
157 |
158 | def _load_checkpoint_for_ema(model_ema, checkpoint):
159 | """
160 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
161 | """
162 | mem_file = io.BytesIO()
163 | torch.save(checkpoint, mem_file)
164 | mem_file.seek(0)
165 | model_ema._load_checkpoint(mem_file)
166 |
167 |
168 | def setup_for_distributed(is_master):
169 | """
170 | This function disables printing when not in master process
171 | """
172 | import builtins as __builtin__
173 | builtin_print = __builtin__.print
174 |
175 | def print(*args, **kwargs):
176 | force = kwargs.pop('force', False)
177 | if is_master or force:
178 | builtin_print(*args, **kwargs)
179 |
180 | __builtin__.print = print
181 |
182 |
183 | def is_dist_avail_and_initialized():
184 | if not dist.is_available():
185 | return False
186 | if not dist.is_initialized():
187 | return False
188 | return True
189 |
190 |
191 | def get_world_size():
192 | if not is_dist_avail_and_initialized():
193 | return 1
194 | return dist.get_world_size()
195 |
196 |
197 | def get_rank():
198 | if not is_dist_avail_and_initialized():
199 | return 0
200 | return dist.get_rank()
201 |
202 |
203 | def is_main_process():
204 | return get_rank() == 0
205 |
206 |
207 | def save_on_master(*args, **kwargs):
208 | if is_main_process():
209 | torch.save(*args, **kwargs)
210 |
211 |
212 | def init_distributed_mode(args):
213 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
214 | args.rank = int(os.environ["RANK"])
215 | args.world_size = int(os.environ['WORLD_SIZE'])
216 | args.gpu = int(os.environ['LOCAL_RANK'])
217 | elif 'SLURM_PROCID' in os.environ:
218 | args.rank = int(os.environ['SLURM_PROCID'])
219 | args.gpu = args.rank % torch.cuda.device_count()
220 | else:
221 | print('Not using distributed mode')
222 | args.distributed = False
223 | return
224 |
225 | args.distributed = True
226 |
227 | torch.cuda.set_device(args.gpu)
228 | args.dist_backend = 'nccl'
229 | print('| distributed init (rank {}): {}'.format(
230 | args.rank, args.dist_url), flush=True)
231 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
232 | world_size=args.world_size, rank=args.rank)
233 | torch.distributed.barrier()
234 | setup_for_distributed(args.rank == 0)
235 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is from https://github.com/facebookresearch/deit/blob/main/main.py.
3 | """
4 | import argparse
5 | import datetime
6 | import numpy as np
7 | import time
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | import json
11 |
12 | from pathlib import Path
13 |
14 | from timm.data import Mixup
15 | from timm.models import create_model
16 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
17 | from timm.scheduler import create_scheduler
18 | from timm.optim import create_optimizer
19 | from timm.utils import NativeScaler, get_state_dict, ModelEma
20 |
21 | from datasets import build_dataset
22 | from engine import train_one_epoch, evaluate
23 | from losses import DistillationLoss
24 | from samplers import RASampler
25 | import models
26 | import utils
27 |
28 |
29 | def get_args_parser():
30 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
31 | parser.add_argument('--batch-size', default=64, type=int)
32 | parser.add_argument('--epochs', default=300, type=int)
33 |
34 | # Model parameters
35 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
36 | help='Name of model to train')
37 | parser.add_argument('--input-size', default=224, type=int, help='images input size')
38 |
39 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
40 | help='Dropout rate (default: 0.)')
41 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
42 | help='Drop path rate (default: 0.1)')
43 |
44 | parser.add_argument('--model-ema', action='store_true')
45 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
46 | parser.set_defaults(model_ema=True)
47 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
48 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
49 |
50 | # Optimizer parameters
51 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
52 | help='Optimizer (default: "adamw"')
53 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
54 | help='Optimizer Epsilon (default: 1e-8)')
55 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
56 | help='Optimizer Betas (default: None, use opt default)')
57 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
58 | help='Clip gradient norm (default: None, no clipping)')
59 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
60 | help='SGD momentum (default: 0.9)')
61 | parser.add_argument('--weight-decay', type=float, default=0.05,
62 | help='weight decay (default: 0.05)')
63 | # Learning rate schedule parameters
64 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
65 | help='LR scheduler (default: "cosine"')
66 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
67 | help='learning rate (default: 5e-4)')
68 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
69 | help='learning rate noise on/off epoch percentages')
70 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
71 | help='learning rate noise limit percent (default: 0.67)')
72 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
73 | help='learning rate noise std-dev (default: 1.0)')
74 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
75 | help='warmup learning rate (default: 1e-6)')
76 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
77 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
78 |
79 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
80 | help='epoch interval to decay LR')
81 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
82 | help='epochs to warmup LR, if scheduler supports')
83 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
84 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
85 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
86 | help='patience epochs for Plateau LR scheduler (default: 10')
87 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
88 | help='LR decay rate (default: 0.1)')
89 |
90 | # Augmentation parameters
91 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
92 | help='Color jitter factor (default: 0.4)')
93 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
94 | help='Use AutoAugment policy. "v0" or "original". " + \
95 | "(default: rand-m9-mstd0.5-inc1)'),
96 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
97 | parser.add_argument('--train-interpolation', type=str, default='bicubic',
98 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
99 |
100 | parser.add_argument('--repeated-aug', action='store_true')
101 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
102 | parser.set_defaults(repeated_aug=True)
103 |
104 | # * Random Erase params
105 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
106 | help='Random erase prob (default: 0.25)')
107 | parser.add_argument('--remode', type=str, default='pixel',
108 | help='Random erase mode (default: "pixel")')
109 | parser.add_argument('--recount', type=int, default=1,
110 | help='Random erase count (default: 1)')
111 | parser.add_argument('--resplit', action='store_true', default=False,
112 | help='Do not random erase first (clean) augmentation split')
113 |
114 | # * Mixup params
115 | parser.add_argument('--mixup', type=float, default=0.8,
116 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
117 | parser.add_argument('--cutmix', type=float, default=1.0,
118 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
119 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
120 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
121 | parser.add_argument('--mixup-prob', type=float, default=1.0,
122 | help='Probability of performing mixup or cutmix when either/both is enabled')
123 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
124 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
125 | parser.add_argument('--mixup-mode', type=str, default='batch',
126 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
127 |
128 | # Distillation parameters
129 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
130 | help='Name of teacher model to train (default: "regnety_160"')
131 | parser.add_argument('--teacher-path', type=str, default='')
132 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
133 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
134 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
135 |
136 | # * Finetuning params
137 | parser.add_argument('--finetune', default='', help='finetune from checkpoint')
138 |
139 | # Dataset parameters
140 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
141 | help='dataset path')
142 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
143 | type=str, help='Image Net dataset path')
144 | parser.add_argument('--inat-category', default='name',
145 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
146 | type=str, help='semantic granularity')
147 |
148 | parser.add_argument('--output_dir', default='',
149 | help='path where to save, empty for no saving')
150 | parser.add_argument('--device', default='cuda',
151 | help='device to use for training / testing')
152 | parser.add_argument('--seed', default=0, type=int)
153 | parser.add_argument('--resume', default='', help='resume from checkpoint')
154 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
155 | help='start epoch')
156 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
157 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
158 | parser.add_argument('--num_workers', default=10, type=int)
159 | parser.add_argument('--pin-mem', action='store_true',
160 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
161 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
162 | help='')
163 | parser.set_defaults(pin_mem=True)
164 |
165 | # distributed training parameters
166 | parser.add_argument('--world_size', default=1, type=int,
167 | help='number of distributed processes')
168 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
169 | return parser
170 |
171 |
172 | def main(args):
173 | utils.init_distributed_mode(args)
174 |
175 | print(args)
176 |
177 | if args.distillation_type != 'none' and args.finetune and not args.eval:
178 | raise NotImplementedError("Finetuning with distillation not yet supported")
179 |
180 | device = torch.device(args.device)
181 |
182 | # fix the seed for reproducibility
183 | seed = args.seed + utils.get_rank()
184 | torch.manual_seed(seed)
185 | np.random.seed(seed)
186 | # random.seed(seed)
187 |
188 | cudnn.benchmark = True
189 |
190 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
191 | dataset_val, _ = build_dataset(is_train=False, args=args)
192 |
193 | if True: # args.distributed:
194 | num_tasks = utils.get_world_size()
195 | global_rank = utils.get_rank()
196 | if args.repeated_aug:
197 | sampler_train = RASampler(
198 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
199 | )
200 | else:
201 | sampler_train = torch.utils.data.DistributedSampler(
202 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
203 | )
204 | if args.dist_eval:
205 | if len(dataset_val) % num_tasks != 0:
206 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
207 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
208 | 'equal num of samples per-process.')
209 | sampler_val = torch.utils.data.DistributedSampler(
210 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
211 | else:
212 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
213 | else:
214 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
215 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
216 |
217 | data_loader_train = torch.utils.data.DataLoader(
218 | dataset_train, sampler=sampler_train,
219 | batch_size=args.batch_size,
220 | num_workers=args.num_workers,
221 | pin_memory=args.pin_mem,
222 | drop_last=True,
223 | )
224 |
225 | data_loader_val = torch.utils.data.DataLoader(
226 | dataset_val, sampler=sampler_val,
227 | batch_size=int(1.5 * args.batch_size),
228 | num_workers=args.num_workers,
229 | pin_memory=args.pin_mem,
230 | drop_last=False
231 | )
232 |
233 | mixup_fn = None
234 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
235 | if mixup_active:
236 | mixup_fn = Mixup(
237 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
238 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
239 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
240 |
241 | print(f"Creating model: {args.model}")
242 | model = create_model(
243 | args.model,
244 | pretrained=False,
245 | num_classes=args.nb_classes,
246 | drop_rate=args.drop,
247 | drop_path_rate=args.drop_path,
248 | drop_block_rate=None,
249 | )
250 |
251 | if args.finetune:
252 | if args.finetune.startswith('https'):
253 | checkpoint = torch.hub.load_state_dict_from_url(
254 | args.finetune, map_location='cpu', check_hash=True)
255 | else:
256 | checkpoint = torch.load(args.finetune, map_location='cpu')
257 |
258 | checkpoint_model = checkpoint['model']
259 | state_dict = model.state_dict()
260 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
261 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
262 | print(f"Removing key {k} from pretrained checkpoint")
263 | del checkpoint_model[k]
264 |
265 | # interpolate position embedding
266 | pos_embed_checkpoint = checkpoint_model['pos_embed']
267 | embedding_size = pos_embed_checkpoint.shape[-1]
268 | num_patches = model.patch_embed.num_patches
269 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
270 | # height (== width) for the checkpoint position embedding
271 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
272 | # height (== width) for the new position embedding
273 | new_size = int(num_patches ** 0.5)
274 | # class_token and dist_token are kept unchanged
275 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
276 | # only the position tokens are interpolated
277 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
278 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
279 | pos_tokens = torch.nn.functional.interpolate(
280 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
281 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
282 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
283 | checkpoint_model['pos_embed'] = new_pos_embed
284 |
285 | model.load_state_dict(checkpoint_model, strict=False)
286 |
287 | model.to(device)
288 |
289 | model_ema = None
290 | if args.model_ema:
291 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
292 | model_ema = ModelEma(
293 | model,
294 | decay=args.model_ema_decay,
295 | device='cpu' if args.model_ema_force_cpu else '',
296 | resume='')
297 |
298 | model_without_ddp = model
299 | if args.distributed:
300 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
301 | model_without_ddp = model.module
302 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
303 | print('number of params:', n_parameters)
304 |
305 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
306 | args.lr = linear_scaled_lr
307 | optimizer = create_optimizer(args, model_without_ddp)
308 | loss_scaler = NativeScaler()
309 |
310 | lr_scheduler, _ = create_scheduler(args, optimizer)
311 |
312 | criterion = LabelSmoothingCrossEntropy()
313 |
314 | if mixup_active:
315 | # smoothing is handled with mixup label transform
316 | criterion = SoftTargetCrossEntropy()
317 | elif args.smoothing:
318 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
319 | else:
320 | criterion = torch.nn.CrossEntropyLoss()
321 |
322 | teacher_model = None
323 | if args.distillation_type != 'none':
324 | assert args.teacher_path, 'need to specify teacher-path when using distillation'
325 | print(f"Creating teacher model: {args.teacher_model}")
326 | teacher_model = create_model(
327 | args.teacher_model,
328 | pretrained=False,
329 | num_classes=args.nb_classes,
330 | global_pool='avg',
331 | )
332 | if args.teacher_path.startswith('https'):
333 | checkpoint = torch.hub.load_state_dict_from_url(
334 | args.teacher_path, map_location='cpu', check_hash=True)
335 | else:
336 | checkpoint = torch.load(args.teacher_path, map_location='cpu')
337 | teacher_model.load_state_dict(checkpoint['model'])
338 | teacher_model.to(device)
339 | teacher_model.eval()
340 |
341 | # wrap the criterion in our custom DistillationLoss, which
342 | # just dispatches to the original criterion if args.distillation_type is 'none'
343 | criterion = DistillationLoss(
344 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
345 | )
346 |
347 | output_dir = Path(args.output_dir)
348 | if args.resume:
349 | if args.resume.startswith('https'):
350 | checkpoint = torch.hub.load_state_dict_from_url(
351 | args.resume, map_location='cpu', check_hash=True)
352 | else:
353 | checkpoint = torch.load(args.resume, map_location='cpu')
354 | model_without_ddp.load_state_dict(checkpoint['model'])
355 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
356 | optimizer.load_state_dict(checkpoint['optimizer'])
357 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
358 | args.start_epoch = checkpoint['epoch'] + 1
359 | if args.model_ema:
360 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
361 | if 'scaler' in checkpoint:
362 | loss_scaler.load_state_dict(checkpoint['scaler'])
363 | lr_scheduler.step(args.start_epoch)
364 | if args.eval:
365 | test_stats = evaluate(data_loader_val, model, device)
366 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
367 | return
368 |
369 | print(f"Start training for {args.epochs} epochs")
370 | start_time = time.time()
371 | max_accuracy = 0.0
372 | for epoch in range(args.start_epoch, args.epochs):
373 | if args.distributed:
374 | data_loader_train.sampler.set_epoch(epoch)
375 |
376 | train_stats = train_one_epoch(
377 | model, criterion, data_loader_train,
378 | optimizer, device, epoch, loss_scaler,
379 | args.clip_grad, model_ema, mixup_fn,
380 | set_training_mode=args.finetune == '' # keep in eval mode during finetuning
381 | )
382 |
383 | lr_scheduler.step(epoch)
384 | if args.output_dir:
385 | checkpoint_paths = [output_dir / 'checkpoint.pth']
386 | for checkpoint_path in checkpoint_paths:
387 | utils.save_on_master({
388 | 'model': model_without_ddp.state_dict(),
389 | 'optimizer': optimizer.state_dict(),
390 | 'lr_scheduler': lr_scheduler.state_dict(),
391 | 'epoch': epoch,
392 | 'model_ema': get_state_dict(model_ema),
393 | 'scaler': loss_scaler.state_dict(),
394 | 'args': args,
395 | }, checkpoint_path)
396 |
397 |
398 | test_stats = evaluate(data_loader_val, model, device)
399 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
400 |
401 | if max_accuracy < test_stats["acc1"]:
402 | max_accuracy = test_stats["acc1"]
403 | if args.output_dir:
404 | checkpoint_paths = [output_dir / 'best_checkpoint.pth']
405 | for checkpoint_path in checkpoint_paths:
406 | utils.save_on_master({
407 | 'model': model_without_ddp.state_dict(),
408 | 'optimizer': optimizer.state_dict(),
409 | 'lr_scheduler': lr_scheduler.state_dict(),
410 | 'epoch': epoch,
411 | 'model_ema': get_state_dict(model_ema),
412 | 'scaler': loss_scaler.state_dict(),
413 | 'args': args,
414 | }, checkpoint_path)
415 |
416 | print(f'Max accuracy: {max_accuracy:.2f}%')
417 |
418 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
419 | **{f'test_{k}': v for k, v in test_stats.items()},
420 | 'epoch': epoch,
421 | 'n_parameters': n_parameters}
422 |
423 |
424 |
425 |
426 | if args.output_dir and utils.is_main_process():
427 | with (output_dir / "log.txt").open("a") as f:
428 | f.write(json.dumps(log_stats) + "\n")
429 |
430 | total_time = time.time() - start_time
431 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
432 | print('Training time {}'.format(total_time_str))
433 |
434 |
435 | if __name__ == '__main__':
436 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
437 | args = parser.parse_args()
438 | if args.output_dir:
439 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
440 | main(args)
441 |
--------------------------------------------------------------------------------
/KNN_VisionTransformer.py:
--------------------------------------------------------------------------------
1 | """
2 | modified from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
3 | """
4 | import math
5 | import logging
6 | from functools import partial
7 | from collections import OrderedDict
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14 | from timm.models.helpers import load_pretrained
15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
16 | from timm.models.resnet import resnet26d, resnet50d
17 | from timm.models.resnetv2 import ResNetV2, StdConv2dSame
18 | from timm.models.registry import register_model
19 |
20 | _logger = logging.getLogger(__name__)
21 |
22 | class Mlp(nn.Module):
23 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
24 | super().__init__()
25 | out_features = out_features or in_features
26 | hidden_features = hidden_features or in_features
27 | self.fc1 = nn.Linear(in_features, hidden_features)
28 | self.act = act_layer()
29 | self.fc2 = nn.Linear(hidden_features, out_features)
30 | self.drop = nn.Dropout(drop)
31 |
32 | def forward(self, x):
33 | x = self.fc1(x)
34 | x = self.act(x)
35 | x = self.drop(x)
36 | x = self.fc2(x)
37 | x = self.drop(x)
38 | return x
39 |
40 |
41 | class kNNAttention(nn.Module):
42 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,topk=100):
43 | super().__init__()
44 | self.num_heads = num_heads
45 | head_dim = dim // num_heads
46 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
47 | self.scale = qk_scale or head_dim ** -0.5
48 |
49 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
50 | self.attn_drop = nn.Dropout(attn_drop)
51 | self.proj = nn.Linear(dim, dim)
52 | self.proj_drop = nn.Dropout(proj_drop)
53 | self.topk = topk
54 |
55 | def forward(self, x):
56 | B, N, C = x.shape
57 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
58 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
59 | attn = (q @ k.transpose(-2, -1)) * self.scale
60 | # the core code block
61 | mask=torch.zeros(B,self.num_heads,N,N,device=x.device,requires_grad=False)
62 | index=torch.topk(attn,k=self.topk,dim=-1,largest=True)[1]
63 | mask.scatter_(-1,index,1.)
64 | attn=torch.where(mask>0,attn,torch.full_like(attn,float('-inf')))
65 | # end of the core code block
66 |
67 | attn = attn.softmax(dim=-1)
68 | attn = self.attn_drop(attn)
69 |
70 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
71 | x = self.proj(x)
72 | x = self.proj_drop(x)
73 | return x
74 |
75 |
76 | class Block(nn.Module):
77 |
78 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
79 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
80 | super().__init__()
81 | self.norm1 = norm_layer(dim)
82 | self.attn = kNNAttention(
83 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
84 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
85 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
86 | self.norm2 = norm_layer(dim)
87 | mlp_hidden_dim = int(dim * mlp_ratio)
88 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
89 |
90 | def forward(self, x):
91 | x = x + self.drop_path(self.attn(self.norm1(x)))
92 | x = x + self.drop_path(self.mlp(self.norm2(x)))
93 | return x
94 |
95 |
96 | class PatchEmbed(nn.Module):
97 | """ Image to Patch Embedding
98 | """
99 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
100 | super().__init__()
101 | img_size = to_2tuple(img_size)
102 | patch_size = to_2tuple(patch_size)
103 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
104 | self.img_size = img_size
105 | self.patch_size = patch_size
106 | self.num_patches = num_patches
107 |
108 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
109 |
110 | def forward(self, x):
111 | B, C, H, W = x.shape
112 | # FIXME look at relaxing size constraints
113 | assert H == self.img_size[0] and W == self.img_size[1], \
114 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
115 | x = self.proj(x).flatten(2).transpose(1, 2)
116 | return x
117 |
118 |
119 | class HybridEmbed(nn.Module):
120 | """ CNN Feature Map Embedding
121 | Extract feature map from CNN, flatten, project to embedding dim.
122 | """
123 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
124 | super().__init__()
125 | assert isinstance(backbone, nn.Module)
126 | img_size = to_2tuple(img_size)
127 | self.img_size = img_size
128 | self.backbone = backbone
129 | if feature_size is None:
130 | with torch.no_grad():
131 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
132 | # map for all networks, the feature metadata has reliable channel and stride info, but using
133 | # stride to calc feature dim requires info about padding of each stage that isn't captured.
134 | training = backbone.training
135 | if training:
136 | backbone.eval()
137 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
138 | if isinstance(o, (list, tuple)):
139 | o = o[-1] # last feature if backbone outputs list/tuple of features
140 | feature_size = o.shape[-2:]
141 | feature_dim = o.shape[1]
142 | backbone.train(training)
143 | else:
144 | feature_size = to_2tuple(feature_size)
145 | if hasattr(self.backbone, 'feature_info'):
146 | feature_dim = self.backbone.feature_info.channels()[-1]
147 | else:
148 | feature_dim = self.backbone.num_features
149 | self.num_patches = feature_size[0] * feature_size[1]
150 | self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
151 |
152 | def forward(self, x):
153 | x = self.backbone(x)
154 | if isinstance(x, (list, tuple)):
155 | x = x[-1] # last feature if backbone outputs list/tuple of features
156 | x = self.proj(x).flatten(2).transpose(1, 2)
157 | return x
158 |
159 |
160 | class VisionTransformer(nn.Module):
161 | """ Vision Transformer
162 |
163 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
164 | https://arxiv.org/abs/2010.11929
165 | """
166 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
167 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
168 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
169 | """
170 | Args:
171 | img_size (int, tuple): input image size
172 | patch_size (int, tuple): patch size
173 | in_chans (int): number of input channels
174 | num_classes (int): number of classes for classification head
175 | embed_dim (int): embedding dimension
176 | depth (int): depth of transformer
177 | num_heads (int): number of attention heads
178 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
179 | qkv_bias (bool): enable bias for qkv if True
180 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set
181 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
182 | drop_rate (float): dropout rate
183 | attn_drop_rate (float): attention dropout rate
184 | drop_path_rate (float): stochastic depth rate
185 | hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
186 | norm_layer: (nn.Module): normalization layer
187 | """
188 | super().__init__()
189 | self.num_classes = num_classes
190 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
191 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
192 |
193 | if hybrid_backbone is not None:
194 | self.patch_embed = HybridEmbed(
195 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
196 | else:
197 | self.patch_embed = PatchEmbed(
198 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
199 | num_patches = self.patch_embed.num_patches
200 |
201 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
202 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
203 | self.pos_drop = nn.Dropout(p=drop_rate)
204 |
205 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
206 | self.blocks = nn.ModuleList([
207 | Block(
208 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
209 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
210 | for i in range(depth)])
211 | self.norm = norm_layer(embed_dim)
212 |
213 | # Representation layer
214 | if representation_size:
215 | self.num_features = representation_size
216 | self.pre_logits = nn.Sequential(OrderedDict([
217 | ('fc', nn.Linear(embed_dim, representation_size)),
218 | ('act', nn.Tanh())
219 | ]))
220 | else:
221 | self.pre_logits = nn.Identity()
222 |
223 | # Classifier head
224 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
225 |
226 | trunc_normal_(self.pos_embed, std=.02)
227 | trunc_normal_(self.cls_token, std=.02)
228 | self.apply(self._init_weights)
229 |
230 | def _init_weights(self, m):
231 | if isinstance(m, nn.Linear):
232 | trunc_normal_(m.weight, std=.02)
233 | if isinstance(m, nn.Linear) and m.bias is not None:
234 | nn.init.constant_(m.bias, 0)
235 | elif isinstance(m, nn.LayerNorm):
236 | nn.init.constant_(m.bias, 0)
237 | nn.init.constant_(m.weight, 1.0)
238 |
239 | @torch.jit.ignore
240 | def no_weight_decay(self):
241 | return {'pos_embed', 'cls_token'}
242 |
243 | def get_classifier(self):
244 | return self.head
245 |
246 | def reset_classifier(self, num_classes, global_pool=''):
247 | self.num_classes = num_classes
248 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
249 |
250 | def forward_features(self, x):
251 | B = x.shape[0]
252 | x = self.patch_embed(x)
253 |
254 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
255 | x = torch.cat((cls_tokens, x), dim=1)
256 | x = x + self.pos_embed
257 | x = self.pos_drop(x)
258 |
259 | for blk in self.blocks:
260 | x = blk(x)
261 |
262 | x = self.norm(x)[:, 0]
263 | x = self.pre_logits(x)
264 | return x
265 |
266 | def forward(self, x):
267 | x = self.forward_features(x)
268 | x = self.head(x)
269 | return x
270 |
271 |
272 | class DistilledVisionTransformer(VisionTransformer):
273 | """ Vision Transformer with distillation token.
274 |
275 | Paper: `Training data-efficient image transformers & distillation through attention` -
276 | https://arxiv.org/abs/2012.12877
277 |
278 | This impl of distilled ViT is taken from https://github.com/facebookresearch/deit
279 | """
280 | def __init__(self, *args, **kwargs):
281 | super().__init__(*args, **kwargs)
282 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
283 | num_patches = self.patch_embed.num_patches
284 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
285 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
286 |
287 | trunc_normal_(self.dist_token, std=.02)
288 | trunc_normal_(self.pos_embed, std=.02)
289 | self.head_dist.apply(self._init_weights)
290 |
291 | def forward_features(self, x):
292 | B = x.shape[0]
293 | x = self.patch_embed(x)
294 |
295 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
296 | dist_token = self.dist_token.expand(B, -1, -1)
297 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
298 |
299 | x = x + self.pos_embed
300 | x = self.pos_drop(x)
301 |
302 | for blk in self.blocks:
303 | x = blk(x)
304 |
305 | x = self.norm(x)
306 | return x[:, 0], x[:, 1]
307 |
308 | def forward(self, x):
309 | x, x_dist = self.forward_features(x)
310 | x = self.head(x)
311 | x_dist = self.head_dist(x_dist)
312 | if self.training:
313 | return x, x_dist
314 | else:
315 | # during inference, return the average of both classifier predictions
316 | return (x + x_dist) / 2
317 |
318 |
319 | def resize_pos_embed(posemb, posemb_new):
320 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
321 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
322 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
323 | ntok_new = posemb_new.shape[1]
324 | if True:
325 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
326 | ntok_new -= 1
327 | else:
328 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
329 | gs_old = int(math.sqrt(len(posemb_grid)))
330 | gs_new = int(math.sqrt(ntok_new))
331 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
332 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
333 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
334 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
335 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
336 | return posemb
337 |
338 |
339 | def checkpoint_filter_fn(state_dict, model):
340 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
341 | out_dict = {}
342 | if 'model' in state_dict:
343 | # For deit models
344 | state_dict = state_dict['model']
345 | for k, v in state_dict.items():
346 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
347 | # For old models that I trained prior to conv based patchification
348 | O, I, H, W = model.patch_embed.proj.weight.shape
349 | v = v.reshape(O, -1, H, W)
350 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
351 | # To resize pos embedding when using model at different size from pretrained weights
352 | v = resize_pos_embed(v, model.pos_embed)
353 | out_dict[k] = v
354 | return out_dict
355 |
356 |
357 | def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
358 | default_cfg = default_cfgs[variant]
359 | default_num_classes = default_cfg['num_classes']
360 | default_img_size = default_cfg['input_size'][-1]
361 |
362 | num_classes = kwargs.pop('num_classes', default_num_classes)
363 | img_size = kwargs.pop('img_size', default_img_size)
364 | repr_size = kwargs.pop('representation_size', None)
365 | if repr_size is not None and num_classes != default_num_classes:
366 | # Remove representation layer if fine-tuning. This may not always be the desired action,
367 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
368 | _logger.warning("Removing representation layer for fine-tuning.")
369 | repr_size = None
370 |
371 | model_cls = DistilledVisionTransformer if distilled else VisionTransformer
372 | model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
373 | model.default_cfg = default_cfg
374 |
375 | if pretrained:
376 | load_pretrained(
377 | model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
378 | filter_fn=partial(checkpoint_filter_fn, model=model))
379 | return model
380 |
381 |
382 | @register_model
383 | def vit_small_patch16_224(pretrained=False, **kwargs):
384 | """ My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3."""
385 | model_kwargs = dict(
386 | patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
387 | qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
388 | if pretrained:
389 | # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
390 | model_kwargs.setdefault('qk_scale', 768 ** -0.5)
391 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
392 | return model
393 |
394 |
395 | @register_model
396 | def vit_base_patch16_224(pretrained=False, **kwargs):
397 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
398 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
399 | """
400 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
401 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
402 | return model
403 |
404 |
405 | @register_model
406 | def vit_base_patch32_224(pretrained=False, **kwargs):
407 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
408 | """
409 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
410 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
411 | return model
412 |
413 |
414 | @register_model
415 | def vit_base_patch16_384(pretrained=False, **kwargs):
416 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
417 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
418 | """
419 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
420 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
421 | return model
422 |
423 |
424 | @register_model
425 | def vit_base_patch32_384(pretrained=False, **kwargs):
426 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
427 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
428 | """
429 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
430 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
431 | return model
432 |
433 |
434 | @register_model
435 | def vit_large_patch16_224(pretrained=False, **kwargs):
436 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
437 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
438 | """
439 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
440 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
441 | return model
442 |
443 |
444 | @register_model
445 | def vit_large_patch32_224(pretrained=False, **kwargs):
446 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
447 | """
448 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
449 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
450 | return model
451 |
452 |
453 | @register_model
454 | def vit_large_patch16_384(pretrained=False, **kwargs):
455 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
456 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
457 | """
458 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
459 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
460 | return model
461 |
462 |
463 | @register_model
464 | def vit_large_patch32_384(pretrained=False, **kwargs):
465 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
466 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
467 | """
468 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
469 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
470 | return model
471 |
472 |
473 | @register_model
474 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
475 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
476 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
477 | """
478 | model_kwargs = dict(
479 | patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
480 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
481 | return model
482 |
483 |
484 | @register_model
485 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
486 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
487 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
488 | """
489 | model_kwargs = dict(
490 | patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
491 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
492 | return model
493 |
494 |
495 | @register_model
496 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
497 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
498 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
499 | """
500 | model_kwargs = dict(
501 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
502 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
503 | return model
504 |
505 |
506 | @register_model
507 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
508 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
509 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
510 | """
511 | model_kwargs = dict(
512 | patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
513 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
514 | return model
515 |
516 |
517 | @register_model
518 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
519 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
520 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
521 | NOTE: converted weights not currently available, too large for github release hosting.
522 | """
523 | model_kwargs = dict(
524 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
525 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
526 | return model
527 |
528 |
529 | @register_model
530 | def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
531 | """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
532 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
533 | """
534 | # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
535 | backbone = ResNetV2(
536 | layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
537 | preact=False, stem_type='same', conv_layer=StdConv2dSame)
538 | model_kwargs = dict(
539 | embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone,
540 | representation_size=768, **kwargs)
541 | model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
542 | return model
543 |
544 |
545 | @register_model
546 | def vit_base_resnet50_384(pretrained=False, **kwargs):
547 | """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
548 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
549 | """
550 | # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
551 | backbone = ResNetV2(
552 | layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
553 | preact=False, stem_type='same', conv_layer=StdConv2dSame)
554 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
555 | model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
556 | return model
557 |
558 |
559 | @register_model
560 | def vit_small_resnet26d_224(pretrained=False, **kwargs):
561 | """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
562 | """
563 | backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
564 | model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
565 | model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
566 | return model
567 |
568 |
569 | @register_model
570 | def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
571 | """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
572 | """
573 | backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
574 | model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
575 | model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
576 | return model
577 |
578 |
579 | @register_model
580 | def vit_base_resnet26d_224(pretrained=False, **kwargs):
581 | """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
582 | """
583 | backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
584 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
585 | model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
586 | return model
587 |
588 |
589 | @register_model
590 | def vit_base_resnet50d_224(pretrained=False, **kwargs):
591 | """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
592 | """
593 | backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
594 | model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
595 | model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
596 | return model
597 |
598 |
599 | @register_model
600 | def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
601 | """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
602 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
603 | """
604 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
605 | model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
606 | return model
607 |
608 |
609 | @register_model
610 | def vit_deit_small_patch16_224(pretrained=False, **kwargs):
611 | """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
612 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
613 | """
614 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
615 | model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
616 | return model
617 |
618 |
619 | @register_model
620 | def vit_deit_base_patch16_224(pretrained=False, **kwargs):
621 | """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
622 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
623 | """
624 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
625 | model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
626 | return model
627 |
628 |
629 | @register_model
630 | def vit_deit_base_patch16_384(pretrained=False, **kwargs):
631 | """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
632 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
633 | """
634 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
635 | model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
636 | return model
637 |
638 |
639 | @register_model
640 | def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
641 | """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
642 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
643 | """
644 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
645 | model = _create_vision_transformer(
646 | 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
647 | return model
648 |
649 |
650 | @register_model
651 | def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
652 | """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
653 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
654 | """
655 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
656 | model = _create_vision_transformer(
657 | 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
658 | return model
659 |
660 |
661 | @register_model
662 | def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
663 | """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
664 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
665 | """
666 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
667 | model = _create_vision_transformer(
668 | 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
669 | return model
670 |
671 |
672 | @register_model
673 | def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
674 | """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
675 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
676 | """
677 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
678 | model = _create_vision_transformer(
679 | 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
680 | return model
681 |
--------------------------------------------------------------------------------