├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── beit_finetuning ├── modeling_finetune.py ├── optim_factory.py ├── requirements.txt ├── run_class_finetuning.py ├── run_with_submitit_finetune.py └── utils.py ├── bpe_simple_vocab_16e6.txt.gz ├── dataset_catalog.json ├── datasets.py ├── eval_zeroshot.py ├── labels.json ├── losses.py ├── main.py ├── main_linear.py ├── make_dataset.py ├── models.py ├── redcaps └── combine_captions.py ├── run_with_submitit.py ├── run_with_submitit_linear.py ├── slip.png ├── templates.json ├── tokenizer.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SLIP 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 2 spaces for indentation rather than tabs 31 | * 80 character line length 32 | 33 | ## License 34 | By contributing to SLIP, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [SLIP: Self-supervision meets Language-Image Pre-training](https://arxiv.org/abs/2112.12750) 2 | 3 |

SLIP framework

4 | 5 | 6 | ## What you can find in this repo: 7 | - Pre-trained models (with ViT-Small, Base, Large) and code to reproduce results from our paper: **[SLIP: Self-supervision meets Language-Image Pre-training](https://arxiv.org/abs/2112.12750).** *[Norman Mu](https://normanmu.com), [Alexander Kirillov](https://alexander-kirillov.github.io/), [David Wagner](http://people.eecs.berkeley.edu/~daw/) and [Saining Xie](https://sainingxie.com)*, arXiv 2021 8 | 9 | - An improved CLIP baseline (31.3% → 34.6% ImageNet 0-shot w/ Modified ResNet-50) on YFCC15M dataset. 10 | - Zero-shot transfer and linear classification evaluation scripts on **26** downstream datasets. 11 | 12 | ## Updates: 13 | 14 | Jan 18 2022: Added support for training on RedCaps 15 | 16 | Jan 17 2022: Released CC3M/CC12M CLIP/SLIP ViT-B checkpoints 17 | 18 | ## Results and Pre-trained Models 19 | The following models are pre-trained on YFCC15M and evaluated on ImageNet-1K (ILSVRC2012). 20 | 21 | ### ViT-Small (MoCo v3 version w/ 12 vs. 6 heads) 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 |
MethodEpochs0-shotLinearFinetunedWeights
CLIP2532.759.378.2url
SimCLR25-58.179.9url
SLIP2538.366.480.3url
SLIP5039.367.680.7url
SLIP10039.568.380.7url
74 | 75 | ### ViT-Base 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 |
MethodEpochs0-shotLinearFinetunedWeights
CLIP2537.666.580.5url
SimCLR25-64.082.5url
SLIP2542.872.182.6url
SLIP5044.173.082.9url
SLIP10045.073.683.4url
128 | 129 | ### ViT-Large 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 |
MethodEpochs0-shotLinearFinetunedWeights
CLIP2540.470.581.0url
SimCLR25-66.784.0url
SLIP2546.276.084.2url
SLIP5047.475.884.7url
SLIP10047.975.184.8url
182 | 183 | ### Additional Datasets and Models 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 |
DatasetMethodModelEpochs0-shotLinearFinetunedWeights
CC3MCLIPViT-B4017.153.379.5url
CC3MSLIPViT-B4023.065.481.4url
CC12MCLIPViT-B3536.569.082.1url
CC12MSLIPViT-B3540.773.783.1url
239 | 240 | ## 1. Setup 241 | Install [PyTorch](https://pytorch.org) and [timm](https://github.com/rwightman/pytorch-image-models). 242 | The code has been tested with CUDA 11.3/CuDNN 8.2.0, PyTorch 1.10.0 and timm 0.5.0. 243 | 244 | ### 1.1. YFCC15M Setup 245 | Download the [YFCC100M dataset](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/). 246 | Our dataloader expects the following dataset directory structure with 100 folders containing 1000 zip archives of 1000 images each. 247 | The concatenation of the folder, archive, and file names is the index of the image (i.e. image 12345678 is stored as `678.jpg` within `12/345.zip`): 248 | 249 | ``` 250 | /path/to/yfcc100m/ 251 | ├── images/ 252 | │   ├── 00/ 253 | │   │   └── 000.zip 254 | │   │   │   ├── 000.jpg 255 | │   │   │   │   ... 256 | │   │   │   └── 999.jpg 257 | │   │   ... 258 | │   │   └── 999.zip 259 | │   ... 260 | │   └── 99/ 261 | ... 262 | ``` 263 | 264 | Prepare the YFCC15M subset metadata pickle: 265 | 1. Download and compile a list of downloaded images to `flickr_unique_ids.npy` ([ours](https://dl.fbaipublicfiles.com/deepcluster/flickr_unique_ids.npy)) 266 | 2. Download OpenAI's list of captioned YFCC100M images according to instructions [here](https://github.com/openai/CLIP/blob/8cad3a736a833bc4c9b4dd34ef12b52ec0e68856/data/yfcc100m.md) 267 | 3. Run `python make_dataset.py` to create the `yfcc15m.pkl` metadata pickle 268 | 269 | When pre-training with YFCC15M, set `--dataset yfcc15m --root /path/to/yfcc100m --metadata /path/to/yfcc15m.pkl`. 270 | 271 | ### 1.2. COCO Captions Setup 272 | Download and unzip the 2017 Train [images](http://images.cocodataset.org/zips/train2017.zip) and [annotations](http://images.cocodataset.org/annotations/annotations_trainval2017.zip). 273 | When pre-training on COCO, set `--dataset coco --root /path/to/coco --metadata /path/to/captions_train2017.json`. 274 | 275 | ### 1.3. Conceptual Captions Setup 276 | [CC3M](https://ai.google.com/research/ConceptualCaptions/download) and [CC12M](https://github.com/google-research-datasets/conceptual-12m) are published as tsv files listing original image urls and processed captions. 277 | Download images and collect the captions of all available images (many will be missing due to broken links) into `cc3m.npy` and `cc12m.npy`. 278 | 279 | For CC3M our dataloader expects `cc3m.npy` to contain a NumPy array of dicts in the following format: 280 | 281 | ``` 282 | { 283 | 'image_id': 1510438788, # local file path relative to root 284 | 'captions': ['large field with pink tulips on a clear sunny summer day with a blue sky'] 285 | } 286 | ``` 287 | 288 | For CC12M our dataloader expects `cc12m.npy` to contain a NumPy array of dicts in the following format: 289 | 290 | ``` 291 | { 292 | 'image_name': '0.jpg', # local file path relative to root 293 | 'image_id': 0, 294 | 'captions': ['Metal Design Within Reach Ivory Slipper Chairs - a Pair For Sale - Image 7 of 10'] 295 | } 296 | ``` 297 | 298 | When pre-training on CC3M set `--dataset cc3m --root /path/to/cc3m --metadata /path/to/cc3m.npy`, and whe pre-training on CC12M set `--dataset cc12m --root /path/to/cc12m --metadata /path/to/cc12m.npy`. 299 | 300 | ### 1.4. RedCaps Setup 301 | [RedCaps](https://redcaps.xyz) is published as a list of JSON annotation files containing image urls and raw/processed captions. 302 | Images can be downloaded from these annotations with a helpful [downloader tool](https://github.com/redcaps-dataset/redcaps-downloader). 303 | Then merge all per-subreddit annotations into a single file with the [combine_captions.py](redcaps/combine_captions.py) script: 304 | 305 | ``` 306 | python redcaps/combine_captions.py --input /path/to/redcaps/annotations --output /path/to/redcaps_v1.json 307 | ``` 308 | 309 | To pre-train on RedCaps set `--dataset redcaps --root /path/to/redcaps --metadata /path/to/redcaps_v1.json`. 310 | 311 | 312 | ### 1.4. Downstream Dataset Setup 313 | Zero-shot (in [main.py](main.py) and [eval_zeroshot.py](eval_zeroshot.py)) and linear (in [main_linear.py](main_linear.py)) evaluations read dataset paths from [dataset_catalog.json](dataset_catalog.json). 314 | Zero-shot evaluations read CLIP's class labels and caption templates from [labels.json](labels.json) and [templates.json](templates.json). 315 | If just pre-training models on YFCC15M, only the ImageNet path is required for model validation between training epochs. 316 | See Section 3 below on zero-shot transfer evaluation for dataset preparation details. 317 | 318 | ## 2. Pre-training 319 | We use the following pre-training recipes for SLIP, CLIP, and SimCLR. 320 | See [main.py](main.py) for the full list of default arguments. 321 | We use the same lr and wd settings for all model sizes within the same training framework, and different model sizes can be selected by passing in different strings to the `--model` argument such as `SLIP_VITS16` or `SLIP_VITL16`. 322 | 323 | In our workflow we use [submitit](https://github.com/facebookincubator/submitit), which interfaces nicely with Slurm. 324 | For local training with the [torchrun](https://pytorch.org/docs/stable/elastic/run.html) utility (supersedes `torch.distributed.launch`), replace `python run_with_submitit.py` with `torchrun --nproc_per_node=8 main.py`. 325 | Local multi-node training with `torchrun` should also be possible. 326 | 327 | We train most of our models on 8x 8-gpu nodes, but training with fewer gpus is possible by reducing the batch size and setting the `--update-freq` argument above 1 to enable gradient accumulation. 328 | Note that gradient accumulation will increase the variance of minibatch statistics and alter the training dynamics of batchnorm, which is used in SLIP and SimCLR. 329 | 330 | ### SLIP ViT-Base with 8-nodes (batch size 4096) 331 | ``` 332 | python run_with_submitit.py \ 333 | --root /path/to/yfcc100m \ 334 | --model SLIP_VITB16 \ 335 | --lr 3e-3 --wd 0.1 336 | ``` 337 | 338 | ### CLIP ViT-Base with 8-nodes (batch size 4096) 339 | ``` 340 | python run_with_submitit.py \ 341 | --root /path/to/yfcc100m \ 342 | --model CLIP_VITB16 \ 343 | --lr 5e-4 --wd 0.5 344 | ``` 345 | 346 | ### SimCLR ViT-Base with 8-nodes (batch size 4096) 347 | ``` 348 | python run_with_submitit.py \ 349 | --root /path/to/yfcc100m \ 350 | --model SIMCLR_VITB16 \ 351 | --ssl-mlp-dim 4096 --ssl-emb-dim 256 --ssl-temp 0.1 \ 352 | --lr 3.2e-3 --wd 0.1 353 | ``` 354 | 355 | Some important arguments: 356 | 357 | `--dataset`: pre-training dataset name. choices include `yfcc15m`, `cc12m`, `cc3m`, `coco`. 358 | 359 | `--root`: path to dataset root 360 | 361 | `--metadata`: path to metadata file (see section 1 for details) 362 | 363 | `--ssl-mlp-dim`: hidden dim of SimCLR mlp projection head 364 | 365 | `--ssl-emb-dim`: output embed dim of SimCLR mlp projection head 366 | 367 | `--ssl-scale`: loss scale for SimCLR objective 368 | 369 | `--ssl-temp`: softmax temperature for SimCLR objective 370 | 371 | `--batch-size`: number of samples per-device/per-gpu 372 | 373 | `--lr-start`: initial warmup lr 374 | 375 | `--lr-end`: minimum final lr 376 | 377 | `--update-freq`: optimizer update frequency, i.e. gradient accumulation steps 378 | 379 | `--disable-amp`: disable mixed-precision training (requires more memory and compute) 380 | 381 | ## 3. Evaluation: Zero-shot Transfer 382 | First, prepare additional downstream classification datasets: 383 | - MNIST, CIFAR-10/100, STL-10: Automatic download via [torchvision datasets](https://pytorch.org/vision/stable/datasets.html) 384 | - HatefulMemes: Manual download from [official website](https://hatefulmemeschallenge.com/#download) and sort images according to `train.jsonl`/`dev.jsonl` into train/dev folder 385 | - Rendered SST2, Country211: Manual download from [CLIP repo](https://github.com/openai/CLIP/tree/main/data) 386 | - Other datasets: Use scripts from [VISSL](https://github.com/facebookresearch/vissl/tree/main/extra_scripts/datasets) 387 | 388 | Then set all dataset paths in [dataset_catalog.json](dataset_catalog.json). 389 | 390 | Evaluate zero-shot transfer to various classification benchmarks with [eval_zeroshot.py](eval_zeroshot.py), which reads labels and templates from [labels.json](labels.json)/[templates.json](templates.json) and dataset paths from [dataset_catalog.json](dataset_catalog.json). Inference is performed with a single gpu. By default, the script iterates through all datasets in [dataset_catalog.json](dataset_catalog.json) and evaluates zero-shot in order. Evaluation can be limited to a subset of datasets by replacing `for d in datasets:` with `for d in ['imagenet']:` on line 78. 391 | 392 | ``` 393 | python eval_zeroshot.py --resume /path/to/checkpoint.pt 394 | ``` 395 | 396 | ## 4. Evaluation: Linear Classification 397 | We use a modified version of the MoCo v3 ImageNet linear classification script, [main_linear.py](main_linear.py). 398 | We use the same single node 8-gpu recipe for all model sizes. 399 | See [main_linear.py](main_linear.py) for the full list of default arguments. 400 | As with pre-training, our workflow uses [submitit](https://github.com/facebookincubator/submitit). 401 | For local training with [torchrun](https://pytorch.org/docs/stable/elastic/run.html), replace `python run_with_submitit_linear.py` with `torchrun --nproc_per_node=8 main_linear.py`. 402 | This script reads the ImageNet dataset path from the dataset catalog ([dataset_catalog.json](dataset_catalog.json)), which must be set properly before training. 403 | 404 | ``` 405 | python run_with_submitit_linear.py \ 406 | --arch vit_base_patch16_224 --dataset imagenet \ 407 | --pretrained /path/to/checkpoint.pt 408 | ``` 409 | 410 | To evaluate linear classification on other datasets, set `--dataset` to the corresponding dataset name listed in [dataset_catalog.json](dataset_catalog.json). 411 | 412 | ## 5. Evaluation: End-to-End Finetuning 413 | We use a modified version of the ImageNet finetuning script from [BeiT](https://github.com/microsoft/unilm/tree/f8f3df80c65eb5e5fc6d6d3c9bd3137621795d1e/beit). 414 | Our code has been tested with commit `f8f3df8`. 415 | We have removed the explicit torch, torchvision, and timm dependencies from [beit_finetuning/requirements.txt](beit_finetuning/requirements.txt), as they conflict with the versions used in our SLIP code (CUDA 11.3/CuDNN 8.2.0, PyTorch 1.10.0 and timm 0.5.0). 416 | The fintuning code has been modified and tested to work with these versions. 417 | 418 | ### 5.1. Setup 419 | To evaluate end-to-end finetuning on ImageNet, first clone the BeiT repo and checkout the correct commit: 420 | 421 | ``` 422 | git clone git@github.com:microsoft/unilm.git 423 | cd unilm/beit 424 | git checkout f8f3df8 425 | ``` 426 | 427 | Now copy over modified files from our [beit_finetuning](beit_finetuning) directory: 428 | 429 | ``` 430 | cp beit_finetuning/* unilm/beit 431 | cd unilm/beit 432 | ``` 433 | 434 | Install pip dependencies and Nvidia Apex: 435 | 436 | ``` 437 | pip install -r requirements.txt 438 | git clone https://github.com/NVIDIA/apex 439 | cd apex 440 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 441 | ``` 442 | 443 | 444 | ### 5.2. Commands 445 | As with pre-training, our workflow uses [submitit](https://github.com/facebookincubator/submitit). 446 | For local training with [torchrun](https://pytorch.org/docs/stable/elastic/run.html), replace `python run_with_submitit_finetune.py` with `torchrun --nproc_per_node=8 run_class_finetuning.py`. 447 | We established finetuning recipes based on the BeiT recipes with some light additional hyperparameter tuning. 448 | We increase regularization with model size: ViT-S uses drop_path=0 and layer_decay=0.65, ViT-B uses drop_path=0.1 and layer_decay=0.65, and ViT-L uses drop_path=0.1 and layer_decay=0.75. 449 | Note the use of the `--finetune` argument instead of `--resume`. 450 | 451 | ### ViT-Small (MoCo v3 version w/ 12 vs. 6 heads) 452 | 453 | ``` 454 | python run_with_submitit_finetune.py \ 455 | --batch_size 128 --enable_deepspeed \ 456 | --epochs 100 --warmup_epochs 20 \ 457 | --model beit_small_patch16_224 --nb_classes 1000 \ 458 | --imagenet_default_mean_and_std \ 459 | --model_key state_dict --model_prefix module.visual. \ 460 | --disable_rel_pos_bias --abs_pos_emb --use_cls \ 461 | --mixup 0.8 --cutmix 1 \ 462 | --layer_scale_init_value 0 \ 463 | --lr 4e-3 --drop_path 0 --layer_decay 0.65 \ 464 | --output_dir /path/to/output_dir --finetune /path/to/checkpoint.pt 465 | ``` 466 | 467 | ### ViT-Base 468 | 469 | ``` 470 | python run_with_submitit_finetune.py \ 471 | --batch_size 128 --enable_deepspeed \ 472 | --epochs 100 --warmup_epochs 20 \ 473 | --model beit_base_patch16_224 --nb_classes 1000 \ 474 | --imagenet_default_mean_and_std \ 475 | --model_key state_dict --model_prefix module.visual. \ 476 | --disable_rel_pos_bias --abs_pos_emb --use_cls \ 477 | --mixup 0.8 --cutmix 1 \ 478 | --layer_scale_init_value 0 \ 479 | --lr 4e-3 --drop_path 0.1 --layer_decay 0.65 \ 480 | --output_dir /path/to/output_dir --finetune /path/to/checkpoint.pt 481 | ``` 482 | 483 | ### ViT-Large 484 | 485 | ``` 486 | python run_with_submitit_finetune.py \ 487 | --batch_size 128 --enable_deepspeed \ 488 | --epochs 50 --warmup_epochs 5 \ 489 | --model beit_large_patch16_224 --nb_classes 1000 \ 490 | --imagenet_default_mean_and_std \ 491 | --model_key state_dict --model_prefix module.visual. \ 492 | --disable_rel_pos_bias --abs_pos_emb --use_cls \ 493 | --mixup 0.8 --cutmix 1 \ 494 | --layer_scale_init_value 0 \ 495 | --lr 4e-3 --drop_path 0.1 --layer_decay 0.75 \ 496 | --output_dir /path/to/output_dir --finetune /path/to/checkpoint.pt 497 | ``` 498 | 499 | 500 | ### License 501 | 502 | This project is under the MIT license. See [LICENSE](LICENSE) for details. 503 | 504 | ### Citation 505 | ``` 506 | @Article{mu2021slip, 507 | author = {Norman Mu and Alexander Kirillov and David Wagner and Saining Xie}, 508 | title = {SLIP: Self-supervision meets Language-Image Pre-training}, 509 | journal = {arXiv preprint arXiv:2112.12750}, 510 | year = {2021}, 511 | } 512 | ``` 513 | -------------------------------------------------------------------------------- /beit_finetuning/modeling_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # -------------------------------------------------------- 8 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 9 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 10 | # Copyright (c) 2021 Microsoft 11 | # Licensed under The MIT License [see LICENSE for details] 12 | # By Hangbo Bao 13 | # Based on timm and DeiT code bases 14 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 15 | # https://github.com/facebookresearch/deit/ 16 | # https://github.com/facebookresearch/dino 17 | # --------------------------------------------------------' 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 25 | from timm.models.registry import register_model 26 | 27 | 28 | def _cfg(url='', **kwargs): 29 | return { 30 | 'url': url, 31 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 32 | 'crop_pct': .9, 'interpolation': 'bicubic', 33 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 34 | **kwargs 35 | } 36 | 37 | 38 | class DropPath(nn.Module): 39 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 40 | """ 41 | def __init__(self, drop_prob=None): 42 | super(DropPath, self).__init__() 43 | self.drop_prob = drop_prob 44 | 45 | def forward(self, x): 46 | return drop_path(x, self.drop_prob, self.training) 47 | 48 | def extra_repr(self) -> str: 49 | return 'p={}'.format(self.drop_prob) 50 | 51 | 52 | class Mlp(nn.Module): 53 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 54 | super().__init__() 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | self.fc1 = nn.Linear(in_features, hidden_features) 58 | self.act = act_layer() 59 | self.fc2 = nn.Linear(hidden_features, out_features) 60 | self.drop = nn.Dropout(drop) 61 | 62 | def forward(self, x): 63 | x = self.fc1(x) 64 | x = self.act(x) 65 | # x = self.drop(x) 66 | # commit this for the orignal BERT implement 67 | x = self.fc2(x) 68 | x = self.drop(x) 69 | return x 70 | 71 | 72 | class Attention(nn.Module): 73 | def __init__( 74 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 75 | proj_drop=0., window_size=None, attn_head_dim=None): 76 | super().__init__() 77 | self.num_heads = num_heads 78 | head_dim = dim // num_heads 79 | if attn_head_dim is not None: 80 | head_dim = attn_head_dim 81 | all_head_dim = head_dim * self.num_heads 82 | self.scale = qk_scale or head_dim ** -0.5 83 | 84 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) 85 | 86 | if window_size: 87 | self.window_size = window_size 88 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 89 | self.relative_position_bias_table = nn.Parameter( 90 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 91 | # cls to token & token 2 cls & cls to cls 92 | 93 | # get pair-wise relative position index for each token inside the window 94 | coords_h = torch.arange(window_size[0]) 95 | coords_w = torch.arange(window_size[1]) 96 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 97 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 98 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 99 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 100 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 101 | relative_coords[:, :, 1] += window_size[1] - 1 102 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 103 | relative_position_index = \ 104 | torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) 105 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 106 | relative_position_index[0, 0:] = self.num_relative_distance - 3 107 | relative_position_index[0:, 0] = self.num_relative_distance - 2 108 | relative_position_index[0, 0] = self.num_relative_distance - 1 109 | 110 | self.register_buffer("relative_position_index", relative_position_index) 111 | else: 112 | self.window_size = None 113 | self.relative_position_bias_table = None 114 | self.relative_position_index = None 115 | 116 | self.attn_drop = nn.Dropout(attn_drop) 117 | self.proj = nn.Linear(all_head_dim, dim) 118 | self.proj_drop = nn.Dropout(proj_drop) 119 | 120 | def forward(self, x, rel_pos_bias=None): 121 | B, N, C = x.shape 122 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 123 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 124 | 125 | q = q * self.scale 126 | attn = (q @ k.transpose(-2, -1)) 127 | 128 | if self.relative_position_bias_table is not None: 129 | relative_position_bias = \ 130 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 131 | self.window_size[0] * self.window_size[1] + 1, 132 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 133 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 134 | attn = attn + relative_position_bias.unsqueeze(0) 135 | 136 | if rel_pos_bias is not None: 137 | attn = attn + rel_pos_bias 138 | 139 | attn = attn.softmax(dim=-1) 140 | attn = self.attn_drop(attn) 141 | 142 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 143 | x = self.proj(x) 144 | x = self.proj_drop(x) 145 | return x 146 | 147 | 148 | class Block(nn.Module): 149 | 150 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 151 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 152 | window_size=None, attn_head_dim=None): 153 | super().__init__() 154 | self.norm1 = norm_layer(dim) 155 | self.attn = Attention( 156 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 157 | attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) 158 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 159 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 160 | self.norm2 = norm_layer(dim) 161 | mlp_hidden_dim = int(dim * mlp_ratio) 162 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 163 | 164 | if init_values > 0: 165 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 166 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 167 | else: 168 | self.gamma_1, self.gamma_2 = None, None 169 | 170 | def forward(self, x, rel_pos_bias=None): 171 | if self.gamma_1 is None: 172 | x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 173 | x = x + self.drop_path(self.mlp(self.norm2(x))) 174 | else: 175 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 176 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 177 | return x 178 | 179 | 180 | class PatchEmbed(nn.Module): 181 | """ Image to Patch Embedding 182 | """ 183 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 184 | super().__init__() 185 | img_size = to_2tuple(img_size) 186 | patch_size = to_2tuple(patch_size) 187 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 188 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 189 | self.img_size = img_size 190 | self.patch_size = patch_size 191 | self.num_patches = num_patches 192 | 193 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 194 | 195 | def forward(self, x, **kwargs): 196 | B, C, H, W = x.shape 197 | # FIXME look at relaxing size constraints 198 | assert H == self.img_size[0] and W == self.img_size[1], \ 199 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 200 | x = self.proj(x).flatten(2).transpose(1, 2) 201 | return x 202 | 203 | 204 | class RelativePositionBias(nn.Module): 205 | 206 | def __init__(self, window_size, num_heads): 207 | super().__init__() 208 | self.window_size = window_size 209 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 210 | self.relative_position_bias_table = nn.Parameter( 211 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 212 | # cls to token & token 2 cls & cls to cls 213 | 214 | # get pair-wise relative position index for each token inside the window 215 | coords_h = torch.arange(window_size[0]) 216 | coords_w = torch.arange(window_size[1]) 217 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 218 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 219 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 220 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 221 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 222 | relative_coords[:, :, 1] += window_size[1] - 1 223 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 224 | relative_position_index = \ 225 | torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) 226 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 227 | relative_position_index[0, 0:] = self.num_relative_distance - 3 228 | relative_position_index[0:, 0] = self.num_relative_distance - 2 229 | relative_position_index[0, 0] = self.num_relative_distance - 1 230 | 231 | self.register_buffer("relative_position_index", relative_position_index) 232 | 233 | # trunc_normal_(self.relative_position_bias_table, std=.02) 234 | 235 | def forward(self): 236 | relative_position_bias = \ 237 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 238 | self.window_size[0] * self.window_size[1] + 1, 239 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 240 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 241 | 242 | 243 | class VisionTransformer(nn.Module): 244 | """ Vision Transformer with support for patch or hybrid CNN input stage 245 | """ 246 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 247 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 248 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, 249 | use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, 250 | use_mean_pooling=True, init_scale=0.001): 251 | super().__init__() 252 | self.num_classes = num_classes 253 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 254 | 255 | self.patch_embed = PatchEmbed( 256 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 257 | num_patches = self.patch_embed.num_patches 258 | 259 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 260 | # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 261 | if use_abs_pos_emb: 262 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 263 | else: 264 | self.pos_embed = None 265 | self.pos_drop = nn.Dropout(p=drop_rate) 266 | 267 | if use_shared_rel_pos_bias: 268 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) 269 | else: 270 | self.rel_pos_bias = None 271 | 272 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 273 | self.use_rel_pos_bias = use_rel_pos_bias 274 | self.blocks = nn.ModuleList([ 275 | Block( 276 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 277 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 278 | init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) 279 | for i in range(depth)]) 280 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 281 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 282 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 283 | 284 | if self.pos_embed is not None: 285 | trunc_normal_(self.pos_embed, std=.02) 286 | trunc_normal_(self.cls_token, std=.02) 287 | # trunc_normal_(self.mask_token, std=.02) 288 | trunc_normal_(self.head.weight, std=.02) 289 | self.apply(self._init_weights) 290 | self.fix_init_weight() 291 | 292 | self.head.weight.data.mul_(init_scale) 293 | self.head.bias.data.mul_(init_scale) 294 | 295 | def fix_init_weight(self): 296 | def rescale(param, layer_id): 297 | param.div_(math.sqrt(2.0 * layer_id)) 298 | 299 | for layer_id, layer in enumerate(self.blocks): 300 | rescale(layer.attn.proj.weight.data, layer_id + 1) 301 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 302 | 303 | def _init_weights(self, m): 304 | if isinstance(m, nn.Linear): 305 | trunc_normal_(m.weight, std=.02) 306 | if isinstance(m, nn.Linear) and m.bias is not None: 307 | nn.init.constant_(m.bias, 0) 308 | elif isinstance(m, nn.LayerNorm): 309 | nn.init.constant_(m.bias, 0) 310 | nn.init.constant_(m.weight, 1.0) 311 | 312 | def get_num_layers(self): 313 | return len(self.blocks) 314 | 315 | @torch.jit.ignore 316 | def no_weight_decay(self): 317 | return {'pos_embed', 'cls_token'} 318 | 319 | def get_classifier(self): 320 | return self.head 321 | 322 | def reset_classifier(self, num_classes, global_pool=''): 323 | self.num_classes = num_classes 324 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 325 | 326 | def forward_features(self, x): 327 | x = self.patch_embed(x) 328 | batch_size, seq_len, _ = x.size() 329 | 330 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 331 | x = torch.cat((cls_tokens, x), dim=1) 332 | if self.pos_embed is not None: 333 | x = x + self.pos_embed 334 | x = self.pos_drop(x) 335 | 336 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 337 | for blk in self.blocks: 338 | x = blk(x, rel_pos_bias=rel_pos_bias) 339 | 340 | x = self.norm(x) 341 | if self.fc_norm is not None: 342 | t = x[:, 1:, :] 343 | return self.fc_norm(t.mean(1)) 344 | else: 345 | return x[:, 0] 346 | 347 | def forward(self, x): 348 | x = self.forward_features(x) 349 | x = self.head(x) 350 | return x 351 | 352 | 353 | @register_model 354 | def beit_small_patch16_224(pretrained=False, **kwargs): 355 | model = VisionTransformer( 356 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 357 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 358 | model.default_cfg = _cfg() 359 | return model 360 | 361 | 362 | @register_model 363 | def beit_base_patch16_224(pretrained=False, **kwargs): 364 | model = VisionTransformer( 365 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 366 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 367 | model.default_cfg = _cfg() 368 | return model 369 | 370 | 371 | @register_model 372 | def beit_base_patch16_384(pretrained=False, **kwargs): 373 | model = VisionTransformer( 374 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 375 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 376 | model.default_cfg = _cfg() 377 | return model 378 | 379 | 380 | @register_model 381 | def beit_large_patch16_224(pretrained=False, **kwargs): 382 | model = VisionTransformer( 383 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 384 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 385 | model.default_cfg = _cfg() 386 | return model 387 | 388 | 389 | @register_model 390 | def beit_large_patch16_384(pretrained=False, **kwargs): 391 | model = VisionTransformer( 392 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 393 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 394 | model.default_cfg = _cfg() 395 | return model 396 | 397 | 398 | @register_model 399 | def beit_large_patch16_512(pretrained=False, **kwargs): 400 | model = VisionTransformer( 401 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 402 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 403 | model.default_cfg = _cfg() 404 | return model 405 | -------------------------------------------------------------------------------- /beit_finetuning/optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # -------------------------------------------------------- 8 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 9 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 10 | # Copyright (c) 2021 Microsoft 11 | # Licensed under The MIT License [see LICENSE for details] 12 | # By Hangbo Bao 13 | # Based on timm code bases 14 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 15 | # --------------------------------------------------------' 16 | import torch 17 | from torch import optim as optim 18 | 19 | from timm.optim.adafactor import Adafactor 20 | from timm.optim.adahessian import Adahessian 21 | from timm.optim.adamp import AdamP 22 | from timm.optim.lookahead import Lookahead 23 | from timm.optim.nadam import Nadam 24 | # from timm.optim.novograd import NovoGrad 25 | from timm.optim.nvnovograd import NvNovoGrad 26 | from timm.optim.radam import RAdam 27 | from timm.optim.rmsprop_tf import RMSpropTF 28 | from timm.optim.sgdp import SGDP 29 | 30 | import json 31 | 32 | try: 33 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 34 | has_apex = True 35 | except ImportError: 36 | has_apex = False 37 | 38 | 39 | def get_num_layer_for_vit(var_name, num_max_layer): 40 | if var_name in ("cls_token", "mask_token", "pos_embed"): 41 | return 0 42 | elif var_name.startswith("patch_embed"): 43 | return 0 44 | elif var_name.startswith("rel_pos_bias"): 45 | return num_max_layer - 1 46 | elif var_name.startswith("blocks"): 47 | layer_id = int(var_name.split('.')[1]) 48 | return layer_id + 1 49 | else: 50 | return num_max_layer - 1 51 | 52 | 53 | class LayerDecayValueAssigner(object): 54 | def __init__(self, values): 55 | self.values = values 56 | 57 | def get_scale(self, layer_id): 58 | return self.values[layer_id] 59 | 60 | def get_layer_id(self, var_name): 61 | return get_num_layer_for_vit(var_name, len(self.values)) 62 | 63 | 64 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 65 | parameter_group_names = {} 66 | parameter_group_vars = {} 67 | 68 | for name, param in model.named_parameters(): 69 | if not param.requires_grad: 70 | continue # frozen weights 71 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 72 | group_name = "no_decay" 73 | this_weight_decay = 0. 74 | else: 75 | group_name = "decay" 76 | this_weight_decay = weight_decay 77 | if get_num_layer is not None: 78 | layer_id = get_num_layer(name) 79 | group_name = "layer_%d_%s" % (layer_id, group_name) 80 | else: 81 | layer_id = None 82 | 83 | if group_name not in parameter_group_names: 84 | if get_layer_scale is not None: 85 | scale = get_layer_scale(layer_id) 86 | else: 87 | scale = 1. 88 | 89 | parameter_group_names[group_name] = { 90 | "weight_decay": this_weight_decay, 91 | "params": [], 92 | "lr_scale": scale 93 | } 94 | parameter_group_vars[group_name] = { 95 | "weight_decay": this_weight_decay, 96 | "params": [], 97 | "lr_scale": scale 98 | } 99 | 100 | parameter_group_vars[group_name]["params"].append(param) 101 | parameter_group_names[group_name]["params"].append(name) 102 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 103 | return list(parameter_group_vars.values()) 104 | 105 | 106 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 107 | opt_lower = args.opt.lower() 108 | weight_decay = args.weight_decay 109 | if weight_decay and filter_bias_and_bn: 110 | skip = {} 111 | if skip_list is not None: 112 | skip = skip_list 113 | elif hasattr(model, 'no_weight_decay'): 114 | skip = model.no_weight_decay() 115 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 116 | weight_decay = 0. 117 | else: 118 | parameters = model.parameters() 119 | 120 | if 'fused' in opt_lower: 121 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 122 | 123 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 124 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 125 | opt_args['eps'] = args.opt_eps 126 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 127 | opt_args['betas'] = args.opt_betas 128 | 129 | opt_split = opt_lower.split('_') 130 | opt_lower = opt_split[-1] 131 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 132 | opt_args.pop('eps', None) 133 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 134 | elif opt_lower == 'momentum': 135 | opt_args.pop('eps', None) 136 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 137 | elif opt_lower == 'adam': 138 | optimizer = optim.Adam(parameters, **opt_args) 139 | elif opt_lower == 'adamw': 140 | optimizer = optim.AdamW(parameters, **opt_args) 141 | elif opt_lower == 'nadam': 142 | optimizer = Nadam(parameters, **opt_args) 143 | elif opt_lower == 'radam': 144 | optimizer = RAdam(parameters, **opt_args) 145 | elif opt_lower == 'adamp': 146 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 147 | elif opt_lower == 'sgdp': 148 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 149 | elif opt_lower == 'adadelta': 150 | optimizer = optim.Adadelta(parameters, **opt_args) 151 | elif opt_lower == 'adafactor': 152 | if not args.lr: 153 | opt_args['lr'] = None 154 | optimizer = Adafactor(parameters, **opt_args) 155 | elif opt_lower == 'adahessian': 156 | optimizer = Adahessian(parameters, **opt_args) 157 | elif opt_lower == 'rmsprop': 158 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 159 | elif opt_lower == 'rmsproptf': 160 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 161 | # elif opt_lower == 'novograd': 162 | # optimizer = NovoGrad(parameters, **opt_args) 163 | elif opt_lower == 'nvnovograd': 164 | optimizer = NvNovoGrad(parameters, **opt_args) 165 | elif opt_lower == 'fusedsgd': 166 | opt_args.pop('eps', None) 167 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 168 | elif opt_lower == 'fusedmomentum': 169 | opt_args.pop('eps', None) 170 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 171 | elif opt_lower == 'fusedadam': 172 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 173 | elif opt_lower == 'fusedadamw': 174 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 175 | elif opt_lower == 'fusedlamb': 176 | optimizer = FusedLAMB(parameters, **opt_args) 177 | elif opt_lower == 'fusednovograd': 178 | opt_args.setdefault('betas', (0.95, 0.98)) 179 | optimizer = FusedNovoGrad(parameters, **opt_args) 180 | else: 181 | assert False and "Invalid optimizer" 182 | raise ValueError 183 | 184 | if len(opt_split) > 1: 185 | if opt_split[0] == 'lookahead': 186 | optimizer = Lookahead(optimizer) 187 | 188 | return optimizer 189 | -------------------------------------------------------------------------------- /beit_finetuning/requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | blobfile 3 | mypy 4 | numpy 5 | pytest 6 | requests 7 | einops 8 | tensorboardX 9 | deepspeed==0.4.0 10 | scipy -------------------------------------------------------------------------------- /beit_finetuning/run_class_finetuning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # -------------------------------------------------------- 8 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 9 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 10 | # Copyright (c) 2021 Microsoft 11 | # Licensed under The MIT License [see LICENSE for details] 12 | # By Hangbo Bao 13 | # Based on timm, DINO and DeiT code bases 14 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 15 | # https://github.com/facebookresearch/deit 16 | # https://github.com/facebookresearch/dino 17 | # --------------------------------------------------------' 18 | import argparse 19 | import datetime 20 | import numpy as np 21 | import time 22 | import torch 23 | import torch.backends.cudnn as cudnn 24 | import json 25 | import os 26 | 27 | from pathlib import Path 28 | 29 | from timm.data.mixup import Mixup 30 | from timm.models import create_model 31 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 32 | from timm.utils import ModelEma 33 | from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner 34 | 35 | from datasets import build_dataset 36 | from engine_for_finetuning import train_one_epoch, evaluate 37 | from utils import NativeScalerWithGradNormCount as NativeScaler 38 | import utils 39 | from scipy import interpolate 40 | import modeling_finetune 41 | 42 | 43 | def get_args_parser(): 44 | parser = argparse.ArgumentParser('BEiT fine-tuning and evaluation script for image classification', add_help=False) 45 | parser.add_argument('--batch_size', default=64, type=int) 46 | parser.add_argument('--epochs', default=30, type=int) 47 | parser.add_argument('--update_freq', default=1, type=int) 48 | parser.add_argument('--save_ckpt_freq', default=5, type=int) 49 | 50 | # Model parameters 51 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', 52 | help='Name of model to train') 53 | parser.add_argument('--rel_pos_bias', action='store_true') 54 | parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias') 55 | parser.set_defaults(rel_pos_bias=True) 56 | parser.add_argument('--abs_pos_emb', action='store_true') 57 | parser.set_defaults(abs_pos_emb=False) 58 | parser.add_argument('--layer_scale_init_value', default=0.1, type=float, 59 | help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") 60 | 61 | parser.add_argument('--input_size', default=224, type=int, 62 | help='images input size') 63 | 64 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 65 | help='Dropout rate (default: 0.)') 66 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', 67 | help='Attention dropout rate (default: 0.)') 68 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 69 | help='Drop path rate (default: 0.1)') 70 | 71 | parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False) 72 | 73 | parser.add_argument('--model_ema', action='store_true', default=False) 74 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 75 | parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') 76 | 77 | # Optimizer parameters 78 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 79 | help='Optimizer (default: "adamw"') 80 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 81 | help='Optimizer Epsilon (default: 1e-8)') 82 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 83 | help='Optimizer Betas (default: None, use opt default)') 84 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 85 | help='Clip gradient norm (default: None, no clipping)') 86 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 87 | help='SGD momentum (default: 0.9)') 88 | parser.add_argument('--weight_decay', type=float, default=0.05, 89 | help='weight decay (default: 0.05)') 90 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 91 | weight decay. We use a cosine schedule for WD and using a larger decay by 92 | the end of training improves performance for ViTs.""") 93 | 94 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 95 | help='learning rate (default: 5e-4)') 96 | parser.add_argument('--layer_decay', type=float, default=0.9) 97 | 98 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', 99 | help='warmup learning rate (default: 1e-6)') 100 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 101 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 102 | 103 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 104 | help='epochs to warmup LR, if scheduler supports') 105 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 106 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 107 | 108 | # Augmentation parameters 109 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 110 | help='Color jitter factor (default: 0.4)') 111 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 112 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 113 | parser.add_argument('--smoothing', type=float, default=0.1, 114 | help='Label smoothing (default: 0.1)') 115 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 116 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 117 | 118 | # Evaluation parameters 119 | parser.add_argument('--crop_pct', type=float, default=None) 120 | 121 | # * Random Erase params 122 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 123 | help='Random erase prob (default: 0.25)') 124 | parser.add_argument('--remode', type=str, default='pixel', 125 | help='Random erase mode (default: "pixel")') 126 | parser.add_argument('--recount', type=int, default=1, 127 | help='Random erase count (default: 1)') 128 | parser.add_argument('--resplit', action='store_true', default=False, 129 | help='Do not random erase first (clean) augmentation split') 130 | 131 | # * Mixup params 132 | parser.add_argument('--mixup', type=float, default=0, 133 | help='mixup alpha, mixup enabled if > 0.') 134 | parser.add_argument('--cutmix', type=float, default=0, 135 | help='cutmix alpha, cutmix enabled if > 0.') 136 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 137 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 138 | parser.add_argument('--mixup_prob', type=float, default=1.0, 139 | help='Probability of performing mixup or cutmix when either/both is enabled') 140 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 141 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 142 | parser.add_argument('--mixup_mode', type=str, default='batch', 143 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 144 | 145 | # * Finetuning params 146 | parser.add_argument('--finetune', default='', 147 | help='finetune from checkpoint') 148 | parser.add_argument('--model_key', default='model|module', type=str) 149 | parser.add_argument('--model_prefix', default='', type=str) 150 | parser.add_argument('--init_scale', default=0.001, type=float) 151 | parser.add_argument('--use_mean_pooling', action='store_true') 152 | parser.set_defaults(use_mean_pooling=True) 153 | parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling') 154 | parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False) 155 | 156 | # Dataset parameters 157 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 158 | help='dataset path') 159 | parser.add_argument('--eval_data_path', default=None, type=str, 160 | help='dataset path for evaluation') 161 | parser.add_argument('--nb_classes', default=0, type=int, 162 | help='number of the classification types') 163 | parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true') 164 | 165 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'], 166 | type=str, help='ImageNet dataset path') 167 | parser.add_argument('--output_dir', default='', 168 | help='path where to save, empty for no saving') 169 | parser.add_argument('--log_dir', default=None, 170 | help='path where to tensorboard log') 171 | parser.add_argument('--device', default='cuda', 172 | help='device to use for training / testing') 173 | parser.add_argument('--seed', default=0, type=int) 174 | parser.add_argument('--resume', default='', 175 | help='resume from checkpoint') 176 | parser.add_argument('--auto_resume', action='store_true') 177 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') 178 | parser.set_defaults(auto_resume=True) 179 | 180 | parser.add_argument('--save_ckpt', action='store_true') 181 | parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt') 182 | parser.set_defaults(save_ckpt=True) 183 | 184 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 185 | help='start epoch') 186 | parser.add_argument('--eval', action='store_true', 187 | help='Perform evaluation only') 188 | parser.add_argument('--dist_eval', action='store_true', default=False, 189 | help='Enabling distributed evaluation') 190 | parser.add_argument('--num_workers', default=10, type=int) 191 | parser.add_argument('--pin_mem', action='store_true', 192 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 193 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 194 | parser.set_defaults(pin_mem=True) 195 | 196 | # distributed training parameters 197 | parser.add_argument('--world_size', default=1, type=int, 198 | help='number of distributed processes') 199 | parser.add_argument('--local_rank', default=-1, type=int) 200 | parser.add_argument('--dist_on_itp', action='store_true') 201 | parser.add_argument('--dist_url', default='env://', 202 | help='url used to set up distributed training') 203 | 204 | parser.add_argument('--enable_deepspeed', action='store_true', default=False) 205 | 206 | return parser 207 | 208 | 209 | def main(args, parser): 210 | if args.enable_deepspeed: 211 | try: 212 | import deepspeed 213 | from deepspeed import DeepSpeedConfig 214 | parser = deepspeed.add_config_arguments(parser) 215 | ds_init = deepspeed.initialize 216 | except: 217 | print("Please 'pip install deepspeed==0.4.0'") 218 | exit(0) 219 | else: 220 | ds_init = None 221 | 222 | utils.init_distributed_mode(args) 223 | 224 | if ds_init is not None: 225 | utils.create_ds_config(args) 226 | 227 | print(args) 228 | 229 | device = torch.device(args.device) 230 | 231 | # fix the seed for reproducibility 232 | seed = args.seed + utils.get_rank() 233 | torch.manual_seed(seed) 234 | np.random.seed(seed) 235 | # random.seed(seed) 236 | 237 | cudnn.benchmark = True 238 | 239 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 240 | if args.disable_eval_during_finetuning: 241 | dataset_val = None 242 | else: 243 | dataset_val, _ = build_dataset(is_train=False, args=args) 244 | 245 | if True: # args.distributed: 246 | num_tasks = utils.get_world_size() 247 | global_rank = utils.get_rank() 248 | sampler_train = torch.utils.data.DistributedSampler( 249 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 250 | ) 251 | print("Sampler_train = %s" % str(sampler_train)) 252 | if args.dist_eval: 253 | if len(dataset_val) % num_tasks != 0: 254 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 255 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 256 | 'equal num of samples per-process.') 257 | sampler_val = torch.utils.data.DistributedSampler( 258 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 259 | else: 260 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 261 | else: 262 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 263 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 264 | 265 | if global_rank == 0 and args.log_dir is not None: 266 | os.makedirs(args.log_dir, exist_ok=True) 267 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 268 | else: 269 | log_writer = None 270 | 271 | data_loader_train = torch.utils.data.DataLoader( 272 | dataset_train, sampler=sampler_train, 273 | batch_size=args.batch_size, 274 | num_workers=args.num_workers, 275 | pin_memory=args.pin_mem, 276 | drop_last=True, 277 | ) 278 | 279 | if dataset_val is not None: 280 | data_loader_val = torch.utils.data.DataLoader( 281 | dataset_val, sampler=sampler_val, 282 | batch_size=int(1.5 * args.batch_size), 283 | num_workers=args.num_workers, 284 | pin_memory=args.pin_mem, 285 | drop_last=False 286 | ) 287 | else: 288 | data_loader_val = None 289 | 290 | mixup_fn = None 291 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 292 | if mixup_active: 293 | print("Mixup is activated!") 294 | mixup_fn = Mixup( 295 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 296 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 297 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 298 | 299 | model = create_model( 300 | args.model, 301 | pretrained=False, 302 | num_classes=args.nb_classes, 303 | drop_rate=args.drop, 304 | drop_path_rate=args.drop_path, 305 | attn_drop_rate=args.attn_drop_rate, 306 | drop_block_rate=None, 307 | use_mean_pooling=args.use_mean_pooling, 308 | init_scale=args.init_scale, 309 | use_rel_pos_bias=args.rel_pos_bias, 310 | use_abs_pos_emb=args.abs_pos_emb, 311 | init_values=args.layer_scale_init_value, 312 | ) 313 | 314 | patch_size = model.patch_embed.patch_size 315 | print("Patch size = %s" % str(patch_size)) 316 | args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1]) 317 | args.patch_size = patch_size 318 | 319 | if args.finetune: 320 | if args.finetune.startswith('https'): 321 | checkpoint = torch.hub.load_state_dict_from_url( 322 | args.finetune, map_location='cpu', check_hash=True) 323 | else: 324 | checkpoint = torch.load(args.finetune, map_location='cpu') 325 | 326 | print("Load ckpt from %s" % args.finetune) 327 | checkpoint_model = None 328 | for model_key in args.model_key.split('|'): 329 | if model_key in checkpoint: 330 | checkpoint_model = checkpoint[model_key] 331 | print("Load state_dict by model_key = %s" % model_key) 332 | break 333 | if checkpoint_model is None: 334 | checkpoint_model = checkpoint 335 | state_dict = model.state_dict() 336 | for k in ['head.weight', 'head.bias']: 337 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 338 | print(f"Removing key {k} from pretrained checkpoint") 339 | del checkpoint_model[k] 340 | 341 | if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model: 342 | print("Expand the shared relative position embedding to each transformer block. ") 343 | num_layers = model.get_num_layers() 344 | rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"] 345 | for i in range(num_layers): 346 | checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() 347 | 348 | checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") 349 | 350 | all_keys = list(checkpoint_model.keys()) 351 | for key in all_keys: 352 | if "relative_position_index" in key: 353 | checkpoint_model.pop(key) 354 | 355 | if "relative_position_bias_table" in key: 356 | rel_pos_bias = checkpoint_model[key] 357 | src_num_pos, num_attn_heads = rel_pos_bias.size() 358 | dst_num_pos, _ = model.state_dict()[key].size() 359 | dst_patch_shape = model.patch_embed.patch_shape 360 | if dst_patch_shape[0] != dst_patch_shape[1]: 361 | raise NotImplementedError() 362 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) 363 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 364 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 365 | if src_size != dst_size: 366 | print("Position interpolate for %s from %dx%d to %dx%d" % ( 367 | key, src_size, src_size, dst_size, dst_size)) 368 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 369 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 370 | 371 | def geometric_progression(a, r, n): 372 | return a * (1.0 - r ** n) / (1.0 - r) 373 | 374 | left, right = 1.01, 1.5 375 | while right - left > 1e-6: 376 | q = (left + right) / 2.0 377 | gp = geometric_progression(1, q, src_size // 2) 378 | if gp > dst_size // 2: 379 | right = q 380 | else: 381 | left = q 382 | 383 | # if q > 1.090307: 384 | # q = 1.090307 385 | 386 | dis = [] 387 | cur = 1 388 | for i in range(src_size // 2): 389 | dis.append(cur) 390 | cur += q ** (i + 1) 391 | 392 | r_ids = [-_ for _ in reversed(dis)] 393 | 394 | x = r_ids + [0] + dis 395 | y = r_ids + [0] + dis 396 | 397 | t = dst_size // 2.0 398 | dx = np.arange(-t, t + 0.1, 1.0) 399 | dy = np.arange(-t, t + 0.1, 1.0) 400 | 401 | print("Original positions = %s" % str(x)) 402 | print("Target positions = %s" % str(dx)) 403 | 404 | all_rel_pos_bias = [] 405 | 406 | for i in range(num_attn_heads): 407 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 408 | f = interpolate.interp2d(x, y, z, kind='cubic') 409 | all_rel_pos_bias.append( 410 | torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) 411 | 412 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 413 | 414 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 415 | checkpoint_model[key] = new_rel_pos_bias 416 | 417 | # interpolate position embedding 418 | if 'pos_embed' in checkpoint_model: 419 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 420 | embedding_size = pos_embed_checkpoint.shape[-1] 421 | num_patches = model.patch_embed.num_patches 422 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 423 | # height (== width) for the checkpoint position embedding 424 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 425 | # height (== width) for the new position embedding 426 | new_size = int(num_patches ** 0.5) 427 | # class_token and dist_token are kept unchanged 428 | if orig_size != new_size: 429 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 430 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 431 | # only the position tokens are interpolated 432 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 433 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 434 | pos_tokens = torch.nn.functional.interpolate( 435 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 436 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 437 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 438 | checkpoint_model['pos_embed'] = new_pos_embed 439 | 440 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 441 | # model.load_state_dict(checkpoint_model, strict=False) 442 | 443 | model.to(device) 444 | 445 | model_ema = None 446 | if args.model_ema: 447 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 448 | model_ema = ModelEma( 449 | model, 450 | decay=args.model_ema_decay, 451 | device='cpu' if args.model_ema_force_cpu else '', 452 | resume='') 453 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 454 | 455 | model_without_ddp = model 456 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 457 | 458 | print("Model = %s" % str(model_without_ddp)) 459 | print('number of params:', n_parameters) 460 | 461 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 462 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 463 | print("LR = %.8f" % args.lr) 464 | print("Batch size = %d" % total_batch_size) 465 | print("Update frequent = %d" % args.update_freq) 466 | print("Number of training examples = %d" % len(dataset_train)) 467 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 468 | 469 | num_layers = model_without_ddp.get_num_layers() 470 | if args.layer_decay < 1.0: 471 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 472 | else: 473 | assigner = None 474 | 475 | if assigner is not None: 476 | print("Assigned values = %s" % str(assigner.values)) 477 | 478 | skip_weight_decay_list = model.no_weight_decay() 479 | if args.disable_weight_decay_on_rel_pos_bias: 480 | for i in range(num_layers): 481 | skip_weight_decay_list.add("blocks.%d.attn.relative_position_bias_table" % i) 482 | 483 | if args.enable_deepspeed: 484 | loss_scaler = None 485 | optimizer_params = get_parameter_groups( 486 | model, args.weight_decay, skip_weight_decay_list, 487 | assigner.get_layer_id if assigner is not None else None, 488 | assigner.get_scale if assigner is not None else None) 489 | model, optimizer, _, _ = ds_init( 490 | args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed, 491 | ) 492 | 493 | print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps()) 494 | assert model.gradient_accumulation_steps() == args.update_freq 495 | else: 496 | if args.distributed: 497 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 498 | model_without_ddp = model.module 499 | 500 | optimizer = create_optimizer( 501 | args, model_without_ddp, skip_list=skip_weight_decay_list, 502 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 503 | get_layer_scale=assigner.get_scale if assigner is not None else None) 504 | loss_scaler = NativeScaler() 505 | 506 | print("Use step level LR scheduler!") 507 | lr_schedule_values = utils.cosine_scheduler( 508 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 509 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 510 | ) 511 | if args.weight_decay_end is None: 512 | args.weight_decay_end = args.weight_decay 513 | wd_schedule_values = utils.cosine_scheduler( 514 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 515 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 516 | 517 | if mixup_fn is not None: 518 | # smoothing is handled with mixup label transform 519 | criterion = SoftTargetCrossEntropy() 520 | elif args.smoothing > 0.: 521 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 522 | else: 523 | criterion = torch.nn.CrossEntropyLoss() 524 | 525 | print("criterion = %s" % str(criterion)) 526 | 527 | utils.auto_load_model( 528 | args=args, model=model, model_without_ddp=model_without_ddp, 529 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 530 | 531 | if args.eval: 532 | test_stats = evaluate(data_loader_val, model, device) 533 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 534 | exit(0) 535 | 536 | print(f"Start training for {args.epochs} epochs") 537 | start_time = time.time() 538 | max_accuracy = 0.0 539 | for epoch in range(args.start_epoch, args.epochs): 540 | if args.distributed: 541 | data_loader_train.sampler.set_epoch(epoch) 542 | if log_writer is not None: 543 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 544 | train_stats = train_one_epoch( 545 | model, criterion, data_loader_train, optimizer, 546 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 547 | log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch, 548 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, 549 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 550 | ) 551 | if args.output_dir and args.save_ckpt: 552 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 553 | utils.save_model( 554 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 555 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 556 | if data_loader_val is not None: 557 | test_stats = evaluate(data_loader_val, model, device) 558 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 559 | if max_accuracy < test_stats["acc1"]: 560 | max_accuracy = test_stats["acc1"] 561 | if args.output_dir and args.save_ckpt: 562 | utils.save_model( 563 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 564 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 565 | 566 | print(f'Max accuracy: {max_accuracy:.2f}%') 567 | if log_writer is not None: 568 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 569 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 570 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 571 | 572 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 573 | **{f'test_{k}': v for k, v in test_stats.items()}, 574 | 'epoch': epoch, 575 | 'n_parameters': n_parameters} 576 | else: 577 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 578 | # **{f'test_{k}': v for k, v in test_stats.items()}, 579 | 'epoch': epoch, 580 | 'n_parameters': n_parameters} 581 | 582 | if args.output_dir and utils.is_main_process(): 583 | if log_writer is not None: 584 | log_writer.flush() 585 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 586 | f.write(json.dumps(log_stats) + "\n") 587 | 588 | total_time = time.time() - start_time 589 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 590 | print('Training time {}'.format(total_time_str)) 591 | 592 | 593 | if __name__ == '__main__': 594 | parser = get_args_parser() 595 | args, _ = parser.parse_known_args() 596 | if args.output_dir: 597 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 598 | main(args, parser) 599 | -------------------------------------------------------------------------------- /beit_finetuning/run_with_submitit_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | A script to run multinode training with submitit. 8 | """ 9 | import argparse 10 | import os 11 | import uuid 12 | from pathlib import Path 13 | 14 | import run_class_finetuning 15 | import submitit 16 | 17 | 18 | def parse_args(): 19 | beit_parser = run_class_finetuning.get_args_parser() 20 | parser = argparse.ArgumentParser("Submitit for BEIT", parents=[beit_parser]) 21 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 22 | parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request") 23 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 24 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 25 | 26 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") 27 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 28 | parser.add_argument('--comment', default="", type=str, 29 | help='Comment to pass to scheduler, e.g. priority message') 30 | args, _ = parser.parse_known_args() 31 | return args, parser 32 | 33 | 34 | def get_shared_folder() -> Path: 35 | user = os.getenv("USER") 36 | if Path("/checkpoint/").is_dir(): 37 | p = Path(f"/checkpoint/{user}/experiments/slip") 38 | p.mkdir(exist_ok=True) 39 | return p 40 | raise RuntimeError("No shared folder available") 41 | 42 | 43 | def get_init_file(): 44 | # Init file must not exist, but it's parent dir must exist. 45 | os.makedirs(str(get_shared_folder()), exist_ok=True) 46 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 47 | if init_file.exists(): 48 | os.remove(str(init_file)) 49 | return init_file 50 | 51 | 52 | class Trainer(object): 53 | def __init__(self, args, parser): 54 | self.args = args 55 | self.parser = parser 56 | 57 | def __call__(self): 58 | import run_class_finetuning 59 | 60 | self._setup_gpu_args() 61 | run_class_finetuning.main(self.args, self.parser) 62 | 63 | def checkpoint(self): 64 | import os 65 | import submitit 66 | 67 | self.args.dist_url = get_init_file().as_uri() 68 | print("Requeuing ", self.args) 69 | empty_trainer = type(self)(self.args) 70 | return submitit.helpers.DelayedSubmission(empty_trainer) 71 | 72 | def _setup_gpu_args(self): 73 | import submitit 74 | from pathlib import Path 75 | 76 | job_env = submitit.JobEnvironment() 77 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 78 | self.args.gpu = job_env.local_rank 79 | self.args.rank = job_env.global_rank 80 | self.args.world_size = job_env.num_tasks 81 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 82 | 83 | 84 | def main(): 85 | args, parser = parse_args() 86 | 87 | if args.job_dir == "": 88 | args.job_dir = get_shared_folder() / "%j" 89 | 90 | # Note that the folder will depend on the job_id, to easily track experiments 91 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 92 | 93 | num_gpus_per_node = args.ngpus 94 | nodes = args.nodes 95 | timeout_min = args.timeout 96 | 97 | partition = args.partition 98 | kwargs = {} 99 | if args.use_volta32: 100 | kwargs['slurm_constraint'] = 'volta32gb' 101 | if args.comment: 102 | kwargs['slurm_comment'] = args.comment 103 | 104 | executor.update_parameters( 105 | mem_gb=40 * num_gpus_per_node, 106 | gpus_per_node=num_gpus_per_node, 107 | tasks_per_node=num_gpus_per_node, # one task per GPU 108 | cpus_per_task=10, 109 | nodes=nodes, 110 | timeout_min=timeout_min, # max is 60 * 72 111 | # Below are cluster dependent parameters 112 | slurm_partition=partition, 113 | slurm_signal_delay_s=120, 114 | **kwargs 115 | ) 116 | 117 | executor.update_parameters(name="beit") 118 | 119 | args.dist_url = get_init_file().as_uri() 120 | args.output_dir = args.job_dir 121 | 122 | trainer = Trainer(args, parser) 123 | job = executor.submit(trainer) 124 | 125 | print("Submitted job_id:", job.job_id) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /beit_finetuning/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # -------------------------------------------------------- 8 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 9 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 10 | # Copyright (c) 2021 Microsoft 11 | # Licensed under The MIT License [see LICENSE for details] 12 | # By Hangbo Bao 13 | # Based on timm, DINO and DeiT code bases 14 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 15 | # https://github.com/facebookresearch/deit 16 | # https://github.com/facebookresearch/dino 17 | # --------------------------------------------------------' 18 | import io 19 | import os 20 | import math 21 | import time 22 | import json 23 | from collections import defaultdict, deque 24 | import datetime 25 | import numpy as np 26 | from timm.utils import get_state_dict 27 | 28 | from pathlib import Path 29 | 30 | import torch 31 | import torch.distributed as dist 32 | from torch._six import inf 33 | from modeling_discrete_vae import Dalle_VAE, DiscreteVAE 34 | 35 | from tensorboardX import SummaryWriter 36 | 37 | 38 | class SmoothedValue(object): 39 | """Track a series of values and provide access to smoothed values over a 40 | window or the global series average. 41 | """ 42 | 43 | def __init__(self, window_size=20, fmt=None): 44 | if fmt is None: 45 | fmt = "{median:.4f} ({global_avg:.4f})" 46 | self.deque = deque(maxlen=window_size) 47 | self.total = 0.0 48 | self.count = 0 49 | self.fmt = fmt 50 | 51 | def update(self, value, n=1): 52 | self.deque.append(value) 53 | self.count += n 54 | self.total += value * n 55 | 56 | def synchronize_between_processes(self): 57 | """ 58 | Warning: does not synchronize the deque! 59 | """ 60 | if not is_dist_avail_and_initialized(): 61 | return 62 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 63 | dist.barrier() 64 | dist.all_reduce(t) 65 | t = t.tolist() 66 | self.count = int(t[0]) 67 | self.total = t[1] 68 | 69 | @property 70 | def median(self): 71 | d = torch.tensor(list(self.deque)) 72 | return d.median().item() 73 | 74 | @property 75 | def avg(self): 76 | d = torch.tensor(list(self.deque), dtype=torch.float32) 77 | return d.mean().item() 78 | 79 | @property 80 | def global_avg(self): 81 | return self.total / self.count 82 | 83 | @property 84 | def max(self): 85 | return max(self.deque) 86 | 87 | @property 88 | def value(self): 89 | return self.deque[-1] 90 | 91 | def __str__(self): 92 | return self.fmt.format( 93 | median=self.median, 94 | avg=self.avg, 95 | global_avg=self.global_avg, 96 | max=self.max, 97 | value=self.value) 98 | 99 | 100 | class MetricLogger(object): 101 | def __init__(self, delimiter="\t"): 102 | self.meters = defaultdict(SmoothedValue) 103 | self.delimiter = delimiter 104 | 105 | def update(self, **kwargs): 106 | for k, v in kwargs.items(): 107 | if v is None: 108 | continue 109 | if isinstance(v, torch.Tensor): 110 | v = v.item() 111 | assert isinstance(v, (float, int)) 112 | self.meters[k].update(v) 113 | 114 | def __getattr__(self, attr): 115 | if attr in self.meters: 116 | return self.meters[attr] 117 | if attr in self.__dict__: 118 | return self.__dict__[attr] 119 | raise AttributeError("'{}' object has no attribute '{}'".format( 120 | type(self).__name__, attr)) 121 | 122 | def __str__(self): 123 | loss_str = [] 124 | for name, meter in self.meters.items(): 125 | loss_str.append( 126 | "{}: {}".format(name, str(meter)) 127 | ) 128 | return self.delimiter.join(loss_str) 129 | 130 | def synchronize_between_processes(self): 131 | for meter in self.meters.values(): 132 | meter.synchronize_between_processes() 133 | 134 | def add_meter(self, name, meter): 135 | self.meters[name] = meter 136 | 137 | def log_every(self, iterable, print_freq, header=None): 138 | i = 0 139 | if not header: 140 | header = '' 141 | start_time = time.time() 142 | end = time.time() 143 | iter_time = SmoothedValue(fmt='{avg:.4f}') 144 | data_time = SmoothedValue(fmt='{avg:.4f}') 145 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 146 | log_msg = [ 147 | header, 148 | '[{0' + space_fmt + '}/{1}]', 149 | 'eta: {eta}', 150 | '{meters}', 151 | 'time: {time}', 152 | 'data: {data}' 153 | ] 154 | if torch.cuda.is_available(): 155 | log_msg.append('max mem: {memory:.0f}') 156 | log_msg = self.delimiter.join(log_msg) 157 | MB = 1024.0 * 1024.0 158 | for obj in iterable: 159 | data_time.update(time.time() - end) 160 | yield obj 161 | iter_time.update(time.time() - end) 162 | if i % print_freq == 0 or i == len(iterable) - 1: 163 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 164 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 165 | if torch.cuda.is_available(): 166 | print(log_msg.format( 167 | i, len(iterable), eta=eta_string, 168 | meters=str(self), 169 | time=str(iter_time), data=str(data_time), 170 | memory=torch.cuda.max_memory_allocated() / MB)) 171 | else: 172 | print(log_msg.format( 173 | i, len(iterable), eta=eta_string, 174 | meters=str(self), 175 | time=str(iter_time), data=str(data_time))) 176 | i += 1 177 | end = time.time() 178 | total_time = time.time() - start_time 179 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 180 | print('{} Total time: {} ({:.4f} s / it)'.format( 181 | header, total_time_str, total_time / len(iterable))) 182 | 183 | 184 | class TensorboardLogger(object): 185 | def __init__(self, log_dir): 186 | self.writer = SummaryWriter(logdir=log_dir) 187 | self.step = 0 188 | 189 | def set_step(self, step=None): 190 | if step is not None: 191 | self.step = step 192 | else: 193 | self.step += 1 194 | 195 | def update(self, head='scalar', step=None, **kwargs): 196 | for k, v in kwargs.items(): 197 | if v is None: 198 | continue 199 | if isinstance(v, torch.Tensor): 200 | v = v.item() 201 | assert isinstance(v, (float, int)) 202 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 203 | 204 | def flush(self): 205 | self.writer.flush() 206 | 207 | 208 | def _load_checkpoint_for_ema(model_ema, checkpoint): 209 | """ 210 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 211 | """ 212 | mem_file = io.BytesIO() 213 | torch.save(checkpoint, mem_file) 214 | mem_file.seek(0) 215 | model_ema._load_checkpoint(mem_file) 216 | 217 | 218 | def setup_for_distributed(is_master): 219 | """ 220 | This function disables printing when not in master process 221 | """ 222 | import builtins as __builtin__ 223 | builtin_print = __builtin__.print 224 | 225 | def print(*args, **kwargs): 226 | force = kwargs.pop('force', False) 227 | if is_master or force: 228 | builtin_print(*args, **kwargs) 229 | 230 | __builtin__.print = print 231 | 232 | 233 | def is_dist_avail_and_initialized(): 234 | if not dist.is_available(): 235 | return False 236 | if not dist.is_initialized(): 237 | return False 238 | return True 239 | 240 | 241 | def get_world_size(): 242 | if not is_dist_avail_and_initialized(): 243 | return 1 244 | return dist.get_world_size() 245 | 246 | 247 | def get_rank(): 248 | if not is_dist_avail_and_initialized(): 249 | return 0 250 | return dist.get_rank() 251 | 252 | 253 | def is_main_process(): 254 | return get_rank() == 0 255 | 256 | 257 | def save_on_master(*args, **kwargs): 258 | if is_main_process(): 259 | torch.save(*args, **kwargs) 260 | 261 | 262 | def init_distributed_mode(args): 263 | if args.dist_on_itp: 264 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 265 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 266 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 267 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 268 | os.environ['LOCAL_RANK'] = str(args.gpu) 269 | os.environ['RANK'] = str(args.rank) 270 | os.environ['WORLD_SIZE'] = str(args.world_size) 271 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 272 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 273 | args.rank = int(os.environ["RANK"]) 274 | args.world_size = int(os.environ['WORLD_SIZE']) 275 | args.gpu = int(os.environ['LOCAL_RANK']) 276 | elif 'SLURM_PROCID' in os.environ: 277 | args.rank = int(os.environ['SLURM_PROCID']) 278 | args.gpu = args.rank % torch.cuda.device_count() 279 | os.environ['LOCAL_RANK'] = str(args.gpu) 280 | os.environ['RANK'] = str(args.rank) 281 | os.environ['WORLD_SIZE'] = str(args.world_size) 282 | else: 283 | print('Not using distributed mode') 284 | args.distributed = False 285 | return 286 | 287 | args.distributed = True 288 | 289 | torch.cuda.set_device(args.gpu) 290 | args.dist_backend = 'nccl' 291 | print('| distributed init (rank {}): {}, gpu {}'.format( 292 | args.rank, args.dist_url, args.gpu), flush=True) 293 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 294 | world_size=args.world_size, rank=args.rank) 295 | torch.distributed.barrier() 296 | setup_for_distributed(args.rank == 0) 297 | 298 | 299 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 300 | missing_keys = [] 301 | unexpected_keys = [] 302 | error_msgs = [] 303 | # copy state_dict so _load_from_state_dict can modify it 304 | metadata = getattr(state_dict, '_metadata', None) 305 | state_dict = state_dict.copy() 306 | if metadata is not None: 307 | state_dict._metadata = metadata 308 | 309 | def load(module, prefix=''): 310 | local_metadata = {} if metadata is None else metadata.get( 311 | prefix[:-1], {}) 312 | module._load_from_state_dict( 313 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 314 | for name, child in module._modules.items(): 315 | if child is not None: 316 | load(child, prefix + name + '.') 317 | 318 | load(model, prefix=prefix) 319 | 320 | warn_missing_keys = [] 321 | ignore_missing_keys = [] 322 | for key in missing_keys: 323 | keep_flag = True 324 | for ignore_key in ignore_missing.split('|'): 325 | if ignore_key in key: 326 | keep_flag = False 327 | break 328 | if keep_flag: 329 | warn_missing_keys.append(key) 330 | else: 331 | ignore_missing_keys.append(key) 332 | 333 | missing_keys = warn_missing_keys 334 | 335 | if len(missing_keys) > 0: 336 | print("Weights of {} not initialized from pretrained model: {}".format( 337 | model.__class__.__name__, missing_keys)) 338 | if len(unexpected_keys) > 0: 339 | print("Weights from pretrained model not used in {}: {}".format( 340 | model.__class__.__name__, unexpected_keys)) 341 | if len(ignore_missing_keys) > 0: 342 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 343 | model.__class__.__name__, ignore_missing_keys)) 344 | if len(error_msgs) > 0: 345 | print('\n'.join(error_msgs)) 346 | 347 | 348 | class NativeScalerWithGradNormCount: 349 | state_dict_key = "amp_scaler" 350 | 351 | def __init__(self): 352 | self._scaler = torch.cuda.amp.GradScaler() 353 | 354 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 355 | self._scaler.scale(loss).backward(create_graph=create_graph) 356 | if update_grad: 357 | if clip_grad is not None: 358 | assert parameters is not None 359 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 360 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 361 | else: 362 | self._scaler.unscale_(optimizer) 363 | norm = get_grad_norm_(parameters) 364 | self._scaler.step(optimizer) 365 | self._scaler.update() 366 | else: 367 | norm = None 368 | return norm 369 | 370 | def state_dict(self): 371 | return self._scaler.state_dict() 372 | 373 | def load_state_dict(self, state_dict): 374 | self._scaler.load_state_dict(state_dict) 375 | 376 | 377 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 378 | if isinstance(parameters, torch.Tensor): 379 | parameters = [parameters] 380 | parameters = [p for p in parameters if p.grad is not None] 381 | norm_type = float(norm_type) 382 | if len(parameters) == 0: 383 | return torch.tensor(0.) 384 | device = parameters[0].grad.device 385 | if norm_type == inf: 386 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 387 | else: 388 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 389 | return total_norm 390 | 391 | 392 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 393 | start_warmup_value=0, warmup_steps=-1): 394 | warmup_schedule = np.array([]) 395 | warmup_iters = warmup_epochs * niter_per_ep 396 | if warmup_steps > 0: 397 | warmup_iters = warmup_steps 398 | print("Set warmup steps = %d" % warmup_iters) 399 | if warmup_epochs > 0: 400 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 401 | 402 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 403 | schedule = np.array( 404 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 405 | 406 | schedule = np.concatenate((warmup_schedule, schedule)) 407 | 408 | assert len(schedule) == epochs * niter_per_ep 409 | return schedule 410 | 411 | 412 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 413 | output_dir = Path(args.output_dir) 414 | epoch_name = str(epoch) 415 | if loss_scaler is not None: 416 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 417 | for checkpoint_path in checkpoint_paths: 418 | to_save = { 419 | 'model': model_without_ddp.state_dict(), 420 | 'optimizer': optimizer.state_dict(), 421 | 'epoch': epoch, 422 | 'scaler': loss_scaler.state_dict(), 423 | 'args': args, 424 | } 425 | 426 | if model_ema is not None: 427 | to_save['model_ema'] = get_state_dict(model_ema) 428 | 429 | save_on_master(to_save, checkpoint_path) 430 | else: 431 | client_state = {'epoch': epoch} 432 | if model_ema is not None: 433 | client_state['model_ema'] = get_state_dict(model_ema) 434 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 435 | 436 | 437 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 438 | output_dir = Path(args.output_dir) 439 | if loss_scaler is not None: 440 | # torch.amp 441 | if args.auto_resume and len(args.resume) == 0: 442 | import glob 443 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 444 | latest_ckpt = -1 445 | for ckpt in all_checkpoints: 446 | t = ckpt.split('-')[-1].split('.')[0] 447 | if t.isdigit(): 448 | latest_ckpt = max(int(t), latest_ckpt) 449 | if latest_ckpt >= 0: 450 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 451 | print("Auto resume checkpoint: %s" % args.resume) 452 | 453 | if args.resume: 454 | if args.resume.startswith('https'): 455 | checkpoint = torch.hub.load_state_dict_from_url( 456 | args.resume, map_location='cpu', check_hash=True) 457 | else: 458 | checkpoint = torch.load(args.resume, map_location='cpu') 459 | model_without_ddp.load_state_dict(checkpoint['model']) 460 | print("Resume checkpoint %s" % args.resume) 461 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 462 | optimizer.load_state_dict(checkpoint['optimizer']) 463 | args.start_epoch = checkpoint['epoch'] + 1 464 | if hasattr(args, 'model_ema') and args.model_ema: 465 | _load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 466 | if 'scaler' in checkpoint: 467 | loss_scaler.load_state_dict(checkpoint['scaler']) 468 | print("With optim & sched!") 469 | else: 470 | # deepspeed, only support '--auto_resume'. 471 | if args.auto_resume: 472 | import glob 473 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) 474 | latest_ckpt = -1 475 | for ckpt in all_checkpoints: 476 | t = ckpt.split('-')[-1].split('.')[0] 477 | if t.isdigit(): 478 | latest_ckpt = max(int(t), latest_ckpt) 479 | if latest_ckpt >= 0: 480 | args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt) 481 | print("Auto resume checkpoint: %d" % latest_ckpt) 482 | _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt) 483 | args.start_epoch = client_states['epoch'] + 1 484 | if model_ema is not None: 485 | if args.model_ema: 486 | _load_checkpoint_for_ema(model_ema, client_states['model_ema']) 487 | 488 | 489 | def create_d_vae(weight_path, d_vae_type, image_size, device): 490 | if d_vae_type == "dall-e": 491 | return get_dalle_vae(weight_path, image_size, device) 492 | elif d_vae_type == "customized": 493 | return get_d_vae(weight_path, image_size, device) 494 | else: 495 | raise NotImplementedError() 496 | 497 | 498 | def get_dalle_vae(weight_path, image_size, device): 499 | vae = Dalle_VAE(image_size) 500 | vae.load_model(model_dir=weight_path, device=device) 501 | return vae 502 | 503 | 504 | def get_d_vae(weight_path, image_size, device): 505 | NUM_TOKENS = 8192 506 | NUM_LAYERS = 3 507 | EMB_DIM = 512 508 | HID_DIM = 256 509 | 510 | state_dict = torch.load(os.path.join(weight_path, "pytorch_model.bin"), map_location="cpu")["weights"] 511 | 512 | model = DiscreteVAE( 513 | image_size=image_size, 514 | num_layers=NUM_LAYERS, 515 | num_tokens=NUM_TOKENS, 516 | codebook_dim=EMB_DIM, 517 | hidden_dim=HID_DIM, 518 | ).to(device) 519 | 520 | model.load_state_dict(state_dict) 521 | return model 522 | 523 | 524 | def create_ds_config(args): 525 | args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") 526 | with open(args.deepspeed_config, mode="w") as writer: 527 | ds_config = { 528 | "train_batch_size": args.batch_size * args.update_freq * get_world_size(), 529 | "train_micro_batch_size_per_gpu": args.batch_size, 530 | "steps_per_print": 1000, 531 | "optimizer": { 532 | "type": "Adam", 533 | "adam_w_mode": True, 534 | "params": { 535 | "lr": args.lr, 536 | "weight_decay": args.weight_decay, 537 | "bias_correction": True, 538 | "betas": [ 539 | 0.9, 540 | 0.999 541 | ], 542 | "eps": 1e-8 543 | } 544 | }, 545 | "fp16": { 546 | "enabled": True, 547 | "loss_scale": 0, 548 | "initial_scale_power": 7, 549 | "loss_scale_window": 128 550 | } 551 | } 552 | 553 | writer.write(json.dumps(ds_config, indent=2)) 554 | -------------------------------------------------------------------------------- /bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SLIP/c6faf5d03cbfa7d529d210779f859cd3dddec09a/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /dataset_catalog.json: -------------------------------------------------------------------------------- 1 | { 2 | "food101": { 3 | "path": "/path/to/dataset", 4 | "type": "imagefolder", 5 | "train": "train", 6 | "test": "test" 7 | }, 8 | "cifar10": { 9 | "path": "./data/cifar10", 10 | "type": "special" 11 | }, 12 | "cifar100": { 13 | "path": "./data/cifar100", 14 | "type": "special" 15 | }, 16 | "cub200": { 17 | "path": "/path/to/dataset", 18 | "type": "imagefolder", 19 | "train": "train", 20 | "test": "val" 21 | }, 22 | "sun397": { 23 | "path": "/path/to/dataset", 24 | "type": "filelist", 25 | "train": "train", 26 | "test": "test" 27 | }, 28 | "cars": { 29 | "path": "/path/to/dataset", 30 | "type": "imagefolder", 31 | "train": "train", 32 | "test": "test" 33 | }, 34 | "aircraft": { 35 | "path": "/path/to/dataset", 36 | "type": "imagefolder", 37 | "train": "train", 38 | "test": "test" 39 | }, 40 | "dtd": { 41 | "path": "/path/to/dataset", 42 | "type": "imagefolder", 43 | "train": "train", 44 | "test": "test" 45 | }, 46 | "pets": { 47 | "path": "/path/to/dataset", 48 | "type": "imagefolder", 49 | "train": "train", 50 | "test": "test" 51 | }, 52 | "caltech101": { 53 | "path": "/path/to/dataset", 54 | "type": "imagefolder", 55 | "train": "train", 56 | "test": "test" 57 | }, 58 | "flowers": { 59 | "path": "/path/to/dataset", 60 | "type": "imagefolder", 61 | "train": "train", 62 | "test": "test" 63 | }, 64 | "mnist": { 65 | "path": "./data/mnist", 66 | "type": "special" 67 | }, 68 | "fer2013": { 69 | "path": "/path/to/dataset", 70 | "type": "imagefolder", 71 | "train": "train", 72 | "test": "test" 73 | }, 74 | "stl10": { 75 | "path": "./data/stl10", 76 | "type": "special" 77 | }, 78 | "eurosat": { 79 | "path": "/path/to/dataset", 80 | "type": "imagefolder", 81 | "train": "train", 82 | "test": "val" 83 | }, 84 | "resisc45": { 85 | "path": "/path/to/dataset", 86 | "type": "imagefolder", 87 | "train": "train", 88 | "test": "test" 89 | }, 90 | "gtsrb": { 91 | "path": "/path/to/dataset", 92 | "type": "imagefolder", 93 | "train": "train", 94 | "test": "test" 95 | }, 96 | "kitti_distance": { 97 | "path": "/path/to/dataset", 98 | "type": "imagefolder", 99 | "train": "train", 100 | "test": "val" 101 | }, 102 | "country211": { 103 | "path": "/path/to/dataset", 104 | "type": "imagefolder", 105 | "train": "train", 106 | "test": "test" 107 | }, 108 | "patch_camelyon": { 109 | "path": "/path/to/dataset", 110 | "type": "imagefolder", 111 | "train": "train", 112 | "test": "val" 113 | }, 114 | "ucf101_frames": { 115 | "path": "/path/to/dataset", 116 | "type": "imagefolder", 117 | "train": "train", 118 | "test": "val" 119 | }, 120 | "kinetics700_frames": { 121 | "path": "/path/to/dataset", 122 | "type": "imagefolder", 123 | "train": "train_images", 124 | "test": "val_images" 125 | }, 126 | "clevr_counts": { 127 | "path": "/path/to/dataset", 128 | "type": "filelist", 129 | "train": "train", 130 | "test": "val" 131 | }, 132 | "hateful_memes": { 133 | "path": "/path/to/dataset", 134 | "type": "imagefolder", 135 | "train": "train", 136 | "test": "dev" 137 | }, 138 | "rendered_sst2": { 139 | "path": "/path/to/dataset", 140 | "type": "imagefolder", 141 | "train": "train", 142 | "test": "test" 143 | }, 144 | "imagenet": { 145 | "path": "/path/to/dataset", 146 | "type": "imagefolder", 147 | "train": "train", 148 | "test": "val" 149 | } 150 | } -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from collections import defaultdict 7 | import json 8 | import os 9 | import pickle 10 | import zipfile 11 | 12 | import numpy as np 13 | from PIL import Image, ImageFile 14 | 15 | import torch 16 | from torchvision import transforms 17 | from torchvision import datasets as t_datasets 18 | 19 | import utils 20 | 21 | 22 | ImageFile.LOAD_TRUNCATED_IMAGES = True 23 | 24 | 25 | def pil_loader(path): 26 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 27 | with open(path, 'rb') as f: 28 | img = Image.open(f) 29 | return img.convert('RGB') 30 | 31 | 32 | def yfcc_loader(root, index): 33 | index = format(index, "0>8d") 34 | repo = index[:2] 35 | z = index[2: 5] 36 | file_img = index[5:] + '.jpg' 37 | path_zip = os.path.join(root, 'images', repo, z) + '.zip' 38 | with zipfile.ZipFile(path_zip, 'r') as myzip: 39 | img = Image.open(myzip.open(file_img)) 40 | return img.convert('RGB') 41 | 42 | 43 | class ImageCaptionDatasetBase(torch.utils.data.Dataset): 44 | def __init__(self, dataset, root, metadata): 45 | self.dataset = dataset 46 | self.root = root 47 | if self.dataset == 'yfcc15m': 48 | with open(metadata, 'rb') as f: 49 | self.samples = pickle.load(f) 50 | elif self.dataset == 'coco': 51 | samples = defaultdict(list) 52 | with open(metadata) as f: 53 | annotations = json.load(f)['annotations'] 54 | for ann in annotations: 55 | samples[ann['image_id']].append(ann['caption']) 56 | self.samples = [(k, v) for k, v in samples.items()] 57 | elif self.dataset == 'cc12m' or self.dataset == 'cc3m': 58 | self.samples = np.load(metadata, allow_pickle=True) 59 | elif self.dataset == 'redcaps': 60 | with open(metadata) as f: 61 | annotations = json.load(f) 62 | self.samples = [(ann['image_id'], ann['subreddit'], ann['caption']) for ann in annotations] 63 | 64 | def get_raw_item(self, i): 65 | if self.dataset == 'yfcc15m': 66 | index, title, desc = self.samples[i] 67 | caption = np.random.choice([title, desc]) 68 | img = yfcc_loader(self.root, index) 69 | elif self.dataset == 'coco': 70 | index, captions = self.samples[i] 71 | path = os.path.join(self.root, 'train2017', '{:012d}.jpg'.format(index)) 72 | img = pil_loader(path) 73 | caption = np.random.choice(captions) 74 | elif self.dataset == 'cc3m': 75 | ann = self.samples[i] 76 | filename, captions = ann['image_id'], ann['captions'] 77 | path = os.path.join(self.root, str(filename)) 78 | img = pil_loader(path) 79 | caption = np.random.choice(captions) 80 | elif self.dataset == 'cc12m': 81 | ann = self.samples[i] 82 | filename, captions = ann['image_name'], ann['captions'] 83 | path = os.path.join(self.root, filename) 84 | img = pil_loader(path) 85 | caption = np.random.choice(captions) 86 | elif self.dataset == 'redcaps': 87 | image_id, subreddit, caption = self.samples[i] 88 | path = os.path.join(self.root, subreddit, f"{image_id}.jpg") 89 | img = pil_loader(path) 90 | 91 | return img, caption 92 | 93 | def __getitem__(self, i): 94 | raise NotImplementedError 95 | 96 | def __len__(self): 97 | return len(self.samples) 98 | 99 | 100 | class ImageCaptionDatasetCLIP(ImageCaptionDatasetBase): 101 | def __init__(self, dataset, root, metadata, transform=None, tokenizer=None): 102 | super().__init__(dataset, root, metadata) 103 | 104 | self.transform = transform 105 | self.tokenizer = tokenizer 106 | 107 | def __getitem__(self, i): 108 | img, caption = self.get_raw_item(i) 109 | 110 | # apply transformation 111 | if self.transform is not None: 112 | image = self.transform(img) 113 | 114 | # tokenize caption 115 | if self.tokenizer is not None: 116 | caption = self.tokenizer(caption) 117 | 118 | return image, caption 119 | 120 | 121 | class ImageCaptionDatasetSLIP(ImageCaptionDatasetBase): 122 | def __init__(self, dataset, root, metadata, transform, augment, tokenizer=None): 123 | super().__init__(dataset, root, metadata) 124 | 125 | self.transform = transform 126 | self.augment = augment 127 | self.tokenizer = tokenizer 128 | 129 | def __getitem__(self, i): 130 | img, caption = self.get_raw_item(i) 131 | 132 | image = self.transform(img) 133 | aug1 = self.augment(img) 134 | aug2 = self.augment(img) 135 | 136 | # tokenize caption 137 | if self.tokenizer is not None: 138 | caption = self.tokenizer(caption) 139 | 140 | return image, caption, aug1, aug2 141 | 142 | 143 | class ImageCaptionDatasetSSL(ImageCaptionDatasetBase): 144 | def __init__(self, dataset, root, metadata, augment): 145 | super().__init__(dataset, root, metadata) 146 | 147 | self.augment = augment 148 | 149 | def __getitem__(self, i): 150 | img, _ = self.get_raw_item(i) 151 | 152 | aug1 = self.augment(img) 153 | aug2 = self.augment(img) 154 | 155 | return aug1, aug2 156 | 157 | 158 | class FileListDataset(torch.utils.data.Dataset): 159 | def __init__(self, images, labels, transform=None, target_transform=None): 160 | self.transform = transform 161 | self.target_transform = target_transform 162 | self.images = np.load(images) 163 | self.labels = np.load(labels) 164 | 165 | def __getitem__(self, index): 166 | img = pil_loader(self.images[index]) 167 | target = self.labels[index] 168 | 169 | if self.transform is not None: 170 | img = self.transform(img) 171 | 172 | if self.target_transform is not None: 173 | target = self.target_transform(target) 174 | 175 | return img, target 176 | 177 | def __len__(self): 178 | return len(self.images) 179 | 180 | 181 | def get_downstream_dataset(catalog, name, is_train, transform): 182 | entry = catalog[name] 183 | root = entry['path'] 184 | if entry['type'] == 'imagefolder': 185 | dataset = t_datasets.ImageFolder(os.path.join(root, entry['train'] if is_train else entry['test']), 186 | transform=transform) 187 | elif entry['type'] == 'special': 188 | if name == 'cifar10': 189 | dataset = t_datasets.CIFAR10(root, train=is_train, 190 | transform=transform, download=True) 191 | elif name == 'cifar100': 192 | dataset = t_datasets.CIFAR100(root, train=is_train, 193 | transform=transform, download=True) 194 | elif name == 'stl10': 195 | dataset = t_datasets.STL10(root, split='train' if is_train else 'test', 196 | transform=transform, download=True) 197 | elif name == 'mnist': 198 | dataset = t_datasets.MNIST(root, train=is_train, 199 | transform=transform, download=True) 200 | elif entry['type'] == 'filelist': 201 | path = entry['train'] if is_train else entry['test'] 202 | val_images = os.path.join(root, path + '_images.npy') 203 | val_labels = os.path.join(root, path + '_labels.npy') 204 | if name == 'clevr_counts': 205 | target_transform = lambda x: ['count_10', 'count_3', 'count_4', 'count_5', 'count_6', 'count_7', 'count_8', 'count_9'].index(x) 206 | else: 207 | target_transform = None 208 | dataset = FileListDataset(val_images, val_labels, transform, target_transform) 209 | else: 210 | raise Exception('Unknown dataset') 211 | 212 | return dataset 213 | 214 | 215 | def get_dataset(train_transform, tokenizer, args): 216 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 217 | std=[0.229, 0.224, 0.225]) 218 | augment = transforms.Compose([ 219 | transforms.RandomResizedCrop(224, scale=(0.08, 1.)), 220 | transforms.RandomApply([ 221 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 222 | ], p=0.8), 223 | transforms.RandomGrayscale(p=0.2), 224 | transforms.RandomApply([utils.GaussianBlur([.1, 2.])], p=0.5), 225 | transforms.RandomHorizontalFlip(), 226 | transforms.ToTensor(), 227 | normalize, 228 | ]) 229 | 230 | if args.model.startswith('SIMCLR'): 231 | return ImageCaptionDatasetSSL(args.dataset, args.root, args.metadata, augment) 232 | elif args.model.startswith('CLIP'): 233 | return ImageCaptionDatasetCLIP(args.dataset, args.root, args.metadata, train_transform, tokenizer) 234 | elif args.model.startswith('SLIP'): 235 | return ImageCaptionDatasetSLIP(args.dataset, args.root, args.metadata, train_transform, augment, tokenizer) -------------------------------------------------------------------------------- /eval_zeroshot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import argparse 7 | from collections import OrderedDict 8 | import json 9 | import os 10 | from sklearn import metrics 11 | 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | import torch.utils.data 15 | import torchvision.transforms as transforms 16 | 17 | import datasets 18 | import models 19 | from tokenizer import SimpleTokenizer 20 | import utils 21 | 22 | 23 | def get_args_parser(): 24 | parser = argparse.ArgumentParser(description='SLIP 0-shot evaluations', add_help=False) 25 | parser.add_argument('--output-dir', default='./', type=str, help='output dir') 26 | parser.add_argument('--batch-size', default=256, type=int, help='batch_size') 27 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', 28 | help='number of data loading workers per process') 29 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') 30 | return parser 31 | 32 | 33 | def main(args): 34 | # optionally resume from a checkpoint (takes precedence over autoresume) 35 | if args.resume: 36 | ckpt_path = args.resume 37 | elif os.path.isfile(os.path.join(args.output_dir, 'checkpoint_best.pt')): 38 | ckpt_path = os.path.join(args.output_dir, 'checkpoint_best.pt') 39 | else: 40 | raise Exception('no checkpoint found') 41 | 42 | ckpt = torch.load(ckpt_path, map_location='cpu') 43 | state_dict = OrderedDict() 44 | for k, v in ckpt['state_dict'].items(): 45 | state_dict[k.replace('module.', '')] = v 46 | 47 | # create model 48 | old_args = ckpt['args'] 49 | print("=> creating model: {}".format(old_args.model)) 50 | model = getattr(models, old_args.model)(rand_embed=False, 51 | ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim) 52 | model.cuda() 53 | model.load_state_dict(state_dict, strict=True) 54 | print("=> loaded resume checkpoint '{}' (epoch {})".format(args.resume, ckpt['epoch'])) 55 | 56 | cudnn.benchmark = True 57 | 58 | cwd = os.path.dirname(os.path.realpath(__file__)) 59 | with open(os.path.join(cwd, 'dataset_catalog.json')) as f: 60 | catalog = json.load(f) 61 | 62 | with open(os.path.join(cwd, 'templates.json')) as f: 63 | all_templates = json.load(f) 64 | 65 | with open(os.path.join(cwd, 'labels.json')) as f: 66 | all_labels = json.load(f) 67 | 68 | # Data loading code 69 | print("=> creating dataset") 70 | tokenizer = SimpleTokenizer() 71 | val_transform = transforms.Compose([ 72 | transforms.Resize(224), 73 | transforms.CenterCrop(224), 74 | lambda x: x.convert('RGB'), 75 | transforms.ToTensor(), 76 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 77 | std=[0.229, 0.224, 0.225]) 78 | ]) 79 | 80 | results = [] 81 | for d in catalog: 82 | print('Evaluating {}'.format(d)) 83 | val_dataset = datasets.get_downstream_dataset(catalog, name=d, is_train=False, transform=val_transform) 84 | 85 | val_loader = torch.utils.data.DataLoader( 86 | val_dataset, batch_size=args.batch_size, shuffle=False, 87 | num_workers=args.workers, pin_memory=True, drop_last=False) 88 | 89 | templates = all_templates[d] 90 | labels = all_labels[d] 91 | 92 | is_acc = d not in ['aircraft', 'pets', 'caltech101', 'flowers', 'kinetics700_frames', 'hateful_memes'] 93 | 94 | acc_or_outputs = validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc) 95 | 96 | if d in ['aircraft', 'pets', 'caltech101', 'flowers']: 97 | metric = mean_per_class(*acc_or_outputs) 98 | elif d == 'kinetics700_frames': 99 | top1, top5 = accuracy(*acc_or_outputs, topk=(1, 5)) 100 | metric = (top1 + top5) / 2 101 | metric = metric.item() 102 | elif d == 'hateful_memes': 103 | metric = roc_auc(*acc_or_outputs) 104 | else: 105 | metric = acc_or_outputs 106 | 107 | results.append(metric) 108 | 109 | print('metric:', metric) 110 | 111 | print('all results:') 112 | for x in results: 113 | print('{:.1f}'.format(x)) 114 | 115 | def validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc): 116 | # switch to evaluate mode 117 | model.eval() 118 | total_top1 = 0 119 | total_images = 0 120 | 121 | all_outputs = [] 122 | all_targets = [] 123 | 124 | print('=> encoding captions') 125 | with torch.no_grad(): 126 | text_features = [] 127 | for label in labels: 128 | if isinstance(label, list): 129 | texts = [t.format(l) for t in templates for l in label] 130 | else: 131 | texts = [t.format(label) for t in templates] 132 | texts = tokenizer(texts).cuda(non_blocking=True) 133 | texts = texts.view(-1, 77).contiguous() 134 | class_embeddings = utils.get_model(model).encode_text(texts) 135 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 136 | class_embeddings = class_embeddings.mean(dim=0) 137 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 138 | text_features.append(class_embeddings) 139 | text_features = torch.stack(text_features, dim=0) 140 | 141 | for images, target in val_loader: 142 | images = images.cuda(non_blocking=True) 143 | target = target.cuda(non_blocking=True) 144 | 145 | # encode images 146 | image_features = utils.get_model(model).encode_image(images) 147 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 148 | 149 | # cosine similarity as logits 150 | logits_per_image = image_features @ text_features.t() 151 | 152 | if is_acc: 153 | # measure accuracy and record loss 154 | pred = logits_per_image.argmax(dim=1) 155 | correct = pred.eq(target).sum() 156 | total_top1 += correct.item() 157 | total_images += images.size(0) 158 | else: 159 | all_outputs.append(logits_per_image.cpu()) 160 | all_targets.append(target.cpu()) 161 | 162 | if is_acc: 163 | return 100 * total_top1 / total_images 164 | else: 165 | return torch.cat(all_outputs), torch.cat(all_targets) 166 | 167 | 168 | def accuracy(output, target, topk=(1,)): 169 | """Computes the accuracy over the k top predictions for the specified values of k""" 170 | with torch.no_grad(): 171 | maxk = max(topk) 172 | batch_size = target.size(0) 173 | 174 | _, pred = output.topk(maxk, 1, True, True) 175 | pred = pred.t() 176 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 177 | 178 | res = [] 179 | for k in topk: 180 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 181 | res.append(correct_k.mul_(100.0 / batch_size)) 182 | return res 183 | 184 | 185 | def mean_per_class(outputs, targets): 186 | pred = outputs.argmax(1) 187 | confusion_matrix = metrics.confusion_matrix(targets, pred) 188 | per_classes = confusion_matrix.diagonal() / confusion_matrix.sum(axis=1) 189 | 190 | return 100 * per_classes.mean() 191 | 192 | 193 | def roc_auc(outputs, targets): 194 | pos_score = outputs[:, 1] - outputs[:, 0] 195 | metric = metrics.roc_auc_score(targets, pos_score) 196 | 197 | return 100 * metric 198 | 199 | 200 | if __name__ == '__main__': 201 | parser = argparse.ArgumentParser('SLIP 0-shot evaluations', parents=[get_args_parser()]) 202 | args = parser.parse_args() 203 | os.makedirs(args.output_dir, exist_ok=True) 204 | main(args) 205 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import utils 11 | 12 | 13 | class CLIPLoss(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.labels = None 17 | self.last_local_batch_size = None 18 | 19 | def forward(self, outputs): 20 | image_embed = outputs['image_embed'] 21 | text_embed = outputs['text_embed'] 22 | logit_scale = outputs['logit_scale'] 23 | local_batch_size = image_embed.size(0) 24 | 25 | if local_batch_size != self.last_local_batch_size: 26 | self.labels = local_batch_size * utils.get_rank() + torch.arange( 27 | local_batch_size, device=image_embed.device 28 | ) 29 | self.last_local_batch_size = local_batch_size 30 | 31 | # normalized features 32 | image_embed = F.normalize(image_embed, dim=-1, p=2) 33 | text_embed = F.normalize(text_embed, dim=-1, p=2) 34 | 35 | # gather features from all GPUs 36 | image_embed_all, text_embed_all = \ 37 | utils.all_gather_batch([image_embed, text_embed]) 38 | 39 | # cosine similarity as logits 40 | logits_per_image = logit_scale * image_embed @ text_embed_all.t() 41 | logits_per_text = logit_scale * text_embed @ image_embed_all.t() 42 | 43 | loss = (F.cross_entropy(logits_per_image, self.labels) + \ 44 | F.cross_entropy(logits_per_text, self.labels)) / 2 45 | 46 | # compute accuracy 47 | with torch.no_grad(): 48 | pred = torch.argmax(logits_per_image, dim=-1) 49 | correct = pred.eq(self.labels).sum() 50 | acc = 100 * correct / local_batch_size 51 | 52 | return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc} 53 | 54 | 55 | class SIMCLRLoss(nn.Module): 56 | """ 57 | This is the SimCLR loss in https://arxiv.org/abs/2002.05709 58 | The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and 59 | the memory layout that can be reshaped into shape (2, batch_size, embedding_dim). 60 | This memory layout is consistent with the SimCLR collator in 61 | https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py 62 | Config params: 63 | temperature (float): the temperature to be applied on the logits 64 | """ 65 | 66 | def __init__(self, temperature=0.1): 67 | super().__init__() 68 | self.tau = temperature 69 | self.labels = None 70 | self.masks = None 71 | self.last_local_batch_size = None 72 | 73 | def forward(self, outputs): 74 | q_a = outputs['aug1_embed'] 75 | q_b = outputs['aug2_embed'] 76 | 77 | q_a = F.normalize(q_a, dim=-1, p=2) 78 | q_b = F.normalize(q_b, dim=-1, p=2) 79 | 80 | local_batch_size = q_a.size(0) 81 | 82 | k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b]) 83 | 84 | if local_batch_size != self.last_local_batch_size: 85 | self.labels = local_batch_size * utils.get_rank() + torch.arange( 86 | local_batch_size, device=q_a.device 87 | ) 88 | total_batch_size = local_batch_size * utils.get_world_size() 89 | self.masks = F.one_hot(self.labels, total_batch_size) * 1e9 90 | self.last_local_batch_size = local_batch_size 91 | 92 | logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau 93 | logits_aa = logits_aa - self.masks 94 | logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau 95 | logits_bb = logits_bb - self.masks 96 | logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau 97 | logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau 98 | 99 | loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels) 100 | loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels) 101 | loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples 102 | 103 | # compute accuracy 104 | with torch.no_grad(): 105 | pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1) 106 | correct = pred.eq(self.labels).sum() 107 | acc = 100 * correct / local_batch_size 108 | 109 | return {'loss': loss, 'ssl_loss': loss, 'ssl_acc': acc} 110 | 111 | 112 | class SLIPLoss(nn.Module): 113 | def __init__(self, ssl_loss, ssl_scale): 114 | super().__init__() 115 | self.clip_loss = CLIPLoss() 116 | self.ssl_loss = ssl_loss 117 | self.ssl_scale = ssl_scale 118 | 119 | def forward(self, outputs): 120 | clip_loss_dict = self.clip_loss(outputs) 121 | clip_loss = clip_loss_dict['clip_loss'] 122 | clip_acc = clip_loss_dict['clip_acc'] 123 | 124 | ssl_loss_dict = self.ssl_loss(outputs) 125 | ssl_loss = ssl_loss_dict['ssl_loss'] 126 | ssl_acc = ssl_loss_dict['ssl_acc'] 127 | 128 | return {'loss': clip_loss + self.ssl_scale * ssl_loss, 129 | 'clip_loss': clip_loss, 130 | 'clip_acc': clip_acc, 131 | 'ssl_loss': ssl_loss, 132 | 'ssl_acc': ssl_acc} 133 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import argparse 7 | from collections import OrderedDict 8 | import json 9 | import math 10 | import os 11 | import sys 12 | import time 13 | import wandb 14 | 15 | import numpy as np 16 | import torch 17 | import torch.cuda.amp as amp 18 | import torch.nn.parallel 19 | import torch.backends.cudnn as cudnn 20 | import torch.distributed as dist 21 | import torch.optim 22 | import torch.utils.data 23 | import torch.utils.data.distributed 24 | from torchvision.datasets import ImageFolder 25 | import torchvision.transforms as transforms 26 | 27 | import datasets 28 | import models 29 | from tokenizer import SimpleTokenizer 30 | import utils 31 | 32 | 33 | def get_args_parser(): 34 | parser = argparse.ArgumentParser(description='SLIP training and evaluation', add_help=False) 35 | # Data 36 | parser.add_argument('--dataset', default='yfcc15m', type=str, choices=['yfcc15m', 'cc3m', 'cc12m', 'coco', 'redcaps']) 37 | parser.add_argument('--root', default='', type=str, 38 | help='path to dataset root') 39 | parser.add_argument('--metadata', default='yfcc15m.pkl', type=str, 40 | help='path to metadata file (see README for details)') 41 | parser.add_argument('--output-dir', default='./', type=str, help='output dir') 42 | # Model 43 | parser.add_argument('--model', default='SLIP_VITB16', type=str) 44 | parser.add_argument('--ssl-mlp-dim', default=4096, type=int, 45 | help='hidden dim of SimCLR mlp projection head') 46 | parser.add_argument('--ssl-emb-dim', default=256, type=int, 47 | help='output embed dim of SimCLR mlp projection head') 48 | parser.add_argument('--ssl-scale', default=1.0, type=float, 49 | help='loss scale for SimCLR objective') 50 | parser.add_argument('--ssl-temp', default=0.1, type=float, 51 | help='softmax temperature for SimCLR objective') 52 | parser.add_argument('--resume', default='', type=str, help='path to resume from') 53 | # Training 54 | parser.add_argument('--epochs', default=25, type=int) 55 | parser.add_argument('--warmup-epochs', default=1, type=int) 56 | parser.add_argument('--start-epoch', default=0, type=int) 57 | parser.add_argument('--batch-size', default=64, type=int, 58 | help='number of samples per-device/per-gpu') 59 | parser.add_argument('--lr', default=3e-3, type=float) 60 | parser.add_argument('--lr-start', default=1e-6, type=float, 61 | help='initial warmup lr') 62 | parser.add_argument('--lr-end', default=1e-5, type=float, 63 | help='minimum final lr') 64 | parser.add_argument('--update-freq', default=1, type=int, 65 | help='optimizer update frequency (i.e. gradient accumulation steps)') 66 | parser.add_argument('--wd', default=0.1, type=float) 67 | parser.add_argument('--betas', default=(0.9, 0.98), nargs=2, type=float) 68 | parser.add_argument('--eps', default=1e-8, type=float) 69 | parser.add_argument('--eval-freq', default=1, type=int) 70 | parser.add_argument('--disable-amp', action='store_true', 71 | help='disable mixed-precision training (requires more memory and compute)') 72 | # System 73 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 74 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', 75 | help='number of data loading workers per process') 76 | parser.add_argument('--evaluate', action='store_true', help='eval only') 77 | parser.add_argument('--world-size', default=1, type=int, 78 | help='number of nodes for distributed training') 79 | parser.add_argument('--rank', default=0, type=int, 80 | help='node rank for distributed training') 81 | parser.add_argument("--local_rank", type=int, default=0) 82 | parser.add_argument('--dist-url', default='env://', type=str, 83 | help='url used to set up distributed training') 84 | parser.add_argument('--dist-backend', default='nccl', type=str) 85 | parser.add_argument('--seed', default=0, type=int) 86 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 87 | parser.add_argument('--wandb', action='store_true', help='Enable WandB logging') 88 | return parser 89 | 90 | best_acc1 = 0 91 | 92 | 93 | def main(args): 94 | utils.init_distributed_mode(args) 95 | 96 | global best_acc1 97 | 98 | # fix the seed for reproducibility 99 | seed = args.seed + utils.get_rank() 100 | torch.manual_seed(seed) 101 | np.random.seed(seed) 102 | 103 | # create model 104 | print("=> creating model: {}".format(args.model)) 105 | model = getattr(models, args.model)(ssl_mlp_dim=args.ssl_mlp_dim, ssl_emb_dim=args.ssl_emb_dim) 106 | model.cuda(args.gpu) 107 | 108 | if args.distributed: 109 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], bucket_cap_mb=200) 110 | 111 | # define loss function (criterion) and optimizer 112 | criterion = models.get_loss(args.model, args.ssl_temp, args.ssl_scale).cuda(args.gpu) 113 | 114 | p_wd, p_non_wd = [], [] 115 | for n, p in model.named_parameters(): 116 | if not p.requires_grad: 117 | continue # frozen weights 118 | if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n: 119 | p_non_wd.append(p) 120 | else: 121 | p_wd.append(p) 122 | 123 | optim_params = [{"params": p_wd, "weight_decay": args.wd}, 124 | {"params": p_non_wd, "weight_decay": 0}] 125 | 126 | optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas, 127 | eps=args.eps, weight_decay=args.wd) 128 | scaler = amp.GradScaler(enabled=not args.disable_amp) 129 | 130 | # optionally resume from a checkpoint (takes precedence over autoresume) 131 | if args.resume: 132 | if os.path.isfile(args.resume): 133 | print("=> loading resume checkpoint '{}'".format(args.resume)) 134 | checkpoint = torch.load(args.resume, map_location='cpu') 135 | epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0 136 | args.start_epoch = epoch 137 | result = model.load_state_dict(checkpoint['state_dict'], strict=False) 138 | print(result) 139 | optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else () 140 | scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else () 141 | best_acc1 = checkpoint['best_acc1'] 142 | print("=> loaded resume checkpoint '{}' (epoch {})" 143 | .format(args.resume, epoch)) 144 | else: 145 | print("=> no checkpoint found at '{}'".format(args.resume)) 146 | else: 147 | # auto-resume from latest checkpoint in output directory 148 | latest = os.path.join(args.output_dir, 'checkpoint.pt') 149 | if os.path.isfile(latest): 150 | print("=> loading latest checkpoint '{}'".format(latest)) 151 | latest_checkpoint = torch.load(latest, map_location='cpu') 152 | args.start_epoch = latest_checkpoint['epoch'] 153 | model.load_state_dict(latest_checkpoint['state_dict']) 154 | optimizer.load_state_dict(latest_checkpoint['optimizer']) 155 | scaler.load_state_dict(latest_checkpoint['scaler']) 156 | best_acc1 = latest_checkpoint['best_acc1'] 157 | print("=> loaded latest checkpoint '{}' (epoch {})" 158 | .format(latest, latest_checkpoint['epoch'])) 159 | 160 | cudnn.benchmark = True 161 | 162 | # Data loading code 163 | print("=> creating dataset") 164 | tokenizer = SimpleTokenizer() 165 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 166 | std=[0.229, 0.224, 0.225]) 167 | train_transform = transforms.Compose([ 168 | transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), 169 | transforms.ToTensor(), 170 | normalize 171 | ]) 172 | val_transform = transforms.Compose([ 173 | transforms.Resize(224), 174 | transforms.CenterCrop(224), 175 | transforms.ToTensor(), 176 | normalize 177 | ]) 178 | 179 | train_dataset = datasets.get_dataset(train_transform, tokenizer, args) 180 | cwd = os.path.dirname(os.path.realpath(__file__)) 181 | with open(os.path.join(cwd, 'dataset_catalog.json')) as f: 182 | root = json.load(f)['imagenet']['path'] 183 | val_dataset = ImageFolder(os.path.join(root, 'val'), val_transform) 184 | 185 | # dist eval resamples data to pad uneven batch sizes 186 | # make sure num_samples = 0 mod num_gpus for exact acc 187 | if args.distributed: 188 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 189 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 190 | else: 191 | train_sampler = None 192 | val_sampler = None 193 | 194 | train_loader = torch.utils.data.DataLoader( 195 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 196 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 197 | 198 | val_loader = torch.utils.data.DataLoader( 199 | val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None), 200 | num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False) 201 | 202 | if args.evaluate: 203 | if args.model.startswith('SIMCLR'): 204 | print('zero-shot evaluation not supported with ssl-only model.') 205 | return 206 | 207 | zero_stats = validate_zeroshot(val_loader, model, tokenizer, args) 208 | if utils.is_main_process(): 209 | with open(os.path.join(args.output_dir, 'eval_log.txt'), 'a') as f: 210 | f.write(json.dumps(zero_stats) + '\n') 211 | return 212 | 213 | lr_schedule = utils.cosine_scheduler(args.lr, args.lr_end, args.epochs, 214 | len(train_loader) // args.update_freq, warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start) 215 | 216 | if utils.is_main_process() and args.wandb: 217 | wandb_id = os.path.split(args.output_dir)[-1] 218 | wandb.init(project='slip', id=wandb_id, config=args, resume='allow') 219 | 220 | print(args) 221 | 222 | print("=> beginning training") 223 | for epoch in range(args.start_epoch, args.epochs): 224 | if args.distributed: 225 | train_sampler.set_epoch(epoch) 226 | 227 | # train for one epoch 228 | train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args) 229 | 230 | if (epoch + 1) % args.eval_freq != 0: 231 | continue 232 | 233 | if args.model.startswith('SIMCLR'): 234 | val_stats = {'acc1': -1} 235 | acc1 = -1 236 | else: 237 | val_stats = validate_zeroshot(val_loader, model, tokenizer, args) 238 | acc1 = val_stats['acc1'] 239 | 240 | is_best = acc1 > best_acc1 241 | best_acc1 = max(acc1, best_acc1) 242 | 243 | print("=> saving checkpoint") 244 | utils.save_on_master({ 245 | 'epoch': epoch + 1, 246 | 'state_dict': model.state_dict(), 247 | 'optimizer' : optimizer.state_dict(), 248 | 'scaler': scaler.state_dict(), 249 | 'best_acc1': best_acc1, 250 | 'args': args, 251 | }, is_best, args.output_dir) 252 | 253 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 254 | **{f'test_{k}': v for k, v in val_stats.items()}, 255 | 'epoch': epoch} 256 | 257 | if utils.is_main_process(): 258 | if args.wandb: 259 | wandb.log(log_stats) 260 | with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f: 261 | f.write(json.dumps(log_stats) + '\n') 262 | 263 | 264 | def train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args): 265 | batch_time = AverageMeter('Time', ':6.2f') 266 | data_time = AverageMeter('Data', ':6.2f') 267 | mem = AverageMeter('Mem (GB)', ':6.1f') 268 | metric_names = models.get_metric_names(args.model) 269 | iters_per_epoch = len(train_loader) // args.update_freq 270 | metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names]) 271 | progress = ProgressMeter( 272 | iters_per_epoch, 273 | [batch_time, data_time, mem, *metrics.values()], 274 | prefix="Epoch: [{}]".format(epoch)) 275 | 276 | # switch to train mode 277 | model.train() 278 | 279 | end = time.time() 280 | for data_iter, inputs in enumerate(train_loader): 281 | optim_iter = data_iter // args.update_freq 282 | 283 | # measure data loading time 284 | data_time.update(time.time() - end) 285 | 286 | # update weight decay and learning rate according to their schedule 287 | it = iters_per_epoch * epoch + optim_iter # global training iteration 288 | for k, param_group in enumerate(optimizer.param_groups): 289 | param_group['lr'] = lr_schedule[it] 290 | 291 | inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs] 292 | 293 | # compute output 294 | with amp.autocast(enabled=not args.disable_amp): 295 | outputs = model(*inputs) 296 | loss_dict = criterion(outputs) 297 | loss = loss_dict['loss'] 298 | loss /= args.update_freq 299 | 300 | if not math.isfinite(loss.item()): 301 | print("Loss is {}, stopping training".format(loss.item())) 302 | sys.exit(1) 303 | 304 | scaler.scale(loss).backward() 305 | 306 | if (data_iter + 1) % args.update_freq != 0: 307 | continue 308 | 309 | # compute gradient and do SGD step 310 | scaler.step(optimizer) 311 | scaler.update() 312 | model.zero_grad(set_to_none=True) 313 | 314 | # clamp logit scale to [0, 100] 315 | if args.model.startswith('SIMCLR'): 316 | logit_scale = 0 317 | else: 318 | utils.get_model(model).logit_scale.data.clamp_(0, 4.6052) 319 | logit_scale = utils.get_model(model).logit_scale.exp().item() 320 | 321 | for k in loss_dict: 322 | metrics[k].update(loss_dict[k].item(), args.batch_size) 323 | 324 | # measure elapsed time 325 | batch_time.update(time.time() - end) 326 | end = time.time() 327 | 328 | mem.update(torch.cuda.max_memory_allocated() // 1e9) 329 | 330 | if optim_iter % args.print_freq == 0: 331 | if utils.is_main_process() and args.wandb: 332 | wandb.log({**{k: v.item() for k, v in loss_dict.items()}, 333 | 'scaler': scaler.get_scale(), 334 | 'logit': logit_scale}) 335 | progress.display(optim_iter) 336 | 337 | progress.synchronize() 338 | return {**{k: v.avg for k, v in metrics.items()}, 339 | 'lr': optimizer.param_groups[0]['lr'], 340 | 'logit_scale': logit_scale} 341 | 342 | 343 | def validate_zeroshot(val_loader, model, tokenizer, args): 344 | batch_time = AverageMeter('Time', ':6.3f') 345 | top1 = AverageMeter('Acc@1', ':6.2f') 346 | top5 = AverageMeter('Acc@5', ':6.2f') 347 | progress = ProgressMeter( 348 | len(val_loader), 349 | [batch_time, top1, top5], 350 | prefix='Test: ') 351 | 352 | # switch to evaluate mode 353 | model.eval() 354 | 355 | print('=> encoding captions') 356 | cwd = os.path.dirname(os.path.realpath(__file__)) 357 | with open(os.path.join(cwd, 'templates.json')) as f: 358 | templates = json.load(f)['imagenet'] 359 | 360 | with open(os.path.join(cwd, 'labels.json')) as f: 361 | labels = json.load(f)['imagenet'] 362 | 363 | with torch.no_grad(): 364 | text_features = [] 365 | for l in labels: 366 | texts = [t.format(l) for t in templates] 367 | texts = tokenizer(texts).cuda(args.gpu, non_blocking=True) 368 | class_embeddings = utils.get_model(model).encode_text(texts) 369 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 370 | class_embeddings = class_embeddings.mean(dim=0) 371 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 372 | text_features.append(class_embeddings) 373 | text_features = torch.stack(text_features, dim=0) 374 | 375 | end = time.time() 376 | for i, (images, target) in enumerate(val_loader): 377 | images = images.cuda(args.gpu, non_blocking=True) 378 | target = target.cuda(args.gpu, non_blocking=True) 379 | 380 | # encode images 381 | image_features = utils.get_model(model).encode_image(images) 382 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 383 | 384 | # cosine similarity as logits 385 | logits_per_image = image_features @ text_features.t() 386 | 387 | # measure accuracy and record loss 388 | acc1, acc5 = accuracy(logits_per_image, target, topk=(1, 5)) 389 | acc1, acc5 = utils.scaled_all_reduce([acc1, acc5]) 390 | top1.update(acc1.item(), images.size(0)) 391 | top5.update(acc5.item(), images.size(0)) 392 | 393 | # measure elapsed time 394 | batch_time.update(time.time() - end) 395 | end = time.time() 396 | 397 | if i % args.print_freq == 0: 398 | progress.display(i) 399 | 400 | progress.synchronize() 401 | print('0-shot * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 402 | .format(top1=top1, top5=top5)) 403 | return {'acc1': top1.avg, 'acc5': top5.avg} 404 | 405 | 406 | class AverageMeter(object): 407 | """Computes and stores the average and current value""" 408 | def __init__(self, name, fmt=':f'): 409 | self.name = name 410 | self.fmt = fmt 411 | self.reset() 412 | 413 | def reset(self): 414 | self.val = 0 415 | self.avg = 0 416 | self.sum = 0 417 | self.count = 0 418 | 419 | def update(self, val, n=1): 420 | self.val = val 421 | self.sum += val * n 422 | self.count += n 423 | self.avg = self.sum / self.count 424 | 425 | def synchronize(self): 426 | if not utils.is_dist_avail_and_initialized(): 427 | return 428 | t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda') 429 | dist.barrier() 430 | dist.all_reduce(t) 431 | t = t.tolist() 432 | self.sum = int(t[0]) 433 | self.count = t[1] 434 | self.avg = self.sum / self.count 435 | 436 | def __str__(self): 437 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 438 | return fmtstr.format(**self.__dict__) 439 | 440 | 441 | class ProgressMeter(object): 442 | def __init__(self, num_batches, meters, prefix=""): 443 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 444 | self.meters = meters 445 | self.prefix = prefix 446 | 447 | def display(self, batch): 448 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 449 | entries += [str(meter) for meter in self.meters] 450 | print('\t'.join(entries)) 451 | 452 | def synchronize(self): 453 | for meter in self.meters: 454 | meter.synchronize() 455 | 456 | def _get_batch_fmtstr(self, num_batches): 457 | num_digits = len(str(num_batches // 1)) 458 | fmt = '{:' + str(num_digits) + 'd}' 459 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 460 | 461 | 462 | def accuracy(output, target, topk=(1,)): 463 | """Computes the accuracy over the k top predictions for the specified values of k""" 464 | with torch.no_grad(): 465 | maxk = max(topk) 466 | batch_size = target.size(0) 467 | 468 | _, pred = output.topk(maxk, 1, True, True) 469 | pred = pred.t() 470 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 471 | 472 | res = [] 473 | for k in topk: 474 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 475 | res.append(correct_k.mul_(100.0 / batch_size)) 476 | return res 477 | 478 | 479 | if __name__ == '__main__': 480 | parser = argparse.ArgumentParser('SLIP training and evaluation', parents=[get_args_parser()]) 481 | args = parser.parse_args() 482 | os.makedirs(args.output_dir, exist_ok=True) 483 | main(args) 484 | -------------------------------------------------------------------------------- /main_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import argparse 7 | import json 8 | import math 9 | import os 10 | import random 11 | import shutil 12 | import time 13 | import timm 14 | import warnings 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.parallel 19 | import torch.backends.cudnn as cudnn 20 | import torch.optim 21 | import torch.utils.data 22 | import torch.utils.data.distributed 23 | import torchvision.transforms as transforms 24 | 25 | import datasets 26 | import utils 27 | 28 | 29 | def get_args_parser(): 30 | parser = argparse.ArgumentParser(description='Linear probe evaluation', add_help=False) 31 | parser.add_argument('--dataset', default='imagenet', help='dataset name') 32 | parser.add_argument('--output-dir', default='./', type=str) 33 | parser.add_argument('-a', '--arch', metavar='ARCH', default='vit_base_patch16_224', 34 | help='model architecture: (default: ViT-B/16)') 35 | parser.add_argument('-j', '--workers', default=64, type=int, metavar='N', 36 | help='number of data loading workers (default: 64)') 37 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 40 | help='manual epoch number (useful on restarts)') 41 | parser.add_argument('-b', '--batch-size', default=128, type=int, 42 | metavar='N', 43 | help='number of samples per-device/per-gpu ') 44 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 45 | metavar='LR', help='initial (base) learning rate', dest='lr') 46 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 47 | help='momentum') 48 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 49 | metavar='W', help='weight decay (default: 0.)', 50 | dest='weight_decay') 51 | parser.add_argument('-p', '--print-freq', default=10, type=int, 52 | metavar='N', help='print frequency (default: 10)') 53 | parser.add_argument('--eval-freq', default=10, type=int) 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--world-size', default=1, type=int, 59 | help='number of nodes for distributed training') 60 | parser.add_argument('--rank', default=0, type=int, 61 | help='node rank for distributed training') 62 | parser.add_argument("--local_rank", type=int, default=0) 63 | parser.add_argument('--dist-url', default='env://', type=str, 64 | help='url used to set up distributed training') 65 | parser.add_argument('--dist-backend', default='nccl', type=str, 66 | help='distributed backend') 67 | parser.add_argument('--seed', default=None, type=int, 68 | help='seed for initializing training. ') 69 | parser.add_argument('--gpu', default=None, type=int, 70 | help='GPU id to use.') 71 | parser.add_argument('--pretrained', default='', type=str, 72 | help='path to CLIP pretrained checkpoint') 73 | return parser 74 | 75 | best_acc1 = 0 76 | 77 | 78 | def main(args): 79 | utils.init_distributed_mode(args) 80 | 81 | global best_acc1 82 | 83 | if args.seed is not None: 84 | random.seed(args.seed) 85 | torch.manual_seed(args.seed) 86 | cudnn.deterministic = True 87 | warnings.warn('You have chosen to seed training. ' 88 | 'This will turn on the CUDNN deterministic setting, ' 89 | 'which can slow down your training considerably! ' 90 | 'You may see unexpected behavior when restarting ' 91 | 'from checkpoints.') 92 | 93 | linear_keyword = 'head' 94 | if os.path.isfile(args.pretrained): 95 | print("=> loading checkpoint '{}'".format(args.pretrained)) 96 | 97 | if args.gpu is None: 98 | checkpoint = torch.load(args.pretrained) 99 | else: 100 | # Map model to be loaded to specified single gpu. 101 | loc = 'cuda:{}'.format(args.gpu) 102 | checkpoint = torch.load(args.pretrained, map_location=loc) 103 | 104 | visual_keyword = 'module.visual.' 105 | 106 | # rename CLIP pre-trained keys 107 | state_dict = checkpoint['state_dict'] 108 | for k in list(state_dict.keys()): 109 | # retain only base_encoder up to before the embedding layer 110 | if k.startswith(visual_keyword) and not k.startswith(visual_keyword + linear_keyword): 111 | # remove prefix 112 | state_dict[k[len(visual_keyword):]] = state_dict[k] 113 | # delete renamed or unused k 114 | del state_dict[k] 115 | else: 116 | raise Exception('Missing pretrained model checkpoint: {}'.format(args.pretrained)) 117 | 118 | # create model 119 | print("=> creating model '{}'".format(args.arch)) 120 | model = timm.models.create_model(args.arch, num_classes=1000) 121 | 122 | args.start_epoch = 0 123 | msg = model.load_state_dict(state_dict, strict=False) 124 | assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword} 125 | 126 | # freeze all layers but the last fc 127 | for name, param in model.named_parameters(): 128 | if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]: 129 | param.requires_grad = False 130 | # init the fc layer 131 | getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01) 132 | getattr(model, linear_keyword).bias.data.zero_() 133 | 134 | init_lr = args.lr * int(args.batch_size / utils.get_world_size()) / 256 135 | args.workers = int((args.workers + utils.get_world_size() - 1) / utils.get_world_size()) 136 | 137 | model.cuda(args.gpu) 138 | 139 | if args.distributed: 140 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 141 | 142 | # define loss function (criterion) and optimizer 143 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 144 | 145 | # optimize only the linear classifier 146 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 147 | assert len(parameters) == 2 # weight, bias 148 | 149 | optimizer = torch.optim.SGD(parameters, init_lr, 150 | momentum=args.momentum, 151 | weight_decay=args.weight_decay) 152 | 153 | # optionally resume from a checkpoint 154 | if args.resume: 155 | if os.path.isfile(args.resume): 156 | print("=> loading checkpoint '{}'".format(args.resume)) 157 | if args.gpu is None: 158 | checkpoint = torch.load(args.resume) 159 | else: 160 | # Map model to be loaded to specified single gpu. 161 | loc = 'cuda:{}'.format(args.gpu) 162 | checkpoint = torch.load(args.resume, map_location=loc) 163 | args.start_epoch = checkpoint['epoch'] 164 | best_acc1 = checkpoint['best_acc1'] 165 | if args.gpu is not None: 166 | # best_acc1 may be from a checkpoint from a different GPU 167 | best_acc1 = best_acc1.to(args.gpu) 168 | model.load_state_dict(checkpoint['state_dict']) 169 | optimizer.load_state_dict(checkpoint['optimizer']) 170 | print("=> loaded checkpoint '{}' (epoch {})" 171 | .format(args.resume, checkpoint['epoch'])) 172 | else: 173 | print("=> no checkpoint found at '{}'".format(args.resume)) 174 | 175 | cudnn.benchmark = True 176 | 177 | # Data loading code 178 | cwd = os.path.dirname(os.path.realpath(__file__)) 179 | with open(os.path.join(cwd, 'dataset_catalog.json')) as f: 180 | catalog = json.load(f) 181 | 182 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 183 | std=[0.229, 0.224, 0.225]) 184 | 185 | train_transform = transforms.Compose([ 186 | transforms.RandomResizedCrop(224), 187 | transforms.RandomHorizontalFlip(), 188 | lambda x: x.convert('RGB'), 189 | transforms.ToTensor(), 190 | normalize, 191 | ]) 192 | val_transform = transforms.Compose([ 193 | transforms.Resize(256), 194 | transforms.CenterCrop(224), 195 | lambda x: x.convert('RGB'), 196 | transforms.ToTensor(), 197 | normalize, 198 | ]) 199 | 200 | train_dataset = datasets.get_downstream_dataset(catalog, args.dataset, is_train=True, transform=train_transform) 201 | val_dataset = datasets.get_downstream_dataset(catalog, args.dataset, is_train=False, transform=val_transform) 202 | 203 | if args.distributed: 204 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 205 | else: 206 | train_sampler = None 207 | 208 | train_loader = torch.utils.data.DataLoader( 209 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 210 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 211 | 212 | val_loader = torch.utils.data.DataLoader( 213 | val_dataset, batch_size=256, shuffle=False, 214 | num_workers=args.workers, pin_memory=True) 215 | 216 | if args.evaluate: 217 | validate(val_loader, model, criterion, args) 218 | return 219 | 220 | print(args) 221 | 222 | for epoch in range(args.start_epoch, args.epochs): 223 | if args.distributed: 224 | train_sampler.set_epoch(epoch) 225 | adjust_learning_rate(optimizer, init_lr, epoch, args) 226 | 227 | # train for one epoch 228 | train_stats = train(train_loader, model, criterion, optimizer, epoch, args) 229 | 230 | if (epoch + 1) % args.eval_freq != 0: 231 | continue 232 | 233 | # evaluate on validation set 234 | val_stats = validate(val_loader, model, criterion, args) 235 | acc1 = val_stats['acc1'] 236 | 237 | # remember best acc@1 and save checkpoint 238 | is_best = acc1 > best_acc1 239 | best_acc1 = max(acc1, best_acc1) 240 | 241 | if utils.is_main_process(): # only the first GPU saves checkpoint 242 | save_checkpoint({ 243 | 'epoch': epoch + 1, 244 | 'arch': args.arch, 245 | 'state_dict': model.state_dict(), 246 | 'best_acc1': best_acc1, 247 | 'optimizer' : optimizer.state_dict(), 248 | }, is_best, args.output_dir) 249 | if epoch == args.start_epoch: 250 | sanity_check(model.state_dict(), args.pretrained, linear_keyword, visual_keyword) 251 | 252 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 253 | **{f'test_{k}': v for k, v in val_stats.items()}, 254 | 'epoch': epoch} 255 | 256 | if utils.is_main_process(): 257 | with open(os.path.join(args.output_dir, 'linear_{}_lr={}_log.txt'.format(args.dataset, args.lr)), 'a') as f: 258 | f.write(json.dumps(log_stats) + '\n') 259 | 260 | 261 | def train(train_loader, model, criterion, optimizer, epoch, args): 262 | batch_time = AverageMeter('Time', ':6.3f') 263 | data_time = AverageMeter('Data', ':6.3f') 264 | losses = AverageMeter('Loss', ':.4e') 265 | top1 = AverageMeter('Acc@1', ':6.2f') 266 | top5 = AverageMeter('Acc@5', ':6.2f') 267 | progress = ProgressMeter( 268 | len(train_loader), 269 | [batch_time, data_time, losses, top1, top5], 270 | prefix="Epoch: [{}]".format(epoch)) 271 | 272 | """ 273 | Switch to eval mode: 274 | Under the protocol of linear classification on frozen features/models, 275 | it is not legitimate to change any part of the pre-trained model. 276 | BatchNorm in train mode may revise running mean/std (even if it receives 277 | no gradient), which are part of the model parameters too. 278 | """ 279 | model.eval() 280 | 281 | end = time.time() 282 | for i, (images, target) in enumerate(train_loader): 283 | # measure data loading time 284 | data_time.update(time.time() - end) 285 | 286 | if args.gpu is not None: 287 | images = images.cuda(args.gpu, non_blocking=True) 288 | if torch.cuda.is_available(): 289 | target = target.cuda(args.gpu, non_blocking=True) 290 | 291 | # compute output 292 | output = model(images) 293 | loss = criterion(output, target) 294 | 295 | # measure accuracy and record loss 296 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 297 | losses.update(loss.item(), images.size(0)) 298 | top1.update(acc1.item(), images.size(0)) 299 | top5.update(acc5.item(), images.size(0)) 300 | 301 | # compute gradient and do SGD step 302 | optimizer.zero_grad() 303 | loss.backward() 304 | optimizer.step() 305 | 306 | # measure elapsed time 307 | batch_time.update(time.time() - end) 308 | end = time.time() 309 | 310 | if i % args.print_freq == 0: 311 | progress.display(i) 312 | 313 | return {'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg} 314 | 315 | 316 | def validate(val_loader, model, criterion, args): 317 | batch_time = AverageMeter('Time', ':6.3f') 318 | losses = AverageMeter('Loss', ':.4e') 319 | top1 = AverageMeter('Acc@1', ':6.2f') 320 | top5 = AverageMeter('Acc@5', ':6.2f') 321 | progress = ProgressMeter( 322 | len(val_loader), 323 | [batch_time, losses, top1, top5], 324 | prefix='Test: ') 325 | 326 | # switch to evaluate mode 327 | model.eval() 328 | 329 | with torch.no_grad(): 330 | end = time.time() 331 | for i, (images, target) in enumerate(val_loader): 332 | if args.gpu is not None: 333 | images = images.cuda(args.gpu, non_blocking=True) 334 | if torch.cuda.is_available(): 335 | target = target.cuda(args.gpu, non_blocking=True) 336 | 337 | # compute output 338 | output = model(images) 339 | loss = criterion(output, target) 340 | 341 | # measure accuracy and record loss 342 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 343 | losses.update(loss.item(), images.size(0)) 344 | top1.update(acc1.item(), images.size(0)) 345 | top5.update(acc5.item(), images.size(0)) 346 | 347 | # measure elapsed time 348 | batch_time.update(time.time() - end) 349 | end = time.time() 350 | 351 | if i % args.print_freq == 0: 352 | progress.display(i) 353 | 354 | # TODO: this should also be done with the ProgressMeter 355 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 356 | .format(top1=top1, top5=top5)) 357 | 358 | return {'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg} 359 | 360 | 361 | def save_checkpoint(state, is_best, output_dir): 362 | ckpt_path = f'{output_dir}/linear_checkpoint.pt' 363 | best_path = f'{output_dir}/linear_best.pt' 364 | torch.save(state, ckpt_path) 365 | if is_best: 366 | shutil.copyfile(ckpt_path, best_path) 367 | 368 | 369 | def sanity_check(state_dict, pretrained_weights, linear_keyword, visual_keyword): 370 | """ 371 | Linear classifier should not change any weights other than the linear layer. 372 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 373 | """ 374 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 375 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 376 | state_dict_pre = checkpoint['state_dict'] 377 | 378 | for k in list(state_dict.keys()): 379 | # only ignore linear layer 380 | if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k: 381 | continue 382 | 383 | # name in pretrained model 384 | k_pre = visual_keyword + k[len('module.'):] \ 385 | if k.startswith('module.') else visual_keyword + k 386 | 387 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 388 | '{} is changed in linear classifier training.'.format(k) 389 | 390 | print("=> sanity check passed.") 391 | 392 | 393 | class AverageMeter(object): 394 | """Computes and stores the average and current value""" 395 | def __init__(self, name, fmt=':f'): 396 | self.name = name 397 | self.fmt = fmt 398 | self.reset() 399 | 400 | def reset(self): 401 | self.val = 0 402 | self.avg = 0 403 | self.sum = 0 404 | self.count = 0 405 | 406 | def update(self, val, n=1): 407 | self.val = val 408 | self.sum += val * n 409 | self.count += n 410 | self.avg = self.sum / self.count 411 | 412 | def __str__(self): 413 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 414 | return fmtstr.format(**self.__dict__) 415 | 416 | 417 | class ProgressMeter(object): 418 | def __init__(self, num_batches, meters, prefix=""): 419 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 420 | self.meters = meters 421 | self.prefix = prefix 422 | 423 | def display(self, batch): 424 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 425 | entries += [str(meter) for meter in self.meters] 426 | print('\t'.join(entries)) 427 | 428 | def _get_batch_fmtstr(self, num_batches): 429 | num_digits = len(str(num_batches // 1)) 430 | fmt = '{:' + str(num_digits) + 'd}' 431 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 432 | 433 | 434 | def adjust_learning_rate(optimizer, init_lr, epoch, args): 435 | """Decay the learning rate based on schedule""" 436 | cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 437 | for param_group in optimizer.param_groups: 438 | param_group['lr'] = cur_lr 439 | 440 | 441 | def accuracy(output, target, topk=(1,)): 442 | """Computes the accuracy over the k top predictions for the specified values of k""" 443 | with torch.no_grad(): 444 | maxk = max(topk) 445 | batch_size = target.size(0) 446 | 447 | _, pred = output.topk(maxk, 1, True, True) 448 | pred = pred.t() 449 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 450 | 451 | res = [] 452 | for k in topk: 453 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 454 | res.append(correct_k.mul_(100.0 / batch_size)) 455 | return res 456 | 457 | 458 | if __name__ == '__main__': 459 | parser = argparse.ArgumentParser('Linear probe evaluation', parents=[get_args_parser()]) 460 | args = parser.parse_args() 461 | main(args) 462 | -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import numpy as np 7 | import pickle 8 | import re 9 | from urllib.parse import unquote 10 | from tqdm import tqdm 11 | 12 | 13 | DATASET = 'yfcc100m_dataset.txt' 14 | 15 | cleanhtml = re.compile('|||||') 16 | cleanurl = re.compile('http\S+|www\S+') 17 | 18 | print('=> loading YFCC image ids') 19 | image_ids = np.load('flickr_unique_ids.npy') 20 | image_ids = set(image_ids) 21 | 22 | print('=> loading CLIP image ids') 23 | clip_ids = set() 24 | with open('yfcc100m_subset_data.tsv') as f: 25 | for l in tqdm(f.readlines()): 26 | row = l.strip().split('\t') 27 | clip_ids.add(int(row[0])) 28 | 29 | print('=> collecting and cleaning subset captions') 30 | captioned = [] 31 | uncaptioned = [] 32 | with open('yfcc100m_dataset.txt') as f: 33 | for l in tqdm(f.readlines()): 34 | row = l.strip().split('\t') 35 | if int(row[0]) in image_ids: 36 | uncaptioned.append(int(row[0])) 37 | if int(row[0]) in clip_ids: 38 | title = unquote(row[8]).replace('+', ' ') 39 | title = re.sub(cleanhtml, '', title) 40 | title = re.sub(cleanurl, '', title) 41 | 42 | desc = unquote(row[9]).replace('+', ' ') 43 | desc = re.sub(cleanhtml, '', desc) 44 | desc = re.sub(cleanurl, '', desc) 45 | 46 | captioned.append((int(row[0]), title, desc)) 47 | 48 | with open('yfcc15m.pkl', 'wb') as f: 49 | pickle.dump(captioned, f) 50 | 51 | with open('yfcc100m.pkl', 'wb') as f: 52 | pickle.dump(uncaptioned, f) 53 | 54 | print('Total captioned images:', len(captioned)) # 14689580 55 | print('Total uncaptioned images:', len(uncaptioned)) # 95920149 56 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from github.com/openai/CLIP 8 | from collections import OrderedDict 9 | 10 | import numpy as np 11 | import timm 12 | import torch 13 | from torch import nn 14 | 15 | import losses 16 | 17 | 18 | class LayerNorm(nn.LayerNorm): 19 | """Subclass torch's LayerNorm to handle fp16.""" 20 | 21 | def forward(self, x: torch.Tensor): 22 | orig_type = x.dtype 23 | ret = super().forward(x.type(torch.float32)) 24 | return ret.type(orig_type) 25 | 26 | 27 | class QuickGELU(nn.Module): 28 | def forward(self, x: torch.Tensor): 29 | return x * torch.sigmoid(1.702 * x) 30 | 31 | 32 | class ResidualAttentionBlock(nn.Module): 33 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 34 | super().__init__() 35 | 36 | self.attn = nn.MultiheadAttention(d_model, n_head) 37 | self.ln_1 = LayerNorm(d_model) 38 | self.mlp = nn.Sequential(OrderedDict([ 39 | ("c_fc", nn.Linear(d_model, d_model * 4)), 40 | ("gelu", QuickGELU()), 41 | ("c_proj", nn.Linear(d_model * 4, d_model)) 42 | ])) 43 | self.ln_2 = LayerNorm(d_model) 44 | self.attn_mask = attn_mask 45 | 46 | def attention(self, x: torch.Tensor): 47 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 48 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 49 | 50 | def forward(self, x: torch.Tensor): 51 | x = x + self.attention(self.ln_1(x)) 52 | x = x + self.mlp(self.ln_2(x)) 53 | return x 54 | 55 | 56 | class Transformer(nn.Module): 57 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 58 | super().__init__() 59 | self.width = width 60 | self.layers = layers 61 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 62 | 63 | def forward(self, x: torch.Tensor): 64 | return self.resblocks(x) 65 | 66 | 67 | class CLIP(nn.Module): 68 | def __init__(self, 69 | embed_dim: int, 70 | # vision 71 | vision_width: int, 72 | vision_model: nn.Module, 73 | # text 74 | context_length: int, 75 | vocab_size: int, 76 | transformer_width: int, 77 | transformer_heads: int, 78 | transformer_layers: int, 79 | **kwargs, 80 | ): 81 | super().__init__() 82 | 83 | self.context_length = context_length 84 | self.vision_width = vision_width 85 | 86 | self.visual = vision_model 87 | 88 | self.transformer = Transformer( 89 | width=transformer_width, 90 | layers=transformer_layers, 91 | heads=transformer_heads, 92 | attn_mask=self.build_attention_mask(), 93 | ) 94 | 95 | self.vocab_size = vocab_size 96 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 97 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 98 | self.ln_final = LayerNorm(transformer_width) 99 | 100 | self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) 101 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 102 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 103 | 104 | self.initialize_parameters() 105 | 106 | def initialize_parameters(self): 107 | nn.init.normal_(self.token_embedding.weight, std=0.02) 108 | nn.init.normal_(self.positional_embedding, std=0.01) 109 | 110 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 111 | attn_std = self.transformer.width ** -0.5 112 | fc_std = (2 * self.transformer.width) ** -0.5 113 | for block in self.transformer.resblocks: 114 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 115 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 116 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 117 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 118 | 119 | nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5) 120 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 121 | 122 | def build_attention_mask(self): 123 | # lazily create causal attention mask, with full attention between the vision tokens 124 | # pytorch uses additive attention mask; fill with -inf 125 | mask = torch.empty(self.context_length, self.context_length) 126 | mask.fill_(float("-inf")) 127 | mask.triu_(1) # zero out the lower diagonal 128 | return mask 129 | 130 | def encode_image(self, image): 131 | x = self.visual(image) 132 | x = x @ self.image_projection 133 | 134 | return x 135 | 136 | def encode_text(self, text): 137 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model] 138 | x = x + self.positional_embedding 139 | x = x.permute(1, 0, 2) # NLD -> LND 140 | x = self.transformer(x) 141 | x = x.permute(1, 0, 2) # LND -> NLD 142 | x = self.ln_final(x) 143 | 144 | # x.shape = [batch_size, n_ctx, transformer.width] 145 | # take features from the eot embedding (eot_token is the highest number in each sequence) 146 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 147 | 148 | return x 149 | 150 | def forward(self, image, text): 151 | image_embed = self.encode_image(image) 152 | text_embed = self.encode_text(text) 153 | 154 | return {'image_embed': image_embed, 155 | 'text_embed': text_embed, 156 | 'logit_scale': self.logit_scale.exp()} 157 | 158 | 159 | class SIMCLR(nn.Module): 160 | def __init__(self, 161 | # vision 162 | vision_width: int, 163 | vision_model: nn.Module, 164 | # ssl 165 | ssl_mlp_dim: int, 166 | ssl_emb_dim: int, 167 | **kwargs, 168 | ): 169 | super().__init__() 170 | 171 | self.vision_width = vision_width 172 | self.visual = vision_model 173 | 174 | self.image_mlp = self._build_mlp(in_dim=vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim) 175 | 176 | def _build_mlp(self, in_dim, mlp_dim, out_dim): 177 | return nn.Sequential(OrderedDict([ 178 | ("layer1", nn.Linear(in_dim, mlp_dim)), 179 | ("bn1", nn.SyncBatchNorm(mlp_dim)), 180 | ("relu1", nn.ReLU(inplace=True)), 181 | ("layer2", nn.Linear(mlp_dim, mlp_dim)), 182 | ("bn2", nn.SyncBatchNorm(mlp_dim)), 183 | ("relu2", nn.ReLU(inplace=True)), 184 | ("layer3", nn.Linear(mlp_dim, out_dim)), 185 | ])) 186 | 187 | def encode_image(self, image): 188 | x = self.visual(image) 189 | 190 | return x 191 | 192 | def forward(self, aug1, aug2): 193 | h1 = self.visual(aug1) 194 | h2 = self.visual(aug2) 195 | 196 | aug1_embed = self.image_mlp(h1) 197 | aug2_embed = self.image_mlp(h2) 198 | 199 | return {'aug1_embed': aug1_embed, 200 | 'aug2_embed': aug2_embed} 201 | 202 | 203 | class SLIP(CLIP): 204 | def __init__(self, 205 | ssl_mlp_dim: int, 206 | ssl_emb_dim: int, 207 | **kwargs, 208 | ): 209 | super().__init__(**kwargs) 210 | 211 | self.image_mlp = self._build_mlp(in_dim=self.vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim) 212 | 213 | def _build_mlp(self, in_dim, mlp_dim, out_dim): 214 | return nn.Sequential(OrderedDict([ 215 | ("layer1", nn.Linear(in_dim, mlp_dim)), 216 | ("bn1", nn.SyncBatchNorm(mlp_dim)), 217 | ("relu1", nn.ReLU(inplace=True)), 218 | ("layer2", nn.Linear(mlp_dim, mlp_dim)), 219 | ("bn2", nn.SyncBatchNorm(mlp_dim)), 220 | ("relu2", nn.ReLU(inplace=True)), 221 | ("layer3", nn.Linear(mlp_dim, out_dim)), 222 | ])) 223 | 224 | def forward(self, image, text, aug1, aug2): 225 | aug1_embed = self.image_mlp(self.visual(aug1)) 226 | aug2_embed = self.image_mlp(self.visual(aug2)) 227 | 228 | image_embed = self.encode_image(image) 229 | text_embed = self.encode_text(text) 230 | 231 | return {'image_embed': image_embed, 232 | 'text_embed': text_embed, 233 | 'logit_scale': self.logit_scale.exp(), 234 | 'aug1_embed': aug1_embed, 235 | 'aug2_embed': aug2_embed} 236 | 237 | 238 | def get_loss(model, ssl_temp, ssl_scale): 239 | if model.startswith('SLIP'): 240 | ssl_loss = losses.SIMCLRLoss(temperature=ssl_temp) 241 | return losses.SLIPLoss(ssl_loss, ssl_scale) 242 | if model.startswith('CLIP'): 243 | return losses.CLIPLoss() 244 | if model.startswith('SIMCLR'): 245 | return losses.SIMCLRLoss(temperature=ssl_temp) 246 | 247 | 248 | def get_metric_names(model): 249 | if model.startswith('SLIP'): 250 | return ['loss', 'clip_loss', 'ssl_loss', 'clip_acc', 'ssl_acc'] 251 | elif model.startswith('CLIP'): 252 | return ['loss', 'clip_loss', 'clip_acc'] 253 | else: 254 | return ['loss', 'ssl_loss', 'ssl_acc'] 255 | 256 | 257 | @timm.models.registry.register_model 258 | def vit_small_mocov3_patch16_224(**kwargs): 259 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs) 260 | model = timm.models.vision_transformer._create_vision_transformer('vit_small_patch16_224', **model_kwargs) 261 | 262 | return model 263 | 264 | 265 | def CLIP_VITS16(**kwargs): 266 | vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0) 267 | model = CLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408, 268 | transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) 269 | 270 | return model 271 | 272 | 273 | def SIMCLR_VITS16(**kwargs): 274 | vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0) 275 | model = SIMCLR(vision_width=384, vision_model=vision_model, **kwargs) 276 | 277 | return model 278 | 279 | 280 | def SLIP_VITS16(**kwargs): 281 | vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0) 282 | model = SLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408, 283 | transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) 284 | 285 | return model 286 | 287 | 288 | def CLIP_VITB16(**kwargs): 289 | vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) 290 | model = CLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408, 291 | transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) 292 | 293 | return model 294 | 295 | 296 | def SIMCLR_VITB16(**kwargs): 297 | vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) 298 | model = SIMCLR(vision_width=768, vision_model=vision_model, **kwargs) 299 | 300 | return model 301 | 302 | 303 | def SLIP_VITB16(**kwargs): 304 | vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) 305 | model = SLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408, 306 | transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) 307 | 308 | return model 309 | 310 | 311 | def CLIP_VITL16(**kwargs): 312 | vision_model = timm.create_model('vit_large_patch16_224', num_classes=0) 313 | model = CLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408, 314 | transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) 315 | 316 | return model 317 | 318 | 319 | def SIMCLR_VITL16(**kwargs): 320 | vision_model = timm.create_model('vit_large_patch16_224', num_classes=0) 321 | model = SIMCLR(vision_width=1024, vision_model=vision_model, **kwargs) 322 | 323 | return model 324 | 325 | 326 | def SLIP_VITL16(**kwargs): 327 | vision_model = timm.create_model('vit_large_patch16_224', num_classes=0) 328 | model = SLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408, 329 | transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) 330 | 331 | return model 332 | -------------------------------------------------------------------------------- /redcaps/combine_captions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | 6 | def get_args_parser(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--input', default='', type=str, help='path to redcaps annotations directory') 9 | parser.add_argument('--output', default='', type=str, help='output annotations file path') 10 | return parser 11 | 12 | 13 | def main(args): 14 | annos = [] 15 | for fname in os.listdir(args.input): 16 | if fname.endswith('json'): 17 | with open(os.path.join(args.input, fname)) as f: 18 | a = json.load(f) 19 | annos.extend(a['annotations']) 20 | 21 | with open(args.output, 'w') as f: 22 | json.dump(annos, f) 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = get_args_parser() 27 | args = parser.parse_args() 28 | main(args) -------------------------------------------------------------------------------- /run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | A script to run multinode training with submitit. 8 | """ 9 | import argparse 10 | import os 11 | import uuid 12 | from pathlib import Path 13 | 14 | import main as main_slip 15 | import submitit 16 | 17 | 18 | def parse_args(): 19 | parser = main_slip.get_args_parser() 20 | parser = argparse.ArgumentParser("Submitit for SLIP pre-training", parents=[parser]) 21 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 22 | parser.add_argument("--nodes", default=8, type=int, help="Number of nodes to request") 23 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 24 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 25 | 26 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") 27 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 28 | parser.add_argument('--comment', default="", type=str, 29 | help='Comment to pass to scheduler, e.g. priority message') 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments/slip") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main as main_slip 57 | 58 | self._setup_gpu_args() 59 | main_slip.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | print("Requeuing ", self.args) 67 | empty_trainer = type(self)(self.args) 68 | return submitit.helpers.DelayedSubmission(empty_trainer) 69 | 70 | def _setup_gpu_args(self): 71 | import submitit 72 | from pathlib import Path 73 | 74 | job_env = submitit.JobEnvironment() 75 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 76 | self.args.gpu = job_env.local_rank 77 | self.args.rank = job_env.global_rank 78 | self.args.world_size = job_env.num_tasks 79 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 80 | 81 | 82 | def main(): 83 | args = parse_args() 84 | if args.job_dir == "": 85 | args.job_dir = get_shared_folder() / "%j" 86 | 87 | # Note that the folder will depend on the job_id, to easily track experiments 88 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 89 | 90 | num_gpus_per_node = args.ngpus 91 | nodes = args.nodes 92 | timeout_min = args.timeout 93 | 94 | partition = args.partition 95 | kwargs = {} 96 | if args.use_volta32: 97 | kwargs['slurm_constraint'] = 'volta32gb' 98 | if args.comment: 99 | kwargs['slurm_comment'] = args.comment 100 | 101 | executor.update_parameters( 102 | mem_gb=40 * num_gpus_per_node, 103 | gpus_per_node=num_gpus_per_node, 104 | tasks_per_node=num_gpus_per_node, # one task per GPU 105 | cpus_per_task=10, 106 | nodes=nodes, 107 | timeout_min=timeout_min, # max is 60 * 72 108 | # Below are cluster dependent parameters 109 | slurm_partition=partition, 110 | slurm_signal_delay_s=120, 111 | **kwargs 112 | ) 113 | 114 | executor.update_parameters(name="slip") 115 | 116 | args.dist_url = get_init_file().as_uri() 117 | args.output_dir = args.job_dir 118 | 119 | trainer = Trainer(args) 120 | job = executor.submit(trainer) 121 | 122 | print("Submitted job_id:", job.job_id) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /run_with_submitit_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | A script to run multinode training with submitit. 8 | """ 9 | import argparse 10 | import os 11 | import uuid 12 | from pathlib import Path 13 | 14 | import main_linear 15 | import submitit 16 | 17 | 18 | def parse_args(): 19 | parser = main_linear.get_args_parser() 20 | parser = argparse.ArgumentParser("Submitit for Linear Probe Eval", parents=[parser]) 21 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 22 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 23 | parser.add_argument("--timeout", default=1800, type=int, help="Duration of the job") 24 | parser.add_argument("--job_dir", default="", type=str, help="Job dir of miniclip training run to eval") 25 | 26 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") 27 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 28 | parser.add_argument('--comment', default="", type=str, 29 | help='Comment to pass to scheduler, e.g. priority message') 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments/slip") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_linear 57 | 58 | self._setup_gpu_args() 59 | main_linear.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | print("Requeuing ", self.args) 67 | empty_trainer = type(self)(self.args) 68 | return submitit.helpers.DelayedSubmission(empty_trainer) 69 | 70 | def _setup_gpu_args(self): 71 | import submitit 72 | from pathlib import Path 73 | 74 | job_env = submitit.JobEnvironment() 75 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 76 | self.args.gpu = job_env.local_rank 77 | self.args.rank = job_env.global_rank 78 | self.args.world_size = job_env.num_tasks 79 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 80 | 81 | 82 | def main(): 83 | args = parse_args() 84 | if args.job_dir == "": 85 | args.job_dir = get_shared_folder() / "%j" 86 | 87 | # Note that the folder will depend on the job_id, to easily track experiments 88 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 89 | 90 | num_gpus_per_node = args.ngpus 91 | nodes = args.nodes 92 | timeout_min = args.timeout 93 | 94 | partition = args.partition 95 | kwargs = {} 96 | if args.use_volta32: 97 | kwargs['slurm_constraint'] = 'volta32gb' 98 | if args.comment: 99 | kwargs['slurm_comment'] = args.comment 100 | 101 | executor.update_parameters( 102 | mem_gb=40 * num_gpus_per_node, 103 | gpus_per_node=num_gpus_per_node, 104 | tasks_per_node=num_gpus_per_node, # one task per GPU 105 | cpus_per_task=10, 106 | nodes=nodes, 107 | timeout_min=timeout_min, # max is 60 * 72 108 | # Below are cluster dependent parameters 109 | slurm_partition=partition, 110 | slurm_signal_delay_s=120, 111 | **kwargs 112 | ) 113 | 114 | executor.update_parameters(name="linear") 115 | 116 | args.dist_url = get_init_file().as_uri() 117 | args.output_dir = args.job_dir 118 | 119 | trainer = Trainer(args) 120 | job = executor.submit(trainer) 121 | 122 | print("Submitted job_id:", job.job_id) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /slip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SLIP/c6faf5d03cbfa7d529d210779f859cd3dddec09a/slip.png -------------------------------------------------------------------------------- /templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "food101": [ 3 | "a photo of {}, a type of food." 4 | ], 5 | "cifar10": [ 6 | "a photo of a {}.", 7 | "a blurry photo of a {}.", 8 | "a black and white photo of a {}.", 9 | "a low contrast photo of a {}.", 10 | "a high contrast photo of a {}.", 11 | "a bad photo of a {}.", 12 | "a good photo of a {}.", 13 | "a photo of a small {}.", 14 | "a photo of a big {}.", 15 | "a photo of the {}.", 16 | "a blurry photo of the {}.", 17 | "a black and white photo of the {}.", 18 | "a low contrast photo of the {}.", 19 | "a high contrast photo of the {}.", 20 | "a bad photo of the {}.", 21 | "a good photo of the {}.", 22 | "a photo of the small {}.", 23 | "a photo of the big {}." 24 | ], 25 | "cifar100": [ 26 | "a photo of a {}.", 27 | "a blurry photo of a {}.", 28 | "a black and white photo of a {}.", 29 | "a low contrast photo of a {}.", 30 | "a high contrast photo of a {}.", 31 | "a bad photo of a {}.", 32 | "a good photo of a {}.", 33 | "a photo of a small {}.", 34 | "a photo of a big {}.", 35 | "a photo of the {}.", 36 | "a blurry photo of the {}.", 37 | "a black and white photo of the {}.", 38 | "a low contrast photo of the {}.", 39 | "a high contrast photo of the {}.", 40 | "a bad photo of the {}.", 41 | "a good photo of the {}.", 42 | "a photo of the small {}.", 43 | "a photo of the big {}." 44 | ], 45 | "birdsnap": [ 46 | "a photo of a {}, a type of bird." 47 | ], 48 | "cub200": [ 49 | "a photo of a {}, a type of bird." 50 | ], 51 | "imagenet": [ 52 | "itap of a {}.", 53 | "a bad photo of the {}.", 54 | "a origami {}.", 55 | "a photo of the large {}.", 56 | "a {} in a video game.", 57 | "art of the {}.", 58 | "a photo of the small {}." 59 | ], 60 | "rendered_sst2": [ 61 | "a {} review of a movie." 62 | ], 63 | "hateful_memes": [ 64 | "a {}." 65 | ], 66 | "clevr_counts": [ 67 | "a photo of {} objects." 68 | ], 69 | "kinetics700_frames": [ 70 | "a photo of {}.", 71 | "a photo of a person {}.", 72 | "a photo of a person using {}.", 73 | "a photo of a person doing {}.", 74 | "a photo of a person during {}.", 75 | "a photo of a person performing {}.", 76 | "a photo of a person practicing {}.", 77 | "a video of {}.", 78 | "a video of a person {}.", 79 | "a video of a person using {}.", 80 | "a video of a person doing {}.", 81 | "a video of a person during {}.", 82 | "a video of a person performing {}.", 83 | "a video of a person practicing {}.", 84 | "a example of {}.", 85 | "a example of a person {}.", 86 | "a example of a person using {}.", 87 | "a example of a person doing {}.", 88 | "a example of a person during {}.", 89 | "a example of a person performing {}.", 90 | "a example of a person practicing {}.", 91 | "a demonstration of {}.", 92 | "a demonstration of a person {}.", 93 | "a demonstration of a person using {}.", 94 | "a demonstration of a person doing {}.", 95 | "a demonstration of a person during {}.", 96 | "a demonstration of a person performing {}.", 97 | "a demonstration of a person practicing {}." 98 | ], 99 | "ucf101_frames": [ 100 | "a photo of a person {}.", 101 | "a video of a person {}.", 102 | "a example of a person {}.", 103 | "a demonstration of a person {}.", 104 | "a photo of the person {}.", 105 | "a video of the person {}.", 106 | "a example of the person {}.", 107 | "a demonstration of the person {}.", 108 | "a photo of a person using {}.", 109 | "a video of a person using {}.", 110 | "a example of a person using {}.", 111 | "a demonstration of a person using {}.", 112 | "a photo of the person using {}.", 113 | "a video of the person using {}.", 114 | "a example of the person using {}.", 115 | "a demonstration of the person using {}.", 116 | "a photo of a person doing {}.", 117 | "a video of a person doing {}.", 118 | "a example of a person doing {}.", 119 | "a demonstration of a person doing {}.", 120 | "a photo of the person doing {}.", 121 | "a video of the person doing {}.", 122 | "a example of the person doing {}.", 123 | "a demonstration of the person doing {}.", 124 | "a photo of a person during {}.", 125 | "a video of a person during {}.", 126 | "a example of a person during {}.", 127 | "a demonstration of a person during {}.", 128 | "a photo of the person during {}.", 129 | "a video of the person during {}.", 130 | "a example of the person during {}.", 131 | "a demonstration of the person during {}.", 132 | "a photo of a person performing {}.", 133 | "a video of a person performing {}.", 134 | "a example of a person performing {}.", 135 | "a demonstration of a person performing {}.", 136 | "a photo of the person performing {}.", 137 | "a video of the person performing {}.", 138 | "a example of the person performing {}.", 139 | "a demonstration of the person performing {}.", 140 | "a photo of a person practicing {}.", 141 | "a video of a person practicing {}.", 142 | "a example of a person practicing {}.", 143 | "a demonstration of a person practicing {}.", 144 | "a photo of the person practicing {}.", 145 | "a video of the person practicing {}.", 146 | "a example of the person practicing {}.", 147 | "a demonstration of the person practicing {}." 148 | ], 149 | "patch_camelyon": [ 150 | "this is a photo of {}" 151 | ], 152 | "country211": [ 153 | "a photo i took in {}.", 154 | "a photo i took while visiting {}.", 155 | "a photo from my home country of {}.", 156 | "a photo from my visit to {}.", 157 | "a photo showing the country of {}." 158 | ], 159 | "kitti_distance": [ 160 | "{}" 161 | ], 162 | "gtsrb": [ 163 | "a zoomed in photo of a \"{}\" traffic sign.", 164 | "a centered photo of a \"{}\" traffic sign.", 165 | "a close up photo of a \"{}\" traffic sign." 166 | ], 167 | "resisc45": [ 168 | "satellite imagery of {}.", 169 | "aerial imagery of {}.", 170 | "satellite photo of {}.", 171 | "aerial photo of {}.", 172 | "satellite view of {}.", 173 | "aerial view of {}.", 174 | "satellite imagery of a {}.", 175 | "aerial imagery of a {}.", 176 | "satellite photo of a {}.", 177 | "aerial photo of a {}.", 178 | "satellite view of a {}.", 179 | "aerial view of a {}.", 180 | "satellite imagery of the {}.", 181 | "aerial imagery of the {}.", 182 | "satellite photo of the {}.", 183 | "aerial photo of the {}.", 184 | "satellite view of the {}.", 185 | "aerial view of the {}." 186 | ], 187 | "eurosat": [ 188 | "a centered satellite photo of {}.", 189 | "a centered satellite photo of a {}.", 190 | "a centered satellite photo of the {}." 191 | ], 192 | "stl10": [ 193 | "a photo of a {}.", 194 | "a photo of the {}." 195 | ], 196 | "fer2013": [ 197 | "a photo of a {} looking face.", 198 | "a photo of a face showing the emotion: {}.", 199 | "a photo of a face looking {}.", 200 | "a face that looks {}.", 201 | "they look {}.", 202 | "look at how {} they are." 203 | ], 204 | "mnist": [ 205 | "a photo of the number: \"{}\"." 206 | ], 207 | "flowers": [ 208 | "a photo of a {}, a type of flower." 209 | ], 210 | "caltech101": [ 211 | "a photo of a {}.", 212 | "a painting of a {}.", 213 | "a plastic {}.", 214 | "a sculpture of a {}.", 215 | "a sketch of a {}.", 216 | "a tattoo of a {}.", 217 | "a toy {}.", 218 | "a rendition of a {}.", 219 | "a embroidered {}.", 220 | "a cartoon {}.", 221 | "a {} in a video game.", 222 | "a plushie {}.", 223 | "a origami {}.", 224 | "art of a {}.", 225 | "graffiti of a {}.", 226 | "a drawing of a {}.", 227 | "a doodle of a {}.", 228 | "a photo of the {}.", 229 | "a painting of the {}.", 230 | "the plastic {}.", 231 | "a sculpture of the {}.", 232 | "a sketch of the {}.", 233 | "a tattoo of the {}.", 234 | "the toy {}.", 235 | "a rendition of the {}.", 236 | "the embroidered {}.", 237 | "the cartoon {}.", 238 | "the {} in a video game.", 239 | "the plushie {}.", 240 | "the origami {}.", 241 | "art of the {}.", 242 | "graffiti of the {}.", 243 | "a drawing of the {}.", 244 | "a doodle of the {}." 245 | ], 246 | "pets": [ 247 | "a photo of a {}, a type of pet." 248 | ], 249 | "dtd": [ 250 | "a photo of a {} texture.", 251 | "a photo of a {} pattern.", 252 | "a photo of a {} thing.", 253 | "a photo of a {} object.", 254 | "a photo of the {} texture.", 255 | "a photo of the {} pattern.", 256 | "a photo of the {} thing.", 257 | "a photo of the {} object." 258 | ], 259 | "voc2007": [ 260 | "a photo of a {}." 261 | ], 262 | "aircraft": [ 263 | "a photo of a {}, a type of aircraft.", 264 | "a photo of the {}, a type of aircraft." 265 | ], 266 | "cars": [ 267 | "a photo of a {}.", 268 | "a photo of the {}.", 269 | "a photo of my {}.", 270 | "i love my {}!", 271 | "a photo of my dirty {}.", 272 | "a photo of my clean {}.", 273 | "a photo of my new {}.", 274 | "a photo of my old {}." 275 | ], 276 | "sun397": [ 277 | "a photo of a {}.", 278 | "a photo of the {}." 279 | ] 280 | } -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from github.com/openai/CLIP 8 | import gzip 9 | import html 10 | import os 11 | from functools import lru_cache 12 | 13 | import ftfy 14 | import regex as re 15 | import torch 16 | 17 | 18 | @lru_cache() 19 | def default_bpe(): 20 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 21 | 22 | 23 | @lru_cache() 24 | def bytes_to_unicode(): 25 | """ 26 | Returns list of utf-8 byte and a corresponding list of unicode strings. 27 | The reversible bpe codes work on unicode strings. 28 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 29 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 30 | This is a signficant percentage of your normal, say, 32K bpe vocab. 31 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 32 | And avoids mapping to whitespace/control characters the bpe code barfs on. 33 | """ 34 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 35 | cs = bs[:] 36 | n = 0 37 | for b in range(2**8): 38 | if b not in bs: 39 | bs.append(b) 40 | cs.append(2**8+n) 41 | n += 1 42 | cs = [chr(n) for n in cs] 43 | return dict(zip(bs, cs)) 44 | 45 | 46 | def get_pairs(word): 47 | """Return set of symbol pairs in a word. 48 | Word is represented as tuple of symbols (symbols being variable-length strings). 49 | """ 50 | pairs = set() 51 | prev_char = word[0] 52 | for char in word[1:]: 53 | pairs.add((prev_char, char)) 54 | prev_char = char 55 | return pairs 56 | 57 | 58 | def basic_clean(text): 59 | text = ftfy.fix_text(text) 60 | text = html.unescape(html.unescape(text)) 61 | return text.strip() 62 | 63 | 64 | def whitespace_clean(text): 65 | text = re.sub(r'\s+', ' ', text) 66 | text = text.strip() 67 | return text 68 | 69 | 70 | class SimpleTokenizer(object): 71 | def __init__(self, bpe_path: str = default_bpe()): 72 | self.byte_encoder = bytes_to_unicode() 73 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 74 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 75 | merges = merges[1:49152-256-2+1] 76 | merges = [tuple(merge.split()) for merge in merges] 77 | vocab = list(bytes_to_unicode().values()) 78 | vocab = vocab + [v+'' for v in vocab] 79 | for merge in merges: 80 | vocab.append(''.join(merge)) 81 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 82 | self.encoder = dict(zip(vocab, range(len(vocab)))) 83 | self.decoder = {v: k for k, v in self.encoder.items()} 84 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 85 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 86 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 87 | 88 | def bpe(self, token): 89 | if token in self.cache: 90 | return self.cache[token] 91 | word = tuple(token[:-1]) + ( token[-1] + '',) 92 | pairs = get_pairs(word) 93 | 94 | if not pairs: 95 | return token+'' 96 | 97 | while True: 98 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 99 | if bigram not in self.bpe_ranks: 100 | break 101 | first, second = bigram 102 | new_word = [] 103 | i = 0 104 | while i < len(word): 105 | try: 106 | j = word.index(first, i) 107 | new_word.extend(word[i:j]) 108 | i = j 109 | except: 110 | new_word.extend(word[i:]) 111 | break 112 | 113 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 114 | new_word.append(first+second) 115 | i += 2 116 | else: 117 | new_word.append(word[i]) 118 | i += 1 119 | new_word = tuple(new_word) 120 | word = new_word 121 | if len(word) == 1: 122 | break 123 | else: 124 | pairs = get_pairs(word) 125 | word = ' '.join(word) 126 | self.cache[token] = word 127 | return word 128 | 129 | def encode(self, text): 130 | bpe_tokens = [] 131 | text = whitespace_clean(basic_clean(text)).lower() 132 | for token in re.findall(self.pat, text): 133 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 134 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 135 | return bpe_tokens 136 | 137 | def decode(self, tokens): 138 | text = ''.join([self.decoder[token] for token in tokens]) 139 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 140 | return text 141 | 142 | def __call__(self, texts, context_length=77): 143 | if isinstance(texts, str): 144 | texts = [texts] 145 | 146 | sot_token = self.encoder["<|startoftext|>"] 147 | eot_token = self.encoder["<|endoftext|>"] 148 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] 149 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 150 | 151 | for i, tokens in enumerate(all_tokens): 152 | tokens = tokens[:context_length] 153 | result[i, :len(tokens)] = torch.tensor(tokens) 154 | 155 | if len(result) == 1: 156 | return result[0] 157 | return result -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import numpy as np 7 | import os 8 | import random 9 | import shutil 10 | import torch 11 | import torch.distributed as dist 12 | import torch.autograd as autograd 13 | 14 | from PIL import ImageFilter 15 | 16 | 17 | def get_model(model): 18 | if isinstance(model, torch.nn.DataParallel) \ 19 | or isinstance(model, torch.nn.parallel.DistributedDataParallel): 20 | return model.module 21 | else: 22 | return model 23 | 24 | 25 | def setup_for_distributed(is_master): 26 | """ 27 | This function disables printing when not in master process 28 | """ 29 | import builtins as __builtin__ 30 | builtin_print = __builtin__.print 31 | 32 | def print(*args, **kwargs): 33 | force = kwargs.pop('force', False) 34 | if is_master or force: 35 | builtin_print(*args, **kwargs) 36 | 37 | __builtin__.print = print 38 | 39 | 40 | def is_dist_avail_and_initialized(): 41 | if not dist.is_available(): 42 | return False 43 | if not dist.is_initialized(): 44 | return False 45 | return True 46 | 47 | 48 | def get_world_size(): 49 | if not is_dist_avail_and_initialized(): 50 | return 1 51 | return dist.get_world_size() 52 | 53 | 54 | def get_rank(): 55 | if not is_dist_avail_and_initialized(): 56 | return 0 57 | return dist.get_rank() 58 | 59 | 60 | def is_main_process(): 61 | return get_rank() == 0 62 | 63 | 64 | def save_on_master(state, is_best, output_dir): 65 | if is_main_process(): 66 | ckpt_path = f'{output_dir}/checkpoint.pt' 67 | best_path = f'{output_dir}/checkpoint_best.pt' 68 | torch.save(state, ckpt_path) 69 | if is_best: 70 | shutil.copyfile(ckpt_path, best_path) 71 | 72 | 73 | def init_distributed_mode(args): 74 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 75 | args.rank = int(os.environ["RANK"]) 76 | args.world_size = int(os.environ['WORLD_SIZE']) 77 | args.gpu = int(os.environ['LOCAL_RANK']) 78 | elif 'SLURM_PROCID' in os.environ: 79 | args.rank = int(os.environ['SLURM_PROCID']) 80 | args.gpu = args.rank % torch.cuda.device_count() 81 | else: 82 | print('Not using distributed mode') 83 | args.distributed = False 84 | return 85 | 86 | args.distributed = True 87 | 88 | torch.cuda.set_device(args.gpu) 89 | args.dist_backend = 'nccl' 90 | print('| distributed init (rank {}): {}'.format( 91 | args.rank, args.dist_url), flush=True) 92 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 93 | world_size=args.world_size, rank=args.rank) 94 | torch.distributed.barrier() 95 | setup_for_distributed(args.rank == 0) 96 | 97 | 98 | def scaled_all_reduce(tensors, is_scale=True): 99 | """Performs the scaled all_reduce operation on the provided tensors. 100 | The input tensors are modified in-place. Currently supports only the sum 101 | reduction operator. The reduced values are scaled by the inverse size of the 102 | world size. 103 | """ 104 | world_size = get_world_size() 105 | # There is no need for reduction in the single-proc case 106 | if world_size == 1: 107 | return tensors 108 | # Queue the reductions 109 | reductions = [] 110 | for tensor in tensors: 111 | reduction = dist.all_reduce(tensor, async_op=True) 112 | reductions.append(reduction) 113 | # Wait for reductions to finish 114 | for reduction in reductions: 115 | reduction.wait() 116 | # Scale the results 117 | if is_scale: 118 | for tensor in tensors: 119 | tensor.mul_(1.0 / world_size) 120 | return tensors 121 | 122 | 123 | def all_gather_batch(tensors): 124 | """ 125 | Performs all_gather operation on the provided tensors. 126 | """ 127 | # Queue the gathered tensors 128 | world_size = get_world_size() 129 | # There is no need for reduction in the single-proc case 130 | if world_size == 1: 131 | return tensors 132 | tensor_list = [] 133 | output_tensor = [] 134 | for tensor in tensors: 135 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 136 | dist.all_gather( 137 | tensor_all, 138 | tensor, 139 | async_op=False # performance opt 140 | ) 141 | 142 | tensor_list.append(tensor_all) 143 | 144 | for tensor_all in tensor_list: 145 | output_tensor.append(torch.cat(tensor_all, dim=0)) 146 | return output_tensor 147 | 148 | 149 | class GatherLayer(autograd.Function): 150 | """ 151 | Gather tensors from all workers with support for backward propagation: 152 | This implementation does not cut the gradients as torch.distributed.all_gather does. 153 | """ 154 | 155 | @staticmethod 156 | def forward(ctx, x): 157 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 158 | dist.all_gather(output, x) 159 | return tuple(output) 160 | 161 | @staticmethod 162 | def backward(ctx, *grads): 163 | all_gradients = torch.stack(grads) 164 | dist.all_reduce(all_gradients) 165 | return all_gradients[dist.get_rank()] 166 | 167 | 168 | def all_gather_batch_with_grad(tensors): 169 | """ 170 | Performs all_gather operation on the provided tensors. 171 | Graph remains connected for backward grad computation. 172 | """ 173 | # Queue the gathered tensors 174 | world_size = get_world_size() 175 | # There is no need for reduction in the single-proc case 176 | if world_size == 1: 177 | return tensors 178 | tensor_list = [] 179 | output_tensor = [] 180 | 181 | for tensor in tensors: 182 | tensor_all = GatherLayer.apply(tensor) 183 | tensor_list.append(tensor_all) 184 | 185 | for tensor_all in tensor_list: 186 | output_tensor.append(torch.cat(tensor_all, dim=0)) 187 | return output_tensor 188 | 189 | 190 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 191 | warmup_schedule = np.array([]) 192 | warmup_iters = warmup_epochs * niter_per_ep 193 | if warmup_epochs > 0: 194 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 195 | 196 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 197 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 198 | 199 | schedule = np.concatenate((warmup_schedule, schedule)) 200 | assert len(schedule) == epochs * niter_per_ep 201 | return schedule 202 | 203 | 204 | class GaussianBlur(object): 205 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 206 | 207 | def __init__(self, sigma=[.1, 2.]): 208 | self.sigma = sigma 209 | 210 | def __call__(self, x): 211 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 212 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 213 | return x 214 | --------------------------------------------------------------------------------