├── data
├── __init__.py
├── samplers.py
├── zipreader.py
├── build.py
└── cached_image_folder.py
├── models
├── __init__.py
├── build.py
├── inline_pvt.py
├── inline_cswin.py
├── inline_deit.py
└── inline_swin.py
├── figures
├── fig2_cls.png
├── fig3_speed.png
└── fig1_injectivity.png
├── cfgs
├── inline_deit_t.yaml
├── inline_deit_b.yaml
├── inline_deit_s.yaml
├── inline_swin_t.yaml
├── inline_swin_s.yaml
├── inline_swin_b.yaml
├── inline_pvt_b.yaml
├── inline_pvt_t.yaml
├── inline_pvt_m.yaml
├── inline_pvt_s.yaml
├── inline_cswin_b.yaml
├── inline_cswin_s.yaml
├── inline_cswin_t.yaml
├── inline_cswin_b_384.yaml
└── inline_swin_b_384.yaml
├── logger.py
├── optimizer.py
├── lr_scheduler.py
├── README.md
├── config.py
├── utils.py
├── utils_ema.py
├── main.py
└── main_ema.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_loader
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_model
--------------------------------------------------------------------------------
/figures/fig2_cls.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/InLine/HEAD/figures/fig2_cls.png
--------------------------------------------------------------------------------
/figures/fig3_speed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/InLine/HEAD/figures/fig3_speed.png
--------------------------------------------------------------------------------
/figures/fig1_injectivity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/InLine/HEAD/figures/fig1_injectivity.png
--------------------------------------------------------------------------------
/cfgs/inline_deit_t.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: inline_deit_tiny
3 | NAME: inline_deit_tiny
4 | DATA:
5 | BATCH_SIZE: 512
--------------------------------------------------------------------------------
/cfgs/inline_deit_b.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: inline_deit_base
3 | NAME: inline_deit_base
4 | DATA:
5 | IMG_SIZE: 448
6 | BATCH_SIZE: 64
--------------------------------------------------------------------------------
/cfgs/inline_deit_s.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: inline_deit_small
3 | NAME: inline_deit_small
4 | DATA:
5 | IMG_SIZE: 288
6 | BATCH_SIZE: 128
--------------------------------------------------------------------------------
/cfgs/inline_swin_t.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: inline_swin
3 | NAME: inline_swin_tiny
4 | DROP_PATH_RATE: 0.2
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 6, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 56
10 | INLINE:
11 | ATTN_TYPE: IIIS
--------------------------------------------------------------------------------
/cfgs/inline_swin_s.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: inline_swin
3 | NAME: inline_swin_small
4 | DROP_PATH_RATE: 0.3
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 56
10 | INLINE:
11 | ATTN_TYPE: IISS
--------------------------------------------------------------------------------
/cfgs/inline_swin_b.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: inline_swin
3 | NAME: inline_swin_base
4 | DROP_PATH_RATE: 0.5
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 56
10 | INLINE:
11 | ATTN_TYPE: IIMS2
12 | DATA:
13 | BATCH_SIZE: 64
--------------------------------------------------------------------------------
/cfgs/inline_pvt_b.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 224
3 | BATCH_SIZE: 64
4 |
5 | TRAIN:
6 | WEIGHT_DECAY: 0.05
7 | EPOCHS: 300
8 | WARMUP_EPOCHS: 5
9 | COOLDOWN_EPOCHS: 10
10 | BASE_LR: 5e-4
11 | WARMUP_LR: 1e-6
12 | MIN_LR: 1e-5
13 | CLIP_GRAD: 1.0
14 |
15 | MODEL:
16 | TYPE: inline_pvt_large
17 | NAME: inline_pvt_large
18 | DROP_PATH_RATE: 0.3
19 |
--------------------------------------------------------------------------------
/cfgs/inline_pvt_t.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 224
3 | BATCH_SIZE: 128
4 |
5 | TRAIN:
6 | WEIGHT_DECAY: 0.05
7 | EPOCHS: 300
8 | WARMUP_EPOCHS: 5
9 | COOLDOWN_EPOCHS: 10
10 | BASE_LR: 5e-4
11 | WARMUP_LR: 1e-6
12 | MIN_LR: 1e-5
13 | CLIP_GRAD: None
14 |
15 | MODEL:
16 | TYPE: inline_pvt_tiny
17 | NAME: inline_pvt_tiny
18 | DROP_PATH_RATE: 0.1
19 |
--------------------------------------------------------------------------------
/cfgs/inline_pvt_m.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 224
3 | BATCH_SIZE: 128
4 |
5 | TRAIN:
6 | WEIGHT_DECAY: 0.05
7 | EPOCHS: 300
8 | WARMUP_EPOCHS: 5
9 | COOLDOWN_EPOCHS: 10
10 | BASE_LR: 5e-4
11 | WARMUP_LR: 1e-6
12 | MIN_LR: 1e-5
13 | CLIP_GRAD: 1.0
14 |
15 | MODEL:
16 | TYPE: inline_pvt_medium
17 | NAME: inline_pvt_medium
18 | DROP_PATH_RATE: 0.3
19 |
--------------------------------------------------------------------------------
/cfgs/inline_pvt_s.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 224
3 | BATCH_SIZE: 128
4 |
5 | TRAIN:
6 | WEIGHT_DECAY: 0.05
7 | EPOCHS: 300
8 | WARMUP_EPOCHS: 5
9 | COOLDOWN_EPOCHS: 10
10 | BASE_LR: 5e-4
11 | WARMUP_LR: 1e-6
12 | MIN_LR: 1e-5
13 | CLIP_GRAD: None
14 |
15 | MODEL:
16 | TYPE: inline_pvt_small
17 | NAME: inline_pvt_small
18 | DROP_PATH_RATE: 0.1
19 |
--------------------------------------------------------------------------------
/cfgs/inline_cswin_b.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 224
3 | BATCH_SIZE: 64
4 |
5 | TRAIN:
6 | WEIGHT_DECAY: 0.1
7 | EPOCHS: 300
8 | WARMUP_EPOCHS: 20
9 | COOLDOWN_EPOCHS: 10
10 | BASE_LR: 5e-4
11 | WARMUP_LR: 1e-6
12 | MIN_LR: 1e-5
13 | CLIP_GRAD: None
14 |
15 | MODEL:
16 | TYPE: inline_cswin_base
17 | NAME: inline_cswin_base
18 | DROP_PATH_RATE: 0.5
19 | INLINE:
20 | ATTN_TYPE: IISS
21 |
--------------------------------------------------------------------------------
/cfgs/inline_cswin_s.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 224
3 | BATCH_SIZE: 128
4 |
5 | TRAIN:
6 | WEIGHT_DECAY: 0.05
7 | EPOCHS: 300
8 | WARMUP_EPOCHS: 20
9 | COOLDOWN_EPOCHS: 10
10 | BASE_LR: 5e-4
11 | WARMUP_LR: 1e-6
12 | MIN_LR: 1e-5
13 | CLIP_GRAD: None
14 |
15 | MODEL:
16 | TYPE: inline_cswin_small
17 | NAME: inline_cswin_small
18 | DROP_PATH_RATE: 0.4
19 | INLINE:
20 | ATTN_TYPE: IISS
--------------------------------------------------------------------------------
/cfgs/inline_cswin_t.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 224
3 | BATCH_SIZE: 128
4 |
5 | TRAIN:
6 | WEIGHT_DECAY: 0.05
7 | EPOCHS: 300
8 | WARMUP_EPOCHS: 20
9 | COOLDOWN_EPOCHS: 10
10 | BASE_LR: 5e-4
11 | WARMUP_LR: 1e-6
12 | MIN_LR: 1e-5
13 | CLIP_GRAD: None
14 |
15 | MODEL:
16 | TYPE: inline_cswin_tiny
17 | NAME: inline_cswin_tiny
18 | DROP_PATH_RATE: 0.2
19 | INLINE:
20 | ATTN_TYPE: IISS
21 |
--------------------------------------------------------------------------------
/cfgs/inline_cswin_b_384.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 384
3 | BATCH_SIZE: 32
4 |
5 | TRAIN:
6 | WEIGHT_DECAY: 1e-8
7 | EPOCHS: 20
8 | WARMUP_EPOCHS: 0
9 | COOLDOWN_EPOCHS: 10
10 | BASE_LR: 5e-6
11 | WARMUP_LR: 5e-6
12 | MIN_LR: 5e-7
13 | CLIP_GRAD: None
14 |
15 | MODEL:
16 | TYPE: inline_cswin_base_384
17 | NAME: inline_cswin_base_384
18 | DROP_PATH_RATE: 0.7
19 | INLINE:
20 | ATTN_TYPE: IISS
21 | CSWIN_LA_SPLIT_SIZE: 96-48-24-12
22 |
--------------------------------------------------------------------------------
/cfgs/inline_swin_b_384.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 384
3 | BATCH_SIZE: 32
4 | MODEL:
5 | TYPE: inline_swin
6 | NAME: inline_swin_base_384
7 | DROP_PATH_RATE: 0.5
8 | SWIN:
9 | EMBED_DIM: 128
10 | DEPTHS: [ 2, 2, 18, 2 ]
11 | NUM_HEADS: [ 4, 8, 16, 32 ]
12 | WINDOW_SIZE: 96
13 | INLINE:
14 | ATTN_TYPE: IIMS2
15 | TRAIN:
16 | EPOCHS: 30
17 | WARMUP_EPOCHS: 5
18 | WEIGHT_DECAY: 1e-8
19 | BASE_LR: 2e-05
20 | WARMUP_LR: 2e-08
21 | MIN_LR: 2e-07
22 | TEST:
23 | CROP: False
--------------------------------------------------------------------------------
/data/samplers.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import torch
9 |
10 |
11 | class SubsetRandomSampler(torch.utils.data.Sampler):
12 | r"""Samples elements randomly from a given list of indices, without replacement.
13 |
14 | Arguments:
15 | indices (sequence): a sequence of indices
16 | """
17 |
18 | def __init__(self, indices):
19 | self.epoch = 0
20 | self.indices = indices
21 |
22 | def __iter__(self):
23 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
24 |
25 | def __len__(self):
26 | return len(self.indices)
27 |
28 | def set_epoch(self, epoch):
29 | self.epoch = epoch
30 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import logging
11 | import functools
12 | from termcolor import colored
13 |
14 |
15 | @functools.lru_cache()
16 | def create_logger(output_dir, dist_rank=0, name=''):
17 | # create logger
18 | logger = logging.getLogger(name)
19 | logger.setLevel(logging.DEBUG)
20 | logger.propagate = False
21 |
22 | # create formatter
23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
26 |
27 | # create console handlers for master process
28 | if dist_rank == 0:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setLevel(logging.DEBUG)
31 | console_handler.setFormatter(
32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
33 | logger.addHandler(console_handler)
34 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
35 | file_handler.setLevel(logging.DEBUG)
36 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
37 | logger.addHandler(file_handler)
38 |
39 | return logger
40 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | from warnings import resetwarnings
2 | import torch.optim as optim
3 |
4 | def build_optimizer(config, model):
5 | """
6 | Build optimizer, set weight decay of normalization to 0 by default.
7 | """
8 | skip = {}
9 | skip_keywords = {}
10 | if hasattr(model, 'no_weight_decay'):
11 | skip = model.no_weight_decay()
12 | if hasattr(model, 'no_weight_decay_keywords'):
13 | skip_keywords = model.no_weight_decay_keywords()
14 |
15 | if hasattr(model, 'lower_lr_kvs'):
16 | lower_lr_kvs = model.lower_lr_kvs
17 | else:
18 | lower_lr_kvs = {}
19 |
20 | parameters = set_weight_decay_and_lr(
21 | model, skip, skip_keywords, lower_lr_kvs, config.TRAIN.BASE_LR)
22 |
23 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
24 | optimizer = None
25 | if opt_lower == 'sgd':
26 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
27 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
28 | elif opt_lower == 'adamw':
29 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
30 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
31 |
32 | return optimizer
33 |
34 |
35 | def set_weight_decay_and_lr(
36 | model,
37 | skip_list=(), skip_keywords=(),
38 | lower_lr_kvs={}, base_lr=5e-4):
39 | # breakpoint()
40 | assert len(lower_lr_kvs) == 1 or len(lower_lr_kvs) == 0
41 | has_lower_lr = len(lower_lr_kvs) == 1
42 | if has_lower_lr:
43 | for k,v in lower_lr_kvs.items():
44 | lower_lr_key = k
45 | lower_lr = v * base_lr
46 |
47 | has_decay = []
48 | has_decay_low = []
49 | no_decay = []
50 | no_decay_low = []
51 |
52 | for name, param in model.named_parameters():
53 | if not param.requires_grad:
54 | continue # frozen weights
55 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
56 | check_keywords_in_name(name, skip_keywords):
57 |
58 | if has_lower_lr and check_keywords_in_name(name, (lower_lr_key,)):
59 | no_decay_low.append(param)
60 | else:
61 | no_decay.append(param)
62 |
63 | else:
64 |
65 | if has_lower_lr and check_keywords_in_name(name, (lower_lr_key,)):
66 | has_decay_low.append(param)
67 | else:
68 | has_decay.append(param)
69 |
70 | if has_lower_lr:
71 | result = [
72 | {'params': has_decay},
73 | {'params': has_decay_low, 'lr': lower_lr},
74 | {'params': no_decay, 'weight_decay': 0.},
75 | {'params': no_decay_low, 'weight_decay': 0., 'lr': lower_lr}
76 | ]
77 | else:
78 | result = [
79 | {'params': has_decay},
80 | {'params': no_decay, 'weight_decay': 0.}
81 | ]
82 | # breakpoint()
83 | return result
84 |
85 |
86 | def check_keywords_in_name(name, keywords=()):
87 | isin = False
88 | for keyword in keywords:
89 | if keyword in name:
90 | isin = True
91 | return isin
92 |
--------------------------------------------------------------------------------
/models/build.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | from .inline_swin import InLineSwin
9 | from .inline_deit import inline_deit_tiny, inline_deit_small, inline_deit_base
10 | from .inline_pvt import inline_pvt_tiny, inline_pvt_small, inline_pvt_medium, inline_pvt_large
11 | from .inline_cswin import inline_cswin_tiny, inline_cswin_small, inline_cswin_base, inline_cswin_base_384
12 |
13 |
14 | def build_model(config):
15 | model_type = config.MODEL.TYPE
16 | if model_type == 'inline_swin':
17 | model = InLineSwin(img_size=config.DATA.IMG_SIZE,
18 | patch_size=config.MODEL.SWIN.PATCH_SIZE,
19 | in_chans=config.MODEL.SWIN.IN_CHANS,
20 | num_classes=config.MODEL.NUM_CLASSES,
21 | embed_dim=config.MODEL.SWIN.EMBED_DIM,
22 | depths=config.MODEL.SWIN.DEPTHS,
23 | num_heads=config.MODEL.SWIN.NUM_HEADS,
24 | window_size=config.MODEL.SWIN.WINDOW_SIZE,
25 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
26 | qkv_bias=config.MODEL.SWIN.QKV_BIAS,
27 | qk_scale=config.MODEL.SWIN.QK_SCALE,
28 | drop_rate=config.MODEL.DROP_RATE,
29 | drop_path_rate=config.MODEL.DROP_PATH_RATE,
30 | ape=config.MODEL.SWIN.APE,
31 | patch_norm=config.MODEL.SWIN.PATCH_NORM,
32 | use_checkpoint=config.TRAIN.USE_CHECKPOINT,
33 | attn_type=config.MODEL.INLINE.ATTN_TYPE)
34 |
35 | elif model_type in ['inline_deit_tiny', 'inline_deit_small', 'inline_deit_base']:
36 | model = eval(model_type + '(img_size=config.DATA.IMG_SIZE,'
37 | 'drop_path_rate=config.MODEL.DROP_PATH_RATE)')
38 |
39 | elif model_type in ['inline_pvt_tiny', 'inline_pvt_small', 'inline_pvt_medium', 'inline_pvt_large']:
40 | model = eval(model_type + '(img_size=config.DATA.IMG_SIZE,'
41 | 'drop_path_rate=config.MODEL.DROP_PATH_RATE,'
42 | 'attn_type=config.MODEL.INLINE.ATTN_TYPE,'
43 | 'la_sr_ratios=str(config.MODEL.INLINE.PVT_LA_SR_RATIOS))')
44 |
45 | elif model_type in ['inline_cswin_tiny', 'inline_cswin_small', 'inline_cswin_base', 'inline_cswin_base_384']:
46 | model = eval(model_type + '(img_size=config.DATA.IMG_SIZE,'
47 | 'in_chans=config.MODEL.SWIN.IN_CHANS,'
48 | 'num_classes=config.MODEL.NUM_CLASSES,'
49 | 'drop_rate=config.MODEL.DROP_RATE,'
50 | 'drop_path_rate=config.MODEL.DROP_PATH_RATE,'
51 | 'attn_type=config.MODEL.INLINE.ATTN_TYPE,'
52 | 'la_split_size=config.MODEL.INLINE.CSWIN_LA_SPLIT_SIZE)')
53 |
54 | else:
55 | raise NotImplementedError(f"Unkown model: {model_type}")
56 |
57 | return model
58 |
--------------------------------------------------------------------------------
/data/zipreader.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import zipfile
10 | import io
11 | import numpy as np
12 | from PIL import Image
13 | from PIL import ImageFile
14 |
15 | ImageFile.LOAD_TRUNCATED_IMAGES = True
16 |
17 |
18 | def is_zip_path(img_or_path):
19 | """judge if this is a zip path"""
20 | return '.zip@' in img_or_path
21 |
22 |
23 | class ZipReader(object):
24 | """A class to read zipped files"""
25 | zip_bank = dict()
26 |
27 | def __init__(self):
28 | super(ZipReader, self).__init__()
29 |
30 | @staticmethod
31 | def get_zipfile(path):
32 | zip_bank = ZipReader.zip_bank
33 | if path not in zip_bank:
34 | zfile = zipfile.ZipFile(path, 'r')
35 | zip_bank[path] = zfile
36 | return zip_bank[path]
37 |
38 | @staticmethod
39 | def split_zip_style_path(path):
40 | pos_at = path.index('@')
41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
42 |
43 | zip_path = path[0: pos_at]
44 | folder_path = path[pos_at + 1:]
45 | folder_path = str.strip(folder_path, '/')
46 | return zip_path, folder_path
47 |
48 | @staticmethod
49 | def list_folder(path):
50 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
51 |
52 | zfile = ZipReader.get_zipfile(zip_path)
53 | folder_list = []
54 | for file_foler_name in zfile.namelist():
55 | file_foler_name = str.strip(file_foler_name, '/')
56 | if file_foler_name.startswith(folder_path) and \
57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \
58 | file_foler_name != folder_path:
59 | if len(folder_path) == 0:
60 | folder_list.append(file_foler_name)
61 | else:
62 | folder_list.append(file_foler_name[len(folder_path) + 1:])
63 |
64 | return folder_list
65 |
66 | @staticmethod
67 | def list_files(path, extension=None):
68 | if extension is None:
69 | extension = ['.*']
70 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
71 |
72 | zfile = ZipReader.get_zipfile(zip_path)
73 | file_lists = []
74 | for file_foler_name in zfile.namelist():
75 | file_foler_name = str.strip(file_foler_name, '/')
76 | if file_foler_name.startswith(folder_path) and \
77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
78 | if len(folder_path) == 0:
79 | file_lists.append(file_foler_name)
80 | else:
81 | file_lists.append(file_foler_name[len(folder_path) + 1:])
82 |
83 | return file_lists
84 |
85 | @staticmethod
86 | def read(path):
87 | zip_path, path_img = ZipReader.split_zip_style_path(path)
88 | zfile = ZipReader.get_zipfile(zip_path)
89 | data = zfile.read(path_img)
90 | return data
91 |
92 | @staticmethod
93 | def imread(path):
94 | zip_path, path_img = ZipReader.split_zip_style_path(path)
95 | zfile = ZipReader.get_zipfile(zip_path)
96 | data = zfile.read(path_img)
97 | try:
98 | im = Image.open(io.BytesIO(data))
99 | except:
100 | print("ERROR IMG LOADED: ", path_img)
101 | random_img = np.random.rand(224, 224, 3) * 255
102 | im = Image.fromarray(np.uint8(random_img))
103 | return im
104 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from timm.scheduler.cosine_lr import CosineLRScheduler
10 | from timm.scheduler.step_lr import StepLRScheduler
11 | from timm.scheduler.scheduler import Scheduler
12 |
13 |
14 | def build_scheduler(config, optimizer, n_iter_per_epoch):
15 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
16 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
17 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
18 |
19 | lr_scheduler = None
20 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
21 | lr_scheduler = CosineLRScheduler(
22 | optimizer,
23 | t_initial=num_steps,
24 | lr_min=config.TRAIN.MIN_LR,
25 | warmup_lr_init=config.TRAIN.WARMUP_LR,
26 | warmup_t=warmup_steps,
27 | cycle_limit=1,
28 | t_in_epochs=False,
29 | )
30 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
31 | lr_scheduler = LinearLRScheduler(
32 | optimizer,
33 | t_initial=num_steps,
34 | lr_min_rate=0.01,
35 | warmup_lr_init=config.TRAIN.WARMUP_LR,
36 | warmup_t=warmup_steps,
37 | t_in_epochs=False,
38 | )
39 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
40 | lr_scheduler = StepLRScheduler(
41 | optimizer,
42 | decay_t=decay_steps,
43 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
44 | warmup_lr_init=config.TRAIN.WARMUP_LR,
45 | warmup_t=warmup_steps,
46 | t_in_epochs=False,
47 | )
48 |
49 | return lr_scheduler
50 |
51 |
52 | class LinearLRScheduler(Scheduler):
53 | def __init__(self,
54 | optimizer: torch.optim.Optimizer,
55 | t_initial: int,
56 | lr_min_rate: float,
57 | warmup_t=0,
58 | warmup_lr_init=0.,
59 | t_in_epochs=True,
60 | noise_range_t=None,
61 | noise_pct=0.67,
62 | noise_std=1.0,
63 | noise_seed=42,
64 | initialize=True,
65 | ) -> None:
66 | super().__init__(
67 | optimizer, param_group_field="lr",
68 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
69 | initialize=initialize)
70 |
71 | self.t_initial = t_initial
72 | self.lr_min_rate = lr_min_rate
73 | self.warmup_t = warmup_t
74 | self.warmup_lr_init = warmup_lr_init
75 | self.t_in_epochs = t_in_epochs
76 | if self.warmup_t:
77 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
78 | super().update_groups(self.warmup_lr_init)
79 | else:
80 | self.warmup_steps = [1 for _ in self.base_values]
81 |
82 | def _get_lr(self, t):
83 | if t < self.warmup_t:
84 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
85 | else:
86 | t = t - self.warmup_t
87 | total_t = self.t_initial - self.warmup_t
88 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
89 | return lrs
90 |
91 | def get_epoch_values(self, epoch: int):
92 | if self.t_in_epochs:
93 | return self._get_lr(epoch)
94 | else:
95 | return None
96 |
97 | def get_update_values(self, num_updates: int):
98 | if not self.t_in_epochs:
99 | return self._get_lr(num_updates)
100 | else:
101 | return None
102 |
--------------------------------------------------------------------------------
/data/build.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin on HF
3 | # --------------------------------------------------------
4 |
5 | import os
6 | import torch
7 | import numpy as np
8 | import torch.distributed as dist
9 | from torchvision import datasets, transforms
10 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11 | from timm.data import Mixup
12 | from timm.data import create_transform
13 | from timm.data.transforms import _pil_interp
14 |
15 | from .cached_image_folder import CachedImageFolder
16 | from .samplers import SubsetRandomSampler
17 |
18 |
19 | def build_loader(config):
20 | config.defrost()
21 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
22 | config.freeze()
23 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
24 | dataset_val, _ = build_dataset(is_train=False, config=config)
25 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
26 |
27 | num_tasks = dist.get_world_size()
28 | global_rank = dist.get_rank()
29 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
30 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
31 | sampler_train = SubsetRandomSampler(indices)
32 | else:
33 | sampler_train = torch.utils.data.DistributedSampler(
34 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
35 | )
36 |
37 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
38 | sampler_val = SubsetRandomSampler(indices)
39 |
40 | data_loader_train = torch.utils.data.DataLoader(
41 | dataset_train, sampler=sampler_train,
42 | batch_size=config.DATA.BATCH_SIZE,
43 | num_workers=config.DATA.NUM_WORKERS,
44 | pin_memory=config.DATA.PIN_MEMORY,
45 | drop_last=True,
46 | )
47 |
48 | data_loader_val = torch.utils.data.DataLoader(
49 | dataset_val, sampler=sampler_val,
50 | batch_size=config.DATA.BATCH_SIZE,
51 | shuffle=False,
52 | num_workers=config.DATA.NUM_WORKERS,
53 | pin_memory=config.DATA.PIN_MEMORY,
54 | drop_last=False
55 | )
56 |
57 | # setup mixup / cutmix
58 | mixup_fn = None
59 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
60 | if mixup_active:
61 | mixup_fn = Mixup(
62 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
63 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
64 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
65 |
66 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
67 |
68 |
69 | def build_dataset(is_train, config):
70 | transform = build_transform(is_train, config)
71 | if config.DATA.DATASET == 'imagenet':
72 | prefix = 'train' if is_train else 'val'
73 | if config.DATA.ZIP_MODE:
74 | ann_file = prefix + "_map.txt"
75 | prefix = prefix + ".zip@/"
76 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
77 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
78 | else:
79 | root = os.path.join(config.DATA.DATA_PATH, prefix)
80 | dataset = datasets.ImageFolder(root, transform=transform)
81 | nb_classes = 1000
82 | else:
83 | raise NotImplementedError("We only support ImageNet Now.")
84 |
85 | return dataset, nb_classes
86 |
87 |
88 | def build_transform(is_train, config):
89 | resize_im = config.DATA.IMG_SIZE > 32
90 | if is_train:
91 | # this should always dispatch to transforms_imagenet_train
92 | transform = create_transform(
93 | input_size=config.DATA.IMG_SIZE,
94 | is_training=True,
95 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
96 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
97 | re_prob=config.AUG.REPROB,
98 | re_mode=config.AUG.REMODE,
99 | re_count=config.AUG.RECOUNT,
100 | interpolation=config.DATA.INTERPOLATION,
101 | )
102 | if not resize_im:
103 | # replace RandomResizedCropAndInterpolation with
104 | # RandomCrop
105 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
106 | return transform
107 |
108 | t = []
109 | if resize_im:
110 | if config.TEST.CROP:
111 | size = int((256 / 224) * config.DATA.IMG_SIZE)
112 | t.append(
113 | transforms.Resize((size, size), interpolation=_pil_interp(config.DATA.INTERPOLATION)),
114 | # to maintain same ratio w.r.t. 224 images
115 | )
116 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
117 | else:
118 | t.append(
119 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
120 | interpolation=_pil_interp(config.DATA.INTERPOLATION))
121 | )
122 |
123 | t.append(transforms.ToTensor())
124 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
125 | print(t)
126 | return transforms.Compose(t)
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bridging the divide: Reconsidering softmax and linear attention
2 |
3 | This repo contains the official PyTorch code and pre-trained models for **Injective Linear Attention (InLine)**.
4 |
5 | + [Bridging the divide: Reconsidering softmax and linear attention](https://arxiv.org/abs/2412.06590) [[中文讲解]](https://www.bilibili.com/video/BV1BAqCYnEag)
6 |
7 |
8 |
9 | ## News
10 |
11 | - November 12 2024: Initialize repo.
12 |
13 | ## Abstract
14 |
15 | Widely adopted in modern Vision Transformer designs, Softmax attention can effectively capture long-range visual information; however, it incurs excessive computational cost when dealing with high-resolution inputs. In contrast, linear attention naturally enjoys linear complexity and has great potential to scale up to higher-resolution images. Nonetheless, the unsatisfactory performance of linear attention greatly limits its practical application in various scenarios. In this paper, we take a step forward to close the gap between the linear and Softmax attention with novel theoretical analyses, which demystify the core factors behind the per formance deviations. Specifically, we present two key perspectives to understand and alleviate the limitations of linear attention: the **injective property** and the **local modeling ability**. Firstly, we prove that linear attention is not injective, which is prone to assign identical attention weights to different query vectors, thus adding to severe semantic confusion since different queries correspond to the same outputs. Secondly, we confirm that effective local modeling is essential for the success of Softmax attention, in which linear attention falls short. The aforementioned two fundamental differences significantly contribute to the disparities between these two attention paradigms, which is demonstrated by our substantial empirical validation in the paper. In addition, more experiment results indicate that linear attention, as long as endowed with these two properties, can outperform Softmax attention across various tasks while maintaining lower computation complexity.
16 |
17 | ## Injectivity of Attention Function
18 |
19 | We find that the injectivity of attention function greatly affects the performance of the model. Specifically, *if the attention function is not injective, different queries will induce identical attention distributions, leading to severe semantic confusion within the feature space.* Our prove that the Softmax attention function is an injective function, whereas the linear attention function is not. Therefore, linear attention is vulnerable to the semantic confusion problem, which largely leads to its insufficient expressiveness.
20 |
21 |
22 |
23 |
24 |
25 | Our method, **Injective Linear Attention (InLine)**:
26 |
27 | $$\mathrm{InL_K}(Q_i) = {\left[
28 | \phi(Q_i)^\top\phi(K_1),
29 | \cdots,
30 | \phi(Q_i)^\top\phi(K_N)
31 | \right]}^\top - \frac{1}{N}\sum_{s=1}^{N} \phi(Q_i)^\top\phi(K_s) + \frac{1}{N}.$$
32 |
33 |
34 | ## Results
35 |
36 | - ImageNet-1K results.
37 |
38 |
39 |
40 |
41 |
42 |
43 | - Real speed measurements. Benefited from linear complexity and simple design, our InLine attention delivers much higher inference speed than Softmax attention, especially in high-resolution scenarios.
44 |
45 |
46 |
47 |
48 |
49 |
50 | ## Dependencies
51 |
52 | - Python 3.9
53 | - PyTorch == 1.11.0
54 | - torchvision == 0.12.0
55 | - numpy
56 | - timm == 0.4.12
57 | - yacs
58 |
59 | The ImageNet dataset should be prepared as follows:
60 |
61 | ```
62 | imagenet
63 | ├── train
64 | │ ├── class1
65 | │ │ ├── img1.jpeg
66 | │ │ └── ...
67 | │ ├── class2
68 | │ │ ├── img2.jpeg
69 | │ │ └── ...
70 | │ └── ...
71 | └── val
72 | ├── class1
73 | │ ├── img3.jpeg
74 | │ └── ...
75 | ├── class2
76 | │ ├── img4.jpeg
77 | │ └── ...
78 | └── ...
79 | ```
80 |
81 | ## Pretrained Models
82 |
83 | | model | Resolution | #Params | FLOPs | acc@1 | config | pretrained weights |
84 | | ------ | :--------: | :-----: | :---: | :---: | :--------------------------: | :----------------------------------------------------------: |
85 | | InLine-DeiT-T | 224 | 6.5M | 1.1G | 74.5 | [config](./cfgs/inline_deit_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/1d6b8191ad6d4114b291/?dl=1) |
86 | | InLine-DeiT-S | 288 | 16.7M | 5.0G | 80.2 | [config](./cfgs/inline_deit_s.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/2f3898b07e9247f3beb3/?dl=1) |
87 | | InLine-DeiT-B | 448 | 23.8M | 17.2G | 82.3 | [config](./cfgs/inline_deit_b.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/10bdd726d4b0435eb34e/?dl=1) |
88 | | InLine-PVT-T | 224 | 12.0M | 2.0G | 78.2 | [config](./cfgs/inline_pvt_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/90ac52b1555b401eb5e6/?dl=1) |
89 | | InLine-PVT-S | 224 | 21.6M | 3.9G | 82.0 | [config](./cfgs/inline_pvt_s.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/1ab953b2479d433080a3/?dl=1) |
90 | | InLine-PVT-M | 224 | 37.6M | 6.9G | 83.2 | [config](./cfgs/inline_pvt_m.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/a72aec31e6084bc0a280/?dl=1) |
91 | | InLine-PVT-L | 224 | 50.2M | 10.2G | 83.6 | [config](./cfgs/inline_pvt_b.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/efd91318ba964f01b288/?dl=1) |
92 | | InLine-Swin-T | 224 | 30M | 4.5G | 82.4 | [config](./cfgs/inline_swin_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/32810869fcc34410966b/?dl=1) |
93 | | InLine-Swin-S | 224 | 50M | 8.7G | 83.6 | [config](./cfgs/inline_swin_s.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/e9657fd247c04c7cb1a1/?dl=1) |
94 | | InLine-Swin-B | 224 | 88M | 15.4G | 84.1 | [config](./cfgs/inline_swin_b.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/bf23564bb64c420aafe1/?dl=1) |
95 | | InLine-CSwin-T | 224 | 25M | 4.3G | 83.2 | [config](./cfgs/inline_cswin_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/49fa2ecc543647c4b970/?dl=1) |
96 | | InLine-CSwin-S | 224 | 43M | 6.8G | 83.8 | [config](./cfgs/inline_cswin_s.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/59f4e65f776f4052b93c/?dl=1) |
97 | | InLine-CSwin-B | 224 | 96M | 14.9G | 84.5 | [config](./cfgs/inline_cswin_b.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/91e17121df284ae38521/?dl=1) |
98 |
99 | ## Model Training and Inference
100 |
101 | - Evaluate InLine-DeiT/PVT/Swin on ImageNet:
102 |
103 | ```
104 | python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg --data-path --output --eval --resume
105 | ```
106 |
107 | - To train InLine-DeiT/PVT/Swin on ImageNet from scratch, run:
108 |
109 | ```
110 | python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg --data-path --output --amp
111 | ```
112 |
113 | - Evaluate InLine-CSwin on ImageNet:
114 |
115 | ```
116 | python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg --data-path --output --eval --resume
117 | ```
118 |
119 | - To train InLine-CSwin on ImageNet from scratch, run:
120 |
121 | ```
122 | python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg --data-path --output --amp
123 | ```
124 |
125 | ## Acknowledgements
126 |
127 | This code is developed on the top of [Swin Transformer](https://github.com/microsoft/Swin-Transformer).
128 |
129 | ## Citation
130 |
131 | If you find this repo helpful, please consider citing us.
132 |
133 | ```latex
134 | @inproceedings{han2024inline,
135 | title={Bridging the Divide: Reconsidering Softmax and Linear Attention
136 | },
137 | author={Han, Dongchen and Pu, Yifan and Xia, Zhuofan and Han, Yizeng and Pan, Xuran and Li, Xiu and Lu, Jiwen and Song, Shiji and Huang, Gao},
138 | booktitle={NeurIPS},
139 | year={2024},
140 | }
141 | ```
142 |
143 | ## Contact
144 |
145 | If you have any questions, please feel free to contact the authors.
146 |
147 | Dongchen Han: [hdc23@mails.tsinghua.edu.cn](mailto:hdc23@mails.tsinghua.edu.cn)
148 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------'
7 |
8 | import os
9 | import yaml
10 | from yacs.config import CfgNode as CN
11 |
12 | _C = CN()
13 |
14 | # Base config files
15 | _C.BASE = ['']
16 |
17 | # -----------------------------------------------------------------------------
18 | # Data settings
19 | # -----------------------------------------------------------------------------
20 | _C.DATA = CN()
21 | # Batch size for a single GPU, could be overwritten by command line argument
22 | _C.DATA.BATCH_SIZE = 128
23 | # Path to dataset, could be overwritten by command line argument
24 | _C.DATA.DATA_PATH = ''
25 | # Dataset name
26 | _C.DATA.DATASET = 'imagenet'
27 | # Input image size
28 | _C.DATA.IMG_SIZE = 224
29 | # Interpolation to resize image (random, bilinear, bicubic)
30 | _C.DATA.INTERPOLATION = 'bicubic'
31 | # Use zipped dataset instead of folder dataset
32 | # could be overwritten by command line argument
33 | _C.DATA.ZIP_MODE = False
34 | # Cache Data in Memory, could be overwritten by command line argument
35 | _C.DATA.CACHE_MODE = 'part'
36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
37 | _C.DATA.PIN_MEMORY = True
38 | # Number of data loading threads
39 | _C.DATA.NUM_WORKERS = 8
40 |
41 | # -----------------------------------------------------------------------------
42 | # Model settings
43 | # -----------------------------------------------------------------------------
44 | _C.MODEL = CN()
45 | # Model type
46 | _C.MODEL.TYPE = 'swin'
47 | # Model name
48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
49 | # Checkpoint to resume, could be overwritten by command line argument
50 | _C.MODEL.RESUME = ''
51 | # Number of classes, overwritten in data preparation
52 | _C.MODEL.NUM_CLASSES = 1000
53 | # Dropout rate
54 | _C.MODEL.DROP_RATE = 0.0
55 | # Drop path rate
56 | _C.MODEL.DROP_PATH_RATE = 0.1
57 | # Label Smoothing
58 | _C.MODEL.LABEL_SMOOTHING = 0.1
59 |
60 | # Swin Transformer parameters
61 | _C.MODEL.SWIN = CN()
62 | _C.MODEL.SWIN.PATCH_SIZE = 4
63 | _C.MODEL.SWIN.IN_CHANS = 3
64 | _C.MODEL.SWIN.EMBED_DIM = 96
65 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
66 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
67 | _C.MODEL.SWIN.WINDOW_SIZE = 7
68 | _C.MODEL.SWIN.MLP_RATIO = 4.
69 | _C.MODEL.SWIN.QKV_BIAS = True
70 | _C.MODEL.SWIN.QK_SCALE = None
71 | _C.MODEL.SWIN.KA = [7, 7, 7, 7]
72 | _C.MODEL.SWIN.DIM_REDUCTION = [4, 4, 4, 4]
73 | _C.MODEL.SWIN.STAGES = [True, True, True, True]
74 | _C.MODEL.SWIN.STAGES_NUM = [-1, -1, -1, -1]
75 | _C.MODEL.SWIN.RPB = True
76 | _C.MODEL.SWIN.PADDING_MODE = 'zeros'
77 | _C.MODEL.SWIN.SHARE_DWC_KERNEL = True
78 | _C.MODEL.SWIN.SHARE_QKV = False
79 | _C.MODEL.SWIN.APE = False
80 | _C.MODEL.SWIN.PATCH_NORM = True
81 | _C.MODEL.SWIN.LR_FACTOR = 2
82 | _C.MODEL.SWIN.DEPTHS_LR = [2, 2, 2, 2]
83 | _C.MODEL.SWIN.FUSION_TYPE = 'add'
84 | _C.MODEL.SWIN.STAGE_CFG = None
85 |
86 | _C.MODEL.SWIN_HR = CN(new_allowed=True)
87 | _C.MODEL.SWIN_LRVIT = CN(new_allowed=True)
88 | _C.MODEL.PVD = CN(new_allowed=True)
89 |
90 | # -----------------------------------------------------------------------------
91 | # InLine Attention options
92 | # -----------------------------------------------------------------------------
93 | _C.MODEL.INLINE = CN()
94 | _C.MODEL.INLINE.ATTN_TYPE = 'IIII'
95 | _C.MODEL.INLINE.PVT_LA_SR_RATIOS = 1111
96 | _C.MODEL.INLINE.CSWIN_LA_SPLIT_SIZE = '56-28-14-7'
97 |
98 | # -----------------------------------------------------------------------------
99 | # Training settings
100 | # -----------------------------------------------------------------------------
101 | _C.TRAIN = CN()
102 | _C.TRAIN.START_EPOCH = 0
103 | _C.TRAIN.EPOCHS = 300
104 | _C.TRAIN.WARMUP_EPOCHS = 20
105 | _C.TRAIN.COOLDOWN_EPOCHS = 0
106 | _C.TRAIN.WEIGHT_DECAY = 0.05
107 | _C.TRAIN.BASE_LR = 5e-4
108 | _C.TRAIN.WARMUP_LR = 5e-7
109 | _C.TRAIN.MIN_LR = 5e-6
110 | # Clip gradient norm
111 | _C.TRAIN.CLIP_GRAD = 5.0
112 | # Auto resume from latest checkpoint
113 | _C.TRAIN.AUTO_RESUME = True
114 | # Whether to use gradient checkpointing to save memory
115 | # could be overwritten by command line argument
116 | _C.TRAIN.USE_CHECKPOINT = False
117 |
118 | # LR scheduler
119 | _C.TRAIN.LR_SCHEDULER = CN()
120 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
121 | # Epoch interval to decay LR, used in StepLRScheduler
122 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
123 | # LR decay rate, used in StepLRScheduler
124 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
125 |
126 | # Optimizer
127 | _C.TRAIN.OPTIMIZER = CN()
128 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
129 | # Optimizer Epsilon
130 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
131 | # Optimizer Betas
132 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
133 | # SGD momentum
134 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
135 |
136 | # -----------------------------------------------------------------------------
137 | # Augmentation settings
138 | # -----------------------------------------------------------------------------
139 | _C.AUG = CN()
140 | # Color jitter factor
141 | _C.AUG.COLOR_JITTER = 0.4
142 | # Use AutoAugment policy. "v0" or "original"
143 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
144 | # Random erase prob
145 | _C.AUG.REPROB = 0.25
146 | # Random erase mode
147 | _C.AUG.REMODE = 'pixel'
148 | # Random erase count
149 | _C.AUG.RECOUNT = 1
150 | # Mixup alpha, mixup enabled if > 0
151 | _C.AUG.MIXUP = 0.8
152 | # Cutmix alpha, cutmix enabled if > 0
153 | _C.AUG.CUTMIX = 1.0
154 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
155 | _C.AUG.CUTMIX_MINMAX = None
156 | # Probability of performing mixup or cutmix when either/both is enabled
157 | _C.AUG.MIXUP_PROB = 1.0
158 | # Probability of switching to cutmix when both mixup and cutmix enabled
159 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
160 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
161 | _C.AUG.MIXUP_MODE = 'batch'
162 |
163 | # -----------------------------------------------------------------------------
164 | # Testing settings
165 | # -----------------------------------------------------------------------------
166 | _C.TEST = CN()
167 | # Whether to use center crop when testing
168 | _C.TEST.CROP = True
169 |
170 | # -----------------------------------------------------------------------------
171 | # Misc
172 | # -----------------------------------------------------------------------------
173 |
174 | # overwritten by command line argument
175 | _C.AMP = False
176 | # Path to output folder, overwritten by command line argument
177 | _C.OUTPUT = ''
178 | # Tag of experiment, overwritten by command line argument
179 | _C.TAG = 'default'
180 | # Frequency to save checkpoint
181 | _C.SAVE_FREQ = 1
182 | # Frequency to logging info
183 | _C.PRINT_FREQ = 100
184 | # Fixed random seed
185 | _C.SEED = 0
186 | # Perform evaluation only, overwritten by command line argument
187 | _C.EVAL_MODE = False
188 | # Test throughput only, overwritten by command line argument
189 | _C.THROUGHPUT_MODE = False
190 | # local rank for DistributedDataParallel, given by command line argument
191 | _C.LOCAL_RANK = 0
192 |
193 |
194 | def _update_config_from_file(config, cfg_file):
195 | config.defrost()
196 | with open(cfg_file, 'r') as f:
197 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
198 |
199 | for cfg in yaml_cfg.setdefault('BASE', ['']):
200 | if cfg:
201 | _update_config_from_file(
202 | config, os.path.join(os.path.dirname(cfg_file), cfg)
203 | )
204 | print('=> merge config from {}'.format(cfg_file))
205 | config.merge_from_file(cfg_file)
206 | config.freeze()
207 |
208 |
209 | def update_config(config, args):
210 | _update_config_from_file(config, args.cfg)
211 |
212 | config.defrost()
213 | if args.opts:
214 | config.merge_from_list(args.opts)
215 |
216 | # merge from specific arguments
217 | if args.batch_size:
218 | config.DATA.BATCH_SIZE = args.batch_size
219 | if args.data_path:
220 | config.DATA.DATA_PATH = args.data_path
221 | if args.zip:
222 | config.DATA.ZIP_MODE = True
223 | if args.cache_mode:
224 | config.DATA.CACHE_MODE = args.cache_mode
225 | if args.resume:
226 | config.MODEL.RESUME = args.resume
227 | if args.use_checkpoint:
228 | config.TRAIN.USE_CHECKPOINT = True
229 | if args.amp:
230 | config.AMP = args.amp
231 | if args.output:
232 | config.OUTPUT = args.output
233 | if args.tag:
234 | config.TAG = args.tag
235 | if args.eval:
236 | config.EVAL_MODE = True
237 | if args.throughput:
238 | config.THROUGHPUT_MODE = True
239 |
240 | # set local rank for distributed training
241 | # config.LOCAL_RANK = args.local_rank
242 |
243 | # output folder
244 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
245 |
246 | config.freeze()
247 |
248 |
249 | def get_config(args):
250 | """Get a yacs CfgNode object with default values."""
251 | # Return a clone so that the defaults will not be altered
252 | # This is for the "local variable" use pattern
253 | config = _C.clone()
254 | update_config(config, args)
255 |
256 | return config
257 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import torch
10 | import torch.distributed as dist
11 |
12 |
13 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
14 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
15 | if config.MODEL.RESUME.startswith('https'):
16 | checkpoint = torch.hub.load_state_dict_from_url(
17 | config.MODEL.RESUME, map_location='cpu', check_hash=True)
18 | else:
19 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
20 | msg = model.load_state_dict(checkpoint['model'], strict=False)
21 | logger.info(msg)
22 | max_accuracy = 0.0
23 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
24 | optimizer.load_state_dict(checkpoint['optimizer'])
25 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
26 | config.defrost()
27 | config.TRAIN.START_EPOCH = checkpoint['epoch']
28 | config.freeze()
29 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
30 | if 'max_accuracy' in checkpoint:
31 | max_accuracy = checkpoint['max_accuracy']
32 |
33 | del checkpoint
34 | torch.cuda.empty_cache()
35 | return max_accuracy
36 |
37 |
38 | def load_pretrained(ckpt_path, model, logger):
39 | logger.info(f"==============> Loading pretrained form {ckpt_path}....................")
40 | checkpoint = torch.load(ckpt_path, map_location='cpu')
41 | # msg = model.load_pretrained(checkpoint['model'])
42 | # logger.info(msg)
43 | # logger.info(f"=> Loaded successfully {ckpt_path} ")
44 | # del checkpoint
45 | # torch.cuda.empty_cache()
46 | state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint
47 |
48 | # delete relative_position_index since we always re-init it
49 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
50 | for k in relative_position_index_keys:
51 | del state_dict[k]
52 |
53 | # delete relative_coords_table since we always re-init it
54 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
55 | for k in relative_position_index_keys:
56 | del state_dict[k]
57 |
58 | # delete attn_mask since we always re-init it
59 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
60 | for k in attn_mask_keys:
61 | del state_dict[k]
62 |
63 | # bicubic interpolate relative_position_bias_table if not match
64 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
65 | for k in relative_position_bias_table_keys:
66 | relative_position_bias_table_pretrained = state_dict[k]
67 | relative_position_bias_table_current = model.state_dict()[k]
68 | L1, nH1 = relative_position_bias_table_pretrained.size()
69 | L2, nH2 = relative_position_bias_table_current.size()
70 | if nH1 != nH2:
71 | logger.warning(f"Error in loading {k}, passing......")
72 | else:
73 | if L1 != L2:
74 | # bicubic interpolate relative_position_bias_table if not match
75 | S1 = int(L1 ** 0.5)
76 | S2 = int(L2 ** 0.5)
77 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
78 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
79 | mode='bicubic')
80 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
81 |
82 | # bicubic interpolate absolute_pos_embed if not match
83 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "pos_embed" in k]
84 | for k in absolute_pos_embed_keys:
85 | # dpe
86 | absolute_pos_embed_pretrained = state_dict[k]
87 | absolute_pos_embed_current = model.state_dict()[k]
88 | _, L1, C1 = absolute_pos_embed_pretrained.size()
89 | _, L2, C2 = absolute_pos_embed_current.size()
90 | if C1 != C1:
91 | logger.warning(f"Error in loading {k}, passing......")
92 | else:
93 | if L1 != L2:
94 | S1 = int(L1 ** 0.5)
95 | S2 = int(L2 ** 0.5)
96 | i, j = L1 - S1 ** 2, L2 - S2 ** 2
97 | absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained[:, i:, :].reshape(-1, S1, S1, C1)
98 | absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained_.permute(0, 3, 1, 2)
99 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
100 | absolute_pos_embed_pretrained_, size=(S2, S2), mode='bicubic')
101 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
102 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
103 | state_dict[k] = torch.cat([absolute_pos_embed_pretrained[:, :j, :],
104 | absolute_pos_embed_pretrained_resized], dim=1)
105 |
106 | # check classifier, if not match, then re-init classifier to zero
107 | head_bias_pretrained = state_dict['head.bias']
108 | Nc1 = head_bias_pretrained.shape[0]
109 | Nc2 = model.head.bias.shape[0]
110 | if (Nc1 != Nc2):
111 | if Nc1 == 21841 and Nc2 == 1000:
112 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
113 | map22kto1k_path = f'data/map22kto1k.txt'
114 | with open(map22kto1k_path) as f:
115 | map22kto1k = f.readlines()
116 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
117 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
118 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
119 | else:
120 | torch.nn.init.constant_(model.head.bias, 0.)
121 | torch.nn.init.constant_(model.head.weight, 0.)
122 | del state_dict['head.weight']
123 | del state_dict['head.bias']
124 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
125 |
126 | msg = model.load_state_dict(state_dict, strict=False)
127 | logger.warning(msg)
128 |
129 | logger.info(f"=> loaded successfully '{ckpt_path}'")
130 |
131 | del checkpoint
132 | torch.cuda.empty_cache()
133 |
134 |
135 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
136 | save_state = {'model': model.state_dict(),
137 | 'optimizer': optimizer.state_dict(),
138 | 'lr_scheduler': lr_scheduler.state_dict(),
139 | 'max_accuracy': max_accuracy,
140 | 'epoch': epoch,
141 | 'config': config}
142 |
143 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
144 | logger.info(f"{save_path} saving......")
145 | torch.save(save_state, save_path)
146 | logger.info(f"{save_path} saved !!!")
147 |
148 | def save_checkpoint_new(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, name=None):
149 | save_state = {'model': model.state_dict(),
150 | 'optimizer': optimizer.state_dict(),
151 | 'lr_scheduler': lr_scheduler.state_dict(),
152 | 'max_accuracy': max_accuracy,
153 | 'epoch': epoch,
154 | 'config': config}
155 | if name == None:
156 | old_ckpt = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch-3}.pth')
157 | if os.path.exists(old_ckpt):
158 | os.remove(old_ckpt)
159 |
160 | if name != None:
161 | save_path = os.path.join(config.OUTPUT, f'{name}.pth')
162 | logger.info(f"{save_path} saving......")
163 | torch.save(save_state, save_path)
164 | logger.info(f"{save_path} saved !!!")
165 | else:
166 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
167 | logger.info(f"{save_path} saving......")
168 | torch.save(save_state, save_path)
169 | logger.info(f"{save_path} saved !!!")
170 |
171 |
172 | def get_grad_norm(parameters, norm_type=2):
173 | if isinstance(parameters, torch.Tensor):
174 | parameters = [parameters]
175 | parameters = list(filter(lambda p: p.grad is not None, parameters))
176 | norm_type = float(norm_type)
177 | total_norm = 0
178 | for p in parameters:
179 | param_norm = p.grad.data.norm(norm_type)
180 | total_norm += param_norm.item() ** norm_type
181 | total_norm = total_norm ** (1. / norm_type)
182 | return total_norm
183 |
184 |
185 | def auto_resume_helper(output_dir):
186 | checkpoints = os.listdir(output_dir)
187 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
188 | print(f"All checkpoints founded in {output_dir}: {checkpoints}")
189 | if len(checkpoints) > 0:
190 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
191 | print(f"The latest checkpoint founded: {latest_checkpoint}")
192 | resume_file = latest_checkpoint
193 | else:
194 | resume_file = None
195 | return resume_file
196 |
197 |
198 | def reduce_tensor(tensor):
199 | rt = tensor.clone()
200 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
201 | rt /= dist.get_world_size()
202 | return rt
203 |
--------------------------------------------------------------------------------
/data/cached_image_folder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import io
9 | import os
10 | import time
11 | import torch.distributed as dist
12 | import torch.utils.data as data
13 | from PIL import Image
14 |
15 | from .zipreader import is_zip_path, ZipReader
16 |
17 |
18 | def has_file_allowed_extension(filename, extensions):
19 | """Checks if a file is an allowed extension.
20 | Args:
21 | filename (string): path to a file
22 | Returns:
23 | bool: True if the filename ends with a known image extension
24 | """
25 | filename_lower = filename.lower()
26 | return any(filename_lower.endswith(ext) for ext in extensions)
27 |
28 |
29 | def find_classes(dir):
30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
31 | classes.sort()
32 | class_to_idx = {classes[i]: i for i in range(len(classes))}
33 | return classes, class_to_idx
34 |
35 |
36 | def make_dataset(dir, class_to_idx, extensions):
37 | images = []
38 | dir = os.path.expanduser(dir)
39 | for target in sorted(os.listdir(dir)):
40 | d = os.path.join(dir, target)
41 | if not os.path.isdir(d):
42 | continue
43 |
44 | for root, _, fnames in sorted(os.walk(d)):
45 | for fname in sorted(fnames):
46 | if has_file_allowed_extension(fname, extensions):
47 | path = os.path.join(root, fname)
48 | item = (path, class_to_idx[target])
49 | images.append(item)
50 |
51 | return images
52 |
53 |
54 | def make_dataset_with_ann(ann_file, img_prefix, extensions):
55 | images = []
56 | with open(ann_file, "r") as f:
57 | contents = f.readlines()
58 | for line_str in contents:
59 | path_contents = [c for c in line_str.split('\t')]
60 | im_file_name = path_contents[0]
61 | class_index = int(path_contents[1])
62 |
63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
64 | item = (os.path.join(img_prefix, im_file_name), class_index)
65 |
66 | images.append(item)
67 |
68 | return images
69 |
70 |
71 | class DatasetFolder(data.Dataset):
72 | """A generic data loader where the samples are arranged in this way: ::
73 | root/class_x/xxx.ext
74 | root/class_x/xxy.ext
75 | root/class_x/xxz.ext
76 | root/class_y/123.ext
77 | root/class_y/nsdf3.ext
78 | root/class_y/asd932_.ext
79 | Args:
80 | root (string): Root directory path.
81 | loader (callable): A function to load a sample given its path.
82 | extensions (list[string]): A list of allowed extensions.
83 | transform (callable, optional): A function/transform that takes in
84 | a sample and returns a transformed version.
85 | E.g, ``transforms.RandomCrop`` for images.
86 | target_transform (callable, optional): A function/transform that takes
87 | in the target and transforms it.
88 | Attributes:
89 | samples (list): List of (sample path, class_index) tuples
90 | """
91 |
92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
93 | cache_mode="no"):
94 | # image folder mode
95 | if ann_file == '':
96 | _, class_to_idx = find_classes(root)
97 | samples = make_dataset(root, class_to_idx, extensions)
98 | # zip mode
99 | else:
100 | samples = make_dataset_with_ann(os.path.join(root, ann_file),
101 | os.path.join(root, img_prefix),
102 | extensions)
103 |
104 | if len(samples) == 0:
105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
106 | "Supported extensions are: " + ",".join(extensions)))
107 |
108 | self.root = root
109 | self.loader = loader
110 | self.extensions = extensions
111 |
112 | self.samples = samples
113 | self.labels = [y_1k for _, y_1k in samples]
114 | self.classes = list(set(self.labels))
115 |
116 | self.transform = transform
117 | self.target_transform = target_transform
118 |
119 | self.cache_mode = cache_mode
120 | if self.cache_mode != "no":
121 | self.init_cache()
122 |
123 | def init_cache(self):
124 | assert self.cache_mode in ["part", "full"]
125 | n_sample = len(self.samples)
126 | global_rank = dist.get_rank()
127 | world_size = dist.get_world_size()
128 |
129 | samples_bytes = [None for _ in range(n_sample)]
130 | start_time = time.time()
131 | for index in range(n_sample):
132 | if index % (n_sample // 10) == 0:
133 | t = time.time() - start_time
134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
135 | start_time = time.time()
136 | path, target = self.samples[index]
137 | if self.cache_mode == "full":
138 | samples_bytes[index] = (ZipReader.read(path), target)
139 | elif self.cache_mode == "part" and index % world_size == global_rank:
140 | samples_bytes[index] = (ZipReader.read(path), target)
141 | else:
142 | samples_bytes[index] = (path, target)
143 | self.samples = samples_bytes
144 |
145 | def __getitem__(self, index):
146 | """
147 | Args:
148 | index (int): Index
149 | Returns:
150 | tuple: (sample, target) where target is class_index of the target class.
151 | """
152 | path, target = self.samples[index]
153 | sample = self.loader(path)
154 | if self.transform is not None:
155 | sample = self.transform(sample)
156 | if self.target_transform is not None:
157 | target = self.target_transform(target)
158 |
159 | return sample, target
160 |
161 | def __len__(self):
162 | return len(self.samples)
163 |
164 | def __repr__(self):
165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
167 | fmt_str += ' Root Location: {}\n'.format(self.root)
168 | tmp = ' Transforms (if any): '
169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
170 | tmp = ' Target Transforms (if any): '
171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
172 | return fmt_str
173 |
174 |
175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
176 |
177 |
178 | def pil_loader(path):
179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
180 | if isinstance(path, bytes):
181 | img = Image.open(io.BytesIO(path))
182 | elif is_zip_path(path):
183 | data = ZipReader.read(path)
184 | img = Image.open(io.BytesIO(data))
185 | else:
186 | with open(path, 'rb') as f:
187 | img = Image.open(f)
188 | return img.convert('RGB')
189 |
190 |
191 | def accimage_loader(path):
192 | import accimage
193 | try:
194 | return accimage.Image(path)
195 | except IOError:
196 | # Potentially a decoding problem, fall back to PIL.Image
197 | return pil_loader(path)
198 |
199 |
200 | def default_img_loader(path):
201 | from torchvision import get_image_backend
202 | if get_image_backend() == 'accimage':
203 | return accimage_loader(path)
204 | else:
205 | return pil_loader(path)
206 |
207 |
208 | class CachedImageFolder(DatasetFolder):
209 | """A generic data loader where the images are arranged in this way: ::
210 | root/dog/xxx.png
211 | root/dog/xxy.png
212 | root/dog/xxz.png
213 | root/cat/123.png
214 | root/cat/nsdf3.png
215 | root/cat/asd932_.png
216 | Args:
217 | root (string): Root directory path.
218 | transform (callable, optional): A function/transform that takes in an PIL image
219 | and returns a transformed version. E.g, ``transforms.RandomCrop``
220 | target_transform (callable, optional): A function/transform that takes in the
221 | target and transforms it.
222 | loader (callable, optional): A function to load an image given its path.
223 | Attributes:
224 | imgs (list): List of (image path, class_index) tuples
225 | """
226 |
227 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
228 | loader=default_img_loader, cache_mode="no"):
229 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
230 | ann_file=ann_file, img_prefix=img_prefix,
231 | transform=transform, target_transform=target_transform,
232 | cache_mode=cache_mode)
233 | self.imgs = self.samples
234 |
235 | def __getitem__(self, index):
236 | """
237 | Args:
238 | index (int): Index
239 | Returns:
240 | tuple: (image, target) where target is class_index of the target class.
241 | """
242 | path, target = self.samples[index]
243 | image = self.loader(path)
244 | if self.transform is not None:
245 | img = self.transform(image)
246 | else:
247 | img = image
248 | if self.target_transform is not None:
249 | target = self.target_transform(target)
250 |
251 | return img, target
252 |
--------------------------------------------------------------------------------
/utils_ema.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import torch
10 | import torch.distributed as dist
11 | from timm.utils.model import unwrap_model, get_state_dict
12 |
13 |
14 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
15 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
16 | if config.MODEL.RESUME.startswith('https'):
17 | checkpoint = torch.hub.load_state_dict_from_url(
18 | config.MODEL.RESUME, map_location='cpu', check_hash=True)
19 | else:
20 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
21 | msg = model.load_state_dict(checkpoint['model'], strict=False)
22 | logger.info(msg)
23 | max_accuracy = 0.0
24 | max_accuracy_e = 0.0
25 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
26 | optimizer.load_state_dict(checkpoint['optimizer'])
27 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
28 | config.defrost()
29 | config.TRAIN.START_EPOCH = checkpoint['epoch']
30 | config.freeze()
31 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
32 | if 'max_accuracy' in checkpoint:
33 | max_accuracy = checkpoint['max_accuracy']
34 | max_accuracy_e = checkpoint['max_accuracy_e']
35 |
36 | del checkpoint
37 | torch.cuda.empty_cache()
38 | return max_accuracy, max_accuracy_e
39 |
40 |
41 | def load_pretrained(ckpt_path, model, logger):
42 | logger.info(f"==============> Loading pretrained form {ckpt_path}....................")
43 | checkpoint = torch.load(ckpt_path, map_location='cpu')
44 | # msg = model.load_pretrained(checkpoint['model'])
45 | # logger.info(msg)
46 | # logger.info(f"=> Loaded successfully {ckpt_path} ")
47 | # del checkpoint
48 | # torch.cuda.empty_cache()
49 | state_dict = checkpoint['state_dict_ema'] if 'state_dict_ema' in checkpoint.keys() else checkpoint
50 |
51 | # delete relative_position_index since we always re-init it
52 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
53 | for k in relative_position_index_keys:
54 | del state_dict[k]
55 |
56 | # delete relative_coords_table since we always re-init it
57 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
58 | for k in relative_position_index_keys:
59 | del state_dict[k]
60 |
61 | # delete attn_mask since we always re-init it
62 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
63 | for k in attn_mask_keys:
64 | del state_dict[k]
65 |
66 | # bicubic interpolate relative_position_bias_table if not match
67 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
68 | for k in relative_position_bias_table_keys:
69 | relative_position_bias_table_pretrained = state_dict[k]
70 | relative_position_bias_table_current = model.state_dict()[k]
71 | L1, nH1 = relative_position_bias_table_pretrained.size()
72 | L2, nH2 = relative_position_bias_table_current.size()
73 | if nH1 != nH2:
74 | logger.warning(f"Error in loading {k}, passing......")
75 | else:
76 | if L1 != L2:
77 | # bicubic interpolate relative_position_bias_table if not match
78 | S1 = int(L1 ** 0.5)
79 | S2 = int(L2 ** 0.5)
80 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
81 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
82 | mode='bicubic')
83 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
84 |
85 | # bicubic interpolate absolute_pos_embed if not match
86 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "pos_embed" in k]
87 | for k in absolute_pos_embed_keys:
88 | # dpe
89 | absolute_pos_embed_pretrained = state_dict[k]
90 | absolute_pos_embed_current = model.state_dict()[k]
91 | _, L1, C1 = absolute_pos_embed_pretrained.size()
92 | _, L2, C2 = absolute_pos_embed_current.size()
93 | if C1 != C1:
94 | logger.warning(f"Error in loading {k}, passing......")
95 | else:
96 | if L1 != L2:
97 | S1 = int(L1 ** 0.5)
98 | S2 = int(L2 ** 0.5)
99 | i, j = L1 - S1 ** 2, L2 - S2 ** 2
100 | absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained[:, i:, :].reshape(-1, S1, S1, C1)
101 | absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained_.permute(0, 3, 1, 2)
102 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
103 | absolute_pos_embed_pretrained_, size=(S2, S2), mode='bicubic')
104 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
105 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
106 | state_dict[k] = torch.cat([absolute_pos_embed_pretrained[:, :j, :],
107 | absolute_pos_embed_pretrained_resized], dim=1)
108 |
109 | # check classifier, if not match, then re-init classifier to zero
110 | head_bias_pretrained = state_dict['head.bias']
111 | Nc1 = head_bias_pretrained.shape[0]
112 | Nc2 = model.head.bias.shape[0]
113 | if (Nc1 != Nc2):
114 | if Nc1 == 21841 and Nc2 == 1000:
115 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
116 | map22kto1k_path = f'data/map22kto1k.txt'
117 | with open(map22kto1k_path) as f:
118 | map22kto1k = f.readlines()
119 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
120 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
121 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
122 | else:
123 | torch.nn.init.constant_(model.head.bias, 0.)
124 | torch.nn.init.constant_(model.head.weight, 0.)
125 | del state_dict['head.weight']
126 | del state_dict['head.bias']
127 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
128 |
129 | msg = model.load_state_dict(state_dict, strict=False)
130 | logger.warning(msg)
131 |
132 | logger.info(f"=> loaded successfully '{ckpt_path}'")
133 |
134 | del checkpoint
135 | torch.cuda.empty_cache()
136 |
137 |
138 | def save_checkpoint(config, epoch, model, max_accuracy, max_accuracy_e, optimizer, lr_scheduler, logger):
139 | save_state = {'model': model.state_dict(),
140 | # 'model_ema': model_ema.state_dict(),
141 | 'optimizer': optimizer.state_dict(),
142 | 'lr_scheduler': lr_scheduler.state_dict(),
143 | 'max_accuracy': max_accuracy,
144 | 'max_accuracy_e': max_accuracy_e,
145 | 'epoch': epoch,
146 | 'config': config}
147 |
148 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
149 | logger.info(f"{save_path} saving......")
150 | torch.save(save_state, save_path)
151 | logger.info(f"{save_path} saved !!!")
152 |
153 | def save_checkpoint_ema(config, epoch, model, model_ema, max_accuracy, max_accuracy_e, optimizer, lr_scheduler, logger):
154 | save_state = {'model': model.state_dict(),
155 | # 'model_ema': model_ema.state_dict(),
156 | 'optimizer': optimizer.state_dict(),
157 | 'lr_scheduler': lr_scheduler.state_dict(),
158 | 'max_accuracy': max_accuracy,
159 | 'max_accuracy_e': max_accuracy_e,
160 | 'epoch': epoch,
161 | 'config': config}
162 | save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model)
163 |
164 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
165 | logger.info(f"{save_path} saving......")
166 | torch.save(save_state, save_path)
167 | logger.info(f"{save_path} saved !!!")
168 |
169 |
170 | def save_checkpoint_ema_new(config, epoch, model, model_ema, max_accuracy, max_accuracy_e, optimizer, lr_scheduler, logger, name=None):
171 | save_state = {'model': model.state_dict(),
172 | # 'model_ema': model_ema.state_dict(),
173 | 'optimizer': optimizer.state_dict(),
174 | 'lr_scheduler': lr_scheduler.state_dict(),
175 | 'max_accuracy': max_accuracy,
176 | 'max_accuracy_e': max_accuracy_e,
177 | 'epoch': epoch,
178 | 'config': config}
179 | save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model)
180 |
181 | if name == None:
182 | old_ckpt = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch-3}.pth')
183 | if os.path.exists(old_ckpt):
184 | os.remove(old_ckpt)
185 |
186 | if name != None:
187 | save_path = os.path.join(config.OUTPUT, f'{name}.pth')
188 | logger.info(f"{save_path} saving......")
189 | torch.save(save_state, save_path)
190 | logger.info(f"{save_path} saved !!!")
191 | else:
192 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
193 | logger.info(f"{save_path} saving......")
194 | torch.save(save_state, save_path)
195 | logger.info(f"{save_path} saved !!!")
196 |
197 |
198 | def get_grad_norm(parameters, norm_type=2):
199 | if isinstance(parameters, torch.Tensor):
200 | parameters = [parameters]
201 | parameters = list(filter(lambda p: p.grad is not None, parameters))
202 | norm_type = float(norm_type)
203 | total_norm = 0
204 | for p in parameters:
205 | param_norm = p.grad.data.norm(norm_type)
206 | total_norm += param_norm.item() ** norm_type
207 | total_norm = total_norm ** (1. / norm_type)
208 | return total_norm
209 |
210 |
211 | def auto_resume_helper(output_dir):
212 | checkpoints = os.listdir(output_dir)
213 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
214 | print(f"All checkpoints founded in {output_dir}: {checkpoints}")
215 | if len(checkpoints) > 0:
216 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
217 | print(f"The latest checkpoint founded: {latest_checkpoint}")
218 | resume_file = latest_checkpoint
219 | else:
220 | resume_file = None
221 | return resume_file
222 |
223 |
224 | def reduce_tensor(tensor):
225 | rt = tensor.clone()
226 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
227 | rt /= dist.get_world_size()
228 | return rt
229 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import datetime
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.backends.cudnn as cudnn
10 | import torch.distributed as dist
11 | from torch.cuda.amp import autocast, GradScaler
12 |
13 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
14 | from timm.utils import accuracy, AverageMeter
15 |
16 | from config import get_config
17 | from models import build_model
18 | from data import build_loader
19 | from lr_scheduler import build_scheduler
20 | from optimizer import build_optimizer
21 | from logger import create_logger
22 | from utils import load_checkpoint, save_checkpoint, save_checkpoint_new, get_grad_norm, auto_resume_helper, reduce_tensor, load_pretrained
23 |
24 | import warnings
25 | warnings.filterwarnings('ignore')
26 |
27 | def parse_option():
28 | parser = argparse.ArgumentParser('InLine Attention training and evaluation script', add_help=False)
29 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
30 | parser.add_argument(
31 | "--opts",
32 | help="Modify config options by adding 'KEY VALUE' pairs. ",
33 | default=None,
34 | nargs='+',
35 | )
36 |
37 | # easy config modification
38 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
39 | parser.add_argument('--data-path', type=str, help='path to dataset')
40 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
41 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
42 | help='no: no cache, '
43 | 'full: cache all data, '
44 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
45 | parser.add_argument('--resume', help='resume from checkpoint')
46 | parser.add_argument('--use-checkpoint', action='store_true',
47 | help="whether to use gradient checkpointing to save memory")
48 | parser.add_argument('--amp', action='store_true', default=False)
49 | parser.add_argument('--output', default='output', type=str, metavar='PATH',
50 | help='root of output folder, the full path is