├── util ├── __init__.py ├── lr_sched.py ├── datasets.py ├── model_ema.py ├── lr_decay.py ├── pos_embed.py └── misc.py ├── attention_transfer.png ├── environment.yml ├── README.md ├── models_vit.py ├── submitit_finetune.py ├── engine_finetune.py ├── models_dual_vit.py ├── LICENSE └── main_finetune.py /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /attention_transfer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexlioralexli/attention-transfer/HEAD/attention_transfer.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: attn-transfer 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.8 7 | - torchvision 8 | - timm 9 | - torchaudio 10 | - pytorch-cuda=12.1 11 | - pytorch 12 | - tensorboard -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # -------------------------------------------------------- 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # DeiT: https://github.com/facebookresearch/deit 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import PIL 11 | 12 | from torchvision import datasets, transforms 13 | 14 | from timm.data import create_transform 15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | 17 | def build_dataset(is_train, args): 18 | transform = build_transform(is_train, args) 19 | 20 | if args.dataset_name == 'imagenet': 21 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 22 | dataset = datasets.ImageFolder(root, transform=transform) 23 | else: 24 | raise NotImplementedError 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Attention Transfer: A PyTorch Implementation 2 | 3 |

4 | 5 |

6 | 7 | 8 | This is a PyTorch/GPU reimplementation of the NeurIPS 2024 paper [On the Surprising Effectiveness of Attention Transfer for Vision Transformers](https://arxiv.org/abs/2411.09702): 9 | ``` 10 | @inproceedings{AttentionTransfer2024, 11 | title = {On the Surprising Effectiveness of Attention Transfer for Vision Transformers}, 12 | author = {Li, Alexander Cong and Tian, Yuandong and Chen, Beidi and Pathak, Deepak and Chen, Xinlei}, 13 | booktitle = {The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 14 | year = {2024} 15 | } 16 | ``` 17 | 18 | * The original implementation was in Jax+TPU. This reimplementation is in PyTorch+GPU. 19 | 20 | * This repo is a modification of the [MAE repo](https://github.com/facebookresearch/mae). Refer to that repo for detailed installation and setup. 21 | 22 | ### Installation 23 | ``` 24 | conda env create -f environment.yml 25 | ``` 26 | 27 | ### Training student with pre-trained teacher 28 | Obtain pre-trained MAE checkpoints from [here](https://github.com/facebookresearch/mae). 29 | 30 | **Attention Distillation** 31 | 32 | To train with multi-node distributed training, run the following on 8 nodes with 8 GPUs each: 33 | ``` 34 | python submitit_finetune.py \ 35 | --job_dir ${JOB_DIR} \ 36 | --nodes 8 \ 37 | --batch_size 32 \ 38 | --model dual_vit_large_patch16 --mode distill --end_layer -6 --atd_weight 3.0 \ 39 | --finetune mae_pretrain_vit_large.pth --resume allow \ 40 | --epochs 200 --ema 0.9999 \ 41 | --blr 1e-4 --layer_decay 1 --beta2 0.95 --warmup_epochs 20 \ 42 | --weight_decay 0.3 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 43 | --dist_eval --data_path ${IMAGENET_PATH} 44 | ``` 45 | 46 | **Attention Copy** 47 | ``` 48 | python submitit_finetune.py \ 49 | --job_dir ${JOB_DIR} \ 50 | --nodes 8 \ 51 | --batch_size 32 \ 52 | --model dual_vit_large_patch16 --mode copy \ 53 | --finetune mae_pretrain_vit_large.pth --resume allow \ 54 | --epochs 100 --ema 0.9999 \ 55 | --blr 1e-3 --min_lr 2e-3 --layer_decay 0.75 --beta2 0.999 \ 56 | --weight_decay 0.05 --drop_path 0 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 57 | --dist_eval --data_path ${IMAGENET_PATH} 58 | ``` 59 | 60 | 61 | 62 | 63 | 64 | 65 | ### License 66 | 67 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 68 | -------------------------------------------------------------------------------- /util/model_ema.py: -------------------------------------------------------------------------------- 1 | # from Xinlei Chen 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ModelEmaV2(nn.Module): 9 | """Model Exponential Moving Average V2 10 | 11 | Keep a moving average of everything in the model state_dict (parameters and buffers). 12 | V2 of this module is simpler, it does not match params/buffers based on name but simply 13 | iterates in order. It works with torchscript (JIT of full model). 14 | 15 | This is intended to allow functionality like 16 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 17 | 18 | A smoothed version of the weights is necessary for some training schemes to perform well. 19 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 20 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 21 | smoothing of weights to match results. Pay attention to the decay constant you are using 22 | relative to your update count per epoch. 23 | 24 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 25 | disable validation of the EMA weights. Validation will have to be done manually in a separate 26 | process, or after the training stops converging. 27 | 28 | This class is sensitive where it is initialized in the sequence of model init, 29 | GPU assignment and distributed training wrappers. 30 | """ 31 | 32 | def __init__(self, model, decay=0.9999, device=None): 33 | super(ModelEmaV2, self).__init__() 34 | # make a copy of the model for accumulating moving average of weights 35 | self.module = deepcopy(model) 36 | self.module.eval() 37 | self.decay = decay 38 | self.device = device # perform ema on different device from model if set 39 | if self.device is not None: 40 | self.module.to(device=device) 41 | 42 | def _update(self, model, update_fn): 43 | with torch.no_grad(): 44 | for ema_v, model_v in zip( 45 | self.module.state_dict().values(), model.state_dict().values() 46 | ): 47 | if self.device is not None: 48 | model_v = model_v.to(device=self.device) 49 | ema_v.copy_(update_fn(ema_v, model_v)) 50 | 51 | def update(self, model): 52 | self._update( 53 | model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m 54 | ) 55 | 56 | def set(self, model): 57 | self._update(model, update_fn=lambda e, m: m) 58 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # ELECTRA https://github.com/google-research/electra 7 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 8 | # -------------------------------------------------------- 9 | 10 | import models_dual_vit 11 | 12 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 13 | """ 14 | Parameter groups for layer-wise lr decay 15 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 16 | """ 17 | param_group_names = {} 18 | param_groups = {} 19 | 20 | if isinstance(model, models_dual_vit.DualVisionTransformer): 21 | num_layers = model.student_depth 22 | else: 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | return list(param_groups.values()) 60 | 61 | 62 | def get_layer_id_for_vit(name, num_layers): 63 | """ 64 | Assign a parameter with its layer id 65 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 66 | """ 67 | 68 | if 'cls_token' in name or 'pos_embed' in name: 69 | return 0 70 | elif 'patch_embed' in name: 71 | return 0 72 | elif 'blocks' in name: 73 | if 'student' in name: 74 | return int(name.split('.')[2]) + 1 75 | else: 76 | return int(name.split('.')[1]) + 1 77 | else: 78 | return num_layers -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 7 | # DeiT: https://github.com/facebookresearch/deit 8 | # -------------------------------------------------------- 9 | 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | import timm.models.vision_transformer 16 | 17 | 18 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 19 | """ Vision Transformer with support for global average pooling 20 | """ 21 | def __init__(self, global_pool=False, **kwargs): 22 | super(VisionTransformer, self).__init__(**kwargs) 23 | 24 | self.global_pool = global_pool 25 | if self.global_pool: 26 | norm_layer = kwargs['norm_layer'] 27 | embed_dim = kwargs['embed_dim'] 28 | self.fc_norm = norm_layer(embed_dim) 29 | 30 | del self.norm # remove the original norm 31 | 32 | def forward_features(self, x, layer_to_return=None): 33 | # return features up until layer_to_return 34 | intermediate_feat = [] 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for i, blk in enumerate(self.blocks): 44 | x = blk(x) 45 | if layer_to_return is not None and i < layer_to_return: 46 | intermediate_feat.append(x) 47 | 48 | if self.global_pool: 49 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 50 | outcome = self.fc_norm(x) 51 | else: 52 | x = self.norm(x) 53 | outcome = x[:, 0] 54 | 55 | if layer_to_return is not None: 56 | return outcome, intermediate_feat 57 | else: 58 | return outcome 59 | 60 | def forward(self, x, return_features=False, layer_to_return=None): 61 | if layer_to_return is not None: 62 | final_feats, intermediate_feats = self.forward_features(x, layer_to_return=layer_to_return) 63 | else: 64 | final_feats = self.forward_features(x) 65 | pred = self.head(final_feats) 66 | if return_features and layer_to_return is None: 67 | return pred, final_feats 68 | elif return_features and layer_to_return is not None: 69 | return pred, intermediate_feats 70 | else: 71 | return pred 72 | 73 | 74 | def vit_base_patch16(**kwargs): 75 | model = VisionTransformer( 76 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 77 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 78 | return model 79 | 80 | 81 | def vit_large_patch16(**kwargs): 82 | model = VisionTransformer( 83 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 84 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 85 | return model 86 | 87 | 88 | def vit_huge_patch14(**kwargs): 89 | model = VisionTransformer( 90 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 91 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 92 | return model -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /submitit_finetune.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_finetune as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--ncpus", default=10, type=int, help="Number of cpus per gpu") 24 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 25 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 26 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 27 | 28 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 29 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 30 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 31 | parser.add_argument('--name', default="mae", type=str, help="Name for job") 32 | return parser.parse_args() 33 | 34 | 35 | def get_shared_folder() -> Path: 36 | user = os.getenv("USER") 37 | if Path("/checkpoint/").is_dir(): 38 | p = Path(f"/checkpoint/{user}/experiments") 39 | p.mkdir(exist_ok=True) 40 | return p 41 | raise RuntimeError("No shared folder available") 42 | 43 | 44 | def get_init_file(): 45 | # Init file must not exist, but it's parent dir must exist. 46 | os.makedirs(str(get_shared_folder()), exist_ok=True) 47 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 48 | if init_file.exists(): 49 | os.remove(str(init_file)) 50 | return init_file 51 | 52 | 53 | class Trainer(object): 54 | def __init__(self, args): 55 | self.args = args 56 | 57 | def __call__(self): 58 | import main_finetune as classification 59 | 60 | self._setup_gpu_args() 61 | self._setup_fair() 62 | classification.main(self.args) 63 | 64 | def checkpoint(self): 65 | import os 66 | import submitit 67 | 68 | self.args.dist_url = get_init_file().as_uri() 69 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 70 | if os.path.exists(checkpoint_file): 71 | self.args.resume = checkpoint_file 72 | print("Requeuing ", self.args) 73 | empty_trainer = type(self)(self.args) 74 | return submitit.helpers.DelayedSubmission(empty_trainer) 75 | 76 | def _setup_gpu_args(self): 77 | import submitit 78 | from pathlib import Path 79 | 80 | job_env = submitit.JobEnvironment() 81 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 82 | self.args.log_dir = self.args.output_dir 83 | self.args.gpu = job_env.local_rank 84 | self.args.rank = job_env.global_rank 85 | self.args.world_size = job_env.num_tasks 86 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 87 | 88 | def _setup_fair(self): 89 | os.environ["GLOO_SOCKET_IFNAME"] = "" 90 | os.environ["NCCL_SOCKET_IFNAME"] = "" 91 | os.environ["NCCL_DEBUG"] = "INFO" 92 | # os.environ["NCCL_BLOCKING_WAIT"] = '1' 93 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" 94 | # os.environ["CUDA_LAUNCH_BLOCKING"] = '1' 95 | return 96 | 97 | 98 | def main(): 99 | args = parse_args() 100 | if args.job_dir == "": 101 | args.job_dir = get_shared_folder() / "%j" 102 | 103 | # Note that the folder will depend on the job_id, to easily track experiments 104 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 105 | 106 | num_gpus_per_node = args.ngpus 107 | num_cpus_per_gpu = args.ncpus 108 | nodes = args.nodes 109 | timeout_min = args.timeout 110 | 111 | partition = args.partition 112 | kwargs = {} 113 | if args.use_volta32: 114 | kwargs['slurm_constraint'] = 'volta32gb' 115 | if args.comment: 116 | kwargs['slurm_comment'] = args.comment 117 | 118 | executor.update_parameters( 119 | mem_gb=40 * num_gpus_per_node, 120 | gpus_per_node=num_gpus_per_node, 121 | tasks_per_node=num_gpus_per_node, # one task per GPU 122 | cpus_per_task=num_cpus_per_gpu, 123 | nodes=nodes, 124 | timeout_min=timeout_min, 125 | # Below are cluster dependent parameters 126 | slurm_partition=partition, 127 | slurm_signal_delay_s=120, 128 | **kwargs 129 | ) 130 | 131 | executor.update_parameters(name=args.name) 132 | 133 | args.dist_url = get_init_file().as_uri() 134 | args.output_dir = args.job_dir 135 | 136 | trainer = Trainer(args) 137 | job = executor.submit(trainer) 138 | 139 | print(job.job_id) 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # DeiT: https://github.com/facebookresearch/deit 7 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import sys 12 | from typing import Iterable, Optional 13 | 14 | import torch 15 | 16 | from timm.data import Mixup 17 | from timm.utils import accuracy 18 | 19 | import util.misc as misc 20 | import util.lr_sched as lr_sched 21 | 22 | 23 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 24 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 25 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 26 | mixup_fn: Optional[Mixup] = None, log_writer=None, 27 | args=None, model_ema=None): 28 | model.train(True) 29 | metric_logger = misc.MetricLogger(delimiter=" ") 30 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 31 | header = 'Epoch: [{}]'.format(epoch) 32 | print_freq = 20 33 | 34 | accum_iter = args.accum_iter 35 | 36 | optimizer.zero_grad() 37 | 38 | if log_writer is not None: 39 | print('log_dir: {}'.format(log_writer.log_dir)) 40 | 41 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 42 | # we use a per iteration (instead of per epoch) lr scheduler 43 | if data_iter_step % accum_iter == 0: 44 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 45 | 46 | samples = samples.to(device, non_blocking=True) 47 | targets = targets.to(device, non_blocking=True) 48 | 49 | if mixup_fn is not None: 50 | samples, targets = mixup_fn(samples, targets) 51 | 52 | if args.model.startswith("dual_vit") and 'distill' in args.mode: 53 | outputs, distill_loss = model(samples) 54 | metric_logger.update(distill_loss=distill_loss.item()) 55 | loss = criterion(outputs, targets) + args.atd_weight * distill_loss 56 | else: 57 | # with torch.cuda.amp.autocast(): 58 | outputs = model(samples) 59 | loss = criterion(outputs, targets) 60 | 61 | loss_value = loss.item() 62 | 63 | if not math.isfinite(loss_value): 64 | print("Loss is {}, stopping training".format(loss_value)) 65 | sys.exit(1) 66 | 67 | loss /= accum_iter 68 | loss_scaler(loss, optimizer, clip_grad=max_norm, 69 | parameters=model.parameters(), create_graph=False, 70 | update_grad=(data_iter_step + 1) % accum_iter == 0) 71 | if (data_iter_step + 1) % accum_iter == 0: 72 | optimizer.zero_grad() 73 | if model_ema is not None: 74 | model_ema.update(model) 75 | 76 | torch.cuda.synchronize() 77 | 78 | metric_logger.update(loss=loss_value) 79 | min_lr = 10. 80 | max_lr = 0. 81 | for group in optimizer.param_groups: 82 | min_lr = min(min_lr, group["lr"]) 83 | max_lr = max(max_lr, group["lr"]) 84 | 85 | metric_logger.update(lr=max_lr) 86 | 87 | loss_value_reduce = misc.all_reduce_mean(loss_value) 88 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 89 | """ We use epoch_1000x as the x-axis in tensorboard. 90 | This calibrates different curves when batch size changes. 91 | """ 92 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 93 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 94 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 95 | if args.model.startswith("dual_vit") and 'distill' in args.mode: 96 | log_writer.add_scalar('distill_loss', distill_loss.item(), epoch_1000x) 97 | 98 | # gather the stats from all processes 99 | metric_logger.synchronize_between_processes() 100 | print("Averaged stats:", metric_logger) 101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 102 | 103 | 104 | @torch.no_grad() 105 | def evaluate(data_loader, model, device): 106 | criterion = torch.nn.CrossEntropyLoss() 107 | 108 | metric_logger = misc.MetricLogger(delimiter=" ") 109 | header = 'Test:' 110 | 111 | # switch to evaluation mode 112 | model.eval() 113 | 114 | for batch in metric_logger.log_every(data_loader, 10, header): 115 | images = batch[0] 116 | target = batch[-1] 117 | images = images.to(device, non_blocking=True) 118 | target = target.to(device, non_blocking=True) 119 | 120 | # compute output 121 | # with torch.cuda.amp.autocast(): 122 | output = model(images) 123 | loss = criterion(output, target) 124 | 125 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 126 | 127 | batch_size = images.shape[0] 128 | metric_logger.update(loss=loss.item()) 129 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 130 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 131 | # gather the stats from all processes 132 | metric_logger.synchronize_between_processes() 133 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 134 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 135 | 136 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # DeiT: https://github.com/facebookresearch/deit 7 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 8 | # -------------------------------------------------------- 9 | 10 | import builtins 11 | import datetime 12 | import os 13 | import time 14 | from collections import defaultdict, deque 15 | from pathlib import Path 16 | 17 | import torch 18 | import torch.distributed as dist 19 | from torch import inf 20 | 21 | 22 | class SmoothedValue(object): 23 | """Track a series of values and provide access to smoothed values over a 24 | window or the global series average. 25 | """ 26 | 27 | def __init__(self, window_size=20, fmt=None): 28 | if fmt is None: 29 | fmt = "{median:.4f} ({global_avg:.4f})" 30 | self.deque = deque(maxlen=window_size) 31 | self.total = 0.0 32 | self.count = 0 33 | self.fmt = fmt 34 | 35 | def update(self, value, n=1): 36 | self.deque.append(value) 37 | self.count += n 38 | self.total += value * n 39 | 40 | def synchronize_between_processes(self): 41 | """ 42 | Warning: does not synchronize the deque! 43 | """ 44 | if not is_dist_avail_and_initialized(): 45 | return 46 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 47 | dist.barrier() 48 | dist.all_reduce(t) 49 | t = t.tolist() 50 | self.count = int(t[0]) 51 | self.total = t[1] 52 | 53 | @property 54 | def median(self): 55 | d = torch.tensor(list(self.deque)) 56 | return d.median().item() 57 | 58 | @property 59 | def avg(self): 60 | d = torch.tensor(list(self.deque), dtype=torch.float32) 61 | return d.mean().item() 62 | 63 | @property 64 | def global_avg(self): 65 | return self.total / self.count 66 | 67 | @property 68 | def max(self): 69 | return max(self.deque) 70 | 71 | @property 72 | def value(self): 73 | return self.deque[-1] 74 | 75 | def __str__(self): 76 | return self.fmt.format( 77 | median=self.median, 78 | avg=self.avg, 79 | global_avg=self.global_avg, 80 | max=self.max, 81 | value=self.value) 82 | 83 | 84 | class MetricLogger(object): 85 | def __init__(self, delimiter="\t"): 86 | self.meters = defaultdict(SmoothedValue) 87 | self.delimiter = delimiter 88 | 89 | def update(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if v is None: 92 | continue 93 | if isinstance(v, torch.Tensor): 94 | v = v.item() 95 | assert isinstance(v, (float, int)) 96 | self.meters[k].update(v) 97 | 98 | def __getattr__(self, attr): 99 | if attr in self.meters: 100 | return self.meters[attr] 101 | if attr in self.__dict__: 102 | return self.__dict__[attr] 103 | raise AttributeError("'{}' object has no attribute '{}'".format( 104 | type(self).__name__, attr)) 105 | 106 | def __str__(self): 107 | loss_str = [] 108 | for name, meter in self.meters.items(): 109 | loss_str.append( 110 | "{}: {}".format(name, str(meter)) 111 | ) 112 | return self.delimiter.join(loss_str) 113 | 114 | def synchronize_between_processes(self): 115 | for meter in self.meters.values(): 116 | meter.synchronize_between_processes() 117 | 118 | def add_meter(self, name, meter): 119 | self.meters[name] = meter 120 | 121 | def log_every(self, iterable, print_freq, header=None): 122 | i = 0 123 | if not header: 124 | header = '' 125 | start_time = time.time() 126 | end = time.time() 127 | iter_time = SmoothedValue(fmt='{avg:.4f}') 128 | data_time = SmoothedValue(fmt='{avg:.4f}') 129 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 130 | log_msg = [ 131 | header, 132 | '[{0' + space_fmt + '}/{1}]', 133 | 'eta: {eta}', 134 | '{meters}', 135 | 'time: {time}', 136 | 'data: {data}' 137 | ] 138 | if torch.cuda.is_available(): 139 | log_msg.append('max mem: {memory:.0f}') 140 | log_msg = self.delimiter.join(log_msg) 141 | MB = 1024.0 * 1024.0 142 | for obj in iterable: 143 | data_time.update(time.time() - end) 144 | yield obj 145 | iter_time.update(time.time() - end) 146 | if i % print_freq == 0 or i == len(iterable) - 1: 147 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 148 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 149 | if torch.cuda.is_available(): 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time), 154 | memory=torch.cuda.max_memory_allocated() / MB)) 155 | else: 156 | print(log_msg.format( 157 | i, len(iterable), eta=eta_string, 158 | meters=str(self), 159 | time=str(iter_time), data=str(data_time))) 160 | i += 1 161 | end = time.time() 162 | total_time = time.time() - start_time 163 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 164 | print('{} Total time: {} ({:.4f} s / it)'.format( 165 | header, total_time_str, total_time / len(iterable))) 166 | 167 | 168 | def setup_for_distributed(is_master): 169 | """ 170 | This function disables printing when not in master process 171 | """ 172 | builtin_print = builtins.print 173 | 174 | def print(*args, **kwargs): 175 | force = kwargs.pop('force', False) 176 | force = force or (get_world_size() > 8) 177 | if is_master or force: 178 | now = datetime.datetime.now().time() 179 | builtin_print('[{}] '.format(now), end='') # print with time stamp 180 | builtin_print(*args, **kwargs) 181 | 182 | builtins.print = print 183 | 184 | 185 | def is_dist_avail_and_initialized(): 186 | if not dist.is_available(): 187 | return False 188 | if not dist.is_initialized(): 189 | return False 190 | return True 191 | 192 | 193 | def get_world_size(): 194 | if not is_dist_avail_and_initialized(): 195 | return 1 196 | return dist.get_world_size() 197 | 198 | 199 | def get_rank(): 200 | if not is_dist_avail_and_initialized(): 201 | return 0 202 | return dist.get_rank() 203 | 204 | 205 | def is_main_process(): 206 | return get_rank() == 0 207 | 208 | 209 | def save_on_master(*args, **kwargs): 210 | if is_main_process(): 211 | torch.save(*args, **kwargs) 212 | 213 | 214 | def init_distributed_mode(args): 215 | if args.dist_on_itp: 216 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 217 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 218 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 219 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 220 | os.environ['LOCAL_RANK'] = str(args.gpu) 221 | os.environ['RANK'] = str(args.rank) 222 | os.environ['WORLD_SIZE'] = str(args.world_size) 223 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 224 | args.rank = int(os.environ["RANK"]) 225 | args.world_size = int(os.environ['WORLD_SIZE']) 226 | args.gpu = int(os.environ['LOCAL_RANK']) 227 | elif 'SLURM_PROCID' in os.environ: 228 | args.rank = int(os.environ['SLURM_PROCID']) 229 | args.gpu = args.rank % torch.cuda.device_count() 230 | else: 231 | print('Not using distributed mode') 232 | setup_for_distributed(is_master=True) # hack 233 | args.distributed = False 234 | return 235 | 236 | args.distributed = True 237 | 238 | torch.cuda.set_device(args.gpu) 239 | args.dist_backend = 'nccl' 240 | print('| distributed init (rank {}): {}, gpu {}'.format( 241 | args.rank, args.dist_url, args.gpu), flush=True) 242 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 243 | world_size=args.world_size, rank=args.rank) 244 | torch.distributed.barrier() 245 | setup_for_distributed(args.rank == 0) 246 | 247 | 248 | class NativeScalerWithGradNormCount: 249 | state_dict_key = "amp_scaler" 250 | 251 | def __init__(self): 252 | self._scaler = torch.cuda.amp.GradScaler() 253 | 254 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 255 | self._scaler.scale(loss).backward(create_graph=create_graph) 256 | if update_grad: 257 | if clip_grad is not None: 258 | assert parameters is not None 259 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 260 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 261 | else: 262 | self._scaler.unscale_(optimizer) 263 | norm = get_grad_norm_(parameters) 264 | self._scaler.step(optimizer) 265 | self._scaler.update() 266 | else: 267 | norm = None 268 | return norm 269 | 270 | def state_dict(self): 271 | return self._scaler.state_dict() 272 | 273 | def load_state_dict(self, state_dict): 274 | self._scaler.load_state_dict(state_dict) 275 | 276 | 277 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 278 | if isinstance(parameters, torch.Tensor): 279 | parameters = [parameters] 280 | parameters = [p for p in parameters if p.grad is not None] 281 | norm_type = float(norm_type) 282 | if len(parameters) == 0: 283 | return torch.tensor(0.) 284 | device = parameters[0].grad.device 285 | if norm_type == inf: 286 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 287 | else: 288 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 289 | return total_norm 290 | 291 | 292 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, last_only=False, model_ema=None): 293 | output_dir = Path(args.output_dir) 294 | epoch_name = str(epoch) 295 | if last_only: 296 | tags = ["checkpoint"] 297 | else: 298 | tags = [f"checkpoint-{epoch_name}", "checkpoint"] 299 | if loss_scaler is not None: 300 | checkpoint_paths = [output_dir / (tag + ".pth") for tag in tags] 301 | for checkpoint_path in checkpoint_paths: 302 | to_save = { 303 | 'model': model_without_ddp.state_dict(), 304 | 'optimizer': optimizer.state_dict(), 305 | 'epoch': epoch, 306 | 'scaler': loss_scaler.state_dict(), 307 | 'args': args, 308 | } 309 | if model_ema is not None: 310 | to_save['model_ema'] = model_ema.module.state_dict() 311 | save_on_master(to_save, checkpoint_path) 312 | else: 313 | client_state = {'epoch': epoch} 314 | for tag in tags: 315 | model.save_checkpoint(save_dir=args.output_dir, tag=tag, client_state=client_state) 316 | 317 | 318 | 319 | def load_model(args, model_without_ddp, model_ema, optimizer, loss_scaler): 320 | if args.resume: 321 | if args.resume.startswith('https'): 322 | checkpoint = torch.hub.load_state_dict_from_url( 323 | args.resume, map_location='cpu', check_hash=True) 324 | elif args.resume == 'allow': 325 | path = os.path.join(args.output_dir, 'checkpoint.pth') 326 | if not os.path.exists(path): 327 | return 328 | checkpoint = torch.load(path, map_location='cpu') 329 | else: 330 | checkpoint = torch.load(args.resume, map_location='cpu') 331 | model_without_ddp.load_state_dict(checkpoint['model']) 332 | if model_ema is not None and "model_ema" in checkpoint: 333 | model_ema_incompatible_keys = model_ema.module.load_state_dict( 334 | checkpoint["model_ema"] 335 | ) 336 | print("Loaded model_ema:", model_ema_incompatible_keys) 337 | print("Resume checkpoint %s" % args.resume) 338 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 339 | optimizer.load_state_dict(checkpoint['optimizer']) 340 | args.start_epoch = checkpoint['epoch'] + 1 341 | if 'scaler' in checkpoint: 342 | loss_scaler.load_state_dict(checkpoint['scaler']) 343 | print("With optim & sched!") 344 | 345 | 346 | def all_reduce_mean(x): 347 | world_size = get_world_size() 348 | if world_size > 1: 349 | x_reduce = torch.tensor(x).cuda() 350 | dist.all_reduce(x_reduce) 351 | x_reduce /= world_size 352 | return x_reduce.item() 353 | else: 354 | return x -------------------------------------------------------------------------------- /models_dual_vit.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 7 | # DeiT: https://github.com/facebookresearch/deit 8 | # -------------------------------------------------------- 9 | 10 | """ Vision Transformer (ViT) in PyTorch 11 | 12 | A PyTorch implement of Vision Transformers as described in 13 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 14 | 15 | The official jax code is released and available at https://github.com/google-research/vision_transformer 16 | 17 | Status/TODO: 18 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. 19 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. 20 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. 21 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. 22 | 23 | Acknowledgments: 24 | * The paper authors for releasing code and weights, thanks! 25 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 26 | for some einops/einsum fun 27 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 28 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 29 | 30 | Hacked together by / Copyright 2020 Ross Wightman 31 | """ 32 | import torch 33 | import torch.nn as nn 34 | from functools import partial 35 | 36 | from timm.models.vision_transformer import Mlp, PatchEmbed 37 | from timm.models.layers import trunc_normal_ 38 | 39 | 40 | class Attention(nn.Module): 41 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, mode=None): 42 | super().__init__() 43 | self.num_heads = num_heads 44 | head_dim = dim // num_heads 45 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 46 | self.scale = qk_scale or head_dim ** -0.5 47 | 48 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 49 | self.proj = nn.Linear(dim, dim) 50 | self.mode = mode 51 | 52 | 53 | def forward(self, x, teacher_act=None, return_act=None): 54 | B, N, C = x.shape 55 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 56 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 57 | 58 | if self.mode == 'copy' and teacher_act is not None: 59 | attn = teacher_act.softmax(dim=-1) 60 | elif self.mode == 'copy_q' and teacher_act is not None: 61 | teacher_q = teacher_act 62 | attn_logits = (teacher_q @ k.transpose(-2, -1)) * self.scale 63 | attn = attn_logits.softmax(dim=-1) 64 | elif self.mode == 'copy_k' and teacher_act is not None: 65 | teacher_k = teacher_act 66 | attn_logits = (q @ teacher_k.transpose(-2, -1)) * self.scale 67 | attn = attn_logits.softmax(dim=-1) 68 | else: 69 | attn_logits = (q @ k.transpose(-2, -1)) * self.scale 70 | attn = attn_logits.softmax(dim=-1) 71 | if self.mode == 'copy_v' and teacher_act is not None: 72 | teacher_v = teacher_act 73 | x = (attn @ teacher_v).transpose(1, 2).reshape(B, N, C) 74 | else: 75 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 76 | x = self.proj(x) 77 | if return_act == 'attention': 78 | return x, attn_logits 79 | elif return_act == 'q': 80 | return x, q 81 | elif return_act == 'k': 82 | return x, k 83 | elif return_act == 'v': 84 | return x, v 85 | return x 86 | 87 | 88 | class Block(nn.Module): 89 | 90 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 91 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, mode=None): 92 | super().__init__() 93 | self.norm1 = norm_layer(dim) 94 | self.attn = Attention( 95 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, mode=mode) 96 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 97 | self.norm2 = norm_layer(dim) 98 | mlp_hidden_dim = int(dim * mlp_ratio) 99 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=0) 100 | 101 | def forward(self, x, drop_masks=None, teacher_act=None, return_act=None): 102 | if drop_masks is None: 103 | drop_masks = (1, 1) 104 | else: 105 | shape = (len(x),) + (1,) * (x.ndim - 1) 106 | drop_masks = (drop_masks[0].view(*shape), drop_masks[1].view(*shape)) 107 | attn_outputs = self.attn(self.norm1(x), teacher_act=teacher_act, return_act=return_act) 108 | if return_act is not None: 109 | attn_result, act = attn_outputs 110 | else: 111 | attn_result = attn_outputs 112 | 113 | x = x + drop_masks[0] * attn_result 114 | x = x + drop_masks[1] * self.mlp(self.norm2(x)) 115 | if return_act is not None: 116 | return x, act 117 | return x 118 | 119 | 120 | class VisionTransformer(nn.Module): 121 | """ Vision Transformer with support for patch or hybrid CNN input stage 122 | """ 123 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 124 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 125 | norm_layer=nn.LayerNorm, mode=None, global_pool=False,): 126 | super().__init__() 127 | self.num_classes = num_classes 128 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 129 | self.mode = mode 130 | 131 | self.patch_embed = PatchEmbed( 132 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 133 | num_patches = self.patch_embed.num_patches 134 | 135 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 136 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 137 | # self.pos_drop = nn.Dropout(p=drop_rate) 138 | 139 | # dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 140 | self.blocks = nn.ModuleList([ 141 | Block( 142 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 143 | qk_scale=qk_scale, norm_layer=norm_layer, mode=mode) 144 | for i in range(depth)]) 145 | self.norm = norm_layer(embed_dim) 146 | 147 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 148 | #self.repr = nn.Linear(embed_dim, representation_size) 149 | #self.repr_act = nn.Tanh() 150 | 151 | # Classifier head 152 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 153 | 154 | trunc_normal_(self.pos_embed, std=.02) 155 | trunc_normal_(self.cls_token, std=.02) 156 | self.apply(self._init_weights) 157 | 158 | self.global_pool = global_pool 159 | if self.global_pool: 160 | self.fc_norm = norm_layer(embed_dim) 161 | del self.norm # remove the original norm 162 | 163 | def _init_weights(self, m): 164 | if isinstance(m, nn.Linear): 165 | trunc_normal_(m.weight, std=.02) 166 | if isinstance(m, nn.Linear) and m.bias is not None: 167 | nn.init.constant_(m.bias, 0) 168 | elif isinstance(m, nn.LayerNorm): 169 | nn.init.constant_(m.bias, 0) 170 | nn.init.constant_(m.weight, 1.0) 171 | 172 | @torch.jit.ignore 173 | def no_weight_decay(self): 174 | return {'pos_embed', 'cls_token'} 175 | 176 | def get_classifier(self): 177 | return self.head 178 | 179 | def reset_classifier(self, num_classes, global_pool=''): 180 | self.num_classes = num_classes 181 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 182 | 183 | def forward_features(self, x, drop_masks=None, teacher_act=None, return_act=None): 184 | if teacher_act is not None: 185 | assert 'copy' in self.mode 186 | # pad if we copy fewer blocks 187 | if len(teacher_act) < len(self.blocks): 188 | teacher_act = teacher_act + [None] * (len(self.blocks) - len(teacher_act)) 189 | else: 190 | teacher_act = [None] * len(self.blocks) 191 | 192 | B = x.shape[0] 193 | x = self.patch_embed(x) 194 | 195 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 196 | x = torch.cat((cls_tokens, x), dim=1) 197 | x = x + self.pos_embed 198 | 199 | attns = [] 200 | for i, blk in enumerate(self.blocks): 201 | if return_act is not None: 202 | x, act = blk(x, return_act=return_act, drop_masks=drop_masks[i]) 203 | attns.append(act) 204 | else: 205 | x = blk(x, teacher_act=teacher_act[i], drop_masks=drop_masks[i]) 206 | 207 | if self.global_pool: 208 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 209 | outcome = self.fc_norm(x) 210 | else: 211 | x = self.norm(x) 212 | outcome = x[:, 0] 213 | 214 | if return_act is not None: 215 | return outcome, attns 216 | return outcome 217 | 218 | 219 | def forward(self, x, drop_masks=None, teacher_act=None, return_act=None): 220 | x = self.forward_features(x, drop_masks=drop_masks, teacher_act=teacher_act, return_act=return_act) 221 | if return_act is not None: 222 | x, act = x 223 | return self.head(x), act 224 | else: 225 | return self.head(x) 226 | 227 | 228 | class DualVisionTransformer(nn.Module): 229 | """ 230 | Vision Transformer with support for global average pooling 231 | Has two streams (one is a teacher, the other is a student) 232 | """ 233 | def __init__(self, mode='distill', drop_path_rate=0, 234 | teacher_kwargs=None, student_kwargs=None, end_layer=-3): 235 | super().__init__() 236 | assert mode in {'copy', 'copy_q', 'copy_k', 'copy_v', 'distill', 'distill_q', 'distill_k', 'distill_v'} 237 | self.mode = mode 238 | self.drop_path_rate = drop_path_rate 239 | self.teacher_depth = teacher_kwargs['depth'] 240 | self.student_depth = student_kwargs['depth'] 241 | self.teacher = VisionTransformer(mode='teacher', **teacher_kwargs) 242 | self.student = VisionTransformer(mode=mode, **student_kwargs) 243 | self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.teacher_depth)] 244 | self.end_layer = end_layer # for distillation 245 | 246 | def forward(self, x): 247 | drop_masks = self.get_drop_path_mask(len(x), x.dtype, x.device) # depends on self.training 248 | # attention activation to get from the teacher 249 | if self.mode in {'copy', 'distill'}: 250 | return_act = 'attention' 251 | elif self.mode in {'copy_q', 'distill_q'}: 252 | return_act = 'q' 253 | elif self.mode in {'copy_k', 'distill_k'}: 254 | return_act = 'k' 255 | elif self.mode in {'copy_v', 'distill_v'}: 256 | return_act = 'v' 257 | else: 258 | raise NotImplementedError 259 | with torch.no_grad(): 260 | if self.training or 'copy' in self.mode: 261 | _, teacher_act = self.teacher.forward_features(x, 262 | drop_masks=drop_masks, 263 | return_act=return_act) 264 | teacher_act = [act.detach() for act in teacher_act] 265 | 266 | # forward student 267 | if 'copy' in self.mode: 268 | # teacher_act to copy 269 | act_to_copy = teacher_act[:self.teacher_depth + self.end_layer] 270 | return self.student(x, drop_masks=drop_masks, teacher_act=act_to_copy) 271 | elif 'distill' in self.mode and self.training: 272 | student_out, student_act = self.student(x, drop_masks=drop_masks, return_act=return_act) 273 | distill_loss = 0 274 | 275 | if self.mode == 'distill': 276 | def distill_loss_fn(teacher_map, student_map): 277 | return - (teacher_map.softmax(dim=-1) * torch.log_softmax(student_map, dim=-1)).sum(dim=-1).mean() 278 | else: 279 | def distill_loss_fn(teacher_map, student_map): 280 | return torch.nn.functional.mse_loss(teacher_map, student_map) 281 | 282 | for i in range(0, self.teacher_depth + self.end_layer): 283 | distill_loss += distill_loss_fn(teacher_act[i], student_act[i]) 284 | return student_out, distill_loss 285 | else: 286 | return self.student(x, drop_masks=drop_masks) 287 | 288 | 289 | def get_drop_path_mask(self, batch_size, dtype, device): 290 | if not self.training: 291 | return [None] * self.teacher_depth 292 | drop_masks = [] 293 | shape = (batch_size,) 294 | for i in range(self.teacher_depth): 295 | curr_layer_masks = [] 296 | for _ in range(2): 297 | keep_prob = 1 - self.dpr[i] 298 | random_tensor = keep_prob + torch.rand(shape, dtype=dtype, device=device) 299 | random_tensor.floor_() # binarize 300 | output = random_tensor / keep_prob 301 | curr_layer_masks.append(output) 302 | drop_masks.append(curr_layer_masks) 303 | return drop_masks 304 | 305 | def no_weight_decay(self): 306 | return {'student.' + k for k in self.student.no_weight_decay()} 307 | 308 | 309 | def dual_vit_base_patch16(mode='distill', drop_path_rate=0, end_layer=-3, **kwargs): 310 | kwargs = dict( 311 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 312 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 313 | 314 | model = DualVisionTransformer(mode=mode, drop_path_rate=drop_path_rate, end_layer=end_layer, 315 | teacher_kwargs=kwargs, student_kwargs=kwargs) 316 | return model 317 | 318 | 319 | def dual_vit_large_patch16(mode='distill', drop_path_rate=0, end_layer=-3, **kwargs): 320 | kwargs = dict( 321 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 322 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 323 | 324 | model = DualVisionTransformer(mode=mode, drop_path_rate=drop_path_rate, end_layer=end_layer, 325 | teacher_kwargs=kwargs, student_kwargs=kwargs) 326 | return model 327 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # MAE: https://github.com/facebookresearch/mae 6 | # DeiT: https://github.com/facebookresearch/deit 7 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import datetime 12 | import json 13 | import numpy as np 14 | import os 15 | import time 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.backends.cudnn as cudnn 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | from timm.models.layers import trunc_normal_ 23 | from timm.data.mixup import Mixup 24 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 25 | 26 | import util.lr_decay as lrd 27 | import util.misc as misc 28 | from util.datasets import build_dataset 29 | from util.pos_embed import interpolate_pos_embed 30 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 31 | from util.model_ema import ModelEmaV2 32 | 33 | import models_vit 34 | import models_dual_vit 35 | 36 | from engine_finetune import train_one_epoch, evaluate 37 | 38 | 39 | def get_args_parser(): 40 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) 41 | parser.add_argument('--batch_size', default=64, type=int, 42 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 43 | parser.add_argument('--epochs', default=50, type=int) 44 | parser.add_argument('--accum_iter', default=1, type=int, 45 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 46 | 47 | # Model parameters 48 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 49 | help='Name of model to train') 50 | 51 | parser.add_argument('--input_size', default=224, type=int, 52 | help='images input size') 53 | 54 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 55 | help='Drop path rate (default: 0.1)') 56 | 57 | # Optimizer parameters 58 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 59 | help='Clip gradient norm (default: None, no clipping)') 60 | parser.add_argument('--weight_decay', type=float, default=0.05, 61 | help='weight decay (default: 0.05)') 62 | 63 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 64 | help='learning rate (absolute lr)') 65 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 66 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 67 | parser.add_argument('--beta2', type=float, default=0.999, metavar='BETA2', 68 | help='beta_2 for optimizer') 69 | parser.add_argument('--layer_decay', type=float, default=0.75, 70 | help='layer-wise lr decay from ELECTRA/BEiT') 71 | 72 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 73 | help='lower lr bound for cyclic schedulers that hit 0') 74 | 75 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 76 | help='epochs to warmup LR') 77 | 78 | # Augmentation parameters 79 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 80 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 81 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 82 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 83 | parser.add_argument('--smoothing', type=float, default=0.1, 84 | help='Label smoothing (default: 0.1)') 85 | 86 | # * Random Erase params 87 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 88 | help='Random erase prob (default: 0.25)') 89 | parser.add_argument('--remode', type=str, default='pixel', 90 | help='Random erase mode (default: "pixel")') 91 | parser.add_argument('--recount', type=int, default=1, 92 | help='Random erase count (default: 1)') 93 | parser.add_argument('--resplit', action='store_true', default=False, 94 | help='Do not random erase first (clean) augmentation split') 95 | 96 | # * Mixup params 97 | parser.add_argument('--mixup', type=float, default=0, 98 | help='mixup alpha, mixup enabled if > 0.') 99 | parser.add_argument('--cutmix', type=float, default=0, 100 | help='cutmix alpha, cutmix enabled if > 0.') 101 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 102 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 103 | parser.add_argument('--mixup_prob', type=float, default=1.0, 104 | help='Probability of performing mixup or cutmix when either/both is enabled') 105 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 106 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 107 | parser.add_argument('--mixup_mode', type=str, default='batch', 108 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 109 | 110 | # * Finetuning params 111 | parser.add_argument('--finetune', default='', 112 | help='finetune from checkpoint') 113 | parser.add_argument('--use_teacher_ema', action='store_true', help="Use the EMA teacher model.") 114 | parser.add_argument('--global_pool', action='store_true') 115 | parser.set_defaults(global_pool=True) 116 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 117 | help='Use class token instead of global pool for classification') 118 | 119 | # Dataset parameters 120 | parser.add_argument('--dataset_name', default='imagenet', type=str, metavar='DATASET') 121 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 122 | help='dataset path') 123 | parser.add_argument('--nb_classes', default=1000, type=int, 124 | help='number of the classification types') 125 | 126 | parser.add_argument('--output_dir', default='./output_dir', 127 | help='path where to save, empty for no saving') 128 | parser.add_argument('--log_dir', default='./output_dir', 129 | help='path where to tensorboard log') 130 | parser.add_argument('--device', default='cuda', 131 | help='device to use for training / testing') 132 | parser.add_argument('--seed', default=0, type=int) 133 | parser.add_argument('--resume', default='', 134 | help='resume from checkpoint') 135 | 136 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 137 | help='start epoch') 138 | parser.add_argument('--eval', action='store_true', 139 | help='Perform evaluation only') 140 | parser.add_argument('--dist_eval', action='store_true', default=False, 141 | help='Enabling distributed evaluation (recommended during training for faster monitor') 142 | parser.add_argument('--num_workers', default=10, type=int) 143 | parser.add_argument('--pin_mem', action='store_true', 144 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 145 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 146 | parser.set_defaults(pin_mem=True) 147 | 148 | # distributed training parameters 149 | parser.add_argument('--world_size', default=1, type=int, 150 | help='number of distributed processes') 151 | parser.add_argument('--local_rank', default=-1, type=int) 152 | parser.add_argument('--dist_on_itp', action='store_true') 153 | parser.add_argument('--dist_url', default='env://', 154 | help='url used to set up distributed training') 155 | 156 | parser.add_argument('--ema', default=None, type=float, metavar='ALPHA', 157 | help='ema decay (default: None, no ema)') 158 | 159 | parser.add_argument('--train_qkv', action='store_true') 160 | # for attn transfer 161 | parser.add_argument('--mode', default=None, type=str, 162 | choices=[None, 'copy', 'distill', 'copy_q', 'copy_k', 'copy_v', 'distill_q', 'distill_k', 'distill_v'], 163 | help='mode for attention transfer') 164 | parser.add_argument('--end_layer', default=0, type=int) 165 | parser.add_argument('--atd_weight', default=3, type=float) 166 | 167 | 168 | return parser 169 | 170 | 171 | def main(args): 172 | misc.init_distributed_mode(args) 173 | 174 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 175 | print("{}".format(args).replace(', ', ',\n')) 176 | 177 | device = torch.device(args.device) 178 | 179 | # fix the seed for reproducibility 180 | seed = args.seed + misc.get_rank() 181 | torch.manual_seed(seed) 182 | np.random.seed(seed) 183 | 184 | cudnn.benchmark = True 185 | 186 | dataset_train = build_dataset(is_train=True, args=args) 187 | dataset_val = build_dataset(is_train=False, args=args) 188 | 189 | if True: # args.distributed: 190 | num_tasks = misc.get_world_size() 191 | global_rank = misc.get_rank() 192 | sampler_train = torch.utils.data.DistributedSampler( 193 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 194 | ) 195 | print("Sampler_train = %s" % str(sampler_train)) 196 | if args.dist_eval: 197 | if len(dataset_val) % num_tasks != 0: 198 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 199 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 200 | 'equal num of samples per-process.') 201 | sampler_val = torch.utils.data.DistributedSampler( 202 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 203 | else: 204 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 205 | else: 206 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 207 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 208 | 209 | if global_rank == 0 and args.log_dir is not None and not args.eval: 210 | os.makedirs(args.log_dir, exist_ok=True) 211 | log_writer = SummaryWriter(log_dir=args.log_dir) 212 | else: 213 | log_writer = None 214 | 215 | data_loader_train = torch.utils.data.DataLoader( 216 | dataset_train, sampler=sampler_train, 217 | batch_size=args.batch_size, 218 | num_workers=args.num_workers, 219 | pin_memory=args.pin_mem, 220 | drop_last=True, 221 | ) 222 | 223 | data_loader_val = torch.utils.data.DataLoader( 224 | dataset_val, sampler=sampler_val, 225 | batch_size=args.batch_size, 226 | num_workers=args.num_workers, 227 | pin_memory=args.pin_mem, 228 | drop_last=False 229 | ) 230 | 231 | mixup_fn = None 232 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 233 | if mixup_active: 234 | print("Mixup is activated!") 235 | mixup_fn = Mixup( 236 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 237 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 238 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 239 | 240 | if args.model.startswith("dual_vit"): 241 | model = models_dual_vit.__dict__[args.model]( 242 | mode=args.mode, 243 | num_classes=args.nb_classes, 244 | drop_path_rate=args.drop_path, 245 | global_pool=args.global_pool, 246 | ) 247 | else: 248 | model = models_vit.__dict__[args.model]( 249 | num_classes=args.nb_classes, 250 | drop_path_rate=args.drop_path, 251 | global_pool=args.global_pool, 252 | ) 253 | 254 | if args.finetune and not args.eval: 255 | checkpoint = torch.load(args.finetune, map_location='cpu') 256 | 257 | print("Load pre-trained checkpoint from: %s" % args.finetune) 258 | checkpoint_model = checkpoint['model_ema' if args.use_teacher_ema else 'model'] 259 | state_dict = model.state_dict() 260 | for k in ['head.weight', 'head.bias']: 261 | if k in checkpoint_model and (k not in state_dict or checkpoint_model[k].shape != state_dict[k].shape): 262 | print(f"Removing key {k} from pretrained checkpoint") 263 | del checkpoint_model[k] 264 | 265 | # interpolate position embedding 266 | model_to_interp = model.teacher if 'dual_vit' in args.model else model 267 | interpolate_pos_embed(model_to_interp, checkpoint_model) 268 | 269 | # load pre-trained model 270 | if 'dual_vit' in args.model: 271 | new_checkpoint = {'teacher.' + k: v for k, v in checkpoint_model.items()} 272 | checkpoint_model = new_checkpoint 273 | msg = model.load_state_dict(checkpoint_model, strict=False) 274 | print(msg) 275 | 276 | missing_keys = set(msg.missing_keys) 277 | if 'dual_vit' in args.model: 278 | # remove student. from msg missing keys 279 | for k in msg.missing_keys: 280 | if k.startswith('student.'): 281 | missing_keys.remove(k) 282 | if k.startswith('teacher.'): 283 | missing_keys.add(k[8:]) 284 | missing_keys.remove(k) 285 | assert {'head.weight', 'head.bias'} <= set(missing_keys) 286 | if args.global_pool: 287 | assert set(missing_keys) <= {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 288 | 289 | # manually initialize fc layer 290 | if 'dual_vit' in args.model: 291 | # initialize student 292 | trunc_normal_(model.student.head.weight, std=0.02) 293 | else: 294 | trunc_normal_(model.head.weight, std=2e-5) 295 | 296 | model.to(device) 297 | 298 | model_without_ddp = model 299 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 300 | 301 | print("Model = %s" % str(model_without_ddp)) 302 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 303 | 304 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 305 | 306 | if args.lr is None: # only base_lr is specified 307 | args.lr = args.blr * eff_batch_size / 256 308 | 309 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 310 | print("actual lr: %.2e" % args.lr) 311 | 312 | print("accumulate grad iterations: %d" % args.accum_iter) 313 | print("effective batch size: %d" % eff_batch_size) 314 | 315 | if 'dual_vit' in args.model: 316 | for n, p in model.teacher.named_parameters(): 317 | p.requires_grad = False 318 | if args.train_qkv: 319 | for n, p in model.named_parameters(): 320 | if not ('qkv' in n or 'head' in n or 'fc_norm' in n): 321 | p.requires_grad = False 322 | if args.distributed: 323 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 324 | model_without_ddp = model.module 325 | 326 | # build optimizer with layer-wise lr decay (lrd) 327 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 328 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 329 | layer_decay=args.layer_decay 330 | ) 331 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, args.beta2)) 332 | loss_scaler = NativeScaler() 333 | 334 | if mixup_fn is not None: 335 | # smoothing is handled with mixup label transform 336 | criterion = SoftTargetCrossEntropy() 337 | elif args.smoothing > 0.: 338 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 339 | else: 340 | criterion = torch.nn.CrossEntropyLoss() 341 | 342 | print("criterion = %s" % str(criterion)) 343 | 344 | model_ema = None 345 | if args.ema is not None: 346 | model_ema = ModelEmaV2( 347 | model_without_ddp, 348 | decay=args.ema, 349 | device=None 350 | ) 351 | 352 | misc.load_model(args=args, model_without_ddp=model_without_ddp, model_ema=model_ema, 353 | optimizer=optimizer, loss_scaler=loss_scaler) 354 | 355 | if args.eval: 356 | test_stats = evaluate(data_loader_val, model, device) 357 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 358 | if model_ema is not None: 359 | ema_test_stats = evaluate(data_loader_val, model_ema.module, device) 360 | print( 361 | f"Accuracy of the network (EMA) on the {len(dataset_val)} test images: {ema_test_stats['acc1']:.1f}%" 362 | ) 363 | exit(0) 364 | 365 | print(f"Start training for {args.epochs} epochs") 366 | start_time = time.time() 367 | max_accuracy = 0.0 368 | ema_max_accuracy = 0.0 369 | for epoch in range(args.start_epoch, args.epochs): 370 | if args.distributed: 371 | data_loader_train.sampler.set_epoch(epoch) 372 | train_stats = train_one_epoch( 373 | model, criterion, data_loader_train, 374 | optimizer, device, epoch, loss_scaler, 375 | args.clip_grad, mixup_fn, 376 | log_writer=log_writer, 377 | model_ema=model_ema, 378 | args=args 379 | ) 380 | if args.output_dir: 381 | misc.save_model( 382 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 383 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 384 | 385 | test_stats = evaluate(data_loader_val, model, device) 386 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 387 | if model_ema is not None: 388 | ema_test_stats = evaluate(data_loader_val, model_ema.module, device) 389 | print(f"Accuracy of the EMA network on the {len(dataset_val)} test images: {ema_test_stats['acc1']:.1f}%") 390 | ema_max_accuracy = max(ema_max_accuracy, ema_test_stats["acc1"]) 391 | print(f"Max accuracy (EMA): {ema_max_accuracy:.2f}%") 392 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 393 | print(f'Max accuracy: {max_accuracy:.2f}%') 394 | 395 | if log_writer is not None: 396 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 397 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 398 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 399 | if model_ema is not None: 400 | log_writer.add_scalar('perf/ema_test_acc1', ema_test_stats['acc1'], epoch) 401 | log_writer.add_scalar('perf/ema_test_acc5', ema_test_stats['acc5'], epoch) 402 | log_writer.add_scalar('perf/ema_test_loss', ema_test_stats['loss'], epoch) 403 | 404 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 405 | **{f'test_{k}': v for k, v in test_stats.items()}, 406 | 'epoch': epoch, 407 | 'n_parameters': n_parameters} 408 | if model_ema is not None: 409 | log_stats.update({f'ema_test_{k}': v for k, v in ema_test_stats.items()}) 410 | 411 | if args.output_dir and misc.is_main_process(): 412 | if log_writer is not None: 413 | log_writer.flush() 414 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 415 | f.write(json.dumps(log_stats) + "\n") 416 | 417 | total_time = time.time() - start_time 418 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 419 | print('Training time {}'.format(total_time_str)) 420 | 421 | 422 | if __name__ == '__main__': 423 | args = get_args_parser() 424 | args = args.parse_args() 425 | if args.output_dir: 426 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 427 | main(args) 428 | --------------------------------------------------------------------------------