├── .gitignore ├── CODE_OF_CONDUCT.md ├── DATA.md ├── INSTALL.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── config.py ├── configs ├── unicl_swin_base.yaml └── unicl_swin_tiny.yaml ├── data ├── __init__.py ├── build.py ├── cached_image_folder.py ├── imagenet │ ├── __init__.py │ ├── constants.py │ └── real_labels.py ├── language │ ├── __init__.py │ └── prompt_engineering.py ├── samplers.py ├── zipdata.py └── zipreader.py ├── figures ├── unicl_comparison.png └── unified_cv.png ├── logger.py ├── lr_scheduler.py ├── main.py ├── model ├── __init__.py ├── image_encoder │ ├── __init__.py │ ├── build.py │ ├── focalnet.py │ └── swin_transformer.py ├── model.py └── text_encoder │ ├── __init__.py │ ├── build.py │ ├── hf_model.py │ ├── registry.py │ └── transformer.py ├── optimizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /DATA.md: -------------------------------------------------------------------------------- 1 | ## ImageNet-1K Preparation 2 | 3 | We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to 4 | load data: 5 | 6 | - For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like: 7 | ```bash 8 | $ tree data 9 | imagenet 10 | ├── train 11 | │ ├── class1 12 | │ │ ├── img1.jpeg 13 | │ │ ├── img2.jpeg 14 | │ │ └── ... 15 | │ ├── class2 16 | │ │ ├── img3.jpeg 17 | │ │ └── ... 18 | │ └── ... 19 | └── val 20 | ├── class1 21 | │ ├── img4.jpeg 22 | │ ├── img5.jpeg 23 | │ └── ... 24 | ├── class2 25 | │ ├── img6.jpeg 26 | │ └── ... 27 | └── ... 28 | 29 | ``` 30 | - To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes 31 | four files: 32 | - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits. 33 | - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth 34 | label. Make sure the data folder looks like this: 35 | 36 | ```bash 37 | $ tree data 38 | data 39 | └── ImageNet-Zip 40 | ├── train_map.txt 41 | ├── train.zip 42 | ├── val_map.txt 43 | └── val.zip 44 | 45 | $ head -n 5 data/ImageNet-Zip/val_map.txt 46 | ILSVRC2012_val_00000001.JPEG 65 47 | ILSVRC2012_val_00000002.JPEG 970 48 | ILSVRC2012_val_00000003.JPEG 230 49 | ILSVRC2012_val_00000004.JPEG 809 50 | ILSVRC2012_val_00000005.JPEG 516 51 | 52 | $ head -n 5 data/ImageNet-Zip/train_map.txt 53 | n01440764/n01440764_10026.JPEG 0 54 | n01440764/n01440764_10027.JPEG 0 55 | n01440764/n01440764_10029.JPEG 0 56 | n01440764/n01440764_10040.JPEG 0 57 | n01440764/n01440764_10042.JPEG 0 58 | ``` -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | 2 | ## Installation Instructions 3 | 4 | - Clone this repo: 5 | 6 | ```bash 7 | git clone https://github.com/microsoft/UniCL 8 | cd UniCL 9 | ``` 10 | 11 | - Create a conda virtual environment and activate it (optional): 12 | 13 | ```bash 14 | conda create -n unicl python=3.7 -y 15 | conda activate unicl 16 | ``` 17 | 18 | - Install `CUDA==10.1` with `cudnn7` following 19 | the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) 20 | - Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`: 21 | 22 | ```bash 23 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch 24 | ``` 25 | 26 | - Install `timm==0.3.2`: 27 | 28 | ```bash 29 | pip install timm==0.3.2 30 | ``` 31 | 32 | - Install `Apex`: 33 | 34 | ```bash 35 | git clone https://github.com/NVIDIA/apex 36 | cd apex 37 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 38 | ``` 39 | 40 | - Install other requirements: 41 | 42 | ```bash 43 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 44 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Unified Contrastive Learning in Image-Text-Label Space](https://arxiv.org/abs/2204.03610) 2 | 3 | This is the official Pytorch implementation of UniCL: 4 | 5 | ["**Unifiled Contrastive Learning in Image-Text-Label Space. CVPR 2022**"](https://arxiv.org/abs/2204.03610) by [Jianwei Yang*](https://jwyang.github.io/), [Chunyuan Li*](https://chunyuan.li/), [Pengchuan Zhang*](https://pzzhang.github.io/pzzhang/), [Bin Xiao*](https://www.microsoft.com/en-us/research/people/bixi/), [Ce Liu](http://people.csail.mit.edu/celiu/), [Lu Yuan](https://scholar.google.com/citations?user=k9TsUVsAAAAJ&hl=en) and [Jianfeng Gao](https://www.microsoft.com/en-us/research/people/jfgao/?from=http%3A%2F%2Fresearch.microsoft.com%2Fen-us%2Fum%2Fpeople%2Fjfgao%2F). 6 | 7 | ## Introduction 8 | 9 |

10 | 11 |

