├── data ├── __init__.py ├── samplers.py ├── zipreader.py ├── build.py └── cached_image_folder.py ├── models ├── __init__.py ├── build.py ├── inline_pvt.py ├── inline_cswin.py ├── inline_deit.py └── inline_swin.py ├── figures ├── fig2_cls.png ├── fig3_speed.png └── fig1_injectivity.png ├── cfgs ├── inline_deit_t.yaml ├── inline_deit_b.yaml ├── inline_deit_s.yaml ├── inline_swin_t.yaml ├── inline_swin_s.yaml ├── inline_swin_b.yaml ├── inline_pvt_b.yaml ├── inline_pvt_t.yaml ├── inline_pvt_m.yaml ├── inline_pvt_s.yaml ├── inline_cswin_b.yaml ├── inline_cswin_s.yaml ├── inline_cswin_t.yaml ├── inline_cswin_b_384.yaml └── inline_swin_b_384.yaml ├── logger.py ├── optimizer.py ├── lr_scheduler.py ├── README.md ├── config.py ├── utils.py ├── utils_ema.py ├── main.py └── main_ema.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /figures/fig2_cls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/InLine/HEAD/figures/fig2_cls.png -------------------------------------------------------------------------------- /figures/fig3_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/InLine/HEAD/figures/fig3_speed.png -------------------------------------------------------------------------------- /figures/fig1_injectivity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/InLine/HEAD/figures/fig1_injectivity.png -------------------------------------------------------------------------------- /cfgs/inline_deit_t.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: inline_deit_tiny 3 | NAME: inline_deit_tiny 4 | DATA: 5 | BATCH_SIZE: 512 -------------------------------------------------------------------------------- /cfgs/inline_deit_b.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: inline_deit_base 3 | NAME: inline_deit_base 4 | DATA: 5 | IMG_SIZE: 448 6 | BATCH_SIZE: 64 -------------------------------------------------------------------------------- /cfgs/inline_deit_s.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: inline_deit_small 3 | NAME: inline_deit_small 4 | DATA: 5 | IMG_SIZE: 288 6 | BATCH_SIZE: 128 -------------------------------------------------------------------------------- /cfgs/inline_swin_t.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: inline_swin 3 | NAME: inline_swin_tiny 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 56 10 | INLINE: 11 | ATTN_TYPE: IIIS -------------------------------------------------------------------------------- /cfgs/inline_swin_s.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: inline_swin 3 | NAME: inline_swin_small 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 56 10 | INLINE: 11 | ATTN_TYPE: IISS -------------------------------------------------------------------------------- /cfgs/inline_swin_b.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: inline_swin 3 | NAME: inline_swin_base 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 56 10 | INLINE: 11 | ATTN_TYPE: IIMS2 12 | DATA: 13 | BATCH_SIZE: 64 -------------------------------------------------------------------------------- /cfgs/inline_pvt_b.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 224 3 | BATCH_SIZE: 64 4 | 5 | TRAIN: 6 | WEIGHT_DECAY: 0.05 7 | EPOCHS: 300 8 | WARMUP_EPOCHS: 5 9 | COOLDOWN_EPOCHS: 10 10 | BASE_LR: 5e-4 11 | WARMUP_LR: 1e-6 12 | MIN_LR: 1e-5 13 | CLIP_GRAD: 1.0 14 | 15 | MODEL: 16 | TYPE: inline_pvt_large 17 | NAME: inline_pvt_large 18 | DROP_PATH_RATE: 0.3 19 | -------------------------------------------------------------------------------- /cfgs/inline_pvt_t.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 224 3 | BATCH_SIZE: 128 4 | 5 | TRAIN: 6 | WEIGHT_DECAY: 0.05 7 | EPOCHS: 300 8 | WARMUP_EPOCHS: 5 9 | COOLDOWN_EPOCHS: 10 10 | BASE_LR: 5e-4 11 | WARMUP_LR: 1e-6 12 | MIN_LR: 1e-5 13 | CLIP_GRAD: None 14 | 15 | MODEL: 16 | TYPE: inline_pvt_tiny 17 | NAME: inline_pvt_tiny 18 | DROP_PATH_RATE: 0.1 19 | -------------------------------------------------------------------------------- /cfgs/inline_pvt_m.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 224 3 | BATCH_SIZE: 128 4 | 5 | TRAIN: 6 | WEIGHT_DECAY: 0.05 7 | EPOCHS: 300 8 | WARMUP_EPOCHS: 5 9 | COOLDOWN_EPOCHS: 10 10 | BASE_LR: 5e-4 11 | WARMUP_LR: 1e-6 12 | MIN_LR: 1e-5 13 | CLIP_GRAD: 1.0 14 | 15 | MODEL: 16 | TYPE: inline_pvt_medium 17 | NAME: inline_pvt_medium 18 | DROP_PATH_RATE: 0.3 19 | -------------------------------------------------------------------------------- /cfgs/inline_pvt_s.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 224 3 | BATCH_SIZE: 128 4 | 5 | TRAIN: 6 | WEIGHT_DECAY: 0.05 7 | EPOCHS: 300 8 | WARMUP_EPOCHS: 5 9 | COOLDOWN_EPOCHS: 10 10 | BASE_LR: 5e-4 11 | WARMUP_LR: 1e-6 12 | MIN_LR: 1e-5 13 | CLIP_GRAD: None 14 | 15 | MODEL: 16 | TYPE: inline_pvt_small 17 | NAME: inline_pvt_small 18 | DROP_PATH_RATE: 0.1 19 | -------------------------------------------------------------------------------- /cfgs/inline_cswin_b.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 224 3 | BATCH_SIZE: 64 4 | 5 | TRAIN: 6 | WEIGHT_DECAY: 0.1 7 | EPOCHS: 300 8 | WARMUP_EPOCHS: 20 9 | COOLDOWN_EPOCHS: 10 10 | BASE_LR: 5e-4 11 | WARMUP_LR: 1e-6 12 | MIN_LR: 1e-5 13 | CLIP_GRAD: None 14 | 15 | MODEL: 16 | TYPE: inline_cswin_base 17 | NAME: inline_cswin_base 18 | DROP_PATH_RATE: 0.5 19 | INLINE: 20 | ATTN_TYPE: IISS 21 | -------------------------------------------------------------------------------- /cfgs/inline_cswin_s.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 224 3 | BATCH_SIZE: 128 4 | 5 | TRAIN: 6 | WEIGHT_DECAY: 0.05 7 | EPOCHS: 300 8 | WARMUP_EPOCHS: 20 9 | COOLDOWN_EPOCHS: 10 10 | BASE_LR: 5e-4 11 | WARMUP_LR: 1e-6 12 | MIN_LR: 1e-5 13 | CLIP_GRAD: None 14 | 15 | MODEL: 16 | TYPE: inline_cswin_small 17 | NAME: inline_cswin_small 18 | DROP_PATH_RATE: 0.4 19 | INLINE: 20 | ATTN_TYPE: IISS -------------------------------------------------------------------------------- /cfgs/inline_cswin_t.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 224 3 | BATCH_SIZE: 128 4 | 5 | TRAIN: 6 | WEIGHT_DECAY: 0.05 7 | EPOCHS: 300 8 | WARMUP_EPOCHS: 20 9 | COOLDOWN_EPOCHS: 10 10 | BASE_LR: 5e-4 11 | WARMUP_LR: 1e-6 12 | MIN_LR: 1e-5 13 | CLIP_GRAD: None 14 | 15 | MODEL: 16 | TYPE: inline_cswin_tiny 17 | NAME: inline_cswin_tiny 18 | DROP_PATH_RATE: 0.2 19 | INLINE: 20 | ATTN_TYPE: IISS 21 | -------------------------------------------------------------------------------- /cfgs/inline_cswin_b_384.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | BATCH_SIZE: 32 4 | 5 | TRAIN: 6 | WEIGHT_DECAY: 1e-8 7 | EPOCHS: 20 8 | WARMUP_EPOCHS: 0 9 | COOLDOWN_EPOCHS: 10 10 | BASE_LR: 5e-6 11 | WARMUP_LR: 5e-6 12 | MIN_LR: 5e-7 13 | CLIP_GRAD: None 14 | 15 | MODEL: 16 | TYPE: inline_cswin_base_384 17 | NAME: inline_cswin_base_384 18 | DROP_PATH_RATE: 0.7 19 | INLINE: 20 | ATTN_TYPE: IISS 21 | CSWIN_LA_SPLIT_SIZE: 96-48-24-12 22 | -------------------------------------------------------------------------------- /cfgs/inline_swin_b_384.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | BATCH_SIZE: 32 4 | MODEL: 5 | TYPE: inline_swin 6 | NAME: inline_swin_base_384 7 | DROP_PATH_RATE: 0.5 8 | SWIN: 9 | EMBED_DIM: 128 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 4, 8, 16, 32 ] 12 | WINDOW_SIZE: 96 13 | INLINE: 14 | ATTN_TYPE: IIMS2 15 | TRAIN: 16 | EPOCHS: 30 17 | WARMUP_EPOCHS: 5 18 | WEIGHT_DECAY: 1e-8 19 | BASE_LR: 2e-05 20 | WARMUP_LR: 2e-08 21 | MIN_LR: 2e-07 22 | TEST: 23 | CROP: False -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 35 | file_handler.setLevel(logging.DEBUG) 36 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 37 | logger.addHandler(file_handler) 38 | 39 | return logger 40 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | from warnings import resetwarnings 2 | import torch.optim as optim 3 | 4 | def build_optimizer(config, model): 5 | """ 6 | Build optimizer, set weight decay of normalization to 0 by default. 7 | """ 8 | skip = {} 9 | skip_keywords = {} 10 | if hasattr(model, 'no_weight_decay'): 11 | skip = model.no_weight_decay() 12 | if hasattr(model, 'no_weight_decay_keywords'): 13 | skip_keywords = model.no_weight_decay_keywords() 14 | 15 | if hasattr(model, 'lower_lr_kvs'): 16 | lower_lr_kvs = model.lower_lr_kvs 17 | else: 18 | lower_lr_kvs = {} 19 | 20 | parameters = set_weight_decay_and_lr( 21 | model, skip, skip_keywords, lower_lr_kvs, config.TRAIN.BASE_LR) 22 | 23 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 24 | optimizer = None 25 | if opt_lower == 'sgd': 26 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 27 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 28 | elif opt_lower == 'adamw': 29 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 30 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 31 | 32 | return optimizer 33 | 34 | 35 | def set_weight_decay_and_lr( 36 | model, 37 | skip_list=(), skip_keywords=(), 38 | lower_lr_kvs={}, base_lr=5e-4): 39 | # breakpoint() 40 | assert len(lower_lr_kvs) == 1 or len(lower_lr_kvs) == 0 41 | has_lower_lr = len(lower_lr_kvs) == 1 42 | if has_lower_lr: 43 | for k,v in lower_lr_kvs.items(): 44 | lower_lr_key = k 45 | lower_lr = v * base_lr 46 | 47 | has_decay = [] 48 | has_decay_low = [] 49 | no_decay = [] 50 | no_decay_low = [] 51 | 52 | for name, param in model.named_parameters(): 53 | if not param.requires_grad: 54 | continue # frozen weights 55 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 56 | check_keywords_in_name(name, skip_keywords): 57 | 58 | if has_lower_lr and check_keywords_in_name(name, (lower_lr_key,)): 59 | no_decay_low.append(param) 60 | else: 61 | no_decay.append(param) 62 | 63 | else: 64 | 65 | if has_lower_lr and check_keywords_in_name(name, (lower_lr_key,)): 66 | has_decay_low.append(param) 67 | else: 68 | has_decay.append(param) 69 | 70 | if has_lower_lr: 71 | result = [ 72 | {'params': has_decay}, 73 | {'params': has_decay_low, 'lr': lower_lr}, 74 | {'params': no_decay, 'weight_decay': 0.}, 75 | {'params': no_decay_low, 'weight_decay': 0., 'lr': lower_lr} 76 | ] 77 | else: 78 | result = [ 79 | {'params': has_decay}, 80 | {'params': no_decay, 'weight_decay': 0.} 81 | ] 82 | # breakpoint() 83 | return result 84 | 85 | 86 | def check_keywords_in_name(name, keywords=()): 87 | isin = False 88 | for keyword in keywords: 89 | if keyword in name: 90 | isin = True 91 | return isin 92 | -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .inline_swin import InLineSwin 9 | from .inline_deit import inline_deit_tiny, inline_deit_small, inline_deit_base 10 | from .inline_pvt import inline_pvt_tiny, inline_pvt_small, inline_pvt_medium, inline_pvt_large 11 | from .inline_cswin import inline_cswin_tiny, inline_cswin_small, inline_cswin_base, inline_cswin_base_384 12 | 13 | 14 | def build_model(config): 15 | model_type = config.MODEL.TYPE 16 | if model_type == 'inline_swin': 17 | model = InLineSwin(img_size=config.DATA.IMG_SIZE, 18 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 19 | in_chans=config.MODEL.SWIN.IN_CHANS, 20 | num_classes=config.MODEL.NUM_CLASSES, 21 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 22 | depths=config.MODEL.SWIN.DEPTHS, 23 | num_heads=config.MODEL.SWIN.NUM_HEADS, 24 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 25 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 26 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 27 | qk_scale=config.MODEL.SWIN.QK_SCALE, 28 | drop_rate=config.MODEL.DROP_RATE, 29 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 30 | ape=config.MODEL.SWIN.APE, 31 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 32 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 33 | attn_type=config.MODEL.INLINE.ATTN_TYPE) 34 | 35 | elif model_type in ['inline_deit_tiny', 'inline_deit_small', 'inline_deit_base']: 36 | model = eval(model_type + '(img_size=config.DATA.IMG_SIZE,' 37 | 'drop_path_rate=config.MODEL.DROP_PATH_RATE)') 38 | 39 | elif model_type in ['inline_pvt_tiny', 'inline_pvt_small', 'inline_pvt_medium', 'inline_pvt_large']: 40 | model = eval(model_type + '(img_size=config.DATA.IMG_SIZE,' 41 | 'drop_path_rate=config.MODEL.DROP_PATH_RATE,' 42 | 'attn_type=config.MODEL.INLINE.ATTN_TYPE,' 43 | 'la_sr_ratios=str(config.MODEL.INLINE.PVT_LA_SR_RATIOS))') 44 | 45 | elif model_type in ['inline_cswin_tiny', 'inline_cswin_small', 'inline_cswin_base', 'inline_cswin_base_384']: 46 | model = eval(model_type + '(img_size=config.DATA.IMG_SIZE,' 47 | 'in_chans=config.MODEL.SWIN.IN_CHANS,' 48 | 'num_classes=config.MODEL.NUM_CLASSES,' 49 | 'drop_rate=config.MODEL.DROP_RATE,' 50 | 'drop_path_rate=config.MODEL.DROP_PATH_RATE,' 51 | 'attn_type=config.MODEL.INLINE.ATTN_TYPE,' 52 | 'la_split_size=config.MODEL.INLINE.CSWIN_LA_SPLIT_SIZE)') 53 | 54 | else: 55 | raise NotImplementedError(f"Unkown model: {model_type}") 56 | 57 | return model 58 | -------------------------------------------------------------------------------- /data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import zipfile 10 | import io 11 | import numpy as np 12 | from PIL import Image 13 | from PIL import ImageFile 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def is_zip_path(img_or_path): 19 | """judge if this is a zip path""" 20 | return '.zip@' in img_or_path 21 | 22 | 23 | class ZipReader(object): 24 | """A class to read zipped files""" 25 | zip_bank = dict() 26 | 27 | def __init__(self): 28 | super(ZipReader, self).__init__() 29 | 30 | @staticmethod 31 | def get_zipfile(path): 32 | zip_bank = ZipReader.zip_bank 33 | if path not in zip_bank: 34 | zfile = zipfile.ZipFile(path, 'r') 35 | zip_bank[path] = zfile 36 | return zip_bank[path] 37 | 38 | @staticmethod 39 | def split_zip_style_path(path): 40 | pos_at = path.index('@') 41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 42 | 43 | zip_path = path[0: pos_at] 44 | folder_path = path[pos_at + 1:] 45 | folder_path = str.strip(folder_path, '/') 46 | return zip_path, folder_path 47 | 48 | @staticmethod 49 | def list_folder(path): 50 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 51 | 52 | zfile = ZipReader.get_zipfile(zip_path) 53 | folder_list = [] 54 | for file_foler_name in zfile.namelist(): 55 | file_foler_name = str.strip(file_foler_name, '/') 56 | if file_foler_name.startswith(folder_path) and \ 57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 58 | file_foler_name != folder_path: 59 | if len(folder_path) == 0: 60 | folder_list.append(file_foler_name) 61 | else: 62 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 63 | 64 | return folder_list 65 | 66 | @staticmethod 67 | def list_files(path, extension=None): 68 | if extension is None: 69 | extension = ['.*'] 70 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 71 | 72 | zfile = ZipReader.get_zipfile(zip_path) 73 | file_lists = [] 74 | for file_foler_name in zfile.namelist(): 75 | file_foler_name = str.strip(file_foler_name, '/') 76 | if file_foler_name.startswith(folder_path) and \ 77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 78 | if len(folder_path) == 0: 79 | file_lists.append(file_foler_name) 80 | else: 81 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 82 | 83 | return file_lists 84 | 85 | @staticmethod 86 | def read(path): 87 | zip_path, path_img = ZipReader.split_zip_style_path(path) 88 | zfile = ZipReader.get_zipfile(zip_path) 89 | data = zfile.read(path_img) 90 | return data 91 | 92 | @staticmethod 93 | def imread(path): 94 | zip_path, path_img = ZipReader.split_zip_style_path(path) 95 | zfile = ZipReader.get_zipfile(zip_path) 96 | data = zfile.read(path_img) 97 | try: 98 | im = Image.open(io.BytesIO(data)) 99 | except: 100 | print("ERROR IMG LOADED: ", path_img) 101 | random_img = np.random.rand(224, 224, 3) * 255 102 | im = Image.fromarray(np.uint8(random_img)) 103 | return im 104 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from timm.scheduler.cosine_lr import CosineLRScheduler 10 | from timm.scheduler.step_lr import StepLRScheduler 11 | from timm.scheduler.scheduler import Scheduler 12 | 13 | 14 | def build_scheduler(config, optimizer, n_iter_per_epoch): 15 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 16 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 17 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 18 | 19 | lr_scheduler = None 20 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 21 | lr_scheduler = CosineLRScheduler( 22 | optimizer, 23 | t_initial=num_steps, 24 | lr_min=config.TRAIN.MIN_LR, 25 | warmup_lr_init=config.TRAIN.WARMUP_LR, 26 | warmup_t=warmup_steps, 27 | cycle_limit=1, 28 | t_in_epochs=False, 29 | ) 30 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 31 | lr_scheduler = LinearLRScheduler( 32 | optimizer, 33 | t_initial=num_steps, 34 | lr_min_rate=0.01, 35 | warmup_lr_init=config.TRAIN.WARMUP_LR, 36 | warmup_t=warmup_steps, 37 | t_in_epochs=False, 38 | ) 39 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 40 | lr_scheduler = StepLRScheduler( 41 | optimizer, 42 | decay_t=decay_steps, 43 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 44 | warmup_lr_init=config.TRAIN.WARMUP_LR, 45 | warmup_t=warmup_steps, 46 | t_in_epochs=False, 47 | ) 48 | 49 | return lr_scheduler 50 | 51 | 52 | class LinearLRScheduler(Scheduler): 53 | def __init__(self, 54 | optimizer: torch.optim.Optimizer, 55 | t_initial: int, 56 | lr_min_rate: float, 57 | warmup_t=0, 58 | warmup_lr_init=0., 59 | t_in_epochs=True, 60 | noise_range_t=None, 61 | noise_pct=0.67, 62 | noise_std=1.0, 63 | noise_seed=42, 64 | initialize=True, 65 | ) -> None: 66 | super().__init__( 67 | optimizer, param_group_field="lr", 68 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 69 | initialize=initialize) 70 | 71 | self.t_initial = t_initial 72 | self.lr_min_rate = lr_min_rate 73 | self.warmup_t = warmup_t 74 | self.warmup_lr_init = warmup_lr_init 75 | self.t_in_epochs = t_in_epochs 76 | if self.warmup_t: 77 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 78 | super().update_groups(self.warmup_lr_init) 79 | else: 80 | self.warmup_steps = [1 for _ in self.base_values] 81 | 82 | def _get_lr(self, t): 83 | if t < self.warmup_t: 84 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 85 | else: 86 | t = t - self.warmup_t 87 | total_t = self.t_initial - self.warmup_t 88 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 89 | return lrs 90 | 91 | def get_epoch_values(self, epoch: int): 92 | if self.t_in_epochs: 93 | return self._get_lr(epoch) 94 | else: 95 | return None 96 | 97 | def get_update_values(self, num_updates: int): 98 | if not self.t_in_epochs: 99 | return self._get_lr(num_updates) 100 | else: 101 | return None 102 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin on HF 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import torch 7 | import numpy as np 8 | import torch.distributed as dist 9 | from torchvision import datasets, transforms 10 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 11 | from timm.data import Mixup 12 | from timm.data import create_transform 13 | from timm.data.transforms import _pil_interp 14 | 15 | from .cached_image_folder import CachedImageFolder 16 | from .samplers import SubsetRandomSampler 17 | 18 | 19 | def build_loader(config): 20 | config.defrost() 21 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 22 | config.freeze() 23 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 24 | dataset_val, _ = build_dataset(is_train=False, config=config) 25 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 26 | 27 | num_tasks = dist.get_world_size() 28 | global_rank = dist.get_rank() 29 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 30 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 31 | sampler_train = SubsetRandomSampler(indices) 32 | else: 33 | sampler_train = torch.utils.data.DistributedSampler( 34 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 35 | ) 36 | 37 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 38 | sampler_val = SubsetRandomSampler(indices) 39 | 40 | data_loader_train = torch.utils.data.DataLoader( 41 | dataset_train, sampler=sampler_train, 42 | batch_size=config.DATA.BATCH_SIZE, 43 | num_workers=config.DATA.NUM_WORKERS, 44 | pin_memory=config.DATA.PIN_MEMORY, 45 | drop_last=True, 46 | ) 47 | 48 | data_loader_val = torch.utils.data.DataLoader( 49 | dataset_val, sampler=sampler_val, 50 | batch_size=config.DATA.BATCH_SIZE, 51 | shuffle=False, 52 | num_workers=config.DATA.NUM_WORKERS, 53 | pin_memory=config.DATA.PIN_MEMORY, 54 | drop_last=False 55 | ) 56 | 57 | # setup mixup / cutmix 58 | mixup_fn = None 59 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 60 | if mixup_active: 61 | mixup_fn = Mixup( 62 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 63 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 64 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 65 | 66 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 67 | 68 | 69 | def build_dataset(is_train, config): 70 | transform = build_transform(is_train, config) 71 | if config.DATA.DATASET == 'imagenet': 72 | prefix = 'train' if is_train else 'val' 73 | if config.DATA.ZIP_MODE: 74 | ann_file = prefix + "_map.txt" 75 | prefix = prefix + ".zip@/" 76 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 77 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') 78 | else: 79 | root = os.path.join(config.DATA.DATA_PATH, prefix) 80 | dataset = datasets.ImageFolder(root, transform=transform) 81 | nb_classes = 1000 82 | else: 83 | raise NotImplementedError("We only support ImageNet Now.") 84 | 85 | return dataset, nb_classes 86 | 87 | 88 | def build_transform(is_train, config): 89 | resize_im = config.DATA.IMG_SIZE > 32 90 | if is_train: 91 | # this should always dispatch to transforms_imagenet_train 92 | transform = create_transform( 93 | input_size=config.DATA.IMG_SIZE, 94 | is_training=True, 95 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 96 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 97 | re_prob=config.AUG.REPROB, 98 | re_mode=config.AUG.REMODE, 99 | re_count=config.AUG.RECOUNT, 100 | interpolation=config.DATA.INTERPOLATION, 101 | ) 102 | if not resize_im: 103 | # replace RandomResizedCropAndInterpolation with 104 | # RandomCrop 105 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 106 | return transform 107 | 108 | t = [] 109 | if resize_im: 110 | if config.TEST.CROP: 111 | size = int((256 / 224) * config.DATA.IMG_SIZE) 112 | t.append( 113 | transforms.Resize((size, size), interpolation=_pil_interp(config.DATA.INTERPOLATION)), 114 | # to maintain same ratio w.r.t. 224 images 115 | ) 116 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 117 | else: 118 | t.append( 119 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 120 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 121 | ) 122 | 123 | t.append(transforms.ToTensor()) 124 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 125 | print(t) 126 | return transforms.Compose(t) 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bridging the divide: Reconsidering softmax and linear attention 2 | 3 | This repo contains the official PyTorch code and pre-trained models for **Injective Linear Attention (InLine)**. 4 | 5 | + [Bridging the divide: Reconsidering softmax and linear attention](https://arxiv.org/abs/2412.06590) [[中文讲解]](https://www.bilibili.com/video/BV1BAqCYnEag) 6 | 7 | 8 | 9 | ## News 10 | 11 | - November 12 2024: Initialize repo. 12 | 13 | ## Abstract 14 | 15 | Widely adopted in modern Vision Transformer designs, Softmax attention can effectively capture long-range visual information; however, it incurs excessive computational cost when dealing with high-resolution inputs. In contrast, linear attention naturally enjoys linear complexity and has great potential to scale up to higher-resolution images. Nonetheless, the unsatisfactory performance of linear attention greatly limits its practical application in various scenarios. In this paper, we take a step forward to close the gap between the linear and Softmax attention with novel theoretical analyses, which demystify the core factors behind the per formance deviations. Specifically, we present two key perspectives to understand and alleviate the limitations of linear attention: the **injective property** and the **local modeling ability**. Firstly, we prove that linear attention is not injective, which is prone to assign identical attention weights to different query vectors, thus adding to severe semantic confusion since different queries correspond to the same outputs. Secondly, we confirm that effective local modeling is essential for the success of Softmax attention, in which linear attention falls short. The aforementioned two fundamental differences significantly contribute to the disparities between these two attention paradigms, which is demonstrated by our substantial empirical validation in the paper. In addition, more experiment results indicate that linear attention, as long as endowed with these two properties, can outperform Softmax attention across various tasks while maintaining lower computation complexity. 16 | 17 | ## Injectivity of Attention Function 18 | 19 | We find that the injectivity of attention function greatly affects the performance of the model. Specifically, *if the attention function is not injective, different queries will induce identical attention distributions, leading to severe semantic confusion within the feature space.* Our prove that the Softmax attention function is an injective function, whereas the linear attention function is not. Therefore, linear attention is vulnerable to the semantic confusion problem, which largely leads to its insufficient expressiveness. 20 | 21 |

