├── .gitignore ├── smalldatavit.png ├── requirements.txt ├── util ├── lr_sched.py ├── crop.py ├── lars.py ├── lr_decay.py ├── datasets.py ├── pos_embed.py └── misc.py ├── environment.yml ├── models_vit.py ├── README.md ├── engine_two_branch.py ├── models_mae.py ├── model_mae_image_loss.py └── main_two_branch.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.sh 2 | output_dir/ 3 | data/ 4 | __pycache__/ 5 | *.pth 6 | -------------------------------------------------------------------------------- /smalldatavit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dominickrei/Limited-data-vits/HEAD/smalldatavit.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.5 2 | submitit==1.4.5 3 | timm==0.6.12 4 | torch==1.8.1+cu111 5 | torchvision==0.9.1+cu111 6 | tensorboard==2.11.2 7 | wandb==0.13.4 8 | scipy==1.7.3 9 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: limiteddatavit 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.12.12=h06a4308_0 8 | - certifi=2022.12.7=py37h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.4.4=h6a678d5_0 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.4=h6a678d5_0 15 | - openssl=1.1.1w=h7f8727e_0 16 | - pip=22.3.1=py37h06a4308_0 17 | - python=3.7.16=h7a1cb2a_0 18 | - readline=8.2=h5eee18b_0 19 | - setuptools=65.6.3=py37h06a4308_0 20 | - sqlite=3.41.2=h5eee18b_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.38.4=py37h06a4308_0 23 | - xz=5.4.5=h5eee18b_0 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - absl-py==2.0.0 27 | - cachetools==5.3.2 28 | - charset-normalizer==3.3.2 29 | - click==8.1.7 30 | - cloudpickle==2.2.1 31 | - docker-pycreds==0.4.0 32 | - filelock==3.12.2 33 | - fsspec==2023.1.0 34 | - gitdb==4.0.11 35 | - gitpython==3.1.40 36 | - google-auth==2.25.2 37 | - google-auth-oauthlib==0.4.6 38 | - grpcio==1.60.0 39 | - huggingface-hub==0.16.4 40 | - idna==3.6 41 | - importlib-metadata==6.7.0 42 | - markdown==3.4.4 43 | - markupsafe==2.1.3 44 | - numpy==1.21.5 45 | - oauthlib==3.2.2 46 | - packaging==23.2 47 | - pathtools==0.1.2 48 | - pillow==9.2.0 49 | - promise==2.3 50 | - protobuf==3.20.3 51 | - psutil==5.9.7 52 | - pyasn1==0.5.1 53 | - pyasn1-modules==0.3.0 54 | - pyyaml==6.0.1 55 | - requests==2.31.0 56 | - requests-oauthlib==1.3.1 57 | - rsa==4.9 58 | - sentry-sdk==1.39.1 59 | - setproctitle==1.3.3 60 | - shortuuid==1.0.11 61 | - six==1.16.0 62 | - smmap==5.0.1 63 | - submitit==1.4.5 64 | - tensorboard==2.11.2 65 | - tensorboard-data-server==0.6.1 66 | - tensorboard-plugin-wit==1.8.1 67 | - timm==0.6.12 68 | - torch==1.8.1+cu111 69 | - torchvision==0.9.1+cu111 70 | - tqdm==4.66.1 71 | - typing-extensions==4.7.1 72 | - urllib3==2.0.7 73 | - wandb==0.13.4 74 | - werkzeug==2.2.3 75 | - zipp==3.15.0 76 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | import timm.models.vision_transformer 17 | 18 | 19 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 20 | """ Vision Transformer with support for global average pooling 21 | """ 22 | def __init__(self, global_pool=False, **kwargs): 23 | super(VisionTransformer, self).__init__(**kwargs) 24 | 25 | def forward_features(self, x): 26 | B = x.shape[0] 27 | x = self.patch_embed(x) 28 | 29 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 30 | x = torch.cat((cls_tokens, x), dim=1) 31 | x = x + self.pos_embed 32 | x = self.pos_drop(x) 33 | 34 | for blk in self.blocks: 35 | x = blk(x) 36 | 37 | 38 | x = self.norm(x) 39 | outcome = x[:, 0] 40 | return outcome 41 | 42 | def forward_head(self, x): 43 | return self.head(x) 44 | 45 | 46 | def vit_base_patch16(**kwargs): 47 | model = VisionTransformer( 48 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 49 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 50 | return model 51 | 52 | 53 | def vit_large_patch16(**kwargs): 54 | model = VisionTransformer( 55 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 56 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 57 | return model 58 | 59 | 60 | def vit_huge_patch14(**kwargs): 61 | model = VisionTransformer( 62 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 63 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 64 | return model 65 | 66 | 67 | def vit_tiny(**kwargs): 68 | model = VisionTransformer( 69 | # patch_size=16 , 70 | embed_dim=192, depth=12, num_heads=3, 71 | mlp_ratio=4,qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 72 | ) 73 | return model 74 | 75 | def vit_small(**kwargs): 76 | model = VisionTransformer( 77 | embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 78 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 79 | 80 | return model 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [WACV 2024] Limited Data, Unlimited Potential: A Study on ViTs Augmented by Masked Autoencoders 2 | This is the official codebase for our paper "Limited Data, Unlimited Potential: A Study on ViTs Augmented by Masked Autoencoders" presented at WACV 2024. The paper can be viewed at [this link](https://arxiv.org/abs/2310.20704). 3 | 4 | ![Overview of self-supervised auxiliary task (SSAT)](smalldatavit.png) 5 | 6 | ## Installation 7 | 8 | Create the conda environment and install the necessary packages: 9 | 10 | ``` 11 | conda env create -f environment.yml -n limiteddatavit 12 | ``` 13 | 14 | or alternatively 15 | 16 | ``` 17 | conda create -n limiteddatavit python=3.7 -y 18 | pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 19 | ``` 20 | 21 | ## Data preparation 22 | 23 | We provide code for training on ImageNet, CIFAR10, and CIFAR100. CIFAR10 and 100 will be automatically downloaded using torchvision, ImageNet must be downloaded separately. 24 | 25 | Download and extract ImageNet train and val images from http://image-net.org/. 26 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively: 27 | 28 | ``` 29 | /path/to/imagenet/ 30 | train/ 31 | class1/ 32 | img1.jpeg 33 | class2/ 34 | img2.jpeg 35 | val/ 36 | class1/ 37 | img3.jpeg 38 | class2/ 39 | img4.jpeg 40 | ``` 41 | 42 | ## Pretrained model weights 43 | | Model | Dataset | Evaluation Command| 44 | | ------------- | ------------- | ------------- | 45 | | ViT-T + SSAT ([weights](https://drive.google.com/file/d/1zD4t6m98UckQkk8f2V1PLIaPIH_0HqWS/view?usp=sharing)) | ImageNet-1k | `python main_two_branch.py --data_path /path/to/imagenet/ --resume vittiny-ssat_imagenet1k_weights.pth --eval --model mae_vit_tiny` | 46 | | ViT-S + SSAT ([weights](https://drive.google.com/file/d/1Z6ynVVyxZavUjoRtRnIuYzLIW0zQJE4C/view?usp=sharing)) | ImageNet-1k | `python main_two_branch.py --data_path /path/to/imagenet/ --resume vitsmall-ssat_imagenet1k_weights.pth --eval --model mae_vit_small` | 47 | 48 | 49 | ## Training models 50 | To train ViT-Tiny with Self-Supervised Auxiliary Task on ImageNet-1k using 8 GPUs run the following command: 51 | ``` 52 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main_two_branch.py --data_path /path/to/imagenet/ --output_dir ./output_dir --epochs 100 --model mae_vit_tiny 53 | ``` 54 | 55 | Available arguments for `--data_path` are `/path/to/imagenet`, `c10`, `c100`. Other datasets can be added in `utils/datasets.py`. 56 | 57 | Available arguments for `--model` are `mae_vit_tiny`, `mae_vit_small`, `mae_vit_base`, `mae_vit_large`, `mae_vit_huge`. 58 | 59 | ## Citation & Acknowledgement 60 | ``` 61 | @article{das-limiteddatavit-wacv2024, 62 | title={Limited Data, Unlimited Potential: A Study on ViTs Augmented by Masked Autoencoders}, 63 | author={Srijan Das and Tanmay Jain and Dominick Reilly and Pranav Balaji and Soumyajit Karmakar and Shyam Marjit and Xiang Li and Abhijit Das and Michael Ryoo}, 64 | journal={2024 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 65 | year={2024} 66 | } 67 | ``` 68 | 69 | This repository is built on top of the [code](https://github.com/facebookresearch/mae) for the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) from Meta Research. 70 | -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | import torch 16 | 17 | from timm.data import create_transform 18 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | 20 | import numpy as np 21 | 22 | 23 | def build_dataset(is_train, args): 24 | transform = build_transform(is_train, args) 25 | 26 | if args.data_path == 'c10': 27 | dataset = datasets.CIFAR10( 28 | root='./data', train=is_train, download=True, transform=transform 29 | ) 30 | elif args.data_path == 'c100': 31 | dataset = datasets.CIFAR100( 32 | root='./data', train=is_train, download=True, transform=transform 33 | ) 34 | elif args.data_path == 'svhn': 35 | split = 'train' if is_train else 'test' 36 | dataset = datasets.SVHN(root='./data', split=split, download=True, transform=transform) 37 | else: 38 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 39 | dataset = datasets.ImageFolder(root, transform=transform) 40 | 41 | if is_train and args.subset_size != 1: 42 | np.random.seed(127) 43 | 44 | dsize = len(dataset) 45 | idxs = np.random.choice(dsize, int(dsize*args.subset_size), replace=False) 46 | 47 | dataset = torch.utils.data.Subset(dataset, idxs) 48 | print('Subset dataset size: ', len(dataset)) 49 | 50 | print(f'{dataset} ({len(dataset)})') 51 | 52 | return dataset 53 | 54 | 55 | def build_transform(is_train, args): 56 | if args.data_path == 'c10': 57 | mean = (0.4914, 0.4822, 0.4465) 58 | std = (0.2023, 0.1994, 0.2010) 59 | else: 60 | mean = IMAGENET_DEFAULT_MEAN 61 | std = IMAGENET_DEFAULT_STD 62 | 63 | # train transform 64 | if is_train: 65 | # this should always dispatch to transforms_imagenet_train 66 | transform = create_transform( 67 | input_size=args.input_size, 68 | is_training=True, 69 | color_jitter=args.color_jitter, 70 | auto_augment=args.aa, 71 | interpolation='bicubic', 72 | re_prob=args.reprob, 73 | re_mode=args.remode, 74 | re_count=args.recount, 75 | mean=mean, 76 | std=std, 77 | ) 78 | return transform 79 | 80 | # eval transform 81 | t = [] 82 | if args.input_size <= 224: 83 | crop_pct = 224 / 256 84 | else: 85 | crop_pct = 1.0 86 | size = int(args.input_size / crop_pct) 87 | t.append( 88 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 89 | ) 90 | t.append(transforms.CenterCrop(args.input_size)) 91 | 92 | if 'perturb_perspective' in vars(args) and args.perturb_perspective: 93 | print('[Log] Perturbing perspective of images in dataset') 94 | t.append(transforms.RandomPerspective(distortion_scale=0.5, p=1.0)) 95 | else: 96 | print('[Log] Not perturbing perspective of images in dataset') 97 | 98 | t.append(transforms.ToTensor()) 99 | t.append(transforms.Normalize(mean, std)) 100 | return transforms.Compose(t) 101 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /engine_two_branch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | import torch 16 | import wandb 17 | from timm.data import Mixup 18 | from timm.utils import accuracy 19 | 20 | import util.misc as misc 21 | import util.lr_sched as lr_sched 22 | 23 | 24 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 25 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 26 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 27 | mixup_fn: Optional[Mixup] = None, log_writer=None, 28 | args=None): 29 | model.train(True) 30 | metric_logger = misc.MetricLogger(delimiter=" ") 31 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 32 | header = 'Epoch: [{}]'.format(epoch) 33 | print_freq = 20 34 | 35 | accum_iter = args.accum_iter 36 | 37 | optimizer.zero_grad() 38 | 39 | if log_writer is not None: 40 | print('log_dir: {}'.format(log_writer.log_dir)) 41 | 42 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 43 | 44 | # we use a per iteration (instead of per epoch) lr scheduler 45 | if data_iter_step % accum_iter == 0: 46 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 47 | 48 | samples = samples.to(device, non_blocking=True) 49 | targets = targets.to(device, non_blocking=True) 50 | 51 | if mixup_fn is not None: 52 | samples, targets = mixup_fn(samples, targets) 53 | loss = 0.0 54 | with torch.cuda.amp.autocast(): 55 | loss_twobranch , outputs = model(samples , args.mask_ratio) 56 | classification_loss = args.lambda_weight * criterion(outputs, targets) 57 | loss_twobranch = (1-args.lambda_weight) * loss_twobranch 58 | loss = loss_twobranch + classification_loss 59 | 60 | 61 | loss_value = loss.item() 62 | 63 | if not math.isfinite(loss_value): 64 | print("Loss is {}, stopping training".format(loss_value)) 65 | sys.exit(1) 66 | 67 | loss /= accum_iter 68 | loss_scaler(loss, optimizer, clip_grad=max_norm, 69 | parameters=model.parameters(), create_graph=False, 70 | update_grad=(data_iter_step + 1) % accum_iter == 0) 71 | if (data_iter_step + 1) % accum_iter == 0: 72 | optimizer.zero_grad() 73 | 74 | torch.cuda.synchronize() 75 | 76 | metric_logger.update(training_loss=loss_value) 77 | metric_logger.update(mae_loss = loss_twobranch) 78 | metric_logger.update(classification_loss = classification_loss) 79 | min_lr = 10. 80 | max_lr = 0. 81 | for group in optimizer.param_groups: 82 | min_lr = min(min_lr, group["lr"]) 83 | max_lr = max(max_lr, group["lr"]) 84 | 85 | metric_logger.update(lr=max_lr) 86 | 87 | loss_value_reduce = misc.all_reduce_mean(loss_value) 88 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 89 | """ We use epoch_1000x as the x-axis in tensorboard. 90 | This calibrates different curves when batch size changes. 91 | """ 92 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 93 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 94 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 95 | 96 | # gather the stats from all processes 97 | metric_logger.synchronize_between_processes() 98 | print("Averaged stats:", metric_logger) 99 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 100 | 101 | 102 | @torch.no_grad() 103 | def evaluate(data_loader, model, device): 104 | criterion = torch.nn.CrossEntropyLoss() 105 | 106 | metric_logger = misc.MetricLogger(delimiter=" ") 107 | header = 'Test:' 108 | 109 | # switch to evaluation mode 110 | model.eval() 111 | 112 | for batch in metric_logger.log_every(data_loader, 10, header): 113 | images = batch[0] 114 | target = batch[-1] 115 | images = images.to(device, non_blocking=True) 116 | target = target.to(device, non_blocking=True) 117 | 118 | # compute output 119 | with torch.cuda.amp.autocast(): 120 | output = model.forward_test(images) 121 | loss = criterion(output, target) 122 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 123 | 124 | 125 | batch_size = images.shape[0] 126 | metric_logger.update(testing_loss=loss.item()) 127 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 128 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 129 | # gather the stats from all processes 130 | metric_logger.synchronize_between_processes() 131 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 132 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.testing_loss)) 133 | 134 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 135 | -------------------------------------------------------------------------------- /models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from timm.models.vision_transformer import PatchEmbed, Block 18 | 19 | from util.pos_embed import get_2d_sincos_pos_embed 20 | 21 | 22 | class MaskedAutoencoderViT(nn.Module): 23 | """ Masked Autoencoder with VisionTransformer backbone 24 | """ 25 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 26 | embed_dim=1024, depth=24, num_heads=16, 27 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 28 | mlp_ratio=4., norm_layer=nn.LayerNorm, num_classes =10 ,norm_pix_loss=False): 29 | super().__init__() 30 | 31 | print(f'Patch size of MaskedAutoencoderViT: {patch_size}') 32 | 33 | # -------------------------------------------------------------------------- 34 | # MAE encoder specifics 35 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 36 | num_patches = self.patch_embed.num_patches 37 | 38 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 39 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 40 | 41 | self.blocks = nn.ModuleList([ 42 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 43 | for i in range(depth)]) 44 | self.norm = norm_layer(embed_dim) 45 | # -------------------------------------------------------------------------- 46 | 47 | # -------------------------------------------------------------------------- 48 | # MAE decoder specifics 49 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 50 | self.decoder_embed_vanilla = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 51 | 52 | 53 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 54 | 55 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 56 | 57 | self.decoder_blocks = nn.ModuleList([ 58 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 59 | for i in range(decoder_depth)]) 60 | 61 | self.decoder_norm = norm_layer(decoder_embed_dim) 62 | # -------------------------------------------------------------------------- 63 | 64 | 65 | 66 | 67 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 68 | self.norm_pix_loss = norm_pix_loss 69 | 70 | self.initialize_weights() 71 | 72 | def initialize_weights(self): 73 | # initialization 74 | # initialize (and freeze) pos_embed by sin-cos embedding 75 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 76 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 77 | 78 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 79 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 80 | 81 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 82 | w = self.patch_embed.proj.weight.data 83 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 84 | 85 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 86 | torch.nn.init.normal_(self.cls_token, std=.02) 87 | torch.nn.init.normal_(self.mask_token, std=.02) 88 | 89 | # initialize nn.Linear and nn.LayerNorm 90 | self.apply(self._init_weights) 91 | 92 | def _init_weights(self, m): 93 | if isinstance(m, nn.Linear): 94 | # we use xavier_uniform following official JAX ViT: 95 | torch.nn.init.xavier_uniform_(m.weight) 96 | if isinstance(m, nn.Linear) and m.bias is not None: 97 | nn.init.constant_(m.bias, 0) 98 | elif isinstance(m, nn.LayerNorm): 99 | nn.init.constant_(m.bias, 0) 100 | nn.init.constant_(m.weight, 1.0) 101 | 102 | def patchify(self, imgs): 103 | """ 104 | imgs: (N, 3, H, W) 105 | x: (N, L, patch_size**2 *3) 106 | """ 107 | p = self.patch_embed.patch_size[0] 108 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 109 | 110 | h = w = imgs.shape[2] // p 111 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 112 | x = torch.einsum('nchpwq->nhwpqc', x) 113 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 114 | return x 115 | 116 | def unpatchify(self, x): 117 | """ 118 | x: (N, L, patch_size**2 *3) 119 | imgs: (N, 3, H, W) 120 | """ 121 | p = self.patch_embed.patch_size[0] 122 | h = w = int(x.shape[1]**.5) 123 | assert h * w == x.shape[1] 124 | 125 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 126 | x = torch.einsum('nhwpqc->nchpwq', x) 127 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 128 | return imgs 129 | 130 | def random_masking(self, x, mask_ratio): 131 | """ 132 | Perform per-sample random masking by per-sample shuffling. 133 | Per-sample shuffling is done by argsort random noise. 134 | x: [N, L, D], sequence 135 | """ 136 | N, L, D = x.shape # batch, length, dim 137 | len_keep = int(L * (1 - mask_ratio)) 138 | 139 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 140 | 141 | # sort noise for each sample 142 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 143 | ids_restore = torch.argsort(ids_shuffle, dim=1) 144 | 145 | # keep the first subset 146 | ids_keep = ids_shuffle[:, :len_keep] 147 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 148 | 149 | # generate the binary mask: 0 is keep, 1 is remove 150 | mask = torch.ones([N, L], device=x.device) 151 | mask[:, :len_keep] = 0 152 | # unshuffle to get the binary mask 153 | mask = torch.gather(mask, dim=1, index=ids_restore) 154 | 155 | return x_masked, mask, ids_restore 156 | 157 | def forward_vanilla(self, x): 158 | # embed patches 159 | x = self.patch_embed(x) 160 | 161 | # add pos embed w/o cls token 162 | x = x + self.pos_embed[:, 1:, :] 163 | 164 | 165 | # append cls token 166 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 167 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 168 | x = torch.cat((cls_tokens, x), dim=1) 169 | 170 | # apply Transformer blocks 171 | for blk in self.blocks: 172 | x = blk(x) 173 | x = self.norm(x) 174 | 175 | return x 176 | 177 | 178 | def forward_encoder(self, x, mask_ratio): 179 | # embed patches 180 | x = self.patch_embed(x) 181 | 182 | # add pos embed w/o cls token 183 | x = x + self.pos_embed[:, 1:, :] 184 | 185 | # masking: length -> length * mask_ratio 186 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 187 | 188 | # append cls token 189 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 190 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 191 | x = torch.cat((cls_tokens, x), dim=1) 192 | 193 | # apply Transformer blocks 194 | for blk in self.blocks: 195 | x = blk(x) 196 | x = self.norm(x) 197 | 198 | return x, mask, ids_restore 199 | 200 | def forward_decoder(self, x, ids_restore): 201 | # embed tokens 202 | x = self.decoder_embed(x) 203 | 204 | # append mask tokens to sequence 205 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 206 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 207 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 208 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 209 | 210 | # add pos embed 211 | x = x + self.decoder_pos_embed 212 | 213 | # apply Transformer blocks 214 | for blk in self.decoder_blocks: 215 | x = blk(x) 216 | x = self.decoder_norm(x) 217 | 218 | # remove cls token 219 | x = x[:, 1:, :] 220 | 221 | return x 222 | 223 | def forward_loss(self, img_vanilla, decoder_predicted , mask): 224 | if self.norm_pix_loss: 225 | mean = img_vanilla.mean(dim=-1, keepdim=True) 226 | var = img_vanilla.var(dim=-1, keepdim=True) 227 | img_vanilla = (img_vanilla - mean) / (var + 1.e-6)**.5 228 | loss = (img_vanilla - decoder_predicted) ** 2 229 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 230 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 231 | return loss 232 | 233 | def forward(self, imgs, mask_ratio=0.75): 234 | # This has the class token appended to it 235 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 236 | # This also has the class token 237 | latent_vanilla = self.forward_vanilla(imgs) 238 | 239 | cls_vanilla = latent_vanilla[: , 0 , :] 240 | predicted_class = self.head(cls_vanilla) # Class predictions by the network 241 | 242 | img_patch_vanilla = latent_vanilla[: , 1: , :] 243 | img_patch_vanilla = self.decoder_embed_vanilla(img_patch_vanilla) 244 | #This doesnt have class token 245 | pred_decoder = self.forward_decoder(latent, ids_restore) 246 | 247 | loss_twobranch = self.forward_loss(img_patch_vanilla , pred_decoder , mask) 248 | return loss_twobranch, predicted_class 249 | 250 | def forward_test(self, imgs): 251 | output = self.forward_vanilla(imgs) 252 | class_token = output[: , 0 , :] 253 | predicted_class = self.head(class_token) 254 | return predicted_class 255 | 256 | 257 | def mae_vit_tiny_dec128d2b(**kwargs): 258 | model = MaskedAutoencoderViT( 259 | embed_dim=192, depth=12, num_heads=3, 260 | decoder_embed_dim=128, decoder_depth=2, decoder_num_heads=16, 261 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 262 | ) 263 | return model 264 | 265 | 266 | def mae_vit_base_dec512d8b(**kwargs): 267 | model = MaskedAutoencoderViT( 268 | embed_dim=768, depth=12, num_heads=12, 269 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 270 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 271 | return model 272 | 273 | 274 | def mae_vit_large_dec512d8b(**kwargs): 275 | model = MaskedAutoencoderViT( 276 | embed_dim=1024, depth=24, num_heads=16, 277 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 278 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 279 | return model 280 | 281 | 282 | def mae_vit_huge_dec512d8b(**kwargs): 283 | model = MaskedAutoencoderViT( 284 | embed_dim=1280, depth=32, num_heads=16, 285 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 286 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 287 | return model 288 | 289 | 290 | 291 | 292 | def mae_vit_small_dec128d2b(**kwargs): 293 | model = MaskedAutoencoderViT( 294 | embed_dim=384, depth=12, num_heads=6, 295 | decoder_embed_dim=128, decoder_depth=2, decoder_num_heads=16, 296 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 297 | ) 298 | return model 299 | 300 | 301 | # set recommended archs 302 | mae_vit_tiny = mae_vit_tiny_dec128d2b 303 | mae_vit_small = mae_vit_small_dec128d2b 304 | mae_vit_base = mae_vit_base_dec512d8b # decoder: 512 dim, 8 blocks 305 | mae_vit_large = mae_vit_large_dec512d8b # decoder: 512 dim, 8 blocks 306 | mae_vit_huge = mae_vit_huge_dec512d8b # decoder: 512 dim, 8 blocks 307 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 316 | if args.resume: 317 | if args.resume.startswith('https'): 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | args.resume, map_location='cpu', check_hash=True) 320 | else: 321 | checkpoint = torch.load(args.resume, map_location='cpu') 322 | model_without_ddp.load_state_dict(checkpoint['model']) 323 | print("Resume checkpoint %s" % args.resume) 324 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | args.start_epoch = checkpoint['epoch'] + 1 327 | if 'scaler' in checkpoint: 328 | loss_scaler.load_state_dict(checkpoint['scaler']) 329 | print("With optim & sched!") 330 | 331 | 332 | def all_reduce_mean(x): 333 | world_size = get_world_size() 334 | if world_size > 1: 335 | x_reduce = torch.tensor(x).cuda() 336 | dist.all_reduce(x_reduce) 337 | x_reduce /= world_size 338 | return x_reduce.item() 339 | else: 340 | return x -------------------------------------------------------------------------------- /model_mae_image_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp 18 | 19 | from util.pos_embed import get_2d_sincos_pos_embed 20 | 21 | 22 | class Attention(nn.Module): 23 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 24 | super().__init__() 25 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 26 | self.num_heads = num_heads 27 | head_dim = dim // num_heads 28 | self.scale = head_dim ** -0.5 29 | 30 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 31 | self.attn_drop = nn.Dropout(attn_drop) 32 | self.proj = nn.Linear(dim, dim) 33 | self.proj_drop = nn.Dropout(proj_drop) 34 | 35 | def forward(self, x , attention_block=False): 36 | if attention_block == False: 37 | B, N, C = x.shape 38 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 39 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 40 | 41 | attn = (q @ k.transpose(-2, -1)) * self.scale 42 | attn = attn.softmax(dim=-1) 43 | attn = self.attn_drop(attn) 44 | 45 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 46 | x = self.proj(x) 47 | x = self.proj_drop(x) 48 | return x 49 | elif attention_block == True: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 53 | 54 | attn = (q @ k.transpose(-2, -1)) * self.scale 55 | attn = attn.softmax(dim=-1) 56 | return attn 57 | 58 | class LayerScale(nn.Module): 59 | def __init__(self, dim, init_values=1e-5, inplace=False): 60 | super().__init__() 61 | self.inplace = inplace 62 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 63 | 64 | def forward(self, x): 65 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 66 | 67 | 68 | class Block(nn.Module): 69 | 70 | def __init__( 71 | self, 72 | dim, 73 | num_heads, 74 | mlp_ratio=4., 75 | qkv_bias=False, 76 | drop=0., 77 | attn_drop=0., 78 | init_values=None, 79 | drop_path=0., 80 | act_layer=nn.GELU, 81 | norm_layer=nn.LayerNorm 82 | ): 83 | super().__init__() 84 | self.norm1 = norm_layer(dim) 85 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 86 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 87 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 88 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 89 | 90 | self.norm2 = norm_layer(dim) 91 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 92 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 93 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 94 | 95 | def forward(self, x , attention_block = False): 96 | if attention_block == False: 97 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 98 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 99 | return x 100 | elif attention_block == True: 101 | attention_matrix = self.attn(self.norm1(x),True) 102 | return attention_matrix 103 | 104 | 105 | 106 | class MaskedAutoencoderViT(nn.Module): 107 | """ Masked Autoencoder with VisionTransformer backbone 108 | """ 109 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 110 | embed_dim=1024, depth=24, num_heads=16, 111 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 112 | mlp_ratio=4., norm_layer=nn.LayerNorm, num_classes =10 ,norm_pix_loss=False): 113 | super().__init__() 114 | 115 | # -------------------------------------------------------------------------- 116 | # MAE encoder specifics 117 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 118 | num_patches = self.patch_embed.num_patches 119 | 120 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 121 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 122 | 123 | self.blocks = nn.ModuleList([ 124 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 125 | for i in range(depth)]) 126 | self.norm = norm_layer(embed_dim) 127 | # -------------------------------------------------------------------------- 128 | 129 | # -------------------------------------------------------------------------- 130 | # MAE decoder specifics 131 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 132 | 133 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 134 | 135 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 136 | 137 | self.decoder_blocks = nn.ModuleList([ 138 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 139 | for i in range(decoder_depth)]) 140 | 141 | self.decoder_norm = norm_layer(decoder_embed_dim) 142 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 143 | # -------------------------------------------------------------------------- 144 | 145 | 146 | 147 | 148 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 149 | self.norm_pix_loss = norm_pix_loss 150 | 151 | self.initialize_weights() 152 | 153 | def initialize_weights(self): 154 | # initialization 155 | # initialize (and freeze) pos_embed by sin-cos embedding 156 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 157 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 158 | 159 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 160 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 161 | 162 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 163 | w = self.patch_embed.proj.weight.data 164 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 165 | 166 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 167 | torch.nn.init.normal_(self.cls_token, std=.02) 168 | torch.nn.init.normal_(self.mask_token, std=.02) 169 | 170 | # initialize nn.Linear and nn.LayerNorm 171 | self.apply(self._init_weights) 172 | 173 | def _init_weights(self, m): 174 | if isinstance(m, nn.Linear): 175 | # we use xavier_uniform following official JAX ViT: 176 | torch.nn.init.xavier_uniform_(m.weight) 177 | if isinstance(m, nn.Linear) and m.bias is not None: 178 | nn.init.constant_(m.bias, 0) 179 | elif isinstance(m, nn.LayerNorm): 180 | nn.init.constant_(m.bias, 0) 181 | nn.init.constant_(m.weight, 1.0) 182 | 183 | def patchify(self, imgs): 184 | """ 185 | imgs: (N, 3, H, W) 186 | x: (N, L, patch_size**2 *3) 187 | """ 188 | p = self.patch_embed.patch_size[0] 189 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 190 | 191 | h = w = imgs.shape[2] // p 192 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 193 | x = torch.einsum('nchpwq->nhwpqc', x) 194 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 195 | return x 196 | 197 | def unpatchify(self, x): 198 | """ 199 | x: (N, L, patch_size**2 *3) 200 | imgs: (N, 3, H, W) 201 | """ 202 | p = self.patch_embed.patch_size[0] 203 | h = w = int(x.shape[1]**.5) 204 | assert h * w == x.shape[1] 205 | 206 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 207 | x = torch.einsum('nhwpqc->nchpwq', x) 208 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 209 | return imgs 210 | 211 | def random_masking(self, x, mask_ratio): 212 | """ 213 | Perform per-sample random masking by per-sample shuffling. 214 | Per-sample shuffling is done by argsort random noise. 215 | x: [N, L, D], sequence 216 | """ 217 | N, L, D = x.shape # batch, length, dim 218 | len_keep = int(L * (1 - mask_ratio)) 219 | 220 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 221 | 222 | # sort noise for each sample 223 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 224 | ids_restore = torch.argsort(ids_shuffle, dim=1) 225 | 226 | # keep the first subset 227 | ids_keep = ids_shuffle[:, :len_keep] 228 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 229 | 230 | # generate the binary mask: 0 is keep, 1 is remove 231 | mask = torch.ones([N, L], device=x.device) 232 | mask[:, :len_keep] = 0 233 | # unshuffle to get the binary mask 234 | mask = torch.gather(mask, dim=1, index=ids_restore) 235 | 236 | return x_masked, mask, ids_restore 237 | 238 | def forward_vanilla(self, x): 239 | # embed patches 240 | x = self.patch_embed(x) 241 | 242 | # add pos embed w/o cls token 243 | x = x + self.pos_embed[:, 1:, :] 244 | 245 | 246 | # append cls token 247 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 248 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 249 | x = torch.cat((cls_tokens, x), dim=1) 250 | 251 | # apply Transformer blocks 252 | for blk in self.blocks: 253 | x = blk(x) 254 | x = self.norm(x) 255 | 256 | return x 257 | 258 | 259 | def forward_encoder(self, x, mask_ratio): 260 | # embed patches 261 | x = self.patch_embed(x) 262 | 263 | # add pos embed w/o cls token 264 | x = x + self.pos_embed[:, 1:, :] 265 | 266 | # masking: length -> length * mask_ratio 267 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 268 | 269 | # append cls token 270 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 271 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 272 | x = torch.cat((cls_tokens, x), dim=1) 273 | 274 | # apply Transformer blocks 275 | for blk in self.blocks: 276 | x = blk(x) 277 | x = self.norm(x) 278 | 279 | return x, mask, ids_restore 280 | 281 | def forward_decoder(self, x, ids_restore): 282 | # embed tokens 283 | x = self.decoder_embed(x) 284 | 285 | # append mask tokens to sequence 286 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 287 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 288 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 289 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 290 | 291 | # add pos embed 292 | x = x + self.decoder_pos_embed 293 | 294 | # apply Transformer blocks 295 | for blk in self.decoder_blocks: 296 | x = blk(x) 297 | x = self.decoder_norm(x) 298 | x = self.decoder_pred(x) 299 | # remove cls token 300 | x = x[:, 1:, :] 301 | 302 | return x 303 | 304 | def forward_loss(self, imgs, pred, mask): 305 | target = self.patchify(imgs) 306 | if self.norm_pix_loss: 307 | mean = target.mean(dim=-1, keepdim=True) 308 | var = target.var(dim=-1, keepdim=True) 309 | target = (target - mean) / (var + 1.e-6)**.5 310 | 311 | loss = (pred - target) ** 2 312 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 313 | 314 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 315 | return loss 316 | 317 | def forward(self, imgs, mask_ratio=0.75): 318 | # This has the class token appended to it 319 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 320 | # This also has the class token 321 | latent_vanilla = self.forward_vanilla(imgs) 322 | 323 | cls_vanilla = latent_vanilla[: , 0 , :] 324 | predicted_class = self.head(cls_vanilla) # Class predictions by the network 325 | 326 | #This doesnt have class token 327 | pred_decoder = self.forward_decoder(latent, ids_restore) 328 | 329 | loss_twobranch = self.forward_loss(imgs, pred_decoder , mask) 330 | return loss_twobranch, predicted_class 331 | 332 | def forward_test(self, imgs): 333 | output = self.forward_vanilla(imgs) 334 | class_token = output[: , 0 , :] 335 | predicted_class = self.head(class_token) 336 | return predicted_class 337 | 338 | def forward_attention(self,img ,depth): 339 | x = self.patch_embed(img) 340 | x = x + self.pos_embed[:, 1:, :] 341 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 342 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 343 | x = torch.cat((cls_tokens, x), dim=1) 344 | count = 0 345 | # apply Transformer blocks 346 | for blk in self.blocks: 347 | if depth == count: 348 | attention_matrix = blk(x , True) 349 | break 350 | else: 351 | x = blk(x) 352 | count = count + 1 353 | 354 | 355 | return attention_matrix 356 | 357 | 358 | 359 | 360 | 361 | def mae_vit_tiny_dec128d2b(**kwargs): 362 | model = MaskedAutoencoderViT( 363 | embed_dim=192, depth=12, num_heads=3, 364 | decoder_embed_dim=128, decoder_depth=2, decoder_num_heads=16, 365 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 366 | ) 367 | return model 368 | 369 | 370 | def mae_vit_base_dec512d8b(**kwargs): 371 | model = MaskedAutoencoderViT( 372 | embed_dim=768, depth=12, num_heads=12, 373 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 374 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 375 | return model 376 | 377 | 378 | def mae_vit_large_dec512d8b(**kwargs): 379 | model = MaskedAutoencoderViT( 380 | embed_dim=1024, depth=24, num_heads=16, 381 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 382 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 383 | return model 384 | 385 | 386 | def mae_vit_huge_dec512d8b(**kwargs): 387 | model = MaskedAutoencoderViT( 388 | embed_dim=1280, depth=32, num_heads=16, 389 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 390 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 391 | return model 392 | 393 | 394 | 395 | 396 | def mae_vit_small_dec128d2b(**kwargs): 397 | model = MaskedAutoencoderViT( 398 | embed_dim=384, depth=12, num_heads=6, 399 | decoder_embed_dim=128, decoder_depth=2, decoder_num_heads=16, 400 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 401 | ) 402 | return model 403 | 404 | 405 | # set recommended archs 406 | mae_vit_tiny = mae_vit_tiny_dec128d2b 407 | mae_vit_small = mae_vit_small_dec128d2b 408 | mae_vit_base = mae_vit_base_dec512d8b # decoder: 512 dim, 8 blocks 409 | mae_vit_large = mae_vit_large_dec512d8b # decoder: 512 dim, 8 blocks 410 | mae_vit_huge = mae_vit_huge_dec512d8b # decoder: 512 dim, 8 blocks 411 | -------------------------------------------------------------------------------- /main_two_branch.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import datetime 4 | import json 5 | import numpy as np 6 | import os 7 | import time 8 | import wandb 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | import timm 16 | 17 | assert timm.__version__ == "0.6.12" # version check 18 | from timm.models.layers import trunc_normal_ 19 | from timm.data.mixup import Mixup 20 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 21 | import timm.optim.optim_factory as optim_factory 22 | 23 | import util.lr_decay as lrd 24 | import util.misc as misc 25 | from util.datasets import build_dataset 26 | from util.pos_embed import interpolate_pos_embed 27 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 28 | 29 | import model_mae_image_loss as models_mae 30 | from engine_two_branch import train_one_epoch, evaluate 31 | 32 | 33 | def get_args_parser(): 34 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) 35 | parser.add_argument('--batch_size', default=64, type=int, 36 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 37 | parser.add_argument('--epochs', default=100, type=int) 38 | parser.add_argument('--accum_iter', default=1, type=int, 39 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 40 | 41 | # Model parameters 42 | parser.add_argument('--model', default='mae_vit_tiny', type=str, metavar='MODEL', 43 | help='Name of model to train') 44 | parser.add_argument('--norm_pix_loss', action='store_true', 45 | help='Use (per-patch) normalized pixels as targets for computing loss') 46 | parser.set_defaults(norm_pix_loss=False) 47 | 48 | parser.add_argument('--input_size', default=224, type=int, 49 | help='images input size') 50 | parser.add_argument('--patch_size', default=16, type=int, 51 | help='images input size') 52 | parser.add_argument('--mask_ratio', default=0.75, type=float, 53 | help='Masking ratio (percentage of removed patches).') 54 | 55 | parser.add_argument('--lambda_weight', default=0.1, type=float, 56 | help='Loss weightage .') 57 | 58 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 59 | help='Drop path rate (default: 0.1)') 60 | 61 | # Optimizer parameters 62 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 63 | help='Clip gradient norm (default: None, no clipping)') 64 | parser.add_argument('--weight_decay', type=float, default=0.05, 65 | help='weight decay (default: 0.05)') 66 | 67 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 68 | help='learning rate (absolute lr)') 69 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 70 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 71 | parser.add_argument('--layer_decay', type=float, default=0.75, 72 | help='layer-wise lr decay from ELECTRA/BEiT') 73 | 74 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 75 | help='lower lr bound for cyclic schedulers that hit 0') 76 | 77 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 78 | help='epochs to warmup LR') 79 | 80 | # Augmentation parameters 81 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 82 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 83 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 84 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 85 | parser.add_argument('--smoothing', type=float, default=0.1, 86 | help='Label smoothing (default: 0.1)') 87 | 88 | # * Random Erase params 89 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 90 | help='Random erase prob (default: 0.25)') 91 | parser.add_argument('--remode', type=str, default='pixel', 92 | help='Random erase mode (default: "pixel")') 93 | parser.add_argument('--recount', type=int, default=1, 94 | help='Random erase count (default: 1)') 95 | parser.add_argument('--resplit', action='store_true', default=False, 96 | help='Do not random erase first (clean) augmentation split') 97 | 98 | # * Mixup params 99 | parser.add_argument('--mixup', type=float, default=0, 100 | help='mixup alpha, mixup enabled if > 0.') 101 | parser.add_argument('--cutmix', type=float, default=0, 102 | help='cutmix alpha, cutmix enabled if > 0.') 103 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 104 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 105 | parser.add_argument('--mixup_prob', type=float, default=1.0, 106 | help='Probability of performing mixup or cutmix when either/both is enabled') 107 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 108 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 109 | parser.add_argument('--mixup_mode', type=str, default='batch', 110 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 111 | 112 | # * Finetuning params 113 | parser.add_argument('--finetune', default='', 114 | help='finetune from checkpoint') 115 | parser.add_argument('--global_pool', action='store_true') 116 | parser.set_defaults(global_pool=True) 117 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 118 | help='Use class token instead of global pool for classification') 119 | 120 | # Dataset parameters 121 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 122 | help='dataset path') 123 | parser.add_argument('--nb_classes', default=1000, type=int, 124 | help='number of the classification types') 125 | parser.add_argument('--subset_size', default=1, type=float, 126 | help='percentage (as decimal) of imagenet subset to use') 127 | parser.add_argument('--perturb_perspective', action='store_true', 128 | help='whether to perform random perspective transform on images') 129 | 130 | parser.add_argument('--output_dir', default='./output_dir', 131 | help='path where to save, empty for no saving') 132 | parser.add_argument('--log_dir', default='./output_dir', 133 | help='path where to tensorboard log') 134 | parser.add_argument('--device', default='cuda', 135 | help='device to use for training / testing') 136 | parser.add_argument('--seed', default=0, type=int) 137 | parser.add_argument('--resume', default='', 138 | help='resume from checkpoint') 139 | 140 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 141 | help='start epoch') 142 | parser.add_argument('--eval', action='store_true', 143 | help='Perform evaluation only') 144 | parser.add_argument('--dist_eval', action='store_true', default=False, 145 | help='Enabling distributed evaluation (recommended during training for faster monitor') 146 | parser.add_argument('--num_workers', default=10, type=int) 147 | parser.add_argument('--pin_mem', action='store_true', 148 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 149 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 150 | parser.set_defaults(pin_mem=True) 151 | 152 | # distributed training parameters 153 | parser.add_argument('--world_size', default=1, type=int, 154 | help='number of distributed processes') 155 | parser.add_argument('--local_rank', default=-1, type=int) 156 | parser.add_argument('--dist_on_itp', action='store_true') 157 | parser.add_argument('--dist_url', default='env://', 158 | help='url used to set up distributed training') 159 | 160 | return parser 161 | 162 | 163 | 164 | def main(args): 165 | # wandb.init(project='MAE-Project', entity='tanmayj2020') 166 | # config = wandb.config 167 | # config.dataset = args.dataset 168 | # config.model = args.model 169 | # config.epoch = args.epochs 170 | # config.classification_loss_ratio= args.lambda_weight 171 | 172 | misc.init_distributed_mode(args) 173 | 174 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 175 | print("{}".format(args).replace(', ', ',\n')) 176 | 177 | device = torch.device(args.device) 178 | 179 | # fix the seed for reproducibility 180 | seed = args.seed + misc.get_rank() 181 | torch.manual_seed(seed) 182 | np.random.seed(seed) 183 | 184 | cudnn.benchmark = True 185 | 186 | dataset_train = build_dataset(is_train=True, args=args) 187 | dataset_val = build_dataset(is_train=False, args=args) 188 | 189 | 190 | if True: # args.distributed: 191 | num_tasks = misc.get_world_size() 192 | global_rank = misc.get_rank() 193 | sampler_train = torch.utils.data.DistributedSampler( 194 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 195 | ) 196 | print("Sampler_train = %s" % str(sampler_train)) 197 | if args.dist_eval: 198 | if len(dataset_val) % num_tasks != 0: 199 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 200 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 201 | 'equal num of samples per-process.') 202 | sampler_val = torch.utils.data.DistributedSampler( 203 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 204 | else: 205 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 206 | else: 207 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 208 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 209 | 210 | if global_rank == 0 and args.log_dir is not None and not args.eval: 211 | os.makedirs(args.log_dir, exist_ok=True) 212 | log_writer = SummaryWriter(log_dir=args.log_dir) 213 | else: 214 | log_writer = None 215 | 216 | data_loader_train = torch.utils.data.DataLoader( 217 | dataset_train, sampler=sampler_train, 218 | batch_size=args.batch_size, 219 | num_workers=args.num_workers, 220 | pin_memory=args.pin_mem, 221 | drop_last=True, 222 | ) 223 | 224 | data_loader_val = torch.utils.data.DataLoader( 225 | dataset_val, sampler=sampler_val, 226 | batch_size=args.batch_size, 227 | num_workers=args.num_workers, 228 | pin_memory=args.pin_mem, 229 | drop_last=False 230 | ) 231 | 232 | mixup_fn = None 233 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 234 | if mixup_active: 235 | print("Mixup is activated!") 236 | mixup_fn = Mixup( 237 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 238 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 239 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 240 | 241 | model = models_mae.__dict__[args.model]( 242 | patch_size= args.patch_size, 243 | img_size= args.input_size, 244 | num_classes=args.nb_classes, 245 | norm_pix_loss = args.norm_pix_loss 246 | ) 247 | 248 | if args.finetune and not args.eval: 249 | checkpoint = torch.load(args.finetune, map_location='cpu') 250 | 251 | print("Load pre-trained checkpoint from: %s" % args.finetune) 252 | checkpoint_model = checkpoint['model'] 253 | state_dict = model.state_dict() 254 | for k in ['head.weight', 'head.bias']: 255 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 256 | print(f"Removing key {k} from pretrained checkpoint") 257 | del checkpoint_model[k] 258 | 259 | # interpolate position embedding 260 | interpolate_pos_embed(model, checkpoint_model) 261 | 262 | # load pre-trained model 263 | msg = model.load_state_dict(checkpoint_model, strict=False) 264 | print(msg) 265 | 266 | if args.global_pool: 267 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 268 | else: 269 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 270 | 271 | # manually initialize fc layer 272 | trunc_normal_(model.head.weight, std=2e-5) 273 | 274 | model.to(device) 275 | 276 | model_without_ddp = model 277 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 278 | 279 | print("Model = %s" % str(model_without_ddp)) 280 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 281 | 282 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 283 | 284 | if args.lr is None: # only base_lr is specified 285 | args.lr = args.blr * eff_batch_size / 256 286 | 287 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 288 | print("actual lr: %.2e" % args.lr) 289 | 290 | print("accumulate grad iterations: %d" % args.accum_iter) 291 | print("effective batch size: %d" % eff_batch_size) 292 | 293 | if args.distributed: 294 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 295 | model_without_ddp = model.module 296 | 297 | # build optimizer with layer-wise lr decay (lrd) 298 | param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) 299 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 300 | loss_scaler = NativeScaler() 301 | 302 | if mixup_fn is not None: 303 | # smoothing is handled with mixup label transform 304 | criterion = SoftTargetCrossEntropy() 305 | elif args.smoothing > 0.: 306 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 307 | else: 308 | criterion = torch.nn.CrossEntropyLoss() 309 | 310 | print("criterion = %s" % str(criterion)) 311 | 312 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 313 | 314 | if args.eval: 315 | test_stats = evaluate(data_loader_val, model, device) 316 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 317 | exit(0) 318 | 319 | print(f"Start training for {args.epochs} epochs") 320 | start_time = time.time() 321 | max_accuracy = 0.0 322 | for epoch in range(args.start_epoch, args.epochs): 323 | if args.distributed: 324 | data_loader_train.sampler.set_epoch(epoch) 325 | train_stats = train_one_epoch( 326 | model, criterion, data_loader_train, 327 | optimizer, device, epoch, loss_scaler, 328 | args.clip_grad, mixup_fn, 329 | log_writer=log_writer, 330 | args=args 331 | ) 332 | if args.output_dir: 333 | if (epoch +1)% 10 == 0: 334 | misc.save_model( 335 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 336 | loss_scaler=loss_scaler, epoch=epoch) 337 | 338 | # test_stats = evaluate(data_loader_val, model, device) # Orig line 339 | test_stats = evaluate(data_loader_val, model_without_ddp, device) 340 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 341 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 342 | print(f'Max accuracy: {max_accuracy:.2f}%') 343 | 344 | if log_writer is not None: 345 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 346 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 347 | log_writer.add_scalar('perf/test_loss', test_stats['testing_loss'], epoch) 348 | 349 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 350 | **{f'test_{k}': v for k, v in test_stats.items()}, 351 | 'epoch': epoch, 352 | 'n_parameters': n_parameters} 353 | # wandb.log({"epoch" : epoch , **train_stats , **test_stats}) 354 | 355 | if args.output_dir and misc.is_main_process(): 356 | if log_writer is not None: 357 | log_writer.flush() 358 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 359 | f.write(json.dumps(log_stats) + "\n") 360 | 361 | total_time = time.time() - start_time 362 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 363 | print('Training time {}'.format(total_time_str)) 364 | 365 | 366 | if __name__ == '__main__': 367 | # Getting the arguments 368 | args = get_args_parser() 369 | # Parsing the arguments 370 | args = args.parse_args() 371 | if args.output_dir: 372 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 373 | # Calling the main function 374 | main(args) 375 | --------------------------------------------------------------------------------