├── .gitignore ├── LICENSE ├── README.md ├── assets ├── data │ ├── gqa │ │ └── .gitkeep │ └── hico │ │ └── .gitkeep ├── dataset.md └── overview.png ├── cache └── .gitkeep ├── configs ├── train_gqa_mcan.yaml └── train_hico.yaml ├── datasets ├── __init__.py ├── datasets.py ├── image_gqa.py └── image_hicodet_bbox.py ├── models ├── __init__.py ├── classifier.py ├── mcan.py ├── models.py ├── pvt_v2.py ├── swin_transformer.py ├── transparent_encoder.py └── vit.py ├── requirements.txt ├── scripts ├── train_gqa_image.sh └── train_hico_image.sh ├── train_gqa.py ├── train_hico.py └── utils ├── __init__.py └── relvit.py /.gitignore: -------------------------------------------------------------------------------- 1 | # custom 2 | *.json 3 | *.jpg 4 | *.png 5 | *.pkl 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code License for RelViT 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | 9 | “Work” means the Software and any additions to or derivative works of the Software that are made 10 | available under this License. 11 | 12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as 13 | provided under U.S. copyright law; provided, however, that for the purposes of this License, 14 | derivative works shall not include works that remain separable from, or merely link (or bind by name) 15 | to the interfaces of, the Work. 16 | 17 | Works, including the Software, are “made available” under this License by including in or with the 18 | Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a 19 | copy of this License. 20 | 21 | 2. License Grant 22 | 23 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants 24 | to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, 25 | prepare derivative works of, publicly display, publicly perform, sublicense and distribute its 26 | Work and any resulting derivative works in any form. 27 | 28 | 3. Limitations 29 | 30 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this 31 | License, (b) you include a complete copy of this License with your distribution, and (c) you 32 | retain without modification any copyright, patent, trademark, or attribution notices that are 33 | present in the Work. 34 | 35 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, 36 | reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) 37 | Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and 38 | (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding 39 | Your Terms, this License (including the redistribution requirements in Section 3.1) will 40 | continue to apply to the Work itself. 41 | 42 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended 43 | for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use 44 | the Work and any derivative works commercially. As used herein, “non-commercially” means 45 | for research or evaluation purposes only. 46 | 47 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor 48 | (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you 49 | allege are infringed by any Work, then your rights under this License from such Licensor 50 | (including the grant in Section 2.1) will terminate immediately. 51 | 52 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ 53 | names, logos, or trademarks, except as necessary to reproduce the notices described in this 54 | License. 55 | 56 | 3.6 Termination. If you violate any term of this License, then your rights under this License 57 | (including the grant in Section 2.1) will terminate immediately. 58 | 59 | 4. Disclaimer of Warranty. 60 | 61 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 62 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 63 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON- 64 | INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS 65 | LICENSE. 66 | 67 | 5. Limitation of Liability. 68 | 69 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 70 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 71 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 72 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR 73 | RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT 74 | NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, 75 | COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR 76 | LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH 77 | DAMAGES. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | RelViT 2 | === 3 | 4 |

