├── can.png ├── .DS_Store ├── config ├── util ├── __pycache__ │ ├── misc.cpython-38.pyc │ ├── lr_sched.cpython-38.pyc │ └── pos_embed.cpython-38.pyc ├── lr_sched.py ├── crop.py ├── lars.py ├── datasets.py ├── lr_decay.py ├── pos_embed.py └── misc.py ├── util_contrastive.py ├── models_vit.py ├── README.md ├── can.yml ├── engine_pretrain.py ├── loss_contrastive.py ├── submitit_finetune.py ├── submitit_linprobe.py ├── submitit_pretrain.py ├── engine_finetune.py ├── main_pretrain.py ├── models_mae.py ├── main_linprobe.py └── main_finetune.py /can.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/can.png -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/.DS_Store -------------------------------------------------------------------------------- /config: -------------------------------------------------------------------------------- 1 | [core] 2 | repositoryformatversion = 0 3 | filemode = true 4 | bare = true 5 | -------------------------------------------------------------------------------- /util/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/util/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_sched.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/util/__pycache__/lr_sched.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/pos_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/util/__pycache__/pos_embed.cpython-38.pyc -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def adjust_learning_rate(optimizer, epoch, args): 4 | """Decay the learning rate with half-cycle cosine after warmup""" 5 | if epoch < args.warmup_epochs: 6 | lr = args.lr * epoch / args.warmup_epochs 7 | else: 8 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 9 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 10 | for param_group in optimizer.param_groups: 11 | if "lr_scale" in param_group: 12 | param_group["lr"] = lr * param_group["lr_scale"] 13 | else: 14 | param_group["lr"] = lr 15 | return lr 16 | -------------------------------------------------------------------------------- /util_contrastive.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | from PIL import ImageFilter 8 | import random 9 | 10 | 11 | class TwoCropTransform: 12 | """Create two crops of the same image""" 13 | def __init__(self, transform): 14 | self.transform = transform 15 | 16 | def __call__(self, x): 17 | return [self.transform(x), self.transform(x)] 18 | 19 | class GaussianBlur(object): 20 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 21 | 22 | def __init__(self, sigma=[.1, 2.]): 23 | self.sigma = sigma 24 | 25 | def __call__(self, x): 26 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 27 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 28 | return x -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from torchvision import transforms 6 | from torchvision.transforms import functional as F 7 | 8 | 9 | class RandomResizedCrop(transforms.RandomResizedCrop): 10 | """ 11 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 12 | This may lead to results different with torchvision's version. 13 | Following BYOL's TF code: 14 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 15 | """ 16 | @staticmethod 17 | def get_params(img, scale, ratio): 18 | width, height = F._get_image_size(img) 19 | area = height * width 20 | 21 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 22 | log_ratio = torch.log(torch.tensor(ratio)) 23 | aspect_ratio = torch.exp( 24 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 25 | ).item() 26 | 27 | w = int(round(math.sqrt(target_area * aspect_ratio))) 28 | h = int(round(math.sqrt(target_area / aspect_ratio))) 29 | 30 | w = min(w, width) 31 | h = min(h, height) 32 | 33 | i = torch.randint(0, height - h + 1, size=(1,)).item() 34 | j = torch.randint(0, width - w + 1, size=(1,)).item() 35 | 36 | return i, j, h, w -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LARS(torch.optim.Optimizer): 5 | """ 6 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 7 | """ 8 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 9 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 10 | super().__init__(params, defaults) 11 | 12 | @torch.no_grad() 13 | def step(self): 14 | for g in self.param_groups: 15 | for p in g['params']: 16 | dp = p.grad 17 | 18 | if dp is None: 19 | continue 20 | 21 | if p.ndim > 1: # if not normalization gamma/beta or bias 22 | dp = dp.add(p, alpha=g['weight_decay']) 23 | param_norm = torch.norm(p) 24 | update_norm = torch.norm(dp) 25 | one = torch.ones_like(param_norm) 26 | q = torch.where(param_norm > 0., 27 | torch.where(update_norm > 0, 28 | (g['trust_coefficient'] * param_norm / update_norm), one), 29 | one) 30 | dp = dp.mul(q) 31 | 32 | param_state = self.state[p] 33 | if 'mu' not in param_state: 34 | param_state['mu'] = torch.zeros_like(p) 35 | mu = param_state['mu'] 36 | mu.mul_(g['momentum']).add_(dp) 37 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import timm.models.vision_transformer 7 | 8 | 9 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 10 | """ Vision Transformer with support for global average pooling 11 | """ 12 | def __init__(self, global_pool=False, **kwargs): 13 | super(VisionTransformer, self).__init__(**kwargs) 14 | 15 | self.global_pool = global_pool 16 | if self.global_pool: 17 | norm_layer = kwargs['norm_layer'] 18 | embed_dim = kwargs['embed_dim'] 19 | self.fc_norm = norm_layer(embed_dim) 20 | 21 | del self.norm # remove the original norm 22 | 23 | def forward_features(self, x): 24 | B = x.shape[0] 25 | x = self.patch_embed(x) 26 | 27 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 28 | x = torch.cat((cls_tokens, x), dim=1) 29 | x = x + self.pos_embed 30 | x = self.pos_drop(x) 31 | 32 | for blk in self.blocks: 33 | x = blk(x) 34 | 35 | if self.global_pool: 36 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 37 | outcome = self.fc_norm(x) 38 | else: 39 | x = self.norm(x) 40 | outcome = x[:, 0] 41 | 42 | return outcome 43 | 44 | 45 | def vit_base_patch16(**kwargs): 46 | model = VisionTransformer( 47 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 48 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 49 | return model 50 | 51 | 52 | def vit_large_patch16(**kwargs): 53 | model = VisionTransformer( 54 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 55 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 56 | return model 57 | 58 | 59 | def vit_huge_patch14(**kwargs): 60 | model = VisionTransformer( 61 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 62 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 63 | return model -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | 4 | from torchvision import datasets, transforms 5 | from util_contrastive import TwoCropTransform 6 | 7 | from timm.data import create_transform 8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | 10 | 11 | def build_dataset(is_train, args): 12 | transform = build_transform(is_train, args) 13 | 14 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 15 | if is_train: 16 | dataset = datasets.ImageFolder(root, transform=TwoCropTransform(transform)) 17 | else: 18 | dataset = datasets.ImageFolder(root, transform=transform) 19 | 20 | print(dataset) 21 | 22 | return dataset 23 | 24 | 25 | def build_transform(is_train, args): 26 | mean = IMAGENET_DEFAULT_MEAN 27 | std = IMAGENET_DEFAULT_STD 28 | # train transform 29 | if is_train: 30 | # this should always dispatch to transforms_imagenet_train 31 | # transform = create_transform( 32 | # input_size=args.input_size, 33 | # is_training=True, 34 | # color_jitter=args.color_jitter, 35 | # auto_augment=args.aa, 36 | # interpolation='bicubic', 37 | # re_prob=args.reprob, 38 | # re_mode=args.remode, 39 | # re_count=args.recount, 40 | # mean=mean, 41 | # std=std, 42 | # ) 43 | # return transform 44 | normalize = transforms.Normalize(mean=mean, std=std) 45 | train_transform = transforms.Compose([ 46 | transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.RandomApply([ 49 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 50 | ], p=0.8), 51 | transforms.RandomGrayscale(p=0.2), 52 | transforms.ToTensor(), 53 | normalize, 54 | ]) 55 | return train_transform 56 | 57 | # eval transform 58 | t = [] 59 | if args.input_size <= 224: 60 | crop_pct = 224 / 256 61 | else: 62 | crop_pct = 1.0 63 | size = int(args.input_size / crop_pct) 64 | t.append( 65 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 66 | ) 67 | t.append(transforms.CenterCrop(args.input_size)) 68 | 69 | t.append(transforms.ToTensor()) 70 | t.append(transforms.Normalize(mean, std)) 71 | return transforms.Compose(t) 72 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CAN: A simple, efficient and scalable contrastive masked autoencoder for learning visual representations 2 | 3 | Official PyTorch implementation of ["A simple, efficient and scalable contrastive masked autoencoder for learning visual representations"](https://arxiv.org/abs/2210.16870). 4 | 5 |

6 | 7 |

