├── src
└── method.png
├── augmentations
└── dino_augmentation.py
├── utils
├── preprocess.py
├── optimizers.py
├── metrics.py
├── checkpoint_io.py
└── utils.py
├── README.md
├── backbones
├── resnet.py
└── vision_transformer.py
├── models
└── mokd.py
├── LICENSE
├── eval_knn.py
├── eval_linear.py
└── main_mokd.py
/src/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skyoux/mokd/HEAD/src/method.png
--------------------------------------------------------------------------------
/augmentations/dino_augmentation.py:
--------------------------------------------------------------------------------
1 |
2 | from torchvision import transforms
3 | from PIL import Image
4 |
5 | from utils.preprocess import GaussianBlur, Solarization
6 |
7 |
8 | class DINODataAugmentation(object):
9 | def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
10 | flip_and_color_jitter = transforms.Compose([
11 | transforms.RandomHorizontalFlip(p=0.5),
12 | transforms.RandomApply(
13 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
14 | p=0.8
15 | ),
16 | transforms.RandomGrayscale(p=0.2),
17 | ])
18 | normalize = transforms.Compose([
19 | transforms.ToTensor(),
20 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
21 | ])
22 |
23 | # first global crop
24 | self.global_transfo1 = transforms.Compose([
25 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
26 | flip_and_color_jitter,
27 | GaussianBlur(1.0),
28 | normalize,
29 | ])
30 | # second global crop
31 | self.global_transfo2 = transforms.Compose([
32 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
33 | flip_and_color_jitter,
34 | GaussianBlur(0.1),
35 | Solarization(0.2),
36 | normalize,
37 | ])
38 | # transformation for the local small crops
39 | self.local_crops_number = local_crops_number
40 | self.local_transfo = transforms.Compose([
41 | transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC),
42 | flip_and_color_jitter,
43 | GaussianBlur(p=0.5),
44 | normalize,
45 | ])
46 |
47 | def __call__(self, image):
48 | crops = []
49 | crops.append(self.global_transfo1(image))
50 | crops.append(self.global_transfo2(image))
51 | for _ in range(self.local_crops_number):
52 | crops.append(self.local_transfo(image))
53 | return crops
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 |
2 | import random
3 | from PIL import ImageFilter, ImageOps
4 | import numpy as np
5 | from numpy.random import randint
6 |
7 | import torch
8 |
9 | class GaussianBlur(object):
10 | """
11 | Apply Gaussian Blur to the PIL image.
12 | """
13 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
14 | self.prob = p
15 | self.radius_min = radius_min
16 | self.radius_max = radius_max
17 |
18 | def __call__(self, img):
19 | do_it = random.random() <= self.prob
20 | if not do_it:
21 | return img
22 |
23 | return img.filter(
24 | ImageFilter.GaussianBlur(
25 | radius=random.uniform(self.radius_min, self.radius_max)
26 | )
27 | )
28 |
29 |
30 | class Solarization(object):
31 | """
32 | Apply Solarization to the PIL image.
33 | """
34 | def __init__(self, p):
35 | self.p = p
36 |
37 | def __call__(self, img):
38 | if random.random() < self.p:
39 | return ImageOps.solarize(img)
40 | else:
41 | return img
42 |
43 |
44 | def drop_rand_patches(X, X_rep=None, max_drop=0.3, max_block_sz=0.25, tolr=0.05):
45 | #######################
46 | # X_rep: replace X with patches from X_rep. If X_rep is None, replace the patches with Noise
47 | # max_drop: percentage of image to be dropped
48 | # max_block_sz: percentage of the maximum block to be dropped
49 | # tolr: minimum size of the block in terms of percentage of the image size
50 | #######################
51 |
52 | C, H, W = X.size()
53 | n_drop_pix = np.random.uniform(0, max_drop)*H*W
54 | mx_blk_height = int(H*max_block_sz)
55 | mx_blk_width = int(W*max_block_sz)
56 |
57 | tolr = (int(tolr*H), int(tolr*W))
58 |
59 | total_pix = 0
60 | while total_pix < n_drop_pix:
61 |
62 | # get a random block by selecting a random row, column, width, height
63 | rnd_r = randint(0, H-tolr[0])
64 | rnd_c = randint(0, W-tolr[1])
65 | rnd_h = min(randint(tolr[0], mx_blk_height)+rnd_r, H) #rnd_r is alread added - this is not height anymore
66 | rnd_w = min(randint(tolr[1], mx_blk_width)+rnd_c, W)
67 |
68 | if X_rep is None:
69 | X[:, rnd_r:rnd_h, rnd_c:rnd_w] = torch.empty((C, rnd_h-rnd_r, rnd_w-rnd_c), dtype=X.dtype, device='cuda').normal_()
70 | else:
71 | X[:, rnd_r:rnd_h, rnd_c:rnd_w] = X_rep[:, rnd_r:rnd_h, rnd_c:rnd_w]
72 |
73 | total_pix = total_pix + (rnd_h-rnd_r)*(rnd_w-rnd_c)
74 |
75 | return X
76 |
--------------------------------------------------------------------------------
/utils/optimizers.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 |
5 | class LARS(torch.optim.Optimizer):
6 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
7 | weight_decay_filter=None, lars_adaptation_filter=None):
8 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
9 | eta=eta, weight_decay_filter=weight_decay_filter,
10 | lars_adaptation_filter=lars_adaptation_filter)
11 | super().__init__(params, defaults)
12 |
13 | @torch.no_grad()
14 | def step(self):
15 | for g in self.param_groups:
16 | for p in g['params']:
17 | dp = p.grad
18 |
19 | if dp is None:
20 | continue
21 |
22 | if p.ndim != 1:
23 | dp = dp.add(p, alpha=g['weight_decay'])
24 |
25 | if p.ndim != 1:
26 | param_norm = torch.norm(p)
27 | update_norm = torch.norm(dp)
28 | one = torch.ones_like(param_norm)
29 | q = torch.where(param_norm > 0.,
30 | torch.where(update_norm > 0,
31 | (g['eta'] * param_norm / update_norm), one), one)
32 | dp = dp.mul(q)
33 |
34 | param_state = self.state[p]
35 | if 'mu' not in param_state:
36 | param_state['mu'] = torch.zeros_like(p)
37 | mu = param_state['mu']
38 | mu.mul_(g['momentum']).add_(dp)
39 |
40 | p.add_(mu, alpha=-g['lr'])
41 |
42 |
43 | class NativeScalerWithGradNormCount:
44 | state_dict_key = "amp_scaler"
45 |
46 | def __init__(self):
47 | self._scaler = torch.cuda.amp.GradScaler()
48 |
49 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
50 | self._scaler.scale(loss).backward(create_graph=create_graph)
51 | if update_grad:
52 | if clip_grad is not None:
53 | assert parameters is not None
54 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
55 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
56 | else:
57 | self._scaler.unscale_(optimizer)
58 | norm = get_grad_norm_(parameters)
59 | self._scaler.step(optimizer)
60 | self._scaler.update()
61 | else:
62 | norm = None
63 | return norm
64 |
65 | def state_dict(self):
66 | return self._scaler.state_dict()
67 |
68 | def load_state_dict(self, state_dict):
69 | self._scaler.load_state_dict(state_dict)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-Mode Online Knowledge Distillation for Self-Supervised Visual Representation Learning
2 | Official Implementation of our paper "**Multi-Mode Online Knowledge Distillation for Self-Supervised Visual Representation Learning**", in **CVPR 2023**.
3 |
4 | by Kaiyou Song, Jin Xie, Shan Zhang and Zimeng Luo.
5 |
6 | **[[arXiv]](https://arxiv.org/abs/2304.06461)** **[[Paper]](https://arxiv.org/pdf/2304.06461.pdf)**
7 |
8 | ## Method
9 |
10 |
11 |
12 | ## Usage
13 |
14 | ### ImageNet Pre-training
15 |
16 | This implementation supports **DistributedDataParallel** training; single-gpu or DataParallel training is not supported.
17 |
18 | To do pre-training of a ResNet50-ViT-S model pairs on ImageNet in an 16-gpu machine, run:
19 |
20 | ```
21 | python3 -m torch.distributed.launch --nproc_per_node=8 \
22 | --nnodes 2 --node_rank 0 --master_addr='100.123.45.67' --master_port='10001' \
23 | main_mokd.py \
24 | --arch_cnn resnet50 --arch_vit vit_small \
25 | --out_dim 65536 --norm_last_layer False \
26 | --clip_grad_cnn 3 --clip_grad_vit 3 --freeze_last_layer 1 \
27 | --optimizer sgd \
28 | --lr_cnn 0.1 --lr_vit 0.0003 --warmup_epochs 10 \
29 | --use_fp16 True \
30 | --warmup_teacher_temp 0.04 --teacher_temp 0.07 \
31 | --warmup_teacher_temp_epochs_cnn 50 --warmup_teacher_temp_epochs_vit 30 \
32 | --patch_size 16 --drop_path_rate 0.1 \
33 | --local_crops_number 8 --global_crops_scale 0.25 1 --local_crops_scale 0.05 0.25 \
34 | --momentum_teacher 0.996 \
35 | --num_workers 10 \
36 | --batch_size_per_gpu 16 --epochs 100 \
37 | --lamda_t 0.1 --lamda_c 1.0 \
38 | --data_path /path to imagenet/ \
39 | --output_dir output/ \
40 | ```
41 |
42 | ### ImageNet Linear Classification
43 |
44 | With a pre-trained model, to train a supervised linear classifier on frozen features/weights in an 8-gpu machine, run:
45 |
46 | ```
47 | python3 -m torch.distributed.launch --nproc_per_node=8 eval_linear.py \
48 | --arch resnet50 \
49 | --lr 0.01 \
50 | --batch_size_per_gpu 256 \
51 | --num_workers 10 \
52 | --pretrained_weights /path to pretrained checkpoints/xxx.pth \
53 | --checkpoint_key teacher_cnn \
54 | --data_path /path to imagenet/ \
55 | --output_dir output/ \
56 | --method mokd
57 | ```
58 |
59 | ```
60 | python3 -m torch.distributed.launch --nproc_per_node=8 eval_linear.py \
61 | --arch vit_small \
62 | --n_last_blocks 4 \
63 | --lr 0.001 \
64 | --batch_size_per_gpu 256 \
65 | --pretrained_weights /path to pretrained checkpoints/xxx.pth \
66 | --checkpoint_key teacher_vit \
67 | --data_path /path to imagenet/ \
68 | --output_dir output/ \
69 | --method mokd
70 | ```
71 |
72 | ### Evaluation: k-NN classification on ImageNet
73 | To evaluate a k-NN classifier with a single GPU on a pre-trained model, run:
74 |
75 | ```
76 | python3 -m torch.distributed.launch --nproc_per_node=8 eval_knn.py \
77 | --arch resnet18 \
78 | --batch_size_per_gpu 512 \
79 | --pretrained_weights /path to pretrained checkpoints/xxx.pth \
80 | --checkpoint_key teacher_cnn \
81 | --num_workers 20 \
82 | --data_path /path to imagenet/ \
83 | --use_cuda True \
84 | --method mokd
85 | ```
86 |
87 | ## Acknowledgement
88 | This project is based on [DINO](https://github.com/facebookresearch/dino).
89 | Thanks for the wonderful work.
90 |
91 | ## License
92 | This project is under the Apache License 2.0 license. See [LICENSE](LICENSE) for details.
93 |
94 | ## Citation
95 | ```bibtex
96 | @InProceedings{Song_2023_CVPR,
97 | author = {Song, Kaiyou and Xie, Jin and Zhang, Shan and Luo, Zimeng},
98 | title = {Multi-Mode Online Knowledge Distillation for Self-Supervised Visual Representation Learning},
99 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
100 | month = {June},
101 | year = {2023},
102 | pages = {11848-11857}
103 | }
104 | ```
105 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 |
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the accuracy over the k top predictions for the specified values of k"""
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 | _, pred = output.topk(maxk, 1, True, True)
10 | pred = pred.t()
11 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
12 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
13 |
14 |
15 | def compute_ap(ranks, nres):
16 | """
17 | Computes average precision for given ranked indexes.
18 | Arguments
19 | ---------
20 | ranks : zerro-based ranks of positive images
21 | nres : number of positive images
22 | Returns
23 | -------
24 | ap : average precision
25 | """
26 |
27 | # number of images ranked by the system
28 | nimgranks = len(ranks)
29 |
30 | # accumulate trapezoids in PR-plot
31 | ap = 0
32 |
33 | recall_step = 1. / nres
34 |
35 | for j in np.arange(nimgranks):
36 | rank = ranks[j]
37 |
38 | if rank == 0:
39 | precision_0 = 1.
40 | else:
41 | precision_0 = float(j) / rank
42 |
43 | precision_1 = float(j + 1) / (rank + 1)
44 |
45 | ap += (precision_0 + precision_1) * recall_step / 2.
46 |
47 | return ap
48 |
49 |
50 | def compute_map(ranks, gnd, kappas=[]):
51 | """
52 | Computes the mAP for a given set of returned results.
53 | Usage:
54 | map = compute_map (ranks, gnd)
55 | computes mean average precsion (map) only
56 | map, aps, pr, prs = compute_map (ranks, gnd, kappas)
57 | computes mean average precision (map), average precision (aps) for each query
58 | computes mean precision at kappas (pr), precision at kappas (prs) for each query
59 | Notes:
60 | 1) ranks starts from 0, ranks.shape = db_size X #queries
61 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
62 | 3) If there are no positive images for some query, that query is excluded from the evaluation
63 | """
64 |
65 | map = 0.
66 | nq = len(gnd) # number of queries
67 | aps = np.zeros(nq)
68 | pr = np.zeros(len(kappas))
69 | prs = np.zeros((nq, len(kappas)))
70 | nempty = 0
71 |
72 | for i in np.arange(nq):
73 | qgnd = np.array(gnd[i]['ok'])
74 |
75 | # no positive images, skip from the average
76 | if qgnd.shape[0] == 0:
77 | aps[i] = float('nan')
78 | prs[i, :] = float('nan')
79 | nempty += 1
80 | continue
81 |
82 | try:
83 | qgndj = np.array(gnd[i]['junk'])
84 | except:
85 | qgndj = np.empty(0)
86 |
87 | # sorted positions of positive and junk images (0 based)
88 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]
89 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]
90 |
91 | k = 0;
92 | ij = 0;
93 | if len(junk):
94 | # decrease positions of positives based on the number of
95 | # junk images appearing before them
96 | ip = 0
97 | while (ip < len(pos)):
98 | while (ij < len(junk) and pos[ip] > junk[ij]):
99 | k += 1
100 | ij += 1
101 | pos[ip] = pos[ip] - k
102 | ip += 1
103 |
104 | # compute ap
105 | ap = compute_ap(pos, len(qgnd))
106 | map = map + ap
107 | aps[i] = ap
108 |
109 | # compute precision @ k
110 | pos += 1 # get it to 1-based
111 | for j in np.arange(len(kappas)):
112 | kq = min(max(pos), kappas[j]);
113 | prs[i, j] = (pos <= kq).sum() / kq
114 | pr = pr + prs[i, :]
115 |
116 | map = map / (nq - nempty)
117 | pr = pr / (nq - nempty)
118 |
119 | return map, aps, pr, prs
--------------------------------------------------------------------------------
/utils/checkpoint_io.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, method):
6 | if os.path.isfile(pretrained_weights):
7 | state_dict = torch.load(pretrained_weights, map_location="cpu")
8 | if checkpoint_key is not None and checkpoint_key in state_dict:
9 | print(f"Take key {checkpoint_key} in provided checkpoint dict")
10 | print("epoch:", state_dict["epoch"])
11 | state_dict = state_dict[checkpoint_key]
12 |
13 | if "mokd" in method :
14 | # remove `module.` prefix
15 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
16 | # remove `backbone.` prefix induced by multicrop wrapper
17 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
18 |
19 | msg = model.load_state_dict(state_dict, strict=False)
20 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
21 | return True
22 | else:
23 | print("There is no reference weights available for this model => Random weights will be used.")
24 | return False
25 |
26 |
27 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
28 | missing_keys = []
29 | unexpected_keys = []
30 | error_msgs = []
31 | # copy state_dict so _load_from_state_dict can modify it
32 | metadata = getattr(state_dict, '_metadata', None)
33 | state_dict = state_dict.copy()
34 | if metadata is not None:
35 | state_dict._metadata = metadata
36 |
37 | def load(module, prefix=''):
38 | local_metadata = {} if metadata is None else metadata.get(
39 | prefix[:-1], {})
40 | module._load_from_state_dict(
41 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
42 | for name, child in module._modules.items():
43 | if child is not None:
44 | load(child, prefix + name + '.')
45 |
46 | load(model, prefix=prefix)
47 |
48 |
49 | def load_pretrained_linear_weights(linear_classifier, ckpt_path, method):
50 | if not os.path.isfile(ckpt_path):
51 | print("Cannot find checkpoint at {}, Use random linear weights.".format(ckpt_path))
52 | return False
53 | print("Found checkpoint at {}".format(ckpt_path))
54 | # open checkpoint file
55 | checkpoint = torch.load(ckpt_path, map_location="cpu")
56 | state_dict = checkpoint['state_dict']
57 | linear_classifier.load_state_dict(state_dict, strict=True)
58 | return True
59 |
60 |
61 | def restart_from_checkpoint(ckpt_path, run_variables=None, **kwargs):
62 | """
63 | Re-start from checkpoint
64 | """
65 | if not os.path.isfile(ckpt_path):
66 | print("Checkpoint not founded in {}, train from random initialization".format(ckpt_path))
67 | return
68 | print("Found checkpoint at {}".format(ckpt_path))
69 |
70 | # open checkpoint file
71 | checkpoint = torch.load(ckpt_path, map_location="cpu")
72 |
73 | # key is what to look for in the checkpoint file
74 | # value is the object to load
75 | # example: {'state_dict': model}
76 | for key, value in kwargs.items():
77 | if key in checkpoint and value is not None:
78 | try:
79 | msg = value.load_state_dict(checkpoint[key], strict=False)
80 | print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckpt_path, msg))
81 | except TypeError:
82 | try:
83 | msg = value.load_state_dict(checkpoint[key])
84 | print("=> loaded '{}' from checkpoint: '{}'".format(key, ckpt_path))
85 | except ValueError:
86 | print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckpt_path))
87 | else:
88 | print("=> key '{}' not found in checkpoint: '{}'".format(key, ckpt_path))
89 |
90 | # re load variable important for the run
91 | if run_variables is not None:
92 | for var_name in run_variables:
93 | if var_name in checkpoint:
94 | run_variables[var_name] = checkpoint[var_name]
--------------------------------------------------------------------------------
/backbones/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | from typing import Type, Any, Callable, Union, List, Optional
5 | from torchvision import models as torchvision_models
6 | from torchvision.models.resnet import BasicBlock, Bottleneck, model_urls
7 | from torch.hub import load_state_dict_from_url
8 |
9 |
10 | class ResNet(torchvision_models.ResNet):
11 | def __init__(self, block: Type[Union[BasicBlock, Bottleneck]],
12 | layers: List[int], **kwargs):
13 | super(ResNet, self).__init__(block, layers, **kwargs)
14 |
15 | def _forward_impl(self, x: Tensor) -> Tensor:
16 | # See note [TorchScript super()]
17 | x = self.conv1(x)
18 | x = self.bn1(x)
19 | x = self.relu(x)
20 | x = self.maxpool(x)
21 |
22 | x = self.layer1(x)
23 | x = self.layer2(x)
24 | x = self.layer3(x)
25 | x = self.layer4(x)
26 |
27 | # sky
28 | f4 = x
29 |
30 | x = self.avgpool(x)
31 | x = torch.flatten(x, 1)
32 | x = self.fc(x)
33 |
34 | return x, f4
35 |
36 | def _resnet(
37 | arch: str,
38 | block: Type[Union[BasicBlock, Bottleneck]],
39 | layers: List[int],
40 | pretrained: bool,
41 | progress: bool,
42 | **kwargs: Any
43 | ) -> ResNet:
44 | model = ResNet(block, layers, **kwargs)
45 | if pretrained:
46 | state_dict = load_state_dict_from_url(model_urls[arch],
47 | progress=progress)
48 | model.load_state_dict(state_dict)
49 | return model
50 |
51 |
52 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
53 | r"""ResNet-18 model from
54 | `"Deep Residual Learning for Image Recognition" `_.
55 |
56 | Args:
57 | pretrained (bool): If True, returns a model pre-trained on ImageNet
58 | progress (bool): If True, displays a progress bar of the download to stderr
59 | """
60 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
61 | **kwargs)
62 |
63 |
64 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
65 | r"""ResNet-34 model from
66 | `"Deep Residual Learning for Image Recognition" `_.
67 |
68 | Args:
69 | pretrained (bool): If True, returns a model pre-trained on ImageNet
70 | progress (bool): If True, displays a progress bar of the download to stderr
71 | """
72 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
73 | **kwargs)
74 |
75 |
76 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
77 | r"""ResNet-50 model from
78 | `"Deep Residual Learning for Image Recognition" `_.
79 |
80 | Args:
81 | pretrained (bool): If True, returns a model pre-trained on ImageNet
82 | progress (bool): If True, displays a progress bar of the download to stderr
83 | """
84 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
85 | **kwargs)
86 |
87 |
88 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
89 | r"""ResNet-101 model from
90 | `"Deep Residual Learning for Image Recognition" `_.
91 |
92 | Args:
93 | pretrained (bool): If True, returns a model pre-trained on ImageNet
94 | progress (bool): If True, displays a progress bar of the download to stderr
95 | """
96 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
97 | **kwargs)
98 |
99 |
100 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
101 | r"""ResNet-152 model from
102 | `"Deep Residual Learning for Image Recognition" `_.
103 |
104 | Args:
105 | pretrained (bool): If True, returns a model pre-trained on ImageNet
106 | progress (bool): If True, displays a progress bar of the download to stderr
107 | """
108 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
109 | **kwargs)
110 |
111 |
112 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
113 | r"""ResNeXt-50 32x4d model from
114 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
115 |
116 | Args:
117 | pretrained (bool): If True, returns a model pre-trained on ImageNet
118 | progress (bool): If True, displays a progress bar of the download to stderr
119 | """
120 | kwargs['groups'] = 32
121 | kwargs['width_per_group'] = 4
122 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
123 | pretrained, progress, **kwargs)
124 |
125 |
126 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
127 | r"""ResNeXt-101 32x8d model from
128 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
129 |
130 | Args:
131 | pretrained (bool): If True, returns a model pre-trained on ImageNet
132 | progress (bool): If True, displays a progress bar of the download to stderr
133 | """
134 | kwargs['groups'] = 32
135 | kwargs['width_per_group'] = 8
136 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
137 | pretrained, progress, **kwargs)
138 |
139 |
140 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
141 | r"""Wide ResNet-50-2 model from
142 | `"Wide Residual Networks" `_.
143 |
144 | The model is the same as ResNet except for the bottleneck number of channels
145 | which is twice larger in every block. The number of channels in outer 1x1
146 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
147 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
148 |
149 | Args:
150 | pretrained (bool): If True, returns a model pre-trained on ImageNet
151 | progress (bool): If True, displays a progress bar of the download to stderr
152 | """
153 | kwargs['width_per_group'] = 64 * 2
154 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
155 | pretrained, progress, **kwargs)
156 |
157 |
158 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
159 | r"""Wide ResNet-101-2 model from
160 | `"Wide Residual Networks" `_.
161 |
162 | The model is the same as ResNet except for the bottleneck number of channels
163 | which is twice larger in every block. The number of channels in outer 1x1
164 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
165 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
166 |
167 | Args:
168 | pretrained (bool): If True, returns a model pre-trained on ImageNet
169 | progress (bool): If True, displays a progress bar of the download to stderr
170 | """
171 | kwargs['width_per_group'] = 64 * 2
172 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
173 | pretrained, progress, **kwargs)
174 |
--------------------------------------------------------------------------------
/models/mokd.py:
--------------------------------------------------------------------------------
1 | from operator import is_
2 | import torch
3 | import torch.nn as nn
4 | import torch.distributed as dist
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | from utils.utils import trunc_normal_
9 |
10 |
11 | class DINOHead(nn.Module):
12 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
13 | super().__init__()
14 | nlayers = max(nlayers, 1)
15 | if nlayers == 1:
16 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
17 | else:
18 | layers = [nn.Linear(in_dim, hidden_dim)]
19 | if use_bn:
20 | layers.append(nn.BatchNorm1d(hidden_dim))
21 | layers.append(nn.GELU())
22 | for _ in range(nlayers - 2):
23 | layers.append(nn.Linear(hidden_dim, hidden_dim))
24 | if use_bn:
25 | layers.append(nn.BatchNorm1d(hidden_dim))
26 | layers.append(nn.GELU())
27 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
28 | self.mlp = nn.Sequential(*layers)
29 | self.apply(self._init_weights)
30 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
31 | self.last_layer.weight_g.data.fill_(1)
32 | if norm_last_layer:
33 | self.last_layer.weight_g.requires_grad = False
34 |
35 | def _init_weights(self, m):
36 | if isinstance(m, nn.Linear):
37 | trunc_normal_(m.weight, std=.02)
38 | if isinstance(m, nn.Linear) and m.bias is not None:
39 | nn.init.constant_(m.bias, 0)
40 |
41 | def forward(self, x):
42 | x = self.mlp(x)
43 | x = nn.functional.normalize(x, dim=-1, p=2)
44 | x = self.last_layer(x)
45 | return x
46 |
47 |
48 | class CNNStudentWrapper(nn.Module):
49 | def __init__(self, backbone, head, transhead):
50 | super(CNNStudentWrapper, self).__init__()
51 | # disable layers dedicated to ImageNet labels classification
52 | backbone.fc, backbone.head = nn.Identity(), nn.Identity()
53 | self.backbone = backbone
54 | self.head = head
55 | self.transhead = transhead
56 |
57 | def forward(self, x):
58 | fea, fea4 = self.backbone(torch.cat(x[:]))
59 |
60 | fea_trans, fea_reduce_dim = self.transhead(fea4)
61 |
62 | return self.head(fea), fea_trans, fea_reduce_dim
63 |
64 |
65 | class CNNTeacherWrapper(nn.Module):
66 | def __init__(self, backbone, head, transhead, num_crops):
67 | super(CNNTeacherWrapper, self).__init__()
68 | # disable layers dedicated to ImageNet labels classification
69 | backbone.fc, backbone.head = nn.Identity(), nn.Identity()
70 | self.backbone = backbone
71 | self.head = head
72 | self.transhead = transhead
73 | self.num_crops = num_crops
74 |
75 |
76 | def forward(self, x, local_token=None, global_token=None):
77 | fea, fea4 = self.backbone(torch.cat(x[:]))
78 |
79 | fea4_trans, _ = self.transhead(fea4)
80 | local_search_fea = None
81 | global_search_fea = None
82 | if local_token != None:
83 | local_search_fea, _ = self.transhead(fea4.chunk(2)[0].repeat(self.num_crops, 1, 1, 1), local_token)
84 | if global_token != None:
85 | global_search_fea, _ = self.transhead(fea4, global_token)
86 | return self.head(fea), fea4_trans, local_search_fea, global_search_fea
87 |
88 |
89 | class ViTStudentWrapper(nn.Module):
90 | def __init__(self, backbone, head, transhead):
91 | super(ViTStudentWrapper, self).__init__()
92 | # disable layers dedicated to ImageNet labels classification
93 | backbone.fc, backbone.head = nn.Identity(), nn.Identity()
94 | self.backbone = backbone
95 | self.head = head
96 | self.transhead = transhead
97 |
98 | def forward(self, x):
99 | tokens = self.backbone(torch.cat(x[:]), return_all_tokens=True)
100 | if isinstance(tokens, tuple):
101 | tokens = tokens[0]
102 |
103 | return self.head(tokens[:, 0]), self.transhead(tokens[:, 1:]), tokens[:, 0]
104 |
105 |
106 | class ViTTeacherWrapper(nn.Module):
107 | def __init__(self, backbone, head, transhead, num_crops):
108 | super(ViTTeacherWrapper, self).__init__()
109 | # disable layers dedicated to ImageNet labels classification
110 | backbone.fc, backbone.head = nn.Identity(), nn.Identity()
111 | self.backbone = backbone
112 | self.head = head
113 | self.transhead = transhead
114 | self.num_crops = num_crops
115 |
116 | def forward(self, x, local_token=None, global_token=None):
117 | tokens = self.backbone(torch.cat(x[:]), return_all_tokens=True)
118 | if isinstance(tokens, tuple):
119 | tokens = tokens[0]
120 |
121 | cls_token = tokens[:, 0]
122 | pth_token = tokens[:, 1:] # B,N,C
123 |
124 | local_search_fea = None
125 | global_search_fea = None
126 | if local_token != None:
127 | local_search_fea = self.transhead(pth_token.chunk(2)[0].repeat(self.num_crops, 1, 1), local_token)
128 | if global_token != None:
129 | global_search_fea = self.transhead(pth_token, global_token)
130 | return self.head(cls_token), self.transhead(pth_token), local_search_fea, global_search_fea
131 |
132 |
133 | class DINOLoss(nn.Module):
134 | def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
135 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
136 | center_momentum=0.9):
137 | super().__init__()
138 | self.student_temp = student_temp
139 | self.center_momentum = center_momentum
140 | self.ncrops = ncrops
141 | self.register_buffer("center", torch.zeros(1, out_dim))
142 | # we apply a warm up for the teacher temperature because
143 | # a too high temperature makes the training instable at the beginning
144 | self.teacher_temp_schedule = np.concatenate((
145 | np.linspace(warmup_teacher_temp,
146 | teacher_temp, warmup_teacher_temp_epochs),
147 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
148 | ))
149 |
150 | def forward(self, student_output, teacher_output, epoch):
151 | """
152 | Cross-entropy between softmax outputs of the teacher and student networks.
153 | """
154 | student_out = student_output / self.student_temp
155 | student_out = student_out.chunk(self.ncrops)
156 |
157 | # teacher centering and sharpening
158 | temp = self.teacher_temp_schedule[epoch]
159 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
160 | teacher_out = teacher_out.detach().chunk(2)
161 |
162 | total_loss = 0
163 | n_loss_terms = 0
164 | for iq, q in enumerate(teacher_out):
165 | for v in range(len(student_out)):
166 | if v == iq:
167 | # we skip cases where student and teacher operate on the same view
168 | continue
169 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
170 | total_loss += loss.mean()
171 | n_loss_terms += 1
172 | total_loss /= n_loss_terms
173 | self.update_center(teacher_output)
174 | return total_loss
175 |
176 | @torch.no_grad()
177 | def update_center(self, teacher_output):
178 | """
179 | Update center used for teacher output.
180 | """
181 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
182 | dist.all_reduce(batch_center)
183 | batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
184 |
185 | # ema update
186 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
187 |
188 |
189 | class CTSearchLoss(nn.Module):
190 | def __init__(self, warmup_teacher_temp, teacher_temp,
191 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1):
192 | super().__init__()
193 | self.student_temp = student_temp
194 | self.teacher_temp_schedule = np.concatenate((
195 | np.linspace(warmup_teacher_temp,
196 | teacher_temp, warmup_teacher_temp_epochs),
197 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
198 | ))
199 |
200 | def forward(self, student, teacher, epoch):
201 |
202 | temp = self.teacher_temp_schedule[epoch]
203 |
204 | student_out = student / self.student_temp
205 | teacher_out = F.softmax(teacher / temp, dim=-1)
206 | loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1).mean()
207 |
208 | return loss
209 |
210 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Kaiyou Song
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/eval_knn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import sys
4 | import argparse
5 | import numpy as np
6 |
7 | import torch
8 | from torch import nn
9 | import torch.distributed as dist
10 | import torch.backends.cudnn as cudnn
11 | from torchvision import datasets
12 | from torchvision import transforms as pth_transforms
13 | from torchvision import models as torchvision_models
14 | import timm.models as timm_models
15 |
16 | from utils import utils
17 | from utils import checkpoint_io
18 | import backbones.vision_transformer as vits
19 |
20 |
21 | def extract_feature_pipeline(args):
22 | ######################## preparing data ... ########################
23 | resize_size = 256 if args.input_size == 224 else 512
24 | transform = pth_transforms.Compose([
25 | pth_transforms.Resize(resize_size, interpolation=3),
26 | pth_transforms.CenterCrop(args.input_size),
27 | pth_transforms.ToTensor(),
28 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
29 | ])
30 |
31 | dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform)
32 | dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform)
33 |
34 | train_labels = torch.tensor(dataset_train.target).long()
35 | test_labels = torch.tensor(dataset_val.target).long()
36 |
37 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
38 | data_loader_train = torch.utils.data.DataLoader(
39 | dataset_train,
40 | sampler=sampler,
41 | batch_size=args.batch_size_per_gpu,
42 | num_workers=args.num_workers,
43 | pin_memory=False,
44 | drop_last=False,
45 | )
46 | data_loader_val = torch.utils.data.DataLoader(
47 | dataset_val,
48 | batch_size=args.batch_size_per_gpu,
49 | num_workers=args.num_workers,
50 | pin_memory=False,
51 | drop_last=False,
52 | )
53 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
54 |
55 | ######################## building network ... ########################
56 | if "vit" in args.arch:
57 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
58 | # model = timm_models.__dict__[args.arch](num_classes=0)
59 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
60 | elif args.arch in torchvision_models.__dict__.keys():
61 | model = torchvision_models.__dict__[args.arch](num_classes=0)
62 | model.fc = nn.Identity()
63 | else:
64 | print(f"Architecture {args.arch} non supported")
65 | sys.exit(1)
66 | # print(model)
67 | model.cuda()
68 | is_load_success = checkpoint_io.load_pretrained_weights(model, args.pretrained_weights,
69 | args.checkpoint_key, args.method)
70 | if is_load_success == False:
71 | sys.exit(1)
72 |
73 | model.eval()
74 |
75 | ######################## extract features ... ########################
76 | print("Extracting features for train set...")
77 | train_features = extract_features(model, data_loader_train, args.arch, args.avgpool_patchtokens, args.use_cuda)
78 | print("Extracting features for val set...")
79 | test_features = extract_features(model, data_loader_val, args.arch, args.avgpool_patchtokens, args.use_cuda)
80 |
81 | if utils.get_rank() == 0:
82 | train_features = nn.functional.normalize(train_features, dim=1, p=2)
83 | test_features = nn.functional.normalize(test_features, dim=1, p=2)
84 |
85 | # save features and labels
86 | if args.dump_features and dist.get_rank() == 0:
87 | torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth"))
88 | torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth"))
89 | torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth"))
90 | torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth"))
91 | return train_features, test_features, train_labels, test_labels
92 |
93 |
94 | @torch.no_grad()
95 | def extract_features(model, data_loader, arch="resnet50", avgpool_patchtokens=1, use_cuda=True, multiscale=False):
96 | metric_logger = utils.MetricLogger(delimiter=" ")
97 | features = None
98 | for samples, index in metric_logger.log_every(data_loader, 10):
99 | samples = samples.cuda(non_blocking=True)
100 | index = index.cuda(non_blocking=True)
101 | if multiscale:
102 | feats = utils.multi_scale(samples, model)
103 | else:
104 | if "resnet" in arch:
105 | feats = model(samples).clone()
106 | else: # vits
107 | output = model(samples, return_all_tokens=True) # cl_mode: class_token patch_token
108 | if isinstance(output, tuple):
109 | output = output[0]
110 | output = output.clone()
111 | if avgpool_patchtokens == 0:
112 | # norm(x[:, 0])
113 | feats = output[:, 0].contiguous()
114 | elif avgpool_patchtokens == 1:
115 | # x[:, 1:].mean(1)
116 | feats = torch.mean(output[:, 1:], dim=1)
117 | elif avgpool_patchtokens == 2:
118 | # norm(x[:, 0]) + norm(x[:, 1:]).mean(1)
119 | feats = output[:, 0] + torch.mean(output[:, 1:], dim=1)
120 | else:
121 | assert False, "Unkown avgpool type {}".format(avgpool_patchtokens)
122 |
123 | # print(feats.shape)
124 | if len(feats.shape) != 2:
125 | feats = feats.squeeze()
126 |
127 | # init storage feature matrix
128 | if dist.get_rank() == 0 and features is None:
129 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
130 | if use_cuda:
131 | features = features.cuda(non_blocking=True)
132 | print(f"Storing features into tensor of shape {features.shape}")
133 |
134 | # get indexes from all processes
135 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
136 | y_l = list(y_all.unbind(0))
137 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
138 | y_all_reduce.wait()
139 | index_all = torch.cat(y_l)
140 |
141 | # share features between processes
142 | feats_all = torch.empty(
143 | dist.get_world_size(),
144 | feats.size(0),
145 | feats.size(1),
146 | dtype=feats.dtype,
147 | device=feats.device,
148 | )
149 | output_l = list(feats_all.unbind(0))
150 | output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
151 | output_all_reduce.wait()
152 |
153 | # update storage feature matrix
154 | if dist.get_rank() == 0:
155 | if use_cuda:
156 | features.index_copy_(0, index_all, torch.cat(output_l))
157 | else:
158 | features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
159 | return features
160 |
161 |
162 | @torch.no_grad()
163 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000, use_cuda=True):
164 | top1, top5, total = 0.0, 0.0, 0
165 | train_features = train_features.t()
166 | num_test_images, num_chunks = test_labels.shape[0], 100
167 | imgs_per_chunk = num_test_images // num_chunks
168 | retrieval_one_hot = torch.zeros(k, num_classes)
169 | if use_cuda:
170 | retrieval_one_hot = retrieval_one_hot.cuda()
171 | for idx in range(0, num_test_images, imgs_per_chunk):
172 | # get the features for test images
173 | features = test_features[
174 | idx : min((idx + imgs_per_chunk), num_test_images), :
175 | ]
176 | targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)]
177 | batch_size = targets.shape[0]
178 |
179 | # calculate the dot product and compute top-k neighbors
180 | similarity = torch.mm(features, train_features)
181 | distances, indices = similarity.topk(k, largest=True, sorted=True)
182 | candidates = train_labels.view(1, -1).expand(batch_size, -1) #500x1281167
183 | retrieved_neighbors = torch.gather(candidates, 1, indices) # 500x10
184 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() #5000x0
185 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
186 | distances_transform = distances.clone().div_(T).exp_()
187 | probs = torch.sum(
188 | torch.mul(
189 | retrieval_one_hot.view(batch_size, -1, num_classes),
190 | distances_transform.view(batch_size, -1, 1),
191 | ),
192 | 1,
193 | )
194 | _, predictions = probs.sort(1, True)
195 |
196 | # find the predictions that match the target
197 | correct = predictions.eq(targets.data.view(-1, 1))
198 | top1 = top1 + correct.narrow(1, 0, 1).sum().item()
199 | top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() # top5 does not make sense if k < 5
200 | total += targets.size(0)
201 | top1 = top1 * 100.0 / total
202 | top5 = top5 * 100.0 / total
203 | return top1, top5
204 |
205 |
206 | class ReturnIndexDataset_nori(ImagenetLoader):
207 | def __getitem__(self, idx):
208 | img, lab = super(ReturnIndexDataset_nori, self).__getitem__(idx)
209 | return img, idx
210 |
211 | class ReturnIndexDataset(datasets.ImageFolder):
212 | def __getitem__(self, idx):
213 | img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
214 | return img, idx
215 |
216 | def get_args_parser():
217 | parser = argparse.ArgumentParser("KNN Evaluation", add_help=False)
218 | parser.add_argument('--input_size', default=224, type=int, help='input image size')
219 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
220 | parser.add_argument('--nb_knn', default=[20, 10, 30], nargs='+', type=int,
221 | help='Number of NN to use. 20 is usually working the best.')
222 | parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
223 | parser.add_argument('--temperature', default=0.07, type=float,
224 | help='Temperature used in the voting coefficient')
225 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
226 | parser.add_argument('--use_cuda', default=False, type=utils.bool_flag,
227 | help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM")
228 | parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
229 | parser.add_argument("--checkpoint_key", default="teacher", type=str,
230 | help='Key to use in the checkpoint (example: "teacher")')
231 | # for ViTs
232 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
233 | parser.add_argument('--avgpool_patchtokens', default=0, choices=[0, 1, 2], type=int,
234 | help="""Whether or not to use global average pooled features or the [CLS] token.
235 | We typically set this to 1 for BEiT and 0 for models with [CLS] token (e.g., DINO).
236 | we set this to 2 for base-size models with [CLS] token when doing linear classification.""")
237 |
238 | parser.add_argument('--dump_features', default=None,
239 | help='Path where to save computed features, empty for no saving')
240 | parser.add_argument('--load_features', default=None, help="""If the features have
241 | already been computed, where to find them.""")
242 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
243 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
244 | distributed training; see https://pytorch.org/docs/stable/distributed.html""")
245 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
246 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
247 | parser.add_argument('--method', default='moco', type=str, help='model name')
248 |
249 | return parser
250 |
251 |
252 | if __name__ == '__main__':
253 | parser = argparse.ArgumentParser("KNN Evaluation", parents=[get_args_parser()])
254 | args = parser.parse_args()
255 |
256 | utils.init_distributed_mode(args)
257 | print("git:\n {}\n".format(utils.get_sha()))
258 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
259 | cudnn.benchmark = True
260 |
261 | if args.load_features:
262 | train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth"))
263 | test_features = torch.load(os.path.join(args.load_features, "testfeat.pth"))
264 | train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth"))
265 | test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth"))
266 | else:
267 | # need to extract features !
268 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args)
269 |
270 | if utils.get_rank() == 0:
271 | if args.use_cuda:
272 | train_features = train_features.cuda()
273 | test_features = test_features.cuda()
274 | train_labels = train_labels.cuda()
275 | test_labels = test_labels.cuda()
276 | else:
277 | train_features = train_features.cpu()
278 | test_features = test_features.cpu()
279 | train_labels = train_labels.cpu()
280 | test_labels = test_labels.cpu()
281 |
282 | print("Features are ready!\nStart the k-NN classification.")
283 | for k in args.nb_knn:
284 | top1, top5 = knn_classifier(train_features, train_labels,
285 | test_features, test_labels, k, args.temperature, args.num_labels, args.use_cuda)
286 | print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
287 | dist.barrier()
288 |
--------------------------------------------------------------------------------
/eval_linear.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import sys
4 | import argparse
5 | import json
6 | from pathlib import Path
7 | import time
8 | import builtins
9 | import random
10 |
11 | import torch
12 | from torch import nn
13 | import torch.distributed as dist
14 | import torch.multiprocessing as mp
15 | import torch.backends.cudnn as cudnn
16 | from torchvision import datasets
17 | from torchvision import transforms as transforms
18 | from torchvision import models as torchvision_models
19 | from torch.utils.tensorboard import SummaryWriter
20 |
21 | from utils import utils
22 | from utils import checkpoint_io
23 | from utils import optimizers
24 | from utils import metrics
25 | import backbones.vision_transformer as vits
26 |
27 | def main():
28 | parser = argparse.ArgumentParser("Linear Evaluation", parents=[get_args_parser()])
29 | args = parser.parse_args()
30 |
31 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
32 |
33 | utils.init_distributed_mode(args)
34 | print("git:\n {}\n".format(utils.get_sha()))
35 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
36 | cudnn.benchmark = True
37 |
38 | train(args)
39 |
40 |
41 | def get_args_parser():
42 | parser = argparse.ArgumentParser("Linear Evaluation", add_help=False)
43 |
44 | #################################
45 | #### input and output parameters ####
46 | #################################
47 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
48 | parser.add_argument('--pretrained_weights', default='', type=str,
49 | help="Path to pretrained weights to evaluate.")
50 | parser.add_argument("--checkpoint_key", default="teacher", type=str,
51 | help='Key to use in the checkpoint (example: "teacher")')
52 | parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
53 | parser.add_argument('--evaluate', dest='evaluate', type=str, default=None, help='evaluate model on validation set')
54 | parser.add_argument('--method', default='moco', type=str, help='model name')
55 | parser.add_argument('--experiment', default='exp', type=str, help='experiment name')
56 |
57 | #################################
58 | ####model parameters ####
59 | #################################
60 | parser.add_argument('--arch', default='resnet50', type=str, help='Architecture')
61 | parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
62 | #for ViTs
63 | parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens
64 | for the `n` last blocks. Use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""")
65 | parser.add_argument('--avgpool_patchtokens', default=0, type=int,
66 | help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.""")
67 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
68 |
69 | #################################
70 | #### optim parameters ###
71 | #################################
72 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
73 | parser.add_argument('--optimizer', default='sgd', type=str,
74 | choices=['sgd', 'lars'], help="""Type of optimizer.""")
75 | parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of
76 | training (highest LR used during training). The learning rate is linearly scaled
77 | with the batch size, and specified here for a reference batch size of 256.""")
78 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
79 | parser.add_argument('--val_freq', default=5, type=int, help="Epoch frequency for validation.")
80 |
81 | #################################
82 | #### dist parameters ###
83 | #################################
84 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training""")
85 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
86 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
87 |
88 | return parser
89 |
90 |
91 | def train(args):
92 | ######################## building network ... ########################
93 | if args.arch in vits.__dict__.keys():
94 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
95 | embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
96 | # embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
97 | # otherwise, we check if the architecture is in torchvision models
98 | elif args.arch in torchvision_models.__dict__.keys():
99 | model = torchvision_models.__dict__[args.arch]()
100 | embed_dim = model.fc.weight.shape[1]
101 | model.fc = nn.Identity()
102 | else:
103 | print(f"Unknow architecture: {args.arch}")
104 | sys.exit(1)
105 | model.cuda()
106 | model.eval()
107 |
108 | ######################## load pretrained weights to evaluate ########################
109 | is_load_success = checkpoint_io.load_pretrained_weights(model, args.pretrained_weights,
110 | args.checkpoint_key, args.method)
111 | if is_load_success:
112 | print(f"Model {args.arch} built.")
113 | else:
114 | sys.exit(1)
115 |
116 | linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels)
117 | linear_classifier = linear_classifier.cuda()
118 | linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])
119 |
120 | ######################## preparing data ########################
121 | val_transform = transforms.Compose([
122 | transforms.Resize(256, interpolation=3),
123 | transforms.CenterCrop(224),
124 | transforms.ToTensor(),
125 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
126 | ])
127 |
128 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
129 |
130 | val_loader = torch.utils.data.DataLoader(
131 | dataset_val,
132 | batch_size=args.batch_size_per_gpu,
133 | num_workers=args.num_workers,
134 | pin_memory=False,
135 | )
136 |
137 | ######################## just do evaluation ########################
138 | if args.evaluate != None:
139 | is_load_success = checkpoint_io.load_pretrained_linear_weights(linear_classifier, args.evaluate, args.method)
140 | if is_load_success == False:
141 | sys.exit(1)
142 | test_stats = validate(val_loader, model, linear_classifier, args)
143 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
144 | return
145 |
146 | ######################## preparing data ########################
147 | train_transform = transforms.Compose([
148 | transforms.RandomResizedCrop(224),
149 | transforms.RandomHorizontalFlip(),
150 | transforms.ToTensor(),
151 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
152 | ])
153 |
154 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
155 |
156 | sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
157 | train_loader = torch.utils.data.DataLoader(
158 | dataset_train,
159 | sampler=sampler,
160 | batch_size=args.batch_size_per_gpu,
161 | num_workers=args.num_workers,
162 | pin_memory=False,
163 | )
164 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
165 |
166 | ######################## preparing optimizer ########################
167 | optimizer = torch.optim.SGD(
168 | linear_classifier.parameters(),
169 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
170 | momentum=0.9,
171 | weight_decay=0,
172 | )
173 | if args.optimizer == "lars":
174 | optimizer = optimizers.LARS(linear_classifier.parameters(),
175 | lr=args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256.,
176 | momentum=0.9,
177 | weight_decay=0,
178 | ) # to use with convnet and large batches
179 |
180 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)
181 |
182 | ######################## optionally resume training ########################
183 | to_restore = {"epoch": 0, "best_acc": 0.}
184 | checkpoint_io.restart_from_checkpoint(
185 | os.path.join(args.output_dir,
186 | "{}_{}_linear_{}.pth".format(args.method, args.arch, args.experiment)),
187 | run_variables=to_restore,
188 | state_dict=linear_classifier,
189 | optimizer=optimizer,
190 | scheduler=scheduler,
191 | )
192 | start_epoch = to_restore["epoch"]
193 | best_acc = to_restore["best_acc"]
194 |
195 | summary_writer = SummaryWriter(log_dir=os.path.join(args.output_dir,
196 | "tb", "{}_{}_linear_{}".format(args.method, args.arch, args.experiment))) if utils.is_main_process() else None
197 |
198 | ######################## start training ########################
199 | for epoch in range(start_epoch, args.epochs):
200 | train_loader.sampler.set_epoch(epoch)
201 |
202 | train_stats = train_one_epoch(model, linear_classifier, optimizer, train_loader, epoch, summary_writer, args)
203 | scheduler.step()
204 |
205 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch}
206 | if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
207 | test_stats = validate(val_loader, model, linear_classifier, args)
208 |
209 | print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
210 | best_acc = max(best_acc, test_stats["acc1"])
211 |
212 | print(f'Max accuracy so far: {best_acc:.2f}%')
213 | log_stats = {**{k: v for k, v in log_stats.items()},
214 | **{f'test_{k}': v for k, v in test_stats.items()}}
215 |
216 | if utils.is_main_process():
217 | summary_writer.add_scalar("acc1", test_stats["acc1"], epoch)
218 | summary_writer.add_scalar("acc5", test_stats["acc5"], epoch)
219 |
220 | if utils.is_main_process():
221 | with (Path(args.output_dir) / "{}_{}_linear_{}_log.txt".format(args.method, args.arch, args.experiment)).open("a") as f:
222 | f.write(json.dumps(log_stats) + "\n")
223 | save_dict = {
224 | "epoch": epoch + 1,
225 | "state_dict": linear_classifier.state_dict(),
226 | "optimizer": optimizer.state_dict(),
227 | "scheduler": scheduler.state_dict(),
228 | "best_acc": best_acc,
229 | }
230 | torch.save(save_dict, os.path.join(args.output_dir,
231 | "{}_{}_linear_{}.pth".format(args.method, args.arch, args.experiment)))
232 |
233 | if utils.is_main_process():
234 | summary_writer.close()
235 | print("Training of the supervised linear classifier on frozen features completed.\n"
236 | "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))
237 | time.sleep(30)
238 |
239 |
240 | def train_one_epoch(model, linear_classifier, optimizer, loader, epoch, summary_writer, args):
241 | linear_classifier.train()
242 | metric_logger = utils.MetricLogger(delimiter=" ")
243 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
244 | header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
245 | iters_per_epoch = len(loader)
246 | for it, (inp, target) in enumerate(metric_logger.log_every(loader, 20, header)):
247 | # move to gpu
248 | inp = inp.cuda(non_blocking=True)
249 | target = target.cuda(non_blocking=True)
250 |
251 | # forward
252 | with torch.no_grad():
253 | if "vit" in args.arch:
254 | intermediate_output = model.get_intermediate_layers(inp, args.n_last_blocks) # take n last block out
255 | output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
256 | if args.avgpool_patchtokens:
257 | output = torch.cat((output.unsqueeze(-1),
258 | torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
259 | output = output.reshape(output.shape[0], -1)
260 | else:
261 | output = model(inp)
262 | if len(output.shape) != 2:
263 | output = output.squeeze()
264 | output = linear_classifier(output)
265 |
266 | # compute cross entropy loss
267 | loss = nn.CrossEntropyLoss()(output, target)
268 |
269 | # compute the gradients
270 | optimizer.zero_grad()
271 | loss.backward()
272 |
273 | # step
274 | optimizer.step()
275 |
276 | # logging
277 | if utils.is_main_process():
278 | summary_writer.add_scalar("loss", loss.item(), epoch * iters_per_epoch + it)
279 | summary_writer.add_scalar("learning rate", optimizer.param_groups[0]["lr"],
280 | epoch * iters_per_epoch + it)
281 | torch.cuda.synchronize()
282 | metric_logger.update(loss=loss.item())
283 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
284 | # gather the stats from all processes
285 | metric_logger.synchronize_between_processes()
286 | print("Averaged stats:", metric_logger)
287 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
288 |
289 |
290 | @torch.no_grad()
291 | def validate(val_loader, model, linear_classifier, args):
292 | linear_classifier.eval()
293 | metric_logger = utils.MetricLogger(delimiter=" ")
294 | header = 'Test:'
295 | for inp, target in metric_logger.log_every(val_loader, 20, header):
296 | # move to gpu
297 | inp = inp.cuda(non_blocking=True)
298 | target = target.cuda(non_blocking=True)
299 |
300 | # forward
301 | with torch.no_grad():
302 | if "vit" in args.arch:
303 | intermediate_output = model.get_intermediate_layers(inp, args.n_last_blocks) # take n last block out
304 | output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
305 | if args.avgpool_patchtokens:
306 | output = torch.cat((output.unsqueeze(-1),
307 | torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
308 | output = output.reshape(output.shape[0], -1)
309 | else:
310 | output = model(inp)
311 | output = linear_classifier(output)
312 | loss = nn.CrossEntropyLoss()(output, target)
313 |
314 | if linear_classifier.module.num_labels >= 5:
315 | acc1, acc5 = metrics.accuracy(output, target, topk=(1, 5))
316 | else:
317 | acc1, = metrics.accuracy(output, target, topk=(1,))
318 |
319 | batch_size = inp.shape[0]
320 | metric_logger.update(loss=loss.item())
321 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
322 | if linear_classifier.module.num_labels >= 5:
323 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
324 | if linear_classifier.module.num_labels >= 5:
325 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
326 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
327 | else:
328 | print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
329 | .format(top1=metric_logger.acc1, losses=metric_logger.loss))
330 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
331 |
332 |
333 | class LinearClassifier(nn.Module):
334 | """Linear layer to train on top of frozen features"""
335 | def __init__(self, dim, num_labels=1000):
336 | super(LinearClassifier, self).__init__()
337 | self.num_labels = num_labels
338 | self.linear = nn.Linear(dim, num_labels)
339 | self.linear.weight.data.normal_(mean=0.0, std=0.01)
340 | self.linear.bias.data.zero_()
341 |
342 | def forward(self, x):
343 | # flatten
344 | x = x.view(x.size(0), -1)
345 |
346 | # linear layer
347 | return self.linear(x)
348 |
349 |
350 | if __name__ == '__main__':
351 | main()
352 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import math
5 | import datetime
6 | import argparse
7 | import subprocess
8 | import warnings
9 | from collections import defaultdict, deque
10 |
11 | import numpy as np
12 | import torch
13 | from torch import nn
14 | import torch.distributed as dist
15 |
16 |
17 | class Dict(dict):
18 | __setattr__ = dict.__setitem__
19 | __getattr__ = dict.__getitem__
20 |
21 |
22 | def DictToObj(dictObj):
23 | if not isinstance(dictObj, dict):
24 | return dictObj
25 | d = Dict()
26 | for k, v in dictObj.items():
27 | d[k] = DictToObj(v)
28 | return d
29 |
30 |
31 | def clip_gradients(model, clip):
32 | norms = []
33 | for name, p in model.named_parameters():
34 | if p.grad is not None:
35 | param_norm = p.grad.data.norm(2)
36 | norms.append(param_norm.item())
37 | clip_coef = clip / (param_norm + 1e-6)
38 | if clip_coef < 1:
39 | p.grad.data.mul_(clip_coef)
40 | return norms
41 |
42 |
43 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
44 | if epoch >= freeze_last_layer:
45 | return
46 | for n, p in model.named_parameters():
47 | if "last_layer" in n:
48 | p.grad = None
49 |
50 |
51 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
52 | warmup_schedule = np.array([])
53 | warmup_iters = warmup_epochs * niter_per_ep
54 | if warmup_epochs > 0:
55 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
56 |
57 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
58 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
59 |
60 | schedule = np.concatenate((warmup_schedule, schedule))
61 | assert len(schedule) == epochs * niter_per_ep
62 | return schedule
63 |
64 | def piecewise_scheduler(start_value, end_value, epochs, niter_per_ep, warmup_epochs=100):
65 |
66 | warmup_iters = warmup_epochs * niter_per_ep
67 | warmup_schedule = np.ones_like(np.arange(warmup_iters)) * start_value
68 |
69 | schedule = np.ones_like(np.arange(epochs * niter_per_ep - warmup_iters)) * end_value
70 |
71 | schedule = np.concatenate((warmup_schedule, schedule))
72 | assert len(schedule) == epochs * niter_per_ep
73 | return schedule
74 |
75 | def adjust_learning_rate(optimizer, epoch, args):
76 | """Decay the learning rate based on schedule"""
77 | lr = args.lr
78 | if args.cos_learning_rate: # cosine lr schedule
79 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
80 | else: # stepwise lr schedule
81 | for milestone in args.schedule:
82 | lr *= 0.1 if epoch >= milestone else 1.
83 | for param_group in optimizer.param_groups:
84 | param_group['lr'] = lr
85 |
86 | def adjust_learning_rate_with_fix_lr(optimizer, epoch, args):
87 | """Decay the learning rate based on schedule"""
88 | lr = args.lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
89 | for param_group in optimizer.param_groups:
90 | if 'fix_lr' in param_group and param_group['fix_lr']:
91 | param_group['lr'] = args.lr
92 | else:
93 | param_group['lr'] = lr
94 |
95 | def adjust_learning_rate_with_warmup(optimizer, epoch, args):
96 | """Decays the learning rate with half-cycle cosine after warmup"""
97 | if epoch < args.warmup_epochs:
98 | lr = args.lr * epoch / args.warmup_epochs
99 | else:
100 | lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
101 | for param_group in optimizer.param_groups:
102 | param_group['lr'] = lr
103 | return lr
104 |
105 | def adjust_ema_momentum(epoch, args):
106 | """Adjust ema momentum based on current epoch"""
107 | m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.ema_m)
108 | return m
109 |
110 | def bool_flag(s):
111 | """
112 | Parse boolean arguments from the command line.
113 | """
114 | FALSY_STRINGS = {"off", "false", "0"}
115 | TRUTHY_STRINGS = {"on", "true", "1"}
116 | if s.lower() in FALSY_STRINGS:
117 | return False
118 | elif s.lower() in TRUTHY_STRINGS:
119 | return True
120 | else:
121 | raise argparse.ArgumentTypeError("invalid value for a boolean flag")
122 |
123 |
124 | def fix_random_seeds(seed=31):
125 | """
126 | Fix random seeds.
127 | """
128 | torch.manual_seed(seed)
129 | torch.cuda.manual_seed_all(seed)
130 | np.random.seed(seed)
131 |
132 |
133 | class SmoothedValue(object):
134 | """Track a series of values and provide access to smoothed values over a
135 | window or the global series average.
136 | """
137 |
138 | def __init__(self, window_size=20, fmt=None):
139 | if fmt is None:
140 | fmt = "{median:.6f} ({global_avg:.6f})"
141 | self.deque = deque(maxlen=window_size)
142 | self.total = 0.0
143 | self.count = 0
144 | self.fmt = fmt
145 |
146 | def update(self, value, n=1):
147 | self.deque.append(value)
148 | self.count += n
149 | self.total += value * n
150 |
151 | def synchronize_between_processes(self):
152 | """
153 | Warning: does not synchronize the deque!
154 | """
155 | if not is_dist_avail_and_initialized():
156 | return
157 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
158 | dist.barrier()
159 | dist.all_reduce(t)
160 | t = t.tolist()
161 | self.count = int(t[0])
162 | self.total = t[1]
163 |
164 | @property
165 | def median(self):
166 | d = torch.tensor(list(self.deque))
167 | return d.median().item()
168 |
169 | @property
170 | def avg(self):
171 | d = torch.tensor(list(self.deque), dtype=torch.float32)
172 | return d.mean().item()
173 |
174 | @property
175 | def global_avg(self):
176 | return self.total / self.count
177 |
178 | @property
179 | def max(self):
180 | return max(self.deque)
181 |
182 | @property
183 | def value(self):
184 | return self.deque[-1]
185 |
186 | def __str__(self):
187 | return self.fmt.format(
188 | median=self.median,
189 | avg=self.avg,
190 | global_avg=self.global_avg,
191 | max=self.max,
192 | value=self.value)
193 |
194 |
195 | def reduce_dict(input_dict, average=True):
196 | """
197 | Args:
198 | input_dict (dict): all the values will be reduced
199 | average (bool): whether to do average or sum
200 | Reduce the values in the dictionary from all processes so that all processes
201 | have the averaged results. Returns a dict with the same fields as
202 | input_dict, after reduction.
203 | """
204 | world_size = get_world_size()
205 | if world_size < 2:
206 | return input_dict
207 | with torch.no_grad():
208 | names = []
209 | values = []
210 | # sort the keys so that they are consistent across processes
211 | for k in sorted(input_dict.keys()):
212 | names.append(k)
213 | values.append(input_dict[k])
214 | values = torch.stack(values, dim=0)
215 | dist.all_reduce(values)
216 | if average:
217 | values /= world_size
218 | reduced_dict = {k: v for k, v in zip(names, values)}
219 | return reduced_dict
220 |
221 |
222 | class MetricLogger(object):
223 | def __init__(self, delimiter="\t"):
224 | self.meters = defaultdict(SmoothedValue)
225 | self.delimiter = delimiter
226 |
227 | def update(self, **kwargs):
228 | for k, v in kwargs.items():
229 | if isinstance(v, torch.Tensor):
230 | v = v.item()
231 | assert isinstance(v, (float, int))
232 | self.meters[k].update(v)
233 |
234 | def __getattr__(self, attr):
235 | if attr in self.meters:
236 | return self.meters[attr]
237 | if attr in self.__dict__:
238 | return self.__dict__[attr]
239 | raise AttributeError("'{}' object has no attribute '{}'".format(
240 | type(self).__name__, attr))
241 |
242 | def __str__(self):
243 | loss_str = []
244 | for name, meter in self.meters.items():
245 | loss_str.append(
246 | "{}: {}".format(name, str(meter))
247 | )
248 | return self.delimiter.join(loss_str)
249 |
250 | def synchronize_between_processes(self):
251 | for meter in self.meters.values():
252 | meter.synchronize_between_processes()
253 |
254 | def add_meter(self, name, meter):
255 | self.meters[name] = meter
256 |
257 | def log_every(self, iterable, print_freq, header=None):
258 | i = 0
259 | if not header:
260 | header = ''
261 | start_time = time.time()
262 | end = time.time()
263 | iter_time = SmoothedValue(fmt='{avg:.6f}')
264 | data_time = SmoothedValue(fmt='{avg:.6f}')
265 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
266 | if torch.cuda.is_available():
267 | log_msg = self.delimiter.join([
268 | header,
269 | '[{0' + space_fmt + '}/{1}]',
270 | 'eta: {eta}',
271 | '{meters}',
272 | 'time: {time}',
273 | 'data: {data}',
274 | 'max mem: {memory:.0f}'
275 | ])
276 | else:
277 | log_msg = self.delimiter.join([
278 | header,
279 | '[{0' + space_fmt + '}/{1}]',
280 | 'eta: {eta}',
281 | '{meters}',
282 | 'time: {time}',
283 | 'data: {data}'
284 | ])
285 | MB = 1024.0 * 1024.0
286 | for obj in iterable:
287 | data_time.update(time.time() - end)
288 | yield obj
289 | iter_time.update(time.time() - end)
290 | if i % print_freq == 0 or i == len(iterable) - 1:
291 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
292 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
293 | if torch.cuda.is_available():
294 | print(log_msg.format(
295 | i, len(iterable), eta=eta_string,
296 | meters=str(self),
297 | time=str(iter_time), data=str(data_time),
298 | memory=torch.cuda.max_memory_allocated() / MB))
299 | else:
300 | print(log_msg.format(
301 | i, len(iterable), eta=eta_string,
302 | meters=str(self),
303 | time=str(iter_time), data=str(data_time)))
304 | i += 1
305 | end = time.time()
306 | total_time = time.time() - start_time
307 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
308 | print('{} Total time: {} ({:.6f} s / it)'.format(
309 | header, total_time_str, total_time / len(iterable)))
310 |
311 |
312 | def get_sha():
313 | cwd = os.path.dirname(os.path.abspath(__file__))
314 |
315 | def _run(command):
316 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
317 | sha = 'N/A'
318 | diff = "clean"
319 | branch = 'N/A'
320 | try:
321 | sha = _run(['git', 'rev-parse', 'HEAD'])
322 | subprocess.check_output(['git', 'diff'], cwd=cwd)
323 | diff = _run(['git', 'diff-index', 'HEAD'])
324 | diff = "has uncommited changes" if diff else "clean"
325 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
326 | except Exception:
327 | pass
328 | message = f"sha: {sha}, status: {diff}, branch: {branch}"
329 | return message
330 |
331 |
332 | def is_dist_avail_and_initialized():
333 | if not dist.is_available():
334 | return False
335 | if not dist.is_initialized():
336 | return False
337 | return True
338 |
339 |
340 | def get_world_size():
341 | if not is_dist_avail_and_initialized():
342 | return 1
343 | return dist.get_world_size()
344 |
345 |
346 | def get_rank():
347 | if not is_dist_avail_and_initialized():
348 | return 0
349 | return dist.get_rank()
350 |
351 |
352 | def is_main_process():
353 | return get_rank() == 0
354 |
355 |
356 | def save_on_master(*args, **kwargs):
357 | if is_main_process():
358 | torch.save(*args, **kwargs)
359 |
360 |
361 | def setup_for_distributed(is_master):
362 | """
363 | This function disables printing when not in master process
364 | """
365 | import builtins as __builtin__
366 | builtin_print = __builtin__.print
367 |
368 | def print(*args, **kwargs):
369 | force = kwargs.pop('force', False)
370 | if is_master or force:
371 | builtin_print(*args, **kwargs)
372 |
373 | __builtin__.print = print
374 |
375 |
376 | def init_distributed_mode_global(args):
377 | if args.dist_url == "env://" and args.world_size == -1:
378 | args.world_size = int(os.environ["WORLD_SIZE"])
379 |
380 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
381 | args.ngpus_per_node = torch.cuda.device_count()
382 | if args.multiprocessing_distributed:
383 | args.world_size = args.ngpus_per_node * args.world_size
384 |
385 |
386 | def init_distributed_mode(args):
387 | # launched with torch.distributed.launch
388 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
389 | args.rank = int(os.environ["RANK"])
390 | args.world_size = int(os.environ['WORLD_SIZE'])
391 | args.gpu = int(os.environ['LOCAL_RANK'])
392 | # launched with submitit on a slurm cluster
393 | elif 'SLURM_PROCID' in os.environ:
394 | args.rank = int(os.environ['SLURM_PROCID'])
395 | args.gpu = args.rank % torch.cuda.device_count()
396 | elif torch.cuda.is_available():
397 | print('Run the code on one GPU.')
398 | args.rank, args.gpu, args.world_size = 0, 0, 1
399 | os.environ['MASTER_ADDR'] = '127.0.0.1'
400 | os.environ['MASTER_PORT'] = '29500'
401 | else:
402 | print('Not support training without GPU.')
403 | sys.exit(1)
404 |
405 | dist.init_process_group(
406 | backend='nccl',
407 | init_method=args.dist_url,
408 | world_size=args.world_size,
409 | rank=args.rank,
410 | )
411 |
412 | torch.cuda.set_device(args.gpu)
413 | print('| distributed init (rank {}): {}'.format(
414 | args.rank, args.dist_url), flush=True)
415 | dist.barrier()
416 | setup_for_distributed(args.rank == 0)
417 |
418 |
419 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
420 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
421 | def norm_cdf(x):
422 | # Computes standard normal cumulative distribution function
423 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
424 |
425 | if (mean < a - 2 * std) or (mean > b + 2 * std):
426 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
427 | "The distribution of values may be incorrect.",
428 | stacklevel=2)
429 |
430 | with torch.no_grad():
431 | l = norm_cdf((a - mean) / std)
432 | u = norm_cdf((b - mean) / std)
433 |
434 | # Uniformly fill tensor with values from [l, u], then translate to
435 | # [2l-1, 2u-1].
436 | tensor.uniform_(2 * l - 1, 2 * u - 1)
437 |
438 | # Use inverse cdf transform for normal distribution to get truncated
439 | # standard normal
440 | tensor.erfinv_()
441 |
442 | # Transform to proper mean, std
443 | tensor.mul_(std * math.sqrt(2.))
444 | tensor.add_(mean)
445 |
446 | # Clamp to ensure it's in the proper range
447 | tensor.clamp_(min=a, max=b)
448 | return tensor
449 |
450 |
451 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
452 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
453 |
454 |
455 | def get_params_groups(model):
456 | regularized = []
457 | not_regularized = []
458 | for name, param in model.named_parameters():
459 | if not param.requires_grad:
460 | continue
461 | # we do not regularize biases nor Norm parameters
462 | if name.endswith(".bias") or len(param.shape) == 1:
463 | not_regularized.append(param)
464 | else:
465 | regularized.append(param)
466 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
467 |
468 |
469 | def get_params_groups2(model, head_name):
470 | regularized = []
471 | not_regularized = []
472 | transhead = []
473 | for name, param in model.named_parameters():
474 | if not param.requires_grad:
475 | continue
476 | if head_name in name:
477 | transhead.append(param)
478 | continue
479 | # we do not regularize biases nor Norm parameters
480 | if name.endswith(".bias") or len(param.shape) == 1:
481 | not_regularized.append(param)
482 | else:
483 | regularized.append(param)
484 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}, {'params': transhead}]
485 |
486 |
487 | def has_batchnorms(model):
488 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
489 | for name, module in model.named_modules():
490 | if isinstance(module, bn_types):
491 | return True
492 | return False
493 |
494 |
495 | def multi_scale(samples, model):
496 | v = None
497 | for s in [1, 1/2**(1/2), 1/2]: # we use 3 different scales
498 | if s == 1:
499 | inp = samples.clone()
500 | else:
501 | inp = nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False)
502 | feats = model(inp).clone()
503 | if v is None:
504 | v = feats
505 | else:
506 | v += feats
507 | v /= 3
508 | v /= v.norm()
509 | return v
510 |
--------------------------------------------------------------------------------
/backbones/vision_transformer.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Mostly copy-paste from timm library.
4 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
5 | """
6 | import math
7 | from functools import partial, reduce
8 | from operator import mul
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from timm.models.vision_transformer import VisionTransformer, _cfg
14 | from timm.models.layers.helpers import to_2tuple
15 | # from timm.models.layers import PatchEmbed
16 |
17 | from utils.utils import trunc_normal_
18 |
19 | __all__ = [
20 | 'vit_small',
21 | 'vit_base',
22 | 'vit_conv_small',
23 | 'vit_conv_base',
24 | ]
25 |
26 | def drop_path(x, drop_prob: float = 0., training: bool = False):
27 | if drop_prob == 0. or not training:
28 | return x
29 | keep_prob = 1 - drop_prob
30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
32 | random_tensor.floor_() # binarize
33 | output = x.div(keep_prob) * random_tensor
34 | return output
35 |
36 |
37 | class DropPath(nn.Module):
38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
39 | """
40 | def __init__(self, drop_prob=None):
41 | super(DropPath, self).__init__()
42 | self.drop_prob = drop_prob
43 |
44 | def forward(self, x):
45 | return drop_path(x, self.drop_prob, self.training)
46 |
47 |
48 | class Mlp(nn.Module):
49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
50 | super().__init__()
51 | out_features = out_features or in_features
52 | hidden_features = hidden_features or in_features
53 | self.fc1 = nn.Linear(in_features, hidden_features)
54 | self.act = act_layer()
55 | self.fc2 = nn.Linear(hidden_features, out_features)
56 | self.drop = nn.Dropout(drop)
57 |
58 | def forward(self, x):
59 | x = self.fc1(x)
60 | x = self.act(x)
61 | x = self.drop(x)
62 | x = self.fc2(x)
63 | x = self.drop(x)
64 | return x
65 |
66 |
67 | class Attention(nn.Module):
68 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
69 | super().__init__()
70 | self.num_heads = num_heads
71 | head_dim = dim // num_heads
72 | self.scale = qk_scale or head_dim ** -0.5
73 |
74 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
75 | self.attn_drop = nn.Dropout(attn_drop)
76 | self.proj = nn.Linear(dim, dim)
77 | self.proj_drop = nn.Dropout(proj_drop)
78 |
79 | def forward(self, x):
80 | B, N, C = x.shape
81 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
82 | q, k, v = qkv[0], qkv[1], qkv[2]
83 |
84 | attn = (q @ k.transpose(-2, -1)) * self.scale
85 | attn = attn.softmax(dim=-1)
86 | attn = self.attn_drop(attn)
87 |
88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
89 | x = self.proj(x)
90 | x = self.proj_drop(x)
91 | return x, attn
92 |
93 |
94 | class Block(nn.Module):
95 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
96 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
97 | super().__init__()
98 | self.norm1 = norm_layer(dim)
99 | self.attn = Attention(
100 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
101 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
102 | self.norm2 = norm_layer(dim)
103 | mlp_hidden_dim = int(dim * mlp_ratio)
104 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
105 |
106 | def forward(self, x, return_attention=False):
107 | y, attn = self.attn(self.norm1(x))
108 | if return_attention:
109 | return attn
110 | x = x + self.drop_path(y)
111 | x = x + self.drop_path(self.mlp(self.norm2(x)))
112 | return x
113 |
114 |
115 | class PatchEmbed(nn.Module):
116 | """ Image to Patch Embedding
117 | """
118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
119 | super().__init__()
120 | num_patches = (img_size // patch_size) * (img_size // patch_size)
121 | self.img_size = img_size
122 | self.patch_size = patch_size
123 | self.num_patches = num_patches
124 |
125 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
126 |
127 | def forward(self, x):
128 | B, C, H, W = x.shape
129 | x = self.proj(x).flatten(2).transpose(1, 2)
130 | return x
131 |
132 |
133 | class VisionTransformer(nn.Module):
134 | """ Vision Transformer """
135 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
136 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
137 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
138 | super().__init__()
139 | self.num_features = self.embed_dim = embed_dim
140 |
141 | self.patch_embed = PatchEmbed(
142 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
143 | num_patches = self.patch_embed.num_patches
144 |
145 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
146 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
147 | self.pos_drop = nn.Dropout(p=drop_rate)
148 |
149 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
150 | self.blocks = nn.ModuleList([
151 | Block(
152 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
153 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
154 | for i in range(depth)])
155 | self.norm = norm_layer(embed_dim)
156 |
157 | # Classifier head
158 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
159 |
160 | trunc_normal_(self.pos_embed, std=.02)
161 | trunc_normal_(self.cls_token, std=.02)
162 | self.apply(self._init_weights)
163 |
164 | def _init_weights(self, m):
165 | if isinstance(m, nn.Linear):
166 | trunc_normal_(m.weight, std=.02)
167 | if isinstance(m, nn.Linear) and m.bias is not None:
168 | nn.init.constant_(m.bias, 0)
169 | elif isinstance(m, nn.LayerNorm):
170 | nn.init.constant_(m.bias, 0)
171 | nn.init.constant_(m.weight, 1.0)
172 |
173 | def interpolate_pos_encoding(self, x, w, h):
174 | npatch = x.shape[1] - 1
175 | N = self.pos_embed.shape[1] - 1
176 | if npatch == N and w == h:
177 | return self.pos_embed
178 | class_pos_embed = self.pos_embed[:, 0]
179 | patch_pos_embed = self.pos_embed[:, 1:]
180 | dim = x.shape[-1]
181 | w0 = w // self.patch_embed.patch_size
182 | h0 = h // self.patch_embed.patch_size
183 | # we add a small number to avoid floating point error in the interpolation
184 | # see discussion at https://github.com/facebookresearch/dino/issues/8
185 | w0, h0 = w0 + 0.1, h0 + 0.1
186 | patch_pos_embed = nn.functional.interpolate(
187 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
188 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
189 | mode='bicubic',
190 | )
191 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
192 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
193 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
194 |
195 | def prepare_tokens(self, x):
196 | B, nc, w, h = x.shape
197 | x = self.patch_embed(x) # patch linear embedding
198 |
199 | # add the [CLS] token to the embed patch tokens
200 | cls_tokens = self.cls_token.expand(B, -1, -1)
201 | x = torch.cat((cls_tokens, x), dim=1)
202 |
203 | # add positional encoding to each token
204 | x = x + self.interpolate_pos_encoding(x, w, h)
205 |
206 | return self.pos_drop(x)
207 |
208 | def forward(self, x, return_all_tokens=False, return_attention=False):
209 | x = self.prepare_tokens(x)
210 |
211 | if not return_attention:
212 | for blk in self.blocks:
213 | x = blk(x)
214 | x = self.norm(x)
215 | if return_all_tokens:
216 | return x, None
217 | else:
218 | return x[:, 0], None
219 | else:
220 | for i, blk in enumerate(self.blocks):
221 | if i < len(self.blocks) - 1:
222 | x = blk(x)
223 | else:
224 | # return attention of the last block
225 | out = blk(x)
226 | with torch.no_grad():
227 | attentions = blk(x, return_attention=True) # B nHead N N
228 | out = self.norm(out)
229 | # generate mask
230 | with torch.no_grad():
231 | # keep only the output patch attention
232 | attentions = attentions[:, :, 0, 1:] # B nHead N
233 | # attentions = torch.mean(attentions, dim=1) # B N
234 | if return_all_tokens:
235 | return out, attentions
236 | else:
237 | return out[:, 0], attentions
238 |
239 | def get_last_selfattention(self, x):
240 | x = self.prepare_tokens(x)
241 | for i, blk in enumerate(self.blocks):
242 | if i < len(self.blocks) - 1:
243 | x = blk(x)
244 | else:
245 | # return attention of the last block
246 | return blk(x, return_attention=True)
247 |
248 | def get_selfattention(self, x, block_index=11):
249 | x = self.prepare_tokens(x)
250 | if block_index == 0:
251 | return self.blocks[0](x, return_attention=True)
252 |
253 | for i, blk in enumerate(self.blocks):
254 | if i < block_index:
255 | x = blk(x)
256 | else:
257 | return blk(x, return_attention=True)
258 |
259 | def get_layer(self, x, block_index=11):
260 | x = self.prepare_tokens(x)
261 | if block_index == 0:
262 | return self.blocks[0](x)
263 |
264 | for i, blk in enumerate(self.blocks):
265 | if i < block_index:
266 | x = blk(x)
267 | else:
268 | return blk(x)
269 |
270 | def get_intermediate_layers(self, x, n=1):
271 | x = self.prepare_tokens(x)
272 | # we return the output tokens from the `n` last blocks
273 | output = []
274 | for i, blk in enumerate(self.blocks):
275 | x = blk(x)
276 | if len(self.blocks) - i <= n:
277 | output.append(self.norm(x))
278 | return output
279 |
280 |
281 | def vit_tiny(patch_size=16, **kwargs):
282 | model = VisionTransformer(
283 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
284 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
285 | return model
286 |
287 |
288 | def vit_small(patch_size=16, **kwargs):
289 | model = VisionTransformer(
290 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
291 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
292 | return model
293 |
294 |
295 | def vit_base(patch_size=16, **kwargs):
296 | model = VisionTransformer(
297 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
298 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
299 | return model
300 |
301 | def vit_large(patch_size=16, **kwargs):
302 | model = VisionTransformer(
303 | patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
304 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
305 | return model
306 |
307 | def vit_huge(patch_size=16, **kwargs):
308 | model = VisionTransformer(
309 | patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
310 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
311 | return model
312 |
313 | # input is a 4-d feature map with shape B C H W
314 | class VisionTransformerHead(nn.Module):
315 | """ Vision Transformer """
316 | def __init__(self, featuremap_size=7, in_chans=2048, embed_dim=384, depth=12,
317 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
318 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
319 | super().__init__()
320 |
321 | self.num_features = self.embed_dim = embed_dim
322 | self.num_patches = featuremap_size * featuremap_size
323 |
324 | self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=1, stride=1)
325 |
326 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
327 | self.pos_drop = nn.Dropout(p=drop_rate)
328 |
329 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
330 | self.blocks = nn.ModuleList([
331 | Block(
332 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
333 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
334 | for i in range(depth)])
335 | self.norm = norm_layer(embed_dim)
336 |
337 | trunc_normal_(self.pos_embed, std=.02)
338 | self.apply(self._init_weights)
339 |
340 |
341 | def _init_weights(self, m):
342 | if isinstance(m, nn.Linear):
343 | trunc_normal_(m.weight, std=.02)
344 | if isinstance(m, nn.Linear) and m.bias is not None:
345 | nn.init.constant_(m.bias, 0)
346 | elif isinstance(m, nn.LayerNorm):
347 | nn.init.constant_(m.bias, 0)
348 | nn.init.constant_(m.weight, 1.0)
349 |
350 |
351 | def prepare_tokens(self, x):
352 | B, nc, h, w = x.shape
353 | x = self.patch_embed(x).flatten(2).transpose(1, 2) # patch linear embedding B N C
354 |
355 | # add positional encoding to each token
356 | x = x + self.pos_embed
357 |
358 | return self.pos_drop(x)
359 |
360 |
361 | def forward(self, x, return_all_tokens=False, return_attention=False):
362 | x = self.prepare_tokens(x)
363 |
364 | if not return_attention:
365 | for blk in self.blocks:
366 | x = blk(x)
367 | x = self.norm(x)
368 | if return_all_tokens:
369 | return x, None
370 | else:
371 | return x[:, 0], None
372 | else:
373 | for i, blk in enumerate(self.blocks):
374 | if i < len(self.blocks) - 1:
375 | x = blk(x)
376 | else:
377 | # return attention of the last block
378 | out = blk(x)
379 | with torch.no_grad():
380 | attentions = blk(x, return_attention=True) # B nHead N N
381 | out = self.norm(out)
382 | # generate mask
383 | with torch.no_grad():
384 | # keep only the output patch attention
385 | attentions = attentions[:, :, 0, 1:] # B nHead N
386 | # attentions = torch.mean(attentions, dim=1) # B N
387 | if return_all_tokens:
388 | return out, attentions
389 | else:
390 | return out[:, 0], attentions
391 |
392 |
393 | # input is a 4-d feature map with shape B C H W
394 | class THead_CNN(nn.Module):
395 | """ Vision Transformer """
396 | def __init__(self, out_dim, featuremap_size=7, in_chans=2048, embed_dim=384, depth=12,
397 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
398 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
399 | super().__init__()
400 |
401 | self.num_features = self.embed_dim = embed_dim
402 | self.num_patches = featuremap_size * featuremap_size
403 |
404 | self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=1, stride=1)
405 |
406 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
407 | self.pos_drop = nn.Dropout(p=drop_rate)
408 |
409 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
410 | self.blocks = nn.ModuleList([
411 | Block(
412 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
413 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
414 | for i in range(depth)])
415 | self.norm = norm_layer(embed_dim)
416 |
417 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
418 |
419 | # self.last_layer = nn.Linear(embed_dim, out_dim, bias=False)
420 | self.last_layer = nn.utils.weight_norm(nn.Linear(embed_dim, out_dim, bias=False))
421 | self.last_layer.weight_g.data.fill_(1)
422 |
423 | trunc_normal_(self.pos_embed, std=.02)
424 | self.apply(self._init_weights)
425 |
426 | def _init_weights(self, m):
427 | if isinstance(m, nn.Linear):
428 | trunc_normal_(m.weight, std=.02)
429 | if isinstance(m, nn.Linear) and m.bias is not None:
430 | nn.init.constant_(m.bias, 0)
431 | elif isinstance(m, nn.LayerNorm):
432 | nn.init.constant_(m.bias, 0)
433 | nn.init.constant_(m.weight, 1.0)
434 |
435 | def interpolate_pos_encoding(self, x, w, h):
436 | npatch = x.shape[1]
437 | N = self.pos_embed.shape[1]
438 | if npatch == N and w == h:
439 | return self.pos_embed
440 | patch_pos_embed = self.pos_embed
441 | dim = x.shape[-1]
442 | # w0 = w // self.patch_embed.patch_size
443 | # h0 = h // self.patch_embed.patch_size
444 | # we add a small number to avoid floating point error in the interpolation
445 | # see discussion at https://github.com/facebookresearch/dino/issues/8
446 | w0, h0 = w + 0.1, h + 0.1
447 | patch_pos_embed = nn.functional.interpolate(
448 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
449 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
450 | mode='bicubic',
451 | )
452 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
453 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
454 | return patch_pos_embed
455 |
456 | def prepare_tokens(self, x):
457 | B, nc, h, w = x.shape
458 | x = self.patch_embed(x) # reduce demision
459 | x_pool = self.avgpool(x).flatten(1)
460 |
461 | x = x.flatten(2).transpose(1, 2) # patch linear embedding B N C
462 |
463 | # add positional encoding to each token
464 | # x = x + self.pos_embed
465 | x = x + self.interpolate_pos_encoding(x, w, h)
466 |
467 | # return self.pos_drop(x)
468 | return x, x_pool
469 |
470 | def forward(self, x, query_token=None):
471 | x, x_pool = self.prepare_tokens(x)
472 |
473 | # add query token
474 | if query_token is not None:
475 | query_token = query_token.unsqueeze(1) # B C => B 1 C
476 | x = torch.cat((query_token, x), dim=1)
477 |
478 | for blk in self.blocks:
479 | x = blk(x)
480 | x = self.norm(x)
481 | if query_token is not None:
482 | x = x[:, 1:]
483 | x = torch.mean(x, dim=1)
484 | x = self.last_layer(x) # B out_dim
485 | return x, x_pool
486 |
487 |
488 | # input is a 4-d feature map with shape B N C
489 | class THead_ViT(nn.Module):
490 | """ Vision Transformer """
491 | def __init__(self, out_dim, featuremap_size=14, embed_dim=384, depth=12,
492 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
493 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
494 | super().__init__()
495 |
496 | self.num_features = self.embed_dim = embed_dim
497 | self.num_patches = featuremap_size * featuremap_size
498 |
499 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
500 | self.pos_drop = nn.Dropout(p=drop_rate)
501 |
502 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
503 | self.blocks = nn.ModuleList([
504 | Block(
505 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
506 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
507 | for i in range(depth)])
508 | self.norm = norm_layer(embed_dim)
509 |
510 | # self.last_layer = nn.Linear(embed_dim, out_dim, bias=False)
511 | self.last_layer = nn.utils.weight_norm(nn.Linear(embed_dim, out_dim, bias=False))
512 | self.last_layer.weight_g.data.fill_(1)
513 |
514 | trunc_normal_(self.pos_embed, std=.02)
515 | self.apply(self._init_weights)
516 |
517 | def _init_weights(self, m):
518 | if isinstance(m, nn.Linear):
519 | trunc_normal_(m.weight, std=.02)
520 | if isinstance(m, nn.Linear) and m.bias is not None:
521 | nn.init.constant_(m.bias, 0)
522 | elif isinstance(m, nn.LayerNorm):
523 | nn.init.constant_(m.bias, 0)
524 | nn.init.constant_(m.weight, 1.0)
525 |
526 | def interpolate_pos_encoding(self, x, w, h):
527 | npatch = x.shape[1]
528 | N = self.pos_embed.shape[1]
529 | if npatch == N and w == h:
530 | return self.pos_embed
531 | patch_pos_embed = self.pos_embed
532 | dim = x.shape[-1]
533 | # w0 = w // self.patch_embed.patch_size
534 | # h0 = h // self.patch_embed.patch_size
535 | # we add a small number to avoid floating point error in the interpolation
536 | # see discussion at https://github.com/facebookresearch/dino/issues/8
537 | w0, h0 = w + 0.1, h + 0.1
538 | patch_pos_embed = nn.functional.interpolate(
539 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
540 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
541 | mode='bicubic',
542 | )
543 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
544 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
545 | return patch_pos_embed
546 |
547 | def prepare_tokens(self, x):
548 | B, N, C = x.shape
549 | size = int(math.sqrt(N))
550 |
551 | # add positional encoding to each token
552 | # x = x + self.pos_embed
553 | x = x + self.interpolate_pos_encoding(x, size, size)
554 |
555 | # return self.pos_drop(x)
556 | return x
557 |
558 | def forward(self, x, query_token=None):
559 | x = self.prepare_tokens(x)
560 |
561 | # add query token
562 | if query_token is not None:
563 | query_token = query_token.unsqueeze(1) # B C => B 1 C
564 | x = torch.cat((query_token, x), dim=1)
565 |
566 | for blk in self.blocks:
567 | x = blk(x)
568 | x = self.norm(x)
569 | if query_token is not None:
570 | x = x[:, 1:]
571 | x = torch.mean(x, dim=1)
572 | x = self.last_layer(x) # B out_dim
573 | return x
574 |
575 |
--------------------------------------------------------------------------------
/main_mokd.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import argparse
3 | import os
4 | import sys
5 | import builtins
6 | import datetime
7 | import time
8 | import math
9 | import json
10 | from pathlib import Path
11 | from functools import partial
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.backends.cudnn as cudnn
16 | # from torchvision import models as torchvision_models
17 | from torch.utils.tensorboard import SummaryWriter
18 | from torchvision import datasets
19 |
20 | from utils import utils
21 | from utils import checkpoint_io
22 | from utils import optimizers
23 | from augmentations.dino_augmentation import DINODataAugmentation
24 | import backbones.vision_transformer as vits
25 | import backbones.resnet as resnets
26 | from models.mokd import *
27 |
28 | method = "mokd"
29 |
30 | def get_args_parser():
31 | parser = argparse.ArgumentParser(method, add_help=False)
32 |
33 | #################################
34 | #### input and output parameters ####
35 | #################################
36 | parser.add_argument('--data_path', default='/path/to/imagenet/train/', type=str,
37 | help='Please specify path to the ImageNet training data.')
38 | parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.')
39 | parser.add_argument('--saveckp_freq', default=10, type=int, help='Save checkpoint every x epochs.')
40 | parser.add_argument('--experiment', default='exp', type=str, help='experiment name')
41 |
42 | #################################
43 | #### augmentation parameters ####
44 | #################################
45 | # multi-crop parameters
46 | parser.add_argument('--input_size', default=224, type=int, help='input image size')
47 | parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.),
48 | help="""Scale range of the cropped image before resizing, relatively to the origin image.
49 | Used for large global view cropping.""")
50 | parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small
51 | local views to generate.""")
52 | parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4),
53 | help="""Scale range of the cropped image before resizing, relatively to the origin image.
54 | Used for small local view cropping of multi-crop.""")
55 |
56 | #################################
57 | ####model parameters ####
58 | #################################
59 | parser.add_argument('--arch_cnn', default='resnet50', type=str,
60 | help="""Name of architecture to train. For quick experiments with ViTs""")
61 | parser.add_argument('--arch_vit', default='vit_small', type=str,
62 | help="""Name of architecture to train. For quick experiments with ViTs""")
63 | parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of
64 | the DINO head output. For complex and large datasets large values (like 65k) work well.""")
65 | parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag,
66 | help="""Whether or not to weight normalize the last layer of the DINO head..""")
67 | parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA
68 | parameter for teacher update.""")
69 | parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag,
70 | help="Whether to use batch normalizations in projection head (Default: False)")
71 | # for ViTs
72 | parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels
73 | of input square patches - default 16.""")
74 | parser.add_argument('--drop_path_rate', type=float, default=0.1, help="stochastic depth rate")
75 | # for cross-distillation
76 | parser.add_argument("--lamda_c", default=1.0, type=float, help=""" weight for ct loss cnn.""")
77 | parser.add_argument("--lamda_t", default=1.0, type=float, help=""" weight for ct loss vit.""")
78 |
79 | #################################
80 | #### optim parameters ###
81 | #################################
82 | # training/pptimization parameters
83 | parser.add_argument('--batch_size_per_gpu', default=64, type=int,
84 | help='Per-GPU batch-size : number of distinct images loaded on one GPU.')
85 | parser.add_argument('--epochs', default=200, type=int, help='Number of epochs of training.')
86 | parser.add_argument('--optimizer', default='adamw', type=str,
87 | choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""")
88 | parser.add_argument("--lr_cnn", default=0.3, type=float, help="""Learning rate at the end of
89 | linear warmup (highest LR used during training).""")
90 | parser.add_argument("--lr_vit", default=0.0005, type=float, help="""Learning rate at the end of
91 | linear warmup (highest LR used during training). """)
92 | parser.add_argument("--warmup_epochs", default=10, type=int,
93 | help="Number of epochs for the linear learning-rate warm up.")
94 | parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the
95 | end of optimization. We use a cosine LR schedule with linear warmup.""")
96 | parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the
97 | weight decay. With ViT, a smaller value at the beginning of training works well.""")
98 | parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the
99 | weight decay. We use a cosine schedule for WD and using a larger decay by
100 | the end of training improves performance for ViTs.""")
101 | parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not
102 | to use half precision for training.""")
103 | parser.add_argument('--clip_grad_cnn', type=float, default=3.0, help="""Maximal parameter
104 | gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
105 | help optimization for larger ViT architectures. 0 for disabling.""")
106 | parser.add_argument('--clip_grad_vit', type=float, default=0.0, help="""Maximal parameter
107 | gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
108 | help optimization for larger ViT architectures. 0 for disabling.""")
109 | parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs
110 | during which we keep the output layer fixed. Typically doing so during
111 | the first epoch helps training. Try increasing this value if the loss does not decrease.""")
112 | # temperature teacher parameters
113 | parser.add_argument('--warmup_teacher_temp', default=0.04, type=float,
114 | help="""Initial value for the teacher temperature: 0.04 works well in most cases.
115 | Try decreasing it if the training loss does not decrease.""")
116 | parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup)
117 | of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
118 | starting with the default value of 0.04 and increase this slightly if needed.""")
119 | parser.add_argument('--warmup_teacher_temp_epochs_cnn', default=50, type=int,
120 | help='Number of warmup epochs for the teacher temperature (Default: 50).')
121 | parser.add_argument('--warmup_teacher_temp_epochs_vit', default=30, type=int,
122 | help='Number of warmup epochs for the teacher temperature (Default: 30).')
123 |
124 | #################################
125 | #### dist parameters ###
126 | #################################
127 | parser.add_argument('--seed', default=0, type=int, help='Random seed.')
128 | parser.add_argument('--rank', default=-1, type=int,
129 | help='node rank for distributed training')
130 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
131 | parser.add_argument("--dist_url", default="env://", type=str, help="url used to set up distributed training")
132 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
133 |
134 | return parser
135 |
136 | def train(args):
137 |
138 | ######################## init dist ########################
139 | utils.init_distributed_mode(args)
140 | utils.fix_random_seeds(args.seed)
141 | print("git:\n {}\n".format(utils.get_sha()))
142 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
143 | cudnn.benchmark = True
144 |
145 | if args.gpu != 0:
146 | def print_pass(*args):
147 | pass
148 | builtins.print = print_pass
149 |
150 | if args.gpu is not None:
151 | print("Use GPU: {} for training".format(args.gpu))
152 |
153 | ######################## preparing data ... ########################
154 | transform = DINODataAugmentation(
155 | args.global_crops_scale,
156 | args.local_crops_scale,
157 | args.local_crops_number,
158 | )
159 |
160 | dataset = datasets.ImageFolder(args.data_path, transform=transform)
161 |
162 | sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
163 | data_loader = torch.utils.data.DataLoader(
164 | dataset,
165 | sampler=sampler,
166 | batch_size=args.batch_size_per_gpu,
167 | num_workers=args.num_workers,
168 | pin_memory=True,
169 | drop_last=True,
170 | )
171 | print(f"Loaded {len(dataset)} training images.")
172 |
173 | ######################## building networks ...########################
174 | if args.arch_vit in vits.__dict__.keys():
175 | student_vit = vits.__dict__[args.arch_vit](
176 | patch_size=args.patch_size,
177 | drop_path_rate=args.drop_path_rate, # stochastic depth
178 | )
179 | teacher_vit = vits.__dict__[args.arch_vit](patch_size=args.patch_size)
180 | embed_dim_vit = student_vit.embed_dim
181 | else:
182 | print(f"Unknow architecture: {args.arch_vit}")
183 |
184 | if args.arch_cnn in resnets.__dict__.keys():
185 | student_cnn = resnets.__dict__[args.arch_cnn]()
186 | teacher_cnn = resnets.__dict__[args.arch_cnn]()
187 | embed_dim_cnn = student_cnn.fc.weight.shape[1]
188 | else:
189 | print(f"Unknow architecture: {args.arch_cnn}")
190 |
191 | # multi-crop wrapper handles forward with inputs of different resolutions
192 | # use_bn = True if args.cnn_checkpoint is None else False
193 | student_cnn = CNNStudentWrapper(student_cnn, DINOHead(
194 | embed_dim_cnn,
195 | args.out_dim,
196 | use_bn=False,
197 | norm_last_layer=True),
198 | vits.THead_CNN(out_dim=args.out_dim, featuremap_size=args.input_size // 32, in_chans=embed_dim_cnn,
199 | embed_dim=embed_dim_vit, depth=3, num_heads=6, mlp_ratio=4,
200 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)),
201 | )
202 |
203 | teacher_cnn = CNNTeacherWrapper(
204 | teacher_cnn,
205 | DINOHead(embed_dim_cnn, args.out_dim, False),
206 | vits.THead_CNN(out_dim=args.out_dim, featuremap_size=args.input_size // 32, in_chans=embed_dim_cnn,
207 | embed_dim=embed_dim_vit, depth=3, num_heads=6, mlp_ratio=4,
208 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)),
209 | args.local_crops_number,
210 | )
211 |
212 | student_vit = ViTStudentWrapper(student_vit, DINOHead(
213 | embed_dim_vit,
214 | args.out_dim,
215 | use_bn=False,
216 | norm_last_layer=False),
217 | vits.THead_ViT(out_dim=args.out_dim, featuremap_size=14,
218 | embed_dim=embed_dim_vit, depth=3, num_heads=6, mlp_ratio=4,
219 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)),
220 | )
221 | teacher_vit = ViTTeacherWrapper(
222 | teacher_vit,
223 | DINOHead(embed_dim_vit, args.out_dim, False),
224 | vits.THead_ViT(out_dim=args.out_dim, featuremap_size=14,
225 | embed_dim=embed_dim_vit, depth=3, num_heads=6, mlp_ratio=4,
226 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)),
227 | args.local_crops_number,
228 | )
229 |
230 | # move networks to gpu
231 | student_cnn, teacher_cnn = student_cnn.cuda(), teacher_cnn.cuda()
232 | # synchronize batch norms
233 | if utils.has_batchnorms(student_cnn):
234 | student_cnn = nn.SyncBatchNorm.convert_sync_batchnorm(student_cnn)
235 | teacher_cnn = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_cnn)
236 | # use DDP wrapper to have synchro batch norms working...
237 | teacher_cnn = nn.parallel.DistributedDataParallel(teacher_cnn, device_ids=[args.gpu])
238 | teacher_cnn_without_ddp = teacher_cnn.module
239 | else:
240 | teacher_cnn_without_ddp = teacher_cnn
241 | student_cnn = nn.parallel.DistributedDataParallel(student_cnn, device_ids=[args.gpu], find_unused_parameters=False)
242 | teacher_cnn_without_ddp.load_state_dict(student_cnn.module.state_dict(), strict=False)
243 | for p in teacher_cnn.parameters():
244 | p.requires_grad = False
245 | print(f"CNN student and Teacher are built: they are both {args.arch_cnn} network.")
246 |
247 | student_vit, teacher_vit = student_vit.cuda(), teacher_vit.cuda()
248 | # synchronize batch norms
249 | if utils.has_batchnorms(student_vit):
250 | student_vit = nn.SyncBatchNorm.convert_sync_batchnorm(student_vit)
251 | teacher_vit = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_vit)
252 | # use DDP wrapper to have synchro batch norms working...
253 | teacher_vit = nn.parallel.DistributedDataParallel(teacher_vit, device_ids=[args.gpu])
254 | teacher_vit_without_ddp = teacher_vit.module
255 | else:
256 | teacher_vit_without_ddp = teacher_vit
257 | student_vit = nn.parallel.DistributedDataParallel(student_vit, device_ids=[args.gpu], find_unused_parameters=False)
258 | teacher_vit_without_ddp.load_state_dict(student_vit.module.state_dict(), strict=False)
259 | for p in teacher_vit.parameters():
260 | p.requires_grad = False
261 | print(f"ViT student and Teacher are built: they are both {args.arch_vit} network.")
262 |
263 | ######################## preparing loss ... ########################
264 | loss_cnn_fn = DINOLoss(
265 | args.out_dim,
266 | args.local_crops_number + 2,
267 | args.warmup_teacher_temp,
268 | args.teacher_temp,
269 | args.warmup_teacher_temp_epochs_cnn, # args.warmup_teacher_temp_epochs 50
270 | args.epochs,
271 | ).cuda()
272 |
273 | loss_vit_fn = DINOLoss(
274 | args.out_dim,
275 | args.local_crops_number + 2,
276 | args.warmup_teacher_temp,
277 | args.teacher_temp,
278 | args.warmup_teacher_temp_epochs_vit, # args.warmup_teacher_temp_epochs 30
279 | args.epochs,
280 | ).cuda()
281 |
282 | loss_cnn_ct_fn = DINOLoss(
283 | args.out_dim,
284 | args.local_crops_number + 2,
285 | args.warmup_teacher_temp,
286 | args.teacher_temp,
287 | args.warmup_teacher_temp_epochs_vit, # args.warmup_teacher_temp_epochs 30
288 | args.epochs,
289 | ).cuda()
290 |
291 | loss_vit_ct_fn = DINOLoss(
292 | args.out_dim,
293 | args.local_crops_number + 2,
294 | args.warmup_teacher_temp,
295 | args.teacher_temp,
296 | args.warmup_teacher_temp_epochs_vit, # args.warmup_teacher_temp_epochs 30
297 | args.epochs,
298 | ).cuda()
299 |
300 | loss_cnn_thead_fn = DINOLoss(
301 | args.out_dim,
302 | args.local_crops_number + 2,
303 | args.warmup_teacher_temp,
304 | args.teacher_temp,
305 | args.warmup_teacher_temp_epochs_cnn, # args.warmup_teacher_temp_epochs 50
306 | args.epochs,
307 | ).cuda()
308 |
309 | loss_vit_thead_fn = DINOLoss(
310 | args.out_dim,
311 | args.local_crops_number + 2,
312 | args.warmup_teacher_temp,
313 | args.teacher_temp,
314 | args.warmup_teacher_temp_epochs_vit, # args.warmup_teacher_temp_epochs 30
315 | args.epochs,
316 | ).cuda()
317 |
318 | loss_search_ct_fn = CTSearchLoss(
319 | args.warmup_teacher_temp,
320 | args.teacher_temp,
321 | args.warmup_teacher_temp_epochs_vit,
322 | args.epochs,)
323 |
324 | ######################## preparing optimizer ... ########################
325 | params_groups_cnn = utils.get_params_groups2(student_cnn, head_name="transhead")
326 | if args.optimizer == "adamw":
327 | optimizer_cnn = torch.optim.AdamW(params_groups_cnn) # to use with ViTs
328 | elif args.optimizer == "sgd":
329 | optimizer_cnn = torch.optim.SGD(params_groups_cnn, lr=0, momentum=0.9) # lr is set by scheduler
330 | elif args.optimizer == "lars":
331 | optimizer_cnn = optimizers.LARS(params_groups_cnn) # to use with convnet and large batches
332 |
333 | params_groups_vit = utils.get_params_groups(student_vit)
334 | optimizer_vit = torch.optim.AdamW(params_groups_vit) # to use with ViTs
335 |
336 | # for mixed precision training
337 | fp16_scaler_cnn = None
338 | fp16_scaler_vit = None
339 | if args.use_fp16:
340 | fp16_scaler_cnn = torch.cuda.amp.GradScaler()
341 | fp16_scaler_vit = torch.cuda.amp.GradScaler()
342 |
343 | ######################## init schedulers ... ########################
344 | lr_schedule_cnn = utils.cosine_scheduler(
345 | args.lr_cnn * (args.batch_size_per_gpu * utils.get_world_size()) / 256.,
346 | 0.0048, # args.min_lr
347 | args.epochs, len(data_loader),
348 | warmup_epochs=args.warmup_epochs,
349 | )
350 | wd_schedule_cnn = utils.cosine_scheduler(
351 | 1e-4, # args.weight_decay
352 | 1e-4, # args.weight_decay_end
353 | args.epochs, len(data_loader),
354 | )
355 |
356 | lr_schedule_vit = utils.cosine_scheduler(
357 | args.lr_vit * (args.batch_size_per_gpu * utils.get_world_size()) / 256.,
358 | 1e-5, # args.min_lr
359 | args.epochs, len(data_loader),
360 | warmup_epochs=args.warmup_epochs,
361 | )
362 | wd_schedule_vit = utils.cosine_scheduler(
363 | 0.04, # args.weight_decay
364 | 0.4, # args.weight_decay_end
365 | args.epochs, len(data_loader),
366 | )
367 |
368 | # momentum parameter is increased to 1. during training with a cosine schedule
369 | momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, args.epochs, len(data_loader))
370 | print(f"Loss, optimizer and schedulers ready.")
371 |
372 | summary_writer = SummaryWriter(log_dir=os.path.join(args.output_dir,
373 | "tb", "{}_{}.{}_pretrain_{}".format(method, args.arch_cnn, args.arch_vit, args.experiment))) if args.rank == 0 else None
374 |
375 | to_restore = {"epoch": 0}
376 | checkpoint_io.restart_from_checkpoint(
377 | os.path.join(args.output_dir, "{}_{}.{}_pretrain_{}_temp.pth".format(method, args.arch_cnn,
378 | args.arch_vit, args.experiment)),
379 | run_variables=to_restore,
380 | student_cnn=student_cnn,
381 | teacher_cnn=teacher_cnn,
382 | student_vit=student_vit,
383 | teacher_vit=teacher_vit,
384 | optimizer_cnn=optimizer_cnn,
385 | optimizer_vit=optimizer_vit,
386 | fp16_scaler_cnn=fp16_scaler_cnn,
387 | fp16_scaler_vit=fp16_scaler_vit,
388 | loss_cnn_fn=loss_cnn_fn,
389 | loss_vit_fn=loss_vit_fn,
390 | loss_cnn_ct_fn=loss_cnn_ct_fn,
391 | loss_vit_ct_fn=loss_vit_ct_fn,
392 | loss_cnn_thead_fn=loss_cnn_thead_fn,
393 | loss_vit_thead_fn=loss_vit_thead_fn,
394 | )
395 | start_epoch = to_restore["epoch"]
396 |
397 | ######################## start training ########################
398 | start_time = time.time()
399 | print("Starting {} training !".format(method))
400 | for epoch in range(start_epoch, args.epochs):
401 | data_loader.sampler.set_epoch(epoch)
402 |
403 | ######################## training one epoch of DINO ... ########################
404 | train_stats = train_one_epoch(
405 | student_cnn, teacher_cnn, teacher_cnn_without_ddp,
406 | student_vit, teacher_vit, teacher_vit_without_ddp,
407 | loss_cnn_fn, loss_vit_fn, loss_cnn_ct_fn, loss_vit_ct_fn, loss_cnn_thead_fn, loss_vit_thead_fn, loss_search_ct_fn,
408 | data_loader, optimizer_cnn, lr_schedule_cnn, wd_schedule_cnn,
409 | optimizer_vit, lr_schedule_vit, wd_schedule_vit,
410 | momentum_schedule, epoch,
411 | fp16_scaler_cnn, fp16_scaler_vit,
412 | summary_writer, args)
413 |
414 | ########################writing logs ... ########################
415 | save_dict = {
416 | 'student_cnn': student_cnn.state_dict(),
417 | 'teacher_cnn': teacher_cnn.state_dict(),
418 | 'student_vit': student_vit.state_dict(),
419 | 'teacher_vit': teacher_vit.state_dict(),
420 | 'optimizer_cnn': optimizer_cnn.state_dict(),
421 | 'optimizer_vit': optimizer_vit.state_dict(),
422 | 'epoch': epoch + 1,
423 | 'arch_cnn': args.arch_cnn,
424 | 'arch_vit': args.arch_vit,
425 | 'loss_cnn_fn': loss_cnn_fn.state_dict(),
426 | 'loss_vit_fn': loss_vit_fn.state_dict(),
427 | 'loss_cnn_ct_fn': loss_cnn_ct_fn.state_dict(),
428 | 'loss_vit_ct_fn': loss_vit_ct_fn.state_dict(),
429 | 'loss_cnn_thead_fn': loss_cnn_thead_fn.state_dict(),
430 | 'loss_vit_thead_fn': loss_vit_thead_fn.state_dict(),
431 | }
432 | if fp16_scaler_cnn is not None:
433 | save_dict['fp16_scaler_cnn'] = fp16_scaler_cnn.state_dict()
434 | save_dict['fp16_scaler_vit'] = fp16_scaler_vit.state_dict()
435 |
436 | utils.save_on_master(save_dict, os.path.join(args.output_dir,
437 | "{}_{}.{}_pretrain_{}_temp.pth".format(method, args.arch_cnn, args.arch_vit, args.experiment)))
438 |
439 | if (args.saveckp_freq and epoch % args.saveckp_freq == 0) or epoch == args.epochs - 1:
440 | utils.save_on_master(save_dict, os.path.join(args.output_dir,
441 | "{}_{}.{}_pretrain_{}_{:04d}.pth".format(method, args.arch_cnn, args.arch_vit, args.experiment, epoch)))
442 |
443 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
444 | 'epoch': epoch}
445 |
446 | if utils.is_main_process():
447 | with (Path(args.output_dir) / "{}_{}.{}_pretrain_{}_log.txt".format(method, args.arch_cnn, args.arch_vit, args.experiment)).open("a") as f:
448 | f.write(json.dumps(log_stats) + "\n")
449 |
450 | if args.rank == 0:
451 | summary_writer.close()
452 | total_time = time.time() - start_time
453 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
454 | print('Training time {}'.format(total_time_str))
455 |
456 | def train_one_epoch(student_cnn, teacher_cnn, teacher_cnn_without_ddp,
457 | student_vit, teacher_vit, teacher_vit_without_ddp,
458 | loss_cnn_fn, loss_vit_fn, loss_cnn_ct_fn, loss_vit_ct_fn, loss_cnn_thead_fn, loss_vit_thead_fn, loss_search_ct_fn,
459 | data_loader, optimizer_cnn, lr_schedule_cnn, wd_schedule_cnn,
460 | optimizer_vit, lr_schedule_vit, wd_schedule_vit,
461 | momentum_schedule, epoch,
462 | fp16_scaler_cnn, fp16_scaler_vit,
463 | summary_writer, args):
464 | metric_logger = utils.MetricLogger(delimiter=" ")
465 | header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
466 | iters_per_epoch = len(data_loader)
467 | for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
468 | # update weight decay and learning rate
469 | it = len(data_loader) * epoch + it
470 | for i, param_group in enumerate(optimizer_cnn.param_groups):
471 | param_group["lr"] = lr_schedule_cnn[it]
472 | if i == 0: # only the first group is regularized
473 | param_group["weight_decay"] = wd_schedule_cnn[it]
474 | if i == 2: # transhead
475 | param_group["lr"] = lr_schedule_vit[it]
476 | param_group["weight_decay"] = wd_schedule_vit[it]
477 |
478 | for i, param_group in enumerate(optimizer_vit.param_groups):
479 | param_group["lr"] = lr_schedule_vit[it]
480 | if i == 0: # only the first group is regularized
481 | param_group["weight_decay"] = wd_schedule_vit[it]
482 |
483 | # move images to gpu
484 | images = [im.cuda(non_blocking=True) for im in images]
485 |
486 | # student cnn
487 | with torch.cuda.amp.autocast(fp16_scaler_cnn is not None):
488 | s_cnn_m, s_cnn_t_g, s_cnn_out_g = student_cnn(images[:2])
489 | s_cnn_m_l, s_cnn_t_l, s_cnn_out_l = student_cnn(images[2:])
490 | s_cnn_m = torch.cat([s_cnn_m, s_cnn_m_l], dim=0)
491 | s_cnn_t = torch.cat([s_cnn_t_g, s_cnn_t_l], dim=0)
492 |
493 | # student vit
494 | with torch.cuda.amp.autocast(fp16_scaler_vit is not None):
495 | s_vit_m, s_vit_t_g, s_vit_out_g = student_vit(images[:2])
496 | s_vit_m_l, s_vit_t_l, s_vit_out_l = student_vit(images[2:])
497 | s_vit_m = torch.cat([s_vit_m, s_vit_m_l], dim=0)
498 | s_vit_t = torch.cat([s_vit_t_g, s_vit_t_l], dim=0)
499 |
500 | # teacher cnn
501 | with torch.cuda.amp.autocast(fp16_scaler_cnn is not None):
502 | t_cnn_m, t_cnn_t, t_cnn_search_l, _ = teacher_cnn(images[:2], local_token=s_vit_out_l.detach())
503 |
504 | # teacher vit
505 | with torch.cuda.amp.autocast(fp16_scaler_vit is not None):
506 | t_vit_m, t_vit_t, t_vit_search_l, _ = teacher_vit(images[:2], local_token=s_cnn_out_l.detach())
507 |
508 | # loss
509 | with torch.cuda.amp.autocast(fp16_scaler_cnn is not None):
510 | loss_cnn_m = loss_cnn_fn(s_cnn_m, t_cnn_m, epoch)
511 | loss_cnn_t = loss_cnn_thead_fn(s_cnn_t, t_cnn_t, epoch)
512 | loss_ct_cnn = loss_cnn_ct_fn(s_cnn_m, t_vit_m.detach(), epoch)
513 | loss_ct_search_cnn = loss_search_ct_fn(s_cnn_t_l, t_vit_search_l.detach(), epoch)
514 |
515 | loss_cnn_total = loss_cnn_m + loss_cnn_t + args.lamda_c * (loss_ct_cnn + loss_ct_search_cnn)
516 |
517 | with torch.cuda.amp.autocast(fp16_scaler_vit is not None):
518 | loss_vit_m = loss_vit_fn(s_vit_m, t_vit_m, epoch)
519 | loss_vit_t = loss_vit_thead_fn(s_vit_t, t_vit_t, epoch)
520 | loss_ct_vit = loss_vit_ct_fn(s_vit_m, t_cnn_m.detach(), epoch)
521 | loss_ct_search_vit = loss_search_ct_fn(s_vit_t_l, t_cnn_search_l.detach(), epoch)
522 |
523 | loss_vit_total = loss_vit_m + loss_vit_t + args.lamda_t * (loss_ct_vit + loss_ct_search_vit)
524 |
525 | if not math.isfinite(loss_cnn_total.item()):
526 | print("Loss is {}, stopping training".format(loss_cnn_total.item()), force=True)
527 | sys.exit(1)
528 | if not math.isfinite(loss_vit_total.item()):
529 | print("Loss is {}, stopping training".format(loss_vit_total.item()), force=True)
530 | sys.exit(1)
531 |
532 | optimizer_cnn.zero_grad()
533 | param_norms = None
534 | if fp16_scaler_cnn is None:
535 | loss_cnn_total.backward()
536 | if args.clip_grad_cnn:
537 | param_norms = utils.clip_gradients(student_cnn, args.clip_grad_cnn)
538 | utils.cancel_gradients_last_layer(epoch, student_cnn, args.freeze_last_layer)
539 | optimizer_cnn.step()
540 | else:
541 | fp16_scaler_cnn.scale(loss_cnn_total).backward() # retain_graph=True
542 | if args.clip_grad_cnn:
543 | fp16_scaler_cnn.unscale_(optimizer_cnn) # unscale the gradients of optimizer's assigned params in-place
544 | param_norms = utils.clip_gradients(student_cnn, args.clip_grad_cnn)
545 | utils.cancel_gradients_last_layer(epoch, student_cnn, args.freeze_last_layer)
546 | fp16_scaler_cnn.step(optimizer_cnn)
547 | fp16_scaler_cnn.update()
548 |
549 | # EMA update for the cnn teacher
550 | with torch.no_grad():
551 | m = momentum_schedule[it]
552 | for param_q, param_k in zip(student_cnn.module.parameters(), teacher_cnn_without_ddp.parameters()):
553 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
554 |
555 | optimizer_vit.zero_grad()
556 | param_norms = None
557 | if fp16_scaler_vit is None:
558 | loss_vit_total.backward()
559 | if args.clip_grad_vit:
560 | param_norms = utils.clip_gradients(student_vit, args.clip_grad_vit)
561 | utils.cancel_gradients_last_layer(epoch, student_vit, args.freeze_last_layer)
562 | optimizer_vit.step()
563 | else:
564 | fp16_scaler_vit.scale(loss_vit_total).backward()
565 | if args.clip_grad_vit:
566 | fp16_scaler_vit.unscale_(optimizer_vit) # unscale the gradients of optimizer's assigned params in-place
567 | param_norms = utils.clip_gradients(student_vit, args.clip_grad_vit)
568 | utils.cancel_gradients_last_layer(epoch, student_vit, args.freeze_last_layer)
569 | fp16_scaler_vit.step(optimizer_vit)
570 | fp16_scaler_vit.update()
571 |
572 | # EMA update for the cnn teacher
573 | with torch.no_grad():
574 | m = momentum_schedule[it]
575 | for param_q, param_k in zip(student_vit.module.parameters(), teacher_vit_without_ddp.parameters()):
576 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
577 |
578 | # logging
579 | if args.rank == 0:
580 | summary_writer.add_scalar("loss_cnn", loss_cnn_total.item(), it)
581 | summary_writer.add_scalar("loss_vit", loss_vit_total.item(), it)
582 | summary_writer.add_scalar("lr_cnn", optimizer_cnn.param_groups[0]["lr"], it)
583 | summary_writer.add_scalar("lr_vit", optimizer_vit.param_groups[0]["lr"], it)
584 |
585 | torch.cuda.synchronize()
586 | metric_logger.update(loss_cnn=loss_cnn_m.item())
587 | metric_logger.update(loss_vit=loss_vit_m.item())
588 |
589 | metric_logger.synchronize_between_processes()
590 | print("Averaged stats:", metric_logger)
591 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
592 |
593 | if __name__ == '__main__':
594 | parser = argparse.ArgumentParser(method, parents=[get_args_parser()])
595 | args = parser.parse_args()
596 |
597 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
598 |
599 | train(args)
600 |
--------------------------------------------------------------------------------