12 | 13 | In this paper, we introduce a new perspective on commonly used image-label and image-text data by residing them in an image-text-label space. In this space, a new learning paradigm, called **Unified Contrastive Learning (UniCL)** with a single learning objective is proposed to seamlessly prompt the synergy of two data types. We demonstrate that UniCL is an effective way of learning **semantically rich yet discriminative representations**, universally for image recognition in zero-shot, linear-probe, fully finetuning and transfer learning scenarios. When scaled up to billions of data, UniCL can exclusively learn a powerful visual-semantic representation supporting dozens of downstream tasks shown in [Florence](https://arxiv.org/pdf/2111.11432v1.pdf). 14 | 15 | We make the comparisons between UniCL with coventional learning methods below: 16 | 17 |

18 | 19 |

20 | 21 | ## :collision: All previous links are broken. Please find all checkpoints here: [https://github.com/microsoft/UniCL/releases/tag/v1.0](https://github.com/microsoft/UniCL/releases/tag/v1.0) 22 | ## Updates 23 | * [11/24/2022] [KLITE](https://arxiv.org/abs/2204.09222), the knowledge-augmented version of UniCL, is publicly released on [Github](https://github.com/microsoft/klite). 24 | * :collision: [10/05/2022] How do we use the pretrainied UniCL checkpoints? Beyond the zero-shot classification shown in our paper, we can use them for object detection. Now [RegionCLIP](https://github.com/microsoft/RegionCLIP/tree/zero-shot) supports to use pretrained UniCL transformer models, such as Swin, ViT for open-vocabulary object detection without any finetuning. Check it out! 25 | * [08/19/2022] Organizing ECCV Workshop [*Computer Vision in the Wild (CVinW)*](https://computer-vision-in-the-wild.github.io/eccv-2022/), where two challenges are hosted to evaluate the zero-shot, few-shot and full-shot performance of pre-trained vision models in downstream tasks: 26 | - [``*Image Classification in the Wild (ICinW)*''](https://eval.ai/web/challenges/challenge-page/1832/overview) Challenge evaluates on 20 image classification tasks. 27 | - [``*Object Detection in the Wild (ODinW)*''](https://eval.ai/web/challenges/challenge-page/1839/overview) Challenge evaluates on 35 object detection tasks. 28 | 29 | $\qquad$ [ [Workshop]](https://computer-vision-in-the-wild.github.io/eccv-2022/) $\qquad$ [ [IC Challenge] ](https://eval.ai/web/challenges/challenge-page/1832/overview) 30 | $\qquad$ [ [OD Challenge] ](https://eval.ai/web/challenges/challenge-page/1839/overview) 31 | 32 | 33 | 34 | * [06/19/2022] Released the evaluation benchmark used in UniCL, [ELEVATER](https://arxiv.org/abs/2204.08790), which contains 20 downstream image classification tasks. More info: [[Benchmark]](https://computer-vision-in-the-wild.github.io/ELEVATER/) [[Toolkit]](https://github.com/Computer-Vision-in-the-Wild/Elevater_Toolkit_IC) [[Paper]](https://arxiv.org/abs/2204.08790) 35 | * [06/04/2022] Checkout out our [Huggingface Gradio demo](https://huggingface.co/spaces/CVPR/unicl-zero-shot-img-recog). 36 | * [05/21/2022] Released pretrained model and zero-shot evaluation on ImageNet-1k. 37 | 38 | ## Benchmarking 39 | 40 | ### Image-label training augmented by image-text pairs 41 | 42 | | Model | Training Set | Top-1 on IN-1K | ZS on 14 datasets | Download 43 | | :----: | :---: | :---: | :---: | :---: | 44 | | Swin-T | IN-1K | 79.9 | 30.2 | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in1k.pth)/[config](configs/unicl_swin_tiny.yaml) 45 | | Swin-T | IN-1K + GCC-3M | 80.2 | 39.0 | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in1k_gcc3m.pth)/[config](configs/unicl_swin_tiny.yaml) 46 | | Swin-T | IN-1K + GYFCC-14M | 81.1 | 40.0 | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in1k_yfcc14m.pth)/[config](configs/unicl_swin_tiny.yaml) 47 | | Swin-T | IN-1K + GCC-15M | 81.8 | 45.1 | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in1k_gcc15m.pth)/[config](configs/unicl_swin_tiny.yaml) 48 | 49 | Note that all the above models are trained **without** strong data augmentations like mixup and cutmix. 50 | 51 | ### Image-text learning augmented by image-label data 52 | 53 | | Model | Training Set | ZS on IN-1K | ZS on 14 datasets | ZS on 20 datasets | Download 54 | | :----: | :---: | :---: | :---: | :---: | :---: | 55 | | Swin-T | YFCC-14M | 30.1 | 36.3 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/yfcc14m.pth)/[config](configs/unicl_swin_tiny.yaml) 56 | | Swin-T | IN-21K | 28.5 | 37.8 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k.pth)/[config](configs/unicl_swin_tiny.yaml) 57 | | Swin-T | IN-22K | 66.8 | 38.9 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in22k.pth)/[config](configs/unicl_swin_tiny.yaml) 58 | | Swin-T | IN-21K (half) + YFCC-14M (half) | 36.4 | 45.5 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_yfcc14m_half.pth)/[config](configs/unicl_swin_tiny.yaml) 59 | | Swin-T | IN-21K + YFCC-14M | 40.5 | 49.1 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_yfcc14m.pth)/[config](configs/unicl_swin_tiny.yaml) 60 | | Swin-B | IN-21K | 29.9 | 42.4 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_swin_base.pth)/[config](configs/unicl_swin_base.yaml) 61 | | Swin-B | IN-21K (half) + YFCC-14M (half) | 41.1 | 48.5 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_yfcc14m_half_swin_base.pth)/[config](configs/unicl_swin_base.yaml) 62 | | Swin-B | IN-21K + YFCC-14M | 44.3 | 52.2 | - | [ckpt](https://projects4jw.blob.core.windows.net/unicl/release/in21k_yfcc14m_swin_base.pth)/[config](configs/unicl_swin_base.yaml) 63 | | Swin-B | IN-21K + GCC-15M + YFCC-14M | 52.2 | - | 43.2 | [ckpt]()/[config](configs/unicl_swin_base.yaml) 64 | | Focal-B | IN-21K + GCC-15M + YFCC-14M | 54.2 | - | 44.0 | [ckpt]()/[config]() 65 | 66 | **NOTE**: Setting "ZS on 20 datasets" is used in the ICinW benchmark. 67 | 68 | ## Getting Started 69 | 70 | ### Installation 71 | 72 | Please follow [INSTALL.md](./INSTALL.md) for installation. 73 | 74 | ### Data preparation 75 | 76 | Please following [DATA.md](./DATA.md) for data preparation. 77 | 78 | ### Evaluation 79 | 80 | To evaluate a pre-trained UniCL on ImageNet val, run: 81 | 82 | ```bash 83 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py --eval \ 84 | --cfg --resume --data-path 85 | ``` 86 | 87 | For example, to evaluate the UniCL-Swin-Tiny trained on YFCC-14M with a single GPU: 88 | 89 | ```bash 90 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \ 91 | --cfg configs/unicl_swin_tiny.yaml --resume yfcc14m.pth --data-path 92 | ``` 93 | 94 | ### The Image Classification in the Wild Benchmark 95 | Interested in evaluating UniCL for downstream image classification tasks, and comparing performance on the same task suite? We release [ELEVATER](https://computer-vision-in-the-wild.github.io/ELEVATER/) benchmark, which has 20 downstream image classification tasks. The [software toolkit](https://github.com/Computer-Vision-in-the-Wild/Elevater_Toolkit_IC) is also released to ease the process to onboad new models. It will be hosted as a challenge at the [CV in the Wild Workshop @ ECCV 2022](https://computer-vision-in-the-wild.github.io/eccv-2022/). We hope our benchmark and toolkit can encourage the community to solve the challenge of image classification in the wild! 96 | 97 | Please see more instructions: [[Benchmark]](https://computer-vision-in-the-wild.github.io/ELEVATER/) [[Toolkit]](https://github.com/Computer-Vision-in-the-Wild/Elevater_Toolkit_IC) [[Paper]](https://arxiv.org/abs/2204.08790) 98 | 99 | ## Citation 100 | 101 | If you find this repo useful to your project, please consider to cite it with following bib: 102 | 103 | @misc{yang2022unified, 104 | title={Unified Contrastive Learning in Image-Text-Label Space}, 105 | author={Jianwei Yang and Chunyuan Li and Pengchuan Zhang and Bin Xiao and Ce Liu and Lu Yuan and Jianfeng Gao}, 106 | year={2022}, 107 | eprint={2204.03610}, 108 | archivePrefix={arXiv}, 109 | primaryClass={cs.CV} 110 | } 111 | 112 | ## Acknowledgement 113 | 114 | Our codebase is built based on [Swin Transformer](https://github.com/microsoft/Swin-Transformer), [Focal Transformer](https://github.com/microsoft/Focal-Transformer) and [FocalNet](https://github.com/microsoft/FocalNet). 115 | 116 | ## Contributing 117 | 118 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 119 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 120 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 121 | 122 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 123 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 124 | provided by the bot. You will only need to do this once across all repos using our CLA. 125 | 126 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 127 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 128 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 129 | 130 | ## Trademarks 131 | 132 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 133 | trademarks or logos is subject to and must follow 134 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 135 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 136 | Any use of third-party trademarks or logos are subject to those third-party's policies. 137 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Unified Contrastive Learning (UniCL) 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Jianwei Yang (jianwyan@microsoft.com) 6 | # Based on Swin Transformer written by Zhe Liu 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import yaml 11 | from yacs.config import CfgNode as CN 12 | 13 | _C = CN() 14 | _C.VERBOSE = False 15 | 16 | # Base config files 17 | _C.BASE = [''] 18 | 19 | # ----------------------------------------------------------------------------- 20 | # Data settings 21 | # ----------------------------------------------------------------------------- 22 | _C.DATA = CN() 23 | # Batch size for a single GPU, could be overwritten by command line argument 24 | _C.DATA.BATCH_SIZE = 128 25 | # Path to dataset, could be overwritten by command line argument 26 | _C.DATA.DATA_PATH = '' 27 | # Dataset name 28 | _C.DATA.DATASET = 'imagenet' 29 | # Input image size 30 | _C.DATA.IMG_SIZE = 224 31 | # Interpolation to resize image (random, bilinear, bicubic) 32 | _C.DATA.INTERPOLATION = 'bicubic' 33 | # Use zipped dataset instead of folder dataset 34 | # could be overwritten by command line argument 35 | _C.DATA.ZIP_MODE = False 36 | # Cache Data in Memory, could be overwritten by command line argument 37 | _C.DATA.CACHE_MODE = 'part' 38 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 39 | _C.DATA.PIN_MEMORY = True 40 | # Number of data loading threads 41 | _C.DATA.NUM_WORKERS = 8 42 | 43 | # ----------------------------------------------------------------------------- 44 | # Model settings 45 | # ----------------------------------------------------------------------------- 46 | _C.MODEL = CN() 47 | # Model name 48 | _C.MODEL.NAME = '' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.RESUME = '' 51 | # Number of classes, overwritten in data preparation 52 | _C.MODEL.NUM_CLASSES = 0 53 | # Label Smoothing 54 | _C.MODEL.LABEL_SMOOTHING = 0.1 55 | # Whether load pretrained model 56 | _C.MODEL.PRETRAINED = '' 57 | # Projection dimension 58 | _C.MODEL.DIM_PROJECTION = 512 59 | # Mode specific 60 | _C.MODEL.SPEC = CN(new_allowed=True) 61 | # ----------------------------------------------------------------------------- 62 | # Build Image Encoder 63 | # ----------------------------------------------------------------------------- 64 | _C.MODEL.IMAGE_ENCODER = CN() 65 | # Image encoder type 66 | _C.MODEL.IMAGE_ENCODER.TYPE = 'swin' 67 | # Input image size 68 | _C.MODEL.IMAGE_ENCODER.IMG_SIZE = 224 69 | # Dropout rate 70 | _C.MODEL.IMAGE_ENCODER.DROP_RATE = 0.0 71 | # Drop path rate 72 | _C.MODEL.IMAGE_ENCODER.DROP_PATH_RATE = 0.1 73 | 74 | # Swin Transformer parameters 75 | _C.MODEL.IMAGE_ENCODER.SWIN = CN() 76 | _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_SIZE = 4 77 | _C.MODEL.IMAGE_ENCODER.SWIN.IN_CHANS = 3 78 | _C.MODEL.IMAGE_ENCODER.SWIN.EMBED_DIM = 96 79 | _C.MODEL.IMAGE_ENCODER.SWIN.DEPTHS = [2, 2, 6, 2] 80 | _C.MODEL.IMAGE_ENCODER.SWIN.NUM_HEADS = [3, 6, 12, 24] 81 | _C.MODEL.IMAGE_ENCODER.SWIN.WINDOW_SIZE = 7 82 | _C.MODEL.IMAGE_ENCODER.SWIN.MLP_RATIO = 4. 83 | _C.MODEL.IMAGE_ENCODER.SWIN.QKV_BIAS = True 84 | _C.MODEL.IMAGE_ENCODER.SWIN.QK_SCALE = None 85 | _C.MODEL.IMAGE_ENCODER.SWIN.APE = False 86 | _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_NORM = True 87 | 88 | # FocalNet parameters 89 | _C.MODEL.IMAGE_ENCODER.FOCAL = CN() 90 | _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_SIZE = 4 91 | _C.MODEL.IMAGE_ENCODER.FOCAL.IN_CHANS = 3 92 | _C.MODEL.IMAGE_ENCODER.FOCAL.EMBED_DIM = 96 93 | _C.MODEL.IMAGE_ENCODER.FOCAL.DEPTHS = [2, 2, 6, 2] 94 | _C.MODEL.IMAGE_ENCODER.FOCAL.MLP_RATIO = 4. 95 | _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_NORM = True 96 | _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_LEVELS = [2, 2, 2, 2] 97 | _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_WINDOWS = [3, 3, 3, 3] 98 | _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_FACTORS = [2, 2, 2, 2] 99 | _C.MODEL.IMAGE_ENCODER.FOCAL.USE_CONV_EMBED = False 100 | _C.MODEL.IMAGE_ENCODER.FOCAL.USE_LAYERSCALE = False 101 | _C.MODEL.IMAGE_ENCODER.FOCAL.USE_POSTLN = False 102 | 103 | # ----------------------------------------------------------------------------- 104 | # Build Text Encoder 105 | # ----------------------------------------------------------------------------- 106 | _C.MODEL.TEXT_ENCODER = CN() 107 | 108 | _C.MODEL.TEXT_ENCODER.NAME = 'transformer' 109 | _C.MODEL.TEXT_ENCODER.LOAD_PRETRAINED = False 110 | _C.MODEL.TEXT_ENCODER.PRETRAINED = '' 111 | _C.MODEL.TEXT_ENCODER.TOKENIZER = 'clip' 112 | _C.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77 113 | _C.MODEL.TEXT_ENCODER.WIDTH = 1024 114 | _C.MODEL.TEXT_ENCODER.HEADS = 16 115 | _C.MODEL.TEXT_ENCODER.LAYERS = 12 116 | _C.MODEL.TEXT_ENCODER.AUTOGRESSIVE = True 117 | 118 | # ----------------------------------------------------------------------------- 119 | # Training settings 120 | # ----------------------------------------------------------------------------- 121 | _C.TRAIN = CN() 122 | _C.TRAIN.START_EPOCH = 0 123 | _C.TRAIN.EPOCHS = 32 124 | _C.TRAIN.WARMUP_EPOCHS = 5 125 | _C.TRAIN.WEIGHT_DECAY = 0.1 126 | _C.TRAIN.BASE_LR = 5e-4 127 | _C.TRAIN.WARMUP_LR = 5e-7 128 | _C.TRAIN.MIN_LR = 5e-6 129 | # Clip gradient norm 130 | _C.TRAIN.CLIP_GRAD = 5.0 131 | # Auto resume from latest checkpoint 132 | _C.TRAIN.AUTO_RESUME = True 133 | # Gradient accumulation steps 134 | # could be overwritten by command line argument 135 | _C.TRAIN.ACCUMULATION_STEPS = 0 136 | # Whether to use gradient checkpointing to save memory 137 | # could be overwritten by command line argument 138 | _C.TRAIN.USE_CHECKPOINT = False 139 | 140 | # LR scheduler 141 | _C.TRAIN.LR_SCHEDULER = CN() 142 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 143 | # Epoch interval to decay LR, used in StepLRScheduler 144 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 145 | # LR decay rate, used in StepLRScheduler 146 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 147 | 148 | # Optimizer 149 | _C.TRAIN.OPTIMIZER = CN() 150 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 151 | # Optimizer Epsilon 152 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 153 | # Optimizer Betas 154 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 155 | # SGD momentum 156 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 157 | 158 | # ----------------------------------------------------------------------------- 159 | # Augmentation settings 160 | # ----------------------------------------------------------------------------- 161 | _C.AUG = CN() 162 | # Color jitter factor 163 | _C.AUG.COLOR_JITTER = 0.4 164 | # Use AutoAugment policy. "v0" or "original" 165 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 166 | # Random erase prob 167 | _C.AUG.REPROB = 0.25 168 | # Random erase mode 169 | _C.AUG.REMODE = 'pixel' 170 | # Random erase count 171 | _C.AUG.RECOUNT = 1 172 | # Mixup alpha, mixup enabled if > 0 173 | _C.AUG.MIXUP = 0.8 174 | # Cutmix alpha, cutmix enabled if > 0 175 | _C.AUG.CUTMIX = 1.0 176 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 177 | _C.AUG.CUTMIX_MINMAX = None 178 | # Probability of performing mixup or cutmix when either/both is enabled 179 | _C.AUG.MIXUP_PROB = 1.0 180 | # Probability of switching to cutmix when both mixup and cutmix enabled 181 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 182 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 183 | _C.AUG.MIXUP_MODE = 'batch' 184 | 185 | # ----------------------------------------------------------------------------- 186 | # Testing settings 187 | # ----------------------------------------------------------------------------- 188 | _C.TEST = CN() 189 | # Whether to use center crop when testing 190 | _C.TEST.CROP = True 191 | 192 | # ----------------------------------------------------------------------------- 193 | # Misc 194 | # ----------------------------------------------------------------------------- 195 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 196 | # overwritten by command line argument 197 | _C.AMP_OPT_LEVEL = '' 198 | # Path to output folder, overwritten by command line argument 199 | _C.OUTPUT = '' 200 | # Tag of experiment, overwritten by command line argument 201 | _C.TAG = 'default' 202 | # Frequency to save checkpoint 203 | _C.SAVE_FREQ = 1 204 | # Frequency to logging info 205 | _C.PRINT_FREQ = 100 206 | # Fixed random seed 207 | _C.SEED = 0 208 | # Perform evaluation only, overwritten by command line argument 209 | _C.EVAL_MODE = False 210 | # Test throughput only, overwritten by command line argument 211 | _C.THROUGHPUT_MODE = False 212 | # Debug only so that skip dataloader initialization, overwritten by command line argument 213 | _C.DEBUG_MODE = False 214 | # local rank for DistributedDataParallel, given by command line argument 215 | _C.LOCAL_RANK = 0 216 | 217 | 218 | def _update_config_from_file(config, cfg_file): 219 | config.defrost() 220 | with open(cfg_file, 'r') as f: 221 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 222 | 223 | for cfg in yaml_cfg.setdefault('BASE', ['']): 224 | if cfg: 225 | _update_config_from_file( 226 | config, os.path.join(os.path.dirname(cfg_file), cfg) 227 | ) 228 | print('=> merge config from {}'.format(cfg_file)) 229 | config.merge_from_file(cfg_file) 230 | config.freeze() 231 | 232 | 233 | def update_config(config, args): 234 | _update_config_from_file(config, args.cfg) 235 | 236 | config.defrost() 237 | if args.opts: 238 | config.merge_from_list(args.opts) 239 | 240 | # merge from specific arguments 241 | if args.batch_size: 242 | config.DATA.BATCH_SIZE = args.batch_size 243 | if args.dataset: 244 | config.DATA.DATASET = args.dataset 245 | if args.data_path: 246 | config.DATA.DATA_PATH = args.data_path 247 | if args.zip: 248 | config.DATA.ZIP_MODE = True 249 | if args.cache_mode: 250 | config.DATA.CACHE_MODE = args.cache_mode 251 | if args.resume: 252 | config.MODEL.RESUME = args.resume 253 | if args.accumulation_steps: 254 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 255 | if args.use_checkpoint: 256 | config.TRAIN.USE_CHECKPOINT = True 257 | if args.amp_opt_level: 258 | config.AMP_OPT_LEVEL = args.amp_opt_level 259 | if args.output: 260 | config.OUTPUT = args.output 261 | if args.tag: 262 | config.TAG = args.tag 263 | if args.eval: 264 | config.EVAL_MODE = True 265 | if args.throughput: 266 | config.THROUGHPUT_MODE = True 267 | if args.debug: 268 | config.DEBUG_MODE = True 269 | 270 | if config.DATA.DATASET == 'imagewoof': 271 | config.MODEL.NUM_CLASSES = 10 272 | 273 | # set local rank for distributed training 274 | config.LOCAL_RANK = args.local_rank 275 | 276 | # output folder 277 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 278 | 279 | config.freeze() 280 | 281 | 282 | def get_config(args): 283 | """Get a yacs CfgNode object with default values.""" 284 | # Return a clone so that the defaults will not be altered 285 | # This is for the "local variable" use pattern 286 | config = _C.clone() 287 | update_config(config, args) 288 | 289 | return config 290 | -------------------------------------------------------------------------------- /configs/unicl_swin_base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: unicl_swin_base 3 | DIM_PROJECTION: 512 4 | IMAGE_ENCODER: 5 | TYPE: swin 6 | DROP_PATH_RATE: 0.5 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 7 12 | TEXT_ENCODER: 13 | NAME: 'transformer' 14 | WIDTH: 512 15 | HEADS: 8 16 | LAYERS: 12 17 | -------------------------------------------------------------------------------- /configs/unicl_swin_tiny.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: unicl_swin_tiny 3 | DIM_PROJECTION: 512 4 | IMAGE_ENCODER: 5 | TYPE: swin 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 7 12 | TEXT_ENCODER: 13 | NAME: 'transformer' 14 | WIDTH: 512 15 | HEADS: 8 16 | LAYERS: 12 17 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | import torch.distributed as dist 12 | from torchvision import datasets, transforms 13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.data import Mixup 15 | from timm.data import create_transform 16 | from timm.data.transforms import _pil_interp 17 | 18 | from .cached_image_folder import CachedImageFolder 19 | from .samplers import SubsetRandomSampler 20 | from model.text_encoder import build_tokenizer 21 | 22 | def build_loader(config): 23 | config.defrost() 24 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 25 | config.freeze() 26 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 27 | dataset_val, _ = build_dataset(is_train=False, config=config) 28 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 29 | 30 | num_tasks = dist.get_world_size() 31 | global_rank = dist.get_rank() 32 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 33 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 34 | sampler_train = SubsetRandomSampler(indices) 35 | else: 36 | sampler_train = torch.utils.data.DistributedSampler( 37 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 38 | ) 39 | 40 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 41 | sampler_val = SubsetRandomSampler(indices) 42 | 43 | data_loader_train = torch.utils.data.DataLoader( 44 | dataset_train, sampler=sampler_train, 45 | batch_size=config.DATA.BATCH_SIZE, 46 | num_workers=config.DATA.NUM_WORKERS, 47 | pin_memory=config.DATA.PIN_MEMORY, 48 | drop_last=True, 49 | ) 50 | 51 | data_loader_val = torch.utils.data.DataLoader( 52 | dataset_val, sampler=sampler_val, 53 | batch_size=config.DATA.BATCH_SIZE, 54 | shuffle=False, 55 | num_workers=config.DATA.NUM_WORKERS, 56 | pin_memory=config.DATA.PIN_MEMORY, 57 | drop_last=False 58 | ) 59 | 60 | # setup mixup / cutmix 61 | mixup_fn = None 62 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 63 | if mixup_active: 64 | mixup_fn = Mixup( 65 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 66 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 67 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 68 | 69 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 70 | 71 | 72 | def build_dataset(is_train, config): 73 | transform = build_transform(is_train, config) 74 | tokenizer = build_tokenizer(config['MODEL']['TEXT_ENCODER']) 75 | 76 | if config.DATA.DATASET == 'imagenet': 77 | prefix = 'train' if is_train else 'val' 78 | if config.DATA.ZIP_MODE: 79 | from .zipdata import ZipData 80 | if is_train: 81 | datapath = os.path.join(config.DATA.DATA_PATH, 'train.zip') 82 | data_map = os.path.join(config.DATA.DATA_PATH, 'train_map.txt') 83 | else: 84 | datapath = os.path.join(config.DATA.DATA_PATH, 'val.zip') 85 | data_map = os.path.join(config.DATA.DATA_PATH, 'val_map.txt') 86 | dataset = ZipData( 87 | datapath, data_map, 88 | tokenizer, 89 | transform, 90 | ) 91 | else: 92 | root = os.path.join(config.DATA.DATA_PATH, prefix) 93 | dataset = datasets.ImageFolder(root, transform=transform) 94 | nb_classes = 1000 95 | elif config.DATA.DATASET == 'imagewoof': 96 | prefix = 'train' if is_train else 'val' 97 | root = os.path.join(config.DATA.DATA_PATH, prefix) 98 | dataset = datasets.ImageFolder(root, transform=transform) 99 | nb_classes = 1000 100 | else: 101 | raise NotImplementedError("We only support ImageNet Now.") 102 | 103 | return dataset, nb_classes 104 | 105 | 106 | def build_transform(is_train, config): 107 | resize_im = config.DATA.IMG_SIZE > 32 108 | if is_train: 109 | # this should always dispatch to transforms_imagenet_train 110 | transform = create_transform( 111 | input_size=config.DATA.IMG_SIZE, 112 | is_training=True, 113 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 114 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 115 | re_prob=config.AUG.REPROB, 116 | re_mode=config.AUG.REMODE, 117 | re_count=config.AUG.RECOUNT, 118 | interpolation=config.DATA.INTERPOLATION, 119 | ) 120 | if not resize_im: 121 | # replace RandomResizedCropAndInterpolation with 122 | # RandomCrop 123 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 124 | return transform 125 | 126 | t = [] 127 | if resize_im: 128 | if config.TEST.CROP: 129 | size = int((256 / 224) * config.DATA.IMG_SIZE) 130 | t.append( 131 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 132 | # to maintain same ratio w.r.t. 224 images 133 | ) 134 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 135 | else: 136 | t.append( 137 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 138 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 139 | ) 140 | 141 | t.append(transforms.ToTensor()) 142 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 143 | return transforms.Compose(t) 144 | -------------------------------------------------------------------------------- /data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import io 9 | import os 10 | import time 11 | import torch.distributed as dist 12 | import torch.utils.data as data 13 | from PIL import Image 14 | 15 | from .zipreader import is_zip_path, ZipReader 16 | 17 | 18 | def has_file_allowed_extension(filename, extensions): 19 | """Checks if a file is an allowed extension. 20 | Args: 21 | filename (string): path to a file 22 | Returns: 23 | bool: True if the filename ends with a known image extension 24 | """ 25 | filename_lower = filename.lower() 26 | return any(filename_lower.endswith(ext) for ext in extensions) 27 | 28 | 29 | def find_classes(dir): 30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 31 | classes.sort() 32 | class_to_idx = {classes[i]: i for i in range(len(classes))} 33 | return classes, class_to_idx 34 | 35 | 36 | def make_dataset(dir, class_to_idx, extensions): 37 | images = [] 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if has_file_allowed_extension(fname, extensions): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 55 | images = [] 56 | with open(ann_file, "r") as f: 57 | contents = f.readlines() 58 | for line_str in contents: 59 | path_contents = [c for c in line_str.split('\t')] 60 | im_file_name = path_contents[0] 61 | class_index = int(path_contents[1]) 62 | 63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 64 | item = (os.path.join(img_prefix, im_file_name), class_index) 65 | 66 | images.append(item) 67 | 68 | return images 69 | 70 | 71 | class DatasetFolder(data.Dataset): 72 | """A generic data loader where the samples are arranged in this way: :: 73 | root/class_x/xxx.ext 74 | root/class_x/xxy.ext 75 | root/class_x/xxz.ext 76 | root/class_y/123.ext 77 | root/class_y/nsdf3.ext 78 | root/class_y/asd932_.ext 79 | Args: 80 | root (string): Root directory path. 81 | loader (callable): A function to load a sample given its path. 82 | extensions (list[string]): A list of allowed extensions. 83 | transform (callable, optional): A function/transform that takes in 84 | a sample and returns a transformed version. 85 | E.g, ``transforms.RandomCrop`` for images. 86 | target_transform (callable, optional): A function/transform that takes 87 | in the target and transforms it. 88 | Attributes: 89 | samples (list): List of (sample path, class_index) tuples 90 | """ 91 | 92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 93 | cache_mode="no"): 94 | # image folder mode 95 | if ann_file == '': 96 | _, class_to_idx = find_classes(root) 97 | samples = make_dataset(root, class_to_idx, extensions) 98 | # zip mode 99 | else: 100 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 101 | os.path.join(root, img_prefix), 102 | extensions) 103 | 104 | if len(samples) == 0: 105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.samples = samples 113 | self.labels = [y_1k for _, y_1k in samples] 114 | self.classes = list(set(self.labels)) 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.cache_mode = cache_mode 120 | if self.cache_mode != "no": 121 | self.init_cache() 122 | 123 | def init_cache(self): 124 | assert self.cache_mode in ["part", "full"] 125 | n_sample = len(self.samples) 126 | global_rank = dist.get_rank() 127 | world_size = dist.get_world_size() 128 | 129 | samples_bytes = [None for _ in range(n_sample)] 130 | start_time = time.time() 131 | for index in range(n_sample): 132 | if index % (n_sample // 10) == 0: 133 | t = time.time() - start_time 134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 135 | start_time = time.time() 136 | path, target = self.samples[index] 137 | if self.cache_mode == "full": 138 | samples_bytes[index] = (ZipReader.read(path), target) 139 | elif self.cache_mode == "part" and index % world_size == global_rank: 140 | samples_bytes[index] = (ZipReader.read(path), target) 141 | else: 142 | samples_bytes[index] = (path, target) 143 | self.samples = samples_bytes 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | Returns: 150 | tuple: (sample, target) where target is class_index of the target class. 151 | """ 152 | path, target = self.samples[index] 153 | sample = self.loader(path) 154 | if self.transform is not None: 155 | sample = self.transform(sample) 156 | if self.target_transform is not None: 157 | target = self.target_transform(target) 158 | 159 | return sample, target 160 | 161 | def __len__(self): 162 | return len(self.samples) 163 | 164 | def __repr__(self): 165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 167 | fmt_str += ' Root Location: {}\n'.format(self.root) 168 | tmp = ' Transforms (if any): ' 169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | tmp = ' Target Transforms (if any): ' 171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 172 | return fmt_str 173 | 174 | 175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 176 | 177 | 178 | def pil_loader(path): 179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 180 | if isinstance(path, bytes): 181 | img = Image.open(io.BytesIO(path)) 182 | elif is_zip_path(path): 183 | data = ZipReader.read(path) 184 | img = Image.open(io.BytesIO(data)) 185 | else: 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | 190 | 191 | def accimage_loader(path): 192 | import accimage 193 | try: 194 | return accimage.Image(path) 195 | except IOError: 196 | # Potentially a decoding problem, fall back to PIL.Image 197 | return pil_loader(path) 198 | 199 | 200 | def default_img_loader(path): 201 | from torchvision import get_image_backend 202 | if get_image_backend() == 'accimage': 203 | return accimage_loader(path) 204 | else: 205 | return pil_loader(path) 206 | 207 | 208 | class CachedImageFolder(DatasetFolder): 209 | """A generic data loader where the images are arranged in this way: :: 210 | root/dog/xxx.png 211 | root/dog/xxy.png 212 | root/dog/xxz.png 213 | root/cat/123.png 214 | root/cat/nsdf3.png 215 | root/cat/asd932_.png 216 | Args: 217 | root (string): Root directory path. 218 | transform (callable, optional): A function/transform that takes in an PIL image 219 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 220 | target_transform (callable, optional): A function/transform that takes in the 221 | target and transforms it. 222 | loader (callable, optional): A function to load an image given its path. 223 | Attributes: 224 | imgs (list): List of (image path, class_index) tuples 225 | """ 226 | 227 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 228 | loader=default_img_loader, cache_mode="no"): 229 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 230 | ann_file=ann_file, img_prefix=img_prefix, 231 | transform=transform, target_transform=target_transform, 232 | cache_mode=cache_mode) 233 | self.imgs = self.samples 234 | 235 | def __getitem__(self, index): 236 | """ 237 | Args: 238 | index (int): Index 239 | Returns: 240 | tuple: (image, target) where target is class_index of the target class. 241 | """ 242 | path, target = self.samples[index] 243 | image = self.loader(path) 244 | if self.transform is not None: 245 | img = self.transform(image) 246 | else: 247 | img = image 248 | if self.target_transform is not None: 249 | target = self.target_transform(target) 250 | 251 | return img, target 252 | -------------------------------------------------------------------------------- /data/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import IMAGENET_CLASSES, IMAGENET_DEFAULT_TEMPLATES -------------------------------------------------------------------------------- /data/imagenet/constants.py: -------------------------------------------------------------------------------- 1 | IMAGENET_CLASSES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "dark glasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 2 | 3 | IMAGENET_DEFAULT_TEMPLATES = [ 4 | '{}.', 5 | 'a bad photo of a {}.', 6 | 'a photo of many {}.', 7 | 'a sculpture of a {}.', 8 | 'a photo of the hard to see {}.', 9 | 'a low resolution photo of the {}.', 10 | 'a rendering of a {}.', 11 | 'graffiti of a {}.', 12 | 'a bad photo of the {}.', 13 | 'a cropped photo of the {}.', 14 | 'a tattoo of a {}.', 15 | 'the embroidered {}.', 16 | 'a photo of a hard to see {}.', 17 | 'a bright photo of a {}.', 18 | 'a photo of a clean {}.', 19 | 'a photo of a dirty {}.', 20 | 'a dark photo of the {}.', 21 | 'a drawing of a {}.', 22 | 'a photo of my {}.', 23 | 'the plastic {}.', 24 | 'a photo of the cool {}.', 25 | 'a close-up photo of a {}.', 26 | 'a black and white photo of the {}.', 27 | 'a painting of the {}.', 28 | 'a painting of a {}.', 29 | 'a pixelated photo of the {}.', 30 | 'a sculpture of the {}.', 31 | 'a bright photo of the {}.', 32 | 'a cropped photo of a {}.', 33 | 'a plastic {}.', 34 | 'a photo of the dirty {}.', 35 | 'a jpeg corrupted photo of a {}.', 36 | 'a blurry photo of the {}.', 37 | 'a photo of the {}.', 38 | 'a good photo of the {}.', 39 | 'a rendering of the {}.', 40 | 'a {} in a video game.', 41 | 'a photo of one {}.', 42 | 'a doodle of a {}.', 43 | 'a close-up photo of the {}.', 44 | 'a photo of a {}.', 45 | 'the origami {}.', 46 | 'the {} in a video game.', 47 | 'a sketch of a {}.', 48 | 'a doodle of the {}.', 49 | 'a origami {}.', 50 | 'a low resolution photo of a {}.', 51 | 'the toy {}.', 52 | 'a rendition of the {}.', 53 | 'a photo of the clean {}.', 54 | 'a photo of a large {}.', 55 | 'a rendition of a {}.', 56 | 'a photo of a nice {}.', 57 | 'a photo of a weird {}.', 58 | 'a blurry photo of a {}.', 59 | 'a cartoon {}.', 60 | 'art of a {}.', 61 | 'a sketch of the {}.', 62 | 'a embroidered {}.', 63 | 'a pixelated photo of a {}.', 64 | 'itap of the {}.', 65 | 'a jpeg corrupted photo of the {}.', 66 | 'a good photo of a {}.', 67 | 'a plushie {}.', 68 | 'a photo of the nice {}.', 69 | 'a photo of the small {}.', 70 | 'a photo of the weird {}.', 71 | 'the cartoon {}.', 72 | 'art of the {}.', 73 | 'a drawing of the {}.', 74 | 'a photo of the large {}.', 75 | 'a black and white photo of a {}.', 76 | 'the plushie {}.', 77 | 'a dark photo of a {}.', 78 | 'itap of a {}.', 79 | 'graffiti of the {}.', 80 | 'a toy {}.', 81 | 'itap of my {}.', 82 | 'a photo of a cool {}.', 83 | 'a photo of a small {}.', 84 | 'a tattoo of the {}.', 85 | ] 86 | -------------------------------------------------------------------------------- /data/imagenet/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | 11 | 12 | class RealLabelsImagenet: 13 | 14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)): 15 | with open(real_json) as real_labels: 16 | real_labels = json.load(real_labels) 17 | real_labels = { 18 | f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels 19 | for i, labels in enumerate(real_labels) 20 | } 21 | self.real_labels = real_labels 22 | self.filenames = filenames 23 | assert len(self.filenames) == len(self.real_labels) 24 | self.topk = topk 25 | self.is_correct = {k: [] for k in topk} 26 | self.sample_idx = 0 27 | 28 | def add_result(self, output): 29 | maxk = max(self.topk) 30 | _, pred_batch = output.topk(maxk, 1, True, True) 31 | pred_batch = pred_batch.cpu().numpy() 32 | for pred in pred_batch: 33 | filename = self.filenames[self.sample_idx] 34 | filename = os.path.basename(filename) 35 | if self.real_labels[filename]: 36 | for k in self.topk: 37 | self.is_correct[k].append( 38 | any([p in self.real_labels[filename] for p in pred[:k]])) 39 | self.sample_idx += 1 40 | 41 | def get_accuracy(self, k=None): 42 | if k is None: 43 | return {k: float(np.mean(self.is_correct[k] for k in self.topk))} 44 | else: 45 | return float(np.mean(self.is_correct[k])) * 100 46 | -------------------------------------------------------------------------------- /data/language/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_engineering import prompt_engineering -------------------------------------------------------------------------------- /data/language/prompt_engineering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | def get_prompt_templates(): 6 | prompt_templates = [ 7 | '{}.', 8 | 'a photo of a {}.', 9 | 'a bad photo of a {}.', 10 | 'a photo of many {}.', 11 | 'a sculpture of a {}.', 12 | 'a photo of the hard to see {}.', 13 | 'a low resolution photo of the {}.', 14 | 'a rendering of a {}.', 15 | 'graffiti of a {}.', 16 | 'a bad photo of the {}.', 17 | 'a cropped photo of the {}.', 18 | 'a tattoo of a {}.', 19 | 'the embroidered {}.', 20 | 'a photo of a hard to see {}.', 21 | 'a bright photo of a {}.', 22 | 'a photo of a clean {}.', 23 | 'a photo of a dirty {}.', 24 | 'a dark photo of the {}.', 25 | 'a drawing of a {}.', 26 | 'a photo of my {}.', 27 | 'the plastic {}.', 28 | 'a photo of the cool {}.', 29 | 'a close-up photo of a {}.', 30 | 'a black and white photo of the {}.', 31 | 'a painting of the {}.', 32 | 'a painting of a {}.', 33 | 'a pixelated photo of the {}.', 34 | 'a sculpture of the {}.', 35 | 'a bright photo of the {}.', 36 | 'a cropped photo of a {}.', 37 | 'a plastic {}.', 38 | 'a photo of the dirty {}.', 39 | 'a jpeg corrupted photo of a {}.', 40 | 'a blurry photo of the {}.', 41 | 'a photo of the {}.', 42 | 'a good photo of the {}.', 43 | 'a rendering of the {}.', 44 | 'a {} in a video game.', 45 | 'a photo of one {}.', 46 | 'a doodle of a {}.', 47 | 'a close-up photo of the {}.', 48 | 'the origami {}.', 49 | 'the {} in a video game.', 50 | 'a sketch of a {}.', 51 | 'a doodle of the {}.', 52 | 'a origami {}.', 53 | 'a low resolution photo of a {}.', 54 | 'the toy {}.', 55 | 'a rendition of the {}.', 56 | 'a photo of the clean {}.', 57 | 'a photo of a large {}.', 58 | 'a rendition of a {}.', 59 | 'a photo of a nice {}.', 60 | 'a photo of a weird {}.', 61 | 'a blurry photo of a {}.', 62 | 'a cartoon {}.', 63 | 'art of a {}.', 64 | 'a sketch of the {}.', 65 | 'a embroidered {}.', 66 | 'a pixelated photo of a {}.', 67 | 'itap of the {}.', 68 | 'a jpeg corrupted photo of the {}.', 69 | 'a good photo of a {}.', 70 | 'a plushie {}.', 71 | 'a photo of the nice {}.', 72 | 'a photo of the small {}.', 73 | 'a photo of the weird {}.', 74 | 'the cartoon {}.', 75 | 'art of the {}.', 76 | 'a drawing of the {}.', 77 | 'a photo of the large {}.', 78 | 'a black and white photo of a {}.', 79 | 'the plushie {}.', 80 | 'a dark photo of a {}.', 81 | 'itap of a {}.', 82 | 'graffiti of the {}.', 83 | 'a toy {}.', 84 | 'itap of my {}.', 85 | 'a photo of a cool {}.', 86 | 'a photo of a small {}.', 87 | 'a tattoo of the {}.', 88 | ] 89 | return prompt_templates 90 | 91 | 92 | def prompt_engineering(classnames): 93 | prompt_templates = get_prompt_templates() 94 | temp_idx = np.random.randint(len(prompt_templates)) 95 | 96 | if isinstance(classnames, list): 97 | classname = random.choice(classnames) 98 | else: 99 | classname = classnames 100 | 101 | return prompt_templates[temp_idx].format(classname.replace(',', '').replace('+', ' ')) 102 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /data/zipdata.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | from zipfile import ZipFile, BadZipFile 3 | import torch.utils.data as data 4 | from PIL import Image 5 | from io import BytesIO 6 | import multiprocessing 7 | from .imagenet import IMAGENET_CLASSES 8 | from .language import prompt_engineering 9 | 10 | _VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png'] 11 | 12 | class ZipData(data.Dataset): 13 | _IGNORE_ATTRS = {'_zip_file'} 14 | 15 | def __init__(self, path, map_file, tokenizer, 16 | transform=None, target_transform=None, 17 | extensions=None): 18 | self._path = path 19 | if not extensions: 20 | extensions = _VALID_IMAGE_TYPES 21 | self._zip_file = ZipFile(path) 22 | self.zip_dict = {} 23 | self.samples = [] 24 | assert tokenizer is not None 25 | self.tokenizer = tokenizer 26 | self.transform = transform 27 | self.target_transform = target_transform 28 | 29 | self.class_to_idx = {} 30 | with open(map_file, 'r') as f: 31 | for line in iter(f.readline, ""): 32 | line = line.strip() 33 | if not line: 34 | continue 35 | cls_idx = [l for l in line.split('\t') if l] 36 | if not cls_idx: 37 | continue 38 | assert len(cls_idx) >= 2, "invalid line: {}".format(line) 39 | idx = int(cls_idx[1]) 40 | cls = cls_idx[0] 41 | del cls_idx 42 | at_idx = cls.find('@') 43 | assert at_idx >= 0, "invalid class: {}".format(cls) 44 | cls = cls[at_idx + 1:] 45 | if cls.startswith('/'): 46 | # Python ZipFile expects no root 47 | cls = cls[1:] 48 | assert cls, "invalid class in line {}".format(line) 49 | prev_idx = self.class_to_idx.get(cls) 50 | assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format( 51 | cls, idx, prev_idx 52 | ) 53 | self.class_to_idx[cls] = idx 54 | 55 | for fst in self._zip_file.infolist(): 56 | fname = fst.filename 57 | target = self.class_to_idx.get(fname) 58 | if target is None: 59 | continue 60 | if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0: 61 | continue 62 | ext = op.splitext(fname)[1].lower() 63 | if ext in extensions: 64 | self.samples.append((fname, target)) 65 | assert len(self), "No images found in: {} with map: {}".format(self._path, map_file) 66 | 67 | def __repr__(self): 68 | return 'ZipData({}, size={})'.format(self._path, len(self)) 69 | 70 | def __getstate__(self): 71 | return { 72 | key: val if key not in self._IGNORE_ATTRS else None 73 | for key, val in self.__dict__.iteritems() 74 | } 75 | 76 | def __getitem__(self, index): 77 | proc = multiprocessing.current_process() 78 | pid = proc.pid # get pid of this process. 79 | if pid not in self.zip_dict: 80 | self.zip_dict[pid] = ZipFile(self._path) 81 | zip_file = self.zip_dict[pid] 82 | 83 | if index >= len(self) or index < 0: 84 | raise KeyError("{} is invalid".format(index)) 85 | path, target = self.samples[index] 86 | try: 87 | sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB') 88 | except BadZipFile: 89 | print("bad zip file") 90 | return None, None 91 | if self.transform is not None: 92 | sample = self.transform(sample) 93 | 94 | # convert label to text with prompt 95 | text = self._decode_text_from_label(target) 96 | tokens = self.tokenizer( 97 | text, padding='max_length', 98 | truncation=True, 99 | max_length=77, 100 | return_tensors='pt' 101 | ) if self.tokenizer else text 102 | tokens['input_ids'].squeeze_() 103 | tokens['attention_mask'].squeeze_() 104 | 105 | if self.target_transform is not None: 106 | target = self.target_transform(target) 107 | return sample, tokens, target 108 | 109 | def __len__(self): 110 | return len(self.samples) 111 | 112 | def _decode_text_from_label(self, label): 113 | concept = IMAGENET_CLASSES[int(label)] 114 | text = prompt_engineering(concept) 115 | return text 116 | 117 | -------------------------------------------------------------------------------- /data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import zipfile 10 | import io 11 | import numpy as np 12 | from PIL import Image 13 | from PIL import ImageFile 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def is_zip_path(img_or_path): 19 | """judge if this is a zip path""" 20 | return '.zip@' in img_or_path 21 | 22 | 23 | class ZipReader(object): 24 | """A class to read zipped files""" 25 | zip_bank = dict() 26 | 27 | def __init__(self): 28 | super(ZipReader, self).__init__() 29 | 30 | @staticmethod 31 | def get_zipfile(path): 32 | zip_bank = ZipReader.zip_bank 33 | if path not in zip_bank: 34 | zfile = zipfile.ZipFile(path, 'r') 35 | zip_bank[path] = zfile 36 | return zip_bank[path] 37 | 38 | @staticmethod 39 | def split_zip_style_path(path): 40 | pos_at = path.index('@') 41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 42 | 43 | zip_path = path[0: pos_at] 44 | folder_path = path[pos_at + 1:] 45 | folder_path = str.strip(folder_path, '/') 46 | return zip_path, folder_path 47 | 48 | @staticmethod 49 | def list_folder(path): 50 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 51 | 52 | zfile = ZipReader.get_zipfile(zip_path) 53 | folder_list = [] 54 | for file_foler_name in zfile.namelist(): 55 | file_foler_name = str.strip(file_foler_name, '/') 56 | if file_foler_name.startswith(folder_path) and \ 57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 58 | file_foler_name != folder_path: 59 | if len(folder_path) == 0: 60 | folder_list.append(file_foler_name) 61 | else: 62 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 63 | 64 | return folder_list 65 | 66 | @staticmethod 67 | def list_files(path, extension=None): 68 | if extension is None: 69 | extension = ['.*'] 70 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 71 | 72 | zfile = ZipReader.get_zipfile(zip_path) 73 | file_lists = [] 74 | for file_foler_name in zfile.namelist(): 75 | file_foler_name = str.strip(file_foler_name, '/') 76 | if file_foler_name.startswith(folder_path) and \ 77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 78 | if len(folder_path) == 0: 79 | file_lists.append(file_foler_name) 80 | else: 81 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 82 | 83 | return file_lists 84 | 85 | @staticmethod 86 | def read(path): 87 | zip_path, path_img = ZipReader.split_zip_style_path(path) 88 | zfile = ZipReader.get_zipfile(zip_path) 89 | data = zfile.read(path_img) 90 | return data 91 | 92 | @staticmethod 93 | def imread(path): 94 | zip_path, path_img = ZipReader.split_zip_style_path(path) 95 | zfile = ZipReader.get_zipfile(zip_path) 96 | data = zfile.read(path_img) 97 | try: 98 | im = Image.open(io.BytesIO(data)) 99 | except: 100 | print("ERROR IMG LOADED: ", path_img) 101 | random_img = np.random.rand(224, 224, 3) * 255 102 | im = Image.fromarray(np.uint8(random_img)) 103 | return im 104 | -------------------------------------------------------------------------------- /figures/unicl_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/UniCL/9983bf89660f16dc923530a14d56a83b4699dcbc/figures/unicl_comparison.png -------------------------------------------------------------------------------- /figures/unified_cv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/UniCL/9983bf89660f16dc923530a14d56a83b4699dcbc/figures/unified_cv.png -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from timm.scheduler.cosine_lr import CosineLRScheduler 10 | from timm.scheduler.step_lr import StepLRScheduler 11 | from timm.scheduler.scheduler import Scheduler 12 | 13 | 14 | def build_scheduler(config, optimizer, n_iter_per_epoch): 15 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 16 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 17 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 18 | 19 | lr_scheduler = None 20 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 21 | lr_scheduler = CosineLRScheduler( 22 | optimizer, 23 | t_initial=num_steps, 24 | t_mul=1., 25 | lr_min=config.TRAIN.MIN_LR, 26 | warmup_lr_init=config.TRAIN.WARMUP_LR, 27 | warmup_t=warmup_steps, 28 | cycle_limit=1, 29 | t_in_epochs=False, 30 | ) 31 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 32 | lr_scheduler = LinearLRScheduler( 33 | optimizer, 34 | t_initial=num_steps, 35 | lr_min_rate=0.01, 36 | warmup_lr_init=config.TRAIN.WARMUP_LR, 37 | warmup_t=warmup_steps, 38 | t_in_epochs=False, 39 | ) 40 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 41 | lr_scheduler = StepLRScheduler( 42 | optimizer, 43 | decay_t=decay_steps, 44 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 45 | warmup_lr_init=config.TRAIN.WARMUP_LR, 46 | warmup_t=warmup_steps, 47 | t_in_epochs=False, 48 | ) 49 | 50 | return lr_scheduler 51 | 52 | 53 | class LinearLRScheduler(Scheduler): 54 | def __init__(self, 55 | optimizer: torch.optim.Optimizer, 56 | t_initial: int, 57 | lr_min_rate: float, 58 | warmup_t=0, 59 | warmup_lr_init=0., 60 | t_in_epochs=True, 61 | noise_range_t=None, 62 | noise_pct=0.67, 63 | noise_std=1.0, 64 | noise_seed=42, 65 | initialize=True, 66 | ) -> None: 67 | super().__init__( 68 | optimizer, param_group_field="lr", 69 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 70 | initialize=initialize) 71 | 72 | self.t_initial = t_initial 73 | self.lr_min_rate = lr_min_rate 74 | self.warmup_t = warmup_t 75 | self.warmup_lr_init = warmup_lr_init 76 | self.t_in_epochs = t_in_epochs 77 | if self.warmup_t: 78 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 79 | super().update_groups(self.warmup_lr_init) 80 | else: 81 | self.warmup_steps = [1 for _ in self.base_values] 82 | 83 | def _get_lr(self, t): 84 | if t < self.warmup_t: 85 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 86 | else: 87 | t = t - self.warmup_t 88 | total_t = self.t_initial - self.warmup_t 89 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 90 | return lrs 91 | 92 | def get_epoch_values(self, epoch: int): 93 | if self.t_in_epochs: 94 | return self._get_lr(epoch) 95 | else: 96 | return None 97 | 98 | def get_update_values(self, num_updates: int): 99 | if not self.t_in_epochs: 100 | return self._get_lr(num_updates) 101 | else: 102 | return None 103 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Unified Contrastive Learning (UniCL) in Image-Text-Label Space 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Jianwei Yang (jianwyan@microsoft.com) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import time 10 | import argparse 11 | import datetime 12 | import numpy as np 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | 18 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 19 | from timm.utils import accuracy, AverageMeter 20 | 21 | from config import get_config 22 | from model import build_model 23 | from data import build_loader 24 | from lr_scheduler import build_scheduler 25 | from optimizer import build_optimizer 26 | from logger import create_logger 27 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor 28 | 29 | try: 30 | # noinspection PyUnresolvedReferences 31 | from apex import amp 32 | except ImportError: 33 | amp = None 34 | 35 | 36 | def parse_option(): 37 | parser = argparse.ArgumentParser('UniCL training and evaluation script', add_help=False) 38 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 39 | parser.add_argument( 40 | "--opts", 41 | help="Modify config options by adding 'KEY VALUE' pairs. ", 42 | default=None, 43 | nargs='+', 44 | ) 45 | 46 | # easy config modification 47 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 48 | parser.add_argument('--dataset', type=str, default='imagenet', help='dataset name') 49 | parser.add_argument('--data-path', type=str, help='path to dataset') 50 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 51 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 52 | help='no: no cache, ' 53 | 'full: cache all data, ' 54 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 55 | parser.add_argument('--resume', help='resume from checkpoint') 56 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 57 | parser.add_argument('--use-checkpoint', action='store_true', 58 | help="whether to use gradient checkpointing to save memory") 59 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 60 | help='mixed precision opt level, if O0, no amp is used') 61 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 62 | help='root of output folder, the full path is // (default: output)') 63 | parser.add_argument('--tag', help='tag of experiment') 64 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 65 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 66 | parser.add_argument('--debug', action='store_true', help='Perform debug only') 67 | 68 | # distributed training 69 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 70 | 71 | args, unparsed = parser.parse_known_args() 72 | 73 | config = get_config(args) 74 | 75 | return args, config 76 | 77 | 78 | def main(config): 79 | 80 | if not config.DEBUG_MODE: 81 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) 82 | 83 | logger.info(f"Creating model:{config.MODEL.NAME}") 84 | 85 | model = build_model(config) 86 | model.cuda() 87 | logger.info(str(model)) 88 | 89 | optimizer = build_optimizer(config, model) 90 | if config.AMP_OPT_LEVEL != "O0": 91 | model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 93 | model_without_ddp = model.module 94 | 95 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 96 | logger.info(f"number of params: {n_parameters}") 97 | if hasattr(model_without_ddp, 'flops'): 98 | flops = model_without_ddp.flops() 99 | logger.info(f"number of GFLOPs: {flops / 1e9}") 100 | 101 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 102 | 103 | if config.AUG.MIXUP > 0.: 104 | # smoothing is handled with mixup label transform 105 | criterion = SoftTargetCrossEntropy() 106 | elif config.MODEL.LABEL_SMOOTHING > 0.: 107 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 108 | else: 109 | criterion = torch.nn.CrossEntropyLoss() 110 | 111 | max_accuracy = 0.0 112 | 113 | if config.TRAIN.AUTO_RESUME: 114 | resume_file = auto_resume_helper(config.OUTPUT) 115 | if resume_file: 116 | if config.MODEL.RESUME: 117 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 118 | config.defrost() 119 | config.MODEL.RESUME = resume_file 120 | config.freeze() 121 | logger.info(f'auto resuming from {resume_file}') 122 | else: 123 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 124 | 125 | if config.MODEL.RESUME: 126 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 127 | acc1, acc5, loss = validate(config, data_loader_val, model) 128 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 129 | if config.EVAL_MODE: 130 | return 131 | 132 | if config.THROUGHPUT_MODE: 133 | throughput(data_loader_val, model, logger) 134 | return 135 | 136 | logger.info("Start training") 137 | start_time = time.time() 138 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 139 | data_loader_train.sampler.set_epoch(epoch) 140 | 141 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) 142 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 143 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) 144 | 145 | acc1, acc5, loss = validate(config, data_loader_val, model) 146 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 147 | max_accuracy = max(max_accuracy, acc1) 148 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 149 | 150 | total_time = time.time() - start_time 151 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 152 | logger.info('Training time {}'.format(total_time_str)) 153 | 154 | 155 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): 156 | model.train() 157 | optimizer.zero_grad() 158 | 159 | num_steps = len(data_loader) 160 | batch_time = AverageMeter() 161 | loss_meter = AverageMeter() 162 | norm_meter = AverageMeter() 163 | 164 | start = time.time() 165 | end = time.time() 166 | for idx, (samples, tokens, targets) in enumerate(data_loader): 167 | samples = samples.cuda(non_blocking=True) 168 | targets = targets.cuda(non_blocking=True) 169 | tokens = {key:val.cuda(non_blocking=True) for key,val in tokens.items()} 170 | 171 | if mixup_fn is not None: 172 | samples, targets = mixup_fn(samples, targets) 173 | 174 | features_image, features_text, tau = model(samples, tokens) 175 | ouputs = tau * features_image @ features_text.t() 176 | 177 | # gather features from all gpus 178 | 179 | if config.TRAIN.ACCUMULATION_STEPS > 1: 180 | loss = criterion(outputs, targets) 181 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 182 | if config.AMP_OPT_LEVEL != "O0": 183 | with amp.scale_loss(loss, optimizer) as scaled_loss: 184 | scaled_loss.backward() 185 | if config.TRAIN.CLIP_GRAD: 186 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 187 | else: 188 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 189 | else: 190 | loss.backward() 191 | if config.TRAIN.CLIP_GRAD: 192 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 193 | else: 194 | grad_norm = get_grad_norm(model.parameters()) 195 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 196 | optimizer.step() 197 | optimizer.zero_grad() 198 | lr_scheduler.step_update(epoch * num_steps + idx) 199 | else: 200 | loss = criterion(outputs, targets) 201 | optimizer.zero_grad() 202 | if config.AMP_OPT_LEVEL != "O0": 203 | with amp.scale_loss(loss, optimizer) as scaled_loss: 204 | scaled_loss.backward() 205 | if config.TRAIN.CLIP_GRAD: 206 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 207 | else: 208 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 209 | else: 210 | loss.backward() 211 | if config.TRAIN.CLIP_GRAD: 212 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 213 | else: 214 | grad_norm = get_grad_norm(model.parameters()) 215 | 216 | optimizer.step() 217 | lr_scheduler.step_update(epoch * num_steps + idx) 218 | 219 | torch.cuda.synchronize() 220 | 221 | loss_meter.update(loss.item(), targets.size(0)) 222 | norm_meter.update(grad_norm) 223 | batch_time.update(time.time() - end) 224 | end = time.time() 225 | 226 | if idx % config.PRINT_FREQ == 0: 227 | lr = optimizer.param_groups[0]['lr'] 228 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 229 | etas = batch_time.avg * (num_steps - idx) 230 | logger.info( 231 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 232 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 233 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 234 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 235 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 236 | f'mem {memory_used:.0f}MB') 237 | epoch_time = time.time() - start 238 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 239 | 240 | 241 | @torch.no_grad() 242 | def validate(config, data_loader, model): 243 | criterion = torch.nn.CrossEntropyLoss() 244 | model.eval() 245 | imnet_text_embeddings = model.module.get_imnet_embeddings() 246 | 247 | batch_time = AverageMeter() 248 | loss_meter = AverageMeter() 249 | acc1_meter = AverageMeter() 250 | acc5_meter = AverageMeter() 251 | 252 | end = time.time() 253 | for idx, (images, tokens, target) in enumerate(data_loader): 254 | 255 | images = images.cuda(non_blocking=True) 256 | target = target.cuda(non_blocking=True) 257 | tokens = {key:val.cuda(non_blocking=True) for key,val in tokens.items()} 258 | 259 | # compute output 260 | feat_img = model.module.encode_image(images) 261 | output = model.module.logit_scale.exp() * feat_img @ imnet_text_embeddings.t() 262 | 263 | # measure accuracy and record loss 264 | loss = criterion(output, target) 265 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 266 | 267 | acc1 = reduce_tensor(acc1) 268 | acc5 = reduce_tensor(acc5) 269 | loss = reduce_tensor(loss) 270 | 271 | loss_meter.update(loss.item(), target.size(0)) 272 | acc1_meter.update(acc1.item(), target.size(0)) 273 | acc5_meter.update(acc5.item(), target.size(0)) 274 | 275 | # measure elapsed time 276 | batch_time.update(time.time() - end) 277 | end = time.time() 278 | 279 | if idx % config.PRINT_FREQ == 0: 280 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 281 | logger.info( 282 | f'Test: [{idx}/{len(data_loader)}]\t' 283 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 284 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 285 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 286 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 287 | f'Mem {memory_used:.0f}MB') 288 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 289 | 290 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 291 | 292 | 293 | @torch.no_grad() 294 | def throughput(data_loader, model, logger): 295 | model.eval() 296 | 297 | for idx, (images, _) in enumerate(data_loader): 298 | images = images.cuda(non_blocking=True) 299 | batch_size = images.shape[0] 300 | for i in range(50): 301 | model(images) 302 | torch.cuda.synchronize() 303 | logger.info(f"throughput averaged with 30 times") 304 | tic1 = time.time() 305 | for i in range(30): 306 | model(images) 307 | torch.cuda.synchronize() 308 | tic2 = time.time() 309 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 310 | return 311 | 312 | 313 | if __name__ == '__main__': 314 | _, config = parse_option() 315 | 316 | if config.AMP_OPT_LEVEL != "O0": 317 | assert amp is not None, "amp not installed!" 318 | 319 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 320 | rank = int(os.environ["RANK"]) 321 | world_size = int(os.environ['WORLD_SIZE']) 322 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 323 | else: 324 | rank = -1 325 | world_size = -1 326 | torch.cuda.set_device(config.LOCAL_RANK) 327 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 328 | torch.distributed.barrier() 329 | 330 | seed = config.SEED + dist.get_rank() 331 | torch.manual_seed(seed) 332 | np.random.seed(seed) 333 | cudnn.benchmark = True 334 | 335 | # linear scale the learning rate according to total batch size, may not be optimal 336 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 337 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 338 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 339 | # gradient accumulation also need to scale the learning rate 340 | if config.TRAIN.ACCUMULATION_STEPS > 1: 341 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 342 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 343 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 344 | config.defrost() 345 | config.TRAIN.BASE_LR = linear_scaled_lr 346 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 347 | config.TRAIN.MIN_LR = linear_scaled_min_lr 348 | config.freeze() 349 | 350 | config.defrost() 351 | config.OUTPUT = os.getenv('PT_OUTPUT_DIR') if os.getenv('PT_OUTPUT_DIR') else config.OUTPUT 352 | config.freeze() 353 | 354 | os.makedirs(config.OUTPUT, exist_ok=True) 355 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 356 | 357 | if dist.get_rank() == 0: 358 | path = os.path.join(config.OUTPUT, "config.json") 359 | with open(path, "w") as f: 360 | f.write(config.dump()) 361 | logger.info(f"Full config saved to {path}") 362 | 363 | # print config 364 | logger.info(config.dump()) 365 | 366 | main(config) 367 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import build_unicl_model as build_model -------------------------------------------------------------------------------- /model/image_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model as build_image_encoder -------------------------------------------------------------------------------- /model/image_encoder/build.py: -------------------------------------------------------------------------------- 1 | from timm.models import create_model 2 | from .swin_transformer import SwinTransformer 3 | 4 | def build_model(config): 5 | model_type = config.TYPE 6 | print(f"Creating model: {model_type}") 7 | 8 | if "swin" in model_type: 9 | model = SwinTransformer( 10 | num_classes=0, 11 | img_size=config.IMG_SIZE, 12 | patch_size=config.SWIN.PATCH_SIZE, 13 | in_chans=config.SWIN.IN_CHANS, 14 | embed_dim=config.SWIN.EMBED_DIM, 15 | depths=config.SWIN.DEPTHS, 16 | num_heads=config.SWIN.NUM_HEADS, 17 | window_size=config.SWIN.WINDOW_SIZE, 18 | mlp_ratio=config.SWIN.MLP_RATIO, 19 | qkv_bias=config.SWIN.QKV_BIAS, 20 | qk_scale=config.SWIN.QK_SCALE, 21 | drop_rate=config.DROP_RATE, 22 | drop_path_rate=config.DROP_PATH_RATE, 23 | ape=config.SWIN.APE, 24 | patch_norm=config.SWIN.PATCH_NORM, 25 | use_checkpoint=False 26 | ) 27 | elif "vit" in model_type: 28 | model = create_model( 29 | model_type, 30 | pretrained=is_pretrained, 31 | img_size=config.DATA.IMG_SIZE, 32 | num_classes=config.MODEL.NUM_CLASSES, 33 | ) 34 | elif "resnet" in model_type: 35 | model = create_model( 36 | model_type, 37 | pretrained=is_pretrained, 38 | num_classes=config.MODEL.NUM_CLASSES 39 | ) 40 | else: 41 | model = create_model( 42 | model_type, 43 | pretrained=is_pretrained, 44 | num_classes=config.MODEL.NUM_CLASSES 45 | ) 46 | return model 47 | -------------------------------------------------------------------------------- /model/image_encoder/focalnet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # FocalNets -- Focal Modulation Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Jianwei Yang (jianwyan@microsoft.com) 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | from timm.models.registry import register_model 14 | 15 | from torchvision import transforms 16 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 17 | from timm.data import create_transform 18 | from timm.data.transforms import _pil_interp 19 | 20 | class Mlp(nn.Module): 21 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 22 | super().__init__() 23 | out_features = out_features or in_features 24 | hidden_features = hidden_features or in_features 25 | self.fc1 = nn.Linear(in_features, hidden_features) 26 | self.act = act_layer() 27 | self.fc2 = nn.Linear(hidden_features, out_features) 28 | self.drop = nn.Dropout(drop) 29 | 30 | def forward(self, x): 31 | x = self.fc1(x) 32 | x = self.act(x) 33 | x = self.drop(x) 34 | x = self.fc2(x) 35 | x = self.drop(x) 36 | return x 37 | 38 | class FocalModulation(nn.Module): 39 | def __init__(self, dim, focal_window, focal_level, focal_factor=2, bias=True, proj_drop=0., use_postln=False): 40 | super().__init__() 41 | 42 | self.dim = dim 43 | self.focal_window = focal_window 44 | self.focal_level = focal_level 45 | self.focal_factor = focal_factor 46 | self.use_postln = use_postln 47 | 48 | self.f = nn.Linear(dim, 2*dim + (self.focal_level+1), bias=bias) 49 | self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias) 50 | 51 | self.act = nn.GELU() 52 | self.proj = nn.Linear(dim, dim) 53 | self.proj_drop = nn.Dropout(proj_drop) 54 | self.focal_layers = nn.ModuleList() 55 | 56 | self.kernel_sizes = [] 57 | for k in range(self.focal_level): 58 | kernel_size = self.focal_factor*k + self.focal_window 59 | self.focal_layers.append( 60 | nn.Sequential( 61 | nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, 62 | groups=dim, padding=kernel_size//2, bias=False), 63 | nn.GELU(), 64 | ) 65 | ) 66 | self.kernel_sizes.append(kernel_size) 67 | if self.use_postln: 68 | self.ln = nn.LayerNorm(dim) 69 | 70 | def forward(self, x): 71 | """ 72 | Args: 73 | x: input features with shape of (B, H, W, C) 74 | """ 75 | C = x.shape[-1] 76 | 77 | # pre linear projection 78 | x = self.f(x).permute(0, 3, 1, 2).contiguous() 79 | q, ctx, self.gates = torch.split(x, (C, C, self.focal_level+1), 1) 80 | 81 | # context aggreation 82 | ctx_all = 0 83 | for l in range(self.focal_level): 84 | ctx = self.focal_layers[l](ctx) 85 | ctx_all = ctx_all + ctx*self.gates[:, l:l+1] 86 | ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) 87 | ctx_all = ctx_all + ctx_global*self.gates[:,self.focal_level:] 88 | 89 | # focal modulation 90 | self.modulator = self.h(ctx_all) 91 | x_out = q*self.modulator 92 | x_out = x_out.permute(0, 2, 3, 1).contiguous() 93 | if self.use_postln: 94 | x_out = self.ln(x_out) 95 | 96 | # post linear porjection 97 | x_out = self.proj(x_out) 98 | x_out = self.proj_drop(x_out) 99 | return x_out 100 | 101 | def extra_repr(self) -> str: 102 | return f'dim={self.dim}' 103 | 104 | def flops(self, N): 105 | # calculate flops for 1 window with token length of N 106 | flops = 0 107 | 108 | flops += N * self.dim * (self.dim * 2 + (self.focal_level+1)) 109 | 110 | # focal convolution 111 | for k in range(self.focal_level): 112 | flops += N * (self.kernel_sizes[k]**2+1) * self.dim 113 | 114 | # global gating 115 | flops += N * 1 * self.dim 116 | 117 | # self.linear 118 | flops += N * self.dim * (self.dim + 1) 119 | 120 | # x = self.proj(x) 121 | flops += N * self.dim * self.dim 122 | return flops 123 | 124 | class FocalNetBlock(nn.Module): 125 | r""" Focal Modulation Network Block. 126 | 127 | Args: 128 | dim (int): Number of input channels. 129 | input_resolution (tuple[int]): Input resulotion. 130 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 131 | drop (float, optional): Dropout rate. Default: 0.0 132 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 133 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 134 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 135 | focal_level (int): Number of focal levels. 136 | focal_window (int): Focal window size at first focal level 137 | use_layerscale (bool): Whether use layerscale 138 | layerscale_value (float): Initial layerscale value 139 | use_postln (bool): Whether use layernorm after modulation 140 | """ 141 | 142 | def __init__(self, dim, input_resolution, mlp_ratio=4., drop=0., drop_path=0., 143 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, 144 | focal_level=1, focal_window=3, 145 | use_layerscale=False, layerscale_value=1e-4, 146 | use_postln=False): 147 | super().__init__() 148 | self.dim = dim 149 | self.input_resolution = input_resolution 150 | self.mlp_ratio = mlp_ratio 151 | 152 | self.focal_window = focal_window 153 | self.focal_level = focal_level 154 | 155 | self.norm1 = norm_layer(dim) 156 | self.modulation = FocalModulation(dim, proj_drop=drop, focal_window=focal_window, focal_level=self.focal_level, use_postln=use_postln) 157 | 158 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 159 | self.norm2 = norm_layer(dim) 160 | mlp_hidden_dim = int(dim * mlp_ratio) 161 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 162 | 163 | self.gamma_1 = 1.0 164 | self.gamma_2 = 1.0 165 | if use_layerscale: 166 | self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) 167 | self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) 168 | 169 | self.H = None 170 | self.W = None 171 | 172 | def forward(self, x): 173 | H, W = self.H, self.W 174 | B, L, C = x.shape 175 | shortcut = x 176 | 177 | # Focal Modulation 178 | x = self.norm1(x) 179 | x = x.view(B, H, W, C) 180 | x = self.modulation(x).view(B, H * W, C) 181 | 182 | # FFN 183 | x = shortcut + self.drop_path(self.gamma_1 * x) 184 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 185 | 186 | return x 187 | 188 | def extra_repr(self) -> str: 189 | return f"dim={self.dim}, input_resolution={self.input_resolution}, " \ 190 | f"mlp_ratio={self.mlp_ratio}" 191 | 192 | def flops(self): 193 | flops = 0 194 | H, W = self.input_resolution 195 | # norm1 196 | flops += self.dim * H * W 197 | 198 | # W-MSA/SW-MSA 199 | flops += self.modulation.flops(H*W) 200 | 201 | # mlp 202 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 203 | # norm2 204 | flops += self.dim * H * W 205 | return flops 206 | 207 | class BasicLayer(nn.Module): 208 | """ A basic Focal Transformer layer for one stage. 209 | 210 | Args: 211 | dim (int): Number of input channels. 212 | input_resolution (tuple[int]): Input resolution. 213 | depth (int): Number of blocks. 214 | window_size (int): Local window size. 215 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 216 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 217 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 218 | drop (float, optional): Dropout rate. Default: 0.0 219 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 220 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 221 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 222 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 223 | focal_level (int): Number of focal levels 224 | focal_window (int): Focal window size at first focal level 225 | use_layerscale (bool): Whether use layerscale 226 | layerscale_value (float): Initial layerscale value 227 | use_postln (bool): Whether use layernorm after modulation 228 | """ 229 | 230 | def __init__(self, dim, out_dim, input_resolution, depth, 231 | mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, 232 | downsample=None, use_checkpoint=False, 233 | focal_level=1, focal_window=1, 234 | use_conv_embed=False, 235 | use_layerscale=False, layerscale_value=1e-4, use_postln=False): 236 | 237 | super().__init__() 238 | self.dim = dim 239 | self.input_resolution = input_resolution 240 | self.depth = depth 241 | self.use_checkpoint = use_checkpoint 242 | 243 | # build blocks 244 | self.blocks = nn.ModuleList([ 245 | FocalNetBlock( 246 | dim=dim, 247 | input_resolution=input_resolution, 248 | mlp_ratio=mlp_ratio, 249 | drop=drop, 250 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 251 | norm_layer=norm_layer, 252 | focal_level=focal_level, 253 | focal_window=focal_window, 254 | use_layerscale=use_layerscale, 255 | layerscale_value=layerscale_value, 256 | use_postln=use_postln, 257 | ) 258 | for i in range(depth)]) 259 | 260 | if downsample is not None: 261 | self.downsample = downsample( 262 | img_size=input_resolution, 263 | patch_size=2, 264 | in_chans=dim, 265 | embed_dim=out_dim, 266 | use_conv_embed=use_conv_embed, 267 | norm_layer=norm_layer, 268 | is_stem=False 269 | ) 270 | else: 271 | self.downsample = None 272 | 273 | def forward(self, x, H, W): 274 | for blk in self.blocks: 275 | blk.H, blk.W = H, W 276 | if self.use_checkpoint: 277 | x = checkpoint.checkpoint(blk, x) 278 | else: 279 | x = blk(x) 280 | 281 | if self.downsample is not None: 282 | x = x.transpose(1, 2).reshape(x.shape[0], -1, H, W) 283 | x, Ho, Wo = self.downsample(x) 284 | else: 285 | Ho, Wo = H, W 286 | return x, Ho, Wo 287 | 288 | def extra_repr(self) -> str: 289 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 290 | 291 | def flops(self): 292 | flops = 0 293 | for blk in self.blocks: 294 | flops += blk.flops() 295 | if self.downsample is not None: 296 | flops += self.downsample.flops() 297 | return flops 298 | 299 | class PatchEmbed(nn.Module): 300 | r""" Image to Patch Embedding 301 | 302 | Args: 303 | img_size (int): Image size. Default: 224. 304 | patch_size (int): Patch token size. Default: 4. 305 | in_chans (int): Number of input image channels. Default: 3. 306 | embed_dim (int): Number of linear projection output channels. Default: 96. 307 | norm_layer (nn.Module, optional): Normalization layer. Default: None 308 | """ 309 | 310 | def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, use_conv_embed=False, norm_layer=None, is_stem=False): 311 | super().__init__() 312 | patch_size = to_2tuple(patch_size) 313 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 314 | self.img_size = img_size 315 | self.patch_size = patch_size 316 | self.patches_resolution = patches_resolution 317 | self.num_patches = patches_resolution[0] * patches_resolution[1] 318 | 319 | self.in_chans = in_chans 320 | self.embed_dim = embed_dim 321 | 322 | if use_conv_embed: 323 | # if we choose to use conv embedding, then we treat the stem and non-stem differently 324 | if is_stem: 325 | kernel_size = 7; padding = 2; stride = 4 326 | else: 327 | kernel_size = 3; padding = 1; stride = 2 328 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) 329 | else: 330 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 331 | 332 | if norm_layer is not None: 333 | self.norm = norm_layer(embed_dim) 334 | else: 335 | self.norm = None 336 | 337 | def forward(self, x): 338 | B, C, H, W = x.shape 339 | 340 | x = self.proj(x) 341 | H, W = x.shape[2:] 342 | x = x.flatten(2).transpose(1, 2) # B Ph*Pw C 343 | if self.norm is not None: 344 | x = self.norm(x) 345 | return x, H, W 346 | 347 | def flops(self): 348 | Ho, Wo = self.patches_resolution 349 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 350 | if self.norm is not None: 351 | flops += Ho * Wo * self.embed_dim 352 | return flops 353 | 354 | class FocalNet(nn.Module): 355 | r""" Focal Modulation Networks (FocalNets) 356 | 357 | Args: 358 | img_size (int | tuple(int)): Input image size. Default 224 359 | patch_size (int | tuple(int)): Patch size. Default: 4 360 | in_chans (int): Number of input image channels. Default: 3 361 | num_classes (int): Number of classes for classification head. Default: 1000 362 | embed_dim (int): Patch embedding dimension. Default: 96 363 | depths (tuple(int)): Depth of each Focal Transformer layer. 364 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 365 | drop_rate (float): Dropout rate. Default: 0 366 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 367 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 368 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 369 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 370 | focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. Default: [1, 1, 1, 1] 371 | focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1] 372 | use_conv_embed (bool): Whether use convolutional embedding. We noted that using convolutional embedding usually improve the performance, but we do not use it by default. Default: False 373 | use_layerscale (bool): Whether use layerscale proposed in CaiT. Default: False 374 | layerscale_value (float): Value for layer scale. Default: 1e-4 375 | use_postln (bool): Whether use layernorm after modulation (it helps stablize training of large models) 376 | """ 377 | def __init__(self, 378 | img_size=224, 379 | patch_size=4, 380 | in_chans=3, 381 | num_classes=1000, 382 | embed_dim=96, 383 | depths=[2, 2, 6, 2], 384 | mlp_ratio=4., 385 | drop_rate=0., 386 | drop_path_rate=0.1, 387 | norm_layer=nn.LayerNorm, 388 | patch_norm=True, 389 | use_checkpoint=False, 390 | focal_levels=[2, 2, 2, 2], 391 | focal_windows=[3, 3, 3, 3], 392 | use_conv_embed=False, 393 | use_layerscale=False, 394 | layerscale_value=1e-4, 395 | use_postln=False, 396 | **kwargs): 397 | super().__init__() 398 | 399 | self.num_layers = len(depths) 400 | embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)] 401 | 402 | self.num_classes = num_classes 403 | self.embed_dim = embed_dim 404 | self.patch_norm = patch_norm 405 | self.num_features = embed_dim[-1] 406 | self.mlp_ratio = mlp_ratio 407 | 408 | # split image into patches using either non-overlapped embedding or overlapped embedding 409 | self.patch_embed = PatchEmbed( 410 | img_size=to_2tuple(img_size), 411 | patch_size=patch_size, 412 | in_chans=in_chans, 413 | embed_dim=embed_dim[0], 414 | use_conv_embed=use_conv_embed, 415 | norm_layer=norm_layer if self.patch_norm else None, 416 | is_stem=True) 417 | 418 | num_patches = self.patch_embed.num_patches 419 | patches_resolution = self.patch_embed.patches_resolution 420 | self.patches_resolution = patches_resolution 421 | self.pos_drop = nn.Dropout(p=drop_rate) 422 | 423 | # stochastic depth 424 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 425 | 426 | # build layers 427 | self.layers = nn.ModuleList() 428 | for i_layer in range(self.num_layers): 429 | layer = BasicLayer(dim=embed_dim[i_layer], 430 | out_dim=embed_dim[i_layer+1] if (i_layer < self.num_layers - 1) else None, 431 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 432 | patches_resolution[1] // (2 ** i_layer)), 433 | depth=depths[i_layer], 434 | mlp_ratio=self.mlp_ratio, 435 | drop=drop_rate, 436 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 437 | norm_layer=norm_layer, 438 | downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, 439 | focal_level=focal_levels[i_layer], 440 | focal_window=focal_windows[i_layer], 441 | use_conv_embed=use_conv_embed, 442 | use_checkpoint=use_checkpoint, 443 | use_layerscale=use_layerscale, 444 | layerscale_value=layerscale_value, 445 | use_postln=use_postln, 446 | ) 447 | self.layers.append(layer) 448 | 449 | self.norm = norm_layer(self.num_features) 450 | self.avgpool = nn.AdaptiveAvgPool1d(1) 451 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 452 | self.dim_out = self.num_features 453 | 454 | self.apply(self._init_weights) 455 | 456 | def _init_weights(self, m): 457 | if isinstance(m, nn.Linear): 458 | trunc_normal_(m.weight, std=.02) 459 | if isinstance(m, nn.Linear) and m.bias is not None: 460 | nn.init.constant_(m.bias, 0) 461 | elif isinstance(m, nn.LayerNorm): 462 | nn.init.constant_(m.bias, 0) 463 | nn.init.constant_(m.weight, 1.0) 464 | 465 | @torch.jit.ignore 466 | def no_weight_decay(self): 467 | return {''} 468 | 469 | @torch.jit.ignore 470 | def no_weight_decay_keywords(self): 471 | return {''} 472 | 473 | def forward_features(self, x): 474 | x, H, W = self.patch_embed(x) 475 | x = self.pos_drop(x) 476 | 477 | for layer in self.layers: 478 | x, H, W = layer(x, H, W) 479 | x = self.norm(x) # B L C 480 | x = self.avgpool(x.transpose(1, 2)) # B C 1 481 | x = torch.flatten(x, 1) 482 | return x 483 | 484 | def forward(self, x): 485 | x = self.forward_features(x) 486 | x = self.head(x) 487 | return x 488 | 489 | def flops(self): 490 | flops = 0 491 | flops += self.patch_embed.flops() 492 | for i, layer in enumerate(self.layers): 493 | flops += layer.flops() 494 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 495 | flops += self.num_features * self.num_classes 496 | return flops 497 | 498 | def build_transforms(img_size, center_crop=False): 499 | t = [] 500 | if center_crop: 501 | size = int((256 / 224) * img_size) 502 | t.append( 503 | transforms.Resize(size, interpolation=_pil_interp('bicubic')) 504 | ) 505 | t.append( 506 | transforms.CenterCrop(img_size) 507 | ) 508 | else: 509 | t.append( 510 | transforms.Resize(img_size, interpolation=_pil_interp('bicubic')) 511 | ) 512 | t.append(transforms.ToTensor()) 513 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 514 | return transforms.Compose(t) 515 | 516 | def build_transforms4display(img_size, center_crop=False): 517 | t = [] 518 | if center_crop: 519 | size = int((256 / 224) * img_size) 520 | t.append( 521 | transforms.Resize(size, interpolation=_pil_interp('bicubic')) 522 | ) 523 | t.append( 524 | transforms.CenterCrop(img_size) 525 | ) 526 | else: 527 | t.append( 528 | transforms.Resize(img_size, interpolation=_pil_interp('bicubic')) 529 | ) 530 | t.append(transforms.ToTensor()) 531 | return transforms.Compose(t) 532 | 533 | model_urls = { 534 | "focalnet_tiny_srf": "", 535 | "focalnet_small_srf": "", 536 | "focalnet_base_srf": "", 537 | "focalnet_tiny_lrf": "", 538 | "focalnet_small_lrf": "", 539 | "focalnet_base_lrf": "", 540 | } 541 | 542 | @register_model 543 | def focalnet_tiny_srf(pretrained=False, **kwargs): 544 | model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, **kwargs) 545 | if pretrained: 546 | url = model_urls['focalnet_tiny_srf'] 547 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 548 | model.load_state_dict(checkpoint["model"]) 549 | return model 550 | 551 | @register_model 552 | def focalnet_small_srf(pretrained=False, **kwargs): 553 | model = FocalNet(depths=[2, 2, 18, 2], embed_dim=96, **kwargs) 554 | if pretrained: 555 | url = model_urls['focalnet_small_srf'] 556 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 557 | model.load_state_dict(checkpoint["model"]) 558 | return model 559 | 560 | @register_model 561 | def focalnet_base_srf(pretrained=False, **kwargs): 562 | model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, **kwargs) 563 | if pretrained: 564 | url = model_urls['focalnet_base_srf'] 565 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 566 | model.load_state_dict(checkpoint["model"]) 567 | return model 568 | 569 | @register_model 570 | def focalnet_tiny_lrf(pretrained=False, **kwargs): 571 | model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs) 572 | if pretrained: 573 | url = model_urls['focalnet_tiny_lrf'] 574 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 575 | model.load_state_dict(checkpoint["model"]) 576 | return model 577 | 578 | @register_model 579 | def focalnet_small_lrf(pretrained=False, **kwargs): 580 | model = FocalNet(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs) 581 | if pretrained: 582 | url = model_urls['focalnet_small_lrf'] 583 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 584 | model.load_state_dict(checkpoint["model"]) 585 | return model 586 | 587 | @register_model 588 | def focalnet_base_lrf(pretrained=False, **kwargs): 589 | model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs) 590 | if pretrained: 591 | url = model_urls['focalnet_base_lrf'] 592 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 593 | model.load_state_dict(checkpoint["model"]) 594 | return model 595 | 596 | @register_model 597 | def focalnet_tiny_iso_16(pretrained=False, **kwargs): 598 | model = FocalNet(depths=[12], patch_size=16, embed_dim=192, focal_levels=[3], focal_windows=[3], **kwargs) 599 | if pretrained: 600 | url = model_urls['focalnet_tiny_iso_16'] 601 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 602 | model.load_state_dict(checkpoint["model"]) 603 | return model 604 | 605 | @register_model 606 | def focalnet_small_iso_16(pretrained=False, **kwargs): 607 | model = FocalNet(depths=[12], patch_size=16, embed_dim=384, focal_levels=[3], focal_windows=[3], **kwargs) 608 | if pretrained: 609 | url = model_urls['focalnet_small_iso_16'] 610 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 611 | model.load_state_dict(checkpoint["model"]) 612 | return model 613 | 614 | @register_model 615 | def focalnet_base_iso_16(pretrained=False, **kwargs): 616 | model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], use_layerscale=True, use_postln=True, **kwargs) 617 | if pretrained: 618 | url = model_urls['focalnet_base_iso_16'] 619 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 620 | model.load_state_dict(checkpoint["model"]) 621 | return model 622 | 623 | if __name__ == '__main__': 624 | img_size = 224 625 | x = torch.rand(16, 3, img_size, img_size).cuda() 626 | # model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96) 627 | # model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], focal_factors=[2]) 628 | model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3]).cuda() 629 | print(model); model(x) 630 | 631 | flops = model.flops() 632 | print(f"number of GFLOPs: {flops / 1e9}") 633 | 634 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 635 | print(f"number of params: {n_parameters}") 636 | -------------------------------------------------------------------------------- /model/image_encoder/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint as checkpoint 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.act = act_layer() 21 | self.fc2 = nn.Linear(hidden_features, out_features) 22 | self.drop = nn.Dropout(drop) 23 | 24 | def forward(self, x): 25 | x = self.fc1(x) 26 | x = self.act(x) 27 | x = self.drop(x) 28 | x = self.fc2(x) 29 | x = self.drop(x) 30 | return x 31 | 32 | 33 | def window_partition(x, window_size): 34 | """ 35 | Args: 36 | x: (B, H, W, C) 37 | window_size (int): window size 38 | 39 | Returns: 40 | windows: (num_windows*B, window_size, window_size, C) 41 | """ 42 | B, H, W, C = x.shape 43 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 44 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 45 | return windows 46 | 47 | 48 | def window_reverse(windows, window_size, H, W): 49 | """ 50 | Args: 51 | windows: (num_windows*B, window_size, window_size, C) 52 | window_size (int): Window size 53 | H (int): Height of image 54 | W (int): Width of image 55 | 56 | Returns: 57 | x: (B, H, W, C) 58 | """ 59 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 60 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 61 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 62 | return x 63 | 64 | 65 | class WindowAttention(nn.Module): 66 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 67 | It supports both of shifted and non-shifted window. 68 | 69 | Args: 70 | dim (int): Number of input channels. 71 | window_size (tuple[int]): The height and width of the window. 72 | num_heads (int): Number of attention heads. 73 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 74 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 75 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 76 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 77 | """ 78 | 79 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 80 | 81 | super().__init__() 82 | self.dim = dim 83 | self.window_size = window_size # Wh, Ww 84 | self.num_heads = num_heads 85 | head_dim = dim // num_heads 86 | self.scale = qk_scale or head_dim ** -0.5 87 | 88 | # define a parameter table of relative position bias 89 | self.relative_position_bias_table = nn.Parameter( 90 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 91 | 92 | # get pair-wise relative position index for each token inside the window 93 | coords_h = torch.arange(self.window_size[0]) 94 | coords_w = torch.arange(self.window_size[1]) 95 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 96 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 97 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 98 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 99 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 100 | relative_coords[:, :, 1] += self.window_size[1] - 1 101 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 102 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 103 | self.register_buffer("relative_position_index", relative_position_index) 104 | 105 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 106 | self.attn_drop = nn.Dropout(attn_drop) 107 | self.proj = nn.Linear(dim, dim) 108 | self.proj_drop = nn.Dropout(proj_drop) 109 | 110 | trunc_normal_(self.relative_position_bias_table, std=.02) 111 | self.softmax = nn.Softmax(dim=-1) 112 | 113 | def forward(self, x, mask=None): 114 | """ 115 | Args: 116 | x: input features with shape of (num_windows*B, N, C) 117 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 118 | """ 119 | B_, N, C = x.shape 120 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 121 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 122 | 123 | q = q * self.scale 124 | attn = (q @ k.transpose(-2, -1)) 125 | 126 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 127 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 128 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 129 | attn = attn + relative_position_bias.unsqueeze(0) 130 | 131 | if mask is not None: 132 | nW = mask.shape[0] 133 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 134 | attn = attn.view(-1, self.num_heads, N, N) 135 | attn = self.softmax(attn) 136 | else: 137 | attn = self.softmax(attn) 138 | 139 | attn = self.attn_drop(attn) 140 | 141 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 142 | x = self.proj(x) 143 | x = self.proj_drop(x) 144 | return x 145 | 146 | def extra_repr(self) -> str: 147 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 148 | 149 | def flops(self, N): 150 | # calculate flops for 1 window with token length of N 151 | flops = 0 152 | # qkv = self.qkv(x) 153 | flops += N * self.dim * 3 * self.dim 154 | # attn = (q @ k.transpose(-2, -1)) 155 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 156 | # x = (attn @ v) 157 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 158 | # x = self.proj(x) 159 | flops += N * self.dim * self.dim 160 | return flops 161 | 162 | 163 | class SwinTransformerBlock(nn.Module): 164 | r""" Swin Transformer Block. 165 | 166 | Args: 167 | dim (int): Number of input channels. 168 | input_resolution (tuple[int]): Input resulotion. 169 | num_heads (int): Number of attention heads. 170 | window_size (int): Window size. 171 | shift_size (int): Shift size for SW-MSA. 172 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 173 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 174 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 175 | drop (float, optional): Dropout rate. Default: 0.0 176 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 177 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 178 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 179 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 180 | """ 181 | 182 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 183 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 184 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 185 | super().__init__() 186 | self.dim = dim 187 | self.input_resolution = input_resolution 188 | self.num_heads = num_heads 189 | self.window_size = window_size 190 | self.shift_size = shift_size 191 | self.mlp_ratio = mlp_ratio 192 | if min(self.input_resolution) <= self.window_size: 193 | # if window size is larger than input resolution, we don't partition windows 194 | self.shift_size = 0 195 | self.window_size = min(self.input_resolution) 196 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 197 | 198 | self.norm1 = norm_layer(dim) 199 | self.attn = WindowAttention( 200 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 201 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 202 | 203 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 204 | self.norm2 = norm_layer(dim) 205 | mlp_hidden_dim = int(dim * mlp_ratio) 206 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 207 | 208 | if self.shift_size > 0: 209 | # calculate attention mask for SW-MSA 210 | H, W = self.input_resolution 211 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 212 | h_slices = (slice(0, -self.window_size), 213 | slice(-self.window_size, -self.shift_size), 214 | slice(-self.shift_size, None)) 215 | w_slices = (slice(0, -self.window_size), 216 | slice(-self.window_size, -self.shift_size), 217 | slice(-self.shift_size, None)) 218 | cnt = 0 219 | for h in h_slices: 220 | for w in w_slices: 221 | img_mask[:, h, w, :] = cnt 222 | cnt += 1 223 | 224 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 225 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 226 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 227 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 228 | else: 229 | attn_mask = None 230 | 231 | self.register_buffer("attn_mask", attn_mask) 232 | 233 | def forward(self, x): 234 | H, W = self.input_resolution 235 | B, L, C = x.shape 236 | assert L == H * W, "input feature has wrong size" 237 | 238 | shortcut = x 239 | x = self.norm1(x) 240 | x = x.view(B, H, W, C) 241 | 242 | # cyclic shift 243 | if self.shift_size > 0: 244 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 245 | else: 246 | shifted_x = x 247 | 248 | # partition windows 249 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 250 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 251 | 252 | # W-MSA/SW-MSA 253 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 254 | 255 | # merge windows 256 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 257 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 258 | 259 | # reverse cyclic shift 260 | if self.shift_size > 0: 261 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 262 | else: 263 | x = shifted_x 264 | x = x.view(B, H * W, C) 265 | 266 | # FFN 267 | x = shortcut + self.drop_path(x) 268 | x = x + self.drop_path(self.mlp(self.norm2(x))) 269 | 270 | return x 271 | 272 | def extra_repr(self) -> str: 273 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 274 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 275 | 276 | def flops(self): 277 | flops = 0 278 | H, W = self.input_resolution 279 | # norm1 280 | flops += self.dim * H * W 281 | # W-MSA/SW-MSA 282 | nW = H * W / self.window_size / self.window_size 283 | flops += nW * self.attn.flops(self.window_size * self.window_size) 284 | # mlp 285 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 286 | # norm2 287 | flops += self.dim * H * W 288 | return flops 289 | 290 | 291 | class PatchMerging(nn.Module): 292 | r""" Patch Merging Layer. 293 | 294 | Args: 295 | input_resolution (tuple[int]): Resolution of input feature. 296 | dim (int): Number of input channels. 297 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 298 | """ 299 | 300 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 301 | super().__init__() 302 | self.input_resolution = input_resolution 303 | self.dim = dim 304 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 305 | self.norm = norm_layer(4 * dim) 306 | 307 | def forward(self, x): 308 | """ 309 | x: B, H*W, C 310 | """ 311 | H, W = self.input_resolution 312 | B, L, C = x.shape 313 | assert L == H * W, "input feature has wrong size" 314 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 315 | 316 | x = x.view(B, H, W, C) 317 | 318 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 319 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 320 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 321 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 322 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 323 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 324 | 325 | x = self.norm(x) 326 | x = self.reduction(x) 327 | 328 | return x 329 | 330 | def extra_repr(self) -> str: 331 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 332 | 333 | def flops(self): 334 | H, W = self.input_resolution 335 | flops = H * W * self.dim 336 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 337 | return flops 338 | 339 | 340 | class BasicLayer(nn.Module): 341 | """ A basic Swin Transformer layer for one stage. 342 | 343 | Args: 344 | dim (int): Number of input channels. 345 | input_resolution (tuple[int]): Input resolution. 346 | depth (int): Number of blocks. 347 | num_heads (int): Number of attention heads. 348 | window_size (int): Local window size. 349 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 350 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 351 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 352 | drop (float, optional): Dropout rate. Default: 0.0 353 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 354 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 355 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 356 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 357 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 358 | """ 359 | 360 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 361 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 362 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 363 | 364 | super().__init__() 365 | self.dim = dim 366 | self.input_resolution = input_resolution 367 | self.depth = depth 368 | self.use_checkpoint = use_checkpoint 369 | 370 | # build blocks 371 | self.blocks = nn.ModuleList([ 372 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 373 | num_heads=num_heads, window_size=window_size, 374 | shift_size=0 if (i % 2 == 0) else window_size // 2, 375 | mlp_ratio=mlp_ratio, 376 | qkv_bias=qkv_bias, qk_scale=qk_scale, 377 | drop=drop, attn_drop=attn_drop, 378 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 379 | norm_layer=norm_layer) 380 | for i in range(depth)]) 381 | 382 | # patch merging layer 383 | if downsample is not None: 384 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 385 | else: 386 | self.downsample = None 387 | 388 | def forward(self, x): 389 | for blk in self.blocks: 390 | if self.use_checkpoint: 391 | x = checkpoint.checkpoint(blk, x) 392 | else: 393 | x = blk(x) 394 | if self.downsample is not None: 395 | x = self.downsample(x) 396 | return x 397 | 398 | def extra_repr(self) -> str: 399 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 400 | 401 | def flops(self): 402 | flops = 0 403 | for blk in self.blocks: 404 | flops += blk.flops() 405 | if self.downsample is not None: 406 | flops += self.downsample.flops() 407 | return flops 408 | 409 | 410 | class PatchEmbed(nn.Module): 411 | r""" Image to Patch Embedding 412 | 413 | Args: 414 | img_size (int): Image size. Default: 224. 415 | patch_size (int): Patch token size. Default: 4. 416 | in_chans (int): Number of input image channels. Default: 3. 417 | embed_dim (int): Number of linear projection output channels. Default: 96. 418 | norm_layer (nn.Module, optional): Normalization layer. Default: None 419 | """ 420 | 421 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 422 | super().__init__() 423 | img_size = to_2tuple(img_size) 424 | patch_size = to_2tuple(patch_size) 425 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 426 | self.img_size = img_size 427 | self.patch_size = patch_size 428 | self.patches_resolution = patches_resolution 429 | self.num_patches = patches_resolution[0] * patches_resolution[1] 430 | 431 | self.in_chans = in_chans 432 | self.embed_dim = embed_dim 433 | 434 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 435 | if norm_layer is not None: 436 | self.norm = norm_layer(embed_dim) 437 | else: 438 | self.norm = None 439 | 440 | def forward(self, x): 441 | B, C, H, W = x.shape 442 | # FIXME look at relaxing size constraints 443 | assert H == self.img_size[0] and W == self.img_size[1], \ 444 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 445 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 446 | if self.norm is not None: 447 | x = self.norm(x) 448 | return x 449 | 450 | def flops(self): 451 | Ho, Wo = self.patches_resolution 452 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 453 | if self.norm is not None: 454 | flops += Ho * Wo * self.embed_dim 455 | return flops 456 | 457 | 458 | class SwinTransformer(nn.Module): 459 | r""" Swin Transformer 460 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 461 | https://arxiv.org/pdf/2103.14030 462 | 463 | Args: 464 | img_size (int | tuple(int)): Input image size. Default 224 465 | patch_size (int | tuple(int)): Patch size. Default: 4 466 | in_chans (int): Number of input image channels. Default: 3 467 | num_classes (int): Number of classes for classification head. Default: 1000 468 | embed_dim (int): Patch embedding dimension. Default: 96 469 | depths (tuple(int)): Depth of each Swin Transformer layer. 470 | num_heads (tuple(int)): Number of attention heads in different layers. 471 | window_size (int): Window size. Default: 7 472 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 473 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 474 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 475 | drop_rate (float): Dropout rate. Default: 0 476 | attn_drop_rate (float): Attention dropout rate. Default: 0 477 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 478 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 479 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 480 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 481 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 482 | """ 483 | 484 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 485 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 486 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 487 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 488 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 489 | use_checkpoint=False, **kwargs): 490 | super().__init__() 491 | 492 | self.num_classes = num_classes 493 | self.num_layers = len(depths) 494 | self.embed_dim = embed_dim 495 | self.ape = ape 496 | self.patch_norm = patch_norm 497 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 498 | self.mlp_ratio = mlp_ratio 499 | 500 | # split image into non-overlapping patches 501 | self.patch_embed = PatchEmbed( 502 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 503 | norm_layer=norm_layer if self.patch_norm else None) 504 | num_patches = self.patch_embed.num_patches 505 | patches_resolution = self.patch_embed.patches_resolution 506 | self.patches_resolution = patches_resolution 507 | 508 | # absolute position embedding 509 | if self.ape: 510 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 511 | trunc_normal_(self.absolute_pos_embed, std=.02) 512 | 513 | self.pos_drop = nn.Dropout(p=drop_rate) 514 | 515 | # stochastic depth 516 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 517 | 518 | # build layers 519 | self.layers = nn.ModuleList() 520 | for i_layer in range(self.num_layers): 521 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 522 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 523 | patches_resolution[1] // (2 ** i_layer)), 524 | depth=depths[i_layer], 525 | num_heads=num_heads[i_layer], 526 | window_size=window_size, 527 | mlp_ratio=self.mlp_ratio, 528 | qkv_bias=qkv_bias, qk_scale=qk_scale, 529 | drop=drop_rate, attn_drop=attn_drop_rate, 530 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 531 | norm_layer=norm_layer, 532 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 533 | use_checkpoint=use_checkpoint) 534 | self.layers.append(layer) 535 | 536 | self.norm = norm_layer(self.num_features) 537 | self.avgpool = nn.AdaptiveAvgPool1d(1) 538 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 539 | self.dim_out = self.num_features 540 | 541 | self.apply(self._init_weights) 542 | 543 | def _init_weights(self, m): 544 | if isinstance(m, nn.Linear): 545 | trunc_normal_(m.weight, std=.02) 546 | if isinstance(m, nn.Linear) and m.bias is not None: 547 | nn.init.constant_(m.bias, 0) 548 | elif isinstance(m, nn.LayerNorm): 549 | nn.init.constant_(m.bias, 0) 550 | nn.init.constant_(m.weight, 1.0) 551 | 552 | @torch.jit.ignore 553 | def no_weight_decay(self): 554 | return {'absolute_pos_embed'} 555 | 556 | @torch.jit.ignore 557 | def no_weight_decay_keywords(self): 558 | return {'relative_position_bias_table'} 559 | 560 | def forward_features(self, x): 561 | x = self.patch_embed(x) 562 | if self.ape: 563 | x = x + self.absolute_pos_embed 564 | x = self.pos_drop(x) 565 | 566 | for layer in self.layers: 567 | x = layer(x) 568 | 569 | x = self.norm(x) # B L C 570 | x = self.avgpool(x.transpose(1, 2)) # B C 1 571 | x = torch.flatten(x, 1) 572 | return x 573 | 574 | def forward(self, x): 575 | x = self.forward_features(x) 576 | x = self.head(x) 577 | return x 578 | 579 | def flops(self): 580 | flops = 0 581 | flops += self.patch_embed.flops() 582 | for i, layer in enumerate(self.layers): 583 | flops += layer.flops() 584 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 585 | flops += self.num_features * self.num_classes 586 | return flops 587 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import tempfile 3 | from collections import OrderedDict 4 | from typing import Tuple, Union 5 | import logging 6 | import os 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from timm.models.layers import DropPath, trunc_normal_ 14 | 15 | from .image_encoder import build_image_encoder 16 | from .text_encoder import build_text_encoder 17 | from .text_encoder import build_tokenizer 18 | from data.imagenet import IMAGENET_CLASSES, IMAGENET_DEFAULT_TEMPLATES 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class UniCLModel(nn.Module): 24 | def __init__(self, config: dict,): 25 | super().__init__() 26 | 27 | self.conf_lang_encoder = config['MODEL']['TEXT_ENCODER'] 28 | self.tokenizer = build_tokenizer(self.conf_lang_encoder) 29 | 30 | self.text_encoder = build_text_encoder(self.conf_lang_encoder, self.tokenizer, config['VERBOSE']) 31 | 32 | dim_projection = config['MODEL']['DIM_PROJECTION'] 33 | if hasattr(self.text_encoder, 'dim_out'): 34 | dim_out = self.text_encoder.dim_out 35 | else: 36 | with torch.no_grad(): 37 | dim_out = self.text_encoder( 38 | torch.zeros(1,1).type(torch.LongTensor) 39 | )['last_hidden_state'].size(2) 40 | 41 | self.text_projection = nn.Parameter(torch.empty(dim_out, dim_projection)) 42 | 43 | self.conf_image_encoder = config['MODEL']['IMAGE_ENCODER'] 44 | self.image_encoder = build_image_encoder(self.conf_image_encoder) 45 | 46 | self.image_projection = nn.Parameter( 47 | torch.empty(self.image_encoder.dim_out, dim_projection) 48 | ) 49 | 50 | self.logit_scale = nn.Parameter(torch.ones([])) 51 | 52 | trunc_normal_(self.text_projection, std=.02) 53 | trunc_normal_(self.image_projection, std=.02) 54 | 55 | def _convert_old_weights(self, model_dict): 56 | model_dict_updated = {} 57 | for k, v in model_dict.items(): 58 | if k.startswith('visual.'): 59 | model_dict_updated['image_encoder.'+k[7:]] = v 60 | elif k.startswith('text.'): 61 | model_dict_updated['lang_encoder.'+k[5:]] = v 62 | elif k == 'vision_projection': 63 | model_dict_updated['image_projection'] = v 64 | elif k == 'text_projection': 65 | model_dict_updated['text_projection'] = v 66 | else: 67 | model_dict_updated[k] = v 68 | 69 | return model_dict_updated 70 | 71 | def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): 72 | if not os.path.isfile(pretrained): 73 | logger.warning(f'=> Pretrained model ({pretrained}) is not a file, skip init weight') 74 | return 75 | 76 | pretrained_dict = torch.load(pretrained, map_location='cpu') 77 | logger.info(f'=> Loading pretrained model {pretrained}') 78 | pretrained_dict = self._convert_old_weights(pretrained_dict) 79 | model_dict = self.state_dict() 80 | pretrained_dict = { 81 | k: v for k, v in pretrained_dict.items() 82 | if k in model_dict.keys() 83 | } 84 | need_init_state_dict = {} 85 | image_encoder_state_dict = {} 86 | for k, v in pretrained_dict.items(): 87 | need_init = ( 88 | k.split('.')[0] in pretrained_layers 89 | or pretrained_layers[0] == '*' 90 | ) 91 | 92 | if need_init: 93 | if k.startswith('image_encoder.'): 94 | image_encoder_state_dict[k] = v 95 | else: 96 | if verbose: 97 | logger.info(f'=> init {k} from {pretrained}') 98 | 99 | need_init_state_dict[k] = v 100 | self.image_encoder.from_state_dict(image_encoder_state_dict, ['*'], verbose) 101 | self.load_state_dict(need_init_state_dict, strict=False) 102 | 103 | @torch.jit.ignore 104 | def no_weight_decay(self): 105 | no_weight_decay = {'logit_scale'} 106 | if hasattr(self.text_encoder, 'no_weight_decay'): 107 | for k in self.text_encoder.no_weight_decay(): 108 | no_weight_decay.add('lang_encoder.'+k) 109 | 110 | if hasattr(self.image_encoder, 'no_weight_decay'): 111 | for k in self.image_encoder.no_weight_decay(): 112 | no_weight_decay.add('image_encoder.'+k) 113 | 114 | return no_weight_decay 115 | 116 | @property 117 | def dtype(self): 118 | return self.logit_scale.dtype 119 | 120 | def get_imnet_embeddings(self): 121 | templates = IMAGENET_DEFAULT_TEMPLATES 122 | clss_embeddings = [] 123 | for clss in IMAGENET_CLASSES: 124 | txts = [template.format(clss) for template in templates] 125 | 126 | tokens = self.tokenizer( 127 | txts, padding='max_length', truncation=True, max_length=77, return_tensors='pt' 128 | ) 129 | tokens = {key:val.cuda() for key,val in tokens.items()} 130 | 131 | clss_embedding = self.encode_text(tokens) 132 | clss_embedding = clss_embedding.mean(dim=0) 133 | clss_embedding /= clss_embedding.norm() 134 | clss_embeddings.append(clss_embedding) 135 | imnet_text_embeddings = torch.stack(clss_embeddings, dim=0) 136 | return imnet_text_embeddings 137 | 138 | def encode_image(self, image, norm=True): 139 | x = self.image_encoder.forward_features(image) 140 | x = x @ self.image_projection 141 | 142 | if norm: 143 | x = x / x.norm(dim=-1, keepdim=True) 144 | 145 | return x 146 | 147 | def encode_text(self, text, norm=True): 148 | x = self.text_encoder(**text) 149 | x = x['last_hidden_state'] 150 | 151 | if self.conf_lang_encoder['TOKENIZER'] == 'clip': 152 | x = x[torch.arange(x.size(0)), text['input_ids'].argmax(dim=-1)] 153 | else: 154 | x = x[:, 0] 155 | 156 | x = x @ self.text_projection 157 | 158 | if norm: 159 | x = x / x.norm(dim=-1, keepdim=True) 160 | 161 | return x 162 | 163 | def forward(self, image, text): 164 | features_image = self.encode_image(image) 165 | features_text = self.encode_text(text) 166 | 167 | # cosine similarity as logits 168 | T = self.logit_scale.exp() 169 | 170 | return features_image, features_text, T 171 | 172 | 173 | def build_unicl_model(config, **kwargs): 174 | model = UniCLModel(config) 175 | if config['MODEL']['PRETRAINED'] != '': 176 | pretrained_path = config['MODEL']['PRETRAINED'] 177 | from ..Utils.Utils import is_valid_url, download_file 178 | if is_valid_url(pretrained_path): 179 | with tempfile.TemporaryDirectory() as tmp_path: 180 | file_local_path = pathlib.Path(tmp_path) / 'base_model.pt' 181 | download_file(pretrained_path, file_local_path) 182 | model.from_pretrained(str(file_local_path), config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) 183 | else: 184 | model.from_pretrained(pretrained_path, config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) 185 | 186 | return model 187 | -------------------------------------------------------------------------------- /model/text_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .build import build_lang_encoder as build_text_encoder 6 | from .build import build_tokenizer 7 | 8 | from .transformer import * 9 | from .hf_model import * 10 | -------------------------------------------------------------------------------- /model/text_encoder/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from transformers import CLIPTokenizer 4 | from transformers import AutoTokenizer 5 | 6 | from .registry import lang_encoders 7 | from .registry import is_lang_encoder 8 | 9 | 10 | def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs): 11 | model_name = config_encoder['NAME'] 12 | 13 | if not is_lang_encoder(model_name): 14 | raise ValueError(f'Unknown model: {model_name}') 15 | 16 | return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs) 17 | 18 | 19 | def build_tokenizer(config_encoder): 20 | tokenizer = None 21 | os.environ['TOKENIZERS_PARALLELISM'] = 'true' 22 | if config_encoder['TOKENIZER'] == 'clip': 23 | pretrained_tokenizer = config_encoder.get( 24 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' 25 | ) 26 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) 27 | tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) 28 | else: 29 | tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER']) 30 | 31 | return tokenizer 32 | -------------------------------------------------------------------------------- /model/text_encoder/hf_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from transformers import AutoConfig 4 | from transformers import AutoModel 5 | 6 | from .registry import register_lang_encoder 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @register_lang_encoder 12 | def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): 13 | 14 | hf_model = None 15 | if config_encoder['LOAD_PRETRAINED']: 16 | hf_model = AutoModel.from_pretrained(config_encoder['HF_MODEL']) 17 | else: 18 | hf_config = AutoConfig.from_pretrained(config_encoder['HF_MODEL']) 19 | 20 | if 'CONFIG_OVERRIDE' in config_encoder: 21 | logger.warning(f'Override config: {config_encoder["CONFIG_OVERRIDE"]}') 22 | hf_config.update(config_encoder['CONFIG_OVERRIDE']) 23 | 24 | logger.info(f'HF model config: {hf_config}') 25 | hf_model = AutoModel.from_config(hf_config) 26 | 27 | return hf_model 28 | -------------------------------------------------------------------------------- /model/text_encoder/registry.py: -------------------------------------------------------------------------------- 1 | _lang_encoders = {} 2 | 3 | 4 | def register_lang_encoder(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | 8 | _lang_encoders[model_name] = fn 9 | 10 | return fn 11 | 12 | 13 | def lang_encoders(model_name): 14 | return _lang_encoders[model_name] 15 | 16 | 17 | def is_lang_encoder(model_name): 18 | return model_name in _lang_encoders 19 | -------------------------------------------------------------------------------- /model/text_encoder/transformer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | import logging 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from timm.models.layers import DropPath, trunc_normal_ 12 | 13 | from .registry import register_lang_encoder 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | class LayerNorm(nn.Module): 18 | def __init__(self, hidden_size, eps=1e-12): 19 | """Construct a layernorm module in the TF style (epsilon inside the square root). 20 | """ 21 | super(LayerNorm, self).__init__() 22 | self.weight = nn.Parameter(torch.ones(hidden_size)) 23 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 24 | self.variance_epsilon = eps 25 | 26 | def forward(self, x): 27 | pdtype = x.dtype 28 | x = x.float() 29 | u = x.mean(-1, keepdim=True) 30 | s = (x - u).pow(2).mean(-1, keepdim=True) 31 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 32 | return self.weight * x.to(pdtype) + self.bias 33 | 34 | 35 | class QuickGELU(nn.Module): 36 | def forward(self, x: torch.Tensor): 37 | return x * torch.sigmoid(1.702 * x) 38 | 39 | 40 | class ResidualAttentionBlock(nn.Module): 41 | def __init__(self, 42 | d_model: int, 43 | n_head: int, 44 | attn_mask: torch.Tensor = None, 45 | drop_path: float = 0.0): 46 | super().__init__() 47 | 48 | self.attn = nn.MultiheadAttention(d_model, n_head) 49 | self.ln_1 = LayerNorm(d_model) 50 | self.mlp = nn.Sequential(OrderedDict([ 51 | ("c_fc", nn.Linear(d_model, d_model * 4)), 52 | ("gelu", QuickGELU()), 53 | ("c_proj", nn.Linear(d_model * 4, d_model)) 54 | ])) 55 | self.ln_2 = LayerNorm(d_model) 56 | self.attn_mask = attn_mask 57 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 58 | 59 | def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): 60 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ 61 | if self.attn_mask is not None else None 62 | 63 | 64 | return self.attn( 65 | x, x, x, 66 | key_padding_mask=key_padding_mask, 67 | need_weights=False, 68 | attn_mask=self.attn_mask 69 | )[0] 70 | 71 | def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): 72 | x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) 73 | x = x + self.drop_path(self.mlp(self.ln_2(x))) 74 | return x 75 | 76 | 77 | class Transformer(nn.Module): 78 | def __init__(self, 79 | context_length: int, 80 | vocab_size: int, 81 | width: int, 82 | layers: int, 83 | heads: int, 84 | drop_path: float = 0.0, 85 | autogressive: bool =True): 86 | super().__init__() 87 | 88 | self.token_embedding = nn.Embedding(vocab_size, width) 89 | 90 | self.context_length = context_length 91 | self.positional_embedding = nn.Parameter( 92 | torch.empty(self.context_length, width) 93 | ) 94 | 95 | self.width = width 96 | self.layers = layers 97 | self.autogressive = autogressive 98 | attn_mask = self.build_attention_mask() if autogressive else None 99 | dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule 100 | self.resblocks = nn.ModuleList( 101 | [ 102 | ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) 103 | for i in range(layers) 104 | ] 105 | ) 106 | 107 | self.ln_final = LayerNorm(width) 108 | 109 | trunc_normal_(self.positional_embedding, std=.02) 110 | # nn.init.normal_(self.token_embedding, std=.02) 111 | trunc_normal_(self.token_embedding.weight, std=.02) 112 | self.apply(self._init_weights) 113 | 114 | @property 115 | def dim_out(self): 116 | return self.width 117 | 118 | def build_attention_mask(self): 119 | # lazily create causal attention mask, with full attention between the vision tokens 120 | # pytorch uses additive attention mask; fill with -inf 121 | mask = torch.empty(self.context_length, self.context_length) 122 | mask.fill_(float("-inf")) 123 | mask.triu_(1) # zero out the lower diagonal 124 | return mask 125 | 126 | def _init_weights(self, m): 127 | if isinstance(m, (nn.Linear, nn.Conv2d)): 128 | logger.info('=> init weight of Linear/Conv2d from trunc norm') 129 | trunc_normal_(m.weight, std=0.02) 130 | if m.bias is not None: 131 | logger.info('=> init bias of Linear/Conv2d to zeros') 132 | nn.init.constant_(m.bias, 0) 133 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 134 | nn.init.constant_(m.bias, 0) 135 | 136 | def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): 137 | if os.path.isfile(pretrained): 138 | pretrained_dict = torch.load(pretrained, map_location='cpu') 139 | logging.info(f'=> loading pretrained model {pretrained}') 140 | model_dict = self.state_dict() 141 | pretrained_dict = { 142 | k: v for k, v in pretrained_dict.items() 143 | if k in model_dict.keys() 144 | } 145 | need_init_state_dict = {} 146 | for k, v in pretrained_dict.items(): 147 | need_init = ( 148 | k.split('.')[0] in pretrained_layers 149 | or pretrained_layers[0] == '*' 150 | ) 151 | if need_init: 152 | if verbose: 153 | logging.info(f'=> init {k} from {pretrained}') 154 | 155 | need_init_state_dict[k] = v 156 | self.load_state_dict(need_init_state_dict, strict=False) 157 | 158 | 159 | @torch.jit.ignore 160 | def no_weight_decay(self): 161 | return { 162 | 'positional_embedding', 163 | 'token_embedding', 164 | } 165 | 166 | def forward(self, input_ids, attention_mask=None): 167 | key_padding_mask = (input_ids == 0) if not self.autogressive else None 168 | x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model] 169 | x = x + self.positional_embedding 170 | x = x.permute(1, 0, 2) # NLD -> LND 171 | for block in self.resblocks: 172 | x = block(x, key_padding_mask) 173 | x = x.permute(1, 0, 2) # LND -> NLD 174 | 175 | x = self.ln_final(x) 176 | 177 | return {'last_hidden_state': x} 178 | 179 | 180 | @register_lang_encoder 181 | def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): 182 | transformer = Transformer( 183 | context_length=config_encoder['CONTEXT_LENGTH'], 184 | vocab_size=tokenizer.vocab_size, 185 | width=config_encoder['WIDTH'], 186 | layers=config_encoder['LAYERS'], 187 | heads=config_encoder['HEADS'], 188 | autogressive=config_encoder.get('AUTOGRESSIVE', True) 189 | ) 190 | 191 | if config_encoder['LOAD_PRETRAINED']: 192 | transformer.load_pretrained() 193 | 194 | return transformer 195 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Focal Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Jianwei Yang (jianwyan@microsoft.com) 6 | # Based on Swin Transformer written by Zhe Liu 7 | # -------------------------------------------------------- 8 | 9 | from torch import optim as optim 10 | 11 | 12 | def build_optimizer(config, model): 13 | """ 14 | Build optimizer, set weight decay of normalization to 0 by default. 15 | """ 16 | skip = {} 17 | skip_keywords = {} 18 | if hasattr(model, 'no_weight_decay'): 19 | skip = model.no_weight_decay() 20 | if hasattr(model, 'no_weight_decay_keywords'): 21 | skip_keywords = model.no_weight_decay_keywords() 22 | parameters = set_weight_decay(model, skip, skip_keywords) 23 | 24 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 25 | optimizer = None 26 | if opt_lower == 'sgd': 27 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 28 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 29 | elif opt_lower == 'adamw': 30 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 31 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 32 | 33 | return optimizer 34 | 35 | 36 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 37 | has_decay = [] 38 | no_decay = [] 39 | 40 | for name, param in model.named_parameters(): 41 | if not param.requires_grad: 42 | continue # frozen weights 43 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 44 | check_keywords_in_name(name, skip_keywords): 45 | no_decay.append(param) 46 | # print(f"{name} has no weight decay") 47 | else: 48 | has_decay.append(param) 49 | return [{'params': has_decay}, 50 | {'params': no_decay, 'weight_decay': 0.}] 51 | 52 | 53 | def check_keywords_in_name(name, keywords=()): 54 | isin = False 55 | for keyword in keywords: 56 | if keyword in name: 57 | isin = True 58 | return isin 59 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Focal Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Jianwei Yang (jianwyan@microsoft.com) 6 | # Based on Swin Transformer written by Zhe Liu 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import torch 11 | import torch.distributed as dist 12 | from timm.models.layers import trunc_normal_ 13 | 14 | try: 15 | # noinspection PyUnresolvedReferences 16 | from apex import amp 17 | except ImportError: 18 | amp = None 19 | 20 | 21 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger): 22 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 23 | if config.MODEL.RESUME.startswith('https'): 24 | checkpoint = torch.hub.load_state_dict_from_url( 25 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 26 | elif os.path.exists(config.MODEL.RESUME): 27 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 28 | else: 29 | logger.info(f"==============> Cannot find {config.MODEL.RESUME}....................") 30 | return None 31 | 32 | msg = model.load_state_dict(checkpoint['model'], strict=False) 33 | logger.info(msg) 34 | max_accuracy = 0.0 35 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 36 | optimizer.load_state_dict(checkpoint['optimizer']) 37 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 38 | config.defrost() 39 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 40 | config.freeze() 41 | if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": 42 | amp.load_state_dict(checkpoint['amp']) 43 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 44 | if 'max_accuracy' in checkpoint: 45 | max_accuracy = checkpoint['max_accuracy'] 46 | 47 | del checkpoint 48 | torch.cuda.empty_cache() 49 | return max_accuracy 50 | 51 | 52 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger): 53 | save_state = {'model': model.state_dict(), 54 | 'optimizer': optimizer.state_dict(), 55 | 'lr_scheduler': lr_scheduler.state_dict(), 56 | 'max_accuracy': max_accuracy, 57 | 'epoch': epoch, 58 | 'config': config} 59 | if config.AMP_OPT_LEVEL != "O0": 60 | save_state['amp'] = amp.state_dict() 61 | 62 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 63 | logger.info(f"{save_path} saving......") 64 | torch.save(save_state, save_path) 65 | logger.info(f"{save_path} saved !!!") 66 | 67 | 68 | def get_grad_norm(parameters, norm_type=2): 69 | if isinstance(parameters, torch.Tensor): 70 | parameters = [parameters] 71 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 72 | norm_type = float(norm_type) 73 | total_norm = 0 74 | for p in parameters: 75 | param_norm = p.grad.data.norm(norm_type) 76 | total_norm += param_norm.item() ** norm_type 77 | total_norm = total_norm ** (1. / norm_type) 78 | return total_norm 79 | 80 | 81 | def auto_resume_helper(output_dir): 82 | checkpoints = os.listdir(output_dir) 83 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 84 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 85 | if len(checkpoints) > 0: 86 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 87 | print(f"The latest checkpoint founded: {latest_checkpoint}") 88 | resume_file = latest_checkpoint 89 | else: 90 | resume_file = None 91 | return resume_file 92 | 93 | 94 | def reduce_tensor(tensor): 95 | rt = tensor.clone() 96 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 97 | rt /= dist.get_world_size() 98 | return rt 99 | --------------------------------------------------------------------------------