22 | 23 |

24 | 25 | Our method, **Injective Linear Attention (InLine)**: 26 | 27 | $$\mathrm{InL_K}(Q_i) = {\left[ 28 | \phi(Q_i)^\top\phi(K_1), 29 | \cdots, 30 | \phi(Q_i)^\top\phi(K_N) 31 | \right]}^\top - \frac{1}{N}\sum_{s=1}^{N} \phi(Q_i)^\top\phi(K_s) + \frac{1}{N}.$$ 32 | 33 | 34 | ## Results 35 | 36 | - ImageNet-1K results. 37 | 38 |

39 | 40 |

41 | 42 | 43 | - Real speed measurements. Benefited from linear complexity and simple design, our InLine attention delivers much higher inference speed than Softmax attention, especially in high-resolution scenarios. 44 | 45 |

46 | 47 |

48 | 49 | 50 | ## Dependencies 51 | 52 | - Python 3.9 53 | - PyTorch == 1.11.0 54 | - torchvision == 0.12.0 55 | - numpy 56 | - timm == 0.4.12 57 | - yacs 58 | 59 | The ImageNet dataset should be prepared as follows: 60 | 61 | ``` 62 | imagenet 63 | ├── train 64 | │ ├── class1 65 | │ │ ├── img1.jpeg 66 | │ │ └── ... 67 | │ ├── class2 68 | │ │ ├── img2.jpeg 69 | │ │ └── ... 70 | │ └── ... 71 | └── val 72 | ├── class1 73 | │ ├── img3.jpeg 74 | │ └── ... 75 | ├── class2 76 | │ ├── img4.jpeg 77 | │ └── ... 78 | └── ... 79 | ``` 80 | 81 | ## Pretrained Models 82 | 83 | | model | Resolution | #Params | FLOPs | acc@1 | config | pretrained weights | 84 | | ------ | :--------: | :-----: | :---: | :---: | :--------------------------: | :----------------------------------------------------------: | 85 | | InLine-DeiT-T | 224 | 6.5M | 1.1G | 74.5 | [config](./cfgs/inline_deit_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/1d6b8191ad6d4114b291/?dl=1) | 86 | | InLine-DeiT-S | 288 | 16.7M | 5.0G | 80.2 | [config](./cfgs/inline_deit_s.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/2f3898b07e9247f3beb3/?dl=1) | 87 | | InLine-DeiT-B | 448 | 23.8M | 17.2G | 82.3 | [config](./cfgs/inline_deit_b.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/10bdd726d4b0435eb34e/?dl=1) | 88 | | InLine-PVT-T | 224 | 12.0M | 2.0G | 78.2 | [config](./cfgs/inline_pvt_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/90ac52b1555b401eb5e6/?dl=1) | 89 | | InLine-PVT-S | 224 | 21.6M | 3.9G | 82.0 | [config](./cfgs/inline_pvt_s.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/1ab953b2479d433080a3/?dl=1) | 90 | | InLine-PVT-M | 224 | 37.6M | 6.9G | 83.2 | [config](./cfgs/inline_pvt_m.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/a72aec31e6084bc0a280/?dl=1) | 91 | | InLine-PVT-L | 224 | 50.2M | 10.2G | 83.6 | [config](./cfgs/inline_pvt_b.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/efd91318ba964f01b288/?dl=1) | 92 | | InLine-Swin-T | 224 | 30M | 4.5G | 82.4 | [config](./cfgs/inline_swin_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/32810869fcc34410966b/?dl=1) | 93 | | InLine-Swin-S | 224 | 50M | 8.7G | 83.6 | [config](./cfgs/inline_swin_s.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/e9657fd247c04c7cb1a1/?dl=1) | 94 | | InLine-Swin-B | 224 | 88M | 15.4G | 84.1 | [config](./cfgs/inline_swin_b.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/bf23564bb64c420aafe1/?dl=1) | 95 | | InLine-CSwin-T | 224 | 25M | 4.3G | 83.2 | [config](./cfgs/inline_cswin_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/49fa2ecc543647c4b970/?dl=1) | 96 | | InLine-CSwin-S | 224 | 43M | 6.8G | 83.8 | [config](./cfgs/inline_cswin_s.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/59f4e65f776f4052b93c/?dl=1) | 97 | | InLine-CSwin-B | 224 | 96M | 14.9G | 84.5 | [config](./cfgs/inline_cswin_b.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/91e17121df284ae38521/?dl=1) | 98 | 99 | ## Model Training and Inference 100 | 101 | - Evaluate InLine-DeiT/PVT/Swin on ImageNet: 102 | 103 | ``` 104 | python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg --data-path --output --eval --resume 105 | ``` 106 | 107 | - To train InLine-DeiT/PVT/Swin on ImageNet from scratch, run: 108 | 109 | ``` 110 | python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg --data-path --output --amp 111 | ``` 112 | 113 | - Evaluate InLine-CSwin on ImageNet: 114 | 115 | ``` 116 | python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg --data-path --output --eval --resume 117 | ``` 118 | 119 | - To train InLine-CSwin on ImageNet from scratch, run: 120 | 121 | ``` 122 | python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg --data-path --output --amp 123 | ``` 124 | 125 | ## Acknowledgements 126 | 127 | This code is developed on the top of [Swin Transformer](https://github.com/microsoft/Swin-Transformer). 128 | 129 | ## Citation 130 | 131 | If you find this repo helpful, please consider citing us. 132 | 133 | ```latex 134 | @inproceedings{han2024inline, 135 | title={Bridging the Divide: Reconsidering Softmax and Linear Attention 136 | }, 137 | author={Han, Dongchen and Pu, Yifan and Xia, Zhuofan and Han, Yizeng and Pan, Xuran and Li, Xiu and Lu, Jiwen and Song, Shiji and Huang, Gao}, 138 | booktitle={NeurIPS}, 139 | year={2024}, 140 | } 141 | ``` 142 | 143 | ## Contact 144 | 145 | If you have any questions, please feel free to contact the authors. 146 | 147 | Dongchen Han: [hdc23@mails.tsinghua.edu.cn](mailto:hdc23@mails.tsinghua.edu.cn) 148 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.RESUME = '' 51 | # Number of classes, overwritten in data preparation 52 | _C.MODEL.NUM_CLASSES = 1000 53 | # Dropout rate 54 | _C.MODEL.DROP_RATE = 0.0 55 | # Drop path rate 56 | _C.MODEL.DROP_PATH_RATE = 0.1 57 | # Label Smoothing 58 | _C.MODEL.LABEL_SMOOTHING = 0.1 59 | 60 | # Swin Transformer parameters 61 | _C.MODEL.SWIN = CN() 62 | _C.MODEL.SWIN.PATCH_SIZE = 4 63 | _C.MODEL.SWIN.IN_CHANS = 3 64 | _C.MODEL.SWIN.EMBED_DIM = 96 65 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 66 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 67 | _C.MODEL.SWIN.WINDOW_SIZE = 7 68 | _C.MODEL.SWIN.MLP_RATIO = 4. 69 | _C.MODEL.SWIN.QKV_BIAS = True 70 | _C.MODEL.SWIN.QK_SCALE = None 71 | _C.MODEL.SWIN.KA = [7, 7, 7, 7] 72 | _C.MODEL.SWIN.DIM_REDUCTION = [4, 4, 4, 4] 73 | _C.MODEL.SWIN.STAGES = [True, True, True, True] 74 | _C.MODEL.SWIN.STAGES_NUM = [-1, -1, -1, -1] 75 | _C.MODEL.SWIN.RPB = True 76 | _C.MODEL.SWIN.PADDING_MODE = 'zeros' 77 | _C.MODEL.SWIN.SHARE_DWC_KERNEL = True 78 | _C.MODEL.SWIN.SHARE_QKV = False 79 | _C.MODEL.SWIN.APE = False 80 | _C.MODEL.SWIN.PATCH_NORM = True 81 | _C.MODEL.SWIN.LR_FACTOR = 2 82 | _C.MODEL.SWIN.DEPTHS_LR = [2, 2, 2, 2] 83 | _C.MODEL.SWIN.FUSION_TYPE = 'add' 84 | _C.MODEL.SWIN.STAGE_CFG = None 85 | 86 | _C.MODEL.SWIN_HR = CN(new_allowed=True) 87 | _C.MODEL.SWIN_LRVIT = CN(new_allowed=True) 88 | _C.MODEL.PVD = CN(new_allowed=True) 89 | 90 | # ----------------------------------------------------------------------------- 91 | # InLine Attention options 92 | # ----------------------------------------------------------------------------- 93 | _C.MODEL.INLINE = CN() 94 | _C.MODEL.INLINE.ATTN_TYPE = 'IIII' 95 | _C.MODEL.INLINE.PVT_LA_SR_RATIOS = 1111 96 | _C.MODEL.INLINE.CSWIN_LA_SPLIT_SIZE = '56-28-14-7' 97 | 98 | # ----------------------------------------------------------------------------- 99 | # Training settings 100 | # ----------------------------------------------------------------------------- 101 | _C.TRAIN = CN() 102 | _C.TRAIN.START_EPOCH = 0 103 | _C.TRAIN.EPOCHS = 300 104 | _C.TRAIN.WARMUP_EPOCHS = 20 105 | _C.TRAIN.COOLDOWN_EPOCHS = 0 106 | _C.TRAIN.WEIGHT_DECAY = 0.05 107 | _C.TRAIN.BASE_LR = 5e-4 108 | _C.TRAIN.WARMUP_LR = 5e-7 109 | _C.TRAIN.MIN_LR = 5e-6 110 | # Clip gradient norm 111 | _C.TRAIN.CLIP_GRAD = 5.0 112 | # Auto resume from latest checkpoint 113 | _C.TRAIN.AUTO_RESUME = True 114 | # Whether to use gradient checkpointing to save memory 115 | # could be overwritten by command line argument 116 | _C.TRAIN.USE_CHECKPOINT = False 117 | 118 | # LR scheduler 119 | _C.TRAIN.LR_SCHEDULER = CN() 120 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 121 | # Epoch interval to decay LR, used in StepLRScheduler 122 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 123 | # LR decay rate, used in StepLRScheduler 124 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 125 | 126 | # Optimizer 127 | _C.TRAIN.OPTIMIZER = CN() 128 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 129 | # Optimizer Epsilon 130 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 131 | # Optimizer Betas 132 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 133 | # SGD momentum 134 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 135 | 136 | # ----------------------------------------------------------------------------- 137 | # Augmentation settings 138 | # ----------------------------------------------------------------------------- 139 | _C.AUG = CN() 140 | # Color jitter factor 141 | _C.AUG.COLOR_JITTER = 0.4 142 | # Use AutoAugment policy. "v0" or "original" 143 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 144 | # Random erase prob 145 | _C.AUG.REPROB = 0.25 146 | # Random erase mode 147 | _C.AUG.REMODE = 'pixel' 148 | # Random erase count 149 | _C.AUG.RECOUNT = 1 150 | # Mixup alpha, mixup enabled if > 0 151 | _C.AUG.MIXUP = 0.8 152 | # Cutmix alpha, cutmix enabled if > 0 153 | _C.AUG.CUTMIX = 1.0 154 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 155 | _C.AUG.CUTMIX_MINMAX = None 156 | # Probability of performing mixup or cutmix when either/both is enabled 157 | _C.AUG.MIXUP_PROB = 1.0 158 | # Probability of switching to cutmix when both mixup and cutmix enabled 159 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 160 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 161 | _C.AUG.MIXUP_MODE = 'batch' 162 | 163 | # ----------------------------------------------------------------------------- 164 | # Testing settings 165 | # ----------------------------------------------------------------------------- 166 | _C.TEST = CN() 167 | # Whether to use center crop when testing 168 | _C.TEST.CROP = True 169 | 170 | # ----------------------------------------------------------------------------- 171 | # Misc 172 | # ----------------------------------------------------------------------------- 173 | 174 | # overwritten by command line argument 175 | _C.AMP = False 176 | # Path to output folder, overwritten by command line argument 177 | _C.OUTPUT = '' 178 | # Tag of experiment, overwritten by command line argument 179 | _C.TAG = 'default' 180 | # Frequency to save checkpoint 181 | _C.SAVE_FREQ = 1 182 | # Frequency to logging info 183 | _C.PRINT_FREQ = 100 184 | # Fixed random seed 185 | _C.SEED = 0 186 | # Perform evaluation only, overwritten by command line argument 187 | _C.EVAL_MODE = False 188 | # Test throughput only, overwritten by command line argument 189 | _C.THROUGHPUT_MODE = False 190 | # local rank for DistributedDataParallel, given by command line argument 191 | _C.LOCAL_RANK = 0 192 | 193 | 194 | def _update_config_from_file(config, cfg_file): 195 | config.defrost() 196 | with open(cfg_file, 'r') as f: 197 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 198 | 199 | for cfg in yaml_cfg.setdefault('BASE', ['']): 200 | if cfg: 201 | _update_config_from_file( 202 | config, os.path.join(os.path.dirname(cfg_file), cfg) 203 | ) 204 | print('=> merge config from {}'.format(cfg_file)) 205 | config.merge_from_file(cfg_file) 206 | config.freeze() 207 | 208 | 209 | def update_config(config, args): 210 | _update_config_from_file(config, args.cfg) 211 | 212 | config.defrost() 213 | if args.opts: 214 | config.merge_from_list(args.opts) 215 | 216 | # merge from specific arguments 217 | if args.batch_size: 218 | config.DATA.BATCH_SIZE = args.batch_size 219 | if args.data_path: 220 | config.DATA.DATA_PATH = args.data_path 221 | if args.zip: 222 | config.DATA.ZIP_MODE = True 223 | if args.cache_mode: 224 | config.DATA.CACHE_MODE = args.cache_mode 225 | if args.resume: 226 | config.MODEL.RESUME = args.resume 227 | if args.use_checkpoint: 228 | config.TRAIN.USE_CHECKPOINT = True 229 | if args.amp: 230 | config.AMP = args.amp 231 | if args.output: 232 | config.OUTPUT = args.output 233 | if args.tag: 234 | config.TAG = args.tag 235 | if args.eval: 236 | config.EVAL_MODE = True 237 | if args.throughput: 238 | config.THROUGHPUT_MODE = True 239 | 240 | # set local rank for distributed training 241 | # config.LOCAL_RANK = args.local_rank 242 | 243 | # output folder 244 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 245 | 246 | config.freeze() 247 | 248 | 249 | def get_config(args): 250 | """Get a yacs CfgNode object with default values.""" 251 | # Return a clone so that the defaults will not be altered 252 | # This is for the "local variable" use pattern 253 | config = _C.clone() 254 | update_config(config, args) 255 | 256 | return config 257 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger): 14 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 15 | if config.MODEL.RESUME.startswith('https'): 16 | checkpoint = torch.hub.load_state_dict_from_url( 17 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 18 | else: 19 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 20 | msg = model.load_state_dict(checkpoint['model'], strict=False) 21 | logger.info(msg) 22 | max_accuracy = 0.0 23 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 24 | optimizer.load_state_dict(checkpoint['optimizer']) 25 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 26 | config.defrost() 27 | config.TRAIN.START_EPOCH = checkpoint['epoch'] 28 | config.freeze() 29 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 30 | if 'max_accuracy' in checkpoint: 31 | max_accuracy = checkpoint['max_accuracy'] 32 | 33 | del checkpoint 34 | torch.cuda.empty_cache() 35 | return max_accuracy 36 | 37 | 38 | def load_pretrained(ckpt_path, model, logger): 39 | logger.info(f"==============> Loading pretrained form {ckpt_path}....................") 40 | checkpoint = torch.load(ckpt_path, map_location='cpu') 41 | # msg = model.load_pretrained(checkpoint['model']) 42 | # logger.info(msg) 43 | # logger.info(f"=> Loaded successfully {ckpt_path} ") 44 | # del checkpoint 45 | # torch.cuda.empty_cache() 46 | state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint 47 | 48 | # delete relative_position_index since we always re-init it 49 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] 50 | for k in relative_position_index_keys: 51 | del state_dict[k] 52 | 53 | # delete relative_coords_table since we always re-init it 54 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] 55 | for k in relative_position_index_keys: 56 | del state_dict[k] 57 | 58 | # delete attn_mask since we always re-init it 59 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] 60 | for k in attn_mask_keys: 61 | del state_dict[k] 62 | 63 | # bicubic interpolate relative_position_bias_table if not match 64 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 65 | for k in relative_position_bias_table_keys: 66 | relative_position_bias_table_pretrained = state_dict[k] 67 | relative_position_bias_table_current = model.state_dict()[k] 68 | L1, nH1 = relative_position_bias_table_pretrained.size() 69 | L2, nH2 = relative_position_bias_table_current.size() 70 | if nH1 != nH2: 71 | logger.warning(f"Error in loading {k}, passing......") 72 | else: 73 | if L1 != L2: 74 | # bicubic interpolate relative_position_bias_table if not match 75 | S1 = int(L1 ** 0.5) 76 | S2 = int(L2 ** 0.5) 77 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( 78 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), 79 | mode='bicubic') 80 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) 81 | 82 | # bicubic interpolate absolute_pos_embed if not match 83 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "pos_embed" in k] 84 | for k in absolute_pos_embed_keys: 85 | # dpe 86 | absolute_pos_embed_pretrained = state_dict[k] 87 | absolute_pos_embed_current = model.state_dict()[k] 88 | _, L1, C1 = absolute_pos_embed_pretrained.size() 89 | _, L2, C2 = absolute_pos_embed_current.size() 90 | if C1 != C1: 91 | logger.warning(f"Error in loading {k}, passing......") 92 | else: 93 | if L1 != L2: 94 | S1 = int(L1 ** 0.5) 95 | S2 = int(L2 ** 0.5) 96 | i, j = L1 - S1 ** 2, L2 - S2 ** 2 97 | absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained[:, i:, :].reshape(-1, S1, S1, C1) 98 | absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained_.permute(0, 3, 1, 2) 99 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( 100 | absolute_pos_embed_pretrained_, size=(S2, S2), mode='bicubic') 101 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) 102 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) 103 | state_dict[k] = torch.cat([absolute_pos_embed_pretrained[:, :j, :], 104 | absolute_pos_embed_pretrained_resized], dim=1) 105 | 106 | # check classifier, if not match, then re-init classifier to zero 107 | head_bias_pretrained = state_dict['head.bias'] 108 | Nc1 = head_bias_pretrained.shape[0] 109 | Nc2 = model.head.bias.shape[0] 110 | if (Nc1 != Nc2): 111 | if Nc1 == 21841 and Nc2 == 1000: 112 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......") 113 | map22kto1k_path = f'data/map22kto1k.txt' 114 | with open(map22kto1k_path) as f: 115 | map22kto1k = f.readlines() 116 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] 117 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] 118 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] 119 | else: 120 | torch.nn.init.constant_(model.head.bias, 0.) 121 | torch.nn.init.constant_(model.head.weight, 0.) 122 | del state_dict['head.weight'] 123 | del state_dict['head.bias'] 124 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0") 125 | 126 | msg = model.load_state_dict(state_dict, strict=False) 127 | logger.warning(msg) 128 | 129 | logger.info(f"=> loaded successfully '{ckpt_path}'") 130 | 131 | del checkpoint 132 | torch.cuda.empty_cache() 133 | 134 | 135 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger): 136 | save_state = {'model': model.state_dict(), 137 | 'optimizer': optimizer.state_dict(), 138 | 'lr_scheduler': lr_scheduler.state_dict(), 139 | 'max_accuracy': max_accuracy, 140 | 'epoch': epoch, 141 | 'config': config} 142 | 143 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 144 | logger.info(f"{save_path} saving......") 145 | torch.save(save_state, save_path) 146 | logger.info(f"{save_path} saved !!!") 147 | 148 | def save_checkpoint_new(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, name=None): 149 | save_state = {'model': model.state_dict(), 150 | 'optimizer': optimizer.state_dict(), 151 | 'lr_scheduler': lr_scheduler.state_dict(), 152 | 'max_accuracy': max_accuracy, 153 | 'epoch': epoch, 154 | 'config': config} 155 | if name == None: 156 | old_ckpt = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch-3}.pth') 157 | if os.path.exists(old_ckpt): 158 | os.remove(old_ckpt) 159 | 160 | if name != None: 161 | save_path = os.path.join(config.OUTPUT, f'{name}.pth') 162 | logger.info(f"{save_path} saving......") 163 | torch.save(save_state, save_path) 164 | logger.info(f"{save_path} saved !!!") 165 | else: 166 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 167 | logger.info(f"{save_path} saving......") 168 | torch.save(save_state, save_path) 169 | logger.info(f"{save_path} saved !!!") 170 | 171 | 172 | def get_grad_norm(parameters, norm_type=2): 173 | if isinstance(parameters, torch.Tensor): 174 | parameters = [parameters] 175 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 176 | norm_type = float(norm_type) 177 | total_norm = 0 178 | for p in parameters: 179 | param_norm = p.grad.data.norm(norm_type) 180 | total_norm += param_norm.item() ** norm_type 181 | total_norm = total_norm ** (1. / norm_type) 182 | return total_norm 183 | 184 | 185 | def auto_resume_helper(output_dir): 186 | checkpoints = os.listdir(output_dir) 187 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 188 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 189 | if len(checkpoints) > 0: 190 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 191 | print(f"The latest checkpoint founded: {latest_checkpoint}") 192 | resume_file = latest_checkpoint 193 | else: 194 | resume_file = None 195 | return resume_file 196 | 197 | 198 | def reduce_tensor(tensor): 199 | rt = tensor.clone() 200 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 201 | rt /= dist.get_world_size() 202 | return rt 203 | -------------------------------------------------------------------------------- /data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import io 9 | import os 10 | import time 11 | import torch.distributed as dist 12 | import torch.utils.data as data 13 | from PIL import Image 14 | 15 | from .zipreader import is_zip_path, ZipReader 16 | 17 | 18 | def has_file_allowed_extension(filename, extensions): 19 | """Checks if a file is an allowed extension. 20 | Args: 21 | filename (string): path to a file 22 | Returns: 23 | bool: True if the filename ends with a known image extension 24 | """ 25 | filename_lower = filename.lower() 26 | return any(filename_lower.endswith(ext) for ext in extensions) 27 | 28 | 29 | def find_classes(dir): 30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 31 | classes.sort() 32 | class_to_idx = {classes[i]: i for i in range(len(classes))} 33 | return classes, class_to_idx 34 | 35 | 36 | def make_dataset(dir, class_to_idx, extensions): 37 | images = [] 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if has_file_allowed_extension(fname, extensions): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 55 | images = [] 56 | with open(ann_file, "r") as f: 57 | contents = f.readlines() 58 | for line_str in contents: 59 | path_contents = [c for c in line_str.split('\t')] 60 | im_file_name = path_contents[0] 61 | class_index = int(path_contents[1]) 62 | 63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 64 | item = (os.path.join(img_prefix, im_file_name), class_index) 65 | 66 | images.append(item) 67 | 68 | return images 69 | 70 | 71 | class DatasetFolder(data.Dataset): 72 | """A generic data loader where the samples are arranged in this way: :: 73 | root/class_x/xxx.ext 74 | root/class_x/xxy.ext 75 | root/class_x/xxz.ext 76 | root/class_y/123.ext 77 | root/class_y/nsdf3.ext 78 | root/class_y/asd932_.ext 79 | Args: 80 | root (string): Root directory path. 81 | loader (callable): A function to load a sample given its path. 82 | extensions (list[string]): A list of allowed extensions. 83 | transform (callable, optional): A function/transform that takes in 84 | a sample and returns a transformed version. 85 | E.g, ``transforms.RandomCrop`` for images. 86 | target_transform (callable, optional): A function/transform that takes 87 | in the target and transforms it. 88 | Attributes: 89 | samples (list): List of (sample path, class_index) tuples 90 | """ 91 | 92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 93 | cache_mode="no"): 94 | # image folder mode 95 | if ann_file == '': 96 | _, class_to_idx = find_classes(root) 97 | samples = make_dataset(root, class_to_idx, extensions) 98 | # zip mode 99 | else: 100 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 101 | os.path.join(root, img_prefix), 102 | extensions) 103 | 104 | if len(samples) == 0: 105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.samples = samples 113 | self.labels = [y_1k for _, y_1k in samples] 114 | self.classes = list(set(self.labels)) 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.cache_mode = cache_mode 120 | if self.cache_mode != "no": 121 | self.init_cache() 122 | 123 | def init_cache(self): 124 | assert self.cache_mode in ["part", "full"] 125 | n_sample = len(self.samples) 126 | global_rank = dist.get_rank() 127 | world_size = dist.get_world_size() 128 | 129 | samples_bytes = [None for _ in range(n_sample)] 130 | start_time = time.time() 131 | for index in range(n_sample): 132 | if index % (n_sample // 10) == 0: 133 | t = time.time() - start_time 134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 135 | start_time = time.time() 136 | path, target = self.samples[index] 137 | if self.cache_mode == "full": 138 | samples_bytes[index] = (ZipReader.read(path), target) 139 | elif self.cache_mode == "part" and index % world_size == global_rank: 140 | samples_bytes[index] = (ZipReader.read(path), target) 141 | else: 142 | samples_bytes[index] = (path, target) 143 | self.samples = samples_bytes 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | Returns: 150 | tuple: (sample, target) where target is class_index of the target class. 151 | """ 152 | path, target = self.samples[index] 153 | sample = self.loader(path) 154 | if self.transform is not None: 155 | sample = self.transform(sample) 156 | if self.target_transform is not None: 157 | target = self.target_transform(target) 158 | 159 | return sample, target 160 | 161 | def __len__(self): 162 | return len(self.samples) 163 | 164 | def __repr__(self): 165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 167 | fmt_str += ' Root Location: {}\n'.format(self.root) 168 | tmp = ' Transforms (if any): ' 169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | tmp = ' Target Transforms (if any): ' 171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 172 | return fmt_str 173 | 174 | 175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 176 | 177 | 178 | def pil_loader(path): 179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 180 | if isinstance(path, bytes): 181 | img = Image.open(io.BytesIO(path)) 182 | elif is_zip_path(path): 183 | data = ZipReader.read(path) 184 | img = Image.open(io.BytesIO(data)) 185 | else: 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | 190 | 191 | def accimage_loader(path): 192 | import accimage 193 | try: 194 | return accimage.Image(path) 195 | except IOError: 196 | # Potentially a decoding problem, fall back to PIL.Image 197 | return pil_loader(path) 198 | 199 | 200 | def default_img_loader(path): 201 | from torchvision import get_image_backend 202 | if get_image_backend() == 'accimage': 203 | return accimage_loader(path) 204 | else: 205 | return pil_loader(path) 206 | 207 | 208 | class CachedImageFolder(DatasetFolder): 209 | """A generic data loader where the images are arranged in this way: :: 210 | root/dog/xxx.png 211 | root/dog/xxy.png 212 | root/dog/xxz.png 213 | root/cat/123.png 214 | root/cat/nsdf3.png 215 | root/cat/asd932_.png 216 | Args: 217 | root (string): Root directory path. 218 | transform (callable, optional): A function/transform that takes in an PIL image 219 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 220 | target_transform (callable, optional): A function/transform that takes in the 221 | target and transforms it. 222 | loader (callable, optional): A function to load an image given its path. 223 | Attributes: 224 | imgs (list): List of (image path, class_index) tuples 225 | """ 226 | 227 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 228 | loader=default_img_loader, cache_mode="no"): 229 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 230 | ann_file=ann_file, img_prefix=img_prefix, 231 | transform=transform, target_transform=target_transform, 232 | cache_mode=cache_mode) 233 | self.imgs = self.samples 234 | 235 | def __getitem__(self, index): 236 | """ 237 | Args: 238 | index (int): Index 239 | Returns: 240 | tuple: (image, target) where target is class_index of the target class. 241 | """ 242 | path, target = self.samples[index] 243 | image = self.loader(path) 244 | if self.transform is not None: 245 | img = self.transform(image) 246 | else: 247 | img = image 248 | if self.target_transform is not None: 249 | target = self.target_transform(target) 250 | 251 | return img, target 252 | -------------------------------------------------------------------------------- /utils_ema.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.distributed as dist 11 | from timm.utils.model import unwrap_model, get_state_dict 12 | 13 | 14 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger): 15 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 16 | if config.MODEL.RESUME.startswith('https'): 17 | checkpoint = torch.hub.load_state_dict_from_url( 18 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 19 | else: 20 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 21 | msg = model.load_state_dict(checkpoint['model'], strict=False) 22 | logger.info(msg) 23 | max_accuracy = 0.0 24 | max_accuracy_e = 0.0 25 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 26 | optimizer.load_state_dict(checkpoint['optimizer']) 27 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 28 | config.defrost() 29 | config.TRAIN.START_EPOCH = checkpoint['epoch'] 30 | config.freeze() 31 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 32 | if 'max_accuracy' in checkpoint: 33 | max_accuracy = checkpoint['max_accuracy'] 34 | max_accuracy_e = checkpoint['max_accuracy_e'] 35 | 36 | del checkpoint 37 | torch.cuda.empty_cache() 38 | return max_accuracy, max_accuracy_e 39 | 40 | 41 | def load_pretrained(ckpt_path, model, logger): 42 | logger.info(f"==============> Loading pretrained form {ckpt_path}....................") 43 | checkpoint = torch.load(ckpt_path, map_location='cpu') 44 | # msg = model.load_pretrained(checkpoint['model']) 45 | # logger.info(msg) 46 | # logger.info(f"=> Loaded successfully {ckpt_path} ") 47 | # del checkpoint 48 | # torch.cuda.empty_cache() 49 | state_dict = checkpoint['state_dict_ema'] if 'state_dict_ema' in checkpoint.keys() else checkpoint 50 | 51 | # delete relative_position_index since we always re-init it 52 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] 53 | for k in relative_position_index_keys: 54 | del state_dict[k] 55 | 56 | # delete relative_coords_table since we always re-init it 57 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] 58 | for k in relative_position_index_keys: 59 | del state_dict[k] 60 | 61 | # delete attn_mask since we always re-init it 62 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] 63 | for k in attn_mask_keys: 64 | del state_dict[k] 65 | 66 | # bicubic interpolate relative_position_bias_table if not match 67 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 68 | for k in relative_position_bias_table_keys: 69 | relative_position_bias_table_pretrained = state_dict[k] 70 | relative_position_bias_table_current = model.state_dict()[k] 71 | L1, nH1 = relative_position_bias_table_pretrained.size() 72 | L2, nH2 = relative_position_bias_table_current.size() 73 | if nH1 != nH2: 74 | logger.warning(f"Error in loading {k}, passing......") 75 | else: 76 | if L1 != L2: 77 | # bicubic interpolate relative_position_bias_table if not match 78 | S1 = int(L1 ** 0.5) 79 | S2 = int(L2 ** 0.5) 80 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( 81 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), 82 | mode='bicubic') 83 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) 84 | 85 | # bicubic interpolate absolute_pos_embed if not match 86 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "pos_embed" in k] 87 | for k in absolute_pos_embed_keys: 88 | # dpe 89 | absolute_pos_embed_pretrained = state_dict[k] 90 | absolute_pos_embed_current = model.state_dict()[k] 91 | _, L1, C1 = absolute_pos_embed_pretrained.size() 92 | _, L2, C2 = absolute_pos_embed_current.size() 93 | if C1 != C1: 94 | logger.warning(f"Error in loading {k}, passing......") 95 | else: 96 | if L1 != L2: 97 | S1 = int(L1 ** 0.5) 98 | S2 = int(L2 ** 0.5) 99 | i, j = L1 - S1 ** 2, L2 - S2 ** 2 100 | absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained[:, i:, :].reshape(-1, S1, S1, C1) 101 | absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained_.permute(0, 3, 1, 2) 102 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( 103 | absolute_pos_embed_pretrained_, size=(S2, S2), mode='bicubic') 104 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) 105 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) 106 | state_dict[k] = torch.cat([absolute_pos_embed_pretrained[:, :j, :], 107 | absolute_pos_embed_pretrained_resized], dim=1) 108 | 109 | # check classifier, if not match, then re-init classifier to zero 110 | head_bias_pretrained = state_dict['head.bias'] 111 | Nc1 = head_bias_pretrained.shape[0] 112 | Nc2 = model.head.bias.shape[0] 113 | if (Nc1 != Nc2): 114 | if Nc1 == 21841 and Nc2 == 1000: 115 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......") 116 | map22kto1k_path = f'data/map22kto1k.txt' 117 | with open(map22kto1k_path) as f: 118 | map22kto1k = f.readlines() 119 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] 120 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] 121 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] 122 | else: 123 | torch.nn.init.constant_(model.head.bias, 0.) 124 | torch.nn.init.constant_(model.head.weight, 0.) 125 | del state_dict['head.weight'] 126 | del state_dict['head.bias'] 127 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0") 128 | 129 | msg = model.load_state_dict(state_dict, strict=False) 130 | logger.warning(msg) 131 | 132 | logger.info(f"=> loaded successfully '{ckpt_path}'") 133 | 134 | del checkpoint 135 | torch.cuda.empty_cache() 136 | 137 | 138 | def save_checkpoint(config, epoch, model, max_accuracy, max_accuracy_e, optimizer, lr_scheduler, logger): 139 | save_state = {'model': model.state_dict(), 140 | # 'model_ema': model_ema.state_dict(), 141 | 'optimizer': optimizer.state_dict(), 142 | 'lr_scheduler': lr_scheduler.state_dict(), 143 | 'max_accuracy': max_accuracy, 144 | 'max_accuracy_e': max_accuracy_e, 145 | 'epoch': epoch, 146 | 'config': config} 147 | 148 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 149 | logger.info(f"{save_path} saving......") 150 | torch.save(save_state, save_path) 151 | logger.info(f"{save_path} saved !!!") 152 | 153 | def save_checkpoint_ema(config, epoch, model, model_ema, max_accuracy, max_accuracy_e, optimizer, lr_scheduler, logger): 154 | save_state = {'model': model.state_dict(), 155 | # 'model_ema': model_ema.state_dict(), 156 | 'optimizer': optimizer.state_dict(), 157 | 'lr_scheduler': lr_scheduler.state_dict(), 158 | 'max_accuracy': max_accuracy, 159 | 'max_accuracy_e': max_accuracy_e, 160 | 'epoch': epoch, 161 | 'config': config} 162 | save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model) 163 | 164 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 165 | logger.info(f"{save_path} saving......") 166 | torch.save(save_state, save_path) 167 | logger.info(f"{save_path} saved !!!") 168 | 169 | 170 | def save_checkpoint_ema_new(config, epoch, model, model_ema, max_accuracy, max_accuracy_e, optimizer, lr_scheduler, logger, name=None): 171 | save_state = {'model': model.state_dict(), 172 | # 'model_ema': model_ema.state_dict(), 173 | 'optimizer': optimizer.state_dict(), 174 | 'lr_scheduler': lr_scheduler.state_dict(), 175 | 'max_accuracy': max_accuracy, 176 | 'max_accuracy_e': max_accuracy_e, 177 | 'epoch': epoch, 178 | 'config': config} 179 | save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model) 180 | 181 | if name == None: 182 | old_ckpt = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch-3}.pth') 183 | if os.path.exists(old_ckpt): 184 | os.remove(old_ckpt) 185 | 186 | if name != None: 187 | save_path = os.path.join(config.OUTPUT, f'{name}.pth') 188 | logger.info(f"{save_path} saving......") 189 | torch.save(save_state, save_path) 190 | logger.info(f"{save_path} saved !!!") 191 | else: 192 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 193 | logger.info(f"{save_path} saving......") 194 | torch.save(save_state, save_path) 195 | logger.info(f"{save_path} saved !!!") 196 | 197 | 198 | def get_grad_norm(parameters, norm_type=2): 199 | if isinstance(parameters, torch.Tensor): 200 | parameters = [parameters] 201 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 202 | norm_type = float(norm_type) 203 | total_norm = 0 204 | for p in parameters: 205 | param_norm = p.grad.data.norm(norm_type) 206 | total_norm += param_norm.item() ** norm_type 207 | total_norm = total_norm ** (1. / norm_type) 208 | return total_norm 209 | 210 | 211 | def auto_resume_helper(output_dir): 212 | checkpoints = os.listdir(output_dir) 213 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 214 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 215 | if len(checkpoints) > 0: 216 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 217 | print(f"The latest checkpoint founded: {latest_checkpoint}") 218 | resume_file = latest_checkpoint 219 | else: 220 | resume_file = None 221 | return resume_file 222 | 223 | 224 | def reduce_tensor(tensor): 225 | rt = tensor.clone() 226 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 227 | rt /= dist.get_world_size() 228 | return rt 229 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import datetime 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.backends.cudnn as cudnn 10 | import torch.distributed as dist 11 | from torch.cuda.amp import autocast, GradScaler 12 | 13 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 14 | from timm.utils import accuracy, AverageMeter 15 | 16 | from config import get_config 17 | from models import build_model 18 | from data import build_loader 19 | from lr_scheduler import build_scheduler 20 | from optimizer import build_optimizer 21 | from logger import create_logger 22 | from utils import load_checkpoint, save_checkpoint, save_checkpoint_new, get_grad_norm, auto_resume_helper, reduce_tensor, load_pretrained 23 | 24 | import warnings 25 | warnings.filterwarnings('ignore') 26 | 27 | def parse_option(): 28 | parser = argparse.ArgumentParser('InLine Attention training and evaluation script', add_help=False) 29 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 30 | parser.add_argument( 31 | "--opts", 32 | help="Modify config options by adding 'KEY VALUE' pairs. ", 33 | default=None, 34 | nargs='+', 35 | ) 36 | 37 | # easy config modification 38 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 39 | parser.add_argument('--data-path', type=str, help='path to dataset') 40 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 41 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 42 | help='no: no cache, ' 43 | 'full: cache all data, ' 44 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 45 | parser.add_argument('--resume', help='resume from checkpoint') 46 | parser.add_argument('--use-checkpoint', action='store_true', 47 | help="whether to use gradient checkpointing to save memory") 48 | parser.add_argument('--amp', action='store_true', default=False) 49 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 50 | help='root of output folder, the full path is // (default: output)') 51 | parser.add_argument('--tag', help='tag of experiment') 52 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 53 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 54 | parser.add_argument('--pretrained', type=str, help='Finetune 384 initial checkpoint.', default='') 55 | parser.add_argument('--find-unused-params', action='store_true', default=False) 56 | 57 | args, unparsed = parser.parse_known_args() 58 | 59 | config = get_config(args) 60 | 61 | return args, config 62 | 63 | 64 | def main(): 65 | os.environ["NCCL_BLOCKING_WAIT"] = "1" 66 | args, config = parse_option() 67 | 68 | rank = int(os.environ["RANK"]) 69 | world_size = int(os.environ['WORLD_SIZE']) 70 | local_rank = int(os.environ['LOCAL_RANK']) 71 | torch.cuda.set_device(local_rank) 72 | dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 73 | 74 | seed = config.SEED + dist.get_rank() 75 | torch.manual_seed(seed) 76 | np.random.seed(seed) 77 | cudnn.enabled = True 78 | cudnn.benchmark = True 79 | 80 | # linear scale the learning rate according to total batch size, may not be optimal 81 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 82 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 83 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 84 | 85 | config.defrost() 86 | config.TRAIN.BASE_LR = linear_scaled_lr 87 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 88 | config.TRAIN.MIN_LR = linear_scaled_min_lr 89 | config.LOCAL_RANK = local_rank 90 | config.freeze() 91 | 92 | os.makedirs(config.OUTPUT, exist_ok=True) 93 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 94 | 95 | if dist.get_rank() == 0: 96 | path = os.path.join(config.OUTPUT, "config.json") 97 | with open(path, "w") as f: 98 | f.write(config.dump()) 99 | logger.info(f"Full config saved to {path}") 100 | 101 | # print config 102 | logger.info(config.dump()) 103 | 104 | _, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) 105 | 106 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 107 | model = build_model(config) 108 | model.cuda() 109 | logger.info(str(model)) 110 | 111 | optimizer = build_optimizer(config, model) 112 | 113 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=True, find_unused_parameters=args.find_unused_params) 114 | model_without_ddp = model.module 115 | 116 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 117 | total_epochs = config.TRAIN.EPOCHS + config.TRAIN.COOLDOWN_EPOCHS 118 | 119 | if config.AUG.MIXUP > 0.: 120 | # smoothing is handled with mixup label transform 121 | criterion = SoftTargetCrossEntropy() 122 | elif config.MODEL.LABEL_SMOOTHING > 0.: 123 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 124 | else: 125 | criterion = nn.CrossEntropyLoss() 126 | 127 | max_accuracy = 0.0 128 | 129 | if args.pretrained != '': 130 | load_pretrained(args.pretrained, model_without_ddp, logger) 131 | 132 | if config.TRAIN.AUTO_RESUME: 133 | resume_file = auto_resume_helper(config.OUTPUT) 134 | if resume_file: 135 | if config.MODEL.RESUME: 136 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 137 | config.defrost() 138 | config.MODEL.RESUME = resume_file 139 | config.freeze() 140 | logger.info(f'auto resuming from {resume_file}') 141 | else: 142 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 143 | 144 | if config.MODEL.RESUME: 145 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 146 | acc1, acc5, loss = validate(config, data_loader_val, model, logger) 147 | max_accuracy = max(max_accuracy, acc1) 148 | torch.cuda.empty_cache() 149 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 150 | if config.EVAL_MODE: 151 | return 152 | 153 | if config.THROUGHPUT_MODE: 154 | throughput(data_loader_val, model, logger) 155 | return 156 | 157 | logger.info("Start training") 158 | start_time = time.time() 159 | for epoch in range(config.TRAIN.START_EPOCH, total_epochs): 160 | data_loader_train.sampler.set_epoch(epoch) 161 | 162 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, logger, total_epochs) 163 | acc1, acc5, loss = validate(config, data_loader_val, model, logger) 164 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 165 | 166 | if dist.get_rank() == 0 and ((epoch + 1) % config.SAVE_FREQ == 0 or (epoch + 1) == (total_epochs)): 167 | save_checkpoint_new(config, epoch + 1, model_without_ddp, max(max_accuracy, acc1), optimizer, lr_scheduler, logger) 168 | 169 | if dist.get_rank() == 0 and ((epoch + 1) % config.SAVE_FREQ == 0 or (epoch + 1) == (total_epochs)) and acc1 >= max_accuracy: 170 | save_checkpoint_new(config, epoch + 1, model_without_ddp, max(max_accuracy, acc1), optimizer, lr_scheduler, logger, name='max_acc') 171 | max_accuracy = max(max_accuracy, acc1) 172 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 173 | 174 | total_time = time.time() - start_time 175 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 176 | logger.info('Training time {}'.format(total_time_str)) 177 | 178 | 179 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, logger, total_epochs): 180 | model.train() 181 | optimizer.zero_grad() 182 | 183 | num_steps = len(data_loader) 184 | batch_time = AverageMeter() 185 | loss_meter = AverageMeter() 186 | norm_meter = AverageMeter() 187 | 188 | start = time.time() 189 | end = time.time() 190 | 191 | scaler = GradScaler() 192 | 193 | for idx, (samples, targets) in enumerate(data_loader): 194 | 195 | optimizer.zero_grad() 196 | samples = samples.cuda(non_blocking=True) 197 | targets = targets.cuda(non_blocking=True) 198 | 199 | if mixup_fn is not None: 200 | samples, targets = mixup_fn(samples, targets) 201 | 202 | if config.AMP: 203 | with autocast(): 204 | outputs = model(samples) 205 | loss = criterion(outputs, targets) 206 | scaler.scale(loss).backward() 207 | if config.TRAIN.CLIP_GRAD: 208 | scaler.unscale_(optimizer) 209 | grad_norm = nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 210 | scaler.step(optimizer) 211 | scaler.update() 212 | else: 213 | grad_norm = get_grad_norm(model.parameters()) 214 | scaler.step(optimizer) 215 | scaler.update() 216 | else: 217 | outputs = model(samples) 218 | loss = criterion(outputs, targets) 219 | loss.backward() 220 | if config.TRAIN.CLIP_GRAD: 221 | grad_norm = nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 222 | else: 223 | grad_norm = get_grad_norm(model.parameters()) 224 | optimizer.step() 225 | 226 | lr_scheduler.step_update(epoch * num_steps + idx) 227 | 228 | torch.cuda.synchronize() 229 | 230 | loss_meter.update(loss.item(), targets.size(0)) 231 | norm_meter.update(grad_norm) 232 | batch_time.update(time.time() - end) 233 | end = time.time() 234 | 235 | if (idx + 1) % config.PRINT_FREQ == 0: 236 | lr = optimizer.param_groups[0]['lr'] 237 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 238 | etas = batch_time.avg * (num_steps - idx) 239 | logger.info( 240 | f'Train: [{epoch + 1}/{total_epochs}][{idx + 1}/{num_steps}]\t' 241 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 242 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 243 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 244 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 245 | f'mem {memory_used:.0f}MB') 246 | epoch_time = time.time() - start 247 | logger.info(f"EPOCH {epoch + 1} training takes {datetime.timedelta(seconds=int(epoch_time))}") 248 | 249 | 250 | @torch.no_grad() 251 | def validate(config, data_loader, model, logger): 252 | criterion = nn.CrossEntropyLoss() 253 | model.eval() 254 | 255 | batch_time = AverageMeter() 256 | loss_meter = AverageMeter() 257 | acc1_meter = AverageMeter() 258 | acc5_meter = AverageMeter() 259 | 260 | end = time.time() 261 | for idx, (images, target) in enumerate(data_loader): 262 | images = images.cuda(non_blocking=True) 263 | target = target.cuda(non_blocking=True) 264 | 265 | # compute output 266 | output = model(images) 267 | 268 | # measure accuracy and record loss 269 | loss = criterion(output, target) 270 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 271 | 272 | acc1 = reduce_tensor(acc1) 273 | acc5 = reduce_tensor(acc5) 274 | loss = reduce_tensor(loss) 275 | 276 | loss_meter.update(loss.item(), target.size(0)) 277 | acc1_meter.update(acc1.item(), target.size(0)) 278 | acc5_meter.update(acc5.item(), target.size(0)) 279 | 280 | # measure elapsed time 281 | batch_time.update(time.time() - end) 282 | end = time.time() 283 | 284 | if (idx + 1) % config.PRINT_FREQ == 0: 285 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 286 | logger.info( 287 | f'Test: [{(idx + 1)}/{len(data_loader)}]\t' 288 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 289 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 290 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 291 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 292 | f'Mem {memory_used:.0f}MB') 293 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 294 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 295 | 296 | 297 | @torch.no_grad() 298 | def throughput(data_loader, model, logger): 299 | model.eval() 300 | 301 | for _, (images, _) in enumerate(data_loader): 302 | images = images.cuda(non_blocking=True) 303 | batch_size = images.shape[0] 304 | for i in range(50): 305 | model(images) 306 | torch.cuda.synchronize() 307 | logger.info(f"throughput averaged with 30 times") 308 | tic1 = time.time() 309 | for i in range(30): 310 | model(images) 311 | torch.cuda.synchronize() 312 | tic2 = time.time() 313 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 314 | return 315 | 316 | 317 | if __name__ == '__main__': 318 | 319 | main() 320 | -------------------------------------------------------------------------------- /models/inline_pvt.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------- 2 | # Bridging the Divide: Reconsidering Softmax and Linear Attention 3 | # Modified by Dongchen Han 4 | # ----------------------------------------------------------------------- 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from functools import partial 11 | 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | from timm.models.registry import register_model 14 | from timm.models.vision_transformer import _cfg 15 | 16 | __all__ = [ 17 | 'inline_pvt_tiny', 'inline_pvt_small', 'inline_pvt_medium', 'inline_pvt_large' 18 | ] 19 | 20 | 21 | class Mlp(nn.Module): 22 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.fc1 = nn.Linear(in_features, hidden_features) 27 | self.act = act_layer() 28 | self.fc2 = nn.Linear(hidden_features, out_features) 29 | self.drop = nn.Dropout(drop) 30 | 31 | def forward(self, x): 32 | x = self.fc1(x) 33 | x = self.act(x) 34 | x = self.drop(x) 35 | x = self.fc2(x) 36 | x = self.drop(x) 37 | return x 38 | 39 | 40 | class Attention(nn.Module): 41 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 42 | super().__init__() 43 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 44 | 45 | self.dim = dim 46 | self.num_heads = num_heads 47 | head_dim = dim // num_heads 48 | self.scale = qk_scale or head_dim ** -0.5 49 | 50 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 51 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | self.sr_ratio = sr_ratio 57 | if sr_ratio > 1: 58 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 59 | self.norm = nn.LayerNorm(dim) 60 | 61 | def forward(self, x, H, W): 62 | B, N, C = x.shape 63 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 64 | 65 | if self.sr_ratio > 1: 66 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 67 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 68 | x_ = self.norm(x_) 69 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 70 | else: 71 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 72 | k, v = kv[0], kv[1] 73 | 74 | attn = (q @ k.transpose(-2, -1)) * self.scale 75 | attn = attn.softmax(dim=-1) 76 | attn = self.attn_drop(attn) 77 | 78 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | 82 | return x 83 | 84 | 85 | class InLineAttention(nn.Module): 86 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 87 | sr_ratio=1, **kwargs): 88 | super().__init__() 89 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 90 | 91 | self.dim = dim 92 | self.num_heads = num_heads 93 | head_dim = dim // num_heads 94 | self.scale = head_dim ** -0.5 95 | 96 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 97 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 98 | self.attn_drop = nn.Dropout(attn_drop) 99 | self.proj = nn.Linear(dim, dim) 100 | self.proj_drop = nn.Dropout(proj_drop) 101 | 102 | self.sr_ratio = sr_ratio 103 | if sr_ratio > 1: 104 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 105 | self.norm = nn.LayerNorm(dim) 106 | 107 | self.residual = nn.Sequential( 108 | nn.Conv1d(dim, dim, kernel_size=1, groups=num_heads), 109 | nn.GELU(), 110 | nn.Conv1d(dim, dim * 9, kernel_size=1, groups=num_heads) 111 | ) 112 | 113 | def forward(self, x, H, W): 114 | b, n, c = x.shape 115 | num_heads = self.num_heads 116 | head_dim = c // num_heads 117 | q = self.q(x) 118 | 119 | if self.sr_ratio > 1: 120 | x_ = x.permute(0, 2, 1).reshape(b, c, H, W) 121 | x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1) 122 | x_ = self.norm(x_) 123 | kv = self.kv(x_).reshape(b, -1, 2, c).permute(2, 0, 1, 3) 124 | else: 125 | kv = self.kv(x).reshape(b, -1, 2, c).permute(2, 0, 1, 3) 126 | k, v = kv[0], kv[1] 127 | 128 | q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 129 | k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 130 | v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 131 | 132 | res_weight = self.residual(x.mean(dim=1).unsqueeze(dim=-1)).reshape(b * c, 1, 3, 3) 133 | 134 | # The self.scale / n = head_dim ** -0.5 / n is a scale factor used in InLine attention. 135 | # This factor can be equivalently achieved by scaling \phi(Q) = \phi(Q) * self.scale / n 136 | # Therefore, we omit it in eq. 5 of the paper for simplicity. 137 | kv = (k.transpose(-2, -1) * (self.scale / n) ** 0.5) @ (v * (self.scale / n) ** 0.5) 138 | x = q @ kv + (1 - q @ k.mean(dim=2, keepdim=True).transpose(-2, -1) * self.scale) * v.mean(dim=2, keepdim=True) 139 | 140 | x = x.transpose(1, 2).reshape(b, n, c) 141 | v = v.transpose(1, 2).reshape(b, H, W, c).permute(0, 3, 1, 2).reshape(1, b * c, H, W) 142 | residual = F.conv2d(v, res_weight, None, padding=(1, 1), groups=b * c) 143 | x = x + residual.reshape(b, c, n).permute(0, 2, 1) 144 | 145 | x = self.proj(x) 146 | x = self.proj_drop(x) 147 | return x 148 | 149 | 150 | class Block(nn.Module): 151 | 152 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 153 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, attn_type='I'): 154 | super().__init__() 155 | self.norm1 = norm_layer(dim) 156 | assert attn_type in ['I', 'S'] 157 | attn = InLineAttention if attn_type == 'I' else Attention 158 | self.attn = attn( 159 | dim, 160 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 161 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 162 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 163 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 164 | self.norm2 = norm_layer(dim) 165 | mlp_hidden_dim = int(dim * mlp_ratio) 166 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 167 | 168 | def forward(self, x, H, W): 169 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 170 | x = x + self.drop_path(self.mlp(self.norm2(x))) 171 | 172 | return x 173 | 174 | 175 | class PatchEmbed(nn.Module): 176 | """ Image to Patch Embedding 177 | """ 178 | 179 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 180 | super().__init__() 181 | img_size = to_2tuple(img_size) 182 | patch_size = to_2tuple(patch_size) 183 | 184 | self.img_size = img_size 185 | self.patch_size = patch_size 186 | # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ 187 | # f"img_size {img_size} should be divided by patch_size {patch_size}." 188 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 189 | self.num_patches = self.H * self.W 190 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 191 | self.norm = nn.LayerNorm(embed_dim) 192 | 193 | def forward(self, x): 194 | B, C, H, W = x.shape 195 | 196 | x = self.proj(x).flatten(2).transpose(1, 2) 197 | x = self.norm(x) 198 | H, W = H // self.patch_size[0], W // self.patch_size[1] 199 | 200 | return x, (H, W) 201 | 202 | 203 | class PyramidVisionTransformer(nn.Module): 204 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 205 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 206 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 207 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], la_sr_ratios='8421', num_stages=4, 208 | attn_type='IIII', **kwargs): 209 | super().__init__() 210 | self.num_classes = num_classes 211 | self.depths = depths 212 | self.num_stages = num_stages 213 | 214 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 215 | cur = 0 216 | 217 | attn_type = 'IIII' if attn_type is None else attn_type 218 | for i in range(num_stages): 219 | patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i - 1) * patch_size), 220 | patch_size=patch_size if i == 0 else 2, 221 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 222 | embed_dim=embed_dims[i]) 223 | num_patches = patch_embed.num_patches 224 | pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i])) 225 | pos_drop = nn.Dropout(p=drop_rate) 226 | 227 | block = nn.ModuleList([Block( 228 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 229 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], 230 | norm_layer=norm_layer, sr_ratio=sr_ratios[i] if attn_type[i] == 'S' else int(la_sr_ratios[i]), 231 | attn_type=attn_type[i]) 232 | for j in range(depths[i])]) 233 | cur += depths[i] 234 | 235 | setattr(self, f"patch_embed{i + 1}", patch_embed) 236 | setattr(self, f"pos_embed{i + 1}", pos_embed) 237 | setattr(self, f"pos_drop{i + 1}", pos_drop) 238 | setattr(self, f"block{i + 1}", block) 239 | 240 | # classification head 241 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 242 | 243 | # init weights 244 | for i in range(num_stages): 245 | pos_embed = getattr(self, f"pos_embed{i + 1}") 246 | trunc_normal_(pos_embed, std=.02) 247 | # trunc_normal_(self.cls_token, std=.02) 248 | self.apply(self._init_weights) 249 | 250 | def _init_weights(self, m): 251 | if isinstance(m, nn.Linear): 252 | trunc_normal_(m.weight, std=.02) 253 | if isinstance(m, nn.Linear) and m.bias is not None: 254 | nn.init.constant_(m.bias, 0) 255 | elif isinstance(m, nn.LayerNorm): 256 | nn.init.constant_(m.bias, 0) 257 | nn.init.constant_(m.weight, 1.0) 258 | 259 | @torch.jit.ignore 260 | def no_weight_decay(self): 261 | # return {'pos_embed', 'cls_token'} # has pos_embed may be better 262 | return {'cls_token'} 263 | 264 | def get_classifier(self): 265 | return self.head 266 | 267 | def reset_classifier(self, num_classes, global_pool=''): 268 | self.num_classes = num_classes 269 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 270 | 271 | def _get_pos_embed(self, pos_embed, patch_embed, H, W): 272 | if H * W == self.patch_embed1.num_patches: 273 | return pos_embed 274 | else: 275 | return F.interpolate( 276 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 277 | size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 278 | 279 | def forward_features(self, x): 280 | B = x.shape[0] 281 | 282 | for i in range(self.num_stages): 283 | patch_embed = getattr(self, f"patch_embed{i + 1}") 284 | pos_embed = getattr(self, f"pos_embed{i + 1}") 285 | pos_drop = getattr(self, f"pos_drop{i + 1}") 286 | block = getattr(self, f"block{i + 1}") 287 | x, (H, W) = patch_embed(x) 288 | 289 | pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W) 290 | 291 | x = pos_drop(x + pos_embed) 292 | for blk in block: 293 | x = blk(x, H, W) 294 | if i != self.num_stages - 1: 295 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 296 | 297 | return x.mean(dim=1) 298 | 299 | def forward(self, x): 300 | x = self.forward_features(x) 301 | x = self.head(x) 302 | 303 | return x 304 | 305 | 306 | def _conv_filter(state_dict, patch_size=16): 307 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 308 | out_dict = {} 309 | for k, v in state_dict.items(): 310 | if 'patch_embed.proj.weight' in k: 311 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 312 | out_dict[k] = v 313 | 314 | return out_dict 315 | 316 | 317 | def inline_pvt_tiny(pretrained=False, **kwargs): 318 | model = PyramidVisionTransformer( 319 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[2, 4, 10, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 320 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 321 | **kwargs) 322 | model.default_cfg = _cfg() 323 | 324 | return model 325 | 326 | 327 | def inline_pvt_small(pretrained=False, **kwargs): 328 | model = PyramidVisionTransformer( 329 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[2, 4, 10, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 330 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) 331 | model.default_cfg = _cfg() 332 | 333 | return model 334 | 335 | 336 | def inline_pvt_medium(pretrained=False, **kwargs): 337 | model = PyramidVisionTransformer( 338 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[2, 4, 10, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 339 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 340 | **kwargs) 341 | model.default_cfg = _cfg() 342 | 343 | return model 344 | 345 | 346 | def inline_pvt_large(pretrained=False, **kwargs): 347 | model = PyramidVisionTransformer( 348 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[2, 4, 10, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 349 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 350 | **kwargs) 351 | model.default_cfg = _cfg() 352 | 353 | return model 354 | -------------------------------------------------------------------------------- /main_ema.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import datetime 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.backends.cudnn as cudnn 10 | import torch.distributed as dist 11 | from torch.cuda.amp import autocast, GradScaler 12 | 13 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 14 | from timm.utils import accuracy, AverageMeter, ModelEma 15 | 16 | from config import get_config 17 | from models import build_model 18 | from data import build_loader 19 | from lr_scheduler import build_scheduler 20 | from optimizer import build_optimizer 21 | from logger import create_logger 22 | from utils_ema import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor, load_pretrained, save_checkpoint_ema, save_checkpoint_ema_new 23 | 24 | import warnings 25 | warnings.filterwarnings('ignore') 26 | 27 | def parse_option(): 28 | parser = argparse.ArgumentParser('InLine Attention training and evaluation script', add_help=False) 29 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 30 | parser.add_argument( 31 | "--opts", 32 | help="Modify config options by adding 'KEY VALUE' pairs. ", 33 | default=None, 34 | nargs='+', 35 | ) 36 | 37 | # easy config modification 38 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 39 | parser.add_argument('--data-path', type=str, help='path to dataset') 40 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 41 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 42 | help='no: no cache, ' 43 | 'full: cache all data, ' 44 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 45 | parser.add_argument('--resume', help='resume from checkpoint') 46 | parser.add_argument('--use-checkpoint', action='store_true', 47 | help="whether to use gradient checkpointing to save memory") 48 | parser.add_argument('--amp', action='store_true', default=False) 49 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 50 | help='root of output folder, the full path is // (default: output)') 51 | parser.add_argument('--tag', help='tag of experiment') 52 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 53 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 54 | parser.add_argument('--pretrained', type=str, help='Finetune 384 initial checkpoint.', default='') 55 | parser.add_argument('--find-unused-params', action='store_true', default=False) 56 | 57 | # Model Exponential Moving Average 58 | parser.add_argument('--model-ema', action='store_true', default=True, 59 | help='Enable tracking moving average of model weights') 60 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 61 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 62 | parser.add_argument('--model-ema-decay', type=float, default=0.9996, 63 | help='decay factor for model weights moving average (default: 0.9996)') 64 | 65 | args, unparsed = parser.parse_known_args() 66 | 67 | config = get_config(args) 68 | 69 | return args, config 70 | 71 | 72 | def main(): 73 | os.environ["NCCL_BLOCKING_WAIT"] = "1" 74 | args, config = parse_option() 75 | 76 | rank = int(os.environ["RANK"]) 77 | world_size = int(os.environ['WORLD_SIZE']) 78 | local_rank = int(os.environ['LOCAL_RANK']) 79 | torch.cuda.set_device(local_rank) 80 | dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 81 | 82 | seed = config.SEED + dist.get_rank() 83 | torch.manual_seed(seed) 84 | np.random.seed(seed) 85 | cudnn.enabled = True 86 | cudnn.benchmark = True 87 | 88 | # linear scale the learning rate according to total batch size, may not be optimal 89 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 90 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 91 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 92 | 93 | config.defrost() 94 | config.TRAIN.BASE_LR = linear_scaled_lr 95 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 96 | config.TRAIN.MIN_LR = linear_scaled_min_lr 97 | config.LOCAL_RANK = local_rank 98 | config.freeze() 99 | 100 | # adjust ema decay according to total batch size, may not be optimal 101 | args.model_ema_decay = args.model_ema_decay ** (config.DATA.BATCH_SIZE * dist.get_world_size() / 4096.0) 102 | 103 | os.makedirs(config.OUTPUT, exist_ok=True) 104 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 105 | 106 | if dist.get_rank() == 0: 107 | path = os.path.join(config.OUTPUT, "config.json") 108 | with open(path, "w") as f: 109 | f.write(config.dump()) 110 | logger.info(f"Full config saved to {path}") 111 | 112 | # print config 113 | logger.info(config.dump()) 114 | 115 | _, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) 116 | 117 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 118 | model = build_model(config) 119 | model.cuda() 120 | logger.info(str(model)) 121 | 122 | optimizer = build_optimizer(config, model) 123 | 124 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=True, find_unused_parameters=args.find_unused_params) 125 | model_without_ddp = model.module 126 | 127 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 128 | total_epochs = config.TRAIN.EPOCHS + config.TRAIN.COOLDOWN_EPOCHS 129 | 130 | if config.AUG.MIXUP > 0.: 131 | # smoothing is handled with mixup label transform 132 | criterion = SoftTargetCrossEntropy() 133 | elif config.MODEL.LABEL_SMOOTHING > 0.: 134 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 135 | else: 136 | criterion = nn.CrossEntropyLoss() 137 | 138 | max_accuracy = 0.0 139 | max_accuracy_e = 0.0 140 | 141 | if args.pretrained != '': 142 | load_pretrained(args.pretrained, model_without_ddp, logger) 143 | 144 | if config.TRAIN.AUTO_RESUME: 145 | resume_file = auto_resume_helper(config.OUTPUT) 146 | if resume_file: 147 | if config.MODEL.RESUME: 148 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 149 | config.defrost() 150 | config.MODEL.RESUME = resume_file 151 | config.freeze() 152 | logger.info(f'auto resuming from {resume_file}') 153 | else: 154 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 155 | 156 | if config.MODEL.RESUME: 157 | max_accuracy, max_accuracy_e = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 158 | acc1, acc5, loss = validate(config, data_loader_val, model, logger) 159 | max_accuracy = max(max_accuracy, acc1) 160 | torch.cuda.empty_cache() 161 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 162 | if config.EVAL_MODE and not args.model_ema: 163 | return 164 | 165 | model_ema = None 166 | if args.model_ema: 167 | if not config.EVAL_MODE: 168 | logger.info(f'Model EMA decay {args.model_ema_decay}') 169 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 170 | model_ema = ModelEma( 171 | model, 172 | decay=args.model_ema_decay, 173 | device='cpu' if args.model_ema_force_cpu else '', 174 | resume=config.MODEL.RESUME) 175 | acc1_e, acc5_e, loss_e = validate(config, data_loader_val, model_ema.ema, logger) 176 | torch.cuda.empty_cache() 177 | logger.info(f"Accuracy of the ema network on the {len(dataset_val)} test images: {acc1_e:.1f}%") 178 | if config.EVAL_MODE: 179 | return 180 | 181 | if config.THROUGHPUT_MODE: 182 | throughput(data_loader_val, model, logger) 183 | return 184 | 185 | logger.info("Start training") 186 | start_time = time.time() 187 | for epoch in range(config.TRAIN.START_EPOCH, total_epochs): 188 | data_loader_train.sampler.set_epoch(epoch) 189 | 190 | train_one_epoch(config, model, model_ema, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, logger, total_epochs) 191 | acc1, acc5, loss = validate(config, data_loader_val, model, logger) 192 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 193 | if model_ema is not None and not args.model_ema_force_cpu: 194 | acc1_e, acc5_e, loss_e = validate(config, data_loader_val, model_ema.ema, logger) 195 | logger.info(f"Accuracy of the ema network on the {len(dataset_val)} test images: {acc1_e:.1f}%") 196 | else: 197 | acc1_e, acc5_e, loss_e = 0, 0, 0 198 | 199 | if dist.get_rank() == 0 and ((epoch + 1) % config.SAVE_FREQ == 0 or (epoch + 1) == (total_epochs)): 200 | save_checkpoint_ema_new(config, epoch + 1, model_without_ddp, model_ema, max(max_accuracy, acc1), max(max_accuracy_e, acc1_e), optimizer, lr_scheduler, logger) 201 | 202 | if dist.get_rank() == 0 and ((epoch + 1) % config.SAVE_FREQ == 0 or (epoch + 1) == (total_epochs)) and acc1 >= max_accuracy: 203 | save_checkpoint_ema_new(config, epoch + 1, model_without_ddp, model_ema, max(max_accuracy, acc1), max(max_accuracy_e, acc1_e), optimizer, lr_scheduler, logger, name='max_acc') 204 | max_accuracy = max(max_accuracy, acc1) 205 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 206 | 207 | if model_ema is not None and not args.model_ema_force_cpu: 208 | if dist.get_rank() == 0 and ((epoch + 1) % config.SAVE_FREQ == 0 or (epoch + 1) == (total_epochs)) and acc1_e >= max_accuracy_e: 209 | save_checkpoint_ema_new(config, epoch + 1, model_without_ddp, model_ema, max(max_accuracy, acc1), max(max_accuracy_e, acc1_e), optimizer, lr_scheduler, logger, name='max_ema_acc') 210 | max_accuracy_e = max(max_accuracy_e, acc1_e) 211 | logger.info(f'Max ema accuracy: {max_accuracy_e:.2f}%') 212 | 213 | total_time = time.time() - start_time 214 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 215 | logger.info('Training time {}'.format(total_time_str)) 216 | 217 | 218 | def train_one_epoch(config, model, model_ema, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, logger, total_epochs): 219 | model.train() 220 | optimizer.zero_grad() 221 | 222 | num_steps = len(data_loader) 223 | batch_time = AverageMeter() 224 | loss_meter = AverageMeter() 225 | norm_meter = AverageMeter() 226 | 227 | start = time.time() 228 | end = time.time() 229 | 230 | scaler = GradScaler() 231 | 232 | for idx, (samples, targets) in enumerate(data_loader): 233 | 234 | optimizer.zero_grad() 235 | samples = samples.cuda(non_blocking=True) 236 | targets = targets.cuda(non_blocking=True) 237 | 238 | if mixup_fn is not None: 239 | samples, targets = mixup_fn(samples, targets) 240 | 241 | if config.AMP: 242 | with autocast(): 243 | outputs = model(samples) 244 | loss = criterion(outputs, targets) 245 | scaler.scale(loss).backward() 246 | if config.TRAIN.CLIP_GRAD: 247 | scaler.unscale_(optimizer) 248 | grad_norm = nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 249 | scaler.step(optimizer) 250 | scaler.update() 251 | else: 252 | grad_norm = get_grad_norm(model.parameters()) 253 | scaler.step(optimizer) 254 | scaler.update() 255 | else: 256 | outputs = model(samples) 257 | loss = criterion(outputs, targets) 258 | loss.backward() 259 | if config.TRAIN.CLIP_GRAD: 260 | grad_norm = nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 261 | else: 262 | grad_norm = get_grad_norm(model.parameters()) 263 | optimizer.step() 264 | 265 | lr_scheduler.step_update(epoch * num_steps + idx) 266 | 267 | torch.cuda.synchronize() 268 | if model_ema is not None: 269 | model_ema.update(model) 270 | 271 | loss_meter.update(loss.item(), targets.size(0)) 272 | norm_meter.update(grad_norm) 273 | batch_time.update(time.time() - end) 274 | end = time.time() 275 | 276 | if (idx + 1) % config.PRINT_FREQ == 0: 277 | lr = optimizer.param_groups[0]['lr'] 278 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 279 | etas = batch_time.avg * (num_steps - idx) 280 | logger.info( 281 | f'Train: [{epoch + 1}/{total_epochs}][{idx + 1}/{num_steps}]\t' 282 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 283 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 284 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 285 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 286 | f'mem {memory_used:.0f}MB') 287 | epoch_time = time.time() - start 288 | logger.info(f"EPOCH {epoch + 1} training takes {datetime.timedelta(seconds=int(epoch_time))}") 289 | 290 | 291 | @torch.no_grad() 292 | def validate(config, data_loader, model, logger): 293 | criterion = nn.CrossEntropyLoss() 294 | model.eval() 295 | 296 | batch_time = AverageMeter() 297 | loss_meter = AverageMeter() 298 | acc1_meter = AverageMeter() 299 | acc5_meter = AverageMeter() 300 | 301 | end = time.time() 302 | for idx, (images, target) in enumerate(data_loader): 303 | images = images.cuda(non_blocking=True) 304 | target = target.cuda(non_blocking=True) 305 | 306 | # compute output 307 | output = model(images) 308 | 309 | # measure accuracy and record loss 310 | loss = criterion(output, target) 311 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 312 | 313 | acc1 = reduce_tensor(acc1) 314 | acc5 = reduce_tensor(acc5) 315 | loss = reduce_tensor(loss) 316 | 317 | loss_meter.update(loss.item(), target.size(0)) 318 | acc1_meter.update(acc1.item(), target.size(0)) 319 | acc5_meter.update(acc5.item(), target.size(0)) 320 | 321 | # measure elapsed time 322 | batch_time.update(time.time() - end) 323 | end = time.time() 324 | 325 | if (idx + 1) % config.PRINT_FREQ == 0: 326 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 327 | logger.info( 328 | f'Test: [{(idx + 1)}/{len(data_loader)}]\t' 329 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 330 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 331 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 332 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 333 | f'Mem {memory_used:.0f}MB') 334 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 335 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 336 | 337 | 338 | @torch.no_grad() 339 | def throughput(data_loader, model, logger): 340 | model.eval() 341 | 342 | for _, (images, _) in enumerate(data_loader): 343 | images = images.cuda(non_blocking=True) 344 | batch_size = images.shape[0] 345 | for i in range(50): 346 | model(images) 347 | torch.cuda.synchronize() 348 | logger.info(f"throughput averaged with 30 times") 349 | tic1 = time.time() 350 | for i in range(30): 351 | model(images) 352 | torch.cuda.synchronize() 353 | tic2 = time.time() 354 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 355 | return 356 | 357 | 358 | if __name__ == '__main__': 359 | 360 | main() 361 | -------------------------------------------------------------------------------- /models/inline_cswin.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # CSWin Transformer 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT License. 5 | # written By Xiaoyi Dong 6 | # ------------------------------------------ 7 | # Bridging the Divide: Reconsidering Softmax and Linear Attention 8 | # Modified by Dongchen Han 9 | # ----------------------------------------------------------------------- 10 | 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from functools import partial 16 | 17 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | from timm.models.helpers import load_pretrained 19 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 20 | from timm.models.registry import register_model 21 | from einops.layers.torch import Rearrange 22 | import torch.utils.checkpoint as checkpoint 23 | import numpy as np 24 | import time 25 | 26 | 27 | def _cfg(url='', **kwargs): 28 | return { 29 | 'url': url, 30 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 31 | 'crop_pct': .9, 'interpolation': 'bicubic', 32 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 33 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 34 | **kwargs 35 | } 36 | 37 | 38 | default_cfgs = { 39 | 'cswin_224': _cfg(), 40 | 'cswin_384': _cfg( 41 | crop_pct=1.0 42 | ), 43 | 44 | } 45 | 46 | 47 | class Mlp(nn.Module): 48 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 49 | super().__init__() 50 | out_features = out_features or in_features 51 | hidden_features = hidden_features or in_features 52 | self.fc1 = nn.Linear(in_features, hidden_features) 53 | self.act = act_layer() 54 | self.fc2 = nn.Linear(hidden_features, out_features) 55 | self.drop = nn.Dropout(drop) 56 | 57 | def forward(self, x): 58 | x = self.fc1(x) 59 | x = self.act(x) 60 | x = self.drop(x) 61 | x = self.fc2(x) 62 | x = self.drop(x) 63 | return x 64 | 65 | 66 | class LePEAttention(nn.Module): 67 | def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0., 68 | qk_scale=None): 69 | super().__init__() 70 | self.dim = dim 71 | self.dim_out = dim_out or dim 72 | self.resolution = resolution 73 | self.split_size = split_size 74 | self.num_heads = num_heads 75 | head_dim = dim // num_heads 76 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 77 | self.scale = qk_scale or head_dim ** -0.5 78 | if idx == -1: 79 | H_sp, W_sp = self.resolution, self.resolution 80 | elif idx == 0: 81 | H_sp, W_sp = self.resolution, self.split_size 82 | elif idx == 1: 83 | W_sp, H_sp = self.resolution, self.split_size 84 | else: 85 | print("ERROR MODE", idx) 86 | exit(0) 87 | self.H_sp = H_sp 88 | self.W_sp = W_sp 89 | stride = 1 90 | self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 91 | 92 | self.attn_drop = nn.Dropout(attn_drop) 93 | 94 | def im2cswin(self, x): 95 | B, N, C = x.shape 96 | H = W = int(np.sqrt(N)) 97 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 98 | x = img2windows(x, self.H_sp, self.W_sp) 99 | x = x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous() 100 | return x 101 | 102 | def get_lepe(self, x, func): 103 | B, N, C = x.shape 104 | H = W = int(np.sqrt(N)) 105 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 106 | 107 | H_sp, W_sp = self.H_sp, self.W_sp 108 | x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 109 | x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp) ### B', C, H', W' 110 | 111 | lepe = func(x) ### B', C, H', W' 112 | lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3, 2).contiguous() 113 | 114 | x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3, 2).contiguous() 115 | return x, lepe 116 | 117 | def forward(self, qkv, x_mean): 118 | """ 119 | x: B L C 120 | """ 121 | q, k, v = qkv[0], qkv[1], qkv[2] 122 | 123 | ### Img2Window 124 | H = W = self.resolution 125 | B, L, C = q.shape 126 | assert L == H * W, "flatten img_tokens has wrong size" 127 | 128 | q = self.im2cswin(q) 129 | k = self.im2cswin(k) 130 | v, lepe = self.get_lepe(v, self.get_v) 131 | 132 | q = q * self.scale 133 | attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N 134 | attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype) 135 | attn = self.attn_drop(attn) 136 | 137 | x = (attn @ v) + lepe 138 | x = x.transpose(1, 2).reshape(-1, self.H_sp * self.W_sp, C) # B head N N @ B head N C 139 | 140 | ### Window2Img 141 | x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C) # B H' W' C 142 | 143 | return x 144 | 145 | 146 | class InLineAttention(nn.Module): 147 | def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0., **kwargs): 148 | super().__init__() 149 | self.dim = dim 150 | self.dim_out = dim_out or dim 151 | self.resolution = resolution 152 | self.split_size = split_size 153 | self.num_heads = num_heads 154 | head_dim = dim // num_heads 155 | self.scale = head_dim ** -0.5 156 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 157 | # self.scale = qk_scale or head_dim ** -0.5 158 | if idx == -1: 159 | H_sp, W_sp = self.resolution, self.resolution 160 | elif idx == 0: 161 | H_sp, W_sp = self.resolution, self.split_size 162 | elif idx == 1: 163 | W_sp, H_sp = self.resolution, self.split_size 164 | else: 165 | print("ERROR MODE", idx) 166 | exit(0) 167 | self.H_sp = H_sp 168 | self.W_sp = W_sp 169 | self.get_v = nn.Conv2d(dim, dim, kernel_size=(3, 3), stride=(1, 1), padding=1, groups=dim) 170 | 171 | self.attn_drop = nn.Dropout(attn_drop) 172 | 173 | self.residual = nn.Sequential( 174 | nn.Conv1d(dim, dim, kernel_size=1, groups=num_heads), 175 | nn.GELU(), 176 | nn.Conv1d(dim, dim * 9, kernel_size=1, groups=num_heads) 177 | ) 178 | 179 | def im2cswin(self, x): 180 | B, N, C = x.shape 181 | H = W = int(np.sqrt(N)) 182 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 183 | x = img2windows(x, self.H_sp, self.W_sp) 184 | # x = x.reshape(-1, self.H_sp * self.W_sp, C).contiguous() 185 | return x 186 | 187 | def get_lepe(self, x, func): 188 | B, N, C = x.shape 189 | H = W = int(np.sqrt(N)) 190 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 191 | 192 | H_sp, W_sp = self.H_sp, self.W_sp 193 | x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 194 | x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp) ### B', C, H', W' 195 | 196 | lepe = func(x) ### B', C, H', W' 197 | lepe = lepe.reshape(-1, C, H_sp * W_sp).permute(0, 2, 1).contiguous() 198 | 199 | x = x.reshape(-1, C, self.H_sp * self.W_sp).permute(0, 2, 1).contiguous() 200 | return x, lepe 201 | 202 | def forward(self, qkv, x_mean): 203 | """ 204 | x: B L C 205 | """ 206 | q, k, v = qkv[0], qkv[1], qkv[2] 207 | 208 | ### Img2Window 209 | H = W = self.resolution 210 | B, L, C = q.shape 211 | assert L == H * W, "flatten img_tokens has wrong size" 212 | 213 | q = self.im2cswin(q) 214 | k = self.im2cswin(k) 215 | v, lepe = self.get_lepe(v, self.get_v) 216 | # q, k, v = (rearrange(x, "b h n c -> b n (h c)", h=self.num_heads) for x in [q, k, v]) 217 | 218 | b, n, c = q.shape 219 | h, w = self.H_sp, self.W_sp 220 | num_heads, head_dim = self.num_heads, self.dim // self.num_heads 221 | 222 | q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 223 | k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 224 | v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 225 | 226 | res_weight = self.residual(x_mean.unsqueeze(dim=-1)).reshape(b * c, 1, 3, 3) 227 | 228 | # The self.scale / n = head_dim ** -0.5 / n is a scale factor used in InLine attention. 229 | # This factor can be equivalently achieved by scaling \phi(Q) = \phi(Q) * self.scale / n 230 | # Therefore, we omit it in eq. 5 of the paper for simplicity. 231 | kv = (k.transpose(-2, -1) * (self.scale / n) ** 0.5) @ (v * (self.scale / n) ** 0.5) 232 | x = q @ kv + (1 - q @ k.mean(dim=2, keepdim=True).transpose(-2, -1) * self.scale) * v.mean(dim=2, keepdim=True) 233 | 234 | x = x.transpose(1, 2).reshape(b, n, c) 235 | v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2).reshape(1, b * c, h, w) 236 | residual = F.conv2d(v, res_weight, None, padding=(1, 1), groups=b * c) 237 | x = x + residual.reshape(b, c, n).permute(0, 2, 1) 238 | 239 | x = x + lepe 240 | x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C) 241 | 242 | return x 243 | 244 | 245 | class CSWinBlock(nn.Module): 246 | def __init__(self, dim, reso, num_heads, 247 | split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None, 248 | drop=0., attn_drop=0., drop_path=0., 249 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, 250 | last_stage=False, attn_type='I'): 251 | super().__init__() 252 | self.dim = dim 253 | self.num_heads = num_heads 254 | self.patches_resolution = reso 255 | self.split_size = split_size 256 | self.mlp_ratio = mlp_ratio 257 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 258 | self.norm1 = norm_layer(dim) 259 | 260 | if self.patches_resolution == split_size: 261 | last_stage = True 262 | if last_stage: 263 | self.branch_num = 1 264 | else: 265 | self.branch_num = 2 266 | self.proj = nn.Linear(dim, dim) 267 | self.proj_drop = nn.Dropout(drop) 268 | 269 | assert attn_type in ['I', 'S'] 270 | attn = InLineAttention if attn_type == 'I' else LePEAttention 271 | if last_stage: 272 | self.attns = nn.ModuleList([ 273 | attn( 274 | dim, resolution=self.patches_resolution, idx=-1, 275 | split_size=split_size, num_heads=num_heads, dim_out=dim, 276 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 277 | for i in range(self.branch_num)]) 278 | else: 279 | self.attns = nn.ModuleList([ 280 | attn( 281 | dim // 2, resolution=self.patches_resolution, idx=i, 282 | split_size=split_size, num_heads=num_heads // 2, dim_out=dim // 2, 283 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 284 | for i in range(self.branch_num)]) 285 | 286 | mlp_hidden_dim = int(dim * mlp_ratio) 287 | 288 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 289 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, 290 | drop=drop) 291 | self.norm2 = norm_layer(dim) 292 | 293 | def forward(self, x): 294 | """ 295 | x: B, H*W, C 296 | """ 297 | 298 | H = W = self.patches_resolution 299 | B, L, C = x.shape 300 | assert L == H * W, "flatten img_tokens has wrong size" 301 | img = self.norm1(x) 302 | qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3) 303 | 304 | if self.branch_num == 2: 305 | x1 = self.attns[0](qkv[:, :, :, :C // 2], x.mean(dim=1)[:, :C // 2]) 306 | x2 = self.attns[1](qkv[:, :, :, C // 2:], x.mean(dim=1)[:, C // 2:]) 307 | attened_x = torch.cat([x1, x2], dim=2) 308 | else: 309 | attened_x = self.attns[0](qkv, x.mean(dim=1)) 310 | attened_x = self.proj(attened_x) 311 | x = x + self.drop_path(attened_x) 312 | x = x + self.drop_path(self.mlp(self.norm2(x))) 313 | 314 | return x 315 | 316 | 317 | def img2windows(img, H_sp, W_sp): 318 | """ 319 | img: B C H W 320 | """ 321 | B, C, H, W = img.shape 322 | img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 323 | img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp * W_sp, C) 324 | return img_perm 325 | 326 | 327 | def windows2img(img_splits_hw, H_sp, W_sp, H, W): 328 | """ 329 | img_splits_hw: B' H W C 330 | """ 331 | B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) 332 | 333 | img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) 334 | img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 335 | return img 336 | 337 | 338 | class Merge_Block(nn.Module): 339 | def __init__(self, dim, dim_out, norm_layer=nn.LayerNorm): 340 | super().__init__() 341 | self.conv = nn.Conv2d(dim, dim_out, 3, 2, 1) 342 | self.norm = norm_layer(dim_out) 343 | 344 | def forward(self, x): 345 | B, new_HW, C = x.shape 346 | H = W = int(np.sqrt(new_HW)) 347 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 348 | x = self.conv(x) 349 | B, C = x.shape[:2] 350 | x = x.view(B, C, -1).transpose(-2, -1).contiguous() 351 | x = self.norm(x) 352 | 353 | return x 354 | 355 | 356 | class CSWinTransformer(nn.Module): 357 | """ Vision Transformer with support for patch or hybrid CNN input stage 358 | """ 359 | 360 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=96, depth=[2, 2, 6, 2], 361 | split_size=[1, 2, 7, 7], la_split_size='1-2-7-7', 362 | num_heads=[2, 4, 8, 16], mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., 363 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, use_chk=False, 364 | attn_type='IIII'): 365 | super().__init__() 366 | 367 | # split_size = [1, 2, img_size // 32, img_size // 32] 368 | la_split_size = la_split_size.split('-') 369 | 370 | self.use_chk = use_chk 371 | self.num_classes = num_classes 372 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 373 | heads = num_heads 374 | 375 | self.stage1_conv_embed = nn.Sequential( 376 | nn.Conv2d(in_chans, embed_dim, 7, 4, 2), 377 | Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4), 378 | nn.LayerNorm(embed_dim) 379 | ) 380 | 381 | curr_dim = embed_dim 382 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))] # stochastic depth decay rule 383 | attn_types = [(attn_type[0] if attn_type[0] != 'M' else ('I' if i < int(attn_type[4:]) else 'S')) for i in range(depth[0])] 384 | split_sizes = [(int(la_split_size[0]) if attn_types[i] == 'I' else split_size[0]) for i in range(depth[0])] 385 | self.stage1 = nn.ModuleList([ 386 | CSWinBlock( 387 | dim=curr_dim, num_heads=heads[0], reso=img_size // 4, mlp_ratio=mlp_ratio, 388 | qkv_bias=qkv_bias, qk_scale=qk_scale, 389 | split_size=split_sizes[i], 390 | drop=drop_rate, attn_drop=attn_drop_rate, 391 | drop_path=dpr[i], norm_layer=norm_layer, 392 | attn_type=attn_types[i]) 393 | for i in range(depth[0])]) 394 | 395 | self.merge1 = Merge_Block(curr_dim, curr_dim * 2) 396 | curr_dim = curr_dim * 2 397 | attn_types = [(attn_type[1] if attn_type[1] != 'M' else ('I' if i < int(attn_type[4:]) else 'S')) for i in range(depth[1])] 398 | split_sizes = [(int(la_split_size[1]) if attn_types[i] == 'I' else split_size[1]) for i in range(depth[1])] 399 | self.stage2 = nn.ModuleList( 400 | [CSWinBlock( 401 | dim=curr_dim, num_heads=heads[1], reso=img_size // 8, mlp_ratio=mlp_ratio, 402 | qkv_bias=qkv_bias, qk_scale=qk_scale, 403 | split_size=split_sizes[i], 404 | drop=drop_rate, attn_drop=attn_drop_rate, 405 | drop_path=dpr[np.sum(depth[:1]) + i], norm_layer=norm_layer, 406 | attn_type=attn_types[i]) 407 | for i in range(depth[1])]) 408 | 409 | self.merge2 = Merge_Block(curr_dim, curr_dim * 2) 410 | curr_dim = curr_dim * 2 411 | attn_types = [(attn_type[2] if attn_type[2] != 'M' else ('I' if i < int(attn_type[4:]) else 'S')) for i in range(depth[2])] 412 | split_sizes = [(int(la_split_size[2]) if attn_types[i] == 'I' else split_size[2]) for i in range(depth[2])] 413 | temp_stage3 = [] 414 | temp_stage3.extend( 415 | [CSWinBlock( 416 | dim=curr_dim, num_heads=heads[2], reso=img_size // 16, mlp_ratio=mlp_ratio, 417 | qkv_bias=qkv_bias, qk_scale=qk_scale, 418 | split_size=split_sizes[i], 419 | drop=drop_rate, attn_drop=attn_drop_rate, 420 | drop_path=dpr[np.sum(depth[:2]) + i], norm_layer=norm_layer, 421 | attn_type=attn_types[i]) 422 | for i in range(depth[2])]) 423 | 424 | self.stage3 = nn.ModuleList(temp_stage3) 425 | 426 | self.merge3 = Merge_Block(curr_dim, curr_dim * 2) 427 | curr_dim = curr_dim * 2 428 | attn_types = [(attn_type[3] if attn_type[3] != 'M' else ('I' if i < int(attn_type[4:]) else 'S')) for i in range(depth[3])] 429 | split_sizes = [(int(la_split_size[3]) if attn_types[i] == 'I' else split_size[3]) for i in range(depth[3])] 430 | self.stage4 = nn.ModuleList( 431 | [CSWinBlock( 432 | dim=curr_dim, num_heads=heads[3], reso=img_size // 32, mlp_ratio=mlp_ratio, 433 | qkv_bias=qkv_bias, qk_scale=qk_scale, 434 | split_size=split_sizes[i], 435 | drop=drop_rate, attn_drop=attn_drop_rate, 436 | drop_path=dpr[np.sum(depth[:-1]) + i], norm_layer=norm_layer, last_stage=True, 437 | attn_type=attn_types[i]) 438 | for i in range(depth[-1])]) 439 | 440 | self.norm = norm_layer(curr_dim) 441 | # Classifier head 442 | self.head = nn.Linear(curr_dim, num_classes) if num_classes > 0 else nn.Identity() 443 | 444 | trunc_normal_(self.head.weight, std=0.02) 445 | self.apply(self._init_weights) 446 | 447 | def _init_weights(self, m): 448 | if isinstance(m, nn.Linear): 449 | trunc_normal_(m.weight, std=.02) 450 | if isinstance(m, nn.Linear) and m.bias is not None: 451 | nn.init.constant_(m.bias, 0) 452 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 453 | nn.init.constant_(m.bias, 0) 454 | nn.init.constant_(m.weight, 1.0) 455 | 456 | @torch.jit.ignore 457 | def no_weight_decay(self): 458 | return {'pos_embed', 'cls_token'} 459 | 460 | def get_classifier(self): 461 | return self.head 462 | 463 | def reset_classifier(self, num_classes, global_pool=''): 464 | if self.num_classes != num_classes: 465 | print('reset head to', num_classes) 466 | self.num_classes = num_classes 467 | self.head = nn.Linear(self.out_dim, num_classes) if num_classes > 0 else nn.Identity() 468 | self.head = self.head.cuda() 469 | trunc_normal_(self.head.weight, std=.02) 470 | if self.head.bias is not None: 471 | nn.init.constant_(self.head.bias, 0) 472 | 473 | def forward_features(self, x): 474 | B = x.shape[0] 475 | x = self.stage1_conv_embed(x) 476 | for blk in self.stage1: 477 | if self.use_chk: 478 | x = checkpoint.checkpoint(blk, x) 479 | else: 480 | x = blk(x) 481 | for pre, blocks in zip([self.merge1, self.merge2, self.merge3], 482 | [self.stage2, self.stage3, self.stage4]): 483 | x = pre(x) 484 | for blk in blocks: 485 | if self.use_chk: 486 | x = checkpoint.checkpoint(blk, x) 487 | else: 488 | x = blk(x) 489 | x = self.norm(x) 490 | return torch.mean(x, dim=1) 491 | 492 | def forward(self, x): 493 | x = self.forward_features(x) 494 | x = self.head(x) 495 | return x 496 | 497 | 498 | def _conv_filter(state_dict, patch_size=16): 499 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 500 | out_dict = {} 501 | for k, v in state_dict.items(): 502 | if 'patch_embed.proj.weight' in k: 503 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 504 | out_dict[k] = v 505 | return out_dict 506 | 507 | 508 | ### 224 models 509 | 510 | def inline_cswin_tiny(pretrained=False, **kwargs): 511 | model = CSWinTransformer(patch_size=4, embed_dim=64, depth=[2, 4, 18, 1], 512 | split_size=[1, 2, 7, 7], num_heads=[2, 4, 8, 16], mlp_ratio=4., **kwargs) 513 | model.default_cfg = default_cfgs['cswin_224'] 514 | return model 515 | 516 | 517 | def inline_cswin_small(pretrained=False, **kwargs): 518 | model = CSWinTransformer(patch_size=4, embed_dim=64, depth=[3, 6, 29, 2], 519 | split_size=[1, 2, 7, 7], num_heads=[2, 4, 8, 16], mlp_ratio=4., **kwargs) 520 | model.default_cfg = default_cfgs['cswin_224'] 521 | return model 522 | 523 | 524 | def inline_cswin_base(pretrained=False, **kwargs): 525 | model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[3, 6, 29, 2], 526 | split_size=[1, 2, 7, 7], num_heads=[4, 8, 16, 32], mlp_ratio=4., **kwargs) 527 | model.default_cfg = default_cfgs['cswin_224'] 528 | return model 529 | 530 | 531 | ### 384 models 532 | 533 | 534 | def inline_cswin_base_384(pretrained=False, **kwargs): 535 | model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[3, 6, 29, 2], 536 | split_size=[1, 2, 12, 12], num_heads=[4, 8, 16, 32], mlp_ratio=4., **kwargs) 537 | model.default_cfg = default_cfgs['cswin_384'] 538 | return model -------------------------------------------------------------------------------- /models/inline_deit.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | The official jax code is released and available at https://github.com/google-research/vision_transformer 12 | 13 | DeiT model defs and weights from https://github.com/facebookresearch/deit, 14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 15 | 16 | Acknowledgments: 17 | * The paper authors for releasing code and weights, thanks! 18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 19 | for some einops/einsum fun 20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 22 | 23 | Hacked together by / Copyright 2021 Ross Wightman 24 | """ 25 | # ----------------------------------------------------------------------- 26 | # Bridging the Divide: Reconsidering Softmax and Linear Attention 27 | # Modified by Dongchen Han 28 | # ----------------------------------------------------------------------- 29 | 30 | 31 | import math 32 | import logging 33 | from functools import partial 34 | from collections import OrderedDict 35 | from copy import deepcopy 36 | 37 | import torch 38 | import torch.nn as nn 39 | import torch.nn.functional as F 40 | 41 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 42 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 43 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 44 | from timm.models.registry import register_model 45 | 46 | _logger = logging.getLogger(__name__) 47 | 48 | 49 | def _cfg(url='', **kwargs): 50 | return { 51 | 'url': url, 52 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 53 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 54 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 55 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 56 | **kwargs 57 | } 58 | 59 | 60 | default_cfgs = { 61 | # patch models (weights from official Google JAX impl) 62 | 'vit_tiny_patch16_224': _cfg( 63 | url='https://storage.googleapis.com/vit_models/augreg/' 64 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 65 | 'vit_tiny_patch16_384': _cfg( 66 | url='https://storage.googleapis.com/vit_models/augreg/' 67 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 68 | input_size=(3, 384, 384), crop_pct=1.0), 69 | 'vit_small_patch32_224': _cfg( 70 | url='https://storage.googleapis.com/vit_models/augreg/' 71 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 72 | 'vit_small_patch32_384': _cfg( 73 | url='https://storage.googleapis.com/vit_models/augreg/' 74 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 75 | input_size=(3, 384, 384), crop_pct=1.0), 76 | 'vit_small_patch16_224': _cfg( 77 | url='https://storage.googleapis.com/vit_models/augreg/' 78 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 79 | 'vit_small_patch16_384': _cfg( 80 | url='https://storage.googleapis.com/vit_models/augreg/' 81 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 82 | input_size=(3, 384, 384), crop_pct=1.0), 83 | 'vit_base_patch32_224': _cfg( 84 | url='https://storage.googleapis.com/vit_models/augreg/' 85 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 86 | 'vit_base_patch32_384': _cfg( 87 | url='https://storage.googleapis.com/vit_models/augreg/' 88 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 89 | input_size=(3, 384, 384), crop_pct=1.0), 90 | 'vit_base_patch16_224': _cfg( 91 | url='https://storage.googleapis.com/vit_models/augreg/' 92 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 93 | 'vit_base_patch16_384': _cfg( 94 | url='https://storage.googleapis.com/vit_models/augreg/' 95 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 96 | input_size=(3, 384, 384), crop_pct=1.0), 97 | 'vit_large_patch32_224': _cfg( 98 | url='', # no official model weights for this combo, only for in21k 99 | ), 100 | 'vit_large_patch32_384': _cfg( 101 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 102 | input_size=(3, 384, 384), crop_pct=1.0), 103 | 'vit_large_patch16_224': _cfg( 104 | url='https://storage.googleapis.com/vit_models/augreg/' 105 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 106 | 'vit_large_patch16_384': _cfg( 107 | url='https://storage.googleapis.com/vit_models/augreg/' 108 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 109 | input_size=(3, 384, 384), crop_pct=1.0), 110 | 111 | # patch models, imagenet21k (weights from official Google JAX impl) 112 | 'vit_tiny_patch16_224_in21k': _cfg( 113 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', 114 | num_classes=21843), 115 | 'vit_small_patch32_224_in21k': _cfg( 116 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 117 | num_classes=21843), 118 | 'vit_small_patch16_224_in21k': _cfg( 119 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 120 | num_classes=21843), 121 | 'vit_base_patch32_224_in21k': _cfg( 122 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', 123 | num_classes=21843), 124 | 'vit_base_patch16_224_in21k': _cfg( 125 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 126 | num_classes=21843), 127 | 'vit_large_patch32_224_in21k': _cfg( 128 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', 129 | num_classes=21843), 130 | 'vit_large_patch16_224_in21k': _cfg( 131 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', 132 | num_classes=21843), 133 | 'vit_huge_patch14_224_in21k': _cfg( 134 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', 135 | hf_hub='timm/vit_huge_patch14_224_in21k', 136 | num_classes=21843), 137 | 138 | # deit models (FB weights) 139 | 'deit_tiny_patch16_224': _cfg( 140 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', 141 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 142 | 'deit_small_patch16_224': _cfg( 143 | url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', 144 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 145 | 'deit_base_patch16_224': _cfg( 146 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', 147 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 148 | 'deit_base_patch16_384': _cfg( 149 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', 150 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), 151 | 'deit_tiny_distilled_patch16_224': _cfg( 152 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', 153 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 154 | 'deit_small_distilled_patch16_224': _cfg( 155 | url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', 156 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 157 | 'deit_base_distilled_patch16_224': _cfg( 158 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', 159 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 160 | 'deit_base_distilled_patch16_384': _cfg( 161 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', 162 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, 163 | classifier=('head', 'head_dist')), 164 | 165 | # ViT ImageNet-21K-P pretraining by MILL 166 | 'vit_base_patch16_224_miil_in21k': _cfg( 167 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', 168 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, 169 | ), 170 | 'vit_base_patch16_224_miil': _cfg( 171 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' 172 | '/vit_base_patch16_224_1k_miil_84_4.pth', 173 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', 174 | ), 175 | } 176 | 177 | 178 | class InLineAttention(nn.Module): 179 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., 180 | window=14, **kwargs): 181 | super().__init__() 182 | self.dim = dim 183 | self.num_heads = num_heads 184 | head_dim = dim // num_heads 185 | self.scale = head_dim ** -0.5 186 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 187 | self.attn_drop = nn.Dropout(attn_drop) 188 | self.proj = nn.Linear(dim, dim) 189 | self.proj_drop = nn.Dropout(proj_drop) 190 | self.softmax = nn.Softmax(dim=-1) 191 | self.window = window 192 | 193 | self.residual = nn.Sequential( 194 | nn.Conv1d(dim, dim, kernel_size=1, groups=num_heads), 195 | nn.GELU(), 196 | nn.Conv1d(dim, dim * 9, kernel_size=1, groups=num_heads) 197 | ) 198 | 199 | def forward(self, x): 200 | """ 201 | Args: 202 | x: input features with shape of (num_windows*B, N, C) 203 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 204 | """ 205 | b, n, c = x.shape 206 | h = int(n ** 0.5) 207 | w = int(n ** 0.5) 208 | num_heads = self.num_heads 209 | head_dim = c // num_heads 210 | qkv = self.qkv(x).reshape(b, n, 3, c).permute(2, 0, 1, 3) 211 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 212 | # q, k, v: b, n, c 213 | 214 | q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 215 | k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 216 | v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 217 | 218 | res_weight = self.residual(x.mean(dim=1).unsqueeze(dim=-1)).reshape(b * c, 1, 3, 3) 219 | 220 | # The self.scale / n = head_dim ** -0.5 / n is a scale factor used in InLine attention. 221 | # This factor can be equivalently achieved by scaling \phi(Q) = \phi(Q) * self.scale / n 222 | # Therefore, we omit it in eq. 5 of the paper for simplicity. 223 | kv = (k.transpose(-2, -1) * (self.scale / n) ** 0.5) @ (v * (self.scale / n) ** 0.5) 224 | x = q @ kv + (1 - q @ k.mean(dim=2, keepdim=True).transpose(-2, -1) * self.scale) * v.mean(dim=2, keepdim=True) 225 | 226 | x = x.transpose(1, 2).reshape(b, n, c) 227 | v_ = v[:, :, 1:, :].transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2).reshape(1, b * c, h, w) 228 | residual = F.conv2d(v_, res_weight, None, padding=(1, 1), groups=b * c) 229 | x[:, 1:, :] = x[:, 1:, :] + residual.reshape(b, c, n - 1).permute(0, 2, 1) 230 | 231 | x = self.proj(x) 232 | x = self.proj_drop(x) 233 | return x 234 | 235 | 236 | class InLineBlock(nn.Module): 237 | 238 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 239 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, window=14): 240 | super().__init__() 241 | self.norm1 = norm_layer(dim) 242 | self.attn = InLineAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, 243 | window=window) 244 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 245 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 246 | self.norm2 = norm_layer(dim) 247 | mlp_hidden_dim = int(dim * mlp_ratio) 248 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 249 | 250 | def forward(self, x): 251 | x = x + self.drop_path(self.attn(self.norm1(x))) 252 | x = x + self.drop_path(self.mlp(self.norm2(x))) 253 | return x 254 | 255 | 256 | class VisionTransformer(nn.Module): 257 | """ Vision Transformer 258 | 259 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 260 | - https://arxiv.org/abs/2010.11929 261 | 262 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 263 | - https://arxiv.org/abs/2012.12877 264 | """ 265 | 266 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 267 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 268 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 269 | act_layer=None, weight_init=''): 270 | """ 271 | Args: 272 | img_size (int, tuple): input image size 273 | patch_size (int, tuple): patch size 274 | in_chans (int): number of input channels 275 | num_classes (int): number of classes for classification head 276 | embed_dim (int): embedding dimension 277 | depth (int): depth of transformer 278 | num_heads (int): number of attention heads 279 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 280 | qkv_bias (bool): enable bias for qkv if True 281 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 282 | distilled (bool): model includes a distillation token and head as in DeiT models 283 | drop_rate (float): dropout rate 284 | attn_drop_rate (float): attention dropout rate 285 | drop_path_rate (float): stochastic depth rate 286 | embed_layer (nn.Module): patch embedding layer 287 | norm_layer: (nn.Module): normalization layer 288 | weight_init: (str): weight init scheme 289 | """ 290 | super().__init__() 291 | self.num_classes = num_classes 292 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 293 | self.num_tokens = 2 if distilled else 1 294 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 295 | act_layer = act_layer or nn.GELU 296 | 297 | self.patch_embed = embed_layer( 298 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 299 | num_patches = self.patch_embed.num_patches 300 | 301 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 302 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 303 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 304 | self.pos_drop = nn.Dropout(p=drop_rate) 305 | 306 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 307 | self.blocks = nn.Sequential(*[ 308 | InLineBlock( 309 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 310 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, 311 | window=img_size // patch_size) for i in range(depth)]) 312 | self.norm = norm_layer(embed_dim) 313 | 314 | # Representation layer 315 | if representation_size and not distilled: 316 | self.num_features = representation_size 317 | self.pre_logits = nn.Sequential(OrderedDict([ 318 | ('fc', nn.Linear(embed_dim, representation_size)), 319 | ('act', nn.Tanh()) 320 | ])) 321 | else: 322 | self.pre_logits = nn.Identity() 323 | 324 | # Classifier head(s) 325 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 326 | self.head_dist = None 327 | if distilled: 328 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 329 | 330 | self.init_weights(weight_init) 331 | 332 | def init_weights(self, mode=''): 333 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 334 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 335 | trunc_normal_(self.pos_embed, std=.02) 336 | if self.dist_token is not None: 337 | trunc_normal_(self.dist_token, std=.02) 338 | if mode.startswith('jax'): 339 | # leave cls token as zeros to match jax impl 340 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 341 | else: 342 | trunc_normal_(self.cls_token, std=.02) 343 | self.apply(_init_vit_weights) 344 | 345 | def _init_weights(self, m): 346 | # this fn left here for compat with downstream users 347 | _init_vit_weights(m) 348 | 349 | @torch.jit.ignore() 350 | def load_pretrained(self, checkpoint_path, prefix=''): 351 | _load_weights(self, checkpoint_path, prefix) 352 | 353 | @torch.jit.ignore 354 | def no_weight_decay(self): 355 | return {'pos_embed', 'cls_token', 'dist_token'} 356 | 357 | def get_classifier(self): 358 | if self.dist_token is None: 359 | return self.head 360 | else: 361 | return self.head, self.head_dist 362 | 363 | def reset_classifier(self, num_classes, global_pool=''): 364 | self.num_classes = num_classes 365 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 366 | if self.num_tokens == 2: 367 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 368 | 369 | def forward_features(self, x): 370 | x = self.patch_embed(x) 371 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 372 | if self.dist_token is None: 373 | x = torch.cat((cls_token, x), dim=1) 374 | else: 375 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 376 | x = self.pos_drop(x + self.pos_embed) 377 | x = self.blocks(x) 378 | x = self.norm(x) 379 | if self.dist_token is None: 380 | return self.pre_logits(x[:, 0]) 381 | else: 382 | return x[:, 0], x[:, 1] 383 | 384 | def forward(self, x): 385 | x = self.forward_features(x) 386 | if self.head_dist is not None: 387 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 388 | if self.training and not torch.jit.is_scripting(): 389 | # during inference, return the average of both classifier predictions 390 | return x, x_dist 391 | else: 392 | return (x + x_dist) / 2 393 | else: 394 | x = self.head(x) 395 | return x 396 | 397 | 398 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 399 | """ ViT weight initialization 400 | * When called without n, head_bias, jax_impl args it will behave exactly the same 401 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 402 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 403 | """ 404 | if isinstance(module, nn.Linear): 405 | if name.startswith('head'): 406 | nn.init.zeros_(module.weight) 407 | nn.init.constant_(module.bias, head_bias) 408 | elif name.startswith('pre_logits'): 409 | lecun_normal_(module.weight) 410 | nn.init.zeros_(module.bias) 411 | else: 412 | if jax_impl: 413 | nn.init.xavier_uniform_(module.weight) 414 | if module.bias is not None: 415 | if 'mlp' in name: 416 | nn.init.normal_(module.bias, std=1e-6) 417 | else: 418 | nn.init.zeros_(module.bias) 419 | else: 420 | trunc_normal_(module.weight, std=.02) 421 | if module.bias is not None: 422 | nn.init.zeros_(module.bias) 423 | elif jax_impl and isinstance(module, nn.Conv2d): 424 | # NOTE conv was left to pytorch default in my original init 425 | lecun_normal_(module.weight) 426 | if module.bias is not None: 427 | nn.init.zeros_(module.bias) 428 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 429 | nn.init.zeros_(module.bias) 430 | nn.init.ones_(module.weight) 431 | 432 | 433 | @torch.no_grad() 434 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 435 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 436 | """ 437 | import numpy as np 438 | 439 | def _n2p(w, t=True): 440 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 441 | w = w.flatten() 442 | if t: 443 | if w.ndim == 4: 444 | w = w.transpose([3, 2, 0, 1]) 445 | elif w.ndim == 3: 446 | w = w.transpose([2, 0, 1]) 447 | elif w.ndim == 2: 448 | w = w.transpose([1, 0]) 449 | return torch.from_numpy(w) 450 | 451 | w = np.load(checkpoint_path) 452 | if not prefix and 'opt/target/embedding/kernel' in w: 453 | prefix = 'opt/target/' 454 | 455 | if hasattr(model.patch_embed, 'backbone'): 456 | # hybrid 457 | backbone = model.patch_embed.backbone 458 | stem_only = not hasattr(backbone, 'stem') 459 | stem = backbone if stem_only else backbone.stem 460 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 461 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 462 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 463 | if not stem_only: 464 | for i, stage in enumerate(backbone.stages): 465 | for j, block in enumerate(stage.blocks): 466 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 467 | for r in range(3): 468 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 469 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 470 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 471 | if block.downsample is not None: 472 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 473 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 474 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 475 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 476 | else: 477 | embed_conv_w = adapt_input_conv( 478 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 479 | model.patch_embed.proj.weight.copy_(embed_conv_w) 480 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 481 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 482 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 483 | if pos_embed_w.shape != model.pos_embed.shape: 484 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 485 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 486 | model.pos_embed.copy_(pos_embed_w) 487 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 488 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 489 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 490 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 491 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 492 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 493 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 494 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 495 | for i, block in enumerate(model.blocks.children()): 496 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 497 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 498 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 499 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 500 | block.attn.qkv.weight.copy_(torch.cat([ 501 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 502 | block.attn.qkv.bias.copy_(torch.cat([ 503 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 504 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 505 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 506 | for r in range(2): 507 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 508 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 509 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 510 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 511 | 512 | 513 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 514 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 515 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 516 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 517 | ntok_new = posemb_new.shape[1] 518 | if num_tokens: 519 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 520 | ntok_new -= num_tokens 521 | else: 522 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 523 | gs_old = int(math.sqrt(len(posemb_grid))) 524 | if not len(gs_new): # backwards compatibility 525 | gs_new = [int(math.sqrt(ntok_new))] * 2 526 | assert len(gs_new) >= 2 527 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 528 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 529 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') 530 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 531 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 532 | return posemb 533 | 534 | 535 | def checkpoint_filter_fn(state_dict, model): 536 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 537 | out_dict = {} 538 | if 'model' in state_dict: 539 | # For deit models 540 | state_dict = state_dict['model'] 541 | for k, v in state_dict.items(): 542 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 543 | # For old models that I trained prior to conv based patchification 544 | O, I, H, W = model.patch_embed.proj.weight.shape 545 | v = v.reshape(O, -1, H, W) 546 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 547 | # To resize pos embedding when using model at different size from pretrained weights 548 | v = resize_pos_embed( 549 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 550 | out_dict[k] = v 551 | return out_dict 552 | 553 | 554 | def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): 555 | default_cfg = default_cfg or default_cfgs[variant] 556 | if kwargs.get('features_only', None): 557 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 558 | 559 | # NOTE this extra code to support handling of repr size for in21k pretrained models 560 | default_num_classes = default_cfg['num_classes'] 561 | num_classes = kwargs.get('num_classes', default_num_classes) 562 | repr_size = kwargs.pop('representation_size', None) 563 | if repr_size is not None and num_classes != default_num_classes: 564 | # Remove representation layer if fine-tuning. This may not always be the desired action, 565 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 566 | _logger.warning("Removing representation layer for fine-tuning.") 567 | repr_size = None 568 | 569 | model = build_model_with_cfg( 570 | VisionTransformer, variant, pretrained, 571 | default_cfg=default_cfg, 572 | representation_size=repr_size, 573 | pretrained_filter_fn=checkpoint_filter_fn, 574 | pretrained_custom_load='npz' in default_cfg['url'], 575 | **kwargs) 576 | return model 577 | 578 | 579 | @register_model 580 | def inline_deit_tiny(pretrained=False, **kwargs): 581 | """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 582 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 583 | """ 584 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=6, **kwargs) 585 | model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 586 | return model 587 | 588 | 589 | @register_model 590 | def inline_deit_small(pretrained=False, **kwargs): 591 | """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 592 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 593 | """ 594 | model_kwargs = dict(patch_size=16, embed_dim=320, depth=12, num_heads=10, **kwargs) 595 | model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 596 | return model 597 | 598 | 599 | @register_model 600 | def inline_deit_base(pretrained=False, **kwargs): 601 | """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 602 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 603 | """ 604 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs) 605 | model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) 606 | return model 607 | -------------------------------------------------------------------------------- /models/inline_swin.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Bridging the Divide: Reconsidering Softmax and Linear Attention 8 | # Modified by Dongchen Han 9 | # ----------------------------------------------------------------------- 10 | 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.utils.checkpoint as checkpoint 15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 16 | import torch.nn.functional as F 17 | 18 | 19 | class Mlp(nn.Module): 20 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 21 | super().__init__() 22 | out_features = out_features or in_features 23 | hidden_features = hidden_features or in_features 24 | self.fc1 = nn.Linear(in_features, hidden_features) 25 | self.act = act_layer() 26 | self.fc2 = nn.Linear(hidden_features, out_features) 27 | self.drop = nn.Dropout(drop) 28 | 29 | def forward(self, x): 30 | x = self.fc1(x) 31 | x = self.act(x) 32 | x = self.drop(x) 33 | x = self.fc2(x) 34 | x = self.drop(x) 35 | return x 36 | 37 | 38 | def window_partition(x, window_size): 39 | """ 40 | Args: 41 | x: (B, H, W, C) 42 | window_size (int): window size 43 | 44 | Returns: 45 | windows: (num_windows*B, window_size, window_size, C) 46 | """ 47 | B, H, W, C = x.shape 48 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 49 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 50 | return windows 51 | 52 | 53 | def window_reverse(windows, window_size, H, W): 54 | """ 55 | Args: 56 | windows: (num_windows*B, window_size, window_size, C) 57 | window_size (int): Window size 58 | H (int): Height of image 59 | W (int): Width of image 60 | 61 | Returns: 62 | x: (B, H, W, C) 63 | """ 64 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 65 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 66 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 67 | return x 68 | 69 | 70 | class PosCNN(nn.Module): 71 | def __init__(self, in_chans, embed_dim=768, s=1): 72 | super(PosCNN, self).__init__() 73 | self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim), ) 74 | self.s = s 75 | 76 | def forward(self, x): 77 | B, N, C = x.shape 78 | H = int(N ** 0.5) 79 | W = int(N ** 0.5) 80 | feat_token = x 81 | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) 82 | if self.s == 1: 83 | x = self.proj(cnn_feat) + cnn_feat 84 | else: 85 | x = self.proj(cnn_feat) 86 | x = x.flatten(2).transpose(1, 2) 87 | return x 88 | 89 | def no_weight_decay(self): 90 | return ['proj.%d.weight' % i for i in range(4)] 91 | 92 | 93 | class WindowAttention(nn.Module): 94 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 95 | It supports both of shifted and non-shifted window. 96 | 97 | Args: 98 | dim (int): Number of input channels. 99 | window_size (tuple[int]): The height and width of the window. 100 | num_heads (int): Number of attention heads. 101 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 102 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 103 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 104 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 105 | """ 106 | 107 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 108 | 109 | super().__init__() 110 | self.dim = dim 111 | self.window_size = window_size # Wh, Ww 112 | self.num_heads = num_heads 113 | head_dim = dim // num_heads 114 | self.scale = qk_scale or head_dim ** -0.5 115 | 116 | # define a parameter table of relative position bias 117 | self.relative_position_bias_table = nn.Parameter( 118 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 119 | 120 | # get pair-wise relative position index for each token inside the window 121 | coords_h = torch.arange(self.window_size[0]) 122 | coords_w = torch.arange(self.window_size[1]) 123 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 124 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 125 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 126 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 127 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 128 | relative_coords[:, :, 1] += self.window_size[1] - 1 129 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 130 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 131 | self.register_buffer("relative_position_index", relative_position_index) 132 | 133 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 134 | self.attn_drop = nn.Dropout(attn_drop) 135 | self.proj = nn.Linear(dim, dim) 136 | self.proj_drop = nn.Dropout(proj_drop) 137 | 138 | trunc_normal_(self.relative_position_bias_table, std=.02) 139 | self.softmax = nn.Softmax(dim=-1) 140 | 141 | def forward(self, x, mask=None): 142 | """ 143 | Args: 144 | x: input features with shape of (num_windows*B, N, C) 145 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 146 | """ 147 | B_, N, C = x.shape 148 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 149 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 150 | 151 | q = q * self.scale 152 | attn = (q @ k.transpose(-2, -1)) 153 | 154 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 155 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 156 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 157 | attn = attn + relative_position_bias.unsqueeze(0) 158 | 159 | if mask is not None: 160 | nW = mask.shape[0] 161 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 162 | attn = attn.view(-1, self.num_heads, N, N) 163 | attn = self.softmax(attn) 164 | else: 165 | attn = self.softmax(attn) 166 | 167 | attn = self.attn_drop(attn) 168 | 169 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 170 | x = self.proj(x) 171 | x = self.proj_drop(x) 172 | return x 173 | 174 | def extra_repr(self) -> str: 175 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 176 | 177 | def flops(self, N): 178 | # calculate flops for 1 window with token length of N 179 | flops = 0 180 | # qkv = self.qkv(x) 181 | flops += N * self.dim * 3 * self.dim 182 | # attn = (q @ k.transpose(-2, -1)) 183 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 184 | # x = (attn @ v) 185 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 186 | # x = self.proj(x) 187 | flops += N * self.dim * self.dim 188 | return flops 189 | 190 | 191 | def exp_kernel(x, t=0.2): 192 | return (x * t).exp() 193 | 194 | 195 | class InLineAttention(nn.Module): 196 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 197 | It supports both of shifted and non-shifted window. 198 | 199 | Args: 200 | dim (int): Number of input channels. 201 | num_heads (int): Number of attention heads. 202 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 203 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 204 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 205 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 206 | """ 207 | 208 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., shift_size=0, 209 | kernel_func='identity', **kwargs): 210 | 211 | super().__init__() 212 | self.dim = dim 213 | self.window_size = window_size # Wh, Ww 214 | self.num_heads = num_heads 215 | head_dim = dim // num_heads 216 | self.scale = head_dim ** -0.5 217 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 218 | self.attn_drop = nn.Dropout(attn_drop) 219 | self.proj = nn.Linear(dim, dim) 220 | self.proj_drop = nn.Dropout(proj_drop) 221 | self.softmax = nn.Softmax(dim=-1) 222 | self.shift_size = shift_size 223 | 224 | assert kernel_func in ['identity', 'relu', 'leakyrelu', 'exp'] 225 | if kernel_func == 'identity': 226 | self.phi = None 227 | elif kernel_func == 'relu': 228 | self.phi = nn.ReLU() 229 | elif kernel_func == 'leakyrelu': 230 | self.phi = nn.LeakyReLU() 231 | elif kernel_func == 'exp': 232 | self.phi = exp_kernel 233 | else: 234 | self.phi = None 235 | self.kernel_func = kernel_func 236 | 237 | self.residual = nn.Sequential( 238 | nn.Conv1d(dim, dim, kernel_size=1, groups=num_heads), 239 | nn.GELU(), 240 | nn.Conv1d(dim, dim * 9, kernel_size=1, groups=num_heads) 241 | ) 242 | 243 | def forward(self, x, mask=None): 244 | """ 245 | Args: 246 | x: input features with shape of (num_windows*B, N, C) 247 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 248 | """ 249 | b, n, c = x.shape 250 | h = int(n ** 0.5) 251 | w = int(n ** 0.5) 252 | num_heads = self.num_heads 253 | head_dim = c // num_heads 254 | qkv = self.qkv(x).reshape(b, n, 3, c).permute(2, 0, 1, 3) 255 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 256 | # q, k, v: b, n, c 257 | 258 | if self.phi is not None: 259 | q = self.phi(q) 260 | k = self.phi(k) 261 | q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 262 | k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 263 | v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 264 | 265 | res_weight = self.residual(x.mean(dim=1).unsqueeze(dim=-1)).reshape(b * c, 1, 3, 3) 266 | 267 | # The self.scale / n = head_dim ** -0.5 / n is a scale factor used in InLine attention. 268 | # This factor can be equivalently achieved by scaling \phi(Q) = \phi(Q) * self.scale / n 269 | # Therefore, we omit it in eq. 5 of the paper for simplicity. 270 | kv = (k.transpose(-2, -1) * (self.scale / n) ** 0.5) @ (v * (self.scale / n) ** 0.5) 271 | x = q @ kv + (1 - q @ k.mean(dim=2, keepdim=True).transpose(-2, -1) * self.scale) * v.mean(dim=2, keepdim=True) 272 | 273 | x = x.transpose(1, 2).reshape(b, n, c) 274 | v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2).reshape(1, b * c, h, w) 275 | residual = F.conv2d(v, res_weight, None, padding=(1, 1), groups=b * c) 276 | x = x + residual.reshape(b, c, n).permute(0, 2, 1) 277 | 278 | x = self.proj(x) 279 | x = self.proj_drop(x) 280 | return x 281 | 282 | def extra_repr(self) -> str: 283 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 284 | 285 | def flops(self, N): 286 | # calculate flops for 1 window with token length of N 287 | flops = 0 288 | # qkv = self.qkv(x) 289 | flops += N * self.dim * 3 * self.dim 290 | # attn = (q @ k.transpose(-2, -1)) 291 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 292 | # x = (attn @ v) 293 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 294 | # x = self.proj(x) 295 | flops += N * self.dim * self.dim 296 | return flops 297 | 298 | 299 | class SwinTransformerBlock(nn.Module): 300 | r""" Swin Transformer Block. 301 | 302 | Args: 303 | dim (int): Number of input channels. 304 | input_resolution (tuple[int]): Input resulotion. 305 | num_heads (int): Number of attention heads. 306 | window_size (int): Window size. 307 | shift_size (int): Shift size for SW-MSA. 308 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 309 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 310 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 311 | drop (float, optional): Dropout rate. Default: 0.0 312 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 313 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 314 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 315 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 316 | """ 317 | 318 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 319 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 320 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_type='I'): 321 | super().__init__() 322 | self.dim = dim 323 | self.input_resolution = input_resolution 324 | self.num_heads = num_heads 325 | self.window_size = window_size 326 | self.shift_size = shift_size 327 | self.mlp_ratio = mlp_ratio 328 | if min(self.input_resolution) <= self.window_size: 329 | # if window size is larger than input resolution, we don't partition windows 330 | self.shift_size = 0 331 | self.window_size = min(self.input_resolution) 332 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 333 | 334 | self.norm1 = norm_layer(dim) 335 | assert attn_type in ['I', 'S'] 336 | attn = InLineAttention if attn_type == 'I' else WindowAttention 337 | self.attn = attn( 338 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 339 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 340 | 341 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 342 | self.norm2 = norm_layer(dim) 343 | mlp_hidden_dim = int(dim * mlp_ratio) 344 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 345 | 346 | if self.shift_size > 0: 347 | # calculate attention mask for SW-MSA 348 | H, W = self.input_resolution 349 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 350 | h_slices = (slice(0, -self.window_size), 351 | slice(-self.window_size, -self.shift_size), 352 | slice(-self.shift_size, None)) 353 | w_slices = (slice(0, -self.window_size), 354 | slice(-self.window_size, -self.shift_size), 355 | slice(-self.shift_size, None)) 356 | cnt = 0 357 | for h in h_slices: 358 | for w in w_slices: 359 | img_mask[:, h, w, :] = cnt 360 | cnt += 1 361 | 362 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 363 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 364 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 365 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 366 | else: 367 | attn_mask = None 368 | 369 | self.register_buffer("attn_mask", attn_mask) 370 | 371 | def forward(self, x): 372 | H, W = self.input_resolution 373 | B, L, C = x.shape 374 | assert L == H * W, "input feature has wrong size" 375 | 376 | shortcut = x 377 | x = self.norm1(x) 378 | x = x.view(B, H, W, C) 379 | 380 | # cyclic shift 381 | if self.shift_size > 0: 382 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 383 | else: 384 | shifted_x = x 385 | 386 | # partition windows 387 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 388 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 389 | 390 | # W-MSA/SW-MSA 391 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 392 | 393 | # merge windows 394 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 395 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 396 | 397 | # reverse cyclic shift 398 | if self.shift_size > 0: 399 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 400 | else: 401 | x = shifted_x 402 | x = x.view(B, H * W, C) 403 | 404 | # FFN 405 | x = shortcut + self.drop_path(x) 406 | x = x + self.drop_path(self.mlp(self.norm2(x))) 407 | 408 | return x 409 | 410 | def extra_repr(self) -> str: 411 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 412 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 413 | 414 | def flops(self): 415 | flops = 0 416 | H, W = self.input_resolution 417 | # norm1 418 | flops += self.dim * H * W 419 | # W-MSA/SW-MSA 420 | nW = H * W / self.window_size / self.window_size 421 | flops += nW * self.attn.flops(self.window_size * self.window_size) 422 | # mlp 423 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 424 | # norm2 425 | flops += self.dim * H * W 426 | return flops 427 | 428 | 429 | class PatchMerging(nn.Module): 430 | r""" Patch Merging Layer. 431 | 432 | Args: 433 | input_resolution (tuple[int]): Resolution of input feature. 434 | dim (int): Number of input channels. 435 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 436 | """ 437 | 438 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 439 | super().__init__() 440 | self.input_resolution = input_resolution 441 | self.dim = dim 442 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 443 | self.norm = norm_layer(4 * dim) 444 | 445 | def forward(self, x): 446 | """ 447 | x: B, H*W, C 448 | """ 449 | H, W = self.input_resolution 450 | B, L, C = x.shape 451 | assert L == H * W, "input feature has wrong size" 452 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 453 | 454 | x = x.view(B, H, W, C) 455 | 456 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 457 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 458 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 459 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 460 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 461 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 462 | 463 | x = self.norm(x) 464 | x = self.reduction(x) 465 | 466 | return x 467 | 468 | def extra_repr(self) -> str: 469 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 470 | 471 | def flops(self): 472 | H, W = self.input_resolution 473 | flops = H * W * self.dim 474 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 475 | return flops 476 | 477 | 478 | class BasicLayer(nn.Module): 479 | """ A basic Swin Transformer layer for one stage. 480 | 481 | Args: 482 | dim (int): Number of input channels. 483 | input_resolution (tuple[int]): Input resolution. 484 | depth (int): Number of blocks. 485 | num_heads (int): Number of attention heads. 486 | window_size (int): Local window size. 487 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 488 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 489 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 490 | drop (float, optional): Dropout rate. Default: 0.0 491 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 492 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 493 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 494 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 495 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 496 | """ 497 | 498 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 499 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 500 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, attn_type='I'): 501 | 502 | super().__init__() 503 | self.dim = dim 504 | self.input_resolution = input_resolution 505 | self.depth = depth 506 | self.use_checkpoint = use_checkpoint 507 | 508 | # build blocks 509 | attn_types = [(attn_type if attn_type[0] != 'M' else ('I' if i < int(attn_type[1:]) else 'S')) for i in range(depth)] 510 | window_sizes = [(window_size if attn_types[i] == 'I' else max(7, (window_size // 8))) for i in range(depth)] 511 | self.blocks = nn.ModuleList([ 512 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 513 | num_heads=num_heads, window_size=window_sizes[i], 514 | shift_size=0 if (i % 2 == 0) else window_sizes[i] // 2, 515 | mlp_ratio=mlp_ratio, 516 | qkv_bias=qkv_bias, qk_scale=qk_scale, 517 | drop=drop, attn_drop=attn_drop, 518 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 519 | norm_layer=norm_layer, 520 | attn_type=attn_types[i]) 521 | for i in range(depth)]) 522 | 523 | # patch merging layer 524 | if downsample is not None: 525 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 526 | else: 527 | self.downsample = None 528 | 529 | def forward(self, x): 530 | for blk in self.blocks: 531 | if self.use_checkpoint: 532 | x = checkpoint.checkpoint(blk, x) 533 | else: 534 | x = blk(x) 535 | if self.downsample is not None: 536 | x = self.downsample(x) 537 | return x 538 | 539 | def extra_repr(self) -> str: 540 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 541 | 542 | def flops(self): 543 | flops = 0 544 | for blk in self.blocks: 545 | flops += blk.flops() 546 | if self.downsample is not None: 547 | flops += self.downsample.flops() 548 | return flops 549 | 550 | 551 | class PatchEmbed(nn.Module): 552 | r""" Image to Patch Embedding 553 | 554 | Args: 555 | img_size (int): Image size. Default: 224. 556 | patch_size (int): Patch token size. Default: 4. 557 | in_chans (int): Number of input image channels. Default: 3. 558 | embed_dim (int): Number of linear projection output channels. Default: 96. 559 | norm_layer (nn.Module, optional): Normalization layer. Default: None 560 | """ 561 | 562 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 563 | super().__init__() 564 | img_size = to_2tuple(img_size) 565 | patch_size = to_2tuple(patch_size) 566 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 567 | self.img_size = img_size 568 | self.patch_size = patch_size 569 | self.patches_resolution = patches_resolution 570 | self.num_patches = patches_resolution[0] * patches_resolution[1] 571 | 572 | self.in_chans = in_chans 573 | self.embed_dim = embed_dim 574 | 575 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 576 | if norm_layer is not None: 577 | self.norm = norm_layer(embed_dim) 578 | else: 579 | self.norm = None 580 | 581 | def forward(self, x): 582 | B, C, H, W = x.shape 583 | # FIXME look at relaxing size constraints 584 | assert H == self.img_size[0] and W == self.img_size[1], \ 585 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 586 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 587 | if self.norm is not None: 588 | x = self.norm(x) 589 | return x 590 | 591 | def flops(self): 592 | Ho, Wo = self.patches_resolution 593 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 594 | if self.norm is not None: 595 | flops += Ho * Wo * self.embed_dim 596 | return flops 597 | 598 | 599 | class InLineSwin(nn.Module): 600 | r""" Swin Transformer 601 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 602 | https://arxiv.org/pdf/2103.14030 603 | 604 | Args: 605 | img_size (int | tuple(int)): Input image size. Default 224 606 | patch_size (int | tuple(int)): Patch size. Default: 4 607 | in_chans (int): Number of input image channels. Default: 3 608 | num_classes (int): Number of classes for classification head. Default: 1000 609 | embed_dim (int): Patch embedding dimension. Default: 96 610 | depths (tuple(int)): Depth of each Swin Transformer layer. 611 | num_heads (tuple(int)): Number of attention heads in different layers. 612 | window_size (int): Window size. Default: 7 613 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 614 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 615 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 616 | drop_rate (float): Dropout rate. Default: 0 617 | attn_drop_rate (float): Attention dropout rate. Default: 0 618 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 619 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 620 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 621 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 622 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 623 | """ 624 | 625 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 626 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 627 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 628 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 629 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 630 | use_checkpoint=False, attn_type='IIII', **kwargs): 631 | super().__init__() 632 | self.num_classes = num_classes 633 | self.num_layers = len(depths) 634 | self.embed_dim = embed_dim 635 | self.ape = ape 636 | self.patch_norm = patch_norm 637 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 638 | self.mlp_ratio = mlp_ratio 639 | 640 | # split image into non-overlapping patches 641 | self.patch_embed = PatchEmbed( 642 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 643 | norm_layer=norm_layer if self.patch_norm else None) 644 | num_patches = self.patch_embed.num_patches 645 | patches_resolution = self.patch_embed.patches_resolution 646 | self.patches_resolution = patches_resolution 647 | 648 | # absolute position embedding 649 | if self.ape: 650 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 651 | trunc_normal_(self.absolute_pos_embed, std=.02) 652 | 653 | self.pos_drop = nn.Dropout(p=drop_rate) 654 | 655 | # stochastic depth 656 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 657 | 658 | # build layers 659 | self.layers = nn.ModuleList() 660 | for i_layer in range(self.num_layers): 661 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 662 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 663 | patches_resolution[1] // (2 ** i_layer)), 664 | depth=depths[i_layer], 665 | num_heads=num_heads[i_layer], 666 | window_size=window_size, 667 | mlp_ratio=self.mlp_ratio, 668 | qkv_bias=qkv_bias, qk_scale=qk_scale, 669 | drop=drop_rate, attn_drop=attn_drop_rate, 670 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 671 | norm_layer=norm_layer, 672 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 673 | use_checkpoint=use_checkpoint, 674 | attn_type=attn_type[i_layer] + (attn_type[self.num_layers:] if attn_type[i_layer] == 'M' else '')) 675 | self.layers.append(layer) 676 | 677 | self.norm = norm_layer(self.num_features) 678 | self.avgpool = nn.AdaptiveAvgPool1d(1) 679 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 680 | 681 | self.apply(self._init_weights) 682 | 683 | def _init_weights(self, m): 684 | if isinstance(m, nn.Linear): 685 | trunc_normal_(m.weight, std=.02) 686 | if isinstance(m, nn.Linear) and m.bias is not None: 687 | nn.init.constant_(m.bias, 0) 688 | elif isinstance(m, nn.LayerNorm): 689 | nn.init.constant_(m.bias, 0) 690 | nn.init.constant_(m.weight, 1.0) 691 | 692 | @torch.jit.ignore 693 | def no_weight_decay(self): 694 | return {'absolute_pos_embed'} 695 | 696 | @torch.jit.ignore 697 | def no_weight_decay_keywords(self): 698 | return {'relative_position_bias_table'} 699 | 700 | def forward_features(self, x): 701 | x = self.patch_embed(x) 702 | if self.ape: 703 | x = x + self.absolute_pos_embed 704 | x = self.pos_drop(x) 705 | 706 | for layer in self.layers: 707 | x = layer(x) 708 | 709 | x = self.norm(x) # B L C 710 | x = self.avgpool(x.transpose(1, 2)) # B C 1 711 | x = torch.flatten(x, 1) 712 | return x 713 | 714 | def forward(self, x): 715 | x = self.forward_features(x) 716 | x = self.head(x) 717 | return x 718 | 719 | def flops(self): 720 | flops = 0 721 | flops += self.patch_embed.flops() 722 | for i, layer in enumerate(self.layers): 723 | flops += layer.flops() 724 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 725 | flops += self.num_features * self.num_classes 726 | return flops 727 | --------------------------------------------------------------------------------