├── models
├── __init__.py
├── build.py
└── lipsformer_swin.py
├── data
├── __init__.py
├── samplers.py
├── zipreader.py
├── cached_image_folder.py
└── build.py
├── figures
└── teaser.png
├── configs
├── swin_small_patch4_window7_224.yaml
├── swin_tiny_patch4_window7_224.yaml
├── cswin_tiny_patch4_window7_224.yaml
├── cswin_base_patch4_window7_224.yaml
├── swin_mlp_base_patch4_window7_224.yaml
├── swin_tiny_c24_patch4_window8_256.yaml
├── swin_mlp_tiny_c12_patch4_window8_256.yaml
├── swin_mlp_tiny_c24_patch4_window8_256.yaml
├── swin_mlp_tiny_c6_patch4_window8_256.yaml
├── swin_base_patch4_window12_384.yaml
├── swin_large_patch4_window12_384.yaml
├── swin_base_patch4_window7_224.yaml
├── swin_large_patch4_window7_224.yaml
├── swin_large_plus_patch4_window7_224_22k.yaml
├── swin_large_patch4_window12_384_22kto1k_finetune_swin_finetune.yaml
└── swin_large_patch4_window12_384_22kto1k_finetune.yaml
├── CODE_OF_CONDUCT.md
├── logger.py
├── optimizer.py
├── .gitignore
├── README.md
├── lr_scheduler.py
├── get_started.md
├── utils.py
├── config.py
├── LICENSE
└── main.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_model
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_loader, build_22k_loader
--------------------------------------------------------------------------------
/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/LipsFormer/HEAD/figures/teaser.png
--------------------------------------------------------------------------------
/configs/swin_small_patch4_window7_224.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: swin_small_patch4_window7_224
4 | DROP_PATH_RATE: 0.3
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 7
--------------------------------------------------------------------------------
/configs/swin_tiny_patch4_window7_224.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: swin_tiny_patch4_window7_224
4 | DROP_PATH_RATE: 0.2
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 6, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 7
--------------------------------------------------------------------------------
/configs/cswin_tiny_patch4_window7_224.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: cswin
3 | NAME: cswin_tiny_patch4_window7_224
4 | DROP_PATH_RATE: 0.2
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 6, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 7
10 |
--------------------------------------------------------------------------------
/configs/cswin_base_patch4_window7_224.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: cswin
3 | NAME: swin_base_patch4_window7_224
4 | DROP_PATH_RATE: 0.5
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 7
10 |
--------------------------------------------------------------------------------
/configs/swin_mlp_base_patch4_window7_224.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin_mlp
3 | NAME: swin_mlp_base_patch4_window7_224
4 | DROP_PATH_RATE: 0.5
5 | SWIN_MLP:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 7
10 |
--------------------------------------------------------------------------------
/configs/swin_tiny_c24_patch4_window8_256.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 256
3 | MODEL:
4 | TYPE: swin
5 | NAME: swin_tiny_c24_patch4_window8_256
6 | DROP_PATH_RATE: 0.2
7 | SWIN:
8 | EMBED_DIM: 96
9 | DEPTHS: [ 2, 2, 6, 2 ]
10 | NUM_HEADS: [ 4, 8, 16, 32 ]
11 | WINDOW_SIZE: 8
--------------------------------------------------------------------------------
/configs/swin_mlp_tiny_c12_patch4_window8_256.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 256
3 | MODEL:
4 | TYPE: swin_mlp
5 | NAME: swin_mlp_tiny_c12_patch4_window8_256
6 | DROP_PATH_RATE: 0.2
7 | SWIN_MLP:
8 | EMBED_DIM: 96
9 | DEPTHS: [ 2, 2, 6, 2 ]
10 | NUM_HEADS: [ 8, 16, 32, 64 ]
11 | WINDOW_SIZE: 8
--------------------------------------------------------------------------------
/configs/swin_mlp_tiny_c24_patch4_window8_256.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 256
3 | MODEL:
4 | TYPE: swin_mlp
5 | NAME: swin_mlp_tiny_c24_patch4_window8_256
6 | DROP_PATH_RATE: 0.2
7 | SWIN_MLP:
8 | EMBED_DIM: 96
9 | DEPTHS: [ 2, 2, 6, 2 ]
10 | NUM_HEADS: [ 4, 8, 16, 32 ]
11 | WINDOW_SIZE: 8
--------------------------------------------------------------------------------
/configs/swin_mlp_tiny_c6_patch4_window8_256.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 256
3 | MODEL:
4 | TYPE: swin_mlp
5 | NAME: swin_mlp_tiny_c6_patch4_window8_256
6 | DROP_PATH_RATE: 0.2
7 | SWIN_MLP:
8 | EMBED_DIM: 96
9 | DEPTHS: [ 2, 2, 6, 2 ]
10 | NUM_HEADS: [ 16, 32, 64, 128 ]
11 | WINDOW_SIZE: 8
--------------------------------------------------------------------------------
/configs/swin_base_patch4_window12_384.yaml:
--------------------------------------------------------------------------------
1 | # only for evaluation
2 | DATA:
3 | IMG_SIZE: 384
4 | MODEL:
5 | TYPE: swin
6 | NAME: swin_base_patch4_window12_384
7 | SWIN:
8 | EMBED_DIM: 128
9 | DEPTHS: [ 2, 2, 18, 2 ]
10 | NUM_HEADS: [ 4, 8, 16, 32 ]
11 | WINDOW_SIZE: 12
12 | TEST:
13 | CROP: False
--------------------------------------------------------------------------------
/configs/swin_large_patch4_window12_384.yaml:
--------------------------------------------------------------------------------
1 | # only for evaluation
2 | DATA:
3 | IMG_SIZE: 384
4 | MODEL:
5 | TYPE: swin
6 | NAME: swin_large_patch4_window12_384
7 | SWIN:
8 | EMBED_DIM: 192
9 | DEPTHS: [ 2, 2, 18, 2 ]
10 | NUM_HEADS: [ 6, 12, 24, 48 ]
11 | WINDOW_SIZE: 12
12 | TEST:
13 | CROP: False
--------------------------------------------------------------------------------
/configs/swin_base_patch4_window7_224.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: swin_base_patch4_window7_224
4 | DROP_PATH_RATE: 0.5
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 7
10 |
11 | TRAIN:
12 | BASE_LR: 2.5e-4
13 | OPTIMIZER:
14 | BETAS: (0.9, 0.998)
--------------------------------------------------------------------------------
/configs/swin_large_patch4_window7_224.yaml:
--------------------------------------------------------------------------------
1 | # only for evaluation
2 | MODEL:
3 | TYPE: swin
4 | NAME: swin_large_patch4_window7_224
5 | DROP_PATH_RATE: 0.5
6 | SWIN:
7 | EMBED_DIM: 192
8 | DEPTHS: [ 2, 2, 18, 2 ]
9 | NUM_HEADS: [ 6, 12, 24, 48 ]
10 | WINDOW_SIZE: 7
11 |
12 | TRAIN:
13 | BASE_LR: 1.25e-4
14 | OPTIMIZER:
15 | BETAS: (0.9, 0.998)
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/configs/swin_large_plus_patch4_window7_224_22k.yaml:
--------------------------------------------------------------------------------
1 | # only for evaluation
2 | DATA:
3 | DATASET: imagenet22K
4 | MODEL:
5 | TYPE: swin
6 | NAME: swin_plus_large_patch4_window7_224_22k
7 | DROP_PATH_RATE: 0.5
8 | LABEL_SMOOTHING: 0.1
9 | SWIN:
10 | EMBED_DIM: 288
11 | DEPTHS: [ 2, 2, 18, 2 ]
12 | NUM_HEADS: [ 6, 12, 24, 48 ]
13 | WINDOW_SIZE: 7
14 |
15 | TRAIN:
16 | EPOCHS: 30
17 | BASE_LR: 1.25e-4
18 | WEIGHT_DECAY: 0.1
19 | OPTIMIZER:
20 | BETAS: (0.9, 0.98)
--------------------------------------------------------------------------------
/configs/swin_large_patch4_window12_384_22kto1k_finetune_swin_finetune.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 384
3 |
4 | MODEL:
5 | TYPE: swin
6 | NAME: swin_large_patch4_window12_384_22kto1k_finetune
7 | PRETRAINED: /comp_robot/workspace/chenyihao/SwinLips_Lalpha_2/output/swin_plus_large_patch4_window7_224_22k/large_plus_22k_30_beta_2_0998/ckpt_epoch_28.pth
8 | DROP_PATH_RATE: 0.2
9 | SWIN:
10 | EMBED_DIM: 288
11 | DEPTHS: [ 2, 2, 18, 2 ]
12 | NUM_HEADS: [ 6, 12, 24, 48 ]
13 | WINDOW_SIZE: 12
14 | TRAIN:
15 | EPOCHS: 30
16 | WARMUP_EPOCHS: 5
17 | WEIGHT_DECAY: 1e-8
18 | BASE_LR: 2e-05
19 | WARMUP_LR: 2e-08
20 | MIN_LR: 2e-07
21 | OPTIMIZER:
22 | BETAS: (0.9, 0.998)
23 | TEST:
24 | CROP: False
--------------------------------------------------------------------------------
/configs/swin_large_patch4_window12_384_22kto1k_finetune.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 384
3 | TRANSFORM_TYPE: easy
4 |
5 | MODEL:
6 | TYPE: swin
7 | NAME: swin_large_patch4_window12_384_22kto1k_finetune
8 | PRETRAINED: /comp_robot/workspace/chenyihao/SwinLips_Lalpha_2/output/swin_plus_large_patch4_window7_224_22k/large_plus_22k_30_beta_2_0998/ckpt_epoch_28.pth
9 | DROP_PATH_RATE: 0.2
10 | SWIN:
11 | EMBED_DIM: 288
12 | DEPTHS: [ 2, 2, 18, 2 ]
13 | NUM_HEADS: [ 6, 12, 24, 48 ]
14 | WINDOW_SIZE: 12
15 | TRAIN:
16 | OPTIMIZER:
17 | NAME: SGD
18 | MOMENTUM: 0.9
19 |
20 | EPOCHS: 8
21 | WEIGHT_DECAY: 0.0
22 | BASE_LR: 3e-02
23 | WARMUP_LR: 2e-08
24 | MIN_LR: 2e-07
25 | TEST:
26 | CROP: False
27 | AUG:
28 | MIXUP: 0.1
29 | CUTMIX: 0.0
30 | CUTMIX_MINMAX: None
31 | MIXUP_PROB: 1.0
32 | MIXUP_SWITCH_PROB: 0.0
--------------------------------------------------------------------------------
/data/samplers.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import torch
9 |
10 |
11 | class SubsetRandomSampler(torch.utils.data.Sampler):
12 | r"""Samples elements randomly from a given list of indices, without replacement.
13 |
14 | Arguments:
15 | indices (sequence): a sequence of indices
16 | """
17 |
18 | def __init__(self, indices):
19 | self.epoch = 0
20 | self.indices = indices
21 |
22 | def __iter__(self):
23 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
24 |
25 | def __len__(self):
26 | return len(self.indices)
27 |
28 | def set_epoch(self, epoch):
29 | self.epoch = epoch
30 |
--------------------------------------------------------------------------------
/models/build.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | from .lipsformer_swin import LipsFormerSwin
9 |
10 | def build_model(config):
11 | model_type = config.MODEL.TYPE
12 | model = LipsFormerSwin(img_size=config.DATA.IMG_SIZE,
13 | patch_size=config.MODEL.SWIN.PATCH_SIZE,
14 | in_chans=config.MODEL.SWIN.IN_CHANS,
15 | num_classes=config.MODEL.NUM_CLASSES,
16 | embed_dim=config.MODEL.SWIN.EMBED_DIM,
17 | depths=config.MODEL.SWIN.DEPTHS,
18 | num_heads=config.MODEL.SWIN.NUM_HEADS,
19 | window_size=config.MODEL.SWIN.WINDOW_SIZE,
20 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
21 | qkv_bias=config.MODEL.SWIN.QKV_BIAS,
22 | qk_scale=config.MODEL.SWIN.QK_SCALE,
23 | drop_rate=config.MODEL.DROP_RATE,
24 | drop_path_rate=config.MODEL.DROP_PATH_RATE,
25 | ape=config.MODEL.SWIN.APE,
26 | patch_norm=config.MODEL.SWIN.PATCH_NORM,
27 | use_checkpoint=config.TRAIN.USE_CHECKPOINT)
28 | return model
29 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import logging
11 | import functools
12 | from termcolor import colored
13 |
14 |
15 | @functools.lru_cache()
16 | def create_logger(output_dir, dist_rank=0, name=''):
17 | # create logger
18 | logger = logging.getLogger(name)
19 | logger.setLevel(logging.DEBUG)
20 | logger.propagate = False
21 |
22 | # create formatter
23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
26 |
27 | # create console handlers for master process
28 | if dist_rank == 0:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setLevel(logging.DEBUG)
31 | console_handler.setFormatter(
32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
33 | logger.addHandler(console_handler)
34 |
35 | # create file handlers
36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
37 | file_handler.setLevel(logging.DEBUG)
38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
39 | logger.addHandler(file_handler)
40 |
41 | return logger
42 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | from torch import optim as optim
9 |
10 |
11 | def build_optimizer(config, model):
12 | """
13 | Build optimizer, set weight decay of normalization to 0 by default.
14 | """
15 | skip = {}
16 | skip_keywords = {}
17 | if hasattr(model, 'no_weight_decay'):
18 | skip = model.no_weight_decay()
19 | if hasattr(model, 'no_weight_decay_keywords'):
20 | skip_keywords = model.no_weight_decay_keywords()
21 | parameters = set_weight_decay(model, skip, skip_keywords)
22 |
23 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
24 | optimizer = None
25 | if opt_lower == 'sgd':
26 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
27 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
28 | elif opt_lower == 'adamw':
29 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
30 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
31 |
32 | return optimizer
33 |
34 |
35 | def set_weight_decay(model, skip_list=(), skip_keywords=()):
36 | has_decay = []
37 | no_decay = []
38 |
39 | for name, param in model.named_parameters():
40 | if not param.requires_grad:
41 | continue # frozen weights
42 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
43 | check_keywords_in_name(name, skip_keywords):
44 | no_decay.append(param)
45 | print(f"{name} has no weight decay")
46 | else:
47 | has_decay.append(param)
48 | return [{'params': has_decay},
49 | {'params': no_decay, 'weight_decay': 0.0005}]
50 |
51 |
52 | def check_keywords_in_name(name, keywords=()):
53 | isin = False
54 | for keyword in keywords:
55 | if keyword in name:
56 | isin = True
57 | return isin
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | output/
132 |
133 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LipsFormer
2 |
3 | By Xianbiao Qi, Jianan Wang, Yihao Chen, Yukai Shi, and Lei Zhang.
4 |
5 | This repo is the official implementation of ["LipsFormer: Introducing Lipschitz Continuity to Vision Transformers"](https://openreview.net/pdf?id=cHf1DcCwcH3).
6 |
7 | Initial commits:
8 | We release one pretrained model below, the result of the model is 82.70 for LipsFormer-Swin-Tiny. We train the model without using any warmup.
9 |
10 | 1. Pretrained models on ImageNet-1K ([LipsFormer-Swin-T-IN1K](https://github.com/cyh1112/LipsFormer/releases/download/checkpoint/lipsformer-swin-tiny.pth)).
11 |
12 | In our paper, we compile two versions of LipsFormer. One is built on Swin and the other one is based on CSwin, We do not merge these two code bases. Thank the authors of CSwin and Swin for releasing their code base.
13 | In this repo, we release the code base based on Swin. You can easily change it to CSwin code.
14 |
15 |
16 |
17 | ## Introduction
18 |
19 | **LipsFormer** introduces a Lipschitz continuous Transformer to pursue training stability both theoretically and empirically for Transformer-based models. In contrast to previous practical tricks that address training instability by learning rate warmup, layer normalization, attention formulation, and weight initialization, we show that Lipschitz continuity is a more essential property to ensure training stability. In LipsFormer, we replace unstable Transformer component modules with Lipschitz continuous counterparts: CenterNorm instead of LayerNorm, spectral initialization instead of Xavier initialization, scaled cosine similarity attention instead of dot-product attention, and weighted residual shortcut. We prove that these introduced modules are Lipschitz continuous and derive an upper bound on the Lipschitz constant of LipsFormer.
20 |
21 |
22 |
23 | ## Getting Started
24 | For Image Classification, please see [get_started.md](https://github.com/IDEA-Research/LipsFormer/blob/main/get_started.md) for detailed instructions.
25 |
26 | ## Training
27 | ```
28 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
29 | --cfg configs/swin_tiny_patch4_window7_224.yaml --data-path --batch-size 32
30 | ```
31 |
32 |
33 |
34 | ## References
35 | ```
36 | @misc{qi2023lipsformer,
37 | title={LipsFormer: Introducing Lipschitz Continuity to Vision Transformers},
38 | author={Xianbiao Qi and Jianan Wang and Yihao Chen and Yukai Shi and Lei Zhang},
39 | year={2023},
40 | eprint={2304.09856},
41 | archivePrefix={arXiv},
42 | primaryClass={cs.CV}
43 | }
44 |
45 | @misc{liu2021swin,
46 | title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
47 | author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo},
48 | year={2021},
49 | eprint={2103.14030},
50 | archivePrefix={arXiv},
51 | primaryClass={cs.CV}
52 | }
53 |
54 | @misc{dong2022cswin,
55 | title={CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows},
56 | author={Xiaoyi Dong and Jianmin Bao and Dongdong Chen and Weiming Zhang and Nenghai Yu and Lu Yuan and Dong Chen and Baining Guo},
57 | year={2022},
58 | eprint={2107.00652},
59 | archivePrefix={arXiv},
60 | primaryClass={cs.CV}
61 | }
62 | ```
63 |
64 |
65 |
--------------------------------------------------------------------------------
/data/zipreader.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import zipfile
10 | import io
11 | import numpy as np
12 | from PIL import Image
13 | from PIL import ImageFile
14 |
15 | ImageFile.LOAD_TRUNCATED_IMAGES = True
16 |
17 |
18 | def is_zip_path(img_or_path):
19 | """judge if this is a zip path"""
20 | return '.zip@' in img_or_path
21 |
22 |
23 | class ZipReader(object):
24 | """A class to read zipped files"""
25 | zip_bank = dict()
26 |
27 | def __init__(self):
28 | super(ZipReader, self).__init__()
29 |
30 | @staticmethod
31 | def get_zipfile(path):
32 | zip_bank = ZipReader.zip_bank
33 | if path not in zip_bank:
34 | zfile = zipfile.ZipFile(path, 'r')
35 | zip_bank[path] = zfile
36 | return zip_bank[path]
37 |
38 | @staticmethod
39 | def split_zip_style_path(path):
40 | pos_at = path.index('@')
41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
42 |
43 | zip_path = path[0: pos_at]
44 | folder_path = path[pos_at + 1:]
45 | folder_path = str.strip(folder_path, '/')
46 | return zip_path, folder_path
47 |
48 | @staticmethod
49 | def list_folder(path):
50 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
51 |
52 | zfile = ZipReader.get_zipfile(zip_path)
53 | folder_list = []
54 | for file_foler_name in zfile.namelist():
55 | file_foler_name = str.strip(file_foler_name, '/')
56 | if file_foler_name.startswith(folder_path) and \
57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \
58 | file_foler_name != folder_path:
59 | if len(folder_path) == 0:
60 | folder_list.append(file_foler_name)
61 | else:
62 | folder_list.append(file_foler_name[len(folder_path) + 1:])
63 |
64 | return folder_list
65 |
66 | @staticmethod
67 | def list_files(path, extension=None):
68 | if extension is None:
69 | extension = ['.*']
70 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
71 |
72 | zfile = ZipReader.get_zipfile(zip_path)
73 | file_lists = []
74 | for file_foler_name in zfile.namelist():
75 | file_foler_name = str.strip(file_foler_name, '/')
76 | if file_foler_name.startswith(folder_path) and \
77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
78 | if len(folder_path) == 0:
79 | file_lists.append(file_foler_name)
80 | else:
81 | file_lists.append(file_foler_name[len(folder_path) + 1:])
82 |
83 | return file_lists
84 |
85 | @staticmethod
86 | def read(path):
87 | zip_path, path_img = ZipReader.split_zip_style_path(path)
88 | zfile = ZipReader.get_zipfile(zip_path)
89 | data = zfile.read(path_img)
90 | return data
91 |
92 | @staticmethod
93 | def imread(path):
94 | zip_path, path_img = ZipReader.split_zip_style_path(path)
95 | zfile = ZipReader.get_zipfile(zip_path)
96 | data = zfile.read(path_img)
97 | try:
98 | im = Image.open(io.BytesIO(data))
99 | except:
100 | print("ERROR IMG LOADED: ", path_img)
101 | random_img = np.random.rand(224, 224, 3) * 255
102 | im = Image.fromarray(np.uint8(random_img))
103 | return im
104 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from timm.scheduler.cosine_lr import CosineLRScheduler
10 | from timm.scheduler.step_lr import StepLRScheduler
11 | from timm.scheduler.scheduler import Scheduler
12 |
13 |
14 | def build_scheduler(config, optimizer, n_iter_per_epoch):
15 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
16 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
17 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
18 |
19 | lr_scheduler = None
20 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
21 | lr_scheduler = CosineLRScheduler(
22 | optimizer,
23 | t_initial=num_steps,
24 | t_mul=1.,
25 | lr_min=config.TRAIN.MIN_LR,
26 | warmup_lr_init=config.TRAIN.WARMUP_LR,
27 | warmup_t=warmup_steps,
28 | cycle_limit=1,
29 | t_in_epochs=False,
30 | )
31 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
32 | lr_scheduler = LinearLRScheduler(
33 | optimizer,
34 | t_initial=num_steps,
35 | lr_min_rate=0.01,
36 | warmup_lr_init=config.TRAIN.WARMUP_LR,
37 | warmup_t=warmup_steps,
38 | t_in_epochs=False,
39 | )
40 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
41 | lr_scheduler = StepLRScheduler(
42 | optimizer,
43 | decay_t=decay_steps,
44 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
45 | warmup_lr_init=config.TRAIN.WARMUP_LR,
46 | warmup_t=warmup_steps,
47 | t_in_epochs=False,
48 | )
49 |
50 | return lr_scheduler
51 |
52 |
53 | class LinearLRScheduler(Scheduler):
54 | def __init__(self,
55 | optimizer: torch.optim.Optimizer,
56 | t_initial: int,
57 | lr_min_rate: float,
58 | warmup_t=0,
59 | warmup_lr_init=0.,
60 | t_in_epochs=True,
61 | noise_range_t=None,
62 | noise_pct=0.67,
63 | noise_std=1.0,
64 | noise_seed=42,
65 | initialize=True,
66 | ) -> None:
67 | super().__init__(
68 | optimizer, param_group_field="lr",
69 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
70 | initialize=initialize)
71 |
72 | self.t_initial = t_initial
73 | self.lr_min_rate = lr_min_rate
74 | self.warmup_t = warmup_t
75 | self.warmup_lr_init = warmup_lr_init
76 | self.t_in_epochs = t_in_epochs
77 | if self.warmup_t:
78 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
79 | super().update_groups(self.warmup_lr_init)
80 | else:
81 | self.warmup_steps = [1 for _ in self.base_values]
82 |
83 | def _get_lr(self, t):
84 | if t < self.warmup_t:
85 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
86 | else:
87 | t = t - self.warmup_t
88 | total_t = self.t_initial - self.warmup_t
89 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
90 | return lrs
91 |
92 | def get_epoch_values(self, epoch: int):
93 | if self.t_in_epochs:
94 | return self._get_lr(epoch)
95 | else:
96 | return None
97 |
98 | def get_update_values(self, num_updates: int):
99 | if not self.t_in_epochs:
100 | return self._get_lr(num_updates)
101 | else:
102 | return None
103 |
--------------------------------------------------------------------------------
/get_started.md:
--------------------------------------------------------------------------------
1 | # LipsFormer-Swin for Image Classification
2 |
3 | This folder contains the implementation of the LipsFormer-Swin for image classification.
4 |
5 | ## Usage
6 |
7 | ### Install
8 |
9 | - Clone this repo:
10 |
11 | ```bash
12 | git clone https://github.com/IDEA-Research/LipsFormer.git
13 | cd LipsFormer
14 | ```
15 |
16 | - Create a conda virtual environment and activate it:
17 |
18 | ```bash
19 | conda create -n lipsformer python=3.7 -y
20 | conda activate lipsformer
21 | ```
22 |
23 | - Install `CUDA==10.1` with `cudnn7` following
24 | the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
25 | - Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`:
26 |
27 | ```bash
28 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
29 | ```
30 |
31 | - Install `timm==0.3.2`:
32 |
33 | ```bash
34 | pip install timm==0.3.2
35 | ```
36 |
37 | - Install `Apex`:
38 |
39 | ```bash
40 | git clone https://github.com/NVIDIA/apex
41 | cd apex
42 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
43 | ```
44 |
45 | - Install other requirements:
46 |
47 | ```bash
48 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8
49 | ```
50 |
51 | ### Data preparation
52 |
53 | We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
54 | load data:
55 |
56 | - For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
57 | ```bash
58 | $ tree data
59 | imagenet
60 | ├── train
61 | │ ├── class1
62 | │ │ ├── img1.jpeg
63 | │ │ ├── img2.jpeg
64 | │ │ └── ...
65 | │ ├── class2
66 | │ │ ├── img3.jpeg
67 | │ │ └── ...
68 | │ └── ...
69 | └── val
70 | ├── class1
71 | │ ├── img4.jpeg
72 | │ ├── img5.jpeg
73 | │ └── ...
74 | ├── class2
75 | │ ├── img6.jpeg
76 | │ └── ...
77 | └── ...
78 |
79 | ```
80 | - To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes
81 | four files:
82 | - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.
83 | - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth
84 | label. Make sure the data folder looks like this:
85 |
86 | ```bash
87 | $ tree data
88 | data
89 | └── ImageNet-Zip
90 | ├── train_map.txt
91 | ├── train.zip
92 | ├── val_map.txt
93 | └── val.zip
94 |
95 | $ head -n 5 data/ImageNet-Zip/val_map.txt
96 | ILSVRC2012_val_00000001.JPEG 65
97 | ILSVRC2012_val_00000002.JPEG 970
98 | ILSVRC2012_val_00000003.JPEG 230
99 | ILSVRC2012_val_00000004.JPEG 809
100 | ILSVRC2012_val_00000005.JPEG 516
101 |
102 | $ head -n 5 data/ImageNet-Zip/train_map.txt
103 | n01440764/n01440764_10026.JPEG 0
104 | n01440764/n01440764_10027.JPEG 0
105 | n01440764/n01440764_10029.JPEG 0
106 | n01440764/n01440764_10040.JPEG 0
107 | n01440764/n01440764_10042.JPEG 0
108 | ```
109 |
110 | ### Evaluation
111 |
112 | To evaluate a pre-trained `LipsFormer` on ImageNet val, run:
113 |
114 | ```bash
115 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py --eval \
116 | --cfg --resume --data-path
117 | ```
118 |
119 | For example, to evaluate the `LipsFormer-Swin-B` with a single GPU:
120 |
121 | ```bash
122 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
123 | --cfg configs/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path
124 | ```
125 |
126 | ### Training from scratch
127 |
128 | To train a `LipsFormer-Swin` on ImageNet from scratch, run:
129 |
130 | ```bash
131 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py \
132 | --cfg --data-path [--batch-size --output --tag ]
133 | ```
134 |
135 | **Notes**:
136 |
137 | - To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters.
138 | - To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will
139 | shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU.
140 | - When GPU memory is not enough, you can try the following suggestions:
141 | - Use gradient accumulation by adding `--accumulation-steps `, set appropriate `` according to your need.
142 | - Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`.
143 | Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details.
144 | - We recommend using multi-node with more GPUs for training very large models, a tutorial can be found
145 | in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html).
146 | - To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g.,
147 | `--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5.
148 | - For additional options, see [config](config.py) and run `python main.py --help` to get detailed message.
149 |
150 | For example, to train `LipsFormer` with 8 GPU on a single node for 300 epochs, run:
151 |
152 | `Swin-T`:
153 |
154 | ```bash
155 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
156 | --cfg configs/swin_tiny_patch4_window7_224.yaml --data-path --batch-size 128
157 | ```
158 |
159 | `Swin-S`:
160 |
161 | ```bash
162 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
163 | --cfg configs/swin_small_patch4_window7_224.yaml --data-path --batch-size 128
164 | ```
165 |
166 | `Swin-B`:
167 |
168 | ```bash
169 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
170 | --cfg configs/swin_base_patch4_window7_224.yaml --data-path --batch-size 64 \
171 | --accumulation-steps 2 [--use-checkpoint]
172 | ```
173 |
174 | ### Throughput
175 |
176 | To measure the throughput, run:
177 |
178 | ```bash
179 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \
180 | --cfg --data-path --batch-size 64 --throughput --amp-opt-level O0
181 | ```
182 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import torch
10 | import torch.distributed as dist
11 |
12 | try:
13 | # noinspection PyUnresolvedReferences
14 | from apex import amp
15 | except ImportError:
16 | amp = None
17 |
18 |
19 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
20 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
21 | if config.MODEL.RESUME.startswith('https'):
22 | checkpoint = torch.hub.load_state_dict_from_url(
23 | config.MODEL.RESUME, map_location='cpu', check_hash=True)
24 | else:
25 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
26 | msg = model.load_state_dict(checkpoint['model'], strict=False)
27 | logger.info(msg)
28 | max_accuracy = 0.0
29 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
30 | optimizer.load_state_dict(checkpoint['optimizer'])
31 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
32 | config.defrost()
33 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
34 | config.freeze()
35 | if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
36 | amp.load_state_dict(checkpoint['amp'])
37 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
38 | if 'max_accuracy' in checkpoint:
39 | max_accuracy = checkpoint['max_accuracy']
40 |
41 | del checkpoint
42 | torch.cuda.empty_cache()
43 | return max_accuracy
44 |
45 |
46 | def load_pretrained(config, model, logger):
47 | logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
48 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
49 | state_dict = checkpoint['model']
50 |
51 | # delete relative_position_index since we always re-init it
52 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
53 | for k in relative_position_index_keys:
54 | del state_dict[k]
55 |
56 | # delete relative_coords_table since we always re-init it
57 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
58 | for k in relative_position_index_keys:
59 | del state_dict[k]
60 |
61 | # delete attn_mask since we always re-init it
62 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
63 | for k in attn_mask_keys:
64 | del state_dict[k]
65 |
66 | # bicubic interpolate relative_position_bias_table if not match
67 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
68 | for k in relative_position_bias_table_keys:
69 | relative_position_bias_table_pretrained = state_dict[k]
70 | relative_position_bias_table_current = model.state_dict()[k]
71 | L1, nH1 = relative_position_bias_table_pretrained.size()
72 | L2, nH2 = relative_position_bias_table_current.size()
73 | if nH1 != nH2:
74 | logger.warning(f"Error in loading {k}, passing......")
75 | else:
76 | if L1 != L2:
77 | # bicubic interpolate relative_position_bias_table if not match
78 | S1 = int(L1 ** 0.5)
79 | S2 = int(L2 ** 0.5)
80 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
81 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
82 | mode='bicubic')
83 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
84 |
85 | # bicubic interpolate absolute_pos_embed if not match
86 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
87 | for k in absolute_pos_embed_keys:
88 | # dpe
89 | absolute_pos_embed_pretrained = state_dict[k]
90 | absolute_pos_embed_current = model.state_dict()[k]
91 | _, L1, C1 = absolute_pos_embed_pretrained.size()
92 | _, L2, C2 = absolute_pos_embed_current.size()
93 | if C1 != C1:
94 | logger.warning(f"Error in loading {k}, passing......")
95 | else:
96 | if L1 != L2:
97 | S1 = int(L1 ** 0.5)
98 | S2 = int(L2 ** 0.5)
99 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
100 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
101 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
102 | absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
103 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
104 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
105 | state_dict[k] = absolute_pos_embed_pretrained_resized
106 |
107 | # check classifier, if not match, then re-init classifier to zero
108 | head_bias_pretrained = state_dict['head.bias']
109 | Nc1 = head_bias_pretrained.shape[0]
110 | Nc2 = model.head.bias.shape[0]
111 | if (Nc1 != Nc2):
112 | if Nc1 == 21841 and Nc2 == 1000:
113 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
114 | map22kto1k_path = f'data/map22kto1k.txt'
115 | with open(map22kto1k_path) as f:
116 | map22kto1k = f.readlines()
117 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
118 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
119 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
120 | else:
121 | torch.nn.init.constant_(model.head.bias, 0.)
122 | torch.nn.init.constant_(model.head.weight, 0.)
123 | del state_dict['head.weight']
124 | del state_dict['head.bias']
125 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
126 |
127 | msg = model.load_state_dict(state_dict, strict=False)
128 | logger.warning(msg)
129 |
130 | logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'")
131 |
132 | del checkpoint
133 | torch.cuda.empty_cache()
134 |
135 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
136 | save_state = {'model': model.state_dict(),
137 | 'optimizer': optimizer.state_dict(),
138 | 'lr_scheduler': lr_scheduler.state_dict(),
139 | 'max_accuracy': max_accuracy,
140 | 'epoch': epoch,
141 | 'config': config}
142 | # if config.AMP_OPT_LEVEL != "O0":
143 | # save_state['amp'] = amp.state_dict()
144 |
145 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
146 | logger.info(f"{save_path} saving......")
147 | torch.save(save_state, save_path)
148 | logger.info(f"{save_path} saved !!!")
149 |
150 |
151 | def get_grad_norm(parameters, norm_type=2):
152 | if isinstance(parameters, torch.Tensor):
153 | parameters = [parameters]
154 | parameters = list(filter(lambda p: p.grad is not None, parameters))
155 | norm_type = float(norm_type)
156 | total_norm = 0
157 | for p in parameters:
158 | param_norm = p.grad.data.norm(norm_type)
159 | total_norm += param_norm.item() ** norm_type
160 | total_norm = total_norm ** (1. / norm_type)
161 | return total_norm
162 |
163 |
164 | def auto_resume_helper(output_dir):
165 | checkpoints = os.listdir(output_dir)
166 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
167 | print(f"All checkpoints founded in {output_dir}: {checkpoints}")
168 | if len(checkpoints) > 0:
169 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
170 | print(f"The latest checkpoint founded: {latest_checkpoint}")
171 | resume_file = latest_checkpoint
172 | else:
173 | resume_file = None
174 | return resume_file
175 |
176 |
177 | def reduce_tensor(tensor):
178 | rt = tensor.clone()
179 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
180 | rt /= dist.get_world_size()
181 | return rt
182 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # LipsFormer Swin
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------'
7 |
8 | import os
9 | import yaml
10 | from yacs.config import CfgNode as CN
11 |
12 | _C = CN()
13 |
14 | # Base config files
15 | _C.BASE = ['']
16 |
17 | # -----------------------------------------------------------------------------
18 | # Data settings
19 | # -----------------------------------------------------------------------------
20 | _C.DATA = CN()
21 | # Batch size for a single GPU, could be overwritten by command line argument
22 | _C.DATA.BATCH_SIZE = 128
23 | # Path to dataset, could be overwritten by command line argument
24 | _C.DATA.DATA_PATH = ''
25 | # Dataset name
26 | _C.DATA.DATASET = 'imagenet'
27 | # Input image size
28 | _C.DATA.IMG_SIZE = 224
29 | # Interpolation to resize image (random, bilinear, bicubic)
30 | _C.DATA.INTERPOLATION = 'bicubic'
31 | # Use zipped dataset instead of folder dataset
32 | # could be overwritten by command line argument
33 | _C.DATA.ZIP_MODE = False
34 | # Cache Data in Memory, could be overwritten by command line argument
35 | _C.DATA.CACHE_MODE = 'part'
36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
37 | _C.DATA.PIN_MEMORY = True
38 | # Number of data loading threads
39 | _C.DATA.NUM_WORKERS = 8
40 |
41 | _C.DATA.TRANSFORM_TYPE = None
42 |
43 | # -----------------------------------------------------------------------------
44 | # Model settings
45 | # -----------------------------------------------------------------------------
46 | _C.MODEL = CN()
47 | # Model type
48 | _C.MODEL.TYPE = 'swin'
49 | # Model name
50 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
51 | # Checkpoint to resume, could be overwritten by command line argument
52 | _C.MODEL.RESUME = ''
53 | # Number of classes, overwritten in data preparation
54 | _C.MODEL.NUM_CLASSES = 1000
55 | # Dropout rate
56 | _C.MODEL.DROP_RATE = 0.0
57 | # Drop path rate
58 | _C.MODEL.DROP_PATH_RATE = 0.1
59 | # Label Smoothing
60 | _C.MODEL.LABEL_SMOOTHING = 0.1
61 |
62 | _C.MODEL.PRETRAINED = None
63 |
64 |
65 | # Swin Transformer parameters
66 | _C.MODEL.SWIN = CN()
67 | _C.MODEL.SWIN.PATCH_SIZE = 4
68 | _C.MODEL.SWIN.IN_CHANS = 3
69 | _C.MODEL.SWIN.EMBED_DIM = 96
70 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
71 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
72 | _C.MODEL.SWIN.WINDOW_SIZE = 7
73 | _C.MODEL.SWIN.MLP_RATIO = 4.
74 | _C.MODEL.SWIN.QKV_BIAS = True
75 | _C.MODEL.SWIN.QK_SCALE = None
76 | _C.MODEL.SWIN.APE = False
77 | _C.MODEL.SWIN.PATCH_NORM = True
78 |
79 | # Swin MLP parameters
80 | _C.MODEL.SWIN_MLP = CN()
81 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4
82 | _C.MODEL.SWIN_MLP.IN_CHANS = 3
83 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96
84 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]
85 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]
86 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7
87 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4.
88 | _C.MODEL.SWIN_MLP.APE = False
89 | _C.MODEL.SWIN_MLP.PATCH_NORM = True
90 |
91 | # -----------------------------------------------------------------------------
92 | # Training settings
93 | # -----------------------------------------------------------------------------
94 | _C.TRAIN = CN()
95 | _C.TRAIN.START_EPOCH = 0
96 | _C.TRAIN.EPOCHS = 100
97 | _C.TRAIN.WARMUP_EPOCHS = 0
98 | _C.TRAIN.WEIGHT_DECAY = 0.05
99 | _C.TRAIN.BASE_LR = 5e-4
100 | _C.TRAIN.WARMUP_LR = 5e-4
101 | _C.TRAIN.MIN_LR = 5e-6
102 | # Clip gradient norm
103 | _C.TRAIN.CLIP_GRAD = 5.0
104 | # Auto resume from latest checkpoint
105 | _C.TRAIN.AUTO_RESUME = True
106 | # Gradient accumulation steps
107 | # could be overwritten by command line argument
108 | _C.TRAIN.ACCUMULATION_STEPS = 0
109 | # Whether to use gradient checkpointing to save memory
110 | # could be overwritten by command line argument
111 | _C.TRAIN.USE_CHECKPOINT = False
112 |
113 | # LR scheduler
114 | _C.TRAIN.LR_SCHEDULER = CN()
115 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
116 | # Epoch interval to decay LR, used in StepLRScheduler
117 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
118 | # LR decay rate, used in StepLRScheduler
119 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
120 |
121 | # Optimizer
122 | _C.TRAIN.OPTIMIZER = CN()
123 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
124 | # Optimizer Epsilon
125 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
126 | # Optimizer Betas
127 | #_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
128 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.98)
129 | # SGD momentum
130 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
131 |
132 | # -----------------------------------------------------------------------------
133 | # Augmentation settings
134 | # -----------------------------------------------------------------------------
135 | _C.AUG = CN()
136 | # Color jitter factor
137 | _C.AUG.COLOR_JITTER = 0.4
138 | # Use AutoAugment policy. "v0" or "original"
139 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
140 | # Random erase prob
141 | _C.AUG.REPROB = 0.25
142 | # Random erase mode
143 | _C.AUG.REMODE = 'pixel'
144 | # Random erase count
145 | _C.AUG.RECOUNT = 1
146 | # Mixup alpha, mixup enabled if > 0
147 | _C.AUG.MIXUP = 0.8
148 | # Cutmix alpha, cutmix enabled if > 0
149 | _C.AUG.CUTMIX = 1.0
150 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
151 | _C.AUG.CUTMIX_MINMAX = None
152 | # Probability of performing mixup or cutmix when either/both is enabled
153 | _C.AUG.MIXUP_PROB = 1.0
154 | # Probability of switching to cutmix when both mixup and cutmix enabled
155 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
156 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
157 | _C.AUG.MIXUP_MODE = 'batch'
158 |
159 | # -----------------------------------------------------------------------------
160 | # Testing settings
161 | # -----------------------------------------------------------------------------
162 | _C.TEST = CN()
163 | # Whether to use center crop when testing
164 | _C.TEST.CROP = True
165 |
166 | # -----------------------------------------------------------------------------
167 | # Misc
168 | # -----------------------------------------------------------------------------
169 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
170 | # overwritten by command line argument
171 | _C.AMP_OPT_LEVEL = ''
172 | # Path to output folder, overwritten by command line argument
173 | _C.OUTPUT = ''
174 | # Tag of experiment, overwritten by command line argument
175 | _C.TAG = 'default'
176 | # Frequency to save checkpoint
177 | _C.SAVE_FREQ = 1
178 | # Frequency to logging info
179 | _C.PRINT_FREQ = 10
180 | # Fixed random seed
181 | _C.SEED = 0
182 | # Perform evaluation only, overwritten by command line argument
183 | _C.EVAL_MODE = False
184 | # Test throughput only, overwritten by command line argument
185 | _C.THROUGHPUT_MODE = False
186 | # local rank for DistributedDataParallel, given by command line argument
187 | _C.LOCAL_RANK = 0
188 |
189 |
190 | def _update_config_from_file(config, cfg_file):
191 | config.defrost()
192 | with open(cfg_file, 'r') as f:
193 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
194 |
195 | for cfg in yaml_cfg.setdefault('BASE', ['']):
196 | if cfg:
197 | _update_config_from_file(
198 | config, os.path.join(os.path.dirname(cfg_file), cfg)
199 | )
200 | print('=> merge config from {}'.format(cfg_file))
201 | config.merge_from_file(cfg_file)
202 | config.freeze()
203 |
204 |
205 | def update_config(config, args):
206 | _update_config_from_file(config, args.cfg)
207 |
208 | config.defrost()
209 | if args.opts:
210 | config.merge_from_list(args.opts)
211 |
212 | # merge from specific arguments
213 | if args.epochs:
214 | config.TRAIN.EPOCHS = args.epochs
215 | if args.batch_size:
216 | config.DATA.BATCH_SIZE = args.batch_size
217 | if args.data_path:
218 | config.DATA.DATA_PATH = args.data_path
219 | if args.zip:
220 | config.DATA.ZIP_MODE = True
221 | if args.cache_mode:
222 | config.DATA.CACHE_MODE = args.cache_mode
223 | if args.resume:
224 | config.MODEL.RESUME = args.resume
225 | if args.accumulation_steps:
226 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
227 | if args.use_checkpoint:
228 | config.TRAIN.USE_CHECKPOINT = True
229 | if args.amp_opt_level:
230 | config.AMP_OPT_LEVEL = args.amp_opt_level
231 | if args.output:
232 | config.OUTPUT = args.output
233 | if args.tag:
234 | config.TAG = args.tag
235 | if args.eval:
236 | config.EVAL_MODE = True
237 | if args.throughput:
238 | config.THROUGHPUT_MODE = True
239 |
240 | # set local rank for distributed training
241 | config.LOCAL_RANK = args.local_rank
242 |
243 | # output folder
244 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
245 |
246 | config.freeze()
247 |
248 |
249 | def get_config(args):
250 | """Get a yacs CfgNode object with default values."""
251 | # Return a clone so that the defaults will not be altered
252 | # This is for the "local variable" use pattern
253 | config = _C.clone()
254 | update_config(config, args)
255 |
256 | return config
257 |
--------------------------------------------------------------------------------
/data/cached_image_folder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import io
9 | import os
10 | import time
11 | import torch.distributed as dist
12 | import torch.utils.data as data
13 | from PIL import Image
14 |
15 | from .zipreader import is_zip_path, ZipReader
16 |
17 |
18 | def has_file_allowed_extension(filename, extensions):
19 | """Checks if a file is an allowed extension.
20 | Args:
21 | filename (string): path to a file
22 | Returns:
23 | bool: True if the filename ends with a known image extension
24 | """
25 | filename_lower = filename.lower()
26 | return any(filename_lower.endswith(ext) for ext in extensions)
27 |
28 |
29 | def find_classes(dir):
30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
31 | classes.sort()
32 | class_to_idx = {classes[i]: i for i in range(len(classes))}
33 | return classes, class_to_idx
34 |
35 |
36 | def make_dataset(dir, class_to_idx, extensions):
37 | images = []
38 | dir = os.path.expanduser(dir)
39 | for target in sorted(os.listdir(dir)):
40 | d = os.path.join(dir, target)
41 | if not os.path.isdir(d):
42 | continue
43 |
44 | for root, _, fnames in sorted(os.walk(d)):
45 | for fname in sorted(fnames):
46 | if has_file_allowed_extension(fname, extensions):
47 | path = os.path.join(root, fname)
48 | item = (path, class_to_idx[target])
49 | images.append(item)
50 |
51 | return images
52 |
53 |
54 | def make_dataset_with_ann(ann_file, img_prefix, extensions):
55 | images = []
56 | with open(ann_file, "r") as f:
57 | contents = f.readlines()
58 | for line_str in contents:
59 | path_contents = [c for c in line_str.split('\t')]
60 | im_file_name = path_contents[0]
61 | class_index = int(path_contents[1])
62 |
63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
64 | item = (os.path.join(img_prefix, im_file_name), class_index)
65 |
66 | images.append(item)
67 |
68 | return images
69 |
70 |
71 | class DatasetFolder(data.Dataset):
72 | """A generic data loader where the samples are arranged in this way: ::
73 | root/class_x/xxx.ext
74 | root/class_x/xxy.ext
75 | root/class_x/xxz.ext
76 | root/class_y/123.ext
77 | root/class_y/nsdf3.ext
78 | root/class_y/asd932_.ext
79 | Args:
80 | root (string): Root directory path.
81 | loader (callable): A function to load a sample given its path.
82 | extensions (list[string]): A list of allowed extensions.
83 | transform (callable, optional): A function/transform that takes in
84 | a sample and returns a transformed version.
85 | E.g, ``transforms.RandomCrop`` for images.
86 | target_transform (callable, optional): A function/transform that takes
87 | in the target and transforms it.
88 | Attributes:
89 | samples (list): List of (sample path, class_index) tuples
90 | """
91 |
92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
93 | cache_mode="no"):
94 | # image folder mode
95 | if ann_file == '':
96 | _, class_to_idx = find_classes(root)
97 | samples = make_dataset(root, class_to_idx, extensions)
98 | # zip mode
99 | else:
100 | samples = make_dataset_with_ann(os.path.join(root, ann_file),
101 | os.path.join(root, img_prefix),
102 | extensions)
103 |
104 | if len(samples) == 0:
105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
106 | "Supported extensions are: " + ",".join(extensions)))
107 |
108 | self.root = root
109 | self.loader = loader
110 | self.extensions = extensions
111 |
112 | self.samples = samples
113 | self.labels = [y_1k for _, y_1k in samples]
114 | self.classes = list(set(self.labels))
115 |
116 | self.transform = transform
117 | self.target_transform = target_transform
118 |
119 | self.cache_mode = cache_mode
120 | if self.cache_mode != "no":
121 | self.init_cache()
122 |
123 | def init_cache(self):
124 | assert self.cache_mode in ["part", "full"]
125 | n_sample = len(self.samples)
126 | global_rank = dist.get_rank()
127 | world_size = dist.get_world_size()
128 |
129 | samples_bytes = [None for _ in range(n_sample)]
130 | start_time = time.time()
131 | for index in range(n_sample):
132 | if index % (n_sample // 10) == 0:
133 | t = time.time() - start_time
134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
135 | start_time = time.time()
136 | path, target = self.samples[index]
137 | if self.cache_mode == "full":
138 | samples_bytes[index] = (ZipReader.read(path), target)
139 | elif self.cache_mode == "part" and index % world_size == global_rank:
140 | samples_bytes[index] = (ZipReader.read(path), target)
141 | else:
142 | samples_bytes[index] = (path, target)
143 | self.samples = samples_bytes
144 |
145 | def __getitem__(self, index):
146 | """
147 | Args:
148 | index (int): Index
149 | Returns:
150 | tuple: (sample, target) where target is class_index of the target class.
151 | """
152 | path, target = self.samples[index]
153 | sample = self.loader(path)
154 | if self.transform is not None:
155 | sample = self.transform(sample)
156 | if self.target_transform is not None:
157 | target = self.target_transform(target)
158 |
159 | return sample, target
160 |
161 | def __len__(self):
162 | return len(self.samples)
163 |
164 | def __repr__(self):
165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
167 | fmt_str += ' Root Location: {}\n'.format(self.root)
168 | tmp = ' Transforms (if any): '
169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
170 | tmp = ' Target Transforms (if any): '
171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
172 | return fmt_str
173 |
174 |
175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
176 |
177 |
178 | def pil_loader(path):
179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
180 | if isinstance(path, bytes):
181 | img = Image.open(io.BytesIO(path))
182 | elif is_zip_path(path):
183 | data = ZipReader.read(path)
184 | img = Image.open(io.BytesIO(data))
185 | else:
186 | with open(path, 'rb') as f:
187 | img = Image.open(f)
188 | return img.convert('RGB')
189 |
190 |
191 | def accimage_loader(path):
192 | import accimage
193 | try:
194 | return accimage.Image(path)
195 | except IOError:
196 | # Potentially a decoding problem, fall back to PIL.Image
197 | return pil_loader(path)
198 |
199 |
200 | def default_img_loader(path):
201 | from torchvision import get_image_backend
202 | if get_image_backend() == 'accimage':
203 | return accimage_loader(path)
204 | else:
205 | return pil_loader(path)
206 |
207 |
208 | class CachedImageFolder(DatasetFolder):
209 | """A generic data loader where the images are arranged in this way: ::
210 | root/dog/xxx.png
211 | root/dog/xxy.png
212 | root/dog/xxz.png
213 | root/cat/123.png
214 | root/cat/nsdf3.png
215 | root/cat/asd932_.png
216 | Args:
217 | root (string): Root directory path.
218 | transform (callable, optional): A function/transform that takes in an PIL image
219 | and returns a transformed version. E.g, ``transforms.RandomCrop``
220 | target_transform (callable, optional): A function/transform that takes in the
221 | target and transforms it.
222 | loader (callable, optional): A function to load an image given its path.
223 | Attributes:
224 | imgs (list): List of (image path, class_index) tuples
225 | """
226 |
227 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
228 | loader=default_img_loader, cache_mode="no"):
229 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
230 | ann_file=ann_file, img_prefix=img_prefix,
231 | transform=transform, target_transform=target_transform,
232 | cache_mode=cache_mode)
233 | self.imgs = self.samples
234 |
235 | def __getitem__(self, index):
236 | """
237 | Args:
238 | index (int): Index
239 | Returns:
240 | tuple: (image, target) where target is class_index of the target class.
241 | """
242 | path, target = self.samples[index]
243 | image = self.loader(path)
244 | if self.transform is not None:
245 | img = self.transform(image)
246 | else:
247 | img = image
248 | if self.target_transform is not None:
249 | target = self.target_transform(target)
250 |
251 | return img, target
252 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 IDEA
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
203 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import time
10 | import argparse
11 | import datetime
12 | import numpy as np
13 |
14 | import torch
15 | import torch.backends.cudnn as cudnn
16 | import torch.distributed as dist
17 |
18 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
19 | from timm.utils import accuracy, AverageMeter
20 |
21 | from config import get_config
22 | from models import build_model
23 | from data import build_loader, build_22k_loader
24 | from lr_scheduler import build_scheduler
25 | from optimizer import build_optimizer
26 | from logger import create_logger
27 | from utils import load_checkpoint, save_checkpoint, load_pretrained, get_grad_norm, auto_resume_helper, reduce_tensor
28 |
29 | import bai
30 |
31 | # try:
32 | # # noinspection PyUnresolvedReferences
33 | # from apex import amp
34 | # except ImportError:
35 | # amp = None
36 | from torch.cuda import amp
37 | from torch.cuda.amp import autocast as autocast
38 |
39 |
40 | def parse_option():
41 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
42 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
43 | parser.add_argument(
44 | "--opts",
45 | help="Modify config options by adding 'KEY VALUE' pairs. ",
46 | default=None,
47 | nargs='+',
48 | )
49 |
50 | # easy config modification
51 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
52 | parser.add_argument('--data-path', type=str, help='path to dataset')
53 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
54 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
55 | help='no: no cache, '
56 | 'full: cache all data, '
57 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
58 | parser.add_argument('--resume', help='resume from checkpoint')
59 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
60 | parser.add_argument('--use-checkpoint', action='store_true',
61 | help="whether to use gradient checkpointing to save memory")
62 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
63 | help='mixed precision opt level, if O0, no amp is used')
64 | parser.add_argument('--output', default='output', type=str, metavar='PATH',
65 | help='root of output folder, the full path is