├── .gitignore ├── LICENSE ├── README.md ├── configs ├── COCO-Detection │ ├── Base-RCNN-DilatedC5.yaml │ ├── WSOVOD_MRRP_V_16_DC5_1x.yaml │ ├── WSOVOD_MRRP_WSR_18_DC5_1x.yaml │ ├── WSOVOD_MRRP_WSR_50_DC5_1x.yaml │ ├── WSOVOD_V_16_DC5_1x.yaml │ ├── WSOVOD_WSR_18_DC5_1x.yaml │ └── WSOVOD_WSR_50_DC5_1x.yaml ├── ImageNet-Detection │ ├── Base-RCNN-DilatedC5.yaml │ └── WSOVOD_WSR_18_DC5_1x.yaml ├── MixedDatasets-Detection │ ├── Base-RCNN-DilatedC5.yaml │ ├── WSOVOD_MRRP_WSR_18_DC5_1x_voc07+coco.yaml │ ├── WSOVOD_MRRP_WSR_50_DC5_1x_voc07+coco.yaml │ ├── WSOVOD_WSR_18_DC5_1x_voc07+coco.yaml │ └── WSOVOD_WSR_50_DC5_1x_voc07+coco.yaml └── PascalVOC-Detection │ ├── Base-RCNN-DilatedC5.yaml │ ├── WSOVOD_MRRP_V_16_DC5_1x.yaml │ ├── WSOVOD_MRRP_V_16_DC5_VOC12_1x.yaml │ ├── WSOVOD_MRRP_WSR_18_DC5_1x.yaml │ ├── WSOVOD_MRRP_WSR_18_DC5_VOC12_1x.yaml │ ├── WSOVOD_MRRP_WSR_50_DC5_1x.yaml │ ├── WSOVOD_MRRP_WSR_50_DC5_VOC12_1x.yaml │ ├── WSOVOD_V_16_DC5_1x.yaml │ ├── WSOVOD_V_16_DC5_VOC12_1x.yaml │ ├── WSOVOD_WSR_18_DC5_1x.yaml │ ├── WSOVOD_WSR_18_DC5_VOC12_1x.yaml │ ├── WSOVOD_WSR_50_DC5_1x.yaml │ └── WSOVOD_WSR_50_DC5_VOC12_1x.yaml ├── datasets └── README.md ├── requirements.txt ├── scripts ├── extract_ilsvrc.sh ├── generate_sam_proposals_cuda.sh ├── prepare_ilsvrc.sh └── train_script.sh ├── setup.py ├── teaser └── framework.png ├── tools ├── convert_ilsvrc_classes_name.py ├── generate_class_text_embedding.py ├── generate_class_text_embedding_cuda.py ├── generate_sam_proposals_cuda.py ├── ilsvrc2012_classes_name.txt ├── ilsvrc_folder.py ├── ilsvrc_info.py ├── train_net.py ├── train_net_debug.py └── train_net_eval_open_vocabulary.py └── wsovod ├── __init__.py ├── config ├── __init__.py └── defaults.py ├── data ├── __init__.py ├── build.py ├── build_multi_dataset.py ├── common.py ├── dataset_mapper.py ├── datasets │ ├── __init__.py │ ├── builtin.py │ ├── builtin_meta.py │ └── pascal_voc.py ├── detection_utils.py └── samplers │ ├── __init__.py │ └── distributed_sampler_multi_dataset.py ├── engine ├── __init__.py ├── defaults.py ├── hooks.py └── trainer.py ├── evaluation ├── __init__.py ├── coco_evaluation.py ├── ov_coco_evaluation.py └── pascal_voc_evaluation.py ├── layers ├── ROILoopPool │ ├── ROILoopPool.h │ ├── ROILoopPool_cpu.cpp │ ├── ROILoopPool_cuda.cu │ └── cuda_helpers.h ├── __init__.py ├── csc.py ├── csc │ ├── csc.h │ └── csc_cuda.cu ├── roi_loop_pool.py └── vision.cpp ├── modeling ├── __init__.py ├── backbone │ ├── __init__.py │ ├── mrrp_conv.py │ ├── resnet_wsl.py │ ├── resnet_wsl_mrrp.py │ ├── swin_transformer.py │ ├── vgg.py │ └── vgg_mrrp.py ├── class_heads │ ├── __init__.py │ ├── data_aware_features_head.py │ └── open_vocabulary_classifier.py ├── meta_arch │ ├── __init__.py │ ├── rcnn_wsovod.py │ └── rcnn_wsovod_mixed_datasets.py ├── poolers.py ├── postprocessing.py ├── proposal_generator │ ├── __init__.py │ ├── proposal_utils.py │ └── rpn.py ├── roi_heads │ ├── __init__.py │ ├── box_head.py │ ├── fast_rcnn_open_vocabulary.py │ └── roi_heads.py ├── test_time_augmentation_avg.py └── test_time_augmentation_union.py ├── solver ├── __init__.py ├── build.py └── lr_scheduler.py └── utils ├── __init__.py └── sam_predictor_with_buffer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # output dir 2 | output* 3 | 4 | *.pth 5 | 6 | # compilation and distribution 7 | __pycache__ 8 | _ext 9 | *.pyc 10 | *.so 11 | wsovod.egg-info/ 12 | build/ 13 | dist/ 14 | 15 | # pytorch/python/numpy formats 16 | *.pth 17 | *.pkl 18 | *.npy 19 | 20 | # ipython/jupyter notebooks 21 | *.ipynb 22 | **/.ipynb_checkpoints/ 23 | 24 | # Editor temporaries 25 | *.swn 26 | *.swo 27 | *.swp 28 | *~ 29 | 30 | # Pycharm editor settings 31 | .idea 32 | 33 | # project dirs 34 | /datasets/coco 35 | /datasets/lvis 36 | /datasets/ILSVRC2012 37 | /datasets/proposals 38 | /datasets/VOC* 39 | /models -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weakly Supervised Open-Vocabulary Object Detection 2 | This is an official implementation for AAAI2024 paper "Weakly Supervised Open-Vocabulary Object Detection". (Code is coming soon!) 3 | 4 | ## 📋 Table of content 5 | 1. [📎 Paper Link](#1) 6 | 2. [💡 Abstract](#2) 7 | 3. [📖 Method](#3) 8 | 4. [🛠️ Install](#4) 9 | 5. [✏️ Usage](#5) 10 | 1. [Start](#51) 11 | 2. [Prepare Datasets](#52) 12 | 3. [Training](#53) 13 | 4. [Inference](#54) 14 | 6. [🔍 Citation](#6) 15 | 7. [❤️ Acknowledgement](#7) 16 | 17 | ## 📎 Paper Link 18 | [Read our arXiv Paper](https://arxiv.org/abs/2312.12437) 19 | 20 | ## 💡 Abstract 21 | Despite weakly supervised object detection (WSOD) being a promising step toward evading strong instance-level annotations, its capability is confined to closed-set categories within a single training dataset. In this paper, we propose a novel weakly supervised open-vocabulary object detection framework, namely WSOVOD, to extend traditional WSOD to detect novel concepts and utilize diverse datasets with only image-level annotations. To achieve this, we explore three vital strategies, including dataset-level feature adaptation, image-level salient object localization, and region-level vision-language alignment. First, we perform data-aware feature extraction to produce an input-conditional coefficient, which is leveraged into dataset attribute prototypes to identify dataset bias and help achieve cross-dataset generalization. Second, a customized location-oriented weakly supervised region proposal network is proposed to utilize high-level semantic layouts from the category-agnostic segment anything model to distinguish object boundaries. Lastly, we introduce a proposal-concept synchronized multiple-instance network, i.e., object mining and refinement with visual-semantic alignment, to discover objects matched to the text embeddings of concepts. Extensive experiments on Pascal VOC and MS COCO demonstrate that the proposed WSOVOD achieves new state-of-the-art compared with previous WSOD methods in both close-set object localization and detection tasks. Meanwhile, WSOVOD enables cross-dataset and open-vocabulary learning to achieve on-par or even better performance than well-established fully-supervised open-vocabulary object detection (FSOVOD). 22 | 23 | ## 📖 Method 24 | 25 | The overall of our **WSOVOD**. 26 |

27 | 28 |

