├── .gitignore
├── LICENSE
├── README.md
├── assets
├── data
│ ├── gqa
│ │ └── .gitkeep
│ └── hico
│ │ └── .gitkeep
├── dataset.md
└── overview.png
├── cache
└── .gitkeep
├── configs
├── train_gqa_mcan.yaml
└── train_hico.yaml
├── datasets
├── __init__.py
├── datasets.py
├── image_gqa.py
└── image_hicodet_bbox.py
├── models
├── __init__.py
├── classifier.py
├── mcan.py
├── models.py
├── pvt_v2.py
├── swin_transformer.py
├── transparent_encoder.py
└── vit.py
├── requirements.txt
├── scripts
├── train_gqa_image.sh
└── train_hico_image.sh
├── train_gqa.py
├── train_hico.py
└── utils
├── __init__.py
└── relvit.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # custom
2 | *.json
3 | *.jpg
4 | *.png
5 | *.pkl
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | NVIDIA Source Code License for RelViT
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 |
7 | “Software” means the original work of authorship made available under this License.
8 |
9 | “Work” means the Software and any additions to or derivative works of the Software that are made
10 | available under this License.
11 |
12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as
13 | provided under U.S. copyright law; provided, however, that for the purposes of this License,
14 | derivative works shall not include works that remain separable from, or merely link (or bind by name)
15 | to the interfaces of, the Work.
16 |
17 | Works, including the Software, are “made available” under this License by including in or with the
18 | Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a
19 | copy of this License.
20 |
21 | 2. License Grant
22 |
23 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants
24 | to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce,
25 | prepare derivative works of, publicly display, publicly perform, sublicense and distribute its
26 | Work and any resulting derivative works in any form.
27 |
28 | 3. Limitations
29 |
30 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this
31 | License, (b) you include a complete copy of this License with your distribution, and (c) you
32 | retain without modification any copyright, patent, trademark, or attribution notices that are
33 | present in the Work.
34 |
35 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use,
36 | reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a)
37 | Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and
38 | (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding
39 | Your Terms, this License (including the redistribution requirements in Section 3.1) will
40 | continue to apply to the Work itself.
41 |
42 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended
43 | for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use
44 | the Work and any derivative works commercially. As used herein, “non-commercially” means
45 | for research or evaluation purposes only.
46 |
47 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor
48 | (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you
49 | allege are infringed by any Work, then your rights under this License from such Licensor
50 | (including the grant in Section 2.1) will terminate immediately.
51 |
52 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’
53 | names, logos, or trademarks, except as necessary to reproduce the notices described in this
54 | License.
55 |
56 | 3.6 Termination. If you violate any term of this License, then your rights under this License
57 | (including the grant in Section 2.1) will terminate immediately.
58 |
59 | 4. Disclaimer of Warranty.
60 |
61 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
62 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
63 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-
64 | INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS
65 | LICENSE.
66 |
67 | 5. Limitation of Liability.
68 |
69 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
70 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
71 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
72 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR
73 | RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT
74 | NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA,
75 | COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR
76 | LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
77 | DAMAGES.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | RelViT
2 | ===
3 |
4 |

5 |
6 | This repository hosts the code for the paper:
7 |
8 | [RelViT: Concept-guided Vision Transformer for Visual Relational Reasoning (ICLR 2022)](https://web.cs.ucla.edu/~xm/file/relvit_iclr22.pdf)
9 |
10 | by [Xiaojian Ma](http://web.cs.ucla.edu/~xm), [Weili Nie](https://weilinie.github.io/), [Zhiding Yu](https://chrisding.github.io/), [Huaizu Jiang](https://jianghz.me/), [Chaowei Xiao](https://xiaocw11.github.io/), [Yuke Zhu](https://www.cs.utexas.edu/~yukez/) and [Anima Anandkumar](http://tensorlab.cms.caltech.edu/users/anima/)
11 |
12 | [arXiv](https://arxiv.org/abs/2204.11167) | [Poster](https://web.cs.ucla.edu/~xm/file/relvit_iclr22_poster.pdf) | [Slides](https://web.cs.ucla.edu/~xm/file/relvit_iclr22_slides.pdf)
13 |
14 | ## News
15 |
16 | - 🔥🔥 **09/10/2022**: Pre-trained models on GQA are now released.
17 |
18 |
19 | ## Abstract
20 |
21 | Reasoning about visual relationships is central to how humans interpret the visual world. This task remains challenging for current deep learning algorithms since it requires addressing three key technical problems jointly: **1)** identifying object entities and their properties, **2)** inferring semantic relations between pairs of entities, and **3)** generalizing to novel object-relation combinations, i.e., systematic generalization. In this work, we use **vision transformers (ViTs)** as our base model for visual reasoning and make better use of concepts defined as object entities and their relations to improve the reasoning ability of ViTs. Specifically, we introduce a novel concept-feature dictionary to allow flexible image feature retrieval at training time with concept keys. This dictionary enables two new concept-guided auxiliary tasks: 1) a **global task** for promoting relational reasoning, and 2) a **local task** for facilitating semantic object-centric correspondence learning. To examine the systematic generalization of visual reasoning models, we introduce systematic splits for the standard HICO and GQA benchmarks. We show the resulting model, **Concept-guided Vision Transformer** (or **RelViT** for short) significantly outperforms prior approaches on HICO and GQA by **16%** and **13%** in the original split, and by **43%** and **18%** in the systematic split. Our ablation analyses also reveal our model's compatibility with multiple ViT variants and robustness to hyper-parameters.
22 |
23 | ## Installation
24 |
25 | - Install PyTorch:
26 | ```shell
27 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
28 | ```
29 |
30 | - Install the necessary packages with `requirements.txt`:
31 | ```shell
32 | pip install -r requirements.txt
33 | ```
34 |
35 | The code has been tested with Python 3.8, PyTorch 1.11.0 and CUDA 11.6 on Ubuntu 20.04
36 |
37 | ## Data Preparation
38 |
39 | Please refer to [data preparation](assets/dataset.md)
40 |
41 | ## Training
42 |
43 | HICO
44 |
45 | ```shell
46 | bash scripts/train_hico_image.sh configs/train_hico.yaml
47 | ```
48 |
49 | ### Note
50 |
51 | - In `configs/train_hico.yaml` you may find some configurable options:
52 | - use `eval_mode` to run different experiments: original or systematic generalization test
53 | - use `model_args.encoder_args.encoder` and `load_encoder` to select the vision backbone. There are five options available: `pvtv2_b2`, `pvtv2_b3`, `swin_small`, `swin_base` and `vit_small_16`.
54 | - use `relvit` to turn on/off RelViT auxillary loss
55 | - use `relvit_weight` to adjust the coefficient of RelViT auxillary loss
56 | - use `relvit_local_only` to control if you only use RelViT local/global task
57 | - use `relvit_mode` to control if you want to include [EsViT](https://github.com/microsoft/esvit) loss.
58 | - use `relvit_sample_uniform` to choose from uniform or "most-recent" concept sampling
59 | - use `relvit_concept_use` and `relvit_num_concepts` to choose the concept used by RelViT among `HOI`, `verb` and `object`
60 |
61 | In general, we don't recommend modifying other parameters.
62 |
63 | - All the GPUs will be used by default. To run with the recommended batch size, you may need 1 V100 32G GPU.
64 |
65 |
66 |
67 | GQA
68 |
69 | ```shell
70 | bash scripts/train_gqa_image.sh configs/train_gqa.yaml
71 | ```
72 |
73 | ### Note
74 |
75 | - In `configs/train_gqa.yaml` you may find some configurable options:
76 | - use `eval_mode` to run different experiments: original or systematic generalization test
77 | - use `model_args.encoder_args.encoder` and `load_encoder` to select the vision backbone. There are five options available: `pvtv2_b2`, `pvtv2_b3`, `swin_small`, `swin_base` and `vit_small_16`.
78 | - use `relvit` to turn on/off RelViT auxillary loss
79 | - use `relvit_weight` to adjust the coefficient of RelViT auxillary loss
80 | - use `relvit_local_only` to control if you only use RelViT local/global task
81 | - use `relvit_mode` to control if you want to include [EsViT](https://github.com/microsoft/esvit) loss
82 | - use `relvit_sample_uniform` to choose from uniform or "most-recent" concept sampling
83 |
84 | In general, we don't recommend modifying other parameters.
85 |
86 | - All the GPUs will be used by default. To run with the recommended batch size, you may need up to 64 V100 32G GPUs. This is because we need to fine-tune the vision backbone during training.
87 |
88 |
89 |
90 | ## Testing
91 |
92 | HICO
93 |
94 | ```shell
95 | bash scripts/train_hico_image.sh configs/train_hico.yaml --test_only --test_model
96 | ```
97 |
98 |
99 |
100 | GQA
101 |
102 | ```shell
103 | bash scripts/train_gqa_image.sh configs/train_gqa.yaml --test_only --test_model
104 | ```
105 |
106 |
107 |
108 | ## Pre-trained models
109 |
110 | | tag | encoder | experiment | result | URL |
111 | |---------------------|--------------|------------|--------|--------------------------------------------------------------------------------------------|
112 | | `swin-small-relvit` | `swin_small` | GQA (val) | 61.38 | [link](https://drive.google.com/file/d/1efkfy7bM6PmzKZy6ffQQq98OkLLgcZd3/view?usp=sharing) |
113 | | `swin-base-relvit` | `swin_base` | GQA (val) | 65.54 | [link](https://drive.google.com/file/d/1u66PqHJiqGJZsXWeLsjooVAFEdkjDVK3/view?usp=sharing) |
114 |
115 |
116 | ## License
117 |
118 | Please check the LICENSE file for both the code and the released pre-trained models. This work may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact researchinquiries@nvidia.com.
119 |
120 | ## Acknowledgement
121 |
122 | The authors have referred the following projects:
123 |
124 | [SimCLR](https://github.com/google-research/simclr)
125 |
126 | [DenseCL](https://github.com/WXinlong/DenseCL)
127 |
128 | [EsViT](https://github.com/microsoft/esvit)
129 |
130 | [Swin-Transformer](https://github.com/microsoft/Swin-Transformer)
131 |
132 | [PVT](https://github.com/whai362/PVT)
133 |
134 | [HICODet](https://github.com/fredzzhang/hicodet)
135 |
136 | [MCAN](https://github.com/MILVLG/mcan-vqa)
137 |
138 | ## Citation
139 |
140 | Please consider citing our paper if you find our work helpful for your research:
141 |
142 | ```
143 | @inproceedings{ma2022relvit,
144 | title={RelViT: Concept-guided Vision Transformer for Visual Relational Reasoning},
145 | author={Xiaojian Ma and Weili Nie and Zhiding Yu and Huaizu Jiang and Chaowei Xiao and Yuke Zhu and Song-Chun Zhu and Anima Anandkumar},
146 | booktitle={International Conference on Learning Representations},
147 | year={2022},
148 | url={https://openreview.net/forum?id=afoV8W3-IYp}
149 | }
150 | ```
151 |
--------------------------------------------------------------------------------
/assets/data/gqa/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/RelViT/ae34a63975738a7eef219fe3a2080286f7dfc753/assets/data/gqa/.gitkeep
--------------------------------------------------------------------------------
/assets/data/hico/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/RelViT/ae34a63975738a7eef219fe3a2080286f7dfc753/assets/data/hico/.gitkeep
--------------------------------------------------------------------------------
/assets/dataset.md:
--------------------------------------------------------------------------------
1 | Data Preparation
2 | ===
3 |
4 | 1. Download the HICO dataset from [here](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) and extract it to `./assets/data/hico`
5 |
6 | 2. Download our customized HICO annotations from [here](https://drive.google.com/file/d/11RvCM0KIBB4pFw0NuaX6Z6IPvnkrb83H/view?usp=sharing) and extract them to `./assets/data/hico/hico_20160224_det`
7 |
8 | 3. Download the images of GQA from [here](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip) and extract them to `./assets/data/gqa`
9 |
10 | 4. Download our customized GQA annotations and meta files from [here](https://drive.google.com/file/d/1hKlRiikRkfZNB-St4kOzuMdLv8N5xk15/view?usp=sharing) and extract them to `./assets/data/gqa`
11 |
12 | 5. Download the pretrained vision backbones and other files from [here](https://drive.google.com/file/d/1pxmUxkk5t8Bg_cS_jdaQgugCqYddZInE/view?usp=sharing) and extract them to `./cache`
13 |
14 | 6. The file structure should look like
15 | ```plain
16 | data
17 | ├── gqa
18 | │ ├── dicts.json
19 | │ ├── gqa_dic.pkl
20 | │ ├── images
21 | │ └── raw
22 | │ └── questions1.2
23 | │ ├── test_balanced_questions.json
24 | │ ├── testdev_balanced_questions.json
25 | │ ├── train_balanced_concepts.json
26 | │ ├── train_balanced_questions.json
27 | │ ├── train_sys_reduced_concepts.json
28 | │ ├── train_sys_reduced_questions.json
29 | │ ├── val_balanced_questions.json
30 | │ └── val_sys_reduced_questions.json
31 | └── hico
32 | └── hico_20160224_det
33 | ├── images
34 | │ ├── test2015
35 | │ └── train2015
36 | ├── instances_test2015.json
37 | ├── instances_train2015.json
38 | ├── sys_vcl_nonrare_instances_test2015.json
39 | ├── sys_vcl_nonrare_instances_train2015.json
40 | ├── sys_vcl_rare_instances_test2015.json
41 | └── sys_vcl_rare_instances_train2015.json
42 | ```
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/RelViT/ae34a63975738a7eef219fe3a2080286f7dfc753/assets/overview.png
--------------------------------------------------------------------------------
/cache/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/RelViT/ae34a63975738a7eef219fe3a2080286f7dfc753/cache/.gitkeep
--------------------------------------------------------------------------------
/configs/train_gqa_mcan.yaml:
--------------------------------------------------------------------------------
1 | train_dataset: gqa
2 | train_dataset_args:
3 | root_dir: ./assets/data/gqa
4 | split: train
5 |
6 | test_dataset: gqa
7 | test_dataset_args:
8 | root_dir: ./assets/data/gqa
9 | split: val
10 |
11 | # model: mcan
12 | # Note: Modifying the hyper-param of MCAN is not recommended. But you may find them in ./utils/__init__.py (MCAN_GQA_PARAMS)
13 | model: mcan-customized
14 | model_args:
15 | word_emb_path: ./cache/gqa_word_embed.npy
16 | encoder: transparent_superpixel_encoder
17 | encoder_args: {encoder: pvtv2_b2, use_boxes_dim: False}
18 | load_encoder: ./cache/pvtv2_b2-{}.pth
19 | encoder_pretrain: imagenet
20 |
21 | train_batches: 1000000
22 | ep_per_batch: 1
23 | max_epoch: 12
24 |
25 | # 0 -- original
26 | # 1 -- systematic
27 | eval_mode: 0
28 |
29 | ###########
30 | # RelViT
31 | ###########
32 | relvit: True
33 | relvit_weight: 1.0
34 | # temperature
35 | relvit_loss_tau: 0.04
36 | # 0 -- both
37 | # 1 -- local only
38 | # 2 -- global only
39 | relvit_local_only: 0
40 | # 0 -- EsViT
41 | # 1 -- EsViT+RelViT
42 | # 2 -- RelViT
43 | relvit_mode: 1
44 | # False -- Most recent
45 | # True -- Uniform
46 | relvit_sample_uniform: True
47 | relvit_num_concepts: 1615
48 |
49 | ###########
50 | # MoCo
51 | ###########
52 | relvit_moco_m: 0.999
53 | relvit_moco_use_queue: False
54 | relvit_moco_K: 10
55 | # 49 for pvtv2 and swin-s, 196 for vit-s-16
56 | relvit_num_tokens: 49
57 |
58 | optimizer: adamw
59 | optimizer_args: {lr: 0.0001, weight_decay: 0, 'milestones': [8, 10], eps: 1e-8}
60 |
61 | print_freq: 10
62 | save_epoch: 1
63 | eval_epoch: 1
64 | grad_norm: 0.5
--------------------------------------------------------------------------------
/configs/train_hico.yaml:
--------------------------------------------------------------------------------
1 | train_dataset: hicodet
2 | train_dataset_args:
3 | im_dir: ./assets/data/hico
4 | split: train
5 | augment_plus: true
6 |
7 | test_dataset: hicodet
8 | test_dataset_args:
9 | im_dir: ./assets/data/hico
10 | split: test
11 |
12 | model: classifier
13 | model_args:
14 | encoder: transparent_superpixel_encoder
15 | encoder_args: {encoder: swin_base, use_boxes_dim: False}
16 | classifier: max-pooling-classifier-twoheads
17 | classifier_args: {n_classes: 600, n_classes1: 117, n_classes2: 80}
18 | load_encoder: ./cache/swin_base-{}.pth
19 | encoder_pretrain: imagenet
20 |
21 | train_batches: 100000
22 | ep_per_batch: 16
23 | max_epoch: 30
24 |
25 | # 1 -- systematic-easy
26 | # 2 -- systematic-hard
27 | # 3 -- original (two heads)
28 | eval_mode: 3
29 |
30 | ###########
31 | # RelViT
32 | ###########
33 | relvit: True
34 | relvit_weight: 0.2
35 | # temperature
36 | relvit_loss_tau: 0.04
37 | # 0 -- both
38 | # 1 -- local only
39 | # 2 -- global only
40 | relvit_local_only: 0
41 | # 0 -- EsViT
42 | # 1 -- EsViT+RelViT
43 | # 2 -- RelViT
44 | relvit_mode: 1
45 | # False -- Most recent
46 | # True -- Uniform
47 | relvit_sample_uniform: False
48 | # 0 & 600 -- hoi
49 | # 1 & 117 -- verb
50 | # 2 & 80 -- object
51 | relvit_concept_use: 0
52 | relvit_num_concepts: 600
53 |
54 | ###########
55 | # MoCo
56 | ###########
57 | relvit_moco_m: 0.999
58 | relvit_moco_use_queue: False
59 | relvit_moco_K: 10
60 | # 49 for pvtv2 and swin-s, 196 for vit-s-16
61 | relvit_num_tokens: 49
62 |
63 | optimizer: adamw
64 | optimizer_args: {lr: 0.00015, weight_decay: 0, 'milestones': [15, 25]}
65 |
66 | print_freq: 10
67 | save_epoch: 1
68 | eval_epoch: 1
69 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for RelViT. To view a copy of this license, see the LICENSE file.
6 | # ----------------------------------------------------------------------
7 |
8 | from .datasets import make
9 | from . import image_hicodet_bbox
10 | from . import image_gqa
11 |
--------------------------------------------------------------------------------
/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the Bongard-HOI library
5 | # which was released under the NVIDIA Source Code Licence.
6 | #
7 | # Source:
8 | # https://github.com/NVlabs/Bongard-HOI
9 | #
10 | # The license for the original version of this file can be
11 | # found in https://github.com/NVlabs/Bongard-HOI/blob/master/LICENSE
12 | # The modifications to this file are subject to the same NVIDIA Source Code Licence.
13 | # ---------------------------------------------------------------
14 |
15 |
16 | import os
17 |
18 |
19 | datasets = {}
20 | def register(name):
21 | def decorator(cls):
22 | datasets[name] = cls
23 | return cls
24 | return decorator
25 |
26 |
27 | def make(name, **kwargs):
28 | dataset = datasets[name](**kwargs)
29 | return dataset
30 |
--------------------------------------------------------------------------------
/datasets/image_gqa.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for RelViT. To view a copy of this license, see the LICENSE file.
6 | # ----------------------------------------------------------------------
7 |
8 | import json
9 | import re
10 | import glob
11 | import os.path as osp
12 | import collections
13 |
14 | import numpy as np
15 | from PIL import Image
16 | import torch
17 | from torch.utils.data import Dataset
18 | from torchvision import transforms
19 | from torch.utils.data._utils.collate import default_collate
20 |
21 | import utils
22 | from .datasets import register
23 |
24 |
25 | def label_to_onehot(label, num_class):
26 | # label: [N]
27 | onehot = torch.zeros(num_class).to(label)
28 | for i in label:
29 | onehot[i] = 1
30 | return onehot
31 |
32 |
33 | contractions = {
34 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve":
35 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've",
36 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt":
37 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've":
38 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent":
39 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve":
40 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll",
41 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im":
42 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've":
43 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
44 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've":
45 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
46 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't",
47 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
48 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat":
49 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve":
50 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt":
51 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve":
52 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've":
53 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll":
54 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd",
55 | "someoned've": "someone'd've", "someone'dve": "someone'd've",
56 | "someonell": "someone'll", "someones": "someone's", "somethingd":
57 | "something'd", "somethingd've": "something'd've", "something'dve":
58 | "something'd've", "somethingll": "something'll", "thats":
59 | "that's", "thered": "there'd", "thered've": "there'd've",
60 | "there'dve": "there'd've", "therere": "there're", "theres":
61 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve":
62 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve":
63 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've":
64 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent":
65 | "weren't", "whatll": "what'll", "whatre": "what're", "whats":
66 | "what's", "whatve": "what've", "whens": "when's", "whered":
67 | "where'd", "wheres": "where's", "whereve": "where've", "whod":
68 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl":
69 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
70 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve":
71 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
72 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll":
73 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
74 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd":
75 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll":
76 | "you'll", "youre": "you're", "youve": "you've"
77 | }
78 |
79 | manual_map = { 'none': '0',
80 | 'zero': '0',
81 | 'one': '1',
82 | 'two': '2',
83 | 'three': '3',
84 | 'four': '4',
85 | 'five': '5',
86 | 'six': '6',
87 | 'seven': '7',
88 | 'eight': '8',
89 | 'nine': '9',
90 | 'ten': '10'}
91 | articles = ['a', 'an', 'the']
92 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
93 | comma_strip = re.compile("(\d)(\,)(\d)")
94 | punct = [';', r"/", '[', ']', '"', '{', '}',
95 | '(', ')', '=', '+', '\\', '_', '-',
96 | '>', '<', '@', '`', ',', '?', '!']
97 |
98 | def process_punctuation(inText):
99 | outText = inText
100 | for p in punct:
101 | if (p + ' ' in inText or ' ' + p in inText) \
102 | or (re.search(comma_strip, inText) != None):
103 | outText = outText.replace(p, '')
104 | else:
105 | outText = outText.replace(p, ' ')
106 | outText = period_strip.sub("", outText, re.UNICODE)
107 | return outText
108 |
109 |
110 | def process_digit_article(inText):
111 | outText = []
112 | tempText = inText.lower().split()
113 | for word in tempText:
114 | word = manual_map.setdefault(word, word)
115 | if word not in articles:
116 | outText.append(word)
117 | else:
118 | pass
119 | for wordId, word in enumerate(outText):
120 | if word in contractions:
121 | outText[wordId] = contractions[word]
122 | outText = ' '.join(outText)
123 | return outText
124 |
125 |
126 | def prep_ans(answer):
127 | answer = process_digit_article(process_punctuation(answer))
128 | answer = answer.replace(',', '')
129 | return answer
130 |
131 |
132 | class Scale(object):
133 | """Rescale the input PIL.Image to the given size.
134 | Args:
135 | size (sequence or int): Desired output size. If size is a sequence like
136 | (w, h), output size will be matched to this. If size is an int,
137 | smaller edge of the image will be matched to this number.
138 | i.e, if height > width, then image will be rescaled to
139 | (size * height / width, size)
140 | interpolation (int, optional): Desired interpolation. Default is
141 | ``PIL.Image.BILINEAR``
142 | """
143 |
144 | def __init__(self, size, interpolation=Image.BILINEAR):
145 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
146 | self.size = size
147 | self.interpolation = interpolation
148 |
149 | def __call__(self, img):
150 | """
151 | Args:
152 | img (PIL.Image): Image to be scaled.
153 | Returns:
154 | PIL.Image: Rescaled image.
155 | """
156 | if isinstance(self.size, int):
157 | w, h = img.size
158 | if (w <= h and w == self.size) or (h <= w and h == self.size):
159 | return img
160 | if w < h:
161 | ow = self.size
162 | oh = int(self.size * h / w)
163 | return img.resize((ow, oh), self.interpolation)
164 | else:
165 | oh = self.size
166 | ow = int(self.size * w / h)
167 | return img.resize((ow, oh), self.interpolation)
168 | else:
169 | return img.resize(self.size, self.interpolation)
170 |
171 |
172 | @register('gqa')
173 | class GQA(Dataset):
174 | def __init__(self, root_dir=None, split='train', eval_mode=0, **kwargs) -> None:
175 | self.eval_mode = eval_mode
176 | self.FRCN_FEAT_SIZE = utils.MCAN_GQA_PARAMS['FRCN_FEAT_SIZE']
177 | self.GRID_FEAT_SIZE = utils.MCAN_GQA_PARAMS['GRID_FEAT_SIZE']
178 | self.BBOX_FEAT_SIZE = utils.MCAN_GQA_PARAMS['BBOX_FEAT_SIZE']
179 | self.root_dir = root_dir
180 | self.split = split
181 | self.dict_json = osp.join(self.root_dir, 'dicts.json')
182 |
183 | if self.eval_mode:
184 | # object/edge/degree
185 | fn = osp.join(self.root_dir, 'raw', 'questions1.2', '{}_sys_reduced_questions.json'.format(self.split))
186 | fn_concept = osp.join(self.root_dir, 'raw', 'questions1.2', 'train_sys_reduced_concepts.json'.format(self.split))
187 | else:
188 | if self.split is 'val':
189 | split = 'val'
190 | else:
191 | split = self.split
192 | fn = osp.join(self.root_dir, 'raw', 'questions1.2', '{}_balanced_questions.json'.format(split))
193 | fn_concept = osp.join(self.root_dir, 'raw', 'questions1.2', 'train_balanced_concepts.json'.format(self.split))
194 | with open(fn, 'r') as f:
195 | self.ques_dict = json.load(f)
196 | self.qid_list = list(self.ques_dict.keys())
197 | with open(fn_concept, 'r') as f:
198 | self.concept_dict = json.load(f)
199 |
200 | def img_feat_path_load(path_list):
201 | iid_to_path = {}
202 | for ix, path in enumerate(path_list):
203 | iid = path.split('/')[-1].split('.')[0]
204 | iid_to_path[iid] = path
205 |
206 | return iid_to_path
207 | self.iid_to_img_path = img_feat_path_load(glob.glob(osp.join(
208 | self.root_dir,
209 | 'images',
210 | '*.jpg'
211 | )))
212 | self.data_size = self.ques_dict.__len__()
213 | self.num_concept = 1615
214 |
215 | # Tokenize
216 | # self.token_to_ix, self.pretrained_emb, max_token = self.tokenize(self.dict_json, True)
217 | self.token_to_ix, _, max_token = self.tokenize(self.dict_json, True)
218 | self.token_size = self.token_to_ix.__len__()
219 |
220 | self.max_token = -1
221 | if self.max_token == -1:
222 | self.max_token = max_token
223 |
224 | # Answers statistic
225 | self.ans_to_ix, self.ix_to_ans = self.ans_stat(self.dict_json)
226 | self.ans_size = self.ans_to_ix.__len__()
227 |
228 | # For RGB
229 | self.pix_mean = (0.485, 0.456, 0.406)
230 | self.pix_std = (0.229, 0.224, 0.225)
231 | self.transform = transforms.Compose([
232 | transforms.ToTensor(),
233 | transforms.Normalize(mean=self.pix_mean, std=self.pix_std)
234 | ])
235 | # For two-view denseCL
236 | # https://github.com/WXinlong/DenseCL/blob/main/configs/selfsup/densecl/densecl_coco_800ep.py
237 | self.densecl_aug = transforms.Compose([
238 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
239 | transforms.RandomApply(torch.nn.ModuleList([
240 | transforms.ColorJitter(
241 | brightness=0.4,
242 | contrast=0.4,
243 | saturation=0.4,
244 | hue=0.1,
245 | )]),
246 | p=0.8
247 | ),
248 | transforms.RandomGrayscale(p=0.2),
249 | transforms.RandomApply(torch.nn.ModuleList([
250 | transforms.GaussianBlur(
251 | # https://github.com/sthalles/SimCLR/blob/master/data_aug/contrastive_learning_dataset.py#L20
252 | kernel_size=(int(0.1*224)//2)*2+1,
253 | sigma=(0.1, 2.0)
254 | )]),
255 | p=0.5
256 | ),
257 | transforms.RandomHorizontalFlip(),
258 | ])
259 |
260 |
261 | def __len__(self):
262 | return self.data_size
263 |
264 | def __getitem__(self, idx):
265 | ques_ix_iter, ans_iter, iid = self.load_ques_ans(idx, raw_q=False)
266 | img = Image.open(self.iid_to_img_path[iid]).convert('RGB')
267 | img = self.transform(img)
268 | second_im = self.densecl_aug(img)
269 | third_im = self.densecl_aug(img)
270 |
271 | concept = np.array(self.load_concept(idx), dtype=np.int64)
272 | concept = label_to_onehot(torch.from_numpy(concept).long(), self.num_concept)
273 |
274 | # len(ret), ret[-3]
275 | # 7, True
276 | ret = []
277 | ret.append(img)
278 | ret.append(second_im)
279 | ret.append(third_im)
280 | ret.append(torch.from_numpy(ques_ix_iter))
281 | ret.append(True)
282 | ret.append(torch.from_numpy(ans_iter))
283 | ret.append(concept)
284 |
285 | return tuple(ret)
286 |
287 | def tokenize(self, json_file, use_glove):
288 | token_to_ix, max_token = json.load(open(json_file, 'r'))[2:]
289 | # spacy_tool = None
290 | # if use_glove:
291 | # spacy_tool = en_vectors_web_lg.load()
292 |
293 | # pretrained_emb = []
294 | # for word in token_to_ix:
295 | # if use_glove:
296 | # pretrained_emb.append(spacy_tool(word).vector)
297 | # pretrained_emb = np.array(pretrained_emb)
298 |
299 | pretrained_emb = None
300 | return token_to_ix, pretrained_emb, max_token
301 |
302 | def ans_stat(self, json_file):
303 | ans_to_ix, ix_to_ans = json.load(open(json_file, 'r'))[:2]
304 |
305 | return ans_to_ix, ix_to_ans
306 |
307 | def load_concept(self, idx):
308 | qid = self.qid_list[idx]
309 | if self.split == 'train':
310 | ret = self.concept_dict[qid]
311 | else:
312 | ret = []
313 | return ret
314 |
315 | def load_ques_ans(self, idx, raw_q=False):
316 |
317 | qid = self.qid_list[idx]
318 | iid = self.ques_dict[qid]['imageId']
319 |
320 | ques = self.ques_dict[qid]['question']
321 | if raw_q:
322 | ques_ix_iter = ques
323 | else:
324 | ques_ix_iter = self.proc_ques(ques, self.token_to_ix, max_token=self.max_token)
325 |
326 | # process answers
327 | ans = self.ques_dict[qid]['answer']
328 | ans_iter = self.proc_ans(ans, self.ans_to_ix)
329 |
330 | return ques_ix_iter, ans_iter, iid
331 |
332 | def proc_img_feat(self, img_feat, img_feat_pad_size):
333 | if img_feat.shape[0] > img_feat_pad_size:
334 | img_feat = img_feat[:img_feat_pad_size]
335 |
336 | img_feat = np.pad(
337 | img_feat,
338 | ((0, img_feat_pad_size - img_feat.shape[0]), (0, 0)),
339 | mode='constant',
340 | constant_values=0
341 | )
342 |
343 | return img_feat
344 |
345 | def proc_bbox_feat(self, bbox, img_shape):
346 | bbox_feat = np.zeros((bbox.shape[0], 5), dtype=np.float32)
347 |
348 | bbox_feat[:, 0] = bbox[:, 0] / float(img_shape[1])
349 | bbox_feat[:, 1] = bbox[:, 1] / float(img_shape[0])
350 | bbox_feat[:, 2] = bbox[:, 2] / float(img_shape[1])
351 | bbox_feat[:, 3] = bbox[:, 3] / float(img_shape[0])
352 | bbox_feat[:, 4] = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1]) / float(img_shape[0] * img_shape[1])
353 |
354 | return bbox_feat
355 |
356 | def proc_ques(self, ques, token_to_ix, max_token):
357 | ques_ix = np.zeros(max_token, np.int64)
358 |
359 | words = re.sub(
360 | r"([.,'!?\"()*#:;])",
361 | '',
362 | ques.lower()
363 | ).replace('-', ' ').replace('/', ' ').split()
364 |
365 | for ix, word in enumerate(words):
366 | if word in token_to_ix:
367 | ques_ix[ix] = token_to_ix[word]
368 | else:
369 | ques_ix[ix] = token_to_ix['UNK']
370 |
371 | if ix + 1 == max_token:
372 | break
373 |
374 | return ques_ix
375 |
376 | def proc_ans(self, ans, ans_to_ix):
377 | ans_ix = np.zeros(1, np.int64)
378 | ans = prep_ans(ans)
379 | ans_ix[0] = ans_to_ix[ans]
380 |
381 | return ans_ix
382 |
383 |
384 | def collate_gqa(batch):
385 | def _pad_tensor(tensor_list):
386 | max_imh, max_imw = -1, -1
387 | for tensor_i in tensor_list:
388 | # import pdb; pdb.set_trace()
389 | imh, imw = tensor_i.shape[-2], tensor_i.shape[-1]
390 | max_imh = max(max_imh, imh)
391 | max_imw = max(max_imw, imw)
392 |
393 | for idx, tensor_i in enumerate(tensor_list):
394 | pad_tensor_i = tensor_i.new_full(list(tensor_i.shape[:-2]) + [max_imh, max_imw], 0)
395 | imh, imw = tensor_i.shape[-2], tensor_i.shape[-1]
396 | pad_tensor_i[..., :imh, :imw].copy_(tensor_i)
397 | tensor_list[idx] = pad_tensor_i
398 | return tensor_list
399 |
400 | # len(ret), ret[-3]
401 | # 7, True
402 | if len(batch[0][0].shape) == 2:
403 | return default_collate(batch)
404 | else:
405 | list_ims = []
406 | list_second_ims = []
407 | list_third_ims = []
408 | list_qs = []
409 | list_xs = []
410 | list_as = []
411 | list_cs = []
412 | for b in batch:
413 | list_ims.append(b[0])
414 | list_second_ims.append(b[1])
415 | list_third_ims.append(b[2])
416 | list_qs.append(b[3])
417 | list_xs.append(b[4])
418 | list_as.append(b[5])
419 | list_cs.append(b[6])
420 | list_ims = torch.stack(_pad_tensor(list_ims))
421 | list_second_ims = torch.stack(_pad_tensor(list_second_ims))
422 | list_third_ims = torch.stack(_pad_tensor(list_third_ims))
423 | if list_xs[0]:
424 | list_qs = torch.stack(list_qs)
425 | list_as = torch.stack(list_as)
426 | list_cs = torch.stack(list_cs)
427 | return list_ims, list_second_ims, list_third_ims, list_qs, list_xs, list_as, list_cs
428 |
--------------------------------------------------------------------------------
/datasets/image_hicodet_bbox.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for RelViT. To view a copy of this license, see the LICENSE file.
6 | # ----------------------------------------------------------------------
7 |
8 | import os
9 | import os.path as osp
10 | import json
11 | import pickle
12 | import numpy as np
13 | from PIL import Image, ImageFilter
14 | import sklearn.metrics
15 |
16 | import cv2
17 | from typing import Any, Callable, List, Optional, Tuple
18 | import torch
19 | from torch.utils.data import Dataset
20 | from torchvision import transforms
21 |
22 | # from detectron2
23 | from detectron2.structures import Boxes
24 | from detectron2.data import transforms as T
25 |
26 | import utils
27 | from .datasets import register
28 |
29 |
30 | def label_to_onehot(label, num_class):
31 | # label: [N]
32 | onehot = torch.zeros(num_class).to(label)
33 | for i in label:
34 | onehot[i] = 1
35 | return onehot
36 |
37 |
38 | @register('hicodet')
39 | class HICODet(Dataset):
40 | def __init__(self, image_size=256, box_size=256, eval_mode=0, **kwargs) -> None:
41 | self.eval_mode = eval_mode
42 | im_dir = kwargs.get('im_dir')
43 | self.split = kwargs.get('split') # train or test
44 | self._root = osp.join(im_dir, 'hico_20160224_det', 'images', '{}2015'.format(self.split))
45 | # rare
46 | if self.eval_mode == 1:
47 | anno_file = osp.join(im_dir, 'hico_20160224_det', 'sys_vcl_rare_instances_{}2015.json'.format(self.split))
48 | # non-rare
49 | elif self.eval_mode == 2:
50 | anno_file = osp.join(im_dir, 'hico_20160224_det', 'sys_vcl_nonrare_instances_{}2015.json'.format(self.split))
51 | else:
52 | anno_file = osp.join(im_dir, 'hico_20160224_det', 'instances_{}2015.json'.format(self.split))
53 |
54 | with open(anno_file, 'r') as f:
55 | anno = json.load(f)
56 |
57 | self.num_object_cls = 80
58 | self.num_interaction_cls = 600
59 | self.num_action_cls = 117
60 | self._anno_file = anno_file
61 |
62 | # Load annotations
63 | self._load_annotation_and_metadata(anno)
64 |
65 | self.pix_mean = (0.485, 0.456, 0.406)
66 | self.pix_std = (0.229, 0.224, 0.225)
67 | # detectron2-style data augmentation
68 | sample_style = 'range'
69 | augmentations = [T.ResizeShortestEdge(image_size, int(image_size * 2), sample_style)]
70 | if kwargs.get('augment') or kwargs.get('augment_plus'):
71 | augmentations.append(
72 | T.RandomFlip(
73 | horizontal=True,
74 | vertical=False,
75 | )
76 | )
77 | if kwargs.get('augment_plus'):
78 | self.photo_aug = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
79 | else:
80 | self.photo_aug = None
81 | self.augmentations = T.AugmentationList(augmentations)
82 |
83 | # For two-view denseCL
84 | # https://github.com/WXinlong/DenseCL/blob/main/configs/selfsup/densecl/densecl_coco_800ep.py
85 | self.densecl_aug = transforms.Compose([
86 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
87 | transforms.RandomApply(torch.nn.ModuleList([
88 | transforms.ColorJitter(
89 | brightness=0.4,
90 | contrast=0.4,
91 | saturation=0.4,
92 | hue=0.1,
93 | )]),
94 | p=0.8
95 | ),
96 | transforms.RandomGrayscale(p=0.2),
97 | transforms.RandomApply(torch.nn.ModuleList([
98 | transforms.GaussianBlur(
99 | # https://github.com/sthalles/SimCLR/blob/master/data_aug/contrastive_learning_dataset.py#L20
100 | kernel_size=(int(0.1*224)//2)*2+1,
101 | sigma=(0.1, 2.0)
102 | )]),
103 | p=0.5
104 | ),
105 | transforms.RandomHorizontalFlip(),
106 | ])
107 |
108 | def __len__(self) -> int:
109 | """Return the number of images"""
110 | return len(self._idx)
111 |
112 | def __getitem__(self, i: int) -> tuple:
113 | """
114 | Arguments:
115 | i(int): Index to an image
116 | """
117 | ind = self._idx[i]
118 | im = cv2.imread(osp.join(self._root, self._filenames[ind]))
119 | assert im is not None
120 | # BGR to RGB
121 | im = im[:, :, ::-1]
122 | # TODO: only keeps the first bbox pair
123 | sub_bbox = torch.Tensor([self._anno[ind]['boxes_h'][0]])
124 | obj_bbox = torch.Tensor([self._anno[ind]['boxes_o'][0]])
125 | union_bbox = torch.cat((torch.min(sub_bbox[0, :2], obj_bbox[0, :2]), torch.max(sub_bbox[0, 2:], obj_bbox[0, 2:]))).unsqueeze(0)
126 | boxes = torch.stack((sub_bbox, obj_bbox, union_bbox), dim=0)
127 | aug_input = T.AugInput(im, boxes=boxes)
128 | transforms = self.augmentations(aug_input)
129 | im, boxes = aug_input.image, torch.as_tensor(aug_input.boxes)
130 | im = im.astype(np.float32)
131 |
132 | def to_tensor(im, pix_mean, pix_std, normalize=True):
133 | if normalize:
134 | for i in range(3):
135 | im[:, :, i] = (im[:, :, i] / 255. - pix_mean[i]) / pix_std[i]
136 | im = torch.as_tensor(np.ascontiguousarray(im.transpose(2, 0, 1))).float()
137 | return im
138 |
139 | second_im = im.copy()
140 | third_im = im.copy()
141 | # Augment&tensorize the main view
142 | if self.photo_aug is not None:
143 | # color jittering of the input image
144 | im = np.array(self.photo_aug(Image.fromarray(im.astype(np.uint8))), dtype=np.float32)
145 | im = to_tensor(im, self.pix_mean, self.pix_std)
146 |
147 | # Augment&tensorize the second view
148 | second_im = np.array(self.densecl_aug(Image.fromarray(second_im.astype(np.uint8))), dtype=np.float32)
149 | second_im = to_tensor(second_im, self.pix_mean, self.pix_std)
150 |
151 | # Augment&tensorize the third view
152 | third_im = np.array(self.densecl_aug(Image.fromarray(third_im.astype(np.uint8))), dtype=np.float32)
153 | third_im = to_tensor(third_im, self.pix_mean, self.pix_std)
154 |
155 | hoi = label_to_onehot(torch.Tensor(self._anno[ind]['hoi']).long(), self.num_interaction_cls)
156 | verb = label_to_onehot(torch.Tensor(self._anno[ind]['verb']).long(), self.num_action_cls)
157 | object = label_to_onehot(torch.Tensor(self._anno[ind]['object']).long(), self.num_object_cls)
158 | return im, second_im, third_im, boxes, hoi, verb, object
159 |
160 | def _load_annotation_and_metadata(self, f: dict) -> None:
161 | """
162 | Arguments:
163 | f(dict): Dictionary loaded from {anno_file}.json
164 | """
165 | idx = list(range(len(f['filenames'])))
166 | for empty_idx in f['empty']:
167 | idx.remove(empty_idx)
168 |
169 | num_anno = [0 for _ in range(self.num_interaction_cls)]
170 | for anno in f['annotation']:
171 | for hoi in anno['hoi']:
172 | num_anno[hoi] += 1
173 |
174 | self._idx = idx
175 | self._num_anno = num_anno
176 |
177 | self._anno = f['annotation']
178 | self._filenames = f['filenames']
179 | self._image_sizes = f['size']
180 | self._class_corr = f['correspondence']
181 | self._empty_idx = f['empty']
182 | self._objects = f['objects']
183 | self._verbs = f['verbs']
184 |
185 | def _ood_split(self):
186 | pass
187 |
188 | def compute_map_hico(y_true, y_score, easy=False, hard=False, rare_only=False):
189 | unseen_hoi_nonrare = np.array([38, 41, 20, 18, 245, 11, 19, 154, 459, 42, 155, 139, 60, 461, 577, 153, 582, 89, 141, 576, 75, 212, 472, 61,
190 | 457, 146, 208, 94, 471, 131, 248, 544, 515, 566, 370, 481, 226, 250, 470, 323, 169, 480, 479, 230, 385, 73,
191 | 159, 190, 377, 176, 249, 371, 284, 48, 583, 53, 162, 140, 185, 106, 294, 56, 320, 152, 374, 338, 29, 594, 346,
192 | 456, 589, 45, 23, 67, 478, 223, 493, 228, 240, 215, 91, 115, 337, 559, 7, 218, 518, 297, 191, 266, 304, 6, 572,
193 | 529, 312, 9, 308, 417, 197, 193, 163, 455, 25, 54, 575, 446, 387, 483, 534, 340, 508, 110, 329, 246, 173, 506,
194 | 383, 93, 516, 64])
195 |
196 | unseen_hoi_rare = np.array([509, 279, 280, 402, 504, 286, 499, 498, 289, 485, 303, 311, 325, 439, 351, 358, 66, 427, 379, 418, 70, 416,
197 | 389, 90, 395, 76, 397, 84, 135, 262, 401, 592, 560, 586, 548, 593, 526, 181, 257, 539, 535, 260, 596, 345, 189,
198 | 205, 206, 429, 179, 350, 405, 522, 449, 261, 255, 546, 547, 44, 22, 334, 599, 239, 315, 317, 229, 158, 195,
199 | 238, 364, 222, 281, 149, 399, 83, 127, 254, 398, 403, 555, 552, 520, 531, 440, 436, 482, 274, 8, 188, 216, 597,
200 | 77, 407, 556, 469, 474, 107, 390, 410, 27, 381, 463, 99, 184, 100, 292, 517, 80, 333, 62, 354, 104, 55, 50,
201 | 198, 168, 391, 192, 595, 136, 581])
202 | rare_hoi = np.load('./cache/rare.npy')
203 | if easy:
204 | index = unseen_hoi_rare
205 | if hard:
206 | index = unseen_hoi_nonrare
207 | if rare_only:
208 | index = rare_hoi
209 | if easy or hard or rare_only:
210 | y_true = y_true[:, index]
211 | y_score = y_score[:, index]
212 | meter = utils.AveragePrecisionMeter(algorithm='AUC', output=torch.Tensor(y_score), labels=torch.Tensor(y_true))
213 | return meter.eval().mean()
214 |
215 |
216 | def collate_images_boxes_dict(batch):
217 | def _pad_tensor(tensor_list):
218 | max_imh, max_imw = -1, -1
219 | for tensor_i in tensor_list:
220 | # import pdb; pdb.set_trace()
221 | imh, imw = tensor_i.shape[-2], tensor_i.shape[-1]
222 | max_imh = max(max_imh, imh)
223 | max_imw = max(max_imw, imw)
224 |
225 | for idx, tensor_i in enumerate(tensor_list):
226 | pad_tensor_i = tensor_i.new_full(list(tensor_i.shape[:-2]) + [max_imh, max_imw], 0)
227 | imh, imw = tensor_i.shape[-2], tensor_i.shape[-1]
228 | pad_tensor_i[..., :imh, :imw].copy_(tensor_i)
229 | tensor_list[idx] = pad_tensor_i
230 | return tensor_list
231 | list_ims = []
232 | list_second_ims = []
233 | list_third_ims = []
234 | list_boxes = []
235 | list_hois = []
236 | list_verbs = []
237 | list_objects = []
238 | for b in batch:
239 | list_ims.append(b[0])
240 | list_second_ims.append(b[1])
241 | list_third_ims.append(b[2])
242 | list_boxes.append(b[3])
243 | list_hois.append(b[4])
244 | list_verbs.append(b[5])
245 | list_objects.append(b[6])
246 |
247 | list_ims = _pad_tensor(list_ims)
248 | list_second_ims = _pad_tensor(list_second_ims)
249 | list_third_ims = _pad_tensor(list_third_ims)
250 | return torch.stack(list_ims), torch.stack(list_second_ims), torch.stack(list_third_ims), torch.stack(list_boxes), torch.stack(list_hois), torch.stack(list_verbs), torch.stack(list_objects)
251 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for RelViT. To view a copy of this license, see the LICENSE file.
6 | # ----------------------------------------------------------------------
7 |
8 | from .models import make, load
9 | from . import classifier
10 | from . import transparent_encoder
11 | from . import pvt_v2
12 | from . import swin_transformer
13 | from . import vit
14 | from . import mcan
--------------------------------------------------------------------------------
/models/classifier.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for RelViT. To view a copy of this license, see the LICENSE file.
6 | # ----------------------------------------------------------------------
7 |
8 | import math
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | import models
15 | import utils
16 | from .models import register
17 |
18 |
19 | @register('classifier')
20 | class Classifier(nn.Module):
21 |
22 | def __init__(self, encoder, encoder_args,
23 | classifier, classifier_args):
24 | super().__init__()
25 | self.encoder = models.make(encoder, **encoder_args)
26 | classifier_args['in_dim'] = self.encoder.out_dim
27 | self.classifier = models.make(classifier, **classifier_args)
28 |
29 | def forward(self, x, boxes=None, info_nce=False):
30 | if info_nce:
31 | x, attn_v = self.encoder(x, boxes, info_nce=info_nce)
32 | x = self.classifier(x)
33 | return x, attn_v
34 | else:
35 | x = self.encoder(x, boxes)
36 | x = self.classifier(x)
37 | return x
38 |
39 |
40 | @register('linear-classifier')
41 | class LinearClassifier(nn.Module):
42 |
43 | def __init__(self, in_dim, n_classes):
44 | super().__init__()
45 | self.linear = nn.Linear(in_dim, n_classes)
46 |
47 | def forward(self, x):
48 | return self.linear(x)
49 |
50 | # Assume input is [B, C, H, W]
51 | @register('max-pooling-classifier')
52 | class MaxPoolingClassifier(nn.Module):
53 |
54 | def __init__(self, in_dim, n_classes, **kwargs):
55 | super().__init__()
56 | self.proj = nn.Sequential(
57 | nn.LayerNorm(in_dim),
58 | nn.Linear(in_dim, n_classes)
59 | )
60 |
61 | def forward(self, x):
62 | # x: [B, C, W, H]
63 | B, C = x.size(0), x.size(1)
64 | x = x.reshape(B, C, -1).max(-1)[0]
65 | # x = x.reshape(B, C, -1)[:, :, 0]
66 | return self.proj(x)
67 |
68 |
69 | # Assume input is [B, C, H, W]
70 | @register('max-pooling-classifier-twoheads')
71 | class MaxPoolingClassifierTwoHeads(nn.Module):
72 |
73 | def __init__(self, in_dim, n_classes1, n_classes2, **kwargs):
74 | super().__init__()
75 | self.proj1 = nn.Sequential(
76 | nn.LayerNorm(in_dim),
77 | nn.Linear(in_dim, n_classes1)
78 | )
79 | self.proj2 = nn.Sequential(
80 | nn.LayerNorm(in_dim),
81 | nn.Linear(in_dim, n_classes2)
82 | )
83 |
84 | def forward(self, x):
85 | # x: [B, C, W, H]
86 | B, C = x.size(0), x.size(1)
87 | x = x.reshape(B, C, -1).max(-1)[0]
88 | # x = x.reshape(B, C, -1)[:, :, 0]
89 | return (self.proj1(x), self.proj2(x))
90 |
91 |
92 | @register('nn-classifier')
93 | class NNClassifier(nn.Module):
94 |
95 | def __init__(self, in_dim, n_classes, metric='cos', temp=None):
96 | super().__init__()
97 | self.proto = nn.Parameter(torch.empty(n_classes, in_dim))
98 | nn.init.kaiming_uniform_(self.proto, a=math.sqrt(5))
99 | if temp is None:
100 | if metric == 'cos':
101 | temp = nn.Parameter(torch.tensor(10.))
102 | else:
103 | temp = 1.0
104 | self.metric = metric
105 | self.temp = temp
106 |
107 | def forward(self, x):
108 | return utils.compute_logits(x, self.proto, self.metric, self.temp)
109 |
--------------------------------------------------------------------------------
/models/mcan.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the mcan-vqa library
5 | # which was released under the Apache Licence.
6 | #
7 | # Source:
8 | # https://github.com/MILVLG/mcan-vqa
9 | #
10 | # The license for the original version of this file can be
11 | # found in https://github.com/MILVLG/mcan-vqa/blob/master/LICENSE
12 | # The modifications to this file are subject to the same Apache Licence.
13 | # ---------------------------------------------------------------
14 |
15 | import math
16 |
17 | import numpy as np
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 |
22 | import models
23 | import utils
24 | from .models import register
25 |
26 |
27 | def make_mask(feature):
28 | return (torch.sum(
29 | torch.abs(feature),
30 | dim=-1
31 | ) == 0).unsqueeze(1).unsqueeze(2)
32 |
33 |
34 | class FC(nn.Module):
35 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
36 | super(FC, self).__init__()
37 | self.dropout_r = dropout_r
38 | self.use_relu = use_relu
39 |
40 | self.linear = nn.Linear(in_size, out_size)
41 |
42 | if use_relu:
43 | self.relu = nn.ReLU(inplace=True)
44 |
45 | if dropout_r > 0:
46 | self.dropout = nn.Dropout(dropout_r)
47 |
48 | def forward(self, x):
49 | x = self.linear(x)
50 |
51 | if self.use_relu:
52 | x = self.relu(x)
53 |
54 | if self.dropout_r > 0:
55 | x = self.dropout(x)
56 |
57 | return x
58 |
59 |
60 | class MLP(nn.Module):
61 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
62 | super(MLP, self).__init__()
63 |
64 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
65 | self.linear = nn.Linear(mid_size, out_size)
66 |
67 | def forward(self, x):
68 | return self.linear(self.fc(x))
69 |
70 |
71 | class LayerNorm(nn.Module):
72 | def __init__(self, size, eps=1e-6):
73 | super(LayerNorm, self).__init__()
74 | self.eps = eps
75 |
76 | self.a_2 = nn.Parameter(torch.ones(size))
77 | self.b_2 = nn.Parameter(torch.zeros(size))
78 |
79 | def forward(self, x):
80 | mean = x.mean(-1, keepdim=True)
81 | std = x.std(-1, keepdim=True)
82 |
83 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
84 |
85 |
86 | # ------------------------------
87 | # ---- Multi-Head Attention ----
88 | # ------------------------------
89 |
90 | class MHAtt(nn.Module):
91 | def __init__(self):
92 | super(MHAtt, self).__init__()
93 |
94 | self.linear_v = nn.Linear(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
95 | self.linear_k = nn.Linear(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
96 | self.linear_q = nn.Linear(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
97 | self.linear_merge = nn.Linear(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
98 |
99 | self.dropout = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R'])
100 |
101 | def forward(self, v, k, q, mask):
102 | n_batches = q.size(0)
103 |
104 | v = self.linear_v(v).view(
105 | n_batches,
106 | -1,
107 | utils.MCAN_GQA_PARAMS['MULTI_HEAD'],
108 | int(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'] / utils.MCAN_GQA_PARAMS['MULTI_HEAD'])
109 | ).transpose(1, 2)
110 |
111 | k = self.linear_k(k).view(
112 | n_batches,
113 | -1,
114 | utils.MCAN_GQA_PARAMS['MULTI_HEAD'],
115 | int(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'] / utils.MCAN_GQA_PARAMS['MULTI_HEAD'])
116 | ).transpose(1, 2)
117 |
118 | q = self.linear_q(q).view(
119 | n_batches,
120 | -1,
121 | utils.MCAN_GQA_PARAMS['MULTI_HEAD'],
122 | int(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'] / utils.MCAN_GQA_PARAMS['MULTI_HEAD'])
123 | ).transpose(1, 2)
124 |
125 | atted = self.att(v, k, q, mask)
126 | atted = atted.transpose(1, 2).contiguous().view(
127 | n_batches,
128 | -1,
129 | utils.MCAN_GQA_PARAMS['HIDDEN_SIZE']
130 | )
131 |
132 | atted = self.linear_merge(atted)
133 |
134 | return atted
135 |
136 | def att(self, value, key, query, mask):
137 | d_k = query.size(-1)
138 |
139 | scores = torch.matmul(
140 | query, key.transpose(-2, -1)
141 | ) / math.sqrt(d_k)
142 |
143 | if mask is not None:
144 | scores = scores.masked_fill(mask, -1e9)
145 |
146 | att_map = F.softmax(scores, dim=-1)
147 | att_map = self.dropout(att_map)
148 |
149 | return torch.matmul(att_map, value)
150 |
151 |
152 | # ---------------------------
153 | # ---- Feed Forward Nets ----
154 | # ---------------------------
155 |
156 | class FFN(nn.Module):
157 | def __init__(self):
158 | super(FFN, self).__init__()
159 |
160 | self.mlp = MLP(
161 | in_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'],
162 | mid_size=utils.MCAN_GQA_PARAMS['FF_SIZE'],
163 | out_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'],
164 | dropout_r=utils.MCAN_GQA_PARAMS['DROPOUT_R'],
165 | use_relu=True
166 | )
167 |
168 | def forward(self, x):
169 | return self.mlp(x)
170 |
171 |
172 | # ------------------------
173 | # ---- Self Attention ----
174 | # ------------------------
175 |
176 | class SA(nn.Module):
177 | def __init__(self):
178 | super(SA, self).__init__()
179 |
180 | self.mhatt = MHAtt()
181 | self.ffn = FFN()
182 |
183 | self.dropout1 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R'])
184 | self.norm1 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
185 |
186 | self.dropout2 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R'])
187 | self.norm2 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
188 |
189 | def forward(self, y, y_mask):
190 | y = self.norm1(y + self.dropout1(
191 | self.mhatt(y, y, y, y_mask)
192 | ))
193 |
194 | y = self.norm2(y + self.dropout2(
195 | self.ffn(y)
196 | ))
197 |
198 | return y
199 |
200 |
201 | # -------------------------------
202 | # ---- Self Guided Attention ----
203 | # -------------------------------
204 |
205 | class SGA(nn.Module):
206 | def __init__(self):
207 | super(SGA, self).__init__()
208 |
209 | self.mhatt1 = MHAtt()
210 | self.mhatt2 = MHAtt()
211 | self.ffn = FFN()
212 |
213 | self.dropout1 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R'])
214 | self.norm1 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
215 |
216 | self.dropout2 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R'])
217 | self.norm2 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
218 |
219 | self.dropout3 = nn.Dropout(utils.MCAN_GQA_PARAMS['DROPOUT_R'])
220 | self.norm3 = LayerNorm(utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
221 |
222 | def forward(self, x, y, x_mask, y_mask):
223 | x = self.norm1(x + self.dropout1(
224 | self.mhatt1(v=x, k=x, q=x, mask=x_mask)
225 | ))
226 |
227 | x = self.norm2(x + self.dropout2(
228 | self.mhatt2(v=y, k=y, q=x, mask=y_mask)
229 | ))
230 |
231 | x = self.norm3(x + self.dropout3(
232 | self.ffn(x)
233 | ))
234 |
235 | return x
236 |
237 |
238 | # ------------------------------------------------
239 | # ---- MAC Layers Cascaded by Encoder-Decoder ----
240 | # ------------------------------------------------
241 |
242 | class MCA_ED(nn.Module):
243 | def __init__(self):
244 | super(MCA_ED, self).__init__()
245 |
246 | self.enc_list = nn.ModuleList([SA() for _ in range(utils.MCAN_GQA_PARAMS['LAYER'])])
247 | self.dec_list = nn.ModuleList([SGA() for _ in range(utils.MCAN_GQA_PARAMS['LAYER'])])
248 |
249 | def forward(self, y, x, y_mask, x_mask):
250 | # Get encoder last hidden vector
251 | for enc in self.enc_list:
252 | y = enc(y, y_mask)
253 |
254 | # Input encoder last hidden vector
255 | # And obtain decoder last hidden vectors
256 | for dec in self.dec_list:
257 | x = dec(x, y, x_mask, y_mask)
258 |
259 | return y, x
260 |
261 |
262 | def feat_filter(frcn_feat, grid_feat, bbox_feat):
263 | feat_dict = {}
264 |
265 | feat_dict['FRCN_FEAT'] = frcn_feat
266 | feat_dict['GRID_FEAT'] = grid_feat
267 | feat_dict['BBOX_FEAT'] = bbox_feat
268 |
269 | return feat_dict
270 |
271 |
272 | class BaseAdapter(nn.Module):
273 | def __init__(self):
274 | super(BaseAdapter, self).__init__()
275 | self.gqa_init()
276 |
277 | def gqa_init(self):
278 | raise NotImplementedError()
279 |
280 | def forward(self, frcn_feat, grid_feat, bbox_feat):
281 | feat_dict = feat_filter(frcn_feat, grid_feat, bbox_feat)
282 |
283 | return self.gqa_forward(feat_dict)
284 |
285 | def gqa_forward(self, feat_dict):
286 | raise NotImplementedError()
287 |
288 |
289 | class Adapter(BaseAdapter):
290 | def __init__(self):
291 | super(Adapter, self).__init__()
292 |
293 | def bbox_proc(self, bbox):
294 | area = (bbox[:, :, 2] - bbox[:, :, 0]) * (bbox[:, :, 3] - bbox[:, :, 1])
295 | # return torch.cat((bbox, area), -1)
296 | ##### FIXME: possibly buggy
297 | return torch.cat((bbox, area.unsqueeze(2)), -1)
298 | #####
299 |
300 | def gqa_init(self):
301 | imgfeat_linear_size = utils.MCAN_GQA_PARAMS['FRCN_FEAT_SIZE'][1]
302 | if utils.MCAN_GQA_PARAMS['USE_BBOX_FEAT']:
303 | self.bbox_linear = nn.Linear(5, utils.MCAN_GQA_PARAMS['BBOXFEAT_EMB_SIZE'])
304 | imgfeat_linear_size += utils.MCAN_GQA_PARAMS['BBOXFEAT_EMB_SIZE']
305 | self.frcn_linear = nn.Linear(imgfeat_linear_size, utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
306 |
307 | if utils.MCAN_GQA_PARAMS['USE_AUX_FEAT']:
308 | self.grid_linear = nn.Linear(utils.MCAN_GQA_PARAMS['GRID_FEAT_SIZE'][1], utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
309 |
310 |
311 | def gqa_forward(self, feat_dict):
312 | frcn_feat = feat_dict['FRCN_FEAT']
313 | bbox_feat = feat_dict['BBOX_FEAT']
314 | grid_feat = feat_dict['GRID_FEAT']
315 |
316 | img_feat_mask = make_mask(frcn_feat)
317 |
318 | if utils.MCAN_GQA_PARAMS['USE_BBOX_FEAT']:
319 | ##### FIXME: possibly buggy
320 | # bbox_feat = self.bbox_proc(bbox_feat)
321 | #####
322 | bbox_feat = self.bbox_linear(bbox_feat)
323 | frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1)
324 | img_feat = self.frcn_linear(frcn_feat)
325 |
326 | if utils.MCAN_GQA_PARAMS['USE_AUX_FEAT']:
327 | grid_feat_mask = make_mask(grid_feat)
328 | img_feat_mask = torch.cat((img_feat_mask, grid_feat_mask), dim=-1)
329 | grid_feat = self.grid_linear(grid_feat)
330 | img_feat = torch.cat((img_feat, grid_feat), dim=1)
331 |
332 | return img_feat, img_feat_mask
333 |
334 |
335 | # ------------------------------
336 | # ---- Flatten the sequence ----
337 | # ------------------------------
338 |
339 | class AttFlat(nn.Module):
340 | def __init__(self):
341 | super(AttFlat, self).__init__()
342 |
343 | self.mlp = MLP(
344 | in_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'],
345 | mid_size=utils.MCAN_GQA_PARAMS['FLAT_MLP_SIZE'],
346 | out_size=utils.MCAN_GQA_PARAMS['FLAT_GLIMPSES'],
347 | dropout_r=utils.MCAN_GQA_PARAMS['DROPOUT_R'],
348 | use_relu=True
349 | )
350 |
351 | self.linear_merge = nn.Linear(
352 | utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'] * utils.MCAN_GQA_PARAMS['FLAT_GLIMPSES'],
353 | utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE']
354 | )
355 |
356 | def forward(self, x, x_mask):
357 | att = self.mlp(x)
358 | att = att.masked_fill(
359 | x_mask.squeeze(1).squeeze(1).unsqueeze(2),
360 | -1e9
361 | )
362 | att = F.softmax(att, dim=1)
363 |
364 | att_list = []
365 | for i in range(utils.MCAN_GQA_PARAMS['FLAT_GLIMPSES']):
366 | att_list.append(
367 | torch.sum(att[:, :, i: i + 1] * x, dim=1)
368 | )
369 |
370 | x_atted = torch.cat(att_list, dim=1)
371 | x_atted = self.linear_merge(x_atted)
372 |
373 | return x_atted
374 |
375 |
376 | @register('mcan')
377 | class MCAN(nn.Module):
378 | def __init__(self, **kwargs):
379 | super(MCAN, self).__init__()
380 |
381 | self.embedding = nn.Embedding(
382 | num_embeddings=utils.MCAN_GQA_PARAMS['TOKEN_SIZE'],
383 | embedding_dim=utils.MCAN_GQA_PARAMS['WORD_EMBED_SIZE']
384 | )
385 |
386 | # Loading the GloVe embedding weights
387 | self.embedding.weight.data.copy_(torch.from_numpy(np.load(kwargs.get('word_emb_path'))))
388 |
389 | self.lstm = nn.LSTM(
390 | input_size=utils.MCAN_GQA_PARAMS['WORD_EMBED_SIZE'],
391 | hidden_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'],
392 | num_layers=1,
393 | batch_first=True
394 | )
395 |
396 | self.adapter = Adapter()
397 |
398 | self.backbone = MCA_ED()
399 |
400 | # Flatten to vector
401 | self.attflat_img = AttFlat()
402 | self.attflat_lang = AttFlat()
403 |
404 | # Classification layers
405 | self.proj_norm = LayerNorm(utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE'])
406 | self.proj = nn.Linear(utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE'], utils.MCAN_GQA_PARAMS['ANSWER_SIZE'])
407 |
408 |
409 | def forward(self, frcn_feat, grid_feat, bbox_feat, ques_ix, info_nce=False, pretrain=False):
410 |
411 | # Pre-process Language Feature
412 | lang_feat_mask = make_mask(ques_ix.unsqueeze(2))
413 | lang_feat = self.embedding(ques_ix)
414 | lang_feat, _ = self.lstm(lang_feat)
415 |
416 | img_feat, img_feat_mask = self.adapter(frcn_feat, grid_feat, bbox_feat)
417 |
418 | # Backbone Framework
419 | # lang_feat: [B, MAX_TOKEN_LENGTH, HIDDEN_SIZE]
420 | # img_feat: [B, FRCN_FEAT_SIZE[0]+GRID_FEAT_SIZE[0], HIDDEN_SIZE]
421 | lang_feat, img_feat = self.backbone(
422 | lang_feat,
423 | img_feat,
424 | lang_feat_mask,
425 | img_feat_mask
426 | )
427 |
428 | # Flatten to vector
429 | lang_feat = self.attflat_lang(
430 | lang_feat,
431 | lang_feat_mask
432 | )
433 |
434 | img_feat = self.attflat_img(
435 | img_feat,
436 | img_feat_mask
437 | )
438 |
439 | # Classification layers
440 | proj_feat = lang_feat + img_feat
441 | proj_feat = self.proj_norm(proj_feat)
442 | proj_feat = self.proj(proj_feat)
443 |
444 | return proj_feat
445 |
446 |
447 | @register('mcan-customized')
448 | class MCANCustomized(nn.Module):
449 | def __init__(self, **kwargs):
450 | super(MCANCustomized, self).__init__()
451 |
452 | self.encoder = models.make(kwargs.get('encoder'), **kwargs.get('encoder_args'))
453 | self.connector = nn.Linear(self.encoder.out_dim, utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'])
454 |
455 | self.embedding = nn.Embedding(
456 | num_embeddings=utils.MCAN_GQA_PARAMS['TOKEN_SIZE'],
457 | embedding_dim=utils.MCAN_GQA_PARAMS['WORD_EMBED_SIZE']
458 | )
459 |
460 | # Loading the GloVe embedding weights
461 | self.embedding.weight.data.copy_(torch.from_numpy(np.load(kwargs.get('word_emb_path'))))
462 |
463 | self.lstm = nn.LSTM(
464 | input_size=utils.MCAN_GQA_PARAMS['WORD_EMBED_SIZE'],
465 | hidden_size=utils.MCAN_GQA_PARAMS['HIDDEN_SIZE'],
466 | num_layers=1,
467 | batch_first=True
468 | )
469 |
470 | self.backbone = MCA_ED()
471 |
472 | # Flatten to vector
473 | self.attflat_img = AttFlat()
474 | self.attflat_lang = AttFlat()
475 |
476 | # Classification layers
477 | self.proj_norm = LayerNorm(utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE'])
478 | self.proj = nn.Linear(utils.MCAN_GQA_PARAMS['FLAT_OUT_SIZE'], utils.MCAN_GQA_PARAMS['ANSWER_SIZE'])
479 |
480 |
481 | def forward(self, ims, ques_ix, info_nce=False, pretrain=False):
482 | if pretrain:
483 | B = ims.size(0)
484 | logits = torch.zeros(B, utils.MCAN_GQA_PARAMS['ANSWER_SIZE']).to(ims)
485 | if info_nce:
486 | img_feat, attn_v = self.encoder(ims, info_nce=info_nce)
487 | return logits, attn_v
488 | else:
489 | return logits
490 | else:
491 | if info_nce:
492 | img_feat, attn_v = self.encoder(ims, info_nce=info_nce)
493 | else:
494 | img_feat = self.encoder(ims)
495 | B, C = img_feat.size(0), img_feat.size(1)
496 | img_feat = self.connector(img_feat.reshape(B, C, -1).permute(0, 2, 1))
497 | img_feat_mask = make_mask(img_feat)
498 |
499 | # Pre-process Language Feature
500 | lang_feat_mask = make_mask(ques_ix.unsqueeze(2))
501 | lang_feat = self.embedding(ques_ix)
502 | self.lstm.flatten_parameters()
503 | lang_feat, _ = self.lstm(lang_feat)
504 |
505 | # Backbone Framework
506 | lang_feat, img_feat = self.backbone(
507 | lang_feat,
508 | img_feat,
509 | lang_feat_mask,
510 | img_feat_mask
511 | )
512 |
513 | # Flatten to vector
514 | lang_feat = self.attflat_lang(
515 | lang_feat,
516 | lang_feat_mask
517 | )
518 |
519 | img_feat = self.attflat_img(
520 | img_feat,
521 | img_feat_mask
522 | )
523 |
524 | # Classification layers
525 | proj_feat = lang_feat + img_feat
526 | proj_feat = self.proj_norm(proj_feat)
527 | proj_feat = self.proj(proj_feat)
528 |
529 | if info_nce:
530 | return proj_feat, attn_v
531 | else:
532 | return proj_feat
533 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the Bongard-HOI library
5 | # which was released under the NVIDIA Source Code Licence.
6 | #
7 | # Source:
8 | # https://github.com/NVlabs/Bongard-HOI
9 | #
10 | # The license for the original version of this file can be
11 | # found in https://github.com/NVlabs/Bongard-HOI/blob/master/LICENSE
12 | # The modifications to this file are subject to the same NVIDIA Source Code Licence.
13 | # ---------------------------------------------------------------
14 |
15 | import torch
16 |
17 |
18 | models = {}
19 | def register(name):
20 | def decorator(cls):
21 | models[name] = cls
22 | return cls
23 | return decorator
24 |
25 |
26 | def make(name, **kwargs):
27 | if name is None:
28 | return None
29 | model = models[name](**kwargs)
30 | # if torch.cuda.is_available():
31 | # model = model.cuda()
32 | return model
33 |
34 |
35 | def load(model_sv, name=None):
36 | if name is None:
37 | name = 'model'
38 | model = make(model_sv[name], **model_sv[name + '_args'])
39 | missing_keys, unexpected_keys = model.load_state_dict(model_sv[name + '_sd'])
40 | return model
41 |
--------------------------------------------------------------------------------
/models/pvt_v2.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the PVT library
5 | # which was released under the Apache Licence.
6 | #
7 | # Source:
8 | # https://github.com/whai362/PVT
9 | #
10 | # The license for the original version of this file can be
11 | # found in https://github.com/whai362/PVT/blob/v2/LICENSE
12 | # The modifications to this file are subject to the same Apache Licence.
13 | # ---------------------------------------------------------------
14 |
15 | import torch
16 | import torch.nn as nn
17 | import math
18 | import warnings
19 | from functools import partial
20 | from itertools import repeat
21 | import collections.abc
22 |
23 | import utils
24 | from .models import register
25 |
26 | __all__ = [
27 | 'pvt_v2_b0', 'pvt_v2_b1', 'pvt_v2_b2'
28 | ]
29 |
30 |
31 | def drop_path(x, drop_prob: float = 0., training: bool = False):
32 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
33 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
34 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
35 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
36 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
37 | 'survival rate' as the argument.
38 | """
39 | if drop_prob == 0. or not training:
40 | return x
41 | keep_prob = 1 - drop_prob
42 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
43 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
44 | random_tensor.floor_() # binarize
45 | output = x.div(keep_prob) * random_tensor
46 | return output
47 |
48 |
49 | class DropPath(nn.Module):
50 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
51 | """
52 | def __init__(self, drop_prob=None):
53 | super(DropPath, self).__init__()
54 | self.drop_prob = drop_prob
55 |
56 | def forward(self, x):
57 | return drop_path(x, self.drop_prob, self.training)
58 |
59 |
60 | def _ntuple(n):
61 | def parse(x):
62 | if isinstance(x, collections.abc.Iterable):
63 | return x
64 | return tuple(repeat(x, n))
65 | return parse
66 |
67 |
68 | to_1tuple = _ntuple(1)
69 | to_2tuple = _ntuple(2)
70 | to_3tuple = _ntuple(3)
71 | to_4tuple = _ntuple(4)
72 | to_ntuple = _ntuple
73 |
74 |
75 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
76 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
77 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
78 | def norm_cdf(x):
79 | # Computes standard normal cumulative distribution function
80 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
81 |
82 | if (mean < a - 2 * std) or (mean > b + 2 * std):
83 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
84 | "The distribution of values may be incorrect.",
85 | stacklevel=2)
86 |
87 | with torch.no_grad():
88 | # Values are generated by using a truncated uniform distribution and
89 | # then using the inverse CDF for the normal distribution.
90 | # Get upper and lower cdf values
91 | l = norm_cdf((a - mean) / std)
92 | u = norm_cdf((b - mean) / std)
93 |
94 | # Uniformly fill tensor with values from [l, u], then translate to
95 | # [2l-1, 2u-1].
96 | tensor.uniform_(2 * l - 1, 2 * u - 1)
97 |
98 | # Use inverse cdf transform for normal distribution to get truncated
99 | # standard normal
100 | tensor.erfinv_()
101 |
102 | # Transform to proper mean, std
103 | tensor.mul_(std * math.sqrt(2.))
104 | tensor.add_(mean)
105 |
106 | # Clamp to ensure it's in the proper range
107 | tensor.clamp_(min=a, max=b)
108 | return tensor
109 |
110 |
111 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
112 | # type: (Tensor, float, float, float, float) -> Tensor
113 | r"""Fills the input Tensor with values drawn from a truncated
114 | normal distribution. The values are effectively drawn from the
115 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
116 | with values outside :math:`[a, b]` redrawn until they are within
117 | the bounds. The method used for generating the random values works
118 | best when :math:`a \leq \text{mean} \leq b`.
119 | Args:
120 | tensor: an n-dimensional `torch.Tensor`
121 | mean: the mean of the normal distribution
122 | std: the standard deviation of the normal distribution
123 | a: the minimum cutoff value
124 | b: the maximum cutoff value
125 | Examples:
126 | >>> w = torch.empty(3, 5)
127 | >>> nn.init.trunc_normal_(w)
128 | """
129 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
130 |
131 |
132 | class Mlp(nn.Module):
133 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
134 | super().__init__()
135 | out_features = out_features or in_features
136 | hidden_features = hidden_features or in_features
137 | self.fc1 = nn.Linear(in_features, hidden_features)
138 | self.dwconv = DWConv(hidden_features)
139 | self.act = act_layer()
140 | self.fc2 = nn.Linear(hidden_features, out_features)
141 | self.drop = nn.Dropout(drop)
142 | self.linear = linear
143 | if self.linear:
144 | self.relu = nn.ReLU(inplace=True)
145 | self.apply(self._init_weights)
146 |
147 | def _init_weights(self, m):
148 | if isinstance(m, nn.Linear):
149 | trunc_normal_(m.weight, std=.02)
150 | if isinstance(m, nn.Linear) and m.bias is not None:
151 | nn.init.constant_(m.bias, 0)
152 | elif isinstance(m, nn.LayerNorm):
153 | nn.init.constant_(m.bias, 0)
154 | nn.init.constant_(m.weight, 1.0)
155 | elif isinstance(m, nn.Conv2d):
156 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
157 | fan_out //= m.groups
158 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
159 | if m.bias is not None:
160 | m.bias.data.zero_()
161 |
162 | def forward(self, x, H, W):
163 | x = self.fc1(x)
164 | if self.linear:
165 | x = self.relu(x)
166 | x = self.dwconv(x, H, W).contiguous()
167 | x = self.act(x)
168 | x = self.drop(x)
169 | x = self.fc2(x)
170 | x = self.drop(x)
171 | return x
172 |
173 |
174 | class Attention(nn.Module):
175 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
176 | super().__init__()
177 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
178 |
179 | self.dim = dim
180 | self.num_heads = num_heads
181 | head_dim = dim // num_heads
182 | self.scale = qk_scale or head_dim ** -0.5
183 |
184 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
185 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
186 | self.attn_drop = nn.Dropout(attn_drop)
187 | self.proj = nn.Linear(dim, dim)
188 | self.proj_drop = nn.Dropout(proj_drop)
189 |
190 | self.linear = linear
191 | self.sr_ratio = sr_ratio
192 | if not linear:
193 | if sr_ratio > 1:
194 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
195 | self.norm = nn.LayerNorm(dim)
196 | else:
197 | self.pool = nn.AdaptiveAvgPool2d(7)
198 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
199 | self.norm = nn.LayerNorm(dim)
200 | self.act = nn.GELU()
201 | self.apply(self._init_weights)
202 |
203 | def _init_weights(self, m):
204 | if isinstance(m, nn.Linear):
205 | trunc_normal_(m.weight, std=.02)
206 | if isinstance(m, nn.Linear) and m.bias is not None:
207 | nn.init.constant_(m.bias, 0)
208 | elif isinstance(m, nn.LayerNorm):
209 | nn.init.constant_(m.bias, 0)
210 | nn.init.constant_(m.weight, 1.0)
211 | elif isinstance(m, nn.Conv2d):
212 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
213 | fan_out //= m.groups
214 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
215 | if m.bias is not None:
216 | m.bias.data.zero_()
217 |
218 | def forward(self, x, H, W, info_nce=False):
219 | B, N, C = x.shape
220 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
221 |
222 | if not self.linear:
223 | if self.sr_ratio > 1:
224 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
225 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
226 | x_ = self.norm(x_)
227 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
228 | else:
229 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
230 | else:
231 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
232 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
233 | x_ = self.norm(x_)
234 | x_ = self.act(x_)
235 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
236 | k, v = kv[0], kv[1]
237 |
238 | attn = (q @ k.transpose(-2, -1)) * self.scale
239 | attn = attn.softmax(dim=-1)
240 | attn = self.attn_drop(attn)
241 | if info_nce:
242 | attn_ret = attn.detach()
243 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
244 | x = self.proj(x)
245 | x = self.proj_drop(x)
246 | if info_nce:
247 | return x, (attn_ret, v)
248 | else:
249 | return x
250 |
251 |
252 | class Block(nn.Module):
253 |
254 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
255 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
256 | super().__init__()
257 | self.norm1 = norm_layer(dim)
258 | self.attn = Attention(
259 | dim,
260 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
261 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
262 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
263 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
264 | self.norm2 = norm_layer(dim)
265 | mlp_hidden_dim = int(dim * mlp_ratio)
266 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)
267 |
268 | self.apply(self._init_weights)
269 |
270 | def _init_weights(self, m):
271 | if isinstance(m, nn.Linear):
272 | trunc_normal_(m.weight, std=.02)
273 | if isinstance(m, nn.Linear) and m.bias is not None:
274 | nn.init.constant_(m.bias, 0)
275 | elif isinstance(m, nn.LayerNorm):
276 | nn.init.constant_(m.bias, 0)
277 | nn.init.constant_(m.weight, 1.0)
278 | elif isinstance(m, nn.Conv2d):
279 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
280 | fan_out //= m.groups
281 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
282 | if m.bias is not None:
283 | m.bias.data.zero_()
284 |
285 | def forward(self, x, H, W, info_nce=False):
286 | if info_nce:
287 | tmp, attn_v = self.attn(self.norm1(x), H, W, info_nce)
288 | else:
289 | tmp = self.attn(self.norm1(x), H, W)
290 | x = x + self.drop_path(tmp)
291 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
292 | if info_nce:
293 | return x, attn_v
294 | else:
295 | return x
296 |
297 |
298 | class OverlapPatchEmbed(nn.Module):
299 | """ Image to Patch Embedding
300 | """
301 |
302 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
303 | super().__init__()
304 | img_size = to_2tuple(img_size)
305 | patch_size = to_2tuple(patch_size)
306 |
307 | self.img_size = img_size
308 | self.patch_size = patch_size
309 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
310 | self.num_patches = self.H * self.W
311 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
312 | padding=(patch_size[0] // 2, patch_size[1] // 2))
313 | self.norm = nn.LayerNorm(embed_dim)
314 |
315 | self.apply(self._init_weights)
316 |
317 | def _init_weights(self, m):
318 | if isinstance(m, nn.Linear):
319 | trunc_normal_(m.weight, std=.02)
320 | if isinstance(m, nn.Linear) and m.bias is not None:
321 | nn.init.constant_(m.bias, 0)
322 | elif isinstance(m, nn.LayerNorm):
323 | nn.init.constant_(m.bias, 0)
324 | nn.init.constant_(m.weight, 1.0)
325 | elif isinstance(m, nn.Conv2d):
326 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
327 | fan_out //= m.groups
328 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
329 | if m.bias is not None:
330 | m.bias.data.zero_()
331 |
332 | def forward(self, x):
333 | x = self.proj(x)
334 | _, _, H, W = x.shape
335 | x = x.flatten(2).transpose(1, 2).contiguous()
336 | x = self.norm(x)
337 |
338 | return x, H, W
339 |
340 |
341 | class PyramidVisionTransformerV2(nn.Module):
342 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
343 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
344 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3],
345 | sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False):
346 | super().__init__()
347 | self.depths = depths
348 | self.num_stages = num_stages
349 | self.linear = linear
350 | self.embed_dims = embed_dims
351 | self.out_dim = embed_dims[-1]
352 |
353 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
354 | cur = 0
355 |
356 | for i in range(num_stages):
357 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
358 | patch_size=7 if i == 0 else 3,
359 | stride=4 if i == 0 else 2,
360 | in_chans=in_chans if i == 0 else embed_dims[i - 1],
361 | embed_dim=embed_dims[i])
362 |
363 | block = nn.ModuleList([Block(
364 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
365 | qk_scale=qk_scale,
366 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer,
367 | sr_ratio=sr_ratios[i], linear=linear)
368 | for j in range(depths[i])])
369 | norm = norm_layer(embed_dims[i])
370 | cur += depths[i]
371 |
372 | setattr(self, f"patch_embed{i + 1}", patch_embed)
373 | setattr(self, f"block{i + 1}", block)
374 | setattr(self, f"norm{i + 1}", norm)
375 |
376 | # init weights (proj head has its own init process)
377 | self.apply(self._init_weights)
378 |
379 | # # Projection head for CL
380 | self.proj = utils.ProjHead(self.out_dim, self.out_dim)
381 |
382 | # Projection head for EsViT extra CL
383 | self.proj2 = utils.ProjHead(self.out_dim, self.out_dim)
384 |
385 | self.num_heads = num_heads[-1]
386 |
387 | def _init_weights(self, m):
388 | if isinstance(m, nn.Linear):
389 | trunc_normal_(m.weight, std=.02)
390 | if isinstance(m, nn.Linear) and m.bias is not None:
391 | nn.init.constant_(m.bias, 0)
392 | elif isinstance(m, nn.LayerNorm):
393 | nn.init.constant_(m.bias, 0)
394 | nn.init.constant_(m.weight, 1.0)
395 | elif isinstance(m, nn.Conv2d):
396 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
397 | fan_out //= m.groups
398 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
399 | if m.bias is not None:
400 | m.bias.data.zero_()
401 |
402 | def forward_features(self, x, info_nce=False):
403 | if info_nce:
404 | attn_vs = []
405 | outs = []
406 | B = x.shape[0]
407 |
408 | for i in range(self.num_stages):
409 | patch_embed = getattr(self, f"patch_embed{i + 1}")
410 | block = getattr(self, f"block{i + 1}")
411 | norm = getattr(self, f"norm{i + 1}")
412 | x, H, W = patch_embed(x)
413 | if info_nce:
414 | for blk in block:
415 | x, attn_v = blk(x, H, W, info_nce)
416 | attn_vs.append(attn_v)
417 | else:
418 | for blk in block:
419 | x = blk(x, H, W)
420 | x = norm(x)
421 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
422 | outs.append(x)
423 |
424 | # output is (B, C, H, W)
425 | B, C = x.size(0), x.size(1)
426 | if info_nce:
427 | x = x.reshape(B, C, -1).permute(0, 2, 1).contiguous()
428 | return outs[-1].contiguous(), (self.proj2(x.max(-2)[0].contiguous()).contiguous(), self.proj(x.contiguous()).contiguous(), x.contiguous())
429 | else:
430 | return outs[-1].contiguous()
431 |
432 | def forward(self, x, info_nce=False):
433 | if info_nce:
434 | x, attn_v = self.forward_features(x, info_nce)
435 | return x, attn_v
436 | else:
437 | x = self.forward_features(x)
438 | return x
439 |
440 |
441 | class DWConv(nn.Module):
442 | def __init__(self, dim=768):
443 | super(DWConv, self).__init__()
444 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
445 |
446 | def forward(self, x, H, W):
447 | B, N, C = x.shape
448 | x = x.transpose(1, 2).view(B, C, H, W).contiguous()
449 | x = self.dwconv(x).contiguous()
450 | x = x.flatten(2).transpose(1, 2).contiguous()
451 |
452 | return x
453 |
454 |
455 | def _conv_filter(state_dict, patch_size=16):
456 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
457 | out_dict = {}
458 | for k, v in state_dict.items():
459 | if 'patch_embed.proj.weight' in k:
460 | v = v.reshape((v.shape[0], 3, patch_size, patch_size))
461 | out_dict[k] = v
462 |
463 | return out_dict
464 |
465 |
466 | class ViTWrap(nn.Module):
467 | def __init__(self, vit):
468 | super().__init__()
469 | self.backbone = vit
470 |
471 | @register('pvtv2_b0')
472 | class pvt_v2_b0(PyramidVisionTransformerV2):
473 | def __init__(self, **kwargs):
474 | super(pvt_v2_b0, self).__init__(
475 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
476 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
477 | drop_rate=0.0, drop_path_rate=0.1)
478 |
479 |
480 | @register('pvtv2_b1')
481 | class pvt_v2_b1(PyramidVisionTransformerV2):
482 | def __init__(self, **kwargs):
483 | super(pvt_v2_b1, self).__init__(
484 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
485 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
486 | drop_rate=0.0, drop_path_rate=0.1)
487 |
488 |
489 | @register('pvtv2_b2')
490 | class pvt_v2_b2(PyramidVisionTransformerV2):
491 | def __init__(self, **kwargs):
492 | super(pvt_v2_b2, self).__init__(
493 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
494 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
495 | drop_rate=0.0, drop_path_rate=0.1)
496 |
497 |
498 | @register('pvtv2_b2_li')
499 | class pvt_v2_b2_li(PyramidVisionTransformerV2):
500 | def __init__(self, **kwargs):
501 | super(pvt_v2_b2_li, self).__init__(
502 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
503 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
504 | drop_rate=0.0, drop_path_rate=0.1, linear=True)
505 |
506 |
507 | @register('pvtv2_b3')
508 | class pvt_v2_b3(PyramidVisionTransformerV2):
509 | def __init__(self, **kwargs):
510 | super(pvt_v2_b3, self).__init__(
511 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
512 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
513 | drop_rate=0.0, drop_path_rate=0.1)
514 |
515 |
516 | @register('pvtv2_b4')
517 | class pvt_v2_b4(PyramidVisionTransformerV2):
518 | def __init__(self, **kwargs):
519 | super(pvt_v2_b4, self).__init__(
520 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
521 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
522 | drop_rate=0.0, drop_path_rate=0.1)
523 |
524 |
525 | @register('pvtv2_b4')
526 | class pvt_v2_b5(PyramidVisionTransformerV2):
527 | def __init__(self, **kwargs):
528 | super(pvt_v2_b5, self).__init__(
529 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
530 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
531 | drop_rate=0.0, drop_path_rate=0.1)
532 |
--------------------------------------------------------------------------------
/models/swin_transformer.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the Swin-Transformer library
5 | # which was released under the MIT Licence.
6 | #
7 | # Source:
8 | # https://github.com/microsoft/Swin-Transformer
9 | #
10 | # The license for the original version of this file can be
11 | # found in https://github.com/microsoft/Swin-Transformer/blob/main/LICENSE
12 | # The modifications to this file are subject to the same Apache Licence.
13 | # ---------------------------------------------------------------
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import torch.utils.checkpoint as checkpoint
19 | import numpy as np
20 | import collections.abc
21 | from itertools import repeat
22 | import math
23 | import warnings
24 |
25 | import utils
26 | from .models import register
27 |
28 | __all__ = [
29 | 'swin_tiny', 'swin_small'
30 | ]
31 |
32 |
33 | def drop_path(x, drop_prob: float = 0., training: bool = False):
34 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
35 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
36 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
37 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
38 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
39 | 'survival rate' as the argument.
40 | """
41 | if drop_prob == 0. or not training:
42 | return x
43 | keep_prob = 1 - drop_prob
44 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
45 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
46 | random_tensor.floor_() # binarize
47 | output = x.div(keep_prob) * random_tensor
48 | return output
49 |
50 |
51 | class DropPath(nn.Module):
52 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
53 | """
54 | def __init__(self, drop_prob=None):
55 | super(DropPath, self).__init__()
56 | self.drop_prob = drop_prob
57 |
58 | def forward(self, x):
59 | return drop_path(x, self.drop_prob, self.training)
60 |
61 |
62 | def _ntuple(n):
63 | def parse(x):
64 | if isinstance(x, collections.abc.Iterable):
65 | return x
66 | return tuple(repeat(x, n))
67 | return parse
68 |
69 |
70 | to_1tuple = _ntuple(1)
71 | to_2tuple = _ntuple(2)
72 | to_3tuple = _ntuple(3)
73 | to_4tuple = _ntuple(4)
74 | to_ntuple = _ntuple
75 |
76 |
77 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
78 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
79 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
80 | def norm_cdf(x):
81 | # Computes standard normal cumulative distribution function
82 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
83 |
84 | if (mean < a - 2 * std) or (mean > b + 2 * std):
85 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
86 | "The distribution of values may be incorrect.",
87 | stacklevel=2)
88 |
89 | with torch.no_grad():
90 | # Values are generated by using a truncated uniform distribution and
91 | # then using the inverse CDF for the normal distribution.
92 | # Get upper and lower cdf values
93 | l = norm_cdf((a - mean) / std)
94 | u = norm_cdf((b - mean) / std)
95 |
96 | # Uniformly fill tensor with values from [l, u], then translate to
97 | # [2l-1, 2u-1].
98 | tensor.uniform_(2 * l - 1, 2 * u - 1)
99 |
100 | # Use inverse cdf transform for normal distribution to get truncated
101 | # standard normal
102 | tensor.erfinv_()
103 |
104 | # Transform to proper mean, std
105 | tensor.mul_(std * math.sqrt(2.))
106 | tensor.add_(mean)
107 |
108 | # Clamp to ensure it's in the proper range
109 | tensor.clamp_(min=a, max=b)
110 | return tensor
111 |
112 |
113 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
114 | # type: (Tensor, float, float, float, float) -> Tensor
115 | r"""Fills the input Tensor with values drawn from a truncated
116 | normal distribution. The values are effectively drawn from the
117 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
118 | with values outside :math:`[a, b]` redrawn until they are within
119 | the bounds. The method used for generating the random values works
120 | best when :math:`a \leq \text{mean} \leq b`.
121 | Args:
122 | tensor: an n-dimensional `torch.Tensor`
123 | mean: the mean of the normal distribution
124 | std: the standard deviation of the normal distribution
125 | a: the minimum cutoff value
126 | b: the maximum cutoff value
127 | Examples:
128 | >>> w = torch.empty(3, 5)
129 | >>> nn.init.trunc_normal_(w)
130 | """
131 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
132 |
133 |
134 | class Mlp(nn.Module):
135 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
136 | super().__init__()
137 | out_features = out_features or in_features
138 | hidden_features = hidden_features or in_features
139 | self.fc1 = nn.Linear(in_features, hidden_features)
140 | self.act = act_layer()
141 | self.fc2 = nn.Linear(hidden_features, out_features)
142 | self.drop = nn.Dropout(drop)
143 |
144 | def forward(self, x):
145 | x = self.fc1(x)
146 | x = self.act(x)
147 | x = self.drop(x)
148 | x = self.fc2(x)
149 | x = self.drop(x)
150 | return x
151 |
152 |
153 | def window_partition(x, window_size):
154 | """
155 | Args:
156 | x: (B, H, W, C)
157 | window_size (int): window size
158 |
159 | Returns:
160 | windows: (num_windows*B, window_size, window_size, C)
161 | """
162 | B, H, W, C = x.shape
163 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
164 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
165 | return windows
166 |
167 |
168 | def window_reverse(windows, window_size, H, W):
169 | """
170 | Args:
171 | windows: (num_windows*B, window_size, window_size, C)
172 | window_size (int): Window size
173 | H (int): Height of image
174 | W (int): Width of image
175 |
176 | Returns:
177 | x: (B, H, W, C)
178 | """
179 | B = int(windows.shape[0] / (H * W / window_size / window_size))
180 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
181 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
182 | return x
183 |
184 |
185 | class WindowAttention(nn.Module):
186 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
187 | It supports both of shifted and non-shifted window.
188 |
189 | Args:
190 | dim (int): Number of input channels.
191 | window_size (tuple[int]): The height and width of the window.
192 | num_heads (int): Number of attention heads.
193 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
194 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
195 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
196 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
197 | """
198 |
199 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
200 |
201 | super().__init__()
202 | self.dim = dim
203 | self.window_size = window_size # Wh, Ww
204 | self.num_heads = num_heads
205 | head_dim = dim // num_heads
206 | self.scale = qk_scale or head_dim ** -0.5
207 |
208 | # define a parameter table of relative position bias
209 | self.relative_position_bias_table = nn.Parameter(
210 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
211 |
212 | # get pair-wise relative position index for each token inside the window
213 | coords_h = torch.arange(self.window_size[0])
214 | coords_w = torch.arange(self.window_size[1])
215 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
216 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
217 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
218 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
219 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
220 | relative_coords[:, :, 1] += self.window_size[1] - 1
221 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
222 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
223 | self.register_buffer("relative_position_index", relative_position_index)
224 |
225 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
226 | self.attn_drop = nn.Dropout(attn_drop)
227 | self.proj = nn.Linear(dim, dim)
228 | self.proj_drop = nn.Dropout(proj_drop)
229 |
230 | trunc_normal_(self.relative_position_bias_table, std=.02)
231 | self.softmax = nn.Softmax(dim=-1)
232 |
233 | def forward(self, x, mask=None, infonce=False):
234 | """
235 | Args:
236 | x: input features with shape of (num_windows*B, N, C)
237 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
238 | """
239 | B_, N, C = x.shape
240 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
241 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
242 |
243 | q = q * self.scale
244 | attn = (q @ k.transpose(-2, -1))
245 |
246 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
247 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
248 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
249 | attn = attn + relative_position_bias.unsqueeze(0)
250 |
251 | if mask is not None:
252 | nW = mask.shape[0]
253 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
254 | attn = attn.view(-1, self.num_heads, N, N)
255 | attn = self.softmax(attn)
256 | else:
257 | attn = self.softmax(attn)
258 |
259 | attn = self.attn_drop(attn)
260 |
261 | if infonce:
262 | attn_ret = attn.detach()
263 |
264 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
265 | x = self.proj(x)
266 | x = self.proj_drop(x)
267 |
268 | if infonce:
269 | return x, (attn_ret, v)
270 | else:
271 | return x
272 |
273 |
274 | class SwinTransformerBlock(nn.Module):
275 | r""" Swin Transformer Block.
276 |
277 | Args:
278 | dim (int): Number of input channels.
279 | num_heads (int): Number of attention heads.
280 | window_size (int): Window size.
281 | shift_size (int): Shift size for SW-MSA.
282 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
283 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
284 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
285 | drop (float, optional): Dropout rate. Default: 0.0
286 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
287 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
288 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
289 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
290 | """
291 |
292 | def __init__(self, dim, num_heads, window_size=7, shift_size=0,
293 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
294 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
295 | super().__init__()
296 | self.dim = dim
297 | self.num_heads = num_heads
298 | self.window_size = window_size
299 | self.shift_size = shift_size
300 | self.mlp_ratio = mlp_ratio
301 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
302 |
303 | self.norm1 = norm_layer(dim)
304 | self.attn = WindowAttention(
305 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
306 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
307 |
308 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
309 | self.norm2 = norm_layer(dim)
310 | mlp_hidden_dim = int(dim * mlp_ratio)
311 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
312 | self.H, self.W = None, None
313 |
314 | def forward(self, x, mask_matrix, infonce=False):
315 | B, L, C = x.shape
316 | H, W = self.H, self.W
317 | assert L == H * W, "input feature has wrong size"
318 |
319 | shortcut = x
320 | x = self.norm1(x)
321 | x = x.view(B, H, W, C)
322 |
323 | # pad feature maps to multiples of window size
324 | pad_l = pad_t = 0
325 | pad_r = (self.window_size - W % self.window_size) % self.window_size
326 | pad_b = (self.window_size - H % self.window_size) % self.window_size
327 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
328 | _, Hp, Wp, _ = x.shape
329 |
330 | # cyclic shift
331 | if self.shift_size > 0:
332 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
333 | attn_mask = mask_matrix
334 | else:
335 | shifted_x = x
336 | attn_mask = None
337 |
338 | # partition windows
339 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
340 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
341 |
342 | # W-MSA/SW-MSA
343 | ret = self.attn(x_windows, mask=attn_mask, infonce=infonce) # nW*B, window_size*window_size, C
344 | if infonce:
345 | attn_windows, attn_v = ret
346 | else:
347 | attn_windows = ret
348 |
349 | # merge windows
350 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
351 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
352 |
353 | # reverse cyclic shift
354 | if self.shift_size > 0:
355 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
356 | else:
357 | x = shifted_x
358 |
359 | if pad_r > 0 or pad_b > 0:
360 | x = x[:, :H, :W, :].contiguous()
361 |
362 | x = x.view(B, H * W, C)
363 |
364 | # FFN
365 | x = shortcut + self.drop_path(x)
366 | x = x + self.drop_path(self.mlp(self.norm2(x)))
367 |
368 | if infonce:
369 | return x, attn_v
370 | else:
371 | return x
372 |
373 |
374 | class PatchMerging(nn.Module):
375 | r""" Patch Merging Layer.
376 |
377 | Args:
378 | dim (int): Number of input channels.
379 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
380 | """
381 |
382 | def __init__(self, dim, norm_layer=nn.LayerNorm):
383 | super().__init__()
384 | self.dim = dim
385 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
386 | self.norm = norm_layer(4 * dim)
387 |
388 | def forward(self, x, H, W):
389 | """
390 | x: B, H*W, C
391 | """
392 | B, L, C = x.shape
393 | assert L == H * W, "input feature has wrong size"
394 |
395 | x = x.view(B, H, W, C)
396 |
397 | # padding
398 | pad_input = (H % 2 == 1) or (W % 2 == 1)
399 | if pad_input:
400 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
401 |
402 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
403 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
404 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
405 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
406 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
407 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
408 |
409 | x = self.norm(x)
410 | x = self.reduction(x)
411 |
412 | return x
413 |
414 |
415 | class BasicLayer(nn.Module):
416 | """ A basic Swin Transformer layer for one stage.
417 |
418 | Args:
419 | dim (int): Number of input channels.
420 | depth (int): Number of blocks.
421 | num_heads (int): Number of attention heads.
422 | window_size (int): Local window size.
423 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
424 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
425 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
426 | drop (float, optional): Dropout rate. Default: 0.0
427 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
428 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
429 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
430 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
431 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
432 | """
433 |
434 | def __init__(self,
435 | dim,
436 | depth,
437 | num_heads,
438 | window_size=7,
439 | mlp_ratio=4.,
440 | qkv_bias=True,
441 | qk_scale=None,
442 | drop=0.,
443 | attn_drop=0.,
444 | drop_path=0.,
445 | norm_layer=nn.LayerNorm,
446 | downsample=None,
447 | use_checkpoint=False):
448 | super().__init__()
449 | self.window_size = window_size
450 | self.shift_size = window_size // 2
451 | self.depth = depth
452 | self.use_checkpoint = use_checkpoint
453 |
454 | # build blocks
455 | self.blocks = nn.ModuleList([
456 | SwinTransformerBlock(
457 | dim=dim,
458 | num_heads=num_heads,
459 | window_size=window_size,
460 | shift_size=0 if (i % 2 == 0) else window_size // 2,
461 | mlp_ratio=mlp_ratio,
462 | qkv_bias=qkv_bias,
463 | qk_scale=qk_scale,
464 | drop=drop,
465 | attn_drop=attn_drop,
466 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
467 | norm_layer=norm_layer)
468 | for i in range(depth)])
469 |
470 | # patch merging layer
471 | if downsample is not None:
472 | self.downsample = downsample(dim=dim, norm_layer=norm_layer)
473 | else:
474 | self.downsample = None
475 |
476 | def forward(self, x, H, W, infonce=False):
477 | # calculate attention mask for SW-MSA
478 | Hp = int(np.ceil(H / self.window_size)) * self.window_size
479 | Wp = int(np.ceil(W / self.window_size)) * self.window_size
480 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
481 | h_slices = (slice(0, -self.window_size),
482 | slice(-self.window_size, -self.shift_size),
483 | slice(-self.shift_size, None))
484 | w_slices = (slice(0, -self.window_size),
485 | slice(-self.window_size, -self.shift_size),
486 | slice(-self.shift_size, None))
487 | cnt = 0
488 | for h in h_slices:
489 | for w in w_slices:
490 | img_mask[:, h, w, :] = cnt
491 | cnt += 1
492 |
493 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
494 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
495 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
496 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
497 |
498 | attn_vs = []
499 | for blk in self.blocks:
500 | blk.H, blk.W = H, W
501 | if self.use_checkpoint:
502 | assert infonce == False
503 | x = checkpoint.checkpoint(blk, x, attn_mask)
504 | else:
505 | ret = blk(x, attn_mask, infonce)
506 | if infonce:
507 | x, attn_v = ret
508 | attn_vs.append(attn_v)
509 | else:
510 | x = ret
511 | if self.downsample is not None:
512 | x_down = self.downsample(x, H, W)
513 | Wh, Ww = (H + 1) // 2, (W + 1) // 2
514 | if infonce:
515 | return x, H, W, x_down, Wh, Ww, attn_vs
516 | else:
517 | return x, H, W, x_down, Wh, Ww
518 | else:
519 | if infonce:
520 | return x, H, W, x, H, W, attn_vs
521 | else:
522 | return x, H, W, x, H, W
523 |
524 |
525 | class PatchEmbed(nn.Module):
526 | r""" Image to Patch Embedding
527 |
528 | Args:
529 | patch_size (int): Patch token size. Default: 4.
530 | in_chans (int): Number of input image channels. Default: 3.
531 | embed_dim (int): Number of linear projection output channels. Default: 96.
532 | norm_layer (nn.Module, optional): Normalization layer. Default: None
533 | """
534 |
535 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
536 | super().__init__()
537 | patch_size = to_2tuple(patch_size)
538 | self.patch_size = patch_size
539 |
540 | self.in_chans = in_chans
541 | self.embed_dim = embed_dim
542 |
543 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
544 | if norm_layer is not None:
545 | self.norm = norm_layer(embed_dim)
546 | else:
547 | self.norm = None
548 |
549 | def forward(self, x):
550 | # padding
551 | _, _, H, W = x.size()
552 | if W % self.patch_size[1] != 0:
553 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
554 | if H % self.patch_size[0] != 0:
555 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
556 | x = self.proj(x) # B C Wh Ww
557 | if self.norm is not None:
558 | Wh, Ww = x.size(2), x.size(3)
559 | x = x.flatten(2).transpose(1, 2)
560 | x = self.norm(x)
561 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
562 | return x
563 |
564 |
565 | class SwinTransformer(nn.Module):
566 | r""" Swin Transformer
567 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
568 | https://arxiv.org/pdf/2103.14030
569 |
570 | Args:
571 | pretrain_img_size (int | tuple(int)): Input image size. Default 224
572 | patch_size (int | tuple(int)): Patch size. Default: 4
573 | in_chans (int): Number of input image channels. Default: 3
574 | embed_dim (int): Patch embedding dimension. Default: 96
575 | depths (tuple(int)): Depth of each Swin Transformer layer.
576 | num_heads (tuple(int)): Number of attention heads in different layers.
577 | window_size (int): Window size. Default: 7
578 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
579 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
580 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
581 | drop_rate (float): Dropout rate. Default: 0
582 | attn_drop_rate (float): Attention dropout rate. Default: 0
583 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
584 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
585 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
586 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
587 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
588 | """
589 |
590 | def __init__(self,
591 | pretrain_img_size=224,
592 | patch_size=4,
593 | in_chans=3,
594 | embed_dim=96,
595 | depths=[2, 2, 6, 2],
596 | num_heads=[3, 6, 12, 24],
597 | window_size=7,
598 | mlp_ratio=4.,
599 | qkv_bias=True,
600 | qk_scale=None,
601 | drop_rate=0.,
602 | attn_drop_rate=0.,
603 | drop_path_rate=0.2,
604 | norm_layer=nn.LayerNorm,
605 | ape=False,
606 | patch_norm=True,
607 | use_checkpoint=False,
608 | **kwargs):
609 | super().__init__()
610 |
611 | self.pretrain_img_size = pretrain_img_size
612 | self.num_layers = len(depths)
613 | self.embed_dim = embed_dim
614 | self.ape = ape
615 | self.patch_norm = patch_norm
616 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
617 | self.out_dim = self.num_features
618 |
619 | # split image into non-overlapping patches
620 | self.patch_embed = PatchEmbed(
621 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
622 | norm_layer=norm_layer if self.patch_norm else None)
623 |
624 | # absolute position embedding
625 | if self.ape:
626 | pretrain_img_size = to_2tuple(pretrain_img_size)
627 | patch_size = to_2tuple(patch_size)
628 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
629 |
630 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
631 | trunc_normal_(self.absolute_pos_embed, std=.02)
632 |
633 | self.pos_drop = nn.Dropout(p=drop_rate)
634 |
635 | # stochastic depth
636 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
637 |
638 | # build layers
639 | self.layers = nn.ModuleList()
640 | for i_layer in range(self.num_layers):
641 | layer = BasicLayer(
642 | dim=int(embed_dim * 2 ** i_layer),
643 | depth=depths[i_layer],
644 | num_heads=num_heads[i_layer],
645 | window_size=window_size,
646 | mlp_ratio=mlp_ratio,
647 | qkv_bias=qkv_bias,
648 | qk_scale=qk_scale,
649 | drop=drop_rate,
650 | attn_drop=attn_drop_rate,
651 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
652 | norm_layer=norm_layer,
653 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
654 | use_checkpoint=use_checkpoint)
655 | self.layers.append(layer)
656 |
657 | self.norm = norm_layer(self.num_features)
658 |
659 | # Projection head for CL
660 | self.proj = utils.ProjHead(self.out_dim, self.out_dim)
661 |
662 | # Projection head for EsViT extra CL
663 | self.proj2 = utils.ProjHead(self.out_dim, self.out_dim)
664 |
665 | @torch.jit.ignore
666 | def no_weight_decay(self):
667 | return {'absolute_pos_embed'}
668 |
669 | @torch.jit.ignore
670 | def no_weight_decay_keywords(self):
671 | return {'relative_position_bias_table'}
672 |
673 | def forward(self, x, info_nce=False):
674 | """Forward function."""
675 | x = self.patch_embed(x)
676 | Wh, Ww = x.size(2), x.size(3)
677 | if self.ape:
678 | # interpolate the position embedding to the corresponding size
679 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
680 | x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
681 | else:
682 | x = x.flatten(2).transpose(1, 2)
683 | x = self.pos_drop(x)
684 |
685 | for layer in self.layers:
686 | _, _, _, x, Wh, Ww = layer(x, Wh, Ww)
687 |
688 | x = self.norm(x) # B L C
689 | x = x.permute(0, 2, 1).contiguous()
690 | # output is (B, C, N)
691 | if info_nce:
692 | out = x
693 | B, C = out.size(0), out.size(1)
694 | x = out.reshape(B, C, -1).permute(0, 2, 1).contiguous()
695 | return out.contiguous(), (self.proj2(x.max(-2)[0]), self.proj(x), x)
696 | else:
697 | return x
698 |
699 | def get_last_selfattention(self, x):
700 | x, attn_v = self.forward(x, True)
701 | return attn_v
702 |
703 |
704 | @register('swin_tiny')
705 | def swin_tiny(pretrained=False, **kwargs):
706 | model = SwinTransformer(drop_path_rate=0.2, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, **kwargs)
707 | return model
708 |
709 |
710 | @register('swin_small')
711 | def swin_small(pretrained=False, **kwargs):
712 | model = SwinTransformer(drop_path_rate=0.3, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7, **kwargs)
713 | return model
714 |
715 |
716 | @register('swin_base')
717 | def swin_base(pretrained=False, **kwargs):
718 | model = SwinTransformer(drop_path_rate=0.3, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12, **kwargs)
719 | return model
720 |
--------------------------------------------------------------------------------
/models/transparent_encoder.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the Bongard-HOI library
5 | # which was released under the NVIDIA Source Code Licence.
6 | #
7 | # Source:
8 | # https://github.com/NVlabs/Bongard-HOI
9 | #
10 | # The license for the original version of this file can be
11 | # found in https://github.com/NVlabs/Bongard-HOI/blob/master/LICENSE
12 | # The modifications to this file are subject to the same NVIDIA Source Code Licence.
13 | # ---------------------------------------------------------------
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 |
19 | from detectron2.modeling.poolers import ROIPooler
20 |
21 | import models
22 | import utils
23 | from .models import register
24 |
25 |
26 | @register('transparent_bbox_encoder')
27 | class TransparentBBoxNetworkEncoder(nn.Module):
28 | def __init__(self, encoder, **kwargs):
29 | raise NotImplementedError('Currently we have to make sure each image only produce one token.')
30 | super(TransparentBBoxNetworkEncoder, self).__init__()
31 |
32 | # image encoder
33 | encoder = models.make(encoder)
34 | self.conv1 = encoder.conv1
35 | self.bn1 = encoder.bn1
36 | self.relu = encoder.relu
37 | self.maxpool = encoder.maxpool
38 | self.layer1 = encoder.layer1
39 | self.layer2 = encoder.layer2
40 | self.layer3 = encoder.layer3
41 | self.layer4 = encoder.layer4
42 |
43 | self.proj = nn.Conv2d(encoder.out_dim, encoder.out_dim // 2, kernel_size=1)
44 |
45 | # ROI Pooler
46 | self.roi_pooler = ROIPooler(
47 | output_size=7,
48 | scales=(1/32,),
49 | sampling_ratio=0,
50 | pooler_type='ROIAlignV2',
51 | )
52 | self.roi_processor = nn.Sequential(
53 | nn.Conv2d(encoder.out_dim // 2, 256, kernel_size=3, padding=1),
54 | nn.BatchNorm2d(256),
55 | nn.ReLU(),
56 | nn.Flatten(),
57 | nn.Linear(256*7*7, 512),
58 | nn.ReLU()
59 | )
60 | self.roi_processor_ln = nn.LayerNorm(512)
61 |
62 | # bbox coord encoding
63 | self.roi_processor_box = nn.Linear(4, 128)
64 | self.roi_processor_box_ln = nn.LayerNorm(128)
65 | rn_in_planes = 512 + 128
66 |
67 | self.out_dim = rn_in_planes
68 |
69 | def forward(self, im, boxes, boxes_dim=None):
70 | # assert im.shape[0] == len(boxes), 'im: {} vs boxes: {}'.format(im.shape[0], len(boxes))
71 | img_shape = im.shape
72 | im = im.view(-1, *img_shape[-3:])
73 | num_im = im.size(0)
74 | # assert im.shape[0] == boxes_dim.shape[0], '{} vs {}'.format(im.shape, boxes_dim.shape)
75 | if boxes_dim is not None:
76 | boxes_dim_shape = boxes_dim.shape
77 | boxes_dim = boxes_dim.view(-1, *boxes_dim_shape[-1:])
78 |
79 | # BxCxHxW
80 | x = self.conv1(im)
81 | x = self.bn1(x)
82 | x = self.relu(x)
83 | x = self.maxpool(x)
84 |
85 | x = self.layer1(x)
86 | x = self.layer2(x)
87 | x = self.layer3(x)
88 | x = self.layer4(x)
89 |
90 | x = self.proj(x)
91 |
92 | # RoI pooling/align
93 | # x_shape = list(img_shape[:-3]) + list(x.shape[-3:])
94 | # x = x.view(x_shape)
95 | all_boxes = []
96 | for boxes_i in boxes:
97 | all_boxes.extend(boxes_i)
98 | num_boxes = [boxes_i.tensor.shape[0] for boxes_i in all_boxes]
99 |
100 | # roi_feats = roi_align(
101 | # x, all_boxes,
102 | # output_size=(7, 7),
103 | # spatial_scale=1/32.,
104 | # sampling_ratio=0,
105 | # aligned=True
106 | # )
107 | roi_feats = self.roi_pooler([x], all_boxes)
108 | roi_feats = self.roi_processor(roi_feats)
109 | roi_feats = self.roi_processor_ln(roi_feats)
110 | # Add bbox pos features
111 | bbox_tensor = torch.cat([box.tensor for box in all_boxes]).to(roi_feats)
112 | # bbox coord normalization
113 | bbox_tensor[:, 0] = bbox_tensor[:, 0] / im.shape[3]
114 | bbox_tensor[:, 1] = bbox_tensor[:, 1] / im.shape[2]
115 | bbox_tensor[:, 2] = bbox_tensor[:, 2] / im.shape[3]
116 | bbox_tensor[:, 3] = bbox_tensor[:, 3] / im.shape[2]
117 | bbox_tensor = bbox_tensor * 2 - 1
118 | roi_box_feats = self.roi_processor_box_ln(self.roi_processor_box(bbox_tensor))
119 | roi_feats = torch.cat([roi_feats, roi_box_feats], dim=-1)
120 |
121 | # TODO: This assumes all the images have the same number of bboxes.
122 | feat = roi_feats.reshape(num_im, -1, roi_feats.size(-1))
123 | return feat
124 |
125 | # CNN backbone then superpixels as patches
126 | @register('transparent_superpixel_encoder')
127 | class TransparentSuperpixelEncoder(nn.Module):
128 | def __init__(self, encoder, **kwargs):
129 | super().__init__()
130 |
131 | # image encoder
132 | encoder = models.make(encoder)
133 | self.encoder = encoder
134 | self.out_dim = encoder.out_dim
135 |
136 | def forward(self, im, boxes=None, boxes_dim=None, info_nce=False):
137 | img_shape = im.shape
138 | im = im.view(-1, *img_shape[-3:])
139 | num_im = im.size(0)
140 |
141 | # BxCxHxW
142 | if info_nce:
143 | feats, attn_v = self.encoder(im, info_nce)
144 | return feats, attn_v
145 | else:
146 | feats = self.encoder(im)
147 | return feats
--------------------------------------------------------------------------------
/models/vit.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the pytorch-image-models library
5 | # which was released under the Apache Licence.
6 | #
7 | # Source:
8 | # https://github.com/rwightman/pytorch-image-models/
9 | #
10 | # The license for the original version of this file can be
11 | # found in https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
12 | # The modifications to this file are subject to the same Apache Licence.
13 | # ---------------------------------------------------------------
14 |
15 | import logging
16 | from functools import partial
17 | from collections import OrderedDict
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 |
23 | from timm.models.layers import Mlp, DropPath, to_2tuple
24 |
25 | import utils
26 | from .models import register
27 |
28 | _logger = logging.getLogger(__name__)
29 |
30 |
31 | __all__ = [
32 | 'vit_tiny_16', 'vit_small_16', 'vit_base_16', 'vit_large_16'
33 | ]
34 |
35 |
36 | class PatchEmbed(nn.Module):
37 | """ 2D Image to Patch Embedding
38 | """
39 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
40 | super().__init__()
41 | img_size = to_2tuple(img_size)
42 | patch_size = to_2tuple(patch_size)
43 | self.img_size = img_size
44 | self.patch_size = patch_size
45 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
46 | self.num_patches = self.H * self.W
47 | self.flatten = flatten
48 |
49 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
50 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
51 |
52 | def forward(self, x):
53 | B, C, H, W = x.shape
54 | # assert H == self.img_size[0] and W == self.img_size[1], \
55 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
56 | x = self.proj(x)
57 | H, W = x.size(2), x.size(3)
58 | if self.flatten:
59 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
60 | x = self.norm(x)
61 | return x, (H, W)
62 |
63 |
64 | class Attention(nn.Module):
65 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
66 | super().__init__()
67 | self.num_heads = num_heads
68 | head_dim = dim // num_heads
69 | self.scale = head_dim ** -0.5
70 |
71 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
72 | self.attn_drop = nn.Dropout(attn_drop)
73 | self.proj = nn.Linear(dim, dim)
74 | self.proj_drop = nn.Dropout(proj_drop)
75 |
76 | def forward(self, x, info_nce=False):
77 | B, N, C = x.shape
78 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
79 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
80 |
81 | attn = (q @ k.transpose(-2, -1)) * self.scale
82 | attn = attn.softmax(dim=-1)
83 | attn = self.attn_drop(attn)
84 | if info_nce:
85 | attn_ret = attn.detach()
86 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
87 | x = self.proj(x)
88 | x = self.proj_drop(x)
89 | if info_nce:
90 | return x, (attn_ret, v)
91 | else:
92 | return x
93 |
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
98 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
99 | super().__init__()
100 | self.norm1 = norm_layer(dim)
101 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
102 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
103 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
104 | self.norm2 = norm_layer(dim)
105 | mlp_hidden_dim = int(dim * mlp_ratio)
106 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
107 |
108 | def forward(self, x, info_nce=False):
109 | if info_nce:
110 | tmp, attn_v = self.attn(self.norm1(x), info_nce)
111 | else:
112 | tmp = self.attn(self.norm1(x))
113 | x = x + self.drop_path(tmp)
114 | x = x + self.drop_path(self.mlp(self.norm2(x)))
115 | if info_nce:
116 | return x, attn_v
117 | else:
118 | return x
119 |
120 |
121 | class VisionTransformer(nn.Module):
122 | """ Vision Transformer
123 |
124 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
125 | - https://arxiv.org/abs/2010.11929
126 |
127 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
128 | - https://arxiv.org/abs/2012.12877
129 | """
130 |
131 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
132 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
133 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
134 | act_layer=None, weight_init=''):
135 | """
136 | Args:
137 | img_size (int, tuple): input image size
138 | patch_size (int, tuple): patch size
139 | in_chans (int): number of input channels
140 | num_classes (int): number of classes for classification head
141 | embed_dim (int): embedding dimension
142 | depth (int): depth of transformer
143 | num_heads (int): number of attention heads
144 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
145 | qkv_bias (bool): enable bias for qkv if True
146 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
147 | distilled (bool): model includes a distillation token and head as in DeiT models
148 | drop_rate (float): dropout rate
149 | attn_drop_rate (float): attention dropout rate
150 | drop_path_rate (float): stochastic depth rate
151 | embed_layer (nn.Module): patch embedding layer
152 | norm_layer: (nn.Module): normalization layer
153 | weight_init: (str): weight init scheme
154 | """
155 | super().__init__()
156 | self.num_classes = num_classes
157 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
158 | self.num_tokens = 2 if distilled else 1
159 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
160 | act_layer = act_layer or nn.GELU
161 | self.out_dim = embed_dim
162 |
163 | self.patch_embed = embed_layer(
164 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
165 | num_patches = self.patch_embed.num_patches
166 |
167 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
168 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
169 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
170 | self.pos_drop = nn.Dropout(p=drop_rate)
171 |
172 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
173 | self.blocks = nn.Sequential(*[
174 | Block(
175 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
176 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
177 | for i in range(depth)])
178 | self.norm = norm_layer(embed_dim)
179 |
180 | # Representation layer
181 | if representation_size and not distilled:
182 | self.num_features = representation_size
183 | self.pre_logits = nn.Sequential(OrderedDict([
184 | ('fc', nn.Linear(embed_dim, representation_size)),
185 | ('act', nn.Tanh())
186 | ]))
187 | else:
188 | self.pre_logits = nn.Identity()
189 |
190 | # Classifier head(s)
191 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
192 | self.head_dist = None
193 | if distilled:
194 | assert NotImplementedError
195 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
196 |
197 | # Projection head for CL
198 | self.proj = utils.ProjHead(embed_dim, embed_dim)
199 |
200 | # Projection head for EsViT extra CL
201 | self.proj2 = utils.ProjHead(embed_dim, embed_dim)
202 |
203 | self.proj_v = utils.ProjHead(embed_dim//num_heads, embed_dim//num_heads)
204 | self.num_heads = num_heads
205 |
206 | @torch.jit.ignore
207 | def no_weight_decay(self):
208 | return {'pos_embed', 'cls_token', 'dist_token'}
209 |
210 | def get_classifier(self):
211 | if self.dist_token is None:
212 | return self.head
213 | else:
214 | return self.head, self.head_dist
215 |
216 | def reset_classifier(self, num_classes, global_pool=''):
217 | self.num_classes = num_classes
218 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
219 | if self.num_tokens == 2:
220 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
221 |
222 | def _get_pos_embed(self, pos_embed, patch_embed, H, W):
223 | if H * W == patch_embed.num_patches:
224 | return pos_embed
225 | else:
226 | return F.interpolate(
227 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
228 | size=(H, W), mode="bilinear", align_corners=False).reshape(1, -1, H * W).permute(0, 2, 1)
229 |
230 | def forward_features(self, x, info_nce=False):
231 | if info_nce:
232 | attn_vs = []
233 | B = x.size(0)
234 | x, (H, W) = self.patch_embed(x)
235 | # skip extra tokens and resize pos_embed if needed
236 | tokens_pos_emb = self.pos_embed[:, :self.num_tokens]
237 | patches_pos_emb = self._get_pos_embed(self.pos_embed[:, self.num_tokens:], self.patch_embed, H, W)
238 | if info_nce in [1, 4]:
239 | x = torch.cat([self.cls_token.expand(B, -1, -1).detach(), x], dim=1)
240 | x = self.pos_drop(x + torch.cat([tokens_pos_emb, patches_pos_emb], dim=1))
241 | else:
242 | x = self.pos_drop(x + patches_pos_emb)
243 | for block in self.blocks:
244 | if info_nce:
245 | x, attn_v = block(x, info_nce)
246 | attn_vs.append(attn_v)
247 | else:
248 | x = block(x)
249 | x = self.norm(x)
250 | # output is (B, C, N)
251 | if info_nce:
252 | return x.permute(0, 2, 1).contiguous()[:, :, [0]], (self.proj2(x[:, 0]), self.proj(x[:, 1:]), x[:, 1:])
253 | else:
254 | return x.permute(0, 2, 1).contiguous()
255 |
256 | def forward(self, x, info_nce=False):
257 | if info_nce:
258 | x, attn_v = self.forward_features(x, info_nce)
259 | return x, attn_v
260 | else:
261 | x = self.forward_features(x)
262 | return x
263 |
264 | def get_last_selfattention(self, x):
265 | x, attn_v = self.forward_features(x, True)
266 | return attn_v[0]
267 |
268 |
269 | @register('vit_tiny_16')
270 | def vit_tiny_16(pretrained=False, **kwargs):
271 | """ ViT-Tiny (Vit-Ti/16)
272 | """
273 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
274 | model = VisionTransformer(**model_kwargs, **kwargs)
275 | return model
276 |
277 |
278 | @register('vit_small_16')
279 | def vit_small_16(pretrained=False, **kwargs):
280 | """ ViT-Small (ViT-S/16)
281 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
282 | """
283 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
284 | model = VisionTransformer(**model_kwargs, **kwargs)
285 | return model
286 |
287 |
288 | @register('vit_base_16')
289 | def vit_base_16(pretrained=False, **kwargs):
290 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
291 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
292 | """
293 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
294 | model = VisionTransformer(**model_kwargs, **kwargs)
295 | return model
296 |
297 |
298 | @register('vit_large_16')
299 | def vit_large_16(pretrained=False, **kwargs):
300 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
301 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
302 | """
303 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
304 | model = VisionTransformer(**model_kwargs, **kwargs)
305 | return model
306 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pillow
3 | matplotlib
4 | tqdm
5 | pandas
6 | tensorboardX==2.4
7 | opencv-python==4.5.4.58
8 | protobuf==3.20.1
9 | scikit-learn
10 | git+https://github.com/facebookresearch/detectron2.git
--------------------------------------------------------------------------------
/scripts/train_gqa_image.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CFG=$1
4 | OPTS=${@:2}
5 |
6 | python train_gqa.py --config-file $CFG $OPTS 2>&1
7 |
--------------------------------------------------------------------------------
/scripts/train_hico_image.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CFG=$1
4 | OPTS=${@:2}
5 |
6 | python train_hico.py --config-file $CFG $OPTS 2>&1
7 |
--------------------------------------------------------------------------------
/train_gqa.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for RelViT. To view a copy of this license, see the LICENSE file.
6 | # ----------------------------------------------------------------------
7 |
8 | import argparse
9 | import os
10 | from re import L
11 | import yaml
12 | import numpy as np
13 | import random
14 | import time
15 |
16 | import torch
17 | import torch.backends.cudnn as cudnn
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import torch.nn.parallel
21 | import torch.optim
22 | import torch.utils.data
23 | import torch.multiprocessing as mp
24 | import torch.distributed as dist
25 |
26 | from tensorboardX import SummaryWriter
27 |
28 | from tqdm import tqdm
29 |
30 | import datasets
31 | import models
32 | import utils
33 | from datasets.image_gqa import collate_gqa
34 |
35 |
36 | def compute_acc_gqa(final_y_true, final_y_pred):
37 | count = len(final_y_true)
38 | correct = (final_y_true == final_y_pred).sum()
39 | return (correct / count).item()
40 |
41 | def gather_score_label(dist, loader, current_idx, y_true, y_pred, ddp=False):
42 | if ddp:
43 | # all_gather and compute Acc
44 | current_idx_list = [torch.zeros(1).long().cuda(non_blocking=True) for _ in range(args.world_size)]
45 | y_true_list = [torch.zeros((len(loader.dataset), 1)).long().cuda(non_blocking=True) for _ in range(args.world_size)]
46 | y_pred_list = [torch.zeros((len(loader.dataset), 1)).long().cuda(non_blocking=True) for _ in range(args.world_size)]
47 | dist.all_gather(current_idx_list, current_idx)
48 | dist.all_gather(y_true_list, y_true)
49 | dist.all_gather(y_pred_list, y_pred)
50 |
51 | final_y_true = np.empty((0, 1))
52 | final_y_pred = np.empty((0, 1))
53 | for idx, yt, ys in zip(current_idx_list, y_true_list, y_pred_list):
54 | idx = idx.item()
55 | yt = yt.detach().cpu().numpy()
56 | ys = ys.detach().cpu().numpy()
57 | final_y_true = np.vstack((final_y_true, yt[:idx]))
58 | final_y_pred = np.vstack((final_y_pred, ys[:idx]))
59 | else:
60 | final_y_true = y_true.detach().cpu().numpy()[:current_idx.item()]
61 | final_y_pred = y_pred.detach().cpu().numpy()[:current_idx.item()]
62 | return final_y_true, final_y_pred
63 |
64 |
65 | def main(config):
66 | args.gpu = ''#[i for i in range(torch.cuda.device_count())]
67 | args.train_gpu = [i for i in range(torch.cuda.device_count())]
68 | num_gpus = torch.cuda.device_count()
69 | for i in range(num_gpus - 1):
70 | args.gpu += '{},'.format(i)
71 | args.gpu += '{}'.format(num_gpus - 1)
72 | if len(args.gpu.split(',')) > 1:
73 | config['_parallel'] = True
74 | config['_gpu'] = args.gpu
75 | utils.set_gpu(args.gpu)
76 | args.config = config
77 |
78 | if args.seed is not None:
79 | random.seed(args.seed)
80 | np.random.seed(args.seed)
81 | torch.manual_seed(args.seed)
82 | torch.cuda.manual_seed(args.seed)
83 | torch.cuda.manual_seed_all(args.seed)
84 | cudnn.benchmark = False
85 | cudnn.deterministic = True
86 | if args.dist_url == "env://" and args.world_size == -1:
87 | args.world_size = int(os.environ["WORLD_SIZE"])
88 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
89 | args.ngpus_per_node = len(args.train_gpu)
90 | if len(args.train_gpu) == 1:
91 | args.sync_bn = False
92 | args.distributed = False
93 | args.multiprocessing_distributed = False
94 | if args.multiprocessing_distributed:
95 | args.sync_bn = True
96 | if args.dist_url[-2:] == '{}':
97 | port = utils.find_free_port()
98 | args.dist_url = args.dist_url.format(port)
99 | args.world_size = args.ngpus_per_node * args.world_size
100 | mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args))
101 | else:
102 | main_worker(args.train_gpu, args.ngpus_per_node, args)
103 |
104 |
105 | def main_worker(gpu, ngpus_per_node, argss):
106 | global args
107 | args = argss
108 | if args.distributed:
109 | if args.dist_url == "env://" and args.rank == -1:
110 | args.rank = int(os.environ["RANK"])
111 | if args.multiprocessing_distributed:
112 | args.rank = args.rank * ngpus_per_node + gpu
113 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
114 |
115 | config = args.config
116 | svname = args.svname
117 | if svname is None:
118 | config_name, _ = os.path.splitext(os.path.basename(args.config_file))
119 | svname = 'hico'
120 | svname += '_' + config['model']
121 | if config['model_args'].get('encoder'):
122 | svname += '-' + config['model_args']['encoder']
123 | svname = os.path.join(config_name, config['train_dataset'], svname)
124 | if not args.test_only:
125 | svname += '-seed' + str(args.seed)
126 | if args.tag is not None:
127 | svname += '_' + args.tag
128 |
129 | sub_dir_name = 'default'
130 | if args.opts:
131 | sub_dir_name = args.opts[0]
132 | split = '#'
133 | for opt in args.opts[1:]:
134 | sub_dir_name += split + opt
135 | split = '#' if split == '_' else '_'
136 | svname = os.path.join(svname, sub_dir_name)
137 |
138 | if utils.is_main_process() and not args.test_only:
139 | save_path = os.path.join(args.save_dir, svname)
140 | utils.ensure_path(save_path, remove=False)
141 | utils.set_log_path(save_path)
142 | writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
143 | args.writer = writer
144 |
145 | yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))
146 |
147 | logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"), file_mode="a+", should_flush=True)
148 | else:
149 | save_path = None
150 | writer = None
151 | args.writer = writer
152 | logger = None
153 |
154 | #### Dataset ####
155 |
156 | if config.get('ep_per_batch') is not None:
157 | ep_per_batch = config['ep_per_batch']
158 | else:
159 | ep_per_batch = 1
160 |
161 | # train
162 | dataset_configs = config['train_dataset_args']
163 | train_dataset = datasets.make(config['train_dataset'], eval_mode=config['eval_mode'], **dataset_configs)
164 | if utils.is_main_process():
165 | utils.log('train dataset: {} samples'.format(len(train_dataset)))
166 | if args.distributed:
167 | args.batch_size = int(ep_per_batch / args.world_size)
168 | args.batch_size_test = int(ep_per_batch / args.world_size)
169 | args.workers = int((args.workers + args.world_size - 1) / args.world_size)
170 | else:
171 | args.batch_size = ep_per_batch
172 | args.batch_size_test = ep_per_batch
173 | args.workers = args.workers
174 |
175 | if args.distributed:
176 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
177 | else:
178 | train_sampler = None
179 | train_loader = torch.utils.data.DataLoader(
180 | train_dataset,
181 | batch_size=args.batch_size,
182 | shuffle=(train_sampler is None),
183 | num_workers=args.workers,
184 | pin_memory=True,
185 | sampler=train_sampler,
186 | drop_last=True,
187 | collate_fn=collate_gqa
188 | )
189 |
190 | # testing
191 | dataset_configs = config['test_dataset_args']
192 | test_dataset = datasets.make(config['test_dataset'], eval_mode=config['eval_mode'], **dataset_configs)
193 | if utils.is_main_process():
194 | utils.log('test dataset: {} samples'.format(len(test_dataset)))
195 | if args.distributed:
196 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
197 | else:
198 | test_sampler = None
199 | test_loader = torch.utils.data.DataLoader(
200 | test_dataset,
201 | batch_size=args.batch_size_test,
202 | shuffle=False,
203 | num_workers=args.workers,
204 | pin_memory=True,
205 | sampler=test_sampler,
206 | collate_fn=collate_gqa
207 | )
208 |
209 | ########
210 |
211 | #### Model and optimizer ####
212 |
213 | if config.get('load'):
214 | print('loading pretrained model: ', config['load'])
215 | model = models.load(torch.load(config['load']))
216 | if config['relvit']:
217 | model_tea_k = models.load(torch.load(config['load']))
218 | else:
219 | model = models.make(config['model'], **config['model_args'])
220 | if config['relvit']:
221 | model_tea_k = models.make(config['model'], **config['model_args'])
222 |
223 | if config.get('load_encoder'):
224 | pretrain = config.get('encoder_pretrain').lower()
225 | if pretrain != 'scratch':
226 | pretrain_model_path = config['load_encoder'].format(pretrain)
227 | state_dict = torch.load(pretrain_model_path, map_location='cpu')
228 | missing_keys, unexpected_keys = model.encoder.encoder.load_state_dict(state_dict, strict=False)
229 | for key in missing_keys:
230 | assert key.startswith('g_mlp.') \
231 | or key.startswith('proj') \
232 | or key.startswith('trans') \
233 | or key.startswith('roi_processor') \
234 | or key.startswith('roi_dim_processor') \
235 | or key.startswith('classifier'), key
236 | for key in unexpected_keys:
237 | assert key.startswith('fc.')
238 | if utils.is_main_process():
239 | utils.log('==> Successfully loaded {} for the enocder.'.format(pretrain_model_path))
240 | if config['relvit']:
241 | model_tea_k.load_state_dict(model.state_dict())
242 |
243 | if args.sync_bn:
244 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
245 | if config['relvit']:
246 | model_tea_k = nn.SyncBatchNorm.convert_sync_batchnorm(model_tea_k)
247 |
248 | if utils.is_main_process():
249 | utils.log(model)
250 |
251 | if args.distributed:
252 | torch.cuda.set_device(gpu)
253 | model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu], find_unused_parameters=True, broadcast_buffers=False)
254 | if config['relvit']:
255 | model_tea_k = torch.nn.parallel.DistributedDataParallel(model_tea_k.cuda(), device_ids=[gpu], find_unused_parameters=True, broadcast_buffers=False)
256 | else:
257 | model = torch.nn.DataParallel(model.cuda())
258 | if config['relvit']:
259 | model_tea_k = torch.nn.DataParallel(model_tea_k.cuda())
260 |
261 | if utils.is_main_process() and not args.test_only:
262 | utils.log('num params: {}'.format(utils.compute_n_params(model)))
263 | utils.log('Results will be saved to {}'.format(save_path))
264 |
265 | max_steps = min(len(train_loader), config['train_batches']) * config['max_epoch']
266 | optimizer, lr_scheduler, update_lr_every_epoch = utils.make_optimizer(
267 | model.parameters(),
268 | config['optimizer'], max_steps, **config['optimizer_args']
269 | )
270 | assert lr_scheduler is not None
271 | args.update_lr_every_epoch = update_lr_every_epoch
272 |
273 | if args.test_only:
274 | filename = args.test_model
275 | assert os.path.exists(filename)
276 | ckpt = torch.load(filename, map_location='cpu')
277 | start_epoch = ckpt['epoch']
278 | model.load_state_dict(ckpt['state_dict'])
279 | if config['relvit']:
280 | model_tea_k.load_state_dict(ckpt['tea_k_state_dict'])
281 | optimizer.load_state_dict(ckpt['optimizer'])
282 | lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
283 | best_test_result = ckpt['best_test_result']
284 | if utils.is_main_process():
285 | utils.log('==> Sucessfully resumed from a checkpoint {}'.format(filename))
286 | else:
287 | ckpt = None
288 | start_epoch = 0
289 | best_test_result = 0.0
290 |
291 | ######## MoCo
292 | if config['relvit']:
293 | cl_loss = 'esvit'
294 | feat_dim = model.module.encoder.encoder.out_dim
295 | moco = utils.relvit.MoCo(
296 | model_tea_k.module,
297 | model.module,
298 | config['relvit_moco_K'],
299 | config['relvit_moco_m'],
300 | feat_dim,
301 | num_concepts=config['relvit_num_concepts'],
302 | relvit_mode=config['relvit_mode'],
303 | num_tokens=config['relvit_num_tokens'])
304 | moco = moco.cuda()
305 | if ckpt is not None:
306 | try:
307 | moco.load_state_dict(ckpt['moco'])
308 | utils.log('==> MoCo sucessfully resumed from a checkpoint.')
309 | except:
310 | utils.log('==> MoCo is not resumed.')
311 | else:
312 | moco = None
313 | model_tea_k = None
314 | ########
315 |
316 | ######## Training & Validation
317 |
318 | max_epoch = config['max_epoch']
319 | save_epoch = config.get('save_epoch')
320 | max_va = best_test_result
321 | timer_used = utils.Timer()
322 | timer_epoch = utils.Timer()
323 |
324 | if args.test_only:
325 | ret = test(test_loader, model_tea_k if model_tea_k is not None else model, 0, args)
326 | if ret is None:
327 | return 0
328 | else:
329 | loss_test, acc_test = ret
330 | if utils.is_main_process():
331 | print('Testing result: ', acc_test)
332 | return 0
333 |
334 | for epoch in range(start_epoch, max_epoch):
335 | # timer_epoch.s()
336 | # aves = {k: utils.Averager() for k in aves_keys}
337 |
338 | # TODO: save param of prev epoch to resume from NaN
339 | sd_model = model.state_dict()
340 | if config['relvit']:
341 | sd_model_tea_k = model_tea_k.state_dict()
342 | sd_moco = moco.state_dict()
343 |
344 | epoch_log = epoch + 1
345 | finish = False
346 | while finish is False:
347 | model.load_state_dict(sd_model)
348 | if config['relvit']:
349 | model_tea_k.load_state_dict(sd_model_tea_k)
350 | moco.load_state_dict(sd_moco)
351 | if args.distributed:
352 | train_sampler.set_epoch(epoch)
353 | ret = train(train_loader, model, optimizer, lr_scheduler, epoch_log, (moco, model_tea_k), writer, args)
354 | # autoresume
355 | if ret is None:
356 | return 0
357 | # NaN
358 | elif ret == -1:
359 | utils.synchronize()
360 | if utils.is_main_process():
361 | utils.log('NaN detected, redo the current epoch...')
362 | continue
363 | else:
364 | finish = True
365 | (loss_train, aux_loss_train), acc_train = ret
366 | import gc; gc.collect()
367 | torch.cuda.empty_cache()
368 | if args.update_lr_every_epoch:
369 | lr_scheduler.step()
370 | if utils.is_main_process():
371 | writer.add_scalar('loss_train', loss_train, epoch_log)
372 | writer.add_scalar('aux_loss_train', aux_loss_train, epoch_log)
373 | writer.add_scalar('Acc_train', acc_train, epoch_log)
374 | for name, param in model.named_parameters():
375 | writer.add_histogram(name, param, epoch_log)
376 |
377 | if (epoch_log % config['save_epoch'] == 0 or epoch_log == config['max_epoch']) and utils.is_main_process():
378 | filename = os.path.join(save_path, 'train.pth')
379 | utils.log('==> Saving checkpoint to: ' + filename)
380 | ckpt = {
381 | 'epoch': epoch_log,
382 | 'state_dict': model.state_dict(),
383 | 'tea_k_state_dict': model_tea_k.state_dict() if config['relvit'] else model.state_dict(),
384 | 'moco': moco.state_dict() if moco else None,
385 | 'optimizer': optimizer.state_dict(),
386 | 'lr_scheduler': lr_scheduler.state_dict(),
387 | 'best_test_result': best_test_result,
388 | }
389 | torch.save(ckpt, filename)
390 | acc_test = 0
391 | if epoch_log % config['eval_epoch'] == 0:
392 | ret = test(test_loader, model_tea_k if model_tea_k is not None else model, epoch_log, args)
393 | if ret is None:
394 | return 0
395 | else:
396 | loss_test, acc_test = ret
397 | import gc; gc.collect()
398 | torch.cuda.empty_cache()
399 | if acc_test > best_test_result:
400 | best_test_result = acc_test
401 | if utils.is_main_process():
402 | utils.log('test result: loss {:.4f}, Acc: {:.4f}.'.format(loss_test, acc_test))
403 | writer.add_scalar('loss_test', loss_test, epoch_log)
404 | writer.add_scalar('Acc_test', acc_test, epoch_log)
405 |
406 | if utils.is_main_process():
407 | utils.log('Best test results so far:')
408 | utils.log(best_test_result)
409 |
410 | if acc_test > max_va and utils.is_main_process():
411 | max_va = acc_test
412 | filename = os.path.join(save_path, 'best_model.pth')
413 | ckpt = {
414 | 'epoch': epoch_log,
415 | 'state_dict': model.state_dict(),
416 | 'tea_k_state_dict': model_tea_k.state_dict() if config['relvit'] else model.state_dict(),
417 | 'optimizer': optimizer.state_dict(),
418 | 'lr_scheduler': lr_scheduler.state_dict(),
419 | 'best_test_result': best_test_result,
420 | }
421 | torch.save(ckpt, filename)
422 | if utils.is_main_process():
423 | writer.flush()
424 |
425 | if utils.is_main_process():
426 | logger.close()
427 |
428 | def train(train_loader, model, optimizer, lr_scheduler, epoch, moco_tuple, writer, args):
429 | batch_time = utils.AverageMeter()
430 | data_time = utils.AverageMeter()
431 | main_loss_meter = utils.AverageMeter()
432 | aux_loss_meter = utils.AverageMeter()
433 | loss_meter = utils.AverageMeter()
434 | intersection_meter = utils.AverageMeter()
435 | union_meter = utils.AverageMeter()
436 | target_meter = utils.AverageMeter()
437 | # Acc
438 | current_idx = torch.zeros(1).long()
439 | y_true = torch.zeros((len(train_loader.dataset), 1)).long()
440 | y_pred = torch.zeros((len(train_loader.dataset), 1)).long()
441 |
442 | config = args.config
443 |
444 | # train
445 | model.train()
446 |
447 | if utils.is_main_process():
448 | args.writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
449 |
450 | end = time.time()
451 | max_iter = config['max_epoch'] * len(train_loader)
452 | for batch_idx, batch in enumerate(train_loader):
453 | if batch_idx >= config['train_batches']:
454 | break
455 |
456 | feed_list = []
457 | ims = batch[0].cuda(non_blocking=True)
458 | second_ims = batch[1].cuda(non_blocking=True)
459 | third_ims = batch[2].cuda(non_blocking=True)
460 | feed_list.append(ims)
461 |
462 | # For Two-view denseCL
463 | cl_ims = torch.cat([second_ims, third_ims], dim=0)
464 | ques_ix_iter = batch[3].cuda(non_blocking=True)
465 | feed_list.append(ques_ix_iter)
466 | ans_iter = batch[5].cuda(non_blocking=True)
467 | concepts = batch[6].cuda(non_blocking=True)
468 |
469 | # with torch.cuda.amp.autocast(enabled=args.amp), torch.autograd.detect_anomaly():
470 | with torch.cuda.amp.autocast(enabled=args.amp):
471 | if config['relvit']:
472 | logits = model(*feed_list)
473 | B = cl_ims.size(0) // 2
474 | with torch.no_grad():
475 | d1, attn_v_k = moco_tuple[1](cl_ims, ques_ix_iter.repeat(2, 1), True)
476 | d2, attn_v_q = model(cl_ims, ques_ix_iter.repeat(2, 1), True)
477 | aux_loss = moco_tuple[0](attn_v_k, attn_v_q, B, config, args, concepts) + d1.mean()*0 + d2.mean()*0
478 | else:
479 | logits = model(*feed_list)
480 | aux_loss = torch.zeros(1).to(logits)
481 |
482 | loss = F.cross_entropy(logits, ans_iter.squeeze(-1), reduction='sum') + config['relvit_weight'] * aux_loss
483 | y_pred[current_idx.item():current_idx.item()+logits.size(0)] = logits.detach().argmax(-1).unsqueeze(-1).cpu()
484 | y_true[current_idx.item():current_idx.item()+logits.size(0)] = ans_iter.detach().cpu()
485 | current_idx += logits.size(0)
486 | if torch.isnan(loss).all().item():
487 | utils.synchronize()
488 | return -1
489 | optimizer.zero_grad()
490 | loss.backward()
491 | torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_norm'])
492 | optimizer.step()
493 | lrs = lr_scheduler.get_last_lr()
494 | if not args.update_lr_every_epoch:
495 | lr_scheduler.step()
496 |
497 | n = logits.size(0)
498 | if args.multiprocessing_distributed:
499 | # TODO: all gather y_score/true
500 | loss = loss * n # not considering ignore pixels
501 | count = ans_iter.new_tensor([n], dtype=torch.long)
502 | dist.all_reduce(loss)
503 | dist.all_reduce(count)
504 | n = count.item()
505 | loss = loss / n
506 |
507 | loss_meter.update(loss.item(), logits.size(0))
508 | aux_loss_meter.update(aux_loss.item(), logits.size(0))
509 | batch_time.update(time.time() - end)
510 | end = time.time()
511 |
512 | current_iter = epoch * len(train_loader) + batch_idx + 1
513 | remain_iter = max_iter - current_iter
514 | remain_time = remain_iter * batch_time.avg
515 | t_m, t_s = divmod(remain_time, 60)
516 | t_h, t_m = divmod(t_m, 60)
517 | remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
518 |
519 | if (batch_idx + 1) % config['print_freq'] == 0 and utils.is_main_process():
520 | utils.log(
521 | '{} Epoch: [{}/{}][{}/{}] '
522 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
523 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
524 | 'Remain {remain_time} '
525 | 'Loss {loss_meter.val:.4f} '
526 | 'Aux Loss {aux_loss_meter.val:.4f} '
527 | 'lr {lr:.6f}'.format(
528 | 'Train',
529 | epoch, config['max_epoch'], batch_idx + 1, len(train_loader),
530 | batch_time=batch_time,
531 | data_time=data_time,
532 | remain_time=remain_time,
533 | loss_meter=loss_meter,
534 | aux_loss_meter=aux_loss_meter,
535 | lr=lrs[0]
536 | )
537 | )
538 | final_y_true, final_y_pred = gather_score_label(dist, train_loader, current_idx.cuda(), y_true.cuda(), y_pred.cuda(), args.multiprocessing_distributed)
539 | acc = compute_acc_gqa(final_y_true, final_y_pred)
540 | if utils.is_main_process():
541 | utils.log('{} result at epoch [{}/{}]: loss {:.4f}, Acc {:.4f}.'.format('Train', epoch, config['max_epoch'], loss_meter.avg, acc))
542 | return (loss_meter.avg, aux_loss_meter.avg), acc
543 |
544 | def test(test_loader, model, epoch_log, args):
545 | # eval
546 | model.eval()
547 |
548 | config = args.config
549 | loss_meter = utils.AverageMeter()
550 | # Acc
551 | current_idx = torch.zeros(1).long()
552 | y_true = torch.zeros((len(test_loader.dataset), 1)).long()
553 | y_pred = torch.zeros((len(test_loader.dataset), 1)).long()
554 |
555 | config = args.config
556 |
557 | np.random.seed(0)
558 | for batch in tqdm(test_loader):
559 | feed_list = []
560 | ims = batch[0].cuda(non_blocking=True)
561 | second_ims = batch[1].cuda(non_blocking=True)
562 | third_ims = batch[2].cuda(non_blocking=True)
563 | feed_list.append(ims)
564 | # For Two-view denseCL
565 | cl_ims = torch.cat([second_ims, third_ims], dim=0)
566 | ques_ix_iter = batch[3].cuda(non_blocking=True)
567 | feed_list.append(ques_ix_iter)
568 | ans_iter = batch[5].cuda(non_blocking=True)
569 |
570 | with torch.no_grad():
571 | with torch.cuda.amp.autocast(enabled=args.amp):
572 | logits = model(*feed_list)
573 | loss = F.cross_entropy(logits, ans_iter.squeeze(-1), reduction='sum')
574 | y_pred[current_idx.item():current_idx.item()+logits.size(0)] = logits.detach().argmax(-1).unsqueeze(-1).cpu()
575 | y_true[current_idx.item():current_idx.item()+logits.size(0)] = ans_iter.detach().cpu()
576 | current_idx += logits.size(0)
577 |
578 | n = logits.size(0)
579 | if args.multiprocessing_distributed:
580 | loss = loss * n # not considering ignore pixels
581 | count = logits.new_tensor([n], dtype=torch.long)
582 | dist.all_reduce(loss)
583 | dist.all_reduce(count)
584 | n = count.item()
585 | loss = loss / n
586 | else:
587 | loss = torch.mean(loss)
588 |
589 | loss_meter.update(loss.item(), logits.size(0))
590 | final_y_true, final_y_pred = gather_score_label(dist, test_loader, current_idx.cuda(), y_true.cuda(), y_pred.cuda(), args.multiprocessing_distributed)
591 | acc = compute_acc_gqa(final_y_true, final_y_pred)
592 | return loss_meter.avg, acc
593 |
594 | if __name__ == '__main__':
595 | parser = argparse.ArgumentParser()
596 | parser.add_argument('--config-file')
597 | parser.add_argument('--svname', default=None)
598 | parser.add_argument('--save_dir', default='./save_dist')
599 | parser.add_argument('--tag', default=None)
600 | # parser.add_argument('--gpu', default='0')
601 | parser.add_argument('--seed', type=int, default=123)
602 | parser.add_argument('--workers', type=int, default=8)
603 | parser.add_argument('--test_only', action='store_true')
604 | parser.add_argument('--test_model', default=None)
605 |
606 | # distributed training
607 | parser.add_argument('--world-size', default=1, type=int)
608 | parser.add_argument('--rank', default=0, type=int)
609 | parser.add_argument('--dist-backend', default='nccl')
610 | parser.add_argument('--amp', action='store_true')
611 | parser.add_argument(
612 | "--dist-url",
613 | default="tcp://127.0.0.1:{}",
614 | help="initialization URL for pytorch distributed backend. See "
615 | "https://pytorch.org/docs/stable/distributed.html for details.",
616 | )
617 | parser.add_argument(
618 | "opts",
619 | help="Modify config options using the command-line",
620 | default=None,
621 | nargs=argparse.REMAINDER,
622 | )
623 | args = parser.parse_args()
624 | args.multiprocessing_distributed = True
625 |
626 | config = yaml.load(open(args.config_file, 'r'), Loader=yaml.FullLoader)
627 | if args.opts is not None:
628 | config = utils.override_cfg_from_list(config, args.opts)
629 | print('config:')
630 | print(config)
631 | main(config)
632 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the Bongard-HOI library
5 | # which was released under the NVIDIA Source Code Licence.
6 | #
7 | # Source:
8 | # https://github.com/NVlabs/Bongard-HOI
9 | #
10 | # The license for the original version of this file can be
11 | # found in https://github.com/NVlabs/Bongard-HOI/blob/master/LICENSE
12 | # The modifications to this file are subject to the same NVIDIA Source Code Licence.
13 | # ---------------------------------------------------------------
14 |
15 | import os
16 | import shutil
17 | import time
18 |
19 | import numpy as np
20 | import torch
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 | from torch.optim import SGD, Adam, AdamW
24 | from torch.optim.lr_scheduler import MultiStepLR
25 |
26 | from typing import Any
27 | import sys
28 | from . import relvit
29 |
30 | import glob
31 | import PIL.Image
32 | from torchvision import transforms
33 | import torch.distributed as dist
34 | import torch.nn as nn
35 | import functools
36 | from copy import deepcopy
37 |
38 | _log_path = None
39 |
40 | _LOCAL_PROCESS_GROUP = None
41 | """
42 | A torch process group which only includes processes that on the same machine as the current process.
43 | This variable is set when processes are spawned by `launch()` in "engine/launch.py".
44 | """
45 |
46 | import multiprocessing
47 | from torch import Tensor
48 | from typing import Optional, Iterable, Any, List, Union, Tuple
49 |
50 |
51 | def div(numerator: Tensor, denom: Union[Tensor, int, float]) -> Tensor:
52 | """Handle division by zero"""
53 | if type(denom) in [int, float]:
54 | if denom == 0:
55 | return torch.zeros_like(numerator)
56 | else:
57 | return numerator / denom
58 | elif type(denom) is Tensor:
59 | zero_idx = torch.nonzero(denom == 0).squeeze(1)
60 | denom[zero_idx] += 1e-8
61 | return numerator / denom
62 | else:
63 | raise TypeError("Unsupported data type ", type(denom))
64 |
65 |
66 | def hico_logit_conversion_to_hoi(logits_verb, logits_object, corr):
67 | # logits_verb: (B, 117)
68 | # logits_object: (B, 80)
69 |
70 | # FIXME: how to compute the joint prob properly
71 | with torch.no_grad():
72 | B = logits_object.size(0)
73 | logits_hoi = torch.zeros(B, 600).to(logits_verb)
74 | for ind, c in enumerate(corr):
75 | obj, verb = c[1], c[2]
76 | logits_hoi[:, ind] = torch.sigmoid(logits_verb[:, verb]) * torch.sigmoid(logits_object[:, obj])
77 | # logits_hoi[:, ind] = logits_verb[:, verb] + logits_object[:, obj]
78 | return logits_hoi
79 |
80 |
81 | class AveragePrecisionMeter:
82 | """
83 | Meter to compute average precision
84 | Arguments:
85 | num_gt(iterable): Number of ground truth instances for each class. When left
86 | as None, all positives are assumed to have been included in the collected
87 | results. As a result, full recall is guaranteed when the lowest scoring
88 | example is accounted for.
89 | algorithm(str, optional): AP evaluation algorithm
90 | '11P': 11-point interpolation algorithm prior to voc2010
91 | 'INT': Interpolation algorithm with all points used in voc2010
92 | 'AUC': Precisely as the area under precision-recall curve
93 | chunksize(int, optional): The approximate size the given iterable will be split
94 | into for each worker. Use -1 to make the argument adaptive to iterable size
95 | and number of workers
96 | precision(int, optional): Precision used for float-point operations. Choose
97 | amongst 64, 32 and 16. Default is 64
98 | output(tensor[N, K], optinoal): Network outputs with N examples and K classes
99 | labels(tensor[N, K], optinoal): Binary labels
100 | Usage:
101 |
102 | (1) Evalute AP using provided output scores and labels
103 | >>> # Given output(tensor[N, K]) and labels(tensor[N, K])
104 | >>> meter = pocket.utils.AveragePrecisionMeter(output=output, labels=labels)
105 | >>> ap = meter.eval(); map_ = ap.mean()
106 | (2) Collect results on the fly and evaluate AP
107 | >>> meter = pocket.utils.AveragePrecisionMeter()
108 | >>> # Compute output(tensor[N, K]) during forward pass
109 | >>> meter.append(output, labels)
110 | >>> ap = meter.eval(); map_ = ap.mean()
111 | >>> # If you are to start new evaluation and want to reset the meter
112 | >>> meter.reset()
113 | """
114 | def __init__(self, num_gt: Optional[Iterable] = None,
115 | algorithm: str = 'AUC', chunksize: int = -1,
116 | precision: int = 64,
117 | output: Optional[Tensor] = None,
118 | labels: Optional[Tensor] = None) -> None:
119 | self._dtype = eval('torch.float' + str(precision))
120 | self.num_gt = torch.as_tensor(num_gt, dtype=self._dtype) \
121 | if num_gt is not None else None
122 | self.algorithm = algorithm
123 | self._chunksize = chunksize
124 |
125 | is_none = (output is None, labels is None)
126 | if is_none == (True, True):
127 | self._output = torch.tensor([], dtype=self._dtype)
128 | self._labels = torch.tensor([], dtype=self._dtype)
129 | elif is_none == (False, False):
130 | self._output = output.detach().cpu().to(self._dtype)
131 | self._labels = labels.detach().cpu().to(self._dtype)
132 | else:
133 | raise AssertionError("Output and labels should both be given or None")
134 |
135 | self._output_temp = [torch.tensor([], dtype=self._dtype)]
136 | self._labels_temp = [torch.tensor([], dtype=self._dtype)]
137 |
138 | @staticmethod
139 | def compute_per_class_ap_as_auc(tuple_: Tuple[Tensor, Tensor]) -> Tensor:
140 | """
141 | Arguments:
142 | tuple_(Tuple[Tensor, Tensor]): precision and recall
143 | Returns:
144 | ap(Tensor[1])
145 | """
146 | prec, rec = tuple_
147 | ap = 0
148 | max_rec = rec[-1]
149 | for idx in range(prec.numel()):
150 | # Stop when maximum recall is reached
151 | if rec[idx] >= max_rec:
152 | break
153 | d_x = rec[idx] - rec[idx - 1]
154 | # Skip when negative example is registered
155 | if d_x == 0:
156 | continue
157 | ap += prec[idx] * rec[idx] if idx == 0 \
158 | else 0.5 * (prec[idx] + prec[idx - 1]) * d_x
159 | return ap
160 |
161 | @staticmethod
162 | def compute_per_class_ap_with_interpolation(tuple_: Tuple[Tensor, Tensor]) -> Tensor:
163 | """
164 | Arguments:
165 | tuple_(Tuple[Tensor, Tensor]): precision and recall
166 | Returns:
167 | ap(Tensor[1])
168 | """
169 | prec, rec = tuple_
170 | ap = 0
171 | max_rec = rec[-1]
172 | for idx in range(prec.numel()):
173 | # Stop when maximum recall is reached
174 | if rec[idx] >= max_rec:
175 | break
176 | d_x = rec[idx] - rec[idx - 1]
177 | # Skip when negative example is registered
178 | if d_x == 0:
179 | continue
180 | # Compute interpolated precision
181 | max_ = prec[idx:].max()
182 | ap += max_ * rec[idx] if idx == 0 \
183 | else 0.5 * (max_ + torch.max(prec[idx - 1], max_)) * d_x
184 | return ap
185 |
186 | @staticmethod
187 | def compute_per_class_ap_with_11_point_interpolation(tuple_: Tuple[Tensor, Tensor]) -> Tensor:
188 | """
189 | Arguments:
190 | tuple_(Tuple[Tensor, Tensor]): precision and recall
191 | Returns:
192 | ap(Tensor[1])
193 | """
194 | prec, rec = tuple_
195 | dtype = rec.dtype
196 | ap = 0
197 | for t in torch.linspace(0, 1, 11, dtype=dtype):
198 | inds = torch.nonzero(rec >= t).squeeze()
199 | if inds.numel():
200 | ap += (prec[inds].max() / 11)
201 | return ap
202 |
203 | @classmethod
204 | def compute_ap(cls, output: Tensor, labels: Tensor,
205 | num_gt: Optional[Tensor] = None,
206 | algorithm: str = 'AUC',
207 | chunksize: int = -1) -> Tensor:
208 | """
209 | Compute average precision under the classification setting. Scores of all
210 | classes are retained for each sample.
211 | Arguments:
212 | output(Tensor[N, K])
213 | labels(Tensor[N, K])
214 | num_gt(Tensor[K]): Number of ground truth instances for each class
215 | algorithm(str): AP evaluation algorithm
216 | chunksize(int, optional): The approximate size the given iterable will be split
217 | into for each worker. Use -1 to make the argument adaptive to iterable size
218 | and number of workers
219 | Returns:
220 | ap(Tensor[K])
221 | """
222 | prec, rec = cls.compute_precision_and_recall(output, labels,
223 | num_gt=num_gt)
224 | ap = torch.zeros(output.shape[1], dtype=prec.dtype)
225 | # Use the logic from pool._map_async to compute chunksize
226 | # https://github.com/python/cpython/blob/master/Lib/multiprocessing/pool.py
227 | # NOTE: Inappropriate chunksize will cause [Errno 24]Too many open files
228 | # Make changes with caution
229 | if chunksize == -1:
230 | chunksize, extra = divmod(
231 | output.shape[1],
232 | multiprocessing.cpu_count() * 4)
233 | if extra:
234 | chunksize += 1
235 |
236 | if algorithm == 'INT':
237 | algorithm_handle = cls.compute_per_class_ap_with_interpolation
238 | elif algorithm == '11P':
239 | algorithm_handle = cls.compute_per_class_ap_with_11_point_interpolation
240 | elif algorithm == 'AUC':
241 | algorithm_handle = cls.compute_per_class_ap_as_auc
242 | else:
243 | raise ValueError("Unknown algorithm option {}.".format(algorithm))
244 |
245 | # with multiprocessing.get_context('spawn').Pool() as pool:
246 | # for idx, result in enumerate(pool.imap(
247 | # func=algorithm_handle,
248 | # # NOTE: Use transpose instead of T for compatibility
249 | # iterable=zip(prec.transpose(0,1), rec.transpose(0,1)),
250 | # chunksize=chunksize
251 | # )):
252 | # ap[idx] = algorithm_handle(prec[idx], rec[idx])
253 | prec = prec.transpose(0,1)
254 | rec = rec.transpose(0,1)
255 | for idx in range(len(prec)):
256 | ap[idx] = algorithm_handle((prec[idx], rec[idx]))
257 |
258 | return ap
259 |
260 | @staticmethod
261 | def compute_precision_and_recall(output: Tensor, labels: Tensor,
262 | num_gt: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
263 | """
264 | Arguments:
265 | output(Tensor[N, K])
266 | labels(Tensor[N, K])
267 | num_gt(Tensor[K])
268 | Returns:
269 | prec(Tensor[N, K])
270 | rec(Tensor[N, K])
271 | """
272 | order = output.argsort(0, descending=True)
273 | tp = labels[
274 | order,
275 | torch.ones_like(order) * torch.arange(output.shape[1])
276 | ]
277 | fp = 1 - tp
278 | tp = tp.cumsum(0)
279 | fp = fp.cumsum(0)
280 |
281 | prec = tp / (tp + fp)
282 | rec = div(tp, labels.sum(0)) if num_gt is None \
283 | else div(tp, num_gt)
284 |
285 | return prec, rec
286 |
287 | def append(self, output: Tensor, labels: Tensor) -> None:
288 | """
289 | Add new results to the meter
290 | Arguments:
291 | output(tensor[N, K]): Network output with N examples and K classes
292 | labels(tensor[N, K]): Binary labels
293 | """
294 | if isinstance(output, torch.Tensor) and isinstance(labels, torch.Tensor):
295 | assert output.shape == labels.shape, \
296 | "Output scores do not match the dimension of labelss"
297 | self._output_temp.append(output.detach().cpu().to(self._dtype))
298 | self._labels_temp.append(labels.detach().cpu().to(self._dtype))
299 | else:
300 | raise TypeError("Arguments should both be torch.Tensor")
301 |
302 | def reset(self, keep_old: bool = False) -> None:
303 | """
304 | Clear saved statistics
305 | Arguments:
306 | keep_tracked(bool): If True, clear only the newly collected statistics
307 | since last evaluation
308 | """
309 | if not keep_old:
310 | self._output = torch.tensor([], dtype=self._dtype)
311 | self._labels = torch.tensor([], dtype=self._dtype)
312 | self._output_temp = [torch.tensor([], dtype=self._dtype)]
313 | self._labels_temp = [torch.tensor([], dtype=self._dtype)]
314 |
315 | def eval(self) -> Tensor:
316 | """
317 | Evaluate the average precision based on collected statistics
318 | Returns:
319 | torch.Tensor[K]: Average precisions for K classes
320 | """
321 | self._output = torch.cat([
322 | self._output,
323 | torch.cat(self._output_temp, 0)
324 | ], 0)
325 | self._labels = torch.cat([
326 | self._labels,
327 | torch.cat(self._labels_temp, 0)
328 | ], 0)
329 | self.reset(keep_old=True)
330 |
331 | # Sanity check
332 | if self.num_gt is not None:
333 | self.num_gt = self.num_gt.to(dtype=self._labels.dtype)
334 | faulty_cls = torch.nonzero(self._labels.sum(0) > self.num_gt).squeeze(1)
335 | if len(faulty_cls):
336 | raise AssertionError("Class {}: ".format(faulty_cls.tolist())+
337 | "Number of true positives larger than that of ground truth")
338 | if len(self._output) and len(self._labels):
339 | return self.compute_ap(self._output, self._labels, num_gt=self.num_gt,
340 | algorithm=self.algorithm, chunksize=self._chunksize)
341 | else:
342 | print("WARNING: Collected results are empty. "
343 | "Return zero AP for all class.")
344 | return torch.zeros(self._output.shape[1], dtype=self._dtype)
345 |
346 | class AverageMeter(object):
347 | """Computes and stores the average and current value"""
348 | def __init__(self):
349 | self.reset()
350 |
351 | def reset(self):
352 | self.val = 0
353 | self.avg = 0
354 | self.sum = 0
355 | self.count = 0
356 |
357 | def update(self, val, n=1):
358 | self.val = val
359 | self.sum += val * n
360 | self.count += n
361 | self.avg = self.sum / self.count
362 |
363 |
364 | def set_log_path(path):
365 | global _log_path
366 | _log_path = path
367 |
368 |
369 | def log(obj, filename='log.txt'):
370 | print(obj)
371 | if _log_path is not None:
372 | with open(os.path.join(_log_path, filename), 'a') as f:
373 | print(obj, file=f)
374 |
375 |
376 | class Averager():
377 |
378 | def __init__(self):
379 | self.n = 0.0
380 | self.v = 0.0
381 |
382 | def add(self, v, n=1.0):
383 | self.v = (self.v * self.n + v * n) / (self.n + n)
384 | self.n += n
385 |
386 | def item(self):
387 | return self.v
388 |
389 |
390 | class Timer():
391 |
392 | def __init__(self):
393 | self.v = time.time()
394 |
395 | def s(self):
396 | self.v = time.time()
397 |
398 | def t(self):
399 | return time.time() - self.v
400 |
401 |
402 | def set_gpu(gpu):
403 | print('set gpu:', gpu)
404 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu
405 |
406 |
407 | def ensure_path(path, remove=True):
408 | basename = os.path.basename(path.rstrip('/'))
409 | if os.path.exists(path):
410 | if remove and (basename.startswith('_')
411 | or input('{} exists, remove? ([y]/n): '.format(path)) != 'n'):
412 | shutil.rmtree(path)
413 | os.makedirs(path)
414 | else:
415 | os.makedirs(path)
416 |
417 |
418 | def time_str(t):
419 | if t >= 3600:
420 | return '{:.1f}h'.format(t / 3600)
421 | if t >= 60:
422 | return '{:.1f}m'.format(t / 60)
423 | return '{:.1f}s'.format(t)
424 |
425 |
426 | def compute_logits(feat, proto, metric='dot', temp=1.0):
427 | assert feat.dim() == proto.dim()
428 |
429 | if feat.dim() == 2:
430 | if metric == 'dot':
431 | logits = torch.mm(feat, proto.t())
432 | elif metric == 'cos':
433 | logits = torch.mm(F.normalize(feat, dim=-1),
434 | F.normalize(proto, dim=-1).t())
435 | elif metric == 'sqr':
436 | logits = -(feat.unsqueeze(1) -
437 | proto.unsqueeze(0)).pow(2).sum(dim=-1)
438 |
439 | elif feat.dim() == 3:
440 | if metric == 'dot':
441 | logits = torch.bmm(feat, proto.permute(0, 2, 1))
442 | elif metric == 'cos':
443 | logits = torch.bmm(F.normalize(feat, dim=-1),
444 | F.normalize(proto, dim=-1).permute(0, 2, 1))
445 | elif metric == 'sqr':
446 | logits = -(feat.unsqueeze(2) -
447 | proto.unsqueeze(1)).pow(2).sum(dim=-1)
448 |
449 | return logits * temp
450 |
451 |
452 | def compute_acc(logits, label, reduction='mean'):
453 | ret = (torch.argmax(logits, dim=1) == label).float()
454 | if reduction == 'none':
455 | return ret.detach()
456 | elif reduction == 'mean':
457 | return ret.mean()
458 |
459 |
460 | def compute_n_params(model, return_str=True):
461 | tot = 0
462 | for p in model.parameters():
463 | w = 1
464 | for x in p.shape:
465 | w *= x
466 | tot += w
467 | if return_str:
468 | if tot >= 1e6:
469 | return '{:.1f}M'.format(tot / 1e6)
470 | else:
471 | return '{:.1f}K'.format(tot / 1e3)
472 | else:
473 | return tot
474 |
475 | def make_optimizer(params, name, max_steps, lr, weight_decay=None, milestones=None, scheduler='step', use_sam=False, sam_rho=0.005, eps=1e-8, **kwargs):
476 | if weight_decay is None:
477 | weight_decay = 0.
478 | if use_sam:
479 | optimizer = SAM(params, AdamW, rho=sam_rho, lr=lr, weight_decay=weight_decay, eps=1e-08)
480 | else:
481 | if name == 'sgd':
482 | optimizer = SGD(params, lr, momentum=0.9, weight_decay=weight_decay)
483 | elif name == 'adam':
484 | optimizer = Adam(params, lr, weight_decay=weight_decay)
485 | elif name == 'adamw':
486 | optimizer = AdamW(
487 | params, float(lr), betas=(0.9, 0.999), eps=float(eps),
488 | weight_decay=weight_decay
489 | )
490 |
491 | update_lr_every_epoch = True
492 | if scheduler == 'step':
493 | if milestones:
494 | lr_scheduler = MultiStepLR(optimizer, milestones)
495 | else:
496 | lr_scheduler = None
497 | elif scheduler == 'onecycle':
498 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
499 | optimizer,
500 | lr,
501 | max_steps + 100,
502 | pct_start=0.05,
503 | cycle_momentum=False,
504 | anneal_strategy='linear',
505 | final_div_factor=10000
506 | )
507 | update_lr_every_epoch = False
508 | elif scheduler == 'warmup_cosine':
509 | import pl_bolts
510 | lr_scheduler = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(optimizer, kwargs['warmup_epochs'], kwargs['max_epochs'], warmup_start_lr=kwargs['warmup_start_lr'], eta_min=0.0, last_epoch=-1)
511 | return optimizer, lr_scheduler, update_lr_every_epoch
512 |
513 | def set_lr(optimizer, lr):
514 | s = optimizer.state_dict()
515 | s['param_groups'][0]['lr'] = lr
516 | optimizer.load_state_dict(s)
517 |
518 | def get_lr(optimizer):
519 | return optimizer.state_dict()['param_groups'][0]['lr']
520 |
521 | class ProjHead(nn.Module):
522 | def __init__(self, in_dim, out_dim):
523 | super(ProjHead, self).__init__()
524 | self.in_dim = in_dim
525 | self.out_dim = out_dim
526 |
527 | self.proj = nn.Sequential(
528 | nn.Linear(self.in_dim, self.out_dim),
529 | nn.GELU(),
530 | nn.Linear(self.out_dim, self.out_dim),
531 | )
532 |
533 | def forward(self, x):
534 | x = self.proj(x)
535 | return x
536 |
537 | MCAN_GQA_PARAMS = {
538 | 'FRCN_FEAT_SIZE': (100, 2048),
539 | 'GRID_FEAT_SIZE': (49, 2048),
540 | 'BBOX_FEAT_SIZE': (100, 5),
541 | 'BBOXFEAT_EMB_SIZE': 2048,
542 | 'HIDDEN_SIZE': 512,
543 | 'FLAT_MLP_SIZE': 512,
544 | 'FLAT_GLIMPSES': 1,
545 | 'FLAT_OUT_SIZE': 1024,
546 | 'DROPOUT_R': 0.1,
547 | 'LAYER': 6,
548 | 'FF_SIZE': 2048,
549 | 'MULTI_HEAD': 8,
550 | 'WORD_EMBED_SIZE': 300,
551 | 'TOKEN_SIZE': 2933,
552 | 'WORD_EMBED_SIZE': 300,
553 | 'ANSWER_SIZE': 1843,
554 | 'MAX_TOKEN_LENGTH': 29,
555 | 'USE_BBOX_FEAT': True,
556 | 'USE_AUX_FEAT': True,
557 | }
558 |
559 | def freeze_bn(model):
560 | for m in model.modules():
561 | if isinstance(m, nn.BatchNorm2d):
562 | m.eval()
563 |
564 | class Logger(object):
565 | """
566 | Redirect stderr to stdout, optionally print stdout to a file,
567 | and optionally force flushing on both stdout and the file.
568 | """
569 |
570 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
571 | self.file = None
572 |
573 | if file_name is not None:
574 | self.file = open(file_name, file_mode)
575 |
576 | self.should_flush = should_flush
577 | self.stdout = sys.stdout
578 | self.stderr = sys.stderr
579 |
580 | sys.stdout = self
581 | sys.stderr = self
582 |
583 | def __enter__(self) -> "Logger":
584 | return self
585 |
586 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
587 | self.close()
588 |
589 | def write(self, text: str) -> None:
590 | """Write text to stdout (and a file) and optionally flush."""
591 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
592 | return
593 |
594 | if self.file is not None:
595 | self.file.write(text)
596 |
597 | self.stdout.write(text)
598 |
599 | if self.should_flush:
600 | self.flush()
601 |
602 | def flush(self) -> None:
603 | """Flush written text to both stdout and a file, if open."""
604 | if self.file is not None:
605 | self.file.flush()
606 |
607 | self.stdout.flush()
608 |
609 | def close(self) -> None:
610 | """Flush, close possible files, and remove stdout/stderr mirroring."""
611 | self.flush()
612 |
613 | # if using multiple loggers, prevent closing in wrong order
614 | if sys.stdout is self:
615 | sys.stdout = self.stdout
616 | if sys.stderr is self:
617 | sys.stderr = self.stderr
618 |
619 | if self.file is not None:
620 | self.file.close()
621 |
622 | def anytype2bool_dict(s):
623 | # check str
624 | if not isinstance(s, str):
625 | return s
626 | else:
627 | # try int
628 | try:
629 | ret = int(s)
630 | except:
631 | # try bool
632 | if s.lower() in ('true', 'false'):
633 | ret = s.lower() == 'true'
634 | # try float
635 | else:
636 | try:
637 | ret = float(s)
638 | except:
639 | ret = s
640 | return ret
641 |
642 | def parse_string_to_dict(field_name, value):
643 | fields = field_name.split('.')
644 | for fd in fields[::-1]:
645 | res = {fd: anytype2bool_dict(value)}
646 | value = res
647 | return res
648 |
649 | def merge_to_dicts(a, b):
650 | if isinstance(b, dict) and isinstance(a, dict):
651 | a_and_b = set(a.keys()) & set(b.keys())
652 | every_key = set(a.keys()) | set(b.keys())
653 | return {k: merge_to_dicts(a[k], b[k]) if k in a_and_b else
654 | deepcopy(a[k] if k in a else b[k]) for k in every_key}
655 | return deepcopy(type(a)(b))
656 |
657 | def override_cfg_from_list(cfg, opts):
658 | assert len(opts) % 2 == 0, 'Paired input must be provided to override config, opts: {}'.format(opts)
659 | for ix in range(0, len(opts), 2):
660 | opts_dict = parse_string_to_dict(opts[ix], opts[ix + 1])
661 | cfg = merge_to_dicts(cfg, opts_dict)
662 | return cfg
663 |
664 | # ----------------------------------------------------------------------------
665 |
666 | def find_free_port():
667 | import socket
668 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
669 | # Binding to port 0 will cause the OS to find an available port for us
670 | sock.bind(("", 0))
671 | port = sock.getsockname()[1]
672 | sock.close()
673 | # NOTE: there is still a chance the port could be taken by other processes.
674 | return port
675 |
676 | def get_world_size() -> int:
677 | if not dist.is_available():
678 | return 1
679 | if not dist.is_initialized():
680 | return 1
681 | return dist.get_world_size()
682 |
683 |
684 | def get_rank() -> int:
685 | if not dist.is_available():
686 | return 0
687 | if not dist.is_initialized():
688 | return 0
689 | return dist.get_rank()
690 |
691 |
692 | def get_local_rank() -> int:
693 | """
694 | Returns:
695 | The rank of the current process within the local (per-machine) process group.
696 | """
697 | if not dist.is_available():
698 | return 0
699 | if not dist.is_initialized():
700 | return 0
701 | assert _LOCAL_PROCESS_GROUP is not None
702 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
703 |
704 |
705 | def get_local_size() -> int:
706 | """
707 | Returns:
708 | The size of the per-machine process group,
709 | i.e. the number of processes per machine.
710 | """
711 | if not dist.is_available():
712 | return 1
713 | if not dist.is_initialized():
714 | return 1
715 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
716 |
717 |
718 | def is_main_process() -> bool:
719 | return get_rank() == 0
720 |
721 |
722 | def synchronize():
723 | """
724 | Helper function to synchronize (barrier) among all processes when
725 | using distributed training
726 | """
727 | if not dist.is_available():
728 | return
729 | if not dist.is_initialized():
730 | return
731 | world_size = dist.get_world_size()
732 | if world_size == 1:
733 | return
734 | dist.barrier()
735 |
--------------------------------------------------------------------------------
/utils/relvit.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------------------------------------
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for RelViT. To view a copy of this license, see the LICENSE file.
6 | # ----------------------------------------------------------------------
7 |
8 | import logging
9 | import math
10 | import warnings
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch.distributed as dist
17 |
18 |
19 | def comp_sample_prob(size, pos, K, use_log=False):
20 | # case 1: size < K
21 | # then pos-1 > pos-2 > ... > 0
22 | if size < K:
23 | rank = np.arange(pos)+1
24 | # case 2: size == K
25 | # then (pos-1)%K > (pos-2)%K > ... > pos
26 | else:
27 | rank = np.roll(np.arange(size), pos)+1
28 | if use_log:
29 | rank = np.log(10*rank)
30 | assert len(rank) == size
31 | return torch.from_numpy(rank).float()
32 |
33 | def dequeue_with_concept(feat, concept, queue_dict, queue_grid_dict, queue_ptr_dict, sample_uniform=True):
34 | # concept: (B, num_concepts)
35 | # If there is no enough concepts in the queue or the current sample comes without any concepts,
36 | # just return the corresponding feat.
37 |
38 | # (2*B, C), (2*B, N, C), (2*B, N, C)
39 | t_cls_out, t_region_out, t_fea = feat
40 | N = t_region_out.size(1)
41 |
42 | K = queue_dict[0].buffer.size(0)
43 |
44 | B = concept.size(0)
45 | with torch.no_grad():
46 | # sanitize samples without any concept
47 | concept = concept.repeat(2, 1)
48 | mask = (concept.sum(dim=-1) == 0)
49 | concept[mask] += 1
50 | concept_sample = torch.distributions.Categorical(concept).sample()
51 |
52 | ret1, ret2, ret3 = [], [], []
53 | for ind, (c, m) in enumerate(zip(concept_sample, mask)):
54 | cur_pos = queue_ptr_dict[c.item()].buffer[0].item()
55 | size = queue_ptr_dict[c.item()].buffer[1].item()
56 |
57 | if size != 0 and m == 0:
58 | if sample_uniform:
59 | # Equal prob
60 | pos = torch.distributions.Categorical(torch.ones(size)).sample().item()
61 | else:
62 | # "FIFO" prob
63 | prob = comp_sample_prob(size, cur_pos, K, use_log=True)
64 | pos = torch.distributions.Categorical(prob).sample().item()
65 |
66 | ret1.append(queue_dict[c.item()].buffer[pos])
67 | ret2.append(queue_grid_dict[c.item()].buffer[pos, :N])
68 | ret3.append(queue_grid_dict[c.item()].buffer[pos, N:])
69 | else:
70 | ret1.append(feat[0][ind])
71 | ret2.append(feat[1][ind])
72 | ret3.append(feat[2][ind])
73 | ret1 = torch.stack(ret1).to(t_region_out)
74 | ret2 = torch.stack(ret2).to(t_region_out)
75 | ret3 = torch.stack(ret3).to(t_region_out)
76 | assert ret1.shape == feat[0].shape
77 | assert ret2.shape == feat[1].shape
78 | assert ret3.shape == feat[2].shape
79 | return (ret1.contiguous(), ret2.contiguous(), ret3.contiguous())
80 |
81 | def enqueue_with_concept(feat, concept, queue_dict, queue_grid_dict, queue_ptr_dict):
82 | # feat: (2*B, C)
83 | # concept: (B, num_concepts)
84 | # We only work on the first B instances and skip those without any concepts.
85 |
86 | # (2*B, C), (2*B, N, C), (2*B, N, C)
87 | t_cls_out, t_region_out, t_fea = feat
88 | N = t_region_out.size(1)
89 |
90 | K = queue_dict[0].buffer.size(0)
91 |
92 | with torch.no_grad():
93 | # sanitize samples without any concept
94 | mask = (concept.sum(dim=-1) == 0)
95 | concept[mask] += 1
96 | concept_sample = torch.distributions.Categorical(concept).sample()
97 |
98 | for ind, (c, m) in enumerate(zip(concept_sample, mask)):
99 | if m == 0:
100 | # write pos and size
101 | pos = queue_ptr_dict[c.item()].buffer[0]
102 | size = queue_ptr_dict[c.item()].buffer[1]
103 | pos += 1
104 | size += 1
105 | # write pos should loop back to 0
106 | pos %= K
107 | queue_ptr_dict[c.item()].buffer[0] = pos
108 | # size should be clamped to K
109 | size = torch.clamp(size, 0, K)
110 | queue_ptr_dict[c.item()].buffer[1] = size
111 |
112 | queue_dict[c.item()].buffer[pos.item()] = feat[0][ind].detach()
113 | queue_grid_dict[c.item()].buffer[pos.item(), :N] = feat[1][ind].detach()
114 | queue_grid_dict[c.item()].buffer[pos.item(), N:] = feat[2][ind].detach()
115 |
116 | def RCL(attn_v_tea, attn_v_stu, concept, queue_dict, queue_grid_dict, queue_ptr_dict, center, center_grid, tau=0.04, local_only=1, tau_stu=0.1, sample_uniform=True):
117 | # only: 1 -- local; 2 -- global
118 |
119 | # FIXME: we should dequeue before enqueuing to avoid using the recently added samples.
120 | target = dequeue_with_concept(attn_v_tea, concept, queue_dict, queue_grid_dict, queue_ptr_dict, sample_uniform)
121 | enqueue_with_concept(attn_v_tea, concept, queue_dict, queue_grid_dict, queue_ptr_dict)
122 |
123 | return token_level_esvit(target, attn_v_stu, center, center_grid, tau, local_only, tau_stu)
124 |
125 | def token_level_esvit(attn_v_tea, attn_v_stu, center, center_grid, tau=0.04, local_only=True, tau_stu=0.1):
126 | # only: 1 -- local; 2 -- global
127 | # FIXME: temporarily disable multi-crop
128 | ncrops = 2
129 | # (2*B, C), (2*B, N, C), (2*B, N, C)
130 | s_cls_out, s_region_out, s_fea = attn_v_stu
131 | t_cls_out, t_region_out, t_fea = attn_v_tea
132 | B, N = s_region_out.size(0), s_region_out.size(1)
133 | B = B // 2
134 | s_region_out = torch.cat([
135 | s_region_out[:B].reshape(B*N, -1),
136 | s_region_out[B:].reshape(B*N, -1)]
137 | ).contiguous()
138 | s_fea = torch.cat([
139 | s_fea[:B].reshape(B*N, -1),
140 | s_fea[B:].reshape(B*N, -1)]
141 | ).contiguous()
142 | t_region_out = torch.cat([
143 | t_region_out[:B].reshape(B*N, -1),
144 | t_region_out[B:].reshape(B*N, -1)]
145 | ).contiguous()
146 | t_fea = torch.cat([
147 | t_fea[:B].reshape(B*N, -1),
148 | t_fea[B:].reshape(B*N, -1)]
149 | ).contiguous()
150 | s_npatch = [N, N]
151 | t_npatch = [N, N]
152 |
153 | # teacher centering and sharpening
154 | temp = tau
155 | t_cls = F.softmax((t_cls_out - center) / temp, dim=-1)
156 | t_cls = t_cls.detach().chunk(2)
157 |
158 | t_region = F.softmax((t_region_out - center_grid) / temp, dim=-1)
159 | t_region = t_region.detach().chunk(2)
160 | t_fea = t_fea.chunk(2)
161 |
162 |
163 | N = t_npatch[0] # num of patches in the first view
164 | B = t_region[0].shape[0]//N # batch size,
165 |
166 | # student sharpening
167 | s_cls = s_cls_out / tau
168 | s_cls = s_cls.chunk(ncrops)
169 |
170 | s_region = s_region_out / tau_stu
171 | s_split_size = [s_npatch[0]] * 2 + [s_npatch[1]] * (ncrops -2)
172 |
173 | s_split_size_bs = [i * B for i in s_split_size]
174 |
175 | s_region = torch.split(s_region, s_split_size_bs, dim=0)
176 | s_fea = torch.split(s_fea, s_split_size_bs, dim=0)
177 |
178 | total_loss = 0
179 | n_loss_terms = 0
180 | for iq, q in enumerate(t_cls):
181 | for v in range(len(s_cls)):
182 | if v == iq:
183 | # we skip cases where student and teacher operate on the same view
184 | continue
185 |
186 | # view level prediction loss
187 | loss = 0.5 * torch.sum(-q * F.log_softmax(s_cls[v], dim=-1), dim=-1)
188 | if local_only == 1:
189 | loss *= 0
190 |
191 | # region level prediction loss
192 | s_region_cur, s_fea_cur = s_region[v].view(B, s_split_size[v], -1).contiguous(), s_fea[v].view(B, s_split_size[v], -1).contiguous() # B x T_s x K, B x T_s x P
193 | t_region_cur, t_fea_cur = t_region[iq].view(B, N, -1).contiguous(), t_fea[iq].view(B, N, -1).contiguous() # B x T_t x K, B x T_t x P,
194 |
195 | # similarity matrix between two sets of region features
196 | region_sim_matrix = torch.matmul(F.normalize(s_fea_cur, p=2, dim=-1) , F.normalize(t_fea_cur, p=2, dim=-1).permute(0, 2, 1).contiguous()) # B x T_s x T_t
197 | region_sim_ind = region_sim_matrix.max(dim=2)[1] # B x T_s; collect the argmax index in teacher for a given student feature
198 |
199 | t_indexed_region = torch.gather( t_region_cur, 1, region_sim_ind.unsqueeze(2).expand(-1, -1, t_region_cur.size(2)) ) # B x T_s x K (index matrix: B, T_s, 1)
200 |
201 | loss_grid = torch.sum(- t_indexed_region * F.log_softmax(s_region_cur, dim=-1), dim=[-1]).mean(-1) # B x T_s x K --> B
202 |
203 | if local_only == 2:
204 | loss += 0. * loss_grid
205 | else:
206 | loss += 0.5 * loss_grid
207 |
208 | total_loss += loss.mean()
209 | n_loss_terms += 1
210 | total_loss /= n_loss_terms
211 |
212 | return total_loss, t_cls_out.contiguous(), t_region_out.contiguous()
213 |
214 | class Buffer(nn.Module):
215 | def __init__(self, tensor_cls, *args, **kwargs):
216 | super(Buffer, self).__init__()
217 | self.register_buffer('buffer', tensor_cls(*args, **kwargs))
218 |
219 | class MoCo(nn.Module):
220 | """
221 | Build a MoCo model with: a query encoder, a key encoder, and a queue
222 | https://arxiv.org/abs/1911.05722
223 | """
224 | def __init__(self, encoder_tea_k, encoder_stu_q, K=65536, m=0.999, feat_dim=None, cl_loss='single-view', num_concepts=600, relvit_mode=1, num_tokens=49):
225 | """
226 | dim: feature dimension (default: 128)
227 | K: queue size; number of negative keys (default: 65536)
228 | m: moco momentum of updating key encoder (default: 0.999)
229 | T: softmax temperature (default: 0.07)
230 | """
231 | super(MoCo, self).__init__()
232 |
233 | self.K = K
234 | self.m = m
235 |
236 | # create the encoders
237 | # feature embedding size is the output fc dimension
238 | self.encoder_q = [encoder_stu_q]
239 | self.encoder_k = [encoder_tea_k]
240 | dim = self.encoder_q[0].out_dim if feat_dim is None else feat_dim
241 |
242 | for param_q, param_k in zip(self.encoder_q[0].parameters(), self.encoder_k[0].parameters()):
243 | param_k.data.copy_(param_q.data) # initialize
244 | param_k.requires_grad = False # not update by gradient
245 |
246 | self.relvit_mode = relvit_mode
247 | self.center_momentum = 0.9
248 | self.register_buffer("center", torch.zeros(1, dim))
249 | self.register_buffer("center_grid", torch.zeros(1, dim))
250 | if self.relvit_mode:
251 | # FIXME: https://github.com/microsoft/esvit/blob/main/main_esvit.py#L606
252 | self.register_buffer("center_rcl", torch.zeros(1, dim))
253 | self.register_buffer("center_grid_rcl", torch.zeros(1, dim))
254 |
255 | if self.relvit_mode:
256 | # assert self.K <= 10
257 | self.queue_dict = nn.ModuleList([
258 | Buffer(torch.rand, self.K, dim) for i in range(num_concepts)
259 | ])
260 | # FIXME: magic number; first projected and unprojected
261 | # 49x2 works with pvtv2b2 and swin_small
262 | # 196x2 works with vit_small_16
263 | self.queue_grid_dict = nn.ModuleList([
264 | Buffer(torch.rand, self.K, num_tokens*2, dim) for i in range(num_concepts)
265 | ])
266 | # current writing pos and size
267 | self.queue_ptr_dict = nn.ModuleList([
268 | Buffer(torch.zeros, 2, dtype=torch.long) for i in range(num_concepts)
269 | ])
270 |
271 | @torch.no_grad()
272 | def _momentum_update_key_encoder(self):
273 | """
274 | Momentum update of the key encoder
275 | """
276 | for param_q, param_k in zip(self.encoder_q[0].parameters(), self.encoder_k[0].parameters()):
277 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
278 |
279 | @torch.no_grad()
280 | def _dequeue_and_enqueue(self, keys):
281 | batch_size = keys.shape[0]
282 |
283 | ptr = int(self.queue_ptr)
284 | assert self.K % batch_size == 0, batch_size # for simplicity
285 |
286 | # replace the keys at ptr (dequeue and enqueue)
287 | self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
288 | ptr = (ptr + batch_size) % self.K # move pointer
289 |
290 | self.queue_ptr[0] = ptr
291 |
292 | def forward(self, attn_v_k, attn_v_q, B, config, args, concepts=None):
293 | def cal_batch(teacher_output, teacher_grid_output):
294 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
295 | if args.multiprocessing_distributed:
296 | dist.all_reduce(batch_center)
297 | world_size = dist.get_world_size()
298 | else:
299 | world_size = 1
300 | batch_center = batch_center / (len(teacher_output) * world_size)
301 |
302 | # region level center update
303 | batch_grid_center = torch.sum(teacher_grid_output, dim=0, keepdim=True)
304 | if args.multiprocessing_distributed:
305 | dist.all_reduce(batch_grid_center)
306 | world_size = dist.get_world_size()
307 | else:
308 | world_size = 1
309 | batch_grid_center = batch_grid_center / (len(teacher_grid_output) * world_size)
310 | return batch_center, batch_grid_center
311 |
312 | # compute key features
313 | with torch.no_grad(): # no gradient to keys
314 | self._momentum_update_key_encoder() # update the key encoder
315 |
316 | loss, teacher_output, teacher_grid_output = token_level_esvit(
317 | attn_v_k,
318 | attn_v_q,
319 | self.center,
320 | self.center_grid,
321 | tau=config['relvit_loss_tau'],
322 | local_only=config['relvit_local_only'])
323 | with torch.no_grad():
324 | # ema update for esvit
325 | batch_center, batch_grid_center = cal_batch(teacher_output, teacher_grid_output)
326 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
327 | self.center_grid = self.center_grid * self.center_momentum + batch_grid_center * (1 - self.center_momentum)
328 |
329 | if self.relvit_mode:
330 | loss_rcl, teacher_output_rcl, teacher_grid_output_rcl = RCL(
331 | attn_v_k,
332 | attn_v_q,
333 | concepts,
334 | self.queue_dict,
335 | self.queue_grid_dict,
336 | self.queue_ptr_dict,
337 | self.center_rcl,
338 | self.center_grid_rcl,
339 | tau=config['relvit_loss_tau'],
340 | local_only=config['relvit_local_only'],
341 | sample_uniform=config['relvit_sample_uniform'])
342 |
343 | if self.relvit_mode == 2:
344 | loss = loss * 0 + loss_rcl
345 | elif self.relvit_mode == 1:
346 | loss += loss_rcl
347 | else:
348 | loss = loss
349 | with torch.no_grad():
350 | # ema update for RCL
351 | batch_center_rcl, batch_grid_center_rcl = cal_batch(teacher_output_rcl, teacher_grid_output_rcl)
352 | self.center_rcl = self.center_rcl * self.center_momentum + batch_center_rcl * (1 - self.center_momentum)
353 | self.center_grid_rcl = self.center_grid_rcl * self.center_momentum + batch_grid_center_rcl * (1 - self.center_momentum)
354 |
355 | return loss
--------------------------------------------------------------------------------