5 | 6 | This repository hosts the code for the paper: 7 | 8 | [RelViT: Concept-guided Vision Transformer for Visual Relational Reasoning (ICLR 2022)](https://web.cs.ucla.edu/~xm/file/relvit_iclr22.pdf) 9 | 10 | by [Xiaojian Ma](http://web.cs.ucla.edu/~xm), [Weili Nie](https://weilinie.github.io/), [Zhiding Yu](https://chrisding.github.io/), [Huaizu Jiang](https://jianghz.me/), [Chaowei Xiao](https://xiaocw11.github.io/), [Yuke Zhu](https://www.cs.utexas.edu/~yukez/) and [Anima Anandkumar](http://tensorlab.cms.caltech.edu/users/anima/) 11 | 12 | [arXiv](https://arxiv.org/abs/2204.11167) | [Poster](https://web.cs.ucla.edu/~xm/file/relvit_iclr22_poster.pdf) | [Slides](https://web.cs.ucla.edu/~xm/file/relvit_iclr22_slides.pdf) 13 | 14 | ## News 15 | 16 | - 🔥🔥 **09/10/2022**: Pre-trained models on GQA are now released. 17 | 18 | 19 | ## Abstract 20 | 21 | Reasoning about visual relationships is central to how humans interpret the visual world. This task remains challenging for current deep learning algorithms since it requires addressing three key technical problems jointly: **1)** identifying object entities and their properties, **2)** inferring semantic relations between pairs of entities, and **3)** generalizing to novel object-relation combinations, i.e., systematic generalization. In this work, we use **vision transformers (ViTs)** as our base model for visual reasoning and make better use of concepts defined as object entities and their relations to improve the reasoning ability of ViTs. Specifically, we introduce a novel concept-feature dictionary to allow flexible image feature retrieval at training time with concept keys. This dictionary enables two new concept-guided auxiliary tasks: 1) a **global task** for promoting relational reasoning, and 2) a **local task** for facilitating semantic object-centric correspondence learning. To examine the systematic generalization of visual reasoning models, we introduce systematic splits for the standard HICO and GQA benchmarks. We show the resulting model, **Concept-guided Vision Transformer** (or **RelViT** for short) significantly outperforms prior approaches on HICO and GQA by **16%** and **13%** in the original split, and by **43%** and **18%** in the systematic split. Our ablation analyses also reveal our model's compatibility with multiple ViT variants and robustness to hyper-parameters. 22 | 23 | ## Installation 24 | 25 | - Install PyTorch: 26 | ```shell 27 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 28 | ``` 29 | 30 | - Install the necessary packages with `requirements.txt`: 31 | ```shell 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | The code has been tested with Python 3.8, PyTorch 1.11.0 and CUDA 11.6 on Ubuntu 20.04 36 | 37 | ## Data Preparation 38 | 39 | Please refer to [data preparation](assets/dataset.md) 40 | 41 | ## Training 42 | 43 |
HICO 44 | 45 | ```shell 46 | bash scripts/train_hico_image.sh configs/train_hico.yaml 47 | ``` 48 | 49 | ### Note 50 | 51 | - In `configs/train_hico.yaml` you may find some configurable options: 52 | - use `eval_mode` to run different experiments: original or systematic generalization test 53 | - use `model_args.encoder_args.encoder` and `load_encoder` to select the vision backbone. There are five options available: `pvtv2_b2`, `pvtv2_b3`, `swin_small`, `swin_base` and `vit_small_16`. 54 | - use `relvit` to turn on/off RelViT auxillary loss 55 | - use `relvit_weight` to adjust the coefficient of RelViT auxillary loss 56 | - use `relvit_local_only` to control if you only use RelViT local/global task 57 | - use `relvit_mode` to control if you want to include [EsViT](https://github.com/microsoft/esvit) loss. 58 | - use `relvit_sample_uniform` to choose from uniform or "most-recent" concept sampling 59 | - use `relvit_concept_use` and `relvit_num_concepts` to choose the concept used by RelViT among `HOI`, `verb` and `object` 60 | 61 | In general, we don't recommend modifying other parameters. 62 | 63 | - All the GPUs will be used by default. To run with the recommended batch size, you may need 1 V100 32G GPU. 64 | 65 |
66 | 67 |
GQA 68 | 69 | ```shell 70 | bash scripts/train_gqa_image.sh configs/train_gqa.yaml 71 | ``` 72 | 73 | ### Note 74 | 75 | - In `configs/train_gqa.yaml` you may find some configurable options: 76 | - use `eval_mode` to run different experiments: original or systematic generalization test 77 | - use `model_args.encoder_args.encoder` and `load_encoder` to select the vision backbone. There are five options available: `pvtv2_b2`, `pvtv2_b3`, `swin_small`, `swin_base` and `vit_small_16`. 78 | - use `relvit` to turn on/off RelViT auxillary loss 79 | - use `relvit_weight` to adjust the coefficient of RelViT auxillary loss 80 | - use `relvit_local_only` to control if you only use RelViT local/global task 81 | - use `relvit_mode` to control if you want to include [EsViT](https://github.com/microsoft/esvit) loss 82 | - use `relvit_sample_uniform` to choose from uniform or "most-recent" concept sampling 83 | 84 | In general, we don't recommend modifying other parameters. 85 | 86 | - All the GPUs will be used by default. To run with the recommended batch size, you may need up to 64 V100 32G GPUs. This is because we need to fine-tune the vision backbone during training. 87 | 88 |
89 | 90 | ## Testing 91 | 92 |
HICO 93 | 94 | ```shell 95 | bash scripts/train_hico_image.sh configs/train_hico.yaml --test_only --test_model 96 | ``` 97 | 98 |
99 | 100 |
GQA 101 | 102 | ```shell 103 | bash scripts/train_gqa_image.sh configs/train_gqa.yaml --test_only --test_model 104 | ``` 105 | 106 |
107 | 108 | ## Pre-trained models 109 | 110 | | tag | encoder | experiment | result | URL | 111 | |---------------------|--------------|------------|--------|--------------------------------------------------------------------------------------------| 112 | | `swin-small-relvit` | `swin_small` | GQA (val) | 61.38 | [link](https://drive.google.com/file/d/1efkfy7bM6PmzKZy6ffQQq98OkLLgcZd3/view?usp=sharing) | 113 | | `swin-base-relvit` | `swin_base` | GQA (val) | 65.54 | [link](https://drive.google.com/file/d/1u66PqHJiqGJZsXWeLsjooVAFEdkjDVK3/view?usp=sharing) | 114 | 115 | 116 | ## License 117 | 118 | Please check the LICENSE file for both the code and the released pre-trained models. This work may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact researchinquiries@nvidia.com. 119 | 120 | ## Acknowledgement 121 | 122 | The authors have referred the following projects: 123 | 124 | [SimCLR](https://github.com/google-research/simclr) 125 | 126 | [DenseCL](https://github.com/WXinlong/DenseCL) 127 | 128 | [EsViT](https://github.com/microsoft/esvit) 129 | 130 | [Swin-Transformer](https://github.com/microsoft/Swin-Transformer) 131 | 132 | [PVT](https://github.com/whai362/PVT) 133 | 134 | [HICODet](https://github.com/fredzzhang/hicodet) 135 | 136 | [MCAN](https://github.com/MILVLG/mcan-vqa) 137 | 138 | ## Citation 139 | 140 | Please consider citing our paper if you find our work helpful for your research: 141 | 142 | ``` 143 | @inproceedings{ma2022relvit, 144 | title={RelViT: Concept-guided Vision Transformer for Visual Relational Reasoning}, 145 | author={Xiaojian Ma and Weili Nie and Zhiding Yu and Huaizu Jiang and Chaowei Xiao and Yuke Zhu and Song-Chun Zhu and Anima Anandkumar}, 146 | booktitle={International Conference on Learning Representations}, 147 | year={2022}, 148 | url={https://openreview.net/forum?id=afoV8W3-IYp} 149 | } 150 | ``` 151 | -------------------------------------------------------------------------------- /assets/data/gqa/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RelViT/ae34a63975738a7eef219fe3a2080286f7dfc753/assets/data/gqa/.gitkeep -------------------------------------------------------------------------------- /assets/data/hico/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RelViT/ae34a63975738a7eef219fe3a2080286f7dfc753/assets/data/hico/.gitkeep -------------------------------------------------------------------------------- /assets/dataset.md: -------------------------------------------------------------------------------- 1 | Data Preparation 2 | === 3 | 4 | 1. Download the HICO dataset from [here](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) and extract it to `./assets/data/hico` 5 | 6 | 2. Download our customized HICO annotations from [here](https://drive.google.com/file/d/11RvCM0KIBB4pFw0NuaX6Z6IPvnkrb83H/view?usp=sharing) and extract them to `./assets/data/hico/hico_20160224_det` 7 | 8 | 3. Download the images of GQA from [here](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip) and extract them to `./assets/data/gqa` 9 | 10 | 4. Download our customized GQA annotations and meta files from [here](https://drive.google.com/file/d/1hKlRiikRkfZNB-St4kOzuMdLv8N5xk15/view?usp=sharing) and extract them to `./assets/data/gqa` 11 | 12 | 5. Download the pretrained vision backbones and other files from [here](https://drive.google.com/file/d/1pxmUxkk5t8Bg_cS_jdaQgugCqYddZInE/view?usp=sharing) and extract them to `./cache` 13 | 14 | 6. The file structure should look like 15 | ```plain 16 | data 17 | ├── gqa 18 | │   ├── dicts.json 19 | │   ├── gqa_dic.pkl 20 | │   ├── images 21 | │   └── raw 22 | │   └── questions1.2 23 | │   ├── test_balanced_questions.json 24 | │   ├── testdev_balanced_questions.json 25 | │   ├── train_balanced_concepts.json 26 | │   ├── train_balanced_questions.json 27 | │   ├── train_sys_reduced_concepts.json 28 | │   ├── train_sys_reduced_questions.json 29 | │   ├── val_balanced_questions.json 30 | │   └── val_sys_reduced_questions.json 31 | └── hico 32 | └── hico_20160224_det 33 | ├── images 34 | │   ├── test2015 35 | │   └── train2015 36 | ├── instances_test2015.json 37 | ├── instances_train2015.json 38 | ├── sys_vcl_nonrare_instances_test2015.json 39 | ├── sys_vcl_nonrare_instances_train2015.json 40 | ├── sys_vcl_rare_instances_test2015.json 41 | └── sys_vcl_rare_instances_train2015.json 42 | ``` -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RelViT/ae34a63975738a7eef219fe3a2080286f7dfc753/assets/overview.png -------------------------------------------------------------------------------- /cache/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RelViT/ae34a63975738a7eef219fe3a2080286f7dfc753/cache/.gitkeep -------------------------------------------------------------------------------- /configs/train_gqa_mcan.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: gqa 2 | train_dataset_args: 3 | root_dir: ./assets/data/gqa 4 | split: train 5 | 6 | test_dataset: gqa 7 | test_dataset_args: 8 | root_dir: ./assets/data/gqa 9 | split: val 10 | 11 | # model: mcan 12 | # Note: Modifying the hyper-param of MCAN is not recommended. But you may find them in ./utils/__init__.py (MCAN_GQA_PARAMS) 13 | model: mcan-customized 14 | model_args: 15 | word_emb_path: ./cache/gqa_word_embed.npy 16 | encoder: transparent_superpixel_encoder 17 | encoder_args: {encoder: pvtv2_b2, use_boxes_dim: False} 18 | load_encoder: ./cache/pvtv2_b2-{}.pth 19 | encoder_pretrain: imagenet 20 | 21 | train_batches: 1000000 22 | ep_per_batch: 1 23 | max_epoch: 12 24 | 25 | # 0 -- original 26 | # 1 -- systematic 27 | eval_mode: 0 28 | 29 | ########### 30 | # RelViT 31 | ########### 32 | relvit: True 33 | relvit_weight: 1.0 34 | # temperature 35 | relvit_loss_tau: 0.04 36 | # 0 -- both 37 | # 1 -- local only 38 | # 2 -- global only 39 | relvit_local_only: 0 40 | # 0 -- EsViT 41 | # 1 -- EsViT+RelViT 42 | # 2 -- RelViT 43 | relvit_mode: 1 44 | # False -- Most recent 45 | # True -- Uniform 46 | relvit_sample_uniform: True 47 | relvit_num_concepts: 1615 48 | 49 | ########### 50 | # MoCo 51 | ########### 52 | relvit_moco_m: 0.999 53 | relvit_moco_use_queue: False 54 | relvit_moco_K: 10 55 | # 49 for pvtv2 and swin-s, 196 for vit-s-16 56 | relvit_num_tokens: 49 57 | 58 | optimizer: adamw 59 | optimizer_args: {lr: 0.0001, weight_decay: 0, 'milestones': [8, 10], eps: 1e-8} 60 | 61 | print_freq: 10 62 | save_epoch: 1 63 | eval_epoch: 1 64 | grad_norm: 0.5 -------------------------------------------------------------------------------- /configs/train_hico.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: hicodet 2 | train_dataset_args: 3 | im_dir: ./assets/data/hico 4 | split: train 5 | augment_plus: true 6 | 7 | test_dataset: hicodet 8 | test_dataset_args: 9 | im_dir: ./assets/data/hico 10 | split: test 11 | 12 | model: classifier 13 | model_args: 14 | encoder: transparent_superpixel_encoder 15 | encoder_args: {encoder: swin_base, use_boxes_dim: False} 16 | classifier: max-pooling-classifier-twoheads 17 | classifier_args: {n_classes: 600, n_classes1: 117, n_classes2: 80} 18 | load_encoder: ./cache/swin_base-{}.pth 19 | encoder_pretrain: imagenet 20 | 21 | train_batches: 100000 22 | ep_per_batch: 16 23 | max_epoch: 30 24 | 25 | # 1 -- systematic-easy 26 | # 2 -- systematic-hard 27 | # 3 -- original (two heads) 28 | eval_mode: 3 29 | 30 | ########### 31 | # RelViT 32 | ########### 33 | relvit: True 34 | relvit_weight: 0.2 35 | # temperature 36 | relvit_loss_tau: 0.04 37 | # 0 -- both 38 | # 1 -- local only 39 | # 2 -- global only 40 | relvit_local_only: 0 41 | # 0 -- EsViT 42 | # 1 -- EsViT+RelViT 43 | # 2 -- RelViT 44 | relvit_mode: 1 45 | # False -- Most recent 46 | # True -- Uniform 47 | relvit_sample_uniform: False 48 | # 0 & 600 -- hoi 49 | # 1 & 117 -- verb 50 | # 2 & 80 -- object 51 | relvit_concept_use: 0 52 | relvit_num_concepts: 600 53 | 54 | ########### 55 | # MoCo 56 | ########### 57 | relvit_moco_m: 0.999 58 | relvit_moco_use_queue: False 59 | relvit_moco_K: 10 60 | # 49 for pvtv2 and swin-s, 196 for vit-s-16 61 | relvit_num_tokens: 49 62 | 63 | optimizer: adamw 64 | optimizer_args: {lr: 0.00015, weight_decay: 0, 'milestones': [15, 25]} 65 | 66 | print_freq: 10 67 | save_epoch: 1 68 | eval_epoch: 1 69 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for RelViT. To view a copy of this license, see the LICENSE file. 6 | # ---------------------------------------------------------------------- 7 | 8 | from .datasets import make 9 | from . import image_hicodet_bbox 10 | from . import image_gqa 11 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the Bongard-HOI library 5 | # which was released under the NVIDIA Source Code Licence. 6 | # 7 | # Source: 8 | # https://github.com/NVlabs/Bongard-HOI 9 | # 10 | # The license for the original version of this file can be 11 | # found in https://github.com/NVlabs/Bongard-HOI/blob/master/LICENSE 12 | # The modifications to this file are subject to the same NVIDIA Source Code Licence. 13 | # --------------------------------------------------------------- 14 | 15 | 16 | import os 17 | 18 | 19 | datasets = {} 20 | def register(name): 21 | def decorator(cls): 22 | datasets[name] = cls 23 | return cls 24 | return decorator 25 | 26 | 27 | def make(name, **kwargs): 28 | dataset = datasets[name](**kwargs) 29 | return dataset 30 | -------------------------------------------------------------------------------- /datasets/image_gqa.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for RelViT. To view a copy of this license, see the LICENSE file. 6 | # ---------------------------------------------------------------------- 7 | 8 | import json 9 | import re 10 | import glob 11 | import os.path as osp 12 | import collections 13 | 14 | import numpy as np 15 | from PIL import Image 16 | import torch 17 | from torch.utils.data import Dataset 18 | from torchvision import transforms 19 | from torch.utils.data._utils.collate import default_collate 20 | 21 | import utils 22 | from .datasets import register 23 | 24 | 25 | def label_to_onehot(label, num_class): 26 | # label: [N] 27 | onehot = torch.zeros(num_class).to(label) 28 | for i in label: 29 | onehot[i] = 1 30 | return onehot 31 | 32 | 33 | contractions = { 34 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 35 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 36 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 37 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 38 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 39 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 40 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 41 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 42 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 43 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 44 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 45 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 46 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 47 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 48 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 49 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 50 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 51 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 52 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 53 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 54 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 55 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 56 | "someonell": "someone'll", "someones": "someone's", "somethingd": 57 | "something'd", "somethingd've": "something'd've", "something'dve": 58 | "something'd've", "somethingll": "something'll", "thats": 59 | "that's", "thered": "there'd", "thered've": "there'd've", 60 | "there'dve": "there'd've", "therere": "there're", "theres": 61 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 62 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 63 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 64 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 65 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 66 | "what's", "whatve": "what've", "whens": "when's", "whered": 67 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 68 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 69 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 70 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 71 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 72 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 73 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 74 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 75 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 76 | "you'll", "youre": "you're", "youve": "you've" 77 | } 78 | 79 | manual_map = { 'none': '0', 80 | 'zero': '0', 81 | 'one': '1', 82 | 'two': '2', 83 | 'three': '3', 84 | 'four': '4', 85 | 'five': '5', 86 | 'six': '6', 87 | 'seven': '7', 88 | 'eight': '8', 89 | 'nine': '9', 90 | 'ten': '10'} 91 | articles = ['a', 'an', 'the'] 92 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 93 | comma_strip = re.compile("(\d)(\,)(\d)") 94 | punct = [';', r"/", '[', ']', '"', '{', '}', 95 | '(', ')', '=', '+', '\\', '_', '-', 96 | '>', '<', '@', '`', ',', '?', '!'] 97 | 98 | def process_punctuation(inText): 99 | outText = inText 100 | for p in punct: 101 | if (p + ' ' in inText or ' ' + p in inText) \ 102 | or (re.search(comma_strip, inText) != None): 103 | outText = outText.replace(p, '') 104 | else: 105 | outText = outText.replace(p, ' ') 106 | outText = period_strip.sub("", outText, re.UNICODE) 107 | return outText 108 | 109 | 110 | def process_digit_article(inText): 111 | outText = [] 112 | tempText = inText.lower().split() 113 | for word in tempText: 114 | word = manual_map.setdefault(word, word) 115 | if word not in articles: 116 | outText.append(word) 117 | else: 118 | pass 119 | for wordId, word in enumerate(outText): 120 | if word in contractions: 121 | outText[wordId] = contractions[word] 122 | outText = ' '.join(outText) 123 | return outText 124 | 125 | 126 | def prep_ans(answer): 127 | answer = process_digit_article(process_punctuation(answer)) 128 | answer = answer.replace(',', '') 129 | return answer 130 | 131 | 132 | class Scale(object): 133 | """Rescale the input PIL.Image to the given size. 134 | Args: 135 | size (sequence or int): Desired output size. If size is a sequence like 136 | (w, h), output size will be matched to this. If size is an int, 137 | smaller edge of the image will be matched to this number. 138 | i.e, if height > width, then image will be rescaled to 139 | (size * height / width, size) 140 | interpolation (int, optional): Desired interpolation. Default is 141 | ``PIL.Image.BILINEAR`` 142 | """ 143 | 144 | def __init__(self, size, interpolation=Image.BILINEAR): 145 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 146 | self.size = size 147 | self.interpolation = interpolation 148 | 149 | def __call__(self, img): 150 | """ 151 | Args: 152 | img (PIL.Image): Image to be scaled. 153 | Returns: 154 | PIL.Image: Rescaled image. 155 | """ 156 | if isinstance(self.size, int): 157 | w, h = img.size 158 | if (w <= h and w == self.size) or (h <= w and h == self.size): 159 | return img 160 | if w < h: 161 | ow = self.size 162 | oh = int(self.size * h / w) 163 | return img.resize((ow, oh), self.interpolation) 164 | else: 165 | oh = self.size 166 | ow = int(self.size * w / h) 167 | return img.resize((ow, oh), self.interpolation) 168 | else: 169 | return img.resize(self.size, self.interpolation) 170 | 171 | 172 | @register('gqa') 173 | class GQA(Dataset): 174 | def __init__(self, root_dir=None, split='train', eval_mode=0, **kwargs) -> None: 175 | self.eval_mode = eval_mode 176 | self.FRCN_FEAT_SIZE = utils.MCAN_GQA_PARAMS['FRCN_FEAT_SIZE'] 177 | self.GRID_FEAT_SIZE = utils.MCAN_GQA_PARAMS['GRID_FEAT_SIZE'] 178 | self.BBOX_FEAT_SIZE = utils.MCAN_GQA_PARAMS['BBOX_FEAT_SIZE'] 179 | self.root_dir = root_dir 180 | self.split = split 181 | self.dict_json = osp.join(self.root_dir, 'dicts.json') 182 | 183 | if self.eval_mode: 184 | # object/edge/degree 185 | fn = osp.join(self.root_dir, 'raw', 'questions1.2', '{}_sys_reduced_questions.json'.format(self.split)) 186 | fn_concept = osp.join(self.root_dir, 'raw', 'questions1.2', 'train_sys_reduced_concepts.json'.format(self.split)) 187 | else: 188 | if self.split is 'val': 189 | split = 'val' 190 | else: 191 | split = self.split 192 | fn = osp.join(self.root_dir, 'raw', 'questions1.2', '{}_balanced_questions.json'.format(split)) 193 | fn_concept = osp.join(self.root_dir, 'raw', 'questions1.2', 'train_balanced_concepts.json'.format(self.split)) 194 | with open(fn, 'r') as f: 195 | self.ques_dict = json.load(f) 196 | self.qid_list = list(self.ques_dict.keys()) 197 | with open(fn_concept, 'r') as f: 198 | self.concept_dict = json.load(f) 199 | 200 | def img_feat_path_load(path_list): 201 | iid_to_path = {} 202 | for ix, path in enumerate(path_list): 203 | iid = path.split('/')[-1].split('.')[0] 204 | iid_to_path[iid] = path 205 | 206 | return iid_to_path 207 | self.iid_to_img_path = img_feat_path_load(glob.glob(osp.join( 208 | self.root_dir, 209 | 'images', 210 | '*.jpg' 211 | ))) 212 | self.data_size = self.ques_dict.__len__() 213 | self.num_concept = 1615 214 | 215 | # Tokenize 216 | # self.token_to_ix, self.pretrained_emb, max_token = self.tokenize(self.dict_json, True) 217 | self.token_to_ix, _, max_token = self.tokenize(self.dict_json, True) 218 | self.token_size = self.token_to_ix.__len__() 219 | 220 | self.max_token = -1 221 | if self.max_token == -1: 222 | self.max_token = max_token 223 | 224 | # Answers statistic 225 | self.ans_to_ix, self.ix_to_ans = self.ans_stat(self.dict_json) 226 | self.ans_size = self.ans_to_ix.__len__() 227 | 228 | # For RGB 229 | self.pix_mean = (0.485, 0.456, 0.406) 230 | self.pix_std = (0.229, 0.224, 0.225) 231 | self.transform = transforms.Compose([ 232 | transforms.ToTensor(), 233 | transforms.Normalize(mean=self.pix_mean, std=self.pix_std) 234 | ]) 235 | # For two-view denseCL 236 | # https://github.com/WXinlong/DenseCL/blob/main/configs/selfsup/densecl/densecl_coco_800ep.py 237 | self.densecl_aug = transforms.Compose([ 238 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 239 | transforms.RandomApply(torch.nn.ModuleList([ 240 | transforms.ColorJitter( 241 | brightness=0.4, 242 | contrast=0.4, 243 | saturation=0.4, 244 | hue=0.1, 245 | )]), 246 | p=0.8 247 | ), 248 | transforms.RandomGrayscale(p=0.2), 249 | transforms.RandomApply(torch.nn.ModuleList([ 250 | transforms.GaussianBlur( 251 | # https://github.com/sthalles/SimCLR/blob/master/data_aug/contrastive_learning_dataset.py#L20 252 | kernel_size=(int(0.1*224)//2)*2+1, 253 | sigma=(0.1, 2.0) 254 | )]), 255 | p=0.5 256 | ), 257 | transforms.RandomHorizontalFlip(), 258 | ]) 259 | 260 | 261 | def __len__(self): 262 | return self.data_size 263 | 264 | def __getitem__(self, idx): 265 | ques_ix_iter, ans_iter, iid = self.load_ques_ans(idx, raw_q=False) 266 | img = Image.open(self.iid_to_img_path[iid]).convert('RGB') 267 | img = self.transform(img) 268 | second_im = self.densecl_aug(img) 269 | third_im = self.densecl_aug(img) 270 | 271 | concept = np.array(self.load_concept(idx), dtype=np.int64) 272 | concept = label_to_onehot(torch.from_numpy(concept).long(), self.num_concept) 273 | 274 | # len(ret), ret[-3] 275 | # 7, True 276 | ret = [] 277 | ret.append(img) 278 | ret.append(second_im) 279 | ret.append(third_im) 280 | ret.append(torch.from_numpy(ques_ix_iter)) 281 | ret.append(True) 282 | ret.append(torch.from_numpy(ans_iter)) 283 | ret.append(concept) 284 | 285 | return tuple(ret) 286 | 287 | def tokenize(self, json_file, use_glove): 288 | token_to_ix, max_token = json.load(open(json_file, 'r'))[2:] 289 | # spacy_tool = None 290 | # if use_glove: 291 | # spacy_tool = en_vectors_web_lg.load() 292 | 293 | # pretrained_emb = [] 294 | # for word in token_to_ix: 295 | # if use_glove: 296 | # pretrained_emb.append(spacy_tool(word).vector) 297 | # pretrained_emb = np.array(pretrained_emb) 298 | 299 | pretrained_emb = None 300 | return token_to_ix, pretrained_emb, max_token 301 | 302 | def ans_stat(self, json_file): 303 | ans_to_ix, ix_to_ans = json.load(open(json_file, 'r'))[:2] 304 | 305 | return ans_to_ix, ix_to_ans 306 | 307 | def load_concept(self, idx): 308 | qid = self.qid_list[idx] 309 | if self.split == 'train': 310 | ret = self.concept_dict[qid] 311 | else: 312 | ret = [] 313 | return ret 314 | 315 | def load_ques_ans(self, idx, raw_q=False): 316 | 317 | qid = self.qid_list[idx] 318 | iid = self.ques_dict[qid]['imageId'] 319 | 320 | ques = self.ques_dict[qid]['question'] 321 | if raw_q: 322 | ques_ix_iter = ques 323 | else: 324 | ques_ix_iter = self.proc_ques(ques, self.token_to_ix, max_token=self.max_token) 325 | 326 | # process answers 327 | ans = self.ques_dict[qid]['answer'] 328 | ans_iter = self.proc_ans(ans, self.ans_to_ix) 329 | 330 | return ques_ix_iter, ans_iter, iid 331 | 332 | def proc_img_feat(self, img_feat, img_feat_pad_size): 333 | if img_feat.shape[0] > img_feat_pad_size: 334 | img_feat = img_feat[:img_feat_pad_size] 335 | 336 | img_feat = np.pad( 337 | img_feat, 338 | ((0, img_feat_pad_size - img_feat.shape[0]), (0, 0)), 339 | mode='constant', 340 | constant_values=0 341 | ) 342 | 343 | return img_feat 344 | 345 | def proc_bbox_feat(self, bbox, img_shape): 346 | bbox_feat = np.zeros((bbox.shape[0], 5), dtype=np.float32) 347 | 348 | bbox_feat[:, 0] = bbox[:, 0] / float(img_shape[1]) 349 | bbox_feat[:, 1] = bbox[:, 1] / float(img_shape[0]) 350 | bbox_feat[:, 2] = bbox[:, 2] / float(img_shape[1]) 351 | bbox_feat[:, 3] = bbox[:, 3] / float(img_shape[0]) 352 | bbox_feat[:, 4] = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1]) / float(img_shape[0] * img_shape[1]) 353 | 354 | return bbox_feat 355 | 356 | def proc_ques(self, ques, token_to_ix, max_token): 357 | ques_ix = np.zeros(max_token, np.int64) 358 | 359 | words = re.sub( 360 | r"([.,'!?\"()*#:;])", 361 | '', 362 | ques.lower() 363 | ).replace('-', ' ').replace('/', ' ').split() 364 | 365 | for ix, word in enumerate(words): 366 | if word in token_to_ix: 367 | ques_ix[ix] = token_to_ix[word] 368 | else: 369 | ques_ix[ix] = token_to_ix['UNK'] 370 | 371 | if ix + 1 == max_token: 372 | break 373 | 374 | return ques_ix 375 | 376 | def proc_ans(self, ans, ans_to_ix): 377 | ans_ix = np.zeros(1, np.int64) 378 | ans = prep_ans(ans) 379 | ans_ix[0] = ans_to_ix[ans] 380 | 381 | return ans_ix 382 | 383 | 384 | def collate_gqa(batch): 385 | def _pad_tensor(tensor_list): 386 | max_imh, max_imw = -1, -1 387 | for tensor_i in tensor_list: 388 | # import pdb; pdb.set_trace() 389 | imh, imw = tensor_i.shape[-2], tensor_i.shape[-1] 390 | max_imh = max(max_imh, imh) 391 | max_imw = max(max_imw, imw) 392 | 393 | for idx, tensor_i in enumerate(tensor_list): 394 | pad_tensor_i = tensor_i.new_full(list(tensor_i.shape[:-2]) + [max_imh, max_imw], 0) 395 | imh, imw = tensor_i.shape[-2], tensor_i.shape[-1] 396 | pad_tensor_i[..., :imh, :imw].copy_(tensor_i) 397 | tensor_list[idx] = pad_tensor_i 398 | return tensor_list 399 | 400 | # len(ret), ret[-3] 401 | # 7, True 402 | if len(batch[0][0].shape) == 2: 403 | return default_collate(batch) 404 | else: 405 | list_ims = [] 406 | list_second_ims = [] 407 | list_third_ims = [] 408 | list_qs = [] 409 | list_xs = [] 410 | list_as = [] 411 | list_cs = [] 412 | for b in batch: 413 | list_ims.append(b[0]) 414 | list_second_ims.append(b[1]) 415 | list_third_ims.append(b[2]) 416 | list_qs.append(b[3]) 417 | list_xs.append(b[4]) 418 | list_as.append(b[5]) 419 | list_cs.append(b[6]) 420 | list_ims = torch.stack(_pad_tensor(list_ims)) 421 | list_second_ims = torch.stack(_pad_tensor(list_second_ims)) 422 | list_third_ims = torch.stack(_pad_tensor(list_third_ims)) 423 | if list_xs[0]: 424 | list_qs = torch.stack(list_qs) 425 | list_as = torch.stack(list_as) 426 | list_cs = torch.stack(list_cs) 427 | return list_ims, list_second_ims, list_third_ims, list_qs, list_xs, list_as, list_cs 428 | -------------------------------------------------------------------------------- /datasets/image_hicodet_bbox.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for RelViT. To view a copy of this license, see the LICENSE file. 6 | # ---------------------------------------------------------------------- 7 | 8 | import os 9 | import os.path as osp 10 | import json 11 | import pickle 12 | import numpy as np 13 | from PIL import Image, ImageFilter 14 | import sklearn.metrics 15 | 16 | import cv2 17 | from typing import Any, Callable, List, Optional, Tuple 18 | import torch 19 | from torch.utils.data import Dataset 20 | from torchvision import transforms 21 | 22 | # from detectron2 23 | from detectron2.structures import Boxes 24 | from detectron2.data import transforms as T 25 | 26 | import utils 27 | from .datasets import register 28 | 29 | 30 | def label_to_onehot(label, num_class): 31 | # label: [N] 32 | onehot = torch.zeros(num_class).to(label) 33 | for i in label: 34 | onehot[i] = 1 35 | return onehot 36 | 37 | 38 | @register('hicodet') 39 | class HICODet(Dataset): 40 | def __init__(self, image_size=256, box_size=256, eval_mode=0, **kwargs) -> None: 41 | self.eval_mode = eval_mode 42 | im_dir = kwargs.get('im_dir') 43 | self.split = kwargs.get('split') # train or test 44 | self._root = osp.join(im_dir, 'hico_20160224_det', 'images', '{}2015'.format(self.split)) 45 | # rare 46 | if self.eval_mode == 1: 47 | anno_file = osp.join(im_dir, 'hico_20160224_det', 'sys_vcl_rare_instances_{}2015.json'.format(self.split)) 48 | # non-rare 49 | elif self.eval_mode == 2: 50 | anno_file = osp.join(im_dir, 'hico_20160224_det', 'sys_vcl_nonrare_instances_{}2015.json'.format(self.split)) 51 | else: 52 | anno_file = osp.join(im_dir, 'hico_20160224_det', 'instances_{}2015.json'.format(self.split)) 53 | 54 | with open(anno_file, 'r') as f: 55 | anno = json.load(f) 56 | 57 | self.num_object_cls = 80 58 | self.num_interaction_cls = 600 59 | self.num_action_cls = 117 60 | self._anno_file = anno_file 61 | 62 | # Load annotations 63 | self._load_annotation_and_metadata(anno) 64 | 65 | self.pix_mean = (0.485, 0.456, 0.406) 66 | self.pix_std = (0.229, 0.224, 0.225) 67 | # detectron2-style data augmentation 68 | sample_style = 'range' 69 | augmentations = [T.ResizeShortestEdge(image_size, int(image_size * 2), sample_style)] 70 | if kwargs.get('augment') or kwargs.get('augment_plus'): 71 | augmentations.append( 72 | T.RandomFlip( 73 | horizontal=True, 74 | vertical=False, 75 | ) 76 | ) 77 | if kwargs.get('augment_plus'): 78 | self.photo_aug = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 79 | else: 80 | self.photo_aug = None 81 | self.augmentations = T.AugmentationList(augmentations) 82 | 83 | # For two-view denseCL 84 | # https://github.com/WXinlong/DenseCL/blob/main/configs/selfsup/densecl/densecl_coco_800ep.py 85 | self.densecl_aug = transforms.Compose([ 86 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 87 | transforms.RandomApply(torch.nn.ModuleList([ 88 | transforms.ColorJitter( 89 | brightness=0.4, 90 | contrast=0.4, 91 | saturation=0.4, 92 | hue=0.1, 93 | )]), 94 | p=0.8 95 | ), 96 | transforms.RandomGrayscale(p=0.2), 97 | transforms.RandomApply(torch.nn.ModuleList([ 98 | transforms.GaussianBlur( 99 | # https://github.com/sthalles/SimCLR/blob/master/data_aug/contrastive_learning_dataset.py#L20 100 | kernel_size=(int(0.1*224)//2)*2+1, 101 | sigma=(0.1, 2.0) 102 | )]), 103 | p=0.5 104 | ), 105 | transforms.RandomHorizontalFlip(), 106 | ]) 107 | 108 | def __len__(self) -> int: 109 | """Return the number of images""" 110 | return len(self._idx) 111 | 112 | def __getitem__(self, i: int) -> tuple: 113 | """ 114 | Arguments: 115 | i(int): Index to an image 116 | """ 117 | ind = self._idx[i] 118 | im = cv2.imread(osp.join(self._root, self._filenames[ind])) 119 | assert im is not None 120 | # BGR to RGB 121 | im = im[:, :, ::-1] 122 | # TODO: only keeps the first bbox pair 123 | sub_bbox = torch.Tensor([self._anno[ind]['boxes_h'][0]]) 124 | obj_bbox = torch.Tensor([self._anno[ind]['boxes_o'][0]]) 125 | union_bbox = torch.cat((torch.min(sub_bbox[0, :2], obj_bbox[0, :2]), torch.max(sub_bbox[0, 2:], obj_bbox[0, 2:]))).unsqueeze(0) 126 | boxes = torch.stack((sub_bbox, obj_bbox, union_bbox), dim=0) 127 | aug_input = T.AugInput(im, boxes=boxes) 128 | transforms = self.augmentations(aug_input) 129 | im, boxes = aug_input.image, torch.as_tensor(aug_input.boxes) 130 | im = im.astype(np.float32) 131 | 132 | def to_tensor(im, pix_mean, pix_std, normalize=True): 133 | if normalize: 134 | for i in range(3): 135 | im[:, :, i] = (im[:, :, i] / 255. - pix_mean[i]) / pix_std[i] 136 | im = torch.as_tensor(np.ascontiguousarray(im.transpose(2, 0, 1))).float() 137 | return im 138 | 139 | second_im = im.copy() 140 | third_im = im.copy() 141 | # Augment&tensorize the main view 142 | if self.photo_aug is not None: 143 | # color jittering of the input image 144 | im = np.array(self.photo_aug(Image.fromarray(im.astype(np.uint8))), dtype=np.float32) 145 | im = to_tensor(im, self.pix_mean, self.pix_std) 146 | 147 | # Augment&tensorize the second view 148 | second_im = np.array(self.densecl_aug(Image.fromarray(second_im.astype(np.uint8))), dtype=np.float32) 149 | second_im = to_tensor(second_im, self.pix_mean, self.pix_std) 150 | 151 | # Augment&tensorize the third view 152 | third_im = np.array(self.densecl_aug(Image.fromarray(third_im.astype(np.uint8))), dtype=np.float32) 153 | third_im = to_tensor(third_im, self.pix_mean, self.pix_std) 154 | 155 | hoi = label_to_onehot(torch.Tensor(self._anno[ind]['hoi']).long(), self.num_interaction_cls) 156 | verb = label_to_onehot(torch.Tensor(self._anno[ind]['verb']).long(), self.num_action_cls) 157 | object = label_to_onehot(torch.Tensor(self._anno[ind]['object']).long(), self.num_object_cls) 158 | return im, second_im, third_im, boxes, hoi, verb, object 159 | 160 | def _load_annotation_and_metadata(self, f: dict) -> None: 161 | """ 162 | Arguments: 163 | f(dict): Dictionary loaded from {anno_file}.json 164 | """ 165 | idx = list(range(len(f['filenames']))) 166 | for empty_idx in f['empty']: 167 | idx.remove(empty_idx) 168 | 169 | num_anno = [0 for _ in range(self.num_interaction_cls)] 170 | for anno in f['annotation']: 171 | for hoi in anno['hoi']: 172 | num_anno[hoi] += 1 173 | 174 | self._idx = idx 175 | self._num_anno = num_anno 176 | 177 | self._anno = f['annotation'] 178 | self._filenames = f['filenames'] 179 | self._image_sizes = f['size'] 180 | self._class_corr = f['correspondence'] 181 | self._empty_idx = f['empty'] 182 | self._objects = f['objects'] 183 | self._verbs = f['verbs'] 184 | 185 | def _ood_split(self): 186 | pass 187 | 188 | def compute_map_hico(y_true, y_score, easy=False, hard=False, rare_only=False): 189 | unseen_hoi_nonrare = np.array([38, 41, 20, 18, 245, 11, 19, 154, 459, 42, 155, 139, 60, 461, 577, 153, 582, 89, 141, 576, 75, 212, 472, 61, 190 | 457, 146, 208, 94, 471, 131, 248, 544, 515, 566, 370, 481, 226, 250, 470, 323, 169, 480, 479, 230, 385, 73, 191 | 159, 190, 377, 176, 249, 371, 284, 48, 583, 53, 162, 140, 185, 106, 294, 56, 320, 152, 374, 338, 29, 594, 346, 192 | 456, 589, 45, 23, 67, 478, 223, 493, 228, 240, 215, 91, 115, 337, 559, 7, 218, 518, 297, 191, 266, 304, 6, 572, 193 | 529, 312, 9, 308, 417, 197, 193, 163, 455, 25, 54, 575, 446, 387, 483, 534, 340, 508, 110, 329, 246, 173, 506, 194 | 383, 93, 516, 64]) 195 | 196 | unseen_hoi_rare = np.array([509, 279, 280, 402, 504, 286, 499, 498, 289, 485, 303, 311, 325, 439, 351, 358, 66, 427, 379, 418, 70, 416, 197 | 389, 90, 395, 76, 397, 84, 135, 262, 401, 592, 560, 586, 548, 593, 526, 181, 257, 539, 535, 260, 596, 345, 189, 198 | 205, 206, 429, 179, 350, 405, 522, 449, 261, 255, 546, 547, 44, 22, 334, 599, 239, 315, 317, 229, 158, 195, 199 | 238, 364, 222, 281, 149, 399, 83, 127, 254, 398, 403, 555, 552, 520, 531, 440, 436, 482, 274, 8, 188, 216, 597, 200 | 77, 407, 556, 469, 474, 107, 390, 410, 27, 381, 463, 99, 184, 100, 292, 517, 80, 333, 62, 354, 104, 55, 50, 201 | 198, 168, 391, 192, 595, 136, 581]) 202 | rare_hoi = np.load('./cache/rare.npy') 203 | if easy: 204 | index = unseen_hoi_rare 205 | if hard: 206 | index = unseen_hoi_nonrare 207 | if rare_only: 208 | index = rare_hoi 209 | if easy or hard or rare_only: 210 | y_true = y_true[:, index] 211 | y_score = y_score[:, index] 212 | meter = utils.AveragePrecisionMeter(algorithm='AUC', output=torch.Tensor(y_score), labels=torch.Tensor(y_true)) 213 | return meter.eval().mean() 214 | 215 | 216 | def collate_images_boxes_dict(batch): 217 | def _pad_tensor(tensor_list): 218 | max_imh, max_imw = -1, -1 219 | for tensor_i in tensor_list: 220 | # import pdb; pdb.set_trace() 221 | imh, imw = tensor_i.shape[-2], tensor_i.shape[-1] 222 | max_imh = max(max_imh, imh) 223 | max_imw = max(max_imw, imw) 224 | 225 | for idx, tensor_i in enumerate(tensor_list): 226 | pad_tensor_i = tensor_i.new_full(list(tensor_i.shape[:-2]) + [max_imh, max_imw], 0) 227 | imh, imw = tensor_i.shape[-2], tensor_i.shape[-1] 228 | pad_tensor_i[..., :imh, :imw].copy_(tensor_i) 229 | tensor_list[idx] = pad_tensor_i 230 | return tensor_list 231 | list_ims = [] 232 | list_second_ims = [] 233 | list_third_ims = [] 234 | list_boxes = [] 235 | list_hois = [] 236 | list_verbs = [] 237 | list_objects = [] 238 | for b in batch: 239 | list_ims.append(b[0]) 240 | list_second_ims.append(b[1]) 241 | list_third_ims.append(b[2]) 242 | list_boxes.append(b[3]) 243 | list_hois.append(b[4]) 244 | list_verbs.append(b[5]) 245 | list_objects.append(b[6]) 246 | 247 | list_ims = _pad_tensor(list_ims) 248 | list_second_ims = _pad_tensor(list_second_ims) 249 | list_third_ims = _pad_tensor(list_third_ims) 250 | return torch.stack(list_ims), torch.stack(list_second_ims), torch.stack(list_third_ims), torch.stack(list_boxes), torch.stack(list_hois), torch.stack(list_verbs), torch.stack(list_objects) 251 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for RelViT. To view a copy of this license, see the LICENSE file. 6 | # ---------------------------------------------------------------------- 7 | 8 | from .models import make, load 9 | from . import classifier 10 | from . import transparent_encoder 11 | from . import pvt_v2 12 | from . import swin_transformer 13 | from . import vit 14 | from . import mcan -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for RelViT. To view a copy of this license, see the LICENSE file. 6 | # ---------------------------------------------------------------------- 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | import models 15 | import utils 16 | from .models import register 17 | 18 | 19 | @register('classifier') 20 | class Classifier(nn.Module): 21 | 22 | def __init__(self, encoder, encoder_args, 23 | classifier, classifier_args): 24 | super().__init__() 25 | self.encoder = models.make(encoder, **encoder_args) 26 | classifier_args['in_dim'] = self.encoder.out_dim 27 | self.classifier = models.make(classifier, **classifier_args) 28 | 29 | def forward(self, x, boxes=None, info_nce=False): 30 | if info_nce: 31 | x, attn_v = self.encoder(x, boxes, info_nce=info_nce) 32 | x = self.classifier(x) 33 | return x, attn_v 34 | else: 35 | x = self.encoder(x, boxes) 36 | x = self.classifier(x) 37 | return x 38 | 39 | 40 | @register('linear-classifier') 41 | class LinearClassifier(nn.Module): 42 | 43 | def __init__(self, in_dim, n_classes): 44 | super().__init__() 45 | self.linear = nn.Linear(in_dim, n_classes) 46 | 47 | def forward(self, x): 48 | return self.linear(x) 49 | 50 | # Assume input is [B, C, H, W] 51 | @register('max-pooling-classifier') 52 | class MaxPoolingClassifier(nn.Module): 53 | 54 | def __init__(self, in_dim, n_classes, **kwargs): 55 | super().__init__() 56 | self.proj = nn.Sequential( 57 | nn.LayerNorm(in_dim), 58 | nn.Linear(in_dim, n_classes) 59 | ) 60 | 61 | def forward(self, x): 62 | # x: [B, C, W, H] 63 | B, C = x.size(0), x.size(1) 64 | x = x.reshape(B, C, -1).max(-1)[0] 65 | # x = x.reshape(B, C, -1)[:, :, 0] 66 | return self.proj(x) 67 | 68 | 69 | # Assume input is [B, C, H, W] 70 | @register('max-pooling-classifier-twoheads') 71 | class MaxPoolingClassifierTwoHeads(nn.Module): 72 | 73 | def __init__(self, in_dim, n_classes1, n_classes2, **kwargs): 74 | super().__init__() 75 | self.proj1 = nn.Sequential( 76 | nn.LayerNorm(in_dim), 77 | nn.Linear(in_dim, n_classes1) 78 | ) 79 | self.proj2 = nn.Sequential( 80 | nn.LayerNorm(in_dim), 81 | nn.Linear(in_dim, n_classes2) 82 | ) 83 | 84 | def forward(self, x): 85 | # x: [B, C, W, H] 86 | B, C = x.size(0), x.size(1) 87 | x = x.reshape(B, C, -1).max(-1)[0] 88 | # x = x.reshape(B, C, -1)[:, :, 0] 89 | return (self.proj1(x), self.proj2(x)) 90 | 91 | 92 | @register('nn-classifier') 93 | class NNClassifier(nn.Module): 94 | 95 | def __init__(self, in_dim, n_classes, metric='cos', temp=None): 96 | super().__init__() 97 | self.proto = nn.Parameter(torch.empty(n_classes, in_dim)) 98 | nn.init.kaiming_uniform_(self.proto, a=math.sqrt(5)) 99 | if temp is None: 100 | if metric == 'cos': 101 | temp = nn.Parameter(torch.tensor(10.)) 102 | else: 103 | temp = 1.0 104 | self.metric = metric 105 | self.temp = temp 106 | 107 | def forward(self, x): 108 | return utils.compute_logits(x, self.proto, self.metric, self.temp) 109 | -------------------------------------------------------------------------------- /models/mcan.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the mcan-vqa library 5 | # which was released under the Apache Licence. 6 | # 7 | # Source: 8 | # https://github.com/MILVLG/mcan-vqa 9 | # 10 | # The license for the original version of this file can be 11 | # found in https://github.com/MILVLG/mcan-vqa/blob/master/LICENSE 12 | # The modifications to this file are subject to the same Apache Licence. 13 | # --------------------------------------------------------------- 14 | 15 | import math 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | import models 23 | import utils 24 | from .models import register 25 | 26 | 27 | def make_mask(feature): 28 | return (torch.sum( 29 | torch.abs(feature), 30 | dim=-1 31 | ) == 0).unsqueeze(1).unsqueeze(2) 32 | 33 | 34 | class FC(nn.Module): 35 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True): 36 | super(FC, self).__init__() 37 | self.dropout_r = dropout_r 38 | self.use_relu = use_relu 39 | 40 | self.linear = nn.Linear(in_size, out_size) 41 | 42 | if use_relu: 43 | self.relu = nn.ReLU(inplace=True) 44 | 45 | if dropout_r > 0: 46 | self.dropout = nn.Dropout(dropout_r) 47 | 48 | def forward(self, x): 49 | x = self.linear(x) 50 | 51 | if self.use_relu: 52 | x = self.relu(x) 53 | 54 | if self.dropout_r > 0: 55 | x = self.dropout(x) 56 | 57 | return x 58 | 59 | 60 | class MLP(nn.Module): 61 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True): 62 | super(MLP, self).__init__() 63 | 64 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu) 65 | self.linear = nn.Linear(mid_size, out_size) 66 | 67 | def forward(self, x): 68 | return self.linear(self.fc(x)) 69 | 70 | 71 | class LayerNorm(nn.Module): 72 | def __init__(self, size, eps=1e-6): 73 | super(LayerNorm, self).__init__() 74 | self.eps = eps 75 | 76 | self.a_2 = nn.Parameter(torch.ones(size)) 77 | self.b_2 = nn.Parameter(torch.zeros(size)) 78 | 79 | def forward(self, x): 80 | mean = x.mean(-1, keepdim=True) 81 | std = x.std(-1, keepdim=True) 82 | 83 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 84 | 85 | 86 | # ------------------------------ 87 | # ---- Multi-Head Attention ---- 88 | # ------------------------------ 89 | 90 | class MHAtt(nn.Module): 91 | def __init__(self): 92 | super(MHAtt, self).__init__() 93 | 94 | self.linear_v = nn.Linear(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 95 | self.linear_k = nn.Linear(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 96 | self.linear_q = nn.Linear(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 97 | self.linear_merge = nn.Linear(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 98 | 99 | self.dropout = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R']) 100 | 101 | def forward(self, v, k, q, mask): 102 | n_batches = q.size(0) 103 | 104 | v = self.linear_v(v).view( 105 | n_batches, 106 | -1, 107 | utils.MCAN_GQA_PARAMS['MULTI_HEAD'], 108 | int(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'] / utils.MCAN_GQA_PARAMS['MULTI_HEAD']) 109 | ).transpose(1, 2) 110 | 111 | k = self.linear_k(k).view( 112 | n_batches, 113 | -1, 114 | utils.MCAN_GQA_PARAMS['MULTI_HEAD'], 115 | int(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'] / utils.MCAN_GQA_PARAMS['MULTI_HEAD']) 116 | ).transpose(1, 2) 117 | 118 | q = self.linear_q(q).view( 119 | n_batches, 120 | -1, 121 | utils.MCAN_GQA_PARAMS['MULTI_HEAD'], 122 | int(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'] / utils.MCAN_GQA_PARAMS['MULTI_HEAD']) 123 | ).transpose(1, 2) 124 | 125 | atted = self.att(v, k, q, mask) 126 | atted = atted.transpose(1, 2).contiguous().view( 127 | n_batches, 128 | -1, 129 | utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'] 130 | ) 131 | 132 | atted = self.linear_merge(atted) 133 | 134 | return atted 135 | 136 | def att(self, value, key, query, mask): 137 | d_k = query.size(-1) 138 | 139 | scores = torch.matmul( 140 | query, key.transpose(-2, -1) 141 | ) / math.sqrt(d_k) 142 | 143 | if mask is not None: 144 | scores = scores.masked_fill(mask, -1e9) 145 | 146 | att_map = F.softmax(scores, dim=-1) 147 | att_map = self.dropout(att_map) 148 | 149 | return torch.matmul(att_map, value) 150 | 151 | 152 | # --------------------------- 153 | # ---- Feed Forward Nets ---- 154 | # --------------------------- 155 | 156 | class FFN(nn.Module): 157 | def __init__(self): 158 | super(FFN, self).__init__() 159 | 160 | self.mlp = MLP( 161 | in_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], 162 | mid_size=utils.MCAN_GQA_PARAMS['FF_SIZE'], 163 | out_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], 164 | dropout_r=utils.MCAN_GQA_PARAMS['DROPOUT_R'], 165 | use_relu=True 166 | ) 167 | 168 | def forward(self, x): 169 | return self.mlp(x) 170 | 171 | 172 | # ------------------------ 173 | # ---- Self Attention ---- 174 | # ------------------------ 175 | 176 | class SA(nn.Module): 177 | def __init__(self): 178 | super(SA, self).__init__() 179 | 180 | self.mhatt = MHAtt() 181 | self.ffn = FFN() 182 | 183 | self.dropout1 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R']) 184 | self.norm1 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 185 | 186 | self.dropout2 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R']) 187 | self.norm2 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 188 | 189 | def forward(self, y, y_mask): 190 | y = self.norm1(y + self.dropout1( 191 | self.mhatt(y, y, y, y_mask) 192 | )) 193 | 194 | y = self.norm2(y + self.dropout2( 195 | self.ffn(y) 196 | )) 197 | 198 | return y 199 | 200 | 201 | # ------------------------------- 202 | # ---- Self Guided Attention ---- 203 | # ------------------------------- 204 | 205 | class SGA(nn.Module): 206 | def __init__(self): 207 | super(SGA, self).__init__() 208 | 209 | self.mhatt1 = MHAtt() 210 | self.mhatt2 = MHAtt() 211 | self.ffn = FFN() 212 | 213 | self.dropout1 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R']) 214 | self.norm1 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 215 | 216 | self.dropout2 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R']) 217 | self.norm2 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 218 | 219 | self.dropout3 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R']) 220 | self.norm3 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 221 | 222 | def forward(self, x, y, x_mask, y_mask): 223 | x = self.norm1(x + self.dropout1( 224 | self.mhatt1(v=x, k=x, q=x, mask=x_mask) 225 | )) 226 | 227 | x = self.norm2(x + self.dropout2( 228 | self.mhatt2(v=y, k=y, q=x, mask=y_mask) 229 | )) 230 | 231 | x = self.norm3(x + self.dropout3( 232 | self.ffn(x) 233 | )) 234 | 235 | return x 236 | 237 | 238 | # ------------------------------------------------ 239 | # ---- MAC Layers Cascaded by Encoder-Decoder ---- 240 | # ------------------------------------------------ 241 | 242 | class MCA_ED(nn.Module): 243 | def __init__(self): 244 | super(MCA_ED, self).__init__() 245 | 246 | self.enc_list = nn.ModuleList([SA() for _ in range(utils.MCAN_GQA_PARAMS['LAYER'])]) 247 | self.dec_list = nn.ModuleList([SGA() for _ in range(utils.MCAN_GQA_PARAMS['LAYER'])]) 248 | 249 | def forward(self, y, x, y_mask, x_mask): 250 | # Get encoder last hidden vector 251 | for enc in self.enc_list: 252 | y = enc(y, y_mask) 253 | 254 | # Input encoder last hidden vector 255 | # And obtain decoder last hidden vectors 256 | for dec in self.dec_list: 257 | x = dec(x, y, x_mask, y_mask) 258 | 259 | return y, x 260 | 261 | 262 | def feat_filter(frcn_feat, grid_feat, bbox_feat): 263 | feat_dict = {} 264 | 265 | feat_dict['FRCN_FEAT'] = frcn_feat 266 | feat_dict['GRID_FEAT'] = grid_feat 267 | feat_dict['BBOX_FEAT'] = bbox_feat 268 | 269 | return feat_dict 270 | 271 | 272 | class BaseAdapter(nn.Module): 273 | def __init__(self): 274 | super(BaseAdapter, self).__init__() 275 | self.gqa_init() 276 | 277 | def gqa_init(self): 278 | raise NotImplementedError() 279 | 280 | def forward(self, frcn_feat, grid_feat, bbox_feat): 281 | feat_dict = feat_filter(frcn_feat, grid_feat, bbox_feat) 282 | 283 | return self.gqa_forward(feat_dict) 284 | 285 | def gqa_forward(self, feat_dict): 286 | raise NotImplementedError() 287 | 288 | 289 | class Adapter(BaseAdapter): 290 | def __init__(self): 291 | super(Adapter, self).__init__() 292 | 293 | def bbox_proc(self, bbox): 294 | area = (bbox[:, :, 2] - bbox[:, :, 0]) * (bbox[:, :, 3] - bbox[:, :, 1]) 295 | # return torch.cat((bbox, area), -1) 296 | ##### FIXME: possibly buggy 297 | return torch.cat((bbox, area.unsqueeze(2)), -1) 298 | ##### 299 | 300 | def gqa_init(self): 301 | imgfeat_linear_size = utils.MCAN_GQA_PARAMS['FRCN_FEAT_SIZE'][1] 302 | if utils.MCAN_GQA_PARAMS['USE_BBOX_FEAT']: 303 | self.bbox_linear = nn.Linear(5, utils.MCAN_GQA_PARAMS['BBOXFEAT_EMB_SIZE']) 304 | imgfeat_linear_size += utils.MCAN_GQA_PARAMS['BBOXFEAT_EMB_SIZE'] 305 | self.frcn_linear = nn.Linear(imgfeat_linear_size, utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 306 | 307 | if utils.MCAN_GQA_PARAMS['USE_AUX_FEAT']: 308 | self.grid_linear = nn.Linear(utils.MCAN_GQA_PARAMS['GRID_FEAT_SIZE'][1], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 309 | 310 | 311 | def gqa_forward(self, feat_dict): 312 | frcn_feat = feat_dict['FRCN_FEAT'] 313 | bbox_feat = feat_dict['BBOX_FEAT'] 314 | grid_feat = feat_dict['GRID_FEAT'] 315 | 316 | img_feat_mask = make_mask(frcn_feat) 317 | 318 | if utils.MCAN_GQA_PARAMS['USE_BBOX_FEAT']: 319 | ##### FIXME: possibly buggy 320 | # bbox_feat = self.bbox_proc(bbox_feat) 321 | ##### 322 | bbox_feat = self.bbox_linear(bbox_feat) 323 | frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1) 324 | img_feat = self.frcn_linear(frcn_feat) 325 | 326 | if utils.MCAN_GQA_PARAMS['USE_AUX_FEAT']: 327 | grid_feat_mask = make_mask(grid_feat) 328 | img_feat_mask = torch.cat((img_feat_mask, grid_feat_mask), dim=-1) 329 | grid_feat = self.grid_linear(grid_feat) 330 | img_feat = torch.cat((img_feat, grid_feat), dim=1) 331 | 332 | return img_feat, img_feat_mask 333 | 334 | 335 | # ------------------------------ 336 | # ---- Flatten the sequence ---- 337 | # ------------------------------ 338 | 339 | class AttFlat(nn.Module): 340 | def __init__(self): 341 | super(AttFlat, self).__init__() 342 | 343 | self.mlp = MLP( 344 | in_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], 345 | mid_size=utils.MCAN_GQA_PARAMS['FLAT_MLP_SIZE'], 346 | out_size=utils.MCAN_GQA_PARAMS['FLAT_GLIMPSES'], 347 | dropout_r=utils.MCAN_GQA_PARAMS['DROPOUT_R'], 348 | use_relu=True 349 | ) 350 | 351 | self.linear_merge = nn.Linear( 352 | utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'] * utils.MCAN_GQA_PARAMS['FLAT_GLIMPSES'], 353 | utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE'] 354 | ) 355 | 356 | def forward(self, x, x_mask): 357 | att = self.mlp(x) 358 | att = att.masked_fill( 359 | x_mask.squeeze(1).squeeze(1).unsqueeze(2), 360 | -1e9 361 | ) 362 | att = F.softmax(att, dim=1) 363 | 364 | att_list = [] 365 | for i in range(utils.MCAN_GQA_PARAMS['FLAT_GLIMPSES']): 366 | att_list.append( 367 | torch.sum(att[:, :, i: i + 1] * x, dim=1) 368 | ) 369 | 370 | x_atted = torch.cat(att_list, dim=1) 371 | x_atted = self.linear_merge(x_atted) 372 | 373 | return x_atted 374 | 375 | 376 | @register('mcan') 377 | class MCAN(nn.Module): 378 | def __init__(self, **kwargs): 379 | super(MCAN, self).__init__() 380 | 381 | self.embedding = nn.Embedding( 382 | num_embeddings=utils.MCAN_GQA_PARAMS['TOKEN_SIZE'], 383 | embedding_dim=utils.MCAN_GQA_PARAMS['WORD_EMBED_SIZE'] 384 | ) 385 | 386 | # Loading the GloVe embedding weights 387 | self.embedding.weight.data.copy_(torch.from_numpy(np.load(kwargs.get('word_emb_path')))) 388 | 389 | self.lstm = nn.LSTM( 390 | input_size=utils.MCAN_GQA_PARAMS['WORD_EMBED_SIZE'], 391 | hidden_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], 392 | num_layers=1, 393 | batch_first=True 394 | ) 395 | 396 | self.adapter = Adapter() 397 | 398 | self.backbone = MCA_ED() 399 | 400 | # Flatten to vector 401 | self.attflat_img = AttFlat() 402 | self.attflat_lang = AttFlat() 403 | 404 | # Classification layers 405 | self.proj_norm = LayerNorm(utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE']) 406 | self.proj = nn.Linear(utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE'], utils.MCAN_GQA_PARAMS['ANSWER_SIZE']) 407 | 408 | 409 | def forward(self, frcn_feat, grid_feat, bbox_feat, ques_ix, info_nce=False, pretrain=False): 410 | 411 | # Pre-process Language Feature 412 | lang_feat_mask = make_mask(ques_ix.unsqueeze(2)) 413 | lang_feat = self.embedding(ques_ix) 414 | lang_feat, _ = self.lstm(lang_feat) 415 | 416 | img_feat, img_feat_mask = self.adapter(frcn_feat, grid_feat, bbox_feat) 417 | 418 | # Backbone Framework 419 | # lang_feat: [B, MAX_TOKEN_LENGTH, HIDDEN_SIZE] 420 | # img_feat: [B, FRCN_FEAT_SIZE[0]+GRID_FEAT_SIZE[0], HIDDEN_SIZE] 421 | lang_feat, img_feat = self.backbone( 422 | lang_feat, 423 | img_feat, 424 | lang_feat_mask, 425 | img_feat_mask 426 | ) 427 | 428 | # Flatten to vector 429 | lang_feat = self.attflat_lang( 430 | lang_feat, 431 | lang_feat_mask 432 | ) 433 | 434 | img_feat = self.attflat_img( 435 | img_feat, 436 | img_feat_mask 437 | ) 438 | 439 | # Classification layers 440 | proj_feat = lang_feat + img_feat 441 | proj_feat = self.proj_norm(proj_feat) 442 | proj_feat = self.proj(proj_feat) 443 | 444 | return proj_feat 445 | 446 | 447 | @register('mcan-customized') 448 | class MCANCustomized(nn.Module): 449 | def __init__(self, **kwargs): 450 | super(MCANCustomized, self).__init__() 451 | 452 | self.encoder = models.make(kwargs.get('encoder'), **kwargs.get('encoder_args')) 453 | self.connector = nn.Linear(self.encoder.out_dim, utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']) 454 | 455 | self.embedding = nn.Embedding( 456 | num_embeddings=utils.MCAN_GQA_PARAMS['TOKEN_SIZE'], 457 | embedding_dim=utils.MCAN_GQA_PARAMS['WORD_EMBED_SIZE'] 458 | ) 459 | 460 | # Loading the GloVe embedding weights 461 | self.embedding.weight.data.copy_(torch.from_numpy(np.load(kwargs.get('word_emb_path')))) 462 | 463 | self.lstm = nn.LSTM( 464 | input_size=utils.MCAN_GQA_PARAMS['WORD_EMBED_SIZE'], 465 | hidden_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], 466 | num_layers=1, 467 | batch_first=True 468 | ) 469 | 470 | self.backbone = MCA_ED() 471 | 472 | # Flatten to vector 473 | self.attflat_img = AttFlat() 474 | self.attflat_lang = AttFlat() 475 | 476 | # Classification layers 477 | self.proj_norm = LayerNorm(utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE']) 478 | self.proj = nn.Linear(utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE'], utils.MCAN_GQA_PARAMS['ANSWER_SIZE']) 479 | 480 | 481 | def forward(self, ims, ques_ix, info_nce=False, pretrain=False): 482 | if pretrain: 483 | B = ims.size(0) 484 | logits = torch.zeros(B, utils.MCAN_GQA_PARAMS['ANSWER_SIZE']).to(ims) 485 | if info_nce: 486 | img_feat, attn_v = self.encoder(ims, info_nce=info_nce) 487 | return logits, attn_v 488 | else: 489 | return logits 490 | else: 491 | if info_nce: 492 | img_feat, attn_v = self.encoder(ims, info_nce=info_nce) 493 | else: 494 | img_feat = self.encoder(ims) 495 | B, C = img_feat.size(0), img_feat.size(1) 496 | img_feat = self.connector(img_feat.reshape(B, C, -1).permute(0, 2, 1)) 497 | img_feat_mask = make_mask(img_feat) 498 | 499 | # Pre-process Language Feature 500 | lang_feat_mask = make_mask(ques_ix.unsqueeze(2)) 501 | lang_feat = self.embedding(ques_ix) 502 | self.lstm.flatten_parameters() 503 | lang_feat, _ = self.lstm(lang_feat) 504 | 505 | # Backbone Framework 506 | lang_feat, img_feat = self.backbone( 507 | lang_feat, 508 | img_feat, 509 | lang_feat_mask, 510 | img_feat_mask 511 | ) 512 | 513 | # Flatten to vector 514 | lang_feat = self.attflat_lang( 515 | lang_feat, 516 | lang_feat_mask 517 | ) 518 | 519 | img_feat = self.attflat_img( 520 | img_feat, 521 | img_feat_mask 522 | ) 523 | 524 | # Classification layers 525 | proj_feat = lang_feat + img_feat 526 | proj_feat = self.proj_norm(proj_feat) 527 | proj_feat = self.proj(proj_feat) 528 | 529 | if info_nce: 530 | return proj_feat, attn_v 531 | else: 532 | return proj_feat 533 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the Bongard-HOI library 5 | # which was released under the NVIDIA Source Code Licence. 6 | # 7 | # Source: 8 | # https://github.com/NVlabs/Bongard-HOI 9 | # 10 | # The license for the original version of this file can be 11 | # found in https://github.com/NVlabs/Bongard-HOI/blob/master/LICENSE 12 | # The modifications to this file are subject to the same NVIDIA Source Code Licence. 13 | # --------------------------------------------------------------- 14 | 15 | import torch 16 | 17 | 18 | models = {} 19 | def register(name): 20 | def decorator(cls): 21 | models[name] = cls 22 | return cls 23 | return decorator 24 | 25 | 26 | def make(name, **kwargs): 27 | if name is None: 28 | return None 29 | model = models[name](**kwargs) 30 | # if torch.cuda.is_available(): 31 | # model = model.cuda() 32 | return model 33 | 34 | 35 | def load(model_sv, name=None): 36 | if name is None: 37 | name = 'model' 38 | model = make(model_sv[name], **model_sv[name + '_args']) 39 | missing_keys, unexpected_keys = model.load_state_dict(model_sv[name + '_sd']) 40 | return model 41 | -------------------------------------------------------------------------------- /models/pvt_v2.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the PVT library 5 | # which was released under the Apache Licence. 6 | # 7 | # Source: 8 | # https://github.com/whai362/PVT 9 | # 10 | # The license for the original version of this file can be 11 | # found in https://github.com/whai362/PVT/blob/v2/LICENSE 12 | # The modifications to this file are subject to the same Apache Licence. 13 | # --------------------------------------------------------------- 14 | 15 | import torch 16 | import torch.nn as nn 17 | import math 18 | import warnings 19 | from functools import partial 20 | from itertools import repeat 21 | import collections.abc 22 | 23 | import utils 24 | from .models import register 25 | 26 | __all__ = [ 27 | 'pvt_v2_b0', 'pvt_v2_b1', 'pvt_v2_b2' 28 | ] 29 | 30 | 31 | def drop_path(x, drop_prob: float = 0., training: bool = False): 32 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 33 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 34 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 35 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 36 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 37 | 'survival rate' as the argument. 38 | """ 39 | if drop_prob == 0. or not training: 40 | return x 41 | keep_prob = 1 - drop_prob 42 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 43 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 44 | random_tensor.floor_() # binarize 45 | output = x.div(keep_prob) * random_tensor 46 | return output 47 | 48 | 49 | class DropPath(nn.Module): 50 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 51 | """ 52 | def __init__(self, drop_prob=None): 53 | super(DropPath, self).__init__() 54 | self.drop_prob = drop_prob 55 | 56 | def forward(self, x): 57 | return drop_path(x, self.drop_prob, self.training) 58 | 59 | 60 | def _ntuple(n): 61 | def parse(x): 62 | if isinstance(x, collections.abc.Iterable): 63 | return x 64 | return tuple(repeat(x, n)) 65 | return parse 66 | 67 | 68 | to_1tuple = _ntuple(1) 69 | to_2tuple = _ntuple(2) 70 | to_3tuple = _ntuple(3) 71 | to_4tuple = _ntuple(4) 72 | to_ntuple = _ntuple 73 | 74 | 75 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 76 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 77 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 78 | def norm_cdf(x): 79 | # Computes standard normal cumulative distribution function 80 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 81 | 82 | if (mean < a - 2 * std) or (mean > b + 2 * std): 83 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 84 | "The distribution of values may be incorrect.", 85 | stacklevel=2) 86 | 87 | with torch.no_grad(): 88 | # Values are generated by using a truncated uniform distribution and 89 | # then using the inverse CDF for the normal distribution. 90 | # Get upper and lower cdf values 91 | l = norm_cdf((a - mean) / std) 92 | u = norm_cdf((b - mean) / std) 93 | 94 | # Uniformly fill tensor with values from [l, u], then translate to 95 | # [2l-1, 2u-1]. 96 | tensor.uniform_(2 * l - 1, 2 * u - 1) 97 | 98 | # Use inverse cdf transform for normal distribution to get truncated 99 | # standard normal 100 | tensor.erfinv_() 101 | 102 | # Transform to proper mean, std 103 | tensor.mul_(std * math.sqrt(2.)) 104 | tensor.add_(mean) 105 | 106 | # Clamp to ensure it's in the proper range 107 | tensor.clamp_(min=a, max=b) 108 | return tensor 109 | 110 | 111 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 112 | # type: (Tensor, float, float, float, float) -> Tensor 113 | r"""Fills the input Tensor with values drawn from a truncated 114 | normal distribution. The values are effectively drawn from the 115 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 116 | with values outside :math:`[a, b]` redrawn until they are within 117 | the bounds. The method used for generating the random values works 118 | best when :math:`a \leq \text{mean} \leq b`. 119 | Args: 120 | tensor: an n-dimensional `torch.Tensor` 121 | mean: the mean of the normal distribution 122 | std: the standard deviation of the normal distribution 123 | a: the minimum cutoff value 124 | b: the maximum cutoff value 125 | Examples: 126 | >>> w = torch.empty(3, 5) 127 | >>> nn.init.trunc_normal_(w) 128 | """ 129 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 130 | 131 | 132 | class Mlp(nn.Module): 133 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): 134 | super().__init__() 135 | out_features = out_features or in_features 136 | hidden_features = hidden_features or in_features 137 | self.fc1 = nn.Linear(in_features, hidden_features) 138 | self.dwconv = DWConv(hidden_features) 139 | self.act = act_layer() 140 | self.fc2 = nn.Linear(hidden_features, out_features) 141 | self.drop = nn.Dropout(drop) 142 | self.linear = linear 143 | if self.linear: 144 | self.relu = nn.ReLU(inplace=True) 145 | self.apply(self._init_weights) 146 | 147 | def _init_weights(self, m): 148 | if isinstance(m, nn.Linear): 149 | trunc_normal_(m.weight, std=.02) 150 | if isinstance(m, nn.Linear) and m.bias is not None: 151 | nn.init.constant_(m.bias, 0) 152 | elif isinstance(m, nn.LayerNorm): 153 | nn.init.constant_(m.bias, 0) 154 | nn.init.constant_(m.weight, 1.0) 155 | elif isinstance(m, nn.Conv2d): 156 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 157 | fan_out //= m.groups 158 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 159 | if m.bias is not None: 160 | m.bias.data.zero_() 161 | 162 | def forward(self, x, H, W): 163 | x = self.fc1(x) 164 | if self.linear: 165 | x = self.relu(x) 166 | x = self.dwconv(x, H, W).contiguous() 167 | x = self.act(x) 168 | x = self.drop(x) 169 | x = self.fc2(x) 170 | x = self.drop(x) 171 | return x 172 | 173 | 174 | class Attention(nn.Module): 175 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False): 176 | super().__init__() 177 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 178 | 179 | self.dim = dim 180 | self.num_heads = num_heads 181 | head_dim = dim // num_heads 182 | self.scale = qk_scale or head_dim ** -0.5 183 | 184 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 185 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 186 | self.attn_drop = nn.Dropout(attn_drop) 187 | self.proj = nn.Linear(dim, dim) 188 | self.proj_drop = nn.Dropout(proj_drop) 189 | 190 | self.linear = linear 191 | self.sr_ratio = sr_ratio 192 | if not linear: 193 | if sr_ratio > 1: 194 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 195 | self.norm = nn.LayerNorm(dim) 196 | else: 197 | self.pool = nn.AdaptiveAvgPool2d(7) 198 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 199 | self.norm = nn.LayerNorm(dim) 200 | self.act = nn.GELU() 201 | self.apply(self._init_weights) 202 | 203 | def _init_weights(self, m): 204 | if isinstance(m, nn.Linear): 205 | trunc_normal_(m.weight, std=.02) 206 | if isinstance(m, nn.Linear) and m.bias is not None: 207 | nn.init.constant_(m.bias, 0) 208 | elif isinstance(m, nn.LayerNorm): 209 | nn.init.constant_(m.bias, 0) 210 | nn.init.constant_(m.weight, 1.0) 211 | elif isinstance(m, nn.Conv2d): 212 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 213 | fan_out //= m.groups 214 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 215 | if m.bias is not None: 216 | m.bias.data.zero_() 217 | 218 | def forward(self, x, H, W, info_nce=False): 219 | B, N, C = x.shape 220 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 221 | 222 | if not self.linear: 223 | if self.sr_ratio > 1: 224 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 225 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 226 | x_ = self.norm(x_) 227 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 228 | else: 229 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 230 | else: 231 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 232 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) 233 | x_ = self.norm(x_) 234 | x_ = self.act(x_) 235 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 236 | k, v = kv[0], kv[1] 237 | 238 | attn = (q @ k.transpose(-2, -1)) * self.scale 239 | attn = attn.softmax(dim=-1) 240 | attn = self.attn_drop(attn) 241 | if info_nce: 242 | attn_ret = attn.detach() 243 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 244 | x = self.proj(x) 245 | x = self.proj_drop(x) 246 | if info_nce: 247 | return x, (attn_ret, v) 248 | else: 249 | return x 250 | 251 | 252 | class Block(nn.Module): 253 | 254 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 255 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): 256 | super().__init__() 257 | self.norm1 = norm_layer(dim) 258 | self.attn = Attention( 259 | dim, 260 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 261 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) 262 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 263 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 264 | self.norm2 = norm_layer(dim) 265 | mlp_hidden_dim = int(dim * mlp_ratio) 266 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) 267 | 268 | self.apply(self._init_weights) 269 | 270 | def _init_weights(self, m): 271 | if isinstance(m, nn.Linear): 272 | trunc_normal_(m.weight, std=.02) 273 | if isinstance(m, nn.Linear) and m.bias is not None: 274 | nn.init.constant_(m.bias, 0) 275 | elif isinstance(m, nn.LayerNorm): 276 | nn.init.constant_(m.bias, 0) 277 | nn.init.constant_(m.weight, 1.0) 278 | elif isinstance(m, nn.Conv2d): 279 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 280 | fan_out //= m.groups 281 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 282 | if m.bias is not None: 283 | m.bias.data.zero_() 284 | 285 | def forward(self, x, H, W, info_nce=False): 286 | if info_nce: 287 | tmp, attn_v = self.attn(self.norm1(x), H, W, info_nce) 288 | else: 289 | tmp = self.attn(self.norm1(x), H, W) 290 | x = x + self.drop_path(tmp) 291 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 292 | if info_nce: 293 | return x, attn_v 294 | else: 295 | return x 296 | 297 | 298 | class OverlapPatchEmbed(nn.Module): 299 | """ Image to Patch Embedding 300 | """ 301 | 302 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 303 | super().__init__() 304 | img_size = to_2tuple(img_size) 305 | patch_size = to_2tuple(patch_size) 306 | 307 | self.img_size = img_size 308 | self.patch_size = patch_size 309 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 310 | self.num_patches = self.H * self.W 311 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 312 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 313 | self.norm = nn.LayerNorm(embed_dim) 314 | 315 | self.apply(self._init_weights) 316 | 317 | def _init_weights(self, m): 318 | if isinstance(m, nn.Linear): 319 | trunc_normal_(m.weight, std=.02) 320 | if isinstance(m, nn.Linear) and m.bias is not None: 321 | nn.init.constant_(m.bias, 0) 322 | elif isinstance(m, nn.LayerNorm): 323 | nn.init.constant_(m.bias, 0) 324 | nn.init.constant_(m.weight, 1.0) 325 | elif isinstance(m, nn.Conv2d): 326 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 327 | fan_out //= m.groups 328 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 329 | if m.bias is not None: 330 | m.bias.data.zero_() 331 | 332 | def forward(self, x): 333 | x = self.proj(x) 334 | _, _, H, W = x.shape 335 | x = x.flatten(2).transpose(1, 2).contiguous() 336 | x = self.norm(x) 337 | 338 | return x, H, W 339 | 340 | 341 | class PyramidVisionTransformerV2(nn.Module): 342 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 343 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 344 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], 345 | sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False): 346 | super().__init__() 347 | self.depths = depths 348 | self.num_stages = num_stages 349 | self.linear = linear 350 | self.embed_dims = embed_dims 351 | self.out_dim = embed_dims[-1] 352 | 353 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 354 | cur = 0 355 | 356 | for i in range(num_stages): 357 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 358 | patch_size=7 if i == 0 else 3, 359 | stride=4 if i == 0 else 2, 360 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 361 | embed_dim=embed_dims[i]) 362 | 363 | block = nn.ModuleList([Block( 364 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 365 | qk_scale=qk_scale, 366 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, 367 | sr_ratio=sr_ratios[i], linear=linear) 368 | for j in range(depths[i])]) 369 | norm = norm_layer(embed_dims[i]) 370 | cur += depths[i] 371 | 372 | setattr(self, f"patch_embed{i + 1}", patch_embed) 373 | setattr(self, f"block{i + 1}", block) 374 | setattr(self, f"norm{i + 1}", norm) 375 | 376 | # init weights (proj head has its own init process) 377 | self.apply(self._init_weights) 378 | 379 | # # Projection head for CL 380 | self.proj = utils.ProjHead(self.out_dim, self.out_dim) 381 | 382 | # Projection head for EsViT extra CL 383 | self.proj2 = utils.ProjHead(self.out_dim, self.out_dim) 384 | 385 | self.num_heads = num_heads[-1] 386 | 387 | def _init_weights(self, m): 388 | if isinstance(m, nn.Linear): 389 | trunc_normal_(m.weight, std=.02) 390 | if isinstance(m, nn.Linear) and m.bias is not None: 391 | nn.init.constant_(m.bias, 0) 392 | elif isinstance(m, nn.LayerNorm): 393 | nn.init.constant_(m.bias, 0) 394 | nn.init.constant_(m.weight, 1.0) 395 | elif isinstance(m, nn.Conv2d): 396 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 397 | fan_out //= m.groups 398 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 399 | if m.bias is not None: 400 | m.bias.data.zero_() 401 | 402 | def forward_features(self, x, info_nce=False): 403 | if info_nce: 404 | attn_vs = [] 405 | outs = [] 406 | B = x.shape[0] 407 | 408 | for i in range(self.num_stages): 409 | patch_embed = getattr(self, f"patch_embed{i + 1}") 410 | block = getattr(self, f"block{i + 1}") 411 | norm = getattr(self, f"norm{i + 1}") 412 | x, H, W = patch_embed(x) 413 | if info_nce: 414 | for blk in block: 415 | x, attn_v = blk(x, H, W, info_nce) 416 | attn_vs.append(attn_v) 417 | else: 418 | for blk in block: 419 | x = blk(x, H, W) 420 | x = norm(x) 421 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 422 | outs.append(x) 423 | 424 | # output is (B, C, H, W) 425 | B, C = x.size(0), x.size(1) 426 | if info_nce: 427 | x = x.reshape(B, C, -1).permute(0, 2, 1).contiguous() 428 | return outs[-1].contiguous(), (self.proj2(x.max(-2)[0].contiguous()).contiguous(), self.proj(x.contiguous()).contiguous(), x.contiguous()) 429 | else: 430 | return outs[-1].contiguous() 431 | 432 | def forward(self, x, info_nce=False): 433 | if info_nce: 434 | x, attn_v = self.forward_features(x, info_nce) 435 | return x, attn_v 436 | else: 437 | x = self.forward_features(x) 438 | return x 439 | 440 | 441 | class DWConv(nn.Module): 442 | def __init__(self, dim=768): 443 | super(DWConv, self).__init__() 444 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 445 | 446 | def forward(self, x, H, W): 447 | B, N, C = x.shape 448 | x = x.transpose(1, 2).view(B, C, H, W).contiguous() 449 | x = self.dwconv(x).contiguous() 450 | x = x.flatten(2).transpose(1, 2).contiguous() 451 | 452 | return x 453 | 454 | 455 | def _conv_filter(state_dict, patch_size=16): 456 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 457 | out_dict = {} 458 | for k, v in state_dict.items(): 459 | if 'patch_embed.proj.weight' in k: 460 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 461 | out_dict[k] = v 462 | 463 | return out_dict 464 | 465 | 466 | class ViTWrap(nn.Module): 467 | def __init__(self, vit): 468 | super().__init__() 469 | self.backbone = vit 470 | 471 | @register('pvtv2_b0') 472 | class pvt_v2_b0(PyramidVisionTransformerV2): 473 | def __init__(self, **kwargs): 474 | super(pvt_v2_b0, self).__init__( 475 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 476 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 477 | drop_rate=0.0, drop_path_rate=0.1) 478 | 479 | 480 | @register('pvtv2_b1') 481 | class pvt_v2_b1(PyramidVisionTransformerV2): 482 | def __init__(self, **kwargs): 483 | super(pvt_v2_b1, self).__init__( 484 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 485 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 486 | drop_rate=0.0, drop_path_rate=0.1) 487 | 488 | 489 | @register('pvtv2_b2') 490 | class pvt_v2_b2(PyramidVisionTransformerV2): 491 | def __init__(self, **kwargs): 492 | super(pvt_v2_b2, self).__init__( 493 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 494 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 495 | drop_rate=0.0, drop_path_rate=0.1) 496 | 497 | 498 | @register('pvtv2_b2_li') 499 | class pvt_v2_b2_li(PyramidVisionTransformerV2): 500 | def __init__(self, **kwargs): 501 | super(pvt_v2_b2_li, self).__init__( 502 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 503 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 504 | drop_rate=0.0, drop_path_rate=0.1, linear=True) 505 | 506 | 507 | @register('pvtv2_b3') 508 | class pvt_v2_b3(PyramidVisionTransformerV2): 509 | def __init__(self, **kwargs): 510 | super(pvt_v2_b3, self).__init__( 511 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 512 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 513 | drop_rate=0.0, drop_path_rate=0.1) 514 | 515 | 516 | @register('pvtv2_b4') 517 | class pvt_v2_b4(PyramidVisionTransformerV2): 518 | def __init__(self, **kwargs): 519 | super(pvt_v2_b4, self).__init__( 520 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 521 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 522 | drop_rate=0.0, drop_path_rate=0.1) 523 | 524 | 525 | @register('pvtv2_b4') 526 | class pvt_v2_b5(PyramidVisionTransformerV2): 527 | def __init__(self, **kwargs): 528 | super(pvt_v2_b5, self).__init__( 529 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 530 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 531 | drop_rate=0.0, drop_path_rate=0.1) 532 | -------------------------------------------------------------------------------- /models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the Swin-Transformer library 5 | # which was released under the MIT Licence. 6 | # 7 | # Source: 8 | # https://github.com/microsoft/Swin-Transformer 9 | # 10 | # The license for the original version of this file can be 11 | # found in https://github.com/microsoft/Swin-Transformer/blob/main/LICENSE 12 | # The modifications to this file are subject to the same Apache Licence. 13 | # --------------------------------------------------------------- 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.utils.checkpoint as checkpoint 19 | import numpy as np 20 | import collections.abc 21 | from itertools import repeat 22 | import math 23 | import warnings 24 | 25 | import utils 26 | from .models import register 27 | 28 | __all__ = [ 29 | 'swin_tiny', 'swin_small' 30 | ] 31 | 32 | 33 | def drop_path(x, drop_prob: float = 0., training: bool = False): 34 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 35 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 36 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 37 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 38 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 39 | 'survival rate' as the argument. 40 | """ 41 | if drop_prob == 0. or not training: 42 | return x 43 | keep_prob = 1 - drop_prob 44 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 45 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 46 | random_tensor.floor_() # binarize 47 | output = x.div(keep_prob) * random_tensor 48 | return output 49 | 50 | 51 | class DropPath(nn.Module): 52 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 53 | """ 54 | def __init__(self, drop_prob=None): 55 | super(DropPath, self).__init__() 56 | self.drop_prob = drop_prob 57 | 58 | def forward(self, x): 59 | return drop_path(x, self.drop_prob, self.training) 60 | 61 | 62 | def _ntuple(n): 63 | def parse(x): 64 | if isinstance(x, collections.abc.Iterable): 65 | return x 66 | return tuple(repeat(x, n)) 67 | return parse 68 | 69 | 70 | to_1tuple = _ntuple(1) 71 | to_2tuple = _ntuple(2) 72 | to_3tuple = _ntuple(3) 73 | to_4tuple = _ntuple(4) 74 | to_ntuple = _ntuple 75 | 76 | 77 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 78 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 79 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 80 | def norm_cdf(x): 81 | # Computes standard normal cumulative distribution function 82 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 83 | 84 | if (mean < a - 2 * std) or (mean > b + 2 * std): 85 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 86 | "The distribution of values may be incorrect.", 87 | stacklevel=2) 88 | 89 | with torch.no_grad(): 90 | # Values are generated by using a truncated uniform distribution and 91 | # then using the inverse CDF for the normal distribution. 92 | # Get upper and lower cdf values 93 | l = norm_cdf((a - mean) / std) 94 | u = norm_cdf((b - mean) / std) 95 | 96 | # Uniformly fill tensor with values from [l, u], then translate to 97 | # [2l-1, 2u-1]. 98 | tensor.uniform_(2 * l - 1, 2 * u - 1) 99 | 100 | # Use inverse cdf transform for normal distribution to get truncated 101 | # standard normal 102 | tensor.erfinv_() 103 | 104 | # Transform to proper mean, std 105 | tensor.mul_(std * math.sqrt(2.)) 106 | tensor.add_(mean) 107 | 108 | # Clamp to ensure it's in the proper range 109 | tensor.clamp_(min=a, max=b) 110 | return tensor 111 | 112 | 113 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 114 | # type: (Tensor, float, float, float, float) -> Tensor 115 | r"""Fills the input Tensor with values drawn from a truncated 116 | normal distribution. The values are effectively drawn from the 117 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 118 | with values outside :math:`[a, b]` redrawn until they are within 119 | the bounds. The method used for generating the random values works 120 | best when :math:`a \leq \text{mean} \leq b`. 121 | Args: 122 | tensor: an n-dimensional `torch.Tensor` 123 | mean: the mean of the normal distribution 124 | std: the standard deviation of the normal distribution 125 | a: the minimum cutoff value 126 | b: the maximum cutoff value 127 | Examples: 128 | >>> w = torch.empty(3, 5) 129 | >>> nn.init.trunc_normal_(w) 130 | """ 131 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 132 | 133 | 134 | class Mlp(nn.Module): 135 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 136 | super().__init__() 137 | out_features = out_features or in_features 138 | hidden_features = hidden_features or in_features 139 | self.fc1 = nn.Linear(in_features, hidden_features) 140 | self.act = act_layer() 141 | self.fc2 = nn.Linear(hidden_features, out_features) 142 | self.drop = nn.Dropout(drop) 143 | 144 | def forward(self, x): 145 | x = self.fc1(x) 146 | x = self.act(x) 147 | x = self.drop(x) 148 | x = self.fc2(x) 149 | x = self.drop(x) 150 | return x 151 | 152 | 153 | def window_partition(x, window_size): 154 | """ 155 | Args: 156 | x: (B, H, W, C) 157 | window_size (int): window size 158 | 159 | Returns: 160 | windows: (num_windows*B, window_size, window_size, C) 161 | """ 162 | B, H, W, C = x.shape 163 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 164 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 165 | return windows 166 | 167 | 168 | def window_reverse(windows, window_size, H, W): 169 | """ 170 | Args: 171 | windows: (num_windows*B, window_size, window_size, C) 172 | window_size (int): Window size 173 | H (int): Height of image 174 | W (int): Width of image 175 | 176 | Returns: 177 | x: (B, H, W, C) 178 | """ 179 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 180 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 181 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 182 | return x 183 | 184 | 185 | class WindowAttention(nn.Module): 186 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 187 | It supports both of shifted and non-shifted window. 188 | 189 | Args: 190 | dim (int): Number of input channels. 191 | window_size (tuple[int]): The height and width of the window. 192 | num_heads (int): Number of attention heads. 193 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 194 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 195 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 196 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 197 | """ 198 | 199 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 200 | 201 | super().__init__() 202 | self.dim = dim 203 | self.window_size = window_size # Wh, Ww 204 | self.num_heads = num_heads 205 | head_dim = dim // num_heads 206 | self.scale = qk_scale or head_dim ** -0.5 207 | 208 | # define a parameter table of relative position bias 209 | self.relative_position_bias_table = nn.Parameter( 210 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 211 | 212 | # get pair-wise relative position index for each token inside the window 213 | coords_h = torch.arange(self.window_size[0]) 214 | coords_w = torch.arange(self.window_size[1]) 215 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 216 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 217 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 218 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 219 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 220 | relative_coords[:, :, 1] += self.window_size[1] - 1 221 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 222 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 223 | self.register_buffer("relative_position_index", relative_position_index) 224 | 225 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 226 | self.attn_drop = nn.Dropout(attn_drop) 227 | self.proj = nn.Linear(dim, dim) 228 | self.proj_drop = nn.Dropout(proj_drop) 229 | 230 | trunc_normal_(self.relative_position_bias_table, std=.02) 231 | self.softmax = nn.Softmax(dim=-1) 232 | 233 | def forward(self, x, mask=None, infonce=False): 234 | """ 235 | Args: 236 | x: input features with shape of (num_windows*B, N, C) 237 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 238 | """ 239 | B_, N, C = x.shape 240 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 241 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 242 | 243 | q = q * self.scale 244 | attn = (q @ k.transpose(-2, -1)) 245 | 246 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 247 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 248 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 249 | attn = attn + relative_position_bias.unsqueeze(0) 250 | 251 | if mask is not None: 252 | nW = mask.shape[0] 253 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 254 | attn = attn.view(-1, self.num_heads, N, N) 255 | attn = self.softmax(attn) 256 | else: 257 | attn = self.softmax(attn) 258 | 259 | attn = self.attn_drop(attn) 260 | 261 | if infonce: 262 | attn_ret = attn.detach() 263 | 264 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 265 | x = self.proj(x) 266 | x = self.proj_drop(x) 267 | 268 | if infonce: 269 | return x, (attn_ret, v) 270 | else: 271 | return x 272 | 273 | 274 | class SwinTransformerBlock(nn.Module): 275 | r""" Swin Transformer Block. 276 | 277 | Args: 278 | dim (int): Number of input channels. 279 | num_heads (int): Number of attention heads. 280 | window_size (int): Window size. 281 | shift_size (int): Shift size for SW-MSA. 282 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 283 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 284 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 285 | drop (float, optional): Dropout rate. Default: 0.0 286 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 287 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 288 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 289 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 290 | """ 291 | 292 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, 293 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 294 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 295 | super().__init__() 296 | self.dim = dim 297 | self.num_heads = num_heads 298 | self.window_size = window_size 299 | self.shift_size = shift_size 300 | self.mlp_ratio = mlp_ratio 301 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 302 | 303 | self.norm1 = norm_layer(dim) 304 | self.attn = WindowAttention( 305 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 306 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 307 | 308 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 309 | self.norm2 = norm_layer(dim) 310 | mlp_hidden_dim = int(dim * mlp_ratio) 311 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 312 | self.H, self.W = None, None 313 | 314 | def forward(self, x, mask_matrix, infonce=False): 315 | B, L, C = x.shape 316 | H, W = self.H, self.W 317 | assert L == H * W, "input feature has wrong size" 318 | 319 | shortcut = x 320 | x = self.norm1(x) 321 | x = x.view(B, H, W, C) 322 | 323 | # pad feature maps to multiples of window size 324 | pad_l = pad_t = 0 325 | pad_r = (self.window_size - W % self.window_size) % self.window_size 326 | pad_b = (self.window_size - H % self.window_size) % self.window_size 327 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 328 | _, Hp, Wp, _ = x.shape 329 | 330 | # cyclic shift 331 | if self.shift_size > 0: 332 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 333 | attn_mask = mask_matrix 334 | else: 335 | shifted_x = x 336 | attn_mask = None 337 | 338 | # partition windows 339 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 340 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 341 | 342 | # W-MSA/SW-MSA 343 | ret = self.attn(x_windows, mask=attn_mask, infonce=infonce) # nW*B, window_size*window_size, C 344 | if infonce: 345 | attn_windows, attn_v = ret 346 | else: 347 | attn_windows = ret 348 | 349 | # merge windows 350 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 351 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C 352 | 353 | # reverse cyclic shift 354 | if self.shift_size > 0: 355 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 356 | else: 357 | x = shifted_x 358 | 359 | if pad_r > 0 or pad_b > 0: 360 | x = x[:, :H, :W, :].contiguous() 361 | 362 | x = x.view(B, H * W, C) 363 | 364 | # FFN 365 | x = shortcut + self.drop_path(x) 366 | x = x + self.drop_path(self.mlp(self.norm2(x))) 367 | 368 | if infonce: 369 | return x, attn_v 370 | else: 371 | return x 372 | 373 | 374 | class PatchMerging(nn.Module): 375 | r""" Patch Merging Layer. 376 | 377 | Args: 378 | dim (int): Number of input channels. 379 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 380 | """ 381 | 382 | def __init__(self, dim, norm_layer=nn.LayerNorm): 383 | super().__init__() 384 | self.dim = dim 385 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 386 | self.norm = norm_layer(4 * dim) 387 | 388 | def forward(self, x, H, W): 389 | """ 390 | x: B, H*W, C 391 | """ 392 | B, L, C = x.shape 393 | assert L == H * W, "input feature has wrong size" 394 | 395 | x = x.view(B, H, W, C) 396 | 397 | # padding 398 | pad_input = (H % 2 == 1) or (W % 2 == 1) 399 | if pad_input: 400 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 401 | 402 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 403 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 404 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 405 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 406 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 407 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 408 | 409 | x = self.norm(x) 410 | x = self.reduction(x) 411 | 412 | return x 413 | 414 | 415 | class BasicLayer(nn.Module): 416 | """ A basic Swin Transformer layer for one stage. 417 | 418 | Args: 419 | dim (int): Number of input channels. 420 | depth (int): Number of blocks. 421 | num_heads (int): Number of attention heads. 422 | window_size (int): Local window size. 423 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 424 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 425 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 426 | drop (float, optional): Dropout rate. Default: 0.0 427 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 428 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 429 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 430 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 431 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 432 | """ 433 | 434 | def __init__(self, 435 | dim, 436 | depth, 437 | num_heads, 438 | window_size=7, 439 | mlp_ratio=4., 440 | qkv_bias=True, 441 | qk_scale=None, 442 | drop=0., 443 | attn_drop=0., 444 | drop_path=0., 445 | norm_layer=nn.LayerNorm, 446 | downsample=None, 447 | use_checkpoint=False): 448 | super().__init__() 449 | self.window_size = window_size 450 | self.shift_size = window_size // 2 451 | self.depth = depth 452 | self.use_checkpoint = use_checkpoint 453 | 454 | # build blocks 455 | self.blocks = nn.ModuleList([ 456 | SwinTransformerBlock( 457 | dim=dim, 458 | num_heads=num_heads, 459 | window_size=window_size, 460 | shift_size=0 if (i % 2 == 0) else window_size // 2, 461 | mlp_ratio=mlp_ratio, 462 | qkv_bias=qkv_bias, 463 | qk_scale=qk_scale, 464 | drop=drop, 465 | attn_drop=attn_drop, 466 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 467 | norm_layer=norm_layer) 468 | for i in range(depth)]) 469 | 470 | # patch merging layer 471 | if downsample is not None: 472 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 473 | else: 474 | self.downsample = None 475 | 476 | def forward(self, x, H, W, infonce=False): 477 | # calculate attention mask for SW-MSA 478 | Hp = int(np.ceil(H / self.window_size)) * self.window_size 479 | Wp = int(np.ceil(W / self.window_size)) * self.window_size 480 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 481 | h_slices = (slice(0, -self.window_size), 482 | slice(-self.window_size, -self.shift_size), 483 | slice(-self.shift_size, None)) 484 | w_slices = (slice(0, -self.window_size), 485 | slice(-self.window_size, -self.shift_size), 486 | slice(-self.shift_size, None)) 487 | cnt = 0 488 | for h in h_slices: 489 | for w in w_slices: 490 | img_mask[:, h, w, :] = cnt 491 | cnt += 1 492 | 493 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 494 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 495 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 496 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 497 | 498 | attn_vs = [] 499 | for blk in self.blocks: 500 | blk.H, blk.W = H, W 501 | if self.use_checkpoint: 502 | assert infonce == False 503 | x = checkpoint.checkpoint(blk, x, attn_mask) 504 | else: 505 | ret = blk(x, attn_mask, infonce) 506 | if infonce: 507 | x, attn_v = ret 508 | attn_vs.append(attn_v) 509 | else: 510 | x = ret 511 | if self.downsample is not None: 512 | x_down = self.downsample(x, H, W) 513 | Wh, Ww = (H + 1) // 2, (W + 1) // 2 514 | if infonce: 515 | return x, H, W, x_down, Wh, Ww, attn_vs 516 | else: 517 | return x, H, W, x_down, Wh, Ww 518 | else: 519 | if infonce: 520 | return x, H, W, x, H, W, attn_vs 521 | else: 522 | return x, H, W, x, H, W 523 | 524 | 525 | class PatchEmbed(nn.Module): 526 | r""" Image to Patch Embedding 527 | 528 | Args: 529 | patch_size (int): Patch token size. Default: 4. 530 | in_chans (int): Number of input image channels. Default: 3. 531 | embed_dim (int): Number of linear projection output channels. Default: 96. 532 | norm_layer (nn.Module, optional): Normalization layer. Default: None 533 | """ 534 | 535 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 536 | super().__init__() 537 | patch_size = to_2tuple(patch_size) 538 | self.patch_size = patch_size 539 | 540 | self.in_chans = in_chans 541 | self.embed_dim = embed_dim 542 | 543 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 544 | if norm_layer is not None: 545 | self.norm = norm_layer(embed_dim) 546 | else: 547 | self.norm = None 548 | 549 | def forward(self, x): 550 | # padding 551 | _, _, H, W = x.size() 552 | if W % self.patch_size[1] != 0: 553 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) 554 | if H % self.patch_size[0] != 0: 555 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) 556 | x = self.proj(x) # B C Wh Ww 557 | if self.norm is not None: 558 | Wh, Ww = x.size(2), x.size(3) 559 | x = x.flatten(2).transpose(1, 2) 560 | x = self.norm(x) 561 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) 562 | return x 563 | 564 | 565 | class SwinTransformer(nn.Module): 566 | r""" Swin Transformer 567 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 568 | https://arxiv.org/pdf/2103.14030 569 | 570 | Args: 571 | pretrain_img_size (int | tuple(int)): Input image size. Default 224 572 | patch_size (int | tuple(int)): Patch size. Default: 4 573 | in_chans (int): Number of input image channels. Default: 3 574 | embed_dim (int): Patch embedding dimension. Default: 96 575 | depths (tuple(int)): Depth of each Swin Transformer layer. 576 | num_heads (tuple(int)): Number of attention heads in different layers. 577 | window_size (int): Window size. Default: 7 578 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 579 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 580 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 581 | drop_rate (float): Dropout rate. Default: 0 582 | attn_drop_rate (float): Attention dropout rate. Default: 0 583 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 584 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 585 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 586 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 587 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 588 | """ 589 | 590 | def __init__(self, 591 | pretrain_img_size=224, 592 | patch_size=4, 593 | in_chans=3, 594 | embed_dim=96, 595 | depths=[2, 2, 6, 2], 596 | num_heads=[3, 6, 12, 24], 597 | window_size=7, 598 | mlp_ratio=4., 599 | qkv_bias=True, 600 | qk_scale=None, 601 | drop_rate=0., 602 | attn_drop_rate=0., 603 | drop_path_rate=0.2, 604 | norm_layer=nn.LayerNorm, 605 | ape=False, 606 | patch_norm=True, 607 | use_checkpoint=False, 608 | **kwargs): 609 | super().__init__() 610 | 611 | self.pretrain_img_size = pretrain_img_size 612 | self.num_layers = len(depths) 613 | self.embed_dim = embed_dim 614 | self.ape = ape 615 | self.patch_norm = patch_norm 616 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 617 | self.out_dim = self.num_features 618 | 619 | # split image into non-overlapping patches 620 | self.patch_embed = PatchEmbed( 621 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 622 | norm_layer=norm_layer if self.patch_norm else None) 623 | 624 | # absolute position embedding 625 | if self.ape: 626 | pretrain_img_size = to_2tuple(pretrain_img_size) 627 | patch_size = to_2tuple(patch_size) 628 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] 629 | 630 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) 631 | trunc_normal_(self.absolute_pos_embed, std=.02) 632 | 633 | self.pos_drop = nn.Dropout(p=drop_rate) 634 | 635 | # stochastic depth 636 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 637 | 638 | # build layers 639 | self.layers = nn.ModuleList() 640 | for i_layer in range(self.num_layers): 641 | layer = BasicLayer( 642 | dim=int(embed_dim * 2 ** i_layer), 643 | depth=depths[i_layer], 644 | num_heads=num_heads[i_layer], 645 | window_size=window_size, 646 | mlp_ratio=mlp_ratio, 647 | qkv_bias=qkv_bias, 648 | qk_scale=qk_scale, 649 | drop=drop_rate, 650 | attn_drop=attn_drop_rate, 651 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 652 | norm_layer=norm_layer, 653 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 654 | use_checkpoint=use_checkpoint) 655 | self.layers.append(layer) 656 | 657 | self.norm = norm_layer(self.num_features) 658 | 659 | # Projection head for CL 660 | self.proj = utils.ProjHead(self.out_dim, self.out_dim) 661 | 662 | # Projection head for EsViT extra CL 663 | self.proj2 = utils.ProjHead(self.out_dim, self.out_dim) 664 | 665 | @torch.jit.ignore 666 | def no_weight_decay(self): 667 | return {'absolute_pos_embed'} 668 | 669 | @torch.jit.ignore 670 | def no_weight_decay_keywords(self): 671 | return {'relative_position_bias_table'} 672 | 673 | def forward(self, x, info_nce=False): 674 | """Forward function.""" 675 | x = self.patch_embed(x) 676 | Wh, Ww = x.size(2), x.size(3) 677 | if self.ape: 678 | # interpolate the position embedding to the corresponding size 679 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') 680 | x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C 681 | else: 682 | x = x.flatten(2).transpose(1, 2) 683 | x = self.pos_drop(x) 684 | 685 | for layer in self.layers: 686 | _, _, _, x, Wh, Ww = layer(x, Wh, Ww) 687 | 688 | x = self.norm(x) # B L C 689 | x = x.permute(0, 2, 1).contiguous() 690 | # output is (B, C, N) 691 | if info_nce: 692 | out = x 693 | B, C = out.size(0), out.size(1) 694 | x = out.reshape(B, C, -1).permute(0, 2, 1).contiguous() 695 | return out.contiguous(), (self.proj2(x.max(-2)[0]), self.proj(x), x) 696 | else: 697 | return x 698 | 699 | def get_last_selfattention(self, x): 700 | x, attn_v = self.forward(x, True) 701 | return attn_v 702 | 703 | 704 | @register('swin_tiny') 705 | def swin_tiny(pretrained=False, **kwargs): 706 | model = SwinTransformer(drop_path_rate=0.2, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, **kwargs) 707 | return model 708 | 709 | 710 | @register('swin_small') 711 | def swin_small(pretrained=False, **kwargs): 712 | model = SwinTransformer(drop_path_rate=0.3, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7, **kwargs) 713 | return model 714 | 715 | 716 | @register('swin_base') 717 | def swin_base(pretrained=False, **kwargs): 718 | model = SwinTransformer(drop_path_rate=0.3, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12, **kwargs) 719 | return model 720 | -------------------------------------------------------------------------------- /models/transparent_encoder.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the Bongard-HOI library 5 | # which was released under the NVIDIA Source Code Licence. 6 | # 7 | # Source: 8 | # https://github.com/NVlabs/Bongard-HOI 9 | # 10 | # The license for the original version of this file can be 11 | # found in https://github.com/NVlabs/Bongard-HOI/blob/master/LICENSE 12 | # The modifications to this file are subject to the same NVIDIA Source Code Licence. 13 | # --------------------------------------------------------------- 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from detectron2.modeling.poolers import ROIPooler 20 | 21 | import models 22 | import utils 23 | from .models import register 24 | 25 | 26 | @register('transparent_bbox_encoder') 27 | class TransparentBBoxNetworkEncoder(nn.Module): 28 | def __init__(self, encoder, **kwargs): 29 | raise NotImplementedError('Currently we have to make sure each image only produce one token.') 30 | super(TransparentBBoxNetworkEncoder, self).__init__() 31 | 32 | # image encoder 33 | encoder = models.make(encoder) 34 | self.conv1 = encoder.conv1 35 | self.bn1 = encoder.bn1 36 | self.relu = encoder.relu 37 | self.maxpool = encoder.maxpool 38 | self.layer1 = encoder.layer1 39 | self.layer2 = encoder.layer2 40 | self.layer3 = encoder.layer3 41 | self.layer4 = encoder.layer4 42 | 43 | self.proj = nn.Conv2d(encoder.out_dim, encoder.out_dim // 2, kernel_size=1) 44 | 45 | # ROI Pooler 46 | self.roi_pooler = ROIPooler( 47 | output_size=7, 48 | scales=(1/32,), 49 | sampling_ratio=0, 50 | pooler_type='ROIAlignV2', 51 | ) 52 | self.roi_processor = nn.Sequential( 53 | nn.Conv2d(encoder.out_dim // 2, 256, kernel_size=3, padding=1), 54 | nn.BatchNorm2d(256), 55 | nn.ReLU(), 56 | nn.Flatten(), 57 | nn.Linear(256*7*7, 512), 58 | nn.ReLU() 59 | ) 60 | self.roi_processor_ln = nn.LayerNorm(512) 61 | 62 | # bbox coord encoding 63 | self.roi_processor_box = nn.Linear(4, 128) 64 | self.roi_processor_box_ln = nn.LayerNorm(128) 65 | rn_in_planes = 512 + 128 66 | 67 | self.out_dim = rn_in_planes 68 | 69 | def forward(self, im, boxes, boxes_dim=None): 70 | # assert im.shape[0] == len(boxes), 'im: {} vs boxes: {}'.format(im.shape[0], len(boxes)) 71 | img_shape = im.shape 72 | im = im.view(-1, *img_shape[-3:]) 73 | num_im = im.size(0) 74 | # assert im.shape[0] == boxes_dim.shape[0], '{} vs {}'.format(im.shape, boxes_dim.shape) 75 | if boxes_dim is not None: 76 | boxes_dim_shape = boxes_dim.shape 77 | boxes_dim = boxes_dim.view(-1, *boxes_dim_shape[-1:]) 78 | 79 | # BxCxHxW 80 | x = self.conv1(im) 81 | x = self.bn1(x) 82 | x = self.relu(x) 83 | x = self.maxpool(x) 84 | 85 | x = self.layer1(x) 86 | x = self.layer2(x) 87 | x = self.layer3(x) 88 | x = self.layer4(x) 89 | 90 | x = self.proj(x) 91 | 92 | # RoI pooling/align 93 | # x_shape = list(img_shape[:-3]) + list(x.shape[-3:]) 94 | # x = x.view(x_shape) 95 | all_boxes = [] 96 | for boxes_i in boxes: 97 | all_boxes.extend(boxes_i) 98 | num_boxes = [boxes_i.tensor.shape[0] for boxes_i in all_boxes] 99 | 100 | # roi_feats = roi_align( 101 | # x, all_boxes, 102 | # output_size=(7, 7), 103 | # spatial_scale=1/32., 104 | # sampling_ratio=0, 105 | # aligned=True 106 | # ) 107 | roi_feats = self.roi_pooler([x], all_boxes) 108 | roi_feats = self.roi_processor(roi_feats) 109 | roi_feats = self.roi_processor_ln(roi_feats) 110 | # Add bbox pos features 111 | bbox_tensor = torch.cat([box.tensor for box in all_boxes]).to(roi_feats) 112 | # bbox coord normalization 113 | bbox_tensor[:, 0] = bbox_tensor[:, 0] / im.shape[3] 114 | bbox_tensor[:, 1] = bbox_tensor[:, 1] / im.shape[2] 115 | bbox_tensor[:, 2] = bbox_tensor[:, 2] / im.shape[3] 116 | bbox_tensor[:, 3] = bbox_tensor[:, 3] / im.shape[2] 117 | bbox_tensor = bbox_tensor * 2 - 1 118 | roi_box_feats = self.roi_processor_box_ln(self.roi_processor_box(bbox_tensor)) 119 | roi_feats = torch.cat([roi_feats, roi_box_feats], dim=-1) 120 | 121 | # TODO: This assumes all the images have the same number of bboxes. 122 | feat = roi_feats.reshape(num_im, -1, roi_feats.size(-1)) 123 | return feat 124 | 125 | # CNN backbone then superpixels as patches 126 | @register('transparent_superpixel_encoder') 127 | class TransparentSuperpixelEncoder(nn.Module): 128 | def __init__(self, encoder, **kwargs): 129 | super().__init__() 130 | 131 | # image encoder 132 | encoder = models.make(encoder) 133 | self.encoder = encoder 134 | self.out_dim = encoder.out_dim 135 | 136 | def forward(self, im, boxes=None, boxes_dim=None, info_nce=False): 137 | img_shape = im.shape 138 | im = im.view(-1, *img_shape[-3:]) 139 | num_im = im.size(0) 140 | 141 | # BxCxHxW 142 | if info_nce: 143 | feats, attn_v = self.encoder(im, info_nce) 144 | return feats, attn_v 145 | else: 146 | feats = self.encoder(im) 147 | return feats -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the pytorch-image-models library 5 | # which was released under the Apache Licence. 6 | # 7 | # Source: 8 | # https://github.com/rwightman/pytorch-image-models/ 9 | # 10 | # The license for the original version of this file can be 11 | # found in https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE 12 | # The modifications to this file are subject to the same Apache Licence. 13 | # --------------------------------------------------------------- 14 | 15 | import logging 16 | from functools import partial 17 | from collections import OrderedDict 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from timm.models.layers import Mlp, DropPath, to_2tuple 24 | 25 | import utils 26 | from .models import register 27 | 28 | _logger = logging.getLogger(__name__) 29 | 30 | 31 | __all__ = [ 32 | 'vit_tiny_16', 'vit_small_16', 'vit_base_16', 'vit_large_16' 33 | ] 34 | 35 | 36 | class PatchEmbed(nn.Module): 37 | """ 2D Image to Patch Embedding 38 | """ 39 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 40 | super().__init__() 41 | img_size = to_2tuple(img_size) 42 | patch_size = to_2tuple(patch_size) 43 | self.img_size = img_size 44 | self.patch_size = patch_size 45 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 46 | self.num_patches = self.H * self.W 47 | self.flatten = flatten 48 | 49 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 50 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 51 | 52 | def forward(self, x): 53 | B, C, H, W = x.shape 54 | # assert H == self.img_size[0] and W == self.img_size[1], \ 55 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 56 | x = self.proj(x) 57 | H, W = x.size(2), x.size(3) 58 | if self.flatten: 59 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 60 | x = self.norm(x) 61 | return x, (H, W) 62 | 63 | 64 | class Attention(nn.Module): 65 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 66 | super().__init__() 67 | self.num_heads = num_heads 68 | head_dim = dim // num_heads 69 | self.scale = head_dim ** -0.5 70 | 71 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 72 | self.attn_drop = nn.Dropout(attn_drop) 73 | self.proj = nn.Linear(dim, dim) 74 | self.proj_drop = nn.Dropout(proj_drop) 75 | 76 | def forward(self, x, info_nce=False): 77 | B, N, C = x.shape 78 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 79 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 80 | 81 | attn = (q @ k.transpose(-2, -1)) * self.scale 82 | attn = attn.softmax(dim=-1) 83 | attn = self.attn_drop(attn) 84 | if info_nce: 85 | attn_ret = attn.detach() 86 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | if info_nce: 90 | return x, (attn_ret, v) 91 | else: 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 98 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 102 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 103 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 104 | self.norm2 = norm_layer(dim) 105 | mlp_hidden_dim = int(dim * mlp_ratio) 106 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 107 | 108 | def forward(self, x, info_nce=False): 109 | if info_nce: 110 | tmp, attn_v = self.attn(self.norm1(x), info_nce) 111 | else: 112 | tmp = self.attn(self.norm1(x)) 113 | x = x + self.drop_path(tmp) 114 | x = x + self.drop_path(self.mlp(self.norm2(x))) 115 | if info_nce: 116 | return x, attn_v 117 | else: 118 | return x 119 | 120 | 121 | class VisionTransformer(nn.Module): 122 | """ Vision Transformer 123 | 124 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 125 | - https://arxiv.org/abs/2010.11929 126 | 127 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 128 | - https://arxiv.org/abs/2012.12877 129 | """ 130 | 131 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 132 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 133 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 134 | act_layer=None, weight_init=''): 135 | """ 136 | Args: 137 | img_size (int, tuple): input image size 138 | patch_size (int, tuple): patch size 139 | in_chans (int): number of input channels 140 | num_classes (int): number of classes for classification head 141 | embed_dim (int): embedding dimension 142 | depth (int): depth of transformer 143 | num_heads (int): number of attention heads 144 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 145 | qkv_bias (bool): enable bias for qkv if True 146 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 147 | distilled (bool): model includes a distillation token and head as in DeiT models 148 | drop_rate (float): dropout rate 149 | attn_drop_rate (float): attention dropout rate 150 | drop_path_rate (float): stochastic depth rate 151 | embed_layer (nn.Module): patch embedding layer 152 | norm_layer: (nn.Module): normalization layer 153 | weight_init: (str): weight init scheme 154 | """ 155 | super().__init__() 156 | self.num_classes = num_classes 157 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 158 | self.num_tokens = 2 if distilled else 1 159 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 160 | act_layer = act_layer or nn.GELU 161 | self.out_dim = embed_dim 162 | 163 | self.patch_embed = embed_layer( 164 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 165 | num_patches = self.patch_embed.num_patches 166 | 167 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 168 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 169 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 170 | self.pos_drop = nn.Dropout(p=drop_rate) 171 | 172 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 173 | self.blocks = nn.Sequential(*[ 174 | Block( 175 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 176 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 177 | for i in range(depth)]) 178 | self.norm = norm_layer(embed_dim) 179 | 180 | # Representation layer 181 | if representation_size and not distilled: 182 | self.num_features = representation_size 183 | self.pre_logits = nn.Sequential(OrderedDict([ 184 | ('fc', nn.Linear(embed_dim, representation_size)), 185 | ('act', nn.Tanh()) 186 | ])) 187 | else: 188 | self.pre_logits = nn.Identity() 189 | 190 | # Classifier head(s) 191 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 192 | self.head_dist = None 193 | if distilled: 194 | assert NotImplementedError 195 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 196 | 197 | # Projection head for CL 198 | self.proj = utils.ProjHead(embed_dim, embed_dim) 199 | 200 | # Projection head for EsViT extra CL 201 | self.proj2 = utils.ProjHead(embed_dim, embed_dim) 202 | 203 | self.proj_v = utils.ProjHead(embed_dim//num_heads, embed_dim//num_heads) 204 | self.num_heads = num_heads 205 | 206 | @torch.jit.ignore 207 | def no_weight_decay(self): 208 | return {'pos_embed', 'cls_token', 'dist_token'} 209 | 210 | def get_classifier(self): 211 | if self.dist_token is None: 212 | return self.head 213 | else: 214 | return self.head, self.head_dist 215 | 216 | def reset_classifier(self, num_classes, global_pool=''): 217 | self.num_classes = num_classes 218 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 219 | if self.num_tokens == 2: 220 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 221 | 222 | def _get_pos_embed(self, pos_embed, patch_embed, H, W): 223 | if H * W == patch_embed.num_patches: 224 | return pos_embed 225 | else: 226 | return F.interpolate( 227 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 228 | size=(H, W), mode="bilinear", align_corners=False).reshape(1, -1, H * W).permute(0, 2, 1) 229 | 230 | def forward_features(self, x, info_nce=False): 231 | if info_nce: 232 | attn_vs = [] 233 | B = x.size(0) 234 | x, (H, W) = self.patch_embed(x) 235 | # skip extra tokens and resize pos_embed if needed 236 | tokens_pos_emb = self.pos_embed[:, :self.num_tokens] 237 | patches_pos_emb = self._get_pos_embed(self.pos_embed[:, self.num_tokens:], self.patch_embed, H, W) 238 | if info_nce in [1, 4]: 239 | x = torch.cat([self.cls_token.expand(B, -1, -1).detach(), x], dim=1) 240 | x = self.pos_drop(x + torch.cat([tokens_pos_emb, patches_pos_emb], dim=1)) 241 | else: 242 | x = self.pos_drop(x + patches_pos_emb) 243 | for block in self.blocks: 244 | if info_nce: 245 | x, attn_v = block(x, info_nce) 246 | attn_vs.append(attn_v) 247 | else: 248 | x = block(x) 249 | x = self.norm(x) 250 | # output is (B, C, N) 251 | if info_nce: 252 | return x.permute(0, 2, 1).contiguous()[:, :, [0]], (self.proj2(x[:, 0]), self.proj(x[:, 1:]), x[:, 1:]) 253 | else: 254 | return x.permute(0, 2, 1).contiguous() 255 | 256 | def forward(self, x, info_nce=False): 257 | if info_nce: 258 | x, attn_v = self.forward_features(x, info_nce) 259 | return x, attn_v 260 | else: 261 | x = self.forward_features(x) 262 | return x 263 | 264 | def get_last_selfattention(self, x): 265 | x, attn_v = self.forward_features(x, True) 266 | return attn_v[0] 267 | 268 | 269 | @register('vit_tiny_16') 270 | def vit_tiny_16(pretrained=False, **kwargs): 271 | """ ViT-Tiny (Vit-Ti/16) 272 | """ 273 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 274 | model = VisionTransformer(**model_kwargs, **kwargs) 275 | return model 276 | 277 | 278 | @register('vit_small_16') 279 | def vit_small_16(pretrained=False, **kwargs): 280 | """ ViT-Small (ViT-S/16) 281 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 282 | """ 283 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 284 | model = VisionTransformer(**model_kwargs, **kwargs) 285 | return model 286 | 287 | 288 | @register('vit_base_16') 289 | def vit_base_16(pretrained=False, **kwargs): 290 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 291 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 292 | """ 293 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 294 | model = VisionTransformer(**model_kwargs, **kwargs) 295 | return model 296 | 297 | 298 | @register('vit_large_16') 299 | def vit_large_16(pretrained=False, **kwargs): 300 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 301 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 302 | """ 303 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 304 | model = VisionTransformer(**model_kwargs, **kwargs) 305 | return model 306 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pillow 3 | matplotlib 4 | tqdm 5 | pandas 6 | tensorboardX==2.4 7 | opencv-python==4.5.4.58 8 | protobuf==3.20.1 9 | scikit-learn 10 | git+https://github.com/facebookresearch/detectron2.git -------------------------------------------------------------------------------- /scripts/train_gqa_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CFG=$1 4 | OPTS=${@:2} 5 | 6 | python train_gqa.py --config-file $CFG $OPTS 2>&1 7 | -------------------------------------------------------------------------------- /scripts/train_hico_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CFG=$1 4 | OPTS=${@:2} 5 | 6 | python train_hico.py --config-file $CFG $OPTS 2>&1 7 | -------------------------------------------------------------------------------- /train_gqa.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for RelViT. To view a copy of this license, see the LICENSE file. 6 | # ---------------------------------------------------------------------- 7 | 8 | import argparse 9 | import os 10 | from re import L 11 | import yaml 12 | import numpy as np 13 | import random 14 | import time 15 | 16 | import torch 17 | import torch.backends.cudnn as cudnn 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.nn.parallel 21 | import torch.optim 22 | import torch.utils.data 23 | import torch.multiprocessing as mp 24 | import torch.distributed as dist 25 | 26 | from tensorboardX import SummaryWriter 27 | 28 | from tqdm import tqdm 29 | 30 | import datasets 31 | import models 32 | import utils 33 | from datasets.image_gqa import collate_gqa 34 | 35 | 36 | def compute_acc_gqa(final_y_true, final_y_pred): 37 | count = len(final_y_true) 38 | correct = (final_y_true == final_y_pred).sum() 39 | return (correct / count).item() 40 | 41 | def gather_score_label(dist, loader, current_idx, y_true, y_pred, ddp=False): 42 | if ddp: 43 | # all_gather and compute Acc 44 | current_idx_list = [torch.zeros(1).long().cuda(non_blocking=True) for _ in range(args.world_size)] 45 | y_true_list = [torch.zeros((len(loader.dataset), 1)).long().cuda(non_blocking=True) for _ in range(args.world_size)] 46 | y_pred_list = [torch.zeros((len(loader.dataset), 1)).long().cuda(non_blocking=True) for _ in range(args.world_size)] 47 | dist.all_gather(current_idx_list, current_idx) 48 | dist.all_gather(y_true_list, y_true) 49 | dist.all_gather(y_pred_list, y_pred) 50 | 51 | final_y_true = np.empty((0, 1)) 52 | final_y_pred = np.empty((0, 1)) 53 | for idx, yt, ys in zip(current_idx_list, y_true_list, y_pred_list): 54 | idx = idx.item() 55 | yt = yt.detach().cpu().numpy() 56 | ys = ys.detach().cpu().numpy() 57 | final_y_true = np.vstack((final_y_true, yt[:idx])) 58 | final_y_pred = np.vstack((final_y_pred, ys[:idx])) 59 | else: 60 | final_y_true = y_true.detach().cpu().numpy()[:current_idx.item()] 61 | final_y_pred = y_pred.detach().cpu().numpy()[:current_idx.item()] 62 | return final_y_true, final_y_pred 63 | 64 | 65 | def main(config): 66 | args.gpu = ''#[i for i in range(torch.cuda.device_count())] 67 | args.train_gpu = [i for i in range(torch.cuda.device_count())] 68 | num_gpus = torch.cuda.device_count() 69 | for i in range(num_gpus - 1): 70 | args.gpu += '{},'.format(i) 71 | args.gpu += '{}'.format(num_gpus - 1) 72 | if len(args.gpu.split(',')) > 1: 73 | config['_parallel'] = True 74 | config['_gpu'] = args.gpu 75 | utils.set_gpu(args.gpu) 76 | args.config = config 77 | 78 | if args.seed is not None: 79 | random.seed(args.seed) 80 | np.random.seed(args.seed) 81 | torch.manual_seed(args.seed) 82 | torch.cuda.manual_seed(args.seed) 83 | torch.cuda.manual_seed_all(args.seed) 84 | cudnn.benchmark = False 85 | cudnn.deterministic = True 86 | if args.dist_url == "env://" and args.world_size == -1: 87 | args.world_size = int(os.environ["WORLD_SIZE"]) 88 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 89 | args.ngpus_per_node = len(args.train_gpu) 90 | if len(args.train_gpu) == 1: 91 | args.sync_bn = False 92 | args.distributed = False 93 | args.multiprocessing_distributed = False 94 | if args.multiprocessing_distributed: 95 | args.sync_bn = True 96 | if args.dist_url[-2:] == '{}': 97 | port = utils.find_free_port() 98 | args.dist_url = args.dist_url.format(port) 99 | args.world_size = args.ngpus_per_node * args.world_size 100 | mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args)) 101 | else: 102 | main_worker(args.train_gpu, args.ngpus_per_node, args) 103 | 104 | 105 | def main_worker(gpu, ngpus_per_node, argss): 106 | global args 107 | args = argss 108 | if args.distributed: 109 | if args.dist_url == "env://" and args.rank == -1: 110 | args.rank = int(os.environ["RANK"]) 111 | if args.multiprocessing_distributed: 112 | args.rank = args.rank * ngpus_per_node + gpu 113 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) 114 | 115 | config = args.config 116 | svname = args.svname 117 | if svname is None: 118 | config_name, _ = os.path.splitext(os.path.basename(args.config_file)) 119 | svname = 'hico' 120 | svname += '_' + config['model'] 121 | if config['model_args'].get('encoder'): 122 | svname += '-' + config['model_args']['encoder'] 123 | svname = os.path.join(config_name, config['train_dataset'], svname) 124 | if not args.test_only: 125 | svname += '-seed' + str(args.seed) 126 | if args.tag is not None: 127 | svname += '_' + args.tag 128 | 129 | sub_dir_name = 'default' 130 | if args.opts: 131 | sub_dir_name = args.opts[0] 132 | split = '#' 133 | for opt in args.opts[1:]: 134 | sub_dir_name += split + opt 135 | split = '#' if split == '_' else '_' 136 | svname = os.path.join(svname, sub_dir_name) 137 | 138 | if utils.is_main_process() and not args.test_only: 139 | save_path = os.path.join(args.save_dir, svname) 140 | utils.ensure_path(save_path, remove=False) 141 | utils.set_log_path(save_path) 142 | writer = SummaryWriter(os.path.join(save_path, 'tensorboard')) 143 | args.writer = writer 144 | 145 | yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w')) 146 | 147 | logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"), file_mode="a+", should_flush=True) 148 | else: 149 | save_path = None 150 | writer = None 151 | args.writer = writer 152 | logger = None 153 | 154 | #### Dataset #### 155 | 156 | if config.get('ep_per_batch') is not None: 157 | ep_per_batch = config['ep_per_batch'] 158 | else: 159 | ep_per_batch = 1 160 | 161 | # train 162 | dataset_configs = config['train_dataset_args'] 163 | train_dataset = datasets.make(config['train_dataset'], eval_mode=config['eval_mode'], **dataset_configs) 164 | if utils.is_main_process(): 165 | utils.log('train dataset: {} samples'.format(len(train_dataset))) 166 | if args.distributed: 167 | args.batch_size = int(ep_per_batch / args.world_size) 168 | args.batch_size_test = int(ep_per_batch / args.world_size) 169 | args.workers = int((args.workers + args.world_size - 1) / args.world_size) 170 | else: 171 | args.batch_size = ep_per_batch 172 | args.batch_size_test = ep_per_batch 173 | args.workers = args.workers 174 | 175 | if args.distributed: 176 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 177 | else: 178 | train_sampler = None 179 | train_loader = torch.utils.data.DataLoader( 180 | train_dataset, 181 | batch_size=args.batch_size, 182 | shuffle=(train_sampler is None), 183 | num_workers=args.workers, 184 | pin_memory=True, 185 | sampler=train_sampler, 186 | drop_last=True, 187 | collate_fn=collate_gqa 188 | ) 189 | 190 | # testing 191 | dataset_configs = config['test_dataset_args'] 192 | test_dataset = datasets.make(config['test_dataset'], eval_mode=config['eval_mode'], **dataset_configs) 193 | if utils.is_main_process(): 194 | utils.log('test dataset: {} samples'.format(len(test_dataset))) 195 | if args.distributed: 196 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) 197 | else: 198 | test_sampler = None 199 | test_loader = torch.utils.data.DataLoader( 200 | test_dataset, 201 | batch_size=args.batch_size_test, 202 | shuffle=False, 203 | num_workers=args.workers, 204 | pin_memory=True, 205 | sampler=test_sampler, 206 | collate_fn=collate_gqa 207 | ) 208 | 209 | ######## 210 | 211 | #### Model and optimizer #### 212 | 213 | if config.get('load'): 214 | print('loading pretrained model: ', config['load']) 215 | model = models.load(torch.load(config['load'])) 216 | if config['relvit']: 217 | model_tea_k = models.load(torch.load(config['load'])) 218 | else: 219 | model = models.make(config['model'], **config['model_args']) 220 | if config['relvit']: 221 | model_tea_k = models.make(config['model'], **config['model_args']) 222 | 223 | if config.get('load_encoder'): 224 | pretrain = config.get('encoder_pretrain').lower() 225 | if pretrain != 'scratch': 226 | pretrain_model_path = config['load_encoder'].format(pretrain) 227 | state_dict = torch.load(pretrain_model_path, map_location='cpu') 228 | missing_keys, unexpected_keys = model.encoder.encoder.load_state_dict(state_dict, strict=False) 229 | for key in missing_keys: 230 | assert key.startswith('g_mlp.') \ 231 | or key.startswith('proj') \ 232 | or key.startswith('trans') \ 233 | or key.startswith('roi_processor') \ 234 | or key.startswith('roi_dim_processor') \ 235 | or key.startswith('classifier'), key 236 | for key in unexpected_keys: 237 | assert key.startswith('fc.') 238 | if utils.is_main_process(): 239 | utils.log('==> Successfully loaded {} for the enocder.'.format(pretrain_model_path)) 240 | if config['relvit']: 241 | model_tea_k.load_state_dict(model.state_dict()) 242 | 243 | if args.sync_bn: 244 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 245 | if config['relvit']: 246 | model_tea_k = nn.SyncBatchNorm.convert_sync_batchnorm(model_tea_k) 247 | 248 | if utils.is_main_process(): 249 | utils.log(model) 250 | 251 | if args.distributed: 252 | torch.cuda.set_device(gpu) 253 | model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu], find_unused_parameters=True, broadcast_buffers=False) 254 | if config['relvit']: 255 | model_tea_k = torch.nn.parallel.DistributedDataParallel(model_tea_k.cuda(), device_ids=[gpu], find_unused_parameters=True, broadcast_buffers=False) 256 | else: 257 | model = torch.nn.DataParallel(model.cuda()) 258 | if config['relvit']: 259 | model_tea_k = torch.nn.DataParallel(model_tea_k.cuda()) 260 | 261 | if utils.is_main_process() and not args.test_only: 262 | utils.log('num params: {}'.format(utils.compute_n_params(model))) 263 | utils.log('Results will be saved to {}'.format(save_path)) 264 | 265 | max_steps = min(len(train_loader), config['train_batches']) * config['max_epoch'] 266 | optimizer, lr_scheduler, update_lr_every_epoch = utils.make_optimizer( 267 | model.parameters(), 268 | config['optimizer'], max_steps, **config['optimizer_args'] 269 | ) 270 | assert lr_scheduler is not None 271 | args.update_lr_every_epoch = update_lr_every_epoch 272 | 273 | if args.test_only: 274 | filename = args.test_model 275 | assert os.path.exists(filename) 276 | ckpt = torch.load(filename, map_location='cpu') 277 | start_epoch = ckpt['epoch'] 278 | model.load_state_dict(ckpt['state_dict']) 279 | if config['relvit']: 280 | model_tea_k.load_state_dict(ckpt['tea_k_state_dict']) 281 | optimizer.load_state_dict(ckpt['optimizer']) 282 | lr_scheduler.load_state_dict(ckpt['lr_scheduler']) 283 | best_test_result = ckpt['best_test_result'] 284 | if utils.is_main_process(): 285 | utils.log('==> Sucessfully resumed from a checkpoint {}'.format(filename)) 286 | else: 287 | ckpt = None 288 | start_epoch = 0 289 | best_test_result = 0.0 290 | 291 | ######## MoCo 292 | if config['relvit']: 293 | cl_loss = 'esvit' 294 | feat_dim = model.module.encoder.encoder.out_dim 295 | moco = utils.relvit.MoCo( 296 | model_tea_k.module, 297 | model.module, 298 | config['relvit_moco_K'], 299 | config['relvit_moco_m'], 300 | feat_dim, 301 | num_concepts=config['relvit_num_concepts'], 302 | relvit_mode=config['relvit_mode'], 303 | num_tokens=config['relvit_num_tokens']) 304 | moco = moco.cuda() 305 | if ckpt is not None: 306 | try: 307 | moco.load_state_dict(ckpt['moco']) 308 | utils.log('==> MoCo sucessfully resumed from a checkpoint.') 309 | except: 310 | utils.log('==> MoCo is not resumed.') 311 | else: 312 | moco = None 313 | model_tea_k = None 314 | ######## 315 | 316 | ######## Training & Validation 317 | 318 | max_epoch = config['max_epoch'] 319 | save_epoch = config.get('save_epoch') 320 | max_va = best_test_result 321 | timer_used = utils.Timer() 322 | timer_epoch = utils.Timer() 323 | 324 | if args.test_only: 325 | ret = test(test_loader, model_tea_k if model_tea_k is not None else model, 0, args) 326 | if ret is None: 327 | return 0 328 | else: 329 | loss_test, acc_test = ret 330 | if utils.is_main_process(): 331 | print('Testing result: ', acc_test) 332 | return 0 333 | 334 | for epoch in range(start_epoch, max_epoch): 335 | # timer_epoch.s() 336 | # aves = {k: utils.Averager() for k in aves_keys} 337 | 338 | # TODO: save param of prev epoch to resume from NaN 339 | sd_model = model.state_dict() 340 | if config['relvit']: 341 | sd_model_tea_k = model_tea_k.state_dict() 342 | sd_moco = moco.state_dict() 343 | 344 | epoch_log = epoch + 1 345 | finish = False 346 | while finish is False: 347 | model.load_state_dict(sd_model) 348 | if config['relvit']: 349 | model_tea_k.load_state_dict(sd_model_tea_k) 350 | moco.load_state_dict(sd_moco) 351 | if args.distributed: 352 | train_sampler.set_epoch(epoch) 353 | ret = train(train_loader, model, optimizer, lr_scheduler, epoch_log, (moco, model_tea_k), writer, args) 354 | # autoresume 355 | if ret is None: 356 | return 0 357 | # NaN 358 | elif ret == -1: 359 | utils.synchronize() 360 | if utils.is_main_process(): 361 | utils.log('NaN detected, redo the current epoch...') 362 | continue 363 | else: 364 | finish = True 365 | (loss_train, aux_loss_train), acc_train = ret 366 | import gc; gc.collect() 367 | torch.cuda.empty_cache() 368 | if args.update_lr_every_epoch: 369 | lr_scheduler.step() 370 | if utils.is_main_process(): 371 | writer.add_scalar('loss_train', loss_train, epoch_log) 372 | writer.add_scalar('aux_loss_train', aux_loss_train, epoch_log) 373 | writer.add_scalar('Acc_train', acc_train, epoch_log) 374 | for name, param in model.named_parameters(): 375 | writer.add_histogram(name, param, epoch_log) 376 | 377 | if (epoch_log % config['save_epoch'] == 0 or epoch_log == config['max_epoch']) and utils.is_main_process(): 378 | filename = os.path.join(save_path, 'train.pth') 379 | utils.log('==> Saving checkpoint to: ' + filename) 380 | ckpt = { 381 | 'epoch': epoch_log, 382 | 'state_dict': model.state_dict(), 383 | 'tea_k_state_dict': model_tea_k.state_dict() if config['relvit'] else model.state_dict(), 384 | 'moco': moco.state_dict() if moco else None, 385 | 'optimizer': optimizer.state_dict(), 386 | 'lr_scheduler': lr_scheduler.state_dict(), 387 | 'best_test_result': best_test_result, 388 | } 389 | torch.save(ckpt, filename) 390 | acc_test = 0 391 | if epoch_log % config['eval_epoch'] == 0: 392 | ret = test(test_loader, model_tea_k if model_tea_k is not None else model, epoch_log, args) 393 | if ret is None: 394 | return 0 395 | else: 396 | loss_test, acc_test = ret 397 | import gc; gc.collect() 398 | torch.cuda.empty_cache() 399 | if acc_test > best_test_result: 400 | best_test_result = acc_test 401 | if utils.is_main_process(): 402 | utils.log('test result: loss {:.4f}, Acc: {:.4f}.'.format(loss_test, acc_test)) 403 | writer.add_scalar('loss_test', loss_test, epoch_log) 404 | writer.add_scalar('Acc_test', acc_test, epoch_log) 405 | 406 | if utils.is_main_process(): 407 | utils.log('Best test results so far:') 408 | utils.log(best_test_result) 409 | 410 | if acc_test > max_va and utils.is_main_process(): 411 | max_va = acc_test 412 | filename = os.path.join(save_path, 'best_model.pth') 413 | ckpt = { 414 | 'epoch': epoch_log, 415 | 'state_dict': model.state_dict(), 416 | 'tea_k_state_dict': model_tea_k.state_dict() if config['relvit'] else model.state_dict(), 417 | 'optimizer': optimizer.state_dict(), 418 | 'lr_scheduler': lr_scheduler.state_dict(), 419 | 'best_test_result': best_test_result, 420 | } 421 | torch.save(ckpt, filename) 422 | if utils.is_main_process(): 423 | writer.flush() 424 | 425 | if utils.is_main_process(): 426 | logger.close() 427 | 428 | def train(train_loader, model, optimizer, lr_scheduler, epoch, moco_tuple, writer, args): 429 | batch_time = utils.AverageMeter() 430 | data_time = utils.AverageMeter() 431 | main_loss_meter = utils.AverageMeter() 432 | aux_loss_meter = utils.AverageMeter() 433 | loss_meter = utils.AverageMeter() 434 | intersection_meter = utils.AverageMeter() 435 | union_meter = utils.AverageMeter() 436 | target_meter = utils.AverageMeter() 437 | # Acc 438 | current_idx = torch.zeros(1).long() 439 | y_true = torch.zeros((len(train_loader.dataset), 1)).long() 440 | y_pred = torch.zeros((len(train_loader.dataset), 1)).long() 441 | 442 | config = args.config 443 | 444 | # train 445 | model.train() 446 | 447 | if utils.is_main_process(): 448 | args.writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) 449 | 450 | end = time.time() 451 | max_iter = config['max_epoch'] * len(train_loader) 452 | for batch_idx, batch in enumerate(train_loader): 453 | if batch_idx >= config['train_batches']: 454 | break 455 | 456 | feed_list = [] 457 | ims = batch[0].cuda(non_blocking=True) 458 | second_ims = batch[1].cuda(non_blocking=True) 459 | third_ims = batch[2].cuda(non_blocking=True) 460 | feed_list.append(ims) 461 | 462 | # For Two-view denseCL 463 | cl_ims = torch.cat([second_ims, third_ims], dim=0) 464 | ques_ix_iter = batch[3].cuda(non_blocking=True) 465 | feed_list.append(ques_ix_iter) 466 | ans_iter = batch[5].cuda(non_blocking=True) 467 | concepts = batch[6].cuda(non_blocking=True) 468 | 469 | # with torch.cuda.amp.autocast(enabled=args.amp), torch.autograd.detect_anomaly(): 470 | with torch.cuda.amp.autocast(enabled=args.amp): 471 | if config['relvit']: 472 | logits = model(*feed_list) 473 | B = cl_ims.size(0) // 2 474 | with torch.no_grad(): 475 | d1, attn_v_k = moco_tuple[1](cl_ims, ques_ix_iter.repeat(2, 1), True) 476 | d2, attn_v_q = model(cl_ims, ques_ix_iter.repeat(2, 1), True) 477 | aux_loss = moco_tuple[0](attn_v_k, attn_v_q, B, config, args, concepts) + d1.mean()*0 + d2.mean()*0 478 | else: 479 | logits = model(*feed_list) 480 | aux_loss = torch.zeros(1).to(logits) 481 | 482 | loss = F.cross_entropy(logits, ans_iter.squeeze(-1), reduction='sum') + config['relvit_weight'] * aux_loss 483 | y_pred[current_idx.item():current_idx.item()+logits.size(0)] = logits.detach().argmax(-1).unsqueeze(-1).cpu() 484 | y_true[current_idx.item():current_idx.item()+logits.size(0)] = ans_iter.detach().cpu() 485 | current_idx += logits.size(0) 486 | if torch.isnan(loss).all().item(): 487 | utils.synchronize() 488 | return -1 489 | optimizer.zero_grad() 490 | loss.backward() 491 | torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_norm']) 492 | optimizer.step() 493 | lrs = lr_scheduler.get_last_lr() 494 | if not args.update_lr_every_epoch: 495 | lr_scheduler.step() 496 | 497 | n = logits.size(0) 498 | if args.multiprocessing_distributed: 499 | # TODO: all gather y_score/true 500 | loss = loss * n # not considering ignore pixels 501 | count = ans_iter.new_tensor([n], dtype=torch.long) 502 | dist.all_reduce(loss) 503 | dist.all_reduce(count) 504 | n = count.item() 505 | loss = loss / n 506 | 507 | loss_meter.update(loss.item(), logits.size(0)) 508 | aux_loss_meter.update(aux_loss.item(), logits.size(0)) 509 | batch_time.update(time.time() - end) 510 | end = time.time() 511 | 512 | current_iter = epoch * len(train_loader) + batch_idx + 1 513 | remain_iter = max_iter - current_iter 514 | remain_time = remain_iter * batch_time.avg 515 | t_m, t_s = divmod(remain_time, 60) 516 | t_h, t_m = divmod(t_m, 60) 517 | remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s)) 518 | 519 | if (batch_idx + 1) % config['print_freq'] == 0 and utils.is_main_process(): 520 | utils.log( 521 | '{} Epoch: [{}/{}][{}/{}] ' 522 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 523 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 524 | 'Remain {remain_time} ' 525 | 'Loss {loss_meter.val:.4f} ' 526 | 'Aux Loss {aux_loss_meter.val:.4f} ' 527 | 'lr {lr:.6f}'.format( 528 | 'Train', 529 | epoch, config['max_epoch'], batch_idx + 1, len(train_loader), 530 | batch_time=batch_time, 531 | data_time=data_time, 532 | remain_time=remain_time, 533 | loss_meter=loss_meter, 534 | aux_loss_meter=aux_loss_meter, 535 | lr=lrs[0] 536 | ) 537 | ) 538 | final_y_true, final_y_pred = gather_score_label(dist, train_loader, current_idx.cuda(), y_true.cuda(), y_pred.cuda(), args.multiprocessing_distributed) 539 | acc = compute_acc_gqa(final_y_true, final_y_pred) 540 | if utils.is_main_process(): 541 | utils.log('{} result at epoch [{}/{}]: loss {:.4f}, Acc {:.4f}.'.format('Train', epoch, config['max_epoch'], loss_meter.avg, acc)) 542 | return (loss_meter.avg, aux_loss_meter.avg), acc 543 | 544 | def test(test_loader, model, epoch_log, args): 545 | # eval 546 | model.eval() 547 | 548 | config = args.config 549 | loss_meter = utils.AverageMeter() 550 | # Acc 551 | current_idx = torch.zeros(1).long() 552 | y_true = torch.zeros((len(test_loader.dataset), 1)).long() 553 | y_pred = torch.zeros((len(test_loader.dataset), 1)).long() 554 | 555 | config = args.config 556 | 557 | np.random.seed(0) 558 | for batch in tqdm(test_loader): 559 | feed_list = [] 560 | ims = batch[0].cuda(non_blocking=True) 561 | second_ims = batch[1].cuda(non_blocking=True) 562 | third_ims = batch[2].cuda(non_blocking=True) 563 | feed_list.append(ims) 564 | # For Two-view denseCL 565 | cl_ims = torch.cat([second_ims, third_ims], dim=0) 566 | ques_ix_iter = batch[3].cuda(non_blocking=True) 567 | feed_list.append(ques_ix_iter) 568 | ans_iter = batch[5].cuda(non_blocking=True) 569 | 570 | with torch.no_grad(): 571 | with torch.cuda.amp.autocast(enabled=args.amp): 572 | logits = model(*feed_list) 573 | loss = F.cross_entropy(logits, ans_iter.squeeze(-1), reduction='sum') 574 | y_pred[current_idx.item():current_idx.item()+logits.size(0)] = logits.detach().argmax(-1).unsqueeze(-1).cpu() 575 | y_true[current_idx.item():current_idx.item()+logits.size(0)] = ans_iter.detach().cpu() 576 | current_idx += logits.size(0) 577 | 578 | n = logits.size(0) 579 | if args.multiprocessing_distributed: 580 | loss = loss * n # not considering ignore pixels 581 | count = logits.new_tensor([n], dtype=torch.long) 582 | dist.all_reduce(loss) 583 | dist.all_reduce(count) 584 | n = count.item() 585 | loss = loss / n 586 | else: 587 | loss = torch.mean(loss) 588 | 589 | loss_meter.update(loss.item(), logits.size(0)) 590 | final_y_true, final_y_pred = gather_score_label(dist, test_loader, current_idx.cuda(), y_true.cuda(), y_pred.cuda(), args.multiprocessing_distributed) 591 | acc = compute_acc_gqa(final_y_true, final_y_pred) 592 | return loss_meter.avg, acc 593 | 594 | if __name__ == '__main__': 595 | parser = argparse.ArgumentParser() 596 | parser.add_argument('--config-file') 597 | parser.add_argument('--svname', default=None) 598 | parser.add_argument('--save_dir', default='./save_dist') 599 | parser.add_argument('--tag', default=None) 600 | # parser.add_argument('--gpu', default='0') 601 | parser.add_argument('--seed', type=int, default=123) 602 | parser.add_argument('--workers', type=int, default=8) 603 | parser.add_argument('--test_only', action='store_true') 604 | parser.add_argument('--test_model', default=None) 605 | 606 | # distributed training 607 | parser.add_argument('--world-size', default=1, type=int) 608 | parser.add_argument('--rank', default=0, type=int) 609 | parser.add_argument('--dist-backend', default='nccl') 610 | parser.add_argument('--amp', action='store_true') 611 | parser.add_argument( 612 | "--dist-url", 613 | default="tcp://127.0.0.1:{}", 614 | help="initialization URL for pytorch distributed backend. See " 615 | "https://pytorch.org/docs/stable/distributed.html for details.", 616 | ) 617 | parser.add_argument( 618 | "opts", 619 | help="Modify config options using the command-line", 620 | default=None, 621 | nargs=argparse.REMAINDER, 622 | ) 623 | args = parser.parse_args() 624 | args.multiprocessing_distributed = True 625 | 626 | config = yaml.load(open(args.config_file, 'r'), Loader=yaml.FullLoader) 627 | if args.opts is not None: 628 | config = utils.override_cfg_from_list(config, args.opts) 629 | print('config:') 630 | print(config) 631 | main(config) 632 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the Bongard-HOI library 5 | # which was released under the NVIDIA Source Code Licence. 6 | # 7 | # Source: 8 | # https://github.com/NVlabs/Bongard-HOI 9 | # 10 | # The license for the original version of this file can be 11 | # found in https://github.com/NVlabs/Bongard-HOI/blob/master/LICENSE 12 | # The modifications to this file are subject to the same NVIDIA Source Code Licence. 13 | # --------------------------------------------------------------- 14 | 15 | import os 16 | import shutil 17 | import time 18 | 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.optim import SGD, Adam, AdamW 24 | from torch.optim.lr_scheduler import MultiStepLR 25 | 26 | from typing import Any 27 | import sys 28 | from . import relvit 29 | 30 | import glob 31 | import PIL.Image 32 | from torchvision import transforms 33 | import torch.distributed as dist 34 | import torch.nn as nn 35 | import functools 36 | from copy import deepcopy 37 | 38 | _log_path = None 39 | 40 | _LOCAL_PROCESS_GROUP = None 41 | """ 42 | A torch process group which only includes processes that on the same machine as the current process. 43 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 44 | """ 45 | 46 | import multiprocessing 47 | from torch import Tensor 48 | from typing import Optional, Iterable, Any, List, Union, Tuple 49 | 50 | 51 | def div(numerator: Tensor, denom: Union[Tensor, int, float]) -> Tensor: 52 | """Handle division by zero""" 53 | if type(denom) in [int, float]: 54 | if denom == 0: 55 | return torch.zeros_like(numerator) 56 | else: 57 | return numerator / denom 58 | elif type(denom) is Tensor: 59 | zero_idx = torch.nonzero(denom == 0).squeeze(1) 60 | denom[zero_idx] += 1e-8 61 | return numerator / denom 62 | else: 63 | raise TypeError("Unsupported data type ", type(denom)) 64 | 65 | 66 | def hico_logit_conversion_to_hoi(logits_verb, logits_object, corr): 67 | # logits_verb: (B, 117) 68 | # logits_object: (B, 80) 69 | 70 | # FIXME: how to compute the joint prob properly 71 | with torch.no_grad(): 72 | B = logits_object.size(0) 73 | logits_hoi = torch.zeros(B, 600).to(logits_verb) 74 | for ind, c in enumerate(corr): 75 | obj, verb = c[1], c[2] 76 | logits_hoi[:, ind] = torch.sigmoid(logits_verb[:, verb]) * torch.sigmoid(logits_object[:, obj]) 77 | # logits_hoi[:, ind] = logits_verb[:, verb] + logits_object[:, obj] 78 | return logits_hoi 79 | 80 | 81 | class AveragePrecisionMeter: 82 | """ 83 | Meter to compute average precision 84 | Arguments: 85 | num_gt(iterable): Number of ground truth instances for each class. When left 86 | as None, all positives are assumed to have been included in the collected 87 | results. As a result, full recall is guaranteed when the lowest scoring 88 | example is accounted for. 89 | algorithm(str, optional): AP evaluation algorithm 90 | '11P': 11-point interpolation algorithm prior to voc2010 91 | 'INT': Interpolation algorithm with all points used in voc2010 92 | 'AUC': Precisely as the area under precision-recall curve 93 | chunksize(int, optional): The approximate size the given iterable will be split 94 | into for each worker. Use -1 to make the argument adaptive to iterable size 95 | and number of workers 96 | precision(int, optional): Precision used for float-point operations. Choose 97 | amongst 64, 32 and 16. Default is 64 98 | output(tensor[N, K], optinoal): Network outputs with N examples and K classes 99 | labels(tensor[N, K], optinoal): Binary labels 100 | Usage: 101 | 102 | (1) Evalute AP using provided output scores and labels 103 | >>> # Given output(tensor[N, K]) and labels(tensor[N, K]) 104 | >>> meter = pocket.utils.AveragePrecisionMeter(output=output, labels=labels) 105 | >>> ap = meter.eval(); map_ = ap.mean() 106 | (2) Collect results on the fly and evaluate AP 107 | >>> meter = pocket.utils.AveragePrecisionMeter() 108 | >>> # Compute output(tensor[N, K]) during forward pass 109 | >>> meter.append(output, labels) 110 | >>> ap = meter.eval(); map_ = ap.mean() 111 | >>> # If you are to start new evaluation and want to reset the meter 112 | >>> meter.reset() 113 | """ 114 | def __init__(self, num_gt: Optional[Iterable] = None, 115 | algorithm: str = 'AUC', chunksize: int = -1, 116 | precision: int = 64, 117 | output: Optional[Tensor] = None, 118 | labels: Optional[Tensor] = None) -> None: 119 | self._dtype = eval('torch.float' + str(precision)) 120 | self.num_gt = torch.as_tensor(num_gt, dtype=self._dtype) \ 121 | if num_gt is not None else None 122 | self.algorithm = algorithm 123 | self._chunksize = chunksize 124 | 125 | is_none = (output is None, labels is None) 126 | if is_none == (True, True): 127 | self._output = torch.tensor([], dtype=self._dtype) 128 | self._labels = torch.tensor([], dtype=self._dtype) 129 | elif is_none == (False, False): 130 | self._output = output.detach().cpu().to(self._dtype) 131 | self._labels = labels.detach().cpu().to(self._dtype) 132 | else: 133 | raise AssertionError("Output and labels should both be given or None") 134 | 135 | self._output_temp = [torch.tensor([], dtype=self._dtype)] 136 | self._labels_temp = [torch.tensor([], dtype=self._dtype)] 137 | 138 | @staticmethod 139 | def compute_per_class_ap_as_auc(tuple_: Tuple[Tensor, Tensor]) -> Tensor: 140 | """ 141 | Arguments: 142 | tuple_(Tuple[Tensor, Tensor]): precision and recall 143 | Returns: 144 | ap(Tensor[1]) 145 | """ 146 | prec, rec = tuple_ 147 | ap = 0 148 | max_rec = rec[-1] 149 | for idx in range(prec.numel()): 150 | # Stop when maximum recall is reached 151 | if rec[idx] >= max_rec: 152 | break 153 | d_x = rec[idx] - rec[idx - 1] 154 | # Skip when negative example is registered 155 | if d_x == 0: 156 | continue 157 | ap += prec[idx] * rec[idx] if idx == 0 \ 158 | else 0.5 * (prec[idx] + prec[idx - 1]) * d_x 159 | return ap 160 | 161 | @staticmethod 162 | def compute_per_class_ap_with_interpolation(tuple_: Tuple[Tensor, Tensor]) -> Tensor: 163 | """ 164 | Arguments: 165 | tuple_(Tuple[Tensor, Tensor]): precision and recall 166 | Returns: 167 | ap(Tensor[1]) 168 | """ 169 | prec, rec = tuple_ 170 | ap = 0 171 | max_rec = rec[-1] 172 | for idx in range(prec.numel()): 173 | # Stop when maximum recall is reached 174 | if rec[idx] >= max_rec: 175 | break 176 | d_x = rec[idx] - rec[idx - 1] 177 | # Skip when negative example is registered 178 | if d_x == 0: 179 | continue 180 | # Compute interpolated precision 181 | max_ = prec[idx:].max() 182 | ap += max_ * rec[idx] if idx == 0 \ 183 | else 0.5 * (max_ + torch.max(prec[idx - 1], max_)) * d_x 184 | return ap 185 | 186 | @staticmethod 187 | def compute_per_class_ap_with_11_point_interpolation(tuple_: Tuple[Tensor, Tensor]) -> Tensor: 188 | """ 189 | Arguments: 190 | tuple_(Tuple[Tensor, Tensor]): precision and recall 191 | Returns: 192 | ap(Tensor[1]) 193 | """ 194 | prec, rec = tuple_ 195 | dtype = rec.dtype 196 | ap = 0 197 | for t in torch.linspace(0, 1, 11, dtype=dtype): 198 | inds = torch.nonzero(rec >= t).squeeze() 199 | if inds.numel(): 200 | ap += (prec[inds].max() / 11) 201 | return ap 202 | 203 | @classmethod 204 | def compute_ap(cls, output: Tensor, labels: Tensor, 205 | num_gt: Optional[Tensor] = None, 206 | algorithm: str = 'AUC', 207 | chunksize: int = -1) -> Tensor: 208 | """ 209 | Compute average precision under the classification setting. Scores of all 210 | classes are retained for each sample. 211 | Arguments: 212 | output(Tensor[N, K]) 213 | labels(Tensor[N, K]) 214 | num_gt(Tensor[K]): Number of ground truth instances for each class 215 | algorithm(str): AP evaluation algorithm 216 | chunksize(int, optional): The approximate size the given iterable will be split 217 | into for each worker. Use -1 to make the argument adaptive to iterable size 218 | and number of workers 219 | Returns: 220 | ap(Tensor[K]) 221 | """ 222 | prec, rec = cls.compute_precision_and_recall(output, labels, 223 | num_gt=num_gt) 224 | ap = torch.zeros(output.shape[1], dtype=prec.dtype) 225 | # Use the logic from pool._map_async to compute chunksize 226 | # https://github.com/python/cpython/blob/master/Lib/multiprocessing/pool.py 227 | # NOTE: Inappropriate chunksize will cause [Errno 24]Too many open files 228 | # Make changes with caution 229 | if chunksize == -1: 230 | chunksize, extra = divmod( 231 | output.shape[1], 232 | multiprocessing.cpu_count() * 4) 233 | if extra: 234 | chunksize += 1 235 | 236 | if algorithm == 'INT': 237 | algorithm_handle = cls.compute_per_class_ap_with_interpolation 238 | elif algorithm == '11P': 239 | algorithm_handle = cls.compute_per_class_ap_with_11_point_interpolation 240 | elif algorithm == 'AUC': 241 | algorithm_handle = cls.compute_per_class_ap_as_auc 242 | else: 243 | raise ValueError("Unknown algorithm option {}.".format(algorithm)) 244 | 245 | # with multiprocessing.get_context('spawn').Pool() as pool: 246 | # for idx, result in enumerate(pool.imap( 247 | # func=algorithm_handle, 248 | # # NOTE: Use transpose instead of T for compatibility 249 | # iterable=zip(prec.transpose(0,1), rec.transpose(0,1)), 250 | # chunksize=chunksize 251 | # )): 252 | # ap[idx] = algorithm_handle(prec[idx], rec[idx]) 253 | prec = prec.transpose(0,1) 254 | rec = rec.transpose(0,1) 255 | for idx in range(len(prec)): 256 | ap[idx] = algorithm_handle((prec[idx], rec[idx])) 257 | 258 | return ap 259 | 260 | @staticmethod 261 | def compute_precision_and_recall(output: Tensor, labels: Tensor, 262 | num_gt: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 263 | """ 264 | Arguments: 265 | output(Tensor[N, K]) 266 | labels(Tensor[N, K]) 267 | num_gt(Tensor[K]) 268 | Returns: 269 | prec(Tensor[N, K]) 270 | rec(Tensor[N, K]) 271 | """ 272 | order = output.argsort(0, descending=True) 273 | tp = labels[ 274 | order, 275 | torch.ones_like(order) * torch.arange(output.shape[1]) 276 | ] 277 | fp = 1 - tp 278 | tp = tp.cumsum(0) 279 | fp = fp.cumsum(0) 280 | 281 | prec = tp / (tp + fp) 282 | rec = div(tp, labels.sum(0)) if num_gt is None \ 283 | else div(tp, num_gt) 284 | 285 | return prec, rec 286 | 287 | def append(self, output: Tensor, labels: Tensor) -> None: 288 | """ 289 | Add new results to the meter 290 | Arguments: 291 | output(tensor[N, K]): Network output with N examples and K classes 292 | labels(tensor[N, K]): Binary labels 293 | """ 294 | if isinstance(output, torch.Tensor) and isinstance(labels, torch.Tensor): 295 | assert output.shape == labels.shape, \ 296 | "Output scores do not match the dimension of labelss" 297 | self._output_temp.append(output.detach().cpu().to(self._dtype)) 298 | self._labels_temp.append(labels.detach().cpu().to(self._dtype)) 299 | else: 300 | raise TypeError("Arguments should both be torch.Tensor") 301 | 302 | def reset(self, keep_old: bool = False) -> None: 303 | """ 304 | Clear saved statistics 305 | Arguments: 306 | keep_tracked(bool): If True, clear only the newly collected statistics 307 | since last evaluation 308 | """ 309 | if not keep_old: 310 | self._output = torch.tensor([], dtype=self._dtype) 311 | self._labels = torch.tensor([], dtype=self._dtype) 312 | self._output_temp = [torch.tensor([], dtype=self._dtype)] 313 | self._labels_temp = [torch.tensor([], dtype=self._dtype)] 314 | 315 | def eval(self) -> Tensor: 316 | """ 317 | Evaluate the average precision based on collected statistics 318 | Returns: 319 | torch.Tensor[K]: Average precisions for K classes 320 | """ 321 | self._output = torch.cat([ 322 | self._output, 323 | torch.cat(self._output_temp, 0) 324 | ], 0) 325 | self._labels = torch.cat([ 326 | self._labels, 327 | torch.cat(self._labels_temp, 0) 328 | ], 0) 329 | self.reset(keep_old=True) 330 | 331 | # Sanity check 332 | if self.num_gt is not None: 333 | self.num_gt = self.num_gt.to(dtype=self._labels.dtype) 334 | faulty_cls = torch.nonzero(self._labels.sum(0) > self.num_gt).squeeze(1) 335 | if len(faulty_cls): 336 | raise AssertionError("Class {}: ".format(faulty_cls.tolist())+ 337 | "Number of true positives larger than that of ground truth") 338 | if len(self._output) and len(self._labels): 339 | return self.compute_ap(self._output, self._labels, num_gt=self.num_gt, 340 | algorithm=self.algorithm, chunksize=self._chunksize) 341 | else: 342 | print("WARNING: Collected results are empty. " 343 | "Return zero AP for all class.") 344 | return torch.zeros(self._output.shape[1], dtype=self._dtype) 345 | 346 | class AverageMeter(object): 347 | """Computes and stores the average and current value""" 348 | def __init__(self): 349 | self.reset() 350 | 351 | def reset(self): 352 | self.val = 0 353 | self.avg = 0 354 | self.sum = 0 355 | self.count = 0 356 | 357 | def update(self, val, n=1): 358 | self.val = val 359 | self.sum += val * n 360 | self.count += n 361 | self.avg = self.sum / self.count 362 | 363 | 364 | def set_log_path(path): 365 | global _log_path 366 | _log_path = path 367 | 368 | 369 | def log(obj, filename='log.txt'): 370 | print(obj) 371 | if _log_path is not None: 372 | with open(os.path.join(_log_path, filename), 'a') as f: 373 | print(obj, file=f) 374 | 375 | 376 | class Averager(): 377 | 378 | def __init__(self): 379 | self.n = 0.0 380 | self.v = 0.0 381 | 382 | def add(self, v, n=1.0): 383 | self.v = (self.v * self.n + v * n) / (self.n + n) 384 | self.n += n 385 | 386 | def item(self): 387 | return self.v 388 | 389 | 390 | class Timer(): 391 | 392 | def __init__(self): 393 | self.v = time.time() 394 | 395 | def s(self): 396 | self.v = time.time() 397 | 398 | def t(self): 399 | return time.time() - self.v 400 | 401 | 402 | def set_gpu(gpu): 403 | print('set gpu:', gpu) 404 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu 405 | 406 | 407 | def ensure_path(path, remove=True): 408 | basename = os.path.basename(path.rstrip('/')) 409 | if os.path.exists(path): 410 | if remove and (basename.startswith('_') 411 | or input('{} exists, remove? ([y]/n): '.format(path)) != 'n'): 412 | shutil.rmtree(path) 413 | os.makedirs(path) 414 | else: 415 | os.makedirs(path) 416 | 417 | 418 | def time_str(t): 419 | if t >= 3600: 420 | return '{:.1f}h'.format(t / 3600) 421 | if t >= 60: 422 | return '{:.1f}m'.format(t / 60) 423 | return '{:.1f}s'.format(t) 424 | 425 | 426 | def compute_logits(feat, proto, metric='dot', temp=1.0): 427 | assert feat.dim() == proto.dim() 428 | 429 | if feat.dim() == 2: 430 | if metric == 'dot': 431 | logits = torch.mm(feat, proto.t()) 432 | elif metric == 'cos': 433 | logits = torch.mm(F.normalize(feat, dim=-1), 434 | F.normalize(proto, dim=-1).t()) 435 | elif metric == 'sqr': 436 | logits = -(feat.unsqueeze(1) - 437 | proto.unsqueeze(0)).pow(2).sum(dim=-1) 438 | 439 | elif feat.dim() == 3: 440 | if metric == 'dot': 441 | logits = torch.bmm(feat, proto.permute(0, 2, 1)) 442 | elif metric == 'cos': 443 | logits = torch.bmm(F.normalize(feat, dim=-1), 444 | F.normalize(proto, dim=-1).permute(0, 2, 1)) 445 | elif metric == 'sqr': 446 | logits = -(feat.unsqueeze(2) - 447 | proto.unsqueeze(1)).pow(2).sum(dim=-1) 448 | 449 | return logits * temp 450 | 451 | 452 | def compute_acc(logits, label, reduction='mean'): 453 | ret = (torch.argmax(logits, dim=1) == label).float() 454 | if reduction == 'none': 455 | return ret.detach() 456 | elif reduction == 'mean': 457 | return ret.mean() 458 | 459 | 460 | def compute_n_params(model, return_str=True): 461 | tot = 0 462 | for p in model.parameters(): 463 | w = 1 464 | for x in p.shape: 465 | w *= x 466 | tot += w 467 | if return_str: 468 | if tot >= 1e6: 469 | return '{:.1f}M'.format(tot / 1e6) 470 | else: 471 | return '{:.1f}K'.format(tot / 1e3) 472 | else: 473 | return tot 474 | 475 | def make_optimizer(params, name, max_steps, lr, weight_decay=None, milestones=None, scheduler='step', use_sam=False, sam_rho=0.005, eps=1e-8, **kwargs): 476 | if weight_decay is None: 477 | weight_decay = 0. 478 | if use_sam: 479 | optimizer = SAM(params, AdamW, rho=sam_rho, lr=lr, weight_decay=weight_decay, eps=1e-08) 480 | else: 481 | if name == 'sgd': 482 | optimizer = SGD(params, lr, momentum=0.9, weight_decay=weight_decay) 483 | elif name == 'adam': 484 | optimizer = Adam(params, lr, weight_decay=weight_decay) 485 | elif name == 'adamw': 486 | optimizer = AdamW( 487 | params, float(lr), betas=(0.9, 0.999), eps=float(eps), 488 | weight_decay=weight_decay 489 | ) 490 | 491 | update_lr_every_epoch = True 492 | if scheduler == 'step': 493 | if milestones: 494 | lr_scheduler = MultiStepLR(optimizer, milestones) 495 | else: 496 | lr_scheduler = None 497 | elif scheduler == 'onecycle': 498 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( 499 | optimizer, 500 | lr, 501 | max_steps + 100, 502 | pct_start=0.05, 503 | cycle_momentum=False, 504 | anneal_strategy='linear', 505 | final_div_factor=10000 506 | ) 507 | update_lr_every_epoch = False 508 | elif scheduler == 'warmup_cosine': 509 | import pl_bolts 510 | lr_scheduler = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(optimizer, kwargs['warmup_epochs'], kwargs['max_epochs'], warmup_start_lr=kwargs['warmup_start_lr'], eta_min=0.0, last_epoch=-1) 511 | return optimizer, lr_scheduler, update_lr_every_epoch 512 | 513 | def set_lr(optimizer, lr): 514 | s = optimizer.state_dict() 515 | s['param_groups'][0]['lr'] = lr 516 | optimizer.load_state_dict(s) 517 | 518 | def get_lr(optimizer): 519 | return optimizer.state_dict()['param_groups'][0]['lr'] 520 | 521 | class ProjHead(nn.Module): 522 | def __init__(self, in_dim, out_dim): 523 | super(ProjHead, self).__init__() 524 | self.in_dim = in_dim 525 | self.out_dim = out_dim 526 | 527 | self.proj = nn.Sequential( 528 | nn.Linear(self.in_dim, self.out_dim), 529 | nn.GELU(), 530 | nn.Linear(self.out_dim, self.out_dim), 531 | ) 532 | 533 | def forward(self, x): 534 | x = self.proj(x) 535 | return x 536 | 537 | MCAN_GQA_PARAMS = { 538 | 'FRCN_FEAT_SIZE': (100, 2048), 539 | 'GRID_FEAT_SIZE': (49, 2048), 540 | 'BBOX_FEAT_SIZE': (100, 5), 541 | 'BBOXFEAT_EMB_SIZE': 2048, 542 | 'HIDDEN_SIZE': 512, 543 | 'FLAT_MLP_SIZE': 512, 544 | 'FLAT_GLIMPSES': 1, 545 | 'FLAT_OUT_SIZE': 1024, 546 | 'DROPOUT_R': 0.1, 547 | 'LAYER': 6, 548 | 'FF_SIZE': 2048, 549 | 'MULTI_HEAD': 8, 550 | 'WORD_EMBED_SIZE': 300, 551 | 'TOKEN_SIZE': 2933, 552 | 'WORD_EMBED_SIZE': 300, 553 | 'ANSWER_SIZE': 1843, 554 | 'MAX_TOKEN_LENGTH': 29, 555 | 'USE_BBOX_FEAT': True, 556 | 'USE_AUX_FEAT': True, 557 | } 558 | 559 | def freeze_bn(model): 560 | for m in model.modules(): 561 | if isinstance(m, nn.BatchNorm2d): 562 | m.eval() 563 | 564 | class Logger(object): 565 | """ 566 | Redirect stderr to stdout, optionally print stdout to a file, 567 | and optionally force flushing on both stdout and the file. 568 | """ 569 | 570 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 571 | self.file = None 572 | 573 | if file_name is not None: 574 | self.file = open(file_name, file_mode) 575 | 576 | self.should_flush = should_flush 577 | self.stdout = sys.stdout 578 | self.stderr = sys.stderr 579 | 580 | sys.stdout = self 581 | sys.stderr = self 582 | 583 | def __enter__(self) -> "Logger": 584 | return self 585 | 586 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 587 | self.close() 588 | 589 | def write(self, text: str) -> None: 590 | """Write text to stdout (and a file) and optionally flush.""" 591 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 592 | return 593 | 594 | if self.file is not None: 595 | self.file.write(text) 596 | 597 | self.stdout.write(text) 598 | 599 | if self.should_flush: 600 | self.flush() 601 | 602 | def flush(self) -> None: 603 | """Flush written text to both stdout and a file, if open.""" 604 | if self.file is not None: 605 | self.file.flush() 606 | 607 | self.stdout.flush() 608 | 609 | def close(self) -> None: 610 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 611 | self.flush() 612 | 613 | # if using multiple loggers, prevent closing in wrong order 614 | if sys.stdout is self: 615 | sys.stdout = self.stdout 616 | if sys.stderr is self: 617 | sys.stderr = self.stderr 618 | 619 | if self.file is not None: 620 | self.file.close() 621 | 622 | def anytype2bool_dict(s): 623 | # check str 624 | if not isinstance(s, str): 625 | return s 626 | else: 627 | # try int 628 | try: 629 | ret = int(s) 630 | except: 631 | # try bool 632 | if s.lower() in ('true', 'false'): 633 | ret = s.lower() == 'true' 634 | # try float 635 | else: 636 | try: 637 | ret = float(s) 638 | except: 639 | ret = s 640 | return ret 641 | 642 | def parse_string_to_dict(field_name, value): 643 | fields = field_name.split('.') 644 | for fd in fields[::-1]: 645 | res = {fd: anytype2bool_dict(value)} 646 | value = res 647 | return res 648 | 649 | def merge_to_dicts(a, b): 650 | if isinstance(b, dict) and isinstance(a, dict): 651 | a_and_b = set(a.keys()) & set(b.keys()) 652 | every_key = set(a.keys()) | set(b.keys()) 653 | return {k: merge_to_dicts(a[k], b[k]) if k in a_and_b else 654 | deepcopy(a[k] if k in a else b[k]) for k in every_key} 655 | return deepcopy(type(a)(b)) 656 | 657 | def override_cfg_from_list(cfg, opts): 658 | assert len(opts) % 2 == 0, 'Paired input must be provided to override config, opts: {}'.format(opts) 659 | for ix in range(0, len(opts), 2): 660 | opts_dict = parse_string_to_dict(opts[ix], opts[ix + 1]) 661 | cfg = merge_to_dicts(cfg, opts_dict) 662 | return cfg 663 | 664 | # ---------------------------------------------------------------------------- 665 | 666 | def find_free_port(): 667 | import socket 668 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 669 | # Binding to port 0 will cause the OS to find an available port for us 670 | sock.bind(("", 0)) 671 | port = sock.getsockname()[1] 672 | sock.close() 673 | # NOTE: there is still a chance the port could be taken by other processes. 674 | return port 675 | 676 | def get_world_size() -> int: 677 | if not dist.is_available(): 678 | return 1 679 | if not dist.is_initialized(): 680 | return 1 681 | return dist.get_world_size() 682 | 683 | 684 | def get_rank() -> int: 685 | if not dist.is_available(): 686 | return 0 687 | if not dist.is_initialized(): 688 | return 0 689 | return dist.get_rank() 690 | 691 | 692 | def get_local_rank() -> int: 693 | """ 694 | Returns: 695 | The rank of the current process within the local (per-machine) process group. 696 | """ 697 | if not dist.is_available(): 698 | return 0 699 | if not dist.is_initialized(): 700 | return 0 701 | assert _LOCAL_PROCESS_GROUP is not None 702 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 703 | 704 | 705 | def get_local_size() -> int: 706 | """ 707 | Returns: 708 | The size of the per-machine process group, 709 | i.e. the number of processes per machine. 710 | """ 711 | if not dist.is_available(): 712 | return 1 713 | if not dist.is_initialized(): 714 | return 1 715 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 716 | 717 | 718 | def is_main_process() -> bool: 719 | return get_rank() == 0 720 | 721 | 722 | def synchronize(): 723 | """ 724 | Helper function to synchronize (barrier) among all processes when 725 | using distributed training 726 | """ 727 | if not dist.is_available(): 728 | return 729 | if not dist.is_initialized(): 730 | return 731 | world_size = dist.get_world_size() 732 | if world_size == 1: 733 | return 734 | dist.barrier() 735 | -------------------------------------------------------------------------------- /utils/relvit.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for RelViT. To view a copy of this license, see the LICENSE file. 6 | # ---------------------------------------------------------------------- 7 | 8 | import logging 9 | import math 10 | import warnings 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.distributed as dist 17 | 18 | 19 | def comp_sample_prob(size, pos, K, use_log=False): 20 | # case 1: size < K 21 | # then pos-1 > pos-2 > ... > 0 22 | if size < K: 23 | rank = np.arange(pos)+1 24 | # case 2: size == K 25 | # then (pos-1)%K > (pos-2)%K > ... > pos 26 | else: 27 | rank = np.roll(np.arange(size), pos)+1 28 | if use_log: 29 | rank = np.log(10*rank) 30 | assert len(rank) == size 31 | return torch.from_numpy(rank).float() 32 | 33 | def dequeue_with_concept(feat, concept, queue_dict, queue_grid_dict, queue_ptr_dict, sample_uniform=True): 34 | # concept: (B, num_concepts) 35 | # If there is no enough concepts in the queue or the current sample comes without any concepts, 36 | # just return the corresponding feat. 37 | 38 | # (2*B, C), (2*B, N, C), (2*B, N, C) 39 | t_cls_out, t_region_out, t_fea = feat 40 | N = t_region_out.size(1) 41 | 42 | K = queue_dict[0].buffer.size(0) 43 | 44 | B = concept.size(0) 45 | with torch.no_grad(): 46 | # sanitize samples without any concept 47 | concept = concept.repeat(2, 1) 48 | mask = (concept.sum(dim=-1) == 0) 49 | concept[mask] += 1 50 | concept_sample = torch.distributions.Categorical(concept).sample() 51 | 52 | ret1, ret2, ret3 = [], [], [] 53 | for ind, (c, m) in enumerate(zip(concept_sample, mask)): 54 | cur_pos = queue_ptr_dict[c.item()].buffer[0].item() 55 | size = queue_ptr_dict[c.item()].buffer[1].item() 56 | 57 | if size != 0 and m == 0: 58 | if sample_uniform: 59 | # Equal prob 60 | pos = torch.distributions.Categorical(torch.ones(size)).sample().item() 61 | else: 62 | # "FIFO" prob 63 | prob = comp_sample_prob(size, cur_pos, K, use_log=True) 64 | pos = torch.distributions.Categorical(prob).sample().item() 65 | 66 | ret1.append(queue_dict[c.item()].buffer[pos]) 67 | ret2.append(queue_grid_dict[c.item()].buffer[pos, :N]) 68 | ret3.append(queue_grid_dict[c.item()].buffer[pos, N:]) 69 | else: 70 | ret1.append(feat[0][ind]) 71 | ret2.append(feat[1][ind]) 72 | ret3.append(feat[2][ind]) 73 | ret1 = torch.stack(ret1).to(t_region_out) 74 | ret2 = torch.stack(ret2).to(t_region_out) 75 | ret3 = torch.stack(ret3).to(t_region_out) 76 | assert ret1.shape == feat[0].shape 77 | assert ret2.shape == feat[1].shape 78 | assert ret3.shape == feat[2].shape 79 | return (ret1.contiguous(), ret2.contiguous(), ret3.contiguous()) 80 | 81 | def enqueue_with_concept(feat, concept, queue_dict, queue_grid_dict, queue_ptr_dict): 82 | # feat: (2*B, C) 83 | # concept: (B, num_concepts) 84 | # We only work on the first B instances and skip those without any concepts. 85 | 86 | # (2*B, C), (2*B, N, C), (2*B, N, C) 87 | t_cls_out, t_region_out, t_fea = feat 88 | N = t_region_out.size(1) 89 | 90 | K = queue_dict[0].buffer.size(0) 91 | 92 | with torch.no_grad(): 93 | # sanitize samples without any concept 94 | mask = (concept.sum(dim=-1) == 0) 95 | concept[mask] += 1 96 | concept_sample = torch.distributions.Categorical(concept).sample() 97 | 98 | for ind, (c, m) in enumerate(zip(concept_sample, mask)): 99 | if m == 0: 100 | # write pos and size 101 | pos = queue_ptr_dict[c.item()].buffer[0] 102 | size = queue_ptr_dict[c.item()].buffer[1] 103 | pos += 1 104 | size += 1 105 | # write pos should loop back to 0 106 | pos %= K 107 | queue_ptr_dict[c.item()].buffer[0] = pos 108 | # size should be clamped to K 109 | size = torch.clamp(size, 0, K) 110 | queue_ptr_dict[c.item()].buffer[1] = size 111 | 112 | queue_dict[c.item()].buffer[pos.item()] = feat[0][ind].detach() 113 | queue_grid_dict[c.item()].buffer[pos.item(), :N] = feat[1][ind].detach() 114 | queue_grid_dict[c.item()].buffer[pos.item(), N:] = feat[2][ind].detach() 115 | 116 | def RCL(attn_v_tea, attn_v_stu, concept, queue_dict, queue_grid_dict, queue_ptr_dict, center, center_grid, tau=0.04, local_only=1, tau_stu=0.1, sample_uniform=True): 117 | # only: 1 -- local; 2 -- global 118 | 119 | # FIXME: we should dequeue before enqueuing to avoid using the recently added samples. 120 | target = dequeue_with_concept(attn_v_tea, concept, queue_dict, queue_grid_dict, queue_ptr_dict, sample_uniform) 121 | enqueue_with_concept(attn_v_tea, concept, queue_dict, queue_grid_dict, queue_ptr_dict) 122 | 123 | return token_level_esvit(target, attn_v_stu, center, center_grid, tau, local_only, tau_stu) 124 | 125 | def token_level_esvit(attn_v_tea, attn_v_stu, center, center_grid, tau=0.04, local_only=True, tau_stu=0.1): 126 | # only: 1 -- local; 2 -- global 127 | # FIXME: temporarily disable multi-crop 128 | ncrops = 2 129 | # (2*B, C), (2*B, N, C), (2*B, N, C) 130 | s_cls_out, s_region_out, s_fea = attn_v_stu 131 | t_cls_out, t_region_out, t_fea = attn_v_tea 132 | B, N = s_region_out.size(0), s_region_out.size(1) 133 | B = B // 2 134 | s_region_out = torch.cat([ 135 | s_region_out[:B].reshape(B*N, -1), 136 | s_region_out[B:].reshape(B*N, -1)] 137 | ).contiguous() 138 | s_fea = torch.cat([ 139 | s_fea[:B].reshape(B*N, -1), 140 | s_fea[B:].reshape(B*N, -1)] 141 | ).contiguous() 142 | t_region_out = torch.cat([ 143 | t_region_out[:B].reshape(B*N, -1), 144 | t_region_out[B:].reshape(B*N, -1)] 145 | ).contiguous() 146 | t_fea = torch.cat([ 147 | t_fea[:B].reshape(B*N, -1), 148 | t_fea[B:].reshape(B*N, -1)] 149 | ).contiguous() 150 | s_npatch = [N, N] 151 | t_npatch = [N, N] 152 | 153 | # teacher centering and sharpening 154 | temp = tau 155 | t_cls = F.softmax((t_cls_out - center) / temp, dim=-1) 156 | t_cls = t_cls.detach().chunk(2) 157 | 158 | t_region = F.softmax((t_region_out - center_grid) / temp, dim=-1) 159 | t_region = t_region.detach().chunk(2) 160 | t_fea = t_fea.chunk(2) 161 | 162 | 163 | N = t_npatch[0] # num of patches in the first view 164 | B = t_region[0].shape[0]//N # batch size, 165 | 166 | # student sharpening 167 | s_cls = s_cls_out / tau 168 | s_cls = s_cls.chunk(ncrops) 169 | 170 | s_region = s_region_out / tau_stu 171 | s_split_size = [s_npatch[0]] * 2 + [s_npatch[1]] * (ncrops -2) 172 | 173 | s_split_size_bs = [i * B for i in s_split_size] 174 | 175 | s_region = torch.split(s_region, s_split_size_bs, dim=0) 176 | s_fea = torch.split(s_fea, s_split_size_bs, dim=0) 177 | 178 | total_loss = 0 179 | n_loss_terms = 0 180 | for iq, q in enumerate(t_cls): 181 | for v in range(len(s_cls)): 182 | if v == iq: 183 | # we skip cases where student and teacher operate on the same view 184 | continue 185 | 186 | # view level prediction loss 187 | loss = 0.5 * torch.sum(-q * F.log_softmax(s_cls[v], dim=-1), dim=-1) 188 | if local_only == 1: 189 | loss *= 0 190 | 191 | # region level prediction loss 192 | s_region_cur, s_fea_cur = s_region[v].view(B, s_split_size[v], -1).contiguous(), s_fea[v].view(B, s_split_size[v], -1).contiguous() # B x T_s x K, B x T_s x P 193 | t_region_cur, t_fea_cur = t_region[iq].view(B, N, -1).contiguous(), t_fea[iq].view(B, N, -1).contiguous() # B x T_t x K, B x T_t x P, 194 | 195 | # similarity matrix between two sets of region features 196 | region_sim_matrix = torch.matmul(F.normalize(s_fea_cur, p=2, dim=-1) , F.normalize(t_fea_cur, p=2, dim=-1).permute(0, 2, 1).contiguous()) # B x T_s x T_t 197 | region_sim_ind = region_sim_matrix.max(dim=2)[1] # B x T_s; collect the argmax index in teacher for a given student feature 198 | 199 | t_indexed_region = torch.gather( t_region_cur, 1, region_sim_ind.unsqueeze(2).expand(-1, -1, t_region_cur.size(2)) ) # B x T_s x K (index matrix: B, T_s, 1) 200 | 201 | loss_grid = torch.sum(- t_indexed_region * F.log_softmax(s_region_cur, dim=-1), dim=[-1]).mean(-1) # B x T_s x K --> B 202 | 203 | if local_only == 2: 204 | loss += 0. * loss_grid 205 | else: 206 | loss += 0.5 * loss_grid 207 | 208 | total_loss += loss.mean() 209 | n_loss_terms += 1 210 | total_loss /= n_loss_terms 211 | 212 | return total_loss, t_cls_out.contiguous(), t_region_out.contiguous() 213 | 214 | class Buffer(nn.Module): 215 | def __init__(self, tensor_cls, *args, **kwargs): 216 | super(Buffer, self).__init__() 217 | self.register_buffer('buffer', tensor_cls(*args, **kwargs)) 218 | 219 | class MoCo(nn.Module): 220 | """ 221 | Build a MoCo model with: a query encoder, a key encoder, and a queue 222 | https://arxiv.org/abs/1911.05722 223 | """ 224 | def __init__(self, encoder_tea_k, encoder_stu_q, K=65536, m=0.999, feat_dim=None, cl_loss='single-view', num_concepts=600, relvit_mode=1, num_tokens=49): 225 | """ 226 | dim: feature dimension (default: 128) 227 | K: queue size; number of negative keys (default: 65536) 228 | m: moco momentum of updating key encoder (default: 0.999) 229 | T: softmax temperature (default: 0.07) 230 | """ 231 | super(MoCo, self).__init__() 232 | 233 | self.K = K 234 | self.m = m 235 | 236 | # create the encoders 237 | # feature embedding size is the output fc dimension 238 | self.encoder_q = [encoder_stu_q] 239 | self.encoder_k = [encoder_tea_k] 240 | dim = self.encoder_q[0].out_dim if feat_dim is None else feat_dim 241 | 242 | for param_q, param_k in zip(self.encoder_q[0].parameters(), self.encoder_k[0].parameters()): 243 | param_k.data.copy_(param_q.data) # initialize 244 | param_k.requires_grad = False # not update by gradient 245 | 246 | self.relvit_mode = relvit_mode 247 | self.center_momentum = 0.9 248 | self.register_buffer("center", torch.zeros(1, dim)) 249 | self.register_buffer("center_grid", torch.zeros(1, dim)) 250 | if self.relvit_mode: 251 | # FIXME: https://github.com/microsoft/esvit/blob/main/main_esvit.py#L606 252 | self.register_buffer("center_rcl", torch.zeros(1, dim)) 253 | self.register_buffer("center_grid_rcl", torch.zeros(1, dim)) 254 | 255 | if self.relvit_mode: 256 | # assert self.K <= 10 257 | self.queue_dict = nn.ModuleList([ 258 | Buffer(torch.rand, self.K, dim) for i in range(num_concepts) 259 | ]) 260 | # FIXME: magic number; first projected and unprojected 261 | # 49x2 works with pvtv2b2 and swin_small 262 | # 196x2 works with vit_small_16 263 | self.queue_grid_dict = nn.ModuleList([ 264 | Buffer(torch.rand, self.K, num_tokens*2, dim) for i in range(num_concepts) 265 | ]) 266 | # current writing pos and size 267 | self.queue_ptr_dict = nn.ModuleList([ 268 | Buffer(torch.zeros, 2, dtype=torch.long) for i in range(num_concepts) 269 | ]) 270 | 271 | @torch.no_grad() 272 | def _momentum_update_key_encoder(self): 273 | """ 274 | Momentum update of the key encoder 275 | """ 276 | for param_q, param_k in zip(self.encoder_q[0].parameters(), self.encoder_k[0].parameters()): 277 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 278 | 279 | @torch.no_grad() 280 | def _dequeue_and_enqueue(self, keys): 281 | batch_size = keys.shape[0] 282 | 283 | ptr = int(self.queue_ptr) 284 | assert self.K % batch_size == 0, batch_size # for simplicity 285 | 286 | # replace the keys at ptr (dequeue and enqueue) 287 | self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) 288 | ptr = (ptr + batch_size) % self.K # move pointer 289 | 290 | self.queue_ptr[0] = ptr 291 | 292 | def forward(self, attn_v_k, attn_v_q, B, config, args, concepts=None): 293 | def cal_batch(teacher_output, teacher_grid_output): 294 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 295 | if args.multiprocessing_distributed: 296 | dist.all_reduce(batch_center) 297 | world_size = dist.get_world_size() 298 | else: 299 | world_size = 1 300 | batch_center = batch_center / (len(teacher_output) * world_size) 301 | 302 | # region level center update 303 | batch_grid_center = torch.sum(teacher_grid_output, dim=0, keepdim=True) 304 | if args.multiprocessing_distributed: 305 | dist.all_reduce(batch_grid_center) 306 | world_size = dist.get_world_size() 307 | else: 308 | world_size = 1 309 | batch_grid_center = batch_grid_center / (len(teacher_grid_output) * world_size) 310 | return batch_center, batch_grid_center 311 | 312 | # compute key features 313 | with torch.no_grad(): # no gradient to keys 314 | self._momentum_update_key_encoder() # update the key encoder 315 | 316 | loss, teacher_output, teacher_grid_output = token_level_esvit( 317 | attn_v_k, 318 | attn_v_q, 319 | self.center, 320 | self.center_grid, 321 | tau=config['relvit_loss_tau'], 322 | local_only=config['relvit_local_only']) 323 | with torch.no_grad(): 324 | # ema update for esvit 325 | batch_center, batch_grid_center = cal_batch(teacher_output, teacher_grid_output) 326 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 327 | self.center_grid = self.center_grid * self.center_momentum + batch_grid_center * (1 - self.center_momentum) 328 | 329 | if self.relvit_mode: 330 | loss_rcl, teacher_output_rcl, teacher_grid_output_rcl = RCL( 331 | attn_v_k, 332 | attn_v_q, 333 | concepts, 334 | self.queue_dict, 335 | self.queue_grid_dict, 336 | self.queue_ptr_dict, 337 | self.center_rcl, 338 | self.center_grid_rcl, 339 | tau=config['relvit_loss_tau'], 340 | local_only=config['relvit_local_only'], 341 | sample_uniform=config['relvit_sample_uniform']) 342 | 343 | if self.relvit_mode == 2: 344 | loss = loss * 0 + loss_rcl 345 | elif self.relvit_mode == 1: 346 | loss += loss_rcl 347 | else: 348 | loss = loss 349 | with torch.no_grad(): 350 | # ema update for RCL 351 | batch_center_rcl, batch_grid_center_rcl = cal_batch(teacher_output_rcl, teacher_grid_output_rcl) 352 | self.center_rcl = self.center_rcl * self.center_momentum + batch_center_rcl * (1 - self.center_momentum) 353 | self.center_grid_rcl = self.center_grid_rcl * self.center_momentum + batch_grid_center_rcl * (1 - self.center_momentum) 354 | 355 | return loss --------------------------------------------------------------------------------