├── figs └── framework.png ├── .github └── FUNDING.yml ├── ModelCard.md ├── util ├── lr_sched.py ├── crop.py ├── lars.py ├── lr_decay.py ├── datasets.py ├── pos_embed.py └── misc.py ├── .gitignore ├── models_vit.py ├── engine_pretrain.py ├── README.md ├── engine_finetune.py ├── CODE_OF_CONDUCT.md ├── main_pretrain.py ├── LICENSE.md ├── main_linprobe.py ├── main_finetune.py └── models_pretrain.py /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/maskalign/HEAD/figs/framework.png -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /ModelCard.md: -------------------------------------------------------------------------------- 1 | # Model Card 2 | 3 | This page lists the MaskAlign model weights. CLIP-L/14* denotes input 196 × 196 resolution image to CLIP-L/14. This will keep the same feature map size as the student model. PT epochs and FT Acc denotes pre-training epochs and fine-tuning accuracy on ImageNet-1K, respectively. 4 | 5 | 6 | | Model | Teacher Model | PT epochs | Link | FT Acc. | 7 | | :----: | :----: | :----: |:----:| :----: | 8 | | ViT-B/16 | CLIP-B/16 | 200 | [gdrive](https://drive.google.com/file/d/1hu_dlzOxVqS1Zx6W41aOu7qL2XqXOuqx/view?usp=share_link) | 85.4 | 9 | | ViT-L/16 | CLIP-B/16 | 200 | [gdrive](https://drive.google.com/file/d/1hWdjhKso52K5M9xem0j81KJVhlg0oZov/view?usp=share_link) | 86.5 | 10 | | ViT-L/16 | CLIP-L/14* | 200 | [gdrive](https://drive.google.com/file/d/1NdwxvQkaHk8axmThxDrQKTLJ8CF_QgM5/view?usp=share_link) | 87.4 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output_dir/ 2 | outputs/ 3 | selected/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | **/*.pyc 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | # custom 112 | /data 113 | .vscode 114 | .idea 115 | *.pkl 116 | *.pkl.json 117 | *.log.json 118 | benchlist.txt 119 | work_dirs/ 120 | 121 | # Pytorch 122 | *.pth 123 | 124 | # Profile 125 | *.prof 126 | 127 | # lmdb 128 | *.mdb 129 | 130 | # unignore some data file in tests/data 131 | !tests/data/**/*.pkl 132 | !tests/data/**/*.pkl.json 133 | !tests/data/**/*.log.json 134 | !tests/data/**/*.pth 135 | 136 | # avoid soft links created by MIM 137 | mmaction/configs/* 138 | mmaction/tools/* 139 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import glob 13 | import PIL 14 | import torch 15 | from io import BytesIO 16 | from PIL import Image 17 | from torchvision import datasets, transforms 18 | 19 | from timm.data import create_transform 20 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 21 | 22 | 23 | class ImageNet1k_JPG(torch.utils.data.Dataset): 24 | ''' 25 | An ImageNet-1k dataset with caching support. 26 | ''' 27 | 28 | def __init__(self, image_root, meta_path, transform): 29 | self.transform = transform 30 | 31 | with open(meta_path) as f: 32 | self.data_list = f.read().splitlines() 33 | self.image_root = image_root 34 | 35 | def __len__(self): 36 | return len(self.data_list) 37 | 38 | def __getitem__(self, idx): 39 | line = self.data_list[idx] 40 | path, label = line.split(' ') 41 | 42 | path = os.path.join(self.image_root, path) 43 | label = int(label) 44 | 45 | image = Image.open(path).convert('RGB') 46 | image = self.transform(image) 47 | 48 | return image, label 49 | 50 | def build_dataset_jpg(is_train, args): 51 | transform = build_transform(is_train, args) 52 | data_root = args.data_path 53 | image_root = os.path.join(data_root, 'train' if is_train else 'val') 54 | meta_path = os.path.join(data_root, 'meta', 'train.txt' if is_train else 'val.txt') 55 | dataset = ImageNet1k_JPG(image_root, meta_path, transform) 56 | print(f"Dataset at {meta_path}. Length of {len(dataset)}") 57 | return dataset 58 | 59 | def build_transform(is_train, args): 60 | mean = IMAGENET_DEFAULT_MEAN 61 | std = IMAGENET_DEFAULT_STD 62 | # train transform 63 | if is_train: 64 | # this should always dispatch to transforms_imagenet_train 65 | transform = create_transform( 66 | input_size=args.input_size, 67 | is_training=True, 68 | color_jitter=args.color_jitter, 69 | auto_augment=args.aa, 70 | interpolation='bicubic', 71 | re_prob=args.reprob, 72 | re_mode=args.remode, 73 | re_count=args.recount, 74 | mean=mean, 75 | std=std, 76 | ) 77 | return transform 78 | 79 | # eval transform 80 | t = [] 81 | if args.input_size <= 224: 82 | crop_pct = 224 / 256 83 | else: 84 | crop_pct = 1.0 85 | size = int(args.input_size / crop_pct) 86 | t.append( 87 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 88 | ) 89 | t.append(transforms.CenterCrop(args.input_size)) 90 | 91 | t.append(transforms.ToTensor()) 92 | t.append(transforms.Normalize(mean, std)) 93 | return transforms.Compose(t) 94 | 95 | -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import timm.models.vision_transformer 18 | 19 | 20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 21 | """ Vision Transformer with support for global average pooling 22 | """ 23 | def __init__(self, global_pool=False, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | 26 | self.global_pool = global_pool 27 | if self.global_pool: 28 | norm_layer = kwargs['norm_layer'] 29 | embed_dim = kwargs['embed_dim'] 30 | self.fc_norm = norm_layer(embed_dim) 31 | 32 | del self.norm # remove the original norm 33 | 34 | def forward_features(self, x): 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | if self.global_pool: 47 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 48 | outcome = self.fc_norm(x) 49 | else: 50 | x = self.norm(x) 51 | outcome = x[:, 0] 52 | 53 | return outcome 54 | 55 | def extract_features(self, x): 56 | B = x.shape[0] 57 | x = self.patch_embed(x) 58 | 59 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 60 | x = torch.cat((cls_tokens, x), dim=1) 61 | x = x + self.pos_embed 62 | x = self.pos_drop(x) 63 | 64 | for blk in self.blocks: 65 | x = blk(x) 66 | 67 | if self.global_pool: 68 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 69 | else: 70 | x = x[:, 0] 71 | 72 | return x 73 | 74 | 75 | def vit_base_patch16(**kwargs): 76 | model = VisionTransformer( 77 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 78 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 79 | return model 80 | 81 | 82 | def vit_large_patch16(**kwargs): 83 | model = VisionTransformer( 84 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 85 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 86 | return model 87 | 88 | 89 | def vit_huge_patch14(**kwargs): 90 | model = VisionTransformer( 91 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 92 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 93 | return model 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | import sys 13 | from typing import Iterable 14 | 15 | import torch 16 | 17 | import util.misc as misc 18 | import util.lr_sched as lr_sched 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, 24 | log_writer=None, 25 | args=None): 26 | model.train(True) 27 | metric_logger = misc.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 20 31 | 32 | accum_iter = args.accum_iter 33 | 34 | optimizer.zero_grad() 35 | 36 | if log_writer is not None: 37 | print('log_dir: {}'.format(log_writer.log_dir)) 38 | 39 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 40 | 41 | # we use a per iteration (instead of per epoch) lr scheduler 42 | if data_iter_step % accum_iter == 0: 43 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 44 | 45 | samples = samples.to(device, non_blocking=True) 46 | 47 | with torch.cuda.amp.autocast(): 48 | loss = model(samples, mask_ratio=args.mask_ratio) 49 | 50 | # handle multiple losses 51 | if isinstance(loss, list): 52 | loss_list = [i.item() for i in loss] 53 | loss = sum(loss) 54 | else: 55 | loss_list = None 56 | 57 | loss_value = loss.item() 58 | 59 | if not math.isfinite(loss_value): 60 | print("Loss is {}, stopping training".format(loss_value)) 61 | sys.exit(1) 62 | 63 | loss /= accum_iter 64 | loss_scaler(loss, optimizer, parameters=model.parameters(), 65 | update_grad=(data_iter_step + 1) % accum_iter == 0) 66 | if (data_iter_step + 1) % accum_iter == 0: 67 | optimizer.zero_grad() 68 | 69 | # torch.cuda.synchronize() 70 | 71 | metric_logger.update(loss=loss_value) 72 | 73 | # handle multiple losses: 2 74 | if loss_list is not None: 75 | assert len(loss_list) == 2 76 | metric_logger.update(loss1=loss_list[0]) 77 | metric_logger.update(loss2=loss_list[1]) 78 | 79 | lr = optimizer.param_groups[0]["lr"] 80 | metric_logger.update(lr=lr) 81 | 82 | loss_value_reduce = misc.all_reduce_mean(loss_value) 83 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 84 | """ We use epoch_1000x as the x-axis in tensorboard. 85 | This calibrates different curves when batch size changes. 86 | """ 87 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 88 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 89 | log_writer.add_scalar('lr', lr, epoch_1000x) 90 | 91 | 92 | # gather the stats from all processes 93 | metric_logger.synchronize_between_processes() 94 | print("Averaged stats:", metric_logger) 95 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > [!IMPORTANT] 2 | > 🌟 Stay up to date at [opendrivelab.com](https://opendrivelab.com/#news)! 3 | 4 | # MaskAlign (CVPR 2023) 5 | 6 |

