├── .gitignore
├── CODE_OF_CONDUCT.md
├── DATA.md
├── INSTALL.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── config.py
├── configs
├── unicl_swin_base.yaml
└── unicl_swin_tiny.yaml
├── data
├── __init__.py
├── build.py
├── cached_image_folder.py
├── imagenet
│ ├── __init__.py
│ ├── constants.py
│ └── real_labels.py
├── language
│ ├── __init__.py
│ └── prompt_engineering.py
├── samplers.py
├── zipdata.py
└── zipreader.py
├── figures
├── unicl_comparison.png
└── unified_cv.png
├── logger.py
├── lr_scheduler.py
├── main.py
├── model
├── __init__.py
├── image_encoder
│ ├── __init__.py
│ ├── build.py
│ ├── focalnet.py
│ └── swin_transformer.py
├── model.py
└── text_encoder
│ ├── __init__.py
│ ├── build.py
│ ├── hf_model.py
│ ├── registry.py
│ └── transformer.py
├── optimizer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/DATA.md:
--------------------------------------------------------------------------------
1 | ## ImageNet-1K Preparation
2 |
3 | We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
4 | load data:
5 |
6 | - For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
7 | ```bash
8 | $ tree data
9 | imagenet
10 | ├── train
11 | │ ├── class1
12 | │ │ ├── img1.jpeg
13 | │ │ ├── img2.jpeg
14 | │ │ └── ...
15 | │ ├── class2
16 | │ │ ├── img3.jpeg
17 | │ │ └── ...
18 | │ └── ...
19 | └── val
20 | ├── class1
21 | │ ├── img4.jpeg
22 | │ ├── img5.jpeg
23 | │ └── ...
24 | ├── class2
25 | │ ├── img6.jpeg
26 | │ └── ...
27 | └── ...
28 |
29 | ```
30 | - To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes
31 | four files:
32 | - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.
33 | - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth
34 | label. Make sure the data folder looks like this:
35 |
36 | ```bash
37 | $ tree data
38 | data
39 | └── ImageNet-Zip
40 | ├── train_map.txt
41 | ├── train.zip
42 | ├── val_map.txt
43 | └── val.zip
44 |
45 | $ head -n 5 data/ImageNet-Zip/val_map.txt
46 | ILSVRC2012_val_00000001.JPEG 65
47 | ILSVRC2012_val_00000002.JPEG 970
48 | ILSVRC2012_val_00000003.JPEG 230
49 | ILSVRC2012_val_00000004.JPEG 809
50 | ILSVRC2012_val_00000005.JPEG 516
51 |
52 | $ head -n 5 data/ImageNet-Zip/train_map.txt
53 | n01440764/n01440764_10026.JPEG 0
54 | n01440764/n01440764_10027.JPEG 0
55 | n01440764/n01440764_10029.JPEG 0
56 | n01440764/n01440764_10040.JPEG 0
57 | n01440764/n01440764_10042.JPEG 0
58 | ```
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 |
2 | ## Installation Instructions
3 |
4 | - Clone this repo:
5 |
6 | ```bash
7 | git clone https://github.com/microsoft/UniCL
8 | cd UniCL
9 | ```
10 |
11 | - Create a conda virtual environment and activate it (optional):
12 |
13 | ```bash
14 | conda create -n unicl python=3.7 -y
15 | conda activate unicl
16 | ```
17 |
18 | - Install `CUDA==10.1` with `cudnn7` following
19 | the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
20 | - Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`:
21 |
22 | ```bash
23 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
24 | ```
25 |
26 | - Install `timm==0.3.2`:
27 |
28 | ```bash
29 | pip install timm==0.3.2
30 | ```
31 |
32 | - Install `Apex`:
33 |
34 | ```bash
35 | git clone https://github.com/NVIDIA/apex
36 | cd apex
37 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
38 | ```
39 |
40 | - Install other requirements:
41 |
42 | ```bash
43 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8
44 | ```
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Unified Contrastive Learning in Image-Text-Label Space](https://arxiv.org/abs/2204.03610)
2 |
3 | This is the official Pytorch implementation of UniCL:
4 |
5 | ["**Unifiled Contrastive Learning in Image-Text-Label Space. CVPR 2022**"](https://arxiv.org/abs/2204.03610) by [Jianwei Yang*](https://jwyang.github.io/), [Chunyuan Li*](https://chunyuan.li/), [Pengchuan Zhang*](https://pzzhang.github.io/pzzhang/), [Bin Xiao*](https://www.microsoft.com/en-us/research/people/bixi/), [Ce Liu](http://people.csail.mit.edu/celiu/), [Lu Yuan](https://scholar.google.com/citations?user=k9TsUVsAAAAJ&hl=en) and [Jianfeng Gao](https://www.microsoft.com/en-us/research/people/jfgao/?from=http%3A%2F%2Fresearch.microsoft.com%2Fen-us%2Fum%2Fpeople%2Fjfgao%2F).
6 |
7 | ## Introduction
8 |
9 |
10 |
11 |
12 |
13 | In this paper, we introduce a new perspective on commonly used image-label and image-text data by residing them in an image-text-label space. In this space, a new learning paradigm, called **Unified Contrastive Learning (UniCL)** with a single learning objective is proposed to seamlessly prompt the synergy of two data types. We demonstrate that UniCL is an effective way of learning **semantically rich yet discriminative representations**, universally for image recognition in zero-shot, linear-probe, fully finetuning and transfer learning scenarios. When scaled up to billions of data, UniCL can exclusively learn a powerful visual-semantic representation supporting dozens of downstream tasks shown in [Florence](https://arxiv.org/pdf/2111.11432v1.pdf).
14 |
15 | We make the comparisons between UniCL with coventional learning methods below:
16 |
17 |
18 |
19 |
20 |
21 | ## :collision: All previous links are broken. Please find all checkpoints here: [https://github.com/microsoft/UniCL/releases/tag/v1.0](https://github.com/microsoft/UniCL/releases/tag/v1.0)
22 | ## Updates
23 | * [11/24/2022] [KLITE](https://arxiv.org/abs/2204.09222), the knowledge-augmented version of UniCL, is publicly released on [Github](https://github.com/microsoft/klite).
24 | * :collision: [10/05/2022] How do we use the pretrainied UniCL checkpoints? Beyond the zero-shot classification shown in our paper, we can use them for object detection. Now [RegionCLIP](https://github.com/microsoft/RegionCLIP/tree/zero-shot) supports to use pretrained UniCL transformer models, such as Swin, ViT for open-vocabulary object detection without any finetuning. Check it out!
25 | * [08/19/2022] Organizing ECCV Workshop [*Computer Vision in the Wild (CVinW)*](https://computer-vision-in-the-wild.github.io/eccv-2022/), where two challenges are hosted to evaluate the zero-shot, few-shot and full-shot performance of pre-trained vision models in downstream tasks:
26 | - [``*Image Classification in the Wild (ICinW)*''](https://eval.ai/web/challenges/challenge-page/1832/overview) Challenge evaluates on 20 image classification tasks.
27 | - [``*Object Detection in the Wild (ODinW)*''](https://eval.ai/web/challenges/challenge-page/1839/overview) Challenge evaluates on 35 object detection tasks.
28 |
29 | $\qquad$ [
[Workshop]](https://computer-vision-in-the-wild.github.io/eccv-2022/) $\qquad$ [
[IC Challenge] ](https://eval.ai/web/challenges/challenge-page/1832/overview)
30 | $\qquad$ [
[OD Challenge] ](https://eval.ai/web/challenges/challenge-page/1839/overview)
31 |
32 |
33 |
34 | * [06/19/2022] Released the evaluation benchmark used in UniCL, [ELEVATER](https://arxiv.org/abs/2204.08790), which contains 20 downstream image classification tasks. More info: [[Benchmark]](https://computer-vision-in-the-wild.github.io/ELEVATER/) [[Toolkit]](https://github.com/Computer-Vision-in-the-Wild/Elevater_Toolkit_IC) [[Paper]](https://arxiv.org/abs/2204.08790)
35 | * [06/04/2022] Checkout out our [Huggingface Gradio demo](https://huggingface.co/spaces/CVPR/unicl-zero-shot-img-recog).
36 | * [05/21/2022] Released pretrained model and zero-shot evaluation on ImageNet-1k.
37 |
38 | ## Benchmarking
39 |
40 | ### Image-label training augmented by image-text pairs
41 |
42 | | Model | Training Set | Top-1 on IN-1K | ZS on 14 datasets | Download
43 | | :----: | :---: | :---: | :---: | :---: |
44 | | Swin-T | IN-1K | 79.9 | 30.2 | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in1k.pth)/[config](configs/unicl_swin_tiny.yaml)
45 | | Swin-T | IN-1K + GCC-3M | 80.2 | 39.0 | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in1k_gcc3m.pth)/[config](configs/unicl_swin_tiny.yaml)
46 | | Swin-T | IN-1K + GYFCC-14M | 81.1 | 40.0 | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in1k_yfcc14m.pth)/[config](configs/unicl_swin_tiny.yaml)
47 | | Swin-T | IN-1K + GCC-15M | 81.8 | 45.1 | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in1k_gcc15m.pth)/[config](configs/unicl_swin_tiny.yaml)
48 |
49 | Note that all the above models are trained **without** strong data augmentations like mixup and cutmix.
50 |
51 | ### Image-text learning augmented by image-label data
52 |
53 | | Model | Training Set | ZS on IN-1K | ZS on 14 datasets | ZS on 20 datasets | Download
54 | | :----: | :---: | :---: | :---: | :---: | :---: |
55 | | Swin-T | YFCC-14M | 30.1 | 36.3 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/yfcc14m.pth)/[config](configs/unicl_swin_tiny.yaml)
56 | | Swin-T | IN-21K | 28.5 | 37.8 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k.pth)/[config](configs/unicl_swin_tiny.yaml)
57 | | Swin-T | IN-22K | 66.8 | 38.9 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in22k.pth)/[config](configs/unicl_swin_tiny.yaml)
58 | | Swin-T | IN-21K (half) + YFCC-14M (half) | 36.4 | 45.5 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_yfcc14m_half.pth)/[config](configs/unicl_swin_tiny.yaml)
59 | | Swin-T | IN-21K + YFCC-14M | 40.5 | 49.1 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_yfcc14m.pth)/[config](configs/unicl_swin_tiny.yaml)
60 | | Swin-B | IN-21K | 29.9 | 42.4 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_swin_base.pth)/[config](configs/unicl_swin_base.yaml)
61 | | Swin-B | IN-21K (half) + YFCC-14M (half) | 41.1 | 48.5 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_yfcc14m_half_swin_base.pth)/[config](configs/unicl_swin_base.yaml)
62 | | Swin-B | IN-21K + YFCC-14M | 44.3 | 52.2 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_yfcc14m_swin_base.pth)/[config](configs/unicl_swin_base.yaml)
63 | | Swin-B | IN-21K + GCC-15M + YFCC-14M | 52.2 | - | 43.2 | [ckpt]()/[config](configs/unicl_swin_base.yaml)
64 | | Focal-B | IN-21K + GCC-15M + YFCC-14M | 54.2 | - | 44.0 | [ckpt]()/[config]()
65 |
66 | **NOTE**: Setting "ZS on 20 datasets" is used in the ICinW benchmark.
67 |
68 | ## Getting Started
69 |
70 | ### Installation
71 |
72 | Please follow [INSTALL.md](./INSTALL.md) for installation.
73 |
74 | ### Data preparation
75 |
76 | Please following [DATA.md](./DATA.md) for data preparation.
77 |
78 | ### Evaluation
79 |
80 | To evaluate a pre-trained UniCL on ImageNet val, run:
81 |
82 | ```bash
83 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py --eval \
84 | --cfg --resume --data-path
85 | ```
86 |
87 | For example, to evaluate the UniCL-Swin-Tiny trained on YFCC-14M with a single GPU:
88 |
89 | ```bash
90 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
91 | --cfg configs/unicl_swin_tiny.yaml --resume yfcc14m.pth --data-path
92 | ```
93 |
94 | ### The Image Classification in the Wild Benchmark
95 | Interested in evaluating UniCL for downstream image classification tasks, and comparing performance on the same task suite? We release [ELEVATER](https://computer-vision-in-the-wild.github.io/ELEVATER/) benchmark, which has 20 downstream image classification tasks. The [software toolkit](https://github.com/Computer-Vision-in-the-Wild/Elevater_Toolkit_IC) is also released to ease the process to onboad new models. It will be hosted as a challenge at the [CV in the Wild Workshop @ ECCV 2022](https://computer-vision-in-the-wild.github.io/eccv-2022/). We hope our benchmark and toolkit can encourage the community to solve the challenge of image classification in the wild!
96 |
97 | Please see more instructions: [[Benchmark]](https://computer-vision-in-the-wild.github.io/ELEVATER/) [[Toolkit]](https://github.com/Computer-Vision-in-the-Wild/Elevater_Toolkit_IC) [[Paper]](https://arxiv.org/abs/2204.08790)
98 |
99 | ## Citation
100 |
101 | If you find this repo useful to your project, please consider to cite it with following bib:
102 |
103 | @misc{yang2022unified,
104 | title={Unified Contrastive Learning in Image-Text-Label Space},
105 | author={Jianwei Yang and Chunyuan Li and Pengchuan Zhang and Bin Xiao and Ce Liu and Lu Yuan and Jianfeng Gao},
106 | year={2022},
107 | eprint={2204.03610},
108 | archivePrefix={arXiv},
109 | primaryClass={cs.CV}
110 | }
111 |
112 | ## Acknowledgement
113 |
114 | Our codebase is built based on [Swin Transformer](https://github.com/microsoft/Swin-Transformer), [Focal Transformer](https://github.com/microsoft/Focal-Transformer) and [FocalNet](https://github.com/microsoft/FocalNet).
115 |
116 | ## Contributing
117 |
118 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
119 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
120 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
121 |
122 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
123 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
124 | provided by the bot. You will only need to do this once across all repos using our CLA.
125 |
126 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
127 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
128 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
129 |
130 | ## Trademarks
131 |
132 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
133 | trademarks or logos is subject to and must follow
134 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
135 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
136 | Any use of third-party trademarks or logos are subject to those third-party's policies.
137 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Unified Contrastive Learning (UniCL)
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Jianwei Yang (jianwyan@microsoft.com)
6 | # Based on Swin Transformer written by Zhe Liu
7 | # --------------------------------------------------------
8 |
9 | import os
10 | import yaml
11 | from yacs.config import CfgNode as CN
12 |
13 | _C = CN()
14 | _C.VERBOSE = False
15 |
16 | # Base config files
17 | _C.BASE = ['']
18 |
19 | # -----------------------------------------------------------------------------
20 | # Data settings
21 | # -----------------------------------------------------------------------------
22 | _C.DATA = CN()
23 | # Batch size for a single GPU, could be overwritten by command line argument
24 | _C.DATA.BATCH_SIZE = 128
25 | # Path to dataset, could be overwritten by command line argument
26 | _C.DATA.DATA_PATH = ''
27 | # Dataset name
28 | _C.DATA.DATASET = 'imagenet'
29 | # Input image size
30 | _C.DATA.IMG_SIZE = 224
31 | # Interpolation to resize image (random, bilinear, bicubic)
32 | _C.DATA.INTERPOLATION = 'bicubic'
33 | # Use zipped dataset instead of folder dataset
34 | # could be overwritten by command line argument
35 | _C.DATA.ZIP_MODE = False
36 | # Cache Data in Memory, could be overwritten by command line argument
37 | _C.DATA.CACHE_MODE = 'part'
38 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
39 | _C.DATA.PIN_MEMORY = True
40 | # Number of data loading threads
41 | _C.DATA.NUM_WORKERS = 8
42 |
43 | # -----------------------------------------------------------------------------
44 | # Model settings
45 | # -----------------------------------------------------------------------------
46 | _C.MODEL = CN()
47 | # Model name
48 | _C.MODEL.NAME = ''
49 | # Checkpoint to resume, could be overwritten by command line argument
50 | _C.MODEL.RESUME = ''
51 | # Number of classes, overwritten in data preparation
52 | _C.MODEL.NUM_CLASSES = 0
53 | # Label Smoothing
54 | _C.MODEL.LABEL_SMOOTHING = 0.1
55 | # Whether load pretrained model
56 | _C.MODEL.PRETRAINED = ''
57 | # Projection dimension
58 | _C.MODEL.DIM_PROJECTION = 512
59 | # Mode specific
60 | _C.MODEL.SPEC = CN(new_allowed=True)
61 | # -----------------------------------------------------------------------------
62 | # Build Image Encoder
63 | # -----------------------------------------------------------------------------
64 | _C.MODEL.IMAGE_ENCODER = CN()
65 | # Image encoder type
66 | _C.MODEL.IMAGE_ENCODER.TYPE = 'swin'
67 | # Input image size
68 | _C.MODEL.IMAGE_ENCODER.IMG_SIZE = 224
69 | # Dropout rate
70 | _C.MODEL.IMAGE_ENCODER.DROP_RATE = 0.0
71 | # Drop path rate
72 | _C.MODEL.IMAGE_ENCODER.DROP_PATH_RATE = 0.1
73 |
74 | # Swin Transformer parameters
75 | _C.MODEL.IMAGE_ENCODER.SWIN = CN()
76 | _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_SIZE = 4
77 | _C.MODEL.IMAGE_ENCODER.SWIN.IN_CHANS = 3
78 | _C.MODEL.IMAGE_ENCODER.SWIN.EMBED_DIM = 96
79 | _C.MODEL.IMAGE_ENCODER.SWIN.DEPTHS = [2, 2, 6, 2]
80 | _C.MODEL.IMAGE_ENCODER.SWIN.NUM_HEADS = [3, 6, 12, 24]
81 | _C.MODEL.IMAGE_ENCODER.SWIN.WINDOW_SIZE = 7
82 | _C.MODEL.IMAGE_ENCODER.SWIN.MLP_RATIO = 4.
83 | _C.MODEL.IMAGE_ENCODER.SWIN.QKV_BIAS = True
84 | _C.MODEL.IMAGE_ENCODER.SWIN.QK_SCALE = None
85 | _C.MODEL.IMAGE_ENCODER.SWIN.APE = False
86 | _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_NORM = True
87 |
88 | # FocalNet parameters
89 | _C.MODEL.IMAGE_ENCODER.FOCAL = CN()
90 | _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_SIZE = 4
91 | _C.MODEL.IMAGE_ENCODER.FOCAL.IN_CHANS = 3
92 | _C.MODEL.IMAGE_ENCODER.FOCAL.EMBED_DIM = 96
93 | _C.MODEL.IMAGE_ENCODER.FOCAL.DEPTHS = [2, 2, 6, 2]
94 | _C.MODEL.IMAGE_ENCODER.FOCAL.MLP_RATIO = 4.
95 | _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_NORM = True
96 | _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_LEVELS = [2, 2, 2, 2]
97 | _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_WINDOWS = [3, 3, 3, 3]
98 | _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_FACTORS = [2, 2, 2, 2]
99 | _C.MODEL.IMAGE_ENCODER.FOCAL.USE_CONV_EMBED = False
100 | _C.MODEL.IMAGE_ENCODER.FOCAL.USE_LAYERSCALE = False
101 | _C.MODEL.IMAGE_ENCODER.FOCAL.USE_POSTLN = False
102 |
103 | # -----------------------------------------------------------------------------
104 | # Build Text Encoder
105 | # -----------------------------------------------------------------------------
106 | _C.MODEL.TEXT_ENCODER = CN()
107 |
108 | _C.MODEL.TEXT_ENCODER.NAME = 'transformer'
109 | _C.MODEL.TEXT_ENCODER.LOAD_PRETRAINED = False
110 | _C.MODEL.TEXT_ENCODER.PRETRAINED = ''
111 | _C.MODEL.TEXT_ENCODER.TOKENIZER = 'clip'
112 | _C.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77
113 | _C.MODEL.TEXT_ENCODER.WIDTH = 1024
114 | _C.MODEL.TEXT_ENCODER.HEADS = 16
115 | _C.MODEL.TEXT_ENCODER.LAYERS = 12
116 | _C.MODEL.TEXT_ENCODER.AUTOGRESSIVE = True
117 |
118 | # -----------------------------------------------------------------------------
119 | # Training settings
120 | # -----------------------------------------------------------------------------
121 | _C.TRAIN = CN()
122 | _C.TRAIN.START_EPOCH = 0
123 | _C.TRAIN.EPOCHS = 32
124 | _C.TRAIN.WARMUP_EPOCHS = 5
125 | _C.TRAIN.WEIGHT_DECAY = 0.1
126 | _C.TRAIN.BASE_LR = 5e-4
127 | _C.TRAIN.WARMUP_LR = 5e-7
128 | _C.TRAIN.MIN_LR = 5e-6
129 | # Clip gradient norm
130 | _C.TRAIN.CLIP_GRAD = 5.0
131 | # Auto resume from latest checkpoint
132 | _C.TRAIN.AUTO_RESUME = True
133 | # Gradient accumulation steps
134 | # could be overwritten by command line argument
135 | _C.TRAIN.ACCUMULATION_STEPS = 0
136 | # Whether to use gradient checkpointing to save memory
137 | # could be overwritten by command line argument
138 | _C.TRAIN.USE_CHECKPOINT = False
139 |
140 | # LR scheduler
141 | _C.TRAIN.LR_SCHEDULER = CN()
142 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
143 | # Epoch interval to decay LR, used in StepLRScheduler
144 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
145 | # LR decay rate, used in StepLRScheduler
146 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
147 |
148 | # Optimizer
149 | _C.TRAIN.OPTIMIZER = CN()
150 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
151 | # Optimizer Epsilon
152 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
153 | # Optimizer Betas
154 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
155 | # SGD momentum
156 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
157 |
158 | # -----------------------------------------------------------------------------
159 | # Augmentation settings
160 | # -----------------------------------------------------------------------------
161 | _C.AUG = CN()
162 | # Color jitter factor
163 | _C.AUG.COLOR_JITTER = 0.4
164 | # Use AutoAugment policy. "v0" or "original"
165 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
166 | # Random erase prob
167 | _C.AUG.REPROB = 0.25
168 | # Random erase mode
169 | _C.AUG.REMODE = 'pixel'
170 | # Random erase count
171 | _C.AUG.RECOUNT = 1
172 | # Mixup alpha, mixup enabled if > 0
173 | _C.AUG.MIXUP = 0.8
174 | # Cutmix alpha, cutmix enabled if > 0
175 | _C.AUG.CUTMIX = 1.0
176 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
177 | _C.AUG.CUTMIX_MINMAX = None
178 | # Probability of performing mixup or cutmix when either/both is enabled
179 | _C.AUG.MIXUP_PROB = 1.0
180 | # Probability of switching to cutmix when both mixup and cutmix enabled
181 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
182 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
183 | _C.AUG.MIXUP_MODE = 'batch'
184 |
185 | # -----------------------------------------------------------------------------
186 | # Testing settings
187 | # -----------------------------------------------------------------------------
188 | _C.TEST = CN()
189 | # Whether to use center crop when testing
190 | _C.TEST.CROP = True
191 |
192 | # -----------------------------------------------------------------------------
193 | # Misc
194 | # -----------------------------------------------------------------------------
195 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
196 | # overwritten by command line argument
197 | _C.AMP_OPT_LEVEL = ''
198 | # Path to output folder, overwritten by command line argument
199 | _C.OUTPUT = ''
200 | # Tag of experiment, overwritten by command line argument
201 | _C.TAG = 'default'
202 | # Frequency to save checkpoint
203 | _C.SAVE_FREQ = 1
204 | # Frequency to logging info
205 | _C.PRINT_FREQ = 100
206 | # Fixed random seed
207 | _C.SEED = 0
208 | # Perform evaluation only, overwritten by command line argument
209 | _C.EVAL_MODE = False
210 | # Test throughput only, overwritten by command line argument
211 | _C.THROUGHPUT_MODE = False
212 | # Debug only so that skip dataloader initialization, overwritten by command line argument
213 | _C.DEBUG_MODE = False
214 | # local rank for DistributedDataParallel, given by command line argument
215 | _C.LOCAL_RANK = 0
216 |
217 |
218 | def _update_config_from_file(config, cfg_file):
219 | config.defrost()
220 | with open(cfg_file, 'r') as f:
221 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
222 |
223 | for cfg in yaml_cfg.setdefault('BASE', ['']):
224 | if cfg:
225 | _update_config_from_file(
226 | config, os.path.join(os.path.dirname(cfg_file), cfg)
227 | )
228 | print('=> merge config from {}'.format(cfg_file))
229 | config.merge_from_file(cfg_file)
230 | config.freeze()
231 |
232 |
233 | def update_config(config, args):
234 | _update_config_from_file(config, args.cfg)
235 |
236 | config.defrost()
237 | if args.opts:
238 | config.merge_from_list(args.opts)
239 |
240 | # merge from specific arguments
241 | if args.batch_size:
242 | config.DATA.BATCH_SIZE = args.batch_size
243 | if args.dataset:
244 | config.DATA.DATASET = args.dataset
245 | if args.data_path:
246 | config.DATA.DATA_PATH = args.data_path
247 | if args.zip:
248 | config.DATA.ZIP_MODE = True
249 | if args.cache_mode:
250 | config.DATA.CACHE_MODE = args.cache_mode
251 | if args.resume:
252 | config.MODEL.RESUME = args.resume
253 | if args.accumulation_steps:
254 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
255 | if args.use_checkpoint:
256 | config.TRAIN.USE_CHECKPOINT = True
257 | if args.amp_opt_level:
258 | config.AMP_OPT_LEVEL = args.amp_opt_level
259 | if args.output:
260 | config.OUTPUT = args.output
261 | if args.tag:
262 | config.TAG = args.tag
263 | if args.eval:
264 | config.EVAL_MODE = True
265 | if args.throughput:
266 | config.THROUGHPUT_MODE = True
267 | if args.debug:
268 | config.DEBUG_MODE = True
269 |
270 | if config.DATA.DATASET == 'imagewoof':
271 | config.MODEL.NUM_CLASSES = 10
272 |
273 | # set local rank for distributed training
274 | config.LOCAL_RANK = args.local_rank
275 |
276 | # output folder
277 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
278 |
279 | config.freeze()
280 |
281 |
282 | def get_config(args):
283 | """Get a yacs CfgNode object with default values."""
284 | # Return a clone so that the defaults will not be altered
285 | # This is for the "local variable" use pattern
286 | config = _C.clone()
287 | update_config(config, args)
288 |
289 | return config
290 |
--------------------------------------------------------------------------------
/configs/unicl_swin_base.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | NAME: unicl_swin_base
3 | DIM_PROJECTION: 512
4 | IMAGE_ENCODER:
5 | TYPE: swin
6 | DROP_PATH_RATE: 0.5
7 | SWIN:
8 | EMBED_DIM: 128
9 | DEPTHS: [ 2, 2, 18, 2 ]
10 | NUM_HEADS: [ 4, 8, 16, 32 ]
11 | WINDOW_SIZE: 7
12 | TEXT_ENCODER:
13 | NAME: 'transformer'
14 | WIDTH: 512
15 | HEADS: 8
16 | LAYERS: 12
17 |
--------------------------------------------------------------------------------
/configs/unicl_swin_tiny.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | NAME: unicl_swin_tiny
3 | DIM_PROJECTION: 512
4 | IMAGE_ENCODER:
5 | TYPE: swin
6 | DROP_PATH_RATE: 0.2
7 | SWIN:
8 | EMBED_DIM: 96
9 | DEPTHS: [ 2, 2, 6, 2 ]
10 | NUM_HEADS: [ 3, 6, 12, 24 ]
11 | WINDOW_SIZE: 7
12 | TEXT_ENCODER:
13 | NAME: 'transformer'
14 | WIDTH: 512
15 | HEADS: 8
16 | LAYERS: 12
17 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_loader
--------------------------------------------------------------------------------
/data/build.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import torch
10 | import numpy as np
11 | import torch.distributed as dist
12 | from torchvision import datasets, transforms
13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14 | from timm.data import Mixup
15 | from timm.data import create_transform
16 | from timm.data.transforms import _pil_interp
17 |
18 | from .cached_image_folder import CachedImageFolder
19 | from .samplers import SubsetRandomSampler
20 | from model.text_encoder import build_tokenizer
21 |
22 | def build_loader(config):
23 | config.defrost()
24 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
25 | config.freeze()
26 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
27 | dataset_val, _ = build_dataset(is_train=False, config=config)
28 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
29 |
30 | num_tasks = dist.get_world_size()
31 | global_rank = dist.get_rank()
32 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
33 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
34 | sampler_train = SubsetRandomSampler(indices)
35 | else:
36 | sampler_train = torch.utils.data.DistributedSampler(
37 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
38 | )
39 |
40 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
41 | sampler_val = SubsetRandomSampler(indices)
42 |
43 | data_loader_train = torch.utils.data.DataLoader(
44 | dataset_train, sampler=sampler_train,
45 | batch_size=config.DATA.BATCH_SIZE,
46 | num_workers=config.DATA.NUM_WORKERS,
47 | pin_memory=config.DATA.PIN_MEMORY,
48 | drop_last=True,
49 | )
50 |
51 | data_loader_val = torch.utils.data.DataLoader(
52 | dataset_val, sampler=sampler_val,
53 | batch_size=config.DATA.BATCH_SIZE,
54 | shuffle=False,
55 | num_workers=config.DATA.NUM_WORKERS,
56 | pin_memory=config.DATA.PIN_MEMORY,
57 | drop_last=False
58 | )
59 |
60 | # setup mixup / cutmix
61 | mixup_fn = None
62 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
63 | if mixup_active:
64 | mixup_fn = Mixup(
65 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
66 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
67 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
68 |
69 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
70 |
71 |
72 | def build_dataset(is_train, config):
73 | transform = build_transform(is_train, config)
74 | tokenizer = build_tokenizer(config['MODEL']['TEXT_ENCODER'])
75 |
76 | if config.DATA.DATASET == 'imagenet':
77 | prefix = 'train' if is_train else 'val'
78 | if config.DATA.ZIP_MODE:
79 | from .zipdata import ZipData
80 | if is_train:
81 | datapath = os.path.join(config.DATA.DATA_PATH, 'train.zip')
82 | data_map = os.path.join(config.DATA.DATA_PATH, 'train_map.txt')
83 | else:
84 | datapath = os.path.join(config.DATA.DATA_PATH, 'val.zip')
85 | data_map = os.path.join(config.DATA.DATA_PATH, 'val_map.txt')
86 | dataset = ZipData(
87 | datapath, data_map,
88 | tokenizer,
89 | transform,
90 | )
91 | else:
92 | root = os.path.join(config.DATA.DATA_PATH, prefix)
93 | dataset = datasets.ImageFolder(root, transform=transform)
94 | nb_classes = 1000
95 | elif config.DATA.DATASET == 'imagewoof':
96 | prefix = 'train' if is_train else 'val'
97 | root = os.path.join(config.DATA.DATA_PATH, prefix)
98 | dataset = datasets.ImageFolder(root, transform=transform)
99 | nb_classes = 1000
100 | else:
101 | raise NotImplementedError("We only support ImageNet Now.")
102 |
103 | return dataset, nb_classes
104 |
105 |
106 | def build_transform(is_train, config):
107 | resize_im = config.DATA.IMG_SIZE > 32
108 | if is_train:
109 | # this should always dispatch to transforms_imagenet_train
110 | transform = create_transform(
111 | input_size=config.DATA.IMG_SIZE,
112 | is_training=True,
113 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
114 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
115 | re_prob=config.AUG.REPROB,
116 | re_mode=config.AUG.REMODE,
117 | re_count=config.AUG.RECOUNT,
118 | interpolation=config.DATA.INTERPOLATION,
119 | )
120 | if not resize_im:
121 | # replace RandomResizedCropAndInterpolation with
122 | # RandomCrop
123 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
124 | return transform
125 |
126 | t = []
127 | if resize_im:
128 | if config.TEST.CROP:
129 | size = int((256 / 224) * config.DATA.IMG_SIZE)
130 | t.append(
131 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
132 | # to maintain same ratio w.r.t. 224 images
133 | )
134 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
135 | else:
136 | t.append(
137 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
138 | interpolation=_pil_interp(config.DATA.INTERPOLATION))
139 | )
140 |
141 | t.append(transforms.ToTensor())
142 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
143 | return transforms.Compose(t)
144 |
--------------------------------------------------------------------------------
/data/cached_image_folder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import io
9 | import os
10 | import time
11 | import torch.distributed as dist
12 | import torch.utils.data as data
13 | from PIL import Image
14 |
15 | from .zipreader import is_zip_path, ZipReader
16 |
17 |
18 | def has_file_allowed_extension(filename, extensions):
19 | """Checks if a file is an allowed extension.
20 | Args:
21 | filename (string): path to a file
22 | Returns:
23 | bool: True if the filename ends with a known image extension
24 | """
25 | filename_lower = filename.lower()
26 | return any(filename_lower.endswith(ext) for ext in extensions)
27 |
28 |
29 | def find_classes(dir):
30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
31 | classes.sort()
32 | class_to_idx = {classes[i]: i for i in range(len(classes))}
33 | return classes, class_to_idx
34 |
35 |
36 | def make_dataset(dir, class_to_idx, extensions):
37 | images = []
38 | dir = os.path.expanduser(dir)
39 | for target in sorted(os.listdir(dir)):
40 | d = os.path.join(dir, target)
41 | if not os.path.isdir(d):
42 | continue
43 |
44 | for root, _, fnames in sorted(os.walk(d)):
45 | for fname in sorted(fnames):
46 | if has_file_allowed_extension(fname, extensions):
47 | path = os.path.join(root, fname)
48 | item = (path, class_to_idx[target])
49 | images.append(item)
50 |
51 | return images
52 |
53 |
54 | def make_dataset_with_ann(ann_file, img_prefix, extensions):
55 | images = []
56 | with open(ann_file, "r") as f:
57 | contents = f.readlines()
58 | for line_str in contents:
59 | path_contents = [c for c in line_str.split('\t')]
60 | im_file_name = path_contents[0]
61 | class_index = int(path_contents[1])
62 |
63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
64 | item = (os.path.join(img_prefix, im_file_name), class_index)
65 |
66 | images.append(item)
67 |
68 | return images
69 |
70 |
71 | class DatasetFolder(data.Dataset):
72 | """A generic data loader where the samples are arranged in this way: ::
73 | root/class_x/xxx.ext
74 | root/class_x/xxy.ext
75 | root/class_x/xxz.ext
76 | root/class_y/123.ext
77 | root/class_y/nsdf3.ext
78 | root/class_y/asd932_.ext
79 | Args:
80 | root (string): Root directory path.
81 | loader (callable): A function to load a sample given its path.
82 | extensions (list[string]): A list of allowed extensions.
83 | transform (callable, optional): A function/transform that takes in
84 | a sample and returns a transformed version.
85 | E.g, ``transforms.RandomCrop`` for images.
86 | target_transform (callable, optional): A function/transform that takes
87 | in the target and transforms it.
88 | Attributes:
89 | samples (list): List of (sample path, class_index) tuples
90 | """
91 |
92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
93 | cache_mode="no"):
94 | # image folder mode
95 | if ann_file == '':
96 | _, class_to_idx = find_classes(root)
97 | samples = make_dataset(root, class_to_idx, extensions)
98 | # zip mode
99 | else:
100 | samples = make_dataset_with_ann(os.path.join(root, ann_file),
101 | os.path.join(root, img_prefix),
102 | extensions)
103 |
104 | if len(samples) == 0:
105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
106 | "Supported extensions are: " + ",".join(extensions)))
107 |
108 | self.root = root
109 | self.loader = loader
110 | self.extensions = extensions
111 |
112 | self.samples = samples
113 | self.labels = [y_1k for _, y_1k in samples]
114 | self.classes = list(set(self.labels))
115 |
116 | self.transform = transform
117 | self.target_transform = target_transform
118 |
119 | self.cache_mode = cache_mode
120 | if self.cache_mode != "no":
121 | self.init_cache()
122 |
123 | def init_cache(self):
124 | assert self.cache_mode in ["part", "full"]
125 | n_sample = len(self.samples)
126 | global_rank = dist.get_rank()
127 | world_size = dist.get_world_size()
128 |
129 | samples_bytes = [None for _ in range(n_sample)]
130 | start_time = time.time()
131 | for index in range(n_sample):
132 | if index % (n_sample // 10) == 0:
133 | t = time.time() - start_time
134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
135 | start_time = time.time()
136 | path, target = self.samples[index]
137 | if self.cache_mode == "full":
138 | samples_bytes[index] = (ZipReader.read(path), target)
139 | elif self.cache_mode == "part" and index % world_size == global_rank:
140 | samples_bytes[index] = (ZipReader.read(path), target)
141 | else:
142 | samples_bytes[index] = (path, target)
143 | self.samples = samples_bytes
144 |
145 | def __getitem__(self, index):
146 | """
147 | Args:
148 | index (int): Index
149 | Returns:
150 | tuple: (sample, target) where target is class_index of the target class.
151 | """
152 | path, target = self.samples[index]
153 | sample = self.loader(path)
154 | if self.transform is not None:
155 | sample = self.transform(sample)
156 | if self.target_transform is not None:
157 | target = self.target_transform(target)
158 |
159 | return sample, target
160 |
161 | def __len__(self):
162 | return len(self.samples)
163 |
164 | def __repr__(self):
165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
167 | fmt_str += ' Root Location: {}\n'.format(self.root)
168 | tmp = ' Transforms (if any): '
169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
170 | tmp = ' Target Transforms (if any): '
171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
172 | return fmt_str
173 |
174 |
175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
176 |
177 |
178 | def pil_loader(path):
179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
180 | if isinstance(path, bytes):
181 | img = Image.open(io.BytesIO(path))
182 | elif is_zip_path(path):
183 | data = ZipReader.read(path)
184 | img = Image.open(io.BytesIO(data))
185 | else:
186 | with open(path, 'rb') as f:
187 | img = Image.open(f)
188 | return img.convert('RGB')
189 |
190 |
191 | def accimage_loader(path):
192 | import accimage
193 | try:
194 | return accimage.Image(path)
195 | except IOError:
196 | # Potentially a decoding problem, fall back to PIL.Image
197 | return pil_loader(path)
198 |
199 |
200 | def default_img_loader(path):
201 | from torchvision import get_image_backend
202 | if get_image_backend() == 'accimage':
203 | return accimage_loader(path)
204 | else:
205 | return pil_loader(path)
206 |
207 |
208 | class CachedImageFolder(DatasetFolder):
209 | """A generic data loader where the images are arranged in this way: ::
210 | root/dog/xxx.png
211 | root/dog/xxy.png
212 | root/dog/xxz.png
213 | root/cat/123.png
214 | root/cat/nsdf3.png
215 | root/cat/asd932_.png
216 | Args:
217 | root (string): Root directory path.
218 | transform (callable, optional): A function/transform that takes in an PIL image
219 | and returns a transformed version. E.g, ``transforms.RandomCrop``
220 | target_transform (callable, optional): A function/transform that takes in the
221 | target and transforms it.
222 | loader (callable, optional): A function to load an image given its path.
223 | Attributes:
224 | imgs (list): List of (image path, class_index) tuples
225 | """
226 |
227 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
228 | loader=default_img_loader, cache_mode="no"):
229 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
230 | ann_file=ann_file, img_prefix=img_prefix,
231 | transform=transform, target_transform=target_transform,
232 | cache_mode=cache_mode)
233 | self.imgs = self.samples
234 |
235 | def __getitem__(self, index):
236 | """
237 | Args:
238 | index (int): Index
239 | Returns:
240 | tuple: (image, target) where target is class_index of the target class.
241 | """
242 | path, target = self.samples[index]
243 | image = self.loader(path)
244 | if self.transform is not None:
245 | img = self.transform(image)
246 | else:
247 | img = image
248 | if self.target_transform is not None:
249 | target = self.target_transform(target)
250 |
251 | return img, target
252 |
--------------------------------------------------------------------------------
/data/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import IMAGENET_CLASSES, IMAGENET_DEFAULT_TEMPLATES
--------------------------------------------------------------------------------
/data/imagenet/constants.py:
--------------------------------------------------------------------------------
1 | IMAGENET_CLASSES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "dark glasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
2 |
3 | IMAGENET_DEFAULT_TEMPLATES = [
4 | '{}.',
5 | 'a bad photo of a {}.',
6 | 'a photo of many {}.',
7 | 'a sculpture of a {}.',
8 | 'a photo of the hard to see {}.',
9 | 'a low resolution photo of the {}.',
10 | 'a rendering of a {}.',
11 | 'graffiti of a {}.',
12 | 'a bad photo of the {}.',
13 | 'a cropped photo of the {}.',
14 | 'a tattoo of a {}.',
15 | 'the embroidered {}.',
16 | 'a photo of a hard to see {}.',
17 | 'a bright photo of a {}.',
18 | 'a photo of a clean {}.',
19 | 'a photo of a dirty {}.',
20 | 'a dark photo of the {}.',
21 | 'a drawing of a {}.',
22 | 'a photo of my {}.',
23 | 'the plastic {}.',
24 | 'a photo of the cool {}.',
25 | 'a close-up photo of a {}.',
26 | 'a black and white photo of the {}.',
27 | 'a painting of the {}.',
28 | 'a painting of a {}.',
29 | 'a pixelated photo of the {}.',
30 | 'a sculpture of the {}.',
31 | 'a bright photo of the {}.',
32 | 'a cropped photo of a {}.',
33 | 'a plastic {}.',
34 | 'a photo of the dirty {}.',
35 | 'a jpeg corrupted photo of a {}.',
36 | 'a blurry photo of the {}.',
37 | 'a photo of the {}.',
38 | 'a good photo of the {}.',
39 | 'a rendering of the {}.',
40 | 'a {} in a video game.',
41 | 'a photo of one {}.',
42 | 'a doodle of a {}.',
43 | 'a close-up photo of the {}.',
44 | 'a photo of a {}.',
45 | 'the origami {}.',
46 | 'the {} in a video game.',
47 | 'a sketch of a {}.',
48 | 'a doodle of the {}.',
49 | 'a origami {}.',
50 | 'a low resolution photo of a {}.',
51 | 'the toy {}.',
52 | 'a rendition of the {}.',
53 | 'a photo of the clean {}.',
54 | 'a photo of a large {}.',
55 | 'a rendition of a {}.',
56 | 'a photo of a nice {}.',
57 | 'a photo of a weird {}.',
58 | 'a blurry photo of a {}.',
59 | 'a cartoon {}.',
60 | 'art of a {}.',
61 | 'a sketch of the {}.',
62 | 'a embroidered {}.',
63 | 'a pixelated photo of a {}.',
64 | 'itap of the {}.',
65 | 'a jpeg corrupted photo of the {}.',
66 | 'a good photo of a {}.',
67 | 'a plushie {}.',
68 | 'a photo of the nice {}.',
69 | 'a photo of the small {}.',
70 | 'a photo of the weird {}.',
71 | 'the cartoon {}.',
72 | 'art of the {}.',
73 | 'a drawing of the {}.',
74 | 'a photo of the large {}.',
75 | 'a black and white photo of a {}.',
76 | 'the plushie {}.',
77 | 'a dark photo of a {}.',
78 | 'itap of a {}.',
79 | 'graffiti of the {}.',
80 | 'a toy {}.',
81 | 'itap of my {}.',
82 | 'a photo of a cool {}.',
83 | 'a photo of a small {}.',
84 | 'a tattoo of the {}.',
85 | ]
86 |
--------------------------------------------------------------------------------
/data/imagenet/real_labels.py:
--------------------------------------------------------------------------------
1 | """ Real labels evaluator for ImageNet
2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import os
8 | import json
9 | import numpy as np
10 |
11 |
12 | class RealLabelsImagenet:
13 |
14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
15 | with open(real_json) as real_labels:
16 | real_labels = json.load(real_labels)
17 | real_labels = {
18 | f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels
19 | for i, labels in enumerate(real_labels)
20 | }
21 | self.real_labels = real_labels
22 | self.filenames = filenames
23 | assert len(self.filenames) == len(self.real_labels)
24 | self.topk = topk
25 | self.is_correct = {k: [] for k in topk}
26 | self.sample_idx = 0
27 |
28 | def add_result(self, output):
29 | maxk = max(self.topk)
30 | _, pred_batch = output.topk(maxk, 1, True, True)
31 | pred_batch = pred_batch.cpu().numpy()
32 | for pred in pred_batch:
33 | filename = self.filenames[self.sample_idx]
34 | filename = os.path.basename(filename)
35 | if self.real_labels[filename]:
36 | for k in self.topk:
37 | self.is_correct[k].append(
38 | any([p in self.real_labels[filename] for p in pred[:k]]))
39 | self.sample_idx += 1
40 |
41 | def get_accuracy(self, k=None):
42 | if k is None:
43 | return {k: float(np.mean(self.is_correct[k] for k in self.topk))}
44 | else:
45 | return float(np.mean(self.is_correct[k])) * 100
46 |
--------------------------------------------------------------------------------
/data/language/__init__.py:
--------------------------------------------------------------------------------
1 | from .prompt_engineering import prompt_engineering
--------------------------------------------------------------------------------
/data/language/prompt_engineering.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 |
5 | def get_prompt_templates():
6 | prompt_templates = [
7 | '{}.',
8 | 'a photo of a {}.',
9 | 'a bad photo of a {}.',
10 | 'a photo of many {}.',
11 | 'a sculpture of a {}.',
12 | 'a photo of the hard to see {}.',
13 | 'a low resolution photo of the {}.',
14 | 'a rendering of a {}.',
15 | 'graffiti of a {}.',
16 | 'a bad photo of the {}.',
17 | 'a cropped photo of the {}.',
18 | 'a tattoo of a {}.',
19 | 'the embroidered {}.',
20 | 'a photo of a hard to see {}.',
21 | 'a bright photo of a {}.',
22 | 'a photo of a clean {}.',
23 | 'a photo of a dirty {}.',
24 | 'a dark photo of the {}.',
25 | 'a drawing of a {}.',
26 | 'a photo of my {}.',
27 | 'the plastic {}.',
28 | 'a photo of the cool {}.',
29 | 'a close-up photo of a {}.',
30 | 'a black and white photo of the {}.',
31 | 'a painting of the {}.',
32 | 'a painting of a {}.',
33 | 'a pixelated photo of the {}.',
34 | 'a sculpture of the {}.',
35 | 'a bright photo of the {}.',
36 | 'a cropped photo of a {}.',
37 | 'a plastic {}.',
38 | 'a photo of the dirty {}.',
39 | 'a jpeg corrupted photo of a {}.',
40 | 'a blurry photo of the {}.',
41 | 'a photo of the {}.',
42 | 'a good photo of the {}.',
43 | 'a rendering of the {}.',
44 | 'a {} in a video game.',
45 | 'a photo of one {}.',
46 | 'a doodle of a {}.',
47 | 'a close-up photo of the {}.',
48 | 'the origami {}.',
49 | 'the {} in a video game.',
50 | 'a sketch of a {}.',
51 | 'a doodle of the {}.',
52 | 'a origami {}.',
53 | 'a low resolution photo of a {}.',
54 | 'the toy {}.',
55 | 'a rendition of the {}.',
56 | 'a photo of the clean {}.',
57 | 'a photo of a large {}.',
58 | 'a rendition of a {}.',
59 | 'a photo of a nice {}.',
60 | 'a photo of a weird {}.',
61 | 'a blurry photo of a {}.',
62 | 'a cartoon {}.',
63 | 'art of a {}.',
64 | 'a sketch of the {}.',
65 | 'a embroidered {}.',
66 | 'a pixelated photo of a {}.',
67 | 'itap of the {}.',
68 | 'a jpeg corrupted photo of the {}.',
69 | 'a good photo of a {}.',
70 | 'a plushie {}.',
71 | 'a photo of the nice {}.',
72 | 'a photo of the small {}.',
73 | 'a photo of the weird {}.',
74 | 'the cartoon {}.',
75 | 'art of the {}.',
76 | 'a drawing of the {}.',
77 | 'a photo of the large {}.',
78 | 'a black and white photo of a {}.',
79 | 'the plushie {}.',
80 | 'a dark photo of a {}.',
81 | 'itap of a {}.',
82 | 'graffiti of the {}.',
83 | 'a toy {}.',
84 | 'itap of my {}.',
85 | 'a photo of a cool {}.',
86 | 'a photo of a small {}.',
87 | 'a tattoo of the {}.',
88 | ]
89 | return prompt_templates
90 |
91 |
92 | def prompt_engineering(classnames):
93 | prompt_templates = get_prompt_templates()
94 | temp_idx = np.random.randint(len(prompt_templates))
95 |
96 | if isinstance(classnames, list):
97 | classname = random.choice(classnames)
98 | else:
99 | classname = classnames
100 |
101 | return prompt_templates[temp_idx].format(classname.replace(',', '').replace('+', ' '))
102 |
--------------------------------------------------------------------------------
/data/samplers.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import torch
9 |
10 |
11 | class SubsetRandomSampler(torch.utils.data.Sampler):
12 | r"""Samples elements randomly from a given list of indices, without replacement.
13 |
14 | Arguments:
15 | indices (sequence): a sequence of indices
16 | """
17 |
18 | def __init__(self, indices):
19 | self.epoch = 0
20 | self.indices = indices
21 |
22 | def __iter__(self):
23 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
24 |
25 | def __len__(self):
26 | return len(self.indices)
27 |
28 | def set_epoch(self, epoch):
29 | self.epoch = epoch
30 |
--------------------------------------------------------------------------------
/data/zipdata.py:
--------------------------------------------------------------------------------
1 | import os.path as op
2 | from zipfile import ZipFile, BadZipFile
3 | import torch.utils.data as data
4 | from PIL import Image
5 | from io import BytesIO
6 | import multiprocessing
7 | from .imagenet import IMAGENET_CLASSES
8 | from .language import prompt_engineering
9 |
10 | _VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
11 |
12 | class ZipData(data.Dataset):
13 | _IGNORE_ATTRS = {'_zip_file'}
14 |
15 | def __init__(self, path, map_file, tokenizer,
16 | transform=None, target_transform=None,
17 | extensions=None):
18 | self._path = path
19 | if not extensions:
20 | extensions = _VALID_IMAGE_TYPES
21 | self._zip_file = ZipFile(path)
22 | self.zip_dict = {}
23 | self.samples = []
24 | assert tokenizer is not None
25 | self.tokenizer = tokenizer
26 | self.transform = transform
27 | self.target_transform = target_transform
28 |
29 | self.class_to_idx = {}
30 | with open(map_file, 'r') as f:
31 | for line in iter(f.readline, ""):
32 | line = line.strip()
33 | if not line:
34 | continue
35 | cls_idx = [l for l in line.split('\t') if l]
36 | if not cls_idx:
37 | continue
38 | assert len(cls_idx) >= 2, "invalid line: {}".format(line)
39 | idx = int(cls_idx[1])
40 | cls = cls_idx[0]
41 | del cls_idx
42 | at_idx = cls.find('@')
43 | assert at_idx >= 0, "invalid class: {}".format(cls)
44 | cls = cls[at_idx + 1:]
45 | if cls.startswith('/'):
46 | # Python ZipFile expects no root
47 | cls = cls[1:]
48 | assert cls, "invalid class in line {}".format(line)
49 | prev_idx = self.class_to_idx.get(cls)
50 | assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format(
51 | cls, idx, prev_idx
52 | )
53 | self.class_to_idx[cls] = idx
54 |
55 | for fst in self._zip_file.infolist():
56 | fname = fst.filename
57 | target = self.class_to_idx.get(fname)
58 | if target is None:
59 | continue
60 | if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
61 | continue
62 | ext = op.splitext(fname)[1].lower()
63 | if ext in extensions:
64 | self.samples.append((fname, target))
65 | assert len(self), "No images found in: {} with map: {}".format(self._path, map_file)
66 |
67 | def __repr__(self):
68 | return 'ZipData({}, size={})'.format(self._path, len(self))
69 |
70 | def __getstate__(self):
71 | return {
72 | key: val if key not in self._IGNORE_ATTRS else None
73 | for key, val in self.__dict__.iteritems()
74 | }
75 |
76 | def __getitem__(self, index):
77 | proc = multiprocessing.current_process()
78 | pid = proc.pid # get pid of this process.
79 | if pid not in self.zip_dict:
80 | self.zip_dict[pid] = ZipFile(self._path)
81 | zip_file = self.zip_dict[pid]
82 |
83 | if index >= len(self) or index < 0:
84 | raise KeyError("{} is invalid".format(index))
85 | path, target = self.samples[index]
86 | try:
87 | sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB')
88 | except BadZipFile:
89 | print("bad zip file")
90 | return None, None
91 | if self.transform is not None:
92 | sample = self.transform(sample)
93 |
94 | # convert label to text with prompt
95 | text = self._decode_text_from_label(target)
96 | tokens = self.tokenizer(
97 | text, padding='max_length',
98 | truncation=True,
99 | max_length=77,
100 | return_tensors='pt'
101 | ) if self.tokenizer else text
102 | tokens['input_ids'].squeeze_()
103 | tokens['attention_mask'].squeeze_()
104 |
105 | if self.target_transform is not None:
106 | target = self.target_transform(target)
107 | return sample, tokens, target
108 |
109 | def __len__(self):
110 | return len(self.samples)
111 |
112 | def _decode_text_from_label(self, label):
113 | concept = IMAGENET_CLASSES[int(label)]
114 | text = prompt_engineering(concept)
115 | return text
116 |
117 |
--------------------------------------------------------------------------------
/data/zipreader.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import zipfile
10 | import io
11 | import numpy as np
12 | from PIL import Image
13 | from PIL import ImageFile
14 |
15 | ImageFile.LOAD_TRUNCATED_IMAGES = True
16 |
17 |
18 | def is_zip_path(img_or_path):
19 | """judge if this is a zip path"""
20 | return '.zip@' in img_or_path
21 |
22 |
23 | class ZipReader(object):
24 | """A class to read zipped files"""
25 | zip_bank = dict()
26 |
27 | def __init__(self):
28 | super(ZipReader, self).__init__()
29 |
30 | @staticmethod
31 | def get_zipfile(path):
32 | zip_bank = ZipReader.zip_bank
33 | if path not in zip_bank:
34 | zfile = zipfile.ZipFile(path, 'r')
35 | zip_bank[path] = zfile
36 | return zip_bank[path]
37 |
38 | @staticmethod
39 | def split_zip_style_path(path):
40 | pos_at = path.index('@')
41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
42 |
43 | zip_path = path[0: pos_at]
44 | folder_path = path[pos_at + 1:]
45 | folder_path = str.strip(folder_path, '/')
46 | return zip_path, folder_path
47 |
48 | @staticmethod
49 | def list_folder(path):
50 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
51 |
52 | zfile = ZipReader.get_zipfile(zip_path)
53 | folder_list = []
54 | for file_foler_name in zfile.namelist():
55 | file_foler_name = str.strip(file_foler_name, '/')
56 | if file_foler_name.startswith(folder_path) and \
57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \
58 | file_foler_name != folder_path:
59 | if len(folder_path) == 0:
60 | folder_list.append(file_foler_name)
61 | else:
62 | folder_list.append(file_foler_name[len(folder_path) + 1:])
63 |
64 | return folder_list
65 |
66 | @staticmethod
67 | def list_files(path, extension=None):
68 | if extension is None:
69 | extension = ['.*']
70 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
71 |
72 | zfile = ZipReader.get_zipfile(zip_path)
73 | file_lists = []
74 | for file_foler_name in zfile.namelist():
75 | file_foler_name = str.strip(file_foler_name, '/')
76 | if file_foler_name.startswith(folder_path) and \
77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
78 | if len(folder_path) == 0:
79 | file_lists.append(file_foler_name)
80 | else:
81 | file_lists.append(file_foler_name[len(folder_path) + 1:])
82 |
83 | return file_lists
84 |
85 | @staticmethod
86 | def read(path):
87 | zip_path, path_img = ZipReader.split_zip_style_path(path)
88 | zfile = ZipReader.get_zipfile(zip_path)
89 | data = zfile.read(path_img)
90 | return data
91 |
92 | @staticmethod
93 | def imread(path):
94 | zip_path, path_img = ZipReader.split_zip_style_path(path)
95 | zfile = ZipReader.get_zipfile(zip_path)
96 | data = zfile.read(path_img)
97 | try:
98 | im = Image.open(io.BytesIO(data))
99 | except:
100 | print("ERROR IMG LOADED: ", path_img)
101 | random_img = np.random.rand(224, 224, 3) * 255
102 | im = Image.fromarray(np.uint8(random_img))
103 | return im
104 |
--------------------------------------------------------------------------------
/figures/unicl_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/UniCL/9983bf89660f16dc923530a14d56a83b4699dcbc/figures/unicl_comparison.png
--------------------------------------------------------------------------------
/figures/unified_cv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/UniCL/9983bf89660f16dc923530a14d56a83b4699dcbc/figures/unified_cv.png
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import logging
11 | import functools
12 | from termcolor import colored
13 |
14 |
15 | @functools.lru_cache()
16 | def create_logger(output_dir, dist_rank=0, name=''):
17 | # create logger
18 | logger = logging.getLogger(name)
19 | logger.setLevel(logging.DEBUG)
20 | logger.propagate = False
21 |
22 | # create formatter
23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
26 |
27 | # create console handlers for master process
28 | if dist_rank == 0:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setLevel(logging.DEBUG)
31 | console_handler.setFormatter(
32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
33 | logger.addHandler(console_handler)
34 |
35 | # create file handlers
36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
37 | file_handler.setLevel(logging.DEBUG)
38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
39 | logger.addHandler(file_handler)
40 |
41 | return logger
42 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from timm.scheduler.cosine_lr import CosineLRScheduler
10 | from timm.scheduler.step_lr import StepLRScheduler
11 | from timm.scheduler.scheduler import Scheduler
12 |
13 |
14 | def build_scheduler(config, optimizer, n_iter_per_epoch):
15 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
16 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
17 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
18 |
19 | lr_scheduler = None
20 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
21 | lr_scheduler = CosineLRScheduler(
22 | optimizer,
23 | t_initial=num_steps,
24 | t_mul=1.,
25 | lr_min=config.TRAIN.MIN_LR,
26 | warmup_lr_init=config.TRAIN.WARMUP_LR,
27 | warmup_t=warmup_steps,
28 | cycle_limit=1,
29 | t_in_epochs=False,
30 | )
31 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
32 | lr_scheduler = LinearLRScheduler(
33 | optimizer,
34 | t_initial=num_steps,
35 | lr_min_rate=0.01,
36 | warmup_lr_init=config.TRAIN.WARMUP_LR,
37 | warmup_t=warmup_steps,
38 | t_in_epochs=False,
39 | )
40 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
41 | lr_scheduler = StepLRScheduler(
42 | optimizer,
43 | decay_t=decay_steps,
44 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
45 | warmup_lr_init=config.TRAIN.WARMUP_LR,
46 | warmup_t=warmup_steps,
47 | t_in_epochs=False,
48 | )
49 |
50 | return lr_scheduler
51 |
52 |
53 | class LinearLRScheduler(Scheduler):
54 | def __init__(self,
55 | optimizer: torch.optim.Optimizer,
56 | t_initial: int,
57 | lr_min_rate: float,
58 | warmup_t=0,
59 | warmup_lr_init=0.,
60 | t_in_epochs=True,
61 | noise_range_t=None,
62 | noise_pct=0.67,
63 | noise_std=1.0,
64 | noise_seed=42,
65 | initialize=True,
66 | ) -> None:
67 | super().__init__(
68 | optimizer, param_group_field="lr",
69 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
70 | initialize=initialize)
71 |
72 | self.t_initial = t_initial
73 | self.lr_min_rate = lr_min_rate
74 | self.warmup_t = warmup_t
75 | self.warmup_lr_init = warmup_lr_init
76 | self.t_in_epochs = t_in_epochs
77 | if self.warmup_t:
78 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
79 | super().update_groups(self.warmup_lr_init)
80 | else:
81 | self.warmup_steps = [1 for _ in self.base_values]
82 |
83 | def _get_lr(self, t):
84 | if t < self.warmup_t:
85 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
86 | else:
87 | t = t - self.warmup_t
88 | total_t = self.t_initial - self.warmup_t
89 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
90 | return lrs
91 |
92 | def get_epoch_values(self, epoch: int):
93 | if self.t_in_epochs:
94 | return self._get_lr(epoch)
95 | else:
96 | return None
97 |
98 | def get_update_values(self, num_updates: int):
99 | if not self.t_in_epochs:
100 | return self._get_lr(num_updates)
101 | else:
102 | return None
103 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Unified Contrastive Learning (UniCL) in Image-Text-Label Space
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Jianwei Yang (jianwyan@microsoft.com)
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import time
10 | import argparse
11 | import datetime
12 | import numpy as np
13 |
14 | import torch
15 | import torch.backends.cudnn as cudnn
16 | import torch.distributed as dist
17 |
18 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
19 | from timm.utils import accuracy, AverageMeter
20 |
21 | from config import get_config
22 | from model import build_model
23 | from data import build_loader
24 | from lr_scheduler import build_scheduler
25 | from optimizer import build_optimizer
26 | from logger import create_logger
27 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
28 |
29 | try:
30 | # noinspection PyUnresolvedReferences
31 | from apex import amp
32 | except ImportError:
33 | amp = None
34 |
35 |
36 | def parse_option():
37 | parser = argparse.ArgumentParser('UniCL training and evaluation script', add_help=False)
38 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
39 | parser.add_argument(
40 | "--opts",
41 | help="Modify config options by adding 'KEY VALUE' pairs. ",
42 | default=None,
43 | nargs='+',
44 | )
45 |
46 | # easy config modification
47 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
48 | parser.add_argument('--dataset', type=str, default='imagenet', help='dataset name')
49 | parser.add_argument('--data-path', type=str, help='path to dataset')
50 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
51 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
52 | help='no: no cache, '
53 | 'full: cache all data, '
54 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
55 | parser.add_argument('--resume', help='resume from checkpoint')
56 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
57 | parser.add_argument('--use-checkpoint', action='store_true',
58 | help="whether to use gradient checkpointing to save memory")
59 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
60 | help='mixed precision opt level, if O0, no amp is used')
61 | parser.add_argument('--output', default='output', type=str, metavar='PATH',
62 | help='root of output folder, the full path is