├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── bpe_simple_vocab_16e6.txt.gz ├── dataset_catalog.json ├── datasets.py ├── docs ├── aclip.png ├── attMask.png ├── masking_Page_1.png └── vis.png ├── eval_zeroshot.py ├── losses.py ├── main.py ├── main_linear.py ├── make_dataset.py ├── models.py ├── templates.json ├── tokenizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A-CLIP 2 | 3 | 4 | ## Introduction 5 | This is the official implementation of Attentive Mask CLIP ([A-CLIP](https://arxiv.org/abs/2212.08653v2), ICCV2023). A-CLIP aims to improve the efficiency of CLIP training by introducing an efficient image augmentation approach called image token removal. 6 | 7 | The purpose of this work is to improve the training efficiency of CLIP by using the image token removal strategy. This method has been proven effective in tasks such as mask image modeling. However, random masking in CLIP may lead to semantic mismatch issues that could affect performance. To address this concern, we propose an attentive masking strategy that retains the most relevant areas to the text while removing tokens. 8 | 9 | Visualization 10 | 11 | For example, in the case of a Ferrari image in the top-left corner, it would be difficult to identify and correctly align with the text under random masking. However, with our attentive selection, we only eliminate irrelevant areas while maximizing the preservation of semantics to avoid ambiguity. 12 | 13 | 14 | | Mask Strategy (View × Ratio) | IN 1K 0-shot | Flickr30K I2T | Flickr30K T2I | MS COCO I2T | MS COCO T2I | 15 | |------------------------------|--------------|---------------|---------------|-------------|-------------| 16 | | w/o mask | | | | | | 17 | | 1×100% | 37.6 | 51.4 | 32.6 | 27.9 | 17.6 | 18 | | +random mask | | | | | | 19 | | 1×50% | 35.0 | 48.8 | 32.5 | 28.9 | 16.6 | 20 | | 2×50% | 38.0 | 54.6 | 34.4 | 31.1 | 18.7 | 21 | | +attentive mask | | | | | | 22 | | 1×50% | 39.5 | 57.6 | 36.6 | 34.2 | 19.8 | 23 | | 2×50% (highlighted) | **41.3** | **59.3** | **38.4** | **35.1** | **21.3** | 24 | 25 | Experiments show that attentive masking can avoid the semantic mismatch problems caused by random mask and bring the following benefits: 26 | 1. **Training efficiency**: Eliminates redundant tokens and improves CLIP's training efficiency. 27 | 2. **Feature stability**: Retains semantically relevant areas, enhances feature stability, and alleviates ambiguity caused by semantic mismatch. By using this approach, our A-CLIP significantly improves both the training efficiency and performance of CLIP. 28 | 29 | We compare our attentive mask CLIP (A-CLIP) with CLIP, SLIP, and MaskCLIP. A-CLIP outperforms CLIP by +6.3%, +11.3/+9.5 and +10.1/+5.6 on Imagenet-1K zero-shot classification, Flickr30K and MS COCO I2T/T2I retrieval. An efficient variant termed A-CLIP-eff outperforms CLIP by +5.3%, +11.3/+8.0, and +9.5/+4.9 on these benchmarks, while reducing the training time to 0.86x. 30 | 31 | | Methods | Training Time | GPU Memory | IN 1K 0-shot | Flickr30K I2T/T2I | MS COCO I2T/T2I | 32 | |---------|---------------|------------|--------------|------------------|-----------------| 33 | | CLIP | 1.00× | 14G | 37.6 | 51.4/32.6 | 27.9/17.6 | 34 | | SLIP | 2.67× | 30G | 42.8 | 57.2/41.2 | 33.6/21.9 | 35 | | MaskCLIP| 1.56× | 16G | 42.7 | 60.0/38.8 | 34.1/21.2 | 36 | | A-CLIP | 1.16× | 14G | **43.9** | **62.7/42.1** | **38.0/23.2** | 37 | | A-CLIP-eff | 0.86× | 13G | 42.9 | **62.7**/40.6 | 37.4/22.5 | 38 | 39 | > Note: The full training wall clock time and GPU memory footprint are measured on the same device. We report the training cost relative to the original CLIP. 40 | 41 | Zero-shot evaluation on a variety of classification benchmarks. The **Epochs** indicates the number of training rounds. A-CLIP significantly outperforms other methods at all epochs setting, both in terms of average accuracy and number of winning tracks of above 25 downstream tasks. 42 | 43 | | Epochs | Methods | Food-101 | CIFAR-10 | CIFAR-100 | CUB | SUN397 | Cars | Aircraft | DTD | Pets | Caltech-101 | Flowers | MNIST | FER-2013 | STL-10 | EuroSAT | RESISC45 | GTSRB | KITTI | Country211 | PCAM | UCF101 | Kinetics700 | CLEVR | HatefulMemes | SST2 | ImageNet | Average | 44 | |:------:|:--------:|:--------:|:--------:|:---------:|:-----:|:------:|:-----:|:--------:|:-----:|:-----:|:-----------:|:-------:|:-----:|:-------:|:------:|:-------:|:-------:|:-----:|:-----:|:---------:|:----:|:------:|:----------:|:-----:|:-----------:|:----:|:--:|:-------:| 45 | | 25 | CLIP | 50\.6 | 66\.0 | 34\.5 | 38\.8 | 51\.1 | 4\.0 | 5\.4 | 21\.2 | 28\.5 | 60\.9 | 53\.3 | 8\.4 | 17\.3 | 90\.5 | 30\.2 | 21\.5 | 6\.1 | 35\.1 | 10\.5 | 53\.5 | 28\.5 | 22\.1 | 10\.8 | 52\.4 | 50\.7 | 37\.6 | 34\.2 | 46 | | | SLIP | 59\.5 | 78\.6 | 45\.2 | 38\.7 | 53\.4 | **5\.4** | 5\.7 | **26\.1** | 31\.1 | 71\.0 | 56\.6 | 9\.8 | 19\.6 | 94\.4 | 20\.3 | 28\.9 | **14\.5** | 34\.0 | 11\.6 | **55\.4** | 37\.7 | 26\.9 | **17\.5** | **52\.8** | **51\.1** | 42\.8 | 38\.0 | 47 | | | MaskCLIP | **60\.6** | 70\.1 | 41\.6 | **43\.3** | 54\.0 | 4\.9 | **8\.2** | 25\.5 | **36\.8** | 68\.9 | 53\.6 | **11\.2** | 22\.4 | 93\.9 | **35\.1** | 24\.8 | 10\.1 | 30\.5 | 12\.5 | 51\.2 | 37\.0 | 28\.1 | 12\.9 | **52\.8** | 50\.0 | 42\.7 | 37\.8 | 48 | | | A-CLIP | 58\.3 | **82\.8** | **51\.0** | 43\.0 | **57\.0** | **5\.4** | 7\.6 | 26\.0 | 32\.0 | **71\.6** | **57\.7** | 9\.8 | **29\.7** | **95\.4** | 29\.3 | **30\.3** | 13\.1 | **35\.2** | **13\.5** | 51\.6 | **38\.3** | **29\.6** | 14\.1 | **52\.8** | 49\.9 | **43\.9** | **39\.6** | 49 | | 50 | CLIP | 55\.2 | 77\.0 | 43\.8 | 38\.9 | 49\.0 | 4\.7 | 6\.3 | 23\.5 | 27\.2 | 63\.5 | 56\.2 | **12\.5** | 30\.2 | 92\.1 | 21\.0 | 31\.9 | 7\.4 | 33\.6 | 10\.9 | 50\.8 | 35\.5 | 24\.8 | 14\.0 | 49\.9 | 50\.1 | 39\.4 | 36\.5 | 50 | | | SLIP | 61\.9 | 76\.8 | 48\.9 | 39\.2 | 54\.8 | 7\.3 | 9\.0 | **29\.8** | 31\.9 | **75\.0** | 57\.7 | 9\.8 | 24\.9 | **95\.6** | 37\.8 | 32\.5 | 9\.0 | **35\.1** | 12\.7 | 54\.4 | 41\.1 | 30\.3 | 13\.8 | 49\.5 | 49\.9 | 44\.1 | 39\.7 | 51 | | | A-CLIP | **62\.2** | **81\.5** | **53\.7** | **48\.2** | **58\.7** | **8\.3** | **10\.2** | 27\.7 | **40\.5** | 73\.3 | **61\.0** | 11\.3 | **32\.9** | 95\.5 | **39\.7** | **37\.5** | **9\.4** | 23\.3 | **14\.4** | **63\.7** | **42\.5** | **31\.6** | **19\.6** | **50\.8** | **52\.3** | **46\.3** | **42\.2** | 52 | | 100 | CLIP | 60\.4 | 79\.4 | 44\.6 | 43\.3 | 53\.0 | 8\.5 | 8\.2 | 26\.2 | 34\.7 | 68\.9 | 59\.2 | 11\.4 | 20\.4 | 93\.2 | 23\.3 | 27\.3 | 10\.3 | 23\.1 | 12\.0 | 54\.0 | 36\.7 | 27\.7 | 13\.0 | 50\.9 | 50\.1 | 42\.7 | 37\.8 | 53 | | | SLIP | 63\.0 | 83\.1 | 50\.4 | 43\.0 | 52\.0 | 8\.3 | 8\.3 | 26\.2 | 34\.0 | 74\.6 | 61\.1 | **16\.1** | 32\.4 | 95\.1 | 22\.6 | 28\.5 | 10\.5 | **34\.8** | 11\.5 | 52\.1 | 37\.3 | 28\.3 | 13\.7 | **55\.2** | 49\.9 | 45\.0 | 39\.9 | 54 | | | A-CLIP | **66\.7** | **86\.6** | **58\.6** | **51\.4** | **58\.6** | **10\.5** | **11\.9** | **33\.1** | **48\.5** | **74\.9** | **64\.3** | 7\.8 | **31\.2** | **96\.7** | **35\.6** | **35\.8** | **12\.9** | 30\.5 | **15\.7** | **57\.1** | **44\.1** | **33\.1** | **22\.9** | 52\.7 | **50\.7** | **48\.1** | **43\.8** | 55 | 56 | The table shows the results of using longer training schedulers and a larger model size. 57 | 58 | | Methods | IN 1K - 0-shot | Flickr30K - I2T | Flickr30K - T2I | MS COCO - I2T | MS COCO - T2I | 59 | |----------|----------------|----------------|----------------|---------------|---------------| 60 | | CLIP(25ep) | 37.6 | 51.4 | 32.6 | 27.9 | 17.6 | 61 | | SLIP(25ep) | 42.8 | 57.2 | 41.2 | 33.6 | 21.9 | 62 | | A-CLIP(25ep) | **43.9** | **62.7** | **42.1** | **38.0** | **23.2** | 63 | | CLIP(50ep) | 39.4 | 53.9 | 35.8 | 30.2 | 19.2 | 64 | | SLIP(50ep) | 44.1 | 60.6 | 41.1 | 33.2 | 22.3 | 65 | | A-CLIP(50ep) | **46.3** | **66.7** | **43.2** | **39.8** | **24.4** | 66 | | CLIP(100ep) | 42.7 | 61.0 | 37.9 | 34.4 | 20.9 | 67 | | SLIP(100ep) | 45.0 | 59.3 | 41.4 | 34.6 | 22.7 | 68 | | A-CLIP(100ep) | **48.0** | **66.3** | **45.7** | **40.7** | **25.1** | 69 | | CLIP(VIT-L) | 40.4 | 51.4 | 35.2 | 28.9 | 18.5 | 70 | | SLIP(VIT-L) | 46.2 | 60.6 | 43.7 | 35.3 | 23.5 | 71 | | A-CLIP(VIT-L) | **48.9** | **64.1** | **48.2** | **39.1** | **26.9** | 72 | 73 | 74 | ## Setup 75 | Install [PyTorch](https://pytorch.org) and [timm](https://github.com/rwightman/pytorch-image-models). 76 | The code has been tested with CUDA 11.6, PyTorch 1.13.0 and timm 0.5.4. 77 | 78 | ### YFCC15M Setup 79 | For data preparation, refer to [SLIP](https://github.com/facebookresearch/SLIP/tree/main). 80 | 81 | ## Pre-training 82 | ### A-CLIP ViT-Base with 8-nodes (batch size 4096) 83 | ``` 84 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 main.py \ 85 | --root /path/to/yfcc100m --dataset yfcc15m --metadata /path/to/yfcc15m.pkl \ 86 | --model ACLIP_VITB16 --batch-size 64 \ 87 | --lr 5e-4 --wd 0.5 88 | ``` 89 | 90 | ## Visualization 91 | There are some cases to show Attentive mask magically preserve the content of text descriptions and filter out redundant backgrounds. 92 | 93 | Visualization 94 | 95 | # Citation 96 | If the code and paper help your research, please kindly cite: 97 | ``` 98 | @InProceedings{Yang_2023_ICCV, author = {Yang, Yifan and Huang, Weiquan and Wei, Yixuan and Peng, Houwen and Jiang, Xinyang and Jiang, Huiqiang and Wei, Fangyun and Wang, Yin and Hu, Han and Qiu, Lili and Yang, Yuqing}, title = {Attentive Mask CLIP}, booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, month = {October}, year = {2023}, pages = {2771-2781} } 99 | ``` -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /dataset_catalog.json: -------------------------------------------------------------------------------- 1 | { 2 | "food101": { 3 | "path": "/path/to/dataset", 4 | "type": "imagefolder", 5 | "train": "train", 6 | "test": "test" 7 | }, 8 | "cifar10": { 9 | "path": "./data/cifar10", 10 | "type": "special" 11 | }, 12 | "cifar100": { 13 | "path": "./data/cifar100", 14 | "type": "special" 15 | }, 16 | "cub200": { 17 | "path": "/path/to/dataset", 18 | "type": "imagefolder", 19 | "train": "train", 20 | "test": "val" 21 | }, 22 | "sun397": { 23 | "path": "/path/to/dataset", 24 | "type": "filelist", 25 | "train": "train", 26 | "test": "test" 27 | }, 28 | "cars": { 29 | "path": "/path/to/dataset", 30 | "type": "imagefolder", 31 | "train": "train", 32 | "test": "test" 33 | }, 34 | "aircraft": { 35 | "path": "/path/to/dataset", 36 | "type": "imagefolder", 37 | "train": "train", 38 | "test": "test" 39 | }, 40 | "dtd": { 41 | "path": "/path/to/dataset", 42 | "type": "imagefolder", 43 | "train": "train", 44 | "test": "test" 45 | }, 46 | "pets": { 47 | "path": "/path/to/dataset", 48 | "type": "imagefolder", 49 | "train": "train", 50 | "test": "test" 51 | }, 52 | "caltech101": { 53 | "path": "/path/to/dataset", 54 | "type": "imagefolder", 55 | "train": "train", 56 | "test": "test" 57 | }, 58 | "flowers": { 59 | "path": "/path/to/dataset", 60 | "type": "imagefolder", 61 | "train": "train", 62 | "test": "test" 63 | }, 64 | "mnist": { 65 | "path": "./data/mnist", 66 | "type": "special" 67 | }, 68 | "fer2013": { 69 | "path": "/path/to/dataset", 70 | "type": "imagefolder", 71 | "train": "train", 72 | "test": "test" 73 | }, 74 | "stl10": { 75 | "path": "./data/stl10", 76 | "type": "special" 77 | }, 78 | "eurosat": { 79 | "path": "/path/to/dataset", 80 | "type": "imagefolder", 81 | "train": "train", 82 | "test": "val" 83 | }, 84 | "resisc45": { 85 | "path": "/path/to/dataset", 86 | "type": "imagefolder", 87 | "train": "train", 88 | "test": "test" 89 | }, 90 | "gtsrb": { 91 | "path": "/path/to/dataset", 92 | "type": "imagefolder", 93 | "train": "train", 94 | "test": "test" 95 | }, 96 | "kitti_distance": { 97 | "path": "/path/to/dataset", 98 | "type": "imagefolder", 99 | "train": "train", 100 | "test": "val" 101 | }, 102 | "country211": { 103 | "path": "/path/to/dataset", 104 | "type": "imagefolder", 105 | "train": "train", 106 | "test": "test" 107 | }, 108 | "patch_camelyon": { 109 | "path": "/path/to/dataset", 110 | "type": "imagefolder", 111 | "train": "train", 112 | "test": "val" 113 | }, 114 | "ucf101_frames": { 115 | "path": "/path/to/dataset", 116 | "type": "imagefolder", 117 | "train": "train", 118 | "test": "val" 119 | }, 120 | "kinetics700_frames": { 121 | "path": "/path/to/dataset", 122 | "type": "imagefolder", 123 | "train": "train_images", 124 | "test": "val_images" 125 | }, 126 | "clevr_counts": { 127 | "path": "/path/to/dataset", 128 | "type": "filelist", 129 | "train": "train", 130 | "test": "val" 131 | }, 132 | "hateful_memes": { 133 | "path": "/path/to/dataset", 134 | "type": "imagefolder", 135 | "train": "train", 136 | "test": "dev" 137 | }, 138 | "rendered_sst2": { 139 | "path": "/path/to/dataset", 140 | "type": "imagefolder", 141 | "train": "train", 142 | "test": "test" 143 | }, 144 | "imagenet": { 145 | "path": "/path/to/dataset", 146 | "type": "imagefolder", 147 | "train": "train", 148 | "test": "val" 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from collections import defaultdict 7 | import json 8 | import os 9 | import pickle 10 | import zipfile 11 | 12 | import numpy as np 13 | from PIL import Image, ImageFile 14 | 15 | import torch 16 | from torchvision import transforms 17 | from torchvision import datasets as t_datasets 18 | from torchvision.transforms import functional as F 19 | 20 | import utils 21 | 22 | 23 | ImageFile.LOAD_TRUNCATED_IMAGES = True 24 | 25 | 26 | def pil_loader(path): 27 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 28 | with open(path, 'rb') as f: 29 | img = Image.open(f) 30 | return img.convert('RGB') 31 | 32 | 33 | def yfcc_loader(root, index): 34 | index = format(index, "0>8d") 35 | repo = index[:2] 36 | z = index[2: 5] 37 | file_img = index[5:] + '.jpg' 38 | path_zip = os.path.join(root, 'images', repo, z) + '.zip' 39 | with zipfile.ZipFile(path_zip, 'r') as myzip: 40 | img = Image.open(myzip.open(file_img)) 41 | return img.convert('RGB') 42 | 43 | 44 | class ImageCaptionDatasetBase(torch.utils.data.Dataset): 45 | def __init__(self, dataset, root, metadata): 46 | self.dataset = dataset 47 | self.root = root 48 | if self.dataset == 'yfcc15m': 49 | with open(metadata, 'rb') as f: 50 | self.samples = pickle.load(f) 51 | elif self.dataset == 'coco': 52 | samples = defaultdict(list) 53 | with open(metadata) as f: 54 | annotations = json.load(f)['annotations'] 55 | for ann in annotations: 56 | samples[ann['image_id']].append(ann['caption']) 57 | self.samples = [(k, v) for k, v in samples.items()] 58 | elif self.dataset == 'cc12m' or self.dataset == 'cc3m': 59 | self.samples = np.load(metadata, allow_pickle=True) 60 | elif self.dataset == 'redcaps': 61 | with open(metadata) as f: 62 | annotations = json.load(f) 63 | self.samples = [(ann['image_id'], ann['subreddit'], ann['caption']) for ann in annotations] 64 | 65 | def get_raw_item(self, i): 66 | if self.dataset == 'yfcc15m': 67 | index, title, desc = self.samples[i] 68 | caption = np.random.choice([title, desc]) 69 | img = yfcc_loader(self.root, index) 70 | elif self.dataset == 'coco': 71 | index, captions = self.samples[i] 72 | path = os.path.join(self.root, 'train2017', '{:012d}.jpg'.format(index)) 73 | img = pil_loader(path) 74 | caption = np.random.choice(captions) 75 | elif self.dataset == 'cc3m': 76 | ann = self.samples[i] 77 | filename, captions = ann['image_id'], ann['captions'] 78 | path = os.path.join(self.root, str(filename)) 79 | img = pil_loader(path) 80 | caption = np.random.choice(captions) 81 | elif self.dataset == 'cc12m': 82 | ann = self.samples[i] 83 | filename, captions = ann['image_name'], ann['captions'] 84 | path = os.path.join(self.root, filename) 85 | img = pil_loader(path) 86 | caption = np.random.choice(captions) 87 | elif self.dataset == 'redcaps': 88 | image_id, subreddit, caption = self.samples[i] 89 | path = os.path.join(self.root, subreddit, f"{image_id}.jpg") 90 | img = pil_loader(path) 91 | 92 | return img, caption 93 | 94 | def __getitem__(self, i): 95 | raise NotImplementedError 96 | 97 | def __len__(self): 98 | return len(self.samples) 99 | 100 | 101 | class ImageCaptionDatasetCLIP(ImageCaptionDatasetBase): 102 | def __init__(self, dataset, root, metadata, transform=None, tokenizer=None): 103 | super().__init__(dataset, root, metadata) 104 | 105 | self.transform = transform 106 | self.tokenizer = tokenizer 107 | 108 | def __getitem__(self, i): 109 | img, caption = self.get_raw_item(i) 110 | 111 | # apply transformation 112 | if self.transform is not None: 113 | image = self.transform(img) 114 | 115 | # tokenize caption 116 | if self.tokenizer is not None: 117 | caption = self.tokenizer(caption) 118 | 119 | return image, caption 120 | 121 | 122 | class ImageCaptionDatasetSLIP(ImageCaptionDatasetBase): 123 | def __init__(self, dataset, root, metadata, transform, augment, tokenizer=None): 124 | super().__init__(dataset, root, metadata) 125 | 126 | self.transform = transform 127 | self.augment = augment 128 | self.tokenizer = tokenizer 129 | 130 | def __getitem__(self, i): 131 | img, caption = self.get_raw_item(i) 132 | 133 | image = self.transform(img) 134 | aug1 = self.augment(img) 135 | aug2 = self.augment(img) 136 | 137 | # tokenize caption 138 | if self.tokenizer is not None: 139 | caption = self.tokenizer(caption) 140 | 141 | return image, caption, aug1, aug2 142 | 143 | class ImageCaptionDatasetACLIP(ImageCaptionDatasetBase): 144 | def __init__(self, dataset, root, metadata, transform, ema_transform, tokenizer=None): 145 | super().__init__(dataset, root, metadata) 146 | self.transform = transform 147 | self.ema_transform = ema_transform 148 | self.tokenizer = tokenizer 149 | self.get_three_crop = GetThreeRandomResizedCrop(224, scale=(0.5, 1.0)) 150 | 151 | def __getitem__(self, i): 152 | img, caption = self.get_raw_item(i) 153 | res = self.get_three_crop(img) 154 | 155 | im1, ret1 = res[0] 156 | im2, ret2 = res[1] 157 | im3, ret3 = res[2] 158 | 159 | im1 = self.transform(im1) 160 | im2 = self.transform(im2) 161 | im3 = self.ema_transform(im3) 162 | 163 | pos = np.array([ret1,ret2,ret3]) 164 | if self.tokenizer is not None: 165 | caption = self.tokenizer(caption) 166 | 167 | return [im1, im2, im3], pos, caption 168 | 169 | class ImageCaptionDatasetSSL(ImageCaptionDatasetBase): 170 | def __init__(self, dataset, root, metadata, augment): 171 | super().__init__(dataset, root, metadata) 172 | 173 | self.augment = augment 174 | 175 | def __getitem__(self, i): 176 | img, _ = self.get_raw_item(i) 177 | 178 | aug1 = self.augment(img) 179 | aug2 = self.augment(img) 180 | 181 | return aug1, aug2 182 | 183 | 184 | class FileListDataset(torch.utils.data.Dataset): 185 | def __init__(self, images, labels, transform=None, target_transform=None): 186 | self.transform = transform 187 | self.target_transform = target_transform 188 | self.images = np.load(images) 189 | self.labels = np.load(labels) 190 | 191 | def __getitem__(self, index): 192 | img = pil_loader(self.images[index]) 193 | target = self.labels[index] 194 | 195 | if self.transform is not None: 196 | img = self.transform(img) 197 | 198 | if self.target_transform is not None: 199 | target = self.target_transform(target) 200 | 201 | return img, target 202 | 203 | def __len__(self): 204 | return len(self.images) 205 | 206 | 207 | def get_downstream_dataset(catalog, name, is_train, transform): 208 | entry = catalog[name] 209 | root = entry['path'] 210 | if entry['type'] == 'imagefolder': 211 | dataset = t_datasets.ImageFolder(os.path.join(root, entry['train'] if is_train else entry['test']), 212 | transform=transform) 213 | elif entry['type'] == 'special': 214 | if name == 'cifar10': 215 | dataset = t_datasets.CIFAR10(root, train=is_train, 216 | transform=transform, download=True) 217 | elif name == 'cifar100': 218 | dataset = t_datasets.CIFAR100(root, train=is_train, 219 | transform=transform, download=True) 220 | elif name == 'stl10': 221 | dataset = t_datasets.STL10(root, split='train' if is_train else 'test', 222 | transform=transform, download=True) 223 | elif name == 'mnist': 224 | dataset = t_datasets.MNIST(root, train=is_train, 225 | transform=transform, download=True) 226 | elif entry['type'] == 'filelist': 227 | path = entry['train'] if is_train else entry['test'] 228 | val_images = os.path.join(root, path + '_images.npy') 229 | val_labels = os.path.join(root, path + '_labels.npy') 230 | if name == 'clevr_counts': 231 | target_transform = lambda x: ['count_10', 'count_3', 'count_4', 'count_5', 'count_6', 'count_7', 'count_8', 'count_9'].index(x) 232 | else: 233 | target_transform = None 234 | dataset = FileListDataset(val_images, val_labels, transform, target_transform) 235 | else: 236 | raise Exception('Unknown dataset') 237 | 238 | return dataset 239 | 240 | class GetThreeRandomResizedCrop(transforms.RandomResizedCrop): 241 | def forward(self, img): 242 | """ 243 | Args: 244 | img (PIL Image or Tensor): Image to be cropped and resized. 245 | 246 | Returns: 247 | List[(cropped image, ret)] *3. 248 | The scale of the last image is larger than the first two. 249 | """ 250 | ret1 = self.get_params(img, self.scale, self.ratio) 251 | ret2 = self.get_params(img, self.scale, self.ratio) 252 | 253 | try: 254 | _, height, width = F.get_dimensions(img) 255 | except: 256 | width, height = F.get_image_size(img) 257 | 258 | im1 = F.resized_crop(img, *ret1, self.size, self.interpolation) 259 | im2 = F.resized_crop(img, *ret2, self.size, self.interpolation) 260 | 261 | # zoom out 262 | ret3 = [0, 0, 0, 0] 263 | ret3[0], ret3[1], = min(ret1[0], ret2[0]), min(ret1[1], ret2[2]) 264 | 265 | rh = max(ret1[0] + ret1[2], ret2[0] + ret2[2]) 266 | rw = max(ret1[1] + ret1[3], ret2[1] + ret2[3]) 267 | ret3[2], ret3[3] = rh - ret3[0], rw - ret3[1] 268 | 269 | ret3[0] = torch.randint(0, ret3[0] + 1, size=(1,)).item() if ret3[0] > 0 else ret3[0] 270 | ret3[1] = torch.randint(0, ret3[1] + 1, size=(1,)).item() if ret3[1] > 0 else ret3[1] 271 | 272 | ret3[2] = torch.randint(ret3[2], height - ret3[0] + 1, size=(1,)).item() if ret3[2] < height else ret3[2] 273 | ret3[3] = torch.randint(ret3[3], width - ret3[1] + 1, size=(1,)).item() if ret3[3] < width else ret3[3] 274 | 275 | im3 = F.resized_crop(img, *ret3, self.size, self.interpolation) 276 | 277 | return [(im1, ret1), (im2, ret2), (im3, ret3)] 278 | 279 | def get_dataset(train_transform, tokenizer, args): 280 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 281 | std=[0.229, 0.224, 0.225]) 282 | augment = transforms.Comrete([ 283 | transforms.RandomResizedCrop(224, scale=(0.08, 1.)), 284 | transforms.RandomApply([ 285 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 286 | ], p=0.8), 287 | transforms.RandomGrayscale(p=0.2), 288 | transforms.RandomApply([utils.GaussianBlur([.1, 2.])], p=0.5), 289 | transforms.RandomHorizontalFlip(), 290 | transforms.ToTensor(), 291 | normalize, 292 | ]) 293 | 294 | if args.model.startswith('SIMCLR'): 295 | return ImageCaptionDatasetSSL(args.dataset, args.root, args.metadata, augment) 296 | elif args.model.startswith('CLIP'): 297 | return ImageCaptionDatasetCLIP(args.dataset, args.root, args.metadata, train_transform, tokenizer) 298 | elif args.model.startswith('SLIP'): 299 | return ImageCaptionDatasetSLIP(args.dataset, args.root, args.metadata, train_transform, augment, tokenizer) 300 | elif args.model.startswith('ACLIP'): 301 | ema_transform = transforms.Compose([ 302 | transforms.ToTensor(), 303 | normalize 304 | ]) 305 | return ImageCaptionDatasetACLIP(args.dataset, args.root, args.metadata, train_transform, ema_transform, tokenizer) 306 | 307 | 308 | 309 | -------------------------------------------------------------------------------- /docs/aclip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/docs/aclip.png -------------------------------------------------------------------------------- /docs/attMask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/docs/attMask.png -------------------------------------------------------------------------------- /docs/masking_Page_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/docs/masking_Page_1.png -------------------------------------------------------------------------------- /docs/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/docs/vis.png -------------------------------------------------------------------------------- /eval_zeroshot.py: -------------------------------------------------------------------------------- 1 | # Based on SLIP code bases 2 | # https://github.com/facebookresearch/SLIP 3 | # -------------------------------------------------------- 4 | import argparse 5 | from collections import OrderedDict 6 | import json 7 | import os 8 | from sklearn import metrics 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.utils.data 13 | import torchvision.transforms as transforms 14 | 15 | import datasets 16 | import models 17 | from tokenizer import SimpleTokenizer 18 | import utils 19 | 20 | 21 | def get_args_parser(): 22 | parser = argparse.ArgumentParser(description='SLIP 0-shot evaluations', add_help=False) 23 | parser.add_argument('--output-dir', default='./', type=str, help='output dir') 24 | parser.add_argument('--batch-size', default=256, type=int, help='batch_size') 25 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', 26 | help='number of data loading workers per process') 27 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') 28 | return parser 29 | 30 | 31 | def main(args): 32 | # optionally resume from a checkpoint (takes precedence over autoresume) 33 | if args.resume: 34 | ckpt_path = args.resume 35 | elif os.path.isfile(os.path.join(args.output_dir, 'checkpoint_best.pt')): 36 | ckpt_path = os.path.join(args.output_dir, 'checkpoint_best.pt') 37 | else: 38 | raise Exception('no checkpoint found') 39 | 40 | ckpt = torch.load(ckpt_path, map_location='cpu') 41 | state_dict = OrderedDict() 42 | for k, v in ckpt['state_dict'].items(): 43 | state_dict[k.replace('module.', '')] = v 44 | 45 | # create model 46 | old_args = ckpt['args'] 47 | print("=> creating model: {}".format(old_args.model)) 48 | model = getattr(models, old_args.model)() 49 | model.cuda() 50 | model.load_state_dict(state_dict, strict=True) 51 | print("=> loaded resume checkpoint '{}' (epoch {})".format(args.resume, ckpt['epoch'])) 52 | 53 | cudnn.benchmark = True 54 | 55 | cwd = os.path.dirname(os.path.realpath(__file__)) 56 | with open(os.path.join(cwd, 'dataset_catalog.json')) as f: 57 | catalog = json.load(f) 58 | 59 | with open(os.path.join(cwd, 'templates.json')) as f: 60 | all_templates = json.load(f) 61 | 62 | with open(os.path.join(cwd, 'labels.json')) as f: 63 | all_labels = json.load(f) 64 | 65 | # Data loading code 66 | print("=> creating dataset") 67 | tokenizer = SimpleTokenizer() 68 | val_transform = transforms.Compose([ 69 | transforms.Resize(224), 70 | transforms.CenterCrop(224), 71 | lambda x: x.convert('RGB'), 72 | transforms.ToTensor(), 73 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 74 | std=[0.229, 0.224, 0.225]) 75 | ]) 76 | 77 | results = [] 78 | for d in catalog: 79 | print('Evaluating {}'.format(d)) 80 | val_dataset = datasets.get_downstream_dataset(catalog, name=d, is_train=False, transform=val_transform) 81 | 82 | val_loader = torch.utils.data.DataLoader( 83 | val_dataset, batch_size=args.batch_size, shuffle=False, 84 | num_workers=args.workers, pin_memory=True, drop_last=False) 85 | 86 | templates = all_templates[d] 87 | labels = all_labels[d] 88 | 89 | is_acc = d not in ['aircraft', 'pets', 'caltech101', 'flowers', 'kinetics700_frames', 'hateful_memes'] 90 | 91 | acc_or_outputs = validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc) 92 | 93 | if d in ['aircraft', 'pets', 'caltech101', 'flowers']: 94 | metric = mean_per_class(*acc_or_outputs) 95 | elif d == 'kinetics700_frames': 96 | top1, top5 = accuracy(*acc_or_outputs, topk=(1, 5)) 97 | metric = (top1 + top5) / 2 98 | metric = metric.item() 99 | elif d == 'hateful_memes': 100 | metric = roc_auc(*acc_or_outputs) 101 | else: 102 | metric = acc_or_outputs 103 | 104 | results.append(metric) 105 | 106 | print('metric:', metric) 107 | 108 | print('all results:') 109 | for x in results: 110 | print('{:.1f}'.format(x)) 111 | 112 | def validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc): 113 | # switch to evaluate mode 114 | model.eval() 115 | total_top1 = 0 116 | total_images = 0 117 | 118 | all_outputs = [] 119 | all_targets = [] 120 | 121 | print('=> encoding captions') 122 | with torch.no_grad(): 123 | text_features = [] 124 | for label in labels: 125 | if isinstance(label, list): 126 | texts = [t.format(l) for t in templates for l in label] 127 | else: 128 | texts = [t.format(label) for t in templates] 129 | texts = tokenizer(texts).cuda(non_blocking=True) 130 | texts = texts.view(-1, 77).contiguous() 131 | class_embeddings = utils.get_model(model).encode_text(texts, ema=True) 132 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 133 | class_embeddings = class_embeddings.mean(dim=0) 134 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 135 | text_features.append(class_embeddings) 136 | text_features = torch.stack(text_features, dim=0) 137 | 138 | for images, target in val_loader: 139 | images = images.cuda(non_blocking=True) 140 | target = target.cuda(non_blocking=True) 141 | 142 | # encode images 143 | image_features = utils.get_model(model).encode_image(images, ema=True) 144 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 145 | 146 | # cosine similarity as logits 147 | logits_per_image = image_features @ text_features.t() 148 | 149 | if is_acc: 150 | # measure accuracy and record loss 151 | pred = logits_per_image.argmax(dim=1) 152 | correct = pred.eq(target).sum() 153 | total_top1 += correct.item() 154 | total_images += images.size(0) 155 | else: 156 | all_outputs.append(logits_per_image.cpu()) 157 | all_targets.append(target.cpu()) 158 | 159 | if is_acc: 160 | return 100 * total_top1 / total_images 161 | else: 162 | return torch.cat(all_outputs), torch.cat(all_targets) 163 | 164 | 165 | def accuracy(output, target, topk=(1,)): 166 | """Computes the accuracy over the k top predictions for the specified values of k""" 167 | with torch.no_grad(): 168 | maxk = max(topk) 169 | batch_size = target.size(0) 170 | 171 | _, pred = output.topk(maxk, 1, True, True) 172 | pred = pred.t() 173 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 174 | 175 | res = [] 176 | for k in topk: 177 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 178 | res.append(correct_k.mul_(100.0 / batch_size)) 179 | return res 180 | 181 | 182 | def mean_per_class(outputs, targets): 183 | pred = outputs.argmax(1) 184 | confusion_matrix = metrics.confusion_matrix(targets, pred) 185 | per_classes = confusion_matrix.diagonal() / confusion_matrix.sum(axis=1) 186 | 187 | return 100 * per_classes.mean() 188 | 189 | 190 | def roc_auc(outputs, targets): 191 | pos_score = outputs[:, 1] - outputs[:, 0] 192 | metric = metrics.roc_auc_score(targets, pos_score) 193 | 194 | return 100 * metric 195 | 196 | 197 | if __name__ == '__main__': 198 | parser = argparse.ArgumentParser('A-CLIP 0-shot evaluations', parents=[get_args_parser()]) 199 | args = parser.parse_args() 200 | os.makedirs(args.output_dir, exist_ok=True) 201 | main(args) 202 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Based on SLIP code bases 2 | # https://github.com/facebookresearch/SLIP 3 | # -------------------------------------------------------- 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import utils 9 | 10 | def get_metric_names(): 11 | metics = ["loss"] 12 | metics.extend(["simclr_loss","im_byol_loss","contra_loss_1","contra_loss_2","clip_acc"]) 13 | 14 | return metics 15 | 16 | 17 | def cal_simsiam_loss(p, z, version="simplified"): # negative cosine similarity 18 | if version == "original": 19 | z = z.detach() # stop gradient 20 | p = F.normalize(p, dim=1) # l2-normalize 21 | z = F.normalize(z, dim=1) # l2-normalize 22 | return -(p * z).sum(dim=1).mean() 23 | 24 | elif ( 25 | version == "simplified" 26 | ): # same thing, much faster. Scroll down, speed test in __main__ 27 | return -F.cosine_similarity(p, z.detach(), dim=-1).mean() 28 | else: 29 | raise Exception 30 | 31 | 32 | class ACLIPLoss(nn.Module): 33 | def __init__(self, temperature=0.1): 34 | super().__init__() 35 | self.labels = None 36 | self.last_local_batch_size = None 37 | self.simclr_loss = SIMCLRLoss(temperature=temperature) 38 | 39 | def forward(self, outputs): 40 | image_embed = outputs["image_embed"] 41 | text_embed = outputs["text_embed"] 42 | logit_scale = outputs["logit_scale"] 43 | 44 | # cal simclr_loss 45 | bs = text_embed.shape[0] 46 | image_ssl_embed = outputs["image_ssl_embed"] 47 | inputs = {} 48 | inputs["aug1_embed"] = image_ssl_embed[:bs] 49 | inputs["aug2_embed"] = image_ssl_embed[:bs] 50 | simclr_loss_dict = self.simclr_loss(inputs) 51 | 52 | def loss_fn(x, y): 53 | x = F.normalize(x, dim=-1, p=2) 54 | y = F.normalize(y, dim=-1, p=2) 55 | return 2 - 2 * (x * y).sum(dim=-1) 56 | 57 | im_features = outputs["byol_feats"] 58 | im_features_e = outputs["byol_feats_e"] 59 | im_features_e = torch.cat([im_features_e, im_features_e], dim=0) 60 | im_byol_loss = loss_fn(im_features, im_features_e).mean() 61 | 62 | local_batch_size = text_embed.size(0) 63 | 64 | if local_batch_size != self.last_local_batch_size: 65 | self.labels = local_batch_size * utils.get_rank() + torch.arange( 66 | local_batch_size, device=image_embed.device 67 | ) 68 | self.last_local_batch_size = local_batch_size 69 | 70 | image_embed = F.normalize(image_embed, dim=-1, p=2) 71 | text_embed = F.normalize(text_embed, dim=-1, p=2) 72 | 73 | image_embed_1 = image_embed[:local_batch_size] 74 | image_embed_2 = image_embed[local_batch_size:] 75 | 76 | ( 77 | image_embed_all_1, 78 | image_embed_all_2, 79 | text_embed_all, 80 | ) = utils.all_gather_batch_with_grad([image_embed_1, image_embed_2, text_embed]) 81 | 82 | # cosine similarity as logits 83 | logits_per_image = logit_scale * image_embed_1 @ text_embed_all.t() 84 | logits_per_text = logit_scale * text_embed @ image_embed_all_1.t() 85 | 86 | contra_loss_1 = ( 87 | F.cross_entropy(logits_per_image, self.labels) 88 | + F.cross_entropy(logits_per_text, self.labels) 89 | ) / 2 90 | 91 | logits_per_image = logit_scale * image_embed_2 @ text_embed_all.t() 92 | logits_per_text = logit_scale * text_embed @ image_embed_all_2.t() 93 | 94 | contra_loss_2 = ( 95 | F.cross_entropy(logits_per_image, self.labels) 96 | + F.cross_entropy(logits_per_text, self.labels) 97 | ) / 2 98 | 99 | 100 | loss = ( 101 | 0.5 * contra_loss_1 102 | + 0.5 * contra_loss_2 103 | + simclr_loss_dict["ssl_loss"] 104 | + 2 * im_byol_loss 105 | ) 106 | 107 | # compute accuracy 108 | with torch.no_grad(): 109 | pred = torch.argmax(logits_per_image, dim=-1) 110 | correct = pred.eq(self.labels).sum() 111 | acc = 100 * correct / local_batch_size 112 | 113 | return { 114 | "loss": loss, 115 | "simclr_loss": simclr_loss_dict["ssl_loss"], 116 | "im_byol_loss": im_byol_loss, 117 | "contra_loss_1": contra_loss_1, 118 | "contra_loss_2": contra_loss_2, 119 | "clip_acc": acc, 120 | } 121 | 122 | 123 | class SIMCLRLoss(nn.Module): 124 | """ 125 | This is the SimCLR loss in https://arxiv.org/abs/2002.05709 126 | The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and 127 | the memory layout that can be reshaped into shape (2, batch_size, embedding_dim). 128 | This memory layout is consistent with the SimCLR collator in 129 | https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py 130 | Config params: 131 | temperature (float): the temperature to be applied on the logits 132 | """ 133 | 134 | def __init__(self, temperature=0.1): 135 | super().__init__() 136 | self.tau = temperature 137 | self.labels = None 138 | self.masks = None 139 | self.last_local_batch_size = None 140 | 141 | def forward(self, outputs): 142 | q_a = outputs["aug1_embed"] 143 | q_b = outputs["aug2_embed"] 144 | 145 | q_a = F.normalize(q_a, dim=-1, p=2) 146 | q_b = F.normalize(q_b, dim=-1, p=2) 147 | 148 | local_batch_size = q_a.size(0) 149 | 150 | k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b]) 151 | 152 | if local_batch_size != self.last_local_batch_size: 153 | self.labels = local_batch_size * utils.get_rank() + torch.arange( 154 | local_batch_size, device=q_a.device 155 | ) 156 | total_batch_size = local_batch_size * utils.get_world_size() 157 | self.masks = F.one_hot(self.labels, total_batch_size) * 1e9 158 | self.last_local_batch_size = local_batch_size 159 | 160 | logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau 161 | logits_aa = logits_aa - self.masks 162 | logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau 163 | logits_bb = logits_bb - self.masks 164 | logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau 165 | logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau 166 | 167 | loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels) 168 | loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels) 169 | loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples 170 | 171 | # compute accuracy 172 | with torch.no_grad(): 173 | pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1) 174 | correct = pred.eq(self.labels).sum() 175 | acc = 100 * correct / local_batch_size 176 | 177 | return {"loss": loss, "ssl_loss": loss, "ssl_acc": acc} 178 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Based on SLIP code bases 2 | # https://github.com/facebookresearch/SLIP 3 | # --------------------------------------------------------' 4 | import argparse 5 | from collections import OrderedDict 6 | import json 7 | import math 8 | import os 9 | import sys 10 | import time 11 | try: 12 | import wandb 13 | except ImportError: 14 | wandb = None 15 | 16 | import numpy as np 17 | import torch 18 | import torch.cuda.amp as amp 19 | import torch.nn.parallel 20 | import torch.backends.cudnn as cudnn 21 | import torch.optim 22 | import torch.utils.data 23 | import torchvision.transforms as transforms 24 | 25 | from datasets import get_dataset 26 | import models 27 | from tokenizer import SimpleTokenizer 28 | from utils import AverageMeter, ProgressMeter, accuracy 29 | import utils 30 | from torchvision.datasets import ImageFolder 31 | from data import GaussianBlur,Solarize 32 | from losses import ACLIPLoss, get_metric_names 33 | 34 | 35 | 36 | def get_args_parser(): 37 | parser = argparse.ArgumentParser(description='A-CLIP pre-training and evaluation', add_help=False) 38 | # Data 39 | parser.add_argument('--dataset', default='yfcc15m', type=str, choices=['yfcc15m', 'cc3m', 'cc12m', 'coco', 'redcaps']) 40 | parser.add_argument('--metadata', default='yfcc15m.pkl', type=str, 41 | help='path to metadata file (see README for details)') 42 | parser.add_argument('--root', default='', type=str, 43 | help='path to dataset root') 44 | parser.add_argument('--output-dir', default='./', type=str, help='path where to save, empty for no saving') 45 | # Model 46 | parser.add_argument('--model', default='ACLIP_VITB16', type=str) 47 | parser.add_argument('--mask-ratio', default=0., type=float) 48 | parser.add_argument('--ssl-mlp-dim', default=4096, type=int, 49 | help='hidden dim of SimCLR mlp projection head') 50 | parser.add_argument('--ssl-emb-dim', default=256, type=int, 51 | help='output embed dim of SimCLR mlp projection head') 52 | parser.add_argument('--ssl-scale', default=1.0, type=float, 53 | help='loss scale for SimCLR objective') 54 | parser.add_argument('--ssl-temp', default=0.1, type=float, 55 | help='softmax temperature for SimCLR objective') 56 | parser.add_argument('--resume', default='', type=str, help='path to resume from') 57 | # Training 58 | parser.add_argument('--momentum-ema', default=0.996, type=float, help="""Base EMA 59 | parameter. The value is increased to 1 during training with cosine schedule.""") 60 | parser.add_argument('--epochs', default=25, type=int) 61 | parser.add_argument('--warmup-epochs', default=1, type=int) 62 | parser.add_argument('--start-epoch', default=0, type=int) 63 | parser.add_argument('--batch-size', default=2, type=int, 64 | help='number of samples per-device/per-gpu') 65 | parser.add_argument('--lr', default=3e-3, type=float) 66 | parser.add_argument('--base-lr', default=3e-3, type=float) 67 | parser.add_argument('--lr-start', default=1e-6, type=float, 68 | help='initial warmup lr') 69 | parser.add_argument('--lr-end', default=1e-5, type=float, 70 | help='minimum final lr') 71 | parser.add_argument('--update-freq', default=1, type=int, 72 | help='optimizer update frequency (i.e. gradient accumulation steps)') 73 | parser.add_argument('--wd', default=0.1, type=float) 74 | parser.add_argument('--betas', default=(0.9, 0.98), nargs=2, type=float) 75 | parser.add_argument('--eps', default=1e-8, type=float) 76 | parser.add_argument('--eval-freq', default=1, type=int) 77 | parser.add_argument('--disable-amp', action='store_true', 78 | help='disable mixed-precision training (requires more memory and compute)') 79 | # System 80 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 81 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', 82 | help='number of data loading workers per process') 83 | parser.add_argument('--evaluate', action='store_true', help='eval only') 84 | parser.add_argument('--world-size', default=1, type=int, 85 | help='number of nodes for distributed training') 86 | parser.add_argument('--rank', default=0, type=int, 87 | help='node rank for distributed training') 88 | parser.add_argument("--local_rank", type=int, default=0) 89 | parser.add_argument('--dist-url', default='env://', type=str, 90 | help='url used to set up distributed training') 91 | parser.add_argument('--dist-backend', default='nccl', type=str) 92 | parser.add_argument('--seed', default=0, type=int) 93 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 94 | parser.add_argument('--wandb', action='store_true', help='Enable WandB logging') 95 | parser.add_argument('--descriptions', default='training', type=str) 96 | 97 | return parser 98 | 99 | def get_model(args): 100 | print("=> creating model: {}".format(args.model)) 101 | model = getattr(models, args.model)(mask_ratio=args.mask_ratio) 102 | model.cuda(args.gpu) 103 | 104 | if args.distributed: 105 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 106 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], bucket_cap_mb=200,find_unused_parameters=False) 107 | 108 | return model 109 | 110 | def get_optim(args, model): 111 | p_wd, p_non_wd = [], [] 112 | for n, p in model.named_parameters(): 113 | if not p.requires_grad: 114 | continue # frozen weights 115 | if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n: 116 | p_non_wd.append(p) 117 | else: 118 | p_wd.append(p) 119 | 120 | optim_params = [{"params": p_wd, "weight_decay": args.wd}, 121 | {"params": p_non_wd, "weight_decay": 0}, 122 | ] 123 | 124 | optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas, 125 | eps=args.eps, weight_decay=args.wd) 126 | return optimizer 127 | 128 | def load_ckpt(args, model, optimizer, scaler): 129 | # optionally resume from a checkpoint (takes precedence over autoresume) 130 | if args.resume: 131 | if os.path.isfile(args.resume): 132 | print("=> loading resume checkpoint '{}'".format(args.resume)) 133 | checkpoint = torch.load(args.resume, map_location='cpu') 134 | epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0 135 | args.start_epoch = epoch 136 | result = model.load_state_dict(checkpoint['state_dict'], strict=False) 137 | print(result) 138 | optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else () 139 | scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else () 140 | args.best_acc = checkpoint['best_acc'] 141 | print("=> loaded resume checkpoint '{}' (epoch {})" 142 | .format(args.resume, epoch)) 143 | else: 144 | print("=> no checkpoint found at '{}'".format(args.resume)) 145 | else: 146 | # auto-resume from latest checkpoint in output directory 147 | latest = os.path.join(args.output_dir, 'checkpoint.pt') 148 | if os.path.isfile(latest): 149 | print("=> loading latest checkpoint '{}'".format(latest)) 150 | latest_checkpoint = torch.load(latest, map_location='cpu') 151 | args.start_epoch = latest_checkpoint['epoch'] 152 | model.load_state_dict(latest_checkpoint['state_dict']) 153 | optimizer.load_state_dict(latest_checkpoint['optimizer']) 154 | scaler.load_state_dict(latest_checkpoint['scaler']) 155 | args.best_acc = latest_checkpoint['best_acc'] 156 | print("=> loaded latest checkpoint '{}' (epoch {})" 157 | .format(latest, latest_checkpoint['epoch'])) 158 | 159 | def get_loader(args, tokenizer): 160 | print("=> creating dataset") 161 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 162 | std=[0.229, 0.224, 0.225]) 163 | 164 | val_transform = transforms.Compose([ 165 | transforms.Resize(224), 166 | transforms.CenterCrop(224), 167 | transforms.ToTensor(), 168 | normalize 169 | ]) 170 | 171 | train_transform = transforms.Compose([ 172 | transforms.RandomApply([ 173 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 174 | ], p=0.8), 175 | transforms.RandomGrayscale(p=0.2), 176 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1), 177 | transforms.RandomApply([Solarize()], p=0.2), 178 | transforms.ToTensor(), 179 | normalize 180 | ]) 181 | 182 | train_dataset = get_dataset(train_transform, tokenizer, args) 183 | cwd = os.path.dirname(os.path.realpath(__file__)) 184 | with open(os.path.join(cwd, 'dataset_catalog.json')) as f: 185 | root = json.load(f)['imagenet']['path'] 186 | val_dataset = ImageFolder(os.path.join(root, 'val'), val_transform) 187 | 188 | # dist eval resamples data to pad uneven batch sizes 189 | # make sure num_samples = 0 mod num_gpus for exact acc 190 | if args.distributed: 191 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 192 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 193 | else: 194 | train_sampler = None 195 | val_sampler = None 196 | 197 | train_loader = torch.utils.data.DataLoader( 198 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 199 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 200 | 201 | val_loader = torch.utils.data.DataLoader( 202 | val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None), 203 | num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False) 204 | 205 | return train_loader, train_sampler, val_loader 206 | 207 | def main(args): 208 | utils.init_distributed_mode(args) 209 | cudnn.benchmark = True 210 | 211 | args.best_acc = 0 212 | 213 | # fix the seed for reproducibility 214 | seed = args.seed + utils.get_rank() 215 | torch.manual_seed(seed) 216 | np.random.seed(seed) 217 | 218 | # create model 219 | model = get_model(args) 220 | 221 | # define loss function (criterion) and optimizer 222 | criterion = ACLIPLoss(args.ssl_temp).cuda(args.gpu) 223 | 224 | optimizer = get_optim(args, model) 225 | scaler = amp.GradScaler(enabled=not args.disable_amp) 226 | 227 | load_ckpt(args, model, optimizer, scaler) 228 | 229 | # Data loading 230 | tokenizer = SimpleTokenizer() 231 | train_loader, train_sampler, val_loader = get_loader(args, tokenizer) 232 | 233 | if args.evaluate: 234 | zero_stats = validate_zeroshot(val_loader, model, tokenizer, args) 235 | if utils.is_main_process(): 236 | with open(os.path.join(args.output_dir, 'eval_log.txt'), 'a') as f: 237 | f.write(json.dumps(zero_stats) + '\n') 238 | return 239 | 240 | lr_schedule = utils.cosine_scheduler(args.lr, args.lr_end, args.epochs, 241 | len(train_loader) // args.update_freq, warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start) 242 | 243 | momentum_schedule = utils.cosine_scheduler(args.momentum_ema, 1, args.epochs, len(train_loader), 0) 244 | 245 | if utils.is_main_process() and args.wandb: 246 | wandb_id = os.path.split(args.output_dir)[-1] 247 | wandb.init(project='ACLIP', id=wandb_id, config=args, resume='resume') 248 | 249 | print(args) 250 | 251 | print("=> beginning training") 252 | for epoch in range(args.start_epoch, args.epochs): 253 | if args.distributed: 254 | train_sampler.set_epoch(epoch) 255 | 256 | # train for one epoch 257 | train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, momentum_schedule ,args) 258 | 259 | if (epoch + 1) % args.eval_freq != 0: 260 | continue 261 | 262 | val_stats = validate_zeroshot(val_loader, model, tokenizer, args) 263 | ema_val_stats = validate_zeroshot(val_loader, model, tokenizer, args, ema=True) 264 | acc1 = val_stats['acc1'] 265 | 266 | is_best = acc1 > args.best_acc 267 | args.best_acc = max(acc1, args.best_acc) 268 | 269 | print("=> saving checkpoint") 270 | utils.save_on_master({ 271 | 'epoch': epoch + 1, 272 | 'state_dict': model.state_dict(), 273 | 'optimizer' : optimizer.state_dict(), 274 | 'scaler': scaler.state_dict(), 275 | 'best_acc': args.best_acc, 276 | 'args': args, 277 | }, is_best, args.output_dir,epoch+1,args.epochs) 278 | 279 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 280 | **{f'test_{k}': v for k, v in val_stats.items()}, 281 | **{f'ema_test_{k}': v for k, v in ema_val_stats.items()}, 282 | 'epoch': epoch} 283 | 284 | if utils.is_main_process(): 285 | if args.wandb: 286 | wandb.log(log_stats) 287 | with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f: 288 | f.write(json.dumps(log_stats) + '\n') 289 | 290 | 291 | def train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, momentum_schedule, args): 292 | batch_time = AverageMeter('Time', ':6.2f') 293 | data_time = AverageMeter('Data', ':6.2f') 294 | mem = AverageMeter('Mem (GB)', ':6.1f') 295 | metric_names = get_metric_names() 296 | iters_per_epoch = len(train_loader) // args.update_freq 297 | metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names]) 298 | progress = ProgressMeter( 299 | iters_per_epoch, 300 | [batch_time, data_time, mem, *metrics.values()], 301 | prefix="Epoch: [{}]".format(epoch)) 302 | 303 | # switch to train mode 304 | model.train() 305 | 306 | end = time.time() 307 | for data_iter, inputs in enumerate(train_loader): 308 | optim_iter = data_iter // args.update_freq 309 | data_time.update(time.time() - end) 310 | 311 | # update weight decay and learning rate according to their schedule 312 | it = iters_per_epoch * epoch + optim_iter # global training iteration 313 | for k, param_group in enumerate(optimizer.param_groups): 314 | param_group['lr'] = lr_schedule[it] 315 | 316 | online_inputs = [torch.cat([inputs[0][0], inputs[0][1]], dim=0), inputs[0][2], inputs[2]] 317 | online_inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in online_inputs] 318 | positions = inputs[1] 319 | 320 | m = momentum_schedule[it] # momentum parameter 321 | # compute output 322 | with amp.autocast(enabled=not args.disable_amp): 323 | outputs = model(*online_inputs, positions, m) 324 | loss_dict = criterion(outputs) 325 | 326 | loss = loss_dict['loss'] 327 | loss /= args.update_freq 328 | 329 | if not math.isfinite(loss.item()): 330 | print("Loss is {}, stopping training".format(loss.item())) 331 | sys.exit(1) 332 | 333 | scaler.scale(loss).backward() 334 | 335 | if (data_iter + 1) % args.update_freq != 0: 336 | continue 337 | 338 | # compute gradient and do SGD step 339 | scaler.step(optimizer) 340 | scaler.update() 341 | model.zero_grad(set_to_none=True) 342 | 343 | # clamp logit scale to [0, 100] 344 | logit_scale_e = 0 345 | 346 | utils.get_model(model).logit_scale.data.clamp_(0, 4.6052) 347 | if hasattr(utils.get_model(model),'logit_scale_e'): 348 | utils.get_model(model).logit_scale_e.data.clamp_(0, 4.6052) 349 | logit_scale_e = utils.get_model(model).logit_scale_e.exp().item() 350 | 351 | logit_scale = utils.get_model(model).logit_scale.exp().item() 352 | 353 | for k in loss_dict: 354 | metrics[k].update(loss_dict[k].item(), args.batch_size) 355 | 356 | # measure elapsed time 357 | batch_time.update(time.time() - end) 358 | 359 | end = time.time() 360 | 361 | mem.update(torch.cuda.max_memory_allocated() // 1e9) 362 | 363 | if optim_iter % args.print_freq == 0: 364 | if utils.is_main_process() and args.wandb: 365 | wandb.log({**{k: v.item() for k, v in loss_dict.items()}, 366 | 'scaler': scaler.get_scale(), 367 | 'logit': logit_scale, 368 | 'logit_e': logit_scale_e, 369 | }) 370 | progress.display(optim_iter) 371 | 372 | progress.synchronize() 373 | return {**{k: v.avg for k, v in metrics.items()}, 374 | 'lr': optimizer.param_groups[0]['lr'], 375 | 'logit_scale': logit_scale} 376 | 377 | 378 | def validate_zeroshot(val_loader, model, tokenizer, args, ema=False): 379 | batch_time = AverageMeter('Time', ':6.3f') 380 | top1 = AverageMeter('Acc@1', ':6.2f') 381 | top5 = AverageMeter('Acc@5', ':6.2f') 382 | progress = ProgressMeter( 383 | len(val_loader), 384 | [batch_time, top1, top5], 385 | prefix='Test: ') 386 | 387 | # switch to evaluate mode 388 | model.eval() 389 | 390 | print('=> encoding captions') 391 | cwd = os.path.dirname(os.path.realpath(__file__)) 392 | with open(os.path.join(cwd, 'templates.json')) as f: 393 | templates = json.load(f)['imagenet'] 394 | 395 | with open(os.path.join(cwd, 'labels.json')) as f: 396 | labels = json.load(f)['imagenet'] 397 | 398 | with torch.no_grad(): 399 | text_features = [] 400 | for l in labels: 401 | texts = [t.format(l) for t in templates] 402 | texts = tokenizer(texts).cuda(args.gpu, non_blocking=True) 403 | class_embeddings = utils.get_model(model).encode_text(texts, ema=ema) 404 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 405 | class_embeddings = class_embeddings.mean(dim=0) 406 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 407 | text_features.append(class_embeddings) 408 | text_features = torch.stack(text_features, dim=0) 409 | 410 | end = time.time() 411 | for i, (images, target) in enumerate(val_loader): 412 | images = images.cuda(args.gpu, non_blocking=True) 413 | target = target.cuda(args.gpu, non_blocking=True) 414 | 415 | # encode images 416 | image_features = utils.get_model(model).encode_image(images, ema=ema) 417 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 418 | 419 | # cosine similarity as logits 420 | logits_per_image = image_features @ text_features.t() 421 | 422 | # measure accuracy and record loss 423 | acc1, acc5 = accuracy(logits_per_image, target, topk=(1, 5)) 424 | acc1, acc5 = utils.scaled_all_reduce([acc1, acc5]) 425 | top1.update(acc1.item(), images.size(0)) 426 | top5.update(acc5.item(), images.size(0)) 427 | 428 | # measure elapsed time 429 | batch_time.update(time.time() - end) 430 | end = time.time() 431 | 432 | if i % args.print_freq == 0: 433 | progress.display(i) 434 | 435 | progress.synchronize() 436 | print('0-shot * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 437 | .format(top1=top1, top5=top5)) 438 | return {'acc1': top1.avg, 'acc5': top5.avg} 439 | 440 | 441 | if __name__ == '__main__': 442 | parser = argparse.ArgumentParser('A-CLIP training and evaluation', parents=[get_args_parser()]) 443 | args = parser.parse_args() 444 | if args.output_dir: 445 | os.makedirs(args.output_dir, exist_ok=True) 446 | main(args) 447 | 448 | 449 | 450 | -------------------------------------------------------------------------------- /main_linear.py: -------------------------------------------------------------------------------- 1 | # Based on SLIP code bases 2 | # https://github.com/facebookresearch/SLIP 3 | # -------------------------------------------------------- 4 | import argparse 5 | import json 6 | import math 7 | import os 8 | import random 9 | import shutil 10 | import time 11 | import timm 12 | import warnings 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.backends.cudnn as cudnn 18 | import torch.optim 19 | import torch.utils.data 20 | import torch.utils.data.distributed 21 | import torchvision.transforms as transforms 22 | 23 | import datasets 24 | import utils 25 | 26 | 27 | def get_args_parser(): 28 | parser = argparse.ArgumentParser(description='Linear probe evaluation', add_help=False) 29 | parser.add_argument('--dataset', default='imagenet', help='dataset name') 30 | parser.add_argument('--output-dir', default='./', type=str) 31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='vit_base_patch16_224', 32 | help='model architecture: (default: ViT-B/16)') 33 | parser.add_argument('-j', '--workers', default=64, type=int, metavar='N', 34 | help='number of data loading workers (default: 64)') 35 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 36 | help='number of total epochs to run') 37 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 38 | help='manual epoch number (useful on restarts)') 39 | parser.add_argument('-b', '--batch-size', default=128, type=int, 40 | metavar='N', 41 | help='number of samples per-device/per-gpu ') 42 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 43 | metavar='LR', help='initial (base) learning rate', dest='lr') 44 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 45 | help='momentum') 46 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 47 | metavar='W', help='weight decay (default: 0.)', 48 | dest='weight_decay') 49 | parser.add_argument('-p', '--print-freq', default=10, type=int, 50 | metavar='N', help='print frequency (default: 10)') 51 | parser.add_argument('--eval-freq', default=10, type=int) 52 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 53 | help='path to latest checkpoint (default: none)') 54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 55 | help='evaluate model on validation set') 56 | parser.add_argument('--world-size', default=1, type=int, 57 | help='number of nodes for distributed training') 58 | parser.add_argument('--rank', default=0, type=int, 59 | help='node rank for distributed training') 60 | parser.add_argument("--local_rank", type=int, default=0) 61 | parser.add_argument('--dist-url', default='env://', type=str, 62 | help='url used to set up distributed training') 63 | parser.add_argument('--dist-backend', default='nccl', type=str, 64 | help='distributed backend') 65 | parser.add_argument('--seed', default=None, type=int, 66 | help='seed for initializing training. ') 67 | parser.add_argument('--gpu', default=None, type=int, 68 | help='GPU id to use.') 69 | parser.add_argument('--pretrained', default='', type=str, 70 | help='path to CLIP pretrained checkpoint') 71 | return parser 72 | 73 | best_acc1 = 0 74 | 75 | 76 | def main(args): 77 | utils.init_distributed_mode(args) 78 | 79 | global best_acc1 80 | 81 | if args.seed is not None: 82 | random.seed(args.seed) 83 | torch.manual_seed(args.seed) 84 | cudnn.deterministic = True 85 | warnings.warn('You have chosen to seed training. ' 86 | 'This will turn on the CUDNN deterministic setting, ' 87 | 'which can slow down your training considerably! ' 88 | 'You may see unexpected behavior when restarting ' 89 | 'from checkpoints.') 90 | 91 | linear_keyword = 'head' 92 | if os.path.isfile(args.pretrained): 93 | print("=> loading checkpoint '{}'".format(args.pretrained)) 94 | 95 | if args.gpu is None: 96 | checkpoint = torch.load(args.pretrained) 97 | else: 98 | # Map model to be loaded to specified single gpu. 99 | loc = 'cuda:{}'.format(args.gpu) 100 | checkpoint = torch.load(args.pretrained, map_location=loc) 101 | 102 | visual_keyword = 'module.visual.' 103 | 104 | # rename CLIP pre-trained keys 105 | state_dict = checkpoint['state_dict'] 106 | for k in list(state_dict.keys()): 107 | # retain only base_encoder up to before the embedding layer 108 | if k.startswith(visual_keyword) and not k.startswith(visual_keyword + linear_keyword): 109 | # remove prefix 110 | state_dict[k[len(visual_keyword):]] = state_dict[k] 111 | # delete renamed or unused k 112 | del state_dict[k] 113 | else: 114 | raise Exception('Missing pretrained model checkpoint: {}'.format(args.pretrained)) 115 | 116 | # create model 117 | print("=> creating model '{}'".format(args.arch)) 118 | model = timm.models.create_model(args.arch, num_classes=1000) 119 | 120 | args.start_epoch = 0 121 | msg = model.load_state_dict(state_dict, strict=False) 122 | assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword} 123 | 124 | # freeze all layers but the last fc 125 | for name, param in model.named_parameters(): 126 | if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]: 127 | param.requires_grad = False 128 | # init the fc layer 129 | getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01) 130 | getattr(model, linear_keyword).bias.data.zero_() 131 | 132 | init_lr = args.lr * int(args.batch_size / utils.get_world_size()) / 256 133 | args.workers = int((args.workers + utils.get_world_size() - 1) / utils.get_world_size()) 134 | 135 | model.cuda(args.gpu) 136 | 137 | if args.distributed: 138 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 139 | 140 | # define loss function (criterion) and optimizer 141 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 142 | 143 | # optimize only the linear classifier 144 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 145 | assert len(parameters) == 2 # weight, bias 146 | 147 | optimizer = torch.optim.SGD(parameters, init_lr, 148 | momentum=args.momentum, 149 | weight_decay=args.weight_decay) 150 | 151 | # optionally resume from a checkpoint 152 | if args.resume: 153 | if os.path.isfile(args.resume): 154 | print("=> loading checkpoint '{}'".format(args.resume)) 155 | if args.gpu is None: 156 | checkpoint = torch.load(args.resume) 157 | else: 158 | # Map model to be loaded to specified single gpu. 159 | loc = 'cuda:{}'.format(args.gpu) 160 | checkpoint = torch.load(args.resume, map_location=loc) 161 | args.start_epoch = checkpoint['epoch'] 162 | best_acc1 = checkpoint['best_acc1'] 163 | if args.gpu is not None: 164 | # best_acc1 may be from a checkpoint from a different GPU 165 | best_acc1 = best_acc1.to(args.gpu) 166 | model.load_state_dict(checkpoint['state_dict']) 167 | optimizer.load_state_dict(checkpoint['optimizer']) 168 | print("=> loaded checkpoint '{}' (epoch {})" 169 | .format(args.resume, checkpoint['epoch'])) 170 | else: 171 | print("=> no checkpoint found at '{}'".format(args.resume)) 172 | 173 | cudnn.benchmark = True 174 | 175 | # Data loading code 176 | cwd = os.path.dirname(os.path.realpath(__file__)) 177 | with open(os.path.join(cwd, 'dataset_catalog.json')) as f: 178 | catalog = json.load(f) 179 | 180 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 181 | std=[0.229, 0.224, 0.225]) 182 | 183 | train_transform = transforms.Compose([ 184 | transforms.RandomResizedCrop(224), 185 | transforms.RandomHorizontalFlip(), 186 | lambda x: x.convert('RGB'), 187 | transforms.ToTensor(), 188 | normalize, 189 | ]) 190 | val_transform = transforms.Compose([ 191 | transforms.Resize(256), 192 | transforms.CenterCrop(224), 193 | lambda x: x.convert('RGB'), 194 | transforms.ToTensor(), 195 | normalize, 196 | ]) 197 | 198 | train_dataset = datasets.get_downstream_dataset(catalog, args.dataset, is_train=True, transform=train_transform) 199 | val_dataset = datasets.get_downstream_dataset(catalog, args.dataset, is_train=False, transform=val_transform) 200 | 201 | if args.distributed: 202 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 203 | else: 204 | train_sampler = None 205 | 206 | train_loader = torch.utils.data.DataLoader( 207 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 208 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 209 | 210 | val_loader = torch.utils.data.DataLoader( 211 | val_dataset, batch_size=256, shuffle=False, 212 | num_workers=args.workers, pin_memory=True) 213 | 214 | if args.evaluate: 215 | validate(val_loader, model, criterion, args) 216 | return 217 | 218 | print(args) 219 | 220 | for epoch in range(args.start_epoch, args.epochs): 221 | if args.distributed: 222 | train_sampler.set_epoch(epoch) 223 | adjust_learning_rate(optimizer, init_lr, epoch, args) 224 | 225 | # train for one epoch 226 | train_stats = train(train_loader, model, criterion, optimizer, epoch, args) 227 | 228 | if (epoch + 1) % args.eval_freq != 0: 229 | continue 230 | 231 | # evaluate on validation set 232 | val_stats = validate(val_loader, model, criterion, args) 233 | acc1 = val_stats['acc1'] 234 | 235 | # remember best acc@1 and save checkpoint 236 | is_best = acc1 > best_acc1 237 | best_acc1 = max(acc1, best_acc1) 238 | 239 | if utils.is_main_process(): # only the first GPU saves checkpoint 240 | save_checkpoint({ 241 | 'epoch': epoch + 1, 242 | 'arch': args.arch, 243 | 'state_dict': model.state_dict(), 244 | 'best_acc1': best_acc1, 245 | 'optimizer' : optimizer.state_dict(), 246 | }, is_best, args.output_dir) 247 | if epoch == args.start_epoch: 248 | sanity_check(model.state_dict(), args.pretrained, linear_keyword, visual_keyword) 249 | 250 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 251 | **{f'test_{k}': v for k, v in val_stats.items()}, 252 | 'epoch': epoch} 253 | 254 | if utils.is_main_process(): 255 | with open(os.path.join(args.output_dir, 'linear_{}_lr={}_log.txt'.format(args.dataset, args.lr)), 'a') as f: 256 | f.write(json.dumps(log_stats) + '\n') 257 | 258 | 259 | def train(train_loader, model, criterion, optimizer, epoch, args): 260 | batch_time = AverageMeter('Time', ':6.3f') 261 | data_time = AverageMeter('Data', ':6.3f') 262 | losses = AverageMeter('Loss', ':.4e') 263 | top1 = AverageMeter('Acc@1', ':6.2f') 264 | top5 = AverageMeter('Acc@5', ':6.2f') 265 | progress = ProgressMeter( 266 | len(train_loader), 267 | [batch_time, data_time, losses, top1, top5], 268 | prefix="Epoch: [{}]".format(epoch)) 269 | 270 | """ 271 | Switch to eval mode: 272 | Under the protocol of linear classification on frozen features/models, 273 | it is not legitimate to change any part of the pre-trained model. 274 | BatchNorm in train mode may revise running mean/std (even if it receives 275 | no gradient), which are part of the model parameters too. 276 | """ 277 | model.eval() 278 | 279 | end = time.time() 280 | for i, (images, target) in enumerate(train_loader): 281 | # measure data loading time 282 | data_time.update(time.time() - end) 283 | 284 | if args.gpu is not None: 285 | images = images.cuda(args.gpu, non_blocking=True) 286 | if torch.cuda.is_available(): 287 | target = target.cuda(args.gpu, non_blocking=True) 288 | 289 | # compute output 290 | output = model(images) 291 | loss = criterion(output, target) 292 | 293 | # measure accuracy and record loss 294 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 295 | losses.update(loss.item(), images.size(0)) 296 | top1.update(acc1.item(), images.size(0)) 297 | top5.update(acc5.item(), images.size(0)) 298 | 299 | # compute gradient and do SGD step 300 | optimizer.zero_grad() 301 | loss.backward() 302 | optimizer.step() 303 | 304 | # measure elapsed time 305 | batch_time.update(time.time() - end) 306 | end = time.time() 307 | 308 | if i % args.print_freq == 0: 309 | progress.display(i) 310 | 311 | return {'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg} 312 | 313 | 314 | def validate(val_loader, model, criterion, args): 315 | batch_time = AverageMeter('Time', ':6.3f') 316 | losses = AverageMeter('Loss', ':.4e') 317 | top1 = AverageMeter('Acc@1', ':6.2f') 318 | top5 = AverageMeter('Acc@5', ':6.2f') 319 | progress = ProgressMeter( 320 | len(val_loader), 321 | [batch_time, losses, top1, top5], 322 | prefix='Test: ') 323 | 324 | # switch to evaluate mode 325 | model.eval() 326 | 327 | with torch.no_grad(): 328 | end = time.time() 329 | for i, (images, target) in enumerate(val_loader): 330 | if args.gpu is not None: 331 | images = images.cuda(args.gpu, non_blocking=True) 332 | if torch.cuda.is_available(): 333 | target = target.cuda(args.gpu, non_blocking=True) 334 | 335 | # compute output 336 | output = model(images) 337 | loss = criterion(output, target) 338 | 339 | # measure accuracy and record loss 340 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 341 | losses.update(loss.item(), images.size(0)) 342 | top1.update(acc1.item(), images.size(0)) 343 | top5.update(acc5.item(), images.size(0)) 344 | 345 | # measure elapsed time 346 | batch_time.update(time.time() - end) 347 | end = time.time() 348 | 349 | if i % args.print_freq == 0: 350 | progress.display(i) 351 | 352 | # TODO: this should also be done with the ProgressMeter 353 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 354 | .format(top1=top1, top5=top5)) 355 | 356 | return {'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg} 357 | 358 | 359 | def save_checkpoint(state, is_best, output_dir): 360 | ckpt_path = f'{output_dir}/linear_checkpoint.pt' 361 | best_path = f'{output_dir}/linear_best.pt' 362 | torch.save(state, ckpt_path) 363 | if is_best: 364 | shutil.copyfile(ckpt_path, best_path) 365 | 366 | 367 | def sanity_check(state_dict, pretrained_weights, linear_keyword, visual_keyword): 368 | """ 369 | Linear classifier should not change any weights other than the linear layer. 370 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 371 | """ 372 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 373 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 374 | state_dict_pre = checkpoint['state_dict'] 375 | 376 | for k in list(state_dict.keys()): 377 | # only ignore linear layer 378 | if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k: 379 | continue 380 | 381 | # name in pretrained model 382 | k_pre = visual_keyword + k[len('module.'):] \ 383 | if k.startswith('module.') else visual_keyword + k 384 | 385 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 386 | '{} is changed in linear classifier training.'.format(k) 387 | 388 | print("=> sanity check passed.") 389 | 390 | 391 | class AverageMeter(object): 392 | """Computes and stores the average and current value""" 393 | def __init__(self, name, fmt=':f'): 394 | self.name = name 395 | self.fmt = fmt 396 | self.reset() 397 | 398 | def reset(self): 399 | self.val = 0 400 | self.avg = 0 401 | self.sum = 0 402 | self.count = 0 403 | 404 | def update(self, val, n=1): 405 | self.val = val 406 | self.sum += val * n 407 | self.count += n 408 | self.avg = self.sum / self.count 409 | 410 | def __str__(self): 411 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 412 | return fmtstr.format(**self.__dict__) 413 | 414 | 415 | class ProgressMeter(object): 416 | def __init__(self, num_batches, meters, prefix=""): 417 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 418 | self.meters = meters 419 | self.prefix = prefix 420 | 421 | def display(self, batch): 422 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 423 | entries += [str(meter) for meter in self.meters] 424 | print('\t'.join(entries)) 425 | 426 | def _get_batch_fmtstr(self, num_batches): 427 | num_digits = len(str(num_batches // 1)) 428 | fmt = '{:' + str(num_digits) + 'd}' 429 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 430 | 431 | 432 | def adjust_learning_rate(optimizer, init_lr, epoch, args): 433 | """Decay the learning rate based on schedule""" 434 | cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 435 | for param_group in optimizer.param_groups: 436 | param_group['lr'] = cur_lr 437 | 438 | 439 | def accuracy(output, target, topk=(1,)): 440 | """Computes the accuracy over the k top predictions for the specified values of k""" 441 | with torch.no_grad(): 442 | maxk = max(topk) 443 | batch_size = target.size(0) 444 | 445 | _, pred = output.topk(maxk, 1, True, True) 446 | pred = pred.t() 447 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 448 | 449 | res = [] 450 | for k in topk: 451 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 452 | res.append(correct_k.mul_(100.0 / batch_size)) 453 | return res 454 | 455 | 456 | if __name__ == '__main__': 457 | parser = argparse.ArgumentParser('Linear probe evaluation', parents=[get_args_parser()]) 458 | args = parser.parse_args() 459 | main(args) -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import numpy as np 7 | import pickle 8 | import re 9 | from urllib.parse import unquote 10 | from tqdm import tqdm 11 | 12 | 13 | DATASET = 'yfcc100m_dataset.txt' 14 | 15 | cleanhtml = re.compile('|||||') 16 | cleanurl = re.compile('http\S+|www\S+') 17 | 18 | print('=> loading YFCC image ids') 19 | image_ids = np.load('flickr_unique_ids.npy') 20 | image_ids = set(image_ids) 21 | 22 | print('=> loading CLIP image ids') 23 | clip_ids = set() 24 | with open('yfcc100m_subset_data.tsv') as f: 25 | for l in tqdm(f.readlines()): 26 | row = l.strip().split('\t') 27 | clip_ids.add(int(row[0])) 28 | 29 | print('=> collecting and cleaning subset captions') 30 | captioned = [] 31 | uncaptioned = [] 32 | with open('yfcc100m_dataset.txt') as f: 33 | for l in tqdm(f.readlines()): 34 | row = l.strip().split('\t') 35 | if int(row[0]) in image_ids: 36 | uncaptioned.append(int(row[0])) 37 | if int(row[0]) in clip_ids: 38 | title = unquote(row[8]).replace('+', ' ') 39 | title = re.sub(cleanhtml, '', title) 40 | title = re.sub(cleanurl, '', title) 41 | 42 | desc = unquote(row[9]).replace('+', ' ') 43 | desc = re.sub(cleanhtml, '', desc) 44 | desc = re.sub(cleanurl, '', desc) 45 | 46 | captioned.append((int(row[0]), title, desc)) 47 | 48 | with open('yfcc15m.pkl', 'wb') as f: 49 | pickle.dump(captioned, f) 50 | 51 | with open('yfcc100m.pkl', 'wb') as f: 52 | pickle.dump(uncaptioned, f) 53 | 54 | print('Total captioned images:', len(captioned)) # 14689580 55 | print('Total uncaptioned images:', len(uncaptioned)) # 95920149 56 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | 2 | # Based on CLIP code bases 3 | # Modified from github.com/openai/CLIP 4 | # --------------------------------------------------------' 5 | 6 | from collections import OrderedDict 7 | import numpy as np 8 | import timm 9 | import torch 10 | from torch import nn 11 | import torchvision.transforms.functional_tensor as F_t 12 | 13 | from timm.models.vision_transformer import VisionTransformer 14 | from timm.models.registry import register_model 15 | from timm.models.vision_transformer import ( 16 | default_cfgs, 17 | build_model_with_cfg, 18 | checkpoint_filter_fn, 19 | ) 20 | 21 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 22 | 23 | 24 | def get_att_mask(attention, ratio=0.5): 25 | bs = attention.shape[0] 26 | masks = torch.ones((bs,49), dtype=torch.bool, device=attention.device) 27 | attention = attention.reshape((-1, 14, 14)) 28 | attention = torch.nn.functional.interpolate(attention.unsqueeze(1), (7, 7), mode='bilinear').squeeze() 29 | attention = attention.reshape(bs,-1) 30 | N = int(attention.shape[1] * ratio) 31 | 32 | reservation = torch.argsort(attention, descending=True) 33 | reservation = reservation[:,:N+1] 34 | masks = masks.scatter_(1, reservation, False) 35 | 36 | full_mask = torch.zeros((bs, 14, 14), dtype=torch.bool, device=attention.device) 37 | full_mask[:, 0::2, 0::2] = masks.reshape(bs, 7, 7) 38 | full_mask[:, 0::2, 1::2] = masks.reshape(bs, 7, 7) 39 | full_mask[:, 1::2, 0::2] = masks.reshape(bs, 7, 7) 40 | full_mask[:, 1::2, 1::2] = masks.reshape(bs, 7, 7) 41 | full_mask = full_mask.reshape(bs, -1) 42 | 43 | return full_mask 44 | 45 | 46 | def trunc_normal_(tensor, mean=0.0, std=1.0): 47 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 48 | 49 | class LayerNorm(nn.LayerNorm): 50 | """Subclass torch's LayerNorm to handle fp16.""" 51 | 52 | def forward(self, x: torch.Tensor): 53 | orig_type = x.dtype 54 | ret = super().forward(x.type(torch.float32)) 55 | return ret.type(orig_type) 56 | 57 | 58 | class QuickGELU(nn.Module): 59 | def forward(self, x: torch.Tensor): 60 | return x * torch.sigmoid(1.702 * x) 61 | 62 | 63 | class ResidualAttentionBlock(nn.Module): 64 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 65 | super().__init__() 66 | 67 | self.attn = nn.MultiheadAttention(d_model, n_head) 68 | self.ln_1 = LayerNorm(d_model) 69 | self.mlp = nn.Sequential( 70 | OrderedDict( 71 | [ 72 | ("c_fc", nn.Linear(d_model, d_model * 4)), 73 | ("gelu", QuickGELU()), 74 | ("c_proj", nn.Linear(d_model * 4, d_model)), 75 | ] 76 | ) 77 | ) 78 | self.ln_2 = LayerNorm(d_model) 79 | self.attn_mask = attn_mask 80 | 81 | def attention(self, x: torch.Tensor): 82 | self.attn_mask = ( 83 | self.attn_mask.to(dtype=x.dtype, device=x.device) 84 | if self.attn_mask is not None 85 | else None 86 | ) 87 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 88 | 89 | def forward(self, x: torch.Tensor): 90 | x = x + self.attention(self.ln_1(x)) 91 | x = x + self.mlp(self.ln_2(x)) 92 | return x 93 | 94 | 95 | class Transformer(nn.Module): 96 | def __init__( 97 | self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None 98 | ): 99 | super().__init__() 100 | self.width = width 101 | self.layers = layers 102 | self.resblocks = nn.Sequential( 103 | *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] 104 | ) 105 | 106 | def forward(self, x: torch.Tensor): 107 | return self.resblocks(x) 108 | 109 | 110 | class CLIP(nn.Module): 111 | def __init__( 112 | self, 113 | embed_dim: int, 114 | # vision 115 | vision_width: int, 116 | vision_model: nn.Module, 117 | # text 118 | context_length: int, 119 | vocab_size: int, 120 | transformer_width: int, 121 | transformer_heads: int, 122 | transformer_layers: int, 123 | **kwargs, 124 | ): 125 | super().__init__() 126 | 127 | self.context_length = context_length 128 | self.vision_width = vision_width 129 | 130 | self.visual = vision_model 131 | 132 | self.transformer = Transformer( 133 | width=transformer_width, 134 | layers=transformer_layers, 135 | heads=transformer_heads, 136 | attn_mask=self.build_attention_mask(), 137 | ) 138 | 139 | self.vocab_size = vocab_size 140 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 141 | self.positional_embedding = nn.Parameter( 142 | torch.empty(self.context_length, transformer_width) 143 | ) 144 | self.ln_final = LayerNorm(transformer_width) 145 | 146 | self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) 147 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 148 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 149 | 150 | self.initialize_parameters() 151 | 152 | def initialize_parameters(self): 153 | nn.init.normal_(self.token_embedding.weight, std=0.02) 154 | nn.init.normal_(self.positional_embedding, std=0.01) 155 | 156 | proj_std = (self.transformer.width**-0.5) * ( 157 | (2 * self.transformer.layers) ** -0.5 158 | ) 159 | attn_std = self.transformer.width**-0.5 160 | fc_std = (2 * self.transformer.width) ** -0.5 161 | for block in self.transformer.resblocks: 162 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 163 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 164 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 165 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 166 | 167 | nn.init.normal_(self.image_projection, std=self.vision_width**-0.5) 168 | nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) 169 | 170 | def build_attention_mask(self): 171 | # lazily create causal attention mask, with full attention between the vision tokens 172 | # pytorch uses additive attention mask; fill with -inf 173 | mask = torch.empty(self.context_length, self.context_length) 174 | mask.fill_(float("-inf")) 175 | mask.triu_(1) # zero out the lower diagonal 176 | return mask 177 | 178 | def encode_image(self, image): 179 | x = self.visual(image) 180 | x = x @ self.image_projection 181 | 182 | return x 183 | 184 | def encode_text(self, text): 185 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model] 186 | x = x + self.positional_embedding 187 | x = x.permute(1, 0, 2) # NLD -> LND 188 | x = self.transformer(x) 189 | x = x.permute(1, 0, 2) # LND -> NLD 190 | x = self.ln_final(x) 191 | 192 | # x.shape = [batch_size, n_ctx, transformer.width] 193 | # take features from the eot embedding (eot_token is the highest number in each sequence) 194 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 195 | 196 | return x 197 | 198 | def forward(self, image, text): 199 | image_embed = self.encode_image(image) 200 | text_embed = self.encode_text(text) 201 | 202 | return { 203 | "image_embed": image_embed, 204 | "text_embed": text_embed, 205 | "logit_scale": self.logit_scale.exp(), 206 | } 207 | 208 | 209 | class Proj_Head(nn.Module): 210 | def __init__(self, in_channels, mlp_hidden_size, projection_size): 211 | super().__init__() 212 | 213 | self.net = nn.Sequential( 214 | nn.Linear(in_channels, mlp_hidden_size, bias=False), 215 | LayerNorm(mlp_hidden_size), 216 | nn.GELU(), 217 | nn.Linear(mlp_hidden_size, mlp_hidden_size, bias=False), 218 | LayerNorm(mlp_hidden_size), 219 | nn.GELU(), 220 | nn.Linear(mlp_hidden_size, projection_size, bias=False), 221 | LayerNorm(projection_size), 222 | ) 223 | 224 | def forward(self, x): 225 | return self.net(x) 226 | 227 | 228 | class Pred_Head(nn.Module): 229 | def __init__(self, dim, mlp_hidden_size, projection_size): 230 | super().__init__() 231 | 232 | self.net = nn.Sequential( 233 | nn.Linear(dim, mlp_hidden_size), 234 | LayerNorm(mlp_hidden_size), 235 | nn.GELU(), 236 | nn.Linear(mlp_hidden_size, projection_size), 237 | ) 238 | 239 | def forward(self, x): 240 | return self.net(x) 241 | 242 | 243 | class projection_MLP(nn.Module): 244 | def __init__(self, in_dim, hidden_dim=2048, out_dim=2048): 245 | super().__init__() 246 | """ page 3 baseline setting 247 | Projection MLP. The projection MLP (in f) has BN ap- 248 | plied to each fully-connected (fc) layer, including its out- 249 | put fc. Its output fc has no ReLU. The hidden fc is 2048-d. 250 | This MLP has 3 layers. 251 | """ 252 | self.layer1 = nn.Sequential( 253 | nn.Linear(in_dim, hidden_dim), 254 | nn.BatchNorm1d(hidden_dim), 255 | nn.ReLU(inplace=True), 256 | ) 257 | self.layer2 = nn.Sequential( 258 | nn.Linear(hidden_dim, hidden_dim), 259 | nn.BatchNorm1d(hidden_dim), 260 | nn.ReLU(inplace=True), 261 | ) 262 | self.layer3 = nn.Sequential( 263 | nn.Linear(hidden_dim, out_dim), nn.BatchNorm1d(hidden_dim) 264 | ) 265 | self.num_layers = 3 266 | 267 | def set_layers(self, num_layers): 268 | self.num_layers = num_layers 269 | 270 | def forward(self, x): 271 | if self.num_layers == 3: 272 | x = self.layer1(x) 273 | x = self.layer2(x) 274 | x = self.layer3(x) 275 | elif self.num_layers == 2: 276 | x = self.layer1(x) 277 | x = self.layer3(x) 278 | else: 279 | raise Exception 280 | return x 281 | 282 | 283 | class prediction_MLP(nn.Module): 284 | def __init__( 285 | self, in_dim=2048, hidden_dim=512, out_dim=2048 286 | ): # bottleneck structure 287 | super().__init__() 288 | """ page 3 baseline setting 289 | Prediction MLP. The prediction MLP (h) has BN applied 290 | to its hidden fc layers. Its output fc does not have BN 291 | (ablation in Sec. 4.4) or ReLU. This MLP has 2 layers. 292 | The dimension of h’s input and output (z and p) is d = 2048, 293 | and h’s hidden layer’s dimension is 512, making h a 294 | bottleneck structure (ablation in supplement). 295 | """ 296 | self.layer1 = nn.Sequential( 297 | nn.Linear(in_dim, hidden_dim), 298 | nn.BatchNorm1d(hidden_dim), 299 | nn.ReLU(inplace=True), 300 | ) 301 | self.layer2 = nn.Linear(hidden_dim, out_dim) 302 | """ 303 | Adding BN to the output of the prediction MLP h does not work 304 | well (Table 3d). We find that this is not about collapsing. 305 | The training is unstable and the loss oscillates. 306 | """ 307 | 308 | def forward(self, x): 309 | x = self.layer1(x) 310 | x = self.layer2(x) 311 | return x 312 | 313 | 314 | class ACLIP(CLIP): 315 | def __init__( 316 | self, 317 | **kwargs, 318 | ): 319 | super().__init__(**kwargs) 320 | self.visual_ema = kwargs['vision_model_ema'] 321 | vision_width = kwargs["vision_width"] 322 | embed_dim = kwargs['embed_dim'] 323 | 324 | self.image_mlp = self._build_mlp(vision_width, 4096, 256) #for simclr 325 | self.im_proj_byol = self._build_mlp_byol(2, vision_width, 4096, 256, False) 326 | self.im_pred_byol = self._build_mlp_byol(2, 256, 4096, 256, False) 327 | 328 | transformer_width = kwargs['transformer_width'] 329 | transformer_layers = kwargs['transformer_layers'] 330 | transformer_heads = kwargs['transformer_heads'] 331 | # text ema 332 | self.transformer_e = Transformer( 333 | width=transformer_width, 334 | layers=transformer_layers, 335 | heads=transformer_heads, 336 | attn_mask=self.build_attention_mask(), 337 | ) 338 | self.image_projection_e = nn.Parameter(torch.empty(vision_width, embed_dim)) 339 | self.text_projection_e = nn.Parameter(torch.empty(transformer_width, embed_dim)) 340 | 341 | self.im_proj_byol_e = self._build_mlp_byol(2, vision_width, 4096, 256, False) 342 | 343 | for param_m, param_b in zip(self.visual_ema.parameters(), self.visual.parameters()): 344 | param_m.data.copy_(param_b.data) # initialize 345 | param_m.requires_grad = False # not update by gradient 346 | for param_m, param_b in zip(self.transformer_e.parameters(), self.transformer.parameters()): 347 | param_m.data.copy_(param_b.data) # initialize 348 | param_m.requires_grad = False # not update by gradient 349 | for param_m, param_b in zip(self.im_proj_byol_e.parameters(), self.im_proj_byol.parameters()): 350 | param_m.data.copy_(param_b.data) # initialize 351 | param_m.requires_grad = False # not update by gradient 352 | 353 | self.image_projection_e.requires_grad = False 354 | self.image_projection_e.data.copy_(self.image_projection.data) 355 | self.text_projection_e.requires_grad = False 356 | self.text_projection_e.data.copy_(self.text_projection.data) 357 | 358 | 359 | @torch.no_grad() 360 | def _update_momentum_encoder(self, m): 361 | """Momentum update of the momentum encoder""" 362 | for param_b, param_m in zip( 363 | self.visual.parameters(), self.visual_ema.parameters() 364 | ): 365 | param_m.data = param_m.data * m + param_b.data * (1.0 - m) 366 | 367 | for param_b, param_m in zip( 368 | self.transformer.parameters(), self.transformer_e.parameters() 369 | ): 370 | param_m.data = param_m.data * m + param_b.data * (1.0 - m) 371 | for param_b, param_m in zip( 372 | self.im_proj_byol.parameters(), self.im_proj_byol_e.parameters() 373 | ): 374 | param_m.data = param_m.data * m + param_b.data * (1.0 - m) 375 | 376 | self.image_projection_e.data = self.image_projection_e.data * m + self.image_projection * (1.0 - m) 377 | self.text_projection_e.data = self.text_projection_e.data * m + self.text_projection * (1.0 - m) 378 | 379 | def _build_mlp_byol(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True): 380 | mlp = [] 381 | for l in range(num_layers): 382 | dim1 = input_dim if l == 0 else mlp_dim 383 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 384 | 385 | mlp.append(nn.Linear(dim1, dim2, bias=False)) 386 | 387 | if l < num_layers - 1: 388 | mlp.append(nn.BatchNorm1d(dim2)) 389 | mlp.append(nn.ReLU(inplace=True)) 390 | elif last_bn: 391 | # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 392 | # for simplicity, we further removed gamma in BN 393 | mlp.append(nn.BatchNorm1d(dim2, affine=False)) 394 | 395 | return nn.Sequential(*mlp) 396 | 397 | # from simclr 398 | def _build_mlp(self, in_dim, mlp_dim, out_dim): 399 | return nn.Sequential( 400 | OrderedDict( 401 | [ 402 | ("layer1", nn.Linear(in_dim, mlp_dim)), 403 | ("bn1", nn.BatchNorm1d(mlp_dim)), 404 | ("relu1", nn.ReLU(inplace=True)), 405 | ("layer2", nn.Linear(mlp_dim, mlp_dim)), 406 | ("bn2", nn.BatchNorm1d(mlp_dim)), 407 | ("relu2", nn.ReLU(inplace=True)), 408 | ("layer3", nn.Linear(mlp_dim, out_dim)), 409 | ] 410 | ) 411 | ) 412 | 413 | def encode_image(self, image, mask=None, ret=False, ema=False): 414 | if ema == False: 415 | x, attn, _ = self.visual(image, mask=mask, need_attn=False) 416 | tokens = x 417 | x = x[:, 0] @ self.image_projection 418 | else: 419 | x, attn, _ = self.visual_ema(image, mask=mask, need_attn=True) 420 | tokens = x 421 | x = x[:, 0] @ self.image_projection_e 422 | 423 | if ret: 424 | return x, attn, tokens 425 | return x 426 | 427 | def encode_text(self, text, ema=False): 428 | if ema: 429 | text_projection = self.text_projection_e 430 | transformer = self.transformer_e 431 | else: 432 | text_projection = self.text_projection 433 | transformer = self.transformer 434 | 435 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model] 436 | x = x + self.positional_embedding 437 | x = x.permute(1, 0, 2) # NLD -> LND 438 | x = transformer(x) 439 | x = x.permute(1, 0, 2) # LND -> NLD 440 | x = self.ln_final(x) 441 | 442 | # x.shape = [batch_size, n_ctx, transformer.width] 443 | # take features from the eot embedding (eot_token is the highest number in each sequence) 444 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ text_projection 445 | 446 | return x 447 | 448 | def get_mask(self, mask, positions, e_positions): 449 | # top, left, width, height = pos 450 | 451 | mask = mask.reshape((-1, 14, 14)) 452 | cmask = [] 453 | 454 | for i in range(mask.shape[0]): 455 | m = mask[i] 456 | m = m.unsqueeze(0) 457 | m = m.unsqueeze(0) 458 | o_pos = positions[i] 459 | e_pos = e_positions[i] 460 | m = torch.nn.functional.interpolate(m, (e_pos[2], e_pos[3]), mode='bilinear') 461 | 462 | top = o_pos[0] - e_pos[0] 463 | left = o_pos[1] - e_pos[1] 464 | m = F_t.crop(m, top, left, o_pos[2], o_pos[3]) 465 | m = torch.nn.functional.interpolate(m, (14, 14), mode='bilinear') 466 | cmask.append(m) 467 | 468 | cmask = torch.stack(cmask).squeeze() 469 | cmask = cmask.reshape(mask.shape[0], -1) 470 | return cmask 471 | 472 | def forward(self, im1, im2, text, pos, momentum): 473 | with torch.no_grad(): 474 | self._update_momentum_encoder(momentum) 475 | x, attn, _ = self.visual_ema(im2, need_attn=True) 476 | cls_token = x[:, 0] 477 | byol_feats_e = self.im_proj_byol_e(cls_token) 478 | 479 | attention_map = attn 480 | attention_map_1 = self.get_mask(attention_map,pos[:,0],pos[:,2]) 481 | mask_1 = get_att_mask(attention_map_1) 482 | attention_map_2 = self.get_mask(attention_map,pos[:,1],pos[:,2]) 483 | mask_2 = get_att_mask(attention_map_2) 484 | mask = torch.cat([mask_1,mask_2],dim=0) 485 | 486 | image_embed, _, tokens = self.encode_image(im1, mask=mask, ret=True) 487 | bs = text.shape[0] 488 | cls_token = tokens[:, 0] 489 | image_ssl_embed = self.image_mlp(cls_token) 490 | byol_feats = self.im_proj_byol(cls_token) 491 | byol_feats = self.im_pred_byol(byol_feats) 492 | text_embed = self.encode_text(text) 493 | 494 | return { 495 | "image_embed": image_embed, 496 | "text_embed": text_embed, 497 | "image_ssl_embed": image_ssl_embed, 498 | "byol_feats": byol_feats, 499 | "byol_feats_e": byol_feats_e, 500 | "logit_scale": self.logit_scale.exp(), 501 | } 502 | 503 | 504 | 505 | def forward_attn(self, x): 506 | B, N, C = x.shape 507 | qkv = ( 508 | self.qkv(x) 509 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 510 | .permute(2, 0, 3, 1, 4) 511 | ) 512 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 513 | 514 | attn = (q @ k.transpose(-2, -1)) * self.scale 515 | attn = attn.softmax(dim=-1) 516 | attn = self.attn_drop(attn) 517 | 518 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 519 | x = self.proj(x) 520 | x = self.proj_drop(x) 521 | return x, attn.detach() 522 | 523 | 524 | def forward_block(self, x): 525 | attn_x, attn = forward_attn(self.attn, self.norm1(x)) 526 | x = x + self.drop_path(attn_x) 527 | x = x + self.drop_path(self.mlp(self.norm2(x))) 528 | 529 | return x, attn 530 | 531 | class MaskVisionTransformer(VisionTransformer): 532 | def __init__(self, mask_ratio=0, **kwargs): 533 | super(MaskVisionTransformer, self).__init__(**kwargs) 534 | self.mask_ratio = mask_ratio 535 | for param in self.patch_embed.proj.parameters(): 536 | param.requires_grad = False 537 | 538 | def random_masking(self, x, mask_ratio): 539 | """ 540 | Perform per-sample random masking by per-sample shuffling. 541 | Per-sample shuffling is done by argsort random noise. 542 | x: [N, L, D], sequence 543 | """ 544 | N, L, D = x.shape # batch, length, dim 545 | len_keep = int(L * (1 - mask_ratio)) 546 | 547 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 548 | 549 | # sort noise for each sample 550 | ids_shuffle = torch.argsort( 551 | noise, dim=1 552 | ) # ascend: small is keep, large is remove 553 | ids_restore = torch.argsort(ids_shuffle, dim=1) 554 | 555 | # keep the first subset 556 | ids_keep = ids_shuffle[:, :len_keep] 557 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 558 | 559 | # generate the binary mask: 0 is keep, 1 is remove 560 | mask = torch.ones([N, L], device=x.device) 561 | mask[:, :len_keep] = 0 562 | # unshuffle to get the binary mask 563 | mask = torch.gather(mask, dim=1, index=ids_restore) 564 | return x_masked, mask, ids_restore 565 | 566 | def mask_model(self, x, mask): 567 | N, L, D = x.shape # batch, length, dim 568 | ids = torch.argsort(mask.long(), dim=1) # ascend 569 | mask_len = mask[0].sum() 570 | ids_keep = ids[:, : L - mask_len] 571 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 572 | return x_masked 573 | 574 | def forward_features(self, x, mask=None, need_attn=False): 575 | x = self.patch_embed(x) 576 | 577 | # add pos embed w/o cls token 578 | x = x + self.pos_embed[:, 1:, :] 579 | ids_restore = None 580 | if self.mask_ratio > 0 and self.training is True: 581 | if mask is None: 582 | x, mask, ids_restore = self.random_masking(x, self.mask_ratio) 583 | else: 584 | x = self.mask_model(x, mask) 585 | 586 | # x, mask , ids_restore = self.random_masking(x, self.mask_ratio) 587 | # add pos embed and cls token 588 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 589 | cls_token = cls_token.expand( 590 | x.shape[0], -1, -1 591 | ) # stole cls_tokens impl from Phil Wang, thanks 592 | 593 | x = torch.cat((cls_token, x), dim=1) 594 | x = self.pos_drop(x) 595 | attn_list = [] 596 | if need_attn: 597 | for b in self.blocks: 598 | x, attn_now = forward_block(b, x) 599 | attn_list.append(attn_now) 600 | attn = torch.stack(attn_list, dim=0) 601 | attn = torch.mean(attn, dim=0) 602 | attn = attn[:, :, 0, 1:].mean(1).detach().clone() 603 | x = self.norm(x) 604 | return x, attn, ids_restore 605 | else: 606 | x = self.blocks(x) 607 | x = self.norm(x) 608 | attn = None 609 | return x, attn, ids_restore 610 | 611 | def forward(self, x, mask=None, need_attn=False): 612 | x = self.forward_features(x, mask=mask, need_attn=need_attn) 613 | if self.head_dist is not None: 614 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 615 | if self.training and not torch.jit.is_scripting(): 616 | # during inference, return the average of both classifier predictions 617 | return x, x_dist 618 | else: 619 | return (x + x_dist) / 2 620 | else: 621 | x = self.head(x) 622 | return x 623 | 624 | @register_model 625 | def mask_vit_small_patch16_224(pretrained=False, **kwargs): 626 | """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 627 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 628 | """ 629 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs) 630 | model = _create_vision_transformer( 631 | "vit_small_patch16_224", 632 | MaskVisionTransformer, 633 | pretrained=pretrained, 634 | **model_kwargs, 635 | ) 636 | return model 637 | 638 | @register_model 639 | def mask_vit_base_patch16_224(pretrained=False, **kwargs): 640 | """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 641 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 642 | """ 643 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 644 | model = _create_vision_transformer( 645 | "vit_base_patch16_224", 646 | MaskVisionTransformer, 647 | pretrained=pretrained, 648 | **model_kwargs, 649 | ) 650 | return model 651 | 652 | @register_model 653 | def mask_vit_large_patch16_224(pretrained=False, **kwargs): 654 | """ViT-Large (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 655 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 656 | """ 657 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 658 | model = _create_vision_transformer( 659 | "vit_large_patch16_224", 660 | MaskVisionTransformer, 661 | pretrained=pretrained, 662 | **model_kwargs, 663 | ) 664 | return model 665 | 666 | 667 | @register_model 668 | def mask_vit_base_patch32_224(pretrained=False, **kwargs): 669 | """ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 670 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 671 | """ 672 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 673 | model = _create_vision_transformer( 674 | "vit_base_patch32_224", 675 | MaskVisionTransformer, 676 | pretrained=pretrained, 677 | **model_kwargs, 678 | ) 679 | return model 680 | 681 | def _create_vision_transformer( 682 | variant, 683 | transformer=MaskVisionTransformer, 684 | pretrained=False, 685 | default_cfg=None, 686 | **kwargs, 687 | ): 688 | default_cfg = default_cfg or default_cfgs[variant] 689 | if kwargs.get("features_only", None): 690 | raise RuntimeError( 691 | "features_only not implemented for Vision Transformer models." 692 | ) 693 | 694 | # NOTE this extra code to support handling of repr size for in21k pretrained models 695 | default_num_classes = default_cfg["num_classes"] 696 | num_classes = kwargs.get("num_classes", default_num_classes) 697 | repr_size = kwargs.pop("representation_size", None) 698 | if repr_size is not None and num_classes != default_num_classes: 699 | # Remove representation layer if fine-tuning. This may not always be the desired action, 700 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 701 | print("Removing representation layer for fine-tuning.") 702 | repr_size = None 703 | 704 | model = build_model_with_cfg( 705 | transformer, 706 | variant, 707 | pretrained, 708 | default_cfg=default_cfg, 709 | representation_size=repr_size, 710 | pretrained_filter_fn=checkpoint_filter_fn, 711 | pretrained_custom_load="npz" in default_cfg["url"], 712 | **kwargs, 713 | ) 714 | return model 715 | 716 | 717 | def ACLIP_VITB16(mask_ratio=0, **kwargs): 718 | vision_model = timm.create_model( 719 | "mask_vit_base_patch16_224", num_classes=0, mask_ratio=mask_ratio 720 | ) 721 | vision_model_ema = timm.create_model( 722 | 'mask_vit_base_patch16_224', num_classes=0, mask_ratio=0 723 | ) 724 | model = ACLIP( 725 | embed_dim=512, 726 | vision_width=768, 727 | vision_model=vision_model, 728 | vision_model_ema=vision_model_ema, 729 | context_length=77, 730 | vocab_size=49408, 731 | transformer_width=512, 732 | transformer_heads=8, 733 | transformer_layers=12, 734 | **kwargs, 735 | ) 736 | return model 737 | 738 | 739 | def ACLIP_VITL16(mask_ratio=0, **kwargs): 740 | vision_model = timm.create_model( 741 | "mask_vit_large_patch16_224", num_classes=0, mask_ratio=mask_ratio 742 | ) 743 | vision_model_ema = timm.create_model( 744 | 'mask_vit_large_patch16_224', num_classes=0, mask_ratio=0 745 | ) 746 | model = ACLIP( 747 | embed_dim=512, 748 | vision_width=1024, 749 | vision_model=vision_model, 750 | vision_model_ema=vision_model_ema, 751 | context_length=77, 752 | vocab_size=49408, 753 | transformer_width=512, 754 | transformer_heads=8, 755 | transformer_layers=12, 756 | **kwargs, 757 | ) 758 | return model 759 | 760 | # SMALL 761 | def ACLIP_VITS16(mask_ratio=0, **kwargs): 762 | vision_model = timm.create_model( 763 | "mask_vit_small_patch16_224", num_classes=0, mask_ratio=mask_ratio 764 | ) 765 | vision_model_ema = timm.create_model( 766 | "mask_vit_small_patch16_224", num_classes=0, mask_ratio=mask_ratio 767 | ) 768 | model = ACLIP( 769 | embed_dim=512, 770 | vision_width=384, 771 | vision_model=vision_model, 772 | vision_model_ema=vision_model_ema, 773 | context_length=77, 774 | vocab_size=49408, 775 | transformer_width=512, 776 | transformer_heads=8, 777 | transformer_layers=12, 778 | **kwargs, 779 | ) 780 | return model 781 | 782 | -------------------------------------------------------------------------------- /templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "food101": [ 3 | "a photo of {}, a type of food." 4 | ], 5 | "cifar10": [ 6 | "a photo of a {}.", 7 | "a blurry photo of a {}.", 8 | "a black and white photo of a {}.", 9 | "a low contrast photo of a {}.", 10 | "a high contrast photo of a {}.", 11 | "a bad photo of a {}.", 12 | "a good photo of a {}.", 13 | "a photo of a small {}.", 14 | "a photo of a big {}.", 15 | "a photo of the {}.", 16 | "a blurry photo of the {}.", 17 | "a black and white photo of the {}.", 18 | "a low contrast photo of the {}.", 19 | "a high contrast photo of the {}.", 20 | "a bad photo of the {}.", 21 | "a good photo of the {}.", 22 | "a photo of the small {}.", 23 | "a photo of the big {}." 24 | ], 25 | "cifar100": [ 26 | "a photo of a {}.", 27 | "a blurry photo of a {}.", 28 | "a black and white photo of a {}.", 29 | "a low contrast photo of a {}.", 30 | "a high contrast photo of a {}.", 31 | "a bad photo of a {}.", 32 | "a good photo of a {}.", 33 | "a photo of a small {}.", 34 | "a photo of a big {}.", 35 | "a photo of the {}.", 36 | "a blurry photo of the {}.", 37 | "a black and white photo of the {}.", 38 | "a low contrast photo of the {}.", 39 | "a high contrast photo of the {}.", 40 | "a bad photo of the {}.", 41 | "a good photo of the {}.", 42 | "a photo of the small {}.", 43 | "a photo of the big {}." 44 | ], 45 | "birdsnap": [ 46 | "a photo of a {}, a type of bird." 47 | ], 48 | "cub200": [ 49 | "a photo of a {}, a type of bird." 50 | ], 51 | "imagenet": [ 52 | "itap of a {}.", 53 | "a bad photo of the {}.", 54 | "a origami {}.", 55 | "a photo of the large {}.", 56 | "a {} in a video game.", 57 | "art of the {}.", 58 | "a photo of the small {}." 59 | ], 60 | "rendered_sst2": [ 61 | "a {} review of a movie." 62 | ], 63 | "hateful_memes": [ 64 | "a {}." 65 | ], 66 | "clevr_counts": [ 67 | "a photo of {} objects." 68 | ], 69 | "kinetics700_frames": [ 70 | "a photo of {}.", 71 | "a photo of a person {}.", 72 | "a photo of a person using {}.", 73 | "a photo of a person doing {}.", 74 | "a photo of a person during {}.", 75 | "a photo of a person performing {}.", 76 | "a photo of a person practicing {}.", 77 | "a video of {}.", 78 | "a video of a person {}.", 79 | "a video of a person using {}.", 80 | "a video of a person doing {}.", 81 | "a video of a person during {}.", 82 | "a video of a person performing {}.", 83 | "a video of a person practicing {}.", 84 | "a example of {}.", 85 | "a example of a person {}.", 86 | "a example of a person using {}.", 87 | "a example of a person doing {}.", 88 | "a example of a person during {}.", 89 | "a example of a person performing {}.", 90 | "a example of a person practicing {}.", 91 | "a demonstration of {}.", 92 | "a demonstration of a person {}.", 93 | "a demonstration of a person using {}.", 94 | "a demonstration of a person doing {}.", 95 | "a demonstration of a person during {}.", 96 | "a demonstration of a person performing {}.", 97 | "a demonstration of a person practicing {}." 98 | ], 99 | "ucf101_frames": [ 100 | "a photo of a person {}.", 101 | "a video of a person {}.", 102 | "a example of a person {}.", 103 | "a demonstration of a person {}.", 104 | "a photo of the person {}.", 105 | "a video of the person {}.", 106 | "a example of the person {}.", 107 | "a demonstration of the person {}.", 108 | "a photo of a person using {}.", 109 | "a video of a person using {}.", 110 | "a example of a person using {}.", 111 | "a demonstration of a person using {}.", 112 | "a photo of the person using {}.", 113 | "a video of the person using {}.", 114 | "a example of the person using {}.", 115 | "a demonstration of the person using {}.", 116 | "a photo of a person doing {}.", 117 | "a video of a person doing {}.", 118 | "a example of a person doing {}.", 119 | "a demonstration of a person doing {}.", 120 | "a photo of the person doing {}.", 121 | "a video of the person doing {}.", 122 | "a example of the person doing {}.", 123 | "a demonstration of the person doing {}.", 124 | "a photo of a person during {}.", 125 | "a video of a person during {}.", 126 | "a example of a person during {}.", 127 | "a demonstration of a person during {}.", 128 | "a photo of the person during {}.", 129 | "a video of the person during {}.", 130 | "a example of the person during {}.", 131 | "a demonstration of the person during {}.", 132 | "a photo of a person performing {}.", 133 | "a video of a person performing {}.", 134 | "a example of a person performing {}.", 135 | "a demonstration of a person performing {}.", 136 | "a photo of the person performing {}.", 137 | "a video of the person performing {}.", 138 | "a example of the person performing {}.", 139 | "a demonstration of the person performing {}.", 140 | "a photo of a person practicing {}.", 141 | "a video of a person practicing {}.", 142 | "a example of a person practicing {}.", 143 | "a demonstration of a person practicing {}.", 144 | "a photo of the person practicing {}.", 145 | "a video of the person practicing {}.", 146 | "a example of the person practicing {}.", 147 | "a demonstration of the person practicing {}." 148 | ], 149 | "patch_camelyon": [ 150 | "this is a photo of {}" 151 | ], 152 | "country211": [ 153 | "a photo i took in {}.", 154 | "a photo i took while visiting {}.", 155 | "a photo from my home country of {}.", 156 | "a photo from my visit to {}.", 157 | "a photo showing the country of {}." 158 | ], 159 | "kitti_distance": [ 160 | "{}" 161 | ], 162 | "gtsrb": [ 163 | "a zoomed in photo of a \"{}\" traffic sign.", 164 | "a centered photo of a \"{}\" traffic sign.", 165 | "a close up photo of a \"{}\" traffic sign." 166 | ], 167 | "resisc45": [ 168 | "satellite imagery of {}.", 169 | "aerial imagery of {}.", 170 | "satellite photo of {}.", 171 | "aerial photo of {}.", 172 | "satellite view of {}.", 173 | "aerial view of {}.", 174 | "satellite imagery of a {}.", 175 | "aerial imagery of a {}.", 176 | "satellite photo of a {}.", 177 | "aerial photo of a {}.", 178 | "satellite view of a {}.", 179 | "aerial view of a {}.", 180 | "satellite imagery of the {}.", 181 | "aerial imagery of the {}.", 182 | "satellite photo of the {}.", 183 | "aerial photo of the {}.", 184 | "satellite view of the {}.", 185 | "aerial view of the {}." 186 | ], 187 | "eurosat": [ 188 | "a centered satellite photo of {}.", 189 | "a centered satellite photo of a {}.", 190 | "a centered satellite photo of the {}." 191 | ], 192 | "stl10": [ 193 | "a photo of a {}.", 194 | "a photo of the {}." 195 | ], 196 | "fer2013": [ 197 | "a photo of a {} looking face.", 198 | "a photo of a face showing the emotion: {}.", 199 | "a photo of a face looking {}.", 200 | "a face that looks {}.", 201 | "they look {}.", 202 | "look at how {} they are." 203 | ], 204 | "mnist": [ 205 | "a photo of the number: \"{}\"." 206 | ], 207 | "flowers": [ 208 | "a photo of a {}, a type of flower." 209 | ], 210 | "caltech101": [ 211 | "a photo of a {}.", 212 | "a painting of a {}.", 213 | "a plastic {}.", 214 | "a sculpture of a {}.", 215 | "a sketch of a {}.", 216 | "a tattoo of a {}.", 217 | "a toy {}.", 218 | "a rendition of a {}.", 219 | "a embroidered {}.", 220 | "a cartoon {}.", 221 | "a {} in a video game.", 222 | "a plushie {}.", 223 | "a origami {}.", 224 | "art of a {}.", 225 | "graffiti of a {}.", 226 | "a drawing of a {}.", 227 | "a doodle of a {}.", 228 | "a photo of the {}.", 229 | "a painting of the {}.", 230 | "the plastic {}.", 231 | "a sculpture of the {}.", 232 | "a sketch of the {}.", 233 | "a tattoo of the {}.", 234 | "the toy {}.", 235 | "a rendition of the {}.", 236 | "the embroidered {}.", 237 | "the cartoon {}.", 238 | "the {} in a video game.", 239 | "the plushie {}.", 240 | "the origami {}.", 241 | "art of the {}.", 242 | "graffiti of the {}.", 243 | "a drawing of the {}.", 244 | "a doodle of the {}." 245 | ], 246 | "pets": [ 247 | "a photo of a {}, a type of pet." 248 | ], 249 | "dtd": [ 250 | "a photo of a {} texture.", 251 | "a photo of a {} pattern.", 252 | "a photo of a {} thing.", 253 | "a photo of a {} object.", 254 | "a photo of the {} texture.", 255 | "a photo of the {} pattern.", 256 | "a photo of the {} thing.", 257 | "a photo of the {} object." 258 | ], 259 | "voc2007": [ 260 | "a photo of a {}." 261 | ], 262 | "aircraft": [ 263 | "a photo of a {}, a type of aircraft.", 264 | "a photo of the {}, a type of aircraft." 265 | ], 266 | "cars": [ 267 | "a photo of a {}.", 268 | "a photo of the {}.", 269 | "a photo of my {}.", 270 | "i love my {}!", 271 | "a photo of my dirty {}.", 272 | "a photo of my clean {}.", 273 | "a photo of my new {}.", 274 | "a photo of my old {}." 275 | ], 276 | "sun397": [ 277 | "a photo of a {}.", 278 | "a photo of the {}." 279 | ] 280 | } -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | # From github.com/openai/CLIP 2 | 3 | import gzip 4 | import html 5 | import os 6 | from functools import lru_cache 7 | 8 | import ftfy 9 | import regex as re 10 | import torch 11 | 12 | 13 | @lru_cache() 14 | def default_bpe(): 15 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 16 | 17 | 18 | @lru_cache() 19 | def bytes_to_unicode(): 20 | """ 21 | Returns list of utf-8 byte and a corresponding list of unicode strings. 22 | The reversible bpe codes work on unicode strings. 23 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 24 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 25 | This is a signficant percentage of your normal, say, 32K bpe vocab. 26 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 27 | And avoids mapping to whitespace/control characters the bpe code barfs on. 28 | """ 29 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 30 | cs = bs[:] 31 | n = 0 32 | for b in range(2**8): 33 | if b not in bs: 34 | bs.append(b) 35 | cs.append(2**8+n) 36 | n += 1 37 | cs = [chr(n) for n in cs] 38 | return dict(zip(bs, cs)) 39 | 40 | 41 | def get_pairs(word): 42 | """Return set of symbol pairs in a word. 43 | Word is represented as tuple of symbols (symbols being variable-length strings). 44 | """ 45 | pairs = set() 46 | prev_char = word[0] 47 | for char in word[1:]: 48 | pairs.add((prev_char, char)) 49 | prev_char = char 50 | return pairs 51 | 52 | 53 | def basic_clean(text): 54 | text = ftfy.fix_text(text) 55 | text = html.unescape(html.unescape(text)) 56 | return text.strip() 57 | 58 | 59 | def whitespace_clean(text): 60 | text = re.sub(r'\s+', ' ', text) 61 | text = text.strip() 62 | return text 63 | 64 | 65 | class SimpleTokenizer(object): 66 | def __init__(self, bpe_path: str = default_bpe()): 67 | self.byte_encoder = bytes_to_unicode() 68 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 69 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 70 | merges = merges[1:49152-256-2+1] 71 | merges = [tuple(merge.split()) for merge in merges] 72 | vocab = list(bytes_to_unicode().values()) 73 | vocab = vocab + [v+'' for v in vocab] 74 | for merge in merges: 75 | vocab.append(''.join(merge)) 76 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 77 | self.encoder = dict(zip(vocab, range(len(vocab)))) 78 | self.decoder = {v: k for k, v in self.encoder.items()} 79 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 80 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 81 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 82 | 83 | def bpe(self, token): 84 | if token in self.cache: 85 | return self.cache[token] 86 | word = tuple(token[:-1]) + ( token[-1] + '',) 87 | pairs = get_pairs(word) 88 | 89 | if not pairs: 90 | return token+'' 91 | 92 | while True: 93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 94 | if bigram not in self.bpe_ranks: 95 | break 96 | first, second = bigram 97 | new_word = [] 98 | i = 0 99 | while i < len(word): 100 | try: 101 | j = word.index(first, i) 102 | new_word.extend(word[i:j]) 103 | i = j 104 | except: 105 | new_word.extend(word[i:]) 106 | break 107 | 108 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 109 | new_word.append(first+second) 110 | i += 2 111 | else: 112 | new_word.append(word[i]) 113 | i += 1 114 | new_word = tuple(new_word) 115 | word = new_word 116 | if len(word) == 1: 117 | break 118 | else: 119 | pairs = get_pairs(word) 120 | word = ' '.join(word) 121 | self.cache[token] = word 122 | return word 123 | 124 | def encode(self, text): 125 | bpe_tokens = [] 126 | text = whitespace_clean(basic_clean(text)).lower() 127 | for token in re.findall(self.pat, text): 128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 130 | return bpe_tokens 131 | 132 | def decode(self, tokens): 133 | text = ''.join([self.decoder[token] for token in tokens]) 134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 135 | return text 136 | 137 | def __call__(self, texts, context_length=77): 138 | if isinstance(texts, str): 139 | texts = [texts] 140 | 141 | sot_token = self.encoder["<|startoftext|>"] 142 | eot_token = self.encoder["<|endoftext|>"] 143 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] 144 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 145 | 146 | for i, tokens in enumerate(all_tokens): 147 | if len(tokens) > context_length: 148 | print(f"Input {texts[i]} is too long for context length {context_length}") 149 | tokens = tokens[:context_length] 150 | result[i, :len(tokens)] = torch.tensor(tokens) 151 | 152 | if len(result) == 1: 153 | return result[0] 154 | 155 | return result -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on BEiT v2, SLIP code bases 3 | # https://github.com/microsoft/unilm/tree/master/beit2 4 | # https://github.com/facebookresearch/SLIP 5 | # --------------------------------------------------------' 6 | import numpy as np 7 | import os 8 | import random 9 | import shutil 10 | import torch 11 | import torch.distributed as dist 12 | import torch.autograd as autograd 13 | 14 | from PIL import ImageFilter 15 | from datetime import datetime 16 | 17 | def unwrap_model(model): 18 | if hasattr(model, 'module'): 19 | return model.module 20 | else: 21 | return model 22 | 23 | def get_model(model): 24 | if isinstance(model, torch.nn.DataParallel) \ 25 | or isinstance(model, torch.nn.parallel.DistributedDataParallel): 26 | return model.module 27 | else: 28 | return model 29 | 30 | class AverageMeter(object): 31 | """Computes and stores the average and current value""" 32 | def __init__(self, name, fmt=':f'): 33 | self.name = name 34 | self.fmt = fmt 35 | self.reset() 36 | 37 | def reset(self): 38 | self.val = 0 39 | self.avg = 0 40 | self.sum = 0 41 | self.count = 0 42 | 43 | def update(self, val, n=1): 44 | self.val = val 45 | self.sum += val * n 46 | self.count += n 47 | self.avg = self.sum / self.count 48 | 49 | def synchronize(self): 50 | if not is_dist_avail_and_initialized(): 51 | return 52 | t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda') 53 | dist.barrier() 54 | dist.all_reduce(t) 55 | t = t.tolist() 56 | self.sum = int(t[0]) 57 | self.count = t[1] 58 | self.avg = self.sum / self.count 59 | 60 | def __str__(self): 61 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 62 | return fmtstr.format(**self.__dict__) 63 | 64 | 65 | class ProgressMeter(object): 66 | def __init__(self, num_batches, meters, prefix=""): 67 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 68 | self.meters = meters 69 | self.prefix = prefix 70 | 71 | def display(self, batch): 72 | nowstr = "["+ datetime.now().strftime("%H:%M:%S")+ "] " 73 | entries = [nowstr + self.prefix + self.batch_fmtstr.format(batch)] 74 | entries += [str(meter) for meter in self.meters] 75 | 76 | print('\t'.join(entries)) 77 | 78 | def synchronize(self): 79 | for meter in self.meters: 80 | meter.synchronize() 81 | 82 | def _get_batch_fmtstr(self, num_batches): 83 | num_digits = len(str(num_batches // 1)) 84 | fmt = '{:' + str(num_digits) + 'd}' 85 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 86 | 87 | 88 | def accuracy(output, target, topk=(1,)): 89 | """Computes the accuracy over the k top predictions for the specified values of k""" 90 | with torch.no_grad(): 91 | maxk = max(topk) 92 | batch_size = target.size(0) 93 | 94 | _, pred = output.topk(maxk, 1, True, True) 95 | pred = pred.t() 96 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 97 | 98 | res = [] 99 | for k in topk: 100 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 101 | res.append(correct_k.mul_(100.0 / batch_size)) 102 | return res 103 | 104 | 105 | def setup_for_distributed(is_master): 106 | """ 107 | This function disables printing when not in master process 108 | """ 109 | import builtins as __builtin__ 110 | builtin_print = __builtin__.print 111 | 112 | def print(*args, **kwargs): 113 | force = kwargs.pop('force', False) 114 | if is_master or force: 115 | builtin_print(*args, **kwargs) 116 | 117 | __builtin__.print = print 118 | 119 | 120 | def is_dist_avail_and_initialized(): 121 | if not dist.is_available(): 122 | return False 123 | if not dist.is_initialized(): 124 | return False 125 | return True 126 | 127 | 128 | def get_world_size(): 129 | if not is_dist_avail_and_initialized(): 130 | return 1 131 | return dist.get_world_size() 132 | 133 | 134 | def get_rank(): 135 | if not is_dist_avail_and_initialized(): 136 | return 0 137 | return dist.get_rank() 138 | 139 | 140 | def is_main_process(): 141 | return get_rank() == 0 142 | 143 | 144 | def save_on_master(state, is_best, output_dir,epoch=None,toal_epoch=25): 145 | if is_main_process(): 146 | ckpt_path = f'{output_dir}/checkpoint.pt' 147 | best_path = f'{output_dir}/checkpoint_best.pt' 148 | torch.save(state, ckpt_path) 149 | if epoch is not None: 150 | save_frep = toal_epoch*0.2 151 | if epoch % save_frep == 0: 152 | each_ckpt_path = f'{output_dir}/checkpoint_{epoch}.pt' 153 | shutil.copyfile(ckpt_path, each_ckpt_path) 154 | if is_best: 155 | shutil.copyfile(ckpt_path, best_path) 156 | 157 | 158 | def init_distributed_mode(args): 159 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 160 | args.rank = int(os.environ["RANK"]) 161 | args.world_size = int(os.environ['WORLD_SIZE']) 162 | args.gpu = int(os.environ['LOCAL_RANK']) 163 | elif 'SLURM_PROCID' in os.environ: 164 | args.rank = int(os.environ['SLURM_PROCID']) 165 | args.gpu = args.rank % torch.cuda.device_count() 166 | else: 167 | print('Not using distributed mode') 168 | args.distributed = False 169 | return 170 | 171 | args.distributed = True 172 | 173 | torch.cuda.set_device(args.gpu) 174 | args.dist_backend = 'nccl' 175 | print('| distributed init (rank {}): {}'.format( 176 | args.rank, args.dist_url), flush=True) 177 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 178 | world_size=args.world_size, rank=args.rank,timeout=datetime.timedelta(minutes=90)) 179 | torch.distributed.barrier() 180 | setup_for_distributed(args.rank == 0) 181 | 182 | 183 | def scaled_all_reduce(tensors, is_scale=True): 184 | """Performs the scaled all_reduce operation on the provided tensors. 185 | The input tensors are modified in-place. Currently supports only the sum 186 | reduction operator. The reduced values are scaled by the inverse size of the 187 | world size. 188 | """ 189 | world_size = get_world_size() 190 | # There is no need for reduction in the single-proc case 191 | if world_size == 1: 192 | return tensors 193 | # Queue the reductions 194 | reductions = [] 195 | for tensor in tensors: 196 | reduction = dist.all_reduce(tensor, async_op=True) 197 | reductions.append(reduction) 198 | # Wait for reductions to finish 199 | for reduction in reductions: 200 | reduction.wait() 201 | # Scale the results 202 | if is_scale: 203 | for tensor in tensors: 204 | tensor.mul_(1.0 / world_size) 205 | return tensors 206 | 207 | 208 | class GatherLayer(autograd.Function): 209 | """ 210 | Gather tensors from all workers with support for backward propagation: 211 | This implementation does not cut the gradients as torch.distributed.all_gather does. 212 | """ 213 | 214 | @staticmethod 215 | def forward(ctx, x): 216 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 217 | dist.all_gather(output, x) 218 | return tuple(output) 219 | 220 | @staticmethod 221 | def backward(ctx, *grads): 222 | all_gradients = torch.stack(grads) 223 | dist.all_reduce(all_gradients) 224 | return all_gradients[dist.get_rank()] 225 | 226 | 227 | def all_gather_batch_with_grad(tensors): 228 | """ 229 | Performs all_gather operation on the provided tensors. 230 | Graph remains connected for backward grad computation. 231 | """ 232 | # Queue the gathered tensors 233 | world_size = get_world_size() 234 | # There is no need for reduction in the single-proc case 235 | if world_size == 1: 236 | return tensors 237 | tensor_list = [] 238 | output_tensor = [] 239 | 240 | for tensor in tensors: 241 | tensor_all = GatherLayer.apply(tensor) 242 | tensor_list.append(tensor_all) 243 | 244 | for tensor_all in tensor_list: 245 | output_tensor.append(torch.cat(tensor_all, dim=0)) 246 | return output_tensor 247 | 248 | 249 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 250 | warmup_schedule = np.array([]) 251 | warmup_iters = warmup_epochs * niter_per_ep 252 | if warmup_epochs > 0: 253 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 254 | 255 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 256 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 257 | 258 | schedule = np.concatenate((warmup_schedule, schedule)) 259 | assert len(schedule) == epochs * niter_per_ep 260 | return schedule 261 | 262 | 263 | class GaussianBlur(object): 264 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 265 | 266 | def __init__(self, sigma=[.1, 2.]): 267 | self.sigma = sigma 268 | 269 | def __call__(self, x): 270 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 271 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 272 | return x 273 | 274 | 275 | # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 276 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 277 | """ 278 | grid_size: int of the grid height and width 279 | return: 280 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 281 | """ 282 | grid_h = np.arange(grid_size, dtype=np.float32) 283 | grid_w = np.arange(grid_size, dtype=np.float32) 284 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 285 | grid = np.stack(grid, axis=0) 286 | 287 | grid = grid.reshape([2, 1, grid_size, grid_size]) 288 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 289 | if cls_token: 290 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 291 | return pos_embed 292 | 293 | 294 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 295 | assert embed_dim % 2 == 0 296 | 297 | # use half of dimensions to encode grid_h 298 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 299 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 300 | 301 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 302 | return emb 303 | 304 | 305 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 306 | """ 307 | embed_dim: output dimension for each position 308 | pos: a list of positions to be encoded: size (M,) 309 | out: (M, D) 310 | """ 311 | assert embed_dim % 2 == 0 312 | omega = np.arange(embed_dim // 2, dtype=np.float) 313 | omega /= embed_dim / 2. 314 | omega = 1. / 10000**omega # (D/2,) 315 | 316 | pos = pos.reshape(-1) # (M,) 317 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 318 | 319 | emb_sin = np.sin(out) # (M, D/2) 320 | emb_cos = np.cos(out) # (M, D/2) 321 | 322 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 323 | return emb 324 | 325 | 326 | # -------------------------------------------------------- 327 | # Interpolate position embeddings for high-resolution 328 | # References: 329 | # DeiT: https://github.com/facebookresearch/deit 330 | # -------------------------------------------------------- 331 | def interpolate_pos_embed(model, checkpoint_model): 332 | if 'pos_embed' in checkpoint_model: 333 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 334 | embedding_size = pos_embed_checkpoint.shape[-1] 335 | num_patches = model.patch_embed.num_patches 336 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 337 | # height (== width) for the checkpoint position embedding 338 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 339 | # height (== width) for the new position embedding 340 | new_size = int(num_patches ** 0.5) 341 | # class_token and dist_token are kept unchanged 342 | if orig_size != new_size: 343 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 344 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 345 | # only the position tokens are interpolated 346 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 347 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 348 | pos_tokens = torch.nn.functional.interpolate( 349 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 350 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 351 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 352 | checkpoint_model['pos_embed'] = new_pos_embed --------------------------------------------------------------------------------