7 | statistics 8 |

9 | 10 | 11 | This is the official PyTorch repository for CVPR 2023 paper [Stare at What You See: Masked Image Modeling without Reconstruction](https://arxiv.org/abs/2211.08887): 12 | ``` 13 | @article{xue2022stare, 14 | title={Stare at What You See: Masked Image Modeling without Reconstruction}, 15 | author={Xue, Hongwei and Gao, Peng and Li, Hongyang and Qiao, Yu and Sun, Hao and Li, Houqiang and Luo, Jiebo}, 16 | journal={arXiv preprint arXiv:2211.08887}, 17 | year={2022} 18 | } 19 | ``` 20 | 21 | * This repo is a modification on the [MAE repo](https://github.com/facebookresearch/mae). Installation and preparation follow that repo. 22 | 23 | * The teacher models in this repo are called from [Huggingface](https://huggingface.co/). Please install transformers package by running:
`pip install transformers`. 24 | 25 | ## Pre-training 26 | 27 | To pre-train ViT-base (recommended default) with **distributed training**, run the following on 8 GPUs: 28 | 29 | ``` 30 | python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py \ 31 | --batch_size 128 \ 32 | --model mae_vit_base_patch16 \ 33 | --blr 1.5e-4 \ 34 | --min_lr 1e-5 \ 35 | --data_path ${IMAGENET_DIR} \ 36 | --output_dir ${OUTPUT_DIR} \ 37 | --target_norm whiten \ 38 | --loss_type smoothl1 \ 39 | --drop_path 0.1 \ 40 | --head_type linear \ 41 | --epochs 200 \ 42 | --warmup_epochs 20 \ 43 | --mask_type attention \ 44 | --mask_ratio 0.7 \ 45 | --loss_weights top5 \ 46 | --fusion_type linear \ 47 | --teacher_model openai/clip-vit-base-patch16 48 | ``` 49 | 50 | - Here the effective batch size is 128 (`batch_size` per gpu) * 8 (gpus) = 1024. If memory or # gpus is limited, use `--accum_iter` to maintain the effective batch size, which is `batch_size` (per gpu) * `nodes` * 8 (gpus) * `accum_iter`. 51 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 52 | - This repo will automatically resume the checkpoints by keeping a "latest checkpoint". 53 | 54 | To train ViT-Large, please set `--model mae_vit_large_patch16` and `--drop_path 0.2`. Currently, this repo supports three teacher models: `--teacher_model ${TEACHER}`, where `${TEACHER} in openai/clip-vit-base-patch16, openai/clip-vit-large-patch14 and facebook/dino-vitb16`. 55 | 56 | ## Fine-tuning 57 | 58 | Get our pre-trained checkpoints from [here](ModelCard.md). 59 | 60 | To fine-tune ViT-base (recommended default) with **distributed training**, run the following on 8 GPUs: 61 | ``` 62 | python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \ 63 | --epochs 100 \ 64 | --batch_size 128 \ 65 | --model vit_base_patch16 \ 66 | --blr 3e-4 \ 67 | --layer_decay 0.55 \ 68 | --weight_decay 0.05 \ 69 | --drop_path 0.2 \ 70 | --reprob 0.25 \ 71 | --mixup 0.8 \ 72 | --cutmix 1.0 \ 73 | --dist_eval \ 74 | --finetune ${PT_CHECKPOINT} \ 75 | --data_path ${IMAGENET_DIR} \ 76 | --output_dir ${OUTPUT_DIR} 77 | ``` 78 | 79 | - Here the effective batch size is 128 (`batch_size` per gpu) * 8 (gpus) = 1024. 80 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 81 | 82 | To fine-tune ViT-Large, please set `--model vit_large_patch16 --epochs 50 --drop_path 0.4 --layer_decay 0.75 --blr 3e-4`. 83 | 84 | 85 | ## Linear Probing 86 | 87 | Run the following on 8 GPUs: 88 | ``` 89 | python -m torch.distributed.launch --nproc_per_node=8 main_linprobe.py \ 90 | --epochs 90 \ 91 | --batch_size 2048 \ 92 | --model vit_base_patch16 \ 93 | --blr 0.025 \ 94 | --weight_decay 0.0 \ 95 | --dist_eval \ 96 | --finetune ${PT_CHECKPOINT} \ 97 | --data_path ${IMAGENET_DIR} \ 98 | --output_dir ${OUTPUT_DIR} 99 | ``` 100 | - Here the effective batch size is 2048 (`batch_size` per gpu) * 8 (gpus) = 16384. 101 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 102 | 103 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy 20 | 21 | import util.misc as misc 22 | import util.lr_sched as lr_sched 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 28 | mixup_fn: Optional[Mixup] = None, log_writer=None, 29 | args=None): 30 | model.train(True) 31 | metric_logger = misc.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 20 35 | 36 | accum_iter = args.accum_iter 37 | 38 | optimizer.zero_grad() 39 | 40 | if log_writer is not None: 41 | print('log_dir: {}'.format(log_writer.log_dir)) 42 | 43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | 45 | # we use a per iteration (instead of per epoch) lr scheduler 46 | if data_iter_step % accum_iter == 0: 47 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 48 | 49 | samples = samples.to(device, non_blocking=True) 50 | targets = targets.to(device, non_blocking=True) 51 | 52 | if mixup_fn is not None: 53 | samples, targets = mixup_fn(samples, targets) 54 | 55 | with torch.cuda.amp.autocast(): 56 | outputs = model(samples) 57 | loss = criterion(outputs, targets) 58 | 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | sys.exit(1) 64 | 65 | loss /= accum_iter 66 | loss_scaler(loss, optimizer, clip_grad=max_norm, 67 | parameters=model.parameters(), create_graph=False, 68 | update_grad=(data_iter_step + 1) % accum_iter == 0) 69 | if (data_iter_step + 1) % accum_iter == 0: 70 | optimizer.zero_grad() 71 | 72 | # torch.cuda.synchronize() 73 | 74 | metric_logger.update(loss=loss_value) 75 | min_lr = 10. 76 | max_lr = 0. 77 | for group in optimizer.param_groups: 78 | min_lr = min(min_lr, group["lr"]) 79 | max_lr = max(max_lr, group["lr"]) 80 | 81 | metric_logger.update(lr=max_lr) 82 | 83 | loss_value_reduce = misc.all_reduce_mean(loss_value) 84 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 85 | """ We use epoch_1000x as the x-axis in tensorboard. 86 | This calibrates different curves when batch size changes. 87 | """ 88 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 89 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 90 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 91 | 92 | # gather the stats from all processes 93 | metric_logger.synchronize_between_processes() 94 | print("Averaged stats:", metric_logger) 95 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 96 | 97 | 98 | @torch.no_grad() 99 | def evaluate(data_loader, model, device): 100 | criterion = torch.nn.CrossEntropyLoss() 101 | 102 | metric_logger = misc.MetricLogger(delimiter=" ") 103 | header = 'Test:' 104 | 105 | # switch to evaluation mode 106 | model.eval() 107 | 108 | for batch in metric_logger.log_every(data_loader, 10, header): 109 | images = batch[0] 110 | target = batch[-1] 111 | images = images.to(device, non_blocking=True) 112 | target = target.to(device, non_blocking=True) 113 | 114 | # compute output 115 | with torch.cuda.amp.autocast(): 116 | output = model(images) 117 | loss = criterion(output, target) 118 | 119 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 120 | 121 | batch_size = images.shape[0] 122 | metric_logger.update(loss=loss.item()) 123 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 124 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 125 | # gather the stats from all processes 126 | metric_logger.synchronize_between_processes() 127 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 128 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 129 | 130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 131 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import numpy as np 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.tensorboard import SummaryWriter 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | 25 | import timm 26 | 27 | assert timm.__version__ == "0.3.2" # version check 28 | import timm.optim.optim_factory as optim_factory 29 | 30 | import util.misc as misc 31 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 32 | 33 | import models_pretrain as models_mae 34 | 35 | from engine_pretrain import train_one_epoch 36 | from util.datasets import ImageNet1k_JPG 37 | 38 | 39 | def get_args_parser(): 40 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 41 | 42 | # Add new args 43 | parser.add_argument('--loss_weights', default="mean", type=str, 44 | help='Loss weights of each block in ViT.') 45 | parser.add_argument('--mask_type', default="random", type=str, 46 | help='Mask type in random, attention.') 47 | parser.add_argument('--fusion_type', default="simple", type=str, 48 | help='Fusion type in distillation.') 49 | parser.add_argument('--target_norm', default="none", type=str, 50 | help='target norm type in teacher model.') 51 | parser.add_argument('--loss_type', default="l2", type=str, 52 | help='loss type for feature reconstruction.') 53 | parser.add_argument('--head_type', default="linear", type=str, 54 | help='head type for feature reconstruction.') 55 | parser.add_argument('--teacher_model', default="openai/clip-vit-base-patch16", type=str, 56 | help='teacher model for feature reconstruction.') 57 | 58 | parser.add_argument('--batch_size', default=64, type=int, 59 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 60 | parser.add_argument('--epochs', default=400, type=int) 61 | parser.add_argument('--accum_iter', default=1, type=int, 62 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 63 | 64 | # Model parameters 65 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 66 | help='Name of model to train') 67 | 68 | parser.add_argument('--input_size', default=224, type=int, 69 | help='images input size') 70 | 71 | parser.add_argument('--mask_ratio', default=0.75, type=float, 72 | help='Masking ratio (percentage of removed patches).') 73 | 74 | parser.add_argument('--norm_pix_loss', action='store_true', 75 | help='Use (per-patch) normalized pixels as targets for computing loss') 76 | parser.set_defaults(norm_pix_loss=False) 77 | 78 | parser.add_argument('--drop_path', type=float, default=0., 79 | help='drop path rate (default: 0.)') 80 | 81 | # Optimizer parameters 82 | parser.add_argument('--weight_decay', type=float, default=0.05, 83 | help='weight decay (default: 0.05)') 84 | 85 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 86 | help='learning rate (absolute lr)') 87 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 88 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 89 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 90 | help='lower lr bound for cyclic schedulers that hit 0') 91 | 92 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 93 | help='epochs to warmup LR') 94 | 95 | # Dataset parameters 96 | parser.add_argument('--data_path', default='/mnt/petrelfs/share/imagenet/images', type=str, 97 | help='dataset path') 98 | 99 | parser.add_argument('--output_dir', default='', 100 | help='path where to save, empty for no saving') 101 | parser.add_argument('--log_dir', default='', 102 | help='path where to tensorboard log') 103 | parser.add_argument('--device', default='cuda', 104 | help='device to use for training / testing') 105 | parser.add_argument('--seed', default=0, type=int) 106 | parser.add_argument('--resume', default='', 107 | help='resume from checkpoint') 108 | parser.add_argument('--auto_resume', action='store_true') 109 | parser.set_defaults(auto_resume=True) 110 | 111 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 112 | help='start epoch') 113 | parser.add_argument('--num_workers', default=10, type=int) 114 | parser.add_argument('--pin_mem', action='store_true', 115 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 116 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 117 | parser.set_defaults(pin_mem=True) 118 | 119 | # distributed training parameters 120 | parser.add_argument('--world_size', default=1, type=int, 121 | help='number of distributed processes') 122 | parser.add_argument('--local_rank', default=-1, type=int) 123 | parser.add_argument('--dist_on_itp', action='store_true') 124 | parser.add_argument('--dist_url', default='env://', 125 | help='url used to set up distributed training') 126 | 127 | return parser 128 | 129 | 130 | def main(args): 131 | misc.init_distributed_mode(args) 132 | 133 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 134 | print("{}".format(args).replace(', ', ',\n')) 135 | 136 | device = torch.device(args.device) 137 | 138 | # fix the seed for reproducibility 139 | seed = args.seed + misc.get_rank() 140 | torch.manual_seed(seed) 141 | np.random.seed(seed) 142 | 143 | cudnn.benchmark = True 144 | 145 | # simple augmentation 146 | transform_train = transforms.Compose([ 147 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 148 | transforms.RandomHorizontalFlip(), 149 | transforms.ToTensor(), 150 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 151 | 152 | dataset_train = ImageNet1k_JPG(image_root=os.path.join(args.data_path, 'train'), meta_path=os.path.join(args.data_path, 'meta', 'train.txt'), transform=transform_train) 153 | print(dataset_train) 154 | 155 | if True: # args.distributed: 156 | num_tasks = misc.get_world_size() 157 | global_rank = misc.get_rank() 158 | sampler_train = torch.utils.data.DistributedSampler( 159 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 160 | ) 161 | print("Sampler_train = %s" % str(sampler_train)) 162 | else: 163 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 164 | 165 | if global_rank == 0 and args.log_dir is not None and len(args.log_dir) > 0: 166 | os.makedirs(args.log_dir, exist_ok=True) 167 | log_writer = SummaryWriter(log_dir=args.log_dir) 168 | else: 169 | log_writer = None 170 | 171 | data_loader_train = torch.utils.data.DataLoader( 172 | dataset_train, sampler=sampler_train, 173 | batch_size=args.batch_size, 174 | num_workers=args.num_workers, 175 | pin_memory=args.pin_mem, 176 | drop_last=True, 177 | ) 178 | 179 | # define the model 180 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, drop_path_rate=args.drop_path, \ 181 | loss_weights=args.loss_weights, loss_type=args.loss_type, \ 182 | mask_type=args.mask_type, fusion_type=args.fusion_type, target_norm=args.target_norm, 183 | head_type=args.head_type, teacher_model=args.teacher_model) 184 | 185 | model.to(device) 186 | 187 | model_without_ddp = model 188 | print("Model = %s" % str(model_without_ddp)) 189 | 190 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 191 | 192 | if args.lr is None: # only base_lr is specified 193 | args.lr = args.blr * eff_batch_size / 256 194 | 195 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 196 | print("actual lr: %.2e" % args.lr) 197 | 198 | print("accumulate grad iterations: %d" % args.accum_iter) 199 | print("effective batch size: %d" % eff_batch_size) 200 | 201 | if args.distributed: 202 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 203 | model_without_ddp = model.module 204 | 205 | # following timm: set wd as 0 for bias and norm layers 206 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay, skip_list=["distill_weights"]) 207 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 208 | print(optimizer) 209 | loss_scaler = NativeScaler() 210 | 211 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 212 | 213 | print(f"Start training for {args.epochs} epochs") 214 | start_time = time.time() 215 | for epoch in range(args.start_epoch, args.epochs): 216 | if args.distributed: 217 | data_loader_train.sampler.set_epoch(epoch) 218 | train_stats = train_one_epoch( 219 | model, data_loader_train, 220 | optimizer, device, epoch, loss_scaler, 221 | log_writer=log_writer, 222 | args=args 223 | ) 224 | if args.output_dir: 225 | misc.save_model_latest( 226 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 227 | loss_scaler=loss_scaler, epoch=epoch) 228 | 229 | if args.output_dir and (epoch % 50 == 0 or epoch + 1 == args.epochs): 230 | misc.save_model( 231 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 232 | loss_scaler=loss_scaler, epoch=epoch) 233 | 234 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 235 | 'epoch': epoch,} 236 | 237 | if args.output_dir and misc.is_main_process(): 238 | if log_writer is not None: 239 | log_writer.flush() 240 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 241 | f.write(json.dumps(log_stats) + "\n") 242 | 243 | total_time = time.time() - start_time 244 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 245 | print('Training time {}'.format(total_time_str)) 246 | 247 | 248 | if __name__ == '__main__': 249 | args = get_args_parser() 250 | args = args.parse_args() 251 | if args.output_dir: 252 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 253 | main(args) 254 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import glob 16 | import time 17 | from collections import defaultdict, deque 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.distributed as dist 22 | from torch._six import inf 23 | import subprocess 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 50 | dist.barrier() 51 | dist.all_reduce(t) 52 | t = t.tolist() 53 | self.count = int(t[0]) 54 | self.total = t[1] 55 | 56 | @property 57 | def median(self): 58 | d = torch.tensor(list(self.deque)) 59 | return d.median().item() 60 | 61 | @property 62 | def avg(self): 63 | d = torch.tensor(list(self.deque), dtype=torch.float32) 64 | return d.mean().item() 65 | 66 | @property 67 | def global_avg(self): 68 | return self.total / self.count 69 | 70 | @property 71 | def max(self): 72 | return max(self.deque) 73 | 74 | @property 75 | def value(self): 76 | return self.deque[-1] 77 | 78 | def __str__(self): 79 | return self.fmt.format( 80 | median=self.median, 81 | avg=self.avg, 82 | global_avg=self.global_avg, 83 | max=self.max, 84 | value=self.value) 85 | 86 | 87 | class MetricLogger(object): 88 | def __init__(self, delimiter="\t"): 89 | self.meters = defaultdict(SmoothedValue) 90 | self.delimiter = delimiter 91 | 92 | def update(self, **kwargs): 93 | for k, v in kwargs.items(): 94 | if v is None: 95 | continue 96 | if isinstance(v, torch.Tensor): 97 | v = v.item() 98 | assert isinstance(v, (float, int)) 99 | self.meters[k].update(v) 100 | 101 | def __getattr__(self, attr): 102 | if attr in self.meters: 103 | return self.meters[attr] 104 | if attr in self.__dict__: 105 | return self.__dict__[attr] 106 | raise AttributeError("'{}' object has no attribute '{}'".format( 107 | type(self).__name__, attr)) 108 | 109 | def __str__(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def synchronize_between_processes(self): 118 | for meter in self.meters.values(): 119 | meter.synchronize_between_processes() 120 | 121 | def add_meter(self, name, meter): 122 | self.meters[name] = meter 123 | 124 | def log_every(self, iterable, print_freq, header=None): 125 | i = 0 126 | if not header: 127 | header = '' 128 | start_time = time.time() 129 | end = time.time() 130 | iter_time = SmoothedValue(fmt='{avg:.4f}') 131 | data_time = SmoothedValue(fmt='{avg:.4f}') 132 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 133 | log_msg = [ 134 | header, 135 | '[{0' + space_fmt + '}/{1}]', 136 | 'eta: {eta}', 137 | '{meters}', 138 | 'time: {time}', 139 | 'data: {data}' 140 | ] 141 | if torch.cuda.is_available(): 142 | log_msg.append('max mem: {memory:.0f}') 143 | log_msg = self.delimiter.join(log_msg) 144 | MB = 1024.0 * 1024.0 145 | for obj in iterable: 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if i % print_freq == 0 or i == len(iterable) - 1: 150 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 152 | if torch.cuda.is_available(): 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time), 157 | memory=torch.cuda.max_memory_allocated() / MB)) 158 | else: 159 | print(log_msg.format( 160 | i, len(iterable), eta=eta_string, 161 | meters=str(self), 162 | time=str(iter_time), data=str(data_time))) 163 | i += 1 164 | end = time.time() 165 | total_time = time.time() - start_time 166 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 167 | print('{} Total time: {} ({:.4f} s / it)'.format( 168 | header, total_time_str, total_time / len(iterable))) 169 | 170 | 171 | def setup_for_distributed(is_master): 172 | """ 173 | This function disables printing when not in master process 174 | """ 175 | builtin_print = builtins.print 176 | 177 | def print(*args, **kwargs): 178 | force = kwargs.pop('force', False) 179 | force = force or (get_world_size() > 8) 180 | if is_master or force: 181 | now = datetime.datetime.now().time() 182 | builtin_print('[{}] '.format(now), end='') # print with time stamp 183 | builtin_print(*args, **kwargs) 184 | 185 | builtins.print = print 186 | 187 | 188 | def is_dist_avail_and_initialized(): 189 | if not dist.is_available(): 190 | return False 191 | if not dist.is_initialized(): 192 | return False 193 | return True 194 | 195 | 196 | def get_world_size(): 197 | if not is_dist_avail_and_initialized(): 198 | return 1 199 | return dist.get_world_size() 200 | 201 | 202 | def get_rank(): 203 | if not is_dist_avail_and_initialized(): 204 | return 0 205 | return dist.get_rank() 206 | 207 | 208 | def is_main_process(): 209 | return get_rank() == 0 210 | 211 | 212 | def save_on_master(*args, **kwargs): 213 | if is_main_process(): 214 | torch.save(*args, **kwargs) 215 | 216 | 217 | def init_distributed_mode(args): 218 | if args.dist_on_itp: 219 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 220 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 221 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 222 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 223 | os.environ['LOCAL_RANK'] = str(args.gpu) 224 | os.environ['RANK'] = str(args.rank) 225 | os.environ['WORLD_SIZE'] = str(args.world_size) 226 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 227 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 228 | args.rank = int(os.environ["RANK"]) 229 | args.world_size = int(os.environ['WORLD_SIZE']) 230 | args.gpu = int(os.environ['LOCAL_RANK']) 231 | else: 232 | print('Not using distributed mode') 233 | setup_for_distributed(is_master=True) # hack 234 | args.distributed = False 235 | return 236 | 237 | args.distributed = True 238 | 239 | torch.cuda.set_device(args.gpu) 240 | args.dist_backend = 'nccl' 241 | print('| distributed init (rank {}): {}, gpu {}'.format( 242 | args.rank, args.dist_url, args.gpu), flush=True) 243 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 244 | world_size=args.world_size, rank=args.rank) 245 | torch.distributed.barrier() 246 | setup_for_distributed(args.rank == 0) 247 | 248 | 249 | class NativeScalerWithGradNormCount: 250 | state_dict_key = "amp_scaler" 251 | 252 | def __init__(self): 253 | self._scaler = torch.cuda.amp.GradScaler() 254 | 255 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 256 | self._scaler.scale(loss).backward(create_graph=create_graph) 257 | if update_grad: 258 | if clip_grad is not None: 259 | assert parameters is not None 260 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 261 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 262 | else: 263 | self._scaler.unscale_(optimizer) 264 | norm = get_grad_norm_(parameters) 265 | self._scaler.step(optimizer) 266 | self._scaler.update() 267 | else: 268 | norm = None 269 | return norm 270 | 271 | def state_dict(self): 272 | return self._scaler.state_dict() 273 | 274 | def load_state_dict(self, state_dict): 275 | self._scaler.load_state_dict(state_dict) 276 | 277 | 278 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 279 | if isinstance(parameters, torch.Tensor): 280 | parameters = [parameters] 281 | parameters = [p for p in parameters if p.grad is not None] 282 | norm_type = float(norm_type) 283 | if len(parameters) == 0: 284 | return torch.tensor(0.) 285 | device = parameters[0].grad.device 286 | if norm_type == inf: 287 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 288 | else: 289 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 290 | return total_norm 291 | 292 | 293 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 294 | output_dir = Path(args.output_dir) 295 | epoch_name = str(epoch) 296 | if loss_scaler is not None: 297 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 298 | for checkpoint_path in checkpoint_paths: 299 | to_save = { 300 | 'model': model_without_ddp.state_dict(), 301 | 'optimizer': optimizer.state_dict(), 302 | 'epoch': epoch, 303 | 'scaler': loss_scaler.state_dict(), 304 | 'args': args, 305 | } 306 | 307 | save_on_master(to_save, checkpoint_path) 308 | else: 309 | client_state = {'epoch': epoch} 310 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 311 | 312 | def save_model_latest(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 313 | output_dir = Path(args.output_dir) 314 | epoch_name = str(epoch) 315 | if loss_scaler is not None: 316 | checkpoint_paths = [output_dir / 'checkpoint-latest.pth'] 317 | for checkpoint_path in checkpoint_paths: 318 | to_save = { 319 | 'model': model_without_ddp.state_dict(), 320 | 'optimizer': optimizer.state_dict(), 321 | 'epoch': epoch, 322 | 'scaler': loss_scaler.state_dict(), 323 | 'args': args, 324 | } 325 | 326 | save_on_master(to_save, checkpoint_path) 327 | else: 328 | client_state = {'epoch': epoch} 329 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-latest", client_state=client_state) 330 | print("Latest checkpoint saved.") 331 | 332 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 333 | output_dir = Path(args.output_dir) 334 | if args.auto_resume and len(args.resume) == 0: 335 | if os.path.exists(os.path.join(output_dir, 'checkpoint-latest.pth')): 336 | args.resume = os.path.join(output_dir, 'checkpoint-latest.pth') 337 | else: 338 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 339 | latest_ckpt = -1 340 | for ckpt in all_checkpoints: 341 | t = ckpt.split('-')[-1].split('.')[0] 342 | if t.isdigit(): 343 | latest_ckpt = max(int(t), latest_ckpt) 344 | if latest_ckpt >= 0: 345 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 346 | 347 | print("Auto resume checkpoint: %s" % args.resume) 348 | 349 | if args.resume: 350 | if args.resume.startswith('https'): 351 | checkpoint = torch.hub.load_state_dict_from_url( 352 | args.resume, map_location='cpu', check_hash=True) 353 | else: 354 | checkpoint = torch.load(args.resume, map_location='cpu') 355 | model_without_ddp.load_state_dict(checkpoint['model']) 356 | print("Resume checkpoint %s" % args.resume) 357 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 358 | optimizer.load_state_dict(checkpoint['optimizer']) 359 | args.start_epoch = checkpoint['epoch'] + 1 360 | if 'scaler' in checkpoint: 361 | loss_scaler.load_state_dict(checkpoint['scaler']) 362 | print("With optim & sched!") 363 | 364 | 365 | def all_reduce_mean(x): 366 | world_size = get_world_size() 367 | if world_size > 1: 368 | x_reduce = torch.tensor(x).cuda() 369 | dist.all_reduce(x_reduce) 370 | x_reduce /= world_size 371 | return x_reduce.item() 372 | else: 373 | return x -------------------------------------------------------------------------------- /main_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # MoCo v3: https://github.com/facebookresearch/moco-v3 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from torch.utils.tensorboard import SummaryWriter 23 | import torchvision.transforms as transforms 24 | import torchvision.datasets as datasets 25 | 26 | import timm 27 | 28 | assert timm.__version__ == "0.3.2" # version check 29 | from timm.models.layers import trunc_normal_ 30 | 31 | import util.misc as misc 32 | from util.pos_embed import interpolate_pos_embed 33 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 34 | from util.lars import LARS 35 | from util.crop import RandomResizedCrop 36 | 37 | import models_vit 38 | 39 | from util.datasets import ImageNet1k_JPG 40 | from engine_finetune import train_one_epoch, evaluate 41 | 42 | 43 | def get_args_parser(): 44 | parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False) 45 | parser.add_argument('--batch_size', default=512, type=int, 46 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 47 | parser.add_argument('--epochs', default=90, type=int) 48 | parser.add_argument('--accum_iter', default=1, type=int, 49 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 50 | 51 | # Model parameters 52 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 53 | help='Name of model to train') 54 | 55 | # Optimizer parameters 56 | parser.add_argument('--weight_decay', type=float, default=0, 57 | help='weight decay (default: 0 for linear probe following MoCo v1)') 58 | 59 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 60 | help='learning rate (absolute lr)') 61 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR', 62 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 63 | 64 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 65 | help='lower lr bound for cyclic schedulers that hit 0') 66 | 67 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 68 | help='epochs to warmup LR') 69 | 70 | # * Finetuning params 71 | parser.add_argument('--finetune', default='', 72 | help='finetune from checkpoint') 73 | parser.add_argument('--global_pool', action='store_true') 74 | parser.set_defaults(global_pool=False) 75 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 76 | help='Use class token instead of global pool for classification') 77 | 78 | # Dataset parameters 79 | parser.add_argument('--data_path', default='/mnt/petrelfs/share/imagenet/images', type=str, 80 | help='dataset path') 81 | parser.add_argument('--nb_classes', default=1000, type=int, 82 | help='number of the classification types') 83 | 84 | parser.add_argument('--output_dir', default='', 85 | help='path where to save, empty for no saving') 86 | parser.add_argument('--log_dir', default='', 87 | help='path where to tensorboard log') 88 | parser.add_argument('--device', default='cuda', 89 | help='device to use for training / testing') 90 | parser.add_argument('--seed', default=0, type=int) 91 | parser.add_argument('--resume', default='', 92 | help='resume from checkpoint') 93 | parser.add_argument('--auto_resume', action='store_true') 94 | parser.set_defaults(auto_resume=True) 95 | 96 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 97 | help='start epoch') 98 | parser.add_argument('--eval', action='store_true', 99 | help='Perform evaluation only') 100 | parser.add_argument('--dist_eval', action='store_true', default=False, 101 | help='Enabling distributed evaluation (recommended during training for faster monitor') 102 | parser.add_argument('--num_workers', default=4, type=int) 103 | parser.add_argument('--pin_mem', action='store_true', 104 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 105 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 106 | parser.set_defaults(pin_mem=True) 107 | 108 | # distributed training parameters 109 | parser.add_argument('--world_size', default=1, type=int, 110 | help='number of distributed processes') 111 | parser.add_argument('--local_rank', default=-1, type=int) 112 | parser.add_argument('--dist_on_itp', action='store_true') 113 | parser.add_argument('--dist_url', default='env://', 114 | help='url used to set up distributed training') 115 | 116 | return parser 117 | 118 | 119 | def main(args): 120 | misc.init_distributed_mode(args) 121 | 122 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 123 | print("{}".format(args).replace(', ', ',\n')) 124 | 125 | device = torch.device(args.device) 126 | 127 | # fix the seed for reproducibility 128 | seed = args.seed + misc.get_rank() 129 | torch.manual_seed(seed) 130 | np.random.seed(seed) 131 | 132 | cudnn.benchmark = True 133 | 134 | # linear probe: weak augmentation 135 | transform_train = transforms.Compose([ 136 | RandomResizedCrop(224, interpolation=3), 137 | transforms.RandomHorizontalFlip(), 138 | transforms.ToTensor(), 139 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 140 | transform_val = transforms.Compose([ 141 | transforms.Resize(256, interpolation=3), 142 | transforms.CenterCrop(224), 143 | transforms.ToTensor(), 144 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 145 | # dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 146 | # dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) 147 | dataset_train = ImageNet1k_JPG(image_root=os.path.join(args.data_path, 'train'), meta_path=os.path.join(args.data_path, 'meta', 'train.txt'), transform=transform_train) 148 | dataset_val = ImageNet1k_JPG(image_root=os.path.join(args.data_path, 'val'), meta_path=os.path.join(args.data_path, 'meta', 'val.txt'), transform=transform_val) 149 | print(dataset_train) 150 | print(dataset_val) 151 | 152 | if True: # args.distributed: 153 | num_tasks = misc.get_world_size() 154 | global_rank = misc.get_rank() 155 | sampler_train = torch.utils.data.DistributedSampler( 156 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 157 | ) 158 | print("Sampler_train = %s" % str(sampler_train)) 159 | if args.dist_eval: 160 | if len(dataset_val) % num_tasks != 0: 161 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 162 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 163 | 'equal num of samples per-process.') 164 | sampler_val = torch.utils.data.DistributedSampler( 165 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 166 | else: 167 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 168 | else: 169 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 170 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 171 | 172 | if global_rank == 0 and args.log_dir is not None and len(args.log_dir) > 0 and not args.eval: 173 | os.makedirs(args.log_dir, exist_ok=True) 174 | log_writer = SummaryWriter(log_dir=args.log_dir) 175 | else: 176 | log_writer = None 177 | 178 | data_loader_train = torch.utils.data.DataLoader( 179 | dataset_train, sampler=sampler_train, 180 | batch_size=args.batch_size, 181 | num_workers=args.num_workers, 182 | pin_memory=args.pin_mem, 183 | drop_last=True, 184 | ) 185 | 186 | data_loader_val = torch.utils.data.DataLoader( 187 | dataset_val, sampler=sampler_val, 188 | batch_size=args.batch_size, 189 | num_workers=args.num_workers, 190 | pin_memory=args.pin_mem, 191 | drop_last=False 192 | ) 193 | 194 | model = models_vit.__dict__[args.model]( 195 | num_classes=args.nb_classes, 196 | global_pool=args.global_pool, 197 | ) 198 | 199 | if args.finetune and not args.eval: 200 | checkpoint = torch.load(args.finetune, map_location='cpu') 201 | 202 | print("Load pre-trained checkpoint from: %s" % args.finetune) 203 | checkpoint_model = checkpoint['model'] 204 | state_dict = model.state_dict() 205 | for k in ['head.weight', 'head.bias']: 206 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 207 | print(f"Removing key {k} from pretrained checkpoint") 208 | del checkpoint_model[k] 209 | 210 | # interpolate position embedding 211 | interpolate_pos_embed(model, checkpoint_model) 212 | 213 | # load pre-trained model 214 | msg = model.load_state_dict(checkpoint_model, strict=False) 215 | print(msg) 216 | 217 | if args.global_pool: 218 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 219 | else: 220 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 221 | 222 | # manually initialize fc layer: following MoCo v3 223 | trunc_normal_(model.head.weight, std=0.01) 224 | 225 | # for linear prob only 226 | # hack: revise model's head with BN 227 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) 228 | # freeze all but the head 229 | for _, p in model.named_parameters(): 230 | p.requires_grad = False 231 | for _, p in model.head.named_parameters(): 232 | p.requires_grad = True 233 | 234 | model.to(device) 235 | 236 | model_without_ddp = model 237 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 238 | 239 | print("Model = %s" % str(model_without_ddp)) 240 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 241 | 242 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 243 | 244 | if args.lr is None: # only base_lr is specified 245 | args.lr = args.blr * eff_batch_size / 256 246 | 247 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 248 | print("actual lr: %.2e" % args.lr) 249 | 250 | print("accumulate grad iterations: %d" % args.accum_iter) 251 | print("effective batch size: %d" % eff_batch_size) 252 | 253 | if args.distributed: 254 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 255 | model_without_ddp = model.module 256 | 257 | optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) 258 | print(optimizer) 259 | loss_scaler = NativeScaler() 260 | 261 | criterion = torch.nn.CrossEntropyLoss() 262 | 263 | print("criterion = %s" % str(criterion)) 264 | 265 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 266 | 267 | if args.eval: 268 | test_stats = evaluate(data_loader_val, model, device) 269 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 270 | exit(0) 271 | 272 | print(f"Start training for {args.epochs} epochs") 273 | start_time = time.time() 274 | max_accuracy = 0.0 275 | for epoch in range(args.start_epoch, args.epochs): 276 | if args.distributed: 277 | data_loader_train.sampler.set_epoch(epoch) 278 | train_stats = train_one_epoch( 279 | model, criterion, data_loader_train, 280 | optimizer, device, epoch, loss_scaler, 281 | max_norm=None, 282 | log_writer=log_writer, 283 | args=args 284 | ) 285 | if args.output_dir: 286 | misc.save_model_latest( 287 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 288 | loss_scaler=loss_scaler, epoch=epoch) 289 | 290 | if args.output_dir and (epoch % 10 == 0 or epoch + 1 == args.epochs): 291 | misc.save_model( 292 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 293 | loss_scaler=loss_scaler, epoch=epoch) 294 | 295 | test_stats = evaluate(data_loader_val, model, device) 296 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 297 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 298 | print(f'Max accuracy: {max_accuracy:.2f}%') 299 | 300 | if log_writer is not None: 301 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 302 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 303 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 304 | 305 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 306 | **{f'test_{k}': v for k, v in test_stats.items()}, 307 | 'epoch': epoch, 308 | 'n_parameters': n_parameters} 309 | 310 | if args.output_dir and misc.is_main_process(): 311 | if log_writer is not None: 312 | log_writer.flush() 313 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 314 | f.write(json.dumps(log_stats) + "\n") 315 | 316 | total_time = time.time() - start_time 317 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 318 | print('Training time {}'.format(total_time_str)) 319 | 320 | 321 | if __name__ == '__main__': 322 | args = get_args_parser() 323 | args = args.parse_args() 324 | if args.output_dir: 325 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 326 | main(args) 327 | -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | import timm 25 | 26 | assert timm.__version__ == "0.3.2" # version check 27 | from timm.models.layers import trunc_normal_ 28 | from timm.data.mixup import Mixup 29 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 30 | 31 | import util.lr_decay as lrd 32 | import util.misc as misc 33 | from util.datasets import build_dataset_jpg 34 | from util.pos_embed import interpolate_pos_embed 35 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 36 | 37 | import models_vit 38 | 39 | from engine_finetune import train_one_epoch, evaluate 40 | 41 | 42 | def get_args_parser(): 43 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) 44 | parser.add_argument('--batch_size', default=64, type=int, 45 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 46 | parser.add_argument('--epochs', default=50, type=int) 47 | parser.add_argument('--accum_iter', default=1, type=int, 48 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 49 | 50 | # Model parameters 51 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 52 | help='Name of model to train') 53 | 54 | parser.add_argument('--input_size', default=224, type=int, 55 | help='images input size') 56 | 57 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 58 | help='Drop path rate (default: 0.1)') 59 | 60 | # Optimizer parameters 61 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 62 | help='Clip gradient norm (default: None, no clipping)') 63 | parser.add_argument('--weight_decay', type=float, default=0.05, 64 | help='weight decay (default: 0.05)') 65 | 66 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 67 | help='learning rate (absolute lr)') 68 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 69 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 70 | parser.add_argument('--layer_decay', type=float, default=0.75, 71 | help='layer-wise lr decay from ELECTRA/BEiT') 72 | 73 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 74 | help='lower lr bound for cyclic schedulers that hit 0') 75 | 76 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 77 | help='epochs to warmup LR') 78 | 79 | # Augmentation parameters 80 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 81 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 82 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 83 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 84 | parser.add_argument('--smoothing', type=float, default=0.1, 85 | help='Label smoothing (default: 0.1)') 86 | 87 | # * Random Erase params 88 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 89 | help='Random erase prob (default: 0.25)') 90 | parser.add_argument('--remode', type=str, default='pixel', 91 | help='Random erase mode (default: "pixel")') 92 | parser.add_argument('--recount', type=int, default=1, 93 | help='Random erase count (default: 1)') 94 | parser.add_argument('--resplit', action='store_true', default=False, 95 | help='Do not random erase first (clean) augmentation split') 96 | 97 | # * Mixup params 98 | parser.add_argument('--mixup', type=float, default=0, 99 | help='mixup alpha, mixup enabled if > 0.') 100 | parser.add_argument('--cutmix', type=float, default=0, 101 | help='cutmix alpha, cutmix enabled if > 0.') 102 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 103 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 104 | parser.add_argument('--mixup_prob', type=float, default=1.0, 105 | help='Probability of performing mixup or cutmix when either/both is enabled') 106 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 107 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 108 | parser.add_argument('--mixup_mode', type=str, default='batch', 109 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 110 | 111 | # * Finetuning params 112 | parser.add_argument('--finetune', default='', 113 | help='finetune from checkpoint') 114 | parser.add_argument('--global_pool', action='store_true') 115 | parser.set_defaults(global_pool=True) 116 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 117 | help='Use class token instead of global pool for classification') 118 | 119 | # Dataset parameters 120 | parser.add_argument('--data_path', default='/mnt/petrelfs/share/imagenet/images', type=str, 121 | help='dataset path') 122 | parser.add_argument('--nb_classes', default=1000, type=int, 123 | help='number of the classification types') 124 | 125 | parser.add_argument('--output_dir', default='', 126 | help='path where to save, empty for no saving') 127 | parser.add_argument('--log_dir', default='', 128 | help='path where to tensorboard log') 129 | parser.add_argument('--device', default='cuda', 130 | help='device to use for training / testing') 131 | parser.add_argument('--seed', default=0, type=int) 132 | parser.add_argument('--resume', default='', 133 | help='resume from checkpoint') 134 | parser.add_argument('--auto_resume', action='store_true') 135 | parser.set_defaults(auto_resume=True) 136 | 137 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 138 | help='start epoch') 139 | parser.add_argument('--eval', action='store_true', 140 | help='Perform evaluation only') 141 | parser.add_argument('--dist_eval', action='store_true', default=False, 142 | help='Enabling distributed evaluation (recommended during training for faster monitor') 143 | parser.add_argument('--num_workers', default=10, type=int) 144 | parser.add_argument('--pin_mem', action='store_true', 145 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 146 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 147 | parser.set_defaults(pin_mem=True) 148 | 149 | # distributed training parameters 150 | parser.add_argument('--world_size', default=1, type=int, 151 | help='number of distributed processes') 152 | parser.add_argument('--local_rank', default=-1, type=int) 153 | parser.add_argument('--dist_on_itp', action='store_true') 154 | parser.add_argument('--dist_url', default='env://', 155 | help='url used to set up distributed training') 156 | 157 | return parser 158 | 159 | 160 | def main(args): 161 | misc.init_distributed_mode(args) 162 | 163 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 164 | print("{}".format(args).replace(', ', ',\n')) 165 | 166 | device = torch.device(args.device) 167 | 168 | # fix the seed for reproducibility 169 | seed = args.seed + misc.get_rank() 170 | torch.manual_seed(seed) 171 | np.random.seed(seed) 172 | 173 | cudnn.benchmark = True 174 | 175 | dataset_train = build_dataset_jpg(is_train=True, args=args) 176 | dataset_val = build_dataset_jpg(is_train=False, args=args) 177 | 178 | if True: # args.distributed: 179 | num_tasks = misc.get_world_size() 180 | global_rank = misc.get_rank() 181 | sampler_train = torch.utils.data.DistributedSampler( 182 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 183 | ) 184 | print("Sampler_train = %s" % str(sampler_train)) 185 | if args.dist_eval: 186 | if len(dataset_val) % num_tasks != 0: 187 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 188 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 189 | 'equal num of samples per-process.') 190 | sampler_val = torch.utils.data.DistributedSampler( 191 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 192 | else: 193 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 194 | else: 195 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 196 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 197 | 198 | if global_rank == 0 and args.log_dir is not None and len(args.log_dir) > 0 and not args.eval: 199 | os.makedirs(args.log_dir, exist_ok=True) 200 | log_writer = SummaryWriter(log_dir=args.log_dir) 201 | else: 202 | log_writer = None 203 | 204 | data_loader_train = torch.utils.data.DataLoader( 205 | dataset_train, sampler=sampler_train, 206 | batch_size=args.batch_size, 207 | num_workers=args.num_workers, 208 | pin_memory=args.pin_mem, 209 | drop_last=True, 210 | ) 211 | 212 | data_loader_val = torch.utils.data.DataLoader( 213 | dataset_val, sampler=sampler_val, 214 | batch_size=args.batch_size, 215 | num_workers=args.num_workers, 216 | pin_memory=args.pin_mem, 217 | drop_last=False 218 | ) 219 | 220 | mixup_fn = None 221 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 222 | if mixup_active: 223 | print("Mixup is activated!") 224 | mixup_fn = Mixup( 225 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 226 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 227 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 228 | 229 | model = models_vit.__dict__[args.model]( 230 | num_classes=args.nb_classes, 231 | drop_path_rate=args.drop_path, 232 | global_pool=args.global_pool, 233 | ) 234 | 235 | if args.finetune and not args.eval: 236 | checkpoint = torch.load(args.finetune, map_location='cpu') 237 | 238 | print("Load pre-trained checkpoint from: %s" % args.finetune) 239 | checkpoint_model = checkpoint['model'] 240 | state_dict = model.state_dict() 241 | for k in ['head.weight', 'head.bias']: 242 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 243 | print(f"Removing key {k} from pretrained checkpoint") 244 | del checkpoint_model[k] 245 | 246 | # interpolate position embedding 247 | interpolate_pos_embed(model, checkpoint_model) 248 | 249 | # load pre-trained model 250 | msg = model.load_state_dict(checkpoint_model, strict=False) 251 | print(msg) 252 | 253 | if args.global_pool: 254 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 255 | else: 256 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 257 | 258 | # manually initialize fc layer 259 | trunc_normal_(model.head.weight, std=2e-5) 260 | 261 | model.to(device) 262 | 263 | model_without_ddp = model 264 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 265 | 266 | print("Model = %s" % str(model_without_ddp)) 267 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 268 | 269 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 270 | 271 | if args.lr is None: # only base_lr is specified 272 | args.lr = args.blr * eff_batch_size / 256 273 | 274 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 275 | print("actual lr: %.2e" % args.lr) 276 | 277 | print("accumulate grad iterations: %d" % args.accum_iter) 278 | print("effective batch size: %d" % eff_batch_size) 279 | 280 | if args.distributed: 281 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 282 | model_without_ddp = model.module 283 | 284 | # build optimizer with layer-wise lr decay (lrd) 285 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 286 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 287 | layer_decay=args.layer_decay 288 | ) 289 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 290 | loss_scaler = NativeScaler() 291 | 292 | if mixup_fn is not None: 293 | # smoothing is handled with mixup label transform 294 | criterion = SoftTargetCrossEntropy() 295 | elif args.smoothing > 0.: 296 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 297 | else: 298 | criterion = torch.nn.CrossEntropyLoss() 299 | 300 | print("criterion = %s" % str(criterion)) 301 | 302 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 303 | 304 | if args.eval: 305 | test_stats = evaluate(data_loader_val, model, device) 306 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 307 | exit(0) 308 | 309 | print(f"Start training for {args.epochs} epochs") 310 | start_time = time.time() 311 | max_accuracy = 0.0 312 | for epoch in range(args.start_epoch, args.epochs): 313 | if args.distributed: 314 | data_loader_train.sampler.set_epoch(epoch) 315 | train_stats = train_one_epoch( 316 | model, criterion, data_loader_train, 317 | optimizer, device, epoch, loss_scaler, 318 | args.clip_grad, mixup_fn, 319 | log_writer=log_writer, 320 | args=args 321 | ) 322 | if args.output_dir: 323 | misc.save_model_latest( 324 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 325 | loss_scaler=loss_scaler, epoch=epoch) 326 | 327 | if args.output_dir and (epoch % 99 == 0 or epoch + 1 == args.epochs): 328 | misc.save_model( 329 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 330 | loss_scaler=loss_scaler, epoch=epoch) 331 | 332 | test_stats = evaluate(data_loader_val, model, device) 333 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 334 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 335 | print(f'Max accuracy: {max_accuracy:.2f}%') 336 | 337 | if log_writer is not None: 338 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 339 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 340 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 341 | 342 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 343 | **{f'test_{k}': v for k, v in test_stats.items()}, 344 | 'epoch': epoch, 345 | 'n_parameters': n_parameters} 346 | 347 | if args.output_dir and misc.is_main_process(): 348 | if log_writer is not None: 349 | log_writer.flush() 350 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 351 | f.write(json.dumps(log_stats) + "\n") 352 | 353 | total_time = time.time() - start_time 354 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 355 | print('Training time {}'.format(total_time_str)) 356 | 357 | 358 | if __name__ == '__main__': 359 | args = get_args_parser() 360 | args = args.parse_args() 361 | if args.output_dir: 362 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 363 | main(args) 364 | -------------------------------------------------------------------------------- /models_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from timm.models.vision_transformer import PatchEmbed, Block 18 | 19 | from util.pos_embed import get_2d_sincos_pos_embed 20 | from transformers import CLIPVisionModel, ViTModel 21 | import pdb 22 | 23 | 24 | def resize_pos_embed(x): 25 | # [256, C] -> [196, C] 26 | C = x.shape[-1] 27 | x = x.reshape(1, 16, 16, C).permute(0, 3, 1, 2) 28 | x = F.interpolate(x, (14, 14), mode='bicubic', align_corners=False) 29 | x = x.permute(0, 2, 3, 1).reshape(196, C) 30 | return x 31 | 32 | 33 | class MaskedAutoencoderViT(nn.Module): 34 | """ Masked Autoencoder with VisionTransformer backbone 35 | """ 36 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 37 | embed_dim=1024, depth=24, num_heads=16, drop_path_rate=0., 38 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, 39 | loss_weights="mean", mask_type="random", fusion_type="simple", target_norm="none", loss_type="l2", 40 | head_type="linear", teacher_model="openai/clip-vit-base-patch16"): 41 | super().__init__() 42 | 43 | assert loss_weights in ["mean", "out", "linear_decay"] or "top" in loss_weights or "mid" in loss_weights 44 | self.loss_weights = loss_weights 45 | assert mask_type in ["random", "attention"] 46 | self.mask_type = mask_type 47 | assert fusion_type in ["simple", "linear", "sum"] 48 | self.fusion_type = fusion_type 49 | assert target_norm in ["none", "l2", "whiten", "bn"] 50 | self.target_norm = target_norm 51 | assert loss_type in ["l2", "l1", "smoothl1"] 52 | self.loss_type = loss_type 53 | assert head_type in ["linear", "norm_linear", "mlp", "mlp2"] 54 | self.head_type= head_type 55 | # assert "clip" in teacher_model or "dino" in teacher_model 56 | self.teacher_model_name = teacher_model 57 | 58 | # -------------------------------------------------------------------------- 59 | # MAE encoder specifics 60 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 61 | num_patches = self.patch_embed.num_patches 62 | 63 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 64 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 65 | 66 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 67 | self.blocks = nn.ModuleList([ 68 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, drop_path=dpr[i]) 69 | for i in range(depth)]) 70 | self.norm = norm_layer(embed_dim) 71 | 72 | if "clip-vit-base-patch16" in self.teacher_model_name or "dino-vitb16" in self.teacher_model_name: 73 | target_dim = 768 74 | teacher_depth = 12 75 | else: 76 | target_dim = 1024 77 | teacher_depth = 24 78 | 79 | if self.head_type == "linear": 80 | self.distill_heads = nn.ModuleList([nn.Linear(embed_dim, target_dim) for i in range(teacher_depth)]) 81 | elif self.head_type == "norm_linear": 82 | self.distill_heads = nn.ModuleList([nn.Sequential( 83 | norm_layer(embed_dim), 84 | nn.Linear(embed_dim, target_dim) 85 | ) 86 | for i in range(teacher_depth)]) 87 | elif self.head_type == "mlp": 88 | self.distill_heads = nn.ModuleList([nn.Sequential( 89 | nn.Linear(embed_dim, embed_dim), 90 | nn.GELU(), 91 | nn.Linear(embed_dim, target_dim) 92 | ) 93 | for i in range(teacher_depth)]) 94 | elif self.head_type == "mlp2": 95 | self.distill_heads = nn.ModuleList([nn.Sequential( 96 | nn.Linear(embed_dim, embed_dim), 97 | norm_layer(embed_dim), 98 | nn.Linear(embed_dim, target_dim) 99 | ) 100 | for i in range(teacher_depth)]) 101 | 102 | if self.fusion_type == "linear": 103 | # only len(student) == len(teacher) 104 | self.distill_weights = nn.Parameter(torch.eye(len(self.blocks)) + 0.01, requires_grad=True) 105 | elif self.fusion_type == "sum": 106 | self.distill_weights = nn.Parameter(torch.ones(teacher_depth, len(self.blocks)) / len(self.blocks), requires_grad=True) 107 | 108 | self.initialize_weights() 109 | 110 | if "clip" in self.teacher_model_name: 111 | self.clip_model = CLIPVisionModel.from_pretrained(self.teacher_model_name) 112 | for name, param in self.clip_model.named_parameters(): 113 | param.requires_grad = False 114 | if "clip-vit-large-patch14" in self.teacher_model_name and "position_embedding" in name: 115 | param.data = torch.cat([param.data[:1], resize_pos_embed(param.data[1:])], dim=0) 116 | if "clip-vit-large-patch14" in self.teacher_model_name: 117 | self.clip_model.vision_model.embeddings.position_ids = torch.arange(197).expand((1, -1)) 118 | 119 | elif "dino" in self.teacher_model_name: 120 | self.dino_model = ViTModel.from_pretrained(self.teacher_model_name) 121 | for param in self.dino_model.parameters(): 122 | param.requires_grad = False 123 | 124 | def initialize_weights(self): 125 | # initialization 126 | # initialize (and freeze) pos_embed by sin-cos embedding 127 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 128 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 129 | 130 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 131 | w = self.patch_embed.proj.weight.data 132 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 133 | 134 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 135 | torch.nn.init.normal_(self.cls_token, std=.02) 136 | # torch.nn.init.normal_(self.mask_token, std=.02) 137 | 138 | # initialize nn.Linear and nn.LayerNorm 139 | self.apply(self._init_weights) 140 | 141 | def _init_weights(self, m): 142 | if isinstance(m, nn.Linear): 143 | # we use xavier_uniform following official JAX ViT: 144 | torch.nn.init.xavier_uniform_(m.weight) 145 | if isinstance(m, nn.Linear) and m.bias is not None: 146 | nn.init.constant_(m.bias, 0) 147 | elif isinstance(m, nn.LayerNorm): 148 | nn.init.constant_(m.bias, 0) 149 | nn.init.constant_(m.weight, 1.0) 150 | 151 | def denormalize(self, images, type="imagenet"): 152 | # sr_images [B, 3, H, W] 153 | mean = torch.tensor([0.485, 0.456, 0.406], device=images.device).view(1, 3, 1, 1).type_as(images) 154 | std = torch.tensor([0.229, 0.224, 0.225], device=images.device).view(1, 3, 1, 1).type_as(images) 155 | return std*images + mean 156 | 157 | def normalize(self, images, type="clip"): 158 | # images [B, 3, h, w] 159 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=images.device).view(1, 3, 1, 1).type_as(images) 160 | std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=images.device).view(1, 3, 1, 1).type_as(images) 161 | return (images - mean) / std 162 | 163 | def patchify(self, imgs): 164 | """ 165 | imgs: (N, 3, H, W) 166 | x: (N, L, patch_size**2 *3) 167 | """ 168 | p = self.patch_embed.patch_size[0] 169 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 170 | 171 | h = w = imgs.shape[2] // p 172 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 173 | x = torch.einsum('nchpwq->nhwpqc', x) 174 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 175 | return x 176 | 177 | def unpatchify(self, x): 178 | """ 179 | x: (N, L, patch_size**2 *3) 180 | imgs: (N, 3, H, W) 181 | """ 182 | p = self.patch_embed.patch_size[0] 183 | h = w = int(x.shape[1]**.5) 184 | assert h * w == x.shape[1] 185 | 186 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 187 | x = torch.einsum('nhwpqc->nchpwq', x) 188 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 189 | return imgs 190 | 191 | def random_masking(self, x, mask_ratio): 192 | """ 193 | Perform per-sample random masking by per-sample shuffling. 194 | Per-sample shuffling is done by argsort random noise. 195 | x: [N, L, D], sequence 196 | """ 197 | N, L, D = x.shape # batch, length, dim 198 | len_keep = int(L * (1 - mask_ratio)) 199 | 200 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 201 | 202 | # sort noise for each sample 203 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 204 | ids_restore = torch.argsort(ids_shuffle, dim=1) 205 | 206 | # keep the first subset 207 | ids_keep = ids_shuffle[:, :len_keep] 208 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 209 | 210 | # generate the binary mask: 0 is keep, 1 is remove 211 | mask = torch.ones([N, L], device=x.device) 212 | mask[:, :len_keep] = 0 213 | # unshuffle to get the binary mask 214 | mask = torch.gather(mask, dim=1, index=ids_restore) 215 | 216 | return x_masked, ids_keep 217 | 218 | def attention_masking(self, x, mask_ratio, importance): 219 | """ 220 | Perform per-sample random masking by per-sample shuffling. 221 | Per-sample shuffling is done by argsort random noise. 222 | x: [N, L, D], sequence 223 | """ 224 | N, L, D = x.shape # batch, length, dim 225 | len_keep = int(L * (1 - mask_ratio)) 226 | 227 | noise = importance.to(x.device) # large is keep, small is remove 228 | 229 | # sort noise for each sample 230 | ids_shuffle = torch.multinomial(noise, L, replacement=False) 231 | ids_restore = torch.argsort(ids_shuffle, dim=1) 232 | 233 | # keep the first subset 234 | ids_keep = ids_shuffle[:, :len_keep] 235 | ids_dump = ids_shuffle[:, len_keep:] 236 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 237 | 238 | # generate the binary mask: 0 is keep, 1 is remove 239 | mask = torch.ones([N, L], device=x.device) 240 | mask[:, :len_keep] = 0 241 | # unshuffle to get the binary mask 242 | mask = torch.gather(mask, dim=1, index=ids_restore) 243 | 244 | return x_masked, ids_keep 245 | 246 | def forward_encoder(self, x, mask_ratio, attentions): 247 | # embed patches 248 | x = self.patch_embed(x) 249 | 250 | # add pos embed w/o cls token 251 | x = x + self.pos_embed[:, 1:, :] 252 | 253 | # masking: length -> length * mask_ratio 254 | if self.mask_type == "attention": 255 | importance = attentions[-1][:, :, 0, 1:].mean(1) 256 | x, ids_keep = self.attention_masking(x, mask_ratio, importance) 257 | else: 258 | x, ids_keep = self.random_masking(x, mask_ratio) 259 | 260 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 261 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 262 | x = torch.cat((cls_tokens, x), dim=1) 263 | 264 | hidden_states = [] 265 | # apply Transformer blocks 266 | for blk in self.blocks: 267 | x = blk(x) 268 | hidden_states.append(x) 269 | x = self.norm(x) 270 | 271 | return hidden_states, ids_keep 272 | 273 | @torch.no_grad() 274 | def forward_clip(self, x): 275 | if "clip-vit-large-patch14" in self.teacher_model_name: 276 | x = F.interpolate(x, (196, 196), mode='bicubic', align_corners=False) 277 | 278 | x = self.normalize(self.denormalize(x)) 279 | input = { 280 | "pixel_values": x, 281 | "output_hidden_states": True, 282 | "output_attentions": True 283 | } 284 | outputs = self.clip_model(**input) 285 | 286 | last_hidden_state, pooler_output, hidden_states, attentions = outputs[0], outputs[1], outputs[2], outputs[3] 287 | return last_hidden_state, pooler_output, hidden_states, attentions 288 | 289 | @torch.no_grad() 290 | def forward_dino(self, x): 291 | input = { 292 | "pixel_values": x, 293 | "output_hidden_states": True, 294 | "output_attentions": True 295 | } 296 | outputs = self.dino_model(**input) 297 | 298 | last_hidden_state, pooler_output, hidden_states, attentions = outputs[0], outputs[1], outputs[2], outputs[3] 299 | return last_hidden_state, pooler_output, hidden_states, attentions 300 | 301 | 302 | def get_student(self, hidden_states): 303 | student = hidden_states 304 | if self.fusion_type != "simple": 305 | student = [x.unsqueeze(0) for x in student] 306 | student = torch.cat(student, dim=0) 307 | student = torch.einsum('ab,bcde->acde', self.distill_weights, student) 308 | student = torch.chunk(student, student.shape[0], dim=0) 309 | student = [x.squeeze(0) for x in student] 310 | student = [self.distill_heads[i](x) for i, x in enumerate(student)] 311 | return student 312 | 313 | def get_teacher(self, hidden_states, ids_keep): 314 | teacher = [] 315 | for i in range(1, len(hidden_states)): 316 | y = hidden_states[i] 317 | if self.target_norm == "l2": 318 | y = F.normalize(y, dim=-1) 319 | elif self.target_norm == "whiten": 320 | y = F.layer_norm(y, (y.shape[-1],)) 321 | elif self.target_norm == "bn": 322 | y = (y - y.mean()) / (y.var() + 1.e-6)**.5 323 | cls = y[:, :1, :] 324 | y = y[:, 1:, :] 325 | y = torch.gather(y, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, y.shape[-1])) 326 | teacher.append(torch.cat([cls, y], dim=1)) 327 | return teacher 328 | 329 | def forward_loss(self, student, teacher): 330 | """ 331 | student: ([B*4, L//4, C]...) 332 | teacher: ([B, 1+L, C]...) 333 | ids_shuffle: [B, L] 334 | """ 335 | loss = torch.tensor(0., device=student[0].device) 336 | 337 | if self.loss_weights == "mean": 338 | weight_list = [1/len(student)]*len(student) 339 | elif self.loss_weights == "out": 340 | weight_list = [0.]*(len(student)-1) + [1.] 341 | elif self.loss_weights == "linear_decay": 342 | weight_list_ = list(range(len(student))) 343 | weight_list = [i / sum(weight_list_) for i in weight_list_] 344 | elif "top" in self.loss_weights: # topk 345 | topk = int(self.loss_weights[3:]) 346 | weight_list = [0.] * (len(student)-topk) + [1/topk] * topk 347 | elif "mid" in self.loss_weights: 348 | mid = int(self.loss_weights[3:]) 349 | weight_list = [0.] * mid + [1.] + [0.] * (len(student) - mid - 1) 350 | 351 | for i, x in enumerate(student): 352 | y = teacher[i] 353 | if weight_list[i] > 0: 354 | if self.loss_type == "l2": 355 | loss = loss + weight_list[i] * ((y - x) ** 2).mean() 356 | elif self.loss_type == "smoothl1": 357 | loss = loss + weight_list[i] * 2 * F.smooth_l1_loss(y, x) 358 | elif self.loss_type == "l1": 359 | loss = loss + weight_list[i] * F.l1_loss(y, x) 360 | return loss 361 | 362 | def forward(self, imgs, mask_ratio=0.75): 363 | if "clip" in self.teacher_model_name: 364 | _, _, hidden_states_teacher, attentions = self.forward_clip(imgs) 365 | elif "dino" in self.teacher_model_name: 366 | _, _, hidden_states_teacher, attentions = self.forward_dino(imgs) 367 | hidden_states, ids_keep = self.forward_encoder(imgs, mask_ratio, attentions) 368 | student = self.get_student(hidden_states) 369 | teacher = self.get_teacher(hidden_states_teacher, ids_keep) 370 | loss = self.forward_loss(student, teacher) 371 | return loss 372 | 373 | 374 | def mae_vit_base_patch16(**kwargs): 375 | model = MaskedAutoencoderViT( 376 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 377 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 378 | return model 379 | 380 | 381 | def mae_vit_large_patch16(**kwargs): 382 | model = MaskedAutoencoderViT( 383 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 384 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 385 | return model 386 | 387 | 388 | 389 | # set recommended archs 390 | mae_vit_base_patch16 = mae_vit_base_patch16 391 | mae_vit_large_patch16 = mae_vit_large_patch16 392 | --------------------------------------------------------------------------------