8 | 9 | - The original implementation was in JAX+TPU. This re-implementation is in PyTorch+GPU. 10 | 11 | ## Requirements 12 | - Instructions for creating conda enviroment.
13 | 14 | 15 | ``` 16 | conda env create -f can.yml 17 | conda activate can 18 | ``` 19 | 20 | ## Instructions for running CAN
21 | ``` 22 | git clone https://github.com/shlokk/mae-contrastive.git 23 | cd mae-contrastive 24 | ``` 25 | 26 | 27 | Script for running CAN: 28 | 29 | ``` 30 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_pretrain.py \ 31 | --data_path path_to_imagenet --output_dir can_noise_baseline --log_dir can_baseline_logs \ 32 | --num_workers 8 --blr 2.5e-4 --weight_decay 0.05 --model mae_vit_base_patch16 \ 33 | --batch_size 64 --dist_url 'tcp://localhost:10004' --epochs 50 --weight_simclr 0.03 \ 34 | --weight_mae 0.97 --accum_iter 4 35 | ``` 36 | 37 | Script for running MAE baseline: 38 | 39 | ``` 40 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_pretrain.py \ 41 | --data_path path_to_imagenet --output_dir mae_baseline --log_dir mae_baseline_logs \ 42 | --num_workers 8 --blr 1.5e-4 --weight_decay 0.05 --model mae_vit_base_patch16 \ 43 | --batch_size 64 --dist_url 'tcp://localhost:10004' --epochs 50 --weight_simclr 0 \ 44 | --weight_mae 1.0 --accum_iter 4 45 | ``` 46 | 47 | Script for running linear evaluation: 48 | ``` 49 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_linprobe.py \ 50 | --data_path path_to_imagenet --batch_size 512 --model vit_base_patch16 --cls_token \ 51 | --finetune can_noise_baseline/checkpoint-49.pth --epochs 90 --blr 0.1 --weight_decay 0.0 \ 52 | --dist_eval --data_path path_to_imagenet --output_dir mae_baseline_lineval 53 | ``` 54 | 55 | ## Pre-trained models
56 | - We have released pretrained models for 50 epoch pretraining here(https://drive.google.com/file/d/18yVmZmKenM-cZh5o6hmcswvS2ePhuDk_/view?usp=sharing).
57 | - We will be releasing longer epoch training (800 and 1600 epochs) soon. 58 | 59 | 60 | This repo is heavily inspired by MAE repo https://github.com/facebookresearch/mae. 61 | 62 | ## Citation 63 | ```bibtex 64 | @article{mishra2022simple, 65 | title={A simple, efficient and scalable contrastive masked autoencoder for learning visual representations}, 66 | author={Mishra, Shlok and Robinson, Joshua and Chang, Huiwen and Jacobs, David and Sarna, Aaron and Maschinot, Aaron and Krishnan, Dilip}, 67 | journal={arXiv preprint arXiv:2210.16870}, 68 | year={2022} 69 | } 70 | -------------------------------------------------------------------------------- /can.yml: -------------------------------------------------------------------------------- 1 | name: can 2 | channels: 3 | - iopath 4 | - pytorch 5 | - vissl 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - antlr-python-runtime=4.8=py38h32f6830_2 11 | - apex=0.0=py38_cu102_pyt171 12 | - blas=1.0=mkl 13 | - ca-certificates=2021.4.13=h06a4308_1 14 | - certifi=2020.12.5=py38h06a4308_0 15 | - cudatoolkit=10.2.89=hfd86e86_1 16 | - faiss-gpu=1.7.0=py3.8_h080d439_0_cuda10.2 17 | - freetype=2.10.4=h5ab3b9f_0 18 | - fvcore=0.1.3.post20210223=pyhd8ed1ab_0 19 | - hydra-core=1.0.6=pyhd8ed1ab_1 20 | - importlib_resources=5.1.2=py38h578d9bd_0 21 | - intel-openmp=2020.2=254 22 | - iopath=0.1.8=py38 23 | - joblib=1.0.1=pyhd8ed1ab_0 24 | - jpeg=9b=h024ee3a_2 25 | - lcms2=2.12=h3be6417_0 26 | - ld_impl_linux-64=2.33.1=h53a641e_7 27 | - libfaiss=1.7.0=h4fe19ad_0_cuda10.2 28 | - libffi=3.3=he6710b0_2 29 | - libgcc-ng=9.1.0=hdf63c60_0 30 | - libgfortran-ng=7.5.0=h14aa051_19 31 | - libgfortran4=7.5.0=h14aa051_19 32 | - libpng=1.6.37=hbc83047_0 33 | - libstdcxx-ng=9.1.0=hdf63c60_0 34 | - libtiff=4.1.0=h2733197_1 35 | - libuv=1.40.0=h7b6447c_0 36 | - lz4-c=1.9.3=h2531618_0 37 | - mkl=2020.2=256 38 | - mkl-service=2.3.0=py38he904b0f_0 39 | - mkl_fft=1.3.0=py38h54f3939_0 40 | - mkl_random=1.1.1=py38h0573a6f_0 41 | - ncurses=6.2=he6710b0_1 42 | - ninja=1.10.2=hff7bd54_1 43 | - numpy=1.19.2=py38h54aff64_0 44 | - numpy-base=1.19.2=py38hfa32c7d_0 45 | - olefile=0.46=py_0 46 | - omegaconf=2.0.6=py38h578d9bd_0 47 | - openssl=1.1.1k=h27cfd23_0 48 | - pandas=1.2.4=py38h2531618_0 49 | - parameterized=0.8.1=pyhd3deb0d_0 50 | - pillow=8.2.0=py38he98fc37_0 51 | - pip=21.0.1=py38h06a4308_0 52 | - portalocker=1.7.0=py38h578d9bd_1 53 | - python=3.8.8=hdb3f193_5 54 | - python-dateutil=2.8.1=pyhd3eb1b0_0 55 | - python_abi=3.8=1_cp38 56 | - pytorch=1.7.1=py3.8_cuda10.2.89_cudnn7.6.5_0 57 | - pytz=2021.1=pyhd3eb1b0_0 58 | - pyyaml=5.3.1=py38h8df0ef7_1 59 | - readline=8.1=h27cfd23_0 60 | - scikit-learn=0.24.1=py38ha9443f7_0 61 | - scipy=1.6.2=py38h91f5cce_0 62 | - setuptools=52.0.0=py38h06a4308_0 63 | - six=1.15.0=py38h06a4308_0 64 | - sqlite=3.35.4=hdfb4753_0 65 | - tabulate=0.8.9=pyhd8ed1ab_0 66 | - termcolor=1.1.0=py_2 67 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 68 | - tk=8.6.10=hbc83047_0 69 | - torchvision=0.8.2=py38_cu102 70 | - tqdm=4.60.0=pyhd8ed1ab_0 71 | - typing_extensions=3.7.4.3=pyha847dfd_0 72 | - vissl=0.1.5=py38 73 | - wheel=0.36.2=pyhd3eb1b0_0 74 | - xz=5.2.5=h7b6447c_0 75 | - yacs=0.1.6=py_0 76 | - yaml=0.2.5=h516909a_0 77 | - zlib=1.2.11=h7b6447c_3 78 | - zstd=1.4.9=haebb681_0 79 | - pip: 80 | - absl-py==1.3.0 81 | - cachetools==5.2.0 82 | - charset-normalizer==2.1.1 83 | - filelock==3.8.0 84 | - google-auth==2.14.1 85 | - google-auth-oauthlib==0.4.6 86 | - grpcio==1.50.0 87 | - huggingface-hub==0.10.1 88 | - idna==3.4 89 | - importlib-metadata==5.0.0 90 | - markdown==3.4.1 91 | - markupsafe==2.1.1 92 | - oauthlib==3.2.2 93 | - packaging==21.3 94 | - protobuf==3.20.3 95 | - pyasn1==0.4.8 96 | - pyasn1-modules==0.2.8 97 | - pyparsing==3.0.9 98 | - requests==2.28.1 99 | - requests-oauthlib==1.3.1 100 | - rsa==4.9 101 | - tensorboard==2.11.0 102 | - tensorboard-data-server==0.6.1 103 | - tensorboard-plugin-wit==1.8.1 104 | - timm==0.3.2 105 | - urllib3==1.26.12 106 | - werkzeug==2.2.2 107 | - zipp==3.10.0 108 | prefix: /vulcanscratch/shlokm/Ana/envs/vissl 109 | -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable 4 | 5 | import torch 6 | 7 | import util.misc as misc 8 | import util.lr_sched as lr_sched 9 | 10 | 11 | def train_one_epoch(model: torch.nn.Module, 12 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 13 | device: torch.device, epoch: int, loss_scaler, 14 | log_writer=None, 15 | args=None): 16 | model.train(True) 17 | metric_logger = misc.MetricLogger(delimiter=" ") 18 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 19 | metric_logger.add_meter('loss_contrastive', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 20 | metric_logger.add_meter('loss_noise', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 21 | header = 'Epoch: [{}]'.format(epoch) 22 | print_freq = 20 23 | 24 | accum_iter = args.accum_iter 25 | weight_mae = args.weight_mae 26 | weight_simclr = args.weight_simclr 27 | weight_noise = args.weight_noise 28 | 29 | optimizer.zero_grad() 30 | 31 | if log_writer is not None: 32 | print('log_dir: {}'.format(log_writer.log_dir)) 33 | 34 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 35 | 36 | # we use a per iteration (instead of per epoch) lr scheduler 37 | if data_iter_step % accum_iter == 0: 38 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 39 | 40 | samples = torch.cat([samples[0], samples[1]], dim=0) # contrastive hack 41 | samples = samples.to(device, non_blocking=True) 42 | 43 | with torch.cuda.amp.autocast(): 44 | loss, loss_contrastive, loss_noise, _, _ = model(samples, mask_ratio=args.mask_ratio) 45 | 46 | loss_recon = weight_noise * loss_noise + (1-weight_noise) * loss 47 | loss = weight_mae * loss_recon + weight_simclr * loss_contrastive 48 | loss_value = loss.item() 49 | loss_contrastive_value = loss_contrastive.item() 50 | loss_noise_value = loss_noise.item() 51 | 52 | if not math.isfinite(loss_value): 53 | print("Loss is {}, stopping training".format(loss_value)) 54 | sys.exit(1) 55 | 56 | loss /= accum_iter 57 | loss_scaler(loss, optimizer, parameters=model.parameters(), 58 | update_grad=(data_iter_step + 1) % accum_iter == 0) 59 | if (data_iter_step + 1) % accum_iter == 0: 60 | optimizer.zero_grad() 61 | 62 | torch.cuda.synchronize() 63 | 64 | metric_logger.update(loss=loss_value) 65 | metric_logger.update(loss_contrastive=loss_contrastive_value) 66 | metric_logger.update(loss_noise=loss_noise_value) 67 | 68 | lr = optimizer.param_groups[0]["lr"] 69 | metric_logger.update(lr=lr) 70 | 71 | loss_value_reduce = misc.all_reduce_mean(loss_value) 72 | loss_contrastive_value_reduce = misc.all_reduce_mean(loss_contrastive_value) 73 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 74 | """ We use epoch_1000x as the x-axis in tensorboard. 75 | This calibrates different curves when batch size changes. 76 | """ 77 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 78 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 79 | log_writer.add_scalar('lr', lr, epoch_1000x) 80 | 81 | 82 | # gather the stats from all processes 83 | metric_logger.synchronize_between_processes() 84 | print("Averaged stats:", metric_logger) 85 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /loss_contrastive.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class SupConLoss(nn.Module): 8 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 9 | It also supports the unsupervised contrastive loss in SimCLR""" 10 | def __init__(self, temperature=0.1, contrast_mode='all', 11 | base_temperature=0.1): 12 | super(SupConLoss, self).__init__() 13 | self.temperature = temperature 14 | self.contrast_mode = contrast_mode 15 | self.base_temperature = base_temperature 16 | 17 | def forward(self, features, labels=None, mask=None): 18 | """Compute loss for model. If both `labels` and `mask` are None, 19 | it degenerates to SimCLR unsupervised loss: 20 | https://arxiv.org/pdf/2002.05709.pdf 21 | 22 | Args: 23 | features: hidden vector of shape [bsz, n_views, ...]. 24 | labels: ground truth of shape [bsz]. 25 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 26 | has the same class as sample i. Can be asymmetric. 27 | Returns: 28 | A loss scalar. 29 | """ 30 | device = (torch.device('cuda') 31 | if features.is_cuda 32 | else torch.device('cpu')) 33 | 34 | if len(features.shape) < 3: 35 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 36 | 'at least 3 dimensions are required') 37 | if len(features.shape) > 3: 38 | features = features.view(features.shape[0], features.shape[1], -1) 39 | 40 | batch_size = features.shape[0] 41 | if labels is not None and mask is not None: 42 | raise ValueError('Cannot define both `labels` and `mask`') 43 | elif labels is None and mask is None: 44 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 45 | elif labels is not None: 46 | labels = labels.contiguous().view(-1, 1) 47 | if labels.shape[0] != batch_size: 48 | raise ValueError('Num of labels does not match num of features') 49 | mask = torch.eq(labels, labels.T).float().to(device) 50 | else: 51 | mask = mask.float().to(device) 52 | 53 | contrast_count = features.shape[1] 54 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 55 | if self.contrast_mode == 'one': 56 | anchor_feature = features[:, 0] 57 | anchor_count = 1 58 | elif self.contrast_mode == 'all': 59 | anchor_feature = contrast_feature 60 | anchor_count = contrast_count 61 | else: 62 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 63 | 64 | # compute logits 65 | anchor_dot_contrast = torch.div( 66 | torch.matmul(anchor_feature, contrast_feature.T), 67 | self.temperature) 68 | # for numerical stability 69 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 70 | logits = anchor_dot_contrast - logits_max.detach() 71 | # logits = anchor_dot_contrast 72 | 73 | # print(logits.mean()) 74 | 75 | # tile mask 76 | mask = mask.repeat(anchor_count, contrast_count) 77 | # mask-out self-contrast cases 78 | logits_mask = torch.scatter( 79 | torch.ones_like(mask), 80 | 1, 81 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 82 | 0 83 | ) 84 | mask = mask * logits_mask 85 | 86 | # compute log_prob 87 | exp_logits = torch.exp(logits) * logits_mask 88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 89 | 90 | # compute mean of log-likelihood over positive 91 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 92 | # import pdb 93 | # pdb.set_trace() 94 | 95 | # loss 96 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 97 | loss = loss.view(anchor_count, batch_size).mean() 98 | 99 | return loss -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | 6 | def get_1d_sincos_pos_embed(x: torch.Tensor, dim: int): 7 | """From: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py""" 8 | half_dim = dim // 2 9 | emb = math.log(10000) / (half_dim - 1) 10 | emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb) 11 | emb = x[:, None] * emb[None, :] 12 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 13 | return emb 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 19 | # MoCo v3: https://github.com/facebookresearch/moco-v3 20 | # -------------------------------------------------------- 21 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 22 | """ 23 | grid_size: int of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 26 | """ 27 | grid_h = np.arange(grid_size, dtype=np.float32) 28 | grid_w = np.arange(grid_size, dtype=np.float32) 29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 30 | grid = np.stack(grid, axis=0) 31 | 32 | grid = grid.reshape([2, 1, grid_size, grid_size]) 33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 34 | if cls_token: 35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 36 | return pos_embed 37 | 38 | 39 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 40 | assert embed_dim % 2 == 0 41 | 42 | # use half of dimensions to encode grid_h 43 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 44 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 45 | 46 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 47 | return emb 48 | 49 | 50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 51 | """ 52 | embed_dim: output dimension for each position 53 | pos: a list of positions to be encoded: size (M,) 54 | out: (M, D) 55 | """ 56 | assert embed_dim % 2 == 0 57 | omega = np.arange(embed_dim // 2, dtype=np.float) 58 | omega /= embed_dim / 2. 59 | omega = 1. / 10000**omega # (D/2,) 60 | 61 | pos = pos.reshape(-1) # (M,) 62 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 63 | 64 | emb_sin = np.sin(out) # (M, D/2) 65 | emb_cos = np.cos(out) # (M, D/2) 66 | 67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 68 | return emb 69 | 70 | 71 | # -------------------------------------------------------- 72 | # Interpolate position embeddings for high-resolution 73 | # References: 74 | # DeiT: https://github.com/facebookresearch/deit 75 | # -------------------------------------------------------- 76 | def interpolate_pos_embed(model, checkpoint_model): 77 | if 'pos_embed' in checkpoint_model: 78 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 79 | embedding_size = pos_embed_checkpoint.shape[-1] 80 | num_patches = model.patch_embed.num_patches 81 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 82 | # height (== width) for the checkpoint position embedding 83 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 84 | # height (== width) for the new position embedding 85 | new_size = int(num_patches ** 0.5) 86 | # class_token and dist_token are kept unchanged 87 | if orig_size != new_size: 88 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 89 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 90 | # only the position tokens are interpolated 91 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 92 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 93 | pos_tokens = torch.nn.functional.interpolate( 94 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 95 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 96 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 97 | checkpoint_model['pos_embed'] = new_pos_embed 98 | -------------------------------------------------------------------------------- /submitit_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import uuid 4 | from pathlib import Path 5 | 6 | import main_finetune as classification 7 | import submitit 8 | 9 | 10 | def parse_args(): 11 | classification_parser = classification.get_args_parser() 12 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser]) 13 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 14 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 15 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 16 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 17 | 18 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 19 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 20 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 21 | return parser.parse_args() 22 | 23 | 24 | def get_shared_folder() -> Path: 25 | user = os.getenv("USER") 26 | if Path("/checkpoint/").is_dir(): 27 | p = Path(f"/checkpoint/{user}/experiments") 28 | p.mkdir(exist_ok=True) 29 | return p 30 | raise RuntimeError("No shared folder available") 31 | 32 | 33 | def get_init_file(): 34 | # Init file must not exist, but it's parent dir must exist. 35 | os.makedirs(str(get_shared_folder()), exist_ok=True) 36 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 37 | if init_file.exists(): 38 | os.remove(str(init_file)) 39 | return init_file 40 | 41 | 42 | class Trainer(object): 43 | def __init__(self, args): 44 | self.args = args 45 | 46 | def __call__(self): 47 | import main_finetune as classification 48 | 49 | self._setup_gpu_args() 50 | classification.main(self.args) 51 | 52 | def checkpoint(self): 53 | import os 54 | import submitit 55 | 56 | self.args.dist_url = get_init_file().as_uri() 57 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 58 | if os.path.exists(checkpoint_file): 59 | self.args.resume = checkpoint_file 60 | print("Requeuing ", self.args) 61 | empty_trainer = type(self)(self.args) 62 | return submitit.helpers.DelayedSubmission(empty_trainer) 63 | 64 | def _setup_gpu_args(self): 65 | import submitit 66 | from pathlib import Path 67 | 68 | job_env = submitit.JobEnvironment() 69 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 70 | self.args.log_dir = self.args.output_dir 71 | self.args.gpu = job_env.local_rank 72 | self.args.rank = job_env.global_rank 73 | self.args.world_size = job_env.num_tasks 74 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 75 | 76 | 77 | def main(): 78 | args = parse_args() 79 | if args.job_dir == "": 80 | args.job_dir = get_shared_folder() / "%j" 81 | 82 | # Note that the folder will depend on the job_id, to easily track experiments 83 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 84 | 85 | num_gpus_per_node = args.ngpus 86 | nodes = args.nodes 87 | timeout_min = args.timeout 88 | 89 | partition = args.partition 90 | kwargs = {} 91 | if args.use_volta32: 92 | kwargs['slurm_constraint'] = 'volta32gb' 93 | if args.comment: 94 | kwargs['slurm_comment'] = args.comment 95 | 96 | executor.update_parameters( 97 | mem_gb=40 * num_gpus_per_node, 98 | gpus_per_node=num_gpus_per_node, 99 | tasks_per_node=num_gpus_per_node, # one task per GPU 100 | cpus_per_task=10, 101 | nodes=nodes, 102 | timeout_min=timeout_min, 103 | # Below are cluster dependent parameters 104 | slurm_partition=partition, 105 | slurm_signal_delay_s=120, 106 | **kwargs 107 | ) 108 | 109 | executor.update_parameters(name="mae") 110 | 111 | args.dist_url = get_init_file().as_uri() 112 | args.output_dir = args.job_dir 113 | 114 | trainer = Trainer(args) 115 | job = executor.submit(trainer) 116 | 117 | # print("Submitted job_id:", job.job_id) 118 | print(job.job_id) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /submitit_linprobe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import uuid 4 | from pathlib import Path 5 | 6 | import main_linprobe as classification 7 | import submitit 8 | 9 | 10 | def parse_args(): 11 | classification_parser = classification.get_args_parser() 12 | parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser]) 13 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 14 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 15 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 16 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 17 | 18 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 19 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 20 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 21 | return parser.parse_args() 22 | 23 | 24 | def get_shared_folder() -> Path: 25 | user = os.getenv("USER") 26 | if Path("/checkpoint/").is_dir(): 27 | p = Path(f"/checkpoint/{user}/experiments") 28 | p.mkdir(exist_ok=True) 29 | return p 30 | raise RuntimeError("No shared folder available") 31 | 32 | 33 | def get_init_file(): 34 | # Init file must not exist, but it's parent dir must exist. 35 | os.makedirs(str(get_shared_folder()), exist_ok=True) 36 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 37 | if init_file.exists(): 38 | os.remove(str(init_file)) 39 | return init_file 40 | 41 | 42 | class Trainer(object): 43 | def __init__(self, args): 44 | self.args = args 45 | 46 | def __call__(self): 47 | import main_linprobe as classification 48 | 49 | self._setup_gpu_args() 50 | classification.main(self.args) 51 | 52 | def checkpoint(self): 53 | import os 54 | import submitit 55 | 56 | self.args.dist_url = get_init_file().as_uri() 57 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 58 | if os.path.exists(checkpoint_file): 59 | self.args.resume = checkpoint_file 60 | print("Requeuing ", self.args) 61 | empty_trainer = type(self)(self.args) 62 | return submitit.helpers.DelayedSubmission(empty_trainer) 63 | 64 | def _setup_gpu_args(self): 65 | import submitit 66 | from pathlib import Path 67 | 68 | job_env = submitit.JobEnvironment() 69 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 70 | self.args.log_dir = self.args.output_dir 71 | self.args.gpu = job_env.local_rank 72 | self.args.rank = job_env.global_rank 73 | self.args.world_size = job_env.num_tasks 74 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 75 | 76 | 77 | def main(): 78 | args = parse_args() 79 | if args.job_dir == "": 80 | args.job_dir = get_shared_folder() / "%j" 81 | 82 | # Note that the folder will depend on the job_id, to easily track experiments 83 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 84 | 85 | num_gpus_per_node = args.ngpus 86 | nodes = args.nodes 87 | timeout_min = args.timeout 88 | 89 | partition = args.partition 90 | kwargs = {} 91 | if args.use_volta32: 92 | kwargs['slurm_constraint'] = 'volta32gb' 93 | if args.comment: 94 | kwargs['slurm_comment'] = args.comment 95 | 96 | executor.update_parameters( 97 | mem_gb=40 * num_gpus_per_node, 98 | gpus_per_node=num_gpus_per_node, 99 | tasks_per_node=num_gpus_per_node, # one task per GPU 100 | cpus_per_task=10, 101 | nodes=nodes, 102 | timeout_min=timeout_min, 103 | # Below are cluster dependent parameters 104 | slurm_partition=partition, 105 | slurm_signal_delay_s=120, 106 | **kwargs 107 | ) 108 | 109 | executor.update_parameters(name="mae") 110 | 111 | args.dist_url = get_init_file().as_uri() 112 | args.output_dir = args.job_dir 113 | 114 | trainer = Trainer(args) 115 | job = executor.submit(trainer) 116 | 117 | # print("Submitted job_id:", job.job_id) 118 | print(job.job_id) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import uuid 4 | from pathlib import Path 5 | 6 | import main_pretrain as trainer 7 | import submitit 8 | 9 | 10 | def parse_args(): 11 | trainer_parser = trainer.get_args_parser() 12 | parser = argparse.ArgumentParser("Submitit for MAE pretrain", parents=[trainer_parser]) 13 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 14 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 15 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 16 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 17 | 18 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 19 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 20 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 21 | return parser.parse_args() 22 | 23 | 24 | def get_shared_folder() -> Path: 25 | user = os.getenv("USER") 26 | return "/fs/vulcan-projects/jigsaw_selfsup_shlokm/dv1/mae/checkpoint/" 27 | # if Path("/checkpoint/").is_dir(): 28 | # p = Path(f"/checkpoint/{user}/experiments") 29 | # p.mkdir(exist_ok=True) 30 | # return p 31 | raise RuntimeError("No shared folder available") 32 | 33 | 34 | def get_init_file(): 35 | # Init file must not exist, but it's parent dir must exist. 36 | os.makedirs(str(get_shared_folder()), exist_ok=True) 37 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 38 | if init_file.exists(): 39 | os.remove(str(init_file)) 40 | return init_file 41 | 42 | 43 | class Trainer(object): 44 | def __init__(self, args): 45 | self.args = args 46 | 47 | def __call__(self): 48 | import main_pretrain as trainer 49 | 50 | self._setup_gpu_args() 51 | trainer.main(self.args) 52 | 53 | def checkpoint(self): 54 | import os 55 | import submitit 56 | 57 | self.args.dist_url = get_init_file().as_uri() 58 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 59 | if os.path.exists(checkpoint_file): 60 | self.args.resume = checkpoint_file 61 | print("Requeuing ", self.args) 62 | empty_trainer = type(self)(self.args) 63 | return submitit.helpers.DelayedSubmission(empty_trainer) 64 | 65 | def _setup_gpu_args(self): 66 | import submitit 67 | from pathlib import Path 68 | 69 | job_env = submitit.JobEnvironment() 70 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 71 | self.args.log_dir = self.args.output_dir 72 | self.args.gpu = job_env.local_rank 73 | self.args.rank = job_env.global_rank 74 | self.args.world_size = job_env.num_tasks 75 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 76 | 77 | 78 | def main(): 79 | args = parse_args() 80 | if args.job_dir == "": 81 | args.job_dir = get_shared_folder() / "%j" 82 | 83 | # Note that the folder will depend on the job_id, to easily track experiments 84 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 85 | 86 | num_gpus_per_node = args.ngpus 87 | nodes = args.nodes 88 | timeout_min = args.timeout 89 | 90 | partition = args.partition 91 | kwargs = {} 92 | if args.use_volta32: 93 | kwargs['slurm_constraint'] = 'volta32gb' 94 | if args.comment: 95 | kwargs['slurm_comment'] = args.comment 96 | 97 | executor.update_parameters( 98 | mem_gb=40 * num_gpus_per_node, 99 | gpus_per_node=num_gpus_per_node, 100 | tasks_per_node=num_gpus_per_node, # one task per GPU 101 | cpus_per_task=10, 102 | nodes=nodes, 103 | timeout_min=timeout_min, # max is 60 * 72 104 | # Below are cluster dependent parameters 105 | slurm_partition=partition, 106 | slurm_signal_delay_s=120, 107 | **kwargs 108 | ) 109 | 110 | executor.update_parameters(name="mae") 111 | 112 | args.dist_url = get_init_file().as_uri() 113 | args.output_dir = args.job_dir 114 | 115 | trainer = Trainer(args) 116 | job = executor.submit(trainer) 117 | 118 | # print("Submitted job_id:", job.job_id) 119 | print(job.job_id) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy 20 | 21 | import util.misc as misc 22 | import util.lr_sched as lr_sched 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 28 | mixup_fn: Optional[Mixup] = None, log_writer=None, 29 | args=None): 30 | model.train(True) 31 | metric_logger = misc.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 20 35 | 36 | accum_iter = args.accum_iter 37 | 38 | optimizer.zero_grad() 39 | 40 | if log_writer is not None: 41 | print('log_dir: {}'.format(log_writer.log_dir)) 42 | 43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | 45 | # we use a per iteration (instead of per epoch) lr scheduler 46 | if data_iter_step % accum_iter == 0: 47 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 48 | 49 | samples = samples.to(device, non_blocking=True) 50 | targets = targets.to(device, non_blocking=True) 51 | 52 | if mixup_fn is not None: 53 | samples, targets = mixup_fn(samples, targets) 54 | 55 | with torch.cuda.amp.autocast(): 56 | outputs = model(samples) 57 | loss = criterion(outputs, targets) 58 | 59 | loss_value = loss.item() 60 | if data_iter_step%10==0: 61 | print("Loss is {}".format(loss_value)) 62 | 63 | if not math.isfinite(loss_value): 64 | print("Loss is {}, stopping training".format(loss_value)) 65 | sys.exit(1) 66 | 67 | loss /= accum_iter 68 | loss_scaler(loss, optimizer, clip_grad=max_norm, 69 | parameters=model.parameters(), create_graph=False, 70 | update_grad=(data_iter_step + 1) % accum_iter == 0) 71 | if (data_iter_step + 1) % accum_iter == 0: 72 | optimizer.zero_grad() 73 | 74 | torch.cuda.synchronize() 75 | 76 | metric_logger.update(loss=loss_value) 77 | min_lr = 10. 78 | max_lr = 0. 79 | for group in optimizer.param_groups: 80 | min_lr = min(min_lr, group["lr"]) 81 | max_lr = max(max_lr, group["lr"]) 82 | 83 | metric_logger.update(lr=max_lr) 84 | 85 | loss_value_reduce = misc.all_reduce_mean(loss_value) 86 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 87 | """ We use epoch_1000x as the x-axis in tensorboard. 88 | This calibrates different curves when batch size changes. 89 | """ 90 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 91 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 92 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 93 | 94 | # gather the stats from all processes 95 | metric_logger.synchronize_between_processes() 96 | print("Averaged stats:", metric_logger) 97 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 98 | 99 | 100 | @torch.no_grad() 101 | def evaluate(data_loader, model, device): 102 | criterion = torch.nn.CrossEntropyLoss() 103 | 104 | metric_logger = misc.MetricLogger(delimiter=" ") 105 | header = 'Test:' 106 | 107 | # switch to evaluation mode 108 | model.eval() 109 | 110 | for batch in metric_logger.log_every(data_loader, 10, header): 111 | images = batch[0] 112 | target = batch[-1] 113 | images = images.to(device, non_blocking=True) 114 | target = target.to(device, non_blocking=True) 115 | 116 | # compute output 117 | with torch.cuda.amp.autocast(): 118 | output = model(images) 119 | loss = criterion(output, target) 120 | 121 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 122 | 123 | batch_size = images.shape[0] 124 | metric_logger.update(loss=loss.item()) 125 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 126 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 127 | # gather the stats from all processes 128 | metric_logger.synchronize_between_processes() 129 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 130 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 131 | 132 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import numpy as np 5 | import os 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils.tensorboard import SummaryWriter 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | 15 | import timm 16 | 17 | #assert timm.__version__ == "0.3.2" # version check 18 | import timm.optim.optim_factory as optim_factory 19 | 20 | import util.misc as misc 21 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 22 | from util_contrastive import TwoCropTransform 23 | from util_contrastive import GaussianBlur 24 | 25 | import models_mae 26 | 27 | from engine_pretrain import train_one_epoch 28 | 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser('CAN pre-training', add_help=False) 32 | parser.add_argument('--batch_size', default=64, type=int, 33 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 34 | parser.add_argument('--epochs', default=400, type=int) 35 | parser.add_argument('--accum_iter', default=1, type=int, 36 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 37 | 38 | # Model parameters 39 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 40 | help='Name of model to train') 41 | 42 | parser.add_argument('--input_size', default=224, type=int, 43 | help='images input size') 44 | 45 | parser.add_argument('--mask_ratio', default=0.75, type=float, 46 | help='Masking ratio (percentage of removed patches).') 47 | 48 | parser.add_argument('--norm_pix_loss', action='store_true', 49 | help='Use (per-patch) normalized pixels as targets for computing loss') 50 | parser.set_defaults(norm_pix_loss=False) 51 | 52 | parser.add_argument('--weight_mae', default=0.97, type=float, 53 | help='Loss weight of mae (default: 0.97).') 54 | parser.add_argument('--weight_simclr', default=0.03, type=float, 55 | help='Loss weight of simclr (default: 0.03).') 56 | 57 | 58 | parser.add_argument('--noise_loss', action='store_true') 59 | parser.add_argument('--std', default=0.05, type=float, 60 | help='Standard deviation of noise added to loss.') 61 | parser.add_argument('--weight_noise', default=0.3, type=float, 62 | help='Weight allocated to noise loss.') 63 | 64 | 65 | # Optimizer parameters 66 | parser.add_argument('--weight_decay', type=float, default=0.05, 67 | help='weight decay (default: 0.05)') 68 | 69 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 70 | help='learning rate (absolute lr)') 71 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 72 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 73 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 74 | help='lower lr bound for cyclic schedulers that hit 0') 75 | 76 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 77 | help='epochs to warmup LR') 78 | 79 | # Dataset parameters 80 | parser.add_argument('--data_path', default='/data/scratch/joshrob/data/imagenet100/', type=str, 81 | help='dataset path') 82 | 83 | parser.add_argument('--output_dir', default='./output_dir', 84 | help='path where to save, empty for no saving') 85 | parser.add_argument('--log_dir', default='./output_dir', 86 | help='path where to tensorboard log') 87 | parser.add_argument('--device', default='cuda', 88 | help='device to use for training / testing') 89 | parser.add_argument('--seed', default=0, type=int) 90 | parser.add_argument('--resume', default='', 91 | help='resume from checkpoint') 92 | 93 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 94 | help='start epoch') 95 | parser.add_argument('--num_workers', default=10, type=int) 96 | parser.add_argument('--pin_mem', action='store_true', 97 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 98 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 99 | parser.set_defaults(pin_mem=True) 100 | 101 | # distributed training parameters 102 | parser.add_argument('--world_size', default=1, type=int, 103 | help='number of distributed processes') 104 | parser.add_argument('--local_rank', default=-1, type=int) 105 | parser.add_argument('--dist_on_itp', action='store_true') 106 | parser.add_argument('--dist_url', default='env://', 107 | help='url used to set up distributed training') 108 | 109 | return parser 110 | 111 | 112 | def main(args): 113 | misc.init_distributed_mode(args) 114 | 115 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 116 | print("{}".format(args).replace(', ', ',\n')) 117 | 118 | device = torch.device(args.device) 119 | 120 | # fix the seed for reproducibility 121 | seed = args.seed + misc.get_rank() 122 | torch.manual_seed(seed) 123 | np.random.seed(seed) 124 | 125 | cudnn.benchmark = True 126 | 127 | # simple augmentation 128 | # transform_train = transforms.Compose([ 129 | # transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 130 | # transforms.RandomHorizontalFlip(), 131 | # transforms.ToTensor(), 132 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 133 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 134 | transform_train = transforms.Compose([ 135 | transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)), # hardcoded TODO 136 | transforms.RandomHorizontalFlip(), 137 | transforms.RandomApply([ 138 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 139 | ], p=0.8), 140 | transforms.RandomGrayscale(p=0.2), 141 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 142 | transforms.ToTensor(), 143 | normalize, 144 | ]) 145 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=TwoCropTransform(transform_train)) 146 | print(dataset_train) 147 | 148 | if True: # args.distributed: 149 | num_tasks = misc.get_world_size() 150 | global_rank = misc.get_rank() 151 | sampler_train = torch.utils.data.DistributedSampler( 152 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 153 | ) 154 | print("Sampler_train = %s" % str(sampler_train)) 155 | else: 156 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 157 | 158 | if global_rank == 0 and args.log_dir is not None: 159 | os.makedirs(args.log_dir, exist_ok=True) 160 | log_writer = SummaryWriter(log_dir=args.log_dir) 161 | else: 162 | log_writer = None 163 | 164 | data_loader_train = torch.utils.data.DataLoader( 165 | dataset_train, sampler=sampler_train, 166 | batch_size=args.batch_size, 167 | num_workers=args.num_workers, 168 | pin_memory=args.pin_mem, 169 | drop_last=True, 170 | ) 171 | 172 | # define the model 173 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, noise_loss=args.noise_loss) 174 | 175 | model.to(device) 176 | 177 | model_without_ddp = model 178 | print("Model = %s" % str(model_without_ddp)) 179 | 180 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 181 | 182 | if args.lr is None: # only base_lr is specified 183 | args.lr = args.blr * eff_batch_size / 256 184 | 185 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 186 | print("actual lr: %.2e" % args.lr) 187 | 188 | print("accumulate grad iterations: %d" % args.accum_iter) 189 | print("effective batch size: %d" % eff_batch_size) 190 | 191 | if args.distributed: 192 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 193 | model_without_ddp = model.module 194 | 195 | # following timm: set wd as 0 for bias and norm layers 196 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 197 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 198 | print(optimizer) 199 | loss_scaler = NativeScaler() 200 | 201 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 202 | 203 | print(f"Start training for {args.epochs} epochs") 204 | start_time = time.time() 205 | for epoch in range(args.start_epoch, args.epochs): 206 | if args.distributed: 207 | data_loader_train.sampler.set_epoch(epoch) 208 | train_stats = train_one_epoch( 209 | model, data_loader_train, 210 | optimizer, device, epoch, loss_scaler, 211 | log_writer=log_writer, 212 | args=args 213 | ) 214 | if args.output_dir and (epoch % 1 == 0 or epoch + 1 == args.epochs): 215 | misc.save_model( 216 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 217 | loss_scaler=loss_scaler, epoch=epoch) 218 | 219 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 220 | 'epoch': epoch,} 221 | 222 | if args.output_dir and misc.is_main_process(): 223 | if log_writer is not None: 224 | log_writer.flush() 225 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 226 | f.write(json.dumps(log_stats) + "\n") 227 | 228 | total_time = time.time() - start_time 229 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 230 | print('Training time {}'.format(total_time_str)) 231 | 232 | 233 | if __name__ == '__main__': 234 | args = get_args_parser() 235 | args = args.parse_args() 236 | #args.local_rank = os.environ['LOCAL_RANK'] 237 | if args.output_dir: 238 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 239 | main(args) 240 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import datetime 3 | import os 4 | import time 5 | from collections import defaultdict, deque 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch._six import inf 11 | 12 | 13 | class SmoothedValue(object): 14 | """Track a series of values and provide access to smoothed values over a 15 | window or the global series average. 16 | """ 17 | 18 | def __init__(self, window_size=20, fmt=None): 19 | if fmt is None: 20 | fmt = "{median:.4f} ({global_avg:.4f})" 21 | self.deque = deque(maxlen=window_size) 22 | self.total = 0.0 23 | self.count = 0 24 | self.fmt = fmt 25 | 26 | def update(self, value, n=1): 27 | self.deque.append(value) 28 | self.count += n 29 | self.total += value * n 30 | 31 | def synchronize_between_processes(self): 32 | """ 33 | Warning: does not synchronize the deque! 34 | """ 35 | if not is_dist_avail_and_initialized(): 36 | return 37 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 38 | dist.barrier() 39 | dist.all_reduce(t) 40 | t = t.tolist() 41 | self.count = int(t[0]) 42 | self.total = t[1] 43 | 44 | @property 45 | def median(self): 46 | d = torch.tensor(list(self.deque)) 47 | return d.median().item() 48 | 49 | @property 50 | def avg(self): 51 | d = torch.tensor(list(self.deque), dtype=torch.float32) 52 | return d.mean().item() 53 | 54 | @property 55 | def global_avg(self): 56 | return self.total / self.count 57 | 58 | @property 59 | def max(self): 60 | return max(self.deque) 61 | 62 | @property 63 | def value(self): 64 | return self.deque[-1] 65 | 66 | def __str__(self): 67 | return self.fmt.format( 68 | median=self.median, 69 | avg=self.avg, 70 | global_avg=self.global_avg, 71 | max=self.max, 72 | value=self.value) 73 | 74 | 75 | class MetricLogger(object): 76 | def __init__(self, delimiter="\t"): 77 | self.meters = defaultdict(SmoothedValue) 78 | self.delimiter = delimiter 79 | 80 | def update(self, **kwargs): 81 | for k, v in kwargs.items(): 82 | if v is None: 83 | continue 84 | if isinstance(v, torch.Tensor): 85 | v = v.item() 86 | assert isinstance(v, (float, int)) 87 | self.meters[k].update(v) 88 | 89 | def __getattr__(self, attr): 90 | if attr in self.meters: 91 | return self.meters[attr] 92 | if attr in self.__dict__: 93 | return self.__dict__[attr] 94 | raise AttributeError("'{}' object has no attribute '{}'".format( 95 | type(self).__name__, attr)) 96 | 97 | def __str__(self): 98 | loss_str = [] 99 | for name, meter in self.meters.items(): 100 | loss_str.append( 101 | "{}: {}".format(name, str(meter)) 102 | ) 103 | return self.delimiter.join(loss_str) 104 | 105 | def synchronize_between_processes(self): 106 | for meter in self.meters.values(): 107 | meter.synchronize_between_processes() 108 | 109 | def add_meter(self, name, meter): 110 | self.meters[name] = meter 111 | 112 | def log_every(self, iterable, print_freq, header=None): 113 | i = 0 114 | if not header: 115 | header = '' 116 | start_time = time.time() 117 | end = time.time() 118 | iter_time = SmoothedValue(fmt='{avg:.4f}') 119 | data_time = SmoothedValue(fmt='{avg:.4f}') 120 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 121 | log_msg = [ 122 | header, 123 | '[{0' + space_fmt + '}/{1}]', 124 | 'eta: {eta}', 125 | '{meters}', 126 | 'time: {time}', 127 | 'data: {data}' 128 | ] 129 | if torch.cuda.is_available(): 130 | log_msg.append('max mem: {memory:.0f}') 131 | log_msg = self.delimiter.join(log_msg) 132 | MB = 1024.0 * 1024.0 133 | for obj in iterable: 134 | data_time.update(time.time() - end) 135 | yield obj 136 | iter_time.update(time.time() - end) 137 | if i % print_freq == 0 or i == len(iterable) - 1: 138 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 139 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 140 | if torch.cuda.is_available(): 141 | print(log_msg.format( 142 | i, len(iterable), eta=eta_string, 143 | meters=str(self), 144 | time=str(iter_time), data=str(data_time), 145 | memory=torch.cuda.max_memory_allocated() / MB)) 146 | else: 147 | print(log_msg.format( 148 | i, len(iterable), eta=eta_string, 149 | meters=str(self), 150 | time=str(iter_time), data=str(data_time))) 151 | i += 1 152 | end = time.time() 153 | total_time = time.time() - start_time 154 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 155 | print('{} Total time: {} ({:.4f} s / it)'.format( 156 | header, total_time_str, total_time / len(iterable))) 157 | 158 | 159 | def setup_for_distributed(is_master): 160 | """ 161 | This function disables printing when not in master process 162 | """ 163 | builtin_print = builtins.print 164 | 165 | def print(*args, **kwargs): 166 | force = kwargs.pop('force', False) 167 | force = force or (get_world_size() > 8) 168 | if is_master or force: 169 | now = datetime.datetime.now().time() 170 | builtin_print('[{}] '.format(now), end='') # print with time stamp 171 | builtin_print(*args, **kwargs) 172 | 173 | builtins.print = print 174 | 175 | 176 | def is_dist_avail_and_initialized(): 177 | if not dist.is_available(): 178 | return False 179 | if not dist.is_initialized(): 180 | return False 181 | return True 182 | 183 | 184 | def get_world_size(): 185 | if not is_dist_avail_and_initialized(): 186 | return 1 187 | return dist.get_world_size() 188 | 189 | 190 | def get_rank(): 191 | if not is_dist_avail_and_initialized(): 192 | return 0 193 | return dist.get_rank() 194 | 195 | 196 | def is_main_process(): 197 | return get_rank() == 0 198 | 199 | 200 | def save_on_master(*args, **kwargs): 201 | if is_main_process(): 202 | torch.save(*args, **kwargs) 203 | 204 | 205 | def init_distributed_mode(args): 206 | if args.dist_on_itp: 207 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 208 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 209 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 210 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 211 | os.environ['LOCAL_RANK'] = str(args.gpu) 212 | os.environ['RANK'] = str(args.rank) 213 | os.environ['WORLD_SIZE'] = str(args.world_size) 214 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 215 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 216 | args.rank = int(os.environ["RANK"]) 217 | args.world_size = int(os.environ['WORLD_SIZE']) 218 | args.gpu = int(os.environ['LOCAL_RANK']) 219 | elif 'SLURM_PROCID' in os.environ: 220 | args.rank = int(os.environ['SLURM_PROCID']) 221 | args.gpu = args.rank % torch.cuda.device_count() 222 | else: 223 | print('Not using distributed mode') 224 | setup_for_distributed(is_master=True) # hack 225 | args.distributed = False 226 | return 227 | 228 | args.distributed = True 229 | print(args.gpu) 230 | torch.cuda.set_device(args.gpu) 231 | args.dist_backend = 'nccl' 232 | print('| distributed init (rank {}): {}, gpu {}'.format( 233 | args.rank, args.dist_url, args.gpu), flush=True) 234 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 235 | world_size=args.world_size, rank=args.rank) 236 | torch.distributed.barrier() 237 | setup_for_distributed(args.rank == 0) 238 | 239 | 240 | class NativeScalerWithGradNormCount: 241 | state_dict_key = "amp_scaler" 242 | 243 | def __init__(self): 244 | self._scaler = torch.cuda.amp.GradScaler() 245 | 246 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 247 | self._scaler.scale(loss).backward(create_graph=create_graph) 248 | if update_grad: 249 | if clip_grad is not None: 250 | assert parameters is not None 251 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 252 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 253 | else: 254 | self._scaler.unscale_(optimizer) 255 | norm = get_grad_norm_(parameters) 256 | self._scaler.step(optimizer) 257 | self._scaler.update() 258 | else: 259 | norm = None 260 | return norm 261 | 262 | def state_dict(self): 263 | return self._scaler.state_dict() 264 | 265 | def load_state_dict(self, state_dict): 266 | self._scaler.load_state_dict(state_dict) 267 | 268 | 269 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 270 | if isinstance(parameters, torch.Tensor): 271 | parameters = [parameters] 272 | parameters = [p for p in parameters if p.grad is not None] 273 | norm_type = float(norm_type) 274 | if len(parameters) == 0: 275 | return torch.tensor(0.) 276 | device = parameters[0].grad.device 277 | if norm_type == inf: 278 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 279 | else: 280 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 281 | return total_norm 282 | 283 | 284 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 285 | output_dir = Path(args.output_dir) 286 | epoch_name = str(epoch) 287 | if loss_scaler is not None: 288 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 289 | for checkpoint_path in checkpoint_paths: 290 | to_save = { 291 | 'model': model_without_ddp.state_dict(), 292 | 'optimizer': optimizer.state_dict(), 293 | 'epoch': epoch, 294 | 'scaler': loss_scaler.state_dict(), 295 | 'args': args, 296 | } 297 | 298 | save_on_master(to_save, checkpoint_path) 299 | else: 300 | client_state = {'epoch': epoch} 301 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 302 | 303 | 304 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 305 | if args.resume: 306 | if args.resume.startswith('https'): 307 | checkpoint = torch.hub.load_state_dict_from_url( 308 | args.resume, map_location='cpu', check_hash=True) 309 | else: 310 | checkpoint = torch.load(args.resume, map_location='cpu') 311 | model_without_ddp.load_state_dict(checkpoint['model']) 312 | print("Resume checkpoint %s" % args.resume) 313 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 314 | optimizer.load_state_dict(checkpoint['optimizer']) 315 | args.start_epoch = checkpoint['epoch'] + 1 316 | if 'scaler' in checkpoint: 317 | loss_scaler.load_state_dict(checkpoint['scaler']) 318 | print("With optim & sched!") 319 | 320 | 321 | def all_reduce_mean(x): 322 | world_size = get_world_size() 323 | if world_size > 1: 324 | x_reduce = torch.tensor(x).cuda() 325 | dist.all_reduce(x_reduce) 326 | x_reduce /= world_size 327 | return x_reduce.item() 328 | else: 329 | return x -------------------------------------------------------------------------------- /models_mae.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from timm.models.vision_transformer import PatchEmbed, Block 8 | from loss_contrastive import SupConLoss 9 | 10 | from util.pos_embed import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed 11 | 12 | 13 | class MaskedAutoencoderViT(nn.Module): 14 | """ Masked Autoencoder with VisionTransformer backbone 15 | """ 16 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 17 | embed_dim=1024, depth=24, num_heads=16, 18 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 19 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, 20 | noise_loss=False, std=0.1, pe_dims=128): 21 | super().__init__() 22 | 23 | # -------------------------------------------------------------------------- 24 | # MAE encoder specifics 25 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 26 | num_patches = self.patch_embed.num_patches 27 | 28 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 29 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 30 | 31 | self.blocks = nn.ModuleList([ 32 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 33 | for i in range(depth)]) 34 | self.norm = norm_layer(embed_dim) 35 | 36 | # projection head changes 37 | feat_dim = 128 38 | self.projection_head = nn.Sequential( 39 | nn.Linear(embed_dim, embed_dim), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(embed_dim, feat_dim) 42 | ) 43 | 44 | # noise loss specifics 45 | self.noise_loss = noise_loss 46 | self.std = std 47 | self.pe_dims=pe_dims 48 | self.noise_pe_mlp = nn.Sequential( 49 | nn.Linear(pe_dims, embed_dim), 50 | nn.ReLU(inplace=True), 51 | nn.Linear(embed_dim, embed_dim) 52 | ) 53 | 54 | # -------------------------------------------------------------------------- 55 | 56 | # -------------------------------------------------------------------------- 57 | # MAE decoder specifics 58 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 59 | 60 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 61 | 62 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 63 | 64 | self.decoder_blocks = nn.ModuleList([ 65 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 66 | for i in range(decoder_depth)]) 67 | 68 | self.decoder_norm = norm_layer(decoder_embed_dim) 69 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 70 | # -------------------------------------------------------------------------- 71 | 72 | self.norm_pix_loss = norm_pix_loss 73 | 74 | self.initialize_weights() 75 | 76 | def initialize_weights(self): 77 | # initialization 78 | # initialize (and freeze) pos_embed by sin-cos embedding 79 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 80 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 81 | 82 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 83 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 84 | 85 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 86 | w = self.patch_embed.proj.weight.data 87 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 88 | 89 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 90 | torch.nn.init.normal_(self.cls_token, std=.02) 91 | torch.nn.init.normal_(self.mask_token, std=.02) 92 | 93 | # initialize nn.Linear and nn.LayerNorm 94 | self.apply(self._init_weights) 95 | 96 | def _init_weights(self, m): 97 | if isinstance(m, nn.Linear): 98 | # we use xavier_uniform following official JAX ViT: 99 | torch.nn.init.xavier_uniform_(m.weight) 100 | if isinstance(m, nn.Linear) and m.bias is not None: 101 | nn.init.constant_(m.bias, 0) 102 | elif isinstance(m, nn.LayerNorm): 103 | nn.init.constant_(m.bias, 0) 104 | nn.init.constant_(m.weight, 1.0) 105 | 106 | def patchify(self, imgs): 107 | """ 108 | imgs: (N, 3, H, W) 109 | x: (N, L, patch_size**2 *3) 110 | """ 111 | p = self.patch_embed.patch_size[0] 112 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 113 | 114 | h = w = imgs.shape[2] // p 115 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 116 | x = torch.einsum('nchpwq->nhwpqc', x) 117 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 118 | return x 119 | 120 | def unpatchify(self, x): 121 | """ 122 | x: (N, L, patch_size**2 *3) 123 | imgs: (N, 3, H, W) 124 | """ 125 | p = self.patch_embed.patch_size[0] 126 | h = w = int(x.shape[1]**.5) 127 | assert h * w == x.shape[1] 128 | 129 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 130 | x = torch.einsum('nhwpqc->nchpwq', x) 131 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 132 | return imgs 133 | 134 | def random_masking(self, x, mask_ratio): 135 | """ 136 | Perform per-sample random masking by per-sample shuffling. 137 | Per-sample shuffling is done by argsort random noise. 138 | x: [N, L, D], sequence 139 | """ 140 | N, L, D = x.shape # batch, length, dim 141 | len_keep = int(L * (1 - mask_ratio)) 142 | 143 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 144 | 145 | # sort noise for each sample 146 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 147 | ids_restore = torch.argsort(ids_shuffle, dim=1) 148 | 149 | # keep the first subset 150 | ids_keep = ids_shuffle[:, :len_keep] 151 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 152 | 153 | # generate the binary mask: 0 is keep, 1 is remove 154 | mask = torch.ones([N, L], device=x.device) 155 | mask[:, :len_keep] = 0 156 | # unshuffle to get the binary mask 157 | mask = torch.gather(mask, dim=1, index=ids_restore) 158 | 159 | return x_masked, mask, ids_restore 160 | 161 | def forward_encoder(self, x, mask_ratio): 162 | # embed patches 163 | x = self.patch_embed(x) 164 | 165 | # add pos embed w/o cls token 166 | x = x + self.pos_embed[:, 1:, :] 167 | 168 | # masking: length -> length * mask_ratio 169 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 170 | 171 | # append cls token 172 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 173 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 174 | x = torch.cat((cls_tokens, x), dim=1) 175 | 176 | 177 | # apply Transformer blocks 178 | for blk in self.blocks: 179 | x = blk(x) 180 | x = self.norm(x) 181 | 182 | return x, mask, ids_restore 183 | 184 | def forward_decoder(self, x, ids_restore): 185 | # embed tokens 186 | x = self.decoder_embed(x) 187 | 188 | # append mask tokens to sequence 189 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 190 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 191 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 192 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 193 | 194 | # add pos embed 195 | x = x + self.decoder_pos_embed 196 | 197 | # apply Transformer blocks 198 | for blk in self.decoder_blocks: 199 | x = blk(x) 200 | x = self.decoder_norm(x) 201 | 202 | # predictor projection 203 | x = self.decoder_pred(x) 204 | 205 | # remove cls token 206 | x = x[:, 1:, :] 207 | 208 | return x 209 | 210 | def forward_loss(self, imgs, pred, mask, noise=None): 211 | """ 212 | imgs: [N, 3, H, W] 213 | pred: [N, L, p*p*3] 214 | mask: [N, L], 0 is keep, 1 is remove, 215 | """ 216 | target = self.patchify(imgs) 217 | if self.norm_pix_loss: 218 | mean = target.mean(dim=-1, keepdim=True) 219 | var = target.var(dim=-1, keepdim=True) 220 | target = (target - mean) / (var + 1.e-6)**.5 221 | 222 | loss = (pred - target) ** 2 223 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 224 | 225 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 226 | 227 | losee_noise = -1 228 | if self.noise_loss: 229 | noise = self.patchify(noise) 230 | loss_noise = (pred - noise) ** 2 231 | loss_noise = loss_noise.mean(dim=-1) # [N, L], mean loss per patch 232 | loss_noise = (loss_noise * (1-mask)).sum() / (1-mask).sum() # mean loss on removed patches 233 | 234 | return loss, loss_noise 235 | 236 | def forward(self, imgs, mask_ratio=0.75): 237 | if self.noise_loss: 238 | noise_level = self.std * torch.rand(imgs.shape[0]).to(imgs.device) 239 | noise = noise_level[:, None, None, None] * torch.randn(imgs.shape).to(imgs.device) 240 | imgs += noise 241 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 242 | 243 | # Contrastive loss 244 | bsz = int(imgs.shape[0]/2) # hack for contrastive, otherwise line 227 doesn't work " error dim 0 is greater than 64 features" 245 | 246 | 247 | latent_contrastive = latent.mean(dim=1, keepdim=False) 248 | latent_contrastive = self.projection_head(latent_contrastive) 249 | 250 | # import pdb 251 | # pdb.set_trace() 252 | features = F.normalize(latent_contrastive, dim=-1) 253 | f1, f2 = torch.split(features, [bsz, bsz], dim=0) 254 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 255 | loss_contrastive = SupConLoss()(features) 256 | # print(loss_contrastive) 257 | if self.noise_loss: 258 | noise_pe = get_1d_sincos_pos_embed(noise_level, dim=self.pe_dims) 259 | noise_pe = self.noise_pe_mlp(noise_pe) 260 | 261 | noise_pe = torch.cat(latent.shape[1] * [noise_pe.unsqueeze(1)], dim=1) 262 | latent += noise_pe 263 | 264 | # mae features 265 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 266 | loss, loss_noise = self.forward_loss(imgs, pred, mask, noise=noise) 267 | return loss, loss_contrastive, loss_noise, pred, mask 268 | 269 | 270 | def mae_vit_base_patch16_dec512d8b(**kwargs): 271 | model = MaskedAutoencoderViT( 272 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 273 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 274 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 275 | return model 276 | 277 | 278 | def mae_vit_large_patch16_dec512d8b(**kwargs): 279 | model = MaskedAutoencoderViT( 280 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 281 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 282 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 283 | return model 284 | 285 | 286 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 287 | model = MaskedAutoencoderViT( 288 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 289 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 290 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 291 | return model 292 | 293 | 294 | # set recommended archs 295 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 296 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 297 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 298 | -------------------------------------------------------------------------------- /main_linprobe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import numpy as np 5 | import os 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils.tensorboard import SummaryWriter 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | 15 | import timm 16 | 17 | assert timm.__version__ == "0.3.2" # version check 18 | from timm.models.layers import trunc_normal_ 19 | 20 | import util.misc as misc 21 | from util.pos_embed import interpolate_pos_embed 22 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 23 | from util.lars import LARS 24 | from util.crop import RandomResizedCrop 25 | 26 | import models_vit 27 | 28 | from engine_finetune import train_one_epoch, evaluate 29 | 30 | 31 | def get_args_parser(): 32 | parser = argparse.ArgumentParser('CAN linear probing for image classification', add_help=False) 33 | parser.add_argument('--batch_size', default=512, type=int, 34 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 35 | parser.add_argument('--epochs', default=90, type=int) 36 | parser.add_argument('--accum_iter', default=1, type=int, 37 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 38 | 39 | # Model parameters 40 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 41 | help='Name of model to train') 42 | 43 | # Optimizer parameters 44 | parser.add_argument('--weight_decay', type=float, default=0, 45 | help='weight decay (default: 0 for linear probe following MoCo v1)') 46 | 47 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 48 | help='learning rate (absolute lr)') 49 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR', 50 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 51 | 52 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 53 | help='lower lr bound for cyclic schedulers that hit 0') 54 | 55 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 56 | help='epochs to warmup LR') 57 | 58 | # * Finetuning params 59 | parser.add_argument('--finetune', default='', 60 | help='finetune from checkpoint') 61 | parser.add_argument('--global_pool', action='store_true') 62 | parser.set_defaults(global_pool=False) 63 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 64 | help='Use class token instead of global pool for classification') 65 | 66 | # Dataset parameters 67 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 68 | help='dataset path') 69 | parser.add_argument('--nb_classes', default=1000, type=int, 70 | help='number of the classification types') 71 | 72 | parser.add_argument('--output_dir', default='./output_dir', 73 | help='path where to save, empty for no saving') 74 | parser.add_argument('--log_dir', default='./output_dir', 75 | help='path where to tensorboard log') 76 | parser.add_argument('--device', default='cuda', 77 | help='device to use for training / testing') 78 | parser.add_argument('--seed', default=0, type=int) 79 | parser.add_argument('--resume', default='', 80 | help='resume from checkpoint') 81 | 82 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 83 | help='start epoch') 84 | parser.add_argument('--eval', action='store_true', 85 | help='Perform evaluation only') 86 | parser.add_argument('--dist_eval', action='store_true', default=False, 87 | help='Enabling distributed evaluation (recommended during training for faster monitor') 88 | parser.add_argument('--num_workers', default=10, type=int) 89 | parser.add_argument('--pin_mem', action='store_true', 90 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 91 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 92 | parser.set_defaults(pin_mem=True) 93 | 94 | # distributed training parameters 95 | parser.add_argument('--world_size', default=1, type=int, 96 | help='number of distributed processes') 97 | parser.add_argument('--local_rank', default=-1, type=int) 98 | parser.add_argument('--dist_on_itp', action='store_true') 99 | parser.add_argument('--dist_url', default='env://', 100 | help='url used to set up distributed training') 101 | 102 | return parser 103 | 104 | 105 | def main(args): 106 | misc.init_distributed_mode(args) 107 | 108 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 109 | print("{}".format(args).replace(', ', ',\n')) 110 | 111 | device = torch.device(args.device) 112 | 113 | # fix the seed for reproducibility 114 | seed = args.seed + misc.get_rank() 115 | torch.manual_seed(seed) 116 | np.random.seed(seed) 117 | 118 | cudnn.benchmark = True 119 | 120 | # linear probe: weak augmentation 121 | transform_train = transforms.Compose([ 122 | RandomResizedCrop(224, interpolation=3), 123 | transforms.RandomHorizontalFlip(), 124 | transforms.ToTensor(), 125 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 126 | transforms.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])]) 127 | transform_val = transforms.Compose([ 128 | transforms.Resize(256, interpolation=3), 129 | transforms.CenterCrop(224), 130 | transforms.ToTensor(), 131 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 132 | transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 133 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255])]) 134 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 135 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) 136 | print(dataset_train) 137 | print(dataset_val) 138 | 139 | if True: # args.distributed: 140 | num_tasks = misc.get_world_size() 141 | global_rank = misc.get_rank() 142 | sampler_train = torch.utils.data.DistributedSampler( 143 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 144 | ) 145 | print("Sampler_train = %s" % str(sampler_train)) 146 | if args.dist_eval: 147 | if len(dataset_val) % num_tasks != 0: 148 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 149 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 150 | 'equal num of samples per-process.') 151 | sampler_val = torch.utils.data.DistributedSampler( 152 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 153 | else: 154 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 155 | else: 156 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 157 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 158 | 159 | if global_rank == 0 and args.log_dir is not None and not args.eval: 160 | os.makedirs(args.log_dir, exist_ok=True) 161 | log_writer = SummaryWriter(log_dir=args.log_dir) 162 | else: 163 | log_writer = None 164 | 165 | data_loader_train = torch.utils.data.DataLoader( 166 | dataset_train, sampler=sampler_train, 167 | batch_size=args.batch_size, 168 | num_workers=args.num_workers, 169 | pin_memory=args.pin_mem, 170 | drop_last=True, 171 | ) 172 | 173 | data_loader_val = torch.utils.data.DataLoader( 174 | dataset_val, sampler=sampler_val, 175 | batch_size=args.batch_size, 176 | num_workers=args.num_workers, 177 | pin_memory=args.pin_mem, 178 | drop_last=False 179 | ) 180 | 181 | model = models_vit.__dict__[args.model]( 182 | num_classes=args.nb_classes, 183 | global_pool=args.global_pool, 184 | ) 185 | 186 | if args.finetune and not args.eval: 187 | checkpoint = torch.load(args.finetune, map_location='cpu') 188 | # import pdb 189 | # pdb.set_trace() 190 | 191 | print("Load pre-trained checkpoint from: %s" % args.finetune) 192 | checkpoint_model = checkpoint['model'] 193 | state_dict = model.state_dict() 194 | # print(checkpoint_model.keys()) 195 | for k in ['head.weight', 'head.bias']: 196 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 197 | print(f"Removing key {k} from pretrained checkpoint") 198 | del checkpoint_model[k] 199 | # if k.startwith('patch_embed'): 200 | # print(k) 201 | 202 | # interpolate position embedding 203 | interpolate_pos_embed(model, checkpoint_model) 204 | 205 | # load pre-trained model 206 | msg = model.load_state_dict(checkpoint_model, strict=False) 207 | print(msg) 208 | 209 | if args.global_pool: 210 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 211 | else: 212 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 213 | 214 | # manually initialize fc layer: following MoCo v3 215 | trunc_normal_(model.head.weight, std=0.01) 216 | 217 | # for linear prob only 218 | # hack: revise model's head with BN 219 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) 220 | # freeze all but the head 221 | for _, p in model.named_parameters(): 222 | p.requires_grad = False 223 | for _, p in model.head.named_parameters(): 224 | p.requires_grad = True 225 | 226 | model.to(device) 227 | 228 | model_without_ddp = model 229 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 230 | 231 | print("Model = %s" % str(model_without_ddp)) 232 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 233 | 234 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 235 | 236 | if args.lr is None: # only base_lr is specified 237 | args.lr = args.blr * eff_batch_size / 256 238 | 239 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 240 | print("actual lr: %.2e" % args.lr) 241 | 242 | print("accumulate grad iterations: %d" % args.accum_iter) 243 | print("effective batch size: %d" % eff_batch_size) 244 | 245 | if args.distributed: 246 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 247 | model_without_ddp = model.module 248 | 249 | optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) 250 | print(optimizer) 251 | loss_scaler = NativeScaler() 252 | 253 | criterion = torch.nn.CrossEntropyLoss() 254 | 255 | print("criterion = %s" % str(criterion)) 256 | 257 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 258 | 259 | if args.eval: 260 | test_stats = evaluate(data_loader_val, model, device) 261 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 262 | exit(0) 263 | 264 | print(f"Start training for {args.epochs} epochs") 265 | start_time = time.time() 266 | max_accuracy = 0.0 267 | for epoch in range(args.start_epoch, args.epochs): 268 | if args.distributed: 269 | data_loader_train.sampler.set_epoch(epoch) 270 | train_stats = train_one_epoch( 271 | model, criterion, data_loader_train, 272 | optimizer, device, epoch, loss_scaler, 273 | max_norm=None, 274 | log_writer=log_writer, 275 | args=args 276 | ) 277 | if args.output_dir: 278 | misc.save_model( 279 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 280 | loss_scaler=loss_scaler, epoch=epoch) 281 | 282 | test_stats = evaluate(data_loader_val, model, device) 283 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 284 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 285 | print(f'Max accuracy: {max_accuracy:.2f}%') 286 | 287 | if log_writer is not None: 288 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 289 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 290 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 291 | 292 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 293 | **{f'test_{k}': v for k, v in test_stats.items()}, 294 | 'epoch': epoch, 295 | 'n_parameters': n_parameters} 296 | 297 | if args.output_dir and misc.is_main_process(): 298 | if log_writer is not None: 299 | log_writer.flush() 300 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 301 | f.write(json.dumps(log_stats) + "\n") 302 | 303 | total_time = time.time() - start_time 304 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 305 | print('Training time {}'.format(total_time_str)) 306 | 307 | 308 | if __name__ == '__main__': 309 | args = get_args_parser() 310 | args = args.parse_args() 311 | if args.output_dir: 312 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 313 | main(args) 314 | -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import numpy as np 5 | import os 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | import timm 14 | 15 | assert timm.__version__ == "0.3.2" # version check 16 | from timm.models.layers import trunc_normal_ 17 | from timm.data.mixup import Mixup 18 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 19 | 20 | import util.lr_decay as lrd 21 | import util.misc as misc 22 | from util.datasets import build_dataset 23 | from util.pos_embed import interpolate_pos_embed 24 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 25 | 26 | import models_vit 27 | 28 | from engine_finetune import train_one_epoch, evaluate 29 | 30 | 31 | def get_args_parser(): 32 | parser = argparse.ArgumentParser('CAN fine-tuning for image classification', add_help=False) 33 | parser.add_argument('--batch_size', default=64, type=int, 34 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 35 | parser.add_argument('--epochs', default=50, type=int) 36 | parser.add_argument('--accum_iter', default=1, type=int, 37 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 38 | 39 | # Model parameters 40 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 41 | help='Name of model to train') 42 | 43 | parser.add_argument('--input_size', default=224, type=int, 44 | help='images input size') 45 | 46 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 47 | help='Drop path rate (default: 0.1)') 48 | 49 | # Optimizer parameters 50 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 51 | help='Clip gradient norm (default: None, no clipping)') 52 | parser.add_argument('--weight_decay', type=float, default=0.05, 53 | help='weight decay (default: 0.05)') 54 | 55 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 56 | help='learning rate (absolute lr)') 57 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 58 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 59 | parser.add_argument('--layer_decay', type=float, default=0.75, 60 | help='layer-wise lr decay from ELECTRA/BEiT') 61 | 62 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 63 | help='lower lr bound for cyclic schedulers that hit 0') 64 | 65 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 66 | help='epochs to warmup LR') 67 | 68 | # Augmentation parameters 69 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 70 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 71 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 72 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 73 | parser.add_argument('--smoothing', type=float, default=0.1, 74 | help='Label smoothing (default: 0.1)') 75 | 76 | # * Random Erase params 77 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 78 | help='Random erase prob (default: 0.25)') 79 | parser.add_argument('--remode', type=str, default='pixel', 80 | help='Random erase mode (default: "pixel")') 81 | parser.add_argument('--recount', type=int, default=1, 82 | help='Random erase count (default: 1)') 83 | parser.add_argument('--resplit', action='store_true', default=False, 84 | help='Do not random erase first (clean) augmentation split') 85 | 86 | # * Mixup params 87 | parser.add_argument('--mixup', type=float, default=0, 88 | help='mixup alpha, mixup enabled if > 0.') 89 | parser.add_argument('--cutmix', type=float, default=0, 90 | help='cutmix alpha, cutmix enabled if > 0.') 91 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 92 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 93 | parser.add_argument('--mixup_prob', type=float, default=1.0, 94 | help='Probability of performing mixup or cutmix when either/both is enabled') 95 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 96 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 97 | parser.add_argument('--mixup_mode', type=str, default='batch', 98 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 99 | 100 | # * Finetuning params 101 | parser.add_argument('--finetune', default='', 102 | help='finetune from checkpoint') 103 | parser.add_argument('--global_pool', action='store_true') 104 | parser.set_defaults(global_pool=True) 105 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 106 | help='Use class token instead of global pool for classification') 107 | 108 | # Dataset parameters 109 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 110 | help='dataset path') 111 | parser.add_argument('--nb_classes', default=1000, type=int, 112 | help='number of the classification types') 113 | 114 | parser.add_argument('--output_dir', default='./output_dir', 115 | help='path where to save, empty for no saving') 116 | parser.add_argument('--log_dir', default='./output_dir', 117 | help='path where to tensorboard log') 118 | parser.add_argument('--device', default='cuda', 119 | help='device to use for training / testing') 120 | parser.add_argument('--seed', default=0, type=int) 121 | parser.add_argument('--resume', default='', 122 | help='resume from checkpoint') 123 | 124 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 125 | help='start epoch') 126 | parser.add_argument('--eval', action='store_true', 127 | help='Perform evaluation only') 128 | parser.add_argument('--dist_eval', action='store_true', default=False, 129 | help='Enabling distributed evaluation (recommended during training for faster monitor') 130 | parser.add_argument('--num_workers', default=10, type=int) 131 | parser.add_argument('--pin_mem', action='store_true', 132 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 133 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 134 | parser.set_defaults(pin_mem=True) 135 | 136 | # distributed training parameters 137 | parser.add_argument('--world_size', default=1, type=int, 138 | help='number of distributed processes') 139 | parser.add_argument('--local_rank', default=-1, type=int) 140 | parser.add_argument('--dist_on_itp', action='store_true') 141 | parser.add_argument('--dist_url', default='env://', 142 | help='url used to set up distributed training') 143 | 144 | return parser 145 | 146 | 147 | def main(args): 148 | misc.init_distributed_mode(args) 149 | 150 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 151 | print("{}".format(args).replace(', ', ',\n')) 152 | 153 | device = torch.device(args.device) 154 | 155 | # fix the seed for reproducibility 156 | seed = args.seed + misc.get_rank() 157 | torch.manual_seed(seed) 158 | np.random.seed(seed) 159 | 160 | cudnn.benchmark = True 161 | 162 | dataset_train = build_dataset(is_train=True, args=args) 163 | dataset_val = build_dataset(is_train=False, args=args) 164 | 165 | if True: # args.distributed: 166 | num_tasks = misc.get_world_size() 167 | global_rank = misc.get_rank() 168 | sampler_train = torch.utils.data.DistributedSampler( 169 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 170 | ) 171 | print("Sampler_train = %s" % str(sampler_train)) 172 | if args.dist_eval: 173 | if len(dataset_val) % num_tasks != 0: 174 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 175 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 176 | 'equal num of samples per-process.') 177 | sampler_val = torch.utils.data.DistributedSampler( 178 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 179 | else: 180 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 181 | else: 182 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 183 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 184 | 185 | if global_rank == 0 and args.log_dir is not None and not args.eval: 186 | os.makedirs(args.log_dir, exist_ok=True) 187 | log_writer = SummaryWriter(log_dir=args.log_dir) 188 | else: 189 | log_writer = None 190 | 191 | data_loader_train = torch.utils.data.DataLoader( 192 | dataset_train, sampler=sampler_train, 193 | batch_size=args.batch_size, 194 | num_workers=args.num_workers, 195 | pin_memory=args.pin_mem, 196 | drop_last=True, 197 | ) 198 | 199 | data_loader_val = torch.utils.data.DataLoader( 200 | dataset_val, sampler=sampler_val, 201 | batch_size=args.batch_size, 202 | num_workers=args.num_workers, 203 | pin_memory=args.pin_mem, 204 | drop_last=False 205 | ) 206 | 207 | mixup_fn = None 208 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 209 | if mixup_active: 210 | print("Mixup is activated!") 211 | mixup_fn = Mixup( 212 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 213 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 214 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 215 | 216 | model = models_vit.__dict__[args.model]( 217 | num_classes=args.nb_classes, 218 | drop_path_rate=args.drop_path, 219 | global_pool=args.global_pool, 220 | ) 221 | 222 | if args.finetune and not args.eval: 223 | checkpoint = torch.load(args.finetune, map_location='cpu') 224 | 225 | print("Load pre-trained checkpoint from: %s" % args.finetune) 226 | checkpoint_model = checkpoint['model'] 227 | state_dict = model.state_dict() 228 | for k in ['head.weight', 'head.bias']: 229 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 230 | print(f"Removing key {k} from pretrained checkpoint") 231 | del checkpoint_model[k] 232 | 233 | # interpolate position embedding 234 | interpolate_pos_embed(model, checkpoint_model) 235 | 236 | # load pre-trained model 237 | msg = model.load_state_dict(checkpoint_model, strict=False) 238 | print(msg) 239 | 240 | if args.global_pool: 241 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 242 | else: 243 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 244 | 245 | # manually initialize fc layer 246 | trunc_normal_(model.head.weight, std=2e-5) 247 | 248 | model.to(device) 249 | 250 | model_without_ddp = model 251 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 252 | 253 | print("Model = %s" % str(model_without_ddp)) 254 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 255 | 256 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 257 | 258 | if args.lr is None: # only base_lr is specified 259 | args.lr = args.blr * eff_batch_size / 256 260 | 261 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 262 | print("actual lr: %.2e" % args.lr) 263 | 264 | print("accumulate grad iterations: %d" % args.accum_iter) 265 | print("effective batch size: %d" % eff_batch_size) 266 | 267 | if args.distributed: 268 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 269 | model_without_ddp = model.module 270 | 271 | # build optimizer with layer-wise lr decay (lrd) 272 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 273 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 274 | layer_decay=args.layer_decay 275 | ) 276 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 277 | loss_scaler = NativeScaler() 278 | 279 | if mixup_fn is not None: 280 | # smoothing is handled with mixup label transform 281 | criterion = SoftTargetCrossEntropy() 282 | elif args.smoothing > 0.: 283 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 284 | else: 285 | criterion = torch.nn.CrossEntropyLoss() 286 | 287 | print("criterion = %s" % str(criterion)) 288 | 289 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 290 | 291 | if args.eval: 292 | test_stats = evaluate(data_loader_val, model, device) 293 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 294 | exit(0) 295 | 296 | print(f"Start training for {args.epochs} epochs") 297 | start_time = time.time() 298 | max_accuracy = 0.0 299 | for epoch in range(args.start_epoch, args.epochs): 300 | if args.distributed: 301 | data_loader_train.sampler.set_epoch(epoch) 302 | train_stats = train_one_epoch( 303 | model, criterion, data_loader_train, 304 | optimizer, device, epoch, loss_scaler, 305 | args.clip_grad, mixup_fn, 306 | log_writer=log_writer, 307 | args=args 308 | ) 309 | if args.output_dir: 310 | misc.save_model( 311 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 312 | loss_scaler=loss_scaler, epoch=epoch) 313 | 314 | test_stats = evaluate(data_loader_val, model, device) 315 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 316 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 317 | print(f'Max accuracy: {max_accuracy:.2f}%') 318 | 319 | if log_writer is not None: 320 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 321 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 322 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 323 | 324 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 325 | **{f'test_{k}': v for k, v in test_stats.items()}, 326 | 'epoch': epoch, 327 | 'n_parameters': n_parameters} 328 | 329 | if args.output_dir and misc.is_main_process(): 330 | if log_writer is not None: 331 | log_writer.flush() 332 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 333 | f.write(json.dumps(log_stats) + "\n") 334 | 335 | total_time = time.time() - start_time 336 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 337 | print('Training time {}'.format(total_time_str)) 338 | 339 | 340 | if __name__ == '__main__': 341 | args = get_args_parser() 342 | args = args.parse_args() 343 | if args.output_dir: 344 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 345 | main(args) 346 | --------------------------------------------------------------------------------