├── .gitignore
├── LICENSE
├── README.md
├── config.py
├── configs
├── finetune
│ ├── fd_finetune__clip_vit_base__img224__100ep.yaml
│ ├── fd_finetune__clip_vit_base__img224__300ep.yaml
│ ├── fd_finetune__clip_vit_large__img224__300ep.yaml
│ ├── fd_finetune__deit_vit_base__img224__300ep.yaml
│ ├── fd_finetune__dino_vit_base__img224__300ep.yaml
│ └── fd_finetune__esvit_swin_base__img224__300ep.yaml
└── pretrain
│ ├── fd_pretrain__clip_vit_base__img224__100ep.yaml
│ ├── fd_pretrain__clip_vit_base__img224__300ep.yaml
│ ├── fd_pretrain__clip_vit_large__img224__300ep.yaml
│ ├── fd_pretrain__deit_vit_base__img224__300ep.yaml
│ ├── fd_pretrain__dino_vit_base__img224__300ep.yaml
│ └── fd_pretrain__esvit_swin_base__img224__300ep.yaml
├── data
├── __init__.py
├── cached_image_folder.py
├── data_fd.py
├── data_finetune.py
├── data_linear.py
└── utils.py
├── figures
└── teaser.jpg
├── logger.py
├── lr_scheduler.py
├── main_fd.py
├── main_finetune.py
├── main_linear.py
├── models
├── __init__.py
├── build.py
├── clip
│ ├── __init__.py
│ ├── clip.py
│ ├── model.py
│ ├── simple_tokenizer.py
│ ├── utils.py
│ └── vit.py
├── deit.py
├── dino.py
├── esvit.py
├── feature_distillation.py
├── swin_transformer.py
├── swin_transformer_v2.py
├── utils.py
└── vision_transformer.py
├── optimizer.py
├── requirements.txt
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # dev files
132 | wandb/
133 | output/
134 | visualize/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Feature-Distillation
2 |
3 | By [Yixuan Wei](https://scholar.google.com/citations?user=xwudKb4AAAAJ&hl=en)\*, [Han Hu](https://ancientmooner.github.io/)\*, [Zhenda Xie](https://zdaxie.github.io), [Zheng Zhang](https://stupidzz.github.io/), [Yue Cao](http://yue-cao.me), [Jianmin Bao](https://jianminbao.github.io/), [Dong Chen](http://www.dongchen.pro) and [Baining Guo](https://scholar.google.com/citations?user=h4kYmRYAAAAJ&hl=en&oi=ao).
4 |
5 | This repo is the official implementation of ["Contrastive Learning Rivals Masked Image Modeling in Fine-tuning via Feature Distillation"](https://arxiv.org/abs/2205.14141).
6 |
7 | ## Updates
8 | ***11/30/2022***
9 |
10 | 1. Distilled and fine-tuned models on ImageNet-1K (`ViT Large`) are provided.
11 |
12 | ***11/28/2022***
13 |
14 | Initial commits:
15 |
16 | 1. Distilled and fine-tuned models on ImageNet-1K (`Swin Base`, and `ViT Base`) are provided.
17 | 2. The supported code for ImageNet-1K distillation and fine-tuning is provided.
18 |
19 | ## Introduction
20 |
21 | **FD** is initially described in [arxiv](https://arxiv.org/abs/2205.14141), which is a simple framework to convert the traditional pre-training models, such as image classification (DeiT), instance contrastive learning (DINO) and image-text alignment (CLIP) into new models with better fine-tuning performances. Through a set of diagosing tools, we find that the models distilled with feature map are endowed with following good properties which are also revealed in masked image modeling models: 1) more diverse attention heads; 2) more diagonal attention patterns; 3) flatten loss landscapes.
22 |
23 |
24 |

25 |
26 |
27 | ## Main Results on ImageNet
28 |
29 | ### Swin Transformer
30 |
31 | **ImageNet-1K Distilled and Fine-tuned Models**
32 |
33 | | name | distillation epochs | teacher model | image resolution | acc@1 | distilled model | fine-tuned model |
34 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
35 | | Swin-Base | 300 | [EsViT-Base](https://github.com/microsoft/esvit) | 224x224 | 85.1 | [google](https://drive.google.com/file/d/11_GQUHgcrUO8PMzl73eJmLSa7f3c5dZY/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__esvit_swin_base__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1criliGcjpEJxqlsYRGBERBAMYrFYFW--/view?usp=sharing)/[config](configs/finetune/fd_finetune__esvit_swin_base__img224__300ep.yaml) |
36 |
37 | ### Vision Transformer
38 |
39 | **ImageNet-1K Distilled and Fine-tuned Models**
40 |
41 | | name | distillation epochs | teacher model | image resolution | acc@1 | distilled model | fine-tuned model |
42 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
43 | | ViT-Base | 300 | [CLIP-Base](https://github.com/openai/CLIP) | 224x224 | 84.9 | [google](https://drive.google.com/file/d/1XFOZ6rJkv5X08Bu5d04_Xy3iJOj6SLc7/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__clip_vit_base__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1mP_JESmcdFeIkpB4aYyFzALtkydy_9qN/view?usp=sharing)/[config](configs/finetune/fd_finetune__clip_vit_base__img224__300ep.yaml) |
44 | | ViT-Base | 300 | [DINO-Base](https://github.com/facebookresearch/dino) | 224x224 | 83.8 | [google](https://drive.google.com/file/d/1fwBINMxpv5zFOI7Ye6l9msI8GzocpA3z/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__dino_vit_base__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1Mn_GgepfZXOe7W0UqEQMFo5MjJpMwM_i/view?usp=sharing)/[config](configs/finetune/fd_finetune__dino_vit_base__img224__300ep.yaml) |
45 | | ViT-Base | 300 | [DeiT-Base](https://github.com/facebookresearch/deit) | 224x224 | 83.0 | [google](https://drive.google.com/file/d/1yPezioDc4O6hdfD6VSAIU9DvJiXG4ZSJ/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__deit_vit_base__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1pb0KUlVcCaEGT-xnx6ookrqcC-88Ori5/view?usp=sharing)/[config](configs/finetune/fd_finetune__deit_vit_base__img224__300ep.yaml) |
46 | | ViT-Large | 300 | [CLIP-Large](https://github.com/openai/CLIP) | 224x224 | 87.7 | [google](https://drive.google.com/file/d/1H5USyzqwoS31JHDX874q8a70LdVD9zNY/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__clip_vit_large__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1XDDbDl9jzt8H2Fy6iZNfNA7Yjepf_MGx/view?usp=sharing)/[config](configs/finetune/fd_finetune__clip_vit_large__img224__300ep.yaml) |
47 |
48 | ## Citation
49 |
50 | If you find our work useful in your research, please cite:
51 |
52 | ```
53 | @article{wei2022FD,
54 | title={Contrastive Learning Rivals Masked Image Modeling in Fine-tuning via Feature Distillation},
55 | author={Yixuan Wei and Han Hu and Zhenda Xie and Zheng Zhang and Yue Cao and Jianmin Bao and Dong Chen and Baining Guo},
56 | journal={Tech Report},
57 | year={2022}
58 | }
59 | ```
60 |
61 | ## Getting Started
62 |
63 | ### Installation
64 |
65 | - Install `CUDA 11.3` with `cuDNN 8` following the official installation guide of [CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) and [cuDNN](https://developer.nvidia.com/rdp/cudnn-archive).
66 |
67 | - Setup conda environment:
68 | ```bash
69 | # Create environment
70 | conda create -n FD python=3.8 -y
71 | conda activate FD
72 |
73 | # Install requirements
74 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
75 |
76 | # Clone codes
77 | git clone https://github.com/SwinTransformer/Feature-Distillation
78 | cd Feature-Distillation
79 |
80 | # Install other requirements
81 | pip install -r requirements.txt
82 | ```
83 |
84 | ### Feature-Distillation
85 | To distill models, run:
86 | ```bash
87 | python -m torch.distributed.launch --nproc_per_node main_fd.py \
88 | --cfg --data-path /train [--batch-size --output --tag ]
89 | ```
90 |
91 | For example, to distill `CLIP-Base` for 300 epochs on one DGX-2 server, run:
92 | ```bash
93 | python -m torch.distributed.launch --nproc_per_node=16 main_fd.py --cfg configs/pretrain/fd_pretrain__clip_vit_base__img224__300ep.yaml --batch-size 128 --data-path /train [--output --tag ]
94 | ```
95 |
96 | If you want to save gpu memory consumption, add `--use-checkpoint`.
97 |
98 | ### Fine-tuning distilled models
99 | To fine-tune distilled models, run:
100 | ```bash
101 | python -m torch.distributed.launch --nproc_per_node main_finetune.py \
102 | --cfg --data-path --pretrained [--batch-size --output --tag ]
103 | ```
104 |
105 | For example, to fine-tune `Distilled-CLIP-Base` on one DGX-2 server, run:
106 | ```bash
107 | python -m torch.distributed.launch --nproc_per_node 16 main_finetune.py \
108 | --cfg configs/finetune/fd_finetune__clip_vit_base__img224__300ep.yaml --batch-size 128 --data-path --pretrained [--output --tag ]
109 | ```
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Copyright (c) 2021 Microsoft
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Ze Liu
5 | # Modified by Zhenda Xie
6 | # Modified by Yixuan Wei
7 | # --------------------------------------------------------
8 |
9 | import os
10 | import yaml
11 | from yacs.config import CfgNode as CN
12 |
13 | _C = CN()
14 |
15 | # Base config files
16 | _C.BASE = ['']
17 |
18 | # -----------------------------------------------------------------------------
19 | # Dev settings
20 | # -----------------------------------------------------------------------------
21 | _C.DEV = CN()
22 | # Relative Coords Table Type
23 | _C.DEV.RCT_TYPE = 'norm8_log'
24 | _C.DEV.CHECKPOINT_BLOCKS = [0,0,0,0]
25 |
26 | # [Feature Distillation]
27 | _C.DEV.PRED_FEAT = ''
28 | _C.DEV.PRED_FEAT_AFTERNORM = False # whether to use feature after norm
29 | _C.DEV.PRED_FEAT_S3 = False # when use swin as target and vit as student, use stage 3 feature as tgt for token number with 14*14
30 | _C.DEV.VIT_WITHKBIAS = False
31 | _C.DEV.FT_SKIP_REMAP = False
32 |
33 | # -----------------------------------------------------------------------------
34 | # Data settings
35 | # -----------------------------------------------------------------------------
36 | _C.DATA = CN()
37 | # Batch size for a single GPU, could be overwritten by command line argument
38 | _C.DATA.BATCH_SIZE = 128
39 | # Path to dataset, could be overwritten by command line argument
40 | _C.DATA.DATA_PATH = ''
41 | # Dataset name
42 | _C.DATA.DATASET = 'imagenet'
43 | # Input image size
44 | _C.DATA.IMG_SIZE = 224
45 | # Interpolation to resize image (random, bilinear, bicubic)
46 | _C.DATA.INTERPOLATION = 'bicubic'
47 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
48 | _C.DATA.PIN_MEMORY = True
49 | # Number of data loading threads
50 | _C.DATA.NUM_WORKERS = 8
51 | # Zip Mode as in Swin Transformer
52 | _C.DATA.ZIP_MODE = False
53 |
54 | # -----------------------------------------------------------------------------
55 | # Model settings
56 | # -----------------------------------------------------------------------------
57 | _C.MODEL = CN()
58 | # Model type
59 | _C.MODEL.TYPE = 'swin'
60 | # Model name
61 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
62 | # Checkpoint to resume, could be overwritten by command line argument
63 | _C.MODEL.RESUME = ''
64 | # Number of classes, overwritten in data preparation
65 | _C.MODEL.NUM_CLASSES = 1000
66 | # Dropout rate
67 | _C.MODEL.DROP_RATE = 0.0
68 | # Drop path rate
69 | _C.MODEL.DROP_PATH_RATE = 0.1
70 | # Label Smoothing
71 | _C.MODEL.LABEL_SMOOTHING = 0.1
72 |
73 | # Swin Transformer parameters
74 | _C.MODEL.SWIN = CN()
75 | _C.MODEL.SWIN.PATCH_SIZE = 4
76 | _C.MODEL.SWIN.IN_CHANS = 3
77 | _C.MODEL.SWIN.EMBED_DIM = 96
78 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
79 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
80 | _C.MODEL.SWIN.WINDOW_SIZE = 7
81 | _C.MODEL.SWIN.MLP_RATIO = 4.
82 | _C.MODEL.SWIN.QKV_BIAS = True
83 | _C.MODEL.SWIN.QK_SCALE = None
84 | _C.MODEL.SWIN.USE_SHARED_RPB = False
85 | _C.MODEL.SWIN.APE = False
86 | _C.MODEL.SWIN.PATCH_NORM = True
87 |
88 | # Vision Transformer parameters
89 | _C.MODEL.VIT = CN()
90 | _C.MODEL.VIT.PATCH_SIZE = 16
91 | _C.MODEL.VIT.IN_CHANS = 3
92 | _C.MODEL.VIT.EMBED_DIM = 768
93 | _C.MODEL.VIT.DEPTH = 12
94 | _C.MODEL.VIT.NUM_HEADS = 12
95 | _C.MODEL.VIT.MLP_RATIO = 4
96 | _C.MODEL.VIT.QKV_BIAS = True
97 | _C.MODEL.VIT.INIT_VALUES = 0.1
98 | _C.MODEL.VIT.USE_APE = False
99 | _C.MODEL.VIT.USE_RPB = False
100 | _C.MODEL.VIT.USE_SHARED_RPB = True
101 | _C.MODEL.VIT.USE_MEAN_POOLING = False
102 | _C.MODEL.VIT.ATTN_TYPE = 'normal'
103 | _C.MODEL.VIT.WITH_CLS_TOKEN = True
104 |
105 |
106 | # -----------------------------------------------------------------------------
107 | # Training settings
108 | # -----------------------------------------------------------------------------
109 | _C.TRAIN = CN()
110 | _C.TRAIN.START_EPOCH = 0
111 | _C.TRAIN.EPOCHS = 300
112 | _C.TRAIN.WARMUP_EPOCHS = 20
113 | _C.TRAIN.WARMUP_EPOCHS_FINE = 0.0 # incase of less than 1ep warmup
114 | _C.TRAIN.WEIGHT_DECAY = 0.05
115 | _C.TRAIN.BASE_LR = 5e-4
116 | _C.TRAIN.WARMUP_LR = 5e-7
117 | _C.TRAIN.MIN_LR = 5e-6
118 | # Clip gradient norm
119 | _C.TRAIN.CLIP_GRAD = 5.0
120 | # Auto resume from latest checkpoint
121 | _C.TRAIN.AUTO_RESUME = True
122 | # Gradient accumulation steps
123 | # could be overwritten by command line argument
124 | _C.TRAIN.ACCUMULATION_STEPS = 0
125 | # Whether to use gradient checkpointing to save memory
126 | # could be overwritten by command line argument
127 | _C.TRAIN.USE_CHECKPOINT = False
128 |
129 | # LR scheduler
130 | _C.TRAIN.LR_SCHEDULER = CN()
131 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
132 | # Epoch interval to decay LR, used in StepLRScheduler
133 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
134 | # LR decay rate, used in StepLRScheduler
135 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
136 | # Gamma / Multi steps value, used in MultiStepLRScheduler
137 | _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1
138 | _C.TRAIN.LR_SCHEDULER.MULTISTEPS = []
139 |
140 | # Optimizer
141 | _C.TRAIN.OPTIMIZER = CN()
142 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
143 | # Optimizer Epsilon
144 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
145 | # Optimizer Betas
146 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
147 | # SGD momentum
148 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
149 |
150 | # Layer decay for fine-tuning
151 | _C.TRAIN.LAYER_DECAY = 1.0
152 |
153 | # -----------------------------------------------------------------------------
154 | # Augmentation settings
155 | # -----------------------------------------------------------------------------
156 | _C.AUG = CN()
157 | # Color jitter factor
158 | _C.AUG.COLOR_JITTER = 0.4
159 | # Use AutoAugment policy. "v0" or "original"
160 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
161 | # Random erase prob
162 | _C.AUG.REPROB = 0.25
163 | # Random erase mode
164 | _C.AUG.REMODE = 'pixel'
165 | # Random erase count
166 | _C.AUG.RECOUNT = 1
167 | # Mixup alpha, mixup enabled if > 0
168 | _C.AUG.MIXUP = 0.8
169 | # Cutmix alpha, cutmix enabled if > 0
170 | _C.AUG.CUTMIX = 1.0
171 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
172 | _C.AUG.CUTMIX_MINMAX = None
173 | # Probability of performing mixup or cutmix when either/both is enabled
174 | _C.AUG.MIXUP_PROB = 1.0
175 | # Probability of switching to cutmix when both mixup and cutmix enabled
176 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
177 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
178 | _C.AUG.MIXUP_MODE = 'batch'
179 | _C.AUG.MAX_SCALE = 1.0
180 | _C.AUG.MIN_SCALE = 0.67
181 |
182 | # -----------------------------------------------------------------------------
183 | # Testing settings
184 | # -----------------------------------------------------------------------------
185 | _C.TEST = CN()
186 | # Whether to use center crop when testing
187 | _C.TEST.CROP = True
188 |
189 | # -----------------------------------------------------------------------------
190 | # Misc
191 | # -----------------------------------------------------------------------------
192 | # Whether to enable pytorch amp, overwritten by command line argument
193 | _C.ENABLE_AMP = False
194 | # Path to output folder, overwritten by command line argument
195 | _C.OUTPUT = ''
196 | # Tag of experiment, overwritten by command line argument
197 | _C.TAG = 'default'
198 | # Frequency to save checkpoint
199 | _C.SAVE_FREQ = 1
200 | # Frequency to logging info
201 | _C.PRINT_FREQ = 10
202 | # Fixed random seed
203 | _C.SEED = 0
204 | # Perform evaluation only, overwritten by command line argument
205 | _C.EVAL_MODE = False
206 | # Test throughput only, overwritten by command line argument
207 | _C.THROUGHPUT_MODE = False
208 | # local rank for DistributedDataParallel, given by command line argument
209 | _C.LOCAL_RANK = 0
210 |
211 | # path to pre-trained model
212 | _C.PRETRAINED = ''
213 |
214 |
215 | def _update_config_from_file(config, cfg_file):
216 | config.defrost()
217 | with open(cfg_file, 'r') as f:
218 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
219 |
220 | for cfg in yaml_cfg.setdefault('BASE', ['']):
221 | if cfg:
222 | _update_config_from_file(
223 | config, os.path.join(os.path.dirname(cfg_file), cfg)
224 | )
225 | print('=> merge config from {}'.format(cfg_file))
226 | config.merge_from_file(cfg_file)
227 | config.freeze()
228 |
229 |
230 | def update_config(config, args):
231 | _update_config_from_file(config, args.cfg)
232 |
233 | config.defrost()
234 | if args.opts:
235 | config.merge_from_list(args.opts)
236 |
237 | def _check_args(name):
238 | if hasattr(args, name) and eval(f'args.{name}'):
239 | return True
240 | return False
241 |
242 | # merge from specific arguments
243 | if _check_args('batch_size'):
244 | config.DATA.BATCH_SIZE = args.batch_size
245 | if _check_args('data_path'):
246 | config.DATA.DATA_PATH = args.data_path
247 | if _check_args('resume'):
248 | config.MODEL.RESUME = args.resume
249 | if _check_args('pretrained'):
250 | config.PRETRAINED = args.pretrained
251 | if _check_args('accumulation_steps'):
252 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
253 | if _check_args('use_checkpoint'):
254 | config.TRAIN.USE_CHECKPOINT = True
255 | if _check_args('enable_amp'):
256 | config.ENABLE_AMP = args.enable_amp
257 | if _check_args('output'):
258 | config.OUTPUT = args.output
259 | if _check_args('tag'):
260 | config.TAG = args.tag
261 | if _check_args('eval'):
262 | config.EVAL_MODE = True
263 | if _check_args('throughput'):
264 | config.THROUGHPUT_MODE = True
265 |
266 | # set local rank for distributed training
267 | config.LOCAL_RANK = args.local_rank
268 |
269 | # output folder
270 | config.OUTPUT = os.path.join(config.OUTPUT, config.TAG)
271 |
272 | config.freeze()
273 |
274 |
275 | def get_config(args):
276 | """Get a yacs CfgNode object with default values."""
277 | # Return a clone so that the defaults will not be altered
278 | # This is for the "local variable" use pattern
279 | config = _C.clone()
280 | update_config(config, args)
281 |
282 | return config
283 |
--------------------------------------------------------------------------------
/configs/finetune/fd_finetune__clip_vit_base__img224__100ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_finetune
4 | DROP_PATH_RATE: 0.1
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: True
11 | USE_SHARED_RPB: False
12 | USE_MEAN_POOLING: True
13 | WITH_CLS_TOKEN: True
14 | DATA:
15 | IMG_SIZE: 224
16 | BATCH_SIZE: 128
17 | TRAIN:
18 | EPOCHS: 100
19 | WARMUP_EPOCHS: 20
20 | BASE_LR: 1.25e-3
21 | WARMUP_LR: 2.5e-7
22 | MIN_LR: 2.5e-7
23 | WEIGHT_DECAY: 0.05
24 | LAYER_DECAY: 0.65
25 | PRINT_FREQ: 100
26 | SAVE_FREQ: 5
27 | TAG: fd_finetune__clip_vit_base__img224__100ep
--------------------------------------------------------------------------------
/configs/finetune/fd_finetune__clip_vit_base__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_finetune
4 | DROP_PATH_RATE: 0.3
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: True
11 | USE_SHARED_RPB: False
12 | USE_MEAN_POOLING: True
13 | WITH_CLS_TOKEN: True
14 | DATA:
15 | IMG_SIZE: 224
16 | BATCH_SIZE: 128
17 | TRAIN:
18 | EPOCHS: 100
19 | WARMUP_EPOCHS: 20
20 | BASE_LR: 1.25e-3
21 | WARMUP_LR: 2.5e-7
22 | MIN_LR: 2.5e-7
23 | WEIGHT_DECAY: 0.05
24 | LAYER_DECAY: 0.6
25 | PRINT_FREQ: 100
26 | SAVE_FREQ: 5
27 | TAG: fd_finetune__clip_vit_base__img224__300ep
--------------------------------------------------------------------------------
/configs/finetune/fd_finetune__clip_vit_large__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_finetune
4 | DROP_PATH_RATE: 0.4
5 | VIT:
6 | EMBED_DIM: 1024
7 | DEPTH: 24
8 | NUM_HEADS: 16
9 | USE_APE: False
10 | USE_RPB: True
11 | USE_SHARED_RPB: False
12 | USE_MEAN_POOLING: True
13 | WITH_CLS_TOKEN: True
14 | PATCH_SIZE: 14
15 | DATA:
16 | IMG_SIZE: 224
17 | BATCH_SIZE: 128
18 | TRAIN:
19 | EPOCHS: 50
20 | WARMUP_EPOCHS: 5
21 | BASE_LR: 2.5e-4
22 | WARMUP_LR: 5.0e-7
23 | MIN_LR: 5.0e-7
24 | WEIGHT_DECAY: 0.05
25 | LAYER_DECAY: 0.75
26 | PRINT_FREQ: 100
27 | SAVE_FREQ: 5
28 | TAG: fd_finetune__clip_vit_large__img224__300ep
--------------------------------------------------------------------------------
/configs/finetune/fd_finetune__deit_vit_base__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_finetune
4 | DROP_PATH_RATE: 0.3
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: True
11 | USE_SHARED_RPB: False
12 | USE_MEAN_POOLING: True
13 | WITH_CLS_TOKEN: True
14 | DATA:
15 | IMG_SIZE: 224
16 | BATCH_SIZE: 128
17 | TRAIN:
18 | EPOCHS: 100
19 | WARMUP_EPOCHS: 20
20 | BASE_LR: 1.25e-3
21 | WARMUP_LR: 2.5e-7
22 | MIN_LR: 2.5e-7
23 | WEIGHT_DECAY: 0.05
24 | LAYER_DECAY: 0.65
25 | PRINT_FREQ: 100
26 | SAVE_FREQ: 5
27 | TAG: fd_finetune__deit_vit_base__img224__300ep
--------------------------------------------------------------------------------
/configs/finetune/fd_finetune__dino_vit_base__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_finetune
4 | DROP_PATH_RATE: 0.2
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: True
11 | USE_SHARED_RPB: False
12 | USE_MEAN_POOLING: True
13 | WITH_CLS_TOKEN: True
14 | DATA:
15 | IMG_SIZE: 224
16 | BATCH_SIZE: 128
17 | TRAIN:
18 | EPOCHS: 100
19 | WARMUP_EPOCHS: 20
20 | BASE_LR: 1.5e-3
21 | WARMUP_LR: 2.5e-7
22 | MIN_LR: 2.5e-7
23 | WEIGHT_DECAY: 0.05
24 | LAYER_DECAY: 0.6
25 | PRINT_FREQ: 100
26 | SAVE_FREQ: 5
27 | TAG: fd_finetune__dino_vit_base__img224__300ep
--------------------------------------------------------------------------------
/configs/finetune/fd_finetune__esvit_swin_base__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin_v2
3 | NAME: fd_finetune
4 | DROP_PATH_RATE: 0.4
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 14
10 | DATA:
11 | IMG_SIZE: 224
12 | BATCH_SIZE: 64 # here is just a random mistake in our experiments. we believe 128*16=2048 will lead to similar results
13 | TRAIN:
14 | EPOCHS: 100
15 | WARMUP_EPOCHS: 20
16 | BASE_LR: 1.25e-3
17 | WARMUP_LR: 2.5e-7
18 | MIN_LR: 2.5e-7
19 | WEIGHT_DECAY: 0.05
20 | LAYER_DECAY: 0.8
21 | PRINT_FREQ: 100
22 | SAVE_FREQ: 5
23 | TAG: fd_finetune__esvit_swin_base__img224__300ep
--------------------------------------------------------------------------------
/configs/pretrain/fd_pretrain__clip_vit_base__img224__100ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_pretrain
4 | DROP_PATH_RATE: 0.1
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: False
11 | USE_SHARED_RPB: True
12 | USE_MEAN_POOLING: False
13 | WITH_CLS_TOKEN: True
14 | DATA:
15 | IMG_SIZE: 224
16 | BATCH_SIZE: 128
17 | TRAIN:
18 | EPOCHS: 100
19 | WARMUP_EPOCHS: 10
20 | BASE_LR: 3e-4
21 | WARMUP_LR: 5e-7
22 | MIN_LR: 5e-6
23 | WEIGHT_DECAY: 0.05
24 | CLIP_GRAD: 3.0
25 | DEV:
26 | PRED_FEAT: CLIP_400M
27 | PRED_FEAT_AFTERNORM: True
28 | PRINT_FREQ: 100
29 | SAVE_FREQ: 5
30 | TAG: fd_pretrain__clip_vit_base__img224__100ep
31 |
--------------------------------------------------------------------------------
/configs/pretrain/fd_pretrain__clip_vit_base__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_pretrain
4 | DROP_PATH_RATE: 0.2
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: False
11 | USE_SHARED_RPB: True
12 | USE_MEAN_POOLING: False
13 | WITH_CLS_TOKEN: True
14 | DATA:
15 | IMG_SIZE: 224
16 | BATCH_SIZE: 128
17 | AUG:
18 | MIN_SCALE: 0.08
19 | TRAIN:
20 | EPOCHS: 300
21 | WARMUP_EPOCHS: 10
22 | BASE_LR: 3e-4
23 | WARMUP_LR: 5e-7
24 | MIN_LR: 5e-6
25 | WEIGHT_DECAY: 0.05
26 | CLIP_GRAD: 3.0
27 | DEV:
28 | PRED_FEAT: CLIP_400M
29 | PRED_FEAT_AFTERNORM: True
30 | PRINT_FREQ: 100
31 | SAVE_FREQ: 5
32 | TAG: fd_pretrain__clip_vit_base__img224__300ep
33 |
--------------------------------------------------------------------------------
/configs/pretrain/fd_pretrain__clip_vit_large__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_pretrain
4 | DROP_PATH_RATE: 0.3
5 | VIT:
6 | EMBED_DIM: 1024
7 | DEPTH: 24
8 | NUM_HEADS: 16
9 | USE_APE: False
10 | USE_RPB: False
11 | USE_SHARED_RPB: True
12 | USE_MEAN_POOLING: False
13 | WITH_CLS_TOKEN: True
14 | PATCH_SIZE: 14
15 | DATA:
16 | IMG_SIZE: 224
17 | BATCH_SIZE: 128
18 | AUG:
19 | MIN_SCALE: 0.08
20 | TRAIN:
21 | EPOCHS: 300
22 | WARMUP_EPOCHS: 10
23 | BASE_LR: 3e-4
24 | WARMUP_LR: 5e-7
25 | MIN_LR: 5e-6
26 | WEIGHT_DECAY: 0.05
27 | CLIP_GRAD: 3.0
28 | DEV:
29 | PRED_FEAT: CLIP_400M_Large
30 | PRED_FEAT_AFTERNORM: True
31 | PRINT_FREQ: 100
32 | SAVE_FREQ: 5
33 | TAG: fd_pretrain__clip_vit_large__img224__300ep
34 |
--------------------------------------------------------------------------------
/configs/pretrain/fd_pretrain__deit_vit_base__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_pretrain
4 | DROP_PATH_RATE: 0.3
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: False
11 | USE_SHARED_RPB: True
12 | USE_MEAN_POOLING: False
13 | WITH_CLS_TOKEN: True
14 | DATA:
15 | IMG_SIZE: 224
16 | BATCH_SIZE: 128
17 | AUG:
18 | MIN_SCALE: 0.08
19 | TRAIN:
20 | EPOCHS: 300
21 | WARMUP_EPOCHS: 10
22 | BASE_LR: 3e-4
23 | WARMUP_LR: 5e-7
24 | MIN_LR: 5e-6
25 | WEIGHT_DECAY: 0.05
26 | CLIP_GRAD: 3.0
27 | DEV:
28 | PRED_FEAT: DEIT
29 | PRED_FEAT_AFTERNORM: True
30 | PRINT_FREQ: 100
31 | SAVE_FREQ: 5
32 | TAG: fd_pretrain__deit_vit_base__img224__300ep
33 |
--------------------------------------------------------------------------------
/configs/pretrain/fd_pretrain__dino_vit_base__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: fd_pretrain
4 | DROP_PATH_RATE: 0.3
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: False
11 | USE_SHARED_RPB: True
12 | USE_MEAN_POOLING: False
13 | WITH_CLS_TOKEN: True
14 | DATA:
15 | IMG_SIZE: 224
16 | BATCH_SIZE: 128
17 | AUG:
18 | MIN_SCALE: 0.08
19 | TRAIN:
20 | EPOCHS: 300
21 | WARMUP_EPOCHS: 10
22 | BASE_LR: 3e-4
23 | WARMUP_LR: 5e-7
24 | MIN_LR: 5e-6
25 | WEIGHT_DECAY: 0.05
26 | CLIP_GRAD: 3.0
27 | DEV:
28 | PRED_FEAT: DINO
29 | PRED_FEAT_AFTERNORM: True
30 | PRINT_FREQ: 100
31 | SAVE_FREQ: 5
32 | TAG: fd_pretrain__dino_vit_base__img224__300ep
33 |
--------------------------------------------------------------------------------
/configs/pretrain/fd_pretrain__esvit_swin_base__img224__300ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin_v2
3 | NAME: fd_pretrain
4 | DROP_PATH_RATE: 0.3
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 14
10 | DATA:
11 | IMG_SIZE: 224
12 | BATCH_SIZE: 128
13 | AUG:
14 | MIN_SCALE: 0.08
15 | TRAIN:
16 | EPOCHS: 300
17 | WARMUP_EPOCHS: 10
18 | BASE_LR: 2e-4
19 | WARMUP_LR: 1e-6
20 | MIN_LR: 1e-5
21 | WEIGHT_DECAY: 0.05
22 | CLIP_GRAD: 3.0
23 | DEV:
24 | PRED_FEAT: ESVIT
25 | PRED_FEAT_AFTERNORM: True
26 | PRINT_FREQ: 100
27 | SAVE_FREQ: 5
28 | TAG: fd_pretrain__esvit_swin_base__img224__300ep
29 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_fd import build_loader_fd
2 | from .data_finetune import build_loader_finetune
3 |
4 | def build_loader(config, logger, is_pretrain):
5 | if is_pretrain:
6 | return build_loader_fd(config, logger)
7 | else:
8 | return build_loader_finetune(config, logger)
--------------------------------------------------------------------------------
/data/cached_image_folder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import io
9 | import os
10 | import time
11 | import torch.distributed as dist
12 | import torch.utils.data as data
13 | from PIL import Image
14 |
15 | from .utils import is_zip_path, ZipReader
16 |
17 |
18 | def has_file_allowed_extension(filename, extensions):
19 | """Checks if a file is an allowed extension.
20 | Args:
21 | filename (string): path to a file
22 | Returns:
23 | bool: True if the filename ends with a known image extension
24 | """
25 | filename_lower = filename.lower()
26 | return any(filename_lower.endswith(ext) for ext in extensions)
27 |
28 |
29 | def find_classes(dir):
30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
31 | classes.sort()
32 | class_to_idx = {classes[i]: i for i in range(len(classes))}
33 | return classes, class_to_idx
34 |
35 |
36 | def make_dataset(dir, class_to_idx, extensions):
37 | images = []
38 | dir = os.path.expanduser(dir)
39 | for target in sorted(os.listdir(dir)):
40 | d = os.path.join(dir, target)
41 | if not os.path.isdir(d):
42 | continue
43 |
44 | for root, _, fnames in sorted(os.walk(d)):
45 | for fname in sorted(fnames):
46 | if has_file_allowed_extension(fname, extensions):
47 | path = os.path.join(root, fname)
48 | item = (path, class_to_idx[target])
49 | images.append(item)
50 |
51 | return images
52 |
53 |
54 | def make_dataset_with_ann(ann_file, img_prefix, extensions):
55 | images = []
56 | with open(ann_file, "r") as f:
57 | contents = f.readlines()
58 | for line_str in contents:
59 | path_contents = [c for c in line_str.split('\t')]
60 | im_file_name = path_contents[0]
61 | class_index = int(path_contents[1])
62 |
63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
64 | item = (os.path.join(img_prefix, im_file_name), class_index)
65 |
66 | images.append(item)
67 |
68 | return images
69 |
70 |
71 | class DatasetFolder(data.Dataset):
72 | """A generic data loader where the samples are arranged in this way: ::
73 | root/class_x/xxx.ext
74 | root/class_x/xxy.ext
75 | root/class_x/xxz.ext
76 | root/class_y/123.ext
77 | root/class_y/nsdf3.ext
78 | root/class_y/asd932_.ext
79 | Args:
80 | root (string): Root directory path.
81 | loader (callable): A function to load a sample given its path.
82 | extensions (list[string]): A list of allowed extensions.
83 | transform (callable, optional): A function/transform that takes in
84 | a sample and returns a transformed version.
85 | E.g, ``transforms.RandomCrop`` for images.
86 | target_transform (callable, optional): A function/transform that takes
87 | in the target and transforms it.
88 | Attributes:
89 | samples (list): List of (sample path, class_index) tuples
90 | """
91 |
92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
93 | cache_mode="no"):
94 | # image folder mode
95 | if ann_file == '':
96 | _, class_to_idx = find_classes(root)
97 | samples = make_dataset(root, class_to_idx, extensions)
98 | # zip mode
99 | else:
100 | samples = make_dataset_with_ann(os.path.join(root, ann_file),
101 | os.path.join(root, img_prefix),
102 | extensions)
103 |
104 | if len(samples) == 0:
105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
106 | "Supported extensions are: " + ",".join(extensions)))
107 |
108 | self.root = root
109 | self.loader = loader
110 | self.extensions = extensions
111 |
112 | self.samples = samples
113 | self.labels = [y_1k for _, y_1k in samples]
114 | self.classes = list(set(self.labels))
115 |
116 | self.transform = transform
117 | self.target_transform = target_transform
118 |
119 | self.cache_mode = cache_mode
120 | if self.cache_mode != "no":
121 | self.init_cache()
122 |
123 | def init_cache(self):
124 | assert self.cache_mode in ["part", "full"]
125 | n_sample = len(self.samples)
126 | global_rank = dist.get_rank()
127 | world_size = dist.get_world_size()
128 |
129 | samples_bytes = [None for _ in range(n_sample)]
130 | start_time = time.time()
131 | for index in range(n_sample):
132 | if index % (n_sample // 10) == 0:
133 | t = time.time() - start_time
134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
135 | start_time = time.time()
136 | path, target = self.samples[index]
137 | if self.cache_mode == "full":
138 | samples_bytes[index] = (ZipReader.read(path), target)
139 | elif self.cache_mode == "part" and index % world_size == global_rank:
140 | samples_bytes[index] = (ZipReader.read(path), target)
141 | else:
142 | samples_bytes[index] = (path, target)
143 | self.samples = samples_bytes
144 |
145 | def __getitem__(self, index):
146 | """
147 | Args:
148 | index (int): Index
149 | Returns:
150 | tuple: (sample, target) where target is class_index of the target class.
151 | """
152 | path, target = self.samples[index]
153 | sample = self.loader(path)
154 | if self.transform is not None:
155 | sample = self.transform(sample)
156 | if self.target_transform is not None:
157 | target = self.target_transform(target)
158 |
159 | return sample, target
160 |
161 | def __len__(self):
162 | return len(self.samples)
163 |
164 | def __repr__(self):
165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
167 | fmt_str += ' Root Location: {}\n'.format(self.root)
168 | tmp = ' Transforms (if any): '
169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
170 | tmp = ' Target Transforms (if any): '
171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
172 | return fmt_str
173 |
174 |
175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
176 |
177 |
178 | def pil_loader(path):
179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
180 | if isinstance(path, bytes):
181 | img = Image.open(io.BytesIO(path))
182 | elif is_zip_path(path):
183 | data = ZipReader.read(path)
184 | img = Image.open(io.BytesIO(data))
185 | else:
186 | with open(path, 'rb') as f:
187 | img = Image.open(f)
188 | return img.convert('RGB')
189 |
190 |
191 | def accimage_loader(path):
192 | import accimage
193 | try:
194 | return accimage.Image(path)
195 | except IOError:
196 | # Potentially a decoding problem, fall back to PIL.Image
197 | return pil_loader(path)
198 |
199 |
200 | def default_img_loader(path):
201 | from torchvision import get_image_backend
202 | if get_image_backend() == 'accimage':
203 | return accimage_loader(path)
204 | else:
205 | return pil_loader(path)
206 |
207 |
208 | class CachedImageFolder(DatasetFolder):
209 | """A generic data loader where the images are arranged in this way: ::
210 | root/dog/xxx.png
211 | root/dog/xxy.png
212 | root/dog/xxz.png
213 | root/cat/123.png
214 | root/cat/nsdf3.png
215 | root/cat/asd932_.png
216 | Args:
217 | root (string): Root directory path.
218 | transform (callable, optional): A function/transform that takes in an PIL image
219 | and returns a transformed version. E.g, ``transforms.RandomCrop``
220 | target_transform (callable, optional): A function/transform that takes in the
221 | target and transforms it.
222 | loader (callable, optional): A function to load an image given its path.
223 | Attributes:
224 | imgs (list): List of (image path, class_index) tuples
225 | """
226 |
227 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
228 | loader=default_img_loader, cache_mode="no"):
229 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
230 | ann_file=ann_file, img_prefix=img_prefix,
231 | transform=transform, target_transform=target_transform,
232 | cache_mode=cache_mode)
233 | self.imgs = self.samples
234 |
235 | def __getitem__(self, index):
236 | """
237 | Args:
238 | index (int): Index
239 | Returns:
240 | tuple: (image, target) where target is class_index of the target class.
241 | """
242 | path, target = self.samples[index]
243 | image = self.loader(path)
244 | if self.transform is not None and not isinstance(self.transform, list):
245 | img = self.transform(image)
246 | elif self.transform is not None and isinstance(self.transform, list):
247 | img = []
248 | for i in range(len(self.transform)):
249 | _img = self.transform[i](image)
250 | if isinstance(_img, list) or isinstance(_img, tuple):
251 | img.extend(_img)
252 | else:
253 | img.append(_img)
254 | else:
255 | img = image
256 | if self.target_transform is not None:
257 | target = self.target_transform(target)
258 |
259 | return img, target
260 |
--------------------------------------------------------------------------------
/data/data_fd.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Feature Distillation
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Zhenda Xie
6 | # Modified by Yixuan Wei
7 | # --------------------------------------------------------
8 |
9 | import numpy as np
10 |
11 | import torch
12 | import torch.distributed as dist
13 | import torchvision.transforms as T
14 | from torch.utils.data import DataLoader, DistributedSampler
15 | from torchvision.datasets import ImageFolder
16 | from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension
17 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18 | from .utils import SubsetRandomSampler
19 | from .cached_image_folder import CachedImageFolder
20 |
21 | from PIL import ImageFile
22 | ImageFile.LOAD_TRUNCATED_IMAGES = True
23 |
24 | class FDTransform:
25 | def __init__(self, config):
26 | self.config = config
27 |
28 | crop_size = config.DATA.IMG_SIZE
29 | self.transform_img = T.Compose([
30 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
31 | T.RandomResizedCrop(crop_size, scale=(config.AUG.MIN_SCALE, config.AUG.MAX_SCALE), ratio=(3. / 4., 4. / 3.)),
32 | T.RandomHorizontalFlip(),
33 | T.ToTensor(),
34 | T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
35 | ])
36 |
37 | def __call__(self, img):
38 | img = self.transform_img(img)
39 | return img
40 |
41 |
42 | def is_valid_file(x: str) -> bool:
43 | unvalid_file_list = """
44 | n01678043_6448.JPEG
45 | n01896844_997.JPEG
46 | n02368116_318.JPEG
47 | n02428089_710.JPEG
48 | n02487347_1956.JPEG
49 | n02597972_5463.JPEG
50 | n03957420_30695.JPEG
51 | n03957420_33553.JPEG
52 | n03957420_8296.JPEG
53 | n04135315_8814.JPEG
54 | n04135315_9318.JPEG
55 | n04257684_9033.JPEG
56 | n04427559_2974.JPEG
57 | n06470073_47249.JPEG
58 | n07930062_4147.JPEG
59 | n09224725_3995.JPEG
60 | n09359803_8155.JPEG
61 | n09620794_5529.JPEG
62 | n09789566_3522.JPEG
63 | n09894445_7463.JPEG
64 | n10175248_583.JPEG
65 | n10316360_4246.JPEG
66 | n10368624_12550.JPEG
67 | n10585217_8484.JPEG
68 | n10721819_1131.JPEG
69 | n12353203_3849.JPEG
70 | n12630763_8018.JPEG
71 | """
72 | unvalid_file_list = tuple([i.strip() for i in unvalid_file_list.split('\n') if len(i.strip()) > 0])
73 | assert len(unvalid_file_list) == 27
74 |
75 | return has_file_allowed_extension(x, IMG_EXTENSIONS) and not x.endswith(unvalid_file_list)
76 |
77 |
78 | def build_loader_fd(config, logger):
79 | transform = FDTransform(config)
80 | logger.info(f'Pre-train data transform:\n{transform}')
81 |
82 | if config.DATA.DATASET == 'imagenet':
83 | prefix = 'train'
84 | if config.DATA.ZIP_MODE:
85 | ann_file = prefix + "_map.txt"
86 | prefix = prefix + ".zip@/"
87 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
88 | cache_mode='part')
89 | else:
90 | dataset = ImageFolder(config.DATA.DATA_PATH, transform, is_valid_file=is_valid_file)
91 | elif config.DATA.DATASET == 'imagenet22k':
92 | dataset = ImageFolder(config.DATA.DATA_PATH, transform, is_valid_file=is_valid_file)
93 |
94 | if config.DATA.DATASET == 'imagenet' and config.DATA.ZIP_MODE:
95 | indices = np.arange(dist.get_rank(), len(dataset), dist.get_world_size())
96 | sampler = SubsetRandomSampler(indices)
97 | else:
98 | sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
99 | dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True)
100 |
101 | return dataloader
--------------------------------------------------------------------------------
/data/data_finetune.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Feature Distillation
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Zhenda Xie
6 | # Modified by Yixuan Wei
7 | # --------------------------------------------------------
8 |
9 | import os
10 | import torch.distributed as dist
11 | from torch.utils.data import DataLoader, DistributedSampler
12 | from torchvision import datasets, transforms
13 | from torchvision.datasets import ImageFolder
14 | from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension
15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16 | from timm.data import Mixup
17 | from timm.data import create_transform
18 | from timm.data.transforms import _pil_interp
19 |
20 | import numpy as np
21 | from .cached_image_folder import CachedImageFolder
22 | from .utils import SubsetRandomSampler
23 | IMAGENET_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
24 | IMAGENET_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
25 |
26 | import random
27 | from PIL import ImageFilter, ImageOps
28 |
29 |
30 | class GaussianBlur(object):
31 | """
32 | Apply Gaussian Blur to the PIL image.
33 | """
34 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
35 | self.prob = p
36 | self.radius_min = radius_min
37 | self.radius_max = radius_max
38 |
39 | def __call__(self, img):
40 | do_it = random.random() <= self.prob
41 | if not do_it:
42 | return img
43 |
44 | img = img.filter(
45 | ImageFilter.GaussianBlur(
46 | radius=random.uniform(self.radius_min, self.radius_max)
47 | )
48 | )
49 | return img
50 |
51 | class Solarization(object):
52 | """
53 | Apply Solarization to the PIL image.
54 | """
55 | def __init__(self, p=0.2):
56 | self.p = p
57 |
58 | def __call__(self, img):
59 | if random.random() < self.p:
60 | return ImageOps.solarize(img)
61 | else:
62 | return img
63 |
64 | class gray_scale(object):
65 | """
66 | Apply Solarization to the PIL image.
67 | """
68 | def __init__(self, p=0.2):
69 | self.p = p
70 | self.transf = transforms.Grayscale(3)
71 |
72 | def __call__(self, img):
73 | if random.random() < self.p:
74 | return self.transf(img)
75 | else:
76 | return img
77 |
78 |
79 | def build_loader_finetune(config, logger):
80 | config.defrost()
81 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config, logger=logger)
82 | config.freeze()
83 | dataset_val, _ = build_dataset(is_train=False, config=config, logger=logger)
84 | logger.info(f"Build dataset: train images = {len(dataset_train)}, val images = {len(dataset_val)}")
85 |
86 | num_tasks = dist.get_world_size()
87 | global_rank = dist.get_rank()
88 | if config.DATA.ZIP_MODE:
89 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
90 | sampler_train = SubsetRandomSampler(indices)
91 | else:
92 | sampler_train = DistributedSampler(
93 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
94 | )
95 |
96 | if config.DATA.ZIP_MODE:
97 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
98 | sampler_val = SubsetRandomSampler(indices)
99 | else:
100 | sampler_val = DistributedSampler(
101 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
102 | )
103 |
104 | data_loader_train = DataLoader(
105 | dataset_train, sampler=sampler_train,
106 | batch_size=config.DATA.BATCH_SIZE,
107 | num_workers=config.DATA.NUM_WORKERS,
108 | pin_memory=config.DATA.PIN_MEMORY,
109 | drop_last=True,
110 | )
111 |
112 | data_loader_val = DataLoader(
113 | dataset_val, sampler=sampler_val,
114 | batch_size=config.DATA.BATCH_SIZE,
115 | num_workers=config.DATA.NUM_WORKERS,
116 | pin_memory=config.DATA.PIN_MEMORY,
117 | drop_last=False,
118 | )
119 |
120 | # setup mixup / cutmix
121 | mixup_fn = None
122 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
123 | if mixup_active:
124 | mixup_fn = Mixup(
125 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
126 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
127 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
128 |
129 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
130 |
131 |
132 | def is_valid_file(x: str) -> bool:
133 | unvalid_file_list = """
134 | n01678043_6448.JPEG
135 | n01896844_997.JPEG
136 | n02368116_318.JPEG
137 | n02428089_710.JPEG
138 | n02487347_1956.JPEG
139 | n02597972_5463.JPEG
140 | n03957420_30695.JPEG
141 | n03957420_33553.JPEG
142 | n03957420_8296.JPEG
143 | n04135315_8814.JPEG
144 | n04135315_9318.JPEG
145 | n04257684_9033.JPEG
146 | n04427559_2974.JPEG
147 | n06470073_47249.JPEG
148 | n07930062_4147.JPEG
149 | n09224725_3995.JPEG
150 | n09359803_8155.JPEG
151 | n09620794_5529.JPEG
152 | n09789566_3522.JPEG
153 | n09894445_7463.JPEG
154 | n10175248_583.JPEG
155 | n10316360_4246.JPEG
156 | n10368624_12550.JPEG
157 | n10585217_8484.JPEG
158 | n10721819_1131.JPEG
159 | n12353203_3849.JPEG
160 | n12630763_8018.JPEG
161 | """
162 | unvalid_file_list = tuple([i.strip() for i in unvalid_file_list.split('\n') if len(i.strip()) > 0])
163 | assert len(unvalid_file_list) == 27
164 |
165 | return has_file_allowed_extension(x, IMG_EXTENSIONS) and not x.endswith(unvalid_file_list)
166 |
167 |
168 | def build_dataset(is_train, config, logger):
169 | transform = build_transform(is_train, config)
170 | logger.info(f'Fine-tune data transform, is_train={is_train}:\n{transform}')
171 |
172 | if config.DATA.DATASET == 'imagenet':
173 | prefix = 'train' if is_train else 'val'
174 | if config.DATA.ZIP_MODE:
175 | ann_file = prefix + "_map.txt"
176 | prefix = prefix + ".zip@/"
177 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
178 | cache_mode='part')
179 | else:
180 | root = os.path.join(config.DATA.DATA_PATH, prefix)
181 | dataset = datasets.ImageFolder(root, transform=transform)
182 | nb_classes = 1000
183 | elif config.DATA.DATASET == 'imagenet22k':
184 | if is_train:
185 | dataset = ImageFolder(config.DATA.DATA_PATH, transform, is_valid_file=is_valid_file)
186 | nb_classes = 21841
187 | else:
188 | nb_classes = 1000
189 | prefix = 'val'
190 | if config.DATA.ZIP_MODE:
191 | ann_file = prefix + "_map.txt"
192 | prefix = prefix + ".zip@/"
193 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
194 | cache_mode='part')
195 | else:
196 | root = os.path.join(config.DATA.DATA_PATH, prefix)
197 | dataset = datasets.ImageFolder(root, transform=transform)
198 | else:
199 | raise NotImplementedError("We only support ImageNet Now.")
200 |
201 | return dataset, nb_classes
202 |
203 |
204 | def build_transform(is_train, config):
205 | resize_im = config.DATA.IMG_SIZE > 32
206 | if is_train:
207 | # this should always dispatch to transforms_imagenet_train
208 | transform = create_transform(
209 | input_size=config.DATA.IMG_SIZE,
210 | is_training=True,
211 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
212 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
213 | re_prob=config.AUG.REPROB,
214 | re_mode=config.AUG.REMODE,
215 | re_count=config.AUG.RECOUNT,
216 | interpolation=config.DATA.INTERPOLATION,
217 | mean=IMAGENET_DEFAULT_MEAN,
218 | std=IMAGENET_DEFAULT_STD
219 | )
220 | if not resize_im:
221 | # replace RandomResizedCropAndInterpolation with
222 | # RandomCrop
223 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
224 | return transform
225 |
226 | t = []
227 | if resize_im:
228 | if config.TEST.CROP:
229 | size = int((256 / 224) * config.DATA.IMG_SIZE)
230 | t.append(
231 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
232 | # to maintain same ratio w.r.t. 224 images
233 | )
234 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
235 | else:
236 | t.append(
237 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
238 | interpolation=_pil_interp(config.DATA.INTERPOLATION))
239 | )
240 |
241 | t.append(transforms.ToTensor())
242 | t.append(transforms.Normalize(IMAGENET_CLIP_MEAN, IMAGENET_CLIP_STD))
243 | return transforms.Compose(t)
--------------------------------------------------------------------------------
/data/data_linear.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Feature Distillation
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Zhenda Xie
6 | # Modified by Yixuan Wei
7 | # --------------------------------------------------------
8 |
9 | import os
10 | import math
11 | import torch
12 | import torch.distributed as dist
13 | from torchvision.transforms import functional as F
14 | from torch.utils.data import DataLoader, DistributedSampler
15 | from torchvision import datasets, transforms
16 | from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension
17 | from timm.data.transforms import _pil_interp
18 |
19 | import numpy as np
20 | from .cached_image_folder import CachedImageFolder
21 | from .utils import SubsetRandomSampler
22 | IMAGENET_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
23 | IMAGENET_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
24 |
25 |
26 | def build_loader(config, logger):
27 | config.defrost()
28 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config, logger=logger)
29 | config.freeze()
30 | dataset_val, _ = build_dataset(is_train=False, config=config, logger=logger)
31 | logger.info(f"Build dataset: train images = {len(dataset_train)}, val images = {len(dataset_val)}")
32 |
33 | num_tasks = dist.get_world_size()
34 | global_rank = dist.get_rank()
35 | if config.DATA.ZIP_MODE:
36 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
37 | sampler_train = SubsetRandomSampler(indices)
38 | else:
39 | sampler_train = DistributedSampler(
40 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
41 | )
42 |
43 | if config.DATA.ZIP_MODE:
44 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
45 | sampler_val = SubsetRandomSampler(indices)
46 | else:
47 | sampler_val = DistributedSampler(
48 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
49 | )
50 |
51 | data_loader_train = DataLoader(
52 | dataset_train, sampler=sampler_train,
53 | batch_size=config.DATA.BATCH_SIZE,
54 | num_workers=config.DATA.NUM_WORKERS,
55 | pin_memory=config.DATA.PIN_MEMORY,
56 | drop_last=True,
57 | )
58 |
59 | data_loader_val = DataLoader(
60 | dataset_val, sampler=sampler_val,
61 | batch_size=config.DATA.BATCH_SIZE,
62 | num_workers=config.DATA.NUM_WORKERS,
63 | pin_memory=config.DATA.PIN_MEMORY,
64 | drop_last=False,
65 | )
66 |
67 | mixup_fn = None
68 |
69 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
70 |
71 |
72 | def is_valid_file(x: str) -> bool:
73 | unvalid_file_list = """
74 | n01678043_6448.JPEG
75 | n01896844_997.JPEG
76 | n02368116_318.JPEG
77 | n02428089_710.JPEG
78 | n02487347_1956.JPEG
79 | n02597972_5463.JPEG
80 | n03957420_30695.JPEG
81 | n03957420_33553.JPEG
82 | n03957420_8296.JPEG
83 | n04135315_8814.JPEG
84 | n04135315_9318.JPEG
85 | n04257684_9033.JPEG
86 | n04427559_2974.JPEG
87 | n06470073_47249.JPEG
88 | n07930062_4147.JPEG
89 | n09224725_3995.JPEG
90 | n09359803_8155.JPEG
91 | n09620794_5529.JPEG
92 | n09789566_3522.JPEG
93 | n09894445_7463.JPEG
94 | n10175248_583.JPEG
95 | n10316360_4246.JPEG
96 | n10368624_12550.JPEG
97 | n10585217_8484.JPEG
98 | n10721819_1131.JPEG
99 | n12353203_3849.JPEG
100 | n12630763_8018.JPEG
101 | """
102 | unvalid_file_list = tuple([i.strip() for i in unvalid_file_list.split('\n') if len(i.strip()) > 0])
103 | assert len(unvalid_file_list) == 27
104 |
105 | return has_file_allowed_extension(x, IMG_EXTENSIONS) and not x.endswith(unvalid_file_list)
106 |
107 |
108 | def build_dataset(is_train, config, logger):
109 | transform = build_transform(is_train, config)
110 | logger.info(f'Fine-tune data transform, is_train={is_train}:\n{transform}')
111 |
112 | if config.DATA.DATASET == 'imagenet':
113 | prefix = 'train' if is_train else 'val'
114 | if config.DATA.ZIP_MODE:
115 | ann_file = prefix + "_map.txt"
116 | prefix = prefix + ".zip@/"
117 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
118 | cache_mode='part')
119 | else:
120 | root = os.path.join(config.DATA.DATA_PATH, prefix)
121 | dataset = datasets.ImageFolder(root, transform=transform)
122 | nb_classes = 1000
123 | else:
124 | raise NotImplementedError("We only support ImageNet Now.")
125 |
126 | return dataset, nb_classes
127 |
128 |
129 | class RandomResizedCrop(transforms.RandomResizedCrop):
130 | """
131 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
132 | This may lead to results different with torchvision's version.
133 | Following BYOL's TF code:
134 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
135 | """
136 | @staticmethod
137 | def get_params(img, scale, ratio):
138 | width, height = F._get_image_size(img)
139 | area = height * width
140 |
141 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
142 | log_ratio = torch.log(torch.tensor(ratio))
143 | aspect_ratio = torch.exp(
144 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
145 | ).item()
146 |
147 | w = int(round(math.sqrt(target_area * aspect_ratio)))
148 | h = int(round(math.sqrt(target_area / aspect_ratio)))
149 |
150 | w = min(w, width)
151 | h = min(h, height)
152 |
153 | i = torch.randint(0, height - h + 1, size=(1,)).item()
154 | j = torch.randint(0, width - w + 1, size=(1,)).item()
155 |
156 | return i, j, h, w
157 |
158 |
159 | def build_transform(is_train, config):
160 | resize_im = config.DATA.IMG_SIZE > 32
161 | if is_train:
162 | # linear probe: weak augmentation
163 | transform = transforms.Compose([
164 | RandomResizedCrop(config.DATA.IMG_SIZE, interpolation=3),
165 | transforms.RandomHorizontalFlip(),
166 | transforms.ToTensor(),
167 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
168 | return transform
169 |
170 | t = []
171 | if resize_im:
172 | if True: # config.TEST.CROP:
173 | size = int((256 / 224) * config.DATA.IMG_SIZE)
174 | t.append(
175 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
176 | # to maintain same ratio w.r.t. 224 images
177 | )
178 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
179 | else:
180 | t.append(
181 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
182 | interpolation=_pil_interp(config.DATA.INTERPOLATION))
183 | )
184 |
185 | t.append(transforms.ToTensor())
186 | t.append(transforms.Normalize(IMAGENET_CLIP_MEAN, IMAGENET_CLIP_STD))
187 | return transforms.Compose(t)
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Feature Distillation
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Yixuan Wei
6 | # --------------------------------------------------------
7 |
8 | import torch
9 |
10 | import os
11 | import io
12 | import zipfile
13 | from PIL import Image
14 |
15 |
16 | def is_zip_path(img_or_path):
17 | """judge if this is a zip path"""
18 | return '.zip@' in img_or_path
19 |
20 |
21 | class ZipReader(object):
22 | zip_bank = dict()
23 |
24 | def __init__(self):
25 | super(ZipReader, self).__init__()
26 |
27 | @staticmethod
28 | def get_zipfile(path):
29 | zip_bank = ZipReader.zip_bank
30 | if path in zip_bank:
31 | return zip_bank[path]
32 | else:
33 | zfile = zipfile.ZipFile(path, 'r')
34 | zip_bank[path] = zfile
35 | return zip_bank[path]
36 |
37 | @staticmethod
38 | def split_zip_style_path(path):
39 | pos_zip_at = path.index('.zip@')
40 | if pos_zip_at == len(path):
41 | print("character '@' is not found from the given path '%s'" % (path))
42 | assert 0
43 | pos_at = pos_zip_at + len('.zip@') - 1
44 |
45 | zip_path = path[0: pos_at]
46 | folder_path = path[pos_at + 1:]
47 | folder_path = str.strip(folder_path, '/')
48 | return zip_path, folder_path
49 |
50 | @staticmethod
51 | def list_folder(path):
52 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
53 |
54 | zfile = ZipReader.get_zipfile(zip_path)
55 | folder_list = []
56 | for file_foler_name in zfile.namelist():
57 | file_foler_name = str.strip(file_foler_name, '/')
58 | if file_foler_name.startswith(folder_path) and \
59 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \
60 | file_foler_name != folder_path:
61 | if len(folder_path) == 0:
62 | folder_list.append(file_foler_name)
63 | else:
64 | folder_list.append(file_foler_name[len(folder_path)+1:])
65 |
66 | return folder_list
67 |
68 | @staticmethod
69 | def list_files(path, extension=['.*']):
70 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
71 |
72 | zfile = ZipReader.get_zipfile(zip_path)
73 | file_lists = []
74 | for file_foler_name in zfile.namelist():
75 | file_foler_name = str.strip(file_foler_name, '/')
76 | if file_foler_name.startswith(folder_path) and str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
77 | if len(folder_path) == 0:
78 | file_lists.append(file_foler_name)
79 | else:
80 | file_lists.append(file_foler_name[len(folder_path)+1:])
81 |
82 | return file_lists
83 |
84 | @staticmethod
85 | def list_files_fullpath(path, extension=['.*']):
86 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
87 |
88 | zfile = ZipReader.get_zipfile(zip_path)
89 | file_lists = []
90 | for file_foler_name in zfile.namelist():
91 | if file_foler_name.startswith(folder_path) and str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
92 | file_lists.append(file_foler_name)
93 |
94 | return file_lists
95 |
96 | @staticmethod
97 | def imread(path):
98 | zip_path, path_img = ZipReader.split_zip_style_path(path)
99 | zfile = ZipReader.get_zipfile(zip_path)
100 | data = zfile.read(path_img)
101 | im = Image.open(io.BytesIO(data))
102 | return im
103 |
104 | @staticmethod
105 | def read(path):
106 | zip_path, path_img = ZipReader.split_zip_style_path(path)
107 | zfile = ZipReader.get_zipfile(zip_path)
108 | data = zfile.read(path_img)
109 | return data
110 |
111 |
112 | class SubsetRandomSampler(torch.utils.data.Sampler):
113 | r"""Samples elements randomly from a given list of indices, without replacement.
114 |
115 | Arguments:
116 | indices (sequence): a sequence of indices
117 | """
118 |
119 | def __init__(self, indices):
120 | self.epoch = 0
121 | self.indices = indices
122 |
123 | def __iter__(self):
124 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
125 |
126 | def __len__(self):
127 | return len(self.indices)
128 |
129 | def set_epoch(self, epoch):
130 | self.epoch = epoch
--------------------------------------------------------------------------------
/figures/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwinTransformer/Feature-Distillation/2115145a388822bba14c183f9ae74fdf479f7df9/figures/teaser.jpg
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Copyright (c) 2021 Microsoft
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Ze Liu
5 | # Modified by Zhenda Xie
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import logging
11 | import functools
12 | from termcolor import colored
13 |
14 |
15 | @functools.lru_cache()
16 | def create_logger(output_dir, dist_rank=0, name=''):
17 | # create logger
18 | logger = logging.getLogger(name)
19 | logger.setLevel(logging.DEBUG)
20 | logger.propagate = False
21 |
22 | # create formatter
23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
26 |
27 | # create console handlers for master process
28 | if dist_rank == 0:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setLevel(logging.DEBUG)
31 | console_handler.setFormatter(
32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
33 | logger.addHandler(console_handler)
34 |
35 | # create file handlers
36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
37 | file_handler.setLevel(logging.DEBUG)
38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
39 | logger.addHandler(file_handler)
40 |
41 | return logger
42 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Copyright (c) 2021 Microsoft
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Ze Liu
5 | # Modified by Zhenda Xie
6 | # --------------------------------------------------------
7 |
8 | from collections import Counter
9 | from bisect import bisect_right
10 |
11 | import torch
12 | from timm.scheduler.cosine_lr import CosineLRScheduler
13 | from timm.scheduler.step_lr import StepLRScheduler
14 | from timm.scheduler.scheduler import Scheduler
15 |
16 |
17 | def build_scheduler(config, optimizer, n_iter_per_epoch):
18 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
19 | if config.TRAIN.WARMUP_EPOCHS < 0 and config.TRAIN.WARMUP_EPOCHS_FINE != 0:
20 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS_FINE * n_iter_per_epoch)
21 | else:
22 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
23 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
24 | multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS]
25 |
26 | lr_scheduler = None
27 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
28 | lr_scheduler = CosineLRScheduler(
29 | optimizer,
30 | t_initial=num_steps - warmup_steps,
31 | t_mul=1.,
32 | lr_min=config.TRAIN.MIN_LR,
33 | warmup_lr_init=config.TRAIN.WARMUP_LR,
34 | warmup_prefix=True,
35 | warmup_t=warmup_steps,
36 | cycle_limit=1,
37 | t_in_epochs=False,
38 | )
39 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
40 | lr_scheduler = LinearLRScheduler(
41 | optimizer,
42 | t_initial=num_steps,
43 | lr_min_rate=0.01,
44 | warmup_lr_init=config.TRAIN.WARMUP_LR,
45 | warmup_t=warmup_steps,
46 | t_in_epochs=False,
47 | )
48 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
49 | lr_scheduler = StepLRScheduler(
50 | optimizer,
51 | decay_t=decay_steps,
52 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
53 | warmup_lr_init=config.TRAIN.WARMUP_LR,
54 | warmup_t=warmup_steps,
55 | t_in_epochs=False,
56 | )
57 | elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
58 | lr_scheduler = MultiStepLRScheduler(
59 | optimizer,
60 | milestones=multi_steps,
61 | gamma=config.TRAIN.LR_SCHEDULER.GAMMA,
62 | warmup_lr_init=config.TRAIN.WARMUP_LR,
63 | warmup_t=warmup_steps,
64 | t_in_epochs=False,
65 | )
66 |
67 | return lr_scheduler
68 |
69 |
70 | class LinearLRScheduler(Scheduler):
71 | def __init__(self,
72 | optimizer: torch.optim.Optimizer,
73 | t_initial: int,
74 | lr_min_rate: float,
75 | warmup_t=0,
76 | warmup_lr_init=0.,
77 | t_in_epochs=True,
78 | noise_range_t=None,
79 | noise_pct=0.67,
80 | noise_std=1.0,
81 | noise_seed=42,
82 | initialize=True,
83 | ) -> None:
84 | super().__init__(
85 | optimizer, param_group_field="lr",
86 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
87 | initialize=initialize)
88 |
89 | self.t_initial = t_initial
90 | self.lr_min_rate = lr_min_rate
91 | self.warmup_t = warmup_t
92 | self.warmup_lr_init = warmup_lr_init
93 | self.t_in_epochs = t_in_epochs
94 | if self.warmup_t:
95 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
96 | super().update_groups(self.warmup_lr_init)
97 | else:
98 | self.warmup_steps = [1 for _ in self.base_values]
99 |
100 | def _get_lr(self, t):
101 | if t < self.warmup_t:
102 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
103 | else:
104 | t = t - self.warmup_t
105 | total_t = self.t_initial - self.warmup_t
106 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
107 | return lrs
108 |
109 | def get_epoch_values(self, epoch: int):
110 | if self.t_in_epochs:
111 | return self._get_lr(epoch)
112 | else:
113 | return None
114 |
115 | def get_update_values(self, num_updates: int):
116 | if not self.t_in_epochs:
117 | return self._get_lr(num_updates)
118 | else:
119 | return None
120 |
121 |
122 | class MultiStepLRScheduler(Scheduler):
123 | def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None:
124 | super().__init__(optimizer, param_group_field="lr")
125 |
126 | self.milestones = milestones
127 | self.gamma = gamma
128 | self.warmup_t = warmup_t
129 | self.warmup_lr_init = warmup_lr_init
130 | self.t_in_epochs = t_in_epochs
131 | if self.warmup_t:
132 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
133 | super().update_groups(self.warmup_lr_init)
134 | else:
135 | self.warmup_steps = [1 for _ in self.base_values]
136 |
137 | assert self.warmup_t <= min(self.milestones)
138 |
139 | def _get_lr(self, t):
140 | if t < self.warmup_t:
141 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
142 | else:
143 | lrs = [v * (self.gamma ** bisect_right(self.milestones, t)) for v in self.base_values]
144 | return lrs
145 |
146 | def get_epoch_values(self, epoch: int):
147 | if self.t_in_epochs:
148 | return self._get_lr(epoch)
149 | else:
150 | return None
151 |
152 | def get_update_values(self, num_updates: int):
153 | if not self.t_in_epochs:
154 | return self._get_lr(num_updates)
155 | else:
156 | return None
--------------------------------------------------------------------------------
/main_fd.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Feature Distillation
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # Modified by Yixuan Wei
7 | # --------------------------------------------------------
8 |
9 | import os
10 | import time
11 | import argparse
12 | import datetime
13 | import numpy as np
14 |
15 | import torch
16 | import torch.backends.cudnn as cudnn
17 | import torch.distributed as dist
18 | import torch.cuda.amp as amp
19 | from timm.utils import AverageMeter
20 |
21 | from config import get_config
22 | from models import build_model
23 | from data import build_loader
24 | from lr_scheduler import build_scheduler
25 | from optimizer import build_optimizer
26 | from logger import create_logger
27 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
28 |
29 | import wandb
30 | global no_wandb
31 | no_wandb = True
32 |
33 | def wandb_log(*args, **kwargs):
34 | if dist.get_rank() == 0 and not no_wandb:
35 | wandb.log(*args, **kwargs)
36 |
37 |
38 | def parse_option():
39 | global no_wandb
40 | parser = argparse.ArgumentParser('Feature Distillation pre-training script', add_help=False)
41 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
42 | parser.add_argument(
43 | "--opts",
44 | help="Modify config options by adding 'KEY VALUE' pairs. ",
45 | default=None,
46 | nargs='+',
47 | )
48 |
49 | # easy config modification
50 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
51 | parser.add_argument('--data-path', type=str, help='path to dataset')
52 | parser.add_argument('--resume', help='resume from checkpoint')
53 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
54 | parser.add_argument('--use-checkpoint', action='store_true',
55 | help="whether to use gradient checkpointing to save memory")
56 | parser.add_argument('--enable-amp', action='store_true')
57 | parser.add_argument('--disable-amp', action='store_false', dest='enable_amp')
58 | parser.set_defaults(enable_amp=True)
59 | parser.add_argument('--output', default='output', type=str, metavar='PATH',
60 | help='root of output folder, the full path is