├── can.png
├── .DS_Store
├── config
├── util
├── __pycache__
│ ├── misc.cpython-38.pyc
│ ├── lr_sched.cpython-38.pyc
│ └── pos_embed.cpython-38.pyc
├── lr_sched.py
├── crop.py
├── lars.py
├── datasets.py
├── lr_decay.py
├── pos_embed.py
└── misc.py
├── util_contrastive.py
├── models_vit.py
├── README.md
├── can.yml
├── engine_pretrain.py
├── loss_contrastive.py
├── submitit_finetune.py
├── submitit_linprobe.py
├── submitit_pretrain.py
├── engine_finetune.py
├── main_pretrain.py
├── models_mae.py
├── main_linprobe.py
└── main_finetune.py
/can.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/can.png
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/.DS_Store
--------------------------------------------------------------------------------
/config:
--------------------------------------------------------------------------------
1 | [core]
2 | repositoryformatversion = 0
3 | filemode = true
4 | bare = true
5 |
--------------------------------------------------------------------------------
/util/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/util/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/util/__pycache__/lr_sched.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/util/__pycache__/lr_sched.cpython-38.pyc
--------------------------------------------------------------------------------
/util/__pycache__/pos_embed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shlokk/mae-contrastive/HEAD/util/__pycache__/pos_embed.cpython-38.pyc
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | def adjust_learning_rate(optimizer, epoch, args):
4 | """Decay the learning rate with half-cycle cosine after warmup"""
5 | if epoch < args.warmup_epochs:
6 | lr = args.lr * epoch / args.warmup_epochs
7 | else:
8 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
9 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
10 | for param_group in optimizer.param_groups:
11 | if "lr_scale" in param_group:
12 | param_group["lr"] = lr * param_group["lr_scale"]
13 | else:
14 | param_group["lr"] = lr
15 | return lr
16 |
--------------------------------------------------------------------------------
/util_contrastive.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import math
4 | import numpy as np
5 | import torch
6 | import torch.optim as optim
7 | from PIL import ImageFilter
8 | import random
9 |
10 |
11 | class TwoCropTransform:
12 | """Create two crops of the same image"""
13 | def __init__(self, transform):
14 | self.transform = transform
15 |
16 | def __call__(self, x):
17 | return [self.transform(x), self.transform(x)]
18 |
19 | class GaussianBlur(object):
20 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
21 |
22 | def __init__(self, sigma=[.1, 2.]):
23 | self.sigma = sigma
24 |
25 | def __call__(self, x):
26 | sigma = random.uniform(self.sigma[0], self.sigma[1])
27 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
28 | return x
--------------------------------------------------------------------------------
/util/crop.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 | from torchvision import transforms
6 | from torchvision.transforms import functional as F
7 |
8 |
9 | class RandomResizedCrop(transforms.RandomResizedCrop):
10 | """
11 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
12 | This may lead to results different with torchvision's version.
13 | Following BYOL's TF code:
14 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
15 | """
16 | @staticmethod
17 | def get_params(img, scale, ratio):
18 | width, height = F._get_image_size(img)
19 | area = height * width
20 |
21 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
22 | log_ratio = torch.log(torch.tensor(ratio))
23 | aspect_ratio = torch.exp(
24 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
25 | ).item()
26 |
27 | w = int(round(math.sqrt(target_area * aspect_ratio)))
28 | h = int(round(math.sqrt(target_area / aspect_ratio)))
29 |
30 | w = min(w, width)
31 | h = min(h, height)
32 |
33 | i = torch.randint(0, height - h + 1, size=(1,)).item()
34 | j = torch.randint(0, width - w + 1, size=(1,)).item()
35 |
36 | return i, j, h, w
--------------------------------------------------------------------------------
/util/lars.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LARS(torch.optim.Optimizer):
5 | """
6 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
7 | """
8 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
9 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
10 | super().__init__(params, defaults)
11 |
12 | @torch.no_grad()
13 | def step(self):
14 | for g in self.param_groups:
15 | for p in g['params']:
16 | dp = p.grad
17 |
18 | if dp is None:
19 | continue
20 |
21 | if p.ndim > 1: # if not normalization gamma/beta or bias
22 | dp = dp.add(p, alpha=g['weight_decay'])
23 | param_norm = torch.norm(p)
24 | update_norm = torch.norm(dp)
25 | one = torch.ones_like(param_norm)
26 | q = torch.where(param_norm > 0.,
27 | torch.where(update_norm > 0,
28 | (g['trust_coefficient'] * param_norm / update_norm), one),
29 | one)
30 | dp = dp.mul(q)
31 |
32 | param_state = self.state[p]
33 | if 'mu' not in param_state:
34 | param_state['mu'] = torch.zeros_like(p)
35 | mu = param_state['mu']
36 | mu.mul_(g['momentum']).add_(dp)
37 | p.add_(mu, alpha=-g['lr'])
--------------------------------------------------------------------------------
/models_vit.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import timm.models.vision_transformer
7 |
8 |
9 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
10 | """ Vision Transformer with support for global average pooling
11 | """
12 | def __init__(self, global_pool=False, **kwargs):
13 | super(VisionTransformer, self).__init__(**kwargs)
14 |
15 | self.global_pool = global_pool
16 | if self.global_pool:
17 | norm_layer = kwargs['norm_layer']
18 | embed_dim = kwargs['embed_dim']
19 | self.fc_norm = norm_layer(embed_dim)
20 |
21 | del self.norm # remove the original norm
22 |
23 | def forward_features(self, x):
24 | B = x.shape[0]
25 | x = self.patch_embed(x)
26 |
27 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
28 | x = torch.cat((cls_tokens, x), dim=1)
29 | x = x + self.pos_embed
30 | x = self.pos_drop(x)
31 |
32 | for blk in self.blocks:
33 | x = blk(x)
34 |
35 | if self.global_pool:
36 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
37 | outcome = self.fc_norm(x)
38 | else:
39 | x = self.norm(x)
40 | outcome = x[:, 0]
41 |
42 | return outcome
43 |
44 |
45 | def vit_base_patch16(**kwargs):
46 | model = VisionTransformer(
47 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
48 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
49 | return model
50 |
51 |
52 | def vit_large_patch16(**kwargs):
53 | model = VisionTransformer(
54 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
55 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
56 | return model
57 |
58 |
59 | def vit_huge_patch14(**kwargs):
60 | model = VisionTransformer(
61 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
62 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
63 | return model
--------------------------------------------------------------------------------
/util/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 |
4 | from torchvision import datasets, transforms
5 | from util_contrastive import TwoCropTransform
6 |
7 | from timm.data import create_transform
8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9 |
10 |
11 | def build_dataset(is_train, args):
12 | transform = build_transform(is_train, args)
13 |
14 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
15 | if is_train:
16 | dataset = datasets.ImageFolder(root, transform=TwoCropTransform(transform))
17 | else:
18 | dataset = datasets.ImageFolder(root, transform=transform)
19 |
20 | print(dataset)
21 |
22 | return dataset
23 |
24 |
25 | def build_transform(is_train, args):
26 | mean = IMAGENET_DEFAULT_MEAN
27 | std = IMAGENET_DEFAULT_STD
28 | # train transform
29 | if is_train:
30 | # this should always dispatch to transforms_imagenet_train
31 | # transform = create_transform(
32 | # input_size=args.input_size,
33 | # is_training=True,
34 | # color_jitter=args.color_jitter,
35 | # auto_augment=args.aa,
36 | # interpolation='bicubic',
37 | # re_prob=args.reprob,
38 | # re_mode=args.remode,
39 | # re_count=args.recount,
40 | # mean=mean,
41 | # std=std,
42 | # )
43 | # return transform
44 | normalize = transforms.Normalize(mean=mean, std=std)
45 | train_transform = transforms.Compose([
46 | transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)),
47 | transforms.RandomHorizontalFlip(),
48 | transforms.RandomApply([
49 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
50 | ], p=0.8),
51 | transforms.RandomGrayscale(p=0.2),
52 | transforms.ToTensor(),
53 | normalize,
54 | ])
55 | return train_transform
56 |
57 | # eval transform
58 | t = []
59 | if args.input_size <= 224:
60 | crop_pct = 224 / 256
61 | else:
62 | crop_pct = 1.0
63 | size = int(args.input_size / crop_pct)
64 | t.append(
65 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
66 | )
67 | t.append(transforms.CenterCrop(args.input_size))
68 |
69 | t.append(transforms.ToTensor())
70 | t.append(transforms.Normalize(mean, std))
71 | return transforms.Compose(t)
72 |
--------------------------------------------------------------------------------
/util/lr_decay.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ELECTRA https://github.com/google-research/electra
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import json
13 |
14 |
15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
16 | """
17 | Parameter groups for layer-wise lr decay
18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19 | """
20 | param_group_names = {}
21 | param_groups = {}
22 |
23 | num_layers = len(model.blocks) + 1
24 |
25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26 |
27 | for n, p in model.named_parameters():
28 | if not p.requires_grad:
29 | continue
30 |
31 | # no decay: all 1D parameters and model specific ones
32 | if p.ndim == 1 or n in no_weight_decay_list:
33 | g_decay = "no_decay"
34 | this_decay = 0.
35 | else:
36 | g_decay = "decay"
37 | this_decay = weight_decay
38 |
39 | layer_id = get_layer_id_for_vit(n, num_layers)
40 | group_name = "layer_%d_%s" % (layer_id, g_decay)
41 |
42 | if group_name not in param_group_names:
43 | this_scale = layer_scales[layer_id]
44 |
45 | param_group_names[group_name] = {
46 | "lr_scale": this_scale,
47 | "weight_decay": this_decay,
48 | "params": [],
49 | }
50 | param_groups[group_name] = {
51 | "lr_scale": this_scale,
52 | "weight_decay": this_decay,
53 | "params": [],
54 | }
55 |
56 | param_group_names[group_name]["params"].append(n)
57 | param_groups[group_name]["params"].append(p)
58 |
59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
60 |
61 | return list(param_groups.values())
62 |
63 |
64 | def get_layer_id_for_vit(name, num_layers):
65 | """
66 | Assign a parameter with its layer id
67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
68 | """
69 | if name in ['cls_token', 'pos_embed']:
70 | return 0
71 | elif name.startswith('patch_embed'):
72 | return 0
73 | elif name.startswith('blocks'):
74 | return int(name.split('.')[1]) + 1
75 | else:
76 | return num_layers
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CAN: A simple, efficient and scalable contrastive masked autoencoder for learning visual representations
2 |
3 | Official PyTorch implementation of ["A simple, efficient and scalable contrastive masked autoencoder for learning visual representations"](https://arxiv.org/abs/2210.16870).
4 |
5 |
6 |
7 |
8 |
9 | - The original implementation was in JAX+TPU. This re-implementation is in PyTorch+GPU.
10 |
11 | ## Requirements
12 | - Instructions for creating conda enviroment.
13 |
14 |
15 | ```
16 | conda env create -f can.yml
17 | conda activate can
18 | ```
19 |
20 | ## Instructions for running CAN
21 | ```
22 | git clone https://github.com/shlokk/mae-contrastive.git
23 | cd mae-contrastive
24 | ```
25 |
26 |
27 | Script for running CAN:
28 |
29 | ```
30 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_pretrain.py \
31 | --data_path path_to_imagenet --output_dir can_noise_baseline --log_dir can_baseline_logs \
32 | --num_workers 8 --blr 2.5e-4 --weight_decay 0.05 --model mae_vit_base_patch16 \
33 | --batch_size 64 --dist_url 'tcp://localhost:10004' --epochs 50 --weight_simclr 0.03 \
34 | --weight_mae 0.97 --accum_iter 4
35 | ```
36 |
37 | Script for running MAE baseline:
38 |
39 | ```
40 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_pretrain.py \
41 | --data_path path_to_imagenet --output_dir mae_baseline --log_dir mae_baseline_logs \
42 | --num_workers 8 --blr 1.5e-4 --weight_decay 0.05 --model mae_vit_base_patch16 \
43 | --batch_size 64 --dist_url 'tcp://localhost:10004' --epochs 50 --weight_simclr 0 \
44 | --weight_mae 1.0 --accum_iter 4
45 | ```
46 |
47 | Script for running linear evaluation:
48 | ```
49 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_linprobe.py \
50 | --data_path path_to_imagenet --batch_size 512 --model vit_base_patch16 --cls_token \
51 | --finetune can_noise_baseline/checkpoint-49.pth --epochs 90 --blr 0.1 --weight_decay 0.0 \
52 | --dist_eval --data_path path_to_imagenet --output_dir mae_baseline_lineval
53 | ```
54 |
55 | ## Pre-trained models
56 | - We have released pretrained models for 50 epoch pretraining here(https://drive.google.com/file/d/18yVmZmKenM-cZh5o6hmcswvS2ePhuDk_/view?usp=sharing).
57 | - We will be releasing longer epoch training (800 and 1600 epochs) soon.
58 |
59 |
60 | This repo is heavily inspired by MAE repo https://github.com/facebookresearch/mae.
61 |
62 | ## Citation
63 | ```bibtex
64 | @article{mishra2022simple,
65 | title={A simple, efficient and scalable contrastive masked autoencoder for learning visual representations},
66 | author={Mishra, Shlok and Robinson, Joshua and Chang, Huiwen and Jacobs, David and Sarna, Aaron and Maschinot, Aaron and Krishnan, Dilip},
67 | journal={arXiv preprint arXiv:2210.16870},
68 | year={2022}
69 | }
70 |
--------------------------------------------------------------------------------
/can.yml:
--------------------------------------------------------------------------------
1 | name: can
2 | channels:
3 | - iopath
4 | - pytorch
5 | - vissl
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=main
10 | - antlr-python-runtime=4.8=py38h32f6830_2
11 | - apex=0.0=py38_cu102_pyt171
12 | - blas=1.0=mkl
13 | - ca-certificates=2021.4.13=h06a4308_1
14 | - certifi=2020.12.5=py38h06a4308_0
15 | - cudatoolkit=10.2.89=hfd86e86_1
16 | - faiss-gpu=1.7.0=py3.8_h080d439_0_cuda10.2
17 | - freetype=2.10.4=h5ab3b9f_0
18 | - fvcore=0.1.3.post20210223=pyhd8ed1ab_0
19 | - hydra-core=1.0.6=pyhd8ed1ab_1
20 | - importlib_resources=5.1.2=py38h578d9bd_0
21 | - intel-openmp=2020.2=254
22 | - iopath=0.1.8=py38
23 | - joblib=1.0.1=pyhd8ed1ab_0
24 | - jpeg=9b=h024ee3a_2
25 | - lcms2=2.12=h3be6417_0
26 | - ld_impl_linux-64=2.33.1=h53a641e_7
27 | - libfaiss=1.7.0=h4fe19ad_0_cuda10.2
28 | - libffi=3.3=he6710b0_2
29 | - libgcc-ng=9.1.0=hdf63c60_0
30 | - libgfortran-ng=7.5.0=h14aa051_19
31 | - libgfortran4=7.5.0=h14aa051_19
32 | - libpng=1.6.37=hbc83047_0
33 | - libstdcxx-ng=9.1.0=hdf63c60_0
34 | - libtiff=4.1.0=h2733197_1
35 | - libuv=1.40.0=h7b6447c_0
36 | - lz4-c=1.9.3=h2531618_0
37 | - mkl=2020.2=256
38 | - mkl-service=2.3.0=py38he904b0f_0
39 | - mkl_fft=1.3.0=py38h54f3939_0
40 | - mkl_random=1.1.1=py38h0573a6f_0
41 | - ncurses=6.2=he6710b0_1
42 | - ninja=1.10.2=hff7bd54_1
43 | - numpy=1.19.2=py38h54aff64_0
44 | - numpy-base=1.19.2=py38hfa32c7d_0
45 | - olefile=0.46=py_0
46 | - omegaconf=2.0.6=py38h578d9bd_0
47 | - openssl=1.1.1k=h27cfd23_0
48 | - pandas=1.2.4=py38h2531618_0
49 | - parameterized=0.8.1=pyhd3deb0d_0
50 | - pillow=8.2.0=py38he98fc37_0
51 | - pip=21.0.1=py38h06a4308_0
52 | - portalocker=1.7.0=py38h578d9bd_1
53 | - python=3.8.8=hdb3f193_5
54 | - python-dateutil=2.8.1=pyhd3eb1b0_0
55 | - python_abi=3.8=1_cp38
56 | - pytorch=1.7.1=py3.8_cuda10.2.89_cudnn7.6.5_0
57 | - pytz=2021.1=pyhd3eb1b0_0
58 | - pyyaml=5.3.1=py38h8df0ef7_1
59 | - readline=8.1=h27cfd23_0
60 | - scikit-learn=0.24.1=py38ha9443f7_0
61 | - scipy=1.6.2=py38h91f5cce_0
62 | - setuptools=52.0.0=py38h06a4308_0
63 | - six=1.15.0=py38h06a4308_0
64 | - sqlite=3.35.4=hdfb4753_0
65 | - tabulate=0.8.9=pyhd8ed1ab_0
66 | - termcolor=1.1.0=py_2
67 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
68 | - tk=8.6.10=hbc83047_0
69 | - torchvision=0.8.2=py38_cu102
70 | - tqdm=4.60.0=pyhd8ed1ab_0
71 | - typing_extensions=3.7.4.3=pyha847dfd_0
72 | - vissl=0.1.5=py38
73 | - wheel=0.36.2=pyhd3eb1b0_0
74 | - xz=5.2.5=h7b6447c_0
75 | - yacs=0.1.6=py_0
76 | - yaml=0.2.5=h516909a_0
77 | - zlib=1.2.11=h7b6447c_3
78 | - zstd=1.4.9=haebb681_0
79 | - pip:
80 | - absl-py==1.3.0
81 | - cachetools==5.2.0
82 | - charset-normalizer==2.1.1
83 | - filelock==3.8.0
84 | - google-auth==2.14.1
85 | - google-auth-oauthlib==0.4.6
86 | - grpcio==1.50.0
87 | - huggingface-hub==0.10.1
88 | - idna==3.4
89 | - importlib-metadata==5.0.0
90 | - markdown==3.4.1
91 | - markupsafe==2.1.1
92 | - oauthlib==3.2.2
93 | - packaging==21.3
94 | - protobuf==3.20.3
95 | - pyasn1==0.4.8
96 | - pyasn1-modules==0.2.8
97 | - pyparsing==3.0.9
98 | - requests==2.28.1
99 | - requests-oauthlib==1.3.1
100 | - rsa==4.9
101 | - tensorboard==2.11.0
102 | - tensorboard-data-server==0.6.1
103 | - tensorboard-plugin-wit==1.8.1
104 | - timm==0.3.2
105 | - urllib3==1.26.12
106 | - werkzeug==2.2.2
107 | - zipp==3.10.0
108 | prefix: /vulcanscratch/shlokm/Ana/envs/vissl
109 |
--------------------------------------------------------------------------------
/engine_pretrain.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from typing import Iterable
4 |
5 | import torch
6 |
7 | import util.misc as misc
8 | import util.lr_sched as lr_sched
9 |
10 |
11 | def train_one_epoch(model: torch.nn.Module,
12 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
13 | device: torch.device, epoch: int, loss_scaler,
14 | log_writer=None,
15 | args=None):
16 | model.train(True)
17 | metric_logger = misc.MetricLogger(delimiter=" ")
18 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
19 | metric_logger.add_meter('loss_contrastive', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
20 | metric_logger.add_meter('loss_noise', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
21 | header = 'Epoch: [{}]'.format(epoch)
22 | print_freq = 20
23 |
24 | accum_iter = args.accum_iter
25 | weight_mae = args.weight_mae
26 | weight_simclr = args.weight_simclr
27 | weight_noise = args.weight_noise
28 |
29 | optimizer.zero_grad()
30 |
31 | if log_writer is not None:
32 | print('log_dir: {}'.format(log_writer.log_dir))
33 |
34 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
35 |
36 | # we use a per iteration (instead of per epoch) lr scheduler
37 | if data_iter_step % accum_iter == 0:
38 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
39 |
40 | samples = torch.cat([samples[0], samples[1]], dim=0) # contrastive hack
41 | samples = samples.to(device, non_blocking=True)
42 |
43 | with torch.cuda.amp.autocast():
44 | loss, loss_contrastive, loss_noise, _, _ = model(samples, mask_ratio=args.mask_ratio)
45 |
46 | loss_recon = weight_noise * loss_noise + (1-weight_noise) * loss
47 | loss = weight_mae * loss_recon + weight_simclr * loss_contrastive
48 | loss_value = loss.item()
49 | loss_contrastive_value = loss_contrastive.item()
50 | loss_noise_value = loss_noise.item()
51 |
52 | if not math.isfinite(loss_value):
53 | print("Loss is {}, stopping training".format(loss_value))
54 | sys.exit(1)
55 |
56 | loss /= accum_iter
57 | loss_scaler(loss, optimizer, parameters=model.parameters(),
58 | update_grad=(data_iter_step + 1) % accum_iter == 0)
59 | if (data_iter_step + 1) % accum_iter == 0:
60 | optimizer.zero_grad()
61 |
62 | torch.cuda.synchronize()
63 |
64 | metric_logger.update(loss=loss_value)
65 | metric_logger.update(loss_contrastive=loss_contrastive_value)
66 | metric_logger.update(loss_noise=loss_noise_value)
67 |
68 | lr = optimizer.param_groups[0]["lr"]
69 | metric_logger.update(lr=lr)
70 |
71 | loss_value_reduce = misc.all_reduce_mean(loss_value)
72 | loss_contrastive_value_reduce = misc.all_reduce_mean(loss_contrastive_value)
73 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
74 | """ We use epoch_1000x as the x-axis in tensorboard.
75 | This calibrates different curves when batch size changes.
76 | """
77 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
78 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
79 | log_writer.add_scalar('lr', lr, epoch_1000x)
80 |
81 |
82 | # gather the stats from all processes
83 | metric_logger.synchronize_between_processes()
84 | print("Averaged stats:", metric_logger)
85 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/loss_contrastive.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class SupConLoss(nn.Module):
8 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
9 | It also supports the unsupervised contrastive loss in SimCLR"""
10 | def __init__(self, temperature=0.1, contrast_mode='all',
11 | base_temperature=0.1):
12 | super(SupConLoss, self).__init__()
13 | self.temperature = temperature
14 | self.contrast_mode = contrast_mode
15 | self.base_temperature = base_temperature
16 |
17 | def forward(self, features, labels=None, mask=None):
18 | """Compute loss for model. If both `labels` and `mask` are None,
19 | it degenerates to SimCLR unsupervised loss:
20 | https://arxiv.org/pdf/2002.05709.pdf
21 |
22 | Args:
23 | features: hidden vector of shape [bsz, n_views, ...].
24 | labels: ground truth of shape [bsz].
25 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
26 | has the same class as sample i. Can be asymmetric.
27 | Returns:
28 | A loss scalar.
29 | """
30 | device = (torch.device('cuda')
31 | if features.is_cuda
32 | else torch.device('cpu'))
33 |
34 | if len(features.shape) < 3:
35 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
36 | 'at least 3 dimensions are required')
37 | if len(features.shape) > 3:
38 | features = features.view(features.shape[0], features.shape[1], -1)
39 |
40 | batch_size = features.shape[0]
41 | if labels is not None and mask is not None:
42 | raise ValueError('Cannot define both `labels` and `mask`')
43 | elif labels is None and mask is None:
44 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
45 | elif labels is not None:
46 | labels = labels.contiguous().view(-1, 1)
47 | if labels.shape[0] != batch_size:
48 | raise ValueError('Num of labels does not match num of features')
49 | mask = torch.eq(labels, labels.T).float().to(device)
50 | else:
51 | mask = mask.float().to(device)
52 |
53 | contrast_count = features.shape[1]
54 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
55 | if self.contrast_mode == 'one':
56 | anchor_feature = features[:, 0]
57 | anchor_count = 1
58 | elif self.contrast_mode == 'all':
59 | anchor_feature = contrast_feature
60 | anchor_count = contrast_count
61 | else:
62 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
63 |
64 | # compute logits
65 | anchor_dot_contrast = torch.div(
66 | torch.matmul(anchor_feature, contrast_feature.T),
67 | self.temperature)
68 | # for numerical stability
69 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
70 | logits = anchor_dot_contrast - logits_max.detach()
71 | # logits = anchor_dot_contrast
72 |
73 | # print(logits.mean())
74 |
75 | # tile mask
76 | mask = mask.repeat(anchor_count, contrast_count)
77 | # mask-out self-contrast cases
78 | logits_mask = torch.scatter(
79 | torch.ones_like(mask),
80 | 1,
81 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
82 | 0
83 | )
84 | mask = mask * logits_mask
85 |
86 | # compute log_prob
87 | exp_logits = torch.exp(logits) * logits_mask
88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
89 |
90 | # compute mean of log-likelihood over positive
91 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
92 | # import pdb
93 | # pdb.set_trace()
94 |
95 | # loss
96 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
97 | loss = loss.view(anchor_count, batch_size).mean()
98 |
99 | return loss
--------------------------------------------------------------------------------
/util/pos_embed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 |
6 | def get_1d_sincos_pos_embed(x: torch.Tensor, dim: int):
7 | """From: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py"""
8 | half_dim = dim // 2
9 | emb = math.log(10000) / (half_dim - 1)
10 | emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
11 | emb = x[:, None] * emb[None, :]
12 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
13 | return emb
14 |
15 | # --------------------------------------------------------
16 | # 2D sine-cosine position embedding
17 | # References:
18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
19 | # MoCo v3: https://github.com/facebookresearch/moco-v3
20 | # --------------------------------------------------------
21 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
22 | """
23 | grid_size: int of the grid height and width
24 | return:
25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
26 | """
27 | grid_h = np.arange(grid_size, dtype=np.float32)
28 | grid_w = np.arange(grid_size, dtype=np.float32)
29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
30 | grid = np.stack(grid, axis=0)
31 |
32 | grid = grid.reshape([2, 1, grid_size, grid_size])
33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
34 | if cls_token:
35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
36 | return pos_embed
37 |
38 |
39 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
40 | assert embed_dim % 2 == 0
41 |
42 | # use half of dimensions to encode grid_h
43 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
44 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
45 |
46 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
47 | return emb
48 |
49 |
50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
51 | """
52 | embed_dim: output dimension for each position
53 | pos: a list of positions to be encoded: size (M,)
54 | out: (M, D)
55 | """
56 | assert embed_dim % 2 == 0
57 | omega = np.arange(embed_dim // 2, dtype=np.float)
58 | omega /= embed_dim / 2.
59 | omega = 1. / 10000**omega # (D/2,)
60 |
61 | pos = pos.reshape(-1) # (M,)
62 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
63 |
64 | emb_sin = np.sin(out) # (M, D/2)
65 | emb_cos = np.cos(out) # (M, D/2)
66 |
67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
68 | return emb
69 |
70 |
71 | # --------------------------------------------------------
72 | # Interpolate position embeddings for high-resolution
73 | # References:
74 | # DeiT: https://github.com/facebookresearch/deit
75 | # --------------------------------------------------------
76 | def interpolate_pos_embed(model, checkpoint_model):
77 | if 'pos_embed' in checkpoint_model:
78 | pos_embed_checkpoint = checkpoint_model['pos_embed']
79 | embedding_size = pos_embed_checkpoint.shape[-1]
80 | num_patches = model.patch_embed.num_patches
81 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
82 | # height (== width) for the checkpoint position embedding
83 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
84 | # height (== width) for the new position embedding
85 | new_size = int(num_patches ** 0.5)
86 | # class_token and dist_token are kept unchanged
87 | if orig_size != new_size:
88 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
89 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
90 | # only the position tokens are interpolated
91 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
92 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
93 | pos_tokens = torch.nn.functional.interpolate(
94 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
95 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
96 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
97 | checkpoint_model['pos_embed'] = new_pos_embed
98 |
--------------------------------------------------------------------------------
/submitit_finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import uuid
4 | from pathlib import Path
5 |
6 | import main_finetune as classification
7 | import submitit
8 |
9 |
10 | def parse_args():
11 | classification_parser = classification.get_args_parser()
12 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser])
13 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
14 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
15 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job")
16 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
17 |
18 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
19 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs")
20 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
21 | return parser.parse_args()
22 |
23 |
24 | def get_shared_folder() -> Path:
25 | user = os.getenv("USER")
26 | if Path("/checkpoint/").is_dir():
27 | p = Path(f"/checkpoint/{user}/experiments")
28 | p.mkdir(exist_ok=True)
29 | return p
30 | raise RuntimeError("No shared folder available")
31 |
32 |
33 | def get_init_file():
34 | # Init file must not exist, but it's parent dir must exist.
35 | os.makedirs(str(get_shared_folder()), exist_ok=True)
36 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
37 | if init_file.exists():
38 | os.remove(str(init_file))
39 | return init_file
40 |
41 |
42 | class Trainer(object):
43 | def __init__(self, args):
44 | self.args = args
45 |
46 | def __call__(self):
47 | import main_finetune as classification
48 |
49 | self._setup_gpu_args()
50 | classification.main(self.args)
51 |
52 | def checkpoint(self):
53 | import os
54 | import submitit
55 |
56 | self.args.dist_url = get_init_file().as_uri()
57 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
58 | if os.path.exists(checkpoint_file):
59 | self.args.resume = checkpoint_file
60 | print("Requeuing ", self.args)
61 | empty_trainer = type(self)(self.args)
62 | return submitit.helpers.DelayedSubmission(empty_trainer)
63 |
64 | def _setup_gpu_args(self):
65 | import submitit
66 | from pathlib import Path
67 |
68 | job_env = submitit.JobEnvironment()
69 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
70 | self.args.log_dir = self.args.output_dir
71 | self.args.gpu = job_env.local_rank
72 | self.args.rank = job_env.global_rank
73 | self.args.world_size = job_env.num_tasks
74 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
75 |
76 |
77 | def main():
78 | args = parse_args()
79 | if args.job_dir == "":
80 | args.job_dir = get_shared_folder() / "%j"
81 |
82 | # Note that the folder will depend on the job_id, to easily track experiments
83 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
84 |
85 | num_gpus_per_node = args.ngpus
86 | nodes = args.nodes
87 | timeout_min = args.timeout
88 |
89 | partition = args.partition
90 | kwargs = {}
91 | if args.use_volta32:
92 | kwargs['slurm_constraint'] = 'volta32gb'
93 | if args.comment:
94 | kwargs['slurm_comment'] = args.comment
95 |
96 | executor.update_parameters(
97 | mem_gb=40 * num_gpus_per_node,
98 | gpus_per_node=num_gpus_per_node,
99 | tasks_per_node=num_gpus_per_node, # one task per GPU
100 | cpus_per_task=10,
101 | nodes=nodes,
102 | timeout_min=timeout_min,
103 | # Below are cluster dependent parameters
104 | slurm_partition=partition,
105 | slurm_signal_delay_s=120,
106 | **kwargs
107 | )
108 |
109 | executor.update_parameters(name="mae")
110 |
111 | args.dist_url = get_init_file().as_uri()
112 | args.output_dir = args.job_dir
113 |
114 | trainer = Trainer(args)
115 | job = executor.submit(trainer)
116 |
117 | # print("Submitted job_id:", job.job_id)
118 | print(job.job_id)
119 |
120 |
121 | if __name__ == "__main__":
122 | main()
123 |
--------------------------------------------------------------------------------
/submitit_linprobe.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import uuid
4 | from pathlib import Path
5 |
6 | import main_linprobe as classification
7 | import submitit
8 |
9 |
10 | def parse_args():
11 | classification_parser = classification.get_args_parser()
12 | parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser])
13 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
14 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
15 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job")
16 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
17 |
18 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
19 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs")
20 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
21 | return parser.parse_args()
22 |
23 |
24 | def get_shared_folder() -> Path:
25 | user = os.getenv("USER")
26 | if Path("/checkpoint/").is_dir():
27 | p = Path(f"/checkpoint/{user}/experiments")
28 | p.mkdir(exist_ok=True)
29 | return p
30 | raise RuntimeError("No shared folder available")
31 |
32 |
33 | def get_init_file():
34 | # Init file must not exist, but it's parent dir must exist.
35 | os.makedirs(str(get_shared_folder()), exist_ok=True)
36 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
37 | if init_file.exists():
38 | os.remove(str(init_file))
39 | return init_file
40 |
41 |
42 | class Trainer(object):
43 | def __init__(self, args):
44 | self.args = args
45 |
46 | def __call__(self):
47 | import main_linprobe as classification
48 |
49 | self._setup_gpu_args()
50 | classification.main(self.args)
51 |
52 | def checkpoint(self):
53 | import os
54 | import submitit
55 |
56 | self.args.dist_url = get_init_file().as_uri()
57 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
58 | if os.path.exists(checkpoint_file):
59 | self.args.resume = checkpoint_file
60 | print("Requeuing ", self.args)
61 | empty_trainer = type(self)(self.args)
62 | return submitit.helpers.DelayedSubmission(empty_trainer)
63 |
64 | def _setup_gpu_args(self):
65 | import submitit
66 | from pathlib import Path
67 |
68 | job_env = submitit.JobEnvironment()
69 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
70 | self.args.log_dir = self.args.output_dir
71 | self.args.gpu = job_env.local_rank
72 | self.args.rank = job_env.global_rank
73 | self.args.world_size = job_env.num_tasks
74 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
75 |
76 |
77 | def main():
78 | args = parse_args()
79 | if args.job_dir == "":
80 | args.job_dir = get_shared_folder() / "%j"
81 |
82 | # Note that the folder will depend on the job_id, to easily track experiments
83 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
84 |
85 | num_gpus_per_node = args.ngpus
86 | nodes = args.nodes
87 | timeout_min = args.timeout
88 |
89 | partition = args.partition
90 | kwargs = {}
91 | if args.use_volta32:
92 | kwargs['slurm_constraint'] = 'volta32gb'
93 | if args.comment:
94 | kwargs['slurm_comment'] = args.comment
95 |
96 | executor.update_parameters(
97 | mem_gb=40 * num_gpus_per_node,
98 | gpus_per_node=num_gpus_per_node,
99 | tasks_per_node=num_gpus_per_node, # one task per GPU
100 | cpus_per_task=10,
101 | nodes=nodes,
102 | timeout_min=timeout_min,
103 | # Below are cluster dependent parameters
104 | slurm_partition=partition,
105 | slurm_signal_delay_s=120,
106 | **kwargs
107 | )
108 |
109 | executor.update_parameters(name="mae")
110 |
111 | args.dist_url = get_init_file().as_uri()
112 | args.output_dir = args.job_dir
113 |
114 | trainer = Trainer(args)
115 | job = executor.submit(trainer)
116 |
117 | # print("Submitted job_id:", job.job_id)
118 | print(job.job_id)
119 |
120 |
121 | if __name__ == "__main__":
122 | main()
123 |
--------------------------------------------------------------------------------
/submitit_pretrain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import uuid
4 | from pathlib import Path
5 |
6 | import main_pretrain as trainer
7 | import submitit
8 |
9 |
10 | def parse_args():
11 | trainer_parser = trainer.get_args_parser()
12 | parser = argparse.ArgumentParser("Submitit for MAE pretrain", parents=[trainer_parser])
13 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
14 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
15 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job")
16 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
17 |
18 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
19 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs")
20 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
21 | return parser.parse_args()
22 |
23 |
24 | def get_shared_folder() -> Path:
25 | user = os.getenv("USER")
26 | return "/fs/vulcan-projects/jigsaw_selfsup_shlokm/dv1/mae/checkpoint/"
27 | # if Path("/checkpoint/").is_dir():
28 | # p = Path(f"/checkpoint/{user}/experiments")
29 | # p.mkdir(exist_ok=True)
30 | # return p
31 | raise RuntimeError("No shared folder available")
32 |
33 |
34 | def get_init_file():
35 | # Init file must not exist, but it's parent dir must exist.
36 | os.makedirs(str(get_shared_folder()), exist_ok=True)
37 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
38 | if init_file.exists():
39 | os.remove(str(init_file))
40 | return init_file
41 |
42 |
43 | class Trainer(object):
44 | def __init__(self, args):
45 | self.args = args
46 |
47 | def __call__(self):
48 | import main_pretrain as trainer
49 |
50 | self._setup_gpu_args()
51 | trainer.main(self.args)
52 |
53 | def checkpoint(self):
54 | import os
55 | import submitit
56 |
57 | self.args.dist_url = get_init_file().as_uri()
58 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
59 | if os.path.exists(checkpoint_file):
60 | self.args.resume = checkpoint_file
61 | print("Requeuing ", self.args)
62 | empty_trainer = type(self)(self.args)
63 | return submitit.helpers.DelayedSubmission(empty_trainer)
64 |
65 | def _setup_gpu_args(self):
66 | import submitit
67 | from pathlib import Path
68 |
69 | job_env = submitit.JobEnvironment()
70 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
71 | self.args.log_dir = self.args.output_dir
72 | self.args.gpu = job_env.local_rank
73 | self.args.rank = job_env.global_rank
74 | self.args.world_size = job_env.num_tasks
75 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
76 |
77 |
78 | def main():
79 | args = parse_args()
80 | if args.job_dir == "":
81 | args.job_dir = get_shared_folder() / "%j"
82 |
83 | # Note that the folder will depend on the job_id, to easily track experiments
84 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
85 |
86 | num_gpus_per_node = args.ngpus
87 | nodes = args.nodes
88 | timeout_min = args.timeout
89 |
90 | partition = args.partition
91 | kwargs = {}
92 | if args.use_volta32:
93 | kwargs['slurm_constraint'] = 'volta32gb'
94 | if args.comment:
95 | kwargs['slurm_comment'] = args.comment
96 |
97 | executor.update_parameters(
98 | mem_gb=40 * num_gpus_per_node,
99 | gpus_per_node=num_gpus_per_node,
100 | tasks_per_node=num_gpus_per_node, # one task per GPU
101 | cpus_per_task=10,
102 | nodes=nodes,
103 | timeout_min=timeout_min, # max is 60 * 72
104 | # Below are cluster dependent parameters
105 | slurm_partition=partition,
106 | slurm_signal_delay_s=120,
107 | **kwargs
108 | )
109 |
110 | executor.update_parameters(name="mae")
111 |
112 | args.dist_url = get_init_file().as_uri()
113 | args.output_dir = args.job_dir
114 |
115 | trainer = Trainer(args)
116 | job = executor.submit(trainer)
117 |
118 | # print("Submitted job_id:", job.job_id)
119 | print(job.job_id)
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/engine_finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import math
13 | import sys
14 | from typing import Iterable, Optional
15 |
16 | import torch
17 |
18 | from timm.data import Mixup
19 | from timm.utils import accuracy
20 |
21 | import util.misc as misc
22 | import util.lr_sched as lr_sched
23 |
24 |
25 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
26 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
27 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
28 | mixup_fn: Optional[Mixup] = None, log_writer=None,
29 | args=None):
30 | model.train(True)
31 | metric_logger = misc.MetricLogger(delimiter=" ")
32 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
33 | header = 'Epoch: [{}]'.format(epoch)
34 | print_freq = 20
35 |
36 | accum_iter = args.accum_iter
37 |
38 | optimizer.zero_grad()
39 |
40 | if log_writer is not None:
41 | print('log_dir: {}'.format(log_writer.log_dir))
42 |
43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
44 |
45 | # we use a per iteration (instead of per epoch) lr scheduler
46 | if data_iter_step % accum_iter == 0:
47 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
48 |
49 | samples = samples.to(device, non_blocking=True)
50 | targets = targets.to(device, non_blocking=True)
51 |
52 | if mixup_fn is not None:
53 | samples, targets = mixup_fn(samples, targets)
54 |
55 | with torch.cuda.amp.autocast():
56 | outputs = model(samples)
57 | loss = criterion(outputs, targets)
58 |
59 | loss_value = loss.item()
60 | if data_iter_step%10==0:
61 | print("Loss is {}".format(loss_value))
62 |
63 | if not math.isfinite(loss_value):
64 | print("Loss is {}, stopping training".format(loss_value))
65 | sys.exit(1)
66 |
67 | loss /= accum_iter
68 | loss_scaler(loss, optimizer, clip_grad=max_norm,
69 | parameters=model.parameters(), create_graph=False,
70 | update_grad=(data_iter_step + 1) % accum_iter == 0)
71 | if (data_iter_step + 1) % accum_iter == 0:
72 | optimizer.zero_grad()
73 |
74 | torch.cuda.synchronize()
75 |
76 | metric_logger.update(loss=loss_value)
77 | min_lr = 10.
78 | max_lr = 0.
79 | for group in optimizer.param_groups:
80 | min_lr = min(min_lr, group["lr"])
81 | max_lr = max(max_lr, group["lr"])
82 |
83 | metric_logger.update(lr=max_lr)
84 |
85 | loss_value_reduce = misc.all_reduce_mean(loss_value)
86 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
87 | """ We use epoch_1000x as the x-axis in tensorboard.
88 | This calibrates different curves when batch size changes.
89 | """
90 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
91 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
92 | log_writer.add_scalar('lr', max_lr, epoch_1000x)
93 |
94 | # gather the stats from all processes
95 | metric_logger.synchronize_between_processes()
96 | print("Averaged stats:", metric_logger)
97 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
98 |
99 |
100 | @torch.no_grad()
101 | def evaluate(data_loader, model, device):
102 | criterion = torch.nn.CrossEntropyLoss()
103 |
104 | metric_logger = misc.MetricLogger(delimiter=" ")
105 | header = 'Test:'
106 |
107 | # switch to evaluation mode
108 | model.eval()
109 |
110 | for batch in metric_logger.log_every(data_loader, 10, header):
111 | images = batch[0]
112 | target = batch[-1]
113 | images = images.to(device, non_blocking=True)
114 | target = target.to(device, non_blocking=True)
115 |
116 | # compute output
117 | with torch.cuda.amp.autocast():
118 | output = model(images)
119 | loss = criterion(output, target)
120 |
121 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
122 |
123 | batch_size = images.shape[0]
124 | metric_logger.update(loss=loss.item())
125 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
126 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
127 | # gather the stats from all processes
128 | metric_logger.synchronize_between_processes()
129 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
130 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
131 |
132 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/main_pretrain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import numpy as np
5 | import os
6 | import time
7 | from pathlib import Path
8 |
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | from torch.utils.tensorboard import SummaryWriter
12 | import torchvision.transforms as transforms
13 | import torchvision.datasets as datasets
14 |
15 | import timm
16 |
17 | #assert timm.__version__ == "0.3.2" # version check
18 | import timm.optim.optim_factory as optim_factory
19 |
20 | import util.misc as misc
21 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
22 | from util_contrastive import TwoCropTransform
23 | from util_contrastive import GaussianBlur
24 |
25 | import models_mae
26 |
27 | from engine_pretrain import train_one_epoch
28 |
29 |
30 | def get_args_parser():
31 | parser = argparse.ArgumentParser('CAN pre-training', add_help=False)
32 | parser.add_argument('--batch_size', default=64, type=int,
33 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
34 | parser.add_argument('--epochs', default=400, type=int)
35 | parser.add_argument('--accum_iter', default=1, type=int,
36 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
37 |
38 | # Model parameters
39 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
40 | help='Name of model to train')
41 |
42 | parser.add_argument('--input_size', default=224, type=int,
43 | help='images input size')
44 |
45 | parser.add_argument('--mask_ratio', default=0.75, type=float,
46 | help='Masking ratio (percentage of removed patches).')
47 |
48 | parser.add_argument('--norm_pix_loss', action='store_true',
49 | help='Use (per-patch) normalized pixels as targets for computing loss')
50 | parser.set_defaults(norm_pix_loss=False)
51 |
52 | parser.add_argument('--weight_mae', default=0.97, type=float,
53 | help='Loss weight of mae (default: 0.97).')
54 | parser.add_argument('--weight_simclr', default=0.03, type=float,
55 | help='Loss weight of simclr (default: 0.03).')
56 |
57 |
58 | parser.add_argument('--noise_loss', action='store_true')
59 | parser.add_argument('--std', default=0.05, type=float,
60 | help='Standard deviation of noise added to loss.')
61 | parser.add_argument('--weight_noise', default=0.3, type=float,
62 | help='Weight allocated to noise loss.')
63 |
64 |
65 | # Optimizer parameters
66 | parser.add_argument('--weight_decay', type=float, default=0.05,
67 | help='weight decay (default: 0.05)')
68 |
69 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
70 | help='learning rate (absolute lr)')
71 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
72 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
73 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
74 | help='lower lr bound for cyclic schedulers that hit 0')
75 |
76 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
77 | help='epochs to warmup LR')
78 |
79 | # Dataset parameters
80 | parser.add_argument('--data_path', default='/data/scratch/joshrob/data/imagenet100/', type=str,
81 | help='dataset path')
82 |
83 | parser.add_argument('--output_dir', default='./output_dir',
84 | help='path where to save, empty for no saving')
85 | parser.add_argument('--log_dir', default='./output_dir',
86 | help='path where to tensorboard log')
87 | parser.add_argument('--device', default='cuda',
88 | help='device to use for training / testing')
89 | parser.add_argument('--seed', default=0, type=int)
90 | parser.add_argument('--resume', default='',
91 | help='resume from checkpoint')
92 |
93 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
94 | help='start epoch')
95 | parser.add_argument('--num_workers', default=10, type=int)
96 | parser.add_argument('--pin_mem', action='store_true',
97 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
98 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
99 | parser.set_defaults(pin_mem=True)
100 |
101 | # distributed training parameters
102 | parser.add_argument('--world_size', default=1, type=int,
103 | help='number of distributed processes')
104 | parser.add_argument('--local_rank', default=-1, type=int)
105 | parser.add_argument('--dist_on_itp', action='store_true')
106 | parser.add_argument('--dist_url', default='env://',
107 | help='url used to set up distributed training')
108 |
109 | return parser
110 |
111 |
112 | def main(args):
113 | misc.init_distributed_mode(args)
114 |
115 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
116 | print("{}".format(args).replace(', ', ',\n'))
117 |
118 | device = torch.device(args.device)
119 |
120 | # fix the seed for reproducibility
121 | seed = args.seed + misc.get_rank()
122 | torch.manual_seed(seed)
123 | np.random.seed(seed)
124 |
125 | cudnn.benchmark = True
126 |
127 | # simple augmentation
128 | # transform_train = transforms.Compose([
129 | # transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic
130 | # transforms.RandomHorizontalFlip(),
131 | # transforms.ToTensor(),
132 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
133 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
134 | transform_train = transforms.Compose([
135 | transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)), # hardcoded TODO
136 | transforms.RandomHorizontalFlip(),
137 | transforms.RandomApply([
138 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
139 | ], p=0.8),
140 | transforms.RandomGrayscale(p=0.2),
141 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
142 | transforms.ToTensor(),
143 | normalize,
144 | ])
145 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=TwoCropTransform(transform_train))
146 | print(dataset_train)
147 |
148 | if True: # args.distributed:
149 | num_tasks = misc.get_world_size()
150 | global_rank = misc.get_rank()
151 | sampler_train = torch.utils.data.DistributedSampler(
152 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
153 | )
154 | print("Sampler_train = %s" % str(sampler_train))
155 | else:
156 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
157 |
158 | if global_rank == 0 and args.log_dir is not None:
159 | os.makedirs(args.log_dir, exist_ok=True)
160 | log_writer = SummaryWriter(log_dir=args.log_dir)
161 | else:
162 | log_writer = None
163 |
164 | data_loader_train = torch.utils.data.DataLoader(
165 | dataset_train, sampler=sampler_train,
166 | batch_size=args.batch_size,
167 | num_workers=args.num_workers,
168 | pin_memory=args.pin_mem,
169 | drop_last=True,
170 | )
171 |
172 | # define the model
173 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, noise_loss=args.noise_loss)
174 |
175 | model.to(device)
176 |
177 | model_without_ddp = model
178 | print("Model = %s" % str(model_without_ddp))
179 |
180 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
181 |
182 | if args.lr is None: # only base_lr is specified
183 | args.lr = args.blr * eff_batch_size / 256
184 |
185 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
186 | print("actual lr: %.2e" % args.lr)
187 |
188 | print("accumulate grad iterations: %d" % args.accum_iter)
189 | print("effective batch size: %d" % eff_batch_size)
190 |
191 | if args.distributed:
192 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
193 | model_without_ddp = model.module
194 |
195 | # following timm: set wd as 0 for bias and norm layers
196 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
197 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
198 | print(optimizer)
199 | loss_scaler = NativeScaler()
200 |
201 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
202 |
203 | print(f"Start training for {args.epochs} epochs")
204 | start_time = time.time()
205 | for epoch in range(args.start_epoch, args.epochs):
206 | if args.distributed:
207 | data_loader_train.sampler.set_epoch(epoch)
208 | train_stats = train_one_epoch(
209 | model, data_loader_train,
210 | optimizer, device, epoch, loss_scaler,
211 | log_writer=log_writer,
212 | args=args
213 | )
214 | if args.output_dir and (epoch % 1 == 0 or epoch + 1 == args.epochs):
215 | misc.save_model(
216 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
217 | loss_scaler=loss_scaler, epoch=epoch)
218 |
219 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
220 | 'epoch': epoch,}
221 |
222 | if args.output_dir and misc.is_main_process():
223 | if log_writer is not None:
224 | log_writer.flush()
225 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
226 | f.write(json.dumps(log_stats) + "\n")
227 |
228 | total_time = time.time() - start_time
229 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
230 | print('Training time {}'.format(total_time_str))
231 |
232 |
233 | if __name__ == '__main__':
234 | args = get_args_parser()
235 | args = args.parse_args()
236 | #args.local_rank = os.environ['LOCAL_RANK']
237 | if args.output_dir:
238 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
239 | main(args)
240 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import datetime
3 | import os
4 | import time
5 | from collections import defaultdict, deque
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.distributed as dist
10 | from torch._six import inf
11 |
12 |
13 | class SmoothedValue(object):
14 | """Track a series of values and provide access to smoothed values over a
15 | window or the global series average.
16 | """
17 |
18 | def __init__(self, window_size=20, fmt=None):
19 | if fmt is None:
20 | fmt = "{median:.4f} ({global_avg:.4f})"
21 | self.deque = deque(maxlen=window_size)
22 | self.total = 0.0
23 | self.count = 0
24 | self.fmt = fmt
25 |
26 | def update(self, value, n=1):
27 | self.deque.append(value)
28 | self.count += n
29 | self.total += value * n
30 |
31 | def synchronize_between_processes(self):
32 | """
33 | Warning: does not synchronize the deque!
34 | """
35 | if not is_dist_avail_and_initialized():
36 | return
37 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
38 | dist.barrier()
39 | dist.all_reduce(t)
40 | t = t.tolist()
41 | self.count = int(t[0])
42 | self.total = t[1]
43 |
44 | @property
45 | def median(self):
46 | d = torch.tensor(list(self.deque))
47 | return d.median().item()
48 |
49 | @property
50 | def avg(self):
51 | d = torch.tensor(list(self.deque), dtype=torch.float32)
52 | return d.mean().item()
53 |
54 | @property
55 | def global_avg(self):
56 | return self.total / self.count
57 |
58 | @property
59 | def max(self):
60 | return max(self.deque)
61 |
62 | @property
63 | def value(self):
64 | return self.deque[-1]
65 |
66 | def __str__(self):
67 | return self.fmt.format(
68 | median=self.median,
69 | avg=self.avg,
70 | global_avg=self.global_avg,
71 | max=self.max,
72 | value=self.value)
73 |
74 |
75 | class MetricLogger(object):
76 | def __init__(self, delimiter="\t"):
77 | self.meters = defaultdict(SmoothedValue)
78 | self.delimiter = delimiter
79 |
80 | def update(self, **kwargs):
81 | for k, v in kwargs.items():
82 | if v is None:
83 | continue
84 | if isinstance(v, torch.Tensor):
85 | v = v.item()
86 | assert isinstance(v, (float, int))
87 | self.meters[k].update(v)
88 |
89 | def __getattr__(self, attr):
90 | if attr in self.meters:
91 | return self.meters[attr]
92 | if attr in self.__dict__:
93 | return self.__dict__[attr]
94 | raise AttributeError("'{}' object has no attribute '{}'".format(
95 | type(self).__name__, attr))
96 |
97 | def __str__(self):
98 | loss_str = []
99 | for name, meter in self.meters.items():
100 | loss_str.append(
101 | "{}: {}".format(name, str(meter))
102 | )
103 | return self.delimiter.join(loss_str)
104 |
105 | def synchronize_between_processes(self):
106 | for meter in self.meters.values():
107 | meter.synchronize_between_processes()
108 |
109 | def add_meter(self, name, meter):
110 | self.meters[name] = meter
111 |
112 | def log_every(self, iterable, print_freq, header=None):
113 | i = 0
114 | if not header:
115 | header = ''
116 | start_time = time.time()
117 | end = time.time()
118 | iter_time = SmoothedValue(fmt='{avg:.4f}')
119 | data_time = SmoothedValue(fmt='{avg:.4f}')
120 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
121 | log_msg = [
122 | header,
123 | '[{0' + space_fmt + '}/{1}]',
124 | 'eta: {eta}',
125 | '{meters}',
126 | 'time: {time}',
127 | 'data: {data}'
128 | ]
129 | if torch.cuda.is_available():
130 | log_msg.append('max mem: {memory:.0f}')
131 | log_msg = self.delimiter.join(log_msg)
132 | MB = 1024.0 * 1024.0
133 | for obj in iterable:
134 | data_time.update(time.time() - end)
135 | yield obj
136 | iter_time.update(time.time() - end)
137 | if i % print_freq == 0 or i == len(iterable) - 1:
138 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
139 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
140 | if torch.cuda.is_available():
141 | print(log_msg.format(
142 | i, len(iterable), eta=eta_string,
143 | meters=str(self),
144 | time=str(iter_time), data=str(data_time),
145 | memory=torch.cuda.max_memory_allocated() / MB))
146 | else:
147 | print(log_msg.format(
148 | i, len(iterable), eta=eta_string,
149 | meters=str(self),
150 | time=str(iter_time), data=str(data_time)))
151 | i += 1
152 | end = time.time()
153 | total_time = time.time() - start_time
154 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
155 | print('{} Total time: {} ({:.4f} s / it)'.format(
156 | header, total_time_str, total_time / len(iterable)))
157 |
158 |
159 | def setup_for_distributed(is_master):
160 | """
161 | This function disables printing when not in master process
162 | """
163 | builtin_print = builtins.print
164 |
165 | def print(*args, **kwargs):
166 | force = kwargs.pop('force', False)
167 | force = force or (get_world_size() > 8)
168 | if is_master or force:
169 | now = datetime.datetime.now().time()
170 | builtin_print('[{}] '.format(now), end='') # print with time stamp
171 | builtin_print(*args, **kwargs)
172 |
173 | builtins.print = print
174 |
175 |
176 | def is_dist_avail_and_initialized():
177 | if not dist.is_available():
178 | return False
179 | if not dist.is_initialized():
180 | return False
181 | return True
182 |
183 |
184 | def get_world_size():
185 | if not is_dist_avail_and_initialized():
186 | return 1
187 | return dist.get_world_size()
188 |
189 |
190 | def get_rank():
191 | if not is_dist_avail_and_initialized():
192 | return 0
193 | return dist.get_rank()
194 |
195 |
196 | def is_main_process():
197 | return get_rank() == 0
198 |
199 |
200 | def save_on_master(*args, **kwargs):
201 | if is_main_process():
202 | torch.save(*args, **kwargs)
203 |
204 |
205 | def init_distributed_mode(args):
206 | if args.dist_on_itp:
207 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
208 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
209 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
210 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
211 | os.environ['LOCAL_RANK'] = str(args.gpu)
212 | os.environ['RANK'] = str(args.rank)
213 | os.environ['WORLD_SIZE'] = str(args.world_size)
214 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
215 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
216 | args.rank = int(os.environ["RANK"])
217 | args.world_size = int(os.environ['WORLD_SIZE'])
218 | args.gpu = int(os.environ['LOCAL_RANK'])
219 | elif 'SLURM_PROCID' in os.environ:
220 | args.rank = int(os.environ['SLURM_PROCID'])
221 | args.gpu = args.rank % torch.cuda.device_count()
222 | else:
223 | print('Not using distributed mode')
224 | setup_for_distributed(is_master=True) # hack
225 | args.distributed = False
226 | return
227 |
228 | args.distributed = True
229 | print(args.gpu)
230 | torch.cuda.set_device(args.gpu)
231 | args.dist_backend = 'nccl'
232 | print('| distributed init (rank {}): {}, gpu {}'.format(
233 | args.rank, args.dist_url, args.gpu), flush=True)
234 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
235 | world_size=args.world_size, rank=args.rank)
236 | torch.distributed.barrier()
237 | setup_for_distributed(args.rank == 0)
238 |
239 |
240 | class NativeScalerWithGradNormCount:
241 | state_dict_key = "amp_scaler"
242 |
243 | def __init__(self):
244 | self._scaler = torch.cuda.amp.GradScaler()
245 |
246 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
247 | self._scaler.scale(loss).backward(create_graph=create_graph)
248 | if update_grad:
249 | if clip_grad is not None:
250 | assert parameters is not None
251 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
252 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
253 | else:
254 | self._scaler.unscale_(optimizer)
255 | norm = get_grad_norm_(parameters)
256 | self._scaler.step(optimizer)
257 | self._scaler.update()
258 | else:
259 | norm = None
260 | return norm
261 |
262 | def state_dict(self):
263 | return self._scaler.state_dict()
264 |
265 | def load_state_dict(self, state_dict):
266 | self._scaler.load_state_dict(state_dict)
267 |
268 |
269 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
270 | if isinstance(parameters, torch.Tensor):
271 | parameters = [parameters]
272 | parameters = [p for p in parameters if p.grad is not None]
273 | norm_type = float(norm_type)
274 | if len(parameters) == 0:
275 | return torch.tensor(0.)
276 | device = parameters[0].grad.device
277 | if norm_type == inf:
278 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
279 | else:
280 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
281 | return total_norm
282 |
283 |
284 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
285 | output_dir = Path(args.output_dir)
286 | epoch_name = str(epoch)
287 | if loss_scaler is not None:
288 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
289 | for checkpoint_path in checkpoint_paths:
290 | to_save = {
291 | 'model': model_without_ddp.state_dict(),
292 | 'optimizer': optimizer.state_dict(),
293 | 'epoch': epoch,
294 | 'scaler': loss_scaler.state_dict(),
295 | 'args': args,
296 | }
297 |
298 | save_on_master(to_save, checkpoint_path)
299 | else:
300 | client_state = {'epoch': epoch}
301 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
302 |
303 |
304 | def load_model(args, model_without_ddp, optimizer, loss_scaler):
305 | if args.resume:
306 | if args.resume.startswith('https'):
307 | checkpoint = torch.hub.load_state_dict_from_url(
308 | args.resume, map_location='cpu', check_hash=True)
309 | else:
310 | checkpoint = torch.load(args.resume, map_location='cpu')
311 | model_without_ddp.load_state_dict(checkpoint['model'])
312 | print("Resume checkpoint %s" % args.resume)
313 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
314 | optimizer.load_state_dict(checkpoint['optimizer'])
315 | args.start_epoch = checkpoint['epoch'] + 1
316 | if 'scaler' in checkpoint:
317 | loss_scaler.load_state_dict(checkpoint['scaler'])
318 | print("With optim & sched!")
319 |
320 |
321 | def all_reduce_mean(x):
322 | world_size = get_world_size()
323 | if world_size > 1:
324 | x_reduce = torch.tensor(x).cuda()
325 | dist.all_reduce(x_reduce)
326 | x_reduce /= world_size
327 | return x_reduce.item()
328 | else:
329 | return x
--------------------------------------------------------------------------------
/models_mae.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from timm.models.vision_transformer import PatchEmbed, Block
8 | from loss_contrastive import SupConLoss
9 |
10 | from util.pos_embed import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed
11 |
12 |
13 | class MaskedAutoencoderViT(nn.Module):
14 | """ Masked Autoencoder with VisionTransformer backbone
15 | """
16 | def __init__(self, img_size=224, patch_size=16, in_chans=3,
17 | embed_dim=1024, depth=24, num_heads=16,
18 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
19 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
20 | noise_loss=False, std=0.1, pe_dims=128):
21 | super().__init__()
22 |
23 | # --------------------------------------------------------------------------
24 | # MAE encoder specifics
25 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
26 | num_patches = self.patch_embed.num_patches
27 |
28 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
29 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
30 |
31 | self.blocks = nn.ModuleList([
32 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
33 | for i in range(depth)])
34 | self.norm = norm_layer(embed_dim)
35 |
36 | # projection head changes
37 | feat_dim = 128
38 | self.projection_head = nn.Sequential(
39 | nn.Linear(embed_dim, embed_dim),
40 | nn.ReLU(inplace=True),
41 | nn.Linear(embed_dim, feat_dim)
42 | )
43 |
44 | # noise loss specifics
45 | self.noise_loss = noise_loss
46 | self.std = std
47 | self.pe_dims=pe_dims
48 | self.noise_pe_mlp = nn.Sequential(
49 | nn.Linear(pe_dims, embed_dim),
50 | nn.ReLU(inplace=True),
51 | nn.Linear(embed_dim, embed_dim)
52 | )
53 |
54 | # --------------------------------------------------------------------------
55 |
56 | # --------------------------------------------------------------------------
57 | # MAE decoder specifics
58 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
59 |
60 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
61 |
62 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
63 |
64 | self.decoder_blocks = nn.ModuleList([
65 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
66 | for i in range(decoder_depth)])
67 |
68 | self.decoder_norm = norm_layer(decoder_embed_dim)
69 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
70 | # --------------------------------------------------------------------------
71 |
72 | self.norm_pix_loss = norm_pix_loss
73 |
74 | self.initialize_weights()
75 |
76 | def initialize_weights(self):
77 | # initialization
78 | # initialize (and freeze) pos_embed by sin-cos embedding
79 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
80 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
81 |
82 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
83 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
84 |
85 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
86 | w = self.patch_embed.proj.weight.data
87 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
88 |
89 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
90 | torch.nn.init.normal_(self.cls_token, std=.02)
91 | torch.nn.init.normal_(self.mask_token, std=.02)
92 |
93 | # initialize nn.Linear and nn.LayerNorm
94 | self.apply(self._init_weights)
95 |
96 | def _init_weights(self, m):
97 | if isinstance(m, nn.Linear):
98 | # we use xavier_uniform following official JAX ViT:
99 | torch.nn.init.xavier_uniform_(m.weight)
100 | if isinstance(m, nn.Linear) and m.bias is not None:
101 | nn.init.constant_(m.bias, 0)
102 | elif isinstance(m, nn.LayerNorm):
103 | nn.init.constant_(m.bias, 0)
104 | nn.init.constant_(m.weight, 1.0)
105 |
106 | def patchify(self, imgs):
107 | """
108 | imgs: (N, 3, H, W)
109 | x: (N, L, patch_size**2 *3)
110 | """
111 | p = self.patch_embed.patch_size[0]
112 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
113 |
114 | h = w = imgs.shape[2] // p
115 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
116 | x = torch.einsum('nchpwq->nhwpqc', x)
117 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
118 | return x
119 |
120 | def unpatchify(self, x):
121 | """
122 | x: (N, L, patch_size**2 *3)
123 | imgs: (N, 3, H, W)
124 | """
125 | p = self.patch_embed.patch_size[0]
126 | h = w = int(x.shape[1]**.5)
127 | assert h * w == x.shape[1]
128 |
129 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
130 | x = torch.einsum('nhwpqc->nchpwq', x)
131 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
132 | return imgs
133 |
134 | def random_masking(self, x, mask_ratio):
135 | """
136 | Perform per-sample random masking by per-sample shuffling.
137 | Per-sample shuffling is done by argsort random noise.
138 | x: [N, L, D], sequence
139 | """
140 | N, L, D = x.shape # batch, length, dim
141 | len_keep = int(L * (1 - mask_ratio))
142 |
143 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
144 |
145 | # sort noise for each sample
146 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
147 | ids_restore = torch.argsort(ids_shuffle, dim=1)
148 |
149 | # keep the first subset
150 | ids_keep = ids_shuffle[:, :len_keep]
151 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
152 |
153 | # generate the binary mask: 0 is keep, 1 is remove
154 | mask = torch.ones([N, L], device=x.device)
155 | mask[:, :len_keep] = 0
156 | # unshuffle to get the binary mask
157 | mask = torch.gather(mask, dim=1, index=ids_restore)
158 |
159 | return x_masked, mask, ids_restore
160 |
161 | def forward_encoder(self, x, mask_ratio):
162 | # embed patches
163 | x = self.patch_embed(x)
164 |
165 | # add pos embed w/o cls token
166 | x = x + self.pos_embed[:, 1:, :]
167 |
168 | # masking: length -> length * mask_ratio
169 | x, mask, ids_restore = self.random_masking(x, mask_ratio)
170 |
171 | # append cls token
172 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
173 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
174 | x = torch.cat((cls_tokens, x), dim=1)
175 |
176 |
177 | # apply Transformer blocks
178 | for blk in self.blocks:
179 | x = blk(x)
180 | x = self.norm(x)
181 |
182 | return x, mask, ids_restore
183 |
184 | def forward_decoder(self, x, ids_restore):
185 | # embed tokens
186 | x = self.decoder_embed(x)
187 |
188 | # append mask tokens to sequence
189 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
190 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
191 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
192 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
193 |
194 | # add pos embed
195 | x = x + self.decoder_pos_embed
196 |
197 | # apply Transformer blocks
198 | for blk in self.decoder_blocks:
199 | x = blk(x)
200 | x = self.decoder_norm(x)
201 |
202 | # predictor projection
203 | x = self.decoder_pred(x)
204 |
205 | # remove cls token
206 | x = x[:, 1:, :]
207 |
208 | return x
209 |
210 | def forward_loss(self, imgs, pred, mask, noise=None):
211 | """
212 | imgs: [N, 3, H, W]
213 | pred: [N, L, p*p*3]
214 | mask: [N, L], 0 is keep, 1 is remove,
215 | """
216 | target = self.patchify(imgs)
217 | if self.norm_pix_loss:
218 | mean = target.mean(dim=-1, keepdim=True)
219 | var = target.var(dim=-1, keepdim=True)
220 | target = (target - mean) / (var + 1.e-6)**.5
221 |
222 | loss = (pred - target) ** 2
223 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch
224 |
225 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
226 |
227 | losee_noise = -1
228 | if self.noise_loss:
229 | noise = self.patchify(noise)
230 | loss_noise = (pred - noise) ** 2
231 | loss_noise = loss_noise.mean(dim=-1) # [N, L], mean loss per patch
232 | loss_noise = (loss_noise * (1-mask)).sum() / (1-mask).sum() # mean loss on removed patches
233 |
234 | return loss, loss_noise
235 |
236 | def forward(self, imgs, mask_ratio=0.75):
237 | if self.noise_loss:
238 | noise_level = self.std * torch.rand(imgs.shape[0]).to(imgs.device)
239 | noise = noise_level[:, None, None, None] * torch.randn(imgs.shape).to(imgs.device)
240 | imgs += noise
241 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
242 |
243 | # Contrastive loss
244 | bsz = int(imgs.shape[0]/2) # hack for contrastive, otherwise line 227 doesn't work " error dim 0 is greater than 64 features"
245 |
246 |
247 | latent_contrastive = latent.mean(dim=1, keepdim=False)
248 | latent_contrastive = self.projection_head(latent_contrastive)
249 |
250 | # import pdb
251 | # pdb.set_trace()
252 | features = F.normalize(latent_contrastive, dim=-1)
253 | f1, f2 = torch.split(features, [bsz, bsz], dim=0)
254 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
255 | loss_contrastive = SupConLoss()(features)
256 | # print(loss_contrastive)
257 | if self.noise_loss:
258 | noise_pe = get_1d_sincos_pos_embed(noise_level, dim=self.pe_dims)
259 | noise_pe = self.noise_pe_mlp(noise_pe)
260 |
261 | noise_pe = torch.cat(latent.shape[1] * [noise_pe.unsqueeze(1)], dim=1)
262 | latent += noise_pe
263 |
264 | # mae features
265 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
266 | loss, loss_noise = self.forward_loss(imgs, pred, mask, noise=noise)
267 | return loss, loss_contrastive, loss_noise, pred, mask
268 |
269 |
270 | def mae_vit_base_patch16_dec512d8b(**kwargs):
271 | model = MaskedAutoencoderViT(
272 | patch_size=16, embed_dim=768, depth=12, num_heads=12,
273 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
274 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
275 | return model
276 |
277 |
278 | def mae_vit_large_patch16_dec512d8b(**kwargs):
279 | model = MaskedAutoencoderViT(
280 | patch_size=16, embed_dim=1024, depth=24, num_heads=16,
281 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
282 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
283 | return model
284 |
285 |
286 | def mae_vit_huge_patch14_dec512d8b(**kwargs):
287 | model = MaskedAutoencoderViT(
288 | patch_size=14, embed_dim=1280, depth=32, num_heads=16,
289 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
290 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
291 | return model
292 |
293 |
294 | # set recommended archs
295 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
296 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
297 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
298 |
--------------------------------------------------------------------------------
/main_linprobe.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import numpy as np
5 | import os
6 | import time
7 | from pathlib import Path
8 |
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | from torch.utils.tensorboard import SummaryWriter
12 | import torchvision.transforms as transforms
13 | import torchvision.datasets as datasets
14 |
15 | import timm
16 |
17 | assert timm.__version__ == "0.3.2" # version check
18 | from timm.models.layers import trunc_normal_
19 |
20 | import util.misc as misc
21 | from util.pos_embed import interpolate_pos_embed
22 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
23 | from util.lars import LARS
24 | from util.crop import RandomResizedCrop
25 |
26 | import models_vit
27 |
28 | from engine_finetune import train_one_epoch, evaluate
29 |
30 |
31 | def get_args_parser():
32 | parser = argparse.ArgumentParser('CAN linear probing for image classification', add_help=False)
33 | parser.add_argument('--batch_size', default=512, type=int,
34 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
35 | parser.add_argument('--epochs', default=90, type=int)
36 | parser.add_argument('--accum_iter', default=1, type=int,
37 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
38 |
39 | # Model parameters
40 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
41 | help='Name of model to train')
42 |
43 | # Optimizer parameters
44 | parser.add_argument('--weight_decay', type=float, default=0,
45 | help='weight decay (default: 0 for linear probe following MoCo v1)')
46 |
47 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
48 | help='learning rate (absolute lr)')
49 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR',
50 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
51 |
52 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
53 | help='lower lr bound for cyclic schedulers that hit 0')
54 |
55 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
56 | help='epochs to warmup LR')
57 |
58 | # * Finetuning params
59 | parser.add_argument('--finetune', default='',
60 | help='finetune from checkpoint')
61 | parser.add_argument('--global_pool', action='store_true')
62 | parser.set_defaults(global_pool=False)
63 | parser.add_argument('--cls_token', action='store_false', dest='global_pool',
64 | help='Use class token instead of global pool for classification')
65 |
66 | # Dataset parameters
67 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
68 | help='dataset path')
69 | parser.add_argument('--nb_classes', default=1000, type=int,
70 | help='number of the classification types')
71 |
72 | parser.add_argument('--output_dir', default='./output_dir',
73 | help='path where to save, empty for no saving')
74 | parser.add_argument('--log_dir', default='./output_dir',
75 | help='path where to tensorboard log')
76 | parser.add_argument('--device', default='cuda',
77 | help='device to use for training / testing')
78 | parser.add_argument('--seed', default=0, type=int)
79 | parser.add_argument('--resume', default='',
80 | help='resume from checkpoint')
81 |
82 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
83 | help='start epoch')
84 | parser.add_argument('--eval', action='store_true',
85 | help='Perform evaluation only')
86 | parser.add_argument('--dist_eval', action='store_true', default=False,
87 | help='Enabling distributed evaluation (recommended during training for faster monitor')
88 | parser.add_argument('--num_workers', default=10, type=int)
89 | parser.add_argument('--pin_mem', action='store_true',
90 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
91 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
92 | parser.set_defaults(pin_mem=True)
93 |
94 | # distributed training parameters
95 | parser.add_argument('--world_size', default=1, type=int,
96 | help='number of distributed processes')
97 | parser.add_argument('--local_rank', default=-1, type=int)
98 | parser.add_argument('--dist_on_itp', action='store_true')
99 | parser.add_argument('--dist_url', default='env://',
100 | help='url used to set up distributed training')
101 |
102 | return parser
103 |
104 |
105 | def main(args):
106 | misc.init_distributed_mode(args)
107 |
108 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
109 | print("{}".format(args).replace(', ', ',\n'))
110 |
111 | device = torch.device(args.device)
112 |
113 | # fix the seed for reproducibility
114 | seed = args.seed + misc.get_rank()
115 | torch.manual_seed(seed)
116 | np.random.seed(seed)
117 |
118 | cudnn.benchmark = True
119 |
120 | # linear probe: weak augmentation
121 | transform_train = transforms.Compose([
122 | RandomResizedCrop(224, interpolation=3),
123 | transforms.RandomHorizontalFlip(),
124 | transforms.ToTensor(),
125 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
126 | transforms.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])])
127 | transform_val = transforms.Compose([
128 | transforms.Resize(256, interpolation=3),
129 | transforms.CenterCrop(224),
130 | transforms.ToTensor(),
131 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
132 | transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
133 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255])])
134 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
135 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)
136 | print(dataset_train)
137 | print(dataset_val)
138 |
139 | if True: # args.distributed:
140 | num_tasks = misc.get_world_size()
141 | global_rank = misc.get_rank()
142 | sampler_train = torch.utils.data.DistributedSampler(
143 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
144 | )
145 | print("Sampler_train = %s" % str(sampler_train))
146 | if args.dist_eval:
147 | if len(dataset_val) % num_tasks != 0:
148 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
149 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
150 | 'equal num of samples per-process.')
151 | sampler_val = torch.utils.data.DistributedSampler(
152 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
153 | else:
154 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
155 | else:
156 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
157 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
158 |
159 | if global_rank == 0 and args.log_dir is not None and not args.eval:
160 | os.makedirs(args.log_dir, exist_ok=True)
161 | log_writer = SummaryWriter(log_dir=args.log_dir)
162 | else:
163 | log_writer = None
164 |
165 | data_loader_train = torch.utils.data.DataLoader(
166 | dataset_train, sampler=sampler_train,
167 | batch_size=args.batch_size,
168 | num_workers=args.num_workers,
169 | pin_memory=args.pin_mem,
170 | drop_last=True,
171 | )
172 |
173 | data_loader_val = torch.utils.data.DataLoader(
174 | dataset_val, sampler=sampler_val,
175 | batch_size=args.batch_size,
176 | num_workers=args.num_workers,
177 | pin_memory=args.pin_mem,
178 | drop_last=False
179 | )
180 |
181 | model = models_vit.__dict__[args.model](
182 | num_classes=args.nb_classes,
183 | global_pool=args.global_pool,
184 | )
185 |
186 | if args.finetune and not args.eval:
187 | checkpoint = torch.load(args.finetune, map_location='cpu')
188 | # import pdb
189 | # pdb.set_trace()
190 |
191 | print("Load pre-trained checkpoint from: %s" % args.finetune)
192 | checkpoint_model = checkpoint['model']
193 | state_dict = model.state_dict()
194 | # print(checkpoint_model.keys())
195 | for k in ['head.weight', 'head.bias']:
196 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
197 | print(f"Removing key {k} from pretrained checkpoint")
198 | del checkpoint_model[k]
199 | # if k.startwith('patch_embed'):
200 | # print(k)
201 |
202 | # interpolate position embedding
203 | interpolate_pos_embed(model, checkpoint_model)
204 |
205 | # load pre-trained model
206 | msg = model.load_state_dict(checkpoint_model, strict=False)
207 | print(msg)
208 |
209 | if args.global_pool:
210 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
211 | else:
212 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
213 |
214 | # manually initialize fc layer: following MoCo v3
215 | trunc_normal_(model.head.weight, std=0.01)
216 |
217 | # for linear prob only
218 | # hack: revise model's head with BN
219 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head)
220 | # freeze all but the head
221 | for _, p in model.named_parameters():
222 | p.requires_grad = False
223 | for _, p in model.head.named_parameters():
224 | p.requires_grad = True
225 |
226 | model.to(device)
227 |
228 | model_without_ddp = model
229 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
230 |
231 | print("Model = %s" % str(model_without_ddp))
232 | print('number of params (M): %.2f' % (n_parameters / 1.e6))
233 |
234 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
235 |
236 | if args.lr is None: # only base_lr is specified
237 | args.lr = args.blr * eff_batch_size / 256
238 |
239 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
240 | print("actual lr: %.2e" % args.lr)
241 |
242 | print("accumulate grad iterations: %d" % args.accum_iter)
243 | print("effective batch size: %d" % eff_batch_size)
244 |
245 | if args.distributed:
246 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
247 | model_without_ddp = model.module
248 |
249 | optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay)
250 | print(optimizer)
251 | loss_scaler = NativeScaler()
252 |
253 | criterion = torch.nn.CrossEntropyLoss()
254 |
255 | print("criterion = %s" % str(criterion))
256 |
257 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
258 |
259 | if args.eval:
260 | test_stats = evaluate(data_loader_val, model, device)
261 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
262 | exit(0)
263 |
264 | print(f"Start training for {args.epochs} epochs")
265 | start_time = time.time()
266 | max_accuracy = 0.0
267 | for epoch in range(args.start_epoch, args.epochs):
268 | if args.distributed:
269 | data_loader_train.sampler.set_epoch(epoch)
270 | train_stats = train_one_epoch(
271 | model, criterion, data_loader_train,
272 | optimizer, device, epoch, loss_scaler,
273 | max_norm=None,
274 | log_writer=log_writer,
275 | args=args
276 | )
277 | if args.output_dir:
278 | misc.save_model(
279 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
280 | loss_scaler=loss_scaler, epoch=epoch)
281 |
282 | test_stats = evaluate(data_loader_val, model, device)
283 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
284 | max_accuracy = max(max_accuracy, test_stats["acc1"])
285 | print(f'Max accuracy: {max_accuracy:.2f}%')
286 |
287 | if log_writer is not None:
288 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
289 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
290 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
291 |
292 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
293 | **{f'test_{k}': v for k, v in test_stats.items()},
294 | 'epoch': epoch,
295 | 'n_parameters': n_parameters}
296 |
297 | if args.output_dir and misc.is_main_process():
298 | if log_writer is not None:
299 | log_writer.flush()
300 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
301 | f.write(json.dumps(log_stats) + "\n")
302 |
303 | total_time = time.time() - start_time
304 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
305 | print('Training time {}'.format(total_time_str))
306 |
307 |
308 | if __name__ == '__main__':
309 | args = get_args_parser()
310 | args = args.parse_args()
311 | if args.output_dir:
312 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
313 | main(args)
314 |
--------------------------------------------------------------------------------
/main_finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import numpy as np
5 | import os
6 | import time
7 | from pathlib import Path
8 |
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | import timm
14 |
15 | assert timm.__version__ == "0.3.2" # version check
16 | from timm.models.layers import trunc_normal_
17 | from timm.data.mixup import Mixup
18 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
19 |
20 | import util.lr_decay as lrd
21 | import util.misc as misc
22 | from util.datasets import build_dataset
23 | from util.pos_embed import interpolate_pos_embed
24 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
25 |
26 | import models_vit
27 |
28 | from engine_finetune import train_one_epoch, evaluate
29 |
30 |
31 | def get_args_parser():
32 | parser = argparse.ArgumentParser('CAN fine-tuning for image classification', add_help=False)
33 | parser.add_argument('--batch_size', default=64, type=int,
34 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
35 | parser.add_argument('--epochs', default=50, type=int)
36 | parser.add_argument('--accum_iter', default=1, type=int,
37 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
38 |
39 | # Model parameters
40 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
41 | help='Name of model to train')
42 |
43 | parser.add_argument('--input_size', default=224, type=int,
44 | help='images input size')
45 |
46 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
47 | help='Drop path rate (default: 0.1)')
48 |
49 | # Optimizer parameters
50 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
51 | help='Clip gradient norm (default: None, no clipping)')
52 | parser.add_argument('--weight_decay', type=float, default=0.05,
53 | help='weight decay (default: 0.05)')
54 |
55 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
56 | help='learning rate (absolute lr)')
57 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
58 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
59 | parser.add_argument('--layer_decay', type=float, default=0.75,
60 | help='layer-wise lr decay from ELECTRA/BEiT')
61 |
62 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
63 | help='lower lr bound for cyclic schedulers that hit 0')
64 |
65 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
66 | help='epochs to warmup LR')
67 |
68 | # Augmentation parameters
69 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
70 | help='Color jitter factor (enabled only when not using Auto/RandAug)')
71 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
72 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
73 | parser.add_argument('--smoothing', type=float, default=0.1,
74 | help='Label smoothing (default: 0.1)')
75 |
76 | # * Random Erase params
77 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
78 | help='Random erase prob (default: 0.25)')
79 | parser.add_argument('--remode', type=str, default='pixel',
80 | help='Random erase mode (default: "pixel")')
81 | parser.add_argument('--recount', type=int, default=1,
82 | help='Random erase count (default: 1)')
83 | parser.add_argument('--resplit', action='store_true', default=False,
84 | help='Do not random erase first (clean) augmentation split')
85 |
86 | # * Mixup params
87 | parser.add_argument('--mixup', type=float, default=0,
88 | help='mixup alpha, mixup enabled if > 0.')
89 | parser.add_argument('--cutmix', type=float, default=0,
90 | help='cutmix alpha, cutmix enabled if > 0.')
91 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
92 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
93 | parser.add_argument('--mixup_prob', type=float, default=1.0,
94 | help='Probability of performing mixup or cutmix when either/both is enabled')
95 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
96 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
97 | parser.add_argument('--mixup_mode', type=str, default='batch',
98 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
99 |
100 | # * Finetuning params
101 | parser.add_argument('--finetune', default='',
102 | help='finetune from checkpoint')
103 | parser.add_argument('--global_pool', action='store_true')
104 | parser.set_defaults(global_pool=True)
105 | parser.add_argument('--cls_token', action='store_false', dest='global_pool',
106 | help='Use class token instead of global pool for classification')
107 |
108 | # Dataset parameters
109 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
110 | help='dataset path')
111 | parser.add_argument('--nb_classes', default=1000, type=int,
112 | help='number of the classification types')
113 |
114 | parser.add_argument('--output_dir', default='./output_dir',
115 | help='path where to save, empty for no saving')
116 | parser.add_argument('--log_dir', default='./output_dir',
117 | help='path where to tensorboard log')
118 | parser.add_argument('--device', default='cuda',
119 | help='device to use for training / testing')
120 | parser.add_argument('--seed', default=0, type=int)
121 | parser.add_argument('--resume', default='',
122 | help='resume from checkpoint')
123 |
124 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
125 | help='start epoch')
126 | parser.add_argument('--eval', action='store_true',
127 | help='Perform evaluation only')
128 | parser.add_argument('--dist_eval', action='store_true', default=False,
129 | help='Enabling distributed evaluation (recommended during training for faster monitor')
130 | parser.add_argument('--num_workers', default=10, type=int)
131 | parser.add_argument('--pin_mem', action='store_true',
132 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
133 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
134 | parser.set_defaults(pin_mem=True)
135 |
136 | # distributed training parameters
137 | parser.add_argument('--world_size', default=1, type=int,
138 | help='number of distributed processes')
139 | parser.add_argument('--local_rank', default=-1, type=int)
140 | parser.add_argument('--dist_on_itp', action='store_true')
141 | parser.add_argument('--dist_url', default='env://',
142 | help='url used to set up distributed training')
143 |
144 | return parser
145 |
146 |
147 | def main(args):
148 | misc.init_distributed_mode(args)
149 |
150 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
151 | print("{}".format(args).replace(', ', ',\n'))
152 |
153 | device = torch.device(args.device)
154 |
155 | # fix the seed for reproducibility
156 | seed = args.seed + misc.get_rank()
157 | torch.manual_seed(seed)
158 | np.random.seed(seed)
159 |
160 | cudnn.benchmark = True
161 |
162 | dataset_train = build_dataset(is_train=True, args=args)
163 | dataset_val = build_dataset(is_train=False, args=args)
164 |
165 | if True: # args.distributed:
166 | num_tasks = misc.get_world_size()
167 | global_rank = misc.get_rank()
168 | sampler_train = torch.utils.data.DistributedSampler(
169 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
170 | )
171 | print("Sampler_train = %s" % str(sampler_train))
172 | if args.dist_eval:
173 | if len(dataset_val) % num_tasks != 0:
174 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
175 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
176 | 'equal num of samples per-process.')
177 | sampler_val = torch.utils.data.DistributedSampler(
178 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
179 | else:
180 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
181 | else:
182 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
183 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
184 |
185 | if global_rank == 0 and args.log_dir is not None and not args.eval:
186 | os.makedirs(args.log_dir, exist_ok=True)
187 | log_writer = SummaryWriter(log_dir=args.log_dir)
188 | else:
189 | log_writer = None
190 |
191 | data_loader_train = torch.utils.data.DataLoader(
192 | dataset_train, sampler=sampler_train,
193 | batch_size=args.batch_size,
194 | num_workers=args.num_workers,
195 | pin_memory=args.pin_mem,
196 | drop_last=True,
197 | )
198 |
199 | data_loader_val = torch.utils.data.DataLoader(
200 | dataset_val, sampler=sampler_val,
201 | batch_size=args.batch_size,
202 | num_workers=args.num_workers,
203 | pin_memory=args.pin_mem,
204 | drop_last=False
205 | )
206 |
207 | mixup_fn = None
208 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
209 | if mixup_active:
210 | print("Mixup is activated!")
211 | mixup_fn = Mixup(
212 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
213 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
214 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
215 |
216 | model = models_vit.__dict__[args.model](
217 | num_classes=args.nb_classes,
218 | drop_path_rate=args.drop_path,
219 | global_pool=args.global_pool,
220 | )
221 |
222 | if args.finetune and not args.eval:
223 | checkpoint = torch.load(args.finetune, map_location='cpu')
224 |
225 | print("Load pre-trained checkpoint from: %s" % args.finetune)
226 | checkpoint_model = checkpoint['model']
227 | state_dict = model.state_dict()
228 | for k in ['head.weight', 'head.bias']:
229 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
230 | print(f"Removing key {k} from pretrained checkpoint")
231 | del checkpoint_model[k]
232 |
233 | # interpolate position embedding
234 | interpolate_pos_embed(model, checkpoint_model)
235 |
236 | # load pre-trained model
237 | msg = model.load_state_dict(checkpoint_model, strict=False)
238 | print(msg)
239 |
240 | if args.global_pool:
241 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
242 | else:
243 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
244 |
245 | # manually initialize fc layer
246 | trunc_normal_(model.head.weight, std=2e-5)
247 |
248 | model.to(device)
249 |
250 | model_without_ddp = model
251 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
252 |
253 | print("Model = %s" % str(model_without_ddp))
254 | print('number of params (M): %.2f' % (n_parameters / 1.e6))
255 |
256 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
257 |
258 | if args.lr is None: # only base_lr is specified
259 | args.lr = args.blr * eff_batch_size / 256
260 |
261 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
262 | print("actual lr: %.2e" % args.lr)
263 |
264 | print("accumulate grad iterations: %d" % args.accum_iter)
265 | print("effective batch size: %d" % eff_batch_size)
266 |
267 | if args.distributed:
268 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
269 | model_without_ddp = model.module
270 |
271 | # build optimizer with layer-wise lr decay (lrd)
272 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
273 | no_weight_decay_list=model_without_ddp.no_weight_decay(),
274 | layer_decay=args.layer_decay
275 | )
276 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
277 | loss_scaler = NativeScaler()
278 |
279 | if mixup_fn is not None:
280 | # smoothing is handled with mixup label transform
281 | criterion = SoftTargetCrossEntropy()
282 | elif args.smoothing > 0.:
283 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
284 | else:
285 | criterion = torch.nn.CrossEntropyLoss()
286 |
287 | print("criterion = %s" % str(criterion))
288 |
289 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
290 |
291 | if args.eval:
292 | test_stats = evaluate(data_loader_val, model, device)
293 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
294 | exit(0)
295 |
296 | print(f"Start training for {args.epochs} epochs")
297 | start_time = time.time()
298 | max_accuracy = 0.0
299 | for epoch in range(args.start_epoch, args.epochs):
300 | if args.distributed:
301 | data_loader_train.sampler.set_epoch(epoch)
302 | train_stats = train_one_epoch(
303 | model, criterion, data_loader_train,
304 | optimizer, device, epoch, loss_scaler,
305 | args.clip_grad, mixup_fn,
306 | log_writer=log_writer,
307 | args=args
308 | )
309 | if args.output_dir:
310 | misc.save_model(
311 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
312 | loss_scaler=loss_scaler, epoch=epoch)
313 |
314 | test_stats = evaluate(data_loader_val, model, device)
315 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
316 | max_accuracy = max(max_accuracy, test_stats["acc1"])
317 | print(f'Max accuracy: {max_accuracy:.2f}%')
318 |
319 | if log_writer is not None:
320 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
321 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
322 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
323 |
324 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
325 | **{f'test_{k}': v for k, v in test_stats.items()},
326 | 'epoch': epoch,
327 | 'n_parameters': n_parameters}
328 |
329 | if args.output_dir and misc.is_main_process():
330 | if log_writer is not None:
331 | log_writer.flush()
332 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
333 | f.write(json.dumps(log_stats) + "\n")
334 |
335 | total_time = time.time() - start_time
336 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
337 | print('Training time {}'.format(total_time_str))
338 |
339 |
340 | if __name__ == '__main__':
341 | args = get_args_parser()
342 | args = args.parse_args()
343 | if args.output_dir:
344 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
345 | main(args)
346 |
--------------------------------------------------------------------------------