├── .gitignore
├── LICENSE
├── README.md
├── configs
├── COCO-Detection
│ ├── Base-RCNN-DilatedC5.yaml
│ ├── WSOVOD_MRRP_V_16_DC5_1x.yaml
│ ├── WSOVOD_MRRP_WSR_18_DC5_1x.yaml
│ ├── WSOVOD_MRRP_WSR_50_DC5_1x.yaml
│ ├── WSOVOD_V_16_DC5_1x.yaml
│ ├── WSOVOD_WSR_18_DC5_1x.yaml
│ └── WSOVOD_WSR_50_DC5_1x.yaml
├── ImageNet-Detection
│ ├── Base-RCNN-DilatedC5.yaml
│ └── WSOVOD_WSR_18_DC5_1x.yaml
├── MixedDatasets-Detection
│ ├── Base-RCNN-DilatedC5.yaml
│ ├── WSOVOD_MRRP_WSR_18_DC5_1x_voc07+coco.yaml
│ ├── WSOVOD_MRRP_WSR_50_DC5_1x_voc07+coco.yaml
│ ├── WSOVOD_WSR_18_DC5_1x_voc07+coco.yaml
│ └── WSOVOD_WSR_50_DC5_1x_voc07+coco.yaml
└── PascalVOC-Detection
│ ├── Base-RCNN-DilatedC5.yaml
│ ├── WSOVOD_MRRP_V_16_DC5_1x.yaml
│ ├── WSOVOD_MRRP_V_16_DC5_VOC12_1x.yaml
│ ├── WSOVOD_MRRP_WSR_18_DC5_1x.yaml
│ ├── WSOVOD_MRRP_WSR_18_DC5_VOC12_1x.yaml
│ ├── WSOVOD_MRRP_WSR_50_DC5_1x.yaml
│ ├── WSOVOD_MRRP_WSR_50_DC5_VOC12_1x.yaml
│ ├── WSOVOD_V_16_DC5_1x.yaml
│ ├── WSOVOD_V_16_DC5_VOC12_1x.yaml
│ ├── WSOVOD_WSR_18_DC5_1x.yaml
│ ├── WSOVOD_WSR_18_DC5_VOC12_1x.yaml
│ ├── WSOVOD_WSR_50_DC5_1x.yaml
│ └── WSOVOD_WSR_50_DC5_VOC12_1x.yaml
├── datasets
└── README.md
├── requirements.txt
├── scripts
├── extract_ilsvrc.sh
├── generate_sam_proposals_cuda.sh
├── prepare_ilsvrc.sh
└── train_script.sh
├── setup.py
├── teaser
└── framework.png
├── tools
├── convert_ilsvrc_classes_name.py
├── generate_class_text_embedding.py
├── generate_class_text_embedding_cuda.py
├── generate_sam_proposals_cuda.py
├── ilsvrc2012_classes_name.txt
├── ilsvrc_folder.py
├── ilsvrc_info.py
├── train_net.py
├── train_net_debug.py
└── train_net_eval_open_vocabulary.py
└── wsovod
├── __init__.py
├── config
├── __init__.py
└── defaults.py
├── data
├── __init__.py
├── build.py
├── build_multi_dataset.py
├── common.py
├── dataset_mapper.py
├── datasets
│ ├── __init__.py
│ ├── builtin.py
│ ├── builtin_meta.py
│ └── pascal_voc.py
├── detection_utils.py
└── samplers
│ ├── __init__.py
│ └── distributed_sampler_multi_dataset.py
├── engine
├── __init__.py
├── defaults.py
├── hooks.py
└── trainer.py
├── evaluation
├── __init__.py
├── coco_evaluation.py
├── ov_coco_evaluation.py
└── pascal_voc_evaluation.py
├── layers
├── ROILoopPool
│ ├── ROILoopPool.h
│ ├── ROILoopPool_cpu.cpp
│ ├── ROILoopPool_cuda.cu
│ └── cuda_helpers.h
├── __init__.py
├── csc.py
├── csc
│ ├── csc.h
│ └── csc_cuda.cu
├── roi_loop_pool.py
└── vision.cpp
├── modeling
├── __init__.py
├── backbone
│ ├── __init__.py
│ ├── mrrp_conv.py
│ ├── resnet_wsl.py
│ ├── resnet_wsl_mrrp.py
│ ├── swin_transformer.py
│ ├── vgg.py
│ └── vgg_mrrp.py
├── class_heads
│ ├── __init__.py
│ ├── data_aware_features_head.py
│ └── open_vocabulary_classifier.py
├── meta_arch
│ ├── __init__.py
│ ├── rcnn_wsovod.py
│ └── rcnn_wsovod_mixed_datasets.py
├── poolers.py
├── postprocessing.py
├── proposal_generator
│ ├── __init__.py
│ ├── proposal_utils.py
│ └── rpn.py
├── roi_heads
│ ├── __init__.py
│ ├── box_head.py
│ ├── fast_rcnn_open_vocabulary.py
│ └── roi_heads.py
├── test_time_augmentation_avg.py
└── test_time_augmentation_union.py
├── solver
├── __init__.py
├── build.py
└── lr_scheduler.py
└── utils
├── __init__.py
└── sam_predictor_with_buffer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # output dir
2 | output*
3 |
4 | *.pth
5 |
6 | # compilation and distribution
7 | __pycache__
8 | _ext
9 | *.pyc
10 | *.so
11 | wsovod.egg-info/
12 | build/
13 | dist/
14 |
15 | # pytorch/python/numpy formats
16 | *.pth
17 | *.pkl
18 | *.npy
19 |
20 | # ipython/jupyter notebooks
21 | *.ipynb
22 | **/.ipynb_checkpoints/
23 |
24 | # Editor temporaries
25 | *.swn
26 | *.swo
27 | *.swp
28 | *~
29 |
30 | # Pycharm editor settings
31 | .idea
32 |
33 | # project dirs
34 | /datasets/coco
35 | /datasets/lvis
36 | /datasets/ILSVRC2012
37 | /datasets/proposals
38 | /datasets/VOC*
39 | /models
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Weakly Supervised Open-Vocabulary Object Detection
2 | This is an official implementation for AAAI2024 paper "Weakly Supervised Open-Vocabulary Object Detection". (Code is coming soon!)
3 |
4 | ## 📋 Table of content
5 | 1. [📎 Paper Link](#1)
6 | 2. [💡 Abstract](#2)
7 | 3. [📖 Method](#3)
8 | 4. [🛠️ Install](#4)
9 | 5. [✏️ Usage](#5)
10 | 1. [Start](#51)
11 | 2. [Prepare Datasets](#52)
12 | 3. [Training](#53)
13 | 4. [Inference](#54)
14 | 6. [🔍 Citation](#6)
15 | 7. [❤️ Acknowledgement](#7)
16 |
17 | ## 📎 Paper Link
18 | [Read our arXiv Paper](https://arxiv.org/abs/2312.12437)
19 |
20 | ## 💡 Abstract
21 | Despite weakly supervised object detection (WSOD) being a promising step toward evading strong instance-level annotations, its capability is confined to closed-set categories within a single training dataset. In this paper, we propose a novel weakly supervised open-vocabulary object detection framework, namely WSOVOD, to extend traditional WSOD to detect novel concepts and utilize diverse datasets with only image-level annotations. To achieve this, we explore three vital strategies, including dataset-level feature adaptation, image-level salient object localization, and region-level vision-language alignment. First, we perform data-aware feature extraction to produce an input-conditional coefficient, which is leveraged into dataset attribute prototypes to identify dataset bias and help achieve cross-dataset generalization. Second, a customized location-oriented weakly supervised region proposal network is proposed to utilize high-level semantic layouts from the category-agnostic segment anything model to distinguish object boundaries. Lastly, we introduce a proposal-concept synchronized multiple-instance network, i.e., object mining and refinement with visual-semantic alignment, to discover objects matched to the text embeddings of concepts. Extensive experiments on Pascal VOC and MS COCO demonstrate that the proposed WSOVOD achieves new state-of-the-art compared with previous WSOD methods in both close-set object localization and detection tasks. Meanwhile, WSOVOD enables cross-dataset and open-vocabulary learning to achieve on-par or even better performance than well-established fully-supervised open-vocabulary object detection (FSOVOD).
22 |
23 | ## 📖 Method
24 |
25 | The overall of our **WSOVOD**.
26 |
27 |
28 |
29 |
30 | ## 🛠️ Install
31 | ```
32 | conda create --name wsovod python=3.9
33 | conda activate wsovod
34 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
35 | pip install -r requirements.txt
36 | pip install -e .
37 | ```
38 |
39 | ## ✏️ Usage
40 | 1、Please follow [this](datasets/README.md) to prepare datasets for training.
41 |
42 | 2、Download SAM checkpoints.
43 | ```
44 | mkdir tools/sam_checkpoints & cd tools/sam_checkpoints
45 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
46 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
47 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
48 | ```
49 |
50 | 3、Prepare SAM proposals for WSOVOD, take voc_2007_train for example.
51 | ```
52 | bash scripts/generate_sam_proposals_cuda.sh 4 --checkpoint tools/sam_checkpoints/sam_vit_h_4b8939.pth --model-type vit_h --points-per-side 32 --pred-iou-thresh 0.86 --stability-score-thresh 0.92 --crop-n-layers 1 --crop-n-points-downscale-factor 2 --min-mask-region-area 20.0 --dataset-name voc_2007_train --output datasets/proposals/sam_voc_2007_train_d2.pkl
53 | ```
54 |
55 | 4、Prepare class text embeddings for WSOVOD, take COCO for example.
56 | ```
57 | python tools/generate_class_text_embedding_cuda.py --dataset-name coco_2017_val --mode-type ViT-L/14/32 --prompt-type single --output models/coco_text_embedding_single_prompt.pkl
58 | ```
59 |
60 | 5、Download backbone pretrained from [here](https://1drv.ms/f/s!Am1oWgo9554dgRQ8RE1SRGvK7HW2).
61 |
62 | 6、Train a single dataset and test on another dataset, take COCO and VOC for example.
63 | ```
64 | bash scripts/train_script.sh tools/train_net.py configs/COCO-Detection/WSOVOD_WSR_18_DC5_1x.yaml 4 20240301
65 |
66 | python tools/train_net.py --config-file configs/PascalVOC-Detection/WSOVOD_WSR_18_DC5_1x.yaml --num-gpus 4 --eval-only MODEL.WEIGHTS output/configs/COCO-Detection/WSOVOD_WSR_50_DC5_1x_20240301/model_final.pth
67 | ```
68 | 7、Train mix datasets, take COCO and VOC for example.
69 | ```
70 |
71 | ```
72 |
73 | ## 🔍 Citation
74 | If you find WSOVOD useful in your research, please consider citing:
75 |
76 | ```
77 | @InProceedings{WSOVOD_2024_AAAI,
78 | author = {Lin, Jianghang and Shen, Yunhang and Wang, Bingquan and Lin, Shaohui and Li, Ke and Cao, Liujuan},
79 | title = {Weakly Supervised Open-Vocabulary Object Detection},
80 | booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
81 | year = {2024},
82 | }
83 | ```
84 |
85 |
86 |
87 | ## License
88 |
89 | WSOVOD is released under the [Apache 2.0 license](LICENSE).
90 |
91 | ## ❤️ Acknowledgement
92 | - [UWSOD](https://github.com/shenyunhang/UWSOD)
93 | - [detectron2](https://github.com/facebookresearch/detectron2)
94 |
--------------------------------------------------------------------------------
/configs/COCO-Detection/Base-RCNN-DilatedC5.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "GeneralizedRCNN_WSOVOD"
3 | PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
4 | LOAD_PROPOSALS: True
5 | PROPOSAL_GENERATOR:
6 | NAME: "WSOVODRPN_V2"
7 | MIN_SIZE: 40
8 | RPN:
9 | IN_FEATURES: ["res5"]
10 | PRE_NMS_TOPK_TRAIN: 2048
11 | PRE_NMS_TOPK_TEST: 2048
12 | POST_NMS_TOPK_TRAIN: 1024
13 | POST_NMS_TOPK_TEST: 1024
14 | NMS_THRESH: 0.7
15 | BATCH_SIZE_PER_IMAGE: 512
16 | # BATCH_SIZE_PER_IMAGE: 256 # default rpn
17 | POSITIVE_FRACTION: 0.5
18 | BBOX_REG_LOSS_TYPE: "smooth_l1"
19 | IOU_THRESHOLDS: [0.2, 0.6]
20 | # IOU_THRESHOLDS: [0.3, 0.7] # default rpn
21 | ROI_HEADS:
22 | NAME: "WSOVODROIHeads"
23 | NUM_CLASSES: 80
24 | SCORE_THRESH_TEST: 0.00001
25 | NMS_THRESH_TEST: 0.3
26 | BATCH_SIZE_PER_IMAGE: 4096
27 | POSITIVE_FRACTION: 1.0
28 | PROPOSAL_APPEND_GT: False
29 | ROI_BOX_HEAD:
30 | NAME: "DiscriminativeAdaptationNeck"
31 | POOLER_TYPE: "ROIPool"
32 | NUM_CONV: 0
33 | NUM_FC: 2
34 | DAN_DIM: [4096, 4096]
35 | POOLER_RESOLUTION: 7
36 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
37 | OPEN_VOCABULARY:
38 | WEIGHT_PATH_TRAIN: "models/coco_text_embedding_single_prompt.pkl"
39 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl"
40 | WEIGHT_DIM: 512
41 | USE_BIAS: 0.0
42 | NORM_WEIGHT: True
43 | NORM_TEMP: 50.0
44 | DATA_AWARE: True
45 | WSOVOD:
46 | ITER_SIZE: 1
47 | BBOX_REFINE:
48 | ENABLE: True
49 | OBJECT_MINING:
50 | MEAN_LOSS: True
51 | INSTANCE_REFINEMENT:
52 | REFINE_NUM: 1
53 | REFINE_REG: [True]
54 | SAMPLING:
55 | SAMPLING_ON: True
56 | IOU_THRESHOLDS: [[0.5],]
57 | IOU_LABELS: [[0, 1],]
58 | BATCH_SIZE_PER_IMAGE: [4096,]
59 | POSITIVE_FRACTION: [1.0,]
60 | SOLVER:
61 | IMS_PER_BATCH: 4
62 | BASE_LR: 0.01
63 | STEPS: (140000,)
64 | MAX_ITER: 200000
65 | REFERENCE_WORLD_SIZE: 4
66 | INPUT:
67 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216)
68 | MAX_SIZE_TRAIN: 2000
69 | MIN_SIZE_TEST: 688
70 | MAX_SIZE_TEST: 4000
71 | CROP:
72 | ENABLED: True
73 | TEST:
74 | AUG:
75 | ENABLED: True
76 | MIN_SIZES: (480, 576, 672, 768, 864, 960, 1056, 1152)
77 | MAX_SIZE: 4000
78 | FLIP: True
79 | EVAL_PERIOD: 5000
80 | EVAL_TRAIN: False
81 | DATASETS:
82 | TRAIN: ('coco_2017_train',)
83 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_coco_2017_train_d2.pkl', )
84 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000
85 | TEST: ('coco_2017_val',)
86 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', )
87 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000
88 | VIS_PERIOD: 10000
89 | VERSION: 2
--------------------------------------------------------------------------------
/configs/COCO-Detection/WSOVOD_MRRP_V_16_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/VGG/VGG_ILSVRC_16_layers_v1_d2.pkl"
4 | PIXEL_MEAN: [103.939, 116.779, 123.68]
5 | BACKBONE:
6 | NAME: "build_mrrp_vgg_backbone"
7 | FREEZE_AT: 5
8 | MRRP:
9 | MRRP_ON: True
10 | NUM_BRANCH: 3
11 | BRANCH_DILATIONS: [1, 2, 4]
12 | TEST_BRANCH_IDX: -1
13 | MRRP_STAGE: "plain5"
14 | VGG:
15 | DEPTH: 16
16 | CONV5_DILATION: 2
17 | ANCHOR_GENERATOR:
18 | SIZES: [[32, 64], [128, 256], [512, 768]]
19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
20 | ROI_HEADS:
21 | IN_FEATURES: ["plain5"]
22 | SOLVER:
23 | STEPS: (140000,)
24 | MAX_ITER: 200000
25 | WARMUP_ITERS: 200
26 | IMS_PER_BATCH: 4
27 | BASE_LR: 0.001
28 | WEIGHT_DECAY: 0.0005
29 | BIAS_LR_FACTOR: 2.0
30 | WEIGHT_DECAY_BIAS: 0.0
--------------------------------------------------------------------------------
/configs/COCO-Detection/WSOVOD_MRRP_WSR_18_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_mrrp_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | MRRP:
8 | MRRP_ON: True
9 | NUM_BRANCH: 3
10 | BRANCH_DILATIONS: [1, 2, 4]
11 | TEST_BRANCH_IDX: -1
12 | MRRP_STAGE: "res5"
13 | RESNETS:
14 | DEPTH: 18
15 | RES5_DILATION: 2
16 | RES2_OUT_CHANNELS: 64
17 | OUT_FEATURES: ["res5"]
18 | ANCHOR_GENERATOR:
19 | SIZES: [[32, 64], [128, 256], [512, 768]]
20 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
21 | ROI_HEADS:
22 | IN_FEATURES: ["res5"]
23 | ROI_BOX_HEAD:
24 | POOLER_TYPE: "ROILoopPool"
25 | NUM_CONV: 0
26 | NUM_FC: 2
27 | DAN_DIM: [4096, 4096]
28 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
29 | SOLVER:
30 | STEPS: (140000,)
31 | MAX_ITER: 200000
32 | WARMUP_ITERS: 200
33 | IMS_PER_BATCH: 4
34 | BASE_LR: 0.01
35 | WEIGHT_DECAY: 0.0005
36 | BIAS_LR_FACTOR: 2.0
37 | WEIGHT_DECAY_BIAS: 0.0
--------------------------------------------------------------------------------
/configs/COCO-Detection/WSOVOD_MRRP_WSR_50_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_mrrp_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | MRRP:
8 | MRRP_ON: True
9 | NUM_BRANCH: 3
10 | BRANCH_DILATIONS: [1, 2, 4]
11 | TEST_BRANCH_IDX: -1
12 | MRRP_STAGE: "res5"
13 | RESNETS:
14 | DEPTH: 50
15 | RES5_DILATION: 2
16 | OUT_FEATURES: ["res5"]
17 | ANCHOR_GENERATOR:
18 | SIZES: [[32, 64], [128, 256], [512, 768]]
19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
20 | ROI_HEADS:
21 | IN_FEATURES: ["res5"]
22 | ROI_BOX_HEAD:
23 | POOLER_TYPE: "ROILoopPool"
24 | NUM_CONV: 0
25 | NUM_FC: 2
26 | DAN_DIM: [4096, 4096]
27 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
28 | SOLVER:
29 | STEPS: (140000,)
30 | MAX_ITER: 200000
31 | WARMUP_ITERS: 200
32 | IMS_PER_BATCH: 4
33 | BASE_LR: 0.01
34 | WEIGHT_DECAY: 0.0005
35 | BIAS_LR_FACTOR: 2.0
36 | WEIGHT_DECAY_BIAS: 0.0
--------------------------------------------------------------------------------
/configs/COCO-Detection/WSOVOD_V_16_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/VGG/VGG_ILSVRC_16_layers_v1_d2.pkl"
4 | PIXEL_MEAN: [103.939, 116.779, 123.68]
5 | BACKBONE:
6 | NAME: "build_vgg_backbone"
7 | FREEZE_AT: 5
8 | VGG:
9 | DEPTH: 16
10 | CONV5_DILATION: 2
11 | ANCHOR_GENERATOR:
12 | SIZES: [32, 64, 128, 256, 512, 768]
13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
14 | ROI_HEADS:
15 | IN_FEATURES: ["plain5"]
16 | SOLVER:
17 | STEPS: (140000,)
18 | MAX_ITER: 200000
19 | WARMUP_ITERS: 200
20 | IMS_PER_BATCH: 4
21 | BASE_LR: 0.001
22 | WEIGHT_DECAY: 0.0005
23 | BIAS_LR_FACTOR: 2.0
24 | WEIGHT_DECAY_BIAS: 0.0
--------------------------------------------------------------------------------
/configs/COCO-Detection/WSOVOD_WSR_18_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | RESNETS:
8 | DEPTH: 18
9 | RES5_DILATION: 2
10 | RES2_OUT_CHANNELS: 64
11 | OUT_FEATURES: ["res5"]
12 | ANCHOR_GENERATOR:
13 | SIZES: [32, 64, 128, 256, 512, 768]
14 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
15 | ROI_HEADS:
16 | IN_FEATURES: ["res5"]
17 | SOLVER:
18 | STEPS: (140000,)
19 | MAX_ITER: 200000
20 | WARMUP_ITERS: 200
21 | IMS_PER_BATCH: 4
22 | BASE_LR: 0.01
23 | WEIGHT_DECAY: 0.0005
24 | BIAS_LR_FACTOR: 2.0
25 | WEIGHT_DECAY_BIAS: 0.0
26 |
--------------------------------------------------------------------------------
/configs/COCO-Detection/WSOVOD_WSR_50_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | RESNETS:
8 | DEPTH: 50
9 | RES5_DILATION: 2
10 | OUT_FEATURES: ["res5"]
11 | ANCHOR_GENERATOR:
12 | SIZES: [32, 64, 128, 256, 512, 768]
13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
14 | ROI_HEADS:
15 | IN_FEATURES: ["res5"]
16 | SOLVER:
17 | STEPS: (140000,)
18 | MAX_ITER: 200000
19 | WARMUP_ITERS: 200
20 | IMS_PER_BATCH: 4
21 | BASE_LR: 0.01
22 | WEIGHT_DECAY: 0.0005
23 | BIAS_LR_FACTOR: 2.0
24 | WEIGHT_DECAY_BIAS: 0.0
25 |
--------------------------------------------------------------------------------
/configs/ImageNet-Detection/Base-RCNN-DilatedC5.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "GeneralizedRCNN_WSOVOD"
3 | PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
4 | LOAD_PROPOSALS: True
5 | PROPOSAL_GENERATOR:
6 | NAME: "WSOVODRPN_V2"
7 | MIN_SIZE: 100
8 | RPN:
9 | IN_FEATURES: ["res5"]
10 | PRE_NMS_TOPK_TRAIN: 512
11 | PRE_NMS_TOPK_TEST: 512
12 | POST_NMS_TOPK_TRAIN: 16
13 | POST_NMS_TOPK_TEST: 256
14 | NMS_THRESH: 0.7
15 | BATCH_SIZE_PER_IMAGE: 512
16 | # BATCH_SIZE_PER_IMAGE: 256 # default rpn
17 | POSITIVE_FRACTION: 0.5
18 | BBOX_REG_LOSS_TYPE: "smooth_l1"
19 | IOU_THRESHOLDS: [0.3, 0.7]
20 | # IOU_THRESHOLDS: [0.3, 0.7] # default rpn
21 | ROI_HEADS:
22 | NAME: "WSOVODROIHeads"
23 | NUM_CLASSES: 1000
24 | SCORE_THRESH_TEST: 0.00001
25 | NMS_THRESH_TEST: 0.3
26 | BATCH_SIZE_PER_IMAGE: 4096
27 | POSITIVE_FRACTION: 1.0
28 | PROPOSAL_APPEND_GT: False
29 | ROI_BOX_HEAD:
30 | NAME: "DiscriminativeAdaptationNeck"
31 | POOLER_TYPE: "ROIPool"
32 | NUM_CONV: 0
33 | NUM_FC: 2
34 | DAN_DIM: [4096, 4096]
35 | POOLER_RESOLUTION: 7
36 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
37 | OPEN_VOCABULARY:
38 | WEIGHT_PATH_TRAIN: "models/in1k_text_embedding_single_prompt.pkl"
39 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl"
40 | WEIGHT_DIM: 512
41 | USE_BIAS: 0.0
42 | NORM_WEIGHT: True
43 | NORM_TEMP: 50.0
44 | DATA_AWARE: True
45 | WSOVOD:
46 | ITER_SIZE: 1
47 | BBOX_REFINE:
48 | ENABLE: True
49 | OBJECT_MINING:
50 | MEAN_LOSS: True
51 | INSTANCE_REFINEMENT:
52 | REFINE_NUM: 1
53 | REFINE_REG: [True]
54 | SAMPLING:
55 | SAMPLING_ON: True
56 | IOU_THRESHOLDS: [[0.5],]
57 | IOU_LABELS: [[0, 1],]
58 | BATCH_SIZE_PER_IMAGE: [4096,]
59 | POSITIVE_FRACTION: [1.0,]
60 | SOLVER:
61 | IMS_PER_BATCH: 4
62 | BASE_LR: 0.01
63 | STEPS: (140000,)
64 | MAX_ITER: 200000
65 | REFERENCE_WORLD_SIZE: 4
66 | INPUT:
67 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896)
68 | MAX_SIZE_TRAIN: 2000
69 | MIN_SIZE_TEST: 512
70 | MAX_SIZE_TEST: 2000
71 | CROP:
72 | ENABLED: True
73 | TEST:
74 | AUG:
75 | ENABLED: True
76 | MIN_SIZES: (480, 576, 672, 768, 864, 960, 1056)
77 | MAX_SIZE: 2000
78 | FLIP: True
79 | EVAL_PERIOD: 5000
80 | EVAL_TRAIN: False
81 | DATASETS:
82 | TRAIN: ('ilsvrc_2012_val',)
83 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_ilsvrc_2012_val_d2.pkl',)
84 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 10
85 | TEST: ('coco_2017_val',)
86 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', )
87 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000
88 | VIS_PERIOD: 10000
89 | VERSION: 2
--------------------------------------------------------------------------------
/configs/ImageNet-Detection/WSOVOD_WSR_18_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | RESNETS:
8 | DEPTH: 18
9 | RES5_DILATION: 2
10 | RES2_OUT_CHANNELS: 64
11 | OUT_FEATURES: ["res5"]
12 | ANCHOR_GENERATOR:
13 | SIZES: [32, 64, 128, 256, 512, 768]
14 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
15 | ROI_HEADS:
16 | IN_FEATURES: ["res5"]
17 | SOLVER:
18 | STEPS: (70000,)
19 | MAX_ITER: 100000
20 | WARMUP_ITERS: 200
21 | IMS_PER_BATCH: 4
22 | BASE_LR: 0.01
23 | WEIGHT_DECAY: 0.0005
24 | BIAS_LR_FACTOR: 2.0
25 | WEIGHT_DECAY_BIAS: 0.0
26 |
--------------------------------------------------------------------------------
/configs/MixedDatasets-Detection/Base-RCNN-DilatedC5.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "GeneralizedRCNN_WSOVOD_MixedDatasets"
3 | PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
4 | LOAD_PROPOSALS: True
5 | PROPOSAL_GENERATOR:
6 | NAME: "WSOVODRPN_V2"
7 | MIN_SIZE: 40
8 | RPN:
9 | IN_FEATURES: ["res5"]
10 | PRE_NMS_TOPK_TRAIN: 2048
11 | PRE_NMS_TOPK_TEST: 2048
12 | POST_NMS_TOPK_TRAIN: 1024
13 | POST_NMS_TOPK_TEST: 1024
14 | NMS_THRESH: 0.7
15 | BATCH_SIZE_PER_IMAGE: 512
16 | # BATCH_SIZE_PER_IMAGE: 256 # default rpn
17 | POSITIVE_FRACTION: 0.5
18 | BBOX_REG_LOSS_TYPE: "smooth_l1"
19 | IOU_THRESHOLDS: [0.2, 0.6]
20 | # IOU_THRESHOLDS: [0.3, 0.7] # default rpn
21 | ROI_HEADS:
22 | NAME: "WSOVODMixedDatasetsROIHeads"
23 | NUM_CLASSES: 80
24 | SCORE_THRESH_TEST: 0.00001
25 | NMS_THRESH_TEST: 0.3
26 | BATCH_SIZE_PER_IMAGE: 4096
27 | POSITIVE_FRACTION: 1.0
28 | PROPOSAL_APPEND_GT: False
29 | ROI_BOX_HEAD:
30 | NAME: "DiscriminativeAdaptationNeck"
31 | POOLER_TYPE: "ROIPool"
32 | NUM_CONV: 0
33 | NUM_FC: 2
34 | DAN_DIM: [4096, 4096]
35 | POOLER_RESOLUTION: 7
36 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
37 | OPEN_VOCABULARY:
38 | WEIGHT_PATH_TRAIN: "models/coco_text_embedding_single_prompt.pkl"
39 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl"
40 | WEIGHT_DIM: 512
41 | USE_BIAS: 0.0
42 | NORM_WEIGHT: True
43 | NORM_TEMP: 50.0
44 | DATA_AWARE: True
45 | WSOVOD:
46 | ITER_SIZE: 1
47 | BBOX_REFINE:
48 | ENABLE: True
49 | OBJECT_MINING:
50 | MEAN_LOSS: True
51 | INSTANCE_REFINEMENT:
52 | REFINE_NUM: 1
53 | REFINE_REG: [True]
54 | SAMPLING:
55 | SAMPLING_ON: True
56 | IOU_THRESHOLDS: [[0.5],]
57 | IOU_LABELS: [[0, 1],]
58 | BATCH_SIZE_PER_IMAGE: [4096,]
59 | POSITIVE_FRACTION: [1.0,]
60 | SOLVER:
61 | IMS_PER_BATCH: 4
62 | IMS_PER_BATCH_LIST: [4,4,4]
63 | BASE_LR: 0.01
64 | STEPS: (140000,)
65 | MAX_ITER: 200000
66 | REFERENCE_WORLD_SIZE: 4
67 | INPUT:
68 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216)
69 | MAX_SIZE_TRAIN: 2000
70 | MIN_SIZE_TEST: 688
71 | MAX_SIZE_TEST: 4000
72 | CROP:
73 | ENABLED: True
74 | TEST:
75 | AUG:
76 | ENABLED: True
77 | MIN_SIZES: (480, 576, 672, 768, 864, 960, 1056, 1152)
78 | MAX_SIZE: 4000
79 | FLIP: True
80 | EVAL_PERIOD: 5000
81 | EVAL_TRAIN: False
82 | DATASETS:
83 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000
84 | MIXED_DATASETS:
85 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"]
86 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train')
87 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl')
88 | NUM_CLASSES: [20,20,80]
89 | FILTER_EMPTY_ANNOTATIONS: [True, True, True]
90 | RATIOS: [1,1,20]
91 | USE_CAS: [False,False,False]
92 | USE_RFS: [False,False,False]
93 | CAS_LAMBDA: 1.0
94 | TEST: ('coco_2017_val',)
95 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', )
96 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000
97 | DATALOADER:
98 | SAMPLER_TRAIN: "MultiDatasetTrainingSampler"
99 | VIS_PERIOD: 10000
100 | VERSION: 2
--------------------------------------------------------------------------------
/configs/MixedDatasets-Detection/WSOVOD_MRRP_WSR_18_DC5_1x_voc07+coco.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_mrrp_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | MRRP:
8 | MRRP_ON: True
9 | NUM_BRANCH: 3
10 | BRANCH_DILATIONS: [1, 2, 4]
11 | TEST_BRANCH_IDX: -1
12 | MRRP_STAGE: "res5"
13 | RESNETS:
14 | DEPTH: 18
15 | RES5_DILATION: 2
16 | RES2_OUT_CHANNELS: 64
17 | OUT_FEATURES: ["res5"]
18 | ANCHOR_GENERATOR:
19 | SIZES: [[32, 64], [128, 256], [512, 768]]
20 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
21 | ROI_HEADS:
22 | IN_FEATURES: ["res5"]
23 | ROI_BOX_HEAD:
24 | OPEN_VOCABULARY:
25 | DATA_AWARE: True
26 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl"
27 | POOLER_TYPE: "ROILoopPool"
28 | NUM_CONV: 0
29 | NUM_FC: 2
30 | DAN_DIM: [4096, 4096]
31 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
32 | SOLVER:
33 | STEPS: (140000,)
34 | MAX_ITER: 200000
35 | WARMUP_ITERS: 200
36 | IMS_PER_BATCH: 4
37 | BASE_LR: 0.01
38 | WEIGHT_DECAY: 0.0005
39 | BIAS_LR_FACTOR: 2.0
40 | WEIGHT_DECAY_BIAS: 0.0
41 | WSOVOD:
42 | BBOX_REFINE:
43 | ENABLE: True
44 | INSTANCE_REFINEMENT:
45 | REFINE_NUM: 1
46 | REFINE_REG: [True]
47 | DATASETS:
48 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000
49 | MIXED_DATASETS:
50 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"]
51 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train')
52 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl')
53 | NUM_CLASSES: [20,20,80]
54 | FILTER_EMPTY_ANNOTATIONS: [True, True, True]
55 | RATIOS: [1,1,20]
56 | USE_CAS: [False,False,False]
57 | USE_RFS: [False,False,False]
58 | CAS_LAMBDA: 1.0
59 | TEST: ('coco_2017_val',)
60 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', )
61 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000
--------------------------------------------------------------------------------
/configs/MixedDatasets-Detection/WSOVOD_MRRP_WSR_50_DC5_1x_voc07+coco.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_mrrp_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | MRRP:
8 | MRRP_ON: True
9 | NUM_BRANCH: 3
10 | BRANCH_DILATIONS: [1, 2, 4]
11 | TEST_BRANCH_IDX: -1
12 | MRRP_STAGE: "res5"
13 | RESNETS:
14 | DEPTH: 50
15 | RES5_DILATION: 2
16 | OUT_FEATURES: ["res5"]
17 | ANCHOR_GENERATOR:
18 | SIZES: [[32, 64], [128, 256], [512, 768]]
19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
20 | ROI_HEADS:
21 | IN_FEATURES: ["res5"]
22 | ROI_BOX_HEAD:
23 | OPEN_VOCABULARY:
24 | DATA_AWARE: True
25 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl"
26 | POOLER_TYPE: "ROILoopPool"
27 | NUM_CONV: 0
28 | NUM_FC: 2
29 | DAN_DIM: [4096, 4096]
30 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
31 | SOLVER:
32 | STEPS: (140000,)
33 | MAX_ITER: 200000
34 | WARMUP_ITERS: 200
35 | IMS_PER_BATCH: 4
36 | BASE_LR: 0.01
37 | WEIGHT_DECAY: 0.0005
38 | BIAS_LR_FACTOR: 2.0
39 | WEIGHT_DECAY_BIAS: 0.0
40 | WSOVOD:
41 | BBOX_REFINE:
42 | ENABLE: True
43 | INSTANCE_REFINEMENT:
44 | REFINE_NUM: 1
45 | REFINE_REG: [True]
46 | DATASETS:
47 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000
48 | MIXED_DATASETS:
49 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"]
50 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train')
51 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl')
52 | NUM_CLASSES: [20,20,80]
53 | FILTER_EMPTY_ANNOTATIONS: [True, True, True]
54 | RATIOS: [1,1,20]
55 | USE_CAS: [False,False,False]
56 | USE_RFS: [False,False,False]
57 | CAS_LAMBDA: 1.0
58 | TEST: ('coco_2017_val',)
59 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', )
60 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000
--------------------------------------------------------------------------------
/configs/MixedDatasets-Detection/WSOVOD_WSR_18_DC5_1x_voc07+coco.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | RESNETS:
8 | DEPTH: 18
9 | RES5_DILATION: 2
10 | RES2_OUT_CHANNELS: 64
11 | OUT_FEATURES: ["res5"]
12 | ANCHOR_GENERATOR:
13 | SIZES: [32, 64, 128, 256, 512, 768]
14 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
15 | ROI_HEADS:
16 | IN_FEATURES: ["res5"]
17 | ROI_BOX_HEAD:
18 | OPEN_VOCABULARY:
19 | DATA_AWARE: True
20 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl"
21 | SOLVER:
22 | STEPS: (140000,)
23 | MAX_ITER: 200000
24 | WARMUP_ITERS: 200
25 | IMS_PER_BATCH: 4
26 | BASE_LR: 0.01
27 | WEIGHT_DECAY: 0.0005
28 | BIAS_LR_FACTOR: 2.0
29 | WEIGHT_DECAY_BIAS: 0.0
30 | WSOVOD:
31 | BBOX_REFINE:
32 | ENABLE: True
33 | INSTANCE_REFINEMENT:
34 | REFINE_NUM: 1
35 | REFINE_REG: [True]
36 | DATASETS:
37 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000
38 | MIXED_DATASETS:
39 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"]
40 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train')
41 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl')
42 | NUM_CLASSES: [20,20,80]
43 | FILTER_EMPTY_ANNOTATIONS: [True, True, True]
44 | RATIOS: [1,1,20]
45 | USE_CAS: [False,False,False]
46 | USE_RFS: [False,False,False]
47 | CAS_LAMBDA: 1.0
48 | TEST: ('coco_2017_val',)
49 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', )
50 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000
--------------------------------------------------------------------------------
/configs/MixedDatasets-Detection/WSOVOD_WSR_50_DC5_1x_voc07+coco.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | RESNETS:
8 | DEPTH: 50
9 | RES5_DILATION: 2
10 | OUT_FEATURES: ["res5"]
11 | ANCHOR_GENERATOR:
12 | SIZES: [32, 64, 128, 256, 512, 768]
13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
14 | ROI_HEADS:
15 | IN_FEATURES: ["res5"]
16 | ROI_BOX_HEAD:
17 | OPEN_VOCABULARY:
18 | DATA_AWARE: True
19 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl"
20 | SOLVER:
21 | STEPS: (140000,)
22 | MAX_ITER: 200000
23 | WARMUP_ITERS: 200
24 | IMS_PER_BATCH: 4
25 | BASE_LR: 0.01
26 | WEIGHT_DECAY: 0.0005
27 | BIAS_LR_FACTOR: 2.0
28 | WEIGHT_DECAY_BIAS: 0.0
29 | WSOVOD:
30 | BBOX_REFINE:
31 | ENABLE: True
32 | INSTANCE_REFINEMENT:
33 | REFINE_NUM: 1
34 | REFINE_REG: [True]
35 | DATASETS:
36 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000
37 | MIXED_DATASETS:
38 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"]
39 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train')
40 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl')
41 | NUM_CLASSES: [20,20,80]
42 | FILTER_EMPTY_ANNOTATIONS: [True, True, True]
43 | RATIOS: [1,1,20]
44 | USE_CAS: [False,False,False]
45 | USE_RFS: [False,False,False]
46 | CAS_LAMBDA: 1.0
47 | TEST: ('coco_2017_val',)
48 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', )
49 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/Base-RCNN-DilatedC5.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "GeneralizedRCNN_WSOVOD"
3 | PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
4 | LOAD_PROPOSALS: True
5 | PROPOSAL_GENERATOR:
6 | NAME: "WSOVODRPN_V2"
7 | MIN_SIZE: 40
8 | RPN:
9 | IN_FEATURES: ["res5"]
10 | PRE_NMS_TOPK_TRAIN: 2048
11 | PRE_NMS_TOPK_TEST: 2048
12 | POST_NMS_TOPK_TRAIN: 1024
13 | POST_NMS_TOPK_TEST: 1024
14 | NMS_THRESH: 0.7
15 | BATCH_SIZE_PER_IMAGE: 512
16 | # BATCH_SIZE_PER_IMAGE: 256 # default rpn
17 | POSITIVE_FRACTION: 0.5
18 | BBOX_REG_LOSS_TYPE: "smooth_l1"
19 | IOU_THRESHOLDS: [0.2, 0.6]
20 | # IOU_THRESHOLDS: [0.3, 0.7] # default rpn
21 | ROI_HEADS:
22 | NAME: "WSOVODROIHeads"
23 | NUM_CLASSES: 20
24 | SCORE_THRESH_TEST: 0.00001
25 | NMS_THRESH_TEST: 0.3
26 | BATCH_SIZE_PER_IMAGE: 4096
27 | POSITIVE_FRACTION: 1.0
28 | PROPOSAL_APPEND_GT: False
29 | ROI_BOX_HEAD:
30 | NAME: "DiscriminativeAdaptationNeck"
31 | POOLER_TYPE: "ROIPool"
32 | NUM_CONV: 0
33 | NUM_FC: 2
34 | DAN_DIM: [4096, 4096]
35 | POOLER_RESOLUTION: 7
36 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
37 | OPEN_VOCABULARY:
38 | WEIGHT_PATH_TRAIN: "models/voc_text_embedding_single_prompt.pkl"
39 | WEIGHT_PATH_TEST: "models/voc_text_embedding_single_prompt.pkl"
40 | WEIGHT_DIM: 512
41 | USE_BIAS: 0.0
42 | NORM_WEIGHT: True
43 | NORM_TEMP: 50.0
44 | DATA_AWARE: True
45 | WSOVOD:
46 | ITER_SIZE: 1
47 | BBOX_REFINE:
48 | ENABLE: True
49 | OBJECT_MINING:
50 | MEAN_LOSS: True
51 | INSTANCE_REFINEMENT:
52 | REFINE_NUM: 1
53 | REFINE_REG: [True]
54 | SAMPLING:
55 | SAMPLING_ON: True
56 | IOU_THRESHOLDS: [[0.5],]
57 | IOU_LABELS: [[0, 1],]
58 | BATCH_SIZE_PER_IMAGE: [4096,]
59 | POSITIVE_FRACTION: [1.0,]
60 | SOLVER:
61 | IMS_PER_BATCH: 4
62 | BASE_LR: 0.02
63 | STEPS: (60000, 80000)
64 | MAX_ITER: 90000
65 | REFERENCE_WORLD_SIZE: 4
66 | INPUT:
67 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216)
68 | MAX_SIZE_TRAIN: 2000
69 | MIN_SIZE_TEST: 688
70 | MAX_SIZE_TEST: 4000
71 | CROP:
72 | ENABLED: True
73 | TEST:
74 | AUG:
75 | ENABLED: True
76 | MIN_SIZES: (480, 576, 672, 768, 864, 960, 1056, 1152)
77 | MAX_SIZE: 4000
78 | FLIP: True
79 | EVAL_PERIOD: 5000
80 | EVAL_TRAIN: True
81 | DATASETS:
82 | TRAIN: ('voc_2007_train', 'voc_2007_val')
83 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl')
84 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000
85 | TEST: ('voc_2007_test',)
86 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', )
87 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000
88 | VIS_PERIOD: 5000
89 | VERSION: 2
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_MRRP_V_16_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/VGG/VGG_ILSVRC_16_layers_v1_d2.pkl"
4 | PIXEL_MEAN: [103.939, 116.779, 123.68]
5 | BACKBONE:
6 | NAME: "build_mrrp_vgg_backbone"
7 | FREEZE_AT: 5
8 | MRRP:
9 | MRRP_ON: True
10 | NUM_BRANCH: 3
11 | BRANCH_DILATIONS: [1, 2, 4]
12 | TEST_BRANCH_IDX: -1
13 | MRRP_STAGE: "plain5"
14 | VGG:
15 | DEPTH: 16
16 | CONV5_DILATION: 2
17 | ANCHOR_GENERATOR:
18 | SIZES: [[32, 64], [128, 256], [512, 768]]
19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
20 | ROI_HEADS:
21 | IN_FEATURES: ["plain5"]
22 | SOLVER:
23 | STEPS: (70000,)
24 | MAX_ITER: 100000
25 | WARMUP_ITERS: 200
26 | IMS_PER_BATCH: 4
27 | BASE_LR: 0.001
28 | WEIGHT_DECAY: 0.0005
29 | BIAS_LR_FACTOR: 2.0
30 | WEIGHT_DECAY_BIAS: 0.0
31 |
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_MRRP_V_16_DC5_VOC12_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "WSOVOD_MRRP_V_16_DC5_1x.yaml"
2 | DATASETS:
3 | TRAIN: ('voc_2012_train', 'voc_2012_val')
4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl')
5 | TEST: ('voc_2007_test',)
6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', )
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_MRRP_WSR_18_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_mrrp_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | MRRP:
8 | MRRP_ON: True
9 | NUM_BRANCH: 3
10 | BRANCH_DILATIONS: [1, 2, 4]
11 | TEST_BRANCH_IDX: -1
12 | MRRP_STAGE: "res5"
13 | RESNETS:
14 | DEPTH: 18
15 | RES5_DILATION: 2
16 | RES2_OUT_CHANNELS: 64
17 | OUT_FEATURES: ["res5"]
18 | ANCHOR_GENERATOR:
19 | SIZES: [[32, 64], [128, 256], [512, 768]]
20 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
21 | ROI_HEADS:
22 | IN_FEATURES: ["res5"]
23 | ROI_BOX_HEAD:
24 | POOLER_TYPE: "ROILoopPool"
25 | NUM_CONV: 0
26 | NUM_FC: 2
27 | DAN_DIM: [4096, 4096]
28 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
29 | SOLVER:
30 | STEPS: (70000,)
31 | MAX_ITER: 100000
32 | WARMUP_ITERS: 200
33 | IMS_PER_BATCH: 4
34 | BASE_LR: 0.01
35 | WEIGHT_DECAY: 0.0005
36 | BIAS_LR_FACTOR: 2.0
37 | WEIGHT_DECAY_BIAS: 0.0
38 |
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_MRRP_WSR_18_DC5_VOC12_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "WSOVOD_MRRP_WSR_18_DC5_1x.yaml"
2 | DATASETS:
3 | TRAIN: ('voc_2012_train', 'voc_2012_val')
4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl')
5 | TEST: ('voc_2007_test',)
6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', )
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_MRRP_WSR_50_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_mrrp_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | MRRP:
8 | MRRP_ON: True
9 | NUM_BRANCH: 3
10 | BRANCH_DILATIONS: [1, 2, 4]
11 | TEST_BRANCH_IDX: -1
12 | MRRP_STAGE: "res5"
13 | RESNETS:
14 | DEPTH: 50
15 | RES5_DILATION: 2
16 | OUT_FEATURES: ["res5"]
17 | ANCHOR_GENERATOR:
18 | SIZES: [[32, 64], [128, 256], [512, 768]]
19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
20 | ROI_HEADS:
21 | IN_FEATURES: ["res5"]
22 | ROI_BOX_HEAD:
23 | POOLER_TYPE: "ROILoopPool"
24 | NUM_CONV: 0
25 | NUM_FC: 2
26 | DAN_DIM: [4096, 4096]
27 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted"
28 | SOLVER:
29 | STEPS: (70000,)
30 | MAX_ITER: 100000
31 | WARMUP_ITERS: 200
32 | IMS_PER_BATCH: 4
33 | BASE_LR: 0.01
34 | WEIGHT_DECAY: 0.0005
35 | BIAS_LR_FACTOR: 2.0
36 | WEIGHT_DECAY_BIAS: 0.0
37 |
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_MRRP_WSR_50_DC5_VOC12_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "WSOVOD_MRRP_WSR_50_DC5_1x.yaml"
2 | DATASETS:
3 | TRAIN: ('voc_2012_train', 'voc_2012_val')
4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl')
5 | TEST: ('voc_2007_test',)
6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', )
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_V_16_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/VGG/VGG_ILSVRC_16_layers_v1_d2.pkl"
4 | PIXEL_MEAN: [103.939, 116.779, 123.68]
5 | BACKBONE:
6 | NAME: "build_vgg_backbone"
7 | FREEZE_AT: 5
8 | VGG:
9 | DEPTH: 16
10 | CONV5_DILATION: 2
11 | ANCHOR_GENERATOR:
12 | SIZES: [32, 64, 128, 256, 512, 768]
13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
14 | ROI_HEADS:
15 | IN_FEATURES: ["plain5"]
16 | SOLVER:
17 | STEPS: (70000,)
18 | MAX_ITER: 100000
19 | WARMUP_ITERS: 200
20 | IMS_PER_BATCH: 4
21 | BASE_LR: 0.001
22 | WEIGHT_DECAY: 0.0005
23 | BIAS_LR_FACTOR: 2.0
24 | WEIGHT_DECAY_BIAS: 0.0
25 |
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_V_16_DC5_VOC12_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "WSOVOD_V_16_DC5_1x.yaml"
2 | DATASETS:
3 | TRAIN: ('voc_2012_train', 'voc_2012_val')
4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl')
5 | TEST: ('voc_2007_test',)
6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', )
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_WSR_18_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | RESNETS:
8 | DEPTH: 18
9 | RES5_DILATION: 2
10 | RES2_OUT_CHANNELS: 64
11 | OUT_FEATURES: ["res5"]
12 | ANCHOR_GENERATOR:
13 | SIZES: [32, 64, 128, 256, 512, 768]
14 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
15 | ROI_HEADS:
16 | IN_FEATURES: ["res5"]
17 | SOLVER:
18 | STEPS: (70000,)
19 | MAX_ITER: 100000
20 | WARMUP_ITERS: 200
21 | IMS_PER_BATCH: 4
22 | BASE_LR: 0.01
23 | WEIGHT_DECAY: 0.0005
24 | BIAS_LR_FACTOR: 2.0
25 | WEIGHT_DECAY_BIAS: 0.0
26 |
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_WSR_18_DC5_VOC12_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "WSOVOD_WSR_18_DC5_1x.yaml"
2 | DATASETS:
3 | TRAIN: ('voc_2012_train', 'voc_2012_val')
4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl')
5 | TEST: ('voc_2007_test',)
6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', )
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_WSR_50_DC5_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-DilatedC5.yaml"
2 | MODEL:
3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl"
4 | BACKBONE:
5 | NAME: "build_wsl_resnet_backbone"
6 | FREEZE_AT: 5
7 | RESNETS:
8 | DEPTH: 50
9 | RES5_DILATION: 2
10 | OUT_FEATURES: ["res5"]
11 | ANCHOR_GENERATOR:
12 | SIZES: [32, 64, 128, 256, 512, 768]
13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]]
14 | ROI_HEADS:
15 | IN_FEATURES: ["res5"]
16 | SOLVER:
17 | STEPS: (70000,)
18 | MAX_ITER: 100000
19 | WARMUP_ITERS: 200
20 | IMS_PER_BATCH: 4
21 | BASE_LR: 0.01
22 | WEIGHT_DECAY: 0.0005
23 | BIAS_LR_FACTOR: 2.0
24 | WEIGHT_DECAY_BIAS: 0.0
25 |
--------------------------------------------------------------------------------
/configs/PascalVOC-Detection/WSOVOD_WSR_50_DC5_VOC12_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "WSOVOD_WSR_50_DC5_1x.yaml"
2 | DATASETS:
3 | TRAIN: ('voc_2012_train', 'voc_2012_val')
4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl')
5 | TEST: ('voc_2007_test',)
6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', )
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Use Builtin Datasets
2 |
3 | A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog)
4 | for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc).
5 | This document explains how to setup the builtin datasets so they can be used by the above APIs.
6 | [Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`,
7 | and how to add new datasets to them.
8 |
9 | Detectron2 has builtin support for a few datasets.
10 | The datasets are assumed to exist in a directory specified by the environment variable
11 | `DETECTRON2_DATASETS`.
12 | Under this directory, detectron2 will look for datasets in the structure described below, if needed.
13 | ```
14 | $DETECTRON2_DATASETS/
15 | coco/
16 | lvis/
17 | ILSVRC2012/
18 | VOC20{07,12}/
19 | ```
20 |
21 | You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`.
22 | If left unset, the default is `./datasets` relative to your current working directory.
23 |
24 | ## Expected dataset structure for COCO instance/keypoint detection:
25 |
26 | ```
27 | coco/
28 | annotations/
29 | instances_{train,val}2017.json
30 | person_keypoints_{train,val}2017.json
31 | {train,val}2017/
32 | # image files that are mentioned in the corresponding json
33 | ```
34 |
35 | You can use the 2014 version of the dataset as well.
36 |
37 | Some of the builtin tests (`dev/run_*_tests.sh`) uses a tiny version of the COCO dataset,
38 | which you can download with `./prepare_for_tests.sh`.
39 |
40 |
41 | ## Expected dataset structure for LVIS instance segmentation:
42 | ```
43 | coco/
44 | {train,val,test}2017/
45 | lvis/
46 | lvis_v0.5_{train,val}.json
47 | lvis_v0.5_image_info_test.json
48 | lvis_v1_{train,val}.json
49 | lvis_v1_image_info_test{,_challenge}.json
50 | ```
51 |
52 | Install lvis-api by:
53 | ```
54 | pip install git+https://github.com/lvis-dataset/lvis-api.git
55 | ```
56 |
57 | To evaluate models trained on the COCO dataset using LVIS annotations,
58 | run `python prepare_cocofied_lvis.py` to prepare "cocofied" LVIS annotations.
59 |
60 | ## Expected dataset structure for Pascal VOC:
61 | ```
62 | VOC20{07,12}/
63 | Annotations/
64 | ImageSets/
65 | Main/
66 | trainval.txt
67 | test.txt
68 | # train.txt or val.txt, if you use these splits
69 | JPEGImages/
70 | ```
71 |
72 | ## Expected dataset structure for ILSVRC2012:
73 | Go to [this link](https://www.image-net.org/download-images.php) to download tar files (for training and validation)
74 | ```
75 | ├── ILSVRC2012_img_train.tar
76 | └── ILSVRC2012_img_val.tar
77 | ```
78 | Run bash scripts/extract_ilsvrc.sh for handling the above compressed files.
79 | Be sure if they are arranged like below:
80 | ```
81 | ./train
82 | ├── n07693725
83 | ├── ...
84 | └── n07614500
85 | ./val
86 | ├── n01440764
87 | ├── ...
88 | └── n04458633
89 | ```
90 | Run below scripts for getting json annotations.
91 | ```
92 | bash scripts/prepare_ilsvrc.sh datasets/ILSVRC2012/val/ output/temp/ilsvrc_2012_val_info.json datasets/ILSVRC2012/ILSVRC2012_img_val.json tools/ilsvrc2012_classes_name.txt datasets/ILSVRC2012/ILSVRC2012_img_val_converted.json
93 |
94 | bash scripts/prepare_ilsvrc.sh datasets/ILSVRC2012/train/ output/temp/ilsvrc_2012_train_info.json datasets/ILSVRC2012/ILSVRC2012_img_train.json tools/ilsvrc2012_classes_name.txt datasets/ILSVRC2012/ILSVRC2012_img_train_converted.json
95 | ```
96 |
97 | ```
98 | ILSCRC2012/
99 | ILSVRC2012_img_train_converted.json
100 | ILSVRC2012_img_val_converted.json
101 | {train,val}/
102 | n01440764/*.JPEG # image files that are mentioned in the corresponding json
103 | ......
104 | n15075141/*.JPEG # image files that are mentioned in the corresponding json
105 |
106 | ```
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/openai/CLIP.git
2 | git+https://github.com/facebookresearch/segment-anything.git
3 | git+https://github.com/facebookresearch/detectron2.git
4 | opencv-python
5 | scikit-image
6 | shapely
7 | graphviz
8 | timm
9 | opencv-contrib-python
10 | Pillow==9.5.0
--------------------------------------------------------------------------------
/scripts/extract_ilsvrc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # script to extract ImageNet dataset
4 | # ILSVRC2012_img_train.tar (about 138 GB)
5 | # ILSVRC2012_img_val.tar (about 6.3 GB)
6 | # make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory
7 | #
8 | # https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md
9 | #
10 | # train/
11 | # ├── n01440764
12 | # │ ├── n01440764_10026.JPEG
13 | # │ ├── n01440764_10027.JPEG
14 | # │ ├── ......
15 | # ├── ......
16 | # val/
17 | # ├── n01440764
18 | # │ ├── ILSVRC2012_val_00000293.JPEG
19 | # │ ├── ILSVRC2012_val_00002138.JPEG
20 | # │ ├── ......
21 | # ├── ......
22 | #
23 | #
24 | # Extract the training data:
25 | #
26 | mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
27 | tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
28 | find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
29 | cd ..
30 | #
31 | # Extract the validation data and move images to subfolders:
32 | #
33 | mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
34 | wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
35 | #
36 | # Check total files after extract
37 | #
38 | # $ find train/ -name "*.JPEG" | wc -l
39 | # 1281167
40 | # $ find val/ -name "*.JPEG" | wc -l
41 | # 50000
42 | #
--------------------------------------------------------------------------------
/scripts/generate_sam_proposals_cuda.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Assign the first argument to GPUS, the number of GPUs to use.
4 | GPUS=$1
5 |
6 | # Set the number of nodes, defaulting to 1 if not set.
7 | NNODES=${NNODES:-1}
8 |
9 | # Set the rank of the node, defaulting to 0 if not set.
10 | NODE_RANK=${NODE_RANK:-0}
11 |
12 | # Set the port, defaulting to a random value within a specific range if not set.
13 | PORT=${PORT:-$((28500 + $RANDOM % 2000))}
14 |
15 | # Set the master address, defaulting to localhost if not set.
16 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
17 |
18 | # Check if torchrun is available.
19 | if command -v torchrun &> /dev/null
20 | then
21 | echo "Using torchrun mode."
22 | # Set environment variables for Python path and thread settings, then execute the Python script with torchrun.
23 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
24 | torchrun --nnodes=${NNODES} \
25 | --node_rank=${NODE_RANK} \
26 | --master_addr=${MASTER_ADDR} \
27 | --master_port=${PORT} \
28 | --nproc_per_node=${GPUS} \
29 | tools/generate_sam_proposals_cuda.py "${@:2}"
30 | else
31 | echo "Using launch mode."
32 | # Fallback to using python -m torch.distributed.launch if torchrun is not available.
33 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
34 | python -m torch.distributed.launch \
35 | --nnodes=${NNODES} \
36 | --node_rank=${NODE_RANK} \
37 | --master_addr=${MASTER_ADDR} \
38 | --master_port=${PORT} \
39 | --nproc_per_node=${GPUS} \
40 | tools/generate_sam_proposals_cuda.py "${@:2}"
41 | fi
--------------------------------------------------------------------------------
/scripts/prepare_ilsvrc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 | set -x
5 | temp="output/temp"
6 | if [ ! -d "$temp" ]; then
7 | # 如果目录不存在,使用mkdir -p递归创建目录
8 | mkdir -p "$temp"
9 | echo "Directory $temp created."
10 | else
11 | echo "Directory $temp already exists."
12 | fi
13 |
14 | img_root="$1"
15 | info_json="$2"
16 | out_file="$3"
17 | class_name="$4"
18 | converted_file="$5"
19 |
20 | python tools/ilsvrc_info.py --img-root ${img_root} --out-file ${info_json}
21 | python tools/ilsvrc_folder.py --out-file ${out_file} --info-json ${info_json}
22 | python tools/convert_ilsvrc_classes_name.py --ann ${out_file} --f ${class_name} --output ${converted_file}
23 |
--------------------------------------------------------------------------------
/scripts/train_script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 | set -x
5 |
6 | train_file_path="$1"
7 | config_file_path="$2"
8 | GPU_NUM="$3"
9 | timestamp="$4"
10 | rest_args="${@:5}"
11 | PORT=${PORT:-$((28500 + $RANDOM % 2000))}
12 |
13 | if [ -z "$timestamp" ]
14 | then
15 | timestamp="`date +'%Y%m%d_%H%M%S'`"
16 | fi
17 |
18 | python ${train_file_path} --dist-url="tcp://127.0.0.1:${PORT}" --num-gpus ${GPU_NUM} --resume --config-file ${config_file_path} OUTPUT_DIR output/${config_file_path:0:-5}_${timestamp} ${rest_args}
19 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | import glob
5 | import os
6 | import shutil
7 | from os import path
8 | from typing import List
9 |
10 | import torch
11 | from setuptools import find_packages, setup
12 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
13 |
14 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
15 | assert torch_ver >= [1, 8], "Requires PyTorch >= 1.8"
16 |
17 |
18 | def get_version():
19 | init_py_path = path.join(path.abspath(path.dirname(__file__)), "wsovod", "__init__.py")
20 | init_py = open(init_py_path, "r").readlines()
21 | version_line = [l.strip() for l in init_py if l.startswith("__version__")][0]
22 | version = version_line.split("=")[-1].strip().strip("'\"")
23 |
24 | # The following is used to build release packages.
25 | # Users should never use it.
26 | suffix = os.getenv("D2_VERSION_SUFFIX", "")
27 | version = version + suffix
28 | if os.getenv("BUILD_NIGHTLY", "0") == "1":
29 | from datetime import datetime
30 |
31 | date_str = datetime.today().strftime("%y%m%d")
32 | version = version + ".dev" + date_str
33 |
34 | new_init_py = [l for l in init_py if not l.startswith("__version__")]
35 | new_init_py.append('__version__ = "{}"\n'.format(version))
36 | with open(init_py_path, "w") as f:
37 | f.write("".join(new_init_py))
38 | return version
39 |
40 |
41 | def get_extensions():
42 | this_dir = path.dirname(path.abspath(__file__))
43 | extensions_dir = path.join(this_dir, "wsovod", "layers")
44 |
45 | main_source = path.join(extensions_dir, "vision.cpp")
46 | sources = glob.glob(path.join(extensions_dir, "**", "*.cpp"))
47 |
48 | from torch.utils.cpp_extension import ROCM_HOME
49 |
50 | is_rocm_pytorch = (
51 | True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
52 | )
53 | if is_rocm_pytorch:
54 | assert torch_ver >= [1, 8], "ROCM support requires PyTorch >= 1.8!"
55 |
56 | # common code between cuda and rocm platforms, for hipify version [1,0,0] and later.
57 | source_cuda = glob.glob(path.join(extensions_dir, "**", "*.cu")) + glob.glob(
58 | path.join(extensions_dir, "*.cu")
59 | )
60 | sources = [main_source] + sources
61 |
62 | extension = CppExtension
63 |
64 | extra_compile_args = {"cxx": []}
65 | define_macros = []
66 |
67 | if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv(
68 | "FORCE_CUDA", "0"
69 | ) == "1":
70 | extension = CUDAExtension
71 | sources += source_cuda
72 |
73 | if not is_rocm_pytorch:
74 | define_macros += [("WITH_CUDA", None)]
75 | extra_compile_args["nvcc"] = [
76 | "-O3",
77 | "-DCUDA_HAS_FP16=1",
78 | "-D__CUDA_NO_HALF_OPERATORS__",
79 | "-D__CUDA_NO_HALF_CONVERSIONS__",
80 | "-D__CUDA_NO_HALF2_OPERATORS__",
81 | ]
82 | else:
83 | define_macros += [("WITH_HIP", None)]
84 | extra_compile_args["nvcc"] = []
85 |
86 | if torch_ver < [1, 7]:
87 | # supported by https://github.com/pytorch/pytorch/pull/43931
88 | CC = os.environ.get("CC", None)
89 | if CC is not None:
90 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
91 |
92 | include_dirs = [extensions_dir]
93 |
94 | ext_modules = [
95 | extension(
96 | "wsovod._C",
97 | sources,
98 | include_dirs=include_dirs,
99 | define_macros=define_macros,
100 | extra_compile_args=extra_compile_args,
101 | )
102 | ]
103 |
104 | return ext_modules
105 |
106 |
107 | # For projects that are relative small and provide features that are very close
108 | # to detectron2's core functionalities, we install them under detectron2.projects
109 | PROJECTS = {}
110 |
111 | setup(
112 | name="wsovod",
113 | version=get_version(),
114 | author="Hunterj Lin",
115 | url="https://github.com/HunterJ-Lin/WSOVOD",
116 | description="WSOVOD is next-generation research "
117 | "platform for open-vocabulary object detection and segmentation.",
118 | packages=find_packages(exclude=("configs", "tests*")) + list(PROJECTS.keys()),
119 | package_dir=PROJECTS,
120 | package_data={},
121 | python_requires=">=3.8",
122 | install_requires=["graphviz"],
123 | ext_modules=get_extensions(),
124 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
125 | )
--------------------------------------------------------------------------------
/teaser/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HunterJ-Lin/WSOVOD/74e7647faaa0ab5de37a351532f65c57654f66be/teaser/framework.png
--------------------------------------------------------------------------------
/tools/convert_ilsvrc_classes_name.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | if __name__ == '__main__':
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--ann', type=str, default='datasets/ILSVRC2012/ILSVRC2012_img_val.json')
7 | parser.add_argument('--f', type=str, default='tools/ilsvrc2012_classes_name.txt')
8 | parser.add_argument('--output',type=str, default='datasets/ILSVRC2012/ILSVRC2012_img_val_converted.json')
9 | args = parser.parse_args()
10 | with open(args.f,'r') as f:
11 | lines = f.readlines()
12 | d = {}
13 | for line in lines:
14 | k,v = line.split(':')
15 | d[k.strip()] = v.split(',')[0].strip()
16 |
17 | print('Loading', args.ann)
18 | data = json.load(open(args.ann, 'r'))
19 | data['categories'] = [{'id':cat['id'],'name':d[cat['name']]} for cat in data['categories']]
20 | print('Saving to', args.output)
21 | json.dump(data, open(args.output, 'w'))
--------------------------------------------------------------------------------
/tools/generate_class_text_embedding.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import cv2
4 | from six.moves import cPickle as pickle
5 | from tqdm import tqdm
6 | import numpy as np
7 | from detectron2.data.catalog import DatasetCatalog
8 | from detectron2.utils.file_io import PathManager
9 | import os
10 | import clip
11 | from detectron2.data import MetadataCatalog
12 | import torch
13 |
14 | PROMPTS = [
15 | 'There is a {category} in the scene.',
16 | 'There is my {category} in the scene.',
17 | 'There is the {category} in the scene.',
18 | 'There is one {category} in the scene.',
19 | 'a photo of a {category} in the scene.',
20 | 'a photo of my {category} in the scene.',
21 | 'a photo of the {category} in the scene.',
22 | 'a photo of one {category} in the scene.',
23 | 'itap of a {category}.',
24 | 'itap of my {category}.',
25 | 'itap of the {category}.',
26 | 'itap of one {category}.',
27 | 'a photo of a {category}.',
28 | 'a photo of my {category}.',
29 | 'a photo of the {category}.',
30 | 'a photo of one {category}.',
31 | 'a good photo of a {category}.',
32 | 'a good photo of the {category}.',
33 | 'a bad photo of a {category}.',
34 | 'a bad photo of the {category}.',
35 | 'a photo of a nice {category}.',
36 | 'a photo of the nice {category}.',
37 | 'a photo of a cool {category}.',
38 | 'a photo of the cool {category}.',
39 | 'a photo of a weird {category}.',
40 | 'a photo of the weird {category}.',
41 | 'a photo of a small {category}.',
42 | 'a photo of the small {category}.',
43 | 'a photo of a large {category}.',
44 | 'a photo of the large {category}.',
45 | 'a photo of a clean {category}.',
46 | 'a photo of the clean {category}.',
47 | 'a photo of a dirty {category}.',
48 | 'a photo of the dirty {category}.',
49 | 'a bright photo of a {category}.',
50 | 'a bright photo of the {category}.',
51 | 'a dark photo of a {category}.',
52 | 'a dark photo of the {category}.',
53 | 'a photo of a hard to see {category}.',
54 | 'a photo of the hard to see {category}.',
55 | 'a low resolution photo of a {category}.',
56 | 'a low resolution photo of the {category}.',
57 | 'a cropped photo of a {category}.',
58 | 'a cropped photo of the {category}.',
59 | 'a close-up photo of a {category}.',
60 | 'a close-up photo of the {category}.',
61 | 'a jpeg corrupted photo of a {category}.',
62 | 'a jpeg corrupted photo of the {category}.',
63 | 'a blurry photo of a {category}.',
64 | 'a blurry photo of the {category}.',
65 | 'a pixelated photo of a {category}.',
66 | 'a pixelated photo of the {category}.',
67 | ]
68 |
69 | if __name__ == '__main__':
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('--dataset-name', type=str, default='coco_2017_val')
72 | parser.add_argument('--categories', type=str, default='')
73 | # parser.add_argument('--mode-type', type=str, default='ViT-L/14/32')
74 | parser.add_argument('--model-type', type=str, default='ViT-B/32')
75 | parser.add_argument('--prompt-type', type=str, default='single')
76 | parser.add_argument('--output', type=str, default='models/coco_text_embedding_single_prompt.pkl')
77 | args = parser.parse_args()
78 | # load clip model
79 | clip_model, clip_preprocess = clip.load(args.model_type, 'cpu', jit=False)
80 | clip_model = clip_model.eval()
81 | d = []
82 |
83 | thing_classes = MetadataCatalog.get(args.dataset_name).thing_classes
84 | if thing_classes is None:
85 | thing_classes = args.categories.split(',')
86 | assert len(thing_classes)>0
87 |
88 | if args.prompt_type == 'mutiple':
89 | for category in thing_classes:
90 | text_inputs = torch.cat([clip.tokenize(prompt.format(**{'category':category})) for prompt in PROMPTS]).to('cpu')
91 | text_embedding = clip_model.encode_text(text_inputs).float()
92 | d[category] = text_embedding.mean(0)
93 | text_embeddings = torch.cat(d)
94 | else:
95 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}.") for c in thing_classes]).to('cpu')
96 | text_embeddings = clip_model.encode_text(text_inputs).float()
97 |
98 | print(text_embeddings.shape)
99 | print('save to '+args.output)
100 | with open(args.output, "wb") as f:
101 | pickle.dump(text_embeddings, f, pickle.HIGHEST_PROTOCOL)
--------------------------------------------------------------------------------
/tools/generate_class_text_embedding_cuda.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import cv2
4 | from six.moves import cPickle as pickle
5 | from tqdm import tqdm
6 | import numpy as np
7 | from detectron2.data.catalog import DatasetCatalog
8 | from detectron2.utils.file_io import PathManager
9 | import os
10 | import clip
11 | from detectron2.data import MetadataCatalog
12 | from detectron2.utils.comm import is_main_process, get_world_size, get_local_rank
13 | import torch
14 | import torch.distributed as dist
15 |
16 | PROMPTS = [
17 | 'There is a {category} in the scene.',
18 | 'There is my {category} in the scene.',
19 | 'There is the {category} in the scene.',
20 | 'There is one {category} in the scene.',
21 | 'a photo of a {category} in the scene.',
22 | 'a photo of my {category} in the scene.',
23 | 'a photo of the {category} in the scene.',
24 | 'a photo of one {category} in the scene.',
25 | 'itap of a {category}.',
26 | 'itap of my {category}.',
27 | 'itap of the {category}.',
28 | 'itap of one {category}.',
29 | 'a photo of a {category}.',
30 | 'a photo of my {category}.',
31 | 'a photo of the {category}.',
32 | 'a photo of one {category}.',
33 | 'a good photo of a {category}.',
34 | 'a good photo of the {category}.',
35 | 'a bad photo of a {category}.',
36 | 'a bad photo of the {category}.',
37 | 'a photo of a nice {category}.',
38 | 'a photo of the nice {category}.',
39 | 'a photo of a cool {category}.',
40 | 'a photo of the cool {category}.',
41 | 'a photo of a weird {category}.',
42 | 'a photo of the weird {category}.',
43 | 'a photo of a small {category}.',
44 | 'a photo of the small {category}.',
45 | 'a photo of a large {category}.',
46 | 'a photo of the large {category}.',
47 | 'a photo of a clean {category}.',
48 | 'a photo of the clean {category}.',
49 | 'a photo of a dirty {category}.',
50 | 'a photo of the dirty {category}.',
51 | 'a bright photo of a {category}.',
52 | 'a bright photo of the {category}.',
53 | 'a dark photo of a {category}.',
54 | 'a dark photo of the {category}.',
55 | 'a photo of a hard to see {category}.',
56 | 'a photo of the hard to see {category}.',
57 | 'a low resolution photo of a {category}.',
58 | 'a low resolution photo of the {category}.',
59 | 'a cropped photo of a {category}.',
60 | 'a cropped photo of the {category}.',
61 | 'a close-up photo of a {category}.',
62 | 'a close-up photo of the {category}.',
63 | 'a jpeg corrupted photo of a {category}.',
64 | 'a jpeg corrupted photo of the {category}.',
65 | 'a blurry photo of a {category}.',
66 | 'a blurry photo of the {category}.',
67 | 'a pixelated photo of a {category}.',
68 | 'a pixelated photo of the {category}.',
69 | ]
70 |
71 | if __name__ == '__main__':
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('--dataset-name', type=str, default='voc_2007_val')
74 | parser.add_argument('--categories', type=str, default='')
75 | parser.add_argument('--bs', type=int, default=32)
76 | # parser.add_argument('--mode-type', type=str, default='ViT-L/14/32')
77 | parser.add_argument('--model-type', type=str, default='ViT-B/32')
78 | parser.add_argument('--prompt-type', type=str, default='single')
79 | parser.add_argument('--output', type=str, default='models/voc_text_embedding_single_prompt_cuda.pkl')
80 | args = parser.parse_args()
81 |
82 | # load clip model
83 | clip_model, clip_preprocess = clip.load(args.model_type, 'cuda', jit=False)
84 | clip_model = clip_model.eval()
85 |
86 | thing_classes = MetadataCatalog.get(args.dataset_name).thing_classes
87 | if thing_classes is None:
88 | thing_classes = args.categories.split(',')
89 | assert len(thing_classes)>0
90 |
91 | descriptions = []
92 | candidates = []
93 | for cls_name in thing_classes:
94 | if args.prompt_type == 'mutiple':
95 | candidates.append(len(PROMPTS))
96 | for template in PROMPTS:
97 | description = template.format(**{'category':cls_name})
98 | descriptions.append(description)
99 | else:
100 | candidates.append(1)
101 | descriptions.append(f"a photo of a {cls_name}.")
102 |
103 | with torch.no_grad():
104 | tot = len(descriptions)
105 | bs = args.bs
106 | nb = tot // bs
107 | if tot % bs != 0:
108 | nb += 1
109 | text_embeddings_list = []
110 | for i in range(nb):
111 | local_descriptions = descriptions[i * bs: (i + 1) * bs]
112 | text_inputs = torch.cat([clip.tokenize(ds) for ds in local_descriptions]).to('cuda')
113 | local_text_embeddings = clip_model.encode_text(text_inputs).to(device='cpu').float()
114 | text_embeddings_list.append(local_text_embeddings)
115 | text_embeddings = torch.cat(text_embeddings_list)
116 |
117 | dim = text_embeddings.shape[-1]
118 | candidate_tot = sum(candidates)
119 | text_embeddings = text_embeddings.split(candidates, dim=0)
120 | if args.prompt_type == 'mutiple':
121 | text_embeddings = [text_embedding.mean(0).unsqueeze(0) for text_embedding in text_embeddings]
122 |
123 | text_embeddings = torch.cat(text_embeddings)
124 | print('save to '+args.output)
125 | with open(args.output, "wb") as f:
126 | pickle.dump(text_embeddings, f, pickle.HIGHEST_PROTOCOL)
--------------------------------------------------------------------------------
/tools/generate_sam_proposals_cuda.py:
--------------------------------------------------------------------------------
1 | from math import e
2 | import cv2
3 | import os
4 | import argparse
5 | import numpy as np
6 | import pickle
7 | from detectron2.data.catalog import DatasetCatalog
8 | import torch
9 | import torch.distributed as dist
10 | from detectron2.utils.comm import synchronize, get_world_size, get_rank
11 | from wsovod.data.datasets import builtin
12 |
13 | # Import DistributedDataParallel
14 | from torch.nn.parallel import DistributedDataParallel as DDP
15 |
16 | def process_image(mask_generator, image_info):
17 | image_path = image_info['file_name']
18 | print("processing ", image_path)
19 | image = cv2.imread(image_path)
20 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
21 | try:
22 | masks = mask_generator.generate(image)
23 | except Exception as e:
24 | print(f"Error processing image {image_path}: {e}")
25 | mask_generator.predictor.model.to('cpu')
26 | masks = mask_generator.generate(image)
27 | mask_generator.predictor.model.to(f'cuda:{get_rank()}')
28 | proposals = []
29 | scores = []
30 | for instance in masks:
31 | score = instance['predicted_iou'] * instance['stability_score']
32 | if score > 1.0:
33 | score = 1.0
34 | bbox = instance['bbox']
35 | if bbox[2] <= 0 or bbox[3] <= 0:
36 | continue
37 | bbox[2] = bbox[0] + bbox[2]
38 | bbox[3] = bbox[1] + bbox[3]
39 | scores.append(score)
40 | proposals.append(bbox)
41 | proposals = np.array(proposals)
42 | scores = np.array(scores)
43 | return proposals, scores, image_info['image_id']
44 |
45 | def main():
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('--checkpoint', type=str, default='tools/sam_checkpoints/sam_vit_h_4b8939.pth')
48 | parser.add_argument('--model-type', type=str, default='vit_h')
49 | parser.add_argument('--dataset-name', type=str, default='coco_2017_val')
50 | parser.add_argument('--output', type=str, default='datasets/proposals/sam_coco_2017_val_d2.pkl')
51 | parser.add_argument('--points-per-side', type=int, default=32)
52 | parser.add_argument('--pred-iou-thresh', type=float, default=0.86)
53 | parser.add_argument('--stability-score-thresh', type=float, default=0.92)
54 | parser.add_argument('--crop-n-layers', type=int, default=1)
55 | parser.add_argument('--crop-n-points-downscale-factor', type=int, default=2)
56 | parser.add_argument('--min-mask-region-area', type=float, default=20.0)
57 | # Add an argument for the local rank that will be set by torchrun or the PyTorch launcher
58 | parser.add_argument('--local_rank', type=int, default=0)
59 | args = parser.parse_args()
60 | import os
61 | print("Local rank (from environment):", os.environ.get("LOCAL_RANK"))
62 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
63 | dataset_dicts = DatasetCatalog.get(args.dataset_name)
64 | # dataset_dicts = dataset_dicts[:100] for debugging
65 | rank = get_rank()
66 | world_size = get_world_size()
67 | device = torch.device(f'cuda:{rank}')
68 | print(f"Rank: {rank}, World Size: {world_size}, Device: {device}")
69 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
70 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
71 | sam.to(device=device)
72 | mask_generator = SamAutomaticMaskGenerator(
73 | model=sam,
74 | points_per_side=args.points_per_side,
75 | pred_iou_thresh=args.pred_iou_thresh,
76 | stability_score_thresh=args.stability_score_thresh,
77 | crop_n_layers=args.crop_n_layers,
78 | crop_n_points_downscale_factor=args.crop_n_points_downscale_factor,
79 | min_mask_region_area=args.min_mask_region_area, # Requires open-cv to run post-processing
80 | )
81 |
82 | # Now, adjust the data loading and processing to only work on a subset of the data based on the rank of the process
83 | # This is a simple way to split the dataset, more sophisticated methods might be needed for your use case
84 | subset_size = len(dataset_dicts) // world_size
85 | start_idx = rank * subset_size
86 | end_idx = start_idx + subset_size if rank < world_size - 1 else len(dataset_dicts)
87 | local_dataset_dicts = dataset_dicts[start_idx:end_idx]
88 |
89 | all_boxes = []
90 | all_scores = []
91 | all_indexes = []
92 | if rank == 0:
93 | from tqdm import tqdm
94 | local_dataset_dicts = tqdm(local_dataset_dicts)
95 | with torch.no_grad():
96 | for image_info in local_dataset_dicts:
97 | boxes, scores, indexes = process_image(mask_generator, image_info)
98 | all_boxes.append(boxes)
99 | all_scores.append(scores)
100 | all_indexes.append(indexes)
101 |
102 | # Gather results from all processes
103 | gathered_boxes = [None] * world_size
104 | gathered_scores = [None] * world_size
105 | gathered_indexes = [None] * world_size
106 | torch.cuda.set_device(rank)
107 | dist.barrier()
108 | print(f"Rank: {rank} gathering boxes results...")
109 | dist.all_gather_object(gathered_boxes, all_boxes)
110 | print(f"Rank: {rank} gathered boxes.")
111 | dist.barrier()
112 | print(f"Rank: {rank} gathering scores results...")
113 | dist.all_gather_object(gathered_scores, all_scores)
114 | print(f"Rank: {rank} gathered scores.")
115 | dist.barrier()
116 | print(f"Rank: {rank} gathering indexes results...")
117 | dist.all_gather_object(gathered_indexes, all_indexes)
118 | print(f"Rank: {rank} gathered indexes.")
119 | dist.barrier()
120 | # Collecting and saving the results should be done by one process
121 | if rank == 0:
122 | print("All results gathered.")
123 | # Flatten lists from all processes
124 | all_boxes_flat = [item for sublist in gathered_boxes for item in sublist]
125 | all_scores_flat = [item for sublist in gathered_scores for item in sublist]
126 | all_indexes_flat = [item for sublist in gathered_indexes for item in sublist]
127 | assert len(all_boxes_flat) == len(all_scores_flat) == len(all_indexes_flat) == len(dataset_dicts)
128 | assert set(all_indexes_flat) == set([image_info['image_id'] for image_info in dataset_dicts])
129 | # Save only in the master process
130 | output_file = args.output
131 | with open(output_file, 'wb') as f:
132 | pickle.dump({'boxes': all_boxes_flat, 'scores': all_scores_flat, 'indexes': all_indexes_flat}, f)
133 | print("Proposal generation and saving completed.")
134 | # Ensure all processes have finished before exiting
135 | dist.barrier()
136 |
137 | if __name__ == '__main__':
138 | main()
139 |
--------------------------------------------------------------------------------
/tools/ilsvrc_folder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import os.path as osp
5 | import pathlib
6 | import random
7 | import xml.dom.minidom
8 | from distutils.util import strtobool
9 |
10 | import numpy as np
11 | from detectron2.data.detection_utils import read_image
12 | # from nltk.corpus import wordnet as wn
13 | from PIL import Image
14 | from tqdm import tqdm
15 |
16 | # import nltk
17 | # nltk.download('wordnet')
18 |
19 | miss_xml = []
20 | folder_classes = []
21 |
22 | def parse_xml(file_path):
23 | print("processing: ", file_path)
24 |
25 | dom = xml.dom.minidom.parse(file_path)
26 | document = dom.documentElement
27 |
28 | size = document.getElementsByTagName('size')
29 | bboxes = []
30 | labels = []
31 | for item in size:
32 | width = item.getElementsByTagName('width')
33 | width = int(width[0].firstChild.data)
34 | height = item.getElementsByTagName('height')
35 | height = int(height[0].firstChild.data)
36 |
37 | # try:
38 | if True:
39 | object = document.getElementsByTagName('object')
40 | for item in object:
41 | label = item.getElementsByTagName('name')
42 | label = str(label[0].firstChild.data)
43 |
44 | xmin = item.getElementsByTagName('xmin')
45 | xmin = float(xmin[0].firstChild.data)
46 | ymin = item.getElementsByTagName('ymin')
47 | ymin = float(ymin[0].firstChild.data)
48 | xmax = item.getElementsByTagName('xmax')
49 | xmax = float(xmax[0].firstChild.data)
50 | ymax = item.getElementsByTagName('ymax')
51 | ymax = float(ymax[0].firstChild.data)
52 |
53 | bboxes.append([xmin, ymin, xmax, ymax])
54 | labels.append(label)
55 | # except:
56 | # continue
57 |
58 | return bboxes, labels
59 |
60 |
61 | def cvt_annotations(path_h_w, out_file, has_instance, has_segmentation, ann_root):
62 | label_ids = {name: i for i, name in enumerate(folder_classes)}
63 | print('cvt annotations')
64 | print("label_ids: ", label_ids)
65 |
66 | annotations = []
67 | pbar = enumerate(path_h_w)
68 | pbar = tqdm(pbar)
69 | for i, [img_path, height, width] in pbar:
70 | # if i > 10000:
71 | # break
72 | # print(i, img_path)
73 |
74 | wnid = pathlib.PurePath(img_path).parent.name
75 |
76 | if has_instance:
77 | xml_path = ann_root + img_path[:-5] + ".xml"
78 | if os.path.exists(xml_path):
79 | bboxes, wnids = parse_xml(xml_path)
80 | ## check error wnid
81 | # for wnid_ in wnids:
82 | # assert wnid_ == wnid, [wnids, wnid]
83 | # labels = [label_ids[wnid_] for wnid_ in wnids]
84 | labels = [label_ids[wnid] for wnid_ in wnids]
85 | else:
86 | print("xml file not found", xml_path)
87 | miss_xml.append(xml_path)
88 | # continue
89 | print("genenrate pseudo annotation for", img_path)
90 | bboxes = [[1, 1, width, height]]
91 | labels = [label_ids[wnid]]
92 |
93 | bboxes = np.array(bboxes, ndmin=2) - 1
94 | labels = np.array(labels)
95 | else:
96 | bboxes = np.array(np.zeros((0, 4), dtype=np.float32))
97 | labels = np.array(np.zeros((0,), dtype=np.int64))
98 |
99 | annotation = {
100 | "filename": img_path,
101 | "width": width,
102 | "height": height,
103 | "ann": {
104 | "bboxes": bboxes.astype(np.float32),
105 | "labels": labels.astype(np.int64),
106 | "bboxes_ignore": np.zeros((0, 4), dtype=np.float32),
107 | "labels_ignore": np.zeros((0,), dtype=np.int64),
108 | },
109 | }
110 |
111 | annotations.append(annotation)
112 | annotations = cvt_to_coco_json(annotations, has_segmentation)
113 |
114 | with open(out_file, "w") as f:
115 | json.dump(annotations, f)
116 |
117 | return annotations
118 |
119 |
120 | def cvt_to_coco_json(annotations, has_segmentation):
121 | image_id = 0
122 | annotation_id = 0
123 | coco = dict()
124 | coco["images"] = []
125 | coco["type"] = "instance"
126 | coco["categories"] = []
127 | coco["annotations"] = []
128 | image_set = set()
129 | print('cvt to coco json')
130 | def addAnnItem(annotation_id, image_id, category_id, bbox, difficult_flag):
131 | annotation_item = dict()
132 | if has_segmentation:
133 | annotation_item["segmentation"] = []
134 |
135 | seg = []
136 | # bbox[] is x1,y1,x2,y2
137 | # left_top
138 | seg.append(int(bbox[0]))
139 | seg.append(int(bbox[1]))
140 | # left_bottom
141 | seg.append(int(bbox[0]))
142 | seg.append(int(bbox[3]))
143 | # right_bottom
144 | seg.append(int(bbox[2]))
145 | seg.append(int(bbox[3]))
146 | # right_top
147 | seg.append(int(bbox[2]))
148 | seg.append(int(bbox[1]))
149 |
150 | annotation_item["segmentation"].append(seg)
151 |
152 | xywh = np.array([bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]])
153 | annotation_item["area"] = int(xywh[2] * xywh[3])
154 | if difficult_flag == 1:
155 | annotation_item["ignore"] = 0
156 | annotation_item["iscrowd"] = 1
157 | else:
158 | annotation_item["ignore"] = 0
159 | annotation_item["iscrowd"] = 0
160 | annotation_item["image_id"] = int(image_id)
161 | annotation_item["bbox"] = xywh.astype(int).tolist()
162 | annotation_item["category_id"] = int(category_id)
163 | annotation_item["id"] = int(annotation_id)
164 | coco["annotations"].append(annotation_item)
165 | return annotation_id + 1
166 |
167 | for category_id, name in enumerate(folder_classes):
168 | category_item = dict()
169 | category_item["supercategory"] = str("none")
170 | category_item["id"] = int(category_id)
171 | category_item["name"] = str(name)
172 | coco["categories"].append(category_item)
173 |
174 | pbar = tqdm(annotations)
175 | for ann_dict in pbar:
176 | file_name = ann_dict["filename"]
177 | ann = ann_dict["ann"]
178 | assert file_name not in image_set
179 | image_item = dict()
180 | image_item["id"] = int(image_id)
181 | image_item["file_name"] = str(file_name)
182 | image_item["height"] = int(ann_dict["height"])
183 | image_item["width"] = int(ann_dict["width"])
184 | coco["images"].append(image_item)
185 | image_set.add(file_name)
186 |
187 | bboxes = ann["bboxes"][:, :4]
188 | labels = ann["labels"]
189 | for bbox_id in range(len(bboxes)):
190 | bbox = bboxes[bbox_id]
191 | label = labels[bbox_id]
192 | annotation_id = addAnnItem(annotation_id, image_id, label, bbox, difficult_flag=0)
193 |
194 | bboxes_ignore = ann["bboxes_ignore"][:, :4]
195 | labels_ignore = ann["labels_ignore"]
196 | for bbox_id in range(len(bboxes_ignore)):
197 | bbox = bboxes_ignore[bbox_id]
198 | label = labels_ignore[bbox_id]
199 | annotation_id = addAnnItem(annotation_id, image_id, label, bbox, difficult_flag=1)
200 |
201 | image_id += 1
202 |
203 | return coco
204 |
205 |
206 | def _strtobool(x):
207 | return bool(strtobool(x))
208 |
209 |
210 | def parse_args():
211 | parser = argparse.ArgumentParser(description="Convert image list to coco format")
212 | parser.add_argument("--ann-root", help="ann root", type=str, default='')
213 | parser.add_argument("--out-file", help="output path", required=True)
214 | parser.add_argument("--info-json", help="output path", required=True)
215 | parser.add_argument('--has-instance', nargs='?', const=True, type=_strtobool, default=True, help="has instance")
216 | parser.add_argument('--has-segmentation', nargs='?', const=True, type=_strtobool, default=True, help="has segmentation")
217 | args = parser.parse_args()
218 | return args
219 |
220 |
221 | def main():
222 |
223 | args = parse_args()
224 | print(args)
225 | ann_root = args.ann_root
226 | out_file = args.out_file
227 |
228 | has_instance = args.has_instance
229 | has_segmentation = args.has_segmentation
230 |
231 | global folder_classes
232 | with open(args.info_json,'r') as f:
233 | d = json.load(f)
234 | folder_classes = d['categories']
235 | path_h_w = d['path_h_w']
236 | annotations = cvt_annotations(path_h_w, out_file, has_instance, has_segmentation, ann_root)
237 | print("Done!")
238 | print('miss xml file:{}'.format(len(miss_xml)))
239 | print(miss_xml)
240 |
241 | with open(out_file, "w") as f:
242 | json.dump(annotations, f)
243 |
244 | print(args)
245 |
246 |
247 | if __name__ == "__main__":
248 | main()
--------------------------------------------------------------------------------
/tools/ilsvrc_info.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import os.path as osp
5 | import pathlib
6 | import random
7 | import xml.dom.minidom
8 | from distutils.util import strtobool
9 |
10 | import numpy as np
11 | from detectron2.data.detection_utils import read_image
12 | # from nltk.corpus import wordnet as wn
13 | from PIL import Image
14 | from tqdm import tqdm
15 |
16 |
17 | def get_filename_key(x):
18 | basename = os.path.basename(x)
19 | if basename[:-4].isdigit():
20 | return int(basename[:-4])
21 |
22 | return basename
23 |
24 | def _strtobool(x):
25 | return bool(strtobool(x))
26 |
27 | def parse_args():
28 | parser = argparse.ArgumentParser(description="Statistical ILSVRC")
29 | parser.add_argument("--img-root", help="img root", required=True)
30 | parser.add_argument("--out-file", help="output path", required=True)
31 | parser.add_argument('--max-per-dir', nargs='?', const=1e10, type=int, default=1e10, help="max per dir")
32 | parser.add_argument('--min-per-dir', nargs='?', const=0, type=int, default=0, help="min per dir")
33 | parser.add_argument('--has-shuffle', nargs='?', const=False, type=_strtobool, default=False, help="shuffle or sort")
34 | args = parser.parse_args()
35 | return args
36 |
37 |
38 | def main():
39 | args = parse_args()
40 | print(args)
41 | img_root = args.img_root
42 | out_file = args.out_file
43 |
44 | max_per_dir = args.max_per_dir
45 | min_per_dir = args.min_per_dir
46 |
47 | has_shuffle = args.has_shuffle
48 |
49 | folder_classes = []
50 | path_h_w = []
51 | pbar = tqdm(os.walk(img_root))
52 | for root, dirs, files in pbar:
53 | if not os.path.basename(root).startswith("n"):
54 | print("\tskip folder: ", root)
55 | continue
56 |
57 | files = [f for f in files if f.endswith(('.jpg', '.JPG', '.png', '.PNG', '.jpeg', '.JPEG'))]
58 |
59 | if min_per_dir > 0 and len(files) < min_per_dir:
60 | print("\tskip folder: ", root, " #image: ", len(files))
61 | continue
62 |
63 | if has_shuffle:
64 | random.shuffle(files)
65 | else:
66 | files = sorted(files, key=lambda x: get_filename_key(x), reverse=False)
67 |
68 |
69 | folder_classes_this = []
70 | path_h_w_this = []
71 | for name in files:
72 | if max_per_dir > 0 and len(path_h_w_this) >= max_per_dir:
73 | break
74 |
75 | path = os.path.join(root, name)
76 | print('parsing:',path)
77 | rpath = path.replace(img_root, "")
78 |
79 | wnid = pathlib.PurePath(path).parent.name
80 |
81 | if wnid not in folder_classes_this:
82 | folder_classes_this.append(wnid)
83 |
84 | try:
85 | img = read_image(path, format="BGR")
86 | height, width, _ = img.shape
87 | except Exception as e:
88 | print("*" * 60)
89 | print("fail to open image: ", e)
90 | print("*" * 60)
91 | continue
92 |
93 | # if width < 300 or height < 300:
94 | # continue
95 |
96 | # if width < 224 or height < 224:
97 | # continue
98 |
99 | # if width > 1000 or height > 1000:
100 | # continue
101 |
102 | path_h_w_this.append([rpath, height, width])
103 |
104 | if len(path_h_w_this) >= min_per_dir:
105 | pass
106 | else:
107 | print("\tskip folder: ", root, " #image: ", len(path_h_w_this))
108 | continue
109 |
110 | folder_classes.extend(folder_classes_this)
111 | path_h_w.extend(path_h_w_this)
112 |
113 | print("folder: ", root, " #image: ", len(path_h_w_this))
114 | print("#folder: ", len(folder_classes), " #image: ", len(path_h_w))
115 |
116 | folder_classes = list(set(folder_classes))
117 | folder_classes.sort()
118 |
119 | print("#category: ", len(folder_classes), " categories: ", folder_classes)
120 | print("#image: ", len(path_h_w))
121 | print("")
122 | with open(out_file,'w') as f:
123 | d = {}
124 | d['categories'] = folder_classes
125 | d['path_h_w'] = path_h_w
126 | json.dump(d,f)
127 | print(args)
128 |
129 | if __name__ == "__main__":
130 | main()
--------------------------------------------------------------------------------
/tools/train_net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | """
4 | A main training script.
5 |
6 | This scripts reads a given config file and runs the training or evaluation.
7 | It is an entry point that is made to train standard models in detectron2.
8 |
9 | In order to let one script support training of many models,
10 | this script contains logic that are specific to these built-in models and therefore
11 | may not be suitable for your own project.
12 | For example, your research project perhaps only needs a single "evaluator".
13 |
14 | Therefore, we recommend you to use detectron2 as an library and take
15 | this file as an example of how to use the library.
16 | You may want to write your own script with your datasets and other customizations.
17 | """
18 |
19 | import logging
20 | import os
21 |
22 | import detectron2.utils.comm as comm
23 | from detectron2.checkpoint import DetectionCheckpointer
24 | from detectron2.config import get_cfg
25 | from detectron2.engine import (default_argument_parser, default_setup,hooks,launch)
26 | from detectron2.utils.logger import setup_logger
27 | from detectron2.evaluation import verify_results
28 | from wsovod.config import add_wsovod_config
29 |
30 |
31 | def setup(args):
32 | """
33 | Create configs and perform basic setups.
34 | """
35 | cfg = get_cfg()
36 | add_wsovod_config(cfg)
37 | cfg.merge_from_file(args.config_file)
38 | cfg.merge_from_list(args.opts)
39 | cfg.freeze()
40 | default_setup(cfg, args)
41 | setup_logger(cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="wsovod")
42 | return cfg
43 |
44 |
45 | def main(args):
46 | cfg = setup(args)
47 | if "MixedDatasets" in args.config_file:
48 | from wsovod.engine import DefaultTrainer_WSOVOD_MixedDatasets as Trainer
49 | else:
50 | from wsovod.engine import DefaultTrainer_WSOVOD as Trainer
51 |
52 | if args.eval_only:
53 | model = Trainer.build_model(cfg)
54 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
55 | cfg.MODEL.WEIGHTS, resume=args.resume
56 | )
57 | res = {}
58 | res = Trainer.test_WSL(cfg, model)
59 | if cfg.TEST.AUG.ENABLED:
60 | res.update(Trainer.test_with_TTA_WSL(cfg, model))
61 | if comm.is_main_process():
62 | verify_results(cfg, res)
63 | return res
64 |
65 | """
66 | If you'd like to do anything fancier than the standard training logic,
67 | consider writing your own training loop (see plain_train_net.py) or
68 | subclassing the trainer.
69 | """
70 | trainer = Trainer(cfg)
71 | trainer.resume_or_load(resume=args.resume)
72 | if cfg.TEST.AUG.ENABLED:
73 | # trainer.register_hooks([hooks.EvalHook(0, lambda: trainer.test_WSL(cfg, trainer.model))])
74 | trainer.register_hooks(
75 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA_WSL(cfg, trainer.model))]
76 | )
77 | return trainer.train()
78 |
79 |
80 | if __name__ == "__main__":
81 | args = default_argument_parser().parse_args()
82 | print("Command Line Args:", args)
83 | launch(
84 | main,
85 | args.num_gpus,
86 | num_machines=args.num_machines,
87 | machine_rank=args.machine_rank,
88 | dist_url=args.dist_url,
89 | args=(args,),
90 | )
91 |
--------------------------------------------------------------------------------
/tools/train_net_debug.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | """
4 | A main training script.
5 |
6 | This scripts reads a given config file and runs the training or evaluation.
7 | It is an entry point that is made to train standard models in detectron2.
8 |
9 | In order to let one script support training of many models,
10 | this script contains logic that are specific to these built-in models and therefore
11 | may not be suitable for your own project.
12 | For example, your research project perhaps only needs a single "evaluator".
13 |
14 | Therefore, we recommend you to use detectron2 as an library and take
15 | this file as an example of how to use the library.
16 | You may want to write your own script with your datasets and other customizations.
17 | """
18 |
19 | import logging
20 | import os
21 |
22 | import detectron2.utils.comm as comm
23 | from detectron2.checkpoint import DetectionCheckpointer
24 | from detectron2.config import get_cfg
25 | from detectron2.engine import (default_argument_parser, default_setup,hooks,launch)
26 | from detectron2.utils.logger import setup_logger
27 | from detectron2.evaluation import verify_results
28 | from wsovod.config import add_wsovod_config
29 |
30 |
31 | def setup(args):
32 | """
33 | Create configs and perform basic setups.
34 | """
35 | cfg = get_cfg()
36 | add_wsovod_config(cfg)
37 | cfg.merge_from_file(args.config_file)
38 | cfg.merge_from_list(args.opts)
39 | cfg.freeze()
40 | default_setup(cfg, args)
41 | setup_logger(cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="wsovod")
42 | return cfg
43 |
44 |
45 | def main(args):
46 | cfg = setup(args)
47 | if "MixedDatasets" in args.config_file:
48 | from wsovod.engine import DefaultTrainer_WSOVOD_MixedDatasets as Trainer
49 | else:
50 | from wsovod.engine import DefaultTrainer_WSOVOD as Trainer
51 |
52 | if args.eval_only:
53 | model = Trainer.build_model(cfg)
54 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
55 | cfg.MODEL.WEIGHTS, resume=args.resume
56 | )
57 | res = {}
58 | res = Trainer.test_WSL(cfg, model)
59 | if cfg.TEST.AUG.ENABLED:
60 | res.update(Trainer.test_with_TTA_WSL(cfg, model))
61 | if comm.is_main_process():
62 | verify_results(cfg, res)
63 | return res
64 |
65 | """
66 | If you'd like to do anything fancier than the standard training logic,
67 | consider writing your own training loop (see plain_train_net.py) or
68 | subclassing the trainer.
69 | """
70 | trainer = Trainer(cfg)
71 | trainer.resume_or_load(resume=args.resume)
72 | if cfg.TEST.AUG.ENABLED:
73 | # trainer.register_hooks([hooks.EvalHook(0, lambda: trainer.test_WSL(cfg, trainer.model))])
74 | trainer.register_hooks(
75 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA_WSL(cfg, trainer.model))]
76 | )
77 | return trainer.train()
78 |
79 |
80 | if __name__ == "__main__":
81 | args = default_argument_parser().parse_args()
82 | args.config_file = 'configs/MixedDatasets-Detection/WSOVOD_MRRP_WSR_18_DC5_1x.yaml'
83 | # args.num_gpus = 1
84 | # # # args.eval_only = True
85 | # args.opts = ['OUTPUT_DIR','output/temp','SOLVER.IMS_PER_BATCH',1,'SOLVER.REFERENCE_WORLD_SIZE',1,]
86 | # 'MODEL.WEIGHTS','output/configs/COCO-Detection/WSOVOD_WSR_18_DC5_1x_20230621_003554/model_final.pth']
87 | print("Command Line Args:", args)
88 | launch(
89 | main,
90 | args.num_gpus,
91 | num_machines=args.num_machines,
92 | machine_rank=args.machine_rank,
93 | dist_url=args.dist_url,
94 | args=(args,),
95 | )
96 |
--------------------------------------------------------------------------------
/tools/train_net_eval_open_vocabulary.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | """
4 | A main training script.
5 |
6 | This scripts reads a given config file and runs the training or evaluation.
7 | It is an entry point that is made to train standard models in detectron2.
8 |
9 | In order to let one script support training of many models,
10 | this script contains logic that are specific to these built-in models and therefore
11 | may not be suitable for your own project.
12 | For example, your research project perhaps only needs a single "evaluator".
13 |
14 | Therefore, we recommend you to use detectron2 as an library and take
15 | this file as an example of how to use the library.
16 | You may want to write your own script with your datasets and other customizations.
17 | """
18 |
19 | import logging
20 | import os
21 |
22 | import detectron2.utils.comm as comm
23 | from detectron2.checkpoint import DetectionCheckpointer
24 | from detectron2.config import get_cfg
25 | from detectron2.engine import (default_argument_parser, default_setup,hooks,launch)
26 | from detectron2.utils.logger import setup_logger
27 | from detectron2.evaluation import verify_results
28 | from wsovod.config import add_wsovod_config
29 |
30 |
31 | def setup(args):
32 | """
33 | Create configs and perform basic setups.
34 | """
35 | cfg = get_cfg()
36 | add_wsovod_config(cfg)
37 | cfg.merge_from_file(args.config_file)
38 | cfg.merge_from_list(args.opts)
39 | cfg.freeze()
40 | default_setup(cfg, args)
41 | setup_logger(cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="wsovod")
42 | return cfg
43 |
44 |
45 | def main(args):
46 | cfg = setup(args)
47 | if "MixedDatasets" in args.config_file:
48 | from wsovod.engine import DefaultTrainer_WSOVOD_MixedDatasets as Trainer_
49 | else:
50 | from wsovod.engine import DefaultTrainer_WSOVOD as Trainer_
51 | from detectron2.evaluation import DatasetEvaluators
52 | from wsovod.evaluation.ov_coco_evaluation import OVCOCOEvaluator
53 |
54 | class Trainer(Trainer_):
55 | @classmethod
56 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
57 | """
58 | Create evaluator(s) for a given dataset.
59 | This uses the special metadata "evaluator_type" associated with each builtin dataset.
60 | For your own dataset, you can simply create an evaluator manually in your
61 | script and do not have to worry about the hacky if-else logic here.
62 | """
63 | if output_folder is None:
64 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_" + dataset_name)
65 | evaluator_list = []
66 |
67 | evaluator_list.append(OVCOCOEvaluator(dataset_name, output_dir=output_folder))
68 | return DatasetEvaluators(evaluator_list)
69 |
70 | if args.eval_only:
71 | model = Trainer.build_model(cfg)
72 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
73 | cfg.MODEL.WEIGHTS, resume=args.resume
74 | )
75 | res = {}
76 | res = Trainer.test_WSL(cfg, model)
77 | if cfg.TEST.AUG.ENABLED:
78 | res.update(Trainer.test_with_TTA_WSL(cfg, model))
79 | if comm.is_main_process():
80 | verify_results(cfg, res)
81 | return res
82 |
83 | """
84 | If you'd like to do anything fancier than the standard training logic,
85 | consider writing your own training loop (see plain_train_net.py) or
86 | subclassing the trainer.
87 | """
88 | trainer = Trainer(cfg)
89 | trainer.resume_or_load(resume=args.resume)
90 | if cfg.TEST.AUG.ENABLED:
91 | # trainer.register_hooks([hooks.EvalHook(0, lambda: trainer.test_WSL(cfg, trainer.model))])
92 | trainer.register_hooks(
93 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA_WSL(cfg, trainer.model))]
94 | )
95 | return trainer.train()
96 |
97 |
98 | if __name__ == "__main__":
99 | args = default_argument_parser().parse_args()
100 | print("Command Line Args:", args)
101 | launch(
102 | main,
103 | args.num_gpus,
104 | num_machines=args.num_machines,
105 | machine_rank=args.machine_rank,
106 | dist_url=args.dist_url,
107 | args=(args,),
108 | )
109 |
--------------------------------------------------------------------------------
/wsovod/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from .config import add_wsovod_config
3 | from .data import *
4 | from .evaluation import *
5 | from .modeling import *
6 | from .solver import *
7 |
8 | # This line will be programatically read/write by setup.py.
9 | # Leave them at the bottom of this file and don't touch them.
10 | __version__ = "0.4"
--------------------------------------------------------------------------------
/wsovod/config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from .defaults import add_wsovod_config
3 |
--------------------------------------------------------------------------------
/wsovod/config/defaults.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | from detectron2.config import CfgNode as CN
5 |
6 |
7 | def add_wsovod_config(cfg):
8 | """
9 | Add config for wsovod.
10 | """
11 | _C = cfg
12 | _C.WSOVOD = CN()
13 | _C.WSOVOD.ITER_SIZE = 1
14 | _C.WSOVOD.CLS_AGNOSTIC_BBOX_KNOWN = False
15 | _C.WSOVOD.SAMPLING = CN()
16 | _C.WSOVOD.SAMPLING.SAMPLING_ON = False
17 | _C.WSOVOD.SAMPLING.IOU_THRESHOLDS = [[0.5], [0.5], [0.5], [0.5]]
18 | _C.WSOVOD.SAMPLING.IOU_LABELS = [[0, 1], [0, 1], [0, 1], [0, 1]]
19 | _C.WSOVOD.SAMPLING.BATCH_SIZE_PER_IMAGE = [4096, 4096, 4096, 4096]
20 | _C.WSOVOD.SAMPLING.POSITIVE_FRACTION = [1.0, 1.0, 1.0, 1.0]
21 | _C.WSOVOD.OBJECT_MINING = CN()
22 | _C.WSOVOD.OBJECT_MINING.WEIGHT = 1.0
23 | _C.WSOVOD.OBJECT_MINING.MEAN_LOSS = True
24 | _C.WSOVOD.INSTANCE_REFINEMENT = CN()
25 | _C.WSOVOD.INSTANCE_REFINEMENT.WEIGHT = 1.0
26 | _C.WSOVOD.INSTANCE_REFINEMENT.REFINE_NUM = 3
27 | _C.WSOVOD.INSTANCE_REFINEMENT.REFINE_REG = [False, False, False]
28 | _C.WSOVOD.INSTANCE_REFINEMENT.REFINE_MIST = False
29 | _C.WSOVOD.INSTANCE_REFINEMENT.CROSS_ENTROPY_WEIGHTED = True
30 | _C.WSOVOD.BBOX_REFINE = CN()
31 | _C.WSOVOD.BBOX_REFINE.ENABLE = False
32 | _C.WSOVOD.BBOX_REFINE.MODEL_TYPE = "vit_b" # vit_h vit_l vit_b
33 | _C.WSOVOD.BBOX_REFINE.MODEL_CHECKPOINT = "tools/sam_checkpoints/sam_vit_b_01ec64.pth" # sam_checkpoints/sam_vit_h_4b8939.pth sam_checkpoints/sam_vit_l_0b3195.pth sam_checkpoints/sam_vit_b_01ec64.pth
34 |
35 | # VGG
36 | _C.MODEL.VGG = CN()
37 | _C.MODEL.VGG.DEPTH = 16
38 | _C.MODEL.VGG.OUT_FEATURES = ["plain5"]
39 | _C.MODEL.VGG.CONV5_DILATION = 1
40 | # Swin
41 | _C.MODEL.SWIN = CN()
42 | _C.MODEL.SWIN.EMBED_DIM = 96
43 | _C.MODEL.SWIN.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"]
44 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
45 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
46 | _C.MODEL.SWIN.WINDOW_SIZE = 7
47 | _C.MODEL.SWIN.MLP_RATIO = 4
48 | _C.MODEL.SWIN.DROP_PATH_RATE = 0.2
49 | _C.MODEL.SWIN.APE = False
50 | _C.MODEL.SWIN.PATH_NORM = True
51 |
52 | _C.MODEL.ROI_BOX_HEAD.DAN_DIM = [4096, 4096]
53 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY = CN()
54 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_PATH_TRAIN = ''
55 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_PATH_TEST = ''
56 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_DIM = 512
57 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.USE_BIAS = 0.0
58 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.NORM_WEIGHT = True
59 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.NORM_TEMP = 100.0
60 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.DATA_AWARE = False
61 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.PROTOTYPE_NUM = 5
62 |
63 | _C.MODEL.RPN.SCORE_THRESH_TRAIN = 0.2
64 | _C.MODEL.RPN.SCORE_THRESH_TEST = 0.2
65 | _C.MODEL.RPN.TOPK_CANDIDATES_TRAIN = 2000
66 | _C.MODEL.RPN.TOPK_CANDIDATES_TEST = 1000
67 |
68 | _C.MODEL.MRRP = CN()
69 | _C.MODEL.MRRP.MRRP_ON = False
70 | _C.MODEL.MRRP.NUM_BRANCH = 3
71 | _C.MODEL.MRRP.BRANCH_DILATIONS = [1, 2, 3]
72 | _C.MODEL.MRRP.MRRP_STAGE = "res4"
73 | _C.MODEL.MRRP.TEST_BRANCH_IDX = 1
74 |
75 | _C.DATALOADER.CLASS_ASPECT_RATIO_GROUPING = False
76 | _C.DATALOADER.GROUP_WAIT = 5
77 |
78 | _C.TEST.EVAL_TRAIN = False
79 | _C.VIS_TEST = False
80 |
81 | _C.SOLVER.OPTIMIZER = "SGD"
82 | _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
83 | _C.SOLVER.BASE_LR_END = 0.1
84 | _C.SOLVER.IMS_PER_BATCH_LIST = [4]
85 |
86 | _C.DATASETS.MIXED_DATASETS = CN()
87 | _C.DATASETS.MIXED_DATASETS.NAMES = ['coco_2017_train']
88 | _C.DATASETS.MIXED_DATASETS.WEIGHT_PATH_TRAINS = ['models/coco_text_embedding_single_prompt.pkl']
89 | _C.DATASETS.MIXED_DATASETS.NUM_CLASSES = [80]
90 | _C.DATASETS.MIXED_DATASETS.PROPOSAL_FILES = ['']
91 | _C.DATASETS.MIXED_DATASETS.RATIOS = [1]
92 | _C.DATASETS.MIXED_DATASETS.USE_CAS = [False]
93 | _C.DATASETS.MIXED_DATASETS.USE_RFS = [True]
94 | _C.DATASETS.MIXED_DATASETS.FILTER_EMPTY_ANNOTATIONS = [True]
95 | _C.DATASETS.MIXED_DATASETS.CAS_LAMBDA = 1.0
96 | _C.DATASETS.MIXED_DATASETS.REPEAT_THRESHOLD = 0.001
--------------------------------------------------------------------------------
/wsovod/data/__init__.py:
--------------------------------------------------------------------------------
1 | # ensure the builtin datasets are registered
2 | from . import datasets # isort:skip
3 |
4 | from .build import build_detection_test_loader, build_detection_train_loader
5 |
6 | __all__ = [k for k in globals().keys() if not k.startswith("_")]
7 |
--------------------------------------------------------------------------------
/wsovod/data/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import copy
3 | import itertools
4 | import logging
5 | import pickle
6 | import random
7 |
8 | import numpy as np
9 | import torch.utils.data as data
10 | from detectron2.utils.serialize import PicklableWrapper
11 | from torch.utils.data.sampler import Sampler
12 | from detectron2.data.common import MapDataset, AspectRatioGroupedDataset
13 |
14 | __all__ = [
15 | "ClassAspectRatioGroupedDataset",
16 | "AspectRatioGroupedMixedDatasets",
17 | ]
18 |
19 |
20 | class ClassAspectRatioGroupedDataset(data.IterableDataset):
21 | """
22 | Batch data that have similar aspect ratio together.
23 | In this implementation, images whose aspect ratio < (or >) 1 will
24 | be batched together.
25 | This improves training speed because the images then need less padding
26 | to form a batch.
27 |
28 | It assumes the underlying dataset produces dicts with "width" and "height" keys.
29 | It will then produce a list of original dicts with length = batch_size,
30 | all with similar aspect ratios.
31 | """
32 |
33 | def __init__(self, dataset, batch_size):
34 | """
35 | Args:
36 | dataset: an iterable. Each element must be a dict with keys
37 | "width" and "height", which will be used to batch data.
38 | batch_size (int):
39 | """
40 | self.dataset = dataset
41 | self.batch_size = batch_size
42 | self._buckets = [[] for _ in range(2 * 2000)]
43 |
44 | def __iter__(self):
45 | for d in self.dataset:
46 | w, h = d["width"], d["height"]
47 | bucket_id = 0 if w > h else 1
48 |
49 | classes = list(set(d["instances"].gt_classes.tolist()))
50 |
51 | # for c in classes:
52 | # bucket = self._buckets[2 * c + bucket_id]
53 | # bucket.append(copy.deepcopy(d))
54 | # if len(bucket) == self.batch_size:
55 | # yield bucket[:]
56 | # del bucket[:]
57 | # continue
58 |
59 | if len(classes) > 0:
60 | c = random.choice(classes)
61 | else:
62 | c = 0
63 |
64 | bucket_id = bucket_id + c * 2
65 |
66 | bucket = self._buckets[bucket_id]
67 | bucket.append(d)
68 | if len(bucket) == self.batch_size:
69 | yield bucket[:]
70 | del bucket[:]
71 |
72 |
--------------------------------------------------------------------------------
/wsovod/data/dataset_mapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import copy
3 | import logging
4 | from typing import List, Optional, Union
5 |
6 | import numpy as np
7 | import torch
8 | from detectron2.config import configurable
9 | from detectron2.data import transforms as T
10 |
11 | from . import detection_utils as utils
12 |
13 | """
14 | This file contains the default mapping that's applied to "dataset dicts".
15 | """
16 |
17 | __all__ = ["DatasetMapper"]
18 |
19 |
20 | class DatasetMapper:
21 | """
22 | A callable which takes a dataset dict in Detectron2 Dataset format,
23 | and map it into a format used by the model.
24 |
25 | This is the default callable to be used to map your dataset dict into training data.
26 | You may need to follow it to implement your own one for customized logic,
27 | such as a different way to read or transform images.
28 | See :doc:`/tutorials/data_loading` for details.
29 |
30 | The callable currently does the following:
31 |
32 | 1. Read the image from "file_name"
33 | 2. Applies cropping/geometric transforms to the image and annotations
34 | 3. Prepare data and annotations to Tensor and :class:`Instances`
35 | """
36 |
37 | @configurable
38 | def __init__(
39 | self,
40 | is_train: bool,
41 | *,
42 | augmentations: List[Union[T.Augmentation, T.Transform]],
43 | image_format: str,
44 | use_instance_mask: bool = False,
45 | use_keypoint: bool = False,
46 | instance_mask_format: str = "polygon",
47 | keypoint_hflip_indices: Optional[np.ndarray] = None,
48 | precomputed_proposal_topk: Optional[int] = None,
49 | recompute_boxes: bool = False,
50 | ):
51 | """
52 | NOTE: this interface is experimental.
53 |
54 | Args:
55 | is_train: whether it's used in training or inference
56 | augmentations: a list of augmentations or deterministic transforms to apply
57 | image_format: an image format supported by :func:`detection_utils.read_image`.
58 | use_instance_mask: whether to process instance segmentation annotations, if available
59 | use_keypoint: whether to process keypoint annotations if available
60 | instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
61 | masks into this format.
62 | keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
63 | precomputed_proposal_topk: if given, will load pre-computed
64 | proposals from dataset_dict and keep the top k proposals for each image.
65 | recompute_boxes: whether to overwrite bounding box annotations
66 | by computing tight bounding boxes from instance mask annotations.
67 | """
68 | if recompute_boxes:
69 | assert use_instance_mask, "recompute_boxes requires instance masks"
70 | # fmt: off
71 | self.is_train = is_train
72 | self.augmentations = T.AugmentationList(augmentations)
73 | self.image_format = image_format
74 | self.use_instance_mask = use_instance_mask
75 | self.instance_mask_format = instance_mask_format
76 | self.use_keypoint = use_keypoint
77 | self.keypoint_hflip_indices = keypoint_hflip_indices
78 | self.proposal_topk = precomputed_proposal_topk
79 | self.recompute_boxes = recompute_boxes
80 | # fmt: on
81 | logger = logging.getLogger(__name__)
82 | mode = "training" if is_train else "inference"
83 | logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
84 |
85 | @classmethod
86 | def from_config(cls, cfg, is_train: bool = True):
87 | augs = utils.build_augmentation(cfg, is_train)
88 | if cfg.INPUT.CROP.ENABLED and is_train:
89 | augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
90 | recompute_boxes = cfg.MODEL.MASK_ON
91 | else:
92 | recompute_boxes = False
93 |
94 | ret = {
95 | "is_train": is_train,
96 | "augmentations": augs,
97 | "image_format": cfg.INPUT.FORMAT,
98 | "use_instance_mask": cfg.MODEL.MASK_ON,
99 | "instance_mask_format": cfg.INPUT.MASK_FORMAT,
100 | "use_keypoint": cfg.MODEL.KEYPOINT_ON,
101 | "recompute_boxes": recompute_boxes,
102 | }
103 |
104 | if cfg.MODEL.KEYPOINT_ON:
105 | ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
106 |
107 | if cfg.MODEL.LOAD_PROPOSALS:
108 | ret["precomputed_proposal_topk"] = (
109 | cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
110 | if is_train
111 | else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
112 | )
113 | return ret
114 |
115 | def _transform_annotations(self, dataset_dict, transforms, image_shape):
116 | # USER: Modify this if you want to keep them for some reason.
117 | for anno in dataset_dict["annotations"]:
118 | if not self.use_instance_mask:
119 | anno.pop("segmentation", None)
120 | if not self.use_keypoint:
121 | anno.pop("keypoints", None)
122 |
123 | # USER: Implement additional transformations if you have other types of data
124 | annos = [
125 | utils.transform_instance_annotations(
126 | obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
127 | )
128 | for obj in dataset_dict.pop("annotations")
129 | if obj.get("iscrowd", 0) == 0
130 | ]
131 | instances = utils.annotations_to_instances(
132 | annos, image_shape, mask_format=self.instance_mask_format
133 | )
134 |
135 | # After transforms such as cropping are applied, the bounding box may no longer
136 | # tightly bound the object. As an example, imagine a triangle object
137 | # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
138 | # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
139 | # the intersection of original bounding box and the cropping box.
140 | if self.recompute_boxes:
141 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
142 | dataset_dict["instances"] = utils.filter_empty_instances(instances)
143 |
144 | def __call__(self, dataset_dict):
145 | """
146 | Args:
147 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
148 |
149 | Returns:
150 | dict: a format that builtin models in detectron2 accept
151 | """
152 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
153 | # USER: Write your own image loading if it's not from a file
154 | image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
155 | utils.check_image_size(dataset_dict, image)
156 |
157 | # USER: Remove if you don't do semantic/panoptic segmentation.
158 | if "sem_seg_file_name" in dataset_dict:
159 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
160 | else:
161 | sem_seg_gt = None
162 |
163 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
164 | transforms = self.augmentations(aug_input)
165 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg
166 |
167 | image_shape = image.shape[:2] # h, w
168 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
169 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
170 | # Therefore it's important to use torch.Tensor.
171 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
172 | if sem_seg_gt is not None:
173 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
174 |
175 | # USER: Remove if you don't use pre-computed proposals.
176 | # Most users would not need this feature.
177 | if self.proposal_topk is not None:
178 | utils.transform_proposals(
179 | dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
180 | )
181 |
182 | if not self.is_train:
183 | # USER: Modify this if you want to keep them for some reason.
184 | dataset_dict.pop("annotations", None)
185 | dataset_dict.pop("sem_seg_file_name", None)
186 | return dataset_dict
187 |
188 | if "annotations" in dataset_dict:
189 | self._transform_annotations(dataset_dict, transforms, image_shape)
190 |
191 | return dataset_dict
192 |
193 |
194 |
--------------------------------------------------------------------------------
/wsovod/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from . import builtin # ensure the builtin datasets are registered
2 |
--------------------------------------------------------------------------------
/wsovod/data/datasets/builtin.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 |
5 | """
6 | This file registers pre-defined datasets at hard-coded paths, and their metadata.
7 |
8 | We hard-code metadata for common datasets. This will enable:
9 | 1. Consistency check when loading the datasets
10 | 2. Use models on these standard datasets directly and run demos,
11 | without having to download the dataset annotations
12 |
13 | We hard-code some paths to the dataset that's assumed to
14 | exist in "./datasets/".
15 |
16 | Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
17 | To add new dataset, refer to the tutorial "docs/DATASETS.md".
18 | """
19 |
20 | import os
21 |
22 | from detectron2.data import DatasetCatalog, MetadataCatalog
23 | from detectron2.data.datasets.coco import load_coco_json
24 | from detectron2.data.datasets.coco_panoptic import (
25 | register_coco_panoptic, register_coco_panoptic_separated)
26 | from detectron2.data.datasets.lvis import (get_lvis_instances_meta,
27 | register_lvis_instances)
28 | from detectron2.data.datasets.register_coco import register_coco_instances
29 |
30 | from .builtin_meta import _get_builtin_metadata
31 | from .pascal_voc import register_pascal_voc
32 |
33 |
34 | # ==== Predefined splits for PASCAL VOC ===========
35 | def register_all_pascal_voc(root):
36 | SPLITS = [
37 | ("voc_2007_trainval", "VOC2007", "trainval"),
38 | ("voc_2007_train", "VOC2007", "train"),
39 | ("voc_2007_val", "VOC2007", "val"),
40 | ("voc_2007_test", "VOC2007", "test"),
41 | ("voc_2012_trainval", "VOC2012", "trainval"),
42 | ("voc_2012_train", "VOC2012", "train"),
43 | ("voc_2012_val", "VOC2012", "val"),
44 | ("voc_2012_test", "VOC2012", "test"),
45 | ]
46 | for name, dirname, split in SPLITS:
47 | year = 2007 if "2007" in name else 2012
48 | register_pascal_voc(name, os.path.join(root, dirname), split, year)
49 | MetadataCatalog.get(name).evaluator_type = "pascal_voc"
50 |
51 |
52 | # ==== Predefined datasets and splits for ImageNet ==========
53 |
54 | _PREDEFINED_SPLITS_ImageNet = {}
55 | _PREDEFINED_SPLITS_ImageNet["imagenet"] = {
56 | "ilsvrc_2012_val": (
57 | "ILSVRC2012/val/",
58 | "ILSVRC2012/ILSVRC2012_img_val_converted.json",
59 | ),
60 | "ilsvrc_2012_train": (
61 | "ILSVRC2012/train/",
62 | "ILSVRC2012/ILSVRC2012_img_train_converted.json",
63 | ),
64 | }
65 |
66 | def register_all_imagenet(root):
67 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_ImageNet.items():
68 | for key, (image_root, json_file) in splits_per_dataset.items():
69 | # Assume pre-defined datasets live in `./datasets`.
70 | register_coco_instances(
71 | key,
72 | _get_builtin_metadata(dataset_name),
73 | os.path.join(root, json_file) if "://" not in json_file else json_file,
74 | os.path.join(root, image_root),
75 | )
76 |
77 | # True for open source;
78 | # Internally at fb, we register them elsewhere
79 | if __name__.endswith(".builtin"):
80 | # Assume pre-defined datasets live in `./datasets`.
81 | _root = os.getenv("WSOVOD_DATASETS", "datasets")
82 | register_all_pascal_voc(_root)
83 | register_all_imagenet(_root)
--------------------------------------------------------------------------------
/wsovod/data/datasets/pascal_voc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | import os
5 | import xml.etree.ElementTree as ET
6 | from typing import List, Tuple, Union
7 |
8 | import numpy as np
9 | from detectron2.data import DatasetCatalog, MetadataCatalog
10 | from detectron2.structures import BoxMode
11 | from detectron2.utils.file_io import PathManager
12 | from PIL import Image
13 |
14 | __all__ = ["load_voc_instances", "register_pascal_voc"]
15 |
16 |
17 | # fmt: off
18 | CLASS_NAMES = (
19 | "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
20 | "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
21 | "pottedplant", "sheep", "sofa", "train", "tvmonitor"
22 | )
23 | # fmt: on
24 |
25 |
26 | def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
27 | """
28 | Load Pascal VOC detection annotations to Detectron2 format.
29 |
30 | Args:
31 | dirname: Contain "Annotations", "ImageSets", "JPEGImages"
32 | split (str): one of "train", "test", "val", "trainval"
33 | class_names: list or tuple of class names
34 | """
35 | with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
36 | fileids = np.loadtxt(f, dtype=str)
37 |
38 | # Needs to read many small annotation files. Makes sense at local
39 | annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
40 | dicts = []
41 | for fileid in fileids:
42 | anno_file = os.path.join(annotation_dirname, fileid + ".xml")
43 | jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
44 |
45 | if not os.path.isfile(anno_file):
46 | with Image.open(jpeg_file) as img:
47 | width, height = img.size
48 | r = {"file_name": jpeg_file, "image_id": fileid, "height": height, "width": width}
49 | instances = []
50 | r["annotations"] = instances
51 | dicts.append(r)
52 | continue
53 |
54 | with PathManager.open(anno_file) as f:
55 | tree = ET.parse(f)
56 |
57 | r = {
58 | "file_name": jpeg_file,
59 | "image_id": fileid,
60 | "height": int(tree.findall("./size/height")[0].text),
61 | "width": int(tree.findall("./size/width")[0].text),
62 | }
63 | instances = []
64 |
65 | for obj in tree.findall("object"):
66 | cls = obj.find("name").text
67 | # We include "difficult" samples in training.
68 | # Based on limited experiments, they don't hurt accuracy.
69 | difficult = int(obj.find("difficult").text)
70 | if difficult == 1:
71 | continue
72 | bbox = obj.find("bndbox")
73 | bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
74 | # Original annotations are integers in the range [1, W or H]
75 | # Assuming they mean 1-based pixel indices (inclusive),
76 | # a box with annotation (xmin=1, xmax=W) covers the whole image.
77 | # In coordinate space this is represented by (xmin=0, xmax=W)
78 | bbox[0] -= 1.0
79 | bbox[1] -= 1.0
80 | instances.append(
81 | {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
82 | )
83 | r["annotations"] = instances
84 | dicts.append(r)
85 | return dicts
86 |
87 |
88 | def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES):
89 | if name in DatasetCatalog:
90 | DatasetCatalog.remove(name)
91 | DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names))
92 | MetadataCatalog.get(name).set(
93 | thing_classes=list(class_names), dirname=dirname, year=year, split=split
94 | )
95 |
--------------------------------------------------------------------------------
/wsovod/data/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributed_sampler_multi_dataset import MultiDatasetTrainingSampler, InferenceSampler
2 |
3 | __all__ = [
4 | "MultiDatasetTrainingSampler",
5 | "InferenceSampler",
6 | ]
--------------------------------------------------------------------------------
/wsovod/data/samplers/distributed_sampler_multi_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import itertools
3 | import logging
4 | import math
5 | from collections import defaultdict
6 | from typing import Optional
7 |
8 | import torch
9 | from torch.utils.data.sampler import Sampler
10 |
11 | from detectron2.data.samplers import RepeatFactorTrainingSampler
12 | from detectron2.utils import comm
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class MultiDatasetTrainingSampler(Sampler):
18 | def __init__(self, repeat_factors, *, shuffle=True, seed=None):
19 | self._shuffle = shuffle
20 | if seed is None:
21 | seed = comm.shared_random_seed()
22 | self._seed = int(seed)
23 |
24 | self._rank = comm.get_rank()
25 | self._world_size = comm.get_world_size()
26 |
27 | # Split into whole number (_int_part) and fractional (_frac_part) parts.
28 | self._int_part = torch.trunc(repeat_factors)
29 | self._frac_part = repeat_factors - self._int_part
30 |
31 | @staticmethod
32 | def get_repeat_factors(
33 | dataset_dicts, num_datasets, dataset_ratio, use_rfs, use_cas, repeat_thresh, cas_lambda
34 | ):
35 | sizes = [0 for _ in range(num_datasets)]
36 | for d in dataset_dicts:
37 | sizes[d["dataset_id"]] += 1
38 |
39 | assert len(dataset_ratio) == len(
40 | sizes
41 | ), "length of dataset ratio {} should be equal to number if dataset {}".format(
42 | len(dataset_ratio), len(sizes)
43 | )
44 | dataset_weight = [
45 | torch.ones(s, dtype=torch.float32) * max(sizes) / s * r
46 | for i, (r, s) in enumerate(zip(dataset_ratio, sizes))
47 | ]
48 |
49 | logger = logging.getLogger(__name__)
50 | logger.info(
51 | "Training sampler dataset weight: {}".format(
52 | str([max(sizes) / s * r for i, (r, s) in enumerate(zip(dataset_ratio, sizes))])
53 | )
54 | )
55 |
56 | st = 0
57 | repeat_factors = []
58 | for i, s in enumerate(sizes):
59 | assert use_rfs[i] * use_cas[i] == 0
60 | if use_rfs[i]:
61 | repeat_factor = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
62 | dataset_dicts[st : st + s], repeat_thresh
63 | )
64 | elif use_cas[i]:
65 | repeat_factor = MultiDatasetTrainingSampler.get_class_balance_factor_per_dataset(
66 | dataset_dicts[st : st + s], l=cas_lambda
67 | )
68 | repeat_factor = repeat_factor * (s / repeat_factor.sum())
69 | else:
70 | repeat_factor = torch.ones(s, dtype=torch.float32)
71 | logger.info(
72 | "Training sampler class weight: {} {} {}".format(
73 | repeat_factor.size(), repeat_factor.max(), repeat_factor.min()
74 | )
75 | )
76 | repeat_factors.append(repeat_factor)
77 | st = st + s
78 | repeat_factors = torch.cat(repeat_factors)
79 | dataset_weight = torch.cat(dataset_weight)
80 | repeat_factors = dataset_weight * repeat_factors
81 |
82 | return repeat_factors
83 |
84 | @staticmethod
85 | def get_class_balance_factor_per_dataset(dataset_dicts, l=1.0):
86 | rep_factors = []
87 | category_freq = defaultdict(int)
88 | for dataset_dict in dataset_dicts: # For each image (without repeats)
89 | cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
90 | for cat_id in cat_ids:
91 | category_freq[cat_id] += 1
92 | for dataset_dict in dataset_dicts:
93 | cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
94 | rep_factor = sum([1.0 / (category_freq[cat_id] ** l) for cat_id in cat_ids])
95 | rep_factors.append(rep_factor)
96 |
97 | return torch.tensor(rep_factors, dtype=torch.float32)
98 |
99 | def _get_epoch_indices(self, generator):
100 | """
101 | Create a list of dataset indices (with repeats) to use for one epoch.
102 |
103 | Args:
104 | generator (torch.Generator): pseudo random number generator used for
105 | stochastic rounding.
106 |
107 | Returns:
108 | torch.Tensor: list of dataset indices to use in one epoch. Each index
109 | is repeated based on its calculated repeat factor.
110 | """
111 | # Since repeat factors are fractional, we use stochastic rounding so
112 | # that the target repeat factor is achieved in expectation over the
113 | # course of training
114 | rands = torch.rand(len(self._frac_part), generator=generator)
115 | rep_factors = self._int_part + (rands < self._frac_part).float()
116 | # Construct a list of indices in which we repeat images as specified
117 | indices = []
118 | for dataset_index, rep_factor in enumerate(rep_factors):
119 | indices.extend([dataset_index] * int(rep_factor.item()))
120 | return torch.tensor(indices, dtype=torch.int64)
121 |
122 | def __iter__(self):
123 | start = self._rank
124 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
125 |
126 | def _infinite_indices(self):
127 | g = torch.Generator()
128 | g.manual_seed(self._seed)
129 | while True:
130 | # Sample indices with repeats determined by stochastic rounding; each
131 | # "epoch" may have a slightly different size due to the rounding.
132 | indices = self._get_epoch_indices(g)
133 | if self._shuffle:
134 | randperm = torch.randperm(len(indices), generator=g)
135 | yield from indices[randperm].tolist()
136 | else:
137 | yield from indices.tolist()
138 |
139 |
140 | class InferenceSampler(Sampler):
141 | """
142 | Produce indices for inference across all workers.
143 | Inference needs to run on the __exact__ set of samples,
144 | therefore when the total number of samples is not divisible by the number of workers,
145 | this sampler produces different number of samples on different workers.
146 | """
147 |
148 | def __init__(self, size: int):
149 | """
150 | Args:
151 | size (int): the total number of data of the underlying dataset to sample from
152 | """
153 | self._size = size
154 | assert size > 0
155 | self._rank = comm.get_rank()
156 | self._world_size = comm.get_world_size()
157 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
158 |
159 | @staticmethod
160 | def _get_local_indices(total_size, world_size, rank):
161 | shard_size = total_size // world_size
162 | left = total_size % world_size
163 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
164 |
165 | begin = sum(shard_sizes[:rank])
166 | end = min(sum(shard_sizes[: rank + 1]), total_size)
167 | if end - begin < max(shard_sizes):
168 | assert begin > 0
169 | begin = begin - 1
170 | return range(begin, end)
171 |
172 | def __iter__(self):
173 | yield from self._local_indices
174 |
175 | def __len__(self):
176 | return len(self._local_indices)
--------------------------------------------------------------------------------
/wsovod/engine/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from .hooks import *
3 | from .trainer import *
4 |
--------------------------------------------------------------------------------
/wsovod/engine/hooks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | import datetime
5 | import itertools
6 | import logging
7 | import math
8 | import operator
9 | import os
10 | import tempfile
11 | import time
12 | import warnings
13 | from collections import Counter
14 |
15 | import detectron2.utils.comm as comm
16 | import torch
17 | from detectron2.engine.train_loop import HookBase
18 | from detectron2.evaluation.testing import flatten_results_dict
19 | from detectron2.solver import LRMultiplier
20 | from detectron2.utils.events import EventStorage, EventWriter
21 | from detectron2.utils.file_io import PathManager
22 | from fvcore.common.checkpoint import Checkpointer
23 | from fvcore.common.checkpoint import \
24 | PeriodicCheckpointer as _PeriodicCheckpointer
25 | from fvcore.common.param_scheduler import ParamScheduler
26 | from fvcore.common.timer import Timer
27 | from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats
28 |
29 | # __all__ = [
30 | # "CallbackHook",
31 | # "IterationTimer",
32 | # "PeriodicWriter",
33 | # "PeriodicCheckpointer",
34 | # "BestCheckpointer",
35 | # "LRScheduler",
36 | # "AutogradProfiler",
37 | # "EvalHook",
38 | # "PreciseBN",
39 | # "TorchProfiler",
40 | # "TorchMemoryStats",
41 | # ]
42 |
43 |
44 | """
45 | Implement some common hooks.
46 | """
47 |
48 | class ParametersNormInspectHook(HookBase):
49 |
50 | def __init__(self, period, model, p):
51 | self._period = period
52 | self._model = model
53 | self._p = p
54 | logger = logging.getLogger(__name__)
55 | logger.info('period, norm '+str((period,p)))
56 |
57 | @torch.no_grad()
58 | def _do_inspect(self):
59 | results = {}
60 | # logger = logging.getLogger(__name__)
61 | for key,val in self._model.named_parameters(recurse=True):
62 | results[key] = torch.norm(val,p=self._p)
63 | self.trainer.storage.put_scalar('parameters norm {}/{}'.format(self._p,key),torch.norm(val,p=self._p))
64 | # logger.info(results)
65 |
66 | def after_step(self):
67 | next_iter = self.trainer.iter + 1
68 | if self._period > 0 and next_iter % self._period == 0:
69 | if next_iter != self.trainer.max_iter:
70 | self._do_inspect()
71 |
--------------------------------------------------------------------------------
/wsovod/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .pascal_voc_evaluation import PascalVOCDetectionEvaluator_WSL
2 | from .coco_evaluation import COCOEvaluator
3 |
4 | __all__ = [k for k in globals().keys() if not k.startswith("_")]
5 |
--------------------------------------------------------------------------------
/wsovod/layers/ROILoopPool/ROILoopPool.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | #pragma once
3 | #include
4 |
5 | namespace wsovod {
6 |
7 | std::tuple ROILoopPool_forward_cpu(
8 | const at::Tensor& input,
9 | const at::Tensor& rois,
10 | const float spatial_scale,
11 | const int pooled_height,
12 | const int pooled_width);
13 |
14 | at::Tensor ROILoopPool_backward_cpu(
15 | const at::Tensor& grad,
16 | const at::Tensor& rois,
17 | const at::Tensor& argmax,
18 | const float spatial_scale,
19 | const int pooled_height,
20 | const int pooled_width,
21 | const int batch_size,
22 | const int channels,
23 | const int height,
24 | const int width);
25 |
26 | #if defined(WITH_CUDA) || defined(WITH_HIP)
27 | std::tuple ROILoopPool_forward_cuda(
28 | const at::Tensor& input,
29 | const at::Tensor& rois,
30 | const float spatial_scale,
31 | const int pooled_height,
32 | const int pooled_width);
33 |
34 | at::Tensor ROILoopPool_backward_cuda(
35 | const at::Tensor& grad,
36 | const at::Tensor& rois,
37 | const at::Tensor& argmax,
38 | const float spatial_scale,
39 | const int pooled_height,
40 | const int pooled_width,
41 | const int batch_size,
42 | const int channels,
43 | const int height,
44 | const int width);
45 | #endif
46 |
47 | // Interface for Python
48 | inline std::tuple ROILoopPool_forward(
49 | const at::Tensor& input,
50 | const at::Tensor& rois,
51 | const float spatial_scale,
52 | const int pooled_height,
53 | const int pooled_width) {
54 | if (input.is_cuda()) {
55 | #if defined(WITH_CUDA) || defined(WITH_HIP)
56 | return ROILoopPool_forward_cuda(
57 | input, rois, spatial_scale, pooled_height, pooled_width);
58 | #else
59 | AT_ERROR("Not compiled with GPU support");
60 | #endif
61 | }
62 | AT_ERROR("Not compiled with CPU support");
63 | return ROILoopPool_forward_cpu(
64 | input, rois, spatial_scale, pooled_height, pooled_width);
65 | }
66 |
67 | inline at::Tensor ROILoopPool_backward(
68 | const at::Tensor& grad,
69 | const at::Tensor& rois,
70 | const at::Tensor& argmax,
71 | const float spatial_scale,
72 | const int pooled_height,
73 | const int pooled_width,
74 | const int batch_size,
75 | const int channels,
76 | const int height,
77 | const int width) {
78 | if (grad.is_cuda()) {
79 | #if defined(WITH_CUDA) || defined(WITH_HIP)
80 | return ROILoopPool_backward_cuda(
81 | grad,
82 | rois,
83 | argmax,
84 | spatial_scale,
85 | pooled_height,
86 | pooled_width,
87 | batch_size,
88 | channels,
89 | height,
90 | width);
91 | #else
92 | AT_ERROR("Not compiled with GPU support");
93 | #endif
94 | }
95 | AT_ERROR("Not compiled with CPU support");
96 | return ROILoopPool_backward_cpu(
97 | grad,
98 | rois,
99 | argmax,
100 | spatial_scale,
101 | pooled_height,
102 | pooled_width,
103 | batch_size,
104 | channels,
105 | height,
106 | width);
107 | }
108 |
109 | } // namespace wsovod
110 |
--------------------------------------------------------------------------------
/wsovod/layers/ROILoopPool/ROILoopPool_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include "ROILoopPool.h"
7 |
8 | template
9 | inline void add(T* address, const T& val) {
10 | *address += val;
11 | }
12 |
13 | template
14 | void RoILoopPoolForward(
15 | const T* input,
16 | const T spatial_scale,
17 | const int channels,
18 | const int height,
19 | const int width,
20 | const int pooled_height,
21 | const int pooled_width,
22 | const T* rois,
23 | const int num_rois,
24 | T* output,
25 | int* argmax_data) {
26 | for (int n = 0; n < num_rois; ++n) {
27 | const T* offset_rois = rois + n * 5;
28 | int roi_batch_ind = offset_rois[0];
29 | int roi_start_w = round(offset_rois[1] * spatial_scale);
30 | int roi_start_h = round(offset_rois[2] * spatial_scale);
31 | int roi_end_w = round(offset_rois[3] * spatial_scale);
32 | int roi_end_h = round(offset_rois[4] * spatial_scale);
33 |
34 | // Force malformed ROIs to be 1x1
35 | int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
36 | int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
37 | T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
38 | T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
39 |
40 | for (int ph = 0; ph < pooled_height; ++ph) {
41 | for (int pw = 0; pw < pooled_width; ++pw) {
42 | int hstart = static_cast(floor(static_cast(ph) * bin_size_h));
43 | int wstart = static_cast(floor(static_cast(pw) * bin_size_w));
44 | int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h));
45 | int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w));
46 |
47 | // Add roi offsets and clip to input boundaries
48 | hstart = std::min(std::max(hstart + roi_start_h, 0), height);
49 | hend = std::min(std::max(hend + roi_start_h, 0), height);
50 | wstart = std::min(std::max(wstart + roi_start_w, 0), width);
51 | wend = std::min(std::max(wend + roi_start_w, 0), width);
52 | bool is_empty = (hend <= hstart) || (wend <= wstart);
53 |
54 | for (int c = 0; c < channels; ++c) {
55 | // Define an empty pooling region to be zero
56 | T maxval = is_empty ? 0 : -FLT_MAX;
57 | // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
58 | int maxidx = -1;
59 |
60 | const T* input_offset =
61 | input + (roi_batch_ind * channels + c) * height * width;
62 |
63 | for (int h = hstart; h < hend; ++h) {
64 | for (int w = wstart; w < wend; ++w) {
65 | int input_index = h * width + w;
66 | if (input_offset[input_index] > maxval) {
67 | maxval = input_offset[input_index];
68 | maxidx = input_index;
69 | }
70 | }
71 | }
72 | int index =
73 | ((n * channels + c) * pooled_height + ph) * pooled_width + pw;
74 | output[index] = maxval;
75 | argmax_data[index] = maxidx;
76 | } // channels
77 | } // pooled_width
78 | } // pooled_height
79 | } // num_rois
80 | }
81 |
82 | template
83 | void RoILoopPoolBackward(
84 | const T* grad_output,
85 | const int* argmax_data,
86 | const int num_rois,
87 | const int channels,
88 | const int height,
89 | const int width,
90 | const int pooled_height,
91 | const int pooled_width,
92 | T* grad_input,
93 | const T* rois,
94 | const int n_stride,
95 | const int c_stride,
96 | const int h_stride,
97 | const int w_stride) {
98 | for (int n = 0; n < num_rois; ++n) {
99 | const T* offset_rois = rois + n * 5;
100 | int roi_batch_ind = offset_rois[0];
101 |
102 | for (int c = 0; c < channels; ++c) {
103 | T* grad_input_offset =
104 | grad_input + ((roi_batch_ind * channels + c) * height * width);
105 | const int* argmax_data_offset =
106 | argmax_data + (n * channels + c) * pooled_height * pooled_width;
107 |
108 | for (int ph = 0; ph < pooled_height; ++ph) {
109 | for (int pw = 0; pw < pooled_width; ++pw) {
110 | int output_offset = n * n_stride + c * c_stride;
111 | int argmax = argmax_data_offset[ph * pooled_width + pw];
112 |
113 | if (argmax != -1) {
114 | add(grad_input_offset + argmax,
115 | static_cast(
116 | grad_output
117 | [output_offset + ph * h_stride + pw * w_stride]));
118 | }
119 | } // pooled_width
120 | } // pooled_height
121 | } // channels
122 | } // num_rois
123 | }
124 |
125 | namespace wsovod {
126 |
127 | std::tuple ROILoopPool_forward_cpu(
128 | const at::Tensor& input,
129 | const at::Tensor& rois,
130 | const float spatial_scale,
131 | const int pooled_height,
132 | const int pooled_width) {
133 | AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
134 | AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
135 |
136 | at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
137 |
138 | at::CheckedFrom c = "ROILoopPool_forward_cpu";
139 | at::checkAllSameType(c, {input_t, rois_t});
140 |
141 | int num_rois = rois.size(0);
142 | int channels = input.size(1);
143 | int height = input.size(2);
144 | int width = input.size(3);
145 |
146 | at::Tensor output = at::zeros(
147 | {num_rois, channels, pooled_height, pooled_width}, input.options());
148 | at::Tensor argmax = at::zeros(
149 | {num_rois, channels, pooled_height, pooled_width},
150 | input.options().dtype(at::kInt));
151 |
152 | if (output.numel() == 0) {
153 | return std::make_tuple(output, argmax);
154 | }
155 |
156 | auto input_ = input.contiguous(), rois_ = rois.contiguous();
157 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
158 | input.scalar_type(), "ROILoopPool_forward", [&] {
159 | RoILoopPoolForward(
160 | input_.data_ptr(),
161 | spatial_scale,
162 | channels,
163 | height,
164 | width,
165 | pooled_height,
166 | pooled_width,
167 | rois_.data_ptr(),
168 | num_rois,
169 | output.data_ptr(),
170 | argmax.data_ptr());
171 | });
172 | return std::make_tuple(output, argmax);
173 | }
174 |
175 | at::Tensor ROILoopPool_backward_cpu(
176 | const at::Tensor& grad,
177 | const at::Tensor& rois,
178 | const at::Tensor& argmax,
179 | const float spatial_scale,
180 | const int pooled_height,
181 | const int pooled_width,
182 | const int batch_size,
183 | const int channels,
184 | const int height,
185 | const int width) {
186 | // Check if input tensors are CPU tensors
187 | AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
188 | AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
189 | AT_ASSERTM(argmax.device().is_cpu(), "argmax must be a CPU tensor");
190 |
191 | at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
192 |
193 | at::CheckedFrom c = "ROILoopPool_backward_cpu";
194 | at::checkAllSameType(c, {grad_t, rois_t});
195 |
196 | auto num_rois = rois.size(0);
197 |
198 | at::Tensor grad_input =
199 | at::zeros({batch_size, channels, height, width}, grad.options());
200 |
201 | // handle possibly empty gradients
202 | if (grad.numel() == 0) {
203 | return grad_input;
204 | }
205 |
206 | // get stride values to ensure indexing into gradients is correct.
207 | int n_stride = grad.stride(0);
208 | int c_stride = grad.stride(1);
209 | int h_stride = grad.stride(2);
210 | int w_stride = grad.stride(3);
211 |
212 | auto rois_ = rois.contiguous();
213 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
214 | grad.scalar_type(), "ROILoopPool_backward", [&] {
215 | RoILoopPoolBackward(
216 | grad.data_ptr(),
217 | argmax.data_ptr(),
218 | num_rois,
219 | channels,
220 | height,
221 | width,
222 | pooled_height,
223 | pooled_width,
224 | grad_input.data_ptr(),
225 | rois_.data_ptr(),
226 | n_stride,
227 | c_stride,
228 | h_stride,
229 | w_stride);
230 | });
231 | return grad_input;
232 | }
233 |
234 | } // namespace wsovod
235 |
--------------------------------------------------------------------------------
/wsovod/layers/ROILoopPool/cuda_helpers.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #define CUDA_1D_KERNEL_LOOP(i, n) \
4 | for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \
5 | i += (blockDim.x * gridDim.x))
6 |
7 | template
8 | constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) {
9 | return (n + m - 1) / m;
10 | }
11 |
--------------------------------------------------------------------------------
/wsovod/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .roi_loop_pool import ROILoopPool
2 | from .csc import CSC, CSCConstraint, csc, csc_constraint
3 |
--------------------------------------------------------------------------------
/wsovod/layers/csc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Function
4 | from torch.autograd.function import once_differentiable
5 |
6 | from wsovod import _C
7 |
8 |
9 | class _CSC(Function):
10 | @staticmethod
11 | def forward(
12 | ctx,
13 | cpgs,
14 | labels,
15 | preds,
16 | rois,
17 | tau,
18 | debug_info,
19 | fg_threshold,
20 | mass_threshold,
21 | density_threshold,
22 | area_sqrt,
23 | context_scale,
24 | ):
25 | PL = labels.clone().detach()
26 | NL = torch.zeros(labels.size(), dtype=labels.dtype, device=labels.device)
27 |
28 | W = _C.csc_forward(
29 | cpgs,
30 | labels,
31 | preds,
32 | rois,
33 | tau,
34 | debug_info,
35 | fg_threshold,
36 | mass_threshold,
37 | density_threshold,
38 | area_sqrt,
39 | context_scale,
40 | )
41 |
42 | return W, PL, NL
43 |
44 | @staticmethod
45 | @once_differentiable
46 | def backward(ctx, dW, dPL, dNL):
47 | return None, None, None, None
48 |
49 |
50 | csc = _CSC.apply
51 |
52 |
53 | class CSC(nn.Module):
54 | def __init__(
55 | self,
56 | tau=0.7,
57 | debug_info=False,
58 | fg_threshold=0.1,
59 | mass_threshold=0.2,
60 | density_threshold=0.0,
61 | area_sqrt=True,
62 | context_scale=1.8,
63 | ):
64 | super(CSC, self).__init__()
65 |
66 | self.tau = tau
67 | self.debug_info = debug_info
68 | self.fg_threshold = fg_threshold
69 | self.mass_threshold = mass_threshold
70 | self.density_threshold = density_threshold
71 | self.area_sqrt = area_sqrt
72 | self.context_scale = context_scale
73 |
74 | def forward(self, cpgs, labels, preds, rois):
75 | return csc(
76 | cpgs,
77 | labels,
78 | preds,
79 | rois,
80 | self.tau,
81 | self.debug_info,
82 | self.fg_threshold,
83 | self.mass_threshold,
84 | self.density_threshold,
85 | self.area_sqrt,
86 | self.context_scale,
87 | )
88 |
89 | def __repr__(self):
90 | tmpstr = self.__class__.__name__ + "("
91 | tmpstr += "tau=" + str(self.tau)
92 | tmpstr += ", debug_info=" + str(self.debug_info)
93 | tmpstr += ", fg_threshold=" + str(self.fg_threshold)
94 | tmpstr += ", mass_threshold=" + str(self.mass_threshold)
95 | tmpstr += ", density_threshold=" + str(self.density_threshold)
96 | tmpstr += ", area_sqrt=" + str(self.area_sqrt)
97 | tmpstr += ", context_scale=" + str(self.context_scale)
98 | tmpstr += ")"
99 | return tmpstr
100 |
101 |
102 | class _CSCConstraint(Function):
103 | @staticmethod
104 | def forward(ctx, X, W, polar):
105 |
106 | if polar:
107 | W_ = torch.clamp(W, min=0.0)
108 | else:
109 | W_ = torch.clamp(W, max=0.0)
110 | W_ = W_ * (-1.0)
111 |
112 | ctx.save_for_backward(W_)
113 |
114 | Y = X * W_
115 |
116 | return Y
117 |
118 | @staticmethod
119 | @once_differentiable
120 | def backward(ctx, dY):
121 | (W_,) = ctx.saved_tensors
122 |
123 | dX = dY * W_
124 |
125 | return dX, None, None
126 |
127 |
128 | csc_constraint = _CSCConstraint.apply
129 |
130 |
131 | class CSCConstraint(nn.Module):
132 | def __init__(self, polar=True):
133 | super(CSCConstraint, self).__init__()
134 |
135 | self.polar = polar
136 |
137 | def forward(self, X, W):
138 | return csc_constraint(X, W, self.polar)
139 |
140 | def __repr__(self):
141 | tmpstr = self.__class__.__name__ + "("
142 | tmpstr += "polar=" + str(self.polar)
143 | tmpstr += ")"
144 | return tmpstr
--------------------------------------------------------------------------------
/wsovod/layers/csc/csc.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace wsovod {
5 |
6 | #ifdef WITH_CUDA
7 | at::Tensor csc_forward_cuda(
8 | const at::Tensor& cpgs,
9 | const at::Tensor& labels,
10 | const at::Tensor& preds,
11 | const at::Tensor& rois,
12 | const float tau_,
13 | const bool debug_info_,
14 | const float fg_threshold_,
15 | const float mass_threshold_,
16 | const float density_threshold_,
17 | const bool area_sqrt_,
18 | const float context_scale_);
19 | #endif
20 |
21 | // Interface for Python
22 | inline at::Tensor csc_forward(
23 | const at::Tensor& cpgs,
24 | const at::Tensor& labels,
25 | const at::Tensor& preds,
26 | const at::Tensor& rois,
27 | const float tau_ = 0.7,
28 | const bool debug_info_ = false,
29 | const float fg_threshold_ = 0.1,
30 | const float mass_threshold_ = 0.2,
31 | const float density_threshold_ = 0.0,
32 | const bool area_sqrt_ = true,
33 | const float context_scale_ = 1.8) {
34 | if (cpgs.device().type() == at::kCUDA) {
35 | #ifdef WITH_CUDA
36 | return csc_forward_cuda(
37 | cpgs,
38 | labels,
39 | preds,
40 | rois,
41 | tau_,
42 | debug_info_,
43 | fg_threshold_,
44 | mass_threshold_,
45 | density_threshold_,
46 | area_sqrt_,
47 | context_scale_);
48 | #else
49 | AT_ERROR("Not compiled with GPU support");
50 | #endif
51 | }
52 | AT_ERROR("Not compiled with CPU support");
53 | }
54 |
55 | } // namespace wsovod
56 |
--------------------------------------------------------------------------------
/wsovod/layers/roi_loop_pool.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.autograd import Function
3 | from torch.autograd.function import once_differentiable
4 | from torch.nn.modules.utils import _pair
5 |
6 | from wsovod import _C
7 |
8 |
9 | class _ROILoopPool(Function):
10 | @staticmethod
11 | def forward(ctx, input, roi, output_size, spatial_scale):
12 | ctx.output_size = _pair(output_size)
13 | ctx.spatial_scale = spatial_scale
14 | ctx.input_shape = input.size()
15 | output, argmax = _C.roi_loop_pool_forward(
16 | input, roi, spatial_scale, output_size[0], output_size[1]
17 | )
18 | ctx.save_for_backward(roi, argmax)
19 | ctx.mark_non_differentiable(argmax)
20 | return output
21 |
22 | @staticmethod
23 | @once_differentiable
24 | def backward(ctx, grad_output):
25 | (rois, argmax) = ctx.saved_tensors
26 | output_size = ctx.output_size
27 | spatial_scale = ctx.spatial_scale
28 | bs, ch, h, w = ctx.input_shape
29 | grad_input = _C.roi_loop_pool_backward(
30 | grad_output, rois, argmax, spatial_scale, output_size[0], output_size[1], bs, ch, h, w
31 | )
32 | return grad_input, None, None, None
33 |
34 |
35 | roi_loop_pool = _ROILoopPool.apply
36 |
37 |
38 | class ROILoopPool(nn.Module):
39 | def __init__(self, output_size, spatial_scale):
40 | super(ROILoopPool, self).__init__()
41 | self.output_size = output_size
42 | self.spatial_scale = spatial_scale
43 |
44 | def forward(self, input, rois):
45 | """
46 | Args:
47 | input: NCHW images
48 | rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
49 | """
50 | assert rois.dim() == 2 and rois.size(1) == 5
51 | return roi_loop_pool(input, rois, self.output_size, self.spatial_scale)
52 |
53 | def __repr__(self):
54 | tmpstr = self.__class__.__name__ + "("
55 | tmpstr += "output_size=" + str(self.output_size)
56 | tmpstr += ", spatial_scale=" + str(self.spatial_scale)
57 | tmpstr += ")"
58 | return tmpstr
59 |
--------------------------------------------------------------------------------
/wsovod/layers/vision.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | #include
4 | #include "ROILoopPool/ROILoopPool.h"
5 | #include "csc/csc.h"
6 |
7 | namespace wsovod {
8 |
9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10 | m.def("roi_loop_pool_forward", &ROILoopPool_forward, "ROILoopPool_forward");
11 | m.def("roi_loop_pool_backward", &ROILoopPool_backward, "ROILoopPool_backward");
12 | m.def("csc_forward", &csc_forward, "csc_forward");
13 | }
14 |
15 | } // namespace wsovod
16 |
--------------------------------------------------------------------------------
/wsovod/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | from .backbone import *
4 | from .meta_arch import *
5 | from .postprocessing import detector_postprocess
6 | from .roi_heads import *
7 | from .proposal_generator import *
8 | from .test_time_augmentation_avg import (DatasetMapperTTAAVG,
9 | GeneralizedRCNNWithTTAAVG)
10 | from .test_time_augmentation_union import (DatasetMapperTTAUNION,
11 | GeneralizedRCNNWithTTAUNION)
12 |
13 |
14 |
15 | _EXCLUDE = {"ShapeSpec"}
16 | __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
17 |
--------------------------------------------------------------------------------
/wsovod/modeling/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from .swin_transformer import build_swin_backbone, build_swin_fpn_backbone
3 | from .resnet_wsl import build_wsl_resnet_backbone
4 | from .resnet_wsl_mrrp import build_mrrp_wsl_resnet_backbone
5 | from .vgg import VGG16, PlainBlockBase, build_vgg_backbone
6 | from .vgg_mrrp import build_mrrp_vgg_backbone
7 |
8 | # TODO can expose more resnet blocks after careful consideration
9 |
--------------------------------------------------------------------------------
/wsovod/modeling/backbone/mrrp_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn.modules.utils import _pair
6 |
7 | from detectron2.layers.wrappers import _NewEmptyTensorOp
8 |
9 |
10 | class MRRPConv(nn.Module):
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size,
16 | stride=1,
17 | paddings=0,
18 | dilations=1,
19 | groups=1,
20 | num_branch=1,
21 | test_branch_idx=-1,
22 | bias=False,
23 | norm=None,
24 | activation=None,
25 | ):
26 | super(MRRPConv, self).__init__()
27 | self.in_channels = in_channels
28 | self.out_channels = out_channels
29 | self.kernel_size = _pair(kernel_size)
30 | self.num_branch = num_branch
31 | self.stride = _pair(stride)
32 | self.groups = groups
33 | self.with_bias = bias
34 | if isinstance(paddings, int):
35 | paddings = [paddings] * self.num_branch
36 | if isinstance(dilations, int):
37 | dilations = [dilations] * self.num_branch
38 | self.paddings = [_pair(padding) for padding in paddings]
39 | self.dilations = [_pair(dilation) for dilation in dilations]
40 | self.test_branch_idx = test_branch_idx
41 | self.norm = norm
42 | self.activation = activation
43 |
44 | assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1
45 |
46 | self.weight = nn.Parameter(
47 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
48 | )
49 | if bias:
50 | self.bias = nn.Parameter(torch.Tensor(out_channels))
51 | else:
52 | self.bias = None
53 |
54 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
55 | if self.bias is not None:
56 | nn.init.constant_(self.bias, 0)
57 |
58 | def forward(self, inputs):
59 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
60 | assert len(inputs) == num_branch
61 |
62 | if inputs[0].numel() == 0:
63 | output_shape = [
64 | (i + 2 * p - (di * (k - 1) + 1)) // s + 1
65 | for i, p, di, k, s in zip(
66 | inputs[0].shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
67 | )
68 | ]
69 | output_shape = [input[0].shape[0], self.weight.shape[0]] + output_shape
70 | return [_NewEmptyTensorOp.apply(input, output_shape) for input in inputs]
71 |
72 | if self.training or self.test_branch_idx == -1:
73 | outputs = [
74 | F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups)
75 | for input, dilation, padding in zip(inputs, self.dilations, self.paddings)
76 | ]
77 | else:
78 | outputs = [
79 | F.conv2d(
80 | inputs[0],
81 | self.weight,
82 | self.bias,
83 | self.stride,
84 | self.paddings[self.test_branch_idx],
85 | self.dilations[self.test_branch_idx],
86 | self.groups,
87 | )
88 | ]
89 |
90 | if self.norm is not None:
91 | outputs = [self.norm(x) for x in outputs]
92 | if self.activation is not None:
93 | outputs = [self.activation(x) for x in outputs]
94 | return outputs
95 |
96 | def extra_repr(self):
97 | tmpstr = "in_channels=" + str(self.in_channels)
98 | tmpstr += ", out_channels=" + str(self.out_channels)
99 | tmpstr += ", kernel_size=" + str(self.kernel_size)
100 | tmpstr += ", num_branch=" + str(self.num_branch)
101 | tmpstr += ", test_branch_idx=" + str(self.test_branch_idx)
102 | tmpstr += ", stride=" + str(self.stride)
103 | tmpstr += ", paddings=" + str(self.paddings)
104 | tmpstr += ", dilations=" + str(self.dilations)
105 | tmpstr += ", groups=" + str(self.groups)
106 | tmpstr += ", bias=" + str(self.with_bias)
107 | return tmpstr
--------------------------------------------------------------------------------
/wsovod/modeling/backbone/vgg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import fvcore.nn.weight_init as weight_init
3 | import torch.nn.functional as F
4 | from detectron2.layers import Conv2d, FrozenBatchNorm2d, ShapeSpec
5 | from detectron2.modeling.backbone.backbone import Backbone
6 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
7 | from torch import nn
8 |
9 | __all__ = ["PlainBlockBase", "PlainBlock", "VGG16", "build_vgg_backbone"]
10 |
11 |
12 | class PlainBlockBase(nn.Module):
13 | def __init__(self, in_channels, out_channels, stride):
14 | """
15 | The `__init__` method of any subclass should also contain these arguments.
16 |
17 | Args:
18 | in_channels (int):
19 | out_channels (int):
20 | stride (int):
21 | """
22 | super().__init__()
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 | self.stride = stride
26 |
27 | def freeze(self):
28 | for p in self.parameters():
29 | p.requires_grad = False
30 | FrozenBatchNorm2d.convert_frozen_batchnorm(self)
31 | return self
32 |
33 |
34 | class PlainBlock(PlainBlockBase):
35 | def __init__(self, in_channels, out_channels, num_conv=3, dilation=1, stride=1, has_pool=False):
36 | super().__init__(in_channels, out_channels, stride)
37 |
38 | self.num_conv = num_conv
39 | self.dilation = dilation
40 |
41 | self.has_pool = has_pool
42 | self.pool_stride = stride
43 |
44 | self.conv1 = Conv2d(
45 | in_channels,
46 | out_channels,
47 | kernel_size=3,
48 | stride=1,
49 | padding=1 * dilation,
50 | bias=True,
51 | groups=1,
52 | dilation=dilation,
53 | norm=None,
54 | )
55 | weight_init.c2_msra_fill(self.conv1)
56 |
57 | self.conv2 = Conv2d(
58 | out_channels,
59 | out_channels,
60 | kernel_size=3,
61 | stride=1,
62 | padding=1 * dilation,
63 | bias=True,
64 | groups=1,
65 | dilation=dilation,
66 | norm=None,
67 | )
68 | weight_init.c2_msra_fill(self.conv2)
69 |
70 | if self.num_conv > 2:
71 | self.conv3 = Conv2d(
72 | out_channels,
73 | out_channels,
74 | kernel_size=3,
75 | stride=1,
76 | padding=1 * dilation,
77 | bias=True,
78 | groups=1,
79 | dilation=dilation,
80 | norm=None,
81 | )
82 | weight_init.c2_msra_fill(self.conv3)
83 |
84 | if self.num_conv > 3:
85 | self.conv4 = Conv2d(
86 | out_channels,
87 | out_channels,
88 | kernel_size=3,
89 | stride=1,
90 | padding=1 * dilation,
91 | bias=True,
92 | groups=1,
93 | dilation=dilation,
94 | norm=None,
95 | )
96 | weight_init.c2_msra_fill(self.conv4)
97 |
98 | if self.has_pool:
99 | self.pool = nn.MaxPool2d(kernel_size=2, stride=self.pool_stride, padding=0)
100 |
101 | assert num_conv < 5
102 |
103 | def forward(self, x):
104 | x = self.conv1(x)
105 | x = F.relu_(x)
106 |
107 | x = self.conv2(x)
108 | x = F.relu_(x)
109 |
110 | if self.num_conv > 2:
111 | x = self.conv3(x)
112 | x = F.relu_(x)
113 |
114 | if self.num_conv > 3:
115 | x = self.conv4(x)
116 | x = F.relu_(x)
117 |
118 | if self.has_pool:
119 | x = self.pool(x)
120 |
121 | return x
122 |
123 |
124 | class VGG16(Backbone):
125 | def __init__(self, conv5_dilation, freeze_at, num_classes=None, out_features=None):
126 | """
127 | Args:
128 | stem (nn.Module): a stem module
129 | stages (list[list[ResNetBlock]]): several (typically 4) stages,
130 | each contains multiple :class:`ResNetBlockBase`.
131 | num_classes (None or int): if None, will not perform classification.
132 | out_features (list[str]): name of the layers whose outputs should
133 | be returned in forward. Can be anything in "stem", "linear", or "res2" ...
134 | If None, will return the output of the last layer.
135 | """
136 | super(VGG16, self).__init__()
137 |
138 | self.num_classes = num_classes
139 |
140 | self._out_feature_strides = {}
141 | self._out_feature_channels = {}
142 |
143 | self.stages_and_names = []
144 |
145 | name = "plain1"
146 | block = PlainBlock(3, 64, num_conv=2, stride=2, has_pool=True)
147 | blocks = [block]
148 | stage = nn.Sequential(*blocks)
149 | self.add_module(name, stage)
150 | self.stages_and_names.append((stage, name))
151 | self._out_feature_strides[name] = 2
152 | self._out_feature_channels[name] = blocks[-1].out_channels
153 | if freeze_at >= 1:
154 | for block in blocks:
155 | block.freeze()
156 |
157 | name = "plain2"
158 | block = PlainBlock(64, 128, num_conv=2, stride=2, has_pool=True)
159 | blocks = [block]
160 | stage = nn.Sequential(*blocks)
161 | self.add_module(name, stage)
162 | self.stages_and_names.append((stage, name))
163 | self._out_feature_strides[name] = 4
164 | self._out_feature_channels[name] = blocks[-1].out_channels
165 | if freeze_at >= 2:
166 | for block in blocks:
167 | block.freeze()
168 |
169 | name = "plain3"
170 | block = PlainBlock(128, 256, num_conv=3, stride=2, has_pool=True)
171 | blocks = [block]
172 | stage = nn.Sequential(*blocks)
173 | self.add_module(name, stage)
174 | self.stages_and_names.append((stage, name))
175 | self._out_feature_strides[name] = 8
176 | self._out_feature_channels[name] = blocks[-1].out_channels
177 | if freeze_at >= 3:
178 | for block in blocks:
179 | block.freeze()
180 |
181 | name = "plain4"
182 | block = PlainBlock(
183 | 256, 512, num_conv=3, stride=1 if conv5_dilation == 2 else 2, has_pool=True
184 | )
185 | blocks = [block]
186 | stage = nn.Sequential(*blocks)
187 | self.add_module(name, stage)
188 | self.stages_and_names.append((stage, name))
189 | self._out_feature_strides[name] = 8 if conv5_dilation == 2 else 16
190 | self._out_feature_channels[name] = blocks[-1].out_channels
191 | if freeze_at >= 4:
192 | for block in blocks:
193 | block.freeze()
194 |
195 | name = "plain5"
196 | block = PlainBlock(512, 512, num_conv=3, stride=1, dilation=conv5_dilation, has_pool=False)
197 | blocks = [block]
198 | stage = nn.Sequential(*blocks)
199 | self.add_module(name, stage)
200 | self.stages_and_names.append((stage, name))
201 | self._out_feature_strides[name] = 8 if conv5_dilation == 2 else 16
202 | self._out_feature_channels[name] = blocks[-1].out_channels
203 | if freeze_at >= 5:
204 | for block in blocks:
205 | block.freeze()
206 |
207 | if out_features is None:
208 | out_features = [name]
209 | self._out_features = out_features
210 | assert len(self._out_features)
211 | children = [x[0] for x in self.named_children()]
212 | for out_feature in self._out_features:
213 | assert out_feature in children, "Available children: {}".format(", ".join(children))
214 |
215 | def forward(self, x):
216 | outputs = {}
217 | for stage, name in self.stages_and_names:
218 | x = stage(x)
219 | if name in self._out_features:
220 | outputs[name] = x
221 |
222 | return outputs
223 |
224 | def output_shape(self):
225 | return {
226 | name: ShapeSpec(
227 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
228 | )
229 | for name in self._out_features
230 | }
231 |
232 |
233 | @BACKBONE_REGISTRY.register()
234 | def build_vgg_backbone(cfg, input_shape):
235 |
236 | # fmt: off
237 | depth = cfg.MODEL.VGG.DEPTH
238 | conv5_dilation = cfg.MODEL.VGG.CONV5_DILATION
239 | freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
240 | # fmt: on
241 |
242 | if depth == 16:
243 | return VGG16(conv5_dilation, freeze_at)
244 |
--------------------------------------------------------------------------------
/wsovod/modeling/class_heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .open_vocabulary_classifier import OpenVocabularyClassifier
--------------------------------------------------------------------------------
/wsovod/modeling/class_heads/data_aware_features_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import os
3 | from pathlib import Path
4 | from typing import Dict, List, Optional, Tuple
5 |
6 | import cv2
7 | import fvcore.nn.weight_init as weight_init
8 | import numpy as np
9 | import torch
10 | from detectron2.config import configurable
11 | from detectron2.layers import Conv2d, Linear, ShapeSpec, get_norm
12 | from detectron2.structures import Boxes, ImageList, Instances
13 | from detectron2.utils.events import get_event_storage
14 | from torch import nn
15 | from torch.nn import functional as F
16 | import logging
17 |
18 |
19 | class DataAwareFeaturesHead(nn.Module):
20 | @configurable
21 | def __init__(
22 | self,
23 | input_shape: ShapeSpec,
24 | *,
25 | datasets_prototype_num: int = 5,
26 | features_dim: int = 512,
27 | cls_in_features: List[str],
28 | mrrp_on: bool = False,
29 | mrrp_num_branch: int = 3,
30 | ):
31 | """
32 | NOTE: this interface is experimental.
33 |
34 | Args:
35 | input_shape (ShapeSpec): shape of the input feature.
36 | conv_dims (list[int]): the output dimensions of the conv layers
37 | fc_dims (list[int]): the output dimensions of the fc layers
38 | conv_norm (str or callable): normalization for the conv layers.
39 | See :func:`detectron2.layers.get_norm` for supported types.
40 | """
41 | super().__init__()
42 |
43 | self.in_features = self.cls_in_features = cls_in_features
44 | self.features_dim = features_dim
45 | self.mrrp_on = mrrp_on
46 | self.mrrp_num_branch = mrrp_num_branch
47 |
48 | self.in_channels = [input_shape[f].channels for f in self.in_features]
49 | in_channels = [input_shape[f].channels for f in self.in_features]
50 | # Check all channel counts are equal
51 | assert len(set(in_channels)) == 1, in_channels
52 | in_channels = in_channels[0]
53 |
54 | self.strides = [input_shape[f].stride for f in self.in_features]
55 |
56 | self._output_size = (in_channels,)
57 |
58 | self.datasets_prototype_num = datasets_prototype_num
59 | self.datasets_feat = nn.Embedding(self.datasets_prototype_num, self.features_dim)
60 |
61 | self.GAP = nn.AdaptiveAvgPool2d(1)
62 |
63 | fc_dims = [in_channels//16,self.datasets_prototype_num]
64 |
65 | self.fcs = []
66 | for k, fc_dim in enumerate(fc_dims):
67 | fc = Linear(np.prod(self._output_size), fc_dim)
68 | self.fcs.append(fc)
69 | self.add_module("linear{}".format(k + 1), fc)
70 | if k < len(fc_dims) - 1:
71 | relu = nn.ReLU(inplace=True)
72 | self.fcs.append(relu)
73 | self.add_module("linear_relu{}".format(k + 1), relu)
74 | else:
75 | tanh = nn.Tanh()
76 | self.fcs.append(tanh)
77 | self.add_module("linear_tanh{}".format(k + 1), tanh)
78 | self._output_size = fc_dim
79 |
80 | for layer in self.fcs:
81 | if not isinstance(layer, Linear):
82 | continue
83 | nn.init.uniform_(layer.weight, -0.01, 0.01)
84 | torch.nn.init.constant_(layer.bias, 0)
85 |
86 | @classmethod
87 | def from_config(cls, cfg, input_shape):
88 | # fmt: off
89 | in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
90 | mrrp_on = cfg.MODEL.MRRP.MRRP_ON
91 | mrrp_num_branch = cfg.MODEL.MRRP.NUM_BRANCH
92 | datasets_prototype_num = cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.PROTOTYPE_NUM
93 | # fmt: on
94 | return {
95 | "cls_in_features": in_features,
96 | "input_shape": input_shape,
97 | "datasets_prototype_num": datasets_prototype_num,
98 | "features_dim": cfg.MODEL.ROI_BOX_HEAD.DAN_DIM[-1],
99 | "mrrp_on": mrrp_on,
100 | "mrrp_num_branch": mrrp_num_branch,
101 | }
102 |
103 | def forward(
104 | self,
105 | features: Dict[str, torch.Tensor],
106 | proposals,
107 | ):
108 |
109 | features = [features[f] for f in self.cls_in_features]
110 | if self.mrrp_on:
111 | features = [torch.stack(torch.chunk(f, self.mrrp_num_branch)).mean(0) for f in features]
112 | data_features = [self._forward(f) for f in features]
113 | if len(data_features)>1:
114 | data_features = torch.stack(data_features).mean(0)
115 | else:
116 | data_features = data_features[0]
117 | results = []
118 | for i in range(len(proposals)):
119 | results.append(data_features[i].repeat(len(proposals[i]),1))
120 | results = torch.cat(results)
121 | return results
122 |
123 | def _forward(self, x):
124 | x = self.GAP(x)
125 | x = x.flatten(start_dim=1)
126 | for k, layer in enumerate(self.fcs):
127 | x = layer(x)
128 | combined = torch.matmul(x, self.datasets_feat.weight)
129 | return combined
130 |
131 |
132 |
--------------------------------------------------------------------------------
/wsovod/modeling/class_heads/open_vocabulary_classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import logging
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from math import fabs
8 | from detectron2.config import configurable
9 | from detectron2.layers import Linear, ShapeSpec
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class OpenVocabularyClassifier(nn.Module):
15 | @configurable
16 | def __init__(
17 | self,
18 | input_shape: ShapeSpec,
19 | *,
20 | num_classes: int,
21 | weight_path: str,
22 | weight_dim: int = 512,
23 | use_bias: float = 0.0,
24 | norm_weight: bool = True,
25 | norm_temperature: float = 50.0,
26 | ):
27 | super().__init__()
28 | if isinstance(input_shape, int): # some backward compatibility
29 | input_shape = ShapeSpec(channels=input_shape)
30 | input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1)
31 | self.norm_weight = norm_weight
32 | self.weight_dim = weight_dim
33 | self.norm_temperature = norm_temperature
34 |
35 | self.use_bias = fabs(use_bias)>1e-9
36 | if self.use_bias:
37 | self.cls_bias = nn.Parameter(torch.ones(1) * use_bias)
38 |
39 | self.projection = nn.Sequential(
40 | nn.Linear(input_size, 1024),
41 | nn.ReLU(),
42 | nn.Linear(1024, weight_dim),
43 | nn.ReLU(),
44 | )
45 |
46 |
47 | if weight_path == "rand":
48 | class_weight = torch.randn((weight_dim, num_classes))
49 | nn.init.normal_(class_weight, std=0.01)
50 | else:
51 | logger.info("Loading " + weight_path)
52 | class_weight = (
53 | torch.tensor(np.load(weight_path,encoding='bytes', allow_pickle=True), dtype=torch.float32)
54 | .permute(1, 0)
55 | .contiguous()
56 | ) # D x C
57 | logger.info(f"Loaded class weight {class_weight.size()}")
58 |
59 | if self.norm_weight:
60 | class_weight = F.normalize(class_weight, p=2, dim=0)
61 |
62 | if weight_path == "rand":
63 | self.class_weight = nn.Parameter(class_weight)
64 | else:
65 | self.register_buffer("class_weight", class_weight)
66 |
67 | @classmethod
68 | def from_config(cls, cfg, input_shape, weight_path = None, use_bias = None, norm_weight = None, norm_temperature = None):
69 | return {
70 | "input_shape": input_shape,
71 | "num_classes": cfg.MODEL.ROI_HEADS.NUM_CLASSES,
72 | "weight_path": weight_path if weight_path is not None else cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_PATH_TRAIN,
73 | "weight_dim": cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_DIM,
74 | "use_bias": use_bias if use_bias is not None else cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.USE_BIAS,
75 | "norm_weight": norm_weight if norm_weight is not None else cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.NORM_WEIGHT,
76 | "norm_temperature": norm_temperature if norm_temperature is not None else cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.NORM_TEMP,
77 | }
78 |
79 | def forward(self, x, classifier=None, append_background = False):
80 | """
81 | Inputs:
82 | x: B x D
83 | classifier: (C', C' x D)
84 | """
85 | x = self.projection(x)
86 |
87 | if classifier is not None:
88 | class_weight = classifier.permute(1, 0).contiguous() # D x C'
89 | class_weight = F.normalize(class_weight, p=2, dim=0) if self.norm_weight else class_weight
90 | else:
91 | class_weight = self.class_weight
92 |
93 | if self.norm_weight:
94 | x = self.norm_temperature * F.normalize(x, p=2, dim=1)
95 |
96 | if append_background:
97 | class_weight = torch.cat(
98 | [class_weight, class_weight.new_zeros((self.weight_dim, 1))], dim=1
99 | ) # D x (C + 1)
100 | # logger.info(f"Cated class weight {class_weight.size()}")
101 |
102 | x = torch.mm(x, class_weight)
103 | if self.use_bias:
104 | x = x + self.cls_bias
105 | return x
--------------------------------------------------------------------------------
/wsovod/modeling/meta_arch/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | import imp
4 | from .rcnn_wsovod import GeneralizedRCNN_WSOVOD
5 | from .rcnn_wsovod_mixed_datasets import GeneralizedRCNN_WSOVOD_MixedDatasets
--------------------------------------------------------------------------------
/wsovod/modeling/postprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch
3 | from detectron2.structures import Instances, ROIMasks
4 | from torch.nn import functional as F
5 |
6 |
7 | # perhaps should rename to "resize_instance"
8 | def detector_postprocess(
9 | results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5
10 | ):
11 | """
12 | Resize the output instances.
13 | The input images are often resized when entering an object detector.
14 | As a result, we often need the outputs of the detector in a different
15 | resolution from its inputs.
16 |
17 | This function will resize the raw outputs of an R-CNN detector
18 | to produce outputs according to the desired output resolution.
19 |
20 | Args:
21 | results (Instances): the raw outputs from the detector.
22 | `results.image_size` contains the input image resolution the detector sees.
23 | This object might be modified in-place.
24 | output_height, output_width: the desired output resolution.
25 |
26 | Returns:
27 | Instances: the resized output from the model, based on the output resolution
28 | """
29 | # Change to 'if is_tracing' after PT1.7
30 | if isinstance(output_height, torch.Tensor):
31 | # Converts integer tensors to float temporaries to ensure true
32 | # division is performed when computing scale_x and scale_y.
33 | output_width_tmp = output_width.float()
34 | output_height_tmp = output_height.float()
35 | new_size = torch.stack([output_height, output_width])
36 | else:
37 | new_size = (output_height, output_width)
38 | output_width_tmp = output_width
39 | output_height_tmp = output_height
40 |
41 | scale_x, scale_y = (
42 | output_width_tmp / results.image_size[1],
43 | output_height_tmp / results.image_size[0],
44 | )
45 | results = Instances(new_size, **results.get_fields())
46 |
47 | if results.has("pred_boxes"):
48 | output_boxes = results.pred_boxes
49 | elif results.has("proposal_boxes"):
50 | output_boxes = results.proposal_boxes
51 | else:
52 | output_boxes = None
53 | assert output_boxes is not None, "Predictions must contain boxes!"
54 |
55 | output_boxes.scale(scale_x, scale_y)
56 | output_boxes.clip(results.image_size)
57 |
58 | results = results[output_boxes.nonempty()]
59 |
60 | if results.has("pred_masks") and results.has("no_paste"):
61 | results.pred_masks = F.interpolate(
62 | results.pred_masks,
63 | size=(output_height, output_width),
64 | mode="bilinear",
65 | align_corners=False,
66 | )
67 | results.pred_masks = (results.pred_masks[:, 0, :, :] >= mask_threshold).to(dtype=torch.bool)
68 | elif results.has("pred_masks"):
69 | if isinstance(results.pred_masks, ROIMasks):
70 | roi_masks = results.pred_masks
71 | else:
72 | # pred_masks is a tensor of shape (N, 1, M, M)
73 | roi_masks = ROIMasks(results.pred_masks[:, 0, :, :])
74 | results.pred_masks = roi_masks.to_bitmasks(
75 | results.pred_boxes, output_height, output_width, mask_threshold
76 | ).tensor # TODO return ROIMasks/BitMask object in the future
77 |
78 | if results.has("pred_keypoints"):
79 | results.pred_keypoints[:, :, 0] *= scale_x
80 | results.pred_keypoints[:, :, 1] *= scale_y
81 |
82 | return results
83 |
--------------------------------------------------------------------------------
/wsovod/modeling/proposal_generator/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from .rpn import WSOVODRPN_V2, WSOVODRPN
--------------------------------------------------------------------------------
/wsovod/modeling/roi_heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from .box_head import DiscriminativeAdaptationNeck
--------------------------------------------------------------------------------
/wsovod/modeling/roi_heads/box_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from typing import List
3 |
4 | import fvcore.nn.weight_init as weight_init
5 | import numpy as np
6 | import torch
7 | from detectron2.config import configurable
8 | from detectron2.layers import Conv2d, ShapeSpec, get_norm
9 | from detectron2.modeling.roi_heads import ROI_BOX_HEAD_REGISTRY
10 | from torch import nn
11 |
12 | __all__ = ["DiscriminativeAdaptationNeck"]
13 |
14 |
15 | # To get torchscript support, we make the head a subclass of `nn.Sequential`.
16 | # Therefore, to add new layers in this head class, please make sure they are
17 | # added in the order they will be used in forward().
18 | @ROI_BOX_HEAD_REGISTRY.register()
19 | class DiscriminativeAdaptationNeck(nn.Sequential):
20 | """
21 | A head with several 3x3 conv layers (each followed by norm & relu) and then
22 | several fc layers (each followed by relu).
23 | """
24 |
25 | @configurable
26 | def __init__(
27 | self, input_shape: ShapeSpec, *, conv_dims: List[int], fc_dims: List[int], conv_norm=""
28 | ):
29 | """
30 | NOTE: this interface is experimental.
31 |
32 | Args:
33 | input_shape (ShapeSpec): shape of the input feature.
34 | conv_dims (list[int]): the output dimensions of the conv layers
35 | fc_dims (list[int]): the output dimensions of the fc layers
36 | conv_norm (str or callable): normalization for the conv layers.
37 | See :func:`detectron2.layers.get_norm` for supported types.
38 | """
39 | super().__init__()
40 | assert len(conv_dims) + len(fc_dims) > 0
41 |
42 | self._output_size = (input_shape.channels, input_shape.height, input_shape.width)
43 |
44 | self.conv_norm_relus = []
45 | for k, conv_dim in enumerate(conv_dims):
46 | conv = Conv2d(
47 | self._output_size[0],
48 | conv_dim,
49 | kernel_size=3,
50 | padding=1,
51 | bias=not conv_norm,
52 | norm=get_norm(conv_norm, conv_dim),
53 | activation=nn.ReLU(inplace=True),
54 | )
55 | self.add_module("conv{}".format(k + 1), conv)
56 | self.conv_norm_relus.append(conv)
57 | self._output_size = (conv_dim, self._output_size[1], self._output_size[2])
58 |
59 | self.fcs = []
60 | for k, fc_dim in enumerate(fc_dims):
61 | if k == 0:
62 | self.add_module("flatten", nn.Flatten())
63 | fc = nn.Linear(int(np.prod(self._output_size)), fc_dim)
64 | self.add_module("fc{}".format(k + 1), fc)
65 | self.add_module("fc_relu{}".format(k + 1), nn.ReLU(inplace=True))
66 | self.add_module("fc_dropout{}".format(k + 1), nn.Dropout(p=0.5, inplace=False))
67 | self.fcs.append(fc)
68 | self._output_size = fc_dim
69 |
70 | for layer in self.conv_norm_relus:
71 | weight_init.c2_msra_fill(layer)
72 | for layer in self.fcs:
73 | # weight_init.c2_xavier_fill(layer)
74 | torch.nn.init.normal_(layer.weight, std=0.005)
75 | torch.nn.init.constant_(layer.bias, 0.1)
76 |
77 | @classmethod
78 | def from_config(cls, cfg, input_shape):
79 | num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV
80 | conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM
81 | # num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC
82 | fc_dims = cfg.MODEL.ROI_BOX_HEAD.DAN_DIM
83 | return {
84 | "input_shape": input_shape,
85 | "conv_dims": [conv_dim] * num_conv,
86 | "fc_dims": fc_dims,
87 | "conv_norm": cfg.MODEL.ROI_BOX_HEAD.NORM,
88 | }
89 |
90 | def forward(self, x):
91 | for layer in self:
92 | x = layer(x)
93 | return x
94 |
95 | @property
96 | @torch.jit.unused
97 | def output_shape(self):
98 | """
99 | Returns:
100 | ShapeSpec: the output feature shape
101 | """
102 | o = self._output_size
103 | if isinstance(o, int):
104 | return ShapeSpec(channels=o)
105 | else:
106 | return ShapeSpec(channels=o[0], height=o[1], width=o[2])
107 |
--------------------------------------------------------------------------------
/wsovod/solver/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params
3 | from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR, LRMultiplier, WarmupParamScheduler
4 |
5 | __all__ = [k for k in globals().keys() if not k.startswith("_")]
--------------------------------------------------------------------------------
/wsovod/solver/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import logging
3 | import math
4 | from bisect import bisect_right
5 | from typing import List
6 | import torch
7 | from fvcore.common.param_scheduler import (
8 | CompositeParamScheduler,
9 | ConstantParamScheduler,
10 | LinearParamScheduler,
11 | ParamScheduler,
12 | )
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class WarmupParamScheduler(CompositeParamScheduler):
18 | """
19 | Add an initial warmup stage to another scheduler.
20 | """
21 |
22 | def __init__(
23 | self,
24 | scheduler: ParamScheduler,
25 | warmup_factor: float,
26 | warmup_length: float,
27 | warmup_method: str = "linear",
28 | ):
29 | """
30 | Args:
31 | scheduler: warmup will be added at the beginning of this scheduler
32 | warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001
33 | warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire
34 | training, e.g. 0.01
35 | warmup_method: one of "linear" or "constant"
36 | """
37 | end_value = scheduler(warmup_length) # the value to reach when warmup ends
38 | start_value = warmup_factor * scheduler(0.0)
39 | if warmup_method == "constant":
40 | warmup = ConstantParamScheduler(start_value)
41 | elif warmup_method == "linear":
42 | warmup = LinearParamScheduler(start_value, end_value)
43 | else:
44 | raise ValueError("Unknown warmup method: {}".format(warmup_method))
45 | super().__init__(
46 | [warmup, scheduler],
47 | interval_scaling=["rescaled", "fixed"],
48 | lengths=[warmup_length, 1 - warmup_length],
49 | )
50 |
51 |
52 | class LRMultiplier(torch.optim.lr_scheduler._LRScheduler):
53 | """
54 | A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the
55 | learning rate of each param in the optimizer.
56 | Every step, the learning rate of each parameter becomes its initial value
57 | multiplied by the output of the given :class:`ParamScheduler`.
58 |
59 | The absolute learning rate value of each parameter can be different.
60 | This scheduler can be used as long as the relative scale among them do
61 | not change during training.
62 |
63 | Examples:
64 | ::
65 | LRMultiplier(
66 | opt,
67 | WarmupParamScheduler(
68 | MultiStepParamScheduler(
69 | [1, 0.1, 0.01],
70 | milestones=[60000, 80000],
71 | num_updates=90000,
72 | ), 0.001, 100 / 90000
73 | ),
74 | max_iter=90000
75 | )
76 | """
77 |
78 | # NOTES: in the most general case, every LR can use its own scheduler.
79 | # Supporting this requires interaction with the optimizer when its parameter
80 | # group is initialized. For example, classyvision implements its own optimizer
81 | # that allows different schedulers for every parameter group.
82 | # To avoid this complexity, we use this class to support the most common cases
83 | # where the relative scale among all LRs stay unchanged during training. In this
84 | # case we only need a total of one scheduler that defines the relative LR multiplier.
85 |
86 | def __init__(
87 | self,
88 | optimizer: torch.optim.Optimizer,
89 | multiplier: ParamScheduler,
90 | max_iter: int,
91 | last_iter: int = -1,
92 | ):
93 | """
94 | Args:
95 | optimizer, last_iter: See ``torch.optim.lr_scheduler._LRScheduler``.
96 | ``last_iter`` is the same as ``last_epoch``.
97 | multiplier: a fvcore ParamScheduler that defines the multiplier on
98 | every LR of the optimizer
99 | max_iter: the total number of training iterations
100 | """
101 | if not isinstance(multiplier, ParamScheduler):
102 | raise ValueError(
103 | "_LRMultiplier(multiplier=) must be an instance of fvcore "
104 | f"ParamScheduler. Got {multiplier} instead."
105 | )
106 | self._multiplier = multiplier
107 | self._max_iter = max_iter
108 | super().__init__(optimizer, last_epoch=last_iter)
109 |
110 | def state_dict(self):
111 | # fvcore schedulers are stateless. Only keep pytorch scheduler states
112 | return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch}
113 |
114 | def get_lr(self) -> List[float]:
115 | multiplier = self._multiplier(self.last_epoch / self._max_iter)
116 | return [base_lr * multiplier for base_lr in self.base_lrs]
117 |
118 |
119 | """
120 | Content below is no longer needed!
121 | """
122 |
123 | # NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
124 | # only on epoch boundaries. We typically use iteration based schedules instead.
125 | # As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean
126 | # "iteration" instead.
127 |
128 | # FIXME: ideally this would be achieved with a CombinedLRScheduler, separating
129 | # MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it.
130 |
131 |
132 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
133 | def __init__(
134 | self,
135 | optimizer: torch.optim.Optimizer,
136 | milestones: List[int],
137 | gamma: float = 0.1,
138 | warmup_factor: float = 0.001,
139 | warmup_iters: int = 1000,
140 | warmup_method: str = "linear",
141 | last_epoch: int = -1,
142 | ):
143 | logger.warning(
144 | "WarmupMultiStepLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
145 | )
146 | if not list(milestones) == sorted(milestones):
147 | raise ValueError(
148 | "Milestones should be a list of" " increasing integers. Got {}", milestones
149 | )
150 | self.milestones = milestones
151 | self.gamma = gamma
152 | self.warmup_factor = warmup_factor
153 | self.warmup_iters = warmup_iters
154 | self.warmup_method = warmup_method
155 | super().__init__(optimizer, last_epoch)
156 |
157 | def get_lr(self) -> List[float]:
158 | warmup_factor = _get_warmup_factor_at_iter(
159 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
160 | )
161 | return [
162 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
163 | for base_lr in self.base_lrs
164 | ]
165 |
166 | def _compute_values(self) -> List[float]:
167 | # The new interface
168 | return self.get_lr()
169 |
170 |
171 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler):
172 | def __init__(
173 | self,
174 | optimizer: torch.optim.Optimizer,
175 | max_iters: int,
176 | warmup_factor: float = 0.001,
177 | warmup_iters: int = 1000,
178 | warmup_method: str = "linear",
179 | last_epoch: int = -1,
180 | ):
181 | logger.warning(
182 | "WarmupCosineLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
183 | )
184 | self.max_iters = max_iters
185 | self.warmup_factor = warmup_factor
186 | self.warmup_iters = warmup_iters
187 | self.warmup_method = warmup_method
188 | super().__init__(optimizer, last_epoch)
189 |
190 | def get_lr(self) -> List[float]:
191 | warmup_factor = _get_warmup_factor_at_iter(
192 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
193 | )
194 | # Different definitions of half-cosine with warmup are possible. For
195 | # simplicity we multiply the standard half-cosine schedule by the warmup
196 | # factor. An alternative is to start the period of the cosine at warmup_iters
197 | # instead of at 0. In the case that warmup_iters << max_iters the two are
198 | # very close to each other.
199 | return [
200 | base_lr
201 | * warmup_factor
202 | * 0.5
203 | * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters))
204 | for base_lr in self.base_lrs
205 | ]
206 |
207 | def _compute_values(self) -> List[float]:
208 | # The new interface
209 | return self.get_lr()
210 |
211 |
212 | def _get_warmup_factor_at_iter(
213 | method: str, iter: int, warmup_iters: int, warmup_factor: float
214 | ) -> float:
215 | """
216 | Return the learning rate warmup factor at a specific iteration.
217 | See :paper:`ImageNet in 1h` for more details.
218 |
219 | Args:
220 | method (str): warmup method; either "constant" or "linear".
221 | iter (int): iteration at which to calculate the warmup factor.
222 | warmup_iters (int): the number of warmup iterations.
223 | warmup_factor (float): the base warmup factor (the meaning changes according
224 | to the method used).
225 |
226 | Returns:
227 | float: the effective warmup factor at the given iteration.
228 | """
229 | if iter >= warmup_iters:
230 | return 1.0
231 |
232 | if method == "constant":
233 | return warmup_factor
234 | elif method == "linear":
235 | alpha = iter / warmup_iters
236 | return warmup_factor * (1 - alpha) + alpha
237 | else:
238 | raise ValueError("Unknown warmup method: {}".format(method))
--------------------------------------------------------------------------------
/wsovod/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HunterJ-Lin/WSOVOD/74e7647faaa0ab5de37a351532f65c57654f66be/wsovod/utils/__init__.py
--------------------------------------------------------------------------------
/wsovod/utils/sam_predictor_with_buffer.py:
--------------------------------------------------------------------------------
1 | from segment_anything.modeling import Sam
2 | from segment_anything.utils.transforms import ResizeLongestSide
3 | import numpy as np
4 | import torch
5 | from typing import Optional, Tuple
6 |
7 | class SamPredictorBuffer:
8 | def __init__(
9 | self,
10 | sam_model: Sam,
11 | ):
12 | """
13 | Uses SAM to calculate the image embedding for an image, and then
14 | allow repeated, efficient mask prediction given prompts.
15 |
16 | Arguments:
17 | sam_model (Sam): The model to use for mask prediction.
18 | """
19 | super().__init__()
20 | self.model = sam_model
21 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
22 | self.buffer = {}
23 |
24 | def set_image(
25 | self,
26 | image: np.ndarray,
27 | image_format: str = "RGB",
28 | file_name = None,
29 | ):
30 | """
31 | Calculates the image embeddings for the provided image, allowing
32 | masks to be predicted with the 'predict' method.
33 |
34 | Arguments:
35 | image (np.ndarray): The image for calculating masks. Expects an
36 | image in HWC uint8 format, with pixel values in [0, 255].
37 | image_format (str): The color format of the image, in ['RGB', 'BGR'].
38 | """
39 | if file_name in self.buffer.keys():
40 | return
41 | assert image_format in [
42 | "RGB",
43 | "BGR",
44 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
45 | if image_format != self.model.image_format:
46 | image = image[..., ::-1]
47 |
48 | # Transform the image to the form expected by the model
49 | input_image = self.transform.apply_image(image)
50 | input_image_torch = torch.as_tensor(input_image, device=self.device)
51 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
52 |
53 | self.set_torch_image(input_image_torch, image.shape[:2], file_name)
54 |
55 | @torch.no_grad()
56 | def set_torch_image(
57 | self,
58 | transformed_image: torch.Tensor,
59 | original_image_size: Tuple[int, ...],
60 | file_name = None,
61 | ):
62 | """
63 | Calculates the image embeddings for the provided image, allowing
64 | masks to be predicted with the 'predict' method. Expects the input
65 | image to be already transformed to the format expected by the model.
66 |
67 | Arguments:
68 | transformed_image (torch.Tensor): The input image, with shape
69 | 1x3xHxW, which has been transformed with ResizeLongestSide.
70 | original_image_size (tuple(int, int)): The size of the image
71 | before transformation, in (H, W) format.
72 | """
73 | if file_name in self.buffer.keys():
74 | return
75 | assert (
76 | len(transformed_image.shape) == 4
77 | and transformed_image.shape[1] == 3
78 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
79 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
80 |
81 | input_size = tuple(transformed_image.shape[-2:])
82 | self.buffer[file_name] = {
83 | 'original_size' : original_image_size,
84 | 'input_size' : input_size,
85 | 'features' : self.model.image_encoder(self.model.preprocess(transformed_image)).detach()
86 | }
87 |
88 | def predict(
89 | self,
90 | point_coords: Optional[np.ndarray] = None,
91 | point_labels: Optional[np.ndarray] = None,
92 | box: Optional[np.ndarray] = None,
93 | mask_input: Optional[np.ndarray] = None,
94 | multimask_output: bool = True,
95 | return_logits: bool = False,
96 | file_name = None,
97 | ):
98 | """
99 | Predict masks for the given input prompts, using the currently set image.
100 |
101 | Arguments:
102 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the
103 | model. Each point is in (X,Y) in pixels.
104 | point_labels (np.ndarray or None): A length N array of labels for the
105 | point prompts. 1 indicates a foreground point and 0 indicates a
106 | background point.
107 | box (np.ndarray or None): A length 4 array given a box prompt to the
108 | model, in XYXY format.
109 | mask_input (np.ndarray): A low resolution mask input to the model, typically
110 | coming from a previous prediction iteration. Has form 1xHxW, where
111 | for SAM, H=W=256.
112 | multimask_output (bool): If true, the model will return three masks.
113 | For ambiguous input prompts (such as a single click), this will often
114 | produce better masks than a single prediction. If only a single
115 | mask is needed, the model's predicted quality score can be used
116 | to select the best mask. For non-ambiguous prompts, such as multiple
117 | input prompts, multimask_output=False can give better results.
118 | return_logits (bool): If true, returns un-thresholded masks logits
119 | instead of a binary mask.
120 |
121 | Returns:
122 | (np.ndarray): The output masks in CxHxW format, where C is the
123 | number of masks, and (H, W) is the original image size.
124 | (np.ndarray): An array of length C containing the model's
125 | predictions for the quality of each mask.
126 | (np.ndarray): An array of shape CxHxW, where C is the number
127 | of masks and H=W=256. These low resolution logits can be passed to
128 | a subsequent iteration as mask input.
129 | """
130 | assert file_name in self.buffer.keys(), RuntimeError("An image must be set with .set_image(...) before mask prediction.")
131 | self.original_size = self.buffer[file_name]['original_size']
132 | self.input_size = self.buffer[file_name]['input_size']
133 | self.features = self.buffer[file_name]['features']
134 |
135 | # Transform input prompts
136 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
137 | if point_coords is not None:
138 | assert (
139 | point_labels is not None
140 | ), "point_labels must be supplied if point_coords is supplied."
141 | point_coords = self.transform.apply_coords(point_coords, self.original_size)
142 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
143 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
144 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
145 | if box is not None:
146 | box = self.transform.apply_boxes(box, self.original_size)
147 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
148 | box_torch = box_torch[None, :]
149 | if mask_input is not None:
150 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
151 | mask_input_torch = mask_input_torch[None, :, :, :]
152 |
153 | masks, iou_predictions, low_res_masks = self.predict_torch(
154 | coords_torch,
155 | labels_torch,
156 | box_torch,
157 | mask_input_torch,
158 | multimask_output,
159 | return_logits=return_logits,
160 | file_name=file_name
161 | )
162 |
163 | masks_np = masks[0].detach().cpu().numpy()
164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
166 | return masks_np, iou_predictions_np, low_res_masks_np
167 |
168 | @torch.no_grad()
169 | def predict_torch(
170 | self,
171 | point_coords: Optional[torch.Tensor],
172 | point_labels: Optional[torch.Tensor],
173 | boxes: Optional[torch.Tensor] = None,
174 | mask_input: Optional[torch.Tensor] = None,
175 | multimask_output: bool = True,
176 | return_logits: bool = False,
177 | file_name = None,
178 | ):
179 | """
180 | Predict masks for the given input prompts, using the currently set image.
181 | Input prompts are batched torch tensors and are expected to already be
182 | transformed to the input frame using ResizeLongestSide.
183 |
184 | Arguments:
185 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
186 | model. Each point is in (X,Y) in pixels.
187 | point_labels (torch.Tensor or None): A BxN array of labels for the
188 | point prompts. 1 indicates a foreground point and 0 indicates a
189 | background point.
190 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the
191 | model, in XYXY format.
192 | mask_input (np.ndarray): A low resolution mask input to the model, typically
193 | coming from a previous prediction iteration. Has form Bx1xHxW, where
194 | for SAM, H=W=256. Masks returned by a previous iteration of the
195 | predict method do not need further transformation.
196 | multimask_output (bool): If true, the model will return three masks.
197 | For ambiguous input prompts (such as a single click), this will often
198 | produce better masks than a single prediction. If only a single
199 | mask is needed, the model's predicted quality score can be used
200 | to select the best mask. For non-ambiguous prompts, such as multiple
201 | input prompts, multimask_output=False can give better results.
202 | return_logits (bool): If true, returns un-thresholded masks logits
203 | instead of a binary mask.
204 |
205 | Returns:
206 | (torch.Tensor): The output masks in BxCxHxW format, where C is the
207 | number of masks, and (H, W) is the original image size.
208 | (torch.Tensor): An array of shape BxC containing the model's
209 | predictions for the quality of each mask.
210 | (torch.Tensor): An array of shape BxCxHxW, where C is the number
211 | of masks and H=W=256. These low res logits can be passed to
212 | a subsequent iteration as mask input.
213 | """
214 | assert file_name in self.buffer.keys(), RuntimeError("An image must be set with .set_image(...) before mask prediction.")
215 | self.original_size = self.buffer[file_name]['original_size']
216 | self.input_size = self.buffer[file_name]['input_size']
217 | self.features = self.buffer[file_name]['features']
218 |
219 | if point_coords is not None:
220 | points = (point_coords, point_labels)
221 | else:
222 | points = None
223 |
224 | # Embed prompts
225 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
226 | points=points,
227 | boxes=boxes,
228 | masks=mask_input,
229 | )
230 |
231 | # Predict masks
232 | low_res_masks, iou_predictions = self.model.mask_decoder(
233 | image_embeddings=self.features,
234 | image_pe=self.model.prompt_encoder.get_dense_pe(),
235 | sparse_prompt_embeddings=sparse_embeddings,
236 | dense_prompt_embeddings=dense_embeddings,
237 | multimask_output=multimask_output,
238 | )
239 |
240 | # Upscale the masks to the original image resolution
241 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
242 |
243 | if not return_logits:
244 | masks = masks > self.model.mask_threshold
245 |
246 | return masks, iou_predictions, low_res_masks
247 |
248 | def get_image_embedding(self, file_name):
249 | """
250 | Returns the image embeddings for the currently set image, with
251 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are
252 | the embedding spatial dimension of SAM (typically C=256, H=W=64).
253 | """
254 | if file_name in self.buffer:
255 | return self.buffer['file_name']
256 |
257 | return None
258 |
259 | @property
260 | def device(self):
261 | return self.model.device
262 |
263 | def reset_buffer(self):
264 | del self.buffer
265 | self.buffer = {}
--------------------------------------------------------------------------------