├── LICENSE ├── README.md ├── data └── json_deal.py ├── datasets ├── __init__.py ├── dataset.py ├── dataset_fewshot.py ├── dataset_inference_val.py ├── dataset_support.py ├── eval_detection.py ├── samplers.py ├── torchvision_datasets │ ├── __init__.py │ └── coco.py ├── transforms.py └── visual.py ├── dinov2 ├── __init__.py ├── layers │ ├── __init__.py │ ├── attention.py │ ├── block.py │ ├── dino_head.py │ ├── drop_path.py │ ├── layer_scale.py │ ├── mlp.py │ ├── patch_embed.py │ └── swiglu_ffn.py ├── pad.py ├── vit.py └── vit_beifen.py ├── engine.py ├── inference.py ├── main.py ├── models ├── CDFormer.py ├── __init__.py ├── attention.py ├── backbone.py ├── backbone_frozen.py ├── beifen │ └── CDFormer_beifen.py ├── deformable_transformer.py ├── dino_backbone.py ├── matcher.py ├── ops │ ├── functions │ │ ├── __init__.py │ │ └── ms_deform_attn_func.py │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ └── ms_deform_attn.py │ ├── setup.py │ ├── src │ │ ├── cpu │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ └── ms_deform_attn_cpu.h │ │ ├── cuda │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ ├── ms_deform_attn_cuda.h │ │ │ └── ms_deform_im2col_cuda.cuh │ │ ├── ms_deform_attn.h │ │ └── vision.cpp │ └── test.py └── position_encoding.py ├── scripts ├── basetrain.sh ├── eval.sh └── fsfinetune.sh ├── tools ├── launch.py ├── run_dist_launch.sh └── run_dist_slurm.sh └── util ├── __init__.py ├── box_ops.py ├── lr_scheduler.py ├── misc.py └── plot_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 LONGXUANX 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CDFormer: Cross-Domain Few-Shot Object Detection Transformer Against Feature Confusion 2 | - In our work, We did not use any extra training data 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cdformer-cross-domain-few-shot-object/cross-domain-few-shot-object-detection-on)](https://paperswithcode.com/sota/cross-domain-few-shot-object-detection-on?p=cdformer-cross-domain-few-shot-object)[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cdformer-cross-domain-few-shot-object/cross-domain-few-shot-object-detection-on-1)](https://paperswithcode.com/sota/cross-domain-few-shot-object-detection-on-1?p=cdformer-cross-domain-few-shot-object)[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cdformer-cross-domain-few-shot-object/cross-domain-few-shot-object-detection-on-3)](https://paperswithcode.com/sota/cross-domain-few-shot-object-detection-on-3?p=cdformer-cross-domain-few-shot-object)[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cdformer-cross-domain-few-shot-object/cross-domain-few-shot-object-detection-on-2)](https://paperswithcode.com/sota/cross-domain-few-shot-object-detection-on-2?p=cdformer-cross-domain-few-shot-object)[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cdformer-cross-domain-few-shot-object/cross-domain-few-shot-object-detection-on-neu)](https://paperswithcode.com/sota/cross-domain-few-shot-object-detection-on-neu?p=cdformer-cross-domain-few-shot-object)[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cdformer-cross-domain-few-shot-object/cross-domain-few-shot-object-detection-on-4)](https://paperswithcode.com/sota/cross-domain-few-shot-object-detection-on-4?p=cdformer-cross-domain-few-shot-object) 5 | 6 | 7 | **In this paper**, our key contributions: 8 | 1) While CD-ViTO demonstrates significant performance degradation in open-set detection on the CD-FSOD benchmark, our network exhibits notable domain robustness. 9 | 2) Our single-stage framework with fixed classification heads achieves arbitrary-class inference capability through the introduction of background placeholders. 10 | 3) We propose highly effective object-object distinguishing and object-background distinguishing strategies. 11 | 12 | 13 | 14 | ### Pre-Requisites 15 | You must have NVIDIA GPUs to run the codes. 16 | 17 | The implementation codes are developed and tested with the following environment setups: 18 | - 4 x NVIDIA 4090 GPUs 19 | - CUDA 11.8 20 | - Python == 3.8 21 | - PyTorch == 2.3.1+cu118, TorchVision == 0.18.1+cu118 22 | - GCC == 11.4.0 23 | - cython, pycocotools, tqdm, scipy, opencv-python 24 | 25 | 26 | ### Deformable attention compile 27 | ```bash 28 | # compile CUDA operators of Deformable Attention 29 | cd CDFormer 30 | cd ./models/ops 31 | sh ./make.sh 32 | python test.py # unit test (should see all checking is True) 33 | ``` 34 | 35 | ### Data Preparation 36 | 37 | #### MS-COCO for base train and UODD/Artaxor/Clipart1k/Dior/NEU-DET/Deepfish for evaluation or finetune & evaluation 38 | 39 | Please download [COCO 2017 dataset](https://cocodataset.org/) and [CD-FSOD Benchmark](https://github.com/lovelyqian/CDFSOD-benchmark?tab=readme-ov-file), 40 | then organize them as following: 41 | 42 | ``` 43 | code_root/ 44 | └── data/ 45 | └── coco/ # MS-COCO dataset 46 | ├── train2017/ 47 | ├── val2017/ 48 | └── annotations/ 49 | ├── instances_train2017.json 50 | └── instances_val2017.json 51 | └── UODD(/Artaxor/Clipart1k/Dior/NEU-DET/Deepfish/) # UODD/Artaxor/Clipart1k/Dior/NEU-DET/Deepfish dataset 52 | ├── train/ 53 | ├── test/ 54 | └── annotations/ 55 | ├── train.json 56 | ├── test.json/ 57 | ├── 1_shot.json/ 58 | ├── 5_shot.json/ 59 | └── 10_shot.json/ 60 | ``` 61 | 62 | #### Pre-Trained Model Weights 63 | 64 | - DINOv2 ViTL/14 model: click [here](https://github.com/facebookresearch/dinov2) to download. Please put it in model_pt/dinov2 65 | - The pre-training weights we provide on COCO (It is recommended to train by yourself. Replacing with [CDFormer_beifen.py] may result in higher metrics, but fine-tuning epochs may not be consistent with the current setting): click [here](https://pan.baidu.com/s/1eoe9dkjNlqeQ75aD5PNLOA?pwd=w628) to download(百度网盘). 66 | 67 | ### Base Training 68 | run the commands below to start base training. 69 | ```bash 70 | GPUS_PER_NODE=4 ./tools/run_dist_launch.sh 4 nohup ./scripts/basetrain.sh >/dev/null 2>&1 & 71 | ``` 72 | 73 | ### Cross-domain Few-Shot Finetuning 74 | ``` 75 | We have chosen different tuning epochs for different datasets, so please adjust the parameters epoch, save_every_epoch, eval_every_epoch, save_every_epoch in fstinune.sh. 76 | Dataset | ArTaxOr | Clipart | DIOR | Deepfish | NEU-DET | UODD | 77 | epoch | 70 | 30 | 190 | 15 | 140 | 50 | 78 | In addition, we did not fine-tune the hyperparameters due to limited computational resources. 79 | ``` 80 | ```bash 81 | GPUS_PER_NODE=2 ./tools/run_dist_launch.sh 2 nohup ./scripts/fsfinetune.sh >/dev/null 2>&1 & 82 | ``` 83 | 84 | ### Evaluation (after base training or after base training & finetuning) 85 | Evaluate the metrics 86 | ```bash 87 | ./scripts/eval.sh 88 | ``` 89 | 90 | ### Inference & visualization 91 | You can just allow the inference.py which contains a visualization of the test images and a visualization of the confusion matrix. 92 | 93 | ## Citation 94 | If you find CDFormer useful or inspiring, please consider citing: 95 | 96 | ```bibtex 97 | @misc{meng2025cdformercrossdomainfewshotobject, 98 | title={CDFormer: Cross-Domain Few-Shot Object Detection Transformer Against Feature Confusion}, 99 | author={Boyuan Meng and Xiaohan Zhang and Peilin Li and Zhe Wu and Yiming Li and Wenkai Zhao and Beinan Yu and Hui-Liang Shen}, 100 | year={2025}, 101 | eprint={2505.00938}, 102 | archivePrefix={arXiv}, 103 | primaryClass={cs.CV}, 104 | url={https://arxiv.org/abs/2505.00938}, 105 | } 106 | ``` 107 | 108 | ## Acknowledgement 109 | Our proposed CDFormer is heavily inspired by many outstanding prior works, including [Deformable DETR](https://arxiv.org/pdf/2010.04159), [CDMM-FSOD](https://arxiv.org/pdf/2502.16469), and [CD-ViTO](https://arxiv.org/pdf/2402.03094) 110 | 111 | Thanks for their work. 112 | -------------------------------------------------------------------------------- /data/json_deal.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | #CD-ViTO对于ArTaxOr数据集的json命名在本项目中无法解析,需要重新进行id映射 4 | 5 | # 加载JSON文件 6 | with open('data/ArTaxOr/annotations/5_shot.json', 'r') as f: 7 | data = json.load(f) 8 | 9 | # 初始化计数器 10 | image_id_counter = 1 11 | annotation_id_counter = 1 12 | 13 | # 创建一个映射,将旧的字符串id映射为新的整数id 14 | image_id_map = {} 15 | 16 | # 将images中的id转换为整数 17 | for image in data['images']: 18 | image_id_map[image['id']] = image_id_counter # 记录映射关系 19 | image['id'] = image_id_counter # 更新为整数id 20 | image_id_counter += 1 21 | 22 | # 将annotations中的image_id和id转换为整数 23 | for annotation in data['annotations']: 24 | annotation['image_id'] = image_id_map[annotation['image_id']] # 将image_id映射到新的整数id 25 | annotation['id'] = annotation_id_counter # 给每个annotation一个新的整数id 26 | annotation_id_counter += 1 27 | 28 | # 保存修改后的JSON文件 29 | with open('data/ArTaxOr/annotations/fixed_5_shot.json', 'w') as f: 30 | json.dump(data, f) 31 | 32 | print("转换完成,保存为 'fixed_test.json'.") 33 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | import torch.utils.data 4 | from .torchvision_datasets import CocoDetection 5 | 6 | from .dataset import build 7 | from .dataset_fewshot import build as build_fewshot 8 | 9 | 10 | # Meta-settings for few-shot object detection: base / novel category split 11 | coco_base_class_ids = [ 12 | 8, 10, 11, 13, 14, 15, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 46, 47, 13 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 65, 70, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 14 | 85, 86, 87, 88, 89, 90 15 | ] 16 | 17 | coco_novel_class_ids = [ 18 | 1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72 19 | ] 20 | 21 | voc_base1_class_ids = [ 22 | 1, 2, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17, 19, 20 23 | ] 24 | 25 | voc_novel1_class_ids = [ 26 | 3, 6, 10, 14, 18 27 | ] 28 | 29 | voc_base2_class_ids = [ 30 | 2, 3, 4, 6, 7, 8, 9, 11, 12, 14, 15, 16, 17, 19, 20 31 | ] 32 | 33 | voc_novel2_class_ids = [ 34 | 1, 5, 10, 13, 18 35 | ] 36 | 37 | voc_base3_class_ids = [ 38 | 1, 2, 3, 5, 6, 7, 9, 10, 11, 12, 13, 15, 16, 19, 20 39 | ] 40 | 41 | voc_novel3_class_ids = [ 42 | 4, 8, 14, 17, 18 43 | ] 44 | 45 | # 因为uodd的class_id是从0开始的,所以这里要从0开始 46 | uodd_class_ids = [ 47 | 0, 1, 2 48 | ] 49 | 50 | # 因为deepfish的class_id是从1开始的,所以这里要从1开始 51 | deepfish_class_ids = [ 52 | 1 53 | ] 54 | 55 | # 因为neu的class_id是从1开始的,所以这里要从1开始 56 | neu_class_ids = [ 57 | 1, 2, 3, 4, 5, 6 58 | ] 59 | 60 | # 因为clipart的class_id是从1开始的,所以这里要从1开始 61 | clipart_class_ids = [ 62 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 63 | ] 64 | 65 | # 因为artaxor的class_id是从1开始的,所以这里要从1开始 66 | artaxor_class_ids = [ 67 | 1, 2, 3, 4, 5, 6, 7 68 | ] 69 | 70 | # 因为dior的class_id是从1开始的,所以这里要从1开始 71 | dior_class_ids = [ 72 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 73 | ] 74 | 75 | # 因为dataset1的class_id是从0开始的,所以这里要从0开始 76 | dataset1_class_ids = [ 77 | 0, 1, 2, 3, 4, 5, 6 78 | ] 79 | 80 | # 因为dataset2的class_id是从1开始的,所以这里要从1开始 81 | dataset2_class_ids = [ 82 | 1 83 | ] 84 | 85 | # 因为dataset3的class_id是从1开始的,所以这里要从1开始 86 | dataset3_class_ids = [ 87 | 1, 2, 3, 4, 5, 6 88 | ] 89 | 90 | def get_class_ids(dataset, type): 91 | if dataset == 'coco_base': 92 | if type == 'all': 93 | ids = (coco_base_class_ids + coco_novel_class_ids) 94 | ids.sort() 95 | return ids 96 | elif type == 'base': 97 | return coco_base_class_ids 98 | elif type == 'novel': 99 | return coco_novel_class_ids 100 | else: 101 | raise ValueError 102 | if dataset == 'coco': 103 | if type == 'all': 104 | ids = (coco_base_class_ids + coco_novel_class_ids) 105 | ids.sort() 106 | return ids 107 | else: 108 | raise ValueError 109 | if dataset == 'voc_base1': 110 | if type == 'all': 111 | ids = (voc_base1_class_ids + voc_novel1_class_ids) 112 | ids.sort() 113 | return ids 114 | elif type == 'base': 115 | return voc_base1_class_ids 116 | elif type == 'novel': 117 | return voc_novel1_class_ids 118 | else: 119 | raise ValueError 120 | if dataset == 'voc_base2': 121 | if type == 'all': 122 | ids = (voc_base2_class_ids + voc_novel2_class_ids) 123 | ids.sort() 124 | return ids 125 | elif type == 'base': 126 | return voc_base2_class_ids 127 | elif type == 'novel': 128 | return voc_novel2_class_ids 129 | else: 130 | raise ValueError 131 | if dataset == 'voc_base3': 132 | if type == 'all': 133 | ids = (voc_base3_class_ids + voc_novel3_class_ids) 134 | ids.sort() 135 | return ids 136 | elif type == 'base': 137 | return voc_base3_class_ids 138 | elif type == 'novel': 139 | return voc_novel3_class_ids 140 | else: 141 | raise ValueError 142 | if dataset == 'voc': 143 | if type == 'all': 144 | ids = (voc_base1_class_ids + voc_novel1_class_ids) 145 | ids.sort() 146 | return ids 147 | else: 148 | raise ValueError 149 | if dataset == 'uodd': 150 | ids = uodd_class_ids 151 | ids.sort() 152 | return ids 153 | if dataset == 'deepfish': 154 | ids = deepfish_class_ids 155 | ids.sort() 156 | return ids 157 | if dataset == 'neu': 158 | ids = neu_class_ids 159 | ids.sort() 160 | return ids 161 | if dataset == 'clipart': 162 | ids = clipart_class_ids 163 | ids.sort() 164 | return ids 165 | if dataset == 'artaxor': 166 | ids = artaxor_class_ids 167 | ids.sort() 168 | return ids 169 | if dataset == 'dior': 170 | ids = dior_class_ids 171 | ids.sort() 172 | return ids 173 | if dataset == 'dataset1': 174 | ids = dataset1_class_ids 175 | ids.sort() 176 | return ids 177 | if dataset == 'dataset2': 178 | ids = dataset2_class_ids 179 | ids.sort() 180 | return ids 181 | if dataset == 'dataset3': 182 | ids = dataset3_class_ids 183 | ids.sort() 184 | return ids 185 | 186 | raise ValueError 187 | 188 | 189 | def get_coco_api_from_dataset(dataset): 190 | for _ in range(10): 191 | if isinstance(dataset, torch.utils.data.Subset): 192 | dataset = dataset.dataset 193 | if isinstance(dataset, CocoDetection): 194 | return dataset.coco 195 | 196 | 197 | def build_dataset(image_set, args): 198 | assert image_set in ['train', 'val', 'fewshot'], "image_set must be 'train', 'val' or 'fewshot'." 199 | # For training set, need to perform base/novel category filtering. 200 | # For training set, we use dataset with support to construct meta-tasks 201 | if image_set == 'train': 202 | if args.dataset_file == 'coco': 203 | root = Path('/home/csy/datasets/mscoco') 204 | img_folder = root / "train2017" 205 | ann_file = root / "annotations" / 'instances_train2017.json' 206 | class_ids = coco_base_class_ids + coco_novel_class_ids 207 | class_ids.sort() 208 | return build(args, img_folder, ann_file, image_set, activated_class_ids=class_ids, with_support=True) 209 | if args.dataset_file == 'coco_base': 210 | root = Path('data/coco') 211 | img_folder = root / "train2017" 212 | ann_file = root / "annotations" / 'instances_train2017.json' 213 | return build(args, img_folder, ann_file, image_set, activated_class_ids=coco_base_class_ids, with_support=True) 214 | if args.dataset_file == 'voc': 215 | root = Path('data/voc') 216 | img_folder = root / "images" 217 | ann_file = root / "annotations" / 'pascal_trainval0712.json' 218 | return build(args, img_folder, ann_file, image_set, activated_class_ids=list(range(1, 20+1)), with_support=True) 219 | if args.dataset_file == 'voc_base1': 220 | root = Path('data/voc') 221 | img_folder = root / "images" 222 | ann_file = root / "annotations" / 'pascal_trainval0712.json' 223 | return build(args, img_folder, ann_file, image_set, activated_class_ids=voc_base1_class_ids, with_support=True) 224 | if args.dataset_file == 'voc_base2': 225 | root = Path('data/voc') 226 | img_folder = root / "images" 227 | ann_file = root / "annotations" / 'pascal_trainval0712.json' 228 | return build(args, img_folder, ann_file, image_set, activated_class_ids=voc_base2_class_ids, with_support=True) 229 | if args.dataset_file == 'voc_base3': 230 | root = Path('data/voc') 231 | img_folder = root / "images" 232 | ann_file = root / "annotations" / 'pascal_trainval0712.json' 233 | return build(args, img_folder, ann_file, image_set, activated_class_ids=voc_base3_class_ids, with_support=True) 234 | 235 | # For valid set, no need to perform base/novel category filtering. 236 | # This is because that evaluation should be performed on all images. 237 | # For valid set, we do not need support dataset. 238 | if image_set == 'val': 239 | if args.dataset_file in ['coco', 'coco_base']: 240 | root = Path('/home/csy/datasets/mscoco') 241 | img_folder = root / "val2017" 242 | ann_file = root / "annotations" / 'instances_val2017.json' 243 | class_ids = coco_base_class_ids + coco_novel_class_ids 244 | class_ids.sort() 245 | return build(args, img_folder, ann_file, image_set, activated_class_ids=class_ids, with_support=False) 246 | if args.dataset_file in ['voc', 'voc_base1', 'voc_base2', 'voc_base3']: 247 | root = Path('data/voc') 248 | img_folder = root / "images" 249 | ann_file = root / "annotations" / 'pascal_test2007.json' 250 | return build(args, img_folder, ann_file, image_set, activated_class_ids=list(range(1, 20+1)), with_support=False) 251 | if args.dataset_file in ['uodd']: 252 | root = Path('data/UODD') 253 | img_folder = root / "test" 254 | ann_file = root / f'annotations' / f'test.json' 255 | return build(args, img_folder, ann_file, image_set, activated_class_ids=uodd_class_ids, with_support=False) 256 | if args.dataset_file in ['deepfish']: 257 | root = Path('data/FISH') 258 | img_folder = root / "test" 259 | ann_file = root / f'annotations' / f'test.json' 260 | return build(args, img_folder, ann_file, image_set, activated_class_ids=deepfish_class_ids, with_support=False) 261 | if args.dataset_file in ['neu']: 262 | root = Path('data/NEU-DET') 263 | img_folder = root / "test" 264 | ann_file = root / f'annotations' / f'test.json' 265 | return build(args, img_folder, ann_file, image_set, activated_class_ids=neu_class_ids, with_support=False) 266 | if args.dataset_file in ['clipart']: 267 | root = Path('data/clipart1k') 268 | img_folder = root / "test" 269 | ann_file = root / f'annotations' / f'test.json' 270 | return build(args, img_folder, ann_file, image_set, activated_class_ids=clipart_class_ids, with_support=False) 271 | if args.dataset_file in ['artaxor']: 272 | root = Path('data/ArTaxOr') 273 | img_folder = root / "test" 274 | ann_file = root / f'annotations' / f'test.json' 275 | return build(args, img_folder, ann_file, image_set, activated_class_ids=artaxor_class_ids, with_support=False) 276 | if args.dataset_file in ['dior']: 277 | root = Path('data/DIOR') 278 | img_folder = root / "test" 279 | ann_file = root / f'annotations' / f'test.json' 280 | return build(args, img_folder, ann_file, image_set, activated_class_ids=dior_class_ids, with_support=False) 281 | if args.dataset_file in ['dataset1']: 282 | root = Path('data/dataset1') 283 | img_folder = root / "test" 284 | ann_file = root / f'annotations' / f'test.json' 285 | return build(args, img_folder, ann_file, image_set, activated_class_ids=dataset1_class_ids, with_support=False) 286 | if args.dataset_file in ['dataset2']: 287 | root = Path('data/dataset2') 288 | img_folder = root / "test" 289 | ann_file = root / f'annotations' / f'test.json' 290 | return build(args, img_folder, ann_file, image_set, activated_class_ids=dataset2_class_ids, with_support=False) 291 | if args.dataset_file in ['dataset3']: 292 | root = Path('data/dataset3') 293 | img_folder = root / "test" 294 | ann_file = root / f'annotations' / f'test.json' 295 | return build(args, img_folder, ann_file, image_set, activated_class_ids=dataset3_class_ids, with_support=False) 296 | 297 | 298 | if image_set == 'fewshot': 299 | if args.dataset_file in ['coco', 'coco_base']: 300 | class_ids = coco_base_class_ids + coco_novel_class_ids 301 | class_ids.sort() 302 | return build_fewshot(args, image_set, activated_class_ids=class_ids, with_support=True) 303 | if args.dataset_file in ['voc', 'voc_base1', 'voc_base2', 'voc_base3']: 304 | return build_fewshot(args, image_set, activated_class_ids=list(range(1, 20+1)), with_support=True) 305 | if args.dataset_file in ['uodd']: 306 | return build_fewshot(args, image_set, activated_class_ids=uodd_class_ids, with_support=True) 307 | if args.dataset_file in ['deepfish']: 308 | return build_fewshot(args, image_set, activated_class_ids=deepfish_class_ids, with_support=True) 309 | if args.dataset_file in ['neu']: 310 | return build_fewshot(args, image_set, activated_class_ids=neu_class_ids, with_support=True) 311 | if args.dataset_file in ['clipart']: 312 | return build_fewshot(args, image_set, activated_class_ids=clipart_class_ids, with_support=True) 313 | if args.dataset_file in ['artaxor']: 314 | return build_fewshot(args, image_set, activated_class_ids=artaxor_class_ids, with_support=True) 315 | if args.dataset_file in ['dior']: 316 | return build_fewshot(args, image_set, activated_class_ids=dior_class_ids, with_support=True) 317 | if args.dataset_file in ['dataset1']: 318 | return build_fewshot(args, image_set, activated_class_ids=dataset1_class_ids, with_support=True) 319 | if args.dataset_file in ['dataset2']: 320 | return build_fewshot(args, image_set, activated_class_ids=dataset2_class_ids, with_support=True) 321 | if args.dataset_file in ['dataset3']: 322 | return build_fewshot(args, image_set, activated_class_ids=dataset3_class_ids, with_support=True) 323 | 324 | raise ValueError(f'{image_set} of dataset {args.dataset_file} not supported.') 325 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | import torch 5 | import torch.utils.data 6 | from pycocotools.coco import COCO 7 | from pycocotools import mask as coco_mask 8 | 9 | from .torchvision_datasets import CocoDetection as TvCocoDetection 10 | from util.misc import get_local_rank, get_local_size 11 | import datasets.transforms as T 12 | 13 | 14 | class DetectionDataset(TvCocoDetection): 15 | def __init__(self, args, img_folder, ann_file, transforms, support_transforms, return_masks, activated_class_ids, 16 | with_support, cache_mode=False, local_rank=0, local_size=1): 17 | super(DetectionDataset, self).__init__(img_folder, ann_file, cache_mode=cache_mode, local_rank=local_rank, local_size=local_size) 18 | self.with_support = with_support 19 | self.activated_class_ids = activated_class_ids 20 | self._transforms = transforms 21 | self.prepare = ConvertCocoPolysToMask(return_masks) 22 | """ 23 | If with_support = True, this dataset will also produce support images and support targets. 24 | with_support should be set to True for training, and should be set to False for inference. 25 | * During training, support images are sampled along with query images in this dataset. 26 | * During inference, support images are sampled from dataset_support.py 27 | """ 28 | if self.with_support: 29 | self.NUM_SUPP = args.total_num_support 30 | self.NUM_MAX_POS_SUPP = args.max_pos_support 31 | self.support_transforms = support_transforms 32 | self.build_support_dataset(ann_file) 33 | 34 | def __getitem__(self, idx): 35 | img, target = super(DetectionDataset, self).__getitem__(idx) 36 | target = [anno for anno in target if anno['category_id'] in self.activated_class_ids] 37 | image_id = self.ids[idx] 38 | target = {'image_id': image_id, 'annotations': target} 39 | img, target = self.prepare(img, target) 40 | if self._transforms is not None: 41 | img, target = self._transforms(img, target) 42 | if self.with_support: 43 | support_images, support_class_ids, support_targets = self.sample_support_samples(target) 44 | return img, target, support_images, support_class_ids, support_targets 45 | else: 46 | return img, target 47 | 48 | def build_support_dataset(self, ann_file): 49 | self.anns_by_class = {i: [] for i in self.activated_class_ids} 50 | coco = COCO(ann_file) 51 | for classid in self.activated_class_ids: 52 | annIds = coco.getAnnIds(catIds=classid) 53 | for annId in annIds: 54 | ann = coco.loadAnns(annId)[0] 55 | if 'area' in ann: 56 | if ann['area'] < 5.0: 57 | continue 58 | if 'ignore' in ann: 59 | if ann['ignore']: 60 | continue 61 | if 'iscrowd' in ann: 62 | if ann['iscrowd'] == 1: 63 | continue 64 | ann['image_path'] = coco.loadImgs(ann['image_id'])[0]['file_name'] 65 | self.anns_by_class[classid].append(ann) 66 | 67 | def sample_support_samples(self, target): 68 | positive_labels = target['labels'].unique() 69 | num_positive_labels = positive_labels.shape[0] 70 | positive_labels_list = positive_labels.tolist() 71 | negative_labels_list = list(set(self.activated_class_ids) - set(positive_labels_list)) 72 | 73 | ''' 74 | 在跨域少样本中,类别不像原来的voc/coco一样多,所以这样的策略是有问题的 75 | 因为跨域少样本微调没有基类,导致其类别就新类那么几个,比如三个,此时下面的策略就会有问题 76 | ''' 77 | ''' 78 | # Positive labels in a batch < TRAIN_NUM_POSITIVE_SUPP: we include additional labels as negative samples 79 | if num_positive_labels <= self.NUM_MAX_POS_SUPP: 80 | sampled_labels_list = positive_labels_list 81 | sampled_labels_list += random.sample(negative_labels_list, k=self.NUM_SUPP - num_positive_labels) 82 | # Positive labels in a batch > TRAIN_NUM_POSITIVE_SUPP: remove some positive labels. 83 | else: 84 | sampled_positive_labels_list = random.sample(positive_labels_list, k=self.NUM_MAX_POS_SUPP) 85 | sampled_negative_labels_list = random.sample(negative_labels_list, k=self.NUM_SUPP - self.NUM_MAX_POS_SUPP) 86 | sampled_labels_list = sampled_positive_labels_list + sampled_negative_labels_list 87 | # ----------------------------------------------------------------------- 88 | # NOTE: There is no need to filter gt info at this stage. 89 | # Filtering is done when formulating the episodes. 90 | # ----------------------------------------------------------------------- 91 | ''' 92 | # 下面的核心思想:不再要求每张图片对应的support都是15个,而是有多少类提多少,这对于多类和少类均适用 93 | # 但对于基类来说,比如voc那就一张图片对应20个提取,coco就是80了 94 | num_support_class = len(self.activated_class_ids) 95 | if num_support_class > 20: 96 | num_support_class = 20 97 | # Positive labels in a batch < TRAIN_NUM_POSITIVE_SUPP: we include additional labels as negative samples 98 | if num_positive_labels <= self.NUM_MAX_POS_SUPP: 99 | sampled_labels_list = positive_labels_list 100 | sampled_labels_list += random.sample(negative_labels_list, k=num_support_class - num_positive_labels) 101 | # Positive labels in a batch > TRAIN_NUM_POSITIVE_SUPP: remove some positive labels. 102 | else: 103 | sampled_positive_labels_list = random.sample(positive_labels_list, k=self.NUM_MAX_POS_SUPP) 104 | sampled_negative_labels_list = random.sample(negative_labels_list, k=num_support_class - self.NUM_MAX_POS_SUPP) 105 | sampled_labels_list = sampled_positive_labels_list + sampled_negative_labels_list 106 | # ----------------------------------------------------------------------- 107 | # NOTE: There is no need to filter gt info at this stage. 108 | # Filtering is done when formulating the episodes. 109 | # ----------------------------------------------------------------------- 110 | support_images = [] 111 | support_targets = [] 112 | support_class_ids = [] 113 | for class_id in sampled_labels_list: 114 | i = random.randint(0, len(self.anns_by_class[class_id]) - 1) 115 | support_target = self.anns_by_class[class_id][i] 116 | support_target = {'image_id': class_id, 'annotations': [support_target]} # Actually it is class_id for key 'image_id' here 117 | support_image_path = os.path.join(self.root, self.anns_by_class[class_id][i]['image_path']) 118 | support_image = Image.open(support_image_path).convert('RGB') 119 | support_image, support_target = self.prepare(support_image, support_target) 120 | if self.support_transforms is not None: 121 | org_support_target, org_support_image = support_target, support_image 122 | attempts = 0 123 | while True: 124 | support_image, support_target = self.support_transforms(org_support_image, org_support_target) 125 | attempts += 1 126 | # Make sure the object is not deleted after transforms, and it is not too small (mostly cut off) 127 | if support_target['boxes'].shape[0] == 1 and support_target['area'] >= org_support_target['area'] / 5.0: 128 | break 129 | # 在跑artaxor和clipart的train时遇到上面循环出不去的情况,主要是area的问题,所以这里加上一个条件 130 | elif support_target['boxes'].shape[0] == 1 and attempts >= 10: 131 | # print("Max attempts reached 10, breaking out of the loop.") 132 | break 133 | support_images.append(support_image) 134 | support_targets.append(support_target) 135 | support_class_ids.append(class_id) 136 | return support_images, torch.as_tensor(support_class_ids), support_targets 137 | 138 | 139 | def convert_coco_poly_to_mask(segmentations, height, width): 140 | masks = [] 141 | for polygons in segmentations: 142 | rles = coco_mask.frPyObjects(polygons, height, width) 143 | mask = coco_mask.decode(rles) 144 | if len(mask.shape) < 3: 145 | mask = mask[..., None] 146 | mask = torch.as_tensor(mask, dtype=torch.uint8) 147 | mask = mask.any(dim=2) 148 | masks.append(mask) 149 | if masks: 150 | masks = torch.stack(masks, dim=0) 151 | else: 152 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 153 | return masks 154 | 155 | 156 | class ConvertCocoPolysToMask(object): 157 | def __init__(self, return_masks=False): 158 | self.return_masks = return_masks 159 | 160 | def __call__(self, image, target): 161 | w, h = image.size 162 | 163 | image_id = target["image_id"] 164 | image_id = torch.tensor([image_id]) 165 | 166 | anno = target["annotations"] 167 | 168 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 169 | 170 | boxes = [obj["bbox"] for obj in anno] 171 | # guard against no boxes via resizing 172 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 173 | boxes[:, 2:] += boxes[:, :2] 174 | boxes[:, 0::2].clamp_(min=0, max=w) 175 | boxes[:, 1::2].clamp_(min=0, max=h) 176 | 177 | classes = [obj["category_id"] for obj in anno] 178 | classes = torch.tensor(classes, dtype=torch.int64) 179 | 180 | if self.return_masks: 181 | segmentations = [obj["segmentation"] for obj in anno] 182 | masks = convert_coco_poly_to_mask(segmentations, h, w) 183 | 184 | keypoints = None 185 | if anno and "keypoints" in anno[0]: 186 | keypoints = [obj["keypoints"] for obj in anno] 187 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 188 | num_keypoints = keypoints.shape[0] 189 | if num_keypoints: 190 | keypoints = keypoints.view(num_keypoints, -1, 3) 191 | 192 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 193 | boxes = boxes[keep] 194 | classes = classes[keep] 195 | if self.return_masks: 196 | masks = masks[keep] 197 | if keypoints is not None: 198 | keypoints = keypoints[keep] 199 | 200 | target = {} 201 | target["boxes"] = boxes 202 | target["labels"] = classes 203 | if self.return_masks: 204 | target["masks"] = masks 205 | target["image_id"] = image_id 206 | if keypoints is not None: 207 | target["keypoints"] = keypoints 208 | 209 | # for conversion to coco api 210 | area = torch.tensor([obj["area"] for obj in anno]) 211 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 212 | target["area"] = area[keep] 213 | target["iscrowd"] = iscrowd[keep] 214 | 215 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 216 | target["size"] = torch.as_tensor([int(h), int(w)]) 217 | 218 | return image, target 219 | 220 | 221 | def make_transforms(image_set): 222 | """ 223 | Transforms for query images. 224 | """ 225 | normalize = T.Compose([ 226 | T.ToTensor(), 227 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 228 | ]) 229 | 230 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 231 | 232 | if image_set == 'train': 233 | return T.Compose([ 234 | T.RandomHorizontalFlip(), 235 | T.RandomColorJitter(p=0.3333), 236 | T.RandomSelect( 237 | T.RandomResize(scales, max_size=1152), 238 | T.Compose([ 239 | T.RandomResize([400, 500, 600]), 240 | T.RandomSizeCrop(384, 600), 241 | T.RandomResize(scales, max_size=1152), 242 | ]) 243 | ), 244 | normalize, 245 | ]) 246 | 247 | if image_set == 'val' or image_set == 'test': 248 | return T.Compose([ 249 | T.RandomResize([800], max_size=1152), 250 | normalize, 251 | ]) 252 | 253 | raise ValueError(f'unknown {image_set}') 254 | 255 | 256 | def make_support_transforms(): 257 | """ 258 | Transforms for support images during the training phase. 259 | For transforms for support images during inference, please check dataset_support.py 260 | """ 261 | normalize = T.Compose([ 262 | T.ToTensor(), 263 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 264 | ]) 265 | 266 | scales = [448, 464, 480, 496, 512, 528, 544, 560, 576, 592, 608, 624, 640, 656, 672] 267 | 268 | return T.Compose([ 269 | T.RandomHorizontalFlip(), 270 | T.RandomColorJitter(p=0.25), 271 | T.RandomSelect( 272 | T.RandomResize(scales, max_size=672), 273 | T.Compose([ 274 | T.RandomResize([400, 500, 600]), 275 | T.RandomSizeCrop(384, 600), 276 | T.RandomResize(scales, max_size=672), 277 | ]) 278 | ), 279 | normalize, 280 | ]) 281 | 282 | 283 | def build(args, img_folder, ann_file, image_set, activated_class_ids, with_support): 284 | return DetectionDataset(args, img_folder, ann_file, 285 | transforms=make_transforms(image_set), 286 | support_transforms=make_support_transforms(), 287 | return_masks=False, 288 | activated_class_ids=activated_class_ids, 289 | with_support=with_support, 290 | cache_mode=args.cache_mode, 291 | local_rank=get_local_rank(), 292 | local_size=get_local_size()) 293 | -------------------------------------------------------------------------------- /datasets/dataset_fewshot.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from datasets.dataset import DetectionDataset 4 | import datasets.transforms as T 5 | from util.misc import get_local_rank, get_local_size 6 | 7 | 8 | def make_transforms(): 9 | """ 10 | Transforms for query images during the few-shot fine-tuning stage. 11 | """ 12 | normalize = T.Compose([ 13 | T.ToTensor(), 14 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 15 | ]) 16 | 17 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 18 | 19 | return T.Compose([ 20 | T.RandomHorizontalFlip(), 21 | T.RandomColorJitter(p=0.3333), 22 | T.RandomSelect( 23 | T.RandomResize(scales, max_size=1152), 24 | T.Compose([ 25 | T.RandomResize([400, 500, 600]), 26 | T.RandomSizeCrop(384, 600), 27 | T.RandomResize(scales, max_size=1152), 28 | ]) 29 | ), 30 | normalize, 31 | ]) 32 | 33 | 34 | def make_support_transforms(): 35 | """ 36 | Transforms for support images during the few-shot fine-tuning stage. 37 | For transforms for support images during the base training stage, please check dataset.py 38 | For transforms for support images during inference, please check dataset_support.py 39 | """ 40 | normalize = T.Compose([ 41 | T.ToTensor(), 42 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 43 | ]) 44 | 45 | scales = [448, 464, 480, 496, 512, 528, 544, 560, 576, 592, 608, 624, 640, 656, 672] 46 | 47 | return T.Compose([ 48 | T.RandomHorizontalFlip(), 49 | T.RandomColorJitter(p=0.25), 50 | T.RandomSelect( 51 | T.RandomResize(scales, max_size=672), 52 | T.Compose([ 53 | T.RandomResize([400, 500, 600]), 54 | T.RandomSizeCrop(384, 600), 55 | T.RandomResize(scales, max_size=672), 56 | ]) 57 | ), 58 | normalize, 59 | ]) 60 | 61 | 62 | def build(args, image_set, activated_class_ids, with_support=True): 63 | assert image_set == "fewshot" 64 | activated_class_ids.sort() 65 | 66 | if args.dataset_file in ['coco_base']: 67 | root = Path('data/coco_fewshot') 68 | img_folder = "/home/csy/datasets/mscoco/train2017" 69 | ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json' 70 | return DetectionDataset(args, img_folder, str(ann_file), 71 | transforms=make_transforms(), 72 | support_transforms=make_support_transforms(), 73 | return_masks=False, 74 | activated_class_ids=activated_class_ids, 75 | with_support=with_support, 76 | cache_mode=args.cache_mode, 77 | local_rank=get_local_rank(), 78 | local_size=get_local_size()) 79 | 80 | if args.dataset_file == "voc_base1": 81 | root = Path('data/voc_fewshot_split1') 82 | img_folder = root.parent / 'voc' / "images" 83 | ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json' 84 | return DetectionDataset(args, img_folder, str(ann_file), 85 | transforms=make_transforms(), 86 | support_transforms=make_support_transforms(), 87 | return_masks=False, 88 | activated_class_ids=activated_class_ids, 89 | with_support=with_support, 90 | cache_mode=args.cache_mode, 91 | local_rank=get_local_rank(), 92 | local_size=get_local_size()) 93 | 94 | if args.dataset_file == "voc_base2": 95 | root = Path('data/voc_fewshot_split2') 96 | img_folder = root.parent / 'voc' / "images" 97 | ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json' 98 | return DetectionDataset(args, img_folder, str(ann_file), 99 | transforms=make_transforms(), 100 | support_transforms=make_support_transforms(), 101 | return_masks=False, 102 | activated_class_ids=activated_class_ids, 103 | with_support=with_support, 104 | cache_mode=args.cache_mode, 105 | local_rank=get_local_rank(), 106 | local_size=get_local_size()) 107 | 108 | if args.dataset_file == "voc_base3": 109 | root = Path('data/voc_fewshot_split3') 110 | img_folder = root.parent / 'voc' / "images" 111 | ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json' 112 | return DetectionDataset(args, img_folder, str(ann_file), 113 | transforms=make_transforms(), 114 | support_transforms=make_support_transforms(), 115 | return_masks=False, 116 | activated_class_ids=activated_class_ids, 117 | with_support=with_support, 118 | cache_mode=args.cache_mode, 119 | local_rank=get_local_rank(), 120 | local_size=get_local_size()) 121 | # uodd 122 | if args.dataset_file == "uodd": 123 | root = Path('data/UODD') 124 | img_folder = root / "train" 125 | ann_file = root / f'annotations' / f'{args.num_shots}_shot.json' 126 | return DetectionDataset(args, img_folder, str(ann_file), 127 | transforms=make_transforms(), 128 | support_transforms=make_support_transforms(), 129 | return_masks=False, 130 | activated_class_ids=activated_class_ids, 131 | with_support=with_support, 132 | cache_mode=args.cache_mode, 133 | local_rank=get_local_rank(), 134 | local_size=get_local_size()) 135 | 136 | # deepfish 137 | if args.dataset_file == "deepfish": 138 | root = Path('data/FISH') 139 | img_folder = root / "train" 140 | ann_file = root / f'annotations' / f'{args.num_shots}_shot.json' 141 | return DetectionDataset(args, img_folder, str(ann_file), 142 | transforms=make_transforms(), 143 | support_transforms=make_support_transforms(), 144 | return_masks=False, 145 | activated_class_ids=activated_class_ids, 146 | with_support=with_support, 147 | cache_mode=args.cache_mode, 148 | local_rank=get_local_rank(), 149 | local_size=get_local_size()) 150 | # neu 151 | if args.dataset_file == "neu": 152 | root = Path('data/NEU-DET') 153 | img_folder = root / "train" 154 | ann_file = root / f'annotations' / f'{args.num_shots}_shot.json' 155 | return DetectionDataset(args, img_folder, str(ann_file), 156 | transforms=make_transforms(), 157 | support_transforms=make_support_transforms(), 158 | return_masks=False, 159 | activated_class_ids=activated_class_ids, 160 | with_support=with_support, 161 | cache_mode=args.cache_mode, 162 | local_rank=get_local_rank(), 163 | local_size=get_local_size()) 164 | # clipart 165 | if args.dataset_file == "clipart": 166 | root = Path('data/clipart1k') 167 | img_folder = root / "train" 168 | ann_file = root / f'annotations' / f'{args.num_shots}_shot.json' 169 | return DetectionDataset(args, img_folder, str(ann_file), 170 | transforms=make_transforms(), 171 | support_transforms=make_support_transforms(), 172 | return_masks=False, 173 | activated_class_ids=activated_class_ids, 174 | with_support=with_support, 175 | cache_mode=args.cache_mode, 176 | local_rank=get_local_rank(), 177 | local_size=get_local_size()) 178 | # artaxor 179 | if args.dataset_file == "artaxor": 180 | root = Path('data/ArTaxOr') 181 | img_folder = root / "train" 182 | ann_file = root / f'annotations' / f'{args.num_shots}_shot.json' 183 | return DetectionDataset(args, img_folder, str(ann_file), 184 | transforms=make_transforms(), 185 | support_transforms=make_support_transforms(), 186 | return_masks=False, 187 | activated_class_ids=activated_class_ids, 188 | with_support=with_support, 189 | cache_mode=args.cache_mode, 190 | local_rank=get_local_rank(), 191 | local_size=get_local_size()) 192 | # dior 193 | if args.dataset_file == "dior": 194 | root = Path('data/DIOR') 195 | img_folder = root / "train" 196 | ann_file = root / f'annotations' / f'{args.num_shots}_shot.json' 197 | return DetectionDataset(args, img_folder, str(ann_file), 198 | transforms=make_transforms(), 199 | support_transforms=make_support_transforms(), 200 | return_masks=False, 201 | activated_class_ids=activated_class_ids, 202 | with_support=with_support, 203 | cache_mode=args.cache_mode, 204 | local_rank=get_local_rank(), 205 | local_size=get_local_size()) 206 | # dataset1 207 | if args.dataset_file == "dataset1": 208 | root = Path('data/dataset1') 209 | img_folder = root / "train" 210 | ann_file = root / f'annotations' / f'{args.num_shots}_shot.json' 211 | return DetectionDataset(args, img_folder, str(ann_file), 212 | transforms=make_transforms(), 213 | support_transforms=make_support_transforms(), 214 | return_masks=False, 215 | activated_class_ids=activated_class_ids, 216 | with_support=with_support, 217 | cache_mode=args.cache_mode, 218 | local_rank=get_local_rank(), 219 | local_size=get_local_size()) 220 | # dataset2 221 | if args.dataset_file == "dataset2": 222 | root = Path('data/dataset2') 223 | img_folder = root / "train" 224 | ann_file = root / f'annotations' / f'{args.num_shots}_shot.json' 225 | return DetectionDataset(args, img_folder, str(ann_file), 226 | transforms=make_transforms(), 227 | support_transforms=make_support_transforms(), 228 | return_masks=False, 229 | activated_class_ids=activated_class_ids, 230 | with_support=with_support, 231 | cache_mode=args.cache_mode, 232 | local_rank=get_local_rank(), 233 | local_size=get_local_size()) 234 | # dataset3 235 | if args.dataset_file == "dataset3": 236 | root = Path('data/dataset3') 237 | img_folder = root / "train" 238 | ann_file = root / f'annotations' / f'{args.num_shots}_shot.json' 239 | return DetectionDataset(args, img_folder, str(ann_file), 240 | transforms=make_transforms(), 241 | support_transforms=make_support_transforms(), 242 | return_masks=False, 243 | activated_class_ids=activated_class_ids, 244 | with_support=with_support, 245 | cache_mode=args.cache_mode, 246 | local_rank=get_local_rank(), 247 | local_size=get_local_size()) 248 | 249 | 250 | raise ValueError 251 | -------------------------------------------------------------------------------- /datasets/eval_detection.py: -------------------------------------------------------------------------------- 1 | """ 2 | COCO evaluator that works in distributed mode. 3 | 4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 5 | The difference is that there is less copy-pasting from pycocotools 6 | in the end of the file, as python3 can suppress prints with contextlib 7 | """ 8 | import os 9 | import contextlib 10 | import copy 11 | import numpy as np 12 | import torch 13 | import json 14 | 15 | from pycocotools.cocoeval import COCOeval 16 | from pycocotools.coco import COCO 17 | import pycocotools.mask as mask_util 18 | 19 | from util.misc import all_gather 20 | 21 | 22 | class DetectionEvaluator(object): 23 | def __init__(self, coco_gt, iou_types): 24 | assert isinstance(iou_types, (list, tuple)) 25 | coco_gt = copy.deepcopy(coco_gt) 26 | self.coco_gt = coco_gt 27 | 28 | self.iou_types = iou_types 29 | self.coco_eval = {} 30 | for iou_type in iou_types: 31 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 32 | 33 | self.img_ids = [] 34 | self.eval_imgs = {k: [] for k in iou_types} 35 | 36 | def update(self, predictions): 37 | img_ids = list(np.unique(list(predictions.keys()))) 38 | self.img_ids.extend(img_ids) 39 | 40 | for iou_type in self.iou_types: 41 | results = self.prepare(predictions, iou_type) 42 | 43 | # suppress pycocotools prints 44 | with open(os.devnull, 'w') as devnull: 45 | with contextlib.redirect_stdout(devnull): 46 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 47 | coco_eval = self.coco_eval[iou_type] 48 | 49 | coco_eval.cocoDt = coco_dt 50 | coco_eval.params.imgIds = list(img_ids) 51 | img_ids, eval_imgs = evaluate(coco_eval) 52 | 53 | self.eval_imgs[iou_type].append(eval_imgs) 54 | 55 | def synchronize_between_processes(self): 56 | for iou_type in self.iou_types: 57 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 58 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 59 | 60 | def accumulate(self): 61 | for coco_eval in self.coco_eval.values(): 62 | coco_eval.accumulate() 63 | 64 | def summarize(self): 65 | for iou_type, coco_eval in self.coco_eval.items(): 66 | print("IoU metric: {}".format(iou_type)) 67 | coco_eval.summarize() 68 | 69 | def prepare(self, predictions, iou_type): 70 | if iou_type == "bbox": 71 | return self.prepare_for_coco_detection(predictions) 72 | elif iou_type == "segm": 73 | return self.prepare_for_coco_segmentation(predictions) 74 | elif iou_type == "keypoints": 75 | return self.prepare_for_coco_keypoint(predictions) 76 | else: 77 | raise ValueError("Unknown iou type {}".format(iou_type)) 78 | 79 | def prepare_for_coco_detection(self, predictions): 80 | coco_results = [] 81 | for original_id, prediction in predictions.items(): 82 | if len(prediction) == 0: 83 | continue 84 | 85 | boxes = prediction["boxes"] 86 | boxes = convert_to_xywh(boxes).tolist() 87 | scores = prediction["scores"].tolist() 88 | labels = prediction["labels"].tolist() 89 | 90 | coco_results.extend( 91 | [ 92 | { 93 | "image_id": original_id, 94 | "category_id": labels[k], 95 | "bbox": box, 96 | "score": scores[k], 97 | } 98 | for k, box in enumerate(boxes) 99 | ] 100 | ) 101 | return coco_results 102 | 103 | def prepare_for_coco_segmentation(self, predictions): 104 | coco_results = [] 105 | for original_id, prediction in predictions.items(): 106 | if len(prediction) == 0: 107 | continue 108 | 109 | scores = prediction["scores"] 110 | labels = prediction["labels"] 111 | masks = prediction["masks"] 112 | 113 | masks = masks > 0.5 114 | 115 | scores = prediction["scores"].tolist() 116 | labels = prediction["labels"].tolist() 117 | 118 | rles = [ 119 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 120 | for mask in masks 121 | ] 122 | for rle in rles: 123 | rle["counts"] = rle["counts"].decode("utf-8") 124 | 125 | coco_results.extend( 126 | [ 127 | { 128 | "image_id": original_id, 129 | "category_id": labels[k], 130 | "segmentation": rle, 131 | "score": scores[k], 132 | } 133 | for k, rle in enumerate(rles) 134 | ] 135 | ) 136 | return coco_results 137 | 138 | def prepare_for_coco_keypoint(self, predictions): 139 | coco_results = [] 140 | for original_id, prediction in predictions.items(): 141 | if len(prediction) == 0: 142 | continue 143 | 144 | boxes = prediction["boxes"] 145 | boxes = convert_to_xywh(boxes).tolist() 146 | scores = prediction["scores"].tolist() 147 | labels = prediction["labels"].tolist() 148 | keypoints = prediction["keypoints"] 149 | keypoints = keypoints.flatten(start_dim=1).tolist() 150 | 151 | coco_results.extend( 152 | [ 153 | { 154 | "image_id": original_id, 155 | "category_id": labels[k], 156 | 'keypoints': keypoint, 157 | "score": scores[k], 158 | } 159 | for k, keypoint in enumerate(keypoints) 160 | ] 161 | ) 162 | return coco_results 163 | 164 | 165 | def convert_to_xywh(boxes): 166 | xmin, ymin, xmax, ymax = boxes.unbind(1) 167 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 168 | 169 | 170 | def merge(img_ids, eval_imgs): 171 | all_img_ids = all_gather(img_ids) 172 | all_eval_imgs = all_gather(eval_imgs) 173 | 174 | merged_img_ids = [] 175 | for p in all_img_ids: 176 | merged_img_ids.extend(p) 177 | 178 | merged_eval_imgs = [] 179 | for p in all_eval_imgs: 180 | merged_eval_imgs.append(p) 181 | 182 | merged_img_ids = np.array(merged_img_ids) 183 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 184 | 185 | # keep only unique (and in sorted order) images 186 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 187 | merged_eval_imgs = merged_eval_imgs[..., idx] 188 | 189 | return merged_img_ids, merged_eval_imgs 190 | 191 | 192 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 193 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 194 | img_ids = list(img_ids) 195 | eval_imgs = list(eval_imgs.flatten()) 196 | 197 | coco_eval.evalImgs = eval_imgs 198 | coco_eval.params.imgIds = img_ids 199 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 200 | 201 | 202 | ################################################################# 203 | # From pycocotools, just removed the prints and fixed 204 | # a Python3 bug about unicode not defined 205 | ################################################################# 206 | 207 | 208 | def evaluate(self): 209 | ''' 210 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 211 | :return: None 212 | ''' 213 | # tic = time.time() 214 | # print('Running per image evaluation...') 215 | p = self.params 216 | # add backward compatibility if useSegm is specified in params 217 | if p.useSegm is not None: 218 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 219 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 220 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 221 | p.imgIds = list(np.unique(p.imgIds)) 222 | if p.useCats: 223 | p.catIds = list(np.unique(p.catIds)) 224 | p.maxDets = sorted(p.maxDets) 225 | self.params = p 226 | 227 | self._prepare() 228 | # loop through images, area range, max detection number 229 | catIds = p.catIds if p.useCats else [-1] 230 | 231 | if p.iouType == 'segm' or p.iouType == 'bbox': 232 | computeIoU = self.computeIoU 233 | elif p.iouType == 'keypoints': 234 | computeIoU = self.computeOks 235 | self.ious = { 236 | (imgId, catId): computeIoU(imgId, catId) 237 | for imgId in p.imgIds 238 | for catId in catIds} 239 | 240 | evaluateImg = self.evaluateImg 241 | maxDet = p.maxDets[-1] 242 | evalImgs = [ 243 | evaluateImg(imgId, catId, areaRng, maxDet) 244 | for catId in catIds 245 | for areaRng in p.areaRng 246 | for imgId in p.imgIds 247 | ] 248 | # this is NOT in the pycocotools code, but could be done outside 249 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 250 | self._paramsEval = copy.deepcopy(self.params) 251 | # toc = time.time() 252 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 253 | return p.imgIds, evalImgs 254 | 255 | ################################################################# 256 | # end of straight copy from pycocotools, just removing the prints 257 | ################################################################# 258 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.distributed as dist 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class DistributedSampler(Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset. 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | .. note:: 15 | Dataset is assumed to be of constant size. 16 | Arguments: 17 | dataset: Dataset used for sampling. 18 | num_replicas (optional): Number of processes participating in 19 | distributed training. 20 | rank (optional): Rank of the current process within num_replicas. 21 | """ 22 | 23 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): 24 | if num_replicas is None: 25 | if not dist.is_available(): 26 | raise RuntimeError("Requires distributed package to be available") 27 | num_replicas = dist.get_world_size() 28 | if rank is None: 29 | if not dist.is_available(): 30 | raise RuntimeError("Requires distributed package to be available") 31 | rank = dist.get_rank() 32 | self.dataset = dataset 33 | self.num_replicas = num_replicas 34 | self.rank = rank 35 | self.epoch = 0 36 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 37 | self.total_size = self.num_samples * self.num_replicas 38 | self.shuffle = shuffle 39 | 40 | def __iter__(self): 41 | if self.shuffle: 42 | # deterministically shuffle based on epoch 43 | g = torch.Generator() 44 | g.manual_seed(self.epoch) 45 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 46 | else: 47 | indices = torch.arange(len(self.dataset)).tolist() 48 | 49 | # add extra samples to make it evenly divisible 50 | indices += indices[: (self.total_size - len(indices))] 51 | assert len(indices) == self.total_size 52 | 53 | # subsample 54 | offset = self.num_samples * self.rank 55 | indices = indices[offset : offset + self.num_samples] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices) 59 | 60 | def __len__(self): 61 | return self.num_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | 66 | 67 | class NodeDistributedSampler(Sampler): 68 | """Sampler that restricts data loading to a subset of the dataset. 69 | It is especially useful in conjunction with 70 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 71 | process can pass a DistributedSampler instance as a DataLoader sampler, 72 | and load a subset of the original dataset that is exclusive to it. 73 | .. note:: 74 | Dataset is assumed to be of constant size. 75 | Arguments: 76 | dataset: Dataset used for sampling. 77 | num_replicas (optional): Number of processes participating in 78 | distributed training. 79 | rank (optional): Rank of the current process within num_replicas. 80 | """ 81 | 82 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): 83 | if num_replicas is None: 84 | if not dist.is_available(): 85 | raise RuntimeError("Requires distributed package to be available") 86 | num_replicas = dist.get_world_size() 87 | if rank is None: 88 | if not dist.is_available(): 89 | raise RuntimeError("Requires distributed package to be available") 90 | rank = dist.get_rank() 91 | if local_rank is None: 92 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 93 | if local_size is None: 94 | local_size = int(os.environ.get('LOCAL_SIZE', 1)) 95 | self.dataset = dataset 96 | self.shuffle = shuffle 97 | self.num_replicas = num_replicas 98 | self.num_parts = local_size 99 | self.rank = rank 100 | self.local_rank = local_rank 101 | self.epoch = 0 102 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 103 | self.total_size = self.num_samples * self.num_replicas 104 | 105 | self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts 106 | 107 | def __iter__(self): 108 | if self.shuffle: 109 | # deterministically shuffle based on epoch 110 | g = torch.Generator() 111 | g.manual_seed(self.epoch) 112 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 113 | else: 114 | indices = torch.arange(len(self.dataset)).tolist() 115 | indices = [i for i in indices if i % self.num_parts == self.local_rank] 116 | 117 | # add extra samples to make it evenly divisible 118 | indices += indices[:(self.total_size_parts - len(indices))] 119 | assert len(indices) == self.total_size_parts 120 | 121 | # subsample 122 | indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts] 123 | assert len(indices) == self.num_samples 124 | 125 | return iter(indices) 126 | 127 | def __len__(self): 128 | return self.num_samples 129 | 130 | def set_epoch(self, epoch): 131 | self.epoch = epoch 132 | -------------------------------------------------------------------------------- /datasets/torchvision_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco import CocoDetection 2 | -------------------------------------------------------------------------------- /datasets/torchvision_datasets/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy-Paste from torchvision, but add utility of caching images on memory 3 | """ 4 | from torchvision.datasets.vision import VisionDataset 5 | from PIL import Image 6 | import os 7 | import os.path 8 | from io import BytesIO 9 | 10 | 11 | class CocoDetection(VisionDataset): 12 | """`MS Coco Detection `_ Dataset. 13 | Args: 14 | root (string): Root directory where images are downloaded to. 15 | annFile (string): Path to json annotation file. 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.ToTensor`` 18 | target_transform (callable, optional): A function/transform that takes in the 19 | target and transforms it. 20 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 21 | and returns a transformed version. 22 | """ 23 | 24 | def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None, 25 | cache_mode=False, local_rank=0, local_size=1): 26 | super(CocoDetection, self).__init__(root, transforms, transform, target_transform) 27 | from pycocotools.coco import COCO 28 | self.coco = COCO(annFile) 29 | self.ids = list(sorted(self.coco.imgs.keys())) 30 | self.cache_mode = cache_mode 31 | self.local_rank = local_rank 32 | self.local_size = local_size 33 | if cache_mode: 34 | self.cache = {} 35 | self.cache_images() 36 | 37 | def cache_images(self): 38 | self.cache = {} 39 | for index, img_id in zip(range(len(self.ids)), self.ids): 40 | if index % self.local_size != self.local_rank: 41 | continue 42 | path = self.coco.loadImgs(img_id)[0]['file_name'] 43 | with open(os.path.join(self.root, path), 'rb') as f: 44 | self.cache[path] = f.read() 45 | 46 | def get_image(self, path): 47 | if self.cache_mode: 48 | if path not in self.cache.keys(): 49 | with open(os.path.join(self.root, path), 'rb') as f: 50 | self.cache[path] = f.read() 51 | return Image.open(BytesIO(self.cache[path])).convert('RGB') 52 | return Image.open(os.path.join(self.root, path)).convert('RGB') 53 | 54 | def __getitem__(self, index): 55 | """ 56 | Args: 57 | index (int): Index 58 | Returns: 59 | tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. 60 | """ 61 | coco = self.coco 62 | img_id = self.ids[index] 63 | ann_ids = coco.getAnnIds(imgIds=img_id) 64 | target = coco.loadAnns(ann_ids) 65 | 66 | path = coco.loadImgs(img_id)[0]['file_name'] 67 | # print(path) 68 | img = self.get_image(path) 69 | if self.transforms is not None: 70 | img, target = self.transforms(img, target) 71 | 72 | return img, target 73 | 74 | def __len__(self): 75 | return len(self.ids) 76 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transforms and data augmentation for both image + bbox. 3 | """ 4 | import random 5 | 6 | import PIL 7 | import torch 8 | import torchvision.transforms as T 9 | import torchvision.transforms.functional as F 10 | 11 | from util.box_ops import box_xyxy_to_cxcywh 12 | from util.misc import interpolate 13 | 14 | 15 | def crop(image, target, region): 16 | cropped_image = F.crop(image, *region) 17 | 18 | target = target.copy() 19 | i, j, h, w = region 20 | 21 | # should we do something wrt the original size? 22 | target["size"] = torch.tensor([h, w]) 23 | 24 | fields = ["labels", "area", "iscrowd"] 25 | 26 | if "boxes" in target: 27 | boxes = target["boxes"] 28 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 29 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 30 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 31 | cropped_boxes = cropped_boxes.clamp(min=0) 32 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 33 | target["boxes"] = cropped_boxes.reshape(-1, 4) 34 | target["area"] = area 35 | fields.append("boxes") 36 | 37 | if "masks" in target: 38 | # FIXME should we update the area here if there are no boxes? 39 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 40 | fields.append("masks") 41 | 42 | # remove elements for which the boxes or masks that have zero area 43 | if "boxes" in target or "masks" in target: 44 | # favor boxes selection when defining which elements to keep 45 | # this is compatible with previous implementation 46 | if "boxes" in target: 47 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 48 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 49 | else: 50 | keep = target['masks'].flatten(1).any(1) 51 | 52 | for field in fields: 53 | target[field] = target[field][keep] 54 | 55 | return cropped_image, target 56 | 57 | 58 | def hflip(image, target): 59 | flipped_image = F.hflip(image) 60 | 61 | w, h = image.size 62 | 63 | target = target.copy() 64 | if "boxes" in target: 65 | boxes = target["boxes"] 66 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 67 | target["boxes"] = boxes 68 | 69 | if "masks" in target: 70 | target['masks'] = target['masks'].flip(-1) 71 | 72 | return flipped_image, target 73 | 74 | 75 | def resize(image, target, size, max_size=None): 76 | # size can be min_size (scalar) or (w, h) tuple 77 | 78 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 79 | w, h = image_size 80 | if max_size is not None: 81 | min_original_size = float(min((w, h))) 82 | max_original_size = float(max((w, h))) 83 | if max_original_size / min_original_size * size > max_size: 84 | size = int(round(max_size * min_original_size / max_original_size)) 85 | 86 | if (w <= h and w == size) or (h <= w and h == size): 87 | return (h, w) 88 | 89 | if w < h: 90 | ow = size 91 | oh = int(size * h / w) 92 | else: 93 | oh = size 94 | ow = int(size * w / h) 95 | 96 | return (oh, ow) 97 | 98 | def get_size(image_size, size, max_size=None): 99 | if isinstance(size, (list, tuple)): 100 | return size[::-1] 101 | else: 102 | return get_size_with_aspect_ratio(image_size, size, max_size) 103 | 104 | size = get_size(image.size, size, max_size) 105 | rescaled_image = F.resize(image, size) 106 | 107 | if target is None: 108 | return rescaled_image, None 109 | 110 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 111 | ratio_width, ratio_height = ratios 112 | 113 | target = target.copy() 114 | if "boxes" in target: 115 | boxes = target["boxes"] 116 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 117 | target["boxes"] = scaled_boxes 118 | 119 | if "area" in target: 120 | area = target["area"] 121 | scaled_area = area * (ratio_width * ratio_height) 122 | target["area"] = scaled_area 123 | 124 | h, w = size 125 | target["size"] = torch.tensor([h, w]) 126 | 127 | if "masks" in target: 128 | target['masks'] = interpolate( target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 129 | 130 | return rescaled_image, target 131 | 132 | 133 | def pad(image, target, padding): 134 | # assumes that we only pad on the bottom right corners 135 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 136 | if target is None: 137 | return padded_image, None 138 | target = target.copy() 139 | # should we do something wrt the original size? 140 | target["size"] = torch.tensor(padded_image[::-1]) 141 | if "masks" in target: 142 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 143 | return padded_image, target 144 | 145 | 146 | class RandomCrop(object): 147 | def __init__(self, size): 148 | self.size = size 149 | 150 | def __call__(self, img, target): 151 | region = T.RandomCrop.get_params(img, self.size) 152 | return crop(img, target, region) 153 | 154 | 155 | class RandomSizeCrop(object): 156 | def __init__(self, min_size: int, max_size: int): 157 | self.min_size = min_size 158 | self.max_size = max_size 159 | 160 | def __call__(self, img: PIL.Image.Image, target: dict): 161 | w = random.randint(self.min_size, min(img.width, self.max_size)) 162 | h = random.randint(self.min_size, min(img.height, self.max_size)) 163 | region = T.RandomCrop.get_params(img, [h, w]) 164 | return crop(img, target, region) 165 | 166 | 167 | class CenterCrop(object): 168 | def __init__(self, size): 169 | self.size = size 170 | 171 | def __call__(self, img, target): 172 | image_width, image_height = img.size 173 | crop_height, crop_width = self.size 174 | crop_top = int(round((image_height - crop_height) / 2.)) 175 | crop_left = int(round((image_width - crop_width) / 2.)) 176 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 177 | 178 | 179 | class RandomHorizontalFlip(object): 180 | def __init__(self, p=0.5): 181 | self.p = p 182 | 183 | def __call__(self, img, target): 184 | if random.random() < self.p: 185 | return hflip(img, target) 186 | return img, target 187 | 188 | 189 | class RandomResize(object): 190 | def __init__(self, sizes, max_size=None): 191 | assert isinstance(sizes, (list, tuple)) 192 | self.sizes = sizes 193 | self.max_size = max_size 194 | 195 | def __call__(self, img, target=None): 196 | size = random.choice(self.sizes) 197 | return resize(img, target, size, self.max_size) 198 | 199 | 200 | class RandomPad(object): 201 | def __init__(self, max_pad): 202 | self.max_pad = max_pad 203 | 204 | def __call__(self, img, target): 205 | pad_x = random.randint(0, self.max_pad) 206 | pad_y = random.randint(0, self.max_pad) 207 | return pad(img, target, (pad_x, pad_y)) 208 | 209 | 210 | class RandomSelect(object): 211 | """ 212 | Randomly selects between transforms1 and transforms2, 213 | with probability p for transforms1 and (1 - p) for transforms2 214 | """ 215 | def __init__(self, transforms1, transforms2, p=0.5): 216 | self.transforms1 = transforms1 217 | self.transforms2 = transforms2 218 | self.p = p 219 | 220 | def __call__(self, img, target): 221 | if random.random() < self.p: 222 | return self.transforms1(img, target) 223 | return self.transforms2(img, target) 224 | 225 | 226 | class ToTensor(object): 227 | def __call__(self, img, target): 228 | return F.to_tensor(img), target 229 | 230 | 231 | class RandomErasing(object): 232 | def __init__(self, *args, **kwargs): 233 | self.eraser = T.RandomErasing(*args, **kwargs) 234 | 235 | def __call__(self, img, target): 236 | return self.eraser(img), target 237 | 238 | 239 | class RandomColorJitter(object): 240 | def __init__(self, p=0.5): 241 | self.p = p 242 | self.colorjitter = T.ColorJitter(brightness=0.40, contrast=0.40, saturation=0.40, hue=0.20) 243 | 244 | def __call__(self, img, target): 245 | if random.random() < self.p: 246 | return self.colorjitter(img), target 247 | else: 248 | return img, target 249 | 250 | 251 | class Normalize(object): 252 | def __init__(self, mean, std): 253 | self.mean = mean 254 | self.std = std 255 | 256 | def __call__(self, image, target=None): 257 | image = F.normalize(image, mean=self.mean, std=self.std) 258 | if target is None: 259 | return image, None 260 | target = target.copy() 261 | h, w = image.shape[-2:] 262 | if "boxes" in target: 263 | boxes = target["boxes"] 264 | boxes = box_xyxy_to_cxcywh(boxes) 265 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 266 | target["boxes"] = boxes 267 | return image, target 268 | 269 | 270 | class Compose(object): 271 | def __init__(self, transforms): 272 | self.transforms = transforms 273 | 274 | def __call__(self, image, target): 275 | for t in self.transforms: 276 | image, target = t(image, target) 277 | return image, target 278 | 279 | def __repr__(self): 280 | format_string = self.__class__.__name__ + "(" 281 | for t in self.transforms: 282 | format_string += "\n" 283 | format_string += " {0}".format(t) 284 | format_string += "\n)" 285 | return format_string 286 | -------------------------------------------------------------------------------- /datasets/visual.py: -------------------------------------------------------------------------------- 1 | #画图相关 2 | import util.misc as utils 3 | from datasets.eval_detection import DetectionEvaluator 4 | import math 5 | import matplotlib.pyplot as plt 6 | import matplotlib.patches as patches 7 | import numpy as np 8 | import torchvision.transforms 9 | import os 10 | import random 11 | from PIL import Image 12 | import matplotlib.colors as mcolors 13 | import matplotlib.colors as mplc 14 | import colorsys 15 | 16 | def change_color_brightness(color, brightness_factor = -0.7): 17 | """ 18 | Depending on the brightness_factor, gives a lighter or darker color i.e. a color with 19 | less or more saturation than the original color. 20 | 21 | Args: 22 | color: color of the polygon. Refer to `matplotlib.colors` for a full list of 23 | formats that are accepted. 24 | brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of 25 | 0 will correspond to no change, a factor in [-1.0, 0) range will result in 26 | a darker color and a factor in (0, 1.0] range will result in a lighter color. 27 | 28 | Returns: 29 | modified_color (tuple[double]): a tuple containing the RGB values of the 30 | modified color. Each value in the tuple is in the [0.0, 1.0] range. 31 | """ 32 | assert brightness_factor >= -1.0 and brightness_factor <= 1.0 33 | color = mplc.to_rgb(color) 34 | polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) 35 | modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) 36 | modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness 37 | modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness 38 | modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2]) 39 | return modified_color 40 | 41 | def draw_text(ax, text, position, font_size=None, color="g", horizontal_alignment="left", rotation=0): 42 | """ 43 | 在给定的Axes上绘制文本。 44 | 45 | Args: 46 | ax: 要绘制的matplotlib Axes对象。 47 | text (str): 要绘制的文本。 48 | position (tuple): 文本位置的x和y坐标。 49 | font_size (int): 文本字体大小。 50 | color: 文本颜色。 51 | horizontal_alignment (str): 水平对齐方式。 52 | rotation: 旋转角度(单位为度)。 53 | """ 54 | # 确保文本颜色明亮 55 | color = np.maximum(list(mcolors.to_rgb(color)), 0.2) 56 | color[np.argmax(color)] = max(0.8, np.max(color)) 57 | 58 | x, y = position 59 | ax.text( 60 | x, 61 | y, 62 | text, 63 | size=font_size, 64 | family="sans-serif", 65 | bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.5, "edgecolor": "none"}, 66 | verticalalignment="top", 67 | horizontalalignment=horizontal_alignment, 68 | color=color, 69 | zorder=10, 70 | rotation=rotation, 71 | ) 72 | 73 | 74 | def visualize_predictions(image, targets, predictions, output_dir, draw_gt, draw_pre): 75 | UODD_TEST = ['seacucumber', 'seaurchin', 'scallop'] 76 | clipart1k_TEST = ['sheep', 'chair', 'boat', 'bottle', 'diningtable', 'sofa', 'cow', 'motorbike', 'car', 'aeroplane', 'cat', 'train', 'person', 'bicycle', 'pottedplant', 'bird', 'dog', 'bus', 'tvmonitor', 'horse'] 77 | NEUDET_TEST = ['crazing', 'inclusion', 'patches', 'pitted_surface', 'rolled-in_scale', 'scratches'] 78 | ArTaxOr_TEST = ['Araneae', 'Coleoptera', 'Diptera', 'Hemiptera', 'Hymenoptera', 'Lepidoptera', 'Odonata'] 79 | DIOR_TEST = ['Expressway-Service-area','Expressway-toll-station','airplane','airport','baseballfield','basketballcourt','bridge','chimney','dam', 'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship', 'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle', 'windmill'] 80 | FISH_TEST = ['fish'] 81 | 82 | class_names = UODD_TEST 83 | # CUSTOM_COLORS = [ 84 | # (220/255, 20/255, 60/255), (0/255, 128/255, 0/255), (0/255, 0/255, 255/255), 85 | # (255/255, 140/255, 0/255), (128/255, 0/255, 128/255), 86 | # (0/255, 255/255, 255/255), (255/255, 0/255, 255/255), (75/255, 0/255, 130/255), 87 | # (0/255, 255/255, 127/255), (255/255, 69/255, 0/255), 88 | # (255/255, 228/255, 181/255), (255/255, 20/255, 147/255), 89 | # (154/255, 205/255, 50/255), (139/255, 69/255, 19/255), 90 | # (255/255, 215/255, 0/255), 91 | # (0/255, 191/255, 255/255), (47/255, 79/255, 79/255), 92 | # (188/255, 143/255, 143/255), (255/255, 99/255, 71/255), 93 | # (205/255, 92/255, 92/255), 94 | # (144/255, 238/255, 144/255), (30/255, 144/255, 255/255), 95 | # (128/255, 128/255, 0/255), (107/255, 142/255, 35/255), 96 | # (255/255, 127/255, 80/255) 97 | # ] 98 | #DIOR 99 | CUSTOM_COLORS = [ 100 | (220/255, 20/255, 60/255), (205/255, 92/255, 92/255), (255/255, 215/255, 0/255), 101 | (255/255, 140/255, 0/255), (255/255, 215/255, 0/255), 102 | (255/255, 215/255, 0/255), (255/255, 0/255, 255/255), (75/255, 0/255, 130/255), 103 | (0/255, 255/255, 127/255), (255/255, 69/255, 0/255), 104 | (255/255, 228/255, 181/255), (255/255, 20/255, 147/255), 105 | (154/255, 205/255, 50/255), (139/255, 69/255, 19/255), 106 | (255/255, 215/255, 0/255), 107 | (0/255, 191/255, 255/255), (47/255, 79/255, 79/255), 108 | (188/255, 143/255, 143/255), (255/255, 99/255, 71/255), 109 | (205/255, 92/255, 92/255), 110 | (144/255, 238/255, 144/255), (30/255, 144/255, 255/255), 111 | (128/255, 128/255, 0/255), (107/255, 142/255, 35/255), 112 | (255/255, 127/255, 80/255) 113 | ] 114 | fig, ax = plt.subplots(1, figsize=(12, 9)) 115 | ax.imshow(image) 116 | 117 | # 获取图像尺寸 118 | img_width, img_height = image.size 119 | default_font_size = max(np.sqrt(img_width * img_height) // 90, 10) 120 | 121 | if draw_gt: 122 | # 绘制gt框 123 | boxes = targets["boxes"].cpu().numpy() 124 | labels = targets["labels"].cpu().numpy() 125 | for i, box in enumerate(boxes): 126 | x, y, w, h = box 127 | xmin, ymin, xmax, ymax = x, y, x + w, y + h 128 | # 需要减一是uodd以外的数据集 129 | if class_names != UODD_TEST: 130 | label = class_names[labels[i] - 1] 131 | label_index = labels[i] - 1 132 | else: 133 | label = class_names[labels[i]] 134 | label_index = labels[i] 135 | # rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='r', line_style="-", alpha=0.5) 136 | # ax.add_patch(rect) 137 | # ax.text(x, y - 10, f'{label}', color='r', fontsize=12, backgroundcolor="none") 138 | instance_area = w * h 139 | if instance_area < 1000 or h < 40: 140 | if ymax >= h - 5: 141 | text_pos = (xmax, ymin) 142 | else: 143 | text_pos = (xmin, ymax) 144 | 145 | height_ratio = h / np.sqrt(h * w) 146 | lighter_color = change_color_brightness(CUSTOM_COLORS[label_index % len(CUSTOM_COLORS)], brightness_factor=0.7) 147 | font_size = ( 148 | np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) 149 | * 0.5 150 | * default_font_size 151 | ) 152 | rect = patches.Rectangle((x, y), w, h, linewidth=3, 153 | edgecolor=CUSTOM_COLORS[label_index % len(CUSTOM_COLORS)], facecolor='none') 154 | ax.add_patch(rect) 155 | draw_text(ax, f'{label}', (x, y), color=lighter_color, font_size=font_size) # 使用 draw_text 绘制文本 156 | 157 | if draw_pre: 158 | # 绘制预测框 159 | boxes = predictions["boxes"].cpu().numpy() 160 | scores = predictions["scores"].cpu().numpy() 161 | labels = predictions["labels"].cpu().numpy() 162 | for i, box in enumerate(boxes): 163 | score = scores[i] 164 | if score > 0.2: # 只可视化置信度大于0.5的预测 165 | xmin, ymin, xmax, ymax = box 166 | x, y, w, h = xmin, ymin, xmax - xmin, ymax - ymin 167 | # 需要减一是uodd以外的数据集 168 | if class_names != UODD_TEST: 169 | label = class_names[labels[i] - 1] 170 | label_index = labels[i] - 1 171 | else: 172 | label = class_names[labels[i]] 173 | label_index = labels[i] 174 | # rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='r', line_style="-", alpha=0.5) 175 | # ax.add_patch(rect) 176 | # ax.text(x, y - 10, f'{label}', color='r', fontsize=12, backgroundcolor="none") 177 | instance_area = w * h 178 | if instance_area < 1000 or h < 40: 179 | if ymax >= h - 5: 180 | text_pos = (xmax, ymin) 181 | else: 182 | text_pos = (xmin, ymax) 183 | 184 | height_ratio = h / np.sqrt(h * w) 185 | lighter_color = change_color_brightness(CUSTOM_COLORS[label_index % len(CUSTOM_COLORS)], brightness_factor=0.7) 186 | font_size = ( 187 | np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) 188 | * 0.5 189 | * default_font_size 190 | ) 191 | rect = patches.Rectangle((x, y), w, h, linewidth=3, 192 | edgecolor=CUSTOM_COLORS[label_index % len(CUSTOM_COLORS)], facecolor='none') 193 | ax.add_patch(rect) 194 | draw_text(ax, f'{label}', (x, y), color=lighter_color, font_size=font_size) # 使用 draw_text 绘制文本 195 | 196 | 197 | plt.axis('off') 198 | 199 | # 获取 image_id 并构造输出文件路径 200 | image_id = targets["image_id"].item() 201 | output_path = os.path.join(output_dir, f'{image_id}.jpg') 202 | 203 | # 保存为 jpg 文件 204 | plt.savefig(output_path, bbox_inches='tight', pad_inches=0, format='jpg') 205 | plt.close(fig) 206 | print(f"Image saved to {output_path}") 207 | -------------------------------------------------------------------------------- /dinov2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LONGXUANX/CDFormer_code/da25d484d16a637ff5364575841e4c58239ab212/dinov2/__init__.py -------------------------------------------------------------------------------- /dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention 13 | -------------------------------------------------------------------------------- /dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | import torch 23 | 24 | XFORMERS_AVAILABLE = torch.cuda.is_available() 25 | except ImportError: 26 | logger.warning("xFormers not available") 27 | XFORMERS_AVAILABLE = False 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__( 32 | self, 33 | dim: int, 34 | num_heads: int = 8, 35 | qkv_bias: bool = False, 36 | proj_bias: bool = True, 37 | attn_drop: float = 0.0, 38 | proj_drop: float = 0.0, 39 | ) -> None: 40 | super().__init__() 41 | self.num_heads = num_heads 42 | head_dim = dim // num_heads 43 | self.scale = head_dim**-0.5 44 | 45 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 46 | self.attn_drop = nn.Dropout(attn_drop) 47 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 48 | self.proj_drop = nn.Dropout(proj_drop) 49 | 50 | def forward(self, x: Tensor) -> Tensor: 51 | B, N, C = x.shape 52 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 53 | 54 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 55 | attn = q @ k.transpose(-2, -1) 56 | 57 | attn = attn.softmax(dim=-1) 58 | attn = self.attn_drop(attn) 59 | 60 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 61 | x = self.proj(x) 62 | x = self.proj_drop(x) 63 | return x 64 | 65 | 66 | class MemEffAttention(Attention): 67 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 68 | if (not XFORMERS_AVAILABLE) or (not x.is_cuda): 69 | assert attn_bias is None, "xFormers is required for nested tensors usage" 70 | return super().forward(x) 71 | 72 | # from pudb.remote import set_trace; set_trace(term_size=(202, 47), port=12345) 73 | B, N, C = x.shape 74 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 75 | 76 | q, k, v = unbind(qkv, 2) 77 | 78 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) #! different results when in vscode and bash 79 | x = x.reshape([B, N, C]) 80 | 81 | x = self.proj(x) 82 | x = self.proj_drop(x) 83 | return x 84 | -------------------------------------------------------------------------------- /dinov2/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | import logging 12 | from typing import Callable, List, Any, Tuple, Dict 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | 17 | from .attention import Attention, MemEffAttention 18 | from .drop_path import DropPath 19 | from .layer_scale import LayerScale 20 | from .mlp import Mlp 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | try: 27 | from xformers.ops import fmha 28 | from xformers.ops import scaled_index_add, index_select_cat 29 | 30 | XFORMERS_AVAILABLE = torch.cuda.is_available() 31 | except ImportError: 32 | logger.warning("xFormers not available") 33 | XFORMERS_AVAILABLE = False 34 | 35 | 36 | class Block(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | ffn_bias: bool = True, 45 | drop: float = 0.0, 46 | attn_drop: float = 0.0, 47 | init_values=None, 48 | drop_path: float = 0.0, 49 | act_layer: Callable[..., nn.Module] = nn.GELU, 50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 51 | attn_class: Callable[..., nn.Module] = Attention, 52 | ffn_layer: Callable[..., nn.Module] = Mlp, 53 | ) -> None: 54 | super().__init__() 55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 56 | self.norm1 = norm_layer(dim) 57 | self.attn = attn_class( 58 | dim, 59 | num_heads=num_heads, 60 | qkv_bias=qkv_bias, 61 | proj_bias=proj_bias, 62 | attn_drop=attn_drop, 63 | proj_drop=drop, 64 | ) 65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 67 | 68 | self.norm2 = norm_layer(dim) 69 | mlp_hidden_dim = int(dim * mlp_ratio) 70 | self.mlp = ffn_layer( 71 | in_features=dim, 72 | hidden_features=mlp_hidden_dim, 73 | act_layer=act_layer, 74 | drop=drop, 75 | bias=ffn_bias, 76 | ) 77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 79 | 80 | self.sample_drop_ratio = drop_path 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | def attn_residual_func(x: Tensor) -> Tensor: 84 | return self.ls1(self.attn(self.norm1(x))) 85 | 86 | def ffn_residual_func(x: Tensor) -> Tensor: 87 | return self.ls2(self.mlp(self.norm2(x))) 88 | 89 | if self.training and self.sample_drop_ratio > 0.1: 90 | # the overhead is compensated only for a drop path rate larger than 0.1 91 | x = drop_add_residual_stochastic_depth( 92 | x, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | # from pudb.remote import set_trace; set_trace(term_size=(202, 47), port=12345) 106 | x = x + attn_residual_func(x) 107 | x = x + ffn_residual_func(x) 108 | return x 109 | 110 | 111 | def drop_add_residual_stochastic_depth( 112 | x: Tensor, 113 | residual_func: Callable[[Tensor], Tensor], 114 | sample_drop_ratio: float = 0.0, 115 | ) -> Tensor: 116 | # 1) extract subset using permutation 117 | b, n, d = x.shape 118 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 119 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 120 | x_subset = x[brange] 121 | 122 | # 2) apply residual_func to get residual 123 | residual = residual_func(x_subset) 124 | 125 | x_flat = x.flatten(1) 126 | residual = residual.flatten(1) 127 | 128 | residual_scale_factor = b / sample_subset_size 129 | 130 | # 3) add the residual 131 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 132 | return x_plus_residual.view_as(x) 133 | 134 | 135 | def get_branges_scales(x, sample_drop_ratio=0.0): 136 | b, n, d = x.shape 137 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 138 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 139 | residual_scale_factor = b / sample_subset_size 140 | return brange, residual_scale_factor 141 | 142 | 143 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 144 | if scaling_vector is None: 145 | x_flat = x.flatten(1) 146 | residual = residual.flatten(1) 147 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 148 | else: 149 | x_plus_residual = scaled_index_add( 150 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 151 | ) 152 | return x_plus_residual 153 | 154 | 155 | attn_bias_cache: Dict[Tuple, Any] = {} 156 | 157 | 158 | def get_attn_bias_and_cat(x_list, branges=None): 159 | """ 160 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 161 | """ 162 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 163 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 164 | if all_shapes not in attn_bias_cache.keys(): 165 | seqlens = [] 166 | for b, x in zip(batch_sizes, x_list): 167 | for _ in range(b): 168 | seqlens.append(x.shape[1]) 169 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 170 | attn_bias._batch_sizes = batch_sizes 171 | attn_bias_cache[all_shapes] = attn_bias 172 | 173 | if branges is not None: 174 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 175 | else: 176 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 177 | cat_tensors = torch.cat(tensors_bs1, dim=1) 178 | 179 | return attn_bias_cache[all_shapes], cat_tensors 180 | 181 | 182 | def drop_add_residual_stochastic_depth_list( 183 | x_list: List[Tensor], 184 | residual_func: Callable[[Tensor, Any], Tensor], 185 | sample_drop_ratio: float = 0.0, 186 | scaling_vector=None, 187 | ) -> Tensor: 188 | # 1) generate random set of indices for dropping samples in the batch 189 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 190 | branges = [s[0] for s in branges_scales] 191 | residual_scale_factors = [s[1] for s in branges_scales] 192 | 193 | # 2) get attention bias and index+concat the tensors 194 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 195 | 196 | # 3) apply residual_func to get residual, and split the result 197 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 198 | 199 | outputs = [] 200 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 201 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 202 | return outputs 203 | 204 | 205 | class NestedTensorBlock(Block): 206 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 207 | """ 208 | x_list contains a list of tensors to nest together and run 209 | """ 210 | assert isinstance(self.attn, MemEffAttention) 211 | 212 | if self.training and self.sample_drop_ratio > 0.0: 213 | 214 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 215 | return self.attn(self.norm1(x), attn_bias=attn_bias) 216 | 217 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 218 | return self.mlp(self.norm2(x)) 219 | 220 | x_list = drop_add_residual_stochastic_depth_list( 221 | x_list, 222 | residual_func=attn_residual_func, 223 | sample_drop_ratio=self.sample_drop_ratio, 224 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 225 | ) 226 | x_list = drop_add_residual_stochastic_depth_list( 227 | x_list, 228 | residual_func=ffn_residual_func, 229 | sample_drop_ratio=self.sample_drop_ratio, 230 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 231 | ) 232 | return x_list 233 | else: 234 | 235 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 236 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 237 | 238 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 239 | return self.ls2(self.mlp(self.norm2(x))) 240 | 241 | attn_bias, x = get_attn_bias_and_cat(x_list) 242 | x = x + attn_residual_func(x, attn_bias=attn_bias) 243 | x = x + ffn_residual_func(x) 244 | return attn_bias.split(x) 245 | 246 | def forward(self, x_or_x_list): 247 | if isinstance(x_or_x_list, Tensor): 248 | return super().forward(x_or_x_list) 249 | elif isinstance(x_or_x_list, list): 250 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 251 | return self.forward_nested(x_or_x_list) 252 | else: 253 | raise AssertionError 254 | -------------------------------------------------------------------------------- /dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | import torch 39 | 40 | XFORMERS_AVAILABLE = torch.cuda.is_available() 41 | except ImportError: 42 | SwiGLU = SwiGLUFFN 43 | XFORMERS_AVAILABLE = False 44 | 45 | 46 | class SwiGLUFFNFused(SwiGLU): 47 | def __init__( 48 | self, 49 | in_features: int, 50 | hidden_features: Optional[int] = None, 51 | out_features: Optional[int] = None, 52 | act_layer: Callable[..., nn.Module] = None, 53 | drop: float = 0.0, 54 | bias: bool = True, 55 | ) -> None: 56 | out_features = out_features or in_features 57 | hidden_features = hidden_features or in_features 58 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 59 | super().__init__( 60 | in_features=in_features, 61 | hidden_features=hidden_features, 62 | out_features=out_features, 63 | bias=bias, 64 | ) 65 | -------------------------------------------------------------------------------- /dinov2/pad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import detectron2.data.transforms as T 4 | from fvcore.transforms.transform import ( 5 | PadTransform, 6 | Transform, 7 | TransformList 8 | ) 9 | 10 | from detectron2.data.transforms.augmentation import Augmentation 11 | 12 | class SizeDivisibilityPad(Augmentation): 13 | 14 | @torch.jit.unused 15 | def __init__(self, divide_by=14): 16 | """ 17 | Args: 18 | prob (float): probability of flip. 19 | horizontal (boolean): whether to apply horizontal flipping 20 | vertical (boolean): whether to apply vertical flipping 21 | """ 22 | super().__init__() 23 | self.divide_by = divide_by 24 | 25 | @torch.jit.unused 26 | def get_transform(self, image): 27 | h, w = image.shape[:2] 28 | py = int(math.ceil(h / self.divide_by)) * self.divide_by - h 29 | px = int(math.ceil(w / self.divide_by)) * self.divide_by - w 30 | 31 | py0 = py // 2 32 | py1 = py - py0 33 | 34 | px0 = px // 2 35 | px1 = px - px0 36 | 37 | return PadTransform(px0, py0, px1, py1) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .CDFormer import build 2 | 3 | 4 | def build_model(args): 5 | return build(args) 6 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | """ Scaled Dot-Product Attention """ 9 | def __init__(self, temperature, attn_dropout=0.0): 10 | super().__init__() 11 | self.temperature = temperature 12 | self.dropout = nn.Dropout(attn_dropout) 13 | self.softmax = nn.Softmax(dim=2) 14 | 15 | def forward(self, q, k, v): 16 | attn = torch.bmm(q, k.transpose(1, 2)) 17 | attn = attn / self.temperature 18 | log_attn = F.log_softmax(attn, 2) 19 | attn = self.softmax(attn) 20 | attn = self.dropout(attn) 21 | output = torch.bmm(attn, v) 22 | return output, attn, log_attn 23 | 24 | 25 | class SingleHeadSiameseAttention(nn.Module): 26 | # 在计算support特征编码时,不同类别之间的信息会有交互,因为此时class类不再是batch维度 27 | # 而计算query和support时,不同类别之间的信息会有交互,因为此时class类不再是batch维度 28 | # 计算特征编码(这里是下一层的输入,不是category向量)时, 29 | # (类别数为batch_size*441*256) @ (类别数为batch_size*256*5) => (441*5)@(5*256)=> out(类别数为batch_size,441,256)送入下一层 30 | # 计算query和support时,query的batch为batch_size 31 | """ Single-Head Attention Module. Weights for Q and K are shared in a Siamese manner. No proj weights for V.""" 32 | def __init__(self, d_model): 33 | super().__init__() 34 | self.n_head = 1 35 | self.d_model = d_model 36 | self.w_qk = nn.Linear(self.d_model, self.n_head * self.d_model, bias=False) 37 | self.attention = ScaledDotProductAttention(temperature=np.power(self.d_model, 0.5)) 38 | nn.init.normal_(self.w_qk.weight, mean=0, std=np.sqrt(2.0 / (self.d_model + self.d_model))) 39 | 40 | self.dummy = nn.Parameter(torch.Tensor(1, self.d_model)) 41 | nn.init.normal_(self.dummy) 42 | 43 | self.linear1 = nn.Sequential(nn.Linear(self.d_model, self.d_model // 2), nn.ReLU(inplace=True)) 44 | self.linear2 = nn.Sequential(nn.Linear(self.d_model, self.d_model // 2), nn.ReLU(inplace=True)) 45 | self.linear3 = nn.Linear(self.d_model * 2, self.d_model) 46 | 47 | def forward(self, q, k, v, tsp, supp_class_id): 48 | sz_b, len_q, _ = q.size() 49 | sz_b, len_k, _ = k.size() 50 | sz_b, len_v, _ = v.size() 51 | 52 | # tsp为[batch_size,eposide_size, 256] 53 | sz_b, len_tsp, _ = tsp.size() 54 | 55 | residual = q 56 | q = self.w_qk(q).view(sz_b, len_q, self.n_head, self.d_model) 57 | k = self.w_qk(k).view(sz_b, len_k, self.n_head, self.d_model) 58 | v = v.view(sz_b, len_v, self.n_head, self.d_model) 59 | 60 | # tsp = tsp.view(sz_b, len_v, self.n_head, self.d_model) 61 | # 这里要注意,batch_size要与qkv一致,而len_tsp则是自己的 62 | # 因为我们引入了背景信息,而tsp仍然保持为eposide(5)的设置,因为在计算交叉注意力时仍有五个tsp 63 | # 比如在支持类提取时,有两个正常类,三个背景类,则batch=2,此时qkv输入为[2,2,256],tsp输入为[2,5,256] 64 | # qkv在上面view成了[2,2,1,256],所以这里tsp也要变成[2,5,1,256]才行 65 | # k和v在下面均会变为[2,5,1,256],则tsp也要变成[2,5,1,256]才行 66 | tsp = tsp.view(sz_b, len_tsp, self.n_head, self.d_model) 67 | 68 | # 因为类别提取和正向传播时不同类别之间均要交互,所以拓展batch_size维度 69 | dummy = self.dummy.reshape(1, 1, 1, self.d_model).expand(sz_b, -1, self.n_head, -1) 70 | dummy_v = torch.zeros(sz_b, 1, self.n_head, self.d_model, device=v.device) 71 | 72 | # 创建迭代器 73 | k_iter = iter(k.split(1, dim=1)) # 将 k 按 class 维度拆分成多个形状为 [BZ, 1, n_head, 256] 的张量 74 | list_k = [next(k_iter) if x != 100 else dummy for x in supp_class_id] 75 | # 假如batch为2,则k_cat为[2,5,1,256](support提取和正向传播均适用) 76 | k_cat = torch.cat(list_k, dim=1) 77 | v_iter = iter(v.split(1, dim=1)) # 将 k 按 class 维度拆分成多个形状为 [BZ, 1, n_head, 256] 的张量 78 | list_v = [next(v_iter) if x != 100 else dummy_v for x in supp_class_id] 79 | v_cat = torch.cat(list_v, dim=1) 80 | 81 | # k = torch.cat([k, dummy], dim=1) 82 | # v = torch.cat([v, dummy_v], dim=1) 83 | # tsp = torch.cat([tsp, dummy_v], dim=1) 84 | k = torch.cat([k_cat, dummy], dim=1) 85 | v = torch.cat([v_cat, dummy_v], dim=1) 86 | tsp = torch.cat([tsp, dummy_v], dim=1) 87 | 88 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, self.d_model) # (n_head * b) x lq x d_model 89 | # 因为batch不为定值,故要用len_tsp(与tsp一致) 90 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_tsp + 1, self.d_model) # (n_head * b) x lk x d_model 91 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_tsp + 1, self.d_model) # (n_head * b) x lv x d_model 92 | tsp = tsp.permute(2, 0, 1, 3).contiguous().view(-1, len_tsp + 1, self.d_model) # (n_head * b) x lv x d_model 93 | 94 | # k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k + 1, self.d_model) # (n_head * b) x lk x d_model 95 | # v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v + 1, self.d_model) # (n_head * b) x lv x d_model 96 | # tsp = tsp.permute(2, 0, 1, 3).contiguous().view(-1, len_v + 1, self.d_model) # (n_head * b) x lv x d_model 97 | 98 | output, attn, log_attn = self.attention(q, k, v) 99 | tsp, _, _ = self.attention(q, k, tsp) 100 | 101 | output = output.view(self.n_head, sz_b, len_q, self.d_model) 102 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n_head * d_model) 103 | 104 | tsp = tsp.view(self.n_head, sz_b, len_q, self.d_model) 105 | tsp = tsp.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n_head * d_model) 106 | 107 | output1 = self.linear1(output * residual) 108 | output2 = self.linear2(residual - output) 109 | output = self.linear3( 110 | torch.cat([output1, output2, residual], dim=2) 111 | ) 112 | 113 | return output, tsp 114 | 115 | # def forward(self, q, k, v, tsp, supp_class_id): 116 | # sz_b, len_q, _ = q.size() 117 | # sz_b, len_k, _ = k.size() 118 | # sz_b, len_v, _ = v.size() 119 | 120 | # # tsp为[batch_size,eposide_size, 256] 121 | # sz_b, len_tsp, _ = tsp.size() 122 | 123 | # residual = q 124 | # q = self.w_qk(q).view(sz_b, len_q, self.n_head, self.d_model) 125 | # k = self.w_qk(k).view(sz_b, len_k, self.n_head, self.d_model) 126 | # v = v.view(sz_b, len_v, self.n_head, self.d_model) 127 | 128 | # # tsp = tsp.view(sz_b, len_v, self.n_head, self.d_model) 129 | # # 这里要注意,batch_size要与qkv一致,而len_tsp则是自己的 130 | # # 因为我们引入了背景信息,而tsp仍然保持为eposide(5)的设置,因为在计算交叉注意力时仍有五个tsp 131 | # # 比如在支持类提取时,有两个正常类,三个背景类,则batch=2,此时qkv输入为[2,2,256],tsp输入为[2,5,256] 132 | # # qkv在上面view成了[2,2,1,256],所以这里tsp也要变成[2,5,1,256]才行 133 | # # k和v在下面均会变为[2,5,1,256],则tsp也要变成[2,5,1,256]才行 134 | # tsp = tsp.view(sz_b, len_tsp, self.n_head, self.d_model) 135 | 136 | # # 因为类别提取和正向传播时不同类别之间均要交互,所以拓展batch_size维度 137 | # dummy = self.dummy.reshape(1, 1, 1, self.d_model).expand(sz_b, -1, self.n_head, -1) 138 | # dummy_v = torch.zeros(sz_b, 1, self.n_head, self.d_model, device=v.device) 139 | 140 | # # 创建迭代器 141 | # k_iter = iter(k.split(1, dim=1)) # 将 k 按 class 维度拆分成多个形状为 [BZ, 1, n_head, 256] 的张量 142 | # list_k = [next(k_iter) if x != 100 else dummy for x in supp_class_id] 143 | # # 假如batch为2,则k_cat为[2,5,1,256](support提取和正向传播均适用) 144 | # k_cat = torch.cat(list_k, dim=1) 145 | # v_iter = iter(v.split(1, dim=1)) # 将 k 按 class 维度拆分成多个形状为 [BZ, 1, n_head, 256] 的张量 146 | # list_v = [next(v_iter) if x != 100 else dummy_v for x in supp_class_id] 147 | # v_cat = torch.cat(list_v, dim=1) 148 | 149 | # # k = torch.cat([k, dummy], dim=1) 150 | # # v = torch.cat([v, dummy_v], dim=1) 151 | # # tsp = torch.cat([tsp, dummy_v], dim=1) 152 | # # k = torch.cat([k_cat, dummy], dim=1) 153 | # # v = torch.cat([v_cat, dummy_v], dim=1) 154 | # # tsp = torch.cat([tsp, dummy_v], dim=1) 155 | # k = k_cat 156 | # v = v_cat 157 | 158 | # q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, self.d_model) # (n_head * b) x lq x d_model 159 | # # 因为batch不为定值,故要用len_tsp(与tsp一致) 160 | # k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_tsp, self.d_model) # (n_head * b) x lk x d_model 161 | # v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_tsp, self.d_model) # (n_head * b) x lv x d_model 162 | # tsp = tsp.permute(2, 0, 1, 3).contiguous().view(-1, len_tsp, self.d_model) # (n_head * b) x lv x d_model 163 | 164 | # # k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k + 1, self.d_model) # (n_head * b) x lk x d_model 165 | # # v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v + 1, self.d_model) # (n_head * b) x lv x d_model 166 | # # tsp = tsp.permute(2, 0, 1, 3).contiguous().view(-1, len_v + 1, self.d_model) # (n_head * b) x lv x d_model 167 | 168 | # output, attn, log_attn = self.attention(q, k, v) 169 | # tsp, _, _ = self.attention(q, k, tsp) 170 | 171 | # output = output.view(self.n_head, sz_b, len_q, self.d_model) 172 | # output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n_head * d_model) 173 | 174 | # tsp = tsp.view(self.n_head, sz_b, len_q, self.d_model) 175 | # tsp = tsp.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n_head * d_model) 176 | 177 | # output1 = self.linear1(output * residual) 178 | # output2 = self.linear2(residual - output) 179 | # output = self.linear3( 180 | # torch.cat([output1, output2, residual], dim=2) 181 | # ) 182 | 183 | # return output, tsp 184 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | from torchvision.models._utils import IntermediateLayerGetter 8 | 9 | from util.misc import NestedTensor, is_main_process 10 | 11 | from .position_encoding import build_position_encoding 12 | 13 | 14 | class FrozenBatchNorm2d(torch.nn.Module): 15 | """ 16 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 17 | 18 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 19 | without which any other models than torchvision.models.resnet[18,34,50,101] 20 | produce nans. 21 | """ 22 | def __init__(self, n, eps=1e-5): 23 | super(FrozenBatchNorm2d, self).__init__() 24 | self.register_buffer("weight", torch.ones(n)) 25 | self.register_buffer("bias", torch.zeros(n)) 26 | self.register_buffer("running_mean", torch.zeros(n)) 27 | self.register_buffer("running_var", torch.ones(n)) 28 | self.eps = eps 29 | 30 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 31 | missing_keys, unexpected_keys, error_msgs): 32 | num_batches_tracked_key = prefix + 'num_batches_tracked' 33 | if num_batches_tracked_key in state_dict: 34 | del state_dict[num_batches_tracked_key] 35 | 36 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 37 | state_dict, prefix, local_metadata, strict, 38 | missing_keys, unexpected_keys, error_msgs) 39 | 40 | def forward(self, x): 41 | # move reshapes to the beginning to make it fuser-friendly 42 | w = self.weight.reshape(1, -1, 1, 1) 43 | b = self.bias.reshape(1, -1, 1, 1) 44 | rv = self.running_var.reshape(1, -1, 1, 1) 45 | rm = self.running_mean.reshape(1, -1, 1, 1) 46 | eps = self.eps 47 | scale = w * (rv + eps).rsqrt() 48 | bias = b - rm * scale 49 | return x * scale + bias 50 | 51 | 52 | class BackboneBase(nn.Module): 53 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool, args): 54 | super().__init__() 55 | self.args = args 56 | self.backbone = backbone 57 | 58 | # # Settings for freezing backbone 59 | # assert 0 <= args.freeze_backbone_at_layer <= 4 60 | # for name, parameter in backbone.named_parameters(): parameter.requires_grad_(False) # First freeze all 61 | # if train_backbone: 62 | # if args.freeze_backbone_at_layer == 0: 63 | # for name, parameter in backbone.named_parameters(): 64 | # if 'layer1' in name or 'layer2' in name or 'layer3' in name or 'layer4' in name: 65 | # parameter.requires_grad_(True) 66 | # elif args.freeze_backbone_at_layer == 1: 67 | # for name, parameter in backbone.named_parameters(): 68 | # if 'layer2' in name or 'layer3' in name or 'layer4' in name: 69 | # parameter.requires_grad_(True) 70 | # elif args.freeze_backbone_at_layer == 2: 71 | # for name, parameter in backbone.named_parameters(): 72 | # if 'layer3' in name or 'layer4' in name: 73 | # parameter.requires_grad_(True) 74 | # elif args.freeze_backbone_at_layer == 3: 75 | # for name, parameter in backbone.named_parameters(): 76 | # if 'layer4' in name: 77 | # parameter.requires_grad_(True) 78 | # elif args.freeze_backbone_at_layer == 4: 79 | # pass 80 | # else: 81 | # raise RuntimeError 82 | 83 | if return_interm_layers: 84 | return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 85 | self.strides = [8, 16, 32] 86 | self.num_channels = [512, 1024, 2048] 87 | else: 88 | return_layers = {'layer4': "0"} 89 | self.strides = [32] 90 | self.num_channels = [2048] 91 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 92 | 93 | def support_encoding_net(self, x, return_interm_layers=False): 94 | out: Dict[str, NestedTensor] = {} 95 | m = x.mask 96 | # x = self.meta_conv(x.tensors) 97 | x = self.backbone.conv1(x.tensors) 98 | x = self.backbone.bn1(x) 99 | x = self.backbone.relu(x) 100 | x = self.backbone.maxpool(x) 101 | x = self.backbone.layer1(x) 102 | x = self.backbone.layer2(x) 103 | if return_interm_layers: 104 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 105 | out['0'] = NestedTensor(x, mask) 106 | 107 | x = self.backbone.layer3(x) 108 | if return_interm_layers: 109 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 110 | out['1'] = NestedTensor(x, mask) 111 | 112 | x = self.backbone.layer4(x) 113 | if return_interm_layers: 114 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 115 | out['2'] = NestedTensor(x, mask) 116 | 117 | if return_interm_layers: 118 | return out 119 | else: 120 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 121 | out['0'] = NestedTensor(x, mask) 122 | return out 123 | 124 | def forward(self, tensor_list: NestedTensor): 125 | xs = self.body(tensor_list.tensors) 126 | out: Dict[str, NestedTensor] = {} 127 | for name, x in xs.items(): 128 | m = tensor_list.mask 129 | assert m is not None 130 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 131 | out[name] = NestedTensor(x, mask) 132 | return out 133 | 134 | 135 | class Backbone(BackboneBase): 136 | """ResNet backbone with frozen BatchNorm.""" 137 | def __init__(self, 138 | name: str, 139 | train_backbone: bool, 140 | return_interm_layers: bool, 141 | args): 142 | self.args = args 143 | dilation = args.dilation 144 | norm_layer = FrozenBatchNorm2d 145 | backbone = getattr(torchvision.models, name)( 146 | replace_stride_with_dilation=[False, False, dilation], 147 | pretrained=is_main_process(), norm_layer=norm_layer) 148 | # for param in backbone.parameters(): 149 | # param.requires_grad = False 150 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded, cannot use res18 & res34." 151 | super().__init__(backbone, train_backbone, return_interm_layers, args) 152 | if dilation: 153 | self.strides[-1] = self.strides[-1] // 2 154 | 155 | 156 | class Joiner(nn.Sequential): 157 | def __init__(self, backbone, position_embedding): 158 | super().__init__(backbone, position_embedding) 159 | self.strides = backbone.strides 160 | self.num_channels = backbone.num_channels 161 | 162 | def forward(self, tensor_list: NestedTensor): 163 | xs = self[0](tensor_list) 164 | out: List[NestedTensor] = [] 165 | pos = [] 166 | for name, x in sorted(xs.items()): 167 | out.append(x) 168 | 169 | # position encoding 170 | for x in out: 171 | pos.append(self[1](x).to(x.tensors.dtype)) 172 | 173 | return out, pos 174 | 175 | def forward_supp_branch(self, tensor_list: NestedTensor, return_interm_layers=False): 176 | xs = self[0].support_encoding_net(tensor_list, return_interm_layers=return_interm_layers) 177 | out: List[NestedTensor] = [] 178 | pos = [] 179 | for name, x in sorted(xs.items()): 180 | out.append(x) 181 | 182 | # position encoding 183 | for x in out: 184 | pos.append(self[1](x).to(x.tensors.dtype)) 185 | 186 | return out, pos 187 | 188 | 189 | def build_backbone(args): 190 | position_embedding = build_position_encoding(args) 191 | train_backbone = args.lr_backbone > 0 192 | return_interm_layers = (args.num_feature_levels > 1) 193 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args) 194 | # for p in backbone.parameters(): p.requires_grad = False 195 | # backbone.eval() 196 | model = Joiner(backbone, position_embedding) 197 | return model 198 | -------------------------------------------------------------------------------- /models/backbone_frozen.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | from torchvision.models._utils import IntermediateLayerGetter 8 | 9 | from util.misc import NestedTensor, is_main_process 10 | 11 | from .position_encoding import build_position_encoding 12 | 13 | 14 | class FrozenBatchNorm2d(torch.nn.Module): 15 | """ 16 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 17 | 18 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 19 | without which any other models than torchvision.models.resnet[18,34,50,101] 20 | produce nans. 21 | """ 22 | def __init__(self, n, eps=1e-5): 23 | super(FrozenBatchNorm2d, self).__init__() 24 | self.register_buffer("weight", torch.ones(n)) 25 | self.register_buffer("bias", torch.zeros(n)) 26 | self.register_buffer("running_mean", torch.zeros(n)) 27 | self.register_buffer("running_var", torch.ones(n)) 28 | self.eps = eps 29 | 30 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 31 | missing_keys, unexpected_keys, error_msgs): 32 | num_batches_tracked_key = prefix + 'num_batches_tracked' 33 | if num_batches_tracked_key in state_dict: 34 | del state_dict[num_batches_tracked_key] 35 | 36 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 37 | state_dict, prefix, local_metadata, strict, 38 | missing_keys, unexpected_keys, error_msgs) 39 | 40 | def forward(self, x): 41 | # move reshapes to the beginning to make it fuser-friendly 42 | w = self.weight.reshape(1, -1, 1, 1) 43 | b = self.bias.reshape(1, -1, 1, 1) 44 | rv = self.running_var.reshape(1, -1, 1, 1) 45 | rm = self.running_mean.reshape(1, -1, 1, 1) 46 | eps = self.eps 47 | scale = w * (rv + eps).rsqrt() 48 | bias = b - rm * scale 49 | return x * scale + bias 50 | 51 | 52 | class BackboneBase(nn.Module): 53 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool, args): 54 | super().__init__() 55 | self.args = args 56 | self.backbone = backbone 57 | 58 | # # Settings for freezing backbone 59 | # assert 0 <= args.freeze_backbone_at_layer <= 4 60 | # for name, parameter in backbone.named_parameters(): parameter.requires_grad_(False) # First freeze all 61 | # if train_backbone: 62 | # if args.freeze_backbone_at_layer == 0: 63 | # for name, parameter in backbone.named_parameters(): 64 | # if 'layer1' in name or 'layer2' in name or 'layer3' in name or 'layer4' in name: 65 | # parameter.requires_grad_(True) 66 | # elif args.freeze_backbone_at_layer == 1: 67 | # for name, parameter in backbone.named_parameters(): 68 | # if 'layer2' in name or 'layer3' in name or 'layer4' in name: 69 | # parameter.requires_grad_(True) 70 | # elif args.freeze_backbone_at_layer == 2: 71 | # for name, parameter in backbone.named_parameters(): 72 | # if 'layer3' in name or 'layer4' in name: 73 | # parameter.requires_grad_(True) 74 | # elif args.freeze_backbone_at_layer == 3: 75 | # for name, parameter in backbone.named_parameters(): 76 | # if 'layer4' in name: 77 | # parameter.requires_grad_(True) 78 | # elif args.freeze_backbone_at_layer == 4: 79 | # pass 80 | # else: 81 | # raise RuntimeError 82 | 83 | if return_interm_layers: 84 | return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 85 | self.strides = [8, 16, 32] 86 | self.num_channels = [512, 1024, 2048] 87 | else: 88 | return_layers = {'layer4': "0"} 89 | self.strides = [32] 90 | self.num_channels = [2048] 91 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 92 | 93 | def support_encoding_net(self, x, return_interm_layers=False): 94 | out: Dict[str, NestedTensor] = {} 95 | m = x.mask 96 | # x = self.meta_conv(x.tensors) 97 | x = self.backbone.conv1(x.tensors) 98 | x = self.backbone.bn1(x) 99 | x = self.backbone.relu(x) 100 | x = self.backbone.maxpool(x) 101 | x = self.backbone.layer1(x) 102 | x = self.backbone.layer2(x) 103 | if return_interm_layers: 104 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 105 | out['0'] = NestedTensor(x, mask) 106 | 107 | x = self.backbone.layer3(x) 108 | if return_interm_layers: 109 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 110 | out['1'] = NestedTensor(x, mask) 111 | 112 | x = self.backbone.layer4(x) 113 | if return_interm_layers: 114 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 115 | out['2'] = NestedTensor(x, mask) 116 | 117 | if return_interm_layers: 118 | return out 119 | else: 120 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 121 | out['0'] = NestedTensor(x, mask) 122 | return out 123 | 124 | def forward(self, tensor_list: NestedTensor): 125 | xs = self.body(tensor_list.tensors) 126 | out: Dict[str, NestedTensor] = {} 127 | for name, x in xs.items(): 128 | m = tensor_list.mask 129 | assert m is not None 130 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 131 | out[name] = NestedTensor(x, mask) 132 | return out 133 | 134 | 135 | class Backbone(BackboneBase): 136 | """ResNet backbone with frozen BatchNorm.""" 137 | def __init__(self, 138 | name: str, 139 | train_backbone: bool, 140 | return_interm_layers: bool, 141 | args): 142 | self.args = args 143 | dilation = args.dilation 144 | norm_layer = FrozenBatchNorm2d 145 | backbone = getattr(torchvision.models, name)( 146 | replace_stride_with_dilation=[False, False, dilation], 147 | pretrained=is_main_process(), norm_layer=norm_layer) 148 | for param in backbone.parameters(): 149 | param.requires_grad = False 150 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded, cannot use res18 & res34." 151 | super().__init__(backbone, train_backbone, return_interm_layers, args) 152 | if dilation: 153 | self.strides[-1] = self.strides[-1] // 2 154 | 155 | 156 | class Joiner(nn.Sequential): 157 | def __init__(self, backbone, position_embedding): 158 | super().__init__(backbone, position_embedding) 159 | self.strides = backbone.strides 160 | self.num_channels = backbone.num_channels 161 | 162 | def forward(self, tensor_list: NestedTensor): 163 | xs = self[0](tensor_list) 164 | out: List[NestedTensor] = [] 165 | pos = [] 166 | for name, x in sorted(xs.items()): 167 | out.append(x) 168 | 169 | # position encoding 170 | for x in out: 171 | pos.append(self[1](x).to(x.tensors.dtype)) 172 | 173 | return out, pos 174 | 175 | def forward_supp_branch(self, tensor_list: NestedTensor, return_interm_layers=False): 176 | xs = self[0].support_encoding_net(tensor_list, return_interm_layers=return_interm_layers) 177 | out: List[NestedTensor] = [] 178 | pos = [] 179 | for name, x in sorted(xs.items()): 180 | out.append(x) 181 | 182 | # position encoding 183 | for x in out: 184 | pos.append(self[1](x).to(x.tensors.dtype)) 185 | 186 | return out, pos 187 | 188 | 189 | def build_backbone(args): 190 | position_embedding = build_position_encoding(args) 191 | train_backbone = args.lr_backbone > 0 192 | return_interm_layers = (args.num_feature_levels > 1) 193 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args) 194 | for p in backbone.parameters(): p.requires_grad = False 195 | backbone.eval() 196 | model = Joiner(backbone, position_embedding) 197 | return model 198 | -------------------------------------------------------------------------------- /models/dino_backbone.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | from torchvision.models._utils import IntermediateLayerGetter 8 | 9 | from util.misc import NestedTensor, is_main_process 10 | 11 | from .position_encoding import build_position_encoding 12 | from dinov2.vit import DinoVisionTransformer, vit_base, vit_large 13 | 14 | class Joiner(nn.Sequential): 15 | def __init__(self, backbone, position_embedding, conv_14_28): 16 | super().__init__(backbone, position_embedding, conv_14_28) 17 | # dino的stride为14 18 | # self.strides = backbone.strides 19 | self.num_channels = backbone.num_channels 20 | # 也即最后一层输出的名字 21 | self.vit_feat_name = f'res{backbone.n_blocks - 1}' 22 | 23 | def forward(self, tensor_list: NestedTensor, VPT_enable = False): 24 | out_tmp: Dict[str, NestedTensor] = {} 25 | m = tensor_list.mask 26 | if VPT_enable: 27 | x = self[0](tensor_list.tensors)[self.vit_feat_name] 28 | else: 29 | with torch.no_grad(): x = self[0](tensor_list.tensors)[self.vit_feat_name] 30 | # conv_14_28 ↓ 31 | x = self[2](x) 32 | # conv_14_28 ↑ 33 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 34 | out_tmp['0'] = NestedTensor(x, mask) 35 | 36 | xs = out_tmp 37 | out: List[NestedTensor] = [] 38 | pos = [] 39 | for name, x in sorted(xs.items()): 40 | out.append(x) 41 | 42 | # position encoding 43 | for x in out: 44 | pos.append(self[1](x).to(x.tensors.dtype)) 45 | 46 | return out, pos 47 | 48 | def forward_supp_branch(self, tensor_list: NestedTensor, return_interm_layers = False, VPT_enable = False): 49 | out_tmp: Dict[str, NestedTensor] = {} 50 | m = tensor_list.mask 51 | if VPT_enable: 52 | x = self[0](tensor_list.tensors)[self.vit_feat_name] 53 | else: 54 | with torch.no_grad(): x = self[0](tensor_list.tensors)[self.vit_feat_name] 55 | # conv_14_28 ↓ 56 | x = self[2](x) 57 | # conv_14_28 ↑ 58 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 59 | out_tmp['0'] = NestedTensor(x, mask) 60 | 61 | xs = out_tmp 62 | out: List[NestedTensor] = [] 63 | pos = [] 64 | for name, x in sorted(xs.items()): 65 | out.append(x) 66 | 67 | # position encoding 68 | for x in out: 69 | pos.append(self[1](x).to(x.tensors.dtype)) 70 | 71 | return out, pos 72 | 73 | def build_dino_v2_vit(args, input_shape, VPT_enable): 74 | # 指定在模型的哪些层输出特征图。如果为 None,则不输出中间层特征。 75 | out_indices = None 76 | 77 | if out_indices is not None: 78 | if isinstance(out_indices, str): 79 | out_indices = [int(m) for m in out_indices.split(",")] 80 | 81 | if args.dino_type == 'small': 82 | return DinoVisionTransformer( 83 | patch_size=14, 84 | img_size=518, 85 | init_values=1, 86 | embed_dim=384, 87 | depth=12, 88 | num_heads=6, 89 | mlp_ratio=4, 90 | out_indices=out_indices, 91 | VPT_enable=VPT_enable, 92 | ) 93 | elif args.dino_type == 'base': 94 | return vit_base(out_indices=out_indices, VPT_enable=VPT_enable) 95 | elif args.dino_type == "large": 96 | return vit_large(img_size=518, patch_size=14, init_values=1, out_indices=out_indices, VPT_enable=VPT_enable) 97 | else: 98 | raise NotImplementedError() 99 | 100 | def build_backbone(args): 101 | position_embedding = build_position_encoding(args) 102 | # 这个3应该没啥用 103 | backbone = build_dino_v2_vit(args, 3, VPT_enable=args.VPT_enable) 104 | # 禁止更新参数 105 | for p in backbone.parameters(): p.requires_grad = False 106 | backbone.eval() 107 | if args.VPT_enable: 108 | # 确保 prompt_dropout 和 prompt_embeddings 可训练 109 | backbone.prompt_dropout.requires_grad = True 110 | backbone.prompt_embeddings.requires_grad = True 111 | backbone.prompt_dropout.train() 112 | """ 113 | nn.Parameter 自身并没有 train() 或 eval() 方法。这是因为 train() 和 eval() 是 nn.Module 类的方法,它们用于设置模块的训练模式或评估模式,主要影响的是 Dropout 和 BatchNorm 等层的行为。 114 | """ 115 | # backbone.prompt_embeddings.train() 116 | missing_keys, unexpected_keys = backbone.load_state_dict(torch.load(args.dino_weight_path), strict=False) 117 | else: 118 | missing_keys, unexpected_keys = backbone.load_state_dict(torch.load(args.dino_weight_path)) 119 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))] 120 | """ 121 | missing_keys 是检查点文件中缺失但在模型中需要的键(即模型参数)。 122 | unexpected_keys 是检查点文件中有的但模型中没有的键。 123 | """ 124 | if len(missing_keys) > 0: 125 | print('Missing Keys: {}'.format(missing_keys)) 126 | if len(unexpected_keys) > 0: 127 | print('Unexpected Keys: {}'.format(unexpected_keys)) 128 | # 让dino的14stride->28stride 129 | conv_14_28 = nn.Sequential( 130 | nn.Conv2d(backbone.num_channels[0], backbone.num_channels[0], kernel_size=3, stride=2, padding=1), 131 | nn.GroupNorm(32, backbone.num_channels[0]), 132 | ) 133 | # 按照detr中一样的初始化 134 | nn.init.xavier_uniform_(conv_14_28[0].weight, gain=1) 135 | nn.init.constant_(conv_14_28[0].bias, 0) 136 | model = Joiner(backbone, position_embedding, conv_14_28) 137 | return model -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Modules to compute the matching cost and solve the corresponding LSAP. 12 | """ 13 | import torch 14 | from scipy.optimize import linear_sum_assignment 15 | from torch import nn 16 | 17 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 18 | 19 | 20 | class HungarianMatcher(nn.Module): 21 | """This class computes an assignment between the targets and the predictions of the network 22 | 23 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 24 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 25 | while the others are un-matched (and thus treated as non-objects). 26 | """ 27 | 28 | def __init__(self, 29 | cost_class: float = 1.0, 30 | cost_bbox: float = 1.0, 31 | cost_giou: float = 1.0): 32 | """Creates the matcher 33 | 34 | Params: 35 | cost_class: This is the relative weight of the classification error in the matching cost 36 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 37 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 38 | """ 39 | super().__init__() 40 | self.cost_class = cost_class 41 | self.cost_bbox = cost_bbox 42 | self.cost_giou = cost_giou 43 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 44 | 45 | def forward(self, outputs, targets): 46 | """ Performs the matching 47 | 48 | Params: 49 | outputs: This is a dict that contains at least these entries: 50 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 51 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 52 | 53 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 54 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 55 | objects in the target) containing the class labels 56 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 57 | 58 | Returns: 59 | A list of size batch_size, containing tuples of (index_i, index_j) where: 60 | - index_i is the indices of the selected predictions (in order) 61 | - index_j is the indices of the corresponding selected targets (in order) 62 | For each batch element, it holds: 63 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 64 | """ 65 | with torch.no_grad(): 66 | bs, num_queries = outputs["pred_logits"].shape[:2] 67 | 68 | # We flatten to compute the cost matrices in a batch 69 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() 70 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 71 | 72 | # Also concat the target labels and boxes 73 | tgt_ids = torch.cat([v["labels"] for v in targets]) 74 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 75 | 76 | # Compute the classification cost. 77 | alpha = 0.25 78 | gamma = 2.0 79 | neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) 80 | pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) 81 | cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] 82 | 83 | # Compute the L1 cost between boxes 84 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 85 | 86 | # Compute the giou cost betwen boxes 87 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), 88 | box_cxcywh_to_xyxy(tgt_bbox)) 89 | 90 | # Final cost matrix 91 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 92 | C = C.view(bs, num_queries, -1).cpu() 93 | 94 | sizes = [len(v["boxes"]) for v in targets] 95 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 96 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 97 | 98 | 99 | def build_matcher(args): 100 | return HungarianMatcher(cost_class=args.set_cost_class, 101 | cost_bbox=args.set_cost_bbox, 102 | cost_giou=args.set_cost_giou) 103 | -------------------------------------------------------------------------------- /models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | -------------------------------------------------------------------------------- /models/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 46 | "which is more efficient in our CUDA implementation.") 47 | 48 | self.im2col_step = 64 49 | 50 | self.d_model = d_model 51 | self.n_levels = n_levels 52 | self.n_heads = n_heads 53 | self.n_points = n_points 54 | 55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 57 | self.value_proj = nn.Linear(d_model, d_model) 58 | self.output_proj = nn.Linear(d_model, d_model) 59 | 60 | self._reset_parameters() 61 | 62 | def _reset_parameters(self): 63 | constant_(self.sampling_offsets.weight.data, 0.) 64 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 65 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 66 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 67 | for i in range(self.n_points): 68 | grid_init[:, :, i, :] *= i + 1 69 | with torch.no_grad(): 70 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 71 | constant_(self.attention_weights.weight.data, 0.) 72 | constant_(self.attention_weights.bias.data, 0.) 73 | xavier_uniform_(self.value_proj.weight.data) 74 | constant_(self.value_proj.bias.data, 0.) 75 | xavier_uniform_(self.output_proj.weight.data) 76 | constant_(self.output_proj.bias.data, 0.) 77 | 78 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 79 | """ 80 | :param query (N, Length_{query}, C) 81 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 82 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 83 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 84 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 85 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 86 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 87 | 88 | :return output (N, Length_{query}, C) 89 | """ 90 | N, Len_q, _ = query.shape 91 | N, Len_in, _ = input_flatten.shape 92 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 93 | 94 | value = self.value_proj(input_flatten) 95 | if input_padding_mask is not None: 96 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 97 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 98 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 99 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 100 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 101 | # N, Len_q, n_heads, n_levels, n_points, 2 102 | if reference_points.shape[-1] == 2: 103 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 104 | sampling_locations = reference_points[:, :, None, :, None, :] \ 105 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 106 | elif reference_points.shape[-1] == 4: 107 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 108 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 109 | else: 110 | raise ValueError( 111 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 112 | output = MSDeformAttnFunction.apply( 113 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 114 | output = self.output_proj(output) 115 | return output 116 | -------------------------------------------------------------------------------- /models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Various positional encodings for the transformer. 12 | """ 13 | import math 14 | import torch 15 | from torch import nn 16 | from torch.autograd import Variable 17 | 18 | from util.misc import NestedTensor 19 | 20 | 21 | class PositionEmbeddingSine(nn.Module): 22 | """ 23 | This is a more standard version of the position embedding, very similar to the one 24 | used by the Attention is all you need paper, generalized to work on images. 25 | """ 26 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 27 | super().__init__() 28 | self.num_pos_feats = num_pos_feats 29 | self.temperature = temperature 30 | self.normalize = normalize 31 | if scale is not None and normalize is False: 32 | raise ValueError("normalize should be True if scale is passed") 33 | if scale is None: 34 | scale = 2 * math.pi 35 | self.scale = scale 36 | 37 | def forward(self, tensor_list: NestedTensor): 38 | x = tensor_list.tensors 39 | mask = tensor_list.mask 40 | assert mask is not None 41 | not_mask = ~mask 42 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 43 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 44 | if self.normalize: 45 | eps = 1e-6 46 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 47 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 48 | 49 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 50 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 51 | 52 | pos_x = x_embed[:, :, :, None] / dim_t 53 | pos_y = y_embed[:, :, :, None] / dim_t 54 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 55 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 56 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 57 | return pos 58 | 59 | 60 | class PositionEmbeddingLearned(nn.Module): 61 | """ 62 | Absolute pos embedding, learned. 63 | """ 64 | def __init__(self, num_pos_feats=256): 65 | super().__init__() 66 | self.row_embed = nn.Embedding(50, num_pos_feats) 67 | self.col_embed = nn.Embedding(50, num_pos_feats) 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | nn.init.uniform_(self.row_embed.weight) 72 | nn.init.uniform_(self.col_embed.weight) 73 | 74 | def forward(self, tensor_list: NestedTensor): 75 | x = tensor_list.tensors 76 | h, w = x.shape[-2:] 77 | i = torch.arange(w, device=x.device) 78 | j = torch.arange(h, device=x.device) 79 | x_emb = self.col_embed(i) 80 | y_emb = self.row_embed(j) 81 | pos = torch.cat([ 82 | x_emb.unsqueeze(0).repeat(h, 1, 1), 83 | y_emb.unsqueeze(1).repeat(1, w, 1), 84 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 85 | return pos 86 | 87 | """ 88 | 由于 TaskPositionalEncoding 中的位置编码是预计算并保存在缓冲区中的,而且输入是一个常量零张量,因此 tsp 在训练过程中不会改变,是一个定值。 89 | """ 90 | class TaskPositionalEncoding(nn.Module): 91 | def __init__(self, d_model, dropout=0.05, max_len=10): 92 | super(TaskPositionalEncoding, self).__init__() 93 | self.dropout = nn.Dropout(p=dropout) 94 | # Compute the task positional encodings once and for all in log space. 95 | tpe = torch.zeros(max_len, d_model) 96 | position = torch.arange(0, max_len).unsqueeze(1) 97 | div_term = torch.exp(torch.arange(0, d_model, 2) * 98 | -(math.log(10000.0) / d_model)) 99 | tpe[:, 0::2] = torch.sin(position * div_term) 100 | tpe[:, 1::2] = torch.cos(position * div_term) 101 | self.register_buffer('tpe', tpe) 102 | 103 | def forward(self, x): 104 | x = x + torch.flip(Variable(self.tpe[:x.size(1)], requires_grad=False), [1]) 105 | return self.dropout(x) 106 | 107 | 108 | class QueryEncoding(nn.Module): 109 | def __init__(self, d_model, dropout=0.0, max_len=100): 110 | super(QueryEncoding, self).__init__() 111 | self.dropout = nn.Dropout(p=dropout) 112 | # Compute the query encodings once and for all in log space. 113 | queryencoding = torch.zeros(max_len, d_model) 114 | position = torch.arange(0, max_len).unsqueeze(1) 115 | div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) 116 | queryencoding[:, 0::2] = torch.sin(position * div_term) 117 | queryencoding[:, 1::2] = torch.cos(position * div_term) 118 | self.register_buffer('queryencoding', queryencoding) 119 | 120 | def forward(self): 121 | x = Variable(self.queryencoding, requires_grad=False) 122 | return self.dropout(x) 123 | 124 | 125 | def build_position_encoding(args): 126 | N_steps = args.hidden_dim // 2 127 | if args.position_embedding in ('v2', 'sine'): 128 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 129 | elif args.position_embedding in ('v3', 'learned'): 130 | position_embedding = PositionEmbeddingLearned(N_steps) 131 | else: 132 | raise ValueError(f"not supported {args.position_embedding}") 133 | 134 | return position_embedding 135 | -------------------------------------------------------------------------------- /scripts/basetrain.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source /opt/anaconda3/etc/profile.d/conda.sh 4 | conda activate cdformer 5 | 6 | EXP_DIR=exps/CDFormer_dinov2 7 | BASE_TRAIN_DIR=${EXP_DIR}/base_train 8 | mkdir exps 9 | mkdir ${EXP_DIR} 10 | mkdir ${BASE_TRAIN_DIR} 11 | 12 | python -u main.py \ 13 | --dataset_file coco \ 14 | --backbone dinov2 \ 15 | --num_feature_levels 1 \ 16 | --enc_layers 6 \ 17 | --dec_layers 6 \ 18 | --hidden_dim 256 \ 19 | --num_queries 300 \ 20 | --batch_size 8 \ 21 | --epoch 50 \ 22 | --lr_drop_milestones 45 \ 23 | --save_every_epoch 5 \ 24 | --eval_every_epoch 5 \ 25 | --output_dir ${BASE_TRAIN_DIR} \ 26 | --category_codes_cls_loss \ 27 | 2>&1 | tee ${BASE_TRAIN_DIR}/log.txt -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source /opt/anaconda3/etc/profile.d/conda.sh 4 | conda activate cdformer 5 | export CUDA_VISIBLE_DEVICES=0 6 | fewshot_seed=01 7 | num_shot=10 8 | 9 | python -u main.py \ 10 | --dataset_file uodd \ 11 | --backbone dinov2 \ 12 | --num_feature_levels 1 \ 13 | --enc_layers 6 \ 14 | --dec_layers 6 \ 15 | --hidden_dim 256 \ 16 | --num_queries 300 \ 17 | --batch_size 2 \ 18 | --resume base_train.pth \ 19 | --fewshot_finetune \ 20 | --fewshot_seed ${fewshot_seed} \ 21 | --num_shots ${num_shot} \ 22 | --eval \ 23 | 2>&1 | tee ./log_inference_base_0.txt 24 | ``` -------------------------------------------------------------------------------- /scripts/fsfinetune.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source /opt/anaconda3/etc/profile.d/conda.sh 4 | conda activate cdformer 5 | 6 | EXP_DIR=exps/uodd 7 | BASE_TRAIN_DIR=${EXP_DIR}/epoch20 8 | mkdir exps 9 | mkdir ${EXP_DIR} 10 | mkdir ${BASE_TRAIN_DIR} 11 | 12 | fewshot_seed=01 13 | num_shot=10 14 | epoch= 15 | save_every_epoch= 16 | lr_drop1=250 17 | lr_drop2=450 18 | lr=5e-5 19 | lr_backbone=5e-6 20 | FS_FT_DIR=${BASE_TRAIN_DIR}/seed${fewshot_seed}_${num_shot}shot_01 21 | mkdir ${FS_FT_DIR} 22 | 23 | python -u main.py \ 24 | --dataset_file dataset2 \ 25 | --backbone dinov2 \ 26 | --num_feature_levels 1 \ 27 | --enc_layers 6 \ 28 | --dec_layers 6 \ 29 | --hidden_dim 256 \ 30 | --num_queries 300 \ 31 | --batch_size 2 \ 32 | --lr ${lr} \ 33 | --lr_backbone ${lr_backbone} \ 34 | --resume exps/dino_coco_80_size_vitl/base_train/checkpoint0049.pth \ 35 | --fewshot_finetune \ 36 | --fewshot_seed ${fewshot_seed} \ 37 | --num_shots ${num_shot} \ 38 | --epoch ${epoch} \ 39 | --lr_drop_milestones ${lr_drop1} ${lr_drop2} \ 40 | --warmup_epochs 50 \ 41 | --save_every_epoch ${save_every_epoch} \ 42 | --eval_every_epoch \ 43 | --output_dir ${FS_FT_DIR} \ 44 | --category_codes_cls_loss \ 45 | 2>&1 | tee ${FS_FT_DIR}/log.txt 46 | -------------------------------------------------------------------------------- /tools/launch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------------------------------- 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # -------------------------------------------------------------------------------------------------------------------------- 6 | # Modified from https://github.com/pytorch/pytorch/blob/173f224570017b4b1a3a1a13d0bff280a54d9cd9/torch/distributed/launch.py 7 | # -------------------------------------------------------------------------------------------------------------------------- 8 | 9 | r""" 10 | `torch.distributed.launch` is a module that spawns up multiple distributed 11 | training processes on each of the training nodes. 12 | The utility can be used for single-node distributed training, in which one or 13 | more processes per node will be spawned. The utility can be used for either 14 | CPU training or GPU training. If the utility is used for GPU training, 15 | each distributed process will be operating on a single GPU. This can achieve 16 | well-improved single-node training performance. It can also be used in 17 | multi-node distributed training, by spawning up multiple processes on each node 18 | for well-improved multi-node distributed training performance as well. 19 | This will especially be benefitial for systems with multiple Infiniband 20 | interfaces that have direct-GPU support, since all of them can be utilized for 21 | aggregated communication bandwidth. 22 | In both cases of single-node distributed training or multi-node distributed 23 | training, this utility will launch the given number of processes per node 24 | (``--nproc_per_node``). If used for GPU training, this number needs to be less 25 | or euqal to the number of GPUs on the current system (``nproc_per_node``), 26 | and each process will be operating on a single GPU from *GPU 0 to 27 | GPU (nproc_per_node - 1)*. 28 | **How to use this module:** 29 | 1. Single-Node multi-process distributed training 30 | :: 31 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE 32 | YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other 33 | arguments of your training script) 34 | 2. Multi-Node multi-process distributed training: (e.g. two nodes) 35 | Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* 36 | :: 37 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE 38 | --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" 39 | --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 40 | and all other arguments of your training script) 41 | Node 2: 42 | :: 43 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE 44 | --nnodes=2 --node_rank=1 --master_addr="192.168.1.1" 45 | --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 46 | and all other arguments of your training script) 47 | 3. To look up what optional arguments this module offers: 48 | :: 49 | >>> python -m torch.distributed.launch --help 50 | **Important Notices:** 51 | 1. This utilty and multi-process distributed (single-node or 52 | multi-node) GPU training currently only achieves the best performance using 53 | the NCCL distributed backend. Thus NCCL backend is the recommended backend to 54 | use for GPU training. 55 | 2. In your training program, you must parse the command-line argument: 56 | ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module. 57 | If your training program uses GPUs, you should ensure that your code only 58 | runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: 59 | Parsing the local_rank argument 60 | :: 61 | >>> import argparse 62 | >>> parser = argparse.ArgumentParser() 63 | >>> parser.add_argument("--local_rank", type=int) 64 | >>> args = parser.parse_args() 65 | Set your device to local rank using either 66 | :: 67 | >>> torch.cuda.set_device(arg.local_rank) # before your code runs 68 | or 69 | :: 70 | >>> with torch.cuda.device(arg.local_rank): 71 | >>> # your code to run 72 | 3. In your training program, you are supposed to call the following function 73 | at the beginning to start the distributed backend. You need to make sure that 74 | the init_method uses ``env://``, which is the only supported ``init_method`` 75 | by this module. 76 | :: 77 | torch.distributed.init_process_group(backend='YOUR BACKEND', 78 | init_method='env://') 79 | 4. In your training program, you can either use regular distributed functions 80 | or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your 81 | training program uses GPUs for training and you would like to use 82 | :func:`torch.nn.parallel.DistributedDataParallel` module, 83 | here is how to configure it. 84 | :: 85 | model = torch.nn.parallel.DistributedDataParallel(model, 86 | device_ids=[arg.local_rank], 87 | output_device=arg.local_rank) 88 | Please ensure that ``device_ids`` argument is set to be the only GPU device id 89 | that your code will be operating on. This is generally the local rank of the 90 | process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, 91 | and ``output_device`` needs to be ``args.local_rank`` in order to use this 92 | utility 93 | 5. Another way to pass ``local_rank`` to the subprocesses via environment variable 94 | ``LOCAL_RANK``. This behavior is enabled when you launch the script with 95 | ``--use_env=True``. You must adjust the subprocess example above to replace 96 | ``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher 97 | will not pass ``--local_rank`` when you specify this flag. 98 | .. warning:: 99 | ``local_rank`` is NOT globally unique: it is only unique per process 100 | on a machine. Thus, don't use it to decide if you should, e.g., 101 | write to a networked filesystem. See 102 | https://github.com/pytorch/pytorch/issues/12042 for an example of 103 | how things can go wrong if you don't do this correctly. 104 | """ 105 | 106 | 107 | import sys 108 | import subprocess 109 | import os 110 | import socket 111 | from argparse import ArgumentParser, REMAINDER 112 | 113 | import torch 114 | 115 | 116 | def parse_args(): 117 | """ 118 | Helper function parsing the command line options 119 | @retval ArgumentParser 120 | """ 121 | parser = ArgumentParser(description="PyTorch distributed training launch " 122 | "helper utilty that will spawn up " 123 | "multiple distributed processes") 124 | 125 | # Optional arguments for the launch helper 126 | parser.add_argument("--nnodes", type=int, default=1, 127 | help="The number of nodes to use for distributed " 128 | "training") 129 | parser.add_argument("--node_rank", type=int, default=0, 130 | help="The rank of the node for multi-node distributed " 131 | "training") 132 | parser.add_argument("--nproc_per_node", type=int, default=1, 133 | help="The number of processes to launch on each node, " 134 | "for GPU training, this is recommended to be set " 135 | "to the number of GPUs in your system so that " 136 | "each process can be bound to a single GPU.") 137 | parser.add_argument("--master_addr", default="127.0.0.1", type=str, 138 | help="Master node (rank 0)'s address, should be either " 139 | "the IP address or the hostname of node 0, for " 140 | "single node multi-proc training, the " 141 | "--master_addr can simply be 127.0.0.1") 142 | parser.add_argument("--master_port", default=29500, type=int, 143 | help="Master node (rank 0)'s free port that needs to " 144 | "be used for communciation during distributed " 145 | "training") 146 | 147 | # positional 148 | parser.add_argument("training_script", type=str, 149 | help="The full path to the single GPU training " 150 | "program/script to be launched in parallel, " 151 | "followed by all the arguments for the " 152 | "training script") 153 | 154 | # rest from the training program 155 | parser.add_argument('training_script_args', nargs=REMAINDER) 156 | return parser.parse_args() 157 | 158 | 159 | def main(): 160 | args = parse_args() 161 | 162 | # world size in terms of number of processes 163 | dist_world_size = args.nproc_per_node * args.nnodes 164 | 165 | # set PyTorch distributed related environmental variables 166 | current_env = os.environ.copy() 167 | current_env["MASTER_ADDR"] = args.master_addr 168 | current_env["MASTER_PORT"] = str(args.master_port) 169 | current_env["WORLD_SIZE"] = str(dist_world_size) 170 | 171 | processes = [] 172 | print('wo ni ide') 173 | for local_rank in range(0, args.nproc_per_node): 174 | # each process's rank 175 | dist_rank = args.nproc_per_node * args.node_rank + local_rank 176 | current_env["RANK"] = str(dist_rank) 177 | current_env["LOCAL_RANK"] = str(local_rank) 178 | # print(f'lauch rank {current_env["RANK"]} local rank {current_env["LOCAL_RANK"]}') 179 | 180 | cmd = [args.training_script] + args.training_script_args 181 | 182 | process = subprocess.Popen(cmd, env=current_env) 183 | processes.append(process) 184 | 185 | for process in processes: 186 | process.wait() 187 | if process.returncode != 0: 188 | raise subprocess.CalledProcessError(returncode=process.returncode, 189 | cmd=process.args) 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /tools/run_dist_launch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source /opt/anaconda3/etc/profile.d/conda.sh 4 | # conda info --env 5 | conda activate cdformer 6 | 7 | # 设置要使用的 GPU 为 GPU 4 8 | export CUDA_VISIBLE_DEVICES=2,3 9 | 10 | set -x 11 | 12 | GPUS=$1 13 | RUN_COMMAND=${@:2} 14 | if [ $GPUS -lt 8 ]; then 15 | GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} 16 | else 17 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 18 | fi 19 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.10"} 20 | # MASTER_PORT=${MASTER_PORT:-"29505"} 21 | MASTER_PORT=${MASTER_PORT:-"29505"} 22 | NODE_RANK=${NODE_RANK:-0} 23 | 24 | let "NNODES=GPUS/GPUS_PER_NODE" 25 | 26 | python ./tools/launch.py \ 27 | --nnodes ${NNODES} \ 28 | --node_rank ${NODE_RANK} \ 29 | --master_addr ${MASTER_ADDR} \ 30 | --master_port ${MASTER_PORT} \ 31 | --nproc_per_node ${GPUS_PER_NODE} \ 32 | ${RUN_COMMAND} 33 | -------------------------------------------------------------------------------- /tools/run_dist_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | GPUS=$3 8 | RUN_COMMAND=${@:4} 9 | if [ $GPUS -lt 8 ]; then 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} 11 | else 12 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 13 | fi 14 | CPUS_PER_TASK=${CPUS_PER_TASK:-2} 15 | SRUN_ARGS=${SRUN_ARGS:-""} 16 | 17 | srun -p ${PARTITION} \ 18 | --job-name=${JOB_NAME} \ 19 | --gres=gpu:${GPUS_PER_NODE} \ 20 | --ntasks=${GPUS} \ 21 | --ntasks-per-node=${GPUS_PER_NODE} \ 22 | --cpus-per-task=${CPUS_PER_TASK} \ 23 | --kill-on-bad-exit=1 \ 24 | ${SRUN_ARGS} \ 25 | ${RUN_COMMAND} 26 | 27 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for bounding box manipulation and GIoU. 3 | """ 4 | import torch 5 | from torchvision.ops.boxes import box_area 6 | 7 | 8 | def box_cxcywh_to_xyxy(x): 9 | x_c, y_c, w, h = x.unbind(-1) 10 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 11 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 12 | return torch.stack(b, dim=-1) 13 | 14 | 15 | def box_xyxy_to_cxcywh(x): 16 | x0, y0, x1, y1 = x.unbind(-1) 17 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 18 | (x1 - x0), (y1 - y0)] 19 | return torch.stack(b, dim=-1) 20 | 21 | 22 | # modified from torchvision to also return the union 23 | def box_iou(boxes1, boxes2): 24 | area1 = box_area(boxes1) 25 | area2 = box_area(boxes2) 26 | 27 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 28 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 29 | 30 | wh = (rb - lt).clamp(min=0) # [N,M,2] 31 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 32 | 33 | union = area1[:, None] + area2 - inter 34 | 35 | iou = inter / union 36 | return iou, union 37 | 38 | 39 | def generalized_box_iou(boxes1, boxes2): 40 | """ 41 | Generalized IoU from https://giou.stanford.edu/ 42 | 43 | The boxes should be in [x0, y0, x1, y1] format 44 | 45 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 46 | and M = len(boxes2) 47 | """ 48 | # degenerate boxes gives inf / nan results 49 | # so do an early check 50 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 51 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 52 | iou, union = box_iou(boxes1, boxes2) 53 | 54 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 55 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 56 | 57 | wh = (rb - lt).clamp(min=0) # [N,M,2] 58 | area = wh[:, :, 0] * wh[:, :, 1] 59 | 60 | return iou - (area - union) / area 61 | 62 | 63 | def masks_to_boxes(masks): 64 | """Compute the bounding boxes around the provided masks 65 | 66 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 67 | 68 | Returns a [N, 4] tensors, with the boxes in xyxy format 69 | """ 70 | if masks.numel() == 0: 71 | return torch.zeros((0, 4), device=masks.device) 72 | 73 | h, w = masks.shape[-2:] 74 | 75 | y = torch.arange(0, h, dtype=torch.float) 76 | x = torch.arange(0, w, dtype=torch.float) 77 | y, x = torch.meshgrid(y, x) 78 | 79 | x_mask = (masks * x.unsqueeze(0)) 80 | x_max = x_mask.flatten(1).max(-1)[0] 81 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 82 | 83 | y_mask = (masks * y.unsqueeze(0)) 84 | y_max = y_mask.flatten(1).max(-1)[0] 85 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 86 | 87 | return torch.stack([x_min, y_min, x_max, y_max], 1) 88 | -------------------------------------------------------------------------------- /util/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bisect import bisect_right 3 | from typing import List 4 | 5 | 6 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 7 | def __init__( 8 | self, 9 | optimizer: torch.optim.Optimizer, 10 | milestones: List[int], 11 | gamma: float = 0.1, 12 | warmup_factor: float = 0.333333333, 13 | warmup_epochs: int = 5, 14 | warmup_method: str = "linear", 15 | last_epoch: int = -1, 16 | ): 17 | if not list(milestones) == sorted(milestones): 18 | raise ValueError( 19 | "Milestones should be a list of" " increasing integers. Got {}", milestones 20 | ) 21 | self.milestones = milestones 22 | self.gamma = gamma 23 | self.warmup_factor = warmup_factor 24 | self.warmup_epochs = warmup_epochs 25 | self.warmup_method = warmup_method 26 | super().__init__(optimizer, last_epoch) 27 | 28 | def get_lr(self) -> List[float]: 29 | warmup_factor = _get_warmup_factor_at_iter( 30 | self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor 31 | ) 32 | return [ 33 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 34 | for base_lr in self.base_lrs 35 | ] 36 | 37 | def _compute_values(self) -> List[float]: 38 | # The new interface 39 | return self.get_lr() 40 | 41 | 42 | def _get_warmup_factor_at_iter( 43 | method: str, epoch: int, warmup_epochs: int, warmup_factor: float 44 | ) -> float: 45 | """ 46 | Return the learning rate warmup factor at a specific iteration. 47 | 48 | Args: 49 | method (str): warmup method; either "constant" or "linear". 50 | epoch (int): epoch at which to calculate the warmup factor. 51 | warmup_epochs (int): the number of warmup iterations. 52 | warmup_factor (float): the base warmup factor (the meaning changes according 53 | to the method used). 54 | Returns: 55 | float: the effective warmup factor at the given iteration. 56 | """ 57 | if epoch >= warmup_epochs: 58 | return 1.0 59 | 60 | if method == "constant": 61 | return warmup_factor 62 | elif method == "linear": 63 | alpha = epoch / warmup_epochs 64 | return warmup_factor * (1 - alpha) + alpha 65 | else: 66 | raise ValueError("Unknown warmup method: {}".format(method)) 67 | -------------------------------------------------------------------------------- /util/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting utilities to visualize training logs. 3 | """ 4 | import torch 5 | import pandas as pd 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | 9 | from pathlib import Path, PurePath 10 | 11 | 12 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 13 | ''' 14 | Function to plot specific fields from training log(s). Plots both training and test results. 15 | 16 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 17 | - fields = which results to plot from each log file - plots both training and test for each field. 18 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 19 | - log_name = optional, name of log file if different than default 'log.txt'. 20 | 21 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 22 | - solid lines are training results, dashed lines are test results. 23 | 24 | ''' 25 | func_name = "plot_utils.py::plot_logs" 26 | 27 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 28 | # convert single Path to list to avoid 'not iterable' error 29 | 30 | if not isinstance(logs, list): 31 | if isinstance(logs, PurePath): 32 | logs = [logs] 33 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 34 | else: 35 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 36 | Expect list[Path] or single Path obj, received {type(logs)}") 37 | 38 | # verify valid dir(s) and that every item in list is Path object 39 | for i, dir in enumerate(logs): 40 | if not isinstance(dir, PurePath): 41 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 42 | if dir.exists(): 43 | continue 44 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 45 | 46 | # load log file(s) and plot 47 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 48 | 49 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 50 | 51 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 52 | for j, field in enumerate(fields): 53 | if field == 'mAP': 54 | coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() 55 | axs[j].plot(coco_eval, c=color) 56 | else: 57 | df.interpolate().ewm(com=ewm_col).mean().plot( 58 | y=[f'train_{field}', f'test_{field}'], 59 | ax=axs[j], 60 | color=[color] * 2, 61 | style=['-', '--'] 62 | ) 63 | for ax, field in zip(axs, fields): 64 | ax.legend([Path(p).name for p in logs]) 65 | ax.set_title(field) 66 | 67 | 68 | def plot_precision_recall(files, naming_scheme='iter'): 69 | if naming_scheme == 'exp_id': 70 | # name becomes exp_id 71 | names = [f.parts[-3] for f in files] 72 | elif naming_scheme == 'iter': 73 | names = [f.stem for f in files] 74 | else: 75 | raise ValueError(f'not supported {naming_scheme}') 76 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 77 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 78 | data = torch.load(f) 79 | # precision is n_iou, n_points, n_cat, n_area, max_det 80 | precision = data['precision'] 81 | recall = data['params'].recThrs 82 | scores = data['scores'] 83 | # take precision for all classes, all areas and 100 detections 84 | precision = precision[0, :, :, 0, -1].mean(1) 85 | scores = scores[0, :, :, 0, -1].mean(1) 86 | prec = precision.mean() 87 | rec = data['recall'][0, :, 0, -1].mean() 88 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 89 | f'score={scores.mean():0.3f}, ' + 90 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 91 | ) 92 | axs[0].plot(recall, precision, c=color) 93 | axs[1].plot(recall, scores, c=color) 94 | 95 | axs[0].set_title('Precision / Recall') 96 | axs[0].legend(names) 97 | axs[1].set_title('Scores / Recall') 98 | axs[1].legend(names) 99 | return fig, axs 100 | 101 | 102 | 103 | --------------------------------------------------------------------------------