├── src └── method.png ├── augmentations └── dino_augmentation.py ├── utils ├── preprocess.py ├── optimizers.py ├── metrics.py ├── checkpoint_io.py └── utils.py ├── README.md ├── backbones ├── resnet.py └── vision_transformer.py ├── models └── mokd.py ├── LICENSE ├── eval_knn.py ├── eval_linear.py └── main_mokd.py /src/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyoux/mokd/HEAD/src/method.png -------------------------------------------------------------------------------- /augmentations/dino_augmentation.py: -------------------------------------------------------------------------------- 1 | 2 | from torchvision import transforms 3 | from PIL import Image 4 | 5 | from utils.preprocess import GaussianBlur, Solarization 6 | 7 | 8 | class DINODataAugmentation(object): 9 | def __init__(self, global_crops_scale, local_crops_scale, local_crops_number): 10 | flip_and_color_jitter = transforms.Compose([ 11 | transforms.RandomHorizontalFlip(p=0.5), 12 | transforms.RandomApply( 13 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], 14 | p=0.8 15 | ), 16 | transforms.RandomGrayscale(p=0.2), 17 | ]) 18 | normalize = transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 21 | ]) 22 | 23 | # first global crop 24 | self.global_transfo1 = transforms.Compose([ 25 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), 26 | flip_and_color_jitter, 27 | GaussianBlur(1.0), 28 | normalize, 29 | ]) 30 | # second global crop 31 | self.global_transfo2 = transforms.Compose([ 32 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), 33 | flip_and_color_jitter, 34 | GaussianBlur(0.1), 35 | Solarization(0.2), 36 | normalize, 37 | ]) 38 | # transformation for the local small crops 39 | self.local_crops_number = local_crops_number 40 | self.local_transfo = transforms.Compose([ 41 | transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC), 42 | flip_and_color_jitter, 43 | GaussianBlur(p=0.5), 44 | normalize, 45 | ]) 46 | 47 | def __call__(self, image): 48 | crops = [] 49 | crops.append(self.global_transfo1(image)) 50 | crops.append(self.global_transfo2(image)) 51 | for _ in range(self.local_crops_number): 52 | crops.append(self.local_transfo(image)) 53 | return crops -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | from PIL import ImageFilter, ImageOps 4 | import numpy as np 5 | from numpy.random import randint 6 | 7 | import torch 8 | 9 | class GaussianBlur(object): 10 | """ 11 | Apply Gaussian Blur to the PIL image. 12 | """ 13 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 14 | self.prob = p 15 | self.radius_min = radius_min 16 | self.radius_max = radius_max 17 | 18 | def __call__(self, img): 19 | do_it = random.random() <= self.prob 20 | if not do_it: 21 | return img 22 | 23 | return img.filter( 24 | ImageFilter.GaussianBlur( 25 | radius=random.uniform(self.radius_min, self.radius_max) 26 | ) 27 | ) 28 | 29 | 30 | class Solarization(object): 31 | """ 32 | Apply Solarization to the PIL image. 33 | """ 34 | def __init__(self, p): 35 | self.p = p 36 | 37 | def __call__(self, img): 38 | if random.random() < self.p: 39 | return ImageOps.solarize(img) 40 | else: 41 | return img 42 | 43 | 44 | def drop_rand_patches(X, X_rep=None, max_drop=0.3, max_block_sz=0.25, tolr=0.05): 45 | ####################### 46 | # X_rep: replace X with patches from X_rep. If X_rep is None, replace the patches with Noise 47 | # max_drop: percentage of image to be dropped 48 | # max_block_sz: percentage of the maximum block to be dropped 49 | # tolr: minimum size of the block in terms of percentage of the image size 50 | ####################### 51 | 52 | C, H, W = X.size() 53 | n_drop_pix = np.random.uniform(0, max_drop)*H*W 54 | mx_blk_height = int(H*max_block_sz) 55 | mx_blk_width = int(W*max_block_sz) 56 | 57 | tolr = (int(tolr*H), int(tolr*W)) 58 | 59 | total_pix = 0 60 | while total_pix < n_drop_pix: 61 | 62 | # get a random block by selecting a random row, column, width, height 63 | rnd_r = randint(0, H-tolr[0]) 64 | rnd_c = randint(0, W-tolr[1]) 65 | rnd_h = min(randint(tolr[0], mx_blk_height)+rnd_r, H) #rnd_r is alread added - this is not height anymore 66 | rnd_w = min(randint(tolr[1], mx_blk_width)+rnd_c, W) 67 | 68 | if X_rep is None: 69 | X[:, rnd_r:rnd_h, rnd_c:rnd_w] = torch.empty((C, rnd_h-rnd_r, rnd_w-rnd_c), dtype=X.dtype, device='cuda').normal_() 70 | else: 71 | X[:, rnd_r:rnd_h, rnd_c:rnd_w] = X_rep[:, rnd_r:rnd_h, rnd_c:rnd_w] 72 | 73 | total_pix = total_pix + (rnd_h-rnd_r)*(rnd_w-rnd_c) 74 | 75 | return X 76 | -------------------------------------------------------------------------------- /utils/optimizers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | class LARS(torch.optim.Optimizer): 6 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, 7 | weight_decay_filter=None, lars_adaptation_filter=None): 8 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, 9 | eta=eta, weight_decay_filter=weight_decay_filter, 10 | lars_adaptation_filter=lars_adaptation_filter) 11 | super().__init__(params, defaults) 12 | 13 | @torch.no_grad() 14 | def step(self): 15 | for g in self.param_groups: 16 | for p in g['params']: 17 | dp = p.grad 18 | 19 | if dp is None: 20 | continue 21 | 22 | if p.ndim != 1: 23 | dp = dp.add(p, alpha=g['weight_decay']) 24 | 25 | if p.ndim != 1: 26 | param_norm = torch.norm(p) 27 | update_norm = torch.norm(dp) 28 | one = torch.ones_like(param_norm) 29 | q = torch.where(param_norm > 0., 30 | torch.where(update_norm > 0, 31 | (g['eta'] * param_norm / update_norm), one), one) 32 | dp = dp.mul(q) 33 | 34 | param_state = self.state[p] 35 | if 'mu' not in param_state: 36 | param_state['mu'] = torch.zeros_like(p) 37 | mu = param_state['mu'] 38 | mu.mul_(g['momentum']).add_(dp) 39 | 40 | p.add_(mu, alpha=-g['lr']) 41 | 42 | 43 | class NativeScalerWithGradNormCount: 44 | state_dict_key = "amp_scaler" 45 | 46 | def __init__(self): 47 | self._scaler = torch.cuda.amp.GradScaler() 48 | 49 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 50 | self._scaler.scale(loss).backward(create_graph=create_graph) 51 | if update_grad: 52 | if clip_grad is not None: 53 | assert parameters is not None 54 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 55 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 56 | else: 57 | self._scaler.unscale_(optimizer) 58 | norm = get_grad_norm_(parameters) 59 | self._scaler.step(optimizer) 60 | self._scaler.update() 61 | else: 62 | norm = None 63 | return norm 64 | 65 | def state_dict(self): 66 | return self._scaler.state_dict() 67 | 68 | def load_state_dict(self, state_dict): 69 | self._scaler.load_state_dict(state_dict) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Mode Online Knowledge Distillation for Self-Supervised Visual Representation Learning 2 | Official Implementation of our paper "**Multi-Mode Online Knowledge Distillation for Self-Supervised Visual Representation Learning**", in **CVPR 2023**. 3 | 4 | by Kaiyou Song, Jin Xie, Shan Zhang and Zimeng Luo. 5 | 6 | **[[arXiv]](https://arxiv.org/abs/2304.06461)** **[[Paper]](https://arxiv.org/pdf/2304.06461.pdf)** 7 | 8 | ## Method 9 | 10 | 11 | 12 | ## Usage 13 | 14 | ### ImageNet Pre-training 15 | 16 | This implementation supports **DistributedDataParallel** training; single-gpu or DataParallel training is not supported. 17 | 18 | To do pre-training of a ResNet50-ViT-S model pairs on ImageNet in an 16-gpu machine, run: 19 | 20 | ``` 21 | python3 -m torch.distributed.launch --nproc_per_node=8 \ 22 | --nnodes 2 --node_rank 0 --master_addr='100.123.45.67' --master_port='10001' \ 23 | main_mokd.py \ 24 | --arch_cnn resnet50 --arch_vit vit_small \ 25 | --out_dim 65536 --norm_last_layer False \ 26 | --clip_grad_cnn 3 --clip_grad_vit 3 --freeze_last_layer 1 \ 27 | --optimizer sgd \ 28 | --lr_cnn 0.1 --lr_vit 0.0003 --warmup_epochs 10 \ 29 | --use_fp16 True \ 30 | --warmup_teacher_temp 0.04 --teacher_temp 0.07 \ 31 | --warmup_teacher_temp_epochs_cnn 50 --warmup_teacher_temp_epochs_vit 30 \ 32 | --patch_size 16 --drop_path_rate 0.1 \ 33 | --local_crops_number 8 --global_crops_scale 0.25 1 --local_crops_scale 0.05 0.25 \ 34 | --momentum_teacher 0.996 \ 35 | --num_workers 10 \ 36 | --batch_size_per_gpu 16 --epochs 100 \ 37 | --lamda_t 0.1 --lamda_c 1.0 \ 38 | --data_path /path to imagenet/ \ 39 | --output_dir output/ \ 40 | ``` 41 | 42 | ### ImageNet Linear Classification 43 | 44 | With a pre-trained model, to train a supervised linear classifier on frozen features/weights in an 8-gpu machine, run: 45 | 46 | ``` 47 | python3 -m torch.distributed.launch --nproc_per_node=8 eval_linear.py \ 48 | --arch resnet50 \ 49 | --lr 0.01 \ 50 | --batch_size_per_gpu 256 \ 51 | --num_workers 10 \ 52 | --pretrained_weights /path to pretrained checkpoints/xxx.pth \ 53 | --checkpoint_key teacher_cnn \ 54 | --data_path /path to imagenet/ \ 55 | --output_dir output/ \ 56 | --method mokd 57 | ``` 58 | 59 | ``` 60 | python3 -m torch.distributed.launch --nproc_per_node=8 eval_linear.py \ 61 | --arch vit_small \ 62 | --n_last_blocks 4 \ 63 | --lr 0.001 \ 64 | --batch_size_per_gpu 256 \ 65 | --pretrained_weights /path to pretrained checkpoints/xxx.pth \ 66 | --checkpoint_key teacher_vit \ 67 | --data_path /path to imagenet/ \ 68 | --output_dir output/ \ 69 | --method mokd 70 | ``` 71 | 72 | ### Evaluation: k-NN classification on ImageNet 73 | To evaluate a k-NN classifier with a single GPU on a pre-trained model, run: 74 | 75 | ``` 76 | python3 -m torch.distributed.launch --nproc_per_node=8 eval_knn.py \ 77 | --arch resnet18 \ 78 | --batch_size_per_gpu 512 \ 79 | --pretrained_weights /path to pretrained checkpoints/xxx.pth \ 80 | --checkpoint_key teacher_cnn \ 81 | --num_workers 20 \ 82 | --data_path /path to imagenet/ \ 83 | --use_cuda True \ 84 | --method mokd 85 | ``` 86 | 87 | ## Acknowledgement 88 | This project is based on [DINO](https://github.com/facebookresearch/dino). 89 | Thanks for the wonderful work. 90 | 91 | ## License 92 | This project is under the Apache License 2.0 license. See [LICENSE](LICENSE) for details. 93 | 94 | ## Citation 95 | ```bibtex 96 | @InProceedings{Song_2023_CVPR, 97 | author = {Song, Kaiyou and Xie, Jin and Zhang, Shan and Luo, Zimeng}, 98 | title = {Multi-Mode Online Knowledge Distillation for Self-Supervised Visual Representation Learning}, 99 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 100 | month = {June}, 101 | year = {2023}, 102 | pages = {11848-11857} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the accuracy over the k top predictions for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | _, pred = output.topk(maxk, 1, True, True) 10 | pred = pred.t() 11 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 12 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 13 | 14 | 15 | def compute_ap(ranks, nres): 16 | """ 17 | Computes average precision for given ranked indexes. 18 | Arguments 19 | --------- 20 | ranks : zerro-based ranks of positive images 21 | nres : number of positive images 22 | Returns 23 | ------- 24 | ap : average precision 25 | """ 26 | 27 | # number of images ranked by the system 28 | nimgranks = len(ranks) 29 | 30 | # accumulate trapezoids in PR-plot 31 | ap = 0 32 | 33 | recall_step = 1. / nres 34 | 35 | for j in np.arange(nimgranks): 36 | rank = ranks[j] 37 | 38 | if rank == 0: 39 | precision_0 = 1. 40 | else: 41 | precision_0 = float(j) / rank 42 | 43 | precision_1 = float(j + 1) / (rank + 1) 44 | 45 | ap += (precision_0 + precision_1) * recall_step / 2. 46 | 47 | return ap 48 | 49 | 50 | def compute_map(ranks, gnd, kappas=[]): 51 | """ 52 | Computes the mAP for a given set of returned results. 53 | Usage: 54 | map = compute_map (ranks, gnd) 55 | computes mean average precsion (map) only 56 | map, aps, pr, prs = compute_map (ranks, gnd, kappas) 57 | computes mean average precision (map), average precision (aps) for each query 58 | computes mean precision at kappas (pr), precision at kappas (prs) for each query 59 | Notes: 60 | 1) ranks starts from 0, ranks.shape = db_size X #queries 61 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 62 | 3) If there are no positive images for some query, that query is excluded from the evaluation 63 | """ 64 | 65 | map = 0. 66 | nq = len(gnd) # number of queries 67 | aps = np.zeros(nq) 68 | pr = np.zeros(len(kappas)) 69 | prs = np.zeros((nq, len(kappas))) 70 | nempty = 0 71 | 72 | for i in np.arange(nq): 73 | qgnd = np.array(gnd[i]['ok']) 74 | 75 | # no positive images, skip from the average 76 | if qgnd.shape[0] == 0: 77 | aps[i] = float('nan') 78 | prs[i, :] = float('nan') 79 | nempty += 1 80 | continue 81 | 82 | try: 83 | qgndj = np.array(gnd[i]['junk']) 84 | except: 85 | qgndj = np.empty(0) 86 | 87 | # sorted positions of positive and junk images (0 based) 88 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)] 89 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)] 90 | 91 | k = 0; 92 | ij = 0; 93 | if len(junk): 94 | # decrease positions of positives based on the number of 95 | # junk images appearing before them 96 | ip = 0 97 | while (ip < len(pos)): 98 | while (ij < len(junk) and pos[ip] > junk[ij]): 99 | k += 1 100 | ij += 1 101 | pos[ip] = pos[ip] - k 102 | ip += 1 103 | 104 | # compute ap 105 | ap = compute_ap(pos, len(qgnd)) 106 | map = map + ap 107 | aps[i] = ap 108 | 109 | # compute precision @ k 110 | pos += 1 # get it to 1-based 111 | for j in np.arange(len(kappas)): 112 | kq = min(max(pos), kappas[j]); 113 | prs[i, j] = (pos <= kq).sum() / kq 114 | pr = pr + prs[i, :] 115 | 116 | map = map / (nq - nempty) 117 | pr = pr / (nq - nempty) 118 | 119 | return map, aps, pr, prs -------------------------------------------------------------------------------- /utils/checkpoint_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, method): 6 | if os.path.isfile(pretrained_weights): 7 | state_dict = torch.load(pretrained_weights, map_location="cpu") 8 | if checkpoint_key is not None and checkpoint_key in state_dict: 9 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 10 | print("epoch:", state_dict["epoch"]) 11 | state_dict = state_dict[checkpoint_key] 12 | 13 | if "mokd" in method : 14 | # remove `module.` prefix 15 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 16 | # remove `backbone.` prefix induced by multicrop wrapper 17 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 18 | 19 | msg = model.load_state_dict(state_dict, strict=False) 20 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) 21 | return True 22 | else: 23 | print("There is no reference weights available for this model => Random weights will be used.") 24 | return False 25 | 26 | 27 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 28 | missing_keys = [] 29 | unexpected_keys = [] 30 | error_msgs = [] 31 | # copy state_dict so _load_from_state_dict can modify it 32 | metadata = getattr(state_dict, '_metadata', None) 33 | state_dict = state_dict.copy() 34 | if metadata is not None: 35 | state_dict._metadata = metadata 36 | 37 | def load(module, prefix=''): 38 | local_metadata = {} if metadata is None else metadata.get( 39 | prefix[:-1], {}) 40 | module._load_from_state_dict( 41 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 42 | for name, child in module._modules.items(): 43 | if child is not None: 44 | load(child, prefix + name + '.') 45 | 46 | load(model, prefix=prefix) 47 | 48 | 49 | def load_pretrained_linear_weights(linear_classifier, ckpt_path, method): 50 | if not os.path.isfile(ckpt_path): 51 | print("Cannot find checkpoint at {}, Use random linear weights.".format(ckpt_path)) 52 | return False 53 | print("Found checkpoint at {}".format(ckpt_path)) 54 | # open checkpoint file 55 | checkpoint = torch.load(ckpt_path, map_location="cpu") 56 | state_dict = checkpoint['state_dict'] 57 | linear_classifier.load_state_dict(state_dict, strict=True) 58 | return True 59 | 60 | 61 | def restart_from_checkpoint(ckpt_path, run_variables=None, **kwargs): 62 | """ 63 | Re-start from checkpoint 64 | """ 65 | if not os.path.isfile(ckpt_path): 66 | print("Checkpoint not founded in {}, train from random initialization".format(ckpt_path)) 67 | return 68 | print("Found checkpoint at {}".format(ckpt_path)) 69 | 70 | # open checkpoint file 71 | checkpoint = torch.load(ckpt_path, map_location="cpu") 72 | 73 | # key is what to look for in the checkpoint file 74 | # value is the object to load 75 | # example: {'state_dict': model} 76 | for key, value in kwargs.items(): 77 | if key in checkpoint and value is not None: 78 | try: 79 | msg = value.load_state_dict(checkpoint[key], strict=False) 80 | print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckpt_path, msg)) 81 | except TypeError: 82 | try: 83 | msg = value.load_state_dict(checkpoint[key]) 84 | print("=> loaded '{}' from checkpoint: '{}'".format(key, ckpt_path)) 85 | except ValueError: 86 | print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckpt_path)) 87 | else: 88 | print("=> key '{}' not found in checkpoint: '{}'".format(key, ckpt_path)) 89 | 90 | # re load variable important for the run 91 | if run_variables is not None: 92 | for var_name in run_variables: 93 | if var_name in checkpoint: 94 | run_variables[var_name] = checkpoint[var_name] -------------------------------------------------------------------------------- /backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | from typing import Type, Any, Callable, Union, List, Optional 5 | from torchvision import models as torchvision_models 6 | from torchvision.models.resnet import BasicBlock, Bottleneck, model_urls 7 | from torch.hub import load_state_dict_from_url 8 | 9 | 10 | class ResNet(torchvision_models.ResNet): 11 | def __init__(self, block: Type[Union[BasicBlock, Bottleneck]], 12 | layers: List[int], **kwargs): 13 | super(ResNet, self).__init__(block, layers, **kwargs) 14 | 15 | def _forward_impl(self, x: Tensor) -> Tensor: 16 | # See note [TorchScript super()] 17 | x = self.conv1(x) 18 | x = self.bn1(x) 19 | x = self.relu(x) 20 | x = self.maxpool(x) 21 | 22 | x = self.layer1(x) 23 | x = self.layer2(x) 24 | x = self.layer3(x) 25 | x = self.layer4(x) 26 | 27 | # sky 28 | f4 = x 29 | 30 | x = self.avgpool(x) 31 | x = torch.flatten(x, 1) 32 | x = self.fc(x) 33 | 34 | return x, f4 35 | 36 | def _resnet( 37 | arch: str, 38 | block: Type[Union[BasicBlock, Bottleneck]], 39 | layers: List[int], 40 | pretrained: bool, 41 | progress: bool, 42 | **kwargs: Any 43 | ) -> ResNet: 44 | model = ResNet(block, layers, **kwargs) 45 | if pretrained: 46 | state_dict = load_state_dict_from_url(model_urls[arch], 47 | progress=progress) 48 | model.load_state_dict(state_dict) 49 | return model 50 | 51 | 52 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 53 | r"""ResNet-18 model from 54 | `"Deep Residual Learning for Image Recognition" `_. 55 | 56 | Args: 57 | pretrained (bool): If True, returns a model pre-trained on ImageNet 58 | progress (bool): If True, displays a progress bar of the download to stderr 59 | """ 60 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 61 | **kwargs) 62 | 63 | 64 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 65 | r"""ResNet-34 model from 66 | `"Deep Residual Learning for Image Recognition" `_. 67 | 68 | Args: 69 | pretrained (bool): If True, returns a model pre-trained on ImageNet 70 | progress (bool): If True, displays a progress bar of the download to stderr 71 | """ 72 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 73 | **kwargs) 74 | 75 | 76 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 77 | r"""ResNet-50 model from 78 | `"Deep Residual Learning for Image Recognition" `_. 79 | 80 | Args: 81 | pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | progress (bool): If True, displays a progress bar of the download to stderr 83 | """ 84 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 85 | **kwargs) 86 | 87 | 88 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 89 | r"""ResNet-101 model from 90 | `"Deep Residual Learning for Image Recognition" `_. 91 | 92 | Args: 93 | pretrained (bool): If True, returns a model pre-trained on ImageNet 94 | progress (bool): If True, displays a progress bar of the download to stderr 95 | """ 96 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 97 | **kwargs) 98 | 99 | 100 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 101 | r"""ResNet-152 model from 102 | `"Deep Residual Learning for Image Recognition" `_. 103 | 104 | Args: 105 | pretrained (bool): If True, returns a model pre-trained on ImageNet 106 | progress (bool): If True, displays a progress bar of the download to stderr 107 | """ 108 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 109 | **kwargs) 110 | 111 | 112 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 113 | r"""ResNeXt-50 32x4d model from 114 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 115 | 116 | Args: 117 | pretrained (bool): If True, returns a model pre-trained on ImageNet 118 | progress (bool): If True, displays a progress bar of the download to stderr 119 | """ 120 | kwargs['groups'] = 32 121 | kwargs['width_per_group'] = 4 122 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 123 | pretrained, progress, **kwargs) 124 | 125 | 126 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 127 | r"""ResNeXt-101 32x8d model from 128 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 129 | 130 | Args: 131 | pretrained (bool): If True, returns a model pre-trained on ImageNet 132 | progress (bool): If True, displays a progress bar of the download to stderr 133 | """ 134 | kwargs['groups'] = 32 135 | kwargs['width_per_group'] = 8 136 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 137 | pretrained, progress, **kwargs) 138 | 139 | 140 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 141 | r"""Wide ResNet-50-2 model from 142 | `"Wide Residual Networks" `_. 143 | 144 | The model is the same as ResNet except for the bottleneck number of channels 145 | which is twice larger in every block. The number of channels in outer 1x1 146 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 147 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 148 | 149 | Args: 150 | pretrained (bool): If True, returns a model pre-trained on ImageNet 151 | progress (bool): If True, displays a progress bar of the download to stderr 152 | """ 153 | kwargs['width_per_group'] = 64 * 2 154 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 155 | pretrained, progress, **kwargs) 156 | 157 | 158 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 159 | r"""Wide ResNet-101-2 model from 160 | `"Wide Residual Networks" `_. 161 | 162 | The model is the same as ResNet except for the bottleneck number of channels 163 | which is twice larger in every block. The number of channels in outer 1x1 164 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 165 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 166 | 167 | Args: 168 | pretrained (bool): If True, returns a model pre-trained on ImageNet 169 | progress (bool): If True, displays a progress bar of the download to stderr 170 | """ 171 | kwargs['width_per_group'] = 64 * 2 172 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 173 | pretrained, progress, **kwargs) 174 | -------------------------------------------------------------------------------- /models/mokd.py: -------------------------------------------------------------------------------- 1 | from operator import is_ 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributed as dist 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | from utils.utils import trunc_normal_ 9 | 10 | 11 | class DINOHead(nn.Module): 12 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 13 | super().__init__() 14 | nlayers = max(nlayers, 1) 15 | if nlayers == 1: 16 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 17 | else: 18 | layers = [nn.Linear(in_dim, hidden_dim)] 19 | if use_bn: 20 | layers.append(nn.BatchNorm1d(hidden_dim)) 21 | layers.append(nn.GELU()) 22 | for _ in range(nlayers - 2): 23 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 24 | if use_bn: 25 | layers.append(nn.BatchNorm1d(hidden_dim)) 26 | layers.append(nn.GELU()) 27 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 28 | self.mlp = nn.Sequential(*layers) 29 | self.apply(self._init_weights) 30 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 31 | self.last_layer.weight_g.data.fill_(1) 32 | if norm_last_layer: 33 | self.last_layer.weight_g.requires_grad = False 34 | 35 | def _init_weights(self, m): 36 | if isinstance(m, nn.Linear): 37 | trunc_normal_(m.weight, std=.02) 38 | if isinstance(m, nn.Linear) and m.bias is not None: 39 | nn.init.constant_(m.bias, 0) 40 | 41 | def forward(self, x): 42 | x = self.mlp(x) 43 | x = nn.functional.normalize(x, dim=-1, p=2) 44 | x = self.last_layer(x) 45 | return x 46 | 47 | 48 | class CNNStudentWrapper(nn.Module): 49 | def __init__(self, backbone, head, transhead): 50 | super(CNNStudentWrapper, self).__init__() 51 | # disable layers dedicated to ImageNet labels classification 52 | backbone.fc, backbone.head = nn.Identity(), nn.Identity() 53 | self.backbone = backbone 54 | self.head = head 55 | self.transhead = transhead 56 | 57 | def forward(self, x): 58 | fea, fea4 = self.backbone(torch.cat(x[:])) 59 | 60 | fea_trans, fea_reduce_dim = self.transhead(fea4) 61 | 62 | return self.head(fea), fea_trans, fea_reduce_dim 63 | 64 | 65 | class CNNTeacherWrapper(nn.Module): 66 | def __init__(self, backbone, head, transhead, num_crops): 67 | super(CNNTeacherWrapper, self).__init__() 68 | # disable layers dedicated to ImageNet labels classification 69 | backbone.fc, backbone.head = nn.Identity(), nn.Identity() 70 | self.backbone = backbone 71 | self.head = head 72 | self.transhead = transhead 73 | self.num_crops = num_crops 74 | 75 | 76 | def forward(self, x, local_token=None, global_token=None): 77 | fea, fea4 = self.backbone(torch.cat(x[:])) 78 | 79 | fea4_trans, _ = self.transhead(fea4) 80 | local_search_fea = None 81 | global_search_fea = None 82 | if local_token != None: 83 | local_search_fea, _ = self.transhead(fea4.chunk(2)[0].repeat(self.num_crops, 1, 1, 1), local_token) 84 | if global_token != None: 85 | global_search_fea, _ = self.transhead(fea4, global_token) 86 | return self.head(fea), fea4_trans, local_search_fea, global_search_fea 87 | 88 | 89 | class ViTStudentWrapper(nn.Module): 90 | def __init__(self, backbone, head, transhead): 91 | super(ViTStudentWrapper, self).__init__() 92 | # disable layers dedicated to ImageNet labels classification 93 | backbone.fc, backbone.head = nn.Identity(), nn.Identity() 94 | self.backbone = backbone 95 | self.head = head 96 | self.transhead = transhead 97 | 98 | def forward(self, x): 99 | tokens = self.backbone(torch.cat(x[:]), return_all_tokens=True) 100 | if isinstance(tokens, tuple): 101 | tokens = tokens[0] 102 | 103 | return self.head(tokens[:, 0]), self.transhead(tokens[:, 1:]), tokens[:, 0] 104 | 105 | 106 | class ViTTeacherWrapper(nn.Module): 107 | def __init__(self, backbone, head, transhead, num_crops): 108 | super(ViTTeacherWrapper, self).__init__() 109 | # disable layers dedicated to ImageNet labels classification 110 | backbone.fc, backbone.head = nn.Identity(), nn.Identity() 111 | self.backbone = backbone 112 | self.head = head 113 | self.transhead = transhead 114 | self.num_crops = num_crops 115 | 116 | def forward(self, x, local_token=None, global_token=None): 117 | tokens = self.backbone(torch.cat(x[:]), return_all_tokens=True) 118 | if isinstance(tokens, tuple): 119 | tokens = tokens[0] 120 | 121 | cls_token = tokens[:, 0] 122 | pth_token = tokens[:, 1:] # B,N,C 123 | 124 | local_search_fea = None 125 | global_search_fea = None 126 | if local_token != None: 127 | local_search_fea = self.transhead(pth_token.chunk(2)[0].repeat(self.num_crops, 1, 1), local_token) 128 | if global_token != None: 129 | global_search_fea = self.transhead(pth_token, global_token) 130 | return self.head(cls_token), self.transhead(pth_token), local_search_fea, global_search_fea 131 | 132 | 133 | class DINOLoss(nn.Module): 134 | def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, 135 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1, 136 | center_momentum=0.9): 137 | super().__init__() 138 | self.student_temp = student_temp 139 | self.center_momentum = center_momentum 140 | self.ncrops = ncrops 141 | self.register_buffer("center", torch.zeros(1, out_dim)) 142 | # we apply a warm up for the teacher temperature because 143 | # a too high temperature makes the training instable at the beginning 144 | self.teacher_temp_schedule = np.concatenate(( 145 | np.linspace(warmup_teacher_temp, 146 | teacher_temp, warmup_teacher_temp_epochs), 147 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp 148 | )) 149 | 150 | def forward(self, student_output, teacher_output, epoch): 151 | """ 152 | Cross-entropy between softmax outputs of the teacher and student networks. 153 | """ 154 | student_out = student_output / self.student_temp 155 | student_out = student_out.chunk(self.ncrops) 156 | 157 | # teacher centering and sharpening 158 | temp = self.teacher_temp_schedule[epoch] 159 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 160 | teacher_out = teacher_out.detach().chunk(2) 161 | 162 | total_loss = 0 163 | n_loss_terms = 0 164 | for iq, q in enumerate(teacher_out): 165 | for v in range(len(student_out)): 166 | if v == iq: 167 | # we skip cases where student and teacher operate on the same view 168 | continue 169 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 170 | total_loss += loss.mean() 171 | n_loss_terms += 1 172 | total_loss /= n_loss_terms 173 | self.update_center(teacher_output) 174 | return total_loss 175 | 176 | @torch.no_grad() 177 | def update_center(self, teacher_output): 178 | """ 179 | Update center used for teacher output. 180 | """ 181 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 182 | dist.all_reduce(batch_center) 183 | batch_center = batch_center / (len(teacher_output) * dist.get_world_size()) 184 | 185 | # ema update 186 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 187 | 188 | 189 | class CTSearchLoss(nn.Module): 190 | def __init__(self, warmup_teacher_temp, teacher_temp, 191 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1): 192 | super().__init__() 193 | self.student_temp = student_temp 194 | self.teacher_temp_schedule = np.concatenate(( 195 | np.linspace(warmup_teacher_temp, 196 | teacher_temp, warmup_teacher_temp_epochs), 197 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp 198 | )) 199 | 200 | def forward(self, student, teacher, epoch): 201 | 202 | temp = self.teacher_temp_schedule[epoch] 203 | 204 | student_out = student / self.student_temp 205 | teacher_out = F.softmax(teacher / temp, dim=-1) 206 | loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1).mean() 207 | 208 | return loss 209 | 210 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Kaiyou Song 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /eval_knn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | import argparse 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | import torch.distributed as dist 10 | import torch.backends.cudnn as cudnn 11 | from torchvision import datasets 12 | from torchvision import transforms as pth_transforms 13 | from torchvision import models as torchvision_models 14 | import timm.models as timm_models 15 | 16 | from utils import utils 17 | from utils import checkpoint_io 18 | import backbones.vision_transformer as vits 19 | 20 | 21 | def extract_feature_pipeline(args): 22 | ######################## preparing data ... ######################## 23 | resize_size = 256 if args.input_size == 224 else 512 24 | transform = pth_transforms.Compose([ 25 | pth_transforms.Resize(resize_size, interpolation=3), 26 | pth_transforms.CenterCrop(args.input_size), 27 | pth_transforms.ToTensor(), 28 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 29 | ]) 30 | 31 | dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform) 32 | dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform) 33 | 34 | train_labels = torch.tensor(dataset_train.target).long() 35 | test_labels = torch.tensor(dataset_val.target).long() 36 | 37 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) 38 | data_loader_train = torch.utils.data.DataLoader( 39 | dataset_train, 40 | sampler=sampler, 41 | batch_size=args.batch_size_per_gpu, 42 | num_workers=args.num_workers, 43 | pin_memory=False, 44 | drop_last=False, 45 | ) 46 | data_loader_val = torch.utils.data.DataLoader( 47 | dataset_val, 48 | batch_size=args.batch_size_per_gpu, 49 | num_workers=args.num_workers, 50 | pin_memory=False, 51 | drop_last=False, 52 | ) 53 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 54 | 55 | ######################## building network ... ######################## 56 | if "vit" in args.arch: 57 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 58 | # model = timm_models.__dict__[args.arch](num_classes=0) 59 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") 60 | elif args.arch in torchvision_models.__dict__.keys(): 61 | model = torchvision_models.__dict__[args.arch](num_classes=0) 62 | model.fc = nn.Identity() 63 | else: 64 | print(f"Architecture {args.arch} non supported") 65 | sys.exit(1) 66 | # print(model) 67 | model.cuda() 68 | is_load_success = checkpoint_io.load_pretrained_weights(model, args.pretrained_weights, 69 | args.checkpoint_key, args.method) 70 | if is_load_success == False: 71 | sys.exit(1) 72 | 73 | model.eval() 74 | 75 | ######################## extract features ... ######################## 76 | print("Extracting features for train set...") 77 | train_features = extract_features(model, data_loader_train, args.arch, args.avgpool_patchtokens, args.use_cuda) 78 | print("Extracting features for val set...") 79 | test_features = extract_features(model, data_loader_val, args.arch, args.avgpool_patchtokens, args.use_cuda) 80 | 81 | if utils.get_rank() == 0: 82 | train_features = nn.functional.normalize(train_features, dim=1, p=2) 83 | test_features = nn.functional.normalize(test_features, dim=1, p=2) 84 | 85 | # save features and labels 86 | if args.dump_features and dist.get_rank() == 0: 87 | torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")) 88 | torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")) 89 | torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")) 90 | torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")) 91 | return train_features, test_features, train_labels, test_labels 92 | 93 | 94 | @torch.no_grad() 95 | def extract_features(model, data_loader, arch="resnet50", avgpool_patchtokens=1, use_cuda=True, multiscale=False): 96 | metric_logger = utils.MetricLogger(delimiter=" ") 97 | features = None 98 | for samples, index in metric_logger.log_every(data_loader, 10): 99 | samples = samples.cuda(non_blocking=True) 100 | index = index.cuda(non_blocking=True) 101 | if multiscale: 102 | feats = utils.multi_scale(samples, model) 103 | else: 104 | if "resnet" in arch: 105 | feats = model(samples).clone() 106 | else: # vits 107 | output = model(samples, return_all_tokens=True) # cl_mode: class_token patch_token 108 | if isinstance(output, tuple): 109 | output = output[0] 110 | output = output.clone() 111 | if avgpool_patchtokens == 0: 112 | # norm(x[:, 0]) 113 | feats = output[:, 0].contiguous() 114 | elif avgpool_patchtokens == 1: 115 | # x[:, 1:].mean(1) 116 | feats = torch.mean(output[:, 1:], dim=1) 117 | elif avgpool_patchtokens == 2: 118 | # norm(x[:, 0]) + norm(x[:, 1:]).mean(1) 119 | feats = output[:, 0] + torch.mean(output[:, 1:], dim=1) 120 | else: 121 | assert False, "Unkown avgpool type {}".format(avgpool_patchtokens) 122 | 123 | # print(feats.shape) 124 | if len(feats.shape) != 2: 125 | feats = feats.squeeze() 126 | 127 | # init storage feature matrix 128 | if dist.get_rank() == 0 and features is None: 129 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) 130 | if use_cuda: 131 | features = features.cuda(non_blocking=True) 132 | print(f"Storing features into tensor of shape {features.shape}") 133 | 134 | # get indexes from all processes 135 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) 136 | y_l = list(y_all.unbind(0)) 137 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) 138 | y_all_reduce.wait() 139 | index_all = torch.cat(y_l) 140 | 141 | # share features between processes 142 | feats_all = torch.empty( 143 | dist.get_world_size(), 144 | feats.size(0), 145 | feats.size(1), 146 | dtype=feats.dtype, 147 | device=feats.device, 148 | ) 149 | output_l = list(feats_all.unbind(0)) 150 | output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) 151 | output_all_reduce.wait() 152 | 153 | # update storage feature matrix 154 | if dist.get_rank() == 0: 155 | if use_cuda: 156 | features.index_copy_(0, index_all, torch.cat(output_l)) 157 | else: 158 | features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) 159 | return features 160 | 161 | 162 | @torch.no_grad() 163 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000, use_cuda=True): 164 | top1, top5, total = 0.0, 0.0, 0 165 | train_features = train_features.t() 166 | num_test_images, num_chunks = test_labels.shape[0], 100 167 | imgs_per_chunk = num_test_images // num_chunks 168 | retrieval_one_hot = torch.zeros(k, num_classes) 169 | if use_cuda: 170 | retrieval_one_hot = retrieval_one_hot.cuda() 171 | for idx in range(0, num_test_images, imgs_per_chunk): 172 | # get the features for test images 173 | features = test_features[ 174 | idx : min((idx + imgs_per_chunk), num_test_images), : 175 | ] 176 | targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] 177 | batch_size = targets.shape[0] 178 | 179 | # calculate the dot product and compute top-k neighbors 180 | similarity = torch.mm(features, train_features) 181 | distances, indices = similarity.topk(k, largest=True, sorted=True) 182 | candidates = train_labels.view(1, -1).expand(batch_size, -1) #500x1281167 183 | retrieved_neighbors = torch.gather(candidates, 1, indices) # 500x10 184 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() #5000x0 185 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 186 | distances_transform = distances.clone().div_(T).exp_() 187 | probs = torch.sum( 188 | torch.mul( 189 | retrieval_one_hot.view(batch_size, -1, num_classes), 190 | distances_transform.view(batch_size, -1, 1), 191 | ), 192 | 1, 193 | ) 194 | _, predictions = probs.sort(1, True) 195 | 196 | # find the predictions that match the target 197 | correct = predictions.eq(targets.data.view(-1, 1)) 198 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 199 | top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() # top5 does not make sense if k < 5 200 | total += targets.size(0) 201 | top1 = top1 * 100.0 / total 202 | top5 = top5 * 100.0 / total 203 | return top1, top5 204 | 205 | 206 | class ReturnIndexDataset_nori(ImagenetLoader): 207 | def __getitem__(self, idx): 208 | img, lab = super(ReturnIndexDataset_nori, self).__getitem__(idx) 209 | return img, idx 210 | 211 | class ReturnIndexDataset(datasets.ImageFolder): 212 | def __getitem__(self, idx): 213 | img, lab = super(ReturnIndexDataset, self).__getitem__(idx) 214 | return img, idx 215 | 216 | def get_args_parser(): 217 | parser = argparse.ArgumentParser("KNN Evaluation", add_help=False) 218 | parser.add_argument('--input_size', default=224, type=int, help='input image size') 219 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') 220 | parser.add_argument('--nb_knn', default=[20, 10, 30], nargs='+', type=int, 221 | help='Number of NN to use. 20 is usually working the best.') 222 | parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') 223 | parser.add_argument('--temperature', default=0.07, type=float, 224 | help='Temperature used in the voting coefficient') 225 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 226 | parser.add_argument('--use_cuda', default=False, type=utils.bool_flag, 227 | help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM") 228 | parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') 229 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 230 | help='Key to use in the checkpoint (example: "teacher")') 231 | # for ViTs 232 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 233 | parser.add_argument('--avgpool_patchtokens', default=0, choices=[0, 1, 2], type=int, 234 | help="""Whether or not to use global average pooled features or the [CLS] token. 235 | We typically set this to 1 for BEiT and 0 for models with [CLS] token (e.g., DINO). 236 | we set this to 2 for base-size models with [CLS] token when doing linear classification.""") 237 | 238 | parser.add_argument('--dump_features', default=None, 239 | help='Path where to save computed features, empty for no saving') 240 | parser.add_argument('--load_features', default=None, help="""If the features have 241 | already been computed, where to find them.""") 242 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 243 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 244 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 245 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 246 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) 247 | parser.add_argument('--method', default='moco', type=str, help='model name') 248 | 249 | return parser 250 | 251 | 252 | if __name__ == '__main__': 253 | parser = argparse.ArgumentParser("KNN Evaluation", parents=[get_args_parser()]) 254 | args = parser.parse_args() 255 | 256 | utils.init_distributed_mode(args) 257 | print("git:\n {}\n".format(utils.get_sha())) 258 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 259 | cudnn.benchmark = True 260 | 261 | if args.load_features: 262 | train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) 263 | test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) 264 | train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) 265 | test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) 266 | else: 267 | # need to extract features ! 268 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args) 269 | 270 | if utils.get_rank() == 0: 271 | if args.use_cuda: 272 | train_features = train_features.cuda() 273 | test_features = test_features.cuda() 274 | train_labels = train_labels.cuda() 275 | test_labels = test_labels.cuda() 276 | else: 277 | train_features = train_features.cpu() 278 | test_features = test_features.cpu() 279 | train_labels = train_labels.cpu() 280 | test_labels = test_labels.cpu() 281 | 282 | print("Features are ready!\nStart the k-NN classification.") 283 | for k in args.nb_knn: 284 | top1, top5 = knn_classifier(train_features, train_labels, 285 | test_features, test_labels, k, args.temperature, args.num_labels, args.use_cuda) 286 | print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") 287 | dist.barrier() 288 | -------------------------------------------------------------------------------- /eval_linear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | import argparse 5 | import json 6 | from pathlib import Path 7 | import time 8 | import builtins 9 | import random 10 | 11 | import torch 12 | from torch import nn 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | import torch.backends.cudnn as cudnn 16 | from torchvision import datasets 17 | from torchvision import transforms as transforms 18 | from torchvision import models as torchvision_models 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | from utils import utils 22 | from utils import checkpoint_io 23 | from utils import optimizers 24 | from utils import metrics 25 | import backbones.vision_transformer as vits 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser("Linear Evaluation", parents=[get_args_parser()]) 29 | args = parser.parse_args() 30 | 31 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 32 | 33 | utils.init_distributed_mode(args) 34 | print("git:\n {}\n".format(utils.get_sha())) 35 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 36 | cudnn.benchmark = True 37 | 38 | train(args) 39 | 40 | 41 | def get_args_parser(): 42 | parser = argparse.ArgumentParser("Linear Evaluation", add_help=False) 43 | 44 | ################################# 45 | #### input and output parameters #### 46 | ################################# 47 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) 48 | parser.add_argument('--pretrained_weights', default='', type=str, 49 | help="Path to pretrained weights to evaluate.") 50 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 51 | help='Key to use in the checkpoint (example: "teacher")') 52 | parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints') 53 | parser.add_argument('--evaluate', dest='evaluate', type=str, default=None, help='evaluate model on validation set') 54 | parser.add_argument('--method', default='moco', type=str, help='model name') 55 | parser.add_argument('--experiment', default='exp', type=str, help='experiment name') 56 | 57 | ################################# 58 | ####model parameters #### 59 | ################################# 60 | parser.add_argument('--arch', default='resnet50', type=str, help='Architecture') 61 | parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') 62 | #for ViTs 63 | parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens 64 | for the `n` last blocks. Use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""") 65 | parser.add_argument('--avgpool_patchtokens', default=0, type=int, 66 | help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.""") 67 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 68 | 69 | ################################# 70 | #### optim parameters ### 71 | ################################# 72 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.') 73 | parser.add_argument('--optimizer', default='sgd', type=str, 74 | choices=['sgd', 'lars'], help="""Type of optimizer.""") 75 | parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of 76 | training (highest LR used during training). The learning rate is linearly scaled 77 | with the batch size, and specified here for a reference batch size of 256.""") 78 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') 79 | parser.add_argument('--val_freq', default=5, type=int, help="Epoch frequency for validation.") 80 | 81 | ################################# 82 | #### dist parameters ### 83 | ################################# 84 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training""") 85 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 86 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 87 | 88 | return parser 89 | 90 | 91 | def train(args): 92 | ######################## building network ... ######################## 93 | if args.arch in vits.__dict__.keys(): 94 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 95 | embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) 96 | # embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) 97 | # otherwise, we check if the architecture is in torchvision models 98 | elif args.arch in torchvision_models.__dict__.keys(): 99 | model = torchvision_models.__dict__[args.arch]() 100 | embed_dim = model.fc.weight.shape[1] 101 | model.fc = nn.Identity() 102 | else: 103 | print(f"Unknow architecture: {args.arch}") 104 | sys.exit(1) 105 | model.cuda() 106 | model.eval() 107 | 108 | ######################## load pretrained weights to evaluate ######################## 109 | is_load_success = checkpoint_io.load_pretrained_weights(model, args.pretrained_weights, 110 | args.checkpoint_key, args.method) 111 | if is_load_success: 112 | print(f"Model {args.arch} built.") 113 | else: 114 | sys.exit(1) 115 | 116 | linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels) 117 | linear_classifier = linear_classifier.cuda() 118 | linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) 119 | 120 | ######################## preparing data ######################## 121 | val_transform = transforms.Compose([ 122 | transforms.Resize(256, interpolation=3), 123 | transforms.CenterCrop(224), 124 | transforms.ToTensor(), 125 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 126 | ]) 127 | 128 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform) 129 | 130 | val_loader = torch.utils.data.DataLoader( 131 | dataset_val, 132 | batch_size=args.batch_size_per_gpu, 133 | num_workers=args.num_workers, 134 | pin_memory=False, 135 | ) 136 | 137 | ######################## just do evaluation ######################## 138 | if args.evaluate != None: 139 | is_load_success = checkpoint_io.load_pretrained_linear_weights(linear_classifier, args.evaluate, args.method) 140 | if is_load_success == False: 141 | sys.exit(1) 142 | test_stats = validate(val_loader, model, linear_classifier, args) 143 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 144 | return 145 | 146 | ######################## preparing data ######################## 147 | train_transform = transforms.Compose([ 148 | transforms.RandomResizedCrop(224), 149 | transforms.RandomHorizontalFlip(), 150 | transforms.ToTensor(), 151 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 152 | ]) 153 | 154 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform) 155 | 156 | sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) 157 | train_loader = torch.utils.data.DataLoader( 158 | dataset_train, 159 | sampler=sampler, 160 | batch_size=args.batch_size_per_gpu, 161 | num_workers=args.num_workers, 162 | pin_memory=False, 163 | ) 164 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 165 | 166 | ######################## preparing optimizer ######################## 167 | optimizer = torch.optim.SGD( 168 | linear_classifier.parameters(), 169 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule 170 | momentum=0.9, 171 | weight_decay=0, 172 | ) 173 | if args.optimizer == "lars": 174 | optimizer = optimizers.LARS(linear_classifier.parameters(), 175 | lr=args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., 176 | momentum=0.9, 177 | weight_decay=0, 178 | ) # to use with convnet and large batches 179 | 180 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0) 181 | 182 | ######################## optionally resume training ######################## 183 | to_restore = {"epoch": 0, "best_acc": 0.} 184 | checkpoint_io.restart_from_checkpoint( 185 | os.path.join(args.output_dir, 186 | "{}_{}_linear_{}.pth".format(args.method, args.arch, args.experiment)), 187 | run_variables=to_restore, 188 | state_dict=linear_classifier, 189 | optimizer=optimizer, 190 | scheduler=scheduler, 191 | ) 192 | start_epoch = to_restore["epoch"] 193 | best_acc = to_restore["best_acc"] 194 | 195 | summary_writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 196 | "tb", "{}_{}_linear_{}".format(args.method, args.arch, args.experiment))) if utils.is_main_process() else None 197 | 198 | ######################## start training ######################## 199 | for epoch in range(start_epoch, args.epochs): 200 | train_loader.sampler.set_epoch(epoch) 201 | 202 | train_stats = train_one_epoch(model, linear_classifier, optimizer, train_loader, epoch, summary_writer, args) 203 | scheduler.step() 204 | 205 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch} 206 | if epoch % args.val_freq == 0 or epoch == args.epochs - 1: 207 | test_stats = validate(val_loader, model, linear_classifier, args) 208 | 209 | print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 210 | best_acc = max(best_acc, test_stats["acc1"]) 211 | 212 | print(f'Max accuracy so far: {best_acc:.2f}%') 213 | log_stats = {**{k: v for k, v in log_stats.items()}, 214 | **{f'test_{k}': v for k, v in test_stats.items()}} 215 | 216 | if utils.is_main_process(): 217 | summary_writer.add_scalar("acc1", test_stats["acc1"], epoch) 218 | summary_writer.add_scalar("acc5", test_stats["acc5"], epoch) 219 | 220 | if utils.is_main_process(): 221 | with (Path(args.output_dir) / "{}_{}_linear_{}_log.txt".format(args.method, args.arch, args.experiment)).open("a") as f: 222 | f.write(json.dumps(log_stats) + "\n") 223 | save_dict = { 224 | "epoch": epoch + 1, 225 | "state_dict": linear_classifier.state_dict(), 226 | "optimizer": optimizer.state_dict(), 227 | "scheduler": scheduler.state_dict(), 228 | "best_acc": best_acc, 229 | } 230 | torch.save(save_dict, os.path.join(args.output_dir, 231 | "{}_{}_linear_{}.pth".format(args.method, args.arch, args.experiment))) 232 | 233 | if utils.is_main_process(): 234 | summary_writer.close() 235 | print("Training of the supervised linear classifier on frozen features completed.\n" 236 | "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)) 237 | time.sleep(30) 238 | 239 | 240 | def train_one_epoch(model, linear_classifier, optimizer, loader, epoch, summary_writer, args): 241 | linear_classifier.train() 242 | metric_logger = utils.MetricLogger(delimiter=" ") 243 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 244 | header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) 245 | iters_per_epoch = len(loader) 246 | for it, (inp, target) in enumerate(metric_logger.log_every(loader, 20, header)): 247 | # move to gpu 248 | inp = inp.cuda(non_blocking=True) 249 | target = target.cuda(non_blocking=True) 250 | 251 | # forward 252 | with torch.no_grad(): 253 | if "vit" in args.arch: 254 | intermediate_output = model.get_intermediate_layers(inp, args.n_last_blocks) # take n last block out 255 | output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) 256 | if args.avgpool_patchtokens: 257 | output = torch.cat((output.unsqueeze(-1), 258 | torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) 259 | output = output.reshape(output.shape[0], -1) 260 | else: 261 | output = model(inp) 262 | if len(output.shape) != 2: 263 | output = output.squeeze() 264 | output = linear_classifier(output) 265 | 266 | # compute cross entropy loss 267 | loss = nn.CrossEntropyLoss()(output, target) 268 | 269 | # compute the gradients 270 | optimizer.zero_grad() 271 | loss.backward() 272 | 273 | # step 274 | optimizer.step() 275 | 276 | # logging 277 | if utils.is_main_process(): 278 | summary_writer.add_scalar("loss", loss.item(), epoch * iters_per_epoch + it) 279 | summary_writer.add_scalar("learning rate", optimizer.param_groups[0]["lr"], 280 | epoch * iters_per_epoch + it) 281 | torch.cuda.synchronize() 282 | metric_logger.update(loss=loss.item()) 283 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 284 | # gather the stats from all processes 285 | metric_logger.synchronize_between_processes() 286 | print("Averaged stats:", metric_logger) 287 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 288 | 289 | 290 | @torch.no_grad() 291 | def validate(val_loader, model, linear_classifier, args): 292 | linear_classifier.eval() 293 | metric_logger = utils.MetricLogger(delimiter=" ") 294 | header = 'Test:' 295 | for inp, target in metric_logger.log_every(val_loader, 20, header): 296 | # move to gpu 297 | inp = inp.cuda(non_blocking=True) 298 | target = target.cuda(non_blocking=True) 299 | 300 | # forward 301 | with torch.no_grad(): 302 | if "vit" in args.arch: 303 | intermediate_output = model.get_intermediate_layers(inp, args.n_last_blocks) # take n last block out 304 | output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) 305 | if args.avgpool_patchtokens: 306 | output = torch.cat((output.unsqueeze(-1), 307 | torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) 308 | output = output.reshape(output.shape[0], -1) 309 | else: 310 | output = model(inp) 311 | output = linear_classifier(output) 312 | loss = nn.CrossEntropyLoss()(output, target) 313 | 314 | if linear_classifier.module.num_labels >= 5: 315 | acc1, acc5 = metrics.accuracy(output, target, topk=(1, 5)) 316 | else: 317 | acc1, = metrics.accuracy(output, target, topk=(1,)) 318 | 319 | batch_size = inp.shape[0] 320 | metric_logger.update(loss=loss.item()) 321 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 322 | if linear_classifier.module.num_labels >= 5: 323 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 324 | if linear_classifier.module.num_labels >= 5: 325 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 326 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 327 | else: 328 | print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}' 329 | .format(top1=metric_logger.acc1, losses=metric_logger.loss)) 330 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 331 | 332 | 333 | class LinearClassifier(nn.Module): 334 | """Linear layer to train on top of frozen features""" 335 | def __init__(self, dim, num_labels=1000): 336 | super(LinearClassifier, self).__init__() 337 | self.num_labels = num_labels 338 | self.linear = nn.Linear(dim, num_labels) 339 | self.linear.weight.data.normal_(mean=0.0, std=0.01) 340 | self.linear.bias.data.zero_() 341 | 342 | def forward(self, x): 343 | # flatten 344 | x = x.view(x.size(0), -1) 345 | 346 | # linear layer 347 | return self.linear(x) 348 | 349 | 350 | if __name__ == '__main__': 351 | main() 352 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import datetime 6 | import argparse 7 | import subprocess 8 | import warnings 9 | from collections import defaultdict, deque 10 | 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | import torch.distributed as dist 15 | 16 | 17 | class Dict(dict): 18 | __setattr__ = dict.__setitem__ 19 | __getattr__ = dict.__getitem__ 20 | 21 | 22 | def DictToObj(dictObj): 23 | if not isinstance(dictObj, dict): 24 | return dictObj 25 | d = Dict() 26 | for k, v in dictObj.items(): 27 | d[k] = DictToObj(v) 28 | return d 29 | 30 | 31 | def clip_gradients(model, clip): 32 | norms = [] 33 | for name, p in model.named_parameters(): 34 | if p.grad is not None: 35 | param_norm = p.grad.data.norm(2) 36 | norms.append(param_norm.item()) 37 | clip_coef = clip / (param_norm + 1e-6) 38 | if clip_coef < 1: 39 | p.grad.data.mul_(clip_coef) 40 | return norms 41 | 42 | 43 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): 44 | if epoch >= freeze_last_layer: 45 | return 46 | for n, p in model.named_parameters(): 47 | if "last_layer" in n: 48 | p.grad = None 49 | 50 | 51 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 52 | warmup_schedule = np.array([]) 53 | warmup_iters = warmup_epochs * niter_per_ep 54 | if warmup_epochs > 0: 55 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 56 | 57 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 58 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 59 | 60 | schedule = np.concatenate((warmup_schedule, schedule)) 61 | assert len(schedule) == epochs * niter_per_ep 62 | return schedule 63 | 64 | def piecewise_scheduler(start_value, end_value, epochs, niter_per_ep, warmup_epochs=100): 65 | 66 | warmup_iters = warmup_epochs * niter_per_ep 67 | warmup_schedule = np.ones_like(np.arange(warmup_iters)) * start_value 68 | 69 | schedule = np.ones_like(np.arange(epochs * niter_per_ep - warmup_iters)) * end_value 70 | 71 | schedule = np.concatenate((warmup_schedule, schedule)) 72 | assert len(schedule) == epochs * niter_per_ep 73 | return schedule 74 | 75 | def adjust_learning_rate(optimizer, epoch, args): 76 | """Decay the learning rate based on schedule""" 77 | lr = args.lr 78 | if args.cos_learning_rate: # cosine lr schedule 79 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 80 | else: # stepwise lr schedule 81 | for milestone in args.schedule: 82 | lr *= 0.1 if epoch >= milestone else 1. 83 | for param_group in optimizer.param_groups: 84 | param_group['lr'] = lr 85 | 86 | def adjust_learning_rate_with_fix_lr(optimizer, epoch, args): 87 | """Decay the learning rate based on schedule""" 88 | lr = args.lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 89 | for param_group in optimizer.param_groups: 90 | if 'fix_lr' in param_group and param_group['fix_lr']: 91 | param_group['lr'] = args.lr 92 | else: 93 | param_group['lr'] = lr 94 | 95 | def adjust_learning_rate_with_warmup(optimizer, epoch, args): 96 | """Decays the learning rate with half-cycle cosine after warmup""" 97 | if epoch < args.warmup_epochs: 98 | lr = args.lr * epoch / args.warmup_epochs 99 | else: 100 | lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 101 | for param_group in optimizer.param_groups: 102 | param_group['lr'] = lr 103 | return lr 104 | 105 | def adjust_ema_momentum(epoch, args): 106 | """Adjust ema momentum based on current epoch""" 107 | m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.ema_m) 108 | return m 109 | 110 | def bool_flag(s): 111 | """ 112 | Parse boolean arguments from the command line. 113 | """ 114 | FALSY_STRINGS = {"off", "false", "0"} 115 | TRUTHY_STRINGS = {"on", "true", "1"} 116 | if s.lower() in FALSY_STRINGS: 117 | return False 118 | elif s.lower() in TRUTHY_STRINGS: 119 | return True 120 | else: 121 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 122 | 123 | 124 | def fix_random_seeds(seed=31): 125 | """ 126 | Fix random seeds. 127 | """ 128 | torch.manual_seed(seed) 129 | torch.cuda.manual_seed_all(seed) 130 | np.random.seed(seed) 131 | 132 | 133 | class SmoothedValue(object): 134 | """Track a series of values and provide access to smoothed values over a 135 | window or the global series average. 136 | """ 137 | 138 | def __init__(self, window_size=20, fmt=None): 139 | if fmt is None: 140 | fmt = "{median:.6f} ({global_avg:.6f})" 141 | self.deque = deque(maxlen=window_size) 142 | self.total = 0.0 143 | self.count = 0 144 | self.fmt = fmt 145 | 146 | def update(self, value, n=1): 147 | self.deque.append(value) 148 | self.count += n 149 | self.total += value * n 150 | 151 | def synchronize_between_processes(self): 152 | """ 153 | Warning: does not synchronize the deque! 154 | """ 155 | if not is_dist_avail_and_initialized(): 156 | return 157 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 158 | dist.barrier() 159 | dist.all_reduce(t) 160 | t = t.tolist() 161 | self.count = int(t[0]) 162 | self.total = t[1] 163 | 164 | @property 165 | def median(self): 166 | d = torch.tensor(list(self.deque)) 167 | return d.median().item() 168 | 169 | @property 170 | def avg(self): 171 | d = torch.tensor(list(self.deque), dtype=torch.float32) 172 | return d.mean().item() 173 | 174 | @property 175 | def global_avg(self): 176 | return self.total / self.count 177 | 178 | @property 179 | def max(self): 180 | return max(self.deque) 181 | 182 | @property 183 | def value(self): 184 | return self.deque[-1] 185 | 186 | def __str__(self): 187 | return self.fmt.format( 188 | median=self.median, 189 | avg=self.avg, 190 | global_avg=self.global_avg, 191 | max=self.max, 192 | value=self.value) 193 | 194 | 195 | def reduce_dict(input_dict, average=True): 196 | """ 197 | Args: 198 | input_dict (dict): all the values will be reduced 199 | average (bool): whether to do average or sum 200 | Reduce the values in the dictionary from all processes so that all processes 201 | have the averaged results. Returns a dict with the same fields as 202 | input_dict, after reduction. 203 | """ 204 | world_size = get_world_size() 205 | if world_size < 2: 206 | return input_dict 207 | with torch.no_grad(): 208 | names = [] 209 | values = [] 210 | # sort the keys so that they are consistent across processes 211 | for k in sorted(input_dict.keys()): 212 | names.append(k) 213 | values.append(input_dict[k]) 214 | values = torch.stack(values, dim=0) 215 | dist.all_reduce(values) 216 | if average: 217 | values /= world_size 218 | reduced_dict = {k: v for k, v in zip(names, values)} 219 | return reduced_dict 220 | 221 | 222 | class MetricLogger(object): 223 | def __init__(self, delimiter="\t"): 224 | self.meters = defaultdict(SmoothedValue) 225 | self.delimiter = delimiter 226 | 227 | def update(self, **kwargs): 228 | for k, v in kwargs.items(): 229 | if isinstance(v, torch.Tensor): 230 | v = v.item() 231 | assert isinstance(v, (float, int)) 232 | self.meters[k].update(v) 233 | 234 | def __getattr__(self, attr): 235 | if attr in self.meters: 236 | return self.meters[attr] 237 | if attr in self.__dict__: 238 | return self.__dict__[attr] 239 | raise AttributeError("'{}' object has no attribute '{}'".format( 240 | type(self).__name__, attr)) 241 | 242 | def __str__(self): 243 | loss_str = [] 244 | for name, meter in self.meters.items(): 245 | loss_str.append( 246 | "{}: {}".format(name, str(meter)) 247 | ) 248 | return self.delimiter.join(loss_str) 249 | 250 | def synchronize_between_processes(self): 251 | for meter in self.meters.values(): 252 | meter.synchronize_between_processes() 253 | 254 | def add_meter(self, name, meter): 255 | self.meters[name] = meter 256 | 257 | def log_every(self, iterable, print_freq, header=None): 258 | i = 0 259 | if not header: 260 | header = '' 261 | start_time = time.time() 262 | end = time.time() 263 | iter_time = SmoothedValue(fmt='{avg:.6f}') 264 | data_time = SmoothedValue(fmt='{avg:.6f}') 265 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 266 | if torch.cuda.is_available(): 267 | log_msg = self.delimiter.join([ 268 | header, 269 | '[{0' + space_fmt + '}/{1}]', 270 | 'eta: {eta}', 271 | '{meters}', 272 | 'time: {time}', 273 | 'data: {data}', 274 | 'max mem: {memory:.0f}' 275 | ]) 276 | else: 277 | log_msg = self.delimiter.join([ 278 | header, 279 | '[{0' + space_fmt + '}/{1}]', 280 | 'eta: {eta}', 281 | '{meters}', 282 | 'time: {time}', 283 | 'data: {data}' 284 | ]) 285 | MB = 1024.0 * 1024.0 286 | for obj in iterable: 287 | data_time.update(time.time() - end) 288 | yield obj 289 | iter_time.update(time.time() - end) 290 | if i % print_freq == 0 or i == len(iterable) - 1: 291 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 292 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 293 | if torch.cuda.is_available(): 294 | print(log_msg.format( 295 | i, len(iterable), eta=eta_string, 296 | meters=str(self), 297 | time=str(iter_time), data=str(data_time), 298 | memory=torch.cuda.max_memory_allocated() / MB)) 299 | else: 300 | print(log_msg.format( 301 | i, len(iterable), eta=eta_string, 302 | meters=str(self), 303 | time=str(iter_time), data=str(data_time))) 304 | i += 1 305 | end = time.time() 306 | total_time = time.time() - start_time 307 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 308 | print('{} Total time: {} ({:.6f} s / it)'.format( 309 | header, total_time_str, total_time / len(iterable))) 310 | 311 | 312 | def get_sha(): 313 | cwd = os.path.dirname(os.path.abspath(__file__)) 314 | 315 | def _run(command): 316 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 317 | sha = 'N/A' 318 | diff = "clean" 319 | branch = 'N/A' 320 | try: 321 | sha = _run(['git', 'rev-parse', 'HEAD']) 322 | subprocess.check_output(['git', 'diff'], cwd=cwd) 323 | diff = _run(['git', 'diff-index', 'HEAD']) 324 | diff = "has uncommited changes" if diff else "clean" 325 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 326 | except Exception: 327 | pass 328 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 329 | return message 330 | 331 | 332 | def is_dist_avail_and_initialized(): 333 | if not dist.is_available(): 334 | return False 335 | if not dist.is_initialized(): 336 | return False 337 | return True 338 | 339 | 340 | def get_world_size(): 341 | if not is_dist_avail_and_initialized(): 342 | return 1 343 | return dist.get_world_size() 344 | 345 | 346 | def get_rank(): 347 | if not is_dist_avail_and_initialized(): 348 | return 0 349 | return dist.get_rank() 350 | 351 | 352 | def is_main_process(): 353 | return get_rank() == 0 354 | 355 | 356 | def save_on_master(*args, **kwargs): 357 | if is_main_process(): 358 | torch.save(*args, **kwargs) 359 | 360 | 361 | def setup_for_distributed(is_master): 362 | """ 363 | This function disables printing when not in master process 364 | """ 365 | import builtins as __builtin__ 366 | builtin_print = __builtin__.print 367 | 368 | def print(*args, **kwargs): 369 | force = kwargs.pop('force', False) 370 | if is_master or force: 371 | builtin_print(*args, **kwargs) 372 | 373 | __builtin__.print = print 374 | 375 | 376 | def init_distributed_mode_global(args): 377 | if args.dist_url == "env://" and args.world_size == -1: 378 | args.world_size = int(os.environ["WORLD_SIZE"]) 379 | 380 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 381 | args.ngpus_per_node = torch.cuda.device_count() 382 | if args.multiprocessing_distributed: 383 | args.world_size = args.ngpus_per_node * args.world_size 384 | 385 | 386 | def init_distributed_mode(args): 387 | # launched with torch.distributed.launch 388 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 389 | args.rank = int(os.environ["RANK"]) 390 | args.world_size = int(os.environ['WORLD_SIZE']) 391 | args.gpu = int(os.environ['LOCAL_RANK']) 392 | # launched with submitit on a slurm cluster 393 | elif 'SLURM_PROCID' in os.environ: 394 | args.rank = int(os.environ['SLURM_PROCID']) 395 | args.gpu = args.rank % torch.cuda.device_count() 396 | elif torch.cuda.is_available(): 397 | print('Run the code on one GPU.') 398 | args.rank, args.gpu, args.world_size = 0, 0, 1 399 | os.environ['MASTER_ADDR'] = '127.0.0.1' 400 | os.environ['MASTER_PORT'] = '29500' 401 | else: 402 | print('Not support training without GPU.') 403 | sys.exit(1) 404 | 405 | dist.init_process_group( 406 | backend='nccl', 407 | init_method=args.dist_url, 408 | world_size=args.world_size, 409 | rank=args.rank, 410 | ) 411 | 412 | torch.cuda.set_device(args.gpu) 413 | print('| distributed init (rank {}): {}'.format( 414 | args.rank, args.dist_url), flush=True) 415 | dist.barrier() 416 | setup_for_distributed(args.rank == 0) 417 | 418 | 419 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 420 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 421 | def norm_cdf(x): 422 | # Computes standard normal cumulative distribution function 423 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 424 | 425 | if (mean < a - 2 * std) or (mean > b + 2 * std): 426 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 427 | "The distribution of values may be incorrect.", 428 | stacklevel=2) 429 | 430 | with torch.no_grad(): 431 | l = norm_cdf((a - mean) / std) 432 | u = norm_cdf((b - mean) / std) 433 | 434 | # Uniformly fill tensor with values from [l, u], then translate to 435 | # [2l-1, 2u-1]. 436 | tensor.uniform_(2 * l - 1, 2 * u - 1) 437 | 438 | # Use inverse cdf transform for normal distribution to get truncated 439 | # standard normal 440 | tensor.erfinv_() 441 | 442 | # Transform to proper mean, std 443 | tensor.mul_(std * math.sqrt(2.)) 444 | tensor.add_(mean) 445 | 446 | # Clamp to ensure it's in the proper range 447 | tensor.clamp_(min=a, max=b) 448 | return tensor 449 | 450 | 451 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 452 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 453 | 454 | 455 | def get_params_groups(model): 456 | regularized = [] 457 | not_regularized = [] 458 | for name, param in model.named_parameters(): 459 | if not param.requires_grad: 460 | continue 461 | # we do not regularize biases nor Norm parameters 462 | if name.endswith(".bias") or len(param.shape) == 1: 463 | not_regularized.append(param) 464 | else: 465 | regularized.append(param) 466 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 467 | 468 | 469 | def get_params_groups2(model, head_name): 470 | regularized = [] 471 | not_regularized = [] 472 | transhead = [] 473 | for name, param in model.named_parameters(): 474 | if not param.requires_grad: 475 | continue 476 | if head_name in name: 477 | transhead.append(param) 478 | continue 479 | # we do not regularize biases nor Norm parameters 480 | if name.endswith(".bias") or len(param.shape) == 1: 481 | not_regularized.append(param) 482 | else: 483 | regularized.append(param) 484 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}, {'params': transhead}] 485 | 486 | 487 | def has_batchnorms(model): 488 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 489 | for name, module in model.named_modules(): 490 | if isinstance(module, bn_types): 491 | return True 492 | return False 493 | 494 | 495 | def multi_scale(samples, model): 496 | v = None 497 | for s in [1, 1/2**(1/2), 1/2]: # we use 3 different scales 498 | if s == 1: 499 | inp = samples.clone() 500 | else: 501 | inp = nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False) 502 | feats = model(inp).clone() 503 | if v is None: 504 | v = feats 505 | else: 506 | v += feats 507 | v /= 3 508 | v /= v.norm() 509 | return v 510 | -------------------------------------------------------------------------------- /backbones/vision_transformer.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Mostly copy-paste from timm library. 4 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 5 | """ 6 | import math 7 | from functools import partial, reduce 8 | from operator import mul 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from timm.models.vision_transformer import VisionTransformer, _cfg 14 | from timm.models.layers.helpers import to_2tuple 15 | # from timm.models.layers import PatchEmbed 16 | 17 | from utils.utils import trunc_normal_ 18 | 19 | __all__ = [ 20 | 'vit_small', 21 | 'vit_base', 22 | 'vit_conv_small', 23 | 'vit_conv_base', 24 | ] 25 | 26 | def drop_path(x, drop_prob: float = 0., training: bool = False): 27 | if drop_prob == 0. or not training: 28 | return x 29 | keep_prob = 1 - drop_prob 30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 32 | random_tensor.floor_() # binarize 33 | output = x.div(keep_prob) * random_tensor 34 | return output 35 | 36 | 37 | class DropPath(nn.Module): 38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 39 | """ 40 | def __init__(self, drop_prob=None): 41 | super(DropPath, self).__init__() 42 | self.drop_prob = drop_prob 43 | 44 | def forward(self, x): 45 | return drop_path(x, self.drop_prob, self.training) 46 | 47 | 48 | class Mlp(nn.Module): 49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 50 | super().__init__() 51 | out_features = out_features or in_features 52 | hidden_features = hidden_features or in_features 53 | self.fc1 = nn.Linear(in_features, hidden_features) 54 | self.act = act_layer() 55 | self.fc2 = nn.Linear(hidden_features, out_features) 56 | self.drop = nn.Dropout(drop) 57 | 58 | def forward(self, x): 59 | x = self.fc1(x) 60 | x = self.act(x) 61 | x = self.drop(x) 62 | x = self.fc2(x) 63 | x = self.drop(x) 64 | return x 65 | 66 | 67 | class Attention(nn.Module): 68 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 69 | super().__init__() 70 | self.num_heads = num_heads 71 | head_dim = dim // num_heads 72 | self.scale = qk_scale or head_dim ** -0.5 73 | 74 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 75 | self.attn_drop = nn.Dropout(attn_drop) 76 | self.proj = nn.Linear(dim, dim) 77 | self.proj_drop = nn.Dropout(proj_drop) 78 | 79 | def forward(self, x): 80 | B, N, C = x.shape 81 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 82 | q, k, v = qkv[0], qkv[1], qkv[2] 83 | 84 | attn = (q @ k.transpose(-2, -1)) * self.scale 85 | attn = attn.softmax(dim=-1) 86 | attn = self.attn_drop(attn) 87 | 88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 89 | x = self.proj(x) 90 | x = self.proj_drop(x) 91 | return x, attn 92 | 93 | 94 | class Block(nn.Module): 95 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 96 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 97 | super().__init__() 98 | self.norm1 = norm_layer(dim) 99 | self.attn = Attention( 100 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 101 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 102 | self.norm2 = norm_layer(dim) 103 | mlp_hidden_dim = int(dim * mlp_ratio) 104 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 105 | 106 | def forward(self, x, return_attention=False): 107 | y, attn = self.attn(self.norm1(x)) 108 | if return_attention: 109 | return attn 110 | x = x + self.drop_path(y) 111 | x = x + self.drop_path(self.mlp(self.norm2(x))) 112 | return x 113 | 114 | 115 | class PatchEmbed(nn.Module): 116 | """ Image to Patch Embedding 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 119 | super().__init__() 120 | num_patches = (img_size // patch_size) * (img_size // patch_size) 121 | self.img_size = img_size 122 | self.patch_size = patch_size 123 | self.num_patches = num_patches 124 | 125 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 126 | 127 | def forward(self, x): 128 | B, C, H, W = x.shape 129 | x = self.proj(x).flatten(2).transpose(1, 2) 130 | return x 131 | 132 | 133 | class VisionTransformer(nn.Module): 134 | """ Vision Transformer """ 135 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 136 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 137 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 138 | super().__init__() 139 | self.num_features = self.embed_dim = embed_dim 140 | 141 | self.patch_embed = PatchEmbed( 142 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 143 | num_patches = self.patch_embed.num_patches 144 | 145 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 146 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 147 | self.pos_drop = nn.Dropout(p=drop_rate) 148 | 149 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 150 | self.blocks = nn.ModuleList([ 151 | Block( 152 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 153 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 154 | for i in range(depth)]) 155 | self.norm = norm_layer(embed_dim) 156 | 157 | # Classifier head 158 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 159 | 160 | trunc_normal_(self.pos_embed, std=.02) 161 | trunc_normal_(self.cls_token, std=.02) 162 | self.apply(self._init_weights) 163 | 164 | def _init_weights(self, m): 165 | if isinstance(m, nn.Linear): 166 | trunc_normal_(m.weight, std=.02) 167 | if isinstance(m, nn.Linear) and m.bias is not None: 168 | nn.init.constant_(m.bias, 0) 169 | elif isinstance(m, nn.LayerNorm): 170 | nn.init.constant_(m.bias, 0) 171 | nn.init.constant_(m.weight, 1.0) 172 | 173 | def interpolate_pos_encoding(self, x, w, h): 174 | npatch = x.shape[1] - 1 175 | N = self.pos_embed.shape[1] - 1 176 | if npatch == N and w == h: 177 | return self.pos_embed 178 | class_pos_embed = self.pos_embed[:, 0] 179 | patch_pos_embed = self.pos_embed[:, 1:] 180 | dim = x.shape[-1] 181 | w0 = w // self.patch_embed.patch_size 182 | h0 = h // self.patch_embed.patch_size 183 | # we add a small number to avoid floating point error in the interpolation 184 | # see discussion at https://github.com/facebookresearch/dino/issues/8 185 | w0, h0 = w0 + 0.1, h0 + 0.1 186 | patch_pos_embed = nn.functional.interpolate( 187 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 188 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 189 | mode='bicubic', 190 | ) 191 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 192 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 193 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 194 | 195 | def prepare_tokens(self, x): 196 | B, nc, w, h = x.shape 197 | x = self.patch_embed(x) # patch linear embedding 198 | 199 | # add the [CLS] token to the embed patch tokens 200 | cls_tokens = self.cls_token.expand(B, -1, -1) 201 | x = torch.cat((cls_tokens, x), dim=1) 202 | 203 | # add positional encoding to each token 204 | x = x + self.interpolate_pos_encoding(x, w, h) 205 | 206 | return self.pos_drop(x) 207 | 208 | def forward(self, x, return_all_tokens=False, return_attention=False): 209 | x = self.prepare_tokens(x) 210 | 211 | if not return_attention: 212 | for blk in self.blocks: 213 | x = blk(x) 214 | x = self.norm(x) 215 | if return_all_tokens: 216 | return x, None 217 | else: 218 | return x[:, 0], None 219 | else: 220 | for i, blk in enumerate(self.blocks): 221 | if i < len(self.blocks) - 1: 222 | x = blk(x) 223 | else: 224 | # return attention of the last block 225 | out = blk(x) 226 | with torch.no_grad(): 227 | attentions = blk(x, return_attention=True) # B nHead N N 228 | out = self.norm(out) 229 | # generate mask 230 | with torch.no_grad(): 231 | # keep only the output patch attention 232 | attentions = attentions[:, :, 0, 1:] # B nHead N 233 | # attentions = torch.mean(attentions, dim=1) # B N 234 | if return_all_tokens: 235 | return out, attentions 236 | else: 237 | return out[:, 0], attentions 238 | 239 | def get_last_selfattention(self, x): 240 | x = self.prepare_tokens(x) 241 | for i, blk in enumerate(self.blocks): 242 | if i < len(self.blocks) - 1: 243 | x = blk(x) 244 | else: 245 | # return attention of the last block 246 | return blk(x, return_attention=True) 247 | 248 | def get_selfattention(self, x, block_index=11): 249 | x = self.prepare_tokens(x) 250 | if block_index == 0: 251 | return self.blocks[0](x, return_attention=True) 252 | 253 | for i, blk in enumerate(self.blocks): 254 | if i < block_index: 255 | x = blk(x) 256 | else: 257 | return blk(x, return_attention=True) 258 | 259 | def get_layer(self, x, block_index=11): 260 | x = self.prepare_tokens(x) 261 | if block_index == 0: 262 | return self.blocks[0](x) 263 | 264 | for i, blk in enumerate(self.blocks): 265 | if i < block_index: 266 | x = blk(x) 267 | else: 268 | return blk(x) 269 | 270 | def get_intermediate_layers(self, x, n=1): 271 | x = self.prepare_tokens(x) 272 | # we return the output tokens from the `n` last blocks 273 | output = [] 274 | for i, blk in enumerate(self.blocks): 275 | x = blk(x) 276 | if len(self.blocks) - i <= n: 277 | output.append(self.norm(x)) 278 | return output 279 | 280 | 281 | def vit_tiny(patch_size=16, **kwargs): 282 | model = VisionTransformer( 283 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 284 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 285 | return model 286 | 287 | 288 | def vit_small(patch_size=16, **kwargs): 289 | model = VisionTransformer( 290 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 291 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 292 | return model 293 | 294 | 295 | def vit_base(patch_size=16, **kwargs): 296 | model = VisionTransformer( 297 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 298 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 299 | return model 300 | 301 | def vit_large(patch_size=16, **kwargs): 302 | model = VisionTransformer( 303 | patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 304 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 305 | return model 306 | 307 | def vit_huge(patch_size=16, **kwargs): 308 | model = VisionTransformer( 309 | patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 310 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 311 | return model 312 | 313 | # input is a 4-d feature map with shape B C H W 314 | class VisionTransformerHead(nn.Module): 315 | """ Vision Transformer """ 316 | def __init__(self, featuremap_size=7, in_chans=2048, embed_dim=384, depth=12, 317 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 318 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 319 | super().__init__() 320 | 321 | self.num_features = self.embed_dim = embed_dim 322 | self.num_patches = featuremap_size * featuremap_size 323 | 324 | self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=1, stride=1) 325 | 326 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) 327 | self.pos_drop = nn.Dropout(p=drop_rate) 328 | 329 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 330 | self.blocks = nn.ModuleList([ 331 | Block( 332 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 333 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 334 | for i in range(depth)]) 335 | self.norm = norm_layer(embed_dim) 336 | 337 | trunc_normal_(self.pos_embed, std=.02) 338 | self.apply(self._init_weights) 339 | 340 | 341 | def _init_weights(self, m): 342 | if isinstance(m, nn.Linear): 343 | trunc_normal_(m.weight, std=.02) 344 | if isinstance(m, nn.Linear) and m.bias is not None: 345 | nn.init.constant_(m.bias, 0) 346 | elif isinstance(m, nn.LayerNorm): 347 | nn.init.constant_(m.bias, 0) 348 | nn.init.constant_(m.weight, 1.0) 349 | 350 | 351 | def prepare_tokens(self, x): 352 | B, nc, h, w = x.shape 353 | x = self.patch_embed(x).flatten(2).transpose(1, 2) # patch linear embedding B N C 354 | 355 | # add positional encoding to each token 356 | x = x + self.pos_embed 357 | 358 | return self.pos_drop(x) 359 | 360 | 361 | def forward(self, x, return_all_tokens=False, return_attention=False): 362 | x = self.prepare_tokens(x) 363 | 364 | if not return_attention: 365 | for blk in self.blocks: 366 | x = blk(x) 367 | x = self.norm(x) 368 | if return_all_tokens: 369 | return x, None 370 | else: 371 | return x[:, 0], None 372 | else: 373 | for i, blk in enumerate(self.blocks): 374 | if i < len(self.blocks) - 1: 375 | x = blk(x) 376 | else: 377 | # return attention of the last block 378 | out = blk(x) 379 | with torch.no_grad(): 380 | attentions = blk(x, return_attention=True) # B nHead N N 381 | out = self.norm(out) 382 | # generate mask 383 | with torch.no_grad(): 384 | # keep only the output patch attention 385 | attentions = attentions[:, :, 0, 1:] # B nHead N 386 | # attentions = torch.mean(attentions, dim=1) # B N 387 | if return_all_tokens: 388 | return out, attentions 389 | else: 390 | return out[:, 0], attentions 391 | 392 | 393 | # input is a 4-d feature map with shape B C H W 394 | class THead_CNN(nn.Module): 395 | """ Vision Transformer """ 396 | def __init__(self, out_dim, featuremap_size=7, in_chans=2048, embed_dim=384, depth=12, 397 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 398 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 399 | super().__init__() 400 | 401 | self.num_features = self.embed_dim = embed_dim 402 | self.num_patches = featuremap_size * featuremap_size 403 | 404 | self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=1, stride=1) 405 | 406 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) 407 | self.pos_drop = nn.Dropout(p=drop_rate) 408 | 409 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 410 | self.blocks = nn.ModuleList([ 411 | Block( 412 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 413 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 414 | for i in range(depth)]) 415 | self.norm = norm_layer(embed_dim) 416 | 417 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 418 | 419 | # self.last_layer = nn.Linear(embed_dim, out_dim, bias=False) 420 | self.last_layer = nn.utils.weight_norm(nn.Linear(embed_dim, out_dim, bias=False)) 421 | self.last_layer.weight_g.data.fill_(1) 422 | 423 | trunc_normal_(self.pos_embed, std=.02) 424 | self.apply(self._init_weights) 425 | 426 | def _init_weights(self, m): 427 | if isinstance(m, nn.Linear): 428 | trunc_normal_(m.weight, std=.02) 429 | if isinstance(m, nn.Linear) and m.bias is not None: 430 | nn.init.constant_(m.bias, 0) 431 | elif isinstance(m, nn.LayerNorm): 432 | nn.init.constant_(m.bias, 0) 433 | nn.init.constant_(m.weight, 1.0) 434 | 435 | def interpolate_pos_encoding(self, x, w, h): 436 | npatch = x.shape[1] 437 | N = self.pos_embed.shape[1] 438 | if npatch == N and w == h: 439 | return self.pos_embed 440 | patch_pos_embed = self.pos_embed 441 | dim = x.shape[-1] 442 | # w0 = w // self.patch_embed.patch_size 443 | # h0 = h // self.patch_embed.patch_size 444 | # we add a small number to avoid floating point error in the interpolation 445 | # see discussion at https://github.com/facebookresearch/dino/issues/8 446 | w0, h0 = w + 0.1, h + 0.1 447 | patch_pos_embed = nn.functional.interpolate( 448 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 449 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 450 | mode='bicubic', 451 | ) 452 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 453 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 454 | return patch_pos_embed 455 | 456 | def prepare_tokens(self, x): 457 | B, nc, h, w = x.shape 458 | x = self.patch_embed(x) # reduce demision 459 | x_pool = self.avgpool(x).flatten(1) 460 | 461 | x = x.flatten(2).transpose(1, 2) # patch linear embedding B N C 462 | 463 | # add positional encoding to each token 464 | # x = x + self.pos_embed 465 | x = x + self.interpolate_pos_encoding(x, w, h) 466 | 467 | # return self.pos_drop(x) 468 | return x, x_pool 469 | 470 | def forward(self, x, query_token=None): 471 | x, x_pool = self.prepare_tokens(x) 472 | 473 | # add query token 474 | if query_token is not None: 475 | query_token = query_token.unsqueeze(1) # B C => B 1 C 476 | x = torch.cat((query_token, x), dim=1) 477 | 478 | for blk in self.blocks: 479 | x = blk(x) 480 | x = self.norm(x) 481 | if query_token is not None: 482 | x = x[:, 1:] 483 | x = torch.mean(x, dim=1) 484 | x = self.last_layer(x) # B out_dim 485 | return x, x_pool 486 | 487 | 488 | # input is a 4-d feature map with shape B N C 489 | class THead_ViT(nn.Module): 490 | """ Vision Transformer """ 491 | def __init__(self, out_dim, featuremap_size=14, embed_dim=384, depth=12, 492 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 493 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 494 | super().__init__() 495 | 496 | self.num_features = self.embed_dim = embed_dim 497 | self.num_patches = featuremap_size * featuremap_size 498 | 499 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) 500 | self.pos_drop = nn.Dropout(p=drop_rate) 501 | 502 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 503 | self.blocks = nn.ModuleList([ 504 | Block( 505 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 506 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 507 | for i in range(depth)]) 508 | self.norm = norm_layer(embed_dim) 509 | 510 | # self.last_layer = nn.Linear(embed_dim, out_dim, bias=False) 511 | self.last_layer = nn.utils.weight_norm(nn.Linear(embed_dim, out_dim, bias=False)) 512 | self.last_layer.weight_g.data.fill_(1) 513 | 514 | trunc_normal_(self.pos_embed, std=.02) 515 | self.apply(self._init_weights) 516 | 517 | def _init_weights(self, m): 518 | if isinstance(m, nn.Linear): 519 | trunc_normal_(m.weight, std=.02) 520 | if isinstance(m, nn.Linear) and m.bias is not None: 521 | nn.init.constant_(m.bias, 0) 522 | elif isinstance(m, nn.LayerNorm): 523 | nn.init.constant_(m.bias, 0) 524 | nn.init.constant_(m.weight, 1.0) 525 | 526 | def interpolate_pos_encoding(self, x, w, h): 527 | npatch = x.shape[1] 528 | N = self.pos_embed.shape[1] 529 | if npatch == N and w == h: 530 | return self.pos_embed 531 | patch_pos_embed = self.pos_embed 532 | dim = x.shape[-1] 533 | # w0 = w // self.patch_embed.patch_size 534 | # h0 = h // self.patch_embed.patch_size 535 | # we add a small number to avoid floating point error in the interpolation 536 | # see discussion at https://github.com/facebookresearch/dino/issues/8 537 | w0, h0 = w + 0.1, h + 0.1 538 | patch_pos_embed = nn.functional.interpolate( 539 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 540 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 541 | mode='bicubic', 542 | ) 543 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 544 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 545 | return patch_pos_embed 546 | 547 | def prepare_tokens(self, x): 548 | B, N, C = x.shape 549 | size = int(math.sqrt(N)) 550 | 551 | # add positional encoding to each token 552 | # x = x + self.pos_embed 553 | x = x + self.interpolate_pos_encoding(x, size, size) 554 | 555 | # return self.pos_drop(x) 556 | return x 557 | 558 | def forward(self, x, query_token=None): 559 | x = self.prepare_tokens(x) 560 | 561 | # add query token 562 | if query_token is not None: 563 | query_token = query_token.unsqueeze(1) # B C => B 1 C 564 | x = torch.cat((query_token, x), dim=1) 565 | 566 | for blk in self.blocks: 567 | x = blk(x) 568 | x = self.norm(x) 569 | if query_token is not None: 570 | x = x[:, 1:] 571 | x = torch.mean(x, dim=1) 572 | x = self.last_layer(x) # B out_dim 573 | return x 574 | 575 | -------------------------------------------------------------------------------- /main_mokd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os 4 | import sys 5 | import builtins 6 | import datetime 7 | import time 8 | import math 9 | import json 10 | from pathlib import Path 11 | from functools import partial 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | # from torchvision import models as torchvision_models 17 | from torch.utils.tensorboard import SummaryWriter 18 | from torchvision import datasets 19 | 20 | from utils import utils 21 | from utils import checkpoint_io 22 | from utils import optimizers 23 | from augmentations.dino_augmentation import DINODataAugmentation 24 | import backbones.vision_transformer as vits 25 | import backbones.resnet as resnets 26 | from models.mokd import * 27 | 28 | method = "mokd" 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser(method, add_help=False) 32 | 33 | ################################# 34 | #### input and output parameters #### 35 | ################################# 36 | parser.add_argument('--data_path', default='/path/to/imagenet/train/', type=str, 37 | help='Please specify path to the ImageNet training data.') 38 | parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.') 39 | parser.add_argument('--saveckp_freq', default=10, type=int, help='Save checkpoint every x epochs.') 40 | parser.add_argument('--experiment', default='exp', type=str, help='experiment name') 41 | 42 | ################################# 43 | #### augmentation parameters #### 44 | ################################# 45 | # multi-crop parameters 46 | parser.add_argument('--input_size', default=224, type=int, help='input image size') 47 | parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.), 48 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 49 | Used for large global view cropping.""") 50 | parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small 51 | local views to generate.""") 52 | parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4), 53 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 54 | Used for small local view cropping of multi-crop.""") 55 | 56 | ################################# 57 | ####model parameters #### 58 | ################################# 59 | parser.add_argument('--arch_cnn', default='resnet50', type=str, 60 | help="""Name of architecture to train. For quick experiments with ViTs""") 61 | parser.add_argument('--arch_vit', default='vit_small', type=str, 62 | help="""Name of architecture to train. For quick experiments with ViTs""") 63 | parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of 64 | the DINO head output. For complex and large datasets large values (like 65k) work well.""") 65 | parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag, 66 | help="""Whether or not to weight normalize the last layer of the DINO head..""") 67 | parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA 68 | parameter for teacher update.""") 69 | parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag, 70 | help="Whether to use batch normalizations in projection head (Default: False)") 71 | # for ViTs 72 | parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels 73 | of input square patches - default 16.""") 74 | parser.add_argument('--drop_path_rate', type=float, default=0.1, help="stochastic depth rate") 75 | # for cross-distillation 76 | parser.add_argument("--lamda_c", default=1.0, type=float, help=""" weight for ct loss cnn.""") 77 | parser.add_argument("--lamda_t", default=1.0, type=float, help=""" weight for ct loss vit.""") 78 | 79 | ################################# 80 | #### optim parameters ### 81 | ################################# 82 | # training/pptimization parameters 83 | parser.add_argument('--batch_size_per_gpu', default=64, type=int, 84 | help='Per-GPU batch-size : number of distinct images loaded on one GPU.') 85 | parser.add_argument('--epochs', default=200, type=int, help='Number of epochs of training.') 86 | parser.add_argument('--optimizer', default='adamw', type=str, 87 | choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""") 88 | parser.add_argument("--lr_cnn", default=0.3, type=float, help="""Learning rate at the end of 89 | linear warmup (highest LR used during training).""") 90 | parser.add_argument("--lr_vit", default=0.0005, type=float, help="""Learning rate at the end of 91 | linear warmup (highest LR used during training). """) 92 | parser.add_argument("--warmup_epochs", default=10, type=int, 93 | help="Number of epochs for the linear learning-rate warm up.") 94 | parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the 95 | end of optimization. We use a cosine LR schedule with linear warmup.""") 96 | parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the 97 | weight decay. With ViT, a smaller value at the beginning of training works well.""") 98 | parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the 99 | weight decay. We use a cosine schedule for WD and using a larger decay by 100 | the end of training improves performance for ViTs.""") 101 | parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not 102 | to use half precision for training.""") 103 | parser.add_argument('--clip_grad_cnn', type=float, default=3.0, help="""Maximal parameter 104 | gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can 105 | help optimization for larger ViT architectures. 0 for disabling.""") 106 | parser.add_argument('--clip_grad_vit', type=float, default=0.0, help="""Maximal parameter 107 | gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can 108 | help optimization for larger ViT architectures. 0 for disabling.""") 109 | parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs 110 | during which we keep the output layer fixed. Typically doing so during 111 | the first epoch helps training. Try increasing this value if the loss does not decrease.""") 112 | # temperature teacher parameters 113 | parser.add_argument('--warmup_teacher_temp', default=0.04, type=float, 114 | help="""Initial value for the teacher temperature: 0.04 works well in most cases. 115 | Try decreasing it if the training loss does not decrease.""") 116 | parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup) 117 | of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend 118 | starting with the default value of 0.04 and increase this slightly if needed.""") 119 | parser.add_argument('--warmup_teacher_temp_epochs_cnn', default=50, type=int, 120 | help='Number of warmup epochs for the teacher temperature (Default: 50).') 121 | parser.add_argument('--warmup_teacher_temp_epochs_vit', default=30, type=int, 122 | help='Number of warmup epochs for the teacher temperature (Default: 30).') 123 | 124 | ################################# 125 | #### dist parameters ### 126 | ################################# 127 | parser.add_argument('--seed', default=0, type=int, help='Random seed.') 128 | parser.add_argument('--rank', default=-1, type=int, 129 | help='node rank for distributed training') 130 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 131 | parser.add_argument("--dist_url", default="env://", type=str, help="url used to set up distributed training") 132 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 133 | 134 | return parser 135 | 136 | def train(args): 137 | 138 | ######################## init dist ######################## 139 | utils.init_distributed_mode(args) 140 | utils.fix_random_seeds(args.seed) 141 | print("git:\n {}\n".format(utils.get_sha())) 142 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 143 | cudnn.benchmark = True 144 | 145 | if args.gpu != 0: 146 | def print_pass(*args): 147 | pass 148 | builtins.print = print_pass 149 | 150 | if args.gpu is not None: 151 | print("Use GPU: {} for training".format(args.gpu)) 152 | 153 | ######################## preparing data ... ######################## 154 | transform = DINODataAugmentation( 155 | args.global_crops_scale, 156 | args.local_crops_scale, 157 | args.local_crops_number, 158 | ) 159 | 160 | dataset = datasets.ImageFolder(args.data_path, transform=transform) 161 | 162 | sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) 163 | data_loader = torch.utils.data.DataLoader( 164 | dataset, 165 | sampler=sampler, 166 | batch_size=args.batch_size_per_gpu, 167 | num_workers=args.num_workers, 168 | pin_memory=True, 169 | drop_last=True, 170 | ) 171 | print(f"Loaded {len(dataset)} training images.") 172 | 173 | ######################## building networks ...######################## 174 | if args.arch_vit in vits.__dict__.keys(): 175 | student_vit = vits.__dict__[args.arch_vit]( 176 | patch_size=args.patch_size, 177 | drop_path_rate=args.drop_path_rate, # stochastic depth 178 | ) 179 | teacher_vit = vits.__dict__[args.arch_vit](patch_size=args.patch_size) 180 | embed_dim_vit = student_vit.embed_dim 181 | else: 182 | print(f"Unknow architecture: {args.arch_vit}") 183 | 184 | if args.arch_cnn in resnets.__dict__.keys(): 185 | student_cnn = resnets.__dict__[args.arch_cnn]() 186 | teacher_cnn = resnets.__dict__[args.arch_cnn]() 187 | embed_dim_cnn = student_cnn.fc.weight.shape[1] 188 | else: 189 | print(f"Unknow architecture: {args.arch_cnn}") 190 | 191 | # multi-crop wrapper handles forward with inputs of different resolutions 192 | # use_bn = True if args.cnn_checkpoint is None else False 193 | student_cnn = CNNStudentWrapper(student_cnn, DINOHead( 194 | embed_dim_cnn, 195 | args.out_dim, 196 | use_bn=False, 197 | norm_last_layer=True), 198 | vits.THead_CNN(out_dim=args.out_dim, featuremap_size=args.input_size // 32, in_chans=embed_dim_cnn, 199 | embed_dim=embed_dim_vit, depth=3, num_heads=6, mlp_ratio=4, 200 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)), 201 | ) 202 | 203 | teacher_cnn = CNNTeacherWrapper( 204 | teacher_cnn, 205 | DINOHead(embed_dim_cnn, args.out_dim, False), 206 | vits.THead_CNN(out_dim=args.out_dim, featuremap_size=args.input_size // 32, in_chans=embed_dim_cnn, 207 | embed_dim=embed_dim_vit, depth=3, num_heads=6, mlp_ratio=4, 208 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)), 209 | args.local_crops_number, 210 | ) 211 | 212 | student_vit = ViTStudentWrapper(student_vit, DINOHead( 213 | embed_dim_vit, 214 | args.out_dim, 215 | use_bn=False, 216 | norm_last_layer=False), 217 | vits.THead_ViT(out_dim=args.out_dim, featuremap_size=14, 218 | embed_dim=embed_dim_vit, depth=3, num_heads=6, mlp_ratio=4, 219 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)), 220 | ) 221 | teacher_vit = ViTTeacherWrapper( 222 | teacher_vit, 223 | DINOHead(embed_dim_vit, args.out_dim, False), 224 | vits.THead_ViT(out_dim=args.out_dim, featuremap_size=14, 225 | embed_dim=embed_dim_vit, depth=3, num_heads=6, mlp_ratio=4, 226 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)), 227 | args.local_crops_number, 228 | ) 229 | 230 | # move networks to gpu 231 | student_cnn, teacher_cnn = student_cnn.cuda(), teacher_cnn.cuda() 232 | # synchronize batch norms 233 | if utils.has_batchnorms(student_cnn): 234 | student_cnn = nn.SyncBatchNorm.convert_sync_batchnorm(student_cnn) 235 | teacher_cnn = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_cnn) 236 | # use DDP wrapper to have synchro batch norms working... 237 | teacher_cnn = nn.parallel.DistributedDataParallel(teacher_cnn, device_ids=[args.gpu]) 238 | teacher_cnn_without_ddp = teacher_cnn.module 239 | else: 240 | teacher_cnn_without_ddp = teacher_cnn 241 | student_cnn = nn.parallel.DistributedDataParallel(student_cnn, device_ids=[args.gpu], find_unused_parameters=False) 242 | teacher_cnn_without_ddp.load_state_dict(student_cnn.module.state_dict(), strict=False) 243 | for p in teacher_cnn.parameters(): 244 | p.requires_grad = False 245 | print(f"CNN student and Teacher are built: they are both {args.arch_cnn} network.") 246 | 247 | student_vit, teacher_vit = student_vit.cuda(), teacher_vit.cuda() 248 | # synchronize batch norms 249 | if utils.has_batchnorms(student_vit): 250 | student_vit = nn.SyncBatchNorm.convert_sync_batchnorm(student_vit) 251 | teacher_vit = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_vit) 252 | # use DDP wrapper to have synchro batch norms working... 253 | teacher_vit = nn.parallel.DistributedDataParallel(teacher_vit, device_ids=[args.gpu]) 254 | teacher_vit_without_ddp = teacher_vit.module 255 | else: 256 | teacher_vit_without_ddp = teacher_vit 257 | student_vit = nn.parallel.DistributedDataParallel(student_vit, device_ids=[args.gpu], find_unused_parameters=False) 258 | teacher_vit_without_ddp.load_state_dict(student_vit.module.state_dict(), strict=False) 259 | for p in teacher_vit.parameters(): 260 | p.requires_grad = False 261 | print(f"ViT student and Teacher are built: they are both {args.arch_vit} network.") 262 | 263 | ######################## preparing loss ... ######################## 264 | loss_cnn_fn = DINOLoss( 265 | args.out_dim, 266 | args.local_crops_number + 2, 267 | args.warmup_teacher_temp, 268 | args.teacher_temp, 269 | args.warmup_teacher_temp_epochs_cnn, # args.warmup_teacher_temp_epochs 50 270 | args.epochs, 271 | ).cuda() 272 | 273 | loss_vit_fn = DINOLoss( 274 | args.out_dim, 275 | args.local_crops_number + 2, 276 | args.warmup_teacher_temp, 277 | args.teacher_temp, 278 | args.warmup_teacher_temp_epochs_vit, # args.warmup_teacher_temp_epochs 30 279 | args.epochs, 280 | ).cuda() 281 | 282 | loss_cnn_ct_fn = DINOLoss( 283 | args.out_dim, 284 | args.local_crops_number + 2, 285 | args.warmup_teacher_temp, 286 | args.teacher_temp, 287 | args.warmup_teacher_temp_epochs_vit, # args.warmup_teacher_temp_epochs 30 288 | args.epochs, 289 | ).cuda() 290 | 291 | loss_vit_ct_fn = DINOLoss( 292 | args.out_dim, 293 | args.local_crops_number + 2, 294 | args.warmup_teacher_temp, 295 | args.teacher_temp, 296 | args.warmup_teacher_temp_epochs_vit, # args.warmup_teacher_temp_epochs 30 297 | args.epochs, 298 | ).cuda() 299 | 300 | loss_cnn_thead_fn = DINOLoss( 301 | args.out_dim, 302 | args.local_crops_number + 2, 303 | args.warmup_teacher_temp, 304 | args.teacher_temp, 305 | args.warmup_teacher_temp_epochs_cnn, # args.warmup_teacher_temp_epochs 50 306 | args.epochs, 307 | ).cuda() 308 | 309 | loss_vit_thead_fn = DINOLoss( 310 | args.out_dim, 311 | args.local_crops_number + 2, 312 | args.warmup_teacher_temp, 313 | args.teacher_temp, 314 | args.warmup_teacher_temp_epochs_vit, # args.warmup_teacher_temp_epochs 30 315 | args.epochs, 316 | ).cuda() 317 | 318 | loss_search_ct_fn = CTSearchLoss( 319 | args.warmup_teacher_temp, 320 | args.teacher_temp, 321 | args.warmup_teacher_temp_epochs_vit, 322 | args.epochs,) 323 | 324 | ######################## preparing optimizer ... ######################## 325 | params_groups_cnn = utils.get_params_groups2(student_cnn, head_name="transhead") 326 | if args.optimizer == "adamw": 327 | optimizer_cnn = torch.optim.AdamW(params_groups_cnn) # to use with ViTs 328 | elif args.optimizer == "sgd": 329 | optimizer_cnn = torch.optim.SGD(params_groups_cnn, lr=0, momentum=0.9) # lr is set by scheduler 330 | elif args.optimizer == "lars": 331 | optimizer_cnn = optimizers.LARS(params_groups_cnn) # to use with convnet and large batches 332 | 333 | params_groups_vit = utils.get_params_groups(student_vit) 334 | optimizer_vit = torch.optim.AdamW(params_groups_vit) # to use with ViTs 335 | 336 | # for mixed precision training 337 | fp16_scaler_cnn = None 338 | fp16_scaler_vit = None 339 | if args.use_fp16: 340 | fp16_scaler_cnn = torch.cuda.amp.GradScaler() 341 | fp16_scaler_vit = torch.cuda.amp.GradScaler() 342 | 343 | ######################## init schedulers ... ######################## 344 | lr_schedule_cnn = utils.cosine_scheduler( 345 | args.lr_cnn * (args.batch_size_per_gpu * utils.get_world_size()) / 256., 346 | 0.0048, # args.min_lr 347 | args.epochs, len(data_loader), 348 | warmup_epochs=args.warmup_epochs, 349 | ) 350 | wd_schedule_cnn = utils.cosine_scheduler( 351 | 1e-4, # args.weight_decay 352 | 1e-4, # args.weight_decay_end 353 | args.epochs, len(data_loader), 354 | ) 355 | 356 | lr_schedule_vit = utils.cosine_scheduler( 357 | args.lr_vit * (args.batch_size_per_gpu * utils.get_world_size()) / 256., 358 | 1e-5, # args.min_lr 359 | args.epochs, len(data_loader), 360 | warmup_epochs=args.warmup_epochs, 361 | ) 362 | wd_schedule_vit = utils.cosine_scheduler( 363 | 0.04, # args.weight_decay 364 | 0.4, # args.weight_decay_end 365 | args.epochs, len(data_loader), 366 | ) 367 | 368 | # momentum parameter is increased to 1. during training with a cosine schedule 369 | momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, args.epochs, len(data_loader)) 370 | print(f"Loss, optimizer and schedulers ready.") 371 | 372 | summary_writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 373 | "tb", "{}_{}.{}_pretrain_{}".format(method, args.arch_cnn, args.arch_vit, args.experiment))) if args.rank == 0 else None 374 | 375 | to_restore = {"epoch": 0} 376 | checkpoint_io.restart_from_checkpoint( 377 | os.path.join(args.output_dir, "{}_{}.{}_pretrain_{}_temp.pth".format(method, args.arch_cnn, 378 | args.arch_vit, args.experiment)), 379 | run_variables=to_restore, 380 | student_cnn=student_cnn, 381 | teacher_cnn=teacher_cnn, 382 | student_vit=student_vit, 383 | teacher_vit=teacher_vit, 384 | optimizer_cnn=optimizer_cnn, 385 | optimizer_vit=optimizer_vit, 386 | fp16_scaler_cnn=fp16_scaler_cnn, 387 | fp16_scaler_vit=fp16_scaler_vit, 388 | loss_cnn_fn=loss_cnn_fn, 389 | loss_vit_fn=loss_vit_fn, 390 | loss_cnn_ct_fn=loss_cnn_ct_fn, 391 | loss_vit_ct_fn=loss_vit_ct_fn, 392 | loss_cnn_thead_fn=loss_cnn_thead_fn, 393 | loss_vit_thead_fn=loss_vit_thead_fn, 394 | ) 395 | start_epoch = to_restore["epoch"] 396 | 397 | ######################## start training ######################## 398 | start_time = time.time() 399 | print("Starting {} training !".format(method)) 400 | for epoch in range(start_epoch, args.epochs): 401 | data_loader.sampler.set_epoch(epoch) 402 | 403 | ######################## training one epoch of DINO ... ######################## 404 | train_stats = train_one_epoch( 405 | student_cnn, teacher_cnn, teacher_cnn_without_ddp, 406 | student_vit, teacher_vit, teacher_vit_without_ddp, 407 | loss_cnn_fn, loss_vit_fn, loss_cnn_ct_fn, loss_vit_ct_fn, loss_cnn_thead_fn, loss_vit_thead_fn, loss_search_ct_fn, 408 | data_loader, optimizer_cnn, lr_schedule_cnn, wd_schedule_cnn, 409 | optimizer_vit, lr_schedule_vit, wd_schedule_vit, 410 | momentum_schedule, epoch, 411 | fp16_scaler_cnn, fp16_scaler_vit, 412 | summary_writer, args) 413 | 414 | ########################writing logs ... ######################## 415 | save_dict = { 416 | 'student_cnn': student_cnn.state_dict(), 417 | 'teacher_cnn': teacher_cnn.state_dict(), 418 | 'student_vit': student_vit.state_dict(), 419 | 'teacher_vit': teacher_vit.state_dict(), 420 | 'optimizer_cnn': optimizer_cnn.state_dict(), 421 | 'optimizer_vit': optimizer_vit.state_dict(), 422 | 'epoch': epoch + 1, 423 | 'arch_cnn': args.arch_cnn, 424 | 'arch_vit': args.arch_vit, 425 | 'loss_cnn_fn': loss_cnn_fn.state_dict(), 426 | 'loss_vit_fn': loss_vit_fn.state_dict(), 427 | 'loss_cnn_ct_fn': loss_cnn_ct_fn.state_dict(), 428 | 'loss_vit_ct_fn': loss_vit_ct_fn.state_dict(), 429 | 'loss_cnn_thead_fn': loss_cnn_thead_fn.state_dict(), 430 | 'loss_vit_thead_fn': loss_vit_thead_fn.state_dict(), 431 | } 432 | if fp16_scaler_cnn is not None: 433 | save_dict['fp16_scaler_cnn'] = fp16_scaler_cnn.state_dict() 434 | save_dict['fp16_scaler_vit'] = fp16_scaler_vit.state_dict() 435 | 436 | utils.save_on_master(save_dict, os.path.join(args.output_dir, 437 | "{}_{}.{}_pretrain_{}_temp.pth".format(method, args.arch_cnn, args.arch_vit, args.experiment))) 438 | 439 | if (args.saveckp_freq and epoch % args.saveckp_freq == 0) or epoch == args.epochs - 1: 440 | utils.save_on_master(save_dict, os.path.join(args.output_dir, 441 | "{}_{}.{}_pretrain_{}_{:04d}.pth".format(method, args.arch_cnn, args.arch_vit, args.experiment, epoch))) 442 | 443 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 444 | 'epoch': epoch} 445 | 446 | if utils.is_main_process(): 447 | with (Path(args.output_dir) / "{}_{}.{}_pretrain_{}_log.txt".format(method, args.arch_cnn, args.arch_vit, args.experiment)).open("a") as f: 448 | f.write(json.dumps(log_stats) + "\n") 449 | 450 | if args.rank == 0: 451 | summary_writer.close() 452 | total_time = time.time() - start_time 453 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 454 | print('Training time {}'.format(total_time_str)) 455 | 456 | def train_one_epoch(student_cnn, teacher_cnn, teacher_cnn_without_ddp, 457 | student_vit, teacher_vit, teacher_vit_without_ddp, 458 | loss_cnn_fn, loss_vit_fn, loss_cnn_ct_fn, loss_vit_ct_fn, loss_cnn_thead_fn, loss_vit_thead_fn, loss_search_ct_fn, 459 | data_loader, optimizer_cnn, lr_schedule_cnn, wd_schedule_cnn, 460 | optimizer_vit, lr_schedule_vit, wd_schedule_vit, 461 | momentum_schedule, epoch, 462 | fp16_scaler_cnn, fp16_scaler_vit, 463 | summary_writer, args): 464 | metric_logger = utils.MetricLogger(delimiter=" ") 465 | header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) 466 | iters_per_epoch = len(data_loader) 467 | for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)): 468 | # update weight decay and learning rate 469 | it = len(data_loader) * epoch + it 470 | for i, param_group in enumerate(optimizer_cnn.param_groups): 471 | param_group["lr"] = lr_schedule_cnn[it] 472 | if i == 0: # only the first group is regularized 473 | param_group["weight_decay"] = wd_schedule_cnn[it] 474 | if i == 2: # transhead 475 | param_group["lr"] = lr_schedule_vit[it] 476 | param_group["weight_decay"] = wd_schedule_vit[it] 477 | 478 | for i, param_group in enumerate(optimizer_vit.param_groups): 479 | param_group["lr"] = lr_schedule_vit[it] 480 | if i == 0: # only the first group is regularized 481 | param_group["weight_decay"] = wd_schedule_vit[it] 482 | 483 | # move images to gpu 484 | images = [im.cuda(non_blocking=True) for im in images] 485 | 486 | # student cnn 487 | with torch.cuda.amp.autocast(fp16_scaler_cnn is not None): 488 | s_cnn_m, s_cnn_t_g, s_cnn_out_g = student_cnn(images[:2]) 489 | s_cnn_m_l, s_cnn_t_l, s_cnn_out_l = student_cnn(images[2:]) 490 | s_cnn_m = torch.cat([s_cnn_m, s_cnn_m_l], dim=0) 491 | s_cnn_t = torch.cat([s_cnn_t_g, s_cnn_t_l], dim=0) 492 | 493 | # student vit 494 | with torch.cuda.amp.autocast(fp16_scaler_vit is not None): 495 | s_vit_m, s_vit_t_g, s_vit_out_g = student_vit(images[:2]) 496 | s_vit_m_l, s_vit_t_l, s_vit_out_l = student_vit(images[2:]) 497 | s_vit_m = torch.cat([s_vit_m, s_vit_m_l], dim=0) 498 | s_vit_t = torch.cat([s_vit_t_g, s_vit_t_l], dim=0) 499 | 500 | # teacher cnn 501 | with torch.cuda.amp.autocast(fp16_scaler_cnn is not None): 502 | t_cnn_m, t_cnn_t, t_cnn_search_l, _ = teacher_cnn(images[:2], local_token=s_vit_out_l.detach()) 503 | 504 | # teacher vit 505 | with torch.cuda.amp.autocast(fp16_scaler_vit is not None): 506 | t_vit_m, t_vit_t, t_vit_search_l, _ = teacher_vit(images[:2], local_token=s_cnn_out_l.detach()) 507 | 508 | # loss 509 | with torch.cuda.amp.autocast(fp16_scaler_cnn is not None): 510 | loss_cnn_m = loss_cnn_fn(s_cnn_m, t_cnn_m, epoch) 511 | loss_cnn_t = loss_cnn_thead_fn(s_cnn_t, t_cnn_t, epoch) 512 | loss_ct_cnn = loss_cnn_ct_fn(s_cnn_m, t_vit_m.detach(), epoch) 513 | loss_ct_search_cnn = loss_search_ct_fn(s_cnn_t_l, t_vit_search_l.detach(), epoch) 514 | 515 | loss_cnn_total = loss_cnn_m + loss_cnn_t + args.lamda_c * (loss_ct_cnn + loss_ct_search_cnn) 516 | 517 | with torch.cuda.amp.autocast(fp16_scaler_vit is not None): 518 | loss_vit_m = loss_vit_fn(s_vit_m, t_vit_m, epoch) 519 | loss_vit_t = loss_vit_thead_fn(s_vit_t, t_vit_t, epoch) 520 | loss_ct_vit = loss_vit_ct_fn(s_vit_m, t_cnn_m.detach(), epoch) 521 | loss_ct_search_vit = loss_search_ct_fn(s_vit_t_l, t_cnn_search_l.detach(), epoch) 522 | 523 | loss_vit_total = loss_vit_m + loss_vit_t + args.lamda_t * (loss_ct_vit + loss_ct_search_vit) 524 | 525 | if not math.isfinite(loss_cnn_total.item()): 526 | print("Loss is {}, stopping training".format(loss_cnn_total.item()), force=True) 527 | sys.exit(1) 528 | if not math.isfinite(loss_vit_total.item()): 529 | print("Loss is {}, stopping training".format(loss_vit_total.item()), force=True) 530 | sys.exit(1) 531 | 532 | optimizer_cnn.zero_grad() 533 | param_norms = None 534 | if fp16_scaler_cnn is None: 535 | loss_cnn_total.backward() 536 | if args.clip_grad_cnn: 537 | param_norms = utils.clip_gradients(student_cnn, args.clip_grad_cnn) 538 | utils.cancel_gradients_last_layer(epoch, student_cnn, args.freeze_last_layer) 539 | optimizer_cnn.step() 540 | else: 541 | fp16_scaler_cnn.scale(loss_cnn_total).backward() # retain_graph=True 542 | if args.clip_grad_cnn: 543 | fp16_scaler_cnn.unscale_(optimizer_cnn) # unscale the gradients of optimizer's assigned params in-place 544 | param_norms = utils.clip_gradients(student_cnn, args.clip_grad_cnn) 545 | utils.cancel_gradients_last_layer(epoch, student_cnn, args.freeze_last_layer) 546 | fp16_scaler_cnn.step(optimizer_cnn) 547 | fp16_scaler_cnn.update() 548 | 549 | # EMA update for the cnn teacher 550 | with torch.no_grad(): 551 | m = momentum_schedule[it] 552 | for param_q, param_k in zip(student_cnn.module.parameters(), teacher_cnn_without_ddp.parameters()): 553 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 554 | 555 | optimizer_vit.zero_grad() 556 | param_norms = None 557 | if fp16_scaler_vit is None: 558 | loss_vit_total.backward() 559 | if args.clip_grad_vit: 560 | param_norms = utils.clip_gradients(student_vit, args.clip_grad_vit) 561 | utils.cancel_gradients_last_layer(epoch, student_vit, args.freeze_last_layer) 562 | optimizer_vit.step() 563 | else: 564 | fp16_scaler_vit.scale(loss_vit_total).backward() 565 | if args.clip_grad_vit: 566 | fp16_scaler_vit.unscale_(optimizer_vit) # unscale the gradients of optimizer's assigned params in-place 567 | param_norms = utils.clip_gradients(student_vit, args.clip_grad_vit) 568 | utils.cancel_gradients_last_layer(epoch, student_vit, args.freeze_last_layer) 569 | fp16_scaler_vit.step(optimizer_vit) 570 | fp16_scaler_vit.update() 571 | 572 | # EMA update for the cnn teacher 573 | with torch.no_grad(): 574 | m = momentum_schedule[it] 575 | for param_q, param_k in zip(student_vit.module.parameters(), teacher_vit_without_ddp.parameters()): 576 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 577 | 578 | # logging 579 | if args.rank == 0: 580 | summary_writer.add_scalar("loss_cnn", loss_cnn_total.item(), it) 581 | summary_writer.add_scalar("loss_vit", loss_vit_total.item(), it) 582 | summary_writer.add_scalar("lr_cnn", optimizer_cnn.param_groups[0]["lr"], it) 583 | summary_writer.add_scalar("lr_vit", optimizer_vit.param_groups[0]["lr"], it) 584 | 585 | torch.cuda.synchronize() 586 | metric_logger.update(loss_cnn=loss_cnn_m.item()) 587 | metric_logger.update(loss_vit=loss_vit_m.item()) 588 | 589 | metric_logger.synchronize_between_processes() 590 | print("Averaged stats:", metric_logger) 591 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 592 | 593 | if __name__ == '__main__': 594 | parser = argparse.ArgumentParser(method, parents=[get_args_parser()]) 595 | args = parser.parse_args() 596 | 597 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 598 | 599 | train(args) 600 | --------------------------------------------------------------------------------