├── models ├── __init__.py ├── build.py └── lipsformer_swin.py ├── data ├── __init__.py ├── samplers.py ├── zipreader.py ├── cached_image_folder.py └── build.py ├── figures └── teaser.png ├── configs ├── swin_small_patch4_window7_224.yaml ├── swin_tiny_patch4_window7_224.yaml ├── cswin_tiny_patch4_window7_224.yaml ├── cswin_base_patch4_window7_224.yaml ├── swin_mlp_base_patch4_window7_224.yaml ├── swin_tiny_c24_patch4_window8_256.yaml ├── swin_mlp_tiny_c12_patch4_window8_256.yaml ├── swin_mlp_tiny_c24_patch4_window8_256.yaml ├── swin_mlp_tiny_c6_patch4_window8_256.yaml ├── swin_base_patch4_window12_384.yaml ├── swin_large_patch4_window12_384.yaml ├── swin_base_patch4_window7_224.yaml ├── swin_large_patch4_window7_224.yaml ├── swin_large_plus_patch4_window7_224_22k.yaml ├── swin_large_patch4_window12_384_22kto1k_finetune_swin_finetune.yaml └── swin_large_patch4_window12_384_22kto1k_finetune.yaml ├── CODE_OF_CONDUCT.md ├── logger.py ├── optimizer.py ├── .gitignore ├── README.md ├── lr_scheduler.py ├── get_started.md ├── utils.py ├── config.py ├── LICENSE └── main.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader, build_22k_loader -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/LipsFormer/HEAD/figures/teaser.png -------------------------------------------------------------------------------- /configs/swin_small_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_small_patch4_window7_224 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /configs/swin_tiny_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /configs/cswin_tiny_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: cswin 3 | NAME: cswin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | -------------------------------------------------------------------------------- /configs/cswin_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: cswin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | -------------------------------------------------------------------------------- /configs/swin_mlp_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin_mlp 3 | NAME: swin_mlp_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN_MLP: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | -------------------------------------------------------------------------------- /configs/swin_tiny_c24_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_tiny_c24_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swin_mlp_tiny_c12_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c12_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 8, 16, 32, 64 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swin_mlp_tiny_c24_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c24_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swin_mlp_tiny_c6_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c6_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 16, 32, 64, 128 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swin_base_patch4_window12_384.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | DATA: 3 | IMG_SIZE: 384 4 | MODEL: 5 | TYPE: swin 6 | NAME: swin_base_patch4_window12_384 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TEST: 13 | CROP: False -------------------------------------------------------------------------------- /configs/swin_large_patch4_window12_384.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | DATA: 3 | IMG_SIZE: 384 4 | MODEL: 5 | TYPE: swin 6 | NAME: swin_large_patch4_window12_384 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 12 12 | TEST: 13 | CROP: False -------------------------------------------------------------------------------- /configs/swin_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | 11 | TRAIN: 12 | BASE_LR: 2.5e-4 13 | OPTIMIZER: 14 | BETAS: (0.9, 0.998) -------------------------------------------------------------------------------- /configs/swin_large_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | MODEL: 3 | TYPE: swin 4 | NAME: swin_large_patch4_window7_224 5 | DROP_PATH_RATE: 0.5 6 | SWIN: 7 | EMBED_DIM: 192 8 | DEPTHS: [ 2, 2, 18, 2 ] 9 | NUM_HEADS: [ 6, 12, 24, 48 ] 10 | WINDOW_SIZE: 7 11 | 12 | TRAIN: 13 | BASE_LR: 1.25e-4 14 | OPTIMIZER: 15 | BETAS: (0.9, 0.998) -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /configs/swin_large_plus_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | DATA: 3 | DATASET: imagenet22K 4 | MODEL: 5 | TYPE: swin 6 | NAME: swin_plus_large_patch4_window7_224_22k 7 | DROP_PATH_RATE: 0.5 8 | LABEL_SMOOTHING: 0.1 9 | SWIN: 10 | EMBED_DIM: 288 11 | DEPTHS: [ 2, 2, 18, 2 ] 12 | NUM_HEADS: [ 6, 12, 24, 48 ] 13 | WINDOW_SIZE: 7 14 | 15 | TRAIN: 16 | EPOCHS: 30 17 | BASE_LR: 1.25e-4 18 | WEIGHT_DECAY: 0.1 19 | OPTIMIZER: 20 | BETAS: (0.9, 0.98) -------------------------------------------------------------------------------- /configs/swin_large_patch4_window12_384_22kto1k_finetune_swin_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | 4 | MODEL: 5 | TYPE: swin 6 | NAME: swin_large_patch4_window12_384_22kto1k_finetune 7 | PRETRAINED: /comp_robot/workspace/chenyihao/SwinLips_Lalpha_2/output/swin_plus_large_patch4_window7_224_22k/large_plus_22k_30_beta_2_0998/ckpt_epoch_28.pth 8 | DROP_PATH_RATE: 0.2 9 | SWIN: 10 | EMBED_DIM: 288 11 | DEPTHS: [ 2, 2, 18, 2 ] 12 | NUM_HEADS: [ 6, 12, 24, 48 ] 13 | WINDOW_SIZE: 12 14 | TRAIN: 15 | EPOCHS: 30 16 | WARMUP_EPOCHS: 5 17 | WEIGHT_DECAY: 1e-8 18 | BASE_LR: 2e-05 19 | WARMUP_LR: 2e-08 20 | MIN_LR: 2e-07 21 | OPTIMIZER: 22 | BETAS: (0.9, 0.998) 23 | TEST: 24 | CROP: False -------------------------------------------------------------------------------- /configs/swin_large_patch4_window12_384_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | TRANSFORM_TYPE: easy 4 | 5 | MODEL: 6 | TYPE: swin 7 | NAME: swin_large_patch4_window12_384_22kto1k_finetune 8 | PRETRAINED: /comp_robot/workspace/chenyihao/SwinLips_Lalpha_2/output/swin_plus_large_patch4_window7_224_22k/large_plus_22k_30_beta_2_0998/ckpt_epoch_28.pth 9 | DROP_PATH_RATE: 0.2 10 | SWIN: 11 | EMBED_DIM: 288 12 | DEPTHS: [ 2, 2, 18, 2 ] 13 | NUM_HEADS: [ 6, 12, 24, 48 ] 14 | WINDOW_SIZE: 12 15 | TRAIN: 16 | OPTIMIZER: 17 | NAME: SGD 18 | MOMENTUM: 0.9 19 | 20 | EPOCHS: 8 21 | WEIGHT_DECAY: 0.0 22 | BASE_LR: 3e-02 23 | WARMUP_LR: 2e-08 24 | MIN_LR: 2e-07 25 | TEST: 26 | CROP: False 27 | AUG: 28 | MIXUP: 0.1 29 | CUTMIX: 0.0 30 | CUTMIX_MINMAX: None 31 | MIXUP_PROB: 1.0 32 | MIXUP_SWITCH_PROB: 0.0 -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .lipsformer_swin import LipsFormerSwin 9 | 10 | def build_model(config): 11 | model_type = config.MODEL.TYPE 12 | model = LipsFormerSwin(img_size=config.DATA.IMG_SIZE, 13 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 14 | in_chans=config.MODEL.SWIN.IN_CHANS, 15 | num_classes=config.MODEL.NUM_CLASSES, 16 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 17 | depths=config.MODEL.SWIN.DEPTHS, 18 | num_heads=config.MODEL.SWIN.NUM_HEADS, 19 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 20 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 21 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 22 | qk_scale=config.MODEL.SWIN.QK_SCALE, 23 | drop_rate=config.MODEL.DROP_RATE, 24 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 25 | ape=config.MODEL.SWIN.APE, 26 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 27 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 28 | return model 29 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from torch import optim as optim 9 | 10 | 11 | def build_optimizer(config, model): 12 | """ 13 | Build optimizer, set weight decay of normalization to 0 by default. 14 | """ 15 | skip = {} 16 | skip_keywords = {} 17 | if hasattr(model, 'no_weight_decay'): 18 | skip = model.no_weight_decay() 19 | if hasattr(model, 'no_weight_decay_keywords'): 20 | skip_keywords = model.no_weight_decay_keywords() 21 | parameters = set_weight_decay(model, skip, skip_keywords) 22 | 23 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 24 | optimizer = None 25 | if opt_lower == 'sgd': 26 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 27 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 28 | elif opt_lower == 'adamw': 29 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 30 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 31 | 32 | return optimizer 33 | 34 | 35 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 36 | has_decay = [] 37 | no_decay = [] 38 | 39 | for name, param in model.named_parameters(): 40 | if not param.requires_grad: 41 | continue # frozen weights 42 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 43 | check_keywords_in_name(name, skip_keywords): 44 | no_decay.append(param) 45 | print(f"{name} has no weight decay") 46 | else: 47 | has_decay.append(param) 48 | return [{'params': has_decay}, 49 | {'params': no_decay, 'weight_decay': 0.0005}] 50 | 51 | 52 | def check_keywords_in_name(name, keywords=()): 53 | isin = False 54 | for keyword in keywords: 55 | if keyword in name: 56 | isin = True 57 | return isin 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | output/ 132 | 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LipsFormer 2 | 3 | By Xianbiao Qi, Jianan Wang, Yihao Chen, Yukai Shi, and Lei Zhang. 4 | 5 | This repo is the official implementation of ["LipsFormer: Introducing Lipschitz Continuity to Vision Transformers"](https://openreview.net/pdf?id=cHf1DcCwcH3). 6 | 7 | Initial commits: 8 | We release one pretrained model below, the result of the model is 82.70 for LipsFormer-Swin-Tiny. We train the model without using any warmup. 9 | 10 | 1. Pretrained models on ImageNet-1K ([LipsFormer-Swin-T-IN1K](https://github.com/cyh1112/LipsFormer/releases/download/checkpoint/lipsformer-swin-tiny.pth)). 11 | 12 | In our paper, we compile two versions of LipsFormer. One is built on Swin and the other one is based on CSwin, We do not merge these two code bases. Thank the authors of CSwin and Swin for releasing their code base. 13 | In this repo, we release the code base based on Swin. You can easily change it to CSwin code. 14 | 15 |
16 | 17 | ## Introduction 18 | 19 | **LipsFormer** introduces a Lipschitz continuous Transformer to pursue training stability both theoretically and empirically for Transformer-based models. In contrast to previous practical tricks that address training instability by learning rate warmup, layer normalization, attention formulation, and weight initialization, we show that Lipschitz continuity is a more essential property to ensure training stability. In LipsFormer, we replace unstable Transformer component modules with Lipschitz continuous counterparts: CenterNorm instead of LayerNorm, spectral initialization instead of Xavier initialization, scaled cosine similarity attention instead of dot-product attention, and weighted residual shortcut. We prove that these introduced modules are Lipschitz continuous and derive an upper bound on the Lipschitz constant of LipsFormer. 20 | 21 |
22 | 23 | ## Getting Started 24 | For Image Classification, please see [get_started.md](https://github.com/IDEA-Research/LipsFormer/blob/main/get_started.md) for detailed instructions. 25 | 26 | ## Training 27 | ``` 28 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 29 | --cfg configs/swin_tiny_patch4_window7_224.yaml --data-path --batch-size 32 30 | ``` 31 | 32 |
33 | 34 | ## References 35 | ``` 36 | @misc{qi2023lipsformer, 37 | title={LipsFormer: Introducing Lipschitz Continuity to Vision Transformers}, 38 | author={Xianbiao Qi and Jianan Wang and Yihao Chen and Yukai Shi and Lei Zhang}, 39 | year={2023}, 40 | eprint={2304.09856}, 41 | archivePrefix={arXiv}, 42 | primaryClass={cs.CV} 43 | } 44 | 45 | @misc{liu2021swin, 46 | title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, 47 | author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo}, 48 | year={2021}, 49 | eprint={2103.14030}, 50 | archivePrefix={arXiv}, 51 | primaryClass={cs.CV} 52 | } 53 | 54 | @misc{dong2022cswin, 55 | title={CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows}, 56 | author={Xiaoyi Dong and Jianmin Bao and Dongdong Chen and Weiming Zhang and Nenghai Yu and Lu Yuan and Dong Chen and Baining Guo}, 57 | year={2022}, 58 | eprint={2107.00652}, 59 | archivePrefix={arXiv}, 60 | primaryClass={cs.CV} 61 | } 62 | ``` 63 | 64 | 65 | -------------------------------------------------------------------------------- /data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import zipfile 10 | import io 11 | import numpy as np 12 | from PIL import Image 13 | from PIL import ImageFile 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def is_zip_path(img_or_path): 19 | """judge if this is a zip path""" 20 | return '.zip@' in img_or_path 21 | 22 | 23 | class ZipReader(object): 24 | """A class to read zipped files""" 25 | zip_bank = dict() 26 | 27 | def __init__(self): 28 | super(ZipReader, self).__init__() 29 | 30 | @staticmethod 31 | def get_zipfile(path): 32 | zip_bank = ZipReader.zip_bank 33 | if path not in zip_bank: 34 | zfile = zipfile.ZipFile(path, 'r') 35 | zip_bank[path] = zfile 36 | return zip_bank[path] 37 | 38 | @staticmethod 39 | def split_zip_style_path(path): 40 | pos_at = path.index('@') 41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 42 | 43 | zip_path = path[0: pos_at] 44 | folder_path = path[pos_at + 1:] 45 | folder_path = str.strip(folder_path, '/') 46 | return zip_path, folder_path 47 | 48 | @staticmethod 49 | def list_folder(path): 50 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 51 | 52 | zfile = ZipReader.get_zipfile(zip_path) 53 | folder_list = [] 54 | for file_foler_name in zfile.namelist(): 55 | file_foler_name = str.strip(file_foler_name, '/') 56 | if file_foler_name.startswith(folder_path) and \ 57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 58 | file_foler_name != folder_path: 59 | if len(folder_path) == 0: 60 | folder_list.append(file_foler_name) 61 | else: 62 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 63 | 64 | return folder_list 65 | 66 | @staticmethod 67 | def list_files(path, extension=None): 68 | if extension is None: 69 | extension = ['.*'] 70 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 71 | 72 | zfile = ZipReader.get_zipfile(zip_path) 73 | file_lists = [] 74 | for file_foler_name in zfile.namelist(): 75 | file_foler_name = str.strip(file_foler_name, '/') 76 | if file_foler_name.startswith(folder_path) and \ 77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 78 | if len(folder_path) == 0: 79 | file_lists.append(file_foler_name) 80 | else: 81 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 82 | 83 | return file_lists 84 | 85 | @staticmethod 86 | def read(path): 87 | zip_path, path_img = ZipReader.split_zip_style_path(path) 88 | zfile = ZipReader.get_zipfile(zip_path) 89 | data = zfile.read(path_img) 90 | return data 91 | 92 | @staticmethod 93 | def imread(path): 94 | zip_path, path_img = ZipReader.split_zip_style_path(path) 95 | zfile = ZipReader.get_zipfile(zip_path) 96 | data = zfile.read(path_img) 97 | try: 98 | im = Image.open(io.BytesIO(data)) 99 | except: 100 | print("ERROR IMG LOADED: ", path_img) 101 | random_img = np.random.rand(224, 224, 3) * 255 102 | im = Image.fromarray(np.uint8(random_img)) 103 | return im 104 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from timm.scheduler.cosine_lr import CosineLRScheduler 10 | from timm.scheduler.step_lr import StepLRScheduler 11 | from timm.scheduler.scheduler import Scheduler 12 | 13 | 14 | def build_scheduler(config, optimizer, n_iter_per_epoch): 15 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 16 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 17 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 18 | 19 | lr_scheduler = None 20 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 21 | lr_scheduler = CosineLRScheduler( 22 | optimizer, 23 | t_initial=num_steps, 24 | t_mul=1., 25 | lr_min=config.TRAIN.MIN_LR, 26 | warmup_lr_init=config.TRAIN.WARMUP_LR, 27 | warmup_t=warmup_steps, 28 | cycle_limit=1, 29 | t_in_epochs=False, 30 | ) 31 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 32 | lr_scheduler = LinearLRScheduler( 33 | optimizer, 34 | t_initial=num_steps, 35 | lr_min_rate=0.01, 36 | warmup_lr_init=config.TRAIN.WARMUP_LR, 37 | warmup_t=warmup_steps, 38 | t_in_epochs=False, 39 | ) 40 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 41 | lr_scheduler = StepLRScheduler( 42 | optimizer, 43 | decay_t=decay_steps, 44 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 45 | warmup_lr_init=config.TRAIN.WARMUP_LR, 46 | warmup_t=warmup_steps, 47 | t_in_epochs=False, 48 | ) 49 | 50 | return lr_scheduler 51 | 52 | 53 | class LinearLRScheduler(Scheduler): 54 | def __init__(self, 55 | optimizer: torch.optim.Optimizer, 56 | t_initial: int, 57 | lr_min_rate: float, 58 | warmup_t=0, 59 | warmup_lr_init=0., 60 | t_in_epochs=True, 61 | noise_range_t=None, 62 | noise_pct=0.67, 63 | noise_std=1.0, 64 | noise_seed=42, 65 | initialize=True, 66 | ) -> None: 67 | super().__init__( 68 | optimizer, param_group_field="lr", 69 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 70 | initialize=initialize) 71 | 72 | self.t_initial = t_initial 73 | self.lr_min_rate = lr_min_rate 74 | self.warmup_t = warmup_t 75 | self.warmup_lr_init = warmup_lr_init 76 | self.t_in_epochs = t_in_epochs 77 | if self.warmup_t: 78 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 79 | super().update_groups(self.warmup_lr_init) 80 | else: 81 | self.warmup_steps = [1 for _ in self.base_values] 82 | 83 | def _get_lr(self, t): 84 | if t < self.warmup_t: 85 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 86 | else: 87 | t = t - self.warmup_t 88 | total_t = self.t_initial - self.warmup_t 89 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 90 | return lrs 91 | 92 | def get_epoch_values(self, epoch: int): 93 | if self.t_in_epochs: 94 | return self._get_lr(epoch) 95 | else: 96 | return None 97 | 98 | def get_update_values(self, num_updates: int): 99 | if not self.t_in_epochs: 100 | return self._get_lr(num_updates) 101 | else: 102 | return None 103 | -------------------------------------------------------------------------------- /get_started.md: -------------------------------------------------------------------------------- 1 | # LipsFormer-Swin for Image Classification 2 | 3 | This folder contains the implementation of the LipsFormer-Swin for image classification. 4 | 5 | ## Usage 6 | 7 | ### Install 8 | 9 | - Clone this repo: 10 | 11 | ```bash 12 | git clone https://github.com/IDEA-Research/LipsFormer.git 13 | cd LipsFormer 14 | ``` 15 | 16 | - Create a conda virtual environment and activate it: 17 | 18 | ```bash 19 | conda create -n lipsformer python=3.7 -y 20 | conda activate lipsformer 21 | ``` 22 | 23 | - Install `CUDA==10.1` with `cudnn7` following 24 | the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) 25 | - Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`: 26 | 27 | ```bash 28 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch 29 | ``` 30 | 31 | - Install `timm==0.3.2`: 32 | 33 | ```bash 34 | pip install timm==0.3.2 35 | ``` 36 | 37 | - Install `Apex`: 38 | 39 | ```bash 40 | git clone https://github.com/NVIDIA/apex 41 | cd apex 42 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 43 | ``` 44 | 45 | - Install other requirements: 46 | 47 | ```bash 48 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 49 | ``` 50 | 51 | ### Data preparation 52 | 53 | We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to 54 | load data: 55 | 56 | - For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like: 57 | ```bash 58 | $ tree data 59 | imagenet 60 | ├── train 61 | │ ├── class1 62 | │ │ ├── img1.jpeg 63 | │ │ ├── img2.jpeg 64 | │ │ └── ... 65 | │ ├── class2 66 | │ │ ├── img3.jpeg 67 | │ │ └── ... 68 | │ └── ... 69 | └── val 70 | ├── class1 71 | │ ├── img4.jpeg 72 | │ ├── img5.jpeg 73 | │ └── ... 74 | ├── class2 75 | │ ├── img6.jpeg 76 | │ └── ... 77 | └── ... 78 | 79 | ``` 80 | - To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes 81 | four files: 82 | - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits. 83 | - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth 84 | label. Make sure the data folder looks like this: 85 | 86 | ```bash 87 | $ tree data 88 | data 89 | └── ImageNet-Zip 90 | ├── train_map.txt 91 | ├── train.zip 92 | ├── val_map.txt 93 | └── val.zip 94 | 95 | $ head -n 5 data/ImageNet-Zip/val_map.txt 96 | ILSVRC2012_val_00000001.JPEG 65 97 | ILSVRC2012_val_00000002.JPEG 970 98 | ILSVRC2012_val_00000003.JPEG 230 99 | ILSVRC2012_val_00000004.JPEG 809 100 | ILSVRC2012_val_00000005.JPEG 516 101 | 102 | $ head -n 5 data/ImageNet-Zip/train_map.txt 103 | n01440764/n01440764_10026.JPEG 0 104 | n01440764/n01440764_10027.JPEG 0 105 | n01440764/n01440764_10029.JPEG 0 106 | n01440764/n01440764_10040.JPEG 0 107 | n01440764/n01440764_10042.JPEG 0 108 | ``` 109 | 110 | ### Evaluation 111 | 112 | To evaluate a pre-trained `LipsFormer` on ImageNet val, run: 113 | 114 | ```bash 115 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py --eval \ 116 | --cfg --resume --data-path 117 | ``` 118 | 119 | For example, to evaluate the `LipsFormer-Swin-B` with a single GPU: 120 | 121 | ```bash 122 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \ 123 | --cfg configs/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path 124 | ``` 125 | 126 | ### Training from scratch 127 | 128 | To train a `LipsFormer-Swin` on ImageNet from scratch, run: 129 | 130 | ```bash 131 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py \ 132 | --cfg --data-path [--batch-size --output --tag ] 133 | ``` 134 | 135 | **Notes**: 136 | 137 | - To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters. 138 | - To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will 139 | shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU. 140 | - When GPU memory is not enough, you can try the following suggestions: 141 | - Use gradient accumulation by adding `--accumulation-steps `, set appropriate `` according to your need. 142 | - Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`. 143 | Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details. 144 | - We recommend using multi-node with more GPUs for training very large models, a tutorial can be found 145 | in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html). 146 | - To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g., 147 | `--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5. 148 | - For additional options, see [config](config.py) and run `python main.py --help` to get detailed message. 149 | 150 | For example, to train `LipsFormer` with 8 GPU on a single node for 300 epochs, run: 151 | 152 | `Swin-T`: 153 | 154 | ```bash 155 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 156 | --cfg configs/swin_tiny_patch4_window7_224.yaml --data-path --batch-size 128 157 | ``` 158 | 159 | `Swin-S`: 160 | 161 | ```bash 162 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 163 | --cfg configs/swin_small_patch4_window7_224.yaml --data-path --batch-size 128 164 | ``` 165 | 166 | `Swin-B`: 167 | 168 | ```bash 169 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 170 | --cfg configs/swin_base_patch4_window7_224.yaml --data-path --batch-size 64 \ 171 | --accumulation-steps 2 [--use-checkpoint] 172 | ``` 173 | 174 | ### Throughput 175 | 176 | To measure the throughput, run: 177 | 178 | ```bash 179 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \ 180 | --cfg --data-path --batch-size 64 --throughput --amp-opt-level O0 181 | ``` 182 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.distributed as dist 11 | 12 | try: 13 | # noinspection PyUnresolvedReferences 14 | from apex import amp 15 | except ImportError: 16 | amp = None 17 | 18 | 19 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger): 20 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 21 | if config.MODEL.RESUME.startswith('https'): 22 | checkpoint = torch.hub.load_state_dict_from_url( 23 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 24 | else: 25 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 26 | msg = model.load_state_dict(checkpoint['model'], strict=False) 27 | logger.info(msg) 28 | max_accuracy = 0.0 29 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 30 | optimizer.load_state_dict(checkpoint['optimizer']) 31 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 32 | config.defrost() 33 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 34 | config.freeze() 35 | if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": 36 | amp.load_state_dict(checkpoint['amp']) 37 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 38 | if 'max_accuracy' in checkpoint: 39 | max_accuracy = checkpoint['max_accuracy'] 40 | 41 | del checkpoint 42 | torch.cuda.empty_cache() 43 | return max_accuracy 44 | 45 | 46 | def load_pretrained(config, model, logger): 47 | logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") 48 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') 49 | state_dict = checkpoint['model'] 50 | 51 | # delete relative_position_index since we always re-init it 52 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] 53 | for k in relative_position_index_keys: 54 | del state_dict[k] 55 | 56 | # delete relative_coords_table since we always re-init it 57 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] 58 | for k in relative_position_index_keys: 59 | del state_dict[k] 60 | 61 | # delete attn_mask since we always re-init it 62 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] 63 | for k in attn_mask_keys: 64 | del state_dict[k] 65 | 66 | # bicubic interpolate relative_position_bias_table if not match 67 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 68 | for k in relative_position_bias_table_keys: 69 | relative_position_bias_table_pretrained = state_dict[k] 70 | relative_position_bias_table_current = model.state_dict()[k] 71 | L1, nH1 = relative_position_bias_table_pretrained.size() 72 | L2, nH2 = relative_position_bias_table_current.size() 73 | if nH1 != nH2: 74 | logger.warning(f"Error in loading {k}, passing......") 75 | else: 76 | if L1 != L2: 77 | # bicubic interpolate relative_position_bias_table if not match 78 | S1 = int(L1 ** 0.5) 79 | S2 = int(L2 ** 0.5) 80 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( 81 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), 82 | mode='bicubic') 83 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) 84 | 85 | # bicubic interpolate absolute_pos_embed if not match 86 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] 87 | for k in absolute_pos_embed_keys: 88 | # dpe 89 | absolute_pos_embed_pretrained = state_dict[k] 90 | absolute_pos_embed_current = model.state_dict()[k] 91 | _, L1, C1 = absolute_pos_embed_pretrained.size() 92 | _, L2, C2 = absolute_pos_embed_current.size() 93 | if C1 != C1: 94 | logger.warning(f"Error in loading {k}, passing......") 95 | else: 96 | if L1 != L2: 97 | S1 = int(L1 ** 0.5) 98 | S2 = int(L2 ** 0.5) 99 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) 100 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) 101 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( 102 | absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') 103 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) 104 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) 105 | state_dict[k] = absolute_pos_embed_pretrained_resized 106 | 107 | # check classifier, if not match, then re-init classifier to zero 108 | head_bias_pretrained = state_dict['head.bias'] 109 | Nc1 = head_bias_pretrained.shape[0] 110 | Nc2 = model.head.bias.shape[0] 111 | if (Nc1 != Nc2): 112 | if Nc1 == 21841 and Nc2 == 1000: 113 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......") 114 | map22kto1k_path = f'data/map22kto1k.txt' 115 | with open(map22kto1k_path) as f: 116 | map22kto1k = f.readlines() 117 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] 118 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] 119 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] 120 | else: 121 | torch.nn.init.constant_(model.head.bias, 0.) 122 | torch.nn.init.constant_(model.head.weight, 0.) 123 | del state_dict['head.weight'] 124 | del state_dict['head.bias'] 125 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0") 126 | 127 | msg = model.load_state_dict(state_dict, strict=False) 128 | logger.warning(msg) 129 | 130 | logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'") 131 | 132 | del checkpoint 133 | torch.cuda.empty_cache() 134 | 135 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger): 136 | save_state = {'model': model.state_dict(), 137 | 'optimizer': optimizer.state_dict(), 138 | 'lr_scheduler': lr_scheduler.state_dict(), 139 | 'max_accuracy': max_accuracy, 140 | 'epoch': epoch, 141 | 'config': config} 142 | # if config.AMP_OPT_LEVEL != "O0": 143 | # save_state['amp'] = amp.state_dict() 144 | 145 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 146 | logger.info(f"{save_path} saving......") 147 | torch.save(save_state, save_path) 148 | logger.info(f"{save_path} saved !!!") 149 | 150 | 151 | def get_grad_norm(parameters, norm_type=2): 152 | if isinstance(parameters, torch.Tensor): 153 | parameters = [parameters] 154 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 155 | norm_type = float(norm_type) 156 | total_norm = 0 157 | for p in parameters: 158 | param_norm = p.grad.data.norm(norm_type) 159 | total_norm += param_norm.item() ** norm_type 160 | total_norm = total_norm ** (1. / norm_type) 161 | return total_norm 162 | 163 | 164 | def auto_resume_helper(output_dir): 165 | checkpoints = os.listdir(output_dir) 166 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 167 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 168 | if len(checkpoints) > 0: 169 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 170 | print(f"The latest checkpoint founded: {latest_checkpoint}") 171 | resume_file = latest_checkpoint 172 | else: 173 | resume_file = None 174 | return resume_file 175 | 176 | 177 | def reduce_tensor(tensor): 178 | rt = tensor.clone() 179 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 180 | rt /= dist.get_world_size() 181 | return rt 182 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # LipsFormer Swin 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | _C.DATA.TRANSFORM_TYPE = None 42 | 43 | # ----------------------------------------------------------------------------- 44 | # Model settings 45 | # ----------------------------------------------------------------------------- 46 | _C.MODEL = CN() 47 | # Model type 48 | _C.MODEL.TYPE = 'swin' 49 | # Model name 50 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 51 | # Checkpoint to resume, could be overwritten by command line argument 52 | _C.MODEL.RESUME = '' 53 | # Number of classes, overwritten in data preparation 54 | _C.MODEL.NUM_CLASSES = 1000 55 | # Dropout rate 56 | _C.MODEL.DROP_RATE = 0.0 57 | # Drop path rate 58 | _C.MODEL.DROP_PATH_RATE = 0.1 59 | # Label Smoothing 60 | _C.MODEL.LABEL_SMOOTHING = 0.1 61 | 62 | _C.MODEL.PRETRAINED = None 63 | 64 | 65 | # Swin Transformer parameters 66 | _C.MODEL.SWIN = CN() 67 | _C.MODEL.SWIN.PATCH_SIZE = 4 68 | _C.MODEL.SWIN.IN_CHANS = 3 69 | _C.MODEL.SWIN.EMBED_DIM = 96 70 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 71 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 72 | _C.MODEL.SWIN.WINDOW_SIZE = 7 73 | _C.MODEL.SWIN.MLP_RATIO = 4. 74 | _C.MODEL.SWIN.QKV_BIAS = True 75 | _C.MODEL.SWIN.QK_SCALE = None 76 | _C.MODEL.SWIN.APE = False 77 | _C.MODEL.SWIN.PATCH_NORM = True 78 | 79 | # Swin MLP parameters 80 | _C.MODEL.SWIN_MLP = CN() 81 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 82 | _C.MODEL.SWIN_MLP.IN_CHANS = 3 83 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96 84 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] 85 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] 86 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 87 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4. 88 | _C.MODEL.SWIN_MLP.APE = False 89 | _C.MODEL.SWIN_MLP.PATCH_NORM = True 90 | 91 | # ----------------------------------------------------------------------------- 92 | # Training settings 93 | # ----------------------------------------------------------------------------- 94 | _C.TRAIN = CN() 95 | _C.TRAIN.START_EPOCH = 0 96 | _C.TRAIN.EPOCHS = 100 97 | _C.TRAIN.WARMUP_EPOCHS = 0 98 | _C.TRAIN.WEIGHT_DECAY = 0.05 99 | _C.TRAIN.BASE_LR = 5e-4 100 | _C.TRAIN.WARMUP_LR = 5e-4 101 | _C.TRAIN.MIN_LR = 5e-6 102 | # Clip gradient norm 103 | _C.TRAIN.CLIP_GRAD = 5.0 104 | # Auto resume from latest checkpoint 105 | _C.TRAIN.AUTO_RESUME = True 106 | # Gradient accumulation steps 107 | # could be overwritten by command line argument 108 | _C.TRAIN.ACCUMULATION_STEPS = 0 109 | # Whether to use gradient checkpointing to save memory 110 | # could be overwritten by command line argument 111 | _C.TRAIN.USE_CHECKPOINT = False 112 | 113 | # LR scheduler 114 | _C.TRAIN.LR_SCHEDULER = CN() 115 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 116 | # Epoch interval to decay LR, used in StepLRScheduler 117 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 118 | # LR decay rate, used in StepLRScheduler 119 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 120 | 121 | # Optimizer 122 | _C.TRAIN.OPTIMIZER = CN() 123 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 124 | # Optimizer Epsilon 125 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 126 | # Optimizer Betas 127 | #_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 128 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.98) 129 | # SGD momentum 130 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 131 | 132 | # ----------------------------------------------------------------------------- 133 | # Augmentation settings 134 | # ----------------------------------------------------------------------------- 135 | _C.AUG = CN() 136 | # Color jitter factor 137 | _C.AUG.COLOR_JITTER = 0.4 138 | # Use AutoAugment policy. "v0" or "original" 139 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 140 | # Random erase prob 141 | _C.AUG.REPROB = 0.25 142 | # Random erase mode 143 | _C.AUG.REMODE = 'pixel' 144 | # Random erase count 145 | _C.AUG.RECOUNT = 1 146 | # Mixup alpha, mixup enabled if > 0 147 | _C.AUG.MIXUP = 0.8 148 | # Cutmix alpha, cutmix enabled if > 0 149 | _C.AUG.CUTMIX = 1.0 150 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 151 | _C.AUG.CUTMIX_MINMAX = None 152 | # Probability of performing mixup or cutmix when either/both is enabled 153 | _C.AUG.MIXUP_PROB = 1.0 154 | # Probability of switching to cutmix when both mixup and cutmix enabled 155 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 156 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 157 | _C.AUG.MIXUP_MODE = 'batch' 158 | 159 | # ----------------------------------------------------------------------------- 160 | # Testing settings 161 | # ----------------------------------------------------------------------------- 162 | _C.TEST = CN() 163 | # Whether to use center crop when testing 164 | _C.TEST.CROP = True 165 | 166 | # ----------------------------------------------------------------------------- 167 | # Misc 168 | # ----------------------------------------------------------------------------- 169 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 170 | # overwritten by command line argument 171 | _C.AMP_OPT_LEVEL = '' 172 | # Path to output folder, overwritten by command line argument 173 | _C.OUTPUT = '' 174 | # Tag of experiment, overwritten by command line argument 175 | _C.TAG = 'default' 176 | # Frequency to save checkpoint 177 | _C.SAVE_FREQ = 1 178 | # Frequency to logging info 179 | _C.PRINT_FREQ = 10 180 | # Fixed random seed 181 | _C.SEED = 0 182 | # Perform evaluation only, overwritten by command line argument 183 | _C.EVAL_MODE = False 184 | # Test throughput only, overwritten by command line argument 185 | _C.THROUGHPUT_MODE = False 186 | # local rank for DistributedDataParallel, given by command line argument 187 | _C.LOCAL_RANK = 0 188 | 189 | 190 | def _update_config_from_file(config, cfg_file): 191 | config.defrost() 192 | with open(cfg_file, 'r') as f: 193 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 194 | 195 | for cfg in yaml_cfg.setdefault('BASE', ['']): 196 | if cfg: 197 | _update_config_from_file( 198 | config, os.path.join(os.path.dirname(cfg_file), cfg) 199 | ) 200 | print('=> merge config from {}'.format(cfg_file)) 201 | config.merge_from_file(cfg_file) 202 | config.freeze() 203 | 204 | 205 | def update_config(config, args): 206 | _update_config_from_file(config, args.cfg) 207 | 208 | config.defrost() 209 | if args.opts: 210 | config.merge_from_list(args.opts) 211 | 212 | # merge from specific arguments 213 | if args.epochs: 214 | config.TRAIN.EPOCHS = args.epochs 215 | if args.batch_size: 216 | config.DATA.BATCH_SIZE = args.batch_size 217 | if args.data_path: 218 | config.DATA.DATA_PATH = args.data_path 219 | if args.zip: 220 | config.DATA.ZIP_MODE = True 221 | if args.cache_mode: 222 | config.DATA.CACHE_MODE = args.cache_mode 223 | if args.resume: 224 | config.MODEL.RESUME = args.resume 225 | if args.accumulation_steps: 226 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 227 | if args.use_checkpoint: 228 | config.TRAIN.USE_CHECKPOINT = True 229 | if args.amp_opt_level: 230 | config.AMP_OPT_LEVEL = args.amp_opt_level 231 | if args.output: 232 | config.OUTPUT = args.output 233 | if args.tag: 234 | config.TAG = args.tag 235 | if args.eval: 236 | config.EVAL_MODE = True 237 | if args.throughput: 238 | config.THROUGHPUT_MODE = True 239 | 240 | # set local rank for distributed training 241 | config.LOCAL_RANK = args.local_rank 242 | 243 | # output folder 244 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 245 | 246 | config.freeze() 247 | 248 | 249 | def get_config(args): 250 | """Get a yacs CfgNode object with default values.""" 251 | # Return a clone so that the defaults will not be altered 252 | # This is for the "local variable" use pattern 253 | config = _C.clone() 254 | update_config(config, args) 255 | 256 | return config 257 | -------------------------------------------------------------------------------- /data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import io 9 | import os 10 | import time 11 | import torch.distributed as dist 12 | import torch.utils.data as data 13 | from PIL import Image 14 | 15 | from .zipreader import is_zip_path, ZipReader 16 | 17 | 18 | def has_file_allowed_extension(filename, extensions): 19 | """Checks if a file is an allowed extension. 20 | Args: 21 | filename (string): path to a file 22 | Returns: 23 | bool: True if the filename ends with a known image extension 24 | """ 25 | filename_lower = filename.lower() 26 | return any(filename_lower.endswith(ext) for ext in extensions) 27 | 28 | 29 | def find_classes(dir): 30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 31 | classes.sort() 32 | class_to_idx = {classes[i]: i for i in range(len(classes))} 33 | return classes, class_to_idx 34 | 35 | 36 | def make_dataset(dir, class_to_idx, extensions): 37 | images = [] 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if has_file_allowed_extension(fname, extensions): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 55 | images = [] 56 | with open(ann_file, "r") as f: 57 | contents = f.readlines() 58 | for line_str in contents: 59 | path_contents = [c for c in line_str.split('\t')] 60 | im_file_name = path_contents[0] 61 | class_index = int(path_contents[1]) 62 | 63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 64 | item = (os.path.join(img_prefix, im_file_name), class_index) 65 | 66 | images.append(item) 67 | 68 | return images 69 | 70 | 71 | class DatasetFolder(data.Dataset): 72 | """A generic data loader where the samples are arranged in this way: :: 73 | root/class_x/xxx.ext 74 | root/class_x/xxy.ext 75 | root/class_x/xxz.ext 76 | root/class_y/123.ext 77 | root/class_y/nsdf3.ext 78 | root/class_y/asd932_.ext 79 | Args: 80 | root (string): Root directory path. 81 | loader (callable): A function to load a sample given its path. 82 | extensions (list[string]): A list of allowed extensions. 83 | transform (callable, optional): A function/transform that takes in 84 | a sample and returns a transformed version. 85 | E.g, ``transforms.RandomCrop`` for images. 86 | target_transform (callable, optional): A function/transform that takes 87 | in the target and transforms it. 88 | Attributes: 89 | samples (list): List of (sample path, class_index) tuples 90 | """ 91 | 92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 93 | cache_mode="no"): 94 | # image folder mode 95 | if ann_file == '': 96 | _, class_to_idx = find_classes(root) 97 | samples = make_dataset(root, class_to_idx, extensions) 98 | # zip mode 99 | else: 100 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 101 | os.path.join(root, img_prefix), 102 | extensions) 103 | 104 | if len(samples) == 0: 105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.samples = samples 113 | self.labels = [y_1k for _, y_1k in samples] 114 | self.classes = list(set(self.labels)) 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.cache_mode = cache_mode 120 | if self.cache_mode != "no": 121 | self.init_cache() 122 | 123 | def init_cache(self): 124 | assert self.cache_mode in ["part", "full"] 125 | n_sample = len(self.samples) 126 | global_rank = dist.get_rank() 127 | world_size = dist.get_world_size() 128 | 129 | samples_bytes = [None for _ in range(n_sample)] 130 | start_time = time.time() 131 | for index in range(n_sample): 132 | if index % (n_sample // 10) == 0: 133 | t = time.time() - start_time 134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 135 | start_time = time.time() 136 | path, target = self.samples[index] 137 | if self.cache_mode == "full": 138 | samples_bytes[index] = (ZipReader.read(path), target) 139 | elif self.cache_mode == "part" and index % world_size == global_rank: 140 | samples_bytes[index] = (ZipReader.read(path), target) 141 | else: 142 | samples_bytes[index] = (path, target) 143 | self.samples = samples_bytes 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | Returns: 150 | tuple: (sample, target) where target is class_index of the target class. 151 | """ 152 | path, target = self.samples[index] 153 | sample = self.loader(path) 154 | if self.transform is not None: 155 | sample = self.transform(sample) 156 | if self.target_transform is not None: 157 | target = self.target_transform(target) 158 | 159 | return sample, target 160 | 161 | def __len__(self): 162 | return len(self.samples) 163 | 164 | def __repr__(self): 165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 167 | fmt_str += ' Root Location: {}\n'.format(self.root) 168 | tmp = ' Transforms (if any): ' 169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | tmp = ' Target Transforms (if any): ' 171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 172 | return fmt_str 173 | 174 | 175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 176 | 177 | 178 | def pil_loader(path): 179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 180 | if isinstance(path, bytes): 181 | img = Image.open(io.BytesIO(path)) 182 | elif is_zip_path(path): 183 | data = ZipReader.read(path) 184 | img = Image.open(io.BytesIO(data)) 185 | else: 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | 190 | 191 | def accimage_loader(path): 192 | import accimage 193 | try: 194 | return accimage.Image(path) 195 | except IOError: 196 | # Potentially a decoding problem, fall back to PIL.Image 197 | return pil_loader(path) 198 | 199 | 200 | def default_img_loader(path): 201 | from torchvision import get_image_backend 202 | if get_image_backend() == 'accimage': 203 | return accimage_loader(path) 204 | else: 205 | return pil_loader(path) 206 | 207 | 208 | class CachedImageFolder(DatasetFolder): 209 | """A generic data loader where the images are arranged in this way: :: 210 | root/dog/xxx.png 211 | root/dog/xxy.png 212 | root/dog/xxz.png 213 | root/cat/123.png 214 | root/cat/nsdf3.png 215 | root/cat/asd932_.png 216 | Args: 217 | root (string): Root directory path. 218 | transform (callable, optional): A function/transform that takes in an PIL image 219 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 220 | target_transform (callable, optional): A function/transform that takes in the 221 | target and transforms it. 222 | loader (callable, optional): A function to load an image given its path. 223 | Attributes: 224 | imgs (list): List of (image path, class_index) tuples 225 | """ 226 | 227 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 228 | loader=default_img_loader, cache_mode="no"): 229 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 230 | ann_file=ann_file, img_prefix=img_prefix, 231 | transform=transform, target_transform=target_transform, 232 | cache_mode=cache_mode) 233 | self.imgs = self.samples 234 | 235 | def __getitem__(self, index): 236 | """ 237 | Args: 238 | index (int): Index 239 | Returns: 240 | tuple: (image, target) where target is class_index of the target class. 241 | """ 242 | path, target = self.samples[index] 243 | image = self.loader(path) 244 | if self.transform is not None: 245 | img = self.transform(image) 246 | else: 247 | img = image 248 | if self.target_transform is not None: 249 | target = self.target_transform(target) 250 | 251 | return img, target 252 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 IDEA 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import time 10 | import argparse 11 | import datetime 12 | import numpy as np 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | 18 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 19 | from timm.utils import accuracy, AverageMeter 20 | 21 | from config import get_config 22 | from models import build_model 23 | from data import build_loader, build_22k_loader 24 | from lr_scheduler import build_scheduler 25 | from optimizer import build_optimizer 26 | from logger import create_logger 27 | from utils import load_checkpoint, save_checkpoint, load_pretrained, get_grad_norm, auto_resume_helper, reduce_tensor 28 | 29 | import bai 30 | 31 | # try: 32 | # # noinspection PyUnresolvedReferences 33 | # from apex import amp 34 | # except ImportError: 35 | # amp = None 36 | from torch.cuda import amp 37 | from torch.cuda.amp import autocast as autocast 38 | 39 | 40 | def parse_option(): 41 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 42 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 43 | parser.add_argument( 44 | "--opts", 45 | help="Modify config options by adding 'KEY VALUE' pairs. ", 46 | default=None, 47 | nargs='+', 48 | ) 49 | 50 | # easy config modification 51 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 52 | parser.add_argument('--data-path', type=str, help='path to dataset') 53 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 54 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 55 | help='no: no cache, ' 56 | 'full: cache all data, ' 57 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 58 | parser.add_argument('--resume', help='resume from checkpoint') 59 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 60 | parser.add_argument('--use-checkpoint', action='store_true', 61 | help="whether to use gradient checkpointing to save memory") 62 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 63 | help='mixed precision opt level, if O0, no amp is used') 64 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 65 | help='root of output folder, the full path is // (default: output)') 66 | parser.add_argument('--tag', help='tag of experiment') 67 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 68 | parser.add_argument('--epochs', default=100, type=int, help='Perform evaluation only') 69 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 70 | 71 | # distributed training 72 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 73 | 74 | args, unparsed = parser.parse_known_args() 75 | 76 | config = get_config(args) 77 | 78 | return args, config 79 | 80 | 81 | def main(config): 82 | dataset_val = None 83 | data_loader_val = None 84 | 85 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) 86 | 87 | # dataset_train, data_loader_train, mixup_fn = build_22k_loader(config) 88 | 89 | 90 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 91 | model = build_model(config) 92 | 93 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 94 | logger.info(f"number of params: {n_parameters}") 95 | 96 | model.cuda() 97 | logger.info(str(model)) 98 | 99 | optimizer = build_optimizer(config, model) 100 | if config.AMP_OPT_LEVEL != "O0": 101 | pass 102 | # model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) 103 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 104 | model_without_ddp = model.module 105 | 106 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 107 | logger.info(f"number of params: {n_parameters}") 108 | if hasattr(model_without_ddp, 'flops'): 109 | flops = model_without_ddp.flops() 110 | logger.info(f"number of GFLOPs: {flops / 1e9}") 111 | 112 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 113 | 114 | if config.AUG.MIXUP > 0.: 115 | # smoothing is handled with mixup label transform 116 | criterion = SoftTargetCrossEntropy() 117 | elif config.MODEL.LABEL_SMOOTHING > 0.: 118 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 119 | else: 120 | criterion = torch.nn.CrossEntropyLoss() 121 | 122 | max_accuracy = 0.0 123 | 124 | if config.TRAIN.AUTO_RESUME: 125 | resume_file = auto_resume_helper(config.OUTPUT) 126 | if resume_file: 127 | if config.MODEL.RESUME: 128 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 129 | config.defrost() 130 | config.MODEL.RESUME = resume_file 131 | config.freeze() 132 | logger.info(f'auto resuming from {resume_file}') 133 | else: 134 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 135 | 136 | if config.MODEL.RESUME: 137 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 138 | 139 | if data_loader_val: 140 | acc1, acc5, loss = validate(config, data_loader_val, model) 141 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 142 | if config.EVAL_MODE: 143 | return 144 | 145 | if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): 146 | load_pretrained(config, model_without_ddp, logger) 147 | 148 | 149 | if config.THROUGHPUT_MODE: 150 | throughput(data_loader_val, model, logger) 151 | return 152 | 153 | logger.info("Start training") 154 | start_time = time.time() 155 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 156 | data_loader_train.sampler.set_epoch(epoch) 157 | 158 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) 159 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 160 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) 161 | 162 | if data_loader_val: 163 | acc1, acc5, loss = validate(config, data_loader_val, model, epoch) 164 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 165 | max_accuracy = max(max_accuracy, acc1) 166 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 167 | 168 | total_time = time.time() - start_time 169 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 170 | logger.info('Training time {}'.format(total_time_str)) 171 | 172 | 173 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): 174 | model.train() 175 | optimizer.zero_grad() 176 | 177 | num_steps = len(data_loader) 178 | batch_time = AverageMeter() 179 | loss_meter = AverageMeter() 180 | norm_meter = AverageMeter() 181 | 182 | # torch.amp scaler 183 | if config.AMP_OPT_LEVEL != "O0": 184 | scaler = torch.cuda.amp.GradScaler() 185 | 186 | start = time.time() 187 | end = time.time() 188 | for idx, (samples, targets) in enumerate(data_loader): 189 | samples = samples.cuda(non_blocking=True) 190 | targets = targets.cuda(non_blocking=True) 191 | 192 | if mixup_fn is not None: 193 | samples, targets = mixup_fn(samples, targets) 194 | 195 | with autocast(): 196 | outputs = model(samples) 197 | loss = criterion(outputs, targets) 198 | 199 | if config.TRAIN.ACCUMULATION_STEPS > 1: 200 | #loss = criterion(outputs, targets) 201 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 202 | if config.AMP_OPT_LEVEL != "O0": 203 | #with amp.scale_loss(loss, optimizer) as scaled_loss: 204 | # scaled_loss.backward() 205 | scaler.scale(loss).backward() 206 | if config.TRAIN.CLIP_GRAD: 207 | #grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 208 | 209 | # Unscales the gradients of optimizer's assigned params in-place 210 | scaler.unscale_(optimizer) 211 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 212 | else: 213 | #grad_norm = get_grad_norm(amp.master_params(optimizer)) 214 | grad_norm = get_grad_norm(model.parameters()) 215 | else: 216 | loss.backward() 217 | if config.TRAIN.CLIP_GRAD: 218 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 219 | else: 220 | grad_norm = get_grad_norm(model.parameters()) 221 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 222 | #optimizer.step() 223 | scaler.step(optimizer) 224 | scaler.update() 225 | optimizer.zero_grad() 226 | lr_scheduler.step_update(epoch * num_steps + idx) 227 | else: 228 | #loss = criterion(outputs, targets) 229 | optimizer.zero_grad() 230 | if config.AMP_OPT_LEVEL != "O0": 231 | #with amp.scale_loss(loss, optimizer) as scaled_loss: 232 | # scaled_loss.backward() 233 | scaler.scale(loss).backward() 234 | if config.TRAIN.CLIP_GRAD: 235 | #grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 236 | 237 | # Unscales the gradients of optimizer's assigned params in-place 238 | scaler.unscale_(optimizer) 239 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 240 | else: 241 | #grad_norm = get_grad_norm(amp.master_params(optimizer)) 242 | grad_norm = get_grad_norm(model.parameters()) 243 | else: 244 | loss.backward() 245 | if config.TRAIN.CLIP_GRAD: 246 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 247 | else: 248 | grad_norm = get_grad_norm(model.parameters()) 249 | #optimizer.step() 250 | scaler.step(optimizer) 251 | scaler.update() 252 | lr_scheduler.step_update(epoch * num_steps + idx) 253 | 254 | torch.cuda.synchronize() 255 | 256 | loss_meter.update(loss.item(), targets.size(0)) 257 | norm_meter.update(grad_norm) 258 | batch_time.update(time.time() - end) 259 | end = time.time() 260 | 261 | if idx % config.PRINT_FREQ == 0: 262 | lr = optimizer.param_groups[0]['lr'] 263 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 264 | etas = batch_time.avg * (num_steps - idx) 265 | logger.info( 266 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 267 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 268 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 269 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 270 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 271 | f'mem {memory_used:.0f}MB') 272 | epoch_time = time.time() - start 273 | 274 | if dist.get_rank() == 0: 275 | bai.text(f"{config.TAG} \n epoch: {epoch} loss: {loss_meter.avg:.4f}, grad norm: {norm_meter.avg:.4f} ") 276 | 277 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 278 | 279 | @torch.no_grad() 280 | def validate(config, data_loader, model, epoch=0): 281 | criterion = torch.nn.CrossEntropyLoss() 282 | model.eval() 283 | 284 | batch_time = AverageMeter() 285 | loss_meter = AverageMeter() 286 | acc1_meter = AverageMeter() 287 | acc5_meter = AverageMeter() 288 | 289 | end = time.time() 290 | for idx, (images, target) in enumerate(data_loader): 291 | images = images.cuda(non_blocking=True) 292 | target = target.cuda(non_blocking=True) 293 | 294 | # compute output 295 | output = model(images) 296 | 297 | # measure accuracy and record loss 298 | loss = criterion(output, target) 299 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 300 | 301 | acc1 = reduce_tensor(acc1) 302 | acc5 = reduce_tensor(acc5) 303 | loss = reduce_tensor(loss) 304 | 305 | loss_meter.update(loss.item(), target.size(0)) 306 | acc1_meter.update(acc1.item(), target.size(0)) 307 | acc5_meter.update(acc5.item(), target.size(0)) 308 | 309 | # measure elapsed time 310 | batch_time.update(time.time() - end) 311 | end = time.time() 312 | 313 | if idx % config.PRINT_FREQ == 0: 314 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 315 | logger.info( 316 | f'Test: [{idx}/{len(data_loader)}]\t' 317 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 318 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 319 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 320 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 321 | f'Mem {memory_used:.0f}MB') 322 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 323 | 324 | if epoch % 10 == 0 and dist.get_rank() == 0: 325 | bai.text(f"{config.TAG} \n epoch: {epoch} * Acc@1 {acc1_meter.avg:.3f}") 326 | 327 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 328 | 329 | 330 | @torch.no_grad() 331 | def throughput(data_loader, model, logger): 332 | model.eval() 333 | 334 | for idx, (images, _) in enumerate(data_loader): 335 | images = images.cuda(non_blocking=True) 336 | batch_size = images.shape[0] 337 | for i in range(50): 338 | model(images) 339 | torch.cuda.synchronize() 340 | logger.info(f"throughput averaged with 30 times") 341 | tic1 = time.time() 342 | for i in range(30): 343 | model(images) 344 | torch.cuda.synchronize() 345 | tic2 = time.time() 346 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 347 | return 348 | 349 | 350 | if __name__ == '__main__': 351 | _, config = parse_option() 352 | 353 | if config.AMP_OPT_LEVEL != "O0": 354 | assert amp is not None, "amp not installed!" 355 | 356 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 357 | rank = int(os.environ["RANK"]) 358 | world_size = int(os.environ['WORLD_SIZE']) 359 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 360 | else: 361 | rank = -1 362 | world_size = -1 363 | 364 | torch.cuda.set_device(config.LOCAL_RANK) 365 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 366 | torch.distributed.barrier() 367 | 368 | seed = config.SEED + dist.get_rank() 369 | torch.manual_seed(seed) 370 | np.random.seed(seed) 371 | cudnn.benchmark = True 372 | 373 | # linear scale the learning rate according to total batch size, may not be optimal 374 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 375 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 376 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 377 | # gradient accumulation also need to scale the learning rate 378 | if config.TRAIN.ACCUMULATION_STEPS > 1: 379 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 380 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 381 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 382 | config.defrost() 383 | config.TRAIN.BASE_LR = linear_scaled_lr 384 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 385 | config.TRAIN.MIN_LR = linear_scaled_min_lr 386 | config.freeze() 387 | 388 | os.makedirs(config.OUTPUT, exist_ok=True) 389 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 390 | 391 | if dist.get_rank() == 0: 392 | path = os.path.join(config.OUTPUT, "config.json") 393 | with open(path, "w") as f: 394 | f.write(config.dump()) 395 | logger.info(f"Full config saved to {path}") 396 | 397 | # print config 398 | logger.info(config.dump()) 399 | 400 | main(config) 401 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | import torch.distributed as dist 12 | from torchvision import datasets, transforms 13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.data import Mixup 15 | from timm.data import create_transform 16 | from timm.data.transforms import _pil_interp 17 | 18 | from .cached_image_folder import CachedImageFolder 19 | from .samplers import SubsetRandomSampler 20 | 21 | from PIL import Image 22 | import random 23 | 24 | img_list = ["/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02850732/n02850732_2993.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01620735/n01620735_1406.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10235024/n10235024_13754.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03221351/n03221351_4401.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03171228/n03171228_397.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02119477/n02119477_1663.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04963307/n04963307_2301.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10746931/n10746931_8057.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07931001/n07931001_6055.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03046802/n03046802_1837.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03246933/n03246933_3240.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03397266/n03397266_5192.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10559288/n10559288_6980.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07723559/n07723559_4116.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03242506/n03242506_1505.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07938149/n07938149_6354.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03914438/n03914438_2967.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02692877/n02692877_4397.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04371563/n04371563_10681.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03012897/n03012897_6751.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04474035/n04474035_16466.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03358380/n03358380_15087.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02063224/n02063224_1716.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04477548/n04477548_4044.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03956922/n03956922_29989.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03026907/n03026907_3732.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11968704/n11968704_2797.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03683708/n03683708_878.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03070059/n03070059_6553.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09475179/n09475179_24992.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12154773/n12154773_6166.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04067818/n04067818_9697.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12277578/n12277578_9559.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11924445/n11924445_21.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02867715/n02867715_2009.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09895561/n09895561_3625.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03503233/n03503233_6198.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12749049/n12749049_5193.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11552133/n11552133_3360.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11815918/n11815918_3130.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n00466273/n00466273_6177.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11879722/n11879722_12028.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07722485/n07722485_9567.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07917272/n07917272_12.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01636352/n01636352_2089.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04293119/n04293119_1043.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10215623/n10215623_8286.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12273114/n12273114_8238.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03965456/n03965456_36579.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02942699/n02942699_2539.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04008634/n04008634_13730.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02535258/n02535258_3381.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02639605/n02639605_2470.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12174311/n12174311_27887.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02233943/n02233943_10227.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03235042/n03235042_5201.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03978966/n03978966_5819.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10055410/n10055410_56807.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02987492/n02987492_19243.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02120079/n02120079_850.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01665541/n01665541_15659.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n00441073/n00441073_8055.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02933649/n02933649_3546.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02398521/n02398521_2232.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04590021/n04590021_6496.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09901502/n09901502_4343.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10120330/n10120330_3443.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04354182/n04354182_419.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03308152/n03308152_18425.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04467307/n04467307_111.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04312654/n04312654_984.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03782006/n03782006_904.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12822955/n12822955_7952.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02514041/n02514041_7676.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12282737/n12282737_3800.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04045397/n04045397_2680.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01730563/n01730563_731.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12886600/n12886600_290.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02095050/n02095050_1395.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09838370/n09838370_10936.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09335809/n09335809_3543.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10369095/n10369095_3450.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03173387/n03173387_1289.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03448956/n03448956_5799.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01756291/n01756291_7691.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01494041/n01494041_1193.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12704343/n12704343_7244.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03131574/n03131574_1447.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10686885/n10686885_7054.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02379630/n02379630_2603.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03356858/n03356858_35752.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02166229/n02166229_2974.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04028764/n04028764_13572.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03236217/n03236217_2899.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11707827/n11707827_8522.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03488438/n03488438_8538.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07696403/n07696403_5887.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03773835/n03773835_6601.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10628644/n10628644_24240.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n15091846/n15091846_20229.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03168107/n03168107_11756.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04161981/n04161981_12086.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09691729/n09691729_11429.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02274024/n02274024_1911.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03349599/n03349599_7829.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07685730/n07685730_5140.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09606527/n09606527_11698.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02565324/n02565324_7513.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n06273986/n06273986_1617.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02516188/n02516188_6015.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10663315/n10663315_84228.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02776205/n02776205_20359.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04457767/n04457767_2794.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03294833/n03294833_13951.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12290748/n12290748_9816.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11872146/n11872146_2301.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03650551/n03650551_14330.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02863536/n02863536_4946.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07897975/n07897975_468.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03176594/n03176594_5221.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03041632/n03041632_1585.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03927539/n03927539_3213.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10205231/n10205231_12889.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03684143/n03684143_9146.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03121431/n03121431_8120.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02966687/n02966687_366.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04258438/n04258438_1540.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07694659/n07694659_2709.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03819448/n03819448_7383.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n08596076/n08596076_4388.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04328946/n04328946_6914.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n13209808/n13209808_4896.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04097866/n04097866_1798.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03488887/n03488887_7862.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03607923/n03607923_4351.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09326662/n09326662_57025.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11794519/n11794519_2950.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01847089/n01847089_845.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03871724/n03871724_6918.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12089320/n12089320_3691.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12340755/n12340755_12550.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11733548/n11733548_919.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09752023/n09752023_3639.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04173907/n04173907_9518.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03489162/n03489162_2605.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11915214/n11915214_13218.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04230808/n04230808_5712.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09688804/n09688804_8222.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09303528/n09303528_4109.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09899671/n09899671_8994.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03766322/n03766322_9922.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03735963/n03735963_895.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02028900/n02028900_10457.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07825972/n07825972_20127.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11823436/n11823436_5492.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04325704/n04325704_4446.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07852614/n07852614_12381.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03291819/n03291819_4406.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02820556/n02820556_1251.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09824135/n09824135_20248.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01538630/n01538630_4997.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02086646/n02086646_1757.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07928367/n07928367_3911.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02588286/n02588286_5861.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03201776/n03201776_49961.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03238586/n03238586_4664.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09384106/n09384106_836.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03522634/n03522634_10351.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12204032/n12204032_944.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07693048/n07693048_7247.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11939491/n11939491_2793.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n00447540/n00447540_18292.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04104770/n04104770_11746.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10634849/n10634849_3228.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12018271/n12018271_1114.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02794474/n02794474_4249.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02607201/n02607201_7531.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04607869/n04607869_5381.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04005630/n04005630_27217.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02180427/n02180427_1380.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11974888/n11974888_1542.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12276477/n12276477_2656.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04557648/n04557648_4273.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11753700/n11753700_17593.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11937360/n11937360_186.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07891726/n07891726_29350.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03560430/n03560430_19067.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11925898/n11925898_5267.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07804543/n07804543_3998.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07727048/n07727048_4931.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03490884/n03490884_3506.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07842605/n07842605_6852.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09410224/n09410224_6048.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03415252/n03415252_96.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02588286/n02588286_9216.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04398951/n04398951_5196.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09836160/n09836160_8048.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04338963/n04338963_5267.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07732168/n07732168_3666.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07609840/n07609840_5111.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04209613/n04209613_13205.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12320010/n12320010_4231.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09715427/n09715427_5520.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09764201/n09764201_572.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12951835/n12951835_6002.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02094258/n02094258_817.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09742315/n09742315_2444.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03326795/n03326795_4299.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07609632/n07609632_16687.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02842809/n02842809_4163.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04525305/n04525305_1580.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n00463543/n00463543_7946.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04284572/n04284572_2500.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01950731/n01950731_10364.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n13128582/n13128582_3534.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03696301/n03696301_11386.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11903671/n11903671_27291.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11723770/n11723770_9131.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03587205/n03587205_8466.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02816768/n02816768_2802.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03391770/n03391770_15492.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04289576/n04289576_1682.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n00449054/n00449054_2120.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03488438/n03488438_4214.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01664492/n01664492_11610.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n00475273/n00475273_1660.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04185529/n04185529_5241.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09337253/n09337253_6577.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04056180/n04056180_906.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03279153/n03279153_2136.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12306089/n12306089_2006.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02475078/n02475078_9503.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01880716/n01880716_472.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10147262/n10147262_4054.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03016953/n03016953_3504.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03994614/n03994614_6041.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03390786/n03390786_761.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04186268/n04186268_7177.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02756977/n02756977_4583.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11939491/n11939491_5343.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07907161/n07907161_1421.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09991867/n09991867_97625.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04968139/n04968139_10250.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07597145/n07597145_11944.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07771891/n07771891_7157.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04546194/n04546194_8836.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02296276/n02296276_3774.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07861983/n07861983_3133.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03610418/n03610418_6995.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03652729/n03652729_7068.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04403524/n04403524_7010.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04045397/n04045397_19732.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03003091/n03003091_8141.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02086753/n02086753_7482.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12162181/n12162181_5310.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03914337/n03914337_14346.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n12256920/n12256920_2574.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02107574/n02107574_3335.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02802544/n02802544_9124.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02907082/n02907082_20650.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n15019030/n15019030_18034.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n01445429/n01445429_246.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10493419/n10493419_8350.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n10262445/n10262445_12908.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03783430/n03783430_520.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11857875/n11857875_5315.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04118635/n04118635_285.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n04395106/n04395106_5902.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07768694/n07768694_2282.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02836035/n02836035_8908.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02213788/n02213788_2171.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03898395/n03898395_2471.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n07723330/n07723330_5132.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03543394/n03543394_1381.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11757653/n11757653_2602.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03369276/n03369276_868.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n09732170/n09732170_5420.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03793850/n03793850_7494.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02952485/n02952485_17123.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02754656/n02754656_11408.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02084732/n02084732_15806.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03105306/n03105306_738.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n13061348/n13061348_1633.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n11736694/n11736694_2448.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n03766044/n03766044_15114.JPEG", "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image/n02275560/n02275560_786.JPEG"] 25 | 26 | def img_loader(img_path): 27 | try: 28 | return Image.open(img_path).convert('RGB') 29 | except: 30 | img_path = random.choice(img_list) 31 | return Image.open(img_path).convert('RGB') 32 | 33 | 34 | 35 | def build_22k_loader(config): 36 | config.defrost() 37 | 38 | transform = build_transform(True, config) 39 | root = "/comp_robot/cv_public_dataset/imagenet22k/imgnet22k-image" 40 | dataset_train = datasets.ImageFolder(root, transform=transform, loader=img_loader) 41 | config.MODEL.NUM_CLASSES = 21842 42 | config.freeze() 43 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 44 | 45 | num_tasks = dist.get_world_size() 46 | global_rank = dist.get_rank() 47 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 48 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 49 | sampler_train = SubsetRandomSampler(indices) 50 | else: 51 | sampler_train = torch.utils.data.DistributedSampler( 52 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 53 | ) 54 | 55 | data_loader_train = torch.utils.data.DataLoader( 56 | dataset_train, sampler=sampler_train, 57 | batch_size=config.DATA.BATCH_SIZE, 58 | num_workers=config.DATA.NUM_WORKERS, 59 | pin_memory=config.DATA.PIN_MEMORY, 60 | drop_last=True, 61 | ) 62 | 63 | # setup mixup / cutmix 64 | mixup_fn = None 65 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 66 | if mixup_active: 67 | mixup_fn = Mixup( 68 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 69 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 70 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 71 | 72 | return dataset_train, data_loader_train, mixup_fn 73 | 74 | 75 | def build_loader(config): 76 | config.defrost() 77 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 78 | config.freeze() 79 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 80 | dataset_val, _ = build_dataset(is_train=False, config=config) 81 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 82 | 83 | num_tasks = dist.get_world_size() 84 | global_rank = dist.get_rank() 85 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 86 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 87 | sampler_train = SubsetRandomSampler(indices) 88 | else: 89 | sampler_train = torch.utils.data.DistributedSampler( 90 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 91 | ) 92 | 93 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 94 | sampler_val = SubsetRandomSampler(indices) 95 | 96 | data_loader_train = torch.utils.data.DataLoader( 97 | dataset_train, sampler=sampler_train, 98 | batch_size=config.DATA.BATCH_SIZE, 99 | num_workers=config.DATA.NUM_WORKERS, 100 | pin_memory=config.DATA.PIN_MEMORY, 101 | drop_last=True, 102 | ) 103 | 104 | data_loader_val = torch.utils.data.DataLoader( 105 | dataset_val, sampler=sampler_val, 106 | batch_size=config.DATA.BATCH_SIZE, 107 | shuffle=False, 108 | num_workers=config.DATA.NUM_WORKERS, 109 | pin_memory=config.DATA.PIN_MEMORY, 110 | drop_last=False 111 | ) 112 | 113 | # setup mixup / cutmix 114 | mixup_fn = None 115 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 116 | if mixup_active: 117 | mixup_fn = Mixup( 118 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 119 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 120 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 121 | 122 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 123 | 124 | 125 | def build_dataset(is_train, config): 126 | transform = build_transform(is_train, config) 127 | if config.DATA.DATASET == 'imagenet': 128 | prefix = 'train' if is_train else 'val' 129 | if config.DATA.ZIP_MODE: 130 | ann_file = prefix + "_map.txt" 131 | prefix = prefix + ".zip@/" 132 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 133 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') 134 | else: 135 | root = os.path.join(config.DATA.DATA_PATH, prefix) 136 | dataset = datasets.ImageFolder(root, transform=transform) 137 | nb_classes = 1000 138 | else: 139 | raise NotImplementedError("We only support ImageNet Now.") 140 | 141 | return dataset, nb_classes 142 | 143 | 144 | def build_transform(is_train, config): 145 | resize_im = config.DATA.IMG_SIZE > 32 146 | if is_train: 147 | # if config.DATA.TRANSFORM_TYPE != 'easy': 148 | # this should always dispatch to transforms_imagenet_train 149 | transform = create_transform( 150 | input_size=config.DATA.IMG_SIZE, 151 | is_training=True, 152 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 153 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 154 | re_prob=config.AUG.REPROB, 155 | re_mode=config.AUG.REMODE, 156 | re_count=config.AUG.RECOUNT, 157 | interpolation=config.DATA.INTERPOLATION, 158 | ) 159 | if not resize_im: 160 | # replace RandomResizedCropAndInterpolation with 161 | # RandomCrop 162 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 163 | return transform 164 | # else: 165 | # t = [] 166 | # t.append( 167 | # transforms.Resize(config.DATA.IMG_SIZE + 32, interpolation=_pil_interp(config.DATA.INTERPOLATION))) 168 | # t.append(transforms.RandomCrop(config.DATA.IMG_SIZE)) 169 | # t.append(transforms.RandomHorizontalFlip()) 170 | # t.append(transforms.ToTensor()) 171 | # t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 172 | # return transforms.Compose(t) 173 | 174 | t = [] 175 | if resize_im: 176 | if config.TEST.CROP: 177 | size = int((256 / 224) * config.DATA.IMG_SIZE) 178 | t.append( 179 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 180 | # to maintain same ratio w.r.t. 224 images 181 | ) 182 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 183 | else: 184 | t.append( 185 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 186 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 187 | ) 188 | 189 | t.append(transforms.ToTensor()) 190 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 191 | return transforms.Compose(t) 192 | -------------------------------------------------------------------------------- /models/lipsformer_swin.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | 14 | 15 | class ConvProjection(nn.Module): 16 | def __init__(self, dim, out_features=None): 17 | super().__init__() 18 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 19 | if out_features is None: 20 | out_features = dim 21 | self.pwconv1 = nn.Conv2d(dim, out_features, kernel_size=1) # pointwise/1x1 convs, implemented with linear layers 22 | 23 | def forward(self, x): 24 | out = self.dwconv(x) 25 | out = self.pwconv1(out) 26 | return out 27 | 28 | class DWConvLayer(nn.Module): 29 | def __init__(self, dim, out_features=None): 30 | super().__init__() 31 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # 32 | def forward(self, x): 33 | out = self.dwconv(x) 34 | out = out + x 35 | return out 36 | 37 | class Mlp(nn.Module): 38 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 39 | super().__init__() 40 | out_features = out_features or in_features 41 | hidden_features = hidden_features or in_features 42 | self.fc1 = nn.Linear(in_features, hidden_features) 43 | self.act = act_layer() 44 | self.fc2 = nn.Linear(hidden_features, out_features) 45 | self.drop = nn.Dropout(drop) 46 | 47 | def forward(self, x): 48 | x = self.fc1(x) 49 | x = self.act(x) 50 | x = self.drop(x) 51 | x = self.fc2(x) 52 | x = self.drop(x) 53 | return x 54 | 55 | 56 | 57 | 58 | class ScaleLayer(nn.Module): 59 | def __init__(self, alpha=0.2, learnable=True, dim=1): 60 | super().__init__() 61 | self.alpha = alpha 62 | self.learnable = learnable 63 | self.dim = dim 64 | if self.learnable: 65 | #self.scale = nn.Parameter((torch.fmod(torch.arange(1, dim+1), 2, out=None)-torch.fmod(torch.arange(0, dim), 2, out=None))*init_t) 66 | self.scale = nn.Parameter(torch.ones(dim) * self.alpha) 67 | else: 68 | self.scale = self.alpha 69 | 70 | def forward(self, x): 71 | if self.learnable: 72 | y = self.scale[None, None, :]*x 73 | else: 74 | y = self.scale*x 75 | return y 76 | 77 | def __repr__(self): 78 | return f"ScaleLayer(alpha={self.alpha}, learnable={self.learnable}, dim={self.dim})" 79 | 80 | 81 | 82 | class CenterNorm(nn.Module): 83 | r""" CenterNorm that supports two data formats: channels_last (default) or channels_first. 84 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 85 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 86 | with shape (batch_size, channels, height, width). 87 | """ 88 | def __init__(self, normalized_shape, eps=1e-6): 89 | super().__init__() 90 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 91 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 92 | self.scale = normalized_shape/(normalized_shape-1.0) 93 | def forward(self, x): 94 | u = x.mean(-1, keepdim=True) 95 | x = self.scale*(x - u) 96 | x = self.weight[None, None, :] * x + self.bias[None, None, :] 97 | return x 98 | 99 | def __repr__(self): 100 | return "CenterNorm()" 101 | 102 | 103 | def window_partition(x, window_size): 104 | """ 105 | Args: 106 | x: (B, H, W, C) 107 | window_size (int): window size 108 | 109 | Returns: 110 | windows: (num_windows*B, window_size, window_size, C) 111 | """ 112 | B, H, W, C = x.shape 113 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 114 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 115 | return windows 116 | 117 | 118 | def window_reverse(windows, window_size, H, W): 119 | """ 120 | Args: 121 | windows: (num_windows*B, window_size, window_size, C) 122 | window_size (int): Window size 123 | H (int): Height of image 124 | W (int): Width of image 125 | 126 | Returns: 127 | x: (B, H, W, C) 128 | """ 129 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 130 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 131 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 132 | return x 133 | 134 | 135 | class WindowAttention(nn.Module): 136 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 137 | It supports both of shifted and non-shifted window. 138 | 139 | Args: 140 | dim (int): Number of input channels. 141 | window_size (tuple[int]): The height and width of the window. 142 | num_heads (int): Number of attention heads. 143 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 144 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 145 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 146 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 147 | """ 148 | 149 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 150 | 151 | super().__init__() 152 | self.dim = dim 153 | self.window_size = window_size # Wh, Ww 154 | self.num_heads = num_heads 155 | head_dim = dim // num_heads 156 | self.scale = qk_scale or head_dim ** -0.5 157 | 158 | # define a parameter table of relative position bias 159 | self.relative_position_bias_table = nn.Parameter( 160 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 161 | 162 | # get pair-wise relative position index for each token inside the window 163 | coords_h = torch.arange(self.window_size[0]) 164 | coords_w = torch.arange(self.window_size[1]) 165 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 166 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 167 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 168 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 169 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 170 | relative_coords[:, :, 1] += self.window_size[1] - 1 171 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 172 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 173 | self.register_buffer("relative_position_index", relative_position_index) 174 | 175 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 176 | 177 | #self.q = ConvProjection(dim) 178 | #self.k = ConvProjection(dim) 179 | #self.v = ConvProjection(dim) 180 | 181 | 182 | self.attn_drop = nn.Dropout(attn_drop) 183 | self.proj = nn.Linear(dim, dim) 184 | self.proj_drop = nn.Dropout(proj_drop) 185 | 186 | trunc_normal_(self.relative_position_bias_table, std=.02) 187 | self.softmax = nn.Softmax(dim=-1) 188 | 189 | def forward(self, x, mask=None): 190 | """ 191 | Args: 192 | x: input features with shape of (num_windows*B, N, C) 193 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 194 | """ 195 | B_, N, C = x.shape 196 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 197 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 198 | #q = self.q(x) 199 | #k = self.k(x) 200 | #v = self.v(x) 201 | 202 | 203 | #q = q * self.scale 204 | 205 | q = F.normalize(q, p=2.0, dim=-1) 206 | k = F.normalize(k, p=2.0, dim=-1) 207 | 208 | 209 | attn = 10.0*(q @ k.transpose(-2, -1)) 210 | 211 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 212 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 213 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 214 | attn = attn + relative_position_bias.unsqueeze(0) 215 | 216 | if mask is not None: 217 | nW = mask.shape[0] 218 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 219 | attn = attn.view(-1, self.num_heads, N, N) 220 | #attn = self.softmax(attn) 221 | attn = checkpoint.checkpoint(self.softmax, attn) 222 | 223 | else: 224 | #attn = self.softmax(attn) 225 | attn = checkpoint.checkpoint(self.softmax, attn) 226 | 227 | 228 | attn = self.attn_drop(attn) 229 | 230 | # The setting of v is slightly different from our paper. In our paper, we use a normalized and then scaled v, but here we directly use a non-normalized v. We tested both versions, both are ok. 231 | # Here, to avoid adding one more parameter $\nu$, we use the current version. 232 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 233 | x = self.proj(x) 234 | x = self.proj_drop(x) 235 | return x 236 | 237 | def extra_repr(self) -> str: 238 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 239 | 240 | def flops(self, N): 241 | # calculate flops for 1 window with token length of N 242 | flops = 0 243 | # qkv = self.qkv(x) 244 | flops += N * self.dim * 3 * self.dim 245 | # attn = (q @ k.transpose(-2, -1)) 246 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 247 | # x = (attn @ v) 248 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 249 | # x = self.proj(x) 250 | flops += N * self.dim * self.dim 251 | return flops 252 | 253 | 254 | class SwinTransformerBlock(nn.Module): 255 | r""" Swin Transformer Block. 256 | 257 | Args: 258 | dim (int): Number of input channels. 259 | input_resolution (tuple[int]): Input resulotion. 260 | num_heads (int): Number of attention heads. 261 | window_size (int): Window size. 262 | shift_size (int): Shift size for SW-MSA. 263 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 264 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 265 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 266 | drop (float, optional): Dropout rate. Default: 0.0 267 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 268 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 269 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 270 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 271 | """ 272 | 273 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 274 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 275 | act_layer=nn.GELU, norm_layer=CenterNorm, dw_conv_layer=2, num_layers=12): 276 | super().__init__() 277 | self.dim = dim 278 | self.input_resolution = input_resolution 279 | self.num_heads = num_heads 280 | self.window_size = window_size 281 | self.shift_size = shift_size 282 | self.mlp_ratio = mlp_ratio 283 | if min(self.input_resolution) <= self.window_size: 284 | # if window size is larger than input resolution, we don't partition windows 285 | self.shift_size = 0 286 | self.window_size = min(self.input_resolution) 287 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 288 | 289 | self.attn = WindowAttention( 290 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 291 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 292 | 293 | 294 | 295 | 296 | 297 | mlp_hidden_dim = int(dim * mlp_ratio) 298 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 299 | 300 | if self.shift_size > 0: 301 | # calculate attention mask for SW-MSA 302 | H, W = self.input_resolution 303 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 304 | h_slices = (slice(0, -self.window_size), 305 | slice(-self.window_size, -self.shift_size), 306 | slice(-self.shift_size, None)) 307 | w_slices = (slice(0, -self.window_size), 308 | slice(-self.window_size, -self.shift_size), 309 | slice(-self.shift_size, None)) 310 | cnt = 0 311 | for h in h_slices: 312 | for w in w_slices: 313 | img_mask[:, h, w, :] = cnt 314 | cnt += 1 315 | 316 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 317 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 318 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 319 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 320 | else: 321 | attn_mask = None 322 | 323 | 324 | mlp_hidden_dim = int(dim * mlp_ratio) 325 | 326 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 327 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 328 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 329 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop) 330 | self.norm1 = norm_layer(dim) 331 | self.norm2 = norm_layer(dim) 332 | 333 | self.dwconv_layers = nn.ModuleList() 334 | #self.norm_layers = nn.ModuleList() 335 | self.scale_layers = nn.ModuleList() 336 | self.gelus = nn.ModuleList() 337 | 338 | 339 | for i in range(dw_conv_layer): 340 | self.dwconv_layers.append(nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)) 341 | #self.norm_layers.append(norm_layer(dim)) 342 | self.scale_layers.append(ScaleLayer(dim=dim, alpha=1./num_layers))#**alpha_cfg)) 343 | self.gelus.append(nn.GELU()) 344 | self.alpha1 = ScaleLayer(dim=dim, alpha=1.0/num_layers) #**alpha_cfg) 345 | self.alpha2 = ScaleLayer(dim=dim, alpha=1.0/num_layers) #**alpha_cfg) 346 | 347 | self.nin = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0) 348 | self.gelu2 = nn.GELU() 349 | 350 | 351 | 352 | 353 | self.register_buffer("attn_mask", attn_mask) 354 | 355 | def forward(self, x): 356 | ''' 357 | H, W = self.input_resolution 358 | B, L, C = x.shape 359 | assert L == H * W, "input feature has wrong size" 360 | 361 | x = x.view(B, H, W, C) 362 | x = x.permute(0, 3, 1, 2) 363 | x = self.dwconvlayer(x) 364 | 365 | x = x.permute(0, 2, 3, 1) 366 | x = x.view(B, L, C) 367 | 368 | shortcut = x 369 | x = self.norm1(x) 370 | x = x.view(B, H, W, C) 371 | x = x.permute(0, 3, 1, 2) 372 | x = self.dwconv(x) 373 | x = x.permute(0, 2, 3, 1) 374 | ''' 375 | 376 | 377 | H, W = self.input_resolution 378 | B, L, C = x.shape 379 | assert L == H * W, "flatten img_tokens has wrong size" 380 | 381 | i = 0 382 | if len(self.dwconv_layers) > 0 : 383 | for deconv, act, scale in zip(self.dwconv_layers, self.gelus, self.scale_layers): 384 | tmp_x = x 385 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 386 | x = deconv(x) 387 | if i == 0: 388 | #x = self.gelu2(x) 389 | x = checkpoint.checkpoint(self.gelu2, x) 390 | x = self.nin(x) 391 | i = i + 1 392 | x = x.view(B, C, -1).transpose(-2, -1).contiguous() 393 | #x = tmp_x + self.drop_path2(scale(act(x))) 394 | x = tmp_x + self.drop_path2(scale(checkpoint.checkpoint(act, x))) 395 | 396 | shortcut = x 397 | x = x.view(B, H, W, C) 398 | 399 | 400 | 401 | # cyclic shift 402 | if self.shift_size > 0: 403 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 404 | else: 405 | shifted_x = x 406 | 407 | # partition windows 408 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 409 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 410 | 411 | # W-MSA/SW-MSA 412 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 413 | 414 | # merge windows 415 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 416 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 417 | 418 | # reverse cyclic shift 419 | if self.shift_size > 0: 420 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 421 | else: 422 | x = shifted_x 423 | x = x.view(B, H * W, C) 424 | 425 | 426 | x = shortcut + self.drop_path(self.alpha1(x)) 427 | x = self.norm1(x) 428 | 429 | #x = x + self.drop_path1(self.alpha2(self.mlp(x))) 430 | x = x + self.drop_path1(self.alpha2(checkpoint.checkpoint(self.mlp, x))) 431 | 432 | x = self.norm2(x) 433 | 434 | return x 435 | 436 | def extra_repr(self) -> str: 437 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 438 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 439 | 440 | def flops(self): 441 | flops = 0 442 | H, W = self.input_resolution 443 | # norm1 444 | flops += self.dim * H * W 445 | # W-MSA/SW-MSA 446 | nW = H * W / self.window_size / self.window_size 447 | flops += nW * self.attn.flops(self.window_size * self.window_size) 448 | # mlp 449 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 450 | # norm2 451 | flops += self.dim * H * W 452 | 453 | # dwconv1 454 | flops += self.dim * H * W * 49 455 | # dwconv2 456 | flops += self.dim * H * W * 49 457 | 458 | # 1 * 1 conv 459 | flops += self.dim * H * W * self.dim 460 | 461 | # alpha1 462 | flops += self.dim * H * W 463 | # alpha2 464 | flops += self.dim * H * W 465 | return flops 466 | 467 | 468 | class PatchMerging(nn.Module): 469 | r""" Patch Merging Layer. 470 | 471 | Args: 472 | input_resolution (tuple[int]): Resolution of input feature. 473 | dim (int): Number of input channels. 474 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 475 | """ 476 | 477 | def __init__(self, input_resolution, dim, norm_layer=CenterNorm): 478 | super().__init__() 479 | self.input_resolution = input_resolution 480 | self.dim = dim 481 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 482 | self.norm = norm_layer(4 * dim) 483 | 484 | def forward(self, x): 485 | """ 486 | x: B, H*W, C 487 | """ 488 | H, W = self.input_resolution 489 | B, L, C = x.shape 490 | assert L == H * W, "input feature has wrong size" 491 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 492 | 493 | x = x.view(B, H, W, C) 494 | 495 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 496 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 497 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 498 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 499 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 500 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 501 | 502 | x = self.norm(x) 503 | x = self.reduction(x) 504 | 505 | return x 506 | 507 | def extra_repr(self) -> str: 508 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 509 | 510 | def flops(self): 511 | H, W = self.input_resolution 512 | flops = H * W * self.dim 513 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 514 | return flops 515 | 516 | 517 | class BasicLayer(nn.Module): 518 | """ A basic Swin Transformer layer for one stage. 519 | 520 | Args: 521 | dim (int): Number of input channels. 522 | input_resolution (tuple[int]): Input resolution. 523 | depth (int): Number of blocks. 524 | num_heads (int): Number of attention heads. 525 | window_size (int): Local window size. 526 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 527 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 528 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 529 | drop (float, optional): Dropout rate. Default: 0.0 530 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 531 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 532 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 533 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 534 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 535 | """ 536 | 537 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 538 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 539 | drop_path=0., norm_layer=CenterNorm, downsample=None, use_checkpoint=False, num_layers=12): 540 | 541 | super().__init__() 542 | self.dim = dim 543 | self.input_resolution = input_resolution 544 | self.depth = depth 545 | self.use_checkpoint = use_checkpoint 546 | 547 | 548 | #L = sum(depth) 549 | # build blocks 550 | self.blocks = nn.ModuleList([ 551 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 552 | num_heads=num_heads, window_size=window_size, 553 | shift_size=0, #if (i % 2 == 0) else window_size // 2 554 | mlp_ratio=mlp_ratio, 555 | qkv_bias=qkv_bias, qk_scale=qk_scale, 556 | drop=drop, attn_drop=attn_drop, 557 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 558 | norm_layer=norm_layer, num_layers=num_layers) 559 | for i in range(depth)]) 560 | 561 | # patch merging layer 562 | if downsample is not None: 563 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 564 | else: 565 | self.downsample = None 566 | 567 | def forward(self, x): 568 | for blk in self.blocks: 569 | if self.use_checkpoint: 570 | x = checkpoint.checkpoint(blk, x) 571 | else: 572 | x = blk(x) 573 | if self.downsample is not None: 574 | x = self.downsample(x) 575 | return x 576 | 577 | def extra_repr(self) -> str: 578 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 579 | 580 | def flops(self): 581 | flops = 0 582 | for blk in self.blocks: 583 | flops += blk.flops() 584 | if self.downsample is not None: 585 | flops += self.downsample.flops() 586 | return flops 587 | 588 | 589 | class PatchEmbed(nn.Module): 590 | r""" Image to Patch Embedding 591 | 592 | Args: 593 | img_size (int): Image size. Default: 224. 594 | patch_size (int): Patch token size. Default: 4. 595 | in_chans (int): Number of input image channels. Default: 3. 596 | embed_dim (int): Number of linear projection output channels. Default: 96. 597 | norm_layer (nn.Module, optional): Normalization layer. Default: None 598 | """ 599 | 600 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 601 | super().__init__() 602 | img_size = to_2tuple(img_size) 603 | patch_size = to_2tuple(patch_size) 604 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 605 | self.img_size = img_size 606 | self.patch_size = patch_size 607 | self.patches_resolution = patches_resolution 608 | self.num_patches = patches_resolution[0] * patches_resolution[1] 609 | 610 | self.in_chans = in_chans 611 | self.embed_dim = embed_dim 612 | 613 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 614 | if norm_layer is not None: 615 | self.norm = norm_layer(embed_dim) 616 | else: 617 | self.norm = None 618 | 619 | def forward(self, x): 620 | B, C, H, W = x.shape 621 | # FIXME look at relaxing size constraints 622 | assert H == self.img_size[0] and W == self.img_size[1], \ 623 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 624 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 625 | if self.norm is not None: 626 | x = self.norm(x) 627 | return x 628 | 629 | def flops(self): 630 | Ho, Wo = self.patches_resolution 631 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 632 | if self.norm is not None: 633 | flops += Ho * Wo * self.embed_dim 634 | return flops 635 | 636 | 637 | class LipsFormerSwin(nn.Module): 638 | r""" Swin Transformer 639 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 640 | https://arxiv.org/pdf/2103.14030 641 | 642 | Args: 643 | img_size (int | tuple(int)): Input image size. Default 224 644 | patch_size (int | tuple(int)): Patch size. Default: 4 645 | in_chans (int): Number of input image channels. Default: 3 646 | num_classes (int): Number of classes for classification head. Default: 1000 647 | embed_dim (int): Patch embedding dimension. Default: 96 648 | depths (tuple(int)): Depth of each Swin Transformer layer. 649 | num_heads (tuple(int)): Number of attention heads in different layers. 650 | window_size (int): Window size. Default: 7 651 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 652 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 653 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 654 | drop_rate (float): Dropout rate. Default: 0 655 | attn_drop_rate (float): Attention dropout rate. Default: 0 656 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 657 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 658 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 659 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 660 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 661 | """ 662 | 663 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 664 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 665 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 666 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 667 | norm_layer=CenterNorm, ape=False, patch_norm=True, 668 | use_checkpoint=False, **kwargs): 669 | super().__init__() 670 | 671 | self.num_classes = num_classes 672 | self.num_layers = len(depths) 673 | self.embed_dim = embed_dim 674 | self.ape = ape 675 | self.patch_norm = patch_norm 676 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 677 | self.mlp_ratio = mlp_ratio 678 | 679 | # split image into non-overlapping patches 680 | self.patch_embed = PatchEmbed( 681 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 682 | norm_layer=norm_layer if self.patch_norm else None) 683 | num_patches = self.patch_embed.num_patches 684 | patches_resolution = self.patch_embed.patches_resolution 685 | self.patches_resolution = patches_resolution 686 | 687 | # absolute position embedding 688 | if self.ape: 689 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 690 | trunc_normal_(self.absolute_pos_embed, std=.02) 691 | 692 | self.pos_drop = nn.Dropout(p=drop_rate) 693 | 694 | 695 | LL = sum(depths) 696 | # stochastic depth 697 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 698 | 699 | # build layers 700 | self.layers = nn.ModuleList() 701 | for i_layer in range(self.num_layers): 702 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 703 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 704 | patches_resolution[1] // (2 ** i_layer)), 705 | depth=depths[i_layer], 706 | num_heads=num_heads[i_layer], 707 | window_size=window_size, 708 | mlp_ratio=self.mlp_ratio, 709 | qkv_bias=qkv_bias, qk_scale=qk_scale, 710 | drop=drop_rate, attn_drop=attn_drop_rate, 711 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 712 | norm_layer=norm_layer, 713 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 714 | use_checkpoint=use_checkpoint, num_layers=LL) 715 | self.layers.append(layer) 716 | 717 | self.norm = norm_layer(self.num_features) 718 | self.avgpool = nn.AdaptiveAvgPool1d(1) 719 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 720 | 721 | self.apply(self._spectral_init) 722 | 723 | def _spectral_init(self, m): 724 | if isinstance(m, nn.Linear): 725 | #torch.nn.init.orthogonal_(m.weight, gain=1) 726 | torch.nn.init.xavier_normal_(m.weight) 727 | if isinstance(m, nn.Linear) and m.bias is not None: 728 | nn.init.constant_(m.bias, 0) 729 | 730 | u,s,v = torch.svd(m.weight) 731 | m.weight.data = 1.0*m.weight.data/s[0] 732 | 733 | elif isinstance(m, (nn.Conv2d)): 734 | #torch.nn.init.orthogonal_(m.weight, gain=1) 735 | torch.nn.init.xavier_normal_(m.weight) 736 | weight = torch.reshape(m.weight.data, (m.weight.data.shape[0], -1)) 737 | u,s,v = torch.svd(weight) 738 | m.weight.data = m.weight.data/s[0] 739 | 740 | elif isinstance(m, (nn.LayerNorm, CenterNorm, nn.BatchNorm2d)): 741 | nn.init.constant_(m.bias, 0) 742 | nn.init.constant_(m.weight, 1.0) 743 | 744 | def _init_weights(self, m): 745 | if isinstance(m, nn.Linear): 746 | trunc_normal_(m.weight, std=.02) 747 | if isinstance(m, nn.Linear) and m.bias is not None: 748 | nn.init.constant_(m.bias, 0) 749 | elif isinstance(m, nn.LayerNorm): 750 | nn.init.constant_(m.bias, 0) 751 | nn.init.constant_(m.weight, 1.0) 752 | 753 | @torch.jit.ignore 754 | def no_weight_decay(self): 755 | return {'absolute_pos_embed'} 756 | 757 | @torch.jit.ignore 758 | def no_weight_decay_keywords(self): 759 | return {'relative_position_bias_table'} 760 | 761 | def forward_features(self, x): 762 | x = self.patch_embed(x) 763 | if self.ape: 764 | x = x + self.absolute_pos_embed 765 | x = self.pos_drop(x) 766 | 767 | for layer in self.layers: 768 | x = layer(x) 769 | 770 | x = self.norm(x) # B L C 771 | x = self.avgpool(x.transpose(1, 2)) # B C 1 772 | x = torch.flatten(x, 1) 773 | return x 774 | 775 | def forward(self, x): 776 | x = self.forward_features(x) 777 | x = self.head(x) 778 | return x 779 | 780 | def flops(self): 781 | flops = 0 782 | flops += self.patch_embed.flops() 783 | for i, layer in enumerate(self.layers): 784 | flops += layer.flops() 785 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 786 | flops += self.num_features * self.num_classes 787 | return flops 788 | 789 | 790 | 791 | --------------------------------------------------------------------------------