├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── bpe_simple_vocab_16e6.txt.gz
├── dataset_catalog.json
├── datasets.py
├── docs
├── aclip.png
├── attMask.png
├── masking_Page_1.png
└── vis.png
├── eval_zeroshot.py
├── losses.py
├── main.py
├── main_linear.py
├── make_dataset.py
├── models.py
├── templates.json
├── tokenizer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A-CLIP
2 |
3 |
4 | ## Introduction
5 | This is the official implementation of Attentive Mask CLIP ([A-CLIP](https://arxiv.org/abs/2212.08653v2), ICCV2023). A-CLIP aims to improve the efficiency of CLIP training by introducing an efficient image augmentation approach called image token removal.
6 |
7 | The purpose of this work is to improve the training efficiency of CLIP by using the image token removal strategy. This method has been proven effective in tasks such as mask image modeling. However, random masking in CLIP may lead to semantic mismatch issues that could affect performance. To address this concern, we propose an attentive masking strategy that retains the most relevant areas to the text while removing tokens.
8 |
9 |
10 |
11 | For example, in the case of a Ferrari image in the top-left corner, it would be difficult to identify and correctly align with the text under random masking. However, with our attentive selection, we only eliminate irrelevant areas while maximizing the preservation of semantics to avoid ambiguity.
12 |
13 |
14 | | Mask Strategy (View × Ratio) | IN 1K 0-shot | Flickr30K I2T | Flickr30K T2I | MS COCO I2T | MS COCO T2I |
15 | |------------------------------|--------------|---------------|---------------|-------------|-------------|
16 | | w/o mask | | | | | |
17 | | 1×100% | 37.6 | 51.4 | 32.6 | 27.9 | 17.6 |
18 | | +random mask | | | | | |
19 | | 1×50% | 35.0 | 48.8 | 32.5 | 28.9 | 16.6 |
20 | | 2×50% | 38.0 | 54.6 | 34.4 | 31.1 | 18.7 |
21 | | +attentive mask | | | | | |
22 | | 1×50% | 39.5 | 57.6 | 36.6 | 34.2 | 19.8 |
23 | | 2×50% (highlighted) | **41.3** | **59.3** | **38.4** | **35.1** | **21.3** |
24 |
25 | Experiments show that attentive masking can avoid the semantic mismatch problems caused by random mask and bring the following benefits:
26 | 1. **Training efficiency**: Eliminates redundant tokens and improves CLIP's training efficiency.
27 | 2. **Feature stability**: Retains semantically relevant areas, enhances feature stability, and alleviates ambiguity caused by semantic mismatch. By using this approach, our A-CLIP significantly improves both the training efficiency and performance of CLIP.
28 |
29 | We compare our attentive mask CLIP (A-CLIP) with CLIP, SLIP, and MaskCLIP. A-CLIP outperforms CLIP by +6.3%, +11.3/+9.5 and +10.1/+5.6 on Imagenet-1K zero-shot classification, Flickr30K and MS COCO I2T/T2I retrieval. An efficient variant termed A-CLIP-eff outperforms CLIP by +5.3%, +11.3/+8.0, and +9.5/+4.9 on these benchmarks, while reducing the training time to 0.86x.
30 |
31 | | Methods | Training Time | GPU Memory | IN 1K 0-shot | Flickr30K I2T/T2I | MS COCO I2T/T2I |
32 | |---------|---------------|------------|--------------|------------------|-----------------|
33 | | CLIP | 1.00× | 14G | 37.6 | 51.4/32.6 | 27.9/17.6 |
34 | | SLIP | 2.67× | 30G | 42.8 | 57.2/41.2 | 33.6/21.9 |
35 | | MaskCLIP| 1.56× | 16G | 42.7 | 60.0/38.8 | 34.1/21.2 |
36 | | A-CLIP | 1.16× | 14G | **43.9** | **62.7/42.1** | **38.0/23.2** |
37 | | A-CLIP-eff | 0.86× | 13G | 42.9 | **62.7**/40.6 | 37.4/22.5 |
38 |
39 | > Note: The full training wall clock time and GPU memory footprint are measured on the same device. We report the training cost relative to the original CLIP.
40 |
41 | Zero-shot evaluation on a variety of classification benchmarks. The **Epochs** indicates the number of training rounds. A-CLIP significantly outperforms other methods at all epochs setting, both in terms of average accuracy and number of winning tracks of above 25 downstream tasks.
42 |
43 | | Epochs | Methods | Food-101 | CIFAR-10 | CIFAR-100 | CUB | SUN397 | Cars | Aircraft | DTD | Pets | Caltech-101 | Flowers | MNIST | FER-2013 | STL-10 | EuroSAT | RESISC45 | GTSRB | KITTI | Country211 | PCAM | UCF101 | Kinetics700 | CLEVR | HatefulMemes | SST2 | ImageNet | Average |
44 | |:------:|:--------:|:--------:|:--------:|:---------:|:-----:|:------:|:-----:|:--------:|:-----:|:-----:|:-----------:|:-------:|:-----:|:-------:|:------:|:-------:|:-------:|:-----:|:-----:|:---------:|:----:|:------:|:----------:|:-----:|:-----------:|:----:|:--:|:-------:|
45 | | 25 | CLIP | 50\.6 | 66\.0 | 34\.5 | 38\.8 | 51\.1 | 4\.0 | 5\.4 | 21\.2 | 28\.5 | 60\.9 | 53\.3 | 8\.4 | 17\.3 | 90\.5 | 30\.2 | 21\.5 | 6\.1 | 35\.1 | 10\.5 | 53\.5 | 28\.5 | 22\.1 | 10\.8 | 52\.4 | 50\.7 | 37\.6 | 34\.2 |
46 | | | SLIP | 59\.5 | 78\.6 | 45\.2 | 38\.7 | 53\.4 | **5\.4** | 5\.7 | **26\.1** | 31\.1 | 71\.0 | 56\.6 | 9\.8 | 19\.6 | 94\.4 | 20\.3 | 28\.9 | **14\.5** | 34\.0 | 11\.6 | **55\.4** | 37\.7 | 26\.9 | **17\.5** | **52\.8** | **51\.1** | 42\.8 | 38\.0 |
47 | | | MaskCLIP | **60\.6** | 70\.1 | 41\.6 | **43\.3** | 54\.0 | 4\.9 | **8\.2** | 25\.5 | **36\.8** | 68\.9 | 53\.6 | **11\.2** | 22\.4 | 93\.9 | **35\.1** | 24\.8 | 10\.1 | 30\.5 | 12\.5 | 51\.2 | 37\.0 | 28\.1 | 12\.9 | **52\.8** | 50\.0 | 42\.7 | 37\.8 |
48 | | | A-CLIP | 58\.3 | **82\.8** | **51\.0** | 43\.0 | **57\.0** | **5\.4** | 7\.6 | 26\.0 | 32\.0 | **71\.6** | **57\.7** | 9\.8 | **29\.7** | **95\.4** | 29\.3 | **30\.3** | 13\.1 | **35\.2** | **13\.5** | 51\.6 | **38\.3** | **29\.6** | 14\.1 | **52\.8** | 49\.9 | **43\.9** | **39\.6** |
49 | | 50 | CLIP | 55\.2 | 77\.0 | 43\.8 | 38\.9 | 49\.0 | 4\.7 | 6\.3 | 23\.5 | 27\.2 | 63\.5 | 56\.2 | **12\.5** | 30\.2 | 92\.1 | 21\.0 | 31\.9 | 7\.4 | 33\.6 | 10\.9 | 50\.8 | 35\.5 | 24\.8 | 14\.0 | 49\.9 | 50\.1 | 39\.4 | 36\.5 |
50 | | | SLIP | 61\.9 | 76\.8 | 48\.9 | 39\.2 | 54\.8 | 7\.3 | 9\.0 | **29\.8** | 31\.9 | **75\.0** | 57\.7 | 9\.8 | 24\.9 | **95\.6** | 37\.8 | 32\.5 | 9\.0 | **35\.1** | 12\.7 | 54\.4 | 41\.1 | 30\.3 | 13\.8 | 49\.5 | 49\.9 | 44\.1 | 39\.7 |
51 | | | A-CLIP | **62\.2** | **81\.5** | **53\.7** | **48\.2** | **58\.7** | **8\.3** | **10\.2** | 27\.7 | **40\.5** | 73\.3 | **61\.0** | 11\.3 | **32\.9** | 95\.5 | **39\.7** | **37\.5** | **9\.4** | 23\.3 | **14\.4** | **63\.7** | **42\.5** | **31\.6** | **19\.6** | **50\.8** | **52\.3** | **46\.3** | **42\.2** |
52 | | 100 | CLIP | 60\.4 | 79\.4 | 44\.6 | 43\.3 | 53\.0 | 8\.5 | 8\.2 | 26\.2 | 34\.7 | 68\.9 | 59\.2 | 11\.4 | 20\.4 | 93\.2 | 23\.3 | 27\.3 | 10\.3 | 23\.1 | 12\.0 | 54\.0 | 36\.7 | 27\.7 | 13\.0 | 50\.9 | 50\.1 | 42\.7 | 37\.8 |
53 | | | SLIP | 63\.0 | 83\.1 | 50\.4 | 43\.0 | 52\.0 | 8\.3 | 8\.3 | 26\.2 | 34\.0 | 74\.6 | 61\.1 | **16\.1** | 32\.4 | 95\.1 | 22\.6 | 28\.5 | 10\.5 | **34\.8** | 11\.5 | 52\.1 | 37\.3 | 28\.3 | 13\.7 | **55\.2** | 49\.9 | 45\.0 | 39\.9 |
54 | | | A-CLIP | **66\.7** | **86\.6** | **58\.6** | **51\.4** | **58\.6** | **10\.5** | **11\.9** | **33\.1** | **48\.5** | **74\.9** | **64\.3** | 7\.8 | **31\.2** | **96\.7** | **35\.6** | **35\.8** | **12\.9** | 30\.5 | **15\.7** | **57\.1** | **44\.1** | **33\.1** | **22\.9** | 52\.7 | **50\.7** | **48\.1** | **43\.8** |
55 |
56 | The table shows the results of using longer training schedulers and a larger model size.
57 |
58 | | Methods | IN 1K - 0-shot | Flickr30K - I2T | Flickr30K - T2I | MS COCO - I2T | MS COCO - T2I |
59 | |----------|----------------|----------------|----------------|---------------|---------------|
60 | | CLIP(25ep) | 37.6 | 51.4 | 32.6 | 27.9 | 17.6 |
61 | | SLIP(25ep) | 42.8 | 57.2 | 41.2 | 33.6 | 21.9 |
62 | | A-CLIP(25ep) | **43.9** | **62.7** | **42.1** | **38.0** | **23.2** |
63 | | CLIP(50ep) | 39.4 | 53.9 | 35.8 | 30.2 | 19.2 |
64 | | SLIP(50ep) | 44.1 | 60.6 | 41.1 | 33.2 | 22.3 |
65 | | A-CLIP(50ep) | **46.3** | **66.7** | **43.2** | **39.8** | **24.4** |
66 | | CLIP(100ep) | 42.7 | 61.0 | 37.9 | 34.4 | 20.9 |
67 | | SLIP(100ep) | 45.0 | 59.3 | 41.4 | 34.6 | 22.7 |
68 | | A-CLIP(100ep) | **48.0** | **66.3** | **45.7** | **40.7** | **25.1** |
69 | | CLIP(VIT-L) | 40.4 | 51.4 | 35.2 | 28.9 | 18.5 |
70 | | SLIP(VIT-L) | 46.2 | 60.6 | 43.7 | 35.3 | 23.5 |
71 | | A-CLIP(VIT-L) | **48.9** | **64.1** | **48.2** | **39.1** | **26.9** |
72 |
73 |
74 | ## Setup
75 | Install [PyTorch](https://pytorch.org) and [timm](https://github.com/rwightman/pytorch-image-models).
76 | The code has been tested with CUDA 11.6, PyTorch 1.13.0 and timm 0.5.4.
77 |
78 | ### YFCC15M Setup
79 | For data preparation, refer to [SLIP](https://github.com/facebookresearch/SLIP/tree/main).
80 |
81 | ## Pre-training
82 | ### A-CLIP ViT-Base with 8-nodes (batch size 4096)
83 | ```
84 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 main.py \
85 | --root /path/to/yfcc100m --dataset yfcc15m --metadata /path/to/yfcc15m.pkl \
86 | --model ACLIP_VITB16 --batch-size 64 \
87 | --lr 5e-4 --wd 0.5
88 | ```
89 |
90 | ## Visualization
91 | There are some cases to show Attentive mask magically preserve the content of text descriptions and filter out redundant backgrounds.
92 |
93 |
94 |
95 | # Citation
96 | If the code and paper help your research, please kindly cite:
97 | ```
98 | @InProceedings{Yang_2023_ICCV, author = {Yang, Yifan and Huang, Weiquan and Wei, Yixuan and Peng, Houwen and Jiang, Xinyang and Jiang, Huiqiang and Wei, Fangyun and Wang, Yin and Hu, Han and Qiu, Lili and Yang, Yuqing}, title = {Attentive Mask CLIP}, booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, month = {October}, year = {2023}, pages = {2771-2781} }
99 | ```
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/dataset_catalog.json:
--------------------------------------------------------------------------------
1 | {
2 | "food101": {
3 | "path": "/path/to/dataset",
4 | "type": "imagefolder",
5 | "train": "train",
6 | "test": "test"
7 | },
8 | "cifar10": {
9 | "path": "./data/cifar10",
10 | "type": "special"
11 | },
12 | "cifar100": {
13 | "path": "./data/cifar100",
14 | "type": "special"
15 | },
16 | "cub200": {
17 | "path": "/path/to/dataset",
18 | "type": "imagefolder",
19 | "train": "train",
20 | "test": "val"
21 | },
22 | "sun397": {
23 | "path": "/path/to/dataset",
24 | "type": "filelist",
25 | "train": "train",
26 | "test": "test"
27 | },
28 | "cars": {
29 | "path": "/path/to/dataset",
30 | "type": "imagefolder",
31 | "train": "train",
32 | "test": "test"
33 | },
34 | "aircraft": {
35 | "path": "/path/to/dataset",
36 | "type": "imagefolder",
37 | "train": "train",
38 | "test": "test"
39 | },
40 | "dtd": {
41 | "path": "/path/to/dataset",
42 | "type": "imagefolder",
43 | "train": "train",
44 | "test": "test"
45 | },
46 | "pets": {
47 | "path": "/path/to/dataset",
48 | "type": "imagefolder",
49 | "train": "train",
50 | "test": "test"
51 | },
52 | "caltech101": {
53 | "path": "/path/to/dataset",
54 | "type": "imagefolder",
55 | "train": "train",
56 | "test": "test"
57 | },
58 | "flowers": {
59 | "path": "/path/to/dataset",
60 | "type": "imagefolder",
61 | "train": "train",
62 | "test": "test"
63 | },
64 | "mnist": {
65 | "path": "./data/mnist",
66 | "type": "special"
67 | },
68 | "fer2013": {
69 | "path": "/path/to/dataset",
70 | "type": "imagefolder",
71 | "train": "train",
72 | "test": "test"
73 | },
74 | "stl10": {
75 | "path": "./data/stl10",
76 | "type": "special"
77 | },
78 | "eurosat": {
79 | "path": "/path/to/dataset",
80 | "type": "imagefolder",
81 | "train": "train",
82 | "test": "val"
83 | },
84 | "resisc45": {
85 | "path": "/path/to/dataset",
86 | "type": "imagefolder",
87 | "train": "train",
88 | "test": "test"
89 | },
90 | "gtsrb": {
91 | "path": "/path/to/dataset",
92 | "type": "imagefolder",
93 | "train": "train",
94 | "test": "test"
95 | },
96 | "kitti_distance": {
97 | "path": "/path/to/dataset",
98 | "type": "imagefolder",
99 | "train": "train",
100 | "test": "val"
101 | },
102 | "country211": {
103 | "path": "/path/to/dataset",
104 | "type": "imagefolder",
105 | "train": "train",
106 | "test": "test"
107 | },
108 | "patch_camelyon": {
109 | "path": "/path/to/dataset",
110 | "type": "imagefolder",
111 | "train": "train",
112 | "test": "val"
113 | },
114 | "ucf101_frames": {
115 | "path": "/path/to/dataset",
116 | "type": "imagefolder",
117 | "train": "train",
118 | "test": "val"
119 | },
120 | "kinetics700_frames": {
121 | "path": "/path/to/dataset",
122 | "type": "imagefolder",
123 | "train": "train_images",
124 | "test": "val_images"
125 | },
126 | "clevr_counts": {
127 | "path": "/path/to/dataset",
128 | "type": "filelist",
129 | "train": "train",
130 | "test": "val"
131 | },
132 | "hateful_memes": {
133 | "path": "/path/to/dataset",
134 | "type": "imagefolder",
135 | "train": "train",
136 | "test": "dev"
137 | },
138 | "rendered_sst2": {
139 | "path": "/path/to/dataset",
140 | "type": "imagefolder",
141 | "train": "train",
142 | "test": "test"
143 | },
144 | "imagenet": {
145 | "path": "/path/to/dataset",
146 | "type": "imagefolder",
147 | "train": "train",
148 | "test": "val"
149 | }
150 | }
151 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from collections import defaultdict
7 | import json
8 | import os
9 | import pickle
10 | import zipfile
11 |
12 | import numpy as np
13 | from PIL import Image, ImageFile
14 |
15 | import torch
16 | from torchvision import transforms
17 | from torchvision import datasets as t_datasets
18 | from torchvision.transforms import functional as F
19 |
20 | import utils
21 |
22 |
23 | ImageFile.LOAD_TRUNCATED_IMAGES = True
24 |
25 |
26 | def pil_loader(path):
27 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
28 | with open(path, 'rb') as f:
29 | img = Image.open(f)
30 | return img.convert('RGB')
31 |
32 |
33 | def yfcc_loader(root, index):
34 | index = format(index, "0>8d")
35 | repo = index[:2]
36 | z = index[2: 5]
37 | file_img = index[5:] + '.jpg'
38 | path_zip = os.path.join(root, 'images', repo, z) + '.zip'
39 | with zipfile.ZipFile(path_zip, 'r') as myzip:
40 | img = Image.open(myzip.open(file_img))
41 | return img.convert('RGB')
42 |
43 |
44 | class ImageCaptionDatasetBase(torch.utils.data.Dataset):
45 | def __init__(self, dataset, root, metadata):
46 | self.dataset = dataset
47 | self.root = root
48 | if self.dataset == 'yfcc15m':
49 | with open(metadata, 'rb') as f:
50 | self.samples = pickle.load(f)
51 | elif self.dataset == 'coco':
52 | samples = defaultdict(list)
53 | with open(metadata) as f:
54 | annotations = json.load(f)['annotations']
55 | for ann in annotations:
56 | samples[ann['image_id']].append(ann['caption'])
57 | self.samples = [(k, v) for k, v in samples.items()]
58 | elif self.dataset == 'cc12m' or self.dataset == 'cc3m':
59 | self.samples = np.load(metadata, allow_pickle=True)
60 | elif self.dataset == 'redcaps':
61 | with open(metadata) as f:
62 | annotations = json.load(f)
63 | self.samples = [(ann['image_id'], ann['subreddit'], ann['caption']) for ann in annotations]
64 |
65 | def get_raw_item(self, i):
66 | if self.dataset == 'yfcc15m':
67 | index, title, desc = self.samples[i]
68 | caption = np.random.choice([title, desc])
69 | img = yfcc_loader(self.root, index)
70 | elif self.dataset == 'coco':
71 | index, captions = self.samples[i]
72 | path = os.path.join(self.root, 'train2017', '{:012d}.jpg'.format(index))
73 | img = pil_loader(path)
74 | caption = np.random.choice(captions)
75 | elif self.dataset == 'cc3m':
76 | ann = self.samples[i]
77 | filename, captions = ann['image_id'], ann['captions']
78 | path = os.path.join(self.root, str(filename))
79 | img = pil_loader(path)
80 | caption = np.random.choice(captions)
81 | elif self.dataset == 'cc12m':
82 | ann = self.samples[i]
83 | filename, captions = ann['image_name'], ann['captions']
84 | path = os.path.join(self.root, filename)
85 | img = pil_loader(path)
86 | caption = np.random.choice(captions)
87 | elif self.dataset == 'redcaps':
88 | image_id, subreddit, caption = self.samples[i]
89 | path = os.path.join(self.root, subreddit, f"{image_id}.jpg")
90 | img = pil_loader(path)
91 |
92 | return img, caption
93 |
94 | def __getitem__(self, i):
95 | raise NotImplementedError
96 |
97 | def __len__(self):
98 | return len(self.samples)
99 |
100 |
101 | class ImageCaptionDatasetCLIP(ImageCaptionDatasetBase):
102 | def __init__(self, dataset, root, metadata, transform=None, tokenizer=None):
103 | super().__init__(dataset, root, metadata)
104 |
105 | self.transform = transform
106 | self.tokenizer = tokenizer
107 |
108 | def __getitem__(self, i):
109 | img, caption = self.get_raw_item(i)
110 |
111 | # apply transformation
112 | if self.transform is not None:
113 | image = self.transform(img)
114 |
115 | # tokenize caption
116 | if self.tokenizer is not None:
117 | caption = self.tokenizer(caption)
118 |
119 | return image, caption
120 |
121 |
122 | class ImageCaptionDatasetSLIP(ImageCaptionDatasetBase):
123 | def __init__(self, dataset, root, metadata, transform, augment, tokenizer=None):
124 | super().__init__(dataset, root, metadata)
125 |
126 | self.transform = transform
127 | self.augment = augment
128 | self.tokenizer = tokenizer
129 |
130 | def __getitem__(self, i):
131 | img, caption = self.get_raw_item(i)
132 |
133 | image = self.transform(img)
134 | aug1 = self.augment(img)
135 | aug2 = self.augment(img)
136 |
137 | # tokenize caption
138 | if self.tokenizer is not None:
139 | caption = self.tokenizer(caption)
140 |
141 | return image, caption, aug1, aug2
142 |
143 | class ImageCaptionDatasetACLIP(ImageCaptionDatasetBase):
144 | def __init__(self, dataset, root, metadata, transform, ema_transform, tokenizer=None):
145 | super().__init__(dataset, root, metadata)
146 | self.transform = transform
147 | self.ema_transform = ema_transform
148 | self.tokenizer = tokenizer
149 | self.get_three_crop = GetThreeRandomResizedCrop(224, scale=(0.5, 1.0))
150 |
151 | def __getitem__(self, i):
152 | img, caption = self.get_raw_item(i)
153 | res = self.get_three_crop(img)
154 |
155 | im1, ret1 = res[0]
156 | im2, ret2 = res[1]
157 | im3, ret3 = res[2]
158 |
159 | im1 = self.transform(im1)
160 | im2 = self.transform(im2)
161 | im3 = self.ema_transform(im3)
162 |
163 | pos = np.array([ret1,ret2,ret3])
164 | if self.tokenizer is not None:
165 | caption = self.tokenizer(caption)
166 |
167 | return [im1, im2, im3], pos, caption
168 |
169 | class ImageCaptionDatasetSSL(ImageCaptionDatasetBase):
170 | def __init__(self, dataset, root, metadata, augment):
171 | super().__init__(dataset, root, metadata)
172 |
173 | self.augment = augment
174 |
175 | def __getitem__(self, i):
176 | img, _ = self.get_raw_item(i)
177 |
178 | aug1 = self.augment(img)
179 | aug2 = self.augment(img)
180 |
181 | return aug1, aug2
182 |
183 |
184 | class FileListDataset(torch.utils.data.Dataset):
185 | def __init__(self, images, labels, transform=None, target_transform=None):
186 | self.transform = transform
187 | self.target_transform = target_transform
188 | self.images = np.load(images)
189 | self.labels = np.load(labels)
190 |
191 | def __getitem__(self, index):
192 | img = pil_loader(self.images[index])
193 | target = self.labels[index]
194 |
195 | if self.transform is not None:
196 | img = self.transform(img)
197 |
198 | if self.target_transform is not None:
199 | target = self.target_transform(target)
200 |
201 | return img, target
202 |
203 | def __len__(self):
204 | return len(self.images)
205 |
206 |
207 | def get_downstream_dataset(catalog, name, is_train, transform):
208 | entry = catalog[name]
209 | root = entry['path']
210 | if entry['type'] == 'imagefolder':
211 | dataset = t_datasets.ImageFolder(os.path.join(root, entry['train'] if is_train else entry['test']),
212 | transform=transform)
213 | elif entry['type'] == 'special':
214 | if name == 'cifar10':
215 | dataset = t_datasets.CIFAR10(root, train=is_train,
216 | transform=transform, download=True)
217 | elif name == 'cifar100':
218 | dataset = t_datasets.CIFAR100(root, train=is_train,
219 | transform=transform, download=True)
220 | elif name == 'stl10':
221 | dataset = t_datasets.STL10(root, split='train' if is_train else 'test',
222 | transform=transform, download=True)
223 | elif name == 'mnist':
224 | dataset = t_datasets.MNIST(root, train=is_train,
225 | transform=transform, download=True)
226 | elif entry['type'] == 'filelist':
227 | path = entry['train'] if is_train else entry['test']
228 | val_images = os.path.join(root, path + '_images.npy')
229 | val_labels = os.path.join(root, path + '_labels.npy')
230 | if name == 'clevr_counts':
231 | target_transform = lambda x: ['count_10', 'count_3', 'count_4', 'count_5', 'count_6', 'count_7', 'count_8', 'count_9'].index(x)
232 | else:
233 | target_transform = None
234 | dataset = FileListDataset(val_images, val_labels, transform, target_transform)
235 | else:
236 | raise Exception('Unknown dataset')
237 |
238 | return dataset
239 |
240 | class GetThreeRandomResizedCrop(transforms.RandomResizedCrop):
241 | def forward(self, img):
242 | """
243 | Args:
244 | img (PIL Image or Tensor): Image to be cropped and resized.
245 |
246 | Returns:
247 | List[(cropped image, ret)] *3.
248 | The scale of the last image is larger than the first two.
249 | """
250 | ret1 = self.get_params(img, self.scale, self.ratio)
251 | ret2 = self.get_params(img, self.scale, self.ratio)
252 |
253 | try:
254 | _, height, width = F.get_dimensions(img)
255 | except:
256 | width, height = F.get_image_size(img)
257 |
258 | im1 = F.resized_crop(img, *ret1, self.size, self.interpolation)
259 | im2 = F.resized_crop(img, *ret2, self.size, self.interpolation)
260 |
261 | # zoom out
262 | ret3 = [0, 0, 0, 0]
263 | ret3[0], ret3[1], = min(ret1[0], ret2[0]), min(ret1[1], ret2[2])
264 |
265 | rh = max(ret1[0] + ret1[2], ret2[0] + ret2[2])
266 | rw = max(ret1[1] + ret1[3], ret2[1] + ret2[3])
267 | ret3[2], ret3[3] = rh - ret3[0], rw - ret3[1]
268 |
269 | ret3[0] = torch.randint(0, ret3[0] + 1, size=(1,)).item() if ret3[0] > 0 else ret3[0]
270 | ret3[1] = torch.randint(0, ret3[1] + 1, size=(1,)).item() if ret3[1] > 0 else ret3[1]
271 |
272 | ret3[2] = torch.randint(ret3[2], height - ret3[0] + 1, size=(1,)).item() if ret3[2] < height else ret3[2]
273 | ret3[3] = torch.randint(ret3[3], width - ret3[1] + 1, size=(1,)).item() if ret3[3] < width else ret3[3]
274 |
275 | im3 = F.resized_crop(img, *ret3, self.size, self.interpolation)
276 |
277 | return [(im1, ret1), (im2, ret2), (im3, ret3)]
278 |
279 | def get_dataset(train_transform, tokenizer, args):
280 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
281 | std=[0.229, 0.224, 0.225])
282 | augment = transforms.Comrete([
283 | transforms.RandomResizedCrop(224, scale=(0.08, 1.)),
284 | transforms.RandomApply([
285 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
286 | ], p=0.8),
287 | transforms.RandomGrayscale(p=0.2),
288 | transforms.RandomApply([utils.GaussianBlur([.1, 2.])], p=0.5),
289 | transforms.RandomHorizontalFlip(),
290 | transforms.ToTensor(),
291 | normalize,
292 | ])
293 |
294 | if args.model.startswith('SIMCLR'):
295 | return ImageCaptionDatasetSSL(args.dataset, args.root, args.metadata, augment)
296 | elif args.model.startswith('CLIP'):
297 | return ImageCaptionDatasetCLIP(args.dataset, args.root, args.metadata, train_transform, tokenizer)
298 | elif args.model.startswith('SLIP'):
299 | return ImageCaptionDatasetSLIP(args.dataset, args.root, args.metadata, train_transform, augment, tokenizer)
300 | elif args.model.startswith('ACLIP'):
301 | ema_transform = transforms.Compose([
302 | transforms.ToTensor(),
303 | normalize
304 | ])
305 | return ImageCaptionDatasetACLIP(args.dataset, args.root, args.metadata, train_transform, ema_transform, tokenizer)
306 |
307 |
308 |
309 |
--------------------------------------------------------------------------------
/docs/aclip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/docs/aclip.png
--------------------------------------------------------------------------------
/docs/attMask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/docs/attMask.png
--------------------------------------------------------------------------------
/docs/masking_Page_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/docs/masking_Page_1.png
--------------------------------------------------------------------------------
/docs/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/A-CLIP/0e51942fb677f6d26accfb1d1d19f56219563f0e/docs/vis.png
--------------------------------------------------------------------------------
/eval_zeroshot.py:
--------------------------------------------------------------------------------
1 | # Based on SLIP code bases
2 | # https://github.com/facebookresearch/SLIP
3 | # --------------------------------------------------------
4 | import argparse
5 | from collections import OrderedDict
6 | import json
7 | import os
8 | from sklearn import metrics
9 |
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | import torch.utils.data
13 | import torchvision.transforms as transforms
14 |
15 | import datasets
16 | import models
17 | from tokenizer import SimpleTokenizer
18 | import utils
19 |
20 |
21 | def get_args_parser():
22 | parser = argparse.ArgumentParser(description='SLIP 0-shot evaluations', add_help=False)
23 | parser.add_argument('--output-dir', default='./', type=str, help='output dir')
24 | parser.add_argument('--batch-size', default=256, type=int, help='batch_size')
25 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
26 | help='number of data loading workers per process')
27 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
28 | return parser
29 |
30 |
31 | def main(args):
32 | # optionally resume from a checkpoint (takes precedence over autoresume)
33 | if args.resume:
34 | ckpt_path = args.resume
35 | elif os.path.isfile(os.path.join(args.output_dir, 'checkpoint_best.pt')):
36 | ckpt_path = os.path.join(args.output_dir, 'checkpoint_best.pt')
37 | else:
38 | raise Exception('no checkpoint found')
39 |
40 | ckpt = torch.load(ckpt_path, map_location='cpu')
41 | state_dict = OrderedDict()
42 | for k, v in ckpt['state_dict'].items():
43 | state_dict[k.replace('module.', '')] = v
44 |
45 | # create model
46 | old_args = ckpt['args']
47 | print("=> creating model: {}".format(old_args.model))
48 | model = getattr(models, old_args.model)()
49 | model.cuda()
50 | model.load_state_dict(state_dict, strict=True)
51 | print("=> loaded resume checkpoint '{}' (epoch {})".format(args.resume, ckpt['epoch']))
52 |
53 | cudnn.benchmark = True
54 |
55 | cwd = os.path.dirname(os.path.realpath(__file__))
56 | with open(os.path.join(cwd, 'dataset_catalog.json')) as f:
57 | catalog = json.load(f)
58 |
59 | with open(os.path.join(cwd, 'templates.json')) as f:
60 | all_templates = json.load(f)
61 |
62 | with open(os.path.join(cwd, 'labels.json')) as f:
63 | all_labels = json.load(f)
64 |
65 | # Data loading code
66 | print("=> creating dataset")
67 | tokenizer = SimpleTokenizer()
68 | val_transform = transforms.Compose([
69 | transforms.Resize(224),
70 | transforms.CenterCrop(224),
71 | lambda x: x.convert('RGB'),
72 | transforms.ToTensor(),
73 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
74 | std=[0.229, 0.224, 0.225])
75 | ])
76 |
77 | results = []
78 | for d in catalog:
79 | print('Evaluating {}'.format(d))
80 | val_dataset = datasets.get_downstream_dataset(catalog, name=d, is_train=False, transform=val_transform)
81 |
82 | val_loader = torch.utils.data.DataLoader(
83 | val_dataset, batch_size=args.batch_size, shuffle=False,
84 | num_workers=args.workers, pin_memory=True, drop_last=False)
85 |
86 | templates = all_templates[d]
87 | labels = all_labels[d]
88 |
89 | is_acc = d not in ['aircraft', 'pets', 'caltech101', 'flowers', 'kinetics700_frames', 'hateful_memes']
90 |
91 | acc_or_outputs = validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc)
92 |
93 | if d in ['aircraft', 'pets', 'caltech101', 'flowers']:
94 | metric = mean_per_class(*acc_or_outputs)
95 | elif d == 'kinetics700_frames':
96 | top1, top5 = accuracy(*acc_or_outputs, topk=(1, 5))
97 | metric = (top1 + top5) / 2
98 | metric = metric.item()
99 | elif d == 'hateful_memes':
100 | metric = roc_auc(*acc_or_outputs)
101 | else:
102 | metric = acc_or_outputs
103 |
104 | results.append(metric)
105 |
106 | print('metric:', metric)
107 |
108 | print('all results:')
109 | for x in results:
110 | print('{:.1f}'.format(x))
111 |
112 | def validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc):
113 | # switch to evaluate mode
114 | model.eval()
115 | total_top1 = 0
116 | total_images = 0
117 |
118 | all_outputs = []
119 | all_targets = []
120 |
121 | print('=> encoding captions')
122 | with torch.no_grad():
123 | text_features = []
124 | for label in labels:
125 | if isinstance(label, list):
126 | texts = [t.format(l) for t in templates for l in label]
127 | else:
128 | texts = [t.format(label) for t in templates]
129 | texts = tokenizer(texts).cuda(non_blocking=True)
130 | texts = texts.view(-1, 77).contiguous()
131 | class_embeddings = utils.get_model(model).encode_text(texts, ema=True)
132 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
133 | class_embeddings = class_embeddings.mean(dim=0)
134 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
135 | text_features.append(class_embeddings)
136 | text_features = torch.stack(text_features, dim=0)
137 |
138 | for images, target in val_loader:
139 | images = images.cuda(non_blocking=True)
140 | target = target.cuda(non_blocking=True)
141 |
142 | # encode images
143 | image_features = utils.get_model(model).encode_image(images, ema=True)
144 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
145 |
146 | # cosine similarity as logits
147 | logits_per_image = image_features @ text_features.t()
148 |
149 | if is_acc:
150 | # measure accuracy and record loss
151 | pred = logits_per_image.argmax(dim=1)
152 | correct = pred.eq(target).sum()
153 | total_top1 += correct.item()
154 | total_images += images.size(0)
155 | else:
156 | all_outputs.append(logits_per_image.cpu())
157 | all_targets.append(target.cpu())
158 |
159 | if is_acc:
160 | return 100 * total_top1 / total_images
161 | else:
162 | return torch.cat(all_outputs), torch.cat(all_targets)
163 |
164 |
165 | def accuracy(output, target, topk=(1,)):
166 | """Computes the accuracy over the k top predictions for the specified values of k"""
167 | with torch.no_grad():
168 | maxk = max(topk)
169 | batch_size = target.size(0)
170 |
171 | _, pred = output.topk(maxk, 1, True, True)
172 | pred = pred.t()
173 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
174 |
175 | res = []
176 | for k in topk:
177 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
178 | res.append(correct_k.mul_(100.0 / batch_size))
179 | return res
180 |
181 |
182 | def mean_per_class(outputs, targets):
183 | pred = outputs.argmax(1)
184 | confusion_matrix = metrics.confusion_matrix(targets, pred)
185 | per_classes = confusion_matrix.diagonal() / confusion_matrix.sum(axis=1)
186 |
187 | return 100 * per_classes.mean()
188 |
189 |
190 | def roc_auc(outputs, targets):
191 | pos_score = outputs[:, 1] - outputs[:, 0]
192 | metric = metrics.roc_auc_score(targets, pos_score)
193 |
194 | return 100 * metric
195 |
196 |
197 | if __name__ == '__main__':
198 | parser = argparse.ArgumentParser('A-CLIP 0-shot evaluations', parents=[get_args_parser()])
199 | args = parser.parse_args()
200 | os.makedirs(args.output_dir, exist_ok=True)
201 | main(args)
202 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | # Based on SLIP code bases
2 | # https://github.com/facebookresearch/SLIP
3 | # --------------------------------------------------------
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import utils
9 |
10 | def get_metric_names():
11 | metics = ["loss"]
12 | metics.extend(["simclr_loss","im_byol_loss","contra_loss_1","contra_loss_2","clip_acc"])
13 |
14 | return metics
15 |
16 |
17 | def cal_simsiam_loss(p, z, version="simplified"): # negative cosine similarity
18 | if version == "original":
19 | z = z.detach() # stop gradient
20 | p = F.normalize(p, dim=1) # l2-normalize
21 | z = F.normalize(z, dim=1) # l2-normalize
22 | return -(p * z).sum(dim=1).mean()
23 |
24 | elif (
25 | version == "simplified"
26 | ): # same thing, much faster. Scroll down, speed test in __main__
27 | return -F.cosine_similarity(p, z.detach(), dim=-1).mean()
28 | else:
29 | raise Exception
30 |
31 |
32 | class ACLIPLoss(nn.Module):
33 | def __init__(self, temperature=0.1):
34 | super().__init__()
35 | self.labels = None
36 | self.last_local_batch_size = None
37 | self.simclr_loss = SIMCLRLoss(temperature=temperature)
38 |
39 | def forward(self, outputs):
40 | image_embed = outputs["image_embed"]
41 | text_embed = outputs["text_embed"]
42 | logit_scale = outputs["logit_scale"]
43 |
44 | # cal simclr_loss
45 | bs = text_embed.shape[0]
46 | image_ssl_embed = outputs["image_ssl_embed"]
47 | inputs = {}
48 | inputs["aug1_embed"] = image_ssl_embed[:bs]
49 | inputs["aug2_embed"] = image_ssl_embed[:bs]
50 | simclr_loss_dict = self.simclr_loss(inputs)
51 |
52 | def loss_fn(x, y):
53 | x = F.normalize(x, dim=-1, p=2)
54 | y = F.normalize(y, dim=-1, p=2)
55 | return 2 - 2 * (x * y).sum(dim=-1)
56 |
57 | im_features = outputs["byol_feats"]
58 | im_features_e = outputs["byol_feats_e"]
59 | im_features_e = torch.cat([im_features_e, im_features_e], dim=0)
60 | im_byol_loss = loss_fn(im_features, im_features_e).mean()
61 |
62 | local_batch_size = text_embed.size(0)
63 |
64 | if local_batch_size != self.last_local_batch_size:
65 | self.labels = local_batch_size * utils.get_rank() + torch.arange(
66 | local_batch_size, device=image_embed.device
67 | )
68 | self.last_local_batch_size = local_batch_size
69 |
70 | image_embed = F.normalize(image_embed, dim=-1, p=2)
71 | text_embed = F.normalize(text_embed, dim=-1, p=2)
72 |
73 | image_embed_1 = image_embed[:local_batch_size]
74 | image_embed_2 = image_embed[local_batch_size:]
75 |
76 | (
77 | image_embed_all_1,
78 | image_embed_all_2,
79 | text_embed_all,
80 | ) = utils.all_gather_batch_with_grad([image_embed_1, image_embed_2, text_embed])
81 |
82 | # cosine similarity as logits
83 | logits_per_image = logit_scale * image_embed_1 @ text_embed_all.t()
84 | logits_per_text = logit_scale * text_embed @ image_embed_all_1.t()
85 |
86 | contra_loss_1 = (
87 | F.cross_entropy(logits_per_image, self.labels)
88 | + F.cross_entropy(logits_per_text, self.labels)
89 | ) / 2
90 |
91 | logits_per_image = logit_scale * image_embed_2 @ text_embed_all.t()
92 | logits_per_text = logit_scale * text_embed @ image_embed_all_2.t()
93 |
94 | contra_loss_2 = (
95 | F.cross_entropy(logits_per_image, self.labels)
96 | + F.cross_entropy(logits_per_text, self.labels)
97 | ) / 2
98 |
99 |
100 | loss = (
101 | 0.5 * contra_loss_1
102 | + 0.5 * contra_loss_2
103 | + simclr_loss_dict["ssl_loss"]
104 | + 2 * im_byol_loss
105 | )
106 |
107 | # compute accuracy
108 | with torch.no_grad():
109 | pred = torch.argmax(logits_per_image, dim=-1)
110 | correct = pred.eq(self.labels).sum()
111 | acc = 100 * correct / local_batch_size
112 |
113 | return {
114 | "loss": loss,
115 | "simclr_loss": simclr_loss_dict["ssl_loss"],
116 | "im_byol_loss": im_byol_loss,
117 | "contra_loss_1": contra_loss_1,
118 | "contra_loss_2": contra_loss_2,
119 | "clip_acc": acc,
120 | }
121 |
122 |
123 | class SIMCLRLoss(nn.Module):
124 | """
125 | This is the SimCLR loss in https://arxiv.org/abs/2002.05709
126 | The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and
127 | the memory layout that can be reshaped into shape (2, batch_size, embedding_dim).
128 | This memory layout is consistent with the SimCLR collator in
129 | https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py
130 | Config params:
131 | temperature (float): the temperature to be applied on the logits
132 | """
133 |
134 | def __init__(self, temperature=0.1):
135 | super().__init__()
136 | self.tau = temperature
137 | self.labels = None
138 | self.masks = None
139 | self.last_local_batch_size = None
140 |
141 | def forward(self, outputs):
142 | q_a = outputs["aug1_embed"]
143 | q_b = outputs["aug2_embed"]
144 |
145 | q_a = F.normalize(q_a, dim=-1, p=2)
146 | q_b = F.normalize(q_b, dim=-1, p=2)
147 |
148 | local_batch_size = q_a.size(0)
149 |
150 | k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b])
151 |
152 | if local_batch_size != self.last_local_batch_size:
153 | self.labels = local_batch_size * utils.get_rank() + torch.arange(
154 | local_batch_size, device=q_a.device
155 | )
156 | total_batch_size = local_batch_size * utils.get_world_size()
157 | self.masks = F.one_hot(self.labels, total_batch_size) * 1e9
158 | self.last_local_batch_size = local_batch_size
159 |
160 | logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau
161 | logits_aa = logits_aa - self.masks
162 | logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau
163 | logits_bb = logits_bb - self.masks
164 | logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau
165 | logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau
166 |
167 | loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels)
168 | loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels)
169 | loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples
170 |
171 | # compute accuracy
172 | with torch.no_grad():
173 | pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1)
174 | correct = pred.eq(self.labels).sum()
175 | acc = 100 * correct / local_batch_size
176 |
177 | return {"loss": loss, "ssl_loss": loss, "ssl_acc": acc}
178 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Based on SLIP code bases
2 | # https://github.com/facebookresearch/SLIP
3 | # --------------------------------------------------------'
4 | import argparse
5 | from collections import OrderedDict
6 | import json
7 | import math
8 | import os
9 | import sys
10 | import time
11 | try:
12 | import wandb
13 | except ImportError:
14 | wandb = None
15 |
16 | import numpy as np
17 | import torch
18 | import torch.cuda.amp as amp
19 | import torch.nn.parallel
20 | import torch.backends.cudnn as cudnn
21 | import torch.optim
22 | import torch.utils.data
23 | import torchvision.transforms as transforms
24 |
25 | from datasets import get_dataset
26 | import models
27 | from tokenizer import SimpleTokenizer
28 | from utils import AverageMeter, ProgressMeter, accuracy
29 | import utils
30 | from torchvision.datasets import ImageFolder
31 | from data import GaussianBlur,Solarize
32 | from losses import ACLIPLoss, get_metric_names
33 |
34 |
35 |
36 | def get_args_parser():
37 | parser = argparse.ArgumentParser(description='A-CLIP pre-training and evaluation', add_help=False)
38 | # Data
39 | parser.add_argument('--dataset', default='yfcc15m', type=str, choices=['yfcc15m', 'cc3m', 'cc12m', 'coco', 'redcaps'])
40 | parser.add_argument('--metadata', default='yfcc15m.pkl', type=str,
41 | help='path to metadata file (see README for details)')
42 | parser.add_argument('--root', default='', type=str,
43 | help='path to dataset root')
44 | parser.add_argument('--output-dir', default='./', type=str, help='path where to save, empty for no saving')
45 | # Model
46 | parser.add_argument('--model', default='ACLIP_VITB16', type=str)
47 | parser.add_argument('--mask-ratio', default=0., type=float)
48 | parser.add_argument('--ssl-mlp-dim', default=4096, type=int,
49 | help='hidden dim of SimCLR mlp projection head')
50 | parser.add_argument('--ssl-emb-dim', default=256, type=int,
51 | help='output embed dim of SimCLR mlp projection head')
52 | parser.add_argument('--ssl-scale', default=1.0, type=float,
53 | help='loss scale for SimCLR objective')
54 | parser.add_argument('--ssl-temp', default=0.1, type=float,
55 | help='softmax temperature for SimCLR objective')
56 | parser.add_argument('--resume', default='', type=str, help='path to resume from')
57 | # Training
58 | parser.add_argument('--momentum-ema', default=0.996, type=float, help="""Base EMA
59 | parameter. The value is increased to 1 during training with cosine schedule.""")
60 | parser.add_argument('--epochs', default=25, type=int)
61 | parser.add_argument('--warmup-epochs', default=1, type=int)
62 | parser.add_argument('--start-epoch', default=0, type=int)
63 | parser.add_argument('--batch-size', default=2, type=int,
64 | help='number of samples per-device/per-gpu')
65 | parser.add_argument('--lr', default=3e-3, type=float)
66 | parser.add_argument('--base-lr', default=3e-3, type=float)
67 | parser.add_argument('--lr-start', default=1e-6, type=float,
68 | help='initial warmup lr')
69 | parser.add_argument('--lr-end', default=1e-5, type=float,
70 | help='minimum final lr')
71 | parser.add_argument('--update-freq', default=1, type=int,
72 | help='optimizer update frequency (i.e. gradient accumulation steps)')
73 | parser.add_argument('--wd', default=0.1, type=float)
74 | parser.add_argument('--betas', default=(0.9, 0.98), nargs=2, type=float)
75 | parser.add_argument('--eps', default=1e-8, type=float)
76 | parser.add_argument('--eval-freq', default=1, type=int)
77 | parser.add_argument('--disable-amp', action='store_true',
78 | help='disable mixed-precision training (requires more memory and compute)')
79 | # System
80 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
81 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
82 | help='number of data loading workers per process')
83 | parser.add_argument('--evaluate', action='store_true', help='eval only')
84 | parser.add_argument('--world-size', default=1, type=int,
85 | help='number of nodes for distributed training')
86 | parser.add_argument('--rank', default=0, type=int,
87 | help='node rank for distributed training')
88 | parser.add_argument("--local_rank", type=int, default=0)
89 | parser.add_argument('--dist-url', default='env://', type=str,
90 | help='url used to set up distributed training')
91 | parser.add_argument('--dist-backend', default='nccl', type=str)
92 | parser.add_argument('--seed', default=0, type=int)
93 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
94 | parser.add_argument('--wandb', action='store_true', help='Enable WandB logging')
95 | parser.add_argument('--descriptions', default='training', type=str)
96 |
97 | return parser
98 |
99 | def get_model(args):
100 | print("=> creating model: {}".format(args.model))
101 | model = getattr(models, args.model)(mask_ratio=args.mask_ratio)
102 | model.cuda(args.gpu)
103 |
104 | if args.distributed:
105 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
106 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], bucket_cap_mb=200,find_unused_parameters=False)
107 |
108 | return model
109 |
110 | def get_optim(args, model):
111 | p_wd, p_non_wd = [], []
112 | for n, p in model.named_parameters():
113 | if not p.requires_grad:
114 | continue # frozen weights
115 | if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n:
116 | p_non_wd.append(p)
117 | else:
118 | p_wd.append(p)
119 |
120 | optim_params = [{"params": p_wd, "weight_decay": args.wd},
121 | {"params": p_non_wd, "weight_decay": 0},
122 | ]
123 |
124 | optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas,
125 | eps=args.eps, weight_decay=args.wd)
126 | return optimizer
127 |
128 | def load_ckpt(args, model, optimizer, scaler):
129 | # optionally resume from a checkpoint (takes precedence over autoresume)
130 | if args.resume:
131 | if os.path.isfile(args.resume):
132 | print("=> loading resume checkpoint '{}'".format(args.resume))
133 | checkpoint = torch.load(args.resume, map_location='cpu')
134 | epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0
135 | args.start_epoch = epoch
136 | result = model.load_state_dict(checkpoint['state_dict'], strict=False)
137 | print(result)
138 | optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else ()
139 | scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else ()
140 | args.best_acc = checkpoint['best_acc']
141 | print("=> loaded resume checkpoint '{}' (epoch {})"
142 | .format(args.resume, epoch))
143 | else:
144 | print("=> no checkpoint found at '{}'".format(args.resume))
145 | else:
146 | # auto-resume from latest checkpoint in output directory
147 | latest = os.path.join(args.output_dir, 'checkpoint.pt')
148 | if os.path.isfile(latest):
149 | print("=> loading latest checkpoint '{}'".format(latest))
150 | latest_checkpoint = torch.load(latest, map_location='cpu')
151 | args.start_epoch = latest_checkpoint['epoch']
152 | model.load_state_dict(latest_checkpoint['state_dict'])
153 | optimizer.load_state_dict(latest_checkpoint['optimizer'])
154 | scaler.load_state_dict(latest_checkpoint['scaler'])
155 | args.best_acc = latest_checkpoint['best_acc']
156 | print("=> loaded latest checkpoint '{}' (epoch {})"
157 | .format(latest, latest_checkpoint['epoch']))
158 |
159 | def get_loader(args, tokenizer):
160 | print("=> creating dataset")
161 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
162 | std=[0.229, 0.224, 0.225])
163 |
164 | val_transform = transforms.Compose([
165 | transforms.Resize(224),
166 | transforms.CenterCrop(224),
167 | transforms.ToTensor(),
168 | normalize
169 | ])
170 |
171 | train_transform = transforms.Compose([
172 | transforms.RandomApply([
173 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
174 | ], p=0.8),
175 | transforms.RandomGrayscale(p=0.2),
176 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1),
177 | transforms.RandomApply([Solarize()], p=0.2),
178 | transforms.ToTensor(),
179 | normalize
180 | ])
181 |
182 | train_dataset = get_dataset(train_transform, tokenizer, args)
183 | cwd = os.path.dirname(os.path.realpath(__file__))
184 | with open(os.path.join(cwd, 'dataset_catalog.json')) as f:
185 | root = json.load(f)['imagenet']['path']
186 | val_dataset = ImageFolder(os.path.join(root, 'val'), val_transform)
187 |
188 | # dist eval resamples data to pad uneven batch sizes
189 | # make sure num_samples = 0 mod num_gpus for exact acc
190 | if args.distributed:
191 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
192 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
193 | else:
194 | train_sampler = None
195 | val_sampler = None
196 |
197 | train_loader = torch.utils.data.DataLoader(
198 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
199 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
200 |
201 | val_loader = torch.utils.data.DataLoader(
202 | val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
203 | num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False)
204 |
205 | return train_loader, train_sampler, val_loader
206 |
207 | def main(args):
208 | utils.init_distributed_mode(args)
209 | cudnn.benchmark = True
210 |
211 | args.best_acc = 0
212 |
213 | # fix the seed for reproducibility
214 | seed = args.seed + utils.get_rank()
215 | torch.manual_seed(seed)
216 | np.random.seed(seed)
217 |
218 | # create model
219 | model = get_model(args)
220 |
221 | # define loss function (criterion) and optimizer
222 | criterion = ACLIPLoss(args.ssl_temp).cuda(args.gpu)
223 |
224 | optimizer = get_optim(args, model)
225 | scaler = amp.GradScaler(enabled=not args.disable_amp)
226 |
227 | load_ckpt(args, model, optimizer, scaler)
228 |
229 | # Data loading
230 | tokenizer = SimpleTokenizer()
231 | train_loader, train_sampler, val_loader = get_loader(args, tokenizer)
232 |
233 | if args.evaluate:
234 | zero_stats = validate_zeroshot(val_loader, model, tokenizer, args)
235 | if utils.is_main_process():
236 | with open(os.path.join(args.output_dir, 'eval_log.txt'), 'a') as f:
237 | f.write(json.dumps(zero_stats) + '\n')
238 | return
239 |
240 | lr_schedule = utils.cosine_scheduler(args.lr, args.lr_end, args.epochs,
241 | len(train_loader) // args.update_freq, warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start)
242 |
243 | momentum_schedule = utils.cosine_scheduler(args.momentum_ema, 1, args.epochs, len(train_loader), 0)
244 |
245 | if utils.is_main_process() and args.wandb:
246 | wandb_id = os.path.split(args.output_dir)[-1]
247 | wandb.init(project='ACLIP', id=wandb_id, config=args, resume='resume')
248 |
249 | print(args)
250 |
251 | print("=> beginning training")
252 | for epoch in range(args.start_epoch, args.epochs):
253 | if args.distributed:
254 | train_sampler.set_epoch(epoch)
255 |
256 | # train for one epoch
257 | train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, momentum_schedule ,args)
258 |
259 | if (epoch + 1) % args.eval_freq != 0:
260 | continue
261 |
262 | val_stats = validate_zeroshot(val_loader, model, tokenizer, args)
263 | ema_val_stats = validate_zeroshot(val_loader, model, tokenizer, args, ema=True)
264 | acc1 = val_stats['acc1']
265 |
266 | is_best = acc1 > args.best_acc
267 | args.best_acc = max(acc1, args.best_acc)
268 |
269 | print("=> saving checkpoint")
270 | utils.save_on_master({
271 | 'epoch': epoch + 1,
272 | 'state_dict': model.state_dict(),
273 | 'optimizer' : optimizer.state_dict(),
274 | 'scaler': scaler.state_dict(),
275 | 'best_acc': args.best_acc,
276 | 'args': args,
277 | }, is_best, args.output_dir,epoch+1,args.epochs)
278 |
279 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
280 | **{f'test_{k}': v for k, v in val_stats.items()},
281 | **{f'ema_test_{k}': v for k, v in ema_val_stats.items()},
282 | 'epoch': epoch}
283 |
284 | if utils.is_main_process():
285 | if args.wandb:
286 | wandb.log(log_stats)
287 | with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f:
288 | f.write(json.dumps(log_stats) + '\n')
289 |
290 |
291 | def train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, momentum_schedule, args):
292 | batch_time = AverageMeter('Time', ':6.2f')
293 | data_time = AverageMeter('Data', ':6.2f')
294 | mem = AverageMeter('Mem (GB)', ':6.1f')
295 | metric_names = get_metric_names()
296 | iters_per_epoch = len(train_loader) // args.update_freq
297 | metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names])
298 | progress = ProgressMeter(
299 | iters_per_epoch,
300 | [batch_time, data_time, mem, *metrics.values()],
301 | prefix="Epoch: [{}]".format(epoch))
302 |
303 | # switch to train mode
304 | model.train()
305 |
306 | end = time.time()
307 | for data_iter, inputs in enumerate(train_loader):
308 | optim_iter = data_iter // args.update_freq
309 | data_time.update(time.time() - end)
310 |
311 | # update weight decay and learning rate according to their schedule
312 | it = iters_per_epoch * epoch + optim_iter # global training iteration
313 | for k, param_group in enumerate(optimizer.param_groups):
314 | param_group['lr'] = lr_schedule[it]
315 |
316 | online_inputs = [torch.cat([inputs[0][0], inputs[0][1]], dim=0), inputs[0][2], inputs[2]]
317 | online_inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in online_inputs]
318 | positions = inputs[1]
319 |
320 | m = momentum_schedule[it] # momentum parameter
321 | # compute output
322 | with amp.autocast(enabled=not args.disable_amp):
323 | outputs = model(*online_inputs, positions, m)
324 | loss_dict = criterion(outputs)
325 |
326 | loss = loss_dict['loss']
327 | loss /= args.update_freq
328 |
329 | if not math.isfinite(loss.item()):
330 | print("Loss is {}, stopping training".format(loss.item()))
331 | sys.exit(1)
332 |
333 | scaler.scale(loss).backward()
334 |
335 | if (data_iter + 1) % args.update_freq != 0:
336 | continue
337 |
338 | # compute gradient and do SGD step
339 | scaler.step(optimizer)
340 | scaler.update()
341 | model.zero_grad(set_to_none=True)
342 |
343 | # clamp logit scale to [0, 100]
344 | logit_scale_e = 0
345 |
346 | utils.get_model(model).logit_scale.data.clamp_(0, 4.6052)
347 | if hasattr(utils.get_model(model),'logit_scale_e'):
348 | utils.get_model(model).logit_scale_e.data.clamp_(0, 4.6052)
349 | logit_scale_e = utils.get_model(model).logit_scale_e.exp().item()
350 |
351 | logit_scale = utils.get_model(model).logit_scale.exp().item()
352 |
353 | for k in loss_dict:
354 | metrics[k].update(loss_dict[k].item(), args.batch_size)
355 |
356 | # measure elapsed time
357 | batch_time.update(time.time() - end)
358 |
359 | end = time.time()
360 |
361 | mem.update(torch.cuda.max_memory_allocated() // 1e9)
362 |
363 | if optim_iter % args.print_freq == 0:
364 | if utils.is_main_process() and args.wandb:
365 | wandb.log({**{k: v.item() for k, v in loss_dict.items()},
366 | 'scaler': scaler.get_scale(),
367 | 'logit': logit_scale,
368 | 'logit_e': logit_scale_e,
369 | })
370 | progress.display(optim_iter)
371 |
372 | progress.synchronize()
373 | return {**{k: v.avg for k, v in metrics.items()},
374 | 'lr': optimizer.param_groups[0]['lr'],
375 | 'logit_scale': logit_scale}
376 |
377 |
378 | def validate_zeroshot(val_loader, model, tokenizer, args, ema=False):
379 | batch_time = AverageMeter('Time', ':6.3f')
380 | top1 = AverageMeter('Acc@1', ':6.2f')
381 | top5 = AverageMeter('Acc@5', ':6.2f')
382 | progress = ProgressMeter(
383 | len(val_loader),
384 | [batch_time, top1, top5],
385 | prefix='Test: ')
386 |
387 | # switch to evaluate mode
388 | model.eval()
389 |
390 | print('=> encoding captions')
391 | cwd = os.path.dirname(os.path.realpath(__file__))
392 | with open(os.path.join(cwd, 'templates.json')) as f:
393 | templates = json.load(f)['imagenet']
394 |
395 | with open(os.path.join(cwd, 'labels.json')) as f:
396 | labels = json.load(f)['imagenet']
397 |
398 | with torch.no_grad():
399 | text_features = []
400 | for l in labels:
401 | texts = [t.format(l) for t in templates]
402 | texts = tokenizer(texts).cuda(args.gpu, non_blocking=True)
403 | class_embeddings = utils.get_model(model).encode_text(texts, ema=ema)
404 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
405 | class_embeddings = class_embeddings.mean(dim=0)
406 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
407 | text_features.append(class_embeddings)
408 | text_features = torch.stack(text_features, dim=0)
409 |
410 | end = time.time()
411 | for i, (images, target) in enumerate(val_loader):
412 | images = images.cuda(args.gpu, non_blocking=True)
413 | target = target.cuda(args.gpu, non_blocking=True)
414 |
415 | # encode images
416 | image_features = utils.get_model(model).encode_image(images, ema=ema)
417 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
418 |
419 | # cosine similarity as logits
420 | logits_per_image = image_features @ text_features.t()
421 |
422 | # measure accuracy and record loss
423 | acc1, acc5 = accuracy(logits_per_image, target, topk=(1, 5))
424 | acc1, acc5 = utils.scaled_all_reduce([acc1, acc5])
425 | top1.update(acc1.item(), images.size(0))
426 | top5.update(acc5.item(), images.size(0))
427 |
428 | # measure elapsed time
429 | batch_time.update(time.time() - end)
430 | end = time.time()
431 |
432 | if i % args.print_freq == 0:
433 | progress.display(i)
434 |
435 | progress.synchronize()
436 | print('0-shot * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
437 | .format(top1=top1, top5=top5))
438 | return {'acc1': top1.avg, 'acc5': top5.avg}
439 |
440 |
441 | if __name__ == '__main__':
442 | parser = argparse.ArgumentParser('A-CLIP training and evaluation', parents=[get_args_parser()])
443 | args = parser.parse_args()
444 | if args.output_dir:
445 | os.makedirs(args.output_dir, exist_ok=True)
446 | main(args)
447 |
448 |
449 |
450 |
--------------------------------------------------------------------------------
/main_linear.py:
--------------------------------------------------------------------------------
1 | # Based on SLIP code bases
2 | # https://github.com/facebookresearch/SLIP
3 | # --------------------------------------------------------
4 | import argparse
5 | import json
6 | import math
7 | import os
8 | import random
9 | import shutil
10 | import time
11 | import timm
12 | import warnings
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.parallel
17 | import torch.backends.cudnn as cudnn
18 | import torch.optim
19 | import torch.utils.data
20 | import torch.utils.data.distributed
21 | import torchvision.transforms as transforms
22 |
23 | import datasets
24 | import utils
25 |
26 |
27 | def get_args_parser():
28 | parser = argparse.ArgumentParser(description='Linear probe evaluation', add_help=False)
29 | parser.add_argument('--dataset', default='imagenet', help='dataset name')
30 | parser.add_argument('--output-dir', default='./', type=str)
31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='vit_base_patch16_224',
32 | help='model architecture: (default: ViT-B/16)')
33 | parser.add_argument('-j', '--workers', default=64, type=int, metavar='N',
34 | help='number of data loading workers (default: 64)')
35 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
36 | help='number of total epochs to run')
37 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
38 | help='manual epoch number (useful on restarts)')
39 | parser.add_argument('-b', '--batch-size', default=128, type=int,
40 | metavar='N',
41 | help='number of samples per-device/per-gpu ')
42 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
43 | metavar='LR', help='initial (base) learning rate', dest='lr')
44 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
45 | help='momentum')
46 | parser.add_argument('--wd', '--weight-decay', default=0., type=float,
47 | metavar='W', help='weight decay (default: 0.)',
48 | dest='weight_decay')
49 | parser.add_argument('-p', '--print-freq', default=10, type=int,
50 | metavar='N', help='print frequency (default: 10)')
51 | parser.add_argument('--eval-freq', default=10, type=int)
52 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
53 | help='path to latest checkpoint (default: none)')
54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
55 | help='evaluate model on validation set')
56 | parser.add_argument('--world-size', default=1, type=int,
57 | help='number of nodes for distributed training')
58 | parser.add_argument('--rank', default=0, type=int,
59 | help='node rank for distributed training')
60 | parser.add_argument("--local_rank", type=int, default=0)
61 | parser.add_argument('--dist-url', default='env://', type=str,
62 | help='url used to set up distributed training')
63 | parser.add_argument('--dist-backend', default='nccl', type=str,
64 | help='distributed backend')
65 | parser.add_argument('--seed', default=None, type=int,
66 | help='seed for initializing training. ')
67 | parser.add_argument('--gpu', default=None, type=int,
68 | help='GPU id to use.')
69 | parser.add_argument('--pretrained', default='', type=str,
70 | help='path to CLIP pretrained checkpoint')
71 | return parser
72 |
73 | best_acc1 = 0
74 |
75 |
76 | def main(args):
77 | utils.init_distributed_mode(args)
78 |
79 | global best_acc1
80 |
81 | if args.seed is not None:
82 | random.seed(args.seed)
83 | torch.manual_seed(args.seed)
84 | cudnn.deterministic = True
85 | warnings.warn('You have chosen to seed training. '
86 | 'This will turn on the CUDNN deterministic setting, '
87 | 'which can slow down your training considerably! '
88 | 'You may see unexpected behavior when restarting '
89 | 'from checkpoints.')
90 |
91 | linear_keyword = 'head'
92 | if os.path.isfile(args.pretrained):
93 | print("=> loading checkpoint '{}'".format(args.pretrained))
94 |
95 | if args.gpu is None:
96 | checkpoint = torch.load(args.pretrained)
97 | else:
98 | # Map model to be loaded to specified single gpu.
99 | loc = 'cuda:{}'.format(args.gpu)
100 | checkpoint = torch.load(args.pretrained, map_location=loc)
101 |
102 | visual_keyword = 'module.visual.'
103 |
104 | # rename CLIP pre-trained keys
105 | state_dict = checkpoint['state_dict']
106 | for k in list(state_dict.keys()):
107 | # retain only base_encoder up to before the embedding layer
108 | if k.startswith(visual_keyword) and not k.startswith(visual_keyword + linear_keyword):
109 | # remove prefix
110 | state_dict[k[len(visual_keyword):]] = state_dict[k]
111 | # delete renamed or unused k
112 | del state_dict[k]
113 | else:
114 | raise Exception('Missing pretrained model checkpoint: {}'.format(args.pretrained))
115 |
116 | # create model
117 | print("=> creating model '{}'".format(args.arch))
118 | model = timm.models.create_model(args.arch, num_classes=1000)
119 |
120 | args.start_epoch = 0
121 | msg = model.load_state_dict(state_dict, strict=False)
122 | assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword}
123 |
124 | # freeze all layers but the last fc
125 | for name, param in model.named_parameters():
126 | if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
127 | param.requires_grad = False
128 | # init the fc layer
129 | getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
130 | getattr(model, linear_keyword).bias.data.zero_()
131 |
132 | init_lr = args.lr * int(args.batch_size / utils.get_world_size()) / 256
133 | args.workers = int((args.workers + utils.get_world_size() - 1) / utils.get_world_size())
134 |
135 | model.cuda(args.gpu)
136 |
137 | if args.distributed:
138 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
139 |
140 | # define loss function (criterion) and optimizer
141 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
142 |
143 | # optimize only the linear classifier
144 | parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
145 | assert len(parameters) == 2 # weight, bias
146 |
147 | optimizer = torch.optim.SGD(parameters, init_lr,
148 | momentum=args.momentum,
149 | weight_decay=args.weight_decay)
150 |
151 | # optionally resume from a checkpoint
152 | if args.resume:
153 | if os.path.isfile(args.resume):
154 | print("=> loading checkpoint '{}'".format(args.resume))
155 | if args.gpu is None:
156 | checkpoint = torch.load(args.resume)
157 | else:
158 | # Map model to be loaded to specified single gpu.
159 | loc = 'cuda:{}'.format(args.gpu)
160 | checkpoint = torch.load(args.resume, map_location=loc)
161 | args.start_epoch = checkpoint['epoch']
162 | best_acc1 = checkpoint['best_acc1']
163 | if args.gpu is not None:
164 | # best_acc1 may be from a checkpoint from a different GPU
165 | best_acc1 = best_acc1.to(args.gpu)
166 | model.load_state_dict(checkpoint['state_dict'])
167 | optimizer.load_state_dict(checkpoint['optimizer'])
168 | print("=> loaded checkpoint '{}' (epoch {})"
169 | .format(args.resume, checkpoint['epoch']))
170 | else:
171 | print("=> no checkpoint found at '{}'".format(args.resume))
172 |
173 | cudnn.benchmark = True
174 |
175 | # Data loading code
176 | cwd = os.path.dirname(os.path.realpath(__file__))
177 | with open(os.path.join(cwd, 'dataset_catalog.json')) as f:
178 | catalog = json.load(f)
179 |
180 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
181 | std=[0.229, 0.224, 0.225])
182 |
183 | train_transform = transforms.Compose([
184 | transforms.RandomResizedCrop(224),
185 | transforms.RandomHorizontalFlip(),
186 | lambda x: x.convert('RGB'),
187 | transforms.ToTensor(),
188 | normalize,
189 | ])
190 | val_transform = transforms.Compose([
191 | transforms.Resize(256),
192 | transforms.CenterCrop(224),
193 | lambda x: x.convert('RGB'),
194 | transforms.ToTensor(),
195 | normalize,
196 | ])
197 |
198 | train_dataset = datasets.get_downstream_dataset(catalog, args.dataset, is_train=True, transform=train_transform)
199 | val_dataset = datasets.get_downstream_dataset(catalog, args.dataset, is_train=False, transform=val_transform)
200 |
201 | if args.distributed:
202 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
203 | else:
204 | train_sampler = None
205 |
206 | train_loader = torch.utils.data.DataLoader(
207 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
208 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
209 |
210 | val_loader = torch.utils.data.DataLoader(
211 | val_dataset, batch_size=256, shuffle=False,
212 | num_workers=args.workers, pin_memory=True)
213 |
214 | if args.evaluate:
215 | validate(val_loader, model, criterion, args)
216 | return
217 |
218 | print(args)
219 |
220 | for epoch in range(args.start_epoch, args.epochs):
221 | if args.distributed:
222 | train_sampler.set_epoch(epoch)
223 | adjust_learning_rate(optimizer, init_lr, epoch, args)
224 |
225 | # train for one epoch
226 | train_stats = train(train_loader, model, criterion, optimizer, epoch, args)
227 |
228 | if (epoch + 1) % args.eval_freq != 0:
229 | continue
230 |
231 | # evaluate on validation set
232 | val_stats = validate(val_loader, model, criterion, args)
233 | acc1 = val_stats['acc1']
234 |
235 | # remember best acc@1 and save checkpoint
236 | is_best = acc1 > best_acc1
237 | best_acc1 = max(acc1, best_acc1)
238 |
239 | if utils.is_main_process(): # only the first GPU saves checkpoint
240 | save_checkpoint({
241 | 'epoch': epoch + 1,
242 | 'arch': args.arch,
243 | 'state_dict': model.state_dict(),
244 | 'best_acc1': best_acc1,
245 | 'optimizer' : optimizer.state_dict(),
246 | }, is_best, args.output_dir)
247 | if epoch == args.start_epoch:
248 | sanity_check(model.state_dict(), args.pretrained, linear_keyword, visual_keyword)
249 |
250 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
251 | **{f'test_{k}': v for k, v in val_stats.items()},
252 | 'epoch': epoch}
253 |
254 | if utils.is_main_process():
255 | with open(os.path.join(args.output_dir, 'linear_{}_lr={}_log.txt'.format(args.dataset, args.lr)), 'a') as f:
256 | f.write(json.dumps(log_stats) + '\n')
257 |
258 |
259 | def train(train_loader, model, criterion, optimizer, epoch, args):
260 | batch_time = AverageMeter('Time', ':6.3f')
261 | data_time = AverageMeter('Data', ':6.3f')
262 | losses = AverageMeter('Loss', ':.4e')
263 | top1 = AverageMeter('Acc@1', ':6.2f')
264 | top5 = AverageMeter('Acc@5', ':6.2f')
265 | progress = ProgressMeter(
266 | len(train_loader),
267 | [batch_time, data_time, losses, top1, top5],
268 | prefix="Epoch: [{}]".format(epoch))
269 |
270 | """
271 | Switch to eval mode:
272 | Under the protocol of linear classification on frozen features/models,
273 | it is not legitimate to change any part of the pre-trained model.
274 | BatchNorm in train mode may revise running mean/std (even if it receives
275 | no gradient), which are part of the model parameters too.
276 | """
277 | model.eval()
278 |
279 | end = time.time()
280 | for i, (images, target) in enumerate(train_loader):
281 | # measure data loading time
282 | data_time.update(time.time() - end)
283 |
284 | if args.gpu is not None:
285 | images = images.cuda(args.gpu, non_blocking=True)
286 | if torch.cuda.is_available():
287 | target = target.cuda(args.gpu, non_blocking=True)
288 |
289 | # compute output
290 | output = model(images)
291 | loss = criterion(output, target)
292 |
293 | # measure accuracy and record loss
294 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
295 | losses.update(loss.item(), images.size(0))
296 | top1.update(acc1.item(), images.size(0))
297 | top5.update(acc5.item(), images.size(0))
298 |
299 | # compute gradient and do SGD step
300 | optimizer.zero_grad()
301 | loss.backward()
302 | optimizer.step()
303 |
304 | # measure elapsed time
305 | batch_time.update(time.time() - end)
306 | end = time.time()
307 |
308 | if i % args.print_freq == 0:
309 | progress.display(i)
310 |
311 | return {'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg}
312 |
313 |
314 | def validate(val_loader, model, criterion, args):
315 | batch_time = AverageMeter('Time', ':6.3f')
316 | losses = AverageMeter('Loss', ':.4e')
317 | top1 = AverageMeter('Acc@1', ':6.2f')
318 | top5 = AverageMeter('Acc@5', ':6.2f')
319 | progress = ProgressMeter(
320 | len(val_loader),
321 | [batch_time, losses, top1, top5],
322 | prefix='Test: ')
323 |
324 | # switch to evaluate mode
325 | model.eval()
326 |
327 | with torch.no_grad():
328 | end = time.time()
329 | for i, (images, target) in enumerate(val_loader):
330 | if args.gpu is not None:
331 | images = images.cuda(args.gpu, non_blocking=True)
332 | if torch.cuda.is_available():
333 | target = target.cuda(args.gpu, non_blocking=True)
334 |
335 | # compute output
336 | output = model(images)
337 | loss = criterion(output, target)
338 |
339 | # measure accuracy and record loss
340 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
341 | losses.update(loss.item(), images.size(0))
342 | top1.update(acc1.item(), images.size(0))
343 | top5.update(acc5.item(), images.size(0))
344 |
345 | # measure elapsed time
346 | batch_time.update(time.time() - end)
347 | end = time.time()
348 |
349 | if i % args.print_freq == 0:
350 | progress.display(i)
351 |
352 | # TODO: this should also be done with the ProgressMeter
353 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
354 | .format(top1=top1, top5=top5))
355 |
356 | return {'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg}
357 |
358 |
359 | def save_checkpoint(state, is_best, output_dir):
360 | ckpt_path = f'{output_dir}/linear_checkpoint.pt'
361 | best_path = f'{output_dir}/linear_best.pt'
362 | torch.save(state, ckpt_path)
363 | if is_best:
364 | shutil.copyfile(ckpt_path, best_path)
365 |
366 |
367 | def sanity_check(state_dict, pretrained_weights, linear_keyword, visual_keyword):
368 | """
369 | Linear classifier should not change any weights other than the linear layer.
370 | This sanity check asserts nothing wrong happens (e.g., BN stats updated).
371 | """
372 | print("=> loading '{}' for sanity check".format(pretrained_weights))
373 | checkpoint = torch.load(pretrained_weights, map_location="cpu")
374 | state_dict_pre = checkpoint['state_dict']
375 |
376 | for k in list(state_dict.keys()):
377 | # only ignore linear layer
378 | if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k:
379 | continue
380 |
381 | # name in pretrained model
382 | k_pre = visual_keyword + k[len('module.'):] \
383 | if k.startswith('module.') else visual_keyword + k
384 |
385 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
386 | '{} is changed in linear classifier training.'.format(k)
387 |
388 | print("=> sanity check passed.")
389 |
390 |
391 | class AverageMeter(object):
392 | """Computes and stores the average and current value"""
393 | def __init__(self, name, fmt=':f'):
394 | self.name = name
395 | self.fmt = fmt
396 | self.reset()
397 |
398 | def reset(self):
399 | self.val = 0
400 | self.avg = 0
401 | self.sum = 0
402 | self.count = 0
403 |
404 | def update(self, val, n=1):
405 | self.val = val
406 | self.sum += val * n
407 | self.count += n
408 | self.avg = self.sum / self.count
409 |
410 | def __str__(self):
411 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
412 | return fmtstr.format(**self.__dict__)
413 |
414 |
415 | class ProgressMeter(object):
416 | def __init__(self, num_batches, meters, prefix=""):
417 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
418 | self.meters = meters
419 | self.prefix = prefix
420 |
421 | def display(self, batch):
422 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
423 | entries += [str(meter) for meter in self.meters]
424 | print('\t'.join(entries))
425 |
426 | def _get_batch_fmtstr(self, num_batches):
427 | num_digits = len(str(num_batches // 1))
428 | fmt = '{:' + str(num_digits) + 'd}'
429 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
430 |
431 |
432 | def adjust_learning_rate(optimizer, init_lr, epoch, args):
433 | """Decay the learning rate based on schedule"""
434 | cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
435 | for param_group in optimizer.param_groups:
436 | param_group['lr'] = cur_lr
437 |
438 |
439 | def accuracy(output, target, topk=(1,)):
440 | """Computes the accuracy over the k top predictions for the specified values of k"""
441 | with torch.no_grad():
442 | maxk = max(topk)
443 | batch_size = target.size(0)
444 |
445 | _, pred = output.topk(maxk, 1, True, True)
446 | pred = pred.t()
447 | correct = pred.eq(target.view(1, -1).expand_as(pred))
448 |
449 | res = []
450 | for k in topk:
451 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
452 | res.append(correct_k.mul_(100.0 / batch_size))
453 | return res
454 |
455 |
456 | if __name__ == '__main__':
457 | parser = argparse.ArgumentParser('Linear probe evaluation', parents=[get_args_parser()])
458 | args = parser.parse_args()
459 | main(args)
--------------------------------------------------------------------------------
/make_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import numpy as np
7 | import pickle
8 | import re
9 | from urllib.parse import unquote
10 | from tqdm import tqdm
11 |
12 |
13 | DATASET = 'yfcc100m_dataset.txt'
14 |
15 | cleanhtml = re.compile('|||||')
16 | cleanurl = re.compile('http\S+|www\S+')
17 |
18 | print('=> loading YFCC image ids')
19 | image_ids = np.load('flickr_unique_ids.npy')
20 | image_ids = set(image_ids)
21 |
22 | print('=> loading CLIP image ids')
23 | clip_ids = set()
24 | with open('yfcc100m_subset_data.tsv') as f:
25 | for l in tqdm(f.readlines()):
26 | row = l.strip().split('\t')
27 | clip_ids.add(int(row[0]))
28 |
29 | print('=> collecting and cleaning subset captions')
30 | captioned = []
31 | uncaptioned = []
32 | with open('yfcc100m_dataset.txt') as f:
33 | for l in tqdm(f.readlines()):
34 | row = l.strip().split('\t')
35 | if int(row[0]) in image_ids:
36 | uncaptioned.append(int(row[0]))
37 | if int(row[0]) in clip_ids:
38 | title = unquote(row[8]).replace('+', ' ')
39 | title = re.sub(cleanhtml, '', title)
40 | title = re.sub(cleanurl, '', title)
41 |
42 | desc = unquote(row[9]).replace('+', ' ')
43 | desc = re.sub(cleanhtml, '', desc)
44 | desc = re.sub(cleanurl, '', desc)
45 |
46 | captioned.append((int(row[0]), title, desc))
47 |
48 | with open('yfcc15m.pkl', 'wb') as f:
49 | pickle.dump(captioned, f)
50 |
51 | with open('yfcc100m.pkl', 'wb') as f:
52 | pickle.dump(uncaptioned, f)
53 |
54 | print('Total captioned images:', len(captioned)) # 14689580
55 | print('Total uncaptioned images:', len(uncaptioned)) # 95920149
56 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 |
2 | # Based on CLIP code bases
3 | # Modified from github.com/openai/CLIP
4 | # --------------------------------------------------------'
5 |
6 | from collections import OrderedDict
7 | import numpy as np
8 | import timm
9 | import torch
10 | from torch import nn
11 | import torchvision.transforms.functional_tensor as F_t
12 |
13 | from timm.models.vision_transformer import VisionTransformer
14 | from timm.models.registry import register_model
15 | from timm.models.vision_transformer import (
16 | default_cfgs,
17 | build_model_with_cfg,
18 | checkpoint_filter_fn,
19 | )
20 |
21 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_
22 |
23 |
24 | def get_att_mask(attention, ratio=0.5):
25 | bs = attention.shape[0]
26 | masks = torch.ones((bs,49), dtype=torch.bool, device=attention.device)
27 | attention = attention.reshape((-1, 14, 14))
28 | attention = torch.nn.functional.interpolate(attention.unsqueeze(1), (7, 7), mode='bilinear').squeeze()
29 | attention = attention.reshape(bs,-1)
30 | N = int(attention.shape[1] * ratio)
31 |
32 | reservation = torch.argsort(attention, descending=True)
33 | reservation = reservation[:,:N+1]
34 | masks = masks.scatter_(1, reservation, False)
35 |
36 | full_mask = torch.zeros((bs, 14, 14), dtype=torch.bool, device=attention.device)
37 | full_mask[:, 0::2, 0::2] = masks.reshape(bs, 7, 7)
38 | full_mask[:, 0::2, 1::2] = masks.reshape(bs, 7, 7)
39 | full_mask[:, 1::2, 0::2] = masks.reshape(bs, 7, 7)
40 | full_mask[:, 1::2, 1::2] = masks.reshape(bs, 7, 7)
41 | full_mask = full_mask.reshape(bs, -1)
42 |
43 | return full_mask
44 |
45 |
46 | def trunc_normal_(tensor, mean=0.0, std=1.0):
47 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
48 |
49 | class LayerNorm(nn.LayerNorm):
50 | """Subclass torch's LayerNorm to handle fp16."""
51 |
52 | def forward(self, x: torch.Tensor):
53 | orig_type = x.dtype
54 | ret = super().forward(x.type(torch.float32))
55 | return ret.type(orig_type)
56 |
57 |
58 | class QuickGELU(nn.Module):
59 | def forward(self, x: torch.Tensor):
60 | return x * torch.sigmoid(1.702 * x)
61 |
62 |
63 | class ResidualAttentionBlock(nn.Module):
64 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
65 | super().__init__()
66 |
67 | self.attn = nn.MultiheadAttention(d_model, n_head)
68 | self.ln_1 = LayerNorm(d_model)
69 | self.mlp = nn.Sequential(
70 | OrderedDict(
71 | [
72 | ("c_fc", nn.Linear(d_model, d_model * 4)),
73 | ("gelu", QuickGELU()),
74 | ("c_proj", nn.Linear(d_model * 4, d_model)),
75 | ]
76 | )
77 | )
78 | self.ln_2 = LayerNorm(d_model)
79 | self.attn_mask = attn_mask
80 |
81 | def attention(self, x: torch.Tensor):
82 | self.attn_mask = (
83 | self.attn_mask.to(dtype=x.dtype, device=x.device)
84 | if self.attn_mask is not None
85 | else None
86 | )
87 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
88 |
89 | def forward(self, x: torch.Tensor):
90 | x = x + self.attention(self.ln_1(x))
91 | x = x + self.mlp(self.ln_2(x))
92 | return x
93 |
94 |
95 | class Transformer(nn.Module):
96 | def __init__(
97 | self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
98 | ):
99 | super().__init__()
100 | self.width = width
101 | self.layers = layers
102 | self.resblocks = nn.Sequential(
103 | *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
104 | )
105 |
106 | def forward(self, x: torch.Tensor):
107 | return self.resblocks(x)
108 |
109 |
110 | class CLIP(nn.Module):
111 | def __init__(
112 | self,
113 | embed_dim: int,
114 | # vision
115 | vision_width: int,
116 | vision_model: nn.Module,
117 | # text
118 | context_length: int,
119 | vocab_size: int,
120 | transformer_width: int,
121 | transformer_heads: int,
122 | transformer_layers: int,
123 | **kwargs,
124 | ):
125 | super().__init__()
126 |
127 | self.context_length = context_length
128 | self.vision_width = vision_width
129 |
130 | self.visual = vision_model
131 |
132 | self.transformer = Transformer(
133 | width=transformer_width,
134 | layers=transformer_layers,
135 | heads=transformer_heads,
136 | attn_mask=self.build_attention_mask(),
137 | )
138 |
139 | self.vocab_size = vocab_size
140 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
141 | self.positional_embedding = nn.Parameter(
142 | torch.empty(self.context_length, transformer_width)
143 | )
144 | self.ln_final = LayerNorm(transformer_width)
145 |
146 | self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
147 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
148 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
149 |
150 | self.initialize_parameters()
151 |
152 | def initialize_parameters(self):
153 | nn.init.normal_(self.token_embedding.weight, std=0.02)
154 | nn.init.normal_(self.positional_embedding, std=0.01)
155 |
156 | proj_std = (self.transformer.width**-0.5) * (
157 | (2 * self.transformer.layers) ** -0.5
158 | )
159 | attn_std = self.transformer.width**-0.5
160 | fc_std = (2 * self.transformer.width) ** -0.5
161 | for block in self.transformer.resblocks:
162 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
163 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
164 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
165 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
166 |
167 | nn.init.normal_(self.image_projection, std=self.vision_width**-0.5)
168 | nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
169 |
170 | def build_attention_mask(self):
171 | # lazily create causal attention mask, with full attention between the vision tokens
172 | # pytorch uses additive attention mask; fill with -inf
173 | mask = torch.empty(self.context_length, self.context_length)
174 | mask.fill_(float("-inf"))
175 | mask.triu_(1) # zero out the lower diagonal
176 | return mask
177 |
178 | def encode_image(self, image):
179 | x = self.visual(image)
180 | x = x @ self.image_projection
181 |
182 | return x
183 |
184 | def encode_text(self, text):
185 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
186 | x = x + self.positional_embedding
187 | x = x.permute(1, 0, 2) # NLD -> LND
188 | x = self.transformer(x)
189 | x = x.permute(1, 0, 2) # LND -> NLD
190 | x = self.ln_final(x)
191 |
192 | # x.shape = [batch_size, n_ctx, transformer.width]
193 | # take features from the eot embedding (eot_token is the highest number in each sequence)
194 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
195 |
196 | return x
197 |
198 | def forward(self, image, text):
199 | image_embed = self.encode_image(image)
200 | text_embed = self.encode_text(text)
201 |
202 | return {
203 | "image_embed": image_embed,
204 | "text_embed": text_embed,
205 | "logit_scale": self.logit_scale.exp(),
206 | }
207 |
208 |
209 | class Proj_Head(nn.Module):
210 | def __init__(self, in_channels, mlp_hidden_size, projection_size):
211 | super().__init__()
212 |
213 | self.net = nn.Sequential(
214 | nn.Linear(in_channels, mlp_hidden_size, bias=False),
215 | LayerNorm(mlp_hidden_size),
216 | nn.GELU(),
217 | nn.Linear(mlp_hidden_size, mlp_hidden_size, bias=False),
218 | LayerNorm(mlp_hidden_size),
219 | nn.GELU(),
220 | nn.Linear(mlp_hidden_size, projection_size, bias=False),
221 | LayerNorm(projection_size),
222 | )
223 |
224 | def forward(self, x):
225 | return self.net(x)
226 |
227 |
228 | class Pred_Head(nn.Module):
229 | def __init__(self, dim, mlp_hidden_size, projection_size):
230 | super().__init__()
231 |
232 | self.net = nn.Sequential(
233 | nn.Linear(dim, mlp_hidden_size),
234 | LayerNorm(mlp_hidden_size),
235 | nn.GELU(),
236 | nn.Linear(mlp_hidden_size, projection_size),
237 | )
238 |
239 | def forward(self, x):
240 | return self.net(x)
241 |
242 |
243 | class projection_MLP(nn.Module):
244 | def __init__(self, in_dim, hidden_dim=2048, out_dim=2048):
245 | super().__init__()
246 | """ page 3 baseline setting
247 | Projection MLP. The projection MLP (in f) has BN ap-
248 | plied to each fully-connected (fc) layer, including its out-
249 | put fc. Its output fc has no ReLU. The hidden fc is 2048-d.
250 | This MLP has 3 layers.
251 | """
252 | self.layer1 = nn.Sequential(
253 | nn.Linear(in_dim, hidden_dim),
254 | nn.BatchNorm1d(hidden_dim),
255 | nn.ReLU(inplace=True),
256 | )
257 | self.layer2 = nn.Sequential(
258 | nn.Linear(hidden_dim, hidden_dim),
259 | nn.BatchNorm1d(hidden_dim),
260 | nn.ReLU(inplace=True),
261 | )
262 | self.layer3 = nn.Sequential(
263 | nn.Linear(hidden_dim, out_dim), nn.BatchNorm1d(hidden_dim)
264 | )
265 | self.num_layers = 3
266 |
267 | def set_layers(self, num_layers):
268 | self.num_layers = num_layers
269 |
270 | def forward(self, x):
271 | if self.num_layers == 3:
272 | x = self.layer1(x)
273 | x = self.layer2(x)
274 | x = self.layer3(x)
275 | elif self.num_layers == 2:
276 | x = self.layer1(x)
277 | x = self.layer3(x)
278 | else:
279 | raise Exception
280 | return x
281 |
282 |
283 | class prediction_MLP(nn.Module):
284 | def __init__(
285 | self, in_dim=2048, hidden_dim=512, out_dim=2048
286 | ): # bottleneck structure
287 | super().__init__()
288 | """ page 3 baseline setting
289 | Prediction MLP. The prediction MLP (h) has BN applied
290 | to its hidden fc layers. Its output fc does not have BN
291 | (ablation in Sec. 4.4) or ReLU. This MLP has 2 layers.
292 | The dimension of h’s input and output (z and p) is d = 2048,
293 | and h’s hidden layer’s dimension is 512, making h a
294 | bottleneck structure (ablation in supplement).
295 | """
296 | self.layer1 = nn.Sequential(
297 | nn.Linear(in_dim, hidden_dim),
298 | nn.BatchNorm1d(hidden_dim),
299 | nn.ReLU(inplace=True),
300 | )
301 | self.layer2 = nn.Linear(hidden_dim, out_dim)
302 | """
303 | Adding BN to the output of the prediction MLP h does not work
304 | well (Table 3d). We find that this is not about collapsing.
305 | The training is unstable and the loss oscillates.
306 | """
307 |
308 | def forward(self, x):
309 | x = self.layer1(x)
310 | x = self.layer2(x)
311 | return x
312 |
313 |
314 | class ACLIP(CLIP):
315 | def __init__(
316 | self,
317 | **kwargs,
318 | ):
319 | super().__init__(**kwargs)
320 | self.visual_ema = kwargs['vision_model_ema']
321 | vision_width = kwargs["vision_width"]
322 | embed_dim = kwargs['embed_dim']
323 |
324 | self.image_mlp = self._build_mlp(vision_width, 4096, 256) #for simclr
325 | self.im_proj_byol = self._build_mlp_byol(2, vision_width, 4096, 256, False)
326 | self.im_pred_byol = self._build_mlp_byol(2, 256, 4096, 256, False)
327 |
328 | transformer_width = kwargs['transformer_width']
329 | transformer_layers = kwargs['transformer_layers']
330 | transformer_heads = kwargs['transformer_heads']
331 | # text ema
332 | self.transformer_e = Transformer(
333 | width=transformer_width,
334 | layers=transformer_layers,
335 | heads=transformer_heads,
336 | attn_mask=self.build_attention_mask(),
337 | )
338 | self.image_projection_e = nn.Parameter(torch.empty(vision_width, embed_dim))
339 | self.text_projection_e = nn.Parameter(torch.empty(transformer_width, embed_dim))
340 |
341 | self.im_proj_byol_e = self._build_mlp_byol(2, vision_width, 4096, 256, False)
342 |
343 | for param_m, param_b in zip(self.visual_ema.parameters(), self.visual.parameters()):
344 | param_m.data.copy_(param_b.data) # initialize
345 | param_m.requires_grad = False # not update by gradient
346 | for param_m, param_b in zip(self.transformer_e.parameters(), self.transformer.parameters()):
347 | param_m.data.copy_(param_b.data) # initialize
348 | param_m.requires_grad = False # not update by gradient
349 | for param_m, param_b in zip(self.im_proj_byol_e.parameters(), self.im_proj_byol.parameters()):
350 | param_m.data.copy_(param_b.data) # initialize
351 | param_m.requires_grad = False # not update by gradient
352 |
353 | self.image_projection_e.requires_grad = False
354 | self.image_projection_e.data.copy_(self.image_projection.data)
355 | self.text_projection_e.requires_grad = False
356 | self.text_projection_e.data.copy_(self.text_projection.data)
357 |
358 |
359 | @torch.no_grad()
360 | def _update_momentum_encoder(self, m):
361 | """Momentum update of the momentum encoder"""
362 | for param_b, param_m in zip(
363 | self.visual.parameters(), self.visual_ema.parameters()
364 | ):
365 | param_m.data = param_m.data * m + param_b.data * (1.0 - m)
366 |
367 | for param_b, param_m in zip(
368 | self.transformer.parameters(), self.transformer_e.parameters()
369 | ):
370 | param_m.data = param_m.data * m + param_b.data * (1.0 - m)
371 | for param_b, param_m in zip(
372 | self.im_proj_byol.parameters(), self.im_proj_byol_e.parameters()
373 | ):
374 | param_m.data = param_m.data * m + param_b.data * (1.0 - m)
375 |
376 | self.image_projection_e.data = self.image_projection_e.data * m + self.image_projection * (1.0 - m)
377 | self.text_projection_e.data = self.text_projection_e.data * m + self.text_projection * (1.0 - m)
378 |
379 | def _build_mlp_byol(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
380 | mlp = []
381 | for l in range(num_layers):
382 | dim1 = input_dim if l == 0 else mlp_dim
383 | dim2 = output_dim if l == num_layers - 1 else mlp_dim
384 |
385 | mlp.append(nn.Linear(dim1, dim2, bias=False))
386 |
387 | if l < num_layers - 1:
388 | mlp.append(nn.BatchNorm1d(dim2))
389 | mlp.append(nn.ReLU(inplace=True))
390 | elif last_bn:
391 | # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
392 | # for simplicity, we further removed gamma in BN
393 | mlp.append(nn.BatchNorm1d(dim2, affine=False))
394 |
395 | return nn.Sequential(*mlp)
396 |
397 | # from simclr
398 | def _build_mlp(self, in_dim, mlp_dim, out_dim):
399 | return nn.Sequential(
400 | OrderedDict(
401 | [
402 | ("layer1", nn.Linear(in_dim, mlp_dim)),
403 | ("bn1", nn.BatchNorm1d(mlp_dim)),
404 | ("relu1", nn.ReLU(inplace=True)),
405 | ("layer2", nn.Linear(mlp_dim, mlp_dim)),
406 | ("bn2", nn.BatchNorm1d(mlp_dim)),
407 | ("relu2", nn.ReLU(inplace=True)),
408 | ("layer3", nn.Linear(mlp_dim, out_dim)),
409 | ]
410 | )
411 | )
412 |
413 | def encode_image(self, image, mask=None, ret=False, ema=False):
414 | if ema == False:
415 | x, attn, _ = self.visual(image, mask=mask, need_attn=False)
416 | tokens = x
417 | x = x[:, 0] @ self.image_projection
418 | else:
419 | x, attn, _ = self.visual_ema(image, mask=mask, need_attn=True)
420 | tokens = x
421 | x = x[:, 0] @ self.image_projection_e
422 |
423 | if ret:
424 | return x, attn, tokens
425 | return x
426 |
427 | def encode_text(self, text, ema=False):
428 | if ema:
429 | text_projection = self.text_projection_e
430 | transformer = self.transformer_e
431 | else:
432 | text_projection = self.text_projection
433 | transformer = self.transformer
434 |
435 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
436 | x = x + self.positional_embedding
437 | x = x.permute(1, 0, 2) # NLD -> LND
438 | x = transformer(x)
439 | x = x.permute(1, 0, 2) # LND -> NLD
440 | x = self.ln_final(x)
441 |
442 | # x.shape = [batch_size, n_ctx, transformer.width]
443 | # take features from the eot embedding (eot_token is the highest number in each sequence)
444 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ text_projection
445 |
446 | return x
447 |
448 | def get_mask(self, mask, positions, e_positions):
449 | # top, left, width, height = pos
450 |
451 | mask = mask.reshape((-1, 14, 14))
452 | cmask = []
453 |
454 | for i in range(mask.shape[0]):
455 | m = mask[i]
456 | m = m.unsqueeze(0)
457 | m = m.unsqueeze(0)
458 | o_pos = positions[i]
459 | e_pos = e_positions[i]
460 | m = torch.nn.functional.interpolate(m, (e_pos[2], e_pos[3]), mode='bilinear')
461 |
462 | top = o_pos[0] - e_pos[0]
463 | left = o_pos[1] - e_pos[1]
464 | m = F_t.crop(m, top, left, o_pos[2], o_pos[3])
465 | m = torch.nn.functional.interpolate(m, (14, 14), mode='bilinear')
466 | cmask.append(m)
467 |
468 | cmask = torch.stack(cmask).squeeze()
469 | cmask = cmask.reshape(mask.shape[0], -1)
470 | return cmask
471 |
472 | def forward(self, im1, im2, text, pos, momentum):
473 | with torch.no_grad():
474 | self._update_momentum_encoder(momentum)
475 | x, attn, _ = self.visual_ema(im2, need_attn=True)
476 | cls_token = x[:, 0]
477 | byol_feats_e = self.im_proj_byol_e(cls_token)
478 |
479 | attention_map = attn
480 | attention_map_1 = self.get_mask(attention_map,pos[:,0],pos[:,2])
481 | mask_1 = get_att_mask(attention_map_1)
482 | attention_map_2 = self.get_mask(attention_map,pos[:,1],pos[:,2])
483 | mask_2 = get_att_mask(attention_map_2)
484 | mask = torch.cat([mask_1,mask_2],dim=0)
485 |
486 | image_embed, _, tokens = self.encode_image(im1, mask=mask, ret=True)
487 | bs = text.shape[0]
488 | cls_token = tokens[:, 0]
489 | image_ssl_embed = self.image_mlp(cls_token)
490 | byol_feats = self.im_proj_byol(cls_token)
491 | byol_feats = self.im_pred_byol(byol_feats)
492 | text_embed = self.encode_text(text)
493 |
494 | return {
495 | "image_embed": image_embed,
496 | "text_embed": text_embed,
497 | "image_ssl_embed": image_ssl_embed,
498 | "byol_feats": byol_feats,
499 | "byol_feats_e": byol_feats_e,
500 | "logit_scale": self.logit_scale.exp(),
501 | }
502 |
503 |
504 |
505 | def forward_attn(self, x):
506 | B, N, C = x.shape
507 | qkv = (
508 | self.qkv(x)
509 | .reshape(B, N, 3, self.num_heads, C // self.num_heads)
510 | .permute(2, 0, 3, 1, 4)
511 | )
512 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
513 |
514 | attn = (q @ k.transpose(-2, -1)) * self.scale
515 | attn = attn.softmax(dim=-1)
516 | attn = self.attn_drop(attn)
517 |
518 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
519 | x = self.proj(x)
520 | x = self.proj_drop(x)
521 | return x, attn.detach()
522 |
523 |
524 | def forward_block(self, x):
525 | attn_x, attn = forward_attn(self.attn, self.norm1(x))
526 | x = x + self.drop_path(attn_x)
527 | x = x + self.drop_path(self.mlp(self.norm2(x)))
528 |
529 | return x, attn
530 |
531 | class MaskVisionTransformer(VisionTransformer):
532 | def __init__(self, mask_ratio=0, **kwargs):
533 | super(MaskVisionTransformer, self).__init__(**kwargs)
534 | self.mask_ratio = mask_ratio
535 | for param in self.patch_embed.proj.parameters():
536 | param.requires_grad = False
537 |
538 | def random_masking(self, x, mask_ratio):
539 | """
540 | Perform per-sample random masking by per-sample shuffling.
541 | Per-sample shuffling is done by argsort random noise.
542 | x: [N, L, D], sequence
543 | """
544 | N, L, D = x.shape # batch, length, dim
545 | len_keep = int(L * (1 - mask_ratio))
546 |
547 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
548 |
549 | # sort noise for each sample
550 | ids_shuffle = torch.argsort(
551 | noise, dim=1
552 | ) # ascend: small is keep, large is remove
553 | ids_restore = torch.argsort(ids_shuffle, dim=1)
554 |
555 | # keep the first subset
556 | ids_keep = ids_shuffle[:, :len_keep]
557 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
558 |
559 | # generate the binary mask: 0 is keep, 1 is remove
560 | mask = torch.ones([N, L], device=x.device)
561 | mask[:, :len_keep] = 0
562 | # unshuffle to get the binary mask
563 | mask = torch.gather(mask, dim=1, index=ids_restore)
564 | return x_masked, mask, ids_restore
565 |
566 | def mask_model(self, x, mask):
567 | N, L, D = x.shape # batch, length, dim
568 | ids = torch.argsort(mask.long(), dim=1) # ascend
569 | mask_len = mask[0].sum()
570 | ids_keep = ids[:, : L - mask_len]
571 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
572 | return x_masked
573 |
574 | def forward_features(self, x, mask=None, need_attn=False):
575 | x = self.patch_embed(x)
576 |
577 | # add pos embed w/o cls token
578 | x = x + self.pos_embed[:, 1:, :]
579 | ids_restore = None
580 | if self.mask_ratio > 0 and self.training is True:
581 | if mask is None:
582 | x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
583 | else:
584 | x = self.mask_model(x, mask)
585 |
586 | # x, mask , ids_restore = self.random_masking(x, self.mask_ratio)
587 | # add pos embed and cls token
588 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
589 | cls_token = cls_token.expand(
590 | x.shape[0], -1, -1
591 | ) # stole cls_tokens impl from Phil Wang, thanks
592 |
593 | x = torch.cat((cls_token, x), dim=1)
594 | x = self.pos_drop(x)
595 | attn_list = []
596 | if need_attn:
597 | for b in self.blocks:
598 | x, attn_now = forward_block(b, x)
599 | attn_list.append(attn_now)
600 | attn = torch.stack(attn_list, dim=0)
601 | attn = torch.mean(attn, dim=0)
602 | attn = attn[:, :, 0, 1:].mean(1).detach().clone()
603 | x = self.norm(x)
604 | return x, attn, ids_restore
605 | else:
606 | x = self.blocks(x)
607 | x = self.norm(x)
608 | attn = None
609 | return x, attn, ids_restore
610 |
611 | def forward(self, x, mask=None, need_attn=False):
612 | x = self.forward_features(x, mask=mask, need_attn=need_attn)
613 | if self.head_dist is not None:
614 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
615 | if self.training and not torch.jit.is_scripting():
616 | # during inference, return the average of both classifier predictions
617 | return x, x_dist
618 | else:
619 | return (x + x_dist) / 2
620 | else:
621 | x = self.head(x)
622 | return x
623 |
624 | @register_model
625 | def mask_vit_small_patch16_224(pretrained=False, **kwargs):
626 | """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
627 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
628 | """
629 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs)
630 | model = _create_vision_transformer(
631 | "vit_small_patch16_224",
632 | MaskVisionTransformer,
633 | pretrained=pretrained,
634 | **model_kwargs,
635 | )
636 | return model
637 |
638 | @register_model
639 | def mask_vit_base_patch16_224(pretrained=False, **kwargs):
640 | """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
641 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
642 | """
643 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
644 | model = _create_vision_transformer(
645 | "vit_base_patch16_224",
646 | MaskVisionTransformer,
647 | pretrained=pretrained,
648 | **model_kwargs,
649 | )
650 | return model
651 |
652 | @register_model
653 | def mask_vit_large_patch16_224(pretrained=False, **kwargs):
654 | """ViT-Large (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
655 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
656 | """
657 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
658 | model = _create_vision_transformer(
659 | "vit_large_patch16_224",
660 | MaskVisionTransformer,
661 | pretrained=pretrained,
662 | **model_kwargs,
663 | )
664 | return model
665 |
666 |
667 | @register_model
668 | def mask_vit_base_patch32_224(pretrained=False, **kwargs):
669 | """ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
670 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
671 | """
672 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
673 | model = _create_vision_transformer(
674 | "vit_base_patch32_224",
675 | MaskVisionTransformer,
676 | pretrained=pretrained,
677 | **model_kwargs,
678 | )
679 | return model
680 |
681 | def _create_vision_transformer(
682 | variant,
683 | transformer=MaskVisionTransformer,
684 | pretrained=False,
685 | default_cfg=None,
686 | **kwargs,
687 | ):
688 | default_cfg = default_cfg or default_cfgs[variant]
689 | if kwargs.get("features_only", None):
690 | raise RuntimeError(
691 | "features_only not implemented for Vision Transformer models."
692 | )
693 |
694 | # NOTE this extra code to support handling of repr size for in21k pretrained models
695 | default_num_classes = default_cfg["num_classes"]
696 | num_classes = kwargs.get("num_classes", default_num_classes)
697 | repr_size = kwargs.pop("representation_size", None)
698 | if repr_size is not None and num_classes != default_num_classes:
699 | # Remove representation layer if fine-tuning. This may not always be the desired action,
700 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
701 | print("Removing representation layer for fine-tuning.")
702 | repr_size = None
703 |
704 | model = build_model_with_cfg(
705 | transformer,
706 | variant,
707 | pretrained,
708 | default_cfg=default_cfg,
709 | representation_size=repr_size,
710 | pretrained_filter_fn=checkpoint_filter_fn,
711 | pretrained_custom_load="npz" in default_cfg["url"],
712 | **kwargs,
713 | )
714 | return model
715 |
716 |
717 | def ACLIP_VITB16(mask_ratio=0, **kwargs):
718 | vision_model = timm.create_model(
719 | "mask_vit_base_patch16_224", num_classes=0, mask_ratio=mask_ratio
720 | )
721 | vision_model_ema = timm.create_model(
722 | 'mask_vit_base_patch16_224', num_classes=0, mask_ratio=0
723 | )
724 | model = ACLIP(
725 | embed_dim=512,
726 | vision_width=768,
727 | vision_model=vision_model,
728 | vision_model_ema=vision_model_ema,
729 | context_length=77,
730 | vocab_size=49408,
731 | transformer_width=512,
732 | transformer_heads=8,
733 | transformer_layers=12,
734 | **kwargs,
735 | )
736 | return model
737 |
738 |
739 | def ACLIP_VITL16(mask_ratio=0, **kwargs):
740 | vision_model = timm.create_model(
741 | "mask_vit_large_patch16_224", num_classes=0, mask_ratio=mask_ratio
742 | )
743 | vision_model_ema = timm.create_model(
744 | 'mask_vit_large_patch16_224', num_classes=0, mask_ratio=0
745 | )
746 | model = ACLIP(
747 | embed_dim=512,
748 | vision_width=1024,
749 | vision_model=vision_model,
750 | vision_model_ema=vision_model_ema,
751 | context_length=77,
752 | vocab_size=49408,
753 | transformer_width=512,
754 | transformer_heads=8,
755 | transformer_layers=12,
756 | **kwargs,
757 | )
758 | return model
759 |
760 | # SMALL
761 | def ACLIP_VITS16(mask_ratio=0, **kwargs):
762 | vision_model = timm.create_model(
763 | "mask_vit_small_patch16_224", num_classes=0, mask_ratio=mask_ratio
764 | )
765 | vision_model_ema = timm.create_model(
766 | "mask_vit_small_patch16_224", num_classes=0, mask_ratio=mask_ratio
767 | )
768 | model = ACLIP(
769 | embed_dim=512,
770 | vision_width=384,
771 | vision_model=vision_model,
772 | vision_model_ema=vision_model_ema,
773 | context_length=77,
774 | vocab_size=49408,
775 | transformer_width=512,
776 | transformer_heads=8,
777 | transformer_layers=12,
778 | **kwargs,
779 | )
780 | return model
781 |
782 |
--------------------------------------------------------------------------------
/templates.json:
--------------------------------------------------------------------------------
1 | {
2 | "food101": [
3 | "a photo of {}, a type of food."
4 | ],
5 | "cifar10": [
6 | "a photo of a {}.",
7 | "a blurry photo of a {}.",
8 | "a black and white photo of a {}.",
9 | "a low contrast photo of a {}.",
10 | "a high contrast photo of a {}.",
11 | "a bad photo of a {}.",
12 | "a good photo of a {}.",
13 | "a photo of a small {}.",
14 | "a photo of a big {}.",
15 | "a photo of the {}.",
16 | "a blurry photo of the {}.",
17 | "a black and white photo of the {}.",
18 | "a low contrast photo of the {}.",
19 | "a high contrast photo of the {}.",
20 | "a bad photo of the {}.",
21 | "a good photo of the {}.",
22 | "a photo of the small {}.",
23 | "a photo of the big {}."
24 | ],
25 | "cifar100": [
26 | "a photo of a {}.",
27 | "a blurry photo of a {}.",
28 | "a black and white photo of a {}.",
29 | "a low contrast photo of a {}.",
30 | "a high contrast photo of a {}.",
31 | "a bad photo of a {}.",
32 | "a good photo of a {}.",
33 | "a photo of a small {}.",
34 | "a photo of a big {}.",
35 | "a photo of the {}.",
36 | "a blurry photo of the {}.",
37 | "a black and white photo of the {}.",
38 | "a low contrast photo of the {}.",
39 | "a high contrast photo of the {}.",
40 | "a bad photo of the {}.",
41 | "a good photo of the {}.",
42 | "a photo of the small {}.",
43 | "a photo of the big {}."
44 | ],
45 | "birdsnap": [
46 | "a photo of a {}, a type of bird."
47 | ],
48 | "cub200": [
49 | "a photo of a {}, a type of bird."
50 | ],
51 | "imagenet": [
52 | "itap of a {}.",
53 | "a bad photo of the {}.",
54 | "a origami {}.",
55 | "a photo of the large {}.",
56 | "a {} in a video game.",
57 | "art of the {}.",
58 | "a photo of the small {}."
59 | ],
60 | "rendered_sst2": [
61 | "a {} review of a movie."
62 | ],
63 | "hateful_memes": [
64 | "a {}."
65 | ],
66 | "clevr_counts": [
67 | "a photo of {} objects."
68 | ],
69 | "kinetics700_frames": [
70 | "a photo of {}.",
71 | "a photo of a person {}.",
72 | "a photo of a person using {}.",
73 | "a photo of a person doing {}.",
74 | "a photo of a person during {}.",
75 | "a photo of a person performing {}.",
76 | "a photo of a person practicing {}.",
77 | "a video of {}.",
78 | "a video of a person {}.",
79 | "a video of a person using {}.",
80 | "a video of a person doing {}.",
81 | "a video of a person during {}.",
82 | "a video of a person performing {}.",
83 | "a video of a person practicing {}.",
84 | "a example of {}.",
85 | "a example of a person {}.",
86 | "a example of a person using {}.",
87 | "a example of a person doing {}.",
88 | "a example of a person during {}.",
89 | "a example of a person performing {}.",
90 | "a example of a person practicing {}.",
91 | "a demonstration of {}.",
92 | "a demonstration of a person {}.",
93 | "a demonstration of a person using {}.",
94 | "a demonstration of a person doing {}.",
95 | "a demonstration of a person during {}.",
96 | "a demonstration of a person performing {}.",
97 | "a demonstration of a person practicing {}."
98 | ],
99 | "ucf101_frames": [
100 | "a photo of a person {}.",
101 | "a video of a person {}.",
102 | "a example of a person {}.",
103 | "a demonstration of a person {}.",
104 | "a photo of the person {}.",
105 | "a video of the person {}.",
106 | "a example of the person {}.",
107 | "a demonstration of the person {}.",
108 | "a photo of a person using {}.",
109 | "a video of a person using {}.",
110 | "a example of a person using {}.",
111 | "a demonstration of a person using {}.",
112 | "a photo of the person using {}.",
113 | "a video of the person using {}.",
114 | "a example of the person using {}.",
115 | "a demonstration of the person using {}.",
116 | "a photo of a person doing {}.",
117 | "a video of a person doing {}.",
118 | "a example of a person doing {}.",
119 | "a demonstration of a person doing {}.",
120 | "a photo of the person doing {}.",
121 | "a video of the person doing {}.",
122 | "a example of the person doing {}.",
123 | "a demonstration of the person doing {}.",
124 | "a photo of a person during {}.",
125 | "a video of a person during {}.",
126 | "a example of a person during {}.",
127 | "a demonstration of a person during {}.",
128 | "a photo of the person during {}.",
129 | "a video of the person during {}.",
130 | "a example of the person during {}.",
131 | "a demonstration of the person during {}.",
132 | "a photo of a person performing {}.",
133 | "a video of a person performing {}.",
134 | "a example of a person performing {}.",
135 | "a demonstration of a person performing {}.",
136 | "a photo of the person performing {}.",
137 | "a video of the person performing {}.",
138 | "a example of the person performing {}.",
139 | "a demonstration of the person performing {}.",
140 | "a photo of a person practicing {}.",
141 | "a video of a person practicing {}.",
142 | "a example of a person practicing {}.",
143 | "a demonstration of a person practicing {}.",
144 | "a photo of the person practicing {}.",
145 | "a video of the person practicing {}.",
146 | "a example of the person practicing {}.",
147 | "a demonstration of the person practicing {}."
148 | ],
149 | "patch_camelyon": [
150 | "this is a photo of {}"
151 | ],
152 | "country211": [
153 | "a photo i took in {}.",
154 | "a photo i took while visiting {}.",
155 | "a photo from my home country of {}.",
156 | "a photo from my visit to {}.",
157 | "a photo showing the country of {}."
158 | ],
159 | "kitti_distance": [
160 | "{}"
161 | ],
162 | "gtsrb": [
163 | "a zoomed in photo of a \"{}\" traffic sign.",
164 | "a centered photo of a \"{}\" traffic sign.",
165 | "a close up photo of a \"{}\" traffic sign."
166 | ],
167 | "resisc45": [
168 | "satellite imagery of {}.",
169 | "aerial imagery of {}.",
170 | "satellite photo of {}.",
171 | "aerial photo of {}.",
172 | "satellite view of {}.",
173 | "aerial view of {}.",
174 | "satellite imagery of a {}.",
175 | "aerial imagery of a {}.",
176 | "satellite photo of a {}.",
177 | "aerial photo of a {}.",
178 | "satellite view of a {}.",
179 | "aerial view of a {}.",
180 | "satellite imagery of the {}.",
181 | "aerial imagery of the {}.",
182 | "satellite photo of the {}.",
183 | "aerial photo of the {}.",
184 | "satellite view of the {}.",
185 | "aerial view of the {}."
186 | ],
187 | "eurosat": [
188 | "a centered satellite photo of {}.",
189 | "a centered satellite photo of a {}.",
190 | "a centered satellite photo of the {}."
191 | ],
192 | "stl10": [
193 | "a photo of a {}.",
194 | "a photo of the {}."
195 | ],
196 | "fer2013": [
197 | "a photo of a {} looking face.",
198 | "a photo of a face showing the emotion: {}.",
199 | "a photo of a face looking {}.",
200 | "a face that looks {}.",
201 | "they look {}.",
202 | "look at how {} they are."
203 | ],
204 | "mnist": [
205 | "a photo of the number: \"{}\"."
206 | ],
207 | "flowers": [
208 | "a photo of a {}, a type of flower."
209 | ],
210 | "caltech101": [
211 | "a photo of a {}.",
212 | "a painting of a {}.",
213 | "a plastic {}.",
214 | "a sculpture of a {}.",
215 | "a sketch of a {}.",
216 | "a tattoo of a {}.",
217 | "a toy {}.",
218 | "a rendition of a {}.",
219 | "a embroidered {}.",
220 | "a cartoon {}.",
221 | "a {} in a video game.",
222 | "a plushie {}.",
223 | "a origami {}.",
224 | "art of a {}.",
225 | "graffiti of a {}.",
226 | "a drawing of a {}.",
227 | "a doodle of a {}.",
228 | "a photo of the {}.",
229 | "a painting of the {}.",
230 | "the plastic {}.",
231 | "a sculpture of the {}.",
232 | "a sketch of the {}.",
233 | "a tattoo of the {}.",
234 | "the toy {}.",
235 | "a rendition of the {}.",
236 | "the embroidered {}.",
237 | "the cartoon {}.",
238 | "the {} in a video game.",
239 | "the plushie {}.",
240 | "the origami {}.",
241 | "art of the {}.",
242 | "graffiti of the {}.",
243 | "a drawing of the {}.",
244 | "a doodle of the {}."
245 | ],
246 | "pets": [
247 | "a photo of a {}, a type of pet."
248 | ],
249 | "dtd": [
250 | "a photo of a {} texture.",
251 | "a photo of a {} pattern.",
252 | "a photo of a {} thing.",
253 | "a photo of a {} object.",
254 | "a photo of the {} texture.",
255 | "a photo of the {} pattern.",
256 | "a photo of the {} thing.",
257 | "a photo of the {} object."
258 | ],
259 | "voc2007": [
260 | "a photo of a {}."
261 | ],
262 | "aircraft": [
263 | "a photo of a {}, a type of aircraft.",
264 | "a photo of the {}, a type of aircraft."
265 | ],
266 | "cars": [
267 | "a photo of a {}.",
268 | "a photo of the {}.",
269 | "a photo of my {}.",
270 | "i love my {}!",
271 | "a photo of my dirty {}.",
272 | "a photo of my clean {}.",
273 | "a photo of my new {}.",
274 | "a photo of my old {}."
275 | ],
276 | "sun397": [
277 | "a photo of a {}.",
278 | "a photo of the {}."
279 | ]
280 | }
--------------------------------------------------------------------------------
/tokenizer.py:
--------------------------------------------------------------------------------
1 | # From github.com/openai/CLIP
2 |
3 | import gzip
4 | import html
5 | import os
6 | from functools import lru_cache
7 |
8 | import ftfy
9 | import regex as re
10 | import torch
11 |
12 |
13 | @lru_cache()
14 | def default_bpe():
15 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
16 |
17 |
18 | @lru_cache()
19 | def bytes_to_unicode():
20 | """
21 | Returns list of utf-8 byte and a corresponding list of unicode strings.
22 | The reversible bpe codes work on unicode strings.
23 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
24 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
25 | This is a signficant percentage of your normal, say, 32K bpe vocab.
26 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
27 | And avoids mapping to whitespace/control characters the bpe code barfs on.
28 | """
29 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
30 | cs = bs[:]
31 | n = 0
32 | for b in range(2**8):
33 | if b not in bs:
34 | bs.append(b)
35 | cs.append(2**8+n)
36 | n += 1
37 | cs = [chr(n) for n in cs]
38 | return dict(zip(bs, cs))
39 |
40 |
41 | def get_pairs(word):
42 | """Return set of symbol pairs in a word.
43 | Word is represented as tuple of symbols (symbols being variable-length strings).
44 | """
45 | pairs = set()
46 | prev_char = word[0]
47 | for char in word[1:]:
48 | pairs.add((prev_char, char))
49 | prev_char = char
50 | return pairs
51 |
52 |
53 | def basic_clean(text):
54 | text = ftfy.fix_text(text)
55 | text = html.unescape(html.unescape(text))
56 | return text.strip()
57 |
58 |
59 | def whitespace_clean(text):
60 | text = re.sub(r'\s+', ' ', text)
61 | text = text.strip()
62 | return text
63 |
64 |
65 | class SimpleTokenizer(object):
66 | def __init__(self, bpe_path: str = default_bpe()):
67 | self.byte_encoder = bytes_to_unicode()
68 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
69 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
70 | merges = merges[1:49152-256-2+1]
71 | merges = [tuple(merge.split()) for merge in merges]
72 | vocab = list(bytes_to_unicode().values())
73 | vocab = vocab + [v+'' for v in vocab]
74 | for merge in merges:
75 | vocab.append(''.join(merge))
76 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
77 | self.encoder = dict(zip(vocab, range(len(vocab))))
78 | self.decoder = {v: k for k, v in self.encoder.items()}
79 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
80 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
81 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
82 |
83 | def bpe(self, token):
84 | if token in self.cache:
85 | return self.cache[token]
86 | word = tuple(token[:-1]) + ( token[-1] + '',)
87 | pairs = get_pairs(word)
88 |
89 | if not pairs:
90 | return token+''
91 |
92 | while True:
93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
94 | if bigram not in self.bpe_ranks:
95 | break
96 | first, second = bigram
97 | new_word = []
98 | i = 0
99 | while i < len(word):
100 | try:
101 | j = word.index(first, i)
102 | new_word.extend(word[i:j])
103 | i = j
104 | except:
105 | new_word.extend(word[i:])
106 | break
107 |
108 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
109 | new_word.append(first+second)
110 | i += 2
111 | else:
112 | new_word.append(word[i])
113 | i += 1
114 | new_word = tuple(new_word)
115 | word = new_word
116 | if len(word) == 1:
117 | break
118 | else:
119 | pairs = get_pairs(word)
120 | word = ' '.join(word)
121 | self.cache[token] = word
122 | return word
123 |
124 | def encode(self, text):
125 | bpe_tokens = []
126 | text = whitespace_clean(basic_clean(text)).lower()
127 | for token in re.findall(self.pat, text):
128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
130 | return bpe_tokens
131 |
132 | def decode(self, tokens):
133 | text = ''.join([self.decoder[token] for token in tokens])
134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
135 | return text
136 |
137 | def __call__(self, texts, context_length=77):
138 | if isinstance(texts, str):
139 | texts = [texts]
140 |
141 | sot_token = self.encoder["<|startoftext|>"]
142 | eot_token = self.encoder["<|endoftext|>"]
143 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
144 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
145 |
146 | for i, tokens in enumerate(all_tokens):
147 | if len(tokens) > context_length:
148 | print(f"Input {texts[i]} is too long for context length {context_length}")
149 | tokens = tokens[:context_length]
150 | result[i, :len(tokens)] = torch.tensor(tokens)
151 |
152 | if len(result) == 1:
153 | return result[0]
154 |
155 | return result
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Based on BEiT v2, SLIP code bases
3 | # https://github.com/microsoft/unilm/tree/master/beit2
4 | # https://github.com/facebookresearch/SLIP
5 | # --------------------------------------------------------'
6 | import numpy as np
7 | import os
8 | import random
9 | import shutil
10 | import torch
11 | import torch.distributed as dist
12 | import torch.autograd as autograd
13 |
14 | from PIL import ImageFilter
15 | from datetime import datetime
16 |
17 | def unwrap_model(model):
18 | if hasattr(model, 'module'):
19 | return model.module
20 | else:
21 | return model
22 |
23 | def get_model(model):
24 | if isinstance(model, torch.nn.DataParallel) \
25 | or isinstance(model, torch.nn.parallel.DistributedDataParallel):
26 | return model.module
27 | else:
28 | return model
29 |
30 | class AverageMeter(object):
31 | """Computes and stores the average and current value"""
32 | def __init__(self, name, fmt=':f'):
33 | self.name = name
34 | self.fmt = fmt
35 | self.reset()
36 |
37 | def reset(self):
38 | self.val = 0
39 | self.avg = 0
40 | self.sum = 0
41 | self.count = 0
42 |
43 | def update(self, val, n=1):
44 | self.val = val
45 | self.sum += val * n
46 | self.count += n
47 | self.avg = self.sum / self.count
48 |
49 | def synchronize(self):
50 | if not is_dist_avail_and_initialized():
51 | return
52 | t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda')
53 | dist.barrier()
54 | dist.all_reduce(t)
55 | t = t.tolist()
56 | self.sum = int(t[0])
57 | self.count = t[1]
58 | self.avg = self.sum / self.count
59 |
60 | def __str__(self):
61 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
62 | return fmtstr.format(**self.__dict__)
63 |
64 |
65 | class ProgressMeter(object):
66 | def __init__(self, num_batches, meters, prefix=""):
67 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
68 | self.meters = meters
69 | self.prefix = prefix
70 |
71 | def display(self, batch):
72 | nowstr = "["+ datetime.now().strftime("%H:%M:%S")+ "] "
73 | entries = [nowstr + self.prefix + self.batch_fmtstr.format(batch)]
74 | entries += [str(meter) for meter in self.meters]
75 |
76 | print('\t'.join(entries))
77 |
78 | def synchronize(self):
79 | for meter in self.meters:
80 | meter.synchronize()
81 |
82 | def _get_batch_fmtstr(self, num_batches):
83 | num_digits = len(str(num_batches // 1))
84 | fmt = '{:' + str(num_digits) + 'd}'
85 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
86 |
87 |
88 | def accuracy(output, target, topk=(1,)):
89 | """Computes the accuracy over the k top predictions for the specified values of k"""
90 | with torch.no_grad():
91 | maxk = max(topk)
92 | batch_size = target.size(0)
93 |
94 | _, pred = output.topk(maxk, 1, True, True)
95 | pred = pred.t()
96 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
97 |
98 | res = []
99 | for k in topk:
100 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
101 | res.append(correct_k.mul_(100.0 / batch_size))
102 | return res
103 |
104 |
105 | def setup_for_distributed(is_master):
106 | """
107 | This function disables printing when not in master process
108 | """
109 | import builtins as __builtin__
110 | builtin_print = __builtin__.print
111 |
112 | def print(*args, **kwargs):
113 | force = kwargs.pop('force', False)
114 | if is_master or force:
115 | builtin_print(*args, **kwargs)
116 |
117 | __builtin__.print = print
118 |
119 |
120 | def is_dist_avail_and_initialized():
121 | if not dist.is_available():
122 | return False
123 | if not dist.is_initialized():
124 | return False
125 | return True
126 |
127 |
128 | def get_world_size():
129 | if not is_dist_avail_and_initialized():
130 | return 1
131 | return dist.get_world_size()
132 |
133 |
134 | def get_rank():
135 | if not is_dist_avail_and_initialized():
136 | return 0
137 | return dist.get_rank()
138 |
139 |
140 | def is_main_process():
141 | return get_rank() == 0
142 |
143 |
144 | def save_on_master(state, is_best, output_dir,epoch=None,toal_epoch=25):
145 | if is_main_process():
146 | ckpt_path = f'{output_dir}/checkpoint.pt'
147 | best_path = f'{output_dir}/checkpoint_best.pt'
148 | torch.save(state, ckpt_path)
149 | if epoch is not None:
150 | save_frep = toal_epoch*0.2
151 | if epoch % save_frep == 0:
152 | each_ckpt_path = f'{output_dir}/checkpoint_{epoch}.pt'
153 | shutil.copyfile(ckpt_path, each_ckpt_path)
154 | if is_best:
155 | shutil.copyfile(ckpt_path, best_path)
156 |
157 |
158 | def init_distributed_mode(args):
159 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
160 | args.rank = int(os.environ["RANK"])
161 | args.world_size = int(os.environ['WORLD_SIZE'])
162 | args.gpu = int(os.environ['LOCAL_RANK'])
163 | elif 'SLURM_PROCID' in os.environ:
164 | args.rank = int(os.environ['SLURM_PROCID'])
165 | args.gpu = args.rank % torch.cuda.device_count()
166 | else:
167 | print('Not using distributed mode')
168 | args.distributed = False
169 | return
170 |
171 | args.distributed = True
172 |
173 | torch.cuda.set_device(args.gpu)
174 | args.dist_backend = 'nccl'
175 | print('| distributed init (rank {}): {}'.format(
176 | args.rank, args.dist_url), flush=True)
177 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
178 | world_size=args.world_size, rank=args.rank,timeout=datetime.timedelta(minutes=90))
179 | torch.distributed.barrier()
180 | setup_for_distributed(args.rank == 0)
181 |
182 |
183 | def scaled_all_reduce(tensors, is_scale=True):
184 | """Performs the scaled all_reduce operation on the provided tensors.
185 | The input tensors are modified in-place. Currently supports only the sum
186 | reduction operator. The reduced values are scaled by the inverse size of the
187 | world size.
188 | """
189 | world_size = get_world_size()
190 | # There is no need for reduction in the single-proc case
191 | if world_size == 1:
192 | return tensors
193 | # Queue the reductions
194 | reductions = []
195 | for tensor in tensors:
196 | reduction = dist.all_reduce(tensor, async_op=True)
197 | reductions.append(reduction)
198 | # Wait for reductions to finish
199 | for reduction in reductions:
200 | reduction.wait()
201 | # Scale the results
202 | if is_scale:
203 | for tensor in tensors:
204 | tensor.mul_(1.0 / world_size)
205 | return tensors
206 |
207 |
208 | class GatherLayer(autograd.Function):
209 | """
210 | Gather tensors from all workers with support for backward propagation:
211 | This implementation does not cut the gradients as torch.distributed.all_gather does.
212 | """
213 |
214 | @staticmethod
215 | def forward(ctx, x):
216 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
217 | dist.all_gather(output, x)
218 | return tuple(output)
219 |
220 | @staticmethod
221 | def backward(ctx, *grads):
222 | all_gradients = torch.stack(grads)
223 | dist.all_reduce(all_gradients)
224 | return all_gradients[dist.get_rank()]
225 |
226 |
227 | def all_gather_batch_with_grad(tensors):
228 | """
229 | Performs all_gather operation on the provided tensors.
230 | Graph remains connected for backward grad computation.
231 | """
232 | # Queue the gathered tensors
233 | world_size = get_world_size()
234 | # There is no need for reduction in the single-proc case
235 | if world_size == 1:
236 | return tensors
237 | tensor_list = []
238 | output_tensor = []
239 |
240 | for tensor in tensors:
241 | tensor_all = GatherLayer.apply(tensor)
242 | tensor_list.append(tensor_all)
243 |
244 | for tensor_all in tensor_list:
245 | output_tensor.append(torch.cat(tensor_all, dim=0))
246 | return output_tensor
247 |
248 |
249 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
250 | warmup_schedule = np.array([])
251 | warmup_iters = warmup_epochs * niter_per_ep
252 | if warmup_epochs > 0:
253 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
254 |
255 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
256 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
257 |
258 | schedule = np.concatenate((warmup_schedule, schedule))
259 | assert len(schedule) == epochs * niter_per_ep
260 | return schedule
261 |
262 |
263 | class GaussianBlur(object):
264 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
265 |
266 | def __init__(self, sigma=[.1, 2.]):
267 | self.sigma = sigma
268 |
269 | def __call__(self, x):
270 | sigma = random.uniform(self.sigma[0], self.sigma[1])
271 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
272 | return x
273 |
274 |
275 | # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
276 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
277 | """
278 | grid_size: int of the grid height and width
279 | return:
280 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
281 | """
282 | grid_h = np.arange(grid_size, dtype=np.float32)
283 | grid_w = np.arange(grid_size, dtype=np.float32)
284 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
285 | grid = np.stack(grid, axis=0)
286 |
287 | grid = grid.reshape([2, 1, grid_size, grid_size])
288 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
289 | if cls_token:
290 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
291 | return pos_embed
292 |
293 |
294 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
295 | assert embed_dim % 2 == 0
296 |
297 | # use half of dimensions to encode grid_h
298 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
299 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
300 |
301 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
302 | return emb
303 |
304 |
305 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
306 | """
307 | embed_dim: output dimension for each position
308 | pos: a list of positions to be encoded: size (M,)
309 | out: (M, D)
310 | """
311 | assert embed_dim % 2 == 0
312 | omega = np.arange(embed_dim // 2, dtype=np.float)
313 | omega /= embed_dim / 2.
314 | omega = 1. / 10000**omega # (D/2,)
315 |
316 | pos = pos.reshape(-1) # (M,)
317 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
318 |
319 | emb_sin = np.sin(out) # (M, D/2)
320 | emb_cos = np.cos(out) # (M, D/2)
321 |
322 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
323 | return emb
324 |
325 |
326 | # --------------------------------------------------------
327 | # Interpolate position embeddings for high-resolution
328 | # References:
329 | # DeiT: https://github.com/facebookresearch/deit
330 | # --------------------------------------------------------
331 | def interpolate_pos_embed(model, checkpoint_model):
332 | if 'pos_embed' in checkpoint_model:
333 | pos_embed_checkpoint = checkpoint_model['pos_embed']
334 | embedding_size = pos_embed_checkpoint.shape[-1]
335 | num_patches = model.patch_embed.num_patches
336 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
337 | # height (== width) for the checkpoint position embedding
338 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
339 | # height (== width) for the new position embedding
340 | new_size = int(num_patches ** 0.5)
341 | # class_token and dist_token are kept unchanged
342 | if orig_size != new_size:
343 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
344 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
345 | # only the position tokens are interpolated
346 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
347 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
348 | pos_tokens = torch.nn.functional.interpolate(
349 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
350 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
351 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
352 | checkpoint_model['pos_embed'] = new_pos_embed
--------------------------------------------------------------------------------