29 | 30 | ## 🛠️ Install 31 | ``` 32 | conda create --name wsovod python=3.9 33 | conda activate wsovod 34 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 35 | pip install -r requirements.txt 36 | pip install -e . 37 | ``` 38 | 39 | ## ✏️ Usage 40 | 1、Please follow [this](datasets/README.md) to prepare datasets for training. 41 | 42 | 2、Download SAM checkpoints. 43 | ``` 44 | mkdir tools/sam_checkpoints & cd tools/sam_checkpoints 45 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 46 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth 47 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 48 | ``` 49 | 50 | 3、Prepare SAM proposals for WSOVOD, take voc_2007_train for example. 51 | ``` 52 | bash scripts/generate_sam_proposals_cuda.sh 4 --checkpoint tools/sam_checkpoints/sam_vit_h_4b8939.pth --model-type vit_h --points-per-side 32 --pred-iou-thresh 0.86 --stability-score-thresh 0.92 --crop-n-layers 1 --crop-n-points-downscale-factor 2 --min-mask-region-area 20.0 --dataset-name voc_2007_train --output datasets/proposals/sam_voc_2007_train_d2.pkl 53 | ``` 54 | 55 | 4、Prepare class text embeddings for WSOVOD, take COCO for example. 56 | ``` 57 | python tools/generate_class_text_embedding_cuda.py --dataset-name coco_2017_val --mode-type ViT-L/14/32 --prompt-type single --output models/coco_text_embedding_single_prompt.pkl 58 | ``` 59 | 60 | 5、Download backbone pretrained from [here](https://1drv.ms/f/s!Am1oWgo9554dgRQ8RE1SRGvK7HW2). 61 | 62 | 6、Train a single dataset and test on another dataset, take COCO and VOC for example. 63 | ``` 64 | bash scripts/train_script.sh tools/train_net.py configs/COCO-Detection/WSOVOD_WSR_18_DC5_1x.yaml 4 20240301 65 | 66 | python tools/train_net.py --config-file configs/PascalVOC-Detection/WSOVOD_WSR_18_DC5_1x.yaml --num-gpus 4 --eval-only MODEL.WEIGHTS output/configs/COCO-Detection/WSOVOD_WSR_50_DC5_1x_20240301/model_final.pth 67 | ``` 68 | 7、Train mix datasets, take COCO and VOC for example. 69 | ``` 70 | 71 | ``` 72 | 73 | ## 🔍 Citation 74 | If you find WSOVOD useful in your research, please consider citing: 75 | 76 | ``` 77 | @InProceedings{WSOVOD_2024_AAAI, 78 | author = {Lin, Jianghang and Shen, Yunhang and Wang, Bingquan and Lin, Shaohui and Li, Ke and Cao, Liujuan}, 79 | title = {Weakly Supervised Open-Vocabulary Object Detection}, 80 | booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence}, 81 | year = {2024}, 82 | } 83 | ``` 84 | 85 | 86 | 87 | ## License 88 | 89 | WSOVOD is released under the [Apache 2.0 license](LICENSE). 90 | 91 | ## ❤️ Acknowledgement 92 | - [UWSOD](https://github.com/shenyunhang/UWSOD) 93 | - [detectron2](https://github.com/facebookresearch/detectron2) 94 | -------------------------------------------------------------------------------- /configs/COCO-Detection/Base-RCNN-DilatedC5.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN_WSOVOD" 3 | PIXEL_MEAN: [102.9801, 115.9465, 122.7717] 4 | LOAD_PROPOSALS: True 5 | PROPOSAL_GENERATOR: 6 | NAME: "WSOVODRPN_V2" 7 | MIN_SIZE: 40 8 | RPN: 9 | IN_FEATURES: ["res5"] 10 | PRE_NMS_TOPK_TRAIN: 2048 11 | PRE_NMS_TOPK_TEST: 2048 12 | POST_NMS_TOPK_TRAIN: 1024 13 | POST_NMS_TOPK_TEST: 1024 14 | NMS_THRESH: 0.7 15 | BATCH_SIZE_PER_IMAGE: 512 16 | # BATCH_SIZE_PER_IMAGE: 256 # default rpn 17 | POSITIVE_FRACTION: 0.5 18 | BBOX_REG_LOSS_TYPE: "smooth_l1" 19 | IOU_THRESHOLDS: [0.2, 0.6] 20 | # IOU_THRESHOLDS: [0.3, 0.7] # default rpn 21 | ROI_HEADS: 22 | NAME: "WSOVODROIHeads" 23 | NUM_CLASSES: 80 24 | SCORE_THRESH_TEST: 0.00001 25 | NMS_THRESH_TEST: 0.3 26 | BATCH_SIZE_PER_IMAGE: 4096 27 | POSITIVE_FRACTION: 1.0 28 | PROPOSAL_APPEND_GT: False 29 | ROI_BOX_HEAD: 30 | NAME: "DiscriminativeAdaptationNeck" 31 | POOLER_TYPE: "ROIPool" 32 | NUM_CONV: 0 33 | NUM_FC: 2 34 | DAN_DIM: [4096, 4096] 35 | POOLER_RESOLUTION: 7 36 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 37 | OPEN_VOCABULARY: 38 | WEIGHT_PATH_TRAIN: "models/coco_text_embedding_single_prompt.pkl" 39 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl" 40 | WEIGHT_DIM: 512 41 | USE_BIAS: 0.0 42 | NORM_WEIGHT: True 43 | NORM_TEMP: 50.0 44 | DATA_AWARE: True 45 | WSOVOD: 46 | ITER_SIZE: 1 47 | BBOX_REFINE: 48 | ENABLE: True 49 | OBJECT_MINING: 50 | MEAN_LOSS: True 51 | INSTANCE_REFINEMENT: 52 | REFINE_NUM: 1 53 | REFINE_REG: [True] 54 | SAMPLING: 55 | SAMPLING_ON: True 56 | IOU_THRESHOLDS: [[0.5],] 57 | IOU_LABELS: [[0, 1],] 58 | BATCH_SIZE_PER_IMAGE: [4096,] 59 | POSITIVE_FRACTION: [1.0,] 60 | SOLVER: 61 | IMS_PER_BATCH: 4 62 | BASE_LR: 0.01 63 | STEPS: (140000,) 64 | MAX_ITER: 200000 65 | REFERENCE_WORLD_SIZE: 4 66 | INPUT: 67 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216) 68 | MAX_SIZE_TRAIN: 2000 69 | MIN_SIZE_TEST: 688 70 | MAX_SIZE_TEST: 4000 71 | CROP: 72 | ENABLED: True 73 | TEST: 74 | AUG: 75 | ENABLED: True 76 | MIN_SIZES: (480, 576, 672, 768, 864, 960, 1056, 1152) 77 | MAX_SIZE: 4000 78 | FLIP: True 79 | EVAL_PERIOD: 5000 80 | EVAL_TRAIN: False 81 | DATASETS: 82 | TRAIN: ('coco_2017_train',) 83 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_coco_2017_train_d2.pkl', ) 84 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000 85 | TEST: ('coco_2017_val',) 86 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', ) 87 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000 88 | VIS_PERIOD: 10000 89 | VERSION: 2 -------------------------------------------------------------------------------- /configs/COCO-Detection/WSOVOD_MRRP_V_16_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/VGG/VGG_ILSVRC_16_layers_v1_d2.pkl" 4 | PIXEL_MEAN: [103.939, 116.779, 123.68] 5 | BACKBONE: 6 | NAME: "build_mrrp_vgg_backbone" 7 | FREEZE_AT: 5 8 | MRRP: 9 | MRRP_ON: True 10 | NUM_BRANCH: 3 11 | BRANCH_DILATIONS: [1, 2, 4] 12 | TEST_BRANCH_IDX: -1 13 | MRRP_STAGE: "plain5" 14 | VGG: 15 | DEPTH: 16 16 | CONV5_DILATION: 2 17 | ANCHOR_GENERATOR: 18 | SIZES: [[32, 64], [128, 256], [512, 768]] 19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 20 | ROI_HEADS: 21 | IN_FEATURES: ["plain5"] 22 | SOLVER: 23 | STEPS: (140000,) 24 | MAX_ITER: 200000 25 | WARMUP_ITERS: 200 26 | IMS_PER_BATCH: 4 27 | BASE_LR: 0.001 28 | WEIGHT_DECAY: 0.0005 29 | BIAS_LR_FACTOR: 2.0 30 | WEIGHT_DECAY_BIAS: 0.0 -------------------------------------------------------------------------------- /configs/COCO-Detection/WSOVOD_MRRP_WSR_18_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_mrrp_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | MRRP: 8 | MRRP_ON: True 9 | NUM_BRANCH: 3 10 | BRANCH_DILATIONS: [1, 2, 4] 11 | TEST_BRANCH_IDX: -1 12 | MRRP_STAGE: "res5" 13 | RESNETS: 14 | DEPTH: 18 15 | RES5_DILATION: 2 16 | RES2_OUT_CHANNELS: 64 17 | OUT_FEATURES: ["res5"] 18 | ANCHOR_GENERATOR: 19 | SIZES: [[32, 64], [128, 256], [512, 768]] 20 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 21 | ROI_HEADS: 22 | IN_FEATURES: ["res5"] 23 | ROI_BOX_HEAD: 24 | POOLER_TYPE: "ROILoopPool" 25 | NUM_CONV: 0 26 | NUM_FC: 2 27 | DAN_DIM: [4096, 4096] 28 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 29 | SOLVER: 30 | STEPS: (140000,) 31 | MAX_ITER: 200000 32 | WARMUP_ITERS: 200 33 | IMS_PER_BATCH: 4 34 | BASE_LR: 0.01 35 | WEIGHT_DECAY: 0.0005 36 | BIAS_LR_FACTOR: 2.0 37 | WEIGHT_DECAY_BIAS: 0.0 -------------------------------------------------------------------------------- /configs/COCO-Detection/WSOVOD_MRRP_WSR_50_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_mrrp_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | MRRP: 8 | MRRP_ON: True 9 | NUM_BRANCH: 3 10 | BRANCH_DILATIONS: [1, 2, 4] 11 | TEST_BRANCH_IDX: -1 12 | MRRP_STAGE: "res5" 13 | RESNETS: 14 | DEPTH: 50 15 | RES5_DILATION: 2 16 | OUT_FEATURES: ["res5"] 17 | ANCHOR_GENERATOR: 18 | SIZES: [[32, 64], [128, 256], [512, 768]] 19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 20 | ROI_HEADS: 21 | IN_FEATURES: ["res5"] 22 | ROI_BOX_HEAD: 23 | POOLER_TYPE: "ROILoopPool" 24 | NUM_CONV: 0 25 | NUM_FC: 2 26 | DAN_DIM: [4096, 4096] 27 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 28 | SOLVER: 29 | STEPS: (140000,) 30 | MAX_ITER: 200000 31 | WARMUP_ITERS: 200 32 | IMS_PER_BATCH: 4 33 | BASE_LR: 0.01 34 | WEIGHT_DECAY: 0.0005 35 | BIAS_LR_FACTOR: 2.0 36 | WEIGHT_DECAY_BIAS: 0.0 -------------------------------------------------------------------------------- /configs/COCO-Detection/WSOVOD_V_16_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/VGG/VGG_ILSVRC_16_layers_v1_d2.pkl" 4 | PIXEL_MEAN: [103.939, 116.779, 123.68] 5 | BACKBONE: 6 | NAME: "build_vgg_backbone" 7 | FREEZE_AT: 5 8 | VGG: 9 | DEPTH: 16 10 | CONV5_DILATION: 2 11 | ANCHOR_GENERATOR: 12 | SIZES: [32, 64, 128, 256, 512, 768] 13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 14 | ROI_HEADS: 15 | IN_FEATURES: ["plain5"] 16 | SOLVER: 17 | STEPS: (140000,) 18 | MAX_ITER: 200000 19 | WARMUP_ITERS: 200 20 | IMS_PER_BATCH: 4 21 | BASE_LR: 0.001 22 | WEIGHT_DECAY: 0.0005 23 | BIAS_LR_FACTOR: 2.0 24 | WEIGHT_DECAY_BIAS: 0.0 -------------------------------------------------------------------------------- /configs/COCO-Detection/WSOVOD_WSR_18_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | RESNETS: 8 | DEPTH: 18 9 | RES5_DILATION: 2 10 | RES2_OUT_CHANNELS: 64 11 | OUT_FEATURES: ["res5"] 12 | ANCHOR_GENERATOR: 13 | SIZES: [32, 64, 128, 256, 512, 768] 14 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 15 | ROI_HEADS: 16 | IN_FEATURES: ["res5"] 17 | SOLVER: 18 | STEPS: (140000,) 19 | MAX_ITER: 200000 20 | WARMUP_ITERS: 200 21 | IMS_PER_BATCH: 4 22 | BASE_LR: 0.01 23 | WEIGHT_DECAY: 0.0005 24 | BIAS_LR_FACTOR: 2.0 25 | WEIGHT_DECAY_BIAS: 0.0 26 | -------------------------------------------------------------------------------- /configs/COCO-Detection/WSOVOD_WSR_50_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | RESNETS: 8 | DEPTH: 50 9 | RES5_DILATION: 2 10 | OUT_FEATURES: ["res5"] 11 | ANCHOR_GENERATOR: 12 | SIZES: [32, 64, 128, 256, 512, 768] 13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 14 | ROI_HEADS: 15 | IN_FEATURES: ["res5"] 16 | SOLVER: 17 | STEPS: (140000,) 18 | MAX_ITER: 200000 19 | WARMUP_ITERS: 200 20 | IMS_PER_BATCH: 4 21 | BASE_LR: 0.01 22 | WEIGHT_DECAY: 0.0005 23 | BIAS_LR_FACTOR: 2.0 24 | WEIGHT_DECAY_BIAS: 0.0 25 | -------------------------------------------------------------------------------- /configs/ImageNet-Detection/Base-RCNN-DilatedC5.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN_WSOVOD" 3 | PIXEL_MEAN: [102.9801, 115.9465, 122.7717] 4 | LOAD_PROPOSALS: True 5 | PROPOSAL_GENERATOR: 6 | NAME: "WSOVODRPN_V2" 7 | MIN_SIZE: 100 8 | RPN: 9 | IN_FEATURES: ["res5"] 10 | PRE_NMS_TOPK_TRAIN: 512 11 | PRE_NMS_TOPK_TEST: 512 12 | POST_NMS_TOPK_TRAIN: 16 13 | POST_NMS_TOPK_TEST: 256 14 | NMS_THRESH: 0.7 15 | BATCH_SIZE_PER_IMAGE: 512 16 | # BATCH_SIZE_PER_IMAGE: 256 # default rpn 17 | POSITIVE_FRACTION: 0.5 18 | BBOX_REG_LOSS_TYPE: "smooth_l1" 19 | IOU_THRESHOLDS: [0.3, 0.7] 20 | # IOU_THRESHOLDS: [0.3, 0.7] # default rpn 21 | ROI_HEADS: 22 | NAME: "WSOVODROIHeads" 23 | NUM_CLASSES: 1000 24 | SCORE_THRESH_TEST: 0.00001 25 | NMS_THRESH_TEST: 0.3 26 | BATCH_SIZE_PER_IMAGE: 4096 27 | POSITIVE_FRACTION: 1.0 28 | PROPOSAL_APPEND_GT: False 29 | ROI_BOX_HEAD: 30 | NAME: "DiscriminativeAdaptationNeck" 31 | POOLER_TYPE: "ROIPool" 32 | NUM_CONV: 0 33 | NUM_FC: 2 34 | DAN_DIM: [4096, 4096] 35 | POOLER_RESOLUTION: 7 36 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 37 | OPEN_VOCABULARY: 38 | WEIGHT_PATH_TRAIN: "models/in1k_text_embedding_single_prompt.pkl" 39 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl" 40 | WEIGHT_DIM: 512 41 | USE_BIAS: 0.0 42 | NORM_WEIGHT: True 43 | NORM_TEMP: 50.0 44 | DATA_AWARE: True 45 | WSOVOD: 46 | ITER_SIZE: 1 47 | BBOX_REFINE: 48 | ENABLE: True 49 | OBJECT_MINING: 50 | MEAN_LOSS: True 51 | INSTANCE_REFINEMENT: 52 | REFINE_NUM: 1 53 | REFINE_REG: [True] 54 | SAMPLING: 55 | SAMPLING_ON: True 56 | IOU_THRESHOLDS: [[0.5],] 57 | IOU_LABELS: [[0, 1],] 58 | BATCH_SIZE_PER_IMAGE: [4096,] 59 | POSITIVE_FRACTION: [1.0,] 60 | SOLVER: 61 | IMS_PER_BATCH: 4 62 | BASE_LR: 0.01 63 | STEPS: (140000,) 64 | MAX_ITER: 200000 65 | REFERENCE_WORLD_SIZE: 4 66 | INPUT: 67 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896) 68 | MAX_SIZE_TRAIN: 2000 69 | MIN_SIZE_TEST: 512 70 | MAX_SIZE_TEST: 2000 71 | CROP: 72 | ENABLED: True 73 | TEST: 74 | AUG: 75 | ENABLED: True 76 | MIN_SIZES: (480, 576, 672, 768, 864, 960, 1056) 77 | MAX_SIZE: 2000 78 | FLIP: True 79 | EVAL_PERIOD: 5000 80 | EVAL_TRAIN: False 81 | DATASETS: 82 | TRAIN: ('ilsvrc_2012_val',) 83 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_ilsvrc_2012_val_d2.pkl',) 84 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 10 85 | TEST: ('coco_2017_val',) 86 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', ) 87 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000 88 | VIS_PERIOD: 10000 89 | VERSION: 2 -------------------------------------------------------------------------------- /configs/ImageNet-Detection/WSOVOD_WSR_18_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | RESNETS: 8 | DEPTH: 18 9 | RES5_DILATION: 2 10 | RES2_OUT_CHANNELS: 64 11 | OUT_FEATURES: ["res5"] 12 | ANCHOR_GENERATOR: 13 | SIZES: [32, 64, 128, 256, 512, 768] 14 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 15 | ROI_HEADS: 16 | IN_FEATURES: ["res5"] 17 | SOLVER: 18 | STEPS: (70000,) 19 | MAX_ITER: 100000 20 | WARMUP_ITERS: 200 21 | IMS_PER_BATCH: 4 22 | BASE_LR: 0.01 23 | WEIGHT_DECAY: 0.0005 24 | BIAS_LR_FACTOR: 2.0 25 | WEIGHT_DECAY_BIAS: 0.0 26 | -------------------------------------------------------------------------------- /configs/MixedDatasets-Detection/Base-RCNN-DilatedC5.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN_WSOVOD_MixedDatasets" 3 | PIXEL_MEAN: [102.9801, 115.9465, 122.7717] 4 | LOAD_PROPOSALS: True 5 | PROPOSAL_GENERATOR: 6 | NAME: "WSOVODRPN_V2" 7 | MIN_SIZE: 40 8 | RPN: 9 | IN_FEATURES: ["res5"] 10 | PRE_NMS_TOPK_TRAIN: 2048 11 | PRE_NMS_TOPK_TEST: 2048 12 | POST_NMS_TOPK_TRAIN: 1024 13 | POST_NMS_TOPK_TEST: 1024 14 | NMS_THRESH: 0.7 15 | BATCH_SIZE_PER_IMAGE: 512 16 | # BATCH_SIZE_PER_IMAGE: 256 # default rpn 17 | POSITIVE_FRACTION: 0.5 18 | BBOX_REG_LOSS_TYPE: "smooth_l1" 19 | IOU_THRESHOLDS: [0.2, 0.6] 20 | # IOU_THRESHOLDS: [0.3, 0.7] # default rpn 21 | ROI_HEADS: 22 | NAME: "WSOVODMixedDatasetsROIHeads" 23 | NUM_CLASSES: 80 24 | SCORE_THRESH_TEST: 0.00001 25 | NMS_THRESH_TEST: 0.3 26 | BATCH_SIZE_PER_IMAGE: 4096 27 | POSITIVE_FRACTION: 1.0 28 | PROPOSAL_APPEND_GT: False 29 | ROI_BOX_HEAD: 30 | NAME: "DiscriminativeAdaptationNeck" 31 | POOLER_TYPE: "ROIPool" 32 | NUM_CONV: 0 33 | NUM_FC: 2 34 | DAN_DIM: [4096, 4096] 35 | POOLER_RESOLUTION: 7 36 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 37 | OPEN_VOCABULARY: 38 | WEIGHT_PATH_TRAIN: "models/coco_text_embedding_single_prompt.pkl" 39 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl" 40 | WEIGHT_DIM: 512 41 | USE_BIAS: 0.0 42 | NORM_WEIGHT: True 43 | NORM_TEMP: 50.0 44 | DATA_AWARE: True 45 | WSOVOD: 46 | ITER_SIZE: 1 47 | BBOX_REFINE: 48 | ENABLE: True 49 | OBJECT_MINING: 50 | MEAN_LOSS: True 51 | INSTANCE_REFINEMENT: 52 | REFINE_NUM: 1 53 | REFINE_REG: [True] 54 | SAMPLING: 55 | SAMPLING_ON: True 56 | IOU_THRESHOLDS: [[0.5],] 57 | IOU_LABELS: [[0, 1],] 58 | BATCH_SIZE_PER_IMAGE: [4096,] 59 | POSITIVE_FRACTION: [1.0,] 60 | SOLVER: 61 | IMS_PER_BATCH: 4 62 | IMS_PER_BATCH_LIST: [4,4,4] 63 | BASE_LR: 0.01 64 | STEPS: (140000,) 65 | MAX_ITER: 200000 66 | REFERENCE_WORLD_SIZE: 4 67 | INPUT: 68 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216) 69 | MAX_SIZE_TRAIN: 2000 70 | MIN_SIZE_TEST: 688 71 | MAX_SIZE_TEST: 4000 72 | CROP: 73 | ENABLED: True 74 | TEST: 75 | AUG: 76 | ENABLED: True 77 | MIN_SIZES: (480, 576, 672, 768, 864, 960, 1056, 1152) 78 | MAX_SIZE: 4000 79 | FLIP: True 80 | EVAL_PERIOD: 5000 81 | EVAL_TRAIN: False 82 | DATASETS: 83 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000 84 | MIXED_DATASETS: 85 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"] 86 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train') 87 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl') 88 | NUM_CLASSES: [20,20,80] 89 | FILTER_EMPTY_ANNOTATIONS: [True, True, True] 90 | RATIOS: [1,1,20] 91 | USE_CAS: [False,False,False] 92 | USE_RFS: [False,False,False] 93 | CAS_LAMBDA: 1.0 94 | TEST: ('coco_2017_val',) 95 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', ) 96 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000 97 | DATALOADER: 98 | SAMPLER_TRAIN: "MultiDatasetTrainingSampler" 99 | VIS_PERIOD: 10000 100 | VERSION: 2 -------------------------------------------------------------------------------- /configs/MixedDatasets-Detection/WSOVOD_MRRP_WSR_18_DC5_1x_voc07+coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_mrrp_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | MRRP: 8 | MRRP_ON: True 9 | NUM_BRANCH: 3 10 | BRANCH_DILATIONS: [1, 2, 4] 11 | TEST_BRANCH_IDX: -1 12 | MRRP_STAGE: "res5" 13 | RESNETS: 14 | DEPTH: 18 15 | RES5_DILATION: 2 16 | RES2_OUT_CHANNELS: 64 17 | OUT_FEATURES: ["res5"] 18 | ANCHOR_GENERATOR: 19 | SIZES: [[32, 64], [128, 256], [512, 768]] 20 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 21 | ROI_HEADS: 22 | IN_FEATURES: ["res5"] 23 | ROI_BOX_HEAD: 24 | OPEN_VOCABULARY: 25 | DATA_AWARE: True 26 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl" 27 | POOLER_TYPE: "ROILoopPool" 28 | NUM_CONV: 0 29 | NUM_FC: 2 30 | DAN_DIM: [4096, 4096] 31 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 32 | SOLVER: 33 | STEPS: (140000,) 34 | MAX_ITER: 200000 35 | WARMUP_ITERS: 200 36 | IMS_PER_BATCH: 4 37 | BASE_LR: 0.01 38 | WEIGHT_DECAY: 0.0005 39 | BIAS_LR_FACTOR: 2.0 40 | WEIGHT_DECAY_BIAS: 0.0 41 | WSOVOD: 42 | BBOX_REFINE: 43 | ENABLE: True 44 | INSTANCE_REFINEMENT: 45 | REFINE_NUM: 1 46 | REFINE_REG: [True] 47 | DATASETS: 48 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000 49 | MIXED_DATASETS: 50 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"] 51 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train') 52 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl') 53 | NUM_CLASSES: [20,20,80] 54 | FILTER_EMPTY_ANNOTATIONS: [True, True, True] 55 | RATIOS: [1,1,20] 56 | USE_CAS: [False,False,False] 57 | USE_RFS: [False,False,False] 58 | CAS_LAMBDA: 1.0 59 | TEST: ('coco_2017_val',) 60 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', ) 61 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000 -------------------------------------------------------------------------------- /configs/MixedDatasets-Detection/WSOVOD_MRRP_WSR_50_DC5_1x_voc07+coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_mrrp_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | MRRP: 8 | MRRP_ON: True 9 | NUM_BRANCH: 3 10 | BRANCH_DILATIONS: [1, 2, 4] 11 | TEST_BRANCH_IDX: -1 12 | MRRP_STAGE: "res5" 13 | RESNETS: 14 | DEPTH: 50 15 | RES5_DILATION: 2 16 | OUT_FEATURES: ["res5"] 17 | ANCHOR_GENERATOR: 18 | SIZES: [[32, 64], [128, 256], [512, 768]] 19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 20 | ROI_HEADS: 21 | IN_FEATURES: ["res5"] 22 | ROI_BOX_HEAD: 23 | OPEN_VOCABULARY: 24 | DATA_AWARE: True 25 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl" 26 | POOLER_TYPE: "ROILoopPool" 27 | NUM_CONV: 0 28 | NUM_FC: 2 29 | DAN_DIM: [4096, 4096] 30 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 31 | SOLVER: 32 | STEPS: (140000,) 33 | MAX_ITER: 200000 34 | WARMUP_ITERS: 200 35 | IMS_PER_BATCH: 4 36 | BASE_LR: 0.01 37 | WEIGHT_DECAY: 0.0005 38 | BIAS_LR_FACTOR: 2.0 39 | WEIGHT_DECAY_BIAS: 0.0 40 | WSOVOD: 41 | BBOX_REFINE: 42 | ENABLE: True 43 | INSTANCE_REFINEMENT: 44 | REFINE_NUM: 1 45 | REFINE_REG: [True] 46 | DATASETS: 47 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000 48 | MIXED_DATASETS: 49 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"] 50 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train') 51 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl') 52 | NUM_CLASSES: [20,20,80] 53 | FILTER_EMPTY_ANNOTATIONS: [True, True, True] 54 | RATIOS: [1,1,20] 55 | USE_CAS: [False,False,False] 56 | USE_RFS: [False,False,False] 57 | CAS_LAMBDA: 1.0 58 | TEST: ('coco_2017_val',) 59 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', ) 60 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000 -------------------------------------------------------------------------------- /configs/MixedDatasets-Detection/WSOVOD_WSR_18_DC5_1x_voc07+coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | RESNETS: 8 | DEPTH: 18 9 | RES5_DILATION: 2 10 | RES2_OUT_CHANNELS: 64 11 | OUT_FEATURES: ["res5"] 12 | ANCHOR_GENERATOR: 13 | SIZES: [32, 64, 128, 256, 512, 768] 14 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 15 | ROI_HEADS: 16 | IN_FEATURES: ["res5"] 17 | ROI_BOX_HEAD: 18 | OPEN_VOCABULARY: 19 | DATA_AWARE: True 20 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl" 21 | SOLVER: 22 | STEPS: (140000,) 23 | MAX_ITER: 200000 24 | WARMUP_ITERS: 200 25 | IMS_PER_BATCH: 4 26 | BASE_LR: 0.01 27 | WEIGHT_DECAY: 0.0005 28 | BIAS_LR_FACTOR: 2.0 29 | WEIGHT_DECAY_BIAS: 0.0 30 | WSOVOD: 31 | BBOX_REFINE: 32 | ENABLE: True 33 | INSTANCE_REFINEMENT: 34 | REFINE_NUM: 1 35 | REFINE_REG: [True] 36 | DATASETS: 37 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000 38 | MIXED_DATASETS: 39 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"] 40 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train') 41 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl') 42 | NUM_CLASSES: [20,20,80] 43 | FILTER_EMPTY_ANNOTATIONS: [True, True, True] 44 | RATIOS: [1,1,20] 45 | USE_CAS: [False,False,False] 46 | USE_RFS: [False,False,False] 47 | CAS_LAMBDA: 1.0 48 | TEST: ('coco_2017_val',) 49 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', ) 50 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000 -------------------------------------------------------------------------------- /configs/MixedDatasets-Detection/WSOVOD_WSR_50_DC5_1x_voc07+coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | RESNETS: 8 | DEPTH: 50 9 | RES5_DILATION: 2 10 | OUT_FEATURES: ["res5"] 11 | ANCHOR_GENERATOR: 12 | SIZES: [32, 64, 128, 256, 512, 768] 13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 14 | ROI_HEADS: 15 | IN_FEATURES: ["res5"] 16 | ROI_BOX_HEAD: 17 | OPEN_VOCABULARY: 18 | DATA_AWARE: True 19 | WEIGHT_PATH_TEST: "models/coco_text_embedding_single_prompt.pkl" 20 | SOLVER: 21 | STEPS: (140000,) 22 | MAX_ITER: 200000 23 | WARMUP_ITERS: 200 24 | IMS_PER_BATCH: 4 25 | BASE_LR: 0.01 26 | WEIGHT_DECAY: 0.0005 27 | BIAS_LR_FACTOR: 2.0 28 | WEIGHT_DECAY_BIAS: 0.0 29 | WSOVOD: 30 | BBOX_REFINE: 31 | ENABLE: True 32 | INSTANCE_REFINEMENT: 33 | REFINE_NUM: 1 34 | REFINE_REG: [True] 35 | DATASETS: 36 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000 37 | MIXED_DATASETS: 38 | WEIGHT_PATH_TRAINS: ["models/voc_text_embedding_single_prompt.pkl","models/voc_text_embedding_single_prompt.pkl","models/coco_text_embedding_single_prompt.pkl"] 39 | NAMES: ('voc_2007_train', 'voc_2007_val', 'coco_2017_train') 40 | PROPOSAL_FILES: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl','datasets/proposals/sam_coco_2017_train_d2.pkl') 41 | NUM_CLASSES: [20,20,80] 42 | FILTER_EMPTY_ANNOTATIONS: [True, True, True] 43 | RATIOS: [1,1,20] 44 | USE_CAS: [False,False,False] 45 | USE_RFS: [False,False,False] 46 | CAS_LAMBDA: 1.0 47 | TEST: ('coco_2017_val',) 48 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_coco_2017_val_d2.pkl', ) 49 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000 -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/Base-RCNN-DilatedC5.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN_WSOVOD" 3 | PIXEL_MEAN: [102.9801, 115.9465, 122.7717] 4 | LOAD_PROPOSALS: True 5 | PROPOSAL_GENERATOR: 6 | NAME: "WSOVODRPN_V2" 7 | MIN_SIZE: 40 8 | RPN: 9 | IN_FEATURES: ["res5"] 10 | PRE_NMS_TOPK_TRAIN: 2048 11 | PRE_NMS_TOPK_TEST: 2048 12 | POST_NMS_TOPK_TRAIN: 1024 13 | POST_NMS_TOPK_TEST: 1024 14 | NMS_THRESH: 0.7 15 | BATCH_SIZE_PER_IMAGE: 512 16 | # BATCH_SIZE_PER_IMAGE: 256 # default rpn 17 | POSITIVE_FRACTION: 0.5 18 | BBOX_REG_LOSS_TYPE: "smooth_l1" 19 | IOU_THRESHOLDS: [0.2, 0.6] 20 | # IOU_THRESHOLDS: [0.3, 0.7] # default rpn 21 | ROI_HEADS: 22 | NAME: "WSOVODROIHeads" 23 | NUM_CLASSES: 20 24 | SCORE_THRESH_TEST: 0.00001 25 | NMS_THRESH_TEST: 0.3 26 | BATCH_SIZE_PER_IMAGE: 4096 27 | POSITIVE_FRACTION: 1.0 28 | PROPOSAL_APPEND_GT: False 29 | ROI_BOX_HEAD: 30 | NAME: "DiscriminativeAdaptationNeck" 31 | POOLER_TYPE: "ROIPool" 32 | NUM_CONV: 0 33 | NUM_FC: 2 34 | DAN_DIM: [4096, 4096] 35 | POOLER_RESOLUTION: 7 36 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 37 | OPEN_VOCABULARY: 38 | WEIGHT_PATH_TRAIN: "models/voc_text_embedding_single_prompt.pkl" 39 | WEIGHT_PATH_TEST: "models/voc_text_embedding_single_prompt.pkl" 40 | WEIGHT_DIM: 512 41 | USE_BIAS: 0.0 42 | NORM_WEIGHT: True 43 | NORM_TEMP: 50.0 44 | DATA_AWARE: True 45 | WSOVOD: 46 | ITER_SIZE: 1 47 | BBOX_REFINE: 48 | ENABLE: True 49 | OBJECT_MINING: 50 | MEAN_LOSS: True 51 | INSTANCE_REFINEMENT: 52 | REFINE_NUM: 1 53 | REFINE_REG: [True] 54 | SAMPLING: 55 | SAMPLING_ON: True 56 | IOU_THRESHOLDS: [[0.5],] 57 | IOU_LABELS: [[0, 1],] 58 | BATCH_SIZE_PER_IMAGE: [4096,] 59 | POSITIVE_FRACTION: [1.0,] 60 | SOLVER: 61 | IMS_PER_BATCH: 4 62 | BASE_LR: 0.02 63 | STEPS: (60000, 80000) 64 | MAX_ITER: 90000 65 | REFERENCE_WORLD_SIZE: 4 66 | INPUT: 67 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216) 68 | MAX_SIZE_TRAIN: 2000 69 | MIN_SIZE_TEST: 688 70 | MAX_SIZE_TEST: 4000 71 | CROP: 72 | ENABLED: True 73 | TEST: 74 | AUG: 75 | ENABLED: True 76 | MIN_SIZES: (480, 576, 672, 768, 864, 960, 1056, 1152) 77 | MAX_SIZE: 4000 78 | FLIP: True 79 | EVAL_PERIOD: 5000 80 | EVAL_TRAIN: True 81 | DATASETS: 82 | TRAIN: ('voc_2007_train', 'voc_2007_val') 83 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2007_train_d2.pkl', 'datasets/proposals/sam_voc_2007_val_d2.pkl') 84 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 4000 85 | TEST: ('voc_2007_test',) 86 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', ) 87 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 4000 88 | VIS_PERIOD: 5000 89 | VERSION: 2 -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_MRRP_V_16_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/VGG/VGG_ILSVRC_16_layers_v1_d2.pkl" 4 | PIXEL_MEAN: [103.939, 116.779, 123.68] 5 | BACKBONE: 6 | NAME: "build_mrrp_vgg_backbone" 7 | FREEZE_AT: 5 8 | MRRP: 9 | MRRP_ON: True 10 | NUM_BRANCH: 3 11 | BRANCH_DILATIONS: [1, 2, 4] 12 | TEST_BRANCH_IDX: -1 13 | MRRP_STAGE: "plain5" 14 | VGG: 15 | DEPTH: 16 16 | CONV5_DILATION: 2 17 | ANCHOR_GENERATOR: 18 | SIZES: [[32, 64], [128, 256], [512, 768]] 19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 20 | ROI_HEADS: 21 | IN_FEATURES: ["plain5"] 22 | SOLVER: 23 | STEPS: (70000,) 24 | MAX_ITER: 100000 25 | WARMUP_ITERS: 200 26 | IMS_PER_BATCH: 4 27 | BASE_LR: 0.001 28 | WEIGHT_DECAY: 0.0005 29 | BIAS_LR_FACTOR: 2.0 30 | WEIGHT_DECAY_BIAS: 0.0 31 | -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_MRRP_V_16_DC5_VOC12_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "WSOVOD_MRRP_V_16_DC5_1x.yaml" 2 | DATASETS: 3 | TRAIN: ('voc_2012_train', 'voc_2012_val') 4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl') 5 | TEST: ('voc_2007_test',) 6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', ) -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_MRRP_WSR_18_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_mrrp_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | MRRP: 8 | MRRP_ON: True 9 | NUM_BRANCH: 3 10 | BRANCH_DILATIONS: [1, 2, 4] 11 | TEST_BRANCH_IDX: -1 12 | MRRP_STAGE: "res5" 13 | RESNETS: 14 | DEPTH: 18 15 | RES5_DILATION: 2 16 | RES2_OUT_CHANNELS: 64 17 | OUT_FEATURES: ["res5"] 18 | ANCHOR_GENERATOR: 19 | SIZES: [[32, 64], [128, 256], [512, 768]] 20 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 21 | ROI_HEADS: 22 | IN_FEATURES: ["res5"] 23 | ROI_BOX_HEAD: 24 | POOLER_TYPE: "ROILoopPool" 25 | NUM_CONV: 0 26 | NUM_FC: 2 27 | DAN_DIM: [4096, 4096] 28 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 29 | SOLVER: 30 | STEPS: (70000,) 31 | MAX_ITER: 100000 32 | WARMUP_ITERS: 200 33 | IMS_PER_BATCH: 4 34 | BASE_LR: 0.01 35 | WEIGHT_DECAY: 0.0005 36 | BIAS_LR_FACTOR: 2.0 37 | WEIGHT_DECAY_BIAS: 0.0 38 | -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_MRRP_WSR_18_DC5_VOC12_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "WSOVOD_MRRP_WSR_18_DC5_1x.yaml" 2 | DATASETS: 3 | TRAIN: ('voc_2012_train', 'voc_2012_val') 4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl') 5 | TEST: ('voc_2007_test',) 6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', ) -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_MRRP_WSR_50_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_mrrp_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | MRRP: 8 | MRRP_ON: True 9 | NUM_BRANCH: 3 10 | BRANCH_DILATIONS: [1, 2, 4] 11 | TEST_BRANCH_IDX: -1 12 | MRRP_STAGE: "res5" 13 | RESNETS: 14 | DEPTH: 50 15 | RES5_DILATION: 2 16 | OUT_FEATURES: ["res5"] 17 | ANCHOR_GENERATOR: 18 | SIZES: [[32, 64], [128, 256], [512, 768]] 19 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 20 | ROI_HEADS: 21 | IN_FEATURES: ["res5"] 22 | ROI_BOX_HEAD: 23 | POOLER_TYPE: "ROILoopPool" 24 | NUM_CONV: 0 25 | NUM_FC: 2 26 | DAN_DIM: [4096, 4096] 27 | BBOX_REG_LOSS_TYPE: "smooth_l1_weighted" 28 | SOLVER: 29 | STEPS: (70000,) 30 | MAX_ITER: 100000 31 | WARMUP_ITERS: 200 32 | IMS_PER_BATCH: 4 33 | BASE_LR: 0.01 34 | WEIGHT_DECAY: 0.0005 35 | BIAS_LR_FACTOR: 2.0 36 | WEIGHT_DECAY_BIAS: 0.0 37 | -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_MRRP_WSR_50_DC5_VOC12_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "WSOVOD_MRRP_WSR_50_DC5_1x.yaml" 2 | DATASETS: 3 | TRAIN: ('voc_2012_train', 'voc_2012_val') 4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl') 5 | TEST: ('voc_2007_test',) 6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', ) -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_V_16_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/VGG/VGG_ILSVRC_16_layers_v1_d2.pkl" 4 | PIXEL_MEAN: [103.939, 116.779, 123.68] 5 | BACKBONE: 6 | NAME: "build_vgg_backbone" 7 | FREEZE_AT: 5 8 | VGG: 9 | DEPTH: 16 10 | CONV5_DILATION: 2 11 | ANCHOR_GENERATOR: 12 | SIZES: [32, 64, 128, 256, 512, 768] 13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 14 | ROI_HEADS: 15 | IN_FEATURES: ["plain5"] 16 | SOLVER: 17 | STEPS: (70000,) 18 | MAX_ITER: 100000 19 | WARMUP_ITERS: 200 20 | IMS_PER_BATCH: 4 21 | BASE_LR: 0.001 22 | WEIGHT_DECAY: 0.0005 23 | BIAS_LR_FACTOR: 2.0 24 | WEIGHT_DECAY_BIAS: 0.0 25 | -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_V_16_DC5_VOC12_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "WSOVOD_V_16_DC5_1x.yaml" 2 | DATASETS: 3 | TRAIN: ('voc_2012_train', 'voc_2012_val') 4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl') 5 | TEST: ('voc_2007_test',) 6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', ) -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_WSR_18_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet18_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | RESNETS: 8 | DEPTH: 18 9 | RES5_DILATION: 2 10 | RES2_OUT_CHANNELS: 64 11 | OUT_FEATURES: ["res5"] 12 | ANCHOR_GENERATOR: 13 | SIZES: [32, 64, 128, 256, 512, 768] 14 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 15 | ROI_HEADS: 16 | IN_FEATURES: ["res5"] 17 | SOLVER: 18 | STEPS: (70000,) 19 | MAX_ITER: 100000 20 | WARMUP_ITERS: 200 21 | IMS_PER_BATCH: 4 22 | BASE_LR: 0.01 23 | WEIGHT_DECAY: 0.0005 24 | BIAS_LR_FACTOR: 2.0 25 | WEIGHT_DECAY_BIAS: 0.0 26 | -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_WSR_18_DC5_VOC12_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "WSOVOD_WSR_18_DC5_1x.yaml" 2 | DATASETS: 3 | TRAIN: ('voc_2012_train', 'voc_2012_val') 4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl') 5 | TEST: ('voc_2007_test',) 6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', ) -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_WSR_50_DC5_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-DilatedC5.yaml" 2 | MODEL: 3 | WEIGHTS: "models/DRN-WSOD/resnet50_ws_model_120_d2.pkl" 4 | BACKBONE: 5 | NAME: "build_wsl_resnet_backbone" 6 | FREEZE_AT: 5 7 | RESNETS: 8 | DEPTH: 50 9 | RES5_DILATION: 2 10 | OUT_FEATURES: ["res5"] 11 | ANCHOR_GENERATOR: 12 | SIZES: [32, 64, 128, 256, 512, 768] 13 | ASPECT_RATIOS: [[1.0, 2.0, 0.5]] 14 | ROI_HEADS: 15 | IN_FEATURES: ["res5"] 16 | SOLVER: 17 | STEPS: (70000,) 18 | MAX_ITER: 100000 19 | WARMUP_ITERS: 200 20 | IMS_PER_BATCH: 4 21 | BASE_LR: 0.01 22 | WEIGHT_DECAY: 0.0005 23 | BIAS_LR_FACTOR: 2.0 24 | WEIGHT_DECAY_BIAS: 0.0 25 | -------------------------------------------------------------------------------- /configs/PascalVOC-Detection/WSOVOD_WSR_50_DC5_VOC12_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "WSOVOD_WSR_50_DC5_1x.yaml" 2 | DATASETS: 3 | TRAIN: ('voc_2012_train', 'voc_2012_val') 4 | PROPOSAL_FILES_TRAIN: ('datasets/proposals/sam_voc_2012_train_d2.pkl', 'datasets/proposals/sam_voc_2012_val_d2.pkl') 5 | TEST: ('voc_2007_test',) 6 | PROPOSAL_FILES_TEST: ('datasets/proposals/sam_voc_2007_test_d2.pkl', ) -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Use Builtin Datasets 2 | 3 | A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog) 4 | for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc). 5 | This document explains how to setup the builtin datasets so they can be used by the above APIs. 6 | [Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`, 7 | and how to add new datasets to them. 8 | 9 | Detectron2 has builtin support for a few datasets. 10 | The datasets are assumed to exist in a directory specified by the environment variable 11 | `DETECTRON2_DATASETS`. 12 | Under this directory, detectron2 will look for datasets in the structure described below, if needed. 13 | ``` 14 | $DETECTRON2_DATASETS/ 15 | coco/ 16 | lvis/ 17 | ILSVRC2012/ 18 | VOC20{07,12}/ 19 | ``` 20 | 21 | You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`. 22 | If left unset, the default is `./datasets` relative to your current working directory. 23 | 24 | ## Expected dataset structure for COCO instance/keypoint detection: 25 | 26 | ``` 27 | coco/ 28 | annotations/ 29 | instances_{train,val}2017.json 30 | person_keypoints_{train,val}2017.json 31 | {train,val}2017/ 32 | # image files that are mentioned in the corresponding json 33 | ``` 34 | 35 | You can use the 2014 version of the dataset as well. 36 | 37 | Some of the builtin tests (`dev/run_*_tests.sh`) uses a tiny version of the COCO dataset, 38 | which you can download with `./prepare_for_tests.sh`. 39 | 40 | 41 | ## Expected dataset structure for LVIS instance segmentation: 42 | ``` 43 | coco/ 44 | {train,val,test}2017/ 45 | lvis/ 46 | lvis_v0.5_{train,val}.json 47 | lvis_v0.5_image_info_test.json 48 | lvis_v1_{train,val}.json 49 | lvis_v1_image_info_test{,_challenge}.json 50 | ``` 51 | 52 | Install lvis-api by: 53 | ``` 54 | pip install git+https://github.com/lvis-dataset/lvis-api.git 55 | ``` 56 | 57 | To evaluate models trained on the COCO dataset using LVIS annotations, 58 | run `python prepare_cocofied_lvis.py` to prepare "cocofied" LVIS annotations. 59 | 60 | ## Expected dataset structure for Pascal VOC: 61 | ``` 62 | VOC20{07,12}/ 63 | Annotations/ 64 | ImageSets/ 65 | Main/ 66 | trainval.txt 67 | test.txt 68 | # train.txt or val.txt, if you use these splits 69 | JPEGImages/ 70 | ``` 71 | 72 | ## Expected dataset structure for ILSVRC2012: 73 | Go to [this link](https://www.image-net.org/download-images.php) to download tar files (for training and validation) 74 | ``` 75 | ├── ILSVRC2012_img_train.tar 76 | └── ILSVRC2012_img_val.tar 77 | ``` 78 | Run bash scripts/extract_ilsvrc.sh for handling the above compressed files. 79 | Be sure if they are arranged like below: 80 | ``` 81 | ./train 82 | ├── n07693725 83 | ├── ... 84 | └── n07614500 85 | ./val 86 | ├── n01440764 87 | ├── ... 88 | └── n04458633 89 | ``` 90 | Run below scripts for getting json annotations. 91 | ``` 92 | bash scripts/prepare_ilsvrc.sh datasets/ILSVRC2012/val/ output/temp/ilsvrc_2012_val_info.json datasets/ILSVRC2012/ILSVRC2012_img_val.json tools/ilsvrc2012_classes_name.txt datasets/ILSVRC2012/ILSVRC2012_img_val_converted.json 93 | 94 | bash scripts/prepare_ilsvrc.sh datasets/ILSVRC2012/train/ output/temp/ilsvrc_2012_train_info.json datasets/ILSVRC2012/ILSVRC2012_img_train.json tools/ilsvrc2012_classes_name.txt datasets/ILSVRC2012/ILSVRC2012_img_train_converted.json 95 | ``` 96 | 97 | ``` 98 | ILSCRC2012/ 99 | ILSVRC2012_img_train_converted.json 100 | ILSVRC2012_img_val_converted.json 101 | {train,val}/ 102 | n01440764/*.JPEG # image files that are mentioned in the corresponding json 103 | ...... 104 | n15075141/*.JPEG # image files that are mentioned in the corresponding json 105 | 106 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/openai/CLIP.git 2 | git+https://github.com/facebookresearch/segment-anything.git 3 | git+https://github.com/facebookresearch/detectron2.git 4 | opencv-python 5 | scikit-image 6 | shapely 7 | graphviz 8 | timm 9 | opencv-contrib-python 10 | Pillow==9.5.0 -------------------------------------------------------------------------------- /scripts/extract_ilsvrc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # script to extract ImageNet dataset 4 | # ILSVRC2012_img_train.tar (about 138 GB) 5 | # ILSVRC2012_img_val.tar (about 6.3 GB) 6 | # make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory 7 | # 8 | # https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md 9 | # 10 | # train/ 11 | # ├── n01440764 12 | # │ ├── n01440764_10026.JPEG 13 | # │ ├── n01440764_10027.JPEG 14 | # │ ├── ...... 15 | # ├── ...... 16 | # val/ 17 | # ├── n01440764 18 | # │ ├── ILSVRC2012_val_00000293.JPEG 19 | # │ ├── ILSVRC2012_val_00002138.JPEG 20 | # │ ├── ...... 21 | # ├── ...... 22 | # 23 | # 24 | # Extract the training data: 25 | # 26 | mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train 27 | tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar 28 | find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done 29 | cd .. 30 | # 31 | # Extract the validation data and move images to subfolders: 32 | # 33 | mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar 34 | wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash 35 | # 36 | # Check total files after extract 37 | # 38 | # $ find train/ -name "*.JPEG" | wc -l 39 | # 1281167 40 | # $ find val/ -name "*.JPEG" | wc -l 41 | # 50000 42 | # -------------------------------------------------------------------------------- /scripts/generate_sam_proposals_cuda.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Assign the first argument to GPUS, the number of GPUs to use. 4 | GPUS=$1 5 | 6 | # Set the number of nodes, defaulting to 1 if not set. 7 | NNODES=${NNODES:-1} 8 | 9 | # Set the rank of the node, defaulting to 0 if not set. 10 | NODE_RANK=${NODE_RANK:-0} 11 | 12 | # Set the port, defaulting to a random value within a specific range if not set. 13 | PORT=${PORT:-$((28500 + $RANDOM % 2000))} 14 | 15 | # Set the master address, defaulting to localhost if not set. 16 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 17 | 18 | # Check if torchrun is available. 19 | if command -v torchrun &> /dev/null 20 | then 21 | echo "Using torchrun mode." 22 | # Set environment variables for Python path and thread settings, then execute the Python script with torchrun. 23 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ 24 | torchrun --nnodes=${NNODES} \ 25 | --node_rank=${NODE_RANK} \ 26 | --master_addr=${MASTER_ADDR} \ 27 | --master_port=${PORT} \ 28 | --nproc_per_node=${GPUS} \ 29 | tools/generate_sam_proposals_cuda.py "${@:2}" 30 | else 31 | echo "Using launch mode." 32 | # Fallback to using python -m torch.distributed.launch if torchrun is not available. 33 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ 34 | python -m torch.distributed.launch \ 35 | --nnodes=${NNODES} \ 36 | --node_rank=${NODE_RANK} \ 37 | --master_addr=${MASTER_ADDR} \ 38 | --master_port=${PORT} \ 39 | --nproc_per_node=${GPUS} \ 40 | tools/generate_sam_proposals_cuda.py "${@:2}" 41 | fi -------------------------------------------------------------------------------- /scripts/prepare_ilsvrc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | set -x 5 | temp="output/temp" 6 | if [ ! -d "$temp" ]; then 7 | # 如果目录不存在,使用mkdir -p递归创建目录 8 | mkdir -p "$temp" 9 | echo "Directory $temp created." 10 | else 11 | echo "Directory $temp already exists." 12 | fi 13 | 14 | img_root="$1" 15 | info_json="$2" 16 | out_file="$3" 17 | class_name="$4" 18 | converted_file="$5" 19 | 20 | python tools/ilsvrc_info.py --img-root ${img_root} --out-file ${info_json} 21 | python tools/ilsvrc_folder.py --out-file ${out_file} --info-json ${info_json} 22 | python tools/convert_ilsvrc_classes_name.py --ann ${out_file} --f ${class_name} --output ${converted_file} 23 | -------------------------------------------------------------------------------- /scripts/train_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | set -x 5 | 6 | train_file_path="$1" 7 | config_file_path="$2" 8 | GPU_NUM="$3" 9 | timestamp="$4" 10 | rest_args="${@:5}" 11 | PORT=${PORT:-$((28500 + $RANDOM % 2000))} 12 | 13 | if [ -z "$timestamp" ] 14 | then 15 | timestamp="`date +'%Y%m%d_%H%M%S'`" 16 | fi 17 | 18 | python ${train_file_path} --dist-url="tcp://127.0.0.1:${PORT}" --num-gpus ${GPU_NUM} --resume --config-file ${config_file_path} OUTPUT_DIR output/${config_file_path:0:-5}_${timestamp} ${rest_args} 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import glob 5 | import os 6 | import shutil 7 | from os import path 8 | from typing import List 9 | 10 | import torch 11 | from setuptools import find_packages, setup 12 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 13 | 14 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 15 | assert torch_ver >= [1, 8], "Requires PyTorch >= 1.8" 16 | 17 | 18 | def get_version(): 19 | init_py_path = path.join(path.abspath(path.dirname(__file__)), "wsovod", "__init__.py") 20 | init_py = open(init_py_path, "r").readlines() 21 | version_line = [l.strip() for l in init_py if l.startswith("__version__")][0] 22 | version = version_line.split("=")[-1].strip().strip("'\"") 23 | 24 | # The following is used to build release packages. 25 | # Users should never use it. 26 | suffix = os.getenv("D2_VERSION_SUFFIX", "") 27 | version = version + suffix 28 | if os.getenv("BUILD_NIGHTLY", "0") == "1": 29 | from datetime import datetime 30 | 31 | date_str = datetime.today().strftime("%y%m%d") 32 | version = version + ".dev" + date_str 33 | 34 | new_init_py = [l for l in init_py if not l.startswith("__version__")] 35 | new_init_py.append('__version__ = "{}"\n'.format(version)) 36 | with open(init_py_path, "w") as f: 37 | f.write("".join(new_init_py)) 38 | return version 39 | 40 | 41 | def get_extensions(): 42 | this_dir = path.dirname(path.abspath(__file__)) 43 | extensions_dir = path.join(this_dir, "wsovod", "layers") 44 | 45 | main_source = path.join(extensions_dir, "vision.cpp") 46 | sources = glob.glob(path.join(extensions_dir, "**", "*.cpp")) 47 | 48 | from torch.utils.cpp_extension import ROCM_HOME 49 | 50 | is_rocm_pytorch = ( 51 | True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False 52 | ) 53 | if is_rocm_pytorch: 54 | assert torch_ver >= [1, 8], "ROCM support requires PyTorch >= 1.8!" 55 | 56 | # common code between cuda and rocm platforms, for hipify version [1,0,0] and later. 57 | source_cuda = glob.glob(path.join(extensions_dir, "**", "*.cu")) + glob.glob( 58 | path.join(extensions_dir, "*.cu") 59 | ) 60 | sources = [main_source] + sources 61 | 62 | extension = CppExtension 63 | 64 | extra_compile_args = {"cxx": []} 65 | define_macros = [] 66 | 67 | if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv( 68 | "FORCE_CUDA", "0" 69 | ) == "1": 70 | extension = CUDAExtension 71 | sources += source_cuda 72 | 73 | if not is_rocm_pytorch: 74 | define_macros += [("WITH_CUDA", None)] 75 | extra_compile_args["nvcc"] = [ 76 | "-O3", 77 | "-DCUDA_HAS_FP16=1", 78 | "-D__CUDA_NO_HALF_OPERATORS__", 79 | "-D__CUDA_NO_HALF_CONVERSIONS__", 80 | "-D__CUDA_NO_HALF2_OPERATORS__", 81 | ] 82 | else: 83 | define_macros += [("WITH_HIP", None)] 84 | extra_compile_args["nvcc"] = [] 85 | 86 | if torch_ver < [1, 7]: 87 | # supported by https://github.com/pytorch/pytorch/pull/43931 88 | CC = os.environ.get("CC", None) 89 | if CC is not None: 90 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) 91 | 92 | include_dirs = [extensions_dir] 93 | 94 | ext_modules = [ 95 | extension( 96 | "wsovod._C", 97 | sources, 98 | include_dirs=include_dirs, 99 | define_macros=define_macros, 100 | extra_compile_args=extra_compile_args, 101 | ) 102 | ] 103 | 104 | return ext_modules 105 | 106 | 107 | # For projects that are relative small and provide features that are very close 108 | # to detectron2's core functionalities, we install them under detectron2.projects 109 | PROJECTS = {} 110 | 111 | setup( 112 | name="wsovod", 113 | version=get_version(), 114 | author="Hunterj Lin", 115 | url="https://github.com/HunterJ-Lin/WSOVOD", 116 | description="WSOVOD is next-generation research " 117 | "platform for open-vocabulary object detection and segmentation.", 118 | packages=find_packages(exclude=("configs", "tests*")) + list(PROJECTS.keys()), 119 | package_dir=PROJECTS, 120 | package_data={}, 121 | python_requires=">=3.8", 122 | install_requires=["graphviz"], 123 | ext_modules=get_extensions(), 124 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 125 | ) -------------------------------------------------------------------------------- /teaser/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HunterJ-Lin/WSOVOD/74e7647faaa0ab5de37a351532f65c57654f66be/teaser/framework.png -------------------------------------------------------------------------------- /tools/convert_ilsvrc_classes_name.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | if __name__ == '__main__': 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--ann', type=str, default='datasets/ILSVRC2012/ILSVRC2012_img_val.json') 7 | parser.add_argument('--f', type=str, default='tools/ilsvrc2012_classes_name.txt') 8 | parser.add_argument('--output',type=str, default='datasets/ILSVRC2012/ILSVRC2012_img_val_converted.json') 9 | args = parser.parse_args() 10 | with open(args.f,'r') as f: 11 | lines = f.readlines() 12 | d = {} 13 | for line in lines: 14 | k,v = line.split(':') 15 | d[k.strip()] = v.split(',')[0].strip() 16 | 17 | print('Loading', args.ann) 18 | data = json.load(open(args.ann, 'r')) 19 | data['categories'] = [{'id':cat['id'],'name':d[cat['name']]} for cat in data['categories']] 20 | print('Saving to', args.output) 21 | json.dump(data, open(args.output, 'w')) -------------------------------------------------------------------------------- /tools/generate_class_text_embedding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import cv2 4 | from six.moves import cPickle as pickle 5 | from tqdm import tqdm 6 | import numpy as np 7 | from detectron2.data.catalog import DatasetCatalog 8 | from detectron2.utils.file_io import PathManager 9 | import os 10 | import clip 11 | from detectron2.data import MetadataCatalog 12 | import torch 13 | 14 | PROMPTS = [ 15 | 'There is a {category} in the scene.', 16 | 'There is my {category} in the scene.', 17 | 'There is the {category} in the scene.', 18 | 'There is one {category} in the scene.', 19 | 'a photo of a {category} in the scene.', 20 | 'a photo of my {category} in the scene.', 21 | 'a photo of the {category} in the scene.', 22 | 'a photo of one {category} in the scene.', 23 | 'itap of a {category}.', 24 | 'itap of my {category}.', 25 | 'itap of the {category}.', 26 | 'itap of one {category}.', 27 | 'a photo of a {category}.', 28 | 'a photo of my {category}.', 29 | 'a photo of the {category}.', 30 | 'a photo of one {category}.', 31 | 'a good photo of a {category}.', 32 | 'a good photo of the {category}.', 33 | 'a bad photo of a {category}.', 34 | 'a bad photo of the {category}.', 35 | 'a photo of a nice {category}.', 36 | 'a photo of the nice {category}.', 37 | 'a photo of a cool {category}.', 38 | 'a photo of the cool {category}.', 39 | 'a photo of a weird {category}.', 40 | 'a photo of the weird {category}.', 41 | 'a photo of a small {category}.', 42 | 'a photo of the small {category}.', 43 | 'a photo of a large {category}.', 44 | 'a photo of the large {category}.', 45 | 'a photo of a clean {category}.', 46 | 'a photo of the clean {category}.', 47 | 'a photo of a dirty {category}.', 48 | 'a photo of the dirty {category}.', 49 | 'a bright photo of a {category}.', 50 | 'a bright photo of the {category}.', 51 | 'a dark photo of a {category}.', 52 | 'a dark photo of the {category}.', 53 | 'a photo of a hard to see {category}.', 54 | 'a photo of the hard to see {category}.', 55 | 'a low resolution photo of a {category}.', 56 | 'a low resolution photo of the {category}.', 57 | 'a cropped photo of a {category}.', 58 | 'a cropped photo of the {category}.', 59 | 'a close-up photo of a {category}.', 60 | 'a close-up photo of the {category}.', 61 | 'a jpeg corrupted photo of a {category}.', 62 | 'a jpeg corrupted photo of the {category}.', 63 | 'a blurry photo of a {category}.', 64 | 'a blurry photo of the {category}.', 65 | 'a pixelated photo of a {category}.', 66 | 'a pixelated photo of the {category}.', 67 | ] 68 | 69 | if __name__ == '__main__': 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('--dataset-name', type=str, default='coco_2017_val') 72 | parser.add_argument('--categories', type=str, default='') 73 | # parser.add_argument('--mode-type', type=str, default='ViT-L/14/32') 74 | parser.add_argument('--model-type', type=str, default='ViT-B/32') 75 | parser.add_argument('--prompt-type', type=str, default='single') 76 | parser.add_argument('--output', type=str, default='models/coco_text_embedding_single_prompt.pkl') 77 | args = parser.parse_args() 78 | # load clip model 79 | clip_model, clip_preprocess = clip.load(args.model_type, 'cpu', jit=False) 80 | clip_model = clip_model.eval() 81 | d = [] 82 | 83 | thing_classes = MetadataCatalog.get(args.dataset_name).thing_classes 84 | if thing_classes is None: 85 | thing_classes = args.categories.split(',') 86 | assert len(thing_classes)>0 87 | 88 | if args.prompt_type == 'mutiple': 89 | for category in thing_classes: 90 | text_inputs = torch.cat([clip.tokenize(prompt.format(**{'category':category})) for prompt in PROMPTS]).to('cpu') 91 | text_embedding = clip_model.encode_text(text_inputs).float() 92 | d[category] = text_embedding.mean(0) 93 | text_embeddings = torch.cat(d) 94 | else: 95 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}.") for c in thing_classes]).to('cpu') 96 | text_embeddings = clip_model.encode_text(text_inputs).float() 97 | 98 | print(text_embeddings.shape) 99 | print('save to '+args.output) 100 | with open(args.output, "wb") as f: 101 | pickle.dump(text_embeddings, f, pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /tools/generate_class_text_embedding_cuda.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import cv2 4 | from six.moves import cPickle as pickle 5 | from tqdm import tqdm 6 | import numpy as np 7 | from detectron2.data.catalog import DatasetCatalog 8 | from detectron2.utils.file_io import PathManager 9 | import os 10 | import clip 11 | from detectron2.data import MetadataCatalog 12 | from detectron2.utils.comm import is_main_process, get_world_size, get_local_rank 13 | import torch 14 | import torch.distributed as dist 15 | 16 | PROMPTS = [ 17 | 'There is a {category} in the scene.', 18 | 'There is my {category} in the scene.', 19 | 'There is the {category} in the scene.', 20 | 'There is one {category} in the scene.', 21 | 'a photo of a {category} in the scene.', 22 | 'a photo of my {category} in the scene.', 23 | 'a photo of the {category} in the scene.', 24 | 'a photo of one {category} in the scene.', 25 | 'itap of a {category}.', 26 | 'itap of my {category}.', 27 | 'itap of the {category}.', 28 | 'itap of one {category}.', 29 | 'a photo of a {category}.', 30 | 'a photo of my {category}.', 31 | 'a photo of the {category}.', 32 | 'a photo of one {category}.', 33 | 'a good photo of a {category}.', 34 | 'a good photo of the {category}.', 35 | 'a bad photo of a {category}.', 36 | 'a bad photo of the {category}.', 37 | 'a photo of a nice {category}.', 38 | 'a photo of the nice {category}.', 39 | 'a photo of a cool {category}.', 40 | 'a photo of the cool {category}.', 41 | 'a photo of a weird {category}.', 42 | 'a photo of the weird {category}.', 43 | 'a photo of a small {category}.', 44 | 'a photo of the small {category}.', 45 | 'a photo of a large {category}.', 46 | 'a photo of the large {category}.', 47 | 'a photo of a clean {category}.', 48 | 'a photo of the clean {category}.', 49 | 'a photo of a dirty {category}.', 50 | 'a photo of the dirty {category}.', 51 | 'a bright photo of a {category}.', 52 | 'a bright photo of the {category}.', 53 | 'a dark photo of a {category}.', 54 | 'a dark photo of the {category}.', 55 | 'a photo of a hard to see {category}.', 56 | 'a photo of the hard to see {category}.', 57 | 'a low resolution photo of a {category}.', 58 | 'a low resolution photo of the {category}.', 59 | 'a cropped photo of a {category}.', 60 | 'a cropped photo of the {category}.', 61 | 'a close-up photo of a {category}.', 62 | 'a close-up photo of the {category}.', 63 | 'a jpeg corrupted photo of a {category}.', 64 | 'a jpeg corrupted photo of the {category}.', 65 | 'a blurry photo of a {category}.', 66 | 'a blurry photo of the {category}.', 67 | 'a pixelated photo of a {category}.', 68 | 'a pixelated photo of the {category}.', 69 | ] 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--dataset-name', type=str, default='voc_2007_val') 74 | parser.add_argument('--categories', type=str, default='') 75 | parser.add_argument('--bs', type=int, default=32) 76 | # parser.add_argument('--mode-type', type=str, default='ViT-L/14/32') 77 | parser.add_argument('--model-type', type=str, default='ViT-B/32') 78 | parser.add_argument('--prompt-type', type=str, default='single') 79 | parser.add_argument('--output', type=str, default='models/voc_text_embedding_single_prompt_cuda.pkl') 80 | args = parser.parse_args() 81 | 82 | # load clip model 83 | clip_model, clip_preprocess = clip.load(args.model_type, 'cuda', jit=False) 84 | clip_model = clip_model.eval() 85 | 86 | thing_classes = MetadataCatalog.get(args.dataset_name).thing_classes 87 | if thing_classes is None: 88 | thing_classes = args.categories.split(',') 89 | assert len(thing_classes)>0 90 | 91 | descriptions = [] 92 | candidates = [] 93 | for cls_name in thing_classes: 94 | if args.prompt_type == 'mutiple': 95 | candidates.append(len(PROMPTS)) 96 | for template in PROMPTS: 97 | description = template.format(**{'category':cls_name}) 98 | descriptions.append(description) 99 | else: 100 | candidates.append(1) 101 | descriptions.append(f"a photo of a {cls_name}.") 102 | 103 | with torch.no_grad(): 104 | tot = len(descriptions) 105 | bs = args.bs 106 | nb = tot // bs 107 | if tot % bs != 0: 108 | nb += 1 109 | text_embeddings_list = [] 110 | for i in range(nb): 111 | local_descriptions = descriptions[i * bs: (i + 1) * bs] 112 | text_inputs = torch.cat([clip.tokenize(ds) for ds in local_descriptions]).to('cuda') 113 | local_text_embeddings = clip_model.encode_text(text_inputs).to(device='cpu').float() 114 | text_embeddings_list.append(local_text_embeddings) 115 | text_embeddings = torch.cat(text_embeddings_list) 116 | 117 | dim = text_embeddings.shape[-1] 118 | candidate_tot = sum(candidates) 119 | text_embeddings = text_embeddings.split(candidates, dim=0) 120 | if args.prompt_type == 'mutiple': 121 | text_embeddings = [text_embedding.mean(0).unsqueeze(0) for text_embedding in text_embeddings] 122 | 123 | text_embeddings = torch.cat(text_embeddings) 124 | print('save to '+args.output) 125 | with open(args.output, "wb") as f: 126 | pickle.dump(text_embeddings, f, pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /tools/generate_sam_proposals_cuda.py: -------------------------------------------------------------------------------- 1 | from math import e 2 | import cv2 3 | import os 4 | import argparse 5 | import numpy as np 6 | import pickle 7 | from detectron2.data.catalog import DatasetCatalog 8 | import torch 9 | import torch.distributed as dist 10 | from detectron2.utils.comm import synchronize, get_world_size, get_rank 11 | from wsovod.data.datasets import builtin 12 | 13 | # Import DistributedDataParallel 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | 16 | def process_image(mask_generator, image_info): 17 | image_path = image_info['file_name'] 18 | print("processing ", image_path) 19 | image = cv2.imread(image_path) 20 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 21 | try: 22 | masks = mask_generator.generate(image) 23 | except Exception as e: 24 | print(f"Error processing image {image_path}: {e}") 25 | mask_generator.predictor.model.to('cpu') 26 | masks = mask_generator.generate(image) 27 | mask_generator.predictor.model.to(f'cuda:{get_rank()}') 28 | proposals = [] 29 | scores = [] 30 | for instance in masks: 31 | score = instance['predicted_iou'] * instance['stability_score'] 32 | if score > 1.0: 33 | score = 1.0 34 | bbox = instance['bbox'] 35 | if bbox[2] <= 0 or bbox[3] <= 0: 36 | continue 37 | bbox[2] = bbox[0] + bbox[2] 38 | bbox[3] = bbox[1] + bbox[3] 39 | scores.append(score) 40 | proposals.append(bbox) 41 | proposals = np.array(proposals) 42 | scores = np.array(scores) 43 | return proposals, scores, image_info['image_id'] 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--checkpoint', type=str, default='tools/sam_checkpoints/sam_vit_h_4b8939.pth') 48 | parser.add_argument('--model-type', type=str, default='vit_h') 49 | parser.add_argument('--dataset-name', type=str, default='coco_2017_val') 50 | parser.add_argument('--output', type=str, default='datasets/proposals/sam_coco_2017_val_d2.pkl') 51 | parser.add_argument('--points-per-side', type=int, default=32) 52 | parser.add_argument('--pred-iou-thresh', type=float, default=0.86) 53 | parser.add_argument('--stability-score-thresh', type=float, default=0.92) 54 | parser.add_argument('--crop-n-layers', type=int, default=1) 55 | parser.add_argument('--crop-n-points-downscale-factor', type=int, default=2) 56 | parser.add_argument('--min-mask-region-area', type=float, default=20.0) 57 | # Add an argument for the local rank that will be set by torchrun or the PyTorch launcher 58 | parser.add_argument('--local_rank', type=int, default=0) 59 | args = parser.parse_args() 60 | import os 61 | print("Local rank (from environment):", os.environ.get("LOCAL_RANK")) 62 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 63 | dataset_dicts = DatasetCatalog.get(args.dataset_name) 64 | # dataset_dicts = dataset_dicts[:100] for debugging 65 | rank = get_rank() 66 | world_size = get_world_size() 67 | device = torch.device(f'cuda:{rank}') 68 | print(f"Rank: {rank}, World Size: {world_size}, Device: {device}") 69 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 70 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) 71 | sam.to(device=device) 72 | mask_generator = SamAutomaticMaskGenerator( 73 | model=sam, 74 | points_per_side=args.points_per_side, 75 | pred_iou_thresh=args.pred_iou_thresh, 76 | stability_score_thresh=args.stability_score_thresh, 77 | crop_n_layers=args.crop_n_layers, 78 | crop_n_points_downscale_factor=args.crop_n_points_downscale_factor, 79 | min_mask_region_area=args.min_mask_region_area, # Requires open-cv to run post-processing 80 | ) 81 | 82 | # Now, adjust the data loading and processing to only work on a subset of the data based on the rank of the process 83 | # This is a simple way to split the dataset, more sophisticated methods might be needed for your use case 84 | subset_size = len(dataset_dicts) // world_size 85 | start_idx = rank * subset_size 86 | end_idx = start_idx + subset_size if rank < world_size - 1 else len(dataset_dicts) 87 | local_dataset_dicts = dataset_dicts[start_idx:end_idx] 88 | 89 | all_boxes = [] 90 | all_scores = [] 91 | all_indexes = [] 92 | if rank == 0: 93 | from tqdm import tqdm 94 | local_dataset_dicts = tqdm(local_dataset_dicts) 95 | with torch.no_grad(): 96 | for image_info in local_dataset_dicts: 97 | boxes, scores, indexes = process_image(mask_generator, image_info) 98 | all_boxes.append(boxes) 99 | all_scores.append(scores) 100 | all_indexes.append(indexes) 101 | 102 | # Gather results from all processes 103 | gathered_boxes = [None] * world_size 104 | gathered_scores = [None] * world_size 105 | gathered_indexes = [None] * world_size 106 | torch.cuda.set_device(rank) 107 | dist.barrier() 108 | print(f"Rank: {rank} gathering boxes results...") 109 | dist.all_gather_object(gathered_boxes, all_boxes) 110 | print(f"Rank: {rank} gathered boxes.") 111 | dist.barrier() 112 | print(f"Rank: {rank} gathering scores results...") 113 | dist.all_gather_object(gathered_scores, all_scores) 114 | print(f"Rank: {rank} gathered scores.") 115 | dist.barrier() 116 | print(f"Rank: {rank} gathering indexes results...") 117 | dist.all_gather_object(gathered_indexes, all_indexes) 118 | print(f"Rank: {rank} gathered indexes.") 119 | dist.barrier() 120 | # Collecting and saving the results should be done by one process 121 | if rank == 0: 122 | print("All results gathered.") 123 | # Flatten lists from all processes 124 | all_boxes_flat = [item for sublist in gathered_boxes for item in sublist] 125 | all_scores_flat = [item for sublist in gathered_scores for item in sublist] 126 | all_indexes_flat = [item for sublist in gathered_indexes for item in sublist] 127 | assert len(all_boxes_flat) == len(all_scores_flat) == len(all_indexes_flat) == len(dataset_dicts) 128 | assert set(all_indexes_flat) == set([image_info['image_id'] for image_info in dataset_dicts]) 129 | # Save only in the master process 130 | output_file = args.output 131 | with open(output_file, 'wb') as f: 132 | pickle.dump({'boxes': all_boxes_flat, 'scores': all_scores_flat, 'indexes': all_indexes_flat}, f) 133 | print("Proposal generation and saving completed.") 134 | # Ensure all processes have finished before exiting 135 | dist.barrier() 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /tools/ilsvrc_folder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import os.path as osp 5 | import pathlib 6 | import random 7 | import xml.dom.minidom 8 | from distutils.util import strtobool 9 | 10 | import numpy as np 11 | from detectron2.data.detection_utils import read_image 12 | # from nltk.corpus import wordnet as wn 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | # import nltk 17 | # nltk.download('wordnet') 18 | 19 | miss_xml = [] 20 | folder_classes = [] 21 | 22 | def parse_xml(file_path): 23 | print("processing: ", file_path) 24 | 25 | dom = xml.dom.minidom.parse(file_path) 26 | document = dom.documentElement 27 | 28 | size = document.getElementsByTagName('size') 29 | bboxes = [] 30 | labels = [] 31 | for item in size: 32 | width = item.getElementsByTagName('width') 33 | width = int(width[0].firstChild.data) 34 | height = item.getElementsByTagName('height') 35 | height = int(height[0].firstChild.data) 36 | 37 | # try: 38 | if True: 39 | object = document.getElementsByTagName('object') 40 | for item in object: 41 | label = item.getElementsByTagName('name') 42 | label = str(label[0].firstChild.data) 43 | 44 | xmin = item.getElementsByTagName('xmin') 45 | xmin = float(xmin[0].firstChild.data) 46 | ymin = item.getElementsByTagName('ymin') 47 | ymin = float(ymin[0].firstChild.data) 48 | xmax = item.getElementsByTagName('xmax') 49 | xmax = float(xmax[0].firstChild.data) 50 | ymax = item.getElementsByTagName('ymax') 51 | ymax = float(ymax[0].firstChild.data) 52 | 53 | bboxes.append([xmin, ymin, xmax, ymax]) 54 | labels.append(label) 55 | # except: 56 | # continue 57 | 58 | return bboxes, labels 59 | 60 | 61 | def cvt_annotations(path_h_w, out_file, has_instance, has_segmentation, ann_root): 62 | label_ids = {name: i for i, name in enumerate(folder_classes)} 63 | print('cvt annotations') 64 | print("label_ids: ", label_ids) 65 | 66 | annotations = [] 67 | pbar = enumerate(path_h_w) 68 | pbar = tqdm(pbar) 69 | for i, [img_path, height, width] in pbar: 70 | # if i > 10000: 71 | # break 72 | # print(i, img_path) 73 | 74 | wnid = pathlib.PurePath(img_path).parent.name 75 | 76 | if has_instance: 77 | xml_path = ann_root + img_path[:-5] + ".xml" 78 | if os.path.exists(xml_path): 79 | bboxes, wnids = parse_xml(xml_path) 80 | ## check error wnid 81 | # for wnid_ in wnids: 82 | # assert wnid_ == wnid, [wnids, wnid] 83 | # labels = [label_ids[wnid_] for wnid_ in wnids] 84 | labels = [label_ids[wnid] for wnid_ in wnids] 85 | else: 86 | print("xml file not found", xml_path) 87 | miss_xml.append(xml_path) 88 | # continue 89 | print("genenrate pseudo annotation for", img_path) 90 | bboxes = [[1, 1, width, height]] 91 | labels = [label_ids[wnid]] 92 | 93 | bboxes = np.array(bboxes, ndmin=2) - 1 94 | labels = np.array(labels) 95 | else: 96 | bboxes = np.array(np.zeros((0, 4), dtype=np.float32)) 97 | labels = np.array(np.zeros((0,), dtype=np.int64)) 98 | 99 | annotation = { 100 | "filename": img_path, 101 | "width": width, 102 | "height": height, 103 | "ann": { 104 | "bboxes": bboxes.astype(np.float32), 105 | "labels": labels.astype(np.int64), 106 | "bboxes_ignore": np.zeros((0, 4), dtype=np.float32), 107 | "labels_ignore": np.zeros((0,), dtype=np.int64), 108 | }, 109 | } 110 | 111 | annotations.append(annotation) 112 | annotations = cvt_to_coco_json(annotations, has_segmentation) 113 | 114 | with open(out_file, "w") as f: 115 | json.dump(annotations, f) 116 | 117 | return annotations 118 | 119 | 120 | def cvt_to_coco_json(annotations, has_segmentation): 121 | image_id = 0 122 | annotation_id = 0 123 | coco = dict() 124 | coco["images"] = [] 125 | coco["type"] = "instance" 126 | coco["categories"] = [] 127 | coco["annotations"] = [] 128 | image_set = set() 129 | print('cvt to coco json') 130 | def addAnnItem(annotation_id, image_id, category_id, bbox, difficult_flag): 131 | annotation_item = dict() 132 | if has_segmentation: 133 | annotation_item["segmentation"] = [] 134 | 135 | seg = [] 136 | # bbox[] is x1,y1,x2,y2 137 | # left_top 138 | seg.append(int(bbox[0])) 139 | seg.append(int(bbox[1])) 140 | # left_bottom 141 | seg.append(int(bbox[0])) 142 | seg.append(int(bbox[3])) 143 | # right_bottom 144 | seg.append(int(bbox[2])) 145 | seg.append(int(bbox[3])) 146 | # right_top 147 | seg.append(int(bbox[2])) 148 | seg.append(int(bbox[1])) 149 | 150 | annotation_item["segmentation"].append(seg) 151 | 152 | xywh = np.array([bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]]) 153 | annotation_item["area"] = int(xywh[2] * xywh[3]) 154 | if difficult_flag == 1: 155 | annotation_item["ignore"] = 0 156 | annotation_item["iscrowd"] = 1 157 | else: 158 | annotation_item["ignore"] = 0 159 | annotation_item["iscrowd"] = 0 160 | annotation_item["image_id"] = int(image_id) 161 | annotation_item["bbox"] = xywh.astype(int).tolist() 162 | annotation_item["category_id"] = int(category_id) 163 | annotation_item["id"] = int(annotation_id) 164 | coco["annotations"].append(annotation_item) 165 | return annotation_id + 1 166 | 167 | for category_id, name in enumerate(folder_classes): 168 | category_item = dict() 169 | category_item["supercategory"] = str("none") 170 | category_item["id"] = int(category_id) 171 | category_item["name"] = str(name) 172 | coco["categories"].append(category_item) 173 | 174 | pbar = tqdm(annotations) 175 | for ann_dict in pbar: 176 | file_name = ann_dict["filename"] 177 | ann = ann_dict["ann"] 178 | assert file_name not in image_set 179 | image_item = dict() 180 | image_item["id"] = int(image_id) 181 | image_item["file_name"] = str(file_name) 182 | image_item["height"] = int(ann_dict["height"]) 183 | image_item["width"] = int(ann_dict["width"]) 184 | coco["images"].append(image_item) 185 | image_set.add(file_name) 186 | 187 | bboxes = ann["bboxes"][:, :4] 188 | labels = ann["labels"] 189 | for bbox_id in range(len(bboxes)): 190 | bbox = bboxes[bbox_id] 191 | label = labels[bbox_id] 192 | annotation_id = addAnnItem(annotation_id, image_id, label, bbox, difficult_flag=0) 193 | 194 | bboxes_ignore = ann["bboxes_ignore"][:, :4] 195 | labels_ignore = ann["labels_ignore"] 196 | for bbox_id in range(len(bboxes_ignore)): 197 | bbox = bboxes_ignore[bbox_id] 198 | label = labels_ignore[bbox_id] 199 | annotation_id = addAnnItem(annotation_id, image_id, label, bbox, difficult_flag=1) 200 | 201 | image_id += 1 202 | 203 | return coco 204 | 205 | 206 | def _strtobool(x): 207 | return bool(strtobool(x)) 208 | 209 | 210 | def parse_args(): 211 | parser = argparse.ArgumentParser(description="Convert image list to coco format") 212 | parser.add_argument("--ann-root", help="ann root", type=str, default='') 213 | parser.add_argument("--out-file", help="output path", required=True) 214 | parser.add_argument("--info-json", help="output path", required=True) 215 | parser.add_argument('--has-instance', nargs='?', const=True, type=_strtobool, default=True, help="has instance") 216 | parser.add_argument('--has-segmentation', nargs='?', const=True, type=_strtobool, default=True, help="has segmentation") 217 | args = parser.parse_args() 218 | return args 219 | 220 | 221 | def main(): 222 | 223 | args = parse_args() 224 | print(args) 225 | ann_root = args.ann_root 226 | out_file = args.out_file 227 | 228 | has_instance = args.has_instance 229 | has_segmentation = args.has_segmentation 230 | 231 | global folder_classes 232 | with open(args.info_json,'r') as f: 233 | d = json.load(f) 234 | folder_classes = d['categories'] 235 | path_h_w = d['path_h_w'] 236 | annotations = cvt_annotations(path_h_w, out_file, has_instance, has_segmentation, ann_root) 237 | print("Done!") 238 | print('miss xml file:{}'.format(len(miss_xml))) 239 | print(miss_xml) 240 | 241 | with open(out_file, "w") as f: 242 | json.dump(annotations, f) 243 | 244 | print(args) 245 | 246 | 247 | if __name__ == "__main__": 248 | main() -------------------------------------------------------------------------------- /tools/ilsvrc_info.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import os.path as osp 5 | import pathlib 6 | import random 7 | import xml.dom.minidom 8 | from distutils.util import strtobool 9 | 10 | import numpy as np 11 | from detectron2.data.detection_utils import read_image 12 | # from nltk.corpus import wordnet as wn 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | 17 | def get_filename_key(x): 18 | basename = os.path.basename(x) 19 | if basename[:-4].isdigit(): 20 | return int(basename[:-4]) 21 | 22 | return basename 23 | 24 | def _strtobool(x): 25 | return bool(strtobool(x)) 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description="Statistical ILSVRC") 29 | parser.add_argument("--img-root", help="img root", required=True) 30 | parser.add_argument("--out-file", help="output path", required=True) 31 | parser.add_argument('--max-per-dir', nargs='?', const=1e10, type=int, default=1e10, help="max per dir") 32 | parser.add_argument('--min-per-dir', nargs='?', const=0, type=int, default=0, help="min per dir") 33 | parser.add_argument('--has-shuffle', nargs='?', const=False, type=_strtobool, default=False, help="shuffle or sort") 34 | args = parser.parse_args() 35 | return args 36 | 37 | 38 | def main(): 39 | args = parse_args() 40 | print(args) 41 | img_root = args.img_root 42 | out_file = args.out_file 43 | 44 | max_per_dir = args.max_per_dir 45 | min_per_dir = args.min_per_dir 46 | 47 | has_shuffle = args.has_shuffle 48 | 49 | folder_classes = [] 50 | path_h_w = [] 51 | pbar = tqdm(os.walk(img_root)) 52 | for root, dirs, files in pbar: 53 | if not os.path.basename(root).startswith("n"): 54 | print("\tskip folder: ", root) 55 | continue 56 | 57 | files = [f for f in files if f.endswith(('.jpg', '.JPG', '.png', '.PNG', '.jpeg', '.JPEG'))] 58 | 59 | if min_per_dir > 0 and len(files) < min_per_dir: 60 | print("\tskip folder: ", root, " #image: ", len(files)) 61 | continue 62 | 63 | if has_shuffle: 64 | random.shuffle(files) 65 | else: 66 | files = sorted(files, key=lambda x: get_filename_key(x), reverse=False) 67 | 68 | 69 | folder_classes_this = [] 70 | path_h_w_this = [] 71 | for name in files: 72 | if max_per_dir > 0 and len(path_h_w_this) >= max_per_dir: 73 | break 74 | 75 | path = os.path.join(root, name) 76 | print('parsing:',path) 77 | rpath = path.replace(img_root, "") 78 | 79 | wnid = pathlib.PurePath(path).parent.name 80 | 81 | if wnid not in folder_classes_this: 82 | folder_classes_this.append(wnid) 83 | 84 | try: 85 | img = read_image(path, format="BGR") 86 | height, width, _ = img.shape 87 | except Exception as e: 88 | print("*" * 60) 89 | print("fail to open image: ", e) 90 | print("*" * 60) 91 | continue 92 | 93 | # if width < 300 or height < 300: 94 | # continue 95 | 96 | # if width < 224 or height < 224: 97 | # continue 98 | 99 | # if width > 1000 or height > 1000: 100 | # continue 101 | 102 | path_h_w_this.append([rpath, height, width]) 103 | 104 | if len(path_h_w_this) >= min_per_dir: 105 | pass 106 | else: 107 | print("\tskip folder: ", root, " #image: ", len(path_h_w_this)) 108 | continue 109 | 110 | folder_classes.extend(folder_classes_this) 111 | path_h_w.extend(path_h_w_this) 112 | 113 | print("folder: ", root, " #image: ", len(path_h_w_this)) 114 | print("#folder: ", len(folder_classes), " #image: ", len(path_h_w)) 115 | 116 | folder_classes = list(set(folder_classes)) 117 | folder_classes.sort() 118 | 119 | print("#category: ", len(folder_classes), " categories: ", folder_classes) 120 | print("#image: ", len(path_h_w)) 121 | print("") 122 | with open(out_file,'w') as f: 123 | d = {} 124 | d['categories'] = folder_classes 125 | d['path_h_w'] = path_h_w 126 | json.dump(d,f) 127 | print(args) 128 | 129 | if __name__ == "__main__": 130 | main() -------------------------------------------------------------------------------- /tools/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | """ 4 | A main training script. 5 | 6 | This scripts reads a given config file and runs the training or evaluation. 7 | It is an entry point that is made to train standard models in detectron2. 8 | 9 | In order to let one script support training of many models, 10 | this script contains logic that are specific to these built-in models and therefore 11 | may not be suitable for your own project. 12 | For example, your research project perhaps only needs a single "evaluator". 13 | 14 | Therefore, we recommend you to use detectron2 as an library and take 15 | this file as an example of how to use the library. 16 | You may want to write your own script with your datasets and other customizations. 17 | """ 18 | 19 | import logging 20 | import os 21 | 22 | import detectron2.utils.comm as comm 23 | from detectron2.checkpoint import DetectionCheckpointer 24 | from detectron2.config import get_cfg 25 | from detectron2.engine import (default_argument_parser, default_setup,hooks,launch) 26 | from detectron2.utils.logger import setup_logger 27 | from detectron2.evaluation import verify_results 28 | from wsovod.config import add_wsovod_config 29 | 30 | 31 | def setup(args): 32 | """ 33 | Create configs and perform basic setups. 34 | """ 35 | cfg = get_cfg() 36 | add_wsovod_config(cfg) 37 | cfg.merge_from_file(args.config_file) 38 | cfg.merge_from_list(args.opts) 39 | cfg.freeze() 40 | default_setup(cfg, args) 41 | setup_logger(cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="wsovod") 42 | return cfg 43 | 44 | 45 | def main(args): 46 | cfg = setup(args) 47 | if "MixedDatasets" in args.config_file: 48 | from wsovod.engine import DefaultTrainer_WSOVOD_MixedDatasets as Trainer 49 | else: 50 | from wsovod.engine import DefaultTrainer_WSOVOD as Trainer 51 | 52 | if args.eval_only: 53 | model = Trainer.build_model(cfg) 54 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 55 | cfg.MODEL.WEIGHTS, resume=args.resume 56 | ) 57 | res = {} 58 | res = Trainer.test_WSL(cfg, model) 59 | if cfg.TEST.AUG.ENABLED: 60 | res.update(Trainer.test_with_TTA_WSL(cfg, model)) 61 | if comm.is_main_process(): 62 | verify_results(cfg, res) 63 | return res 64 | 65 | """ 66 | If you'd like to do anything fancier than the standard training logic, 67 | consider writing your own training loop (see plain_train_net.py) or 68 | subclassing the trainer. 69 | """ 70 | trainer = Trainer(cfg) 71 | trainer.resume_or_load(resume=args.resume) 72 | if cfg.TEST.AUG.ENABLED: 73 | # trainer.register_hooks([hooks.EvalHook(0, lambda: trainer.test_WSL(cfg, trainer.model))]) 74 | trainer.register_hooks( 75 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA_WSL(cfg, trainer.model))] 76 | ) 77 | return trainer.train() 78 | 79 | 80 | if __name__ == "__main__": 81 | args = default_argument_parser().parse_args() 82 | print("Command Line Args:", args) 83 | launch( 84 | main, 85 | args.num_gpus, 86 | num_machines=args.num_machines, 87 | machine_rank=args.machine_rank, 88 | dist_url=args.dist_url, 89 | args=(args,), 90 | ) 91 | -------------------------------------------------------------------------------- /tools/train_net_debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | """ 4 | A main training script. 5 | 6 | This scripts reads a given config file and runs the training or evaluation. 7 | It is an entry point that is made to train standard models in detectron2. 8 | 9 | In order to let one script support training of many models, 10 | this script contains logic that are specific to these built-in models and therefore 11 | may not be suitable for your own project. 12 | For example, your research project perhaps only needs a single "evaluator". 13 | 14 | Therefore, we recommend you to use detectron2 as an library and take 15 | this file as an example of how to use the library. 16 | You may want to write your own script with your datasets and other customizations. 17 | """ 18 | 19 | import logging 20 | import os 21 | 22 | import detectron2.utils.comm as comm 23 | from detectron2.checkpoint import DetectionCheckpointer 24 | from detectron2.config import get_cfg 25 | from detectron2.engine import (default_argument_parser, default_setup,hooks,launch) 26 | from detectron2.utils.logger import setup_logger 27 | from detectron2.evaluation import verify_results 28 | from wsovod.config import add_wsovod_config 29 | 30 | 31 | def setup(args): 32 | """ 33 | Create configs and perform basic setups. 34 | """ 35 | cfg = get_cfg() 36 | add_wsovod_config(cfg) 37 | cfg.merge_from_file(args.config_file) 38 | cfg.merge_from_list(args.opts) 39 | cfg.freeze() 40 | default_setup(cfg, args) 41 | setup_logger(cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="wsovod") 42 | return cfg 43 | 44 | 45 | def main(args): 46 | cfg = setup(args) 47 | if "MixedDatasets" in args.config_file: 48 | from wsovod.engine import DefaultTrainer_WSOVOD_MixedDatasets as Trainer 49 | else: 50 | from wsovod.engine import DefaultTrainer_WSOVOD as Trainer 51 | 52 | if args.eval_only: 53 | model = Trainer.build_model(cfg) 54 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 55 | cfg.MODEL.WEIGHTS, resume=args.resume 56 | ) 57 | res = {} 58 | res = Trainer.test_WSL(cfg, model) 59 | if cfg.TEST.AUG.ENABLED: 60 | res.update(Trainer.test_with_TTA_WSL(cfg, model)) 61 | if comm.is_main_process(): 62 | verify_results(cfg, res) 63 | return res 64 | 65 | """ 66 | If you'd like to do anything fancier than the standard training logic, 67 | consider writing your own training loop (see plain_train_net.py) or 68 | subclassing the trainer. 69 | """ 70 | trainer = Trainer(cfg) 71 | trainer.resume_or_load(resume=args.resume) 72 | if cfg.TEST.AUG.ENABLED: 73 | # trainer.register_hooks([hooks.EvalHook(0, lambda: trainer.test_WSL(cfg, trainer.model))]) 74 | trainer.register_hooks( 75 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA_WSL(cfg, trainer.model))] 76 | ) 77 | return trainer.train() 78 | 79 | 80 | if __name__ == "__main__": 81 | args = default_argument_parser().parse_args() 82 | args.config_file = 'configs/MixedDatasets-Detection/WSOVOD_MRRP_WSR_18_DC5_1x.yaml' 83 | # args.num_gpus = 1 84 | # # # args.eval_only = True 85 | # args.opts = ['OUTPUT_DIR','output/temp','SOLVER.IMS_PER_BATCH',1,'SOLVER.REFERENCE_WORLD_SIZE',1,] 86 | # 'MODEL.WEIGHTS','output/configs/COCO-Detection/WSOVOD_WSR_18_DC5_1x_20230621_003554/model_final.pth'] 87 | print("Command Line Args:", args) 88 | launch( 89 | main, 90 | args.num_gpus, 91 | num_machines=args.num_machines, 92 | machine_rank=args.machine_rank, 93 | dist_url=args.dist_url, 94 | args=(args,), 95 | ) 96 | -------------------------------------------------------------------------------- /tools/train_net_eval_open_vocabulary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | """ 4 | A main training script. 5 | 6 | This scripts reads a given config file and runs the training or evaluation. 7 | It is an entry point that is made to train standard models in detectron2. 8 | 9 | In order to let one script support training of many models, 10 | this script contains logic that are specific to these built-in models and therefore 11 | may not be suitable for your own project. 12 | For example, your research project perhaps only needs a single "evaluator". 13 | 14 | Therefore, we recommend you to use detectron2 as an library and take 15 | this file as an example of how to use the library. 16 | You may want to write your own script with your datasets and other customizations. 17 | """ 18 | 19 | import logging 20 | import os 21 | 22 | import detectron2.utils.comm as comm 23 | from detectron2.checkpoint import DetectionCheckpointer 24 | from detectron2.config import get_cfg 25 | from detectron2.engine import (default_argument_parser, default_setup,hooks,launch) 26 | from detectron2.utils.logger import setup_logger 27 | from detectron2.evaluation import verify_results 28 | from wsovod.config import add_wsovod_config 29 | 30 | 31 | def setup(args): 32 | """ 33 | Create configs and perform basic setups. 34 | """ 35 | cfg = get_cfg() 36 | add_wsovod_config(cfg) 37 | cfg.merge_from_file(args.config_file) 38 | cfg.merge_from_list(args.opts) 39 | cfg.freeze() 40 | default_setup(cfg, args) 41 | setup_logger(cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="wsovod") 42 | return cfg 43 | 44 | 45 | def main(args): 46 | cfg = setup(args) 47 | if "MixedDatasets" in args.config_file: 48 | from wsovod.engine import DefaultTrainer_WSOVOD_MixedDatasets as Trainer_ 49 | else: 50 | from wsovod.engine import DefaultTrainer_WSOVOD as Trainer_ 51 | from detectron2.evaluation import DatasetEvaluators 52 | from wsovod.evaluation.ov_coco_evaluation import OVCOCOEvaluator 53 | 54 | class Trainer(Trainer_): 55 | @classmethod 56 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 57 | """ 58 | Create evaluator(s) for a given dataset. 59 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 60 | For your own dataset, you can simply create an evaluator manually in your 61 | script and do not have to worry about the hacky if-else logic here. 62 | """ 63 | if output_folder is None: 64 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_" + dataset_name) 65 | evaluator_list = [] 66 | 67 | evaluator_list.append(OVCOCOEvaluator(dataset_name, output_dir=output_folder)) 68 | return DatasetEvaluators(evaluator_list) 69 | 70 | if args.eval_only: 71 | model = Trainer.build_model(cfg) 72 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 73 | cfg.MODEL.WEIGHTS, resume=args.resume 74 | ) 75 | res = {} 76 | res = Trainer.test_WSL(cfg, model) 77 | if cfg.TEST.AUG.ENABLED: 78 | res.update(Trainer.test_with_TTA_WSL(cfg, model)) 79 | if comm.is_main_process(): 80 | verify_results(cfg, res) 81 | return res 82 | 83 | """ 84 | If you'd like to do anything fancier than the standard training logic, 85 | consider writing your own training loop (see plain_train_net.py) or 86 | subclassing the trainer. 87 | """ 88 | trainer = Trainer(cfg) 89 | trainer.resume_or_load(resume=args.resume) 90 | if cfg.TEST.AUG.ENABLED: 91 | # trainer.register_hooks([hooks.EvalHook(0, lambda: trainer.test_WSL(cfg, trainer.model))]) 92 | trainer.register_hooks( 93 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA_WSL(cfg, trainer.model))] 94 | ) 95 | return trainer.train() 96 | 97 | 98 | if __name__ == "__main__": 99 | args = default_argument_parser().parse_args() 100 | print("Command Line Args:", args) 101 | launch( 102 | main, 103 | args.num_gpus, 104 | num_machines=args.num_machines, 105 | machine_rank=args.machine_rank, 106 | dist_url=args.dist_url, 107 | args=(args,), 108 | ) 109 | -------------------------------------------------------------------------------- /wsovod/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .config import add_wsovod_config 3 | from .data import * 4 | from .evaluation import * 5 | from .modeling import * 6 | from .solver import * 7 | 8 | # This line will be programatically read/write by setup.py. 9 | # Leave them at the bottom of this file and don't touch them. 10 | __version__ = "0.4" -------------------------------------------------------------------------------- /wsovod/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .defaults import add_wsovod_config 3 | -------------------------------------------------------------------------------- /wsovod/config/defaults.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from detectron2.config import CfgNode as CN 5 | 6 | 7 | def add_wsovod_config(cfg): 8 | """ 9 | Add config for wsovod. 10 | """ 11 | _C = cfg 12 | _C.WSOVOD = CN() 13 | _C.WSOVOD.ITER_SIZE = 1 14 | _C.WSOVOD.CLS_AGNOSTIC_BBOX_KNOWN = False 15 | _C.WSOVOD.SAMPLING = CN() 16 | _C.WSOVOD.SAMPLING.SAMPLING_ON = False 17 | _C.WSOVOD.SAMPLING.IOU_THRESHOLDS = [[0.5], [0.5], [0.5], [0.5]] 18 | _C.WSOVOD.SAMPLING.IOU_LABELS = [[0, 1], [0, 1], [0, 1], [0, 1]] 19 | _C.WSOVOD.SAMPLING.BATCH_SIZE_PER_IMAGE = [4096, 4096, 4096, 4096] 20 | _C.WSOVOD.SAMPLING.POSITIVE_FRACTION = [1.0, 1.0, 1.0, 1.0] 21 | _C.WSOVOD.OBJECT_MINING = CN() 22 | _C.WSOVOD.OBJECT_MINING.WEIGHT = 1.0 23 | _C.WSOVOD.OBJECT_MINING.MEAN_LOSS = True 24 | _C.WSOVOD.INSTANCE_REFINEMENT = CN() 25 | _C.WSOVOD.INSTANCE_REFINEMENT.WEIGHT = 1.0 26 | _C.WSOVOD.INSTANCE_REFINEMENT.REFINE_NUM = 3 27 | _C.WSOVOD.INSTANCE_REFINEMENT.REFINE_REG = [False, False, False] 28 | _C.WSOVOD.INSTANCE_REFINEMENT.REFINE_MIST = False 29 | _C.WSOVOD.INSTANCE_REFINEMENT.CROSS_ENTROPY_WEIGHTED = True 30 | _C.WSOVOD.BBOX_REFINE = CN() 31 | _C.WSOVOD.BBOX_REFINE.ENABLE = False 32 | _C.WSOVOD.BBOX_REFINE.MODEL_TYPE = "vit_b" # vit_h vit_l vit_b 33 | _C.WSOVOD.BBOX_REFINE.MODEL_CHECKPOINT = "tools/sam_checkpoints/sam_vit_b_01ec64.pth" # sam_checkpoints/sam_vit_h_4b8939.pth sam_checkpoints/sam_vit_l_0b3195.pth sam_checkpoints/sam_vit_b_01ec64.pth 34 | 35 | # VGG 36 | _C.MODEL.VGG = CN() 37 | _C.MODEL.VGG.DEPTH = 16 38 | _C.MODEL.VGG.OUT_FEATURES = ["plain5"] 39 | _C.MODEL.VGG.CONV5_DILATION = 1 40 | # Swin 41 | _C.MODEL.SWIN = CN() 42 | _C.MODEL.SWIN.EMBED_DIM = 96 43 | _C.MODEL.SWIN.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"] 44 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 45 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 46 | _C.MODEL.SWIN.WINDOW_SIZE = 7 47 | _C.MODEL.SWIN.MLP_RATIO = 4 48 | _C.MODEL.SWIN.DROP_PATH_RATE = 0.2 49 | _C.MODEL.SWIN.APE = False 50 | _C.MODEL.SWIN.PATH_NORM = True 51 | 52 | _C.MODEL.ROI_BOX_HEAD.DAN_DIM = [4096, 4096] 53 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY = CN() 54 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_PATH_TRAIN = '' 55 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_PATH_TEST = '' 56 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_DIM = 512 57 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.USE_BIAS = 0.0 58 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.NORM_WEIGHT = True 59 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.NORM_TEMP = 100.0 60 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.DATA_AWARE = False 61 | _C.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.PROTOTYPE_NUM = 5 62 | 63 | _C.MODEL.RPN.SCORE_THRESH_TRAIN = 0.2 64 | _C.MODEL.RPN.SCORE_THRESH_TEST = 0.2 65 | _C.MODEL.RPN.TOPK_CANDIDATES_TRAIN = 2000 66 | _C.MODEL.RPN.TOPK_CANDIDATES_TEST = 1000 67 | 68 | _C.MODEL.MRRP = CN() 69 | _C.MODEL.MRRP.MRRP_ON = False 70 | _C.MODEL.MRRP.NUM_BRANCH = 3 71 | _C.MODEL.MRRP.BRANCH_DILATIONS = [1, 2, 3] 72 | _C.MODEL.MRRP.MRRP_STAGE = "res4" 73 | _C.MODEL.MRRP.TEST_BRANCH_IDX = 1 74 | 75 | _C.DATALOADER.CLASS_ASPECT_RATIO_GROUPING = False 76 | _C.DATALOADER.GROUP_WAIT = 5 77 | 78 | _C.TEST.EVAL_TRAIN = False 79 | _C.VIS_TEST = False 80 | 81 | _C.SOLVER.OPTIMIZER = "SGD" 82 | _C.SOLVER.BACKBONE_MULTIPLIER = 1.0 83 | _C.SOLVER.BASE_LR_END = 0.1 84 | _C.SOLVER.IMS_PER_BATCH_LIST = [4] 85 | 86 | _C.DATASETS.MIXED_DATASETS = CN() 87 | _C.DATASETS.MIXED_DATASETS.NAMES = ['coco_2017_train'] 88 | _C.DATASETS.MIXED_DATASETS.WEIGHT_PATH_TRAINS = ['models/coco_text_embedding_single_prompt.pkl'] 89 | _C.DATASETS.MIXED_DATASETS.NUM_CLASSES = [80] 90 | _C.DATASETS.MIXED_DATASETS.PROPOSAL_FILES = [''] 91 | _C.DATASETS.MIXED_DATASETS.RATIOS = [1] 92 | _C.DATASETS.MIXED_DATASETS.USE_CAS = [False] 93 | _C.DATASETS.MIXED_DATASETS.USE_RFS = [True] 94 | _C.DATASETS.MIXED_DATASETS.FILTER_EMPTY_ANNOTATIONS = [True] 95 | _C.DATASETS.MIXED_DATASETS.CAS_LAMBDA = 1.0 96 | _C.DATASETS.MIXED_DATASETS.REPEAT_THRESHOLD = 0.001 -------------------------------------------------------------------------------- /wsovod/data/__init__.py: -------------------------------------------------------------------------------- 1 | # ensure the builtin datasets are registered 2 | from . import datasets # isort:skip 3 | 4 | from .build import build_detection_test_loader, build_detection_train_loader 5 | 6 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 7 | -------------------------------------------------------------------------------- /wsovod/data/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import itertools 4 | import logging 5 | import pickle 6 | import random 7 | 8 | import numpy as np 9 | import torch.utils.data as data 10 | from detectron2.utils.serialize import PicklableWrapper 11 | from torch.utils.data.sampler import Sampler 12 | from detectron2.data.common import MapDataset, AspectRatioGroupedDataset 13 | 14 | __all__ = [ 15 | "ClassAspectRatioGroupedDataset", 16 | "AspectRatioGroupedMixedDatasets", 17 | ] 18 | 19 | 20 | class ClassAspectRatioGroupedDataset(data.IterableDataset): 21 | """ 22 | Batch data that have similar aspect ratio together. 23 | In this implementation, images whose aspect ratio < (or >) 1 will 24 | be batched together. 25 | This improves training speed because the images then need less padding 26 | to form a batch. 27 | 28 | It assumes the underlying dataset produces dicts with "width" and "height" keys. 29 | It will then produce a list of original dicts with length = batch_size, 30 | all with similar aspect ratios. 31 | """ 32 | 33 | def __init__(self, dataset, batch_size): 34 | """ 35 | Args: 36 | dataset: an iterable. Each element must be a dict with keys 37 | "width" and "height", which will be used to batch data. 38 | batch_size (int): 39 | """ 40 | self.dataset = dataset 41 | self.batch_size = batch_size 42 | self._buckets = [[] for _ in range(2 * 2000)] 43 | 44 | def __iter__(self): 45 | for d in self.dataset: 46 | w, h = d["width"], d["height"] 47 | bucket_id = 0 if w > h else 1 48 | 49 | classes = list(set(d["instances"].gt_classes.tolist())) 50 | 51 | # for c in classes: 52 | # bucket = self._buckets[2 * c + bucket_id] 53 | # bucket.append(copy.deepcopy(d)) 54 | # if len(bucket) == self.batch_size: 55 | # yield bucket[:] 56 | # del bucket[:] 57 | # continue 58 | 59 | if len(classes) > 0: 60 | c = random.choice(classes) 61 | else: 62 | c = 0 63 | 64 | bucket_id = bucket_id + c * 2 65 | 66 | bucket = self._buckets[bucket_id] 67 | bucket.append(d) 68 | if len(bucket) == self.batch_size: 69 | yield bucket[:] 70 | del bucket[:] 71 | 72 | -------------------------------------------------------------------------------- /wsovod/data/dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import logging 4 | from typing import List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from detectron2.config import configurable 9 | from detectron2.data import transforms as T 10 | 11 | from . import detection_utils as utils 12 | 13 | """ 14 | This file contains the default mapping that's applied to "dataset dicts". 15 | """ 16 | 17 | __all__ = ["DatasetMapper"] 18 | 19 | 20 | class DatasetMapper: 21 | """ 22 | A callable which takes a dataset dict in Detectron2 Dataset format, 23 | and map it into a format used by the model. 24 | 25 | This is the default callable to be used to map your dataset dict into training data. 26 | You may need to follow it to implement your own one for customized logic, 27 | such as a different way to read or transform images. 28 | See :doc:`/tutorials/data_loading` for details. 29 | 30 | The callable currently does the following: 31 | 32 | 1. Read the image from "file_name" 33 | 2. Applies cropping/geometric transforms to the image and annotations 34 | 3. Prepare data and annotations to Tensor and :class:`Instances` 35 | """ 36 | 37 | @configurable 38 | def __init__( 39 | self, 40 | is_train: bool, 41 | *, 42 | augmentations: List[Union[T.Augmentation, T.Transform]], 43 | image_format: str, 44 | use_instance_mask: bool = False, 45 | use_keypoint: bool = False, 46 | instance_mask_format: str = "polygon", 47 | keypoint_hflip_indices: Optional[np.ndarray] = None, 48 | precomputed_proposal_topk: Optional[int] = None, 49 | recompute_boxes: bool = False, 50 | ): 51 | """ 52 | NOTE: this interface is experimental. 53 | 54 | Args: 55 | is_train: whether it's used in training or inference 56 | augmentations: a list of augmentations or deterministic transforms to apply 57 | image_format: an image format supported by :func:`detection_utils.read_image`. 58 | use_instance_mask: whether to process instance segmentation annotations, if available 59 | use_keypoint: whether to process keypoint annotations if available 60 | instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation 61 | masks into this format. 62 | keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` 63 | precomputed_proposal_topk: if given, will load pre-computed 64 | proposals from dataset_dict and keep the top k proposals for each image. 65 | recompute_boxes: whether to overwrite bounding box annotations 66 | by computing tight bounding boxes from instance mask annotations. 67 | """ 68 | if recompute_boxes: 69 | assert use_instance_mask, "recompute_boxes requires instance masks" 70 | # fmt: off 71 | self.is_train = is_train 72 | self.augmentations = T.AugmentationList(augmentations) 73 | self.image_format = image_format 74 | self.use_instance_mask = use_instance_mask 75 | self.instance_mask_format = instance_mask_format 76 | self.use_keypoint = use_keypoint 77 | self.keypoint_hflip_indices = keypoint_hflip_indices 78 | self.proposal_topk = precomputed_proposal_topk 79 | self.recompute_boxes = recompute_boxes 80 | # fmt: on 81 | logger = logging.getLogger(__name__) 82 | mode = "training" if is_train else "inference" 83 | logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") 84 | 85 | @classmethod 86 | def from_config(cls, cfg, is_train: bool = True): 87 | augs = utils.build_augmentation(cfg, is_train) 88 | if cfg.INPUT.CROP.ENABLED and is_train: 89 | augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) 90 | recompute_boxes = cfg.MODEL.MASK_ON 91 | else: 92 | recompute_boxes = False 93 | 94 | ret = { 95 | "is_train": is_train, 96 | "augmentations": augs, 97 | "image_format": cfg.INPUT.FORMAT, 98 | "use_instance_mask": cfg.MODEL.MASK_ON, 99 | "instance_mask_format": cfg.INPUT.MASK_FORMAT, 100 | "use_keypoint": cfg.MODEL.KEYPOINT_ON, 101 | "recompute_boxes": recompute_boxes, 102 | } 103 | 104 | if cfg.MODEL.KEYPOINT_ON: 105 | ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) 106 | 107 | if cfg.MODEL.LOAD_PROPOSALS: 108 | ret["precomputed_proposal_topk"] = ( 109 | cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN 110 | if is_train 111 | else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST 112 | ) 113 | return ret 114 | 115 | def _transform_annotations(self, dataset_dict, transforms, image_shape): 116 | # USER: Modify this if you want to keep them for some reason. 117 | for anno in dataset_dict["annotations"]: 118 | if not self.use_instance_mask: 119 | anno.pop("segmentation", None) 120 | if not self.use_keypoint: 121 | anno.pop("keypoints", None) 122 | 123 | # USER: Implement additional transformations if you have other types of data 124 | annos = [ 125 | utils.transform_instance_annotations( 126 | obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices 127 | ) 128 | for obj in dataset_dict.pop("annotations") 129 | if obj.get("iscrowd", 0) == 0 130 | ] 131 | instances = utils.annotations_to_instances( 132 | annos, image_shape, mask_format=self.instance_mask_format 133 | ) 134 | 135 | # After transforms such as cropping are applied, the bounding box may no longer 136 | # tightly bound the object. As an example, imagine a triangle object 137 | # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight 138 | # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to 139 | # the intersection of original bounding box and the cropping box. 140 | if self.recompute_boxes: 141 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 142 | dataset_dict["instances"] = utils.filter_empty_instances(instances) 143 | 144 | def __call__(self, dataset_dict): 145 | """ 146 | Args: 147 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 148 | 149 | Returns: 150 | dict: a format that builtin models in detectron2 accept 151 | """ 152 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 153 | # USER: Write your own image loading if it's not from a file 154 | image = utils.read_image(dataset_dict["file_name"], format=self.image_format) 155 | utils.check_image_size(dataset_dict, image) 156 | 157 | # USER: Remove if you don't do semantic/panoptic segmentation. 158 | if "sem_seg_file_name" in dataset_dict: 159 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) 160 | else: 161 | sem_seg_gt = None 162 | 163 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt) 164 | transforms = self.augmentations(aug_input) 165 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg 166 | 167 | image_shape = image.shape[:2] # h, w 168 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 169 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 170 | # Therefore it's important to use torch.Tensor. 171 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 172 | if sem_seg_gt is not None: 173 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) 174 | 175 | # USER: Remove if you don't use pre-computed proposals. 176 | # Most users would not need this feature. 177 | if self.proposal_topk is not None: 178 | utils.transform_proposals( 179 | dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk 180 | ) 181 | 182 | if not self.is_train: 183 | # USER: Modify this if you want to keep them for some reason. 184 | dataset_dict.pop("annotations", None) 185 | dataset_dict.pop("sem_seg_file_name", None) 186 | return dataset_dict 187 | 188 | if "annotations" in dataset_dict: 189 | self._transform_annotations(dataset_dict, transforms, image_shape) 190 | 191 | return dataset_dict 192 | 193 | 194 | -------------------------------------------------------------------------------- /wsovod/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import builtin # ensure the builtin datasets are registered 2 | -------------------------------------------------------------------------------- /wsovod/data/datasets/builtin.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | 5 | """ 6 | This file registers pre-defined datasets at hard-coded paths, and their metadata. 7 | 8 | We hard-code metadata for common datasets. This will enable: 9 | 1. Consistency check when loading the datasets 10 | 2. Use models on these standard datasets directly and run demos, 11 | without having to download the dataset annotations 12 | 13 | We hard-code some paths to the dataset that's assumed to 14 | exist in "./datasets/". 15 | 16 | Users SHOULD NOT use this file to create new dataset / metadata for new dataset. 17 | To add new dataset, refer to the tutorial "docs/DATASETS.md". 18 | """ 19 | 20 | import os 21 | 22 | from detectron2.data import DatasetCatalog, MetadataCatalog 23 | from detectron2.data.datasets.coco import load_coco_json 24 | from detectron2.data.datasets.coco_panoptic import ( 25 | register_coco_panoptic, register_coco_panoptic_separated) 26 | from detectron2.data.datasets.lvis import (get_lvis_instances_meta, 27 | register_lvis_instances) 28 | from detectron2.data.datasets.register_coco import register_coco_instances 29 | 30 | from .builtin_meta import _get_builtin_metadata 31 | from .pascal_voc import register_pascal_voc 32 | 33 | 34 | # ==== Predefined splits for PASCAL VOC =========== 35 | def register_all_pascal_voc(root): 36 | SPLITS = [ 37 | ("voc_2007_trainval", "VOC2007", "trainval"), 38 | ("voc_2007_train", "VOC2007", "train"), 39 | ("voc_2007_val", "VOC2007", "val"), 40 | ("voc_2007_test", "VOC2007", "test"), 41 | ("voc_2012_trainval", "VOC2012", "trainval"), 42 | ("voc_2012_train", "VOC2012", "train"), 43 | ("voc_2012_val", "VOC2012", "val"), 44 | ("voc_2012_test", "VOC2012", "test"), 45 | ] 46 | for name, dirname, split in SPLITS: 47 | year = 2007 if "2007" in name else 2012 48 | register_pascal_voc(name, os.path.join(root, dirname), split, year) 49 | MetadataCatalog.get(name).evaluator_type = "pascal_voc" 50 | 51 | 52 | # ==== Predefined datasets and splits for ImageNet ========== 53 | 54 | _PREDEFINED_SPLITS_ImageNet = {} 55 | _PREDEFINED_SPLITS_ImageNet["imagenet"] = { 56 | "ilsvrc_2012_val": ( 57 | "ILSVRC2012/val/", 58 | "ILSVRC2012/ILSVRC2012_img_val_converted.json", 59 | ), 60 | "ilsvrc_2012_train": ( 61 | "ILSVRC2012/train/", 62 | "ILSVRC2012/ILSVRC2012_img_train_converted.json", 63 | ), 64 | } 65 | 66 | def register_all_imagenet(root): 67 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_ImageNet.items(): 68 | for key, (image_root, json_file) in splits_per_dataset.items(): 69 | # Assume pre-defined datasets live in `./datasets`. 70 | register_coco_instances( 71 | key, 72 | _get_builtin_metadata(dataset_name), 73 | os.path.join(root, json_file) if "://" not in json_file else json_file, 74 | os.path.join(root, image_root), 75 | ) 76 | 77 | # True for open source; 78 | # Internally at fb, we register them elsewhere 79 | if __name__.endswith(".builtin"): 80 | # Assume pre-defined datasets live in `./datasets`. 81 | _root = os.getenv("WSOVOD_DATASETS", "datasets") 82 | register_all_pascal_voc(_root) 83 | register_all_imagenet(_root) -------------------------------------------------------------------------------- /wsovod/data/datasets/pascal_voc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import os 5 | import xml.etree.ElementTree as ET 6 | from typing import List, Tuple, Union 7 | 8 | import numpy as np 9 | from detectron2.data import DatasetCatalog, MetadataCatalog 10 | from detectron2.structures import BoxMode 11 | from detectron2.utils.file_io import PathManager 12 | from PIL import Image 13 | 14 | __all__ = ["load_voc_instances", "register_pascal_voc"] 15 | 16 | 17 | # fmt: off 18 | CLASS_NAMES = ( 19 | "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", 20 | "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", 21 | "pottedplant", "sheep", "sofa", "train", "tvmonitor" 22 | ) 23 | # fmt: on 24 | 25 | 26 | def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 27 | """ 28 | Load Pascal VOC detection annotations to Detectron2 format. 29 | 30 | Args: 31 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 32 | split (str): one of "train", "test", "val", "trainval" 33 | class_names: list or tuple of class names 34 | """ 35 | with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f: 36 | fileids = np.loadtxt(f, dtype=str) 37 | 38 | # Needs to read many small annotation files. Makes sense at local 39 | annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/")) 40 | dicts = [] 41 | for fileid in fileids: 42 | anno_file = os.path.join(annotation_dirname, fileid + ".xml") 43 | jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg") 44 | 45 | if not os.path.isfile(anno_file): 46 | with Image.open(jpeg_file) as img: 47 | width, height = img.size 48 | r = {"file_name": jpeg_file, "image_id": fileid, "height": height, "width": width} 49 | instances = [] 50 | r["annotations"] = instances 51 | dicts.append(r) 52 | continue 53 | 54 | with PathManager.open(anno_file) as f: 55 | tree = ET.parse(f) 56 | 57 | r = { 58 | "file_name": jpeg_file, 59 | "image_id": fileid, 60 | "height": int(tree.findall("./size/height")[0].text), 61 | "width": int(tree.findall("./size/width")[0].text), 62 | } 63 | instances = [] 64 | 65 | for obj in tree.findall("object"): 66 | cls = obj.find("name").text 67 | # We include "difficult" samples in training. 68 | # Based on limited experiments, they don't hurt accuracy. 69 | difficult = int(obj.find("difficult").text) 70 | if difficult == 1: 71 | continue 72 | bbox = obj.find("bndbox") 73 | bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]] 74 | # Original annotations are integers in the range [1, W or H] 75 | # Assuming they mean 1-based pixel indices (inclusive), 76 | # a box with annotation (xmin=1, xmax=W) covers the whole image. 77 | # In coordinate space this is represented by (xmin=0, xmax=W) 78 | bbox[0] -= 1.0 79 | bbox[1] -= 1.0 80 | instances.append( 81 | {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS} 82 | ) 83 | r["annotations"] = instances 84 | dicts.append(r) 85 | return dicts 86 | 87 | 88 | def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES): 89 | if name in DatasetCatalog: 90 | DatasetCatalog.remove(name) 91 | DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names)) 92 | MetadataCatalog.get(name).set( 93 | thing_classes=list(class_names), dirname=dirname, year=year, split=split 94 | ) 95 | -------------------------------------------------------------------------------- /wsovod/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed_sampler_multi_dataset import MultiDatasetTrainingSampler, InferenceSampler 2 | 3 | __all__ = [ 4 | "MultiDatasetTrainingSampler", 5 | "InferenceSampler", 6 | ] -------------------------------------------------------------------------------- /wsovod/data/samplers/distributed_sampler_multi_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import itertools 3 | import logging 4 | import math 5 | from collections import defaultdict 6 | from typing import Optional 7 | 8 | import torch 9 | from torch.utils.data.sampler import Sampler 10 | 11 | from detectron2.data.samplers import RepeatFactorTrainingSampler 12 | from detectron2.utils import comm 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class MultiDatasetTrainingSampler(Sampler): 18 | def __init__(self, repeat_factors, *, shuffle=True, seed=None): 19 | self._shuffle = shuffle 20 | if seed is None: 21 | seed = comm.shared_random_seed() 22 | self._seed = int(seed) 23 | 24 | self._rank = comm.get_rank() 25 | self._world_size = comm.get_world_size() 26 | 27 | # Split into whole number (_int_part) and fractional (_frac_part) parts. 28 | self._int_part = torch.trunc(repeat_factors) 29 | self._frac_part = repeat_factors - self._int_part 30 | 31 | @staticmethod 32 | def get_repeat_factors( 33 | dataset_dicts, num_datasets, dataset_ratio, use_rfs, use_cas, repeat_thresh, cas_lambda 34 | ): 35 | sizes = [0 for _ in range(num_datasets)] 36 | for d in dataset_dicts: 37 | sizes[d["dataset_id"]] += 1 38 | 39 | assert len(dataset_ratio) == len( 40 | sizes 41 | ), "length of dataset ratio {} should be equal to number if dataset {}".format( 42 | len(dataset_ratio), len(sizes) 43 | ) 44 | dataset_weight = [ 45 | torch.ones(s, dtype=torch.float32) * max(sizes) / s * r 46 | for i, (r, s) in enumerate(zip(dataset_ratio, sizes)) 47 | ] 48 | 49 | logger = logging.getLogger(__name__) 50 | logger.info( 51 | "Training sampler dataset weight: {}".format( 52 | str([max(sizes) / s * r for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]) 53 | ) 54 | ) 55 | 56 | st = 0 57 | repeat_factors = [] 58 | for i, s in enumerate(sizes): 59 | assert use_rfs[i] * use_cas[i] == 0 60 | if use_rfs[i]: 61 | repeat_factor = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( 62 | dataset_dicts[st : st + s], repeat_thresh 63 | ) 64 | elif use_cas[i]: 65 | repeat_factor = MultiDatasetTrainingSampler.get_class_balance_factor_per_dataset( 66 | dataset_dicts[st : st + s], l=cas_lambda 67 | ) 68 | repeat_factor = repeat_factor * (s / repeat_factor.sum()) 69 | else: 70 | repeat_factor = torch.ones(s, dtype=torch.float32) 71 | logger.info( 72 | "Training sampler class weight: {} {} {}".format( 73 | repeat_factor.size(), repeat_factor.max(), repeat_factor.min() 74 | ) 75 | ) 76 | repeat_factors.append(repeat_factor) 77 | st = st + s 78 | repeat_factors = torch.cat(repeat_factors) 79 | dataset_weight = torch.cat(dataset_weight) 80 | repeat_factors = dataset_weight * repeat_factors 81 | 82 | return repeat_factors 83 | 84 | @staticmethod 85 | def get_class_balance_factor_per_dataset(dataset_dicts, l=1.0): 86 | rep_factors = [] 87 | category_freq = defaultdict(int) 88 | for dataset_dict in dataset_dicts: # For each image (without repeats) 89 | cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} 90 | for cat_id in cat_ids: 91 | category_freq[cat_id] += 1 92 | for dataset_dict in dataset_dicts: 93 | cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} 94 | rep_factor = sum([1.0 / (category_freq[cat_id] ** l) for cat_id in cat_ids]) 95 | rep_factors.append(rep_factor) 96 | 97 | return torch.tensor(rep_factors, dtype=torch.float32) 98 | 99 | def _get_epoch_indices(self, generator): 100 | """ 101 | Create a list of dataset indices (with repeats) to use for one epoch. 102 | 103 | Args: 104 | generator (torch.Generator): pseudo random number generator used for 105 | stochastic rounding. 106 | 107 | Returns: 108 | torch.Tensor: list of dataset indices to use in one epoch. Each index 109 | is repeated based on its calculated repeat factor. 110 | """ 111 | # Since repeat factors are fractional, we use stochastic rounding so 112 | # that the target repeat factor is achieved in expectation over the 113 | # course of training 114 | rands = torch.rand(len(self._frac_part), generator=generator) 115 | rep_factors = self._int_part + (rands < self._frac_part).float() 116 | # Construct a list of indices in which we repeat images as specified 117 | indices = [] 118 | for dataset_index, rep_factor in enumerate(rep_factors): 119 | indices.extend([dataset_index] * int(rep_factor.item())) 120 | return torch.tensor(indices, dtype=torch.int64) 121 | 122 | def __iter__(self): 123 | start = self._rank 124 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) 125 | 126 | def _infinite_indices(self): 127 | g = torch.Generator() 128 | g.manual_seed(self._seed) 129 | while True: 130 | # Sample indices with repeats determined by stochastic rounding; each 131 | # "epoch" may have a slightly different size due to the rounding. 132 | indices = self._get_epoch_indices(g) 133 | if self._shuffle: 134 | randperm = torch.randperm(len(indices), generator=g) 135 | yield from indices[randperm].tolist() 136 | else: 137 | yield from indices.tolist() 138 | 139 | 140 | class InferenceSampler(Sampler): 141 | """ 142 | Produce indices for inference across all workers. 143 | Inference needs to run on the __exact__ set of samples, 144 | therefore when the total number of samples is not divisible by the number of workers, 145 | this sampler produces different number of samples on different workers. 146 | """ 147 | 148 | def __init__(self, size: int): 149 | """ 150 | Args: 151 | size (int): the total number of data of the underlying dataset to sample from 152 | """ 153 | self._size = size 154 | assert size > 0 155 | self._rank = comm.get_rank() 156 | self._world_size = comm.get_world_size() 157 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 158 | 159 | @staticmethod 160 | def _get_local_indices(total_size, world_size, rank): 161 | shard_size = total_size // world_size 162 | left = total_size % world_size 163 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 164 | 165 | begin = sum(shard_sizes[:rank]) 166 | end = min(sum(shard_sizes[: rank + 1]), total_size) 167 | if end - begin < max(shard_sizes): 168 | assert begin > 0 169 | begin = begin - 1 170 | return range(begin, end) 171 | 172 | def __iter__(self): 173 | yield from self._local_indices 174 | 175 | def __len__(self): 176 | return len(self._local_indices) -------------------------------------------------------------------------------- /wsovod/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .hooks import * 3 | from .trainer import * 4 | -------------------------------------------------------------------------------- /wsovod/engine/hooks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import datetime 5 | import itertools 6 | import logging 7 | import math 8 | import operator 9 | import os 10 | import tempfile 11 | import time 12 | import warnings 13 | from collections import Counter 14 | 15 | import detectron2.utils.comm as comm 16 | import torch 17 | from detectron2.engine.train_loop import HookBase 18 | from detectron2.evaluation.testing import flatten_results_dict 19 | from detectron2.solver import LRMultiplier 20 | from detectron2.utils.events import EventStorage, EventWriter 21 | from detectron2.utils.file_io import PathManager 22 | from fvcore.common.checkpoint import Checkpointer 23 | from fvcore.common.checkpoint import \ 24 | PeriodicCheckpointer as _PeriodicCheckpointer 25 | from fvcore.common.param_scheduler import ParamScheduler 26 | from fvcore.common.timer import Timer 27 | from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats 28 | 29 | # __all__ = [ 30 | # "CallbackHook", 31 | # "IterationTimer", 32 | # "PeriodicWriter", 33 | # "PeriodicCheckpointer", 34 | # "BestCheckpointer", 35 | # "LRScheduler", 36 | # "AutogradProfiler", 37 | # "EvalHook", 38 | # "PreciseBN", 39 | # "TorchProfiler", 40 | # "TorchMemoryStats", 41 | # ] 42 | 43 | 44 | """ 45 | Implement some common hooks. 46 | """ 47 | 48 | class ParametersNormInspectHook(HookBase): 49 | 50 | def __init__(self, period, model, p): 51 | self._period = period 52 | self._model = model 53 | self._p = p 54 | logger = logging.getLogger(__name__) 55 | logger.info('period, norm '+str((period,p))) 56 | 57 | @torch.no_grad() 58 | def _do_inspect(self): 59 | results = {} 60 | # logger = logging.getLogger(__name__) 61 | for key,val in self._model.named_parameters(recurse=True): 62 | results[key] = torch.norm(val,p=self._p) 63 | self.trainer.storage.put_scalar('parameters norm {}/{}'.format(self._p,key),torch.norm(val,p=self._p)) 64 | # logger.info(results) 65 | 66 | def after_step(self): 67 | next_iter = self.trainer.iter + 1 68 | if self._period > 0 and next_iter % self._period == 0: 69 | if next_iter != self.trainer.max_iter: 70 | self._do_inspect() 71 | -------------------------------------------------------------------------------- /wsovod/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .pascal_voc_evaluation import PascalVOCDetectionEvaluator_WSL 2 | from .coco_evaluation import COCOEvaluator 3 | 4 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 5 | -------------------------------------------------------------------------------- /wsovod/layers/ROILoopPool/ROILoopPool.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | #pragma once 3 | #include 4 | 5 | namespace wsovod { 6 | 7 | std::tuple ROILoopPool_forward_cpu( 8 | const at::Tensor& input, 9 | const at::Tensor& rois, 10 | const float spatial_scale, 11 | const int pooled_height, 12 | const int pooled_width); 13 | 14 | at::Tensor ROILoopPool_backward_cpu( 15 | const at::Tensor& grad, 16 | const at::Tensor& rois, 17 | const at::Tensor& argmax, 18 | const float spatial_scale, 19 | const int pooled_height, 20 | const int pooled_width, 21 | const int batch_size, 22 | const int channels, 23 | const int height, 24 | const int width); 25 | 26 | #if defined(WITH_CUDA) || defined(WITH_HIP) 27 | std::tuple ROILoopPool_forward_cuda( 28 | const at::Tensor& input, 29 | const at::Tensor& rois, 30 | const float spatial_scale, 31 | const int pooled_height, 32 | const int pooled_width); 33 | 34 | at::Tensor ROILoopPool_backward_cuda( 35 | const at::Tensor& grad, 36 | const at::Tensor& rois, 37 | const at::Tensor& argmax, 38 | const float spatial_scale, 39 | const int pooled_height, 40 | const int pooled_width, 41 | const int batch_size, 42 | const int channels, 43 | const int height, 44 | const int width); 45 | #endif 46 | 47 | // Interface for Python 48 | inline std::tuple ROILoopPool_forward( 49 | const at::Tensor& input, 50 | const at::Tensor& rois, 51 | const float spatial_scale, 52 | const int pooled_height, 53 | const int pooled_width) { 54 | if (input.is_cuda()) { 55 | #if defined(WITH_CUDA) || defined(WITH_HIP) 56 | return ROILoopPool_forward_cuda( 57 | input, rois, spatial_scale, pooled_height, pooled_width); 58 | #else 59 | AT_ERROR("Not compiled with GPU support"); 60 | #endif 61 | } 62 | AT_ERROR("Not compiled with CPU support"); 63 | return ROILoopPool_forward_cpu( 64 | input, rois, spatial_scale, pooled_height, pooled_width); 65 | } 66 | 67 | inline at::Tensor ROILoopPool_backward( 68 | const at::Tensor& grad, 69 | const at::Tensor& rois, 70 | const at::Tensor& argmax, 71 | const float spatial_scale, 72 | const int pooled_height, 73 | const int pooled_width, 74 | const int batch_size, 75 | const int channels, 76 | const int height, 77 | const int width) { 78 | if (grad.is_cuda()) { 79 | #if defined(WITH_CUDA) || defined(WITH_HIP) 80 | return ROILoopPool_backward_cuda( 81 | grad, 82 | rois, 83 | argmax, 84 | spatial_scale, 85 | pooled_height, 86 | pooled_width, 87 | batch_size, 88 | channels, 89 | height, 90 | width); 91 | #else 92 | AT_ERROR("Not compiled with GPU support"); 93 | #endif 94 | } 95 | AT_ERROR("Not compiled with CPU support"); 96 | return ROILoopPool_backward_cpu( 97 | grad, 98 | rois, 99 | argmax, 100 | spatial_scale, 101 | pooled_height, 102 | pooled_width, 103 | batch_size, 104 | channels, 105 | height, 106 | width); 107 | } 108 | 109 | } // namespace wsovod 110 | -------------------------------------------------------------------------------- /wsovod/layers/ROILoopPool/ROILoopPool_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "ROILoopPool.h" 7 | 8 | template 9 | inline void add(T* address, const T& val) { 10 | *address += val; 11 | } 12 | 13 | template 14 | void RoILoopPoolForward( 15 | const T* input, 16 | const T spatial_scale, 17 | const int channels, 18 | const int height, 19 | const int width, 20 | const int pooled_height, 21 | const int pooled_width, 22 | const T* rois, 23 | const int num_rois, 24 | T* output, 25 | int* argmax_data) { 26 | for (int n = 0; n < num_rois; ++n) { 27 | const T* offset_rois = rois + n * 5; 28 | int roi_batch_ind = offset_rois[0]; 29 | int roi_start_w = round(offset_rois[1] * spatial_scale); 30 | int roi_start_h = round(offset_rois[2] * spatial_scale); 31 | int roi_end_w = round(offset_rois[3] * spatial_scale); 32 | int roi_end_h = round(offset_rois[4] * spatial_scale); 33 | 34 | // Force malformed ROIs to be 1x1 35 | int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); 36 | int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); 37 | T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); 38 | T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); 39 | 40 | for (int ph = 0; ph < pooled_height; ++ph) { 41 | for (int pw = 0; pw < pooled_width; ++pw) { 42 | int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); 43 | int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); 44 | int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); 45 | int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); 46 | 47 | // Add roi offsets and clip to input boundaries 48 | hstart = std::min(std::max(hstart + roi_start_h, 0), height); 49 | hend = std::min(std::max(hend + roi_start_h, 0), height); 50 | wstart = std::min(std::max(wstart + roi_start_w, 0), width); 51 | wend = std::min(std::max(wend + roi_start_w, 0), width); 52 | bool is_empty = (hend <= hstart) || (wend <= wstart); 53 | 54 | for (int c = 0; c < channels; ++c) { 55 | // Define an empty pooling region to be zero 56 | T maxval = is_empty ? 0 : -FLT_MAX; 57 | // If nothing is pooled, argmax = -1 causes nothing to be backprop'd 58 | int maxidx = -1; 59 | 60 | const T* input_offset = 61 | input + (roi_batch_ind * channels + c) * height * width; 62 | 63 | for (int h = hstart; h < hend; ++h) { 64 | for (int w = wstart; w < wend; ++w) { 65 | int input_index = h * width + w; 66 | if (input_offset[input_index] > maxval) { 67 | maxval = input_offset[input_index]; 68 | maxidx = input_index; 69 | } 70 | } 71 | } 72 | int index = 73 | ((n * channels + c) * pooled_height + ph) * pooled_width + pw; 74 | output[index] = maxval; 75 | argmax_data[index] = maxidx; 76 | } // channels 77 | } // pooled_width 78 | } // pooled_height 79 | } // num_rois 80 | } 81 | 82 | template 83 | void RoILoopPoolBackward( 84 | const T* grad_output, 85 | const int* argmax_data, 86 | const int num_rois, 87 | const int channels, 88 | const int height, 89 | const int width, 90 | const int pooled_height, 91 | const int pooled_width, 92 | T* grad_input, 93 | const T* rois, 94 | const int n_stride, 95 | const int c_stride, 96 | const int h_stride, 97 | const int w_stride) { 98 | for (int n = 0; n < num_rois; ++n) { 99 | const T* offset_rois = rois + n * 5; 100 | int roi_batch_ind = offset_rois[0]; 101 | 102 | for (int c = 0; c < channels; ++c) { 103 | T* grad_input_offset = 104 | grad_input + ((roi_batch_ind * channels + c) * height * width); 105 | const int* argmax_data_offset = 106 | argmax_data + (n * channels + c) * pooled_height * pooled_width; 107 | 108 | for (int ph = 0; ph < pooled_height; ++ph) { 109 | for (int pw = 0; pw < pooled_width; ++pw) { 110 | int output_offset = n * n_stride + c * c_stride; 111 | int argmax = argmax_data_offset[ph * pooled_width + pw]; 112 | 113 | if (argmax != -1) { 114 | add(grad_input_offset + argmax, 115 | static_cast( 116 | grad_output 117 | [output_offset + ph * h_stride + pw * w_stride])); 118 | } 119 | } // pooled_width 120 | } // pooled_height 121 | } // channels 122 | } // num_rois 123 | } 124 | 125 | namespace wsovod { 126 | 127 | std::tuple ROILoopPool_forward_cpu( 128 | const at::Tensor& input, 129 | const at::Tensor& rois, 130 | const float spatial_scale, 131 | const int pooled_height, 132 | const int pooled_width) { 133 | AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); 134 | AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); 135 | 136 | at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; 137 | 138 | at::CheckedFrom c = "ROILoopPool_forward_cpu"; 139 | at::checkAllSameType(c, {input_t, rois_t}); 140 | 141 | int num_rois = rois.size(0); 142 | int channels = input.size(1); 143 | int height = input.size(2); 144 | int width = input.size(3); 145 | 146 | at::Tensor output = at::zeros( 147 | {num_rois, channels, pooled_height, pooled_width}, input.options()); 148 | at::Tensor argmax = at::zeros( 149 | {num_rois, channels, pooled_height, pooled_width}, 150 | input.options().dtype(at::kInt)); 151 | 152 | if (output.numel() == 0) { 153 | return std::make_tuple(output, argmax); 154 | } 155 | 156 | auto input_ = input.contiguous(), rois_ = rois.contiguous(); 157 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 158 | input.scalar_type(), "ROILoopPool_forward", [&] { 159 | RoILoopPoolForward( 160 | input_.data_ptr(), 161 | spatial_scale, 162 | channels, 163 | height, 164 | width, 165 | pooled_height, 166 | pooled_width, 167 | rois_.data_ptr(), 168 | num_rois, 169 | output.data_ptr(), 170 | argmax.data_ptr()); 171 | }); 172 | return std::make_tuple(output, argmax); 173 | } 174 | 175 | at::Tensor ROILoopPool_backward_cpu( 176 | const at::Tensor& grad, 177 | const at::Tensor& rois, 178 | const at::Tensor& argmax, 179 | const float spatial_scale, 180 | const int pooled_height, 181 | const int pooled_width, 182 | const int batch_size, 183 | const int channels, 184 | const int height, 185 | const int width) { 186 | // Check if input tensors are CPU tensors 187 | AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); 188 | AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); 189 | AT_ASSERTM(argmax.device().is_cpu(), "argmax must be a CPU tensor"); 190 | 191 | at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; 192 | 193 | at::CheckedFrom c = "ROILoopPool_backward_cpu"; 194 | at::checkAllSameType(c, {grad_t, rois_t}); 195 | 196 | auto num_rois = rois.size(0); 197 | 198 | at::Tensor grad_input = 199 | at::zeros({batch_size, channels, height, width}, grad.options()); 200 | 201 | // handle possibly empty gradients 202 | if (grad.numel() == 0) { 203 | return grad_input; 204 | } 205 | 206 | // get stride values to ensure indexing into gradients is correct. 207 | int n_stride = grad.stride(0); 208 | int c_stride = grad.stride(1); 209 | int h_stride = grad.stride(2); 210 | int w_stride = grad.stride(3); 211 | 212 | auto rois_ = rois.contiguous(); 213 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 214 | grad.scalar_type(), "ROILoopPool_backward", [&] { 215 | RoILoopPoolBackward( 216 | grad.data_ptr(), 217 | argmax.data_ptr(), 218 | num_rois, 219 | channels, 220 | height, 221 | width, 222 | pooled_height, 223 | pooled_width, 224 | grad_input.data_ptr(), 225 | rois_.data_ptr(), 226 | n_stride, 227 | c_stride, 228 | h_stride, 229 | w_stride); 230 | }); 231 | return grad_input; 232 | } 233 | 234 | } // namespace wsovod 235 | -------------------------------------------------------------------------------- /wsovod/layers/ROILoopPool/cuda_helpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 4 | for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ 5 | i += (blockDim.x * gridDim.x)) 6 | 7 | template 8 | constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { 9 | return (n + m - 1) / m; 10 | } 11 | -------------------------------------------------------------------------------- /wsovod/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .roi_loop_pool import ROILoopPool 2 | from .csc import CSC, CSCConstraint, csc, csc_constraint 3 | -------------------------------------------------------------------------------- /wsovod/layers/csc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | 6 | from wsovod import _C 7 | 8 | 9 | class _CSC(Function): 10 | @staticmethod 11 | def forward( 12 | ctx, 13 | cpgs, 14 | labels, 15 | preds, 16 | rois, 17 | tau, 18 | debug_info, 19 | fg_threshold, 20 | mass_threshold, 21 | density_threshold, 22 | area_sqrt, 23 | context_scale, 24 | ): 25 | PL = labels.clone().detach() 26 | NL = torch.zeros(labels.size(), dtype=labels.dtype, device=labels.device) 27 | 28 | W = _C.csc_forward( 29 | cpgs, 30 | labels, 31 | preds, 32 | rois, 33 | tau, 34 | debug_info, 35 | fg_threshold, 36 | mass_threshold, 37 | density_threshold, 38 | area_sqrt, 39 | context_scale, 40 | ) 41 | 42 | return W, PL, NL 43 | 44 | @staticmethod 45 | @once_differentiable 46 | def backward(ctx, dW, dPL, dNL): 47 | return None, None, None, None 48 | 49 | 50 | csc = _CSC.apply 51 | 52 | 53 | class CSC(nn.Module): 54 | def __init__( 55 | self, 56 | tau=0.7, 57 | debug_info=False, 58 | fg_threshold=0.1, 59 | mass_threshold=0.2, 60 | density_threshold=0.0, 61 | area_sqrt=True, 62 | context_scale=1.8, 63 | ): 64 | super(CSC, self).__init__() 65 | 66 | self.tau = tau 67 | self.debug_info = debug_info 68 | self.fg_threshold = fg_threshold 69 | self.mass_threshold = mass_threshold 70 | self.density_threshold = density_threshold 71 | self.area_sqrt = area_sqrt 72 | self.context_scale = context_scale 73 | 74 | def forward(self, cpgs, labels, preds, rois): 75 | return csc( 76 | cpgs, 77 | labels, 78 | preds, 79 | rois, 80 | self.tau, 81 | self.debug_info, 82 | self.fg_threshold, 83 | self.mass_threshold, 84 | self.density_threshold, 85 | self.area_sqrt, 86 | self.context_scale, 87 | ) 88 | 89 | def __repr__(self): 90 | tmpstr = self.__class__.__name__ + "(" 91 | tmpstr += "tau=" + str(self.tau) 92 | tmpstr += ", debug_info=" + str(self.debug_info) 93 | tmpstr += ", fg_threshold=" + str(self.fg_threshold) 94 | tmpstr += ", mass_threshold=" + str(self.mass_threshold) 95 | tmpstr += ", density_threshold=" + str(self.density_threshold) 96 | tmpstr += ", area_sqrt=" + str(self.area_sqrt) 97 | tmpstr += ", context_scale=" + str(self.context_scale) 98 | tmpstr += ")" 99 | return tmpstr 100 | 101 | 102 | class _CSCConstraint(Function): 103 | @staticmethod 104 | def forward(ctx, X, W, polar): 105 | 106 | if polar: 107 | W_ = torch.clamp(W, min=0.0) 108 | else: 109 | W_ = torch.clamp(W, max=0.0) 110 | W_ = W_ * (-1.0) 111 | 112 | ctx.save_for_backward(W_) 113 | 114 | Y = X * W_ 115 | 116 | return Y 117 | 118 | @staticmethod 119 | @once_differentiable 120 | def backward(ctx, dY): 121 | (W_,) = ctx.saved_tensors 122 | 123 | dX = dY * W_ 124 | 125 | return dX, None, None 126 | 127 | 128 | csc_constraint = _CSCConstraint.apply 129 | 130 | 131 | class CSCConstraint(nn.Module): 132 | def __init__(self, polar=True): 133 | super(CSCConstraint, self).__init__() 134 | 135 | self.polar = polar 136 | 137 | def forward(self, X, W): 138 | return csc_constraint(X, W, self.polar) 139 | 140 | def __repr__(self): 141 | tmpstr = self.__class__.__name__ + "(" 142 | tmpstr += "polar=" + str(self.polar) 143 | tmpstr += ")" 144 | return tmpstr -------------------------------------------------------------------------------- /wsovod/layers/csc/csc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace wsovod { 5 | 6 | #ifdef WITH_CUDA 7 | at::Tensor csc_forward_cuda( 8 | const at::Tensor& cpgs, 9 | const at::Tensor& labels, 10 | const at::Tensor& preds, 11 | const at::Tensor& rois, 12 | const float tau_, 13 | const bool debug_info_, 14 | const float fg_threshold_, 15 | const float mass_threshold_, 16 | const float density_threshold_, 17 | const bool area_sqrt_, 18 | const float context_scale_); 19 | #endif 20 | 21 | // Interface for Python 22 | inline at::Tensor csc_forward( 23 | const at::Tensor& cpgs, 24 | const at::Tensor& labels, 25 | const at::Tensor& preds, 26 | const at::Tensor& rois, 27 | const float tau_ = 0.7, 28 | const bool debug_info_ = false, 29 | const float fg_threshold_ = 0.1, 30 | const float mass_threshold_ = 0.2, 31 | const float density_threshold_ = 0.0, 32 | const bool area_sqrt_ = true, 33 | const float context_scale_ = 1.8) { 34 | if (cpgs.device().type() == at::kCUDA) { 35 | #ifdef WITH_CUDA 36 | return csc_forward_cuda( 37 | cpgs, 38 | labels, 39 | preds, 40 | rois, 41 | tau_, 42 | debug_info_, 43 | fg_threshold_, 44 | mass_threshold_, 45 | density_threshold_, 46 | area_sqrt_, 47 | context_scale_); 48 | #else 49 | AT_ERROR("Not compiled with GPU support"); 50 | #endif 51 | } 52 | AT_ERROR("Not compiled with CPU support"); 53 | } 54 | 55 | } // namespace wsovod 56 | -------------------------------------------------------------------------------- /wsovod/layers/roi_loop_pool.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | from torch.autograd.function import once_differentiable 4 | from torch.nn.modules.utils import _pair 5 | 6 | from wsovod import _C 7 | 8 | 9 | class _ROILoopPool(Function): 10 | @staticmethod 11 | def forward(ctx, input, roi, output_size, spatial_scale): 12 | ctx.output_size = _pair(output_size) 13 | ctx.spatial_scale = spatial_scale 14 | ctx.input_shape = input.size() 15 | output, argmax = _C.roi_loop_pool_forward( 16 | input, roi, spatial_scale, output_size[0], output_size[1] 17 | ) 18 | ctx.save_for_backward(roi, argmax) 19 | ctx.mark_non_differentiable(argmax) 20 | return output 21 | 22 | @staticmethod 23 | @once_differentiable 24 | def backward(ctx, grad_output): 25 | (rois, argmax) = ctx.saved_tensors 26 | output_size = ctx.output_size 27 | spatial_scale = ctx.spatial_scale 28 | bs, ch, h, w = ctx.input_shape 29 | grad_input = _C.roi_loop_pool_backward( 30 | grad_output, rois, argmax, spatial_scale, output_size[0], output_size[1], bs, ch, h, w 31 | ) 32 | return grad_input, None, None, None 33 | 34 | 35 | roi_loop_pool = _ROILoopPool.apply 36 | 37 | 38 | class ROILoopPool(nn.Module): 39 | def __init__(self, output_size, spatial_scale): 40 | super(ROILoopPool, self).__init__() 41 | self.output_size = output_size 42 | self.spatial_scale = spatial_scale 43 | 44 | def forward(self, input, rois): 45 | """ 46 | Args: 47 | input: NCHW images 48 | rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy. 49 | """ 50 | assert rois.dim() == 2 and rois.size(1) == 5 51 | return roi_loop_pool(input, rois, self.output_size, self.spatial_scale) 52 | 53 | def __repr__(self): 54 | tmpstr = self.__class__.__name__ + "(" 55 | tmpstr += "output_size=" + str(self.output_size) 56 | tmpstr += ", spatial_scale=" + str(self.spatial_scale) 57 | tmpstr += ")" 58 | return tmpstr 59 | -------------------------------------------------------------------------------- /wsovod/layers/vision.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #include 4 | #include "ROILoopPool/ROILoopPool.h" 5 | #include "csc/csc.h" 6 | 7 | namespace wsovod { 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 10 | m.def("roi_loop_pool_forward", &ROILoopPool_forward, "ROILoopPool_forward"); 11 | m.def("roi_loop_pool_backward", &ROILoopPool_backward, "ROILoopPool_backward"); 12 | m.def("csc_forward", &csc_forward, "csc_forward"); 13 | } 14 | 15 | } // namespace wsovod 16 | -------------------------------------------------------------------------------- /wsovod/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from .backbone import * 4 | from .meta_arch import * 5 | from .postprocessing import detector_postprocess 6 | from .roi_heads import * 7 | from .proposal_generator import * 8 | from .test_time_augmentation_avg import (DatasetMapperTTAAVG, 9 | GeneralizedRCNNWithTTAAVG) 10 | from .test_time_augmentation_union import (DatasetMapperTTAUNION, 11 | GeneralizedRCNNWithTTAUNION) 12 | 13 | 14 | 15 | _EXCLUDE = {"ShapeSpec"} 16 | __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] 17 | -------------------------------------------------------------------------------- /wsovod/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .swin_transformer import build_swin_backbone, build_swin_fpn_backbone 3 | from .resnet_wsl import build_wsl_resnet_backbone 4 | from .resnet_wsl_mrrp import build_mrrp_wsl_resnet_backbone 5 | from .vgg import VGG16, PlainBlockBase, build_vgg_backbone 6 | from .vgg_mrrp import build_mrrp_vgg_backbone 7 | 8 | # TODO can expose more resnet blocks after careful consideration 9 | -------------------------------------------------------------------------------- /wsovod/modeling/backbone/mrrp_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn.modules.utils import _pair 6 | 7 | from detectron2.layers.wrappers import _NewEmptyTensorOp 8 | 9 | 10 | class MRRPConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | paddings=0, 18 | dilations=1, 19 | groups=1, 20 | num_branch=1, 21 | test_branch_idx=-1, 22 | bias=False, 23 | norm=None, 24 | activation=None, 25 | ): 26 | super(MRRPConv, self).__init__() 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.kernel_size = _pair(kernel_size) 30 | self.num_branch = num_branch 31 | self.stride = _pair(stride) 32 | self.groups = groups 33 | self.with_bias = bias 34 | if isinstance(paddings, int): 35 | paddings = [paddings] * self.num_branch 36 | if isinstance(dilations, int): 37 | dilations = [dilations] * self.num_branch 38 | self.paddings = [_pair(padding) for padding in paddings] 39 | self.dilations = [_pair(dilation) for dilation in dilations] 40 | self.test_branch_idx = test_branch_idx 41 | self.norm = norm 42 | self.activation = activation 43 | 44 | assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1 45 | 46 | self.weight = nn.Parameter( 47 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 48 | ) 49 | if bias: 50 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 51 | else: 52 | self.bias = None 53 | 54 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 55 | if self.bias is not None: 56 | nn.init.constant_(self.bias, 0) 57 | 58 | def forward(self, inputs): 59 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 60 | assert len(inputs) == num_branch 61 | 62 | if inputs[0].numel() == 0: 63 | output_shape = [ 64 | (i + 2 * p - (di * (k - 1) + 1)) // s + 1 65 | for i, p, di, k, s in zip( 66 | inputs[0].shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride 67 | ) 68 | ] 69 | output_shape = [input[0].shape[0], self.weight.shape[0]] + output_shape 70 | return [_NewEmptyTensorOp.apply(input, output_shape) for input in inputs] 71 | 72 | if self.training or self.test_branch_idx == -1: 73 | outputs = [ 74 | F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups) 75 | for input, dilation, padding in zip(inputs, self.dilations, self.paddings) 76 | ] 77 | else: 78 | outputs = [ 79 | F.conv2d( 80 | inputs[0], 81 | self.weight, 82 | self.bias, 83 | self.stride, 84 | self.paddings[self.test_branch_idx], 85 | self.dilations[self.test_branch_idx], 86 | self.groups, 87 | ) 88 | ] 89 | 90 | if self.norm is not None: 91 | outputs = [self.norm(x) for x in outputs] 92 | if self.activation is not None: 93 | outputs = [self.activation(x) for x in outputs] 94 | return outputs 95 | 96 | def extra_repr(self): 97 | tmpstr = "in_channels=" + str(self.in_channels) 98 | tmpstr += ", out_channels=" + str(self.out_channels) 99 | tmpstr += ", kernel_size=" + str(self.kernel_size) 100 | tmpstr += ", num_branch=" + str(self.num_branch) 101 | tmpstr += ", test_branch_idx=" + str(self.test_branch_idx) 102 | tmpstr += ", stride=" + str(self.stride) 103 | tmpstr += ", paddings=" + str(self.paddings) 104 | tmpstr += ", dilations=" + str(self.dilations) 105 | tmpstr += ", groups=" + str(self.groups) 106 | tmpstr += ", bias=" + str(self.with_bias) 107 | return tmpstr -------------------------------------------------------------------------------- /wsovod/modeling/backbone/vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import fvcore.nn.weight_init as weight_init 3 | import torch.nn.functional as F 4 | from detectron2.layers import Conv2d, FrozenBatchNorm2d, ShapeSpec 5 | from detectron2.modeling.backbone.backbone import Backbone 6 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY 7 | from torch import nn 8 | 9 | __all__ = ["PlainBlockBase", "PlainBlock", "VGG16", "build_vgg_backbone"] 10 | 11 | 12 | class PlainBlockBase(nn.Module): 13 | def __init__(self, in_channels, out_channels, stride): 14 | """ 15 | The `__init__` method of any subclass should also contain these arguments. 16 | 17 | Args: 18 | in_channels (int): 19 | out_channels (int): 20 | stride (int): 21 | """ 22 | super().__init__() 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.stride = stride 26 | 27 | def freeze(self): 28 | for p in self.parameters(): 29 | p.requires_grad = False 30 | FrozenBatchNorm2d.convert_frozen_batchnorm(self) 31 | return self 32 | 33 | 34 | class PlainBlock(PlainBlockBase): 35 | def __init__(self, in_channels, out_channels, num_conv=3, dilation=1, stride=1, has_pool=False): 36 | super().__init__(in_channels, out_channels, stride) 37 | 38 | self.num_conv = num_conv 39 | self.dilation = dilation 40 | 41 | self.has_pool = has_pool 42 | self.pool_stride = stride 43 | 44 | self.conv1 = Conv2d( 45 | in_channels, 46 | out_channels, 47 | kernel_size=3, 48 | stride=1, 49 | padding=1 * dilation, 50 | bias=True, 51 | groups=1, 52 | dilation=dilation, 53 | norm=None, 54 | ) 55 | weight_init.c2_msra_fill(self.conv1) 56 | 57 | self.conv2 = Conv2d( 58 | out_channels, 59 | out_channels, 60 | kernel_size=3, 61 | stride=1, 62 | padding=1 * dilation, 63 | bias=True, 64 | groups=1, 65 | dilation=dilation, 66 | norm=None, 67 | ) 68 | weight_init.c2_msra_fill(self.conv2) 69 | 70 | if self.num_conv > 2: 71 | self.conv3 = Conv2d( 72 | out_channels, 73 | out_channels, 74 | kernel_size=3, 75 | stride=1, 76 | padding=1 * dilation, 77 | bias=True, 78 | groups=1, 79 | dilation=dilation, 80 | norm=None, 81 | ) 82 | weight_init.c2_msra_fill(self.conv3) 83 | 84 | if self.num_conv > 3: 85 | self.conv4 = Conv2d( 86 | out_channels, 87 | out_channels, 88 | kernel_size=3, 89 | stride=1, 90 | padding=1 * dilation, 91 | bias=True, 92 | groups=1, 93 | dilation=dilation, 94 | norm=None, 95 | ) 96 | weight_init.c2_msra_fill(self.conv4) 97 | 98 | if self.has_pool: 99 | self.pool = nn.MaxPool2d(kernel_size=2, stride=self.pool_stride, padding=0) 100 | 101 | assert num_conv < 5 102 | 103 | def forward(self, x): 104 | x = self.conv1(x) 105 | x = F.relu_(x) 106 | 107 | x = self.conv2(x) 108 | x = F.relu_(x) 109 | 110 | if self.num_conv > 2: 111 | x = self.conv3(x) 112 | x = F.relu_(x) 113 | 114 | if self.num_conv > 3: 115 | x = self.conv4(x) 116 | x = F.relu_(x) 117 | 118 | if self.has_pool: 119 | x = self.pool(x) 120 | 121 | return x 122 | 123 | 124 | class VGG16(Backbone): 125 | def __init__(self, conv5_dilation, freeze_at, num_classes=None, out_features=None): 126 | """ 127 | Args: 128 | stem (nn.Module): a stem module 129 | stages (list[list[ResNetBlock]]): several (typically 4) stages, 130 | each contains multiple :class:`ResNetBlockBase`. 131 | num_classes (None or int): if None, will not perform classification. 132 | out_features (list[str]): name of the layers whose outputs should 133 | be returned in forward. Can be anything in "stem", "linear", or "res2" ... 134 | If None, will return the output of the last layer. 135 | """ 136 | super(VGG16, self).__init__() 137 | 138 | self.num_classes = num_classes 139 | 140 | self._out_feature_strides = {} 141 | self._out_feature_channels = {} 142 | 143 | self.stages_and_names = [] 144 | 145 | name = "plain1" 146 | block = PlainBlock(3, 64, num_conv=2, stride=2, has_pool=True) 147 | blocks = [block] 148 | stage = nn.Sequential(*blocks) 149 | self.add_module(name, stage) 150 | self.stages_and_names.append((stage, name)) 151 | self._out_feature_strides[name] = 2 152 | self._out_feature_channels[name] = blocks[-1].out_channels 153 | if freeze_at >= 1: 154 | for block in blocks: 155 | block.freeze() 156 | 157 | name = "plain2" 158 | block = PlainBlock(64, 128, num_conv=2, stride=2, has_pool=True) 159 | blocks = [block] 160 | stage = nn.Sequential(*blocks) 161 | self.add_module(name, stage) 162 | self.stages_and_names.append((stage, name)) 163 | self._out_feature_strides[name] = 4 164 | self._out_feature_channels[name] = blocks[-1].out_channels 165 | if freeze_at >= 2: 166 | for block in blocks: 167 | block.freeze() 168 | 169 | name = "plain3" 170 | block = PlainBlock(128, 256, num_conv=3, stride=2, has_pool=True) 171 | blocks = [block] 172 | stage = nn.Sequential(*blocks) 173 | self.add_module(name, stage) 174 | self.stages_and_names.append((stage, name)) 175 | self._out_feature_strides[name] = 8 176 | self._out_feature_channels[name] = blocks[-1].out_channels 177 | if freeze_at >= 3: 178 | for block in blocks: 179 | block.freeze() 180 | 181 | name = "plain4" 182 | block = PlainBlock( 183 | 256, 512, num_conv=3, stride=1 if conv5_dilation == 2 else 2, has_pool=True 184 | ) 185 | blocks = [block] 186 | stage = nn.Sequential(*blocks) 187 | self.add_module(name, stage) 188 | self.stages_and_names.append((stage, name)) 189 | self._out_feature_strides[name] = 8 if conv5_dilation == 2 else 16 190 | self._out_feature_channels[name] = blocks[-1].out_channels 191 | if freeze_at >= 4: 192 | for block in blocks: 193 | block.freeze() 194 | 195 | name = "plain5" 196 | block = PlainBlock(512, 512, num_conv=3, stride=1, dilation=conv5_dilation, has_pool=False) 197 | blocks = [block] 198 | stage = nn.Sequential(*blocks) 199 | self.add_module(name, stage) 200 | self.stages_and_names.append((stage, name)) 201 | self._out_feature_strides[name] = 8 if conv5_dilation == 2 else 16 202 | self._out_feature_channels[name] = blocks[-1].out_channels 203 | if freeze_at >= 5: 204 | for block in blocks: 205 | block.freeze() 206 | 207 | if out_features is None: 208 | out_features = [name] 209 | self._out_features = out_features 210 | assert len(self._out_features) 211 | children = [x[0] for x in self.named_children()] 212 | for out_feature in self._out_features: 213 | assert out_feature in children, "Available children: {}".format(", ".join(children)) 214 | 215 | def forward(self, x): 216 | outputs = {} 217 | for stage, name in self.stages_and_names: 218 | x = stage(x) 219 | if name in self._out_features: 220 | outputs[name] = x 221 | 222 | return outputs 223 | 224 | def output_shape(self): 225 | return { 226 | name: ShapeSpec( 227 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 228 | ) 229 | for name in self._out_features 230 | } 231 | 232 | 233 | @BACKBONE_REGISTRY.register() 234 | def build_vgg_backbone(cfg, input_shape): 235 | 236 | # fmt: off 237 | depth = cfg.MODEL.VGG.DEPTH 238 | conv5_dilation = cfg.MODEL.VGG.CONV5_DILATION 239 | freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT 240 | # fmt: on 241 | 242 | if depth == 16: 243 | return VGG16(conv5_dilation, freeze_at) 244 | -------------------------------------------------------------------------------- /wsovod/modeling/class_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .open_vocabulary_classifier import OpenVocabularyClassifier -------------------------------------------------------------------------------- /wsovod/modeling/class_heads/data_aware_features_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Tuple 5 | 6 | import cv2 7 | import fvcore.nn.weight_init as weight_init 8 | import numpy as np 9 | import torch 10 | from detectron2.config import configurable 11 | from detectron2.layers import Conv2d, Linear, ShapeSpec, get_norm 12 | from detectron2.structures import Boxes, ImageList, Instances 13 | from detectron2.utils.events import get_event_storage 14 | from torch import nn 15 | from torch.nn import functional as F 16 | import logging 17 | 18 | 19 | class DataAwareFeaturesHead(nn.Module): 20 | @configurable 21 | def __init__( 22 | self, 23 | input_shape: ShapeSpec, 24 | *, 25 | datasets_prototype_num: int = 5, 26 | features_dim: int = 512, 27 | cls_in_features: List[str], 28 | mrrp_on: bool = False, 29 | mrrp_num_branch: int = 3, 30 | ): 31 | """ 32 | NOTE: this interface is experimental. 33 | 34 | Args: 35 | input_shape (ShapeSpec): shape of the input feature. 36 | conv_dims (list[int]): the output dimensions of the conv layers 37 | fc_dims (list[int]): the output dimensions of the fc layers 38 | conv_norm (str or callable): normalization for the conv layers. 39 | See :func:`detectron2.layers.get_norm` for supported types. 40 | """ 41 | super().__init__() 42 | 43 | self.in_features = self.cls_in_features = cls_in_features 44 | self.features_dim = features_dim 45 | self.mrrp_on = mrrp_on 46 | self.mrrp_num_branch = mrrp_num_branch 47 | 48 | self.in_channels = [input_shape[f].channels for f in self.in_features] 49 | in_channels = [input_shape[f].channels for f in self.in_features] 50 | # Check all channel counts are equal 51 | assert len(set(in_channels)) == 1, in_channels 52 | in_channels = in_channels[0] 53 | 54 | self.strides = [input_shape[f].stride for f in self.in_features] 55 | 56 | self._output_size = (in_channels,) 57 | 58 | self.datasets_prototype_num = datasets_prototype_num 59 | self.datasets_feat = nn.Embedding(self.datasets_prototype_num, self.features_dim) 60 | 61 | self.GAP = nn.AdaptiveAvgPool2d(1) 62 | 63 | fc_dims = [in_channels//16,self.datasets_prototype_num] 64 | 65 | self.fcs = [] 66 | for k, fc_dim in enumerate(fc_dims): 67 | fc = Linear(np.prod(self._output_size), fc_dim) 68 | self.fcs.append(fc) 69 | self.add_module("linear{}".format(k + 1), fc) 70 | if k < len(fc_dims) - 1: 71 | relu = nn.ReLU(inplace=True) 72 | self.fcs.append(relu) 73 | self.add_module("linear_relu{}".format(k + 1), relu) 74 | else: 75 | tanh = nn.Tanh() 76 | self.fcs.append(tanh) 77 | self.add_module("linear_tanh{}".format(k + 1), tanh) 78 | self._output_size = fc_dim 79 | 80 | for layer in self.fcs: 81 | if not isinstance(layer, Linear): 82 | continue 83 | nn.init.uniform_(layer.weight, -0.01, 0.01) 84 | torch.nn.init.constant_(layer.bias, 0) 85 | 86 | @classmethod 87 | def from_config(cls, cfg, input_shape): 88 | # fmt: off 89 | in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES 90 | mrrp_on = cfg.MODEL.MRRP.MRRP_ON 91 | mrrp_num_branch = cfg.MODEL.MRRP.NUM_BRANCH 92 | datasets_prototype_num = cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.PROTOTYPE_NUM 93 | # fmt: on 94 | return { 95 | "cls_in_features": in_features, 96 | "input_shape": input_shape, 97 | "datasets_prototype_num": datasets_prototype_num, 98 | "features_dim": cfg.MODEL.ROI_BOX_HEAD.DAN_DIM[-1], 99 | "mrrp_on": mrrp_on, 100 | "mrrp_num_branch": mrrp_num_branch, 101 | } 102 | 103 | def forward( 104 | self, 105 | features: Dict[str, torch.Tensor], 106 | proposals, 107 | ): 108 | 109 | features = [features[f] for f in self.cls_in_features] 110 | if self.mrrp_on: 111 | features = [torch.stack(torch.chunk(f, self.mrrp_num_branch)).mean(0) for f in features] 112 | data_features = [self._forward(f) for f in features] 113 | if len(data_features)>1: 114 | data_features = torch.stack(data_features).mean(0) 115 | else: 116 | data_features = data_features[0] 117 | results = [] 118 | for i in range(len(proposals)): 119 | results.append(data_features[i].repeat(len(proposals[i]),1)) 120 | results = torch.cat(results) 121 | return results 122 | 123 | def _forward(self, x): 124 | x = self.GAP(x) 125 | x = x.flatten(start_dim=1) 126 | for k, layer in enumerate(self.fcs): 127 | x = layer(x) 128 | combined = torch.matmul(x, self.datasets_feat.weight) 129 | return combined 130 | 131 | 132 | -------------------------------------------------------------------------------- /wsovod/modeling/class_heads/open_vocabulary_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from math import fabs 8 | from detectron2.config import configurable 9 | from detectron2.layers import Linear, ShapeSpec 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class OpenVocabularyClassifier(nn.Module): 15 | @configurable 16 | def __init__( 17 | self, 18 | input_shape: ShapeSpec, 19 | *, 20 | num_classes: int, 21 | weight_path: str, 22 | weight_dim: int = 512, 23 | use_bias: float = 0.0, 24 | norm_weight: bool = True, 25 | norm_temperature: float = 50.0, 26 | ): 27 | super().__init__() 28 | if isinstance(input_shape, int): # some backward compatibility 29 | input_shape = ShapeSpec(channels=input_shape) 30 | input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) 31 | self.norm_weight = norm_weight 32 | self.weight_dim = weight_dim 33 | self.norm_temperature = norm_temperature 34 | 35 | self.use_bias = fabs(use_bias)>1e-9 36 | if self.use_bias: 37 | self.cls_bias = nn.Parameter(torch.ones(1) * use_bias) 38 | 39 | self.projection = nn.Sequential( 40 | nn.Linear(input_size, 1024), 41 | nn.ReLU(), 42 | nn.Linear(1024, weight_dim), 43 | nn.ReLU(), 44 | ) 45 | 46 | 47 | if weight_path == "rand": 48 | class_weight = torch.randn((weight_dim, num_classes)) 49 | nn.init.normal_(class_weight, std=0.01) 50 | else: 51 | logger.info("Loading " + weight_path) 52 | class_weight = ( 53 | torch.tensor(np.load(weight_path,encoding='bytes', allow_pickle=True), dtype=torch.float32) 54 | .permute(1, 0) 55 | .contiguous() 56 | ) # D x C 57 | logger.info(f"Loaded class weight {class_weight.size()}") 58 | 59 | if self.norm_weight: 60 | class_weight = F.normalize(class_weight, p=2, dim=0) 61 | 62 | if weight_path == "rand": 63 | self.class_weight = nn.Parameter(class_weight) 64 | else: 65 | self.register_buffer("class_weight", class_weight) 66 | 67 | @classmethod 68 | def from_config(cls, cfg, input_shape, weight_path = None, use_bias = None, norm_weight = None, norm_temperature = None): 69 | return { 70 | "input_shape": input_shape, 71 | "num_classes": cfg.MODEL.ROI_HEADS.NUM_CLASSES, 72 | "weight_path": weight_path if weight_path is not None else cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_PATH_TRAIN, 73 | "weight_dim": cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.WEIGHT_DIM, 74 | "use_bias": use_bias if use_bias is not None else cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.USE_BIAS, 75 | "norm_weight": norm_weight if norm_weight is not None else cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.NORM_WEIGHT, 76 | "norm_temperature": norm_temperature if norm_temperature is not None else cfg.MODEL.ROI_BOX_HEAD.OPEN_VOCABULARY.NORM_TEMP, 77 | } 78 | 79 | def forward(self, x, classifier=None, append_background = False): 80 | """ 81 | Inputs: 82 | x: B x D 83 | classifier: (C', C' x D) 84 | """ 85 | x = self.projection(x) 86 | 87 | if classifier is not None: 88 | class_weight = classifier.permute(1, 0).contiguous() # D x C' 89 | class_weight = F.normalize(class_weight, p=2, dim=0) if self.norm_weight else class_weight 90 | else: 91 | class_weight = self.class_weight 92 | 93 | if self.norm_weight: 94 | x = self.norm_temperature * F.normalize(x, p=2, dim=1) 95 | 96 | if append_background: 97 | class_weight = torch.cat( 98 | [class_weight, class_weight.new_zeros((self.weight_dim, 1))], dim=1 99 | ) # D x (C + 1) 100 | # logger.info(f"Cated class weight {class_weight.size()}") 101 | 102 | x = torch.mm(x, class_weight) 103 | if self.use_bias: 104 | x = x + self.cls_bias 105 | return x -------------------------------------------------------------------------------- /wsovod/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | import imp 4 | from .rcnn_wsovod import GeneralizedRCNN_WSOVOD 5 | from .rcnn_wsovod_mixed_datasets import GeneralizedRCNN_WSOVOD_MixedDatasets -------------------------------------------------------------------------------- /wsovod/modeling/postprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | from detectron2.structures import Instances, ROIMasks 4 | from torch.nn import functional as F 5 | 6 | 7 | # perhaps should rename to "resize_instance" 8 | def detector_postprocess( 9 | results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 10 | ): 11 | """ 12 | Resize the output instances. 13 | The input images are often resized when entering an object detector. 14 | As a result, we often need the outputs of the detector in a different 15 | resolution from its inputs. 16 | 17 | This function will resize the raw outputs of an R-CNN detector 18 | to produce outputs according to the desired output resolution. 19 | 20 | Args: 21 | results (Instances): the raw outputs from the detector. 22 | `results.image_size` contains the input image resolution the detector sees. 23 | This object might be modified in-place. 24 | output_height, output_width: the desired output resolution. 25 | 26 | Returns: 27 | Instances: the resized output from the model, based on the output resolution 28 | """ 29 | # Change to 'if is_tracing' after PT1.7 30 | if isinstance(output_height, torch.Tensor): 31 | # Converts integer tensors to float temporaries to ensure true 32 | # division is performed when computing scale_x and scale_y. 33 | output_width_tmp = output_width.float() 34 | output_height_tmp = output_height.float() 35 | new_size = torch.stack([output_height, output_width]) 36 | else: 37 | new_size = (output_height, output_width) 38 | output_width_tmp = output_width 39 | output_height_tmp = output_height 40 | 41 | scale_x, scale_y = ( 42 | output_width_tmp / results.image_size[1], 43 | output_height_tmp / results.image_size[0], 44 | ) 45 | results = Instances(new_size, **results.get_fields()) 46 | 47 | if results.has("pred_boxes"): 48 | output_boxes = results.pred_boxes 49 | elif results.has("proposal_boxes"): 50 | output_boxes = results.proposal_boxes 51 | else: 52 | output_boxes = None 53 | assert output_boxes is not None, "Predictions must contain boxes!" 54 | 55 | output_boxes.scale(scale_x, scale_y) 56 | output_boxes.clip(results.image_size) 57 | 58 | results = results[output_boxes.nonempty()] 59 | 60 | if results.has("pred_masks") and results.has("no_paste"): 61 | results.pred_masks = F.interpolate( 62 | results.pred_masks, 63 | size=(output_height, output_width), 64 | mode="bilinear", 65 | align_corners=False, 66 | ) 67 | results.pred_masks = (results.pred_masks[:, 0, :, :] >= mask_threshold).to(dtype=torch.bool) 68 | elif results.has("pred_masks"): 69 | if isinstance(results.pred_masks, ROIMasks): 70 | roi_masks = results.pred_masks 71 | else: 72 | # pred_masks is a tensor of shape (N, 1, M, M) 73 | roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) 74 | results.pred_masks = roi_masks.to_bitmasks( 75 | results.pred_boxes, output_height, output_width, mask_threshold 76 | ).tensor # TODO return ROIMasks/BitMask object in the future 77 | 78 | if results.has("pred_keypoints"): 79 | results.pred_keypoints[:, :, 0] *= scale_x 80 | results.pred_keypoints[:, :, 1] *= scale_y 81 | 82 | return results 83 | -------------------------------------------------------------------------------- /wsovod/modeling/proposal_generator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .rpn import WSOVODRPN_V2, WSOVODRPN -------------------------------------------------------------------------------- /wsovod/modeling/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .box_head import DiscriminativeAdaptationNeck -------------------------------------------------------------------------------- /wsovod/modeling/roi_heads/box_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from typing import List 3 | 4 | import fvcore.nn.weight_init as weight_init 5 | import numpy as np 6 | import torch 7 | from detectron2.config import configurable 8 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 9 | from detectron2.modeling.roi_heads import ROI_BOX_HEAD_REGISTRY 10 | from torch import nn 11 | 12 | __all__ = ["DiscriminativeAdaptationNeck"] 13 | 14 | 15 | # To get torchscript support, we make the head a subclass of `nn.Sequential`. 16 | # Therefore, to add new layers in this head class, please make sure they are 17 | # added in the order they will be used in forward(). 18 | @ROI_BOX_HEAD_REGISTRY.register() 19 | class DiscriminativeAdaptationNeck(nn.Sequential): 20 | """ 21 | A head with several 3x3 conv layers (each followed by norm & relu) and then 22 | several fc layers (each followed by relu). 23 | """ 24 | 25 | @configurable 26 | def __init__( 27 | self, input_shape: ShapeSpec, *, conv_dims: List[int], fc_dims: List[int], conv_norm="" 28 | ): 29 | """ 30 | NOTE: this interface is experimental. 31 | 32 | Args: 33 | input_shape (ShapeSpec): shape of the input feature. 34 | conv_dims (list[int]): the output dimensions of the conv layers 35 | fc_dims (list[int]): the output dimensions of the fc layers 36 | conv_norm (str or callable): normalization for the conv layers. 37 | See :func:`detectron2.layers.get_norm` for supported types. 38 | """ 39 | super().__init__() 40 | assert len(conv_dims) + len(fc_dims) > 0 41 | 42 | self._output_size = (input_shape.channels, input_shape.height, input_shape.width) 43 | 44 | self.conv_norm_relus = [] 45 | for k, conv_dim in enumerate(conv_dims): 46 | conv = Conv2d( 47 | self._output_size[0], 48 | conv_dim, 49 | kernel_size=3, 50 | padding=1, 51 | bias=not conv_norm, 52 | norm=get_norm(conv_norm, conv_dim), 53 | activation=nn.ReLU(inplace=True), 54 | ) 55 | self.add_module("conv{}".format(k + 1), conv) 56 | self.conv_norm_relus.append(conv) 57 | self._output_size = (conv_dim, self._output_size[1], self._output_size[2]) 58 | 59 | self.fcs = [] 60 | for k, fc_dim in enumerate(fc_dims): 61 | if k == 0: 62 | self.add_module("flatten", nn.Flatten()) 63 | fc = nn.Linear(int(np.prod(self._output_size)), fc_dim) 64 | self.add_module("fc{}".format(k + 1), fc) 65 | self.add_module("fc_relu{}".format(k + 1), nn.ReLU(inplace=True)) 66 | self.add_module("fc_dropout{}".format(k + 1), nn.Dropout(p=0.5, inplace=False)) 67 | self.fcs.append(fc) 68 | self._output_size = fc_dim 69 | 70 | for layer in self.conv_norm_relus: 71 | weight_init.c2_msra_fill(layer) 72 | for layer in self.fcs: 73 | # weight_init.c2_xavier_fill(layer) 74 | torch.nn.init.normal_(layer.weight, std=0.005) 75 | torch.nn.init.constant_(layer.bias, 0.1) 76 | 77 | @classmethod 78 | def from_config(cls, cfg, input_shape): 79 | num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV 80 | conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM 81 | # num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC 82 | fc_dims = cfg.MODEL.ROI_BOX_HEAD.DAN_DIM 83 | return { 84 | "input_shape": input_shape, 85 | "conv_dims": [conv_dim] * num_conv, 86 | "fc_dims": fc_dims, 87 | "conv_norm": cfg.MODEL.ROI_BOX_HEAD.NORM, 88 | } 89 | 90 | def forward(self, x): 91 | for layer in self: 92 | x = layer(x) 93 | return x 94 | 95 | @property 96 | @torch.jit.unused 97 | def output_shape(self): 98 | """ 99 | Returns: 100 | ShapeSpec: the output feature shape 101 | """ 102 | o = self._output_size 103 | if isinstance(o, int): 104 | return ShapeSpec(channels=o) 105 | else: 106 | return ShapeSpec(channels=o[0], height=o[1], width=o[2]) 107 | -------------------------------------------------------------------------------- /wsovod/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params 3 | from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR, LRMultiplier, WarmupParamScheduler 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /wsovod/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import math 4 | from bisect import bisect_right 5 | from typing import List 6 | import torch 7 | from fvcore.common.param_scheduler import ( 8 | CompositeParamScheduler, 9 | ConstantParamScheduler, 10 | LinearParamScheduler, 11 | ParamScheduler, 12 | ) 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class WarmupParamScheduler(CompositeParamScheduler): 18 | """ 19 | Add an initial warmup stage to another scheduler. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | scheduler: ParamScheduler, 25 | warmup_factor: float, 26 | warmup_length: float, 27 | warmup_method: str = "linear", 28 | ): 29 | """ 30 | Args: 31 | scheduler: warmup will be added at the beginning of this scheduler 32 | warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001 33 | warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire 34 | training, e.g. 0.01 35 | warmup_method: one of "linear" or "constant" 36 | """ 37 | end_value = scheduler(warmup_length) # the value to reach when warmup ends 38 | start_value = warmup_factor * scheduler(0.0) 39 | if warmup_method == "constant": 40 | warmup = ConstantParamScheduler(start_value) 41 | elif warmup_method == "linear": 42 | warmup = LinearParamScheduler(start_value, end_value) 43 | else: 44 | raise ValueError("Unknown warmup method: {}".format(warmup_method)) 45 | super().__init__( 46 | [warmup, scheduler], 47 | interval_scaling=["rescaled", "fixed"], 48 | lengths=[warmup_length, 1 - warmup_length], 49 | ) 50 | 51 | 52 | class LRMultiplier(torch.optim.lr_scheduler._LRScheduler): 53 | """ 54 | A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the 55 | learning rate of each param in the optimizer. 56 | Every step, the learning rate of each parameter becomes its initial value 57 | multiplied by the output of the given :class:`ParamScheduler`. 58 | 59 | The absolute learning rate value of each parameter can be different. 60 | This scheduler can be used as long as the relative scale among them do 61 | not change during training. 62 | 63 | Examples: 64 | :: 65 | LRMultiplier( 66 | opt, 67 | WarmupParamScheduler( 68 | MultiStepParamScheduler( 69 | [1, 0.1, 0.01], 70 | milestones=[60000, 80000], 71 | num_updates=90000, 72 | ), 0.001, 100 / 90000 73 | ), 74 | max_iter=90000 75 | ) 76 | """ 77 | 78 | # NOTES: in the most general case, every LR can use its own scheduler. 79 | # Supporting this requires interaction with the optimizer when its parameter 80 | # group is initialized. For example, classyvision implements its own optimizer 81 | # that allows different schedulers for every parameter group. 82 | # To avoid this complexity, we use this class to support the most common cases 83 | # where the relative scale among all LRs stay unchanged during training. In this 84 | # case we only need a total of one scheduler that defines the relative LR multiplier. 85 | 86 | def __init__( 87 | self, 88 | optimizer: torch.optim.Optimizer, 89 | multiplier: ParamScheduler, 90 | max_iter: int, 91 | last_iter: int = -1, 92 | ): 93 | """ 94 | Args: 95 | optimizer, last_iter: See ``torch.optim.lr_scheduler._LRScheduler``. 96 | ``last_iter`` is the same as ``last_epoch``. 97 | multiplier: a fvcore ParamScheduler that defines the multiplier on 98 | every LR of the optimizer 99 | max_iter: the total number of training iterations 100 | """ 101 | if not isinstance(multiplier, ParamScheduler): 102 | raise ValueError( 103 | "_LRMultiplier(multiplier=) must be an instance of fvcore " 104 | f"ParamScheduler. Got {multiplier} instead." 105 | ) 106 | self._multiplier = multiplier 107 | self._max_iter = max_iter 108 | super().__init__(optimizer, last_epoch=last_iter) 109 | 110 | def state_dict(self): 111 | # fvcore schedulers are stateless. Only keep pytorch scheduler states 112 | return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch} 113 | 114 | def get_lr(self) -> List[float]: 115 | multiplier = self._multiplier(self.last_epoch / self._max_iter) 116 | return [base_lr * multiplier for base_lr in self.base_lrs] 117 | 118 | 119 | """ 120 | Content below is no longer needed! 121 | """ 122 | 123 | # NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes 124 | # only on epoch boundaries. We typically use iteration based schedules instead. 125 | # As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean 126 | # "iteration" instead. 127 | 128 | # FIXME: ideally this would be achieved with a CombinedLRScheduler, separating 129 | # MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it. 130 | 131 | 132 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 133 | def __init__( 134 | self, 135 | optimizer: torch.optim.Optimizer, 136 | milestones: List[int], 137 | gamma: float = 0.1, 138 | warmup_factor: float = 0.001, 139 | warmup_iters: int = 1000, 140 | warmup_method: str = "linear", 141 | last_epoch: int = -1, 142 | ): 143 | logger.warning( 144 | "WarmupMultiStepLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!" 145 | ) 146 | if not list(milestones) == sorted(milestones): 147 | raise ValueError( 148 | "Milestones should be a list of" " increasing integers. Got {}", milestones 149 | ) 150 | self.milestones = milestones 151 | self.gamma = gamma 152 | self.warmup_factor = warmup_factor 153 | self.warmup_iters = warmup_iters 154 | self.warmup_method = warmup_method 155 | super().__init__(optimizer, last_epoch) 156 | 157 | def get_lr(self) -> List[float]: 158 | warmup_factor = _get_warmup_factor_at_iter( 159 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 160 | ) 161 | return [ 162 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 163 | for base_lr in self.base_lrs 164 | ] 165 | 166 | def _compute_values(self) -> List[float]: 167 | # The new interface 168 | return self.get_lr() 169 | 170 | 171 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 172 | def __init__( 173 | self, 174 | optimizer: torch.optim.Optimizer, 175 | max_iters: int, 176 | warmup_factor: float = 0.001, 177 | warmup_iters: int = 1000, 178 | warmup_method: str = "linear", 179 | last_epoch: int = -1, 180 | ): 181 | logger.warning( 182 | "WarmupCosineLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!" 183 | ) 184 | self.max_iters = max_iters 185 | self.warmup_factor = warmup_factor 186 | self.warmup_iters = warmup_iters 187 | self.warmup_method = warmup_method 188 | super().__init__(optimizer, last_epoch) 189 | 190 | def get_lr(self) -> List[float]: 191 | warmup_factor = _get_warmup_factor_at_iter( 192 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 193 | ) 194 | # Different definitions of half-cosine with warmup are possible. For 195 | # simplicity we multiply the standard half-cosine schedule by the warmup 196 | # factor. An alternative is to start the period of the cosine at warmup_iters 197 | # instead of at 0. In the case that warmup_iters << max_iters the two are 198 | # very close to each other. 199 | return [ 200 | base_lr 201 | * warmup_factor 202 | * 0.5 203 | * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters)) 204 | for base_lr in self.base_lrs 205 | ] 206 | 207 | def _compute_values(self) -> List[float]: 208 | # The new interface 209 | return self.get_lr() 210 | 211 | 212 | def _get_warmup_factor_at_iter( 213 | method: str, iter: int, warmup_iters: int, warmup_factor: float 214 | ) -> float: 215 | """ 216 | Return the learning rate warmup factor at a specific iteration. 217 | See :paper:`ImageNet in 1h` for more details. 218 | 219 | Args: 220 | method (str): warmup method; either "constant" or "linear". 221 | iter (int): iteration at which to calculate the warmup factor. 222 | warmup_iters (int): the number of warmup iterations. 223 | warmup_factor (float): the base warmup factor (the meaning changes according 224 | to the method used). 225 | 226 | Returns: 227 | float: the effective warmup factor at the given iteration. 228 | """ 229 | if iter >= warmup_iters: 230 | return 1.0 231 | 232 | if method == "constant": 233 | return warmup_factor 234 | elif method == "linear": 235 | alpha = iter / warmup_iters 236 | return warmup_factor * (1 - alpha) + alpha 237 | else: 238 | raise ValueError("Unknown warmup method: {}".format(method)) -------------------------------------------------------------------------------- /wsovod/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HunterJ-Lin/WSOVOD/74e7647faaa0ab5de37a351532f65c57654f66be/wsovod/utils/__init__.py -------------------------------------------------------------------------------- /wsovod/utils/sam_predictor_with_buffer.py: -------------------------------------------------------------------------------- 1 | from segment_anything.modeling import Sam 2 | from segment_anything.utils.transforms import ResizeLongestSide 3 | import numpy as np 4 | import torch 5 | from typing import Optional, Tuple 6 | 7 | class SamPredictorBuffer: 8 | def __init__( 9 | self, 10 | sam_model: Sam, 11 | ): 12 | """ 13 | Uses SAM to calculate the image embedding for an image, and then 14 | allow repeated, efficient mask prediction given prompts. 15 | 16 | Arguments: 17 | sam_model (Sam): The model to use for mask prediction. 18 | """ 19 | super().__init__() 20 | self.model = sam_model 21 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 22 | self.buffer = {} 23 | 24 | def set_image( 25 | self, 26 | image: np.ndarray, 27 | image_format: str = "RGB", 28 | file_name = None, 29 | ): 30 | """ 31 | Calculates the image embeddings for the provided image, allowing 32 | masks to be predicted with the 'predict' method. 33 | 34 | Arguments: 35 | image (np.ndarray): The image for calculating masks. Expects an 36 | image in HWC uint8 format, with pixel values in [0, 255]. 37 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 38 | """ 39 | if file_name in self.buffer.keys(): 40 | return 41 | assert image_format in [ 42 | "RGB", 43 | "BGR", 44 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 45 | if image_format != self.model.image_format: 46 | image = image[..., ::-1] 47 | 48 | # Transform the image to the form expected by the model 49 | input_image = self.transform.apply_image(image) 50 | input_image_torch = torch.as_tensor(input_image, device=self.device) 51 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 52 | 53 | self.set_torch_image(input_image_torch, image.shape[:2], file_name) 54 | 55 | @torch.no_grad() 56 | def set_torch_image( 57 | self, 58 | transformed_image: torch.Tensor, 59 | original_image_size: Tuple[int, ...], 60 | file_name = None, 61 | ): 62 | """ 63 | Calculates the image embeddings for the provided image, allowing 64 | masks to be predicted with the 'predict' method. Expects the input 65 | image to be already transformed to the format expected by the model. 66 | 67 | Arguments: 68 | transformed_image (torch.Tensor): The input image, with shape 69 | 1x3xHxW, which has been transformed with ResizeLongestSide. 70 | original_image_size (tuple(int, int)): The size of the image 71 | before transformation, in (H, W) format. 72 | """ 73 | if file_name in self.buffer.keys(): 74 | return 75 | assert ( 76 | len(transformed_image.shape) == 4 77 | and transformed_image.shape[1] == 3 78 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 79 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 80 | 81 | input_size = tuple(transformed_image.shape[-2:]) 82 | self.buffer[file_name] = { 83 | 'original_size' : original_image_size, 84 | 'input_size' : input_size, 85 | 'features' : self.model.image_encoder(self.model.preprocess(transformed_image)).detach() 86 | } 87 | 88 | def predict( 89 | self, 90 | point_coords: Optional[np.ndarray] = None, 91 | point_labels: Optional[np.ndarray] = None, 92 | box: Optional[np.ndarray] = None, 93 | mask_input: Optional[np.ndarray] = None, 94 | multimask_output: bool = True, 95 | return_logits: bool = False, 96 | file_name = None, 97 | ): 98 | """ 99 | Predict masks for the given input prompts, using the currently set image. 100 | 101 | Arguments: 102 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 103 | model. Each point is in (X,Y) in pixels. 104 | point_labels (np.ndarray or None): A length N array of labels for the 105 | point prompts. 1 indicates a foreground point and 0 indicates a 106 | background point. 107 | box (np.ndarray or None): A length 4 array given a box prompt to the 108 | model, in XYXY format. 109 | mask_input (np.ndarray): A low resolution mask input to the model, typically 110 | coming from a previous prediction iteration. Has form 1xHxW, where 111 | for SAM, H=W=256. 112 | multimask_output (bool): If true, the model will return three masks. 113 | For ambiguous input prompts (such as a single click), this will often 114 | produce better masks than a single prediction. If only a single 115 | mask is needed, the model's predicted quality score can be used 116 | to select the best mask. For non-ambiguous prompts, such as multiple 117 | input prompts, multimask_output=False can give better results. 118 | return_logits (bool): If true, returns un-thresholded masks logits 119 | instead of a binary mask. 120 | 121 | Returns: 122 | (np.ndarray): The output masks in CxHxW format, where C is the 123 | number of masks, and (H, W) is the original image size. 124 | (np.ndarray): An array of length C containing the model's 125 | predictions for the quality of each mask. 126 | (np.ndarray): An array of shape CxHxW, where C is the number 127 | of masks and H=W=256. These low resolution logits can be passed to 128 | a subsequent iteration as mask input. 129 | """ 130 | assert file_name in self.buffer.keys(), RuntimeError("An image must be set with .set_image(...) before mask prediction.") 131 | self.original_size = self.buffer[file_name]['original_size'] 132 | self.input_size = self.buffer[file_name]['input_size'] 133 | self.features = self.buffer[file_name]['features'] 134 | 135 | # Transform input prompts 136 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 137 | if point_coords is not None: 138 | assert ( 139 | point_labels is not None 140 | ), "point_labels must be supplied if point_coords is supplied." 141 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 142 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 143 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 144 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 145 | if box is not None: 146 | box = self.transform.apply_boxes(box, self.original_size) 147 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 148 | box_torch = box_torch[None, :] 149 | if mask_input is not None: 150 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 151 | mask_input_torch = mask_input_torch[None, :, :, :] 152 | 153 | masks, iou_predictions, low_res_masks = self.predict_torch( 154 | coords_torch, 155 | labels_torch, 156 | box_torch, 157 | mask_input_torch, 158 | multimask_output, 159 | return_logits=return_logits, 160 | file_name=file_name 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | file_name = None, 178 | ): 179 | """ 180 | Predict masks for the given input prompts, using the currently set image. 181 | Input prompts are batched torch tensors and are expected to already be 182 | transformed to the input frame using ResizeLongestSide. 183 | 184 | Arguments: 185 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 186 | model. Each point is in (X,Y) in pixels. 187 | point_labels (torch.Tensor or None): A BxN array of labels for the 188 | point prompts. 1 indicates a foreground point and 0 indicates a 189 | background point. 190 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 191 | model, in XYXY format. 192 | mask_input (np.ndarray): A low resolution mask input to the model, typically 193 | coming from a previous prediction iteration. Has form Bx1xHxW, where 194 | for SAM, H=W=256. Masks returned by a previous iteration of the 195 | predict method do not need further transformation. 196 | multimask_output (bool): If true, the model will return three masks. 197 | For ambiguous input prompts (such as a single click), this will often 198 | produce better masks than a single prediction. If only a single 199 | mask is needed, the model's predicted quality score can be used 200 | to select the best mask. For non-ambiguous prompts, such as multiple 201 | input prompts, multimask_output=False can give better results. 202 | return_logits (bool): If true, returns un-thresholded masks logits 203 | instead of a binary mask. 204 | 205 | Returns: 206 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 207 | number of masks, and (H, W) is the original image size. 208 | (torch.Tensor): An array of shape BxC containing the model's 209 | predictions for the quality of each mask. 210 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 211 | of masks and H=W=256. These low res logits can be passed to 212 | a subsequent iteration as mask input. 213 | """ 214 | assert file_name in self.buffer.keys(), RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | self.original_size = self.buffer[file_name]['original_size'] 216 | self.input_size = self.buffer[file_name]['input_size'] 217 | self.features = self.buffer[file_name]['features'] 218 | 219 | if point_coords is not None: 220 | points = (point_coords, point_labels) 221 | else: 222 | points = None 223 | 224 | # Embed prompts 225 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 226 | points=points, 227 | boxes=boxes, 228 | masks=mask_input, 229 | ) 230 | 231 | # Predict masks 232 | low_res_masks, iou_predictions = self.model.mask_decoder( 233 | image_embeddings=self.features, 234 | image_pe=self.model.prompt_encoder.get_dense_pe(), 235 | sparse_prompt_embeddings=sparse_embeddings, 236 | dense_prompt_embeddings=dense_embeddings, 237 | multimask_output=multimask_output, 238 | ) 239 | 240 | # Upscale the masks to the original image resolution 241 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 242 | 243 | if not return_logits: 244 | masks = masks > self.model.mask_threshold 245 | 246 | return masks, iou_predictions, low_res_masks 247 | 248 | def get_image_embedding(self, file_name): 249 | """ 250 | Returns the image embeddings for the currently set image, with 251 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 252 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 253 | """ 254 | if file_name in self.buffer: 255 | return self.buffer['file_name'] 256 | 257 | return None 258 | 259 | @property 260 | def device(self): 261 | return self.model.device 262 | 263 | def reset_buffer(self): 264 | del self.buffer 265 | self.buffer = {} --------------------------------------------------------------------------------