├── swin ├── models │ ├── __init__.py │ └── build.py ├── configs │ └── swin │ │ ├── swin_base_patch4_window7_224.yaml │ │ ├── swin_small_patch4_window7_224.yaml │ │ ├── swin_tiny_patch4_window7_224.yaml │ │ ├── swin_tiny_c24_patch4_window8_256.yaml │ │ ├── swin_base_patch4_window7_224_22kto1k_finetune.yaml │ │ ├── swin_small_patch4_window7_224_22kto1k_finetune.yaml │ │ ├── swin_tiny_patch4_window7_224_22kto1k_finetune.yaml │ │ ├── swin_large_patch4_window7_224_22kto1k_finetune.yaml │ │ ├── swin_tiny_patch4_window7_224_22k.yaml │ │ ├── swin_base_patch4_window7_224_22k.yaml │ │ ├── swin_large_patch4_window7_224_22k.yaml │ │ ├── swin_small_patch4_window7_224_22k.yaml │ │ ├── swin_base_patch4_window12_384_finetune.yaml │ │ ├── swin_base_patch4_window12_384_22kto1k_finetune.yaml │ │ └── swin_large_patch4_window12_384_22kto1k_finetune.yaml ├── kernels │ └── window_process │ │ ├── setup.py │ │ ├── window_process.py │ │ ├── swin_window_process.cpp │ │ ├── unit_test.py │ │ └── swin_window_process_kernel.cu ├── data │ ├── __init__.py │ ├── samplers.py │ ├── imagenet22k_dataset.py │ ├── zipreader.py │ ├── data_simmim_pt.py │ ├── data_simmim_ft.py │ ├── build.py │ ├── cached_image_folder.py │ └── map22kto1k.txt ├── README.md ├── logger.py ├── lr_scheduler.py ├── optimizer.py ├── utils.py └── config.py ├── preview_picture.jpg ├── .gitignore ├── mae ├── util │ ├── lr_sched.py │ ├── crop.py │ ├── lars.py │ ├── datasets.py │ ├── lr_decay.py │ ├── pos_embed.py │ └── misc.py ├── README.md ├── models_vit.py └── engine_finetune.py ├── resnet_rsb ├── README.md ├── utils.py └── models.py ├── deit ├── samplers.py ├── losses.py ├── augment.py ├── datasets.py ├── engine.py ├── README.md └── utils.py └── README.md /swin/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /preview_picture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/augsub/HEAD/preview_picture.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */nsml_logger.py 2 | .idea 3 | .DS_store 4 | mae/.* 5 | swin/.* 6 | deit/.* 7 | resnet_rsb/.* 8 | -------------------------------------------------------------------------------- /swin/configs/swin/swin_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /swin/configs/swin/swin_small_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_small_patch4_window7_224 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /swin/configs/swin/swin_tiny_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /swin/configs/swin/swin_tiny_c24_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_tiny_c24_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /swin/kernels/window_process/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup(name='swin_window_process', 6 | ext_modules=[ 7 | CUDAExtension('swin_window_process', [ 8 | 'swin_window_process.cpp', 9 | 'swin_window_process_kernel.cu', 10 | ]) 11 | ], 12 | cmdclass={'build_ext': BuildExtension}) -------------------------------------------------------------------------------- /swin/configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /swin/configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_small_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /swin/configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.1 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /swin/configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_large_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 6, 12, 24, 48 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /swin/configs/swin/swin_tiny_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_tiny_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.1 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /swin/configs/swin/swin_base_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /swin/configs/swin/swin_large_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_large_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /swin/configs/swin/swin_small_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_small_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /swin/configs/swin/swin_base_patch4_window12_384_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window12_384_finetune 6 | DROP_PATH_RATE: 0.5 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /swin/configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window12_384_22kto1k_finetune 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /swin/configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_large_patch4_window12_384_22kto1k_finetune 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /swin/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | from .build import build_loader as _build_loader 7 | from .data_simmim_pt import build_loader_simmim 8 | from .data_simmim_ft import build_loader_finetune 9 | 10 | 11 | def build_loader(config, simmim=False, is_pretrain=False): 12 | if not simmim: 13 | return _build_loader(config) 14 | if is_pretrain: 15 | return build_loader_simmim(config) 16 | else: 17 | return build_loader_finetune(config) 18 | -------------------------------------------------------------------------------- /mae/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | 12 | import math 13 | 14 | def adjust_learning_rate(optimizer, epoch, args): 15 | """Decay the learning rate with half-cycle cosine after warmup""" 16 | if epoch < args.warmup_epochs: 17 | lr = args.lr * epoch / args.warmup_epochs 18 | else: 19 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 20 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 21 | for param_group in optimizer.param_groups: 22 | if "lr_scale" in param_group: 23 | param_group["lr"] = lr * param_group["lr_scale"] 24 | else: 25 | param_group["lr"] = lr 26 | return lr 27 | -------------------------------------------------------------------------------- /swin/data/samplers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # -------------------------------------------------------- 12 | 13 | import torch 14 | 15 | 16 | class SubsetRandomSampler(torch.utils.data.Sampler): 17 | r"""Samples elements randomly from a given list of indices, without replacement. 18 | 19 | Arguments: 20 | indices (sequence): a sequence of indices 21 | """ 22 | 23 | def __init__(self, indices): 24 | self.epoch = 0 25 | self.indices = indices 26 | 27 | def __iter__(self): 28 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 29 | 30 | def __len__(self): 31 | return len(self.indices) 32 | 33 | def set_epoch(self, epoch): 34 | self.epoch = epoch 35 | -------------------------------------------------------------------------------- /swin/README.md: -------------------------------------------------------------------------------- 1 | # Swin Transformer 2 | 3 | The codes are originated from [https://github.com/microsoft/Swin-Transformer](https://github.com/microsoft/Swin-Transformer) 4 | 5 | 6 | ### Requirements 7 | ```angular2html 8 | torch==1.11.0 9 | torchvision==0.11.0 10 | timm==0.3.2 11 | ``` 12 | 13 | ### Performance 14 | 15 | | Architecture | # Params | FLOPs | Baseline | + MaskSub | 16 | | :---: | :---: | :---: | :---: |:---------------:| 17 | | Swin-T | 28.3 M | 4.5 G | 81.3 | **81.4 (+0.1)** | 18 | | Swin-S | 49.6 M | 8.7 G | 83.0 | **83.4 (+0.4)** | 19 | | Swin-B | 87.9 M | 15.4 G | 83.5 | **83.9 (+0.4)** | 20 | 21 | ### MaskSub training commands 22 | - Enviroment variables 23 | ```bash 24 | data_path=/your/path/to/imagenet 25 | save_path=/your/path/to/save 26 | # Use config file of target model 27 | config_file=swin_tiny_patch4_window7_224.yaml 28 | config_file=swin_small_patch4_window7_224.yaml 29 | config_file=swin_base_patch4_window7_224.yaml 30 | ``` 31 | 32 | - Command 33 | ```bash 34 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 main.py \ 35 | --cfg configs/swin/${config_file} --data-path ${data_path} --output ${save_path} --batch-size 128 \ 36 | --augsub masking --augsub-ratio 0.5 37 | ``` 38 | -------------------------------------------------------------------------------- /mae/util/crop.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | 12 | import math 13 | 14 | import torch 15 | 16 | from torchvision import transforms 17 | from torchvision.transforms import functional as F 18 | 19 | 20 | class RandomResizedCrop(transforms.RandomResizedCrop): 21 | """ 22 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 23 | This may lead to results different with torchvision's version. 24 | Following BYOL's TF code: 25 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 26 | """ 27 | @staticmethod 28 | def get_params(img, scale, ratio): 29 | width, height = F._get_image_size(img) 30 | area = height * width 31 | 32 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 33 | log_ratio = torch.log(torch.tensor(ratio)) 34 | aspect_ratio = torch.exp( 35 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 36 | ).item() 37 | 38 | w = int(round(math.sqrt(target_area * aspect_ratio))) 39 | h = int(round(math.sqrt(target_area / aspect_ratio))) 40 | 41 | w = min(w, width) 42 | h = min(h, height) 43 | 44 | i = torch.randint(0, height - h + 1, size=(1,)).item() 45 | j = torch.randint(0, width - w + 1, size=(1,)).item() 46 | 47 | return i, j, h, w -------------------------------------------------------------------------------- /swin/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # -------------------------------------------------------- 12 | 13 | import os 14 | import sys 15 | import logging 16 | import functools 17 | from termcolor import colored 18 | 19 | 20 | @functools.lru_cache() 21 | def create_logger(output_dir, dist_rank=0, name=''): 22 | # create logger 23 | logger = logging.getLogger(name) 24 | logger.setLevel(logging.DEBUG) 25 | logger.propagate = False 26 | 27 | # create formatter 28 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 29 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 30 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 31 | 32 | # create console handlers for master process 33 | if dist_rank == 0: 34 | console_handler = logging.StreamHandler(sys.stdout) 35 | console_handler.setLevel(logging.DEBUG) 36 | console_handler.setFormatter( 37 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 38 | logger.addHandler(console_handler) 39 | 40 | # create file handlers 41 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 42 | file_handler.setLevel(logging.DEBUG) 43 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 44 | logger.addHandler(file_handler) 45 | 46 | return logger 47 | -------------------------------------------------------------------------------- /resnet_rsb/README.md: -------------------------------------------------------------------------------- 1 | # ResNet strikes back: An improved training procedure in timm 2 | 3 | The codes are originated from [https://github.com/huggingface/pytorch-image-models](https://github.com/huggingface/pytorch-image-models) 4 | 5 | ### Requirements 6 | ```angular2html 7 | torch==1.11.0 8 | torchvision==0.11.0 9 | timm==0.5.4 10 | ``` 11 | 12 | ### Performance 13 | 14 | | Architecture | # Params | FLOPs | Baseline | + MaskSub | 15 | | :---: | :---: | :---: | :---: |:---------------:| 16 | | ResNet50 | 25.6 M | 4.1 G | 79.7 | **80.0 (+0.3)** | 17 | | ResNet101 | 44.5 M | 7.9 G | 81.4 | **82.1 (+0.7)** | 18 | | ResNet152 | 60.2 M | 11.6 G | 81.8 | **82.8 (+1.0)** | 19 | 20 | 21 | ### MaskSub training commands 22 | - Enviroment variables 23 | ```bash 24 | data_path=/your/path/to/imagenet 25 | save_path=/your/path/to/save 26 | # Use target model name 27 | model_name=resnet50 28 | model_name=resnet101 29 | model_name=resnet152 30 | ``` 31 | 32 | - Command 33 | ```bash 34 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 train.py \ 35 | ${data_path} \ 36 | --model ${model_name} \ 37 | --output ${save_path} \ 38 | --img-size 224 \ 39 | --epochs 300 \ 40 | --batch-size 256 \ 41 | --opt lamb \ 42 | --lr 5e-3 \ 43 | --sched cosine \ 44 | --weight-decay 0.02\ 45 | --warmup-epochs 5 \ 46 | --cooldown-epochs 0 \ 47 | --smoothing 0.0 \ 48 | --drop 0.0 \ 49 | --drop-path 0.05 \ 50 | --aug-repeats 3 \ 51 | --aa rand-m7-mstd0.5 \ 52 | --mixup 0.1 \ 53 | --cutmix 1.0 \ 54 | --reprob 0.0 \ 55 | --color-jitter 0.0 \ 56 | --crop-pct 0.95 \ 57 | --bce-loss \ 58 | --native-amp \ 59 | --log-interval 400 \ 60 | --augsub masking \ 61 | --augsub-ratio 0.5 62 | ``` -------------------------------------------------------------------------------- /swin/data/imagenet22k_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | import os 7 | import json 8 | import torch.utils.data as data 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import warnings 13 | 14 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 15 | 16 | 17 | class IN22KDATASET(data.Dataset): 18 | def __init__(self, root, ann_file='', transform=None, target_transform=None): 19 | super(IN22KDATASET, self).__init__() 20 | 21 | self.data_path = root 22 | self.ann_path = os.path.join(self.data_path, ann_file) 23 | self.transform = transform 24 | self.target_transform = target_transform 25 | # id & label: https://github.com/google-research/big_transfer/issues/7 26 | # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027 27 | self.database = json.load(open(self.ann_path)) 28 | 29 | def _load_image(self, path): 30 | try: 31 | im = Image.open(path) 32 | except: 33 | print("ERROR IMG LOADED: ", path) 34 | random_img = np.random.rand(224, 224, 3) * 255 35 | im = Image.fromarray(np.uint8(random_img)) 36 | return im 37 | 38 | def __getitem__(self, index): 39 | """ 40 | Args: 41 | index (int): Index 42 | Returns: 43 | tuple: (image, target) where target is class_index of the target class. 44 | """ 45 | idb = self.database[index] 46 | 47 | # images 48 | images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB') 49 | if self.transform is not None: 50 | images = self.transform(images) 51 | 52 | # target 53 | target = int(idb[1]) 54 | if self.target_transform is not None: 55 | target = self.target_transform(target) 56 | 57 | return images, target 58 | 59 | def __len__(self): 60 | return len(self.database) 61 | -------------------------------------------------------------------------------- /swin/models/build.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # -------------------------------------------------------- 12 | 13 | from .swin_transformer import SwinTransformer, augsub_SwinTransformer 14 | 15 | 16 | def build_model(config, is_pretrain=False): 17 | model_type = config.MODEL.TYPE 18 | 19 | # accelerate layernorm 20 | if config.FUSED_LAYERNORM: 21 | try: 22 | import apex as amp 23 | layernorm = amp.normalization.FusedLayerNorm 24 | except: 25 | layernorm = None 26 | print("To use FusedLayerNorm, please install apex.") 27 | else: 28 | import torch.nn as nn 29 | layernorm = nn.LayerNorm 30 | 31 | 32 | if model_type == 'swin': 33 | model = augsub_SwinTransformer( 34 | img_size=config.DATA.IMG_SIZE, 35 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 36 | in_chans=config.MODEL.SWIN.IN_CHANS, 37 | num_classes=config.MODEL.NUM_CLASSES, 38 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 39 | depths=config.MODEL.SWIN.DEPTHS, 40 | num_heads=config.MODEL.SWIN.NUM_HEADS, 41 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 42 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 43 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 44 | qk_scale=config.MODEL.SWIN.QK_SCALE, 45 | drop_rate=config.MODEL.DROP_RATE, 46 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 47 | ape=config.MODEL.SWIN.APE, 48 | norm_layer=layernorm, 49 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 50 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 51 | fused_window_process=config.FUSED_WINDOW_PROCESS) 52 | else: 53 | raise NotImplementedError(f"Unkown model: {model_type}") 54 | 55 | return model 56 | -------------------------------------------------------------------------------- /mae/util/lars.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | # -------------------------------------------------------- 12 | # LARS optimizer, implementation from MoCo v3: 13 | # https://github.com/facebookresearch/moco-v3 14 | # -------------------------------------------------------- 15 | 16 | import torch 17 | 18 | 19 | class LARS(torch.optim.Optimizer): 20 | """ 21 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 22 | """ 23 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 24 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 25 | super().__init__(params, defaults) 26 | 27 | @torch.no_grad() 28 | def step(self): 29 | for g in self.param_groups: 30 | for p in g['params']: 31 | dp = p.grad 32 | 33 | if dp is None: 34 | continue 35 | 36 | if p.ndim > 1: # if not normalization gamma/beta or bias 37 | dp = dp.add(p, alpha=g['weight_decay']) 38 | param_norm = torch.norm(p) 39 | update_norm = torch.norm(dp) 40 | one = torch.ones_like(param_norm) 41 | q = torch.where(param_norm > 0., 42 | torch.where(update_norm > 0, 43 | (g['trust_coefficient'] * param_norm / update_norm), one), 44 | one) 45 | dp = dp.mul(q) 46 | 47 | param_state = self.state[p] 48 | if 'mu' not in param_state: 49 | param_state['mu'] = torch.zeros_like(p) 50 | mu = param_state['mu'] 51 | mu.mul_(g['momentum']).add_(dp) 52 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /resnet_rsb/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from timm.utils import get_state_dict 9 | from torch._six import inf 10 | 11 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 12 | if isinstance(parameters, torch.Tensor): 13 | parameters = [parameters] 14 | parameters = [p for p in parameters if p.grad is not None] 15 | norm_type = float(norm_type) 16 | if len(parameters) == 0: 17 | return torch.tensor(0.) 18 | device = parameters[0].grad.device 19 | if norm_type == inf: 20 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 21 | else: 22 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 23 | return total_norm 24 | 25 | class NativeScalerWithGradNormCount: 26 | state_dict_key = "amp_scaler" 27 | 28 | def __init__(self): 29 | self._scaler = torch.cuda.amp.GradScaler() 30 | 31 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, retain_graph=False): 32 | if retain_graph: 33 | self._scaler.scale(loss).backward(create_graph=create_graph, retain_graph=retain_graph) 34 | else: 35 | self._scaler.scale(loss).backward(create_graph=create_graph) 36 | 37 | if update_grad: 38 | if clip_grad is not None: 39 | assert parameters is not None 40 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 41 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 42 | else: 43 | self._scaler.unscale_(optimizer) 44 | norm = get_grad_norm_(parameters) 45 | self._scaler.step(optimizer) 46 | self._scaler.update() 47 | else: 48 | norm = None 49 | return norm 50 | 51 | def state_dict(self): 52 | return self._scaler.state_dict() 53 | 54 | def load_state_dict(self, state_dict): 55 | self._scaler.load_state_dict(state_dict) -------------------------------------------------------------------------------- /mae/util/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | # -------------------------------------------------------- 12 | # References: 13 | # DeiT: https://github.com/facebookresearch/deit 14 | # -------------------------------------------------------- 15 | 16 | import os 17 | import PIL 18 | 19 | from torchvision import datasets, transforms 20 | 21 | from timm.data import create_transform 22 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 23 | 24 | 25 | def build_dataset(is_train, args): 26 | transform = build_transform(is_train, args) 27 | 28 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 29 | dataset = datasets.ImageFolder(root, transform=transform) 30 | 31 | print(dataset) 32 | 33 | return dataset 34 | 35 | 36 | def build_transform(is_train, args): 37 | mean = IMAGENET_DEFAULT_MEAN 38 | std = IMAGENET_DEFAULT_STD 39 | # train transform 40 | if is_train: 41 | # this should always dispatch to transforms_imagenet_train 42 | transform = create_transform( 43 | input_size=args.input_size, 44 | is_training=True, 45 | color_jitter=args.color_jitter, 46 | auto_augment=args.aa, 47 | interpolation='bicubic', 48 | re_prob=args.reprob, 49 | re_mode=args.remode, 50 | re_count=args.recount, 51 | mean=mean, 52 | std=std, 53 | ) 54 | return transform 55 | 56 | # eval transform 57 | t = [] 58 | if args.input_size <= 224: 59 | crop_pct = 224 / 256 60 | else: 61 | crop_pct = 1.0 62 | size = int(args.input_size / crop_pct) 63 | t.append( 64 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 65 | ) 66 | t.append(transforms.CenterCrop(args.input_size)) 67 | 68 | t.append(transforms.ToTensor()) 69 | t.append(transforms.Normalize(mean, std)) 70 | return transforms.Compose(t) 71 | -------------------------------------------------------------------------------- /swin/kernels/window_process/window_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Fused kernel for window process for SwinTransformer 8 | # Copyright (c) 2022 Nvidia 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import swin_window_process 14 | 15 | 16 | class WindowProcess(torch.autograd.Function): 17 | @staticmethod 18 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 19 | output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) 20 | 21 | ctx.B = B 22 | ctx.H = H 23 | ctx.W = W 24 | ctx.C = C 25 | ctx.shift_size = shift_size 26 | ctx.window_size = window_size 27 | return output 28 | 29 | @staticmethod 30 | def backward(ctx, grad_in): 31 | B = ctx.B 32 | H = ctx.H 33 | W = ctx.W 34 | C = ctx.C 35 | shift_size = ctx.shift_size 36 | window_size = ctx.window_size 37 | 38 | grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) 39 | return grad_out, None, None, None, None, None, None, None 40 | 41 | 42 | class WindowProcessReverse(torch.autograd.Function): 43 | @staticmethod 44 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 45 | output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) 46 | 47 | ctx.B = B 48 | ctx.H = H 49 | ctx.W = W 50 | ctx.C = C 51 | ctx.shift_size = shift_size 52 | ctx.window_size = window_size 53 | 54 | return output 55 | 56 | @staticmethod 57 | def backward(ctx, grad_in): 58 | B = ctx.B 59 | H = ctx.H 60 | W = ctx.W 61 | C = ctx.C 62 | shift_size = ctx.shift_size 63 | window_size = ctx.window_size 64 | 65 | #grad_out = ctx.saved_tensors[0] 66 | #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda() 67 | grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) 68 | return grad_out, None, None, None, None, None, None, None 69 | -------------------------------------------------------------------------------- /mae/README.md: -------------------------------------------------------------------------------- 1 | ## MAE finetuning 2 | 3 | The codes are originated from [https://github.com/facebookresearch/mae](https://github.com/facebookresearch/mae) 4 | 5 | 6 | ### Requirements 7 | ```angular2html 8 | torch==1.11.0 9 | torchvision==0.11.0 10 | timm==0.3.2 11 | ``` 12 | 13 | ### Performance 14 | 15 | | Architecture | Finetuning Epochs | Baseline | + MaskSub | 16 | |:------------:|:-----------------:|:--------:|:---------------:| 17 | | ViT-B/16 | 100 | 83.6 | **83.9 (+0.3)** | 18 | | ViT-L/16 | 50 | 85.9 | **86.1 (+0.2)** | 19 | | ViT-H/14 | 50 | 86.9 | **87.2 (+0.3)** | 20 | 21 | ### MaskSub finetuning commands 22 | 23 | Finetuning requires MAE pretrained weights. Please download MAE weights from [original repository](https://github.com/facebookresearch/mae) 24 | 25 | - Enviroment variables 26 | ```bash 27 | data_path=/your/path/to/imagenet 28 | save_path=/your/path/to/save 29 | weight_path=/your/path/to/mae_pretrain_vit_base.pth 30 | ``` 31 | - ViT-B 32 | ```bash 33 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 main_finetune.py \ 34 | --model vit_base_patch16 \ 35 | --data_path ${data_path} \ 36 | --finetune ${weight_path} \ 37 | --output_dir ${save_path} \ 38 | --batch_size 128 \ 39 | --accum_iter 1 \ 40 | --epochs 100 \ 41 | --blr 5e-4 --layer_decay 0.65 \ 42 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 43 | --augsub masking --augsub_ratio 0.5 \ 44 | --dist_eval 45 | ``` 46 | 47 | - ViT-L 48 | ```bash 49 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 main_finetune.py \ 50 | --model vit_large_patch16 \ 51 | --data_path ${data_path} \ 52 | --finetune ${weight_path} \ 53 | --output_dir ${save_path} \ 54 | --batch_size 32 \ 55 | --accum_iter 4 \ 56 | --epochs 50 \ 57 | --blr 1e-3 --layer_decay 0.75 \ 58 | --weight_decay 0.05 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 59 | --augsub masking --augsub_ratio 0.5 \ 60 | --dist_eval 61 | ``` 62 | 63 | - ViT-H 64 | ```bash 65 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --use_env main_finetune.py \ 66 | --model vit_huge_patch16 \ 67 | --data_path ${data_path} \ 68 | --finetune ${weight_path} \ 69 | --output_dir ${save_path} \ 70 | --batch_size 16 \ 71 | --accum_iter 1 \ 72 | --epochs 50 \ 73 | --blr 1e-3 --layer_decay 0.75 \ 74 | --weight_decay 0.05 --drop_path 0.3 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 75 | --augsub masking --augsub_ratio 0.5 \ 76 | --dist_eval 77 | ``` 78 | -------------------------------------------------------------------------------- /mae/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | # -------------------------------------------------------- 12 | # References: 13 | # ELECTRA https://github.com/google-research/electra 14 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 15 | # -------------------------------------------------------- 16 | 17 | import json 18 | 19 | 20 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 21 | """ 22 | Parameter groups for layer-wise lr decay 23 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 24 | """ 25 | param_group_names = {} 26 | param_groups = {} 27 | 28 | num_layers = len(model.blocks) + 1 29 | 30 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 31 | 32 | for n, p in model.named_parameters(): 33 | if not p.requires_grad: 34 | continue 35 | 36 | # no decay: all 1D parameters and model specific ones 37 | if p.ndim == 1 or n in no_weight_decay_list: 38 | g_decay = "no_decay" 39 | this_decay = 0. 40 | else: 41 | g_decay = "decay" 42 | this_decay = weight_decay 43 | 44 | layer_id = get_layer_id_for_vit(n, num_layers) 45 | group_name = "layer_%d_%s" % (layer_id, g_decay) 46 | 47 | if group_name not in param_group_names: 48 | this_scale = layer_scales[layer_id] 49 | 50 | param_group_names[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | param_groups[group_name] = { 56 | "lr_scale": this_scale, 57 | "weight_decay": this_decay, 58 | "params": [], 59 | } 60 | 61 | param_group_names[group_name]["params"].append(n) 62 | param_groups[group_name]["params"].append(p) 63 | 64 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 65 | 66 | return list(param_groups.values()) 67 | 68 | 69 | def get_layer_id_for_vit(name, num_layers): 70 | """ 71 | Assign a parameter with its layer id 72 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 73 | """ 74 | if name in ['cls_token', 'pos_embed']: 75 | return 0 76 | elif name.startswith('patch_embed'): 77 | return 0 78 | elif name.startswith('blocks'): 79 | return int(name.split('.')[1]) + 1 80 | else: 81 | return num_layers -------------------------------------------------------------------------------- /deit/samplers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/deit 4 | """ 5 | 6 | # Copyright (c) 2015-present, Facebook, Inc. 7 | # All rights reserved. 8 | import torch 9 | import torch.distributed as dist 10 | import math 11 | 12 | 13 | class RASampler(torch.utils.data.Sampler): 14 | """Sampler that restricts data loading to a subset of the dataset for distributed, 15 | with repeated augmentation. 16 | It ensures that different each augmented version of a sample will be visible to a 17 | different process (GPU) 18 | Heavily based on torch.utils.data.DistributedSampler 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 22 | if num_replicas is None: 23 | if not dist.is_available(): 24 | raise RuntimeError("Requires distributed package to be available") 25 | num_replicas = dist.get_world_size() 26 | if rank is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | rank = dist.get_rank() 30 | if num_repeats < 1: 31 | raise ValueError("num_repeats should be greater than 0") 32 | self.dataset = dataset 33 | self.num_replicas = num_replicas 34 | self.rank = rank 35 | self.num_repeats = num_repeats 36 | self.epoch = 0 37 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 38 | self.total_size = self.num_samples * self.num_replicas 39 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 40 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 41 | self.shuffle = shuffle 42 | 43 | def __iter__(self): 44 | if self.shuffle: 45 | # deterministically shuffle based on epoch 46 | g = torch.Generator() 47 | g.manual_seed(self.epoch) 48 | indices = torch.randperm(len(self.dataset), generator=g) 49 | else: 50 | indices = torch.arange(start=0, end=len(self.dataset)) 51 | 52 | # add extra samples to make it evenly divisible 53 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 54 | padding_size: int = self.total_size - len(indices) 55 | if padding_size > 0: 56 | indices += indices[:padding_size] 57 | assert len(indices) == self.total_size 58 | 59 | # subsample 60 | indices = indices[self.rank:self.total_size:self.num_replicas] 61 | assert len(indices) == self.num_samples 62 | 63 | return iter(indices[:self.num_selected_samples]) 64 | 65 | def __len__(self): 66 | return self.num_selected_samples 67 | 68 | def set_epoch(self, epoch): 69 | self.epoch = epoch 70 | -------------------------------------------------------------------------------- /deit/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/deit 4 | """ 5 | 6 | # Copyright (c) 2015-present, Facebook, Inc. 7 | # All rights reserved. 8 | """ 9 | Implements the knowledge distillation loss 10 | """ 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | 15 | class DistillationLoss(torch.nn.Module): 16 | """ 17 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 18 | taking a teacher model prediction and using it as additional supervision. 19 | """ 20 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 21 | distillation_type: str, alpha: float, tau: float): 22 | super().__init__() 23 | self.base_criterion = base_criterion 24 | self.teacher_model = teacher_model 25 | assert distillation_type in ['none', 'soft', 'hard'] 26 | self.distillation_type = distillation_type 27 | self.alpha = alpha 28 | self.tau = tau 29 | 30 | def forward(self, inputs, outputs, labels): 31 | """ 32 | Args: 33 | inputs: The original inputs that are feed to the teacher model 34 | outputs: the outputs of the model to be trained. It is expected to be 35 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 36 | in the first position and the distillation predictions as the second output 37 | labels: the labels for the base criterion 38 | """ 39 | outputs_kd = None 40 | if not isinstance(outputs, torch.Tensor): 41 | # assume that the model outputs a tuple of [outputs, outputs_kd] 42 | outputs, outputs_kd = outputs 43 | base_loss = self.base_criterion(outputs, labels) 44 | if self.distillation_type == 'none': 45 | return base_loss 46 | 47 | if outputs_kd is None: 48 | raise ValueError("When knowledge distillation is enabled, the model is " 49 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 50 | "class_token and the dist_token") 51 | # don't backprop throught the teacher 52 | with torch.no_grad(): 53 | teacher_outputs = self.teacher_model(inputs) 54 | 55 | if self.distillation_type == 'soft': 56 | T = self.tau 57 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 58 | # with slight modifications 59 | distillation_loss = F.kl_div( 60 | F.log_softmax(outputs_kd / T, dim=1), 61 | #We provide the teacher's targets in log probability because we use log_target=True 62 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 63 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 64 | F.log_softmax(teacher_outputs / T, dim=1), 65 | reduction='sum', 66 | log_target=True 67 | ) * (T * T) / outputs_kd.numel() 68 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 69 | #But we also experiments output_kd.size(0) 70 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 71 | elif self.distillation_type == 'hard': 72 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 73 | 74 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 75 | return loss 76 | -------------------------------------------------------------------------------- /swin/data/zipreader.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # -------------------------------------------------------- 12 | 13 | import os 14 | import zipfile 15 | import io 16 | import numpy as np 17 | from PIL import Image 18 | from PIL import ImageFile 19 | 20 | ImageFile.LOAD_TRUNCATED_IMAGES = True 21 | 22 | 23 | def is_zip_path(img_or_path): 24 | """judge if this is a zip path""" 25 | return '.zip@' in img_or_path 26 | 27 | 28 | class ZipReader(object): 29 | """A class to read zipped files""" 30 | zip_bank = dict() 31 | 32 | def __init__(self): 33 | super(ZipReader, self).__init__() 34 | 35 | @staticmethod 36 | def get_zipfile(path): 37 | zip_bank = ZipReader.zip_bank 38 | if path not in zip_bank: 39 | zfile = zipfile.ZipFile(path, 'r') 40 | zip_bank[path] = zfile 41 | return zip_bank[path] 42 | 43 | @staticmethod 44 | def split_zip_style_path(path): 45 | pos_at = path.index('@') 46 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 47 | 48 | zip_path = path[0: pos_at] 49 | folder_path = path[pos_at + 1:] 50 | folder_path = str.strip(folder_path, '/') 51 | return zip_path, folder_path 52 | 53 | @staticmethod 54 | def list_folder(path): 55 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 56 | 57 | zfile = ZipReader.get_zipfile(zip_path) 58 | folder_list = [] 59 | for file_foler_name in zfile.namelist(): 60 | file_foler_name = str.strip(file_foler_name, '/') 61 | if file_foler_name.startswith(folder_path) and \ 62 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 63 | file_foler_name != folder_path: 64 | if len(folder_path) == 0: 65 | folder_list.append(file_foler_name) 66 | else: 67 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 68 | 69 | return folder_list 70 | 71 | @staticmethod 72 | def list_files(path, extension=None): 73 | if extension is None: 74 | extension = ['.*'] 75 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 76 | 77 | zfile = ZipReader.get_zipfile(zip_path) 78 | file_lists = [] 79 | for file_foler_name in zfile.namelist(): 80 | file_foler_name = str.strip(file_foler_name, '/') 81 | if file_foler_name.startswith(folder_path) and \ 82 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 83 | if len(folder_path) == 0: 84 | file_lists.append(file_foler_name) 85 | else: 86 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 87 | 88 | return file_lists 89 | 90 | @staticmethod 91 | def read(path): 92 | zip_path, path_img = ZipReader.split_zip_style_path(path) 93 | zfile = ZipReader.get_zipfile(zip_path) 94 | data = zfile.read(path_img) 95 | return data 96 | 97 | @staticmethod 98 | def imread(path): 99 | zip_path, path_img = ZipReader.split_zip_style_path(path) 100 | zfile = ZipReader.get_zipfile(zip_path) 101 | data = zfile.read(path_img) 102 | try: 103 | im = Image.open(io.BytesIO(data)) 104 | except: 105 | print("ERROR IMG LOADED: ", path_img) 106 | random_img = np.random.rand(224, 224, 3) * 255 107 | im = Image.fromarray(np.uint8(random_img)) 108 | return im 109 | -------------------------------------------------------------------------------- /deit/augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/deit 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | """ 10 | 3Augment implementation 11 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 12 | and timm DA(https://github.com/rwightman/pytorch-image-models) 13 | """ 14 | import torch 15 | from torchvision import transforms 16 | 17 | from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor 18 | 19 | import numpy as np 20 | from torchvision import datasets, transforms 21 | import random 22 | 23 | 24 | 25 | from PIL import ImageFilter, ImageOps 26 | import torchvision.transforms.functional as TF 27 | 28 | 29 | class GaussianBlur(object): 30 | """ 31 | Apply Gaussian Blur to the PIL image. 32 | """ 33 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 34 | self.prob = p 35 | self.radius_min = radius_min 36 | self.radius_max = radius_max 37 | 38 | def __call__(self, img): 39 | do_it = random.random() <= self.prob 40 | if not do_it: 41 | return img 42 | 43 | img = img.filter( 44 | ImageFilter.GaussianBlur( 45 | radius=random.uniform(self.radius_min, self.radius_max) 46 | ) 47 | ) 48 | return img 49 | 50 | class Solarization(object): 51 | """ 52 | Apply Solarization to the PIL image. 53 | """ 54 | def __init__(self, p=0.2): 55 | self.p = p 56 | 57 | def __call__(self, img): 58 | if random.random() < self.p: 59 | return ImageOps.solarize(img) 60 | else: 61 | return img 62 | 63 | class gray_scale(object): 64 | """ 65 | Apply Solarization to the PIL image. 66 | """ 67 | def __init__(self, p=0.2): 68 | self.p = p 69 | self.transf = transforms.Grayscale(3) 70 | 71 | def __call__(self, img): 72 | if random.random() < self.p: 73 | return self.transf(img) 74 | else: 75 | return img 76 | 77 | 78 | 79 | class horizontal_flip(object): 80 | """ 81 | Apply Solarization to the PIL image. 82 | """ 83 | def __init__(self, p=0.2,activate_pred=False): 84 | self.p = p 85 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 86 | 87 | def __call__(self, img): 88 | if random.random() < self.p: 89 | return self.transf(img) 90 | else: 91 | return img 92 | 93 | 94 | 95 | def new_data_aug_generator(args = None): 96 | img_size = args.input_size 97 | remove_random_resized_crop = args.src 98 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 99 | primary_tfl = [] 100 | scale=(0.08, 1.0) 101 | interpolation='bicubic' 102 | if remove_random_resized_crop: 103 | primary_tfl = [ 104 | transforms.Resize(img_size, interpolation=3), 105 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), 106 | transforms.RandomHorizontalFlip() 107 | ] 108 | else: 109 | primary_tfl = [ 110 | RandomResizedCropAndInterpolation( 111 | img_size, scale=scale, interpolation=interpolation), 112 | transforms.RandomHorizontalFlip() 113 | ] 114 | 115 | 116 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), 117 | Solarization(p=1.0), 118 | GaussianBlur(p=1.0)])] 119 | 120 | if args.color_jitter is not None and not args.color_jitter==0: 121 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 122 | final_tfl = [ 123 | transforms.ToTensor(), 124 | transforms.Normalize( 125 | mean=torch.tensor(mean), 126 | std=torch.tensor(std)) 127 | ] 128 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) 129 | -------------------------------------------------------------------------------- /swin/data/data_simmim_pt.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # SimMIM 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Zhenda Xie 11 | # -------------------------------------------------------- 12 | 13 | import math 14 | import random 15 | import numpy as np 16 | 17 | import torch 18 | import torch.distributed as dist 19 | import torchvision.transforms as T 20 | from torch.utils.data import DataLoader, DistributedSampler 21 | from torch.utils.data._utils.collate import default_collate 22 | from torchvision.datasets import ImageFolder 23 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 24 | 25 | 26 | class MaskGenerator: 27 | def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): 28 | self.input_size = input_size 29 | self.mask_patch_size = mask_patch_size 30 | self.model_patch_size = model_patch_size 31 | self.mask_ratio = mask_ratio 32 | 33 | assert self.input_size % self.mask_patch_size == 0 34 | assert self.mask_patch_size % self.model_patch_size == 0 35 | 36 | self.rand_size = self.input_size // self.mask_patch_size 37 | self.scale = self.mask_patch_size // self.model_patch_size 38 | 39 | self.token_count = self.rand_size ** 2 40 | self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) 41 | 42 | def __call__(self): 43 | mask_idx = np.random.permutation(self.token_count)[:self.mask_count] 44 | mask = np.zeros(self.token_count, dtype=int) 45 | mask[mask_idx] = 1 46 | 47 | mask = mask.reshape((self.rand_size, self.rand_size)) 48 | mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) 49 | 50 | return mask 51 | 52 | 53 | class SimMIMTransform: 54 | def __init__(self, config): 55 | self.transform_img = T.Compose([ 56 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 57 | T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), 58 | T.RandomHorizontalFlip(), 59 | T.ToTensor(), 60 | T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), 61 | ]) 62 | 63 | if config.MODEL.TYPE in ['swin', 'swinv2']: 64 | model_patch_size=config.MODEL.SWIN.PATCH_SIZE 65 | else: 66 | raise NotImplementedError 67 | 68 | self.mask_generator = MaskGenerator( 69 | input_size=config.DATA.IMG_SIZE, 70 | mask_patch_size=config.DATA.MASK_PATCH_SIZE, 71 | model_patch_size=model_patch_size, 72 | mask_ratio=config.DATA.MASK_RATIO, 73 | ) 74 | 75 | def __call__(self, img): 76 | img = self.transform_img(img) 77 | mask = self.mask_generator() 78 | 79 | return img, mask 80 | 81 | 82 | def collate_fn(batch): 83 | if not isinstance(batch[0][0], tuple): 84 | return default_collate(batch) 85 | else: 86 | batch_num = len(batch) 87 | ret = [] 88 | for item_idx in range(len(batch[0][0])): 89 | if batch[0][0][item_idx] is None: 90 | ret.append(None) 91 | else: 92 | ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) 93 | ret.append(default_collate([batch[i][1] for i in range(batch_num)])) 94 | return ret 95 | 96 | 97 | def build_loader_simmim(config): 98 | transform = SimMIMTransform(config) 99 | dataset = ImageFolder(config.DATA.DATA_PATH, transform) 100 | 101 | sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) 102 | dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) 103 | 104 | return dataloader -------------------------------------------------------------------------------- /resnet_rsb/models.py: -------------------------------------------------------------------------------- 1 | # AugSub 2 | # Copyright (c) 2023-present NAVER Cloud Corp. 3 | # CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import numpy as np 10 | 11 | import timm 12 | from timm.models.registry import register_model 13 | from timm.models.resnet import Bottleneck, default_cfgs 14 | from timm.models.helpers import build_model_with_cfg 15 | 16 | 17 | class augsub_ResNet(timm.models.ResNet): 18 | 19 | def patchify(self, imgs): 20 | """ 21 | imgs: (N, 3, H, W) 22 | x: (N, L, patch_size**2 *3) 23 | """ 24 | p = 32 25 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 26 | 27 | h = w = imgs.shape[2] // p 28 | n = imgs.shape[0] 29 | x = imgs.reshape(shape=(n, 3, h, p, w, p)) 30 | x = torch.einsum('nchpwq->nhwpqc', x) 31 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) 32 | 33 | return x 34 | 35 | def unpatchify(self, x): 36 | """ 37 | x: (N, L, patch_size**2 *3) 38 | imgs: (N, 3, H, W) 39 | """ 40 | p = 32 41 | h = w = int(x.shape[1] ** .5) 42 | n = x.shape[0] 43 | assert h * w == x.shape[1] 44 | 45 | x = x.reshape(shape=(n, h, w, p, p, 3)) 46 | x = torch.einsum('nhwpqc->nchpwq', x) 47 | imgs = x.reshape(shape=(n, 3, h * p, h * p)) 48 | 49 | return imgs 50 | 51 | def random_masking(self, x, mask_ratio): 52 | """ 53 | Perform per-sample random masking by per-sample shuffling. 54 | Per-sample shuffling is done by argsort random noise. 55 | x: [N, L, D], sequence 56 | """ 57 | N, L, D = x.shape # batch, length, dim 58 | len_keep = int(L * (1 - mask_ratio)) 59 | 60 | # Normalize mask_noise to [0, 1] 61 | noise = torch.rand([N, L], device=x.device) 62 | 63 | # sort noise for each sample 64 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 65 | ids_restore = torch.argsort(ids_shuffle, dim=1) 66 | 67 | # generate the binary mask: 0 is keep, 1 is remove 68 | mask = torch.ones([N, L], device=x.device) 69 | mask[:, :len_keep] = 0 70 | # unshuffle to get the binary mask 71 | mask = torch.gather(mask, dim=1, index=ids_restore) 72 | 73 | # Zero-out the masked regions 74 | x_masked = x * mask.unsqueeze(-1) 75 | 76 | return x_masked 77 | 78 | def forward(self, x, augsub_type='none', augsub_ratio=0.0): 79 | if augsub_type == 'masking': 80 | if augsub_ratio > 0.0: 81 | x = self.patchify(x) 82 | x = self.random_masking(x, augsub_ratio) 83 | x = self.unpatchify(x) 84 | elif augsub_type != 'none': 85 | raise NotImplementedError('Only support augsub_type == masking') 86 | x = self.forward_features(x) 87 | x = self.global_pool(x) 88 | if self.drop_rate: 89 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 90 | x = self.fc(x) 91 | return x 92 | 93 | def _create_augsub_resnet(variant, pretrained=False, **kwargs): 94 | return build_model_with_cfg( 95 | augsub_ResNet, variant, pretrained, 96 | default_cfg=default_cfgs[variant], 97 | **kwargs) 98 | 99 | @register_model 100 | def resnet50(pretrained=False, **kwargs): 101 | """Constructs a ResNet-50 model. 102 | """ 103 | model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) 104 | return _create_augsub_resnet('resnet50', pretrained, **model_args) 105 | 106 | @register_model 107 | def resnet101(pretrained=False, **kwargs): 108 | """Constructs a ResNet-101 model. 109 | """ 110 | model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) 111 | return _create_augsub_resnet('resnet101', pretrained, **model_args) 112 | 113 | @register_model 114 | def resnet152(pretrained=False, **kwargs): 115 | """Constructs a ResNet-152 model. 116 | """ 117 | model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) 118 | return _create_augsub_resnet('resnet152', pretrained, **model_args) 119 | -------------------------------------------------------------------------------- /swin/kernels/window_process/swin_window_process.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | */ 5 | 6 | /* 7 | * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 8 | * 9 | * Licensed under the Apache License, Version 2.0 (the "License"); 10 | * you may not use this file except in compliance with the License. 11 | * You may obtain a copy of the License at 12 | * 13 | * http://www.apache.org/licenses/LICENSE-2.0 14 | * 15 | * Unless required by applicable law or agreed to in writing, software 16 | * distributed under the License is distributed on an "AS IS" BASIS, 17 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | * See the License for the specific language governing permissions and 19 | * limitations under the License. 20 | */ 21 | 22 | #include 23 | #include 24 | 25 | 26 | at::Tensor roll_and_window_partition_forward_cuda( 27 | at::Tensor & input, 28 | //at::Tensor & output, 29 | const int B, 30 | const int H, 31 | const int W, 32 | const int C, 33 | const int shift_size, 34 | const int window_size); 35 | 36 | 37 | at::Tensor roll_and_window_partition_backward_cuda( 38 | at::Tensor & grad_in, 39 | //at::Tensor & grad_out, 40 | const int B, 41 | const int H, 42 | const int W, 43 | const int C, 44 | const int shift_size, 45 | const int window_size); 46 | 47 | 48 | at::Tensor window_merge_and_roll_forward_cuda( 49 | at::Tensor & input, 50 | //at::Tensor & output, 51 | const int B, 52 | const int H, 53 | const int W, 54 | const int C, 55 | const int shift_size, 56 | const int window_size); 57 | 58 | at::Tensor window_merge_and_roll_backward_cuda( 59 | at::Tensor & grad_in, 60 | //at::Tensor & grad_out, 61 | const int B, 62 | const int H, 63 | const int W, 64 | const int C, 65 | const int shift_size, 66 | const int window_size); 67 | 68 | 69 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 70 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 71 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 72 | 73 | 74 | 75 | at::Tensor roll_and_window_partition_forward( 76 | at::Tensor & input, 77 | //at::Tensor & output, 78 | const int B, 79 | const int H, 80 | const int W, 81 | const int C, 82 | const int shift_size, 83 | const int window_size){ 84 | CHECK_INPUT(input); 85 | return roll_and_window_partition_forward_cuda(input, B, H, W, C, shift_size, window_size); 86 | } 87 | 88 | 89 | at::Tensor roll_and_window_partition_backward( 90 | at::Tensor & grad_in, 91 | //at::Tensor & grad_out, 92 | const int B, 93 | const int H, 94 | const int W, 95 | const int C, 96 | const int shift_size, 97 | const int window_size){ 98 | CHECK_INPUT(grad_in); 99 | return roll_and_window_partition_backward_cuda(grad_in, B, H, W, C, shift_size, window_size); 100 | } 101 | 102 | 103 | at::Tensor window_merge_and_roll_forward( 104 | at::Tensor & input, 105 | //at::Tensor & output, 106 | const int B, 107 | const int H, 108 | const int W, 109 | const int C, 110 | const int shift_size, 111 | const int window_size){ 112 | CHECK_INPUT(input); 113 | return window_merge_and_roll_forward_cuda(input, B, H, W, C, shift_size, window_size); 114 | } 115 | 116 | 117 | at::Tensor window_merge_and_roll_backward( 118 | at::Tensor & grad_in, 119 | //at::Tensor & grad_out, 120 | const int B, 121 | const int H, 122 | const int W, 123 | const int C, 124 | const int shift_size, 125 | const int window_size){ 126 | CHECK_INPUT(grad_in); 127 | return window_merge_and_roll_backward_cuda(grad_in, B, H, W, C, shift_size, window_size); 128 | } 129 | 130 | 131 | 132 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 133 | m.def("roll_and_window_partition_forward", &roll_and_window_partition_forward, "torch.roll and window_partition."); 134 | m.def("roll_and_window_partition_backward", &roll_and_window_partition_backward, "torch.roll and window_partition."); 135 | m.def("window_merge_and_roll_forward", &window_merge_and_roll_forward, "window merge and torch.roll."); 136 | m.def("window_merge_and_roll_backward", &window_merge_and_roll_backward, "window merge and torch.roll."); 137 | } -------------------------------------------------------------------------------- /mae/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | # -------------------------------------------------------- 12 | # Position embedding utils 13 | # -------------------------------------------------------- 14 | 15 | import numpy as np 16 | 17 | import torch 18 | 19 | # -------------------------------------------------------- 20 | # 2D sine-cosine position embedding 21 | # References: 22 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 23 | # MoCo v3: https://github.com/facebookresearch/moco-v3 24 | # -------------------------------------------------------- 25 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 26 | """ 27 | grid_size: int of the grid height and width 28 | return: 29 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 30 | """ 31 | grid_h = np.arange(grid_size, dtype=np.float32) 32 | grid_w = np.arange(grid_size, dtype=np.float32) 33 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 34 | grid = np.stack(grid, axis=0) 35 | 36 | grid = grid.reshape([2, 1, grid_size, grid_size]) 37 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 38 | if cls_token: 39 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 40 | return pos_embed 41 | 42 | 43 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 44 | assert embed_dim % 2 == 0 45 | 46 | # use half of dimensions to encode grid_h 47 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 48 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 49 | 50 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 51 | return emb 52 | 53 | 54 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 55 | """ 56 | embed_dim: output dimension for each position 57 | pos: a list of positions to be encoded: size (M,) 58 | out: (M, D) 59 | """ 60 | assert embed_dim % 2 == 0 61 | omega = np.arange(embed_dim // 2, dtype=np.float) 62 | omega /= embed_dim / 2. 63 | omega = 1. / 10000**omega # (D/2,) 64 | 65 | pos = pos.reshape(-1) # (M,) 66 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 67 | 68 | emb_sin = np.sin(out) # (M, D/2) 69 | emb_cos = np.cos(out) # (M, D/2) 70 | 71 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 72 | return emb 73 | 74 | 75 | # -------------------------------------------------------- 76 | # Interpolate position embeddings for high-resolution 77 | # References: 78 | # DeiT: https://github.com/facebookresearch/deit 79 | # -------------------------------------------------------- 80 | def interpolate_pos_embed(model, checkpoint_model): 81 | if 'pos_embed' in checkpoint_model: 82 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 83 | embedding_size = pos_embed_checkpoint.shape[-1] 84 | num_patches = model.patch_embed.num_patches 85 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 86 | # height (== width) for the checkpoint position embedding 87 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 88 | # height (== width) for the new position embedding 89 | new_size = int(num_patches ** 0.5) 90 | # class_token and dist_token are kept unchanged 91 | if orig_size != new_size: 92 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 93 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 94 | # only the position tokens are interpolated 95 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 96 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 97 | pos_tokens = torch.nn.functional.interpolate( 98 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 99 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 100 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 101 | checkpoint_model['pos_embed'] = new_pos_embed 102 | -------------------------------------------------------------------------------- /deit/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/deit 4 | """ 5 | 6 | # Copyright (c) 2015-present, Facebook, Inc. 7 | # All rights reserved. 8 | import os 9 | import json 10 | 11 | from torchvision import datasets, transforms 12 | from torchvision.datasets.folder import ImageFolder, default_loader 13 | 14 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 15 | from timm.data import create_transform 16 | 17 | 18 | class INatDataset(ImageFolder): 19 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 20 | category='name', loader=default_loader): 21 | self.transform = transform 22 | self.loader = loader 23 | self.target_transform = target_transform 24 | self.year = year 25 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 26 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 27 | with open(path_json) as json_file: 28 | data = json.load(json_file) 29 | 30 | with open(os.path.join(root, 'categories.json')) as json_file: 31 | data_catg = json.load(json_file) 32 | 33 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 34 | 35 | with open(path_json_for_targeter) as json_file: 36 | data_for_targeter = json.load(json_file) 37 | 38 | targeter = {} 39 | indexer = 0 40 | for elem in data_for_targeter['annotations']: 41 | king = [] 42 | king.append(data_catg[int(elem['category_id'])][category]) 43 | if king[0] not in targeter.keys(): 44 | targeter[king[0]] = indexer 45 | indexer += 1 46 | self.nb_classes = len(targeter) 47 | 48 | self.samples = [] 49 | for elem in data['images']: 50 | cut = elem['file_name'].split('/') 51 | target_current = int(cut[2]) 52 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 53 | 54 | categors = data_catg[target_current] 55 | target_current_true = targeter[categors[category]] 56 | self.samples.append((path_current, target_current_true)) 57 | 58 | # __getitem__ and __len__ inherited from ImageFolder 59 | 60 | 61 | def build_dataset(is_train, args): 62 | transform = build_transform(is_train, args) 63 | 64 | if args.data_set == 'CIFAR': 65 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 66 | nb_classes = 100 67 | elif args.data_set == 'IMNET': 68 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 69 | dataset = datasets.ImageFolder(root, transform=transform) 70 | nb_classes = 1000 71 | elif args.data_set == 'INAT': 72 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 73 | category=args.inat_category, transform=transform) 74 | nb_classes = dataset.nb_classes 75 | elif args.data_set == 'INAT19': 76 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 77 | category=args.inat_category, transform=transform) 78 | nb_classes = dataset.nb_classes 79 | 80 | return dataset, nb_classes 81 | 82 | 83 | def build_transform(is_train, args): 84 | resize_im = args.input_size > 32 85 | if is_train: 86 | # this should always dispatch to transforms_imagenet_train 87 | transform = create_transform( 88 | input_size=args.input_size, 89 | is_training=True, 90 | color_jitter=args.color_jitter, 91 | auto_augment=args.aa, 92 | interpolation=args.train_interpolation, 93 | re_prob=args.reprob, 94 | re_mode=args.remode, 95 | re_count=args.recount, 96 | ) 97 | if not resize_im: 98 | # replace RandomResizedCropAndInterpolation with 99 | # RandomCrop 100 | transform.transforms[0] = transforms.RandomCrop( 101 | args.input_size, padding=4) 102 | return transform 103 | 104 | t = [] 105 | if resize_im: 106 | size = int(args.input_size / args.eval_crop_ratio) 107 | t.append( 108 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 109 | ) 110 | t.append(transforms.CenterCrop(args.input_size)) 111 | 112 | t.append(transforms.ToTensor()) 113 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 114 | return transforms.Compose(t) 115 | -------------------------------------------------------------------------------- /deit/engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/deit 4 | """ 5 | 6 | # Copyright (c) 2015-present, Facebook, Inc. 7 | # All rights reserved. 8 | """ 9 | Train and eval functions used in main.py 10 | """ 11 | import math 12 | import sys 13 | from typing import Iterable, Optional 14 | 15 | import torch 16 | 17 | from timm.data import Mixup 18 | from timm.utils import accuracy, ModelEma 19 | import torch.nn.functional as F 20 | 21 | from losses import DistillationLoss 22 | import utils 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 28 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 29 | set_training_mode=True, args = None): 30 | model.train(set_training_mode) 31 | metric_logger = utils.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 100 35 | 36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 37 | samples = samples.to(device, non_blocking=True) 38 | targets = targets.to(device, non_blocking=True) 39 | 40 | if mixup_fn is not None: 41 | samples, targets = mixup_fn(samples, targets) 42 | 43 | if args.bce_loss: 44 | targets = targets.gt(0.0).type(targets.dtype) 45 | 46 | with torch.cuda.amp.autocast(): 47 | outputs = model(samples) 48 | loss = criterion(samples, outputs, targets) 49 | loss_value = loss.item() 50 | 51 | optimizer.zero_grad() 52 | 53 | if args.augsub != 'none': 54 | # Main model backward 55 | loss_scaler(loss/2, optimizer, clip_grad=max_norm, parameters=model.parameters(), 56 | create_graph=False, update_grad=False) 57 | 58 | # Sub-model forward 59 | outputs_sub = model(samples, augsub_type=args.augsub, augsub_ratio=args.augsub_ratio) 60 | target_sub = F.sigmoid(outputs.detach()) if args.bce_loss else F.softmax(outputs.detach(), dim=-1) 61 | loss = criterion(samples, outputs_sub, target_sub) 62 | 63 | # Sub-model backward 64 | loss_scaler(loss/2, optimizer, clip_grad=max_norm, parameters=model.parameters(), 65 | create_graph=False, update_grad=True) 66 | else: 67 | loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), 68 | create_graph=False, update_grad=True) 69 | 70 | torch.cuda.synchronize() 71 | if model_ema is not None: 72 | model_ema.update(model) 73 | 74 | metric_logger.update(loss=loss_value) 75 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 76 | # gather the stats from all processes 77 | metric_logger.synchronize_between_processes() 78 | print("Averaged stats:", metric_logger) 79 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 80 | 81 | 82 | @torch.no_grad() 83 | def evaluate(data_loader, model, device): 84 | criterion = torch.nn.CrossEntropyLoss() 85 | 86 | metric_logger = utils.MetricLogger(delimiter=" ") 87 | header = 'Test:' 88 | 89 | # switch to evaluation mode 90 | model.eval() 91 | 92 | for images, target in metric_logger.log_every(data_loader, 10, header): 93 | images = images.to(device, non_blocking=True) 94 | target = target.to(device, non_blocking=True) 95 | 96 | # compute output 97 | with torch.cuda.amp.autocast(): 98 | output = model(images) 99 | loss = criterion(output, target) 100 | 101 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 102 | 103 | batch_size = images.shape[0] 104 | metric_logger.update(loss=loss.item()) 105 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 106 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 107 | # gather the stats from all processes 108 | metric_logger.synchronize_between_processes() 109 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 110 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 111 | 112 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 113 | -------------------------------------------------------------------------------- /swin/data/data_simmim_ft.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # SimMIM 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Zhenda Xie 11 | # -------------------------------------------------------- 12 | 13 | import os 14 | import torch.distributed as dist 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | from torchvision import datasets, transforms 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | from timm.data import Mixup 19 | from timm.data import create_transform 20 | from timm.data.transforms import _pil_interp 21 | 22 | 23 | def build_loader_finetune(config): 24 | config.defrost() 25 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 26 | config.freeze() 27 | dataset_val, _ = build_dataset(is_train=False, config=config) 28 | 29 | num_tasks = dist.get_world_size() 30 | global_rank = dist.get_rank() 31 | sampler_train = DistributedSampler( 32 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 33 | ) 34 | sampler_val = DistributedSampler( 35 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False 36 | ) 37 | 38 | data_loader_train = DataLoader( 39 | dataset_train, sampler=sampler_train, 40 | batch_size=config.DATA.BATCH_SIZE, 41 | num_workers=config.DATA.NUM_WORKERS, 42 | pin_memory=config.DATA.PIN_MEMORY, 43 | drop_last=True, 44 | ) 45 | 46 | data_loader_val = DataLoader( 47 | dataset_val, sampler=sampler_val, 48 | batch_size=config.DATA.BATCH_SIZE, 49 | num_workers=config.DATA.NUM_WORKERS, 50 | pin_memory=config.DATA.PIN_MEMORY, 51 | drop_last=False, 52 | ) 53 | 54 | # setup mixup / cutmix 55 | mixup_fn = None 56 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 57 | if mixup_active: 58 | mixup_fn = Mixup( 59 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 60 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 61 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 62 | 63 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 64 | 65 | 66 | def build_dataset(is_train, config): 67 | transform = build_transform(is_train, config) 68 | 69 | if config.DATA.DATASET == 'imagenet': 70 | prefix = 'train' if is_train else 'val' 71 | root = os.path.join(config.DATA.DATA_PATH, prefix) 72 | dataset = datasets.ImageFolder(root, transform=transform) 73 | nb_classes = 1000 74 | else: 75 | raise NotImplementedError("We only support ImageNet Now.") 76 | 77 | return dataset, nb_classes 78 | 79 | 80 | def build_transform(is_train, config): 81 | resize_im = config.DATA.IMG_SIZE > 32 82 | if is_train: 83 | # this should always dispatch to transforms_imagenet_train 84 | transform = create_transform( 85 | input_size=config.DATA.IMG_SIZE, 86 | is_training=True, 87 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 88 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 89 | re_prob=config.AUG.REPROB, 90 | re_mode=config.AUG.REMODE, 91 | re_count=config.AUG.RECOUNT, 92 | interpolation=config.DATA.INTERPOLATION, 93 | ) 94 | if not resize_im: 95 | # replace RandomResizedCropAndInterpolation with 96 | # RandomCrop 97 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 98 | return transform 99 | 100 | t = [] 101 | if resize_im: 102 | if config.TEST.CROP: 103 | size = int((256 / 224) * config.DATA.IMG_SIZE) 104 | t.append( 105 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 106 | # to maintain same ratio w.r.t. 224 images 107 | ) 108 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 109 | else: 110 | t.append( 111 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 112 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 113 | ) 114 | 115 | t.append(transforms.ToTensor()) 116 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 117 | return transforms.Compose(t) 118 | -------------------------------------------------------------------------------- /mae/models_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | # -------------------------------------------------------- 12 | # References: 13 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 14 | # DeiT: https://github.com/facebookresearch/deit 15 | # -------------------------------------------------------- 16 | 17 | from functools import partial 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | import timm.models.vision_transformer 24 | from timm.models.layers import DropPath 25 | 26 | 27 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 28 | """ Vision Transformer with support for global average pooling 29 | """ 30 | def __init__(self, global_pool=False, **kwargs): 31 | super(VisionTransformer, self).__init__(**kwargs) 32 | 33 | self.global_pool = global_pool 34 | if self.global_pool: 35 | norm_layer = kwargs['norm_layer'] 36 | embed_dim = kwargs['embed_dim'] 37 | self.fc_norm = norm_layer(embed_dim) 38 | 39 | del self.norm # remove the original norm 40 | 41 | def forward_features(self, x): 42 | B = x.shape[0] 43 | x = self.patch_embed(x) 44 | 45 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 46 | x = torch.cat((cls_tokens, x), dim=1) 47 | x = x + self.pos_embed 48 | x = self.pos_drop(x) 49 | 50 | for blk in self.blocks: 51 | x = blk(x) 52 | 53 | if self.global_pool: 54 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 55 | outcome = self.fc_norm(x) 56 | else: 57 | x = self.norm(x) 58 | outcome = x[:, 0] 59 | 60 | return outcome 61 | 62 | class augsub_VisionTransformer(VisionTransformer): 63 | 64 | def __init__(self, **kwargs): 65 | super(augsub_VisionTransformer, self).__init__(**kwargs) 66 | 67 | # Do not use nn.Identity for all transformer blocks 68 | for block in self.blocks: 69 | if isinstance(block.drop_path, nn.Identity): 70 | block.drop_path = DropPath(0.) 71 | 72 | def random_masking(self, x, mask_ratio): 73 | cls_token = x[:, :1, :] 74 | x = x[:, 1:, :] 75 | N, L, D = x.shape # batch, length, dim 76 | len_keep = int(L * (1 - mask_ratio)) 77 | 78 | noise = torch.rand([N, L], device=x.device) 79 | 80 | # sort noise for each sample 81 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 82 | 83 | # keep the first subset 84 | ids_keep = ids_shuffle[:, :len_keep] 85 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 86 | x_masked = torch.cat([cls_token, x_masked], dim=1) 87 | 88 | return x_masked 89 | 90 | def forward_features(self, x, mask_ratio=0.0): 91 | B = x.shape[0] 92 | x = self.patch_embed(x) 93 | 94 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 95 | x = torch.cat((cls_tokens, x), dim=1) 96 | x = x + self.pos_embed 97 | 98 | if mask_ratio > 0.0: 99 | x = self.random_masking(x, mask_ratio) 100 | 101 | x = self.pos_drop(x) 102 | 103 | for blk in self.blocks: 104 | x = blk(x) 105 | 106 | if self.global_pool: 107 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 108 | outcome = self.fc_norm(x) 109 | else: 110 | x = self.norm(x) 111 | outcome = x[:, 0] 112 | 113 | return outcome 114 | 115 | def forward(self, x, augsub_type='none', augsub_ratio=0.0): 116 | if augsub_type == 'dropout': 117 | raise NotImplementedError('Augdrop is not implemented yet') 118 | elif augsub_type == 'droppath': 119 | raise NotImplementedError('Augpath is not implemented yet') 120 | elif augsub_type == 'masking': 121 | x = self.forward_features(x, augsub_ratio) 122 | else: 123 | x = self.forward_features(x) 124 | x = self.head(x) 125 | 126 | return x 127 | 128 | def vit_base_patch16(**kwargs): 129 | model = augsub_VisionTransformer( 130 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 131 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 132 | return model 133 | 134 | 135 | def vit_large_patch16(**kwargs): 136 | model = augsub_VisionTransformer( 137 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 138 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 139 | return model 140 | 141 | 142 | def vit_huge_patch14(**kwargs): 143 | model = augsub_VisionTransformer( 144 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 145 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 146 | return model -------------------------------------------------------------------------------- /mae/engine_finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | # -------------------------------------------------------- 12 | # References: 13 | # DeiT: https://github.com/facebookresearch/deit 14 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 15 | # -------------------------------------------------------- 16 | 17 | import math 18 | import sys 19 | from typing import Iterable, Optional 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | 24 | from timm.data import Mixup 25 | from timm.utils import accuracy 26 | 27 | import util.misc as misc 28 | import util.lr_sched as lr_sched 29 | 30 | 31 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 32 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 33 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 34 | mixup_fn: Optional[Mixup] = None, log_writer=None, 35 | args=None): 36 | model.train(True) 37 | metric_logger = misc.MetricLogger(delimiter=" ") 38 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 39 | header = 'Epoch: [{}]'.format(epoch) 40 | print_freq = 100 41 | 42 | accum_iter = args.accum_iter 43 | 44 | optimizer.zero_grad() 45 | 46 | if log_writer is not None: 47 | print('log_dir: {}'.format(log_writer.log_dir)) 48 | 49 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 50 | 51 | # we use a per iteration (instead of per epoch) lr scheduler 52 | if data_iter_step % accum_iter == 0: 53 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 54 | 55 | samples = samples.to(device, non_blocking=True) 56 | targets = targets.to(device, non_blocking=True) 57 | 58 | if mixup_fn is not None: 59 | samples, targets = mixup_fn(samples, targets) 60 | 61 | with torch.cuda.amp.autocast(): 62 | outputs = model(samples) 63 | loss = criterion(outputs, targets) 64 | loss_value = loss.item() 65 | 66 | if args.augsub != 'none': 67 | # Main model backward 68 | loss /= accum_iter 69 | loss_scaler(loss/2, optimizer, clip_grad=max_norm, parameters=model.parameters(), 70 | create_graph=False, update_grad=False) 71 | 72 | # Sub-model forward 73 | outputs_sub = model(samples, augsub_type=args.augsub, augsub_ratio=args.augsub_ratio) 74 | loss = criterion(outputs_sub, F.softmax(outputs.detach(), dim=-1)) 75 | 76 | # Sub-model backward 77 | loss /= accum_iter 78 | loss_scaler(loss/2, optimizer, clip_grad=max_norm, parameters=model.parameters(), 79 | create_graph=False, update_grad=(data_iter_step + 1) % accum_iter == 0) 80 | else: 81 | loss /= accum_iter 82 | loss_scaler(loss, optimizer, clip_grad=max_norm, 83 | parameters=model.parameters(), create_graph=False, 84 | update_grad=(data_iter_step + 1) % accum_iter == 0) 85 | 86 | if (data_iter_step + 1) % accum_iter == 0: 87 | optimizer.zero_grad() 88 | 89 | torch.cuda.synchronize() 90 | 91 | metric_logger.update(loss=loss_value) 92 | min_lr = 10. 93 | max_lr = 0. 94 | for group in optimizer.param_groups: 95 | min_lr = min(min_lr, group["lr"]) 96 | max_lr = max(max_lr, group["lr"]) 97 | 98 | metric_logger.update(lr=max_lr) 99 | 100 | loss_value_reduce = misc.all_reduce_mean(loss_value) 101 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 102 | """ We use epoch_1000x as the x-axis in tensorboard. 103 | This calibrates different curves when batch size changes. 104 | """ 105 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 106 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 107 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 108 | 109 | # gather the stats from all processes 110 | metric_logger.synchronize_between_processes() 111 | print("Averaged stats:", metric_logger) 112 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 113 | 114 | 115 | @torch.no_grad() 116 | def evaluate(data_loader, model, device): 117 | criterion = torch.nn.CrossEntropyLoss() 118 | 119 | metric_logger = misc.MetricLogger(delimiter=" ") 120 | header = 'Test:' 121 | 122 | # switch to evaluation mode 123 | model.eval() 124 | 125 | for batch in metric_logger.log_every(data_loader, 10, header): 126 | images = batch[0] 127 | target = batch[-1] 128 | images = images.to(device, non_blocking=True) 129 | target = target.to(device, non_blocking=True) 130 | 131 | # compute output 132 | with torch.cuda.amp.autocast(): 133 | output = model(images) 134 | loss = criterion(output, target) 135 | 136 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 137 | 138 | batch_size = images.shape[0] 139 | metric_logger.update(loss=loss.item()) 140 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 141 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 142 | # gather the stats from all processes 143 | metric_logger.synchronize_between_processes() 144 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 145 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 146 | 147 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /swin/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # -------------------------------------------------------- 12 | 13 | import bisect 14 | 15 | import torch 16 | from timm.scheduler.cosine_lr import CosineLRScheduler 17 | from timm.scheduler.step_lr import StepLRScheduler 18 | from timm.scheduler.scheduler import Scheduler 19 | 20 | 21 | def build_scheduler(config, optimizer, n_iter_per_epoch): 22 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 23 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 24 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 25 | multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] 26 | 27 | lr_scheduler = None 28 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 29 | lr_scheduler = CosineLRScheduler( 30 | optimizer, 31 | t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps, 32 | t_mul=1., 33 | lr_min=config.TRAIN.MIN_LR, 34 | warmup_lr_init=config.TRAIN.WARMUP_LR, 35 | warmup_t=warmup_steps, 36 | cycle_limit=1, 37 | t_in_epochs=False, 38 | warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX, 39 | ) 40 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 41 | lr_scheduler = LinearLRScheduler( 42 | optimizer, 43 | t_initial=num_steps, 44 | lr_min_rate=0.01, 45 | warmup_lr_init=config.TRAIN.WARMUP_LR, 46 | warmup_t=warmup_steps, 47 | t_in_epochs=False, 48 | ) 49 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 50 | lr_scheduler = StepLRScheduler( 51 | optimizer, 52 | decay_t=decay_steps, 53 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 54 | warmup_lr_init=config.TRAIN.WARMUP_LR, 55 | warmup_t=warmup_steps, 56 | t_in_epochs=False, 57 | ) 58 | elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': 59 | lr_scheduler = MultiStepLRScheduler( 60 | optimizer, 61 | milestones=multi_steps, 62 | gamma=config.TRAIN.LR_SCHEDULER.GAMMA, 63 | warmup_lr_init=config.TRAIN.WARMUP_LR, 64 | warmup_t=warmup_steps, 65 | t_in_epochs=False, 66 | ) 67 | 68 | return lr_scheduler 69 | 70 | 71 | class LinearLRScheduler(Scheduler): 72 | def __init__(self, 73 | optimizer: torch.optim.Optimizer, 74 | t_initial: int, 75 | lr_min_rate: float, 76 | warmup_t=0, 77 | warmup_lr_init=0., 78 | t_in_epochs=True, 79 | noise_range_t=None, 80 | noise_pct=0.67, 81 | noise_std=1.0, 82 | noise_seed=42, 83 | initialize=True, 84 | ) -> None: 85 | super().__init__( 86 | optimizer, param_group_field="lr", 87 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 88 | initialize=initialize) 89 | 90 | self.t_initial = t_initial 91 | self.lr_min_rate = lr_min_rate 92 | self.warmup_t = warmup_t 93 | self.warmup_lr_init = warmup_lr_init 94 | self.t_in_epochs = t_in_epochs 95 | if self.warmup_t: 96 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 97 | super().update_groups(self.warmup_lr_init) 98 | else: 99 | self.warmup_steps = [1 for _ in self.base_values] 100 | 101 | def _get_lr(self, t): 102 | if t < self.warmup_t: 103 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 104 | else: 105 | t = t - self.warmup_t 106 | total_t = self.t_initial - self.warmup_t 107 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 108 | return lrs 109 | 110 | def get_epoch_values(self, epoch: int): 111 | if self.t_in_epochs: 112 | return self._get_lr(epoch) 113 | else: 114 | return None 115 | 116 | def get_update_values(self, num_updates: int): 117 | if not self.t_in_epochs: 118 | return self._get_lr(num_updates) 119 | else: 120 | return None 121 | 122 | 123 | class MultiStepLRScheduler(Scheduler): 124 | def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: 125 | super().__init__(optimizer, param_group_field="lr") 126 | 127 | self.milestones = milestones 128 | self.gamma = gamma 129 | self.warmup_t = warmup_t 130 | self.warmup_lr_init = warmup_lr_init 131 | self.t_in_epochs = t_in_epochs 132 | if self.warmup_t: 133 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 134 | super().update_groups(self.warmup_lr_init) 135 | else: 136 | self.warmup_steps = [1 for _ in self.base_values] 137 | 138 | assert self.warmup_t <= min(self.milestones) 139 | 140 | def _get_lr(self, t): 141 | if t < self.warmup_t: 142 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 143 | else: 144 | lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values] 145 | return lrs 146 | 147 | def get_epoch_values(self, epoch: int): 148 | if self.t_in_epochs: 149 | return self._get_lr(epoch) 150 | else: 151 | return None 152 | 153 | def get_update_values(self, num_updates: int): 154 | if not self.t_in_epochs: 155 | return self._get_lr(num_updates) 156 | else: 157 | return None 158 | -------------------------------------------------------------------------------- /swin/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # -------------------------------------------------------- 12 | 13 | from functools import partial 14 | from torch import optim as optim 15 | 16 | try: 17 | from apex.optimizers import FusedAdam, FusedLAMB 18 | except: 19 | FusedAdam = None 20 | FusedLAMB = None 21 | print("To use FusedLAMB or FusedAdam, please install apex.") 22 | 23 | 24 | def build_optimizer(config, model, simmim=False, is_pretrain=False): 25 | """ 26 | Build optimizer, set weight decay of normalization to 0 by default. 27 | """ 28 | skip = {} 29 | skip_keywords = {} 30 | if hasattr(model, 'no_weight_decay'): 31 | skip = model.no_weight_decay() 32 | if hasattr(model, 'no_weight_decay_keywords'): 33 | skip_keywords = model.no_weight_decay_keywords() 34 | if simmim: 35 | if is_pretrain: 36 | parameters = get_pretrain_param_groups(model, skip, skip_keywords) 37 | else: 38 | depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' else config.MODEL.SWINV2.DEPTHS 39 | num_layers = sum(depths) 40 | get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) 41 | scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2))) 42 | parameters = get_finetune_param_groups(model, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, get_layer_func, scales, skip, skip_keywords) 43 | else: 44 | parameters = set_weight_decay(model, skip, skip_keywords) 45 | 46 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 47 | optimizer = None 48 | if opt_lower == 'sgd': 49 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 50 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 51 | elif opt_lower == 'adamw': 52 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 53 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 54 | elif opt_lower == 'fused_adam': 55 | optimizer = FusedAdam(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 56 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 57 | elif opt_lower == 'fused_lamb': 58 | optimizer = FusedLAMB(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 59 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 60 | 61 | return optimizer 62 | 63 | 64 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 65 | has_decay = [] 66 | no_decay = [] 67 | 68 | for name, param in model.named_parameters(): 69 | if not param.requires_grad: 70 | continue # frozen weights 71 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 72 | check_keywords_in_name(name, skip_keywords): 73 | no_decay.append(param) 74 | # print(f"{name} has no weight decay") 75 | else: 76 | has_decay.append(param) 77 | return [{'params': has_decay}, 78 | {'params': no_decay, 'weight_decay': 0.}] 79 | 80 | 81 | def check_keywords_in_name(name, keywords=()): 82 | isin = False 83 | for keyword in keywords: 84 | if keyword in name: 85 | isin = True 86 | return isin 87 | 88 | 89 | def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): 90 | has_decay = [] 91 | no_decay = [] 92 | has_decay_name = [] 93 | no_decay_name = [] 94 | 95 | for name, param in model.named_parameters(): 96 | if not param.requires_grad: 97 | continue 98 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 99 | check_keywords_in_name(name, skip_keywords): 100 | no_decay.append(param) 101 | no_decay_name.append(name) 102 | else: 103 | has_decay.append(param) 104 | has_decay_name.append(name) 105 | return [{'params': has_decay}, 106 | {'params': no_decay, 'weight_decay': 0.}] 107 | 108 | 109 | def get_swin_layer(name, num_layers, depths): 110 | if name in ("mask_token"): 111 | return 0 112 | elif name.startswith("patch_embed"): 113 | return 0 114 | elif name.startswith("layers"): 115 | layer_id = int(name.split('.')[1]) 116 | block_id = name.split('.')[3] 117 | if block_id == 'reduction' or block_id == 'norm': 118 | return sum(depths[:layer_id + 1]) 119 | layer_id = sum(depths[:layer_id]) + int(block_id) 120 | return layer_id + 1 121 | else: 122 | return num_layers - 1 123 | 124 | 125 | def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): 126 | parameter_group_names = {} 127 | parameter_group_vars = {} 128 | 129 | for name, param in model.named_parameters(): 130 | if not param.requires_grad: 131 | continue 132 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 133 | check_keywords_in_name(name, skip_keywords): 134 | group_name = "no_decay" 135 | this_weight_decay = 0. 136 | else: 137 | group_name = "decay" 138 | this_weight_decay = weight_decay 139 | if get_layer_func is not None: 140 | layer_id = get_layer_func(name) 141 | group_name = "layer_%d_%s" % (layer_id, group_name) 142 | else: 143 | layer_id = None 144 | 145 | if group_name not in parameter_group_names: 146 | if scales is not None: 147 | scale = scales[layer_id] 148 | else: 149 | scale = 1. 150 | 151 | parameter_group_names[group_name] = { 152 | "group_name": group_name, 153 | "weight_decay": this_weight_decay, 154 | "params": [], 155 | "lr": lr * scale, 156 | "lr_scale": scale, 157 | } 158 | parameter_group_vars[group_name] = { 159 | "group_name": group_name, 160 | "weight_decay": this_weight_decay, 161 | "params": [], 162 | "lr": lr * scale, 163 | "lr_scale": scale 164 | } 165 | 166 | parameter_group_vars[group_name]["params"].append(param) 167 | parameter_group_names[group_name]["params"].append(name) 168 | return list(parameter_group_vars.values()) 169 | -------------------------------------------------------------------------------- /deit/README.md: -------------------------------------------------------------------------------- 1 | # Data-Efficient architectures and training for Image classification 2 | 3 | The codes are originated from [https://github.com/facebookresearch/deit](https://github.com/facebookresearch/deit) 4 | 5 | ### Requirements 6 | ```angular2html 7 | torch==1.11.0 8 | torchvision==0.11.0a0 9 | timm==0.3.2 10 | ``` 11 | 12 | ### Performances 13 | 14 | | Architecture | # params | FLOPs | 400 epochs | + MaskSub | 800 epochs | + MaskSub | 15 | |:------------:|:--------:|:------:|:-----------:|:---------------:|:----------:|:---------------:| 16 | | ViT-S/16 | 22.0 M | 4.6 G | 80.4 | **81.1 (+0.7)** | 81.4 | **81.7 (+0.3)** | 17 | | ViT-B/16 | 86.6 M | 17.5 G | 83.5 | **84.1 (+0.6)** | 83.8 | **84.2 (+0.4)** | 18 | | ViT-L/16 | 304.4 M | 61.6 G | 84.5 | **85.2 (+0.7)** | 84.9 | **85.3 (+0.4)** | 19 | | ViT-H/14 | 632.1 M | 167.4 G| 85.1 | **85.7 (+0.6)** | 85.2 | **85.7 (+0.5)** | 20 | 21 | ### MaskSub training commands 22 | - Enviroment variables 23 | ```bash 24 | data_path=/your/path/to/imagenet 25 | save_path=/your/path/to/save 26 | ``` 27 | 28 | - ViT-S 29 | - 400 epochs 30 | ```bash 31 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env main.py --model deit_small_patch16_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --output_dir ${save_path} --batch-size 256 --epochs 400 --smoothing 0.0 --reprob 0.0 --opt fusedlamb --color-jitter 0.3 --lr 4e-3 --weight-decay 0.03 --input-size 224 --drop 0.0 --drop-path 0.0 --unscale-lr --repeated-aug --bce-loss --ThreeAugment --eval-crop-ratio 1.0 --dist-eval 32 | ``` 33 | - 800 epochs 34 | ```bash 35 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env main.py --model deit_small_patch16_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --output_dir ${save_path} --batch-size 256 --epochs 800 --smoothing 0.0 --reprob 0.0 --opt fusedlamb --color-jitter 0.3 --lr 4e-3 --weight-decay 0.05 --input-size 224 --drop 0.0 --drop-path 0.05 --unscale-lr --repeated-aug --bce-loss --ThreeAugment --eval-crop-ratio 1.0 --dist-eval 36 | ``` 37 | - ViT-B 38 | - 400 epochs 39 | ```bash 40 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env main.py --model deit_base_patch16_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --output_dir ${save_path}/pretrain --batch-size 256 --epochs 400 --smoothing 0.0 --reprob 0.0 --opt fusedlamb --color-jitter 0.3 --lr 3e-3 --weight-decay 0.03 --input-size 192 --drop 0.0 --drop-path 0.1 --unscale-lr --repeated-aug --bce-loss --ThreeAugment --eval-crop-ratio 1.0 --dist-eval 41 | ``` 42 | - 800 epochs 43 | ```bash 44 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env main.py --model deit_base_patch16_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --output_dir ${save_path}/pretrain --batch-size 256 --epochs 800 --smoothing 0.0 --reprob 0.0 --opt fusedlamb --color-jitter 0.3 --lr 3e-3 --weight-decay 0.05 --input-size 192 --drop 0.0 --drop-path 0.2 --unscale-lr --repeated-aug --bce-loss --ThreeAugment --eval-crop-ratio 1.0 --dist-eval 45 | ``` 46 | - Finetune 47 | ```bash 48 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env main.py --model deit_base_patch16_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --finetune ${save_path}/pretrain/checkpoint.pth --output_dir ${save_path}/finetune --batch-size 64 --epochs 20 --smoothing 0.1 --reprob 0.0 --opt adamw --lr 1e-5 --weight-decay 0.1 --input-size 224 --drop 0.0 --drop-path 0.2 --mixup 0.8 --cutmix 1.0 --unscale-lr --no-repeated-aug --aa rand-m9-mstd0.5-inc1 --eval-crop-ratio 1.0 --dist-eval 49 | ``` 50 | - ViT-L 51 | - 400 epochs 52 | ```bash 53 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --use_env main.py --model deit_large_patch16_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --output_dir ${save_path}/pretrain --batch-size 32 --epochs 400 --smoothing 0.0 --reprob 0.0 --opt fusedlamb --color-jitter 0.3 --lr 3e-3 --weight-decay 0.03 --input-size 192 --drop 0.0 --drop-path 0.4 --unscale-lr --repeated-aug --bce-loss --ThreeAugment --eval-crop-ratio 1.0 --dist-eval 54 | ``` 55 | - 800 epochs 56 | ```bash 57 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --use_env main.py --model deit_large_patch16_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --output_dir ${save_path}/pretrain --batch-size 32 --epochs 800 --smoothing 0.0 --reprob 0.0 --opt fusedlamb --color-jitter 0.3 --lr 3e-3 --weight-decay 0.05 --input-size 192 --drop 0.0 --drop-path 0.45 --unscale-lr --repeated-aug --bce-loss --ThreeAugment --eval-crop-ratio 1.0 --dist-eval 58 | ``` 59 | - Finetune 60 | ```bash 61 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --use_env main.py --model deit_large_patch16_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --finetune ${save_path}/pretrain/checkpoint.pth --output_dir ${save_path}/finetune --batch-size 8 --epochs 20 --smoothing 0.1 --reprob 0.0 --opt adamw --lr 1e-5 --weight-decay 0.1 --input-size 224 --drop 0.0 --drop-path 0.45 --mixup 0.8 --cutmix 1.0 --unscale-lr --no-repeated-aug --aa rand-m9-mstd0.5-inc1 --eval-crop-ratio 1.0 --dist-eval 62 | ``` 63 | - ViT-H 64 | - 400 epochs 65 | ```bash 66 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --use_env main.py --model deit_huge_patch14_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --output_dir ${save_path}/pretrain --batch-size 32 --epochs 400 --smoothing 0.0 --reprob 0.0 --opt fusedlamb --color-jitter 0.3 --lr 3e-3 --weight-decay 0.03 --input-size 160 --drop 0.0 --drop-path 0.5 --unscale-lr --repeated-aug --bce-loss --ThreeAugment --eval-crop-ratio 1.0 --dist-eval 67 | ``` 68 | - 800 epochs 69 | ```bash 70 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --use_env main.py --model deit_huge_patch14_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --output_dir ${save_path}/pretrain --batch-size 32 --epochs 800 --smoothing 0.0 --reprob 0.0 --opt fusedlamb --color-jitter 0.3 --lr 3e-3 --weight-decay 0.05 --input-size 160 --drop 0.0 --drop-path 0.6 --unscale-lr --repeated-aug --bce-loss --ThreeAugment --eval-crop-ratio 1.0 --dist-eval 71 | ``` 72 | - Finetune 73 | ```bash 74 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --use_env main.py --model deit_huge_patch14_LS --augsub masking --augsub-ratio 0.5 --data-path ${data_path} --finetune ${save_path}/pretrain/checkpoint.pth --output_dir ${save_path}/finetune --batch-size 8 --epochs 20 --smoothing 0.1 --reprob 0.0 --opt adamw --lr 1e-5 --weight-decay 0.1 --input-size 224 --drop 0.0 --drop-path 0.55 --mixup 0.8 --cutmix 1.0 --unscale-lr --no-repeated-aug --aa rand-m9-mstd0.5-inc1 --eval-crop-ratio 1.0 --dist-eval 75 | ``` 76 | -------------------------------------------------------------------------------- /swin/data/build.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # -------------------------------------------------------- 12 | 13 | import os 14 | import torch 15 | import numpy as np 16 | import torch.distributed as dist 17 | from torchvision import datasets, transforms 18 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | from timm.data import Mixup 20 | from timm.data import create_transform 21 | 22 | from .cached_image_folder import CachedImageFolder 23 | from .imagenet22k_dataset import IN22KDATASET 24 | from .samplers import SubsetRandomSampler 25 | 26 | try: 27 | from torchvision.transforms import InterpolationMode 28 | 29 | 30 | def _pil_interp(method): 31 | if method == 'bicubic': 32 | return InterpolationMode.BICUBIC 33 | elif method == 'lanczos': 34 | return InterpolationMode.LANCZOS 35 | elif method == 'hamming': 36 | return InterpolationMode.HAMMING 37 | else: 38 | # default bilinear, do we want to allow nearest? 39 | return InterpolationMode.BILINEAR 40 | 41 | 42 | import timm.data.transforms as timm_transforms 43 | 44 | timm_transforms._pil_interp = _pil_interp 45 | except: 46 | from timm.data.transforms import _pil_interp 47 | 48 | 49 | def build_loader(config): 50 | config.defrost() 51 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 52 | config.freeze() 53 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 54 | dataset_val, _ = build_dataset(is_train=False, config=config) 55 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 56 | 57 | num_tasks = dist.get_world_size() 58 | global_rank = dist.get_rank() 59 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 60 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 61 | sampler_train = SubsetRandomSampler(indices) 62 | else: 63 | sampler_train = torch.utils.data.DistributedSampler( 64 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 65 | ) 66 | 67 | if config.TEST.SEQUENTIAL: 68 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 69 | else: 70 | sampler_val = torch.utils.data.distributed.DistributedSampler( 71 | dataset_val, shuffle=config.TEST.SHUFFLE 72 | ) 73 | 74 | data_loader_train = torch.utils.data.DataLoader( 75 | dataset_train, sampler=sampler_train, 76 | batch_size=config.DATA.BATCH_SIZE, 77 | num_workers=config.DATA.NUM_WORKERS, 78 | pin_memory=config.DATA.PIN_MEMORY, 79 | drop_last=True, 80 | ) 81 | 82 | data_loader_val = torch.utils.data.DataLoader( 83 | dataset_val, sampler=sampler_val, 84 | batch_size=config.DATA.BATCH_SIZE, 85 | shuffle=False, 86 | num_workers=config.DATA.NUM_WORKERS, 87 | pin_memory=config.DATA.PIN_MEMORY, 88 | drop_last=False 89 | ) 90 | 91 | # setup mixup / cutmix 92 | mixup_fn = None 93 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 94 | if mixup_active: 95 | mixup_fn = Mixup( 96 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 97 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 98 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 99 | 100 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 101 | 102 | 103 | def build_dataset(is_train, config): 104 | transform = build_transform(is_train, config) 105 | if config.DATA.DATASET == 'imagenet': 106 | prefix = 'train' if is_train else 'val' 107 | if config.DATA.ZIP_MODE: 108 | ann_file = prefix + "_map.txt" 109 | prefix = prefix + ".zip@/" 110 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 111 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') 112 | else: 113 | root = os.path.join(config.DATA.DATA_PATH, prefix) 114 | dataset = datasets.ImageFolder(root, transform=transform) 115 | nb_classes = 1000 116 | elif config.DATA.DATASET == 'imagenet22K': 117 | prefix = 'ILSVRC2011fall_whole' 118 | if is_train: 119 | ann_file = prefix + "_map_train.txt" 120 | else: 121 | ann_file = prefix + "_map_val.txt" 122 | dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform) 123 | nb_classes = 21841 124 | else: 125 | raise NotImplementedError("We only support ImageNet Now.") 126 | 127 | return dataset, nb_classes 128 | 129 | 130 | def build_transform(is_train, config): 131 | resize_im = config.DATA.IMG_SIZE > 32 132 | if is_train: 133 | # this should always dispatch to transforms_imagenet_train 134 | transform = create_transform( 135 | input_size=config.DATA.IMG_SIZE, 136 | is_training=True, 137 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 138 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 139 | re_prob=config.AUG.REPROB, 140 | re_mode=config.AUG.REMODE, 141 | re_count=config.AUG.RECOUNT, 142 | interpolation=config.DATA.INTERPOLATION, 143 | ) 144 | if not resize_im: 145 | # replace RandomResizedCropAndInterpolation with 146 | # RandomCrop 147 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 148 | return transform 149 | 150 | t = [] 151 | if resize_im: 152 | if config.TEST.CROP: 153 | size = int((256 / 224) * config.DATA.IMG_SIZE) 154 | t.append( 155 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 156 | # to maintain same ratio w.r.t. 224 images 157 | ) 158 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 159 | else: 160 | t.append( 161 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 162 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 163 | ) 164 | 165 | t.append(transforms.ToTensor()) 166 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 167 | return transforms.Compose(t) 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Masking meets Supervision: A Strong Learning Alliance 4 | 5 | **[Byeongho Heo](https://sites.google.com/view/byeongho-heo/home), [Taekyung Kim](https://tkkim93.github.io/), [Sangdoo Yun](https://sangdooyun.github.io/), [Dongyoon Han](https://sites.google.com/site/dyhan0920/)**
6 | 7 | [NAVER AI LAB](https://naver-career.gitbook.io/en/teams/clova-cic/ai-lab) 8 | 9 | [![Paper](https://img.shields.io/badge/Paper-arxiv-green)](https://arxiv.org/abs/2306.11339) 10 | [![Paper](https://img.shields.io/badge/Paper-CVPR_2025-blue)](https://cvpr.thecvf.com/virtual/2025/poster/35257) 11 | [![CC BY-NC 4.0](https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg)](https://github.com/naver-ai/augsub/blob/main/LICENSE) 12 | 13 | 14 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/augmenting-sub-model-to-improve-main-model/self-supervised-image-classification-on-1)](https://paperswithcode.com/sota/self-supervised-image-classification-on-1?p=augmenting-sub-model-to-improve-main-model) 15 | 16 |
17 | 18 | Official PyTorch implementation of MaskSub "Masking meets Supervision: A Strong Learning Alliance" | [arxiv](https://arxiv.org/abs/2306.11339). 19 | 20 | ### Abstract 21 | 22 | Pre-training with random masked inputs has emerged as a novel trend in self-supervised training. However, supervised learning still faces a challenge in adopting masking augmentations, primarily due to unstable training. In this paper, we propose a novel way to involve masking augmentations dubbed Masked Sub-branch (MaskSub). MaskSub consists of the main-branch and sub-branch, the latter being a part of the former. The main-branch undergoes conventional training recipes, while the sub-branch merits intensive masking augmentations, during training. MaskSub tackles the challenge by mitigating adverse effects through a relaxed loss function similar to a self-distillation loss. Our analysis shows that MaskSub improves performance, with the training loss converging faster than in standard training, which suggests our method stabilizes the training process. We further validate MaskSub across diverse training scenarios and models, including DeiT-III training, MAE finetuning, CLIP finetuning, BERT training, and hierarchical architectures (ResNet and Swin Transformer). Our results show that MaskSub consistently achieves impressive performance gains across all the cases. MaskSub provides a practical and effective solution for introducing additional regularization under various training recipes. 23 | 24 | ## Updates 25 | 26 | - **Mar 25, 2025**: Arxiv & README update 27 | - **Feb 27, 2025**: Accepted to CVPR 2025 28 | - **Feb 28, 2024**: Arxiv paper update 29 | - **Jun 21, 2023**: Codes for deit, mae, swin, and resnet are released 30 | - **Jun 21, 2023**: Arxiv paper is released 31 | 32 | ## Getting Started 33 | 34 | You can find MaskSub training command at each folder. 35 | 36 | - `deit/` : DeiT-III training *"DeiT III: Revenge of the ViT"* [original repo](https://github.com/facebookresearch/deit) 37 | - `mae/` : MAE finetuning *"Masked Autoencoders Are Scalable Vision Learners"* [original repo](https://github.com/facebookresearch/mae) 38 | - `swin/` : Swin Transformer training *"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"* [original repo](https://github.com/microsoft/Swin-Transformer) 39 | - `resnet_rsb/` : ResNet training with RSB recipe *"ResNet strikes back: An improved training procedure in timm"* [original repo](https://github.com/huggingface/pytorch-image-models/tree/v0.5.4) 40 | 41 | ## Method preview 42 | 43 | Preview 44 | 45 | ### Pseudo-code for MaskSub 46 | It shows basic mechanism of MaskSub with simple code. 47 | ```python 48 | # For drop probability p 49 | for (x, y) in data_loader: 50 | o1, o2 = model(x, drop_prob=0), model(x, drop_prob=p) 51 | loss = CrossEntropy(o1, y) 52 | loss += CrossEntropy(o2, softmax(o1.detach())) 53 | (loss/2).backward() 54 | optimizer.step() 55 | ``` 56 | 57 | ### Practical code for MaskSub 50\% 58 | In practice, we use gradient accumulation technique to prevent GPU memory issues. Also, we use `loss_scaler` for mixed precision. 59 | ```python 60 | for (x, y) in data_loader: 61 | optimizer.zero_grad() 62 | 63 | # Main model 64 | outputs = model(x), 65 | loss = criterion(outputs, y) 66 | loss_scaler(loss/2, optimizer, retain_graph=False, update_grad=False) 67 | 68 | # Sub-model with masking 69 | outputs_sub = model(x, augsub='masking', augsub_ratio=0.5) 70 | loss = criterion(outputs_sub, F.softmax(outputs.detach())) 71 | loss_scaler(loss/2, optimizer, retain_graph=False, update_grad=True) 72 | ``` 73 | 74 | 75 | ## Performances 76 | 77 | ### DeiT-III 78 | 79 | | Architecture | # params | FLOPs | 400 epochs | + MaskSub | 800 epochs | + MaskSub | 80 | |:------------:|:--------:|:------:|:-----------:|:---------------:|:----------:|:---------------:| 81 | | ViT-S/16 | 22.0 M | 4.6 G | 80.4 | **81.1 (+0.7)** | 81.4 | **81.7 (+0.3)** | 82 | | ViT-B/16 | 86.6 M | 17.5 G | 83.5 | **84.1 (+0.6)** | 83.8 | **84.2 (+0.4)** | 83 | | ViT-L/16 | 304.4 M | 61.6 G | 84.5 | **85.2 (+0.7)** | 84.9 | **85.3 (+0.4)** | 84 | | ViT-H/14 | 632.1 M | 167.4 G| 85.1 | **85.7 (+0.6)** | 85.2 | **85.7 (+0.5)** | 85 | 86 | 87 | ### MAE finetuning 88 | 89 | | Architecture | Finetuning Epochs | Baseline | + MaskSub | 90 | |:------------:|:-----------------:|:--------:|:---------------:| 91 | | ViT-B/16 | 100 | 83.6 | **83.9 (+0.3)** | 92 | | ViT-L/16 | 50 | 85.9 | **86.1 (+0.2)** | 93 | | ViT-H/14 | 50 | 86.9 | **87.2 (+0.3)** | 94 | 95 | 96 | ### Swin Transformer 97 | 98 | | Architecture | # Params | FLOPs | Baseline | + MaskSub | 99 | | :---: | :---: | :---: | :---: |:---------------:| 100 | | Swin-T | 28.3 M | 4.5 G | 81.3 | **81.4 (+0.1)** | 101 | | Swin-S | 49.6 M | 8.7 G | 83.0 | **83.4 (+0.4)** | 102 | | Swin-B | 87.9 M | 15.4 G | 83.5 | **83.9 (+0.4)** | 103 | 104 | 105 | ### ResNet 106 | 107 | | Architecture | # Params | FLOPs | Baseline | + MaskSub | 108 | | :---: | :---: | :---: | :---: |:---------------:| 109 | | ResNet50 | 25.6 M | 4.1 G | 79.7 | **80.0 (+0.3)** | 110 | | ResNet101 | 44.5 M | 7.9 G | 81.4 | **82.1 (+0.7)** | 111 | | ResNet152 | 60.2 M | 11.6 G | 81.8 | **82.8 (+1.0)** | 112 | 113 | 114 | ## License 115 | 116 | Licensed under [CC BY-NC 4.0](LICENSE) 117 | 118 | ``` 119 | AugSub 120 | Copyright (c) 2023-present NAVER Cloud Corp. 121 | CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) 122 | ``` 123 | 124 | ## How to cite 125 | 126 | ``` 127 | @inproceedings{heo2023masksub, 128 | title={Masking meets Supervision: A Strong Learning Alliance}, 129 | author={Heo, Byeongho and Kim, Taekyung and Yun, Sangdoo and Han, Dongyoon}, 130 | year={2025}, 131 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 132 | } 133 | ``` 134 | -------------------------------------------------------------------------------- /swin/data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # -------------------------------------------------------- 12 | 13 | import io 14 | import os 15 | import time 16 | import torch.distributed as dist 17 | import torch.utils.data as data 18 | from PIL import Image 19 | 20 | from .zipreader import is_zip_path, ZipReader 21 | 22 | 23 | def has_file_allowed_extension(filename, extensions): 24 | """Checks if a file is an allowed extension. 25 | Args: 26 | filename (string): path to a file 27 | Returns: 28 | bool: True if the filename ends with a known image extension 29 | """ 30 | filename_lower = filename.lower() 31 | return any(filename_lower.endswith(ext) for ext in extensions) 32 | 33 | 34 | def find_classes(dir): 35 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 36 | classes.sort() 37 | class_to_idx = {classes[i]: i for i in range(len(classes))} 38 | return classes, class_to_idx 39 | 40 | 41 | def make_dataset(dir, class_to_idx, extensions): 42 | images = [] 43 | dir = os.path.expanduser(dir) 44 | for target in sorted(os.listdir(dir)): 45 | d = os.path.join(dir, target) 46 | if not os.path.isdir(d): 47 | continue 48 | 49 | for root, _, fnames in sorted(os.walk(d)): 50 | for fname in sorted(fnames): 51 | if has_file_allowed_extension(fname, extensions): 52 | path = os.path.join(root, fname) 53 | item = (path, class_to_idx[target]) 54 | images.append(item) 55 | 56 | return images 57 | 58 | 59 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 60 | images = [] 61 | with open(ann_file, "r") as f: 62 | contents = f.readlines() 63 | for line_str in contents: 64 | path_contents = [c for c in line_str.split('\t')] 65 | im_file_name = path_contents[0] 66 | class_index = int(path_contents[1]) 67 | 68 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 69 | item = (os.path.join(img_prefix, im_file_name), class_index) 70 | 71 | images.append(item) 72 | 73 | return images 74 | 75 | 76 | class DatasetFolder(data.Dataset): 77 | """A generic data loader where the samples are arranged in this way: :: 78 | root/class_x/xxx.ext 79 | root/class_x/xxy.ext 80 | root/class_x/xxz.ext 81 | root/class_y/123.ext 82 | root/class_y/nsdf3.ext 83 | root/class_y/asd932_.ext 84 | Args: 85 | root (string): Root directory path. 86 | loader (callable): A function to load a sample given its path. 87 | extensions (list[string]): A list of allowed extensions. 88 | transform (callable, optional): A function/transform that takes in 89 | a sample and returns a transformed version. 90 | E.g, ``transforms.RandomCrop`` for images. 91 | target_transform (callable, optional): A function/transform that takes 92 | in the target and transforms it. 93 | Attributes: 94 | samples (list): List of (sample path, class_index) tuples 95 | """ 96 | 97 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 98 | cache_mode="no"): 99 | # image folder mode 100 | if ann_file == '': 101 | _, class_to_idx = find_classes(root) 102 | samples = make_dataset(root, class_to_idx, extensions) 103 | # zip mode 104 | else: 105 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 106 | os.path.join(root, img_prefix), 107 | extensions) 108 | 109 | if len(samples) == 0: 110 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 111 | "Supported extensions are: " + ",".join(extensions))) 112 | 113 | self.root = root 114 | self.loader = loader 115 | self.extensions = extensions 116 | 117 | self.samples = samples 118 | self.labels = [y_1k for _, y_1k in samples] 119 | self.classes = list(set(self.labels)) 120 | 121 | self.transform = transform 122 | self.target_transform = target_transform 123 | 124 | self.cache_mode = cache_mode 125 | if self.cache_mode != "no": 126 | self.init_cache() 127 | 128 | def init_cache(self): 129 | assert self.cache_mode in ["part", "full"] 130 | n_sample = len(self.samples) 131 | global_rank = dist.get_rank() 132 | world_size = dist.get_world_size() 133 | 134 | samples_bytes = [None for _ in range(n_sample)] 135 | start_time = time.time() 136 | for index in range(n_sample): 137 | if index % (n_sample // 10) == 0: 138 | t = time.time() - start_time 139 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 140 | start_time = time.time() 141 | path, target = self.samples[index] 142 | if self.cache_mode == "full": 143 | samples_bytes[index] = (ZipReader.read(path), target) 144 | elif self.cache_mode == "part" and index % world_size == global_rank: 145 | samples_bytes[index] = (ZipReader.read(path), target) 146 | else: 147 | samples_bytes[index] = (path, target) 148 | self.samples = samples_bytes 149 | 150 | def __getitem__(self, index): 151 | """ 152 | Args: 153 | index (int): Index 154 | Returns: 155 | tuple: (sample, target) where target is class_index of the target class. 156 | """ 157 | path, target = self.samples[index] 158 | sample = self.loader(path) 159 | if self.transform is not None: 160 | sample = self.transform(sample) 161 | if self.target_transform is not None: 162 | target = self.target_transform(target) 163 | 164 | return sample, target 165 | 166 | def __len__(self): 167 | return len(self.samples) 168 | 169 | def __repr__(self): 170 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 171 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 172 | fmt_str += ' Root Location: {}\n'.format(self.root) 173 | tmp = ' Transforms (if any): ' 174 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 175 | tmp = ' Target Transforms (if any): ' 176 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 177 | return fmt_str 178 | 179 | 180 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 181 | 182 | 183 | def pil_loader(path): 184 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 185 | if isinstance(path, bytes): 186 | img = Image.open(io.BytesIO(path)) 187 | elif is_zip_path(path): 188 | data = ZipReader.read(path) 189 | img = Image.open(io.BytesIO(data)) 190 | else: 191 | with open(path, 'rb') as f: 192 | img = Image.open(f) 193 | return img.convert('RGB') 194 | return img.convert('RGB') 195 | 196 | 197 | def accimage_loader(path): 198 | import accimage 199 | try: 200 | return accimage.Image(path) 201 | except IOError: 202 | # Potentially a decoding problem, fall back to PIL.Image 203 | return pil_loader(path) 204 | 205 | 206 | def default_img_loader(path): 207 | from torchvision import get_image_backend 208 | if get_image_backend() == 'accimage': 209 | return accimage_loader(path) 210 | else: 211 | return pil_loader(path) 212 | 213 | 214 | class CachedImageFolder(DatasetFolder): 215 | """A generic data loader where the images are arranged in this way: :: 216 | root/dog/xxx.png 217 | root/dog/xxy.png 218 | root/dog/xxz.png 219 | root/cat/123.png 220 | root/cat/nsdf3.png 221 | root/cat/asd932_.png 222 | Args: 223 | root (string): Root directory path. 224 | transform (callable, optional): A function/transform that takes in an PIL image 225 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 226 | target_transform (callable, optional): A function/transform that takes in the 227 | target and transforms it. 228 | loader (callable, optional): A function to load an image given its path. 229 | Attributes: 230 | imgs (list): List of (image path, class_index) tuples 231 | """ 232 | 233 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 234 | loader=default_img_loader, cache_mode="no"): 235 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 236 | ann_file=ann_file, img_prefix=img_prefix, 237 | transform=transform, target_transform=target_transform, 238 | cache_mode=cache_mode) 239 | self.imgs = self.samples 240 | 241 | def __getitem__(self, index): 242 | """ 243 | Args: 244 | index (int): Index 245 | Returns: 246 | tuple: (image, target) where target is class_index of the target class. 247 | """ 248 | path, target = self.samples[index] 249 | image = self.loader(path) 250 | if self.transform is not None: 251 | img = self.transform(image) 252 | else: 253 | img = image 254 | if self.target_transform is not None: 255 | target = self.target_transform(target) 256 | 257 | return img, target 258 | -------------------------------------------------------------------------------- /deit/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/deit 4 | """ 5 | 6 | # Copyright (c) 2015-present, Facebook, Inc. 7 | # All rights reserved. 8 | """ 9 | Misc functions, including distributed helpers. 10 | 11 | Mostly copy-paste from torchvision references. 12 | """ 13 | import io 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | import datetime 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if isinstance(v, torch.Tensor): 94 | v = v.item() 95 | assert isinstance(v, (float, int)) 96 | self.meters[k].update(v) 97 | 98 | def __getattr__(self, attr): 99 | if attr in self.meters: 100 | return self.meters[attr] 101 | if attr in self.__dict__: 102 | return self.__dict__[attr] 103 | raise AttributeError("'{}' object has no attribute '{}'".format( 104 | type(self).__name__, attr)) 105 | 106 | def __str__(self): 107 | loss_str = [] 108 | for name, meter in self.meters.items(): 109 | loss_str.append( 110 | "{}: {}".format(name, str(meter)) 111 | ) 112 | return self.delimiter.join(loss_str) 113 | 114 | def synchronize_between_processes(self): 115 | for meter in self.meters.values(): 116 | meter.synchronize_between_processes() 117 | 118 | def add_meter(self, name, meter): 119 | self.meters[name] = meter 120 | 121 | def log_every(self, iterable, print_freq, header=None): 122 | i = 0 123 | if not header: 124 | header = '' 125 | start_time = time.time() 126 | end = time.time() 127 | iter_time = SmoothedValue(fmt='{avg:.4f}') 128 | data_time = SmoothedValue(fmt='{avg:.4f}') 129 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 130 | log_msg = [ 131 | header, 132 | '[{0' + space_fmt + '}/{1}]', 133 | 'eta: {eta}', 134 | '{meters}', 135 | 'time: {time}', 136 | 'data: {data}' 137 | ] 138 | if torch.cuda.is_available(): 139 | log_msg.append('max mem: {memory:.0f}') 140 | log_msg = self.delimiter.join(log_msg) 141 | MB = 1024.0 * 1024.0 142 | for obj in iterable: 143 | data_time.update(time.time() - end) 144 | yield obj 145 | iter_time.update(time.time() - end) 146 | if i % print_freq == 0 or i == len(iterable) - 1: 147 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 148 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 149 | if torch.cuda.is_available(): 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time), 154 | memory=torch.cuda.max_memory_allocated() / MB)) 155 | else: 156 | print(log_msg.format( 157 | i, len(iterable), eta=eta_string, 158 | meters=str(self), 159 | time=str(iter_time), data=str(data_time))) 160 | i += 1 161 | end = time.time() 162 | total_time = time.time() - start_time 163 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 164 | print('{} Total time: {} ({:.4f} s / it)'.format( 165 | header, total_time_str, total_time / len(iterable))) 166 | 167 | 168 | def _load_checkpoint_for_ema(model_ema, checkpoint): 169 | """ 170 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 171 | """ 172 | mem_file = io.BytesIO() 173 | torch.save({'state_dict_ema':checkpoint}, mem_file) 174 | mem_file.seek(0) 175 | model_ema._load_checkpoint(mem_file) 176 | 177 | 178 | def setup_for_distributed(is_master): 179 | """ 180 | This function disables printing when not in master process 181 | """ 182 | import builtins as __builtin__ 183 | builtin_print = __builtin__.print 184 | 185 | def print(*args, **kwargs): 186 | force = kwargs.pop('force', False) 187 | if is_master or force: 188 | builtin_print(*args, **kwargs) 189 | 190 | __builtin__.print = print 191 | 192 | 193 | def is_dist_avail_and_initialized(): 194 | if not dist.is_available(): 195 | return False 196 | if not dist.is_initialized(): 197 | return False 198 | return True 199 | 200 | 201 | def get_world_size(): 202 | if not is_dist_avail_and_initialized(): 203 | return 1 204 | return dist.get_world_size() 205 | 206 | 207 | def get_rank(): 208 | if not is_dist_avail_and_initialized(): 209 | return 0 210 | return dist.get_rank() 211 | 212 | 213 | def is_main_process(): 214 | return get_rank() == 0 215 | 216 | 217 | def save_on_master(*args, **kwargs): 218 | if is_main_process(): 219 | torch.save(*args, **kwargs) 220 | 221 | 222 | def init_distributed_mode(args): 223 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 224 | args.rank = int(os.environ["RANK"]) 225 | args.world_size = int(os.environ['WORLD_SIZE']) 226 | args.gpu = int(os.environ['LOCAL_RANK']) 227 | elif 'SLURM_PROCID' in os.environ: 228 | args.rank = int(os.environ['SLURM_PROCID']) 229 | args.gpu = args.rank % torch.cuda.device_count() 230 | else: 231 | print('Not using distributed mode') 232 | args.distributed = False 233 | return 234 | 235 | args.distributed = True 236 | 237 | torch.cuda.set_device(args.gpu) 238 | args.dist_backend = 'nccl' 239 | print('| distributed init (rank {}): {}'.format( 240 | args.rank, args.dist_url), flush=True) 241 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 242 | world_size=args.world_size, rank=args.rank) 243 | torch.distributed.barrier() 244 | setup_for_distributed(args.rank == 0) 245 | 246 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 247 | if isinstance(parameters, torch.Tensor): 248 | parameters = [parameters] 249 | parameters = [p for p in parameters if p.grad is not None] 250 | norm_type = float(norm_type) 251 | if len(parameters) == 0: 252 | return torch.tensor(0.) 253 | device = parameters[0].grad.device 254 | if norm_type == inf: 255 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 256 | else: 257 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 258 | return total_norm 259 | 260 | class NativeScalerWithGradNormCount: 261 | state_dict_key = "amp_scaler" 262 | 263 | def __init__(self): 264 | self._scaler = torch.cuda.amp.GradScaler() 265 | 266 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, retain_graph=False): 267 | if retain_graph: 268 | self._scaler.scale(loss).backward(create_graph=create_graph, retain_graph=retain_graph) 269 | else: 270 | self._scaler.scale(loss).backward(create_graph=create_graph) 271 | 272 | if update_grad: 273 | if clip_grad is not None: 274 | assert parameters is not None 275 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 276 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 277 | else: 278 | self._scaler.unscale_(optimizer) 279 | norm = get_grad_norm_(parameters) 280 | self._scaler.step(optimizer) 281 | self._scaler.update() 282 | else: 283 | norm = None 284 | return norm 285 | 286 | def state_dict(self): 287 | return self._scaler.state_dict() 288 | 289 | def load_state_dict(self, state_dict): 290 | self._scaler.load_state_dict(state_dict) -------------------------------------------------------------------------------- /swin/kernels/window_process/unit_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Fused kernel for window process for SwinTransformer 8 | # Copyright (c) 2022 Nvidia 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import swin_window_process 14 | import random 15 | import time 16 | import unittest 17 | 18 | 19 | class WindowProcess(torch.autograd.Function): 20 | @staticmethod 21 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 22 | output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) 23 | 24 | ctx.B = B 25 | ctx.H = H 26 | ctx.W = W 27 | ctx.C = C 28 | ctx.shift_size = shift_size 29 | ctx.window_size = window_size 30 | return output 31 | 32 | @staticmethod 33 | def backward(ctx, grad_in): 34 | B = ctx.B 35 | H = ctx.H 36 | W = ctx.W 37 | C = ctx.C 38 | shift_size = ctx.shift_size 39 | window_size = ctx.window_size 40 | 41 | grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) 42 | return grad_out, None, None, None, None, None, None, None 43 | 44 | 45 | class WindowProcessReverse(torch.autograd.Function): 46 | @staticmethod 47 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 48 | output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) 49 | 50 | ctx.B = B 51 | ctx.H = H 52 | ctx.W = W 53 | ctx.C = C 54 | ctx.shift_size = shift_size 55 | ctx.window_size = window_size 56 | 57 | return output 58 | 59 | @staticmethod 60 | def backward(ctx, grad_in): 61 | B = ctx.B 62 | H = ctx.H 63 | W = ctx.W 64 | C = ctx.C 65 | shift_size = ctx.shift_size 66 | window_size = ctx.window_size 67 | 68 | grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) 69 | return grad_out, None, None, None, None, None, None, None 70 | 71 | 72 | def window_partition(x, window_size): 73 | """ 74 | Args: 75 | x: (B, H, W, C) 76 | window_size (int): window size 77 | Returns: 78 | windows: (num_windows*B, window_size, window_size, C) 79 | """ 80 | B, H, W, C = x.shape 81 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 82 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 83 | return windows 84 | 85 | def window_reverse(windows, window_size, H, W): 86 | """ 87 | Args: 88 | windows: (num_windows*B, window_size, window_size, C) 89 | window_size (int): Window size 90 | H (int): Height of image 91 | W (int): Width of image 92 | Returns: 93 | x: (B, H, W, C) 94 | """ 95 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 96 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 97 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 98 | return x 99 | 100 | 101 | def pyt_forward(x, shift_size, window_size): 102 | # x in shape(B, H, W, C) 103 | # cyclic shift 104 | if shift_size > 0: 105 | shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) 106 | else: 107 | shifted_x = x 108 | # partition windows 109 | x_windows = window_partition(shifted_x, window_size) 110 | return x_windows 111 | 112 | 113 | def reverse_pyt_forward(attn_windows, shift_size, window_size, H, W): 114 | # x in shape(B*nH*nW, window_size, window_size, C) 115 | shifted_x = window_reverse(attn_windows, window_size, H, W) 116 | if shift_size > 0: 117 | x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) 118 | else: 119 | x = shifted_x 120 | return x 121 | 122 | 123 | def copy_one_tensor(input, requires_grad=True): 124 | input1 = input.clone().detach().requires_grad_(requires_grad).cuda() 125 | return input1 126 | 127 | class Test_WindowProcess(unittest.TestCase): 128 | def setUp(self): 129 | self.B = 192 130 | self.H = 56 131 | self.W = 56 132 | self.C = 96 133 | self.shift_size = 2 134 | self.window_size = 7 135 | self.nH = self.H // self.window_size 136 | self.nW = self.W // self.window_size 137 | 138 | def test_roll_and_window_partition_forward(self, dtype=torch.float32): 139 | input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 140 | 141 | input1 = copy_one_tensor(input, True) 142 | input2 = copy_one_tensor(input, True) 143 | 144 | with torch.no_grad(): 145 | # ori 146 | expected = pyt_forward(input1, self.shift_size, self.window_size) 147 | # fused kernel 148 | fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) 149 | 150 | self.assertTrue(torch.equal(expected, fused_output)) 151 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 152 | 153 | def test_roll_and_window_partition_backward(self, dtype=torch.float32): 154 | input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 155 | d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda() 156 | 157 | input1 = copy_one_tensor(input, True) 158 | input2 = copy_one_tensor(input, True) 159 | 160 | # ori 161 | expected = pyt_forward(input1, self.shift_size, self.window_size) 162 | expected.backward(d_loss_tensor) 163 | # fused kernel 164 | fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) 165 | fused_output.backward(d_loss_tensor) 166 | 167 | self.assertTrue(torch.equal(expected, fused_output)) 168 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 169 | 170 | def test_window_merge_and_roll_forward(self, dtype=torch.float32): 171 | input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() 172 | 173 | input1 = copy_one_tensor(input, True) 174 | input2 = copy_one_tensor(input, True) 175 | 176 | with torch.no_grad(): 177 | # ori 178 | expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) 179 | # fused kernel 180 | fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) 181 | 182 | self.assertTrue(torch.equal(expected, fused_output)) 183 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 184 | 185 | 186 | def test_window_merge_and_roll_backward(self, dtype=torch.float32): 187 | input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() 188 | d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 189 | 190 | input1 = copy_one_tensor(input, True) 191 | input2 = copy_one_tensor(input, True) 192 | 193 | # ori 194 | expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) 195 | expected.backward(d_loss_tensor) 196 | # fused kernel 197 | fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) 198 | fused_output.backward(d_loss_tensor) 199 | 200 | self.assertTrue(torch.equal(expected, fused_output)) 201 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 202 | 203 | def test_forward_backward_speed(self, dtype=torch.float32, times=1000): 204 | input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() 205 | d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 206 | 207 | input1 = copy_one_tensor(input, True) 208 | input2 = copy_one_tensor(input, True) 209 | 210 | # SwinTransformer official 211 | def run_pyt(t=1000): 212 | for _ in range(t): 213 | expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) 214 | expected.backward(d_loss_tensor) 215 | 216 | # my op 217 | def run_fusedop(t=1000): 218 | for _ in range(t): 219 | fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) 220 | fused_output.backward(d_loss_tensor) 221 | 222 | torch.cuda.synchronize() 223 | t1 = time.time() 224 | run_pyt(t=times) 225 | torch.cuda.synchronize() 226 | t2 = time.time() 227 | run_fusedop(t=times) 228 | torch.cuda.synchronize() 229 | t3 = time.time() 230 | self.assertTrue((t3 - t2) < (t2 - t1)) 231 | 232 | print('Run {} times'.format(times)) 233 | print('Original time cost: {}'.format(t2 - t1)) 234 | print('Fused op time cost: {}'.format(t3 - t2)) 235 | 236 | def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16): 237 | self.test_roll_and_window_partition_forward(dtype=dtype) 238 | 239 | def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16): 240 | self.test_roll_and_window_partition_backward(dtype=dtype) 241 | 242 | def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16): 243 | self.test_window_merge_and_roll_forward(dtype=dtype) 244 | 245 | def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16): 246 | self.test_window_merge_and_roll_backward(dtype=dtype) 247 | 248 | def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000): 249 | self.test_forward_backward_speed(dtype=dtype, times=times) 250 | 251 | 252 | if __name__ == '__main__': 253 | print('Pass only two tensors are exactly the same (using torch.equal).\n') 254 | torch.manual_seed(0) 255 | unittest.main(verbosity=2) 256 | -------------------------------------------------------------------------------- /swin/kernels/window_process/swin_window_process_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | */ 5 | 6 | /* 7 | * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 8 | * 9 | * Licensed under the Apache License, Version 2.0 (the "License"); 10 | * you may not use this file except in compliance with the License. 11 | * You may obtain a copy of the License at 12 | * 13 | * http://www.apache.org/licenses/LICENSE-2.0 14 | * 15 | * Unless required by applicable law or agreed to in writing, software 16 | * distributed under the License is distributed on an "AS IS" BASIS, 17 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | * See the License for the specific language governing permissions and 19 | * limitations under the License. 20 | */ 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | int best_block_dim(int feat_dim){ 30 | int best_dim; 31 | if (feat_dim < 384){ 32 | best_dim = 64; 33 | } 34 | else{ 35 | if (feat_dim < 1024){ 36 | best_dim = 128; 37 | } 38 | else{ 39 | best_dim = 256; 40 | } 41 | } 42 | return best_dim; 43 | } 44 | 45 | 46 | template 47 | __global__ void roll_and_window_partition_forward_cuda_kernel( 48 | T* input, 49 | T* output, 50 | const int B, 51 | const int H, 52 | const int W, 53 | const int C, 54 | const int shift_size, 55 | const int window_size, 56 | const int nH, 57 | const int nW){ 58 | // start 59 | //bool qual = threadIdx.x < C; 60 | int index = threadIdx.x; 61 | int offset; 62 | for (int i = index; i < C; i += blockDim.x) { 63 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 64 | int input_offset = blockIdx.z / (nH * nW) * H * W * C + 65 | (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y - shift_size + H) % H * W * C + 66 | (blockIdx.z % nW * window_size + blockIdx.x - shift_size + W) % W * C + 67 | i; 68 | output[offset] = (T)(__ldg(input + input_offset)); 69 | } 70 | } 71 | 72 | 73 | template 74 | __global__ void roll_and_window_partition_backward_cuda_kernel( 75 | T* grad_in, 76 | T* grad_out, 77 | const int B, 78 | const int H, 79 | const int W, 80 | const int C, 81 | const int shift_size, 82 | const int window_size, 83 | const int nH, 84 | const int nW){ 85 | // start 86 | int index = threadIdx.x; 87 | int offset; 88 | for (int i = index; i < C; i += blockDim.x) { 89 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 90 | int input_offset = 91 | (blockIdx.z * nH * nW + (blockIdx.y + shift_size + H) % H / window_size * nW + (blockIdx.x + shift_size + W) % W / window_size) * window_size * window_size * C + 92 | (blockIdx.y + shift_size + H ) % H % window_size * window_size * C + 93 | (blockIdx.x + shift_size + W ) % W % window_size * C + 94 | i; 95 | grad_out[offset] = (T)(__ldg(grad_in + input_offset)); 96 | } 97 | } 98 | 99 | 100 | template 101 | __global__ void window_merge_and_roll_forward_cuda_kernel( 102 | T* input, 103 | T* output, 104 | const int B, 105 | const int H, 106 | const int W, 107 | const int C, 108 | const int shift_size, 109 | const int window_size, 110 | const int nH, 111 | const int nW){ 112 | // start 113 | int index = threadIdx.x; 114 | int offset; 115 | for (int i = index; i < C; i += blockDim.x) { 116 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 117 | int input_offset = 118 | (blockIdx.z * nH * nW + (blockIdx.y - shift_size + H) % H / window_size * nH + (blockIdx.x - shift_size + W) % W / window_size) * window_size * window_size * C + 119 | (blockIdx.y - shift_size + H) % window_size * window_size * C + 120 | (blockIdx.x - shift_size + W) % window_size * C + 121 | i; 122 | output[offset] = (T)(__ldg(input + input_offset)); 123 | } 124 | } 125 | 126 | 127 | 128 | template 129 | __global__ void window_merge_and_roll_backward_cuda_kernel( 130 | T* grad_in, 131 | T* grad_out, 132 | const int B, 133 | const int H, 134 | const int W, 135 | const int C, 136 | const int shift_size, 137 | const int window_size, 138 | const int nH, 139 | const int nW){ 140 | // start 141 | int index = threadIdx.x; 142 | int offset; 143 | for (int i = index; i < C; i += blockDim.x) { 144 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 145 | int input_offset = 146 | (blockIdx.z / (nH * nW)) * H * W * C + 147 | (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y + shift_size + H) % H * W * C + 148 | (blockIdx.z % nW * window_size + blockIdx.x + shift_size + W) % W * C + 149 | i; 150 | grad_out[offset] = (T)(__ldg(grad_in + input_offset)); 151 | } 152 | } 153 | 154 | // input: [B, H, W, C] 155 | // output: [B*nH*nW, window_size, window_size, C] 156 | at::Tensor roll_and_window_partition_forward_cuda( 157 | at::Tensor & input, 158 | //at::Tensor & output, 159 | const int B, 160 | const int H, 161 | const int W, 162 | const int C, 163 | const int shift_size, 164 | const int window_size){ 165 | 166 | int nH = H / window_size; 167 | int nW = W / window_size; 168 | 169 | dim3 grid(window_size, window_size, B * nH * nW); 170 | //dim3 block((C + 31) / 32 * 32); 171 | int blocknum = best_block_dim(C); 172 | dim3 block(blocknum); 173 | 174 | at::Tensor output; 175 | if (input.scalar_type() == torch::kFloat16){ 176 | output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true)); 177 | } 178 | else{ 179 | output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true)); 180 | } 181 | 182 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "roll_and_window_partition_forward_cuda_kernel", ([&] { 183 | roll_and_window_partition_forward_cuda_kernel<<>>( 184 | input.data(), 185 | output.data(), 186 | B, 187 | H, 188 | W, 189 | C, 190 | shift_size, 191 | window_size, 192 | nH, 193 | nW); 194 | })); 195 | return output; 196 | } 197 | 198 | 199 | // grad_in: [B*nH*nW, window_size, window_size, C] 200 | // grad_out: [B, H, W, C] 201 | at::Tensor roll_and_window_partition_backward_cuda( 202 | at::Tensor & grad_in, 203 | const int B, 204 | const int H, 205 | const int W, 206 | const int C, 207 | const int shift_size, 208 | const int window_size){ 209 | 210 | int nH = H / window_size; 211 | int nW = W / window_size; 212 | 213 | dim3 grid(W, H, B); 214 | //dim3 block((C + 31) / 32 * 32); 215 | int blocknum = best_block_dim(C); 216 | dim3 block(blocknum); 217 | 218 | at::Tensor grad_out; 219 | if (grad_in.scalar_type() == torch::kFloat16){ 220 | grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); 221 | } 222 | else{ 223 | grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); 224 | } 225 | 226 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "roll_and_window_partition_backward_cuda_kernel", ([&] { 227 | roll_and_window_partition_backward_cuda_kernel<<>>( 228 | grad_in.data(), 229 | grad_out.data(), 230 | B, 231 | H, 232 | W, 233 | C, 234 | shift_size, 235 | window_size, 236 | nH, 237 | nW); 238 | })); 239 | return grad_out; 240 | } 241 | 242 | 243 | // input: [B*nH*nW, window_size, window_size, C] 244 | // output: [B, H, W, C] 245 | at::Tensor window_merge_and_roll_forward_cuda( 246 | at::Tensor & input, 247 | //at::Tensor & output, 248 | const int B, 249 | const int H, 250 | const int W, 251 | const int C, 252 | const int shift_size, 253 | const int window_size){ 254 | 255 | int nH = H / window_size; 256 | int nW = W / window_size; 257 | 258 | dim3 grid(W, H, B); 259 | //dim3 block((C + 31) / 32 * 32); 260 | int blocknum = best_block_dim(C); 261 | dim3 block(blocknum); 262 | 263 | //generate output tensor inside 264 | at::Tensor output; 265 | if (input.scalar_type() == torch::kFloat16){ 266 | output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true)); 267 | } 268 | else{ 269 | output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true)); 270 | } 271 | 272 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "window_merge_and_roll_forward_cuda_kernel", ([&] { 273 | window_merge_and_roll_forward_cuda_kernel<<>>( 274 | input.data(), 275 | output.data(), 276 | B, 277 | H, 278 | W, 279 | C, 280 | shift_size, 281 | window_size, 282 | nH, 283 | nW); 284 | })); 285 | return output; 286 | } 287 | 288 | 289 | at::Tensor window_merge_and_roll_backward_cuda( 290 | at::Tensor & grad_in, 291 | const int B, 292 | const int H, 293 | const int W, 294 | const int C, 295 | const int shift_size, 296 | const int window_size){ 297 | 298 | int nH = H / window_size; 299 | int nW = W / window_size; 300 | 301 | dim3 grid(window_size, window_size, B * nH * nW); 302 | //dim3 block((C + 31) / 32 * 32); 303 | int blocknum = best_block_dim(C); 304 | dim3 block(blocknum); 305 | 306 | at::Tensor grad_out; 307 | if (grad_in.scalar_type() == torch::kFloat16){ 308 | grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); 309 | } 310 | else{ 311 | grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); 312 | } 313 | 314 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "window_merge_and_roll_backward_cuda_kernel", ([&] { 315 | window_merge_and_roll_backward_cuda_kernel<<>>( 316 | grad_in.data(), 317 | grad_out.data(), 318 | B, 319 | H, 320 | W, 321 | C, 322 | shift_size, 323 | window_size, 324 | nH, 325 | nW); 326 | })); 327 | return grad_out; 328 | } -------------------------------------------------------------------------------- /swin/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # -------------------------------------------------------- 12 | 13 | import os 14 | import torch 15 | import torch.distributed as dist 16 | from torch._six import inf 17 | from timm.utils import AverageMeter 18 | 19 | 20 | def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger): 21 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 22 | if config.MODEL.RESUME.startswith('https'): 23 | checkpoint = torch.hub.load_state_dict_from_url( 24 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 25 | else: 26 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 27 | msg = model.load_state_dict(checkpoint['model'], strict=False) 28 | logger.info(msg) 29 | max_accuracy = 0.0 30 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 31 | optimizer.load_state_dict(checkpoint['optimizer']) 32 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 33 | config.defrost() 34 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 35 | config.freeze() 36 | if 'scaler' in checkpoint: 37 | loss_scaler.load_state_dict(checkpoint['scaler']) 38 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 39 | if 'max_accuracy' in checkpoint: 40 | max_accuracy = checkpoint['max_accuracy'] 41 | 42 | del checkpoint 43 | torch.cuda.empty_cache() 44 | return max_accuracy 45 | 46 | 47 | def load_pretrained(config, model, logger): 48 | logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") 49 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') 50 | state_dict = checkpoint['model'] 51 | 52 | # delete relative_position_index since we always re-init it 53 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] 54 | for k in relative_position_index_keys: 55 | del state_dict[k] 56 | 57 | # delete relative_coords_table since we always re-init it 58 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] 59 | for k in relative_position_index_keys: 60 | del state_dict[k] 61 | 62 | # delete attn_mask since we always re-init it 63 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] 64 | for k in attn_mask_keys: 65 | del state_dict[k] 66 | 67 | # bicubic interpolate relative_position_bias_table if not match 68 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 69 | for k in relative_position_bias_table_keys: 70 | relative_position_bias_table_pretrained = state_dict[k] 71 | relative_position_bias_table_current = model.state_dict()[k] 72 | L1, nH1 = relative_position_bias_table_pretrained.size() 73 | L2, nH2 = relative_position_bias_table_current.size() 74 | if nH1 != nH2: 75 | logger.warning(f"Error in loading {k}, passing......") 76 | else: 77 | if L1 != L2: 78 | # bicubic interpolate relative_position_bias_table if not match 79 | S1 = int(L1 ** 0.5) 80 | S2 = int(L2 ** 0.5) 81 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( 82 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), 83 | mode='bicubic') 84 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) 85 | 86 | # bicubic interpolate absolute_pos_embed if not match 87 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] 88 | for k in absolute_pos_embed_keys: 89 | # dpe 90 | absolute_pos_embed_pretrained = state_dict[k] 91 | absolute_pos_embed_current = model.state_dict()[k] 92 | _, L1, C1 = absolute_pos_embed_pretrained.size() 93 | _, L2, C2 = absolute_pos_embed_current.size() 94 | if C1 != C1: 95 | logger.warning(f"Error in loading {k}, passing......") 96 | else: 97 | if L1 != L2: 98 | S1 = int(L1 ** 0.5) 99 | S2 = int(L2 ** 0.5) 100 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) 101 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) 102 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( 103 | absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') 104 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) 105 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) 106 | state_dict[k] = absolute_pos_embed_pretrained_resized 107 | 108 | # check classifier, if not match, then re-init classifier to zero 109 | head_bias_pretrained = state_dict['head.bias'] 110 | Nc1 = head_bias_pretrained.shape[0] 111 | Nc2 = model.head.bias.shape[0] 112 | if (Nc1 != Nc2): 113 | if Nc1 == 21841 and Nc2 == 1000: 114 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......") 115 | map22kto1k_path = f'data/map22kto1k.txt' 116 | with open(map22kto1k_path) as f: 117 | map22kto1k = f.readlines() 118 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] 119 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] 120 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] 121 | else: 122 | torch.nn.init.constant_(model.head.bias, 0.) 123 | torch.nn.init.constant_(model.head.weight, 0.) 124 | del state_dict['head.weight'] 125 | del state_dict['head.bias'] 126 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0") 127 | 128 | msg = model.load_state_dict(state_dict, strict=False) 129 | logger.warning(msg) 130 | 131 | logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'") 132 | 133 | del checkpoint 134 | torch.cuda.empty_cache() 135 | 136 | 137 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger): 138 | save_state = {'model': model.state_dict(), 139 | 'optimizer': optimizer.state_dict(), 140 | 'lr_scheduler': lr_scheduler.state_dict(), 141 | 'max_accuracy': max_accuracy, 142 | 'scaler': loss_scaler.state_dict(), 143 | 'epoch': epoch, 144 | 'config': config} 145 | 146 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 147 | logger.info(f"{save_path} saving......") 148 | torch.save(save_state, save_path) 149 | logger.info(f"{save_path} saved !!!") 150 | 151 | 152 | def get_grad_norm(parameters, norm_type=2): 153 | if isinstance(parameters, torch.Tensor): 154 | parameters = [parameters] 155 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 156 | norm_type = float(norm_type) 157 | total_norm = 0 158 | for p in parameters: 159 | param_norm = p.grad.data.norm(norm_type) 160 | total_norm += param_norm.item() ** norm_type 161 | total_norm = total_norm ** (1. / norm_type) 162 | return total_norm 163 | 164 | 165 | def auto_resume_helper(output_dir): 166 | checkpoints = os.listdir(output_dir) 167 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 168 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 169 | if len(checkpoints) > 0: 170 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 171 | print(f"The latest checkpoint founded: {latest_checkpoint}") 172 | resume_file = latest_checkpoint 173 | else: 174 | resume_file = None 175 | return resume_file 176 | 177 | 178 | def reduce_tensor(tensor): 179 | rt = tensor.clone() 180 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 181 | rt /= dist.get_world_size() 182 | return rt 183 | 184 | 185 | def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: 186 | if isinstance(parameters, torch.Tensor): 187 | parameters = [parameters] 188 | parameters = [p for p in parameters if p.grad is not None] 189 | norm_type = float(norm_type) 190 | if len(parameters) == 0: 191 | return torch.tensor(0.) 192 | device = parameters[0].grad.device 193 | if norm_type == inf: 194 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 195 | else: 196 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 197 | norm_type).to(device) for p in parameters]), norm_type) 198 | return total_norm 199 | 200 | 201 | class NativeScalerWithGradNormCount: 202 | state_dict_key = "amp_scaler" 203 | 204 | def __init__(self): 205 | self._scaler = torch.cuda.amp.GradScaler() 206 | 207 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 208 | self._scaler.scale(loss).backward(create_graph=create_graph) 209 | if update_grad: 210 | if clip_grad is not None: 211 | assert parameters is not None 212 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 213 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 214 | else: 215 | self._scaler.unscale_(optimizer) 216 | norm = ampscaler_get_grad_norm(parameters) 217 | self._scaler.step(optimizer) 218 | self._scaler.update() 219 | else: 220 | norm = None 221 | return norm 222 | 223 | def state_dict(self): 224 | return self._scaler.state_dict() 225 | 226 | def load_state_dict(self, state_dict): 227 | self._scaler.load_state_dict(state_dict) 228 | 229 | 230 | def synchronize_between_processes(meter): 231 | """ 232 | only for timm AverageMeter 233 | """ 234 | if not is_dist_avail_and_initialized(): 235 | return meter 236 | assert isinstance(meter, AverageMeter), 'sychronize only available with timm.utils.AverageMeter' 237 | t = torch.tensor([meter.val, meter.avg, meter.sum, meter.count], dtype=torch.float64, device='cuda') 238 | dist.barrier() 239 | dist.all_reduce(t) 240 | t = t.tolist() 241 | meter.val = t[0] 242 | meter.avg = t[1] 243 | meter.sum = t[2] 244 | meter.count = int(t[3]) 245 | 246 | def is_dist_avail_and_initialized(): 247 | if not dist.is_available(): 248 | return False 249 | if not dist.is_initialized(): 250 | return False 251 | return True -------------------------------------------------------------------------------- /swin/data/map22kto1k.txt: -------------------------------------------------------------------------------- 1 | 359 2 | 368 3 | 460 4 | 475 5 | 486 6 | 492 7 | 496 8 | 514 9 | 516 10 | 525 11 | 547 12 | 548 13 | 556 14 | 563 15 | 575 16 | 641 17 | 648 18 | 723 19 | 733 20 | 765 21 | 801 22 | 826 23 | 852 24 | 858 25 | 878 26 | 896 27 | 900 28 | 905 29 | 908 30 | 910 31 | 935 32 | 946 33 | 947 34 | 994 35 | 999 36 | 1003 37 | 1005 38 | 1010 39 | 1027 40 | 1029 41 | 1048 42 | 1055 43 | 1064 44 | 1065 45 | 1069 46 | 1075 47 | 1079 48 | 1081 49 | 1085 50 | 1088 51 | 1093 52 | 1106 53 | 1143 54 | 1144 55 | 1145 56 | 1147 57 | 1168 58 | 1171 59 | 1178 60 | 1187 61 | 1190 62 | 1197 63 | 1205 64 | 1216 65 | 1223 66 | 1230 67 | 1236 68 | 1241 69 | 1245 70 | 1257 71 | 1259 72 | 1260 73 | 1267 74 | 1268 75 | 1269 76 | 1271 77 | 1272 78 | 1273 79 | 1277 80 | 1303 81 | 1344 82 | 1349 83 | 1355 84 | 1357 85 | 1384 86 | 1388 87 | 1391 88 | 1427 89 | 1429 90 | 1432 91 | 1437 92 | 1450 93 | 1461 94 | 1462 95 | 1474 96 | 1502 97 | 1503 98 | 1512 99 | 1552 100 | 1555 101 | 1577 102 | 1584 103 | 1587 104 | 1589 105 | 1599 106 | 1615 107 | 1616 108 | 1681 109 | 1692 110 | 1701 111 | 1716 112 | 1729 113 | 1757 114 | 1759 115 | 1764 116 | 1777 117 | 1786 118 | 1822 119 | 1841 120 | 1842 121 | 1848 122 | 1850 123 | 1856 124 | 1860 125 | 1861 126 | 1864 127 | 1876 128 | 1897 129 | 1898 130 | 1910 131 | 1913 132 | 1918 133 | 1922 134 | 1928 135 | 1932 136 | 1935 137 | 1947 138 | 1951 139 | 1953 140 | 1970 141 | 1977 142 | 1979 143 | 2001 144 | 2017 145 | 2067 146 | 2081 147 | 2087 148 | 2112 149 | 2128 150 | 2135 151 | 2147 152 | 2174 153 | 2175 154 | 2176 155 | 2177 156 | 2178 157 | 2181 158 | 2183 159 | 2184 160 | 2187 161 | 2189 162 | 2190 163 | 2191 164 | 2192 165 | 2193 166 | 2197 167 | 2202 168 | 2203 169 | 2206 170 | 2208 171 | 2209 172 | 2211 173 | 2212 174 | 2213 175 | 2214 176 | 2215 177 | 2216 178 | 2217 179 | 2219 180 | 2222 181 | 2223 182 | 2224 183 | 2225 184 | 2226 185 | 2227 186 | 2228 187 | 2229 188 | 2230 189 | 2236 190 | 2238 191 | 2240 192 | 2241 193 | 2242 194 | 2243 195 | 2244 196 | 2245 197 | 2247 198 | 2248 199 | 2249 200 | 2250 201 | 2251 202 | 2252 203 | 2255 204 | 2256 205 | 2257 206 | 2262 207 | 2263 208 | 2264 209 | 2265 210 | 2266 211 | 2268 212 | 2270 213 | 2271 214 | 2272 215 | 2273 216 | 2275 217 | 2276 218 | 2279 219 | 2280 220 | 2281 221 | 2282 222 | 2285 223 | 2289 224 | 2292 225 | 2295 226 | 2296 227 | 2297 228 | 2298 229 | 2299 230 | 2300 231 | 2301 232 | 2302 233 | 2303 234 | 2304 235 | 2305 236 | 2306 237 | 2309 238 | 2310 239 | 2312 240 | 2313 241 | 2314 242 | 2315 243 | 2316 244 | 2318 245 | 2319 246 | 2321 247 | 2322 248 | 2326 249 | 2329 250 | 2330 251 | 2331 252 | 2332 253 | 2334 254 | 2335 255 | 2336 256 | 2337 257 | 2338 258 | 2339 259 | 2341 260 | 2342 261 | 2343 262 | 2344 263 | 2346 264 | 2348 265 | 2349 266 | 2351 267 | 2352 268 | 2353 269 | 2355 270 | 2357 271 | 2358 272 | 2359 273 | 2360 274 | 2364 275 | 2365 276 | 2368 277 | 2369 278 | 2377 279 | 2382 280 | 2383 281 | 2385 282 | 2397 283 | 2398 284 | 2400 285 | 2402 286 | 2405 287 | 2412 288 | 2421 289 | 2428 290 | 2431 291 | 2432 292 | 2433 293 | 2436 294 | 2441 295 | 2445 296 | 2450 297 | 2453 298 | 2454 299 | 2465 300 | 2469 301 | 2532 302 | 2533 303 | 2538 304 | 2544 305 | 2547 306 | 2557 307 | 2565 308 | 2578 309 | 2612 310 | 2658 311 | 2702 312 | 2722 313 | 2731 314 | 2738 315 | 2741 316 | 2747 317 | 2810 318 | 2818 319 | 2833 320 | 2844 321 | 2845 322 | 2867 323 | 2874 324 | 2882 325 | 2884 326 | 2888 327 | 2889 328 | 3008 329 | 3012 330 | 3019 331 | 3029 332 | 3033 333 | 3042 334 | 3091 335 | 3106 336 | 3138 337 | 3159 338 | 3164 339 | 3169 340 | 3280 341 | 3296 342 | 3311 343 | 3318 344 | 3320 345 | 3324 346 | 3330 347 | 3366 348 | 3375 349 | 3381 350 | 3406 351 | 3419 352 | 3432 353 | 3434 354 | 3435 355 | 3493 356 | 3495 357 | 3503 358 | 3509 359 | 3511 360 | 3513 361 | 3517 362 | 3521 363 | 3526 364 | 3546 365 | 3554 366 | 3600 367 | 3601 368 | 3606 369 | 3612 370 | 3613 371 | 3616 372 | 3622 373 | 3623 374 | 3627 375 | 3632 376 | 3634 377 | 3636 378 | 3638 379 | 3644 380 | 3646 381 | 3649 382 | 3650 383 | 3651 384 | 3656 385 | 3663 386 | 3673 387 | 3674 388 | 3689 389 | 3690 390 | 3702 391 | 3733 392 | 3769 393 | 3971 394 | 3974 395 | 4065 396 | 4068 397 | 4073 398 | 4102 399 | 4136 400 | 4140 401 | 4151 402 | 4159 403 | 4165 404 | 4207 405 | 4219 406 | 4226 407 | 4249 408 | 4256 409 | 4263 410 | 4270 411 | 4313 412 | 4321 413 | 4378 414 | 4386 415 | 4478 416 | 4508 417 | 4512 418 | 4536 419 | 4542 420 | 4550 421 | 4560 422 | 4562 423 | 4570 424 | 4571 425 | 4572 426 | 4583 427 | 4588 428 | 4594 429 | 4604 430 | 4608 431 | 4623 432 | 4634 433 | 4636 434 | 4646 435 | 4651 436 | 4652 437 | 4686 438 | 4688 439 | 4691 440 | 4699 441 | 4724 442 | 4727 443 | 4737 444 | 4770 445 | 4774 446 | 4789 447 | 4802 448 | 4807 449 | 4819 450 | 4880 451 | 4886 452 | 4908 453 | 4927 454 | 4931 455 | 4936 456 | 4964 457 | 4976 458 | 4993 459 | 5028 460 | 5033 461 | 5043 462 | 5046 463 | 5096 464 | 5111 465 | 5114 466 | 5131 467 | 5132 468 | 5183 469 | 5199 470 | 5235 471 | 5275 472 | 5291 473 | 5293 474 | 5294 475 | 5343 476 | 5360 477 | 5362 478 | 5364 479 | 5390 480 | 5402 481 | 5418 482 | 5428 483 | 5430 484 | 5437 485 | 5443 486 | 5473 487 | 5484 488 | 5486 489 | 5505 490 | 5507 491 | 5508 492 | 5510 493 | 5567 494 | 5578 495 | 5580 496 | 5584 497 | 5606 498 | 5613 499 | 5629 500 | 5672 501 | 5676 502 | 5692 503 | 5701 504 | 5760 505 | 5769 506 | 5770 507 | 5779 508 | 5814 509 | 5850 510 | 5871 511 | 5893 512 | 5911 513 | 5949 514 | 5954 515 | 6005 516 | 6006 517 | 6012 518 | 6017 519 | 6023 520 | 6024 521 | 6040 522 | 6050 523 | 6054 524 | 6087 525 | 6105 526 | 6157 527 | 6235 528 | 6237 529 | 6256 530 | 6259 531 | 6286 532 | 6291 533 | 6306 534 | 6339 535 | 6341 536 | 6343 537 | 6379 538 | 6383 539 | 6393 540 | 6405 541 | 6479 542 | 6511 543 | 6517 544 | 6541 545 | 6561 546 | 6608 547 | 6611 548 | 6615 549 | 6678 550 | 6682 551 | 6707 552 | 6752 553 | 6798 554 | 6850 555 | 6880 556 | 6885 557 | 6890 558 | 6920 559 | 6981 560 | 7000 561 | 7009 562 | 7038 563 | 7049 564 | 7050 565 | 7052 566 | 7073 567 | 7078 568 | 7098 569 | 7111 570 | 7165 571 | 7198 572 | 7204 573 | 7280 574 | 7283 575 | 7286 576 | 7287 577 | 7293 578 | 7294 579 | 7305 580 | 7318 581 | 7341 582 | 7346 583 | 7354 584 | 7382 585 | 7427 586 | 7428 587 | 7435 588 | 7445 589 | 7450 590 | 7455 591 | 7467 592 | 7469 593 | 7497 594 | 7502 595 | 7506 596 | 7514 597 | 7523 598 | 7651 599 | 7661 600 | 7664 601 | 7672 602 | 7679 603 | 7685 604 | 7696 605 | 7730 606 | 7871 607 | 7873 608 | 7895 609 | 7914 610 | 7915 611 | 7920 612 | 7934 613 | 7935 614 | 7949 615 | 8009 616 | 8036 617 | 8051 618 | 8065 619 | 8074 620 | 8090 621 | 8112 622 | 8140 623 | 8164 624 | 8168 625 | 8178 626 | 8182 627 | 8198 628 | 8212 629 | 8216 630 | 8230 631 | 8242 632 | 8288 633 | 8289 634 | 8295 635 | 8318 636 | 8352 637 | 8368 638 | 8371 639 | 8375 640 | 8376 641 | 8401 642 | 8416 643 | 8419 644 | 8436 645 | 8460 646 | 8477 647 | 8478 648 | 8482 649 | 8498 650 | 8500 651 | 8539 652 | 8543 653 | 8552 654 | 8555 655 | 8580 656 | 8584 657 | 8586 658 | 8594 659 | 8598 660 | 8601 661 | 8606 662 | 8610 663 | 8611 664 | 8622 665 | 8627 666 | 8639 667 | 8649 668 | 8650 669 | 8653 670 | 8654 671 | 8667 672 | 8672 673 | 8673 674 | 8674 675 | 8676 676 | 8684 677 | 8720 678 | 8723 679 | 8750 680 | 8753 681 | 8801 682 | 8815 683 | 8831 684 | 8835 685 | 8842 686 | 8845 687 | 8858 688 | 8897 689 | 8916 690 | 8951 691 | 8954 692 | 8959 693 | 8970 694 | 8976 695 | 8981 696 | 8983 697 | 8989 698 | 8991 699 | 8993 700 | 9019 701 | 9039 702 | 9042 703 | 9043 704 | 9056 705 | 9057 706 | 9070 707 | 9087 708 | 9098 709 | 9106 710 | 9130 711 | 9131 712 | 9155 713 | 9171 714 | 9183 715 | 9198 716 | 9199 717 | 9201 718 | 9204 719 | 9212 720 | 9221 721 | 9225 722 | 9229 723 | 9250 724 | 9260 725 | 9271 726 | 9279 727 | 9295 728 | 9300 729 | 9310 730 | 9322 731 | 9345 732 | 9352 733 | 9376 734 | 9377 735 | 9382 736 | 9392 737 | 9401 738 | 9405 739 | 9441 740 | 9449 741 | 9464 742 | 9475 743 | 9502 744 | 9505 745 | 9514 746 | 9515 747 | 9545 748 | 9567 749 | 9576 750 | 9608 751 | 9609 752 | 9624 753 | 9633 754 | 9639 755 | 9643 756 | 9656 757 | 9674 758 | 9740 759 | 9752 760 | 9760 761 | 9767 762 | 9778 763 | 9802 764 | 9820 765 | 9839 766 | 9879 767 | 9924 768 | 9956 769 | 9961 770 | 9963 771 | 9970 772 | 9997 773 | 10010 774 | 10031 775 | 10040 776 | 10052 777 | 10073 778 | 10075 779 | 10078 780 | 10094 781 | 10097 782 | 10109 783 | 10118 784 | 10121 785 | 10124 786 | 10158 787 | 10226 788 | 10276 789 | 10304 790 | 10307 791 | 10314 792 | 10315 793 | 10332 794 | 10337 795 | 10338 796 | 10413 797 | 10423 798 | 10451 799 | 10463 800 | 10465 801 | 10487 802 | 10519 803 | 10522 804 | 10523 805 | 10532 806 | 10534 807 | 10535 808 | 10551 809 | 10559 810 | 10574 811 | 10583 812 | 10586 813 | 10589 814 | 10612 815 | 10626 816 | 10635 817 | 10638 818 | 10677 819 | 10683 820 | 10726 821 | 10776 822 | 10782 823 | 10783 824 | 10807 825 | 10837 826 | 10840 827 | 10848 828 | 10859 829 | 10871 830 | 10881 831 | 10884 832 | 10908 833 | 10914 834 | 10921 835 | 10936 836 | 10947 837 | 10951 838 | 10952 839 | 10957 840 | 10999 841 | 11003 842 | 11018 843 | 11023 844 | 11025 845 | 11027 846 | 11045 847 | 11055 848 | 11095 849 | 11110 850 | 11137 851 | 5564 852 | 11168 853 | 11186 854 | 11221 855 | 11223 856 | 11242 857 | 11255 858 | 11259 859 | 11279 860 | 11306 861 | 11311 862 | 11331 863 | 11367 864 | 11377 865 | 11389 866 | 11392 867 | 11401 868 | 11407 869 | 11437 870 | 11449 871 | 11466 872 | 11469 873 | 11473 874 | 11478 875 | 11483 876 | 11484 877 | 11507 878 | 11536 879 | 11558 880 | 11566 881 | 11575 882 | 11584 883 | 11594 884 | 11611 885 | 11612 886 | 11619 887 | 11621 888 | 11640 889 | 11643 890 | 11664 891 | 11674 892 | 11689 893 | 11709 894 | 11710 895 | 11716 896 | 11721 897 | 11726 898 | 11729 899 | 11743 900 | 11760 901 | 11771 902 | 11837 903 | 11839 904 | 11856 905 | 11876 906 | 11878 907 | 11884 908 | 11889 909 | 11896 910 | 11917 911 | 11923 912 | 11930 913 | 11944 914 | 11952 915 | 11980 916 | 11984 917 | 12214 918 | 12229 919 | 12239 920 | 12241 921 | 12242 922 | 12247 923 | 12283 924 | 12349 925 | 12369 926 | 12373 927 | 12422 928 | 12560 929 | 12566 930 | 12575 931 | 12688 932 | 12755 933 | 12768 934 | 12778 935 | 12780 936 | 12812 937 | 12832 938 | 12835 939 | 12836 940 | 12843 941 | 12847 942 | 12849 943 | 12850 944 | 12856 945 | 12858 946 | 12873 947 | 12938 948 | 12971 949 | 13017 950 | 13038 951 | 13046 952 | 13059 953 | 13085 954 | 13086 955 | 13088 956 | 13094 957 | 13134 958 | 13182 959 | 13230 960 | 13406 961 | 13444 962 | 13614 963 | 13690 964 | 13698 965 | 13709 966 | 13749 967 | 13804 968 | 13982 969 | 14051 970 | 14059 971 | 14219 972 | 14246 973 | 14256 974 | 14264 975 | 14294 976 | 14324 977 | 14367 978 | 14389 979 | 14394 980 | 14438 981 | 14442 982 | 14965 983 | 15732 984 | 16744 985 | 18037 986 | 18205 987 | 18535 988 | 18792 989 | 19102 990 | 20019 991 | 20462 992 | 21026 993 | 21045 994 | 21163 995 | 21171 996 | 21181 997 | 21196 998 | 21200 999 | 21369 1000 | 21817 -------------------------------------------------------------------------------- /mae/util/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/facebookresearch/mae 4 | """ 5 | 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | # -------------------------------------------------------- 12 | # References: 13 | # DeiT: https://github.com/facebookresearch/deit 14 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 15 | # -------------------------------------------------------- 16 | 17 | import builtins 18 | import datetime 19 | import os 20 | import time 21 | from collections import defaultdict, deque 22 | from pathlib import Path 23 | 24 | import torch 25 | import torch.distributed as dist 26 | from torch._six import inf 27 | 28 | 29 | class SmoothedValue(object): 30 | """Track a series of values and provide access to smoothed values over a 31 | window or the global series average. 32 | """ 33 | 34 | def __init__(self, window_size=20, fmt=None): 35 | if fmt is None: 36 | fmt = "{median:.4f} ({global_avg:.4f})" 37 | self.deque = deque(maxlen=window_size) 38 | self.total = 0.0 39 | self.count = 0 40 | self.fmt = fmt 41 | 42 | def update(self, value, n=1): 43 | self.deque.append(value) 44 | self.count += n 45 | self.total += value * n 46 | 47 | def synchronize_between_processes(self): 48 | """ 49 | Warning: does not synchronize the deque! 50 | """ 51 | if not is_dist_avail_and_initialized(): 52 | return 53 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 54 | dist.barrier() 55 | dist.all_reduce(t) 56 | t = t.tolist() 57 | self.count = int(t[0]) 58 | self.total = t[1] 59 | 60 | @property 61 | def median(self): 62 | d = torch.tensor(list(self.deque)) 63 | return d.median().item() 64 | 65 | @property 66 | def avg(self): 67 | d = torch.tensor(list(self.deque), dtype=torch.float32) 68 | return d.mean().item() 69 | 70 | @property 71 | def global_avg(self): 72 | return self.total / self.count 73 | 74 | @property 75 | def max(self): 76 | return max(self.deque) 77 | 78 | @property 79 | def value(self): 80 | return self.deque[-1] 81 | 82 | def __str__(self): 83 | return self.fmt.format( 84 | median=self.median, 85 | avg=self.avg, 86 | global_avg=self.global_avg, 87 | max=self.max, 88 | value=self.value) 89 | 90 | 91 | class MetricLogger(object): 92 | def __init__(self, delimiter="\t"): 93 | self.meters = defaultdict(SmoothedValue) 94 | self.delimiter = delimiter 95 | 96 | def update(self, **kwargs): 97 | for k, v in kwargs.items(): 98 | if v is None: 99 | continue 100 | if isinstance(v, torch.Tensor): 101 | v = v.item() 102 | assert isinstance(v, (float, int)) 103 | self.meters[k].update(v) 104 | 105 | def __getattr__(self, attr): 106 | if attr in self.meters: 107 | return self.meters[attr] 108 | if attr in self.__dict__: 109 | return self.__dict__[attr] 110 | raise AttributeError("'{}' object has no attribute '{}'".format( 111 | type(self).__name__, attr)) 112 | 113 | def __str__(self): 114 | loss_str = [] 115 | for name, meter in self.meters.items(): 116 | loss_str.append( 117 | "{}: {}".format(name, str(meter)) 118 | ) 119 | return self.delimiter.join(loss_str) 120 | 121 | def synchronize_between_processes(self): 122 | for meter in self.meters.values(): 123 | meter.synchronize_between_processes() 124 | 125 | def add_meter(self, name, meter): 126 | self.meters[name] = meter 127 | 128 | def log_every(self, iterable, print_freq, header=None): 129 | i = 0 130 | if not header: 131 | header = '' 132 | start_time = time.time() 133 | end = time.time() 134 | iter_time = SmoothedValue(fmt='{avg:.4f}') 135 | data_time = SmoothedValue(fmt='{avg:.4f}') 136 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 137 | log_msg = [ 138 | header, 139 | '[{0' + space_fmt + '}/{1}]', 140 | 'eta: {eta}', 141 | '{meters}', 142 | 'time: {time}', 143 | 'data: {data}' 144 | ] 145 | if torch.cuda.is_available(): 146 | log_msg.append('max mem: {memory:.0f}') 147 | log_msg = self.delimiter.join(log_msg) 148 | MB = 1024.0 * 1024.0 149 | for obj in iterable: 150 | data_time.update(time.time() - end) 151 | yield obj 152 | iter_time.update(time.time() - end) 153 | if i % print_freq == 0 or i == len(iterable) - 1: 154 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 155 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 156 | if torch.cuda.is_available(): 157 | print(log_msg.format( 158 | i, len(iterable), eta=eta_string, 159 | meters=str(self), 160 | time=str(iter_time), data=str(data_time), 161 | memory=torch.cuda.max_memory_allocated() / MB)) 162 | else: 163 | print(log_msg.format( 164 | i, len(iterable), eta=eta_string, 165 | meters=str(self), 166 | time=str(iter_time), data=str(data_time))) 167 | i += 1 168 | end = time.time() 169 | total_time = time.time() - start_time 170 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 171 | print('{} Total time: {} ({:.4f} s / it)'.format( 172 | header, total_time_str, total_time / len(iterable))) 173 | 174 | 175 | def setup_for_distributed(is_master): 176 | """ 177 | This function disables printing when not in master process 178 | """ 179 | builtin_print = builtins.print 180 | 181 | def print(*args, **kwargs): 182 | force = kwargs.pop('force', False) 183 | force = force or (get_world_size() > 8) 184 | if is_master or force: 185 | now = datetime.datetime.now().time() 186 | builtin_print('[{}] '.format(now), end='') # print with time stamp 187 | builtin_print(*args, **kwargs) 188 | 189 | builtins.print = print 190 | 191 | 192 | def is_dist_avail_and_initialized(): 193 | if not dist.is_available(): 194 | return False 195 | if not dist.is_initialized(): 196 | return False 197 | return True 198 | 199 | 200 | def get_world_size(): 201 | if not is_dist_avail_and_initialized(): 202 | return 1 203 | return dist.get_world_size() 204 | 205 | 206 | def get_rank(): 207 | if not is_dist_avail_and_initialized(): 208 | return 0 209 | return dist.get_rank() 210 | 211 | 212 | def is_main_process(): 213 | return get_rank() == 0 214 | 215 | 216 | def save_on_master(*args, **kwargs): 217 | if is_main_process(): 218 | torch.save(*args, **kwargs) 219 | 220 | 221 | def init_distributed_mode(args): 222 | if args.dist_on_itp: 223 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 224 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 225 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 226 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 227 | os.environ['LOCAL_RANK'] = str(args.gpu) 228 | os.environ['RANK'] = str(args.rank) 229 | os.environ['WORLD_SIZE'] = str(args.world_size) 230 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 231 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 232 | args.rank = int(os.environ["RANK"]) 233 | args.world_size = int(os.environ['WORLD_SIZE']) 234 | args.gpu = int(os.environ['LOCAL_RANK']) 235 | elif 'SLURM_PROCID' in os.environ: 236 | args.rank = int(os.environ['SLURM_PROCID']) 237 | args.gpu = args.rank % torch.cuda.device_count() 238 | else: 239 | print('Not using distributed mode') 240 | setup_for_distributed(is_master=True) # hack 241 | args.distributed = False 242 | return 243 | 244 | args.distributed = True 245 | 246 | torch.cuda.set_device(args.gpu) 247 | args.dist_backend = 'nccl' 248 | print('| distributed init (rank {}): {}, gpu {}'.format( 249 | args.rank, args.dist_url, args.gpu), flush=True) 250 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 251 | world_size=args.world_size, rank=args.rank) 252 | torch.distributed.barrier() 253 | setup_for_distributed(args.rank == 0) 254 | 255 | 256 | class NativeScalerWithGradNormCount: 257 | state_dict_key = "amp_scaler" 258 | 259 | def __init__(self): 260 | self._scaler = torch.cuda.amp.GradScaler() 261 | 262 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 263 | self._scaler.scale(loss).backward(create_graph=create_graph) 264 | if update_grad: 265 | if clip_grad is not None: 266 | assert parameters is not None 267 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 268 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 269 | else: 270 | self._scaler.unscale_(optimizer) 271 | norm = get_grad_norm_(parameters) 272 | self._scaler.step(optimizer) 273 | self._scaler.update() 274 | else: 275 | norm = None 276 | return norm 277 | 278 | def state_dict(self): 279 | return self._scaler.state_dict() 280 | 281 | def load_state_dict(self, state_dict): 282 | self._scaler.load_state_dict(state_dict) 283 | 284 | 285 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 286 | if isinstance(parameters, torch.Tensor): 287 | parameters = [parameters] 288 | parameters = [p for p in parameters if p.grad is not None] 289 | norm_type = float(norm_type) 290 | if len(parameters) == 0: 291 | return torch.tensor(0.) 292 | device = parameters[0].grad.device 293 | if norm_type == inf: 294 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 295 | else: 296 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 297 | return total_norm 298 | 299 | 300 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 301 | output_dir = Path(args.output_dir) 302 | epoch_name = str(epoch) 303 | if loss_scaler is not None: 304 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 305 | for checkpoint_path in checkpoint_paths: 306 | to_save = { 307 | 'model': model_without_ddp.state_dict(), 308 | 'optimizer': optimizer.state_dict(), 309 | 'epoch': epoch, 310 | 'scaler': loss_scaler.state_dict(), 311 | 'args': args, 312 | } 313 | 314 | save_on_master(to_save, checkpoint_path) 315 | else: 316 | client_state = {'epoch': epoch} 317 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 318 | 319 | 320 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 321 | if args.resume: 322 | if args.resume.startswith('https'): 323 | checkpoint = torch.hub.load_state_dict_from_url( 324 | args.resume, map_location='cpu', check_hash=True) 325 | else: 326 | checkpoint = torch.load(args.resume, map_location='cpu') 327 | model_without_ddp.load_state_dict(checkpoint['model']) 328 | print("Resume checkpoint %s" % args.resume) 329 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 330 | optimizer.load_state_dict(checkpoint['optimizer']) 331 | args.start_epoch = checkpoint['epoch'] + 1 332 | if 'scaler' in checkpoint: 333 | loss_scaler.load_state_dict(checkpoint['scaler']) 334 | print("With optim & sched!") 335 | 336 | 337 | def all_reduce_mean(x): 338 | world_size = get_world_size() 339 | if world_size > 1: 340 | x_reduce = torch.tensor(x).cuda() 341 | dist.all_reduce(x_reduce) 342 | x_reduce /= world_size 343 | return x_reduce.item() 344 | else: 345 | return x -------------------------------------------------------------------------------- /swin/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally obtained from: 3 | https://github.com/microsoft/Swin-Transformer 4 | """ 5 | 6 | # -------------------------------------------------------- 7 | # Swin Transformer 8 | # Copyright (c) 2021 Microsoft 9 | # Licensed under The MIT License [see LICENSE for details] 10 | # Written by Ze Liu 11 | # --------------------------------------------------------' 12 | 13 | import os 14 | import yaml 15 | from yacs.config import CfgNode as CN 16 | 17 | _C = CN() 18 | 19 | # Base config files 20 | _C.BASE = [''] 21 | 22 | # ----------------------------------------------------------------------------- 23 | # Data settings 24 | # ----------------------------------------------------------------------------- 25 | _C.DATA = CN() 26 | # Batch size for a single GPU, could be overwritten by command line argument 27 | _C.DATA.BATCH_SIZE = 128 28 | # Path to dataset, could be overwritten by command line argument 29 | _C.DATA.DATA_PATH = '' 30 | # Dataset name 31 | _C.DATA.DATASET = 'imagenet' 32 | # Input image size 33 | _C.DATA.IMG_SIZE = 224 34 | # Interpolation to resize image (random, bilinear, bicubic) 35 | _C.DATA.INTERPOLATION = 'bicubic' 36 | # Use zipped dataset instead of folder dataset 37 | # could be overwritten by command line argument 38 | _C.DATA.ZIP_MODE = False 39 | # Cache Data in Memory, could be overwritten by command line argument 40 | _C.DATA.CACHE_MODE = 'part' 41 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 42 | _C.DATA.PIN_MEMORY = True 43 | # Number of data loading threads 44 | _C.DATA.NUM_WORKERS = 8 45 | 46 | # [SimMIM] Mask patch size for MaskGenerator 47 | _C.DATA.MASK_PATCH_SIZE = 32 48 | # [SimMIM] Mask ratio for MaskGenerator 49 | _C.DATA.MASK_RATIO = 0.6 50 | 51 | # ----------------------------------------------------------------------------- 52 | # Model settings 53 | # ----------------------------------------------------------------------------- 54 | _C.MODEL = CN() 55 | # Model type 56 | _C.MODEL.TYPE = 'swin' 57 | # Model name 58 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 59 | # Pretrained weight from checkpoint, could be imagenet22k pretrained weight 60 | # could be overwritten by command line argument 61 | _C.MODEL.PRETRAINED = '' 62 | # Checkpoint to resume, could be overwritten by command line argument 63 | _C.MODEL.RESUME = '' 64 | # Number of classes, overwritten in data preparation 65 | _C.MODEL.NUM_CLASSES = 1000 66 | # Dropout rate 67 | _C.MODEL.DROP_RATE = 0.0 68 | # Drop path rate 69 | _C.MODEL.DROP_PATH_RATE = 0.1 70 | # Label Smoothing 71 | _C.MODEL.LABEL_SMOOTHING = 0.1 72 | 73 | # Swin Transformer parameters 74 | _C.MODEL.SWIN = CN() 75 | _C.MODEL.SWIN.PATCH_SIZE = 4 76 | _C.MODEL.SWIN.IN_CHANS = 3 77 | _C.MODEL.SWIN.EMBED_DIM = 96 78 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 79 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 80 | _C.MODEL.SWIN.WINDOW_SIZE = 7 81 | _C.MODEL.SWIN.MLP_RATIO = 4. 82 | _C.MODEL.SWIN.QKV_BIAS = True 83 | _C.MODEL.SWIN.QK_SCALE = None 84 | _C.MODEL.SWIN.APE = False 85 | _C.MODEL.SWIN.PATCH_NORM = True 86 | 87 | # Swin Transformer V2 parameters 88 | _C.MODEL.SWINV2 = CN() 89 | _C.MODEL.SWINV2.PATCH_SIZE = 4 90 | _C.MODEL.SWINV2.IN_CHANS = 3 91 | _C.MODEL.SWINV2.EMBED_DIM = 96 92 | _C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2] 93 | _C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24] 94 | _C.MODEL.SWINV2.WINDOW_SIZE = 7 95 | _C.MODEL.SWINV2.MLP_RATIO = 4. 96 | _C.MODEL.SWINV2.QKV_BIAS = True 97 | _C.MODEL.SWINV2.APE = False 98 | _C.MODEL.SWINV2.PATCH_NORM = True 99 | _C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] 100 | 101 | # Swin Transformer MoE parameters 102 | _C.MODEL.SWIN_MOE = CN() 103 | _C.MODEL.SWIN_MOE.PATCH_SIZE = 4 104 | _C.MODEL.SWIN_MOE.IN_CHANS = 3 105 | _C.MODEL.SWIN_MOE.EMBED_DIM = 96 106 | _C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2] 107 | _C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24] 108 | _C.MODEL.SWIN_MOE.WINDOW_SIZE = 7 109 | _C.MODEL.SWIN_MOE.MLP_RATIO = 4. 110 | _C.MODEL.SWIN_MOE.QKV_BIAS = True 111 | _C.MODEL.SWIN_MOE.QK_SCALE = None 112 | _C.MODEL.SWIN_MOE.APE = False 113 | _C.MODEL.SWIN_MOE.PATCH_NORM = True 114 | _C.MODEL.SWIN_MOE.MLP_FC2_BIAS = True 115 | _C.MODEL.SWIN_MOE.INIT_STD = 0.02 116 | _C.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] 117 | _C.MODEL.SWIN_MOE.MOE_BLOCKS = [[-1], [-1], [-1], [-1]] 118 | _C.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS = 1 119 | _C.MODEL.SWIN_MOE.TOP_VALUE = 1 120 | _C.MODEL.SWIN_MOE.CAPACITY_FACTOR = 1.25 121 | _C.MODEL.SWIN_MOE.COSINE_ROUTER = False 122 | _C.MODEL.SWIN_MOE.NORMALIZE_GATE = False 123 | _C.MODEL.SWIN_MOE.USE_BPR = True 124 | _C.MODEL.SWIN_MOE.IS_GSHARD_LOSS = False 125 | _C.MODEL.SWIN_MOE.GATE_NOISE = 1.0 126 | _C.MODEL.SWIN_MOE.COSINE_ROUTER_DIM = 256 127 | _C.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T = 0.5 128 | _C.MODEL.SWIN_MOE.MOE_DROP = 0.0 129 | _C.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT = 0.01 130 | 131 | # Swin MLP parameters 132 | _C.MODEL.SWIN_MLP = CN() 133 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 134 | _C.MODEL.SWIN_MLP.IN_CHANS = 3 135 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96 136 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] 137 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] 138 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 139 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4. 140 | _C.MODEL.SWIN_MLP.APE = False 141 | _C.MODEL.SWIN_MLP.PATCH_NORM = True 142 | 143 | # [SimMIM] Norm target during training 144 | _C.MODEL.SIMMIM = CN() 145 | _C.MODEL.SIMMIM.NORM_TARGET = CN() 146 | _C.MODEL.SIMMIM.NORM_TARGET.ENABLE = False 147 | _C.MODEL.SIMMIM.NORM_TARGET.PATCH_SIZE = 47 148 | 149 | # ----------------------------------------------------------------------------- 150 | # Training settings 151 | # ----------------------------------------------------------------------------- 152 | _C.TRAIN = CN() 153 | _C.TRAIN.START_EPOCH = 0 154 | _C.TRAIN.EPOCHS = 300 155 | _C.TRAIN.WARMUP_EPOCHS = 20 156 | _C.TRAIN.WEIGHT_DECAY = 0.05 157 | _C.TRAIN.BASE_LR = 5e-4 158 | _C.TRAIN.WARMUP_LR = 5e-7 159 | _C.TRAIN.MIN_LR = 5e-6 160 | # Clip gradient norm 161 | _C.TRAIN.CLIP_GRAD = 5.0 162 | # Auto resume from latest checkpoint 163 | _C.TRAIN.AUTO_RESUME = True 164 | # Gradient accumulation steps 165 | # could be overwritten by command line argument 166 | _C.TRAIN.ACCUMULATION_STEPS = 1 167 | # Whether to use gradient checkpointing to save memory 168 | # could be overwritten by command line argument 169 | _C.TRAIN.USE_CHECKPOINT = False 170 | 171 | # LR scheduler 172 | _C.TRAIN.LR_SCHEDULER = CN() 173 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 174 | # Epoch interval to decay LR, used in StepLRScheduler 175 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 176 | # LR decay rate, used in StepLRScheduler 177 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 178 | # warmup_prefix used in CosineLRScheduler 179 | _C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = False 180 | # [SimMIM] Gamma / Multi steps value, used in MultiStepLRScheduler 181 | _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 182 | _C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] 183 | 184 | # Optimizer 185 | _C.TRAIN.OPTIMIZER = CN() 186 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 187 | # Optimizer Epsilon 188 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 189 | # Optimizer Betas 190 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 191 | # SGD momentum 192 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 193 | 194 | # [SimMIM] Layer decay for fine-tuning 195 | _C.TRAIN.LAYER_DECAY = 1.0 196 | 197 | # MoE 198 | _C.TRAIN.MOE = CN() 199 | # Only save model on master device 200 | _C.TRAIN.MOE.SAVE_MASTER = False 201 | # ----------------------------------------------------------------------------- 202 | # Augmentation settings 203 | # ----------------------------------------------------------------------------- 204 | _C.AUG = CN() 205 | # Color jitter factor 206 | _C.AUG.COLOR_JITTER = 0.4 207 | # Use AutoAugment policy. "v0" or "original" 208 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 209 | # Random erase prob 210 | _C.AUG.REPROB = 0.25 211 | # Random erase mode 212 | _C.AUG.REMODE = 'pixel' 213 | # Random erase count 214 | _C.AUG.RECOUNT = 1 215 | # Mixup alpha, mixup enabled if > 0 216 | _C.AUG.MIXUP = 0.8 217 | # Cutmix alpha, cutmix enabled if > 0 218 | _C.AUG.CUTMIX = 1.0 219 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 220 | _C.AUG.CUTMIX_MINMAX = None 221 | # Probability of performing mixup or cutmix when either/both is enabled 222 | _C.AUG.MIXUP_PROB = 1.0 223 | # Probability of switching to cutmix when both mixup and cutmix enabled 224 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 225 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 226 | _C.AUG.MIXUP_MODE = 'batch' 227 | # AugSub method 228 | _C.AUG.AUGSUB = 'none' 229 | # AugSub drop-ratio 230 | _C.AUG.AUGSUB_RATIO = 0.5 231 | 232 | # ----------------------------------------------------------------------------- 233 | # Testing settings 234 | # ----------------------------------------------------------------------------- 235 | _C.TEST = CN() 236 | # Whether to use center crop when testing 237 | _C.TEST.CROP = True 238 | # Whether to use SequentialSampler as validation sampler 239 | _C.TEST.SEQUENTIAL = False 240 | _C.TEST.SHUFFLE = False 241 | 242 | # ----------------------------------------------------------------------------- 243 | # Misc 244 | # ----------------------------------------------------------------------------- 245 | # [SimMIM] Whether to enable pytorch amp, overwritten by command line argument 246 | _C.ENABLE_AMP = False 247 | 248 | # Enable Pytorch automatic mixed precision (amp). 249 | _C.AMP_ENABLE = True 250 | # [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2') 251 | _C.AMP_OPT_LEVEL = '' 252 | # Path to output folder, overwritten by command line argument 253 | _C.OUTPUT = '' 254 | # Tag of experiment, overwritten by command line argument 255 | _C.TAG = 'default' 256 | # Frequency to save checkpoint 257 | _C.SAVE_FREQ = 1 258 | # Frequency to logging info 259 | _C.PRINT_FREQ = 10 260 | # Fixed random seed 261 | _C.SEED = 0 262 | # Perform evaluation only, overwritten by command line argument 263 | _C.EVAL_MODE = False 264 | # Test throughput only, overwritten by command line argument 265 | _C.THROUGHPUT_MODE = False 266 | # local rank for DistributedDataParallel, given by command line argument 267 | _C.LOCAL_RANK = 0 268 | # for acceleration 269 | _C.FUSED_WINDOW_PROCESS = False 270 | _C.FUSED_LAYERNORM = False 271 | 272 | 273 | def _update_config_from_file(config, cfg_file): 274 | config.defrost() 275 | with open(cfg_file, 'r') as f: 276 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 277 | 278 | for cfg in yaml_cfg.setdefault('BASE', ['']): 279 | if cfg: 280 | _update_config_from_file( 281 | config, os.path.join(os.path.dirname(cfg_file), cfg) 282 | ) 283 | print('=> merge config from {}'.format(cfg_file)) 284 | config.merge_from_file(cfg_file) 285 | config.freeze() 286 | 287 | 288 | def update_config(config, args): 289 | _update_config_from_file(config, args.cfg) 290 | 291 | config.defrost() 292 | if args.opts: 293 | config.merge_from_list(args.opts) 294 | 295 | def _check_args(name): 296 | if hasattr(args, name) and eval(f'args.{name}'): 297 | return True 298 | return False 299 | 300 | # merge from specific arguments 301 | if _check_args('batch_size'): 302 | config.DATA.BATCH_SIZE = args.batch_size 303 | if _check_args('data_path'): 304 | config.DATA.DATA_PATH = args.data_path 305 | if _check_args('zip'): 306 | config.DATA.ZIP_MODE = True 307 | if _check_args('cache_mode'): 308 | config.DATA.CACHE_MODE = args.cache_mode 309 | if _check_args('pretrained'): 310 | config.MODEL.PRETRAINED = args.pretrained 311 | if _check_args('resume'): 312 | config.MODEL.RESUME = args.resume 313 | if _check_args('accumulation_steps'): 314 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 315 | if _check_args('use_checkpoint'): 316 | config.TRAIN.USE_CHECKPOINT = True 317 | if _check_args('amp_opt_level'): 318 | print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") 319 | if args.amp_opt_level == 'O0': 320 | config.AMP_ENABLE = False 321 | if _check_args('disable_amp'): 322 | config.AMP_ENABLE = False 323 | if _check_args('output'): 324 | config.OUTPUT = args.output 325 | if _check_args('tag'): 326 | config.TAG = args.tag 327 | if _check_args('eval'): 328 | config.EVAL_MODE = True 329 | if _check_args('throughput'): 330 | config.THROUGHPUT_MODE = True 331 | 332 | # [SimMIM] 333 | if _check_args('enable_amp'): 334 | config.ENABLE_AMP = args.enable_amp 335 | 336 | # for acceleration 337 | if _check_args('fused_window_process'): 338 | config.FUSED_WINDOW_PROCESS = True 339 | if _check_args('fused_layernorm'): 340 | config.FUSED_LAYERNORM = True 341 | ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb] 342 | if _check_args('optim'): 343 | config.TRAIN.OPTIMIZER.NAME = args.optim 344 | 345 | # AugSub 346 | if _check_args('augsub'): 347 | config.AUG.AUGSUB = args.augsub 348 | if _check_args('augsub_ratio'): 349 | config.AUG.AUGSUB_RATIO = args.augsub_ratio 350 | 351 | # set local rank for distributed training 352 | config.LOCAL_RANK = args.local_rank 353 | 354 | # output folder 355 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 356 | 357 | config.freeze() 358 | 359 | 360 | def get_config(args): 361 | """Get a yacs CfgNode object with default values.""" 362 | # Return a clone so that the defaults will not be altered 363 | # This is for the "local variable" use pattern 364 | config = _C.clone() 365 | update_config(config, args) 366 | 367 | return config 368 | --------------------------------------------------------------------------------