├── .gitignore ├── requirements.txt ├── util ├── lr_sched.py ├── crop.py ├── lars.py ├── datasets.py ├── lr_decay.py ├── pos_embed.py └── misc.py ├── models_vit.py ├── engine_pretrain.py ├── CODE_OF_CONDUCT.md ├── submitit_pretrain.py ├── submitit_finetune.py ├── submitit_linprobe.py ├── engine_finetune.py ├── README.md ├── transformer_utils.py ├── models_mae.py ├── main_pretrain.py ├── main_linprobe.py ├── models_cross.py ├── main_finetune.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | output_dir/* 2 | *.yaml 3 | .idea/* 4 | *__pycache__* 5 | util/__pycache__/* 6 | *_ignored* 7 | output 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | torch==2.0.0 3 | torchvision 4 | flash-attn==2.3.1.post1 5 | timm==0.9.7 6 | scipy 7 | matplotlib 8 | tensorboard 9 | packaging 10 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | torch_version = torch.__version__ 11 | is_torch2 = torch_version.startswith('2.') 12 | 13 | from torchvision import transforms 14 | from torchvision.transforms import functional as F 15 | 16 | 17 | class RandomResizedCrop(transforms.RandomResizedCrop): 18 | """ 19 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 20 | This may lead to results different with torchvision's version. 21 | Following BYOL's TF code: 22 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 23 | """ 24 | @staticmethod 25 | def get_params(img, scale, ratio): 26 | if is_torch2: 27 | width, height = F.get_image_size(img) 28 | else: 29 | width, height = F._get_image_size(img) 30 | area = height * width 31 | 32 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 33 | log_ratio = torch.log(torch.tensor(ratio)) 34 | aspect_ratio = torch.exp( 35 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 36 | ).item() 37 | 38 | w = int(round(math.sqrt(target_area * aspect_ratio))) 39 | h = int(round(math.sqrt(target_area / aspect_ratio))) 40 | 41 | w = min(w, width) 42 | h = min(h, height) 43 | 44 | i = torch.randint(0, height - h + 1, size=(1,)).item() 45 | j = torch.randint(0, width - w + 1, size=(1,)).item() 46 | 47 | return i, j, h, w -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | import timm 18 | import timm.models.vision_transformer 19 | 20 | new_timm = '0.9' in timm.__version__ 21 | 22 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 23 | """ Vision Transformer with support for global average pooling 24 | """ 25 | def __init__(self, global_pool=False, **kwargs): 26 | global_pool = "avg" if global_pool else "token" 27 | super(VisionTransformer, self).__init__( 28 | global_pool=global_pool, 29 | **kwargs 30 | ) 31 | 32 | if global_pool == "avg": 33 | norm_layer = kwargs['norm_layer'] 34 | embed_dim = kwargs['embed_dim'] 35 | self.fc_norm = norm_layer(embed_dim) 36 | 37 | self.norm = nn.Identity() # remove the original norm 38 | 39 | def forward_features(self, x): 40 | if new_timm: 41 | x = super().forward_features(x) 42 | return x 43 | 44 | B = x.shape[0] 45 | x = self.patch_embed(x) 46 | 47 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 48 | x = torch.cat((cls_tokens, x), dim=1) 49 | x = x + self.pos_embed 50 | x = self.pos_drop(x) 51 | 52 | for blk in self.blocks: 53 | x = blk(x) 54 | 55 | if self.global_pool == "avg": 56 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 57 | outcome = self.fc_norm(x) 58 | else: 59 | x = self.norm(x) 60 | outcome = x[:, 0] 61 | 62 | return outcome 63 | 64 | 65 | def vit_small_patch16(**kwargs): 66 | model = VisionTransformer( 67 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 68 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 69 | return model 70 | 71 | 72 | def vit_base_patch16(**kwargs): 73 | model = VisionTransformer( 74 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 75 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 76 | return model 77 | 78 | 79 | def vit_large_patch16(**kwargs): 80 | model = VisionTransformer( 81 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 82 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 83 | return model 84 | 85 | 86 | def vit_huge_patch14(**kwargs): 87 | model = VisionTransformer( 88 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 89 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 90 | return model -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | import sys 13 | from typing import Iterable 14 | 15 | import torch 16 | 17 | import util.misc as misc 18 | import util.lr_sched as lr_sched 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, 24 | log_writer=None, 25 | args=None): 26 | model.train(True) 27 | metric_logger = misc.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 20 31 | 32 | accum_iter = args.accum_iter 33 | 34 | optimizer.zero_grad() 35 | 36 | if log_writer is not None: 37 | print('log_dir: {}'.format(log_writer.log_dir)) 38 | 39 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 40 | 41 | # we use a per iteration (instead of per epoch) lr scheduler 42 | if data_iter_step % accum_iter == 0: 43 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 44 | 45 | samples = samples.to(device, non_blocking=True) 46 | 47 | loss = model(samples, mask_ratio=args.mask_ratio, kept_mask_ratio=args.kept_mask_ratio) 48 | 49 | loss_value = loss.item() 50 | 51 | if not math.isfinite(loss_value): 52 | print("Loss is {}, stopping training".format(loss_value)) 53 | sys.exit(1) 54 | 55 | loss /= accum_iter 56 | loss_scaler(loss, optimizer, parameters=model.parameters(), 57 | update_grad=(data_iter_step + 1) % accum_iter == 0) 58 | if (data_iter_step + 1) % accum_iter == 0: 59 | optimizer.zero_grad() 60 | 61 | torch.cuda.synchronize() 62 | 63 | metric_logger.update(loss=loss_value) 64 | 65 | lr = optimizer.param_groups[0]["lr"] 66 | metric_logger.update(lr=lr) 67 | 68 | loss_value_reduce = misc.all_reduce_mean(loss_value) 69 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 70 | """ We use epoch_1000x as the x-axis in tensorboard. 71 | This calibrates different curves when batch size changes. 72 | """ 73 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 74 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 75 | log_writer.add_scalar('lr', lr, epoch_1000x) 76 | 77 | 78 | # gather the stats from all processes 79 | metric_logger.synchronize_between_processes() 80 | print("Averaged stats:", metric_logger) 81 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_pretrain as trainer 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | trainer_parser = trainer.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE pretrain", parents=[trainer_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_pretrain as trainer 57 | 58 | self._setup_gpu_args() 59 | trainer.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, # max is 60 * 72 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /submitit_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_finetune as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_finetune as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /submitit_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_linprobe as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_linprobe as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy 20 | 21 | import util.misc as misc 22 | import util.lr_sched as lr_sched 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 28 | mixup_fn: Optional[Mixup] = None, log_writer=None, 29 | args=None): 30 | model.train(True) 31 | metric_logger = misc.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 20 35 | 36 | accum_iter = args.accum_iter 37 | 38 | optimizer.zero_grad() 39 | 40 | if log_writer is not None: 41 | print('log_dir: {}'.format(log_writer.log_dir)) 42 | 43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | 45 | # we use a per iteration (instead of per epoch) lr scheduler 46 | if data_iter_step % accum_iter == 0: 47 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 48 | 49 | samples = samples.to(device, non_blocking=True) 50 | targets = targets.to(device, non_blocking=True) 51 | 52 | if mixup_fn is not None: 53 | samples, targets = mixup_fn(samples, targets) 54 | 55 | with torch.cuda.amp.autocast(): 56 | outputs = model(samples) 57 | loss = criterion(outputs, targets) 58 | 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | sys.exit(1) 64 | 65 | loss /= accum_iter 66 | loss_scaler(loss, optimizer, clip_grad=max_norm, 67 | parameters=model.parameters(), create_graph=False, 68 | update_grad=(data_iter_step + 1) % accum_iter == 0) 69 | if (data_iter_step + 1) % accum_iter == 0: 70 | optimizer.zero_grad() 71 | 72 | torch.cuda.synchronize() 73 | 74 | metric_logger.update(loss=loss_value) 75 | min_lr = 10. 76 | max_lr = 0. 77 | for group in optimizer.param_groups: 78 | min_lr = min(min_lr, group["lr"]) 79 | max_lr = max(max_lr, group["lr"]) 80 | 81 | metric_logger.update(lr=max_lr) 82 | 83 | loss_value_reduce = misc.all_reduce_mean(loss_value) 84 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 85 | """ We use epoch_1000x as the x-axis in tensorboard. 86 | This calibrates different curves when batch size changes. 87 | """ 88 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 89 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 90 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 91 | 92 | # gather the stats from all processes 93 | metric_logger.synchronize_between_processes() 94 | print("Averaged stats:", metric_logger) 95 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 96 | 97 | 98 | @torch.no_grad() 99 | def evaluate(data_loader, model, device): 100 | criterion = torch.nn.CrossEntropyLoss() 101 | 102 | metric_logger = misc.MetricLogger(delimiter=" ") 103 | header = 'Test:' 104 | 105 | # switch to evaluation mode 106 | model.eval() 107 | 108 | for batch in metric_logger.log_every(data_loader, 10, header): 109 | images = batch[0] 110 | target = batch[-1] 111 | images = images.to(device, non_blocking=True) 112 | target = target.to(device, non_blocking=True) 113 | 114 | # compute output 115 | with torch.cuda.amp.autocast(): 116 | output = model(images) 117 | loss = criterion(output, target) 118 | 119 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 120 | 121 | batch_size = images.shape[0] 122 | metric_logger.update(loss=loss.item()) 123 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 124 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 125 | # gather the stats from all processes 126 | metric_logger.synchronize_between_processes() 127 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 128 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 129 | 130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | def get_2d_sincos_pos_embed_for_coords(embed_dim, coords, image_size=224, patch_size=16): 15 | """ 16 | embed_dim: int, the embedding dimension 17 | coords: an array of shape [num_pixels, 2], containing the x, y coordinates of the pixels (on the raw image) 18 | image_size: size of the rgb image 19 | patch size: size of the patch 20 | return: 21 | pos_embed: [num_pixels, embed_dim] 22 | """ 23 | assert embed_dim % 2 == 0 24 | 25 | # Separate the coordinates into x and y 26 | grid_x = coords[:, 0] - patch_size // 2 27 | grid_y = coords[:, 1] - patch_size // 2 28 | 29 | # Use half of dimensions to encode each coordinate 30 | emb_x = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_x, idx_range=image_size - patch_size) # (num_pixels, D/2) 31 | emb_y = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_y, idx_range=image_size - patch_size) # (num_pixels, D/2) 32 | 33 | emb = np.concatenate([emb_x, emb_y], axis=1) # (num_pixels, D) 34 | return emb 35 | 36 | # -------------------------------------------------------- 37 | # 2D sine-cosine position embedding 38 | # References: 39 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 40 | # MoCo v3: https://github.com/facebookresearch/moco-v3 41 | # -------------------------------------------------------- 42 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 43 | """ 44 | grid_size: int of the grid height and width 45 | return: 46 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 47 | """ 48 | grid_h = np.arange(grid_size, dtype=np.float32) 49 | grid_w = np.arange(grid_size, dtype=np.float32) 50 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 51 | grid = np.stack(grid, axis=0) 52 | 53 | grid = grid.reshape([2, 1, grid_size, grid_size]) 54 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 55 | if cls_token: 56 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 57 | return pos_embed 58 | 59 | 60 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 61 | assert embed_dim % 2 == 0 62 | 63 | # use half of dimensions to encode grid_h 64 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 65 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 66 | 67 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 68 | return emb 69 | 70 | 71 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, idx_range=14): 72 | """ 73 | embed_dim: output dimension for each position 74 | pos: a list of positions to be encoded: size (M,) 75 | out: (M, D) 76 | """ 77 | assert embed_dim % 2 == 0 78 | omega = np.arange(embed_dim // 2, dtype=np.float32) 79 | omega /= embed_dim / 2. 80 | # scale per OG MAE-ViT (14x14, min index 0, max index 13) 81 | omega = (1. / 10000**omega) / (idx_range - 1) * 13 # (D/2,) 82 | 83 | pos = pos.reshape(-1) # (M,) 84 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 85 | 86 | emb_sin = np.sin(out) # (M, D/2) 87 | emb_cos = np.cos(out) # (M, D/2) 88 | 89 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 90 | return emb 91 | 92 | 93 | # -------------------------------------------------------- 94 | # Interpolate position embeddings for high-resolution 95 | # References: 96 | # DeiT: https://github.com/facebookresearch/deit 97 | # -------------------------------------------------------- 98 | def interpolate_pos_embed(model, checkpoint_model): 99 | if 'pos_embed' in checkpoint_model: 100 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 101 | embedding_size = pos_embed_checkpoint.shape[-1] 102 | num_patches = model.patch_embed.num_patches 103 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 104 | # height (== width) for the checkpoint position embedding 105 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 106 | # height (== width) for the new position embedding 107 | new_size = int(num_patches ** 0.5) 108 | # class_token and dist_token are kept unchanged 109 | if orig_size != new_size: 110 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 111 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 112 | # only the position tokens are interpolated 113 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 114 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 115 | pos_tokens = torch.nn.functional.interpolate( 116 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 117 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 118 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 119 | checkpoint_model['pos_embed'] = new_pos_embed 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CrossMAE: Rethinking Patch Dependence for Masked Autoencoders 2 | by Letian Fu*, Long Lian*, Renhao Wang, Baifeng Shi, Xudong Wang, Adam Yala†, Trevor Darrell†, Alexei A. Efros†, Ken Goldberg† at UC Berkeley and UCSF 3 | 4 | [[Paper](https://openreview.net/forum?id=JT2KMuo2BV)] | [[Project Page](https://crossmae.github.io/)] | [[Citation](#citation)] 5 | 6 | 7 |

8 | 9 |

10 | 11 | This is a PyTorch implementation of the CrossMAE paper [Rethinking Patch Dependence for Masked Autoencoders](https://crossmae.github.io/). The code is based on the original [MAE](https://github.com/facebookresearch/mae) repo. The codebase supports CrossMAE and MAE, with `timm==0.9.7`, `torch==2.0.0`, and flash-attn 2. 12 | 13 | ## Models 14 | The encoder part of CrossMAE matches exactly with MAE. Therefore, we use the same code for fine-tuning. We also encourage you to try CrossMAE checkpoints in your downstream applications. These models are trained on ImageNet-1k for 800 epochs (except that 448 models are trained for 400 epochs), with masking ratio and kept mask ratio both set to 0.75, except that ViT-H is with masking ratio 0.75 and kept mask ratio 0.25. 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 |
ViT-SmallViT-BaseViT-Base448ViT-LargeViT-Huge
pretrained checkpointdownloaddownloaddownloaddownloaddownload
fine-tuned checkpointdownloaddownloaddownloaddownloaddownload
Reference ImageNet accuracy (ours)79.31883.72284.59885.43286.256
MAE ImageNet accuracy (baseline)84.885.9
55 | 56 | ## Train CrossMAE on **one single RTX 4090** 57 | With the efficiency of CrossMAE, it's possible to train CrossMAE on **one single RTX 4090** on a personal computer. The CPU is i9-14900k, with 96GB RAM. 58 | 59 |
60 | Instructions and trained models 61 | 62 | The training and fine-tuning command (with `${IMAGENET_DIR}` the directory for imagenet, ViT-S as an example): 63 | ```sh 64 | CUDA_VISIBLE_DEVICES=0 OMP_NUM_THREADS=1 torchrun --nproc_per_node=1 --master_port 2780 main_pretrain.py --batch_size 512 --accum_iter 8 --model mae_vit_small_patch16 --norm_pix_loss --blr 1.5e-4 --weight_decay 0.05 --data_path ${IMAGENET_DIR} --num_workers 16 --multi_epochs_dataloader --output_dir output/imagenet-crossmae-vits-pretrain-wfm-mr0.75-kmr0.25-dd12-ep800 --cross_mae --weight_fm --decoder_depth 12 --mask_ratio 0.75 --kept_mask_ratio 0.75 --epochs 800 --warmup_epochs 40 --use_input 65 | 66 | CUDA_VISIBLE_DEVICES=0 OMP_NUM_THREADS=1 torchrun --nproc_per_node=1 --master_port 2860 main_finetune.py --batch_size 512 --accum_iter 2 --model vit_small_patch16 --finetune output/imagenet-crossmae-vits-pretrain-wfm-mr0.75-kmr0.25-dd12-ep800/checkpoint.pth --epoch 100 --blr 5e-4 --layer_decay 0.65 --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 --dist_eval --data_path ${IMAGENET_DIR} --num_workers 12 --output_dir output/imagenet-crossmae-vits-finetune-wfm-mr0.75-kmr0.25-dd12-ep800 --multi_epochs_dataloader 67 | # Reference results: 68 | # * Acc@1 79.462 Acc@5 94.864 loss 0.907 69 | ``` 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 |
pretrained checkpointfine-tuned checkpointreference ImageNet accuracy
downloaddownload79.462
88 | 89 |
90 | 91 | ## Instructions 92 | Please install the dependencies in `requirements.txt`: 93 | ```sh 94 | # Optionally create a conda environment 95 | conda create -n crossmae python=3.10 -y 96 | conda activate crossmae 97 | # Install dependencies 98 | pip install -r requirements.txt 99 | ``` 100 | 101 | ### Pre-training CrossMAE 102 | To pre-train ViT-Base, run the following on 4 GPUs: 103 | ```sh 104 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port 1234 main_pretrain.py --batch_size 1024 --model mae_vit_base_patch16 --norm_pix_loss --blr 1.5e-4 --weight_decay 0.05 --data_path ${IMAGENET_DIR} --num_workers 20 --enable_flash_attention2 --multi_epochs_dataloader --output_dir output/imagenet-crossmae-vitb-pretrain-wfm-mr0.75-kmr0.25-dd12-ep800 --cross_mae --weight_fm --decoder_depth 12 --mask_ratio 0.75 --kept_mask_ratio 0.25 --epochs 800 --warmup_epochs 40 --use_input 105 | ``` 106 | 107 | To train ViT-Small or ViT-Large, set `--model mae_vit_small_patch16` or `--model mae_vit_large_patch16`. You can use `--accum_iter` to perform gradient accumulation if your hardware could not fit the batch size. [FlashAttention 2](https://github.com/Dao-AILab/flash-attention) should be installed with `pip install flash-attn --no-build-isolation`. 108 | 109 | ### Fine-tuning CrossMAE 110 | To pre-train ViT-Base, run the following on 4 GPUs: 111 | ```sh 112 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port 1234 main_finetune.py --batch_size 256 --model vit_base_patch16 --finetune output/imagenet-crossmae-vitb-pretrain-wfm-mr0.75-kmr0.25-dd12-ep800/checkpoint.pth --epoch 100 --blr 5e-4 --layer_decay 0.65 --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 --dist_eval --data_path ${IMAGENET_DIR} --output_dir output/imagenet-crossmae-vitb-finetune-wfm-mr0.75-kmr0.25-dd12-ep800 --enable_flash_attention2 --multi_epochs_dataloader 113 | ``` 114 | 115 | ## Evaluation 116 | Evaluate ViT-Base in a single GPU (`${IMAGENET_DIR}` is a directory containing `{train, val}` sets of ImageNet). `${FINETUNED_CHECKPOINT_PATH}` is the path to the fine-tuned checkpoint: 117 | ```sh 118 | python main_finetune.py --eval --resume ${FINETUNED_CHECKPOINT_PATH} --model vit_base_patch16 --batch_size 16 --data_path ${IMAGENET_DIR} 119 | ``` 120 | This should give: 121 | ``` 122 | * Acc@1 83.722 Acc@5 96.686 loss 0.729 123 | ``` 124 | 125 | You could replace `vit_base_patch16` with `vit_small_patch16` or `vit_large_patch16` to evaluate ViT-S or ViT-L. To work with 448 input resolution, please append `--input_size 448` to the command line. 126 | 127 | ### License 128 | 129 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 130 | 131 | ## Citation 132 | Please give us a star 🌟 on Github to support us! 133 | 134 | Please cite our work if you find our work inspiring or use our code in your work: 135 | ``` 136 | @article{ 137 | fu2025rethinking, 138 | title={Rethinking Patch Dependence for Masked Autoencoders}, 139 | author={Letian Fu and Long Lian and Renhao Wang and Baifeng Shi and XuDong Wang and Adam Yala and Trevor Darrell and Alexei A Efros and Ken Goldberg}, 140 | journal={Transactions on Machine Learning Research}, 141 | issn={2835-8856}, 142 | year={2025}, 143 | url={https://openreview.net/forum?id=JT2KMuo2BV}, 144 | note={} 145 | } 146 | ``` 147 | -------------------------------------------------------------------------------- /transformer_utils.py: -------------------------------------------------------------------------------- 1 | # This file is largely from timm 2 | # The functions from timm (https://github.com/huggingface/pytorch-image-models/tree/main) adheres to the original license 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from functools import partial 8 | from timm.layers import DropPath 9 | from timm.layers.helpers import to_2tuple 10 | 11 | torch_version = torch.__version__ 12 | is_torch2 = torch_version.startswith('2.') 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.act = act_layer() 21 | self.fc2 = nn.Linear(hidden_features, out_features) 22 | self.drop = nn.Dropout(drop) 23 | 24 | def forward(self, x): 25 | x = self.fc1(x) 26 | x = self.act(x) 27 | x = self.drop(x) 28 | x = self.fc2(x) 29 | x = self.drop(x) 30 | return x 31 | 32 | 33 | class Attention(nn.Module): 34 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 35 | super().__init__() 36 | self.num_heads = num_heads 37 | head_dim = dim // num_heads 38 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 39 | self.scale = qk_scale or head_dim ** -0.5 40 | 41 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 42 | if is_torch2: 43 | self.attn_drop = attn_drop 44 | else: 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x): 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 53 | 54 | if is_torch2: 55 | attn = F.scaled_dot_product_attention( 56 | q, k, v, dropout_p=self.attn_drop, 57 | ) 58 | x = attn.transpose(1, 2).reshape(B, N, C) 59 | else: 60 | attn = (q @ k.transpose(-2, -1)) * self.scale 61 | attn = attn.softmax(dim=-1) 62 | attn = self.attn_drop(attn) 63 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 64 | 65 | x = self.proj(x) 66 | x = self.proj_drop(x) 67 | return x 68 | 69 | class CrossAttention(nn.Module): 70 | def __init__(self, encoder_dim, decoder_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 71 | super().__init__() 72 | self.num_heads = num_heads 73 | head_dim = decoder_dim // num_heads 74 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 75 | self.scale = qk_scale or head_dim ** -0.5 76 | self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias) 77 | self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias) 78 | if is_torch2: 79 | self.attn_drop = attn_drop 80 | else: 81 | self.attn_drop = nn.Dropout(attn_drop) 82 | self.proj = nn.Linear(decoder_dim, decoder_dim) 83 | self.proj_drop = nn.Dropout(proj_drop) 84 | 85 | def forward(self, x, y): 86 | """ 87 | query from decoder (x), key and value from encoder (y) 88 | """ 89 | B, N, C = x.shape 90 | Ny = y.shape[1] 91 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 92 | kv = self.kv(y).reshape(B, Ny, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 93 | k, v = kv[0], kv[1] 94 | 95 | if is_torch2: 96 | attn = F.scaled_dot_product_attention( 97 | q, k, v, dropout_p=self.attn_drop, 98 | ) 99 | x = attn.transpose(1, 2).reshape(B, N, C) 100 | else: 101 | attn = (q @ k.transpose(-2, -1)) * self.scale 102 | attn = attn.softmax(dim=-1) 103 | attn = self.attn_drop(attn) 104 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 105 | 106 | x = self.proj(x) 107 | x = self.proj_drop(x) 108 | return x 109 | 110 | class Block(nn.Module): 111 | 112 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 113 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 114 | super().__init__() 115 | self.norm1 = norm_layer(dim) 116 | self.attn = Attention( 117 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 118 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 119 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 120 | self.norm2 = norm_layer(dim) 121 | mlp_hidden_dim = int(dim * mlp_ratio) 122 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 123 | 124 | def forward(self, x): 125 | x = x + self.drop_path(self.attn(self.norm1(x))) 126 | x = x + self.drop_path(self.mlp(self.norm2(x))) 127 | return x 128 | 129 | class CrossAttentionBlock(nn.Module): 130 | 131 | def __init__(self, encoder_dim, decoder_dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 132 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, self_attn=False): 133 | super().__init__() 134 | self.self_attn = self_attn 135 | if self.self_attn: 136 | self.norm0 = norm_layer(decoder_dim) 137 | self.self_attn = Attention( 138 | decoder_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 139 | self.norm1 = norm_layer(decoder_dim) 140 | self.cross_attn = CrossAttention( 141 | encoder_dim, decoder_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 142 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 143 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 144 | self.norm2 = norm_layer(decoder_dim) 145 | mlp_hidden_dim = int(decoder_dim * mlp_ratio) 146 | self.mlp = Mlp(in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 147 | 148 | def forward(self, x, y): 149 | """ 150 | x: decoder feature; y: encoder feature (after layernorm) 151 | """ 152 | if self.self_attn: 153 | x = x + self.drop_path(self.self_attn(self.norm0(x))) 154 | x = x + self.drop_path(self.cross_attn(self.norm1(x), y)) 155 | x = x + self.drop_path(self.mlp(self.norm2(x))) 156 | return x 157 | 158 | class PatchEmbed(nn.Module): 159 | """ 2D Image to Patch Embedding 160 | """ 161 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 162 | super().__init__() 163 | img_size = to_2tuple(img_size) 164 | patch_size = to_2tuple(patch_size) 165 | self.img_size = img_size 166 | self.patch_size = patch_size 167 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 168 | self.num_patches = self.grid_size[0] * self.grid_size[1] 169 | self.flatten = flatten 170 | 171 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 172 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 173 | 174 | def forward(self, x, random_sample=False): 175 | B, C, H, W = x.shape 176 | assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), \ 177 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 178 | x = self.proj(x) 179 | if self.flatten: 180 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 181 | x = self.norm(x) 182 | return x 183 | 184 | def handle_flash_attn(args): 185 | sm = torch.cuda.get_device_capability(0) 186 | # https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/ 187 | enable_flashattn = sm[0] >= 8 or (sm[0] == 7 and sm[1] >= 5) 188 | 189 | print(f"enable_flashattn: {enable_flashattn}") 190 | 191 | if args.enable_flash_attention2: 192 | print("Flash attention 2 enabled") 193 | 194 | # This requies installing https://github.com/Dao-AILab/flash-attention/tree/v2.2.3 195 | 196 | assert enable_flashattn, "Flash attn requires compute capabilities" 197 | 198 | from flash_attn import flash_attn_func 199 | 200 | torch_scaled_dot_product_attention = F.scaled_dot_product_attention 201 | 202 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): 203 | # torch convention: B, num heads, seq len, C 204 | # print(f"Using flash attention, query: {query.shape}, key: {key.shape}, value: {value.shape}") 205 | assert attn_mask is None, attn_mask 206 | if query.shape[-1] > 256: 207 | return torch_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) 208 | return torch.permute(flash_attn_func(torch.permute(query, [0, 2, 1, 3]), torch.permute(key, [0, 2, 1, 3]), torch.permute(value, [0, 2, 1, 3]), dropout_p=dropout_p, causal=is_causal), [0, 2, 1, 3]) 209 | 210 | F.scaled_dot_product_attention = scaled_dot_product_attention 211 | 212 | # Use memory efficient attention as a fallback 213 | torch.backends.cuda.enable_flash_sdp(False) 214 | torch.backends.cuda.enable_mem_efficient_sdp(True) 215 | torch.backends.cuda.enable_math_sdp(False) 216 | else: 217 | print("Flash attention 2 is not enabled. Using built-in attention implementation.") 218 | torch.backends.cuda.enable_flash_sdp(enable_flashattn) 219 | torch.backends.cuda.enable_mem_efficient_sdp(not enable_flashattn) 220 | torch.backends.cuda.enable_math_sdp(False) 221 | -------------------------------------------------------------------------------- /models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from transformer_utils import Block, PatchEmbed 18 | 19 | from util.pos_embed import get_2d_sincos_pos_embed 20 | 21 | 22 | class MaskedAutoencoderViT(nn.Module): 23 | """ Masked Autoencoder with VisionTransformer backbone 24 | """ 25 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 26 | embed_dim=1024, depth=24, num_heads=16, 27 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 28 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 29 | super().__init__() 30 | 31 | # -------------------------------------------------------------------------- 32 | # MAE encoder specifics 33 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 34 | num_patches = self.patch_embed.num_patches 35 | 36 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 37 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 38 | 39 | self.blocks = nn.ModuleList([ 40 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 41 | for i in range(depth)]) 42 | self.norm = norm_layer(embed_dim) 43 | # -------------------------------------------------------------------------- 44 | 45 | # -------------------------------------------------------------------------- 46 | # MAE decoder specifics 47 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 48 | 49 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 50 | 51 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 52 | 53 | self.decoder_blocks = nn.ModuleList([ 54 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 55 | for i in range(decoder_depth)]) 56 | 57 | self.decoder_norm = norm_layer(decoder_embed_dim) 58 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 59 | # -------------------------------------------------------------------------- 60 | 61 | self.norm_pix_loss = norm_pix_loss 62 | 63 | self.initialize_weights() 64 | 65 | def initialize_weights(self): 66 | # initialization 67 | # initialize (and freeze) pos_embed by sin-cos embedding 68 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 69 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 70 | 71 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 72 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 73 | 74 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 75 | w = self.patch_embed.proj.weight.data 76 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 77 | 78 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 79 | torch.nn.init.normal_(self.cls_token, std=.02) 80 | torch.nn.init.normal_(self.mask_token, std=.02) 81 | 82 | # initialize nn.Linear and nn.LayerNorm 83 | self.apply(self._init_weights) 84 | 85 | def _init_weights(self, m): 86 | if isinstance(m, nn.Linear): 87 | # we use xavier_uniform following official JAX ViT: 88 | torch.nn.init.xavier_uniform_(m.weight) 89 | if isinstance(m, nn.Linear) and m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.LayerNorm): 92 | nn.init.constant_(m.bias, 0) 93 | nn.init.constant_(m.weight, 1.0) 94 | 95 | def patchify(self, imgs): 96 | """ 97 | imgs: (N, 3, H, W) 98 | x: (N, L, patch_size**2 *3) 99 | """ 100 | p = self.patch_embed.patch_size[0] 101 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 102 | 103 | h = w = imgs.shape[2] // p 104 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 105 | x = torch.einsum('nchpwq->nhwpqc', x) 106 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 107 | return x 108 | 109 | def unpatchify(self, x): 110 | """ 111 | x: (N, L, patch_size**2 *3) 112 | imgs: (N, 3, H, W) 113 | """ 114 | p = self.patch_embed.patch_size[0] 115 | h = w = int(x.shape[1]**.5) 116 | assert h * w == x.shape[1] 117 | 118 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 119 | x = torch.einsum('nhwpqc->nchpwq', x) 120 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 121 | return imgs 122 | 123 | def random_masking(self, x, mask_ratio): 124 | """ 125 | Perform per-sample random masking by per-sample shuffling. 126 | Per-sample shuffling is done by argsort random noise. 127 | x: [N, L, D], sequence 128 | """ 129 | N, L, D = x.shape # batch, length, dim 130 | len_keep = int(L * (1 - mask_ratio)) 131 | 132 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 133 | 134 | # sort noise for each sample 135 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 136 | ids_restore = torch.argsort(ids_shuffle, dim=1) 137 | 138 | # keep the first subset 139 | ids_keep = ids_shuffle[:, :len_keep] 140 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 141 | 142 | # generate the binary mask: 0 is keep, 1 is remove 143 | mask = torch.ones([N, L], device=x.device) 144 | mask[:, :len_keep] = 0 145 | # unshuffle to get the binary mask 146 | mask = torch.gather(mask, dim=1, index=ids_restore) 147 | 148 | return x_masked, mask, ids_restore 149 | 150 | def forward_encoder(self, x, mask_ratio): 151 | # embed patches 152 | x = self.patch_embed(x) 153 | 154 | # add pos embed w/o cls token 155 | x = x + self.pos_embed[:, 1:, :] 156 | 157 | # masking: length -> length * mask_ratio 158 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 159 | 160 | # append cls token 161 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 162 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 163 | x = torch.cat((cls_tokens, x), dim=1) 164 | 165 | # apply Transformer blocks 166 | for blk in self.blocks: 167 | x = blk(x) 168 | x = self.norm(x) 169 | 170 | return x, mask, ids_restore 171 | 172 | def forward_decoder(self, x, ids_restore): 173 | # embed tokens 174 | x = self.decoder_embed(x) 175 | 176 | # append mask tokens to sequence 177 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 178 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 179 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 180 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 181 | 182 | # add pos embed 183 | x = x + self.decoder_pos_embed 184 | 185 | # apply Transformer blocks 186 | for blk in self.decoder_blocks: 187 | x = blk(x) 188 | x = self.decoder_norm(x) 189 | 190 | # predictor projection 191 | x = self.decoder_pred(x) 192 | 193 | # remove cls token 194 | x = x[:, 1:, :] 195 | 196 | return x 197 | 198 | def forward_loss(self, imgs, pred, mask): 199 | """ 200 | imgs: [N, 3, H, W] 201 | pred: [N, L, p*p*3] 202 | mask: [N, L], 0 is keep, 1 is remove, 203 | """ 204 | target = self.patchify(imgs) 205 | if self.norm_pix_loss: 206 | mean = target.mean(dim=-1, keepdim=True) 207 | var = target.var(dim=-1, keepdim=True) 208 | target = (target - mean) / (var + 1.e-6)**.5 209 | 210 | loss = (pred - target) ** 2 211 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 212 | 213 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 214 | return loss 215 | 216 | def forward(self, imgs, mask_ratio=0.75, **kwargs): 217 | with torch.cuda.amp.autocast(): 218 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 219 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 220 | loss = self.forward_loss(imgs, pred, mask) 221 | # return loss, pred, mask 222 | return loss 223 | 224 | 225 | def mae_vit_small_patch16_dec512d8b(**kwargs): 226 | model = MaskedAutoencoderViT( 227 | patch_size=16, embed_dim=384, depth=12, num_heads=6, 228 | decoder_embed_dim=256, decoder_num_heads=8, 229 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 230 | return model 231 | 232 | 233 | def mae_vit_base_patch16_dec512d8b(**kwargs): 234 | model = MaskedAutoencoderViT( 235 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 236 | decoder_embed_dim=512, decoder_num_heads=16, 237 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 238 | return model 239 | 240 | 241 | def mae_vit_large_patch16_dec512d8b(**kwargs): 242 | model = MaskedAutoencoderViT( 243 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 244 | decoder_embed_dim=512, decoder_num_heads=16, 245 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 246 | return model 247 | 248 | 249 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 250 | model = MaskedAutoencoderViT( 251 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 252 | decoder_embed_dim=512, decoder_num_heads=16, 253 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 254 | return model 255 | 256 | 257 | # set recommended archs 258 | mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b 259 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 260 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 261 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 262 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | if torch.__version__.startswith('2.'): 22 | from torch import inf 23 | else: 24 | from torch._six import inf 25 | 26 | 27 | class SmoothedValue(object): 28 | """Track a series of values and provide access to smoothed values over a 29 | window or the global series average. 30 | """ 31 | 32 | def __init__(self, window_size=20, fmt=None): 33 | if fmt is None: 34 | fmt = "{median:.4f} ({global_avg:.4f})" 35 | self.deque = deque(maxlen=window_size) 36 | self.total = 0.0 37 | self.count = 0 38 | self.fmt = fmt 39 | 40 | def update(self, value, n=1): 41 | self.deque.append(value) 42 | self.count += n 43 | self.total += value * n 44 | 45 | def synchronize_between_processes(self): 46 | """ 47 | Warning: does not synchronize the deque! 48 | """ 49 | if not is_dist_avail_and_initialized(): 50 | return 51 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 52 | dist.barrier() 53 | dist.all_reduce(t) 54 | t = t.tolist() 55 | self.count = int(t[0]) 56 | self.total = t[1] 57 | 58 | @property 59 | def median(self): 60 | d = torch.tensor(list(self.deque)) 61 | return d.median().item() 62 | 63 | @property 64 | def avg(self): 65 | d = torch.tensor(list(self.deque), dtype=torch.float32) 66 | return d.mean().item() 67 | 68 | @property 69 | def global_avg(self): 70 | return self.total / self.count 71 | 72 | @property 73 | def max(self): 74 | return max(self.deque) 75 | 76 | @property 77 | def value(self): 78 | return self.deque[-1] 79 | 80 | def __str__(self): 81 | return self.fmt.format( 82 | median=self.median, 83 | avg=self.avg, 84 | global_avg=self.global_avg, 85 | max=self.max, 86 | value=self.value) 87 | 88 | 89 | class MetricLogger(object): 90 | def __init__(self, delimiter="\t"): 91 | self.meters = defaultdict(SmoothedValue) 92 | self.delimiter = delimiter 93 | 94 | def update(self, **kwargs): 95 | for k, v in kwargs.items(): 96 | if v is None: 97 | continue 98 | if isinstance(v, torch.Tensor): 99 | v = v.item() 100 | assert isinstance(v, (float, int)) 101 | self.meters[k].update(v) 102 | 103 | def __getattr__(self, attr): 104 | if attr in self.meters: 105 | return self.meters[attr] 106 | if attr in self.__dict__: 107 | return self.__dict__[attr] 108 | raise AttributeError("'{}' object has no attribute '{}'".format( 109 | type(self).__name__, attr)) 110 | 111 | def __str__(self): 112 | loss_str = [] 113 | for name, meter in self.meters.items(): 114 | loss_str.append( 115 | "{}: {}".format(name, str(meter)) 116 | ) 117 | return self.delimiter.join(loss_str) 118 | 119 | def synchronize_between_processes(self): 120 | for meter in self.meters.values(): 121 | meter.synchronize_between_processes() 122 | 123 | def add_meter(self, name, meter): 124 | self.meters[name] = meter 125 | 126 | def log_every(self, iterable, print_freq, header=None): 127 | i = 0 128 | if not header: 129 | header = '' 130 | start_time = time.time() 131 | end = time.time() 132 | iter_time = SmoothedValue(fmt='{avg:.4f}') 133 | data_time = SmoothedValue(fmt='{avg:.4f}') 134 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 135 | log_msg = [ 136 | header, 137 | '[{0' + space_fmt + '}/{1}]', 138 | 'eta: {eta}', 139 | '{meters}', 140 | 'time: {time}', 141 | 'data: {data}' 142 | ] 143 | if torch.cuda.is_available(): 144 | log_msg.append('max mem: {memory:.0f}') 145 | log_msg = self.delimiter.join(log_msg) 146 | MB = 1024.0 * 1024.0 147 | for obj in iterable: 148 | data_time.update(time.time() - end) 149 | yield obj 150 | iter_time.update(time.time() - end) 151 | if i % print_freq == 0 or i == len(iterable) - 1: 152 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 153 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 154 | if torch.cuda.is_available(): 155 | print(log_msg.format( 156 | i, len(iterable), eta=eta_string, 157 | meters=str(self), 158 | time=str(iter_time), data=str(data_time), 159 | memory=torch.cuda.max_memory_allocated() / MB)) 160 | else: 161 | print(log_msg.format( 162 | i, len(iterable), eta=eta_string, 163 | meters=str(self), 164 | time=str(iter_time), data=str(data_time))) 165 | i += 1 166 | end = time.time() 167 | total_time = time.time() - start_time 168 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 169 | print('{} Total time: {} ({:.4f} s / it)'.format( 170 | header, total_time_str, total_time / len(iterable))) 171 | 172 | 173 | def setup_for_distributed(is_master): 174 | """ 175 | This function disables printing when not in master process 176 | """ 177 | builtin_print = builtins.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | force = force or (get_world_size() > 8) 182 | if is_master or force: 183 | now = datetime.datetime.now().time() 184 | builtin_print('[{}] '.format(now), end='') # print with time stamp 185 | builtin_print(*args, **kwargs) 186 | 187 | builtins.print = print 188 | 189 | 190 | def is_dist_avail_and_initialized(): 191 | if not dist.is_available(): 192 | return False 193 | if not dist.is_initialized(): 194 | return False 195 | return True 196 | 197 | 198 | def get_world_size(): 199 | if not is_dist_avail_and_initialized(): 200 | return 1 201 | return dist.get_world_size() 202 | 203 | 204 | def get_rank(): 205 | if not is_dist_avail_and_initialized(): 206 | return 0 207 | return dist.get_rank() 208 | 209 | 210 | def is_main_process(): 211 | return get_rank() == 0 212 | 213 | 214 | def save_on_master(*args, **kwargs): 215 | if is_main_process(): 216 | torch.save(*args, **kwargs) 217 | 218 | 219 | def init_distributed_mode(args): 220 | if args.dist_on_itp: 221 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 222 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 223 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 224 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 225 | os.environ['LOCAL_RANK'] = str(args.gpu) 226 | os.environ['RANK'] = str(args.rank) 227 | os.environ['WORLD_SIZE'] = str(args.world_size) 228 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 229 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 230 | args.rank = int(os.environ["RANK"]) 231 | args.world_size = int(os.environ['WORLD_SIZE']) 232 | args.gpu = int(os.environ['LOCAL_RANK']) 233 | elif 'SLURM_PROCID' in os.environ: 234 | args.rank = int(os.environ['SLURM_PROCID']) 235 | args.gpu = args.rank % torch.cuda.device_count() 236 | else: 237 | print('Not using distributed mode') 238 | setup_for_distributed(is_master=True) # hack 239 | args.distributed = False 240 | return 241 | 242 | args.distributed = True 243 | 244 | torch.cuda.set_device(args.gpu) 245 | args.dist_backend = 'nccl' 246 | print('| distributed init (rank {}): {}, gpu {}'.format( 247 | args.rank, args.dist_url, args.gpu), flush=True) 248 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 249 | world_size=args.world_size, rank=args.rank) 250 | torch.distributed.barrier() 251 | setup_for_distributed(args.rank == 0) 252 | 253 | 254 | class NativeScalerWithGradNormCount: 255 | state_dict_key = "amp_scaler" 256 | 257 | def __init__(self): 258 | self._scaler = torch.cuda.amp.GradScaler() 259 | 260 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 261 | self._scaler.scale(loss).backward(create_graph=create_graph) 262 | if update_grad: 263 | if clip_grad is not None: 264 | assert parameters is not None 265 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 266 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 267 | else: 268 | self._scaler.unscale_(optimizer) 269 | norm = get_grad_norm_(parameters) 270 | self._scaler.step(optimizer) 271 | self._scaler.update() 272 | else: 273 | norm = None 274 | return norm 275 | 276 | def state_dict(self): 277 | return self._scaler.state_dict() 278 | 279 | def load_state_dict(self, state_dict): 280 | self._scaler.load_state_dict(state_dict) 281 | 282 | 283 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 284 | if isinstance(parameters, torch.Tensor): 285 | parameters = [parameters] 286 | parameters = [p for p in parameters if p.grad is not None] 287 | norm_type = float(norm_type) 288 | if len(parameters) == 0: 289 | return torch.tensor(0.) 290 | device = parameters[0].grad.device 291 | if norm_type == inf: 292 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 293 | else: 294 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 295 | return total_norm 296 | 297 | 298 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, save_latest_model_only=False): 299 | output_dir = Path(args.output_dir) 300 | epoch_name = str(epoch) 301 | if loss_scaler is not None: 302 | if save_latest_model_only: 303 | checkpoint_paths = [output_dir / ('checkpoint.pth')] 304 | else: 305 | checkpoint_paths = [output_dir / ('checkpoint.pth'), output_dir / ('checkpoint-%s.pth' % epoch_name)] 306 | for checkpoint_path in checkpoint_paths: 307 | to_save = { 308 | 'model': model_without_ddp.state_dict(), 309 | 'optimizer': optimizer.state_dict(), 310 | 'epoch': epoch, 311 | 'scaler': loss_scaler.state_dict(), 312 | 'args': args, 313 | } 314 | 315 | save_on_master(to_save, checkpoint_path) 316 | else: 317 | client_state = {'epoch': epoch} 318 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint", client_state=client_state) 319 | 320 | 321 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 322 | if args.resume: 323 | if args.resume.startswith('https'): 324 | checkpoint = torch.hub.load_state_dict_from_url( 325 | args.resume, map_location='cpu', check_hash=True) 326 | else: 327 | checkpoint = torch.load(args.resume, map_location='cpu') 328 | model_without_ddp.load_state_dict(checkpoint['model']) 329 | print("Resume checkpoint %s" % args.resume) 330 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 331 | optimizer.load_state_dict(checkpoint['optimizer']) 332 | args.start_epoch = checkpoint['epoch'] + 1 333 | if 'scaler' in checkpoint: 334 | loss_scaler.load_state_dict(checkpoint['scaler']) 335 | print("With optim & sched!") 336 | 337 | 338 | def all_reduce_mean(x): 339 | world_size = get_world_size() 340 | if world_size > 1: 341 | x_reduce = torch.tensor(x).cuda() 342 | dist.all_reduce(x_reduce) 343 | x_reduce /= world_size 344 | return x_reduce.item() 345 | else: 346 | return x -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import numpy as np 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.tensorboard import SummaryWriter 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | 25 | import timm 26 | from timm.data.loader import MultiEpochsDataLoader 27 | 28 | # assert timm.__version__ == "0.3.2" # version check 29 | import timm.optim.optim_factory as optim_factory 30 | 31 | import util.misc as misc 32 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 33 | from transformer_utils import handle_flash_attn 34 | 35 | import models_mae 36 | import models_cross 37 | 38 | from engine_pretrain import train_one_epoch 39 | 40 | 41 | def get_args_parser(): 42 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 43 | parser.add_argument('--batch_size', default=64, type=int, 44 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 45 | parser.add_argument('--epochs', default=400, type=int) 46 | parser.add_argument('--accum_iter', default=1, type=int, 47 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 48 | 49 | # Model parameters 50 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 51 | help='Name of model to train') 52 | 53 | parser.add_argument('--input_size', default=224, type=int, 54 | help='images input size') 55 | 56 | parser.add_argument('--decoder_depth', default=8, type=int, 57 | help='depth of decoder') 58 | 59 | parser.add_argument('--mask_ratio', default=0.75, type=float, 60 | help='Masking ratio (1 - percentage of remained patches).') 61 | 62 | parser.add_argument('--kept_mask_ratio', default=0.75, type=float, 63 | help='Amongst the all tokens, the percentage of the mask that are kept') 64 | parser.add_argument('--inverse_lr', action='store_true', default=False, help='Use inverse lr scheduler') 65 | parser.add_argument('--no_lr_scale', action='store_true', default=False, help='Do not scale lr by mask_ratio') 66 | 67 | parser.add_argument('--norm_pix_loss', action='store_true', 68 | help='Use (per-patch) normalized pixels as targets for computing loss') 69 | parser.set_defaults(norm_pix_loss=False) 70 | 71 | parser.add_argument( 72 | '--find_unused_parameters', action='store_true', 73 | help="distributed ddp find unused parameters") 74 | parser.set_defaults(find_unused_parameters=False) 75 | 76 | # Optimizer parameters 77 | parser.add_argument('--weight_decay', type=float, default=0.05, 78 | help='weight decay (default: 0.05)') 79 | 80 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 81 | help='learning rate (absolute lr)') 82 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 83 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 84 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 85 | help='lower lr bound for cyclic schedulers that hit 0') 86 | 87 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 88 | help='epochs to warmup LR') 89 | 90 | # Dataset parameters 91 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 92 | help='dataset path') 93 | 94 | parser.add_argument('--output_dir', default='./output_dir', 95 | help='path where to save, empty for no saving') 96 | parser.add_argument('--log_dir', default=None, 97 | help='path where to tensorboard log') 98 | parser.add_argument('--device', default='cuda', 99 | help='device to use for training / testing') 100 | parser.add_argument('--seed', default=0, type=int) 101 | parser.add_argument('--resume', default='', 102 | help='resume from checkpoint') 103 | 104 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 105 | help='start epoch') 106 | parser.add_argument('--num_workers', default=10, type=int) 107 | parser.add_argument('--pin_mem', action='store_true', 108 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 109 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 110 | parser.set_defaults(pin_mem=True) 111 | 112 | parser.add_argument('--multi_epochs_dataloader', action='store_true', help='Use MultiEpochsDataLoader to prevent reinitializing dataloader per epoch') 113 | 114 | # distributed training parameters 115 | parser.add_argument('--world_size', default=1, type=int, 116 | help='number of distributed processes') 117 | parser.add_argument('--local_rank', default=-1, type=int) 118 | parser.add_argument('--dist_on_itp', action='store_true') 119 | parser.add_argument('--dist_url', default='env://', 120 | help='url used to set up distributed training') 121 | 122 | # MAE or cross-MAE 123 | parser.add_argument('--cross_mae', action='store_true', default=False) 124 | parser.add_argument('--weight_fm', action='store_true', default=False, 125 | help='Weight the feature maps for decoder when running cross-mae') 126 | parser.add_argument('--use_fm', nargs='+', type=int, default=[-1], 127 | help='Feature maps to use for decoder') 128 | parser.add_argument('--use_input', action='store_true', default=False, 129 | help="use input as a feature map") 130 | parser.add_argument('--self_attn', action='store_true', default=False, help="use self attention in decoder") 131 | 132 | parser.add_argument('--enable_flash_attention2', action='store_true', default=False, help="Use flash attntion 2") 133 | 134 | return parser 135 | 136 | 137 | def main(args): 138 | misc.init_distributed_mode(args) 139 | 140 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 141 | print("{}".format(args).replace(', ', ',\n')) 142 | 143 | device = torch.device(args.device) 144 | 145 | # fix the seed for reproducibility 146 | seed = args.seed + misc.get_rank() 147 | torch.manual_seed(seed) 148 | np.random.seed(seed) 149 | 150 | cudnn.benchmark = True 151 | 152 | handle_flash_attn(args) 153 | 154 | # simple augmentation 155 | transform_train = transforms.Compose([ 156 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 157 | transforms.RandomHorizontalFlip(), 158 | transforms.ToTensor(), 159 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 160 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 161 | print(dataset_train) 162 | 163 | if True: # args.distributed: 164 | num_tasks = misc.get_world_size() 165 | global_rank = misc.get_rank() 166 | sampler_train = torch.utils.data.DistributedSampler( 167 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 168 | ) 169 | print("Sampler_train = %s" % str(sampler_train)) 170 | else: 171 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 172 | 173 | if global_rank == 0 and args.log_dir is not None: 174 | os.makedirs(args.log_dir, exist_ok=True) 175 | log_writer = SummaryWriter(log_dir=args.log_dir) 176 | else: 177 | log_writer = None 178 | 179 | dataloader_cls = MultiEpochsDataLoader if args.multi_epochs_dataloader else torch.utils.data.DataLoader 180 | 181 | data_loader_train = dataloader_cls( 182 | dataset_train, sampler=sampler_train, 183 | batch_size=args.batch_size, 184 | num_workers=args.num_workers, 185 | pin_memory=args.pin_mem, 186 | drop_last=True, 187 | ) 188 | 189 | # define the model 190 | if args.cross_mae: 191 | model = models_cross.__dict__[args.model]( 192 | norm_pix_loss=args.norm_pix_loss, 193 | weight_fm=args.weight_fm, 194 | decoder_depth=args.decoder_depth, 195 | use_fm=args.use_fm, 196 | use_input=args.use_input, 197 | self_attn=args.self_attn, 198 | img_size=args.input_size, 199 | ) 200 | else: 201 | model = models_mae.__dict__[args.model]( 202 | norm_pix_loss=args.norm_pix_loss, 203 | decoder_depth=args.decoder_depth, 204 | ) 205 | 206 | model.to(device) 207 | 208 | model_without_ddp = model 209 | print("Model = %s" % str(model_without_ddp)) 210 | 211 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 212 | 213 | if args.lr is None: # only base_lr is specified 214 | base_ratio = args.kept_mask_ratio / args.mask_ratio # base ratio for MAE 215 | if args.no_lr_scale: 216 | scale_kmr = 1 217 | elif args.inverse_lr: 218 | scale_kmr = 1 / base_ratio 219 | else: 220 | scale_kmr = base_ratio 221 | args.lr = scale_kmr * args.blr * eff_batch_size / 256 222 | 223 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 224 | print("actual lr: %.2e" % args.lr) 225 | 226 | print("accumulate grad iterations: %d" % args.accum_iter) 227 | print("effective batch size: %d" % eff_batch_size) 228 | 229 | if args.distributed: 230 | model = torch.nn.parallel.DistributedDataParallel( 231 | model, device_ids=[args.gpu], 232 | find_unused_parameters=args.find_unused_parameters 233 | ) 234 | model_without_ddp = model.module 235 | 236 | # following timm: set wd as 0 for bias and norm layers 237 | param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) 238 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 239 | print(optimizer) 240 | loss_scaler = NativeScaler() 241 | 242 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 243 | 244 | print(f"Start training for {args.epochs} epochs") 245 | start_time = time.time() 246 | for epoch in range(args.start_epoch, args.epochs): 247 | if args.distributed: 248 | data_loader_train.sampler.set_epoch(epoch) 249 | train_stats = train_one_epoch( 250 | model, data_loader_train, 251 | optimizer, device, epoch, loss_scaler, 252 | log_writer=log_writer, 253 | args=args 254 | ) 255 | if args.output_dir: 256 | if epoch % 200 == 0: 257 | misc.save_model( 258 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 259 | loss_scaler=loss_scaler, epoch=epoch, save_latest_model_only=False) 260 | elif epoch % 20 == 0: 261 | misc.save_model( 262 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 263 | loss_scaler=loss_scaler, epoch=epoch, save_latest_model_only=True) 264 | if epoch + 1 == args.epochs: 265 | misc.save_model( 266 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 267 | loss_scaler=loss_scaler, epoch=epoch, save_latest_model_only=True) 268 | 269 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 270 | 'epoch': epoch,} 271 | 272 | if args.output_dir and misc.is_main_process(): 273 | if log_writer is not None: 274 | log_writer.flush() 275 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 276 | f.write(json.dumps(log_stats) + "\n") 277 | 278 | total_time = time.time() - start_time 279 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 280 | print('Training time {}'.format(total_time_str)) 281 | 282 | 283 | if __name__ == '__main__': 284 | args = get_args_parser() 285 | args = args.parse_args() 286 | assert args.kept_mask_ratio <= args.mask_ratio, "Cannot reconstruct more than what is masked" 287 | if args.log_dir is None: 288 | args.log_dir = args.output_dir 289 | if args.output_dir: 290 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 291 | main(args) 292 | -------------------------------------------------------------------------------- /main_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # MoCo v3: https://github.com/facebookresearch/moco-v3 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from torch.utils.tensorboard import SummaryWriter 23 | import torchvision.transforms as transforms 24 | import torchvision.datasets as datasets 25 | 26 | import timm 27 | from timm.data.loader import MultiEpochsDataLoader 28 | 29 | # assert timm.__version__ == "0.3.2" # version check 30 | from timm.models.layers import trunc_normal_ 31 | 32 | import util.misc as misc 33 | from util.pos_embed import interpolate_pos_embed 34 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 35 | from util.lars import LARS 36 | from util.crop import RandomResizedCrop 37 | 38 | from transformer_utils import handle_flash_attn 39 | 40 | import models_vit 41 | 42 | from engine_finetune import train_one_epoch, evaluate 43 | 44 | 45 | def get_args_parser(): 46 | parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False) 47 | parser.add_argument('--batch_size', default=512, type=int, 48 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 49 | parser.add_argument('--epochs', default=90, type=int) 50 | parser.add_argument('--accum_iter', default=1, type=int, 51 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 52 | 53 | # Model parameters 54 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 55 | help='Name of model to train') 56 | 57 | # Optimizer parameters 58 | parser.add_argument('--weight_decay', type=float, default=0, 59 | help='weight decay (default: 0 for linear probe following MoCo v1)') 60 | 61 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 62 | help='learning rate (absolute lr)') 63 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR', 64 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 65 | 66 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 67 | help='lower lr bound for cyclic schedulers that hit 0') 68 | 69 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 70 | help='epochs to warmup LR') 71 | 72 | # * Finetuning params 73 | parser.add_argument('--finetune', default='', 74 | help='finetune from checkpoint') 75 | parser.add_argument('--global_pool', action='store_true') 76 | parser.set_defaults(global_pool=False) 77 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 78 | help='Use class token instead of global pool for classification') 79 | 80 | # Dataset parameters 81 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 82 | help='dataset path') 83 | parser.add_argument('--nb_classes', default=1000, type=int, 84 | help='number of the classification types') 85 | 86 | parser.add_argument('--output_dir', default='./output_dir', 87 | help='path where to save, empty for no saving') 88 | parser.add_argument('--log_dir', default=None, 89 | help='path where to tensorboard log') 90 | parser.add_argument('--device', default='cuda', 91 | help='device to use for training / testing') 92 | parser.add_argument('--seed', default=0, type=int) 93 | parser.add_argument('--resume', default='', 94 | help='resume from checkpoint') 95 | 96 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 97 | help='start epoch') 98 | parser.add_argument('--eval', action='store_true', 99 | help='Perform evaluation only') 100 | parser.add_argument('--dist_eval', action='store_true', default=False, 101 | help='Enabling distributed evaluation (recommended during training for faster monitor') 102 | parser.add_argument('--num_workers', default=10, type=int) 103 | parser.add_argument('--pin_mem', action='store_true', 104 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 105 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 106 | parser.set_defaults(pin_mem=True) 107 | 108 | parser.add_argument('--multi_epochs_dataloader', action='store_true', help='Use MultiEpochsDataLoader to prevent reinitializing dataloader per epoch') 109 | 110 | # distributed training parameters 111 | parser.add_argument('--world_size', default=1, type=int, 112 | help='number of distributed processes') 113 | parser.add_argument('--local_rank', default=-1, type=int) 114 | parser.add_argument('--dist_on_itp', action='store_true') 115 | parser.add_argument('--dist_url', default='env://', 116 | help='url used to set up distributed training') 117 | 118 | parser.add_argument('--enable_flash_attention2', action='store_true', default=False, help="Use flash attntion 2") 119 | 120 | return parser 121 | 122 | 123 | def main(args): 124 | misc.init_distributed_mode(args) 125 | 126 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 127 | print("{}".format(args).replace(', ', ',\n')) 128 | 129 | device = torch.device(args.device) 130 | 131 | # fix the seed for reproducibility 132 | seed = args.seed + misc.get_rank() 133 | torch.manual_seed(seed) 134 | np.random.seed(seed) 135 | 136 | cudnn.benchmark = True 137 | 138 | handle_flash_attn(args) 139 | 140 | # linear probe: weak augmentation 141 | transform_train = transforms.Compose([ 142 | RandomResizedCrop(224, interpolation=3), 143 | transforms.RandomHorizontalFlip(), 144 | transforms.ToTensor(), 145 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 146 | transform_val = transforms.Compose([ 147 | transforms.Resize(256, interpolation=3), 148 | transforms.CenterCrop(224), 149 | transforms.ToTensor(), 150 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 151 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 152 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) 153 | print(dataset_train) 154 | print(dataset_val) 155 | 156 | if True: # args.distributed: 157 | num_tasks = misc.get_world_size() 158 | global_rank = misc.get_rank() 159 | sampler_train = torch.utils.data.DistributedSampler( 160 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 161 | ) 162 | print("Sampler_train = %s" % str(sampler_train)) 163 | if args.dist_eval: 164 | if len(dataset_val) % num_tasks != 0: 165 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 166 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 167 | 'equal num of samples per-process.') 168 | sampler_val = torch.utils.data.DistributedSampler( 169 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 170 | else: 171 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 172 | else: 173 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 174 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 175 | 176 | if global_rank == 0 and args.log_dir is not None and not args.eval: 177 | os.makedirs(args.log_dir, exist_ok=True) 178 | log_writer = SummaryWriter(log_dir=args.log_dir) 179 | else: 180 | log_writer = None 181 | 182 | dataloader_cls = MultiEpochsDataLoader if args.multi_epochs_dataloader else torch.utils.data.DataLoader 183 | 184 | data_loader_train = dataloader_cls( 185 | dataset_train, sampler=sampler_train, 186 | batch_size=args.batch_size, 187 | num_workers=args.num_workers, 188 | pin_memory=args.pin_mem, 189 | drop_last=True, 190 | ) 191 | 192 | data_loader_val = torch.utils.data.DataLoader( 193 | dataset_val, sampler=sampler_val, 194 | batch_size=args.batch_size, 195 | num_workers=args.num_workers, 196 | pin_memory=args.pin_mem, 197 | drop_last=False 198 | ) 199 | 200 | model = models_vit.__dict__[args.model]( 201 | num_classes=args.nb_classes, 202 | global_pool=args.global_pool, 203 | ) 204 | 205 | if args.finetune and not args.eval: 206 | checkpoint = torch.load(args.finetune, map_location='cpu') 207 | 208 | print("Load pre-trained checkpoint from: %s" % args.finetune) 209 | checkpoint_model = checkpoint['model'] 210 | state_dict = model.state_dict() 211 | for k in ['head.weight', 'head.bias']: 212 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 213 | print(f"Removing key {k} from pretrained checkpoint") 214 | del checkpoint_model[k] 215 | 216 | # interpolate position embedding 217 | interpolate_pos_embed(model, checkpoint_model) 218 | 219 | # load pre-trained model 220 | msg = model.load_state_dict(checkpoint_model, strict=False) 221 | print(msg) 222 | 223 | if args.global_pool: 224 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 225 | else: 226 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 227 | 228 | # manually initialize fc layer: following MoCo v3 229 | trunc_normal_(model.head.weight, std=0.01) 230 | 231 | # for linear prob only 232 | # hack: revise model's head with BN 233 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) 234 | # freeze all but the head 235 | for _, p in model.named_parameters(): 236 | p.requires_grad = False 237 | for _, p in model.head.named_parameters(): 238 | p.requires_grad = True 239 | 240 | model.to(device) 241 | 242 | model_without_ddp = model 243 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 244 | 245 | print("Model = %s" % str(model_without_ddp)) 246 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 247 | 248 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 249 | 250 | if args.lr is None: # only base_lr is specified 251 | args.lr = args.blr * eff_batch_size / 256 252 | 253 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 254 | print("actual lr: %.2e" % args.lr) 255 | 256 | print("accumulate grad iterations: %d" % args.accum_iter) 257 | print("effective batch size: %d" % eff_batch_size) 258 | 259 | if args.distributed: 260 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 261 | model_without_ddp = model.module 262 | 263 | optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) 264 | print(optimizer) 265 | loss_scaler = NativeScaler() 266 | 267 | criterion = torch.nn.CrossEntropyLoss() 268 | 269 | print("criterion = %s" % str(criterion)) 270 | 271 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 272 | 273 | if args.eval: 274 | test_stats = evaluate(data_loader_val, model, device) 275 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 276 | exit(0) 277 | 278 | print(f"Start training for {args.epochs} epochs") 279 | start_time = time.time() 280 | max_accuracy = 0.0 281 | for epoch in range(args.start_epoch, args.epochs): 282 | if args.distributed: 283 | data_loader_train.sampler.set_epoch(epoch) 284 | train_stats = train_one_epoch( 285 | model, criterion, data_loader_train, 286 | optimizer, device, epoch, loss_scaler, 287 | max_norm=None, 288 | log_writer=log_writer, 289 | args=args 290 | ) 291 | 292 | test_stats = evaluate(data_loader_val, model, device) 293 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 294 | 295 | if args.output_dir and test_stats['acc1'] >= max_accuracy: 296 | misc.save_model( 297 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 298 | loss_scaler=loss_scaler, epoch=epoch, save_latest_model_only=True) 299 | 300 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 301 | print(f'Max accuracy: {max_accuracy:.2f}%') 302 | 303 | if log_writer is not None: 304 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 305 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 306 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 307 | 308 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 309 | **{f'test_{k}': v for k, v in test_stats.items()}, 310 | 'epoch': epoch, 311 | 'n_parameters': n_parameters} 312 | 313 | if args.output_dir and misc.is_main_process(): 314 | if log_writer is not None: 315 | log_writer.flush() 316 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 317 | f.write(json.dumps(log_stats) + "\n") 318 | 319 | total_time = time.time() - start_time 320 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 321 | print('Training time {}'.format(total_time_str)) 322 | 323 | 324 | if __name__ == '__main__': 325 | args = get_args_parser() 326 | args = args.parse_args() 327 | if args.log_dir is None: 328 | args.log_dir = args.output_dir 329 | if args.output_dir: 330 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 331 | main(args) 332 | -------------------------------------------------------------------------------- /models_cross.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import numpy as np 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | 19 | from transformer_utils import Block, CrossAttentionBlock, PatchEmbed 20 | 21 | from util.pos_embed import get_2d_sincos_pos_embed 22 | 23 | class WeightedFeatureMaps(nn.Module): 24 | def __init__(self, k, embed_dim, *, norm_layer=nn.LayerNorm, decoder_depth): 25 | super(WeightedFeatureMaps, self).__init__() 26 | self.linear = nn.Linear(k, decoder_depth, bias=False) 27 | 28 | std_dev = 1. / math.sqrt(k) 29 | nn.init.normal_(self.linear.weight, mean=0., std=std_dev) 30 | 31 | def forward(self, feature_maps): 32 | # Ensure the input is a list 33 | assert isinstance(feature_maps, list), "Input should be a list of feature maps" 34 | # Ensure the list has the same length as the number of weights 35 | assert len(feature_maps) == (self.linear.weight.shape[1]), "Number of feature maps and weights should match" 36 | stacked_feature_maps = torch.stack(feature_maps, dim=-1) # shape: (B, L, C, k) 37 | # compute a weighted average of the feature maps 38 | # decoder_depth is denoted as j 39 | output = self.linear(stacked_feature_maps) 40 | return output 41 | 42 | class MaskedAutoencoderViT(nn.Module): 43 | """ Masked Autoencoder with VisionTransformer backbone 44 | """ 45 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 46 | embed_dim=1024, depth=24, num_heads=16, 47 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 48 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, 49 | weight_fm=False, 50 | use_fm=[-1], use_input=False, self_attn=False, 51 | ): 52 | super().__init__() 53 | 54 | # -------------------------------------------------------------------------- 55 | # MAE encoder specifics 56 | self.img_size = img_size 57 | self.patch_size = patch_size 58 | self.embed_dim = embed_dim 59 | self.decoder_embed_dim = decoder_embed_dim 60 | 61 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # these are needed regardless of the patch sampling method 62 | num_patches = self.patch_embed.num_patches 63 | 64 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 65 | 66 | self.blocks = nn.ModuleList([ 67 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 68 | for i in range(depth)]) 69 | # -------------------------------------------------------------------------- 70 | # weighted feature maps for cross attention 71 | self.weight_fm = weight_fm 72 | self.use_input = use_input # use input as one of the feature maps 73 | if len(use_fm) == 1 and use_fm[0] == -1: 74 | self.use_fm = list(range(depth)) 75 | else: 76 | self.use_fm = [i if i >= 0 else depth + i for i in use_fm] 77 | if self.weight_fm: 78 | # print("Weighting feature maps!") 79 | # print("using feature maps: ", self.use_fm) 80 | dec_norms = [] 81 | for i in range(decoder_depth): 82 | norm_layer_i = norm_layer(embed_dim) 83 | dec_norms.append(norm_layer_i) 84 | self.dec_norms = nn.ModuleList(dec_norms) 85 | 86 | # feature weighting 87 | self.wfm = WeightedFeatureMaps(len(self.use_fm) + (1 if self.use_input else 0), embed_dim, norm_layer=norm_layer, decoder_depth=decoder_depth) 88 | 89 | # -------------------------------------------------------------------------- 90 | # MAE decoder specifics 91 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 92 | print("use self attention: ", self_attn) 93 | self.decoder_blocks = nn.ModuleList([ 94 | CrossAttentionBlock(embed_dim, decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, self_attn=self_attn) 95 | for i in range(decoder_depth)]) 96 | 97 | self.decoder_norm = norm_layer(decoder_embed_dim) 98 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 99 | # -------------------------------------------------------------------------- 100 | # Dealing with positional embedding, patch sampling 101 | # encoder 102 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 103 | # decoder 104 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 105 | # -------------------------------------------------------------------------- 106 | 107 | self.norm_pix_loss = norm_pix_loss 108 | 109 | self.initialize_weights() 110 | 111 | def initialize_weights(self): 112 | # initialization 113 | # initialize (and freeze) pos_embed by sin-cos embedding 114 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 115 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 116 | 117 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 118 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 119 | 120 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 121 | w = self.patch_embed.proj.weight.data 122 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 123 | 124 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 125 | torch.nn.init.normal_(self.cls_token, std=.02) 126 | torch.nn.init.normal_(self.mask_token, std=.02) 127 | 128 | # initialize nn.Linear and nn.LayerNorm 129 | self.apply(self._init_weights) 130 | 131 | def _init_weights(self, m): 132 | if isinstance(m, nn.Linear): 133 | # we use xavier_uniform following official JAX ViT: 134 | torch.nn.init.xavier_uniform_(m.weight) 135 | if isinstance(m, nn.Linear) and m.bias is not None: 136 | nn.init.constant_(m.bias, 0) 137 | elif isinstance(m, nn.LayerNorm): 138 | nn.init.constant_(m.bias, 0) 139 | nn.init.constant_(m.weight, 1.0) 140 | 141 | def patchify(self, imgs): 142 | """ 143 | imgs: (N, 3, H, W) 144 | x: (N, L, patch_size**2 *3) 145 | """ 146 | p = self.patch_embed.patch_size[0] 147 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 148 | 149 | h = w = imgs.shape[2] // p 150 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 151 | x = torch.einsum('nchpwq->nhwpqc', x) 152 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 153 | return x 154 | 155 | def unpatchify(self, x): 156 | """ 157 | x: (N, L, patch_size**2 *3) 158 | imgs: (N, 3, H, W) 159 | """ 160 | p = self.patch_embed.patch_size[0] 161 | h = w = int(x.shape[1]**.5) 162 | assert h * w == x.shape[1] 163 | 164 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 165 | x = torch.einsum('nhwpqc->nchpwq', x) 166 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 167 | return imgs 168 | 169 | def random_masking(self, x, mask_ratio, kept_mask_ratio): 170 | """ 171 | Perform per-sample random masking by per-sample shuffling. 172 | Per-sample shuffling is done by argsort random noise. 173 | x: [N, L, D], sequence 174 | """ 175 | N, L, D = x.shape # batch, length, dim 176 | len_keep = int(L * (1 - mask_ratio)) 177 | len_masked = int(L * (mask_ratio - kept_mask_ratio)) 178 | 179 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 180 | 181 | # sort noise for each sample 182 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 183 | ids_restore = torch.argsort(ids_shuffle, dim=1) 184 | 185 | # keep the first subset 186 | ids_keep = ids_shuffle[:, :len_keep] 187 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 188 | 189 | # generate the binary mask: 0 is keep, 1 is remove 190 | mask = torch.ones([N, L], device=x.device) 191 | mask[:, :(len_keep + len_masked)] = 0 192 | # unshuffle to get the binary mask 193 | mask = torch.gather(mask, dim=1, index=ids_restore) 194 | 195 | return x_masked, mask, ids_restore 196 | 197 | def grid_patchify(self, x): 198 | # embed patches 199 | x = self.patch_embed(x) 200 | 201 | # add pos embed w/o cls token 202 | x = x + self.pos_embed[:, 1:, :] 203 | return x 204 | 205 | def forward_encoder(self, x, mask_ratio, kept_mask_ratio): 206 | x = self.grid_patchify(x) 207 | coords = None 208 | 209 | # masking: length -> length * mask_ratio 210 | x, mask, ids_restore = self.random_masking(x, mask_ratio, kept_mask_ratio) 211 | 212 | # append cls token 213 | # cls_token = self.cls_token + self.pos_embed[:, :1, :] # pos embed for cls token is 0 214 | cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 215 | x = torch.cat((cls_tokens, x), dim=1) 216 | 217 | # apply Transformer blocks 218 | x_feats = [] 219 | if self.use_input: 220 | x_feats.append(x) 221 | for idx, blk in enumerate(self.blocks): 222 | x = blk(x) 223 | if self.weight_fm and idx in self.use_fm: 224 | x_feats.append(x) 225 | 226 | if self.weight_fm: 227 | return x_feats, mask, ids_restore, coords 228 | else: 229 | x = self.norm(x) 230 | return x, mask, ids_restore, coords 231 | 232 | def mask_tokens_grid(self, mask, ids_restore): 233 | N, L = ids_restore.shape 234 | 235 | # contruct mask tokens 236 | x = self.decoder_pos_embed[:, 1:].masked_select(mask.bool().unsqueeze(-1)).reshape(N, -1, self.mask_token.shape[-1]) 237 | x = x + self.mask_token 238 | return x 239 | 240 | def forward_decoder(self, y, mask, ids_restore, coords, mask_ratio, kept_mask_ratio): 241 | x = self.mask_tokens_grid(mask, ids_restore) 242 | 243 | if self.weight_fm: 244 | # y input: a list of Tensors (B, C, D) 245 | y = self.wfm(y) 246 | 247 | for i, blk in enumerate(self.decoder_blocks): 248 | if self.weight_fm: 249 | x = blk(x, self.dec_norms[i](y[..., i])) 250 | else: 251 | x = blk(x, y) 252 | 253 | x = self.decoder_norm(x) 254 | x = self.decoder_pred(x) # N, L, patch_size**2 *3 255 | 256 | return x, None 257 | 258 | def forward_loss(self, imgs, pred, mask, coords): 259 | """ 260 | imgs: [N, 3, H, W] 261 | pred: [N, L, p*p*3] 262 | mask: [N, L], 0 is keep, 1 is remove, 263 | """ 264 | target = self.patchify(imgs) 265 | target = target.masked_select(mask.bool().unsqueeze(-1)).reshape(target.shape[0], -1, target.shape[-1]) 266 | if self.norm_pix_loss: 267 | mean = target.mean(dim=-1, keepdim=True) 268 | var = target.var(dim=-1, keepdim=True) 269 | target = (target - mean) / (var + 1.e-6)**.5 270 | 271 | loss = (pred - target) ** 2 272 | loss = loss.mean() 273 | return loss, target 274 | 275 | def forward(self, imgs, mask_ratio=0.75, kept_mask_ratio=0.75, vis=False): 276 | with torch.cuda.amp.autocast(): 277 | latent, mask, ids_restore, coords = self.forward_encoder(imgs, mask_ratio, kept_mask_ratio) 278 | pred, combined = self.forward_decoder(latent, mask, ids_restore, coords, mask_ratio, kept_mask_ratio) # [N, L, p*p*3] 279 | loss, target = self.forward_loss(imgs, pred, mask, coords) 280 | if vis: 281 | # assumes mask ratio is the same as kept_mask_ratio for visualizations 282 | assert mask_ratio == kept_mask_ratio, "mask_ratio needs to be the same as kept_mask_ratio for visualizations. Otherwise we have unpredicted patches." 283 | # create some zero tensors 284 | with torch.no_grad(): 285 | N, L = mask.shape[0], mask.shape[1] 286 | 287 | combined = torch.zeros(N, L, pred.shape[2], device=pred.device, dtype=pred.dtype) 288 | combined[mask.bool()] = pred.view(-1, pred.shape[2]) 289 | pred_combined = combined 290 | 291 | combined = torch.zeros(N, L, pred.shape[2], device=pred.device, dtype=pred.dtype) 292 | combined[mask.bool()] = target.view(-1, target.shape[2]) 293 | target_combined = combined 294 | 295 | return loss, pred_combined, target_combined, mask 296 | else: 297 | return loss 298 | 299 | 300 | def mae_vit_small_patch16_dec512d8b(**kwargs): 301 | model = MaskedAutoencoderViT( 302 | patch_size=16, embed_dim=384, depth=12, num_heads=6, 303 | decoder_embed_dim=256, decoder_num_heads=8, 304 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 305 | return model 306 | 307 | 308 | def mae_vit_base_patch16_dec512d8b(**kwargs): 309 | model = MaskedAutoencoderViT( 310 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 311 | decoder_embed_dim=512, decoder_num_heads=16, 312 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 313 | return model 314 | 315 | 316 | def mae_vit_large_patch16_dec512d8b(**kwargs): 317 | model = MaskedAutoencoderViT( 318 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 319 | decoder_embed_dim=512, decoder_num_heads=16, 320 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 321 | return model 322 | 323 | 324 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 325 | model = MaskedAutoencoderViT( 326 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 327 | decoder_embed_dim=512, decoder_num_heads=16, 328 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 329 | return model 330 | 331 | 332 | # set recommended archs 333 | mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b 334 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 335 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 336 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 337 | -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | import timm 25 | from timm.data.loader import MultiEpochsDataLoader 26 | 27 | # assert timm.__version__ == "0.3.2" # version check 28 | from timm.models.layers import trunc_normal_ 29 | from timm.data.mixup import Mixup 30 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 31 | 32 | import util.lr_decay as lrd 33 | import util.misc as misc 34 | from util.datasets import build_dataset 35 | from util.pos_embed import interpolate_pos_embed 36 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 37 | 38 | from transformer_utils import handle_flash_attn 39 | 40 | import models_vit 41 | 42 | from engine_finetune import train_one_epoch, evaluate 43 | 44 | 45 | def get_args_parser(): 46 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) 47 | parser.add_argument('--batch_size', default=64, type=int, 48 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 49 | parser.add_argument('--epochs', default=50, type=int) 50 | parser.add_argument('--accum_iter', default=1, type=int, 51 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 52 | 53 | # Model parameters 54 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 55 | help='Name of model to train') 56 | 57 | parser.add_argument('--input_size', default=224, type=int, 58 | help='images input size') 59 | 60 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 61 | help='Drop path rate (default: 0.1)') 62 | 63 | # Optimizer parameters 64 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 65 | help='Clip gradient norm (default: None, no clipping)') 66 | parser.add_argument('--weight_decay', type=float, default=0.05, 67 | help='weight decay (default: 0.05)') 68 | 69 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 70 | help='learning rate (absolute lr)') 71 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 72 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 73 | parser.add_argument('--layer_decay', type=float, default=0.75, 74 | help='layer-wise lr decay from ELECTRA/BEiT') 75 | 76 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 77 | help='lower lr bound for cyclic schedulers that hit 0') 78 | 79 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 80 | help='epochs to warmup LR') 81 | 82 | # Augmentation parameters 83 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 84 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 85 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 86 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 87 | parser.add_argument('--smoothing', type=float, default=0.1, 88 | help='Label smoothing (default: 0.1)') 89 | 90 | # * Random Erase params 91 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 92 | help='Random erase prob (default: 0.25)') 93 | parser.add_argument('--remode', type=str, default='pixel', 94 | help='Random erase mode (default: "pixel")') 95 | parser.add_argument('--recount', type=int, default=1, 96 | help='Random erase count (default: 1)') 97 | parser.add_argument('--resplit', action='store_true', default=False, 98 | help='Do not random erase first (clean) augmentation split') 99 | 100 | # * Mixup params 101 | parser.add_argument('--mixup', type=float, default=0, 102 | help='mixup alpha, mixup enabled if > 0.') 103 | parser.add_argument('--cutmix', type=float, default=0, 104 | help='cutmix alpha, cutmix enabled if > 0.') 105 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 106 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 107 | parser.add_argument('--mixup_prob', type=float, default=1.0, 108 | help='Probability of performing mixup or cutmix when either/both is enabled') 109 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 110 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 111 | parser.add_argument('--mixup_mode', type=str, default='batch', 112 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 113 | 114 | # * Finetuning params 115 | parser.add_argument('--finetune', default='', 116 | help='finetune from checkpoint') 117 | parser.add_argument('--global_pool', action='store_true') 118 | parser.set_defaults(global_pool=True) 119 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 120 | help='Use class token instead of global pool for classification') 121 | 122 | # Dataset parameters 123 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 124 | help='dataset path') 125 | parser.add_argument('--nb_classes', default=1000, type=int, 126 | help='number of the classification types') 127 | 128 | parser.add_argument('--output_dir', default='./output_dir', 129 | help='path where to save, empty for no saving') 130 | parser.add_argument('--log_dir', default=None, 131 | help='path where to tensorboard log') 132 | parser.add_argument('--device', default='cuda', 133 | help='device to use for training / testing') 134 | parser.add_argument('--seed', default=0, type=int) 135 | parser.add_argument('--resume', default='', 136 | help='resume from checkpoint') 137 | 138 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 139 | help='start epoch') 140 | parser.add_argument('--eval', action='store_true', 141 | help='Perform evaluation only') 142 | parser.add_argument('--dist_eval', action='store_true', default=False, 143 | help='Enabling distributed evaluation (recommended during training for faster monitor') 144 | parser.add_argument('--num_workers', default=10, type=int) 145 | parser.add_argument('--pin_mem', action='store_true', 146 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 147 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 148 | parser.set_defaults(pin_mem=True) 149 | 150 | parser.add_argument('--multi_epochs_dataloader', action='store_true', help='Use MultiEpochsDataLoader to prevent reinitializing dataloader per epoch') 151 | 152 | # distributed training parameters 153 | parser.add_argument('--world_size', default=1, type=int, 154 | help='number of distributed processes') 155 | parser.add_argument('--local_rank', default=-1, type=int) 156 | parser.add_argument('--dist_on_itp', action='store_true') 157 | parser.add_argument('--dist_url', default='env://', 158 | help='url used to set up distributed training') 159 | 160 | parser.add_argument('--enable_flash_attention2', action='store_true', default=False, help="Use flash attntion 2") 161 | 162 | return parser 163 | 164 | 165 | def main(args): 166 | misc.init_distributed_mode(args) 167 | 168 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 169 | print("{}".format(args).replace(', ', ',\n')) 170 | 171 | device = torch.device(args.device) 172 | 173 | # fix the seed for reproducibility 174 | seed = args.seed + misc.get_rank() 175 | torch.manual_seed(seed) 176 | np.random.seed(seed) 177 | 178 | cudnn.benchmark = True 179 | 180 | handle_flash_attn(args) 181 | 182 | dataset_train = build_dataset(is_train=True, args=args) 183 | dataset_val = build_dataset(is_train=False, args=args) 184 | 185 | if True: # args.distributed: 186 | num_tasks = misc.get_world_size() 187 | global_rank = misc.get_rank() 188 | sampler_train = torch.utils.data.DistributedSampler( 189 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 190 | ) 191 | print("Sampler_train = %s" % str(sampler_train)) 192 | if args.dist_eval: 193 | if len(dataset_val) % num_tasks != 0: 194 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 195 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 196 | 'equal num of samples per-process.') 197 | sampler_val = torch.utils.data.DistributedSampler( 198 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 199 | else: 200 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 201 | else: 202 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 203 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 204 | 205 | if global_rank == 0 and args.log_dir is not None and not args.eval: 206 | os.makedirs(args.log_dir, exist_ok=True) 207 | log_writer = SummaryWriter(log_dir=args.log_dir) 208 | else: 209 | log_writer = None 210 | 211 | dataloader_cls = MultiEpochsDataLoader if args.multi_epochs_dataloader else torch.utils.data.DataLoader 212 | 213 | data_loader_train = dataloader_cls( 214 | dataset_train, sampler=sampler_train, 215 | batch_size=args.batch_size, 216 | num_workers=args.num_workers, 217 | pin_memory=args.pin_mem, 218 | drop_last=True, 219 | ) 220 | 221 | data_loader_val = torch.utils.data.DataLoader( 222 | dataset_val, sampler=sampler_val, 223 | batch_size=args.batch_size, 224 | num_workers=args.num_workers, 225 | pin_memory=args.pin_mem, 226 | drop_last=False 227 | ) 228 | 229 | mixup_fn = None 230 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 231 | if mixup_active: 232 | print("Mixup is activated!") 233 | mixup_fn = Mixup( 234 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 235 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 236 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 237 | 238 | model = models_vit.__dict__[args.model]( 239 | num_classes=args.nb_classes, 240 | drop_path_rate=args.drop_path, 241 | global_pool=args.global_pool, 242 | img_size=args.input_size, 243 | ) 244 | 245 | if args.finetune and not args.eval: 246 | checkpoint = torch.load(args.finetune, map_location='cpu') 247 | 248 | print("Load pre-trained checkpoint from: %s" % args.finetune) 249 | checkpoint_model = checkpoint['model'] 250 | state_dict = model.state_dict() 251 | for k in ['head.weight', 'head.bias']: 252 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 253 | print(f"Removing key {k} from pretrained checkpoint") 254 | del checkpoint_model[k] 255 | 256 | # interpolate position embedding 257 | interpolate_pos_embed(model, checkpoint_model) 258 | 259 | # load pre-trained model 260 | msg = model.load_state_dict(checkpoint_model, strict=False) 261 | print(msg) 262 | 263 | if args.global_pool: 264 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 265 | else: 266 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 267 | 268 | # manually initialize fc layer 269 | trunc_normal_(model.head.weight, std=2e-5) 270 | 271 | model.to(device) 272 | 273 | model_without_ddp = model 274 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 275 | 276 | print("Model = %s" % str(model_without_ddp)) 277 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 278 | 279 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 280 | 281 | if args.lr is None: # only base_lr is specified 282 | args.lr = args.blr * eff_batch_size / 256 283 | 284 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 285 | print("actual lr: %.2e" % args.lr) 286 | 287 | print("accumulate grad iterations: %d" % args.accum_iter) 288 | print("effective batch size: %d" % eff_batch_size) 289 | 290 | if args.distributed: 291 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 292 | model_without_ddp = model.module 293 | 294 | # build optimizer with layer-wise lr decay (lrd) 295 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 296 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 297 | layer_decay=args.layer_decay 298 | ) 299 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 300 | loss_scaler = NativeScaler() 301 | 302 | if mixup_fn is not None: 303 | # smoothing is handled with mixup label transform 304 | criterion = SoftTargetCrossEntropy() 305 | elif args.smoothing > 0.: 306 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 307 | else: 308 | criterion = torch.nn.CrossEntropyLoss() 309 | 310 | print("criterion = %s" % str(criterion)) 311 | 312 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 313 | 314 | if args.eval: 315 | test_stats = evaluate(data_loader_val, model, device) 316 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 317 | exit(0) 318 | 319 | print(f"Start training for {args.epochs} epochs") 320 | start_time = time.time() 321 | max_accuracy = 0.0 322 | for epoch in range(args.start_epoch, args.epochs): 323 | if args.distributed: 324 | data_loader_train.sampler.set_epoch(epoch) 325 | train_stats = train_one_epoch( 326 | model, criterion, data_loader_train, 327 | optimizer, device, epoch, loss_scaler, 328 | args.clip_grad, mixup_fn, 329 | log_writer=log_writer, 330 | args=args 331 | ) 332 | 333 | test_stats = evaluate(data_loader_val, model, device) 334 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 335 | 336 | if args.output_dir and test_stats["acc1"] >= max_accuracy: 337 | misc.save_model( 338 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 339 | loss_scaler=loss_scaler, epoch=epoch, save_latest_model_only=True) 340 | 341 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 342 | print(f'Max accuracy: {max_accuracy:.2f}%') 343 | 344 | if log_writer is not None: 345 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 346 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 347 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 348 | 349 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 350 | **{f'test_{k}': v for k, v in test_stats.items()}, 351 | 'epoch': epoch, 352 | 'n_parameters': n_parameters} 353 | 354 | if args.output_dir and misc.is_main_process(): 355 | if log_writer is not None: 356 | log_writer.flush() 357 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 358 | f.write(json.dumps(log_stats) + "\n") 359 | 360 | total_time = time.time() - start_time 361 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 362 | print('Training time {}'.format(total_time_str)) 363 | 364 | 365 | if __name__ == '__main__': 366 | args = get_args_parser() 367 | args = args.parse_args() 368 | if args.log_dir is None: 369 | args.log_dir = args.output_dir 370 | if args.output_dir: 371 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 372 | main(args) 373 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. --------------------------------------------------------------------------------