├── util
├── __init__.py
├── lr_sched.py
├── datasets.py
├── model_ema.py
├── lr_decay.py
├── pos_embed.py
└── misc.py
├── attention_transfer.png
├── environment.yml
├── README.md
├── models_vit.py
├── submitit_finetune.py
├── engine_finetune.py
├── models_dual_vit.py
├── LICENSE
└── main_finetune.py
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/attention_transfer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alexlioralexli/attention-transfer/HEAD/attention_transfer.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: attn-transfer
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python=3.8
7 | - torchvision
8 | - timm
9 | - torchaudio
10 | - pytorch-cuda=12.1
11 | - pytorch
12 | - tensorboard
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # --------------------------------------------------------
7 | import math
8 |
9 | def adjust_learning_rate(optimizer, epoch, args):
10 | """Decay the learning rate with half-cycle cosine after warmup"""
11 | if epoch < args.warmup_epochs:
12 | lr = args.lr * epoch / args.warmup_epochs
13 | else:
14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
16 | for param_group in optimizer.param_groups:
17 | if "lr_scale" in param_group:
18 | param_group["lr"] = lr * param_group["lr_scale"]
19 | else:
20 | param_group["lr"] = lr
21 | return lr
22 |
--------------------------------------------------------------------------------
/util/datasets.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # DeiT: https://github.com/facebookresearch/deit
7 | # --------------------------------------------------------
8 |
9 | import os
10 | import PIL
11 |
12 | from torchvision import datasets, transforms
13 |
14 | from timm.data import create_transform
15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16 |
17 | def build_dataset(is_train, args):
18 | transform = build_transform(is_train, args)
19 |
20 | if args.dataset_name == 'imagenet':
21 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
22 | dataset = datasets.ImageFolder(root, transform=transform)
23 | else:
24 | raise NotImplementedError
25 |
26 | print(dataset)
27 |
28 | return dataset
29 |
30 |
31 | def build_transform(is_train, args):
32 | mean = IMAGENET_DEFAULT_MEAN
33 | std = IMAGENET_DEFAULT_STD
34 | # train transform
35 | if is_train:
36 | # this should always dispatch to transforms_imagenet_train
37 | transform = create_transform(
38 | input_size=args.input_size,
39 | is_training=True,
40 | color_jitter=args.color_jitter,
41 | auto_augment=args.aa,
42 | interpolation='bicubic',
43 | re_prob=args.reprob,
44 | re_mode=args.remode,
45 | re_count=args.recount,
46 | mean=mean,
47 | std=std,
48 | )
49 | return transform
50 |
51 | # eval transform
52 | t = []
53 | if args.input_size <= 224:
54 | crop_pct = 224 / 256
55 | else:
56 | crop_pct = 1.0
57 | size = int(args.input_size / crop_pct)
58 | t.append(
59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
60 | )
61 | t.append(transforms.CenterCrop(args.input_size))
62 |
63 | t.append(transforms.ToTensor())
64 | t.append(transforms.Normalize(mean, std))
65 | return transforms.Compose(t)
66 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Attention Transfer: A PyTorch Implementation
2 |
3 |
4 |
5 |
6 |
7 |
8 | This is a PyTorch/GPU reimplementation of the NeurIPS 2024 paper [On the Surprising Effectiveness of Attention Transfer for Vision Transformers](https://arxiv.org/abs/2411.09702):
9 | ```
10 | @inproceedings{AttentionTransfer2024,
11 | title = {On the Surprising Effectiveness of Attention Transfer for Vision Transformers},
12 | author = {Li, Alexander Cong and Tian, Yuandong and Chen, Beidi and Pathak, Deepak and Chen, Xinlei},
13 | booktitle = {The Thirty-eighth Annual Conference on Neural Information Processing Systems},
14 | year = {2024}
15 | }
16 | ```
17 |
18 | * The original implementation was in Jax+TPU. This reimplementation is in PyTorch+GPU.
19 |
20 | * This repo is a modification of the [MAE repo](https://github.com/facebookresearch/mae). Refer to that repo for detailed installation and setup.
21 |
22 | ### Installation
23 | ```
24 | conda env create -f environment.yml
25 | ```
26 |
27 | ### Training student with pre-trained teacher
28 | Obtain pre-trained MAE checkpoints from [here](https://github.com/facebookresearch/mae).
29 |
30 | **Attention Distillation**
31 |
32 | To train with multi-node distributed training, run the following on 8 nodes with 8 GPUs each:
33 | ```
34 | python submitit_finetune.py \
35 | --job_dir ${JOB_DIR} \
36 | --nodes 8 \
37 | --batch_size 32 \
38 | --model dual_vit_large_patch16 --mode distill --end_layer -6 --atd_weight 3.0 \
39 | --finetune mae_pretrain_vit_large.pth --resume allow \
40 | --epochs 200 --ema 0.9999 \
41 | --blr 1e-4 --layer_decay 1 --beta2 0.95 --warmup_epochs 20 \
42 | --weight_decay 0.3 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
43 | --dist_eval --data_path ${IMAGENET_PATH}
44 | ```
45 |
46 | **Attention Copy**
47 | ```
48 | python submitit_finetune.py \
49 | --job_dir ${JOB_DIR} \
50 | --nodes 8 \
51 | --batch_size 32 \
52 | --model dual_vit_large_patch16 --mode copy \
53 | --finetune mae_pretrain_vit_large.pth --resume allow \
54 | --epochs 100 --ema 0.9999 \
55 | --blr 1e-3 --min_lr 2e-3 --layer_decay 0.75 --beta2 0.999 \
56 | --weight_decay 0.05 --drop_path 0 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
57 | --dist_eval --data_path ${IMAGENET_PATH}
58 | ```
59 |
60 |
61 |
62 |
63 |
64 |
65 | ### License
66 |
67 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details.
68 |
--------------------------------------------------------------------------------
/util/model_ema.py:
--------------------------------------------------------------------------------
1 | # from Xinlei Chen
2 | from copy import deepcopy
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class ModelEmaV2(nn.Module):
9 | """Model Exponential Moving Average V2
10 |
11 | Keep a moving average of everything in the model state_dict (parameters and buffers).
12 | V2 of this module is simpler, it does not match params/buffers based on name but simply
13 | iterates in order. It works with torchscript (JIT of full model).
14 |
15 | This is intended to allow functionality like
16 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
17 |
18 | A smoothed version of the weights is necessary for some training schemes to perform well.
19 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
20 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
21 | smoothing of weights to match results. Pay attention to the decay constant you are using
22 | relative to your update count per epoch.
23 |
24 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
25 | disable validation of the EMA weights. Validation will have to be done manually in a separate
26 | process, or after the training stops converging.
27 |
28 | This class is sensitive where it is initialized in the sequence of model init,
29 | GPU assignment and distributed training wrappers.
30 | """
31 |
32 | def __init__(self, model, decay=0.9999, device=None):
33 | super(ModelEmaV2, self).__init__()
34 | # make a copy of the model for accumulating moving average of weights
35 | self.module = deepcopy(model)
36 | self.module.eval()
37 | self.decay = decay
38 | self.device = device # perform ema on different device from model if set
39 | if self.device is not None:
40 | self.module.to(device=device)
41 |
42 | def _update(self, model, update_fn):
43 | with torch.no_grad():
44 | for ema_v, model_v in zip(
45 | self.module.state_dict().values(), model.state_dict().values()
46 | ):
47 | if self.device is not None:
48 | model_v = model_v.to(device=self.device)
49 | ema_v.copy_(update_fn(ema_v, model_v))
50 |
51 | def update(self, model):
52 | self._update(
53 | model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m
54 | )
55 |
56 | def set(self, model):
57 | self._update(model, update_fn=lambda e, m: m)
58 |
--------------------------------------------------------------------------------
/util/lr_decay.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # ELECTRA https://github.com/google-research/electra
7 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
8 | # --------------------------------------------------------
9 |
10 | import models_dual_vit
11 |
12 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
13 | """
14 | Parameter groups for layer-wise lr decay
15 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
16 | """
17 | param_group_names = {}
18 | param_groups = {}
19 |
20 | if isinstance(model, models_dual_vit.DualVisionTransformer):
21 | num_layers = model.student_depth
22 | else:
23 | num_layers = len(model.blocks) + 1
24 |
25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26 |
27 | for n, p in model.named_parameters():
28 | if not p.requires_grad:
29 | continue
30 |
31 | # no decay: all 1D parameters and model specific ones
32 | if p.ndim == 1 or n in no_weight_decay_list:
33 | g_decay = "no_decay"
34 | this_decay = 0.
35 | else:
36 | g_decay = "decay"
37 | this_decay = weight_decay
38 |
39 | layer_id = get_layer_id_for_vit(n, num_layers)
40 | group_name = "layer_%d_%s" % (layer_id, g_decay)
41 |
42 | if group_name not in param_group_names:
43 | this_scale = layer_scales[layer_id]
44 |
45 | param_group_names[group_name] = {
46 | "lr_scale": this_scale,
47 | "weight_decay": this_decay,
48 | "params": [],
49 | }
50 | param_groups[group_name] = {
51 | "lr_scale": this_scale,
52 | "weight_decay": this_decay,
53 | "params": [],
54 | }
55 |
56 | param_group_names[group_name]["params"].append(n)
57 | param_groups[group_name]["params"].append(p)
58 |
59 | return list(param_groups.values())
60 |
61 |
62 | def get_layer_id_for_vit(name, num_layers):
63 | """
64 | Assign a parameter with its layer id
65 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
66 | """
67 |
68 | if 'cls_token' in name or 'pos_embed' in name:
69 | return 0
70 | elif 'patch_embed' in name:
71 | return 0
72 | elif 'blocks' in name:
73 | if 'student' in name:
74 | return int(name.split('.')[2]) + 1
75 | else:
76 | return int(name.split('.')[1]) + 1
77 | else:
78 | return num_layers
--------------------------------------------------------------------------------
/models_vit.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
7 | # DeiT: https://github.com/facebookresearch/deit
8 | # --------------------------------------------------------
9 |
10 | from functools import partial
11 |
12 | import torch
13 | import torch.nn as nn
14 |
15 | import timm.models.vision_transformer
16 |
17 |
18 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
19 | """ Vision Transformer with support for global average pooling
20 | """
21 | def __init__(self, global_pool=False, **kwargs):
22 | super(VisionTransformer, self).__init__(**kwargs)
23 |
24 | self.global_pool = global_pool
25 | if self.global_pool:
26 | norm_layer = kwargs['norm_layer']
27 | embed_dim = kwargs['embed_dim']
28 | self.fc_norm = norm_layer(embed_dim)
29 |
30 | del self.norm # remove the original norm
31 |
32 | def forward_features(self, x, layer_to_return=None):
33 | # return features up until layer_to_return
34 | intermediate_feat = []
35 | B = x.shape[0]
36 | x = self.patch_embed(x)
37 |
38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
39 | x = torch.cat((cls_tokens, x), dim=1)
40 | x = x + self.pos_embed
41 | x = self.pos_drop(x)
42 |
43 | for i, blk in enumerate(self.blocks):
44 | x = blk(x)
45 | if layer_to_return is not None and i < layer_to_return:
46 | intermediate_feat.append(x)
47 |
48 | if self.global_pool:
49 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
50 | outcome = self.fc_norm(x)
51 | else:
52 | x = self.norm(x)
53 | outcome = x[:, 0]
54 |
55 | if layer_to_return is not None:
56 | return outcome, intermediate_feat
57 | else:
58 | return outcome
59 |
60 | def forward(self, x, return_features=False, layer_to_return=None):
61 | if layer_to_return is not None:
62 | final_feats, intermediate_feats = self.forward_features(x, layer_to_return=layer_to_return)
63 | else:
64 | final_feats = self.forward_features(x)
65 | pred = self.head(final_feats)
66 | if return_features and layer_to_return is None:
67 | return pred, final_feats
68 | elif return_features and layer_to_return is not None:
69 | return pred, intermediate_feats
70 | else:
71 | return pred
72 |
73 |
74 | def vit_base_patch16(**kwargs):
75 | model = VisionTransformer(
76 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
77 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
78 | return model
79 |
80 |
81 | def vit_large_patch16(**kwargs):
82 | model = VisionTransformer(
83 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
84 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
85 | return model
86 |
87 |
88 | def vit_huge_patch14(**kwargs):
89 | model = VisionTransformer(
90 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
91 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
92 | return model
--------------------------------------------------------------------------------
/util/pos_embed.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 | assert embed_dim % 2 == 0
40 |
41 | # use half of dimensions to encode grid_h
42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44 |
45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 | return emb
47 |
48 |
49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50 | """
51 | embed_dim: output dimension for each position
52 | pos: a list of positions to be encoded: size (M,)
53 | out: (M, D)
54 | """
55 | assert embed_dim % 2 == 0
56 | omega = np.arange(embed_dim // 2, dtype=np.float)
57 | omega /= embed_dim / 2.
58 | omega = 1. / 10000**omega # (D/2,)
59 |
60 | pos = pos.reshape(-1) # (M,)
61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62 |
63 | emb_sin = np.sin(out) # (M, D/2)
64 | emb_cos = np.cos(out) # (M, D/2)
65 |
66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67 | return emb
68 |
69 |
70 | # --------------------------------------------------------
71 | # Interpolate position embeddings for high-resolution
72 | # References:
73 | # DeiT: https://github.com/facebookresearch/deit
74 | # --------------------------------------------------------
75 | def interpolate_pos_embed(model, checkpoint_model):
76 | if 'pos_embed' in checkpoint_model:
77 | pos_embed_checkpoint = checkpoint_model['pos_embed']
78 | embedding_size = pos_embed_checkpoint.shape[-1]
79 | num_patches = model.patch_embed.num_patches
80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81 | # height (== width) for the checkpoint position embedding
82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83 | # height (== width) for the new position embedding
84 | new_size = int(num_patches ** 0.5)
85 | # class_token and dist_token are kept unchanged
86 | if orig_size != new_size:
87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89 | # only the position tokens are interpolated
90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92 | pos_tokens = torch.nn.functional.interpolate(
93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96 | checkpoint_model['pos_embed'] = new_pos_embed
97 |
--------------------------------------------------------------------------------
/submitit_finetune.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # --------------------------------------------------------
7 | # A script to run multinode training with submitit.
8 | # --------------------------------------------------------
9 |
10 | import argparse
11 | import os
12 | import uuid
13 | from pathlib import Path
14 |
15 | import main_finetune as classification
16 | import submitit
17 |
18 |
19 | def parse_args():
20 | classification_parser = classification.get_args_parser()
21 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser])
22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
23 | parser.add_argument("--ncpus", default=10, type=int, help="Number of cpus per gpu")
24 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
25 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job")
26 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
27 |
28 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
29 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs")
30 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
31 | parser.add_argument('--name', default="mae", type=str, help="Name for job")
32 | return parser.parse_args()
33 |
34 |
35 | def get_shared_folder() -> Path:
36 | user = os.getenv("USER")
37 | if Path("/checkpoint/").is_dir():
38 | p = Path(f"/checkpoint/{user}/experiments")
39 | p.mkdir(exist_ok=True)
40 | return p
41 | raise RuntimeError("No shared folder available")
42 |
43 |
44 | def get_init_file():
45 | # Init file must not exist, but it's parent dir must exist.
46 | os.makedirs(str(get_shared_folder()), exist_ok=True)
47 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
48 | if init_file.exists():
49 | os.remove(str(init_file))
50 | return init_file
51 |
52 |
53 | class Trainer(object):
54 | def __init__(self, args):
55 | self.args = args
56 |
57 | def __call__(self):
58 | import main_finetune as classification
59 |
60 | self._setup_gpu_args()
61 | self._setup_fair()
62 | classification.main(self.args)
63 |
64 | def checkpoint(self):
65 | import os
66 | import submitit
67 |
68 | self.args.dist_url = get_init_file().as_uri()
69 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
70 | if os.path.exists(checkpoint_file):
71 | self.args.resume = checkpoint_file
72 | print("Requeuing ", self.args)
73 | empty_trainer = type(self)(self.args)
74 | return submitit.helpers.DelayedSubmission(empty_trainer)
75 |
76 | def _setup_gpu_args(self):
77 | import submitit
78 | from pathlib import Path
79 |
80 | job_env = submitit.JobEnvironment()
81 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
82 | self.args.log_dir = self.args.output_dir
83 | self.args.gpu = job_env.local_rank
84 | self.args.rank = job_env.global_rank
85 | self.args.world_size = job_env.num_tasks
86 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
87 |
88 | def _setup_fair(self):
89 | os.environ["GLOO_SOCKET_IFNAME"] = ""
90 | os.environ["NCCL_SOCKET_IFNAME"] = ""
91 | os.environ["NCCL_DEBUG"] = "INFO"
92 | # os.environ["NCCL_BLOCKING_WAIT"] = '1'
93 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
94 | # os.environ["CUDA_LAUNCH_BLOCKING"] = '1'
95 | return
96 |
97 |
98 | def main():
99 | args = parse_args()
100 | if args.job_dir == "":
101 | args.job_dir = get_shared_folder() / "%j"
102 |
103 | # Note that the folder will depend on the job_id, to easily track experiments
104 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
105 |
106 | num_gpus_per_node = args.ngpus
107 | num_cpus_per_gpu = args.ncpus
108 | nodes = args.nodes
109 | timeout_min = args.timeout
110 |
111 | partition = args.partition
112 | kwargs = {}
113 | if args.use_volta32:
114 | kwargs['slurm_constraint'] = 'volta32gb'
115 | if args.comment:
116 | kwargs['slurm_comment'] = args.comment
117 |
118 | executor.update_parameters(
119 | mem_gb=40 * num_gpus_per_node,
120 | gpus_per_node=num_gpus_per_node,
121 | tasks_per_node=num_gpus_per_node, # one task per GPU
122 | cpus_per_task=num_cpus_per_gpu,
123 | nodes=nodes,
124 | timeout_min=timeout_min,
125 | # Below are cluster dependent parameters
126 | slurm_partition=partition,
127 | slurm_signal_delay_s=120,
128 | **kwargs
129 | )
130 |
131 | executor.update_parameters(name=args.name)
132 |
133 | args.dist_url = get_init_file().as_uri()
134 | args.output_dir = args.job_dir
135 |
136 | trainer = Trainer(args)
137 | job = executor.submit(trainer)
138 |
139 | print(job.job_id)
140 |
141 |
142 | if __name__ == "__main__":
143 | main()
144 |
--------------------------------------------------------------------------------
/engine_finetune.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # DeiT: https://github.com/facebookresearch/deit
7 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
8 | # --------------------------------------------------------
9 |
10 | import math
11 | import sys
12 | from typing import Iterable, Optional
13 |
14 | import torch
15 |
16 | from timm.data import Mixup
17 | from timm.utils import accuracy
18 |
19 | import util.misc as misc
20 | import util.lr_sched as lr_sched
21 |
22 |
23 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
24 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
25 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
26 | mixup_fn: Optional[Mixup] = None, log_writer=None,
27 | args=None, model_ema=None):
28 | model.train(True)
29 | metric_logger = misc.MetricLogger(delimiter=" ")
30 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
31 | header = 'Epoch: [{}]'.format(epoch)
32 | print_freq = 20
33 |
34 | accum_iter = args.accum_iter
35 |
36 | optimizer.zero_grad()
37 |
38 | if log_writer is not None:
39 | print('log_dir: {}'.format(log_writer.log_dir))
40 |
41 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
42 | # we use a per iteration (instead of per epoch) lr scheduler
43 | if data_iter_step % accum_iter == 0:
44 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
45 |
46 | samples = samples.to(device, non_blocking=True)
47 | targets = targets.to(device, non_blocking=True)
48 |
49 | if mixup_fn is not None:
50 | samples, targets = mixup_fn(samples, targets)
51 |
52 | if args.model.startswith("dual_vit") and 'distill' in args.mode:
53 | outputs, distill_loss = model(samples)
54 | metric_logger.update(distill_loss=distill_loss.item())
55 | loss = criterion(outputs, targets) + args.atd_weight * distill_loss
56 | else:
57 | # with torch.cuda.amp.autocast():
58 | outputs = model(samples)
59 | loss = criterion(outputs, targets)
60 |
61 | loss_value = loss.item()
62 |
63 | if not math.isfinite(loss_value):
64 | print("Loss is {}, stopping training".format(loss_value))
65 | sys.exit(1)
66 |
67 | loss /= accum_iter
68 | loss_scaler(loss, optimizer, clip_grad=max_norm,
69 | parameters=model.parameters(), create_graph=False,
70 | update_grad=(data_iter_step + 1) % accum_iter == 0)
71 | if (data_iter_step + 1) % accum_iter == 0:
72 | optimizer.zero_grad()
73 | if model_ema is not None:
74 | model_ema.update(model)
75 |
76 | torch.cuda.synchronize()
77 |
78 | metric_logger.update(loss=loss_value)
79 | min_lr = 10.
80 | max_lr = 0.
81 | for group in optimizer.param_groups:
82 | min_lr = min(min_lr, group["lr"])
83 | max_lr = max(max_lr, group["lr"])
84 |
85 | metric_logger.update(lr=max_lr)
86 |
87 | loss_value_reduce = misc.all_reduce_mean(loss_value)
88 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
89 | """ We use epoch_1000x as the x-axis in tensorboard.
90 | This calibrates different curves when batch size changes.
91 | """
92 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
93 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
94 | log_writer.add_scalar('lr', max_lr, epoch_1000x)
95 | if args.model.startswith("dual_vit") and 'distill' in args.mode:
96 | log_writer.add_scalar('distill_loss', distill_loss.item(), epoch_1000x)
97 |
98 | # gather the stats from all processes
99 | metric_logger.synchronize_between_processes()
100 | print("Averaged stats:", metric_logger)
101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
102 |
103 |
104 | @torch.no_grad()
105 | def evaluate(data_loader, model, device):
106 | criterion = torch.nn.CrossEntropyLoss()
107 |
108 | metric_logger = misc.MetricLogger(delimiter=" ")
109 | header = 'Test:'
110 |
111 | # switch to evaluation mode
112 | model.eval()
113 |
114 | for batch in metric_logger.log_every(data_loader, 10, header):
115 | images = batch[0]
116 | target = batch[-1]
117 | images = images.to(device, non_blocking=True)
118 | target = target.to(device, non_blocking=True)
119 |
120 | # compute output
121 | # with torch.cuda.amp.autocast():
122 | output = model(images)
123 | loss = criterion(output, target)
124 |
125 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
126 |
127 | batch_size = images.shape[0]
128 | metric_logger.update(loss=loss.item())
129 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
130 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
131 | # gather the stats from all processes
132 | metric_logger.synchronize_between_processes()
133 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
134 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
135 |
136 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # DeiT: https://github.com/facebookresearch/deit
7 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
8 | # --------------------------------------------------------
9 |
10 | import builtins
11 | import datetime
12 | import os
13 | import time
14 | from collections import defaultdict, deque
15 | from pathlib import Path
16 |
17 | import torch
18 | import torch.distributed as dist
19 | from torch import inf
20 |
21 |
22 | class SmoothedValue(object):
23 | """Track a series of values and provide access to smoothed values over a
24 | window or the global series average.
25 | """
26 |
27 | def __init__(self, window_size=20, fmt=None):
28 | if fmt is None:
29 | fmt = "{median:.4f} ({global_avg:.4f})"
30 | self.deque = deque(maxlen=window_size)
31 | self.total = 0.0
32 | self.count = 0
33 | self.fmt = fmt
34 |
35 | def update(self, value, n=1):
36 | self.deque.append(value)
37 | self.count += n
38 | self.total += value * n
39 |
40 | def synchronize_between_processes(self):
41 | """
42 | Warning: does not synchronize the deque!
43 | """
44 | if not is_dist_avail_and_initialized():
45 | return
46 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
47 | dist.barrier()
48 | dist.all_reduce(t)
49 | t = t.tolist()
50 | self.count = int(t[0])
51 | self.total = t[1]
52 |
53 | @property
54 | def median(self):
55 | d = torch.tensor(list(self.deque))
56 | return d.median().item()
57 |
58 | @property
59 | def avg(self):
60 | d = torch.tensor(list(self.deque), dtype=torch.float32)
61 | return d.mean().item()
62 |
63 | @property
64 | def global_avg(self):
65 | return self.total / self.count
66 |
67 | @property
68 | def max(self):
69 | return max(self.deque)
70 |
71 | @property
72 | def value(self):
73 | return self.deque[-1]
74 |
75 | def __str__(self):
76 | return self.fmt.format(
77 | median=self.median,
78 | avg=self.avg,
79 | global_avg=self.global_avg,
80 | max=self.max,
81 | value=self.value)
82 |
83 |
84 | class MetricLogger(object):
85 | def __init__(self, delimiter="\t"):
86 | self.meters = defaultdict(SmoothedValue)
87 | self.delimiter = delimiter
88 |
89 | def update(self, **kwargs):
90 | for k, v in kwargs.items():
91 | if v is None:
92 | continue
93 | if isinstance(v, torch.Tensor):
94 | v = v.item()
95 | assert isinstance(v, (float, int))
96 | self.meters[k].update(v)
97 |
98 | def __getattr__(self, attr):
99 | if attr in self.meters:
100 | return self.meters[attr]
101 | if attr in self.__dict__:
102 | return self.__dict__[attr]
103 | raise AttributeError("'{}' object has no attribute '{}'".format(
104 | type(self).__name__, attr))
105 |
106 | def __str__(self):
107 | loss_str = []
108 | for name, meter in self.meters.items():
109 | loss_str.append(
110 | "{}: {}".format(name, str(meter))
111 | )
112 | return self.delimiter.join(loss_str)
113 |
114 | def synchronize_between_processes(self):
115 | for meter in self.meters.values():
116 | meter.synchronize_between_processes()
117 |
118 | def add_meter(self, name, meter):
119 | self.meters[name] = meter
120 |
121 | def log_every(self, iterable, print_freq, header=None):
122 | i = 0
123 | if not header:
124 | header = ''
125 | start_time = time.time()
126 | end = time.time()
127 | iter_time = SmoothedValue(fmt='{avg:.4f}')
128 | data_time = SmoothedValue(fmt='{avg:.4f}')
129 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
130 | log_msg = [
131 | header,
132 | '[{0' + space_fmt + '}/{1}]',
133 | 'eta: {eta}',
134 | '{meters}',
135 | 'time: {time}',
136 | 'data: {data}'
137 | ]
138 | if torch.cuda.is_available():
139 | log_msg.append('max mem: {memory:.0f}')
140 | log_msg = self.delimiter.join(log_msg)
141 | MB = 1024.0 * 1024.0
142 | for obj in iterable:
143 | data_time.update(time.time() - end)
144 | yield obj
145 | iter_time.update(time.time() - end)
146 | if i % print_freq == 0 or i == len(iterable) - 1:
147 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
148 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149 | if torch.cuda.is_available():
150 | print(log_msg.format(
151 | i, len(iterable), eta=eta_string,
152 | meters=str(self),
153 | time=str(iter_time), data=str(data_time),
154 | memory=torch.cuda.max_memory_allocated() / MB))
155 | else:
156 | print(log_msg.format(
157 | i, len(iterable), eta=eta_string,
158 | meters=str(self),
159 | time=str(iter_time), data=str(data_time)))
160 | i += 1
161 | end = time.time()
162 | total_time = time.time() - start_time
163 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
164 | print('{} Total time: {} ({:.4f} s / it)'.format(
165 | header, total_time_str, total_time / len(iterable)))
166 |
167 |
168 | def setup_for_distributed(is_master):
169 | """
170 | This function disables printing when not in master process
171 | """
172 | builtin_print = builtins.print
173 |
174 | def print(*args, **kwargs):
175 | force = kwargs.pop('force', False)
176 | force = force or (get_world_size() > 8)
177 | if is_master or force:
178 | now = datetime.datetime.now().time()
179 | builtin_print('[{}] '.format(now), end='') # print with time stamp
180 | builtin_print(*args, **kwargs)
181 |
182 | builtins.print = print
183 |
184 |
185 | def is_dist_avail_and_initialized():
186 | if not dist.is_available():
187 | return False
188 | if not dist.is_initialized():
189 | return False
190 | return True
191 |
192 |
193 | def get_world_size():
194 | if not is_dist_avail_and_initialized():
195 | return 1
196 | return dist.get_world_size()
197 |
198 |
199 | def get_rank():
200 | if not is_dist_avail_and_initialized():
201 | return 0
202 | return dist.get_rank()
203 |
204 |
205 | def is_main_process():
206 | return get_rank() == 0
207 |
208 |
209 | def save_on_master(*args, **kwargs):
210 | if is_main_process():
211 | torch.save(*args, **kwargs)
212 |
213 |
214 | def init_distributed_mode(args):
215 | if args.dist_on_itp:
216 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
217 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
218 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
219 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
220 | os.environ['LOCAL_RANK'] = str(args.gpu)
221 | os.environ['RANK'] = str(args.rank)
222 | os.environ['WORLD_SIZE'] = str(args.world_size)
223 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
224 | args.rank = int(os.environ["RANK"])
225 | args.world_size = int(os.environ['WORLD_SIZE'])
226 | args.gpu = int(os.environ['LOCAL_RANK'])
227 | elif 'SLURM_PROCID' in os.environ:
228 | args.rank = int(os.environ['SLURM_PROCID'])
229 | args.gpu = args.rank % torch.cuda.device_count()
230 | else:
231 | print('Not using distributed mode')
232 | setup_for_distributed(is_master=True) # hack
233 | args.distributed = False
234 | return
235 |
236 | args.distributed = True
237 |
238 | torch.cuda.set_device(args.gpu)
239 | args.dist_backend = 'nccl'
240 | print('| distributed init (rank {}): {}, gpu {}'.format(
241 | args.rank, args.dist_url, args.gpu), flush=True)
242 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
243 | world_size=args.world_size, rank=args.rank)
244 | torch.distributed.barrier()
245 | setup_for_distributed(args.rank == 0)
246 |
247 |
248 | class NativeScalerWithGradNormCount:
249 | state_dict_key = "amp_scaler"
250 |
251 | def __init__(self):
252 | self._scaler = torch.cuda.amp.GradScaler()
253 |
254 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
255 | self._scaler.scale(loss).backward(create_graph=create_graph)
256 | if update_grad:
257 | if clip_grad is not None:
258 | assert parameters is not None
259 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
260 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
261 | else:
262 | self._scaler.unscale_(optimizer)
263 | norm = get_grad_norm_(parameters)
264 | self._scaler.step(optimizer)
265 | self._scaler.update()
266 | else:
267 | norm = None
268 | return norm
269 |
270 | def state_dict(self):
271 | return self._scaler.state_dict()
272 |
273 | def load_state_dict(self, state_dict):
274 | self._scaler.load_state_dict(state_dict)
275 |
276 |
277 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
278 | if isinstance(parameters, torch.Tensor):
279 | parameters = [parameters]
280 | parameters = [p for p in parameters if p.grad is not None]
281 | norm_type = float(norm_type)
282 | if len(parameters) == 0:
283 | return torch.tensor(0.)
284 | device = parameters[0].grad.device
285 | if norm_type == inf:
286 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
287 | else:
288 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
289 | return total_norm
290 |
291 |
292 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, last_only=False, model_ema=None):
293 | output_dir = Path(args.output_dir)
294 | epoch_name = str(epoch)
295 | if last_only:
296 | tags = ["checkpoint"]
297 | else:
298 | tags = [f"checkpoint-{epoch_name}", "checkpoint"]
299 | if loss_scaler is not None:
300 | checkpoint_paths = [output_dir / (tag + ".pth") for tag in tags]
301 | for checkpoint_path in checkpoint_paths:
302 | to_save = {
303 | 'model': model_without_ddp.state_dict(),
304 | 'optimizer': optimizer.state_dict(),
305 | 'epoch': epoch,
306 | 'scaler': loss_scaler.state_dict(),
307 | 'args': args,
308 | }
309 | if model_ema is not None:
310 | to_save['model_ema'] = model_ema.module.state_dict()
311 | save_on_master(to_save, checkpoint_path)
312 | else:
313 | client_state = {'epoch': epoch}
314 | for tag in tags:
315 | model.save_checkpoint(save_dir=args.output_dir, tag=tag, client_state=client_state)
316 |
317 |
318 |
319 | def load_model(args, model_without_ddp, model_ema, optimizer, loss_scaler):
320 | if args.resume:
321 | if args.resume.startswith('https'):
322 | checkpoint = torch.hub.load_state_dict_from_url(
323 | args.resume, map_location='cpu', check_hash=True)
324 | elif args.resume == 'allow':
325 | path = os.path.join(args.output_dir, 'checkpoint.pth')
326 | if not os.path.exists(path):
327 | return
328 | checkpoint = torch.load(path, map_location='cpu')
329 | else:
330 | checkpoint = torch.load(args.resume, map_location='cpu')
331 | model_without_ddp.load_state_dict(checkpoint['model'])
332 | if model_ema is not None and "model_ema" in checkpoint:
333 | model_ema_incompatible_keys = model_ema.module.load_state_dict(
334 | checkpoint["model_ema"]
335 | )
336 | print("Loaded model_ema:", model_ema_incompatible_keys)
337 | print("Resume checkpoint %s" % args.resume)
338 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
339 | optimizer.load_state_dict(checkpoint['optimizer'])
340 | args.start_epoch = checkpoint['epoch'] + 1
341 | if 'scaler' in checkpoint:
342 | loss_scaler.load_state_dict(checkpoint['scaler'])
343 | print("With optim & sched!")
344 |
345 |
346 | def all_reduce_mean(x):
347 | world_size = get_world_size()
348 | if world_size > 1:
349 | x_reduce = torch.tensor(x).cuda()
350 | dist.all_reduce(x_reduce)
351 | x_reduce /= world_size
352 | return x_reduce.item()
353 | else:
354 | return x
--------------------------------------------------------------------------------
/models_dual_vit.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
7 | # DeiT: https://github.com/facebookresearch/deit
8 | # --------------------------------------------------------
9 |
10 | """ Vision Transformer (ViT) in PyTorch
11 |
12 | A PyTorch implement of Vision Transformers as described in
13 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
14 |
15 | The official jax code is released and available at https://github.com/google-research/vision_transformer
16 |
17 | Status/TODO:
18 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
19 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
20 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
21 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
22 |
23 | Acknowledgments:
24 | * The paper authors for releasing code and weights, thanks!
25 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
26 | for some einops/einsum fun
27 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
28 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
29 |
30 | Hacked together by / Copyright 2020 Ross Wightman
31 | """
32 | import torch
33 | import torch.nn as nn
34 | from functools import partial
35 |
36 | from timm.models.vision_transformer import Mlp, PatchEmbed
37 | from timm.models.layers import trunc_normal_
38 |
39 |
40 | class Attention(nn.Module):
41 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, mode=None):
42 | super().__init__()
43 | self.num_heads = num_heads
44 | head_dim = dim // num_heads
45 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
46 | self.scale = qk_scale or head_dim ** -0.5
47 |
48 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
49 | self.proj = nn.Linear(dim, dim)
50 | self.mode = mode
51 |
52 |
53 | def forward(self, x, teacher_act=None, return_act=None):
54 | B, N, C = x.shape
55 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
56 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
57 |
58 | if self.mode == 'copy' and teacher_act is not None:
59 | attn = teacher_act.softmax(dim=-1)
60 | elif self.mode == 'copy_q' and teacher_act is not None:
61 | teacher_q = teacher_act
62 | attn_logits = (teacher_q @ k.transpose(-2, -1)) * self.scale
63 | attn = attn_logits.softmax(dim=-1)
64 | elif self.mode == 'copy_k' and teacher_act is not None:
65 | teacher_k = teacher_act
66 | attn_logits = (q @ teacher_k.transpose(-2, -1)) * self.scale
67 | attn = attn_logits.softmax(dim=-1)
68 | else:
69 | attn_logits = (q @ k.transpose(-2, -1)) * self.scale
70 | attn = attn_logits.softmax(dim=-1)
71 | if self.mode == 'copy_v' and teacher_act is not None:
72 | teacher_v = teacher_act
73 | x = (attn @ teacher_v).transpose(1, 2).reshape(B, N, C)
74 | else:
75 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
76 | x = self.proj(x)
77 | if return_act == 'attention':
78 | return x, attn_logits
79 | elif return_act == 'q':
80 | return x, q
81 | elif return_act == 'k':
82 | return x, k
83 | elif return_act == 'v':
84 | return x, v
85 | return x
86 |
87 |
88 | class Block(nn.Module):
89 |
90 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
91 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, mode=None):
92 | super().__init__()
93 | self.norm1 = norm_layer(dim)
94 | self.attn = Attention(
95 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, mode=mode)
96 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
97 | self.norm2 = norm_layer(dim)
98 | mlp_hidden_dim = int(dim * mlp_ratio)
99 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=0)
100 |
101 | def forward(self, x, drop_masks=None, teacher_act=None, return_act=None):
102 | if drop_masks is None:
103 | drop_masks = (1, 1)
104 | else:
105 | shape = (len(x),) + (1,) * (x.ndim - 1)
106 | drop_masks = (drop_masks[0].view(*shape), drop_masks[1].view(*shape))
107 | attn_outputs = self.attn(self.norm1(x), teacher_act=teacher_act, return_act=return_act)
108 | if return_act is not None:
109 | attn_result, act = attn_outputs
110 | else:
111 | attn_result = attn_outputs
112 |
113 | x = x + drop_masks[0] * attn_result
114 | x = x + drop_masks[1] * self.mlp(self.norm2(x))
115 | if return_act is not None:
116 | return x, act
117 | return x
118 |
119 |
120 | class VisionTransformer(nn.Module):
121 | """ Vision Transformer with support for patch or hybrid CNN input stage
122 | """
123 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
124 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
125 | norm_layer=nn.LayerNorm, mode=None, global_pool=False,):
126 | super().__init__()
127 | self.num_classes = num_classes
128 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
129 | self.mode = mode
130 |
131 | self.patch_embed = PatchEmbed(
132 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
133 | num_patches = self.patch_embed.num_patches
134 |
135 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
136 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
137 | # self.pos_drop = nn.Dropout(p=drop_rate)
138 |
139 | # dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
140 | self.blocks = nn.ModuleList([
141 | Block(
142 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
143 | qk_scale=qk_scale, norm_layer=norm_layer, mode=mode)
144 | for i in range(depth)])
145 | self.norm = norm_layer(embed_dim)
146 |
147 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
148 | #self.repr = nn.Linear(embed_dim, representation_size)
149 | #self.repr_act = nn.Tanh()
150 |
151 | # Classifier head
152 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
153 |
154 | trunc_normal_(self.pos_embed, std=.02)
155 | trunc_normal_(self.cls_token, std=.02)
156 | self.apply(self._init_weights)
157 |
158 | self.global_pool = global_pool
159 | if self.global_pool:
160 | self.fc_norm = norm_layer(embed_dim)
161 | del self.norm # remove the original norm
162 |
163 | def _init_weights(self, m):
164 | if isinstance(m, nn.Linear):
165 | trunc_normal_(m.weight, std=.02)
166 | if isinstance(m, nn.Linear) and m.bias is not None:
167 | nn.init.constant_(m.bias, 0)
168 | elif isinstance(m, nn.LayerNorm):
169 | nn.init.constant_(m.bias, 0)
170 | nn.init.constant_(m.weight, 1.0)
171 |
172 | @torch.jit.ignore
173 | def no_weight_decay(self):
174 | return {'pos_embed', 'cls_token'}
175 |
176 | def get_classifier(self):
177 | return self.head
178 |
179 | def reset_classifier(self, num_classes, global_pool=''):
180 | self.num_classes = num_classes
181 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
182 |
183 | def forward_features(self, x, drop_masks=None, teacher_act=None, return_act=None):
184 | if teacher_act is not None:
185 | assert 'copy' in self.mode
186 | # pad if we copy fewer blocks
187 | if len(teacher_act) < len(self.blocks):
188 | teacher_act = teacher_act + [None] * (len(self.blocks) - len(teacher_act))
189 | else:
190 | teacher_act = [None] * len(self.blocks)
191 |
192 | B = x.shape[0]
193 | x = self.patch_embed(x)
194 |
195 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
196 | x = torch.cat((cls_tokens, x), dim=1)
197 | x = x + self.pos_embed
198 |
199 | attns = []
200 | for i, blk in enumerate(self.blocks):
201 | if return_act is not None:
202 | x, act = blk(x, return_act=return_act, drop_masks=drop_masks[i])
203 | attns.append(act)
204 | else:
205 | x = blk(x, teacher_act=teacher_act[i], drop_masks=drop_masks[i])
206 |
207 | if self.global_pool:
208 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
209 | outcome = self.fc_norm(x)
210 | else:
211 | x = self.norm(x)
212 | outcome = x[:, 0]
213 |
214 | if return_act is not None:
215 | return outcome, attns
216 | return outcome
217 |
218 |
219 | def forward(self, x, drop_masks=None, teacher_act=None, return_act=None):
220 | x = self.forward_features(x, drop_masks=drop_masks, teacher_act=teacher_act, return_act=return_act)
221 | if return_act is not None:
222 | x, act = x
223 | return self.head(x), act
224 | else:
225 | return self.head(x)
226 |
227 |
228 | class DualVisionTransformer(nn.Module):
229 | """
230 | Vision Transformer with support for global average pooling
231 | Has two streams (one is a teacher, the other is a student)
232 | """
233 | def __init__(self, mode='distill', drop_path_rate=0,
234 | teacher_kwargs=None, student_kwargs=None, end_layer=-3):
235 | super().__init__()
236 | assert mode in {'copy', 'copy_q', 'copy_k', 'copy_v', 'distill', 'distill_q', 'distill_k', 'distill_v'}
237 | self.mode = mode
238 | self.drop_path_rate = drop_path_rate
239 | self.teacher_depth = teacher_kwargs['depth']
240 | self.student_depth = student_kwargs['depth']
241 | self.teacher = VisionTransformer(mode='teacher', **teacher_kwargs)
242 | self.student = VisionTransformer(mode=mode, **student_kwargs)
243 | self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.teacher_depth)]
244 | self.end_layer = end_layer # for distillation
245 |
246 | def forward(self, x):
247 | drop_masks = self.get_drop_path_mask(len(x), x.dtype, x.device) # depends on self.training
248 | # attention activation to get from the teacher
249 | if self.mode in {'copy', 'distill'}:
250 | return_act = 'attention'
251 | elif self.mode in {'copy_q', 'distill_q'}:
252 | return_act = 'q'
253 | elif self.mode in {'copy_k', 'distill_k'}:
254 | return_act = 'k'
255 | elif self.mode in {'copy_v', 'distill_v'}:
256 | return_act = 'v'
257 | else:
258 | raise NotImplementedError
259 | with torch.no_grad():
260 | if self.training or 'copy' in self.mode:
261 | _, teacher_act = self.teacher.forward_features(x,
262 | drop_masks=drop_masks,
263 | return_act=return_act)
264 | teacher_act = [act.detach() for act in teacher_act]
265 |
266 | # forward student
267 | if 'copy' in self.mode:
268 | # teacher_act to copy
269 | act_to_copy = teacher_act[:self.teacher_depth + self.end_layer]
270 | return self.student(x, drop_masks=drop_masks, teacher_act=act_to_copy)
271 | elif 'distill' in self.mode and self.training:
272 | student_out, student_act = self.student(x, drop_masks=drop_masks, return_act=return_act)
273 | distill_loss = 0
274 |
275 | if self.mode == 'distill':
276 | def distill_loss_fn(teacher_map, student_map):
277 | return - (teacher_map.softmax(dim=-1) * torch.log_softmax(student_map, dim=-1)).sum(dim=-1).mean()
278 | else:
279 | def distill_loss_fn(teacher_map, student_map):
280 | return torch.nn.functional.mse_loss(teacher_map, student_map)
281 |
282 | for i in range(0, self.teacher_depth + self.end_layer):
283 | distill_loss += distill_loss_fn(teacher_act[i], student_act[i])
284 | return student_out, distill_loss
285 | else:
286 | return self.student(x, drop_masks=drop_masks)
287 |
288 |
289 | def get_drop_path_mask(self, batch_size, dtype, device):
290 | if not self.training:
291 | return [None] * self.teacher_depth
292 | drop_masks = []
293 | shape = (batch_size,)
294 | for i in range(self.teacher_depth):
295 | curr_layer_masks = []
296 | for _ in range(2):
297 | keep_prob = 1 - self.dpr[i]
298 | random_tensor = keep_prob + torch.rand(shape, dtype=dtype, device=device)
299 | random_tensor.floor_() # binarize
300 | output = random_tensor / keep_prob
301 | curr_layer_masks.append(output)
302 | drop_masks.append(curr_layer_masks)
303 | return drop_masks
304 |
305 | def no_weight_decay(self):
306 | return {'student.' + k for k in self.student.no_weight_decay()}
307 |
308 |
309 | def dual_vit_base_patch16(mode='distill', drop_path_rate=0, end_layer=-3, **kwargs):
310 | kwargs = dict(
311 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
312 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
313 |
314 | model = DualVisionTransformer(mode=mode, drop_path_rate=drop_path_rate, end_layer=end_layer,
315 | teacher_kwargs=kwargs, student_kwargs=kwargs)
316 | return model
317 |
318 |
319 | def dual_vit_large_patch16(mode='distill', drop_path_rate=0, end_layer=-3, **kwargs):
320 | kwargs = dict(
321 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
322 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
323 |
324 | model = DualVisionTransformer(mode=mode, drop_path_rate=drop_path_rate, end_layer=end_layer,
325 | teacher_kwargs=kwargs, student_kwargs=kwargs)
326 | return model
327 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Attribution-NonCommercial 4.0 International
3 |
4 | =======================================================================
5 |
6 | Creative Commons Corporation ("Creative Commons") is not a law firm and
7 | does not provide legal services or legal advice. Distribution of
8 | Creative Commons public licenses does not create a lawyer-client or
9 | other relationship. Creative Commons makes its licenses and related
10 | information available on an "as-is" basis. Creative Commons gives no
11 | warranties regarding its licenses, any material licensed under their
12 | terms and conditions, or any related information. Creative Commons
13 | disclaims all liability for damages resulting from their use to the
14 | fullest extent possible.
15 |
16 | Using Creative Commons Public Licenses
17 |
18 | Creative Commons public licenses provide a standard set of terms and
19 | conditions that creators and other rights holders may use to share
20 | original works of authorship and other material subject to copyright
21 | and certain other rights specified in the public license below. The
22 | following considerations are for informational purposes only, are not
23 | exhaustive, and do not form part of our licenses.
24 |
25 | Considerations for licensors: Our public licenses are
26 | intended for use by those authorized to give the public
27 | permission to use material in ways otherwise restricted by
28 | copyright and certain other rights. Our licenses are
29 | irrevocable. Licensors should read and understand the terms
30 | and conditions of the license they choose before applying it.
31 | Licensors should also secure all rights necessary before
32 | applying our licenses so that the public can reuse the
33 | material as expected. Licensors should clearly mark any
34 | material not subject to the license. This includes other CC-
35 | licensed material, or material used under an exception or
36 | limitation to copyright. More considerations for licensors:
37 | wiki.creativecommons.org/Considerations_for_licensors
38 |
39 | Considerations for the public: By using one of our public
40 | licenses, a licensor grants the public permission to use the
41 | licensed material under specified terms and conditions. If
42 | the licensor's permission is not necessary for any reason--for
43 | example, because of any applicable exception or limitation to
44 | copyright--then that use is not regulated by the license. Our
45 | licenses grant only permissions under copyright and certain
46 | other rights that a licensor has authority to grant. Use of
47 | the licensed material may still be restricted for other
48 | reasons, including because others have copyright or other
49 | rights in the material. A licensor may make special requests,
50 | such as asking that all changes be marked or described.
51 | Although not required by our licenses, you are encouraged to
52 | respect those requests where reasonable. More_considerations
53 | for the public:
54 | wiki.creativecommons.org/Considerations_for_licensees
55 |
56 | =======================================================================
57 |
58 | Creative Commons Attribution-NonCommercial 4.0 International Public
59 | License
60 |
61 | By exercising the Licensed Rights (defined below), You accept and agree
62 | to be bound by the terms and conditions of this Creative Commons
63 | Attribution-NonCommercial 4.0 International Public License ("Public
64 | License"). To the extent this Public License may be interpreted as a
65 | contract, You are granted the Licensed Rights in consideration of Your
66 | acceptance of these terms and conditions, and the Licensor grants You
67 | such rights in consideration of benefits the Licensor receives from
68 | making the Licensed Material available under these terms and
69 | conditions.
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 | Section 2 -- Scope.
142 |
143 | a. License grant.
144 |
145 | 1. Subject to the terms and conditions of this Public License,
146 | the Licensor hereby grants You a worldwide, royalty-free,
147 | non-sublicensable, non-exclusive, irrevocable license to
148 | exercise the Licensed Rights in the Licensed Material to:
149 |
150 | a. reproduce and Share the Licensed Material, in whole or
151 | in part, for NonCommercial purposes only; and
152 |
153 | b. produce, reproduce, and Share Adapted Material for
154 | NonCommercial purposes only.
155 |
156 | 2. Exceptions and Limitations. For the avoidance of doubt, where
157 | Exceptions and Limitations apply to Your use, this Public
158 | License does not apply, and You do not need to comply with
159 | its terms and conditions.
160 |
161 | 3. Term. The term of this Public License is specified in Section
162 | 6(a).
163 |
164 | 4. Media and formats; technical modifications allowed. The
165 | Licensor authorizes You to exercise the Licensed Rights in
166 | all media and formats whether now known or hereafter created,
167 | and to make technical modifications necessary to do so. The
168 | Licensor waives and/or agrees not to assert any right or
169 | authority to forbid You from making technical modifications
170 | necessary to exercise the Licensed Rights, including
171 | technical modifications necessary to circumvent Effective
172 | Technological Measures. For purposes of this Public License,
173 | simply making modifications authorized by this Section 2(a)
174 | (4) never produces Adapted Material.
175 |
176 | 5. Downstream recipients.
177 |
178 | a. Offer from the Licensor -- Licensed Material. Every
179 | recipient of the Licensed Material automatically
180 | receives an offer from the Licensor to exercise the
181 | Licensed Rights under the terms and conditions of this
182 | Public License.
183 |
184 | b. No downstream restrictions. You may not offer or impose
185 | any additional or different terms or conditions on, or
186 | apply any Effective Technological Measures to, the
187 | Licensed Material if doing so restricts exercise of the
188 | Licensed Rights by any recipient of the Licensed
189 | Material.
190 |
191 | 6. No endorsement. Nothing in this Public License constitutes or
192 | may be construed as permission to assert or imply that You
193 | are, or that Your use of the Licensed Material is, connected
194 | with, or sponsored, endorsed, or granted official status by,
195 | the Licensor or others designated to receive attribution as
196 | provided in Section 3(a)(1)(A)(i).
197 |
198 | b. Other rights.
199 |
200 | 1. Moral rights, such as the right of integrity, are not
201 | licensed under this Public License, nor are publicity,
202 | privacy, and/or other similar personality rights; however, to
203 | the extent possible, the Licensor waives and/or agrees not to
204 | assert any such rights held by the Licensor to the limited
205 | extent necessary to allow You to exercise the Licensed
206 | Rights, but not otherwise.
207 |
208 | 2. Patent and trademark rights are not licensed under this
209 | Public License.
210 |
211 | 3. To the extent possible, the Licensor waives any right to
212 | collect royalties from You for the exercise of the Licensed
213 | Rights, whether directly or through a collecting society
214 | under any voluntary or waivable statutory or compulsory
215 | licensing scheme. In all other cases the Licensor expressly
216 | reserves any right to collect such royalties, including when
217 | the Licensed Material is used other than for NonCommercial
218 | purposes.
219 |
220 | Section 3 -- License Conditions.
221 |
222 | Your exercise of the Licensed Rights is expressly made subject to the
223 | following conditions.
224 |
225 | a. Attribution.
226 |
227 | 1. If You Share the Licensed Material (including in modified
228 | form), You must:
229 |
230 | a. retain the following if it is supplied by the Licensor
231 | with the Licensed Material:
232 |
233 | i. identification of the creator(s) of the Licensed
234 | Material and any others designated to receive
235 | attribution, in any reasonable manner requested by
236 | the Licensor (including by pseudonym if
237 | designated);
238 |
239 | ii. a copyright notice;
240 |
241 | iii. a notice that refers to this Public License;
242 |
243 | iv. a notice that refers to the disclaimer of
244 | warranties;
245 |
246 | v. a URI or hyperlink to the Licensed Material to the
247 | extent reasonably practicable;
248 |
249 | b. indicate if You modified the Licensed Material and
250 | retain an indication of any previous modifications; and
251 |
252 | c. indicate the Licensed Material is licensed under this
253 | Public License, and include the text of, or the URI or
254 | hyperlink to, this Public License.
255 |
256 | 2. You may satisfy the conditions in Section 3(a)(1) in any
257 | reasonable manner based on the medium, means, and context in
258 | which You Share the Licensed Material. For example, it may be
259 | reasonable to satisfy the conditions by providing a URI or
260 | hyperlink to a resource that includes the required
261 | information.
262 |
263 | 3. If requested by the Licensor, You must remove any of the
264 | information required by Section 3(a)(1)(A) to the extent
265 | reasonably practicable.
266 |
267 | 4. If You Share Adapted Material You produce, the Adapter's
268 | License You apply must not prevent recipients of the Adapted
269 | Material from complying with this Public License.
270 |
271 | Section 4 -- Sui Generis Database Rights.
272 |
273 | Where the Licensed Rights include Sui Generis Database Rights that
274 | apply to Your use of the Licensed Material:
275 |
276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
277 | to extract, reuse, reproduce, and Share all or a substantial
278 | portion of the contents of the database for NonCommercial purposes
279 | only;
280 |
281 | b. if You include all or a substantial portion of the database
282 | contents in a database in which You have Sui Generis Database
283 | Rights, then the database in which You have Sui Generis Database
284 | Rights (but not its individual contents) is Adapted Material; and
285 |
286 | c. You must comply with the conditions in Section 3(a) if You Share
287 | all or a substantial portion of the contents of the database.
288 |
289 | For the avoidance of doubt, this Section 4 supplements and does not
290 | replace Your obligations under this Public License where the Licensed
291 | Rights include other Copyright and Similar Rights.
292 |
293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
294 |
295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
305 |
306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
315 |
316 | c. The disclaimer of warranties and limitation of liability provided
317 | above shall be interpreted in a manner that, to the extent
318 | possible, most closely approximates an absolute disclaimer and
319 | waiver of all liability.
320 |
321 | Section 6 -- Term and Termination.
322 |
323 | a. This Public License applies for the term of the Copyright and
324 | Similar Rights licensed here. However, if You fail to comply with
325 | this Public License, then Your rights under this Public License
326 | terminate automatically.
327 |
328 | b. Where Your right to use the Licensed Material has terminated under
329 | Section 6(a), it reinstates:
330 |
331 | 1. automatically as of the date the violation is cured, provided
332 | it is cured within 30 days of Your discovery of the
333 | violation; or
334 |
335 | 2. upon express reinstatement by the Licensor.
336 |
337 | For the avoidance of doubt, this Section 6(b) does not affect any
338 | right the Licensor may have to seek remedies for Your violations
339 | of this Public License.
340 |
341 | c. For the avoidance of doubt, the Licensor may also offer the
342 | Licensed Material under separate terms or conditions or stop
343 | distributing the Licensed Material at any time; however, doing so
344 | will not terminate this Public License.
345 |
346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347 | License.
348 |
349 | Section 7 -- Other Terms and Conditions.
350 |
351 | a. The Licensor shall not be bound by any additional or different
352 | terms or conditions communicated by You unless expressly agreed.
353 |
354 | b. Any arrangements, understandings, or agreements regarding the
355 | Licensed Material not stated herein are separate from and
356 | independent of the terms and conditions of this Public License.
357 |
358 | Section 8 -- Interpretation.
359 |
360 | a. For the avoidance of doubt, this Public License does not, and
361 | shall not be interpreted to, reduce, limit, restrict, or impose
362 | conditions on any use of the Licensed Material that could lawfully
363 | be made without permission under this Public License.
364 |
365 | b. To the extent possible, if any provision of this Public License is
366 | deemed unenforceable, it shall be automatically reformed to the
367 | minimum extent necessary to make it enforceable. If the provision
368 | cannot be reformed, it shall be severed from this Public License
369 | without affecting the enforceability of the remaining terms and
370 | conditions.
371 |
372 | c. No term or condition of this Public License will be waived and no
373 | failure to comply consented to unless expressly agreed to by the
374 | Licensor.
375 |
376 | d. Nothing in this Public License constitutes or may be interpreted
377 | as a limitation upon, or waiver of, any privileges and immunities
378 | that apply to the Licensor or You, including from the legal
379 | processes of any jurisdiction or authority.
380 |
381 | =======================================================================
382 |
383 | Creative Commons is not a party to its public
384 | licenses. Notwithstanding, Creative Commons may elect to apply one of
385 | its public licenses to material it publishes and in those instances
386 | will be considered the “Licensor.” The text of the Creative Commons
387 | public licenses is dedicated to the public domain under the CC0 Public
388 | Domain Dedication. Except for the limited purpose of indicating that
389 | material is shared under a Creative Commons public license or as
390 | otherwise permitted by the Creative Commons policies published at
391 | creativecommons.org/policies, Creative Commons does not authorize the
392 | use of the trademark "Creative Commons" or any other trademark or logo
393 | of Creative Commons without its prior written consent including,
394 | without limitation, in connection with any unauthorized modifications
395 | to any of its public licenses or any other arrangements,
396 | understandings, or agreements concerning use of licensed material. For
397 | the avoidance of doubt, this paragraph does not form part of the
398 | public licenses.
399 |
400 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/main_finetune.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # MAE: https://github.com/facebookresearch/mae
6 | # DeiT: https://github.com/facebookresearch/deit
7 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
8 | # --------------------------------------------------------
9 |
10 | import argparse
11 | import datetime
12 | import json
13 | import numpy as np
14 | import os
15 | import time
16 | from pathlib import Path
17 |
18 | import torch
19 | import torch.backends.cudnn as cudnn
20 | from torch.utils.tensorboard import SummaryWriter
21 |
22 | from timm.models.layers import trunc_normal_
23 | from timm.data.mixup import Mixup
24 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
25 |
26 | import util.lr_decay as lrd
27 | import util.misc as misc
28 | from util.datasets import build_dataset
29 | from util.pos_embed import interpolate_pos_embed
30 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
31 | from util.model_ema import ModelEmaV2
32 |
33 | import models_vit
34 | import models_dual_vit
35 |
36 | from engine_finetune import train_one_epoch, evaluate
37 |
38 |
39 | def get_args_parser():
40 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
41 | parser.add_argument('--batch_size', default=64, type=int,
42 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
43 | parser.add_argument('--epochs', default=50, type=int)
44 | parser.add_argument('--accum_iter', default=1, type=int,
45 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
46 |
47 | # Model parameters
48 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
49 | help='Name of model to train')
50 |
51 | parser.add_argument('--input_size', default=224, type=int,
52 | help='images input size')
53 |
54 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
55 | help='Drop path rate (default: 0.1)')
56 |
57 | # Optimizer parameters
58 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
59 | help='Clip gradient norm (default: None, no clipping)')
60 | parser.add_argument('--weight_decay', type=float, default=0.05,
61 | help='weight decay (default: 0.05)')
62 |
63 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
64 | help='learning rate (absolute lr)')
65 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
66 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
67 | parser.add_argument('--beta2', type=float, default=0.999, metavar='BETA2',
68 | help='beta_2 for optimizer')
69 | parser.add_argument('--layer_decay', type=float, default=0.75,
70 | help='layer-wise lr decay from ELECTRA/BEiT')
71 |
72 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
73 | help='lower lr bound for cyclic schedulers that hit 0')
74 |
75 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
76 | help='epochs to warmup LR')
77 |
78 | # Augmentation parameters
79 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
80 | help='Color jitter factor (enabled only when not using Auto/RandAug)')
81 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
82 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
83 | parser.add_argument('--smoothing', type=float, default=0.1,
84 | help='Label smoothing (default: 0.1)')
85 |
86 | # * Random Erase params
87 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
88 | help='Random erase prob (default: 0.25)')
89 | parser.add_argument('--remode', type=str, default='pixel',
90 | help='Random erase mode (default: "pixel")')
91 | parser.add_argument('--recount', type=int, default=1,
92 | help='Random erase count (default: 1)')
93 | parser.add_argument('--resplit', action='store_true', default=False,
94 | help='Do not random erase first (clean) augmentation split')
95 |
96 | # * Mixup params
97 | parser.add_argument('--mixup', type=float, default=0,
98 | help='mixup alpha, mixup enabled if > 0.')
99 | parser.add_argument('--cutmix', type=float, default=0,
100 | help='cutmix alpha, cutmix enabled if > 0.')
101 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
102 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
103 | parser.add_argument('--mixup_prob', type=float, default=1.0,
104 | help='Probability of performing mixup or cutmix when either/both is enabled')
105 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
106 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
107 | parser.add_argument('--mixup_mode', type=str, default='batch',
108 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
109 |
110 | # * Finetuning params
111 | parser.add_argument('--finetune', default='',
112 | help='finetune from checkpoint')
113 | parser.add_argument('--use_teacher_ema', action='store_true', help="Use the EMA teacher model.")
114 | parser.add_argument('--global_pool', action='store_true')
115 | parser.set_defaults(global_pool=True)
116 | parser.add_argument('--cls_token', action='store_false', dest='global_pool',
117 | help='Use class token instead of global pool for classification')
118 |
119 | # Dataset parameters
120 | parser.add_argument('--dataset_name', default='imagenet', type=str, metavar='DATASET')
121 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
122 | help='dataset path')
123 | parser.add_argument('--nb_classes', default=1000, type=int,
124 | help='number of the classification types')
125 |
126 | parser.add_argument('--output_dir', default='./output_dir',
127 | help='path where to save, empty for no saving')
128 | parser.add_argument('--log_dir', default='./output_dir',
129 | help='path where to tensorboard log')
130 | parser.add_argument('--device', default='cuda',
131 | help='device to use for training / testing')
132 | parser.add_argument('--seed', default=0, type=int)
133 | parser.add_argument('--resume', default='',
134 | help='resume from checkpoint')
135 |
136 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
137 | help='start epoch')
138 | parser.add_argument('--eval', action='store_true',
139 | help='Perform evaluation only')
140 | parser.add_argument('--dist_eval', action='store_true', default=False,
141 | help='Enabling distributed evaluation (recommended during training for faster monitor')
142 | parser.add_argument('--num_workers', default=10, type=int)
143 | parser.add_argument('--pin_mem', action='store_true',
144 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
145 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
146 | parser.set_defaults(pin_mem=True)
147 |
148 | # distributed training parameters
149 | parser.add_argument('--world_size', default=1, type=int,
150 | help='number of distributed processes')
151 | parser.add_argument('--local_rank', default=-1, type=int)
152 | parser.add_argument('--dist_on_itp', action='store_true')
153 | parser.add_argument('--dist_url', default='env://',
154 | help='url used to set up distributed training')
155 |
156 | parser.add_argument('--ema', default=None, type=float, metavar='ALPHA',
157 | help='ema decay (default: None, no ema)')
158 |
159 | parser.add_argument('--train_qkv', action='store_true')
160 | # for attn transfer
161 | parser.add_argument('--mode', default=None, type=str,
162 | choices=[None, 'copy', 'distill', 'copy_q', 'copy_k', 'copy_v', 'distill_q', 'distill_k', 'distill_v'],
163 | help='mode for attention transfer')
164 | parser.add_argument('--end_layer', default=0, type=int)
165 | parser.add_argument('--atd_weight', default=3, type=float)
166 |
167 |
168 | return parser
169 |
170 |
171 | def main(args):
172 | misc.init_distributed_mode(args)
173 |
174 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
175 | print("{}".format(args).replace(', ', ',\n'))
176 |
177 | device = torch.device(args.device)
178 |
179 | # fix the seed for reproducibility
180 | seed = args.seed + misc.get_rank()
181 | torch.manual_seed(seed)
182 | np.random.seed(seed)
183 |
184 | cudnn.benchmark = True
185 |
186 | dataset_train = build_dataset(is_train=True, args=args)
187 | dataset_val = build_dataset(is_train=False, args=args)
188 |
189 | if True: # args.distributed:
190 | num_tasks = misc.get_world_size()
191 | global_rank = misc.get_rank()
192 | sampler_train = torch.utils.data.DistributedSampler(
193 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
194 | )
195 | print("Sampler_train = %s" % str(sampler_train))
196 | if args.dist_eval:
197 | if len(dataset_val) % num_tasks != 0:
198 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
199 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
200 | 'equal num of samples per-process.')
201 | sampler_val = torch.utils.data.DistributedSampler(
202 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
203 | else:
204 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
205 | else:
206 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
207 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
208 |
209 | if global_rank == 0 and args.log_dir is not None and not args.eval:
210 | os.makedirs(args.log_dir, exist_ok=True)
211 | log_writer = SummaryWriter(log_dir=args.log_dir)
212 | else:
213 | log_writer = None
214 |
215 | data_loader_train = torch.utils.data.DataLoader(
216 | dataset_train, sampler=sampler_train,
217 | batch_size=args.batch_size,
218 | num_workers=args.num_workers,
219 | pin_memory=args.pin_mem,
220 | drop_last=True,
221 | )
222 |
223 | data_loader_val = torch.utils.data.DataLoader(
224 | dataset_val, sampler=sampler_val,
225 | batch_size=args.batch_size,
226 | num_workers=args.num_workers,
227 | pin_memory=args.pin_mem,
228 | drop_last=False
229 | )
230 |
231 | mixup_fn = None
232 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
233 | if mixup_active:
234 | print("Mixup is activated!")
235 | mixup_fn = Mixup(
236 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
237 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
238 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
239 |
240 | if args.model.startswith("dual_vit"):
241 | model = models_dual_vit.__dict__[args.model](
242 | mode=args.mode,
243 | num_classes=args.nb_classes,
244 | drop_path_rate=args.drop_path,
245 | global_pool=args.global_pool,
246 | )
247 | else:
248 | model = models_vit.__dict__[args.model](
249 | num_classes=args.nb_classes,
250 | drop_path_rate=args.drop_path,
251 | global_pool=args.global_pool,
252 | )
253 |
254 | if args.finetune and not args.eval:
255 | checkpoint = torch.load(args.finetune, map_location='cpu')
256 |
257 | print("Load pre-trained checkpoint from: %s" % args.finetune)
258 | checkpoint_model = checkpoint['model_ema' if args.use_teacher_ema else 'model']
259 | state_dict = model.state_dict()
260 | for k in ['head.weight', 'head.bias']:
261 | if k in checkpoint_model and (k not in state_dict or checkpoint_model[k].shape != state_dict[k].shape):
262 | print(f"Removing key {k} from pretrained checkpoint")
263 | del checkpoint_model[k]
264 |
265 | # interpolate position embedding
266 | model_to_interp = model.teacher if 'dual_vit' in args.model else model
267 | interpolate_pos_embed(model_to_interp, checkpoint_model)
268 |
269 | # load pre-trained model
270 | if 'dual_vit' in args.model:
271 | new_checkpoint = {'teacher.' + k: v for k, v in checkpoint_model.items()}
272 | checkpoint_model = new_checkpoint
273 | msg = model.load_state_dict(checkpoint_model, strict=False)
274 | print(msg)
275 |
276 | missing_keys = set(msg.missing_keys)
277 | if 'dual_vit' in args.model:
278 | # remove student. from msg missing keys
279 | for k in msg.missing_keys:
280 | if k.startswith('student.'):
281 | missing_keys.remove(k)
282 | if k.startswith('teacher.'):
283 | missing_keys.add(k[8:])
284 | missing_keys.remove(k)
285 | assert {'head.weight', 'head.bias'} <= set(missing_keys)
286 | if args.global_pool:
287 | assert set(missing_keys) <= {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
288 |
289 | # manually initialize fc layer
290 | if 'dual_vit' in args.model:
291 | # initialize student
292 | trunc_normal_(model.student.head.weight, std=0.02)
293 | else:
294 | trunc_normal_(model.head.weight, std=2e-5)
295 |
296 | model.to(device)
297 |
298 | model_without_ddp = model
299 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
300 |
301 | print("Model = %s" % str(model_without_ddp))
302 | print('number of params (M): %.2f' % (n_parameters / 1.e6))
303 |
304 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
305 |
306 | if args.lr is None: # only base_lr is specified
307 | args.lr = args.blr * eff_batch_size / 256
308 |
309 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
310 | print("actual lr: %.2e" % args.lr)
311 |
312 | print("accumulate grad iterations: %d" % args.accum_iter)
313 | print("effective batch size: %d" % eff_batch_size)
314 |
315 | if 'dual_vit' in args.model:
316 | for n, p in model.teacher.named_parameters():
317 | p.requires_grad = False
318 | if args.train_qkv:
319 | for n, p in model.named_parameters():
320 | if not ('qkv' in n or 'head' in n or 'fc_norm' in n):
321 | p.requires_grad = False
322 | if args.distributed:
323 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
324 | model_without_ddp = model.module
325 |
326 | # build optimizer with layer-wise lr decay (lrd)
327 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
328 | no_weight_decay_list=model_without_ddp.no_weight_decay(),
329 | layer_decay=args.layer_decay
330 | )
331 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, args.beta2))
332 | loss_scaler = NativeScaler()
333 |
334 | if mixup_fn is not None:
335 | # smoothing is handled with mixup label transform
336 | criterion = SoftTargetCrossEntropy()
337 | elif args.smoothing > 0.:
338 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
339 | else:
340 | criterion = torch.nn.CrossEntropyLoss()
341 |
342 | print("criterion = %s" % str(criterion))
343 |
344 | model_ema = None
345 | if args.ema is not None:
346 | model_ema = ModelEmaV2(
347 | model_without_ddp,
348 | decay=args.ema,
349 | device=None
350 | )
351 |
352 | misc.load_model(args=args, model_without_ddp=model_without_ddp, model_ema=model_ema,
353 | optimizer=optimizer, loss_scaler=loss_scaler)
354 |
355 | if args.eval:
356 | test_stats = evaluate(data_loader_val, model, device)
357 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
358 | if model_ema is not None:
359 | ema_test_stats = evaluate(data_loader_val, model_ema.module, device)
360 | print(
361 | f"Accuracy of the network (EMA) on the {len(dataset_val)} test images: {ema_test_stats['acc1']:.1f}%"
362 | )
363 | exit(0)
364 |
365 | print(f"Start training for {args.epochs} epochs")
366 | start_time = time.time()
367 | max_accuracy = 0.0
368 | ema_max_accuracy = 0.0
369 | for epoch in range(args.start_epoch, args.epochs):
370 | if args.distributed:
371 | data_loader_train.sampler.set_epoch(epoch)
372 | train_stats = train_one_epoch(
373 | model, criterion, data_loader_train,
374 | optimizer, device, epoch, loss_scaler,
375 | args.clip_grad, mixup_fn,
376 | log_writer=log_writer,
377 | model_ema=model_ema,
378 | args=args
379 | )
380 | if args.output_dir:
381 | misc.save_model(
382 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
383 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
384 |
385 | test_stats = evaluate(data_loader_val, model, device)
386 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
387 | if model_ema is not None:
388 | ema_test_stats = evaluate(data_loader_val, model_ema.module, device)
389 | print(f"Accuracy of the EMA network on the {len(dataset_val)} test images: {ema_test_stats['acc1']:.1f}%")
390 | ema_max_accuracy = max(ema_max_accuracy, ema_test_stats["acc1"])
391 | print(f"Max accuracy (EMA): {ema_max_accuracy:.2f}%")
392 | max_accuracy = max(max_accuracy, test_stats["acc1"])
393 | print(f'Max accuracy: {max_accuracy:.2f}%')
394 |
395 | if log_writer is not None:
396 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
397 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
398 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
399 | if model_ema is not None:
400 | log_writer.add_scalar('perf/ema_test_acc1', ema_test_stats['acc1'], epoch)
401 | log_writer.add_scalar('perf/ema_test_acc5', ema_test_stats['acc5'], epoch)
402 | log_writer.add_scalar('perf/ema_test_loss', ema_test_stats['loss'], epoch)
403 |
404 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
405 | **{f'test_{k}': v for k, v in test_stats.items()},
406 | 'epoch': epoch,
407 | 'n_parameters': n_parameters}
408 | if model_ema is not None:
409 | log_stats.update({f'ema_test_{k}': v for k, v in ema_test_stats.items()})
410 |
411 | if args.output_dir and misc.is_main_process():
412 | if log_writer is not None:
413 | log_writer.flush()
414 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
415 | f.write(json.dumps(log_stats) + "\n")
416 |
417 | total_time = time.time() - start_time
418 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
419 | print('Training time {}'.format(total_time_str))
420 |
421 |
422 | if __name__ == '__main__':
423 | args = get_args_parser()
424 | args = args.parse_args()
425 | if args.output_dir:
426 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
427 | main(args)
428 |
--------------------------------------------------------------------------------